"...composable_kernel_onnxruntime.git" did not exist on "4b616aad52807740908071e90e06e184d3177357"
Unverified Commit eac96cd7 authored by Zhengju Tang's avatar Zhengju Tang Committed by GitHub
Browse files

[BugFix] Add autotune and exp2 for GDN kernel (#1258)

* [BugFix] Add autotune and exp2 for GDN kernel

* [Lint]

* [Lint]
parent 5eb30a4f
......@@ -3,6 +3,7 @@
import sys # noqa: F401
import tilelang
import tilelang.language as T
from tilelang.autotuner import autotune
# Add your fla repository path to sys.path
# Currently we use the fla repository from the flash-linear-attention project at commit id f03cb3ae
......@@ -80,7 +81,25 @@ def prepare_output(
return h, final_state, V_new
@tilelang.jit(out_idx=[-3, -2, -1])
def get_configs():
import itertools
block_DK = [32, 64, 128]
block_DV = [32, 64, 128]
threads = [128, 256]
num_stages = [1, 2, 3]
_configs = list(itertools.product(block_DK, block_DV, threads, num_stages))
configs = [{
'block_DK': c[0],
'block_DV': c[1],
'threads': c[2],
'num_stages': c[3]
} for c in _configs]
return configs
@autotune(configs=get_configs(), warmup=3, rep=5)
@tilelang.jit(out_idx=[-3, -2, -1], pass_configs={tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True})
def tilelang_chunk_gated_delta_rule_fwd_h(
# task config
B,
......@@ -94,15 +113,15 @@ def tilelang_chunk_gated_delta_rule_fwd_h(
gate_dtype,
state_dtype,
chunk_size,
use_g=True,
use_initial_state=True,
store_final_state=True,
save_new_value=True,
use_g,
use_initial_state,
store_final_state,
save_new_value,
# kernel config
block_DK=64,
block_DV=64,
threads=256,
num_stages=0,
block_DV=32,
threads=128,
num_stages=1,
):
block_S = chunk_size
BS = S // block_S
......@@ -193,11 +212,11 @@ def tilelang_chunk_gated_delta_rule_fwd_h(
for i_s2, i_v in T.Parallel(block_S, block_DV):
with T.If(G_last_local[0] - G_fragment[i_s2, i_v] <= 0):
with T.Then():
V_new_fragment[i_s2, i_v] = V_new_fragment[i_s2, i_v] * T.exp(
G_last_local[0] - G_fragment[i_s2, i_v])
V_new_fragment[i_s2, i_v] = V_new_fragment[i_s2, i_v] * T.exp2(
(G_last_local[0] - G_fragment[i_s2, i_v]) * 1.442695)
with T.Else():
V_new_fragment[i_s2, i_v] = 0
G_last_local[0] = T.exp(G_last_local[0])
G_last_local[0] = T.exp2(G_last_local[0] * 1.442695)
for i_k, i_v in T.Parallel(DK, block_DV):
b_h_fragment[i_k, i_v] *= G_last_local[0]
......@@ -281,8 +300,7 @@ def run_test(
kernel = tilelang_chunk_gated_delta_rule_fwd_h(B, S, H, DK, DV, input_dtype, output_dtype,
accum_dtype, gate_dtype, state_dtype, chunk_size,
use_g, use_initial_state, store_final_state,
save_new_value, block_DK, block_DV, threads,
num_stages)
save_new_value)
h_tilelang, final_state_tilelang, V_new_tilelang = kernel(K, W, U, G, initial_state)
# (zhengju) If you want to print the generated cuda code, you can uncomment the following line
# print("CUDA Code:\n", kernel.get_kernel_source())
......@@ -352,13 +370,13 @@ def main():
state_dtype="float32",
chunk_size=64,
use_g=True,
use_initial_state=True,
store_final_state=True,
save_new_value=True,
block_DK=64,
use_initial_state=False,
store_final_state=False,
save_new_value=False,
block_DK=32,
block_DV=32,
threads=128,
num_stages=1,
num_stages=2,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment