"vscode:/vscode.git/clone" did not exist on "71254ddd23fa848df04cd1ceebc5172656049710"
Unverified Commit 6f59668d authored by Zhengju Tang's avatar Zhengju Tang Committed by GitHub
Browse files

Gated Delta Net(GDN) kernel implementation in TileLang (#695)

* [GDN] Add examples for GDN forward and backward kernels

* [Refactor] Folder structure refactor for duplicated utils

* [Test] Add test script for kernels

* [Refactor] Rename examples to align with the repo

* [Lint] Modify README

* [Update] Modified README to align upstream repo

* [BugFix] Path of FLA

* [Fix] Copyright and test

* [Lint]

* [CI] Add GDN compilation test CI

* [Lint]

* [BugFix] Import error of fla
parent 36b57617
# Gated Delta Net(GDN) kernel implementation in TileLang
## Requirement
### The Tilelang version for test is 0.1.5+17fafc1b3026d910a83eb8052fdf811ba56be0b1
### We currently use triton=3.3.0 and FLA commit id=f03cb3ae for comparison
## Get started
### The common/chunk_delta_h.py implements the most critical forward kernel of GDN. It's a good start to understand the GDN logic and the tilelang optimization
\ No newline at end of file
# Reference: fla/ops/common/chunk_delta_h.py
import sys # noqa: F401
import tilelang
import tilelang.language as T
print(tilelang.__file__, flush=True)
# Add your fla repository path to sys.path
# Currently we use the fla repository from the flash-linear-attention project at commit id f03cb3ae
# sys.path.insert(0, "/home/tzj/flash-linear-attention")
try:
import fla
print(fla.__file__, flush=True)
from fla.ops.common.chunk_delta_h import chunk_gated_delta_rule_bwd_dhu
except ImportError:
print("fla not found, using tilelang implementation")
fla = None
import torch
import torch.nn.functional as F
torch.random.manual_seed(0)
# torch.set_printoptions(profile="full")
tilelang.disable_cache()
from utils import *
def prepare_input(
B,
S,
H,
DK,
DV,
chunk_size,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
state_dtype,
):
Q = torch.randn(B, S, H, DK, dtype=input_dtype).cuda()
K = torch.randn(B, S, H, DK, dtype=input_dtype).cuda()
K = F.normalize(K, dim=-1, p=2)
W = torch.randn(B, S, H, DK, dtype=input_dtype).cuda()
# Note: G should be in logspace and do chunkwise cumsum
G = torch.randn(B, S, H, dtype=gate_dtype).cuda()
G = F.logsigmoid(G)
try:
from fla.ops.utils.cumsum import chunk_local_cumsum
G = chunk_local_cumsum(G, chunk_size)
except ImportError:
print("fla not found, skip cumsum")
h0 = torch.randn(B, H, DK, DV, dtype=input_dtype).cuda()
dht = torch.randn(B, H, DK, DV, dtype=input_dtype).cuda()
dO = torch.randn(B, S, H, DV, dtype=input_dtype).cuda()
dv = torch.randn(B, S, H, DV, dtype=input_dtype).cuda()
return Q, K, W, G, h0, dht, dO, dv
def prepare_input_fake(
B,
S,
H,
DK,
DV,
chunk_size,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
state_dtype,
):
Q = torch.ones(B, S, H, DK, dtype=input_dtype).cuda()
K = torch.ones(B, S, H, DK, dtype=input_dtype).cuda()
W = torch.ones(B, S, H, DK, dtype=input_dtype).cuda()
G = torch.ones(B, S, H, dtype=gate_dtype).cuda()
h0 = torch.ones(B, H, DK, DV, dtype=input_dtype).cuda()
dht = torch.ones(B, H, DK, DV, dtype=input_dtype).cuda()
dO = torch.ones(B, S, H, DV, dtype=input_dtype).cuda()
dv = torch.ones(B, S, H, DV, dtype=input_dtype).cuda()
return Q, K, W, G, h0, dht, dO, dv
def prepare_output(
B,
S,
H,
DK,
DV,
chunk_size,
output_dtype,
gate_dtype,
state_dtype,
):
BS = S // chunk_size
dh = torch.empty(B, BS, H, DK, DV, dtype=output_dtype).cuda()
dh0 = torch.empty(B, H, DK, DV, dtype=state_dtype).cuda()
dv2 = torch.empty(B, S, H, DV, dtype=output_dtype).cuda()
return dh, dh0, dv2
def torch_chunk_gated_delta_rule_bwd_dhu(
Q: torch.Tensor,
K: torch.Tensor,
W: torch.Tensor,
G: torch.Tensor,
h0: torch.Tensor,
dht: torch.Tensor,
dO: torch.Tensor,
dv: torch.Tensor,
scale: float,
use_g: bool,
use_initial_state: bool,
use_final_state_gradient: bool,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
state_dtype,
):
B, S, H, DK = Q.shape
DV = dv.shape[-1]
block_S = 64
BS = S // block_S
dh, dh0, dv2 = torch.empty((B, BS, H, DK, DV), dtype=output_dtype), torch.empty(
(B, H, DK, DV), dtype=state_dtype), torch.empty((B, S, H, DV), dtype=output_dtype)
dh_tmp = torch.empty((B, H, DK, DV), dtype=accum_dtype)
dv_tmp = torch.empty((B, S, H, DV), dtype=accum_dtype)
Q_tmp = torch.empty((B, S, H, DK), dtype=accum_dtype)
if use_final_state_gradient:
dh_tmp = dht.clone().to(accum_dtype)
else:
dh_tmp = torch.zeros_like(dht).to(accum_dtype)
for i_s in range(BS - 1, -1, -1):
dh[:, i_s, :, :, :] = dh_tmp
dv_tmp = torch.matmul(K[:, i_s * block_S:(i_s + 1) * block_S, :, :].permute(0, 2, 1, 3),
dh_tmp.to(K.dtype)).permute(0, 2, 1, 3)
if use_g:
for i_bh in range(B * H):
i_b, i_h = i_bh // H, i_bh % H
for i_s2 in range(block_S):
if G[i_b, i_s * block_S + block_S - 1, i_h] - G[i_b, i_s * block_S + i_s2,
i_h] <= 0:
dv_tmp[i_b, i_s2,
i_h, :] *= torch.exp(G[i_b, i_s * block_S + block_S - 1, i_h] -
G[i_b, i_s * block_S + i_s2, i_h])
else:
dv_tmp[i_b, i_s2, i_h, :] = 0
dv_tmp += dv[:, i_s * block_S:(i_s + 1) * block_S, :, :]
dv2[:, i_s * block_S:(i_s + 1) * block_S, :, :] = dv_tmp
if use_g:
G_last = G[:, i_s * block_S + block_S - 1, :]
for i_bh in range(B * H):
i_b, i_h = i_bh // H, i_bh % H
dh_tmp[i_b, i_h, :, :] *= torch.exp(G_last[i_b, i_h])
Q_tmp = Q[:, i_s * block_S:(i_s + 1) * block_S, :, :]
for i_s2 in range(block_S):
for i_k in range(DK):
Q_tmp[:, i_s2, :, i_k] *= torch.exp(G[:, i_s * block_S + i_s2, :])
Q_tmp *= scale
W_tmp = W[:, i_s * block_S:(i_s + 1) * block_S, :, :]
dO_tmp = dO[:, i_s * block_S:(i_s + 1) * block_S, :, :]
torch.backends.cuda.matmul.allow_tf32 = True
dh_tmp += torch.matmul(Q_tmp.permute(0, 2, 3, 1), dO_tmp.permute(0, 2, 1, 3))
dh_tmp -= torch.matmul(W_tmp.permute(0, 2, 3, 1), dv_tmp.permute(0, 2, 1, 3))
torch.backends.cuda.matmul.allow_tf32 = False
if use_initial_state:
dh0 = dh_tmp[:, :, :, :]
else:
dh0 = torch.zeros_like(dh_tmp[:, :, :, :])
print(dh0.dtype)
return dh, dh0, dv2
@tilelang.jit(out_idx=[-3, -2, -1])
def tilelang_chunk_gated_delta_rule_bwd_dhu(
# task config
B,
S,
H,
DK,
DV,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
state_dtype,
chunk_size,
scale,
use_g=True,
use_initial_state=True,
use_final_state_gradient=True,
# kernel config
block_DV=64,
threads=256,
num_stages=0,
):
block_S = chunk_size
# Should support cu_seqlen
BS = S // block_S
Q_shape = (B, S, H, DK)
K_shape = (B, S, H, DK)
W_shape = (B, S, H, DK)
G_shape = (B, S, H)
h0_shape = (B, H, DK, DV)
dht_shape = (B, H, DK, DV)
dO_shape = (B, S, H, DV)
dv_shape = (B, S, H, DV)
dh_shape = (B, BS, H, DK, DV)
dh0_shape = (B, H, DK, DV)
dv2_shape = (B, S, H, DV)
@T.prim_func
def kernel(
# Input
Q: T.Tensor(Q_shape, dtype=input_dtype),
K: T.Tensor(K_shape, dtype=input_dtype),
W: T.Tensor(W_shape, dtype=input_dtype),
G: T.Tensor(G_shape, dtype=gate_dtype),
h0: T.Tensor(h0_shape, dtype=input_dtype),
dht: T.Tensor(dht_shape, dtype=input_dtype),
dO: T.Tensor(dO_shape, dtype=input_dtype),
dv: T.Tensor(dv_shape, dtype=input_dtype),
# Output
dh: T.Tensor(dh_shape, dtype=output_dtype),
dh0: T.Tensor(dh0_shape, dtype=state_dtype),
dv2: T.Tensor(dv2_shape, dtype=output_dtype),
):
with T.Kernel(T.ceildiv(DV, block_DV), B * H, threads=threads) as (bv, bbh):
bb, bh = bbh // H, bbh % H
b_dh_shared = T.alloc_shared((DK, block_DV), dtype=output_dtype)
b_dh_shared_fp32 = T.alloc_shared((DK, block_DV), dtype=state_dtype)
b_dh_fragment = T.alloc_fragment((DK, block_DV), dtype=accum_dtype)
b_dh_fragment_1 = T.alloc_fragment((DK, block_DV), dtype=accum_dtype)
b_dh_fragment_2 = T.alloc_fragment((DK, block_DV), dtype=accum_dtype)
dv_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype)
dv_fragment = T.alloc_fragment((block_S, block_DV), dtype=accum_dtype)
dv_fragment_2 = T.alloc_fragment((block_S, block_DV), dtype=accum_dtype)
dO_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype)
dO_shared_t = T.alloc_shared((block_DV, block_S), dtype="float32")
dO_fragment = T.alloc_fragment((block_S, block_DV), dtype="float32")
dO_fragment_t = T.alloc_fragment((block_DV, block_S), dtype="float32")
K_shared = T.alloc_shared((block_S, DK), dtype=input_dtype)
Q_shared = T.alloc_shared((block_S, DK), dtype=input_dtype)
Q_shared_fp32 = T.alloc_shared((block_S, DK), dtype="float32")
W_shared = T.alloc_shared((block_S, DK), dtype=input_dtype)
G_last_local = T.alloc_local((1), dtype=gate_dtype)
G_last_local_exp = T.alloc_local((1), dtype=gate_dtype)
G_shared = T.alloc_shared((block_S), dtype=gate_dtype, scope="shared")
G_fragment = T.alloc_fragment((block_S), dtype=gate_dtype)
G_fragment_post = T.alloc_fragment((block_S), dtype=gate_dtype)
G_fragment_exp = T.alloc_fragment((block_S), dtype=gate_dtype)
Q_fragment = T.alloc_fragment((block_S, DK), dtype=accum_dtype)
Q_fragment_t = T.alloc_fragment((DK, block_S), dtype=accum_dtype)
T.use_swizzle(10)
T.annotate_layout({
b_dh_shared: tilelang.layout.make_swizzled_layout(b_dh_shared),
b_dh_shared_fp32: tilelang.layout.make_swizzled_layout(b_dh_shared_fp32),
dv_shared: tilelang.layout.make_swizzled_layout(dv_shared),
dO_shared: tilelang.layout.make_swizzled_layout(dO_shared),
dO_shared_t: tilelang.layout.make_swizzled_layout(dO_shared_t),
K_shared: tilelang.layout.make_swizzled_layout(K_shared),
Q_shared: tilelang.layout.make_swizzled_layout(Q_shared),
Q_shared_fp32: tilelang.layout.make_swizzled_layout(Q_shared_fp32),
W_shared: tilelang.layout.make_swizzled_layout(W_shared),
})
if use_final_state_gradient:
T.copy(dht[bb, bh, 0:DK, bv * block_DV:(bv + 1) * block_DV], b_dh_shared)
T.copy(b_dh_shared, b_dh_fragment)
else:
T.clear(b_dh_fragment)
for i_s in T.Pipelined(T.ceildiv(S, block_S), num_stages=num_stages):
# The gradient should be stored in the reverse order
i_s_inv = T.ceildiv(S, block_S) - i_s - 1
# Store the updated dh
T.copy(b_dh_fragment, b_dh_shared)
T.copy(b_dh_shared, dh[bb, i_s_inv, bh, 0:DK, bv * block_DV:(bv + 1) * block_DV])
# Update dv
T.copy(K[bb, i_s_inv * block_S:(i_s_inv + 1) * block_S, bh, 0:DK], K_shared)
T.gemm(K_shared, b_dh_shared, dv_fragment, clear_accum=True)
if use_g:
T.copy(
G[bb, i_s_inv * block_S:(i_s_inv + 1) * block_S, bh],
G_shared,
disable_tma=True)
T.copy(G_shared, G_fragment)
G_last_local[0] = G_shared[block_S - 1]
G_last_local_exp[0] = T.exp(G_last_local[0])
for i_s2 in T.Parallel(block_S):
G_fragment_post[i_s2] = T.exp(G_last_local[0] - G_fragment[i_s2])
for i_s2, i_v in T.Parallel(block_S, block_DV):
# with T.If(G_last_local[0] - G_shared[i_s2] <= 0):
with T.If(G_last_local[0] - G_fragment[i_s2] <= 0):
with T.Then():
dv_fragment[i_s2,
i_v] = dv_fragment[i_s2, i_v] * G_fragment_post[i_s2]
with T.Else():
dv_fragment[i_s2, i_v] = 0
T.copy(
dv[bb, i_s_inv * block_S:(i_s_inv + 1) * block_S, bh,
bv * block_DV:(bv + 1) * block_DV], dv_shared)
T.copy(dv_shared, dv_fragment_2)
for i_s2, i_v in T.Parallel(block_S, block_DV):
dv_fragment[i_s2, i_v] = dv_fragment[i_s2, i_v] + dv_fragment_2[i_s2, i_v]
# Store the updated dv
T.copy(dv_fragment, dv_shared)
T.copy(
dv_shared, dv2[bb, i_s_inv * block_S:(i_s_inv + 1) * block_S, bh,
bv * block_DV:(bv + 1) * block_DV])
# Update dh
T.copy(Q[bb, i_s_inv * block_S:(i_s_inv + 1) * block_S, bh, 0:DK], Q_shared)
T.copy(W[bb, i_s_inv * block_S:(i_s_inv + 1) * block_S, bh, 0:DK], W_shared)
T.clear(Q_fragment)
if use_g:
for i_k, i_v in T.Parallel(DK, block_DV):
b_dh_fragment[i_k, i_v] *= G_last_local_exp[0]
T.copy(Q_shared, Q_fragment)
for i_s2 in T.Parallel(block_S):
G_fragment_exp[i_s2] = T.exp(G_shared[i_s2])
for i_s2, i_k in T.Parallel(block_S, DK):
# Q_fragment[i_s2, i_k] = Q_fragment[i_s2, i_k] * T.exp(G_shared[i_s2]) * scale
Q_fragment[i_s2, i_k] = Q_fragment[i_s2, i_k] * G_fragment_exp[i_s2] * scale
else:
T.copy(Q_shared, Q_fragment)
for i_s2, i_k in T.Parallel(block_S, DK):
Q_fragment[i_s2, i_k] = Q_fragment[i_s2, i_k] * scale
# Get transpose of Q_fragment to meet tf32 gemm requirement
for i_s2, i_k in T.Parallel(block_S, DK):
Q_fragment_t[i_k, i_s2] = Q_fragment[i_s2, i_k]
T.copy(
dO[bb, i_s_inv * block_S:(i_s_inv + 1) * block_S, bh,
bv * block_DV:(bv + 1) * block_DV], dO_shared)
T.copy(dO_shared, dO_fragment)
for i_s2, i_v in T.Parallel(block_S, block_DV):
dO_fragment_t[i_v, i_s2] = dO_fragment[i_s2, i_v]
T.copy(dO_fragment_t, dO_shared_t)
T.clear(b_dh_fragment_1)
T.gemm(Q_fragment_t, dO_shared_t, b_dh_fragment_1, transpose_B=True)
T.clear(b_dh_fragment_2)
T.gemm(W_shared, dv_shared, b_dh_fragment_2, transpose_A=True)
for i_k, i_v in T.Parallel(DK, block_DV):
b_dh_fragment[i_k, i_v] += b_dh_fragment_1[i_k, i_v] - b_dh_fragment_2[i_k, i_v]
if use_initial_state:
T.copy(b_dh_fragment, dh0[bb, bh, 0:DK, bv * block_DV:(bv + 1) * block_DV])
return kernel
def test_result(dh_0, dh0_0, dv2_0, dh_1, dh0_1, dv2_1, name):
try:
torch.testing.assert_close(dh_0, dh_1, rtol=1e-2, atol=1e-2, equal_nan=True)
print(f"{name} dh_0 and dh_1 passed for {name}")
except Exception as e:
print(f"{name} dh_0 and dh_1 are not close for {name}")
print(e, end="\n\n")
try:
torch.testing.assert_close(dh0_0, dh0_1, rtol=1e-2, atol=1e-2, equal_nan=True)
print(f"{name} dh0_0 and dh0_1 passed for {name}")
except Exception as e:
print(f"{name} dh0_0 and dh0_1 are not close for {name}")
print(e, end="\n\n")
try:
torch.testing.assert_close(dv2_0, dv2_1, rtol=1e-2, atol=1e-2, equal_nan=True)
print(f"{name} dv2_0 and dv2_1 passed for {name}")
except Exception as e:
print(f"{name} dv2_0 and dv2_1 are not close for {name}")
print(e, end="\n\n")
close = torch.isclose(dh_0, dh_1, rtol=1e-2, atol=1e-2)
mismatch_indices = torch.nonzero(~close, as_tuple=True)
error_num = 0
for indices in zip(*mismatch_indices):
if error_num < 100:
print(
f"{name} dh_0[{[idx.item() for idx in indices]}] = {dh_0[indices[0].item(), indices[1].item(), indices[2].item(), indices[3].item(), indices[4].item()]}, dh_1[{[idx.item() for idx in indices]}] = {dh_1[indices[0].item(), indices[1].item(), indices[2].item(), indices[3].item(), indices[4].item()]}"
)
error_num += 1
close = torch.isclose(dh0_0, dh0_1, rtol=1e-2, atol=1e-2)
mismatch_indices = torch.nonzero(~close, as_tuple=True)
error_num = 0
for indices in zip(*mismatch_indices):
if error_num < 100:
print(
f"{name} dh0_0[{[idx.item() for idx in indices]}] = {dh0_0[indices[0].item(), indices[1].item(), indices[2].item(), indices[3].item()]}, dh0_1[{[idx.item() for idx in indices]}] = {dh0_1[indices[0].item(), indices[1].item(), indices[2].item(), indices[3].item()]}"
)
error_num += 1
close = torch.isclose(dv2_0, dv2_1, rtol=1e-2, atol=1e-2)
mismatch_indices = torch.nonzero(~close, as_tuple=True)
error_num = 0
for indices in zip(*mismatch_indices):
if error_num < 100:
print(
f"{name} dv2_0[{[idx.item() for idx in indices]}] = {dv2_0[indices[0].item(), indices[1].item(), indices[2].item(), indices[3].item()]}, dv2_1[{[idx.item() for idx in indices]}] = {dv2_1[indices[0].item(), indices[1].item(), indices[2].item(), indices[3].item()]}"
)
error_num += 1
def run_test(
B,
S,
H,
DK,
DV,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
state_dtype,
chunk_size,
scale,
use_g=True,
use_initial_state=True,
use_final_state_gradient=True,
block_DV=64,
threads=256,
num_stages=0,
use_torch=False,
):
Q, K, W, G, h0, dht, dO, dv = prepare_input(B, S, H, DK, DV, chunk_size,
getattr(torch, input_dtype),
getattr(torch, output_dtype),
getattr(torch, accum_dtype),
getattr(torch, gate_dtype),
getattr(torch, state_dtype))
dh_ref, dh0_ref, dv2_ref = prepare_output(B, S, H, DK, DV, chunk_size,
getattr(torch, output_dtype),
getattr(torch, gate_dtype),
getattr(torch, state_dtype))
dh_tilelang, dh0_tilelang, dv2_tilelang = prepare_output(B, S, H, DK, DV, chunk_size,
getattr(torch, output_dtype),
getattr(torch, gate_dtype),
getattr(torch, state_dtype))
# fla ref
print("fla running...", flush=True)
if use_g:
dh_ref, dh0_ref, dv2_ref = chunk_gated_delta_rule_bwd_dhu(Q, K, W, G, h0, dht, dO, dv,
scale)
else:
G = G.fill_(0)
dh_ref, dh0_ref, dv2_ref = chunk_gated_delta_rule_bwd_dhu(Q, K, W, G, h0, dht, dO, dv,
scale)
# tilelang
print("tilelang running...", flush=True)
kernel = tilelang_chunk_gated_delta_rule_bwd_dhu(B, S, H, DK, DV, input_dtype, output_dtype,
accum_dtype, gate_dtype, state_dtype,
chunk_size, scale, use_g, use_initial_state,
use_final_state_gradient, block_DV, threads,
num_stages)
# kernel = tilelang.compile(program)
print(kernel.get_kernel_source())
dh_tilelang, dh0_tilelang, dv2_tilelang = kernel(Q, K, W, G, h0, dht, dO, dv)
fla_time = do_bench(
chunk_gated_delta_rule_bwd_dhu, Q, K, W, G, h0, dht, dO, dv, scale, chunk_size=chunk_size)
tilelang_time = do_bench(kernel, Q, K, W, G, h0, dht, dO, dv)
print(f"fla time: {fla_time} ms")
print(f"tilelang time: {tilelang_time} ms")
assert_similar(dh_tilelang, dh_ref, 1e-5, "fla-tilelang", data="dh")
assert_similar(dh0_tilelang, dh0_ref, 1e-5, "fla-tilelang", data="dh0")
assert_similar(dv2_tilelang, dv2_ref, 1e-5, "fla-tilelang", data="dv2")
# torch ref
if use_torch:
print("torch running...", flush=True)
if use_g:
dh_ref_torch, dh0_ref_torch, dv2_ref_torch = torch_chunk_gated_delta_rule_bwd_dhu(
Q, K, W, G, h0, dht, dO, dv, scale, use_g, use_initial_state,
use_final_state_gradient, getattr(torch, input_dtype), getattr(torch, output_dtype),
getattr(torch, accum_dtype), getattr(torch,
gate_dtype), getattr(torch, state_dtype))
dh_ref_torch = dh_ref_torch.cuda()
dh0_ref_torch = dh0_ref_torch.cuda()
dv2_ref_torch = dv2_ref_torch.cuda()
else:
dh_ref_torch, dh0_ref_torch, dv2_ref_torch = torch_chunk_gated_delta_rule_bwd_dhu(
Q, K, W, None, h0, dht, dO, dv, scale, use_g, use_initial_state,
use_final_state_gradient, getattr(torch, input_dtype), getattr(torch, output_dtype),
getattr(torch, accum_dtype), getattr(torch,
gate_dtype), getattr(torch, state_dtype))
dh_ref_torch = dh_ref_torch.cuda()
dh0_ref_torch = dh0_ref_torch.cuda()
dv2_ref_torch = dv2_ref_torch.cuda()
assert_similar(dh_ref_torch, dh_ref, 1e-5, "torch-fla", data="dh")
assert_similar(dh0_ref_torch, dh0_ref, 1e-5, "torch-fla", data="dh0")
assert_similar(dv2_ref_torch, dv2_ref, 1e-5, "torch-fla", data="dv2")
assert_similar(dh_ref_torch, dh_tilelang, 1e-5, "torch-tilelang", data="dh")
assert_similar(dh0_ref_torch, dh0_tilelang, 1e-5, "torch-tilelang", data="dh0")
assert_similar(dv2_ref_torch, dv2_tilelang, 1e-5, "torch-tilelang", data="dv2")
def do_bench(fn, *args, warmup=10, rep=10, **kwargs):
"""
Do benchmark for a function.
"""
start_event = [torch.cuda.Event(enable_timing=True) for i in range(rep)]
end_event = [torch.cuda.Event(enable_timing=True) for i in range(rep)]
for _ in range(warmup):
fn(*args, **kwargs)
torch.cuda.synchronize()
for i in range(rep):
start_event[i].record()
fn(*args, **kwargs)
end_event[i].record()
torch.cuda.synchronize()
# Record clocks
times = torch.tensor(
[s.elapsed_time(e) for s, e in zip(start_event, end_event)],
dtype=torch.float,
)
return times.mean().item()
def main():
DK = 128
run_test(
B=1,
S=32768,
H=8,
DK=DK,
DV=128,
input_dtype="bfloat16",
output_dtype="bfloat16",
accum_dtype="float32",
gate_dtype="float32",
state_dtype="float32",
chunk_size=64,
scale=DK**-0.5,
use_g=True,
use_initial_state=True,
use_final_state_gradient=True,
block_DV=32,
threads=128,
num_stages=1,
use_torch=False,
)
if __name__ == "__main__":
main()
# Reference: fla/ops/common/chunk_delta_h.py
import sys # noqa: F401
import tilelang
import tilelang.language as T
# Add your fla repository path to sys.path
# Currently we use the fla repository from the flash-linear-attention project at commit id f03cb3ae
# sys.path.insert(0, "/home/tzj/flash-linear-attention")
try:
import fla
print(fla.__file__)
from fla.ops.common.chunk_delta_h import chunk_gated_delta_rule_fwd_h
except ImportError:
print("fla not found, using tilelang implementation")
fla = None
import torch
import torch.nn.functional as F
from tilelang.engine.callback import register_cuda_postproc_callback # noqa: F401
from utils import *
# (zhengju) We can slightly modify the generated cuda code from tilelang lowering
# in the debug folder to make the performance better. To enable this callback,
# you can comment out the following function.
# @register_cuda_postproc_callback
# def tilelang_callback_cuda_postproc(code, _):
# cuda_code = open("../debug/chunk_delta_h_fuse.cu", "r").read()
# code = cuda_code
# return code
torch.random.manual_seed(0)
tilelang.disable_cache()
def prepare_input(
B,
S,
H,
DK,
DV,
chunk_size,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
):
K = torch.randn(B, S, H, DK, dtype=input_dtype).cuda()
K = F.normalize(K, dim=-1, p=2)
W = torch.randn(B, S, H, DK, dtype=input_dtype).cuda()
W = F.normalize(W, dim=-1, p=2)
U = torch.randn(B, S, H, DV, dtype=input_dtype).cuda()
U = F.normalize(U, dim=-1, p=2)
G = torch.randn(B, S, H, dtype=gate_dtype).cuda()
G = F.logsigmoid(G)
try:
from fla.ops.utils.cumsum import chunk_local_cumsum
G = chunk_local_cumsum(G, chunk_size)
except ImportError:
print("fla not found, skip cumsum")
initial_state = torch.randn(B, H, DK, DV, dtype=input_dtype).cuda()
return K, W, U, G, initial_state
def prepare_output(
B,
S,
H,
DK,
DV,
chunk_size,
output_dtype,
state_dtype,
):
BS = S // chunk_size
h = torch.empty(B, BS, H, DK, DV, dtype=output_dtype).cuda()
final_state = torch.empty(B, H, DK, DV, dtype=state_dtype).cuda()
V_new = torch.empty(B, S, H, DV, dtype=output_dtype).cuda()
return h, final_state, V_new
@tilelang.jit(out_idx=[-3, -2, -1])
def tilelang_chunk_gated_delta_rule_fwd_h(
# task config
B,
S,
H,
DK,
DV,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
state_dtype,
chunk_size,
use_g=True,
use_initial_state=True,
store_final_state=True,
save_new_value=True,
# kernel config
block_DK=64,
block_DV=64,
threads=256,
num_stages=0,
):
block_S = chunk_size
BS = S // block_S
K_shape = (B, S, H, DK)
V_shape = (B, S, H, DV)
W_shape = (B, S, H, DK)
U_shape = (B, S, H, DV)
G_shape = (B, S, H)
h_shape = (B, BS, H, DK, DV)
initial_state_shape = (B, H, DK, DV)
final_state_shape = (B, H, DK, DV)
@T.prim_func
def kernel(
K: T.Tensor(K_shape, dtype=input_dtype),
W: T.Tensor(W_shape, dtype=input_dtype),
U: T.Tensor(U_shape, dtype=input_dtype),
G: T.Tensor(G_shape, dtype=gate_dtype),
initial_state: T.Tensor(initial_state_shape, dtype=input_dtype),
h: T.Tensor(h_shape, dtype=output_dtype),
final_state: T.Tensor(final_state_shape, dtype=state_dtype),
V_new: T.Tensor(V_shape, dtype=output_dtype),
):
with T.Kernel(T.ceildiv(DV, block_DV), B * H, threads=threads) as (bv, bbh):
bb, bh = bbh // H, bbh % H
b_h_shared = T.alloc_shared((DK, block_DV), dtype=input_dtype)
b_h_fragment = T.alloc_fragment((DK, block_DV), dtype=accum_dtype)
U_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype)
U_fragment = T.alloc_fragment((block_S, block_DV), dtype=accum_dtype)
W_shared = T.alloc_shared((block_S, DK), dtype=input_dtype)
V_new_fragment = T.alloc_fragment((block_S, block_DV), dtype=accum_dtype)
V_new_shared = T.alloc_shared((block_S, block_DV), dtype=output_dtype)
K_shared = T.alloc_shared((block_S, DK), dtype=input_dtype)
G_last_local = T.alloc_local((1), dtype=gate_dtype)
G_shared = T.alloc_shared((block_S, block_DV), dtype=gate_dtype)
G_fragment = T.alloc_fragment((block_S, block_DV), dtype=gate_dtype)
T.annotate_layout({
b_h_shared: tilelang.layout.make_swizzled_layout(b_h_shared),
U_shared: tilelang.layout.make_swizzled_layout(U_shared),
W_shared: tilelang.layout.make_swizzled_layout(W_shared),
V_new_shared: tilelang.layout.make_swizzled_layout(V_new_shared),
K_shared: tilelang.layout.make_swizzled_layout(K_shared),
G_shared: tilelang.layout.make_swizzled_layout(G_shared),
})
T.use_swizzle(10)
if use_initial_state:
T.copy(initial_state[bb, bh, 0:DK, bv * block_DV:(bv + 1) * block_DV], b_h_shared)
T.copy(b_h_shared, b_h_fragment)
else:
T.clear(b_h_fragment)
for i_s in T.Pipelined(T.ceildiv(S, block_S), num_stages=num_stages):
# Store previous result to the hidden tensor, like the epilogue
T.copy(b_h_shared, h[bb, i_s, bh, 0:DK, bv * block_DV:(bv + 1) * block_DV])
# Recurrence
T.copy(W[bb, i_s * block_S:(i_s + 1) * block_S, bh, 0:DK], W_shared)
T.gemm(W_shared, b_h_shared, V_new_fragment, clear_accum=True)
# U - W * S
T.copy(
U[bb, i_s * block_S:(i_s + 1) * block_S, bh, bv * block_DV:(bv + 1) * block_DV],
U_shared)
T.copy(U_shared, U_fragment)
for i_s2, i_v in T.Parallel(block_S, block_DV):
V_new_fragment[i_s2, i_v] = -V_new_fragment[i_s2, i_v] + U_fragment[i_s2, i_v]
# Save V_new
if save_new_value:
T.copy(V_new_fragment, dst=V_new_shared)
T.copy(
V_new_shared, V_new[bb, i_s * block_S:(i_s + 1) * block_S, bh,
bv * block_DV:(bv + 1) * block_DV])
T.copy(K[bb, i_s * block_S:(i_s + 1) * block_S, bh, 0:DK], K_shared)
# use_g
if use_g:
G_last_local[0] = G[bb, (i_s + 1) * block_S - 1, bh]
for i_s2, i_v in T.Parallel(block_S, block_DV):
G_shared[i_s2, i_v] = G[bb, i_s * block_S + i_s2, bh]
T.copy(G_shared, G_fragment)
for i_s2, i_v in T.Parallel(block_S, block_DV):
with T.If(G_last_local[0] - G_fragment[i_s2, i_v] <= 0):
with T.Then():
V_new_fragment[i_s2, i_v] = V_new_fragment[i_s2, i_v] * T.exp(
G_last_local[0] - G_fragment[i_s2, i_v])
with T.Else():
V_new_fragment[i_s2, i_v] = 0
G_last_local[0] = T.exp(G_last_local[0])
for i_k, i_v in T.Parallel(DK, block_DV):
b_h_fragment[i_k, i_v] *= G_last_local[0]
# Update intermediate results
T.copy(V_new_fragment, V_new_shared)
T.gemm(K_shared, V_new_shared, b_h_fragment, transpose_A=True)
T.copy(b_h_fragment, b_h_shared)
# Save final state
if store_final_state:
T.copy(b_h_fragment, final_state[bb, bh, 0:DK, bv * block_DV:(bv + 1) * block_DV])
return kernel
def do_bench(fn, *args, warmup=10, rep=10, **kwargs):
"""
Do benchmark for a function.
"""
start_event = [torch.cuda.Event(enable_timing=True) for i in range(rep)]
end_event = [torch.cuda.Event(enable_timing=True) for i in range(rep)]
for _ in range(warmup):
fn(*args, **kwargs)
torch.cuda.synchronize()
for i in range(rep):
start_event[i].record()
fn(*args, **kwargs)
end_event[i].record()
torch.cuda.synchronize()
# Record clocks
times = torch.tensor(
[s.elapsed_time(e) for s, e in zip(start_event, end_event)],
dtype=torch.float,
)
return times.mean().item()
def run_test(
B,
S,
H,
DK,
DV,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
state_dtype,
chunk_size,
use_g=True,
use_initial_state=True,
store_final_state=True,
save_new_value=True,
block_DK=64,
block_DV=32,
threads=128,
num_stages=0,
):
K, W, U, G, initial_state = prepare_input(B, S, H, DK, DV, chunk_size,
getattr(torch, input_dtype),
getattr(torch, output_dtype),
getattr(torch, accum_dtype),
getattr(torch, gate_dtype))
h_ref, final_state_ref, V_new_ref = prepare_output(B, S, H, DK, DV, chunk_size,
getattr(torch, output_dtype),
getattr(torch, state_dtype))
h_tilelang, final_state_tilelang, V_new_tilelang = prepare_output(B, S, H, DK, DV, chunk_size,
getattr(torch, output_dtype),
getattr(torch, state_dtype))
# fla ref
h_ref, V_new_ref, final_state_ref = chunk_gated_delta_rule_fwd_h(K, W, U, G, initial_state,
store_final_state, chunk_size,
save_new_value)
# tilelang
kernel = tilelang_chunk_gated_delta_rule_fwd_h(B, S, H, DK, DV, input_dtype, output_dtype,
accum_dtype, gate_dtype, state_dtype, chunk_size,
use_g, use_initial_state, store_final_state,
save_new_value, block_DK, block_DV, threads,
num_stages)
h_tilelang, final_state_tilelang, V_new_tilelang = kernel(K, W, U, G, initial_state)
# (zhengju) If you want to print the generated cuda code, you can uncomment the following line
# print("CUDA Code:\n", kernel.get_kernel_source())
fla_time = do_bench(chunk_gated_delta_rule_fwd_h, K, W, U, G, initial_state, store_final_state,
chunk_size, save_new_value)
tilelang_time = do_bench(kernel, K, W, U, G, initial_state)
# check correctness
try:
h_ref_fp32 = h_ref.to(torch.float32)
h_tilelang_fp32 = h_tilelang.to(torch.float32)
assert_similar(
h_ref_fp32,
h_tilelang_fp32,
eps=1e-5,
name="tilelang chunk gated delta rule fwd h",
raise_assert=False)
print("tilelang chunk gated delta rule fwd h passed √")
except Exception as e:
print("tilelang chunk gated delta rule fwd h failed ✗")
print(e)
try:
final_state_ref_fp32 = final_state_ref.to(torch.float32)
final_state_tilelang_fp32 = final_state_tilelang.to(torch.float32)
assert_similar(
final_state_ref_fp32,
final_state_tilelang_fp32,
eps=1e-5,
name="tilelang chunk gated delta rule fwd final_state",
raise_assert=False)
print("tilelang chunk gated delta rule fwd final_state passed √")
except Exception as e:
print("tilelang chunk gated delta rule fwd final_state failed ✗")
print(e)
try:
V_new_ref_fp32 = V_new_ref.to(torch.float32)
V_new_tilelang_fp32 = V_new_tilelang.to(torch.float32)
assert_similar(
V_new_ref_fp32,
V_new_tilelang_fp32,
eps=1e-5,
name="tilelang chunk gated delta rule fwd V_new",
raise_assert=False)
print("tilelang chunk gated delta rule fwd V_new passed √")
except Exception as e:
print("tilelang chunk gated delta rule fwd V_new failed ✗")
print(e)
print(f"tilelang time: {tilelang_time} ms")
print(f"fla time: {fla_time} ms")
def main():
run_test(
B=1,
S=32768,
H=32,
DK=128,
DV=128,
input_dtype="bfloat16",
output_dtype="bfloat16",
accum_dtype="float32",
gate_dtype="float32",
state_dtype="float32",
chunk_size=64,
use_g=True,
use_initial_state=True,
store_final_state=True,
save_new_value=True,
block_DK=64,
block_DV=32,
threads=128,
num_stages=1,
)
if __name__ == "__main__":
main()
# Reference: fla/ops/common/chunk_o.py
import tilelang
import tilelang.language as T
import sys # noqa: F401
# Add your fla repository path to sys.path
# Currently we use the fla repository from the flash-linear-attention project at commit id f03cb3ae
# sys.path.insert(0, "/home/tzj/flash-linear-attention")
try:
import fla
print(fla.__file__)
from fla.ops.common.chunk_o import chunk_fwd_o
except ImportError:
print("fla not found, using tilelang implementation")
fla = None
import torch
torch.random.manual_seed(1)
tilelang.disable_cache()
def prepare_input(
B,
S,
H,
DK,
DV,
chunk_size,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
):
BS = chunk_size
Q = torch.randn(B, S, H, DK, dtype=input_dtype).cuda()
K = torch.randn(B, S, H, DK, dtype=input_dtype).cuda()
V = torch.randn(B, S, H, DV, dtype=input_dtype).cuda()
HIDDEN = torch.randn(B, S // BS, H, DK, DV, dtype=input_dtype).cuda()
G = torch.randn(B, S, H, dtype=gate_dtype).cuda()
return Q, K, V, HIDDEN, G
def prepare_output(
B,
S,
H,
DK,
DV,
chunk_size,
output_dtype,
):
O = torch.empty(B, S, H, DV, dtype=output_dtype).cuda()
return O
@tilelang.jit(out_idx=[-1])
def tilelang_chunk_fwd_o(
# task config
B,
S,
H,
DK,
DV,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
chunk_size,
scale,
use_g,
# kernel config
block_S=64,
block_DK=64,
block_DV=64,
threads=256,
num_stages=0,
):
assert chunk_size == block_S, "chunk_size must be equal to block_S"
BS = chunk_size
Q_shape = (B, S, H, DK)
K_shape = (B, S, H, DK)
V_shape = (B, S, H, DV)
H_shape = (B, S // BS, H, DK, DV)
G_shape = (B, S, H)
O_shape = (B, S, H, DV)
@T.prim_func
def kernel(
Q: T.Tensor(Q_shape, dtype=input_dtype),
K: T.Tensor(K_shape, dtype=input_dtype),
V: T.Tensor(V_shape, dtype=input_dtype),
HIDDEN: T.Tensor(H_shape, dtype=input_dtype),
G: T.Tensor(G_shape, dtype=gate_dtype),
O: T.Tensor(O_shape, dtype=output_dtype),
):
with T.Kernel(
T.ceildiv(DV, block_DV), T.ceildiv(S, block_S), B * H,
threads=threads) as (bv, bs, bbh):
bb, bh = bbh // H, bbh % H
Q_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype)
K_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype)
V_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype)
H_shared = T.alloc_shared((block_DK, block_DV), dtype=input_dtype)
A_shared = T.alloc_shared((block_S, block_S), dtype=input_dtype)
O_shared = T.alloc_shared((block_S, block_DV), dtype=output_dtype)
A_fragment = T.alloc_fragment((block_S, block_S), dtype=accum_dtype)
O_fragment = T.alloc_fragment((block_S, block_DV), dtype=accum_dtype)
G_shared = T.alloc_shared((block_S,), dtype=gate_dtype, scope="shared")
G_diff_local = T.alloc_fragment((block_S, block_S), dtype=gate_dtype)
T.annotate_layout({
Q_shared: tilelang.layout.make_swizzled_layout(Q_shared),
K_shared: tilelang.layout.make_swizzled_layout(K_shared),
V_shared: tilelang.layout.make_swizzled_layout(V_shared),
H_shared: tilelang.layout.make_swizzled_layout(H_shared),
A_shared: tilelang.layout.make_swizzled_layout(A_shared),
O_shared: tilelang.layout.make_swizzled_layout(O_shared),
})
T.clear(A_fragment)
T.clear(O_fragment)
T.no_set_max_nreg()
for i_k in T.Pipelined(T.ceildiv(DK, block_DK), num_stages=num_stages):
T.copy(
Q[bb, bs * block_S:(bs + 1) * block_S, bh, i_k * block_DK:(i_k + 1) * block_DK],
Q_shared)
T.copy(
K[bb, bs * block_S:(bs + 1) * block_S, bh, i_k * block_DK:(i_k + 1) * block_DK],
K_shared)
T.copy(
HIDDEN[bb, bs, bh, i_k * block_DK:(i_k + 1) * block_DK,
bv * block_DV:(bv + 1) * block_DV], H_shared)
T.gemm(Q_shared, H_shared, O_fragment)
T.gemm(Q_shared, K_shared, A_fragment, transpose_B=True)
if use_g:
for i_s in T.Parallel(block_S):
G_shared[i_s] = G[bb, bs * block_S + i_s, bh]
# T.copy(G[bb, bs * block_S:(bs + 1) * block_S, bh], G_shared)
for i_s, i_v in T.Parallel(block_S, block_DV):
O_fragment[i_s, i_v] = O_fragment[i_s, i_v] * T.exp(G_shared[i_s])
for i_s1, i_s2 in T.Parallel(block_S, block_S):
G_diff_local[i_s1, i_s2] = G_shared[i_s1] - G_shared[i_s2]
for i_s1, i_s2 in T.Parallel(block_S, block_S):
with T.If(G_diff_local[i_s1, i_s2] <= 0):
with T.Then():
A_fragment[i_s1, i_s2] = A_fragment[i_s1, i_s2] * T.exp(
G_diff_local[i_s1, i_s2])
with T.Else():
A_fragment[i_s1, i_s2] = 0
for i_s1, i_s2 in T.Parallel(block_S, block_S):
with T.If(i_s1 < i_s2): # noqa: SIM117
with T.Then():
A_fragment[i_s1, i_s2] = 0
T.copy(V[bb, bs * block_S:(bs + 1) * block_S, bh, bv * block_DV:(bv + 1) * block_DV],
V_shared)
T.copy(A_fragment, A_shared)
T.gemm(A_shared, V_shared, O_fragment)
for i_s, i_v in T.Parallel(block_S, block_DV):
O_fragment[i_s, i_v] = O_fragment[i_s, i_v] * scale
T.copy(O_fragment, O_shared)
T.copy(O_shared, O[bb, bs * block_S:(bs + 1) * block_S, bh,
bv * block_DV:(bv + 1) * block_DV])
return kernel
def run_test(
B,
S,
H,
DK,
DV,
chunk_size,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
use_g,
block_DK,
block_DV,
threads,
num_stages,
):
input_dtype_torch = getattr(torch, input_dtype)
output_dtype_torch = getattr(torch, output_dtype)
accum_dtype_torch = getattr(torch, accum_dtype)
gate_dtype_torch = getattr(torch, gate_dtype)
Q, K, V, HIDDEN, G = prepare_input(B, S, H, DK, DV, chunk_size, input_dtype_torch,
output_dtype_torch, accum_dtype_torch, gate_dtype_torch)
scale = 1.0 / DK**0.5
O_ref = prepare_output(B, S, H, DK, DV, chunk_size, output_dtype_torch)
O_ref = chunk_fwd_o(Q, K, V, HIDDEN, G, scale, chunk_size=chunk_size)
block_S = chunk_size
O_tilelang = prepare_output(B, S, H, DK, DV, chunk_size, output_dtype_torch)
kernel = tilelang_chunk_fwd_o(B, S, H, DK, DV, input_dtype, output_dtype, accum_dtype,
gate_dtype, chunk_size, scale, use_g, block_S, block_DK, block_DV,
threads, num_stages)
O_tilelang = kernel(Q, K, V, HIDDEN, G)
try:
torch.testing.assert_close(O_tilelang, O_ref, rtol=1e-2, atol=1e-2)
print("tilelang chunk fwd o passed √")
except Exception as e:
print("tilelang chunk fwd o failed ✗")
print(e)
def main():
run_test(
B=1,
S=32768,
H=32,
DK=128,
DV=128,
chunk_size=64,
input_dtype="bfloat16",
output_dtype="bfloat16",
accum_dtype="float32",
gate_dtype="float32",
use_g=True,
block_DK=128,
block_DV=128,
threads=128,
num_stages=1,
)
if __name__ == "__main__":
main()
# Reference: fla/ops/common/chunk_o.py
import math
import sys # noqa: F401
import tilelang
import tilelang.language as T
from tilelang.engine.callback import register_cuda_postproc_callback # noqa: F401
print(tilelang.__file__)
# Add your fla repository path to sys.path
# Currently we use the fla repository from the flash-linear-attention project at commit id f03cb3ae
# sys.path.insert(0, "/home/tzj/flash-linear-attention")
try:
import fla
print(fla.__file__)
from fla.ops.common.chunk_o import chunk_bwd_dqkwg
except ImportError:
print("fla not found, using tilelang implementation")
fla = None
import torch
from utils import *
torch.random.manual_seed(0)
# torch.set_printoptions(profile="full")
tilelang.disable_cache()
def prepare_input_fake(
B,
S,
H,
DK,
DV,
chunk_size,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
state_dtype,
):
BS = S // chunk_size
Q = torch.ones(B, S, H, DK, dtype=input_dtype).cuda()
K = torch.ones(B, S, H, DK, dtype=input_dtype).cuda()
V = torch.ones(B, S, H, DV, dtype=input_dtype).cuda()
h = torch.ones(B, BS, H, DK, DV, dtype=input_dtype).cuda()
G = torch.ones(B, S, H, dtype=gate_dtype).cuda()
dO = torch.ones(B, S, H, DV, dtype=input_dtype).cuda()
dh = torch.ones(B, BS, H, DK, DV, dtype=input_dtype).cuda()
dv = torch.ones(B, S, H, DV, dtype=output_dtype).cuda()
W = torch.ones(B, S, H, DK, dtype=input_dtype).cuda()
return Q, K, V, h, G, dO, dh, dv, W
def prepare_input(
B,
S,
H,
DK,
DV,
chunk_size,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
state_dtype,
):
BS = S // chunk_size
Q = torch.randn(B, S, H, DK, dtype=input_dtype).cuda()
K = torch.randn(B, S, H, DK, dtype=input_dtype).cuda()
V = torch.randn(B, S, H, DV, dtype=input_dtype).cuda()
h = torch.randn(B, BS, H, DK, DV, dtype=input_dtype).cuda()
G = torch.randn(B, S, H, dtype=gate_dtype).cuda()
dO = torch.randn(B, S, H, DV, dtype=input_dtype).cuda()
dh = torch.randn(B, BS, H, DK, DV, dtype=input_dtype).cuda()
dv = torch.randn(B, S, H, DV, dtype=output_dtype).cuda()
W = torch.randn(B, S, H, DK, dtype=input_dtype).cuda()
return Q, K, V, h, G, dO, dh, dv, W
def prepare_output(
B,
S,
H,
DK,
DV,
chunk_size,
output_dtype,
gate_dtype,
state_dtype,
block_DK,
):
assert DK == 32 and block_DK == 32 or DK > 32 and block_DK >= 64, "When DK > 32, block_DK must be >= 64"
NK = math.ceil(DK / block_DK)
dq = torch.empty(B, S, H, DK, dtype=output_dtype).cuda()
dk = torch.empty(B, S, H, DK, dtype=output_dtype).cuda()
dw = torch.empty(B, S, H, DK, dtype=output_dtype).cuda()
dg = torch.empty(NK, B, S, H, dtype=gate_dtype).cuda()
return dq, dk, dw, dg
# @register_cuda_postproc_callback
# def tilelang_callback_cuda_postproc(code, _):
# cuda_code = open("../debug/chunk_o_bwd3.log", "r").read()
# code = cuda_code
# return code
@tilelang.jit(
out_idx=[-4, -3, -2, -1],
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True
})
def tilelang_chunk_o_bwd_dqkwg(
# task config
B,
S,
H,
DK,
DV,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
state_dtype,
chunk_size,
scale,
use_g=True,
use_dw=True,
# kernel config
block_DK=64,
block_DV=64,
threads=256,
num_stages=0,
):
block_S = chunk_size
BS = S // block_S
NK = math.ceil(DK / block_DK)
Q_shape = (B, S, H, DK)
K_shape = (B, S, H, DK)
V_shape = (B, S, H, DV)
h_shape = (B, BS, H, DK, DV)
G_shape = (B, S, H)
dO_shape = (B, S, H, DV)
dh_shape = (B, BS, H, DK, DV)
dv_shape = (B, S, H, DV)
W_shape = (B, S, H, DK)
dq_shape = (B, S, H, DK)
dk_shape = (B, S, H, DK)
dw_shape = (B, S, H, DK)
dg_shape = (NK, B, S, H)
@T.prim_func
def kernel(
# input
Q: T.Tensor(Q_shape, dtype=input_dtype),
K: T.Tensor(K_shape, dtype=input_dtype),
V: T.Tensor(V_shape, dtype=input_dtype),
h: T.Tensor(h_shape, dtype=input_dtype),
G: T.Tensor(G_shape, dtype=gate_dtype),
dO: T.Tensor(dO_shape, dtype=input_dtype),
dh: T.Tensor(dh_shape, dtype=input_dtype),
dv: T.Tensor(dv_shape, dtype=input_dtype),
W: T.Tensor(W_shape, dtype=input_dtype),
# output
dq: T.Tensor(dq_shape, dtype=output_dtype),
dk: T.Tensor(dk_shape, dtype=output_dtype),
dw: T.Tensor(dw_shape, dtype=output_dtype),
dg: T.Tensor(dg_shape, dtype=gate_dtype),
):
with T.Kernel(
T.ceildiv(DK, block_DK), T.ceildiv(S, block_S), B * H,
threads=threads) as (bk, bs, bbh):
bb, bh = bbh // H, bbh % H
V_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype)
dO_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype)
h_shared = T.alloc_shared((block_DK, block_DV), dtype=input_dtype)
dh_shared = T.alloc_shared((block_DK, block_DV), dtype=input_dtype)
dv_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype)
q_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype)
k_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype)
ds_shared = T.alloc_shared((block_S, block_S), dtype=output_dtype)
dg_shared_1 = T.alloc_shared((block_S,), dtype=gate_dtype)
dg_shared_2 = T.alloc_shared((block_S,), dtype=gate_dtype)
dk_shared = T.alloc_shared((block_S, block_DK), dtype=accum_dtype)
ds_fragment = T.alloc_fragment((block_S, block_S), dtype=accum_dtype)
ds_fragment_positive = T.alloc_fragment((block_S, block_S), dtype=accum_dtype)
ds_fragment_positive_transpose = T.alloc_fragment((block_S, block_S), dtype=accum_dtype)
dq_fragment = T.alloc_fragment((block_S, block_DK), dtype=accum_dtype)
dk_fragment = T.alloc_fragment((block_S, block_DK), dtype=accum_dtype)
dk_fragment_2 = T.alloc_fragment((block_S, block_DK), dtype=accum_dtype)
dw_fragment = T.alloc_fragment((block_S, block_DK), dtype=accum_dtype)
q_fragment = T.alloc_fragment((block_S, block_DK), dtype=input_dtype)
k_fragment = T.alloc_fragment((block_S, block_DK), dtype=input_dtype)
dg_fragment_reduce_tmp = T.alloc_fragment((block_S, block_DK), dtype=gate_dtype)
dg_fragment = T.alloc_fragment((block_S,), dtype=gate_dtype)
dg_fragment_2 = T.alloc_fragment((block_S,), dtype=gate_dtype)
dg_fragment_final = T.alloc_fragment((block_S,), dtype=gate_dtype)
dg_last_local = T.alloc_local((2,), dtype=gate_dtype)
dg_last_fragment = T.alloc_fragment((block_DV * block_DK), dtype=gate_dtype)
dg_last_fragment_scalar = T.alloc_fragment((1,), dtype=gate_dtype)
dg_last_fragment_2 = T.alloc_fragment((block_S * block_DK), dtype=gate_dtype)
dg_last_fragment_scalar_2 = T.alloc_fragment((1,), dtype=gate_dtype)
G_shared = T.alloc_shared((block_S, block_DK), dtype=gate_dtype, scope="shared")
G_last_local = T.alloc_local((1,), dtype=gate_dtype)
T.use_swizzle(10)
T.annotate_layout({
V_shared: tilelang.layout.make_swizzled_layout(V_shared),
dO_shared: tilelang.layout.make_swizzled_layout(dO_shared),
h_shared: tilelang.layout.make_swizzled_layout(h_shared),
dh_shared: tilelang.layout.make_swizzled_layout(dh_shared),
dv_shared: tilelang.layout.make_swizzled_layout(dv_shared),
q_shared: tilelang.layout.make_swizzled_layout(q_shared),
k_shared: tilelang.layout.make_swizzled_layout(k_shared),
})
T.clear(dg_last_local)
T.clear(G_last_local)
T.clear(G_shared)
T.clear(q_fragment)
T.clear(k_fragment)
T.clear(dg_last_fragment)
T.clear(ds_fragment)
T.clear(dq_fragment)
T.clear(dk_fragment)
T.clear(dw_fragment)
for i_v in T.Pipelined(T.ceildiv(DV, block_DV), num_stages=num_stages):
T.copy(
V[bb, bs * block_S:(bs + 1) * block_S, bh, i_v * block_DV:(i_v + 1) * block_DV],
V_shared)
T.copy(
dO[bb, bs * block_S:(bs + 1) * block_S, bh,
i_v * block_DV:(i_v + 1) * block_DV], dO_shared)
T.copy(
h[bb, bs, bh, bk * block_DK:(bk + 1) * block_DK,
i_v * block_DV:(i_v + 1) * block_DV], h_shared)
T.copy(
dh[bb, bs, bh, bk * block_DK:(bk + 1) * block_DK,
i_v * block_DV:(i_v + 1) * block_DV], dh_shared)
if use_g:
T.clear(dg_last_fragment_scalar)
# FIXME: The reduce operation of a whole buffer to a scalar is not supported and will cause incorrect result
# for i_kv in T.Parallel(block_DK * block_DV):
# dg_last_fragment[i_kv] = h_shared[i_kv // block_DV, i_kv % block_DV] * dh_shared[i_kv // block_DV, i_kv % block_DV]
for i_kv in T.Parallel(block_DK * block_DV):
i_k, i_v = i_kv // block_DV, i_kv % block_DV
dg_last_fragment[i_kv] = h_shared[i_k, i_v] * dh_shared[i_k, i_v]
T.reduce_sum(dg_last_fragment, dg_last_fragment_scalar, dim=-1, clear=False)
dg_last_local[0] += dg_last_fragment_scalar[0]
T.gemm(dO_shared, V_shared, ds_fragment, transpose_B=True)
T.gemm(dO_shared, h_shared, dq_fragment, transpose_B=True)
T.gemm(V_shared, dh_shared, dk_fragment, transpose_B=True)
if use_dw:
T.copy(
dv[bb, bs * block_S:(bs + 1) * block_S, bh,
i_v * block_DV:(i_v + 1) * block_DV], dv_shared)
T.gemm(dv_shared, h_shared, dw_fragment, transpose_B=True)
if use_dw:
for i_s, i_k in T.Parallel(block_S, block_DK):
dw_fragment[i_s, i_k] = -dw_fragment[i_s, i_k]
T.copy(
dw_fragment, dw[bb, bs * block_S:(bs + 1) * block_S, bh,
bk * block_DK:(bk + 1) * block_DK])
T.copy(Q[bb, bs * block_S:(bs + 1) * block_S, bh, bk * block_DK:(bk + 1) * block_DK],
q_shared)
T.copy(K[bb, bs * block_S:(bs + 1) * block_S, bh, bk * block_DK:(bk + 1) * block_DK],
k_shared)
T.copy(q_shared, q_fragment)
T.copy(k_shared, k_fragment)
if use_g:
T.clear(dg_fragment)
T.clear(dg_fragment_2)
for i_s, i_k in T.Parallel(block_S, block_DK):
G_shared[i_s, i_k] = G[bb, bs * block_S + i_s, bh]
G_last_local[0] = G[bb, bs * block_S + block_S - 1, bh]
# Use gmem directly instead of local register
dg_last_local[0] = dg_last_local[0] * T.exp(G[bb, bs * block_S + block_S - 1, bh])
for i_s, i_k in T.Parallel(block_S, block_DK):
dq_fragment[i_s, i_k] = dq_fragment[i_s, i_k] * T.exp(G[bb, bs * block_S + i_s,
bh]) * scale
T.clear(dg_fragment_reduce_tmp)
for i_s, i_k in T.Parallel(block_S, block_DK):
dg_fragment_reduce_tmp[i_s, i_k] = dq_fragment[i_s, i_k] * q_shared[i_s, i_k]
# FIXME: The reduce_sum statement with clear=True will cause an error of warp specialized pass
T.reduce_sum(dg_fragment_reduce_tmp, dg_fragment, dim=-1, clear=False)
for i_s, i_k in T.Parallel(block_S, block_DK):
with T.If(G_last_local[0] - G[bb, bs * block_S + i_s, bh] <= 0):
with T.Then():
dk_fragment[i_s, i_k] = dk_fragment[i_s, i_k] * T.exp(
G_last_local[0] - G[bb, bs * block_S + i_s, bh])
with T.Else():
dk_fragment[i_s, i_k] = 0
T.clear(dg_fragment_reduce_tmp)
for i_s, i_k in T.Parallel(block_S, block_DK):
dg_fragment_reduce_tmp[i_s, i_k] = dk_fragment[i_s, i_k] * (-k_shared[i_s, i_k])
# FIXME: The reduce_sum statement with clear=True will cause an error of warp specialized pass
T.reduce_sum(dg_fragment_reduce_tmp, dg_fragment, dim=-1, clear=False)
# FIXME: The reduce operation of a whole buffer to a scalar is not supported and will cause incorrect result
T.copy(dk_fragment, dk_shared)
T.clear(dg_last_fragment_scalar_2)
for i_sk in T.Parallel(block_S * block_DK):
i_s, i_k = i_sk // block_DK, i_sk % block_DK
dg_last_fragment_2[i_sk] = dk_shared[i_s, i_k] * k_shared[i_s, i_k]
T.reduce_sum(dg_last_fragment_2, dg_last_fragment_scalar_2, dim=-1, clear=False)
dg_last_local[1] = dg_last_fragment_scalar_2[0]
for i_s1, i_s2 in T.Parallel(block_S, block_S):
with T.If(i_s1 >= i_s2 and
G[bb, bs * block_S + i_s1, bh] - G[bb, bs * block_S + i_s2, bh] <= 0):
with T.Then():
ds_fragment[i_s1, i_s2] = ds_fragment[
i_s1, i_s2] * T.exp(G[bb, bs * block_S + i_s1, bh] -
G[bb, bs * block_S + i_s2, bh]) * scale
with T.Else():
ds_fragment[i_s1, i_s2] = 0
T.clear(ds_fragment_positive)
T.clear(ds_fragment_positive_transpose)
T.gemm(q_shared, k_shared, ds_fragment_positive, transpose_B=True)
for i_s1, i_s2 in T.Parallel(block_S, block_S):
ds_fragment_positive[
i_s1, i_s2] = ds_fragment[i_s1, i_s2] * ds_fragment_positive[i_s1, i_s2]
# FIXME: The reduce_sum statement with clear=True will cause an error of warp specialized pass
T.reduce_sum(ds_fragment_positive, dg_fragment, dim=1, clear=False)
T.copy(dg_fragment, dg_shared_1)
# We should transpose the matrix because the reduce_sum statement can only reduce along the last dimension
for i_s1, i_s2 in T.Parallel(block_S, block_S):
ds_fragment_positive_transpose[i_s2, i_s1] = ds_fragment_positive[i_s1, i_s2]
# FIXME: The reduce_sum statement with clear=True will cause an error of warp specialized pass
T.reduce_sum(ds_fragment_positive_transpose, dg_fragment_2, dim=1, clear=False)
T.copy(dg_fragment_2, dg_shared_2)
for i_s in T.Parallel(block_S):
dg_fragment_final[i_s] = dg_shared_1[i_s] - dg_shared_2[i_s]
T.copy(ds_fragment, ds_shared)
T.gemm(ds_shared, k_shared, dq_fragment)
T.gemm(ds_shared, q_shared, dk_fragment, transpose_A=True)
for i_s in T.Parallel(block_S):
with T.If(i_s >= block_S - 1): # noqa: SIM117
with T.Then():
dg_fragment_final[
i_s] = dg_fragment_final[i_s] + dg_last_local[0] + dg_last_local[1]
T.copy(
dq_fragment, dq[bb, bs * block_S:(bs + 1) * block_S, bh,
bk * block_DK:(bk + 1) * block_DK])
T.copy(
dk_fragment, dk[bb, bs * block_S:(bs + 1) * block_S, bh,
bk * block_DK:(bk + 1) * block_DK])
for i_s in T.Parallel(block_S):
dg[bk, bb, bs * block_S + i_s, bh] = dg_fragment_final[i_s]
else:
for i_s1, i_s2 in T.Parallel(block_S, block_S):
with T.If(i_s1 < i_s2): # noqa: SIM117
with T.Then():
ds_fragment[i_s1, i_s2] = 0
T.clear(dk_fragment_2)
T.copy(ds_fragment, ds_shared)
T.gemm(ds_shared, k_shared, dq_fragment)
T.gemm(ds_shared, q_shared, dk_fragment_2, transpose_A=True)
for i_s, i_k in T.Parallel(block_S, block_DK):
dq_fragment[i_s, i_k] = dq_fragment[i_s, i_k] * scale
dk_fragment[i_s, i_k] = dk_fragment[i_s, i_k] + dk_fragment_2[i_s, i_k] * scale
T.copy(
dq_fragment, dq[bb, bs * block_S:(bs + 1) * block_S, bh,
bk * block_DK:(bk + 1) * block_DK])
T.copy(
dk_fragment, dk[bb, bs * block_S:(bs + 1) * block_S, bh,
bk * block_DK:(bk + 1) * block_DK])
return kernel
def do_bench(fn, *args, warmup=10, rep=10, **kwargs):
"""
Do benchmark for a function.
"""
start_event = [torch.cuda.Event(enable_timing=True) for i in range(rep)]
end_event = [torch.cuda.Event(enable_timing=True) for i in range(rep)]
for _ in range(warmup):
fn(*args, **kwargs)
torch.cuda.synchronize()
for i in range(rep):
start_event[i].record()
fn(*args, **kwargs)
end_event[i].record()
torch.cuda.synchronize()
# Record clocks
times = torch.tensor(
[s.elapsed_time(e) for s, e in zip(start_event, end_event)],
dtype=torch.float,
)
return times.mean().item()
def run_test(
B,
S,
H,
DK,
DV,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
state_dtype,
chunk_size,
scale,
use_g=True,
use_dw=True,
block_DK=64,
block_DV=64,
threads=256,
num_stages=0,
):
Q, K, V, h, G, dO, dh, dv, W = prepare_input(B, S, H, DK, DV, chunk_size,
getattr(torch, input_dtype),
getattr(torch, output_dtype),
getattr(torch, accum_dtype),
getattr(torch, gate_dtype),
getattr(torch, state_dtype))
dq_ref, dk_ref, dw_ref, dg_ref = prepare_output(B, S, H, DK, DV, chunk_size,
getattr(torch, output_dtype),
getattr(torch, gate_dtype),
getattr(torch, state_dtype), block_DK)
dq_tilelang, dk_tilelang, dw_tilelang, dg_tilelang = prepare_output(
B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype),
getattr(torch, state_dtype), block_DK)
# ref
if use_g:
dq_ref, dk_ref, dw_ref, dg_ref = chunk_bwd_dqkwg(
Q, K, V, G, dO, h, dh, dv, W, chunk_size=chunk_size, scale=scale)
else:
dq_ref, dk_ref, dw_ref, dg_ref = chunk_bwd_dqkwg(
Q, K, V, None, dO, h, dh, dv, W, chunk_size=chunk_size, scale=scale)
# tilelang
kernel = tilelang_chunk_o_bwd_dqkwg(B, S, H, DK, DV, input_dtype, output_dtype, accum_dtype,
gate_dtype, state_dtype, chunk_size, scale, use_g, use_dw,
block_DK, block_DV, threads, num_stages)
print(kernel.get_kernel_source())
dq_tilelang, dk_tilelang, dw_tilelang, dg_tilelang = kernel(Q, K, V, h, G, dO, dh, dv, W)
if use_g:
dg_tilelang = dg_tilelang.sum(dim=0)
# check
try:
assert_similar(dq_ref, dq_tilelang, 1e-5, "tilelang chunk o bwd dq")
print("tilelang chunk o bwd dq passed √")
except Exception as e:
print("tilelang chunk o bwd dq failed ✗")
print(e)
try:
assert_similar(dk_ref, dk_tilelang, 1e-5, "tilelang chunk o bwd dk")
print("tilelang chunk o bwd dk passed √")
except Exception as e:
print("tilelang chunk o bwd dk failed ✗")
print(e)
if use_g:
try:
assert_similar(dg_ref, dg_tilelang, 1e-5, "tilelang chunk o bwd dg")
print("tilelang chunk o bwd dg passed √")
except Exception as e:
print("tilelang chunk o bwd dg failed ✗")
print(e)
if use_dw:
try:
assert_similar(dw_ref, dw_tilelang, 1e-5, "tilelang chunk o bwd dw")
print("tilelang chunk o bwd dw passed √")
except Exception as e:
print("tilelang chunk o bwd dw failed ✗")
print(e)
def main():
DK = 128
DV = 128
run_test(
B=1,
S=32768,
H=8,
DK=DK,
DV=DV,
input_dtype="bfloat16",
output_dtype="bfloat16",
accum_dtype="float32",
gate_dtype="float32",
state_dtype="float32",
chunk_size=64,
scale=DK**-0.5,
# scale=1,
use_g=True,
use_dw=True,
block_DK=64,
block_DV=64,
threads=128,
num_stages=0,
)
if __name__ == "__main__":
main()
# Reference: fla/ops/common/chunk_scaled_dot_kkt.py
import tilelang
import tilelang.language as T
import sys # noqa: F401
# Add your fla repository path to sys.path
# Currently we use the fla repository from the flash-linear-attention project at commit id f03cb3ae
# sys.path.insert(0, "/home/tzj/flash-linear-attention")
try:
import fla
print(fla.__file__)
from fla.ops.common.chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd
except ImportError:
print("fla not found, using tilelang implementation")
fla = None
import torch
torch.set_printoptions(profile="full")
torch.random.manual_seed(0)
tilelang.disable_cache()
def prepare_input(
B,
S,
H,
DK,
input_dtype,
output_dtype,
accum_dtype,
):
K = torch.randn(B, S, H, DK, dtype=input_dtype).cuda()
Beta = torch.randn(B, S, H, dtype=input_dtype).cuda()
G = torch.randn(B, S, H, dtype=accum_dtype).cuda()
return K, Beta, G
def prepare_output(
B,
S,
H,
chunk_size,
dtype,
):
BS = chunk_size
A = torch.empty(B, S, H, BS, dtype=dtype).cuda()
return A
@tilelang.jit(out_idx=[-1])
def tilelang_chunk_scaled_dot_kkt_fwd(
# task config
B,
S,
H,
DK,
chunk_size=64,
input_dtype="bfloat16",
output_dtype="bfloat16",
accum_dtype="float32",
use_g=True,
# kernel config
block_S=64,
block_DK=64,
threads=256,
num_stages=0,
):
K_shape = (B, S, H, DK)
Beta_shape = (B, S, H)
G_shape = (B, S, H)
assert chunk_size == block_S, "chunk_size must be equal to block_S"
BS = chunk_size
output_shape = (B, S, H, BS)
@T.prim_func
def kernel(
K: T.Tensor(K_shape, dtype=input_dtype),
Beta: T.Tensor(Beta_shape, dtype=input_dtype),
G: T.Tensor(G_shape, dtype=accum_dtype),
A: T.Tensor(output_shape, dtype=output_dtype),
):
with T.Kernel(T.ceildiv(S, block_S), B * H, threads=threads) as (bs, bbh):
bb, bh = bbh // H, bbh % H
# !! Pay attention to the scope of the shared memory: may cause misaligned address when shape is one dimension or the buffer is too small
Beta_shared = T.alloc_shared((block_S,), dtype=input_dtype, scope="shared")
K_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype)
A_shared = T.alloc_shared((block_S, block_S), dtype=output_dtype)
Beta_K_fragment = T.alloc_fragment((block_S, block_DK), dtype=input_dtype)
A_fragment = T.alloc_fragment((block_S, block_S), dtype=accum_dtype)
# Tensor used for gated:
G_shared = T.alloc_shared((block_S,), dtype=accum_dtype, scope="shared")
G_diff_local = T.alloc_fragment((block_S, block_S), dtype=accum_dtype)
T.annotate_layout({
K_shared: tilelang.layout.make_swizzled_layout(K_shared),
A_shared: tilelang.layout.make_swizzled_layout(A_shared),
})
T.fill(A_fragment, 0)
T.no_set_max_nreg()
for i_s in T.Parallel(block_S):
Beta_shared[i_s] = Beta[bb, bs * block_S + i_s, bh]
for i_k in T.Pipelined(T.ceildiv(DK, block_DK), num_stages=num_stages):
T.copy(
K[bb, bs * block_S:(bs + 1) * block_S, bh, i_k * block_DK:(i_k + 1) * block_DK],
K_shared)
for i_s, i_k2 in T.Parallel(block_S, block_DK):
Beta_K_fragment[i_s, i_k2] = K_shared[i_s, i_k2] * Beta_shared[i_s]
T.gemm(Beta_K_fragment, K_shared, A_fragment, transpose_B=True)
if use_g:
for i_s in T.Parallel(block_S):
G_shared[i_s] = G[bb, bs * block_S + i_s, bh]
for i_s1, i_s2 in T.Parallel(block_S, block_S):
G_diff_local[i_s1, i_s2] = G_shared[i_s1] - G_shared[i_s2]
for i_s1, i_s2 in T.Parallel(block_S, block_S):
with T.If(G_diff_local[i_s1, i_s2] <= 0 and i_s1 > i_s2):
with T.Then():
A_fragment[i_s1, i_s2] = A_fragment[i_s1, i_s2] * T.exp(
G_diff_local[i_s1, i_s2])
with T.Else():
A_fragment[i_s1, i_s2] = 0
else:
for i_s1, i_s2 in T.Parallel(block_S, block_S):
with T.If(i_s1 <= i_s2): # noqa: SIM117
with T.Then():
A_fragment[i_s1, i_s2] = 0
T.copy(A_fragment, A_shared)
T.copy(A_shared, A[bb, bs * block_S:(bs + 1) * block_S, bh, :])
return kernel
def run_test(
B,
S,
H,
DK,
chunk_size,
input_dtype,
output_dtype,
accum_dtype,
use_g,
block_DK,
threads,
num_stages,
):
K, Beta, G = prepare_input(B, S, H, DK, getattr(torch, input_dtype),
getattr(torch, output_dtype), getattr(torch, accum_dtype))
A_ref = prepare_output(B, S, H, chunk_size, getattr(torch, output_dtype))
A_tilelang = prepare_output(B, S, H, chunk_size, getattr(torch, output_dtype))
# reference
if use_g:
A_ref = chunk_scaled_dot_kkt_fwd(
K, Beta, G, chunk_size=chunk_size, output_dtype=getattr(torch, output_dtype))
else:
A_ref = chunk_scaled_dot_kkt_fwd(
K, Beta, None, chunk_size=chunk_size, output_dtype=getattr(torch, output_dtype))
# tilelang
block_S = chunk_size
kernel = tilelang_chunk_scaled_dot_kkt_fwd(B, S, H, DK, chunk_size, input_dtype, output_dtype,
accum_dtype, use_g, block_S, block_DK, threads,
num_stages)
A_tilelang = kernel(K, Beta, G)
try:
torch.testing.assert_close(A_tilelang, A_ref, rtol=1e-2, atol=1e-2)
print("tilelang chunk scaled dot kkt fwd passed √")
except Exception as e:
print("tilelang chunk scaled dot kkt fwd failed ✗")
print(e)
print("reference cuda kernel:")
print(kernel.get_kernel_source())
def main():
run_test(
B=1,
S=32768,
H=32,
DK=128,
chunk_size=64,
input_dtype="bfloat16",
output_dtype="bfloat16",
accum_dtype="float32",
use_g=True,
block_DK=64,
threads=128,
num_stages=2)
if __name__ == "__main__":
main()
# Util functions for flash linear attention cumsum
# Reference: fla/ops/utils/cumsum.py
import tilelang
import tilelang.language as T
import sys # noqa: F401
# Add your fla repository path to sys.path
# Currently we use the fla repository from the flash-linear-attention project at commit id f03cb3ae
# sys.path.insert(0, "/home/tzj/flash-linear-attention")
try:
import fla
print(fla.__file__)
from fla.ops.utils.cumsum import chunk_local_cumsum_scalar
except ImportError:
print("fla not found, using tilelang implementation")
fla = None
import torch
tilelang.disable_cache()
@tilelang.jit(
out_idx=[-1],
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True
})
def tilelang_chunk_local_cumsum_scalar(
# task config
B,
S,
H,
chunk_size=64,
is_varlen=False,
head_first=False,
reverse=False,
input_dtype="float16",
output_dtype="float32",
# kernel config
block_S=64,
threads=256,
use_fragment=False,
):
G_shape = (B, H, S) if head_first else (B, S, H)
assert chunk_size == 2**(chunk_size.bit_length() - 1), "chunk_size must be a power of 2"
assert chunk_size == block_S, "chunk_size must be equal to block_S"
@T.prim_func
def kernel(
G: T.Tensor(G_shape, dtype=input_dtype),
G_new: T.Tensor(G_shape, dtype=output_dtype),
):
with T.Kernel(T.ceildiv(S, block_S), B * H, threads=threads) as (bs, bbh):
bb, bh = bbh // H, bbh % H
G_shared = T.alloc_shared((1, block_S), dtype=output_dtype, scope="shared")
if head_first:
T.copy(G[bb, bh, bs * block_S:(bs + 1) * block_S], G_shared)
else:
T.copy(G[bb, bs * block_S:(bs + 1) * block_S, bh], G_shared)
if use_fragment:
G_fragment = T.alloc_fragment((1, block_S), dtype=output_dtype, scope="shared")
T.copy(G_shared, G_fragment)
T.cumsum(G_fragment, dim=1, reverse=reverse)
if head_first:
T.copy(G_fragment, G_new[bb, bh, bs * block_S:(bs + 1) * block_S])
else:
T.copy(G_fragment, G_new[bb, bs * block_S:(bs + 1) * block_S, bh])
else:
T.cumsum(G_shared, dim=1, reverse=reverse)
if head_first:
T.copy(G_shared, G_new[bb, bh, bs * block_S:(bs + 1) * block_S])
else:
T.copy(G_shared, G_new[bb, bs * block_S:(bs + 1) * block_S, bh])
return kernel
def prepare_cumsum_input(
B,
S,
H,
dtype,
):
G = torch.randn(B, S, H, dtype=dtype).cuda()
return G
def prepare_cumsum_output(
B,
S,
H,
dtype,
):
G_new = torch.empty(B, S, H, dtype=dtype).cuda()
return G_new
def run_test(
B,
S,
H,
chunk_size,
reverse,
head_first,
input_dtype,
output_dtype,
threads,
use_fragment,
):
G = prepare_cumsum_input(B, S, H, getattr(torch, input_dtype))
G_new_ref = prepare_cumsum_output(B, S, H, getattr(torch, output_dtype))
G_new_tilelang = prepare_cumsum_output(B, S, H, getattr(torch, output_dtype))
# reference cumsum
G_new_ref = chunk_local_cumsum_scalar(
g=G,
chunk_size=chunk_size,
reverse=reverse,
head_first=head_first,
output_dtype=getattr(torch, output_dtype))
# tilelang cumsum
block_S = chunk_size
kernel = tilelang_chunk_local_cumsum_scalar(
B=B,
S=S,
H=H,
chunk_size=chunk_size,
reverse=reverse,
head_first=head_first,
input_dtype=input_dtype,
output_dtype=output_dtype,
block_S=block_S,
threads=threads,
use_fragment=use_fragment,
)
torch.cuda.profiler.start()
G_new_tilelang = kernel(G)
torch.cuda.profiler.stop()
try:
torch.testing.assert_close(G_new_tilelang, G_new_ref, rtol=1e-2, atol=1e-2)
print("tilelang cumsum passed √")
except Exception as e:
print("tilelang cumsum failed ✗")
print(e)
print("G:")
print(G.view(-1))
print("G_new_tilelang:")
print(G_new_tilelang.view(-1))
print("G_new_ref:")
print(G_new_ref.view(-1))
def main():
run_test(
B=1,
S=32768,
H=32,
chunk_size=64,
reverse=True,
head_first=False,
input_dtype="float32",
output_dtype="float32",
threads=256,
use_fragment=False)
if __name__ == "__main__":
main()
# Reference: fla/ops/gated_delta_rule/wy_fast.py
import tilelang
import tilelang.language as T
import sys # noqa: F401
# Add your fla repository path to sys.path
# Currently we use the fla repository from the flash-linear-attention project at commit id f03cb3ae
# sys.path.insert(0, "/home/tzj/flash-linear-attention")
try:
import fla
print(fla.__file__)
from fla.ops.gated_delta_rule.wy_fast import recompute_w_u_fwd
except ImportError:
print("fla not found, using tilelang implementation")
fla = None
import torch
torch.random.manual_seed(1)
tilelang.disable_cache()
def prepare_input(B, S, H, DK, DV, chunk_size, input_dtype, output_dtype, gate_dtype=torch.float32):
BS = chunk_size
K = torch.randn(B, S, H, DK, dtype=input_dtype).cuda()
V = torch.randn(B, S, H, DV, dtype=input_dtype).cuda()
Beta = torch.randn(B, S, H, dtype=input_dtype).cuda()
G = torch.randn(B, S, H, dtype=gate_dtype).cuda()
A = torch.randn(B, S, H, BS, dtype=output_dtype).cuda()
return K, V, Beta, G, A
def prepare_output(
B,
S,
H,
DK,
DV,
output_dtype,
):
W = torch.empty(B, S, H, DK, dtype=output_dtype).cuda()
U = torch.empty(B, S, H, DV, dtype=output_dtype).cuda()
return W, U
@tilelang.jit(out_idx=[-2, -1])
def tilelang_recompute_w_u_fwd(
# task config
B,
S,
H,
DK,
DV,
input_dtype,
output_dtype,
gate_dtype,
accum_dtype,
chunk_size,
# kernel config
block_S=64,
block_DK=64,
block_DV=64,
threads=256,
num_stages=0,
):
K_shape = (B, S, H, DK)
V_shape = (B, S, H, DV)
Beta_shape = (B, S, H)
assert chunk_size == block_S, "chunk_size must be equal to block_S"
BS = chunk_size
G_shape = (B, S, H)
A_shape = (B, S, H, BS)
@T.prim_func
def kernel(
K: T.Tensor(K_shape, dtype=input_dtype),
V: T.Tensor(V_shape, dtype=input_dtype),
Beta: T.Tensor(Beta_shape, dtype=input_dtype),
G: T.Tensor(G_shape, dtype=gate_dtype),
A: T.Tensor(A_shape, dtype=output_dtype),
W: T.Tensor(K_shape, dtype=output_dtype),
U: T.Tensor(V_shape, dtype=output_dtype),
):
with T.Kernel(T.ceildiv(S, block_S), B * H, threads=threads) as (bs, bbh):
bb, bh = bbh // H, bbh % H
Beta_shared = T.alloc_shared((block_S,), dtype=input_dtype, scope="shared")
K_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype)
V_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype)
G_shared = T.alloc_shared((block_S,), dtype=gate_dtype, scope="shared")
A_shared = T.alloc_shared((block_S, block_S), dtype=output_dtype)
W_fragment = T.alloc_fragment((block_S, block_DK), dtype=accum_dtype)
U_fragment = T.alloc_fragment((block_S, block_DV), dtype=accum_dtype)
W_shared = T.alloc_shared((block_S, block_DK), dtype=output_dtype)
U_shared = T.alloc_shared((block_S, block_DV), dtype=output_dtype)
W_Beta_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype)
U_Beta_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype)
T.annotate_layout({
K_shared: tilelang.layout.make_swizzled_layout(K_shared),
V_shared: tilelang.layout.make_swizzled_layout(V_shared),
A_shared: tilelang.layout.make_swizzled_layout(A_shared),
W_shared: tilelang.layout.make_swizzled_layout(W_shared),
U_shared: tilelang.layout.make_swizzled_layout(U_shared),
W_Beta_shared: tilelang.layout.make_swizzled_layout(W_Beta_shared),
U_Beta_shared: tilelang.layout.make_swizzled_layout(U_Beta_shared),
})
T.no_set_max_nreg()
for i_s in T.Parallel(block_S):
Beta_shared[i_s] = Beta[bb, bs * block_S + i_s, bh]
G_shared[i_s] = T.exp(G[bb, bs * block_S + i_s, bh])
T.copy(A[bb, bs * block_S:(bs + 1) * block_S, bh, :], A_shared)
for i_v in T.Pipelined(T.ceildiv(DV, block_DV), num_stages=num_stages):
T.copy(
V[bb, bs * block_S:(bs + 1) * block_S, bh, i_v * block_DV:(i_v + 1) * block_DV],
V_shared)
for i_s, i_v2 in T.Parallel(block_S, block_DV):
U_Beta_shared[i_s, i_v2] = V_shared[i_s, i_v2] * Beta_shared[i_s]
T.gemm(A_shared, U_Beta_shared, U_fragment, clear_accum=True)
# First copy to smem, then copy to gmem to reduce U2RU instructions
T.copy(U_fragment, U_shared)
T.copy(
U_shared, U[bb, bs * block_S:(bs + 1) * block_S, bh,
i_v * block_DV:(i_v + 1) * block_DV])
for i_k in T.Pipelined(T.ceildiv(DK, block_DK), num_stages=num_stages):
T.copy(
K[bb, bs * block_S:(bs + 1) * block_S, bh, i_k * block_DK:(i_k + 1) * block_DK],
K_shared)
for i_s, i_k2 in T.Parallel(block_S, block_DK):
W_Beta_shared[i_s,
i_k2] = K_shared[i_s, i_k2] * Beta_shared[i_s] * G_shared[i_s]
T.gemm(A_shared, W_Beta_shared, W_fragment, clear_accum=True)
# First copy to smem, then copy to gmem to reduce U2RU instructions
T.copy(W_fragment, W_shared)
T.copy(
W_shared, W[bb, bs * block_S:(bs + 1) * block_S, bh,
i_k * block_DK:(i_k + 1) * block_DK])
return kernel
def run_test(
B,
S,
H,
DK,
DV,
chunk_size,
input_dtype,
output_dtype,
gate_dtype,
accum_dtype,
block_DK,
block_DV,
threads,
num_stages,
):
K, V, Beta, G, A = prepare_input(
B,
S,
H,
DK,
DV,
chunk_size,
getattr(torch, input_dtype),
getattr(torch, output_dtype),
gate_dtype=getattr(torch, gate_dtype))
W_ref, U_ref = prepare_output(B, S, H, DK, DV, getattr(torch, output_dtype))
W_tilelang, U_tilelang = prepare_output(B, S, H, DK, DV, getattr(torch, output_dtype))
# reference
W_ref, U_ref = recompute_w_u_fwd(K, V, Beta, G, A, None)
# tilelang
block_S = chunk_size
kernel = tilelang_recompute_w_u_fwd(
B,
S,
H,
DK,
DV,
input_dtype,
output_dtype,
gate_dtype,
accum_dtype,
chunk_size,
block_S=block_S,
block_DK=block_DK,
block_DV=block_DV,
threads=threads,
num_stages=num_stages)
print(kernel.get_kernel_source())
W_tilelang, U_tilelang = kernel(K, V, Beta, G, A)
try:
torch.testing.assert_close(W_tilelang, W_ref, rtol=1e-2, atol=1e-2)
print("tilelang recompute w passed √")
except Exception as e:
print("tilelang recompute w failed ✗")
print(e)
try:
torch.testing.assert_close(U_tilelang, U_ref, rtol=1e-2, atol=1e-2)
print("tilelang recompute u passed √")
except Exception as e:
print("tilelang recompute u failed ✗")
print(e)
def main():
run_test(
B=1,
S=32768,
H=32,
DK=128,
DV=128,
chunk_size=64,
input_dtype="bfloat16",
output_dtype="bfloat16",
gate_dtype="float32",
accum_dtype="float32",
block_DK=64,
block_DV=32,
threads=128,
num_stages=3)
if __name__ == "__main__":
main()
# Reference: fla/ops/gated_delta_rule/wy_fast.py
import sys # noqa: F401
import tilelang
import tilelang.language as T
# Add your fla repository path to sys.path
# Currently we use the fla repository from the flash-linear-attention project at commit id 00000000
# sys.path.insert(0, "/home/tzj/flash-linear-attention")
try:
import fla
print(fla.__file__)
from fla.ops.gated_delta_rule.wy_fast import bwd_prepare_wy_repr
except ImportError:
print("fla not found, using tilelang implementation")
fla = None
import torch
import torch.nn.functional as F
from utils import assert_similar
torch.random.manual_seed(0)
torch.set_printoptions(profile="full")
tilelang.disable_cache()
def prepare_input_fake(
B,
S,
H,
DK,
DV,
chunk_size,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
state_dtype,
):
BS = chunk_size
K = torch.ones(B, S, H, DK, dtype=input_dtype).cuda()
V = torch.ones(B, S, H, DV, dtype=input_dtype).cuda()
Beta = torch.ones(B, S, H, dtype=input_dtype).cuda()
G = torch.ones(B, S, H, dtype=gate_dtype).cuda()
A = torch.ones(B, S, H, BS, dtype=input_dtype).cuda()
dw = torch.ones(B, S, H, DK, dtype=input_dtype).cuda()
du = torch.ones(B, S, H, DV, dtype=input_dtype).cuda()
return K, V, Beta, G, A, dw, du
def prepare_input(
B,
S,
H,
DK,
DV,
chunk_size,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
state_dtype,
):
BS = chunk_size
K = torch.randn(B, S, H, DK, dtype=input_dtype).cuda()
K = F.normalize(K, dim=-1, p=2)
V = torch.randn(B, S, H, DV, dtype=input_dtype).cuda()
V = F.normalize(V, dim=-1, p=2)
Beta = torch.randn(B, S, H, dtype=input_dtype).cuda()
G = torch.randn(B, S, H, dtype=gate_dtype).cuda()
A = torch.randn(B, S, H, BS, dtype=input_dtype).cuda()
dw = torch.randn(B, S, H, DK, dtype=input_dtype).cuda()
du = torch.randn(B, S, H, DV, dtype=input_dtype).cuda()
return K, V, Beta, G, A, dw, du
def prepare_output(
B,
S,
H,
DK,
DV,
chunk_size,
output_dtype,
gate_dtype,
state_dtype,
):
dk = torch.empty(B, S, H, DK, dtype=output_dtype).cuda()
dv = torch.empty(B, S, H, DV, dtype=output_dtype).cuda()
dbeta = torch.empty(B, S, H, dtype=output_dtype).cuda()
dg = torch.empty(B, S, H, dtype=gate_dtype).cuda()
return dk, dv, dbeta, dg
@tilelang.jit(
out_idx=[-5, -4, -3, -2, -1],
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True
})
def tilelang_wy_fast_bwd(
# task config
B,
S,
H,
DK,
DV,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
state_dtype,
chunk_size,
# kernel config
block_DK=64,
block_DV=64,
threads=128,
num_stages=0,
):
block_S = chunk_size
BS = block_S
K_shape = (B, S, H, DK)
V_shape = (B, S, H, DV)
Beta_shape = (B, S, H)
G_shape = (B, S, H)
A_shape = (B, S, H, BS)
dw_shape = (B, S, H, DK)
du_shape = (B, S, H, DV)
dk_shape = (B, S, H, DK)
dv_shape = (B, S, H, DV)
dbeta_shape = (B, S, H)
dg_shape = (B, S, H)
dA_shape = (B, S, H, BS)
@T.prim_func
def kernel(
# input
K: T.Tensor(K_shape, dtype=input_dtype),
V: T.Tensor(V_shape, dtype=input_dtype),
Beta: T.Tensor(Beta_shape, dtype=input_dtype),
G: T.Tensor(G_shape, dtype=gate_dtype),
A: T.Tensor(A_shape, dtype=input_dtype),
dw: T.Tensor(dw_shape, dtype=input_dtype),
du: T.Tensor(du_shape, dtype=input_dtype),
# output
dA: T.Tensor(dA_shape, dtype=input_dtype),
dk: T.Tensor(dk_shape, dtype=output_dtype),
dv: T.Tensor(dv_shape, dtype=output_dtype),
dbeta: T.Tensor(dbeta_shape, dtype=output_dtype),
dg: T.Tensor(dg_shape, dtype=gate_dtype),
):
with T.Kernel(T.ceildiv(S, block_S), B * H, threads=threads) as (bs, bbh):
bb, bh = bbh // H, bbh % H
A_shared = T.alloc_shared((block_S, block_S), dtype=input_dtype)
K_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype)
K_shared_beta_g = T.alloc_shared((block_S, block_DK), dtype=input_dtype)
V_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype)
V_shared_beta = T.alloc_shared((block_S, block_DV), dtype=input_dtype)
Beta_shared = T.alloc_shared((block_S,), dtype=input_dtype)
G_shared = T.alloc_shared((block_S,), dtype=gate_dtype)
G_shared_exp = T.alloc_shared((block_S,), dtype=gate_dtype)
dw_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype)
du_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype)
dA_fragment = T.alloc_fragment((block_S, block_S), dtype=accum_dtype)
dk_fragment = T.alloc_fragment((block_S, block_DK), dtype=accum_dtype)
dk_fragment_beta_g = T.alloc_fragment((block_S, block_DK), dtype=accum_dtype)
dv_fragment = T.alloc_fragment((block_S, block_DV), dtype=accum_dtype)
dv_fragment_beta = T.alloc_fragment((block_S, block_DV), dtype=accum_dtype)
dbeta_fragment_k = T.alloc_fragment((block_S,), dtype=accum_dtype)
dbeta_fragment_v = T.alloc_fragment((block_S,), dtype=accum_dtype)
dbeta_fragment_reduce_tmpk = T.alloc_fragment((block_S, block_DK), dtype=accum_dtype)
dbeta_fragment_reduce_tmpv = T.alloc_fragment((block_S, block_DV), dtype=accum_dtype)
dg_fragment = T.alloc_fragment((block_S,), dtype=gate_dtype)
dg_fragment_reduce_tmp = T.alloc_fragment((block_S, block_DK), dtype=gate_dtype)
T.use_swizzle(10)
T.clear(dA_fragment)
T.clear(dk_fragment)
T.clear(dk_fragment_beta_g)
T.clear(dv_fragment)
T.clear(dv_fragment_beta)
T.clear(dbeta_fragment_k)
T.clear(dbeta_fragment_v)
T.clear(dg_fragment)
T.copy(A[bb, bs * block_S:(bs + 1) * block_S, bh, :], A_shared)
for i_s in T.Parallel(block_S):
Beta_shared[i_s] = Beta[bb, bs * block_S + i_s, bh]
G_shared[i_s] = G[bb, bs * block_S + i_s, bh]
G_shared_exp[i_s] = T.exp(G[bb, bs * block_S + i_s, bh])
# Update dk
for i_k in T.Pipelined(T.ceildiv(DK, block_DK), num_stages=num_stages):
T.copy(
K[bb, bs * block_S:(bs + 1) * block_S, bh, i_k * block_DK:(i_k + 1) * block_DK],
K_shared)
for i_s, i_k2 in T.Parallel(block_S, block_DK):
K_shared_beta_g[i_s,
i_k2] = K_shared[i_s,
i_k2] * Beta_shared[i_s] * G_shared_exp[i_s]
T.copy(
dw[bb, bs * block_S:(bs + 1) * block_S, bh,
i_k * block_DK:(i_k + 1) * block_DK], dw_shared)
T.gemm(dw_shared, K_shared_beta_g, dA_fragment, transpose_B=True)
T.gemm(A_shared, dw_shared, dk_fragment_beta_g, clear_accum=True, transpose_A=True)
for i_s, i_k2 in T.Parallel(block_S, block_DK):
dk_fragment[
i_s,
i_k2] = dk_fragment_beta_g[i_s, i_k2] * Beta_shared[i_s] * G_shared_exp[i_s]
# for i_s, i_k2 in T.Parallel(block_S, block_DK):
# dbeta_fragment[i_s] = dbeta_fragment[i_s] + dk_fragment_beta_g[i_s, i_k2] * K_shared[i_s, i_k2] * G_shared_exp[i_s]
for i_s, i_k2 in T.Parallel(block_S, block_DK):
dbeta_fragment_reduce_tmpk[i_s, i_k2] = dk_fragment_beta_g[
i_s, i_k2] * K_shared[i_s, i_k2] * G_shared_exp[i_s]
T.reduce_sum(dbeta_fragment_reduce_tmpk, dbeta_fragment_k, dim=1, clear=False)
# for i_s, i_k2 in T.Parallel(block_S, block_DK):
# dg_fragment[i_s] = dg_fragment[i_s] + dk_fragment_beta_g[i_s, i_k2] * K_shared[i_s, i_k2] * G_shared_exp[i_s] * Beta_shared[i_s]
for i_s, i_k2 in T.Parallel(block_S, block_DK):
dg_fragment_reduce_tmp[i_s, i_k2] = dk_fragment_beta_g[i_s, i_k2] * K_shared[
i_s, i_k2] * G_shared_exp[i_s] * Beta_shared[i_s]
T.reduce_sum(dg_fragment_reduce_tmp, dg_fragment, dim=1, clear=False)
# correct dk
T.copy(
dk_fragment, dk[bb, bs * block_S:(bs + 1) * block_S, bh,
i_k * block_DK:(i_k + 1) * block_DK])
# Update dv
for i_v in T.Pipelined(T.ceildiv(DV, block_DV), num_stages=num_stages):
T.copy(
V[bb, bs * block_S:(bs + 1) * block_S, bh, i_v * block_DV:(i_v + 1) * block_DV],
V_shared)
for i_s, i_v2 in T.Parallel(block_S, block_DV):
V_shared_beta[i_s, i_v2] = V_shared[i_s, i_v2] * Beta_shared[i_s]
T.copy(
du[bb, bs * block_S:(bs + 1) * block_S, bh,
i_v * block_DV:(i_v + 1) * block_DV], du_shared)
T.gemm(du_shared, V_shared_beta, dA_fragment, transpose_B=True)
T.gemm(A_shared, du_shared, dv_fragment_beta, clear_accum=True, transpose_A=True)
for i_s, i_v2 in T.Parallel(block_S, block_DV):
dv_fragment[i_s, i_v2] = dv_fragment_beta[i_s, i_v2] * Beta_shared[i_s]
# for i_s, i_v2 in T.Parallel(block_S, block_DV):
# dbeta_fragment[i_s] = dbeta_fragment[i_s] + dv_fragment_beta[i_s, i_v2] * V_shared[i_s, i_v2]
for i_s, i_v2 in T.Parallel(block_S, block_DV):
dbeta_fragment_reduce_tmpv[i_s,
i_v2] = dv_fragment_beta[i_s, i_v2] * V_shared[i_s,
i_v2]
T.reduce_sum(dbeta_fragment_reduce_tmpv, dbeta_fragment_v, dim=1, clear=False)
T.copy(
dv_fragment, dv[bb, bs * block_S:(bs + 1) * block_S, bh,
i_v * block_DV:(i_v + 1) * block_DV])
# Temporary store dbeta, dg and dA
for i_s in T.Parallel(block_S):
dbeta[bb, bs * block_S + i_s, bh] = dbeta_fragment_k[i_s] + dbeta_fragment_v[i_s]
dg[bb, bs * block_S + i_s, bh] = dg_fragment[i_s]
# correct dA
T.copy(dA_fragment, dA[bb, bs * block_S:(bs + 1) * block_S, bh, :])
return kernel
@tilelang.jit(
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True
})
def tilelang_wy_fast_bwd_split(
# task config
B,
S,
H,
DK,
DV,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
state_dtype,
chunk_size,
# kernel config
block_DK=64,
block_DV=64,
threads=128,
num_stages=0,
):
block_S = chunk_size
BS = block_S
K_shape = (B, S, H, DK)
V_shape = (B, S, H, DV)
Beta_shape = (B, S, H)
G_shape = (B, S, H)
A_shape = (B, S, H, BS)
dw_shape = (B, S, H, DK)
du_shape = (B, S, H, DV)
dk_shape = (B, S, H, DK)
dv_shape = (B, S, H, DV)
dbeta_shape = (B, S, H)
dA_shape = (B, S, H, BS)
@T.prim_func
def kernel(
# input
K: T.Tensor(K_shape, dtype=input_dtype),
V: T.Tensor(V_shape, dtype=input_dtype),
Beta: T.Tensor(Beta_shape, dtype=input_dtype),
G: T.Tensor(G_shape, dtype=gate_dtype),
A: T.Tensor(A_shape, dtype=input_dtype),
dw: T.Tensor(dw_shape, dtype=input_dtype),
du: T.Tensor(du_shape, dtype=input_dtype),
dA: T.Tensor(dA_shape, dtype=input_dtype),
dk: T.Tensor(dk_shape, dtype=output_dtype),
dv: T.Tensor(dv_shape, dtype=output_dtype),
dbeta_k: T.Tensor(dbeta_shape, dtype=output_dtype),
dg_A_positive: T.Tensor(dA_shape, dtype=gate_dtype),
dg_A_negative: T.Tensor(dA_shape, dtype=gate_dtype),
):
with T.Kernel(T.ceildiv(S, block_S), B * H, threads=threads) as (bs, bbh):
bb, bh = bbh // H, bbh % H
A_shared = T.alloc_shared((block_S, block_S), dtype=input_dtype)
A_fragment = T.alloc_fragment((block_S, block_S), dtype=accum_dtype)
K_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype)
K_shared_beta = T.alloc_shared((block_S, block_DK), dtype=input_dtype)
dA_shared = T.alloc_shared((block_S, block_S), dtype=input_dtype)
dA_fragment = T.alloc_fragment((block_S, block_S), dtype=accum_dtype)
dA_A_fragment = T.alloc_fragment((block_S, block_S), dtype=accum_dtype)
dA_A_fragment_1 = T.alloc_fragment((block_S,), dtype=accum_dtype)
dA_A_fragment_2 = T.alloc_fragment((block_S,), dtype=accum_dtype)
dk_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype)
dk_shared_beta = T.alloc_shared((block_S, block_DK), dtype=input_dtype)
dk_fragment = T.alloc_fragment((block_S, block_DK), dtype=accum_dtype)
dk_fragment_beta = T.alloc_fragment((block_S, block_DK), dtype=accum_dtype)
Beta_shared = T.alloc_shared((block_S,), dtype=input_dtype)
dbeta_fragment_reduce_tmpk = T.alloc_fragment((block_S, block_DK), dtype=accum_dtype)
dbeta_fragment_k = T.alloc_fragment((block_S,), dtype=accum_dtype)
G_shared = T.alloc_shared((block_S,), dtype=gate_dtype)
G_shared_exp = T.alloc_shared((block_S,), dtype=gate_dtype)
T.clear(dbeta_fragment_reduce_tmpk)
T.clear(dbeta_fragment_k)
T.clear(dA_A_fragment_1)
T.clear(dA_A_fragment_2)
T.copy(A[bb, bs * block_S:(bs + 1) * block_S, bh, :], A_shared)
for i_s in T.Parallel(block_S):
Beta_shared[i_s] = Beta[bb, bs * block_S + i_s, bh]
G_shared[i_s] = G[bb, bs * block_S + i_s, bh]
for i_s in T.Parallel(block_S):
G_shared_exp[i_s] = T.exp(G_shared[i_s])
# Load intermediate results
# for i_s in T.Parallel(block_S):
# dbeta_fragment[i_s] = dbeta[bb, bs * block_S + i_s, bh]
# dg_fragment[i_s] = dg[bb, bs * block_S + i_s, bh]
T.copy(dA[bb, bs * block_S:(bs + 1) * block_S, bh, :], dA_shared)
# T.copy(dA_shared, dA[bb, bs * block_S:(bs + 1) * block_S, bh, :])
# Update dA
T.copy(dA_shared, dA_fragment)
for i_s1, i_s2 in T.Parallel(block_S, block_S):
with T.If(i_s1 <= i_s2): # noqa: SIM117
with T.Then():
dA_fragment[i_s1, i_s2] = 0
T.copy(dA_fragment, dA_shared)
T.gemm(dA_shared, A_shared, dA_fragment, clear_accum=True, transpose_B=True)
T.copy(dA_fragment, dA_shared)
T.gemm(A_shared, dA_shared, dA_fragment, clear_accum=True, transpose_A=True)
for i_s1, i_s2 in T.Parallel(block_S, block_S):
with T.If(i_s1 <= i_s2):
with T.Then():
dA_fragment[i_s1, i_s2] = 0
with T.Else():
dA_fragment[i_s1, i_s2] = -dA_fragment[i_s1, i_s2]
for i_s1, i_s2 in T.Parallel(block_S, block_S):
with T.If(G[bb, bs * block_S + i_s1, bh] - G[bb, bs * block_S + i_s2, bh] <= 0):
with T.Then():
dA_fragment[i_s1, i_s2] *= T.exp(G[bb, bs * block_S + i_s1, bh] -
G[bb, bs * block_S + i_s2, bh])
with T.Else():
dA_fragment[i_s1, i_s2] = 0
T.copy(dA_fragment, dA_shared)
# acceptable dA diff
# T.copy(dA_fragment, dA[bb, bs * block_S:(bs + 1) * block_S, bh, :])
# Update dk using previous dk
T.clear(A_fragment)
for i_k in T.Pipelined(T.ceildiv(DK, block_DK), num_stages=num_stages):
T.copy(
K[bb, bs * block_S:(bs + 1) * block_S, bh, i_k * block_DK:(i_k + 1) * block_DK],
K_shared)
T.copy(
dk[bb, bs * block_S:(bs + 1) * block_S, bh,
i_k * block_DK:(i_k + 1) * block_DK], dk_shared)
T.copy(dk_shared, dk_fragment)
for i_s, i_k2 in T.Parallel(block_S, block_DK):
K_shared_beta[i_s, i_k2] = K_shared[i_s, i_k2] * Beta_shared[i_s]
T.gemm(K_shared_beta, K_shared, A_fragment, transpose_B=True)
T.gemm(dA_shared, K_shared, dk_fragment_beta, clear_accum=True)
# for i_s, i_k2 in T.Parallel(block_S, block_DK):
# dbeta_fragment[i_s] = dbeta_fragment[i_s] + dk_fragment_beta[i_s, i_k2] * K_shared[i_s, i_k2]
for i_s, i_k2 in T.Parallel(block_S, block_DK):
dbeta_fragment_reduce_tmpk[i_s,
i_k2] = dk_fragment_beta[i_s, i_k2] * K_shared[i_s,
i_k2]
T.reduce_sum(dbeta_fragment_reduce_tmpk, dbeta_fragment_k, dim=1, clear=False)
T.gemm(dA_shared, K_shared_beta, dk_fragment, transpose_A=True)
for i_s, i_k2 in T.Parallel(block_S, block_DK):
dk_shared_beta[i_s, i_k2] = dk_fragment_beta[i_s, i_k2] * Beta_shared[i_s]
for i_s, i_k2 in T.Parallel(block_S, block_DK):
dk_fragment[i_s, i_k2] = dk_fragment[i_s, i_k2] + dk_shared_beta[i_s, i_k2]
T.copy(
dk_fragment, dk[bb, bs * block_S:(bs + 1) * block_S, bh,
i_k * block_DK:(i_k + 1) * block_DK])
# Update dg and dbeta
T.copy(A_fragment, A_shared)
for i_s1, i_s2 in T.Parallel(block_S, block_S):
dA_A_fragment[i_s1, i_s2] = dA_fragment[i_s1, i_s2] * A_fragment[i_s1, i_s2]
# Note: Reduce operation now not supported in shared memory
# FIXME: reduce will cause incorrect result when dim != -1
T.reduce_sum(dA_A_fragment, dA_A_fragment_1, dim=1)
T.reduce_sum(dA_A_fragment, dA_A_fragment_2, dim=0)
for i_s1, i_s2 in T.Parallel(block_S, block_S):
dg_A_positive[bb, bs * block_S + i_s1, bh, i_s2] = dA_A_fragment[i_s1, i_s2]
dg_A_negative[bb, bs * block_S + i_s2, bh, i_s1] = dA_A_fragment[i_s1, i_s2]
for i_s in T.Parallel(block_S):
dbeta_k[bb, bs * block_S + i_s, bh] = dbeta_fragment_k[i_s]
return kernel
def run_test(
B,
S,
H,
DK,
DV,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
state_dtype,
chunk_size,
block_DK=64,
block_DV=64,
threads=128,
num_stages=0,
):
K, V, Beta, G, A, dw, du = prepare_input(B, S, H, DK, DV, chunk_size,
getattr(torch, input_dtype),
getattr(torch, output_dtype),
getattr(torch,
accum_dtype), getattr(torch, gate_dtype),
getattr(torch, state_dtype))
dk_ref, dv_ref, dbeta_ref, dg_ref = prepare_output(B, S, H, DK, DV, chunk_size,
getattr(torch, output_dtype),
getattr(torch, gate_dtype),
getattr(torch, state_dtype))
dk_tilelang, dv_tilelang, dbeta_tilelang, dg_tilelang = prepare_output(
B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype),
getattr(torch, state_dtype))
BS = chunk_size
dA_tilelang = torch.empty(B, S, H, BS, dtype=getattr(torch, input_dtype)).cuda()
dbeta_tilelang_k = torch.empty(B, S, H, dtype=getattr(torch, output_dtype)).cuda()
dg_tilelang_A_positive = torch.empty(B, S, H, BS, dtype=getattr(torch, gate_dtype)).cuda()
dg_tilelang_A_negative = torch.empty(B, S, H, BS, dtype=getattr(torch, gate_dtype)).cuda()
# ref
dk_ref, dv_ref, dbeta_ref, dg_ref = bwd_prepare_wy_repr(
K, V, G, Beta, A, dw, du, cu_seqlens=None)
# tilelang
kernel = tilelang_wy_fast_bwd(B, S, H, DK, DV, input_dtype, output_dtype, accum_dtype,
gate_dtype, state_dtype, chunk_size, block_DK, block_DV, threads,
num_stages)
dA_tilelang, dk_tilelang, dv_tilelang, dbeta_tilelang, dg_tilelang = kernel(
K, V, Beta, G, A, dw, du)
torch.cuda.synchronize()
kernel_split = tilelang_wy_fast_bwd_split(B, S, H, DK, DV, input_dtype, output_dtype,
accum_dtype, gate_dtype, state_dtype, chunk_size,
block_DK, block_DV, threads, num_stages)
kernel_split(K, V, Beta, G, A, dw, du, dA_tilelang, dk_tilelang, dv_tilelang, dbeta_tilelang_k,
dg_tilelang_A_positive, dg_tilelang_A_negative)
torch.cuda.synchronize()
dbeta_tilelang = dbeta_tilelang_k + dbeta_tilelang
dg_tilelang = dg_tilelang + dg_tilelang_A_positive.sum(dim=-1) - dg_tilelang_A_negative.sum(
dim=-1)
assert_similar(dk_ref, dk_tilelang, eps=1e-5, name="dk", raise_assert=False)
assert_similar(dv_ref, dv_tilelang, eps=1e-5, name="dv", raise_assert=False)
assert_similar(dbeta_ref, dbeta_tilelang, eps=1e-5, name="dbeta", raise_assert=False)
assert_similar(dg_ref, dg_tilelang, eps=1e-5, name="dg", raise_assert=False)
def main():
DK = 128
DV = 128
run_test(
B=1,
S=32768,
H=8,
DK=DK,
DV=DV,
input_dtype="bfloat16",
output_dtype="bfloat16",
accum_dtype="float32",
gate_dtype="float32",
state_dtype="float32",
chunk_size=64,
block_DK=32,
block_DV=32,
threads=128,
num_stages=0,
)
if __name__ == "__main__":
main()
import tilelang.testing
import torch
tilelang.disable_cache()
B = 1
S = 32768
H = 32
DK = 128
DV = 128
input_dtype = "bfloat16"
output_dtype = "bfloat16"
accum_dtype = "float32"
gate_dtype = "float32"
state_dtype = "float32"
chunk_size = 64
use_g = True
use_initial_state = True
store_final_state = True
use_final_state_gradient = True
save_new_value = True
block_DK = 64
block_DV = 32
threads = 128
num_stages = 1
def test_example_wy_fast_compilation():
from example_wy_fast import tilelang_recompute_w_u_fwd, prepare_input, prepare_output
K, V, Beta, G, A = prepare_input(
B,
S,
H,
DK,
DV,
chunk_size,
getattr(torch, input_dtype),
getattr(torch, output_dtype),
gate_dtype=getattr(torch, gate_dtype))
W_tilelang, U_tilelang = prepare_output(B, S, H, DK, DV, getattr(torch, output_dtype))
# tilelang
block_S = chunk_size
kernel = tilelang_recompute_w_u_fwd(
B,
S,
H,
DK,
DV,
input_dtype,
output_dtype,
gate_dtype,
accum_dtype,
chunk_size,
block_S=block_S,
block_DK=block_DK,
block_DV=block_DV,
threads=threads,
num_stages=num_stages)
print(kernel.get_kernel_source())
W_tilelang, U_tilelang = kernel(K, V, Beta, G, A)
def test_example_wy_fast_bwd_split_compilation():
from example_wy_fast_bwd_split import tilelang_wy_fast_bwd, tilelang_wy_fast_bwd_split, prepare_input, prepare_output
K, V, Beta, G, A, dw, du = prepare_input(B, S, H, DK, DV, chunk_size,
getattr(torch, input_dtype),
getattr(torch, output_dtype),
getattr(torch,
accum_dtype), getattr(torch, gate_dtype),
getattr(torch, state_dtype))
dk_tilelang, dv_tilelang, dbeta_tilelang, dg_tilelang = prepare_output(
B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype),
getattr(torch, state_dtype))
BS = chunk_size
dA_tilelang = torch.empty(B, S, H, BS, dtype=getattr(torch, input_dtype)).cuda()
dbeta_tilelang_k = torch.empty(B, S, H, dtype=getattr(torch, output_dtype)).cuda()
dg_tilelang_A_positive = torch.empty(B, S, H, BS, dtype=getattr(torch, gate_dtype)).cuda()
dg_tilelang_A_negative = torch.empty(B, S, H, BS, dtype=getattr(torch, gate_dtype)).cuda()
# tilelang
kernel = tilelang_wy_fast_bwd(B, S, H, DK, DV, input_dtype, output_dtype, accum_dtype,
gate_dtype, state_dtype, chunk_size, block_DK, block_DV, threads,
num_stages)
dA_tilelang, dk_tilelang, dv_tilelang, dbeta_tilelang, dg_tilelang = kernel(
K, V, Beta, G, A, dw, du)
torch.cuda.synchronize()
kernel_split = tilelang_wy_fast_bwd_split(B, S, H, DK, DV, input_dtype, output_dtype,
accum_dtype, gate_dtype, state_dtype, chunk_size,
block_DK, block_DV, threads, num_stages)
kernel_split(K, V, Beta, G, A, dw, du, dA_tilelang, dk_tilelang, dv_tilelang, dbeta_tilelang_k,
dg_tilelang_A_positive, dg_tilelang_A_negative)
torch.cuda.synchronize()
dbeta_tilelang = dbeta_tilelang_k + dbeta_tilelang
dg_tilelang = dg_tilelang + dg_tilelang_A_positive.sum(dim=-1) - dg_tilelang_A_negative.sum(
dim=-1)
def test_example_chunk_o_compilation():
from example_chunk_o import tilelang_chunk_fwd_o, prepare_input, prepare_output
Q, K, V, HIDDEN, G = prepare_input(B, S, H, DK, DV, chunk_size, getattr(torch, input_dtype),
getattr(torch, output_dtype), getattr(torch, accum_dtype),
getattr(torch, gate_dtype))
scale = 1.0 / DK**0.5
block_S = chunk_size
O_tilelang = prepare_output(B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype))
kernel = tilelang_chunk_fwd_o(B, S, H, DK, DV, input_dtype, output_dtype, accum_dtype,
gate_dtype, chunk_size, scale, use_g, block_S, block_DK, block_DV,
threads, num_stages)
O_tilelang = kernel(Q, K, V, HIDDEN, G) # noqa: F841
def test_example_chunk_o_bwd_compilation():
from example_chunk_o_bwd import tilelang_chunk_o_bwd_dqkwg, prepare_input, prepare_output
Q, K, V, h, G, dO, dh, dv, W = prepare_input(B, S, H, DK, DV, chunk_size,
getattr(torch, input_dtype),
getattr(torch, output_dtype),
getattr(torch, accum_dtype),
getattr(torch, gate_dtype),
getattr(torch, state_dtype))
dq_tilelang, dk_tilelang, dw_tilelang, dg_tilelang = prepare_output(
B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype),
getattr(torch, state_dtype), block_DK)
kernel = tilelang_chunk_o_bwd_dqkwg(B, S, H, DK, DV, input_dtype, output_dtype, accum_dtype,
gate_dtype, state_dtype, chunk_size, 1.0, use_g, True,
block_DK, block_DV, threads, num_stages)
dq_tilelang, dk_tilelang, dw_tilelang, dg_tilelang = kernel(Q, K, V, h, G, dO, dh, dv,
W) # noqa: F841
if use_g:
dg_tilelang = dg_tilelang.sum(dim=0)
def test_example_chunk_scaled_dot_kkt_compilation():
from example_chunk_scaled_dot_kkt import tilelang_chunk_scaled_dot_kkt_fwd, prepare_input, prepare_output
K, Beta, G = prepare_input(B, S, H, DK, getattr(torch, input_dtype),
getattr(torch, output_dtype), getattr(torch, accum_dtype))
A_tilelang = prepare_output(B, S, H, chunk_size, getattr(torch, output_dtype))
block_S = chunk_size
kernel = tilelang_chunk_scaled_dot_kkt_fwd(B, S, H, DK, chunk_size, input_dtype, output_dtype,
accum_dtype, use_g, block_S, block_DK, threads,
num_stages)
A_tilelang = kernel(K, Beta, G) # noqa: F841
def test_example_cumsum_compilation():
from example_cumsum import tilelang_chunk_local_cumsum_scalar, prepare_cumsum_input, prepare_cumsum_output
G = prepare_cumsum_input(B, S, H, getattr(torch, gate_dtype))
G_new_tilelang = prepare_cumsum_output(B, S, H, getattr(torch, gate_dtype))
block_S = chunk_size
kernel = tilelang_chunk_local_cumsum_scalar(
B=B,
S=S,
H=H,
chunk_size=chunk_size,
reverse=False,
head_first=False,
input_dtype=gate_dtype,
output_dtype=gate_dtype,
block_S=block_S,
threads=threads,
use_fragment=False,
)
G_new_tilelang = kernel(G) # noqa: F841
def test_example_chunk_delta_h_compilation():
from example_chunk_delta_h import tilelang_chunk_gated_delta_rule_fwd_h, prepare_input, prepare_output
K, W, U, G, initial_state = prepare_input(B, S, H, DK, DV, chunk_size,
getattr(torch, input_dtype),
getattr(torch, output_dtype),
getattr(torch, accum_dtype),
getattr(torch, gate_dtype))
h_tilelang, final_state_tilelang, V_new_tilelang = prepare_output(B, S, H, DK, DV, chunk_size,
getattr(torch, output_dtype),
getattr(torch, state_dtype))
kernel = tilelang_chunk_gated_delta_rule_fwd_h(B, S, H, DK, DV, input_dtype, output_dtype,
accum_dtype, gate_dtype, state_dtype, chunk_size,
use_g, use_initial_state, store_final_state,
save_new_value, block_DK, block_DV, threads,
num_stages)
h_tilelang, final_state_tilelang, V_new_tilelang = kernel(K, W, U, G,
initial_state) # noqa: F841
def test_example_chunk_delta_bwd_compilation():
from example_chunk_delta_bwd import tilelang_chunk_gated_delta_rule_bwd_dhu, prepare_input, prepare_output
Q, K, W, G, h0, dht, dO, dv = prepare_input(B, S, H, DK, DV, chunk_size,
getattr(torch, input_dtype),
getattr(torch, output_dtype),
getattr(torch, accum_dtype),
getattr(torch, gate_dtype),
getattr(torch, state_dtype))
dh_tilelang, dh0_tilelang, dv2_tilelang = prepare_output(B, S, H, DK, DV, chunk_size,
getattr(torch, output_dtype),
getattr(torch, gate_dtype),
getattr(torch, state_dtype))
kernel = tilelang_chunk_gated_delta_rule_bwd_dhu(B, S, H, DK, DV, input_dtype, output_dtype,
accum_dtype, gate_dtype, state_dtype,
chunk_size, 1.0, use_g, use_initial_state,
use_final_state_gradient, block_DV, threads,
num_stages)
dh_tilelang, dh0_tilelang, dv2_tilelang = kernel(Q, K, W, G, h0, dht, dO, dv) # noqa: F841
if __name__ == "__main__":
tilelang.testing.main()
import torch
def print_red_warning(message):
print(f"\033[31mWARNING: {message}\033[0m")
def calc_sim(x, y, name="tensor"):
x, y = x.data.double(), y.data.double()
denominator = (x * x + y * y).sum()
if denominator == 0:
print_red_warning(f'{name} all zero')
return 1
sim = 2 * (x * y).sum() / denominator
return sim
def assert_similar(x, y, eps=1e-8, name="tensor", data="", raise_assert=True):
x_mask = torch.isfinite(x)
y_mask = torch.isfinite(y)
if not torch.all(x_mask == y_mask):
print_red_warning(f'{name} Error: isfinite mask mismatch')
if raise_assert:
raise AssertionError
if not torch.isclose(
x.masked_fill(x_mask, 0), y.masked_fill(y_mask, 0), rtol=0, atol=0,
equal_nan=True).all():
print_red_warning(f'{name} Error: nonfinite value mismatch')
if raise_assert:
raise AssertionError
x = x.masked_fill(~x_mask, 0)
y = y.masked_fill(~y_mask, 0)
sim = calc_sim(x, y, name)
diff = 1. - sim
if not (0 <= diff <= eps):
print_red_warning(f'{name} Error: {diff}')
if raise_assert:
raise AssertionError
else:
print(f"{name} {data} passed")
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment