Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
b70957fc
Unverified
Commit
b70957fc
authored
May 08, 2025
by
JieXin Liang
Committed by
GitHub
May 07, 2025
Browse files
[refactor] slightly tidy fp8 module (#5993)
parent
e444c13f
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
239 additions
and
232 deletions
+239
-232
python/sglang/srt/layers/moe/ep_moe/kernels.py
python/sglang/srt/layers/moe/ep_moe/kernels.py
+2
-5
python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py
...quantization/compressed_tensors/compressed_tensors_moe.py
+2
-4
python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
...compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
+2
-1
python/sglang/srt/layers/quantization/fp8.py
python/sglang/srt/layers/quantization/fp8.py
+106
-91
python/sglang/srt/layers/quantization/fp8_kernel.py
python/sglang/srt/layers/quantization/fp8_kernel.py
+73
-55
python/sglang/srt/layers/quantization/fp8_utils.py
python/sglang/srt/layers/quantization/fp8_utils.py
+36
-23
python/sglang/srt/layers/quantization/kv_cache.py
python/sglang/srt/layers/quantization/kv_cache.py
+3
-10
python/sglang/srt/layers/quantization/utils.py
python/sglang/srt/layers/quantization/utils.py
+0
-5
python/sglang/srt/layers/quantization/w8a8_fp8.py
python/sglang/srt/layers/quantization/w8a8_fp8.py
+8
-10
python/sglang/srt/models/deepseek_nextn.py
python/sglang/srt/models/deepseek_nextn.py
+1
-20
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+4
-6
python/sglang/test/test_block_fp8.py
python/sglang/test/test_block_fp8.py
+2
-2
No files found.
python/sglang/srt/layers/moe/ep_moe/kernels.py
View file @
b70957fc
...
...
@@ -12,7 +12,7 @@ from sglang.srt.utils import is_cuda
_is_cuda
=
is_cuda
()
if
_is_cuda
:
from
sglang.srt.layers.quantization.fp8_kernel
import
(
sglang_per_token_group_quant_fp8
,
sglang_per_token_group_quant_fp8
as
per_token_group_quant_fp8
,
)
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -654,10 +654,7 @@ def grouped_gemm_triton(
if
block_shape
is
not
None
:
assert
len
(
block_shape
)
==
2
block_n
,
block_k
=
block_shape
[
0
],
block_shape
[
1
]
if
_is_cuda
:
a
,
scale_a
=
sglang_per_token_group_quant_fp8
(
a
,
block_k
)
else
:
a
,
scale_a
=
per_token_group_quant_fp8
(
a
,
block_k
)
a
,
scale_a
=
per_token_group_quant_fp8
(
a
,
block_k
)
assert
triton
.
cdiv
(
a
.
shape
[
-
1
],
block_k
)
==
scale_a
.
shape
[
-
1
]
assert
triton
.
cdiv
(
b
.
shape
[
-
2
],
block_n
)
==
scale_b
.
shape
[
-
2
]
...
...
python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py
View file @
b70957fc
...
...
@@ -10,16 +10,14 @@ import torch
from
compressed_tensors
import
CompressionFormat
from
compressed_tensors.quantization
import
QuantizationStrategy
from
sglang.srt.layers.quantization.fp8_kernel
import
scaled_fp8_quant
from
sglang.srt.layers.quantization.fp8_kernel
import
is_fp8_fnuz
,
scaled_fp8_quant
from
sglang.srt.layers.quantization.fp8_utils
import
normalize_e4m3fn_to_e4m3fnuz
from
sglang.srt.layers.quantization.utils
import
(
all_close_1d
,
is_cuda
,
is_fp8_fnuz
,
per_tensor_dequantize
,
replace_parameter
,
)
from
sglang.srt.utils
import
set_weight_attrs
from
sglang.srt.utils
import
is_cuda
,
set_weight_attrs
_is_cuda
=
is_cuda
()
...
...
python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
View file @
b70957fc
...
...
@@ -15,11 +15,12 @@ from sglang.srt.layers.parameter import (
from
sglang.srt.layers.quantization.compressed_tensors.schemes
import
(
CompressedTensorsScheme
,
)
from
sglang.srt.layers.quantization.fp8_kernel
import
is_fp8_fnuz
from
sglang.srt.layers.quantization.fp8_utils
import
(
apply_fp8_linear
,
normalize_e4m3fn_to_e4m3fnuz
,
)
from
sglang.srt.layers.quantization.utils
import
is_fp8_fnuz
,
requantize_with_max_scale
from
sglang.srt.layers.quantization.utils
import
requantize_with_max_scale
__all__
=
[
"CompressedTensorsW8A8Fp8"
]
...
...
python/sglang/srt/layers/quantization/fp8.py
View file @
b70957fc
...
...
@@ -42,6 +42,8 @@ from sglang.srt.layers.quantization.base_config import (
QuantizeMethodBase
,
)
from
sglang.srt.layers.quantization.fp8_kernel
import
(
fp8_dtype
,
is_fp8_fnuz
,
per_token_group_quant_fp8
,
scaled_fp8_quant
,
)
...
...
@@ -71,6 +73,11 @@ from sglang.srt.utils import (
_is_hip
=
is_hip
()
_is_cuda
=
is_cuda
()
_is_fp8_fnuz
=
is_fp8_fnuz
()
use_hip_int4
=
get_bool_env_var
(
"SGLANG_INT4_WEIGHT"
)
use_aiter_moe
=
get_bool_env_var
(
"SGLANG_AITER_MOE"
)
if
_is_hip
:
from
aiter
import
ActivationType
,
QuantType
from
aiter.fused_moe_bf16_asm
import
asm_moe
,
ck_moe_2stages
...
...
@@ -306,25 +313,21 @@ class Fp8LinearMethod(LinearMethodBase):
# Block quant doesn't need to process weights after loading
if
self
.
block_quant
:
# If ROCm, normalize the weights and scales to e4m3fnuz
if
_is_
hip
:
if
_is_
fp8_fnuz
:
# activation_scheme: dynamic
weight
,
weight_scale
,
_
=
normalize_e4m3fn_to_e4m3fnuz
(
weight
=
layer
.
weight
,
weight_scale
=
layer
.
weight_scale_inv
,
input_scale
=
None
,
)
layer
.
weight
=
torch
.
nn
.
Parameter
(
weight
,
requires_grad
=
False
)
layer
.
weight_scale_inv
=
torch
.
nn
.
Parameter
(
weight_scale
,
requires_grad
=
False
)
layer
.
input_scale
=
None
else
:
layer
.
weight
=
torch
.
nn
.
Parameter
(
layer
.
weight
.
data
,
requires_grad
=
False
)
layer
.
weight_scale_inv
=
torch
.
nn
.
Parameter
(
layer
.
weight_scale_inv
.
data
,
requires_grad
=
False
)
weight
,
weight_scale
=
layer
.
weight
.
data
,
layer
.
weight_scale_inv
.
data
layer
.
weight
=
torch
.
nn
.
Parameter
(
weight
,
requires_grad
=
False
)
layer
.
weight_scale_inv
=
torch
.
nn
.
Parameter
(
weight_scale
,
requires_grad
=
False
)
return
layer
.
weight
=
torch
.
nn
.
Parameter
(
layer
.
weight
.
data
,
requires_grad
=
False
)
...
...
@@ -368,7 +371,7 @@ class Fp8LinearMethod(LinearMethodBase):
weight
=
layer
.
weight
weight_scale
=
layer
.
weight_scale
# If ROCm, normalize the weights and scales to e4m3fnuz
if
_is_
hip
:
if
_is_
fp8_fnuz
:
weight
,
weight_scale
,
input_scale
=
normalize_e4m3fn_to_e4m3fnuz
(
weight
=
weight
,
weight_scale
=
weight_scale
,
...
...
@@ -482,11 +485,7 @@ class Fp8MoEMethod:
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoeWeightScaleSupported
if
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
params_dtype
=
(
torch
.
uint32
if
get_bool_env_var
(
"SGLANG_INT4_WEIGHT"
)
else
torch
.
float8_e4m3fn
)
params_dtype
=
torch
.
uint32
if
use_hip_int4
else
torch
.
float8_e4m3fn
tp_size
=
get_tensor_model_parallel_world_size
()
if
self
.
block_quant
:
block_n
,
block_k
=
(
...
...
@@ -511,7 +510,7 @@ class Fp8MoEMethod:
)
# WEIGHTS
if
_is_hip
and
get_bool_env_var
(
"SGLANG_INT4_WEIGHT"
)
:
if
_is_hip
and
use_hip_int4
:
# INT4 MoE weight - INT32 packed
w13_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
...
...
@@ -583,9 +582,7 @@ class Fp8MoEMethod:
layer
.
register_parameter
(
"w13_weight_scale"
,
w13_weight_scale
)
layer
.
register_parameter
(
"w2_weight_scale"
,
w2_weight_scale
)
if
(
_is_hip
):
# and get_bool_env_var("SGLANG_AITER_MOE"): TODO: add check back after triton kernel
if
_is_hip
:
# and use_aiter_moe: TODO: add check back after triton kernel
# ROCm - using column scaling, duplicate scaling numbers in case per tensor scaling
w13_weight_scale1
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
2
*
intermediate_size
,
dtype
=
torch
.
float32
),
...
...
@@ -612,7 +609,7 @@ class Fp8MoEMethod:
set_weight_attrs
(
w13_weight_scale
,
extra_weight_attrs
)
set_weight_attrs
(
w2_weight_scale
,
extra_weight_attrs
)
if
_is_hip
and
get_bool_env_var
(
"SGLANG_INT4_WEIGHT"
)
:
if
_is_hip
and
use_hip_int4
:
extra_weight_attrs
.
update
(
{
"quant_method"
:
FusedMoeWeightScaleSupported
.
CHANNEL
.
value
}
)
...
...
@@ -644,14 +641,14 @@ class Fp8MoEMethod:
layer
.
w2_input_scale
=
None
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
if
_is_hip
and
get_bool_env_var
(
"SGLANG_INT4_WEIGHT"
)
:
if
_is_hip
and
use_hip_int4
:
self
.
process_weights_hip_int4
(
layer
)
return
# Block quant doesn't need to process weights after loading
if
self
.
block_quant
:
# If ROCm, normalize the weights and scales to e4m3fnuz
if
_is_
hip
:
if
_is_
fp8_fnuz
:
# activation_scheme: dynamic
w13_weight
,
w13_weight_scale
,
_
=
normalize_e4m3fn_to_e4m3fnuz
(
weight
=
layer
.
w13_weight
,
...
...
@@ -675,20 +672,19 @@ class Fp8MoEMethod:
)
layer
.
w2_input_scale
=
None
if
get_bool_env_var
(
"SGLANG_AITER_MOE"
)
:
# Pre-shuffle weights
layer
.
w13_weight
.
data
=
shuffle_weight
(
layer
.
w13_weight
.
contiguous
(),
(
16
,
16
)
)
layer
.
w2_weight
.
data
=
shuffle_weight
(
layer
.
w2_weight
.
contiguous
(),
(
16
,
16
)
)
if
_is_hip
and
use_aiter_moe
:
# Pre-shuffle weights
layer
.
w13_weight
.
data
=
shuffle_weight
(
layer
.
w13_weight
.
contiguous
(),
(
16
,
16
)
)
layer
.
w2_weight
.
data
=
shuffle_weight
(
layer
.
w2_weight
.
contiguous
(),
(
16
,
16
)
)
return
# If checkpoint is fp16 or bfloat16, quantize in place.
if
not
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
# If ROCm, use float8_e4m3fnuz instead (MI300x HW)
fp8_dtype
=
torch
.
float8_e4m3fnuz
if
_is_hip
else
torch
.
float8_e4m3fn
# If ROCm, fp8_dtype will be float8_e4m3fnuz (MI300x HW)
w13_weight
=
torch
.
empty_like
(
layer
.
w13_weight
.
data
,
dtype
=
fp8_dtype
)
w2_weight
=
torch
.
empty_like
(
layer
.
w2_weight
.
data
,
dtype
=
fp8_dtype
)
...
...
@@ -742,7 +738,7 @@ class Fp8MoEMethod:
)
# If ROCm, normalize the weights and scales to e4m3fnuz
if
_is_
hip
:
if
_is_
fp8_fnuz
:
# Normalize the weights and scales
w13_weight
,
w13_weight_scale
,
w13_input_scale
=
(
normalize_e4m3fn_to_e4m3fnuz
(
...
...
@@ -798,7 +794,7 @@ class Fp8MoEMethod:
return
def
process_weights_hip_int4
(
self
,
layer
:
Module
):
# TODO: and
get_bool_env_var("SGLANG_AITER_MOE")
: add after triton kernel added
# TODO: and
use_aiter_moe
: add after triton kernel added
# INT4-FP8 (INT4 MoE Weight, FP8 Compute)
# Weight Permutation
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
...
...
@@ -845,7 +841,7 @@ class Fp8MoEMethod:
padding_size
,
# Avoid circular import
)
if
get_bool_env_var
(
"SGLANG_AITER_MOE"
)
:
if
use_aiter_moe
:
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
shuffle_weight
(
layer
.
w13_weight
.
data
,
(
16
,
16
)),
requires_grad
=
False
,
...
...
@@ -856,7 +852,7 @@ class Fp8MoEMethod:
requires_grad
=
False
,
)
torch
.
cuda
.
empty_cache
()
# ROCm (
SGLANG_AITER_MOE
): using column-wise scaling
# ROCm (
use_aiter_moe
): using column-wise scaling
layer
.
w13_weight_scale1
*=
layer
.
w13_weight_scale
.
unsqueeze
(
-
1
)
layer
.
w2_weight_scale1
*=
layer
.
w2_weight_scale
.
unsqueeze
(
-
1
)
elif
get_bool_env_var
(
"SGLANG_MOE_PADDING"
):
...
...
@@ -908,59 +904,16 @@ class Fp8MoEMethod:
)
if
_is_hip
:
if
get_bool_env_var
(
"SGLANG_INT4_WEIGHT"
):
# TODO: add triton kernel and add check get_bool_env_var("SGLANG_AITER_MOE")
assert
not
no_combine
,
f
"
{
no_combine
=
}
is not supported."
return
ck_moe_2stages
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
topk_weights
,
topk_ids
,
QuantType
.
per_Token
,
layer
.
w13_weight_scale1
,
layer
.
w2_weight_scale1
,
activation
=
(
ActivationType
.
Silu
if
activation
==
"silu"
else
ActivationType
.
Gelu
),
)
if
get_bool_env_var
(
"SGLANG_AITER_MOE"
):
assert
not
no_combine
,
f
"
{
no_combine
=
}
is not supported."
if
self
.
block_quant
:
# TODO(SGLANG_AITER_MOE): FP8 block_quant only supports 'silu' for the time-being.
assert
(
activation
==
"silu"
),
f
"SGLANG_AITER_MOE: FP8 bloack_quant
{
activation
=
}
will be supported later, unset SGLANG_AITER_MOE"
return
asm_moe
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
topk_weights
,
topk_ids
,
layer
.
w13_weight_scale_inv
,
layer
.
w2_weight_scale_inv
,
block_shape
=
tuple
(
self
.
quant_config
.
weight_block_size
),
expert_mask
=
None
,
)
else
:
return
ck_moe_2stages
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
topk_weights
,
topk_ids
,
QuantType
.
per_Token
,
layer
.
w13_weight_scale1
,
layer
.
w2_weight_scale1
,
activation
=
(
ActivationType
.
Silu
if
activation
==
"silu"
else
ActivationType
.
Gelu
),
)
ret
=
self
.
maybe_apply_hip_fused_experts
(
layer
,
x
,
topk_weights
,
topk_ids
,
activation
,
no_combine
,
)
if
ret
is
not
None
:
return
ret
# Expert fusion with FP8 quantization
return
fused_experts
(
...
...
@@ -987,6 +940,68 @@ class Fp8MoEMethod:
no_combine
=
no_combine
,
)
def
maybe_apply_hip_fused_experts
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
activation
:
str
=
"silu"
,
no_combine
:
bool
=
False
,
)
->
Optional
[
torch
.
Tensor
]:
if
use_hip_int4
:
# TODO: add triton kernel and add check use_aiter_moe
assert
not
no_combine
,
f
"
{
no_combine
=
}
is not supported."
return
ck_moe_2stages
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
topk_weights
,
topk_ids
,
QuantType
.
per_Token
,
layer
.
w13_weight_scale1
,
layer
.
w2_weight_scale1
,
activation
=
(
ActivationType
.
Silu
if
activation
==
"silu"
else
ActivationType
.
Gelu
),
)
if
use_aiter_moe
:
assert
not
no_combine
,
f
"
{
no_combine
=
}
is not supported."
if
self
.
block_quant
:
# TODO(use_aiter_moe): FP8 block_quant only supports 'silu' for the time-being.
assert
(
activation
==
"silu"
),
f
"use_aiter_moe: FP8 bloack_quant
{
activation
=
}
will be supported later, unset use_aiter_moe"
return
asm_moe
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
topk_weights
,
topk_ids
,
layer
.
w13_weight_scale_inv
,
layer
.
w2_weight_scale_inv
,
block_shape
=
tuple
(
self
.
quant_config
.
weight_block_size
),
expert_mask
=
None
,
)
else
:
return
ck_moe_2stages
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
topk_weights
,
topk_ids
,
QuantType
.
per_Token
,
layer
.
w13_weight_scale1
,
layer
.
w2_weight_scale1
,
activation
=
(
ActivationType
.
Silu
if
activation
==
"silu"
else
ActivationType
.
Gelu
),
)
return
None
class
Fp8KVCacheMethod
(
BaseKVCacheMethod
):
"""
...
...
python/sglang/srt/layers/quantization/fp8_kernel.py
View file @
b70957fc
...
...
@@ -16,6 +16,7 @@ import functools
import
json
import
logging
import
os
from
functools
import
lru_cache
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
import
torch
...
...
@@ -34,12 +35,6 @@ from sglang.srt.utils import (
_is_hip
=
is_hip
()
_is_cuda
=
is_cuda
()
_fp8_type
=
torch
.
float8_e4m3fnuz
if
_is_hip
else
torch
.
float8_e4m3fn
if
_is_hip
:
fp8_max
=
224.0
else
:
fp8_max
=
torch
.
finfo
(
_fp8_type
).
max
fp8_min
=
-
fp8_max
if
_is_cuda
:
from
sgl_kernel
import
(
...
...
@@ -54,6 +49,24 @@ if _is_cuda:
logger
=
logging
.
getLogger
(
__name__
)
@
lru_cache
()
def
is_fp8_fnuz
()
->
bool
:
if
_is_hip
:
# only device 0 is checked, this assumes MI300 platforms are homogeneous
return
"gfx94"
in
torch
.
cuda
.
get_device_properties
(
0
).
gcnArchName
return
False
if
is_fp8_fnuz
():
fp8_dtype
=
torch
.
float8_e4m3fnuz
fp8_max
=
224.0
else
:
fp8_dtype
=
torch
.
float8_e4m3fn
fp8_max
=
torch
.
finfo
(
fp8_dtype
).
max
fp8_min
=
-
fp8_max
if
supports_custom_op
():
def
deep_gemm_fp8_fp8_bf16_nt
(
...
...
@@ -198,7 +211,7 @@ def per_token_group_quant_fp8(
),
"the last dimension of `x` cannot be divisible by `group_size`"
assert
x
.
is_contiguous
(),
"`x` is not contiguous"
x_q
=
torch
.
empty_like
(
x
,
device
=
x
.
device
,
dtype
=
_
fp8_type
)
x_q
=
torch
.
empty_like
(
x
,
device
=
x
.
device
,
dtype
=
fp8_
d
type
)
M
=
x
.
numel
()
//
group_size
N
=
group_size
if
column_major_scales
:
...
...
@@ -272,7 +285,7 @@ def sglang_per_token_group_quant_fp8(
),
"the last dimension of `x` cannot be divisible by `group_size`"
assert
x
.
is_contiguous
(),
"`x` is not contiguous"
x_q
=
torch
.
empty_like
(
x
,
device
=
x
.
device
,
dtype
=
_
fp8_type
)
x_q
=
torch
.
empty_like
(
x
,
device
=
x
.
device
,
dtype
=
fp8_
d
type
)
if
column_major_scales
:
if
scale_tma_aligned
:
# aligned to 4 * sizeof(float)
...
...
@@ -302,7 +315,7 @@ def sglang_per_token_group_quant_fp8(
def
sglang_per_token_quant_fp8
(
x
:
torch
.
Tensor
,
dtype
:
torch
.
dtype
=
_
fp8_type
,
dtype
:
torch
.
dtype
=
fp8_
d
type
,
):
assert
x
.
is_contiguous
(),
"`x` is not contiguous"
...
...
@@ -384,7 +397,7 @@ def static_quant_fp8(
assert
x
.
is_contiguous
(),
"`x` is not contiguous"
assert
x_s
.
numel
()
==
1
,
"only supports per-tensor scale"
x_q
=
torch
.
empty_like
(
x
,
device
=
x
.
device
,
dtype
=
_
fp8_type
)
x_q
=
torch
.
empty_like
(
x
,
device
=
x
.
device
,
dtype
=
fp8_
d
type
)
M
=
x
.
numel
()
//
x
.
shape
[
-
1
]
N
=
x
.
shape
[
-
1
]
if
repeat_scale
:
...
...
@@ -704,6 +717,28 @@ def get_w8a8_block_fp8_configs(
return
None
def
select_w8a8_block_fp8_matmul_kernel
(
M
,
N
,
META
):
return
_w8a8_block_fp8_matmul
if
_is_hip
:
def
use_w8a8_block_fp8_matmul_unrolledx4
(
M
,
N
,
META
):
# Use manually unrolledx4 kernel on AMD GPU when the grid size is small.
# Empirical testing shows the sweet spot lies when it's less than the # of
# compute units available on the device.
num_workgroups
=
triton
.
cdiv
(
M
,
META
[
"BLOCK_SIZE_M"
])
*
triton
.
cdiv
(
N
,
META
[
"BLOCK_SIZE_N"
]
)
num_workgroups
<=
get_device_core_count
()
def
select_w8a8_block_fp8_matmul_kernel
(
M
,
N
,
META
):
if
use_w8a8_block_fp8_matmul_unrolledx4
(
M
,
N
,
META
):
return
_w8a8_block_fp8_matmul_unrolledx4
else
:
return
_w8a8_block_fp8_matmul
def
w8a8_block_fp8_matmul
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
...
...
@@ -744,35 +779,6 @@ def w8a8_block_fp8_matmul(
C_shape
=
A
.
shape
[:
-
1
]
+
(
N
,)
C
=
A
.
new_empty
(
C_shape
,
dtype
=
output_dtype
)
configs
=
get_w8a8_block_fp8_configs
(
N
,
K
,
block_size
[
0
],
block_size
[
1
])
if
configs
:
# If an optimal configuration map has been found, look up the
# optimal config
config
=
configs
[
min
(
configs
.
keys
(),
key
=
lambda
x
:
abs
(
x
-
M
))]
else
:
# Default config
# Block-wise quant: BLOCK_SIZE_K must be divisable by block_size[1]
config
=
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
block_size
[
0
],
"BLOCK_SIZE_K"
:
block_size
[
1
],
"GROUP_SIZE_M"
:
32
,
"num_warps"
:
4
,
"num_stages"
:
3
,
}
def
grid
(
META
):
return
(
triton
.
cdiv
(
M
,
META
[
"BLOCK_SIZE_M"
])
*
triton
.
cdiv
(
N
,
META
[
"BLOCK_SIZE_N"
]),
)
# Use manually unrolledx4 kernel on AMD GPU when the grid size is small.
# Empirical testing shows the sweet spot lies when it's less than the # of
# compute units available on the device.
num_workgroups
=
triton
.
cdiv
(
M
,
config
[
"BLOCK_SIZE_M"
])
*
triton
.
cdiv
(
N
,
config
[
"BLOCK_SIZE_N"
]
)
# deepgemm only support bf16
if
C
.
dtype
==
torch
.
bfloat16
and
_ENABLE_JIT_DEEPGEMM
:
if
supports_custom_op
():
...
...
@@ -780,11 +786,30 @@ def w8a8_block_fp8_matmul(
else
:
deep_gemm_gemm_nt_f8f8bf16
((
A
,
As
),
(
B
,
Bs
),
C
)
else
:
kernel
=
(
_w8a8_block_fp8_matmul_unrolledx4
if
(
_is_hip
==
True
and
num_workgroups
<=
get_device_core_count
())
else
_w8a8_block_fp8_matmul
)
configs
=
get_w8a8_block_fp8_configs
(
N
,
K
,
block_size
[
0
],
block_size
[
1
])
if
configs
:
# If an optimal configuration map has been found, look up the
# optimal config
config
=
configs
[
min
(
configs
.
keys
(),
key
=
lambda
x
:
abs
(
x
-
M
))]
else
:
# Default config
# Block-wise quant: BLOCK_SIZE_K must be divisable by block_size[1]
config
=
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
block_size
[
0
],
"BLOCK_SIZE_K"
:
block_size
[
1
],
"GROUP_SIZE_M"
:
32
,
"num_warps"
:
4
,
"num_stages"
:
3
,
}
def
grid
(
META
):
return
(
triton
.
cdiv
(
M
,
META
[
"BLOCK_SIZE_M"
])
*
triton
.
cdiv
(
N
,
META
[
"BLOCK_SIZE_N"
]),
)
kernel
=
select_w8a8_block_fp8_matmul_kernel
(
M
,
N
,
config
)
kernel
[
grid
](
A
,
...
...
@@ -879,7 +904,7 @@ def per_tensor_quant_mla_fp8(
and
x_s_out
.
device
==
x
.
device
)
x_q
=
x
.
new_empty
(
x
.
size
(),
dtype
=
_
fp8_type
)
x_q
=
x
.
new_empty
(
x
.
size
(),
dtype
=
fp8_
d
type
)
num_head
,
num_seq
,
head_size
=
x
.
shape
BLOCK_SIZE
=
triton
.
next_power_of_2
(
head_size
)
...
...
@@ -961,11 +986,11 @@ def _per_token_group_quant_mla_deep_gemm_masked_fp8(
tl
.
store
(
y_s_ptr
+
gid
*
y_s_stride_g
,
y_s
)
def
per_t
ensor
_quant_mla_deep_gemm_masked_fp8
(
def
per_t
oken_group
_quant_mla_deep_gemm_masked_fp8
(
x
:
torch
.
Tensor
,
group_size
:
int
=
128
,
eps
:
float
=
1e-12
,
dtype
:
torch
.
dtype
=
torch
.
float8_e4m3fn
,
dtype
:
torch
.
dtype
=
fp8_dtype
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
This function quantizes input values to float8 values with per-token-group-quantization
...
...
@@ -973,12 +998,6 @@ def per_tensor_quant_mla_deep_gemm_masked_fp8(
"""
assert
x
.
dim
()
==
3
,
"`x` is not a 3d-tensor"
finfo
=
torch
.
finfo
(
dtype
)
fp8_max
=
finfo
.
max
if
_is_hip
:
dtype
=
torch
.
float8_e4m3fnuz
fp8_max
=
224.0
b
,
m
,
k
=
x
.
shape
aligned_m
=
(
m
+
255
)
//
256
*
256
# 256 is the max block_m of the gemm kernel
num_tiles_k
=
k
//
group_size
...
...
@@ -1043,10 +1062,9 @@ def scaled_fp8_quant(
"""
assert
input
.
ndim
==
2
,
f
"Expected 2D input tensor, got
{
input
.
ndim
}
D"
shape
=
input
.
shape
out_dtype
=
torch
.
float8_e4m3fnuz
if
_is_hip
else
torch
.
float8_e4m3fn
if
num_token_padding
:
shape
=
(
max
(
num_token_padding
,
input
.
shape
[
0
]),
shape
[
1
])
output
=
torch
.
empty
(
shape
,
device
=
input
.
device
,
dtype
=
out
_dtype
)
output
=
torch
.
empty
(
shape
,
device
=
input
.
device
,
dtype
=
fp8
_dtype
)
if
scale
is
None
:
# Dynamic scaling
...
...
python/sglang/srt/layers/quantization/fp8_utils.py
View file @
b70957fc
...
...
@@ -14,6 +14,9 @@ except ImportError:
from
sglang.srt.layers.quantization.deep_gemm
import
_ENABLE_JIT_DEEPGEMM
from
sglang.srt.layers.quantization.fp8_kernel
import
(
fp8_dtype
,
fp8_max
,
is_fp8_fnuz
,
per_token_group_quant_fp8
,
scaled_fp8_quant
,
sglang_per_token_quant_fp8
,
...
...
@@ -30,8 +33,11 @@ from sglang.srt.utils import (
_is_hip
=
is_hip
()
_is_cuda
=
is_cuda
()
_is_fp8_fnuz
=
is_fp8_fnuz
()
if
_is_hip
and
get_bool_env_var
(
"SGLANG_AITER_MOE"
):
use_aiter_moe
=
get_bool_env_var
(
"SGLANG_AITER_MOE"
)
if
_is_hip
and
use_aiter_moe
:
from
aiter
import
gemm_a8w8_blockscale
if
_is_cuda
:
...
...
@@ -43,19 +49,23 @@ use_vllm_cutlass_w8a8_fp8_kernel = get_bool_env_var("USE_VLLM_CUTLASS_W8A8_FP8_K
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
TORCH_DEVICE_IDENTITY
=
None
_TORCH_VERSION
=
torch
.
__version__
.
split
(
"+"
)[
0
]
try
:
_TORCH_VERSION_TUPLE
=
tuple
(
map
(
int
,
_TORCH_VERSION
.
split
(
"."
)[:
3
]))
except
ValueError
:
_TORCH_VERSION_TUPLE
=
(
0
,
0
,
0
)
# The condition to determine if it is on a platform that supports
# torch._scaled_mm rowwise feature.
# The condition is determined once as the operations
# are time consuming.
USE_ROWWISE_TORCH_SCALED_MM
=
(
_is_hip
and
get_device_capability
()
>=
(
9
,
4
)
and
_TORCH_VERSION_TUPLE
>=
(
2
,
7
,
0
)
)
def
use_rowwise_torch_scaled_mm
():
_TORCH_VERSION
=
torch
.
__version__
.
split
(
"+"
)[
0
]
try
:
_TORCH_VERSION_TUPLE
=
tuple
(
map
(
int
,
_TORCH_VERSION
.
split
(
"."
)[:
3
]))
except
ValueError
:
_TORCH_VERSION_TUPLE
=
(
0
,
0
,
0
)
if
_is_hip
:
# The condition to determine if it is on a platform that supports
# torch._scaled_mm rowwise feature.
# The condition is determined once as the operations
# are time consuming.
return
get_device_capability
()
>=
(
9
,
4
)
and
_TORCH_VERSION_TUPLE
>=
(
2
,
7
,
0
)
return
False
USE_ROWWISE_TORCH_SCALED_MM
=
use_rowwise_torch_scaled_mm
()
def
cutlass_fp8_supported
():
...
...
@@ -132,7 +142,7 @@ def apply_w8a8_block_fp8_linear(
output
=
fp8_blockwise_scaled_mm
(
q_input
,
weight
.
T
,
x_scale
,
weight_scale
.
T
,
out_dtype
=
input
.
dtype
)
elif
_is_hip
and
get_bool_env_var
(
"SGLANG_AITER_MOE"
)
:
elif
_is_hip
and
use_aiter_moe
:
q_input
,
x_scale
=
per_token_group_quant_fp8
(
input_2d
,
block_size
[
1
],
column_major_scales
=
False
)
...
...
@@ -164,18 +174,21 @@ def apply_w8a8_block_fp8_linear(
def
input_to_float8
(
x
:
torch
.
Tensor
,
dtype
:
torch
.
dtype
=
torch
.
float8_e4m3fn
x
:
torch
.
Tensor
,
dtype
:
torch
.
dtype
=
fp8_dtype
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""This function quantizes input values to float8 values with tensor-wise quantization."""
finfo
=
torch
.
finfo
(
dtype
)
min_val
,
max_val
=
x
.
aminmax
()
amax
=
torch
.
maximum
(
min_val
.
abs
(),
max_val
.
abs
()).
float
().
clamp
(
min
=
1e-12
)
fp8_max
=
finfo
.
max
if
_is_hip
:
dtype
=
torch
.
float8_e4m3fnuz
fp8_max
=
224.0
scale
=
fp8_max
/
amax
x_scl_sat
=
(
x
.
float
()
*
scale
).
clamp
(
min
=-
fp8_max
,
max
=
fp8_max
)
if
_is_fp8_fnuz
:
dtype
=
fp8_dtype
fp_max
=
fp8_max
else
:
finfo
=
torch
.
finfo
(
dtype
)
fp_max
=
finfo
.
max
scale
=
fp_max
/
amax
x_scl_sat
=
(
x
.
float
()
*
scale
).
clamp
(
min
=-
fp_max
,
max
=
fp_max
)
return
x_scl_sat
.
to
(
dtype
).
contiguous
(),
scale
.
float
().
reciprocal
()
...
...
python/sglang/srt/layers/quantization/kv_cache.py
View file @
b70957fc
...
...
@@ -8,10 +8,8 @@ from sglang.srt.layers.quantization.base_config import (
QuantizationConfig
,
QuantizeMethodBase
,
)
from
sglang.srt.layers.quantization.fp8_kernel
import
is_fp8_fnuz
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.utils
import
is_hip
_is_hip
=
is_hip
()
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -44,11 +42,6 @@ class BaseKVCacheMethod(QuantizeMethodBase):
torch
.
tensor
(
-
1.0
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
@
classmethod
def
is_fp8_fnuz
(
cls
)
->
bool
:
# only device 0 is checked, this assumes MI300 platforms are homogeneous
return
"gfx94"
in
torch
.
cuda
.
get_device_properties
(
0
).
gcnArchName
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
)
->
torch
.
Tensor
:
raise
RuntimeError
(
f
"
{
self
.
__class__
.
__name__
}
.apply should not be called."
)
...
...
@@ -57,7 +50,7 @@ class BaseKVCacheMethod(QuantizeMethodBase):
# We prefer to use separate k_scale and v_scale if present
k_scale
=
layer
.
k_scale
.
to
(
"cpu"
).
tolist
()
v_scale
=
layer
.
v_scale
.
to
(
"cpu"
).
tolist
()
if
_is_hip
and
self
.
is_fp8_fnuz
():
if
is_fp8_fnuz
():
k_scale
*=
2
v_scale
*=
2
elif
layer
.
k_scale
<
0.0
and
layer
.
v_scale
<
0.0
:
...
...
@@ -73,7 +66,7 @@ class BaseKVCacheMethod(QuantizeMethodBase):
scale_to_duplicate
=
max
(
layer
.
k_scale
,
layer
.
v_scale
)
k_scale
=
scale_to_duplicate
.
to
(
"cpu"
).
tolist
()
v_scale
=
scale_to_duplicate
.
to
(
"cpu"
).
tolist
()
if
_is_hip
and
self
.
is_fp8_fnuz
():
if
is_fp8_fnuz
():
k_scale
*=
2
v_scale
*=
2
...
...
python/sglang/srt/layers/quantization/utils.py
View file @
b70957fc
...
...
@@ -14,11 +14,6 @@ if not _is_cuda:
from
vllm._custom_ops
import
scaled_fp8_quant
def
is_fp8_fnuz
()
->
bool
:
# only device 0 is checked, this assumes MI300 platforms are homogeneous
return
"gfx94"
in
torch
.
cuda
.
get_device_properties
(
0
).
gcnArchName
def
is_layer_skipped
(
prefix
:
str
,
ignored_layers
:
List
[
str
],
...
...
python/sglang/srt/layers/quantization/w8a8_fp8.py
View file @
b70957fc
...
...
@@ -9,16 +9,20 @@ from sglang.srt.layers.quantization.base_config import (
QuantizationConfig
,
QuantizeMethodBase
,
)
from
sglang.srt.layers.quantization.fp8_kernel
import
per_token_group_quant_fp8
from
sglang.srt.layers.quantization.fp8_kernel
import
(
fp8_dtype
,
is_fp8_fnuz
,
per_token_group_quant_fp8
,
)
from
sglang.srt.layers.quantization.fp8_utils
import
(
apply_fp8_linear
,
cutlass_fp8_supported
,
input_to_float8
,
normalize_e4m3fn_to_e4m3fnuz
,
)
from
sglang.srt.utils
import
is_hip
,
set_weight_attrs
from
sglang.srt.utils
import
set_weight_attrs
_is_
hip
=
is_hip
()
_is_
fp8_fnuz
=
is_fp8_fnuz
()
class
W8A8Fp8Config
(
QuantizationConfig
):
...
...
@@ -97,7 +101,7 @@ class W8A8Fp8LinearMethod(LinearMethodBase):
if
self
.
quantization_config
.
is_checkpoint_fp8_serialized
:
weight_scale
=
layer
.
weight_scale
.
detach
()
# If checkpoint offline quantized with w8a8_fp8, load the weight and weight_scale directly.
if
_is_
hip
:
if
_is_
fp8_fnuz
:
weight
,
weight_scale
,
_
=
normalize_e4m3fn_to_e4m3fnuz
(
weight
=
weight
,
weight_scale
=
weight_scale
)
...
...
@@ -113,14 +117,9 @@ class W8A8Fp8LinearMethod(LinearMethodBase):
layer
.
weight
,
layer
.
weight
.
shape
[
-
1
]
)
weight_scale
=
weight_scale
.
t
().
contiguous
()
if
_is_hip
:
weight
,
weight_scale
,
_
=
normalize_e4m3fn_to_e4m3fnuz
(
weight
=
weight
,
weight_scale
=
weight_scale
)
else
:
# if cutlass not supported, we fall back to use torch._scaled_mm
# which requires per tensor quantization on weight
fp8_dtype
=
torch
.
float8_e4m3fnuz
if
_is_hip
else
torch
.
float8_e4m3fn
qweight
,
weight_scale
=
input_to_float8
(
layer
.
weight
,
dtype
=
fp8_dtype
)
# Update the layer with the new values.
...
...
@@ -227,7 +226,6 @@ class W8A8FP8MoEMethod:
):
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoeWeightScaleSupported
fp8_dtype
=
torch
.
float8_e4m3fnuz
if
_is_hip
else
torch
.
float8_e4m3fn
# WEIGHTS
w13_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
...
...
python/sglang/srt/models/deepseek_nextn.py
View file @
b70957fc
...
...
@@ -24,34 +24,15 @@ from sglang.srt.distributed import get_tensor_model_parallel_world_size
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.linear
import
ReplicatedLinear
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.moe.ep_moe.layer
import
EPMoE
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoE
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.fp8_utils
import
(
block_quant_to_tensor_quant
,
normalize_e4m3fn_to_e4m3fnuz
,
)
from
sglang.srt.layers.quantization.int8_utils
import
(
block_dequant
as
int8_block_dequant
,
)
from
sglang.srt.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
,
)
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.models.deepseek_v2
import
DeepseekV2DecoderLayer
,
DeepseekV3ForCausalLM
from
sglang.srt.utils
import
BumpAllocator
,
add_prefix
,
is_cuda
,
is_hip
_is_hip
=
is_hip
()
_is_cuda
=
is_cuda
()
if
_is_cuda
:
from
sgl_kernel
import
awq_dequantize
else
:
from
vllm._custom_ops
import
awq_dequantize
from
sglang.srt.utils
import
BumpAllocator
,
add_prefix
logger
=
logging
.
getLogger
(
__name__
)
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
b70957fc
...
...
@@ -59,8 +59,8 @@ from sglang.srt.layers.moe.topk import select_experts
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.deep_gemm
import
_ENABLE_JIT_DEEPGEMM
from
sglang.srt.layers.quantization.fp8_kernel
import
(
per_tensor_quant_mla_deep_gemm_masked_fp8
,
per_tensor_quant_mla_fp8
,
per_token_group_quant_mla_deep_gemm_masked_fp8
,
)
from
sglang.srt.layers.quantization.fp8_utils
import
(
block_quant_to_tensor_quant
,
...
...
@@ -738,9 +738,7 @@ class DeepseekV2AttentionMLA(nn.Module):
if
self
.
use_deep_gemm_bmm
:
q_nope_val
,
q_nope_scale
,
masked_m
,
expected_m
,
aligned_m
=
(
per_tensor_quant_mla_deep_gemm_masked_fp8
(
q_nope
.
transpose
(
0
,
1
),
dtype
=
torch
.
float8_e4m3fn
)
per_token_group_quant_mla_deep_gemm_masked_fp8
(
q_nope
.
transpose
(
0
,
1
))
)
q_nope_out
=
q_nope
.
new_empty
(
(
self
.
num_local_heads
,
aligned_m
,
self
.
kv_lora_rank
)
...
...
@@ -785,8 +783,8 @@ class DeepseekV2AttentionMLA(nn.Module):
if
self
.
use_deep_gemm_bmm
:
attn_output_val
,
attn_output_scale
,
masked_m
,
expected_m
,
aligned_m
=
(
per_t
ensor
_quant_mla_deep_gemm_masked_fp8
(
attn_output
.
transpose
(
0
,
1
)
,
dtype
=
torch
.
float8_e4m3fn
per_t
oken_group
_quant_mla_deep_gemm_masked_fp8
(
attn_output
.
transpose
(
0
,
1
)
)
)
attn_bmm_output
=
attn_output
.
new_empty
(
...
...
python/sglang/test/test_block_fp8.py
View file @
b70957fc
...
...
@@ -7,9 +7,9 @@ import torch
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.moe.fused_moe_triton.fused_moe
import
fused_moe
from
sglang.srt.layers.quantization.fp8_kernel
import
(
per_tensor_quant_mla_deep_gemm_masked_fp8
,
per_tensor_quant_mla_fp8
,
per_token_group_quant_fp8
,
per_token_group_quant_mla_deep_gemm_masked_fp8
,
static_quant_fp8
,
w8a8_block_fp8_matmul
,
)
...
...
@@ -236,7 +236,7 @@ class TestPerTokenGroupQuantMlaDeepGemmMaskedFP8(CustomTestCase):
with
torch
.
inference_mode
():
ref_out
,
ref_scale
=
native_per_token_group_quant_fp8
(
x
,
group_size
,
1e-12
)
out
,
scale
,
_
,
_
,
_
=
per_t
ensor
_quant_mla_deep_gemm_masked_fp8
(
out
,
scale
,
_
,
_
,
_
=
per_t
oken_group
_quant_mla_deep_gemm_masked_fp8
(
x
,
group_size
)
out
=
out
[:,
:
num_tokens
,
:]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment