Unverified Commit 6f59668d authored by Zhengju Tang's avatar Zhengju Tang Committed by GitHub
Browse files

Gated Delta Net(GDN) kernel implementation in TileLang (#695)

* [GDN] Add examples for GDN forward and backward kernels

* [Refactor] Folder structure refactor for duplicated utils

* [Test] Add test script for kernels

* [Refactor] Rename examples to align with the repo

* [Lint] Modify README

* [Update] Modified README to align upstream repo

* [BugFix] Path of FLA

* [Fix] Copyright and test

* [Lint]

* [CI] Add GDN compilation test CI

* [Lint]

* [BugFix] Import error of fla
parent 36b57617
# Gated Delta Net(GDN) kernel implementation in TileLang
## Requirement
### The Tilelang version for test is 0.1.5+17fafc1b3026d910a83eb8052fdf811ba56be0b1
### We currently use triton=3.3.0 and FLA commit id=f03cb3ae for comparison
## Get started
### The common/chunk_delta_h.py implements the most critical forward kernel of GDN. It's a good start to understand the GDN logic and the tilelang optimization
\ No newline at end of file
This diff is collapsed.
# Reference: fla/ops/common/chunk_delta_h.py
import sys # noqa: F401
import tilelang
import tilelang.language as T
# Add your fla repository path to sys.path
# Currently we use the fla repository from the flash-linear-attention project at commit id f03cb3ae
# sys.path.insert(0, "/home/tzj/flash-linear-attention")
try:
import fla
print(fla.__file__)
from fla.ops.common.chunk_delta_h import chunk_gated_delta_rule_fwd_h
except ImportError:
print("fla not found, using tilelang implementation")
fla = None
import torch
import torch.nn.functional as F
from tilelang.engine.callback import register_cuda_postproc_callback # noqa: F401
from utils import *
# (zhengju) We can slightly modify the generated cuda code from tilelang lowering
# in the debug folder to make the performance better. To enable this callback,
# you can comment out the following function.
# @register_cuda_postproc_callback
# def tilelang_callback_cuda_postproc(code, _):
# cuda_code = open("../debug/chunk_delta_h_fuse.cu", "r").read()
# code = cuda_code
# return code
torch.random.manual_seed(0)
tilelang.disable_cache()
def prepare_input(
B,
S,
H,
DK,
DV,
chunk_size,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
):
K = torch.randn(B, S, H, DK, dtype=input_dtype).cuda()
K = F.normalize(K, dim=-1, p=2)
W = torch.randn(B, S, H, DK, dtype=input_dtype).cuda()
W = F.normalize(W, dim=-1, p=2)
U = torch.randn(B, S, H, DV, dtype=input_dtype).cuda()
U = F.normalize(U, dim=-1, p=2)
G = torch.randn(B, S, H, dtype=gate_dtype).cuda()
G = F.logsigmoid(G)
try:
from fla.ops.utils.cumsum import chunk_local_cumsum
G = chunk_local_cumsum(G, chunk_size)
except ImportError:
print("fla not found, skip cumsum")
initial_state = torch.randn(B, H, DK, DV, dtype=input_dtype).cuda()
return K, W, U, G, initial_state
def prepare_output(
B,
S,
H,
DK,
DV,
chunk_size,
output_dtype,
state_dtype,
):
BS = S // chunk_size
h = torch.empty(B, BS, H, DK, DV, dtype=output_dtype).cuda()
final_state = torch.empty(B, H, DK, DV, dtype=state_dtype).cuda()
V_new = torch.empty(B, S, H, DV, dtype=output_dtype).cuda()
return h, final_state, V_new
@tilelang.jit(out_idx=[-3, -2, -1])
def tilelang_chunk_gated_delta_rule_fwd_h(
# task config
B,
S,
H,
DK,
DV,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
state_dtype,
chunk_size,
use_g=True,
use_initial_state=True,
store_final_state=True,
save_new_value=True,
# kernel config
block_DK=64,
block_DV=64,
threads=256,
num_stages=0,
):
block_S = chunk_size
BS = S // block_S
K_shape = (B, S, H, DK)
V_shape = (B, S, H, DV)
W_shape = (B, S, H, DK)
U_shape = (B, S, H, DV)
G_shape = (B, S, H)
h_shape = (B, BS, H, DK, DV)
initial_state_shape = (B, H, DK, DV)
final_state_shape = (B, H, DK, DV)
@T.prim_func
def kernel(
K: T.Tensor(K_shape, dtype=input_dtype),
W: T.Tensor(W_shape, dtype=input_dtype),
U: T.Tensor(U_shape, dtype=input_dtype),
G: T.Tensor(G_shape, dtype=gate_dtype),
initial_state: T.Tensor(initial_state_shape, dtype=input_dtype),
h: T.Tensor(h_shape, dtype=output_dtype),
final_state: T.Tensor(final_state_shape, dtype=state_dtype),
V_new: T.Tensor(V_shape, dtype=output_dtype),
):
with T.Kernel(T.ceildiv(DV, block_DV), B * H, threads=threads) as (bv, bbh):
bb, bh = bbh // H, bbh % H
b_h_shared = T.alloc_shared((DK, block_DV), dtype=input_dtype)
b_h_fragment = T.alloc_fragment((DK, block_DV), dtype=accum_dtype)
U_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype)
U_fragment = T.alloc_fragment((block_S, block_DV), dtype=accum_dtype)
W_shared = T.alloc_shared((block_S, DK), dtype=input_dtype)
V_new_fragment = T.alloc_fragment((block_S, block_DV), dtype=accum_dtype)
V_new_shared = T.alloc_shared((block_S, block_DV), dtype=output_dtype)
K_shared = T.alloc_shared((block_S, DK), dtype=input_dtype)
G_last_local = T.alloc_local((1), dtype=gate_dtype)
G_shared = T.alloc_shared((block_S, block_DV), dtype=gate_dtype)
G_fragment = T.alloc_fragment((block_S, block_DV), dtype=gate_dtype)
T.annotate_layout({
b_h_shared: tilelang.layout.make_swizzled_layout(b_h_shared),
U_shared: tilelang.layout.make_swizzled_layout(U_shared),
W_shared: tilelang.layout.make_swizzled_layout(W_shared),
V_new_shared: tilelang.layout.make_swizzled_layout(V_new_shared),
K_shared: tilelang.layout.make_swizzled_layout(K_shared),
G_shared: tilelang.layout.make_swizzled_layout(G_shared),
})
T.use_swizzle(10)
if use_initial_state:
T.copy(initial_state[bb, bh, 0:DK, bv * block_DV:(bv + 1) * block_DV], b_h_shared)
T.copy(b_h_shared, b_h_fragment)
else:
T.clear(b_h_fragment)
for i_s in T.Pipelined(T.ceildiv(S, block_S), num_stages=num_stages):
# Store previous result to the hidden tensor, like the epilogue
T.copy(b_h_shared, h[bb, i_s, bh, 0:DK, bv * block_DV:(bv + 1) * block_DV])
# Recurrence
T.copy(W[bb, i_s * block_S:(i_s + 1) * block_S, bh, 0:DK], W_shared)
T.gemm(W_shared, b_h_shared, V_new_fragment, clear_accum=True)
# U - W * S
T.copy(
U[bb, i_s * block_S:(i_s + 1) * block_S, bh, bv * block_DV:(bv + 1) * block_DV],
U_shared)
T.copy(U_shared, U_fragment)
for i_s2, i_v in T.Parallel(block_S, block_DV):
V_new_fragment[i_s2, i_v] = -V_new_fragment[i_s2, i_v] + U_fragment[i_s2, i_v]
# Save V_new
if save_new_value:
T.copy(V_new_fragment, dst=V_new_shared)
T.copy(
V_new_shared, V_new[bb, i_s * block_S:(i_s + 1) * block_S, bh,
bv * block_DV:(bv + 1) * block_DV])
T.copy(K[bb, i_s * block_S:(i_s + 1) * block_S, bh, 0:DK], K_shared)
# use_g
if use_g:
G_last_local[0] = G[bb, (i_s + 1) * block_S - 1, bh]
for i_s2, i_v in T.Parallel(block_S, block_DV):
G_shared[i_s2, i_v] = G[bb, i_s * block_S + i_s2, bh]
T.copy(G_shared, G_fragment)
for i_s2, i_v in T.Parallel(block_S, block_DV):
with T.If(G_last_local[0] - G_fragment[i_s2, i_v] <= 0):
with T.Then():
V_new_fragment[i_s2, i_v] = V_new_fragment[i_s2, i_v] * T.exp(
G_last_local[0] - G_fragment[i_s2, i_v])
with T.Else():
V_new_fragment[i_s2, i_v] = 0
G_last_local[0] = T.exp(G_last_local[0])
for i_k, i_v in T.Parallel(DK, block_DV):
b_h_fragment[i_k, i_v] *= G_last_local[0]
# Update intermediate results
T.copy(V_new_fragment, V_new_shared)
T.gemm(K_shared, V_new_shared, b_h_fragment, transpose_A=True)
T.copy(b_h_fragment, b_h_shared)
# Save final state
if store_final_state:
T.copy(b_h_fragment, final_state[bb, bh, 0:DK, bv * block_DV:(bv + 1) * block_DV])
return kernel
def do_bench(fn, *args, warmup=10, rep=10, **kwargs):
"""
Do benchmark for a function.
"""
start_event = [torch.cuda.Event(enable_timing=True) for i in range(rep)]
end_event = [torch.cuda.Event(enable_timing=True) for i in range(rep)]
for _ in range(warmup):
fn(*args, **kwargs)
torch.cuda.synchronize()
for i in range(rep):
start_event[i].record()
fn(*args, **kwargs)
end_event[i].record()
torch.cuda.synchronize()
# Record clocks
times = torch.tensor(
[s.elapsed_time(e) for s, e in zip(start_event, end_event)],
dtype=torch.float,
)
return times.mean().item()
def run_test(
B,
S,
H,
DK,
DV,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
state_dtype,
chunk_size,
use_g=True,
use_initial_state=True,
store_final_state=True,
save_new_value=True,
block_DK=64,
block_DV=32,
threads=128,
num_stages=0,
):
K, W, U, G, initial_state = prepare_input(B, S, H, DK, DV, chunk_size,
getattr(torch, input_dtype),
getattr(torch, output_dtype),
getattr(torch, accum_dtype),
getattr(torch, gate_dtype))
h_ref, final_state_ref, V_new_ref = prepare_output(B, S, H, DK, DV, chunk_size,
getattr(torch, output_dtype),
getattr(torch, state_dtype))
h_tilelang, final_state_tilelang, V_new_tilelang = prepare_output(B, S, H, DK, DV, chunk_size,
getattr(torch, output_dtype),
getattr(torch, state_dtype))
# fla ref
h_ref, V_new_ref, final_state_ref = chunk_gated_delta_rule_fwd_h(K, W, U, G, initial_state,
store_final_state, chunk_size,
save_new_value)
# tilelang
kernel = tilelang_chunk_gated_delta_rule_fwd_h(B, S, H, DK, DV, input_dtype, output_dtype,
accum_dtype, gate_dtype, state_dtype, chunk_size,
use_g, use_initial_state, store_final_state,
save_new_value, block_DK, block_DV, threads,
num_stages)
h_tilelang, final_state_tilelang, V_new_tilelang = kernel(K, W, U, G, initial_state)
# (zhengju) If you want to print the generated cuda code, you can uncomment the following line
# print("CUDA Code:\n", kernel.get_kernel_source())
fla_time = do_bench(chunk_gated_delta_rule_fwd_h, K, W, U, G, initial_state, store_final_state,
chunk_size, save_new_value)
tilelang_time = do_bench(kernel, K, W, U, G, initial_state)
# check correctness
try:
h_ref_fp32 = h_ref.to(torch.float32)
h_tilelang_fp32 = h_tilelang.to(torch.float32)
assert_similar(
h_ref_fp32,
h_tilelang_fp32,
eps=1e-5,
name="tilelang chunk gated delta rule fwd h",
raise_assert=False)
print("tilelang chunk gated delta rule fwd h passed √")
except Exception as e:
print("tilelang chunk gated delta rule fwd h failed ✗")
print(e)
try:
final_state_ref_fp32 = final_state_ref.to(torch.float32)
final_state_tilelang_fp32 = final_state_tilelang.to(torch.float32)
assert_similar(
final_state_ref_fp32,
final_state_tilelang_fp32,
eps=1e-5,
name="tilelang chunk gated delta rule fwd final_state",
raise_assert=False)
print("tilelang chunk gated delta rule fwd final_state passed √")
except Exception as e:
print("tilelang chunk gated delta rule fwd final_state failed ✗")
print(e)
try:
V_new_ref_fp32 = V_new_ref.to(torch.float32)
V_new_tilelang_fp32 = V_new_tilelang.to(torch.float32)
assert_similar(
V_new_ref_fp32,
V_new_tilelang_fp32,
eps=1e-5,
name="tilelang chunk gated delta rule fwd V_new",
raise_assert=False)
print("tilelang chunk gated delta rule fwd V_new passed √")
except Exception as e:
print("tilelang chunk gated delta rule fwd V_new failed ✗")
print(e)
print(f"tilelang time: {tilelang_time} ms")
print(f"fla time: {fla_time} ms")
def main():
run_test(
B=1,
S=32768,
H=32,
DK=128,
DV=128,
input_dtype="bfloat16",
output_dtype="bfloat16",
accum_dtype="float32",
gate_dtype="float32",
state_dtype="float32",
chunk_size=64,
use_g=True,
use_initial_state=True,
store_final_state=True,
save_new_value=True,
block_DK=64,
block_DV=32,
threads=128,
num_stages=1,
)
if __name__ == "__main__":
main()
# Reference: fla/ops/common/chunk_o.py
import tilelang
import tilelang.language as T
import sys # noqa: F401
# Add your fla repository path to sys.path
# Currently we use the fla repository from the flash-linear-attention project at commit id f03cb3ae
# sys.path.insert(0, "/home/tzj/flash-linear-attention")
try:
import fla
print(fla.__file__)
from fla.ops.common.chunk_o import chunk_fwd_o
except ImportError:
print("fla not found, using tilelang implementation")
fla = None
import torch
torch.random.manual_seed(1)
tilelang.disable_cache()
def prepare_input(
B,
S,
H,
DK,
DV,
chunk_size,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
):
BS = chunk_size
Q = torch.randn(B, S, H, DK, dtype=input_dtype).cuda()
K = torch.randn(B, S, H, DK, dtype=input_dtype).cuda()
V = torch.randn(B, S, H, DV, dtype=input_dtype).cuda()
HIDDEN = torch.randn(B, S // BS, H, DK, DV, dtype=input_dtype).cuda()
G = torch.randn(B, S, H, dtype=gate_dtype).cuda()
return Q, K, V, HIDDEN, G
def prepare_output(
B,
S,
H,
DK,
DV,
chunk_size,
output_dtype,
):
O = torch.empty(B, S, H, DV, dtype=output_dtype).cuda()
return O
@tilelang.jit(out_idx=[-1])
def tilelang_chunk_fwd_o(
# task config
B,
S,
H,
DK,
DV,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
chunk_size,
scale,
use_g,
# kernel config
block_S=64,
block_DK=64,
block_DV=64,
threads=256,
num_stages=0,
):
assert chunk_size == block_S, "chunk_size must be equal to block_S"
BS = chunk_size
Q_shape = (B, S, H, DK)
K_shape = (B, S, H, DK)
V_shape = (B, S, H, DV)
H_shape = (B, S // BS, H, DK, DV)
G_shape = (B, S, H)
O_shape = (B, S, H, DV)
@T.prim_func
def kernel(
Q: T.Tensor(Q_shape, dtype=input_dtype),
K: T.Tensor(K_shape, dtype=input_dtype),
V: T.Tensor(V_shape, dtype=input_dtype),
HIDDEN: T.Tensor(H_shape, dtype=input_dtype),
G: T.Tensor(G_shape, dtype=gate_dtype),
O: T.Tensor(O_shape, dtype=output_dtype),
):
with T.Kernel(
T.ceildiv(DV, block_DV), T.ceildiv(S, block_S), B * H,
threads=threads) as (bv, bs, bbh):
bb, bh = bbh // H, bbh % H
Q_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype)
K_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype)
V_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype)
H_shared = T.alloc_shared((block_DK, block_DV), dtype=input_dtype)
A_shared = T.alloc_shared((block_S, block_S), dtype=input_dtype)
O_shared = T.alloc_shared((block_S, block_DV), dtype=output_dtype)
A_fragment = T.alloc_fragment((block_S, block_S), dtype=accum_dtype)
O_fragment = T.alloc_fragment((block_S, block_DV), dtype=accum_dtype)
G_shared = T.alloc_shared((block_S,), dtype=gate_dtype, scope="shared")
G_diff_local = T.alloc_fragment((block_S, block_S), dtype=gate_dtype)
T.annotate_layout({
Q_shared: tilelang.layout.make_swizzled_layout(Q_shared),
K_shared: tilelang.layout.make_swizzled_layout(K_shared),
V_shared: tilelang.layout.make_swizzled_layout(V_shared),
H_shared: tilelang.layout.make_swizzled_layout(H_shared),
A_shared: tilelang.layout.make_swizzled_layout(A_shared),
O_shared: tilelang.layout.make_swizzled_layout(O_shared),
})
T.clear(A_fragment)
T.clear(O_fragment)
T.no_set_max_nreg()
for i_k in T.Pipelined(T.ceildiv(DK, block_DK), num_stages=num_stages):
T.copy(
Q[bb, bs * block_S:(bs + 1) * block_S, bh, i_k * block_DK:(i_k + 1) * block_DK],
Q_shared)
T.copy(
K[bb, bs * block_S:(bs + 1) * block_S, bh, i_k * block_DK:(i_k + 1) * block_DK],
K_shared)
T.copy(
HIDDEN[bb, bs, bh, i_k * block_DK:(i_k + 1) * block_DK,
bv * block_DV:(bv + 1) * block_DV], H_shared)
T.gemm(Q_shared, H_shared, O_fragment)
T.gemm(Q_shared, K_shared, A_fragment, transpose_B=True)
if use_g:
for i_s in T.Parallel(block_S):
G_shared[i_s] = G[bb, bs * block_S + i_s, bh]
# T.copy(G[bb, bs * block_S:(bs + 1) * block_S, bh], G_shared)
for i_s, i_v in T.Parallel(block_S, block_DV):
O_fragment[i_s, i_v] = O_fragment[i_s, i_v] * T.exp(G_shared[i_s])
for i_s1, i_s2 in T.Parallel(block_S, block_S):
G_diff_local[i_s1, i_s2] = G_shared[i_s1] - G_shared[i_s2]
for i_s1, i_s2 in T.Parallel(block_S, block_S):
with T.If(G_diff_local[i_s1, i_s2] <= 0):
with T.Then():
A_fragment[i_s1, i_s2] = A_fragment[i_s1, i_s2] * T.exp(
G_diff_local[i_s1, i_s2])
with T.Else():
A_fragment[i_s1, i_s2] = 0
for i_s1, i_s2 in T.Parallel(block_S, block_S):
with T.If(i_s1 < i_s2): # noqa: SIM117
with T.Then():
A_fragment[i_s1, i_s2] = 0
T.copy(V[bb, bs * block_S:(bs + 1) * block_S, bh, bv * block_DV:(bv + 1) * block_DV],
V_shared)
T.copy(A_fragment, A_shared)
T.gemm(A_shared, V_shared, O_fragment)
for i_s, i_v in T.Parallel(block_S, block_DV):
O_fragment[i_s, i_v] = O_fragment[i_s, i_v] * scale
T.copy(O_fragment, O_shared)
T.copy(O_shared, O[bb, bs * block_S:(bs + 1) * block_S, bh,
bv * block_DV:(bv + 1) * block_DV])
return kernel
def run_test(
B,
S,
H,
DK,
DV,
chunk_size,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
use_g,
block_DK,
block_DV,
threads,
num_stages,
):
input_dtype_torch = getattr(torch, input_dtype)
output_dtype_torch = getattr(torch, output_dtype)
accum_dtype_torch = getattr(torch, accum_dtype)
gate_dtype_torch = getattr(torch, gate_dtype)
Q, K, V, HIDDEN, G = prepare_input(B, S, H, DK, DV, chunk_size, input_dtype_torch,
output_dtype_torch, accum_dtype_torch, gate_dtype_torch)
scale = 1.0 / DK**0.5
O_ref = prepare_output(B, S, H, DK, DV, chunk_size, output_dtype_torch)
O_ref = chunk_fwd_o(Q, K, V, HIDDEN, G, scale, chunk_size=chunk_size)
block_S = chunk_size
O_tilelang = prepare_output(B, S, H, DK, DV, chunk_size, output_dtype_torch)
kernel = tilelang_chunk_fwd_o(B, S, H, DK, DV, input_dtype, output_dtype, accum_dtype,
gate_dtype, chunk_size, scale, use_g, block_S, block_DK, block_DV,
threads, num_stages)
O_tilelang = kernel(Q, K, V, HIDDEN, G)
try:
torch.testing.assert_close(O_tilelang, O_ref, rtol=1e-2, atol=1e-2)
print("tilelang chunk fwd o passed √")
except Exception as e:
print("tilelang chunk fwd o failed ✗")
print(e)
def main():
run_test(
B=1,
S=32768,
H=32,
DK=128,
DV=128,
chunk_size=64,
input_dtype="bfloat16",
output_dtype="bfloat16",
accum_dtype="float32",
gate_dtype="float32",
use_g=True,
block_DK=128,
block_DV=128,
threads=128,
num_stages=1,
)
if __name__ == "__main__":
main()
This diff is collapsed.
# Reference: fla/ops/common/chunk_scaled_dot_kkt.py
import tilelang
import tilelang.language as T
import sys # noqa: F401
# Add your fla repository path to sys.path
# Currently we use the fla repository from the flash-linear-attention project at commit id f03cb3ae
# sys.path.insert(0, "/home/tzj/flash-linear-attention")
try:
import fla
print(fla.__file__)
from fla.ops.common.chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd
except ImportError:
print("fla not found, using tilelang implementation")
fla = None
import torch
torch.set_printoptions(profile="full")
torch.random.manual_seed(0)
tilelang.disable_cache()
def prepare_input(
B,
S,
H,
DK,
input_dtype,
output_dtype,
accum_dtype,
):
K = torch.randn(B, S, H, DK, dtype=input_dtype).cuda()
Beta = torch.randn(B, S, H, dtype=input_dtype).cuda()
G = torch.randn(B, S, H, dtype=accum_dtype).cuda()
return K, Beta, G
def prepare_output(
B,
S,
H,
chunk_size,
dtype,
):
BS = chunk_size
A = torch.empty(B, S, H, BS, dtype=dtype).cuda()
return A
@tilelang.jit(out_idx=[-1])
def tilelang_chunk_scaled_dot_kkt_fwd(
# task config
B,
S,
H,
DK,
chunk_size=64,
input_dtype="bfloat16",
output_dtype="bfloat16",
accum_dtype="float32",
use_g=True,
# kernel config
block_S=64,
block_DK=64,
threads=256,
num_stages=0,
):
K_shape = (B, S, H, DK)
Beta_shape = (B, S, H)
G_shape = (B, S, H)
assert chunk_size == block_S, "chunk_size must be equal to block_S"
BS = chunk_size
output_shape = (B, S, H, BS)
@T.prim_func
def kernel(
K: T.Tensor(K_shape, dtype=input_dtype),
Beta: T.Tensor(Beta_shape, dtype=input_dtype),
G: T.Tensor(G_shape, dtype=accum_dtype),
A: T.Tensor(output_shape, dtype=output_dtype),
):
with T.Kernel(T.ceildiv(S, block_S), B * H, threads=threads) as (bs, bbh):
bb, bh = bbh // H, bbh % H
# !! Pay attention to the scope of the shared memory: may cause misaligned address when shape is one dimension or the buffer is too small
Beta_shared = T.alloc_shared((block_S,), dtype=input_dtype, scope="shared")
K_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype)
A_shared = T.alloc_shared((block_S, block_S), dtype=output_dtype)
Beta_K_fragment = T.alloc_fragment((block_S, block_DK), dtype=input_dtype)
A_fragment = T.alloc_fragment((block_S, block_S), dtype=accum_dtype)
# Tensor used for gated:
G_shared = T.alloc_shared((block_S,), dtype=accum_dtype, scope="shared")
G_diff_local = T.alloc_fragment((block_S, block_S), dtype=accum_dtype)
T.annotate_layout({
K_shared: tilelang.layout.make_swizzled_layout(K_shared),
A_shared: tilelang.layout.make_swizzled_layout(A_shared),
})
T.fill(A_fragment, 0)
T.no_set_max_nreg()
for i_s in T.Parallel(block_S):
Beta_shared[i_s] = Beta[bb, bs * block_S + i_s, bh]
for i_k in T.Pipelined(T.ceildiv(DK, block_DK), num_stages=num_stages):
T.copy(
K[bb, bs * block_S:(bs + 1) * block_S, bh, i_k * block_DK:(i_k + 1) * block_DK],
K_shared)
for i_s, i_k2 in T.Parallel(block_S, block_DK):
Beta_K_fragment[i_s, i_k2] = K_shared[i_s, i_k2] * Beta_shared[i_s]
T.gemm(Beta_K_fragment, K_shared, A_fragment, transpose_B=True)
if use_g:
for i_s in T.Parallel(block_S):
G_shared[i_s] = G[bb, bs * block_S + i_s, bh]
for i_s1, i_s2 in T.Parallel(block_S, block_S):
G_diff_local[i_s1, i_s2] = G_shared[i_s1] - G_shared[i_s2]
for i_s1, i_s2 in T.Parallel(block_S, block_S):
with T.If(G_diff_local[i_s1, i_s2] <= 0 and i_s1 > i_s2):
with T.Then():
A_fragment[i_s1, i_s2] = A_fragment[i_s1, i_s2] * T.exp(
G_diff_local[i_s1, i_s2])
with T.Else():
A_fragment[i_s1, i_s2] = 0
else:
for i_s1, i_s2 in T.Parallel(block_S, block_S):
with T.If(i_s1 <= i_s2): # noqa: SIM117
with T.Then():
A_fragment[i_s1, i_s2] = 0
T.copy(A_fragment, A_shared)
T.copy(A_shared, A[bb, bs * block_S:(bs + 1) * block_S, bh, :])
return kernel
def run_test(
B,
S,
H,
DK,
chunk_size,
input_dtype,
output_dtype,
accum_dtype,
use_g,
block_DK,
threads,
num_stages,
):
K, Beta, G = prepare_input(B, S, H, DK, getattr(torch, input_dtype),
getattr(torch, output_dtype), getattr(torch, accum_dtype))
A_ref = prepare_output(B, S, H, chunk_size, getattr(torch, output_dtype))
A_tilelang = prepare_output(B, S, H, chunk_size, getattr(torch, output_dtype))
# reference
if use_g:
A_ref = chunk_scaled_dot_kkt_fwd(
K, Beta, G, chunk_size=chunk_size, output_dtype=getattr(torch, output_dtype))
else:
A_ref = chunk_scaled_dot_kkt_fwd(
K, Beta, None, chunk_size=chunk_size, output_dtype=getattr(torch, output_dtype))
# tilelang
block_S = chunk_size
kernel = tilelang_chunk_scaled_dot_kkt_fwd(B, S, H, DK, chunk_size, input_dtype, output_dtype,
accum_dtype, use_g, block_S, block_DK, threads,
num_stages)
A_tilelang = kernel(K, Beta, G)
try:
torch.testing.assert_close(A_tilelang, A_ref, rtol=1e-2, atol=1e-2)
print("tilelang chunk scaled dot kkt fwd passed √")
except Exception as e:
print("tilelang chunk scaled dot kkt fwd failed ✗")
print(e)
print("reference cuda kernel:")
print(kernel.get_kernel_source())
def main():
run_test(
B=1,
S=32768,
H=32,
DK=128,
chunk_size=64,
input_dtype="bfloat16",
output_dtype="bfloat16",
accum_dtype="float32",
use_g=True,
block_DK=64,
threads=128,
num_stages=2)
if __name__ == "__main__":
main()
# Util functions for flash linear attention cumsum
# Reference: fla/ops/utils/cumsum.py
import tilelang
import tilelang.language as T
import sys # noqa: F401
# Add your fla repository path to sys.path
# Currently we use the fla repository from the flash-linear-attention project at commit id f03cb3ae
# sys.path.insert(0, "/home/tzj/flash-linear-attention")
try:
import fla
print(fla.__file__)
from fla.ops.utils.cumsum import chunk_local_cumsum_scalar
except ImportError:
print("fla not found, using tilelang implementation")
fla = None
import torch
tilelang.disable_cache()
@tilelang.jit(
out_idx=[-1],
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True
})
def tilelang_chunk_local_cumsum_scalar(
# task config
B,
S,
H,
chunk_size=64,
is_varlen=False,
head_first=False,
reverse=False,
input_dtype="float16",
output_dtype="float32",
# kernel config
block_S=64,
threads=256,
use_fragment=False,
):
G_shape = (B, H, S) if head_first else (B, S, H)
assert chunk_size == 2**(chunk_size.bit_length() - 1), "chunk_size must be a power of 2"
assert chunk_size == block_S, "chunk_size must be equal to block_S"
@T.prim_func
def kernel(
G: T.Tensor(G_shape, dtype=input_dtype),
G_new: T.Tensor(G_shape, dtype=output_dtype),
):
with T.Kernel(T.ceildiv(S, block_S), B * H, threads=threads) as (bs, bbh):
bb, bh = bbh // H, bbh % H
G_shared = T.alloc_shared((1, block_S), dtype=output_dtype, scope="shared")
if head_first:
T.copy(G[bb, bh, bs * block_S:(bs + 1) * block_S], G_shared)
else:
T.copy(G[bb, bs * block_S:(bs + 1) * block_S, bh], G_shared)
if use_fragment:
G_fragment = T.alloc_fragment((1, block_S), dtype=output_dtype, scope="shared")
T.copy(G_shared, G_fragment)
T.cumsum(G_fragment, dim=1, reverse=reverse)
if head_first:
T.copy(G_fragment, G_new[bb, bh, bs * block_S:(bs + 1) * block_S])
else:
T.copy(G_fragment, G_new[bb, bs * block_S:(bs + 1) * block_S, bh])
else:
T.cumsum(G_shared, dim=1, reverse=reverse)
if head_first:
T.copy(G_shared, G_new[bb, bh, bs * block_S:(bs + 1) * block_S])
else:
T.copy(G_shared, G_new[bb, bs * block_S:(bs + 1) * block_S, bh])
return kernel
def prepare_cumsum_input(
B,
S,
H,
dtype,
):
G = torch.randn(B, S, H, dtype=dtype).cuda()
return G
def prepare_cumsum_output(
B,
S,
H,
dtype,
):
G_new = torch.empty(B, S, H, dtype=dtype).cuda()
return G_new
def run_test(
B,
S,
H,
chunk_size,
reverse,
head_first,
input_dtype,
output_dtype,
threads,
use_fragment,
):
G = prepare_cumsum_input(B, S, H, getattr(torch, input_dtype))
G_new_ref = prepare_cumsum_output(B, S, H, getattr(torch, output_dtype))
G_new_tilelang = prepare_cumsum_output(B, S, H, getattr(torch, output_dtype))
# reference cumsum
G_new_ref = chunk_local_cumsum_scalar(
g=G,
chunk_size=chunk_size,
reverse=reverse,
head_first=head_first,
output_dtype=getattr(torch, output_dtype))
# tilelang cumsum
block_S = chunk_size
kernel = tilelang_chunk_local_cumsum_scalar(
B=B,
S=S,
H=H,
chunk_size=chunk_size,
reverse=reverse,
head_first=head_first,
input_dtype=input_dtype,
output_dtype=output_dtype,
block_S=block_S,
threads=threads,
use_fragment=use_fragment,
)
torch.cuda.profiler.start()
G_new_tilelang = kernel(G)
torch.cuda.profiler.stop()
try:
torch.testing.assert_close(G_new_tilelang, G_new_ref, rtol=1e-2, atol=1e-2)
print("tilelang cumsum passed √")
except Exception as e:
print("tilelang cumsum failed ✗")
print(e)
print("G:")
print(G.view(-1))
print("G_new_tilelang:")
print(G_new_tilelang.view(-1))
print("G_new_ref:")
print(G_new_ref.view(-1))
def main():
run_test(
B=1,
S=32768,
H=32,
chunk_size=64,
reverse=True,
head_first=False,
input_dtype="float32",
output_dtype="float32",
threads=256,
use_fragment=False)
if __name__ == "__main__":
main()
# Reference: fla/ops/gated_delta_rule/wy_fast.py
import tilelang
import tilelang.language as T
import sys # noqa: F401
# Add your fla repository path to sys.path
# Currently we use the fla repository from the flash-linear-attention project at commit id f03cb3ae
# sys.path.insert(0, "/home/tzj/flash-linear-attention")
try:
import fla
print(fla.__file__)
from fla.ops.gated_delta_rule.wy_fast import recompute_w_u_fwd
except ImportError:
print("fla not found, using tilelang implementation")
fla = None
import torch
torch.random.manual_seed(1)
tilelang.disable_cache()
def prepare_input(B, S, H, DK, DV, chunk_size, input_dtype, output_dtype, gate_dtype=torch.float32):
BS = chunk_size
K = torch.randn(B, S, H, DK, dtype=input_dtype).cuda()
V = torch.randn(B, S, H, DV, dtype=input_dtype).cuda()
Beta = torch.randn(B, S, H, dtype=input_dtype).cuda()
G = torch.randn(B, S, H, dtype=gate_dtype).cuda()
A = torch.randn(B, S, H, BS, dtype=output_dtype).cuda()
return K, V, Beta, G, A
def prepare_output(
B,
S,
H,
DK,
DV,
output_dtype,
):
W = torch.empty(B, S, H, DK, dtype=output_dtype).cuda()
U = torch.empty(B, S, H, DV, dtype=output_dtype).cuda()
return W, U
@tilelang.jit(out_idx=[-2, -1])
def tilelang_recompute_w_u_fwd(
# task config
B,
S,
H,
DK,
DV,
input_dtype,
output_dtype,
gate_dtype,
accum_dtype,
chunk_size,
# kernel config
block_S=64,
block_DK=64,
block_DV=64,
threads=256,
num_stages=0,
):
K_shape = (B, S, H, DK)
V_shape = (B, S, H, DV)
Beta_shape = (B, S, H)
assert chunk_size == block_S, "chunk_size must be equal to block_S"
BS = chunk_size
G_shape = (B, S, H)
A_shape = (B, S, H, BS)
@T.prim_func
def kernel(
K: T.Tensor(K_shape, dtype=input_dtype),
V: T.Tensor(V_shape, dtype=input_dtype),
Beta: T.Tensor(Beta_shape, dtype=input_dtype),
G: T.Tensor(G_shape, dtype=gate_dtype),
A: T.Tensor(A_shape, dtype=output_dtype),
W: T.Tensor(K_shape, dtype=output_dtype),
U: T.Tensor(V_shape, dtype=output_dtype),
):
with T.Kernel(T.ceildiv(S, block_S), B * H, threads=threads) as (bs, bbh):
bb, bh = bbh // H, bbh % H
Beta_shared = T.alloc_shared((block_S,), dtype=input_dtype, scope="shared")
K_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype)
V_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype)
G_shared = T.alloc_shared((block_S,), dtype=gate_dtype, scope="shared")
A_shared = T.alloc_shared((block_S, block_S), dtype=output_dtype)
W_fragment = T.alloc_fragment((block_S, block_DK), dtype=accum_dtype)
U_fragment = T.alloc_fragment((block_S, block_DV), dtype=accum_dtype)
W_shared = T.alloc_shared((block_S, block_DK), dtype=output_dtype)
U_shared = T.alloc_shared((block_S, block_DV), dtype=output_dtype)
W_Beta_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype)
U_Beta_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype)
T.annotate_layout({
K_shared: tilelang.layout.make_swizzled_layout(K_shared),
V_shared: tilelang.layout.make_swizzled_layout(V_shared),
A_shared: tilelang.layout.make_swizzled_layout(A_shared),
W_shared: tilelang.layout.make_swizzled_layout(W_shared),
U_shared: tilelang.layout.make_swizzled_layout(U_shared),
W_Beta_shared: tilelang.layout.make_swizzled_layout(W_Beta_shared),
U_Beta_shared: tilelang.layout.make_swizzled_layout(U_Beta_shared),
})
T.no_set_max_nreg()
for i_s in T.Parallel(block_S):
Beta_shared[i_s] = Beta[bb, bs * block_S + i_s, bh]
G_shared[i_s] = T.exp(G[bb, bs * block_S + i_s, bh])
T.copy(A[bb, bs * block_S:(bs + 1) * block_S, bh, :], A_shared)
for i_v in T.Pipelined(T.ceildiv(DV, block_DV), num_stages=num_stages):
T.copy(
V[bb, bs * block_S:(bs + 1) * block_S, bh, i_v * block_DV:(i_v + 1) * block_DV],
V_shared)
for i_s, i_v2 in T.Parallel(block_S, block_DV):
U_Beta_shared[i_s, i_v2] = V_shared[i_s, i_v2] * Beta_shared[i_s]
T.gemm(A_shared, U_Beta_shared, U_fragment, clear_accum=True)
# First copy to smem, then copy to gmem to reduce U2RU instructions
T.copy(U_fragment, U_shared)
T.copy(
U_shared, U[bb, bs * block_S:(bs + 1) * block_S, bh,
i_v * block_DV:(i_v + 1) * block_DV])
for i_k in T.Pipelined(T.ceildiv(DK, block_DK), num_stages=num_stages):
T.copy(
K[bb, bs * block_S:(bs + 1) * block_S, bh, i_k * block_DK:(i_k + 1) * block_DK],
K_shared)
for i_s, i_k2 in T.Parallel(block_S, block_DK):
W_Beta_shared[i_s,
i_k2] = K_shared[i_s, i_k2] * Beta_shared[i_s] * G_shared[i_s]
T.gemm(A_shared, W_Beta_shared, W_fragment, clear_accum=True)
# First copy to smem, then copy to gmem to reduce U2RU instructions
T.copy(W_fragment, W_shared)
T.copy(
W_shared, W[bb, bs * block_S:(bs + 1) * block_S, bh,
i_k * block_DK:(i_k + 1) * block_DK])
return kernel
def run_test(
B,
S,
H,
DK,
DV,
chunk_size,
input_dtype,
output_dtype,
gate_dtype,
accum_dtype,
block_DK,
block_DV,
threads,
num_stages,
):
K, V, Beta, G, A = prepare_input(
B,
S,
H,
DK,
DV,
chunk_size,
getattr(torch, input_dtype),
getattr(torch, output_dtype),
gate_dtype=getattr(torch, gate_dtype))
W_ref, U_ref = prepare_output(B, S, H, DK, DV, getattr(torch, output_dtype))
W_tilelang, U_tilelang = prepare_output(B, S, H, DK, DV, getattr(torch, output_dtype))
# reference
W_ref, U_ref = recompute_w_u_fwd(K, V, Beta, G, A, None)
# tilelang
block_S = chunk_size
kernel = tilelang_recompute_w_u_fwd(
B,
S,
H,
DK,
DV,
input_dtype,
output_dtype,
gate_dtype,
accum_dtype,
chunk_size,
block_S=block_S,
block_DK=block_DK,
block_DV=block_DV,
threads=threads,
num_stages=num_stages)
print(kernel.get_kernel_source())
W_tilelang, U_tilelang = kernel(K, V, Beta, G, A)
try:
torch.testing.assert_close(W_tilelang, W_ref, rtol=1e-2, atol=1e-2)
print("tilelang recompute w passed √")
except Exception as e:
print("tilelang recompute w failed ✗")
print(e)
try:
torch.testing.assert_close(U_tilelang, U_ref, rtol=1e-2, atol=1e-2)
print("tilelang recompute u passed √")
except Exception as e:
print("tilelang recompute u failed ✗")
print(e)
def main():
run_test(
B=1,
S=32768,
H=32,
DK=128,
DV=128,
chunk_size=64,
input_dtype="bfloat16",
output_dtype="bfloat16",
gate_dtype="float32",
accum_dtype="float32",
block_DK=64,
block_DV=32,
threads=128,
num_stages=3)
if __name__ == "__main__":
main()
This diff is collapsed.
import tilelang.testing
import torch
tilelang.disable_cache()
B = 1
S = 32768
H = 32
DK = 128
DV = 128
input_dtype = "bfloat16"
output_dtype = "bfloat16"
accum_dtype = "float32"
gate_dtype = "float32"
state_dtype = "float32"
chunk_size = 64
use_g = True
use_initial_state = True
store_final_state = True
use_final_state_gradient = True
save_new_value = True
block_DK = 64
block_DV = 32
threads = 128
num_stages = 1
def test_example_wy_fast_compilation():
from example_wy_fast import tilelang_recompute_w_u_fwd, prepare_input, prepare_output
K, V, Beta, G, A = prepare_input(
B,
S,
H,
DK,
DV,
chunk_size,
getattr(torch, input_dtype),
getattr(torch, output_dtype),
gate_dtype=getattr(torch, gate_dtype))
W_tilelang, U_tilelang = prepare_output(B, S, H, DK, DV, getattr(torch, output_dtype))
# tilelang
block_S = chunk_size
kernel = tilelang_recompute_w_u_fwd(
B,
S,
H,
DK,
DV,
input_dtype,
output_dtype,
gate_dtype,
accum_dtype,
chunk_size,
block_S=block_S,
block_DK=block_DK,
block_DV=block_DV,
threads=threads,
num_stages=num_stages)
print(kernel.get_kernel_source())
W_tilelang, U_tilelang = kernel(K, V, Beta, G, A)
def test_example_wy_fast_bwd_split_compilation():
from example_wy_fast_bwd_split import tilelang_wy_fast_bwd, tilelang_wy_fast_bwd_split, prepare_input, prepare_output
K, V, Beta, G, A, dw, du = prepare_input(B, S, H, DK, DV, chunk_size,
getattr(torch, input_dtype),
getattr(torch, output_dtype),
getattr(torch,
accum_dtype), getattr(torch, gate_dtype),
getattr(torch, state_dtype))
dk_tilelang, dv_tilelang, dbeta_tilelang, dg_tilelang = prepare_output(
B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype),
getattr(torch, state_dtype))
BS = chunk_size
dA_tilelang = torch.empty(B, S, H, BS, dtype=getattr(torch, input_dtype)).cuda()
dbeta_tilelang_k = torch.empty(B, S, H, dtype=getattr(torch, output_dtype)).cuda()
dg_tilelang_A_positive = torch.empty(B, S, H, BS, dtype=getattr(torch, gate_dtype)).cuda()
dg_tilelang_A_negative = torch.empty(B, S, H, BS, dtype=getattr(torch, gate_dtype)).cuda()
# tilelang
kernel = tilelang_wy_fast_bwd(B, S, H, DK, DV, input_dtype, output_dtype, accum_dtype,
gate_dtype, state_dtype, chunk_size, block_DK, block_DV, threads,
num_stages)
dA_tilelang, dk_tilelang, dv_tilelang, dbeta_tilelang, dg_tilelang = kernel(
K, V, Beta, G, A, dw, du)
torch.cuda.synchronize()
kernel_split = tilelang_wy_fast_bwd_split(B, S, H, DK, DV, input_dtype, output_dtype,
accum_dtype, gate_dtype, state_dtype, chunk_size,
block_DK, block_DV, threads, num_stages)
kernel_split(K, V, Beta, G, A, dw, du, dA_tilelang, dk_tilelang, dv_tilelang, dbeta_tilelang_k,
dg_tilelang_A_positive, dg_tilelang_A_negative)
torch.cuda.synchronize()
dbeta_tilelang = dbeta_tilelang_k + dbeta_tilelang
dg_tilelang = dg_tilelang + dg_tilelang_A_positive.sum(dim=-1) - dg_tilelang_A_negative.sum(
dim=-1)
def test_example_chunk_o_compilation():
from example_chunk_o import tilelang_chunk_fwd_o, prepare_input, prepare_output
Q, K, V, HIDDEN, G = prepare_input(B, S, H, DK, DV, chunk_size, getattr(torch, input_dtype),
getattr(torch, output_dtype), getattr(torch, accum_dtype),
getattr(torch, gate_dtype))
scale = 1.0 / DK**0.5
block_S = chunk_size
O_tilelang = prepare_output(B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype))
kernel = tilelang_chunk_fwd_o(B, S, H, DK, DV, input_dtype, output_dtype, accum_dtype,
gate_dtype, chunk_size, scale, use_g, block_S, block_DK, block_DV,
threads, num_stages)
O_tilelang = kernel(Q, K, V, HIDDEN, G) # noqa: F841
def test_example_chunk_o_bwd_compilation():
from example_chunk_o_bwd import tilelang_chunk_o_bwd_dqkwg, prepare_input, prepare_output
Q, K, V, h, G, dO, dh, dv, W = prepare_input(B, S, H, DK, DV, chunk_size,
getattr(torch, input_dtype),
getattr(torch, output_dtype),
getattr(torch, accum_dtype),
getattr(torch, gate_dtype),
getattr(torch, state_dtype))
dq_tilelang, dk_tilelang, dw_tilelang, dg_tilelang = prepare_output(
B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype),
getattr(torch, state_dtype), block_DK)
kernel = tilelang_chunk_o_bwd_dqkwg(B, S, H, DK, DV, input_dtype, output_dtype, accum_dtype,
gate_dtype, state_dtype, chunk_size, 1.0, use_g, True,
block_DK, block_DV, threads, num_stages)
dq_tilelang, dk_tilelang, dw_tilelang, dg_tilelang = kernel(Q, K, V, h, G, dO, dh, dv,
W) # noqa: F841
if use_g:
dg_tilelang = dg_tilelang.sum(dim=0)
def test_example_chunk_scaled_dot_kkt_compilation():
from example_chunk_scaled_dot_kkt import tilelang_chunk_scaled_dot_kkt_fwd, prepare_input, prepare_output
K, Beta, G = prepare_input(B, S, H, DK, getattr(torch, input_dtype),
getattr(torch, output_dtype), getattr(torch, accum_dtype))
A_tilelang = prepare_output(B, S, H, chunk_size, getattr(torch, output_dtype))
block_S = chunk_size
kernel = tilelang_chunk_scaled_dot_kkt_fwd(B, S, H, DK, chunk_size, input_dtype, output_dtype,
accum_dtype, use_g, block_S, block_DK, threads,
num_stages)
A_tilelang = kernel(K, Beta, G) # noqa: F841
def test_example_cumsum_compilation():
from example_cumsum import tilelang_chunk_local_cumsum_scalar, prepare_cumsum_input, prepare_cumsum_output
G = prepare_cumsum_input(B, S, H, getattr(torch, gate_dtype))
G_new_tilelang = prepare_cumsum_output(B, S, H, getattr(torch, gate_dtype))
block_S = chunk_size
kernel = tilelang_chunk_local_cumsum_scalar(
B=B,
S=S,
H=H,
chunk_size=chunk_size,
reverse=False,
head_first=False,
input_dtype=gate_dtype,
output_dtype=gate_dtype,
block_S=block_S,
threads=threads,
use_fragment=False,
)
G_new_tilelang = kernel(G) # noqa: F841
def test_example_chunk_delta_h_compilation():
from example_chunk_delta_h import tilelang_chunk_gated_delta_rule_fwd_h, prepare_input, prepare_output
K, W, U, G, initial_state = prepare_input(B, S, H, DK, DV, chunk_size,
getattr(torch, input_dtype),
getattr(torch, output_dtype),
getattr(torch, accum_dtype),
getattr(torch, gate_dtype))
h_tilelang, final_state_tilelang, V_new_tilelang = prepare_output(B, S, H, DK, DV, chunk_size,
getattr(torch, output_dtype),
getattr(torch, state_dtype))
kernel = tilelang_chunk_gated_delta_rule_fwd_h(B, S, H, DK, DV, input_dtype, output_dtype,
accum_dtype, gate_dtype, state_dtype, chunk_size,
use_g, use_initial_state, store_final_state,
save_new_value, block_DK, block_DV, threads,
num_stages)
h_tilelang, final_state_tilelang, V_new_tilelang = kernel(K, W, U, G,
initial_state) # noqa: F841
def test_example_chunk_delta_bwd_compilation():
from example_chunk_delta_bwd import tilelang_chunk_gated_delta_rule_bwd_dhu, prepare_input, prepare_output
Q, K, W, G, h0, dht, dO, dv = prepare_input(B, S, H, DK, DV, chunk_size,
getattr(torch, input_dtype),
getattr(torch, output_dtype),
getattr(torch, accum_dtype),
getattr(torch, gate_dtype),
getattr(torch, state_dtype))
dh_tilelang, dh0_tilelang, dv2_tilelang = prepare_output(B, S, H, DK, DV, chunk_size,
getattr(torch, output_dtype),
getattr(torch, gate_dtype),
getattr(torch, state_dtype))
kernel = tilelang_chunk_gated_delta_rule_bwd_dhu(B, S, H, DK, DV, input_dtype, output_dtype,
accum_dtype, gate_dtype, state_dtype,
chunk_size, 1.0, use_g, use_initial_state,
use_final_state_gradient, block_DV, threads,
num_stages)
dh_tilelang, dh0_tilelang, dv2_tilelang = kernel(Q, K, W, G, h0, dht, dO, dv) # noqa: F841
if __name__ == "__main__":
tilelang.testing.main()
import torch
def print_red_warning(message):
print(f"\033[31mWARNING: {message}\033[0m")
def calc_sim(x, y, name="tensor"):
x, y = x.data.double(), y.data.double()
denominator = (x * x + y * y).sum()
if denominator == 0:
print_red_warning(f'{name} all zero')
return 1
sim = 2 * (x * y).sum() / denominator
return sim
def assert_similar(x, y, eps=1e-8, name="tensor", data="", raise_assert=True):
x_mask = torch.isfinite(x)
y_mask = torch.isfinite(y)
if not torch.all(x_mask == y_mask):
print_red_warning(f'{name} Error: isfinite mask mismatch')
if raise_assert:
raise AssertionError
if not torch.isclose(
x.masked_fill(x_mask, 0), y.masked_fill(y_mask, 0), rtol=0, atol=0,
equal_nan=True).all():
print_red_warning(f'{name} Error: nonfinite value mismatch')
if raise_assert:
raise AssertionError
x = x.masked_fill(~x_mask, 0)
y = y.masked_fill(~y_mask, 0)
sim = calc_sim(x, y, name)
diff = 1. - sim
if not (0 <= diff <= eps):
print_red_warning(f'{name} Error: {diff}')
if raise_assert:
raise AssertionError
else:
print(f"{name} {data} passed")
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment