Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
98eecbda
"vscode:/vscode.git/clone" did not exist on "ed70c70ea3569670499717f06d117ed25ec32af4"
Unverified
Commit
98eecbda
authored
Feb 13, 2025
by
yizhang2077
Committed by
GitHub
Feb 13, 2025
Browse files
integrate blockwise fp8 kernel (#3529)
parent
4430c0a5
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
124 additions
and
23 deletions
+124
-23
python/pyproject.toml
python/pyproject.toml
+1
-1
python/sglang/srt/layers/quantization/fp8_kernel.py
python/sglang/srt/layers/quantization/fp8_kernel.py
+90
-18
python/sglang/srt/layers/quantization/fp8_utils.py
python/sglang/srt/layers/quantization/fp8_utils.py
+33
-4
No files found.
python/pyproject.toml
View file @
98eecbda
...
...
@@ -25,7 +25,7 @@ runtime_common = [
]
srt
=
[
"sglang[runtime_common]"
,
"cuda-python"
,
"sgl-kernel>=0.0.3.post
4
"
,
"torch"
,
"vllm>=0.6.4.post1,<=0.7.2"
,
"sgl-kernel>=0.0.3.post
5
"
,
"torch"
,
"vllm>=0.6.4.post1,<=0.7.2"
,
"flashinfer_python>=0.2.0.post2"
,
"outlines>=0.0.44,<=0.1.11"
]
...
...
python/sglang/srt/layers/quantization/fp8_kernel.py
View file @
98eecbda
...
...
@@ -76,11 +76,60 @@ def _per_token_group_quant_fp8(
tl
.
store
(
y_s_ptr
,
y_s
)
@
triton
.
jit
def
_per_token_group_quant_fp8_colmajor
(
# Pointers to inputs and output
y_ptr
,
y_q_ptr
,
y_s_ptr
,
group_size
,
# Num columns of y
y_num_columns
,
# Stride from one column to the next of y_s
y_s_col_stride
,
# Avoid to divide zero
eps
,
# Information for float8
fp8_min
,
fp8_max
,
# Meta-parameters
BLOCK
:
tl
.
constexpr
,
):
"""A Triton-accelerated function to perform per-token-group
quantization on a tensor.
This function converts the tensor values into float8 values.
"""
# Map the program id to the row of X and Y it should compute.
g_id
=
tl
.
program_id
(
0
)
y_ptr
+=
g_id
*
group_size
y_q_ptr
+=
g_id
*
group_size
# Convert g_id the flattened block coordinate to 2D so we can index
# into the output y_scales matrix
blocks_per_row
=
y_num_columns
//
group_size
scale_col
=
g_id
%
blocks_per_row
scale_row
=
g_id
//
blocks_per_row
y_s_ptr
+=
scale_col
*
y_s_col_stride
+
scale_row
cols
=
tl
.
arange
(
0
,
BLOCK
)
# group_size <= BLOCK
mask
=
cols
<
group_size
y
=
tl
.
load
(
y_ptr
+
cols
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
# Quant
_absmax
=
tl
.
maximum
(
tl
.
max
(
tl
.
abs
(
y
)),
eps
)
y_s
=
_absmax
/
fp8_max
y_q
=
tl
.
clamp
(
y
/
y_s
,
fp8_min
,
fp8_max
).
to
(
y_q_ptr
.
dtype
.
element_ty
)
tl
.
store
(
y_q_ptr
+
cols
,
y_q
,
mask
=
mask
)
tl
.
store
(
y_s_ptr
,
y_s
)
def
per_token_group_quant_fp8
(
x
:
torch
.
Tensor
,
group_size
:
int
,
eps
:
float
=
1e-10
,
dtype
:
torch
.
dtype
=
fp8_type_
,
column_major_scales
:
bool
=
False
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Function to perform per-token-group quantization on an input tensor `x`.
...
...
@@ -112,6 +161,13 @@ def per_token_group_quant_fp8(
x_q
=
torch
.
empty_like
(
x
,
device
=
x
.
device
,
dtype
=
dtype
)
M
=
x
.
numel
()
//
group_size
N
=
group_size
if
column_major_scales
:
x_s
=
torch
.
empty
(
(
x
.
shape
[
-
1
]
//
group_size
,)
+
x
.
shape
[:
-
1
],
device
=
x
.
device
,
dtype
=
torch
.
float32
,
).
permute
(
-
1
,
-
2
)
else
:
x_s
=
torch
.
empty
(
x
.
shape
[:
-
1
]
+
(
x
.
shape
[
-
1
]
//
group_size
,),
device
=
x
.
device
,
...
...
@@ -122,6 +178,22 @@ def per_token_group_quant_fp8(
# heuristics for number of warps
num_warps
=
min
(
max
(
BLOCK
//
256
,
1
),
8
)
num_stages
=
1
if
column_major_scales
:
_per_token_group_quant_fp8_colmajor
[(
M
,)](
x
,
x_q
,
x_s
,
group_size
,
x
.
shape
[
1
],
x_s
.
stride
(
1
),
eps
,
fp8_min
=
fp8_min
,
fp8_max
=
fp8_max
,
BLOCK
=
BLOCK
,
num_warps
=
num_warps
,
num_stages
=
num_stages
,
)
else
:
_per_token_group_quant_fp8
[(
M
,)](
x
,
x_q
,
...
...
python/sglang/srt/layers/quantization/fp8_utils.py
View file @
98eecbda
...
...
@@ -10,6 +10,9 @@ from sglang.srt.layers.quantization.fp8_kernel import (
from
sglang.srt.utils
import
is_hip
is_hip_
=
is_hip
()
_is_cuda
=
torch
.
cuda
.
is_available
()
and
torch
.
version
.
cuda
if
_is_cuda
:
from
sgl_kernel
import
fp8_blockwise_scaled_mm
def
normalize_e4m3fn_to_e4m3fnuz
(
...
...
@@ -36,6 +39,19 @@ def normalize_e4m3fn_to_e4m3fnuz(
return
weight
,
weight_scale
,
input_scale
def
cutlass_block_fp8_supported
()
->
bool
:
if
_is_cuda
:
major
,
minor
=
torch
.
cuda
.
get_device_capability
()
sm_version
=
major
*
10
+
minor
cuda_version
=
tuple
(
map
(
int
,
torch
.
version
.
cuda
.
split
(
"."
)))
if
cuda_version
>=
(
12
,
0
)
and
sm_version
>=
90
:
return
True
return
False
CUTLASS_BLOCK_FP8_SUPPORTED
=
cutlass_block_fp8_supported
()
def
apply_w8a8_block_fp8_linear
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
...
...
@@ -48,8 +64,21 @@ def apply_w8a8_block_fp8_linear(
# View input as 2D matrix for fp8 methods
input_2d
=
input
.
view
(
-
1
,
input
.
shape
[
-
1
])
output_shape
=
[
*
input
.
shape
[:
-
1
],
weight
.
shape
[
0
]]
q_input
,
x_scale
=
per_token_group_quant_fp8
(
input_2d
,
block_size
[
1
])
# TODO: add more robust shape check here
shape_supported_by_cutlass
=
(
weight
.
shape
[
0
]
%
128
==
0
and
weight
.
shape
[
1
]
%
128
==
0
)
if
CUTLASS_BLOCK_FP8_SUPPORTED
and
shape_supported_by_cutlass
:
q_input
,
x_scale
=
per_token_group_quant_fp8
(
input_2d
,
block_size
[
1
],
column_major_scales
=
True
)
output
=
fp8_blockwise_scaled_mm
(
q_input
,
weight
.
T
,
x_scale
,
weight_scale
.
T
,
out_dtype
=
input
.
dtype
)
else
:
q_input
,
x_scale
=
per_token_group_quant_fp8
(
input_2d
,
block_size
[
1
],
column_major_scales
=
False
)
output
=
w8a8_block_fp8_matmul
(
q_input
,
weight
,
x_scale
,
weight_scale
,
block_size
,
output_dtype
=
input
.
dtype
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment