Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
45e3a7bc
Unverified
Commit
45e3a7bc
authored
Feb 12, 2025
by
Xiaoyu Zhang
Committed by
GitHub
Feb 12, 2025
Browse files
use sgl_per_token_group_quant_fp8 kernel (#3493)
parent
b96e92e6
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
43 additions
and
2 deletions
+43
-2
python/pyproject.toml
python/pyproject.toml
+1
-1
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
+8
-1
python/sglang/srt/layers/quantization/fp8_kernel.py
python/sglang/srt/layers/quantization/fp8_kernel.py
+34
-0
No files found.
python/pyproject.toml
View file @
45e3a7bc
...
@@ -25,7 +25,7 @@ runtime_common = [
...
@@ -25,7 +25,7 @@ runtime_common = [
]
]
srt
=
[
srt
=
[
"sglang[runtime_common]"
,
"cuda-python"
,
"sglang[runtime_common]"
,
"cuda-python"
,
"sgl-kernel>=0.0.3.post
3
"
,
"torch"
,
"vllm>=0.6.4.post1,<=0.7.2"
,
"sgl-kernel>=0.0.3.post
4
"
,
"torch"
,
"vllm>=0.6.4.post1,<=0.7.2"
,
"flashinfer_python>=0.2.0.post2"
,
"outlines>=0.0.44,<=0.1.11"
"flashinfer_python>=0.2.0.post2"
,
"outlines>=0.0.44,<=0.1.11"
]
]
...
...
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
View file @
45e3a7bc
...
@@ -33,6 +33,10 @@ _is_rocm = torch.cuda.is_available() and torch.version.hip
...
@@ -33,6 +33,10 @@ _is_rocm = torch.cuda.is_available() and torch.version.hip
if
_is_cuda
:
if
_is_cuda
:
from
sgl_kernel
import
gelu_and_mul
,
silu_and_mul
from
sgl_kernel
import
gelu_and_mul
,
silu_and_mul
from
sglang.srt.layers.quantization.fp8_kernel
import
(
sglang_per_token_group_quant_fp8
,
)
if
_is_cuda
or
_is_rocm
:
if
_is_cuda
or
_is_rocm
:
from
sgl_kernel
import
moe_align_block_size
as
sgl_moe_align_block_size
from
sgl_kernel
import
moe_align_block_size
as
sgl_moe_align_block_size
...
@@ -488,7 +492,10 @@ def invoke_fused_moe_kernel(
...
@@ -488,7 +492,10 @@ def invoke_fused_moe_kernel(
else
:
else
:
assert
len
(
block_shape
)
==
2
assert
len
(
block_shape
)
==
2
block_n
,
block_k
=
block_shape
[
0
],
block_shape
[
1
]
block_n
,
block_k
=
block_shape
[
0
],
block_shape
[
1
]
A
,
A_scale
=
per_token_group_quant_fp8
(
A
,
block_k
)
if
_is_cuda
:
A
,
A_scale
=
sglang_per_token_group_quant_fp8
(
A
,
block_k
)
else
:
A
,
A_scale
=
per_token_group_quant_fp8
(
A
,
block_k
)
assert
triton
.
cdiv
(
A
.
shape
[
-
1
],
block_k
)
==
A_scale
.
shape
[
-
1
]
assert
triton
.
cdiv
(
A
.
shape
[
-
1
],
block_k
)
==
A_scale
.
shape
[
-
1
]
assert
triton
.
cdiv
(
B
.
shape
[
-
2
],
block_n
)
==
B_scale
.
shape
[
-
2
]
assert
triton
.
cdiv
(
B
.
shape
[
-
2
],
block_n
)
==
B_scale
.
shape
[
-
2
]
assert
triton
.
cdiv
(
B
.
shape
[
-
1
],
block_k
)
==
B_scale
.
shape
[
-
1
]
assert
triton
.
cdiv
(
B
.
shape
[
-
1
],
block_k
)
==
B_scale
.
shape
[
-
1
]
...
...
python/sglang/srt/layers/quantization/fp8_kernel.py
View file @
45e3a7bc
...
@@ -27,6 +27,10 @@ from sglang.srt.utils import get_device_core_count, get_device_name, is_hip
...
@@ -27,6 +27,10 @@ from sglang.srt.utils import get_device_core_count, get_device_name, is_hip
is_hip_
=
is_hip
()
is_hip_
=
is_hip
()
fp8_type_
=
torch
.
float8_e4m3fnuz
if
is_hip_
else
torch
.
float8_e4m3fn
fp8_type_
=
torch
.
float8_e4m3fnuz
if
is_hip_
else
torch
.
float8_e4m3fn
_is_cuda
=
torch
.
cuda
.
is_available
()
and
torch
.
version
.
cuda
if
_is_cuda
:
from
sgl_kernel
import
sgl_per_token_group_quant_fp8
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -135,6 +139,36 @@ def per_token_group_quant_fp8(
...
@@ -135,6 +139,36 @@ def per_token_group_quant_fp8(
return
x_q
,
x_s
return
x_q
,
x_s
def
sglang_per_token_group_quant_fp8
(
x
:
torch
.
Tensor
,
group_size
:
int
,
eps
:
float
=
1e-10
,
dtype
:
torch
.
dtype
=
fp8_type_
,
):
assert
(
x
.
shape
[
-
1
]
%
group_size
==
0
),
"the last dimension of `x` cannot be divisible by `group_size`"
assert
x
.
is_contiguous
(),
"`x` is not contiguous"
finfo
=
torch
.
finfo
(
dtype
)
fp8_max
=
finfo
.
max
fp8_min
=
-
fp8_max
x_q
=
torch
.
empty_like
(
x
,
device
=
x
.
device
,
dtype
=
dtype
)
M
=
x
.
numel
()
//
group_size
N
=
group_size
x_s
=
torch
.
empty
(
x
.
shape
[:
-
1
]
+
(
x
.
shape
[
-
1
]
//
group_size
,),
device
=
x
.
device
,
dtype
=
torch
.
float32
,
)
sgl_per_token_group_quant_fp8
(
x
,
x_q
,
x_s
,
group_size
,
eps
,
fp8_min
,
fp8_max
)
return
x_q
,
x_s
@
triton
.
jit
@
triton
.
jit
def
_w8a8_block_fp8_matmul
(
def
_w8a8_block_fp8_matmul
(
# Pointers to inputs and output
# Pointers to inputs and output
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment