Unverified Commit f26fcdfb authored by Stig-Arne Grönroos's avatar Stig-Arne Grönroos Committed by GitHub
Browse files

[Bugfix][ROCm] Fix lru_cache on paged_mqa_logits_module (#37547)


Signed-off-by: default avatarStig-Arne Grönroos <stig-arne.gronroos@amd.com>
parent bc9c6fbb
...@@ -273,6 +273,25 @@ def fp8_paged_mqa_logits_torch( ...@@ -273,6 +273,25 @@ def fp8_paged_mqa_logits_torch(
return logits return logits
@functools.lru_cache
def paged_mqa_logits_module():
paged_mqa_logits_module_path = None
if importlib.util.find_spec("aiter.ops.triton.pa_mqa_logits") is not None:
paged_mqa_logits_module_path = "aiter.ops.triton.pa_mqa_logits"
elif (
importlib.util.find_spec("aiter.ops.triton.attention.pa_mqa_logits") is not None
):
paged_mqa_logits_module_path = "aiter.ops.triton.attention.pa_mqa_logits"
if paged_mqa_logits_module_path is not None:
try:
module = importlib.import_module(paged_mqa_logits_module_path)
return module
except ImportError:
return None
return None
def rocm_fp8_paged_mqa_logits( def rocm_fp8_paged_mqa_logits(
q_fp8: torch.Tensor, q_fp8: torch.Tensor,
kv_cache_fp8: torch.Tensor, kv_cache_fp8: torch.Tensor,
...@@ -305,25 +324,6 @@ def rocm_fp8_paged_mqa_logits( ...@@ -305,25 +324,6 @@ def rocm_fp8_paged_mqa_logits(
""" """
from vllm._aiter_ops import rocm_aiter_ops from vllm._aiter_ops import rocm_aiter_ops
@functools.lru_cache
def paged_mqa_logits_module():
paged_mqa_logits_module_path = None
if importlib.util.find_spec("aiter.ops.triton.pa_mqa_logits") is not None:
paged_mqa_logits_module_path = "aiter.ops.triton.pa_mqa_logits"
elif (
importlib.util.find_spec("aiter.ops.triton.attention.pa_mqa_logits")
is not None
):
paged_mqa_logits_module_path = "aiter.ops.triton.attention.pa_mqa_logits"
if paged_mqa_logits_module_path is not None:
try:
module = importlib.import_module(paged_mqa_logits_module_path)
return module
except ImportError:
return None
return None
aiter_paged_mqa_logits_module = None aiter_paged_mqa_logits_module = None
if rocm_aiter_ops.is_enabled(): if rocm_aiter_ops.is_enabled():
aiter_paged_mqa_logits_module = paged_mqa_logits_module() aiter_paged_mqa_logits_module = paged_mqa_logits_module()
...@@ -400,6 +400,26 @@ def fp8_mqa_logits_torch( ...@@ -400,6 +400,26 @@ def fp8_mqa_logits_torch(
return logits return logits
@functools.lru_cache
def mqa_logits_module():
mqa_logits_module_path = None
if importlib.util.find_spec("aiter.ops.triton.fp8_mqa_logits") is not None:
mqa_logits_module_path = "aiter.ops.triton.fp8_mqa_logits"
elif (
importlib.util.find_spec("aiter.ops.triton.attention.fp8_mqa_logits")
is not None
):
mqa_logits_module_path = "aiter.ops.triton.attention.fp8_mqa_logits"
if mqa_logits_module_path is not None:
try:
module = importlib.import_module(mqa_logits_module_path)
return module
except ImportError:
return None
return None
def rocm_fp8_mqa_logits( def rocm_fp8_mqa_logits(
q: torch.Tensor, q: torch.Tensor,
kv: tuple[torch.Tensor, torch.Tensor], kv: tuple[torch.Tensor, torch.Tensor],
...@@ -429,25 +449,6 @@ def rocm_fp8_mqa_logits( ...@@ -429,25 +449,6 @@ def rocm_fp8_mqa_logits(
# path after aiter merge this kernel into main # path after aiter merge this kernel into main
from vllm._aiter_ops import rocm_aiter_ops from vllm._aiter_ops import rocm_aiter_ops
@functools.lru_cache
def mqa_logits_module():
mqa_logits_module_path = None
if importlib.util.find_spec("aiter.ops.triton.fp8_mqa_logits") is not None:
mqa_logits_module_path = "aiter.ops.triton.fp8_mqa_logits"
elif (
importlib.util.find_spec("aiter.ops.triton.attention.fp8_mqa_logits")
is not None
):
mqa_logits_module_path = "aiter.ops.triton.attention.fp8_mqa_logits"
if mqa_logits_module_path is not None:
try:
module = importlib.import_module(mqa_logits_module_path)
return module
except ImportError:
return None
return None
aiter_mqa_logits_module = None aiter_mqa_logits_module = None
if rocm_aiter_ops.is_enabled(): if rocm_aiter_ops.is_enabled():
aiter_mqa_logits_module = mqa_logits_module() aiter_mqa_logits_module = mqa_logits_module()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment