Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
485a023b
Unverified
Commit
485a023b
authored
May 29, 2025
by
ChangyiYang
Committed by
GitHub
May 29, 2025
Browse files
refactor apply_w8a8_block_fp8_linear in fp (#6545)
parent
7e412900
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
281 additions
and
118 deletions
+281
-118
benchmark/kernels/deepseek/benchmark_deepgemm_fp8_gemm.py
benchmark/kernels/deepseek/benchmark_deepgemm_fp8_gemm.py
+3
-1
python/sglang/srt/layers/quantization/fp8.py
python/sglang/srt/layers/quantization/fp8.py
+4
-2
python/sglang/srt/layers/quantization/fp8_kernel.py
python/sglang/srt/layers/quantization/fp8_kernel.py
+118
-66
python/sglang/srt/layers/quantization/fp8_utils.py
python/sglang/srt/layers/quantization/fp8_utils.py
+153
-48
sgl-kernel/benchmark/bench_fp8_blockwise_gemm.py
sgl-kernel/benchmark/bench_fp8_blockwise_gemm.py
+3
-1
No files found.
benchmark/kernels/deepseek/benchmark_deepgemm_fp8_gemm.py
View file @
485a023b
...
...
@@ -10,7 +10,9 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
w8a8_block_fp8_matmul
as
vllm_w8a8_block_fp8_matmul
,
)
from
sglang.srt.layers.quantization.fp8_kernel
import
w8a8_block_fp8_matmul
from
sglang.srt.layers.quantization.fp8_kernel
import
(
w8a8_block_fp8_matmul_deepgemm
as
w8a8_block_fp8_matmul
,
)
# Adapted from https://github.com/tile-ai/tilelang/blob/a8cfdce92795cb861c9033573534653ee040b5ed/examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py#L1
...
...
python/sglang/srt/layers/quantization/fp8.py
View file @
485a023b
...
...
@@ -49,8 +49,8 @@ from sglang.srt.layers.quantization.fp8_kernel import (
)
from
sglang.srt.layers.quantization.fp8_utils
import
(
apply_fp8_linear
,
apply_w8a8_block_fp8_linear
,
cutlass_fp8_supported
,
dispatch_w8a8_block_fp8_linear
,
input_to_float8
,
is_sm100_supported
,
normalize_e4m3fn_to_e4m3fnuz
,
...
...
@@ -209,6 +209,8 @@ class Fp8LinearMethod(LinearMethodBase):
# Marlin doesn't support block-wise fp8
self
.
use_marlin
=
False
self
.
w8a8_block_fp8_linear
=
dispatch_w8a8_block_fp8_linear
()
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
...
...
@@ -417,7 +419,7 @@ class Fp8LinearMethod(LinearMethodBase):
)
if
self
.
block_quant
:
return
apply_
w8a8_block_fp8_linear
(
return
self
.
w8a8_block_fp8_linear
(
input
=
x
,
weight
=
layer
.
weight
,
block_size
=
self
.
quant_config
.
weight_block_size
,
...
...
python/sglang/srt/layers/quantization/fp8_kernel.py
View file @
485a023b
...
...
@@ -740,7 +740,59 @@ if _is_hip:
return
_w8a8_block_fp8_matmul
def
w8a8_block_fp8_matmul
(
def
prepare_block_fp8_matmul_inputs
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
As
:
torch
.
Tensor
,
Bs
:
torch
.
Tensor
,
block_size
:
List
[
int
],
output_dtype
:
torch
.
dtype
=
torch
.
float16
,
)
->
Tuple
[
int
,
int
,
int
]:
assert
len
(
block_size
)
==
2
block_n
,
block_k
=
block_size
[
0
],
block_size
[
1
]
assert
A
.
shape
[
-
1
]
==
B
.
shape
[
-
1
]
assert
A
.
shape
[:
-
1
]
==
As
.
shape
[:
-
1
]
assert
A
.
is_contiguous
()
assert
triton
.
cdiv
(
A
.
shape
[
-
1
],
block_k
)
==
As
.
shape
[
-
1
]
M
=
A
.
numel
()
//
A
.
shape
[
-
1
]
assert
B
.
ndim
==
2
assert
B
.
is_contiguous
()
assert
Bs
.
ndim
==
2
N
,
K
=
B
.
shape
assert
triton
.
cdiv
(
N
,
block_n
)
==
Bs
.
shape
[
0
]
assert
triton
.
cdiv
(
K
,
block_k
)
==
Bs
.
shape
[
1
]
C_shape
=
A
.
shape
[:
-
1
]
+
(
N
,)
C
=
A
.
new_empty
(
C_shape
,
dtype
=
output_dtype
)
return
M
,
N
,
K
,
C
def
w8a8_block_fp8_matmul_deepgemm
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
As
:
torch
.
Tensor
,
Bs
:
torch
.
Tensor
,
block_size
:
List
[
int
],
output_dtype
:
torch
.
dtype
,
)
->
torch
.
Tensor
:
M
,
N
,
K
,
C
=
prepare_block_fp8_matmul_inputs
(
A
,
B
,
As
,
Bs
,
block_size
,
output_dtype
)
# Deepgemm only supports output tensor type as bfloat16
assert
C
.
dtype
==
torch
.
bfloat16
and
_ENABLE_JIT_DEEPGEMM
if
supports_custom_op
():
torch
.
ops
.
sglang
.
deep_gemm_fp8_fp8_bf16_nt
(
A
,
As
,
B
,
Bs
,
C
)
else
:
deep_gemm_gemm_nt_f8f8bf16
((
A
,
As
),
(
B
,
Bs
),
C
)
return
C
def
w8a8_block_fp8_matmul_triton
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
As
:
torch
.
Tensor
,
...
...
@@ -764,81 +816,81 @@ def w8a8_block_fp8_matmul(
Returns:
torch.Tensor: The result of matmul.
"""
assert
len
(
block_size
)
==
2
block_n
,
block_k
=
block_size
[
0
],
block_size
[
1
]
assert
A
.
shape
[
-
1
]
==
B
.
shape
[
-
1
]
assert
A
.
shape
[:
-
1
]
==
As
.
shape
[:
-
1
]
and
A
.
is_contiguous
()
assert
triton
.
cdiv
(
A
.
shape
[
-
1
],
block_k
)
==
As
.
shape
[
-
1
]
M
=
A
.
numel
()
//
A
.
shape
[
-
1
]
assert
B
.
ndim
==
2
and
B
.
is_contiguous
()
and
Bs
.
ndim
==
2
N
,
K
=
B
.
shape
assert
triton
.
cdiv
(
N
,
block_n
)
==
Bs
.
shape
[
0
]
assert
triton
.
cdiv
(
K
,
block_k
)
==
Bs
.
shape
[
1
]
M
,
N
,
K
,
C
=
prepare_block_fp8_matmul_inputs
(
A
,
B
,
As
,
Bs
,
block_size
,
output_dtype
)
C_shape
=
A
.
shape
[:
-
1
]
+
(
N
,)
C
=
A
.
new_empty
(
C_shape
,
dtype
=
output_dtype
)
block_n
,
block_k
=
block_size
# deepgemm only support bf16
if
C
.
dtype
==
torch
.
bfloat16
and
_ENABLE_JIT_DEEPGEMM
:
if
supports_custom_op
():
torch
.
ops
.
sglang
.
deep_gemm_fp8_fp8_bf16_nt
(
A
,
As
,
B
,
Bs
,
C
)
else
:
deep_gemm_gemm_nt_f8f8bf16
((
A
,
As
),
(
B
,
Bs
),
C
)
configs
=
get_w8a8_block_fp8_configs
(
N
,
K
,
block_size
[
0
],
block_size
[
1
])
if
configs
:
# If an optimal configuration map has been found, look up the
# optimal config
config
=
configs
[
min
(
configs
.
keys
(),
key
=
lambda
x
:
abs
(
x
-
M
))]
else
:
configs
=
get_w8a8_block_fp8_configs
(
N
,
K
,
block_size
[
0
],
block_size
[
1
])
if
configs
:
# If an optimal configuration map has been found, look up the
# optimal config
config
=
configs
[
min
(
configs
.
keys
(),
key
=
lambda
x
:
abs
(
x
-
M
))]
else
:
# Default config
# Block-wise quant: BLOCK_SIZE_K must be divisible by block_size[1]
config
=
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
block_size
[
0
],
"BLOCK_SIZE_K"
:
block_size
[
1
],
"GROUP_SIZE_M"
:
32
,
"num_warps"
:
4
,
"num_stages"
:
3
,
}
def
grid
(
META
):
return
(
triton
.
cdiv
(
M
,
META
[
"BLOCK_SIZE_M"
])
*
triton
.
cdiv
(
N
,
META
[
"BLOCK_SIZE_N"
]),
)
# Default config
# Block-wise quant: BLOCK_SIZE_K must be divisible by block_size[1]
config
=
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
block_size
[
0
],
"BLOCK_SIZE_K"
:
block_size
[
1
],
"GROUP_SIZE_M"
:
32
,
"num_warps"
:
4
,
"num_stages"
:
3
,
}
def
grid
(
META
):
return
(
triton
.
cdiv
(
M
,
META
[
"BLOCK_SIZE_M"
])
*
triton
.
cdiv
(
N
,
META
[
"BLOCK_SIZE_N"
]),
)
kernel
=
select_w8a8_block_fp8_matmul_kernel
(
M
,
N
,
config
)
kernel
=
select_w8a8_block_fp8_matmul_kernel
(
M
,
N
,
config
)
kernel
[
grid
](
A
,
B
,
C
,
As
,
Bs
,
M
,
N
,
K
,
block_n
,
block_k
,
A
.
stride
(
-
2
),
A
.
stride
(
-
1
),
B
.
stride
(
1
),
B
.
stride
(
0
),
C
.
stride
(
-
2
),
C
.
stride
(
-
1
),
As
.
stride
(
-
2
),
As
.
stride
(
-
1
),
Bs
.
stride
(
1
),
Bs
.
stride
(
0
),
**
config
,
)
kernel
[
grid
](
A
,
B
,
C
,
As
,
Bs
,
M
,
N
,
K
,
block_n
,
block_k
,
A
.
stride
(
-
2
),
A
.
stride
(
-
1
),
B
.
stride
(
1
),
B
.
stride
(
0
),
C
.
stride
(
-
2
),
C
.
stride
(
-
1
),
As
.
stride
(
-
2
),
As
.
stride
(
-
1
),
Bs
.
stride
(
1
),
Bs
.
stride
(
0
),
**
config
,
)
return
C
# universal entry point, for testing purposes
def
w8a8_block_fp8_matmul
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
As
:
torch
.
Tensor
,
Bs
:
torch
.
Tensor
,
block_size
:
List
[
int
],
output_dtype
:
torch
.
dtype
=
torch
.
float16
,
)
->
torch
.
Tensor
:
if
output_dtype
==
torch
.
bfloat16
and
_ENABLE_JIT_DEEPGEMM
:
return
w8a8_block_fp8_matmul_deepgemm
(
A
,
B
,
As
,
Bs
,
block_size
,
output_dtype
=
output_dtype
)
return
w8a8_block_fp8_matmul_triton
(
A
,
B
,
As
,
Bs
,
block_size
,
output_dtype
=
output_dtype
)
@
triton
.
jit
def
_per_tensor_quant_mla_fp8_stage1
(
x_ptr
,
...
...
python/sglang/srt/layers/quantization/fp8_utils.py
View file @
485a023b
import
os
from
typing
import
List
,
Optional
,
Tuple
from
curses
import
flash
from
typing
import
Callable
,
List
,
Optional
,
Tuple
import
torch
...
...
@@ -21,7 +22,8 @@ from sglang.srt.layers.quantization.fp8_kernel import (
scaled_fp8_quant
,
sglang_per_token_quant_fp8
,
static_quant_fp8
,
w8a8_block_fp8_matmul
,
w8a8_block_fp8_matmul_deepgemm
,
w8a8_block_fp8_matmul_triton
,
)
from
sglang.srt.utils
import
(
get_bool_env_var
,
...
...
@@ -134,7 +136,20 @@ if ENABLE_FLASHINFER_GEMM:
from
flashinfer.gemm
import
gemm_fp8_nt_groupwise
def
apply_w8a8_block_fp8_linear
(
def
dispatch_w8a8_block_fp8_linear
()
->
Callable
:
if
ENABLE_FLASHINFER_GEMM
:
return
flashinfer_gemm_w8a8_block_fp8_linear
elif
CUTLASS_BLOCK_FP8_SUPPORTED
:
return
cutlass_w8a8_block_fp8_linear_with_fallback
elif
_is_hip
and
use_aiter_moe
:
return
aiter_w8a8_block_fp8_linear
elif
_ENABLE_JIT_DEEPGEMM
:
return
deepgemm_w8a8_block_fp8_linear_with_fallback
else
:
return
triton_w8a8_block_fp8_linear
def
flashinfer_gemm_w8a8_block_fp8_linear
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
block_size
:
List
[
int
],
...
...
@@ -143,58 +158,148 @@ def apply_w8a8_block_fp8_linear(
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
assert
input_scale
is
None
# View input as 2D matrix for fp8 methods
input_2d
=
input
.
view
(
-
1
,
input
.
shape
[
-
1
])
output_shape
=
[
*
input
.
shape
[:
-
1
],
weight
.
shape
[
0
]]
# TODO: add more robust shape check here
shape_supported_by_cutlass
=
(
weight
.
shape
[
0
]
%
128
==
0
and
weight
.
shape
[
1
]
%
128
==
0
q_input
,
x_scale
=
sglang_per_token_group_quant_fp8
(
input_2d
,
block_size
[
1
],
column_major_scales
=
False
)
if
ENABLE_FLASHINFER_GEMM
:
q_input
,
x_scale
=
sglang_per_token_group_quant_fp8
(
input_2d
,
block_size
[
1
],
column_major_scales
=
False
)
x_scale_input
=
x_scale
.
T
.
contiguous
()
weight_scale_input
=
weight_scale
.
T
.
contiguous
()
output
=
gemm_fp8_nt_groupwise
(
q_input
,
weight
,
x_scale_input
,
weight_scale_input
,
out_dtype
=
input
.
dtype
)
elif
CUTLASS_BLOCK_FP8_SUPPORTED
and
shape_supported_by_cutlass
:
q_input
,
x_scale
=
per_token_group_quant_fp8
(
input_2d
,
block_size
[
1
],
column_major_scales
=
True
)
output
=
fp8_blockwise_scaled_mm
(
q_input
,
weight
.
T
,
x_scale
,
weight_scale
.
T
,
out_dtype
=
input
.
dtype
)
elif
_is_hip
and
use_aiter_moe
:
q_input
,
x_scale
=
per_token_group_quant_fp8
(
input_2d
,
block_size
[
1
],
column_major_scales
=
False
)
output
=
torch
.
zeros
(
[
q_input
.
shape
[
0
],
weight
.
shape
[
0
]],
dtype
=
input
.
dtype
,
device
=
q_input
.
device
,
x_scale_input
=
x_scale
.
T
.
contiguous
()
weight_scale_input
=
weight_scale
.
T
.
contiguous
()
output
=
gemm_fp8_nt_groupwise
(
q_input
,
weight
,
x_scale_input
,
weight_scale_input
,
out_dtype
=
input_2d
.
dtype
)
if
bias
is
not
None
:
output
+=
bias
return
output
.
to
(
dtype
=
input_2d
.
dtype
).
view
(
*
output_shape
)
def
cutlass_w8a8_block_fp8_linear_with_fallback
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
block_size
:
List
[
int
],
weight_scale
:
torch
.
Tensor
,
input_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
assert
input_scale
is
None
# TODO: add more robust shape check here
shape_supported
=
weight
.
shape
[
0
]
%
128
==
0
and
weight
.
shape
[
1
]
%
128
==
0
if
not
shape_supported
:
# fallback to triton
return
triton_w8a8_block_fp8_linear
(
input
,
weight
,
block_size
,
weight_scale
,
input_scale
,
bias
)
gemm_a8w8_blockscale
(
q_input
,
weight
,
x_scale
,
weight_scale
,
output
)
else
:
if
_ENABLE_JIT_DEEPGEMM
:
q_input
,
x_scale
=
sglang_per_token_group_quant_fp8
(
input_2d
,
block_size
[
1
],
column_major_scales
=
True
,
scale_tma_aligned
=
True
,
)
else
:
q_input
,
x_scale
=
per_token_group_quant_fp8
(
input_2d
,
block_size
[
1
],
column_major_scales
=
False
)
output
=
w8a8_block_fp8_matmul
(
q_input
,
weight
,
x_scale
,
weight_scale
,
block_size
,
output_dtype
=
input
.
dtype
input_2d
=
input
.
view
(
-
1
,
input
.
shape
[
-
1
])
output_shape
=
[
*
input
.
shape
[:
-
1
],
weight
.
shape
[
0
]]
q_input
,
x_scale
=
per_token_group_quant_fp8
(
input_2d
,
block_size
[
1
],
column_major_scales
=
True
)
output
=
fp8_blockwise_scaled_mm
(
q_input
,
weight
.
T
,
x_scale
,
weight_scale
.
T
,
out_dtype
=
input_2d
.
dtype
)
if
bias
is
not
None
:
output
+=
bias
return
output
.
to
(
dtype
=
input_2d
.
dtype
).
view
(
*
output_shape
)
def
deepgemm_w8a8_block_fp8_linear_with_fallback
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
block_size
:
List
[
int
],
weight_scale
:
torch
.
Tensor
,
input_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
assert
input_scale
is
None
output_dtype
=
input
.
dtype
dtype_supported
=
output_dtype
==
torch
.
bfloat16
# TODO: add more robust shape check here
shape_supported
=
weight
.
shape
[
0
]
%
128
==
0
and
weight
.
shape
[
1
]
%
128
==
0
if
not
(
shape_supported
and
dtype_supported
):
# fall back to triton
return
triton_w8a8_block_fp8_linear
(
input
,
weight
,
block_size
,
weight_scale
,
input_scale
,
bias
)
input_2d
=
input
.
view
(
-
1
,
input
.
shape
[
-
1
])
output_shape
=
[
*
input
.
shape
[:
-
1
],
weight
.
shape
[
0
]]
q_input
,
x_scale
=
sglang_per_token_group_quant_fp8
(
input_2d
,
block_size
[
1
],
column_major_scales
=
True
,
scale_tma_aligned
=
True
,
)
output
=
w8a8_block_fp8_matmul_deepgemm
(
q_input
,
weight
,
x_scale
,
weight_scale
,
block_size
,
output_dtype
=
output_dtype
)
if
bias
is
not
None
:
output
=
output
+
bias
return
output
.
to
(
dtype
=
input
.
dtype
).
view
(
*
output_shape
)
output
+=
bias
return
output
.
to
(
dtype
=
output_dtype
).
view
(
*
output_shape
)
def
aiter_w8a8_block_fp8_linear
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
block_size
:
List
[
int
],
weight_scale
:
torch
.
Tensor
,
input_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
assert
input_scale
is
None
input_2d
=
input
.
view
(
-
1
,
input
.
shape
[
-
1
])
output_shape
=
[
*
input
.
shape
[:
-
1
],
weight
.
shape
[
0
]]
q_input
,
x_scale
=
per_token_group_quant_fp8
(
input_2d
,
block_size
[
1
],
column_major_scales
=
False
)
output
=
torch
.
zeros
(
[
q_input
.
shape
[
0
],
weight
.
shape
[
0
]],
dtype
=
input_2d
.
dtype
,
device
=
q_input
.
device
,
)
gemm_a8w8_blockscale
(
q_input
,
weight
,
x_scale
,
weight_scale
,
output
)
if
bias
is
not
None
:
output
+=
bias
return
output
.
to
(
dtype
=
input_2d
.
dtype
).
view
(
*
output_shape
)
def
triton_w8a8_block_fp8_linear
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
block_size
:
List
[
int
],
weight_scale
:
torch
.
Tensor
,
input_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
assert
input_scale
is
None
input_2d
=
input
.
view
(
-
1
,
input
.
shape
[
-
1
])
output_shape
=
[
*
input
.
shape
[:
-
1
],
weight
.
shape
[
0
]]
q_input
,
x_scale
=
per_token_group_quant_fp8
(
input_2d
,
block_size
[
1
],
column_major_scales
=
False
)
output
=
w8a8_block_fp8_matmul_triton
(
q_input
,
weight
,
x_scale
,
weight_scale
,
block_size
,
output_dtype
=
input_2d
.
dtype
)
if
bias
is
not
None
:
output
+=
bias
return
output
.
to
(
dtype
=
input_2d
.
dtype
).
view
(
*
output_shape
)
def
input_to_float8
(
...
...
sgl-kernel/benchmark/bench_fp8_blockwise_gemm.py
View file @
485a023b
...
...
@@ -9,7 +9,9 @@ from deep_gemm import get_col_major_tma_aligned_tensor
from
sgl_kernel
import
fp8_blockwise_scaled_mm
from
vllm._custom_ops
import
cutlass_scaled_mm
as
vllm_scaled_mm
from
sglang.srt.layers.quantization.fp8_kernel
import
w8a8_block_fp8_matmul
from
sglang.srt.layers.quantization.fp8_kernel
import
(
w8a8_block_fp8_matmul_triton
as
w8a8_block_fp8_matmul
,
)
def
get_weight_shapes
(
args
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment