Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
087ab832
Unverified
Commit
087ab832
authored
Nov 10, 2024
by
HAI
Committed by
GitHub
Nov 10, 2024
Browse files
[Performance, Triton] Optimize over mask compute to tl.load in fused_moe_kernel (#1980)
parent
8169c6f4
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
30 additions
and
7 deletions
+30
-7
python/sglang/srt/layers/attention/triton_ops/decode_attention.py
...glang/srt/layers/attention/triton_ops/decode_attention.py
+7
-0
python/sglang/srt/layers/fused_moe/fused_moe.py
python/sglang/srt/layers/fused_moe/fused_moe.py
+23
-7
No files found.
python/sglang/srt/layers/attention/triton_ops/decode_attention.py
View file @
087ab832
...
@@ -507,6 +507,12 @@ def _decode_grouped_att_m_fwd(
...
@@ -507,6 +507,12 @@ def _decode_grouped_att_m_fwd(
num_warps
=
4
num_warps
=
4
extra_kargs
=
{}
if
is_hip
():
# https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
extra_kargs
=
{
"waves_per_eu"
:
4
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
}
_fwd_grouped_kernel_stage1
[
grid
](
_fwd_grouped_kernel_stage1
[
grid
](
q
,
q
,
k_buffer
,
k_buffer
,
...
@@ -532,6 +538,7 @@ def _decode_grouped_att_m_fwd(
...
@@ -532,6 +538,7 @@ def _decode_grouped_att_m_fwd(
num_warps
=
num_warps
,
num_warps
=
num_warps
,
num_stages
=
1
,
num_stages
=
1
,
Lk
=
Lk
,
Lk
=
Lk
,
**
extra_kargs
,
)
)
...
...
python/sglang/srt/layers/fused_moe/fused_moe.py
View file @
087ab832
...
@@ -54,6 +54,7 @@ def fused_moe_kernel(
...
@@ -54,6 +54,7 @@ def fused_moe_kernel(
top_k
:
tl
.
constexpr
,
top_k
:
tl
.
constexpr
,
compute_type
:
tl
.
constexpr
,
compute_type
:
tl
.
constexpr
,
use_fp8
:
tl
.
constexpr
,
use_fp8
:
tl
.
constexpr
,
even_Ks
:
tl
.
constexpr
,
):
):
"""
"""
Implements the fused computation for a Mixture of Experts (MOE) using
Implements the fused computation for a Mixture of Experts (MOE) using
...
@@ -130,16 +131,24 @@ def fused_moe_kernel(
...
@@ -130,16 +131,24 @@ def fused_moe_kernel(
# of fp32 values for higher accuracy.
# of fp32 values for higher accuracy.
# `accumulator` will be converted back to fp16 after the loop.
# `accumulator` will be converted back to fp16 after the loop.
accumulator
=
tl
.
zeros
((
BLOCK_SIZE_M
,
BLOCK_SIZE_N
),
dtype
=
tl
.
float32
)
accumulator
=
tl
.
zeros
((
BLOCK_SIZE_M
,
BLOCK_SIZE_N
),
dtype
=
tl
.
float32
)
for
k
in
range
(
0
,
tl
.
cdiv
(
K
,
BLOCK_SIZE_K
)):
for
k
in
range
(
0
,
tl
.
cdiv
(
K
,
BLOCK_SIZE_K
)):
# Load the next block of A and B, generate a mask by checking the
# Load the next block of A and B, generate a mask by checking the
# K dimension.
# K dimension.
a
=
tl
.
load
(
if
even_Ks
:
a_ptrs
,
a
=
tl
.
load
(
mask
=
token_mask
[:,
None
]
&
(
offs_k
[
None
,
:]
<
K
-
k
*
BLOCK_SIZE_K
),
a_ptrs
,
other
=
0.0
,
mask
=
token_mask
[:,
None
],
)
other
=
0.0
,
b
=
tl
.
load
(
b_ptrs
,
mask
=
offs_k
[:,
None
]
<
K
-
k
*
BLOCK_SIZE_K
,
other
=
0.0
)
)
b
=
tl
.
load
(
b_ptrs
)
else
:
a
=
tl
.
load
(
a_ptrs
,
mask
=
token_mask
[:,
None
]
&
(
offs_k
[
None
,
:]
<
K
-
k
*
BLOCK_SIZE_K
),
other
=
0.0
,
)
b
=
tl
.
load
(
b_ptrs
,
mask
=
offs_k
[:,
None
]
<
K
-
k
*
BLOCK_SIZE_K
,
other
=
0.0
)
# We accumulate along the K dimension.
# We accumulate along the K dimension.
if
use_fp8
:
if
use_fp8
:
accumulator
=
tl
.
dot
(
a
,
b
,
acc
=
accumulator
)
accumulator
=
tl
.
dot
(
a
,
b
,
acc
=
accumulator
)
...
@@ -253,6 +262,12 @@ def invoke_fused_moe_kernel(
...
@@ -253,6 +262,12 @@ def invoke_fused_moe_kernel(
*
triton
.
cdiv
(
B
.
shape
[
1
],
META
[
"BLOCK_SIZE_N"
]),
*
triton
.
cdiv
(
B
.
shape
[
1
],
META
[
"BLOCK_SIZE_N"
]),
)
)
K
=
B
.
shape
[
2
]
-
padding_size
if
K
%
config
[
"BLOCK_SIZE_K"
]
==
0
:
even_ks
=
True
else
:
even_ks
=
False
fused_moe_kernel
[
grid
](
fused_moe_kernel
[
grid
](
A
,
A
,
B
,
B
,
...
@@ -278,6 +293,7 @@ def invoke_fused_moe_kernel(
...
@@ -278,6 +293,7 @@ def invoke_fused_moe_kernel(
top_k
=
top_k
,
top_k
=
top_k
,
compute_type
=
compute_type
,
compute_type
=
compute_type
,
use_fp8
=
use_fp8
,
use_fp8
=
use_fp8
,
even_Ks
=
even_ks
,
**
config
,
**
config
,
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment