Commit daefd764 authored by zhuwenwen's avatar zhuwenwen
Browse files

skip is_rocm_aiter_moe_enabled and add mla pad

parent 43a52016
...@@ -1088,8 +1088,8 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): ...@@ -1088,8 +1088,8 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
# maybe_padded_v = torch.nn.functional.pad( # maybe_padded_v = torch.nn.functional.pad(
# v, [0, q.shape[-1] - v.shape[-1]], value=0) # v, [0, q.shape[-1] - v.shape[-1]], value=0)
maybe_padded_v = torch.nn.functional.pad( maybe_padded_v = torch.nn.functional.pad(
v, [0, q.shape[-1] - v.shape[-1]] - 32, value=0) v, [0, q.shape[-1] - v.shape[-1]- 32], value=0)
v_tmp = maybe_padded_v[..., :-32].reshape(v.shape[0], v.shape[1],v.shape[2]) maybe_padded_v = maybe_padded_v[..., :-32].reshape(v.shape[0], v.shape[1],v.shape[2])
if is_hip and envs.VLLM_USE_TRITON_FLASH_ATTN \ if is_hip and envs.VLLM_USE_TRITON_FLASH_ATTN \
and not return_softmax_lse: and not return_softmax_lse:
...@@ -1120,8 +1120,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): ...@@ -1120,8 +1120,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
attn_out = self.flash_attn_varlen_func( attn_out = self.flash_attn_varlen_func(
q=q, q=q,
k=k, k=k,
# v=maybe_padded_v, v = maybe_padded_v,
v = v_tmp,
return_attn_probs=return_softmax_lse, return_attn_probs=return_softmax_lse,
softmax_scale=softmax_scale, softmax_scale=softmax_scale,
**kwargs, **kwargs,
......
...@@ -23,7 +23,7 @@ from vllm.model_executor.layers.quantization.utils.int8_utils import ( ...@@ -23,7 +23,7 @@ from vllm.model_executor.layers.quantization.utils.int8_utils import (
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled # from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled
logger = init_logger(__name__) logger = init_logger(__name__)
device_name = current_platform.get_device_name().replace(" ", "_") device_name = current_platform.get_device_name().replace(" ", "_")
...@@ -1141,9 +1141,9 @@ def vllm_topk_softmax(topk_weights: torch.Tensor, topk_indices: torch.Tensor, ...@@ -1141,9 +1141,9 @@ def vllm_topk_softmax(topk_weights: torch.Tensor, topk_indices: torch.Tensor,
def dispatch_topk_func() -> Callable[..., tuple[torch.Tensor, ...]]: def dispatch_topk_func() -> Callable[..., tuple[torch.Tensor, ...]]:
if is_rocm_aiter_moe_enabled(): # if is_rocm_aiter_moe_enabled():
from .rocm_aiter_fused_moe import rocm_aiter_topk_softmax # from .rocm_aiter_fused_moe import rocm_aiter_topk_softmax
return rocm_aiter_topk_softmax # return rocm_aiter_topk_softmax
return vllm_topk_softmax return vllm_topk_softmax
...@@ -1405,9 +1405,9 @@ def torch_vllm_outplace_fused_experts(**kwargs) -> torch.Tensor: ...@@ -1405,9 +1405,9 @@ def torch_vllm_outplace_fused_experts(**kwargs) -> torch.Tensor:
def dispatch_fused_experts_func(inplace: bool) -> Callable[..., torch.Tensor]: def dispatch_fused_experts_func(inplace: bool) -> Callable[..., torch.Tensor]:
if is_rocm_aiter_moe_enabled(): # if is_rocm_aiter_moe_enabled():
from .rocm_aiter_fused_moe import rocm_aiter_fused_experts # from .rocm_aiter_fused_moe import rocm_aiter_fused_experts
return rocm_aiter_fused_experts # return rocm_aiter_fused_experts
if inplace: if inplace:
return torch_vllm_inplace_fused_experts return torch_vllm_inplace_fused_experts
return torch_vllm_outplace_fused_experts return torch_vllm_outplace_fused_experts
......
...@@ -135,15 +135,15 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -135,15 +135,15 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer.w13_weight.data = self._maybe_pad_weight(layer.w13_weight.data) layer.w13_weight.data = self._maybe_pad_weight(layer.w13_weight.data)
layer.w2_weight.data = self._maybe_pad_weight(layer.w2_weight.data) layer.w2_weight.data = self._maybe_pad_weight(layer.w2_weight.data)
# Lazy import to avoid importing triton. # Lazy import to avoid importing triton.
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
is_rocm_aiter_moe_enabled, shuffle_weights) # is_rocm_aiter_moe_enabled, shuffle_weights)
if is_rocm_aiter_moe_enabled(): # if is_rocm_aiter_moe_enabled():
# reshaping weights is required for aiter moe kernel. # # reshaping weights is required for aiter moe kernel.
shuffled_w13, shuffled_w2 = shuffle_weights( # shuffled_w13, shuffled_w2 = shuffle_weights(
layer.w13_weight.data, layer.w2_weight.data) # layer.w13_weight.data, layer.w2_weight.data)
layer.w13_weight.data = shuffled_w13 # layer.w13_weight.data = shuffled_w13
layer.w2_weight.data = shuffled_w2 # layer.w2_weight.data = shuffled_w2
if current_platform.is_cpu(): if current_platform.is_cpu():
if current_platform.get_cpu_architecture() == CpuArchEnum.X86: if current_platform.get_cpu_architecture() == CpuArchEnum.X86:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment