Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
82392da8
Unverified
Commit
82392da8
authored
Jan 26, 2025
by
HandH1998
Committed by
GitHub
Jan 26, 2025
Browse files
support w8a8 fp8 kernel with CUTLASS (#3047)
Co-authored-by:
yych0745
<
1398089567@qq.com
>
parent
95f789ad
Changes
8
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
881 additions
and
0 deletions
+881
-0
sgl-kernel/benchmark/bench_fp8_gemm.py
sgl-kernel/benchmark/bench_fp8_gemm.py
+164
-0
sgl-kernel/setup.py
sgl-kernel/setup.py
+2
-0
sgl-kernel/src/sgl-kernel/__init__.py
sgl-kernel/src/sgl-kernel/__init__.py
+2
-0
sgl-kernel/src/sgl-kernel/csrc/fp8_gemm_kernel.cu
sgl-kernel/src/sgl-kernel/csrc/fp8_gemm_kernel.cu
+624
-0
sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h
sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h
+5
-0
sgl-kernel/src/sgl-kernel/ops/__init__.py
sgl-kernel/src/sgl-kernel/ops/__init__.py
+11
-0
sgl-kernel/src/sgl-kernel/torch_extension.cc
sgl-kernel/src/sgl-kernel/torch_extension.cc
+6
-0
sgl-kernel/tests/test_fp8_gemm.py
sgl-kernel/tests/test_fp8_gemm.py
+67
-0
No files found.
sgl-kernel/benchmark/bench_fp8_gemm.py
0 → 100644
View file @
82392da8
import
argparse
import
copy
import
itertools
import
torch
import
triton
from
sgl_kernel
import
fp8_scaled_mm
as
sgl_scaled_mm
from
vllm._custom_ops
import
cutlass_scaled_mm
as
vllm_scaled_mm
from
vllm._custom_ops
import
scaled_fp8_quant
as
vllm_scaled_fp8_quant
# Weight Shapes are in the format
# ([K, N], TP_SPLIT_DIM)
# Example:
# A shape of ([14336, 4096], 0) indicates the following GEMM shape,
# - TP1 : K = 14336, N = 4096
# - TP2 : K = 7168, N = 4096
# A shape of ([4096, 6144], 1) indicates the following GEMM shape,
# - TP1 : K = 4096, N = 6144
# - TP4 : K = 4096, N = 1536
# TP1 shapes
WEIGHT_SHAPES
=
{
"meta-llama/Llama-3.1-8B-Instruct"
:
[
([
4096
,
6144
],
1
),
([
4096
,
4096
],
0
),
([
4096
,
28672
],
1
),
([
14336
,
4096
],
0
),
],
"meta-llama/Llama-3.3-70B-Instruct"
:
[
([
8192
,
10240
],
1
),
([
8192
,
8192
],
0
),
([
8192
,
57344
],
1
),
([
28672
,
8192
],
0
),
],
"mistralai/Mistral-Large-Instruct-2407"
:
[
([
12288
,
14336
],
1
),
([
12288
,
12288
],
0
),
([
12288
,
57344
],
1
),
([
28672
,
12288
],
0
),
],
"Qwen/Qwen2.5-7B-Instruct"
:
[
([
3584
,
4608
],
1
),
([
3584
,
3584
],
0
),
([
3584
,
37888
],
1
),
([
18944
,
3584
],
0
),
],
"Qwen/Qwen2.5-32B-Instruct"
:
[
([
5120
,
7168
],
1
),
([
5120
,
5120
],
0
),
([
5120
,
55296
],
1
),
([
27648
,
5120
],
0
),
],
"Qwen/Qwen2.5-72B-Instruct"
:
[
([
8192
,
10240
],
1
),
([
8192
,
8192
],
0
),
([
8192
,
59136
],
1
),
([
29568
,
8192
],
0
),
],
"deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct"
:
[
([
2048
,
3072
],
1
),
([
2048
,
4096
],
1
),
([
2048
,
2048
],
0
),
([
2048
,
576
],
0
),
([
2048
,
21888
],
1
),
([
10944
,
2048
],
0
),
([
2048
,
2816
],
1
),
([
1408
,
2048
],
0
),
],
}
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"batch_size"
],
x_vals
=
[
1
,
16
,
64
,
128
,
256
,
512
,
1024
,
2048
],
x_log
=
False
,
line_arg
=
"provider"
,
line_vals
=
[
"vllm-fp8-fp16"
,
"vllm-fp8-bf16"
,
"sglang-fp8-fp16"
,
"sglang-fp8-bf16"
,
],
line_names
=
[
"vllm-fp8-fp16"
,
"vllm-fp8-bf16"
,
"sglang-fp8-fp16"
,
"sglang-fp8-bf16"
,
],
styles
=
[(
"green"
,
"-"
),
(
"green"
,
"--"
),
(
"blue"
,
"-"
),
(
"blue"
,
"--"
)],
ylabel
=
"GB/s"
,
plot_name
=
"fp8 scaled matmul"
,
args
=
{},
)
)
def
benchmark
(
batch_size
,
provider
,
N
,
K
):
# M, N, K = batch_size, 4096, 8192
M
=
batch_size
a
=
torch
.
ones
((
M
,
K
),
device
=
"cuda"
)
*
5.0
b
=
torch
.
ones
((
N
,
K
),
device
=
"cuda"
)
*
5.0
scale_a
=
torch
.
randn
((
M
,),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
scale_b
=
torch
.
randn
((
N
,),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
a_fp8
,
scale_a_fp8
=
vllm_scaled_fp8_quant
(
a
,
scale_a
)
b_fp8
,
scale_b_fp8
=
vllm_scaled_fp8_quant
(
b
,
scale_b
)
b_fp8
=
b_fp8
.
t
()
quantiles
=
[
0.5
,
0.2
,
0.8
]
dtype
=
torch
.
float16
if
"fp16"
in
provider
else
torch
.
bfloat16
if
"vllm-fp8"
in
provider
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
vllm_scaled_mm
(
a_fp8
,
b_fp8
,
scale_a_fp8
,
scale_b_fp8
,
dtype
),
quantiles
=
quantiles
,
)
elif
"sglang-fp8"
in
provider
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
sgl_scaled_mm
(
a_fp8
,
b_fp8
,
scale_a_fp8
,
scale_b_fp8
,
dtype
,
bias
=
None
),
quantiles
=
quantiles
,
)
gbps
=
lambda
ms
:
(
2
*
M
*
N
*
K
+
M
*
N
)
*
a
.
element_size
()
*
1e-9
/
(
ms
*
1e-3
)
return
gbps
(
ms
),
gbps
(
max_ms
),
gbps
(
min_ms
)
def
prepare_shapes
(
args
):
KN_model_names
=
[]
models_tps
=
list
(
itertools
.
product
(
args
.
models
,
args
.
tp_sizes
))
for
model
,
tp_size
in
models_tps
:
assert
model
in
WEIGHT_SHAPES
for
KN
,
tp_split_dim
in
copy
.
deepcopy
(
WEIGHT_SHAPES
[
model
]):
KN
[
tp_split_dim
]
=
KN
[
tp_split_dim
]
//
tp_size
KN
.
append
(
model
)
KN_model_names
.
append
(
KN
)
return
KN_model_names
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--models"
,
nargs
=
"+"
,
type
=
str
,
default
=
[
"meta-llama/Llama-3.1-8B-Instruct"
],
help
=
"List of models to benchmark"
,
)
parser
.
add_argument
(
"--tp-sizes"
,
nargs
=
"+"
,
type
=
int
,
default
=
[
1
],
help
=
"List of tensor parallel sizes"
,
)
args
=
parser
.
parse_args
()
KN_model_names
=
prepare_shapes
(
args
)
for
K
,
N
,
model_name
in
KN_model_names
:
print
(
f
"
{
model_name
}
N=
{
N
}
K=
{
K
}
: "
)
benchmark
.
run
(
print_data
=
True
,
show_plots
=
True
,
save_path
=
"bench_fp8_res"
,
N
=
N
,
K
=
K
)
print
(
"Benchmark finished!"
)
sgl-kernel/setup.py
View file @
82392da8
...
@@ -56,6 +56,7 @@ include_dirs = [
...
@@ -56,6 +56,7 @@ include_dirs = [
turbomind
.
resolve
(),
turbomind
.
resolve
(),
turbomind
.
resolve
()
/
"src"
,
turbomind
.
resolve
()
/
"src"
,
]
]
nvcc_flags
=
[
nvcc_flags
=
[
"-DNDEBUG"
,
"-DNDEBUG"
,
f
"-DOPERATOR_NAMESPACE=
{
operator_namespace
}
"
,
f
"-DOPERATOR_NAMESPACE=
{
operator_namespace
}
"
,
...
@@ -82,6 +83,7 @@ sources = [
...
@@ -82,6 +83,7 @@ sources = [
"src/sgl-kernel/csrc/trt_reduce_kernel.cu"
,
"src/sgl-kernel/csrc/trt_reduce_kernel.cu"
,
"src/sgl-kernel/csrc/moe_align_kernel.cu"
,
"src/sgl-kernel/csrc/moe_align_kernel.cu"
,
"src/sgl-kernel/csrc/int8_gemm_kernel.cu"
,
"src/sgl-kernel/csrc/int8_gemm_kernel.cu"
,
"src/sgl-kernel/csrc/fp8_gemm_kernel.cu"
,
"src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu"
,
"src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu"
,
"src/sgl-kernel/csrc/rotary_embedding.cu"
,
"src/sgl-kernel/csrc/rotary_embedding.cu"
,
"3rdparty/flashinfer/csrc/activation.cu"
,
"3rdparty/flashinfer/csrc/activation.cu"
,
...
...
sgl-kernel/src/sgl-kernel/__init__.py
View file @
82392da8
...
@@ -2,6 +2,7 @@ from sgl_kernel.ops import (
...
@@ -2,6 +2,7 @@ from sgl_kernel.ops import (
bmm_fp8
,
bmm_fp8
,
custom_dispose
,
custom_dispose
,
custom_reduce
,
custom_reduce
,
fp8_scaled_mm
,
fused_add_rmsnorm
,
fused_add_rmsnorm
,
gelu_and_mul
,
gelu_and_mul
,
gelu_tanh_and_mul
,
gelu_tanh_and_mul
,
...
@@ -27,6 +28,7 @@ __all__ = [
...
@@ -27,6 +28,7 @@ __all__ = [
"bmm_fp8"
,
"bmm_fp8"
,
"custom_dispose"
,
"custom_dispose"
,
"custom_reduce"
,
"custom_reduce"
,
"fp8_scaled_mm"
,
"fused_add_rmsnorm"
,
"fused_add_rmsnorm"
,
"gelu_and_mul"
,
"gelu_and_mul"
,
"gelu_tanh_and_mul"
,
"gelu_tanh_and_mul"
,
...
...
sgl-kernel/src/sgl-kernel/csrc/fp8_gemm_kernel.cu
0 → 100644
View file @
82392da8
This diff is collapsed.
Click to expand it.
sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h
View file @
82392da8
...
@@ -40,6 +40,11 @@ torch::Tensor int8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& ma
...
@@ -40,6 +40,11 @@ torch::Tensor int8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& ma
const
torch
::
Tensor
&
scales_b
,
const
torch
::
Dtype
&
out_dtype
,
const
torch
::
Tensor
&
scales_b
,
const
torch
::
Dtype
&
out_dtype
,
const
c10
::
optional
<
torch
::
Tensor
>&
bias
);
const
c10
::
optional
<
torch
::
Tensor
>&
bias
);
// fp8_scaled_mm
torch
::
Tensor
fp8_scaled_mm
(
const
torch
::
Tensor
&
mat_a
,
const
torch
::
Tensor
&
mat_b
,
const
torch
::
Tensor
&
scales_a
,
const
torch
::
Tensor
&
scales_b
,
const
torch
::
Dtype
&
out_dtype
,
const
c10
::
optional
<
torch
::
Tensor
>&
bias
);
// lightning_attention_decode
// lightning_attention_decode
void
lightning_attention_decode
(
const
torch
::
Tensor
&
q
,
const
torch
::
Tensor
&
k
,
const
torch
::
Tensor
&
v
,
void
lightning_attention_decode
(
const
torch
::
Tensor
&
q
,
const
torch
::
Tensor
&
k
,
const
torch
::
Tensor
&
v
,
const
torch
::
Tensor
&
past_kv
,
const
torch
::
Tensor
&
slope
,
torch
::
Tensor
output
,
const
torch
::
Tensor
&
past_kv
,
const
torch
::
Tensor
&
slope
,
torch
::
Tensor
output
,
...
...
sgl-kernel/src/sgl-kernel/ops/__init__.py
View file @
82392da8
...
@@ -71,6 +71,17 @@ def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
...
@@ -71,6 +71,17 @@ def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
)
)
def
fp8_scaled_mm
(
mat_a
,
mat_b
,
scales_a
,
scales_b
,
out_dtype
,
bias
=
None
):
return
torch
.
ops
.
sgl_kernels
.
fp8_scaled_mm
(
mat_a
,
mat_b
,
scales_a
,
scales_b
,
out_dtype
,
bias
,
)
def
lightning_attention_decode
(
q
,
k
,
v
,
past_kv
,
slope
,
output
,
new_kv
):
def
lightning_attention_decode
(
q
,
k
,
v
,
past_kv
,
slope
,
output
,
new_kv
):
torch
.
ops
.
sgl_kernels
.
lightning_attention_decode
(
torch
.
ops
.
sgl_kernels
.
lightning_attention_decode
(
q
,
k
,
v
,
past_kv
,
slope
,
output
,
new_kv
q
,
k
,
v
,
past_kv
,
slope
,
output
,
new_kv
...
...
sgl-kernel/src/sgl-kernel/torch_extension.cc
View file @
82392da8
...
@@ -34,6 +34,12 @@ TORCH_LIBRARY_EXPAND(sgl_kernels, m) {
...
@@ -34,6 +34,12 @@ TORCH_LIBRARY_EXPAND(sgl_kernels, m) {
"bias) -> Tensor"
);
"bias) -> Tensor"
);
m
.
impl
(
"int8_scaled_mm"
,
torch
::
kCUDA
,
&
int8_scaled_mm
);
m
.
impl
(
"int8_scaled_mm"
,
torch
::
kCUDA
,
&
int8_scaled_mm
);
// fp8_scaled_mm
m
.
def
(
"fp8_scaled_mm(Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, ScalarType out_dtype, Tensor? "
"bias) -> Tensor"
);
m
.
impl
(
"fp8_scaled_mm"
,
torch
::
kCUDA
,
&
fp8_scaled_mm
);
// lightning_attention_decode
// lightning_attention_decode
m
.
def
(
m
.
def
(
"lightning_attention_decode(Tensor q, Tensor k, Tensor v, Tensor past_kv, Tensor slope, Tensor! output, Tensor! "
"lightning_attention_decode(Tensor q, Tensor k, Tensor v, Tensor past_kv, Tensor slope, Tensor! output, Tensor! "
...
...
sgl-kernel/tests/test_fp8_gemm.py
0 → 100644
View file @
82392da8
import
unittest
import
torch
from
sgl_kernel
import
fp8_scaled_mm
def
torch_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
):
o
=
torch
.
matmul
(
a
.
to
(
torch
.
float32
),
b
.
to
(
torch
.
float32
))
o
=
o
.
to
(
torch
.
float32
)
temp1
=
o
*
scale_a
.
view
(
-
1
,
1
)
temp2
=
temp1
*
scale_b
.
view
(
1
,
-
1
)
final
=
temp2
.
to
(
out_dtype
)
if
bias
is
not
None
:
final
=
final
+
bias
.
view
(
1
,
-
1
)
return
final
class
TestFp8Gemm
(
unittest
.
TestCase
):
def
_test_accuracy_once
(
self
,
M
,
N
,
K
,
with_bias
,
out_dtype
,
device
):
fp8_info
=
torch
.
finfo
(
torch
.
float8_e4m3fn
)
fp8_max
,
fp8_min
=
fp8_info
.
max
,
fp8_info
.
min
a_fp32
=
(
(
torch
.
rand
(
M
,
K
,
dtype
=
torch
.
float32
,
device
=
device
)
-
0.5
)
*
2
*
fp8_max
)
a_fp8
=
a_fp32
.
clamp
(
min
=
fp8_min
,
max
=
fp8_max
).
to
(
torch
.
float8_e4m3fn
)
b_fp32
=
(
(
torch
.
rand
(
N
,
K
,
dtype
=
torch
.
float32
,
device
=
device
)
-
0.5
)
*
2
*
fp8_max
)
b_fp8
=
b_fp32
.
clamp
(
min
=
fp8_min
,
max
=
fp8_max
).
to
(
torch
.
float8_e4m3fn
)
scale_a
=
torch
.
randn
((
M
,),
device
=
device
,
dtype
=
torch
.
float32
)
*
0.001
scale_b
=
torch
.
randn
((
N
,),
device
=
device
,
dtype
=
torch
.
float32
)
*
0.001
if
with_bias
:
bias
=
torch
.
randn
((
N
,),
device
=
device
,
dtype
=
out_dtype
)
else
:
bias
=
None
o1
=
torch
.
empty
((
M
,
N
),
device
=
device
,
dtype
=
torch
.
bfloat16
)
b_fp8
=
b_fp8
.
t
()
o
=
torch_scaled_mm
(
a_fp8
,
b_fp8
,
scale_a
,
scale_b
,
out_dtype
,
bias
)
o1
=
fp8_scaled_mm
(
a_fp8
,
b_fp8
,
scale_a
,
scale_b
,
out_dtype
,
bias
)
rtol
=
0.02
atol
=
1
torch
.
testing
.
assert_close
(
o
,
o1
,
rtol
=
rtol
,
atol
=
atol
)
print
(
f
"M=
{
M
}
, N=
{
N
}
, K=
{
K
}
, with_bias=
{
with_bias
}
, out_dtype=
{
out_dtype
}
: OK"
)
def
test_accuracy
(
self
):
Ms
=
[
1
,
128
,
512
,
1024
,
4096
]
Ns
=
[
16
,
128
,
512
,
1024
,
4096
]
Ks
=
[
512
,
1024
,
4096
,
8192
,
16384
]
bias_opts
=
[
True
,
False
]
out_dtypes
=
[
torch
.
bfloat16
,
torch
.
float16
]
for
M
in
Ms
:
for
N
in
Ns
:
for
K
in
Ks
:
for
with_bias
in
bias_opts
:
for
out_dtype
in
out_dtypes
:
self
.
_test_accuracy_once
(
M
,
N
,
K
,
with_bias
,
out_dtype
,
"cuda"
)
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment