Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
dc491b39
Unverified
Commit
dc491b39
authored
Sep 11, 2025
by
Yi Zhang
Committed by
GitHub
Sep 10, 2025
Browse files
add flash linear attention triton kernel (#10239)
parent
5b64f006
Changes
14
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
3590 additions
and
0 deletions
+3590
-0
python/sglang/srt/layers/attention/fla/chunk.py
python/sglang/srt/layers/attention/fla/chunk.py
+242
-0
python/sglang/srt/layers/attention/fla/chunk_delta_h.py
python/sglang/srt/layers/attention/fla/chunk_delta_h.py
+314
-0
python/sglang/srt/layers/attention/fla/chunk_o.py
python/sglang/srt/layers/attention/fla/chunk_o.py
+178
-0
python/sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py
...n/sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py
+151
-0
python/sglang/srt/layers/attention/fla/cumsum.py
python/sglang/srt/layers/attention/fla/cumsum.py
+300
-0
python/sglang/srt/layers/attention/fla/fused_recurrent.py
python/sglang/srt/layers/attention/fla/fused_recurrent.py
+640
-0
python/sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py
...rt/layers/attention/fla/fused_sigmoid_gating_recurrent.py
+232
-0
python/sglang/srt/layers/attention/fla/index.py
python/sglang/srt/layers/attention/fla/index.py
+37
-0
python/sglang/srt/layers/attention/fla/l2norm.py
python/sglang/srt/layers/attention/fla/l2norm.py
+150
-0
python/sglang/srt/layers/attention/fla/layernorm_gated.py
python/sglang/srt/layers/attention/fla/layernorm_gated.py
+326
-0
python/sglang/srt/layers/attention/fla/op.py
python/sglang/srt/layers/attention/fla/op.py
+66
-0
python/sglang/srt/layers/attention/fla/solve_tril.py
python/sglang/srt/layers/attention/fla/solve_tril.py
+465
-0
python/sglang/srt/layers/attention/fla/utils.py
python/sglang/srt/layers/attention/fla/utils.py
+331
-0
python/sglang/srt/layers/attention/fla/wy_fast.py
python/sglang/srt/layers/attention/fla/wy_fast.py
+158
-0
No files found.
python/sglang/srt/layers/attention/fla/chunk.py
0 → 100644
View file @
dc491b39
# Adapted from https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/gated_delta_rule/chunk.py
# -*- coding: utf-8 -*-
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
import
warnings
from
typing
import
Optional
import
torch
from
einops
import
rearrange
from
sglang.srt.layers.attention.fla.chunk_delta_h
import
chunk_gated_delta_rule_fwd_h
from
sglang.srt.layers.attention.fla.chunk_o
import
chunk_fwd_o
from
sglang.srt.layers.attention.fla.chunk_scaled_dot_kkt
import
(
chunk_scaled_dot_kkt_fwd
,
)
from
sglang.srt.layers.attention.fla.cumsum
import
chunk_local_cumsum
from
sglang.srt.layers.attention.fla.l2norm
import
l2norm_fwd
from
sglang.srt.layers.attention.fla.solve_tril
import
solve_tril
from
sglang.srt.layers.attention.fla.utils
import
(
SUPPRESS_LEVEL
,
autocast_custom_fwd
,
input_guard
,
)
from
sglang.srt.layers.attention.fla.wy_fast
import
recompute_w_u_fwd
def
chunk_gated_delta_rule_fwd
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
g
:
torch
.
Tensor
,
beta
:
torch
.
Tensor
,
scale
:
float
,
initial_state
:
torch
.
Tensor
,
output_final_state
:
bool
,
cu_seqlens
:
Optional
[
torch
.
LongTensor
]
=
None
,
):
g
=
chunk_local_cumsum
(
g
,
chunk_size
=
64
,
cu_seqlens
=
cu_seqlens
)
# obtain WY representation. u is actually the new v.
A
=
chunk_scaled_dot_kkt_fwd
(
k
=
k
,
beta
=
beta
,
g_cumsum
=
g
,
cu_seqlens
=
cu_seqlens
,
output_dtype
=
torch
.
float32
)
A
=
solve_tril
(
A
=
A
,
cu_seqlens
=
cu_seqlens
,
output_dtype
=
k
.
dtype
)
w
,
u
=
recompute_w_u_fwd
(
k
=
k
,
v
=
v
,
beta
=
beta
,
A
=
A
,
g_cumsum
=
g
,
cu_seqlens
=
cu_seqlens
,
)
h
,
v_new
,
final_state
=
chunk_gated_delta_rule_fwd_h
(
k
=
k
,
w
=
w
,
u
=
u
,
g
=
g
,
initial_state
=
initial_state
,
output_final_state
=
output_final_state
,
cu_seqlens
=
cu_seqlens
,
)
o
=
chunk_fwd_o
(
q
=
q
,
k
=
k
,
v
=
v_new
,
h
=
h
,
g
=
g
,
scale
=
scale
,
cu_seqlens
=
cu_seqlens
,
)
if
SUPPRESS_LEVEL
<
3
:
return
g
,
o
,
A
,
final_state
,
None
,
None
,
None
elif
SUPPRESS_LEVEL
>=
3
:
return
g
,
o
,
A
,
final_state
,
w
,
h
,
v_new
class
ChunkGatedDeltaRuleFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
@
input_guard
@
autocast_custom_fwd
def
forward
(
ctx
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
g
:
torch
.
Tensor
,
beta
:
torch
.
Tensor
,
scale
:
float
,
initial_state
:
torch
.
Tensor
,
output_final_state
:
bool
,
cu_seqlens
:
Optional
[
torch
.
LongTensor
]
=
None
,
use_qk_l2norm_in_kernel
:
bool
=
False
,
):
q_orig
=
q
k_orig
=
k
if
use_qk_l2norm_in_kernel
:
q
=
l2norm_fwd
(
q
)
k
=
l2norm_fwd
(
k
)
g
,
o
,
A
,
final_state
,
w
,
h
,
v_new
=
chunk_gated_delta_rule_fwd
(
q
=
q
,
k
=
k
,
v
=
v
,
g
=
g
,
beta
=
beta
,
scale
=
scale
,
initial_state
=
initial_state
,
output_final_state
=
output_final_state
,
cu_seqlens
=
cu_seqlens
,
)
return
o
.
to
(
q
.
dtype
),
final_state
@
torch
.
compiler
.
disable
def
chunk_gated_delta_rule
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
g
:
torch
.
Tensor
,
beta
:
torch
.
Tensor
,
scale
:
float
=
None
,
initial_state
:
torch
.
Tensor
=
None
,
output_final_state
:
bool
=
False
,
cu_seqlens
:
Optional
[
torch
.
LongTensor
]
=
None
,
head_first
:
bool
=
False
,
use_qk_l2norm_in_kernel
:
bool
=
False
,
):
r
"""
Args:
q (torch.Tensor):
queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
k (torch.Tensor):
keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
v (torch.Tensor):
values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
g (torch.Tensor):
(forget) gating tensor (in log space!) of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`.
beta (torch.Tensor):
betas of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`.
scale (Optional[int]):
Scale factor for the RetNet attention scores.
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
initial_state (Optional[torch.Tensor]):
Initial state of shape `[N, H, K, V]` for `N` input sequences.
For equal-length input sequences, `N` equals the batch size `B`.
Default: `None`.
output_final_state (Optional[bool]):
Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
cu_seqlens (torch.LongTensor):
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
consistent with the FlashAttention API.
head_first (Optional[bool]):
Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
Default: `False`.
Returns:
o (torch.Tensor):
Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
final_state (torch.Tensor):
Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`.
Examples::
>>> import torch
>>> import torch.nn.functional as F
>>> from einops import rearrange
>>> from fla.ops.gated_delta_rule import chunk_gated_delta_rule
# inputs with equal lengths
>>> B, T, H, K, V = 4, 2048, 4, 512, 512
>>> q = torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda')
>>> k = F.normalize(torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda'), p=2, dim=-1)
>>> v = torch.randn(B, T, H, V, dtype=torch.bfloat16, device='cuda')
>>> beta = torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda').sigmoid()
>>> g = F.logsigmoid(torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda'))
>>> h0 = torch.randn(B, H, K, V, dtype=torch.bfloat16, device='cuda')
>>> o, ht = chunk_gated_delta_rule(
q, k, v, g, beta,
initial_state=h0,
output_final_state=True
)
# for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
>>> q, k, v, beta, g = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, beta, g))
# for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
>>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
>>> o_var, ht_var = chunk_gated_delta_rule(
q, k, v, g, beta,
initial_state=h0,
output_final_state=True,
cu_seqlens=cu_seqlens
)
"""
assert
q
.
dtype
==
k
.
dtype
==
v
.
dtype
assert
(
q
.
dtype
!=
torch
.
float32
),
"ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16."
assert
(
len
(
beta
.
shape
)
==
3
),
"beta must be of shape [B, T, H] if head_first=False, or [B, H, T] otherwise."
if
head_first
:
raise
DeprecationWarning
(
"head_first is deprecated and will be removed in a future version. "
"Please use head_first=False for now instead."
)
q
,
k
,
v
,
beta
,
g
=
map
(
lambda
x
:
rearrange
(
x
,
"b h t ... -> b t h ..."
),
(
q
,
k
,
v
,
beta
,
g
)
)
# if not head_first and q.shape[1] < q.shape[2]:
# warnings.warn(
# f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
# "This may indicate the inputs were passed in head-first format [B, H, T, ...] "
# "when head_first=False was specified. "
# "Please verify your input tensor format matches the expected shape [B, T, H, ...]."
# )
if
cu_seqlens
is
not
None
:
if
q
.
shape
[
0
]
!=
1
:
raise
ValueError
(
f
"The batch size is expected to be 1 rather than
{
q
.
shape
[
0
]
}
when using `cu_seqlens`."
f
"Please flatten variable-length inputs before processing."
)
if
initial_state
is
not
None
and
initial_state
.
shape
[
0
]
!=
len
(
cu_seqlens
)
-
1
:
raise
ValueError
(
f
"The number of initial states is expected to be equal to the number of input sequences, "
f
"i.e.,
{
len
(
cu_seqlens
)
-
1
}
rather than
{
initial_state
.
shape
[
0
]
}
."
)
if
scale
is
None
:
scale
=
k
.
shape
[
-
1
]
**
-
0.5
o
,
final_state
=
ChunkGatedDeltaRuleFunction
.
apply
(
q
,
k
,
v
,
g
,
beta
,
scale
,
initial_state
,
output_final_state
,
cu_seqlens
,
use_qk_l2norm_in_kernel
,
)
if
head_first
:
o
=
rearrange
(
o
,
"b t h ... -> b h t ..."
)
return
o
,
final_state
python/sglang/srt/layers/attention/fla/chunk_delta_h.py
0 → 100644
View file @
dc491b39
# Adapted from https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/common/chunk_delta_h.py
# -*- coding: utf-8 -*-
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
from
typing
import
Optional
,
Tuple
import
torch
import
triton
import
triton.language
as
tl
from
sglang.srt.layers.attention.fla.index
import
(
prepare_chunk_indices
,
prepare_chunk_offsets
,
)
from
sglang.srt.layers.attention.fla.op
import
exp
,
safe_exp
from
sglang.srt.layers.attention.fla.utils
import
is_nvidia_hopper
NUM_WARPS
=
[
2
,
4
]
if
is_nvidia_hopper
else
[
2
,
4
,
8
,
16
]
@
triton
.
heuristics
(
{
"USE_G"
:
lambda
args
:
args
[
"g"
]
is
not
None
,
"USE_INITIAL_STATE"
:
lambda
args
:
args
[
"h0"
]
is
not
None
,
"STORE_FINAL_STATE"
:
lambda
args
:
args
[
"ht"
]
is
not
None
,
"SAVE_NEW_VALUE"
:
lambda
args
:
args
[
"v_new"
]
is
not
None
,
"IS_VARLEN"
:
lambda
args
:
args
[
"cu_seqlens"
]
is
not
None
,
}
)
# @triton.autotune(
# configs=[
# triton.Config({"BV": BV}, num_warps=num_warps, num_stages=num_stages)
# for num_warps in [2, 4]
# for num_stages in [2, 3, 4]
# for BV in [32, 64]
# ],
# key=["H", "K", "V", "BT", "USE_G"],
# use_cuda_graph=use_cuda_graph,
# )
@
triton
.
jit
(
do_not_specialize
=
[
"T"
])
def
chunk_gated_delta_rule_fwd_kernel_h_blockdim64
(
k
,
v
,
w
,
v_new
,
g
,
h
,
h0
,
ht
,
cu_seqlens
,
chunk_offsets
,
T
,
H
:
tl
.
constexpr
,
Hg
:
tl
.
constexpr
,
K
:
tl
.
constexpr
,
V
:
tl
.
constexpr
,
BT
:
tl
.
constexpr
,
BV
:
tl
.
constexpr
,
USE_G
:
tl
.
constexpr
,
USE_INITIAL_STATE
:
tl
.
constexpr
,
STORE_FINAL_STATE
:
tl
.
constexpr
,
SAVE_NEW_VALUE
:
tl
.
constexpr
,
IS_VARLEN
:
tl
.
constexpr
,
):
i_v
,
i_nh
=
tl
.
program_id
(
0
),
tl
.
program_id
(
1
)
i_n
,
i_h
=
i_nh
//
H
,
i_nh
%
H
if
IS_VARLEN
:
bos
,
eos
=
tl
.
load
(
cu_seqlens
+
i_n
).
to
(
tl
.
int32
),
tl
.
load
(
cu_seqlens
+
i_n
+
1
).
to
(
tl
.
int32
)
T
=
eos
-
bos
NT
=
tl
.
cdiv
(
T
,
BT
)
boh
=
tl
.
load
(
chunk_offsets
+
i_n
).
to
(
tl
.
int32
)
else
:
bos
,
eos
=
i_n
*
T
,
i_n
*
T
+
T
NT
=
tl
.
cdiv
(
T
,
BT
)
boh
=
i_n
*
NT
# [BK, BV]
b_h1
=
tl
.
zeros
([
64
,
BV
],
dtype
=
tl
.
float32
)
if
K
>
64
:
b_h2
=
tl
.
zeros
([
64
,
BV
],
dtype
=
tl
.
float32
)
if
K
>
128
:
b_h3
=
tl
.
zeros
([
64
,
BV
],
dtype
=
tl
.
float32
)
if
K
>
192
:
b_h4
=
tl
.
zeros
([
64
,
BV
],
dtype
=
tl
.
float32
)
# calculate offset
h
+=
(
boh
*
H
+
i_h
)
*
K
*
V
v
+=
(
bos
*
H
+
i_h
)
*
V
k
+=
(
bos
*
Hg
+
i_h
//
(
H
//
Hg
))
*
K
w
+=
(
bos
*
H
+
i_h
)
*
K
if
SAVE_NEW_VALUE
:
v_new
+=
(
bos
*
H
+
i_h
)
*
V
stride_v
=
H
*
V
stride_h
=
H
*
K
*
V
stride_k
=
Hg
*
K
stride_w
=
H
*
K
if
USE_INITIAL_STATE
:
h0
=
h0
+
i_nh
*
K
*
V
if
STORE_FINAL_STATE
:
ht
=
ht
+
i_nh
*
K
*
V
# load initial state
if
USE_INITIAL_STATE
:
p_h0_1
=
tl
.
make_block_ptr
(
h0
,
(
K
,
V
),
(
V
,
1
),
(
0
,
i_v
*
BV
),
(
64
,
BV
),
(
1
,
0
))
b_h1
+=
tl
.
load
(
p_h0_1
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
if
K
>
64
:
p_h0_2
=
tl
.
make_block_ptr
(
h0
,
(
K
,
V
),
(
V
,
1
),
(
64
,
i_v
*
BV
),
(
64
,
BV
),
(
1
,
0
)
)
b_h2
+=
tl
.
load
(
p_h0_2
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
if
K
>
128
:
p_h0_3
=
tl
.
make_block_ptr
(
h0
,
(
K
,
V
),
(
V
,
1
),
(
128
,
i_v
*
BV
),
(
64
,
BV
),
(
1
,
0
)
)
b_h3
+=
tl
.
load
(
p_h0_3
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
if
K
>
192
:
p_h0_4
=
tl
.
make_block_ptr
(
h0
,
(
K
,
V
),
(
V
,
1
),
(
192
,
i_v
*
BV
),
(
64
,
BV
),
(
1
,
0
)
)
b_h4
+=
tl
.
load
(
p_h0_4
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
# main recurrence
for
i_t
in
range
(
NT
):
p_h1
=
tl
.
make_block_ptr
(
h
+
i_t
*
stride_h
,
(
K
,
V
),
(
V
,
1
),
(
0
,
i_v
*
BV
),
(
64
,
BV
),
(
1
,
0
)
)
tl
.
store
(
p_h1
,
b_h1
.
to
(
p_h1
.
dtype
.
element_ty
),
boundary_check
=
(
0
,
1
))
if
K
>
64
:
p_h2
=
tl
.
make_block_ptr
(
h
+
i_t
*
stride_h
,
(
K
,
V
),
(
V
,
1
),
(
64
,
i_v
*
BV
),
(
64
,
BV
),
(
1
,
0
)
)
tl
.
store
(
p_h2
,
b_h2
.
to
(
p_h2
.
dtype
.
element_ty
),
boundary_check
=
(
0
,
1
))
if
K
>
128
:
p_h3
=
tl
.
make_block_ptr
(
h
+
i_t
*
stride_h
,
(
K
,
V
),
(
V
,
1
),
(
128
,
i_v
*
BV
),
(
64
,
BV
),
(
1
,
0
)
)
tl
.
store
(
p_h3
,
b_h3
.
to
(
p_h3
.
dtype
.
element_ty
),
boundary_check
=
(
0
,
1
))
if
K
>
192
:
p_h4
=
tl
.
make_block_ptr
(
h
+
i_t
*
stride_h
,
(
K
,
V
),
(
V
,
1
),
(
192
,
i_v
*
BV
),
(
64
,
BV
),
(
1
,
0
)
)
tl
.
store
(
p_h4
,
b_h4
.
to
(
p_h4
.
dtype
.
element_ty
),
boundary_check
=
(
0
,
1
))
p_v
=
tl
.
make_block_ptr
(
v
,
(
T
,
V
),
(
stride_v
,
1
),
(
i_t
*
BT
,
i_v
*
BV
),
(
BT
,
BV
),
(
1
,
0
)
)
p_v_new
=
(
tl
.
make_block_ptr
(
v_new
,
(
T
,
V
),
(
stride_v
,
1
),
(
i_t
*
BT
,
i_v
*
BV
),
(
BT
,
BV
),
(
1
,
0
)
)
if
SAVE_NEW_VALUE
else
None
)
b_v_new
=
tl
.
zeros
([
BT
,
BV
],
dtype
=
tl
.
float32
)
p_w
=
tl
.
make_block_ptr
(
w
,
(
T
,
K
),
(
stride_w
,
1
),
(
i_t
*
BT
,
0
),
(
BT
,
64
),
(
1
,
0
)
)
b_w
=
tl
.
load
(
p_w
,
boundary_check
=
(
0
,
1
))
b_v_new
+=
tl
.
dot
(
b_w
,
b_h1
.
to
(
b_w
.
dtype
))
if
K
>
64
:
p_w
=
tl
.
make_block_ptr
(
w
,
(
T
,
K
),
(
stride_w
,
1
),
(
i_t
*
BT
,
64
),
(
BT
,
64
),
(
1
,
0
)
)
b_w
=
tl
.
load
(
p_w
,
boundary_check
=
(
0
,
1
))
b_v_new
+=
tl
.
dot
(
b_w
,
b_h2
.
to
(
b_w
.
dtype
))
if
K
>
128
:
p_w
=
tl
.
make_block_ptr
(
w
,
(
T
,
K
),
(
stride_w
,
1
),
(
i_t
*
BT
,
128
),
(
BT
,
64
),
(
1
,
0
)
)
b_w
=
tl
.
load
(
p_w
,
boundary_check
=
(
0
,
1
))
b_v_new
+=
tl
.
dot
(
b_w
,
b_h3
.
to
(
b_w
.
dtype
))
if
K
>
192
:
p_w
=
tl
.
make_block_ptr
(
w
,
(
T
,
K
),
(
stride_w
,
1
),
(
i_t
*
BT
,
192
),
(
BT
,
64
),
(
1
,
0
)
)
b_w
=
tl
.
load
(
p_w
,
boundary_check
=
(
0
,
1
))
b_v_new
+=
tl
.
dot
(
b_w
,
b_h4
.
to
(
b_w
.
dtype
))
b_v_new
=
-
b_v_new
+
tl
.
load
(
p_v
,
boundary_check
=
(
0
,
1
))
if
SAVE_NEW_VALUE
:
p_v_new
=
tl
.
make_block_ptr
(
v_new
,
(
T
,
V
),
(
stride_v
,
1
),
(
i_t
*
BT
,
i_v
*
BV
),
(
BT
,
BV
),
(
1
,
0
)
)
tl
.
store
(
p_v_new
,
b_v_new
.
to
(
p_v_new
.
dtype
.
element_ty
),
boundary_check
=
(
0
,
1
)
)
if
USE_G
:
last_idx
=
min
((
i_t
+
1
)
*
BT
,
T
)
-
1
b_g_last
=
tl
.
load
(
g
+
bos
*
H
+
last_idx
*
H
+
i_h
)
p_g
=
tl
.
make_block_ptr
(
g
+
bos
*
H
+
i_h
,
(
T
,),
(
H
,),
(
i_t
*
BT
,),
(
BT
,),
(
0
,)
)
b_g
=
tl
.
load
(
p_g
,
boundary_check
=
(
0
,))
b_v_new
=
b_v_new
*
safe_exp
(
b_g_last
-
b_g
)[:,
None
]
b_g_last
=
exp
(
b_g_last
)
b_h1
=
b_h1
*
b_g_last
if
K
>
64
:
b_h2
=
b_h2
*
b_g_last
if
K
>
128
:
b_h3
=
b_h3
*
b_g_last
if
K
>
192
:
b_h4
=
b_h4
*
b_g_last
b_v_new
=
b_v_new
.
to
(
k
.
dtype
.
element_ty
)
p_k
=
tl
.
make_block_ptr
(
k
,
(
K
,
T
),
(
1
,
stride_k
),
(
0
,
i_t
*
BT
),
(
64
,
BT
),
(
0
,
1
)
)
b_k
=
tl
.
load
(
p_k
,
boundary_check
=
(
0
,
1
))
b_h1
+=
tl
.
dot
(
b_k
,
b_v_new
)
if
K
>
64
:
p_k
=
tl
.
make_block_ptr
(
k
,
(
K
,
T
),
(
1
,
stride_k
),
(
64
,
i_t
*
BT
),
(
64
,
BT
),
(
0
,
1
)
)
b_k
=
tl
.
load
(
p_k
,
boundary_check
=
(
0
,
1
))
b_h2
+=
tl
.
dot
(
b_k
,
b_v_new
)
if
K
>
128
:
p_k
=
tl
.
make_block_ptr
(
k
,
(
K
,
T
),
(
1
,
stride_k
),
(
128
,
i_t
*
BT
),
(
64
,
BT
),
(
0
,
1
)
)
b_k
=
tl
.
load
(
p_k
,
boundary_check
=
(
0
,
1
))
b_h3
+=
tl
.
dot
(
b_k
,
b_v_new
)
if
K
>
192
:
p_k
=
tl
.
make_block_ptr
(
k
,
(
K
,
T
),
(
1
,
stride_k
),
(
192
,
i_t
*
BT
),
(
64
,
BT
),
(
0
,
1
)
)
b_k
=
tl
.
load
(
p_k
,
boundary_check
=
(
0
,
1
))
b_h4
+=
tl
.
dot
(
b_k
,
b_v_new
)
# epilogue
if
STORE_FINAL_STATE
:
p_ht
=
tl
.
make_block_ptr
(
ht
,
(
K
,
V
),
(
V
,
1
),
(
0
,
i_v
*
BV
),
(
64
,
BV
),
(
1
,
0
))
tl
.
store
(
p_ht
,
b_h1
.
to
(
p_ht
.
dtype
.
element_ty
),
boundary_check
=
(
0
,
1
))
if
K
>
64
:
p_ht
=
tl
.
make_block_ptr
(
ht
,
(
K
,
V
),
(
V
,
1
),
(
64
,
i_v
*
BV
),
(
64
,
BV
),
(
1
,
0
)
)
tl
.
store
(
p_ht
,
b_h2
.
to
(
p_ht
.
dtype
.
element_ty
),
boundary_check
=
(
0
,
1
))
if
K
>
128
:
p_ht
=
tl
.
make_block_ptr
(
ht
,
(
K
,
V
),
(
V
,
1
),
(
128
,
i_v
*
BV
),
(
64
,
BV
),
(
1
,
0
)
)
tl
.
store
(
p_ht
,
b_h3
.
to
(
p_ht
.
dtype
.
element_ty
),
boundary_check
=
(
0
,
1
))
if
K
>
192
:
p_ht
=
tl
.
make_block_ptr
(
ht
,
(
K
,
V
),
(
V
,
1
),
(
192
,
i_v
*
BV
),
(
64
,
BV
),
(
1
,
0
)
)
tl
.
store
(
p_ht
,
b_h4
.
to
(
p_ht
.
dtype
.
element_ty
),
boundary_check
=
(
0
,
1
))
def
chunk_gated_delta_rule_fwd_h
(
k
:
torch
.
Tensor
,
w
:
torch
.
Tensor
,
u
:
torch
.
Tensor
,
g
:
Optional
[
torch
.
Tensor
]
=
None
,
initial_state
:
Optional
[
torch
.
Tensor
]
=
None
,
output_final_state
:
bool
=
False
,
chunk_size
:
int
=
64
,
# SY: remove this argument and force chunk size 64?
save_new_value
:
bool
=
True
,
cu_seqlens
:
Optional
[
torch
.
LongTensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
B
,
T
,
Hg
,
K
,
V
=
*
k
.
shape
,
u
.
shape
[
-
1
]
H
=
u
.
shape
[
-
2
]
BT
=
chunk_size
chunk_indices
=
(
prepare_chunk_indices
(
cu_seqlens
,
chunk_size
)
if
cu_seqlens
is
not
None
else
None
)
# N: the actual number of sequences in the batch with either equal or variable lengths
if
cu_seqlens
is
None
:
N
,
NT
,
chunk_offsets
=
B
,
triton
.
cdiv
(
T
,
BT
),
None
else
:
N
,
NT
,
chunk_offsets
=
(
len
(
cu_seqlens
)
-
1
,
len
(
chunk_indices
),
prepare_chunk_offsets
(
cu_seqlens
,
BT
),
)
assert
K
<=
256
,
"current kernel does not support head dimension larger than 256."
h
=
k
.
new_empty
(
B
,
NT
,
H
,
K
,
V
)
final_state
=
(
k
.
new_empty
(
N
,
H
,
K
,
V
,
dtype
=
torch
.
float32
)
if
output_final_state
else
None
)
v_new
=
torch
.
empty_like
(
u
)
if
save_new_value
else
None
def
grid
(
meta
):
return
(
triton
.
cdiv
(
V
,
meta
[
"BV"
]),
N
*
H
)
chunk_gated_delta_rule_fwd_kernel_h_blockdim64
[
grid
](
k
=
k
,
v
=
u
,
w
=
w
,
v_new
=
v_new
,
g
=
g
,
h
=
h
,
h0
=
initial_state
,
ht
=
final_state
,
cu_seqlens
=
cu_seqlens
,
chunk_offsets
=
chunk_offsets
,
T
=
T
,
H
=
H
,
Hg
=
Hg
,
K
=
K
,
V
=
V
,
BT
=
BT
,
BV
=
32
,
num_warps
=
4
,
num_stages
=
2
,
)
return
h
,
v_new
,
final_state
python/sglang/srt/layers/attention/fla/chunk_o.py
0 → 100644
View file @
dc491b39
# Adapted from https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/common/chunk_o.py
# -*- coding: utf-8 -*-
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
from
typing
import
Optional
,
Tuple
import
torch
import
triton
import
triton.language
as
tl
from
sglang.srt.layers.attention.fla.index
import
prepare_chunk_indices
from
sglang.srt.layers.attention.fla.op
import
exp
,
safe_exp
from
sglang.srt.layers.attention.fla.utils
import
check_shared_mem
,
is_nvidia_hopper
BKV_LIST
=
[
64
,
128
]
if
check_shared_mem
()
else
[
32
,
64
]
NUM_WARPS
=
[
2
,
4
]
if
is_nvidia_hopper
else
[
2
,
4
,
8
]
@
triton
.
heuristics
(
{
"USE_G"
:
lambda
args
:
args
[
"g"
]
is
not
None
,
"IS_VARLEN"
:
lambda
args
:
args
[
"cu_seqlens"
]
is
not
None
,
}
)
# @triton.autotune(
# configs=[
# triton.Config({"BK": BK, "BV": BV}, num_warps=num_warps, num_stages=num_stages)
# for BK in BKV_LIST
# for BV in BKV_LIST
# for num_warps in NUM_WARPS
# for num_stages in [2, 3, 4]
# ],
# key=["H", "K", "V", "BT"],
# )
@
triton
.
jit
(
do_not_specialize
=
[
"T"
])
def
chunk_fwd_kernel_o
(
q
,
k
,
v
,
h
,
g
,
o
,
cu_seqlens
,
chunk_indices
,
scale
,
T
,
H
:
tl
.
constexpr
,
Hg
:
tl
.
constexpr
,
K
:
tl
.
constexpr
,
V
:
tl
.
constexpr
,
BT
:
tl
.
constexpr
,
BK
:
tl
.
constexpr
,
BV
:
tl
.
constexpr
,
USE_G
:
tl
.
constexpr
,
IS_VARLEN
:
tl
.
constexpr
,
):
i_v
,
i_t
,
i_bh
=
tl
.
program_id
(
0
),
tl
.
program_id
(
1
),
tl
.
program_id
(
2
)
i_b
,
i_h
=
i_bh
//
H
,
i_bh
%
H
if
IS_VARLEN
:
i_tg
=
i_t
i_n
,
i_t
=
tl
.
load
(
chunk_indices
+
i_t
*
2
).
to
(
tl
.
int32
),
tl
.
load
(
chunk_indices
+
i_t
*
2
+
1
).
to
(
tl
.
int32
)
bos
,
eos
=
tl
.
load
(
cu_seqlens
+
i_n
).
to
(
tl
.
int32
),
tl
.
load
(
cu_seqlens
+
i_n
+
1
).
to
(
tl
.
int32
)
T
=
eos
-
bos
NT
=
tl
.
cdiv
(
T
,
BT
)
else
:
NT
=
tl
.
cdiv
(
T
,
BT
)
i_tg
=
i_b
*
NT
+
i_t
bos
,
eos
=
i_b
*
T
,
i_b
*
T
+
T
# offset calculation
q
+=
(
bos
*
Hg
+
i_h
//
(
H
//
Hg
))
*
K
k
+=
(
bos
*
Hg
+
i_h
//
(
H
//
Hg
))
*
K
v
+=
(
bos
*
H
+
i_h
)
*
V
o
+=
(
bos
*
H
+
i_h
)
*
V
h
+=
(
i_tg
*
H
+
i_h
).
to
(
tl
.
int64
)
*
K
*
V
b_o
=
tl
.
zeros
([
BT
,
BV
],
dtype
=
tl
.
float32
)
b_A
=
tl
.
zeros
([
BT
,
BT
],
dtype
=
tl
.
float32
)
for
i_k
in
range
(
tl
.
cdiv
(
K
,
BK
)):
p_q
=
tl
.
make_block_ptr
(
q
,
(
T
,
K
),
(
Hg
*
K
,
1
),
(
i_t
*
BT
,
i_k
*
BK
),
(
BT
,
BK
),
(
1
,
0
)
)
p_k
=
tl
.
make_block_ptr
(
k
,
(
K
,
T
),
(
1
,
Hg
*
K
),
(
i_k
*
BK
,
i_t
*
BT
),
(
BK
,
BT
),
(
0
,
1
)
)
p_h
=
tl
.
make_block_ptr
(
h
,
(
K
,
V
),
(
V
,
1
),
(
i_k
*
BK
,
i_v
*
BV
),
(
BK
,
BV
),
(
1
,
0
)
)
# [BT, BK]
b_q
=
tl
.
load
(
p_q
,
boundary_check
=
(
0
,
1
))
# [BK, BT]
b_k
=
tl
.
load
(
p_k
,
boundary_check
=
(
0
,
1
))
# [BK, BV]
b_h
=
tl
.
load
(
p_h
,
boundary_check
=
(
0
,
1
))
# [BT, BK] @ [BK, BV] -> [BT, BV]
b_o
+=
tl
.
dot
(
b_q
,
b_h
)
# [BT, BK] @ [BK, BT] -> [BT, BT]
b_A
+=
tl
.
dot
(
b_q
,
b_k
)
if
USE_G
:
g
+=
bos
*
H
+
i_h
p_g
=
tl
.
make_block_ptr
(
g
,
(
T
,),
(
H
,),
(
i_t
*
BT
,),
(
BT
,),
(
0
,))
b_g
=
tl
.
load
(
p_g
,
boundary_check
=
(
0
,))
b_o
=
b_o
*
exp
(
b_g
)[:,
None
]
b_A
=
b_A
*
safe_exp
(
b_g
[:,
None
]
-
b_g
[
None
,
:])
o_i
=
tl
.
arange
(
0
,
BT
)
m_A
=
o_i
[:,
None
]
>=
o_i
[
None
,
:]
b_A
=
tl
.
where
(
m_A
,
b_A
,
0
)
p_v
=
tl
.
make_block_ptr
(
v
,
(
T
,
V
),
(
H
*
V
,
1
),
(
i_t
*
BT
,
i_v
*
BV
),
(
BT
,
BV
),
(
1
,
0
)
)
p_o
=
tl
.
make_block_ptr
(
o
,
(
T
,
V
),
(
H
*
V
,
1
),
(
i_t
*
BT
,
i_v
*
BV
),
(
BT
,
BV
),
(
1
,
0
)
)
b_v
=
tl
.
load
(
p_v
,
boundary_check
=
(
0
,
1
))
# to fix mma -> mma layout conversion
# already solved by triton v3.2 or higher
b_o
=
b_o
*
scale
+
tl
.
dot
(
b_A
.
to
(
b_v
.
dtype
),
b_v
)
*
scale
tl
.
store
(
p_o
,
b_o
.
to
(
p_o
.
dtype
.
element_ty
),
boundary_check
=
(
0
,
1
))
def
chunk_fwd_o
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
h
:
torch
.
Tensor
,
g
:
Optional
[
torch
.
Tensor
]
=
None
,
# cumsum of log decay
scale
:
Optional
[
float
]
=
None
,
cu_seqlens
:
Optional
[
torch
.
LongTensor
]
=
None
,
chunk_size
:
int
=
64
,
)
->
torch
.
Tensor
:
B
,
T
,
Hg
,
K
,
V
=
*
q
.
shape
,
v
.
shape
[
-
1
]
H
=
v
.
shape
[
-
2
]
BT
=
min
(
chunk_size
,
max
(
16
,
triton
.
next_power_of_2
(
T
)))
chunk_indices
=
(
prepare_chunk_indices
(
cu_seqlens
,
BT
)
if
cu_seqlens
is
not
None
else
None
)
NT
=
triton
.
cdiv
(
T
,
BT
)
if
cu_seqlens
is
None
else
len
(
chunk_indices
)
if
scale
is
None
:
scale
=
k
.
shape
[
-
1
]
**
-
0.5
o
=
torch
.
empty_like
(
v
)
def
grid
(
meta
):
return
(
triton
.
cdiv
(
V
,
meta
[
"BV"
]),
NT
,
B
*
H
)
chunk_fwd_kernel_o
[
grid
](
q
,
k
,
v
,
h
,
g
,
o
,
cu_seqlens
,
chunk_indices
,
scale
,
T
=
T
,
H
=
H
,
Hg
=
Hg
,
K
=
K
,
V
=
V
,
BT
=
BT
,
BK
=
128
,
BV
=
64
,
num_warps
=
4
,
num_stages
=
2
,
)
return
o
python/sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py
0 → 100644
View file @
dc491b39
# Adapted from https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/common/chunk_scaled_dot_kkt.py
# -*- coding: utf-8 -*-
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
from
typing
import
Optional
import
torch
import
triton
import
triton.language
as
tl
from
sglang.srt.layers.attention.fla.index
import
prepare_chunk_indices
from
sglang.srt.layers.attention.fla.op
import
safe_exp
@
triton
.
heuristics
(
{
"IS_VARLEN"
:
lambda
args
:
args
[
"cu_seqlens"
]
is
not
None
,
"USE_G"
:
lambda
args
:
args
[
"g_cumsum"
]
is
not
None
,
}
)
# @triton.autotune(
# configs=[
# triton.Config({"BK": BK}, num_warps=num_warps, num_stages=num_stages)
# for BK in [32, 64, 128]
# for num_warps in [2, 4, 8]
# for num_stages in [2, 3, 4]
# ],
# key=["H", "K", "BT", "IS_VARLEN"],
# )
@
triton
.
jit
(
do_not_specialize
=
[
"T"
])
def
chunk_scaled_dot_kkt_fwd_kernel
(
k
,
beta
,
g_cumsum
,
A
,
cu_seqlens
,
chunk_indices
,
T
,
H
:
tl
.
constexpr
,
Hg
:
tl
.
constexpr
,
K
:
tl
.
constexpr
,
BT
:
tl
.
constexpr
,
BK
:
tl
.
constexpr
,
IS_VARLEN
:
tl
.
constexpr
,
USE_G
:
tl
.
constexpr
,
):
i_t
,
i_bh
=
tl
.
program_id
(
0
),
tl
.
program_id
(
1
)
i_b
,
i_h
=
i_bh
//
H
,
i_bh
%
H
if
IS_VARLEN
:
i_n
,
i_t
=
tl
.
load
(
chunk_indices
+
i_t
*
2
).
to
(
tl
.
int32
),
tl
.
load
(
chunk_indices
+
i_t
*
2
+
1
).
to
(
tl
.
int32
)
bos
,
eos
=
tl
.
load
(
cu_seqlens
+
i_n
).
to
(
tl
.
int32
),
tl
.
load
(
cu_seqlens
+
i_n
+
1
).
to
(
tl
.
int32
)
T
=
eos
-
bos
else
:
bos
,
eos
=
i_b
*
T
,
i_b
*
T
+
T
o_t
=
tl
.
arange
(
0
,
BT
)
p_beta
=
tl
.
make_block_ptr
(
beta
+
bos
*
H
+
i_h
,
(
T
,),
(
H
,),
(
i_t
*
BT
,),
(
BT
,),
(
0
,)
)
b_beta
=
tl
.
load
(
p_beta
,
boundary_check
=
(
0
,))
b_A
=
tl
.
zeros
([
BT
,
BT
],
dtype
=
tl
.
float32
)
for
i_k
in
range
(
tl
.
cdiv
(
K
,
BK
)):
p_k
=
tl
.
make_block_ptr
(
k
+
(
bos
*
Hg
+
i_h
//
(
H
//
Hg
))
*
K
,
(
T
,
K
),
(
Hg
*
K
,
1
),
(
i_t
*
BT
,
i_k
*
BK
),
(
BT
,
BK
),
(
1
,
0
),
)
b_k
=
tl
.
load
(
p_k
,
boundary_check
=
(
0
,
1
))
b_kb
=
b_k
*
b_beta
[:,
None
]
b_A
+=
tl
.
dot
(
b_kb
.
to
(
b_k
.
dtype
),
tl
.
trans
(
b_k
))
if
USE_G
:
p_g
=
tl
.
make_block_ptr
(
g_cumsum
+
bos
*
H
+
i_h
,
(
T
,),
(
H
,),
(
i_t
*
BT
,),
(
BT
,),
(
0
,)
)
b_g
=
tl
.
load
(
p_g
,
boundary_check
=
(
0
,))
b_g_diff
=
b_g
[:,
None
]
-
b_g
[
None
,
:]
b_A
=
b_A
*
safe_exp
(
b_g_diff
)
b_A
=
tl
.
where
(
o_t
[:,
None
]
>
o_t
[
None
,
:],
b_A
,
0
)
p_A
=
tl
.
make_block_ptr
(
A
+
(
bos
*
H
+
i_h
)
*
BT
,
(
T
,
BT
),
(
BT
*
H
,
1
),
(
i_t
*
BT
,
0
),
(
BT
,
BT
),
(
1
,
0
)
)
tl
.
store
(
p_A
,
b_A
.
to
(
p_A
.
dtype
.
element_ty
),
boundary_check
=
(
0
,
1
))
def
chunk_scaled_dot_kkt_fwd
(
k
:
torch
.
Tensor
,
beta
:
torch
.
Tensor
,
g_cumsum
:
Optional
[
torch
.
Tensor
]
=
None
,
cu_seqlens
:
Optional
[
torch
.
LongTensor
]
=
None
,
chunk_size
:
int
=
64
,
output_dtype
:
torch
.
dtype
=
torch
.
float32
,
)
->
torch
.
Tensor
:
r
"""
Compute beta * K * K^T.
Args:
k (torch.Tensor):
The key tensor of shape `[B, T, H, K]`.
beta (torch.Tensor):
The beta tensor of shape `[B, T, H]`.
g_cumsum (torch.Tensor):
The cumulative sum of the gate tensor of shape `[B, T, H]`.
Default: None
cu_seqlens (torch.LongTensor):
The cumulative sequence lengths of the input tensor.
Default: None
chunk_size (int):
The chunk size. Default: 64.
output_dtype (torch.dtype):
The dtype of the output tensor. Default: `torch.float32`
Returns:
beta * K * K^T of shape `[B, T, H, BT]` where `BT` is the chunk size.
"""
B
,
T
,
Hg
,
K
=
k
.
shape
H
=
beta
.
shape
[
-
1
]
BT
=
chunk_size
chunk_indices
=
(
prepare_chunk_indices
(
cu_seqlens
,
BT
)
if
cu_seqlens
is
not
None
else
None
)
NT
=
triton
.
cdiv
(
T
,
BT
)
if
cu_seqlens
is
None
else
len
(
chunk_indices
)
A
=
torch
.
empty
(
B
,
T
,
H
,
BT
,
device
=
k
.
device
,
dtype
=
output_dtype
)
chunk_scaled_dot_kkt_fwd_kernel
[(
NT
,
B
*
H
)](
k
=
k
,
beta
=
beta
,
g_cumsum
=
g_cumsum
,
A
=
A
,
cu_seqlens
=
cu_seqlens
,
chunk_indices
=
chunk_indices
,
T
=
T
,
H
=
H
,
Hg
=
Hg
,
K
=
K
,
BT
=
BT
,
BK
=
64
,
num_warps
=
8
,
num_stages
=
3
,
)
return
A
python/sglang/srt/layers/attention/fla/cumsum.py
0 → 100644
View file @
dc491b39
# Adapt from https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/utils/cumsum.py
# -*- coding: utf-8 -*-
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
from
typing
import
Optional
import
torch
import
triton
import
triton.language
as
tl
from
sglang.srt.layers.attention.fla.index
import
prepare_chunk_indices
from
sglang.srt.layers.attention.fla.utils
import
check_shared_mem
,
input_guard
BS_LIST
=
[
32
,
64
]
if
check_shared_mem
()
else
[
16
,
32
]
@
triton
.
heuristics
(
{
"HAS_SCALE"
:
lambda
args
:
args
[
"scale"
]
is
not
None
,
"IS_VARLEN"
:
lambda
args
:
args
[
"cu_seqlens"
]
is
not
None
,
}
)
# @triton.autotune(
# configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8]],
# key=["B", "H", "BT", "IS_VARLEN", "REVERSE"],
# )
@
triton
.
jit
(
do_not_specialize
=
[
"T"
])
def
chunk_local_cumsum_scalar_kernel
(
s
,
o
,
scale
,
cu_seqlens
,
chunk_indices
,
T
,
B
:
tl
.
constexpr
,
H
:
tl
.
constexpr
,
BT
:
tl
.
constexpr
,
REVERSE
:
tl
.
constexpr
,
HAS_SCALE
:
tl
.
constexpr
,
IS_VARLEN
:
tl
.
constexpr
,
HEAD_FIRST
:
tl
.
constexpr
,
):
i_t
,
i_bh
=
tl
.
program_id
(
0
),
tl
.
program_id
(
1
)
i_b
,
i_h
=
i_bh
//
H
,
i_bh
%
H
if
IS_VARLEN
:
i_n
,
i_t
=
tl
.
load
(
chunk_indices
+
i_t
*
2
).
to
(
tl
.
int32
),
tl
.
load
(
chunk_indices
+
i_t
*
2
+
1
).
to
(
tl
.
int32
)
bos
,
eos
=
tl
.
load
(
cu_seqlens
+
i_n
).
to
(
tl
.
int32
),
tl
.
load
(
cu_seqlens
+
i_n
+
1
).
to
(
tl
.
int32
)
T
=
eos
-
bos
else
:
bos
,
eos
=
i_b
*
T
,
i_b
*
T
+
T
if
HEAD_FIRST
:
p_s
=
tl
.
make_block_ptr
(
s
+
bos
*
H
+
i_h
*
T
,
(
T
,),
(
1
,),
(
i_t
*
BT
,),
(
BT
,),
(
0
,)
)
p_o
=
tl
.
make_block_ptr
(
o
+
bos
*
H
+
i_h
*
T
,
(
T
,),
(
1
,),
(
i_t
*
BT
,),
(
BT
,),
(
0
,)
)
else
:
p_s
=
tl
.
make_block_ptr
(
s
+
bos
*
H
+
i_h
,
(
T
,),
(
H
,),
(
i_t
*
BT
,),
(
BT
,),
(
0
,))
p_o
=
tl
.
make_block_ptr
(
o
+
bos
*
H
+
i_h
,
(
T
,),
(
H
,),
(
i_t
*
BT
,),
(
BT
,),
(
0
,))
# [BT]
b_s
=
tl
.
load
(
p_s
,
boundary_check
=
(
0
,)).
to
(
tl
.
float32
)
b_o
=
tl
.
cumsum
(
b_s
,
axis
=
0
)
if
REVERSE
:
b_z
=
tl
.
sum
(
b_s
,
axis
=
0
)
b_o
=
-
b_o
+
b_z
[
None
]
+
b_s
if
HAS_SCALE
:
b_o
*=
scale
tl
.
store
(
p_o
,
b_o
.
to
(
p_o
.
dtype
.
element_ty
),
boundary_check
=
(
0
,))
@
triton
.
heuristics
(
{
"HAS_SCALE"
:
lambda
args
:
args
[
"scale"
]
is
not
None
,
"IS_VARLEN"
:
lambda
args
:
args
[
"cu_seqlens"
]
is
not
None
,
}
)
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({
"BS"
:
BS
},
num_warps
=
num_warps
)
for
BS
in
BS_LIST
for
num_warps
in
[
2
,
4
,
8
]
],
key
=
[
"B"
,
"H"
,
"S"
,
"BT"
,
"IS_VARLEN"
,
"REVERSE"
],
)
@
triton
.
jit
(
do_not_specialize
=
[
"T"
])
def
chunk_local_cumsum_vector_kernel
(
s
,
o
,
scale
,
cu_seqlens
,
chunk_indices
,
T
,
B
:
tl
.
constexpr
,
H
:
tl
.
constexpr
,
S
:
tl
.
constexpr
,
BT
:
tl
.
constexpr
,
BS
:
tl
.
constexpr
,
REVERSE
:
tl
.
constexpr
,
HAS_SCALE
:
tl
.
constexpr
,
IS_VARLEN
:
tl
.
constexpr
,
HEAD_FIRST
:
tl
.
constexpr
,
):
i_s
,
i_t
,
i_bh
=
tl
.
program_id
(
0
),
tl
.
program_id
(
1
),
tl
.
program_id
(
2
)
i_b
,
i_h
=
i_bh
//
H
,
i_bh
%
H
if
IS_VARLEN
:
i_n
,
i_t
=
tl
.
load
(
chunk_indices
+
i_t
*
2
).
to
(
tl
.
int32
),
tl
.
load
(
chunk_indices
+
i_t
*
2
+
1
).
to
(
tl
.
int32
)
bos
,
eos
=
tl
.
load
(
cu_seqlens
+
i_n
).
to
(
tl
.
int32
),
tl
.
load
(
cu_seqlens
+
i_n
+
1
).
to
(
tl
.
int32
)
T
=
eos
-
bos
else
:
bos
,
eos
=
i_b
*
T
,
i_b
*
T
+
T
o_i
=
tl
.
arange
(
0
,
BT
)
if
REVERSE
:
m_s
=
tl
.
where
(
o_i
[:,
None
]
<=
o_i
[
None
,
:],
1.0
,
0.0
)
else
:
m_s
=
tl
.
where
(
o_i
[:,
None
]
>=
o_i
[
None
,
:],
1.0
,
0.0
)
if
HEAD_FIRST
:
p_s
=
tl
.
make_block_ptr
(
s
+
(
bos
*
H
+
i_h
*
T
)
*
S
,
(
T
,
S
),
(
S
,
1
),
(
i_t
*
BT
,
i_s
*
BS
),
(
BT
,
BS
),
(
1
,
0
),
)
p_o
=
tl
.
make_block_ptr
(
o
+
(
bos
*
H
+
i_h
*
T
)
*
S
,
(
T
,
S
),
(
S
,
1
),
(
i_t
*
BT
,
i_s
*
BS
),
(
BT
,
BS
),
(
1
,
0
),
)
else
:
p_s
=
tl
.
make_block_ptr
(
s
+
(
bos
*
H
+
i_h
)
*
S
,
(
T
,
S
),
(
H
*
S
,
1
),
(
i_t
*
BT
,
i_s
*
BS
),
(
BT
,
BS
),
(
1
,
0
),
)
p_o
=
tl
.
make_block_ptr
(
o
+
(
bos
*
H
+
i_h
)
*
S
,
(
T
,
S
),
(
H
*
S
,
1
),
(
i_t
*
BT
,
i_s
*
BS
),
(
BT
,
BS
),
(
1
,
0
),
)
# [BT, BS]
b_s
=
tl
.
load
(
p_s
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
b_o
=
tl
.
dot
(
m_s
,
b_s
,
allow_tf32
=
False
)
if
HAS_SCALE
:
b_o
*=
scale
tl
.
store
(
p_o
,
b_o
.
to
(
p_o
.
dtype
.
element_ty
),
boundary_check
=
(
0
,
1
))
def
chunk_local_cumsum_scalar
(
g
:
torch
.
Tensor
,
chunk_size
:
int
,
reverse
:
bool
=
False
,
scale
:
float
=
None
,
cu_seqlens
:
Optional
[
torch
.
Tensor
]
=
None
,
head_first
:
bool
=
False
,
output_dtype
:
Optional
[
torch
.
dtype
]
=
torch
.
float
,
)
->
torch
.
Tensor
:
if
head_first
:
B
,
H
,
T
=
g
.
shape
else
:
B
,
T
,
H
=
g
.
shape
assert
chunk_size
==
2
**
(
chunk_size
.
bit_length
()
-
1
),
"chunk_size must be a power of 2"
BT
=
chunk_size
chunk_indices
=
(
prepare_chunk_indices
(
cu_seqlens
,
BT
)
if
cu_seqlens
is
not
None
else
None
)
NT
=
triton
.
cdiv
(
T
,
BT
)
if
cu_seqlens
is
None
else
len
(
chunk_indices
)
g_org
,
g
=
g
,
torch
.
empty_like
(
g
,
dtype
=
output_dtype
or
g
.
dtype
)
grid
=
(
NT
,
B
*
H
)
chunk_local_cumsum_scalar_kernel
[
grid
](
s
=
g_org
,
o
=
g
,
scale
=
scale
,
cu_seqlens
=
cu_seqlens
,
chunk_indices
=
chunk_indices
,
T
=
T
,
B
=
B
,
H
=
H
,
BT
=
BT
,
HEAD_FIRST
=
head_first
,
REVERSE
=
reverse
,
num_warps
=
8
,
num_stages
=
3
,
)
return
g
def
chunk_local_cumsum_vector
(
g
:
torch
.
Tensor
,
chunk_size
:
int
,
reverse
:
bool
=
False
,
scale
:
float
=
None
,
cu_seqlens
:
Optional
[
torch
.
Tensor
]
=
None
,
head_first
:
bool
=
False
,
output_dtype
:
Optional
[
torch
.
dtype
]
=
torch
.
float
,
)
->
torch
.
Tensor
:
if
head_first
:
B
,
H
,
T
,
S
=
g
.
shape
else
:
B
,
T
,
H
,
S
=
g
.
shape
BT
=
chunk_size
chunk_indices
=
(
prepare_chunk_indices
(
cu_seqlens
,
chunk_size
)
if
cu_seqlens
is
not
None
else
None
)
NT
=
triton
.
cdiv
(
T
,
BT
)
if
cu_seqlens
is
None
else
len
(
chunk_indices
)
assert
chunk_size
==
2
**
(
chunk_size
.
bit_length
()
-
1
),
"chunk_size must be a power of 2"
g_org
,
g
=
g
,
torch
.
empty_like
(
g
,
dtype
=
output_dtype
or
g
.
dtype
)
def
grid
(
meta
):
return
(
triton
.
cdiv
(
meta
[
"S"
],
meta
[
"BS"
]),
NT
,
B
*
H
)
# keep cumulative normalizer in fp32
# this kernel is equivalent to
# g = g.view(B, H, NT, BT, -1).cumsum(-2).view(B, H, T, -1)
chunk_local_cumsum_vector_kernel
[
grid
](
s
=
g_org
,
o
=
g
,
scale
=
scale
,
cu_seqlens
=
cu_seqlens
,
chunk_indices
=
chunk_indices
,
T
=
T
,
B
=
B
,
H
=
H
,
S
=
S
,
BT
=
BT
,
HEAD_FIRST
=
head_first
,
REVERSE
=
reverse
,
)
return
g
@
input_guard
def
chunk_local_cumsum
(
g
:
torch
.
Tensor
,
chunk_size
:
int
,
reverse
:
bool
=
False
,
scale
:
float
=
None
,
cu_seqlens
:
Optional
[
torch
.
Tensor
]
=
None
,
head_first
:
bool
=
False
,
output_dtype
:
Optional
[
torch
.
dtype
]
=
torch
.
float
,
**
kwargs
,
)
->
torch
.
Tensor
:
if
cu_seqlens
is
not
None
:
assert
(
g
.
shape
[
0
]
==
1
),
"Only batch size 1 is supported when cu_seqlens are provided"
if
len
(
g
.
shape
)
==
3
:
return
chunk_local_cumsum_scalar
(
g
=
g
,
chunk_size
=
chunk_size
,
reverse
=
reverse
,
scale
=
scale
,
cu_seqlens
=
cu_seqlens
,
head_first
=
head_first
,
output_dtype
=
output_dtype
,
)
elif
len
(
g
.
shape
)
==
4
:
return
chunk_local_cumsum_vector
(
g
=
g
,
chunk_size
=
chunk_size
,
reverse
=
reverse
,
scale
=
scale
,
cu_seqlens
=
cu_seqlens
,
head_first
=
head_first
,
output_dtype
=
output_dtype
,
)
else
:
raise
ValueError
(
f
"Unsupported input shape
{
g
.
shape
}
, "
f
"which should be (B, T, H, D) if `head_first=False` "
f
"or (B, H, T, D) otherwise"
)
python/sglang/srt/layers/attention/fla/fused_recurrent.py
0 → 100644
View file @
dc491b39
This diff is collapsed.
Click to expand it.
python/sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py
0 → 100644
View file @
dc491b39
from
typing
import
Optional
import
torch
import
triton
import
triton.language
as
tl
from
sglang.srt.layers.attention.fla.utils
import
input_guard
@
triton
.
heuristics
(
{
"USE_INITIAL_STATE"
:
lambda
args
:
args
[
"h0_source"
]
is
not
None
,
"IS_VARLEN"
:
lambda
args
:
args
[
"cu_seqlens"
]
is
not
None
,
}
)
@
triton
.
jit
(
do_not_specialize
=
[
"T"
])
def
fused_sigmoid_gating_delta_rule_update_kernel
(
A_log
,
a
,
dt_bias
,
softplus_beta
,
softplus_threshold
,
q
,
k
,
v
,
b
,
o
,
h0_source
,
h0_indices
,
cu_seqlens
,
scale
,
T
,
B
:
tl
.
constexpr
,
H
:
tl
.
constexpr
,
HV
:
tl
.
constexpr
,
K
:
tl
.
constexpr
,
V
:
tl
.
constexpr
,
BK
:
tl
.
constexpr
,
BV
:
tl
.
constexpr
,
USE_INITIAL_STATE
:
tl
.
constexpr
,
USE_QK_L2NORM_IN_KERNEL
:
tl
.
constexpr
,
IS_VARLEN
:
tl
.
constexpr
,
):
"""
Fused kernel that combines sigmoid gating computation with recurrent delta rule update.
"""
i_k
,
i_v
,
i_nh
=
tl
.
program_id
(
0
),
tl
.
program_id
(
1
),
tl
.
program_id
(
2
)
i_n
,
i_hv
=
i_nh
//
HV
,
i_nh
%
HV
i_h
=
i_hv
//
(
HV
//
H
)
if
IS_VARLEN
:
bos
,
eos
=
(
tl
.
load
(
cu_seqlens
+
i_n
).
to
(
tl
.
int64
),
tl
.
load
(
cu_seqlens
+
i_n
+
1
).
to
(
tl
.
int64
),
)
all
=
T
T
=
eos
-
bos
else
:
bos
,
eos
=
i_n
*
T
,
i_n
*
T
+
T
all
=
B
*
T
o_k
=
i_k
*
BK
+
tl
.
arange
(
0
,
BK
)
o_v
=
i_v
*
BV
+
tl
.
arange
(
0
,
BV
)
p_q
=
q
+
(
bos
*
H
+
i_h
)
*
K
+
o_k
p_k
=
k
+
(
bos
*
H
+
i_h
)
*
K
+
o_k
p_v
=
v
+
(
bos
*
HV
+
i_hv
)
*
V
+
o_v
p_b
=
b
+
bos
*
HV
+
i_hv
p_o
=
o
+
((
i_k
*
all
+
bos
)
*
HV
+
i_hv
)
*
V
+
o_v
# Gating computation pointers
p_A_log
=
A_log
+
i_hv
p_a
=
a
+
bos
*
HV
+
i_hv
p_dt_bias
=
dt_bias
+
i_hv
mask_k
=
o_k
<
K
mask_v
=
o_v
<
V
mask_h
=
mask_k
[:,
None
]
&
mask_v
[
None
,
:]
b_h
=
tl
.
zeros
([
BK
,
BV
],
dtype
=
tl
.
float32
)
if
USE_INITIAL_STATE
:
idx
=
tl
.
load
(
h0_indices
+
i_n
)
if
idx
>=
0
:
p_h0
=
(
h0_source
+
idx
*
HV
*
K
*
V
+
i_hv
*
K
*
V
+
o_k
[:,
None
]
*
V
+
o_v
[
None
,
:]
)
b_h
+=
tl
.
load
(
p_h0
,
mask
=
mask_h
,
other
=
0
).
to
(
tl
.
float32
)
for
_
in
range
(
0
,
T
):
# Load inputs
b_q
=
tl
.
load
(
p_q
,
mask
=
mask_k
,
other
=
0
).
to
(
tl
.
float32
)
b_k
=
tl
.
load
(
p_k
,
mask
=
mask_k
,
other
=
0
).
to
(
tl
.
float32
)
b_v
=
tl
.
load
(
p_v
,
mask
=
mask_v
,
other
=
0
).
to
(
tl
.
float32
)
b_b
=
tl
.
load
(
p_b
).
to
(
tl
.
float32
)
# Compute sigmoid gating
# Load gating parameters
b_A_log
=
tl
.
load
(
p_A_log
).
to
(
tl
.
float32
)
b_a
=
tl
.
load
(
p_a
).
to
(
tl
.
float32
)
b_dt_bias
=
tl
.
load
(
p_dt_bias
).
to
(
tl
.
float32
)
# Compute g = -exp(A_log) * softplus(a + dt_bias)
x
=
b_a
+
b_dt_bias
beta_x
=
softplus_beta
*
x
# Apply softplus with numerical stability
softplus_x
=
tl
.
where
(
beta_x
<=
softplus_threshold
,
(
1.0
/
softplus_beta
)
*
tl
.
log
(
1.0
+
tl
.
exp
(
beta_x
)),
x
,
)
b_g
=
-
tl
.
exp
(
b_A_log
)
*
softplus_x
# Compute beta = sigmoid(b)
b_beta
=
1.0
/
(
1.0
+
tl
.
exp
(
-
b_b
))
# Apply L2 normalization if enabled
if
USE_QK_L2NORM_IN_KERNEL
:
b_q
=
b_q
/
(
tl
.
sqrt
(
tl
.
sum
(
b_q
*
b_q
))
+
1e-6
)
b_k
=
b_k
/
(
tl
.
sqrt
(
tl
.
sum
(
b_k
*
b_k
))
+
1e-6
)
b_q
=
b_q
*
scale
# Apply gating to hidden state: h *= exp(g)
b_h
*=
tl
.
exp
(
b_g
)
# Delta rule: v -= sum(h * k, dim=0)
b_v
-=
tl
.
sum
(
b_h
*
b_k
[:,
None
],
0
)
# Apply beta gating: v *= beta
b_v
*=
b_beta
# Update hidden state: h += k[:, None] * v[None, :]
b_h
+=
b_k
[:,
None
]
*
b_v
[
None
,
:]
# Compute output: o = sum(h * q, dim=0)
b_o
=
tl
.
sum
(
b_h
*
b_q
[:,
None
],
0
)
tl
.
store
(
p_o
,
b_o
.
to
(
p_o
.
dtype
.
element_ty
),
mask
=
mask_v
)
# Update pointers for next timestep
p_q
+=
H
*
K
p_k
+=
H
*
K
p_o
+=
HV
*
V
p_v
+=
HV
*
V
p_b
+=
HV
p_a
+=
HV
# Store final state back to h0_source with bounds checking
if
USE_INITIAL_STATE
:
idx
=
tl
.
load
(
h0_indices
+
i_n
)
if
idx
>=
0
:
p_h0
=
(
h0_source
+
idx
*
HV
*
K
*
V
+
i_hv
*
K
*
V
+
o_k
[:,
None
]
*
V
+
o_v
[
None
,
:]
)
tl
.
store
(
p_h0
,
b_h
.
to
(
p_h0
.
dtype
.
element_ty
),
mask
=
mask_h
)
@
input_guard
def
fused_sigmoid_gating_delta_rule_update
(
A_log
:
torch
.
Tensor
,
a
:
torch
.
Tensor
,
dt_bias
:
torch
.
Tensor
,
softplus_beta
:
float
,
softplus_threshold
:
float
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
initial_state_source
:
torch
.
Tensor
,
initial_state_indices
:
torch
.
Tensor
,
scale
:
Optional
[
float
]
=
None
,
use_qk_l2norm_in_kernel
:
bool
=
False
,
cu_seqlens
:
Optional
[
torch
.
Tensor
]
=
None
,
):
"""
Fused triton implementation of sigmoid gating delta rule update.
This function uses a single fused kernel that combines both sigmoid gating computation
and the recurrent delta rule update for better performance.
"""
B
,
T
,
H
,
K
,
V
=
*
k
.
shape
,
v
.
shape
[
-
1
]
HV
=
v
.
shape
[
2
]
N
=
B
if
cu_seqlens
is
None
else
len
(
cu_seqlens
)
-
1
BK
,
BV
=
triton
.
next_power_of_2
(
K
),
min
(
triton
.
next_power_of_2
(
V
),
8
)
NK
,
NV
=
triton
.
cdiv
(
K
,
BK
),
triton
.
cdiv
(
V
,
BV
)
assert
NK
==
1
,
"NK > 1 is not supported yet"
num_stages
=
3
num_warps
=
1
if
scale
is
None
:
scale
=
k
.
shape
[
-
1
]
**
-
0.5
else
:
assert
scale
>
0
,
"scale must be positive"
o
=
q
.
new_empty
(
NK
,
*
v
.
shape
)
grid
=
(
NK
,
NV
,
N
*
HV
)
fused_sigmoid_gating_delta_rule_update_kernel
[
grid
](
A_log
=
A_log
,
a
=
a
,
dt_bias
=
dt_bias
,
softplus_beta
=
softplus_beta
,
softplus_threshold
=
softplus_threshold
,
q
=
q
,
k
=
k
,
v
=
v
,
b
=
b
,
o
=
o
,
h0_source
=
initial_state_source
,
h0_indices
=
initial_state_indices
,
cu_seqlens
=
cu_seqlens
,
scale
=
scale
,
T
=
T
,
B
=
B
,
H
=
H
,
HV
=
HV
,
K
=
K
,
V
=
V
,
BK
=
BK
,
BV
=
BV
,
USE_QK_L2NORM_IN_KERNEL
=
use_qk_l2norm_in_kernel
,
num_warps
=
num_warps
,
num_stages
=
num_stages
,
)
o
=
o
.
squeeze
(
0
)
return
o
python/sglang/srt/layers/attention/fla/index.py
0 → 100644
View file @
dc491b39
# Adapt from https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/utils/index.py
# -*- coding: utf-8 -*-
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
import
torch
import
torch.nn.functional
as
F
import
triton
import
triton.language
as
tl
from
sglang.srt.layers.attention.fla.utils
import
tensor_cache
@
tensor_cache
def
prepare_lens
(
cu_seqlens
:
torch
.
LongTensor
)
->
torch
.
LongTensor
:
return
cu_seqlens
[
1
:]
-
cu_seqlens
[:
-
1
]
@
tensor_cache
def
prepare_chunk_indices
(
cu_seqlens
:
torch
.
LongTensor
,
chunk_size
:
int
)
->
torch
.
LongTensor
:
indices
=
torch
.
cat
(
[
torch
.
arange
(
n
)
for
n
in
triton
.
cdiv
(
prepare_lens
(
cu_seqlens
),
chunk_size
).
tolist
()
]
)
return
torch
.
stack
([
indices
.
eq
(
0
).
cumsum
(
0
)
-
1
,
indices
],
1
).
to
(
cu_seqlens
)
@
tensor_cache
def
prepare_chunk_offsets
(
cu_seqlens
:
torch
.
LongTensor
,
chunk_size
:
int
)
->
torch
.
LongTensor
:
return
torch
.
cat
(
[
cu_seqlens
.
new_tensor
([
0
]),
triton
.
cdiv
(
prepare_lens
(
cu_seqlens
),
chunk_size
)]
).
cumsum
(
-
1
)
python/sglang/srt/layers/attention/fla/l2norm.py
0 → 100644
View file @
dc491b39
# Adapt from https://github.com/fla-org/flash-linear-attention/blob/main/fla/modules/l2norm.py
# -*- coding: utf-8 -*-
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
from
typing
import
Optional
import
torch
import
torch.nn
as
nn
import
triton
import
triton.language
as
tl
from
sglang.srt.layers.attention.fla.utils
import
input_guard
BT_LIST
=
[
8
,
16
,
32
,
64
,
128
]
# @triton.autotune(
# configs=[
# triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8, 16, 32]
# ],
# key=["D"],
# )
@
triton
.
jit
def
l2norm_fwd_kernel1
(
x
,
y
,
D
,
BD
:
tl
.
constexpr
,
eps
,
):
i_t
=
tl
.
program_id
(
0
)
x
+=
i_t
*
D
y
+=
i_t
*
D
# Compute mean and variance
cols
=
tl
.
arange
(
0
,
BD
)
mask
=
cols
<
D
b_x
=
tl
.
load
(
x
+
cols
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
b_var
=
tl
.
sum
(
b_x
*
b_x
,
axis
=
0
)
b_rstd
=
1
/
tl
.
sqrt
(
b_var
+
eps
)
# tl.store(Rstd + i_t, rstd)
# Normalize and apply linear transformation
b_y
=
b_x
*
b_rstd
tl
.
store
(
y
+
cols
,
b_y
,
mask
=
mask
)
# @triton.autotune(
# configs=[
# triton.Config({"BT": BT}, num_warps=num_warps)
# for num_warps in [1, 2, 4, 8, 16]
# for BT in BT_LIST
# ],
# key=["D", "NB"],
# )
@
triton
.
jit
def
l2norm_fwd_kernel
(
x
,
y
,
eps
,
NB
:
tl
.
constexpr
,
T
:
tl
.
constexpr
,
D
:
tl
.
constexpr
,
BT
:
tl
.
constexpr
,
BD
:
tl
.
constexpr
,
):
i_t
=
tl
.
program_id
(
0
)
p_x
=
tl
.
make_block_ptr
(
x
,
(
T
,
D
),
(
D
,
1
),
(
i_t
*
BT
,
0
),
(
BT
,
BD
),
(
1
,
0
))
b_x
=
tl
.
load
(
p_x
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
b_var
=
tl
.
sum
(
b_x
*
b_x
,
axis
=
1
)
b_y
=
b_x
/
tl
.
sqrt
(
b_var
+
eps
)[:,
None
]
p_y
=
tl
.
make_block_ptr
(
y
,
(
T
,
D
),
(
D
,
1
),
(
i_t
*
BT
,
0
),
(
BT
,
BD
),
(
1
,
0
))
tl
.
store
(
p_y
,
b_y
.
to
(
p_y
.
dtype
.
element_ty
),
boundary_check
=
(
0
,
1
))
def
l2norm_fwd
(
x
:
torch
.
Tensor
,
eps
:
float
=
1e-6
,
output_dtype
:
Optional
[
torch
.
dtype
]
=
None
):
x_shape_og
=
x
.
shape
x
=
x
.
view
(
-
1
,
x
.
shape
[
-
1
])
# allocate output
if
output_dtype
is
None
:
y
=
torch
.
empty_like
(
x
)
else
:
y
=
torch
.
empty_like
(
x
,
dtype
=
output_dtype
)
assert
y
.
stride
(
-
1
)
==
1
T
,
D
=
x
.
shape
[
0
],
x
.
shape
[
-
1
]
# rstd = torch.empty((T,), dtype=torch.float32, device=x.device)
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE
=
65536
//
x
.
element_size
()
BD
=
min
(
MAX_FUSED_SIZE
,
triton
.
next_power_of_2
(
D
))
if
D
>
BD
:
raise
RuntimeError
(
"This layer doesn't support feature dim >= 64KB."
)
if
D
<=
512
:
NB
=
triton
.
cdiv
(
T
,
2048
)
def
grid
(
meta
):
return
(
triton
.
cdiv
(
T
,
meta
[
"BT"
]),)
l2norm_fwd_kernel
[
grid
](
x
,
y
,
eps
,
NB
=
NB
,
T
=
T
,
D
=
D
,
BD
=
BD
,
BT
=
16
,
num_warps
=
8
,
num_stages
=
3
,
)
else
:
l2norm_fwd_kernel1
[(
T
,)](
x
,
y
,
eps
=
eps
,
D
=
D
,
BD
=
BD
,
num_warps
=
8
,
num_stages
=
3
,
)
return
y
.
view
(
x_shape_og
)
class
L2NormFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
@
input_guard
def
forward
(
ctx
,
x
,
eps
=
1e-6
,
output_dtype
=
None
):
return
l2norm_fwd
(
x
,
eps
,
output_dtype
)
def
l2norm
(
x
:
torch
.
Tensor
,
eps
:
float
=
1e-6
,
output_dtype
:
Optional
[
torch
.
dtype
]
=
None
)
->
torch
.
Tensor
:
return
L2NormFunction
.
apply
(
x
,
eps
,
output_dtype
)
l2_norm
=
l2norm
class
L2Norm
(
nn
.
Module
):
def
__init__
(
self
,
eps
:
float
=
1e-6
,
output_dtype
:
Optional
[
torch
.
dtype
]
=
None
):
super
().
__init__
()
self
.
eps
=
eps
self
.
output_dtype
=
output_dtype
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
l2norm
(
x
,
self
.
eps
,
self
.
output_dtype
)
python/sglang/srt/layers/attention/fla/layernorm_gated.py
0 → 100644
View file @
dc491b39
# Adapt from https://github.com/fla-org/flash-linear-attention/blob/main/fla/modules/layernorm_gated.py
# Copyright (c) 2024, Tri Dao.
# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.
# This backward pass is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
import
math
import
torch
import
torch.nn.functional
as
F
import
triton
import
triton.language
as
tl
from
einops
import
rearrange
def
rms_norm_ref
(
x
,
weight
,
bias
,
z
=
None
,
eps
=
1e-6
,
group_size
=
None
,
norm_before_gate
=
True
,
upcast
=
True
,
):
dtype
=
x
.
dtype
N
=
x
.
shape
[
-
1
]
weight
=
weight
.
float
()
bias
=
bias
.
float
()
if
bias
is
not
None
else
None
if
upcast
:
x
=
x
.
float
()
z
=
z
.
float
()
if
z
is
not
None
else
z
if
z
is
not
None
and
not
norm_before_gate
:
x
=
x
*
F
.
silu
(
z
)
if
group_size
is
None
:
rstd
=
1
/
torch
.
sqrt
((
x
.
square
()).
mean
(
dim
=-
1
,
keepdim
=
True
)
+
eps
)
out
=
(
x
*
rstd
*
weight
)
+
bias
if
bias
is
not
None
else
(
x
*
rstd
*
weight
)
else
:
x_group
=
rearrange
(
x
,
"... (g d) -> ... g d"
,
d
=
group_size
)
rstd
=
1
/
torch
.
sqrt
((
x_group
.
square
()).
mean
(
dim
=-
1
,
keepdim
=
True
)
+
eps
)
out
=
rearrange
(
x_group
*
rstd
,
"... g d -> ... (g d)"
)
*
weight
if
bias
is
not
None
:
out
=
out
+
bias
if
z
is
not
None
and
norm_before_gate
:
out
*=
F
.
silu
(
z
)
return
out
.
to
(
dtype
)
@
triton
.
heuristics
({
"HAS_BIAS"
:
lambda
args
:
args
[
"B"
]
is
not
None
})
@
triton
.
heuristics
({
"HAS_Z"
:
lambda
args
:
args
[
"Z"
]
is
not
None
})
@
triton
.
jit
def
_layer_norm_fwd_1pass_kernel
(
X
,
# pointer to the input
Y
,
# pointer to the output
W
,
# pointer to the weights
B
,
# pointer to the biases
Z
,
# pointer to the other branch
Mean
,
# pointer to the mean
Rstd
,
# pointer to the 1/std
stride_x_row
,
# how much to increase the pointer when moving by 1 row
stride_y_row
,
stride_z_row
,
M
,
# number of rows in X
N
,
# number of columns in X
eps
,
# epsilon to avoid division by zero
BLOCK_N
:
tl
.
constexpr
,
HAS_BIAS
:
tl
.
constexpr
,
HAS_Z
:
tl
.
constexpr
,
NORM_BEFORE_GATE
:
tl
.
constexpr
,
IS_RMS_NORM
:
tl
.
constexpr
,
):
# Map the program id to the row of X and Y it should compute.
row
=
tl
.
program_id
(
0
)
group
=
tl
.
program_id
(
1
)
X
+=
row
*
stride_x_row
+
group
*
N
Y
+=
row
*
stride_y_row
+
group
*
N
if
HAS_Z
:
Z
+=
row
*
stride_z_row
+
group
*
N
if
not
IS_RMS_NORM
:
Mean
+=
group
*
M
Rstd
+=
group
*
M
W
+=
group
*
N
if
HAS_BIAS
:
B
+=
group
*
N
# Compute mean and variance
cols
=
tl
.
arange
(
0
,
BLOCK_N
)
x
=
tl
.
load
(
X
+
cols
,
mask
=
cols
<
N
,
other
=
0.0
).
to
(
tl
.
float32
)
if
HAS_Z
and
not
NORM_BEFORE_GATE
:
z
=
tl
.
load
(
Z
+
cols
,
mask
=
cols
<
N
).
to
(
tl
.
float32
)
x
*=
z
*
tl
.
sigmoid
(
z
)
if
not
IS_RMS_NORM
:
mean
=
tl
.
sum
(
x
,
axis
=
0
)
/
N
tl
.
store
(
Mean
+
row
,
mean
)
xbar
=
tl
.
where
(
cols
<
N
,
x
-
mean
,
0.0
)
var
=
tl
.
sum
(
xbar
*
xbar
,
axis
=
0
)
/
N
else
:
xbar
=
tl
.
where
(
cols
<
N
,
x
,
0.0
)
var
=
tl
.
sum
(
xbar
*
xbar
,
axis
=
0
)
/
N
rstd
=
1
/
tl
.
sqrt
(
var
+
eps
)
tl
.
store
(
Rstd
+
row
,
rstd
)
# Normalize and apply linear transformation
mask
=
cols
<
N
w
=
tl
.
load
(
W
+
cols
,
mask
=
mask
).
to
(
tl
.
float32
)
if
HAS_BIAS
:
b
=
tl
.
load
(
B
+
cols
,
mask
=
mask
).
to
(
tl
.
float32
)
x_hat
=
(
x
-
mean
)
*
rstd
if
not
IS_RMS_NORM
else
x
*
rstd
y
=
x_hat
*
w
+
b
if
HAS_BIAS
else
x_hat
*
w
if
HAS_Z
and
NORM_BEFORE_GATE
:
z
=
tl
.
load
(
Z
+
cols
,
mask
=
mask
).
to
(
tl
.
float32
)
y
*=
z
*
tl
.
sigmoid
(
z
)
# Write output
tl
.
store
(
Y
+
cols
,
y
,
mask
=
mask
)
def
_layer_norm_fwd
(
x
,
weight
,
bias
,
eps
,
z
=
None
,
out
=
None
,
group_size
=
None
,
norm_before_gate
=
True
,
is_rms_norm
=
False
,
):
M
,
N
=
x
.
shape
if
group_size
is
None
:
group_size
=
N
assert
N
%
group_size
==
0
ngroups
=
N
//
group_size
assert
x
.
stride
(
-
1
)
==
1
if
z
is
not
None
:
assert
z
.
stride
(
-
1
)
==
1
assert
z
.
shape
==
(
M
,
N
)
assert
weight
.
shape
==
(
N
,)
assert
weight
.
stride
(
-
1
)
==
1
if
bias
is
not
None
:
assert
bias
.
stride
(
-
1
)
==
1
assert
bias
.
shape
==
(
N
,)
# allocate output
if
out
is
not
None
:
assert
out
.
shape
==
x
.
shape
else
:
out
=
torch
.
empty_like
(
x
)
assert
out
.
stride
(
-
1
)
==
1
mean
=
(
torch
.
empty
((
ngroups
*
M
,),
dtype
=
torch
.
float32
,
device
=
x
.
device
)
if
not
is_rms_norm
else
None
)
rstd
=
torch
.
empty
((
ngroups
*
M
,),
dtype
=
torch
.
float32
,
device
=
x
.
device
)
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE
=
65536
//
x
.
element_size
()
BLOCK_N
=
min
(
MAX_FUSED_SIZE
,
triton
.
next_power_of_2
(
group_size
))
if
group_size
>
BLOCK_N
:
raise
RuntimeError
(
"This layer norm doesn't support feature dim >= 64KB."
)
# heuristics for number of warps
num_warps
=
min
(
max
(
BLOCK_N
//
256
,
1
),
8
)
grid
=
(
M
,
ngroups
)
with
torch
.
cuda
.
device
(
x
.
device
.
index
):
_layer_norm_fwd_1pass_kernel
[
grid
](
x
,
out
,
weight
,
bias
,
z
,
mean
,
rstd
,
x
.
stride
(
0
),
out
.
stride
(
0
),
z
.
stride
(
0
)
if
z
is
not
None
else
0
,
M
,
group_size
,
eps
,
BLOCK_N
=
BLOCK_N
,
NORM_BEFORE_GATE
=
norm_before_gate
,
IS_RMS_NORM
=
is_rms_norm
,
num_warps
=
num_warps
,
)
return
out
,
mean
,
rstd
class
LayerNormFn
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
x
,
weight
,
bias
,
z
=
None
,
eps
=
1e-6
,
group_size
=
None
,
norm_before_gate
=
True
,
is_rms_norm
=
False
,
):
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
x_shape_og
=
x
.
shape
# reshape input data into 2D tensor
x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
if
x
.
stride
(
-
1
)
!=
1
:
x
=
x
.
contiguous
()
if
z
is
not
None
:
assert
z
.
shape
==
x_shape_og
z
=
z
.
reshape
(
-
1
,
z
.
shape
[
-
1
])
if
z
.
stride
(
-
1
)
!=
1
:
z
=
z
.
contiguous
()
weight
=
weight
.
contiguous
()
if
bias
is
not
None
:
bias
=
bias
.
contiguous
()
y
,
mean
,
rstd
=
_layer_norm_fwd
(
x
,
weight
,
bias
,
eps
,
z
=
z
,
group_size
=
group_size
,
norm_before_gate
=
norm_before_gate
,
is_rms_norm
=
is_rms_norm
,
)
return
y
.
reshape
(
x_shape_og
)
def
layernorm_fn
(
x
,
weight
,
bias
,
z
=
None
,
eps
=
1e-6
,
group_size
=
None
,
norm_before_gate
=
True
,
is_rms_norm
=
False
,
):
return
LayerNormFn
.
apply
(
x
,
weight
,
bias
,
z
,
eps
,
group_size
,
norm_before_gate
,
is_rms_norm
)
def
rmsnorm_fn
(
x
,
weight
,
bias
,
z
=
None
,
eps
=
1e-6
,
group_size
=
None
,
norm_before_gate
=
True
):
return
LayerNormFn
.
apply
(
x
,
weight
,
bias
,
z
,
eps
,
group_size
,
norm_before_gate
,
True
)
class
LayerNorm
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
eps
=
1e-5
,
group_size
=
None
,
norm_before_gate
=
True
,
device
=
None
,
dtype
=
None
,
):
"""If group_size is not None, we do GroupNorm with each group having group_size elements.
group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group).
"""
factory_kwargs
=
{
"device"
:
device
,
"dtype"
:
dtype
}
super
().
__init__
()
self
.
eps
=
eps
self
.
weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
hidden_size
,
**
factory_kwargs
))
self
.
bias
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
hidden_size
,
**
factory_kwargs
))
self
.
group_size
=
group_size
self
.
norm_before_gate
=
norm_before_gate
self
.
reset_parameters
()
def
reset_parameters
(
self
):
torch
.
nn
.
init
.
ones_
(
self
.
weight
)
torch
.
nn
.
init
.
zeros_
(
self
.
bias
)
def
forward
(
self
,
x
,
z
=
None
):
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
return
layernorm_fn
(
x
,
self
.
weight
,
self
.
bias
,
z
=
z
,
group_size
=
self
.
group_size
,
eps
=
self
.
eps
,
norm_before_gate
=
self
.
norm_before_gate
,
)
class
RMSNorm
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
eps
=
1e-5
,
group_size
=
None
,
norm_before_gate
=
True
,
device
=
None
,
dtype
=
None
,
):
"""If group_size is not None, we do GroupNorm with each group having group_size elements.
group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group).
"""
factory_kwargs
=
{
"device"
:
device
,
"dtype"
:
dtype
}
super
().
__init__
()
self
.
eps
=
eps
self
.
weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
hidden_size
,
**
factory_kwargs
))
self
.
register_parameter
(
"bias"
,
None
)
self
.
group_size
=
group_size
self
.
norm_before_gate
=
norm_before_gate
self
.
reset_parameters
()
def
reset_parameters
(
self
):
torch
.
nn
.
init
.
ones_
(
self
.
weight
)
def
forward
(
self
,
x
,
z
=
None
):
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
return
rmsnorm_fn
(
x
,
self
.
weight
,
self
.
bias
,
z
=
z
,
eps
=
self
.
eps
,
group_size
=
self
.
group_size
,
norm_before_gate
=
self
.
norm_before_gate
,
)
python/sglang/srt/layers/attention/fla/op.py
0 → 100644
View file @
dc491b39
# Adapt from https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/utils/op.py
# -*- coding: utf-8 -*-
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
import
os
import
triton
import
triton.language
as
tl
import
triton.language.extra.libdevice
as
tldevice
from
sglang.srt.layers.attention.fla.utils
import
is_gather_supported
if
os
.
environ
.
get
(
"FLA_USE_FAST_OPS"
,
"0"
)
==
"1"
:
exp
=
tldevice
.
fast_expf
exp2
=
tldevice
.
exp2
log
=
tldevice
.
fast_logf
log2
=
tldevice
.
fast_log2f
else
:
exp
=
tl
.
exp
exp2
=
tl
.
math
.
exp2
log
=
tl
.
log
log2
=
tl
.
log2
@
triton
.
jit
def
safe_exp
(
x
):
return
exp
(
tl
.
where
(
x
<=
0
,
x
,
float
(
"-inf"
)))
if
not
is_gather_supported
:
@
triton
.
jit
def
gather
(
src
,
index
,
axis
,
_builder
=
None
):
"""
Gather operation that works when tl.gather is not supported.
This is a fallback implementation that returns None.
Just to make triton compiler happy.
"""
return
None
else
:
gather
=
tl
.
gather
if
hasattr
(
triton
.
language
,
"_experimental_make_tensor_descriptor"
):
# For Triton 3.3.x
make_tensor_descriptor
=
triton
.
language
.
_experimental_make_tensor_descriptor
elif
hasattr
(
triton
.
language
,
"make_tensor_descriptor"
):
# For Triton 3.4.x and later
make_tensor_descriptor
=
triton
.
language
.
make_tensor_descriptor
else
:
"""
Fallback implementation when TMA is not supported.
Returns None to indicate TMA descriptors are unavailable.
Just make triton compiler happy.
"""
@
triton
.
jit
def
make_tensor_descriptor
(
base
,
shape
,
strides
,
block_shape
,
_builder
=
None
,
):
return
None
python/sglang/srt/layers/attention/fla/solve_tril.py
0 → 100644
View file @
dc491b39
# Adapt from https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/utils/solve_tril.py
# -*- coding: utf-8 -*-
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
from
typing
import
Optional
import
torch
import
triton
import
triton.language
as
tl
from
sglang.srt.layers.attention.fla.index
import
prepare_chunk_indices
from
sglang.srt.layers.attention.fla.utils
import
input_guard
@
triton
.
heuristics
({
"IS_VARLEN"
:
lambda
args
:
args
[
"cu_seqlens"
]
is
not
None
})
# @triton.autotune(
# configs=[
# triton.Config({}, num_warps=num_warps, num_stages=num_stages)
# for num_warps in [1, 2, 4, 8]
# for num_stages in [2, 3, 4, 5]
# ],
# key=["BT"],
# )
@
triton
.
jit
(
do_not_specialize
=
[
"T"
])
def
solve_tril_16x16_kernel
(
A
,
Ad
,
cu_seqlens
,
chunk_indices
,
T
,
H
:
tl
.
constexpr
,
BT
:
tl
.
constexpr
,
IS_VARLEN
:
tl
.
constexpr
,
):
i_t
,
i_bh
=
tl
.
program_id
(
0
),
tl
.
program_id
(
1
)
i_b
,
i_h
=
i_bh
//
H
,
i_bh
%
H
if
IS_VARLEN
:
i_n
,
i_t
=
tl
.
load
(
chunk_indices
+
i_t
*
2
).
to
(
tl
.
int32
),
tl
.
load
(
chunk_indices
+
i_t
*
2
+
1
).
to
(
tl
.
int32
)
bos
,
eos
=
tl
.
load
(
cu_seqlens
+
i_n
).
to
(
tl
.
int32
),
tl
.
load
(
cu_seqlens
+
i_n
+
1
).
to
(
tl
.
int32
)
T
=
eos
-
bos
else
:
bos
,
eos
=
i_b
*
T
,
i_b
*
T
+
T
A
=
A
+
(
bos
*
H
+
i_h
)
*
BT
Ad
=
Ad
+
(
bos
*
H
+
i_h
)
*
16
offset
=
(
i_t
*
16
)
%
BT
p_A
=
tl
.
make_block_ptr
(
A
,
(
T
,
BT
),
(
H
*
BT
,
1
),
(
i_t
*
16
,
offset
),
(
16
,
16
),
(
1
,
0
)
)
p_Ai
=
tl
.
make_block_ptr
(
Ad
,
(
T
,
16
),
(
H
*
16
,
1
),
(
i_t
*
16
,
0
),
(
16
,
16
),
(
1
,
0
))
b_A
=
tl
.
load
(
p_A
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
b_A
=
-
tl
.
where
(
tl
.
arange
(
0
,
16
)[:,
None
]
>
tl
.
arange
(
0
,
16
)[
None
,
:],
b_A
,
0
)
o_i
=
tl
.
arange
(
0
,
16
)
for
i
in
range
(
1
,
min
(
16
,
T
-
i_t
*
16
)):
b_a
=
-
tl
.
load
(
A
+
(
i_t
*
16
+
i
)
*
H
*
BT
+
o_i
+
offset
)
b_a
=
b_a
+
tl
.
sum
(
b_a
[:,
None
]
*
b_A
,
0
)
mask
=
o_i
==
i
b_A
=
tl
.
where
(
mask
[:,
None
],
b_a
,
b_A
)
b_A
+=
o_i
[:,
None
]
==
o_i
[
None
,
:]
tl
.
store
(
p_Ai
,
b_A
.
to
(
p_Ai
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
),
)
@
triton
.
heuristics
({
"IS_VARLEN"
:
lambda
args
:
args
[
"cu_seqlens"
]
is
not
None
})
# @triton.autotune(
# configs=[
# triton.Config({}, num_warps=num_warps, num_stages=num_stages)
# for num_warps in [1, 2, 4, 8]
# for num_stages in [2, 3, 4, 5]
# ],
# key=["H", "BT", "IS_VARLEN"],
# )
@
triton
.
jit
(
do_not_specialize
=
[
"T"
])
def
merge_16x16_to_32x32_inverse_kernel
(
A
,
Ad
,
Ai
,
cu_seqlens
,
chunk_indices
,
T
,
H
:
tl
.
constexpr
,
BT
:
tl
.
constexpr
,
IS_VARLEN
:
tl
.
constexpr
,
):
i_t
,
i_bh
=
tl
.
program_id
(
0
),
tl
.
program_id
(
1
)
i_b
,
i_h
=
i_bh
//
H
,
i_bh
%
H
if
IS_VARLEN
:
i_n
,
i_t
=
tl
.
load
(
chunk_indices
+
i_t
*
2
).
to
(
tl
.
int32
),
tl
.
load
(
chunk_indices
+
i_t
*
2
+
1
).
to
(
tl
.
int32
)
bos
,
eos
=
tl
.
load
(
cu_seqlens
+
i_n
).
to
(
tl
.
int32
),
tl
.
load
(
cu_seqlens
+
i_n
+
1
).
to
(
tl
.
int32
)
T
=
eos
-
bos
else
:
bos
,
eos
=
i_b
*
T
,
i_b
*
T
+
T
A
+=
(
bos
*
H
+
i_h
)
*
32
Ad
+=
(
bos
*
H
+
i_h
)
*
16
Ai
+=
(
bos
*
H
+
i_h
)
*
32
p_A_21
=
tl
.
make_block_ptr
(
A
,
(
T
,
32
),
(
H
*
32
,
1
),
(
i_t
*
32
+
16
,
0
),
(
16
,
16
),
(
1
,
0
)
)
p_Ad_11
=
tl
.
make_block_ptr
(
Ad
,
(
T
,
16
),
(
H
*
16
,
1
),
(
i_t
*
32
,
0
),
(
16
,
16
),
(
1
,
0
)
)
p_Ad_22
=
tl
.
make_block_ptr
(
Ad
,
(
T
,
16
),
(
H
*
16
,
1
),
(
i_t
*
32
+
16
,
0
),
(
16
,
16
),
(
1
,
0
)
)
p_Ai_11
=
tl
.
make_block_ptr
(
Ai
,
(
T
,
32
),
(
H
*
32
,
1
),
(
i_t
*
32
,
0
),
(
16
,
16
),
(
1
,
0
)
)
p_Ai_22
=
tl
.
make_block_ptr
(
Ai
,
(
T
,
32
),
(
H
*
32
,
1
),
(
i_t
*
32
+
16
,
16
),
(
16
,
16
),
(
1
,
0
)
)
p_Ai_21
=
tl
.
make_block_ptr
(
Ai
,
(
T
,
32
),
(
H
*
32
,
1
),
(
i_t
*
32
+
16
,
0
),
(
16
,
16
),
(
1
,
0
)
)
A_21
=
tl
.
load
(
p_A_21
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
Ai_11
=
tl
.
load
(
p_Ad_11
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
Ai_22
=
tl
.
load
(
p_Ad_22
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
Ai_21
=
-
tl
.
dot
(
tl
.
dot
(
Ai_22
,
A_21
,
input_precision
=
"ieee"
),
Ai_11
,
input_precision
=
"ieee"
)
tl
.
store
(
p_Ai_11
,
Ai_11
.
to
(
p_Ai_11
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
),
)
tl
.
store
(
p_Ai_22
,
Ai_22
.
to
(
p_Ai_22
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
),
)
tl
.
store
(
p_Ai_21
,
Ai_21
.
to
(
p_Ai_21
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
),
)
@
triton
.
heuristics
({
"IS_VARLEN"
:
lambda
args
:
args
[
"cu_seqlens"
]
is
not
None
})
# @triton.autotune(
# configs=[
# triton.Config({}, num_warps=num_warps, num_stages=num_stages)
# for num_warps in [2, 4, 8]
# for num_stages in [2, 3, 4, 5]
# ],
# key=["H", "BT", "IS_VARLEN"],
# )
@
triton
.
jit
(
do_not_specialize
=
[
"T"
])
def
merge_16x16_to_64x64_inverse_kernel
(
A
,
Ad
,
Ai
,
cu_seqlens
,
chunk_indices
,
T
,
H
:
tl
.
constexpr
,
BT
:
tl
.
constexpr
,
IS_VARLEN
:
tl
.
constexpr
,
):
i_t
,
i_bh
=
tl
.
program_id
(
0
),
tl
.
program_id
(
1
)
i_b
,
i_h
=
i_bh
//
H
,
i_bh
%
H
if
IS_VARLEN
:
i_n
,
i_t
=
tl
.
load
(
chunk_indices
+
i_t
*
2
).
to
(
tl
.
int32
),
tl
.
load
(
chunk_indices
+
i_t
*
2
+
1
).
to
(
tl
.
int32
)
bos
,
eos
=
tl
.
load
(
cu_seqlens
+
i_n
).
to
(
tl
.
int32
),
tl
.
load
(
cu_seqlens
+
i_n
+
1
).
to
(
tl
.
int32
)
T
=
eos
-
bos
else
:
bos
,
eos
=
i_b
*
T
,
i_b
*
T
+
T
A
+=
(
bos
*
H
+
i_h
)
*
64
Ad
+=
(
bos
*
H
+
i_h
)
*
16
Ai
+=
(
bos
*
H
+
i_h
)
*
64
p_A_21
=
tl
.
make_block_ptr
(
A
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
+
16
,
0
),
(
16
,
16
),
(
1
,
0
)
)
p_A_32
=
tl
.
make_block_ptr
(
A
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
+
32
,
16
),
(
16
,
16
),
(
1
,
0
)
)
p_A_31
=
tl
.
make_block_ptr
(
A
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
+
32
,
0
),
(
16
,
16
),
(
1
,
0
)
)
p_A_43
=
tl
.
make_block_ptr
(
A
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
+
48
,
32
),
(
16
,
16
),
(
1
,
0
)
)
p_A_42
=
tl
.
make_block_ptr
(
A
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
+
48
,
16
),
(
16
,
16
),
(
1
,
0
)
)
p_A_41
=
tl
.
make_block_ptr
(
A
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
+
48
,
0
),
(
16
,
16
),
(
1
,
0
)
)
p_Ad_11
=
tl
.
make_block_ptr
(
Ad
,
(
T
,
16
),
(
H
*
16
,
1
),
(
i_t
*
64
,
0
),
(
16
,
16
),
(
1
,
0
)
)
p_Ad_22
=
tl
.
make_block_ptr
(
Ad
,
(
T
,
16
),
(
H
*
16
,
1
),
(
i_t
*
64
+
16
,
0
),
(
16
,
16
),
(
1
,
0
)
)
p_Ad_33
=
tl
.
make_block_ptr
(
Ad
,
(
T
,
16
),
(
H
*
16
,
1
),
(
i_t
*
64
+
32
,
0
),
(
16
,
16
),
(
1
,
0
)
)
p_Ad_44
=
tl
.
make_block_ptr
(
Ad
,
(
T
,
16
),
(
H
*
16
,
1
),
(
i_t
*
64
+
48
,
0
),
(
16
,
16
),
(
1
,
0
)
)
A_21
=
tl
.
load
(
p_A_21
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
A_32
=
tl
.
load
(
p_A_32
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
A_31
=
tl
.
load
(
p_A_31
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
A_43
=
tl
.
load
(
p_A_43
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
A_42
=
tl
.
load
(
p_A_42
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
A_41
=
tl
.
load
(
p_A_41
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
Ai_11
=
tl
.
load
(
p_Ad_11
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
Ai_22
=
tl
.
load
(
p_Ad_22
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
Ai_33
=
tl
.
load
(
p_Ad_33
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
Ai_44
=
tl
.
load
(
p_Ad_44
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
Ai_21
=
-
tl
.
dot
(
tl
.
dot
(
Ai_22
,
A_21
,
input_precision
=
"ieee"
),
Ai_11
,
input_precision
=
"ieee"
)
Ai_32
=
-
tl
.
dot
(
tl
.
dot
(
Ai_33
,
A_32
,
input_precision
=
"ieee"
),
Ai_22
,
input_precision
=
"ieee"
)
Ai_43
=
-
tl
.
dot
(
tl
.
dot
(
Ai_44
,
A_43
,
input_precision
=
"ieee"
),
Ai_33
,
input_precision
=
"ieee"
)
Ai_31
=
-
tl
.
dot
(
Ai_33
,
tl
.
dot
(
A_31
,
Ai_11
,
input_precision
=
"ieee"
)
+
tl
.
dot
(
A_32
,
Ai_21
,
input_precision
=
"ieee"
),
input_precision
=
"ieee"
,
)
Ai_42
=
-
tl
.
dot
(
Ai_44
,
tl
.
dot
(
A_42
,
Ai_22
,
input_precision
=
"ieee"
)
+
tl
.
dot
(
A_43
,
Ai_32
,
input_precision
=
"ieee"
),
input_precision
=
"ieee"
,
)
Ai_41
=
-
tl
.
dot
(
Ai_44
,
tl
.
dot
(
A_41
,
Ai_11
,
input_precision
=
"ieee"
)
+
tl
.
dot
(
A_42
,
Ai_21
,
input_precision
=
"ieee"
)
+
tl
.
dot
(
A_43
,
Ai_31
,
input_precision
=
"ieee"
),
input_precision
=
"ieee"
,
)
p_Ai_11
=
tl
.
make_block_ptr
(
Ai
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
,
0
),
(
16
,
16
),
(
1
,
0
)
)
p_Ai_22
=
tl
.
make_block_ptr
(
Ai
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
+
16
,
16
),
(
16
,
16
),
(
1
,
0
)
)
p_Ai_33
=
tl
.
make_block_ptr
(
Ai
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
+
32
,
32
),
(
16
,
16
),
(
1
,
0
)
)
p_Ai_44
=
tl
.
make_block_ptr
(
Ai
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
+
48
,
48
),
(
16
,
16
),
(
1
,
0
)
)
p_Ai_21
=
tl
.
make_block_ptr
(
Ai
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
+
16
,
0
),
(
16
,
16
),
(
1
,
0
)
)
p_Ai_31
=
tl
.
make_block_ptr
(
Ai
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
+
32
,
0
),
(
16
,
16
),
(
1
,
0
)
)
p_Ai_32
=
tl
.
make_block_ptr
(
Ai
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
+
32
,
16
),
(
16
,
16
),
(
1
,
0
)
)
p_Ai_41
=
tl
.
make_block_ptr
(
Ai
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
+
48
,
0
),
(
16
,
16
),
(
1
,
0
)
)
p_Ai_42
=
tl
.
make_block_ptr
(
Ai
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
+
48
,
16
),
(
16
,
16
),
(
1
,
0
)
)
p_Ai_43
=
tl
.
make_block_ptr
(
Ai
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
+
48
,
32
),
(
16
,
16
),
(
1
,
0
)
)
tl
.
store
(
p_Ai_11
,
Ai_11
.
to
(
p_Ai_11
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
),
)
tl
.
store
(
p_Ai_22
,
Ai_22
.
to
(
p_Ai_22
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
),
)
tl
.
store
(
p_Ai_33
,
Ai_33
.
to
(
p_Ai_33
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
),
)
tl
.
store
(
p_Ai_44
,
Ai_44
.
to
(
p_Ai_44
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
),
)
tl
.
store
(
p_Ai_21
,
Ai_21
.
to
(
p_Ai_21
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
),
)
tl
.
store
(
p_Ai_31
,
Ai_31
.
to
(
p_Ai_31
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
),
)
tl
.
store
(
p_Ai_32
,
Ai_32
.
to
(
p_Ai_32
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
),
)
tl
.
store
(
p_Ai_41
,
Ai_41
.
to
(
p_Ai_41
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
),
)
tl
.
store
(
p_Ai_42
,
Ai_42
.
to
(
p_Ai_42
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
),
)
tl
.
store
(
p_Ai_43
,
Ai_43
.
to
(
p_Ai_43
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
),
)
fill_zeros
=
tl
.
zeros
((
16
,
16
),
dtype
=
tl
.
float32
)
p_Ai_12
=
tl
.
make_block_ptr
(
Ai
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
,
16
),
(
16
,
16
),
(
1
,
0
)
)
p_Ai_13
=
tl
.
make_block_ptr
(
Ai
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
,
32
),
(
16
,
16
),
(
1
,
0
)
)
p_Ai_14
=
tl
.
make_block_ptr
(
Ai
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
,
48
),
(
16
,
16
),
(
1
,
0
)
)
p_Ai_23
=
tl
.
make_block_ptr
(
Ai
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
+
16
,
32
),
(
16
,
16
),
(
1
,
0
)
)
p_Ai_24
=
tl
.
make_block_ptr
(
Ai
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
+
16
,
48
),
(
16
,
16
),
(
1
,
0
)
)
p_Ai_34
=
tl
.
make_block_ptr
(
Ai
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
+
32
,
48
),
(
16
,
16
),
(
1
,
0
)
)
tl
.
store
(
p_Ai_12
,
fill_zeros
.
to
(
p_Ai_12
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
),
)
tl
.
store
(
p_Ai_13
,
fill_zeros
.
to
(
p_Ai_13
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
),
)
tl
.
store
(
p_Ai_14
,
fill_zeros
.
to
(
p_Ai_14
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
),
)
tl
.
store
(
p_Ai_23
,
fill_zeros
.
to
(
p_Ai_23
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
),
)
tl
.
store
(
p_Ai_24
,
fill_zeros
.
to
(
p_Ai_24
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
),
)
tl
.
store
(
p_Ai_34
,
fill_zeros
.
to
(
p_Ai_34
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
),
)
@
input_guard
def
solve_tril
(
A
:
torch
.
Tensor
,
cu_seqlens
:
Optional
[
torch
.
Tensor
]
=
None
,
output_dtype
:
torch
.
dtype
=
torch
.
float
,
)
->
torch
.
Tensor
:
"""
Compute the inverse of the lower triangular matrix
A should be strictly lower triangular, i.e., A.triu() == 0.
Args:
A (torch.Tensor):
[B, T, H, K]
cu_seqlens (torch.Tensor):
The cumulative sequence lengths of the input tensor.
Default: None.
output_dtype (torch.dtype):
The dtype of the output tensor. Default: `torch.float`
Returns:
(I + A)^-1 with the same shape as A
"""
assert
A
.
shape
[
-
1
]
in
[
16
,
32
,
64
]
B
,
T
,
H
,
BT
=
A
.
shape
Ad
=
torch
.
empty
(
B
,
T
,
H
,
16
,
device
=
A
.
device
,
dtype
=
torch
.
float
if
BT
!=
16
else
output_dtype
)
chunk_indices
=
(
prepare_chunk_indices
(
cu_seqlens
,
16
)
if
cu_seqlens
is
not
None
else
None
)
NT
=
len
(
chunk_indices
)
if
cu_seqlens
is
not
None
else
triton
.
cdiv
(
T
,
16
)
solve_tril_16x16_kernel
[
NT
,
B
*
H
](
A
=
A
,
Ad
=
Ad
,
cu_seqlens
=
cu_seqlens
,
chunk_indices
=
chunk_indices
,
T
=
T
,
H
=
H
,
BT
=
BT
,
num_warps
=
1
,
num_stages
=
4
,
)
if
BT
==
16
:
return
Ad
Ai
=
torch
.
empty
(
B
,
T
,
H
,
BT
,
device
=
A
.
device
,
dtype
=
output_dtype
)
merge_fn
=
(
merge_16x16_to_32x32_inverse_kernel
if
BT
==
32
else
merge_16x16_to_64x64_inverse_kernel
)
chunk_indices
=
(
prepare_chunk_indices
(
cu_seqlens
,
BT
)
if
cu_seqlens
is
not
None
else
None
)
NT
=
len
(
chunk_indices
)
if
cu_seqlens
is
not
None
else
triton
.
cdiv
(
T
,
BT
)
merge_fn
[
NT
,
B
*
H
](
A
=
A
,
Ad
=
Ad
,
Ai
=
Ai
,
cu_seqlens
=
cu_seqlens
,
chunk_indices
=
chunk_indices
,
T
=
T
,
H
=
H
,
BT
=
BT
,
num_warps
=
4
,
num_stages
=
3
,
)
return
Ai
python/sglang/srt/layers/attention/fla/utils.py
0 → 100644
View file @
dc491b39
# Adapt from https://github.com/fla-org/flash-linear-attention/blob/main/fla/utils.py
# -*- coding: utf-8 -*-
import
contextlib
import
functools
import
logging
import
os
import
sys
from
enum
import
Enum
from
functools
import
lru_cache
from
typing
import
Any
,
Callable
,
Dict
,
Literal
,
Optional
,
Tuple
import
torch
import
triton
from
packaging
import
version
logger
=
logging
.
getLogger
(
__name__
)
COMPILER_MODE
=
os
.
getenv
(
"FLA_COMPILER_MODE"
)
==
"1"
FLA_CI_ENV
=
os
.
getenv
(
"FLA_CI_ENV"
)
==
"1"
@
lru_cache
(
maxsize
=
1
)
def
check_environments
():
"""
Checks the current operating system, Triton version, and Python version,
issuing warnings if they don't meet recommendations.
This function's body only runs once due to lru_cache.
"""
# Check Operating System
if
sys
.
platform
==
"win32"
:
logger
.
warning
(
"Detected Windows operating system. Triton does not have an official Windows release, "
"thus FLA will not be adapted for Windows, and any potential errors will not be fixed. "
"Please consider using a Linux environment for compatibility."
)
triton_version
=
version
.
parse
(
triton
.
__version__
)
required_triton_version
=
version
.
parse
(
"3.2.0"
)
if
triton_version
<
required_triton_version
:
logger
.
warning
(
f
"Current Triton version
{
triton_version
}
is below the recommended 3.2.0 version. "
"Errors may occur and these issues will not be fixed. "
"Please consider upgrading Triton."
)
# Check Python version
py_version
=
version
.
parse
(
f
"
{
sys
.
version_info
.
major
}
.
{
sys
.
version_info
.
minor
}
"
)
required_py_version
=
version
.
parse
(
"3.11"
)
if
py_version
<
required_py_version
:
logger
.
warning
(
f
"Current Python version
{
py_version
}
is below the recommended 3.11 version. "
"It is recommended to upgrade to Python 3.11 or higher for the best experience."
)
return
None
check_environments
()
def
get_abs_err
(
x
,
y
):
return
(
x
.
detach
()
-
y
.
detach
()).
flatten
().
abs
().
max
().
item
()
def
get_err_ratio
(
x
,
y
):
err
=
(
x
.
detach
()
-
y
.
detach
()).
flatten
().
square
().
mean
().
sqrt
().
item
()
base
=
(
x
.
detach
()).
flatten
().
square
().
mean
().
sqrt
().
item
()
return
err
/
(
base
+
1e-8
)
def
assert_close
(
prefix
,
ref
,
tri
,
ratio
,
warning
=
False
,
err_atol
=
1e-6
):
abs_atol
=
get_abs_err
(
ref
,
tri
)
msg
=
f
"
{
prefix
}
diff:
{
abs_atol
:.
6
f
}
ratio:
{
get_err_ratio
(
ref
,
tri
):.
6
f
}
"
logger
.
info
(
msg
)
error_rate
=
get_err_ratio
(
ref
,
tri
)
if
abs_atol
<=
err_atol
:
return
if
warning
or
(
FLA_CI_ENV
and
(
error_rate
<
0.01
or
abs_atol
<=
0.3
)):
if
error_rate
>
ratio
:
import
warnings
warnings
.
warn
(
msg
)
else
:
assert
error_rate
<
ratio
,
msg
SUPPRESS_LEVEL
=
int
(
os
.
getenv
(
"GDN_RECOMPUTE_SUPPRESS_LEVEL"
,
"0"
))
def
tensor_cache
(
fn
:
Callable
[...,
torch
.
Tensor
])
->
Callable
[...,
torch
.
Tensor
]:
"""
A decorator that caches the most recent results of a function with tensor inputs.
This decorator will store the output of the decorated function for the most recent set of input tensors.
The cache is limited to a fixed size (default is 4). When the cache is full, the oldest entry will be removed.
Args:
fn (Callable[..., torch.Tensor]):
The function to be decorated. It should take tensor inputs and return tensor outputs.
Returns:
Callable[..., torch.Tensor]:
A wrapped version of the input function with single-entry caching.
"""
cache_entries
:
Tuple
[
Optional
[
Tuple
],
Optional
[
Dict
],
Any
]
=
[]
cache_size
=
4
@
functools
.
wraps
(
fn
)
def
wrapper
(
*
args
:
Any
,
**
kwargs
:
Any
)
->
Any
:
nonlocal
cache_entries
,
cache_size
for
i
,
entry
in
enumerate
(
cache_entries
):
last_args
,
last_kwargs
,
last_result
=
entry
if
len
(
args
)
==
len
(
last_args
)
and
len
(
kwargs
)
==
len
(
last_kwargs
):
if
all
(
a
is
b
for
a
,
b
in
zip
(
args
,
last_args
))
and
all
(
k
in
last_kwargs
and
v
is
last_kwargs
[
k
]
for
k
,
v
in
kwargs
.
items
()
):
cache_entries
=
(
cache_entries
[:
i
]
+
cache_entries
[
i
+
1
:]
+
[(
args
,
kwargs
,
last_result
)]
)
return
last_result
result
=
fn
(
*
args
,
**
kwargs
)
if
len
(
cache_entries
)
>=
cache_size
:
cache_entries
=
cache_entries
[
1
:]
cache_entries
.
append
((
args
,
kwargs
,
result
))
return
result
return
wrapper
def
input_guard
(
fn
:
Callable
[...,
torch
.
Tensor
])
->
Callable
[...,
torch
.
Tensor
]:
"""
A decorator to make sure all input tensors are contiguous and set the device based on input tensors.
"""
@
functools
.
wraps
(
fn
)
def
wrapper
(
*
args
,
**
kwargs
):
contiguous_args
=
(
i
if
not
isinstance
(
i
,
torch
.
Tensor
)
else
i
.
contiguous
()
for
i
in
args
)
contiguous_kwargs
=
{
k
:
(
v
if
not
isinstance
(
v
,
torch
.
Tensor
)
else
v
.
contiguous
())
for
k
,
v
in
kwargs
.
items
()
}
tensor
=
None
for
arg
in
args
:
if
isinstance
(
arg
,
torch
.
Tensor
):
tensor
=
arg
break
if
tensor
is
None
:
for
value
in
kwargs
.
values
():
if
isinstance
(
value
,
torch
.
Tensor
):
tensor
=
value
break
if
tensor
is
not
None
:
ctx
=
custom_device_ctx
(
tensor
.
device
.
index
)
else
:
ctx
=
contextlib
.
nullcontext
()
with
ctx
:
return
fn
(
*
contiguous_args
,
**
contiguous_kwargs
)
return
wrapper
contiguous
=
input_guard
def
require_version
(
version
,
hint
):
"""
Perform a runtime check of the dependency versions, using the exact same syntax used by pip.
"""
def
decorator
(
fn
):
@
functools
.
wraps
(
fn
)
def
wrapper
(
ctx
,
*
args
,
**
kwargs
):
from
transformers.utils.versions
import
require_version
require_version
(
version
,
hint
)
return
fn
(
ctx
,
*
(
i
if
not
isinstance
(
i
,
torch
.
Tensor
)
else
i
.
contiguous
()
for
i
in
args
),
**
{
k
:
(
v
if
not
isinstance
(
v
,
torch
.
Tensor
)
else
v
.
contiguous
())
for
k
,
v
in
kwargs
.
items
()
},
)
return
wrapper
return
decorator
def
checkpoint
(
fn
):
def
wrapper
(
*
args
,
**
kwargs
):
return
torch
.
utils
.
checkpoint
.
checkpoint
(
fn
,
*
args
,
**
kwargs
)
return
wrapper
@
lru_cache
(
maxsize
=
None
)
def
check_pytorch_version
(
version_s
:
str
=
"2.4"
)
->
bool
:
return
version
.
parse
(
torch
.
__version__
)
>=
version
.
parse
(
version_s
)
def
_cpu_device_warning
():
import
warnings
warnings
.
warn
(
(
"Triton is not supported on current platform, roll back to CPU."
),
stacklevel
=
1
)
@
lru_cache
(
maxsize
=
None
)
def
get_multiprocessor_count
(
tensor_idx
:
int
=
0
)
->
int
:
try
:
return
triton
.
runtime
.
driver
.
active
.
utils
.
get_device_properties
(
tensor_idx
)[
"multiprocessor_count"
]
except
BaseException
:
_cpu_device_warning
()
return
-
1
@
lru_cache
(
maxsize
=
None
)
def
get_available_device
()
->
str
:
try
:
return
triton
.
runtime
.
driver
.
active
.
get_current_target
().
backend
except
BaseException
:
_cpu_device_warning
()
return
"cpu"
@
lru_cache
(
maxsize
=
None
)
def
_check_platform
()
->
Literal
[
"nvidia"
,
"amd"
,
"intel"
,
"musa"
]:
device
=
get_available_device
()
if
device
==
"cuda"
:
return
"nvidia"
elif
device
==
"hip"
:
return
"amd"
elif
device
==
"xpu"
:
return
"intel"
else
:
return
device
# For AMD GPUs, the triton backend is 'hip', while for Nvidia GPUs, the triton backend is 'cuda'.
# However, the torch backend is 'cuda' for both Nvidia and AMD GPUs.
# Therefore, we need to check the triton backend to determine the actual GPU vendor.
device
=
get_available_device
()
if
get_available_device
()
!=
"hip"
else
"cuda"
device_torch_lib
=
getattr
(
torch
,
device
)
device_platform
=
_check_platform
()
is_amd
=
device_platform
==
"amd"
is_intel
=
device_platform
==
"intel"
is_nvidia
=
device_platform
==
"nvidia"
is_intel_alchemist
=
is_intel
and
"Intel(R) Arc(TM) A"
in
torch
.
xpu
.
get_device_name
(
0
)
is_nvidia_hopper
=
is_nvidia
and
(
"NVIDIA H"
in
torch
.
cuda
.
get_device_name
(
0
)
or
torch
.
cuda
.
get_device_capability
()[
0
]
>=
9
)
use_cuda_graph
=
is_nvidia
and
os
.
environ
.
get
(
"FLA_USE_CUDA_GRAPH"
,
"0"
)
==
"1"
# Nvidia Ampere or newer, haven't check AMD and intel yet.
is_tf32_supported
=
is_nvidia
and
torch
.
cuda
.
get_device_capability
(
0
)[
0
]
>=
8
is_gather_supported
=
hasattr
(
triton
.
language
,
"gather"
)
def
get_all_max_shared_mem
():
try
:
return
[
triton
.
runtime
.
driver
.
active
.
utils
.
get_device_properties
(
i
)[
"max_shared_mem"
]
for
i
in
range
(
device_torch_lib
.
device_count
())
]
except
BaseException
:
_cpu_device_warning
()
return
[
-
1
]
class
Backend
(
Enum
):
ADA
=
101376
# RTX 4090
AMPERE
=
166912
# A100
HOPPER
=
232448
# H100
DEFAULT
=
102400
# Default
@
classmethod
def
get_shared_memory
(
cls
,
arch
:
str
)
->
int
:
try
:
return
cls
[
arch
.
upper
()].
value
except
KeyError
:
return
cls
.
DEFAULT
.
value
@
lru_cache
(
maxsize
=
None
)
def
check_shared_mem
(
arch
:
str
=
"none"
,
tensor_idx
:
int
=
0
)
->
bool
:
try
:
device_shared_mem_list
=
get_all_max_shared_mem
()
max_shared_memory
=
device_shared_mem_list
[
tensor_idx
]
return
max_shared_memory
>=
Backend
.
get_shared_memory
(
arch
)
except
Exception
:
return
False
if
check_pytorch_version
(
"2.4"
):
device
=
"cuda"
if
device
==
"cpu"
else
device
autocast_custom_fwd
=
functools
.
partial
(
torch
.
amp
.
custom_fwd
,
device_type
=
device
)
autocast_custom_bwd
=
functools
.
partial
(
torch
.
amp
.
custom_bwd
,
device_type
=
device
)
def
custom_device_ctx
(
index
:
int
):
return
device_torch_lib
.
device
(
index
)
else
:
assert
(
device
==
"cuda"
),
"Only cuda device is supported for PyTorch version < 2.4.0."
autocast_custom_fwd
=
device_torch_lib
.
amp
.
custom_fwd
autocast_custom_bwd
=
device_torch_lib
.
amp
.
custom_bwd
def
custom_device_ctx
(
index
:
int
):
return
torch
.
cuda
.
device
(
index
)
python/sglang/srt/layers/attention/fla/wy_fast.py
0 → 100644
View file @
dc491b39
# Adapt from https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/gated_delta_rule/wy_fast.py
# -*- coding: utf-8 -*-
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
from
typing
import
Optional
,
Tuple
import
torch
import
triton
import
triton.language
as
tl
from
sglang.srt.layers.attention.fla.index
import
prepare_chunk_indices
from
sglang.srt.layers.attention.fla.op
import
safe_exp
from
sglang.srt.layers.attention.fla.utils
import
check_shared_mem
@
triton
.
heuristics
({
"IS_VARLEN"
:
lambda
args
:
args
[
"cu_seqlens"
]
is
not
None
})
# @triton.autotune(
# configs=[
# triton.Config({}, num_warps=num_warps, num_stages=num_stages)
# for num_warps in [2, 4, 8]
# for num_stages in [2, 3, 4]
# ],
# key=["H", "K", "V", "BT", "BK", "BV", "IS_VARLEN"],
# )
@
triton
.
jit
(
do_not_specialize
=
[
"T"
])
def
recompute_w_u_fwd_kernel
(
k
,
v
,
beta
,
w
,
u
,
A
,
g
,
cu_seqlens
,
chunk_indices
,
T
,
H
:
tl
.
constexpr
,
Hg
:
tl
.
constexpr
,
K
:
tl
.
constexpr
,
V
:
tl
.
constexpr
,
BT
:
tl
.
constexpr
,
BK
:
tl
.
constexpr
,
BV
:
tl
.
constexpr
,
IS_VARLEN
:
tl
.
constexpr
,
):
i_t
,
i_bh
=
tl
.
program_id
(
0
),
tl
.
program_id
(
1
)
i_b
,
i_h
=
i_bh
//
H
,
i_bh
%
H
if
IS_VARLEN
:
i_n
,
i_t
=
tl
.
load
(
chunk_indices
+
i_t
*
2
).
to
(
tl
.
int32
),
tl
.
load
(
chunk_indices
+
i_t
*
2
+
1
).
to
(
tl
.
int32
)
bos
,
eos
=
tl
.
load
(
cu_seqlens
+
i_n
).
to
(
tl
.
int32
),
tl
.
load
(
cu_seqlens
+
i_n
+
1
).
to
(
tl
.
int32
)
T
=
eos
-
bos
else
:
bos
,
eos
=
i_b
*
T
,
i_b
*
T
+
T
p_beta
=
tl
.
make_block_ptr
(
beta
+
bos
*
H
+
i_h
,
(
T
,),
(
H
,),
(
i_t
*
BT
,),
(
BT
,),
(
0
,)
)
p_g
=
tl
.
make_block_ptr
(
g
+
(
bos
*
H
+
i_h
),
(
T
,),
(
H
,),
(
i_t
*
BT
,),
(
BT
,),
(
0
,))
p_A
=
tl
.
make_block_ptr
(
A
+
(
bos
*
H
+
i_h
)
*
BT
,
(
T
,
BT
),
(
H
*
BT
,
1
),
(
i_t
*
BT
,
0
),
(
BT
,
BT
),
(
1
,
0
)
)
b_beta
=
tl
.
load
(
p_beta
,
boundary_check
=
(
0
,))
b_A
=
tl
.
load
(
p_A
,
boundary_check
=
(
0
,
1
))
b_g
=
tl
.
exp
(
tl
.
load
(
p_g
,
boundary_check
=
(
0
,)))
for
i_v
in
range
(
tl
.
cdiv
(
V
,
BV
)):
p_v
=
tl
.
make_block_ptr
(
v
+
(
bos
*
H
+
i_h
)
*
V
,
(
T
,
V
),
(
H
*
V
,
1
),
(
i_t
*
BT
,
i_v
*
BV
),
(
BT
,
BV
),
(
1
,
0
),
)
p_u
=
tl
.
make_block_ptr
(
u
+
(
bos
*
H
+
i_h
)
*
V
,
(
T
,
V
),
(
H
*
V
,
1
),
(
i_t
*
BT
,
i_v
*
BV
),
(
BT
,
BV
),
(
1
,
0
),
)
b_v
=
tl
.
load
(
p_v
,
boundary_check
=
(
0
,
1
))
b_vb
=
(
b_v
*
b_beta
[:,
None
]).
to
(
b_v
.
dtype
)
b_u
=
tl
.
dot
(
b_A
,
b_vb
,
allow_tf32
=
False
)
tl
.
store
(
p_u
,
b_u
.
to
(
p_u
.
dtype
.
element_ty
),
boundary_check
=
(
0
,
1
))
for
i_k
in
range
(
tl
.
cdiv
(
K
,
BK
)):
p_k
=
tl
.
make_block_ptr
(
k
+
(
bos
*
Hg
+
i_h
//
(
H
//
Hg
))
*
K
,
(
T
,
K
),
(
Hg
*
K
,
1
),
(
i_t
*
BT
,
i_k
*
BK
),
(
BT
,
BK
),
(
1
,
0
),
)
p_w
=
tl
.
make_block_ptr
(
w
+
(
bos
*
H
+
i_h
)
*
K
,
(
T
,
K
),
(
H
*
K
,
1
),
(
i_t
*
BT
,
i_k
*
BK
),
(
BT
,
BK
),
(
1
,
0
),
)
b_k
=
tl
.
load
(
p_k
,
boundary_check
=
(
0
,
1
))
b_kb
=
(
b_k
*
b_beta
[:,
None
]
*
b_g
[:,
None
]).
to
(
b_k
.
dtype
)
b_w
=
tl
.
dot
(
b_A
,
b_kb
)
tl
.
store
(
p_w
,
b_w
.
to
(
p_w
.
dtype
.
element_ty
),
boundary_check
=
(
0
,
1
))
def
recompute_w_u_fwd
(
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
beta
:
torch
.
Tensor
,
g_cumsum
:
torch
.
Tensor
,
A
:
torch
.
Tensor
,
cu_seqlens
:
Optional
[
torch
.
LongTensor
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
B
,
T
,
Hg
,
K
,
V
=
*
k
.
shape
,
v
.
shape
[
-
1
]
H
=
v
.
shape
[
-
2
]
BT
=
A
.
shape
[
-
1
]
chunk_indices
=
(
prepare_chunk_indices
(
cu_seqlens
,
BT
)
if
cu_seqlens
is
not
None
else
None
)
NT
=
triton
.
cdiv
(
T
,
BT
)
if
cu_seqlens
is
None
else
len
(
chunk_indices
)
BK
=
64
BV
=
64
u
=
torch
.
empty_like
(
v
)
w
=
k
.
new_empty
(
B
,
T
,
H
,
K
)
recompute_w_u_fwd_kernel
[(
NT
,
B
*
H
)](
k
=
k
,
v
=
v
,
beta
=
beta
,
w
=
w
,
u
=
u
,
A
=
A
,
g
=
g_cumsum
,
cu_seqlens
=
cu_seqlens
,
chunk_indices
=
chunk_indices
,
T
=
T
,
H
=
H
,
Hg
=
Hg
,
K
=
K
,
V
=
V
,
BT
=
BT
,
BK
=
BK
,
BV
=
BV
,
num_warps
=
4
,
num_stages
=
3
,
)
return
w
,
u
fwd_recompute_w_u
=
recompute_w_u_fwd
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment