Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
6b46c4b6
Unverified
Commit
6b46c4b6
authored
Jul 21, 2025
by
Zhiyu
Committed by
GitHub
Jul 21, 2025
Browse files
Add Nvidia ModelOpt config adaptation (#19815)
Signed-off-by:
Zhiyu Cheng
<
zhiyuc@nvidia.com
>
parent
d9784107
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
287 additions
and
32 deletions
+287
-32
tests/quantization/test_modelopt.py
tests/quantization/test_modelopt.py
+91
-0
vllm/config.py
vllm/config.py
+13
-7
vllm/model_executor/layers/quantization/modelopt.py
vllm/model_executor/layers/quantization/modelopt.py
+183
-25
No files found.
tests/quantization/test_modelopt.py
0 → 100644
View file @
6b46c4b6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Test ModelOpt quantization method setup and weight loading.
Run `pytest tests/quantization/test_modelopt.py`.
"""
import
os
import
pytest
import
torch
from
tests.quantization.utils
import
is_quant_method_supported
from
vllm.platforms
import
current_platform
@
pytest
.
fixture
(
scope
=
"function"
,
autouse
=
True
)
def
use_v0_only
(
monkeypatch
):
"""
This module relies on V0 internals, so set VLLM_USE_V1=0.
"""
if
not
current_platform
.
is_cpu
():
monkeypatch
.
setenv
(
'VLLM_USE_V1'
,
'0'
)
@
pytest
.
mark
.
skipif
(
not
is_quant_method_supported
(
"modelopt"
),
reason
=
"ModelOpt FP8 is not supported on this GPU type."
)
def
test_modelopt_fp8_checkpoint_setup
(
vllm_runner
):
"""Test ModelOpt FP8 checkpoint loading and structure validation."""
# TODO: provide a small publically available test checkpoint
model_path
=
(
"/home/scratch.omniml_data_1/zhiyu/ckpts/test_ckpts/"
"TinyLlama-1.1B-Chat-v1.0-fp8-0710"
)
# Skip test if checkpoint doesn't exist
if
not
os
.
path
.
exists
(
model_path
):
pytest
.
skip
(
f
"Test checkpoint not found at
{
model_path
}
. "
"This test requires a local ModelOpt FP8 checkpoint."
)
with
vllm_runner
(
model_path
,
quantization
=
"modelopt"
,
enforce_eager
=
True
)
as
llm
:
def
check_model
(
model
):
layer
=
model
.
model
.
layers
[
0
]
qkv_proj
=
layer
.
self_attn
.
qkv_proj
o_proj
=
layer
.
self_attn
.
o_proj
gate_up_proj
=
layer
.
mlp
.
gate_up_proj
down_proj
=
layer
.
mlp
.
down_proj
# Check that ModelOpt quantization method is properly applied
from
vllm.model_executor.layers.quantization.modelopt
import
(
ModelOptFp8LinearMethod
)
assert
isinstance
(
qkv_proj
.
quant_method
,
ModelOptFp8LinearMethod
)
assert
isinstance
(
o_proj
.
quant_method
,
ModelOptFp8LinearMethod
)
assert
isinstance
(
gate_up_proj
.
quant_method
,
ModelOptFp8LinearMethod
)
assert
isinstance
(
down_proj
.
quant_method
,
ModelOptFp8LinearMethod
)
# Check weight dtype is FP8
assert
qkv_proj
.
weight
.
dtype
==
torch
.
float8_e4m3fn
assert
o_proj
.
weight
.
dtype
==
torch
.
float8_e4m3fn
assert
gate_up_proj
.
weight
.
dtype
==
torch
.
float8_e4m3fn
assert
down_proj
.
weight
.
dtype
==
torch
.
float8_e4m3fn
# Check scales are present and have correct dtype
assert
hasattr
(
qkv_proj
,
'weight_scale'
)
assert
hasattr
(
qkv_proj
,
'input_scale'
)
assert
qkv_proj
.
weight_scale
.
dtype
==
torch
.
float32
assert
qkv_proj
.
input_scale
.
dtype
==
torch
.
float32
assert
hasattr
(
o_proj
,
'weight_scale'
)
assert
hasattr
(
o_proj
,
'input_scale'
)
assert
o_proj
.
weight_scale
.
dtype
==
torch
.
float32
assert
o_proj
.
input_scale
.
dtype
==
torch
.
float32
assert
hasattr
(
gate_up_proj
,
'weight_scale'
)
assert
hasattr
(
gate_up_proj
,
'input_scale'
)
assert
gate_up_proj
.
weight_scale
.
dtype
==
torch
.
float32
assert
gate_up_proj
.
input_scale
.
dtype
==
torch
.
float32
assert
hasattr
(
down_proj
,
'weight_scale'
)
assert
hasattr
(
down_proj
,
'input_scale'
)
assert
down_proj
.
weight_scale
.
dtype
==
torch
.
float32
assert
down_proj
.
input_scale
.
dtype
==
torch
.
float32
llm
.
apply_model
(
check_model
)
# Run a simple generation test to ensure the model works
output
=
llm
.
generate_greedy
([
"Hello my name is"
],
max_tokens
=
20
)
assert
output
print
(
f
"ModelOpt FP8 output:
{
output
}
"
)
vllm/config.py
View file @
6b46c4b6
...
...
@@ -1000,9 +1000,13 @@ class ModelConfig:
quant_cfg
=
self
.
_parse_quant_hf_config
()
if
quant_cfg
is
not
None
:
# Use the community standard 'quant_method'
quant_method
=
quant_cfg
.
get
(
"quant_method"
,
""
).
lower
()
# Normalize library names
quant_method
=
quant_method
.
replace
(
"compressed_tensors"
,
"compressed-tensors"
)
quant_cfg
[
"quant_method"
]
=
quant_method
# Quantization methods which are overrides (i.e. they have a
...
...
@@ -1017,6 +1021,8 @@ class ModelConfig:
"awq_marlin"
,
"ipex"
,
"moe_wna16"
,
"modelopt"
,
"modelopt_fp4"
,
]
quantization_methods
=
[
q
for
q
in
supported_quantization
if
q
not
in
overrides
...
...
vllm/model_executor/layers/quantization/modelopt.py
View file @
6b46c4b6
...
...
@@ -75,18 +75,62 @@ class ModelOptFp8Config(QuantizationConfig):
def
get_config_filenames
(
cls
)
->
list
[
str
]:
return
[
"hf_quant_config.json"
]
@
classmethod
def
override_quantization_method
(
cls
,
hf_quant_cfg
,
user_quant
)
->
Optional
[
QuantizationMethods
]:
"""Detect if this ModelOpt config should be used based on
quantization config."""
if
hf_quant_cfg
is
None
:
return
None
# Use the community standard 'quant_method'
quant_method
=
hf_quant_cfg
.
get
(
"quant_method"
,
""
).
lower
()
# Only proceed if the method is explicitly "modelopt"
if
quant_method
!=
"modelopt"
:
return
None
# Look for ModelOpt-specific config structure
if
"quantization"
in
hf_quant_cfg
:
quant_config
=
hf_quant_cfg
[
"quantization"
]
if
isinstance
(
quant_config
,
dict
):
quant_algo
=
quant_config
.
get
(
"quant_algo"
,
""
)
if
"FP8"
in
quant_algo
:
return
"modelopt"
else
:
# Check for compressed-tensors style config with specific quant_algo
quant_algo
=
hf_quant_cfg
.
get
(
"quant_algo"
,
""
)
if
isinstance
(
quant_algo
,
str
)
and
"FP8"
in
quant_algo
:
return
"modelopt"
return
None
@
classmethod
def
from_config
(
cls
,
config
:
dict
[
str
,
Any
])
->
"ModelOptFp8Config"
:
# Handle both ModelOpt format and compressed-tensors style format
if
"quantization"
in
config
:
# ModelOpt format: {"quantization": {"quant_algo": "..."}}
quant_config
=
cls
.
get_from_keys
(
config
,
[
"quantization"
])
quant_method
=
quant_config
[
"quant_algo"
]
kv_cache_quant_method
=
cls
.
get_from_keys
(
config
,
[
"quantization"
]).
get
(
"kv_cache_quant_algo"
)
exclude_modules
=
cls
.
get_from_keys
(
config
,
[
"quantization"
]).
get
(
"exclude_modules"
)
if
not
isinstance
(
quant_config
,
dict
):
raise
ValueError
(
"Expected 'quantization' to be a dictionary in config"
)
quant_method
=
quant_config
.
get
(
"quant_algo"
,
""
)
if
not
quant_method
:
raise
ValueError
(
"Missing 'quant_algo' in quantization config"
)
kv_cache_quant_method
=
quant_config
.
get
(
"kv_cache_quant_algo"
)
exclude_modules
=
quant_config
.
get
(
"exclude_modules"
)
else
:
# Compressed-tensors style format:
# {"quant_algo": "...", "quant_method": "modelopt"}
quant_method
=
config
.
get
(
"quant_algo"
,
""
)
kv_cache_quant_method
=
config
.
get
(
"kv_cache_quant_algo"
)
exclude_modules
=
config
.
get
(
"exclude_modules"
)
if
quant_method
not
in
QUANT_ALGOS
:
raise
ValueError
(
f
"ModelOpt currently only supports:
{
QUANT_ALGOS
}
"
" quantizations in vLLM. Please check the "
raise
ValueError
(
f
"ModelOpt currently only supports:
{
QUANT_ALGOS
}
"
"quantizations in vLLM. Please check the "
"`hf_quant_config.json` file for your model's "
"quant configuration."
)
is_checkpoint_fp8_serialized
=
(
"FP8"
in
quant_method
)
...
...
@@ -434,7 +478,7 @@ class ModelOptNvFp4Config(QuantizationConfig):
def
__init__
(
self
,
is_checkpoint_nvfp4_serialized
:
bool
,
kv_cache_quant_algo
:
str
,
kv_cache_quant_algo
:
Optional
[
str
]
,
exclude_modules
:
list
[
str
],
group_size
:
int
=
16
,
)
->
None
:
...
...
@@ -465,24 +509,138 @@ class ModelOptNvFp4Config(QuantizationConfig):
def
get_config_filenames
(
cls
)
->
list
[
str
]:
return
[
"hf_quant_config.json"
]
@
classmethod
def
override_quantization_method
(
cls
,
hf_quant_cfg
,
user_quant
)
->
Optional
[
QuantizationMethods
]:
"""Detect if this ModelOpt FP4 config should be used based on
quantization config."""
if
hf_quant_cfg
is
None
:
return
None
# Use the community standard 'quant_method'
quant_method
=
hf_quant_cfg
.
get
(
"quant_method"
,
""
).
lower
()
# Only proceed if the method is explicitly "modelopt"
if
quant_method
!=
"modelopt"
:
return
None
# Look for ModelOpt-specific config structure
if
"quantization"
in
hf_quant_cfg
:
quant_config
=
hf_quant_cfg
[
"quantization"
]
if
isinstance
(
quant_config
,
dict
):
quant_algo
=
quant_config
.
get
(
"quant_algo"
,
""
)
if
"NVFP4"
in
quant_algo
:
return
"modelopt_fp4"
else
:
# Check for compressed-tensors style config with specific
# quant_algo field
quant_algo
=
hf_quant_cfg
.
get
(
"quant_algo"
,
""
)
if
isinstance
(
quant_algo
,
str
)
and
"FP4"
in
quant_algo
.
upper
():
return
"modelopt_fp4"
return
None
@
classmethod
def
from_config
(
cls
,
config
:
dict
[
str
,
Any
])
->
"ModelOptNvFp4Config"
:
# Handle both traditional ModelOpt format and compressed-tensors
# style format
if
"quantization"
in
config
:
# Traditional ModelOpt format:
# {"quantization": {"quant_algo": "..."}}
quant_config
=
cls
.
get_from_keys
(
config
,
[
"quantization"
])
quant_method
=
quant_config
[
"quant_algo"
]
if
not
isinstance
(
quant_config
,
dict
):
raise
ValueError
(
"Expected 'quantization' to be a dictionary in config"
)
quant_method
=
quant_config
.
get
(
"quant_algo"
,
""
)
if
not
quant_method
:
raise
ValueError
(
"Missing 'quant_algo' in quantization config"
)
# Handle kv_cache_quant_algo with proper type validation
kv_cache_quant_algo_raw
=
quant_config
.
get
(
"kv_cache_quant_algo"
)
if
kv_cache_quant_algo_raw
is
None
:
# No KV cache quantization by default
kv_cache_quant_algo
=
None
elif
isinstance
(
kv_cache_quant_algo_raw
,
str
):
kv_cache_quant_algo
=
kv_cache_quant_algo_raw
else
:
raise
ValueError
(
f
"kv_cache_quant_algo must be a string, got "
f
"
{
type
(
kv_cache_quant_algo_raw
)
}
"
)
# Handle group_size with proper type validation
group_size_raw
=
quant_config
.
get
(
"group_size"
)
if
group_size_raw
is
None
:
group_size
=
16
# Default value
elif
isinstance
(
group_size_raw
,
int
):
group_size
=
group_size_raw
else
:
try
:
group_size
=
int
(
group_size_raw
)
except
(
ValueError
,
TypeError
):
raise
ValueError
(
f
"group_size must be an integer, got "
f
"
{
type
(
group_size_raw
)
}
"
)
from
None
exclude_modules
=
quant_config
.
get
(
"exclude_modules"
,
[])
if
not
isinstance
(
exclude_modules
,
list
):
raise
ValueError
(
f
"exclude_modules must be a list, got "
f
"
{
type
(
exclude_modules
)
}
"
)
else
:
# Compressed-tensors style format:
# {"quant_algo": "...", "quant_method": "modelopt"}
quant_method
=
config
.
get
(
"quant_algo"
,
""
)
# Handle kv_cache_quant_algo with proper type validation
kv_cache_quant_algo_raw
=
config
.
get
(
"kv_cache_quant_algo"
)
if
kv_cache_quant_algo_raw
is
None
:
# No KV cache quantization by default
kv_cache_quant_algo
=
None
elif
isinstance
(
kv_cache_quant_algo_raw
,
str
):
kv_cache_quant_algo
=
kv_cache_quant_algo_raw
else
:
raise
ValueError
(
f
"kv_cache_quant_algo must be a string, got "
f
"
{
type
(
kv_cache_quant_algo_raw
)
}
"
)
# Handle group_size with proper type validation
group_size_raw
=
config
.
get
(
"group_size"
)
if
group_size_raw
is
None
:
group_size
=
16
# Default value
elif
isinstance
(
group_size_raw
,
int
):
group_size
=
group_size_raw
else
:
try
:
group_size
=
int
(
group_size_raw
)
except
(
ValueError
,
TypeError
):
raise
ValueError
(
f
"group_size must be an integer, got "
f
"
{
type
(
group_size_raw
)
}
"
)
from
None
exclude_modules
=
config
.
get
(
"exclude_modules"
,
[])
if
not
isinstance
(
exclude_modules
,
list
):
raise
ValueError
(
f
"exclude_modules must be a list, got "
f
"
{
type
(
exclude_modules
)
}
"
)
if
quant_method
not
in
QUANT_ALGOS
:
raise
ValueError
(
f
"ModelOpt currently only supports:
{
QUANT_ALGOS
}
"
" quantizations in vLLM. Please check the "
raise
ValueError
(
f
"ModelOpt currently only supports:
{
QUANT_ALGOS
}
"
"quantizations in vLLM. Please check the "
"`hf_quant_config.json` file for your model's "
"quant configuration."
)
is_checkpoint_nvfp4_serialized
=
(
"NVFP4"
in
quant_method
)
if
(
"group_size"
and
"kv_cache_quant_algo"
and
"exclude_modules"
)
not
in
quant_config
:
raise
ValueError
(
"NVFP4 quantization requires group size and "
"kv_cache_quant_algo specified in "
"hf_quant_config.json"
)
kv_cache_quant_algo
=
quant_config
[
"kv_cache_quant_algo"
]
group_size
=
quant_config
[
"group_size"
]
exclude_modules
=
quant_config
[
"exclude_modules"
]
# For FP4, these fields are required
if
is_checkpoint_nvfp4_serialized
and
"quantization"
in
config
:
# Check if required fields are present in the quantization config
quant_config
=
config
[
"quantization"
]
required_fields
=
[
"group_size"
,
"kv_cache_quant_algo"
,
"exclude_modules"
]
missing_fields
=
[
field
for
field
in
required_fields
if
field
not
in
quant_config
]
if
missing_fields
:
raise
ValueError
(
f
"NVFP4 quantization requires the following fields in "
f
"hf_quant_config.json:
{
missing_fields
}
"
)
return
cls
(
is_checkpoint_nvfp4_serialized
,
kv_cache_quant_algo
,
exclude_modules
,
group_size
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment