Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
32de54ed
Unverified
Commit
32de54ed
authored
Feb 05, 2025
by
Wen-Heng (Jack) Chung
Committed by
GitHub
Feb 05, 2025
Browse files
[ROCm] Fix fp8 unrolledx4 matmul kernel. (#3325)
Co-authored-by:
HAI
<
hixiao@gmail.com
>
parent
2d9c3195
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
43 additions
and
10 deletions
+43
-10
python/sglang/srt/layers/quantization/fp8_kernel.py
python/sglang/srt/layers/quantization/fp8_kernel.py
+43
-10
No files found.
python/sglang/srt/layers/quantization/fp8_kernel.py
View file @
32de54ed
...
...
@@ -279,12 +279,21 @@ def _w8a8_block_fp8_matmul_unrolledx4(
accumulator
=
tl
.
zeros
((
BLOCK_SIZE_M
,
BLOCK_SIZE_N
),
dtype
=
tl
.
float32
)
# manually unroll to 4 iterations
for
k
in
range
(
0
,
tl
.
cdiv
(
K
,
BLOCK_SIZE_K
)
//
4
):
UNROLL_FACTOR
=
4
for
k
in
range
(
0
,
tl
.
cdiv
(
K
,
BLOCK_SIZE_K
*
UNROLL_FACTOR
)):
# 1st iteration
a
=
tl
.
load
(
a_ptrs
,
mask
=
offs_k
[
None
,
:]
<
K
-
k
*
BLOCK_SIZE_K
,
other
=
0.0
)
b
=
tl
.
load
(
b_ptrs
,
mask
=
offs_k
[:,
None
]
<
K
-
k
*
BLOCK_SIZE_K
,
other
=
0.0
)
a
=
tl
.
load
(
a_ptrs
,
mask
=
offs_k
[
None
,
:]
<
K
-
(
k
*
UNROLL_FACTOR
)
*
BLOCK_SIZE_K
,
other
=
0.0
,
)
b
=
tl
.
load
(
b_ptrs
,
mask
=
offs_k
[:,
None
]
<
K
-
(
k
*
UNROLL_FACTOR
)
*
BLOCK_SIZE_K
,
other
=
0.0
,
)
k_start
=
k
*
BLOCK_SIZE_K
k_start
=
(
k
*
UNROLL_FACTOR
)
*
BLOCK_SIZE_K
offs_ks
=
k_start
//
group_k
a_s
=
tl
.
load
(
As_ptrs
+
offs_ks
*
stride_As_k
)
b_s
=
tl
.
load
(
Bs_ptrs
+
offs_ks
*
stride_Bs_k
)
...
...
@@ -294,8 +303,16 @@ def _w8a8_block_fp8_matmul_unrolledx4(
b_ptrs
+=
BLOCK_SIZE_K
*
stride_bk
# 2nd iteration
a
=
tl
.
load
(
a_ptrs
,
mask
=
offs_k
[
None
,
:]
<
K
-
k
*
BLOCK_SIZE_K
,
other
=
0.0
)
b
=
tl
.
load
(
b_ptrs
,
mask
=
offs_k
[:,
None
]
<
K
-
k
*
BLOCK_SIZE_K
,
other
=
0.0
)
a
=
tl
.
load
(
a_ptrs
,
mask
=
offs_k
[
None
,
:]
<
K
-
(
k
*
UNROLL_FACTOR
+
1
)
*
BLOCK_SIZE_K
,
other
=
0.0
,
)
b
=
tl
.
load
(
b_ptrs
,
mask
=
offs_k
[:,
None
]
<
K
-
(
k
*
UNROLL_FACTOR
+
1
)
*
BLOCK_SIZE_K
,
other
=
0.0
,
)
k_start
=
k_start
+
BLOCK_SIZE_K
offs_ks
=
k_start
//
group_k
...
...
@@ -307,8 +324,16 @@ def _w8a8_block_fp8_matmul_unrolledx4(
b_ptrs
+=
BLOCK_SIZE_K
*
stride_bk
# 3rd iteration
a
=
tl
.
load
(
a_ptrs
,
mask
=
offs_k
[
None
,
:]
<
K
-
k
*
BLOCK_SIZE_K
,
other
=
0.0
)
b
=
tl
.
load
(
b_ptrs
,
mask
=
offs_k
[:,
None
]
<
K
-
k
*
BLOCK_SIZE_K
,
other
=
0.0
)
a
=
tl
.
load
(
a_ptrs
,
mask
=
offs_k
[
None
,
:]
<
K
-
(
k
*
UNROLL_FACTOR
+
2
)
*
BLOCK_SIZE_K
,
other
=
0.0
,
)
b
=
tl
.
load
(
b_ptrs
,
mask
=
offs_k
[:,
None
]
<
K
-
(
k
*
UNROLL_FACTOR
+
2
)
*
BLOCK_SIZE_K
,
other
=
0.0
,
)
k_start
=
k_start
+
BLOCK_SIZE_K
offs_ks
=
k_start
//
group_k
...
...
@@ -320,8 +345,16 @@ def _w8a8_block_fp8_matmul_unrolledx4(
b_ptrs
+=
BLOCK_SIZE_K
*
stride_bk
# 4th iteration
a
=
tl
.
load
(
a_ptrs
,
mask
=
offs_k
[
None
,
:]
<
K
-
k
*
BLOCK_SIZE_K
,
other
=
0.0
)
b
=
tl
.
load
(
b_ptrs
,
mask
=
offs_k
[:,
None
]
<
K
-
k
*
BLOCK_SIZE_K
,
other
=
0.0
)
a
=
tl
.
load
(
a_ptrs
,
mask
=
offs_k
[
None
,
:]
<
K
-
(
k
*
UNROLL_FACTOR
+
3
)
*
BLOCK_SIZE_K
,
other
=
0.0
,
)
b
=
tl
.
load
(
b_ptrs
,
mask
=
offs_k
[:,
None
]
<
K
-
(
k
*
UNROLL_FACTOR
+
3
)
*
BLOCK_SIZE_K
,
other
=
0.0
,
)
k_start
=
k_start
+
BLOCK_SIZE_K
offs_ks
=
k_start
//
group_k
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment