Commit daefd764 authored by zhuwenwen's avatar zhuwenwen
Browse files

skip is_rocm_aiter_moe_enabled and add mla pad

parent 43a52016
......@@ -1088,8 +1088,8 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
# maybe_padded_v = torch.nn.functional.pad(
# v, [0, q.shape[-1] - v.shape[-1]], value=0)
maybe_padded_v = torch.nn.functional.pad(
v, [0, q.shape[-1] - v.shape[-1]] - 32, value=0)
v_tmp = maybe_padded_v[..., :-32].reshape(v.shape[0], v.shape[1],v.shape[2])
v, [0, q.shape[-1] - v.shape[-1]- 32], value=0)
maybe_padded_v = maybe_padded_v[..., :-32].reshape(v.shape[0], v.shape[1],v.shape[2])
if is_hip and envs.VLLM_USE_TRITON_FLASH_ATTN \
and not return_softmax_lse:
......@@ -1120,8 +1120,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
attn_out = self.flash_attn_varlen_func(
q=q,
k=k,
# v=maybe_padded_v,
v = v_tmp,
v = maybe_padded_v,
return_attn_probs=return_softmax_lse,
softmax_scale=softmax_scale,
**kwargs,
......
......@@ -23,7 +23,7 @@ from vllm.model_executor.layers.quantization.utils.int8_utils import (
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled
# from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled
logger = init_logger(__name__)
device_name = current_platform.get_device_name().replace(" ", "_")
......@@ -1141,9 +1141,9 @@ def vllm_topk_softmax(topk_weights: torch.Tensor, topk_indices: torch.Tensor,
def dispatch_topk_func() -> Callable[..., tuple[torch.Tensor, ...]]:
if is_rocm_aiter_moe_enabled():
from .rocm_aiter_fused_moe import rocm_aiter_topk_softmax
return rocm_aiter_topk_softmax
# if is_rocm_aiter_moe_enabled():
# from .rocm_aiter_fused_moe import rocm_aiter_topk_softmax
# return rocm_aiter_topk_softmax
return vllm_topk_softmax
......@@ -1405,9 +1405,9 @@ def torch_vllm_outplace_fused_experts(**kwargs) -> torch.Tensor:
def dispatch_fused_experts_func(inplace: bool) -> Callable[..., torch.Tensor]:
if is_rocm_aiter_moe_enabled():
from .rocm_aiter_fused_moe import rocm_aiter_fused_experts
return rocm_aiter_fused_experts
# if is_rocm_aiter_moe_enabled():
# from .rocm_aiter_fused_moe import rocm_aiter_fused_experts
# return rocm_aiter_fused_experts
if inplace:
return torch_vllm_inplace_fused_experts
return torch_vllm_outplace_fused_experts
......
......@@ -135,15 +135,15 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer.w13_weight.data = self._maybe_pad_weight(layer.w13_weight.data)
layer.w2_weight.data = self._maybe_pad_weight(layer.w2_weight.data)
# Lazy import to avoid importing triton.
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
is_rocm_aiter_moe_enabled, shuffle_weights)
if is_rocm_aiter_moe_enabled():
# reshaping weights is required for aiter moe kernel.
shuffled_w13, shuffled_w2 = shuffle_weights(
layer.w13_weight.data, layer.w2_weight.data)
layer.w13_weight.data = shuffled_w13
layer.w2_weight.data = shuffled_w2
# from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
# is_rocm_aiter_moe_enabled, shuffle_weights)
# if is_rocm_aiter_moe_enabled():
# # reshaping weights is required for aiter moe kernel.
# shuffled_w13, shuffled_w2 = shuffle_weights(
# layer.w13_weight.data, layer.w2_weight.data)
# layer.w13_weight.data = shuffled_w13
# layer.w2_weight.data = shuffled_w2
if current_platform.is_cpu():
if current_platform.get_cpu_architecture() == CpuArchEnum.X86:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment