Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f519ea41
Commit
f519ea41
authored
Apr 10, 2026
by
liuyunfei
Browse files
实现modelopt的w8a16量化算化
parent
49a30c70
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
62 additions
and
49 deletions
+62
-49
vllm/model_executor/layers/quantization/modelopt.py
vllm/model_executor/layers/quantization/modelopt.py
+61
-48
vllm/platforms/rocm.py
vllm/platforms/rocm.py
+1
-1
No files found.
vllm/model_executor/layers/quantization/modelopt.py
View file @
f519ea41
...
@@ -6,13 +6,14 @@ from typing import TYPE_CHECKING, Any, Callable, Optional, Union
...
@@ -6,13 +6,14 @@ from typing import TYPE_CHECKING, Any, Callable, Optional, Union
import
torch
import
torch
from
torch.nn
import
Module
from
torch.nn
import
Module
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
from
torch.nn
import
functional
as
F
,
init
import
vllm.envs
as
envs
import
vllm.envs
as
envs
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
from
vllm._custom_ops
import
cutlass_scaled_fp4_mm
,
scaled_fp4_quant
from
vllm._custom_ops
import
cutlass_scaled_fp4_mm
,
scaled_fp4_quant
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.fused_moe.config
import
(
from
vllm.model_executor.layers.fused_moe.config
import
(
FusedMoEConfig
,
FusedMoEQuantConfig
,
fp8_w8a8_moe_quant_config
,
FusedMoEConfig
,
FusedMoEQuantConfig
,
fp8_w8a8_moe_quant_config
,
int8_w8a16_moe_quant_config
,
nvfp4_moe_quant_config
)
nvfp4_moe_quant_config
)
from
vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe
import
(
from
vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe
import
(
is_valid_flashinfer_cutlass_fused_moe
)
is_valid_flashinfer_cutlass_fused_moe
)
...
@@ -40,7 +41,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
...
@@ -40,7 +41,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape
,
cutlass_fp4_supported
,
is_layer_skipped
,
swizzle_blockscale
)
GroupShape
,
cutlass_fp4_supported
,
is_layer_skipped
,
swizzle_blockscale
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
Fp8LinearOp
,
requantize_with_max_scale
)
Fp8LinearOp
,
requantize_with_max_scale
)
from
vllm.model_executor.parameter
import
(
ModelWeightParameter
,
from
vllm.model_executor.parameter
import
(
ModelWeightParameter
,
ChannelQuantScaleParameter
,
PerTensorScaleParameter
)
PerTensorScaleParameter
)
from
vllm.scalar_type
import
scalar_types
from
vllm.scalar_type
import
scalar_types
from
vllm.utils
import
next_power_of_2
from
vllm.utils
import
next_power_of_2
...
@@ -52,7 +53,7 @@ if TYPE_CHECKING:
...
@@ -52,7 +53,7 @@ if TYPE_CHECKING:
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
QUANT_ALGOS
=
[
"FP8"
,
"NVFP4"
]
QUANT_ALGOS
=
[
"FP8"
,
"NVFP4"
,
"W8A16"
]
KV_CACHE_QUANT_ALGOS
=
[
"FP8"
]
KV_CACHE_QUANT_ALGOS
=
[
"FP8"
]
...
@@ -145,6 +146,8 @@ class ModelOptFp8Config(QuantizationConfig):
...
@@ -145,6 +146,8 @@ class ModelOptFp8Config(QuantizationConfig):
quant_method
=
config
.
get
(
"quant_algo"
,
""
)
quant_method
=
config
.
get
(
"quant_algo"
,
""
)
kv_cache_quant_method
=
config
.
get
(
"kv_cache_quant_algo"
)
kv_cache_quant_method
=
config
.
get
(
"kv_cache_quant_algo"
)
exclude_modules
=
config
.
get
(
"exclude_modules"
)
exclude_modules
=
config
.
get
(
"exclude_modules"
)
if
not
exclude_modules
:
exclude_modules
=
config
.
get
(
"ignore"
)
if
quant_method
not
in
QUANT_ALGOS
:
if
quant_method
not
in
QUANT_ALGOS
:
raise
ValueError
(
raise
ValueError
(
...
@@ -152,7 +155,7 @@ class ModelOptFp8Config(QuantizationConfig):
...
@@ -152,7 +155,7 @@ class ModelOptFp8Config(QuantizationConfig):
"quantizations in vLLM. Please check the "
"quantizations in vLLM. Please check the "
"`hf_quant_config.json` file for your model's "
"`hf_quant_config.json` file for your model's "
"quant configuration."
)
"quant configuration."
)
is_checkpoint_fp8_serialized
=
(
"
FP8
"
in
quant_method
)
is_checkpoint_fp8_serialized
=
(
"
W8A16
"
in
quant_method
)
return
cls
(
is_checkpoint_fp8_serialized
,
kv_cache_quant_method
,
return
cls
(
is_checkpoint_fp8_serialized
,
kv_cache_quant_method
,
exclude_modules
)
exclude_modules
)
...
@@ -234,7 +237,7 @@ class ModelOptFp8LinearMethod(LinearMethodBase):
...
@@ -234,7 +237,7 @@ class ModelOptFp8LinearMethod(LinearMethodBase):
layer
.
logical_widths
=
output_partition_sizes
layer
.
logical_widths
=
output_partition_sizes
layer
.
input_size_per_partition
=
input_size_per_partition
layer
.
input_size_per_partition
=
input_size_per_partition
layer
.
output_size_per_partition
=
output_size_per_partition
layer
.
output_size_per_partition
=
output_size_per_partition
weight_dtype
=
(
torch
.
float8_e4m3fn
weight_dtype
=
(
torch
.
int8
if
self
.
quant_config
.
is_checkpoint_fp8_serialized
else
if
self
.
quant_config
.
is_checkpoint_fp8_serialized
else
params_dtype
)
params_dtype
)
weight
=
ModelWeightParameter
(
data
=
torch
.
empty
(
weight
=
ModelWeightParameter
(
data
=
torch
.
empty
(
...
@@ -248,29 +251,29 @@ class ModelOptFp8LinearMethod(LinearMethodBase):
...
@@ -248,29 +251,29 @@ class ModelOptFp8LinearMethod(LinearMethodBase):
if
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
if
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
# WEIGHT SCALE
# WEIGHT SCALE
weight_scale
=
PerTensor
ScaleParameter
(
data
=
torch
.
empty
(
weight_scale
=
ChannelQuant
ScaleParameter
(
output_dim
=
0
,
data
=
torch
.
empty
(
len
(
output_partition
_sizes
)
,
dtype
=
torch
.
float
32
),
output_
size_per_
partition
,
dtype
=
torch
.
float
16
),
weight_loader
=
weight_loader
)
weight_loader
=
weight_loader
)
weight_scale
[:]
=
torch
.
finfo
(
torch
.
float
32
).
min
weight_scale
[:]
=
torch
.
finfo
(
torch
.
float
16
).
min
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
# INPUT SCALE
#
#
INPUT SCALE
scale
=
PerTensorScaleParameter
(
data
=
torch
.
empty
(
#
scale = PerTensorScaleParameter(data=torch.empty(
len
(
output_partition_sizes
),
dtype
=
torch
.
float32
),
#
len(output_partition_sizes), dtype=torch.float32),
weight_loader
=
weight_loader
)
#
weight_loader=weight_loader)
scale
[:]
=
torch
.
finfo
(
torch
.
float32
).
min
#
scale[:] = torch.finfo(torch.float32).min
layer
.
register_parameter
(
"input_scale"
,
scale
)
#
layer.register_parameter("input_scale", scale)
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
weight
=
layer
.
weight
#
weight = layer.weight
max_w_scale
=
layer
.
weight_scale
.
max
()
#
max_w_scale = layer.weight_scale.max()
if
not
(
layer
.
weight_scale
==
layer
.
weight_scale
[
0
]).
all
():
#
if not (layer.weight_scale == layer.weight_scale[0]).all():
max_w_scale
,
weight
=
requantize_with_max_scale
(
#
max_w_scale, weight = requantize_with_max_scale(
layer
.
weight
,
layer
.
weight_scale
,
layer
.
logical_widths
)
#
layer.weight, layer.weight_scale, layer.logical_widths)
layer
.
weight
=
Parameter
(
weight
.
t
(),
requires_grad
=
False
)
layer
.
weight
=
Parameter
(
layer
.
weight
.
detach
().
clone
(),
requires_grad
=
False
)
layer
.
weight_scale
=
Parameter
(
max_w_scale
,
requires_grad
=
False
)
layer
.
weight_scale
=
Parameter
(
layer
.
weight_scale
.
detach
().
clone
()
,
requires_grad
=
False
)
layer
.
input_scale
=
Parameter
(
layer
.
input_scale
.
max
(),
#
layer.input_scale = Parameter(layer.input_scale.max(),
requires_grad
=
False
)
#
requires_grad=False)
def
apply
(
def
apply
(
self
,
self
,
...
@@ -278,11 +281,14 @@ class ModelOptFp8LinearMethod(LinearMethodBase):
...
@@ -278,11 +281,14 @@ class ModelOptFp8LinearMethod(LinearMethodBase):
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
return
self
.
fp8_linear
.
apply
(
input
=
x
,
# return self.fp8_linear.apply(input=x,
weight
=
layer
.
weight
,
# weight=layer.weight,
weight_scale
=
layer
.
weight_scale
,
# weight_scale=layer.weight_scale,
input_scale
=
layer
.
input_scale
,
# input_scale=layer.input_scale,
bias
=
bias
)
# bias=bias)
weight_scale
=
layer
.
weight_scale
.
unsqueeze
(
1
)
weights
=
layer
.
weight
.
view
(
torch
.
int8
).
to
(
x
.
dtype
)
*
weight_scale
.
to
(
x
.
dtype
)
return
F
.
linear
(
x
,
weights
,
bias
)
class
ModelOptFp8MoEMethod
(
FusedMoEMethodBase
):
class
ModelOptFp8MoEMethod
(
FusedMoEMethodBase
):
...
@@ -348,7 +354,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
...
@@ -348,7 +354,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
):
):
# Use FP8 dtype if checkpoint is serialized
# Use FP8 dtype if checkpoint is serialized
weight_dtype
=
(
torch
.
float8_e4m3fn
weight_dtype
=
(
torch
.
int8
if
self
.
quant_config
.
is_checkpoint_fp8_serialized
else
if
self
.
quant_config
.
is_checkpoint_fp8_serialized
else
params_dtype
)
params_dtype
)
weight_loader
=
extra_weight_attrs
.
get
(
"weight_loader"
)
weight_loader
=
extra_weight_attrs
.
get
(
"weight_loader"
)
...
@@ -381,14 +387,14 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
...
@@ -381,14 +387,14 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
# They will be combined to a single scale after weight loading.
# They will be combined to a single scale after weight loading.
w13_weight_scale
=
PerTensorScaleParameter
(
w13_weight_scale
=
PerTensorScaleParameter
(
data
=
torch
.
full
(
data
=
torch
.
full
(
(
num_experts
,
2
),
(
num_experts
,
2
,
intermediate_size_per_partition
),
1.0
,
1.0
,
dtype
=
torch
.
float
32
,
dtype
=
torch
.
float
16
,
),
),
weight_loader
=
weight_loader
,
weight_loader
=
weight_loader
,
)
)
w2_weight_scale
=
PerTensorScaleParameter
(
w2_weight_scale
=
PerTensorScaleParameter
(
data
=
torch
.
full
((
num_experts
,
),
1.0
,
dtype
=
torch
.
float
32
),
data
=
torch
.
full
((
num_experts
,
hidden_size
),
1.0
,
dtype
=
torch
.
float
16
),
weight_loader
=
weight_loader
,
weight_loader
=
weight_loader
,
)
)
layer
.
register_parameter
(
"w13_weight_scale"
,
w13_weight_scale
)
layer
.
register_parameter
(
"w13_weight_scale"
,
w13_weight_scale
)
...
@@ -399,16 +405,16 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
...
@@ -399,16 +405,16 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
{
"quant_method"
:
FusedMoeWeightScaleSupported
.
TENSOR
.
value
})
{
"quant_method"
:
FusedMoeWeightScaleSupported
.
TENSOR
.
value
})
# INPUT SCALES - Per-tensor scaling for ModelOpt
# INPUT SCALES - Per-tensor scaling for ModelOpt
w13_input_scale
=
PerTensorScaleParameter
(
#
w13_input_scale = PerTensorScaleParameter(
data
=
torch
.
full
((
num_experts
,
),
1.0
,
dtype
=
torch
.
float32
),
#
data=torch.full((num_experts, ), 1.0, dtype=torch.float32),
weight_loader
=
weight_loader
,
#
weight_loader=weight_loader,
)
#
)
w2_input_scale
=
PerTensorScaleParameter
(
#
w2_input_scale = PerTensorScaleParameter(
data
=
torch
.
full
((
num_experts
,
),
1.0
,
dtype
=
torch
.
float32
),
#
data=torch.full((num_experts, ), 1.0, dtype=torch.float32),
weight_loader
=
weight_loader
,
#
weight_loader=weight_loader,
)
#
)
layer
.
register_parameter
(
"w13_input_scale"
,
w13_input_scale
)
#
layer.register_parameter("w13_input_scale", w13_input_scale)
layer
.
register_parameter
(
"w2_input_scale"
,
w2_input_scale
)
#
layer.register_parameter("w2_input_scale", w2_input_scale)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
"""Process FP8 MoE weights after loading from serialized checkpoint.
"""Process FP8 MoE weights after loading from serialized checkpoint.
...
@@ -462,7 +468,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
...
@@ -462,7 +468,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
layer
.
w13_weight_scale
=
Parameter
(
max_w13_scales
,
layer
.
w13_weight_scale
=
Parameter
(
max_w13_scales
,
requires_grad
=
False
)
requires_grad
=
False
)
else
:
else
:
layer
.
w13_weight_scale
=
Parameter
(
layer
.
w13_weight_scale
.
data
,
layer
.
w13_weight_scale
=
Parameter
(
layer
.
w13_weight_scale
.
data
.
flatten
(
start_dim
=
1
)
,
requires_grad
=
False
)
requires_grad
=
False
)
if
hasattr
(
layer
,
if
hasattr
(
layer
,
...
@@ -491,13 +497,19 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
...
@@ -491,13 +497,19 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
if
self
.
flashinfer_moe_backend
==
FlashinferMoeBackend
.
TENSORRT_LLM
:
if
self
.
flashinfer_moe_backend
==
FlashinferMoeBackend
.
TENSORRT_LLM
:
return
None
return
None
return
fp
8_w8a
8
_moe_quant_config
(
return
int
8_w8a
16
_moe_quant_config
(
w1_scale
=
layer
.
w13_weight_scale
,
w1_scale
=
layer
.
w13_weight_scale
,
w2_scale
=
layer
.
w2_weight_scale
,
w2_scale
=
layer
.
w2_weight_scale
,
a1_scale
=
layer
.
w13_input_scale
,
w1_zp
=
None
,
a2_scale
=
layer
.
w2_input_scale
,
w2_zp
=
None
per_act_token_quant
=
False
,
)
)
# return fp8_w8a8_moe_quant_config(
# w1_scale=layer.w13_weight_scale,
# w2_scale=layer.w2_weight_scale,
# a1_scale=layer.w13_input_scale,
# a2_scale=layer.w2_input_scale,
# per_act_token_quant=False,
# )
def
apply
(
def
apply
(
self
,
self
,
...
@@ -521,6 +533,8 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
...
@@ -521,6 +533,8 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
expert_load_view
:
Optional
[
torch
.
Tensor
]
=
None
,
expert_load_view
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_to_physical_map
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_to_physical_map
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_replica_count
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_replica_count
:
Optional
[
torch
.
Tensor
]
=
None
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
use_fused_gate
:
Optional
[
bool
]
=
False
,
)
->
Union
[
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]:
)
->
Union
[
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]:
if
enable_eplb
:
if
enable_eplb
:
raise
NotImplementedError
(
raise
NotImplementedError
(
...
@@ -594,7 +608,6 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
...
@@ -594,7 +608,6 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
fused_experts
)
fused_experts
)
assert
self
.
moe_quant_config
is
not
None
assert
self
.
moe_quant_config
is
not
None
return
fused_experts
(
return
fused_experts
(
x
,
x
,
layer
.
w13_weight
,
layer
.
w13_weight
,
...
...
vllm/platforms/rocm.py
View file @
f519ea41
...
@@ -190,7 +190,7 @@ class RocmPlatform(Platform):
...
@@ -190,7 +190,7 @@ class RocmPlatform(Platform):
supported_quantization
:
list
[
str
]
=
[
supported_quantization
:
list
[
str
]
=
[
"awq"
,
"gptq"
,
"fp8"
,
"compressed-tensors"
,
"fbgemm_fp8"
,
"gguf"
,
"awq"
,
"gptq"
,
"fp8"
,
"compressed-tensors"
,
"fbgemm_fp8"
,
"gguf"
,
"quark"
,
"ptpc_fp8"
,
"mxfp4"
,
"petit_nvfp4"
,
"torchao"
,
"quark"
,
"ptpc_fp8"
,
"mxfp4"
,
"petit_nvfp4"
,
"torchao"
,
"moe_wna16"
,
"slimquant_w4a8"
,
"w8a8_int8"
,
"awq_marlin"
,
"slimquant_w4a8_marlin"
,
"slimquant_compressed_tensors_marlin"
"moe_wna16"
,
"slimquant_w4a8"
,
"w8a8_int8"
,
"awq_marlin"
,
"slimquant_w4a8_marlin"
,
"slimquant_compressed_tensors_marlin"
,
"modelopt"
]
]
@
classmethod
@
classmethod
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment