Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1aa427fd
Unverified
Commit
1aa427fd
authored
Sep 10, 2025
by
youkaichao
Committed by
GitHub
Sep 10, 2025
Browse files
[Kernels] Add Flash Linear Attention Kernels (#24518)
Signed-off-by:
youkaichao
<
youkaichao@gmail.com
>
parent
1c63a16b
Changes
17
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
2671 additions
and
2 deletions
+2671
-2
tools/mypy.sh
tools/mypy.sh
+1
-1
vllm/model_executor/layers/fla/__init__.py
vllm/model_executor/layers/fla/__init__.py
+8
-0
vllm/model_executor/layers/fla/ops/__init__.py
vllm/model_executor/layers/fla/ops/__init__.py
+17
-0
vllm/model_executor/layers/fla/ops/chunk.py
vllm/model_executor/layers/fla/ops/chunk.py
+225
-0
vllm/model_executor/layers/fla/ops/chunk_delta_h.py
vllm/model_executor/layers/fla/ops/chunk_delta_h.py
+289
-0
vllm/model_executor/layers/fla/ops/chunk_o.py
vllm/model_executor/layers/fla/ops/chunk_o.py
+176
-0
vllm/model_executor/layers/fla/ops/chunk_scaled_dot_kkt.py
vllm/model_executor/layers/fla/ops/chunk_scaled_dot_kkt.py
+138
-0
vllm/model_executor/layers/fla/ops/cumsum.py
vllm/model_executor/layers/fla/ops/cumsum.py
+226
-0
vllm/model_executor/layers/fla/ops/fused_recurrent.py
vllm/model_executor/layers/fla/ops/fused_recurrent.py
+366
-0
vllm/model_executor/layers/fla/ops/index.py
vllm/model_executor/layers/fla/ops/index.py
+39
-0
vllm/model_executor/layers/fla/ops/l2norm.py
vllm/model_executor/layers/fla/ops/l2norm.py
+143
-0
vllm/model_executor/layers/fla/ops/layernorm_guard.py
vllm/model_executor/layers/fla/ops/layernorm_guard.py
+337
-0
vllm/model_executor/layers/fla/ops/op.py
vllm/model_executor/layers/fla/ops/op.py
+44
-0
vllm/model_executor/layers/fla/ops/solve_tril.py
vllm/model_executor/layers/fla/ops/solve_tril.py
+365
-0
vllm/model_executor/layers/fla/ops/utils.py
vllm/model_executor/layers/fla/ops/utils.py
+180
-0
vllm/model_executor/layers/fla/ops/wy_fast.py
vllm/model_executor/layers/fla/ops/wy_fast.py
+114
-0
vllm/triton_utils/__init__.py
vllm/triton_utils/__init__.py
+3
-1
No files found.
tools/mypy.sh
View file @
1aa427fd
...
@@ -29,7 +29,7 @@ run_mypy vllm/engine
...
@@ -29,7 +29,7 @@ run_mypy vllm/engine
run_mypy vllm/executor
run_mypy vllm/executor
run_mypy vllm/inputs
run_mypy vllm/inputs
run_mypy vllm/lora
run_mypy vllm/lora
run_mypy vllm/model_executor
run_mypy
--exclude
'vllm/model_executor/layers/fla/ops'
vllm/model_executor
run_mypy vllm/plugins
run_mypy vllm/plugins
run_mypy vllm/worker
run_mypy vllm/worker
run_mypy vllm/v1
run_mypy vllm/v1
vllm/model_executor/layers/fla/__init__.py
0 → 100644
View file @
1aa427fd
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
#
# This file contains code copied from the flash-linear-attention project.
# The original source code was licensed under the MIT license and included
# the following copyright notice:
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
vllm/model_executor/layers/fla/ops/__init__.py
0 → 100644
View file @
1aa427fd
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
#
# This file contains code copied from the flash-linear-attention project.
# The original source code was licensed under the MIT license and included
# the following copyright notice:
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
from
.chunk
import
chunk_gated_delta_rule
from
.fused_recurrent
import
fused_recurrent_gated_delta_rule
from
.layernorm_guard
import
RMSNormGated
__all__
=
[
"RMSNormGated"
,
"chunk_gated_delta_rule"
,
"fused_recurrent_gated_delta_rule"
,
]
vllm/model_executor/layers/fla/ops/chunk.py
0 → 100644
View file @
1aa427fd
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
#
# This file contains code copied from the flash-linear-attention project.
# The original source code was licensed under the MIT license and included
# the following copyright notice:
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
# ruff: noqa: E501
import
warnings
from
typing
import
Optional
import
torch
from
einops
import
rearrange
from
.chunk_delta_h
import
chunk_gated_delta_rule_fwd_h
from
.chunk_o
import
chunk_fwd_o
from
.chunk_scaled_dot_kkt
import
chunk_scaled_dot_kkt_fwd
from
.cumsum
import
chunk_local_cumsum
from
.l2norm
import
l2norm_fwd
from
.solve_tril
import
solve_tril
from
.utils
import
SUPPRESS_LEVEL
,
input_guard
from
.wy_fast
import
recompute_w_u_fwd
def
chunk_gated_delta_rule_fwd
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
g
:
torch
.
Tensor
,
beta
:
torch
.
Tensor
,
scale
:
float
,
initial_state
:
torch
.
Tensor
,
output_final_state
:
bool
,
cu_seqlens
:
Optional
[
torch
.
LongTensor
]
=
None
):
g
=
chunk_local_cumsum
(
g
,
chunk_size
=
64
,
cu_seqlens
=
cu_seqlens
)
# obtain WY representation. u is actually the new v.
A
=
chunk_scaled_dot_kkt_fwd
(
k
=
k
,
beta
=
beta
,
g_cumsum
=
g
,
cu_seqlens
=
cu_seqlens
,
output_dtype
=
torch
.
float32
)
A
=
solve_tril
(
A
=
A
,
cu_seqlens
=
cu_seqlens
,
output_dtype
=
k
.
dtype
)
w
,
u
=
recompute_w_u_fwd
(
k
=
k
,
v
=
v
,
beta
=
beta
,
A
=
A
,
g_cumsum
=
g
,
cu_seqlens
=
cu_seqlens
,
)
h
,
v_new
,
final_state
=
chunk_gated_delta_rule_fwd_h
(
k
=
k
,
w
=
w
,
u
=
u
,
g
=
g
,
initial_state
=
initial_state
,
output_final_state
=
output_final_state
,
cu_seqlens
=
cu_seqlens
,
)
o
=
chunk_fwd_o
(
q
=
q
,
k
=
k
,
v
=
v_new
,
h
=
h
,
g
=
g
,
scale
=
scale
,
cu_seqlens
=
cu_seqlens
,
)
if
SUPPRESS_LEVEL
<
3
:
return
g
,
o
,
A
,
final_state
,
None
,
None
,
None
elif
SUPPRESS_LEVEL
>=
3
:
return
g
,
o
,
A
,
final_state
,
w
,
h
,
v_new
class
ChunkGatedDeltaRuleFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
@
input_guard
@
torch
.
amp
.
custom_fwd
(
device_type
=
'cuda'
)
def
forward
(
ctx
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
g
:
torch
.
Tensor
,
beta
:
torch
.
Tensor
,
scale
:
float
,
initial_state
:
torch
.
Tensor
,
output_final_state
:
bool
,
cu_seqlens
:
Optional
[
torch
.
LongTensor
]
=
None
,
use_qk_l2norm_in_kernel
:
bool
=
False
):
if
use_qk_l2norm_in_kernel
:
q
=
l2norm_fwd
(
q
)
k
=
l2norm_fwd
(
k
)
g
,
o
,
A
,
final_state
,
w
,
h
,
v_new
=
chunk_gated_delta_rule_fwd
(
q
=
q
,
k
=
k
,
v
=
v
,
g
=
g
,
beta
=
beta
,
scale
=
scale
,
initial_state
=
initial_state
,
output_final_state
=
output_final_state
,
cu_seqlens
=
cu_seqlens
,
)
ctx
.
scale
=
scale
ctx
.
use_qk_l2norm_in_kernel
=
use_qk_l2norm_in_kernel
return
o
.
to
(
q
.
dtype
),
final_state
@
torch
.
compiler
.
disable
def
chunk_gated_delta_rule
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
g
:
torch
.
Tensor
,
beta
:
torch
.
Tensor
,
scale
:
float
=
None
,
initial_state
:
torch
.
Tensor
=
None
,
output_final_state
:
bool
=
False
,
cu_seqlens
:
Optional
[
torch
.
LongTensor
]
=
None
,
head_first
:
bool
=
False
,
use_qk_l2norm_in_kernel
:
bool
=
False
):
r
"""
Args:
q (torch.Tensor):
queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
k (torch.Tensor):
keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
v (torch.Tensor):
values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
g (torch.Tensor):
(forget) gating tensor (in log space!) of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`.
beta (torch.Tensor):
betas of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`.
scale (Optional[int]):
Scale factor for the RetNet attention scores.
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
initial_state (Optional[torch.Tensor]):
Initial state of shape `[N, H, K, V]` for `N` input sequences.
For equal-length input sequences, `N` equals the batch size `B`.
Default: `None`.
output_final_state (Optional[bool]):
Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
cu_seqlens (torch.LongTensor):
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
consistent with the FlashAttention API.
head_first (Optional[bool]):
Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
Default: `False`.
Returns:
o (torch.Tensor):
Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
final_state (torch.Tensor):
Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`.
Examples::
>>> import torch
>>> import torch.nn.functional as F
>>> from einops import rearrange
>>> from fla.ops.gated_delta_rule import chunk_gated_delta_rule
# inputs with equal lengths
>>> B, T, H, K, V = 4, 2048, 4, 512, 512
>>> q = torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda')
>>> k = F.normalize(torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda'), p=2, dim=-1)
>>> v = torch.randn(B, T, H, V, dtype=torch.bfloat16, device='cuda')
>>> beta = torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda').sigmoid()
>>> g = F.logsigmoid(torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda'))
>>> h0 = torch.randn(B, H, K, V, dtype=torch.bfloat16, device='cuda')
>>> o, ht = chunk_gated_delta_rule(
q, k, v, g, beta,
initial_state=h0,
output_final_state=True
)
# for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
>>> q, k, v, beta, g = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, beta, g))
# for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
>>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
>>> o_var, ht_var = chunk_gated_delta_rule(
q, k, v, g, beta,
initial_state=h0,
output_final_state=True,
cu_seqlens=cu_seqlens
)
"""
assert
q
.
dtype
==
k
.
dtype
==
v
.
dtype
assert
q
.
dtype
!=
torch
.
float32
,
"ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16."
assert
len
(
beta
.
shape
)
==
3
,
"beta must be of shape [B, T, H] if head_first=False, or [B, H, T] otherwise."
if
head_first
:
raise
DeprecationWarning
(
"head_first is deprecated and will be removed in a future version. "
"Please use head_first=False for now instead."
,
stacklevel
=
2
)
q
,
k
,
v
,
beta
,
g
=
map
(
lambda
x
:
rearrange
(
x
,
'b h t ... -> b t h ...'
),
(
q
,
k
,
v
,
beta
,
g
))
if
not
head_first
and
q
.
shape
[
1
]
<
q
.
shape
[
2
]:
warnings
.
warn
(
f
"Input tensor shape suggests potential format mismatch: seq_len (
{
q
.
shape
[
1
]
}
) < num_heads (
{
q
.
shape
[
2
]
}
). "
"This may indicate the inputs were passed in head-first format [B, H, T, ...] "
"when head_first=False was specified. "
"Please verify your input tensor format matches the expected shape [B, T, H, ...]."
,
stacklevel
=
2
)
if
cu_seqlens
is
not
None
:
if
q
.
shape
[
0
]
!=
1
:
raise
ValueError
(
f
"The batch size is expected to be 1 rather than
{
q
.
shape
[
0
]
}
when using `cu_seqlens`."
f
"Please flatten variable-length inputs before processing."
)
if
initial_state
is
not
None
and
initial_state
.
shape
[
0
]
!=
len
(
cu_seqlens
)
-
1
:
raise
ValueError
(
f
"The number of initial states is expected to be equal to the number of input sequences, "
f
"i.e.,
{
len
(
cu_seqlens
)
-
1
}
rather than
{
initial_state
.
shape
[
0
]
}
."
)
if
scale
is
None
:
scale
=
k
.
shape
[
-
1
]
**-
0.5
o
,
final_state
=
ChunkGatedDeltaRuleFunction
.
apply
(
q
,
k
,
v
,
g
,
beta
,
scale
,
initial_state
,
output_final_state
,
cu_seqlens
,
use_qk_l2norm_in_kernel
)
if
head_first
:
o
=
rearrange
(
o
,
'b t h ... -> b h t ...'
)
return
o
,
final_state
vllm/model_executor/layers/fla/ops/chunk_delta_h.py
0 → 100644
View file @
1aa427fd
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
#
# This file contains code copied from the flash-linear-attention project.
# The original source code was licensed under the MIT license and included
# the following copyright notice:
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
# ruff: noqa: E501
from
typing
import
Optional
import
torch
from
vllm.triton_utils
import
tl
,
triton
from
.index
import
prepare_chunk_indices
,
prepare_chunk_offsets
from
.op
import
exp
,
safe_exp
from
.utils
import
is_nvidia_hopper
,
use_cuda_graph
NUM_WARPS
=
[
2
,
4
]
if
is_nvidia_hopper
else
[
2
,
4
,
8
,
16
]
@
triton
.
heuristics
({
'USE_G'
:
lambda
args
:
args
[
'g'
]
is
not
None
,
'USE_INITIAL_STATE'
:
lambda
args
:
args
[
'h0'
]
is
not
None
,
'STORE_FINAL_STATE'
:
lambda
args
:
args
[
'ht'
]
is
not
None
,
'SAVE_NEW_VALUE'
:
lambda
args
:
args
[
'v_new'
]
is
not
None
,
'IS_VARLEN'
:
lambda
args
:
args
[
'cu_seqlens'
]
is
not
None
,
})
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({
'BV'
:
BV
},
num_warps
=
num_warps
,
num_stages
=
num_stages
)
for
num_warps
in
[
2
,
4
]
for
num_stages
in
[
2
,
3
,
4
]
for
BV
in
[
32
,
64
]
],
key
=
[
'H'
,
'K'
,
'V'
,
'BT'
,
'USE_G'
],
use_cuda_graph
=
use_cuda_graph
,
)
@
triton
.
jit
(
do_not_specialize
=
[
'T'
])
def
chunk_gated_delta_rule_fwd_kernel_h_blockdim64
(
k
,
v
,
w
,
v_new
,
g
,
h
,
h0
,
ht
,
cu_seqlens
,
chunk_offsets
,
T
,
H
:
tl
.
constexpr
,
Hg
:
tl
.
constexpr
,
K
:
tl
.
constexpr
,
V
:
tl
.
constexpr
,
BT
:
tl
.
constexpr
,
BV
:
tl
.
constexpr
,
USE_G
:
tl
.
constexpr
,
USE_INITIAL_STATE
:
tl
.
constexpr
,
STORE_FINAL_STATE
:
tl
.
constexpr
,
SAVE_NEW_VALUE
:
tl
.
constexpr
,
IS_VARLEN
:
tl
.
constexpr
,
):
i_v
,
i_nh
=
tl
.
program_id
(
0
),
tl
.
program_id
(
1
)
i_n
,
i_h
=
i_nh
//
H
,
i_nh
%
H
if
IS_VARLEN
:
bos
,
eos
=
tl
.
load
(
cu_seqlens
+
i_n
).
to
(
tl
.
int32
),
tl
.
load
(
cu_seqlens
+
i_n
+
1
).
to
(
tl
.
int32
)
T
=
eos
-
bos
NT
=
tl
.
cdiv
(
T
,
BT
)
boh
=
tl
.
load
(
chunk_offsets
+
i_n
).
to
(
tl
.
int32
)
else
:
bos
,
eos
=
i_n
*
T
,
i_n
*
T
+
T
NT
=
tl
.
cdiv
(
T
,
BT
)
boh
=
i_n
*
NT
# [BK, BV]
b_h1
=
tl
.
zeros
([
64
,
BV
],
dtype
=
tl
.
float32
)
if
K
>
64
:
b_h2
=
tl
.
zeros
([
64
,
BV
],
dtype
=
tl
.
float32
)
if
K
>
128
:
b_h3
=
tl
.
zeros
([
64
,
BV
],
dtype
=
tl
.
float32
)
if
K
>
192
:
b_h4
=
tl
.
zeros
([
64
,
BV
],
dtype
=
tl
.
float32
)
# calculate offset
h
+=
(
boh
*
H
+
i_h
)
*
K
*
V
v
+=
(
bos
*
H
+
i_h
)
*
V
k
+=
(
bos
*
Hg
+
i_h
//
(
H
//
Hg
))
*
K
w
+=
(
bos
*
H
+
i_h
)
*
K
if
SAVE_NEW_VALUE
:
v_new
+=
(
bos
*
H
+
i_h
)
*
V
stride_v
=
H
*
V
stride_h
=
H
*
K
*
V
stride_k
=
Hg
*
K
stride_w
=
H
*
K
if
USE_INITIAL_STATE
:
h0
=
h0
+
i_nh
*
K
*
V
if
STORE_FINAL_STATE
:
ht
=
ht
+
i_nh
*
K
*
V
# load initial state
if
USE_INITIAL_STATE
:
p_h0_1
=
tl
.
make_block_ptr
(
h0
,
(
K
,
V
),
(
V
,
1
),
(
0
,
i_v
*
BV
),
(
64
,
BV
),
(
1
,
0
))
b_h1
+=
tl
.
load
(
p_h0_1
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
if
K
>
64
:
p_h0_2
=
tl
.
make_block_ptr
(
h0
,
(
K
,
V
),
(
V
,
1
),
(
64
,
i_v
*
BV
),
(
64
,
BV
),
(
1
,
0
))
b_h2
+=
tl
.
load
(
p_h0_2
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
if
K
>
128
:
p_h0_3
=
tl
.
make_block_ptr
(
h0
,
(
K
,
V
),
(
V
,
1
),
(
128
,
i_v
*
BV
),
(
64
,
BV
),
(
1
,
0
))
b_h3
+=
tl
.
load
(
p_h0_3
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
if
K
>
192
:
p_h0_4
=
tl
.
make_block_ptr
(
h0
,
(
K
,
V
),
(
V
,
1
),
(
192
,
i_v
*
BV
),
(
64
,
BV
),
(
1
,
0
))
b_h4
+=
tl
.
load
(
p_h0_4
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
# main recurrence
for
i_t
in
range
(
NT
):
p_h1
=
tl
.
make_block_ptr
(
h
+
i_t
*
stride_h
,
(
K
,
V
),
(
V
,
1
),
(
0
,
i_v
*
BV
),
(
64
,
BV
),
(
1
,
0
))
tl
.
store
(
p_h1
,
b_h1
.
to
(
p_h1
.
dtype
.
element_ty
),
boundary_check
=
(
0
,
1
))
if
K
>
64
:
p_h2
=
tl
.
make_block_ptr
(
h
+
i_t
*
stride_h
,
(
K
,
V
),
(
V
,
1
),
(
64
,
i_v
*
BV
),
(
64
,
BV
),
(
1
,
0
))
tl
.
store
(
p_h2
,
b_h2
.
to
(
p_h2
.
dtype
.
element_ty
),
boundary_check
=
(
0
,
1
))
if
K
>
128
:
p_h3
=
tl
.
make_block_ptr
(
h
+
i_t
*
stride_h
,
(
K
,
V
),
(
V
,
1
),
(
128
,
i_v
*
BV
),
(
64
,
BV
),
(
1
,
0
))
tl
.
store
(
p_h3
,
b_h3
.
to
(
p_h3
.
dtype
.
element_ty
),
boundary_check
=
(
0
,
1
))
if
K
>
192
:
p_h4
=
tl
.
make_block_ptr
(
h
+
i_t
*
stride_h
,
(
K
,
V
),
(
V
,
1
),
(
192
,
i_v
*
BV
),
(
64
,
BV
),
(
1
,
0
))
tl
.
store
(
p_h4
,
b_h4
.
to
(
p_h4
.
dtype
.
element_ty
),
boundary_check
=
(
0
,
1
))
p_v
=
tl
.
make_block_ptr
(
v
,
(
T
,
V
),
(
stride_v
,
1
),
(
i_t
*
BT
,
i_v
*
BV
),
(
BT
,
BV
),
(
1
,
0
))
p_v_new
=
tl
.
make_block_ptr
(
v_new
,
(
T
,
V
),
(
stride_v
,
1
),
(
i_t
*
BT
,
i_v
*
BV
),
(
BT
,
BV
),
(
1
,
0
))
if
SAVE_NEW_VALUE
else
None
b_v_new
=
tl
.
zeros
([
BT
,
BV
],
dtype
=
tl
.
float32
)
p_w
=
tl
.
make_block_ptr
(
w
,
(
T
,
K
),
(
stride_w
,
1
),
(
i_t
*
BT
,
0
),
(
BT
,
64
),
(
1
,
0
))
b_w
=
tl
.
load
(
p_w
,
boundary_check
=
(
0
,
1
))
b_v_new
+=
tl
.
dot
(
b_w
,
b_h1
.
to
(
b_w
.
dtype
))
if
K
>
64
:
p_w
=
tl
.
make_block_ptr
(
w
,
(
T
,
K
),
(
stride_w
,
1
),
(
i_t
*
BT
,
64
),
(
BT
,
64
),
(
1
,
0
))
b_w
=
tl
.
load
(
p_w
,
boundary_check
=
(
0
,
1
))
b_v_new
+=
tl
.
dot
(
b_w
,
b_h2
.
to
(
b_w
.
dtype
))
if
K
>
128
:
p_w
=
tl
.
make_block_ptr
(
w
,
(
T
,
K
),
(
stride_w
,
1
),
(
i_t
*
BT
,
128
),
(
BT
,
64
),
(
1
,
0
))
b_w
=
tl
.
load
(
p_w
,
boundary_check
=
(
0
,
1
))
b_v_new
+=
tl
.
dot
(
b_w
,
b_h3
.
to
(
b_w
.
dtype
))
if
K
>
192
:
p_w
=
tl
.
make_block_ptr
(
w
,
(
T
,
K
),
(
stride_w
,
1
),
(
i_t
*
BT
,
192
),
(
BT
,
64
),
(
1
,
0
))
b_w
=
tl
.
load
(
p_w
,
boundary_check
=
(
0
,
1
))
b_v_new
+=
tl
.
dot
(
b_w
,
b_h4
.
to
(
b_w
.
dtype
))
b_v_new
=
-
b_v_new
+
tl
.
load
(
p_v
,
boundary_check
=
(
0
,
1
))
if
SAVE_NEW_VALUE
:
p_v_new
=
tl
.
make_block_ptr
(
v_new
,
(
T
,
V
),
(
stride_v
,
1
),
(
i_t
*
BT
,
i_v
*
BV
),
(
BT
,
BV
),
(
1
,
0
))
tl
.
store
(
p_v_new
,
b_v_new
.
to
(
p_v_new
.
dtype
.
element_ty
),
boundary_check
=
(
0
,
1
))
if
USE_G
:
last_idx
=
min
((
i_t
+
1
)
*
BT
,
T
)
-
1
b_g_last
=
tl
.
load
(
g
+
bos
*
H
+
last_idx
*
H
+
i_h
)
p_g
=
tl
.
make_block_ptr
(
g
+
bos
*
H
+
i_h
,
(
T
,
),
(
H
,
),
(
i_t
*
BT
,
),
(
BT
,
),
(
0
,
))
b_g
=
tl
.
load
(
p_g
,
boundary_check
=
(
0
,
))
b_v_new
=
b_v_new
*
safe_exp
(
b_g_last
-
b_g
)[:,
None
]
b_g_last
=
exp
(
b_g_last
)
b_h1
=
b_h1
*
b_g_last
if
K
>
64
:
b_h2
=
b_h2
*
b_g_last
if
K
>
128
:
b_h3
=
b_h3
*
b_g_last
if
K
>
192
:
b_h4
=
b_h4
*
b_g_last
b_v_new
=
b_v_new
.
to
(
k
.
dtype
.
element_ty
)
p_k
=
tl
.
make_block_ptr
(
k
,
(
K
,
T
),
(
1
,
stride_k
),
(
0
,
i_t
*
BT
),
(
64
,
BT
),
(
0
,
1
))
b_k
=
tl
.
load
(
p_k
,
boundary_check
=
(
0
,
1
))
b_h1
+=
tl
.
dot
(
b_k
,
b_v_new
)
if
K
>
64
:
p_k
=
tl
.
make_block_ptr
(
k
,
(
K
,
T
),
(
1
,
stride_k
),
(
64
,
i_t
*
BT
),
(
64
,
BT
),
(
0
,
1
))
b_k
=
tl
.
load
(
p_k
,
boundary_check
=
(
0
,
1
))
b_h2
+=
tl
.
dot
(
b_k
,
b_v_new
)
if
K
>
128
:
p_k
=
tl
.
make_block_ptr
(
k
,
(
K
,
T
),
(
1
,
stride_k
),
(
128
,
i_t
*
BT
),
(
64
,
BT
),
(
0
,
1
))
b_k
=
tl
.
load
(
p_k
,
boundary_check
=
(
0
,
1
))
b_h3
+=
tl
.
dot
(
b_k
,
b_v_new
)
if
K
>
192
:
p_k
=
tl
.
make_block_ptr
(
k
,
(
K
,
T
),
(
1
,
stride_k
),
(
192
,
i_t
*
BT
),
(
64
,
BT
),
(
0
,
1
))
b_k
=
tl
.
load
(
p_k
,
boundary_check
=
(
0
,
1
))
b_h4
+=
tl
.
dot
(
b_k
,
b_v_new
)
# epilogue
if
STORE_FINAL_STATE
:
p_ht
=
tl
.
make_block_ptr
(
ht
,
(
K
,
V
),
(
V
,
1
),
(
0
,
i_v
*
BV
),
(
64
,
BV
),
(
1
,
0
))
tl
.
store
(
p_ht
,
b_h1
.
to
(
p_ht
.
dtype
.
element_ty
),
boundary_check
=
(
0
,
1
))
if
K
>
64
:
p_ht
=
tl
.
make_block_ptr
(
ht
,
(
K
,
V
),
(
V
,
1
),
(
64
,
i_v
*
BV
),
(
64
,
BV
),
(
1
,
0
))
tl
.
store
(
p_ht
,
b_h2
.
to
(
p_ht
.
dtype
.
element_ty
),
boundary_check
=
(
0
,
1
))
if
K
>
128
:
p_ht
=
tl
.
make_block_ptr
(
ht
,
(
K
,
V
),
(
V
,
1
),
(
128
,
i_v
*
BV
),
(
64
,
BV
),
(
1
,
0
))
tl
.
store
(
p_ht
,
b_h3
.
to
(
p_ht
.
dtype
.
element_ty
),
boundary_check
=
(
0
,
1
))
if
K
>
192
:
p_ht
=
tl
.
make_block_ptr
(
ht
,
(
K
,
V
),
(
V
,
1
),
(
192
,
i_v
*
BV
),
(
64
,
BV
),
(
1
,
0
))
tl
.
store
(
p_ht
,
b_h4
.
to
(
p_ht
.
dtype
.
element_ty
),
boundary_check
=
(
0
,
1
))
def
chunk_gated_delta_rule_fwd_h
(
k
:
torch
.
Tensor
,
w
:
torch
.
Tensor
,
u
:
torch
.
Tensor
,
g
:
Optional
[
torch
.
Tensor
]
=
None
,
initial_state
:
Optional
[
torch
.
Tensor
]
=
None
,
output_final_state
:
bool
=
False
,
chunk_size
:
int
=
64
,
# SY: remove this argument and force chunk size 64?
save_new_value
:
bool
=
True
,
cu_seqlens
:
Optional
[
torch
.
LongTensor
]
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
B
,
T
,
Hg
,
K
,
V
=
*
k
.
shape
,
u
.
shape
[
-
1
]
H
=
u
.
shape
[
-
2
]
BT
=
chunk_size
chunk_indices
=
prepare_chunk_indices
(
cu_seqlens
,
chunk_size
)
if
cu_seqlens
is
not
None
else
None
# N: the actual number of sequences in the batch with either equal or variable lengths
if
cu_seqlens
is
None
:
N
,
NT
,
chunk_offsets
=
B
,
triton
.
cdiv
(
T
,
BT
),
None
else
:
N
,
NT
,
chunk_offsets
=
len
(
cu_seqlens
)
-
1
,
len
(
chunk_indices
),
prepare_chunk_offsets
(
cu_seqlens
,
BT
)
assert
K
<=
256
,
"current kernel does not support head dimension larger than 256."
h
=
k
.
new_empty
(
B
,
NT
,
H
,
K
,
V
)
final_state
=
k
.
new_empty
(
N
,
H
,
K
,
V
,
dtype
=
torch
.
float32
)
if
output_final_state
else
None
v_new
=
torch
.
empty_like
(
u
)
if
save_new_value
else
None
def
grid
(
meta
):
return
(
triton
.
cdiv
(
V
,
meta
[
'BV'
]),
N
*
H
)
chunk_gated_delta_rule_fwd_kernel_h_blockdim64
[
grid
](
k
=
k
,
v
=
u
,
w
=
w
,
v_new
=
v_new
,
g
=
g
,
h
=
h
,
h0
=
initial_state
,
ht
=
final_state
,
cu_seqlens
=
cu_seqlens
,
chunk_offsets
=
chunk_offsets
,
T
=
T
,
H
=
H
,
Hg
=
Hg
,
K
=
K
,
V
=
V
,
BT
=
BT
)
return
h
,
v_new
,
final_state
vllm/model_executor/layers/fla/ops/chunk_o.py
0 → 100644
View file @
1aa427fd
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
#
# This file contains code copied from the flash-linear-attention project.
# The original source code was licensed under the MIT license and included
# the following copyright notice:
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
# ruff: noqa: E501
from
typing
import
Optional
import
torch
from
vllm.triton_utils
import
tl
,
triton
from
.index
import
prepare_chunk_indices
from
.op
import
exp
,
safe_exp
from
.utils
import
FLA_GDN_FIX_BT
,
check_shared_mem
,
is_nvidia_hopper
BKV_LIST
=
[
64
,
128
]
if
check_shared_mem
()
else
[
32
,
64
]
NUM_WARPS
=
[
2
,
4
]
if
is_nvidia_hopper
else
[
2
,
4
,
8
]
@
triton
.
heuristics
({
'USE_G'
:
lambda
args
:
args
[
'g'
]
is
not
None
,
'IS_VARLEN'
:
lambda
args
:
args
[
'cu_seqlens'
]
is
not
None
})
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({
'BK'
:
BK
,
'BV'
:
BV
},
num_warps
=
num_warps
,
num_stages
=
num_stages
)
for
BK
in
BKV_LIST
for
BV
in
BKV_LIST
for
num_warps
in
NUM_WARPS
for
num_stages
in
[
2
,
3
,
4
]
],
key
=
[
'H'
,
'K'
,
'V'
,
'BT'
],
)
@
triton
.
jit
(
do_not_specialize
=
[
'T'
])
def
chunk_fwd_kernel_o
(
q
,
k
,
v
,
h
,
g
,
o
,
cu_seqlens
,
chunk_indices
,
scale
,
T
,
H
:
tl
.
constexpr
,
Hg
:
tl
.
constexpr
,
K
:
tl
.
constexpr
,
V
:
tl
.
constexpr
,
BT
:
tl
.
constexpr
,
BK
:
tl
.
constexpr
,
BV
:
tl
.
constexpr
,
USE_G
:
tl
.
constexpr
,
IS_VARLEN
:
tl
.
constexpr
,
):
i_v
,
i_t
,
i_bh
=
tl
.
program_id
(
0
),
tl
.
program_id
(
1
),
tl
.
program_id
(
2
)
i_b
,
i_h
=
i_bh
//
H
,
i_bh
%
H
if
IS_VARLEN
:
i_tg
=
i_t
i_n
,
i_t
=
tl
.
load
(
chunk_indices
+
i_t
*
2
).
to
(
tl
.
int32
),
tl
.
load
(
chunk_indices
+
i_t
*
2
+
1
).
to
(
tl
.
int32
)
bos
,
eos
=
tl
.
load
(
cu_seqlens
+
i_n
).
to
(
tl
.
int32
),
tl
.
load
(
cu_seqlens
+
i_n
+
1
).
to
(
tl
.
int32
)
T
=
eos
-
bos
NT
=
tl
.
cdiv
(
T
,
BT
)
else
:
NT
=
tl
.
cdiv
(
T
,
BT
)
i_tg
=
i_b
*
NT
+
i_t
bos
,
eos
=
i_b
*
T
,
i_b
*
T
+
T
# offset calculation
q
+=
(
bos
*
Hg
+
i_h
//
(
H
//
Hg
))
*
K
k
+=
(
bos
*
Hg
+
i_h
//
(
H
//
Hg
))
*
K
v
+=
(
bos
*
H
+
i_h
)
*
V
o
+=
(
bos
*
H
+
i_h
)
*
V
h
+=
(
i_tg
*
H
+
i_h
).
to
(
tl
.
int64
)
*
K
*
V
b_o
=
tl
.
zeros
([
BT
,
BV
],
dtype
=
tl
.
float32
)
b_A
=
tl
.
zeros
([
BT
,
BT
],
dtype
=
tl
.
float32
)
for
i_k
in
range
(
tl
.
cdiv
(
K
,
BK
)):
p_q
=
tl
.
make_block_ptr
(
q
,
(
T
,
K
),
(
Hg
*
K
,
1
),
(
i_t
*
BT
,
i_k
*
BK
),
(
BT
,
BK
),
(
1
,
0
))
p_k
=
tl
.
make_block_ptr
(
k
,
(
K
,
T
),
(
1
,
Hg
*
K
),
(
i_k
*
BK
,
i_t
*
BT
),
(
BK
,
BT
),
(
0
,
1
))
p_h
=
tl
.
make_block_ptr
(
h
,
(
K
,
V
),
(
V
,
1
),
(
i_k
*
BK
,
i_v
*
BV
),
(
BK
,
BV
),
(
1
,
0
))
# [BT, BK]
b_q
=
tl
.
load
(
p_q
,
boundary_check
=
(
0
,
1
))
# [BK, BT]
b_k
=
tl
.
load
(
p_k
,
boundary_check
=
(
0
,
1
))
# [BK, BV]
b_h
=
tl
.
load
(
p_h
,
boundary_check
=
(
0
,
1
))
# [BT, BK] @ [BK, BV] -> [BT, BV]
b_o
+=
tl
.
dot
(
b_q
,
b_h
)
# [BT, BK] @ [BK, BT] -> [BT, BT]
b_A
+=
tl
.
dot
(
b_q
,
b_k
)
if
USE_G
:
g
+=
bos
*
H
+
i_h
p_g
=
tl
.
make_block_ptr
(
g
,
(
T
,
),
(
H
,
),
(
i_t
*
BT
,
),
(
BT
,
),
(
0
,
))
b_g
=
tl
.
load
(
p_g
,
boundary_check
=
(
0
,
))
b_o
=
b_o
*
exp
(
b_g
)[:,
None
]
b_A
=
b_A
*
safe_exp
(
b_g
[:,
None
]
-
b_g
[
None
,
:])
o_i
=
tl
.
arange
(
0
,
BT
)
m_A
=
o_i
[:,
None
]
>=
o_i
[
None
,
:]
b_A
=
tl
.
where
(
m_A
,
b_A
,
0
)
p_v
=
tl
.
make_block_ptr
(
v
,
(
T
,
V
),
(
H
*
V
,
1
),
(
i_t
*
BT
,
i_v
*
BV
),
(
BT
,
BV
),
(
1
,
0
))
p_o
=
tl
.
make_block_ptr
(
o
,
(
T
,
V
),
(
H
*
V
,
1
),
(
i_t
*
BT
,
i_v
*
BV
),
(
BT
,
BV
),
(
1
,
0
))
b_v
=
tl
.
load
(
p_v
,
boundary_check
=
(
0
,
1
))
# to fix mma -> mma layout conversion
# already solved by triton v3.2 or higher
b_o
=
b_o
*
scale
+
tl
.
dot
(
b_A
.
to
(
b_v
.
dtype
),
b_v
)
*
scale
tl
.
store
(
p_o
,
b_o
.
to
(
p_o
.
dtype
.
element_ty
),
boundary_check
=
(
0
,
1
))
def
chunk_fwd_o
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
h
:
torch
.
Tensor
,
g
:
Optional
[
torch
.
Tensor
]
=
None
,
# cumsum of log decay
scale
:
Optional
[
float
]
=
None
,
cu_seqlens
:
Optional
[
torch
.
LongTensor
]
=
None
,
chunk_size
:
int
=
64
)
->
torch
.
Tensor
:
B
,
T
,
Hg
,
K
,
V
=
*
q
.
shape
,
v
.
shape
[
-
1
]
H
=
v
.
shape
[
-
2
]
if
FLA_GDN_FIX_BT
:
BT
=
64
else
:
BT
=
min
(
chunk_size
,
max
(
16
,
triton
.
next_power_of_2
(
T
)))
chunk_indices
=
prepare_chunk_indices
(
cu_seqlens
,
BT
)
if
cu_seqlens
is
not
None
else
None
NT
=
triton
.
cdiv
(
T
,
BT
)
if
cu_seqlens
is
None
else
len
(
chunk_indices
)
if
scale
is
None
:
scale
=
k
.
shape
[
-
1
]
**-
0.5
o
=
torch
.
empty_like
(
v
)
def
grid
(
meta
):
return
(
triton
.
cdiv
(
V
,
meta
[
'BV'
]),
NT
,
B
*
H
)
chunk_fwd_kernel_o
[
grid
](
q
,
k
,
v
,
h
,
g
,
o
,
cu_seqlens
,
chunk_indices
,
scale
,
T
=
T
,
H
=
H
,
Hg
=
Hg
,
K
=
K
,
V
=
V
,
BT
=
BT
,
)
return
o
vllm/model_executor/layers/fla/ops/chunk_scaled_dot_kkt.py
0 → 100644
View file @
1aa427fd
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
#
# This file contains code copied from the flash-linear-attention project.
# The original source code was licensed under the MIT license and included
# the following copyright notice:
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
# ruff: noqa: E501
from
typing
import
Optional
import
torch
from
vllm.triton_utils
import
tl
,
triton
from
.index
import
prepare_chunk_indices
from
.op
import
safe_exp
@
triton
.
heuristics
({
'IS_VARLEN'
:
lambda
args
:
args
[
'cu_seqlens'
]
is
not
None
,
'USE_G'
:
lambda
args
:
args
[
'g_cumsum'
]
is
not
None
})
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({
'BK'
:
BK
},
num_warps
=
num_warps
,
num_stages
=
num_stages
)
for
BK
in
[
32
,
64
,
128
]
for
num_warps
in
[
2
,
4
,
8
]
for
num_stages
in
[
2
,
3
,
4
]
],
key
=
[
'H'
,
'K'
,
'BT'
,
'IS_VARLEN'
],
)
@
triton
.
jit
(
do_not_specialize
=
[
'T'
])
def
chunk_scaled_dot_kkt_fwd_kernel
(
k
,
beta
,
g_cumsum
,
A
,
cu_seqlens
,
chunk_indices
,
T
,
H
:
tl
.
constexpr
,
Hg
:
tl
.
constexpr
,
K
:
tl
.
constexpr
,
BT
:
tl
.
constexpr
,
BK
:
tl
.
constexpr
,
IS_VARLEN
:
tl
.
constexpr
,
USE_G
:
tl
.
constexpr
,
):
i_t
,
i_bh
=
tl
.
program_id
(
0
),
tl
.
program_id
(
1
)
i_b
,
i_h
=
i_bh
//
H
,
i_bh
%
H
if
IS_VARLEN
:
i_n
,
i_t
=
tl
.
load
(
chunk_indices
+
i_t
*
2
).
to
(
tl
.
int32
),
tl
.
load
(
chunk_indices
+
i_t
*
2
+
1
).
to
(
tl
.
int32
)
bos
,
eos
=
tl
.
load
(
cu_seqlens
+
i_n
).
to
(
tl
.
int32
),
tl
.
load
(
cu_seqlens
+
i_n
+
1
).
to
(
tl
.
int32
)
T
=
eos
-
bos
else
:
bos
,
eos
=
i_b
*
T
,
i_b
*
T
+
T
o_t
=
tl
.
arange
(
0
,
BT
)
p_beta
=
tl
.
make_block_ptr
(
beta
+
bos
*
H
+
i_h
,
(
T
,
),
(
H
,
),
(
i_t
*
BT
,
),
(
BT
,
),
(
0
,
))
b_beta
=
tl
.
load
(
p_beta
,
boundary_check
=
(
0
,
))
b_A
=
tl
.
zeros
([
BT
,
BT
],
dtype
=
tl
.
float32
)
for
i_k
in
range
(
tl
.
cdiv
(
K
,
BK
)):
p_k
=
tl
.
make_block_ptr
(
k
+
(
bos
*
Hg
+
i_h
//
(
H
//
Hg
))
*
K
,
(
T
,
K
),
(
Hg
*
K
,
1
),
(
i_t
*
BT
,
i_k
*
BK
),
(
BT
,
BK
),
(
1
,
0
))
b_k
=
tl
.
load
(
p_k
,
boundary_check
=
(
0
,
1
))
b_kb
=
b_k
*
b_beta
[:,
None
]
b_A
+=
tl
.
dot
(
b_kb
.
to
(
b_k
.
dtype
),
tl
.
trans
(
b_k
))
if
USE_G
:
p_g
=
tl
.
make_block_ptr
(
g_cumsum
+
bos
*
H
+
i_h
,
(
T
,
),
(
H
,
),
(
i_t
*
BT
,
),
(
BT
,
),
(
0
,
))
b_g
=
tl
.
load
(
p_g
,
boundary_check
=
(
0
,
))
b_g_diff
=
b_g
[:,
None
]
-
b_g
[
None
,
:]
b_A
=
b_A
*
safe_exp
(
b_g_diff
)
b_A
=
tl
.
where
(
o_t
[:,
None
]
>
o_t
[
None
,
:],
b_A
,
0
)
p_A
=
tl
.
make_block_ptr
(
A
+
(
bos
*
H
+
i_h
)
*
BT
,
(
T
,
BT
),
(
BT
*
H
,
1
),
(
i_t
*
BT
,
0
),
(
BT
,
BT
),
(
1
,
0
))
tl
.
store
(
p_A
,
b_A
.
to
(
p_A
.
dtype
.
element_ty
),
boundary_check
=
(
0
,
1
))
def
chunk_scaled_dot_kkt_fwd
(
k
:
torch
.
Tensor
,
beta
:
torch
.
Tensor
,
g_cumsum
:
Optional
[
torch
.
Tensor
]
=
None
,
cu_seqlens
:
Optional
[
torch
.
LongTensor
]
=
None
,
chunk_size
:
int
=
64
,
output_dtype
:
torch
.
dtype
=
torch
.
float32
)
->
torch
.
Tensor
:
r
"""
Compute beta * K * K^T.
Args:
k (torch.Tensor):
The key tensor of shape `[B, T, H, K]`.
beta (torch.Tensor):
The beta tensor of shape `[B, T, H]`.
g_cumsum (torch.Tensor):
The cumulative sum of the gate tensor of shape `[B, T, H]`.
Default: None
cu_seqlens (torch.LongTensor):
The cumulative sequence lengths of the input tensor.
Default: None
chunk_size (int):
The chunk size. Default: 64.
output_dtype (torch.dtype):
The dtype of the output tensor. Default: `torch.float32`
Returns:
beta * K * K^T of shape `[B, T, H, BT]` where `BT` is the chunk size.
"""
B
,
T
,
Hg
,
K
=
k
.
shape
H
=
beta
.
shape
[
-
1
]
BT
=
chunk_size
chunk_indices
=
prepare_chunk_indices
(
cu_seqlens
,
BT
)
if
cu_seqlens
is
not
None
else
None
NT
=
triton
.
cdiv
(
T
,
BT
)
if
cu_seqlens
is
None
else
len
(
chunk_indices
)
A
=
torch
.
empty
(
B
,
T
,
H
,
BT
,
device
=
k
.
device
,
dtype
=
output_dtype
)
chunk_scaled_dot_kkt_fwd_kernel
[(
NT
,
B
*
H
)](
k
=
k
,
beta
=
beta
,
g_cumsum
=
g_cumsum
,
A
=
A
,
cu_seqlens
=
cu_seqlens
,
chunk_indices
=
chunk_indices
,
T
=
T
,
H
=
H
,
Hg
=
Hg
,
K
=
K
,
BT
=
BT
,
)
return
A
vllm/model_executor/layers/fla/ops/cumsum.py
0 → 100644
View file @
1aa427fd
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
#
# This file contains code copied from the flash-linear-attention project.
# The original source code was licensed under the MIT license and included
# the following copyright notice:
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
# ruff: noqa: E501
import
warnings
from
typing
import
Optional
import
torch
from
vllm.triton_utils
import
tl
,
triton
from
.index
import
prepare_chunk_indices
from
.utils
import
check_shared_mem
,
input_guard
BS_LIST
=
[
32
,
64
]
if
check_shared_mem
()
else
[
16
,
32
]
@
triton
.
heuristics
({
'IS_VARLEN'
:
lambda
args
:
args
[
'cu_seqlens'
]
is
not
None
})
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({},
num_warps
=
num_warps
)
for
num_warps
in
[
1
,
2
,
4
,
8
]
],
key
=
[
'B'
,
'H'
,
'BT'
,
'IS_VARLEN'
,
'REVERSE'
])
@
triton
.
jit
(
do_not_specialize
=
[
'T'
])
def
chunk_local_cumsum_scalar_kernel
(
s
,
o
,
cu_seqlens
,
chunk_indices
,
T
,
B
:
tl
.
constexpr
,
H
:
tl
.
constexpr
,
BT
:
tl
.
constexpr
,
REVERSE
:
tl
.
constexpr
,
IS_VARLEN
:
tl
.
constexpr
,
HEAD_FIRST
:
tl
.
constexpr
,
):
i_t
,
i_bh
=
tl
.
program_id
(
0
),
tl
.
program_id
(
1
)
i_b
,
i_h
=
i_bh
//
H
,
i_bh
%
H
if
IS_VARLEN
:
i_n
,
i_t
=
tl
.
load
(
chunk_indices
+
i_t
*
2
).
to
(
tl
.
int32
),
tl
.
load
(
chunk_indices
+
i_t
*
2
+
1
).
to
(
tl
.
int32
)
bos
,
eos
=
tl
.
load
(
cu_seqlens
+
i_n
).
to
(
tl
.
int32
),
tl
.
load
(
cu_seqlens
+
i_n
+
1
).
to
(
tl
.
int32
)
T
=
eos
-
bos
else
:
bos
,
eos
=
i_b
*
T
,
i_b
*
T
+
T
if
HEAD_FIRST
:
p_s
=
tl
.
make_block_ptr
(
s
+
bos
*
H
+
i_h
*
T
,
(
T
,
),
(
1
,
),
(
i_t
*
BT
,
),
(
BT
,
),
(
0
,
))
p_o
=
tl
.
make_block_ptr
(
o
+
bos
*
H
+
i_h
*
T
,
(
T
,
),
(
1
,
),
(
i_t
*
BT
,
),
(
BT
,
),
(
0
,
))
else
:
p_s
=
tl
.
make_block_ptr
(
s
+
bos
*
H
+
i_h
,
(
T
,
),
(
H
,
),
(
i_t
*
BT
,
),
(
BT
,
),
(
0
,
))
p_o
=
tl
.
make_block_ptr
(
o
+
bos
*
H
+
i_h
,
(
T
,
),
(
H
,
),
(
i_t
*
BT
,
),
(
BT
,
),
(
0
,
))
# [BT]
b_s
=
tl
.
load
(
p_s
,
boundary_check
=
(
0
,
)).
to
(
tl
.
float32
)
b_o
=
tl
.
cumsum
(
b_s
,
axis
=
0
)
if
REVERSE
:
b_z
=
tl
.
sum
(
b_s
,
axis
=
0
)
b_o
=
-
b_o
+
b_z
[
None
]
+
b_s
tl
.
store
(
p_o
,
b_o
.
to
(
p_o
.
dtype
.
element_ty
),
boundary_check
=
(
0
,
))
@
triton
.
heuristics
({
'IS_VARLEN'
:
lambda
args
:
args
[
'cu_seqlens'
]
is
not
None
})
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({
'BS'
:
BS
},
num_warps
=
num_warps
)
for
BS
in
BS_LIST
for
num_warps
in
[
2
,
4
,
8
]
],
key
=
[
'B'
,
'H'
,
'S'
,
'BT'
,
'IS_VARLEN'
,
'REVERSE'
])
@
triton
.
jit
(
do_not_specialize
=
[
'T'
])
def
chunk_local_cumsum_vector_kernel
(
s
,
o
,
cu_seqlens
,
chunk_indices
,
T
,
B
:
tl
.
constexpr
,
H
:
tl
.
constexpr
,
S
:
tl
.
constexpr
,
BT
:
tl
.
constexpr
,
BS
:
tl
.
constexpr
,
REVERSE
:
tl
.
constexpr
,
IS_VARLEN
:
tl
.
constexpr
,
HEAD_FIRST
:
tl
.
constexpr
,
):
i_s
,
i_t
,
i_bh
=
tl
.
program_id
(
0
),
tl
.
program_id
(
1
),
tl
.
program_id
(
2
)
i_b
,
i_h
=
i_bh
//
H
,
i_bh
%
H
if
IS_VARLEN
:
i_n
,
i_t
=
tl
.
load
(
chunk_indices
+
i_t
*
2
).
to
(
tl
.
int32
),
tl
.
load
(
chunk_indices
+
i_t
*
2
+
1
).
to
(
tl
.
int32
)
bos
,
eos
=
tl
.
load
(
cu_seqlens
+
i_n
).
to
(
tl
.
int32
),
tl
.
load
(
cu_seqlens
+
i_n
+
1
).
to
(
tl
.
int32
)
T
=
eos
-
bos
else
:
bos
,
eos
=
i_b
*
T
,
i_b
*
T
+
T
o_i
=
tl
.
arange
(
0
,
BT
)
if
REVERSE
:
m_s
=
tl
.
where
(
o_i
[:,
None
]
<=
o_i
[
None
,
:],
1.
,
0.
)
else
:
m_s
=
tl
.
where
(
o_i
[:,
None
]
>=
o_i
[
None
,
:],
1.
,
0.
)
if
HEAD_FIRST
:
p_s
=
tl
.
make_block_ptr
(
s
+
(
bos
*
H
+
i_h
*
T
)
*
S
,
(
T
,
S
),
(
S
,
1
),
(
i_t
*
BT
,
i_s
*
BS
),
(
BT
,
BS
),
(
1
,
0
))
p_o
=
tl
.
make_block_ptr
(
o
+
(
bos
*
H
+
i_h
*
T
)
*
S
,
(
T
,
S
),
(
S
,
1
),
(
i_t
*
BT
,
i_s
*
BS
),
(
BT
,
BS
),
(
1
,
0
))
else
:
p_s
=
tl
.
make_block_ptr
(
s
+
(
bos
*
H
+
i_h
)
*
S
,
(
T
,
S
),
(
H
*
S
,
1
),
(
i_t
*
BT
,
i_s
*
BS
),
(
BT
,
BS
),
(
1
,
0
))
p_o
=
tl
.
make_block_ptr
(
o
+
(
bos
*
H
+
i_h
)
*
S
,
(
T
,
S
),
(
H
*
S
,
1
),
(
i_t
*
BT
,
i_s
*
BS
),
(
BT
,
BS
),
(
1
,
0
))
# [BT, BS]
b_s
=
tl
.
load
(
p_s
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
b_o
=
tl
.
dot
(
m_s
,
b_s
,
allow_tf32
=
False
)
tl
.
store
(
p_o
,
b_o
.
to
(
p_o
.
dtype
.
element_ty
),
boundary_check
=
(
0
,
1
))
def
chunk_local_cumsum_scalar
(
g
:
torch
.
Tensor
,
chunk_size
:
int
,
reverse
:
bool
=
False
,
cu_seqlens
:
Optional
[
torch
.
Tensor
]
=
None
,
head_first
:
bool
=
False
,
output_dtype
:
Optional
[
torch
.
dtype
]
=
torch
.
float
)
->
torch
.
Tensor
:
if
head_first
:
B
,
H
,
T
=
g
.
shape
else
:
B
,
T
,
H
=
g
.
shape
assert
chunk_size
==
2
**
(
chunk_size
.
bit_length
()
-
1
),
"chunk_size must be a power of 2"
BT
=
chunk_size
chunk_indices
=
prepare_chunk_indices
(
cu_seqlens
,
BT
)
if
cu_seqlens
is
not
None
else
None
NT
=
triton
.
cdiv
(
T
,
BT
)
if
cu_seqlens
is
None
else
len
(
chunk_indices
)
g_org
,
g
=
g
,
torch
.
empty_like
(
g
,
dtype
=
output_dtype
or
g
.
dtype
)
grid
=
(
NT
,
B
*
H
)
chunk_local_cumsum_scalar_kernel
[
grid
](
g_org
,
g
,
cu_seqlens
,
chunk_indices
,
T
=
T
,
B
=
B
,
H
=
H
,
BT
=
BT
,
HEAD_FIRST
=
head_first
,
REVERSE
=
reverse
)
return
g
def
chunk_local_cumsum_vector
(
g
:
torch
.
Tensor
,
chunk_size
:
int
,
reverse
:
bool
=
False
,
cu_seqlens
:
Optional
[
torch
.
Tensor
]
=
None
,
head_first
:
bool
=
False
,
output_dtype
:
Optional
[
torch
.
dtype
]
=
torch
.
float
)
->
torch
.
Tensor
:
if
head_first
:
B
,
H
,
T
,
S
=
g
.
shape
else
:
B
,
T
,
H
,
S
=
g
.
shape
BT
=
chunk_size
chunk_indices
=
prepare_chunk_indices
(
cu_seqlens
,
chunk_size
)
if
cu_seqlens
is
not
None
else
None
NT
=
triton
.
cdiv
(
T
,
BT
)
if
cu_seqlens
is
None
else
len
(
chunk_indices
)
assert
chunk_size
==
2
**
(
chunk_size
.
bit_length
()
-
1
),
"chunk_size must be a power of 2"
g_org
,
g
=
g
,
torch
.
empty_like
(
g
,
dtype
=
output_dtype
or
g
.
dtype
)
def
grid
(
meta
):
return
(
triton
.
cdiv
(
meta
[
'S'
],
meta
[
'BS'
]),
NT
,
B
*
H
)
# keep cummulative normalizer in fp32
# this kernel is equivalent to
# g = g.view(B, H, NT, BT, -1).cumsum(-2).view(B, H, T, -1)
chunk_local_cumsum_vector_kernel
[
grid
](
g_org
,
g
,
cu_seqlens
,
chunk_indices
,
T
=
T
,
B
=
B
,
H
=
H
,
S
=
S
,
BT
=
BT
,
HEAD_FIRST
=
head_first
,
REVERSE
=
reverse
)
return
g
@
input_guard
def
chunk_local_cumsum
(
g
:
torch
.
Tensor
,
chunk_size
:
int
,
reverse
:
bool
=
False
,
cu_seqlens
:
Optional
[
torch
.
Tensor
]
=
None
,
head_first
:
bool
=
False
,
output_dtype
:
Optional
[
torch
.
dtype
]
=
torch
.
float
,
**
kwargs
)
->
torch
.
Tensor
:
if
not
head_first
and
g
.
shape
[
1
]
<
g
.
shape
[
2
]:
warnings
.
warn
(
f
"Input tensor shape suggests potential format mismatch: seq_len (
{
g
.
shape
[
1
]
}
) < num_heads (
{
g
.
shape
[
2
]
}
). "
"This may indicate the inputs were passed in head-first format [B, H, T, ...] "
"when head_first=False was specified. "
"Please verify your input tensor format matches the expected shape [B, T, H, ...]."
,
stacklevel
=
2
)
if
cu_seqlens
is
not
None
:
assert
g
.
shape
[
0
]
==
1
,
"Only batch size 1 is supported when cu_seqlens are provided"
if
len
(
g
.
shape
)
==
3
:
return
chunk_local_cumsum_scalar
(
g
,
chunk_size
,
reverse
,
cu_seqlens
,
head_first
,
output_dtype
)
elif
len
(
g
.
shape
)
==
4
:
return
chunk_local_cumsum_vector
(
g
,
chunk_size
,
reverse
,
cu_seqlens
,
head_first
,
output_dtype
)
else
:
raise
ValueError
(
f
"Unsupported input shape
{
g
.
shape
}
. "
f
"which should be (B, T, H, D) if `head_first=False` "
f
"or (B, H, T, D) otherwise"
)
vllm/model_executor/layers/fla/ops/fused_recurrent.py
0 → 100644
View file @
1aa427fd
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
#
# This file contains code copied from the flash-linear-attention project.
# The original source code was licensed under the MIT license and included
# the following copyright notice:
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
# ruff: noqa: E501
from
typing
import
Optional
import
torch
from
vllm.triton_utils
import
tl
,
triton
from
.op
import
exp
@
triton
.
heuristics
({
'USE_INITIAL_STATE'
:
lambda
args
:
args
[
'h0'
]
is
not
None
,
'IS_VARLEN'
:
lambda
args
:
args
[
'cu_seqlens'
]
is
not
None
,
"IS_CONTINUOUS_BATCHING"
:
lambda
args
:
args
[
'ssm_state_indices'
]
is
not
None
,
"IS_SPEC_DECODING"
:
lambda
args
:
args
[
'num_accepted_tokens'
]
is
not
None
,
})
@
triton
.
jit
(
do_not_specialize
=
[
'N'
,
'T'
])
def
fused_recurrent_gated_delta_rule_fwd_kernel
(
q
,
k
,
v
,
g
,
beta
,
o
,
h0
,
ht
,
cu_seqlens
,
ssm_state_indices
,
num_accepted_tokens
,
scale
,
N
:
tl
.
constexpr
,
# num of sequences
T
:
tl
.
constexpr
,
# num of tokens
B
:
tl
.
constexpr
,
H
:
tl
.
constexpr
,
HV
:
tl
.
constexpr
,
K
:
tl
.
constexpr
,
V
:
tl
.
constexpr
,
BK
:
tl
.
constexpr
,
BV
:
tl
.
constexpr
,
stride_init_state_token
:
tl
.
constexpr
,
stride_final_state_token
:
tl
.
constexpr
,
stride_indices_seq
:
tl
.
constexpr
,
stride_indices_tok
:
tl
.
constexpr
,
USE_INITIAL_STATE
:
tl
.
constexpr
,
# whether to use initial state
INPLACE_FINAL_STATE
:
tl
.
constexpr
,
# whether to store final state inplace
IS_BETA_HEADWISE
:
tl
.
constexpr
,
# whether beta is headwise vector or scalar,
USE_QK_L2NORM_IN_KERNEL
:
tl
.
constexpr
,
IS_VARLEN
:
tl
.
constexpr
,
IS_CONTINUOUS_BATCHING
:
tl
.
constexpr
,
IS_SPEC_DECODING
:
tl
.
constexpr
,
):
i_k
,
i_v
,
i_nh
=
tl
.
program_id
(
0
),
tl
.
program_id
(
1
),
tl
.
program_id
(
2
)
i_n
,
i_hv
=
i_nh
//
HV
,
i_nh
%
HV
i_h
=
i_hv
//
(
HV
//
H
)
if
IS_VARLEN
:
bos
,
eos
=
tl
.
load
(
cu_seqlens
+
i_n
).
to
(
tl
.
int64
),
tl
.
load
(
cu_seqlens
+
i_n
+
1
).
to
(
tl
.
int64
)
all
=
T
T
=
eos
-
bos
else
:
bos
,
eos
=
i_n
*
T
,
i_n
*
T
+
T
all
=
B
*
T
if
T
==
0
:
# no tokens to process for this sequence
return
o_k
=
i_k
*
BK
+
tl
.
arange
(
0
,
BK
)
o_v
=
i_v
*
BV
+
tl
.
arange
(
0
,
BV
)
p_q
=
q
+
(
bos
*
H
+
i_h
)
*
K
+
o_k
p_k
=
k
+
(
bos
*
H
+
i_h
)
*
K
+
o_k
p_v
=
v
+
(
bos
*
HV
+
i_hv
)
*
V
+
o_v
if
IS_BETA_HEADWISE
:
p_beta
=
beta
+
(
bos
*
HV
+
i_hv
)
*
V
+
o_v
else
:
p_beta
=
beta
+
bos
*
HV
+
i_hv
p_g
=
g
+
bos
*
HV
+
i_hv
p_o
=
o
+
((
i_k
*
all
+
bos
)
*
HV
+
i_hv
)
*
V
+
o_v
mask_k
=
o_k
<
K
mask_v
=
o_v
<
V
mask_h
=
mask_k
[:,
None
]
&
mask_v
[
None
,
:]
b_h
=
tl
.
zeros
([
BK
,
BV
],
dtype
=
tl
.
float32
)
if
USE_INITIAL_STATE
:
if
IS_CONTINUOUS_BATCHING
:
if
IS_SPEC_DECODING
:
i_t
=
tl
.
load
(
num_accepted_tokens
+
i_n
).
to
(
tl
.
int64
)
-
1
else
:
i_t
=
0
p_h0
=
h0
+
tl
.
load
(
ssm_state_indices
+
i_n
*
stride_indices_seq
+
i_t
).
to
(
tl
.
int64
)
*
stride_init_state_token
else
:
p_h0
=
h0
+
bos
*
HV
*
K
*
V
p_h0
=
p_h0
+
i_hv
*
K
*
V
+
o_k
[:,
None
]
*
V
+
o_v
[
None
,
:]
b_h
+=
tl
.
load
(
p_h0
,
mask
=
mask_h
,
other
=
0
).
to
(
tl
.
float32
)
for
i_t
in
range
(
0
,
T
):
b_q
=
tl
.
load
(
p_q
,
mask
=
mask_k
,
other
=
0
).
to
(
tl
.
float32
)
b_k
=
tl
.
load
(
p_k
,
mask
=
mask_k
,
other
=
0
).
to
(
tl
.
float32
)
b_v
=
tl
.
load
(
p_v
,
mask
=
mask_v
,
other
=
0
).
to
(
tl
.
float32
)
b_g
=
tl
.
load
(
p_g
).
to
(
tl
.
float32
)
if
USE_QK_L2NORM_IN_KERNEL
:
b_q
=
b_q
/
(
tl
.
sqrt
(
tl
.
sum
(
b_q
*
b_q
))
+
1e-6
)
b_k
=
b_k
/
(
tl
.
sqrt
(
tl
.
sum
(
b_k
*
b_k
))
+
1e-6
)
b_q
=
b_q
*
scale
# [BK, BV]
b_h
*=
exp
(
b_g
)
# [BV]
b_v
-=
tl
.
sum
(
b_h
*
b_k
[:,
None
],
0
)
if
IS_BETA_HEADWISE
:
b_beta
=
tl
.
load
(
p_beta
,
mask
=
mask_v
,
other
=
0
).
to
(
tl
.
float32
)
else
:
b_beta
=
tl
.
load
(
p_beta
).
to
(
tl
.
float32
)
b_v
*=
b_beta
# [BK, BV]
b_h
+=
b_k
[:,
None
]
*
b_v
[
None
,
:]
# [BV]
b_o
=
tl
.
sum
(
b_h
*
b_q
[:,
None
],
0
)
tl
.
store
(
p_o
,
b_o
.
to
(
p_o
.
dtype
.
element_ty
),
mask
=
mask_v
)
# keep the states for multi-query tokens
if
INPLACE_FINAL_STATE
:
p_ht
=
ht
+
tl
.
load
(
ssm_state_indices
+
i_n
*
stride_indices_seq
+
i_t
).
to
(
tl
.
int64
)
*
stride_final_state_token
else
:
p_ht
=
ht
+
(
bos
+
i_t
)
*
stride_final_state_token
p_ht
=
p_ht
+
i_hv
*
K
*
V
+
o_k
[:,
None
]
*
V
+
o_v
[
None
,
:]
tl
.
store
(
p_ht
,
b_h
.
to
(
p_ht
.
dtype
.
element_ty
),
mask
=
mask_h
)
p_q
+=
H
*
K
p_k
+=
H
*
K
p_o
+=
HV
*
V
p_v
+=
HV
*
V
p_g
+=
HV
p_beta
+=
HV
*
(
V
if
IS_BETA_HEADWISE
else
1
)
def
fused_recurrent_gated_delta_rule_fwd
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
g
:
torch
.
Tensor
,
beta
:
torch
.
Tensor
,
scale
:
float
,
initial_state
:
torch
.
Tensor
,
inplace_final_state
:
bool
=
True
,
cu_seqlens
:
Optional
[
torch
.
LongTensor
]
=
None
,
ssm_state_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
num_accepted_tokens
:
Optional
[
torch
.
Tensor
]
=
None
,
use_qk_l2norm_in_kernel
:
bool
=
False
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
B
,
T
,
H
,
K
,
V
=
*
k
.
shape
,
v
.
shape
[
-
1
]
HV
=
v
.
shape
[
2
]
N
=
B
if
cu_seqlens
is
None
else
len
(
cu_seqlens
)
-
1
BK
,
BV
=
triton
.
next_power_of_2
(
K
),
min
(
triton
.
next_power_of_2
(
V
),
8
)
NK
,
NV
=
triton
.
cdiv
(
K
,
BK
),
triton
.
cdiv
(
V
,
BV
)
assert
NK
==
1
,
"NK > 1 is not supported yet"
num_stages
=
3
num_warps
=
1
o
=
q
.
new_empty
(
NK
,
*
v
.
shape
)
if
inplace_final_state
:
final_state
=
initial_state
else
:
final_state
=
q
.
new_empty
(
T
,
HV
,
K
,
V
,
dtype
=
initial_state
.
dtype
)
stride_init_state_token
=
initial_state
.
stride
(
0
)
stride_final_state_token
=
final_state
.
stride
(
0
)
if
ssm_state_indices
is
None
:
stride_indices_seq
,
stride_indices_tok
=
1
,
1
elif
ssm_state_indices
.
ndim
==
1
:
stride_indices_seq
,
stride_indices_tok
=
ssm_state_indices
.
stride
(
0
),
1
else
:
stride_indices_seq
,
stride_indices_tok
=
ssm_state_indices
.
stride
()
grid
=
(
NK
,
NV
,
N
*
HV
)
fused_recurrent_gated_delta_rule_fwd_kernel
[
grid
](
q
=
q
,
k
=
k
,
v
=
v
,
g
=
g
,
beta
=
beta
,
o
=
o
,
h0
=
initial_state
,
ht
=
final_state
,
cu_seqlens
=
cu_seqlens
,
ssm_state_indices
=
ssm_state_indices
,
num_accepted_tokens
=
num_accepted_tokens
,
scale
=
scale
,
N
=
N
,
T
=
T
,
B
=
B
,
H
=
H
,
HV
=
HV
,
K
=
K
,
V
=
V
,
BK
=
BK
,
BV
=
BV
,
stride_init_state_token
=
stride_init_state_token
,
stride_final_state_token
=
stride_final_state_token
,
stride_indices_seq
=
stride_indices_seq
,
stride_indices_tok
=
stride_indices_tok
,
IS_BETA_HEADWISE
=
beta
.
ndim
==
v
.
ndim
,
USE_QK_L2NORM_IN_KERNEL
=
use_qk_l2norm_in_kernel
,
INPLACE_FINAL_STATE
=
inplace_final_state
,
num_warps
=
num_warps
,
num_stages
=
num_stages
,
)
o
=
o
.
squeeze
(
0
)
return
o
,
final_state
class
FusedRecurrentFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
g
:
torch
.
Tensor
,
beta
:
torch
.
Tensor
,
scale
:
float
,
initial_state
:
torch
.
Tensor
,
inplace_final_state
:
bool
=
True
,
cu_seqlens
:
Optional
[
torch
.
LongTensor
]
=
None
,
ssm_state_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
num_accepted_tokens
:
Optional
[
torch
.
Tensor
]
=
None
,
use_qk_l2norm_in_kernel
:
bool
=
False
):
o
,
final_state
=
fused_recurrent_gated_delta_rule_fwd
(
q
=
q
.
contiguous
(),
k
=
k
.
contiguous
(),
v
=
v
.
contiguous
(),
g
=
g
.
contiguous
(),
beta
=
beta
.
contiguous
(),
scale
=
scale
,
initial_state
=
initial_state
,
inplace_final_state
=
inplace_final_state
,
cu_seqlens
=
cu_seqlens
,
ssm_state_indices
=
ssm_state_indices
,
num_accepted_tokens
=
num_accepted_tokens
,
use_qk_l2norm_in_kernel
=
use_qk_l2norm_in_kernel
,
)
return
o
,
final_state
def
fused_recurrent_gated_delta_rule
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
g
:
torch
.
Tensor
,
beta
:
torch
.
Tensor
=
None
,
scale
:
float
=
None
,
initial_state
:
torch
.
Tensor
=
None
,
inplace_final_state
:
bool
=
True
,
cu_seqlens
:
Optional
[
torch
.
LongTensor
]
=
None
,
ssm_state_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
num_accepted_tokens
:
Optional
[
torch
.
Tensor
]
=
None
,
use_qk_l2norm_in_kernel
:
bool
=
False
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
r
"""
Args:
q (torch.Tensor):
queries of shape `[B, T, H, K]`.
k (torch.Tensor):
keys of shape `[B, T, H, K]`.
v (torch.Tensor):
values of shape `[B, T, HV, V]`.
GVA is applied if `HV > H`.
g (torch.Tensor):
g (decays) of shape `[B, T, HV]`.
beta (torch.Tensor):
betas of shape `[B, T, HV]`.
scale (Optional[int]):
Scale factor for the RetNet attention scores.
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
initial_state (Optional[torch.Tensor]):
Initial state of shape `[N, HV, K, V]` for `N` input sequences.
For equal-length input sequences, `N` equals the batch size `B`.
Default: `None`.
inplace_final_state: bool:
Whether to store the final state in-place to save memory.
Default: `True`.
cu_seqlens (torch.LongTensor):
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
consistent with the FlashAttention API.
ssm_state_indices (Optional[torch.Tensor]):
Indices to map the input sequences to the initial/final states.
num_accepted_tokens (Optional[torch.Tensor]):
Number of accepted tokens for each sequence during decoding.
Returns:
o (torch.Tensor):
Outputs of shape `[B, T, HV, V]`.
final_state (torch.Tensor):
Final state of shape `[N, HV, K, V]`.
Examples::
>>> import torch
>>> import torch.nn.functional as F
>>> from einops import rearrange
>>> from fla.ops.gated_delta_rule import fused_recurrent_gated_delta_rule
# inputs with equal lengths
>>> B, T, H, HV, K, V = 4, 2048, 4, 8, 512, 512
>>> q = torch.randn(B, T, H, K, device='cuda')
>>> k = F.normalize(torch.randn(B, T, H, K, device='cuda'), p=2, dim=-1)
>>> v = torch.randn(B, T, HV, V, device='cuda')
>>> g = F.logsigmoid(torch.rand(B, T, HV, device='cuda'))
>>> beta = torch.rand(B, T, HV, device='cuda').sigmoid()
>>> h0 = torch.randn(B, HV, K, V, device='cuda')
>>> o, ht = fused_gated_recurrent_delta_rule(
q, k, v, g, beta,
initial_state=h0,
)
# for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
>>> q, k, v, g, beta = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, g, beta))
# for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
>>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
>>> o_var, ht_var = fused_gated_recurrent_delta_rule(
q, k, v, g, beta,
initial_state=h0,
cu_seqlens=cu_seqlens
)
"""
if
cu_seqlens
is
not
None
and
q
.
shape
[
0
]
!=
1
:
raise
ValueError
(
f
"The batch size is expected to be 1 rather than
{
q
.
shape
[
0
]
}
when using `cu_seqlens`."
f
"Please flatten variable-length inputs before processing."
)
if
scale
is
None
:
scale
=
k
.
shape
[
-
1
]
**-
0.5
else
:
assert
scale
>
0
,
"scale must be positive"
if
beta
is
None
:
beta
=
torch
.
ones_like
(
q
[...,
0
])
o
,
final_state
=
FusedRecurrentFunction
.
apply
(
q
,
k
,
v
,
g
,
beta
,
scale
,
initial_state
,
inplace_final_state
,
cu_seqlens
,
ssm_state_indices
,
num_accepted_tokens
,
use_qk_l2norm_in_kernel
,
)
return
o
,
final_state
vllm/model_executor/layers/fla/ops/index.py
0 → 100644
View file @
1aa427fd
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
#
# This file contains code copied from the flash-linear-attention project.
# The original source code was licensed under the MIT license and included
# the following copyright notice:
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
# ruff: noqa: E501
import
torch
from
vllm.triton_utils
import
triton
from
.utils
import
tensor_cache
@
tensor_cache
def
prepare_lens
(
cu_seqlens
:
torch
.
LongTensor
)
->
torch
.
LongTensor
:
return
cu_seqlens
[
1
:]
-
cu_seqlens
[:
-
1
]
@
tensor_cache
def
prepare_chunk_indices
(
cu_seqlens
:
torch
.
LongTensor
,
chunk_size
:
int
)
->
torch
.
LongTensor
:
indices
=
torch
.
cat
([
torch
.
arange
(
n
)
for
n
in
triton
.
cdiv
(
prepare_lens
(
cu_seqlens
),
chunk_size
).
tolist
()
])
return
torch
.
stack
([
indices
.
eq
(
0
).
cumsum
(
0
)
-
1
,
indices
],
1
).
to
(
cu_seqlens
)
@
tensor_cache
def
prepare_chunk_offsets
(
cu_seqlens
:
torch
.
LongTensor
,
chunk_size
:
int
)
->
torch
.
LongTensor
:
return
torch
.
cat
([
cu_seqlens
.
new_tensor
([
0
]),
triton
.
cdiv
(
prepare_lens
(
cu_seqlens
),
chunk_size
)
]).
cumsum
(
-
1
)
vllm/model_executor/layers/fla/ops/l2norm.py
0 → 100644
View file @
1aa427fd
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
#
# This file contains code copied from the flash-linear-attention project.
# The original source code was licensed under the MIT license and included
# the following copyright notice:
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
import
os
from
typing
import
Optional
import
torch
from
vllm.triton_utils
import
tl
,
triton
BT_LIST
=
[
8
,
16
,
32
,
64
,
128
]
USE_DEFAULT_FLA_NORM
=
int
(
os
.
getenv
(
"USE_DEFAULT_FLA_NORM"
,
"0"
))
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({},
num_warps
=
num_warps
)
for
num_warps
in
[
1
,
2
,
4
,
8
,
16
,
32
]
],
key
=
[
'D'
])
@
triton
.
jit
def
l2norm_fwd_kernel1
(
x
,
y
,
D
,
BD
:
tl
.
constexpr
,
eps
,
):
i_t
=
tl
.
program_id
(
0
)
x
+=
i_t
*
D
y
+=
i_t
*
D
# Compute mean and variance
cols
=
tl
.
arange
(
0
,
BD
)
mask
=
cols
<
D
b_x
=
tl
.
load
(
x
+
cols
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
b_var
=
tl
.
sum
(
b_x
*
b_x
,
axis
=
0
)
b_rstd
=
1
/
tl
.
sqrt
(
b_var
+
eps
)
# tl.store(Rstd + i_t, rstd)
# Normalize and apply linear transformation
b_y
=
b_x
*
b_rstd
tl
.
store
(
y
+
cols
,
b_y
,
mask
=
mask
)
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({
'BT'
:
BT
},
num_warps
=
num_warps
)
for
num_warps
in
[
1
,
2
,
4
,
8
,
16
]
for
BT
in
BT_LIST
],
key
=
[
'D'
])
@
triton
.
jit
(
do_not_specialize
=
[
"NB"
])
def
l2norm_fwd_kernel
(
x
,
y
,
eps
,
NB
,
T
,
D
:
tl
.
constexpr
,
BT
:
tl
.
constexpr
,
BD
:
tl
.
constexpr
,
):
i_t
=
tl
.
program_id
(
0
)
p_x
=
tl
.
make_block_ptr
(
x
,
(
T
,
D
),
(
D
,
1
),
(
i_t
*
BT
,
0
),
(
BT
,
BD
),
(
1
,
0
))
b_x
=
tl
.
load
(
p_x
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
b_var
=
tl
.
sum
(
b_x
*
b_x
,
axis
=
1
)
b_y
=
b_x
/
tl
.
sqrt
(
b_var
+
eps
)[:,
None
]
p_y
=
tl
.
make_block_ptr
(
y
,
(
T
,
D
),
(
D
,
1
),
(
i_t
*
BT
,
0
),
(
BT
,
BD
),
(
1
,
0
))
tl
.
store
(
p_y
,
b_y
.
to
(
p_y
.
dtype
.
element_ty
),
boundary_check
=
(
0
,
1
))
@
triton
.
jit
def
l2norm_fwd_kernel2
(
X
,
Y
,
eps
,
M
,
N
:
tl
.
constexpr
,
MBLOCK
:
tl
.
constexpr
):
xoffset
=
tl
.
program_id
(
0
)
*
MBLOCK
row_idx
=
xoffset
+
tl
.
arange
(
0
,
MBLOCK
)[:,
None
]
xmask
=
row_idx
<
M
rindex
=
tl
.
arange
(
0
,
N
)[
None
,
:]
xs
=
tl
.
load
(
X
+
(
rindex
+
N
*
row_idx
),
None
).
to
(
tl
.
float32
)
square
=
tl
.
broadcast_to
(
xs
*
xs
,
[
MBLOCK
,
N
])
square_sum
=
tl
.
sum
(
tl
.
where
(
xmask
,
square
,
0
),
1
)[:,
None
]
rsqrt
=
tl
.
rsqrt
(
square_sum
+
eps
)
tl
.
store
(
Y
+
(
rindex
+
N
*
row_idx
),
xs
*
rsqrt
,
xmask
)
def
l2norm_fwd
(
x
:
torch
.
Tensor
,
eps
:
float
=
1e-6
,
output_dtype
:
Optional
[
torch
.
dtype
]
=
None
):
x_shape_og
=
x
.
shape
x
=
x
.
view
(
-
1
,
x
.
shape
[
-
1
])
# allocate output
if
output_dtype
is
None
:
y
=
torch
.
empty_like
(
x
)
else
:
y
=
torch
.
empty_like
(
x
,
dtype
=
output_dtype
)
assert
y
.
stride
(
-
1
)
==
1
T
,
D
=
x
.
shape
[
0
],
x
.
shape
[
-
1
]
# rstd = torch.empty((T,), dtype=torch.float32, device=x.device)
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE
=
65536
//
x
.
element_size
()
BD
=
min
(
MAX_FUSED_SIZE
,
triton
.
next_power_of_2
(
D
))
if
D
>
BD
:
raise
RuntimeError
(
"This layer doesn't support feature dim >= 64KB."
)
if
not
USE_DEFAULT_FLA_NORM
:
MBLOCK
=
32
# M, N = x.shape
l2norm_fwd_kernel2
[(
triton
.
cdiv
(
T
,
MBLOCK
),
)](
x
,
y
,
eps
,
T
,
D
,
MBLOCK
,
)
else
:
if
D
<=
512
:
NB
=
triton
.
cdiv
(
T
,
2048
)
def
grid
(
meta
):
return
(
triton
.
cdiv
(
T
,
meta
[
'BT'
]),
)
l2norm_fwd_kernel
[
grid
](
x
,
y
,
eps
,
NB
=
NB
,
T
=
T
,
D
=
D
,
BD
=
BD
,
)
else
:
l2norm_fwd_kernel1
[(
T
,
)](
x
,
y
,
eps
=
eps
,
D
=
D
,
BD
=
BD
,
)
return
y
.
view
(
x_shape_og
)
vllm/model_executor/layers/fla/ops/layernorm_guard.py
0 → 100644
View file @
1aa427fd
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Tri Dao
#
# This file contains code copied from the flash-linear-attention project.
# The original source code was licensed under the MIT license and included
# the following copyright notice:
# Copyright (c) 2024, Tri Dao.
# ruff: noqa: E501
# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.
# This backward pass is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
from
typing
import
Optional
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
einops
import
rearrange
from
vllm.triton_utils
import
tl
,
triton
from
.utils
import
input_guard
def
rms_norm_ref
(
x
,
weight
,
bias
,
z
=
None
,
eps
=
1e-6
,
group_size
=
None
,
norm_before_gate
=
True
,
upcast
=
True
):
dtype
=
x
.
dtype
weight
=
weight
.
float
()
bias
=
bias
.
float
()
if
bias
is
not
None
else
None
if
upcast
:
x
=
x
.
float
()
z
=
z
.
float
()
if
z
is
not
None
else
z
if
z
is
not
None
and
not
norm_before_gate
:
x
=
x
*
F
.
silu
(
z
)
if
group_size
is
None
:
rstd
=
1
/
torch
.
sqrt
((
x
.
square
()).
mean
(
dim
=-
1
,
keepdim
=
True
)
+
eps
)
out
=
(
x
*
rstd
*
weight
)
+
bias
if
bias
is
not
None
else
(
x
*
rstd
*
weight
)
else
:
x_group
=
rearrange
(
x
,
"... (g d) -> ... g d"
,
d
=
group_size
)
rstd
=
1
/
torch
.
sqrt
((
x_group
.
square
()).
mean
(
dim
=-
1
,
keepdim
=
True
)
+
eps
)
out
=
rearrange
(
x_group
*
rstd
,
"... g d -> ... (g d)"
)
*
weight
if
bias
is
not
None
:
out
=
out
+
bias
if
z
is
not
None
and
norm_before_gate
:
out
*=
F
.
silu
(
z
)
return
out
.
to
(
dtype
)
@
triton
.
heuristics
({
"HAS_BIAS"
:
lambda
args
:
args
[
"B"
]
is
not
None
,
"HAS_Z"
:
lambda
args
:
args
[
"Z"
]
is
not
None
,
})
@
triton
.
jit
def
layer_norm_fwd_kernel
(
X
,
# pointer to the input
Y
,
# pointer to the output
W
,
# pointer to the weights
B
,
# pointer to the biases
Z
,
# pointer to the other branch
Mean
,
# pointer to the mean
Rstd
,
# pointer to the 1/std
stride_x_row
,
# how much to increase the pointer when moving by 1 row
stride_y_row
,
stride_z_row
,
M
,
# number of rows in X
N
,
# number of columns in X
eps
,
# epsilon to avoid division by zero
BLOCK_N
:
tl
.
constexpr
,
HAS_BIAS
:
tl
.
constexpr
,
HAS_Z
:
tl
.
constexpr
,
NORM_BEFORE_GATE
:
tl
.
constexpr
,
IS_RMS_NORM
:
tl
.
constexpr
,
):
# Map the program id to the row of X and Y it should compute.
row
=
tl
.
program_id
(
0
)
group
=
tl
.
program_id
(
1
)
X
+=
row
*
stride_x_row
+
group
*
N
Y
+=
row
*
stride_y_row
+
group
*
N
if
HAS_Z
:
Z
+=
row
*
stride_z_row
+
group
*
N
if
not
IS_RMS_NORM
:
Mean
+=
group
*
M
Rstd
+=
group
*
M
W
+=
group
*
N
if
HAS_BIAS
:
B
+=
group
*
N
# Compute mean and variance
cols
=
tl
.
arange
(
0
,
BLOCK_N
)
x
=
tl
.
load
(
X
+
cols
,
mask
=
cols
<
N
,
other
=
0.
).
to
(
tl
.
float32
)
if
HAS_Z
and
not
NORM_BEFORE_GATE
:
z
=
tl
.
load
(
Z
+
cols
,
mask
=
cols
<
N
).
to
(
tl
.
float32
)
x
*=
z
*
tl
.
sigmoid
(
z
)
if
not
IS_RMS_NORM
:
mean
=
tl
.
sum
(
x
,
axis
=
0
)
/
N
tl
.
store
(
Mean
+
row
,
mean
)
xbar
=
tl
.
where
(
cols
<
N
,
x
-
mean
,
0.
)
var
=
tl
.
sum
(
xbar
*
xbar
,
axis
=
0
)
/
N
else
:
xbar
=
tl
.
where
(
cols
<
N
,
x
,
0.
)
var
=
tl
.
sum
(
xbar
*
xbar
,
axis
=
0
)
/
N
rstd
=
1
/
tl
.
sqrt
(
var
+
eps
)
tl
.
store
(
Rstd
+
row
,
rstd
)
# Normalize and apply linear transformation
mask
=
cols
<
N
w
=
tl
.
load
(
W
+
cols
,
mask
=
mask
).
to
(
tl
.
float32
)
if
HAS_BIAS
:
b
=
tl
.
load
(
B
+
cols
,
mask
=
mask
).
to
(
tl
.
float32
)
x_hat
=
(
x
-
mean
)
*
rstd
if
not
IS_RMS_NORM
else
x
*
rstd
y
=
x_hat
*
w
+
b
if
HAS_BIAS
else
x_hat
*
w
if
HAS_Z
and
NORM_BEFORE_GATE
:
z
=
tl
.
load
(
Z
+
cols
,
mask
=
mask
).
to
(
tl
.
float32
)
y
*=
z
*
tl
.
sigmoid
(
z
)
# Write output
tl
.
store
(
Y
+
cols
,
y
,
mask
=
mask
)
def
layer_norm_fwd
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
,
eps
:
float
,
z
:
torch
.
Tensor
=
None
,
out
:
torch
.
Tensor
=
None
,
group_size
:
int
=
None
,
norm_before_gate
:
bool
=
True
,
is_rms_norm
:
bool
=
False
,
):
M
,
N
=
x
.
shape
if
group_size
is
None
:
group_size
=
N
assert
N
%
group_size
==
0
ngroups
=
N
//
group_size
assert
x
.
stride
(
-
1
)
==
1
if
z
is
not
None
:
assert
z
.
stride
(
-
1
)
==
1
assert
z
.
shape
==
(
M
,
N
)
assert
weight
.
shape
==
(
N
,
)
assert
weight
.
stride
(
-
1
)
==
1
if
bias
is
not
None
:
assert
bias
.
stride
(
-
1
)
==
1
assert
bias
.
shape
==
(
N
,
)
# allocate output
if
out
is
not
None
:
assert
out
.
shape
==
x
.
shape
else
:
out
=
torch
.
empty_like
(
x
)
assert
out
.
stride
(
-
1
)
==
1
mean
=
torch
.
empty
((
ngroups
*
M
,
),
dtype
=
torch
.
float32
,
device
=
x
.
device
)
if
not
is_rms_norm
else
None
rstd
=
torch
.
empty
((
ngroups
*
M
,
),
dtype
=
torch
.
float32
,
device
=
x
.
device
)
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE
=
65536
//
x
.
element_size
()
BLOCK_N
=
min
(
MAX_FUSED_SIZE
,
triton
.
next_power_of_2
(
group_size
))
if
group_size
>
BLOCK_N
:
raise
RuntimeError
(
"This layer norm doesn't support feature dim >= 64KB."
)
# heuristics for number of warps
num_warps
=
min
(
max
(
BLOCK_N
//
256
,
1
),
8
)
grid
=
(
M
,
ngroups
)
layer_norm_fwd_kernel
[
grid
](
x
,
out
,
weight
,
bias
,
z
,
mean
,
rstd
,
x
.
stride
(
0
),
out
.
stride
(
0
),
z
.
stride
(
0
)
if
z
is
not
None
else
0
,
M
,
group_size
,
eps
,
BLOCK_N
=
BLOCK_N
,
NORM_BEFORE_GATE
=
norm_before_gate
,
IS_RMS_NORM
=
is_rms_norm
,
num_warps
=
num_warps
)
return
out
,
mean
,
rstd
class
LayerNormFn
(
torch
.
autograd
.
Function
):
@
input_guard
@
staticmethod
def
forward
(
ctx
,
x
,
weight
,
bias
,
z
=
None
,
eps
=
1e-6
,
group_size
=
None
,
norm_before_gate
=
True
,
is_rms_norm
=
False
):
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))
"""
x_shape_og
=
x
.
shape
# reshape input data into 2D tensor
x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
if
x
.
stride
(
-
1
)
!=
1
:
x
=
x
.
contiguous
()
if
z
is
not
None
:
assert
z
.
shape
==
x_shape_og
z
=
z
.
reshape
(
-
1
,
z
.
shape
[
-
1
])
if
z
.
stride
(
-
1
)
!=
1
:
z
=
z
.
contiguous
()
weight
=
weight
.
contiguous
()
if
bias
is
not
None
:
bias
=
bias
.
contiguous
()
y
,
mean
,
rstd
=
layer_norm_fwd
(
x
,
weight
,
bias
,
eps
,
z
=
z
,
group_size
=
group_size
,
norm_before_gate
=
norm_before_gate
,
is_rms_norm
=
is_rms_norm
,
)
ctx
.
save_for_backward
(
x
,
weight
,
bias
,
mean
,
rstd
,
z
)
ctx
.
x_shape_og
=
x_shape_og
ctx
.
eps
=
eps
ctx
.
group_size
=
group_size
ctx
.
norm_before_gate
=
norm_before_gate
ctx
.
is_rms_norm
=
is_rms_norm
return
y
.
reshape
(
x_shape_og
)
def
layernorm_fn
(
x
,
weight
,
bias
,
z
=
None
,
eps
=
1e-6
,
group_size
=
None
,
norm_before_gate
=
True
,
is_rms_norm
=
False
):
return
LayerNormFn
.
apply
(
x
,
weight
,
bias
,
z
,
eps
,
group_size
,
norm_before_gate
,
is_rms_norm
)
def
rmsnorm_fn
(
x
,
weight
,
bias
,
z
=
None
,
eps
=
1e-6
,
group_size
=
None
,
norm_before_gate
=
True
):
return
LayerNormFn
.
apply
(
x
,
weight
,
bias
,
z
,
eps
,
group_size
,
norm_before_gate
,
True
)
class
LayerNormGated
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
eps
:
float
=
1e-5
,
group_size
:
Optional
[
int
]
=
None
,
norm_before_gate
:
bool
=
True
,
device
:
Optional
[
torch
.
device
]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
):
"""If group_size is not None, we do GroupNorm with each group having group_size elements.
group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group).
"""
factory_kwargs
=
{
"device"
:
device
,
"dtype"
:
dtype
}
super
().
__init__
()
self
.
eps
=
eps
self
.
weight
=
nn
.
Parameter
(
torch
.
empty
(
hidden_size
,
**
factory_kwargs
))
self
.
bias
=
nn
.
Parameter
(
torch
.
empty
(
hidden_size
,
**
factory_kwargs
))
self
.
group_size
=
group_size
self
.
norm_before_gate
=
norm_before_gate
self
.
reset_parameters
()
def
reset_parameters
(
self
):
torch
.
nn
.
init
.
ones_
(
self
.
weight
)
torch
.
nn
.
init
.
zeros_
(
self
.
bias
)
def
forward
(
self
,
x
,
z
=
None
):
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))
"""
return
layernorm_fn
(
x
,
self
.
weight
,
self
.
bias
,
z
=
z
,
group_size
=
self
.
group_size
,
eps
=
self
.
eps
,
norm_before_gate
=
self
.
norm_before_gate
)
class
RMSNormGated
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
eps
:
float
=
1e-5
,
group_size
:
Optional
[
int
]
=
None
,
norm_before_gate
:
bool
=
False
,
device
:
Optional
[
torch
.
device
]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
):
"""If group_size is not None, we do GroupNorm with each group having group_size elements.
group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group).
"""
factory_kwargs
=
{
"device"
:
device
,
"dtype"
:
dtype
}
super
().
__init__
()
self
.
eps
=
eps
self
.
weight
=
nn
.
Parameter
(
torch
.
empty
(
hidden_size
,
**
factory_kwargs
))
self
.
register_parameter
(
"bias"
,
None
)
self
.
group_size
=
group_size
self
.
norm_before_gate
=
norm_before_gate
self
.
reset_parameters
()
def
reset_parameters
(
self
):
torch
.
nn
.
init
.
ones_
(
self
.
weight
)
def
forward
(
self
,
x
,
z
=
None
):
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))
"""
return
rmsnorm_fn
(
x
,
self
.
weight
,
self
.
bias
,
z
=
z
,
eps
=
self
.
eps
,
group_size
=
self
.
group_size
,
norm_before_gate
=
self
.
norm_before_gate
)
vllm/model_executor/layers/fla/ops/op.py
0 → 100644
View file @
1aa427fd
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
#
# This file contains code copied from the flash-linear-attention project.
# The original source code was licensed under the MIT license and included
# the following copyright notice:
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
import
os
from
vllm.triton_utils
import
tl
,
tldevice
,
triton
if
os
.
environ
.
get
(
'FLA_USE_FAST_OPS'
,
'0'
)
==
'1'
:
div
=
tldevice
.
fast_dividef
exp
=
tldevice
.
fast_expf
log
=
tldevice
.
fast_logf
log2
=
tldevice
.
fast_log2f
else
:
@
triton
.
jit
def
div_normal
(
x
,
y
):
return
x
/
y
div
=
div_normal
exp
=
tl
.
exp
log
=
tl
.
log
log2
=
tl
.
log2
@
triton
.
jit
def
safe_exp
(
x
):
return
exp
(
tl
.
where
(
x
<=
0
,
x
,
float
(
'-inf'
)))
if
not
hasattr
(
tl
,
'gather'
):
@
triton
.
jit
def
gather
(
src
,
index
,
axis
,
_builder
=
None
):
# This is a fallback implementation when tl.gather is not supported
# In order to pass triton compiler, there is no actual gather operation
return
src
else
:
gather
=
tl
.
gather
vllm/model_executor/layers/fla/ops/solve_tril.py
0 → 100644
View file @
1aa427fd
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
#
# This file contains code copied from the flash-linear-attention project.
# The original source code was licensed under the MIT license and included
# the following copyright notice:
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
# ruff: noqa: E501
from
typing
import
Optional
import
torch
from
vllm.triton_utils
import
tl
,
triton
from
.index
import
prepare_chunk_indices
from
.utils
import
input_guard
@
triton
.
heuristics
({
'IS_VARLEN'
:
lambda
args
:
args
[
'cu_seqlens'
]
is
not
None
})
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({},
num_warps
=
num_warps
,
num_stages
=
num_stages
)
for
num_warps
in
[
1
,
2
,
4
,
8
]
for
num_stages
in
[
2
,
3
,
4
,
5
]
],
key
=
[
'BT'
],
)
@
triton
.
jit
(
do_not_specialize
=
[
'T'
])
def
solve_tril_16x16_kernel
(
A
,
Ad
,
cu_seqlens
,
chunk_indices
,
T
,
H
:
tl
.
constexpr
,
BT
:
tl
.
constexpr
,
IS_VARLEN
:
tl
.
constexpr
,
):
i_t
,
i_bh
=
tl
.
program_id
(
0
),
tl
.
program_id
(
1
)
i_b
,
i_h
=
i_bh
//
H
,
i_bh
%
H
if
IS_VARLEN
:
i_n
,
i_t
=
tl
.
load
(
chunk_indices
+
i_t
*
2
).
to
(
tl
.
int32
),
tl
.
load
(
chunk_indices
+
i_t
*
2
+
1
).
to
(
tl
.
int32
)
bos
,
eos
=
tl
.
load
(
cu_seqlens
+
i_n
).
to
(
tl
.
int32
),
tl
.
load
(
cu_seqlens
+
i_n
+
1
).
to
(
tl
.
int32
)
T
=
eos
-
bos
else
:
bos
,
eos
=
i_b
*
T
,
i_b
*
T
+
T
A
=
A
+
(
bos
*
H
+
i_h
)
*
BT
Ad
=
Ad
+
(
bos
*
H
+
i_h
)
*
16
offset
=
(
i_t
*
16
)
%
BT
p_A
=
tl
.
make_block_ptr
(
A
,
(
T
,
BT
),
(
H
*
BT
,
1
),
(
i_t
*
16
,
offset
),
(
16
,
16
),
(
1
,
0
))
p_Ai
=
tl
.
make_block_ptr
(
Ad
,
(
T
,
16
),
(
H
*
16
,
1
),
(
i_t
*
16
,
0
),
(
16
,
16
),
(
1
,
0
))
b_A
=
tl
.
load
(
p_A
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
b_A
=
-
tl
.
where
(
tl
.
arange
(
0
,
16
)[:,
None
]
>
tl
.
arange
(
0
,
16
)[
None
,
:],
b_A
,
0
)
o_i
=
tl
.
arange
(
0
,
16
)
for
i
in
range
(
1
,
min
(
16
,
T
-
i_t
*
16
)):
b_a
=
-
tl
.
load
(
A
+
(
i_t
*
16
+
i
)
*
H
*
BT
+
o_i
+
offset
)
b_a
=
b_a
+
tl
.
sum
(
b_a
[:,
None
]
*
b_A
,
0
)
mask
=
o_i
==
i
b_A
=
tl
.
where
(
mask
[:,
None
],
b_a
,
b_A
)
b_A
+=
o_i
[:,
None
]
==
o_i
[
None
,
:]
tl
.
store
(
p_Ai
,
b_A
.
to
(
p_Ai
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
))
@
triton
.
heuristics
({
'IS_VARLEN'
:
lambda
args
:
args
[
'cu_seqlens'
]
is
not
None
})
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({},
num_warps
=
num_warps
,
num_stages
=
num_stages
)
for
num_warps
in
[
1
,
2
,
4
,
8
]
for
num_stages
in
[
2
,
3
,
4
,
5
]
],
key
=
[
'H'
,
'BT'
,
'IS_VARLEN'
],
)
@
triton
.
jit
(
do_not_specialize
=
[
'T'
])
def
merge_16x16_to_32x32_inverse_kernel
(
A
,
Ad
,
Ai
,
cu_seqlens
,
chunk_indices
,
T
,
H
:
tl
.
constexpr
,
BT
:
tl
.
constexpr
,
IS_VARLEN
:
tl
.
constexpr
):
i_t
,
i_bh
=
tl
.
program_id
(
0
),
tl
.
program_id
(
1
)
i_b
,
i_h
=
i_bh
//
H
,
i_bh
%
H
if
IS_VARLEN
:
i_n
,
i_t
=
tl
.
load
(
chunk_indices
+
i_t
*
2
).
to
(
tl
.
int32
),
tl
.
load
(
chunk_indices
+
i_t
*
2
+
1
).
to
(
tl
.
int32
)
bos
,
eos
=
tl
.
load
(
cu_seqlens
+
i_n
).
to
(
tl
.
int32
),
tl
.
load
(
cu_seqlens
+
i_n
+
1
).
to
(
tl
.
int32
)
T
=
eos
-
bos
else
:
bos
,
eos
=
i_b
*
T
,
i_b
*
T
+
T
A
+=
(
bos
*
H
+
i_h
)
*
32
Ad
+=
(
bos
*
H
+
i_h
)
*
16
Ai
+=
(
bos
*
H
+
i_h
)
*
32
p_A_21
=
tl
.
make_block_ptr
(
A
,
(
T
,
32
),
(
H
*
32
,
1
),
(
i_t
*
32
+
16
,
0
),
(
16
,
16
),
(
1
,
0
))
p_Ad_11
=
tl
.
make_block_ptr
(
Ad
,
(
T
,
16
),
(
H
*
16
,
1
),
(
i_t
*
32
,
0
),
(
16
,
16
),
(
1
,
0
))
p_Ad_22
=
tl
.
make_block_ptr
(
Ad
,
(
T
,
16
),
(
H
*
16
,
1
),
(
i_t
*
32
+
16
,
0
),
(
16
,
16
),
(
1
,
0
))
p_Ai_11
=
tl
.
make_block_ptr
(
Ai
,
(
T
,
32
),
(
H
*
32
,
1
),
(
i_t
*
32
,
0
),
(
16
,
16
),
(
1
,
0
))
p_Ai_22
=
tl
.
make_block_ptr
(
Ai
,
(
T
,
32
),
(
H
*
32
,
1
),
(
i_t
*
32
+
16
,
16
),
(
16
,
16
),
(
1
,
0
))
p_Ai_21
=
tl
.
make_block_ptr
(
Ai
,
(
T
,
32
),
(
H
*
32
,
1
),
(
i_t
*
32
+
16
,
0
),
(
16
,
16
),
(
1
,
0
))
A_21
=
tl
.
load
(
p_A_21
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
Ai_11
=
tl
.
load
(
p_Ad_11
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
Ai_22
=
tl
.
load
(
p_Ad_22
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
Ai_21
=
-
tl
.
dot
(
tl
.
dot
(
Ai_22
,
A_21
,
input_precision
=
'ieee'
),
Ai_11
,
input_precision
=
'ieee'
)
tl
.
store
(
p_Ai_11
,
Ai_11
.
to
(
p_Ai_11
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
))
tl
.
store
(
p_Ai_22
,
Ai_22
.
to
(
p_Ai_22
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
))
tl
.
store
(
p_Ai_21
,
Ai_21
.
to
(
p_Ai_21
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
))
@
triton
.
heuristics
({
'IS_VARLEN'
:
lambda
args
:
args
[
'cu_seqlens'
]
is
not
None
})
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({},
num_warps
=
num_warps
,
num_stages
=
num_stages
)
for
num_warps
in
[
2
,
4
,
8
]
for
num_stages
in
[
2
,
3
,
4
,
5
]
],
key
=
[
'H'
,
'BT'
,
'IS_VARLEN'
],
)
@
triton
.
jit
(
do_not_specialize
=
[
'T'
])
def
merge_16x16_to_64x64_inverse_kernel
(
A
,
Ad
,
Ai
,
cu_seqlens
,
chunk_indices
,
T
,
H
:
tl
.
constexpr
,
BT
:
tl
.
constexpr
,
IS_VARLEN
:
tl
.
constexpr
):
i_t
,
i_bh
=
tl
.
program_id
(
0
),
tl
.
program_id
(
1
)
i_b
,
i_h
=
i_bh
//
H
,
i_bh
%
H
if
IS_VARLEN
:
i_n
,
i_t
=
tl
.
load
(
chunk_indices
+
i_t
*
2
).
to
(
tl
.
int32
),
tl
.
load
(
chunk_indices
+
i_t
*
2
+
1
).
to
(
tl
.
int32
)
bos
,
eos
=
tl
.
load
(
cu_seqlens
+
i_n
).
to
(
tl
.
int32
),
tl
.
load
(
cu_seqlens
+
i_n
+
1
).
to
(
tl
.
int32
)
T
=
eos
-
bos
else
:
bos
,
eos
=
i_b
*
T
,
i_b
*
T
+
T
A
+=
(
bos
*
H
+
i_h
)
*
64
Ad
+=
(
bos
*
H
+
i_h
)
*
16
Ai
+=
(
bos
*
H
+
i_h
)
*
64
p_A_21
=
tl
.
make_block_ptr
(
A
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
+
16
,
0
),
(
16
,
16
),
(
1
,
0
))
p_A_32
=
tl
.
make_block_ptr
(
A
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
+
32
,
16
),
(
16
,
16
),
(
1
,
0
))
p_A_31
=
tl
.
make_block_ptr
(
A
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
+
32
,
0
),
(
16
,
16
),
(
1
,
0
))
p_A_43
=
tl
.
make_block_ptr
(
A
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
+
48
,
32
),
(
16
,
16
),
(
1
,
0
))
p_A_42
=
tl
.
make_block_ptr
(
A
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
+
48
,
16
),
(
16
,
16
),
(
1
,
0
))
p_A_41
=
tl
.
make_block_ptr
(
A
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
+
48
,
0
),
(
16
,
16
),
(
1
,
0
))
p_Ad_11
=
tl
.
make_block_ptr
(
Ad
,
(
T
,
16
),
(
H
*
16
,
1
),
(
i_t
*
64
,
0
),
(
16
,
16
),
(
1
,
0
))
p_Ad_22
=
tl
.
make_block_ptr
(
Ad
,
(
T
,
16
),
(
H
*
16
,
1
),
(
i_t
*
64
+
16
,
0
),
(
16
,
16
),
(
1
,
0
))
p_Ad_33
=
tl
.
make_block_ptr
(
Ad
,
(
T
,
16
),
(
H
*
16
,
1
),
(
i_t
*
64
+
32
,
0
),
(
16
,
16
),
(
1
,
0
))
p_Ad_44
=
tl
.
make_block_ptr
(
Ad
,
(
T
,
16
),
(
H
*
16
,
1
),
(
i_t
*
64
+
48
,
0
),
(
16
,
16
),
(
1
,
0
))
A_21
=
tl
.
load
(
p_A_21
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
A_32
=
tl
.
load
(
p_A_32
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
A_31
=
tl
.
load
(
p_A_31
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
A_43
=
tl
.
load
(
p_A_43
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
A_42
=
tl
.
load
(
p_A_42
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
A_41
=
tl
.
load
(
p_A_41
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
Ai_11
=
tl
.
load
(
p_Ad_11
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
Ai_22
=
tl
.
load
(
p_Ad_22
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
Ai_33
=
tl
.
load
(
p_Ad_33
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
Ai_44
=
tl
.
load
(
p_Ad_44
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
Ai_21
=
-
tl
.
dot
(
tl
.
dot
(
Ai_22
,
A_21
,
input_precision
=
'ieee'
),
Ai_11
,
input_precision
=
'ieee'
)
Ai_32
=
-
tl
.
dot
(
tl
.
dot
(
Ai_33
,
A_32
,
input_precision
=
'ieee'
),
Ai_22
,
input_precision
=
'ieee'
)
Ai_43
=
-
tl
.
dot
(
tl
.
dot
(
Ai_44
,
A_43
,
input_precision
=
'ieee'
),
Ai_33
,
input_precision
=
'ieee'
)
Ai_31
=
-
tl
.
dot
(
Ai_33
,
tl
.
dot
(
A_31
,
Ai_11
,
input_precision
=
'ieee'
)
+
tl
.
dot
(
A_32
,
Ai_21
,
input_precision
=
'ieee'
),
input_precision
=
'ieee'
)
Ai_42
=
-
tl
.
dot
(
Ai_44
,
tl
.
dot
(
A_42
,
Ai_22
,
input_precision
=
'ieee'
)
+
tl
.
dot
(
A_43
,
Ai_32
,
input_precision
=
'ieee'
),
input_precision
=
'ieee'
)
Ai_41
=
-
tl
.
dot
(
Ai_44
,
tl
.
dot
(
A_41
,
Ai_11
,
input_precision
=
'ieee'
)
+
tl
.
dot
(
A_42
,
Ai_21
,
input_precision
=
'ieee'
)
+
tl
.
dot
(
A_43
,
Ai_31
,
input_precision
=
'ieee'
),
input_precision
=
'ieee'
)
p_Ai_11
=
tl
.
make_block_ptr
(
Ai
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
,
0
),
(
16
,
16
),
(
1
,
0
))
p_Ai_22
=
tl
.
make_block_ptr
(
Ai
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
+
16
,
16
),
(
16
,
16
),
(
1
,
0
))
p_Ai_33
=
tl
.
make_block_ptr
(
Ai
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
+
32
,
32
),
(
16
,
16
),
(
1
,
0
))
p_Ai_44
=
tl
.
make_block_ptr
(
Ai
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
+
48
,
48
),
(
16
,
16
),
(
1
,
0
))
p_Ai_21
=
tl
.
make_block_ptr
(
Ai
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
+
16
,
0
),
(
16
,
16
),
(
1
,
0
))
p_Ai_31
=
tl
.
make_block_ptr
(
Ai
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
+
32
,
0
),
(
16
,
16
),
(
1
,
0
))
p_Ai_32
=
tl
.
make_block_ptr
(
Ai
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
+
32
,
16
),
(
16
,
16
),
(
1
,
0
))
p_Ai_41
=
tl
.
make_block_ptr
(
Ai
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
+
48
,
0
),
(
16
,
16
),
(
1
,
0
))
p_Ai_42
=
tl
.
make_block_ptr
(
Ai
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
+
48
,
16
),
(
16
,
16
),
(
1
,
0
))
p_Ai_43
=
tl
.
make_block_ptr
(
Ai
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
+
48
,
32
),
(
16
,
16
),
(
1
,
0
))
tl
.
store
(
p_Ai_11
,
Ai_11
.
to
(
p_Ai_11
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
))
tl
.
store
(
p_Ai_22
,
Ai_22
.
to
(
p_Ai_22
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
))
tl
.
store
(
p_Ai_33
,
Ai_33
.
to
(
p_Ai_33
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
))
tl
.
store
(
p_Ai_44
,
Ai_44
.
to
(
p_Ai_44
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
))
tl
.
store
(
p_Ai_21
,
Ai_21
.
to
(
p_Ai_21
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
))
tl
.
store
(
p_Ai_31
,
Ai_31
.
to
(
p_Ai_31
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
))
tl
.
store
(
p_Ai_32
,
Ai_32
.
to
(
p_Ai_32
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
))
tl
.
store
(
p_Ai_41
,
Ai_41
.
to
(
p_Ai_41
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
))
tl
.
store
(
p_Ai_42
,
Ai_42
.
to
(
p_Ai_42
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
))
tl
.
store
(
p_Ai_43
,
Ai_43
.
to
(
p_Ai_43
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
))
fill_zeros
=
tl
.
zeros
((
16
,
16
),
dtype
=
tl
.
float32
)
p_Ai_12
=
tl
.
make_block_ptr
(
Ai
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
,
16
),
(
16
,
16
),
(
1
,
0
))
p_Ai_13
=
tl
.
make_block_ptr
(
Ai
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
,
32
),
(
16
,
16
),
(
1
,
0
))
p_Ai_14
=
tl
.
make_block_ptr
(
Ai
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
,
48
),
(
16
,
16
),
(
1
,
0
))
p_Ai_23
=
tl
.
make_block_ptr
(
Ai
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
+
16
,
32
),
(
16
,
16
),
(
1
,
0
))
p_Ai_24
=
tl
.
make_block_ptr
(
Ai
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
+
16
,
48
),
(
16
,
16
),
(
1
,
0
))
p_Ai_34
=
tl
.
make_block_ptr
(
Ai
,
(
T
,
64
),
(
H
*
64
,
1
),
(
i_t
*
64
+
32
,
48
),
(
16
,
16
),
(
1
,
0
))
tl
.
store
(
p_Ai_12
,
fill_zeros
.
to
(
p_Ai_12
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
))
tl
.
store
(
p_Ai_13
,
fill_zeros
.
to
(
p_Ai_13
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
))
tl
.
store
(
p_Ai_14
,
fill_zeros
.
to
(
p_Ai_14
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
))
tl
.
store
(
p_Ai_23
,
fill_zeros
.
to
(
p_Ai_23
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
))
tl
.
store
(
p_Ai_24
,
fill_zeros
.
to
(
p_Ai_24
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
))
tl
.
store
(
p_Ai_34
,
fill_zeros
.
to
(
p_Ai_34
.
dtype
.
element_ty
,
fp_downcast_rounding
=
"rtne"
),
boundary_check
=
(
0
,
1
))
@
input_guard
def
solve_tril
(
A
:
torch
.
Tensor
,
cu_seqlens
:
Optional
[
torch
.
Tensor
]
=
None
,
output_dtype
:
torch
.
dtype
=
torch
.
float
)
->
torch
.
Tensor
:
"""
Compute the inverse of the lower triangular matrix
A should be strictly lower triangular, i.e., A.triu() == 0.
Args:
A (torch.Tensor):
[B, T, H, K]
cu_seqlens (torch.Tensor):
The cumulative sequence lengths of the input tensor.
Default: None.
output_dtype (torch.dtype):
The dtype of the output tensor. Default: `torch.float`
Returns:
(I + A)^-1 with the same shape as A
"""
assert
A
.
shape
[
-
1
]
in
[
16
,
32
,
64
]
B
,
T
,
H
,
BT
=
A
.
shape
Ad
=
torch
.
empty
(
B
,
T
,
H
,
16
,
device
=
A
.
device
,
dtype
=
torch
.
float
if
BT
!=
16
else
output_dtype
)
chunk_indices
=
prepare_chunk_indices
(
cu_seqlens
,
16
)
if
cu_seqlens
is
not
None
else
None
NT
=
len
(
chunk_indices
)
if
cu_seqlens
is
not
None
else
triton
.
cdiv
(
T
,
16
)
solve_tril_16x16_kernel
[
NT
,
B
*
H
](
A
=
A
,
Ad
=
Ad
,
cu_seqlens
=
cu_seqlens
,
chunk_indices
=
chunk_indices
,
T
=
T
,
H
=
H
,
BT
=
BT
,
)
if
BT
==
16
:
return
Ad
Ai
=
torch
.
empty
(
B
,
T
,
H
,
BT
,
device
=
A
.
device
,
dtype
=
output_dtype
)
merge_fn
=
merge_16x16_to_32x32_inverse_kernel
if
BT
==
32
else
merge_16x16_to_64x64_inverse_kernel
chunk_indices
=
prepare_chunk_indices
(
cu_seqlens
,
BT
)
if
cu_seqlens
is
not
None
else
None
NT
=
len
(
chunk_indices
)
if
cu_seqlens
is
not
None
else
triton
.
cdiv
(
T
,
BT
)
merge_fn
[
NT
,
B
*
H
](
A
=
A
,
Ad
=
Ad
,
Ai
=
Ai
,
cu_seqlens
=
cu_seqlens
,
chunk_indices
=
chunk_indices
,
T
=
T
,
H
=
H
,
BT
=
BT
,
)
return
Ai
vllm/model_executor/layers/fla/ops/utils.py
0 → 100644
View file @
1aa427fd
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
#
# This file contains code copied from the flash-linear-attention project.
# The original source code was licensed under the MIT license and included
# the following copyright notice:
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
# ruff: noqa: E501
import
contextlib
import
functools
import
logging
import
os
from
enum
import
Enum
from
typing
import
Any
,
Callable
,
Literal
,
Optional
import
torch
from
vllm.triton_utils
import
triton
logger
=
logging
.
getLogger
(
__name__
)
COMPILER_MODE
=
os
.
getenv
(
"FLA_COMPILER_MODE"
)
==
"1"
FLA_CI_ENV
=
os
.
getenv
(
"FLA_CI_ENV"
)
==
"1"
FLA_GDN_FIX_BT
=
os
.
getenv
(
"FLA_GDN_FIX_BT"
,
"0"
)
==
"1"
SUPPRESS_LEVEL
=
int
(
os
.
getenv
(
"GDN_RECOMPUTE_SUPPRESS_LEVEL"
,
"0"
))
def
tensor_cache
(
fn
:
Callable
[...,
torch
.
Tensor
])
->
Callable
[...,
torch
.
Tensor
]:
"""
A decorator that caches the most recent results of a function with tensor inputs.
This decorator will store the output of the decorated function for the most recent set of input tensors.
The cache is limited to a fixed size (default is 4). When the cache is full, the oldest entry will be removed.
Args:
fn (Callable[..., torch.Tensor]):
The function to be decorated. It should take tensor inputs and return tensor outputs.
Returns:
Callable[..., torch.Tensor]:
A wrapped version of the input function with single-entry caching.
"""
cache_entries
:
tuple
[
Optional
[
tuple
],
Optional
[
dict
],
Any
]
=
[]
cache_size
=
4
@
functools
.
wraps
(
fn
)
def
wrapper
(
*
args
:
Any
,
**
kwargs
:
Any
)
->
Any
:
nonlocal
cache_entries
,
cache_size
for
i
,
entry
in
enumerate
(
cache_entries
):
last_args
,
last_kwargs
,
last_result
=
entry
if
len
(
args
)
==
len
(
last_args
)
and
len
(
kwargs
)
==
len
(
last_kwargs
)
\
and
all
(
a
is
b
for
a
,
b
in
zip
(
args
,
last_args
))
\
and
all
(
k
in
last_kwargs
and
v
is
last_kwargs
[
k
]
for
k
,
v
in
kwargs
.
items
()):
cache_entries
=
cache_entries
[:
i
]
+
cache_entries
[
i
+
1
:]
+
[
(
args
,
kwargs
,
last_result
)
]
return
last_result
result
=
fn
(
*
args
,
**
kwargs
)
if
len
(
cache_entries
)
>=
cache_size
:
cache_entries
=
cache_entries
[
1
:]
cache_entries
.
append
((
args
,
kwargs
,
result
))
return
result
return
wrapper
def
input_guard
(
fn
:
Callable
[...,
torch
.
Tensor
])
->
Callable
[...,
torch
.
Tensor
]:
"""
A decorator to make sure all input tensors are contiguous and set the device based on input tensors.
"""
@
functools
.
wraps
(
fn
)
def
wrapper
(
*
args
,
**
kwargs
):
contiguous_args
=
(
i
if
not
isinstance
(
i
,
torch
.
Tensor
)
else
i
.
contiguous
()
for
i
in
args
)
contiguous_kwargs
=
{
k
:
(
v
if
not
isinstance
(
v
,
torch
.
Tensor
)
else
v
.
contiguous
())
for
k
,
v
in
kwargs
.
items
()
}
tensor
=
None
for
arg
in
args
:
if
isinstance
(
arg
,
torch
.
Tensor
):
tensor
=
arg
break
if
tensor
is
None
:
for
value
in
kwargs
.
values
():
if
isinstance
(
value
,
torch
.
Tensor
):
tensor
=
value
break
if
tensor
is
not
None
:
ctx
=
torch
.
cuda
.
device
(
tensor
.
device
.
index
)
else
:
ctx
=
contextlib
.
nullcontext
()
with
ctx
:
return
fn
(
*
contiguous_args
,
**
contiguous_kwargs
)
return
wrapper
@
functools
.
cache
def
get_available_device
()
->
str
:
try
:
return
triton
.
runtime
.
driver
.
active
.
get_current_target
().
backend
except
BaseException
:
return
'cpu'
@
functools
.
cache
def
_check_platform
()
->
Literal
[
'nvidia'
,
'amd'
,
'intel'
,
'musa'
]:
device
=
get_available_device
()
mapping
=
{
"cuda"
:
"nvidia"
,
"hip"
:
"amd"
,
"xpu"
:
"intel"
,
}
# return the mapped value, or the original if not found
return
mapping
.
get
(
device
,
device
)
# For AMD GPUs, the triton backend is 'hip', while for Nvidia GPUs, the triton backend is 'cuda'.
# However, the torch backend is 'cuda' for both Nvidia and AMD GPUs.
# Therefore, we need to check the triton backend to determine the actual GPU vendor.
device
=
get_available_device
()
if
get_available_device
()
!=
'hip'
else
'cuda'
device_torch_lib
=
getattr
(
torch
,
device
)
device_platform
=
_check_platform
()
is_amd
=
(
device_platform
==
'amd'
)
is_intel
=
(
device_platform
==
'intel'
)
is_nvidia
=
(
device_platform
==
'nvidia'
)
is_intel_alchemist
=
(
is_intel
and
'Intel(R) Arc(TM) A'
in
torch
.
xpu
.
get_device_name
(
0
))
is_nvidia_hopper
=
(
is_nvidia
and
(
'NVIDIA H'
in
torch
.
cuda
.
get_device_name
(
0
)
or
torch
.
cuda
.
get_device_capability
()[
0
]
>=
9
))
use_cuda_graph
=
(
is_nvidia
and
os
.
environ
.
get
(
'FLA_USE_CUDA_GRAPH'
,
'0'
)
==
'1'
)
def
get_all_max_shared_mem
():
try
:
return
[
triton
.
runtime
.
driver
.
active
.
utils
.
get_device_properties
(
i
)
[
'max_shared_mem'
]
for
i
in
range
(
device_torch_lib
.
device_count
())
]
except
BaseException
:
return
[
-
1
]
class
Backend
(
Enum
):
ADA
=
101376
# RTX 4090
AMPERE
=
166912
# A100
HOPPER
=
232448
# H100
DEFAULT
=
102400
# Default
@
classmethod
def
get_shared_memory
(
cls
,
arch
:
str
)
->
int
:
try
:
return
cls
[
arch
.
upper
()].
value
except
KeyError
:
return
cls
.
DEFAULT
.
value
@
functools
.
cache
def
check_shared_mem
(
arch
:
str
=
"none"
,
tensor_idx
:
int
=
0
)
->
bool
:
try
:
device_shared_mem_list
=
get_all_max_shared_mem
()
max_shared_memory
=
device_shared_mem_list
[
tensor_idx
]
return
max_shared_memory
>=
Backend
.
get_shared_memory
(
arch
)
except
Exception
:
return
False
vllm/model_executor/layers/fla/ops/wy_fast.py
0 → 100644
View file @
1aa427fd
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
#
# This file contains code copied from the flash-linear-attention project.
# The original source code was licensed under the MIT license and included
# the following copyright notice:
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
# ruff: noqa: E501
from
typing
import
Optional
import
torch
from
vllm.triton_utils
import
tl
,
triton
from
.index
import
prepare_chunk_indices
@
triton
.
heuristics
({
'IS_VARLEN'
:
lambda
args
:
args
[
'cu_seqlens'
]
is
not
None
})
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({},
num_warps
=
num_warps
,
num_stages
=
num_stages
)
for
num_warps
in
[
2
,
4
,
8
]
for
num_stages
in
[
2
,
3
,
4
]
],
key
=
[
'H'
,
'K'
,
'V'
,
'BT'
,
'BK'
,
'BV'
,
'IS_VARLEN'
],
)
@
triton
.
jit
(
do_not_specialize
=
[
'T'
])
def
recompute_w_u_fwd_kernel
(
k
,
v
,
beta
,
w
,
u
,
A
,
g
,
cu_seqlens
,
chunk_indices
,
T
,
H
:
tl
.
constexpr
,
Hg
:
tl
.
constexpr
,
K
:
tl
.
constexpr
,
V
:
tl
.
constexpr
,
BT
:
tl
.
constexpr
,
BK
:
tl
.
constexpr
,
BV
:
tl
.
constexpr
,
IS_VARLEN
:
tl
.
constexpr
):
i_t
,
i_bh
=
tl
.
program_id
(
0
),
tl
.
program_id
(
1
)
i_b
,
i_h
=
i_bh
//
H
,
i_bh
%
H
if
IS_VARLEN
:
i_n
,
i_t
=
tl
.
load
(
chunk_indices
+
i_t
*
2
).
to
(
tl
.
int32
),
tl
.
load
(
chunk_indices
+
i_t
*
2
+
1
).
to
(
tl
.
int32
)
bos
,
eos
=
tl
.
load
(
cu_seqlens
+
i_n
).
to
(
tl
.
int32
),
tl
.
load
(
cu_seqlens
+
i_n
+
1
).
to
(
tl
.
int32
)
T
=
eos
-
bos
else
:
bos
,
eos
=
i_b
*
T
,
i_b
*
T
+
T
p_beta
=
tl
.
make_block_ptr
(
beta
+
bos
*
H
+
i_h
,
(
T
,
),
(
H
,
),
(
i_t
*
BT
,
),
(
BT
,
),
(
0
,
))
p_g
=
tl
.
make_block_ptr
(
g
+
(
bos
*
H
+
i_h
),
(
T
,
),
(
H
,
),
(
i_t
*
BT
,
),
(
BT
,
),
(
0
,
))
p_A
=
tl
.
make_block_ptr
(
A
+
(
bos
*
H
+
i_h
)
*
BT
,
(
T
,
BT
),
(
H
*
BT
,
1
),
(
i_t
*
BT
,
0
),
(
BT
,
BT
),
(
1
,
0
))
b_beta
=
tl
.
load
(
p_beta
,
boundary_check
=
(
0
,
))
b_A
=
tl
.
load
(
p_A
,
boundary_check
=
(
0
,
1
))
b_g
=
tl
.
exp
(
tl
.
load
(
p_g
,
boundary_check
=
(
0
,
)))
for
i_v
in
range
(
tl
.
cdiv
(
V
,
BV
)):
p_v
=
tl
.
make_block_ptr
(
v
+
(
bos
*
H
+
i_h
)
*
V
,
(
T
,
V
),
(
H
*
V
,
1
),
(
i_t
*
BT
,
i_v
*
BV
),
(
BT
,
BV
),
(
1
,
0
))
p_u
=
tl
.
make_block_ptr
(
u
+
(
bos
*
H
+
i_h
)
*
V
,
(
T
,
V
),
(
H
*
V
,
1
),
(
i_t
*
BT
,
i_v
*
BV
),
(
BT
,
BV
),
(
1
,
0
))
b_v
=
tl
.
load
(
p_v
,
boundary_check
=
(
0
,
1
))
b_vb
=
(
b_v
*
b_beta
[:,
None
]).
to
(
b_v
.
dtype
)
b_u
=
tl
.
dot
(
b_A
,
b_vb
,
allow_tf32
=
False
)
tl
.
store
(
p_u
,
b_u
.
to
(
p_u
.
dtype
.
element_ty
),
boundary_check
=
(
0
,
1
))
for
i_k
in
range
(
tl
.
cdiv
(
K
,
BK
)):
p_k
=
tl
.
make_block_ptr
(
k
+
(
bos
*
Hg
+
i_h
//
(
H
//
Hg
))
*
K
,
(
T
,
K
),
(
Hg
*
K
,
1
),
(
i_t
*
BT
,
i_k
*
BK
),
(
BT
,
BK
),
(
1
,
0
))
p_w
=
tl
.
make_block_ptr
(
w
+
(
bos
*
H
+
i_h
)
*
K
,
(
T
,
K
),
(
H
*
K
,
1
),
(
i_t
*
BT
,
i_k
*
BK
),
(
BT
,
BK
),
(
1
,
0
))
b_k
=
tl
.
load
(
p_k
,
boundary_check
=
(
0
,
1
))
b_kb
=
(
b_k
*
b_beta
[:,
None
]
*
b_g
[:,
None
]).
to
(
b_k
.
dtype
)
b_w
=
tl
.
dot
(
b_A
,
b_kb
)
tl
.
store
(
p_w
,
b_w
.
to
(
p_w
.
dtype
.
element_ty
),
boundary_check
=
(
0
,
1
))
def
recompute_w_u_fwd
(
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
beta
:
torch
.
Tensor
,
g_cumsum
:
torch
.
Tensor
,
A
:
torch
.
Tensor
,
cu_seqlens
:
Optional
[
torch
.
LongTensor
],
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
B
,
T
,
Hg
,
K
,
V
=
*
k
.
shape
,
v
.
shape
[
-
1
]
H
=
v
.
shape
[
-
2
]
BT
=
A
.
shape
[
-
1
]
chunk_indices
=
prepare_chunk_indices
(
cu_seqlens
,
BT
)
if
cu_seqlens
is
not
None
else
None
NT
=
triton
.
cdiv
(
T
,
BT
)
if
cu_seqlens
is
None
else
len
(
chunk_indices
)
BK
=
64
BV
=
64
u
=
torch
.
empty_like
(
v
)
w
=
k
.
new_empty
(
B
,
T
,
H
,
K
)
recompute_w_u_fwd_kernel
[(
NT
,
B
*
H
)](
k
=
k
,
v
=
v
,
beta
=
beta
,
w
=
w
,
u
=
u
,
A
=
A
,
g
=
g_cumsum
,
cu_seqlens
=
cu_seqlens
,
chunk_indices
=
chunk_indices
,
T
=
T
,
H
=
H
,
Hg
=
Hg
,
K
=
K
,
V
=
V
,
BT
=
BT
,
BK
=
BK
,
BV
=
BV
,
)
return
w
,
u
vllm/triton_utils/__init__.py
View file @
1aa427fd
...
@@ -7,8 +7,10 @@ from vllm.triton_utils.importing import (HAS_TRITON, TritonLanguagePlaceholder,
...
@@ -7,8 +7,10 @@ from vllm.triton_utils.importing import (HAS_TRITON, TritonLanguagePlaceholder,
if
HAS_TRITON
:
if
HAS_TRITON
:
import
triton
import
triton
import
triton.language
as
tl
import
triton.language
as
tl
import
triton.language.extra.libdevice
as
tldevice
else
:
else
:
triton
=
TritonPlaceholder
()
triton
=
TritonPlaceholder
()
tl
=
TritonLanguagePlaceholder
()
tl
=
TritonLanguagePlaceholder
()
tldevice
=
TritonLanguagePlaceholder
()
__all__
=
[
"HAS_TRITON"
,
"triton"
,
"tl"
]
__all__
=
[
"HAS_TRITON"
,
"triton"
,
"tl"
,
"tldevice"
]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment