Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
98eecbda
Unverified
Commit
98eecbda
authored
Feb 13, 2025
by
yizhang2077
Committed by
GitHub
Feb 13, 2025
Browse files
integrate blockwise fp8 kernel (#3529)
parent
4430c0a5
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
124 additions
and
23 deletions
+124
-23
python/pyproject.toml
python/pyproject.toml
+1
-1
python/sglang/srt/layers/quantization/fp8_kernel.py
python/sglang/srt/layers/quantization/fp8_kernel.py
+90
-18
python/sglang/srt/layers/quantization/fp8_utils.py
python/sglang/srt/layers/quantization/fp8_utils.py
+33
-4
No files found.
python/pyproject.toml
View file @
98eecbda
...
@@ -25,7 +25,7 @@ runtime_common = [
...
@@ -25,7 +25,7 @@ runtime_common = [
]
]
srt
=
[
srt
=
[
"sglang[runtime_common]"
,
"cuda-python"
,
"sglang[runtime_common]"
,
"cuda-python"
,
"sgl-kernel>=0.0.3.post
4
"
,
"torch"
,
"vllm>=0.6.4.post1,<=0.7.2"
,
"sgl-kernel>=0.0.3.post
5
"
,
"torch"
,
"vllm>=0.6.4.post1,<=0.7.2"
,
"flashinfer_python>=0.2.0.post2"
,
"outlines>=0.0.44,<=0.1.11"
"flashinfer_python>=0.2.0.post2"
,
"outlines>=0.0.44,<=0.1.11"
]
]
...
...
python/sglang/srt/layers/quantization/fp8_kernel.py
View file @
98eecbda
...
@@ -76,11 +76,60 @@ def _per_token_group_quant_fp8(
...
@@ -76,11 +76,60 @@ def _per_token_group_quant_fp8(
tl
.
store
(
y_s_ptr
,
y_s
)
tl
.
store
(
y_s_ptr
,
y_s
)
@
triton
.
jit
def
_per_token_group_quant_fp8_colmajor
(
# Pointers to inputs and output
y_ptr
,
y_q_ptr
,
y_s_ptr
,
group_size
,
# Num columns of y
y_num_columns
,
# Stride from one column to the next of y_s
y_s_col_stride
,
# Avoid to divide zero
eps
,
# Information for float8
fp8_min
,
fp8_max
,
# Meta-parameters
BLOCK
:
tl
.
constexpr
,
):
"""A Triton-accelerated function to perform per-token-group
quantization on a tensor.
This function converts the tensor values into float8 values.
"""
# Map the program id to the row of X and Y it should compute.
g_id
=
tl
.
program_id
(
0
)
y_ptr
+=
g_id
*
group_size
y_q_ptr
+=
g_id
*
group_size
# Convert g_id the flattened block coordinate to 2D so we can index
# into the output y_scales matrix
blocks_per_row
=
y_num_columns
//
group_size
scale_col
=
g_id
%
blocks_per_row
scale_row
=
g_id
//
blocks_per_row
y_s_ptr
+=
scale_col
*
y_s_col_stride
+
scale_row
cols
=
tl
.
arange
(
0
,
BLOCK
)
# group_size <= BLOCK
mask
=
cols
<
group_size
y
=
tl
.
load
(
y_ptr
+
cols
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
# Quant
_absmax
=
tl
.
maximum
(
tl
.
max
(
tl
.
abs
(
y
)),
eps
)
y_s
=
_absmax
/
fp8_max
y_q
=
tl
.
clamp
(
y
/
y_s
,
fp8_min
,
fp8_max
).
to
(
y_q_ptr
.
dtype
.
element_ty
)
tl
.
store
(
y_q_ptr
+
cols
,
y_q
,
mask
=
mask
)
tl
.
store
(
y_s_ptr
,
y_s
)
def
per_token_group_quant_fp8
(
def
per_token_group_quant_fp8
(
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
group_size
:
int
,
group_size
:
int
,
eps
:
float
=
1e-10
,
eps
:
float
=
1e-10
,
dtype
:
torch
.
dtype
=
fp8_type_
,
dtype
:
torch
.
dtype
=
fp8_type_
,
column_major_scales
:
bool
=
False
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Function to perform per-token-group quantization on an input tensor `x`.
"""Function to perform per-token-group quantization on an input tensor `x`.
...
@@ -112,29 +161,52 @@ def per_token_group_quant_fp8(
...
@@ -112,29 +161,52 @@ def per_token_group_quant_fp8(
x_q
=
torch
.
empty_like
(
x
,
device
=
x
.
device
,
dtype
=
dtype
)
x_q
=
torch
.
empty_like
(
x
,
device
=
x
.
device
,
dtype
=
dtype
)
M
=
x
.
numel
()
//
group_size
M
=
x
.
numel
()
//
group_size
N
=
group_size
N
=
group_size
x_s
=
torch
.
empty
(
if
column_major_scales
:
x
.
shape
[:
-
1
]
+
(
x
.
shape
[
-
1
]
//
group_size
,),
x_s
=
torch
.
empty
(
device
=
x
.
device
,
(
x
.
shape
[
-
1
]
//
group_size
,)
+
x
.
shape
[:
-
1
],
dtype
=
torch
.
float32
,
device
=
x
.
device
,
)
dtype
=
torch
.
float32
,
).
permute
(
-
1
,
-
2
)
else
:
x_s
=
torch
.
empty
(
x
.
shape
[:
-
1
]
+
(
x
.
shape
[
-
1
]
//
group_size
,),
device
=
x
.
device
,
dtype
=
torch
.
float32
,
)
BLOCK
=
triton
.
next_power_of_2
(
N
)
BLOCK
=
triton
.
next_power_of_2
(
N
)
# heuristics for number of warps
# heuristics for number of warps
num_warps
=
min
(
max
(
BLOCK
//
256
,
1
),
8
)
num_warps
=
min
(
max
(
BLOCK
//
256
,
1
),
8
)
num_stages
=
1
num_stages
=
1
_per_token_group_quant_fp8
[(
M
,)](
if
column_major_scales
:
x
,
_per_token_group_quant_fp8_colmajor
[(
M
,)](
x_q
,
x
,
x_s
,
x_q
,
group_size
,
x_s
,
N
,
group_size
,
eps
,
x
.
shape
[
1
],
fp8_min
=
fp8_min
,
x_s
.
stride
(
1
),
fp8_max
=
fp8_max
,
eps
,
BLOCK
=
BLOCK
,
fp8_min
=
fp8_min
,
num_warps
=
num_warps
,
fp8_max
=
fp8_max
,
num_stages
=
num_stages
,
BLOCK
=
BLOCK
,
)
num_warps
=
num_warps
,
num_stages
=
num_stages
,
)
else
:
_per_token_group_quant_fp8
[(
M
,)](
x
,
x_q
,
x_s
,
group_size
,
N
,
eps
,
fp8_min
=
fp8_min
,
fp8_max
=
fp8_max
,
BLOCK
=
BLOCK
,
num_warps
=
num_warps
,
num_stages
=
num_stages
,
)
return
x_q
,
x_s
return
x_q
,
x_s
...
...
python/sglang/srt/layers/quantization/fp8_utils.py
View file @
98eecbda
...
@@ -10,6 +10,9 @@ from sglang.srt.layers.quantization.fp8_kernel import (
...
@@ -10,6 +10,9 @@ from sglang.srt.layers.quantization.fp8_kernel import (
from
sglang.srt.utils
import
is_hip
from
sglang.srt.utils
import
is_hip
is_hip_
=
is_hip
()
is_hip_
=
is_hip
()
_is_cuda
=
torch
.
cuda
.
is_available
()
and
torch
.
version
.
cuda
if
_is_cuda
:
from
sgl_kernel
import
fp8_blockwise_scaled_mm
def
normalize_e4m3fn_to_e4m3fnuz
(
def
normalize_e4m3fn_to_e4m3fnuz
(
...
@@ -36,6 +39,19 @@ def normalize_e4m3fn_to_e4m3fnuz(
...
@@ -36,6 +39,19 @@ def normalize_e4m3fn_to_e4m3fnuz(
return
weight
,
weight_scale
,
input_scale
return
weight
,
weight_scale
,
input_scale
def
cutlass_block_fp8_supported
()
->
bool
:
if
_is_cuda
:
major
,
minor
=
torch
.
cuda
.
get_device_capability
()
sm_version
=
major
*
10
+
minor
cuda_version
=
tuple
(
map
(
int
,
torch
.
version
.
cuda
.
split
(
"."
)))
if
cuda_version
>=
(
12
,
0
)
and
sm_version
>=
90
:
return
True
return
False
CUTLASS_BLOCK_FP8_SUPPORTED
=
cutlass_block_fp8_supported
()
def
apply_w8a8_block_fp8_linear
(
def
apply_w8a8_block_fp8_linear
(
input
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
...
@@ -48,11 +64,24 @@ def apply_w8a8_block_fp8_linear(
...
@@ -48,11 +64,24 @@ def apply_w8a8_block_fp8_linear(
# View input as 2D matrix for fp8 methods
# View input as 2D matrix for fp8 methods
input_2d
=
input
.
view
(
-
1
,
input
.
shape
[
-
1
])
input_2d
=
input
.
view
(
-
1
,
input
.
shape
[
-
1
])
output_shape
=
[
*
input
.
shape
[:
-
1
],
weight
.
shape
[
0
]]
output_shape
=
[
*
input
.
shape
[:
-
1
],
weight
.
shape
[
0
]]
# TODO: add more robust shape check here
q_input
,
x_scale
=
per_token_group_quant_fp8
(
input_2d
,
block_size
[
1
])
shape_supported_by_cutlass
=
(
output
=
w8a8_block_fp8_matmul
(
weight
.
shape
[
0
]
%
128
==
0
and
weight
.
shape
[
1
]
%
128
==
0
q_input
,
weight
,
x_scale
,
weight_scale
,
block_size
,
output_dtype
=
input
.
dtype
)
)
if
CUTLASS_BLOCK_FP8_SUPPORTED
and
shape_supported_by_cutlass
:
q_input
,
x_scale
=
per_token_group_quant_fp8
(
input_2d
,
block_size
[
1
],
column_major_scales
=
True
)
output
=
fp8_blockwise_scaled_mm
(
q_input
,
weight
.
T
,
x_scale
,
weight_scale
.
T
,
out_dtype
=
input
.
dtype
)
else
:
q_input
,
x_scale
=
per_token_group_quant_fp8
(
input_2d
,
block_size
[
1
],
column_major_scales
=
False
)
output
=
w8a8_block_fp8_matmul
(
q_input
,
weight
,
x_scale
,
weight_scale
,
block_size
,
output_dtype
=
input
.
dtype
)
if
bias
is
not
None
:
if
bias
is
not
None
:
output
=
output
+
bias
output
=
output
+
bias
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment