Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
5f1a485d
Unverified
Commit
5f1a485d
authored
Feb 17, 2025
by
Yineng Zhang
Committed by
GitHub
Feb 17, 2025
Browse files
Revert "[ROCm] Use `tl.range()` in block GEMM kernels with `num_stage… (#3632)
parent
c9565e49
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
6 additions
and
101 deletions
+6
-101
python/sglang/srt/layers/quantization/fp8_kernel.py
python/sglang/srt/layers/quantization/fp8_kernel.py
+6
-101
No files found.
python/sglang/srt/layers/quantization/fp8_kernel.py
View file @
5f1a485d
...
...
@@ -272,7 +272,6 @@ def _w8a8_block_fp8_matmul(
BLOCK_SIZE_N
:
tl
.
constexpr
,
BLOCK_SIZE_K
:
tl
.
constexpr
,
GROUP_SIZE_M
:
tl
.
constexpr
,
num_stages
:
tl
.
constexpr
,
):
"""Triton-accelerated function used to perform linear operations (dot
product) on input tensors `A` and `B` with block-wise quantization, and store the result in output
...
...
@@ -358,7 +357,6 @@ def _w8a8_block_fp8_matmul_unrolledx4(
BLOCK_SIZE_N
:
tl
.
constexpr
,
BLOCK_SIZE_K
:
tl
.
constexpr
,
GROUP_SIZE_M
:
tl
.
constexpr
,
num_stages
:
tl
.
constexpr
,
):
"""Triton-accelerated function used to perform linear operations (dot
product) on input tensors `A` and `B` with block-wise quantization, and store the result in output
...
...
@@ -388,9 +386,7 @@ def _w8a8_block_fp8_matmul_unrolledx4(
accumulator
=
tl
.
zeros
((
BLOCK_SIZE_M
,
BLOCK_SIZE_N
),
dtype
=
tl
.
float32
)
# manually unroll to 4 iterations
UNROLL_FACTOR
=
4
for
k
in
tl
.
range
(
0
,
tl
.
cdiv
(
K
,
BLOCK_SIZE_K
*
UNROLL_FACTOR
),
num_stages
=
num_stages
):
for
k
in
range
(
0
,
tl
.
cdiv
(
K
,
BLOCK_SIZE_K
*
UNROLL_FACTOR
)):
# 1st iteration
a
=
tl
.
load
(
a_ptrs
,
...
...
@@ -489,92 +485,6 @@ def _w8a8_block_fp8_matmul_unrolledx4(
tl
.
store
(
c_ptrs
,
c
,
mask
=
c_mask
)
@
triton
.
jit
def
_w8a8_block_fp8_matmul_hip
(
# Pointers to inputs and output
A
,
B
,
C
,
As
,
Bs
,
# Shape for matmul
M
,
N
,
K
,
# Block size for block-wise quantization
group_n
,
group_k
,
# Stride for inputs and output
stride_am
,
stride_ak
,
stride_bk
,
stride_bn
,
stride_cm
,
stride_cn
,
stride_As_m
,
stride_As_k
,
stride_Bs_k
,
stride_Bs_n
,
# Meta-parameters
BLOCK_SIZE_M
:
tl
.
constexpr
,
BLOCK_SIZE_N
:
tl
.
constexpr
,
BLOCK_SIZE_K
:
tl
.
constexpr
,
GROUP_SIZE_M
:
tl
.
constexpr
,
num_stages
:
tl
.
constexpr
,
):
"""Triton-accelerated function used to perform linear operations (dot
product) on input tensors `A` and `B` with block-wise quantization, and store the result in output
tensor `C`.
"""
pid
=
tl
.
program_id
(
axis
=
0
)
num_pid_m
=
tl
.
cdiv
(
M
,
BLOCK_SIZE_M
)
num_pid_n
=
tl
.
cdiv
(
N
,
BLOCK_SIZE_N
)
num_pid_in_group
=
GROUP_SIZE_M
*
num_pid_n
group_id
=
pid
//
num_pid_in_group
first_pid_m
=
group_id
*
GROUP_SIZE_M
group_size_m
=
min
(
num_pid_m
-
first_pid_m
,
GROUP_SIZE_M
)
pid_m
=
first_pid_m
+
(
pid
%
group_size_m
)
pid_n
=
(
pid
%
num_pid_in_group
)
//
group_size_m
offs_am
=
(
pid_m
*
BLOCK_SIZE_M
+
tl
.
arange
(
0
,
BLOCK_SIZE_M
))
%
M
offs_bn
=
(
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
))
%
N
offs_k
=
tl
.
arange
(
0
,
BLOCK_SIZE_K
)
a_ptrs
=
A
+
(
offs_am
[:,
None
]
*
stride_am
+
offs_k
[
None
,
:]
*
stride_ak
)
b_ptrs
=
B
+
(
offs_k
[:,
None
]
*
stride_bk
+
offs_bn
[
None
,
:]
*
stride_bn
)
As_ptrs
=
As
+
offs_am
*
stride_As_m
offs_bsn
=
offs_bn
//
group_n
Bs_ptrs
=
Bs
+
offs_bsn
*
stride_Bs_n
accumulator
=
tl
.
zeros
((
BLOCK_SIZE_M
,
BLOCK_SIZE_N
),
dtype
=
tl
.
float32
)
for
k
in
tl
.
range
(
0
,
tl
.
cdiv
(
K
,
BLOCK_SIZE_K
),
num_stages
=
num_stages
):
a
=
tl
.
load
(
a_ptrs
,
mask
=
offs_k
[
None
,
:]
<
K
-
k
*
BLOCK_SIZE_K
,
other
=
0.0
)
b
=
tl
.
load
(
b_ptrs
,
mask
=
offs_k
[:,
None
]
<
K
-
k
*
BLOCK_SIZE_K
,
other
=
0.0
)
k_start
=
k
*
BLOCK_SIZE_K
offs_ks
=
k_start
//
group_k
a_s
=
tl
.
load
(
As_ptrs
+
offs_ks
*
stride_As_k
)
b_s
=
tl
.
load
(
Bs_ptrs
+
offs_ks
*
stride_Bs_k
)
accumulator
+=
tl
.
dot
(
a
,
b
)
*
a_s
[:,
None
]
*
b_s
[
None
,
:]
a_ptrs
+=
BLOCK_SIZE_K
*
stride_ak
b_ptrs
+=
BLOCK_SIZE_K
*
stride_bk
if
C
.
dtype
.
element_ty
==
tl
.
bfloat16
:
c
=
accumulator
.
to
(
tl
.
bfloat16
)
elif
C
.
dtype
.
element_ty
==
tl
.
float16
:
c
=
accumulator
.
to
(
tl
.
float16
)
else
:
c
=
accumulator
.
to
(
tl
.
float32
)
offs_cm
=
pid_m
*
BLOCK_SIZE_M
+
tl
.
arange
(
0
,
BLOCK_SIZE_M
)
offs_cn
=
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
)
c_ptrs
=
C
+
stride_cm
*
offs_cm
[:,
None
]
+
stride_cn
*
offs_cn
[
None
,
:]
c_mask
=
(
offs_cm
[:,
None
]
<
M
)
&
(
offs_cn
[
None
,
:]
<
N
)
tl
.
store
(
c_ptrs
,
c
,
mask
=
c_mask
)
@
functools
.
lru_cache
def
get_w8a8_block_fp8_configs
(
N
:
int
,
K
:
int
,
block_n
:
int
,
block_k
:
int
...
...
@@ -685,16 +595,11 @@ def w8a8_block_fp8_matmul(
num_workgroups
=
triton
.
cdiv
(
M
,
config
[
"BLOCK_SIZE_M"
])
*
triton
.
cdiv
(
N
,
config
[
"BLOCK_SIZE_N"
]
)
kernel
=
_w8a8_block_fp8_matmul
# On AMD GPU, use kernels where software pipelining with num_stages is
# explicitly specified in the hot loop.
if
is_hip_
==
True
:
if
num_workgroups
<=
get_device_core_count
():
kernel
=
_w8a8_block_fp8_matmul_unrolledx4
else
:
kernel
=
_w8a8_block_fp8_matmul_hip
kernel
=
(
_w8a8_block_fp8_matmul_unrolledx4
if
(
is_hip_
==
True
and
num_workgroups
<=
get_device_core_count
())
else
_w8a8_block_fp8_matmul
)
kernel
[
grid
](
A
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment