Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
6c18addb
Unverified
Commit
6c18addb
authored
Oct 23, 2025
by
Liangsheng Yin
Committed by
GitHub
Oct 23, 2025
Browse files
Revert "Support nvidia/NVIDIA-Nemotron-Nano-9B-v2-FP8/NVFP4" (#12015)
parent
32852fe9
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
127 additions
and
207 deletions
+127
-207
python/sglang/srt/layers/quantization/modelopt_quant.py
python/sglang/srt/layers/quantization/modelopt_quant.py
+60
-73
python/sglang/srt/model_loader/loader.py
python/sglang/srt/model_loader/loader.py
+2
-4
python/sglang/srt/model_loader/weight_utils.py
python/sglang/srt/model_loader/weight_utils.py
+29
-42
python/sglang/srt/models/nemotron_h.py
python/sglang/srt/models/nemotron_h.py
+22
-19
test/srt/layers/attention/mamba/test_causal_conv1d.py
test/srt/layers/attention/mamba/test_causal_conv1d.py
+0
-4
test/srt/layers/attention/mamba/test_mamba2_mixer.py
test/srt/layers/attention/mamba/test_mamba2_mixer.py
+0
-5
test/srt/layers/attention/mamba/test_mamba_ssm.py
test/srt/layers/attention/mamba/test_mamba_ssm.py
+0
-5
test/srt/layers/attention/mamba/test_mamba_ssm_ssd.py
test/srt/layers/attention/mamba/test_mamba_ssm_ssd.py
+9
-34
test/srt/models/test_nvidia_nemotron_nano_v2.py
test/srt/models/test_nvidia_nemotron_nano_v2.py
+3
-16
test/srt/run_suite.py
test/srt/run_suite.py
+2
-5
No files found.
python/sglang/srt/layers/quantization/modelopt_quant.py
View file @
6c18addb
...
...
@@ -90,50 +90,7 @@ CUTEDSL_MOE_NVFP4_DISPATCH = get_bool_env_var(
ACTIVATION_SCHEMES
=
[
"static"
]
class
ModelOptQuantConfig
(
QuantizationConfig
):
def
__init__
(
self
,
kv_cache_quant_algo
:
Optional
[
str
],
exclude_modules
:
Optional
[
List
[
str
]],
packed_modules_mapping
:
Optional
[
Dict
[
str
,
List
[
str
]]],
):
super
().
__init__
()
self
.
packed_modules_mapping
=
packed_modules_mapping
self
.
exclude_modules
=
exclude_modules
or
[]
self
.
kv_cache_quant_algo
=
kv_cache_quant_algo
def
_get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
,
*
,
Linear
:
type
[
LinearMethodBase
],
Moe
:
type
[
FusedMoEMethodBase
],
)
->
Optional
[
QuantizeMethodBase
]:
from
sglang.srt.layers.linear
import
LinearBase
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoE
if
isinstance
(
layer
,
LinearBase
):
if
is_layer_skipped
(
prefix
,
self
.
exclude_modules
,
self
.
packed_modules_mapping
)
or
self
.
is_layer_excluded
(
prefix
):
return
UnquantizedLinearMethod
()
return
Linear
(
self
)
elif
self
.
kv_cache_quant_algo
and
isinstance
(
layer
,
RadixAttention
):
return
ModelOptFp8KVCacheMethod
(
self
)
elif
isinstance
(
layer
,
FusedMoE
):
return
Moe
(
self
)
return
None
@
classmethod
def
get_config_filenames
(
cls
)
->
List
[
str
]:
return
[
"hf_quant_config.json"
]
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
return
[]
class
ModelOptFp8Config
(
ModelOptQuantConfig
):
class
ModelOptFp8Config
(
QuantizationConfig
):
"""Configuration for ModelOpt FP8 quantization, including serialization and compatibility checks."""
def
__init__
(
...
...
@@ -141,14 +98,14 @@ class ModelOptFp8Config(ModelOptQuantConfig):
is_checkpoint_fp8_serialized
:
bool
=
False
,
kv_cache_quant_method
:
Optional
[
str
]
=
None
,
exclude_modules
:
Optional
[
List
[
str
]]
=
None
,
packed_modules_mapping
:
Optional
[
Dict
[
str
,
List
[
str
]]]
=
None
,
)
->
None
:
"""
Args:
is_checkpoint_fp8_serialized (bool): Indicates if the checkpoint uses serialized FP8 format.
"""
super
().
__init__
(
kv_cache_quant_method
,
exclude_modules
,
packed_modules_mapping
)
self
.
is_checkpoint_fp8_serialized
=
is_checkpoint_fp8_serialized
self
.
kv_cache_quant_method
=
kv_cache_quant_method
self
.
exclude_modules
=
exclude_modules
if
is_checkpoint_fp8_serialized
:
logger
.
warning
(
"Detected ModelOpt FP8 checkpoint. The format is experimental and subject to change."
...
...
@@ -171,6 +128,10 @@ class ModelOptFp8Config(ModelOptQuantConfig):
def
get_min_capability
(
cls
)
->
int
:
return
89
# Minimum hardware capability (e.g., Hopper GPUs).
@
classmethod
def
get_config_filenames
(
cls
)
->
List
[
str
]:
return
[
"hf_quant_config.json"
]
@
classmethod
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
ModelOptFp8Config
:
# Handle two different config formats:
...
...
@@ -225,27 +186,37 @@ class ModelOptFp8Config(ModelOptQuantConfig):
is_checkpoint_fp8_serialized
=
True
,
kv_cache_quant_method
=
kv_cache_quant_method
,
exclude_modules
=
exclude_modules
,
packed_modules_mapping
=
config
.
get
(
"packed_modules_mapping"
),
)
def
is_layer_excluded
(
self
,
prefix
:
str
)
->
bool
:
if
len
(
self
.
exclude_modules
)
==
0
:
return
False
return
any
(
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
QuantizeMethodBase
]:
from
sglang.srt.layers.linear
import
LinearBase
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoE
if
self
.
exclude_modules
and
any
(
module
in
prefix
or
(
prefix
.
startswith
(
"language_model."
)
and
module
in
prefix
.
removeprefix
(
"language_model."
)
)
for
module
in
self
.
exclude_modules
)
):
return
None
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
QuantizeMethodBase
]:
return
self
.
_get_quant_method
(
layer
,
prefix
,
Linear
=
ModelOptFp8LinearMethod
,
Moe
=
ModelOptFp8MoEMethod
)
if
isinstance
(
layer
,
LinearBase
):
return
ModelOptFp8LinearMethod
(
self
)
if
self
.
kv_cache_quant_method
and
isinstance
(
layer
,
RadixAttention
):
return
ModelOptFp8KVCacheMethod
(
self
)
if
isinstance
(
layer
,
FusedMoE
):
return
ModelOptFp8MoEMethod
(
self
)
return
None
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
return
[]
class
ModelOptFp8LinearMethod
(
LinearMethodBase
):
...
...
@@ -541,7 +512,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
return
self
.
runner
.
run
(
dispatch_output
,
quant_info
)
class
ModelOptFp4Config
(
ModelOptQuant
Config
):
class
ModelOptFp4Config
(
Quantization
Config
):
"""Config class for FP4."""
def
__init__
(
...
...
@@ -550,9 +521,7 @@ class ModelOptFp4Config(ModelOptQuantConfig):
kv_cache_quant_algo
:
str
=
None
,
group_size
:
int
=
None
,
exclude_modules
:
List
[
str
]
=
None
,
packed_modules_mapping
:
Optional
[
Dict
[
str
,
List
[
str
]]]
=
None
,
)
->
None
:
super
().
__init__
(
kv_cache_quant_algo
,
exclude_modules
,
packed_modules_mapping
)
self
.
is_checkpoint_nvfp4_serialized
=
is_checkpoint_nvfp4_serialized
if
is_checkpoint_nvfp4_serialized
:
logger
.
warning
(
...
...
@@ -560,6 +529,8 @@ class ModelOptFp4Config(ModelOptQuantConfig):
"format is experimental and subject to change."
)
self
.
group_size
=
group_size
self
.
kv_cache_quant_algo
=
kv_cache_quant_algo
self
.
exclude_modules
=
exclude_modules
@
classmethod
def
override_quantization_method
(
cls
,
hf_quant_config
,
user_quant
):
...
...
@@ -578,6 +549,10 @@ class ModelOptFp4Config(ModelOptQuantConfig):
def
get_min_capability
(
cls
)
->
int
:
return
100
@
classmethod
def
get_config_filenames
(
cls
)
->
List
[
str
]:
return
[
"hf_quant_config.json"
]
@
staticmethod
def
common_group_size
(
cfg
:
dict
)
->
int
:
"""Return the unique group_size across the config; raise if missing/mismatched."""
...
...
@@ -693,15 +668,14 @@ class ModelOptFp4Config(ModelOptQuantConfig):
kv_cache_quant_algo
,
group_size
,
exclude_modules
,
config
.
get
(
"packed_modules_mapping"
),
)
def
is_layer_excluded
(
self
,
prefix
:
str
):
def
is_layer_excluded
(
self
,
prefix
:
str
,
exclude_modules
:
list
):
import
regex
as
re
fused_patterns
=
[
"q_a_proj"
,
"q_b_proj"
,
"kv_a_proj_with_mqa"
,
"kv_b_proj"
]
prefix_split
=
prefix
.
split
(
"."
)
for
pattern
in
self
.
exclude_modules
:
for
pattern
in
exclude_modules
:
regex_str
=
pattern
.
replace
(
"."
,
r
"\."
).
replace
(
"*"
,
r
".*"
)
pattern_split
=
pattern
.
split
(
"."
)
if
re
.
fullmatch
(
regex_str
,
prefix
):
...
...
@@ -717,17 +691,30 @@ class ModelOptFp4Config(ModelOptQuantConfig):
return
True
return
False
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
):
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
QuantizeMethodBase
]:
from
sglang.srt.layers.linear
import
LinearBase
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoE
from
sglang.srt.layers.moe.fused_moe_triton.layer
import
FlashInferFP4MoE
Moe
=
(
FlashInferFP4MoE
# FlashInferFP4MoE needs the same quantization method but with compatible attribute handling
if
isinstance
(
layer
,
FlashInferFP4MoE
)
else
ModelOptNvFp4FusedMoEMethod
)
return
self
.
_get_quant_method
(
layer
,
prefix
,
Linear
=
ModelOptFp4LinearMethod
,
Moe
=
Moe
)
if
isinstance
(
layer
,
LinearBase
):
if
is_layer_skipped
(
prefix
,
self
.
exclude_modules
)
or
self
.
is_layer_excluded
(
prefix
,
self
.
exclude_modules
):
return
UnquantizedLinearMethod
()
return
ModelOptFp4LinearMethod
(
self
)
if
self
.
kv_cache_quant_algo
and
isinstance
(
layer
,
RadixAttention
):
return
ModelOptFp8KVCacheMethod
(
self
)
elif
isinstance
(
layer
,
FlashInferFP4MoE
):
# FlashInferFP4MoE needs the same quantization method but with compatible attribute handling
return
ModelOptNvFp4FusedMoEMethod
(
self
)
elif
isinstance
(
layer
,
FusedMoE
):
return
ModelOptNvFp4FusedMoEMethod
(
self
)
return
None
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
return
[]
class
ModelOptFp4LinearMethod
(
LinearMethodBase
):
...
...
python/sglang/srt/model_loader/loader.py
View file @
6c18addb
...
...
@@ -180,12 +180,11 @@ def _get_quantization_config(
model_config
:
ModelConfig
,
load_config
:
LoadConfig
,
packed_modules_mapping
:
Dict
[
str
,
List
[
str
]],
remap_prefix
:
Dict
[
str
,
str
]
|
None
=
None
,
)
->
Optional
[
QuantizationConfig
]:
"""Get the quantization config."""
if
model_config
.
quantization
is
not
None
:
quant_config
=
get_quant_config
(
model_config
,
load_config
,
packed_modules_mapping
,
remap_prefix
model_config
,
load_config
,
packed_modules_mapping
)
# (yizhang2077) workaround for nvidia/Llama-4-Maverick-17B-128E-Eagle3
if
quant_config
is
None
:
...
...
@@ -221,7 +220,6 @@ def _initialize_model(
"""Initialize a model with the given configurations."""
model_class
,
_
=
get_model_architecture
(
model_config
)
packed_modules_mapping
=
getattr
(
model_class
,
"packed_modules_mapping"
,
{})
remap_prefix
=
getattr
(
model_class
,
"remap_prefix"
,
None
)
if
_is_npu
:
packed_modules_mapping
.
update
(
{
...
...
@@ -245,7 +243,7 @@ def _initialize_model(
)
quant_config
=
_get_quantization_config
(
model_config
,
load_config
,
packed_modules_mapping
,
remap_prefix
model_config
,
load_config
,
packed_modules_mapping
)
# Build kwargs conditionally
...
...
python/sglang/srt/model_loader/weight_utils.py
View file @
6c18addb
...
...
@@ -37,10 +37,7 @@ from sglang.srt.configs.model_config import ModelConfig
from
sglang.srt.distributed
import
get_tensor_model_parallel_rank
from
sglang.srt.layers.dp_attention
import
get_attention_tp_rank
from
sglang.srt.layers.quantization
import
QuantizationConfig
,
get_quantization_config
from
sglang.srt.layers.quantization.modelopt_quant
import
(
ModelOptFp4Config
,
ModelOptFp8Config
,
)
from
sglang.srt.layers.quantization.modelopt_quant
import
ModelOptFp4Config
from
sglang.srt.utils
import
find_local_repo_dir
,
log_info_on_rank0
,
print_warning_once
from
sglang.utils
import
is_in_ci
...
...
@@ -138,26 +135,11 @@ def convert_bin_to_safetensor_file(
raise
RuntimeError
(
f
"The output tensors do not match for key
{
k
}
"
)
def
replace_prefix
(
key
:
str
,
prefix_mapping
:
dict
[
str
,
str
])
->
str
:
for
prefix
,
new_prefix
in
prefix_mapping
.
items
():
if
key
.
startswith
(
prefix
):
key
=
key
.
replace
(
prefix
,
new_prefix
,
1
)
return
key
def
replace_substrings
(
key
:
str
,
substring_mapping
:
dict
[
str
,
str
])
->
str
:
for
substr
,
new_substr
in
substring_mapping
.
items
():
if
substr
in
key
:
key
=
key
.
replace
(
substr
,
new_substr
)
return
key
# TODO(woosuk): Move this to other place.
def
get_quant_config
(
model_config
:
ModelConfig
,
load_config
:
LoadConfig
,
packed_modules_mapping
:
Dict
[
str
,
List
[
str
]],
remap_prefix
:
Dict
[
str
,
str
]
|
None
=
None
,
)
->
QuantizationConfig
:
quant_cls
=
get_quantization_config
(
model_config
.
quantization
)
...
...
@@ -227,33 +209,38 @@ def get_quant_config(
quant_config_file
=
quant_config_files
[
0
]
with
open
(
quant_config_file
)
as
f
:
config
=
json
.
load
(
f
)
if
remap_prefix
is
not
None
:
exclude_modules
=
[
replace_prefix
(
key
,
remap_prefix
)
for
key
in
config
[
"quantization"
][
"exclude_modules"
]
]
config
[
"quantization"
][
"exclude_modules"
]
=
exclude_modules
config
[
"packed_modules_mapping"
]
=
packed_modules_mapping
if
model_config
.
quantization
==
"bitsandbytes"
:
config
[
"adapter_name_or_path"
]
=
model_name_or_path
elif
model_config
.
quantization
.
startswith
(
"modelopt"
)
and
(
config
[
"producer"
][
"name"
].
startswith
(
"modelopt"
)
):
quant_algo
=
config
[
"quantization"
][
"quant_algo"
]
if
quant_algo
is
None
:
elif
model_config
.
quantization
==
"modelopt"
:
if
config
[
"producer"
][
"name"
]
==
"modelopt"
:
# (yizhang2077) workaround for nvidia/Llama-4-Maverick-17B-128E-Eagle3
if
model_config
.
hf_config
.
architectures
[
0
]
!=
"LlamaForCausalLMEagle3"
:
raise
ValueError
(
f
"Invalid quant_config, quantization method:
{
model_config
.
quantization
}
,"
f
"hf architectures:
{
model_config
.
hf_config
.
architectures
[
0
]
}
. "
)
return
None
elif
quant_algo
==
"FP8"
or
model_config
.
quantization
==
"modelopt_fp8"
:
return
ModelOptFp8Config
.
from_config
(
config
)
elif
"FP4"
in
quant_algo
:
return
ModelOptFp4Config
.
from_config
(
config
)
return
quant_cls
.
from_config
(
config
)
if
config
[
"quantization"
][
"quant_algo"
]
is
None
:
if
(
model_config
.
hf_config
.
architectures
[
0
]
!=
"LlamaForCausalLMEagle3"
):
raise
ValueError
(
f
"Invalid quant_config, quantization method:
{
model_config
.
quantization
}
,"
f
"hf architectures:
{
model_config
.
hf_config
.
architectures
[
0
]
}
. "
)
return
None
if
"FP4"
in
config
[
"quantization"
][
"quant_algo"
]:
return
ModelOptFp4Config
.
from_config
(
config
)
else
:
return
quant_cls
.
from_config
(
config
)
elif
model_config
.
quantization
==
"modelopt_fp8"
:
if
config
[
"producer"
][
"name"
]
==
"modelopt_fp8"
:
return
quant_cls
.
from_config
(
config
)
else
:
raise
ValueError
(
f
"Unsupported quantization config"
f
" found for
{
model_config
.
quantization
}
in
{
f
}
."
)
elif
model_config
.
quantization
==
"w8a8_int8"
:
config
[
"packed_modules_mapping"
]
=
packed_modules_mapping
return
quant_cls
.
from_config
(
config
)
def
find_local_hf_snapshot_dir
(
...
...
python/sglang/srt/models/nemotron_h.py
View file @
6c18addb
...
...
@@ -48,8 +48,6 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTe
from
sglang.srt.model_loader.weight_utils
import
(
default_weight_loader
,
maybe_remap_kv_scale_name
,
replace_prefix
,
replace_substrings
,
)
from
sglang.srt.utils
import
add_prefix
,
make_layers_non_pp
from
sglang.utils
import
logger
...
...
@@ -157,7 +155,6 @@ class NemotronHMambaDecoderLayer(nn.Module):
rms_norm_eps
=
config
.
rms_norm_eps
,
activation
=
config
.
mamba_hidden_act
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.mixer"
,
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
...
...
@@ -384,19 +381,16 @@ class NemotronHModel(nn.Module):
class
NemotronHForCausalLM
(
nn
.
Module
):
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
]
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
],
}
remap_prefix
=
{
"backbone"
:
"model"
}
remap_substr
=
{
"A_log"
:
"A"
,
"embeddings"
:
"embed_tokens"
}
# LoRA specific attributes
embedding_modules
=
{
"embed_tokens"
:
"input_embeddings"
,
"lm_head"
:
"output_embeddings"
,
}
embedding_padding_modules
=
[
"lm_head"
]
def
__init__
(
self
,
*
,
...
...
@@ -438,9 +432,7 @@ class NemotronHForCausalLM(nn.Module):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
return
NemotronHModel
(
config
=
config
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"model"
,
prefix
)
)
return
NemotronHModel
(
config
=
config
,
quant_config
=
quant_config
,
prefix
=
prefix
)
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
model
.
get_input_embeddings
(
input_ids
)
...
...
@@ -468,10 +460,21 @@ class NemotronHForCausalLM(nn.Module):
return
self
.
mamba_cache
.
get_seqlen_agnostic_capture_inputs
(
batch_size
)
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
None
:
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
]
updated_weights
=
[]
for
name
,
loaded_weight
in
weights
:
name
=
replace_prefix
(
name
,
self
.
remap_prefix
)
name
=
replace_substrings
(
name
,
self
.
remap_substr
)
for
prefix
,
new_key
in
self
.
remap_prefix
.
items
():
if
name
.
startswith
(
prefix
):
name
=
name
.
replace
(
prefix
,
new_key
)
for
substr
,
new_key
in
self
.
remap_substr
.
items
():
if
substr
in
name
:
name
=
name
.
replace
(
substr
,
new_key
)
updated_weights
.
append
((
name
,
loaded_weight
))
params_dict
=
dict
(
self
.
named_parameters
())
...
...
@@ -481,7 +484,7 @@ class NemotronHForCausalLM(nn.Module):
if
name
is
None
:
continue
for
param_name
,
weight_name
,
shard_id
in
self
.
stacked_params_mapping
:
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
...
...
test/srt/layers/attention/mamba/test_causal_conv1d.py
View file @
6c18addb
...
...
@@ -373,7 +373,3 @@ def test_causal_conv1d_varlen(
)
unpadded_out
=
out
[:,
:
out_ref_tensor
.
shape
[
-
1
]]
assert
torch
.
allclose
(
unpadded_out
,
out_ref_tensor
,
rtol
=
rtol
,
atol
=
atol
)
if
__name__
==
"__main__"
:
pytest
.
main
([
__file__
])
test/srt/layers/attention/mamba/test_mamba2_mixer.py
View file @
6c18addb
# Adapted from https://github.com/vllm-project/vllm/blob/2c58742dff8613a3bd7496f2008ce927e18d38d1/tests/kernels/mamba/test_mamba_mixer2.py
from
unittest.mock
import
patch
import
pytest
...
...
@@ -137,7 +136,3 @@ def mixer2_gated_norm_tensor_parallel(
atol
=
5e-3
,
rtol
=
1e-3
,
)
if
__name__
==
"__main__"
:
pytest
.
main
([
__file__
])
test/srt/layers/attention/mamba/test_mamba_ssm.py
View file @
6c18addb
# Adapted from https://github.com/vllm-project/vllm/blob/633f943e30a4444d890d26b81850f7217736f840/tests/kernels/mamba/test_mamba_ssm_ssd.py
import
pytest
import
torch
import
torch.nn.functional
as
F
...
...
@@ -290,7 +289,3 @@ def test_selective_state_update_with_heads_with_batch_indices(
print
(
f
"Output mean diff:
{
(
out
-
out_ref
).
abs
().
mean
().
item
()
}
"
)
assert
torch
.
allclose
(
state
[
state_indices
,
:],
state_ref
,
rtol
=
rtol
,
atol
=
atol
)
assert
torch
.
allclose
(
out
,
out_ref
,
rtol
=
rtol
,
atol
=
atol
)
if
__name__
==
"__main__"
:
pytest
.
main
([
__file__
])
test/srt/layers/attention/mamba/test_mamba_ssm_ssd.py
View file @
6c18addb
...
...
@@ -8,12 +8,13 @@ from einops import rearrange, repeat
from
sglang.srt.layers.attention.mamba.mamba2_metadata
import
Mamba2Metadata
from
sglang.srt.layers.attention.mamba.ops
import
mamba_chunk_scan_combined
from
sglang.utils
import
is_in_ci
# Added by the IBM Team, 2024
# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/modules/ssd_minimal.py
# TODO: These take a long time to run - we should cut down on some of the parameterized matrix.
# this is the segsum implementation taken from above
def
segsum
(
x
):
...
...
@@ -190,22 +191,10 @@ def generate_continuous_batched_examples(
)
SINGLE_ITYPE
=
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
]
SINGLE_NHEADS
=
[
3
,
4
,
11
,
16
,
32
]
SINGLE_DHEAD
=
[
5
,
8
,
19
,
32
,
128
]
SINGLE_SEQ_LEN_CHUNK_SIZE
=
[(
112
,
16
),
(
128
,
32
)]
if
is_in_ci
():
SINGLE_ITYPE
=
[
torch
.
float32
,
torch
.
bfloat16
]
SINGLE_NHEADS
=
[
3
,
32
]
SINGLE_DHEAD
=
[
5
,
128
]
SINGLE_SEQ_LEN_CHUNK_SIZE
=
[(
112
,
16
)]
@
pytest
.
mark
.
parametrize
(
"itype"
,
SINGLE_ITYPE
)
@
pytest
.
mark
.
parametrize
(
"n_heads"
,
SINGLE_NHEADS
)
@
pytest
.
mark
.
parametrize
(
"d_head"
,
SINGLE_DHEAD
)
@
pytest
.
mark
.
parametrize
(
"seq_len_chunk_size"
,
SINGLE_SEQ_LEN_CHUNK_SIZE
)
@
pytest
.
mark
.
parametrize
(
"itype"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"n_heads"
,
[
3
,
4
,
11
,
16
,
32
])
@
pytest
.
mark
.
parametrize
(
"d_head"
,
[
5
,
8
,
19
,
32
,
128
])
@
pytest
.
mark
.
parametrize
(
"seq_len_chunk_size"
,
[(
112
,
16
),
(
128
,
32
)])
def
test_mamba_chunk_scan_single_example
(
d_head
,
n_heads
,
seq_len_chunk_size
,
itype
):
if
not
torch
.
cuda
.
is_available
():
pytest
.
skip
(
"CUDA device not available"
)
...
...
@@ -249,19 +238,9 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size, it
)
BATCHED_ITYPE
=
[
torch
.
float32
,
torch
.
float16
]
BATCHED_NHEADS
=
[
4
,
8
,
13
]
BATCHED_DHEAD
=
[
5
,
16
,
21
,
32
]
if
is_in_ci
():
BATCHED_ITYPE
=
[
torch
.
float32
]
BATCHED_NHEADS
=
[
4
,
13
]
BATCHED_DHEAD
=
[
5
,
32
]
@
pytest
.
mark
.
parametrize
(
"itype"
,
BATCHED_ITYPE
)
@
pytest
.
mark
.
parametrize
(
"n_heads"
,
BATCHED_NHEADS
)
@
pytest
.
mark
.
parametrize
(
"d_head"
,
BATCHED_DHEAD
)
@
pytest
.
mark
.
parametrize
(
"itype"
,
[
torch
.
float32
,
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
"n_heads"
,
[
4
,
8
,
13
])
@
pytest
.
mark
.
parametrize
(
"d_head"
,
[
5
,
16
,
21
,
32
])
@
pytest
.
mark
.
parametrize
(
"seq_len_chunk_size_cases"
,
[
...
...
@@ -600,7 +579,3 @@ def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
rtol
=
rtol
,
msg
=
lambda
x
:
f
"seq
{
i
}
state "
+
x
,
)
# noqa: B023
if
__name__
==
"__main__"
:
pytest
.
main
([
__file__
])
test/srt/models/test_nvidia_nemotron_nano_v2.py
View file @
6c18addb
import
unittest
from
types
import
SimpleNamespace
from
sglang.srt.utils
import
is_blackwell
,
kill_process_tree
from
sglang.srt.utils
import
kill_process_tree
from
sglang.test.few_shot_gsm8k
import
run_eval
from
sglang.test.test_utils
import
(
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
...
...
@@ -12,11 +12,9 @@ from sglang.test.test_utils import (
class
TestNvidiaNemotronNanoV2
(
CustomTestCase
):
model
=
"nvidia/NVIDIA-Nemotron-Nano-9B-v2"
accuracy
=
0.87
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
"nvidia/NVIDIA-Nemotron-Nano-9B-v2"
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
process
=
popen_launch_server
(
cls
.
model
,
...
...
@@ -44,18 +42,7 @@ class TestNvidiaNemotronNanoV2(CustomTestCase):
)
metrics
=
run_eval
(
args
)
print
(
f
"
{
metrics
=
}
"
)
self
.
assertGreaterEqual
(
metrics
[
"accuracy"
],
self
.
accuracy
)
class
TestNvidiaNemotronNanoV2FP8
(
TestNvidiaNemotronNanoV2
):
accuracy
=
0.87
model
=
"nvidia/NVIDIA-Nemotron-Nano-9B-v2-FP8"
@
unittest
.
skipIf
(
not
is_blackwell
(),
"NVFP4 only supported on blackwell"
)
class
TestNvidiaNemotronNanoV2NVFP4
(
TestNvidiaNemotronNanoV2
):
accuracy
=
0.855
model
=
"nvidia/NVIDIA-Nemotron-Nano-9B-v2-NVFP4"
self
.
assertGreater
(
metrics
[
"accuracy"
],
0.87
)
if
__name__
==
"__main__"
:
...
...
test/srt/run_suite.py
View file @
6c18addb
...
...
@@ -19,9 +19,6 @@ suites = {
TestFile
(
"hicache/test_hicache_eagle.py"
,
150
),
TestFile
(
"hicache/test_hicache_mla.py"
,
127
),
TestFile
(
"hicache/test_hicache_storage.py"
,
127
),
TestFile
(
"layers/attention/mamba/test_causal_conv1d.py"
,
25
),
TestFile
(
"layers/attention/mamba/test_mamba_ssm.py"
,
50
),
TestFile
(
"layers/attention/mamba/test_mamba_ssm_ssd.py"
,
70
),
TestFile
(
"lora/test_lora.py"
,
200
),
TestFile
(
"lora/test_lora_eviction.py"
,
200
),
TestFile
(
"lora/test_lora_eviction_policy.py"
,
200
),
...
...
@@ -37,7 +34,7 @@ suites = {
TestFile
(
"models/test_embedding_models.py"
,
73
),
TestFile
(
"models/test_encoder_embedding_models.py"
,
460
),
TestFile
(
"models/test_generation_models.py"
,
103
),
TestFile
(
"models/test_nvidia_nemotron_nano_v2.py"
,
30
0
),
TestFile
(
"models/test_nvidia_nemotron_nano_v2.py"
,
18
0
),
TestFile
(
"models/test_qwen_models.py"
,
82
),
TestFile
(
"batch_invariant/test_batch_invariant_ops.py"
,
10
),
TestFile
(
"models/test_reward_models.py"
,
132
),
...
...
@@ -146,7 +143,7 @@ suites = {
TestFile
(
"hicache/test_hicache_storage_3fs_backend.py"
,
200
),
TestFile
(
"hicache/test_hicache_storage_file_backend.py"
,
200
),
TestFile
(
"hicache/test_hicache_storage_mooncake_backend.py"
,
400
),
TestFile
(
"layers/attention/mamba/test_mamba2_mixer.py"
,
5
0
),
TestFile
(
"layers/attention/mamba/test_mamba2_mixer.py"
,
11
0
),
TestFile
(
"lora/test_lora_tp.py"
,
116
),
TestFile
(
"models/test_glm4_moe_models.py"
,
100
),
TestFile
(
"rl/test_update_weights_from_distributed.py"
,
103
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment