Unverified Commit 0f980f15 authored by Jay Zhuang's avatar Jay Zhuang Committed by GitHub
Browse files

Bug fix for Gated Delta Net benchmark script (#1267)



* fix argument order for fla chunk_gated_delta_rule_fwd_h

* explicit import assert_similar from utils

* rename utils module to avoid name clash

* set store_final_state and save_new_value to True

* fix

---------
Co-authored-by: default avatarLeiWang1999 <leiwang1999@outlook.com>
parent 49c85715
......@@ -24,7 +24,7 @@ import torch.nn.functional as F
torch.random.manual_seed(0)
# torch.set_printoptions(profile="full")
from utils import *
from test_utils import assert_similar
def prepare_input(
......
......@@ -20,7 +20,7 @@ import torch
import torch.nn.functional as F
from tilelang.engine.callback import register_cuda_postproc_callback # noqa: F401
from utils import *
from test_utils import assert_similar
# (zhengju) We can slightly modify the generated cuda code from tilelang lowering
# in the debug folder to make the performance better. To enable this callback,
......@@ -292,9 +292,15 @@ def run_test(
getattr(torch, state_dtype))
# fla ref
h_ref, V_new_ref, final_state_ref = chunk_gated_delta_rule_fwd_h(K, W, U, G, initial_state,
store_final_state, chunk_size,
save_new_value)
h_ref, V_new_ref, final_state_ref = chunk_gated_delta_rule_fwd_h(
k=K,
w=W,
u=U,
g=G,
initial_state=initial_state,
output_final_state=store_final_state,
chunk_size=chunk_size,
save_new_value=save_new_value)
# tilelang
kernel = tilelang_chunk_gated_delta_rule_fwd_h(B, S, H, DK, DV, input_dtype, output_dtype,
......@@ -305,8 +311,16 @@ def run_test(
# (zhengju) If you want to print the generated cuda code, you can uncomment the following line
# print("CUDA Code:\n", kernel.get_kernel_source())
fla_time = do_bench(chunk_gated_delta_rule_fwd_h, K, W, U, G, initial_state, store_final_state,
chunk_size, save_new_value)
fla_time = do_bench(
chunk_gated_delta_rule_fwd_h,
k=K,
w=W,
u=U,
g=G,
initial_state=initial_state,
output_final_state=store_final_state,
chunk_size=chunk_size,
save_new_value=save_new_value)
tilelang_time = do_bench(kernel, K, W, U, G, initial_state)
# check correctness
......@@ -371,8 +385,8 @@ def main():
chunk_size=64,
use_g=True,
use_initial_state=False,
store_final_state=False,
save_new_value=False,
store_final_state=True,
save_new_value=True,
block_DK=32,
block_DV=32,
threads=128,
......
......@@ -19,7 +19,7 @@ except ImportError:
fla = None
import torch
from utils import *
from test_utils import assert_similar
torch.random.manual_seed(0)
# torch.set_printoptions(profile="full")
......
......@@ -501,7 +501,7 @@ def run_test(
dg_tilelang = dg_tilelang + dg_tilelang_A_positive.sum(dim=-1) - dg_tilelang_A_negative.sum(
dim=-1)
from utils import assert_similar
from test_utils import assert_similar
assert_similar(dk_ref, dk_tilelang, eps=1e-5, name="dk", raise_assert=False)
assert_similar(dv_ref, dv_tilelang, eps=1e-5, name="dv", raise_assert=False)
assert_similar(dbeta_ref, dbeta_tilelang, eps=1e-5, name="dbeta", raise_assert=False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment