Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
dc491b39
"vscode:/vscode.git/clone" did not exist on "253d1fea810381282a9ecc794037004ee043fda7"
Unverified
Commit
dc491b39
authored
Sep 11, 2025
by
Yi Zhang
Committed by
GitHub
Sep 10, 2025
Browse files
add flash linear attention triton kernel (#10239)
parent
5b64f006
Changes
14
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
3590 additions
and
0 deletions
+3590
-0
python/sglang/srt/layers/attention/fla/chunk.py
python/sglang/srt/layers/attention/fla/chunk.py
+242
-0
python/sglang/srt/layers/attention/fla/chunk_delta_h.py
python/sglang/srt/layers/attention/fla/chunk_delta_h.py
+314
-0
python/sglang/srt/layers/attention/fla/chunk_o.py
python/sglang/srt/layers/attention/fla/chunk_o.py
+178
-0
python/sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py
...n/sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py
+151
-0
python/sglang/srt/layers/attention/fla/cumsum.py
python/sglang/srt/layers/attention/fla/cumsum.py
+300
-0
python/sglang/srt/layers/attention/fla/fused_recurrent.py
python/sglang/srt/layers/attention/fla/fused_recurrent.py
+640
-0
python/sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py
...rt/layers/attention/fla/fused_sigmoid_gating_recurrent.py
+232
-0
python/sglang/srt/layers/attention/fla/index.py
python/sglang/srt/layers/attention/fla/index.py
+37
-0
python/sglang/srt/layers/attention/fla/l2norm.py
python/sglang/srt/layers/attention/fla/l2norm.py
+150
-0
python/sglang/srt/layers/attention/fla/layernorm_gated.py
python/sglang/srt/layers/attention/fla/layernorm_gated.py
+326
-0
python/sglang/srt/layers/attention/fla/op.py
python/sglang/srt/layers/attention/fla/op.py
+66
-0
python/sglang/srt/layers/attention/fla/solve_tril.py
python/sglang/srt/layers/attention/fla/solve_tril.py
+465
-0
python/sglang/srt/layers/attention/fla/utils.py
python/sglang/srt/layers/attention/fla/utils.py
+331
-0
python/sglang/srt/layers/attention/fla/wy_fast.py
python/sglang/srt/layers/attention/fla/wy_fast.py
+158
-0
No files found.
python/sglang/srt/layers/attention/fla/chunk.py
0 → 100644
View file @
dc491b39
# Adapted from https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/gated_delta_rule/chunk.py
# -*- coding: utf-8 -*-
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
import
warnings
from
typing
import
Optional
import
torch
from
einops
import
rearrange
from
sglang.srt.layers.attention.fla.chunk_delta_h
import
chunk_gated_delta_rule_fwd_h
from
sglang.srt.layers.attention.fla.chunk_o
import
chunk_fwd_o
from
sglang.srt.layers.attention.fla.chunk_scaled_dot_kkt
import
(
chunk_scaled_dot_kkt_fwd
,
)
from
sglang.srt.layers.attention.fla.cumsum
import
chunk_local_cumsum
from
sglang.srt.layers.attention.fla.l2norm
import
l2norm_fwd
from
sglang.srt.layers.attention.fla.solve_tril
import
solve_tril
from
sglang.srt.layers.attention.fla.utils
import
(
SUPPRESS_LEVEL
,
autocast_custom_fwd
,
input_guard
,
)
from
sglang.srt.layers.attention.fla.wy_fast
import
recompute_w_u_fwd
def
chunk_gated_delta_rule_fwd
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
g
:
torch
.
Tensor
,
beta
:
torch
.
Tensor
,
scale
:
float
,
initial_state
:
torch
.
Tensor
,
output_final_state
:
bool
,
cu_seqlens
:
Optional
[
torch
.
LongTensor
]
=
None
,
):
g
=
chunk_local_cumsum
(
g
,
chunk_size
=
64
,
cu_seqlens
=
cu_seqlens
)
# obtain WY representation. u is actually the new v.
A
=
chunk_scaled_dot_kkt_fwd
(
k
=
k
,
beta
=
beta
,
g_cumsum
=
g
,
cu_seqlens
=
cu_seqlens
,
output_dtype
=
torch
.
float32
)
A
=
solve_tril
(
A
=
A
,
cu_seqlens
=
cu_seqlens
,
output_dtype
=
k
.
dtype
)
w
,
u
=
recompute_w_u_fwd
(
k
=
k
,
v
=
v
,
beta
=
beta
,
A
=
A
,
g_cumsum
=
g
,
cu_seqlens
=
cu_seqlens
,
)
h
,
v_new
,
final_state
=
chunk_gated_delta_rule_fwd_h
(
k
=
k
,
w
=
w
,
u
=
u
,
g
=
g
,
initial_state
=
initial_state
,
output_final_state
=
output_final_state
,
cu_seqlens
=
cu_seqlens
,
)
o
=
chunk_fwd_o
(
q
=
q
,
k
=
k
,
v
=
v_new
,
h
=
h
,
g
=
g
,
scale
=
scale
,
cu_seqlens
=
cu_seqlens
,
)
if
SUPPRESS_LEVEL
<
3
:
return
g
,
o
,
A
,
final_state
,
None
,
None
,
None
elif
SUPPRESS_LEVEL
>=
3
:
return
g
,
o
,
A
,
final_state
,
w
,
h
,
v_new
class
ChunkGatedDeltaRuleFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
@
input_guard
@
autocast_custom_fwd
def
forward
(
ctx
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
g
:
torch
.
Tensor
,
beta
:
torch
.
Tensor
,
scale
:
float
,
initial_state
:
torch
.
Tensor
,
output_final_state
:
bool
,
cu_seqlens
:
Optional
[
torch
.
LongTensor
]
=
None
,
use_qk_l2norm_in_kernel
:
bool
=
False
,
):
q_orig
=
q
k_orig
=
k
if
use_qk_l2norm_in_kernel
:
q
=
l2norm_fwd
(
q
)
k
=
l2norm_fwd
(
k
)
g
,
o
,
A
,
final_state
,
w
,
h
,
v_new
=
chunk_gated_delta_rule_fwd
(
q
=
q
,
k
=
k
,
v
=
v
,
g
=
g
,
beta
=
beta
,
scale
=
scale
,
initial_state
=
initial_state
,
output_final_state
=
output_final_state
,
cu_seqlens
=
cu_seqlens
,
)
return
o
.
to
(
q
.
dtype
),
final_state
@
torch
.
compiler
.
disable
def
chunk_gated_delta_rule
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
g
:
torch
.
Tensor
,
beta
:
torch
.
Tensor
,
scale
:
float
=
None
,
initial_state
:
torch
.
Tensor
=
None
,
output_final_state
:
bool
=
False
,
cu_seqlens
:
Optional
[
torch
.
LongTensor
]
=
None
,
head_first
:
bool
=
False
,
use_qk_l2norm_in_kernel
:
bool
=
False
,
):
r
"""
Args:
q (torch.Tensor):
queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
k (torch.Tensor):
keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
v (torch.Tensor):
values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
g (torch.Tensor):
(forget) gating tensor (in log space!) of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`.
beta (torch.Tensor):
betas of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`.
scale (Optional[int]):
Scale factor for the RetNet attention scores.
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
initial_state (Optional[torch.Tensor]):
Initial state of shape `[N, H, K, V]` for `N` input sequences.
For equal-length input sequences, `N` equals the batch size `B`.
Default: `None`.
output_final_state (Optional[bool]):
Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
cu_seqlens (torch.LongTensor):
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
consistent with the FlashAttention API.
head_first (Optional[bool]):
Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
Default: `False`.
Returns:
o (torch.Tensor):
Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
final_state (torch.Tensor):
Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`.
Examples::
>>> import torch
>>> import torch.nn.functional as F
>>> from einops import rearrange
>>> from fla.ops.gated_delta_rule import chunk_gated_delta_rule
# inputs with equal lengths
>>> B, T, H, K, V = 4, 2048, 4, 512, 512
>>> q = torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda')
>>> k = F.normalize(torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda'), p=2, dim=-1)
>>> v = torch.randn(B, T, H, V, dtype=torch.bfloat16, device='cuda')
>>> beta = torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda').sigmoid()
>>> g = F.logsigmoid(torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda'))
>>> h0 = torch.randn(B, H, K, V, dtype=torch.bfloat16, device='cuda')
>>> o, ht = chunk_gated_delta_rule(
q, k, v, g, beta,
initial_state=h0,
output_final_state=True
)
# for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
>>> q, k, v, beta, g = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, beta, g))
# for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
>>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
>>> o_var, ht_var = chunk_gated_delta_rule(
q, k, v, g, beta,
initial_state=h0,
output_final_state=True,
cu_seqlens=cu_seqlens
)
"""
assert
q
.
dtype
==
k
.
dtype
==
v
.
dtype
assert
(
q
.
dtype
!=
torch
.
float32
),
"ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16."
assert
(
len
(
beta
.
shape
)
==
3
),
"beta must be of shape [B, T, H] if head_first=False, or [B, H, T] otherwise."
if
head_first
:
raise
DeprecationWarning
(
"head_first is deprecated and will be removed in a future version. "
"Please use head_first=False for now instead."
)
q
,
k
,
v
,
beta
,
g
=
map
(
lambda
x
:
rearrange
(
x
,
"b h t ... -> b t h ..."
),
(
q
,
k
,
v
,
beta
,
g
)
)
# if not head_first and q.shape[1] < q.shape[2]:
# warnings.warn(
# f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
# "This may indicate the inputs were passed in head-first format [B, H, T, ...] "
# "when head_first=False was specified. "
# "Please verify your input tensor format matches the expected shape [B, T, H, ...]."
# )
if
cu_seqlens
is
not
None
:
if
q
.
shape
[
0
]
!=
1
:
raise
ValueError
(
f
"The batch size is expected to be 1 rather than
{
q
.
shape
[
0
]
}
when using `cu_seqlens`."
f
"Please flatten variable-length inputs before processing."
)
if
initial_state
is
not
None
and
initial_state
.
shape
[
0
]
!=
len
(
cu_seqlens
)
-
1
:
raise
ValueError
(
f
"The number of initial states is expected to be equal to the number of input sequences, "
f
"i.e.,
{
len
(
cu_seqlens
)
-
1
}
rather than
{
initial_state
.
shape
[
0
]
}
."
)
if
scale
is
None
:
scale
=
k
.
shape
[
-
1
]
**
-
0.5
o
,
final_state
=
ChunkGatedDeltaRuleFunction
.
apply
(
q
,
k
,
v
,
g
,
beta
,
scale
,
initial_state
,
output_final_state
,
cu_seqlens
,
use_qk_l2norm_in_kernel
,
)
if
head_first
:
o
=
rearrange
(
o
,
"b t h ... -> b h t ..."
)
return
o
,
final_state
python/sglang/srt/layers/attention/fla/chunk_delta_h.py
0 → 100644
View file @
dc491b39
# Adapted from https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/common/chunk_delta_h.py
# -*- coding: utf-8 -*-
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
from
typing
import
Optional
,
Tuple
import
torch
import
triton
import
triton.language
as
tl
from
sglang.srt.layers.attention.fla.index
import
(
prepare_chunk_indices
,
prepare_chunk_offsets
,
)
from
sglang.srt.layers.attention.fla.op
import
exp
,
safe_exp
from
sglang.srt.layers.attention.fla.utils
import
is_nvidia_hopper
NUM_WARPS
=
[
2
,
4
]
if
is_nvidia_hopper
else
[
2
,
4
,
8
,
16
]
@
triton
.
heuristics
(
{
"USE_G"
:
lambda
args
:
args
[
"g"
]
is
not
None
,
"USE_INITIAL_STATE"
:
lambda
args
:
args
[
"h0"
]
is
not
None
,
"STORE_FINAL_STATE"
:
lambda
args
:
args
[
"ht"
]
is
not
None
,
"SAVE_NEW_VALUE"
:
lambda
args
:
args
[
"v_new"
]
is
not
None
,
"IS_VARLEN"
:
lambda
args
:
args
[
"cu_seqlens"
]
is
not
None
,
}
)
# @triton.autotune(
# configs=[
# triton.Config({"BV": BV}, num_warps=num_warps, num_stages=num_stages)
# for num_warps in [2, 4]
# for num_stages in [2, 3, 4]
# for BV in [32, 64]
# ],
# key=["H", "K", "V", "BT", "USE_G"],
# use_cuda_graph=use_cuda_graph,
# )
@
triton
.
jit
(
do_not_specialize
=
[
"T"
])
def
chunk_gated_delta_rule_fwd_kernel_h_blockdim64
(
k
,
v
,
w
,
v_new
,
g
,
h
,
h0
,
ht
,
cu_seqlens
,
chunk_offsets
,
T
,
H
:
tl
.
constexpr
,
Hg
:
tl
.
constexpr
,
K
:
tl
.
constexpr
,
V
:
tl
.
constexpr
,
BT
:
tl
.
constexpr
,
BV
:
tl
.
constexpr
,
USE_G
:
tl
.
constexpr
,
USE_INITIAL_STATE
:
tl
.
constexpr
,
STORE_FINAL_STATE
:
tl
.
constexpr
,
SAVE_NEW_VALUE
:
tl
.
constexpr
,
IS_VARLEN
:
tl
.
constexpr
,
):
i_v
,
i_nh
=
tl
.
program_id
(
0
),
tl
.
program_id
(
1
)
i_n
,
i_h
=
i_nh
//
H
,
i_nh
%
H
if
IS_VARLEN
:
bos
,
eos
=
tl
.
load
(
cu_seqlens
+
i_n
).
to
(
tl
.
int32
),
tl
.
load
(
cu_seqlens
+
i_n
+
1
).
to
(
tl
.
int32
)
T
=
eos
-
bos
NT
=
tl
.
cdiv
(
T
,
BT
)
boh
=
tl
.
load
(
chunk_offsets
+
i_n
).
to
(
tl
.
int32
)
else
:
bos
,
eos
=
i_n
*
T
,
i_n
*
T
+
T
NT
=
tl
.
cdiv
(
T
,
BT
)
boh
=
i_n
*
NT
# [BK, BV]
b_h1
=
tl
.
zeros
([
64
,
BV
],
dtype
=
tl
.
float32
)
if
K
>
64
:
b_h2
=
tl
.
zeros
([
64
,
BV
],
dtype
=
tl
.
float32
)
if
K
>
128
:
b_h3
=
tl
.
zeros
([
64
,
BV
],
dtype
=
tl
.
float32
)
if
K
>
192
:
b_h4
=
tl
.
zeros
([
64
,
BV
],
dtype
=
tl
.
float32
)
# calculate offset
h
+=
(
boh
*
H
+
i_h
)
*
K
*
V
v
+=
(
bos
*
H
+
i_h
)
*
V
k
+=
(
bos
*
Hg
+
i_h
//
(
H
//
Hg
))
*
K
w
+=
(
bos
*
H
+
i_h
)
*
K
if
SAVE_NEW_VALUE
:
v_new
+=
(
bos
*
H
+
i_h
)
*
V
stride_v
=
H
*
V
stride_h
=
H
*
K
*
V
stride_k
=
Hg
*
K
stride_w
=
H
*
K
if
USE_INITIAL_STATE
:
h0
=
h0
+
i_nh
*
K
*
V
if
STORE_FINAL_STATE
:
ht
=
ht
+
i_nh
*
K
*
V
# load initial state
if
USE_INITIAL_STATE
:
p_h0_1
=
tl
.
make_block_ptr
(
h0
,
(
K
,
V
),
(
V
,
1
),
(
0
,
i_v
*
BV
),
(
64
,
BV
),
(
1
,
0
))
b_h1
+=
tl
.
load
(
p_h0_1
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
if
K
>
64
:
p_h0_2
=
tl
.
make_block_ptr
(
h0
,
(
K
,
V
),
(
V
,
1
),
(
64
,
i_v
*
BV
),
(
64
,
BV
),
(
1
,
0
)
)
b_h2
+=
tl
.
load
(
p_h0_2
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
if
K
>
128
:
p_h0_3
=
tl
.
make_block_ptr
(
h0
,
(
K
,
V
),
(
V
,
1
),
(
128
,
i_v
*
BV
),
(
64
,
BV
),
(
1
,
0
)
)
b_h3
+=
tl
.
load
(
p_h0_3
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
if
K
>
192
:
p_h0_4
=
tl
.
make_block_ptr
(
h0
,
(
K
,
V
),
(
V
,
1
),
(
192
,
i_v
*
BV
),
(
64
,
BV
),
(
1
,
0
)
)
b_h4
+=
tl
.
load
(
p_h0_4
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
# main recurrence
for
i_t
in
range
(
NT
):
p_h1
=
tl
.
make_block_ptr
(
h
+
i_t
*
stride_h
,
(
K
,
V
),
(
V
,
1
),
(
0
,
i_v
*
BV
),
(
64
,
BV
),
(
1
,
0
)
)
tl
.
store
(
p_h1
,
b_h1
.
to
(
p_h1
.
dtype
.
element_ty
),
boundary_check
=
(
0
,
1
))
if
K
>
64
:
p_h2
=
tl
.
make_block_ptr
(
h
+
i_t
*
stride_h
,
(
K
,
V
),
(
V
,
1
),
(
64
,
i_v
*
BV
),
(
64
,
BV
),
(
1
,
0
)
)
tl
.
store
(
p_h2
,
b_h2
.
to
(
p_h2
.
dtype
.
element_ty
),
boundary_check
=
(
0
,
1
))
if
K
>
128
:
p_h3
=
tl
.
make_block_ptr
(
h
+
i_t
*
stride_h
,
(
K
,
V
),
(
V
,
1
),
(
128
,
i_v
*
BV
),
(
64
,
BV
),
(
1
,
0
)
)
tl
.
store
(
p_h3
,
b_h3
.
to
(
p_h3
.
dtype
.
element_ty
),
boundary_check
=
(
0
,
1
))
if
K
>
192
:
p_h4
=
tl
.
make_block_ptr
(
h
+
i_t
*
stride_h
,
(
K
,
V
),
(
V
,
1
),
(
192
,
i_v
*
BV
),
(
64
,
BV
),
(
1
,
0
)
)
tl
.
store
(
p_h4
,
b_h4
.
to
(
p_h4
.
dtype
.
element_ty
),
boundary_check
=
(
0
,
1
))
p_v
=
tl
.
make_block_ptr
(
v
,
(
T
,
V
),
(
stride_v
,
1
),
(
i_t
*
BT
,
i_v
*
BV
),
(
BT
,
BV
),
(
1
,
0
)
)
p_v_new
=
(
tl
.
make_block_ptr
(
v_new
,
(
T
,
V
),
(
stride_v
,
1
),
(
i_t
*
BT
,
i_v
*
BV
),
(
BT
,
BV
),
(
1
,
0
)
)
if
SAVE_NEW_VALUE
else
None
)
b_v_new
=
tl
.
zeros
([
BT
,
BV
],
dtype
=
tl
.
float32
)
p_w
=
tl
.
make_block_ptr
(
w
,
(
T
,
K
),
(
stride_w
,
1
),
(
i_t
*
BT
,
0
),
(
BT
,
64
),
(
1
,
0
)
)
b_w
=
tl
.
load
(
p_w
,
boundary_check
=
(
0
,
1
))
b_v_new
+=
tl
.
dot
(
b_w
,
b_h1
.
to
(
b_w
.
dtype
))
if
K
>
64
:
p_w
=
tl
.
make_block_ptr
(
w
,
(
T
,
K
),
(
stride_w
,
1
),
(
i_t
*
BT
,
64
),
(
BT
,
64
),
(
1
,
0
)
)
b_w
=
tl
.
load
(
p_w
,
boundary_check
=
(
0
,
1
))
b_v_new
+=
tl
.
dot
(
b_w
,
b_h2
.
to
(
b_w
.
dtype
))
if
K
>
128
:
p_w
=
tl
.
make_block_ptr
(
w
,
(
T
,
K
),
(
stride_w
,
1
),
(
i_t
*
BT
,
128
),
(
BT
,
64
),
(
1
,
0
)
)
b_w
=
tl
.
load
(
p_w
,
boundary_check
=
(
0
,
1
))
b_v_new
+=
tl
.
dot
(
b_w
,
b_h3
.
to
(
b_w
.
dtype
))
if
K
>
192
:
p_w
=
tl
.
make_block_ptr
(
w
,
(
T
,
K
),
(
stride_w
,
1
),
(
i_t
*
BT
,
192
),
(
BT
,
64
),
(
1
,
0
)
)
b_w
=
tl
.
load
(
p_w
,
boundary_check
=
(
0
,
1
))
b_v_new
+=
tl
.
dot
(
b_w
,
b_h4
.
to
(
b_w
.
dtype
))
b_v_new
=
-
b_v_new
+
tl
.
load
(
p_v
,
boundary_check
=
(
0
,
1
))
if
SAVE_NEW_VALUE
:
p_v_new
=
tl
.
make_block_ptr
(
v_new
,
(
T
,
V
),
(
stride_v
,
1
),
(
i_t
*
BT
,
i_v
*
BV
),
(
BT
,
BV
),
(
1
,
0
)
)
tl
.
store
(
p_v_new
,
b_v_new
.
to
(
p_v_new
.
dtype
.
element_ty
),
boundary_check
=
(
0
,
1
)
)
if
USE_G
:
last_idx
=
min
((
i_t
+
1
)
*
BT
,
T
)
-
1
b_g_last
=
tl
.
load
(
g
+
bos
*
H
+
last_idx
*
H
+
i_h
)
p_g
=
tl
.
make_block_ptr
(
g
+
bos
*
H
+
i_h
,
(
T
,),
(
H
,),
(
i_t
*
BT
,),
(
BT
,),
(
0
,)
)
b_g
=
tl
.
load
(
p_g
,
boundary_check
=
(
0
,))
b_v_new
=
b_v_new
*
safe_exp
(
b_g_last
-
b_g
)[:,
None
]
b_g_last
=
exp
(
b_g_last
)
b_h1
=
b_h1
*
b_g_last
if
K
>
64
:
b_h2
=
b_h2
*
b_g_last
if
K
>
128
:
b_h3
=
b_h3
*
b_g_last
if
K
>
192
:
b_h4
=
b_h4
*
b_g_last
b_v_new
=
b_v_new
.
to
(
k
.
dtype
.
element_ty
)
p_k
=
tl
.
make_block_ptr
(
k
,
(
K
,
T
),
(
1
,
stride_k
),
(
0
,
i_t
*
BT
),
(
64
,
BT
),
(
0
,
1
)
)
b_k
=
tl
.
load
(
p_k
,
boundary_check
=
(
0
,
1
))
b_h1
+=
tl
.
dot
(
b_k
,
b_v_new
)
if
K
>
64
:
p_k
=
tl
.
make_block_ptr
(
k
,
(
K
,
T
),
(
1
,
stride_k
),
(
64
,
i_t
*
BT
),
(
64
,
BT
),
(
0
,
1
)
)
b_k
=
tl
.
load
(
p_k
,
boundary_check
=
(
0
,
1
))
b_h2
+=
tl
.
dot
(
b_k
,
b_v_new
)
if
K
>
128
:
p_k
=
tl
.
make_block_ptr
(
k
,
(
K
,
T
),
(
1
,
stride_k
),
(
128
,
i_t
*
BT
),
(
64
,
BT
),
(
0
,
1
)
)
b_k
=
tl
.
load
(
p_k
,
boundary_check
=
(
0
,
1
))
b_h3
+=
tl
.
dot
(
b_k
,
b_v_new
)
if
K
>
192
:
p_k
=
tl
.
make_block_ptr
(
k
,
(
K
,
T
),
(
1
,
stride_k
),
(
192
,
i_t
*
BT
),
(
64
,
BT
),
(
0
,
1
)
)
b_k
=
tl
.
load
(
p_k
,
boundary_check
=
(
0
,
1
))
b_h4
+=
tl
.
dot
(
b_k
,
b_v_new
)
# epilogue
if
STORE_FINAL_STATE
:
p_ht
=
tl
.
make_block_ptr
(
ht
,
(
K
,
V
),
(
V
,
1
),
(
0
,
i_v
*
BV
),
(
64
,
BV
),
(
1
,
0
))
tl
.
store
(
p_ht
,
b_h1
.
to
(
p_ht
.
dtype
.
element_ty
),
boundary_check
=
(
0
,
1
))
if
K
>
64
:
p_ht
=
tl
.
make_block_ptr
(
ht
,
(
K
,
V
),
(
V
,
1
),
(
64
,
i_v
*
BV
),
(
64
,
BV
),
(
1
,
0
)
)
tl
.
store
(
p_ht
,
b_h2
.
to
(
p_ht
.
dtype
.
element_ty
),
boundary_check
=
(
0
,
1
))
if
K
>
128
:
p_ht
=
tl
.
make_block_ptr
(
ht
,
(
K
,
V
),
(
V
,
1
),
(
128
,
i_v
*
BV
),
(
64
,
BV
),
(
1
,
0
)
)
tl
.
store
(
p_ht
,
b_h3
.
to
(
p_ht
.
dtype
.
element_ty
),
boundary_check
=
(
0
,
1
))
if
K
>
192
:
p_ht
=
tl
.
make_block_ptr
(
ht
,
(
K
,
V
),
(
V
,
1
),
(
192
,
i_v
*
BV
),
(
64
,
BV
),
(
1
,
0
)
)
tl
.
store
(
p_ht
,
b_h4
.
to
(
p_ht
.
dtype
.
element_ty
),
boundary_check
=
(
0
,
1
))
def
chunk_gated_delta_rule_fwd_h
(
k
:
torch
.
Tensor
,
w
:
torch
.
Tensor
,
u
:
torch
.
Tensor
,
g
:
Optional
[
torch
.
Tensor
]
=
None
,
initial_state
:
Optional
[
torch
.
Tensor
]
=
None
,
output_final_state
:
bool
=
False
,
chunk_size
:
int
=
64
,
# SY: remove this argument and force chunk size 64?
save_new_value
:
bool
=
True
,
cu_seqlens
:
Optional
[
torch
.
LongTensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
B
,
T
,
Hg
,
K
,
V
=
*
k
.
shape
,
u
.
shape
[
-
1
]
H
=
u
.
shape
[
-
2
]
BT
=
chunk_size
chunk_indices
=
(
prepare_chunk_indices
(
cu_seqlens
,
chunk_size
)
if
cu_seqlens
is
not
None
else
None
)
# N: the actual number of sequences in the batch with either equal or variable lengths
if
cu_seqlens
is
None
:
N
,
NT
,
chunk_offsets
=
B
,
triton
.
cdiv
(
T
,
BT
),
None
else
:
N
,
NT
,
chunk_offsets
=
(
len
(
cu_seqlens
)
-
1
,
len
(
chunk_indices
),
prepare_chunk_offsets
(
cu_seqlens
,
BT
),
)
assert
K
<=
256
,
"current kernel does not support head dimension larger than 256."
h
=
k
.
new_empty
(
B
,
NT
,
H
,
K
,
V
)
final_state
=
(
k
.
new_empty
(
N
,
H
,
K
,
V
,
dtype
=
torch
.
float32
)
if
output_final_state
else
None
)
v_new
=
torch
.
empty_like
(
u
)
if
save_new_value
else
None
def
grid
(
meta
):
return
(
triton
.
cdiv
(
V
,
meta
[
"BV"
]),
N
*
H
)
chunk_gated_delta_rule_fwd_kernel_h_blockdim64
[
grid
](
k
=
k
,
v
=
u
,
w
=
w
,
v_new
=
v_new
,
g
=
g
,
h
=
h
,
h0
=
initial_state
,
ht
=
final_state
,
cu_seqlens
=
cu_seqlens
,
chunk_offsets
=
chunk_offsets
,
T
=
T
,
H
=
H
,
Hg
=
Hg
,
K
=
K
,
V
=
V
,
BT
=
BT
,
BV
=
32
,
num_warps
=
4
,
num_stages
=
2
,
)
return
h
,
v_new
,
final_state
python/sglang/srt/layers/attention/fla/chunk_o.py
0 → 100644
View file @
dc491b39
# Adapted from https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/common/chunk_o.py
# -*- coding: utf-8 -*-
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
from
typing
import
Optional
,
Tuple
import
torch
import
triton
import
triton.language
as
tl
from
sglang.srt.layers.attention.fla.index
import
prepare_chunk_indices
from
sglang.srt.layers.attention.fla.op
import
exp
,
safe_exp
from
sglang.srt.layers.attention.fla.utils
import
check_shared_mem
,
is_nvidia_hopper
BKV_LIST
=
[
64
,
128
]
if
check_shared_mem
()
else
[
32
,
64
]
NUM_WARPS
=
[
2
,
4
]
if
is_nvidia_hopper
else
[
2
,
4
,
8
]
@
triton
.
heuristics
(
{
"USE_G"
:
lambda
args
:
args
[
"g"
]
is
not
None
,
"IS_VARLEN"
:
lambda
args
:
args
[
"cu_seqlens"
]
is
not
None
,
}
)
# @triton.autotune(
# configs=[
# triton.Config({"BK": BK, "BV": BV}, num_warps=num_warps, num_stages=num_stages)
# for BK in BKV_LIST
# for BV in BKV_LIST
# for num_warps in NUM_WARPS
# for num_stages in [2, 3, 4]
# ],
# key=["H", "K", "V", "BT"],
# )
@
triton
.
jit
(
do_not_specialize
=
[
"T"
])
def
chunk_fwd_kernel_o
(
q
,
k
,
v
,
h
,
g
,
o
,
cu_seqlens
,
chunk_indices
,
scale
,
T
,
H
:
tl
.
constexpr
,
Hg
:
tl
.
constexpr
,
K
:
tl
.
constexpr
,
V
:
tl
.
constexpr
,
BT
:
tl
.
constexpr
,
BK
:
tl
.
constexpr
,
BV
:
tl
.
constexpr
,
USE_G
:
tl
.
constexpr
,
IS_VARLEN
:
tl
.
constexpr
,
):
i_v
,
i_t
,
i_bh
=
tl
.
program_id
(
0
),
tl
.
program_id
(
1
),
tl
.
program_id
(
2
)
i_b
,
i_h
=
i_bh
//
H
,
i_bh
%
H
if
IS_VARLEN
:
i_tg
=
i_t
i_n
,
i_t
=
tl
.
load
(
chunk_indices
+
i_t
*
2
).
to
(
tl
.
int32
),
tl
.
load
(
chunk_indices
+
i_t
*
2
+
1
).
to
(
tl
.
int32
)
bos
,
eos
=
tl
.
load
(
cu_seqlens
+
i_n
).
to
(
tl
.
int32
),
tl
.
load
(
cu_seqlens
+
i_n
+
1
).
to
(
tl
.
int32
)
T
=
eos
-
bos
NT
=
tl
.
cdiv
(
T
,
BT
)
else
:
NT
=
tl
.
cdiv
(
T
,
BT
)
i_tg
=
i_b
*
NT
+
i_t
bos
,
eos
=
i_b
*
T
,
i_b
*
T
+
T
# offset calculation
q
+=
(
bos
*
Hg
+
i_h
//
(
H
//
Hg
))
*
K
k
+=
(
bos
*
Hg
+
i_h
//
(
H
//
Hg
))
*
K
v
+=
(
bos
*
H
+
i_h
)
*
V
o
+=
(
bos
*
H
+
i_h
)
*
V
h
+=
(
i_tg
*
H
+
i_h
).
to
(
tl
.
int64
)
*
K
*
V
b_o
=
tl
.
zeros
([
BT
,
BV
],
dtype
=
tl
.
float32
)
b_A
=
tl
.
zeros
([
BT
,
BT
],
dtype
=
tl
.
float32
)
for
i_k
in
range
(
tl
.
cdiv
(
K
,
BK
)):
p_q
=
tl
.
make_block_ptr
(
q
,
(
T
,
K
),
(
Hg
*
K
,
1
),
(
i_t
*
BT
,
i_k
*
BK
),
(
BT
,
BK
),
(
1
,
0
)
)
p_k
=
tl
.
make_block_ptr
(
k
,
(
K
,
T
),
(
1
,
Hg
*
K
),
(
i_k
*
BK
,
i_t
*
BT
),
(
BK
,
BT
),
(
0
,
1
)
)
p_h
=
tl
.
make_block_ptr
(
h
,
(
K
,
V
),
(
V
,
1
),
(
i_k
*
BK
,
i_v
*
BV
),
(
BK
,
BV
),
(
1
,
0
)
)
# [BT, BK]
b_q
=
tl
.
load
(
p_q
,
boundary_check
=
(
0
,
1
))
# [BK, BT]
b_k
=
tl
.
load
(
p_k
,
boundary_check
=
(
0
,
1
))
# [BK, BV]
b_h
=
tl
.
load
(
p_h
,
boundary_check
=
(
0
,
1
))
# [BT, BK] @ [BK, BV] -> [BT, BV]
b_o
+=
tl
.
dot
(
b_q
,
b_h
)
# [BT, BK] @ [BK, BT] -> [BT, BT]
b_A
+=
tl
.
dot
(
b_q
,
b_k
)
if
USE_G
:
g
+=
bos
*
H
+
i_h
p_g
=
tl
.
make_block_ptr
(
g
,
(
T
,),
(
H
,),
(
i_t
*
BT
,),
(
BT
,),
(
0
,))
b_g
=
tl
.
load
(
p_g
,
boundary_check
=
(
0
,))
b_o
=
b_o
*
exp
(
b_g
)[:,
None
]
b_A
=
b_A
*
safe_exp
(
b_g
[:,
None
]
-
b_g
[
None
,
:])
o_i
=
tl
.
arange
(
0
,
BT
)
m_A
=
o_i
[:,
None
]
>=
o_i
[
None
,
:]
b_A
=
tl
.
where
(
m_A
,
b_A
,
0
)
p_v
=
tl
.
make_block_ptr
(
v
,
(
T
,
V
),
(
H
*
V
,
1
),
(
i_t
*
BT
,
i_v
*
BV
),
(
BT
,
BV
),
(
1
,
0
)
)
p_o
=
tl
.
make_block_ptr
(
o
,
(
T
,
V
),
(
H
*
V
,
1
),
(
i_t
*
BT
,
i_v
*
BV
),
(
BT
,
BV
),
(
1
,
0
)
)
b_v
=
tl
.
load
(
p_v
,
boundary_check
=
(
0
,
1
))
# to fix mma -> mma layout conversion
# already solved by triton v3.2 or higher
b_o
=
b_o
*
scale
+
tl
.
dot
(
b_A
.
to
(
b_v
.
dtype
),
b_v
)
*
scale
tl
.
store
(
p_o
,
b_o
.
to
(
p_o
.
dtype
.
element_ty
),
boundary_check
=
(
0
,
1
))
def
chunk_fwd_o
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
h
:
torch
.
Tensor
,
g
:
Optional
[
torch
.
Tensor
]
=
None
,
# cumsum of log decay
scale
:
Optional
[
float
]
=
None
,
cu_seqlens
:
Optional
[
torch
.
LongTensor
]
=
None
,
chunk_size
:
int
=
64
,
)
->
torch
.
Tensor
:
B
,
T
,
Hg
,
K
,
V
=
*
q
.
shape
,
v
.
shape
[
-
1
]
H
=
v
.
shape
[
-
2
]
BT
=
min
(
chunk_size
,
max
(
16
,
triton
.
next_power_of_2
(
T
)))
chunk_indices
=
(
prepare_chunk_indices
(
cu_seqlens
,
BT
)
if
cu_seqlens
is
not
None
else
None
)
NT
=
triton
.
cdiv
(
T
,
BT
)
if
cu_seqlens
is
None
else
len
(
chunk_indices
)
if
scale
is
None
:
scale
=
k
.
shape
[
-
1
]
**
-
0.5
o
=
torch
.
empty_like
(
v
)
def
grid
(
meta
):
return
(
triton
.
cdiv
(
V
,
meta
[
"BV"
]),
NT
,
B
*
H
)
chunk_fwd_kernel_o
[
grid
](
q
,
k
,
v
,
h
,
g
,
o
,
cu_seqlens
,
chunk_indices
,
scale
,
T
=
T
,
H
=
H
,
Hg
=
Hg
,
K
=
K
,
V
=
V
,
BT
=
BT
,
BK
=
128
,
BV
=
64
,
num_warps
=
4
,
num_stages
=
2
,
)
return
o
python/sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py
0 → 100644
View file @
dc491b39
# Adapted from https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/common/chunk_scaled_dot_kkt.py
# -*- coding: utf-8 -*-
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
from
typing
import
Optional
import
torch
import
triton
import
triton.language
as
tl
from
sglang.srt.layers.attention.fla.index
import
prepare_chunk_indices
from
sglang.srt.layers.attention.fla.op
import
safe_exp
@
triton
.
heuristics
(
{
"IS_VARLEN"
:
lambda
args
:
args
[
"cu_seqlens"
]
is
not
None
,
"USE_G"
:
lambda
args
:
args
[
"g_cumsum"
]
is
not
None
,
}
)
# @triton.autotune(
# configs=[
# triton.Config({"BK": BK}, num_warps=num_warps, num_stages=num_stages)
# for BK in [32, 64, 128]
# for num_warps in [2, 4, 8]
# for num_stages in [2, 3, 4]
# ],
# key=["H", "K", "BT", "IS_VARLEN"],
# )
@
triton
.
jit
(
do_not_specialize
=
[
"T"
])
def
chunk_scaled_dot_kkt_fwd_kernel
(
k
,
beta
,
g_cumsum
,
A
,
cu_seqlens
,
chunk_indices
,
T
,
H
:
tl
.
constexpr
,
Hg
:
tl
.
constexpr
,
K
:
tl
.
constexpr
,
BT
:
tl
.
constexpr
,
BK
:
tl
.
constexpr
,
IS_VARLEN
:
tl
.
constexpr
,
USE_G
:
tl
.
constexpr
,
):
i_t
,
i_bh
=
tl
.
program_id
(
0
),
tl
.
program_id
(
1
)
i_b
,
i_h
=
i_bh
//
H
,
i_bh
%
H
if
IS_VARLEN
:
i_n
,
i_t
=
tl
.
load
(
chunk_indices
+
i_t
*
2
).
to
(
tl
.
int32
),
tl
.
load
(
chunk_indices
+
i_t
*
2
+
1
).
to
(
tl
.
int32
)
bos
,
eos
=
tl
.
load
(
cu_seqlens
+
i_n
).
to
(
tl
.
int32
),
tl
.
load
(
cu_seqlens
+
i_n
+
1
).
to
(
tl
.
int32
)
T
=
eos
-
bos
else
:
bos
,
eos
=
i_b
*
T
,
i_b
*
T
+
T
o_t
=
tl
.
arange
(
0
,
BT
)
p_beta
=
tl
.
make_block_ptr
(
beta
+
bos
*
H
+
i_h
,
(
T
,),
(
H
,),
(
i_t
*
BT
,),
(
BT
,),
(
0
,)
)
b_beta
=
tl
.
load
(
p_beta
,
boundary_check
=
(
0
,))
b_A
=
tl
.
zeros
([
BT
,
BT
],
dtype
=
tl
.
float32
)
for
i_k
in
range
(
tl
.
cdiv
(
K
,
BK
)):
p_k
=
tl
.
make_block_ptr
(
k
+
(
bos
*
Hg
+
i_h
//
(
H
//
Hg
))
*
K
,
(
T
,
K
),
(
Hg
*
K
,
1
),
(
i_t
*
BT
,
i_k
*
BK
),
(
BT
,
BK
),
(
1
,
0
),
)
b_k
=
tl
.
load
(
p_k
,
boundary_check
=
(
0
,
1
))
b_kb
=
b_k
*
b_beta
[:,
None
]
b_A
+=
tl
.
dot
(
b_kb
.
to
(
b_k
.
dtype
),
tl
.
trans
(
b_k
))
if
USE_G
:
p_g
=
tl
.
make_block_ptr
(
g_cumsum
+
bos
*
H
+
i_h
,
(
T
,),
(
H
,),
(
i_t
*
BT
,),
(
BT
,),
(
0
,)
)
b_g
=
tl
.
load
(
p_g
,
boundary_check
=
(
0
,))
b_g_diff
=
b_g
[:,
None
]
-
b_g
[
None
,
:]
b_A
=
b_A
*
safe_exp
(
b_g_diff
)
b_A
=
tl
.
where
(
o_t
[:,
None
]
>
o_t
[
None
,
:],
b_A
,
0
)
p_A
=
tl
.
make_block_ptr
(
A
+
(
bos
*
H
+
i_h
)
*
BT
,
(
T
,
BT
),
(
BT
*
H
,
1
),
(
i_t
*
BT
,
0
),
(
BT
,
BT
),
(
1
,
0
)
)
tl
.
store
(
p_A
,
b_A
.
to
(
p_A
.
dtype
.
element_ty
),
boundary_check
=
(
0
,
1
))
def
chunk_scaled_dot_kkt_fwd
(
k
:
torch
.
Tensor
,
beta
:
torch
.
Tensor
,
g_cumsum
:
Optional
[
torch
.
Tensor
]
=
None
,
cu_seqlens
:
Optional
[
torch
.
LongTensor
]
=
None
,
chunk_size
:
int
=
64
,
output_dtype
:
torch
.
dtype
=
torch
.
float32
,
)
->
torch
.
Tensor
:
r
"""
Compute beta * K * K^T.
Args:
k (torch.Tensor):
The key tensor of shape `[B, T, H, K]`.
beta (torch.Tensor):
The beta tensor of shape `[B, T, H]`.
g_cumsum (torch.Tensor):
The cumulative sum of the gate tensor of shape `[B, T, H]`.
Default: None
cu_seqlens (torch.LongTensor):
The cumulative sequence lengths of the input tensor.
Default: None
chunk_size (int):
The chunk size. Default: 64.
output_dtype (torch.dtype):
The dtype of the output tensor. Default: `torch.float32`
Returns:
beta * K * K^T of shape `[B, T, H, BT]` where `BT` is the chunk size.
"""
B
,
T
,
Hg
,
K
=
k
.
shape
H
=
beta
.
shape
[
-
1
]
BT
=
chunk_size
chunk_indices
=
(
prepare_chunk_indices
(
cu_seqlens
,
BT
)
if
cu_seqlens
is
not
None
else
None
)
NT
=
triton
.
cdiv
(
T
,
BT
)
if
cu_seqlens
is
None
else
len
(
chunk_indices
)
A
=
torch
.
empty
(
B
,
T
,
H
,
BT
,
device
=
k
.
device
,
dtype
=
output_dtype
)
chunk_scaled_dot_kkt_fwd_kernel
[(
NT
,
B
*
H
)](
k
=
k
,
beta
=
beta
,
g_cumsum
=
g_cumsum
,
A
=
A
,
cu_seqlens
=
cu_seqlens
,
chunk_indices
=
chunk_indices
,
T
=
T
,
H
=
H
,
Hg
=
Hg
,
K
=
K
,
BT
=
BT
,
BK
=
64
,
num_warps
=
8
,
num_stages
=
3
,
)
return
A
python/sglang/srt/layers/attention/fla/cumsum.py
0 → 100644
View file @
dc491b39
# Adapt from https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/utils/cumsum.py
# -*- coding: utf-8 -*-
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
from
typing
import
Optional
import
torch
import
triton
import
triton.language
as
tl
from
sglang.srt.layers.attention.fla.index
import
prepare_chunk_indices
from
sglang.srt.layers.attention.fla.utils
import
check_shared_mem
,
input_guard
BS_LIST
=
[
32
,
64
]
if
check_shared_mem
()
else
[
16
,
32
]
@
triton
.
heuristics
(
{
"HAS_SCALE"
:
lambda
args
:
args
[
"scale"
]
is
not
None
,
"IS_VARLEN"
:
lambda
args
:
args
[
"cu_seqlens"
]
is
not
None
,
}
)
# @triton.autotune(
# configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8]],
# key=["B", "H", "BT", "IS_VARLEN", "REVERSE"],
# )
@
triton
.
jit
(
do_not_specialize
=
[
"T"
])
def
chunk_local_cumsum_scalar_kernel
(
s
,
o
,
scale
,
cu_seqlens
,
chunk_indices
,
T
,
B
:
tl
.
constexpr
,
H
:
tl
.
constexpr
,
BT
:
tl
.
constexpr
,
REVERSE
:
tl
.
constexpr
,
HAS_SCALE
:
tl
.
constexpr
,
IS_VARLEN
:
tl
.
constexpr
,
HEAD_FIRST
:
tl
.
constexpr
,
):
i_t
,
i_bh
=
tl
.
program_id
(
0
),
tl
.
program_id
(
1
)
i_b
,
i_h
=
i_bh
//
H
,
i_bh
%
H
if
IS_VARLEN
:
i_n
,
i_t
=
tl
.
load
(
chunk_indices
+
i_t
*
2
).
to
(
tl
.
int32
),
tl
.
load
(
chunk_indices
+
i_t
*
2
+
1
).
to
(
tl
.
int32
)
bos
,
eos
=
tl
.
load
(
cu_seqlens
+
i_n
).
to
(
tl
.
int32
),
tl
.
load
(
cu_seqlens
+
i_n
+
1
).
to
(
tl
.
int32
)
T
=
eos
-
bos
else
:
bos
,
eos
=
i_b
*
T
,
i_b
*
T
+
T
if
HEAD_FIRST
:
p_s
=
tl
.
make_block_ptr
(
s
+
bos
*
H
+
i_h
*
T
,
(
T
,),
(
1
,),
(
i_t
*
BT
,),
(
BT
,),
(
0
,)
)
p_o
=
tl
.
make_block_ptr
(
o
+
bos
*
H
+
i_h
*
T
,
(
T
,),
(
1
,),
(
i_t
*
BT
,),
(
BT
,),
(
0
,)
)
else
:
p_s
=
tl
.
make_block_ptr
(
s
+
bos
*
H
+
i_h
,
(
T
,),
(
H
,),
(
i_t
*
BT
,),
(
BT
,),
(
0
,))
p_o
=
tl
.
make_block_ptr
(
o
+
bos
*
H
+
i_h
,
(
T
,),
(
H
,),
(
i_t
*
BT
,),
(
BT
,),
(
0
,))
# [BT]
b_s
=
tl
.
load
(
p_s
,
boundary_check
=
(
0
,)).
to
(
tl
.
float32
)
b_o
=
tl
.
cumsum
(
b_s
,
axis
=
0
)
if
REVERSE
:
b_z
=
tl
.
sum
(
b_s
,
axis
=
0
)
b_o
=
-
b_o
+
b_z
[
None
]
+
b_s
if
HAS_SCALE
:
b_o
*=
scale
tl
.
store
(
p_o
,
b_o
.
to
(
p_o
.
dtype
.
element_ty
),
boundary_check
=
(
0
,))
@
triton
.
heuristics
(
{
"HAS_SCALE"
:
lambda
args
:
args
[
"scale"
]
is
not
None
,
"IS_VARLEN"
:
lambda
args
:
args
[
"cu_seqlens"
]
is
not
None
,
}
)
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({
"BS"
:
BS
},
num_warps
=
num_warps
)
for
BS
in
BS_LIST
for
num_warps
in
[
2
,
4
,
8
]
],
key
=
[
"B"
,
"H"
,
"S"
,
"BT"
,
"IS_VARLEN"
,
"REVERSE"
],
)
@
triton
.
jit
(
do_not_specialize
=
[
"T"
])
def
chunk_local_cumsum_vector_kernel
(
s
,
o
,
scale
,
cu_seqlens
,
chunk_indices
,
T
,
B
:
tl
.
constexpr
,
H
:
tl
.
constexpr
,
S
:
tl
.
constexpr
,
BT
:
tl
.
constexpr
,
BS
:
tl
.
constexpr
,
REVERSE
:
tl
.
constexpr
,
HAS_SCALE
:
tl
.
constexpr
,
IS_VARLEN
:
tl
.
constexpr
,
HEAD_FIRST
:
tl
.
constexpr
,
):
i_s
,
i_t
,
i_bh
=
tl
.
program_id
(
0
),
tl
.
program_id
(
1
),
tl
.
program_id
(
2
)
i_b
,
i_h
=
i_bh
//
H
,
i_bh
%
H
if
IS_VARLEN
:
i_n
,
i_t
=
tl
.
load
(
chunk_indices
+
i_t
*
2
).
to
(
tl
.
int32
),
tl
.
load
(
chunk_indices
+
i_t
*
2
+
1
).
to
(
tl
.
int32
)
bos
,
eos
=
tl
.
load
(
cu_seqlens
+
i_n
).
to
(
tl
.
int32
),
tl
.
load
(
cu_seqlens
+
i_n
+
1
).
to
(
tl
.
int32
)
T
=
eos
-
bos
else
:
bos
,
eos
=
i_b
*
T
,
i_b
*
T
+
T
o_i
=
tl
.
arange
(
0
,
BT
)
if
REVERSE
:
m_s
=
tl
.
where
(
o_i
[:,
None
]
<=
o_i
[
None
,
:],
1.0
,
0.0
)
else
:
m_s
=
tl
.
where
(
o_i
[:,
None
]
>=
o_i
[
None
,
:],
1.0
,
0.0
)
if
HEAD_FIRST
:
p_s
=
tl
.
make_block_ptr
(
s
+
(
bos
*
H
+
i_h
*
T
)
*
S
,
(
T
,
S
),
(
S
,
1
),
(
i_t
*
BT
,
i_s
*
BS
),
(
BT
,
BS
),
(
1
,
0
),
)
p_o
=
tl
.
make_block_ptr
(
o
+
(
bos
*
H
+
i_h
*
T
)
*
S
,
(
T
,
S
),
(
S
,
1
),
(
i_t
*
BT
,
i_s
*
BS
),
(
BT
,
BS
),
(
1
,
0
),
)
else
:
p_s
=
tl
.
make_block_ptr
(
s
+
(
bos
*
H
+
i_h
)
*
S
,
(
T
,
S
),
(
H
*
S
,
1
),
(
i_t
*
BT
,
i_s
*
BS
),
(
BT
,
BS
),
(
1
,
0
),
)
p_o
=
tl
.
make_block_ptr
(
o
+
(
bos
*
H
+
i_h
)
*
S
,
(
T
,
S
),
(
H
*
S
,
1
),
(
i_t
*
BT
,
i_s
*
BS
),
(
BT
,
BS
),
(
1
,
0
),
)
# [BT, BS]
b_s
=
tl
.
load
(
p_s
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
b_o
=
tl
.
dot
(
m_s
,
b_s
,
allow_tf32
=
False
)
if
HAS_SCALE
:
b_o
*=
scale
tl
.
store
(
p_o
,
b_o
.
to
(
p_o
.
dtype
.
element_ty
),
boundary_check
=
(
0
,
1
))
def
chunk_local_cumsum_scalar
(
g
:
torch
.
Tensor
,
chunk_size
:
int
,
reverse
:
bool
=
False
,
scale
:
float
=
None
,
cu_seqlens
:
Optional
[
torch
.
Tensor
]
=
None
,
head_first
:
bool
=
False
,
output_dtype
:
Optional
[
torch
.
dtype
]
=
torch
.
float
,
)
->
torch
.
Tensor
:
if
head_first
:
B
,
H
,
T
=
g
.
shape
else
:
B
,
T
,
H
=
g
.
shape
assert
chunk_size
==
2
**
(
chunk_size
.
bit_length
()
-
1
),
"chunk_size must be a power of 2"
BT
=
chunk_size
chunk_indices
=
(
prepare_chunk_indices
(
cu_seqlens
,
BT
)
if
cu_seqlens
is
not
None
else
None
)
NT
=
triton
.
cdiv
(
T
,
BT
)
if
cu_seqlens
is
None
else
len
(
chunk_indices
)
g_org
,
g
=
g
,
torch
.
empty_like
(
g
,
dtype
=
output_dtype
or
g
.
dtype
)
grid
=
(
NT
,
B
*
H
)
chunk_local_cumsum_scalar_kernel
[
grid
](
s
=
g_org
,
o
=
g
,
scale
=
scale
,
cu_seqlens
=
cu_seqlens
,
chunk_indices
=
chunk_indices
,
T
=
T
,
B
=
B
,
H
=
H
,
BT
=
BT
,
HEAD_FIRST
=
head_first
,
REVERSE
=
reverse
,
num_warps
=
8
,
num_stages
=
3
,
)
return
g
def
chunk_local_cumsum_vector
(
g
:
torch
.
Tensor
,
chunk_size
:
int
,
reverse
:
bool
=
False
,
scale
:
float
=
None
,
cu_seqlens
:
Optional
[
torch
.
Tensor
]
=
None
,
head_first
:
bool
=
False
,
output_dtype
:
Optional
[
torch
.
dtype
]
=
torch
.
float
,
)
->
torch
.
Tensor
:
if
head_first
:
B
,
H
,
T
,
S
=
g
.
shape
else
:
B
,
T
,
H
,
S
=
g
.
shape
BT
=
chunk_size
chunk_indices
=
(
prepare_chunk_indices
(
cu_seqlens
,
chunk_size
)
if
cu_seqlens
is
not
None
else
None
)
NT
=
triton
.
cdiv
(
T
,
BT
)
if
cu_seqlens
is
None
else
len
(
chunk_indices
)
assert
chunk_size
==
2
**
(
chunk_size
.
bit_length
()
-
1
),
"chunk_size must be a power of 2"
g_org
,
g
=
g
,
torch
.
empty_like
(
g
,
dtype
=
output_dtype
or
g
.
dtype
)
def
grid
(
meta
):
return
(
triton
.
cdiv
(
meta
[
"S"
],
meta
[
"BS"
]),
NT
,
B
*
H
)
# keep cumulative normalizer in fp32
# this kernel is equivalent to
# g = g.view(B, H, NT, BT, -1).cumsum(-2).view(B, H, T, -1)
chunk_local_cumsum_vector_kernel
[
grid
](
s
=
g_org
,
o
=
g
,
scale
=
scale
,
cu_seqlens
=
cu_seqlens
,
chunk_indices
=
chunk_indices
,
T
=
T
,
B
=
B
,
H
=
H
,
S
=
S
,
BT
=
BT
,
HEAD_FIRST
=
head_first
,
REVERSE
=
reverse
,
)
return
g
@
input_guard
def
chunk_local_cumsum
(
g
:
torch
.
Tensor
,
chunk_size
:
int
,
reverse
:
bool
=
False
,
scale
:
float
=
None
,
cu_seqlens
:
Optional
[
torch
.
Tensor
]
=
None
,
head_first
:
bool
=
False
,
output_dtype
:
Optional
[
torch
.
dtype
]
=
torch
.
float
,
**
kwargs
,
)
->
torch
.
Tensor
:
if
cu_seqlens
is
not
None
:
assert
(
g
.
shape
[
0
]
==
1
),
"Only batch size 1 is supported when cu_seqlens are provided"
if
len
(
g
.
shape
)
==
3
:
return
chunk_local_cumsum_scalar
(
g
=
g
,
chunk_size
=
chunk_size
,
reverse
=
reverse
,
scale
=
scale
,
cu_seqlens
=
cu_seqlens
,
head_first
=
head_first
,
output_dtype
=
output_dtype
,
)
elif
len
(
g
.
shape
)
==
4
:
return
chunk_local_cumsum_vector
(
g
=
g
,
chunk_size
=
chunk_size
,
reverse
=
reverse
,
scale
=
scale
,
cu_seqlens
=
cu_seqlens
,
head_first
=
head_first
,
output_dtype
=
output_dtype
,
)
else
:
raise
ValueError
(
f
"Unsupported input shape
{
g
.
shape
}
, "
f
"which should be (B, T, H, D) if `head_first=False` "
f
"or (B, H, T, D) otherwise"
)
python/sglang/srt/layers/attention/fla/fused_recurrent.py
0 → 100644
View file @
dc491b39
This diff is collapsed.
Click to expand it.
python/sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py
0 → 100644
View file @
dc491b39
from
typing
import
Optional
import
torch
import
triton
import
triton.language
as
tl
from
sglang.srt.layers.attention.fla.utils
import
input_guard
@
triton
.
heuristics
(
{
"USE_INITIAL_STATE"
:
lambda
args
:
args
[
"h0_source"
]
is
not
None
,
"IS_VARLEN"
:
lambda
args
:
args
[
"cu_seqlens"
]
is
not
None
,
}
)
@
triton
.
jit
(
do_not_specialize
=
[
"T"
])
def
fused_sigmoid_gating_delta_rule_update_kernel
(
A_log
,
a
,
dt_bias
,
softplus_beta
,
softplus_threshold
,
q
,
k
,
v
,
b
,
o
,
h0_source
,
h0_indices
,
cu_seqlens
,
scale
,
T
,
B
:
tl
.
constexpr
,
H
:
tl
.
constexpr
,
HV
:
tl
.
constexpr
,
K
:
tl
.
constexpr
,
V
:
tl
.
constexpr
,
BK
:
tl
.
constexpr
,
BV
:
tl
.
constexpr
,
USE_INITIAL_STATE
:
tl
.
constexpr
,
USE_QK_L2NORM_IN_KERNEL
:
tl
.
constexpr
,
IS_VARLEN
:
tl
.
constexpr
,
):
"""
Fused kernel that combines sigmoid gating computation with recurrent delta rule update.
"""
i_k
,
i_v
,
i_nh
=
tl
.
program_id
(
0
),
tl
.
program_id
(
1
),
tl
.
program_id
(
2
)
i_n
,
i_hv
=
i_nh
//
HV
,
i_nh
%
HV
i_h
=
i_hv
//
(
HV
//
H
)
if
IS_VARLEN
:
bos
,
eos
=
(
tl
.
load
(
cu_seqlens
+
i_n
).
to
(
tl
.
int64
),
tl
.
load
(
cu_seqlens
+
i_n
+
1
).
to
(
tl
.
int64
),
)
all
=
T
T
=
eos
-
bos
else
:
bos
,
eos
=
i_n
*
T
,
i_n
*
T
+
T
all
=
B
*
T
o_k
=
i_k
*
BK
+
tl
.
arange
(
0
,
BK
)
o_v
=
i_v
*
BV
+
tl
.
arange
(
0
,
BV
)
p_q
=
q
+
(
bos
*
H
+
i_h
)
*
K
+
o_k
p_k
=
k
+
(
bos
*
H
+
i_h
)
*
K
+
o_k
p_v
=
v
+
(
bos
*
HV
+
i_hv
)
*
V
+
o_v
p_b
=
b
+
bos
*
HV
+
i_hv
p_o
=
o
+
((
i_k
*
all
+
bos
)
*
HV
+
i_hv
)
*
V
+
o_v
# Gating computation pointers
p_A_log
=
A_log
+
i_hv
p_a
=
a
+
bos
*
HV
+
i_hv
p_dt_bias
=
dt_bias
+
i_hv
mask_k
=
o_k
<
K
mask_v
=
o_v
<
V
mask_h
=
mask_k
[:,
None
]
&
mask_v
[
None
,
:]
b_h
=
tl
.
zeros
([
BK
,
BV
],
dtype
=
tl
.
float32
)
if
USE_INITIAL_STATE
:
idx
=
tl
.
load
(
h0_indices
+
i_n
)
if
idx
>=
0
:
p_h0
=
(
h0_source
+
idx
*
HV
*
K
*
V
+
i_hv
*
K
*
V
+
o_k
[:,
None
]
*
V
+
o_v
[
None
,
:]
)
b_h
+=
tl
.
load
(
p_h0
,
mask
=
mask_h
,
other
=
0
).
to
(
tl
.
float32
)
for
_
in
range
(
0
,
T
):
# Load inputs
b_q
=
tl
.
load
(
p_q
,
mask
=
mask_k
,
other
=
0
).
to
(
tl
.
float32
)
b_k
=
tl
.
load
(
p_k
,
mask
=
mask_k
,
other
=
0
).
to
(
tl
.
float32
)
b_v
=
tl
.
load
(
p_v
,
mask
=
mask_v
,
other
=
0
).
to
(
tl
.
float32
)
b_b
=
tl
.
load
(
p_b
).
to
(
tl
.
float32
)
# Compute sigmoid gating
# Load gating parameters
b_A_log
=
tl
.
load
(
p_A_log
).
to
(
tl
.
float32
)
b_a
=
tl
.
load
(
p_a
).
to
(
tl
.
float32
)
b_dt_bias
=
tl
.
load
(
p_dt_bias
).
to
(
tl
.
float32
)
# Compute g = -exp(A_log) * softplus(a + dt_bias)
x
=
b_a
+
b_dt_bias
beta_x
=
softplus_beta
*
x
# Apply softplus with numerical stability
softplus_x
=
tl
.
where
(
beta_x
<=
softplus_threshold
,
(
1.0
/
softplus_beta
)
*
tl
.
log
(
1.0
+
tl
.
exp
(
beta_x
)),
x
,
)
b_g
=
-
tl
.
exp
(
b_A_log
)
*
softplus_x
# Compute beta = sigmoid(b)
b_beta
=
1.0
/
(
1.0
+
tl
.
exp
(
-
b_b
))
# Apply L2 normalization if enabled
if
USE_QK_L2NORM_IN_KERNEL
:
b_q
=
b_q
/
(
tl
.
sqrt
(
tl
.
sum
(
b_q
*
b_q
))
+
1e-6
)
b_k
=
b_k
/
(
tl
.
sqrt
(
tl
.
sum
(
b_k
*
b_k
))
+
1e-6
)
b_q
=
b_q
*
scale
# Apply gating to hidden state: h *= exp(g)
b_h
*=
tl
.
exp
(
b_g
)
# Delta rule: v -= sum(h * k, dim=0)
b_v
-=
tl
.
sum
(
b_h
*
b_k
[:,
None
],
0
)
# Apply beta gating: v *= beta
b_v
*=
b_beta
# Update hidden state: h += k[:, None] * v[None, :]
b_h
+=
b_k
[:,
None
]
*
b_v
[
None
,
:]
# Compute output: o = sum(h * q, dim=0)
b_o
=
tl
.
sum
(
b_h
*
b_q
[:,
None
],
0
)
tl
.
store
(
p_o
,
b_o
.
to
(
p_o
.
dtype
.
element_ty
),
mask
=
mask_v
)
# Update pointers for next timestep
p_q
+=
H
*
K
p_k
+=
H
*
K
p_o
+=
HV
*
V
p_v
+=
HV
*
V
p_b
+=
HV
p_a
+=
HV
# Store final state back to h0_source with bounds checking
if
USE_INITIAL_STATE
:
idx
=
tl
.
load
(
h0_indices
+
i_n
)
if
idx
>=
0
:
p_h0
=
(
h0_source
+
idx
*
HV
*
K
*
V
+
i_hv
*
K
*
V
+
o_k
[:,
None
]
*
V
+
o_v
[
None
,
:]
)
tl
.
store
(
p_h0
,
b_h
.
to
(
p_h0
.
dtype
.
element_ty
),
mask
=
mask_h
)
@
input_guard
def
fused_sigmoid_gating_delta_rule_update
(
A_log
:
torch
.
Tensor
,
a
:
torch
.
Tensor
,
dt_bias
:
torch
.
Tensor
,
softplus_beta
:
float
,
softplus_threshold
:
float
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
initial_state_source
:
torch
.
Tensor
,
initial_state_indices
:
torch
.
Tensor
,
scale
:
Optional
[
float
]
=
None
,
use_qk_l2norm_in_kernel
:
bool
=
False
,
cu_seqlens
:
Optional
[
torch
.
Tensor
]
=
None
,
):
"""
Fused triton implementation of sigmoid gating delta rule update.
This function uses a single fused kernel that combines both sigmoid gating computation
and the recurrent delta rule update for better performance.
"""
B
,
T
,
H
,
K
,
V
=
*
k
.
shape
,
v
.
shape
[
-
1
]
HV
=
v
.
shape
[
2
]
N
=
B
if
cu_seqlens
is
None
else
len
(
cu_seqlens
)
-
1
BK
,
BV
=
triton
.
next_power_of_2
(
K
),
min
(
triton
.
next_power_of_2
(
V
),
8
)
NK
,
NV
=
triton
.
cdiv
(
K
,
BK
),
triton
.
cdiv
(
V
,
BV
)
assert
NK
==
1
,
"NK > 1 is not supported yet"
num_stages
=
3
num_warps
=
1
if
scale
is
None
:
scale
=
k
.
shape
[
-
1
]
**
-
0.5
else
:
assert
scale
>
0
,
"scale must be positive"
o
=
q
.
new_empty
(
NK
,
*
v
.
shape
)
grid
=
(
NK
,
NV
,
N
*
HV
)
fused_sigmoid_gating_delta_rule_update_kernel
[
grid
](
A_log
=
A_log
,
a
=
a
,
dt_bias
=
dt_bias
,
softplus_beta
=
softplus_beta
,
softplus_threshold
=
softplus_threshold
,
q
=
q
,
k
=
k
,
v
=
v
,
b
=
b
,
o
=
o
,
h0_source
=
initial_state_source
,
h0_indices
=
initial_state_indices
,
cu_seqlens
=
cu_seqlens
,
scale
=
scale
,
T
=
T
,
B
=
B
,
H
=
H
,
HV
=
HV
,
K
=
K
,
V
=
V
,
BK
=
BK
,
BV
=
BV
,
USE_QK_L2NORM_IN_KERNEL
=
use_qk_l2norm_in_kernel
,
num_warps
=
num_warps
,
num_stages
=
num_stages
,
)
o
=
o
.
squeeze
(
0
)
return
o
python/sglang/srt/layers/attention/fla/index.py
0 → 100644
View file @
dc491b39
# Adapt from https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/utils/index.py
# -*- coding: utf-8 -*-
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
import
torch
import
torch.nn.functional
as
F
import
triton
import
triton.language
as
tl
from
sglang.srt.layers.attention.fla.utils
import
tensor_cache
@
tensor_cache
def
prepare_lens
(
cu_seqlens
:
torch
.
LongTensor
)
->
torch
.
LongTensor
:
return
cu_seqlens
[
1
:]
-
cu_seqlens
[:
-
1
]
@
tensor_cache
def
prepare_chunk_indices
(
cu_seqlens
:
torch
.
LongTensor
,
chunk_size
:
int
)
->
torch
.
LongTensor
:
indices
=
torch
.
cat
(
[
torch
.
arange
(
n
)
for
n
in
triton
.
cdiv
(
prepare_lens
(
cu_seqlens
),
chunk_size
).
tolist
()
]
)
return
torch
.
stack
([
indices
.
eq
(
0
).
cumsum
(
0
)
-
1
,
indices
],
1
).
to
(
cu_seqlens
)
@
tensor_cache
def
prepare_chunk_offsets
(
cu_seqlens
:
torch
.
LongTensor
,
chunk_size
:
int
)
->
torch
.
LongTensor
:
return
torch
.
cat
(
[
cu_seqlens
.
new_tensor
([
0
]),
triton
.
cdiv
(
prepare_lens
(
cu_seqlens
),
chunk_size
)]
).
cumsum
(
-
1
)
python/sglang/srt/layers/attention/fla/l2norm.py
0 → 100644
View file @
dc491b39
# Adapt from https://github.com/fla-org/flash-linear-attention/blob/main/fla/modules/l2norm.py
# -*- coding: utf-8 -*-
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
from
typing
import
Optional
import
torch
import
torch.nn
as
nn
import
triton
import
triton.language
as
tl
from
sglang.srt.layers.attention.fla.utils
import
input_guard
BT_LIST
=
[
8
,
16
,
32
,
64
,
128
]
# @triton.autotune(
# configs=[
# triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8, 16, 32]
# ],
# key=["D"],
# )
@
triton
.
jit
def
l2norm_fwd_kernel1
(
x
,
y
,
D
,
BD
:
tl
.
constexpr
,
eps
,
):
i_t
=
tl
.
program_id
(
0
)
x
+=
i_t
*
D
y
+=
i_t
*
D
# Compute mean and variance
cols
=
tl
.
arange
(
0
,
BD
)
mask
=
cols
<
D
b_x
=
tl
.
load
(
x
+
cols
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
b_var
=
tl
.
sum
(
b_x
*
b_x
,
axis
=
0
)
b_rstd
=
1
/
tl
.
sqrt
(
b_var
+
eps
)
# tl.store(Rstd + i_t, rstd)
# Normalize and apply linear transformation
b_y
=
b_x
*
b_rstd
tl
.
store
(
y
+
cols
,
b_y
,
mask
=
mask
)
# @triton.autotune(
# configs=[
# triton.Config({"BT": BT}, num_warps=num_warps)
# for num_warps in [1, 2, 4, 8, 16]
# for BT in BT_LIST
# ],
# key=["D", "NB"],
# )
@
triton
.
jit
def
l2norm_fwd_kernel
(
x
,
y
,
eps
,
NB
:
tl
.
constexpr
,
T
:
tl
.
constexpr
,
D
:
tl
.
constexpr
,
BT
:
tl
.
constexpr
,
BD
:
tl
.
constexpr
,
):
i_t
=
tl
.
program_id
(
0
)
p_x
=
tl
.
make_block_ptr
(
x
,
(
T
,
D
),
(
D
,
1
),
(
i_t
*
BT
,
0
),
(
BT
,
BD
),
(
1
,
0
))
b_x
=
tl
.
load
(
p_x
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
b_var
=
tl
.
sum
(
b_x
*
b_x
,
axis
=
1
)
b_y
=
b_x
/
tl
.
sqrt
(
b_var
+
eps
)[:,
None
]
p_y
=
tl
.
make_block_ptr
(
y
,
(
T
,
D
),
(
D
,
1
),
(
i_t
*
BT
,
0
),
(
BT
,
BD
),
(
1
,
0
))
tl
.
store
(
p_y
,
b_y
.
to
(
p_y
.
dtype
.
element_ty
),
boundary_check
=
(
0
,
1
))
def
l2norm_fwd
(
x
:
torch
.
Tensor
,
eps
:
float
=
1e-6
,
output_dtype
:
Optional
[
torch
.
dtype
]
=
None
):
x_shape_og
=
x
.
shape
x
=
x
.
view
(
-
1
,
x
.
shape
[
-
1
])
# allocate output
if
output_dtype
is
None
:
y
=
torch
.
empty_like
(
x
)
else
:
y
=
torch
.
empty_like
(
x
,
dtype
=
output_dtype
)
assert
y
.
stride
(
-
1
)
==
1
T
,
D
=
x
.
shape
[
0
],
x
.
shape
[
-
1
]
# rstd = torch.empty((T,), dtype=torch.float32, device=x.device)
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE
=
65536
//
x
.
element_size
()
BD
=
min
(
MAX_FUSED_SIZE
,
triton
.
next_power_of_2
(
D
))
if
D
>
BD
:
raise
RuntimeError
(
"This layer doesn't support feature dim >= 64KB."
)
if
D
<=
512
:
NB
=
triton
.
cdiv
(
T
,
2048
)
def
grid
(
meta
):
return
(
triton
.
cdiv
(
T
,
meta
[
"BT"
]),)
l2norm_fwd_kernel
[
grid
](
x
,
y
,
eps
,
NB
=
NB
,
T
=
T
,
D
=
D
,
BD
=
BD
,
BT
=
16
,
num_warps
=
8
,
num_stages
=
3
,
)
else
:
l2norm_fwd_kernel1
[(
T
,)](
x
,
y
,
eps
=
eps
,
D
=
D
,
BD
=
BD
,
num_warps
=
8
,
num_stages
=
3
,
)
return
y
.
view
(
x_shape_og
)
class
L2NormFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
@
input_guard
def
forward
(
ctx
,
x
,
eps
=
1e-6
,
output_dtype
=
None
):
return
l2norm_fwd
(
x
,
eps
,
output_dtype
)
def
l2norm
(
x
:
torch
.
Tensor
,
eps
:
float
=
1e-6
,
output_dtype
:
Optional
[
torch
.
dtype
]
=
None
)
->
torch
.
Tensor
:
return
L2NormFunction
.
apply
(
x
,
eps
,
output_dtype
)
l2_norm
=
l2norm
class
L2Norm
(
nn
.
Module
):
def
__init__
(
self
,
eps
:
float
=
1e-6
,
output_dtype
:
Optional
[
torch
.
dtype
]
=
None
):
super
().
__init__
()
self
.
eps
=
eps
self
.
output_dtype
=
output_dtype
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
l2norm
(
x
,
self
.
eps
,
self
.
output_dtype
)
python/sglang/srt/layers/attention/fla/layernorm_gated.py
0 → 100644
View file @
dc491b39
# Adapt from https://github.com/fla-org/flash-linear-attention/blob/main/fla/modules/layernorm_gated.py
# Copyright (c) 2024, Tri Dao.
# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.
# This backward pass is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
import
math
import
torch
import
torch.nn.functional
as
F
import
triton
import
triton.language
as
tl
from
einops
import
rearrange
def
rms_norm_ref
(
x
,
weight
,
bias
,
z
=
None
,
eps
=
1e-6
,
group_size
=
None
,
norm_before_gate
=
True
,
upcast
=
True
,
):
dtype
=
x
.
dtype
N
=
x
.
shape
[
-
1
]
weight
=
weight
.
float
()
bias
=
bias
.
float
()
if
bias
is
not
None
else
None
if
upcast
:
x
=
x
.
float
()
z
=
z
.
float
()
if
z
is
not
None
else
z
if
z
is
not
None
and
not
norm_before_gate
:
x
=
x
*
F
.
silu
(
z
)
if
group_size
is
None
:
rstd
=
1
/
torch
.
sqrt
((
x
.
square
()).
mean
(
dim
=-
1
,
keepdim
=
True
)
+
eps
)
out
=
(
x
*
rstd
*
weight
)
+
bias
if
bias
is
not
None
else
(
x
*
rstd
*
weight
)
else
:
x_group
=
rearrange
(
x
,
"... (g d) -> ... g d"
,
d
=
group_size
)
rstd
=
1
/
torch
.
sqrt
((
x_group
.
square
()).
mean
(
dim
=-
1
,
keepdim
=
True
)
+
eps
)
out
=
rearrange
(
x_group
*
rstd
,
"... g d -> ... (g d)"
)
*
weight
if
bias
is
not
None
:
out
=
out
+
bias
if
z
is
not
None
and
norm_before_gate
:
out
*=
F
.
silu
(
z
)
return
out
.
to
(
dtype
)
@
triton
.
heuristics
({
"HAS_BIAS"
:
lambda
args
:
args
[
"B"
]
is
not
None
})
@
triton
.
heuristics
({
"HAS_Z"
:
lambda
args
:
args
[
"Z"
]
is
not
None
})
@
triton
.
jit
def
_layer_norm_fwd_1pass_kernel
(
X
,
# pointer to the input
Y
,
# pointer to the output
W
,
# pointer to the weights
B
,
# pointer to the biases
Z
,
# pointer to the other branch
Mean
,
# pointer to the mean
Rstd
,
# pointer to the 1/std
stride_x_row
,
# how much to increase the pointer when moving by 1 row
stride_y_row
,
stride_z_row
,
M
,
# number of rows in X
N
,
# number of columns in X
eps
,
# epsilon to avoid division by zero
BLOCK_N
:
tl
.
constexpr
,
HAS_BIAS
:
tl
.
constexpr
,
HAS_Z
:
tl
.
constexpr
,
NORM_BEFORE_GATE
:
tl
.
constexpr
,
IS_RMS_NORM
:
tl
.
constexpr
,
):
# Map the program id to the row of X and Y it should compute.
row
=
tl
.
program_id
(
0
)
group
=
tl
.
program_id
(
1
)
X
+=
row
*
stride_x_row
+
group
*
N
Y
+=
row
*
stride_y_row
+
group
*
N
if
HAS_Z
:
Z
+=
row
*
stride_z_row
+
group
*
N
if
not
IS_RMS_NORM
:
Mean
+=
group
*
M
Rstd
+=
group
*
M
W
+=
group
*
N
if
HAS_BIAS
:
B
+=
group
*
N
# Compute mean and variance
cols
=
tl
.
arange
(
0
,
BLOCK_N
)
x
=
tl
.
load
(
X
+
cols
,
mask
=
cols
<
N
,
other
=
0.0
).
to
(
tl
.
float32
)
if
HAS_Z
and
not
NORM_BEFORE_GATE
:
z
=
tl
.
load
(
Z
+
cols
,
mask
=
cols
<
N
).
to
(
tl
.
float32
)
x
*=
z
*
tl
.
sigmoid
(
z
)
if
not
IS_RMS_NORM
:
mean
=
tl
.
sum
(
x
,
axis
=
0
)
/
N
tl
.
store
(
Mean
+
row
,
mean
)
xbar
=
tl
.
where
(
cols
<
N
,
x
-
mean
,
0.0
)
var
=
tl
.
sum
(
xbar
*
xbar
,
axis
=
0
)
/
N
else
:
xbar
=
tl
.
where
(
cols
<
N
,
x
,
0.0
)
var
=
tl
.
sum
(
xbar
*
xbar
,
axis
=
0
)
/
N
rstd
=
1
/
tl
.
sqrt
(
var
+
eps
)
tl
.
store
(
Rstd
+
row
,
rstd
)
# Normalize and apply linear transformation
mask
=
cols
<
N
w
=
tl
.
load
(
W
+
cols
,
mask
=
mask
).
to
(
tl
.
float32
)
if
HAS_BIAS
:
b
=
tl
.
load
(
B
+
cols
,
mask
=
mask
).
to
(
tl
.
float32
)
x_hat
=
(
x
-
mean
)
*
rstd
if
not
IS_RMS_NORM
else
x
*
rstd
y
=
x_hat
*
w
+
b
if
HAS_BIAS
else
x_hat
*
w
if
HAS_Z
and
NORM_BEFORE_GATE
:
z
=
tl
.
load
(
Z
+
cols
,
mask
=
mask
).
to
(
tl
.
float32
)
y
*=
z
*
tl
.
sigmoid
(
z
)
# Write output
tl
.
store
(
Y
+
cols
,
y
,
mask
=
mask
)
def
_layer_norm_fwd
(
x
,
weight
,
bias
,
eps
,
z
=
None
,
out
=
None
,
group_size
=
None
,
norm_before_gate
=
True
,
is_rms_norm
=
False
,
):
M
,
N
=
x
.
shape
if
group_size
is
None
:
group_size
=
N
assert
N
%
group_size
==
0
ngroups
=
N
//
group_size
assert
x
.
stride
(
-
1
)
==
1
if
z
is
not
None
:
assert
z
.
stride
(
-
1
)
==
1
assert
z
.
shape
==
(
M
,
N
)
assert
weight
.
shape
==
(
N
,)
assert
weight
.
stride
(
-
1
)
==
1
if
bias
is
not
None
:
assert
bias
.
stride
(
-
1
)
==
1
assert
bias
.
shape
==
(
N
,)
# allocate output
if
out
is
not
None
:
assert
out
.
shape
==
x
.
shape
else
:
out
=
torch
.
empty_like
(
x
)
assert
out
.
stride
(
-
1
)
==
1
mean
=
(
torch
.
empty
((
ngroups
*
M
,),
dtype
=
torch
.
float32
,
device
=
x
.
device
)
if
not
is_rms_norm
else
None
)
rstd
=
torch
.
empty
((
ngroups
*
M
,),
dtype
=
torch
.
float32
,
device
=
x
.
device
)
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE
=
65536
//
x
.
element_size
()
BLOCK_N
=
min
(
MAX_FUSED_SIZE
,
triton
.
next_power_of_2
(
group_size
))
if
group_size
>
BLOCK_N
:
raise
RuntimeError
(
"This layer norm doesn't support feature dim >= 64KB."
)
# heuristics for number of warps
num_warps
=
min
(
max
(
BLOCK_N
//
256
,
1
),
8
)
grid
=
(
M
,
ngroups
)
with
torch
.
cuda
.
device
(
x
.
device
.
index
):
_layer_norm_fwd_1pass_kernel
[
grid
](
x
,
out
,
weight
,
bias
,
z
,
mean
,
rstd
,
x
.
stride
(
0
),
out
.
stride
(
0
),
z
.
stride
(
0
)
if
z
is
not
None
else
0
,
M
,
group_size
,
eps
,
BLOCK_N
=
BLOCK_N
,
NORM_BEFORE_GATE
=
norm_before_gate
,
IS_RMS_NORM
=
is_rms_norm
,
num_warps
=
num_warps
,
)
return
out
,
mean
,
rstd
class
LayerNormFn
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
x
,
weight
,
bias
,
z
=
None
,
eps
=
1e-6
,
group_size
=
None
,
norm_before_gate
=
True
,
is_rms_norm
=
False
,
):
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
x_shape_og
=
x
.
shape
# reshape input data into 2D tensor
x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
if
x
.
stride
(
-
1
)
!=
1
:
x
=
x
.
contiguous
()
if
z
is
not
None
:
assert
z
.
shape
==
x_shape_og
z
=
z
.
reshape
(
-
1
,
z
.
shape
[
-
1
])
if
z
.
stride
(
-
1
)
!=
1
:
z
=
z
.
contiguous
()
weight
=
weight
.
contiguous
()
if
bias
is
not
None
:
bias
=
bias
.
contiguous
()
y
,
mean
,
rstd
=
_layer_norm_fwd
(
x
,
weight
,
bias
,
eps
,
z
=
z
,
group_size
=
group_size
,
norm_before_gate
=
norm_before_gate
,
is_rms_norm
=
is_rms_norm
,
)
return
y
.
reshape
(
x_shape_og
)
def
layernorm_fn
(
x
,
weight
,
bias
,
z
=
None
,
eps
=
1e-6
,
group_size
=
None
,
norm_before_gate
=
True
,
is_rms_norm
=
False
,
):
return
LayerNormFn
.
apply
(
x
,
weight
,
bias
,
z
,
eps
,
group_size
,
norm_before_gate
,
is_rms_norm
)
def
rmsnorm_fn
(
x
,
weight
,
bias
,
z
=
None
,
eps
=
1e-6
,
group_size
=
None
,
norm_before_gate
=
True
):
return
LayerNormFn
.
apply
(
x
,
weight
,
bias
,
z
,
eps
,
group_size
,
norm_before_gate
,
True
)
class
LayerNorm
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
eps
=
1e-5
,
group_size
=
None
,
norm_before_gate
=
True
,
device
=
None
,
dtype
=
None
,
):
"""If group_size is not None, we do GroupNorm with each group having group_size elements.
group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group).
"""
factory_kwargs
=
{
"device"
:
device
,
"dtype"
:
dtype
}
super
().
__init__
()
self
.
eps
=
eps
self
.
weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
hidden_size
,
**
factory_kwargs
))
self
.
bias
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
hidden_size
,
**
factory_kwargs
))
self
.
group_size
=
group_size
self
.
norm_before_gate
=
norm_before_gate
self
.
reset_parameters
()
def
reset_parameters
(
self
):
torch
.
nn
.
init
.
ones_
(
self
.
weight
)
torch
.
nn
.
init
.
zeros_
(
self
.
bias
)
def
forward
(
self
,
x
,
z
=
None
):
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
return
layernorm_fn
(
x
,
self
.
weight
,
self
.
bias
,
z
=
z
,
group_size
=
self
.
group_size
,
eps
=
self
.
eps
,
norm_before_gate
=
self
.
norm_before_gate
,
)
class
RMSNorm
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
eps
=
1e-5
,
group_size
=
None
,
norm_before_gate
=
True
,
device
=
None
,
dtype
=
None
,
):
"""If group_size is not None, we do GroupNorm with each group having group_size elements.
group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group).
"""
factory_kwargs
=
{
"device"
:
device
,
"dtype"
:
dtype
}
super
().
__init__
()
self
.
eps
=
eps
self
.
weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
hidden_size
,
**
factory_kwargs
))
self
.
register_parameter
(
"bias"
,
None
)
self
.
group_size
=
group_size
self
.
norm_before_gate
=
norm_before_gate
self
.
reset_parameters
()
def
reset_parameters
(
self
):
torch
.
nn
.
init
.
ones_
(
self
.
weight
)
def
forward
(
self
,
x
,
z
=
None
):
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
return
rmsnorm_fn
(
x
,
self
.
weight
,
self
.
bias
,
z
=
z
,
eps
=
self
.
eps
,
group_size
=
self
.
group_size
,
norm_before_gate
=
self
.
norm_before_gate
,
)
python/sglang/srt/layers/attention/fla/op.py
0 → 100644
View file @
dc491b39
# Adapt from https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/utils/op.py
# -*- coding: utf-8 -*-
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
import
os
import
triton
import
triton.language
as
tl
import
triton.language.extra.libdevice
as
tldevice
from
sglang.srt.layers.attention.fla.utils
import
is_gather_supported
if
os
.
environ
.
get
(
"FLA_USE_FAST_OPS"
,
"0"
)
==
"1"
:
exp
=
tldevice
.
fast_expf
exp2
=
tldevice
.
exp2
log
=
tldevice
.
fast_logf
log2
=
tldevice
.
fast_log2f
else
:
exp
=
tl
.
exp
exp2
=
tl
.
math
.
exp2
log
=
tl
.
log
log2
=
tl
.
log2
@
triton
.
jit
def
safe_exp
(
x
):
return
exp
(
tl
.
where
(
x
<=
0
,
x
,
float
(
"-inf"
)))
if
not
is_gather_supported
:
@
triton
.
jit
def
gather
(
src
,
index
,
axis
,
_builder
=
None
):
"""
Gather operation that works when tl.gather is not supported.
This is a fallback implementation that returns None.
Just to make triton compiler happy.
"""
return
None
else
:
gather
=
tl
.
gather
if
hasattr
(
triton
.
language
,
"_experimental_make_tensor_descriptor"
):
# For Triton 3.3.x
make_tensor_descriptor
=
triton
.
language
.
_experimental_make_tensor_descriptor
elif
hasattr
(
triton
.
language
,
"make_tensor_descriptor"
):
# For Triton 3.4.x and later
make_tensor_descriptor
=
triton
.
language
.
make_tensor_descriptor
else
:
"""
Fallback implementation when TMA is not supported.
Returns None to indicate TMA descriptors are unavailable.
Just make triton compiler happy.
"""
@
triton
.
jit
def
make_tensor_descriptor
(
base
,
shape
,
strides
,
block_shape
,
_builder
=
None
,
):
return
None
python/sglang/srt/layers/attention/fla/solve_tril.py
0 → 100644
View file @
dc491b39
# Adapt from https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/utils/solve_tril.py
# -*- coding: utf-8 -*-
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
from
typing
import
Optional
import
torch
import
triton
import
triton.language
as
tl
from
sglang.srt.layers.attention.fla.index
import
prepare_chunk_indices
from
sglang.srt.layers.attention.fla.utils
import
input_guard
@
triton
.
heuristics
({
"IS_VARLEN"
:
lambda
args
:
args
[
"cu_seqlens"
]
is
not
None
})
# @triton.autotune(
# configs=[
# triton.Config({}, num_warps=num_warps, num_stages=num_stages)
# for num_warps in [1, 2, 4, 8]
# for num_stages in [2, 3, 4, 5]
# ],
# key=["BT"],
# )
@
triton
.
jit
(
do_not_specialize
=
[
"T"
])
def
solve_tril_16x16_kernel
(
A
,
Ad
,
cu_seqlens
,
chunk_indices
,
T
,
H
:
tl
.
constexpr
,
BT
:
tl
.
constexpr
,
IS_VARLEN
:
tl
.
constexpr
,
):
i_t
,
i_bh
=
tl
.
program_id
(
0
),
tl
.
program_id
(
1
)
i_b
,
i_h
=
i_bh
//
H
,
i_bh
%
H
if
IS_VARLEN
:
i_n
,
i_t
=
tl
.
load
(
chunk_indices
+
i_t
*
2
).
to
(
tl
.
int32
),
tl
.
load
(
chunk_indices
+
i_t
*
2
+
1
).
to
(
tl
.
int32
)
bos
,
eos
=
tl
.
load
(
cu_seqlens
+
i_n
).
to
(
tl
.
int32
),
tl
.
load
(
cu_seqlens
+
i_n
+
1
).
to
(
tl
.
int32
)
T
=
eos
-
bos
else
:
bos
,
eos
=
i_b
*
T
,
i_b
*
T
+
T
A
=
A
+
(
bos
*
H
+
i_h
)
*
BT
Ad
=
Ad
+
(
bos
*
H
+
i_h
)
*
16
offset
=
(
i_t
*
16
)
%
BT
p_A
=
tl
.
make_block_ptr
(
A
,
(
T
,
BT
),
(
H
*
BT
,
1
),
(
i_t
*
16
,
offset
),
(
16
,
16
),
(
1
,
0
)
)
p_Ai
=
tl
.
make_block_ptr
(
Ad
,
(
T
,
16
),
(
H
*
16
,
1
),
(
i_t
*
16
,
0
),
(
16
,
16
),
(
1
,
0
))
b_A
=
tl
.
load
(
p_A
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
b_A
=
-
tl
.
where
(
tl
.
arange
(
0
,
16
)[:,
None
]
>
tl
.
arange
(
0
,
16
)[
None
,
:],
b_A
,
0
)
o_i
=
tl
.
arange
(
0
,
16
)
for
i
in
range
(
1
,
min
(
16
,
T
-
i_t
*
16
)):
b_a
=
-
tl
.
load
(
A
+
(
i_t
*
16
+
i
)
*
H
*
BT
+
o_i
+
offset
)
b_a
=
b_a
+
tl
.
sum
(
b_a
[:,
None
]
*
b_A
,
0
)
mask
=
o_i
==
i
b_A
=
tl
.
where
(
mask
[:,
None
],
b_a
,
b_A
)
b_A
+=
o_i
[:,
None
]
==
o_i
[
None
,
:]
tl
.
store
(
p_Ai
,
b_A
.
to
(
p_Ai
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
),
)
@
triton
.
heuristics
({
"IS_VARLEN"
:
lambda
args
:
args
[
"cu_seqlens"
]
is
not
None
})
# @triton.autotune(
# configs=[
# triton.Config({}, num_warps=num_warps, num_stages=num_stages)
# for num_warps in [1, 2, 4, 8]
# for num_stages in [2, 3, 4, 5]
# ],
# key=["H", "BT", "IS_VARLEN"],
# )
@
triton
.
jit
(
do_not_specialize
=
[
"T"
])
def
merge_16x16_to_32x32_inverse_kernel
(
A
,
Ad
,
Ai
,
cu_seqlens
,
chunk_indices
,
T
,
H
:
tl
.
constexpr
,
BT
:
tl
.
constexpr
,
IS_VARLEN
:
tl
.
constexpr
,
):
i_t
,
i_bh
=
tl
.
program_id
(
0
),
tl
.
program_id
(
1
)
i_b
,
i_h
=
i_bh
//
H
,
i_bh
%
H
if
IS_VARLEN
:
i_n
,
i_t
=
tl
.
load
(
chunk_indices
+
i_t
*
2
).
to
(
tl
.
int32
),
tl
.
load
(
chunk_indices
+
i_t
*
2
+
1
).
to
(
tl
.
int32
)
bos
,
eos
=
tl
.
load
(
cu_seqlens
+
i_n
).
to
(
tl
.
int32
),
tl
.
load
(
cu_seqlens
+
i_n
+
1
).
to
(
tl
.
int32
)
T
=
eos
-
bos
else
:
bos
,
eos
=
i_b
*
T
,
i_b
*
T
+
T
A
+=
(
bos
*
H
+
i_h
)
*
32
Ad
+=
(
bos
*
H
+
i_h
)
*
16
Ai
+=
(
bos
*
H
+
i_h
)
*
32
p_A_21
=
tl
.
make_block_ptr
(
A
,
(
T
,
32
),
(
H
*
32
,
1
),
(
i_t
*
32
+
16
,
0
),
(
16
,
16
),
(
1
,
0
)
)
p_Ad_11
=
tl
.
make_block_ptr
(
Ad
,
(
T
,
16
),
(
H
*
16
,
1
),
(
i_t
*
32
,
0
),
(
16
,
16
),
(
1
,
0
)
)
p_Ad_22
=
tl
.
make_block_ptr
(
Ad
,
(
T
,
16
),
(
H
*
16
,
1
),
(
i_t
*
32
+
16
,
0
),
(
16
,
16
),
(
1
,
0
)
)
p_Ai_11
=
tl
.
make_block_ptr
(
Ai
,
(
T
,
32
),
(
H
*
32
,
1
),
(
i_t
*
32
,
0
),
(
16
,
16
),
(
1
,
0
)
)
p_Ai_22
=
tl
.
make_block_ptr
(
Ai
,
(
T
,
32
),
(
H
*
32
,
1
),
(
i_t
*
32
+
16
,
16
),
(
16
,
16
),
(
1
,
0
)
)
p_Ai_21
=
tl
.
make_block_ptr
(
Ai
,
(
T
,
32
),
(
H
*
32
,
1
),
(
i_t
*
32
+
16
,
0
),
(
16
,
16
),
(
1
,
0
)
)
A_21
=
tl
.
load
(
p_A_21
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
Ai_11
=
tl
.
load
(
p_Ad_11
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
Ai_22
=
tl
.
load
(
p_Ad_22
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
Ai_21
=
-
tl
.
dot
(
tl
.
dot
(
Ai_22
,
A_21
,
input_precision
=
"ieee"
),
Ai_11
,
input_precision
=
"ieee"
)
tl
.
store
(
p_Ai_11
,
Ai_11
.
to
(
p_Ai_11
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
),
)
tl
.
store
(
p_Ai_22
,
Ai_22
.
to
(
p_Ai_22
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
),
)
tl
.
store
(
p_Ai_21
,
Ai_21
.
to
(
p_Ai_21
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
),
)
@
triton
.
heuristics
({
"IS_VARLEN"
:
lambda
args
:
args
[
"cu_seqlens"
]
is
not
None
})
# @triton.autotune(
# configs=[
# triton.Config({}, num_warps=num_warps, num_stages=num_stages)
# for num_warps in [2, 4, 8]
# for num_stages in [2, 3, 4, 5]
# ],
# key=["H", "BT", "IS_VARLEN"],
# )
@
triton
.
jit
(
do_not_specialize
=
[
"T"
])
def
merge_16x16_to_64x64_inverse_kernel
(
A
,
Ad
,
Ai
,
cu_seqlens
,
chunk_indices
,
T
,
H
:
tl
.
constexpr
,
BT
:
tl
.
constexpr
,
IS_VARLEN
:
tl
.
constexpr
,
):
i_t
,
i_bh
=
tl
.
program_id
(
0
),
tl
.
program_id
(
1
)
i_b
,
i_h
=
i_bh
//
H
,
i_bh
%
H
if
IS_VARLEN
:
i_n
,
i_t
=
tl
.
load
(
chunk_indices
+
i_t
*
2
).
to
(
tl
.
int32
),
tl
.
load
(
chunk_indices
+
i_t
*
2
+
1
).
to
(
tl
.
int32
)
bos
,
eos
=
tl
.
load
(
cu_seqlens
+
i_n
).
to
(
tl
.
int32
),
tl
.
load
(
cu_seqlens
+
i_n
+
1
).
to
(
tl
.
int32
)
T
=
eos
-
bos
else
:
bos
,
eos
=
i_b
*
T
,
i_b
*
T
+
T
A
+=
(
bos
*
H
+
i_h
)
*
64
Ad
+=
(
bos
*
H
+
i_h
)
*
16
Ai
+=
(
bos
*
H
+
i_h
)
*
64
p_A_21
=
tl
.
make_block_ptr
(
A
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
+
16
,
0
),
(
16
,
16
),
(
1
,
0
)
)
p_A_32
=
tl
.
make_block_ptr
(
A
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
+
32
,
16
),
(
16
,
16
),
(
1
,
0
)
)
p_A_31
=
tl
.
make_block_ptr
(
A
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
+
32
,
0
),
(
16
,
16
),
(
1
,
0
)
)
p_A_43
=
tl
.
make_block_ptr
(
A
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
+
48
,
32
),
(
16
,
16
),
(
1
,
0
)
)
p_A_42
=
tl
.
make_block_ptr
(
A
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
+
48
,
16
),
(
16
,
16
),
(
1
,
0
)
)
p_A_41
=
tl
.
make_block_ptr
(
A
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
+
48
,
0
),
(
16
,
16
),
(
1
,
0
)
)
p_Ad_11
=
tl
.
make_block_ptr
(
Ad
,
(
T
,
16
),
(
H
*
16
,
1
),
(
i_t
*
64
,
0
),
(
16
,
16
),
(
1
,
0
)
)
p_Ad_22
=
tl
.
make_block_ptr
(
Ad
,
(
T
,
16
),
(
H
*
16
,
1
),
(
i_t
*
64
+
16
,
0
),
(
16
,
16
),
(
1
,
0
)
)
p_Ad_33
=
tl
.
make_block_ptr
(
Ad
,
(
T
,
16
),
(
H
*
16
,
1
),
(
i_t
*
64
+
32
,
0
),
(
16
,
16
),
(
1
,
0
)
)
p_Ad_44
=
tl
.
make_block_ptr
(
Ad
,
(
T
,
16
),
(
H
*
16
,
1
),
(
i_t
*
64
+
48
,
0
),
(
16
,
16
),
(
1
,
0
)
)
A_21
=
tl
.
load
(
p_A_21
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
A_32
=
tl
.
load
(
p_A_32
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
A_31
=
tl
.
load
(
p_A_31
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
A_43
=
tl
.
load
(
p_A_43
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
A_42
=
tl
.
load
(
p_A_42
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
A_41
=
tl
.
load
(
p_A_41
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
Ai_11
=
tl
.
load
(
p_Ad_11
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
Ai_22
=
tl
.
load
(
p_Ad_22
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
Ai_33
=
tl
.
load
(
p_Ad_33
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
Ai_44
=
tl
.
load
(
p_Ad_44
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
Ai_21
=
-
tl
.
dot
(
tl
.
dot
(
Ai_22
,
A_21
,
input_precision
=
"ieee"
),
Ai_11
,
input_precision
=
"ieee"
)
Ai_32
=
-
tl
.
dot
(
tl
.
dot
(
Ai_33
,
A_32
,
input_precision
=
"ieee"
),
Ai_22
,
input_precision
=
"ieee"
)
Ai_43
=
-
tl
.
dot
(
tl
.
dot
(
Ai_44
,
A_43
,
input_precision
=
"ieee"
),
Ai_33
,
input_precision
=
"ieee"
)
Ai_31
=
-
tl
.
dot
(
Ai_33
,
tl
.
dot
(
A_31
,
Ai_11
,
input_precision
=
"ieee"
)
+
tl
.
dot
(
A_32
,
Ai_21
,
input_precision
=
"ieee"
),
input_precision
=
"ieee"
,
)
Ai_42
=
-
tl
.
dot
(
Ai_44
,
tl
.
dot
(
A_42
,
Ai_22
,
input_precision
=
"ieee"
)
+
tl
.
dot
(
A_43
,
Ai_32
,
input_precision
=
"ieee"
),
input_precision
=
"ieee"
,
)
Ai_41
=
-
tl
.
dot
(
Ai_44
,
tl
.
dot
(
A_41
,
Ai_11
,
input_precision
=
"ieee"
)
+
tl
.
dot
(
A_42
,
Ai_21
,
input_precision
=
"ieee"
)
+
tl
.
dot
(
A_43
,
Ai_31
,
input_precision
=
"ieee"
),
input_precision
=
"ieee"
,
)
p_Ai_11
=
tl
.
make_block_ptr
(
Ai
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
,
0
),
(
16
,
16
),
(
1
,
0
)
)
p_Ai_22
=
tl
.
make_block_ptr
(
Ai
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
+
16
,
16
),
(
16
,
16
),
(
1
,
0
)
)
p_Ai_33
=
tl
.
make_block_ptr
(
Ai
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
+
32
,
32
),
(
16
,
16
),
(
1
,
0
)
)
p_Ai_44
=
tl
.
make_block_ptr
(
Ai
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
+
48
,
48
),
(
16
,
16
),
(
1
,
0
)
)
p_Ai_21
=
tl
.
make_block_ptr
(
Ai
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
+
16
,
0
),
(
16
,
16
),
(
1
,
0
)
)
p_Ai_31
=
tl
.
make_block_ptr
(
Ai
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
+
32
,
0
),
(
16
,
16
),
(
1
,
0
)
)
p_Ai_32
=
tl
.
make_block_ptr
(
Ai
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
+
32
,
16
),
(
16
,
16
),
(
1
,
0
)
)
p_Ai_41
=
tl
.
make_block_ptr
(
Ai
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
+
48
,
0
),
(
16
,
16
),
(
1
,
0
)
)
p_Ai_42
=
tl
.
make_block_ptr
(
Ai
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
+
48
,
16
),
(
16
,
16
),
(
1
,
0
)
)
p_Ai_43
=
tl
.
make_block_ptr
(
Ai
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
+
48
,
32
),
(
16
,
16
),
(
1
,
0
)
)
tl
.
store
(
p_Ai_11
,
Ai_11
.
to
(
p_Ai_11
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
),
)
tl
.
store
(
p_Ai_22
,
Ai_22
.
to
(
p_Ai_22
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
),
)
tl
.
store
(
p_Ai_33
,
Ai_33
.
to
(
p_Ai_33
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
),
)
tl
.
store
(
p_Ai_44
,
Ai_44
.
to
(
p_Ai_44
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
),
)
tl
.
store
(
p_Ai_21
,
Ai_21
.
to
(
p_Ai_21
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
),
)
tl
.
store
(
p_Ai_31
,
Ai_31
.
to
(
p_Ai_31
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
),
)
tl
.
store
(
p_Ai_32
,
Ai_32
.
to
(
p_Ai_32
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
),
)
tl
.
store
(
p_Ai_41
,
Ai_41
.
to
(
p_Ai_41
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
),
)
tl
.
store
(
p_Ai_42
,
Ai_42
.
to
(
p_Ai_42
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
),
)
tl
.
store
(
p_Ai_43
,
Ai_43
.
to
(
p_Ai_43
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
),
)
fill_zeros
=
tl
.
zeros
((
16
,
16
),
dtype
=
tl
.
float32
)
p_Ai_12
=
tl
.
make_block_ptr
(
Ai
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
,
16
),
(
16
,
16
),
(
1
,
0
)
)
p_Ai_13
=
tl
.
make_block_ptr
(
Ai
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
,
32
),
(
16
,
16
),
(
1
,
0
)
)
p_Ai_14
=
tl
.
make_block_ptr
(
Ai
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
,
48
),
(
16
,
16
),
(
1
,
0
)
)
p_Ai_23
=
tl
.
make_block_ptr
(
Ai
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
+
16
,
32
),
(
16
,
16
),
(
1
,
0
)
)
p_Ai_24
=
tl
.
make_block_ptr
(
Ai
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
+
16
,
48
),
(
16
,
16
),
(
1
,
0
)
)
p_Ai_34
=
tl
.
make_block_ptr
(
Ai
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
+
32
,
48
),
(
16
,
16
),
(
1
,
0
)
)
tl
.
store
(
p_Ai_12
,
fill_zeros
.
to
(
p_Ai_12
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
),
)
tl
.
store
(
p_Ai_13
,
fill_zeros
.
to
(
p_Ai_13
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
),
)
tl
.
store
(
p_Ai_14
,
fill_zeros
.
to
(
p_Ai_14
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
),
)
tl
.
store
(
p_Ai_23
,
fill_zeros
.
to
(
p_Ai_23
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
),
)
tl
.
store
(
p_Ai_24
,
fill_zeros
.
to
(
p_Ai_24
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
),
)
tl
.
store
(
p_Ai_34
,
fill_zeros
.
to
(
p_Ai_34
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
),
)
@
input_guard
def
solve_tril
(
A
:
torch
.
Tensor
,
cu_seqlens
:
Optional
[
torch
.
Tensor
]
=
None
,
output_dtype
:
torch
.
dtype
=
torch
.
float
,
)
->
torch
.
Tensor
:
"""
Compute the inverse of the lower triangular matrix
A should be strictly lower triangular, i.e., A.triu() == 0.
Args:
A (torch.Tensor):
[B, T, H, K]
cu_seqlens (torch.Tensor):
The cumulative sequence lengths of the input tensor.
Default: None.
output_dtype (torch.dtype):
The dtype of the output tensor. Default: `torch.float`
Returns:
(I + A)^-1 with the same shape as A
"""
assert
A
.
shape
[
-
1
]
in
[
16
,
32
,
64
]
B
,
T
,
H
,
BT
=
A
.
shape
Ad
=
torch
.
empty
(
B
,
T
,
H
,
16
,
device
=
A
.
device
,
dtype
=
torch
.
float
if
BT
!=
16
else
output_dtype
)
chunk_indices
=
(
prepare_chunk_indices
(
cu_seqlens
,
16
)
if
cu_seqlens
is
not
None
else
None
)
NT
=
len
(
chunk_indices
)
if
cu_seqlens
is
not
None
else
triton
.
cdiv
(
T
,
16
)
solve_tril_16x16_kernel
[
NT
,
B
*
H
](
A
=
A
,
Ad
=
Ad
,
cu_seqlens
=
cu_seqlens
,
chunk_indices
=
chunk_indices
,
T
=
T
,
H
=
H
,
BT
=
BT
,
num_warps
=
1
,
num_stages
=
4
,
)
if
BT
==
16
:
return
Ad
Ai
=
torch
.
empty
(
B
,
T
,
H
,
BT
,
device
=
A
.
device
,
dtype
=
output_dtype
)
merge_fn
=
(
merge_16x16_to_32x32_inverse_kernel
if
BT
==
32
else
merge_16x16_to_64x64_inverse_kernel
)
chunk_indices
=
(
prepare_chunk_indices
(
cu_seqlens
,
BT
)
if
cu_seqlens
is
not
None
else
None
)
NT
=
len
(
chunk_indices
)
if
cu_seqlens
is
not
None
else
triton
.
cdiv
(
T
,
BT
)
merge_fn
[
NT
,
B
*
H
](
A
=
A
,
Ad
=
Ad
,
Ai
=
Ai
,
cu_seqlens
=
cu_seqlens
,
chunk_indices
=
chunk_indices
,
T
=
T
,
H
=
H
,
BT
=
BT
,
num_warps
=
4
,
num_stages
=
3
,
)
return
Ai
python/sglang/srt/layers/attention/fla/utils.py
0 → 100644
View file @
dc491b39
# Adapt from https://github.com/fla-org/flash-linear-attention/blob/main/fla/utils.py
# -*- coding: utf-8 -*-
import
contextlib
import
functools
import
logging
import
os
import
sys
from
enum
import
Enum
from
functools
import
lru_cache
from
typing
import
Any
,
Callable
,
Dict
,
Literal
,
Optional
,
Tuple
import
torch
import
triton
from
packaging
import
version
logger
=
logging
.
getLogger
(
__name__
)
COMPILER_MODE
=
os
.
getenv
(
"FLA_COMPILER_MODE"
)
==
"1"
FLA_CI_ENV
=
os
.
getenv
(
"FLA_CI_ENV"
)
==
"1"
@
lru_cache
(
maxsize
=
1
)
def
check_environments
():
"""
Checks the current operating system, Triton version, and Python version,
issuing warnings if they don't meet recommendations.
This function's body only runs once due to lru_cache.
"""
# Check Operating System
if
sys
.
platform
==
"win32"
:
logger
.
warning
(
"Detected Windows operating system. Triton does not have an official Windows release, "
"thus FLA will not be adapted for Windows, and any potential errors will not be fixed. "
"Please consider using a Linux environment for compatibility."
)
triton_version
=
version
.
parse
(
triton
.
__version__
)
required_triton_version
=
version
.
parse
(
"3.2.0"
)
if
triton_version
<
required_triton_version
:
logger
.
warning
(
f
"Current Triton version
{
triton_version
}
is below the recommended 3.2.0 version. "
"Errors may occur and these issues will not be fixed. "
"Please consider upgrading Triton."
)
# Check Python version
py_version
=
version
.
parse
(
f
"
{
sys
.
version_info
.
major
}
.
{
sys
.
version_info
.
minor
}
"
)
required_py_version
=
version
.
parse
(
"3.11"
)
if
py_version
<
required_py_version
:
logger
.
warning
(
f
"Current Python version
{
py_version
}
is below the recommended 3.11 version. "
"It is recommended to upgrade to Python 3.11 or higher for the best experience."
)
return
None
check_environments
()
def
get_abs_err
(
x
,
y
):
return
(
x
.
detach
()
-
y
.
detach
()).
flatten
().
abs
().
max
().
item
()
def
get_err_ratio
(
x
,
y
):
err
=
(
x
.
detach
()
-
y
.
detach
()).
flatten
().
square
().
mean
().
sqrt
().
item
()
base
=
(
x
.
detach
()).
flatten
().
square
().
mean
().
sqrt
().
item
()
return
err
/
(
base
+
1e-8
)
def
assert_close
(
prefix
,
ref
,
tri
,
ratio
,
warning
=
False
,
err_atol
=
1e-6
):
abs_atol
=
get_abs_err
(
ref
,
tri
)
msg
=
f
"
{
prefix
}
diff:
{
abs_atol
:.
6
f
}
ratio:
{
get_err_ratio
(
ref
,
tri
):.
6
f
}
"
logger
.
info
(
msg
)
error_rate
=
get_err_ratio
(
ref
,
tri
)
if
abs_atol
<=
err_atol
:
return
if
warning
or
(
FLA_CI_ENV
and
(
error_rate
<
0.01
or
abs_atol
<=
0.3
)):
if
error_rate
>
ratio
:
import
warnings
warnings
.
warn
(
msg
)
else
:
assert
error_rate
<
ratio
,
msg
SUPPRESS_LEVEL
=
int
(
os
.
getenv
(
"GDN_RECOMPUTE_SUPPRESS_LEVEL"
,
"0"
))
def
tensor_cache
(
fn
:
Callable
[...,
torch
.
Tensor
])
->
Callable
[...,
torch
.
Tensor
]:
"""
A decorator that caches the most recent results of a function with tensor inputs.
This decorator will store the output of the decorated function for the most recent set of input tensors.
The cache is limited to a fixed size (default is 4). When the cache is full, the oldest entry will be removed.
Args:
fn (Callable[..., torch.Tensor]):
The function to be decorated. It should take tensor inputs and return tensor outputs.
Returns:
Callable[..., torch.Tensor]:
A wrapped version of the input function with single-entry caching.
"""
cache_entries
:
Tuple
[
Optional
[
Tuple
],
Optional
[
Dict
],
Any
]
=
[]
cache_size
=
4
@
functools
.
wraps
(
fn
)
def
wrapper
(
*
args
:
Any
,
**
kwargs
:
Any
)
->
Any
:
nonlocal
cache_entries
,
cache_size
for
i
,
entry
in
enumerate
(
cache_entries
):
last_args
,
last_kwargs
,
last_result
=
entry
if
len
(
args
)
==
len
(
last_args
)
and
len
(
kwargs
)
==
len
(
last_kwargs
):
if
all
(
a
is
b
for
a
,
b
in
zip
(
args
,
last_args
))
and
all
(
k
in
last_kwargs
and
v
is
last_kwargs
[
k
]
for
k
,
v
in
kwargs
.
items
()
):
cache_entries
=
(
cache_entries
[:
i
]
+
cache_entries
[
i
+
1
:]
+
[(
args
,
kwargs
,
last_result
)]
)
return
last_result
result
=
fn
(
*
args
,
**
kwargs
)
if
len
(
cache_entries
)
>=
cache_size
:
cache_entries
=
cache_entries
[
1
:]
cache_entries
.
append
((
args
,
kwargs
,
result
))
return
result
return
wrapper
def
input_guard
(
fn
:
Callable
[...,
torch
.
Tensor
])
->
Callable
[...,
torch
.
Tensor
]:
"""
A decorator to make sure all input tensors are contiguous and set the device based on input tensors.
"""
@
functools
.
wraps
(
fn
)
def
wrapper
(
*
args
,
**
kwargs
):
contiguous_args
=
(
i
if
not
isinstance
(
i
,
torch
.
Tensor
)
else
i
.
contiguous
()
for
i
in
args
)
contiguous_kwargs
=
{
k
:
(
v
if
not
isinstance
(
v
,
torch
.
Tensor
)
else
v
.
contiguous
())
for
k
,
v
in
kwargs
.
items
()
}
tensor
=
None
for
arg
in
args
:
if
isinstance
(
arg
,
torch
.
Tensor
):
tensor
=
arg
break
if
tensor
is
None
:
for
value
in
kwargs
.
values
():
if
isinstance
(
value
,
torch
.
Tensor
):
tensor
=
value
break
if
tensor
is
not
None
:
ctx
=
custom_device_ctx
(
tensor
.
device
.
index
)
else
:
ctx
=
contextlib
.
nullcontext
()
with
ctx
:
return
fn
(
*
contiguous_args
,
**
contiguous_kwargs
)
return
wrapper
contiguous
=
input_guard
def
require_version
(
version
,
hint
):
"""
Perform a runtime check of the dependency versions, using the exact same syntax used by pip.
"""
def
decorator
(
fn
):
@
functools
.
wraps
(
fn
)
def
wrapper
(
ctx
,
*
args
,
**
kwargs
):
from
transformers.utils.versions
import
require_version
require_version
(
version
,
hint
)
return
fn
(
ctx
,
*
(
i
if
not
isinstance
(
i
,
torch
.
Tensor
)
else
i
.
contiguous
()
for
i
in
args
),
**
{
k
:
(
v
if
not
isinstance
(
v
,
torch
.
Tensor
)
else
v
.
contiguous
())
for
k
,
v
in
kwargs
.
items
()
},
)
return
wrapper
return
decorator
def
checkpoint
(
fn
):
def
wrapper
(
*
args
,
**
kwargs
):
return
torch
.
utils
.
checkpoint
.
checkpoint
(
fn
,
*
args
,
**
kwargs
)
return
wrapper
@
lru_cache
(
maxsize
=
None
)
def
check_pytorch_version
(
version_s
:
str
=
"2.4"
)
->
bool
:
return
version
.
parse
(
torch
.
__version__
)
>=
version
.
parse
(
version_s
)
def
_cpu_device_warning
():
import
warnings
warnings
.
warn
(
(
"Triton is not supported on current platform, roll back to CPU."
),
stacklevel
=
1
)
@
lru_cache
(
maxsize
=
None
)
def
get_multiprocessor_count
(
tensor_idx
:
int
=
0
)
->
int
:
try
:
return
triton
.
runtime
.
driver
.
active
.
utils
.
get_device_properties
(
tensor_idx
)[
"multiprocessor_count"
]
except
BaseException
:
_cpu_device_warning
()
return
-
1
@
lru_cache
(
maxsize
=
None
)
def
get_available_device
()
->
str
:
try
:
return
triton
.
runtime
.
driver
.
active
.
get_current_target
().
backend
except
BaseException
:
_cpu_device_warning
()
return
"cpu"
@
lru_cache
(
maxsize
=
None
)
def
_check_platform
()
->
Literal
[
"nvidia"
,
"amd"
,
"intel"
,
"musa"
]:
device
=
get_available_device
()
if
device
==
"cuda"
:
return
"nvidia"
elif
device
==
"hip"
:
return
"amd"
elif
device
==
"xpu"
:
return
"intel"
else
:
return
device
# For AMD GPUs, the triton backend is 'hip', while for Nvidia GPUs, the triton backend is 'cuda'.
# However, the torch backend is 'cuda' for both Nvidia and AMD GPUs.
# Therefore, we need to check the triton backend to determine the actual GPU vendor.
device
=
get_available_device
()
if
get_available_device
()
!=
"hip"
else
"cuda"
device_torch_lib
=
getattr
(
torch
,
device
)
device_platform
=
_check_platform
()
is_amd
=
device_platform
==
"amd"
is_intel
=
device_platform
==
"intel"
is_nvidia
=
device_platform
==
"nvidia"
is_intel_alchemist
=
is_intel
and
"Intel(R) Arc(TM) A"
in
torch
.
xpu
.
get_device_name
(
0
)
is_nvidia_hopper
=
is_nvidia
and
(
"NVIDIA H"
in
torch
.
cuda
.
get_device_name
(
0
)
or
torch
.
cuda
.
get_device_capability
()[
0
]
>=
9
)
use_cuda_graph
=
is_nvidia
and
os
.
environ
.
get
(
"FLA_USE_CUDA_GRAPH"
,
"0"
)
==
"1"
# Nvidia Ampere or newer, haven't check AMD and intel yet.
is_tf32_supported
=
is_nvidia
and
torch
.
cuda
.
get_device_capability
(
0
)[
0
]
>=
8
is_gather_supported
=
hasattr
(
triton
.
language
,
"gather"
)
def
get_all_max_shared_mem
():
try
:
return
[
triton
.
runtime
.
driver
.
active
.
utils
.
get_device_properties
(
i
)[
"max_shared_mem"
]
for
i
in
range
(
device_torch_lib
.
device_count
())
]
except
BaseException
:
_cpu_device_warning
()
return
[
-
1
]
class
Backend
(
Enum
):
ADA
=
101376
# RTX 4090
AMPERE
=
166912
# A100
HOPPER
=
232448
# H100
DEFAULT
=
102400
# Default
@
classmethod
def
get_shared_memory
(
cls
,
arch
:
str
)
->
int
:
try
:
return
cls
[
arch
.
upper
()].
value
except
KeyError
:
return
cls
.
DEFAULT
.
value
@
lru_cache
(
maxsize
=
None
)
def
check_shared_mem
(
arch
:
str
=
"none"
,
tensor_idx
:
int
=
0
)
->
bool
:
try
:
device_shared_mem_list
=
get_all_max_shared_mem
()
max_shared_memory
=
device_shared_mem_list
[
tensor_idx
]
return
max_shared_memory
>=
Backend
.
get_shared_memory
(
arch
)
except
Exception
:
return
False
if
check_pytorch_version
(
"2.4"
):
device
=
"cuda"
if
device
==
"cpu"
else
device
autocast_custom_fwd
=
functools
.
partial
(
torch
.
amp
.
custom_fwd
,
device_type
=
device
)
autocast_custom_bwd
=
functools
.
partial
(
torch
.
amp
.
custom_bwd
,
device_type
=
device
)
def
custom_device_ctx
(
index
:
int
):
return
device_torch_lib
.
device
(
index
)
else
:
assert
(
device
==
"cuda"
),
"Only cuda device is supported for PyTorch version < 2.4.0."
autocast_custom_fwd
=
device_torch_lib
.
amp
.
custom_fwd
autocast_custom_bwd
=
device_torch_lib
.
amp
.
custom_bwd
def
custom_device_ctx
(
index
:
int
):
return
torch
.
cuda
.
device
(
index
)
python/sglang/srt/layers/attention/fla/wy_fast.py
0 → 100644
View file @
dc491b39
# Adapt from https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/gated_delta_rule/wy_fast.py
# -*- coding: utf-8 -*-
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
from
typing
import
Optional
,
Tuple
import
torch
import
triton
import
triton.language
as
tl
from
sglang.srt.layers.attention.fla.index
import
prepare_chunk_indices
from
sglang.srt.layers.attention.fla.op
import
safe_exp
from
sglang.srt.layers.attention.fla.utils
import
check_shared_mem
@
triton
.
heuristics
({
"IS_VARLEN"
:
lambda
args
:
args
[
"cu_seqlens"
]
is
not
None
})
# @triton.autotune(
# configs=[
# triton.Config({}, num_warps=num_warps, num_stages=num_stages)
# for num_warps in [2, 4, 8]
# for num_stages in [2, 3, 4]
# ],
# key=["H", "K", "V", "BT", "BK", "BV", "IS_VARLEN"],
# )
@
triton
.
jit
(
do_not_specialize
=
[
"T"
])
def
recompute_w_u_fwd_kernel
(
k
,
v
,
beta
,
w
,
u
,
A
,
g
,
cu_seqlens
,
chunk_indices
,
T
,
H
:
tl
.
constexpr
,
Hg
:
tl
.
constexpr
,
K
:
tl
.
constexpr
,
V
:
tl
.
constexpr
,
BT
:
tl
.
constexpr
,
BK
:
tl
.
constexpr
,
BV
:
tl
.
constexpr
,
IS_VARLEN
:
tl
.
constexpr
,
):
i_t
,
i_bh
=
tl
.
program_id
(
0
),
tl
.
program_id
(
1
)
i_b
,
i_h
=
i_bh
//
H
,
i_bh
%
H
if
IS_VARLEN
:
i_n
,
i_t
=
tl
.
load
(
chunk_indices
+
i_t
*
2
).
to
(
tl
.
int32
),
tl
.
load
(
chunk_indices
+
i_t
*
2
+
1
).
to
(
tl
.
int32
)
bos
,
eos
=
tl
.
load
(
cu_seqlens
+
i_n
).
to
(
tl
.
int32
),
tl
.
load
(
cu_seqlens
+
i_n
+
1
).
to
(
tl
.
int32
)
T
=
eos
-
bos
else
:
bos
,
eos
=
i_b
*
T
,
i_b
*
T
+
T
p_beta
=
tl
.
make_block_ptr
(
beta
+
bos
*
H
+
i_h
,
(
T
,),
(
H
,),
(
i_t
*
BT
,),
(
BT
,),
(
0
,)
)
p_g
=
tl
.
make_block_ptr
(
g
+
(
bos
*
H
+
i_h
),
(
T
,),
(
H
,),
(
i_t
*
BT
,),
(
BT
,),
(
0
,))
p_A
=
tl
.
make_block_ptr
(
A
+
(
bos
*
H
+
i_h
)
*
BT
,
(
T
,
BT
),
(
H
*
BT
,
1
),
(
i_t
*
BT
,
0
),
(
BT
,
BT
),
(
1
,
0
)
)
b_beta
=
tl
.
load
(
p_beta
,
boundary_check
=
(
0
,))
b_A
=
tl
.
load
(
p_A
,
boundary_check
=
(
0
,
1
))
b_g
=
tl
.
exp
(
tl
.
load
(
p_g
,
boundary_check
=
(
0
,)))
for
i_v
in
range
(
tl
.
cdiv
(
V
,
BV
)):
p_v
=
tl
.
make_block_ptr
(
v
+
(
bos
*
H
+
i_h
)
*
V
,
(
T
,
V
),
(
H
*
V
,
1
),
(
i_t
*
BT
,
i_v
*
BV
),
(
BT
,
BV
),
(
1
,
0
),
)
p_u
=
tl
.
make_block_ptr
(
u
+
(
bos
*
H
+
i_h
)
*
V
,
(
T
,
V
),
(
H
*
V
,
1
),
(
i_t
*
BT
,
i_v
*
BV
),
(
BT
,
BV
),
(
1
,
0
),
)
b_v
=
tl
.
load
(
p_v
,
boundary_check
=
(
0
,
1
))
b_vb
=
(
b_v
*
b_beta
[:,
None
]).
to
(
b_v
.
dtype
)
b_u
=
tl
.
dot
(
b_A
,
b_vb
,
allow_tf32
=
False
)
tl
.
store
(
p_u
,
b_u
.
to
(
p_u
.
dtype
.
element_ty
),
boundary_check
=
(
0
,
1
))
for
i_k
in
range
(
tl
.
cdiv
(
K
,
BK
)):
p_k
=
tl
.
make_block_ptr
(
k
+
(
bos
*
Hg
+
i_h
//
(
H
//
Hg
))
*
K
,
(
T
,
K
),
(
Hg
*
K
,
1
),
(
i_t
*
BT
,
i_k
*
BK
),
(
BT
,
BK
),
(
1
,
0
),
)
p_w
=
tl
.
make_block_ptr
(
w
+
(
bos
*
H
+
i_h
)
*
K
,
(
T
,
K
),
(
H
*
K
,
1
),
(
i_t
*
BT
,
i_k
*
BK
),
(
BT
,
BK
),
(
1
,
0
),
)
b_k
=
tl
.
load
(
p_k
,
boundary_check
=
(
0
,
1
))
b_kb
=
(
b_k
*
b_beta
[:,
None
]
*
b_g
[:,
None
]).
to
(
b_k
.
dtype
)
b_w
=
tl
.
dot
(
b_A
,
b_kb
)
tl
.
store
(
p_w
,
b_w
.
to
(
p_w
.
dtype
.
element_ty
),
boundary_check
=
(
0
,
1
))
def
recompute_w_u_fwd
(
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
beta
:
torch
.
Tensor
,
g_cumsum
:
torch
.
Tensor
,
A
:
torch
.
Tensor
,
cu_seqlens
:
Optional
[
torch
.
LongTensor
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
B
,
T
,
Hg
,
K
,
V
=
*
k
.
shape
,
v
.
shape
[
-
1
]
H
=
v
.
shape
[
-
2
]
BT
=
A
.
shape
[
-
1
]
chunk_indices
=
(
prepare_chunk_indices
(
cu_seqlens
,
BT
)
if
cu_seqlens
is
not
None
else
None
)
NT
=
triton
.
cdiv
(
T
,
BT
)
if
cu_seqlens
is
None
else
len
(
chunk_indices
)
BK
=
64
BV
=
64
u
=
torch
.
empty_like
(
v
)
w
=
k
.
new_empty
(
B
,
T
,
H
,
K
)
recompute_w_u_fwd_kernel
[(
NT
,
B
*
H
)](
k
=
k
,
v
=
v
,
beta
=
beta
,
w
=
w
,
u
=
u
,
A
=
A
,
g
=
g_cumsum
,
cu_seqlens
=
cu_seqlens
,
chunk_indices
=
chunk_indices
,
T
=
T
,
H
=
H
,
Hg
=
Hg
,
K
=
K
,
V
=
V
,
BT
=
BT
,
BK
=
BK
,
BV
=
BV
,
num_warps
=
4
,
num_stages
=
3
,
)
return
w
,
u
fwd_recompute_w_u
=
recompute_w_u_fwd
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment