Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
9a2a6357
Unverified
Commit
9a2a6357
authored
May 13, 2025
by
Michael Goin
Committed by
GitHub
May 13, 2025
Browse files
[Bugfix] Fix FP8 Marlin MoE and enable for compressed-tensors models (#18026)
Signed-off-by:
mgoin
<
mgoin64@gmail.com
>
parent
6266c57b
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
53 additions
and
6 deletions
+53
-6
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
...quantization/compressed_tensors/compressed_tensors_moe.py
+48
-6
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+4
-0
vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py
...el_executor/layers/quantization/utils/marlin_utils_fp8.py
+1
-0
No files found.
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
View file @
9a2a6357
...
@@ -9,6 +9,7 @@ from compressed_tensors import CompressionFormat
...
@@ -9,6 +9,7 @@ from compressed_tensors import CompressionFormat
from
compressed_tensors.quantization
import
(
ActivationOrdering
,
from
compressed_tensors.quantization
import
(
ActivationOrdering
,
QuantizationStrategy
)
QuantizationStrategy
)
import
vllm.envs
as
envs
import
vllm.model_executor.layers.fused_moe
# noqa
import
vllm.model_executor.layers.fused_moe
# noqa
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
...
@@ -20,10 +21,13 @@ from vllm.model_executor.layers.quantization.utils import replace_parameter
...
@@ -20,10 +21,13 @@ from vllm.model_executor.layers.quantization.utils import replace_parameter
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
check_moe_marlin_supports_layer
,
marlin_make_workspace_new
,
check_moe_marlin_supports_layer
,
marlin_make_workspace_new
,
marlin_moe_permute_scales
)
marlin_moe_permute_scales
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp8
import
(
prepare_moe_fp8_layer_for_marlin
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
all_close_1d
,
normalize_e4m3fn_to_e4m3fnuz
,
per_tensor_dequantize
)
all_close_1d
,
normalize_e4m3fn_to_e4m3fnuz
,
per_tensor_dequantize
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.scalar_type
import
scalar_types
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -114,10 +118,24 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
...
@@ -114,10 +118,24 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
"For FP8 Fused MoE layer, we require either per tensor or "
"For FP8 Fused MoE layer, we require either per tensor or "
"channelwise, dynamic per token quantization."
)
"channelwise, dynamic per token quantization."
)
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
# kernel for fast weight-only FP8 quantization
self
.
use_marlin
=
(
not
current_platform
.
has_device_capability
(
89
)
or
envs
.
VLLM_TEST_FORCE_FP8_MARLIN
)
# Disable marlin for rocm
if
current_platform
.
is_rocm
():
self
.
use_marlin
=
False
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size_per_partition
:
int
,
hidden_size
:
int
,
intermediate_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
):
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
):
layer
.
intermediate_size_per_partition
=
intermediate_size_per_partition
layer
.
hidden_size
=
hidden_size
layer
.
num_experts
=
num_experts
layer
.
orig_dtype
=
params_dtype
layer
.
weight_block_size
=
None
params_dtype
=
torch
.
float8_e4m3fn
params_dtype
=
torch
.
float8_e4m3fn
# WEIGHTS
# WEIGHTS
...
@@ -280,6 +298,12 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
...
@@ -280,6 +298,12 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
from
vllm.model_executor.layers.fused_moe
import
fused_experts
from
vllm.model_executor.layers.fused_moe
import
fused_experts
self
.
fused_experts_func
=
fused_experts
self
.
fused_experts_func
=
fused_experts
if
self
.
use_marlin
:
prepare_moe_fp8_layer_for_marlin
(
layer
,
False
)
# Activations not quantized for marlin.
del
layer
.
w13_input_scale
del
layer
.
w2_input_scale
def
apply
(
def
apply
(
self
,
self
,
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
...
@@ -311,6 +335,24 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
...
@@ -311,6 +335,24 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
scoring_func
=
scoring_func
,
scoring_func
=
scoring_func
,
e_score_correction_bias
=
e_score_correction_bias
)
e_score_correction_bias
=
e_score_correction_bias
)
if
self
.
use_marlin
:
assert
activation
==
"silu"
,
(
f
"
{
activation
}
not supported for Marlin MoE."
)
assert
not
apply_router_weight_on_input
,
(
"Apply router weight on input not supported for Marlin MoE."
)
return
torch
.
ops
.
vllm
.
fused_marlin_moe
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
layer
.
w13_weight_scale
,
layer
.
w2_weight_scale
,
router_logits
,
topk_weights
,
topk_ids
,
quant_type_id
=
scalar_types
.
float8_e4m3fn
.
id
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
)
return
self
.
fused_experts_func
(
return
self
.
fused_experts_func
(
hidden_states
=
x
,
hidden_states
=
x
,
w1
=
layer
.
w13_weight
,
w1
=
layer
.
w13_weight
,
...
@@ -517,7 +559,8 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
...
@@ -517,7 +559,8 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
activation
:
str
=
"silu"
,
activation
:
str
=
"silu"
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
assert
activation
==
"silu"
assert
activation
==
"silu"
,
(
f
"
{
activation
}
not supported for Cutlass MoE."
)
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
hidden_states
=
x
,
...
@@ -942,11 +985,10 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
...
@@ -942,11 +985,10 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
apply_router_weight_on_input
:
bool
=
False
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
activation
:
str
=
"silu"
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
assert
activation
==
"silu"
,
"Only SiLU activation is supported."
assert
activation
==
"silu"
,
(
if
apply_router_weight_on_input
:
f
"
{
activation
}
not supported for Marlin MoE."
)
raise
NotImplementedError
(
assert
not
apply_router_weight_on_input
,
(
"Apply router weight on input is not supported for "
"Apply router weight on input not supported for Marlin MoE."
)
"fused Marlin MoE method."
)
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
hidden_states
=
x
,
...
...
vllm/model_executor/layers/quantization/fp8.py
View file @
9a2a6357
...
@@ -811,6 +811,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
...
@@ -811,6 +811,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
)
)
if
self
.
use_marlin
:
if
self
.
use_marlin
:
assert
activation
==
"silu"
,
(
f
"
{
activation
}
not supported for Marlin MoE."
)
assert
not
apply_router_weight_on_input
,
(
"Apply router weight on input not supported for Marlin MoE."
)
return
torch
.
ops
.
vllm
.
fused_marlin_moe
(
return
torch
.
ops
.
vllm
.
fused_marlin_moe
(
x
,
x
,
layer
.
w13_weight
,
layer
.
w13_weight
,
...
...
vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py
View file @
9a2a6357
...
@@ -268,6 +268,7 @@ def prepare_moe_fp8_layer_for_marlin(layer: torch.nn.Module,
...
@@ -268,6 +268,7 @@ def prepare_moe_fp8_layer_for_marlin(layer: torch.nn.Module,
tensor_list
.
append
(
marlin_scales
)
tensor_list
.
append
(
marlin_scales
)
scales
=
torch
.
cat
([
x
.
unsqueeze
(
0
)
for
x
in
tensor_list
],
0
)
scales
=
torch
.
cat
([
x
.
unsqueeze
(
0
)
for
x
in
tensor_list
],
0
)
scales
=
fp8_fused_exponent_bias_into_scales
(
scales
)
scales
=
torch
.
nn
.
Parameter
(
scales
,
requires_grad
=
False
)
scales
=
torch
.
nn
.
Parameter
(
scales
,
requires_grad
=
False
)
setattr
(
layer
,
name
+
"_weight_scale"
,
scales
)
setattr
(
layer
,
name
+
"_weight_scale"
,
scales
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment