Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
tilelang
Commits
eac96cd7
Unverified
Commit
eac96cd7
authored
Nov 14, 2025
by
Zhengju Tang
Committed by
GitHub
Nov 14, 2025
Browse files
[BugFix] Add autotune and exp2 for GDN kernel (#1258)
* [BugFix] Add autotune and exp2 for GDN kernel * [Lint] * [Lint]
parent
5eb30a4f
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
36 additions
and
18 deletions
+36
-18
examples/gdn/example_chunk_delta_h.py
examples/gdn/example_chunk_delta_h.py
+36
-18
No files found.
examples/gdn/example_chunk_delta_h.py
View file @
eac96cd7
...
...
@@ -3,6 +3,7 @@
import
sys
# noqa: F401
import
tilelang
import
tilelang.language
as
T
from
tilelang.autotuner
import
autotune
# Add your fla repository path to sys.path
# Currently we use the fla repository from the flash-linear-attention project at commit id f03cb3ae
...
...
@@ -80,7 +81,25 @@ def prepare_output(
return
h
,
final_state
,
V_new
@
tilelang
.
jit
(
out_idx
=
[
-
3
,
-
2
,
-
1
])
def
get_configs
():
import
itertools
block_DK
=
[
32
,
64
,
128
]
block_DV
=
[
32
,
64
,
128
]
threads
=
[
128
,
256
]
num_stages
=
[
1
,
2
,
3
]
_configs
=
list
(
itertools
.
product
(
block_DK
,
block_DV
,
threads
,
num_stages
))
configs
=
[{
'block_DK'
:
c
[
0
],
'block_DV'
:
c
[
1
],
'threads'
:
c
[
2
],
'num_stages'
:
c
[
3
]
}
for
c
in
_configs
]
return
configs
@
autotune
(
configs
=
get_configs
(),
warmup
=
3
,
rep
=
5
)
@
tilelang
.
jit
(
out_idx
=
[
-
3
,
-
2
,
-
1
],
pass_configs
=
{
tilelang
.
PassConfigKey
.
TL_ENABLE_FAST_MATH
:
True
})
def
tilelang_chunk_gated_delta_rule_fwd_h
(
# task config
B
,
...
...
@@ -94,15 +113,15 @@ def tilelang_chunk_gated_delta_rule_fwd_h(
gate_dtype
,
state_dtype
,
chunk_size
,
use_g
=
True
,
use_initial_state
=
True
,
store_final_state
=
True
,
save_new_value
=
True
,
use_g
,
use_initial_state
,
store_final_state
,
save_new_value
,
# kernel config
block_DK
=
64
,
block_DV
=
64
,
threads
=
256
,
num_stages
=
0
,
block_DV
=
32
,
threads
=
128
,
num_stages
=
1
,
):
block_S
=
chunk_size
BS
=
S
//
block_S
...
...
@@ -193,11 +212,11 @@ def tilelang_chunk_gated_delta_rule_fwd_h(
for
i_s2
,
i_v
in
T
.
Parallel
(
block_S
,
block_DV
):
with
T
.
If
(
G_last_local
[
0
]
-
G_fragment
[
i_s2
,
i_v
]
<=
0
):
with
T
.
Then
():
V_new_fragment
[
i_s2
,
i_v
]
=
V_new_fragment
[
i_s2
,
i_v
]
*
T
.
exp
(
G_last_local
[
0
]
-
G_fragment
[
i_s2
,
i_v
])
V_new_fragment
[
i_s2
,
i_v
]
=
V_new_fragment
[
i_s2
,
i_v
]
*
T
.
exp
2
(
(
G_last_local
[
0
]
-
G_fragment
[
i_s2
,
i_v
])
*
1.442695
)
with
T
.
Else
():
V_new_fragment
[
i_s2
,
i_v
]
=
0
G_last_local
[
0
]
=
T
.
exp
(
G_last_local
[
0
])
G_last_local
[
0
]
=
T
.
exp
2
(
G_last_local
[
0
]
*
1.442695
)
for
i_k
,
i_v
in
T
.
Parallel
(
DK
,
block_DV
):
b_h_fragment
[
i_k
,
i_v
]
*=
G_last_local
[
0
]
...
...
@@ -281,8 +300,7 @@ def run_test(
kernel
=
tilelang_chunk_gated_delta_rule_fwd_h
(
B
,
S
,
H
,
DK
,
DV
,
input_dtype
,
output_dtype
,
accum_dtype
,
gate_dtype
,
state_dtype
,
chunk_size
,
use_g
,
use_initial_state
,
store_final_state
,
save_new_value
,
block_DK
,
block_DV
,
threads
,
num_stages
)
save_new_value
)
h_tilelang
,
final_state_tilelang
,
V_new_tilelang
=
kernel
(
K
,
W
,
U
,
G
,
initial_state
)
# (zhengju) If you want to print the generated cuda code, you can uncomment the following line
# print("CUDA Code:\n", kernel.get_kernel_source())
...
...
@@ -352,13 +370,13 @@ def main():
state_dtype
=
"float32"
,
chunk_size
=
64
,
use_g
=
True
,
use_initial_state
=
Tru
e
,
store_final_state
=
Tru
e
,
save_new_value
=
Tru
e
,
block_DK
=
64
,
use_initial_state
=
Fals
e
,
store_final_state
=
Fals
e
,
save_new_value
=
Fals
e
,
block_DK
=
32
,
block_DV
=
32
,
threads
=
128
,
num_stages
=
1
,
num_stages
=
2
,
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment