Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
659907e3
"src/vscode:/vscode.git/clone" did not exist on "e70138bbe42f93d44fbf8c4704fbbde1cd1fdbc9"
Unverified
Commit
659907e3
authored
Jul 08, 2025
by
Zhiyu
Committed by
GitHub
Jul 08, 2025
Browse files
Enable ModelOpt Llama4 fp8 checkpoint deployment in SGLang (#7129)
parent
cb9d91ea
Changes
3
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
643 additions
and
81 deletions
+643
-81
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
+39
-1
python/sglang/srt/layers/quantization/modelopt_quant.py
python/sglang/srt/layers/quantization/modelopt_quant.py
+244
-1
python/sglang/srt/models/mllama4.py
python/sglang/srt/models/mllama4.py
+360
-79
No files found.
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
View file @
659907e3
...
@@ -649,6 +649,27 @@ class FusedMoE(torch.nn.Module):
...
@@ -649,6 +649,27 @@ class FusedMoE(torch.nn.Module):
loaded_weight
:
torch
.
tensor
,
loaded_weight
:
torch
.
tensor
,
tp_rank
:
int
,
tp_rank
:
int
,
):
):
"""Load w2 weights for down projection.
Args:
expert_data: The expert data tensor to load into
shard_dim: The dimension to shard along
shard_id: The shard ID (must be "w2")
loaded_weight: The weight tensor to load from
tp_rank: The tensor parallel rank
"""
if
not
isinstance
(
expert_data
,
torch
.
Tensor
)
or
not
isinstance
(
loaded_weight
,
torch
.
Tensor
):
raise
ValueError
(
"expert_data and loaded_weight must be torch.Tensor"
)
if
expert_data
.
dim
()
!=
2
or
loaded_weight
.
dim
()
!=
2
:
raise
ValueError
(
f
"Expected 2D tensors, got expert_data shape
{
expert_data
.
shape
}
and loaded_weight shape
{
loaded_weight
.
shape
}
"
)
if
shard_id
!=
"w2"
:
raise
ValueError
(
f
"shard_id must be 'w2', got
{
shard_id
}
"
)
# Index the loaded weight for tp sharding.
# Index the loaded weight for tp sharding.
# down_proj: "RowParallel" so tp sharding on input_dim
# down_proj: "RowParallel" so tp sharding on input_dim
...
@@ -669,6 +690,10 @@ class FusedMoE(torch.nn.Module):
...
@@ -669,6 +690,10 @@ class FusedMoE(torch.nn.Module):
if
not
self
.
use_presharded_weights
:
if
not
self
.
use_presharded_weights
:
if
self
.
use_triton_kernels
:
if
self
.
use_triton_kernels
:
loaded_weight
=
loaded_weight
.
transpose
(
-
2
,
-
1
)
loaded_weight
=
loaded_weight
.
transpose
(
-
2
,
-
1
)
if
shard_size
*
tp_rank
+
shard_size
>
loaded_weight
.
shape
[
shard_dim
]:
raise
ValueError
(
f
"Shard size
{
shard_size
}
at rank
{
tp_rank
}
exceeds loaded_weight dimension
{
loaded_weight
.
shape
[
shard_dim
]
}
"
)
loaded_weight
=
loaded_weight
.
narrow
(
loaded_weight
=
loaded_weight
.
narrow
(
shard_dim
,
shard_size
*
tp_rank
,
shard_size
shard_dim
,
shard_size
*
tp_rank
,
shard_size
)
)
...
@@ -795,8 +820,21 @@ class FusedMoE(torch.nn.Module):
...
@@ -795,8 +820,21 @@ class FusedMoE(torch.nn.Module):
tp_rank
=
tp_rank
,
tp_rank
=
tp_rank
,
)
)
return
return
if
"ModelOpt"
in
self
.
quant_method
.
__class__
.
__name__
:
if
"ModelOpt"
in
self
.
quant_method
.
__class__
.
__name__
:
if
"weight_scale_2"
in
weight_name
or
"input_scale"
in
weight_name
:
# Determine per-tensor weight scale patterns based on variant
is_fp4_variant
=
(
"ModelOptNvFp4FusedMoEMethod"
in
self
.
quant_method
.
__class__
.
__name__
)
# FP4 uses "weight_scale_2" for per-tensor, FP8 uses "weight_scale" for per-tensor
per_tensor_conditions
=
(
"weight_scale_2"
in
weight_name
if
is_fp4_variant
else
"weight_scale"
in
weight_name
)
or
"input_scale"
in
weight_name
if
per_tensor_conditions
:
self
.
_load_per_tensor_weight_scale
(
self
.
_load_per_tensor_weight_scale
(
shard_id
=
shard_id
,
shard_id
=
shard_id
,
param
=
param
,
param
=
param
,
...
...
python/sglang/srt/layers/quantization/modelopt_quant.py
View file @
659907e3
...
@@ -26,6 +26,7 @@ from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
...
@@ -26,6 +26,7 @@ from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
from
sglang.srt.layers.quantization.utils
import
(
from
sglang.srt.layers.quantization.utils
import
(
convert_to_channelwise
,
convert_to_channelwise
,
is_layer_skipped
,
is_layer_skipped
,
per_tensor_dequantize
,
requantize_with_max_scale
,
requantize_with_max_scale
,
)
)
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
...
@@ -110,7 +111,12 @@ class ModelOptFp8Config(QuantizationConfig):
...
@@ -110,7 +111,12 @@ class ModelOptFp8Config(QuantizationConfig):
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
"QuantizeMethodBase"
]:
)
->
Optional
[
"QuantizeMethodBase"
]:
if
self
.
exclude_modules
and
any
(
if
self
.
exclude_modules
and
any
(
module
in
prefix
for
module
in
self
.
exclude_modules
module
in
prefix
or
(
prefix
.
startswith
(
"language_model."
)
and
module
in
prefix
.
removeprefix
(
"language_model."
)
)
for
module
in
self
.
exclude_modules
):
):
return
None
return
None
...
@@ -119,6 +125,12 @@ class ModelOptFp8Config(QuantizationConfig):
...
@@ -119,6 +125,12 @@ class ModelOptFp8Config(QuantizationConfig):
if
self
.
kv_cache_quant_method
and
isinstance
(
layer
,
RadixAttention
):
if
self
.
kv_cache_quant_method
and
isinstance
(
layer
,
RadixAttention
):
return
ModelOptFp8KVCacheMethod
(
self
)
return
ModelOptFp8KVCacheMethod
(
self
)
# Add MoE support
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoE
if
isinstance
(
layer
,
FusedMoE
):
return
ModelOptFp8MoEMethod
(
self
)
return
None
return
None
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
...
@@ -234,6 +246,237 @@ class ModelOptFp8KVCacheMethod(BaseKVCacheMethod):
...
@@ -234,6 +246,237 @@ class ModelOptFp8KVCacheMethod(BaseKVCacheMethod):
super
().
__init__
(
quant_config
)
super
().
__init__
(
quant_config
)
class
ModelOptFp8MoEMethod
:
"""MoE method for ModelOpt FP8.
Supports loading FP8 checkpoints with static weight scale and activation scale.
Args:
quant_config: The ModelOpt quantization config.
"""
def
__new__
(
cls
,
*
args
,
**
kwargs
):
"""
Dynamic class composition pattern.
This allows us to effectively "inject" FusedMoEMethodBase as a parent class
at runtime while avoiding circular import issues.
"""
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoEMethodBase
if
not
hasattr
(
cls
,
"_initialized"
):
original_init
=
cls
.
__init__
new_cls
=
type
(
cls
.
__name__
,
(
FusedMoEMethodBase
,),
{
"__init__"
:
original_init
,
**
{
k
:
v
for
k
,
v
in
cls
.
__dict__
.
items
()
if
k
!=
"__dict__"
},
},
)
obj
=
super
(
new_cls
,
new_cls
).
__new__
(
new_cls
)
obj
.
__init__
(
*
args
,
**
kwargs
)
return
obj
return
super
().
__new__
(
cls
)
def
__init__
(
self
,
quant_config
:
ModelOptFp8Config
):
self
.
quant_config
=
quant_config
self
.
cutlass_fp8_supported
=
cutlass_fp8_supported
()
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoeWeightScaleSupported
# Use FP8 dtype if checkpoint is serialized, otherwise use the default dtype
weight_dtype
=
(
torch
.
float8_e4m3fn
if
self
.
quant_config
.
is_checkpoint_fp8_serialized
else
params_dtype
)
weight_loader
=
extra_weight_attrs
.
get
(
"weight_loader"
)
w13_weight
=
ModelWeightParameter
(
data
=
torch
.
empty
(
num_experts
,
2
*
intermediate_size
,
hidden_size
,
dtype
=
weight_dtype
),
input_dim
=
2
,
output_dim
=
1
,
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"w13_weight"
,
w13_weight
)
w2_weight
=
ModelWeightParameter
(
data
=
torch
.
empty
(
num_experts
,
hidden_size
,
intermediate_size
,
dtype
=
weight_dtype
),
input_dim
=
2
,
output_dim
=
1
,
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
if
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
# WEIGHT SCALES - Per-tensor scaling for ModelOpts
# Allocate 2 scales for w1 and w3 respectively.
# They will be combined to a single scale after weight loading.
w13_weight_scale
=
PerTensorScaleParameter
(
data
=
torch
.
full
(
(
num_experts
,
2
),
torch
.
finfo
(
torch
.
float32
).
min
,
dtype
=
torch
.
float32
,
),
weight_loader
=
weight_loader
,
)
w2_weight_scale
=
PerTensorScaleParameter
(
data
=
torch
.
full
(
(
num_experts
,),
torch
.
finfo
(
torch
.
float32
).
min
,
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"w13_weight_scale"
,
w13_weight_scale
)
layer
.
register_parameter
(
"w2_weight_scale"
,
w2_weight_scale
)
# Set weight loader attributes for scales
extra_weight_attrs
.
update
(
{
"quant_method"
:
FusedMoeWeightScaleSupported
.
TENSOR
.
value
}
)
# INPUT SCALES - Per-tensor scaling for ModelOpt
w13_input_scale
=
PerTensorScaleParameter
(
data
=
torch
.
full
((
num_experts
,),
1.0
,
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
,
)
w2_input_scale
=
PerTensorScaleParameter
(
data
=
torch
.
full
((
num_experts
,),
1.0
,
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"w13_input_scale"
,
w13_input_scale
)
layer
.
register_parameter
(
"w2_input_scale"
,
w2_input_scale
)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
"""Process FP8 MoE weights after loading from serialized checkpoint.
Only supports pre-quantized checkpoints with FP8 weights and scales.
"""
layer
.
w13_weight
=
Parameter
(
layer
.
w13_weight
.
data
,
requires_grad
=
False
)
layer
.
w2_weight
=
Parameter
(
layer
.
w2_weight
.
data
,
requires_grad
=
False
)
# Handle scale parameters
if
hasattr
(
layer
,
"w13_weight_scale"
)
and
layer
.
w13_weight_scale
is
not
None
:
# Fp8 moe kernel needs single weight scale for w13 per expert.
# We take the max of the w1 and w3 scales then dequant and requant each expert.
if
layer
.
w13_weight_scale
.
dim
()
==
2
:
# Shape: (num_experts, 2)
from
sglang.srt.layers.quantization.fp8_kernel
import
scaled_fp8_quant
# Get the maximum scale across w1 and w3 for each expert
max_w13_scales
=
layer
.
w13_weight_scale
.
max
(
dim
=
1
).
values
# Requantize each expert's weights using the combined scale
# w13_weight has shape (num_experts, 2 * intermediate_size, hidden_size)
# where the first intermediate_size rows are w1, the next are w3
intermediate_size
=
layer
.
w13_weight
.
shape
[
1
]
//
2
for
expert_id
in
range
(
layer
.
w13_weight
.
shape
[
0
]):
start
=
0
for
shard_id
in
range
(
2
):
# w1 and w3
# Dequantize using the original scale for this shard
dq_weight
=
per_tensor_dequantize
(
layer
.
w13_weight
[
expert_id
][
start
:
start
+
intermediate_size
,
:
],
layer
.
w13_weight_scale
[
expert_id
][
shard_id
],
)
# Requantize using the combined max scale
(
layer
.
w13_weight
[
expert_id
][
start
:
start
+
intermediate_size
,
:
],
_
,
)
=
scaled_fp8_quant
(
dq_weight
,
max_w13_scales
[
expert_id
])
start
+=
intermediate_size
# Update the scale parameter to be per-expert instead of per-shard
layer
.
w13_weight_scale
=
Parameter
(
max_w13_scales
,
requires_grad
=
False
)
else
:
layer
.
w13_weight_scale
=
Parameter
(
layer
.
w13_weight_scale
.
data
,
requires_grad
=
False
)
if
hasattr
(
layer
,
"w2_weight_scale"
)
and
layer
.
w2_weight_scale
is
not
None
:
layer
.
w2_weight_scale
=
Parameter
(
layer
.
w2_weight_scale
.
data
,
requires_grad
=
False
)
if
hasattr
(
layer
,
"w13_input_scale"
)
and
layer
.
w13_input_scale
is
not
None
:
layer
.
w13_input_scale
=
Parameter
(
layer
.
w13_input_scale
.
max
(),
requires_grad
=
False
)
if
hasattr
(
layer
,
"w2_input_scale"
)
and
layer
.
w2_input_scale
is
not
None
:
layer
.
w2_input_scale
=
Parameter
(
layer
.
w2_input_scale
.
max
(),
requires_grad
=
False
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
,
use_grouped_topk
:
bool
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
num_fused_shared_experts
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
activation
:
str
=
"silu"
,
apply_router_weight_on_input
:
bool
=
False
,
inplace
:
bool
=
True
,
no_combine
:
bool
=
False
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
)
->
torch
.
Tensor
:
from
sglang.srt.layers.moe.fused_moe_triton.fused_moe
import
fused_experts
from
sglang.srt.layers.moe.topk
import
select_experts
# Expert selection
topk_weights
,
topk_ids
=
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
use_grouped_topk
=
use_grouped_topk
,
top_k
=
top_k
,
renormalize
=
renormalize
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
num_fused_shared_experts
=
num_fused_shared_experts
,
custom_routing_function
=
custom_routing_function
,
correction_bias
=
correction_bias
,
routed_scaling_factor
=
routed_scaling_factor
,
)
return
fused_experts
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
inplace
,
activation
=
activation
,
use_fp8_w8a8
=
True
,
per_channel_quant
=
False
,
# ModelOpt uses per-tensor quantization
w1_scale
=
layer
.
w13_weight_scale
,
w2_scale
=
layer
.
w2_weight_scale
,
a1_scale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
no_combine
=
no_combine
,
)
class
ModelOptFp4Config
(
QuantizationConfig
):
class
ModelOptFp4Config
(
QuantizationConfig
):
"""Config class for FP4."""
"""Config class for FP4."""
...
...
python/sglang/srt/models/mllama4.py
View file @
659907e3
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment