Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
bf669606
Unverified
Commit
bf669606
authored
Jan 23, 2025
by
Yineng Zhang
Committed by
GitHub
Jan 23, 2025
Browse files
feat: integrate bmm_fp8 kernel into sgl-kernel (#3056)
parent
b2bd8f44
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
131 additions
and
12 deletions
+131
-12
sgl-kernel/setup.py
sgl-kernel/setup.py
+11
-1
sgl-kernel/src/sgl-kernel/__init__.py
sgl-kernel/src/sgl-kernel/__init__.py
+2
-0
sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu
sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu
+6
-0
sgl-kernel/src/sgl-kernel/ops/__init__.py
sgl-kernel/src/sgl-kernel/ops/__init__.py
+50
-11
sgl-kernel/src/sgl-kernel/ops/utils.py
sgl-kernel/src/sgl-kernel/ops/utils.py
+19
-0
sgl-kernel/tests/test_bmm_fp8.py
sgl-kernel/tests/test_bmm_fp8.py
+43
-0
No files found.
sgl-kernel/setup.py
View file @
bf669606
...
@@ -62,12 +62,22 @@ nvcc_flags = [
...
@@ -62,12 +62,22 @@ nvcc_flags = [
"-std=c++17"
,
"-std=c++17"
,
"-use_fast_math"
,
"-use_fast_math"
,
"-DFLASHINFER_ENABLE_F16"
,
"-DFLASHINFER_ENABLE_F16"
,
"-DFLASHINFER_ENABLE_BF16"
,
]
]
if
cuda_version
>=
(
12
,
0
)
and
sm_version
>=
90
:
if
cuda_version
>=
(
12
,
0
)
and
sm_version
>=
90
:
nvcc_flags
.
append
(
"-gencode=arch=compute_90a,code=sm_90a"
)
nvcc_flags
.
append
(
"-gencode=arch=compute_90a,code=sm_90a"
)
if
sm_version
>=
90
:
nvcc_flags
.
extend
(
[
"-DFLASHINFER_ENABLE_FP8"
,
"-DFLASHINFER_ENABLE_FP8_E4M3"
,
"-DFLASHINFER_ENABLE_FP8_E5M2"
,
]
)
if
sm_version
>=
80
:
nvcc_flags
.
append
(
"-DFLASHINFER_ENABLE_BF16"
)
for
flag
in
[
for
flag
in
[
"-D__CUDA_NO_HALF_OPERATORS__"
,
"-D__CUDA_NO_HALF_OPERATORS__"
,
"-D__CUDA_NO_HALF_CONVERSIONS__"
,
"-D__CUDA_NO_HALF_CONVERSIONS__"
,
...
...
sgl-kernel/src/sgl-kernel/__init__.py
View file @
bf669606
from
sgl_kernel.ops
import
(
from
sgl_kernel.ops
import
(
bmm_fp8
,
custom_dispose
,
custom_dispose
,
custom_reduce
,
custom_reduce
,
fused_add_rmsnorm
,
fused_add_rmsnorm
,
...
@@ -18,6 +19,7 @@ from sgl_kernel.ops import (
...
@@ -18,6 +19,7 @@ from sgl_kernel.ops import (
)
)
__all__
=
[
__all__
=
[
"bmm_fp8"
,
"custom_dispose"
,
"custom_dispose"
,
"custom_reduce"
,
"custom_reduce"
,
"fused_add_rmsnorm"
,
"fused_add_rmsnorm"
,
...
...
sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu
View file @
bf669606
...
@@ -52,6 +52,10 @@ void gelu_tanh_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream);
...
@@ -52,6 +52,10 @@ void gelu_tanh_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream);
// gelu and mul
// gelu and mul
void
gelu_and_mul
(
at
::
Tensor
&
out
,
at
::
Tensor
&
input
,
int64_t
cuda_stream
);
void
gelu_and_mul
(
at
::
Tensor
&
out
,
at
::
Tensor
&
input
,
int64_t
cuda_stream
);
// bmm fp8
void
bmm_fp8
(
at
::
Tensor
A
,
at
::
Tensor
B
,
at
::
Tensor
D
,
at
::
Tensor
A_scale
,
at
::
Tensor
B_scale
,
at
::
Tensor
workspace_buffer
,
int64_t
cublas_handle
,
int64_t
cuda_stream
);
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
// trt_reduce
// trt_reduce
m
.
def
(
"init_custom_ar"
,
&
init_custom_ar
,
"init custom allreduce meta (CUDA)"
);
m
.
def
(
"init_custom_ar"
,
&
init_custom_ar
,
"init custom allreduce meta (CUDA)"
);
...
@@ -81,4 +85,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
...
@@ -81,4 +85,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m
.
def
(
"gelu_tanh_and_mul"
,
&
gelu_tanh_and_mul
,
"Gelu Tanh and Mul (CUDA)"
);
m
.
def
(
"gelu_tanh_and_mul"
,
&
gelu_tanh_and_mul
,
"Gelu Tanh and Mul (CUDA)"
);
// gelu and mul
// gelu and mul
m
.
def
(
"gelu_and_mul"
,
&
gelu_and_mul
,
"Gelu and Mul (CUDA)"
);
m
.
def
(
"gelu_and_mul"
,
&
gelu_and_mul
,
"Gelu and Mul (CUDA)"
);
// bmm fp8
m
.
def
(
"bmm_fp8"
,
&
bmm_fp8
,
"BMM FP8 (CUDA)"
);
}
}
sgl-kernel/src/sgl-kernel/ops/__init__.py
View file @
bf669606
...
@@ -2,6 +2,7 @@ from typing import Optional
...
@@ -2,6 +2,7 @@ from typing import Optional
import
torch
import
torch
from
sgl_kernel.ops._kernels
import
all_reduce
as
_all_reduce
from
sgl_kernel.ops._kernels
import
all_reduce
as
_all_reduce
from
sgl_kernel.ops._kernels
import
bmm_fp8
as
_bmm_fp8
from
sgl_kernel.ops._kernels
import
dispose
as
_dispose
from
sgl_kernel.ops._kernels
import
dispose
as
_dispose
from
sgl_kernel.ops._kernels
import
fused_add_rmsnorm
as
_fused_add_rmsnorm
from
sgl_kernel.ops._kernels
import
fused_add_rmsnorm
as
_fused_add_rmsnorm
from
sgl_kernel.ops._kernels
import
gelu_and_mul
as
_gelu_and_mul
from
sgl_kernel.ops._kernels
import
gelu_and_mul
as
_gelu_and_mul
...
@@ -21,10 +22,7 @@ from sgl_kernel.ops._kernels import (
...
@@ -21,10 +22,7 @@ from sgl_kernel.ops._kernels import (
sampling_scaling_penalties
as
_sampling_scaling_penalties
,
sampling_scaling_penalties
as
_sampling_scaling_penalties
,
)
)
from
sgl_kernel.ops._kernels
import
silu_and_mul
as
_silu_and_mul
from
sgl_kernel.ops._kernels
import
silu_and_mul
as
_silu_and_mul
from
sgl_kernel.ops.utils
import
_get_cache_buf
,
_get_cuda_stream
def
get_cuda_stream
(
device
:
torch
.
device
)
->
int
:
return
torch
.
cuda
.
current_stream
(
device
).
cuda_stream
def
init_custom_reduce
(
def
init_custom_reduce
(
...
@@ -101,7 +99,7 @@ def rmsnorm(
...
@@ -101,7 +99,7 @@ def rmsnorm(
with
input
.
device
as
device
:
with
input
.
device
as
device
:
if
out
is
None
:
if
out
is
None
:
out
=
torch
.
empty_like
(
input
)
out
=
torch
.
empty_like
(
input
)
_rmsnorm
(
out
,
input
,
weight
,
eps
,
get_cuda_stream
(
device
))
_rmsnorm
(
out
,
input
,
weight
,
eps
,
_
get_cuda_stream
(
device
))
return
out
return
out
...
@@ -109,7 +107,7 @@ def fused_add_rmsnorm(
...
@@ -109,7 +107,7 @@ def fused_add_rmsnorm(
input
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
eps
:
float
=
1e-6
input
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
eps
:
float
=
1e-6
)
->
None
:
)
->
None
:
with
input
.
device
as
device
:
with
input
.
device
as
device
:
_fused_add_rmsnorm
(
input
,
residual
,
weight
,
eps
,
get_cuda_stream
(
device
))
_fused_add_rmsnorm
(
input
,
residual
,
weight
,
eps
,
_
get_cuda_stream
(
device
))
def
gemma_rmsnorm
(
def
gemma_rmsnorm
(
...
@@ -121,7 +119,7 @@ def gemma_rmsnorm(
...
@@ -121,7 +119,7 @@ def gemma_rmsnorm(
with
input
.
device
as
device
:
with
input
.
device
as
device
:
if
out
is
None
:
if
out
is
None
:
out
=
torch
.
empty_like
(
input
)
out
=
torch
.
empty_like
(
input
)
_gemma_rmsnorm
(
out
,
input
,
weight
,
eps
,
get_cuda_stream
(
device
))
_gemma_rmsnorm
(
out
,
input
,
weight
,
eps
,
_
get_cuda_stream
(
device
))
return
out
return
out
...
@@ -129,7 +127,7 @@ def gemma_fused_add_rmsnorm(
...
@@ -129,7 +127,7 @@ def gemma_fused_add_rmsnorm(
input
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
eps
:
float
=
1e-6
input
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
eps
:
float
=
1e-6
)
->
None
:
)
->
None
:
with
input
.
device
as
device
:
with
input
.
device
as
device
:
_gemma_fused_add_rmsnorm
(
input
,
residual
,
weight
,
eps
,
get_cuda_stream
(
device
))
_gemma_fused_add_rmsnorm
(
input
,
residual
,
weight
,
eps
,
_
get_cuda_stream
(
device
))
def
_check_shape
(
input
:
torch
.
Tensor
,
output
:
torch
.
Tensor
)
->
None
:
def
_check_shape
(
input
:
torch
.
Tensor
,
output
:
torch
.
Tensor
)
->
None
:
...
@@ -154,7 +152,7 @@ def silu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
...
@@ -154,7 +152,7 @@ def silu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
dtype
=
input
.
dtype
,
dtype
=
input
.
dtype
,
)
)
with
input
.
device
as
device
:
with
input
.
device
as
device
:
_silu_and_mul
(
out
,
input
,
get_cuda_stream
(
device
))
_silu_and_mul
(
out
,
input
,
_
get_cuda_stream
(
device
))
return
out
return
out
...
@@ -170,7 +168,7 @@ def gelu_tanh_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Te
...
@@ -170,7 +168,7 @@ def gelu_tanh_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Te
dtype
=
input
.
dtype
,
dtype
=
input
.
dtype
,
)
)
with
input
.
device
as
device
:
with
input
.
device
as
device
:
_gelu_tanh_and_mul
(
out
,
input
,
get_cuda_stream
(
device
))
_gelu_tanh_and_mul
(
out
,
input
,
_
get_cuda_stream
(
device
))
return
out
return
out
...
@@ -186,5 +184,46 @@ def gelu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
...
@@ -186,5 +184,46 @@ def gelu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
dtype
=
input
.
dtype
,
dtype
=
input
.
dtype
,
)
)
with
input
.
device
as
device
:
with
input
.
device
as
device
:
_gelu_and_mul
(
out
,
input
,
get_cuda_stream
(
device
))
_gelu_and_mul
(
out
,
input
,
_get_cuda_stream
(
device
))
return
out
def
_bmm_fp8_internal
(
workspace_buffer
:
torch
.
Tensor
,
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
D
:
torch
.
Tensor
,
A_scale
:
torch
.
Tensor
,
B_scale
:
torch
.
Tensor
,
)
->
None
:
with
A
.
device
as
device
:
cublas_handle
=
torch
.
cuda
.
current_blas_handle
()
_bmm_fp8
(
A
,
B
,
D
,
A_scale
,
B_scale
,
workspace_buffer
,
cublas_handle
,
_get_cuda_stream
(
device
),
)
def
bmm_fp8
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
A_scale
:
torch
.
Tensor
,
B_scale
:
torch
.
Tensor
,
dtype
:
torch
.
dtype
,
out
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
if
out
is
None
:
out
=
torch
.
empty
(
(
A
.
shape
[
0
],
A
.
shape
[
1
],
B
.
shape
[
2
]),
device
=
A
.
device
,
dtype
=
dtype
,
)
workspace_buffer
=
_get_cache_buf
(
"bmm_fp8_workspace"
,
32
*
1024
*
1024
,
A
.
device
)
_bmm_fp8_internal
(
workspace_buffer
,
A
,
B
,
out
,
A_scale
,
B_scale
)
return
out
return
out
sgl-kernel/src/sgl-kernel/ops/utils.py
0 → 100644
View file @
bf669606
from
typing
import
Dict
,
Tuple
import
torch
def
_get_cuda_stream
(
device
:
torch
.
device
)
->
int
:
return
torch
.
cuda
.
current_stream
(
device
).
cuda_stream
_cache_buf
:
Dict
[
Tuple
[
str
,
torch
.
device
],
torch
.
Tensor
]
=
{}
def
_get_cache_buf
(
name
:
str
,
bytes
:
int
,
device
:
torch
.
device
)
->
torch
.
Tensor
:
key
=
(
name
,
device
)
buf
=
_cache_buf
.
get
(
key
)
if
buf
is
None
:
buf
=
torch
.
empty
(
bytes
,
dtype
=
torch
.
uint8
,
device
=
device
)
_cache_buf
[
key
]
=
buf
return
buf
sgl-kernel/tests/test_bmm_fp8.py
0 → 100644
View file @
bf669606
# Adapted from https://github.com/flashinfer-ai/flashinfer/blob/4e8eb1879f9c3ba6d75511e5893183bf8f289a62/tests/test_bmm_fp8.py
import
pytest
import
torch
import
torch.nn.functional
as
F
from
sgl_kernel
import
bmm_fp8
def
to_float8
(
x
,
dtype
=
torch
.
float8_e4m3fn
):
finfo
=
torch
.
finfo
(
dtype
)
min_val
,
max_val
=
x
.
aminmax
()
amax
=
torch
.
maximum
(
min_val
.
abs
(),
max_val
.
abs
()).
clamp
(
min
=
1e-12
)
scale
=
finfo
.
max
/
amax
x_scl_sat
=
(
x
*
scale
).
clamp
(
min
=
finfo
.
min
,
max
=
finfo
.
max
)
return
x_scl_sat
.
to
(
dtype
),
scale
.
float
().
reciprocal
()
@
pytest
.
mark
.
parametrize
(
"input_dtype"
,
[
torch
.
float8_e4m3fn
,
torch
.
float8_e5m2
])
@
pytest
.
mark
.
parametrize
(
"mat2_dtype"
,
[
torch
.
float8_e4m3fn
,
torch
.
float8_e5m2
])
@
pytest
.
mark
.
parametrize
(
"res_dtype"
,
[
torch
.
bfloat16
,
torch
.
float16
])
def
test_bmm_fp8
(
input_dtype
,
mat2_dtype
,
res_dtype
):
if
input_dtype
==
torch
.
float8_e5m2
and
mat2_dtype
==
torch
.
float8_e5m2
:
pytest
.
skip
(
"Invalid combination: both input and mat2 are e5m2"
)
input
=
torch
.
randn
([
16
,
48
,
64
],
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
input_fp8
,
input_inv_s
=
to_float8
(
input
,
dtype
=
input_dtype
)
# mat2 row major -> column major
mat2
=
torch
.
randn
([
16
,
80
,
64
],
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
).
transpose
(
-
2
,
-
1
)
mat2_fp8
,
mat2_inv_s
=
to_float8
(
mat2
,
dtype
=
mat2_dtype
)
res
=
torch
.
empty
([
16
,
48
,
80
],
device
=
"cuda"
,
dtype
=
res_dtype
)
bmm_fp8
(
input_fp8
,
mat2_fp8
,
input_inv_s
,
mat2_inv_s
,
res_dtype
,
res
)
reference
=
torch
.
bmm
(
input
,
mat2
)
cos_sim
=
F
.
cosine_similarity
(
reference
.
reshape
(
-
1
),
res
.
reshape
(
-
1
),
dim
=
0
)
assert
cos_sim
>
0.99
if
__name__
==
"__main__"
:
pytest
.
main
([
__file__
])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment