Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3b3b778d
Unverified
Commit
3b3b778d
authored
Jul 13, 2025
by
ElizaWszola
Committed by
GitHub
Jul 12, 2025
Browse files
[Bugfix] Fix a couple PPLX+CUTLASS MoE bugs (#20825)
Signed-off-by:
ElizaWszola
<
ewszola@redhat.com
>
parent
42d440c2
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
37 additions
and
20 deletions
+37
-20
vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py
.../model_executor/layers/fused_moe/pplx_prepare_finalize.py
+2
-2
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
...quantization/compressed_tensors/compressed_tensors_moe.py
+35
-18
No files found.
vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py
View file @
3b3b778d
...
...
@@ -204,7 +204,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
out_expert_x_scale
=
expert_x_scale
,
dp_x
=
a1q
,
dp_x_scale
=
a1q_scale
,
indices
=
topk_ids
,
indices
=
topk_ids
.
view
(
dtype
=
torch
.
uint32
)
,
bound_m
=
bound_m
,
)
...
...
@@ -249,7 +249,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
topk_weights
=
torch
.
ones_like
(
topk_weights
)
self
.
a2a
.
combine
(
out_tokens
=
output
,
indices
=
topk_ids
,
indices
=
topk_ids
.
view
(
dtype
=
torch
.
uint32
)
,
weights
=
topk_weights
,
expert_y
=
fused_expert_output
,
bound_m
=
bound_m
)
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
View file @
3b3b778d
...
...
@@ -737,10 +737,8 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
"For FP8 Fused MoE layer, we require either per tensor or "
"channelwise, dynamic per token quantization."
)
from
vllm.model_executor.layers.fused_moe.cutlass_moe
import
(
cutlass_moe_fp8
)
self
.
topk_indices_dtype
=
None
self
.
fused_experts
=
cutlass_moe_fp8
# type: ignore
self
.
fused_experts
=
None
# type: ignore
self
.
disable_expert_map
=
False
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
...
...
@@ -936,7 +934,11 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
per_act_token
=
a1_scale
.
numel
()
!=
1
if
a1_scale
is
not
None
else
(
a2_scale
.
numel
()
!=
1
if
a2_scale
is
not
None
else
False
)
return
self
.
fused_experts
(
if
self
.
fused_experts
is
None
:
# If no modular kernel is provided, use cutlass_moe_fp8
from
vllm.model_executor.layers.fused_moe.cutlass_moe
import
(
cutlass_moe_fp8
)
return
cutlass_moe_fp8
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
...
...
@@ -951,6 +953,21 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
)
else
:
return
self
.
fused_experts
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
topk_weights
,
topk_ids
,
activation
=
activation
,
global_num_experts
=
global_num_experts
,
expert_map
=
None
if
self
.
disable_expert_map
else
expert_map
,
w1_scale
=
layer
.
w13_weight_scale
,
w2_scale
=
layer
.
w2_weight_scale
,
a1_scale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
)
class
CompressedTensorsW8A8Int8MoEMethod
(
CompressedTensorsMoEMethod
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment