Unverified Commit eef36472 authored by vllmellm's avatar vllmellm Committed by GitHub
Browse files

[FEAT] [ROCm]: AITER Fused MOE V1 Support (#16752)


Signed-off-by: default avatarvllmellm <vllm.ellm@embeddedllm.com>
Co-authored-by: default avatartjtanaa <tunjian.tan@embeddedllm.com>
parent 0d6e187e
...@@ -11,6 +11,8 @@ from vllm.model_executor.layers.fused_moe.fused_moe import ( ...@@ -11,6 +11,8 @@ from vllm.model_executor.layers.fused_moe.fused_moe import (
dispatch_fused_experts_func, dispatch_topk_func, dispatch_fused_experts_func, dispatch_topk_func,
torch_vllm_inplace_fused_experts, torch_vllm_outplace_fused_experts, torch_vllm_inplace_fused_experts, torch_vllm_outplace_fused_experts,
vllm_topk_softmax) vllm_topk_softmax)
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
is_rocm_aiter_moe_enabled)
from vllm.model_executor.layers.layernorm import ( from vllm.model_executor.layers.layernorm import (
RMSNorm, dispatch_cuda_rmsnorm_func, fused_add_rms_norm, rms_norm, RMSNorm, dispatch_cuda_rmsnorm_func, fused_add_rms_norm, rms_norm,
rocm_aiter_fused_add_rms_norm, rocm_aiter_rms_norm) rocm_aiter_fused_add_rms_norm, rocm_aiter_rms_norm)
...@@ -100,11 +102,10 @@ def test_enabled_ops_invalid(env: str): ...@@ -100,11 +102,10 @@ def test_enabled_ops_invalid(env: str):
def test_topk_dispatch(use_rocm_aiter: str, monkeypatch): def test_topk_dispatch(use_rocm_aiter: str, monkeypatch):
monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter) monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter)
topk_func = dispatch_topk_func() topk_func = dispatch_topk_func()
is_rocm_aiter_moe_enabled.cache_clear()
if current_platform.is_rocm() and int(use_rocm_aiter): if current_platform.is_rocm() and int(use_rocm_aiter):
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
rocm_aiter_topk_softmax) rocm_aiter_topk_softmax)
assert topk_func == rocm_aiter_topk_softmax assert topk_func == rocm_aiter_topk_softmax
else: else:
assert topk_func == vllm_topk_softmax assert topk_func == vllm_topk_softmax
...@@ -116,11 +117,11 @@ def test_fused_experts_dispatch(use_rocm_aiter: str, inplace: bool, ...@@ -116,11 +117,11 @@ def test_fused_experts_dispatch(use_rocm_aiter: str, inplace: bool,
monkeypatch): monkeypatch):
monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter) monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter)
is_rocm_aiter_moe_enabled.cache_clear()
fused_experts_func = dispatch_fused_experts_func(inplace) fused_experts_func = dispatch_fused_experts_func(inplace)
if current_platform.is_rocm() and int(use_rocm_aiter): if current_platform.is_rocm() and int(use_rocm_aiter):
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
rocm_aiter_fused_experts) rocm_aiter_fused_experts)
assert fused_experts_func == rocm_aiter_fused_experts assert fused_experts_func == rocm_aiter_fused_experts
elif inplace: elif inplace:
assert fused_experts_func == torch_vllm_inplace_fused_experts assert fused_experts_func == torch_vllm_inplace_fused_experts
......
...@@ -304,9 +304,9 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -304,9 +304,9 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
e_score_correction_bias=e_score_correction_bias) e_score_correction_bias=e_score_correction_bias)
return self.fused_experts_func( return self.fused_experts_func(
x, hidden_states=x,
layer.w13_weight, w1=layer.w13_weight,
layer.w2_weight, w2=layer.w2_weight,
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
inplace=True, inplace=True,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment