Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
d085a440
Unverified
Commit
d085a440
authored
Apr 13, 2025
by
Michael Goin
Committed by
GitHub
Apr 13, 2025
Browse files
Enable PTPC FP8 for CompressedTensorsW8A8Fp8MoEMethod (triton fused_moe) (#16537)
Signed-off-by:
mgoin
<
mgoin64@gmail.com
>
parent
f49e5aff
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
68 additions
and
39 deletions
+68
-39
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
...quantization/compressed_tensors/compressed_tensors_moe.py
+68
-39
No files found.
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
View file @
d085a440
...
...
@@ -88,14 +88,23 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
self
.
input_quant
=
self
.
quant_config
.
target_scheme_map
[
"Linear"
].
get
(
"input_activations"
)
if
not
(
self
.
weight_quant
.
strategy
==
QuantizationStrategy
.
TENSOR
and
self
.
input_quant
.
strategy
==
QuantizationStrategy
.
TENSOR
):
per_tensor
=
(
self
.
weight_quant
.
strategy
==
QuantizationStrategy
.
TENSOR
and
self
.
input_quant
.
strategy
==
QuantizationStrategy
.
TENSOR
)
per_channel
=
(
self
.
weight_quant
.
strategy
==
QuantizationStrategy
.
CHANNEL
and
self
.
input_quant
.
strategy
==
QuantizationStrategy
.
TOKEN
)
if
not
(
per_tensor
or
per_channel
):
raise
ValueError
(
"For FP8 Fused MoE layers,
only
per
-
tensor
scales
"
"
f
or
weights and ac
ti
v
ation
s are supported
. Found "
"For FP8 Fused MoE layers,
we require
per
tensor "
"or
channelwise, dynamic per token quan
ti
z
ation. Found "
f
"
{
self
.
weight_quant
}
,
{
self
.
input_quant
}
"
)
self
.
static_input_scales
=
not
self
.
input_quant
.
dynamic
if
self
.
static_input_scales
and
per_channel
:
raise
ValueError
(
"For FP8 Fused MoE layer, we require either per tensor or "
"channelwise, dynamic per token quantization."
)
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size_per_partition
:
int
,
...
...
@@ -123,22 +132,38 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
# WEIGHT_SCALES
if
self
.
weight_quant
.
strategy
==
QuantizationStrategy
.
TENSOR
:
# Allocate 2 scales for w1 and w3 respectively.
# They will be combined to a single scale after weight loading.
w13_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
2
,
dtype
=
torch
.
float32
),
# They are combined to a single scale after weight loading.
w13_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
2
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w13_weight_scale"
,
w13_weight_scale
)
w2_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w2_weight_scale"
,
w2_weight_scale
)
# Add PER-TENSOR quantization for FusedMoE.weight_loader.
extra_weight_attrs
.
update
(
{
"quant_method"
:
FusedMoeWeightScaleSupported
.
TENSOR
.
value
})
set_weight_attrs
(
w13_weight_scale
,
extra_weight_attrs
)
set_weight_attrs
(
w2_weight_scale
,
extra_weight_attrs
)
w2_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
elif
self
.
weight_quant
.
strategy
==
QuantizationStrategy
.
CHANNEL
:
w13_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
2
*
intermediate_size_per_partition
,
1
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w13_weight_scale"
,
w13_weight_scale
)
w2_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
hidden_size
,
1
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w2_weight_scale"
,
w2_weight_scale
)
# Add the quantization method used (per tensor/grouped/channel)
# to ensure the weight scales are loaded in properly
# Add PER-CHANNEL quantization for FusedMoE.weight_loader.
extra_weight_attrs
.
update
(
{
"quant_method"
:
FusedMoeWeightScaleSupported
.
TENSOR
.
value
})
{
"quant_method"
:
FusedMoeWeightScaleSupported
.
CHANNEL
.
value
})
set_weight_attrs
(
w13_weight_scale
,
extra_weight_attrs
)
set_weight_attrs
(
w2_weight_scale
,
extra_weight_attrs
)
...
...
@@ -163,6 +188,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
# Fp8 moe kernels require a single activation scale.
# We take the max of all the scales in case they differ.
if
self
.
static_input_scales
:
assert
self
.
input_quant
.
strategy
==
QuantizationStrategy
.
TENSOR
if
(
layer
.
w13_input_scale
is
None
or
layer
.
w2_input_scale
is
None
):
raise
ValueError
(
"QuantConfig has static quantization, but found "
...
...
@@ -204,8 +230,9 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
layer
.
w2_input_scale
=
torch
.
nn
.
Parameter
(
w2_input_scale
,
requires_grad
=
False
)
# Fp8 moe kernel needs single weight scale for w13 per expert.
# We take the max then dequant and requant each expert.
# For Per-TENSOR case, Fp8 moe kernel needs single weight scale
# for w13 per expert. Use max then dequant and requant each expert.
if
self
.
weight_quant
.
strategy
==
QuantizationStrategy
.
TENSOR
:
assert
layer
.
w13_weight_scale
is
not
None
shard_size
=
layer
.
intermediate_size_per_partition
max_w13_scales
=
layer
.
w13_weight_scale
.
max
(
dim
=
1
).
values
...
...
@@ -213,13 +240,13 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
start
=
0
for
shard_id
in
range
(
2
):
dq_weight
=
per_tensor_dequantize
(
layer
.
w13_weight
[
expert_id
][
start
:
start
+
shard_size
,
:],
layer
.
w13_weight
[
expert_id
][
start
:
start
+
shard_size
,
:],
layer
.
w13_weight_scale
[
expert_id
][
shard_id
])
layer
.
w13_weight
[
expert_id
][
start
:
start
+
shard_size
,
:],
_
=
ops
.
scaled_fp8_quant
(
dq_weight
,
max_w13_scales
[
expert_id
])
start
+=
shard_size
layer
.
w13_weight_scale
=
torch
.
nn
.
Parameter
(
max_w13_scales
,
requires_grad
=
False
)
...
...
@@ -265,6 +292,8 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
activation
=
activation
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
use_fp8_w8a8
=
True
,
per_channel_quant
=
self
.
weight_quant
.
strategy
==
QuantizationStrategy
.
CHANNEL
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
w1_scale
=
layer
.
w13_weight_scale
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment