Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
e8a69e4d
Unverified
Commit
e8a69e4d
authored
Mar 09, 2025
by
Lianmin Zheng
Committed by
GitHub
Mar 09, 2025
Browse files
Clean up fp8 support (#4230)
parent
fbd56002
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
86 additions
and
110 deletions
+86
-110
.github/workflows/pr-test-amd.yml
.github/workflows/pr-test-amd.yml
+1
-0
python/sglang/srt/layers/quantization/fp8.py
python/sglang/srt/layers/quantization/fp8.py
+80
-101
python/sglang/srt/layers/quantization/fp8_utils.py
python/sglang/srt/layers/quantization/fp8_utils.py
+4
-7
python/sglang/srt/layers/vocab_parallel_embedding.py
python/sglang/srt/layers/vocab_parallel_embedding.py
+0
-1
test/srt/models/test_qwen_models.py
test/srt/models/test_qwen_models.py
+1
-1
No files found.
.github/workflows/pr-test-amd.yml
View file @
e8a69e4d
...
@@ -55,6 +55,7 @@ jobs:
...
@@ -55,6 +55,7 @@ jobs:
timeout-minutes
:
20
timeout-minutes
:
20
run
:
|
run
:
|
docker exec -w /sglang-checkout/test/srt ci_sglang python3 test_eval_accuracy_large.py
docker exec -w /sglang-checkout/test/srt ci_sglang python3 test_eval_accuracy_large.py
docker exec -w /sglang-checkout/test/srt ci_sglang python3 models/test_qwen_models.py
mla-test-1-gpu-amd
:
mla-test-1-gpu-amd
:
if
:
github.event.pull_request.head.repo.fork ==
false
&& github.event.pull_request.draft ==
false
if
:
github.event.pull_request.head.repo.fork ==
false
&& github.event.pull_request.draft ==
false
...
...
python/sglang/srt/layers/quantization/fp8.py
View file @
e8a69e4d
...
@@ -347,7 +347,7 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -347,7 +347,7 @@ class Fp8LinearMethod(LinearMethodBase):
weight
=
layer
.
weight
weight
=
layer
.
weight
weight_scale
=
layer
.
weight_scale
weight_scale
=
layer
.
weight_scale
# If ROCm, normalize the weights and scales to e4m3fnuz
# If ROCm, normalize the weights and scales to e4m3fnuz
if
is_hip
()
:
if
is_hip
_
:
weight
,
weight_scale
,
input_scale
=
normalize_e4m3fn_to_e4m3fnuz
(
weight
,
weight_scale
,
input_scale
=
normalize_e4m3fn_to_e4m3fnuz
(
weight
=
weight
,
weight
=
weight
,
weight_scale
=
weight_scale
,
weight_scale
=
weight_scale
,
...
@@ -624,56 +624,9 @@ class Fp8MoEMethod:
...
@@ -624,56 +624,9 @@ class Fp8MoEMethod:
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
if
get_bool_env_var
(
"USE_INT4_WEIGHT"
):
if
get_bool_env_var
(
"USE_INT4_WEIGHT"
):
# TODO: and get_bool_env_var("CK_MOE"): add after triton kernel added
self
.
process_weights_hip_int4
(
layer
)
# INT4-FP8 (INT4 MoE Weight, FP8 Compute)
# Weight Permutation
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
permute_weight
(
layer
.
w13_weight
.
data
),
requires_grad
=
False
,
)
torch
.
cuda
.
empty_cache
()
layer
.
w2_weight
=
torch
.
nn
.
Parameter
(
permute_weight
(
layer
.
w2_weight
.
data
),
requires_grad
=
False
,
)
torch
.
cuda
.
empty_cache
()
# INT4-FP8 : offset INT4 w13_weight_scale1 to single w13_weight_scale
# Fp8 moe kernel needs single fp8 w13_weight_scale for w13 per expert.
# We won't do requant each expert's fp8 weight (not direct available),
# instead we adjust half of INT4 w13_weight_scale1 numbers
assert
layer
.
w13_weight_scale
is
not
None
shard_size
=
layer
.
intermediate_size_per_partition
max_w13_scales
=
layer
.
w13_weight_scale
.
max
(
dim
=
1
).
values
for
expert_id
in
range
(
layer
.
num_experts
):
start
=
0
max_w13_scale_fp8
=
max_w13_scales
[
expert_id
]
for
shard_id
in
range
(
2
):
if
layer
.
w13_weight_scale
[
expert_id
][
shard_id
]
!=
max_w13_scale_fp8
:
int4_rescale
=
(
layer
.
w13_weight_scale
[
expert_id
][
shard_id
]
/
max_w13_scale_fp8
)
layer
.
w13_weight_scale1
[
expert_id
][
start
:
start
+
shard_size
]
*=
int4_rescale
start
+=
shard_size
layer
.
w13_weight_scale
=
torch
.
nn
.
Parameter
(
max_w13_scales
,
requires_grad
=
False
)
# special hack to asm_moe, which takes (weight_scale1 * weight_scale) as post GEMM scaling
# optimal design - shall apply per-column weight_scale1 before GEMM, and weight_scale post
for
expert_id
in
range
(
layer
.
num_experts
):
layer
.
w13_weight_scale1
[
expert_id
]
*=
max_w13_scales
[
expert_id
]
layer
.
w2_weight_scale1
[
expert_id
]
*=
layer
.
w2_weight_scale
[
expert_id
]
return
return
from
sglang.srt.layers.moe.fused_moe_triton.fused_moe
import
(
padding_size
,
# Avoid circular import
)
# Block quant doesn't need to process weights after loading
# Block quant doesn't need to process weights after loading
if
self
.
block_quant
:
if
self
.
block_quant
:
# If ROCm, normalize the weights and scales to e4m3fnuz
# If ROCm, normalize the weights and scales to e4m3fnuz
...
@@ -710,6 +663,7 @@ class Fp8MoEMethod:
...
@@ -710,6 +663,7 @@ class Fp8MoEMethod:
layer
.
w2_weight
.
contiguous
(),
(
16
,
16
)
layer
.
w2_weight
.
contiguous
(),
(
16
,
16
)
)
)
return
return
# If checkpoint is fp16 or bfloat16, quantize in place.
# If checkpoint is fp16 or bfloat16, quantize in place.
if
not
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
if
not
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
# If ROCm, use float8_e4m3fnuz instead (MI300x HW)
# If ROCm, use float8_e4m3fnuz instead (MI300x HW)
...
@@ -736,32 +690,7 @@ class Fp8MoEMethod:
...
@@ -736,32 +690,7 @@ class Fp8MoEMethod:
layer
.
w2_weight
=
torch
.
nn
.
Parameter
(
w2_weight
,
requires_grad
=
False
)
layer
.
w2_weight
=
torch
.
nn
.
Parameter
(
w2_weight
,
requires_grad
=
False
)
if
is_hip_
:
if
is_hip_
:
if
get_bool_env_var
(
"CK_MOE"
):
self
.
process_weights_hip_scale_padding
(
layer
)
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
permute_weight
(
layer
.
w13_weight
.
data
),
requires_grad
=
False
,
)
torch
.
cuda
.
empty_cache
()
layer
.
w2_weight
=
torch
.
nn
.
Parameter
(
permute_weight
(
layer
.
w2_weight
.
data
),
requires_grad
=
False
,
)
torch
.
cuda
.
empty_cache
()
# ROCm (CK_MOE): using column-wise scaling
layer
.
w13_weight_scale1
*=
layer
.
w13_weight_scale
.
unsqueeze
(
-
1
)
layer
.
w2_weight_scale1
*=
layer
.
w2_weight_scale
.
unsqueeze
(
-
1
)
elif
get_bool_env_var
(
"MOE_PADDING"
):
# If ROCm, apply weight padding (min. Mem channel contention) only if set
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
F
.
pad
(
layer
.
w13_weight
.
data
,
(
0
,
padding_size
),
"constant"
,
0
),
requires_grad
=
False
,
)
torch
.
cuda
.
empty_cache
()
layer
.
w2_weight
=
torch
.
nn
.
Parameter
(
F
.
pad
(
layer
.
w2_weight
.
data
,
(
0
,
padding_size
),
"constant"
,
0
),
requires_grad
=
False
,
)
torch
.
cuda
.
empty_cache
()
return
return
# If checkpoint is fp8, we need to handle that the
# If checkpoint is fp8, we need to handle that the
...
@@ -843,34 +772,84 @@ class Fp8MoEMethod:
...
@@ -843,34 +772,84 @@ class Fp8MoEMethod:
)
)
if
is_hip_
:
if
is_hip_
:
if
get_bool_env_var
(
"CK_MOE"
):
self
.
process_weights_hip_scale_padding
(
layer
)
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
permute_weight
(
layer
.
w13_weight
.
data
),
requires_grad
=
False
,
)
torch
.
cuda
.
empty_cache
()
layer
.
w2_weight
=
torch
.
nn
.
Parameter
(
permute_weight
(
layer
.
w2_weight
.
data
),
requires_grad
=
False
,
)
torch
.
cuda
.
empty_cache
()
# ROCm (CK_MOE): using column-wise scaling
layer
.
w13_weight_scale1
*=
layer
.
w13_weight_scale
.
unsqueeze
(
-
1
)
layer
.
w2_weight_scale1
*=
layer
.
w2_weight_scale
.
unsqueeze
(
-
1
)
elif
get_bool_env_var
(
"MOE_PADDING"
):
# If ROCm, apply weight padding (min. Mem channel contention) only if set
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
F
.
pad
(
layer
.
w13_weight
.
data
,
(
0
,
padding_size
),
"constant"
,
0
),
requires_grad
=
False
,
)
torch
.
cuda
.
empty_cache
()
layer
.
w2_weight
=
torch
.
nn
.
Parameter
(
F
.
pad
(
layer
.
w2_weight
.
data
,
(
0
,
padding_size
),
"constant"
,
0
),
requires_grad
=
False
,
)
torch
.
cuda
.
empty_cache
()
return
return
def
process_weights_hip_int4
(
self
,
layer
:
Module
):
# TODO: and get_bool_env_var("CK_MOE"): add after triton kernel added
# INT4-FP8 (INT4 MoE Weight, FP8 Compute)
# Weight Permutation
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
permute_weight
(
layer
.
w13_weight
.
data
),
requires_grad
=
False
,
)
torch
.
cuda
.
empty_cache
()
layer
.
w2_weight
=
torch
.
nn
.
Parameter
(
permute_weight
(
layer
.
w2_weight
.
data
),
requires_grad
=
False
,
)
torch
.
cuda
.
empty_cache
()
# INT4-FP8 : offset INT4 w13_weight_scale1 to single w13_weight_scale
# Fp8 moe kernel needs single fp8 w13_weight_scale for w13 per expert.
# We won't do requant each expert's fp8 weight (not direct available),
# instead we adjust half of INT4 w13_weight_scale1 numbers
assert
layer
.
w13_weight_scale
is
not
None
shard_size
=
layer
.
intermediate_size_per_partition
max_w13_scales
=
layer
.
w13_weight_scale
.
max
(
dim
=
1
).
values
for
expert_id
in
range
(
layer
.
num_experts
):
start
=
0
max_w13_scale_fp8
=
max_w13_scales
[
expert_id
]
for
shard_id
in
range
(
2
):
if
layer
.
w13_weight_scale
[
expert_id
][
shard_id
]
!=
max_w13_scale_fp8
:
int4_rescale
=
(
layer
.
w13_weight_scale
[
expert_id
][
shard_id
]
/
max_w13_scale_fp8
)
layer
.
w13_weight_scale1
[
expert_id
][
start
:
start
+
shard_size
]
*=
int4_rescale
start
+=
shard_size
layer
.
w13_weight_scale
=
torch
.
nn
.
Parameter
(
max_w13_scales
,
requires_grad
=
False
)
# special hack to asm_moe, which takes (weight_scale1 * weight_scale) as post GEMM scaling
# optimal design - shall apply per-column weight_scale1 before GEMM, and weight_scale post
for
expert_id
in
range
(
layer
.
num_experts
):
layer
.
w13_weight_scale1
[
expert_id
]
*=
max_w13_scales
[
expert_id
]
layer
.
w2_weight_scale1
[
expert_id
]
*=
layer
.
w2_weight_scale
[
expert_id
]
def
process_weights_hip_scale_padding
(
self
,
layer
:
Module
,
padding_size
:
int
):
from
sglang.srt.layers.moe.fused_moe_triton.fused_moe
import
(
padding_size
,
# Avoid circular import
)
if
get_bool_env_var
(
"CK_MOE"
):
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
permute_weight
(
layer
.
w13_weight
.
data
),
requires_grad
=
False
,
)
torch
.
cuda
.
empty_cache
()
layer
.
w2_weight
=
torch
.
nn
.
Parameter
(
permute_weight
(
layer
.
w2_weight
.
data
),
requires_grad
=
False
,
)
torch
.
cuda
.
empty_cache
()
# ROCm (CK_MOE): using column-wise scaling
layer
.
w13_weight_scale1
*=
layer
.
w13_weight_scale
.
unsqueeze
(
-
1
)
layer
.
w2_weight_scale1
*=
layer
.
w2_weight_scale
.
unsqueeze
(
-
1
)
elif
get_bool_env_var
(
"MOE_PADDING"
):
# If ROCm, apply weight padding (min. Mem channel contention) only if set
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
F
.
pad
(
layer
.
w13_weight
.
data
,
(
0
,
padding_size
),
"constant"
,
0
),
requires_grad
=
False
,
)
torch
.
cuda
.
empty_cache
()
layer
.
w2_weight
=
torch
.
nn
.
Parameter
(
F
.
pad
(
layer
.
w2_weight
.
data
,
(
0
,
padding_size
),
"constant"
,
0
),
requires_grad
=
False
,
)
torch
.
cuda
.
empty_cache
()
def
apply
(
def
apply
(
self
,
self
,
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
...
...
python/sglang/srt/layers/quantization/fp8_utils.py
View file @
e8a69e4d
import
os
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
List
,
Optional
,
Tuple
import
torch
import
torch
from
packaging.version
import
Version
from
sglang.srt.layers.quantization.fp8_kernel
import
(
from
sglang.srt.layers.quantization.fp8_kernel
import
(
per_token_group_quant_fp8
,
per_token_group_quant_fp8
,
...
@@ -13,18 +11,17 @@ from sglang.srt.utils import (
...
@@ -13,18 +11,17 @@ from sglang.srt.utils import (
get_bool_env_var
,
get_bool_env_var
,
get_cuda_version
,
get_cuda_version
,
get_device_capability
,
get_device_capability
,
is_cuda
,
is_hip
,
is_hip
,
)
)
use_vllm_cutlass_w8a8_fp8_kernel
=
os
.
environ
.
get
(
use_vllm_cutlass_w8a8_fp8_kernel
=
get_bool_env_var
(
"USE_VLLM_CUTLASS_W8A8_FP8_KERNEL"
)
"USE_VLLM_CUTLASS_W8A8_FP8_KERNEL"
,
default
=
False
)
is_hip_
=
is_hip
()
is_hip_
=
is_hip
()
if
is_hip_
and
get_bool_env_var
(
"CK_MOE"
):
if
is_hip_
and
get_bool_env_var
(
"CK_MOE"
):
from
aiter
import
gemm_a8w8_blockscale
from
aiter
import
gemm_a8w8_blockscale
_is_cuda
=
torch
.
cuda
.
is_available
()
and
torch
.
version
.
cuda
_is_cuda
=
is_
cuda
()
if
_is_cuda
:
if
_is_cuda
:
from
sgl_kernel
import
fp8_blockwise_scaled_mm
from
sgl_kernel
import
fp8_blockwise_scaled_mm
...
@@ -73,7 +70,7 @@ def normalize_e4m3fn_to_e4m3fnuz(
...
@@ -73,7 +70,7 @@ def normalize_e4m3fn_to_e4m3fnuz(
def
cutlass_block_fp8_supported
()
->
bool
:
def
cutlass_block_fp8_supported
()
->
bool
:
if
os
.
environ
.
get
(
"SUPPORT_CUTLASS_BLOCK_FP8"
)
is
None
:
if
get_bool_env_var
(
"SUPPORT_CUTLASS_BLOCK_FP8"
):
return
False
return
False
if
_is_cuda
:
if
_is_cuda
:
major
,
minor
=
torch
.
cuda
.
get_device_capability
()
major
,
minor
=
torch
.
cuda
.
get_device_capability
()
...
...
python/sglang/srt/layers/vocab_parallel_embedding.py
View file @
e8a69e4d
...
@@ -264,7 +264,6 @@ class VocabParallelEmbedding(torch.nn.Module):
...
@@ -264,7 +264,6 @@ class VocabParallelEmbedding(torch.nn.Module):
quant_method
=
None
quant_method
=
None
if
quant_config
is
not
None
:
if
quant_config
is
not
None
:
quant_method
=
quant_config
.
get_quant_method
(
self
,
prefix
=
prefix
)
quant_method
=
quant_config
.
get_quant_method
(
self
,
prefix
=
prefix
)
print
(
"quant_method"
,
quant_method
)
if
quant_method
is
None
:
if
quant_method
is
None
:
quant_method
=
UnquantizedEmbeddingMethod
()
quant_method
=
UnquantizedEmbeddingMethod
()
...
...
test/srt/models/test_qwen_models.py
View file @
e8a69e4d
...
@@ -69,7 +69,7 @@ class TestQwen2FP8(unittest.TestCase):
...
@@ -69,7 +69,7 @@ class TestQwen2FP8(unittest.TestCase):
)
)
metrics
=
run_eval
(
args
)
metrics
=
run_eval
(
args
)
print
(
f
"
{
metrics
=
}
"
)
print
(
f
"
{
metrics
=
}
"
)
self
.
assertGreater
(
metrics
[
"accuracy"
],
0.7
9
)
self
.
assertGreater
(
metrics
[
"accuracy"
],
0.7
8
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment