Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
155cbb51
Unverified
Commit
155cbb51
authored
Oct 06, 2025
by
Zhiyu
Committed by
GitHub
Oct 06, 2025
Browse files
Enable native ModelOpt quantization support (1/3) (#7149)
Signed-off-by:
Zhiyu Cheng
<
zhiyuc@nvidia.com
>
parent
eb30b888
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
464 additions
and
42 deletions
+464
-42
python/sglang/srt/configs/model_config.py
python/sglang/srt/configs/model_config.py
+33
-31
python/sglang/srt/layers/modelopt_utils.py
python/sglang/srt/layers/modelopt_utils.py
+11
-0
python/sglang/srt/layers/quantization/__init__.py
python/sglang/srt/layers/quantization/__init__.py
+1
-1
python/sglang/srt/layers/quantization/modelopt_quant.py
python/sglang/srt/layers/quantization/modelopt_quant.py
+1
-1
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+1
-1
python/sglang/srt/model_loader/__init__.py
python/sglang/srt/model_loader/__init__.py
+1
-1
python/sglang/srt/model_loader/loader.py
python/sglang/srt/model_loader/loader.py
+187
-6
python/sglang/srt/model_loader/weight_utils.py
python/sglang/srt/model_loader/weight_utils.py
+3
-0
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+10
-1
test/srt/run_suite.py
test/srt/run_suite.py
+1
-0
test/srt/test_modelopt_loader.py
test/srt/test_modelopt_loader.py
+215
-0
No files found.
python/sglang/srt/configs/model_config.py
View file @
155cbb51
...
@@ -17,7 +17,7 @@ import logging
...
@@ -17,7 +17,7 @@ import logging
import
math
import
math
import
os
import
os
from
enum
import
Enum
,
IntEnum
,
auto
from
enum
import
Enum
,
IntEnum
,
auto
from
typing
import
List
,
Optional
,
Set
,
Union
from
typing
import
Dict
,
List
,
Optional
,
Set
,
Union
import
torch
import
torch
from
transformers
import
PretrainedConfig
from
transformers
import
PretrainedConfig
...
@@ -85,6 +85,7 @@ class ModelConfig:
...
@@ -85,6 +85,7 @@ class ModelConfig:
enable_multimodal
:
Optional
[
bool
]
=
None
,
enable_multimodal
:
Optional
[
bool
]
=
None
,
dtype
:
str
=
"auto"
,
dtype
:
str
=
"auto"
,
quantization
:
Optional
[
str
]
=
None
,
quantization
:
Optional
[
str
]
=
None
,
modelopt_quant
:
Optional
[
Union
[
str
,
Dict
]]
=
None
,
override_config_file
:
Optional
[
str
]
=
None
,
override_config_file
:
Optional
[
str
]
=
None
,
is_draft_model
:
bool
=
False
,
is_draft_model
:
bool
=
False
,
hybrid_kvcache_ratio
:
Optional
[
float
]
=
None
,
hybrid_kvcache_ratio
:
Optional
[
float
]
=
None
,
...
@@ -94,6 +95,7 @@ class ModelConfig:
...
@@ -94,6 +95,7 @@ class ModelConfig:
self
.
model_path
=
model_path
self
.
model_path
=
model_path
self
.
revision
=
revision
self
.
revision
=
revision
self
.
quantization
=
quantization
self
.
quantization
=
quantization
self
.
modelopt_quant
=
modelopt_quant
self
.
is_draft_model
=
is_draft_model
self
.
is_draft_model
=
is_draft_model
self
.
model_impl
=
model_impl
self
.
model_impl
=
model_impl
...
@@ -209,6 +211,7 @@ class ModelConfig:
...
@@ -209,6 +211,7 @@ class ModelConfig:
enable_multimodal
=
server_args
.
enable_multimodal
,
enable_multimodal
=
server_args
.
enable_multimodal
,
dtype
=
server_args
.
dtype
,
dtype
=
server_args
.
dtype
,
quantization
=
server_args
.
quantization
,
quantization
=
server_args
.
quantization
,
modelopt_quant
=
server_args
.
modelopt_quant
,
hybrid_kvcache_ratio
=
server_args
.
hybrid_kvcache_ratio
,
hybrid_kvcache_ratio
=
server_args
.
hybrid_kvcache_ratio
,
model_impl
=
server_args
.
model_impl
,
model_impl
=
server_args
.
model_impl
,
**
kwargs
,
**
kwargs
,
...
@@ -477,54 +480,52 @@ class ModelConfig:
...
@@ -477,54 +480,52 @@ class ModelConfig:
# example: https://huggingface.co/nvidia/Llama-3.1-8B-Instruct-FP8/tree/main
# example: https://huggingface.co/nvidia/Llama-3.1-8B-Instruct-FP8/tree/main
# example: https://huggingface.co/Barrrrry/DeepSeek-R1-W4AFP8/tree/main
# example: https://huggingface.co/Barrrrry/DeepSeek-R1-W4AFP8/tree/main
is_local
=
os
.
path
.
exists
(
self
.
model_path
)
is_local
=
os
.
path
.
exists
(
self
.
model_path
)
modelopt_quant_config
=
{
"quant_method"
:
"modelopt"
}
if
not
is_local
:
if
not
is_local
:
import
huggingface_hub
import
huggingface_hub
try
:
try
:
from
huggingface_hub
import
HfApi
from
huggingface_hub
import
HfApi
,
hf_hub_download
hf_api
=
HfApi
()
hf_api
=
HfApi
()
if
hf_api
.
file_exists
(
self
.
model_path
,
"hf_quant_config.json"
):
def
check_hf_quant_config
():
# Download and parse the quantization config for remote models
return
hf_api
.
file_exists
(
quant_config_file
=
hf_hub_download
(
self
.
model_path
,
"hf_quant_config.json"
repo_id
=
self
.
model_path
,
filename
=
"hf_quant_config.json"
,
revision
=
self
.
revision
,
)
)
with
open
(
quant_config_file
)
as
f
:
# Retry HF API call up to 3 times
quant_config_dict
=
json
.
load
(
f
)
file_exists
=
retry
(
quant_cfg
=
self
.
_parse_modelopt_quant_config
(
quant_config_dict
)
check_hf_quant_config
,
max_retry
=
2
,
initial_delay
=
1.0
,
max_delay
=
5.0
,
)
if
file_exists
:
quant_cfg
=
modelopt_quant_config
except
huggingface_hub
.
errors
.
OfflineModeIsEnabled
:
except
huggingface_hub
.
errors
.
OfflineModeIsEnabled
:
logger
.
warning
(
logger
.
warning
(
"Offline mode is enabled, skipping hf_quant_config.json check"
"Offline mode is enabled, skipping hf_quant_config.json check"
)
)
except
Exception
as
e
:
pass
logger
.
warning
(
f
"Failed to check hf_quant_config.json:
{
self
.
model_path
}
{
e
}
"
)
elif
os
.
path
.
exists
(
os
.
path
.
join
(
self
.
model_path
,
"hf_quant_config.json"
)):
elif
os
.
path
.
exists
(
os
.
path
.
join
(
self
.
model_path
,
"hf_quant_config.json"
)):
quant_config_file
=
os
.
path
.
join
(
quant_config_file
=
os
.
path
.
join
(
self
.
model_path
,
"hf_quant_config.json"
self
.
model_path
,
"hf_quant_config.json"
)
)
with
open
(
quant_config_file
)
as
f
:
with
open
(
quant_config_file
)
as
f
:
quant_config_dict
=
json
.
load
(
f
)
quant_config_dict
=
json
.
load
(
f
)
json_quant_configs
=
quant_config_dict
[
"quantization"
]
quant_cfg
=
self
.
_parse_modelopt_quant_config
(
quant_config_dict
)
quant_algo
=
json_quant_configs
.
get
(
"quant_algo"
,
None
)
if
quant_algo
==
"MIXED_PRECISION"
:
quant_cfg
=
{
"quant_method"
:
"w4afp8"
}
else
:
quant_cfg
=
modelopt_quant_config
return
quant_cfg
return
quant_cfg
def
_parse_modelopt_quant_config
(
self
,
quant_config_dict
:
dict
)
->
dict
:
"""Parse ModelOpt quantization config and return the appropriate quant_method."""
json_quant_configs
=
quant_config_dict
[
"quantization"
]
quant_algo
=
json_quant_configs
.
get
(
"quant_algo"
,
None
)
if
quant_algo
==
"MIXED_PRECISION"
:
return
{
"quant_method"
:
"w4afp8"
}
elif
quant_algo
and
(
"FP4"
in
quant_algo
or
"NVFP4"
in
quant_algo
):
return
{
"quant_method"
:
"modelopt_fp4"
}
elif
quant_algo
and
"FP8"
in
quant_algo
:
return
{
"quant_method"
:
"modelopt_fp8"
}
else
:
# Default to FP8 for backward compatibility
return
{
"quant_method"
:
"modelopt_fp8"
}
# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
def
_verify_quantization
(
self
)
->
None
:
def
_verify_quantization
(
self
)
->
None
:
supported_quantization
=
[
*
QUANTIZATION_METHODS
]
supported_quantization
=
[
*
QUANTIZATION_METHODS
]
...
@@ -543,7 +544,8 @@ class ModelConfig:
...
@@ -543,7 +544,8 @@ class ModelConfig:
optimized_quantization_methods
=
[
optimized_quantization_methods
=
[
"fp8"
,
"fp8"
,
"marlin"
,
"marlin"
,
"modelopt"
,
"modelopt_fp8"
,
"modelopt_fp4"
,
"gptq_marlin_24"
,
"gptq_marlin_24"
,
"gptq_marlin"
,
"gptq_marlin"
,
"awq_marlin"
,
"awq_marlin"
,
...
...
python/sglang/srt/layers/modelopt_utils.py
0 → 100644
View file @
155cbb51
"""
ModelOpt related constants
"""
QUANT_CFG_CHOICES
=
{
"fp8"
:
"FP8_DEFAULT_CFG"
,
"int4_awq"
:
"INT4_AWQ_CFG"
,
# TODO: add support for int4_awq
"w4a8_awq"
:
"W4A8_AWQ_BETA_CFG"
,
# TODO: add support for w4a8_awq
"nvfp4"
:
"NVFP4_DEFAULT_CFG"
,
"nvfp4_awq"
:
"NVFP4_AWQ_LITE_CFG"
,
# TODO: add support for nvfp4_awq
}
python/sglang/srt/layers/quantization/__init__.py
View file @
155cbb51
...
@@ -72,7 +72,7 @@ if TYPE_CHECKING:
...
@@ -72,7 +72,7 @@ if TYPE_CHECKING:
BASE_QUANTIZATION_METHODS
:
Dict
[
str
,
Type
[
QuantizationConfig
]]
=
{
BASE_QUANTIZATION_METHODS
:
Dict
[
str
,
Type
[
QuantizationConfig
]]
=
{
"fp8"
:
Fp8Config
,
"fp8"
:
Fp8Config
,
"blockwise_int8"
:
BlockInt8Config
,
"blockwise_int8"
:
BlockInt8Config
,
"modelopt"
:
ModelOptFp8Config
,
"modelopt
_fp8
"
:
ModelOptFp8Config
,
"modelopt_fp4"
:
ModelOptFp4Config
,
"modelopt_fp4"
:
ModelOptFp4Config
,
"w8a8_int8"
:
W8A8Int8Config
,
"w8a8_int8"
:
W8A8Int8Config
,
"w8a8_fp8"
:
W8A8Fp8Config
,
"w8a8_fp8"
:
W8A8Fp8Config
,
...
...
python/sglang/srt/layers/quantization/modelopt_quant.py
View file @
155cbb51
...
@@ -113,7 +113,7 @@ class ModelOptFp8Config(QuantizationConfig):
...
@@ -113,7 +113,7 @@ class ModelOptFp8Config(QuantizationConfig):
@
classmethod
@
classmethod
def
get_name
(
cls
)
->
str
:
def
get_name
(
cls
)
->
str
:
return
"modelopt"
return
"modelopt
_fp8
"
@
classmethod
@
classmethod
def
get_supported_act_dtypes
(
cls
)
->
List
[
torch
.
dtype
]:
def
get_supported_act_dtypes
(
cls
)
->
List
[
torch
.
dtype
]:
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
155cbb51
...
@@ -880,7 +880,7 @@ class ModelRunner:
...
@@ -880,7 +880,7 @@ class ModelRunner:
load_config
=
LoadConfig
(
load_format
=
load_format
)
load_config
=
LoadConfig
(
load_format
=
load_format
)
# Only support DefaultModelLoader for now
# Only support DefaultModelLoader for now
loader
=
get_model_loader
(
load_config
)
loader
=
get_model_loader
(
load_config
,
self
.
model_config
)
if
not
isinstance
(
loader
,
DefaultModelLoader
):
if
not
isinstance
(
loader
,
DefaultModelLoader
):
message
=
f
"Failed to get model loader:
{
loader
}
."
message
=
f
"Failed to get model loader:
{
loader
}
."
return
False
,
message
return
False
,
message
...
...
python/sglang/srt/model_loader/__init__.py
View file @
155cbb51
...
@@ -24,7 +24,7 @@ def get_model(
...
@@ -24,7 +24,7 @@ def get_model(
load_config
:
LoadConfig
,
load_config
:
LoadConfig
,
device_config
:
DeviceConfig
,
device_config
:
DeviceConfig
,
)
->
nn
.
Module
:
)
->
nn
.
Module
:
loader
=
get_model_loader
(
load_config
)
loader
=
get_model_loader
(
load_config
,
model_config
)
return
loader
.
load_model
(
return
loader
.
load_model
(
model_config
=
model_config
,
model_config
=
model_config
,
device_config
=
device_config
,
device_config
=
device_config
,
...
...
python/sglang/srt/model_loader/loader.py
View file @
155cbb51
...
@@ -37,10 +37,22 @@ import numpy as np
...
@@ -37,10 +37,22 @@ import numpy as np
import
requests
import
requests
import
safetensors.torch
import
safetensors.torch
import
torch
import
torch
# Try to import accelerate (optional dependency)
try
:
from
accelerate
import
infer_auto_device_map
,
init_empty_weights
from
accelerate.utils
import
get_max_memory
HAS_ACCELERATE
=
True
except
ImportError
:
HAS_ACCELERATE
=
False
infer_auto_device_map
=
None
init_empty_weights
=
None
get_max_memory
=
None
from
huggingface_hub
import
HfApi
,
hf_hub_download
from
huggingface_hub
import
HfApi
,
hf_hub_download
from
torch
import
nn
from
torch
import
nn
from
tqdm.auto
import
tqdm
from
transformers
import
AutoConfig
,
AutoModelForCausalLM
from
transformers
import
AutoModelForCausalLM
from
transformers.utils
import
SAFE_WEIGHTS_INDEX_NAME
from
transformers.utils
import
SAFE_WEIGHTS_INDEX_NAME
from
sglang.srt.configs.load_config
import
LoadConfig
,
LoadFormat
from
sglang.srt.configs.load_config
import
LoadConfig
,
LoadFormat
...
@@ -54,6 +66,8 @@ from sglang.srt.distributed import (
...
@@ -54,6 +66,8 @@ from sglang.srt.distributed import (
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_world_size
,
)
)
from
sglang.srt.layers.modelopt_utils
import
QUANT_CFG_CHOICES
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.model_loader.remote_instance_weight_loader_utils
import
(
from
sglang.srt.model_loader.remote_instance_weight_loader_utils
import
(
trigger_transferring_weights_request
,
trigger_transferring_weights_request
,
)
)
...
@@ -62,6 +76,11 @@ from sglang.srt.model_loader.utils import (
...
@@ -62,6 +76,11 @@ from sglang.srt.model_loader.utils import (
post_load_weights
,
post_load_weights
,
set_default_torch_dtype
,
set_default_torch_dtype
,
)
)
# Constants for memory management
DEFAULT_GPU_MEMORY_FRACTION_FOR_CALIBRATION
=
(
0.8
# Reserve 20% GPU memory headroom for ModelOpt calibration
)
from
sglang.srt.model_loader.weight_utils
import
(
from
sglang.srt.model_loader.weight_utils
import
(
_BAR_FORMAT
,
_BAR_FORMAT
,
default_weight_loader
,
default_weight_loader
,
...
@@ -94,6 +113,8 @@ if TYPE_CHECKING:
...
@@ -94,6 +113,8 @@ if TYPE_CHECKING:
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
_is_npu
=
is_npu
()
_is_npu
=
is_npu
()
# ModelOpt: QUANT_CFG_CHOICES is imported from modelopt_utils.py
# which contains the complete mapping of quantization config choices
@
contextmanager
@
contextmanager
...
@@ -477,12 +498,78 @@ class DefaultModelLoader(BaseModelLoader):
...
@@ -477,12 +498,78 @@ class DefaultModelLoader(BaseModelLoader):
model_config
.
model_path
,
model_config
.
revision
,
fall_back_to_pt
=
True
model_config
.
model_path
,
model_config
.
revision
,
fall_back_to_pt
=
True
)
)
def
_load_modelopt_base_model
(
self
,
model_config
:
ModelConfig
)
->
nn
.
Module
:
"""Load and prepare the base model for ModelOpt quantization.
This method handles the common model loading logic shared between
DefaultModelLoader (conditional) and ModelOptModelLoader (dedicated).
"""
if
not
HAS_ACCELERATE
:
raise
ImportError
(
"accelerate is required for ModelOpt quantization. "
"Please install it with: pip install accelerate"
)
hf_config
=
AutoConfig
.
from_pretrained
(
model_config
.
model_path
,
trust_remote_code
=
True
)
with
init_empty_weights
():
torch_dtype
=
getattr
(
hf_config
,
"torch_dtype"
,
torch
.
float16
)
model
=
AutoModelForCausalLM
.
from_config
(
hf_config
,
torch_dtype
=
torch_dtype
,
trust_remote_code
=
True
)
max_memory
=
get_max_memory
()
inferred_device_map
=
infer_auto_device_map
(
model
,
max_memory
=
max_memory
)
on_cpu
=
"cpu"
in
inferred_device_map
.
values
()
model_kwargs
=
{
"torch_dtype"
:
"auto"
}
device_map
=
"auto"
if
on_cpu
:
for
device
in
max_memory
.
keys
():
if
isinstance
(
device
,
int
):
max_memory
[
device
]
*=
DEFAULT_GPU_MEMORY_FRACTION_FOR_CALIBRATION
logger
.
warning
(
"Model does not fit to the GPU mem. "
f
"We apply the following memory limit for calibration:
\n
{
max_memory
}
\n
"
f
"If you hit GPU OOM issue, please adjust the memory fraction "
f
"(currently
{
DEFAULT_GPU_MEMORY_FRACTION_FOR_CALIBRATION
}
) or "
"reduce the calibration `batch_size` manually."
)
model_kwargs
[
"max_memory"
]
=
max_memory
model
=
AutoModelForCausalLM
.
from_pretrained
(
model_config
.
model_path
,
device_map
=
device_map
,
**
model_kwargs
,
trust_remote_code
=
True
,
)
logger
.
info
(
f
"ModelOpt quantization requested:
{
model_config
.
modelopt_quant
}
"
)
quant_choice_str
=
model_config
.
modelopt_quant
if
not
isinstance
(
quant_choice_str
,
str
):
raise
TypeError
(
f
"modelopt_quant must be a string preset key (e.g., 'fp8'), "
f
"got
{
type
(
quant_choice_str
)
}
"
)
return
model
def
load_model
(
def
load_model
(
self
,
self
,
*
,
*
,
model_config
:
ModelConfig
,
model_config
:
ModelConfig
,
device_config
:
DeviceConfig
,
device_config
:
DeviceConfig
,
)
->
nn
.
Module
:
)
->
nn
.
Module
:
if
hasattr
(
model_config
,
"modelopt_quant"
)
and
model_config
.
modelopt_quant
:
# Load base model using shared method
model
=
self
.
_load_modelopt_base_model
(
model_config
)
# Note: DefaultModelLoader doesn't do additional quantization processing
# For full ModelOpt quantization, use ModelOptModelLoader
return
model
.
eval
()
target_device
=
torch
.
device
(
device_config
.
device
)
target_device
=
torch
.
device
(
device_config
.
device
)
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
target_device
:
with
target_device
:
...
@@ -491,9 +578,9 @@ class DefaultModelLoader(BaseModelLoader):
...
@@ -491,9 +578,9 @@ class DefaultModelLoader(BaseModelLoader):
self
.
load_config
,
self
.
load_config
,
)
)
self
.
load_weights_and_postprocess
(
self
.
load_weights_and_postprocess
(
model
,
self
.
_get_all_weights
(
model_config
,
model
),
target_device
model
,
self
.
_get_all_weights
(
model_config
,
model
),
target_device
)
)
return
model
.
eval
()
return
model
.
eval
()
...
@@ -1668,9 +1755,103 @@ def load_model_with_cpu_quantization(
...
@@ -1668,9 +1755,103 @@ def load_model_with_cpu_quantization(
return
model
.
eval
()
return
model
.
eval
()
def
get_model_loader
(
load_config
:
LoadConfig
)
->
BaseModelLoader
:
class
ModelOptModelLoader
(
DefaultModelLoader
):
"""
Model loader that applies NVIDIA Model Optimizer quantization
"""
def
__init__
(
self
,
load_config
:
LoadConfig
):
super
().
__init__
(
load_config
)
# Any ModelOpt specific initialization if needed
def
load_model
(
self
,
*
,
model_config
:
ModelConfig
,
device_config
:
DeviceConfig
,
)
->
nn
.
Module
:
logger
.
info
(
"ModelOptModelLoader: Loading base model..."
)
# Use shared method from parent class to load base model
model
=
self
.
_load_modelopt_base_model
(
model_config
)
# Import ModelOpt modules (already done in _load_modelopt_base_model, but needed here for quantization)
try
:
import
modelopt.torch.quantization
as
mtq
from
modelopt.torch.utils.dataset_utils
import
create_forward_loop
except
ImportError
:
logger
.
error
(
"NVIDIA Model Optimizer (modelopt) library not found. "
"Please install it to use 'modelopt_quant' feature."
)
raise
quant_choice_str
=
model_config
.
modelopt_quant
quant_cfg_name
=
QUANT_CFG_CHOICES
.
get
(
quant_choice_str
)
if
not
quant_cfg_name
:
raise
ValueError
(
f
"Invalid modelopt_quant choice: '
{
quant_choice_str
}
'. "
f
"Available choices in QUANT_CFG_CHOICES:
{
list
(
QUANT_CFG_CHOICES
.
keys
())
}
. "
"Ensure QUANT_CFG_CHOICES is correctly defined with mappings to "
"attribute names of config objects in modelopt.torch.quantization."
)
try
:
# getattr will fetch the config object, e.g., mtq.FP8_DEFAULT_CFG
quant_cfg
=
getattr
(
mtq
,
quant_cfg_name
)
except
AttributeError
:
raise
AttributeError
(
f
"ModelOpt quantization config attribute '
{
quant_cfg_name
}
' "
f
"(from choice '
{
quant_choice_str
}
') not found in modelopt.torch.quantization module. "
"Please verify QUANT_CFG_CHOICES and the ModelOpt library."
)
# For now, assume no calibration. Calibration setup is a separate, more complex step.
use_calibration
=
False
# This would ideally be a configurable parameter
calib_dataloader
=
None
# This would need to be provided/configured
calibrate_loop
=
(
create_forward_loop
(
dataloader
=
calib_dataloader
)
if
use_calibration
else
None
)
if
use_calibration
and
calib_dataloader
is
None
:
logger
.
warning
(
"ModelOpt calibration requested but no calib_dataloader provided. "
"Proceeding without calibration. Quantization accuracy may be affected."
)
logger
.
info
(
f
"Quantizing model with ModelOpt using config attribute: mtq.
{
quant_cfg_name
}
"
)
try
:
model
=
mtq
.
quantize
(
model
,
quant_cfg
,
forward_loop
=
calibrate_loop
)
logger
.
info
(
"Model successfully quantized with ModelOpt."
)
except
Exception
as
e
:
logger
.
error
(
f
"Error during ModelOpt mtq.quantize call:
{
e
}
"
)
raise
mtq
.
print_quant_summary
(
model
)
return
model
.
eval
()
def
get_model_loader
(
load_config
:
LoadConfig
,
model_config
:
Optional
[
ModelConfig
]
=
None
)
->
BaseModelLoader
:
"""Get a model loader based on the load format."""
"""Get a model loader based on the load format."""
if
(
model_config
and
hasattr
(
model_config
,
"modelopt_quant"
)
and
model_config
.
modelopt_quant
):
logger
.
info
(
"Using ModelOptModelLoader due to 'modelopt_quant' config."
)
return
ModelOptModelLoader
(
load_config
)
if
isinstance
(
load_config
.
load_format
,
type
):
if
isinstance
(
load_config
.
load_format
,
type
):
return
load_config
.
load_format
(
load_config
)
return
load_config
.
load_format
(
load_config
)
...
...
python/sglang/srt/model_loader/weight_utils.py
View file @
155cbb51
...
@@ -226,6 +226,9 @@ def get_quant_config(
...
@@ -226,6 +226,9 @@ def get_quant_config(
return
ModelOptFp4Config
.
from_config
(
config
)
return
ModelOptFp4Config
.
from_config
(
config
)
else
:
else
:
return
quant_cls
.
from_config
(
config
)
return
quant_cls
.
from_config
(
config
)
elif
model_config
.
quantization
==
"modelopt_fp8"
:
if
config
[
"producer"
][
"name"
]
==
"modelopt_fp8"
:
return
quant_cls
.
from_config
(
config
)
else
:
else
:
raise
ValueError
(
raise
ValueError
(
f
"Unsupported quantization config"
f
"Unsupported quantization config"
...
...
python/sglang/srt/server_args.py
View file @
155cbb51
...
@@ -20,7 +20,7 @@ import logging
...
@@ -20,7 +20,7 @@ import logging
import
os
import
os
import
random
import
random
import
tempfile
import
tempfile
from
typing
import
List
,
Literal
,
Optional
,
Union
from
typing
import
Dict
,
List
,
Literal
,
Optional
,
Union
from
sglang.srt.connector
import
ConnectorType
from
sglang.srt.connector
import
ConnectorType
from
sglang.srt.function_call.function_call_parser
import
FunctionCallParser
from
sglang.srt.function_call.function_call_parser
import
FunctionCallParser
...
@@ -162,6 +162,7 @@ class ServerArgs:
...
@@ -162,6 +162,7 @@ class ServerArgs:
load_format
:
str
=
"auto"
load_format
:
str
=
"auto"
model_loader_extra_config
:
str
=
"{}"
model_loader_extra_config
:
str
=
"{}"
trust_remote_code
:
bool
=
False
trust_remote_code
:
bool
=
False
modelopt_quant
:
Optional
[
Union
[
str
,
Dict
]]
=
None
context_length
:
Optional
[
int
]
=
None
context_length
:
Optional
[
int
]
=
None
is_embedding
:
bool
=
False
is_embedding
:
bool
=
False
enable_multimodal
:
Optional
[
bool
]
=
None
enable_multimodal
:
Optional
[
bool
]
=
None
...
@@ -1455,6 +1456,14 @@ class ServerArgs:
...
@@ -1455,6 +1456,14 @@ class ServerArgs:
"KV cache dtype is FP8. Otherwise, KV cache scaling factors "
"KV cache dtype is FP8. Otherwise, KV cache scaling factors "
"default to 1.0, which may cause accuracy issues. "
,
"default to 1.0, which may cause accuracy issues. "
,
)
)
parser
.
add_argument
(
"--modelopt-quant"
,
type
=
str
,
default
=
ServerArgs
.
modelopt_quant
,
help
=
"The ModelOpt quantization configuration. "
"Supported values: 'fp8', 'int4_awq', 'w4a8_awq', 'nvfp4', 'nvfp4_awq'. "
"This requires the NVIDIA Model Optimizer library to be installed: pip install nvidia-modelopt"
,
)
parser
.
add_argument
(
parser
.
add_argument
(
"--kv-cache-dtype"
,
"--kv-cache-dtype"
,
type
=
str
,
type
=
str
,
...
...
test/srt/run_suite.py
View file @
155cbb51
...
@@ -125,6 +125,7 @@ suites = {
...
@@ -125,6 +125,7 @@ suites = {
TestFile
(
"test_vlm_input_format.py"
,
300
),
TestFile
(
"test_vlm_input_format.py"
,
300
),
TestFile
(
"test_vision_openai_server_a.py"
,
724
),
TestFile
(
"test_vision_openai_server_a.py"
,
724
),
TestFile
(
"test_vision_openai_server_b.py"
,
446
),
TestFile
(
"test_vision_openai_server_b.py"
,
446
),
TestFile
(
"test_modelopt_loader.py"
,
30
),
],
],
"per-commit-2-gpu"
:
[
"per-commit-2-gpu"
:
[
TestFile
(
"lora/test_lora_tp.py"
,
116
),
TestFile
(
"lora/test_lora_tp.py"
,
116
),
...
...
test/srt/test_modelopt_loader.py
0 → 100644
View file @
155cbb51
"""
Unit tests for ModelOptModelLoader class.
This test module verifies the functionality of ModelOptModelLoader, which
applies NVIDIA Model Optimizer quantization to models during loading.
"""
import
os
import
sys
import
unittest
from
unittest.mock
import
MagicMock
,
patch
import
torch.nn
as
nn
# Add the sglang path for testing
sys
.
path
.
insert
(
0
,
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
"../../python"
))
from
sglang.srt.configs.device_config
import
DeviceConfig
from
sglang.srt.configs.load_config
import
LoadConfig
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.layers.modelopt_utils
import
QUANT_CFG_CHOICES
from
sglang.srt.model_loader.loader
import
ModelOptModelLoader
from
sglang.test.test_utils
import
CustomTestCase
class
TestModelOptModelLoader
(
CustomTestCase
):
"""Test cases for ModelOptModelLoader functionality."""
def
setUp
(
self
):
"""Set up test fixtures."""
self
.
model_path
=
"TinyLlama/TinyLlama-1.1B-Chat-v1.0"
self
.
load_config
=
LoadConfig
()
self
.
device_config
=
DeviceConfig
(
device
=
"cuda"
)
# Create a basic model config with modelopt_quant
self
.
model_config
=
ModelConfig
(
model_path
=
self
.
model_path
,
modelopt_quant
=
"fp8"
)
# Mock base model
self
.
mock_base_model
=
MagicMock
(
spec
=
nn
.
Module
)
self
.
mock_base_model
.
eval
.
return_value
=
self
.
mock_base_model
@
patch
(
"sglang.srt.model_loader.loader.QUANT_CFG_CHOICES"
,
QUANT_CFG_CHOICES
)
@
patch
(
"sglang.srt.model_loader.loader.logger"
)
def
test_successful_fp8_quantization
(
self
,
mock_logger
):
"""Test successful FP8 quantization workflow."""
# Create loader instance
loader
=
ModelOptModelLoader
(
self
.
load_config
)
# Mock modelopt modules
mock_mtq
=
MagicMock
()
# Configure mtq mock with FP8_DEFAULT_CFG
mock_fp8_cfg
=
MagicMock
()
mock_mtq
.
FP8_DEFAULT_CFG
=
mock_fp8_cfg
mock_mtq
.
quantize
.
return_value
=
self
.
mock_base_model
mock_mtq
.
print_quant_summary
=
MagicMock
()
# Create a custom load_model method for testing that simulates the real logic
def
mock_load_model
(
*
,
model_config
,
device_config
):
mock_logger
.
info
(
"ModelOptModelLoader: Loading base model..."
)
# Simulate loading base model (this is already mocked)
model
=
self
.
mock_base_model
# Simulate the quantization config lookup
quant_choice_str
=
model_config
.
modelopt_quant
quant_cfg_name
=
QUANT_CFG_CHOICES
.
get
(
quant_choice_str
)
if
not
quant_cfg_name
:
raise
ValueError
(
f
"Invalid modelopt_quant choice: '
{
quant_choice_str
}
'"
)
# Simulate getattr call and quantization
if
quant_cfg_name
==
"FP8_DEFAULT_CFG"
:
quant_cfg
=
mock_fp8_cfg
mock_logger
.
info
(
f
"Quantizing model with ModelOpt using config attribute: mtq.
{
quant_cfg_name
}
"
)
# Simulate mtq.quantize call
quantized_model
=
mock_mtq
.
quantize
(
model
,
quant_cfg
,
forward_loop
=
None
)
mock_logger
.
info
(
"Model successfully quantized with ModelOpt."
)
# Simulate print_quant_summary call
mock_mtq
.
print_quant_summary
(
quantized_model
)
return
quantized_model
.
eval
()
return
model
.
eval
()
# Patch the load_model method with our custom implementation
with
patch
.
object
(
loader
,
"load_model"
,
side_effect
=
mock_load_model
):
# Execute the load_model method
result_model
=
loader
.
load_model
(
model_config
=
self
.
model_config
,
device_config
=
self
.
device_config
)
# Verify the quantization process
mock_mtq
.
quantize
.
assert_called_once_with
(
self
.
mock_base_model
,
mock_fp8_cfg
,
forward_loop
=
None
)
# Verify logging
mock_logger
.
info
.
assert_any_call
(
"ModelOptModelLoader: Loading base model..."
)
mock_logger
.
info
.
assert_any_call
(
"Quantizing model with ModelOpt using config attribute: mtq.FP8_DEFAULT_CFG"
)
mock_logger
.
info
.
assert_any_call
(
"Model successfully quantized with ModelOpt."
)
# Verify print_quant_summary was called
mock_mtq
.
print_quant_summary
.
assert_called_once_with
(
self
.
mock_base_model
)
# Verify eval() was called on the returned model
self
.
mock_base_model
.
eval
.
assert_called
()
# Verify we get back the expected model
self
.
assertEqual
(
result_model
,
self
.
mock_base_model
)
class
TestModelOptLoaderIntegration
(
CustomTestCase
):
"""Integration tests for ModelOptModelLoader with Engine API."""
@
patch
(
"sglang.srt.model_loader.loader.get_model_loader"
)
@
patch
(
"sglang.srt.entrypoints.engine.Engine.__init__"
)
def
test_engine_with_modelopt_quant_parameter
(
self
,
mock_engine_init
,
mock_get_model_loader
):
"""Test that Engine properly handles modelopt_quant parameter."""
# Mock the Engine.__init__ to avoid actual initialization
mock_engine_init
.
return_value
=
None
# Mock get_model_loader to return our ModelOptModelLoader
mock_loader
=
MagicMock
(
spec
=
ModelOptModelLoader
)
mock_get_model_loader
.
return_value
=
mock_loader
# Import here to avoid circular imports during test discovery
# import sglang as sgl # Commented out since not directly used
# Test that we can create an engine with modelopt_quant parameter
# This would normally trigger the ModelOptModelLoader selection
try
:
engine_args
=
{
"model_path"
:
"TinyLlama/TinyLlama-1.1B-Chat-v1.0"
,
"modelopt_quant"
:
"fp8"
,
"log_level"
:
"error"
,
# Suppress logs during testing
}
# This tests the parameter parsing and server args creation
from
sglang.srt.server_args
import
ServerArgs
server_args
=
ServerArgs
(
**
engine_args
)
# Verify that modelopt_quant is properly set
self
.
assertEqual
(
server_args
.
modelopt_quant
,
"fp8"
)
except
Exception
as
e
:
# If there are missing dependencies or initialization issues,
# we can still verify the parameter is accepted
if
"modelopt_quant"
not
in
str
(
e
):
# The parameter was accepted, which is what we want to test
pass
else
:
self
.
fail
(
f
"modelopt_quant parameter not properly handled:
{
e
}
"
)
@
patch
(
"sglang.srt.model_loader.loader.get_model_loader"
)
@
patch
(
"sglang.srt.entrypoints.engine.Engine.__init__"
)
def
test_engine_with_modelopt_quant_cli_argument
(
self
,
mock_engine_init
,
mock_get_model_loader
):
"""Test that CLI argument --modelopt-quant is properly parsed."""
# Mock the Engine.__init__ to avoid actual initialization
mock_engine_init
.
return_value
=
None
# Mock get_model_loader to return our ModelOptModelLoader
mock_loader
=
MagicMock
(
spec
=
ModelOptModelLoader
)
mock_get_model_loader
.
return_value
=
mock_loader
# Test CLI argument parsing
import
argparse
from
sglang.srt.server_args
import
ServerArgs
# Create parser and add arguments
parser
=
argparse
.
ArgumentParser
()
ServerArgs
.
add_cli_args
(
parser
)
# Test parsing with modelopt_quant argument
args
=
parser
.
parse_args
(
[
"--model-path"
,
"TinyLlama/TinyLlama-1.1B-Chat-v1.0"
,
"--modelopt-quant"
,
"fp8"
,
]
)
# Convert to ServerArgs using the proper from_cli_args method
server_args
=
ServerArgs
.
from_cli_args
(
args
)
# Verify that modelopt_quant was properly parsed
self
.
assertEqual
(
server_args
.
modelopt_quant
,
"fp8"
)
self
.
assertEqual
(
server_args
.
model_path
,
"TinyLlama/TinyLlama-1.1B-Chat-v1.0"
)
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment