Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
13bc39c5
Unverified
Commit
13bc39c5
authored
Mar 06, 2025
by
HAI
Committed by
GitHub
Mar 06, 2025
Browse files
ROCm: enable trillion-parameter MoE models with INT4-FP8 single node (#4152)
parent
9854a18a
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
124 additions
and
23 deletions
+124
-23
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
+12
-0
python/sglang/srt/layers/quantization/fp8.py
python/sglang/srt/layers/quantization/fp8.py
+110
-22
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+2
-1
No files found.
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
View file @
13bc39c5
...
@@ -513,6 +513,10 @@ class FusedMoE(torch.nn.Module):
...
@@ -513,6 +513,10 @@ class FusedMoE(torch.nn.Module):
# Case input scale: input_scale loading is only supported for fp8
# Case input scale: input_scale loading is only supported for fp8
if
"input_scale"
in
weight_name
:
if
"input_scale"
in
weight_name
:
# INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust input_scale for e4m3fnuz (AMD)
if
is_hip_
and
get_bool_env_var
(
"USE_INT4_WEIGHT"
):
loaded_weight
=
loaded_weight
*
2.0
# this is needed for compressed-tensors only
# this is needed for compressed-tensors only
loaded_weight
=
loaded_weight
.
to
(
param
.
data
.
device
)
loaded_weight
=
loaded_weight
.
to
(
param
.
data
.
device
)
...
@@ -551,6 +555,10 @@ class FusedMoE(torch.nn.Module):
...
@@ -551,6 +555,10 @@ class FusedMoE(torch.nn.Module):
# specific to each case
# specific to each case
quant_method
=
getattr
(
param
,
"quant_method"
,
None
)
quant_method
=
getattr
(
param
,
"quant_method"
,
None
)
if
quant_method
==
FusedMoeWeightScaleSupported
.
CHANNEL
.
value
:
if
quant_method
==
FusedMoeWeightScaleSupported
.
CHANNEL
.
value
:
# INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust INT4 column-wise scaling number to e4m3fnuz (AMD)
if
is_hip_
and
get_bool_env_var
(
"USE_INT4_WEIGHT"
):
loaded_weight
=
loaded_weight
*
0.5
self
.
_load_per_channel_weight_scale
(
self
.
_load_per_channel_weight_scale
(
shard_id
=
shard_id
,
shard_id
=
shard_id
,
shard_dim
=
shard_dim
,
shard_dim
=
shard_dim
,
...
@@ -570,6 +578,10 @@ class FusedMoE(torch.nn.Module):
...
@@ -570,6 +578,10 @@ class FusedMoE(torch.nn.Module):
tp_rank
=
tp_rank
,
tp_rank
=
tp_rank
,
)
)
elif
quant_method
==
FusedMoeWeightScaleSupported
.
TENSOR
.
value
:
elif
quant_method
==
FusedMoeWeightScaleSupported
.
TENSOR
.
value
:
# INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust FP8 per-tensor scaling number for e4m3fnuz (AMD)
if
is_hip_
and
get_bool_env_var
(
"USE_INT4_WEIGHT"
):
loaded_weight
=
loaded_weight
*
2.0
self
.
_load_per_tensor_weight_scale
(
self
.
_load_per_tensor_weight_scale
(
shard_id
=
shard_id
,
shard_id
=
shard_id
,
param
=
param
,
param
=
param
,
...
...
python/sglang/srt/layers/quantization/fp8.py
View file @
13bc39c5
...
@@ -460,7 +460,11 @@ class Fp8MoEMethod:
...
@@ -460,7 +460,11 @@ class Fp8MoEMethod:
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoeWeightScaleSupported
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoeWeightScaleSupported
if
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
if
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
params_dtype
=
torch
.
float8_e4m3fn
params_dtype
=
(
torch
.
int32
if
get_bool_env_var
(
"USE_INT4_WEIGHT"
)
else
torch
.
float8_e4m3fn
)
tp_size
=
get_tensor_model_parallel_world_size
()
tp_size
=
get_tensor_model_parallel_world_size
()
if
self
.
block_quant
:
if
self
.
block_quant
:
block_n
,
block_k
=
(
block_n
,
block_k
=
(
...
@@ -485,21 +489,40 @@ class Fp8MoEMethod:
...
@@ -485,21 +489,40 @@ class Fp8MoEMethod:
)
)
# WEIGHTS
# WEIGHTS
w13_weight
=
torch
.
nn
.
Parameter
(
if
get_bool_env_var
(
"USE_INT4_WEIGHT"
):
torch
.
empty
(
# INT4 MoE weight - INT32 packed
num_experts
,
2
*
intermediate_size
,
hidden_size
,
dtype
=
params_dtype
w13_weight
=
torch
.
nn
.
Parameter
(
),
torch
.
empty
(
requires_grad
=
False
,
num_experts
,
)
2
*
intermediate_size
,
hidden_size
//
8
,
dtype
=
params_dtype
,
),
requires_grad
=
False
,
)
w2_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
,
intermediate_size
//
8
,
dtype
=
params_dtype
),
requires_grad
=
False
,
)
else
:
w13_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
2
*
intermediate_size
,
hidden_size
,
dtype
=
params_dtype
),
requires_grad
=
False
,
)
w2_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
,
intermediate_size
,
dtype
=
params_dtype
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_weight"
,
w13_weight
)
layer
.
register_parameter
(
"w13_weight"
,
w13_weight
)
set_weight_attrs
(
w13_weight
,
extra_weight_attrs
)
set_weight_attrs
(
w13_weight
,
extra_weight_attrs
)
w2_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
,
intermediate_size
,
dtype
=
params_dtype
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
...
@@ -538,7 +561,9 @@ class Fp8MoEMethod:
...
@@ -538,7 +561,9 @@ class Fp8MoEMethod:
layer
.
register_parameter
(
"w13_weight_scale"
,
w13_weight_scale
)
layer
.
register_parameter
(
"w13_weight_scale"
,
w13_weight_scale
)
layer
.
register_parameter
(
"w2_weight_scale"
,
w2_weight_scale
)
layer
.
register_parameter
(
"w2_weight_scale"
,
w2_weight_scale
)
if
is_hip_
and
get_bool_env_var
(
"CK_MOE"
):
if
(
is_hip_
):
# and get_bool_env_var("CK_MOE"): TODO: add check back after triton kernel
# ROCm - using column scaling, duplicate scaling numbers in case per tensor scaling
# ROCm - using column scaling, duplicate scaling numbers in case per tensor scaling
w13_weight_scale1
=
torch
.
nn
.
Parameter
(
w13_weight_scale1
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
2
*
intermediate_size
,
dtype
=
torch
.
float32
),
torch
.
ones
(
num_experts
,
2
*
intermediate_size
,
dtype
=
torch
.
float32
),
...
@@ -565,6 +590,13 @@ class Fp8MoEMethod:
...
@@ -565,6 +590,13 @@ class Fp8MoEMethod:
set_weight_attrs
(
w13_weight_scale
,
extra_weight_attrs
)
set_weight_attrs
(
w13_weight_scale
,
extra_weight_attrs
)
set_weight_attrs
(
w2_weight_scale
,
extra_weight_attrs
)
set_weight_attrs
(
w2_weight_scale
,
extra_weight_attrs
)
if
get_bool_env_var
(
"USE_INT4_WEIGHT"
):
extra_weight_attrs
.
update
(
{
"quant_method"
:
FusedMoeWeightScaleSupported
.
CHANNEL
.
value
}
)
set_weight_attrs
(
w13_weight_scale1
,
extra_weight_attrs
)
set_weight_attrs
(
w2_weight_scale1
,
extra_weight_attrs
)
# INPUT_SCALES
# INPUT_SCALES
if
self
.
quant_config
.
activation_scheme
==
"static"
:
if
self
.
quant_config
.
activation_scheme
==
"static"
:
if
not
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
if
not
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
...
@@ -590,6 +622,53 @@ class Fp8MoEMethod:
...
@@ -590,6 +622,53 @@ class Fp8MoEMethod:
layer
.
w2_input_scale
=
None
layer
.
w2_input_scale
=
None
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
if
get_bool_env_var
(
"USE_INT4_WEIGHT"
):
# TODO: and get_bool_env_var("CK_MOE"): add after triton kernel added
# INT4-FP8 (INT4 MoE Weight, FP8 Compute)
# Weight Permutation
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
permute_weight
(
layer
.
w13_weight
.
data
),
requires_grad
=
False
,
)
torch
.
cuda
.
empty_cache
()
layer
.
w2_weight
=
torch
.
nn
.
Parameter
(
permute_weight
(
layer
.
w2_weight
.
data
),
requires_grad
=
False
,
)
torch
.
cuda
.
empty_cache
()
# INT4-FP8 : offset INT4 w13_weight_scale1 to single w13_weight_scale
# Fp8 moe kernel needs single fp8 w13_weight_scale for w13 per expert.
# We won't do requant each expert's fp8 weight (not direct available),
# instead we adjust half of INT4 w13_weight_scale1 numbers
assert
layer
.
w13_weight_scale
is
not
None
shard_size
=
layer
.
intermediate_size_per_partition
max_w13_scales
=
layer
.
w13_weight_scale
.
max
(
dim
=
1
).
values
for
expert_id
in
range
(
layer
.
num_experts
):
start
=
0
max_w13_scale_fp8
=
max_w13_scales
[
expert_id
]
for
shard_id
in
range
(
2
):
if
layer
.
w13_weight_scale
[
expert_id
][
shard_id
]
!=
max_w13_scale_fp8
:
int4_rescale
=
(
layer
.
w13_weight_scale
[
expert_id
][
shard_id
]
/
max_w13_scale_fp8
)
layer
.
w13_weight_scale1
[
expert_id
][
start
:
start
+
shard_size
]
*=
int4_rescale
start
+=
shard_size
layer
.
w13_weight_scale
=
torch
.
nn
.
Parameter
(
max_w13_scales
,
requires_grad
=
False
)
# special hack to asm_moe, which takes (weight_scale1 * weight_scale) as post GEMM scaling
# optimal design - shall apply per-column weight_scale1 before GEMM, and weight_scale post
for
expert_id
in
range
(
layer
.
num_experts
):
layer
.
w13_weight_scale1
[
expert_id
]
*=
max_w13_scales
[
expert_id
]
layer
.
w2_weight_scale1
[
expert_id
]
*=
layer
.
w2_weight_scale
[
expert_id
]
return
from
sglang.srt.layers.moe.fused_moe_triton.fused_moe
import
(
from
sglang.srt.layers.moe.fused_moe_triton.fused_moe
import
(
padding_size
,
# Avoid circular import
padding_size
,
# Avoid circular import
)
)
...
@@ -823,8 +902,24 @@ class Fp8MoEMethod:
...
@@ -823,8 +902,24 @@ class Fp8MoEMethod:
correction_bias
=
correction_bias
,
correction_bias
=
correction_bias
,
)
)
if
is_hip_
and
get_bool_env_var
(
"CK_MOE"
)
and
activation
==
"silu"
:
if
is_hip_
and
get_bool_env_var
(
"USE_INT4_WEIGHT"
):
# TODO: add triton kernel and add check get_bool_env_var("CK_MOE")
assert
not
no_combine
,
f
"
{
no_combine
=
}
is not supported."
return
asm_moe
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
topk_weights
,
topk_ids
,
layer
.
w13_weight_scale1
,
layer
.
w2_weight_scale1
,
activation
=
activation
,
)
if
is_hip_
and
get_bool_env_var
(
"CK_MOE"
):
# TODO(CK_MOE): FP8 or FP8 block_quant only supports 'silu' for the time-being.
# TODO(CK_MOE): FP8 or FP8 block_quant only supports 'silu' for the time-being.
assert
(
activation
==
"silu"
),
f
"CK_MOE: FP8 and/or FP8 bloack_quant
{
activation
=
}
will be supported later, unset CK_MOE"
assert
not
no_combine
,
f
"
{
no_combine
=
}
is not supported."
assert
not
no_combine
,
f
"
{
no_combine
=
}
is not supported."
if
self
.
block_quant
:
if
self
.
block_quant
:
return
asm_moe
(
return
asm_moe
(
...
@@ -835,10 +930,6 @@ class Fp8MoEMethod:
...
@@ -835,10 +930,6 @@ class Fp8MoEMethod:
topk_ids
,
topk_ids
,
layer
.
w13_weight_scale_inv
,
layer
.
w13_weight_scale_inv
,
layer
.
w2_weight_scale_inv
,
layer
.
w2_weight_scale_inv
,
None
,
None
,
False
,
None
,
block_shape
=
tuple
(
self
.
quant_config
.
weight_block_size
),
block_shape
=
tuple
(
self
.
quant_config
.
weight_block_size
),
expert_mask
=
None
,
expert_mask
=
None
,
)
)
...
@@ -851,9 +942,6 @@ class Fp8MoEMethod:
...
@@ -851,9 +942,6 @@ class Fp8MoEMethod:
topk_ids
,
topk_ids
,
layer
.
w13_weight_scale1
,
layer
.
w13_weight_scale1
,
layer
.
w2_weight_scale1
,
layer
.
w2_weight_scale1
,
None
,
None
,
False
,
)
)
else
:
else
:
# Expert fusion with FP8 quantization
# Expert fusion with FP8 quantization
...
...
python/sglang/srt/utils.py
View file @
13bc39c5
...
@@ -1269,7 +1269,8 @@ def permute_weight(x: torch.Tensor) -> torch.Tensor:
...
@@ -1269,7 +1269,8 @@ def permute_weight(x: torch.Tensor) -> torch.Tensor:
elif
x
.
dtype
==
torch
.
float8_e4m3fnuz
or
x
.
dtype
==
torch
.
int8
:
elif
x
.
dtype
==
torch
.
float8_e4m3fnuz
or
x
.
dtype
==
torch
.
int8
:
x_
=
x_
.
view
(
int
(
b_
),
int
(
n_
/
16
),
16
,
int
(
k_
/
64
),
4
,
16
)
x_
=
x_
.
view
(
int
(
b_
),
int
(
n_
/
16
),
16
,
int
(
k_
/
64
),
4
,
16
)
else
:
else
:
return
x_
# return x_
x_
=
x_
.
view
(
int
(
b_
),
int
(
n_
/
16
),
16
,
int
(
k_
/
8
),
2
,
4
)
x_
=
x_
.
permute
(
0
,
1
,
3
,
4
,
2
,
5
)
x_
=
x_
.
permute
(
0
,
1
,
3
,
4
,
2
,
5
)
x_
=
x_
.
contiguous
()
x_
=
x_
.
contiguous
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment