Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
db593aa6
Unverified
Commit
db593aa6
authored
May 07, 2025
by
Bowen Bao
Committed by
GitHub
May 07, 2025
Browse files
[Quantization] Quark MXFP4 format loading (#16943)
parent
f98e3075
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
289 additions
and
3 deletions
+289
-3
tests/models/quantization/test_mxfp4.py
tests/models/quantization/test_mxfp4.py
+40
-0
vllm/envs.py
vllm/envs.py
+9
-0
vllm/model_executor/layers/quantization/quark/quark.py
vllm/model_executor/layers/quantization/quark/quark.py
+55
-1
vllm/model_executor/layers/quantization/quark/schemes/__init__.py
...el_executor/layers/quantization/quark/schemes/__init__.py
+2
-1
vllm/model_executor/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py
...tor/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py
+125
-0
vllm/model_executor/layers/quantization/utils/mxfp4_utils.py
vllm/model_executor/layers/quantization/utils/mxfp4_utils.py
+45
-0
vllm/model_executor/model_loader/utils.py
vllm/model_executor/model_loader/utils.py
+1
-1
vllm/platforms/interface.py
vllm/platforms/interface.py
+7
-0
vllm/platforms/rocm.py
vllm/platforms/rocm.py
+5
-0
No files found.
tests/models/quantization/test_mxfp4.py
0 → 100644
View file @
db593aa6
# SPDX-License-Identifier: Apache-2.0
# flake8: noqa
"""Tests Quark mxfp4 models against ground truth generation
"""
import
pytest
from
vllm
import
LLM
,
SamplingParams
MODELS
=
[
"amd/Llama-2-7b-chat-hf-wmxfp4-amxfp4-kvfp8-scale-uint8"
]
EXPECTED_STRS_MAP
=
{
"amd/Llama-2-7b-chat-hf-wmxfp4-amxfp4-kvfp8-scale-uint8"
:
[
'
\n
### Key Features
\n\n
* **High-throughput Inference**: vLL'
,
'
\n
Artificial intelligence (AI) has evolved significantly since its inception in the 1'
,
'Artificial intelligence (AI) and human intelligence (HI) are two distinct concepts that have been'
,
'A neural network is a machine learning model inspired by the structure of the human brain. It consists of'
,
'
\n
Title: The Dreaming Robot
\n\n
As the sun set on the bustling metropol'
,
'
\n
The COVID-19 pandemic has had a profound impact on global economic structures and business'
,
'The Mona Lisa painting, created by Leonardo da Vinci in the early 16th'
,
" everybody knows this proverbial saying, but did you know that it's not entirely accurate?"
,
]
}
@
pytest
.
mark
.
skip
(
reason
=
"Model to be released in the future"
)
@
pytest
.
mark
.
quant_model
@
pytest
.
mark
.
parametrize
(
"model_name"
,
MODELS
)
def
test_models
(
example_prompts
,
model_name
)
->
None
:
sampling_params
=
SamplingParams
(
max_tokens
=
20
,
temperature
=
0
)
llm
=
LLM
(
model
=
model_name
,
kv_cache_dtype
=
"fp8"
,
quantization
=
"quark"
,
)
outputs
=
llm
.
generate
(
example_prompts
,
sampling_params
)
for
i
,
output
in
enumerate
(
outputs
):
output_str
=
output
.
outputs
[
0
].
text
expected_str
=
EXPECTED_STRS_MAP
[
model_name
][
i
]
assert
expected_str
==
output_str
,
(
f
"Expected:
{
expected_str
!
r
}
\n
vLLM:
{
output_str
!
r
}
"
)
vllm/envs.py
View file @
db593aa6
...
...
@@ -84,6 +84,7 @@ if TYPE_CHECKING:
VLLM_ROCM_FP8_PADDING
:
bool
=
True
VLLM_ROCM_MOE_PADDING
:
bool
=
True
VLLM_ROCM_CUSTOM_PAGED_ATTN
:
bool
=
True
VLLM_QUARK_EMU_MEM_OPT
:
bool
=
False
VLLM_ENABLE_V1_MULTIPROCESSING
:
bool
=
True
VLLM_LOG_BATCHSIZE_INTERVAL
:
float
=
-
1
VLLM_DISABLE_COMPILE_CACHE
:
bool
=
False
...
...
@@ -583,6 +584,14 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda
:
(
os
.
getenv
(
"VLLM_ROCM_CUSTOM_PAGED_ATTN"
,
"True"
).
lower
()
in
(
"true"
,
"1"
)),
# If set, when running in Quark emulation mode, do not dequantize the
# weights at load time. Instead, dequantize weights on-the-fly during
# kernel execution.
# This allows running larger models at the cost of slower inference.
# This flag has no effect when not running in Quark emulation mode.
"VLLM_QUARK_EMU_MEM_OPT"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_QUARK_EMU_MEM_OPT"
,
"0"
))),
# Divisor for dynamic query scale factor calculation for FP8 KV Cache
"Q_SCALE_CONSTANT"
:
lambda
:
int
(
os
.
getenv
(
"Q_SCALE_CONSTANT"
,
"200"
)),
...
...
vllm/model_executor/layers/quantization/quark/quark.py
View file @
db593aa6
...
...
@@ -5,6 +5,7 @@ from typing import Any, Dict, List, Optional, cast
import
torch
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
LinearMethodBase
,
UnquantizedLinearMethod
)
...
...
@@ -15,13 +16,15 @@ from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from
vllm.model_executor.layers.quantization.quark.quark_moe
import
(
# noqa: E501
QuarkMoEMethod
)
from
vllm.model_executor.layers.quantization.quark.schemes
import
(
QuarkScheme
,
QuarkW8A8Fp8
,
QuarkW8A8Int8
)
QuarkScheme
,
QuarkW4A4MXFP4
,
QuarkW8A8Fp8
,
QuarkW8A8Int8
)
from
vllm.model_executor.layers.quantization.quark.utils
import
(
deep_compare
,
should_ignore_layer
)
from
vllm.platforms
import
current_platform
__all__
=
[
"QuarkLinearMethod"
]
logger
=
init_logger
(
__name__
)
class
QuarkConfig
(
QuantizationConfig
):
...
...
@@ -67,6 +70,7 @@ class QuarkConfig(QuantizationConfig):
return
QuarkLinearMethod
(
self
)
if
isinstance
(
layer
,
Attention
):
return
QuarkKVCacheMethod
(
self
)
if
isinstance
(
layer
,
FusedMoE
):
return
QuarkMoEMethod
.
get_moe_method
(
self
,
module
=
layer
,
...
...
@@ -205,6 +209,54 @@ class QuarkConfig(QuantizationConfig):
# Only symmetric weight quantization supported.
return
is_int8_dtype
and
is_tensor
and
is_weight_symmetric
and
is_static
def
_is_mx_fp4
(
self
,
weight_quant
:
Optional
[
Dict
[
str
,
Any
]],
input_quant
:
Optional
[
Dict
[
str
,
Any
]])
->
bool
:
# Confirm weights and input quantized.
if
weight_quant
is
None
or
input_quant
is
None
:
logger
.
debug
(
"Quark model is not in MX-FP4 format: "
"weight_quant or input_quant not set"
)
return
False
# Input and weight dtype needs to be fp4.
if
weight_quant
.
get
(
"dtype"
)
!=
"fp4"
or
input_quant
.
get
(
"dtype"
)
!=
"fp4"
:
logger
.
debug
(
"Quark model is not in MX-FP4 format: dtype not fp4"
)
return
False
# Input and weight qscheme needs to be per group.
if
weight_quant
.
get
(
"qscheme"
)
!=
"per_group"
or
input_quant
.
get
(
"qscheme"
)
!=
"per_group"
:
logger
.
debug
(
"Quark model is not in MX-FP4 format: not per_group"
)
return
False
# Input and weight group size needs to be 32.
if
weight_quant
.
get
(
"group_size"
)
!=
32
or
input_quant
.
get
(
"group_size"
)
!=
32
:
logger
.
debug
(
"Quark model is not in MX-FP4 format: not group_size=32"
)
return
False
# Weights need to use static quantization.
if
weight_quant
.
get
(
"is_dynamic"
)
is
True
:
logger
.
debug
(
"Quark model is not in MX-FP4 format: not weight static"
)
return
False
# Activations need to use dynamic quantization.
if
input_quant
.
get
(
"is_dynamic"
)
is
False
:
logger
.
debug
(
"Quark model is not in MX-FP4 format: not activation dynamic"
)
return
False
# Activations and weight scales need to be in e8m0 format.
if
weight_quant
.
get
(
"scale_format"
)
!=
"e8m0"
or
input_quant
.
get
(
"scale_format"
)
!=
"e8m0"
:
logger
.
debug
(
"Quark model is not in MX-FP4 format: not scale_format e8m0"
)
return
False
return
True
def
_find_matched_config
(
self
,
layer_name
:
str
,
module
:
torch
.
nn
.
Module
)
->
Dict
[
str
,
Any
]:
...
...
@@ -269,6 +321,8 @@ class QuarkConfig(QuantizationConfig):
return
QuarkW8A8Int8
(
qscheme
=
weight_qscheme
,
is_static_input_scheme
=
True
,
input_symmetric
=
input_config
.
get
(
"symmetric"
))
elif
self
.
_is_mx_fp4
(
weight_config
,
input_config
):
return
QuarkW4A4MXFP4
(
weight_config
,
input_config
)
raise
NotImplementedError
(
"No quark compatible scheme was found. "
f
"Weight config:
{
weight_config
}
, "
...
...
vllm/model_executor/layers/quantization/quark/schemes/__init__.py
View file @
db593aa6
# SPDX-License-Identifier: Apache-2.0
from
.quark_scheme
import
QuarkScheme
from
.quark_w4a4_mxfp4
import
QuarkW4A4MXFP4
from
.quark_w8a8_fp8
import
QuarkW8A8Fp8
from
.quark_w8a8_int8
import
QuarkW8A8Int8
__all__
=
[
"QuarkScheme"
,
"QuarkW8A8Fp8"
,
"QuarkW8A8Int8"
]
__all__
=
[
"QuarkScheme"
,
"QuarkW8A8Fp8"
,
"QuarkW8A8Int8"
,
"QuarkW4A4MXFP4"
]
vllm/model_executor/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py
0 → 100644
View file @
db593aa6
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
import
torch
import
torch.nn.functional
as
F
import
vllm.envs
as
envs
from
vllm.model_executor.layers.quantization.quark.schemes
import
QuarkScheme
from
vllm.model_executor.layers.quantization.utils.mxfp4_utils
import
(
OCP_MX_BLOCK_SIZE
,
per_token_group_quant_mxfp4
)
from
vllm.model_executor.parameter
import
(
GroupQuantScaleParameter
,
PackedvLLMParameter
)
from
vllm.platforms
import
current_platform
__all__
=
[
"QuarkW4A4MXFP4"
]
class
QuarkW4A4MXFP4
(
QuarkScheme
):
def
__init__
(
self
,
weight_quant_spec
:
Dict
[
str
,
Any
],
input_quant_spec
:
Dict
[
str
,
Any
]):
self
.
out_dtype
=
torch
.
get_default_dtype
()
self
.
qscheme
=
"per_group"
self
.
weight_quant_spec
=
weight_quant_spec
self
.
input_quant_spec
=
input_quant_spec
self
.
emulate
=
not
current_platform
.
supports_mx
()
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
return
70
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
layer
.
weight
=
torch
.
nn
.
Parameter
(
layer
.
weight
.
data
,
requires_grad
=
False
)
layer
.
weight_scale
=
torch
.
nn
.
Parameter
(
layer
.
weight_scale
.
data
,
requires_grad
=
False
)
if
self
.
emulate
:
try
:
from
quark.torch.export.nn.modules
import
realquantizer
from
quark.torch.quantization.config.config
import
(
QuantizationSpec
)
except
ImportError
as
err
:
raise
ImportError
(
"The package `amd-quark` is required to use AMD Quark "
"MX-FP4 models. Please install it with `pip install "
"amd-quark`."
)
from
err
weight_quant_spec
=
QuantizationSpec
.
from_dict
(
self
.
weight_quant_spec
)
weight_quantizer
=
realquantizer
.
get_real_quantizer
(
qspec
=
weight_quant_spec
,
quantizer
=
None
,
real_quantized
=
True
,
reorder
=
False
,
float_dtype
=
self
.
out_dtype
,
scale_shape
=
layer
.
weight_scale
.
shape
,
zero_point_shape
=
None
,
)
weight_quantizer
.
scale
.
data
=
layer
.
weight_scale
.
data
if
not
envs
.
VLLM_QUARK_EMU_MEM_OPT
:
layer
.
weight
=
torch
.
nn
.
Parameter
(
weight_quantizer
(
layer
.
weight
.
data
).
to
(
self
.
out_dtype
),
requires_grad
=
False
,
)
else
:
self
.
weight_quantizer
=
weight_quantizer
layer
.
weight_scale
=
None
# This call is necessary to release the scales memory.
torch
.
cuda
.
empty_cache
()
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
output_partition_sizes
:
List
[
int
],
input_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
weight_loader
:
Callable
,
**
kwargs
):
output_size_per_partition
=
sum
(
output_partition_sizes
)
layer
.
logical_widths
=
output_partition_sizes
# WEIGHT
weight
=
PackedvLLMParameter
(
data
=
torch
.
empty
(
output_size_per_partition
,
input_size_per_partition
//
2
,
dtype
=
torch
.
uint8
,
),
input_dim
=
1
,
output_dim
=
0
,
packed_dim
=
1
,
packed_factor
=
2
,
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"weight"
,
weight
)
# WEIGHT SCALE
weight_scale
=
GroupQuantScaleParameter
(
data
=
torch
.
empty
(
output_size_per_partition
,
input_size_per_partition
//
OCP_MX_BLOCK_SIZE
,
dtype
=
torch
.
uint8
,
),
input_dim
=
1
,
output_dim
=
0
,
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
if
self
.
emulate
:
if
envs
.
VLLM_QUARK_EMU_MEM_OPT
:
dq_w
=
self
.
weight_quantizer
(
layer
.
weight
).
to
(
self
.
out_dtype
)
else
:
dq_w
=
layer
.
weight
qdq_x
,
_
=
per_token_group_quant_mxfp4
(
x
,
OCP_MX_BLOCK_SIZE
)
return
F
.
linear
(
qdq_x
,
dq_w
,
bias
)
else
:
raise
NotImplementedError
()
vllm/model_executor/layers/quantization/utils/mxfp4_utils.py
0 → 100644
View file @
db593aa6
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Tuple
import
torch
OCP_MX_BLOCK_SIZE
=
32
def
per_token_group_quant_mxfp4
(
x
:
torch
.
Tensor
,
block_k
:
int
,
scale_calculation_mode
:
str
=
"even"
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
try
:
from
quark.torch.kernel.hw_emulation.hw_emulation_interface
import
(
fake_quantize_fp4_fp6_per_group_with_scale
)
from
quark.torch.quantization.utils
import
(
even_round
,
reshape_to_blocks
)
except
ImportError
as
err
:
raise
ImportError
(
"The package `amd-quark` is required to use "
"MX-FP4 models. Please install it with `pip install "
"amd-quark`."
)
from
err
axis
=
-
1
block_x
=
reshape_to_blocks
(
x
,
block_k
,
axis
)
amax
,
_
=
torch
.
max
(
torch
.
abs
(
block_x
),
dim
=-
1
,
keepdim
=
True
)
amax
=
amax
.
squeeze
(
-
1
)
# TODO: there are other rounding strategies supported in quark and in the
# config.json that we do not check for here!
if
scale_calculation_mode
!=
"even"
:
raise
NotImplementedError
(
f
"Scale calculation mode
{
scale_calculation_mode
}
is not yet "
"supported in MX-FP4 quantization"
)
scale
=
even_round
(
amax
,
"fp4"
)
# Apply dequantize(quantize(x)).
x
=
fake_quantize_fp4_fp6_per_group_with_scale
(
x
,
scale
.
to
(
x
.
device
),
axis
=
axis
,
group_size
=
block_k
,
quant_dtype
=
"fp4"
,
)
return
x
,
scale
vllm/model_executor/model_loader/utils.py
View file @
db593aa6
...
...
@@ -220,7 +220,7 @@ def get_model_architecture(
# Special handling for quantized Mixtral.
# FIXME(woosuk): This is a temporary hack.
mixtral_supported
=
[
"fp8"
,
"compressed-tensors"
,
"gptq_marlin"
,
"awq_marlin"
"fp8"
,
"compressed-tensors"
,
"gptq_marlin"
,
"awq_marlin"
,
"quark"
]
if
(
model_config
.
quantization
is
not
None
...
...
vllm/platforms/interface.py
View file @
db593aa6
...
...
@@ -339,6 +339,13 @@ class Platform:
"""
return
"vllm.distributed.device_communicators.base_device_communicator.DeviceCommunicatorBase"
# noqa
@
classmethod
def
supports_mx
(
cls
)
->
bool
:
"""
Returns whether the current platform supports MX types.
"""
return
False
@
classmethod
def
supports_fp8
(
cls
)
->
bool
:
"""
...
...
vllm/platforms/rocm.py
View file @
db593aa6
...
...
@@ -327,6 +327,11 @@ class RocmPlatform(Platform):
def
get_device_communicator_cls
(
cls
)
->
str
:
return
"vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator"
# noqa
@
classmethod
def
supports_mx
(
cls
)
->
bool
:
gcn_arch
=
torch
.
cuda
.
get_device_properties
(
0
).
gcnArchName
return
any
(
gfx
in
gcn_arch
for
gfx
in
[
"gfx95"
])
@
classmethod
def
supports_fp8
(
cls
)
->
bool
:
gcn_arch
=
torch
.
cuda
.
get_device_properties
(
0
).
gcnArchName
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment