Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
129d2992
Unverified
Commit
129d2992
authored
Oct 11, 2025
by
Zhiyu
Committed by
GitHub
Oct 11, 2025
Browse files
Enable native ModelOpt quantization support (2/3) (#9991)
Signed-off-by:
Zhiyu Cheng
<
zhiyuc@nvidia.com
>
parent
8b85926a
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
127 additions
and
26 deletions
+127
-26
python/sglang/srt/configs/model_config.py
python/sglang/srt/configs/model_config.py
+2
-0
python/sglang/srt/model_loader/loader.py
python/sglang/srt/model_loader/loader.py
+108
-26
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+17
-0
No files found.
python/sglang/srt/configs/model_config.py
View file @
129d2992
...
@@ -86,6 +86,8 @@ class ModelConfig:
...
@@ -86,6 +86,8 @@ class ModelConfig:
dtype
:
str
=
"auto"
,
dtype
:
str
=
"auto"
,
quantization
:
Optional
[
str
]
=
None
,
quantization
:
Optional
[
str
]
=
None
,
modelopt_quant
:
Optional
[
Union
[
str
,
Dict
]]
=
None
,
modelopt_quant
:
Optional
[
Union
[
str
,
Dict
]]
=
None
,
modelopt_checkpoint_restore_path
:
Optional
[
str
]
=
None
,
modelopt_checkpoint_save_path
:
Optional
[
str
]
=
None
,
override_config_file
:
Optional
[
str
]
=
None
,
override_config_file
:
Optional
[
str
]
=
None
,
is_draft_model
:
bool
=
False
,
is_draft_model
:
bool
=
False
,
hybrid_kvcache_ratio
:
Optional
[
float
]
=
None
,
hybrid_kvcache_ratio
:
Optional
[
float
]
=
None
,
...
...
python/sglang/srt/model_loader/loader.py
View file @
129d2992
...
@@ -18,7 +18,7 @@ import threading
...
@@ -18,7 +18,7 @@ import threading
import
time
import
time
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
concurrent.futures
import
ThreadPoolExecutor
from
concurrent.futures
import
ThreadPoolExecutor
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
,
suppress
from
typing
import
(
from
typing
import
(
TYPE_CHECKING
,
TYPE_CHECKING
,
Any
,
Any
,
...
@@ -30,7 +30,6 @@ from typing import (
...
@@ -30,7 +30,6 @@ from typing import (
Tuple
,
Tuple
,
cast
,
cast
,
)
)
from
urllib.parse
import
urlparse
import
huggingface_hub
import
huggingface_hub
import
numpy
as
np
import
numpy
as
np
...
@@ -52,7 +51,7 @@ except ImportError:
...
@@ -52,7 +51,7 @@ except ImportError:
from
huggingface_hub
import
HfApi
,
hf_hub_download
from
huggingface_hub
import
HfApi
,
hf_hub_download
from
torch
import
nn
from
torch
import
nn
from
transformers
import
AutoConfig
,
AutoModelForCausalLM
from
transformers
import
AutoConfig
,
AutoModelForCausalLM
,
AutoTokenizer
from
transformers.utils
import
SAFE_WEIGHTS_INDEX_NAME
from
transformers.utils
import
SAFE_WEIGHTS_INDEX_NAME
from
sglang.srt.configs.load_config
import
LoadConfig
,
LoadFormat
from
sglang.srt.configs.load_config
import
LoadConfig
,
LoadFormat
...
@@ -104,6 +103,7 @@ from sglang.srt.utils import (
...
@@ -104,6 +103,7 @@ from sglang.srt.utils import (
get_device_capability
,
get_device_capability
,
is_npu
,
is_npu
,
is_pin_memory_available
,
is_pin_memory_available
,
rank0_log
,
set_weight_attrs
,
set_weight_attrs
,
)
)
...
@@ -545,7 +545,7 @@ class DefaultModelLoader(BaseModelLoader):
...
@@ -545,7 +545,7 @@ class DefaultModelLoader(BaseModelLoader):
**
model_kwargs
,
**
model_kwargs
,
trust_remote_code
=
True
,
trust_remote_code
=
True
,
)
)
logger
.
info
(
f
"ModelOpt quantization requested:
{
model_config
.
modelopt_quant
}
"
)
rank0_log
(
f
"ModelOpt quantization requested:
{
model_config
.
modelopt_quant
}
"
)
quant_choice_str
=
model_config
.
modelopt_quant
quant_choice_str
=
model_config
.
modelopt_quant
if
not
isinstance
(
quant_choice_str
,
str
):
if
not
isinstance
(
quant_choice_str
,
str
):
...
@@ -1764,6 +1764,96 @@ class ModelOptModelLoader(DefaultModelLoader):
...
@@ -1764,6 +1764,96 @@ class ModelOptModelLoader(DefaultModelLoader):
super
().
__init__
(
load_config
)
super
().
__init__
(
load_config
)
# Any ModelOpt specific initialization if needed
# Any ModelOpt specific initialization if needed
def
_setup_modelopt_quantization
(
self
,
model
,
tokenizer
,
quant_cfg
,
quantized_ckpt_restore_path
:
str
|
None
=
None
,
quantized_ckpt_save_path
:
str
|
None
=
None
,
)
->
None
:
"""
Set up ModelOpt quantization for the given model.
Args:
model: The model to quantize
tokenizer: The tokenizer associated with the model
quant_cfg: The quantization configuration
quantized_ckpt_restore_path: Path to restore quantized checkpoint from
quantized_ckpt_save_path: Path to save quantized checkpoint to
Raises:
ImportError: If ModelOpt is not available
Exception: If quantization setup fails
"""
try
:
import
modelopt.torch.opt
as
mto
import
modelopt.torch.quantization
as
mtq
from
modelopt.torch.quantization.utils
import
is_quantized
except
ImportError
as
e
:
raise
ImportError
(
"ModelOpt is not available. Please install modelopt."
)
from
e
if
is_quantized
(
model
):
rank0_log
(
"Model is already quantized, skipping quantization setup."
)
return
# Restore from checkpoint if provided
if
quantized_ckpt_restore_path
:
try
:
mto
.
restore
(
model
,
quantized_ckpt_restore_path
)
rank0_log
(
f
"Restored quantized model from
{
quantized_ckpt_restore_path
}
"
)
return
except
Exception
as
e
:
logger
.
warning
(
f
"Failed to restore from
{
quantized_ckpt_restore_path
}
:
{
e
}
"
)
rank0_log
(
"Proceeding with calibration-based quantization..."
)
# Set up calibration-based quantization
try
:
# Left padding tends to work better for batched generation with decoder-only LMs
with
suppress
(
Exception
):
tokenizer
.
padding_side
=
"left"
from
modelopt.torch.utils.dataset_utils
import
(
create_forward_loop
,
get_dataset_dataloader
,
)
# Create calibration dataloader
calib_dataloader
=
get_dataset_dataloader
(
dataset_name
=
"cnn_dailymail"
,
# TODO: Consider making this configurable
tokenizer
=
tokenizer
,
batch_size
=
36
,
# TODO: Consider making this configurable
num_samples
=
512
,
# TODO: Consider making this configurable
device
=
model
.
device
,
include_labels
=
False
,
)
calibrate_loop
=
create_forward_loop
(
dataloader
=
calib_dataloader
)
# Apply quantization
mtq
.
quantize
(
model
,
quant_cfg
,
forward_loop
=
calibrate_loop
)
if
get_tensor_model_parallel_rank
()
==
0
:
mtq
.
print_quant_summary
(
model
)
# Save checkpoint if path provided
if
quantized_ckpt_save_path
:
try
:
mto
.
save
(
model
,
quantized_ckpt_save_path
)
rank0_log
(
f
"Quantized model saved to
{
quantized_ckpt_save_path
}
"
)
except
Exception
as
e
:
logger
.
warning
(
f
"Failed to save quantized checkpoint to
{
quantized_ckpt_save_path
}
:
{
e
}
"
)
except
Exception
as
e
:
raise
Exception
(
f
"Failed to set up ModelOpt quantization:
{
e
}
"
)
from
e
def
load_model
(
def
load_model
(
self
,
self
,
*
,
*
,
...
@@ -1779,7 +1869,6 @@ class ModelOptModelLoader(DefaultModelLoader):
...
@@ -1779,7 +1869,6 @@ class ModelOptModelLoader(DefaultModelLoader):
# Import ModelOpt modules (already done in _load_modelopt_base_model, but needed here for quantization)
# Import ModelOpt modules (already done in _load_modelopt_base_model, but needed here for quantization)
try
:
try
:
import
modelopt.torch.quantization
as
mtq
import
modelopt.torch.quantization
as
mtq
from
modelopt.torch.utils.dataset_utils
import
create_forward_loop
except
ImportError
:
except
ImportError
:
logger
.
error
(
logger
.
error
(
"NVIDIA Model Optimizer (modelopt) library not found. "
"NVIDIA Model Optimizer (modelopt) library not found. "
...
@@ -1808,33 +1897,26 @@ class ModelOptModelLoader(DefaultModelLoader):
...
@@ -1808,33 +1897,26 @@ class ModelOptModelLoader(DefaultModelLoader):
"Please verify QUANT_CFG_CHOICES and the ModelOpt library."
"Please verify QUANT_CFG_CHOICES and the ModelOpt library."
)
)
# For now, assume no calibration. Calibration setup is a separate, more complex step.
use_calibration
=
False
# This would ideally be a configurable parameter
calib_dataloader
=
None
# This would need to be provided/configured
calibrate_loop
=
(
create_forward_loop
(
dataloader
=
calib_dataloader
)
if
use_calibration
else
None
)
if
use_calibration
and
calib_dataloader
is
None
:
logger
.
warning
(
"ModelOpt calibration requested but no calib_dataloader provided. "
"Proceeding without calibration. Quantization accuracy may be affected."
)
logger
.
info
(
logger
.
info
(
f
"Quantizing model with ModelOpt using config attribute: mtq.
{
quant_cfg_name
}
"
f
"Quantizing model with ModelOpt using config attribute: mtq.
{
quant_cfg_name
}
"
)
)
quantized_ckpt_restore_path
=
model_config
.
modelopt_checkpoint_restore_path
quantized_ckpt_save_path
=
model_config
.
modelopt_checkpoint_save_path
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_config
.
model_path
,
use_fast
=
True
)
try
:
try
:
model
=
mtq
.
quantize
(
model
,
quant_cfg
,
forward_loop
=
calibrate_loop
)
self
.
_setup_modelopt_quantization
(
logger
.
info
(
"Model successfully quantized with ModelOpt."
)
model
,
tokenizer
,
quant_cfg
,
quantized_ckpt_restore_path
=
quantized_ckpt_restore_path
,
quantized_ckpt_save_path
=
quantized_ckpt_save_path
,
)
except
Exception
as
e
:
except
Exception
as
e
:
logger
.
error
(
f
"Error during ModelOpt mtq.quantize call:
{
e
}
"
)
logger
.
warning
(
f
"ModelOpt quantization failed:
{
e
}
"
)
raise
rank0_log
(
"Proceeding without quantization..."
)
mtq
.
print_quant_summary
(
model
)
return
model
.
eval
()
return
model
.
eval
()
...
...
python/sglang/srt/server_args.py
View file @
129d2992
...
@@ -178,6 +178,8 @@ class ServerArgs:
...
@@ -178,6 +178,8 @@ class ServerArgs:
model_loader_extra_config
:
str
=
"{}"
model_loader_extra_config
:
str
=
"{}"
trust_remote_code
:
bool
=
False
trust_remote_code
:
bool
=
False
modelopt_quant
:
Optional
[
Union
[
str
,
Dict
]]
=
None
modelopt_quant
:
Optional
[
Union
[
str
,
Dict
]]
=
None
modelopt_checkpoint_restore_path
:
Optional
[
str
]
=
None
modelopt_checkpoint_save_path
:
Optional
[
str
]
=
None
context_length
:
Optional
[
int
]
=
None
context_length
:
Optional
[
int
]
=
None
is_embedding
:
bool
=
False
is_embedding
:
bool
=
False
enable_multimodal
:
Optional
[
bool
]
=
None
enable_multimodal
:
Optional
[
bool
]
=
None
...
@@ -1504,6 +1506,21 @@ class ServerArgs:
...
@@ -1504,6 +1506,21 @@ class ServerArgs:
"Supported values: 'fp8', 'int4_awq', 'w4a8_awq', 'nvfp4', 'nvfp4_awq'. "
"Supported values: 'fp8', 'int4_awq', 'w4a8_awq', 'nvfp4', 'nvfp4_awq'. "
"This requires the NVIDIA Model Optimizer library to be installed: pip install nvidia-modelopt"
,
"This requires the NVIDIA Model Optimizer library to be installed: pip install nvidia-modelopt"
,
)
)
parser
.
add_argument
(
"--modelopt-checkpoint-restore-path"
,
type
=
str
,
default
=
ServerArgs
.
modelopt_checkpoint_restore_path
,
help
=
"Path to restore a previously saved ModelOpt quantized checkpoint. "
"If provided, the quantization process will be skipped and the model "
"will be loaded from this checkpoint."
,
)
parser
.
add_argument
(
"--modelopt-checkpoint-save-path"
,
type
=
str
,
default
=
ServerArgs
.
modelopt_checkpoint_save_path
,
help
=
"Path to save the ModelOpt quantized checkpoint after quantization. "
"This allows reusing the quantized model in future runs."
,
)
parser
.
add_argument
(
parser
.
add_argument
(
"--kv-cache-dtype"
,
"--kv-cache-dtype"
,
type
=
str
,
type
=
str
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment