Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
82653f66
"git@developer.sourcefind.cn:change/sglang.git" did not exist on "3c699772c9fa27b734bf91d466306a4ec1b628b4"
Unverified
Commit
82653f66
authored
May 06, 2025
by
DefTruth
Committed by
GitHub
May 05, 2025
Browse files
feat: Add a unified merge_state API (#5428)
parent
22da3d97
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
142 additions
and
0 deletions
+142
-0
python/sglang/srt/layers/attention/merge_state.py
python/sglang/srt/layers/attention/merge_state.py
+46
-0
python/sglang/srt/layers/attention/triton_ops/merge_state.py
python/sglang/srt/layers/attention/triton_ops/merge_state.py
+96
-0
No files found.
python/sglang/srt/layers/attention/merge_state.py
0 → 100644
View file @
82653f66
from
typing
import
Optional
,
Tuple
import
torch
from
sgl_kernel
import
merge_state_v2
from
sglang.srt.layers.attention.triton_ops.merge_state
import
merge_state_triton
from
sglang.srt.utils
import
is_cuda
_is_cuda
=
is_cuda
()
# Automatically fallback to the Triton kernel in some cases
# (e.g., for AMD GPUs, when the head dimension is not a multiple
# of 4 or 8, and in FP8 precision)
def
_supported_dtypes
(
o
:
torch
.
Tensor
)
->
bool
:
return
o
.
dtype
in
[
torch
.
float32
,
torch
.
half
,
torch
.
bfloat16
]
def
_supported_headdim
(
o
:
torch
.
Tensor
)
->
bool
:
headdim
=
o
.
shape
[
2
]
# [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
if
o
.
dtype
==
torch
.
float32
:
return
headdim
%
4
==
0
return
headdim
%
8
==
0
def
merge_state
(
prefix_output
:
torch
.
Tensor
,
prefix_lse
:
torch
.
Tensor
,
suffix_output
:
torch
.
Tensor
,
suffix_lse
:
torch
.
Tensor
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
output_lse
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
if
(
_is_cuda
and
_supported_dtypes
(
prefix_output
)
and
_supported_headdim
(
prefix_output
)
):
return
merge_state_v2
(
prefix_output
,
prefix_lse
,
suffix_output
,
suffix_lse
,
output
,
output_lse
)
else
:
# Fallback to Triton kernel
return
merge_state_triton
(
prefix_output
,
prefix_lse
,
suffix_output
,
suffix_lse
,
output
,
output_lse
)
python/sglang/srt/layers/attention/triton_ops/merge_state.py
0 → 100644
View file @
82653f66
from
typing
import
Optional
,
Tuple
import
torch
import
triton
import
triton.language
as
tl
@
triton
.
jit
def
merge_state_kernel
(
output
,
# [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] v_merged
output_lse
,
# [NUM_TOKENS, NUM_HEADS] s_merged
prefix_output
,
# [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] v_a
prefix_lse
,
# [NUM_TOKENS, NUM_HEADS] s_a
suffix_output
,
# [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] v_b
suffix_lse
,
# [NUM_TOKENS, NUM_HEADS] s_b
HEAD_SIZE
:
tl
.
constexpr
,
PADDED_HEAD_SIZE
:
tl
.
constexpr
,
OUTPUT_LSE
:
tl
.
constexpr
,
):
token_idx
=
tl
.
program_id
(
0
)
num_tokens
=
tl
.
num_programs
(
0
)
head_idx
=
tl
.
program_id
(
1
)
num_heads
=
tl
.
num_programs
(
1
)
p_lse
=
tl
.
load
(
prefix_lse
+
token_idx
*
num_heads
+
head_idx
)
s_lse
=
tl
.
load
(
suffix_lse
+
token_idx
*
num_heads
+
head_idx
)
p_lse
=
float
(
"-inf"
)
if
p_lse
==
float
(
"inf"
)
else
p_lse
s_lse
=
float
(
"-inf"
)
if
s_lse
==
float
(
"inf"
)
else
s_lse
max_lse
=
tl
.
maximum
(
p_lse
,
s_lse
)
p_lse
=
p_lse
-
max_lse
s_lse
=
s_lse
-
max_lse
out_se
=
tl
.
exp
(
p_lse
)
+
tl
.
exp
(
s_lse
)
if
OUTPUT_LSE
:
out_lse
=
tl
.
log
(
out_se
)
+
max_lse
tl
.
store
(
output_lse
+
token_idx
*
num_heads
+
head_idx
,
out_lse
)
head_arange
=
tl
.
arange
(
0
,
PADDED_HEAD_SIZE
)
head_mask
=
head_arange
<
HEAD_SIZE
p_out
=
tl
.
load
(
prefix_output
+
token_idx
*
num_heads
*
HEAD_SIZE
+
head_idx
*
HEAD_SIZE
+
head_arange
,
mask
=
head_mask
,
)
s_out
=
tl
.
load
(
suffix_output
+
token_idx
*
num_heads
*
HEAD_SIZE
+
head_idx
*
HEAD_SIZE
+
head_arange
,
mask
=
head_mask
,
)
p_scale
=
tl
.
exp
(
p_lse
)
/
out_se
s_scale
=
tl
.
exp
(
s_lse
)
/
out_se
out
=
p_out
*
p_scale
+
s_out
*
s_scale
tl
.
store
(
output
+
token_idx
*
num_heads
*
HEAD_SIZE
+
head_idx
*
HEAD_SIZE
+
head_arange
,
out
,
mask
=
head_mask
,
)
def
merge_state_triton
(
prefix_output
:
torch
.
Tensor
,
prefix_lse
:
torch
.
Tensor
,
suffix_output
:
torch
.
Tensor
,
suffix_lse
:
torch
.
Tensor
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
output_lse
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
# Avoid creating new tensors if they are already provided
if
output
is
None
:
output
=
torch
.
empty_like
(
prefix_output
)
if
output_lse
is
None
:
output_lse
=
torch
.
empty_like
(
prefix_lse
)
num_tokens
=
output
.
shape
[
0
]
num_query_heads
=
output
.
shape
[
1
]
head_size
=
output
.
shape
[
2
]
padded_head_size
=
triton
.
next_power_of_2
(
head_size
)
merge_state_kernel
[(
num_tokens
,
num_query_heads
)](
output
,
output_lse
,
prefix_output
,
prefix_lse
,
suffix_output
,
suffix_lse
,
head_size
,
padded_head_size
,
output_lse
is
not
None
,
)
return
output
,
output_lse
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment