Unverified Commit 0f980f15 authored by Jay Zhuang's avatar Jay Zhuang Committed by GitHub
Browse files

Bug fix for Gated Delta Net benchmark script (#1267)



* fix argument order for fla chunk_gated_delta_rule_fwd_h

* explicit import assert_similar from utils

* rename utils module to avoid name clash

* set store_final_state and save_new_value to True

* fix

---------
Co-authored-by: default avatarLeiWang1999 <leiwang1999@outlook.com>
parent 49c85715
...@@ -24,7 +24,7 @@ import torch.nn.functional as F ...@@ -24,7 +24,7 @@ import torch.nn.functional as F
torch.random.manual_seed(0) torch.random.manual_seed(0)
# torch.set_printoptions(profile="full") # torch.set_printoptions(profile="full")
from utils import * from test_utils import assert_similar
def prepare_input( def prepare_input(
......
...@@ -20,7 +20,7 @@ import torch ...@@ -20,7 +20,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from tilelang.engine.callback import register_cuda_postproc_callback # noqa: F401 from tilelang.engine.callback import register_cuda_postproc_callback # noqa: F401
from utils import * from test_utils import assert_similar
# (zhengju) We can slightly modify the generated cuda code from tilelang lowering # (zhengju) We can slightly modify the generated cuda code from tilelang lowering
# in the debug folder to make the performance better. To enable this callback, # in the debug folder to make the performance better. To enable this callback,
...@@ -292,9 +292,15 @@ def run_test( ...@@ -292,9 +292,15 @@ def run_test(
getattr(torch, state_dtype)) getattr(torch, state_dtype))
# fla ref # fla ref
h_ref, V_new_ref, final_state_ref = chunk_gated_delta_rule_fwd_h(K, W, U, G, initial_state, h_ref, V_new_ref, final_state_ref = chunk_gated_delta_rule_fwd_h(
store_final_state, chunk_size, k=K,
save_new_value) w=W,
u=U,
g=G,
initial_state=initial_state,
output_final_state=store_final_state,
chunk_size=chunk_size,
save_new_value=save_new_value)
# tilelang # tilelang
kernel = tilelang_chunk_gated_delta_rule_fwd_h(B, S, H, DK, DV, input_dtype, output_dtype, kernel = tilelang_chunk_gated_delta_rule_fwd_h(B, S, H, DK, DV, input_dtype, output_dtype,
...@@ -305,8 +311,16 @@ def run_test( ...@@ -305,8 +311,16 @@ def run_test(
# (zhengju) If you want to print the generated cuda code, you can uncomment the following line # (zhengju) If you want to print the generated cuda code, you can uncomment the following line
# print("CUDA Code:\n", kernel.get_kernel_source()) # print("CUDA Code:\n", kernel.get_kernel_source())
fla_time = do_bench(chunk_gated_delta_rule_fwd_h, K, W, U, G, initial_state, store_final_state, fla_time = do_bench(
chunk_size, save_new_value) chunk_gated_delta_rule_fwd_h,
k=K,
w=W,
u=U,
g=G,
initial_state=initial_state,
output_final_state=store_final_state,
chunk_size=chunk_size,
save_new_value=save_new_value)
tilelang_time = do_bench(kernel, K, W, U, G, initial_state) tilelang_time = do_bench(kernel, K, W, U, G, initial_state)
# check correctness # check correctness
...@@ -371,8 +385,8 @@ def main(): ...@@ -371,8 +385,8 @@ def main():
chunk_size=64, chunk_size=64,
use_g=True, use_g=True,
use_initial_state=False, use_initial_state=False,
store_final_state=False, store_final_state=True,
save_new_value=False, save_new_value=True,
block_DK=32, block_DK=32,
block_DV=32, block_DV=32,
threads=128, threads=128,
......
...@@ -19,7 +19,7 @@ except ImportError: ...@@ -19,7 +19,7 @@ except ImportError:
fla = None fla = None
import torch import torch
from utils import * from test_utils import assert_similar
torch.random.manual_seed(0) torch.random.manual_seed(0)
# torch.set_printoptions(profile="full") # torch.set_printoptions(profile="full")
......
...@@ -501,7 +501,7 @@ def run_test( ...@@ -501,7 +501,7 @@ def run_test(
dg_tilelang = dg_tilelang + dg_tilelang_A_positive.sum(dim=-1) - dg_tilelang_A_negative.sum( dg_tilelang = dg_tilelang + dg_tilelang_A_positive.sum(dim=-1) - dg_tilelang_A_negative.sum(
dim=-1) dim=-1)
from utils import assert_similar from test_utils import assert_similar
assert_similar(dk_ref, dk_tilelang, eps=1e-5, name="dk", raise_assert=False) assert_similar(dk_ref, dk_tilelang, eps=1e-5, name="dk", raise_assert=False)
assert_similar(dv_ref, dv_tilelang, eps=1e-5, name="dv", raise_assert=False) assert_similar(dv_ref, dv_tilelang, eps=1e-5, name="dv", raise_assert=False)
assert_similar(dbeta_ref, dbeta_tilelang, eps=1e-5, name="dbeta", raise_assert=False) assert_similar(dbeta_ref, dbeta_tilelang, eps=1e-5, name="dbeta", raise_assert=False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment