Unverified Commit 92b9afee authored by Chendi.Xue's avatar Chendi.Xue Committed by GitHub
Browse files

[XPU] Quick fix for TritonMLA to remove cuda hardcode (#39088)


Signed-off-by: default avatarChendi Xue <chendi.xue@intel.com>
Co-authored-by: default avatarKunshang Ji <kunshang.ji@intel.com>
parent 73105554
...@@ -222,7 +222,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -222,7 +222,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
self.cpu_fused_moe = cpu_fused_moe.CPUFusedMOE(layer) self.cpu_fused_moe = cpu_fused_moe.CPUFusedMOE(layer)
else: else:
self.cpu_fused_moe = cpu_fused_moe.CPUFusedMOE(layer) self.cpu_fused_moe = cpu_fused_moe.CPUFusedMOE(layer)
elif current_platform.is_xpu(): elif self.unquantized_backend == UnquantizedMoeBackend.XPU:
w13 = layer.w13_weight w13 = layer.w13_weight
w2 = layer.w2_weight w2 = layer.w2_weight
......
...@@ -13,6 +13,7 @@ from vllm.model_executor.layers.attention.mla_attention import ( ...@@ -13,6 +13,7 @@ from vllm.model_executor.layers.attention.mla_attention import (
MLACommonImpl, MLACommonImpl,
MLACommonMetadata, MLACommonMetadata,
) )
from vllm.platforms import current_platform
from vllm.platforms.interface import DeviceCapability from vllm.platforms.interface import DeviceCapability
from vllm.triton_utils import triton from vllm.triton_utils import triton
from vllm.utils.torch_utils import is_quantized_kv_cache from vllm.utils.torch_utils import is_quantized_kv_cache
...@@ -116,7 +117,7 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]): ...@@ -116,7 +117,7 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
if is_quantized_kv_cache(self.kv_cache_dtype): if is_quantized_kv_cache(self.kv_cache_dtype):
self.supports_quant_query_input = False self.supports_quant_query_input = False
self._sm_count = torch.cuda.get_device_properties(0).multi_processor_count self._sm_count = current_platform.num_compute_units()
def _flash_attn_varlen_diff_headdims( def _flash_attn_varlen_diff_headdims(
self, q, k, v, return_softmax_lse=False, softmax_scale=None, **kwargs self, q, k, v, return_softmax_lse=False, softmax_scale=None, **kwargs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment