Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
868c546d
Unverified
Commit
868c546d
authored
May 02, 2025
by
Michael Goin
Committed by
GitHub
May 02, 2025
Browse files
Support W8A8 INT8 MoE for compressed-tensors (#16745)
Signed-off-by:
mgoin
<
mgoin64@gmail.com
>
parent
99404f53
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
136 additions
and
1 deletion
+136
-1
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
...quantization/compressed_tensors/compressed_tensors_moe.py
+135
-0
vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py
...executor/layers/quantization/kernels/scaled_mm/cutlass.py
+1
-1
No files found.
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
View file @
868c546d
...
@@ -34,6 +34,7 @@ __all__ = [
...
@@ -34,6 +34,7 @@ __all__ = [
"CompressedTensorsMoEMethod"
,
"CompressedTensorsMoEMethod"
,
"CompressedTensorsW8A8Fp8MoEMethod"
,
"CompressedTensorsW8A8Fp8MoEMethod"
,
"CompressedTensorsW8A8Fp8MoECutlassMethod"
,
"CompressedTensorsW8A8Fp8MoECutlassMethod"
,
"CompressedTensorsW8A8Int8MoEMethod"
,
"CompressedTensorsWNA16MarlinMoEMethod"
,
"CompressedTensorsWNA16MarlinMoEMethod"
,
"CompressedTensorsWNA16MoEMethod"
,
"CompressedTensorsWNA16MoEMethod"
,
]
]
...
@@ -71,6 +72,8 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
...
@@ -71,6 +72,8 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
return
CompressedTensorsW8A8Fp8MoECutlassMethod
(
quant_config
)
return
CompressedTensorsW8A8Fp8MoECutlassMethod
(
quant_config
)
elif
quant_config
.
_is_fp8_w8a8
(
weight_quant
,
input_quant
):
elif
quant_config
.
_is_fp8_w8a8
(
weight_quant
,
input_quant
):
return
CompressedTensorsW8A8Fp8MoEMethod
(
quant_config
)
return
CompressedTensorsW8A8Fp8MoEMethod
(
quant_config
)
elif
quant_config
.
_is_dynamic_token_w8a8
(
weight_quant
,
input_quant
):
return
CompressedTensorsW8A8Int8MoEMethod
(
quant_config
)
else
:
else
:
raise
RuntimeError
(
raise
RuntimeError
(
f
"Unsupported FusedMoe scheme:
{
weight_quant
}
,
{
input_quant
}
"
)
f
"Unsupported FusedMoe scheme:
{
weight_quant
}
,
{
input_quant
}
"
)
...
@@ -545,6 +548,138 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
...
@@ -545,6 +548,138 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
)
)
class
CompressedTensorsW8A8Int8MoEMethod
(
CompressedTensorsMoEMethod
):
def
__init__
(
self
,
quant_config
:
"CompressedTensorsConfig"
# type: ignore # noqa E501
):
self
.
quant_config
=
quant_config
self
.
weight_quant
=
self
.
quant_config
.
target_scheme_map
[
"Linear"
].
get
(
"weights"
)
self
.
input_quant
=
self
.
quant_config
.
target_scheme_map
[
"Linear"
].
get
(
"input_activations"
)
per_channel
=
(
self
.
weight_quant
.
strategy
==
QuantizationStrategy
.
CHANNEL
and
self
.
input_quant
.
strategy
==
QuantizationStrategy
.
TOKEN
)
if
not
per_channel
:
raise
ValueError
(
"For INT8 Fused MoE layers, we require channelwise, "
"dynamic per token quantization. Found "
f
"
{
self
.
weight_quant
}
,
{
self
.
input_quant
}
"
)
self
.
static_input_scales
=
not
self
.
input_quant
.
dynamic
if
self
.
static_input_scales
:
raise
ValueError
(
"For INT8 Fused MoE layers, we require channelwise, "
"dynamic per token quantization. Found static input scales."
)
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
):
params_dtype
=
torch
.
int8
# WEIGHTS
w13_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
2
*
intermediate_size_per_partition
,
hidden_size
,
dtype
=
params_dtype
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w13_weight"
,
w13_weight
)
set_weight_attrs
(
w13_weight
,
extra_weight_attrs
)
w2_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
,
intermediate_size_per_partition
,
dtype
=
params_dtype
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
# WEIGHT_SCALES
assert
self
.
weight_quant
.
strategy
==
QuantizationStrategy
.
CHANNEL
w13_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
2
*
intermediate_size_per_partition
,
1
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w13_weight_scale"
,
w13_weight_scale
)
w2_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
hidden_size
,
1
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w2_weight_scale"
,
w2_weight_scale
)
# Add PER-CHANNEL quantization for FusedMoE.weight_loader.
extra_weight_attrs
.
update
(
{
"quant_method"
:
FusedMoeWeightScaleSupported
.
CHANNEL
.
value
})
set_weight_attrs
(
w13_weight_scale
,
extra_weight_attrs
)
set_weight_attrs
(
w2_weight_scale
,
extra_weight_attrs
)
# INPUT_SCALES
assert
not
self
.
static_input_scales
layer
.
w13_input_scale
=
None
layer
.
w2_input_scale
=
None
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
pass
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
,
use_grouped_topk
:
bool
=
False
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe
import
fused_experts
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
use_grouped_topk
=
use_grouped_topk
,
top_k
=
top_k
,
renormalize
=
renormalize
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
custom_routing_function
=
custom_routing_function
,
scoring_func
=
scoring_func
,
e_score_correction_bias
=
e_score_correction_bias
)
return
fused_experts
(
hidden_states
=
x
,
w1
=
layer
.
w13_weight
,
w2
=
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
True
,
activation
=
activation
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
use_int8_w8a8
=
True
,
per_channel_quant
=
True
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
w1_scale
=
layer
.
w13_weight_scale
,
w2_scale
=
layer
.
w2_weight_scale
,
a1_scale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
)
class
CompressedTensorsWNA16MarlinMoEMethod
(
CompressedTensorsMoEMethod
):
class
CompressedTensorsWNA16MarlinMoEMethod
(
CompressedTensorsMoEMethod
):
def
__init__
(
def
__init__
(
...
...
vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py
View file @
868c546d
...
@@ -111,7 +111,7 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
...
@@ -111,7 +111,7 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
# * dynamic, i_s is None and x_s computed from x.
# * dynamic, i_s is None and x_s computed from x.
# * static, i_s is scalar and x_s is i_s.
# * static, i_s is scalar and x_s is i_s.
symmetric
=
azp_adj
is
None
symmetric
=
azp_adj
is
None
x_q
,
x_s
,
x_zp
=
ops
.
scaled_int8_quant
(
x
,
x_q
,
x_s
,
x_zp
=
ops
.
scaled_int8_quant
(
x
.
contiguous
()
,
i_s
,
i_s
,
i_zp
,
i_zp
,
symmetric
=
symmetric
)
symmetric
=
symmetric
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment