Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
d6fee73d
Unverified
Commit
d6fee73d
authored
Oct 23, 2025
by
Netanel Haber
Committed by
GitHub
Oct 23, 2025
Browse files
Support nvidia/NVIDIA-Nemotron-Nano-9B-v2-FP8/NVFP4 (#11866)
parent
36a4cad7
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
207 additions
and
127 deletions
+207
-127
python/sglang/srt/layers/quantization/modelopt_quant.py
python/sglang/srt/layers/quantization/modelopt_quant.py
+73
-60
python/sglang/srt/model_loader/loader.py
python/sglang/srt/model_loader/loader.py
+4
-2
python/sglang/srt/model_loader/weight_utils.py
python/sglang/srt/model_loader/weight_utils.py
+42
-29
python/sglang/srt/models/nemotron_h.py
python/sglang/srt/models/nemotron_h.py
+19
-22
test/srt/layers/attention/mamba/test_causal_conv1d.py
test/srt/layers/attention/mamba/test_causal_conv1d.py
+4
-0
test/srt/layers/attention/mamba/test_mamba2_mixer.py
test/srt/layers/attention/mamba/test_mamba2_mixer.py
+5
-0
test/srt/layers/attention/mamba/test_mamba_ssm.py
test/srt/layers/attention/mamba/test_mamba_ssm.py
+5
-0
test/srt/layers/attention/mamba/test_mamba_ssm_ssd.py
test/srt/layers/attention/mamba/test_mamba_ssm_ssd.py
+34
-9
test/srt/models/test_nvidia_nemotron_nano_v2.py
test/srt/models/test_nvidia_nemotron_nano_v2.py
+16
-3
test/srt/run_suite.py
test/srt/run_suite.py
+5
-2
No files found.
python/sglang/srt/layers/quantization/modelopt_quant.py
View file @
d6fee73d
...
@@ -90,7 +90,50 @@ CUTEDSL_MOE_NVFP4_DISPATCH = get_bool_env_var(
...
@@ -90,7 +90,50 @@ CUTEDSL_MOE_NVFP4_DISPATCH = get_bool_env_var(
ACTIVATION_SCHEMES
=
[
"static"
]
ACTIVATION_SCHEMES
=
[
"static"
]
class
ModelOptFp8Config
(
QuantizationConfig
):
class
ModelOptQuantConfig
(
QuantizationConfig
):
def
__init__
(
self
,
kv_cache_quant_algo
:
Optional
[
str
],
exclude_modules
:
Optional
[
List
[
str
]],
packed_modules_mapping
:
Optional
[
Dict
[
str
,
List
[
str
]]],
):
super
().
__init__
()
self
.
packed_modules_mapping
=
packed_modules_mapping
self
.
exclude_modules
=
exclude_modules
or
[]
self
.
kv_cache_quant_algo
=
kv_cache_quant_algo
def
_get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
,
*
,
Linear
:
type
[
LinearMethodBase
],
Moe
:
type
[
FusedMoEMethodBase
],
)
->
Optional
[
QuantizeMethodBase
]:
from
sglang.srt.layers.linear
import
LinearBase
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoE
if
isinstance
(
layer
,
LinearBase
):
if
is_layer_skipped
(
prefix
,
self
.
exclude_modules
,
self
.
packed_modules_mapping
)
or
self
.
is_layer_excluded
(
prefix
):
return
UnquantizedLinearMethod
()
return
Linear
(
self
)
elif
self
.
kv_cache_quant_algo
and
isinstance
(
layer
,
RadixAttention
):
return
ModelOptFp8KVCacheMethod
(
self
)
elif
isinstance
(
layer
,
FusedMoE
):
return
Moe
(
self
)
return
None
@
classmethod
def
get_config_filenames
(
cls
)
->
List
[
str
]:
return
[
"hf_quant_config.json"
]
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
return
[]
class
ModelOptFp8Config
(
ModelOptQuantConfig
):
"""Configuration for ModelOpt FP8 quantization, including serialization and compatibility checks."""
"""Configuration for ModelOpt FP8 quantization, including serialization and compatibility checks."""
def
__init__
(
def
__init__
(
...
@@ -98,14 +141,14 @@ class ModelOptFp8Config(QuantizationConfig):
...
@@ -98,14 +141,14 @@ class ModelOptFp8Config(QuantizationConfig):
is_checkpoint_fp8_serialized
:
bool
=
False
,
is_checkpoint_fp8_serialized
:
bool
=
False
,
kv_cache_quant_method
:
Optional
[
str
]
=
None
,
kv_cache_quant_method
:
Optional
[
str
]
=
None
,
exclude_modules
:
Optional
[
List
[
str
]]
=
None
,
exclude_modules
:
Optional
[
List
[
str
]]
=
None
,
packed_modules_mapping
:
Optional
[
Dict
[
str
,
List
[
str
]]]
=
None
,
)
->
None
:
)
->
None
:
"""
"""
Args:
Args:
is_checkpoint_fp8_serialized (bool): Indicates if the checkpoint uses serialized FP8 format.
is_checkpoint_fp8_serialized (bool): Indicates if the checkpoint uses serialized FP8 format.
"""
"""
super
().
__init__
(
kv_cache_quant_method
,
exclude_modules
,
packed_modules_mapping
)
self
.
is_checkpoint_fp8_serialized
=
is_checkpoint_fp8_serialized
self
.
is_checkpoint_fp8_serialized
=
is_checkpoint_fp8_serialized
self
.
kv_cache_quant_method
=
kv_cache_quant_method
self
.
exclude_modules
=
exclude_modules
if
is_checkpoint_fp8_serialized
:
if
is_checkpoint_fp8_serialized
:
logger
.
warning
(
logger
.
warning
(
"Detected ModelOpt FP8 checkpoint. The format is experimental and subject to change."
"Detected ModelOpt FP8 checkpoint. The format is experimental and subject to change."
...
@@ -128,10 +171,6 @@ class ModelOptFp8Config(QuantizationConfig):
...
@@ -128,10 +171,6 @@ class ModelOptFp8Config(QuantizationConfig):
def
get_min_capability
(
cls
)
->
int
:
def
get_min_capability
(
cls
)
->
int
:
return
89
# Minimum hardware capability (e.g., Hopper GPUs).
return
89
# Minimum hardware capability (e.g., Hopper GPUs).
@
classmethod
def
get_config_filenames
(
cls
)
->
List
[
str
]:
return
[
"hf_quant_config.json"
]
@
classmethod
@
classmethod
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
ModelOptFp8Config
:
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
ModelOptFp8Config
:
# Handle two different config formats:
# Handle two different config formats:
...
@@ -186,37 +225,27 @@ class ModelOptFp8Config(QuantizationConfig):
...
@@ -186,37 +225,27 @@ class ModelOptFp8Config(QuantizationConfig):
is_checkpoint_fp8_serialized
=
True
,
is_checkpoint_fp8_serialized
=
True
,
kv_cache_quant_method
=
kv_cache_quant_method
,
kv_cache_quant_method
=
kv_cache_quant_method
,
exclude_modules
=
exclude_modules
,
exclude_modules
=
exclude_modules
,
packed_modules_mapping
=
config
.
get
(
"packed_modules_mapping"
),
)
)
def
get_quant_method
(
def
is_layer_excluded
(
self
,
prefix
:
str
)
->
bool
:
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
if
len
(
self
.
exclude_modules
)
==
0
:
)
->
Optional
[
QuantizeMethodBase
]:
return
False
return
any
(
from
sglang.srt.layers.linear
import
LinearBase
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoE
if
self
.
exclude_modules
and
any
(
module
in
prefix
module
in
prefix
or
(
or
(
prefix
.
startswith
(
"language_model."
)
prefix
.
startswith
(
"language_model."
)
and
module
in
prefix
.
removeprefix
(
"language_model."
)
and
module
in
prefix
.
removeprefix
(
"language_model."
)
)
)
for
module
in
self
.
exclude_modules
for
module
in
self
.
exclude_modules
):
)
return
None
if
isinstance
(
layer
,
LinearBase
):
return
ModelOptFp8LinearMethod
(
self
)
if
self
.
kv_cache_quant_method
and
isinstance
(
layer
,
RadixAttention
):
return
ModelOptFp8KVCacheMethod
(
self
)
if
isinstance
(
layer
,
FusedMoE
):
return
ModelOptFp8MoEMethod
(
self
)
return
None
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
def
get_quant_method
(
return
[]
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
QuantizeMethodBase
]:
return
self
.
_get_quant_method
(
layer
,
prefix
,
Linear
=
ModelOptFp8LinearMethod
,
Moe
=
ModelOptFp8MoEMethod
)
class
ModelOptFp8LinearMethod
(
LinearMethodBase
):
class
ModelOptFp8LinearMethod
(
LinearMethodBase
):
...
@@ -512,7 +541,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
...
@@ -512,7 +541,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
return
self
.
runner
.
run
(
dispatch_output
,
quant_info
)
return
self
.
runner
.
run
(
dispatch_output
,
quant_info
)
class
ModelOptFp4Config
(
Quantization
Config
):
class
ModelOptFp4Config
(
ModelOptQuant
Config
):
"""Config class for FP4."""
"""Config class for FP4."""
def
__init__
(
def
__init__
(
...
@@ -521,7 +550,9 @@ class ModelOptFp4Config(QuantizationConfig):
...
@@ -521,7 +550,9 @@ class ModelOptFp4Config(QuantizationConfig):
kv_cache_quant_algo
:
str
=
None
,
kv_cache_quant_algo
:
str
=
None
,
group_size
:
int
=
None
,
group_size
:
int
=
None
,
exclude_modules
:
List
[
str
]
=
None
,
exclude_modules
:
List
[
str
]
=
None
,
packed_modules_mapping
:
Optional
[
Dict
[
str
,
List
[
str
]]]
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
(
kv_cache_quant_algo
,
exclude_modules
,
packed_modules_mapping
)
self
.
is_checkpoint_nvfp4_serialized
=
is_checkpoint_nvfp4_serialized
self
.
is_checkpoint_nvfp4_serialized
=
is_checkpoint_nvfp4_serialized
if
is_checkpoint_nvfp4_serialized
:
if
is_checkpoint_nvfp4_serialized
:
logger
.
warning
(
logger
.
warning
(
...
@@ -529,8 +560,6 @@ class ModelOptFp4Config(QuantizationConfig):
...
@@ -529,8 +560,6 @@ class ModelOptFp4Config(QuantizationConfig):
"format is experimental and subject to change."
"format is experimental and subject to change."
)
)
self
.
group_size
=
group_size
self
.
group_size
=
group_size
self
.
kv_cache_quant_algo
=
kv_cache_quant_algo
self
.
exclude_modules
=
exclude_modules
@
classmethod
@
classmethod
def
override_quantization_method
(
cls
,
hf_quant_config
,
user_quant
):
def
override_quantization_method
(
cls
,
hf_quant_config
,
user_quant
):
...
@@ -549,10 +578,6 @@ class ModelOptFp4Config(QuantizationConfig):
...
@@ -549,10 +578,6 @@ class ModelOptFp4Config(QuantizationConfig):
def
get_min_capability
(
cls
)
->
int
:
def
get_min_capability
(
cls
)
->
int
:
return
100
return
100
@
classmethod
def
get_config_filenames
(
cls
)
->
List
[
str
]:
return
[
"hf_quant_config.json"
]
@
staticmethod
@
staticmethod
def
common_group_size
(
cfg
:
dict
)
->
int
:
def
common_group_size
(
cfg
:
dict
)
->
int
:
"""Return the unique group_size across the config; raise if missing/mismatched."""
"""Return the unique group_size across the config; raise if missing/mismatched."""
...
@@ -668,14 +693,15 @@ class ModelOptFp4Config(QuantizationConfig):
...
@@ -668,14 +693,15 @@ class ModelOptFp4Config(QuantizationConfig):
kv_cache_quant_algo
,
kv_cache_quant_algo
,
group_size
,
group_size
,
exclude_modules
,
exclude_modules
,
config
.
get
(
"packed_modules_mapping"
),
)
)
def
is_layer_excluded
(
self
,
prefix
:
str
,
exclude_modules
:
list
):
def
is_layer_excluded
(
self
,
prefix
:
str
):
import
regex
as
re
import
regex
as
re
fused_patterns
=
[
"q_a_proj"
,
"q_b_proj"
,
"kv_a_proj_with_mqa"
,
"kv_b_proj"
]
fused_patterns
=
[
"q_a_proj"
,
"q_b_proj"
,
"kv_a_proj_with_mqa"
,
"kv_b_proj"
]
prefix_split
=
prefix
.
split
(
"."
)
prefix_split
=
prefix
.
split
(
"."
)
for
pattern
in
exclude_modules
:
for
pattern
in
self
.
exclude_modules
:
regex_str
=
pattern
.
replace
(
"."
,
r
"\."
).
replace
(
"*"
,
r
".*"
)
regex_str
=
pattern
.
replace
(
"."
,
r
"\."
).
replace
(
"*"
,
r
".*"
)
pattern_split
=
pattern
.
split
(
"."
)
pattern_split
=
pattern
.
split
(
"."
)
if
re
.
fullmatch
(
regex_str
,
prefix
):
if
re
.
fullmatch
(
regex_str
,
prefix
):
...
@@ -691,30 +717,17 @@ class ModelOptFp4Config(QuantizationConfig):
...
@@ -691,30 +717,17 @@ class ModelOptFp4Config(QuantizationConfig):
return
True
return
True
return
False
return
False
def
get_quant_method
(
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
):
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
QuantizeMethodBase
]:
from
sglang.srt.layers.linear
import
LinearBase
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoE
from
sglang.srt.layers.moe.fused_moe_triton.layer
import
FlashInferFP4MoE
from
sglang.srt.layers.moe.fused_moe_triton.layer
import
FlashInferFP4MoE
if
isinstance
(
layer
,
LinearBase
):
Moe
=
(
if
is_layer_skipped
(
prefix
,
self
.
exclude_modules
)
or
self
.
is_layer_excluded
(
FlashInferFP4MoE
# FlashInferFP4MoE needs the same quantization method but with compatible attribute handling
prefix
,
self
.
exclude_modules
if
isinstance
(
layer
,
FlashInferFP4MoE
)
):
else
ModelOptNvFp4FusedMoEMethod
return
UnquantizedLinearMethod
()
)
return
ModelOptFp4LinearMethod
(
self
)
return
self
.
_get_quant_method
(
if
self
.
kv_cache_quant_algo
and
isinstance
(
layer
,
RadixAttention
):
layer
,
prefix
,
Linear
=
ModelOptFp4LinearMethod
,
Moe
=
Moe
return
ModelOptFp8KVCacheMethod
(
self
)
)
elif
isinstance
(
layer
,
FlashInferFP4MoE
):
# FlashInferFP4MoE needs the same quantization method but with compatible attribute handling
return
ModelOptNvFp4FusedMoEMethod
(
self
)
elif
isinstance
(
layer
,
FusedMoE
):
return
ModelOptNvFp4FusedMoEMethod
(
self
)
return
None
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
return
[]
class
ModelOptFp4LinearMethod
(
LinearMethodBase
):
class
ModelOptFp4LinearMethod
(
LinearMethodBase
):
...
...
python/sglang/srt/model_loader/loader.py
View file @
d6fee73d
...
@@ -180,11 +180,12 @@ def _get_quantization_config(
...
@@ -180,11 +180,12 @@ def _get_quantization_config(
model_config
:
ModelConfig
,
model_config
:
ModelConfig
,
load_config
:
LoadConfig
,
load_config
:
LoadConfig
,
packed_modules_mapping
:
Dict
[
str
,
List
[
str
]],
packed_modules_mapping
:
Dict
[
str
,
List
[
str
]],
remap_prefix
:
Dict
[
str
,
str
]
|
None
=
None
,
)
->
Optional
[
QuantizationConfig
]:
)
->
Optional
[
QuantizationConfig
]:
"""Get the quantization config."""
"""Get the quantization config."""
if
model_config
.
quantization
is
not
None
:
if
model_config
.
quantization
is
not
None
:
quant_config
=
get_quant_config
(
quant_config
=
get_quant_config
(
model_config
,
load_config
,
packed_modules_mapping
model_config
,
load_config
,
packed_modules_mapping
,
remap_prefix
)
)
# (yizhang2077) workaround for nvidia/Llama-4-Maverick-17B-128E-Eagle3
# (yizhang2077) workaround for nvidia/Llama-4-Maverick-17B-128E-Eagle3
if
quant_config
is
None
:
if
quant_config
is
None
:
...
@@ -220,6 +221,7 @@ def _initialize_model(
...
@@ -220,6 +221,7 @@ def _initialize_model(
"""Initialize a model with the given configurations."""
"""Initialize a model with the given configurations."""
model_class
,
_
=
get_model_architecture
(
model_config
)
model_class
,
_
=
get_model_architecture
(
model_config
)
packed_modules_mapping
=
getattr
(
model_class
,
"packed_modules_mapping"
,
{})
packed_modules_mapping
=
getattr
(
model_class
,
"packed_modules_mapping"
,
{})
remap_prefix
=
getattr
(
model_class
,
"remap_prefix"
,
None
)
if
_is_npu
:
if
_is_npu
:
packed_modules_mapping
.
update
(
packed_modules_mapping
.
update
(
{
{
...
@@ -243,7 +245,7 @@ def _initialize_model(
...
@@ -243,7 +245,7 @@ def _initialize_model(
)
)
quant_config
=
_get_quantization_config
(
quant_config
=
_get_quantization_config
(
model_config
,
load_config
,
packed_modules_mapping
model_config
,
load_config
,
packed_modules_mapping
,
remap_prefix
)
)
# Build kwargs conditionally
# Build kwargs conditionally
...
...
python/sglang/srt/model_loader/weight_utils.py
View file @
d6fee73d
...
@@ -37,7 +37,10 @@ from sglang.srt.configs.model_config import ModelConfig
...
@@ -37,7 +37,10 @@ from sglang.srt.configs.model_config import ModelConfig
from
sglang.srt.distributed
import
get_tensor_model_parallel_rank
from
sglang.srt.distributed
import
get_tensor_model_parallel_rank
from
sglang.srt.layers.dp_attention
import
get_attention_tp_rank
from
sglang.srt.layers.dp_attention
import
get_attention_tp_rank
from
sglang.srt.layers.quantization
import
QuantizationConfig
,
get_quantization_config
from
sglang.srt.layers.quantization
import
QuantizationConfig
,
get_quantization_config
from
sglang.srt.layers.quantization.modelopt_quant
import
ModelOptFp4Config
from
sglang.srt.layers.quantization.modelopt_quant
import
(
ModelOptFp4Config
,
ModelOptFp8Config
,
)
from
sglang.srt.utils
import
find_local_repo_dir
,
log_info_on_rank0
,
print_warning_once
from
sglang.srt.utils
import
find_local_repo_dir
,
log_info_on_rank0
,
print_warning_once
from
sglang.utils
import
is_in_ci
from
sglang.utils
import
is_in_ci
...
@@ -135,11 +138,26 @@ def convert_bin_to_safetensor_file(
...
@@ -135,11 +138,26 @@ def convert_bin_to_safetensor_file(
raise
RuntimeError
(
f
"The output tensors do not match for key
{
k
}
"
)
raise
RuntimeError
(
f
"The output tensors do not match for key
{
k
}
"
)
def
replace_prefix
(
key
:
str
,
prefix_mapping
:
dict
[
str
,
str
])
->
str
:
for
prefix
,
new_prefix
in
prefix_mapping
.
items
():
if
key
.
startswith
(
prefix
):
key
=
key
.
replace
(
prefix
,
new_prefix
,
1
)
return
key
def
replace_substrings
(
key
:
str
,
substring_mapping
:
dict
[
str
,
str
])
->
str
:
for
substr
,
new_substr
in
substring_mapping
.
items
():
if
substr
in
key
:
key
=
key
.
replace
(
substr
,
new_substr
)
return
key
# TODO(woosuk): Move this to other place.
# TODO(woosuk): Move this to other place.
def
get_quant_config
(
def
get_quant_config
(
model_config
:
ModelConfig
,
model_config
:
ModelConfig
,
load_config
:
LoadConfig
,
load_config
:
LoadConfig
,
packed_modules_mapping
:
Dict
[
str
,
List
[
str
]],
packed_modules_mapping
:
Dict
[
str
,
List
[
str
]],
remap_prefix
:
Dict
[
str
,
str
]
|
None
=
None
,
)
->
QuantizationConfig
:
)
->
QuantizationConfig
:
quant_cls
=
get_quantization_config
(
model_config
.
quantization
)
quant_cls
=
get_quantization_config
(
model_config
.
quantization
)
...
@@ -209,37 +227,32 @@ def get_quant_config(
...
@@ -209,37 +227,32 @@ def get_quant_config(
quant_config_file
=
quant_config_files
[
0
]
quant_config_file
=
quant_config_files
[
0
]
with
open
(
quant_config_file
)
as
f
:
with
open
(
quant_config_file
)
as
f
:
config
=
json
.
load
(
f
)
config
=
json
.
load
(
f
)
if
remap_prefix
is
not
None
:
exclude_modules
=
[
replace_prefix
(
key
,
remap_prefix
)
for
key
in
config
[
"quantization"
][
"exclude_modules"
]
]
config
[
"quantization"
][
"exclude_modules"
]
=
exclude_modules
config
[
"packed_modules_mapping"
]
=
packed_modules_mapping
if
model_config
.
quantization
==
"bitsandbytes"
:
if
model_config
.
quantization
==
"bitsandbytes"
:
config
[
"adapter_name_or_path"
]
=
model_name_or_path
config
[
"adapter_name_or_path"
]
=
model_name_or_path
elif
model_config
.
quantization
==
"modelopt"
:
elif
model_config
.
quantization
.
startswith
(
"modelopt"
)
and
(
if
config
[
"producer"
][
"name"
]
==
"modelopt"
:
config
[
"producer"
][
"name"
].
startswith
(
"modelopt"
)
# (yizhang2077) workaround for nvidia/Llama-4-Maverick-17B-128E-Eagle3
if
config
[
"quantization"
][
"quant_algo"
]
is
None
:
if
(
model_config
.
hf_config
.
architectures
[
0
]
!=
"LlamaForCausalLMEagle3"
):
):
quant_algo
=
config
[
"quantization"
][
"quant_algo"
]
if
quant_algo
is
None
:
# (yizhang2077) workaround for nvidia/Llama-4-Maverick-17B-128E-Eagle3
if
model_config
.
hf_config
.
architectures
[
0
]
!=
"LlamaForCausalLMEagle3"
:
raise
ValueError
(
raise
ValueError
(
f
"Invalid quant_config, quantization method:
{
model_config
.
quantization
}
,"
f
"Invalid quant_config, quantization method:
{
model_config
.
quantization
}
,"
f
"hf architectures:
{
model_config
.
hf_config
.
architectures
[
0
]
}
. "
f
"hf architectures:
{
model_config
.
hf_config
.
architectures
[
0
]
}
. "
)
)
return
None
return
None
if
"FP4"
in
config
[
"quantization"
][
"quant_algo"
]:
elif
quant_algo
==
"FP8"
or
model_config
.
quantization
==
"modelopt_fp8"
:
return
ModelOptFp8Config
.
from_config
(
config
)
elif
"FP4"
in
quant_algo
:
return
ModelOptFp4Config
.
from_config
(
config
)
return
ModelOptFp4Config
.
from_config
(
config
)
else
:
return
quant_cls
.
from_config
(
config
)
elif
model_config
.
quantization
==
"modelopt_fp8"
:
if
config
[
"producer"
][
"name"
]
==
"modelopt_fp8"
:
return
quant_cls
.
from_config
(
config
)
else
:
raise
ValueError
(
f
"Unsupported quantization config"
f
" found for
{
model_config
.
quantization
}
in
{
f
}
."
)
elif
model_config
.
quantization
==
"w8a8_int8"
:
config
[
"packed_modules_mapping"
]
=
packed_modules_mapping
return
quant_cls
.
from_config
(
config
)
return
quant_cls
.
from_config
(
config
)
...
...
python/sglang/srt/models/nemotron_h.py
View file @
d6fee73d
...
@@ -48,6 +48,8 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTe
...
@@ -48,6 +48,8 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTe
from
sglang.srt.model_loader.weight_utils
import
(
from
sglang.srt.model_loader.weight_utils
import
(
default_weight_loader
,
default_weight_loader
,
maybe_remap_kv_scale_name
,
maybe_remap_kv_scale_name
,
replace_prefix
,
replace_substrings
,
)
)
from
sglang.srt.utils
import
add_prefix
,
make_layers_non_pp
from
sglang.srt.utils
import
add_prefix
,
make_layers_non_pp
from
sglang.utils
import
logger
from
sglang.utils
import
logger
...
@@ -155,6 +157,7 @@ class NemotronHMambaDecoderLayer(nn.Module):
...
@@ -155,6 +157,7 @@ class NemotronHMambaDecoderLayer(nn.Module):
rms_norm_eps
=
config
.
rms_norm_eps
,
rms_norm_eps
=
config
.
rms_norm_eps
,
activation
=
config
.
mamba_hidden_act
,
activation
=
config
.
mamba_hidden_act
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.mixer"
,
)
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
...
@@ -381,16 +384,19 @@ class NemotronHModel(nn.Module):
...
@@ -381,16 +384,19 @@ class NemotronHModel(nn.Module):
class
NemotronHForCausalLM
(
nn
.
Module
):
class
NemotronHForCausalLM
(
nn
.
Module
):
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
]
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
],
}
remap_prefix
=
{
"backbone"
:
"model"
}
remap_prefix
=
{
"backbone"
:
"model"
}
remap_substr
=
{
"A_log"
:
"A"
,
"embeddings"
:
"embed_tokens"
}
remap_substr
=
{
"A_log"
:
"A"
,
"embeddings"
:
"embed_tokens"
}
# LoRA specific attributes
embedding_modules
=
{
"embed_tokens"
:
"input_embeddings"
,
"lm_head"
:
"output_embeddings"
,
}
embedding_padding_modules
=
[
"lm_head"
]
def
__init__
(
def
__init__
(
self
,
self
,
*
,
*
,
...
@@ -432,7 +438,9 @@ class NemotronHForCausalLM(nn.Module):
...
@@ -432,7 +438,9 @@ class NemotronHForCausalLM(nn.Module):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
prefix
:
str
=
""
,
):
):
return
NemotronHModel
(
config
=
config
,
quant_config
=
quant_config
,
prefix
=
prefix
)
return
NemotronHModel
(
config
=
config
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"model"
,
prefix
)
)
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
model
.
get_input_embeddings
(
input_ids
)
return
self
.
model
.
get_input_embeddings
(
input_ids
)
...
@@ -460,21 +468,10 @@ class NemotronHForCausalLM(nn.Module):
...
@@ -460,21 +468,10 @@ class NemotronHForCausalLM(nn.Module):
return
self
.
mamba_cache
.
get_seqlen_agnostic_capture_inputs
(
batch_size
)
return
self
.
mamba_cache
.
get_seqlen_agnostic_capture_inputs
(
batch_size
)
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
None
:
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
None
:
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
]
updated_weights
=
[]
updated_weights
=
[]
for
name
,
loaded_weight
in
weights
:
for
name
,
loaded_weight
in
weights
:
for
prefix
,
new_key
in
self
.
remap_prefix
.
items
():
name
=
replace_prefix
(
name
,
self
.
remap_prefix
)
if
name
.
startswith
(
prefix
):
name
=
replace_substrings
(
name
,
self
.
remap_substr
)
name
=
name
.
replace
(
prefix
,
new_key
)
for
substr
,
new_key
in
self
.
remap_substr
.
items
():
if
substr
in
name
:
name
=
name
.
replace
(
substr
,
new_key
)
updated_weights
.
append
((
name
,
loaded_weight
))
updated_weights
.
append
((
name
,
loaded_weight
))
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
...
@@ -484,7 +481,7 @@ class NemotronHForCausalLM(nn.Module):
...
@@ -484,7 +481,7 @@ class NemotronHForCausalLM(nn.Module):
if
name
is
None
:
if
name
is
None
:
continue
continue
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
for
param_name
,
weight_name
,
shard_id
in
self
.
stacked_params_mapping
:
if
weight_name
not
in
name
:
if
weight_name
not
in
name
:
continue
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
name
=
name
.
replace
(
weight_name
,
param_name
)
...
...
test/srt/layers/attention/mamba/test_causal_conv1d.py
View file @
d6fee73d
...
@@ -373,3 +373,7 @@ def test_causal_conv1d_varlen(
...
@@ -373,3 +373,7 @@ def test_causal_conv1d_varlen(
)
)
unpadded_out
=
out
[:,
:
out_ref_tensor
.
shape
[
-
1
]]
unpadded_out
=
out
[:,
:
out_ref_tensor
.
shape
[
-
1
]]
assert
torch
.
allclose
(
unpadded_out
,
out_ref_tensor
,
rtol
=
rtol
,
atol
=
atol
)
assert
torch
.
allclose
(
unpadded_out
,
out_ref_tensor
,
rtol
=
rtol
,
atol
=
atol
)
if
__name__
==
"__main__"
:
pytest
.
main
([
__file__
])
test/srt/layers/attention/mamba/test_mamba2_mixer.py
View file @
d6fee73d
# Adapted from https://github.com/vllm-project/vllm/blob/2c58742dff8613a3bd7496f2008ce927e18d38d1/tests/kernels/mamba/test_mamba_mixer2.py
# Adapted from https://github.com/vllm-project/vllm/blob/2c58742dff8613a3bd7496f2008ce927e18d38d1/tests/kernels/mamba/test_mamba_mixer2.py
from
unittest.mock
import
patch
from
unittest.mock
import
patch
import
pytest
import
pytest
...
@@ -136,3 +137,7 @@ def mixer2_gated_norm_tensor_parallel(
...
@@ -136,3 +137,7 @@ def mixer2_gated_norm_tensor_parallel(
atol
=
5e-3
,
atol
=
5e-3
,
rtol
=
1e-3
,
rtol
=
1e-3
,
)
)
if
__name__
==
"__main__"
:
pytest
.
main
([
__file__
])
test/srt/layers/attention/mamba/test_mamba_ssm.py
View file @
d6fee73d
# Adapted from https://github.com/vllm-project/vllm/blob/633f943e30a4444d890d26b81850f7217736f840/tests/kernels/mamba/test_mamba_ssm_ssd.py
# Adapted from https://github.com/vllm-project/vllm/blob/633f943e30a4444d890d26b81850f7217736f840/tests/kernels/mamba/test_mamba_ssm_ssd.py
import
pytest
import
pytest
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
...
@@ -289,3 +290,7 @@ def test_selective_state_update_with_heads_with_batch_indices(
...
@@ -289,3 +290,7 @@ def test_selective_state_update_with_heads_with_batch_indices(
print
(
f
"Output mean diff:
{
(
out
-
out_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"Output mean diff:
{
(
out
-
out_ref
).
abs
().
mean
().
item
()
}
"
)
assert
torch
.
allclose
(
state
[
state_indices
,
:],
state_ref
,
rtol
=
rtol
,
atol
=
atol
)
assert
torch
.
allclose
(
state
[
state_indices
,
:],
state_ref
,
rtol
=
rtol
,
atol
=
atol
)
assert
torch
.
allclose
(
out
,
out_ref
,
rtol
=
rtol
,
atol
=
atol
)
assert
torch
.
allclose
(
out
,
out_ref
,
rtol
=
rtol
,
atol
=
atol
)
if
__name__
==
"__main__"
:
pytest
.
main
([
__file__
])
test/srt/layers/attention/mamba/test_mamba_ssm_ssd.py
View file @
d6fee73d
...
@@ -8,13 +8,12 @@ from einops import rearrange, repeat
...
@@ -8,13 +8,12 @@ from einops import rearrange, repeat
from
sglang.srt.layers.attention.mamba.mamba2_metadata
import
Mamba2Metadata
from
sglang.srt.layers.attention.mamba.mamba2_metadata
import
Mamba2Metadata
from
sglang.srt.layers.attention.mamba.ops
import
mamba_chunk_scan_combined
from
sglang.srt.layers.attention.mamba.ops
import
mamba_chunk_scan_combined
from
sglang.utils
import
is_in_ci
# Added by the IBM Team, 2024
# Added by the IBM Team, 2024
# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/modules/ssd_minimal.py
# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/modules/ssd_minimal.py
# TODO: These take a long time to run - we should cut down on some of the parameterized matrix.
# this is the segsum implementation taken from above
# this is the segsum implementation taken from above
def
segsum
(
x
):
def
segsum
(
x
):
...
@@ -191,10 +190,22 @@ def generate_continuous_batched_examples(
...
@@ -191,10 +190,22 @@ def generate_continuous_batched_examples(
)
)
@
pytest
.
mark
.
parametrize
(
"itype"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
])
SINGLE_ITYPE
=
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
]
@
pytest
.
mark
.
parametrize
(
"n_heads"
,
[
3
,
4
,
11
,
16
,
32
])
SINGLE_NHEADS
=
[
3
,
4
,
11
,
16
,
32
]
@
pytest
.
mark
.
parametrize
(
"d_head"
,
[
5
,
8
,
19
,
32
,
128
])
SINGLE_DHEAD
=
[
5
,
8
,
19
,
32
,
128
]
@
pytest
.
mark
.
parametrize
(
"seq_len_chunk_size"
,
[(
112
,
16
),
(
128
,
32
)])
SINGLE_SEQ_LEN_CHUNK_SIZE
=
[(
112
,
16
),
(
128
,
32
)]
if
is_in_ci
():
SINGLE_ITYPE
=
[
torch
.
float32
,
torch
.
bfloat16
]
SINGLE_NHEADS
=
[
3
,
32
]
SINGLE_DHEAD
=
[
5
,
128
]
SINGLE_SEQ_LEN_CHUNK_SIZE
=
[(
112
,
16
)]
@
pytest
.
mark
.
parametrize
(
"itype"
,
SINGLE_ITYPE
)
@
pytest
.
mark
.
parametrize
(
"n_heads"
,
SINGLE_NHEADS
)
@
pytest
.
mark
.
parametrize
(
"d_head"
,
SINGLE_DHEAD
)
@
pytest
.
mark
.
parametrize
(
"seq_len_chunk_size"
,
SINGLE_SEQ_LEN_CHUNK_SIZE
)
def
test_mamba_chunk_scan_single_example
(
d_head
,
n_heads
,
seq_len_chunk_size
,
itype
):
def
test_mamba_chunk_scan_single_example
(
d_head
,
n_heads
,
seq_len_chunk_size
,
itype
):
if
not
torch
.
cuda
.
is_available
():
if
not
torch
.
cuda
.
is_available
():
pytest
.
skip
(
"CUDA device not available"
)
pytest
.
skip
(
"CUDA device not available"
)
...
@@ -238,9 +249,19 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size, it
...
@@ -238,9 +249,19 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size, it
)
)
@
pytest
.
mark
.
parametrize
(
"itype"
,
[
torch
.
float32
,
torch
.
float16
])
BATCHED_ITYPE
=
[
torch
.
float32
,
torch
.
float16
]
@
pytest
.
mark
.
parametrize
(
"n_heads"
,
[
4
,
8
,
13
])
BATCHED_NHEADS
=
[
4
,
8
,
13
]
@
pytest
.
mark
.
parametrize
(
"d_head"
,
[
5
,
16
,
21
,
32
])
BATCHED_DHEAD
=
[
5
,
16
,
21
,
32
]
if
is_in_ci
():
BATCHED_ITYPE
=
[
torch
.
float32
]
BATCHED_NHEADS
=
[
4
,
13
]
BATCHED_DHEAD
=
[
5
,
32
]
@
pytest
.
mark
.
parametrize
(
"itype"
,
BATCHED_ITYPE
)
@
pytest
.
mark
.
parametrize
(
"n_heads"
,
BATCHED_NHEADS
)
@
pytest
.
mark
.
parametrize
(
"d_head"
,
BATCHED_DHEAD
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"seq_len_chunk_size_cases"
,
"seq_len_chunk_size_cases"
,
[
[
...
@@ -579,3 +600,7 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
...
@@ -579,3 +600,7 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
rtol
=
rtol
,
rtol
=
rtol
,
msg
=
lambda
x
:
f
"seq
{
i
}
state "
+
x
,
msg
=
lambda
x
:
f
"seq
{
i
}
state "
+
x
,
)
# noqa: B023
)
# noqa: B023
if
__name__
==
"__main__"
:
pytest
.
main
([
__file__
])
test/srt/models/test_nvidia_nemotron_nano_v2.py
View file @
d6fee73d
import
unittest
import
unittest
from
types
import
SimpleNamespace
from
types
import
SimpleNamespace
from
sglang.srt.utils
import
kill_process_tree
from
sglang.srt.utils
import
is_blackwell
,
kill_process_tree
from
sglang.test.few_shot_gsm8k
import
run_eval
from
sglang.test.few_shot_gsm8k
import
run_eval
from
sglang.test.test_utils
import
(
from
sglang.test.test_utils
import
(
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
...
@@ -12,9 +12,11 @@ from sglang.test.test_utils import (
...
@@ -12,9 +12,11 @@ from sglang.test.test_utils import (
class
TestNvidiaNemotronNanoV2
(
CustomTestCase
):
class
TestNvidiaNemotronNanoV2
(
CustomTestCase
):
model
=
"nvidia/NVIDIA-Nemotron-Nano-9B-v2"
accuracy
=
0.87
@
classmethod
@
classmethod
def
setUpClass
(
cls
):
def
setUpClass
(
cls
):
cls
.
model
=
"nvidia/NVIDIA-Nemotron-Nano-9B-v2"
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
process
=
popen_launch_server
(
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
model
,
...
@@ -42,7 +44,18 @@ class TestNvidiaNemotronNanoV2(CustomTestCase):
...
@@ -42,7 +44,18 @@ class TestNvidiaNemotronNanoV2(CustomTestCase):
)
)
metrics
=
run_eval
(
args
)
metrics
=
run_eval
(
args
)
print
(
f
"
{
metrics
=
}
"
)
print
(
f
"
{
metrics
=
}
"
)
self
.
assertGreater
(
metrics
[
"accuracy"
],
0.87
)
self
.
assertGreaterEqual
(
metrics
[
"accuracy"
],
self
.
accuracy
)
class
TestNvidiaNemotronNanoV2FP8
(
TestNvidiaNemotronNanoV2
):
accuracy
=
0.87
model
=
"nvidia/NVIDIA-Nemotron-Nano-9B-v2-FP8"
@
unittest
.
skipIf
(
not
is_blackwell
(),
"NVFP4 only supported on blackwell"
)
class
TestNvidiaNemotronNanoV2NVFP4
(
TestNvidiaNemotronNanoV2
):
accuracy
=
0.855
model
=
"nvidia/NVIDIA-Nemotron-Nano-9B-v2-NVFP4"
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
test/srt/run_suite.py
View file @
d6fee73d
...
@@ -19,6 +19,9 @@ suites = {
...
@@ -19,6 +19,9 @@ suites = {
TestFile
(
"hicache/test_hicache_eagle.py"
,
150
),
TestFile
(
"hicache/test_hicache_eagle.py"
,
150
),
TestFile
(
"hicache/test_hicache_mla.py"
,
127
),
TestFile
(
"hicache/test_hicache_mla.py"
,
127
),
TestFile
(
"hicache/test_hicache_storage.py"
,
127
),
TestFile
(
"hicache/test_hicache_storage.py"
,
127
),
TestFile
(
"layers/attention/mamba/test_causal_conv1d.py"
,
25
),
TestFile
(
"layers/attention/mamba/test_mamba_ssm.py"
,
50
),
TestFile
(
"layers/attention/mamba/test_mamba_ssm_ssd.py"
,
70
),
TestFile
(
"lora/test_lora.py"
,
200
),
TestFile
(
"lora/test_lora.py"
,
200
),
TestFile
(
"lora/test_lora_eviction.py"
,
200
),
TestFile
(
"lora/test_lora_eviction.py"
,
200
),
TestFile
(
"lora/test_lora_eviction_policy.py"
,
200
),
TestFile
(
"lora/test_lora_eviction_policy.py"
,
200
),
...
@@ -34,7 +37,7 @@ suites = {
...
@@ -34,7 +37,7 @@ suites = {
TestFile
(
"models/test_embedding_models.py"
,
73
),
TestFile
(
"models/test_embedding_models.py"
,
73
),
TestFile
(
"models/test_encoder_embedding_models.py"
,
460
),
TestFile
(
"models/test_encoder_embedding_models.py"
,
460
),
TestFile
(
"models/test_generation_models.py"
,
103
),
TestFile
(
"models/test_generation_models.py"
,
103
),
TestFile
(
"models/test_nvidia_nemotron_nano_v2.py"
,
18
0
),
TestFile
(
"models/test_nvidia_nemotron_nano_v2.py"
,
30
0
),
TestFile
(
"models/test_qwen_models.py"
,
82
),
TestFile
(
"models/test_qwen_models.py"
,
82
),
TestFile
(
"batch_invariant/test_batch_invariant_ops.py"
,
10
),
TestFile
(
"batch_invariant/test_batch_invariant_ops.py"
,
10
),
TestFile
(
"models/test_reward_models.py"
,
132
),
TestFile
(
"models/test_reward_models.py"
,
132
),
...
@@ -143,7 +146,7 @@ suites = {
...
@@ -143,7 +146,7 @@ suites = {
TestFile
(
"hicache/test_hicache_storage_3fs_backend.py"
,
200
),
TestFile
(
"hicache/test_hicache_storage_3fs_backend.py"
,
200
),
TestFile
(
"hicache/test_hicache_storage_file_backend.py"
,
200
),
TestFile
(
"hicache/test_hicache_storage_file_backend.py"
,
200
),
TestFile
(
"hicache/test_hicache_storage_mooncake_backend.py"
,
400
),
TestFile
(
"hicache/test_hicache_storage_mooncake_backend.py"
,
400
),
TestFile
(
"layers/attention/mamba/test_mamba2_mixer.py"
,
11
0
),
TestFile
(
"layers/attention/mamba/test_mamba2_mixer.py"
,
5
0
),
TestFile
(
"lora/test_lora_tp.py"
,
116
),
TestFile
(
"lora/test_lora_tp.py"
,
116
),
TestFile
(
"models/test_glm4_moe_models.py"
,
100
),
TestFile
(
"models/test_glm4_moe_models.py"
,
100
),
TestFile
(
"rl/test_update_weights_from_distributed.py"
,
103
),
TestFile
(
"rl/test_update_weights_from_distributed.py"
,
103
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment