Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
e8a69e4d
"vscode:/vscode.git/clone" did not exist on "056859fa1f671a5c4982330cd064e8cfd0ad0308"
Unverified
Commit
e8a69e4d
authored
Mar 09, 2025
by
Lianmin Zheng
Committed by
GitHub
Mar 09, 2025
Browse files
Clean up fp8 support (#4230)
parent
fbd56002
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
86 additions
and
110 deletions
+86
-110
.github/workflows/pr-test-amd.yml
.github/workflows/pr-test-amd.yml
+1
-0
python/sglang/srt/layers/quantization/fp8.py
python/sglang/srt/layers/quantization/fp8.py
+80
-101
python/sglang/srt/layers/quantization/fp8_utils.py
python/sglang/srt/layers/quantization/fp8_utils.py
+4
-7
python/sglang/srt/layers/vocab_parallel_embedding.py
python/sglang/srt/layers/vocab_parallel_embedding.py
+0
-1
test/srt/models/test_qwen_models.py
test/srt/models/test_qwen_models.py
+1
-1
No files found.
.github/workflows/pr-test-amd.yml
View file @
e8a69e4d
...
@@ -55,6 +55,7 @@ jobs:
...
@@ -55,6 +55,7 @@ jobs:
timeout-minutes
:
20
timeout-minutes
:
20
run
:
|
run
:
|
docker exec -w /sglang-checkout/test/srt ci_sglang python3 test_eval_accuracy_large.py
docker exec -w /sglang-checkout/test/srt ci_sglang python3 test_eval_accuracy_large.py
docker exec -w /sglang-checkout/test/srt ci_sglang python3 models/test_qwen_models.py
mla-test-1-gpu-amd
:
mla-test-1-gpu-amd
:
if
:
github.event.pull_request.head.repo.fork ==
false
&& github.event.pull_request.draft ==
false
if
:
github.event.pull_request.head.repo.fork ==
false
&& github.event.pull_request.draft ==
false
...
...
python/sglang/srt/layers/quantization/fp8.py
View file @
e8a69e4d
...
@@ -347,7 +347,7 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -347,7 +347,7 @@ class Fp8LinearMethod(LinearMethodBase):
weight
=
layer
.
weight
weight
=
layer
.
weight
weight_scale
=
layer
.
weight_scale
weight_scale
=
layer
.
weight_scale
# If ROCm, normalize the weights and scales to e4m3fnuz
# If ROCm, normalize the weights and scales to e4m3fnuz
if
is_hip
()
:
if
is_hip
_
:
weight
,
weight_scale
,
input_scale
=
normalize_e4m3fn_to_e4m3fnuz
(
weight
,
weight_scale
,
input_scale
=
normalize_e4m3fn_to_e4m3fnuz
(
weight
=
weight
,
weight
=
weight
,
weight_scale
=
weight_scale
,
weight_scale
=
weight_scale
,
...
@@ -624,56 +624,9 @@ class Fp8MoEMethod:
...
@@ -624,56 +624,9 @@ class Fp8MoEMethod:
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
if
get_bool_env_var
(
"USE_INT4_WEIGHT"
):
if
get_bool_env_var
(
"USE_INT4_WEIGHT"
):
# TODO: and get_bool_env_var("CK_MOE"): add after triton kernel added
self
.
process_weights_hip_int4
(
layer
)
# INT4-FP8 (INT4 MoE Weight, FP8 Compute)
# Weight Permutation
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
permute_weight
(
layer
.
w13_weight
.
data
),
requires_grad
=
False
,
)
torch
.
cuda
.
empty_cache
()
layer
.
w2_weight
=
torch
.
nn
.
Parameter
(
permute_weight
(
layer
.
w2_weight
.
data
),
requires_grad
=
False
,
)
torch
.
cuda
.
empty_cache
()
# INT4-FP8 : offset INT4 w13_weight_scale1 to single w13_weight_scale
# Fp8 moe kernel needs single fp8 w13_weight_scale for w13 per expert.
# We won't do requant each expert's fp8 weight (not direct available),
# instead we adjust half of INT4 w13_weight_scale1 numbers
assert
layer
.
w13_weight_scale
is
not
None
shard_size
=
layer
.
intermediate_size_per_partition
max_w13_scales
=
layer
.
w13_weight_scale
.
max
(
dim
=
1
).
values
for
expert_id
in
range
(
layer
.
num_experts
):
start
=
0
max_w13_scale_fp8
=
max_w13_scales
[
expert_id
]
for
shard_id
in
range
(
2
):
if
layer
.
w13_weight_scale
[
expert_id
][
shard_id
]
!=
max_w13_scale_fp8
:
int4_rescale
=
(
layer
.
w13_weight_scale
[
expert_id
][
shard_id
]
/
max_w13_scale_fp8
)
layer
.
w13_weight_scale1
[
expert_id
][
start
:
start
+
shard_size
]
*=
int4_rescale
start
+=
shard_size
layer
.
w13_weight_scale
=
torch
.
nn
.
Parameter
(
max_w13_scales
,
requires_grad
=
False
)
# special hack to asm_moe, which takes (weight_scale1 * weight_scale) as post GEMM scaling
# optimal design - shall apply per-column weight_scale1 before GEMM, and weight_scale post
for
expert_id
in
range
(
layer
.
num_experts
):
layer
.
w13_weight_scale1
[
expert_id
]
*=
max_w13_scales
[
expert_id
]
layer
.
w2_weight_scale1
[
expert_id
]
*=
layer
.
w2_weight_scale
[
expert_id
]
return
return
from
sglang.srt.layers.moe.fused_moe_triton.fused_moe
import
(
padding_size
,
# Avoid circular import
)
# Block quant doesn't need to process weights after loading
# Block quant doesn't need to process weights after loading
if
self
.
block_quant
:
if
self
.
block_quant
:
# If ROCm, normalize the weights and scales to e4m3fnuz
# If ROCm, normalize the weights and scales to e4m3fnuz
...
@@ -710,6 +663,7 @@ class Fp8MoEMethod:
...
@@ -710,6 +663,7 @@ class Fp8MoEMethod:
layer
.
w2_weight
.
contiguous
(),
(
16
,
16
)
layer
.
w2_weight
.
contiguous
(),
(
16
,
16
)
)
)
return
return
# If checkpoint is fp16 or bfloat16, quantize in place.
# If checkpoint is fp16 or bfloat16, quantize in place.
if
not
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
if
not
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
# If ROCm, use float8_e4m3fnuz instead (MI300x HW)
# If ROCm, use float8_e4m3fnuz instead (MI300x HW)
...
@@ -736,32 +690,7 @@ class Fp8MoEMethod:
...
@@ -736,32 +690,7 @@ class Fp8MoEMethod:
layer
.
w2_weight
=
torch
.
nn
.
Parameter
(
w2_weight
,
requires_grad
=
False
)
layer
.
w2_weight
=
torch
.
nn
.
Parameter
(
w2_weight
,
requires_grad
=
False
)
if
is_hip_
:
if
is_hip_
:
if
get_bool_env_var
(
"CK_MOE"
):
self
.
process_weights_hip_scale_padding
(
layer
)
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
permute_weight
(
layer
.
w13_weight
.
data
),
requires_grad
=
False
,
)
torch
.
cuda
.
empty_cache
()
layer
.
w2_weight
=
torch
.
nn
.
Parameter
(
permute_weight
(
layer
.
w2_weight
.
data
),
requires_grad
=
False
,
)
torch
.
cuda
.
empty_cache
()
# ROCm (CK_MOE): using column-wise scaling
layer
.
w13_weight_scale1
*=
layer
.
w13_weight_scale
.
unsqueeze
(
-
1
)
layer
.
w2_weight_scale1
*=
layer
.
w2_weight_scale
.
unsqueeze
(
-
1
)
elif
get_bool_env_var
(
"MOE_PADDING"
):
# If ROCm, apply weight padding (min. Mem channel contention) only if set
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
F
.
pad
(
layer
.
w13_weight
.
data
,
(
0
,
padding_size
),
"constant"
,
0
),
requires_grad
=
False
,
)
torch
.
cuda
.
empty_cache
()
layer
.
w2_weight
=
torch
.
nn
.
Parameter
(
F
.
pad
(
layer
.
w2_weight
.
data
,
(
0
,
padding_size
),
"constant"
,
0
),
requires_grad
=
False
,
)
torch
.
cuda
.
empty_cache
()
return
return
# If checkpoint is fp8, we need to handle that the
# If checkpoint is fp8, we need to handle that the
...
@@ -843,6 +772,57 @@ class Fp8MoEMethod:
...
@@ -843,6 +772,57 @@ class Fp8MoEMethod:
)
)
if
is_hip_
:
if
is_hip_
:
self
.
process_weights_hip_scale_padding
(
layer
)
return
def
process_weights_hip_int4
(
self
,
layer
:
Module
):
# TODO: and get_bool_env_var("CK_MOE"): add after triton kernel added
# INT4-FP8 (INT4 MoE Weight, FP8 Compute)
# Weight Permutation
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
permute_weight
(
layer
.
w13_weight
.
data
),
requires_grad
=
False
,
)
torch
.
cuda
.
empty_cache
()
layer
.
w2_weight
=
torch
.
nn
.
Parameter
(
permute_weight
(
layer
.
w2_weight
.
data
),
requires_grad
=
False
,
)
torch
.
cuda
.
empty_cache
()
# INT4-FP8 : offset INT4 w13_weight_scale1 to single w13_weight_scale
# Fp8 moe kernel needs single fp8 w13_weight_scale for w13 per expert.
# We won't do requant each expert's fp8 weight (not direct available),
# instead we adjust half of INT4 w13_weight_scale1 numbers
assert
layer
.
w13_weight_scale
is
not
None
shard_size
=
layer
.
intermediate_size_per_partition
max_w13_scales
=
layer
.
w13_weight_scale
.
max
(
dim
=
1
).
values
for
expert_id
in
range
(
layer
.
num_experts
):
start
=
0
max_w13_scale_fp8
=
max_w13_scales
[
expert_id
]
for
shard_id
in
range
(
2
):
if
layer
.
w13_weight_scale
[
expert_id
][
shard_id
]
!=
max_w13_scale_fp8
:
int4_rescale
=
(
layer
.
w13_weight_scale
[
expert_id
][
shard_id
]
/
max_w13_scale_fp8
)
layer
.
w13_weight_scale1
[
expert_id
][
start
:
start
+
shard_size
]
*=
int4_rescale
start
+=
shard_size
layer
.
w13_weight_scale
=
torch
.
nn
.
Parameter
(
max_w13_scales
,
requires_grad
=
False
)
# special hack to asm_moe, which takes (weight_scale1 * weight_scale) as post GEMM scaling
# optimal design - shall apply per-column weight_scale1 before GEMM, and weight_scale post
for
expert_id
in
range
(
layer
.
num_experts
):
layer
.
w13_weight_scale1
[
expert_id
]
*=
max_w13_scales
[
expert_id
]
layer
.
w2_weight_scale1
[
expert_id
]
*=
layer
.
w2_weight_scale
[
expert_id
]
def
process_weights_hip_scale_padding
(
self
,
layer
:
Module
,
padding_size
:
int
):
from
sglang.srt.layers.moe.fused_moe_triton.fused_moe
import
(
padding_size
,
# Avoid circular import
)
if
get_bool_env_var
(
"CK_MOE"
):
if
get_bool_env_var
(
"CK_MOE"
):
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
permute_weight
(
layer
.
w13_weight
.
data
),
permute_weight
(
layer
.
w13_weight
.
data
),
...
@@ -869,7 +849,6 @@ class Fp8MoEMethod:
...
@@ -869,7 +849,6 @@ class Fp8MoEMethod:
requires_grad
=
False
,
requires_grad
=
False
,
)
)
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
return
def
apply
(
def
apply
(
self
,
self
,
...
...
python/sglang/srt/layers/quantization/fp8_utils.py
View file @
e8a69e4d
import
os
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
List
,
Optional
,
Tuple
import
torch
import
torch
from
packaging.version
import
Version
from
sglang.srt.layers.quantization.fp8_kernel
import
(
from
sglang.srt.layers.quantization.fp8_kernel
import
(
per_token_group_quant_fp8
,
per_token_group_quant_fp8
,
...
@@ -13,18 +11,17 @@ from sglang.srt.utils import (
...
@@ -13,18 +11,17 @@ from sglang.srt.utils import (
get_bool_env_var
,
get_bool_env_var
,
get_cuda_version
,
get_cuda_version
,
get_device_capability
,
get_device_capability
,
is_cuda
,
is_hip
,
is_hip
,
)
)
use_vllm_cutlass_w8a8_fp8_kernel
=
os
.
environ
.
get
(
use_vllm_cutlass_w8a8_fp8_kernel
=
get_bool_env_var
(
"USE_VLLM_CUTLASS_W8A8_FP8_KERNEL"
)
"USE_VLLM_CUTLASS_W8A8_FP8_KERNEL"
,
default
=
False
)
is_hip_
=
is_hip
()
is_hip_
=
is_hip
()
if
is_hip_
and
get_bool_env_var
(
"CK_MOE"
):
if
is_hip_
and
get_bool_env_var
(
"CK_MOE"
):
from
aiter
import
gemm_a8w8_blockscale
from
aiter
import
gemm_a8w8_blockscale
_is_cuda
=
torch
.
cuda
.
is_available
()
and
torch
.
version
.
cuda
_is_cuda
=
is_
cuda
()
if
_is_cuda
:
if
_is_cuda
:
from
sgl_kernel
import
fp8_blockwise_scaled_mm
from
sgl_kernel
import
fp8_blockwise_scaled_mm
...
@@ -73,7 +70,7 @@ def normalize_e4m3fn_to_e4m3fnuz(
...
@@ -73,7 +70,7 @@ def normalize_e4m3fn_to_e4m3fnuz(
def
cutlass_block_fp8_supported
()
->
bool
:
def
cutlass_block_fp8_supported
()
->
bool
:
if
os
.
environ
.
get
(
"SUPPORT_CUTLASS_BLOCK_FP8"
)
is
None
:
if
get_bool_env_var
(
"SUPPORT_CUTLASS_BLOCK_FP8"
):
return
False
return
False
if
_is_cuda
:
if
_is_cuda
:
major
,
minor
=
torch
.
cuda
.
get_device_capability
()
major
,
minor
=
torch
.
cuda
.
get_device_capability
()
...
...
python/sglang/srt/layers/vocab_parallel_embedding.py
View file @
e8a69e4d
...
@@ -264,7 +264,6 @@ class VocabParallelEmbedding(torch.nn.Module):
...
@@ -264,7 +264,6 @@ class VocabParallelEmbedding(torch.nn.Module):
quant_method
=
None
quant_method
=
None
if
quant_config
is
not
None
:
if
quant_config
is
not
None
:
quant_method
=
quant_config
.
get_quant_method
(
self
,
prefix
=
prefix
)
quant_method
=
quant_config
.
get_quant_method
(
self
,
prefix
=
prefix
)
print
(
"quant_method"
,
quant_method
)
if
quant_method
is
None
:
if
quant_method
is
None
:
quant_method
=
UnquantizedEmbeddingMethod
()
quant_method
=
UnquantizedEmbeddingMethod
()
...
...
test/srt/models/test_qwen_models.py
View file @
e8a69e4d
...
@@ -69,7 +69,7 @@ class TestQwen2FP8(unittest.TestCase):
...
@@ -69,7 +69,7 @@ class TestQwen2FP8(unittest.TestCase):
)
)
metrics
=
run_eval
(
args
)
metrics
=
run_eval
(
args
)
print
(
f
"
{
metrics
=
}
"
)
print
(
f
"
{
metrics
=
}
"
)
self
.
assertGreater
(
metrics
[
"accuracy"
],
0.7
9
)
self
.
assertGreater
(
metrics
[
"accuracy"
],
0.7
8
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment