Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
02cabff2
Unverified
Commit
02cabff2
authored
Jul 01, 2025
by
TJian
Committed by
GitHub
Jul 01, 2025
Browse files
[V1] [ROCm] Enable EP with AITER Fused MoE (#20270)
Signed-off-by:
tjtanaa
<
tunjian.tan@embeddedllm.com
>
parent
3d19d47d
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
15 additions
and
5 deletions
+15
-5
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+1
-1
vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py
vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py
+9
-2
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
...quantization/compressed_tensors/compressed_tensors_moe.py
+2
-1
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+3
-1
No files found.
vllm/model_executor/layers/fused_moe/layer.py
View file @
02cabff2
...
@@ -646,13 +646,13 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
...
@@ -646,13 +646,13 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
indices_type
=
self
.
topk_indices_dtype
)
indices_type
=
self
.
topk_indices_dtype
)
if
self
.
rocm_aiter_moe_enabled
:
if
self
.
rocm_aiter_moe_enabled
:
assert
expert_map
is
None
return
self
.
rocm_aiter_fused_experts
(
return
self
.
rocm_aiter_fused_experts
(
hidden_states
=
x
,
hidden_states
=
x
,
w1
=
layer
.
w13_weight
,
w1
=
layer
.
w13_weight
,
w2
=
layer
.
w2_weight
,
w2
=
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
topk_ids
=
topk_ids
,
expert_map
=
expert_map
,
activation
=
activation
,
activation
=
activation
,
apply_router_weight_on_input
=
apply_router_weight_on_input
)
apply_router_weight_on_input
=
apply_router_weight_on_input
)
else
:
else
:
...
...
vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py
View file @
02cabff2
...
@@ -315,7 +315,8 @@ def rocm_aiter_fused_experts(
...
@@ -315,7 +315,8 @@ def rocm_aiter_fused_experts(
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
list
[
int
]]
=
None
)
->
torch
.
Tensor
:
block_shape
:
Optional
[
list
[
int
]]
=
None
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
activation_method
=
(
ActivationMethod
.
SILU
activation_method
=
(
ActivationMethod
.
SILU
if
activation
==
"silu"
else
ActivationMethod
.
GELU
)
if
activation
==
"silu"
else
ActivationMethod
.
GELU
)
...
@@ -323,6 +324,11 @@ def rocm_aiter_fused_experts(
...
@@ -323,6 +324,11 @@ def rocm_aiter_fused_experts(
topk_weights
=
topk_weights
.
to
(
torch
.
float32
)
topk_weights
=
topk_weights
.
to
(
torch
.
float32
)
topk_ids
=
topk_ids
.
to
(
torch
.
int32
)
topk_ids
=
topk_ids
.
to
(
torch
.
int32
)
if
expert_map
is
not
None
:
expert_mask
=
(
expert_map
>
-
1
).
to
(
torch
.
int32
)
else
:
expert_mask
=
None
# w8a8 per-channel quantization
# w8a8 per-channel quantization
if
per_channel_quant
and
apply_router_weight_on_input
and
use_fp8_w8a8
:
if
per_channel_quant
and
apply_router_weight_on_input
and
use_fp8_w8a8
:
# AITER tkw1 kernel for FP8 models with `apply_router_weight_on_input`
# AITER tkw1 kernel for FP8 models with `apply_router_weight_on_input`
...
@@ -346,7 +352,7 @@ def rocm_aiter_fused_experts(
...
@@ -346,7 +352,7 @@ def rocm_aiter_fused_experts(
fc2_smooth_scale
=
None
,
fc2_smooth_scale
=
None
,
a16
=
False
,
a16
=
False
,
per_tensor_quant_scale
=
None
,
per_tensor_quant_scale
=
None
,
expert_mask
=
None
,
expert_mask
=
expert_mask
,
activation_method
=
activation_method
)
activation_method
=
activation_method
)
else
:
else
:
...
@@ -378,6 +384,7 @@ def rocm_aiter_fused_experts(
...
@@ -378,6 +384,7 @@ def rocm_aiter_fused_experts(
w2
,
w2
,
topk_weights
,
topk_weights
,
topk_ids
,
topk_ids
,
expert_mask
=
expert_mask
,
quant_method
=
quant_method
,
quant_method
=
quant_method
,
activation_method
=
activation_method
,
activation_method
=
activation_method
,
w1_scale
=
w1_scale
,
w1_scale
=
w1_scale
,
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
View file @
02cabff2
...
@@ -633,7 +633,8 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
...
@@ -633,7 +633,8 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
w1_scale
=
layer
.
w13_weight_scale
,
w1_scale
=
layer
.
w13_weight_scale
,
w2_scale
=
layer
.
w2_weight_scale
,
w2_scale
=
layer
.
w2_weight_scale
,
a1_scale
=
layer
.
w13_input_scale
,
a1_scale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
)
a2_scale
=
layer
.
w2_input_scale
,
expert_map
=
expert_map
)
if
self
.
use_marlin
:
if
self
.
use_marlin
:
assert
activation
==
"silu"
,
(
assert
activation
==
"silu"
,
(
f
"
{
activation
}
not supported for Marlin MoE."
)
f
"
{
activation
}
not supported for Marlin MoE."
)
...
...
vllm/model_executor/layers/quantization/fp8.py
View file @
02cabff2
...
@@ -442,6 +442,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
...
@@ -442,6 +442,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
"""
"""
def
__init__
(
self
,
quant_config
:
Fp8Config
):
def
__init__
(
self
,
quant_config
:
Fp8Config
):
from
vllm.model_executor.layers.fused_moe
import
fused_experts
from
vllm.model_executor.layers.fused_moe
import
fused_experts
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
block_quant
=
self
.
quant_config
.
weight_block_size
is
not
None
self
.
block_quant
=
self
.
quant_config
.
weight_block_size
is
not
None
...
@@ -879,7 +880,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
...
@@ -879,7 +880,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
if
self
.
block_quant
else
layer
.
w2_weight_scale
),
if
self
.
block_quant
else
layer
.
w2_weight_scale
),
a1_scale
=
layer
.
w13_input_scale
,
a1_scale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
block_shape
=
self
.
quant_config
.
weight_block_size
)
block_shape
=
self
.
quant_config
.
weight_block_size
,
expert_map
=
expert_map
)
elif
self
.
use_marlin
:
elif
self
.
use_marlin
:
assert
activation
==
"silu"
,
(
assert
activation
==
"silu"
,
(
f
"
{
activation
}
not supported for Marlin MoE."
)
f
"
{
activation
}
not supported for Marlin MoE."
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment