Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e88bdd60
Unverified
Commit
e88bdd60
authored
Oct 28, 2025
by
Zhiyuan Li
Committed by
GitHub
Oct 28, 2025
Browse files
[FLA] Introduce Kimi Delta Attention(KDA) to VLLM (#27654)
Signed-off-by:
lizhiyuan
<
lizhiyuan@moonshot.cn
>
parent
05e034f0
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
1451 additions
and
58 deletions
+1451
-58
vllm/model_executor/layers/fla/ops/chunk.py
vllm/model_executor/layers/fla/ops/chunk.py
+1
-1
vllm/model_executor/layers/fla/ops/chunk_delta_h.py
vllm/model_executor/layers/fla/ops/chunk_delta_h.py
+70
-40
vllm/model_executor/layers/fla/ops/chunk_scaled_dot_kkt.py
vllm/model_executor/layers/fla/ops/chunk_scaled_dot_kkt.py
+11
-13
vllm/model_executor/layers/fla/ops/fused_recurrent.py
vllm/model_executor/layers/fla/ops/fused_recurrent.py
+18
-4
vllm/model_executor/layers/fla/ops/kda.py
vllm/model_executor/layers/fla/ops/kda.py
+1351
-0
No files found.
vllm/model_executor/layers/fla/ops/chunk.py
View file @
e88bdd60
...
@@ -36,7 +36,7 @@ def chunk_gated_delta_rule_fwd(
...
@@ -36,7 +36,7 @@ def chunk_gated_delta_rule_fwd(
g
=
chunk_local_cumsum
(
g
,
chunk_size
=
64
,
cu_seqlens
=
cu_seqlens
)
g
=
chunk_local_cumsum
(
g
,
chunk_size
=
64
,
cu_seqlens
=
cu_seqlens
)
# obtain WY representation. u is actually the new v.
# obtain WY representation. u is actually the new v.
A
=
chunk_scaled_dot_kkt_fwd
(
A
=
chunk_scaled_dot_kkt_fwd
(
k
=
k
,
beta
=
beta
,
g
_cumsum
=
g
,
cu_seqlens
=
cu_seqlens
,
output_dtype
=
torch
.
float32
k
=
k
,
beta
=
beta
,
g
=
g
,
cu_seqlens
=
cu_seqlens
,
output_dtype
=
torch
.
float32
)
)
A
=
solve_tril
(
A
=
A
,
cu_seqlens
=
cu_seqlens
,
output_dtype
=
k
.
dtype
)
A
=
solve_tril
(
A
=
A
,
cu_seqlens
=
cu_seqlens
,
output_dtype
=
k
.
dtype
)
w
,
u
=
recompute_w_u_fwd
(
w
,
u
=
recompute_w_u_fwd
(
...
...
vllm/model_executor/layers/fla/ops/chunk_delta_h.py
View file @
e88bdd60
...
@@ -14,14 +14,15 @@ from vllm.triton_utils import tl, triton
...
@@ -14,14 +14,15 @@ from vllm.triton_utils import tl, triton
from
.index
import
prepare_chunk_indices
,
prepare_chunk_offsets
from
.index
import
prepare_chunk_indices
,
prepare_chunk_offsets
from
.op
import
exp
from
.op
import
exp
from
.utils
import
is_nvidia_hopper
,
use_cuda_graph
from
.utils
import
use_cuda_graph
NUM_WARPS
=
[
2
,
4
]
if
is_nvidia_hopper
else
[
2
,
4
,
8
,
16
]
NUM_WARPS
=
[
2
,
4
,
8
,
16
]
@
triton
.
heuristics
(
@
triton
.
heuristics
(
{
{
"USE_G"
:
lambda
args
:
args
[
"g"
]
is
not
None
,
"USE_G"
:
lambda
args
:
args
[
"g"
]
is
not
None
,
"USE_GK"
:
lambda
args
:
args
[
"gk"
]
is
not
None
,
"USE_INITIAL_STATE"
:
lambda
args
:
args
[
"h0"
]
is
not
None
,
"USE_INITIAL_STATE"
:
lambda
args
:
args
[
"h0"
]
is
not
None
,
"STORE_FINAL_STATE"
:
lambda
args
:
args
[
"ht"
]
is
not
None
,
"STORE_FINAL_STATE"
:
lambda
args
:
args
[
"ht"
]
is
not
None
,
"SAVE_NEW_VALUE"
:
lambda
args
:
args
[
"v_new"
]
is
not
None
,
"SAVE_NEW_VALUE"
:
lambda
args
:
args
[
"v_new"
]
is
not
None
,
...
@@ -35,7 +36,7 @@ NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8, 16]
...
@@ -35,7 +36,7 @@ NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8, 16]
for
num_stages
in
[
2
,
3
,
4
]
for
num_stages
in
[
2
,
3
,
4
]
for
BV
in
[
32
,
64
]
for
BV
in
[
32
,
64
]
],
],
key
=
[
"H"
,
"K"
,
"V"
,
"BT"
,
"USE_G"
],
key
=
[
"H"
,
"K"
,
"V"
,
"BT"
],
use_cuda_graph
=
use_cuda_graph
,
use_cuda_graph
=
use_cuda_graph
,
)
)
@
triton
.
jit
(
do_not_specialize
=
[
"T"
])
@
triton
.
jit
(
do_not_specialize
=
[
"T"
])
...
@@ -45,6 +46,7 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
...
@@ -45,6 +46,7 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
w
,
w
,
v_new
,
v_new
,
g
,
g
,
gk
,
h
,
h
,
h0
,
h0
,
ht
,
ht
,
...
@@ -58,6 +60,7 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
...
@@ -58,6 +60,7 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
BT
:
tl
.
constexpr
,
BT
:
tl
.
constexpr
,
BV
:
tl
.
constexpr
,
BV
:
tl
.
constexpr
,
USE_G
:
tl
.
constexpr
,
USE_G
:
tl
.
constexpr
,
USE_GK
:
tl
.
constexpr
,
USE_INITIAL_STATE
:
tl
.
constexpr
,
USE_INITIAL_STATE
:
tl
.
constexpr
,
STORE_FINAL_STATE
:
tl
.
constexpr
,
STORE_FINAL_STATE
:
tl
.
constexpr
,
SAVE_NEW_VALUE
:
tl
.
constexpr
,
SAVE_NEW_VALUE
:
tl
.
constexpr
,
...
@@ -88,12 +91,12 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
...
@@ -88,12 +91,12 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
b_h4
=
tl
.
zeros
([
64
,
BV
],
dtype
=
tl
.
float32
)
b_h4
=
tl
.
zeros
([
64
,
BV
],
dtype
=
tl
.
float32
)
# calculate offset
# calculate offset
h
+=
(
boh
*
H
+
i_h
)
*
K
*
V
h
+=
(
(
boh
*
H
+
i_h
)
*
K
*
V
).
to
(
tl
.
int64
)
v
+=
(
bos
*
H
+
i_h
)
*
V
v
+=
(
(
bos
*
H
+
i_h
)
*
V
).
to
(
tl
.
int64
)
k
+=
(
bos
*
Hg
+
i_h
//
(
H
//
Hg
))
*
K
k
+=
(
(
bos
*
Hg
+
i_h
//
(
H
//
Hg
))
*
K
).
to
(
tl
.
int64
)
w
+=
(
bos
*
H
+
i_h
)
*
K
w
+=
(
(
bos
*
H
+
i_h
)
*
K
).
to
(
tl
.
int64
)
if
SAVE_NEW_VALUE
:
if
SAVE_NEW_VALUE
:
v_new
+=
(
bos
*
H
+
i_h
)
*
V
v_new
+=
(
(
bos
*
H
+
i_h
)
*
V
).
to
(
tl
.
int64
)
stride_v
=
H
*
V
stride_v
=
H
*
V
stride_h
=
H
*
K
*
V
stride_h
=
H
*
K
*
V
stride_k
=
Hg
*
K
stride_k
=
Hg
*
K
...
@@ -145,92 +148,115 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
...
@@ -145,92 +148,115 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
)
)
tl
.
store
(
p_h4
,
b_h4
.
to
(
p_h4
.
dtype
.
element_ty
),
boundary_check
=
(
0
,
1
))
tl
.
store
(
p_h4
,
b_h4
.
to
(
p_h4
.
dtype
.
element_ty
),
boundary_check
=
(
0
,
1
))
p_v
=
tl
.
make_block_ptr
(
v
,
(
T
,
V
),
(
stride_v
,
1
),
(
i_t
*
BT
,
i_v
*
BV
),
(
BT
,
BV
),
(
1
,
0
)
)
p_v_new
=
(
tl
.
make_block_ptr
(
v_new
,
(
T
,
V
),
(
stride_v
,
1
),
(
i_t
*
BT
,
i_v
*
BV
),
(
BT
,
BV
),
(
1
,
0
)
)
if
SAVE_NEW_VALUE
else
None
)
b_v_new
=
tl
.
zeros
([
BT
,
BV
],
dtype
=
tl
.
float32
)
p_w
=
tl
.
make_block_ptr
(
p_w
=
tl
.
make_block_ptr
(
w
,
(
T
,
K
),
(
stride_w
,
1
),
(
i_t
*
BT
,
0
),
(
BT
,
64
),
(
1
,
0
)
w
,
(
T
,
K
),
(
stride_w
,
1
),
(
i_t
*
BT
,
0
),
(
BT
,
64
),
(
1
,
0
)
)
)
b_w
=
tl
.
load
(
p_w
,
boundary_check
=
(
0
,
1
))
b_w
=
tl
.
load
(
p_w
,
boundary_check
=
(
0
,
1
))
b_v
_new
+
=
tl
.
dot
(
b_w
,
b_h1
.
to
(
b_w
.
dtype
))
b_v
=
tl
.
dot
(
b_w
,
b_h1
.
to
(
b_w
.
dtype
))
if
K
>
64
:
if
K
>
64
:
p_w
=
tl
.
make_block_ptr
(
p_w
=
tl
.
make_block_ptr
(
w
,
(
T
,
K
),
(
stride_w
,
1
),
(
i_t
*
BT
,
64
),
(
BT
,
64
),
(
1
,
0
)
w
,
(
T
,
K
),
(
stride_w
,
1
),
(
i_t
*
BT
,
64
),
(
BT
,
64
),
(
1
,
0
)
)
)
b_w
=
tl
.
load
(
p_w
,
boundary_check
=
(
0
,
1
))
b_w
=
tl
.
load
(
p_w
,
boundary_check
=
(
0
,
1
))
b_v
_new
+=
tl
.
dot
(
b_w
,
b_h2
.
to
(
b_w
.
dtype
))
b_v
+=
tl
.
dot
(
b_w
,
b_h2
.
to
(
b_w
.
dtype
))
if
K
>
128
:
if
K
>
128
:
p_w
=
tl
.
make_block_ptr
(
p_w
=
tl
.
make_block_ptr
(
w
,
(
T
,
K
),
(
stride_w
,
1
),
(
i_t
*
BT
,
128
),
(
BT
,
64
),
(
1
,
0
)
w
,
(
T
,
K
),
(
stride_w
,
1
),
(
i_t
*
BT
,
128
),
(
BT
,
64
),
(
1
,
0
)
)
)
b_w
=
tl
.
load
(
p_w
,
boundary_check
=
(
0
,
1
))
b_w
=
tl
.
load
(
p_w
,
boundary_check
=
(
0
,
1
))
b_v
_new
+=
tl
.
dot
(
b_w
,
b_h3
.
to
(
b_w
.
dtype
))
b_v
+=
tl
.
dot
(
b_w
,
b_h3
.
to
(
b_w
.
dtype
))
if
K
>
192
:
if
K
>
192
:
p_w
=
tl
.
make_block_ptr
(
p_w
=
tl
.
make_block_ptr
(
w
,
(
T
,
K
),
(
stride_w
,
1
),
(
i_t
*
BT
,
192
),
(
BT
,
64
),
(
1
,
0
)
w
,
(
T
,
K
),
(
stride_w
,
1
),
(
i_t
*
BT
,
192
),
(
BT
,
64
),
(
1
,
0
)
)
)
b_w
=
tl
.
load
(
p_w
,
boundary_check
=
(
0
,
1
))
b_w
=
tl
.
load
(
p_w
,
boundary_check
=
(
0
,
1
))
b_v_new
+=
tl
.
dot
(
b_w
,
b_h4
.
to
(
b_w
.
dtype
))
b_v
+=
tl
.
dot
(
b_w
,
b_h4
.
to
(
b_w
.
dtype
))
b_v_new
=
-
b_v_new
+
tl
.
load
(
p_v
,
boundary_check
=
(
0
,
1
))
p_v
=
tl
.
make_block_ptr
(
v
,
(
T
,
V
),
(
stride_v
,
1
),
(
i_t
*
BT
,
i_v
*
BV
),
(
BT
,
BV
),
(
1
,
0
)
)
b_v
=
tl
.
load
(
p_v
,
boundary_check
=
(
0
,
1
))
-
b_v
if
SAVE_NEW_VALUE
:
if
SAVE_NEW_VALUE
:
p_v
_new
=
tl
.
make_block_ptr
(
p_v
=
tl
.
make_block_ptr
(
v_new
,
(
T
,
V
),
(
stride_v
,
1
),
(
i_t
*
BT
,
i_v
*
BV
),
(
BT
,
BV
),
(
1
,
0
)
v_new
,
(
T
,
V
),
(
stride_v
,
1
),
(
i_t
*
BT
,
i_v
*
BV
),
(
BT
,
BV
),
(
1
,
0
)
)
)
tl
.
store
(
tl
.
store
(
p_v
,
b_v
.
to
(
p_v
.
dtype
.
element_ty
),
boundary_check
=
(
0
,
1
))
p_v_new
,
b_v_new
.
to
(
p_v_new
.
dtype
.
element_ty
),
boundary_check
=
(
0
,
1
)
)
last_idx
=
min
((
i_t
+
1
)
*
BT
,
T
)
-
1
if
USE_G
:
if
USE_G
:
m_t
=
(
i_t
*
BT
+
tl
.
arange
(
0
,
BT
))
<
T
m_t
=
(
i_t
*
BT
+
tl
.
arange
(
0
,
BT
))
<
T
last_idx
=
min
((
i_t
+
1
)
*
BT
,
T
)
-
1
b_g_last
=
tl
.
load
(
g
+
bos
*
H
+
last_idx
*
H
+
i_h
)
b_g_last
=
tl
.
load
(
g
+
bos
*
H
+
last_idx
*
H
+
i_h
)
p_g
=
tl
.
make_block_ptr
(
p_g
=
tl
.
make_block_ptr
(
g
+
bos
*
H
+
i_h
,
(
T
,),
(
H
,),
(
i_t
*
BT
,),
(
BT
,),
(
0
,)
g
+
bos
*
H
+
i_h
,
(
T
,),
(
H
,),
(
i_t
*
BT
,),
(
BT
,),
(
0
,)
)
)
b_g
=
tl
.
load
(
p_g
,
boundary_check
=
(
0
,))
b_g
=
tl
.
load
(
p_g
,
boundary_check
=
(
0
,))
b_v
_new
=
b_v
_new
*
tl
.
where
(
m_t
,
exp
(
b_g_last
-
b_g
),
0
)[:,
None
]
b_v
=
b_v
*
tl
.
where
(
m_t
,
exp
(
b_g_last
-
b_g
),
0
)[:,
None
]
b_g_last
=
exp
(
b_g_last
)
b_g_last
=
exp
(
b_g_last
)
b_h1
=
b_h1
*
b_g_last
b_h1
*=
b_g_last
if
K
>
64
:
b_h2
*=
b_g_last
if
K
>
128
:
b_h3
*=
b_g_last
if
K
>
192
:
b_h4
*=
b_g_last
if
USE_GK
:
o_k1
=
tl
.
arange
(
0
,
64
)
b_gk_last1
=
tl
.
load
(
gk
+
(
bos
+
last_idx
)
*
H
*
K
+
i_h
*
K
+
o_k1
,
mask
=
(
o_k1
<
K
),
other
=
0.0
,
)
b_h1
*=
exp
(
b_gk_last1
)[:,
None
]
if
K
>
64
:
if
K
>
64
:
b_h2
=
b_h2
*
b_g_last
o_k2
=
64
+
o_k1
b_gk_last2
=
tl
.
load
(
gk
+
(
bos
+
last_idx
)
*
H
*
K
+
i_h
*
K
+
o_k2
,
mask
=
(
o_k2
<
K
),
other
=
0.0
,
)
b_h2
*=
exp
(
b_gk_last2
)[:,
None
]
if
K
>
128
:
if
K
>
128
:
b_h3
=
b_h3
*
b_g_last
o_k3
=
128
+
o_k1
b_gk_last3
=
tl
.
load
(
gk
+
(
bos
+
last_idx
)
*
H
*
K
+
i_h
*
K
+
o_k3
,
mask
=
(
o_k3
<
K
),
other
=
0.0
,
)
b_h3
*=
exp
(
b_gk_last3
)[:,
None
]
if
K
>
192
:
if
K
>
192
:
b_h4
=
b_h4
*
b_g_last
o_k4
=
192
+
o_k1
b_v_new
=
b_v_new
.
to
(
k
.
dtype
.
element_ty
)
b_gk_last4
=
tl
.
load
(
gk
+
(
bos
+
last_idx
)
*
H
*
K
+
i_h
*
K
+
o_k4
,
mask
=
(
o_k4
<
K
),
other
=
0.0
,
)
b_h4
*=
exp
(
b_gk_last4
)[:,
None
]
b_v
=
b_v
.
to
(
k
.
dtype
.
element_ty
)
p_k
=
tl
.
make_block_ptr
(
p_k
=
tl
.
make_block_ptr
(
k
,
(
K
,
T
),
(
1
,
stride_k
),
(
0
,
i_t
*
BT
),
(
64
,
BT
),
(
0
,
1
)
k
,
(
K
,
T
),
(
1
,
stride_k
),
(
0
,
i_t
*
BT
),
(
64
,
BT
),
(
0
,
1
)
)
)
b_k
=
tl
.
load
(
p_k
,
boundary_check
=
(
0
,
1
))
b_k
=
tl
.
load
(
p_k
,
boundary_check
=
(
0
,
1
))
b_h1
+=
tl
.
dot
(
b_k
,
b_v
_new
)
b_h1
+=
tl
.
dot
(
b_k
,
b_v
)
if
K
>
64
:
if
K
>
64
:
p_k
=
tl
.
make_block_ptr
(
p_k
=
tl
.
make_block_ptr
(
k
,
(
K
,
T
),
(
1
,
stride_k
),
(
64
,
i_t
*
BT
),
(
64
,
BT
),
(
0
,
1
)
k
,
(
K
,
T
),
(
1
,
stride_k
),
(
64
,
i_t
*
BT
),
(
64
,
BT
),
(
0
,
1
)
)
)
b_k
=
tl
.
load
(
p_k
,
boundary_check
=
(
0
,
1
))
b_k
=
tl
.
load
(
p_k
,
boundary_check
=
(
0
,
1
))
b_h2
+=
tl
.
dot
(
b_k
,
b_v
_new
)
b_h2
+=
tl
.
dot
(
b_k
,
b_v
)
if
K
>
128
:
if
K
>
128
:
p_k
=
tl
.
make_block_ptr
(
p_k
=
tl
.
make_block_ptr
(
k
,
(
K
,
T
),
(
1
,
stride_k
),
(
128
,
i_t
*
BT
),
(
64
,
BT
),
(
0
,
1
)
k
,
(
K
,
T
),
(
1
,
stride_k
),
(
128
,
i_t
*
BT
),
(
64
,
BT
),
(
0
,
1
)
)
)
b_k
=
tl
.
load
(
p_k
,
boundary_check
=
(
0
,
1
))
b_k
=
tl
.
load
(
p_k
,
boundary_check
=
(
0
,
1
))
b_h3
+=
tl
.
dot
(
b_k
,
b_v
_new
)
b_h3
+=
tl
.
dot
(
b_k
,
b_v
)
if
K
>
192
:
if
K
>
192
:
p_k
=
tl
.
make_block_ptr
(
p_k
=
tl
.
make_block_ptr
(
k
,
(
K
,
T
),
(
1
,
stride_k
),
(
192
,
i_t
*
BT
),
(
64
,
BT
),
(
0
,
1
)
k
,
(
K
,
T
),
(
1
,
stride_k
),
(
192
,
i_t
*
BT
),
(
64
,
BT
),
(
0
,
1
)
)
)
b_k
=
tl
.
load
(
p_k
,
boundary_check
=
(
0
,
1
))
b_k
=
tl
.
load
(
p_k
,
boundary_check
=
(
0
,
1
))
b_h4
+=
tl
.
dot
(
b_k
,
b_v_new
)
b_h4
+=
tl
.
dot
(
b_k
,
b_v
)
# epilogue
# epilogue
if
STORE_FINAL_STATE
:
if
STORE_FINAL_STATE
:
p_ht
=
tl
.
make_block_ptr
(
ht
,
(
K
,
V
),
(
V
,
1
),
(
0
,
i_v
*
BV
),
(
64
,
BV
),
(
1
,
0
))
p_ht
=
tl
.
make_block_ptr
(
ht
,
(
K
,
V
),
(
V
,
1
),
(
0
,
i_v
*
BV
),
(
64
,
BV
),
(
1
,
0
))
...
@@ -257,12 +283,15 @@ def chunk_gated_delta_rule_fwd_h(
...
@@ -257,12 +283,15 @@ def chunk_gated_delta_rule_fwd_h(
w
:
torch
.
Tensor
,
w
:
torch
.
Tensor
,
u
:
torch
.
Tensor
,
u
:
torch
.
Tensor
,
g
:
torch
.
Tensor
|
None
=
None
,
g
:
torch
.
Tensor
|
None
=
None
,
gk
:
torch
.
Tensor
|
None
=
None
,
initial_state
:
torch
.
Tensor
|
None
=
None
,
initial_state
:
torch
.
Tensor
|
None
=
None
,
output_final_state
:
bool
=
False
,
output_final_state
:
bool
=
False
,
chunk_size
:
int
=
64
,
# SY: remove this argument and force chunk size 64?
chunk_size
:
int
=
64
,
# SY: remove this argument and force chunk size 64?
save_new_value
:
bool
=
True
,
save_new_value
:
bool
=
True
,
cu_seqlens
:
torch
.
LongTensor
|
None
=
None
,
cu_seqlens
:
torch
.
LongTensor
|
None
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# This kernel is slightly different from fla to support Q/K with different head numbers.
# In fla, Q/K always have the same head number, so Hg is always equal to H.
B
,
T
,
Hg
,
K
,
V
=
*
k
.
shape
,
u
.
shape
[
-
1
]
B
,
T
,
Hg
,
K
,
V
=
*
k
.
shape
,
u
.
shape
[
-
1
]
H
=
u
.
shape
[
-
2
]
H
=
u
.
shape
[
-
2
]
BT
=
chunk_size
BT
=
chunk_size
...
@@ -299,6 +328,7 @@ def chunk_gated_delta_rule_fwd_h(
...
@@ -299,6 +328,7 @@ def chunk_gated_delta_rule_fwd_h(
w
=
w
,
w
=
w
,
v_new
=
v_new
,
v_new
=
v_new
,
g
=
g
,
g
=
g
,
gk
=
gk
,
h
=
h
,
h
=
h
,
h0
=
initial_state
,
h0
=
initial_state
,
ht
=
final_state
,
ht
=
final_state
,
...
...
vllm/model_executor/layers/fla/ops/chunk_scaled_dot_kkt.py
View file @
e88bdd60
...
@@ -18,8 +18,8 @@ from .op import exp
...
@@ -18,8 +18,8 @@ from .op import exp
@
triton
.
heuristics
(
@
triton
.
heuristics
(
{
{
"USE_G"
:
lambda
args
:
args
[
"g"
]
is
not
None
,
"IS_VARLEN"
:
lambda
args
:
args
[
"cu_seqlens"
]
is
not
None
,
"IS_VARLEN"
:
lambda
args
:
args
[
"cu_seqlens"
]
is
not
None
,
"USE_G"
:
lambda
args
:
args
[
"g_cumsum"
]
is
not
None
,
}
}
)
)
@
triton
.
autotune
(
@
triton
.
autotune
(
...
@@ -35,7 +35,7 @@ from .op import exp
...
@@ -35,7 +35,7 @@ from .op import exp
def
chunk_scaled_dot_kkt_fwd_kernel
(
def
chunk_scaled_dot_kkt_fwd_kernel
(
k
,
k
,
beta
,
beta
,
g
_cumsum
,
g
,
A
,
A
,
cu_seqlens
,
cu_seqlens
,
chunk_indices
,
chunk_indices
,
...
@@ -85,9 +85,7 @@ def chunk_scaled_dot_kkt_fwd_kernel(
...
@@ -85,9 +85,7 @@ def chunk_scaled_dot_kkt_fwd_kernel(
b_A
+=
tl
.
dot
(
b_kb
.
to
(
b_k
.
dtype
),
tl
.
trans
(
b_k
))
b_A
+=
tl
.
dot
(
b_kb
.
to
(
b_k
.
dtype
),
tl
.
trans
(
b_k
))
if
USE_G
:
if
USE_G
:
p_g
=
tl
.
make_block_ptr
(
p_g
=
tl
.
make_block_ptr
(
g
+
bos
*
H
+
i_h
,
(
T
,),
(
H
,),
(
i_t
*
BT
,),
(
BT
,),
(
0
,))
g_cumsum
+
bos
*
H
+
i_h
,
(
T
,),
(
H
,),
(
i_t
*
BT
,),
(
BT
,),
(
0
,)
)
b_g
=
tl
.
load
(
p_g
,
boundary_check
=
(
0
,))
b_g
=
tl
.
load
(
p_g
,
boundary_check
=
(
0
,))
b_g_diff
=
b_g
[:,
None
]
-
b_g
[
None
,
:]
b_g_diff
=
b_g
[:,
None
]
-
b_g
[
None
,
:]
b_A
=
b_A
*
exp
(
b_g_diff
)
b_A
=
b_A
*
exp
(
b_g_diff
)
...
@@ -102,8 +100,8 @@ def chunk_scaled_dot_kkt_fwd_kernel(
...
@@ -102,8 +100,8 @@ def chunk_scaled_dot_kkt_fwd_kernel(
def
chunk_scaled_dot_kkt_fwd
(
def
chunk_scaled_dot_kkt_fwd
(
k
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
beta
:
torch
.
Tensor
,
g
:
torch
.
Tensor
|
None
=
None
,
g_cumsum
:
torch
.
Tensor
|
None
=
None
,
beta
:
torch
.
Tensor
|
None
=
None
,
cu_seqlens
:
torch
.
LongTensor
|
None
=
None
,
cu_seqlens
:
torch
.
LongTensor
|
None
=
None
,
chunk_size
:
int
=
64
,
chunk_size
:
int
=
64
,
output_dtype
:
torch
.
dtype
=
torch
.
float32
,
output_dtype
:
torch
.
dtype
=
torch
.
float32
,
...
@@ -116,9 +114,8 @@ def chunk_scaled_dot_kkt_fwd(
...
@@ -116,9 +114,8 @@ def chunk_scaled_dot_kkt_fwd(
The key tensor of shape `[B, T, H, K]`.
The key tensor of shape `[B, T, H, K]`.
beta (torch.Tensor):
beta (torch.Tensor):
The beta tensor of shape `[B, T, H]`.
The beta tensor of shape `[B, T, H]`.
g_cumsum (torch.Tensor):
g (torch.Tensor):
The cumulative sum of the gate tensor of shape `[B, T, H]`.
The cumulative sum of the gate tensor of shape `[B, T, H]`. Default: `None`.
Default: None
cu_seqlens (torch.LongTensor):
cu_seqlens (torch.LongTensor):
The cumulative sequence lengths of the input tensor.
The cumulative sequence lengths of the input tensor.
Default: None
Default: None
...
@@ -130,20 +127,21 @@ def chunk_scaled_dot_kkt_fwd(
...
@@ -130,20 +127,21 @@ def chunk_scaled_dot_kkt_fwd(
Returns:
Returns:
beta * K * K^T of shape `[B, T, H, BT]` where `BT` is the chunk size.
beta * K * K^T of shape `[B, T, H, BT]` where `BT` is the chunk size.
"""
"""
# This kernel is slightly different from fla to support Q/K with different head numbers.
# In fla, Q/K always have the same head number, so Hg is always equal to H.
B
,
T
,
Hg
,
K
=
k
.
shape
B
,
T
,
Hg
,
K
=
k
.
shape
H
=
beta
.
shape
[
-
1
]
H
=
beta
.
shape
[
-
1
]
BT
=
chunk_size
BT
=
chunk_size
chunk_indices
=
(
chunk_indices
=
(
prepare_chunk_indices
(
cu_seqlens
,
BT
)
if
cu_seqlens
is
not
None
else
None
prepare_chunk_indices
(
cu_seqlens
,
BT
)
if
cu_seqlens
is
not
None
else
None
)
)
NT
=
triton
.
cdiv
(
T
,
BT
)
if
cu_seqlens
is
None
else
len
(
chunk_indices
)
NT
=
triton
.
cdiv
(
T
,
BT
)
if
cu_seqlens
is
None
else
len
(
chunk_indices
)
A
=
torch
.
empty
(
B
,
T
,
H
,
BT
,
device
=
k
.
device
,
dtype
=
output_dtype
)
A
=
torch
.
empty
(
B
,
T
,
H
,
BT
,
device
=
k
.
device
,
dtype
=
output_dtype
)
chunk_scaled_dot_kkt_fwd_kernel
[(
NT
,
B
*
H
)](
chunk_scaled_dot_kkt_fwd_kernel
[(
NT
,
B
*
H
)](
k
=
k
,
k
=
k
,
g
=
g
,
beta
=
beta
,
beta
=
beta
,
g_cumsum
=
g_cumsum
,
A
=
A
,
A
=
A
,
cu_seqlens
=
cu_seqlens
,
cu_seqlens
=
cu_seqlens
,
chunk_indices
=
chunk_indices
,
chunk_indices
=
chunk_indices
,
...
...
vllm/model_executor/layers/fla/ops/fused_recurrent.py
View file @
e88bdd60
...
@@ -57,6 +57,7 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
...
@@ -57,6 +57,7 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
IS_VARLEN
:
tl
.
constexpr
,
IS_VARLEN
:
tl
.
constexpr
,
IS_CONTINUOUS_BATCHING
:
tl
.
constexpr
,
IS_CONTINUOUS_BATCHING
:
tl
.
constexpr
,
IS_SPEC_DECODING
:
tl
.
constexpr
,
IS_SPEC_DECODING
:
tl
.
constexpr
,
IS_KDA
:
tl
.
constexpr
,
):
):
i_k
,
i_v
,
i_nh
=
tl
.
program_id
(
0
),
tl
.
program_id
(
1
),
tl
.
program_id
(
2
)
i_k
,
i_v
,
i_nh
=
tl
.
program_id
(
0
),
tl
.
program_id
(
1
),
tl
.
program_id
(
2
)
i_n
,
i_hv
=
i_nh
//
HV
,
i_nh
%
HV
i_n
,
i_hv
=
i_nh
//
HV
,
i_nh
%
HV
...
@@ -86,7 +87,12 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
...
@@ -86,7 +87,12 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
p_beta
=
beta
+
(
bos
*
HV
+
i_hv
)
*
V
+
o_v
p_beta
=
beta
+
(
bos
*
HV
+
i_hv
)
*
V
+
o_v
else
:
else
:
p_beta
=
beta
+
bos
*
HV
+
i_hv
p_beta
=
beta
+
bos
*
HV
+
i_hv
if
not
IS_KDA
:
p_g
=
g
+
bos
*
HV
+
i_hv
p_g
=
g
+
bos
*
HV
+
i_hv
else
:
p_gk
=
g
+
(
bos
*
HV
+
i_hv
)
*
K
+
o_k
p_o
=
o
+
((
i_k
*
all
+
bos
)
*
HV
+
i_hv
)
*
V
+
o_v
p_o
=
o
+
((
i_k
*
all
+
bos
)
*
HV
+
i_hv
)
*
V
+
o_v
mask_k
=
o_k
<
K
mask_k
=
o_k
<
K
...
@@ -116,14 +122,18 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
...
@@ -116,14 +122,18 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
b_q
=
tl
.
load
(
p_q
,
mask
=
mask_k
,
other
=
0
).
to
(
tl
.
float32
)
b_q
=
tl
.
load
(
p_q
,
mask
=
mask_k
,
other
=
0
).
to
(
tl
.
float32
)
b_k
=
tl
.
load
(
p_k
,
mask
=
mask_k
,
other
=
0
).
to
(
tl
.
float32
)
b_k
=
tl
.
load
(
p_k
,
mask
=
mask_k
,
other
=
0
).
to
(
tl
.
float32
)
b_v
=
tl
.
load
(
p_v
,
mask
=
mask_v
,
other
=
0
).
to
(
tl
.
float32
)
b_v
=
tl
.
load
(
p_v
,
mask
=
mask_v
,
other
=
0
).
to
(
tl
.
float32
)
b_g
=
tl
.
load
(
p_g
).
to
(
tl
.
float32
)
if
USE_QK_L2NORM_IN_KERNEL
:
if
USE_QK_L2NORM_IN_KERNEL
:
b_q
=
b_q
/
tl
.
sqrt
(
tl
.
sum
(
b_q
*
b_q
)
+
1e-6
)
b_q
=
b_q
/
tl
.
sqrt
(
tl
.
sum
(
b_q
*
b_q
)
+
1e-6
)
b_k
=
b_k
/
tl
.
sqrt
(
tl
.
sum
(
b_k
*
b_k
)
+
1e-6
)
b_k
=
b_k
/
tl
.
sqrt
(
tl
.
sum
(
b_k
*
b_k
)
+
1e-6
)
b_q
=
b_q
*
scale
b_q
=
b_q
*
scale
# [BK, BV]
# [BK, BV]
if
not
IS_KDA
:
b_g
=
tl
.
load
(
p_g
).
to
(
tl
.
float32
)
b_h
*=
exp
(
b_g
)
b_h
*=
exp
(
b_g
)
else
:
b_gk
=
tl
.
load
(
p_gk
).
to
(
tl
.
float32
)
b_h
*=
exp
(
b_gk
[:,
None
])
# [BV]
# [BV]
b_v
-=
tl
.
sum
(
b_h
*
b_k
[:,
None
],
0
)
b_v
-=
tl
.
sum
(
b_h
*
b_k
[:,
None
],
0
)
if
IS_BETA_HEADWISE
:
if
IS_BETA_HEADWISE
:
...
@@ -155,7 +165,10 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
...
@@ -155,7 +165,10 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
p_k
+=
H
*
K
p_k
+=
H
*
K
p_o
+=
HV
*
V
p_o
+=
HV
*
V
p_v
+=
HV
*
V
p_v
+=
HV
*
V
if
not
IS_KDA
:
p_g
+=
HV
p_g
+=
HV
else
:
p_gk
+=
HV
*
K
p_beta
+=
HV
*
(
V
if
IS_BETA_HEADWISE
else
1
)
p_beta
+=
HV
*
(
V
if
IS_BETA_HEADWISE
else
1
)
...
@@ -228,6 +241,7 @@ def fused_recurrent_gated_delta_rule_fwd(
...
@@ -228,6 +241,7 @@ def fused_recurrent_gated_delta_rule_fwd(
IS_BETA_HEADWISE
=
beta
.
ndim
==
v
.
ndim
,
IS_BETA_HEADWISE
=
beta
.
ndim
==
v
.
ndim
,
USE_QK_L2NORM_IN_KERNEL
=
use_qk_l2norm_in_kernel
,
USE_QK_L2NORM_IN_KERNEL
=
use_qk_l2norm_in_kernel
,
INPLACE_FINAL_STATE
=
inplace_final_state
,
INPLACE_FINAL_STATE
=
inplace_final_state
,
IS_KDA
=
False
,
num_warps
=
num_warps
,
num_warps
=
num_warps
,
num_stages
=
num_stages
,
num_stages
=
num_stages
,
)
)
...
...
vllm/model_executor/layers/fla/ops/kda.py
0 → 100644
View file @
e88bdd60
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
#
# This file contains code copied from the flash-linear-attention project.
# The original source code was licensed under the MIT license and included
# the following copyright notice:
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
# ruff: noqa: E501
import
torch
import
torch.nn
as
nn
from
vllm.triton_utils
import
tl
,
triton
from
vllm.utils.math_utils
import
cdiv
,
next_power_of_2
from
.chunk_delta_h
import
chunk_gated_delta_rule_fwd_h
from
.cumsum
import
chunk_local_cumsum
from
.fused_recurrent
import
fused_recurrent_gated_delta_rule_fwd_kernel
from
.index
import
prepare_chunk_indices
from
.l2norm
import
l2norm_fwd
from
.op
import
exp
,
log
from
.solve_tril
import
solve_tril
from
.utils
import
is_amd
BT_LIST_AUTOTUNE
=
[
32
,
64
,
128
]
NUM_WARPS_AUTOTUNE
=
[
2
,
4
,
8
,
16
]
if
is_amd
else
[
4
,
8
,
16
,
32
]
def
fused_recurrent_kda_fwd
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
g
:
torch
.
Tensor
,
beta
:
torch
.
Tensor
,
scale
:
float
,
initial_state
:
torch
.
Tensor
,
inplace_final_state
:
bool
=
True
,
cu_seqlens
:
torch
.
LongTensor
|
None
=
None
,
ssm_state_indices
:
torch
.
Tensor
|
None
=
None
,
num_accepted_tokens
:
torch
.
Tensor
|
None
=
None
,
use_qk_l2norm_in_kernel
:
bool
=
False
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
B
,
T
,
H
,
K
,
V
=
*
k
.
shape
,
v
.
shape
[
-
1
]
HV
=
v
.
shape
[
2
]
N
=
B
if
cu_seqlens
is
None
else
len
(
cu_seqlens
)
-
1
BK
,
BV
=
next_power_of_2
(
K
),
min
(
next_power_of_2
(
V
),
8
)
NK
,
NV
=
cdiv
(
K
,
BK
),
cdiv
(
V
,
BV
)
assert
NK
==
1
,
"NK > 1 is not supported yet"
num_stages
=
3
num_warps
=
1
o
=
torch
.
empty_like
(
k
)
if
inplace_final_state
:
final_state
=
initial_state
else
:
final_state
=
q
.
new_empty
(
T
,
HV
,
K
,
V
,
dtype
=
initial_state
.
dtype
)
stride_init_state_token
=
initial_state
.
stride
(
0
)
stride_final_state_token
=
final_state
.
stride
(
0
)
if
ssm_state_indices
is
None
:
stride_indices_seq
,
stride_indices_tok
=
1
,
1
elif
ssm_state_indices
.
ndim
==
1
:
stride_indices_seq
,
stride_indices_tok
=
ssm_state_indices
.
stride
(
0
),
1
else
:
stride_indices_seq
,
stride_indices_tok
=
ssm_state_indices
.
stride
()
grid
=
(
NK
,
NV
,
N
*
HV
)
fused_recurrent_gated_delta_rule_fwd_kernel
[
grid
](
q
=
q
,
k
=
k
,
v
=
v
,
g
=
g
,
beta
=
beta
,
o
=
o
,
h0
=
initial_state
,
ht
=
final_state
,
cu_seqlens
=
cu_seqlens
,
ssm_state_indices
=
ssm_state_indices
,
num_accepted_tokens
=
num_accepted_tokens
,
scale
=
scale
,
N
=
N
,
T
=
T
,
B
=
B
,
H
=
H
,
HV
=
HV
,
K
=
K
,
V
=
V
,
BK
=
BK
,
BV
=
BV
,
stride_init_state_token
=
stride_init_state_token
,
stride_final_state_token
=
stride_final_state_token
,
stride_indices_seq
=
stride_indices_seq
,
stride_indices_tok
=
stride_indices_tok
,
IS_BETA_HEADWISE
=
beta
.
ndim
==
v
.
ndim
,
USE_QK_L2NORM_IN_KERNEL
=
use_qk_l2norm_in_kernel
,
INPLACE_FINAL_STATE
=
inplace_final_state
,
IS_KDA
=
True
,
num_warps
=
num_warps
,
num_stages
=
num_stages
,
)
return
o
,
final_state
def
fused_recurrent_kda
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
g
:
torch
.
Tensor
,
beta
:
torch
.
Tensor
=
None
,
scale
:
float
=
None
,
initial_state
:
torch
.
Tensor
=
None
,
inplace_final_state
:
bool
=
True
,
use_qk_l2norm_in_kernel
:
bool
=
True
,
cu_seqlens
:
torch
.
LongTensor
|
None
=
None
,
ssm_state_indices
:
torch
.
LongTensor
|
None
=
None
,
**
kwargs
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
cu_seqlens
is
not
None
and
q
.
shape
[
0
]
!=
1
:
raise
ValueError
(
f
"The batch size is expected to be 1 rather than
{
q
.
shape
[
0
]
}
when using `cu_seqlens`."
f
"Please flatten variable-length inputs before processing."
)
if
scale
is
None
:
scale
=
k
.
shape
[
-
1
]
**
-
0.5
o
,
final_state
=
fused_recurrent_kda_fwd
(
q
=
q
.
contiguous
(),
k
=
k
.
contiguous
(),
v
=
v
.
contiguous
(),
g
=
g
.
contiguous
(),
beta
=
beta
.
contiguous
(),
scale
=
scale
,
initial_state
=
initial_state
,
inplace_final_state
=
inplace_final_state
,
cu_seqlens
=
cu_seqlens
,
ssm_state_indices
=
ssm_state_indices
,
num_accepted_tokens
=
None
,
use_qk_l2norm_in_kernel
=
use_qk_l2norm_in_kernel
,
)
return
o
,
final_state
@
triton
.
heuristics
(
{
"STORE_RESIDUAL_OUT"
:
lambda
args
:
args
[
"residual_out"
]
is
not
None
,
"HAS_RESIDUAL"
:
lambda
args
:
args
[
"residual"
]
is
not
None
,
"HAS_WEIGHT"
:
lambda
args
:
args
[
"w"
]
is
not
None
,
"HAS_BIAS"
:
lambda
args
:
args
[
"b"
]
is
not
None
,
}
)
@
triton
.
jit
def
layer_norm_gated_fwd_kernel
(
x
,
# pointer to the input
g
,
# pointer to the gate
y
,
# pointer to the output
w
,
# pointer to the weights
b
,
# pointer to the biases
residual
,
# pointer to the residual
residual_out
,
# pointer to the residual
mean
,
# pointer to the mean
rstd
,
# pointer to the 1/std
eps
,
# epsilon to avoid division by zero
T
,
# number of rows in x
D
:
tl
.
constexpr
,
# number of columns in x
BT
:
tl
.
constexpr
,
BD
:
tl
.
constexpr
,
ACTIVATION
:
tl
.
constexpr
,
IS_RMS_NORM
:
tl
.
constexpr
,
STORE_RESIDUAL_OUT
:
tl
.
constexpr
,
HAS_RESIDUAL
:
tl
.
constexpr
,
HAS_WEIGHT
:
tl
.
constexpr
,
HAS_BIAS
:
tl
.
constexpr
,
):
i_t
=
tl
.
program_id
(
0
)
o_d
=
tl
.
arange
(
0
,
BD
)
m_d
=
o_d
<
D
p_x
=
tl
.
make_block_ptr
(
x
,
(
T
,
D
),
(
D
,
1
),
(
i_t
*
BT
,
0
),
(
BT
,
BD
),
(
1
,
0
))
b_x
=
tl
.
load
(
p_x
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
if
HAS_RESIDUAL
:
p_res
=
tl
.
make_block_ptr
(
residual
,
(
T
,
D
),
(
D
,
1
),
(
i_t
*
BT
,
0
),
(
BT
,
BD
),
(
1
,
0
)
)
b_x
+=
tl
.
load
(
p_res
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
if
STORE_RESIDUAL_OUT
:
p_res_out
=
tl
.
make_block_ptr
(
residual_out
,
(
T
,
D
),
(
D
,
1
),
(
i_t
*
BT
,
0
),
(
BT
,
BD
),
(
1
,
0
)
)
tl
.
store
(
p_res_out
,
b_x
.
to
(
p_res_out
.
dtype
.
element_ty
),
boundary_check
=
(
0
,
1
))
if
not
IS_RMS_NORM
:
b_mean
=
tl
.
sum
(
b_x
,
axis
=
1
)
/
D
p_mean
=
tl
.
make_block_ptr
(
mean
,
(
T
,),
(
1
,),
(
i_t
*
BT
,),
(
BT
,),
(
0
,))
tl
.
store
(
p_mean
,
b_mean
.
to
(
p_mean
.
dtype
.
element_ty
),
boundary_check
=
(
0
,))
b_xbar
=
tl
.
where
(
m_d
[
None
,
:],
b_x
-
b_mean
[:,
None
],
0.0
)
b_var
=
tl
.
sum
(
b_xbar
*
b_xbar
,
axis
=
1
)
/
D
else
:
b_xbar
=
tl
.
where
(
m_d
[
None
,
:],
b_x
,
0.0
)
b_var
=
tl
.
sum
(
b_xbar
*
b_xbar
,
axis
=
1
)
/
D
b_rstd
=
1
/
tl
.
sqrt
(
b_var
+
eps
)
p_rstd
=
tl
.
make_block_ptr
(
rstd
,
(
T
,),
(
1
,),
(
i_t
*
BT
,),
(
BT
,),
(
0
,))
tl
.
store
(
p_rstd
,
b_rstd
.
to
(
p_rstd
.
dtype
.
element_ty
),
boundary_check
=
(
0
,))
if
HAS_WEIGHT
:
b_w
=
tl
.
load
(
w
+
o_d
,
mask
=
m_d
).
to
(
tl
.
float32
)
if
HAS_BIAS
:
b_b
=
tl
.
load
(
b
+
o_d
,
mask
=
m_d
).
to
(
tl
.
float32
)
b_x_hat
=
(
(
b_x
-
b_mean
[:,
None
])
*
b_rstd
[:,
None
]
if
not
IS_RMS_NORM
else
b_x
*
b_rstd
[:,
None
]
)
b_y
=
b_x_hat
*
b_w
[
None
,
:]
if
HAS_WEIGHT
else
b_x_hat
if
HAS_BIAS
:
b_y
=
b_y
+
b_b
[
None
,
:]
# swish/sigmoid output gate
p_g
=
tl
.
make_block_ptr
(
g
,
(
T
,
D
),
(
D
,
1
),
(
i_t
*
BT
,
0
),
(
BT
,
BD
),
(
1
,
0
))
b_g
=
tl
.
load
(
p_g
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
if
ACTIVATION
==
"swish"
or
ACTIVATION
==
"silu"
:
b_y
=
b_y
*
b_g
*
tl
.
sigmoid
(
b_g
)
elif
ACTIVATION
==
"sigmoid"
:
b_y
=
b_y
*
tl
.
sigmoid
(
b_g
)
# Write output
p_y
=
tl
.
make_block_ptr
(
y
,
(
T
,
D
),
(
D
,
1
),
(
i_t
*
BT
,
0
),
(
BT
,
BD
),
(
1
,
0
))
tl
.
store
(
p_y
,
b_y
.
to
(
p_y
.
dtype
.
element_ty
),
boundary_check
=
(
0
,
1
))
@
triton
.
heuristics
(
{
"STORE_RESIDUAL_OUT"
:
lambda
args
:
args
[
"residual_out"
]
is
not
None
,
"HAS_RESIDUAL"
:
lambda
args
:
args
[
"residual"
]
is
not
None
,
"HAS_WEIGHT"
:
lambda
args
:
args
[
"w"
]
is
not
None
,
"HAS_BIAS"
:
lambda
args
:
args
[
"b"
]
is
not
None
,
}
)
@
triton
.
jit
def
layer_norm_gated_fwd_kernel1
(
x
,
# pointer to the input
g
,
# pointer to the gate
y
,
# pointer to the output
w
,
# pointer to the weights
b
,
# pointer to the biases
residual
,
# pointer to the residual
residual_out
,
# pointer to the residual
mean
,
# pointer to the mean
rstd
,
# pointer to the 1/std
eps
,
# epsilon to avoid division by zero
D
:
tl
.
constexpr
,
# number of columns in x
BD
:
tl
.
constexpr
,
ACTIVATION
:
tl
.
constexpr
,
IS_RMS_NORM
:
tl
.
constexpr
,
STORE_RESIDUAL_OUT
:
tl
.
constexpr
,
HAS_RESIDUAL
:
tl
.
constexpr
,
HAS_WEIGHT
:
tl
.
constexpr
,
HAS_BIAS
:
tl
.
constexpr
,
):
i_t
=
tl
.
program_id
(
0
)
x
+=
i_t
*
D
y
+=
i_t
*
D
g
+=
i_t
*
D
if
HAS_RESIDUAL
:
residual
+=
i_t
*
D
if
STORE_RESIDUAL_OUT
:
residual_out
+=
i_t
*
D
o_d
=
tl
.
arange
(
0
,
BD
)
m_d
=
o_d
<
D
b_x
=
tl
.
load
(
x
+
o_d
,
mask
=
m_d
,
other
=
0.0
).
to
(
tl
.
float32
)
if
HAS_RESIDUAL
:
b_x
+=
tl
.
load
(
residual
+
o_d
,
mask
=
m_d
,
other
=
0.0
).
to
(
tl
.
float32
)
if
STORE_RESIDUAL_OUT
:
tl
.
store
(
residual_out
+
o_d
,
b_x
,
mask
=
m_d
)
if
not
IS_RMS_NORM
:
b_mean
=
tl
.
sum
(
b_x
,
axis
=
0
)
/
D
tl
.
store
(
mean
+
i_t
,
b_mean
)
b_xbar
=
tl
.
where
(
m_d
,
b_x
-
b_mean
,
0.0
)
b_var
=
tl
.
sum
(
b_xbar
*
b_xbar
,
axis
=
0
)
/
D
else
:
b_xbar
=
tl
.
where
(
m_d
,
b_x
,
0.0
)
b_var
=
tl
.
sum
(
b_xbar
*
b_xbar
,
axis
=
0
)
/
D
b_rstd
=
1
/
tl
.
sqrt
(
b_var
+
eps
)
tl
.
store
(
rstd
+
i_t
,
b_rstd
)
if
HAS_WEIGHT
:
b_w
=
tl
.
load
(
w
+
o_d
,
mask
=
m_d
).
to
(
tl
.
float32
)
if
HAS_BIAS
:
b_b
=
tl
.
load
(
b
+
o_d
,
mask
=
m_d
).
to
(
tl
.
float32
)
b_x_hat
=
(
b_x
-
b_mean
)
*
b_rstd
if
not
IS_RMS_NORM
else
b_x
*
b_rstd
b_y
=
b_x_hat
*
b_w
if
HAS_WEIGHT
else
b_x_hat
if
HAS_BIAS
:
b_y
=
b_y
+
b_b
# swish/sigmoid output gate
b_g
=
tl
.
load
(
g
+
o_d
,
mask
=
m_d
,
other
=
0.0
).
to
(
tl
.
float32
)
if
ACTIVATION
==
"swish"
or
ACTIVATION
==
"silu"
:
b_y
=
b_y
*
b_g
*
tl
.
sigmoid
(
b_g
)
elif
ACTIVATION
==
"sigmoid"
:
b_y
=
b_y
*
tl
.
sigmoid
(
b_g
)
# Write output
tl
.
store
(
y
+
o_d
,
b_y
,
mask
=
m_d
)
def
layer_norm_gated_fwd
(
x
:
torch
.
Tensor
,
g
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
,
activation
:
str
=
"swish"
,
eps
:
float
=
1e-5
,
residual
:
torch
.
Tensor
=
None
,
out_dtype
:
torch
.
dtype
=
None
,
residual_dtype
:
torch
.
dtype
=
None
,
is_rms_norm
:
bool
=
False
,
):
if
residual
is
not
None
:
residual_dtype
=
residual
.
dtype
T
,
D
=
x
.
shape
if
residual
is
not
None
:
assert
residual
.
shape
==
(
T
,
D
)
if
weight
is
not
None
:
assert
weight
.
shape
==
(
D
,)
if
bias
is
not
None
:
assert
bias
.
shape
==
(
D
,)
# allocate output
y
=
x
if
out_dtype
is
None
else
torch
.
empty_like
(
x
,
dtype
=
out_dtype
)
if
residual
is
not
None
or
(
residual_dtype
is
not
None
and
residual_dtype
!=
x
.
dtype
):
residual_out
=
torch
.
empty
(
T
,
D
,
device
=
x
.
device
,
dtype
=
residual_dtype
)
else
:
residual_out
=
None
mean
=
(
torch
.
empty
((
T
,),
dtype
=
torch
.
float
,
device
=
x
.
device
)
if
not
is_rms_norm
else
None
)
rstd
=
torch
.
empty
((
T
,),
dtype
=
torch
.
float
,
device
=
x
.
device
)
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE
=
65536
//
x
.
element_size
()
BD
=
min
(
MAX_FUSED_SIZE
,
next_power_of_2
(
D
))
if
D
>
BD
:
raise
RuntimeError
(
"This layer norm doesn't support feature dim >= 64KB."
)
# heuristics for number of warps
if
D
<=
512
:
BT
=
32
layer_norm_gated_fwd_kernel
[(
cdiv
(
T
,
BT
),)](
x
=
x
,
g
=
g
,
y
=
y
,
w
=
weight
,
b
=
bias
,
residual
=
residual
,
residual_out
=
residual_out
,
mean
=
mean
,
rstd
=
rstd
,
eps
=
eps
,
T
=
T
,
D
=
D
,
BD
=
BD
,
BT
=
BT
,
ACTIVATION
=
activation
,
IS_RMS_NORM
=
is_rms_norm
,
num_warps
=
4
,
)
else
:
layer_norm_gated_fwd_kernel1
[(
T
,)](
x
=
x
,
g
=
g
,
y
=
y
,
w
=
weight
,
b
=
bias
,
residual
=
residual
,
residual_out
=
residual_out
,
mean
=
mean
,
rstd
=
rstd
,
eps
=
eps
,
D
=
D
,
BD
=
BD
,
ACTIVATION
=
activation
,
IS_RMS_NORM
=
is_rms_norm
,
num_warps
=
4
,
)
# residual_out is None if residual is None and residual_dtype == input_dtype
return
y
,
mean
,
rstd
,
residual_out
if
residual_out
is
not
None
else
x
def
rms_norm_gated
(
x
:
torch
.
Tensor
,
g
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
,
activation
:
str
=
"swish"
,
residual
:
torch
.
Tensor
|
None
=
None
,
prenorm
:
bool
=
False
,
residual_in_fp32
:
bool
=
False
,
eps
:
float
=
1e-6
,
):
x_shape_og
=
x
.
shape
# reshape input data into 2D tensor
x
=
x
.
contiguous
().
reshape
(
-
1
,
x
.
shape
[
-
1
])
g
=
g
.
contiguous
().
reshape
(
-
1
,
g
.
shape
[
-
1
])
if
residual
is
not
None
:
assert
residual
.
shape
==
x_shape_og
residual
=
residual
.
contiguous
().
reshape
(
-
1
,
residual
.
shape
[
-
1
])
residual_dtype
=
(
residual
.
dtype
if
residual
is
not
None
else
(
torch
.
float
if
residual_in_fp32
else
None
)
)
y
,
_
,
_
,
residual_out
=
layer_norm_gated_fwd
(
x
=
x
,
g
=
g
,
weight
=
weight
,
bias
=
bias
,
activation
=
activation
,
eps
=
eps
,
residual
=
residual
,
residual_dtype
=
residual_dtype
,
is_rms_norm
=
True
,
)
y
=
y
.
reshape
(
x_shape_og
)
return
y
if
not
prenorm
else
(
y
,
residual_out
.
reshape
(
x_shape_og
))
class
FusedRMSNormGated
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
elementwise_affine
:
bool
=
True
,
eps
:
float
=
1e-5
,
activation
:
str
=
"swish"
,
device
:
torch
.
device
|
None
=
None
,
dtype
:
torch
.
dtype
|
None
=
None
,
)
->
None
:
factory_kwargs
=
{
"device"
:
device
,
"dtype"
:
dtype
}
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
elementwise_affine
=
elementwise_affine
self
.
eps
=
eps
self
.
activation
=
activation
if
self
.
activation
not
in
[
"swish"
,
"silu"
,
"sigmoid"
]:
raise
ValueError
(
f
"Unsupported activation:
{
self
.
activation
}
"
)
if
elementwise_affine
:
self
.
weight
=
nn
.
Parameter
(
torch
.
empty
(
hidden_size
,
**
factory_kwargs
))
else
:
self
.
register_parameter
(
"weight"
,
None
)
self
.
register_parameter
(
"bias"
,
None
)
def
forward
(
self
,
x
:
torch
.
Tensor
,
g
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
|
None
=
None
,
prenorm
:
bool
=
False
,
residual_in_fp32
:
bool
=
False
,
)
->
torch
.
Tensor
:
return
rms_norm_gated
(
x
,
g
,
self
.
weight
,
self
.
bias
,
self
.
activation
,
residual
=
residual
,
eps
=
self
.
eps
,
prenorm
=
prenorm
,
residual_in_fp32
=
residual_in_fp32
,
)
@
triton
.
heuristics
({
"IS_VARLEN"
:
lambda
args
:
args
[
"cu_seqlens"
]
is
not
None
})
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({
"BK"
:
BK
},
num_warps
=
num_warps
,
num_stages
=
num_stages
)
for
BK
in
[
32
,
64
]
for
num_warps
in
[
1
,
2
,
4
,
8
]
for
num_stages
in
[
2
,
3
,
4
]
],
key
=
[
"BC"
],
)
@
triton
.
jit
(
do_not_specialize
=
[
"T"
])
def
chunk_kda_scaled_dot_kkt_fwd_kernel_intra_sub_inter
(
q
,
k
,
g
,
beta
,
A
,
Aqk
,
scale
,
cu_seqlens
,
chunk_indices
,
T
,
H
:
tl
.
constexpr
,
K
:
tl
.
constexpr
,
BT
:
tl
.
constexpr
,
BC
:
tl
.
constexpr
,
BK
:
tl
.
constexpr
,
NC
:
tl
.
constexpr
,
IS_VARLEN
:
tl
.
constexpr
,
):
i_t
,
i_c
,
i_bh
=
tl
.
program_id
(
0
),
tl
.
program_id
(
1
),
tl
.
program_id
(
2
)
i_b
,
i_h
=
i_bh
//
H
,
i_bh
%
H
i_i
,
i_j
=
i_c
//
NC
,
i_c
%
NC
if
IS_VARLEN
:
i_n
,
i_t
=
(
tl
.
load
(
chunk_indices
+
i_t
*
2
).
to
(
tl
.
int32
),
tl
.
load
(
chunk_indices
+
i_t
*
2
+
1
).
to
(
tl
.
int32
),
)
bos
,
eos
=
(
tl
.
load
(
cu_seqlens
+
i_n
).
to
(
tl
.
int32
),
tl
.
load
(
cu_seqlens
+
i_n
+
1
).
to
(
tl
.
int32
),
)
T
=
eos
-
bos
else
:
bos
,
eos
=
i_b
*
T
,
i_b
*
T
+
T
if
i_t
*
BT
+
i_i
*
BC
>=
T
:
return
if
i_i
<=
i_j
:
return
q
+=
(
bos
*
H
+
i_h
)
*
K
k
+=
(
bos
*
H
+
i_h
)
*
K
g
+=
(
bos
*
H
+
i_h
)
*
K
A
+=
(
bos
*
H
+
i_h
)
*
BT
Aqk
+=
(
bos
*
H
+
i_h
)
*
BT
p_b
=
tl
.
make_block_ptr
(
beta
+
bos
*
H
+
i_h
,
(
T
,),
(
H
,),
(
i_t
*
BT
+
i_i
*
BC
,),
(
BC
,),
(
0
,)
)
b_b
=
tl
.
load
(
p_b
,
boundary_check
=
(
0
,))
b_A
=
tl
.
zeros
([
BC
,
BC
],
dtype
=
tl
.
float32
)
b_Aqk
=
tl
.
zeros
([
BC
,
BC
],
dtype
=
tl
.
float32
)
for
i_k
in
range
(
tl
.
cdiv
(
K
,
BK
)):
p_q
=
tl
.
make_block_ptr
(
q
,
(
T
,
K
),
(
H
*
K
,
1
),
(
i_t
*
BT
+
i_i
*
BC
,
i_k
*
BK
),
(
BC
,
BK
),
(
1
,
0
)
)
p_k
=
tl
.
make_block_ptr
(
k
,
(
T
,
K
),
(
H
*
K
,
1
),
(
i_t
*
BT
+
i_i
*
BC
,
i_k
*
BK
),
(
BC
,
BK
),
(
1
,
0
)
)
p_g
=
tl
.
make_block_ptr
(
g
,
(
T
,
K
),
(
H
*
K
,
1
),
(
i_t
*
BT
+
i_i
*
BC
,
i_k
*
BK
),
(
BC
,
BK
),
(
1
,
0
)
)
b_kt
=
tl
.
make_block_ptr
(
k
,
(
K
,
T
),
(
1
,
H
*
K
),
(
i_k
*
BK
,
i_t
*
BT
+
i_j
*
BC
),
(
BK
,
BC
),
(
0
,
1
)
)
p_gk
=
tl
.
make_block_ptr
(
g
,
(
K
,
T
),
(
1
,
H
*
K
),
(
i_k
*
BK
,
i_t
*
BT
+
i_j
*
BC
),
(
BK
,
BC
),
(
0
,
1
)
)
o_k
=
i_k
*
BK
+
tl
.
arange
(
0
,
BK
)
m_k
=
o_k
<
K
# [BK,]
b_gn
=
tl
.
load
(
g
+
(
i_t
*
BT
+
i_i
*
BC
)
*
H
*
K
+
o_k
,
mask
=
m_k
,
other
=
0
)
# [BC, BK]
b_g
=
tl
.
load
(
p_g
,
boundary_check
=
(
0
,
1
))
b_k
=
tl
.
load
(
p_k
,
boundary_check
=
(
0
,
1
))
*
exp
(
b_g
-
b_gn
[
None
,
:])
# [BK, BC]
b_gk
=
tl
.
load
(
p_gk
,
boundary_check
=
(
0
,
1
))
b_kt
=
tl
.
load
(
b_kt
,
boundary_check
=
(
0
,
1
))
# [BC, BC]
b_ktg
=
b_kt
*
exp
(
b_gn
[:,
None
]
-
b_gk
)
b_A
+=
tl
.
dot
(
b_k
,
b_ktg
)
b_q
=
tl
.
load
(
p_q
,
boundary_check
=
(
0
,
1
))
b_qg
=
b_q
*
exp
(
b_g
-
b_gn
[
None
,
:])
*
scale
b_Aqk
+=
tl
.
dot
(
b_qg
,
b_ktg
)
b_A
*=
b_b
[:,
None
]
p_A
=
tl
.
make_block_ptr
(
A
,
(
T
,
BT
),
(
H
*
BT
,
1
),
(
i_t
*
BT
+
i_i
*
BC
,
i_j
*
BC
),
(
BC
,
BC
),
(
1
,
0
)
)
tl
.
store
(
p_A
,
b_A
.
to
(
A
.
dtype
.
element_ty
),
boundary_check
=
(
0
,
1
))
p_Aqk
=
tl
.
make_block_ptr
(
Aqk
,
(
T
,
BT
),
(
H
*
BT
,
1
),
(
i_t
*
BT
+
i_i
*
BC
,
i_j
*
BC
),
(
BC
,
BC
),
(
1
,
0
)
)
tl
.
store
(
p_Aqk
,
b_Aqk
.
to
(
Aqk
.
dtype
.
element_ty
),
boundary_check
=
(
0
,
1
))
@
triton
.
heuristics
({
"IS_VARLEN"
:
lambda
args
:
args
[
"cu_seqlens"
]
is
not
None
})
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({},
num_warps
=
num_warps
)
for
num_warps
in
[
1
,
2
,
4
,
8
]],
key
=
[
"BK"
,
"BT"
],
)
@
triton
.
jit
(
do_not_specialize
=
[
"T"
])
def
chunk_kda_scaled_dot_kkt_fwd_kernel_intra_sub_intra
(
q
,
k
,
g
,
beta
,
A
,
Aqk
,
scale
,
cu_seqlens
,
chunk_indices
,
T
,
H
:
tl
.
constexpr
,
K
:
tl
.
constexpr
,
BT
:
tl
.
constexpr
,
BC
:
tl
.
constexpr
,
BK
:
tl
.
constexpr
,
IS_VARLEN
:
tl
.
constexpr
,
):
i_t
,
i_i
,
i_bh
=
tl
.
program_id
(
0
),
tl
.
program_id
(
1
),
tl
.
program_id
(
2
)
i_b
,
i_h
=
i_bh
//
H
,
i_bh
%
H
if
IS_VARLEN
:
i_n
,
i_t
=
(
tl
.
load
(
chunk_indices
+
i_t
*
2
).
to
(
tl
.
int32
),
tl
.
load
(
chunk_indices
+
i_t
*
2
+
1
).
to
(
tl
.
int32
),
)
bos
,
eos
=
(
tl
.
load
(
cu_seqlens
+
i_n
).
to
(
tl
.
int32
),
tl
.
load
(
cu_seqlens
+
i_n
+
1
).
to
(
tl
.
int32
),
)
T
=
eos
-
bos
else
:
bos
,
eos
=
i_b
*
T
,
i_b
*
T
+
T
if
i_t
*
BT
+
i_i
*
BC
>=
T
:
return
o_i
=
tl
.
arange
(
0
,
BC
)
o_k
=
tl
.
arange
(
0
,
BK
)
m_k
=
o_k
<
K
m_A
=
(
i_t
*
BT
+
i_i
*
BC
+
o_i
)
<
T
o_A
=
(
bos
+
i_t
*
BT
+
i_i
*
BC
+
o_i
)
*
H
*
BT
+
i_h
*
BT
+
i_i
*
BC
p_q
=
tl
.
make_block_ptr
(
q
+
(
bos
*
H
+
i_h
)
*
K
,
(
T
,
K
),
(
H
*
K
,
1
),
(
i_t
*
BT
+
i_i
*
BC
,
0
),
(
BC
,
BK
),
(
1
,
0
),
)
p_k
=
tl
.
make_block_ptr
(
k
+
(
bos
*
H
+
i_h
)
*
K
,
(
T
,
K
),
(
H
*
K
,
1
),
(
i_t
*
BT
+
i_i
*
BC
,
0
),
(
BC
,
BK
),
(
1
,
0
),
)
p_g
=
tl
.
make_block_ptr
(
g
+
(
bos
*
H
+
i_h
)
*
K
,
(
T
,
K
),
(
H
*
K
,
1
),
(
i_t
*
BT
+
i_i
*
BC
,
0
),
(
BC
,
BK
),
(
1
,
0
),
)
b_q
=
tl
.
load
(
p_q
,
boundary_check
=
(
0
,
1
))
b_k
=
tl
.
load
(
p_k
,
boundary_check
=
(
0
,
1
))
b_g
=
tl
.
load
(
p_g
,
boundary_check
=
(
0
,
1
))
p_b
=
beta
+
(
bos
+
i_t
*
BT
+
i_i
*
BC
+
o_i
)
*
H
+
i_h
b_k
=
b_k
*
tl
.
load
(
p_b
,
mask
=
m_A
,
other
=
0
)[:,
None
]
p_kt
=
k
+
(
bos
+
i_t
*
BT
+
i_i
*
BC
)
*
H
*
K
+
i_h
*
K
+
o_k
p_gk
=
g
+
(
bos
+
i_t
*
BT
+
i_i
*
BC
)
*
H
*
K
+
i_h
*
K
+
o_k
for
j
in
range
(
0
,
min
(
BC
,
T
-
i_t
*
BT
-
i_i
*
BC
)):
b_kt
=
tl
.
load
(
p_kt
,
mask
=
m_k
,
other
=
0
).
to
(
tl
.
float32
)
b_gk
=
tl
.
load
(
p_gk
,
mask
=
m_k
,
other
=
0
).
to
(
tl
.
float32
)
b_ktg
=
b_kt
[
None
,
:]
*
exp
(
b_g
-
b_gk
[
None
,
:])
b_A
=
tl
.
sum
(
b_k
*
b_ktg
,
1
)
b_A
=
tl
.
where
(
o_i
>
j
,
b_A
,
0.0
)
b_Aqk
=
tl
.
sum
(
b_q
*
b_ktg
,
1
)
b_Aqk
=
tl
.
where
(
o_i
>=
j
,
b_Aqk
*
scale
,
0.0
)
tl
.
store
(
A
+
o_A
+
j
,
b_A
,
mask
=
m_A
)
tl
.
store
(
Aqk
+
o_A
+
j
,
b_Aqk
,
mask
=
m_A
)
p_kt
+=
H
*
K
p_gk
+=
H
*
K
def
chunk_kda_scaled_dot_kkt_fwd
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
gk
:
torch
.
Tensor
|
None
=
None
,
beta
:
torch
.
Tensor
|
None
=
None
,
scale
:
float
|
None
=
None
,
cu_seqlens
:
torch
.
LongTensor
|
None
=
None
,
chunk_size
:
int
=
64
,
output_dtype
:
torch
.
dtype
=
torch
.
float32
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
r
"""
Compute beta * K * K^T.
Args:
k (torch.Tensor):
The key tensor of shape `[B, T, H, K]`.
beta (torch.Tensor):
The beta tensor of shape `[B, T, H]`.
gk (torch.Tensor):
The cumulative sum of the gate tensor of shape `[B, T, H, K]` applied to the key tensor. Default: `None`.
cu_seqlens (torch.LongTensor):
The cumulative sequence lengths of the input tensor.
Default: None
chunk_size (int):
The chunk size. Default: 64.
output_dtype (torch.dtype):
The dtype of the output tensor. Default: `torch.float32`
Returns:
beta * K * K^T of shape `[B, T, H, BT]` where `BT` is the chunk size.
"""
B
,
T
,
H
,
K
=
k
.
shape
assert
K
<=
256
BT
=
chunk_size
chunk_indices
=
(
prepare_chunk_indices
(
cu_seqlens
,
BT
)
if
cu_seqlens
is
not
None
else
None
)
NT
=
cdiv
(
T
,
BT
)
if
cu_seqlens
is
None
else
len
(
chunk_indices
)
BC
=
min
(
16
,
BT
)
NC
=
cdiv
(
BT
,
BC
)
BK
=
max
(
next_power_of_2
(
K
),
16
)
A
=
torch
.
zeros
(
B
,
T
,
H
,
BT
,
device
=
k
.
device
,
dtype
=
output_dtype
)
Aqk
=
torch
.
zeros
(
B
,
T
,
H
,
BT
,
device
=
k
.
device
,
dtype
=
output_dtype
)
grid
=
(
NT
,
NC
*
NC
,
B
*
H
)
chunk_kda_scaled_dot_kkt_fwd_kernel_intra_sub_inter
[
grid
](
q
=
q
,
k
=
k
,
g
=
gk
,
beta
=
beta
,
A
=
A
,
Aqk
=
Aqk
,
scale
=
scale
,
cu_seqlens
=
cu_seqlens
,
chunk_indices
=
chunk_indices
,
T
=
T
,
H
=
H
,
K
=
K
,
BT
=
BT
,
BC
=
BC
,
NC
=
NC
,
)
grid
=
(
NT
,
NC
,
B
*
H
)
chunk_kda_scaled_dot_kkt_fwd_kernel_intra_sub_intra
[
grid
](
q
=
q
,
k
=
k
,
g
=
gk
,
beta
=
beta
,
A
=
A
,
Aqk
=
Aqk
,
scale
=
scale
,
cu_seqlens
=
cu_seqlens
,
chunk_indices
=
chunk_indices
,
T
=
T
,
H
=
H
,
K
=
K
,
BT
=
BT
,
BC
=
BC
,
BK
=
BK
,
)
return
A
,
Aqk
@
triton
.
heuristics
(
{
"STORE_QG"
:
lambda
args
:
args
[
"qg"
]
is
not
None
,
"STORE_KG"
:
lambda
args
:
args
[
"kg"
]
is
not
None
,
"IS_VARLEN"
:
lambda
args
:
args
[
"cu_seqlens"
]
is
not
None
,
}
)
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({},
num_warps
=
num_warps
,
num_stages
=
num_stages
)
for
num_warps
in
[
2
,
4
,
8
]
for
num_stages
in
[
2
,
3
,
4
]
],
key
=
[
"H"
,
"K"
,
"V"
,
"BT"
,
"BK"
,
"BV"
,
"IS_VARLEN"
],
)
@
triton
.
jit
(
do_not_specialize
=
[
"T"
])
def
recompute_w_u_fwd_kernel
(
q
,
k
,
qg
,
kg
,
v
,
beta
,
w
,
u
,
A
,
gk
,
cu_seqlens
,
chunk_indices
,
T
,
H
:
tl
.
constexpr
,
K
:
tl
.
constexpr
,
V
:
tl
.
constexpr
,
BT
:
tl
.
constexpr
,
BK
:
tl
.
constexpr
,
BV
:
tl
.
constexpr
,
STORE_QG
:
tl
.
constexpr
,
STORE_KG
:
tl
.
constexpr
,
IS_VARLEN
:
tl
.
constexpr
,
DOT_PRECISION
:
tl
.
constexpr
,
):
i_t
,
i_bh
=
tl
.
program_id
(
0
),
tl
.
program_id
(
1
)
i_b
,
i_h
=
i_bh
//
H
,
i_bh
%
H
if
IS_VARLEN
:
i_n
,
i_t
=
(
tl
.
load
(
chunk_indices
+
i_t
*
2
).
to
(
tl
.
int32
),
tl
.
load
(
chunk_indices
+
i_t
*
2
+
1
).
to
(
tl
.
int32
),
)
bos
,
eos
=
(
tl
.
load
(
cu_seqlens
+
i_n
).
to
(
tl
.
int32
),
tl
.
load
(
cu_seqlens
+
i_n
+
1
).
to
(
tl
.
int32
),
)
T
=
eos
-
bos
else
:
bos
,
eos
=
i_b
*
T
,
i_b
*
T
+
T
p_b
=
tl
.
make_block_ptr
(
beta
+
bos
*
H
+
i_h
,
(
T
,),
(
H
,),
(
i_t
*
BT
,),
(
BT
,),
(
0
,))
b_b
=
tl
.
load
(
p_b
,
boundary_check
=
(
0
,))
p_A
=
tl
.
make_block_ptr
(
A
+
(
bos
*
H
+
i_h
)
*
BT
,
(
T
,
BT
),
(
H
*
BT
,
1
),
(
i_t
*
BT
,
0
),
(
BT
,
BT
),
(
1
,
0
)
)
b_A
=
tl
.
load
(
p_A
,
boundary_check
=
(
0
,
1
))
for
i_v
in
range
(
tl
.
cdiv
(
V
,
BV
)):
p_v
=
tl
.
make_block_ptr
(
v
+
(
bos
*
H
+
i_h
)
*
V
,
(
T
,
V
),
(
H
*
V
,
1
),
(
i_t
*
BT
,
i_v
*
BV
),
(
BT
,
BV
),
(
1
,
0
),
)
p_u
=
tl
.
make_block_ptr
(
u
+
(
bos
*
H
+
i_h
)
*
V
,
(
T
,
V
),
(
H
*
V
,
1
),
(
i_t
*
BT
,
i_v
*
BV
),
(
BT
,
BV
),
(
1
,
0
),
)
b_v
=
tl
.
load
(
p_v
,
boundary_check
=
(
0
,
1
))
b_vb
=
(
b_v
*
b_b
[:,
None
]).
to
(
b_v
.
dtype
)
b_u
=
tl
.
dot
(
b_A
,
b_vb
,
input_precision
=
DOT_PRECISION
)
tl
.
store
(
p_u
,
b_u
.
to
(
p_u
.
dtype
.
element_ty
),
boundary_check
=
(
0
,
1
))
for
i_k
in
range
(
tl
.
cdiv
(
K
,
BK
)):
p_w
=
tl
.
make_block_ptr
(
w
+
(
bos
*
H
+
i_h
)
*
K
,
(
T
,
K
),
(
H
*
K
,
1
),
(
i_t
*
BT
,
i_k
*
BK
),
(
BT
,
BK
),
(
1
,
0
),
)
p_k
=
tl
.
make_block_ptr
(
k
+
(
bos
*
H
+
i_h
)
*
K
,
(
T
,
K
),
(
H
*
K
,
1
),
(
i_t
*
BT
,
i_k
*
BK
),
(
BT
,
BK
),
(
1
,
0
),
)
b_k
=
tl
.
load
(
p_k
,
boundary_check
=
(
0
,
1
))
b_kb
=
b_k
*
b_b
[:,
None
]
p_gk
=
tl
.
make_block_ptr
(
gk
+
(
bos
*
H
+
i_h
)
*
K
,
(
T
,
K
),
(
H
*
K
,
1
),
(
i_t
*
BT
,
i_k
*
BK
),
(
BT
,
BK
),
(
1
,
0
),
)
b_gk
=
tl
.
load
(
p_gk
,
boundary_check
=
(
0
,
1
))
b_kb
*=
exp
(
b_gk
)
if
STORE_QG
:
p_q
=
tl
.
make_block_ptr
(
q
+
(
bos
*
H
+
i_h
)
*
K
,
(
T
,
K
),
(
H
*
K
,
1
),
(
i_t
*
BT
,
i_k
*
BK
),
(
BT
,
BK
),
(
1
,
0
),
)
p_qg
=
tl
.
make_block_ptr
(
qg
+
(
bos
*
H
+
i_h
)
*
K
,
(
T
,
K
),
(
H
*
K
,
1
),
(
i_t
*
BT
,
i_k
*
BK
),
(
BT
,
BK
),
(
1
,
0
),
)
b_q
=
tl
.
load
(
p_q
,
boundary_check
=
(
0
,
1
))
b_qg
=
b_q
*
exp
(
b_gk
)
tl
.
store
(
p_qg
,
b_qg
.
to
(
p_qg
.
dtype
.
element_ty
),
boundary_check
=
(
0
,
1
))
if
STORE_KG
:
last_idx
=
min
(
i_t
*
BT
+
BT
,
T
)
-
1
o_k
=
i_k
*
BK
+
tl
.
arange
(
0
,
BK
)
m_k
=
o_k
<
K
b_gn
=
tl
.
load
(
gk
+
((
bos
+
last_idx
)
*
H
+
i_h
)
*
K
+
o_k
,
mask
=
m_k
,
other
=
0.0
)
b_kg
=
b_k
*
exp
(
b_gn
-
b_gk
)
p_kg
=
tl
.
make_block_ptr
(
kg
+
(
bos
*
H
+
i_h
)
*
K
,
(
T
,
K
),
(
H
*
K
,
1
),
(
i_t
*
BT
,
i_k
*
BK
),
(
BT
,
BK
),
(
1
,
0
),
)
tl
.
store
(
p_kg
,
b_kg
.
to
(
p_kg
.
dtype
.
element_ty
),
boundary_check
=
(
0
,
1
))
b_w
=
tl
.
dot
(
b_A
,
b_kb
.
to
(
b_k
.
dtype
))
tl
.
store
(
p_w
,
b_w
.
to
(
p_w
.
dtype
.
element_ty
),
boundary_check
=
(
0
,
1
))
def
recompute_w_u_fwd
(
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
beta
:
torch
.
Tensor
,
A
:
torch
.
Tensor
,
q
:
torch
.
Tensor
|
None
=
None
,
gk
:
torch
.
Tensor
|
None
=
None
,
cu_seqlens
:
torch
.
LongTensor
|
None
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
B
,
T
,
H
,
K
,
V
=
*
k
.
shape
,
v
.
shape
[
-
1
]
BT
=
A
.
shape
[
-
1
]
BK
=
64
BV
=
64
chunk_indices
=
(
prepare_chunk_indices
(
cu_seqlens
,
BT
)
if
cu_seqlens
is
not
None
else
None
)
NT
=
cdiv
(
T
,
BT
)
if
cu_seqlens
is
None
else
len
(
chunk_indices
)
w
=
torch
.
empty_like
(
k
)
u
=
torch
.
empty_like
(
v
)
kg
=
torch
.
empty_like
(
k
)
if
gk
is
not
None
else
None
recompute_w_u_fwd_kernel
[(
NT
,
B
*
H
)](
q
=
q
,
k
=
k
,
qg
=
None
,
kg
=
kg
,
v
=
v
,
beta
=
beta
,
w
=
w
,
u
=
u
,
A
=
A
,
gk
=
gk
,
cu_seqlens
=
cu_seqlens
,
chunk_indices
=
chunk_indices
,
T
=
T
,
H
=
H
,
K
=
K
,
V
=
V
,
BT
=
BT
,
BK
=
BK
,
BV
=
BV
,
DOT_PRECISION
=
"ieee"
,
)
return
w
,
u
,
None
,
kg
@
triton
.
heuristics
({
"IS_VARLEN"
:
lambda
args
:
args
[
"cu_seqlens"
]
is
not
None
})
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({
"BK"
:
BK
,
"BV"
:
BV
},
num_warps
=
num_warps
,
num_stages
=
num_stages
)
for
BK
in
[
32
,
64
]
for
BV
in
[
64
,
128
]
for
num_warps
in
[
2
,
4
,
8
]
for
num_stages
in
[
2
,
3
,
4
]
],
key
=
[
"BT"
],
)
@
triton
.
jit
(
do_not_specialize
=
[
"T"
])
def
chunk_gla_fwd_kernel_o
(
q
,
v
,
g
,
h
,
o
,
A
,
cu_seqlens
,
chunk_indices
,
scale
,
T
,
H
:
tl
.
constexpr
,
K
:
tl
.
constexpr
,
V
:
tl
.
constexpr
,
BT
:
tl
.
constexpr
,
BK
:
tl
.
constexpr
,
BV
:
tl
.
constexpr
,
IS_VARLEN
:
tl
.
constexpr
,
):
i_v
,
i_t
,
i_bh
=
tl
.
program_id
(
0
),
tl
.
program_id
(
1
),
tl
.
program_id
(
2
)
i_b
,
i_h
=
i_bh
//
H
,
i_bh
%
H
if
IS_VARLEN
:
i_tg
=
i_t
i_n
,
i_t
=
(
tl
.
load
(
chunk_indices
+
i_t
*
2
).
to
(
tl
.
int32
),
tl
.
load
(
chunk_indices
+
i_t
*
2
+
1
).
to
(
tl
.
int32
),
)
bos
,
eos
=
(
tl
.
load
(
cu_seqlens
+
i_n
).
to
(
tl
.
int32
),
tl
.
load
(
cu_seqlens
+
i_n
+
1
).
to
(
tl
.
int32
),
)
T
=
eos
-
bos
NT
=
tl
.
cdiv
(
T
,
BT
)
else
:
NT
=
tl
.
cdiv
(
T
,
BT
)
i_tg
=
i_b
*
NT
+
i_t
bos
,
eos
=
i_b
*
T
,
i_b
*
T
+
T
m_s
=
tl
.
arange
(
0
,
BT
)[:,
None
]
>=
tl
.
arange
(
0
,
BT
)[
None
,
:]
b_o
=
tl
.
zeros
([
BT
,
BV
],
dtype
=
tl
.
float32
)
for
i_k
in
range
(
tl
.
cdiv
(
K
,
BK
)):
p_q
=
tl
.
make_block_ptr
(
q
+
(
bos
*
H
+
i_h
)
*
K
,
(
T
,
K
),
(
H
*
K
,
1
),
(
i_t
*
BT
,
i_k
*
BK
),
(
BT
,
BK
),
(
1
,
0
),
)
p_g
=
tl
.
make_block_ptr
(
g
+
(
bos
*
H
+
i_h
)
*
K
,
(
T
,
K
),
(
H
*
K
,
1
),
(
i_t
*
BT
,
i_k
*
BK
),
(
BT
,
BK
),
(
1
,
0
),
)
p_h
=
tl
.
make_block_ptr
(
h
+
(
i_tg
*
H
+
i_h
)
*
K
*
V
,
(
K
,
V
),
(
V
,
1
),
(
i_k
*
BK
,
i_v
*
BV
),
(
BK
,
BV
),
(
1
,
0
),
)
# [BT, BK]
b_q
=
tl
.
load
(
p_q
,
boundary_check
=
(
0
,
1
))
b_q
=
(
b_q
*
scale
).
to
(
b_q
.
dtype
)
# [BT, BK]
b_g
=
tl
.
load
(
p_g
,
boundary_check
=
(
0
,
1
))
# [BT, BK]
b_qg
=
(
b_q
*
exp
(
b_g
)).
to
(
b_q
.
dtype
)
# [BK, BV]
b_h
=
tl
.
load
(
p_h
,
boundary_check
=
(
0
,
1
))
# works but dkw, owing to divine benevolence
# [BT, BV]
if
i_k
>=
0
:
b_o
+=
tl
.
dot
(
b_qg
,
b_h
.
to
(
b_qg
.
dtype
))
p_v
=
tl
.
make_block_ptr
(
v
+
(
bos
*
H
+
i_h
)
*
V
,
(
T
,
V
),
(
H
*
V
,
1
),
(
i_t
*
BT
,
i_v
*
BV
),
(
BT
,
BV
),
(
1
,
0
),
)
p_o
=
tl
.
make_block_ptr
(
o
+
(
bos
*
H
+
i_h
)
*
V
,
(
T
,
V
),
(
H
*
V
,
1
),
(
i_t
*
BT
,
i_v
*
BV
),
(
BT
,
BV
),
(
1
,
0
),
)
p_A
=
tl
.
make_block_ptr
(
A
+
(
bos
*
H
+
i_h
)
*
BT
,
(
T
,
BT
),
(
H
*
BT
,
1
),
(
i_t
*
BT
,
0
),
(
BT
,
BT
),
(
1
,
0
)
)
# [BT, BV]
b_v
=
tl
.
load
(
p_v
,
boundary_check
=
(
0
,
1
))
# [BT, BT]
b_A
=
tl
.
load
(
p_A
,
boundary_check
=
(
0
,
1
))
b_A
=
tl
.
where
(
m_s
,
b_A
,
0.0
).
to
(
b_v
.
dtype
)
b_o
+=
tl
.
dot
(
b_A
,
b_v
,
allow_tf32
=
False
)
tl
.
store
(
p_o
,
b_o
.
to
(
p_o
.
dtype
.
element_ty
),
boundary_check
=
(
0
,
1
))
def
chunk_gla_fwd_o_gk
(
q
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
g
:
torch
.
Tensor
,
A
:
torch
.
Tensor
,
h
:
torch
.
Tensor
,
o
:
torch
.
Tensor
,
scale
:
float
,
cu_seqlens
:
torch
.
LongTensor
|
None
=
None
,
chunk_size
:
int
=
64
,
):
B
,
T
,
H
,
K
,
V
=
*
q
.
shape
,
v
.
shape
[
-
1
]
BT
=
chunk_size
chunk_indices
=
(
prepare_chunk_indices
(
cu_seqlens
,
chunk_size
)
if
cu_seqlens
is
not
None
else
None
)
NT
=
cdiv
(
T
,
BT
)
if
cu_seqlens
is
None
else
len
(
chunk_indices
)
def
grid
(
meta
):
return
(
cdiv
(
V
,
meta
[
"BV"
]),
NT
,
B
*
H
)
chunk_gla_fwd_kernel_o
[
grid
](
q
=
q
,
v
=
v
,
g
=
g
,
h
=
h
,
o
=
o
,
A
=
A
,
cu_seqlens
=
cu_seqlens
,
chunk_indices
=
chunk_indices
,
scale
=
scale
,
T
=
T
,
H
=
H
,
K
=
K
,
V
=
V
,
BT
=
BT
,
)
return
o
def
chunk_kda_fwd
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
g
:
torch
.
Tensor
,
beta
:
torch
.
Tensor
,
scale
:
float
,
initial_state
:
torch
.
Tensor
,
output_final_state
:
bool
,
cu_seqlens
:
torch
.
LongTensor
|
None
=
None
,
):
chunk_size
=
64
g
=
chunk_local_cumsum
(
g
,
chunk_size
=
chunk_size
,
cu_seqlens
=
cu_seqlens
)
# the intra Aqk is kept in fp32
# the computation has very marginal effect on the entire throughput
A
,
Aqk
=
chunk_kda_scaled_dot_kkt_fwd
(
q
=
q
,
k
=
k
,
gk
=
g
,
beta
=
beta
,
scale
=
scale
,
cu_seqlens
=
cu_seqlens
,
output_dtype
=
torch
.
float32
,
)
A
=
solve_tril
(
A
=
A
,
cu_seqlens
=
cu_seqlens
,
output_dtype
=
k
.
dtype
)
w
,
u
,
_
,
kg
=
recompute_w_u_fwd
(
k
=
k
,
v
=
v
,
beta
=
beta
,
A
=
A
,
gk
=
g
,
cu_seqlens
=
cu_seqlens
,
)
del
A
h
,
v_new
,
final_state
=
chunk_gated_delta_rule_fwd_h
(
k
=
kg
,
w
=
w
,
u
=
u
,
gk
=
g
,
initial_state
=
initial_state
,
output_final_state
=
output_final_state
,
cu_seqlens
=
cu_seqlens
,
)
del
w
,
u
,
kg
o
=
chunk_gla_fwd_o_gk
(
q
=
q
,
v
=
v_new
,
g
=
g
,
A
=
Aqk
,
h
=
h
,
o
=
v
,
scale
=
scale
,
cu_seqlens
=
cu_seqlens
,
chunk_size
=
chunk_size
,
)
del
Aqk
,
v_new
,
h
return
o
,
final_state
def
chunk_kda
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
g
:
torch
.
Tensor
,
beta
:
torch
.
Tensor
,
scale
:
float
=
None
,
initial_state
:
torch
.
Tensor
=
None
,
output_final_state
:
bool
=
False
,
use_qk_l2norm_in_kernel
:
bool
=
False
,
cu_seqlens
:
torch
.
LongTensor
|
None
=
None
,
**
kwargs
,
):
if
scale
is
None
:
scale
=
k
.
shape
[
-
1
]
**
-
0.5
if
use_qk_l2norm_in_kernel
:
q
=
l2norm_fwd
(
q
.
contiguous
())
k
=
l2norm_fwd
(
k
.
contiguous
())
o
,
final_state
=
chunk_kda_fwd
(
q
=
q
,
k
=
k
,
v
=
v
.
contiguous
(),
g
=
g
.
contiguous
(),
beta
=
beta
.
contiguous
(),
scale
=
scale
,
initial_state
=
initial_state
.
contiguous
(),
output_final_state
=
output_final_state
,
cu_seqlens
=
cu_seqlens
,
)
return
o
,
final_state
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({
"BT"
:
bt
},
num_warps
=
nw
,
num_stages
=
ns
)
for
bt
in
BT_LIST_AUTOTUNE
for
nw
in
NUM_WARPS_AUTOTUNE
for
ns
in
[
2
,
3
]
],
key
=
[
"H"
,
"D"
],
)
@
triton
.
jit
def
kda_gate_fwd_kernel
(
g
,
A
,
y
,
g_bias
,
beta
:
tl
.
constexpr
,
threshold
:
tl
.
constexpr
,
T
,
H
,
D
:
tl
.
constexpr
,
BT
:
tl
.
constexpr
,
BD
:
tl
.
constexpr
,
HAS_BIAS
:
tl
.
constexpr
,
):
i_t
,
i_h
=
tl
.
program_id
(
0
),
tl
.
program_id
(
1
)
n_t
=
i_t
*
BT
b_a
=
tl
.
load
(
A
+
i_h
).
to
(
tl
.
float32
)
b_a
=
-
tl
.
exp
(
b_a
)
stride_row
=
H
*
D
stride_col
=
1
g_ptr
=
tl
.
make_block_ptr
(
base
=
g
+
i_h
*
D
,
shape
=
(
T
,
D
),
strides
=
(
stride_row
,
stride_col
),
offsets
=
(
n_t
,
0
),
block_shape
=
(
BT
,
BD
),
order
=
(
1
,
0
),
)
y_ptr
=
tl
.
make_block_ptr
(
base
=
y
+
i_h
*
D
,
shape
=
(
T
,
D
),
strides
=
(
stride_row
,
stride_col
),
offsets
=
(
n_t
,
0
),
block_shape
=
(
BT
,
BD
),
order
=
(
1
,
0
),
)
b_g
=
tl
.
load
(
g_ptr
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
if
HAS_BIAS
:
n_d
=
tl
.
arange
(
0
,
BD
)
bias_mask
=
n_d
<
D
b_bias
=
tl
.
load
(
g_bias
+
i_h
*
D
+
n_d
,
mask
=
bias_mask
,
other
=
0.0
).
to
(
tl
.
float32
)
b_g
=
b_g
+
b_bias
[
None
,
:]
# softplus(x, beta) = (1/beta) * log(1 + exp(beta * x))
# When beta * x > threshold, use linear approximation x
# Use threshold to switch to linear when beta*x > threshold
g_scaled
=
b_g
*
beta
use_linear
=
g_scaled
>
threshold
sp
=
tl
.
where
(
use_linear
,
b_g
,
(
1.0
/
beta
)
*
log
(
1.0
+
tl
.
exp
(
g_scaled
)))
b_y
=
b_a
*
sp
tl
.
store
(
y_ptr
,
b_y
.
to
(
y
.
dtype
.
element_ty
),
boundary_check
=
(
0
,
1
))
def
kda_gate_fwd
(
g
:
torch
.
Tensor
,
A
:
torch
.
Tensor
,
head_k_dim
:
int
,
g_bias
:
torch
.
Tensor
|
None
=
None
,
beta
:
float
=
1.0
,
threshold
:
float
=
20.0
,
)
->
torch
.
Tensor
:
"""
Forward pass for KDA gate:
input g: [..., H*D]
param A: [H] or [1, 1, H, 1]
beta: softplus beta parameter
threshold: softplus threshold parameter
return : [..., H, D]
"""
orig_shape
=
g
.
shape
[:
-
1
]
g
=
g
.
view
(
-
1
,
g
.
shape
[
-
1
])
T
=
g
.
shape
[
0
]
HD
=
g
.
shape
[
1
]
H
=
A
.
numel
()
assert
H
*
head_k_dim
==
HD
y
=
torch
.
empty_like
(
g
,
dtype
=
torch
.
float32
)
def
grid
(
meta
):
return
(
cdiv
(
T
,
meta
[
"BT"
]),
H
)
kda_gate_fwd_kernel
[
grid
](
g
,
A
,
y
,
g_bias
,
beta
,
threshold
,
T
,
H
,
head_k_dim
,
BD
=
next_power_of_2
(
head_k_dim
),
HAS_BIAS
=
g_bias
is
not
None
,
)
y
=
y
.
view
(
*
orig_shape
,
H
,
head_k_dim
)
return
y
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment