Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
31b96d1c
Unverified
Commit
31b96d1c
authored
Jul 10, 2025
by
Michael Goin
Committed by
GitHub
Jul 09, 2025
Browse files
Support Llama 4 for cutlass_moe_fp4 (#20453)
Signed-off-by:
mgoin
<
mgoin64@gmail.com
>
parent
e59ba9e1
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
80 additions
and
74 deletions
+80
-74
vllm/model_executor/layers/fused_moe/cutlass_moe.py
vllm/model_executor/layers/fused_moe/cutlass_moe.py
+28
-9
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
...quantization/compressed_tensors/compressed_tensors_moe.py
+20
-20
vllm/model_executor/layers/quantization/modelopt.py
vllm/model_executor/layers/quantization/modelopt.py
+32
-45
No files found.
vllm/model_executor/layers/fused_moe/cutlass_moe.py
View file @
31b96d1c
...
...
@@ -411,13 +411,23 @@ FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max()
FLOAT8_E4M3_MAX
=
torch
.
finfo
(
torch
.
float8_e4m3fn
).
max
def
cutlass_moe_fp4
(
a
:
torch
.
Tensor
,
a1_gscale
:
torch
.
Tensor
,
w1_fp4
:
torch
.
Tensor
,
w1_blockscale
:
torch
.
Tensor
,
w1_alphas
:
torch
.
Tensor
,
a2_gscale
:
torch
.
Tensor
,
w2_fp4
:
torch
.
Tensor
,
w2_blockscale
:
torch
.
Tensor
,
w2_alphas
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
m
:
int
,
n
:
int
,
k
:
int
,
e
:
int
,
device
:
torch
.
device
):
def
cutlass_moe_fp4
(
a
:
torch
.
Tensor
,
a1_gscale
:
torch
.
Tensor
,
w1_fp4
:
torch
.
Tensor
,
w1_blockscale
:
torch
.
Tensor
,
w1_alphas
:
torch
.
Tensor
,
a2_gscale
:
torch
.
Tensor
,
w2_fp4
:
torch
.
Tensor
,
w2_blockscale
:
torch
.
Tensor
,
w2_alphas
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
m
:
int
,
n
:
int
,
k
:
int
,
e
:
int
,
device
:
torch
.
device
,
apply_router_weight_on_input
:
bool
=
False
):
"""
MoE implementation for FP4 Inputs
...
...
@@ -480,6 +490,12 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor,
a_map
=
torch
.
empty
((
topk_ids
.
numel
()),
dtype
=
torch
.
int32
,
device
=
device
)
c_map
=
torch
.
empty
((
topk_ids
.
numel
()),
dtype
=
torch
.
int32
,
device
=
device
)
if
apply_router_weight_on_input
:
# TODO: this only works for topK=1, will need to update for topK>1
assert
num_topk
==
1
,
\
"apply_router_weight_on_input is only implemented for topk=1"
a
.
mul_
(
topk_weights
.
to
(
out_dtype
))
# problem shapes should have [m, n, k]
# Note that problem sizes are based on logical number of elements.
ops
.
get_cutlass_moe_mm_data
(
topk_ids
,
expert_offsets
,
problem_sizes1
,
...
...
@@ -517,8 +533,11 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor,
del
int_fp4
,
int_blockscale
c2
=
ops
.
shuffle_rows
(
c2
,
c_map
)
out
=
(
c2
.
view
(
m
,
num_topk
,
k
)
*
topk_weights
.
view
(
m
,
num_topk
,
1
).
half
()).
sum
(
dim
=
1
)
if
not
apply_router_weight_on_input
:
out
=
(
c2
.
view
(
m
,
num_topk
,
k
)
*
topk_weights
.
view
(
m
,
num_topk
,
1
).
to
(
out_dtype
)).
sum
(
dim
=
1
)
else
:
out
=
c2
.
view
(
m
,
num_topk
,
k
).
sum
(
dim
=
1
)
return
out
.
to
(
dtype
=
out_dtype
)
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
View file @
31b96d1c
...
...
@@ -295,6 +295,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
if
enable_eplb
:
raise
NotImplementedError
(
"EPLB not supported for "
"`CompressedTensorsW4A4MoeMethod` yet."
)
assert
activation
==
"silu"
,
"Only SiLU activation is supported."
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
...
...
@@ -326,10 +327,6 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
)
assert
activation
==
"silu"
,
"Only SiLU activation is supported."
assert
not
apply_router_weight_on_input
,
(
"Router weight on input is not "
"supported for CompressedTensorsW4A4MoeMethod."
)
assert
expert_map
is
None
,
(
"Expert Parallelism / expert_map "
"is currently not supported for "
"CompressedTensorsW4A4MoeMethod."
)
...
...
@@ -339,22 +336,25 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
# Cutlass moe takes in activations in BF16/Half precision
# and fp4 quantized weights loaded from the checkpoint
return
cutlass_moe_fp4
(
a
=
x
,
w1_fp4
=
layer
.
w13_weight
,
w1_blockscale
=
layer
.
w13_blockscale_swizzled
,
w1_alphas
=
layer
.
g1_alphas
,
w2_fp4
=
layer
.
w2_weight
,
w2_blockscale
=
layer
.
w2_blockscale_swizzled
,
w2_alphas
=
layer
.
g2_alphas
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
m
=
x
.
shape
[
0
],
n
=
layer
.
w2_weight
.
shape
[
2
]
*
2
,
k
=
x
.
shape
[
1
],
e
=
layer
.
w13_weight
.
shape
[
0
],
a1_gscale
=
layer
.
w13_input_scale_quant
,
a2_gscale
=
layer
.
w2_input_scale_quant
,
device
=
x
.
device
).
to
(
x
.
dtype
)
return
cutlass_moe_fp4
(
a
=
x
,
w1_fp4
=
layer
.
w13_weight
,
w1_blockscale
=
layer
.
w13_blockscale_swizzled
,
w1_alphas
=
layer
.
g1_alphas
,
w2_fp4
=
layer
.
w2_weight
,
w2_blockscale
=
layer
.
w2_blockscale_swizzled
,
w2_alphas
=
layer
.
g2_alphas
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
m
=
x
.
shape
[
0
],
n
=
layer
.
w2_weight
.
shape
[
2
]
*
2
,
k
=
x
.
shape
[
1
],
e
=
layer
.
w13_weight
.
shape
[
0
],
a1_gscale
=
layer
.
w13_input_scale_quant
,
a2_gscale
=
layer
.
w2_input_scale_quant
,
device
=
x
.
device
,
apply_router_weight_on_input
=
apply_router_weight_on_input
).
to
(
x
.
dtype
)
class
CompressedTensorsW8A8Fp8MoEMethod
(
CompressedTensorsMoEMethod
):
...
...
vllm/model_executor/layers/quantization/modelopt.py
View file @
31b96d1c
...
...
@@ -673,21 +673,21 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
if
enable_eplb
:
raise
NotImplementedError
(
"EPLB not supported for `ModelOptNvFp4FusedMoE` yet."
)
assert
activation
==
"silu"
,
"Only SiLU activation is supported."
if
self
.
use_marlin
:
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
use_grouped_topk
=
use_grouped_topk
,
top_k
=
top_k
,
renormalize
=
renormalize
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
custom_routing_function
=
custom_routing_function
,
scoring_func
=
scoring_func
,
e_score_correction_bias
=
e_score_correction_bias
,
)
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
use_grouped_topk
=
use_grouped_topk
,
top_k
=
top_k
,
renormalize
=
renormalize
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
custom_routing_function
=
custom_routing_function
,
scoring_func
=
scoring_func
,
e_score_correction_bias
=
e_score_correction_bias
)
if
self
.
use_marlin
:
return
torch
.
ops
.
vllm
.
fused_marlin_moe
(
x
,
layer
.
w13_weight
,
...
...
@@ -704,44 +704,31 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
)
assert
activation
==
"silu"
,
"Only SiLU activation is supported."
assert
not
apply_router_weight_on_input
,
(
"Router weight on input is not "
"supported for ModelOptNvFp4FusedMoE."
)
assert
expert_map
is
None
,
(
"Expert Parallelism / expert_map "
"is currently not supported for "
"ModelOptNvFp4FusedMoE."
)
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
use_grouped_topk
=
use_grouped_topk
,
top_k
=
top_k
,
renormalize
=
renormalize
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
custom_routing_function
=
custom_routing_function
,
scoring_func
=
scoring_func
,
e_score_correction_bias
=
e_score_correction_bias
)
from
vllm.model_executor.layers.fused_moe.cutlass_moe
import
(
cutlass_moe_fp4
)
# Cutlass moe takes in activations in BF16/Half precision
# and fp4 quantized weights loaded from the checkpoint
return
cutlass_moe_fp4
(
a
=
x
,
w1_fp4
=
layer
.
w13_weight
,
w1_blockscale
=
layer
.
w13_blockscale_swizzled
,
w1_alphas
=
layer
.
g1_alphas
,
w2_fp4
=
layer
.
w2_weight
,
w2_blockscale
=
layer
.
w2_blockscale_swizzled
,
w2_alphas
=
layer
.
g2_alphas
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
m
=
x
.
shape
[
0
],
n
=
layer
.
w2_weight
.
shape
[
2
]
*
2
,
k
=
x
.
shape
[
1
],
e
=
layer
.
w13_weight
.
shape
[
0
],
a1_gscale
=
layer
.
w13_input_scale_quant
,
a2_gscale
=
layer
.
w2_input_scale_quant
,
device
=
x
.
device
).
to
(
x
.
dtype
)
return
cutlass_moe_fp4
(
a
=
x
,
w1_fp4
=
layer
.
w13_weight
,
w1_blockscale
=
layer
.
w13_blockscale_swizzled
,
w1_alphas
=
layer
.
g1_alphas
,
w2_fp4
=
layer
.
w2_weight
,
w2_blockscale
=
layer
.
w2_blockscale_swizzled
,
w2_alphas
=
layer
.
g2_alphas
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
m
=
x
.
shape
[
0
],
n
=
layer
.
w2_weight
.
shape
[
2
]
*
2
,
k
=
x
.
shape
[
1
],
e
=
layer
.
w13_weight
.
shape
[
0
],
a1_gscale
=
layer
.
w13_input_scale_quant
,
a2_gscale
=
layer
.
w2_input_scale_quant
,
device
=
x
.
device
,
apply_router_weight_on_input
=
apply_router_weight_on_input
).
to
(
x
.
dtype
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment