Unverified Commit 667632cc authored by guchaoyang's avatar guchaoyang Committed by GitHub
Browse files

Merge branch 'main' into dcu

parents d6dd2ddf a874e4e8
...@@ -15,18 +15,12 @@ torch.random.manual_seed(0) ...@@ -15,18 +15,12 @@ torch.random.manual_seed(0)
def get_configs(): def get_configs():
block_N = [64, 128] block_N = [64, 128]
block_H = [64] block_H = [64]
num_split = [2, 4, 8] num_split = [1, 2, 4, 8]
num_stages = [1, 2, 3] num_stages = [1, 2, 3]
threads = [128] threads = [128]
_configs = list(itertools.product(block_N, block_H, num_split, num_stages, threads)) _configs = list(itertools.product(block_N, block_H, num_split, num_stages, threads))
configs = [{ configs = [{"block_N": c[0], "block_H": c[1], "num_split": c[2], "num_stages": c[3], "threads": c[4]} for c in _configs]
'block_N': c[0],
'block_H': c[1],
'num_split': c[2],
'num_stages': c[3],
'threads': c[4]
} for c in _configs]
return configs return configs
...@@ -42,29 +36,25 @@ def get_heuristic_config() -> Tuple[Dict, int]: ...@@ -42,29 +36,25 @@ def get_heuristic_config() -> Tuple[Dict, int]:
if sm_version == 89: if sm_version == 89:
cfg = dict(block_N=128, block_H=64, num_split=1, num_stages=0, threads=128) cfg = dict(block_N=128, block_H=64, num_split=1, num_stages=0, threads=128)
else: else:
cfg = dict(block_N=128, block_H=64, num_split=1, num_stages=2, threads=128) cfg = dict(block_N=128, block_H=64, num_split=8, num_stages=2, threads=128)
return cfg, sm_version return cfg, sm_version
# TODO(lei): fix warp specialized and tma lower pass # TODO(lei): fix warp specialized and tma lower pass
def get_pass_configs(): def get_pass_configs():
return { return {tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True
}
@autotune(configs=get_configs(), warmup=10, rep=10) @autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit(out_idx=[6], pass_configs=get_pass_configs()) @tilelang.jit(out_idx=[6], pass_configs=get_pass_configs())
def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split, num_stages, def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split, num_stages, threads):
threads): scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
shape_q = [batch, heads, dim] shape_q = [batch, heads, dim]
shape_k = [batch, seqlen_kv, groups, dim] shape_k = [batch, seqlen_kv, groups, dim]
shape_v = [batch, seqlen_kv, groups, dim] shape_v = [batch, seqlen_kv, groups, dim]
shape_o = [batch, heads, dim] shape_o = [batch, heads, dim]
dtype = "float16" dtype = T.float16
accum_dtype = "float" accum_dtype = T.float32
kv_group_num = heads // groups kv_group_num = heads // groups
part_shape = [batch, heads, num_split, dim] part_shape = [batch, heads, num_split, dim]
...@@ -73,11 +63,11 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split, ...@@ -73,11 +63,11 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split,
@T.macro @T.macro
def flash_attn( def flash_attn(
Q: T.Tensor(shape_q, dtype), Q: T.Tensor(shape_q, dtype),
K: T.Tensor(shape_k, dtype), K: T.Tensor(shape_k, dtype),
V: T.Tensor(shape_v, dtype), V: T.Tensor(shape_v, dtype),
mask: T.Tensor([batch, seqlen_kv, groups], "uint8"), mask: T.Tensor([batch, seqlen_kv, groups], "uint8"),
Output: T.Tensor([batch, heads, dim], dtype), Output: T.Tensor([batch, heads, dim], dtype),
): ):
with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz): with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_H, dim], dtype) Q_shared = T.alloc_shared([block_H, dim], dtype)
...@@ -98,23 +88,24 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split, ...@@ -98,23 +88,24 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split,
hid = by hid = by
cur_kv_head = hid // (kv_group_num // valid_block_H) cur_kv_head = hid // (kv_group_num // valid_block_H)
T.copy(Q[bid, hid * valid_block_H:hid * valid_block_H + block_H, :], Q_shared) T.copy(Q[bid, hid * valid_block_H : hid * valid_block_H + block_H, :], Q_shared)
T.fill(acc_o, 0) T.fill(acc_o, 0)
T.fill(logsum, 0) T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = T.ceildiv((seqlen_kv // num_split), block_N) loop_range = T.ceildiv((seqlen_kv // num_split), block_N)
for k in T.Pipelined(loop_range, num_stages=num_stages): for k in T.Pipelined(loop_range, num_stages=num_stages):
T.copy(K[bid, k * block_N:(k + 1) * block_N, cur_kv_head, :], K_shared) T.copy(K[bid, k * block_N : (k + 1) * block_N, cur_kv_head, :], K_shared)
T.copy(mask[bid, k * block_N:(k + 1) * block_N, cur_kv_head], mask_local) T.copy(mask[bid, k * block_N : (k + 1) * block_N, cur_kv_head], mask_local)
T.clear(acc_s) T.clear(acc_s)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
for i, j in T.Parallel(block_H, block_N): for i, j in T.Parallel(block_H, block_N):
acc_s[i, j] = T.if_then_else(mask_local[j] != 0, acc_s[i, j], acc_s[i, j] = T.if_then_else(mask_local[j] != 0, acc_s[i, j], -T.infinity(accum_dtype))
-T.infinity(accum_dtype))
T.copy(scores_max, scores_max_prev) T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False) T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_H):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
for i in T.Parallel(block_H): for i in T.Parallel(block_H):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_H, block_N): for i, j in T.Parallel(block_H, block_N):
...@@ -125,23 +116,23 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split, ...@@ -125,23 +116,23 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split,
T.copy(acc_s, acc_s_cast) T.copy(acc_s, acc_s_cast)
for i, j in T.Parallel(block_H, dim): for i, j in T.Parallel(block_H, dim):
acc_o[i, j] *= scores_scale[i] acc_o[i, j] *= scores_scale[i]
T.copy(V[bid, k * block_N:(k + 1) * block_N, cur_kv_head, :], V_shared) T.copy(V[bid, k * block_N : (k + 1) * block_N, cur_kv_head, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
for i, j in T.Parallel(block_H, dim): for i, j in T.Parallel(block_H, dim):
acc_o[i, j] /= logsum[i] acc_o[i, j] /= logsum[i]
for i in T.Parallel(block_H): for i in T.Parallel(block_H):
logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
T.copy(acc_o[:valid_block_H, :], O_shared) T.copy(acc_o[:valid_block_H, :], O_shared)
T.copy(O_shared, Output[bid, hid * valid_block_H:(hid + 1) * valid_block_H, :]) T.copy(O_shared, Output[bid, hid * valid_block_H : (hid + 1) * valid_block_H, :])
@T.macro @T.macro
def flash_attn_split( def flash_attn_split(
Q: T.Tensor(shape_q, dtype), Q: T.Tensor(shape_q, dtype),
K: T.Tensor(shape_k, dtype), K: T.Tensor(shape_k, dtype),
V: T.Tensor(shape_v, dtype), V: T.Tensor(shape_v, dtype),
mask: T.Tensor([batch, seqlen_kv, groups], "uint8"), mask: T.Tensor([batch, seqlen_kv, groups], "uint8"),
glse: T.Tensor([batch, heads, num_split], dtype), glse: T.Tensor([batch, heads, num_split], dtype),
Output_partial: T.Tensor(part_shape, dtype), Output_partial: T.Tensor(part_shape, dtype),
): ):
with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz): with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_H, dim], dtype) Q_shared = T.alloc_shared([block_H, dim], dtype)
...@@ -163,7 +154,7 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split, ...@@ -163,7 +154,7 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split,
sid = bz sid = bz
cur_kv_head = hid // (kv_group_num // valid_block_H) cur_kv_head = hid // (kv_group_num // valid_block_H)
T.copy(Q[bid, hid * valid_block_H:hid * valid_block_H + block_H, :], Q_shared) T.copy(Q[bid, hid * valid_block_H : hid * valid_block_H + block_H, :], Q_shared)
T.fill(acc_o, 0) T.fill(acc_o, 0)
T.fill(logsum, 0) T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
...@@ -172,22 +163,31 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split, ...@@ -172,22 +163,31 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split,
for k in T.Pipelined(loop_range, num_stages=num_stages): for k in T.Pipelined(loop_range, num_stages=num_stages):
T.copy( T.copy(
K[bid, (seqlen_kv // num_split) * sid + K[
k * valid_block_N:(seqlen_kv // num_split) * sid + (k + 1) * valid_block_N, bid,
cur_kv_head, :], K_shared) (seqlen_kv // num_split) * sid + k * valid_block_N : (seqlen_kv // num_split) * sid + (k + 1) * valid_block_N,
cur_kv_head,
:,
],
K_shared,
)
T.copy( T.copy(
mask[bid, (seqlen_kv // num_split) * sid + mask[
k * valid_block_N:(seqlen_kv // num_split) * sid + (k + 1) * valid_block_N, bid,
cur_kv_head], mask_local) (seqlen_kv // num_split) * sid + k * valid_block_N : (seqlen_kv // num_split) * sid + (k + 1) * valid_block_N,
cur_kv_head,
],
mask_local,
)
T.clear(acc_s) T.clear(acc_s)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
for i, j in T.Parallel(block_H, block_N): for i, j in T.Parallel(block_H, block_N):
acc_s[i, acc_s[i, j] = T.if_then_else((mask_local[j] != 0) & (j < seqlen_kv // num_split), acc_s[i, j], -T.infinity(accum_dtype))
j] = T.if_then_else((mask_local[j] != 0) & (j < seqlen_kv // num_split),
acc_s[i, j], -T.infinity(accum_dtype))
T.copy(scores_max, scores_max_prev) T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False) T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_H):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
for i in T.Parallel(block_H): for i in T.Parallel(block_H):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_H, block_N): for i, j in T.Parallel(block_H, block_N):
...@@ -199,9 +199,14 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split, ...@@ -199,9 +199,14 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split,
for i, j in T.Parallel(block_H, dim): for i, j in T.Parallel(block_H, dim):
acc_o[i, j] *= scores_scale[i] acc_o[i, j] *= scores_scale[i]
T.copy( T.copy(
V[bid, (seqlen_kv // num_split) * sid + V[
k * valid_block_N:(seqlen_kv // num_split) * sid + (k + 1) * valid_block_N, bid,
cur_kv_head, :], V_shared) (seqlen_kv // num_split) * sid + k * valid_block_N : (seqlen_kv // num_split) * sid + (k + 1) * valid_block_N,
cur_kv_head,
:,
],
V_shared,
)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
for i, j in T.Parallel(block_H, dim): for i, j in T.Parallel(block_H, dim):
acc_o[i, j] /= logsum[i] acc_o[i, j] /= logsum[i]
...@@ -212,72 +217,74 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split, ...@@ -212,72 +217,74 @@ def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split,
if i < valid_block_H: if i < valid_block_H:
glse[bid, hid * valid_block_H + i, sid] = logsum[i] glse[bid, hid * valid_block_H + i, sid] = logsum[i]
T.copy(acc_o[:valid_block_H, :], O_shared) T.copy(acc_o[:valid_block_H, :], O_shared)
T.copy(O_shared, Output_partial[bid, hid * valid_block_H:(hid + 1) * valid_block_H, T.copy(O_shared, Output_partial[bid, hid * valid_block_H : (hid + 1) * valid_block_H, sid, :])
sid, :])
@T.macro @T.macro
def combine( def combine(
glse: T.Tensor([batch, heads, num_split], dtype), glse: T.Tensor([batch, heads, num_split], dtype),
Output_partial: T.Tensor(part_shape, dtype), Output_partial: T.Tensor(part_shape, dtype),
Output: T.Tensor(shape_o, dtype), Output: T.Tensor(shape_o, dtype),
): ):
with T.Kernel(heads, batch, threads=128) as (by, bz): with T.Kernel(heads, batch, threads=128) as (by, bz):
po_local = T.alloc_fragment([dim], dtype) po_local = T.alloc_fragment([dim], dtype)
o_accum_local = T.alloc_fragment([dim], accum_dtype) o_accum_local = T.alloc_fragment([dim], accum_dtype)
lse_local = T.alloc_fragment([num_split, 128], dtype) lse_local = T.alloc_fragment([num_split, 128], dtype)
lse_local_split = T.alloc_local([1], accum_dtype) lse_logsum_local = T.alloc_fragment([128], accum_dtype)
lse_logsum_local = T.alloc_local([1], accum_dtype)
lse_max_local = T.alloc_fragment([128], accum_dtype) lse_max_local = T.alloc_fragment([128], accum_dtype)
scale_local = T.alloc_local([1], accum_dtype) scale_local = T.alloc_fragment([128], accum_dtype)
T.annotate_layout({ T.annotate_layout(
lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i), {
lse_max_local: T.Fragment(lse_max_local.shape, forward_thread_fn=lambda i: i), lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i),
# lse_local: (local_id, thread_id) lse_max_local: T.Fragment(lse_max_local.shape, forward_thread_fn=lambda i: i),
lse_local: T.Fragment(lse_local.shape, forward_fn=lambda i, j: (j, i)), # lse_local: (local_id, thread_id)
}) lse_local: T.Fragment(lse_local.shape, forward_fn=lambda i, j: (j, i)),
}
)
T.clear(lse_logsum_local) T.clear(lse_logsum_local)
T.clear(o_accum_local) T.clear(o_accum_local)
for k, j in T.Parallel(num_split, 128): for k, j in T.Parallel(num_split, 128):
lse_local[k, j] = glse[bz, by, k] lse_local[k, j] = glse[bz, by, k]
T.reduce_max(lse_local, lse_max_local, dim=0, clear=True) T.reduce_max(lse_local, lse_max_local, dim=0, clear=True)
for k in T.Pipelined(num_split, num_stages=1): for k in T.serial(num_split):
lse_local_split[0] = glse[bz, by, k] for j in T.Parallel(128):
lse_logsum_local[0] += T.exp2(lse_local_split[0] - lse_max_local[0]) lse_logsum_local[j] += T.exp2(lse_local[k, j] - lse_max_local[j])
lse_logsum_local[0] = T.log2(lse_logsum_local[0]) + lse_max_local[0] for j in T.Parallel(128):
lse_logsum_local[j] = T.log2(lse_logsum_local[j]) + lse_max_local[j]
for k in T.serial(num_split): for k in T.serial(num_split):
for i in T.Parallel(dim): for i in T.Parallel(dim):
po_local[i] = Output_partial[bz, by, k, i] po_local[i] = Output_partial[bz, by, k, i]
lse_local_split[0] = glse[bz, by, k] for j in T.Parallel(128):
scale_local[0] = T.exp2(lse_local_split[0] - lse_logsum_local[0]) scale_local[j] = T.exp2(lse_local[k, j] - lse_logsum_local[j])
# Note: Pay attention to dim and the number of threads in Parallel
for i in T.Parallel(dim): for i in T.Parallel(dim):
o_accum_local[i] += po_local[i] * scale_local[0] o_accum_local[i] += po_local[i] * scale_local[i]
for i in T.Parallel(dim): for i in T.Parallel(dim):
Output[bz, by, i] = o_accum_local[i] Output[bz, by, i] = o_accum_local[i]
@T.prim_func @T.prim_func
def flashattn_gqa_decode_split( def flashattn_gqa_decode_split(
Q: T.Tensor(shape_q, dtype), Q: T.Tensor(shape_q, dtype),
K: T.Tensor(shape_k, dtype), K: T.Tensor(shape_k, dtype),
V: T.Tensor(shape_v, dtype), V: T.Tensor(shape_v, dtype),
mask: T.Tensor([batch, seqlen_kv, groups], "uint8"), mask: T.Tensor([batch, seqlen_kv, groups], "uint8"),
glse: T.Tensor([batch, heads, num_split], dtype), glse: T.Tensor([batch, heads, num_split], dtype),
Output_partial: T.Tensor(part_shape, dtype), Output_partial: T.Tensor(part_shape, dtype),
Output: T.Tensor(shape_o, dtype), Output: T.Tensor(shape_o, dtype),
): ):
flash_attn_split(Q, K, V, mask, glse, Output_partial) flash_attn_split(Q, K, V, mask, glse, Output_partial)
combine(glse, Output_partial, Output) combine(glse, Output_partial, Output)
@T.prim_func @T.prim_func
def flashattn_gqa_decode_no_split( def flashattn_gqa_decode_no_split(
Q: T.Tensor(shape_q, dtype), Q: T.Tensor(shape_q, dtype),
K: T.Tensor(shape_k, dtype), K: T.Tensor(shape_k, dtype),
V: T.Tensor(shape_v, dtype), V: T.Tensor(shape_v, dtype),
mask: T.Tensor([batch, seqlen_kv, groups], "uint8"), mask: T.Tensor([batch, seqlen_kv, groups], "uint8"),
glse: T.Tensor([batch, heads, num_split], dtype), glse: T.Tensor([batch, heads, num_split], dtype),
Output_partial: T.Tensor(part_shape, dtype), Output_partial: T.Tensor(part_shape, dtype),
Output: T.Tensor(shape_o, dtype), Output: T.Tensor(shape_o, dtype),
): ):
flash_attn(Q, K, V, mask, Output) flash_attn(Q, K, V, mask, Output)
...@@ -300,27 +307,21 @@ def ref_program(query, key, value, mask, glse, Output_partial): ...@@ -300,27 +307,21 @@ def ref_program(query, key, value, mask, glse, Output_partial):
dim = query.shape[-1] dim = query.shape[-1]
num_head_groups = query.shape[1] // key.shape[2] num_head_groups = query.shape[1] // key.shape[2]
scale = dim**0.5 scale = dim**0.5
key = rearrange(key, 'b n h d -> b h n d') # [batch_size, groups, seqlen_kv, dim] key = rearrange(key, "b n h d -> b h n d") # [batch_size, groups, seqlen_kv, dim]
value = rearrange(value, 'b n h d -> b h n d') # [batch_size, groups, seqlen_kv, dim] value = rearrange(value, "b n h d -> b h n d") # [batch_size, groups, seqlen_kv, dim]
query = rearrange( query = rearrange(query, "b (h g) d -> b g h d", g=num_head_groups) # [batch_size, num_head_groups, groups, dim]
query, 'b (h g) d -> b g h d',
g=num_head_groups) # [batch_size, num_head_groups, groups, dim]
scores = einsum( scores = einsum(query, key, "b g h d, b h s d -> b g h s") # [batch_size, num_head_groups, groups, seqlen_kv]
query, key,
'b g h d, b h s d -> b g h s') # [batch_size, num_head_groups, groups, seqlen_kv]
if mask is not None: if mask is not None:
mask = rearrange(mask, 'b s h -> b h s') mask = rearrange(mask, "b s h -> b h s")
mask = mask.unsqueeze(1) mask = mask.unsqueeze(1)
scores = scores.masked_fill(mask == 0, float('-inf')) scores = scores.masked_fill(mask == 0, float("-inf"))
attention = F.softmax( attention = F.softmax(scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv]
scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv]
out = einsum(attention, value, out = einsum(attention, value, "b g h s, b h s d -> b g h d") # [batch_size, num_head_groups, groups, dim]
'b g h s, b h s d -> b g h d') # [batch_size, num_head_groups, groups, dim] out = rearrange(out, "b g h d -> b (h g) d") # [batch_size, heads, dim]
out = rearrange(out, 'b g h d -> b (h g) d') # [batch_size, heads, dim]
return out return out
...@@ -334,16 +335,12 @@ def flash_split_ref(Q, K, V, mask): ...@@ -334,16 +335,12 @@ def flash_split_ref(Q, K, V, mask):
seqlen_kv = K.size(1) seqlen_kv = K.size(1)
num_head_groups = nheads // groups num_head_groups = nheads // groups
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
acc_s = torch.empty((batch, num_head_groups, groups, block_N), device="cuda", dtype=torch.float) acc_s = torch.empty((batch, num_head_groups, groups, block_N), device="cuda", dtype=torch.float)
acc_s_cast = torch.empty((batch, num_head_groups, groups, block_N), acc_s_cast = torch.empty((batch, num_head_groups, groups, block_N), device="cuda", dtype=torch.float16)
device="cuda",
dtype=torch.float16)
acc_o = torch.empty((batch, num_head_groups, groups, dim), device="cuda", dtype=torch.float) acc_o = torch.empty((batch, num_head_groups, groups, dim), device="cuda", dtype=torch.float)
scores_max = torch.empty((batch, num_head_groups, groups), device="cuda", dtype=torch.float) scores_max = torch.empty((batch, num_head_groups, groups), device="cuda", dtype=torch.float)
scores_max_prev = torch.empty((batch, num_head_groups, groups), scores_max_prev = torch.empty((batch, num_head_groups, groups), device="cuda", dtype=torch.float)
device="cuda",
dtype=torch.float)
scores_scale = torch.empty((batch, num_head_groups, groups), device="cuda", dtype=torch.float) scores_scale = torch.empty((batch, num_head_groups, groups), device="cuda", dtype=torch.float)
scores_sum = torch.empty((batch, num_head_groups, groups), device="cuda", dtype=torch.float) scores_sum = torch.empty((batch, num_head_groups, groups), device="cuda", dtype=torch.float)
logsum = torch.empty((batch, num_head_groups, groups), device="cuda", dtype=torch.float) logsum = torch.empty((batch, num_head_groups, groups), device="cuda", dtype=torch.float)
...@@ -351,25 +348,25 @@ def flash_split_ref(Q, K, V, mask): ...@@ -351,25 +348,25 @@ def flash_split_ref(Q, K, V, mask):
glogsum = torch.empty((num_split, batch, nheads), device="cuda", dtype=torch.float) glogsum = torch.empty((num_split, batch, nheads), device="cuda", dtype=torch.float)
Q_ = Q * scale Q_ = Q * scale
Q_ = rearrange(Q_, 'b (h g) d -> b g h d', g=num_head_groups) Q_ = rearrange(Q_, "b (h g) d -> b g h d", g=num_head_groups)
for ks in range(num_split): for ks in range(num_split):
acc_o.fill_(0) acc_o.fill_(0)
logsum.fill_(0) logsum.fill_(0)
scores_max.fill_(float('-inf')) scores_max.fill_(float("-inf"))
scores_max_prev.fill_(float('-inf')) scores_max_prev.fill_(float("-inf"))
for i in range(int((seqlen_kv // num_split) / block_N)): for i in range(int((seqlen_kv // num_split) / block_N)):
acc_s.fill_(0) acc_s.fill_(0)
acc_s = torch.einsum('bghd,bkhd->bghk', Q_, acc_s = torch.einsum(
K[:, (seqlen_kv // num_split) * ks + "bghd,bkhd->bghk",
i * block_N:(seqlen_kv // num_split) * ks + Q_,
(i + 1) * block_N, :, :]) # [batch, nheads, block_N] K[:, (seqlen_kv // num_split) * ks + i * block_N : (seqlen_kv // num_split) * ks + (i + 1) * block_N, :, :],
) # [batch, nheads, block_N]
if mask is not None: if mask is not None:
mask_local = mask[:, (seqlen_kv // num_split) * ks + mask_local = mask[:, (seqlen_kv // num_split) * ks + i * block_N : (seqlen_kv // num_split) * ks + (i + 1) * block_N, :]
i * block_N:(seqlen_kv // num_split) * ks + (i + 1) * block_N, :] mask_local = rearrange(mask_local, "b s h -> b h s")
mask_local = rearrange(mask_local, 'b s h -> b h s')
mask_local = mask_local.unsqueeze(1) mask_local = mask_local.unsqueeze(1)
acc_s = acc_s.masked_fill(mask_local == 0, float('-inf')) acc_s = acc_s.masked_fill(mask_local == 0, float("-inf"))
scores_max_prev = scores_max scores_max_prev = scores_max
scores_max = acc_s.max(dim=-1, keepdim=False).values # [batch, nheads] scores_max = acc_s.max(dim=-1, keepdim=False).values # [batch, nheads]
scores_scale = torch.exp2(scores_max_prev - scores_max) # [batch, nheads] scores_scale = torch.exp2(scores_max_prev - scores_max) # [batch, nheads]
...@@ -377,15 +374,16 @@ def flash_split_ref(Q, K, V, mask): ...@@ -377,15 +374,16 @@ def flash_split_ref(Q, K, V, mask):
acc_s = torch.exp2(acc_s - scores_max[:, :, :, None]) acc_s = torch.exp2(acc_s - scores_max[:, :, :, None])
acc_s_cast = acc_s.to(torch.float16) # [batch, nheads, block_N] acc_s_cast = acc_s.to(torch.float16) # [batch, nheads, block_N]
acc_o += torch.einsum( acc_o += torch.einsum(
'bghk,bkhd->bghd', acc_s_cast, "bghk,bkhd->bghd",
V[:, (seqlen_kv // num_split) * ks + i * block_N:(seqlen_kv // num_split) * ks + acc_s_cast,
(i + 1) * block_N, :, :]) V[:, (seqlen_kv // num_split) * ks + i * block_N : (seqlen_kv // num_split) * ks + (i + 1) * block_N, :, :],
)
scores_sum = acc_s.sum(dim=-1, keepdim=False) scores_sum = acc_s.sum(dim=-1, keepdim=False)
logsum = logsum * scores_scale + scores_sum logsum = logsum * scores_scale + scores_sum
acc_o_out = rearrange(acc_o, 'b g h d->b (h g) d') acc_o_out = rearrange(acc_o, "b g h d->b (h g) d")
logsum_out = rearrange(logsum, 'b g h->b (h g)') logsum_out = rearrange(logsum, "b g h->b (h g)")
acc_o_out /= logsum_out[:, :, None] acc_o_out /= logsum_out[:, :, None]
logsum_out = torch.log2(logsum_out) + rearrange(scores_max, 'b g h->b (h g)') logsum_out = torch.log2(logsum_out) + rearrange(scores_max, "b g h->b (h g)")
gacc_o[ks, :, :, :] = acc_o_out gacc_o[ks, :, :, :] = acc_o_out
glogsum[ks, :, :] = logsum_out glogsum[ks, :, :] = logsum_out
...@@ -421,7 +419,7 @@ def calc_sim(x, y, name="tensor"): ...@@ -421,7 +419,7 @@ def calc_sim(x, y, name="tensor"):
x, y = x.data.double(), y.data.double() x, y = x.data.double(), y.data.double()
denominator = (x * x + y * y).sum() denominator = (x * x + y * y).sum()
if denominator == 0: if denominator == 0:
print_red_warning(f'{name} all zero') print_red_warning(f"{name} all zero")
return 1 return 1
sim = 2 * (x * y).sum() / denominator sim = 2 * (x * y).sum() / denominator
return sim return sim
...@@ -429,28 +427,23 @@ def calc_sim(x, y, name="tensor"): ...@@ -429,28 +427,23 @@ def calc_sim(x, y, name="tensor"):
def assert_similar(x, y, eps=1e-2, name="tensor", assert_=False, print_=True): def assert_similar(x, y, eps=1e-2, name="tensor", assert_=False, print_=True):
sim = calc_sim(x, y, name) sim = calc_sim(x, y, name)
diff = 1. - sim diff = 1.0 - sim
if not (0 <= diff <= eps): if not (0 <= diff <= eps):
print_red_warning(f'{name} Error: {diff}') print_red_warning(f"{name} Error: {diff}")
if assert_: if assert_:
raise AssertionError(f'{name} Error: {diff}') raise AssertionError(f"{name} Error: {diff}")
else: else:
if print_: if print_:
print(f'passed: {name} diff={diff}') print(f"passed: {name} diff={diff}")
def main(batch: int = 1, def main(batch: int = 1, heads: int = 32, groups: int = 8, kv_seqlen: int = 8192, dim: int = 128, tune: bool = False):
heads: int = 32,
groups: int = 8,
kv_seqlen: int = 8192,
dim: int = 128,
tune: bool = False):
batch, heads, groups, kv_seqlen, dim = batch, heads, groups, kv_seqlen, dim batch, heads, groups, kv_seqlen, dim = batch, heads, groups, kv_seqlen, dim
qk_flops = 2 * batch * heads * kv_seqlen * dim qk_flops = 2 * batch * heads * kv_seqlen * dim
pv_flops = 2 * batch * heads * kv_seqlen * dim pv_flops = 2 * batch * heads * kv_seqlen * dim
total_flops = qk_flops + pv_flops total_flops = qk_flops + pv_flops
if (not tune): if not tune:
config, sm_version = get_heuristic_config() config, sm_version = get_heuristic_config()
kernel = flashattn(batch, heads, groups, kv_seqlen, dim, **config) kernel = flashattn(batch, heads, groups, kv_seqlen, dim, **config)
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Auto) profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Auto)
...@@ -470,7 +463,7 @@ def main(batch: int = 1, ...@@ -470,7 +463,7 @@ def main(batch: int = 1,
print(o_ref) print(o_ref)
assert_similar(o, o_ref, name="o_ref") assert_similar(o, o_ref, name="o_ref")
assert_similar(o_ref_split, o_ref, name="o_ref_split") assert_similar(o, o_ref_split, name="o_ref_split")
print("All checks pass.") print("All checks pass.")
latency = profiler.do_bench(ref_program, warmup=500) latency = profiler.do_bench(ref_program, warmup=500)
...@@ -492,11 +485,11 @@ def main(batch: int = 1, ...@@ -492,11 +485,11 @@ def main(batch: int = 1,
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=1, help='batch size') parser.add_argument("--batch", type=int, default=1, help="batch size")
parser.add_argument('--heads', type=int, default=32, help='heads') parser.add_argument("--heads", type=int, default=32, help="heads")
parser.add_argument('--groups', type=int, default=8, help='groups') parser.add_argument("--groups", type=int, default=8, help="groups")
parser.add_argument('--kv_seqlen', type=int, default=8192, help='kv sequence length') parser.add_argument("--kv_seqlen", type=int, default=8192, help="kv sequence length")
parser.add_argument('--dim', type=int, default=128, help='dim') parser.add_argument("--dim", type=int, default=128, help="dim")
parser.add_argument('--tune', action='store_true', help='tune configs') parser.add_argument("--tune", action="store_true", help="tune configs")
args = parser.parse_args() args = parser.parse_args()
main(args.batch, args.heads, args.groups, args.kv_seqlen, args.dim, args.tune) main(args.batch, args.heads, args.groups, args.kv_seqlen, args.dim, args.tune)
...@@ -19,8 +19,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: ...@@ -19,8 +19,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
batch, num_key_value_heads, slen, head_dim = hidden_states.shape batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1: if n_rep == 1:
return hidden_states return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
...@@ -74,14 +73,9 @@ def _fwd_inner( ...@@ -74,14 +73,9 @@ def _fwd_inner(
return m_i, l_i, acc return m_i, l_i, acc
@triton.autotune( @triton.autotune(
configs=[ configs=[triton.Config({}, num_warps=num_warps, num_stages=num_stages) for num_warps in [4, 8] for num_stages in [2, 4]],
triton.Config({}, num_warps=num_warps, num_stages=num_stages) key=["gqa_group_size", "BLOCK_N", "BLOCK_D", "BLOCK_H"],
for num_warps in [4, 8]\
for num_stages in [2, 4]\
],
key=['gqa_group_size', 'BLOCK_N', 'BLOCK_D', 'BLOCK_H'],
) )
@triton.jit @triton.jit
def _fwd_kernel_varlen( def _fwd_kernel_varlen(
...@@ -107,13 +101,12 @@ def _fwd_kernel_varlen( ...@@ -107,13 +101,12 @@ def _fwd_kernel_varlen(
stride_od, stride_od,
stride_sb, stride_sb,
stride_sh, stride_sh,
stride_sn, #bmask shape [b, q_h, seq/BLOCK_N] stride_sn, # bmask shape [b, q_h, seq/BLOCK_N]
gqa_group_size: tl.constexpr, gqa_group_size: tl.constexpr,
BLOCK_H: tl.constexpr, BLOCK_H: tl.constexpr,
BLOCK_N: tl.constexpr, BLOCK_N: tl.constexpr,
BLOCK_D: tl.constexpr, BLOCK_D: tl.constexpr,
): ):
off_z = tl.program_id(0) off_z = tl.program_id(0)
off_h_for_kv = tl.program_id(1) off_h_for_kv = tl.program_id(1)
off_h_q = off_h_for_kv * gqa_group_size off_h_q = off_h_for_kv * gqa_group_size
...@@ -134,8 +127,7 @@ def _fwd_kernel_varlen( ...@@ -134,8 +127,7 @@ def _fwd_kernel_varlen(
S_ptrs = S + off_z * stride_sb + off_h_q * stride_sh S_ptrs = S + off_z * stride_sb + off_h_q * stride_sh
mask_h = offs_h < gqa_group_size mask_h = offs_h < gqa_group_size
q = tl.load( q = tl.load(Q_ptrs + offs_d[None, :] * stride_qd + offs_h[:, None] * stride_qh, mask=mask_h[:, None])
Q_ptrs + offs_d[None, :] * stride_qd + offs_h[:, None] * stride_qh, mask=mask_h[:, None])
if s_aux is not None: if s_aux is not None:
sink = tl.load(s_aux + off_h_q + offs_h, mask=mask_h).to(tl.float32) sink = tl.load(s_aux + off_h_q + offs_h, mask=mask_h).to(tl.float32)
...@@ -189,14 +181,12 @@ def _fwd_kernel_varlen( ...@@ -189,14 +181,12 @@ def _fwd_kernel_varlen(
acc = acc.to(O.dtype.element_ty) acc = acc.to(O.dtype.element_ty)
tl.store( tl.store(O_ptrs + offs_h[:, None] * stride_oh + offs_d[None, :] * stride_od, acc, mask=mask_h[:, None])
O_ptrs + offs_h[:, None] * stride_oh + offs_d[None, :] * stride_od,
acc,
mask=mask_h[:, None])
def get_configs(): def get_configs():
import itertools import itertools
block_N = [64, 128] block_N = [64, 128]
block_H = [64] block_H = [64]
num_split = [1] num_split = [1]
...@@ -204,38 +194,23 @@ def get_configs(): ...@@ -204,38 +194,23 @@ def get_configs():
threads = [128] threads = [128]
_configs = list(itertools.product(block_N, block_H, num_split, num_stages, threads)) _configs = list(itertools.product(block_N, block_H, num_split, num_stages, threads))
configs = [{ configs = [{"block_N": c[0], "block_H": c[1], "num_split": c[2], "num_stages": c[3], "threads": c[4]} for c in _configs]
'block_N': c[0],
'block_H': c[1],
'num_split': c[2],
'num_stages': c[3],
'threads': c[4]
} for c in _configs]
return configs return configs
@autotune(configs=get_configs(), warmup=10, rep=10) @autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit(out_idx=[-2, -1], debug_root_path="./examples/flash_decoding") @tilelang.jit(out_idx=[-2, -1], debug_root_path="./examples/flash_decoding")
def flashattn(batch, def flashattn(
heads, batch, heads, k_heads, max_seqlen_kv, total_seqlen_k, dim, has_sink, block_N=128, block_H=64, num_split=1, num_stages=1, threads=128
k_heads, ):
max_seqlen_kv, scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
total_seqlen_k,
dim,
has_sink,
block_N=128,
block_H=64,
num_split=1,
num_stages=1,
threads=128):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
shape_q = [batch, heads, dim] shape_q = [batch, heads, dim]
shape_k = [total_seqlen_k, k_heads, dim] shape_k = [total_seqlen_k, k_heads, dim]
shape_v = [total_seqlen_k, k_heads, dim] shape_v = [total_seqlen_k, k_heads, dim]
shape_o = [batch, heads, dim] shape_o = [batch, heads, dim]
shape_s = [batch, heads, math.ceil(max_seqlen_kv / block_N)] shape_s = [batch, heads, math.ceil(max_seqlen_kv / block_N)]
dtype = "float16" dtype = T.float16
accum_dtype = "float" accum_dtype = T.float32
kv_group_num = heads // k_heads kv_group_num = heads // k_heads
valid_block_H = min(block_H, kv_group_num) valid_block_H = min(block_H, kv_group_num)
...@@ -243,13 +218,13 @@ def flashattn(batch, ...@@ -243,13 +218,13 @@ def flashattn(batch,
@T.macro @T.macro
def flash_attn( def flash_attn(
Q: T.Tensor(shape_q, dtype), Q: T.Tensor(shape_q, dtype),
K: T.Tensor(shape_k, dtype), K: T.Tensor(shape_k, dtype),
V: T.Tensor(shape_v, dtype), V: T.Tensor(shape_v, dtype),
cu_seqlens_k: T.Tensor([batch + 1], "int32"), cu_seqlens_k: T.Tensor([batch + 1], T.int32),
s_aux: T.Tensor([heads], "float32"), s_aux: T.Tensor([heads], T.float32),
Output: T.Tensor([batch, heads, dim], dtype), Output: T.Tensor([batch, heads, dim], dtype),
S: T.Tensor(shape_s, dtype), S: T.Tensor(shape_s, dtype),
): ):
with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz): with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_H, dim], dtype) Q_shared = T.alloc_shared([block_H, dim], dtype)
...@@ -266,15 +241,17 @@ def flashattn(batch, ...@@ -266,15 +241,17 @@ def flashattn(batch,
logsum = T.alloc_fragment([block_H], accum_dtype) logsum = T.alloc_fragment([block_H], accum_dtype)
S_shared = T.alloc_shared([block_H, math.ceil(max_seqlen_kv / block_N)], dtype) S_shared = T.alloc_shared([block_H, math.ceil(max_seqlen_kv / block_N)], dtype)
# S_fragment = T.alloc_fragment([block_H, math.ceil(max_seqlen_kv / block_N)], accum_dtype) # S_fragment = T.alloc_fragment([block_H, math.ceil(max_seqlen_kv / block_N)], accum_dtype)
s_aux_shared = T.alloc_shared([block_H], "float32") s_aux_shared = T.alloc_shared([block_H], T.float32)
T.annotate_layout({ T.annotate_layout(
# Q_shared: tilelang.layout.make_swizzled_layout(Q_shared), {
# K_shared: tilelang.layout.make_swizzled_layout(K_shared), # Q_shared: tilelang.layout.make_swizzled_layout(Q_shared),
# V_shared: tilelang.layout.make_swizzled_layout(V_shared), # K_shared: tilelang.layout.make_swizzled_layout(K_shared),
# O_shared: tilelang.layout.make_swizzled_layout(O_shared), # V_shared: tilelang.layout.make_swizzled_layout(V_shared),
# S_shared: tilelang.layout.make_swizzled_layout(S_shared), # O_shared: tilelang.layout.make_swizzled_layout(O_shared),
}) # S_shared: tilelang.layout.make_swizzled_layout(S_shared),
}
)
bid = bx bid = bx
hid = by hid = by
...@@ -284,7 +261,7 @@ def flashattn(batch, ...@@ -284,7 +261,7 @@ def flashattn(batch,
cur_end_k = cu_seqlens_k[bid + 1] cur_end_k = cu_seqlens_k[bid + 1]
cur_seqlen_k = cur_end_k - cur_start_k cur_seqlen_k = cur_end_k - cur_start_k
T.copy(Q[bid, hid * valid_block_H:hid * valid_block_H + block_H, :], Q_shared) T.copy(Q[bid, hid * valid_block_H : hid * valid_block_H + block_H, :], Q_shared)
T.fill(acc_o, 0) T.fill(acc_o, 0)
T.fill(logsum, 0) T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
...@@ -292,15 +269,13 @@ def flashattn(batch, ...@@ -292,15 +269,13 @@ def flashattn(batch,
# loop_range = T.ceildiv((seqlen_kv // num_split), block_N) # loop_range = T.ceildiv((seqlen_kv // num_split), block_N)
loop_range = T.ceildiv((cur_seqlen_k // num_split), block_N) loop_range = T.ceildiv((cur_seqlen_k // num_split), block_N)
for k in T.Pipelined(loop_range, num_stages=num_stages): for k in T.Pipelined(loop_range, num_stages=num_stages):
T.copy(K[cur_start_k + k * block_N:cur_start_k + (k + 1) * block_N, cur_kv_head, :], T.copy(K[cur_start_k + k * block_N : cur_start_k + (k + 1) * block_N, cur_kv_head, :], K_shared)
K_shared)
T.clear(acc_s) T.clear(acc_s)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
for i, j in T.Parallel(block_H, block_N): for i, j in T.Parallel(block_H, block_N):
# acc_s[i, j] = T.if_then_else(mask_local[j] != 0 and k * block_N + j < cur_seqlen_k, acc_s[i, j], # acc_s[i, j] = T.if_then_else(mask_local[j] != 0 and k * block_N + j < cur_seqlen_k, acc_s[i, j],
# -T.infinity(accum_dtype)) # -T.infinity(accum_dtype))
acc_s[i, j] = T.if_then_else(k * block_N + j < cur_seqlen_k, acc_s[i, j], acc_s[i, j] = T.if_then_else(k * block_N + j < cur_seqlen_k, acc_s[i, j], -T.infinity(accum_dtype))
-T.infinity(accum_dtype))
T.copy(scores_max, scores_max_prev) T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False) T.reduce_max(acc_s, scores_max, dim=1, clear=False)
...@@ -320,12 +295,11 @@ def flashattn(batch, ...@@ -320,12 +295,11 @@ def flashattn(batch,
T.copy(acc_s, acc_s_cast) T.copy(acc_s, acc_s_cast)
for i, j in T.Parallel(block_H, dim): for i, j in T.Parallel(block_H, dim):
acc_o[i, j] *= scores_scale[i] acc_o[i, j] *= scores_scale[i]
T.copy(V[cur_start_k + k * block_N:cur_start_k + (k + 1) * block_N, cur_kv_head, :], T.copy(V[cur_start_k + k * block_N : cur_start_k + (k + 1) * block_N, cur_kv_head, :], V_shared)
V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
if has_sink: if has_sink:
T.copy(s_aux[hid * valid_block_H:hid * valid_block_H + block_H], s_aux_shared) T.copy(s_aux[hid * valid_block_H : hid * valid_block_H + block_H], s_aux_shared)
for i in T.Parallel(block_H): for i in T.Parallel(block_H):
logsum[i] += s_aux_shared[i] logsum[i] += s_aux_shared[i]
for i, j in T.Parallel(block_H, dim): for i, j in T.Parallel(block_H, dim):
...@@ -338,20 +312,19 @@ def flashattn(batch, ...@@ -338,20 +312,19 @@ def flashattn(batch,
for i in T.Parallel(block_H): for i in T.Parallel(block_H):
logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
T.copy(acc_o[:valid_block_H, :], O_shared) T.copy(acc_o[:valid_block_H, :], O_shared)
T.copy(O_shared, Output[bid, hid * valid_block_H:(hid + 1) * valid_block_H, :]) T.copy(O_shared, Output[bid, hid * valid_block_H : (hid + 1) * valid_block_H, :])
# T.copy(S_fragment, S_shared) # T.copy(S_fragment, S_shared)
T.copy(S_shared[:valid_block_H, :], S[bid, T.copy(S_shared[:valid_block_H, :], S[bid, hid * valid_block_H : (hid + 1) * valid_block_H, :])
hid * valid_block_H:(hid + 1) * valid_block_H, :])
@T.prim_func @T.prim_func
def flashattn_gqa_decode_no_split( def flashattn_gqa_decode_no_split(
Q: T.Tensor(shape_q, dtype), Q: T.Tensor(shape_q, dtype),
K: T.Tensor(shape_k, dtype), K: T.Tensor(shape_k, dtype),
V: T.Tensor(shape_v, dtype), V: T.Tensor(shape_v, dtype),
cu_seqlens_k: T.Tensor([batch + 1], "int32"), cu_seqlens_k: T.Tensor([batch + 1], T.int32),
s_aux: T.Tensor([heads], "float32"), s_aux: T.Tensor([heads], T.float32),
Output: T.Tensor(shape_o, dtype), Output: T.Tensor(shape_o, dtype),
S: T.Tensor(shape_s, dtype), S: T.Tensor(shape_s, dtype),
): ):
flash_attn(Q, K, V, cu_seqlens_k, s_aux, Output, S) flash_attn(Q, K, V, cu_seqlens_k, s_aux, Output, S)
...@@ -388,9 +361,7 @@ def flash_attn_with_attn_pool_decode_tilelang( ...@@ -388,9 +361,7 @@ def flash_attn_with_attn_pool_decode_tilelang(
gqa_group_size = q_h // k_h gqa_group_size = q_h // k_h
O_tl = torch.zeros_like(Q) O_tl = torch.zeros_like(Q)
S_tl = torch.zeros((batch, q_h, math.ceil(real_max_k_seqlen / block_size)), S_tl = torch.zeros((batch, q_h, math.ceil(real_max_k_seqlen / block_size)), dtype=Q.dtype, device=Q.device)
dtype=Q.dtype,
device=Q.device)
O_tl, S_tl = tl_kernel(Q, K, V, cu_seqlens_k, s_aux) O_tl, S_tl = tl_kernel(Q, K, V, cu_seqlens_k, s_aux)
if use_per_kv_head_sparse_index: if use_per_kv_head_sparse_index:
...@@ -433,9 +404,7 @@ def flash_attn_with_attn_pool_decode( ...@@ -433,9 +404,7 @@ def flash_attn_with_attn_pool_decode(
BLOCK_H = 64 BLOCK_H = 64
O = torch.zeros_like(Q) O = torch.zeros_like(Q)
S = torch.zeros((batch, q_h, math.ceil(max_seqlen_k / block_size)), S = torch.zeros((batch, q_h, math.ceil(max_seqlen_k / block_size)), dtype=Q.dtype, device=Q.device)
dtype=Q.dtype,
device=Q.device)
def grid(META): def grid(META):
return (batch, k_h) return (batch, k_h)
...@@ -480,18 +449,18 @@ def test_equal_seqlen_decode_main(args): ...@@ -480,18 +449,18 @@ def test_equal_seqlen_decode_main(args):
real_max_k_seqlen = args.k_seqlen real_max_k_seqlen = args.k_seqlen
head_size = args.head_size head_size = args.head_size
block_size = args.block_size block_size = args.block_size
dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16 dtype = torch.bfloat16 if args.dtype == T.bfloat16 else torch.float16
# For decode, query is just 1 token per batch # For decode, query is just 1 token per batch
q = torch.randn(batch_size, q_heads, head_size, device='cuda', dtype=dtype) q = torch.randn(batch_size, q_heads, head_size, device="cuda", dtype=dtype)
k = torch.randn(batch_size, kv_heads, k_seqlen, head_size, device='cuda', dtype=dtype) k = torch.randn(batch_size, kv_heads, k_seqlen, head_size, device="cuda", dtype=dtype)
v = torch.randn(batch_size, kv_heads, k_seqlen, head_size, device='cuda', dtype=dtype) v = torch.randn(batch_size, kv_heads, k_seqlen, head_size, device="cuda", dtype=dtype)
softmax_scale = 1.0 / math.sqrt(head_size) softmax_scale = 1.0 / math.sqrt(head_size)
# Generate sink values if needed # Generate sink values if needed
sink = None sink = None
if args.test_sink: if args.test_sink:
sink = torch.randn(q_heads, device='cuda', dtype=torch.float32) * 0.1 # Small sink values sink = torch.randn(q_heads, device="cuda", dtype=torch.float32) * 0.1 # Small sink values
print(f"Using sink attention with sink values: {sink}") print(f"Using sink attention with sink values: {sink}")
# Convert to varlen format for K, V # Convert to varlen format for K, V
...@@ -499,8 +468,7 @@ def test_equal_seqlen_decode_main(args): ...@@ -499,8 +468,7 @@ def test_equal_seqlen_decode_main(args):
v_varlen = v.transpose(1, 2).reshape(batch_size * k_seqlen, kv_heads, head_size) v_varlen = v.transpose(1, 2).reshape(batch_size * k_seqlen, kv_heads, head_size)
# Generate cumulative sequence lengths # Generate cumulative sequence lengths
cu_seqlens_k = torch.arange( cu_seqlens_k = torch.arange(0, (batch_size + 1) * k_seqlen, k_seqlen, device="cuda", dtype=torch.int32)
0, (batch_size + 1) * k_seqlen, k_seqlen, device='cuda', dtype=torch.int32)
max_seqlen_k = k_seqlen max_seqlen_k = k_seqlen
print(f"q shape: {q.shape}") print(f"q shape: {q.shape}")
...@@ -510,8 +478,7 @@ def test_equal_seqlen_decode_main(args): ...@@ -510,8 +478,7 @@ def test_equal_seqlen_decode_main(args):
num_tokens, q_h, head_size = q.shape num_tokens, q_h, head_size = q.shape
batch = cu_seqlens_k.size(0) - 1 batch = cu_seqlens_k.size(0) - 1
k_h = k_varlen.size(1) k_h = k_varlen.size(1)
tl_kernel = flashattn(batch, q_h, k_h, args.k_seqlen, cu_seqlens_k[-1].item(), head_size, tl_kernel = flashattn(batch, q_h, k_h, args.k_seqlen, cu_seqlens_k[-1].item(), head_size, args.test_sink)
args.test_sink)
# Test our decode kernel # Test our decode kernel
O_triton, S_triton = flash_attn_with_attn_pool_decode( O_triton, S_triton = flash_attn_with_attn_pool_decode(
...@@ -524,7 +491,8 @@ def test_equal_seqlen_decode_main(args): ...@@ -524,7 +491,8 @@ def test_equal_seqlen_decode_main(args):
args.num_split, args.num_split,
softmax_scale, softmax_scale,
s_aux=sink, s_aux=sink,
block_size=block_size) block_size=block_size,
)
O_tilelang, S_tilelang = flash_attn_with_attn_pool_decode_tilelang( O_tilelang, S_tilelang = flash_attn_with_attn_pool_decode_tilelang(
q, q,
k_varlen, k_varlen,
...@@ -539,9 +507,7 @@ def test_equal_seqlen_decode_main(args): ...@@ -539,9 +507,7 @@ def test_equal_seqlen_decode_main(args):
tl_kernel=tl_kernel, tl_kernel=tl_kernel,
) )
for i in range(batch_size): for i in range(batch_size):
S_tilelang[i, :, S_tilelang[i, :, math.ceil((cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item()) / block_size) :] = 0
math.ceil((cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item()) /
block_size):] = 0
# Compute torch reference # Compute torch reference
q_expanded = q.unsqueeze(2) # [b, q_heads, 1, head_size] q_expanded = q.unsqueeze(2) # [b, q_heads, 1, head_size]
...@@ -550,14 +516,12 @@ def test_equal_seqlen_decode_main(args): ...@@ -550,14 +516,12 @@ def test_equal_seqlen_decode_main(args):
if sink is None: if sink is None:
# Standard scaled dot-product attention # Standard scaled dot-product attention
logits = torch.matmul(q_expanded, k_repeat.transpose( logits = torch.matmul(q_expanded, k_repeat.transpose(-2, -1)) * softmax_scale # [batch, q_heads, 1, seqlen_k]
-2, -1)) * softmax_scale # [batch, q_heads, 1, seqlen_k]
attn_weights = torch.softmax(logits, dim=-1) attn_weights = torch.softmax(logits, dim=-1)
O_torch = torch.matmul(attn_weights, v_repeat).squeeze(2) # [batch, q_heads, head_size] O_torch = torch.matmul(attn_weights, v_repeat).squeeze(2) # [batch, q_heads, head_size]
else: else:
# s_aux attention # s_aux attention
logits = torch.matmul(q_expanded, k_repeat.transpose( logits = torch.matmul(q_expanded, k_repeat.transpose(-2, -1)) * softmax_scale # [batch, q_heads, 1, seqlen_k]
-2, -1)) * softmax_scale # [batch, q_heads, 1, seqlen_k]
sink_expanded = sink.view(1, q_heads, 1, 1) # [1, q_heads, 1, 1] sink_expanded = sink.view(1, q_heads, 1, 1) # [1, q_heads, 1, 1]
logits_max = torch.max(logits, dim=-1, keepdim=True).values logits_max = torch.max(logits, dim=-1, keepdim=True).values
...@@ -566,15 +530,15 @@ def test_equal_seqlen_decode_main(args): ...@@ -566,15 +530,15 @@ def test_equal_seqlen_decode_main(args):
unnormalized_scores = torch.exp(logits - logits_or_sinks_max) unnormalized_scores = torch.exp(logits - logits_or_sinks_max)
normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks
attn_weights = unnormalized_scores / normalizer attn_weights = unnormalized_scores / normalizer
O_torch = torch.matmul(attn_weights.to(v_repeat.dtype), O_torch = torch.matmul(attn_weights.to(v_repeat.dtype), v_repeat).squeeze(2) # [batch, q_heads, head_size]
v_repeat).squeeze(2) # [batch, q_heads, head_size]
# Compute attention score pooling # Compute attention score pooling
attn_score_pooled = torch.max_pool2d( attn_score_pooled = torch.max_pool2d(
attn_weights.squeeze(2), # [b, q_heads, k_seqlen] attn_weights.squeeze(2), # [b, q_heads, k_seqlen]
kernel_size=(q_heads, block_size), kernel_size=(q_heads, block_size),
stride=(q_heads, block_size), stride=(q_heads, block_size),
ceil_mode=True).to(torch.float16) ceil_mode=True,
).to(torch.float16)
print("S_tilelang", S_tilelang) print("S_tilelang", S_tilelang)
print("attn_score_pooled", attn_score_pooled) print("attn_score_pooled", attn_score_pooled)
...@@ -588,15 +552,10 @@ def test_equal_seqlen_decode_main(args): ...@@ -588,15 +552,10 @@ def test_equal_seqlen_decode_main(args):
print(f"Max difference in S: {max_diff_s.item()}") print(f"Max difference in S: {max_diff_s.item()}")
print(f"Max difference in O_tilelang: {max_diff_o_tilelang.item()}") print(f"Max difference in O_tilelang: {max_diff_o_tilelang.item()}")
print(f"Max difference in S_tilelang: {max_diff_s_tilelang.item()}") print(f"Max difference in S_tilelang: {max_diff_s_tilelang.item()}")
assert torch.allclose( assert torch.allclose(O_triton, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o.item()}"
O_triton, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o.item()}" assert torch.allclose(S_triton, attn_score_pooled, atol=1e-2, rtol=1e-2), f"Score mismatch: {max_diff_s.item()}"
assert torch.allclose( assert torch.allclose(O_tilelang, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o_tilelang.item()}"
S_triton, attn_score_pooled, atol=1e-2, rtol=1e-2), f"Score mismatch: {max_diff_s.item()}" assert torch.allclose(S_tilelang, attn_score_pooled, atol=1e-2, rtol=1e-2), f"Score mismatch: {max_diff_s_tilelang.item()}"
assert torch.allclose(
O_tilelang, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o_tilelang.item()}"
assert torch.allclose(
S_tilelang, attn_score_pooled, atol=1e-2,
rtol=1e-2), f"Score mismatch: {max_diff_s_tilelang.item()}"
print("✅ All tests passed!") print("✅ All tests passed!")
...@@ -609,14 +568,14 @@ def test_varlen_decode_main(args): ...@@ -609,14 +568,14 @@ def test_varlen_decode_main(args):
real_max_k_seqlen = args.k_seqlen real_max_k_seqlen = args.k_seqlen
head_size = args.head_size head_size = args.head_size
block_size = args.block_size block_size = args.block_size
dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16 dtype = torch.bfloat16 if args.dtype == T.bfloat16 else torch.float16
print(f"Testing decode kernel with variable sequence lengths (max_k_seqlen={max_k_seqlen})") print(f"Testing decode kernel with variable sequence lengths (max_k_seqlen={max_k_seqlen})")
# Generate sink values if needed # Generate sink values if needed
sink = None sink = None
if args.test_sink: if args.test_sink:
sink = torch.randn(q_heads, device='cuda', dtype=torch.float32) * 0.1 # Small sink values sink = torch.randn(q_heads, device="cuda", dtype=torch.float32) * 0.1 # Small sink values
print(f"Using sink attention with sink values: {sink}") print(f"Using sink attention with sink values: {sink}")
# Generate variable length k sequences # Generate variable length k sequences
...@@ -624,7 +583,7 @@ def test_varlen_decode_main(args): ...@@ -624,7 +583,7 @@ def test_varlen_decode_main(args):
print(f"k_seqlens: {k_seqlens}") print(f"k_seqlens: {k_seqlens}")
# Generate cumulative sequence lengths for k # Generate cumulative sequence lengths for k
cu_seqlens_k = torch.zeros(batch_size + 1, device='cuda', dtype=torch.int32) cu_seqlens_k = torch.zeros(batch_size + 1, device="cuda", dtype=torch.int32)
total_k_tokens = 0 total_k_tokens = 0
for i in range(batch_size): for i in range(batch_size):
cu_seqlens_k[i] = total_k_tokens cu_seqlens_k[i] = total_k_tokens
...@@ -634,9 +593,9 @@ def test_varlen_decode_main(args): ...@@ -634,9 +593,9 @@ def test_varlen_decode_main(args):
print(f"cu_seqlens_k: {cu_seqlens_k}") print(f"cu_seqlens_k: {cu_seqlens_k}")
# Generate tensors - Q is [batch_size, q_heads, head_size] for decode # Generate tensors - Q is [batch_size, q_heads, head_size] for decode
q_decode = torch.randn(batch_size, q_heads, head_size, device='cuda', dtype=dtype) q_decode = torch.randn(batch_size, q_heads, head_size, device="cuda", dtype=dtype)
k_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device='cuda', dtype=dtype) k_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device="cuda", dtype=dtype)
v_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device='cuda', dtype=dtype) v_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device="cuda", dtype=dtype)
softmax_scale = 1.0 / math.sqrt(head_size) softmax_scale = 1.0 / math.sqrt(head_size)
max_seqlen_k = int(k_seqlens.max()) max_seqlen_k = int(k_seqlens.max())
...@@ -649,8 +608,7 @@ def test_varlen_decode_main(args): ...@@ -649,8 +608,7 @@ def test_varlen_decode_main(args):
num_tokens, q_h, head_size = q_decode.shape num_tokens, q_h, head_size = q_decode.shape
batch = cu_seqlens_k.size(0) - 1 batch = cu_seqlens_k.size(0) - 1
k_h = k_varlen.size(1) k_h = k_varlen.size(1)
tl_kernel = flashattn(batch, q_h, k_h, args.k_seqlen, cu_seqlens_k[-1].item(), head_size, tl_kernel = flashattn(batch, q_h, k_h, args.k_seqlen, cu_seqlens_k[-1].item(), head_size, args.test_sink)
args.test_sink)
# Test our decode kernel # Test our decode kernel
O_triton, S_triton = flash_attn_with_attn_pool_decode( O_triton, S_triton = flash_attn_with_attn_pool_decode(
...@@ -663,7 +621,8 @@ def test_varlen_decode_main(args): ...@@ -663,7 +621,8 @@ def test_varlen_decode_main(args):
args.num_split, args.num_split,
softmax_scale, softmax_scale,
s_aux=sink, s_aux=sink,
block_size=block_size) block_size=block_size,
)
O_tilelang, S_tilelang = flash_attn_with_attn_pool_decode_tilelang( O_tilelang, S_tilelang = flash_attn_with_attn_pool_decode_tilelang(
q_decode, q_decode,
k_varlen, k_varlen,
...@@ -678,9 +637,7 @@ def test_varlen_decode_main(args): ...@@ -678,9 +637,7 @@ def test_varlen_decode_main(args):
tl_kernel=tl_kernel, tl_kernel=tl_kernel,
) )
for i in range(batch_size): for i in range(batch_size):
S_tilelang[i, :, S_tilelang[i, :, math.ceil((cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item()) / block_size) :] = 0
math.ceil((cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item()) /
block_size):] = 0
# Create torch reference - pad tensors for comparison # Create torch reference - pad tensors for comparison
k_padded_list = [] k_padded_list = []
...@@ -694,8 +651,8 @@ def test_varlen_decode_main(args): ...@@ -694,8 +651,8 @@ def test_varlen_decode_main(args):
k_end = cu_seqlens_k[i + 1] k_end = cu_seqlens_k[i + 1]
# Pad to max_seqlen_k # Pad to max_seqlen_k
k_padded = torch.zeros(max_seqlen_k, kv_heads, head_size, device='cuda', dtype=dtype) k_padded = torch.zeros(max_seqlen_k, kv_heads, head_size, device="cuda", dtype=dtype)
v_padded = torch.zeros(max_seqlen_k, kv_heads, head_size, device='cuda', dtype=dtype) v_padded = torch.zeros(max_seqlen_k, kv_heads, head_size, device="cuda", dtype=dtype)
k_padded[:actual_k_len] = k_varlen[k_start:k_end] k_padded[:actual_k_len] = k_varlen[k_start:k_end]
v_padded[:actual_k_len] = v_varlen[k_start:k_end] v_padded[:actual_k_len] = v_varlen[k_start:k_end]
...@@ -704,10 +661,8 @@ def test_varlen_decode_main(args): ...@@ -704,10 +661,8 @@ def test_varlen_decode_main(args):
v_padded_list.append(v_padded) v_padded_list.append(v_padded)
# Stack to create batched tensors [b, max_seqlen, kv_heads, head_size] # Stack to create batched tensors [b, max_seqlen, kv_heads, head_size]
k_padded_batched = torch.stack( k_padded_batched = torch.stack(k_padded_list, dim=0).transpose(1, 2) # [b, kv_heads, max_seqlen, head_size]
k_padded_list, dim=0).transpose(1, 2) # [b, kv_heads, max_seqlen, head_size] v_padded_batched = torch.stack(v_padded_list, dim=0).transpose(1, 2) # [b, kv_heads, max_seqlen, head_size]
v_padded_batched = torch.stack(
v_padded_list, dim=0).transpose(1, 2) # [b, kv_heads, max_seqlen, head_size]
# Expand q to match kv heads: [b, q_heads, 1, head_size] # Expand q to match kv heads: [b, q_heads, 1, head_size]
q_expanded = q_decode.unsqueeze(2) # [b, q_heads, 1, head_size] q_expanded = q_decode.unsqueeze(2) # [b, q_heads, 1, head_size]
...@@ -717,20 +672,17 @@ def test_varlen_decode_main(args): ...@@ -717,20 +672,17 @@ def test_varlen_decode_main(args):
print(f"v_padded_batched shape: {v_padded_batched.shape}") print(f"v_padded_batched shape: {v_padded_batched.shape}")
# Compute torch reference # Compute torch reference
k_repeat = repeat_kv(k_padded_batched, k_repeat = repeat_kv(k_padded_batched, q_heads // kv_heads) # [b, q_heads, max_seqlen, head_size]
q_heads // kv_heads) # [b, q_heads, max_seqlen, head_size] v_repeat = repeat_kv(v_padded_batched, q_heads // kv_heads) # [b, q_heads, max_seqlen, head_size]
v_repeat = repeat_kv(v_padded_batched,
q_heads // kv_heads) # [b, q_heads, max_seqlen, head_size]
if sink is None: if sink is None:
# Standard attention computation: [b, q_heads, 1, head_size] @ [b, q_heads, head_size, max_seqlen] # Standard attention computation: [b, q_heads, 1, head_size] @ [b, q_heads, head_size, max_seqlen]
attn_score = torch.matmul(q_expanded, k_repeat.transpose( attn_score = torch.matmul(q_expanded, k_repeat.transpose(-2, -1)) * softmax_scale # [b, q_heads, 1, max_seqlen]
-2, -1)) * softmax_scale # [b, q_heads, 1, max_seqlen]
# Apply sequence length masking # Apply sequence length masking
for i in range(batch_size): for i in range(batch_size):
actual_k_len = k_seqlens[i] actual_k_len = k_seqlens[i]
attn_score[i, :, :, actual_k_len:] = float('-inf') attn_score[i, :, :, actual_k_len:] = float("-inf")
attn_weights = attn_score.softmax(dim=-1) # [b, q_heads, 1, max_seqlen] attn_weights = attn_score.softmax(dim=-1) # [b, q_heads, 1, max_seqlen]
...@@ -743,13 +695,12 @@ def test_varlen_decode_main(args): ...@@ -743,13 +695,12 @@ def test_varlen_decode_main(args):
O_torch = torch.matmul(attn_weights, v_repeat) # [b, q_heads, 1, head_size] O_torch = torch.matmul(attn_weights, v_repeat) # [b, q_heads, 1, head_size]
else: else:
# s_aux attention # s_aux attention
logits = torch.matmul(q_expanded, k_repeat.transpose( logits = torch.matmul(q_expanded, k_repeat.transpose(-2, -1)) * softmax_scale # [b, q_heads, 1, max_seqlen]
-2, -1)) * softmax_scale # [b, q_heads, 1, max_seqlen]
# Apply sequence length masking # Apply sequence length masking
for i in range(batch_size): for i in range(batch_size):
actual_k_len = k_seqlens[i] actual_k_len = k_seqlens[i]
logits[i, :, :, actual_k_len:] = float('-inf') logits[i, :, :, actual_k_len:] = float("-inf")
sink_expanded = sink.view(1, q_heads, 1, 1) # [1, q_heads, 1, 1] sink_expanded = sink.view(1, q_heads, 1, 1) # [1, q_heads, 1, 1]
logits_max = torch.max(logits, dim=-1, keepdim=True).values logits_max = torch.max(logits, dim=-1, keepdim=True).values
...@@ -765,8 +716,7 @@ def test_varlen_decode_main(args): ...@@ -765,8 +716,7 @@ def test_varlen_decode_main(args):
attn_weights[i, :, :, actual_k_len:] = 0.0 attn_weights[i, :, :, actual_k_len:] = 0.0
# Compute output: [b, q_heads, 1, max_seqlen] @ [b, q_heads, max_seqlen, head_size] # Compute output: [b, q_heads, 1, max_seqlen] @ [b, q_heads, max_seqlen, head_size]
O_torch = torch.matmul(attn_weights.to(v_repeat.dtype), O_torch = torch.matmul(attn_weights.to(v_repeat.dtype), v_repeat) # [b, q_heads, 1, head_size]
v_repeat) # [b, q_heads, 1, head_size]
O_torch = O_torch.squeeze(2) # [b, q_heads, head_size] O_torch = O_torch.squeeze(2) # [b, q_heads, head_size]
...@@ -775,7 +725,8 @@ def test_varlen_decode_main(args): ...@@ -775,7 +725,8 @@ def test_varlen_decode_main(args):
attn_weights.squeeze(2), # [b, q_heads, max_seqlen] attn_weights.squeeze(2), # [b, q_heads, max_seqlen]
kernel_size=(q_heads, block_size), kernel_size=(q_heads, block_size),
stride=(q_heads, block_size), stride=(q_heads, block_size),
ceil_mode=True).to(dtype=torch.float16) # [b, 1, ceil(max_seqlen/block_size)] ceil_mode=True,
).to(dtype=torch.float16) # [b, 1, ceil(max_seqlen/block_size)]
print(f"O_triton shape: {O_triton.shape}") print(f"O_triton shape: {O_triton.shape}")
print(f"O_tilelang shape: {O_tilelang.shape}") print(f"O_tilelang shape: {O_tilelang.shape}")
...@@ -791,22 +742,16 @@ def test_varlen_decode_main(args): ...@@ -791,22 +742,16 @@ def test_varlen_decode_main(args):
print(f"Max difference in O_tilelang: {max_diff_o_tl.item()}") print(f"Max difference in O_tilelang: {max_diff_o_tl.item()}")
max_diff_s = torch.max(torch.abs(S_triton - attn_score_pooled)) max_diff_s = torch.max(torch.abs(S_triton - attn_score_pooled))
max_diff_s_tl = torch.max( max_diff_s_tl = torch.max(torch.abs(S_tilelang[:, :, : math.ceil(max_seqlen_k / block_size)] - attn_score_pooled))
torch.abs(S_tilelang[:, :, :math.ceil(max_seqlen_k / block_size)] - attn_score_pooled))
print(f"Max difference in S: {max_diff_s.item()}") print(f"Max difference in S: {max_diff_s.item()}")
print(f"Max difference in S_tilelang: {max_diff_s_tl.item()}") print(f"Max difference in S_tilelang: {max_diff_s_tl.item()}")
assert torch.allclose( assert torch.allclose(O_triton, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o.item()}"
O_triton, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o.item()}" assert torch.allclose(S_triton, attn_score_pooled, atol=1e-2, rtol=1e-2), f"Score mismatch: {max_diff_s.item()}"
assert torch.allclose( assert torch.allclose(O_tilelang, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o_tl.item()}"
S_triton, attn_score_pooled, atol=1e-2, rtol=1e-2), f"Score mismatch: {max_diff_s.item()}" assert torch.allclose(S_tilelang[:, :, : math.ceil(max_seqlen_k / block_size)], attn_score_pooled, atol=1e-2, rtol=1e-2), (
assert torch.allclose( f"Score mismatch: {max_diff_s_tl.item()}"
O_tilelang, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o_tl.item()}" )
assert torch.allclose(
S_tilelang[:, :, :math.ceil(max_seqlen_k / block_size)],
attn_score_pooled,
atol=1e-2,
rtol=1e-2), f"Score mismatch: {max_diff_s_tl.item()}"
print("✅ All tests passed!") print("✅ All tests passed!")
...@@ -844,7 +789,7 @@ def speed_benchmark_decode_comparison(args): ...@@ -844,7 +789,7 @@ def speed_benchmark_decode_comparison(args):
max_k_seqlen = args.k_seqlen max_k_seqlen = args.k_seqlen
head_size = args.head_size head_size = args.head_size
block_size = args.block_size block_size = args.block_size
dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16 dtype = torch.bfloat16 if args.dtype == T.bfloat16 else torch.float16
print("\n=== Decode Speed Benchmark Comparison ===") print("\n=== Decode Speed Benchmark Comparison ===")
print("Configuration:") print("Configuration:")
...@@ -865,7 +810,7 @@ def speed_benchmark_decode_comparison(args): ...@@ -865,7 +810,7 @@ def speed_benchmark_decode_comparison(args):
k_seqlens = torch.full((batch_size,), max_k_seqlen, dtype=int) k_seqlens = torch.full((batch_size,), max_k_seqlen, dtype=int)
# Generate cumulative sequence lengths for k # Generate cumulative sequence lengths for k
cu_seqlens_k = torch.zeros(batch_size + 1, device='cuda', dtype=torch.int32) cu_seqlens_k = torch.zeros(batch_size + 1, device="cuda", dtype=torch.int32)
total_k_tokens = 0 total_k_tokens = 0
for i in range(batch_size): for i in range(batch_size):
cu_seqlens_k[i] = total_k_tokens cu_seqlens_k[i] = total_k_tokens
...@@ -873,9 +818,9 @@ def speed_benchmark_decode_comparison(args): ...@@ -873,9 +818,9 @@ def speed_benchmark_decode_comparison(args):
cu_seqlens_k[batch_size] = total_k_tokens cu_seqlens_k[batch_size] = total_k_tokens
# Generate tensors # Generate tensors
q_decode = torch.randn(batch_size, q_heads, head_size, device='cuda', dtype=dtype) q_decode = torch.randn(batch_size, q_heads, head_size, device="cuda", dtype=dtype)
k_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device='cuda', dtype=dtype) k_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device="cuda", dtype=dtype)
v_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device='cuda', dtype=dtype) v_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device="cuda", dtype=dtype)
softmax_scale = 1.0 / math.sqrt(head_size) softmax_scale = 1.0 / math.sqrt(head_size)
max_seqlen_k = int(k_seqlens.max()) max_seqlen_k = int(k_seqlens.max())
...@@ -883,7 +828,7 @@ def speed_benchmark_decode_comparison(args): ...@@ -883,7 +828,7 @@ def speed_benchmark_decode_comparison(args):
# Generate sink values if needed # Generate sink values if needed
sink = None sink = None
if args.test_sink: if args.test_sink:
sink = torch.randn(q_heads, device='cuda', dtype=torch.float32) * 0.1 # Small sink values sink = torch.randn(q_heads, device="cuda", dtype=torch.float32) * 0.1 # Small sink values
print(" Using sink attention with sink values") print(" Using sink attention with sink values")
print("Setup complete:") print("Setup complete:")
...@@ -896,8 +841,7 @@ def speed_benchmark_decode_comparison(args): ...@@ -896,8 +841,7 @@ def speed_benchmark_decode_comparison(args):
num_tokens, q_h, head_size = q_decode.shape num_tokens, q_h, head_size = q_decode.shape
batch = cu_seqlens_k.size(0) - 1 batch = cu_seqlens_k.size(0) - 1
k_h = k_varlen.size(1) k_h = k_varlen.size(1)
tl_kernel = flashattn(batch, q_h, k_h, args.k_seqlen, cu_seqlens_k[-1].item(), head_size, tl_kernel = flashattn(batch, q_h, k_h, args.k_seqlen, cu_seqlens_k[-1].item(), head_size, args.test_sink)
args.test_sink)
# Benchmark # Benchmark
print("⚡ Benchmarking Tilelang kernel (100 iterations)...") print("⚡ Benchmarking Tilelang kernel (100 iterations)...")
...@@ -920,36 +864,41 @@ def speed_benchmark_decode_comparison(args): ...@@ -920,36 +864,41 @@ def speed_benchmark_decode_comparison(args):
# Benchmark # Benchmark
print("⚡ Benchmarking Triton kernel (100 iterations)...") print("⚡ Benchmarking Triton kernel (100 iterations)...")
triton_time = do_bench(flash_attn_with_attn_pool_decode, q_decode, k_varlen, v_varlen, triton_time = do_bench(
cu_seqlens_k, max_seqlen_k, args.k_seqlen, 1, softmax_scale, sink, flash_attn_with_attn_pool_decode,
block_size) q_decode,
k_varlen,
v_varlen,
cu_seqlens_k,
max_seqlen_k,
args.k_seqlen,
1,
softmax_scale,
sink,
block_size,
)
print(f"Average decode kernel time Triton: {triton_time:.3f} ms") print(f"Average decode kernel time Triton: {triton_time:.3f} ms")
print(f"Speedup: {(triton_time / tilelang_time):.3f}") print(f"Speedup: {(triton_time / tilelang_time):.3f}")
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Flash Attention Decode with Attention Pooling') parser = argparse.ArgumentParser(description="Flash Attention Decode with Attention Pooling")
parser.add_argument('--batch_size', type=int, default=1, help='Batch size') parser.add_argument("--batch_size", type=int, default=1, help="Batch size")
parser.add_argument('--q_heads', type=int, default=32, help='Number of query heads') parser.add_argument("--q_heads", type=int, default=32, help="Number of query heads")
parser.add_argument('--kv_heads', type=int, default=8, help='Number of key-value heads') parser.add_argument("--kv_heads", type=int, default=8, help="Number of key-value heads")
parser.add_argument('--k_seqlen', type=int, default=8192, help='Key sequence length') parser.add_argument("--k_seqlen", type=int, default=8192, help="Key sequence length")
parser.add_argument( parser.add_argument("--head_size", type=int, default=128, choices=[64, 128, 256], help="Head dimension")
'--head_size', type=int, default=128, choices=[64, 128, 256], help='Head dimension') parser.add_argument("--block_size", type=int, default=64, help="Block size for computation")
parser.add_argument('--block_size', type=int, default=64, help='Block size for computation') parser.add_argument("--dtype", type=str, default=T.bfloat16, choices=[T.float16, T.bfloat16], help="Data type")
parser.add_argument( parser.add_argument("--test_varlen", action="store_true", help="Test with truly variable sequence lengths")
'--dtype', type=str, default='bfloat16', choices=['float16', 'bfloat16'], help='Data type') parser.add_argument("--test_sink", action="store_true", help="Test with sink attention mechanism")
parser.add_argument( parser.add_argument("--benchmark", action="store_true", help="Run speed benchmark")
'--test_varlen', action='store_true', help='Test with truly variable sequence lengths') parser.add_argument("--num_split", type=int, default=1, choices=[1, 16], help="Number of splits")
parser.add_argument(
'--test_sink', action='store_true', help='Test with sink attention mechanism')
parser.add_argument('--benchmark', action='store_true', help='Run speed benchmark')
parser.add_argument(
'--num_split', type=int, default=1, choices=[1, 16], help='Number of splits')
args = parser.parse_args() args = parser.parse_args()
args.test_sink = True args.test_sink = True
args.test_varlen = False args.test_varlen = False
args.dtype = 'float16' args.dtype = T.float16
args.num_split = 1 args.num_split = 1
if args.benchmark: if args.benchmark:
......
import torch
import math
import argparse
import tilelang
import tilelang.language as T
from example_gqa_decode_varlen_logits import flash_attn_with_attn_pool_decode, repeat_kv, do_bench
torch.manual_seed(0)
def get_configs():
import itertools
block_N = [64, 128]
block_H = [64]
num_split = [1]
num_stages = [1, 2, 3]
threads = [128]
_configs = list(itertools.product(block_N, block_H, num_split, num_stages, threads))
configs = [{"block_N": c[0], "block_H": c[1], "num_split": c[2], "num_stages": c[3], "threads": c[4]} for c in _configs]
return configs
# @autotune(configs=get_configs(), warmup=10, rep=10)
@tilelang.jit(out_idx=[-2, -1], debug_root_path="./examples/flash_decoding")
def flashattn(
batch,
heads,
k_heads,
max_seqlen_kv,
total_seqlen_k,
dim,
has_sink,
page_block_size,
block_N=128,
block_H=64,
num_split=1,
num_stages=1,
threads=128,
):
scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
shape_q = [batch, heads, dim]
shape_k = [total_seqlen_k, k_heads, dim]
shape_v = [total_seqlen_k, k_heads, dim]
shape_o = [batch, heads, dim]
shape_s = [batch, heads, math.ceil(max_seqlen_kv / block_N)]
dtype = T.float16
accum_dtype = T.float32
kv_group_num = heads // k_heads
assert page_block_size >= block_N and page_block_size % block_N == 0, (
"page_block_size must be larger than block_N and a multiple of block_N"
)
valid_block_H = min(block_H, kv_group_num)
# TODO: check if max_seqlen_kv is correct for varlen case
@T.macro
def flash_attn(
Q: T.Tensor(shape_q, dtype),
K: T.Tensor(shape_k, dtype),
V: T.Tensor(shape_v, dtype),
cu_seqlens_k: T.Tensor([batch + 1], T.int32),
s_aux: T.Tensor([heads], T.float32),
BLOCK_TABLE: T.Tensor([batch, math.ceil(max_seqlen_kv / block_N)], T.int32),
Output: T.Tensor([batch, heads, dim], dtype),
S: T.Tensor(shape_s, dtype),
):
with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_H, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
O_shared = T.alloc_shared([valid_block_H, dim], dtype)
acc_s = T.alloc_fragment([block_H, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_H, block_N], dtype)
acc_o = T.alloc_fragment([block_H, dim], accum_dtype)
scores_max = T.alloc_fragment([block_H], accum_dtype)
scores_max_prev = T.alloc_fragment([block_H], accum_dtype)
scores_scale = T.alloc_fragment([block_H], accum_dtype)
scores_sum = T.alloc_fragment([block_H], accum_dtype)
logsum = T.alloc_fragment([block_H], accum_dtype)
S_shared = T.alloc_shared([block_H, math.ceil(max_seqlen_kv / block_N)], dtype)
s_aux_shared = T.alloc_shared([block_H], T.float32)
bid = bx
hid = by
cur_kv_head = hid // (kv_group_num // valid_block_H)
cur_start_k = cu_seqlens_k[bid]
cur_end_k = cu_seqlens_k[bid + 1]
cur_seqlen_k = cur_end_k - cur_start_k
T.copy(Q[bid, hid * valid_block_H : hid * valid_block_H + block_H, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
# loop_range = T.ceildiv((seqlen_kv // num_split), block_N)
loop_range = T.ceildiv((cur_seqlen_k // num_split), block_N)
for k in T.Pipelined(loop_range, num_stages=num_stages):
k_start = BLOCK_TABLE[bid, (k * block_N) // page_block_size] * page_block_size + (k * block_N) % page_block_size
T.copy(K[cur_start_k + k_start : cur_start_k + k_start + block_N, cur_kv_head, :], K_shared)
T.clear(acc_s)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
for i, j in T.Parallel(block_H, block_N):
acc_s[i, j] = T.if_then_else(k * block_N + j < cur_seqlen_k, acc_s[i, j], -T.infinity(accum_dtype))
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
# scores_max_prev is m_i
# scores_max is row_max->m_ij in triton
T.copy(scores_max, S_shared[:, k])
# scores_scale is alpha in triton
for i in T.Parallel(block_H):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_H, block_N):
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
# scores_sum is l_ij in triton
# logsum is l_i in triton
for i in T.Parallel(block_H):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)
for i, j in T.Parallel(block_H, dim):
acc_o[i, j] *= scores_scale[i]
v_start = BLOCK_TABLE[bid, (k * block_N) // page_block_size] * page_block_size + (k * block_N) % page_block_size
T.copy(V[cur_start_k + v_start : cur_start_k + v_start + block_N, cur_kv_head, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
if has_sink:
T.copy(s_aux[hid * valid_block_H : hid * valid_block_H + block_H], s_aux_shared)
for i in T.Parallel(block_H):
logsum[i] += s_aux_shared[i]
for i, j in T.Parallel(block_H, dim):
acc_o[i, j] /= logsum[i]
for h, k in T.Parallel(block_H, math.ceil(max_seqlen_kv / block_N)):
S_shared[h, k] = T.exp2((S_shared[h, k] - scores_max[h]) * scale) / logsum[h]
for i in T.Parallel(block_H):
logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
T.copy(acc_o[:valid_block_H, :], O_shared)
T.copy(O_shared, Output[bid, hid * valid_block_H : (hid + 1) * valid_block_H, :])
T.copy(S_shared[:valid_block_H, :], S[bid, hid * valid_block_H : (hid + 1) * valid_block_H, :])
@T.prim_func
def flashattn_gqa_decode_no_split(
Q: T.Tensor(shape_q, dtype),
K: T.Tensor(shape_k, dtype),
V: T.Tensor(shape_v, dtype),
cu_seqlens_k: T.Tensor([batch + 1], T.int32),
s_aux: T.Tensor([heads], T.float32),
BLOCK_TABLE: T.Tensor([batch, math.ceil(max_seqlen_kv / page_block_size)], T.int32),
Output: T.Tensor(shape_o, dtype),
S: T.Tensor(shape_s, dtype),
):
flash_attn(Q, K, V, cu_seqlens_k, s_aux, BLOCK_TABLE, Output, S)
# TODO: split version
return flashattn_gqa_decode_no_split
def flash_attn_with_attn_pool_decode_tilelang(
Q: torch.Tensor, ## [tq = b, q_h, q_dim]
K: torch.Tensor, ## [tk, k_h, k_dim]
V: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlen_k: int,
real_max_k_seqlen: int,
num_split: int,
softmax_scale: float,
s_aux: torch.Tensor = None,
block_size: int = 64,
use_per_kv_head_sparse_index: bool = False,
tl_kernel=None,
block_table: torch.Tensor = None,
):
num_tokens, q_h, head_size = Q.shape
batch = cu_seqlens_k.size(0) - 1
k_h = K.size(1)
assert Q.dim() == K.dim() == 3
assert Q.size(2) == K.size(2)
assert cu_seqlens_k.dim() == 1
assert head_size in {64, 128, 256}
assert Q.is_contiguous()
assert K.is_contiguous()
assert V.is_contiguous()
gqa_group_size = q_h // k_h
O_tl = torch.zeros_like(Q)
S_tl = torch.zeros((batch, q_h, math.ceil(real_max_k_seqlen / block_size)), dtype=Q.dtype, device=Q.device)
O_tl, S_tl = tl_kernel(Q, K, V, cu_seqlens_k, s_aux, block_table)
if use_per_kv_head_sparse_index:
S_tl = torch.max_pool2d(S_tl, kernel_size=(gqa_group_size, 1), stride=(gqa_group_size, 1))
else:
S_tl = torch.max_pool2d(S_tl, kernel_size=(q_h, 1), stride=(q_h, 1))
return O_tl, S_tl
def test_equal_seqlen_decode_main(args):
"""Test decode kernel with equal sequence lengths"""
print("Testing decode kernel with equal sequence lengths")
batch_size = args.batch_size
q_heads = args.q_heads
kv_heads = args.kv_heads
k_seqlen = args.k_seqlen
real_max_k_seqlen = args.k_seqlen
head_size = args.head_size
block_size = args.block_size
page_block_size = args.page_block_size
dtype = torch.bfloat16 if args.dtype == T.bfloat16 else torch.float16
# For decode, query is just 1 token per batch
q = torch.randn(batch_size, q_heads, head_size, device="cuda", dtype=dtype)
k = torch.randn(batch_size, kv_heads, k_seqlen, head_size, device="cuda", dtype=dtype)
v = torch.randn(batch_size, kv_heads, k_seqlen, head_size, device="cuda", dtype=dtype)
softmax_scale = 1.0 / math.sqrt(head_size)
# Generate sink values if needed
sink = None
if args.test_sink:
sink = torch.randn(q_heads, device="cuda", dtype=torch.float32) * 0.1 # Small sink values
print(f"Using sink attention with sink values: {sink}")
# Convert to varlen format for K, V
k_varlen = k.transpose(1, 2).reshape(batch_size * k_seqlen, kv_heads, head_size).contiguous()
v_varlen = v.transpose(1, 2).reshape(batch_size * k_seqlen, kv_heads, head_size).contiguous()
# Generate cumulative sequence lengths
cu_seqlens_k = torch.arange(0, (batch_size + 1) * k_seqlen, k_seqlen, device="cuda", dtype=torch.int32)
max_seqlen_k = k_seqlen
print(f"q shape: {q.shape}")
print(f"k_varlen shape: {k_varlen.shape}")
print(f"v_varlen shape: {v_varlen.shape}")
num_tokens, q_h, head_size = q.shape
batch = cu_seqlens_k.size(0) - 1
k_h = k_varlen.size(1)
tl_kernel = flashattn(batch, q_h, k_h, args.k_seqlen, cu_seqlens_k[-1].item(), head_size, args.test_sink, page_block_size)
block_table = torch.zeros(batch, math.ceil(real_max_k_seqlen / page_block_size), device="cuda", dtype=torch.int32)
block_cnt = 0
for i in range(batch):
cur_seqlen = cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item()
for j in range(math.ceil(cur_seqlen / page_block_size)):
block_table[i, j] = block_cnt
block_cnt += 1
block_cnt = 0
# Test our decode kernel
O_triton, S_triton = flash_attn_with_attn_pool_decode(
q,
k_varlen,
v_varlen,
cu_seqlens_k,
max_seqlen_k,
real_max_k_seqlen,
args.num_split,
softmax_scale,
s_aux=sink,
block_size=block_size,
)
O_tilelang, S_tilelang = flash_attn_with_attn_pool_decode_tilelang(
q,
k_varlen,
v_varlen,
cu_seqlens_k,
max_seqlen_k,
real_max_k_seqlen,
args.num_split,
softmax_scale,
s_aux=sink,
block_size=block_size,
tl_kernel=tl_kernel,
block_table=block_table,
)
for i in range(batch_size):
S_tilelang[i, :, math.ceil((cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item()) / block_size) :] = 0
# Compute torch reference
q_expanded = q.unsqueeze(2) # [b, q_heads, 1, head_size]
k_repeat = repeat_kv(k, q_heads // kv_heads) # [b, q_heads, k_seqlen, head_size]
v_repeat = repeat_kv(v, q_heads // kv_heads) # [b, q_heads, k_seqlen, head_size]
if sink is None:
# Standard scaled dot-product attention
logits = torch.matmul(q_expanded, k_repeat.transpose(-2, -1)) * softmax_scale # [batch, q_heads, 1, seqlen_k]
attn_weights = torch.softmax(logits, dim=-1)
O_torch = torch.matmul(attn_weights, v_repeat).squeeze(2) # [batch, q_heads, head_size]
else:
# s_aux attention
logits = torch.matmul(q_expanded, k_repeat.transpose(-2, -1)) * softmax_scale # [batch, q_heads, 1, seqlen_k]
sink_expanded = sink.view(1, q_heads, 1, 1) # [1, q_heads, 1, 1]
logits_max = torch.max(logits, dim=-1, keepdim=True).values
logits_or_sinks_max = torch.maximum(logits_max, sink_expanded)
sinks = torch.exp(sink_expanded - logits_or_sinks_max)
unnormalized_scores = torch.exp(logits - logits_or_sinks_max)
normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks
attn_weights = unnormalized_scores / normalizer
O_torch = torch.matmul(attn_weights.to(v_repeat.dtype), v_repeat).squeeze(2) # [batch, q_heads, head_size]
# Compute attention score pooling
attn_score_pooled = torch.max_pool2d(
attn_weights.squeeze(2), # [b, q_heads, k_seqlen]
kernel_size=(q_heads, block_size),
stride=(q_heads, block_size),
ceil_mode=True,
).to(torch.float16)
print("S_tilelang", S_tilelang)
print("attn_score_pooled", attn_score_pooled)
max_diff_o = torch.max(torch.abs(O_triton - O_torch))
max_diff_s = torch.max(torch.abs(S_triton - attn_score_pooled))
max_diff_o_tilelang = torch.max(torch.abs(O_tilelang - O_torch))
max_diff_s_tilelang = torch.max(torch.abs(S_tilelang - attn_score_pooled))
print(f"Max difference in O: {max_diff_o.item()}")
print(f"Max difference in S: {max_diff_s.item()}")
print(f"Max difference in O_tilelang: {max_diff_o_tilelang.item()}")
print(f"Max difference in S_tilelang: {max_diff_s_tilelang.item()}")
assert torch.allclose(O_triton, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o.item()}"
assert torch.allclose(S_triton, attn_score_pooled, atol=1e-2, rtol=1e-2), f"Score mismatch: {max_diff_s.item()}"
assert torch.allclose(O_tilelang, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o_tilelang.item()}"
assert torch.allclose(S_tilelang, attn_score_pooled, atol=1e-2, rtol=1e-2), f"Score mismatch: {max_diff_s_tilelang.item()}"
print("✅ All tests passed!")
def test_varlen_decode_main(args):
"""Test decode kernel with variable sequence lengths"""
batch_size = args.batch_size
q_heads = args.q_heads
kv_heads = args.kv_heads
max_k_seqlen = args.k_seqlen # Use as max sequence length
real_max_k_seqlen = args.k_seqlen
head_size = args.head_size
block_size = args.block_size
page_block_size = args.page_block_size
dtype = torch.bfloat16 if args.dtype == T.bfloat16 else torch.float16
print(f"Testing decode kernel with variable sequence lengths (max_k_seqlen={max_k_seqlen})")
# Generate sink values if needed
sink = None
if args.test_sink:
sink = torch.randn(q_heads, device="cuda", dtype=torch.float32) * 0.1 # Small sink values
print(f"Using sink attention with sink values: {sink}")
# Generate variable length k sequences
k_seqlens = torch.randint(max_k_seqlen // 4, max_k_seqlen + 1, size=(batch_size,))
print(f"k_seqlens: {k_seqlens}")
# Generate cumulative sequence lengths for k
cu_seqlens_k = torch.zeros(batch_size + 1, device="cuda", dtype=torch.int32)
total_k_tokens = 0
for i in range(batch_size):
cu_seqlens_k[i] = total_k_tokens
total_k_tokens += k_seqlens[i]
cu_seqlens_k[batch_size] = total_k_tokens
print(f"cu_seqlens_k: {cu_seqlens_k}")
# Generate tensors - Q is [batch_size, q_heads, head_size] for decode
q_decode = torch.randn(batch_size, q_heads, head_size, device="cuda", dtype=dtype)
k_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device="cuda", dtype=dtype)
v_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device="cuda", dtype=dtype)
softmax_scale = 1.0 / math.sqrt(head_size)
max_seqlen_k = int(k_seqlens.max())
print(f"Actual max_seqlen_k: {max_seqlen_k}")
print(f"q_decode shape: {q_decode.shape}")
print(f"k_varlen shape: {k_varlen.shape}")
print(f"v_varlen shape: {v_varlen.shape}")
num_tokens, q_h, head_size = q_decode.shape
batch = cu_seqlens_k.size(0) - 1
k_h = k_varlen.size(1)
tl_kernel = flashattn(batch, q_h, k_h, args.k_seqlen, cu_seqlens_k[-1].item(), head_size, args.test_sink, page_block_size)
block_table = torch.zeros(batch, math.ceil(real_max_k_seqlen / page_block_size), device="cuda", dtype=torch.int32)
block_cnt = 0
for i in range(batch):
cur_seqlen = cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item()
for j in range(math.ceil(cur_seqlen / page_block_size)):
block_table[i, j] = block_cnt
block_cnt += 1
block_cnt = 0
# Test our decode kernel
O_triton, S_triton = flash_attn_with_attn_pool_decode(
q_decode,
k_varlen,
v_varlen,
cu_seqlens_k,
max_seqlen_k,
real_max_k_seqlen,
args.num_split,
softmax_scale,
s_aux=sink,
block_size=block_size,
)
O_tilelang, S_tilelang = flash_attn_with_attn_pool_decode_tilelang(
q_decode,
k_varlen,
v_varlen,
cu_seqlens_k,
max_seqlen_k,
real_max_k_seqlen,
args.num_split,
softmax_scale,
s_aux=sink,
block_size=block_size,
tl_kernel=tl_kernel,
block_table=block_table,
)
for i in range(batch_size):
S_tilelang[i, :, math.ceil((cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item()) / block_size) :] = 0
# Create torch reference - pad tensors for comparison
k_padded_list = []
v_padded_list = []
for i in range(batch_size):
actual_k_len = k_seqlens[i]
# Extract and pad k, v for this batch
k_start = cu_seqlens_k[i]
k_end = cu_seqlens_k[i + 1]
# Pad to max_seqlen_k
k_padded = torch.zeros(max_seqlen_k, kv_heads, head_size, device="cuda", dtype=dtype)
v_padded = torch.zeros(max_seqlen_k, kv_heads, head_size, device="cuda", dtype=dtype)
k_padded[:actual_k_len] = k_varlen[k_start:k_end]
v_padded[:actual_k_len] = v_varlen[k_start:k_end]
k_padded_list.append(k_padded)
v_padded_list.append(v_padded)
# Stack to create batched tensors [b, max_seqlen, kv_heads, head_size]
k_padded_batched = torch.stack(k_padded_list, dim=0).transpose(1, 2) # [b, kv_heads, max_seqlen, head_size]
v_padded_batched = torch.stack(v_padded_list, dim=0).transpose(1, 2) # [b, kv_heads, max_seqlen, head_size]
# Expand q to match kv heads: [b, q_heads, 1, head_size]
q_expanded = q_decode.unsqueeze(2) # [b, q_heads, 1, head_size]
print(f"q_expanded shape: {q_expanded.shape}")
print(f"k_padded_batched shape: {k_padded_batched.shape}")
print(f"v_padded_batched shape: {v_padded_batched.shape}")
# Compute torch reference
k_repeat = repeat_kv(k_padded_batched, q_heads // kv_heads) # [b, q_heads, max_seqlen, head_size]
v_repeat = repeat_kv(v_padded_batched, q_heads // kv_heads) # [b, q_heads, max_seqlen, head_size]
if sink is None:
# Standard attention computation: [b, q_heads, 1, head_size] @ [b, q_heads, head_size, max_seqlen]
attn_score = torch.matmul(q_expanded, k_repeat.transpose(-2, -1)) * softmax_scale # [b, q_heads, 1, max_seqlen]
# Apply sequence length masking
for i in range(batch_size):
actual_k_len = k_seqlens[i]
attn_score[i, :, :, actual_k_len:] = float("-inf")
attn_weights = attn_score.softmax(dim=-1) # [b, q_heads, 1, max_seqlen]
# Mask out invalid positions
for i in range(batch_size):
actual_k_len = k_seqlens[i]
attn_weights[i, :, :, actual_k_len:] = 0.0
# Compute output: [b, q_heads, 1, max_seqlen] @ [b, q_heads, max_seqlen, head_size]
O_torch = torch.matmul(attn_weights, v_repeat) # [b, q_heads, 1, head_size]
else:
# s_aux attention
logits = torch.matmul(q_expanded, k_repeat.transpose(-2, -1)) * softmax_scale # [b, q_heads, 1, max_seqlen]
# Apply sequence length masking
for i in range(batch_size):
actual_k_len = k_seqlens[i]
logits[i, :, :, actual_k_len:] = float("-inf")
sink_expanded = sink.view(1, q_heads, 1, 1) # [1, q_heads, 1, 1]
logits_max = torch.max(logits, dim=-1, keepdim=True).values
logits_or_sinks_max = torch.maximum(logits_max, sink_expanded)
sinks = torch.exp(sink_expanded - logits_or_sinks_max)
unnormalized_scores = torch.exp(logits - logits_or_sinks_max)
normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks
attn_weights = unnormalized_scores / normalizer
# Mask out invalid positions
for i in range(batch_size):
actual_k_len = k_seqlens[i]
attn_weights[i, :, :, actual_k_len:] = 0.0
# Compute output: [b, q_heads, 1, max_seqlen] @ [b, q_heads, max_seqlen, head_size]
O_torch = torch.matmul(attn_weights.to(v_repeat.dtype), v_repeat) # [b, q_heads, 1, head_size]
O_torch = O_torch.squeeze(2) # [b, q_heads, head_size]
# Compute attention score pooling for S
attn_score_pooled = torch.max_pool2d(
attn_weights.squeeze(2), # [b, q_heads, max_seqlen]
kernel_size=(q_heads, block_size),
stride=(q_heads, block_size),
ceil_mode=True,
).to(dtype=torch.float16) # [b, 1, ceil(max_seqlen/block_size)]
print(f"O_triton shape: {O_triton.shape}")
print(f"O_tilelang shape: {O_tilelang.shape}")
print(f"O_torch shape: {O_torch.shape}")
print(f"S_triton shape: {S_triton.shape}")
print(f"S_tilelang shape: {S_tilelang.shape}")
print(f"attn_score_pooled shape: {attn_score_pooled.shape}")
# Compare results
max_diff_o = torch.max(torch.abs(O_triton - O_torch))
max_diff_o_tl = torch.max(torch.abs(O_tilelang - O_torch))
print(f"Max difference in O: {max_diff_o.item()}")
print(f"Max difference in O_tilelang: {max_diff_o_tl.item()}")
max_diff_s = torch.max(torch.abs(S_triton - attn_score_pooled))
max_diff_s_tl = torch.max(torch.abs(S_tilelang[:, :, : math.ceil(max_seqlen_k / block_size)] - attn_score_pooled))
print(f"Max difference in S: {max_diff_s.item()}")
print(f"Max difference in S_tilelang: {max_diff_s_tl.item()}")
assert torch.allclose(O_triton, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o.item()}"
assert torch.allclose(S_triton, attn_score_pooled, atol=1e-2, rtol=1e-2), f"Score mismatch: {max_diff_s.item()}"
assert torch.allclose(O_tilelang, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o_tl.item()}"
assert torch.allclose(S_tilelang[:, :, : math.ceil(max_seqlen_k / block_size)], attn_score_pooled, atol=1e-2, rtol=1e-2), (
f"Score mismatch: {max_diff_s_tl.item()}"
)
print("✅ All tests passed!")
def speed_benchmark_decode_comparison(args):
"""Speed benchmark for decode kernel"""
batch_size = args.batch_size
q_heads = args.q_heads
kv_heads = args.kv_heads
max_k_seqlen = args.k_seqlen
real_max_k_seqlen = args.k_seqlen
head_size = args.head_size
block_size = args.block_size
page_block_size = args.page_block_size
dtype = torch.bfloat16 if args.dtype == T.bfloat16 else torch.float16
print("\n=== Decode Speed Benchmark Comparison ===")
print("Configuration:")
print(f" Batch size: {batch_size}")
print(f" Q heads: {q_heads}, KV heads: {kv_heads}")
print(f" Max K sequence length: {max_k_seqlen}")
print(f" Head size: {head_size}")
print(f" Block size: {block_size}")
print(f" Data type: {dtype}")
print(f" Variable lengths: {args.test_varlen}")
print(f" s_aux attention: {args.test_sink}")
print()
# Generate input data
if args.test_varlen:
k_seqlens = torch.randint(max_k_seqlen // 4, max_k_seqlen + 1, size=(batch_size,))
else:
k_seqlens = torch.full((batch_size,), max_k_seqlen, dtype=int)
# Generate cumulative sequence lengths for k
cu_seqlens_k = torch.zeros(batch_size + 1, device="cuda", dtype=torch.int32)
total_k_tokens = 0
for i in range(batch_size):
cu_seqlens_k[i] = total_k_tokens
total_k_tokens += k_seqlens[i]
cu_seqlens_k[batch_size] = total_k_tokens
# Generate tensors
q_decode = torch.randn(batch_size, q_heads, head_size, device="cuda", dtype=dtype)
k_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device="cuda", dtype=dtype)
v_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device="cuda", dtype=dtype)
softmax_scale = 1.0 / math.sqrt(head_size)
max_seqlen_k = int(k_seqlens.max())
# Generate sink values if needed
sink = None
if args.test_sink:
sink = torch.randn(q_heads, device="cuda", dtype=torch.float32) * 0.1 # Small sink values
print(" Using sink attention with sink values")
print("Setup complete:")
print(f" Total K tokens: {total_k_tokens}")
print(f" Actual max K seq len: {max_seqlen_k}")
if args.test_varlen:
print(f" K sequence lengths: {k_seqlens.tolist()}")
# Warmup
num_tokens, q_h, head_size = q_decode.shape
batch = cu_seqlens_k.size(0) - 1
k_h = k_varlen.size(1)
tl_kernel = flashattn(batch, q_h, k_h, args.k_seqlen, cu_seqlens_k[-1].item(), head_size, args.test_sink, page_block_size)
block_table = torch.zeros(batch, math.ceil(real_max_k_seqlen / page_block_size), device="cuda", dtype=torch.int32)
block_cnt = 0
for i in range(batch):
cur_seqlen = cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item()
for j in range(math.ceil(cur_seqlen / page_block_size)):
block_table[i, j] = block_cnt
block_cnt += 1
block_cnt = 0
# Benchmark
print("⚡ Benchmarking Tilelang kernel (100 iterations)...")
tilelang_time = do_bench(
flash_attn_with_attn_pool_decode_tilelang,
q_decode,
k_varlen,
v_varlen,
cu_seqlens_k,
max_seqlen_k,
args.k_seqlen,
1,
softmax_scale,
sink,
block_size,
False,
tl_kernel,
block_table,
)
print(f"Average decode kernel time Tilelang: {tilelang_time:.3f} ms")
# Benchmark
print("⚡ Benchmarking Triton kernel (100 iterations)...")
triton_time = do_bench(
flash_attn_with_attn_pool_decode,
q_decode,
k_varlen,
v_varlen,
cu_seqlens_k,
max_seqlen_k,
args.k_seqlen,
1,
softmax_scale,
sink,
block_size,
)
print(f"Average decode kernel time Triton: {triton_time:.3f} ms")
print(f"Speedup: {(triton_time / tilelang_time):.3f}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Flash Attention Decode with Attention Pooling")
parser.add_argument("--batch_size", type=int, default=1, help="Batch size")
parser.add_argument("--q_heads", type=int, default=32, help="Number of query heads")
parser.add_argument("--kv_heads", type=int, default=8, help="Number of key-value heads")
parser.add_argument("--k_seqlen", type=int, default=8192, help="Key sequence length")
parser.add_argument("--head_size", type=int, default=128, choices=[64, 128, 256], help="Head dimension")
parser.add_argument("--block_size", type=int, default=128, help="Block size for computation")
parser.add_argument("--dtype", type=str, default=T.bfloat16, choices=[T.float16, T.bfloat16], help="Data type")
parser.add_argument("--test_varlen", action="store_true", help="Test with truly variable sequence lengths")
parser.add_argument("--test_sink", action="store_true", help="Test with sink attention mechanism")
parser.add_argument("--benchmark", action="store_true", help="Run speed benchmark")
parser.add_argument("--num_split", type=int, default=1, choices=[1, 16], help="Number of splits")
parser.add_argument("--page_block_size", type=int, default=128, help="Page block size")
args = parser.parse_args()
args.test_sink = True
args.test_varlen = True
args.dtype = T.float16
args.num_split = 1
if args.benchmark:
speed_benchmark_decode_comparison(args)
elif args.test_varlen:
test_varlen_decode_main(args)
else:
test_equal_seqlen_decode_main(args)
...@@ -10,12 +10,12 @@ num_split = 4 ...@@ -10,12 +10,12 @@ num_split = 4
@tilelang.jit(out_idx=[5]) @tilelang.jit(out_idx=[5])
def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_N): def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_N):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
shape_q = [batch, seqlen_q, heads, dim] shape_q = [batch, seqlen_q, heads, dim]
shape_kv = [batch, seqlen_kv, heads, dim] shape_kv = [batch, seqlen_kv, heads, dim]
part_shape = [batch, seqlen_q, heads, num_split, dim] part_shape = [batch, seqlen_q, heads, num_split, dim]
dtype = "float16" dtype = T.float16
accum_dtype = "float" accum_dtype = T.float32
@T.macro @T.macro
def MMA0( def MMA0(
...@@ -29,14 +29,11 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_ ...@@ -29,14 +29,11 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_
bid: T.int32, bid: T.int32,
sid: T.int32, sid: T.int32,
): ):
T.copy( T.copy(K[bid, (seqlen_kv // num_split) * sid + k * block_N : (seqlen_kv // num_split) * sid + (k + 1) * block_N, hid, :], K_shared)
K[bid, (seqlen_kv // num_split) * sid + k * block_N:(seqlen_kv // num_split) * sid +
(k + 1) * block_N, hid, :], K_shared)
# TODO: Handle causal split case # TODO: Handle causal split case
if is_causal: if is_causal:
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(mid * block_M + i >= k * block_N + j, 0, acc_s[i, j] = T.if_then_else(mid * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype))
-T.infinity(acc_s.dtype))
else: else:
T.clear(acc_s) T.clear(acc_s)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
...@@ -52,24 +49,24 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_ ...@@ -52,24 +49,24 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_
bid: T.int32, bid: T.int32,
sid: T.int32, sid: T.int32,
): ):
T.copy( T.copy(V[bid, (seqlen_kv // num_split) * sid + k * block_N : (seqlen_kv // num_split) * sid + (k + 1) * block_N, hid, :], V_shared)
V[bid, (seqlen_kv // num_split) * sid + k * block_N:(seqlen_kv // num_split) * sid +
(k + 1) * block_N, hid, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
@T.macro @T.macro
def Softmax( def Softmax(
acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype),
acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype),
scores_max: T.FragmentBuffer([block_M], accum_dtype), scores_max: T.FragmentBuffer([block_M], accum_dtype),
scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), scores_max_prev: T.FragmentBuffer([block_M], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype), scores_scale: T.FragmentBuffer([block_M], accum_dtype),
scores_sum: T.FragmentBuffer([block_M], accum_dtype), scores_sum: T.FragmentBuffer([block_M], accum_dtype),
logsum: T.FragmentBuffer([block_M], accum_dtype), logsum: T.FragmentBuffer([block_M], accum_dtype),
): ):
T.copy(scores_max, scores_max_prev) T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False) T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_M):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
# To do causal softmax, we need to set the scores_max to 0 if it is -inf # To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done # This process is called Check_inf in FlashAttention3 code, and it only need to be done
# in the first ceil_div(kBlockM, kBlockN) steps. # in the first ceil_div(kBlockM, kBlockN) steps.
...@@ -89,23 +86,21 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_ ...@@ -89,23 +86,21 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_
@T.macro @T.macro
def Rescale( def Rescale(
acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), acc_o: T.FragmentBuffer([block_M, dim], accum_dtype),
scores_scale: T.FragmentBuffer([block_M], accum_dtype), scores_scale: T.FragmentBuffer([block_M], accum_dtype),
): ):
for i, j in T.Parallel(block_M, dim): for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i] acc_o[i, j] *= scores_scale[i]
@T.macro @T.macro
def flash_attn_split( def flash_attn_split(
Q: T.Tensor(shape_q, dtype), Q: T.Tensor(shape_q, dtype),
K: T.Tensor(shape_kv, dtype), K: T.Tensor(shape_kv, dtype),
V: T.Tensor(shape_kv, dtype), V: T.Tensor(shape_kv, dtype),
glse: T.Tensor([batch, heads, num_split, seqlen_q], dtype), glse: T.Tensor([batch, heads, num_split, seqlen_q], dtype),
Output_partial: T.Tensor(part_shape, dtype), Output_partial: T.Tensor(part_shape, dtype),
): ):
with T.Kernel( with T.Kernel(T.ceildiv(seqlen_q, block_M), heads * batch, num_split, threads=128) as (bx, by, bz):
T.ceildiv(seqlen_q, block_M), heads * batch, num_split,
threads=128) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype) Q_shared = T.alloc_shared([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype) K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype) V_shared = T.alloc_shared([block_N, dim], dtype)
...@@ -126,39 +121,36 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_ ...@@ -126,39 +121,36 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_
# NOTE(wt): tma barrier has some problems with padded dimensions (seq_q here) currently # NOTE(wt): tma barrier has some problems with padded dimensions (seq_q here) currently
# disable relevant tma copy and use SIMT as fallback for now # disable relevant tma copy and use SIMT as fallback for now
T.copy(Q[bid, mid * block_M:(mid + 1) * block_M, hid, :], Q_shared, disable_tma=True) T.copy(Q[bid, mid * block_M : (mid + 1) * block_M, hid, :], Q_shared, disable_tma=True)
T.fill(acc_o, 0) T.fill(acc_o, 0)
T.fill(logsum, 0) T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype)) T.fill(scores_max, -T.infinity(accum_dtype))
# TODO: Handle causal split case # TODO: Handle causal split case
loop_range = ( loop_range = (
T.min(T.ceildiv(seqlen_kv, block_N), T.ceildiv( T.min(T.ceildiv(seqlen_kv, block_N), T.ceildiv((mid + 1) * block_M, block_N))
(mid + 1) * block_M, block_N)) if is_causal else T.ceildiv( if is_causal
(seqlen_kv // num_split), block_N)) else T.ceildiv((seqlen_kv // num_split), block_N)
)
for k in T.Pipelined(loop_range, num_stages=2): for k in T.Pipelined(loop_range, num_stages=2):
MMA0(K, Q_shared, K_shared, acc_s, k, mid, hid, bid, sid) MMA0(K, Q_shared, K_shared, acc_s, k, mid, hid, bid, sid)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, logsum)
logsum)
Rescale(acc_o, scores_scale) Rescale(acc_o, scores_scale)
MMA1(V, V_shared, acc_s_cast, acc_o, k, hid, bid, sid) MMA1(V, V_shared, acc_s_cast, acc_o, k, hid, bid, sid)
for i, j in T.Parallel(block_M, dim): for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i] acc_o[i, j] /= logsum[i]
for i in T.Parallel(block_M): for i in T.Parallel(block_M):
logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
T.copy(logsum, glse[bid, hid, sid, mid * block_M:(mid + 1) * block_M]) T.copy(logsum, glse[bid, hid, sid, mid * block_M : (mid + 1) * block_M])
T.copy(acc_o, O_shared) T.copy(acc_o, O_shared)
T.copy( T.copy(O_shared, Output_partial[bid, mid * block_M : (mid + 1) * block_M, hid, sid, :], disable_tma=True)
O_shared,
Output_partial[bid, mid * block_M:(mid + 1) * block_M, hid, sid, :],
disable_tma=True)
@T.macro @T.macro
def combine( def combine(
glse: T.Tensor([batch, heads, num_split, seqlen_q], dtype), glse: T.Tensor([batch, heads, num_split, seqlen_q], dtype),
Output_partial: T.Tensor(part_shape, dtype), Output_partial: T.Tensor(part_shape, dtype),
Output: T.Tensor(shape_q, dtype), Output: T.Tensor(shape_q, dtype),
): ):
with T.Kernel(T.ceildiv(seqlen_q, block_M), heads, batch, threads=128) as (bx, by, bz): with T.Kernel(T.ceildiv(seqlen_q, block_M), heads, batch, threads=128) as (bx, by, bz):
po_local = T.alloc_fragment([block_M, dim], dtype) po_local = T.alloc_fragment([block_M, dim], dtype)
...@@ -171,20 +163,25 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_ ...@@ -171,20 +163,25 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_
lse_max_local = T.alloc_fragment([block_M], accum_dtype) lse_max_local = T.alloc_fragment([block_M], accum_dtype)
scale_local = T.alloc_fragment([block_M], accum_dtype) scale_local = T.alloc_fragment([block_M], accum_dtype)
T.annotate_layout({ T.annotate_layout(
o_accum_local: T.Fragment(o_accum_local.shape, forward_thread_fn=lambda i, j: i), {
o_shared: tilelang.layout.make_swizzled_layout(o_shared), o_accum_local: T.Fragment(o_accum_local.shape, forward_thread_fn=lambda i, j: i),
po_shared: tilelang.layout.make_swizzled_layout(po_shared), o_shared: tilelang.layout.make_swizzled_layout(o_shared),
}) po_shared: tilelang.layout.make_swizzled_layout(po_shared),
}
)
T.clear(lse_logsum_local) T.clear(lse_logsum_local)
T.clear(o_accum_local) T.clear(o_accum_local)
T.copy(glse[ T.copy(
bz, glse[
by, bz,
:, by,
bx * block_M:(bx + 1) * block_M, :,
], lse_local) bx * block_M : (bx + 1) * block_M,
],
lse_local,
)
T.reduce_max(lse_local, lse_max_local, dim=0, clear=False) T.reduce_max(lse_local, lse_max_local, dim=0, clear=False)
for k in T.Pipelined(num_split): for k in T.Pipelined(num_split):
T.copy(lse_local[k, :], lse_local_split) T.copy(lse_local[k, :], lse_local_split)
...@@ -193,10 +190,7 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_ ...@@ -193,10 +190,7 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_
for i in T.Parallel(block_M): for i in T.Parallel(block_M):
lse_logsum_local[i] = T.log2(lse_logsum_local[i]) + lse_max_local[i] lse_logsum_local[i] = T.log2(lse_logsum_local[i]) + lse_max_local[i]
for k in T.Pipelined(num_split, num_stages=2): for k in T.Pipelined(num_split, num_stages=2):
T.copy( T.copy(Output_partial[bz, bx * block_M : (bx + 1) * block_M, by, k, :], po_shared, disable_tma=True)
Output_partial[bz, bx * block_M:(bx + 1) * block_M, by, k, :],
po_shared,
disable_tma=True)
T.copy(po_shared, po_local) T.copy(po_shared, po_local)
for i in T.Parallel(block_M): for i in T.Parallel(block_M):
lse_local_split[i] = lse_local[k, i] lse_local_split[i] = lse_local[k, i]
...@@ -205,16 +199,16 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_ ...@@ -205,16 +199,16 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_
for i, j in T.Parallel(block_M, dim): for i, j in T.Parallel(block_M, dim):
o_accum_local[i, j] += po_local[i, j] * scale_local[i] o_accum_local[i, j] += po_local[i, j] * scale_local[i]
T.copy(o_accum_local, o_shared) T.copy(o_accum_local, o_shared)
T.copy(o_shared, Output[bz, bx * block_M:(bx + 1) * block_M, by, :], disable_tma=True) T.copy(o_shared, Output[bz, bx * block_M : (bx + 1) * block_M, by, :], disable_tma=True)
@T.prim_func @T.prim_func
def flashattn_mha_inference( def flashattn_mha_inference(
Q: T.Tensor(shape_q, dtype), Q: T.Tensor(shape_q, dtype),
K: T.Tensor(shape_kv, dtype), K: T.Tensor(shape_kv, dtype),
V: T.Tensor(shape_kv, dtype), V: T.Tensor(shape_kv, dtype),
glse: T.Tensor([batch, heads, num_split, seqlen_q], dtype), glse: T.Tensor([batch, heads, num_split, seqlen_q], dtype),
Output_partial: T.Tensor(part_shape, dtype), # [batch, seqlen_q, heads, num_split, dim] Output_partial: T.Tensor(part_shape, dtype), # [batch, seqlen_q, heads, num_split, dim]
Output: T.Tensor(shape_q, dtype), Output: T.Tensor(shape_q, dtype),
): ):
flash_attn_split(Q, K, V, glse, Output_partial) flash_attn_split(Q, K, V, glse, Output_partial)
combine(glse, Output_partial, Output) combine(glse, Output_partial, Output)
...@@ -225,10 +219,10 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_ ...@@ -225,10 +219,10 @@ def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_
def ref_program(Q, K, V, glse, Output_partial, causal): def ref_program(Q, K, V, glse, Output_partial, causal):
assert causal is False assert causal is False
dim = Q.size(-1) dim = Q.size(-1)
scores = torch.einsum('bqhd,bkhd->bhqk', Q, K) scores = torch.einsum("bqhd,bkhd->bhqk", Q, K)
scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype))
attention_weights = F.softmax(scores, dim=-1) attention_weights = F.softmax(scores, dim=-1)
output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V) output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V)
return output return output
...@@ -256,7 +250,7 @@ def flash_split_ref(Q, K, V, causal): ...@@ -256,7 +250,7 @@ def flash_split_ref(Q, K, V, causal):
block_N = 128 block_N = 128
seqlen_kv = K.size(1) seqlen_kv = K.size(1)
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
acc_s = torch.empty((batch, nheads, block_M, block_N), device="cuda", dtype=torch.float) acc_s = torch.empty((batch, nheads, block_M, block_N), device="cuda", dtype=torch.float)
acc_s_cast = torch.empty((batch, nheads, block_M, block_N), device="cuda", dtype=torch.float16) acc_s_cast = torch.empty((batch, nheads, block_M, block_N), device="cuda", dtype=torch.float16)
acc_o = torch.empty((batch, block_M, nheads, dim), device="cuda", dtype=torch.float) acc_o = torch.empty((batch, block_M, nheads, dim), device="cuda", dtype=torch.float)
...@@ -273,14 +267,15 @@ def flash_split_ref(Q, K, V, causal): ...@@ -273,14 +267,15 @@ def flash_split_ref(Q, K, V, causal):
for ks in range(num_split): for ks in range(num_split):
acc_o.fill_(0) acc_o.fill_(0)
logsum.fill_(0) logsum.fill_(0)
scores_max.fill_(float('-inf')) scores_max.fill_(float("-inf"))
scores_max_prev.fill_(float('-inf')) scores_max_prev.fill_(float("-inf"))
for i in range(int((seqlen_kv // num_split) / block_N)): for i in range(int((seqlen_kv // num_split) / block_N)):
acc_s.fill_(0) acc_s.fill_(0)
acc_s = torch.einsum('bqhd,bkhd->bhqk', Q_, acc_s = torch.einsum(
K[:, (seqlen_kv // num_split) * ks + "bqhd,bkhd->bhqk",
i * block_N:(seqlen_kv // num_split) * ks + Q_,
(i + 1) * block_N, :, :]) # [batch, seqlen, nheads, block_N] K[:, (seqlen_kv // num_split) * ks + i * block_N : (seqlen_kv // num_split) * ks + (i + 1) * block_N, :, :],
) # [batch, seqlen, nheads, block_N]
scores_max_prev = scores_max scores_max_prev = scores_max
scores_max = acc_s.max(dim=-1, keepdim=False).values # [blockM] scores_max = acc_s.max(dim=-1, keepdim=False).values # [blockM]
scores_scale = torch.exp2(scores_max_prev - scores_max) scores_scale = torch.exp2(scores_max_prev - scores_max)
...@@ -288,9 +283,10 @@ def flash_split_ref(Q, K, V, causal): ...@@ -288,9 +283,10 @@ def flash_split_ref(Q, K, V, causal):
acc_s = torch.exp2(acc_s - scores_max[:, :, :, None]) acc_s = torch.exp2(acc_s - scores_max[:, :, :, None])
acc_s_cast = acc_s.to(torch.float16) acc_s_cast = acc_s.to(torch.float16)
acc_o += torch.einsum( acc_o += torch.einsum(
'bhqk,bkhd->bqhd', acc_s_cast, "bhqk,bkhd->bqhd",
V[:, (seqlen_kv // num_split) * ks + i * block_N:(seqlen_kv // num_split) * ks + acc_s_cast,
(i + 1) * block_N, :, :]) V[:, (seqlen_kv // num_split) * ks + i * block_N : (seqlen_kv // num_split) * ks + (i + 1) * block_N, :, :],
)
scores_sum = acc_s.sum(dim=-1, keepdim=False) scores_sum = acc_s.sum(dim=-1, keepdim=False)
logsum = logsum * scores_scale + scores_sum logsum = logsum * scores_scale + scores_sum
acc_o /= logsum[:, :, :, None].transpose(1, 2) acc_o /= logsum[:, :, :, None].transpose(1, 2)
...@@ -298,8 +294,7 @@ def flash_split_ref(Q, K, V, causal): ...@@ -298,8 +294,7 @@ def flash_split_ref(Q, K, V, causal):
gacc_o[ks, :, :, :, :] = acc_o gacc_o[ks, :, :, :, :] = acc_o
glogsum[ks, :, :, :] = logsum glogsum[ks, :, :, :] = logsum
return glogsum.to(torch.float16).permute(1, 2, 0, return glogsum.to(torch.float16).permute(1, 2, 0, 3), gacc_o.to(torch.float16).permute(1, 2, 3, 0, 4)
3), gacc_o.to(torch.float16).permute(1, 2, 3, 0, 4)
def main(BATCH=1, H=32, Q_CTX=128, KV_CTX=8192, D_HEAD=128, causal=False): def main(BATCH=1, H=32, Q_CTX=128, KV_CTX=8192, D_HEAD=128, causal=False):
......
...@@ -9,17 +9,18 @@ from example_fusedmoe_torch import * ...@@ -9,17 +9,18 @@ from example_fusedmoe_torch import *
@tilelang.jit(pass_configs={"tl.disable_tma_lower": True, "tl.disable_warp_specialized": True}) @tilelang.jit(pass_configs={"tl.disable_tma_lower": True, "tl.disable_warp_specialized": True})
def moe_forward_tilelang_shared(d_hidden, def moe_forward_tilelang_shared(
d_expert, d_hidden,
n_shared_experts, d_expert,
dtype, n_shared_experts,
num_tokens, dtype,
block_token=128, num_tokens,
block_dhidden=128, block_token=128,
block_dexpert=128, block_dhidden=128,
threads=256, block_dexpert=128,
num_stages=1): threads=256,
num_stages=1,
):
scale = 1.44269504 # log2(e) scale = 1.44269504 # log2(e)
# Parameters # Parameters
...@@ -32,21 +33,19 @@ def moe_forward_tilelang_shared(d_hidden, ...@@ -32,21 +33,19 @@ def moe_forward_tilelang_shared(d_hidden,
shared_W_up_shape = (dexpert, dhidden) shared_W_up_shape = (dexpert, dhidden)
shared_W_down_shape = (dhidden, dexpert) shared_W_down_shape = (dhidden, dexpert)
accum_type = "float32" accum_type = T.float32
@T.prim_func @T.prim_func
def kernel_shared( def kernel_shared(
input: T.Tensor(input_shape, dtype), # type: ignore input: T.Tensor(input_shape, dtype), # type: ignore
shared_W_gate: T.Tensor(shared_W_gate_shape, dtype), # type: ignore shared_W_gate: T.Tensor(shared_W_gate_shape, dtype), # type: ignore
shared_W_up: T.Tensor(shared_W_up_shape, dtype), # type: ignore shared_W_up: T.Tensor(shared_W_up_shape, dtype), # type: ignore
shared_W_down: T.Tensor(shared_W_down_shape, dtype), # type: ignore shared_W_down: T.Tensor(shared_W_down_shape, dtype), # type: ignore
up_logits: T.Tensor((num_tokens, dexpert), dtype), # type: ignore up_logits: T.Tensor((num_tokens, dexpert), dtype), # type: ignore
output: T.Tensor(input_shape, dtype), # type: ignore output: T.Tensor(input_shape, dtype), # type: ignore
): ):
# Step 1: Compute gate and up logits # Step 1: Compute gate and up logits
with T.Kernel( with T.Kernel(T.ceildiv(num_tokens, block_token), T.ceildiv(dexpert, block_dexpert), threads=threads) as (bx, by):
T.ceildiv(num_tokens, block_token), T.ceildiv(dexpert, block_dexpert),
threads=threads) as (bx, by):
# Split the block to shared experts and routed experts # Split the block to shared experts and routed experts
input_shared = T.alloc_fragment((block_token, block_dhidden), dtype=dtype) input_shared = T.alloc_fragment((block_token, block_dhidden), dtype=dtype)
W_gate_shared = T.alloc_shared((block_dexpert, block_dhidden), dtype=dtype) W_gate_shared = T.alloc_shared((block_dexpert, block_dhidden), dtype=dtype)
...@@ -70,16 +69,13 @@ def moe_forward_tilelang_shared(d_hidden, ...@@ -70,16 +69,13 @@ def moe_forward_tilelang_shared(d_hidden,
# Fuse with SiLU and element-wise product # Fuse with SiLU and element-wise product
for i, j in T.Parallel(block_token, block_dexpert): for i, j in T.Parallel(block_token, block_dexpert):
gate_logits_local[i, j] = gate_logits_local[i, j] * ( gate_logits_local[i, j] = gate_logits_local[i, j] * (1.0 / (1.0 + T.exp2(-gate_logits_local[i, j] * scale)))
1.0 / (1.0 + T.exp2(-gate_logits_local[i, j] * scale)))
up_logits_local[i, j] = up_logits_local[i, j] * gate_logits_local[i, j] up_logits_local[i, j] = up_logits_local[i, j] * gate_logits_local[i, j]
T.copy(up_logits_local, up_logits[bx * block_token, by * block_dexpert]) T.copy(up_logits_local, up_logits[bx * block_token, by * block_dexpert])
# Step 2: Compute down logits # Step 2: Compute down logits
with T.Kernel( with T.Kernel(T.ceildiv(num_tokens, block_token), T.ceildiv(dhidden, block_dhidden), threads=threads) as (bx, by):
T.ceildiv(num_tokens, block_token), T.ceildiv(dhidden, block_dhidden),
threads=threads) as (bx, by):
up_logits_shared = T.alloc_fragment((block_token, block_dexpert), dtype=dtype) up_logits_shared = T.alloc_fragment((block_token, block_dexpert), dtype=dtype)
W_down_shared = T.alloc_shared((block_dhidden, block_dexpert), dtype=dtype) W_down_shared = T.alloc_shared((block_dhidden, block_dexpert), dtype=dtype)
output_local = T.alloc_fragment((block_token, block_dhidden), dtype=accum_type) output_local = T.alloc_fragment((block_token, block_dhidden), dtype=accum_type)
...@@ -98,20 +94,21 @@ def moe_forward_tilelang_shared(d_hidden, ...@@ -98,20 +94,21 @@ def moe_forward_tilelang_shared(d_hidden,
@tilelang.jit(pass_configs={"tl.disable_tma_lower": True, "tl.disable_warp_specialized": True}) @tilelang.jit(pass_configs={"tl.disable_tma_lower": True, "tl.disable_warp_specialized": True})
def moe_forward_tilelang_routed(d_hidden, def moe_forward_tilelang_routed(
d_expert, d_hidden,
n_routed_experts, d_expert,
dtype, n_routed_experts,
group_sum, dtype,
group_count, group_sum,
block_token=128, group_count,
block_dhidden=128, block_token=128,
block_dexpert=128, block_dhidden=128,
threads=256, block_dexpert=128,
num_stages=1, threads=256,
k_pack=1, num_stages=1,
coalesced_width=None): k_pack=1,
coalesced_width=None,
):
scale = 1.44269504 # log2(e) scale = 1.44269504 # log2(e)
# Parameters # Parameters
...@@ -124,7 +121,7 @@ def moe_forward_tilelang_routed(d_hidden, ...@@ -124,7 +121,7 @@ def moe_forward_tilelang_routed(d_hidden,
# group_count = len(group_sizes_list) # group_count = len(group_sizes_list)
# M = sum([(group_size + block_token - 1) // block_token for group_size in group_sizes_list]) # M = sum([(group_size + block_token - 1) // block_token for group_size in group_sizes_list])
M = math.ceil(group_sum / block_token) + group_count M = math.ceil(group_sum / block_token) + group_count
accum_dtype = "float32" accum_dtype = T.float32
# Tensors: Note that input shape is reshape to (bs * seq_len * n_experts_per_token, dhidden) for grouped gemm # Tensors: Note that input shape is reshape to (bs * seq_len * n_experts_per_token, dhidden) for grouped gemm
input_shape = (group_sum, dhidden) input_shape = (group_sum, dhidden)
...@@ -132,22 +129,22 @@ def moe_forward_tilelang_routed(d_hidden, ...@@ -132,22 +129,22 @@ def moe_forward_tilelang_routed(d_hidden,
routed_expert_gate_shape = (n_routed_experts, dexpert, dhidden) routed_expert_gate_shape = (n_routed_experts, dexpert, dhidden)
routed_expert_up_shape = (n_routed_experts, dexpert, dhidden) routed_expert_up_shape = (n_routed_experts, dexpert, dhidden)
routed_expert_down_shape = (n_routed_experts, dhidden, dexpert) routed_expert_down_shape = (n_routed_experts, dhidden, dexpert)
routed_expert_weights_shape = (group_sum) routed_expert_weights_shape = group_sum
group_sizes_shape = (n_routed_experts) group_sizes_shape = n_routed_experts
@T.prim_func @T.prim_func
def kernel( def kernel(
input: T.Tensor(input_shape, dtype), # type: ignore input: T.Tensor(input_shape, dtype), # type: ignore
routed_expert_gate: T.Tensor(routed_expert_gate_shape, dtype), # type: ignore routed_expert_gate: T.Tensor(routed_expert_gate_shape, dtype), # type: ignore
routed_expert_up: T.Tensor(routed_expert_up_shape, dtype), # type: ignore routed_expert_up: T.Tensor(routed_expert_up_shape, dtype), # type: ignore
routed_expert_down: T.Tensor(routed_expert_down_shape, dtype), # type: ignore routed_expert_down: T.Tensor(routed_expert_down_shape, dtype), # type: ignore
routed_expert_weights: T.Tensor(routed_expert_weights_shape, dtype), # type: ignore routed_expert_weights: T.Tensor(routed_expert_weights_shape, dtype), # type: ignore
group_sizes: T.Tensor(group_sizes_shape, "int32"), # type: ignore group_sizes: T.Tensor(group_sizes_shape, T.int32), # type: ignore
group_offsets: T.Tensor(group_sizes_shape, "int32"), # type: ignore group_offsets: T.Tensor(group_sizes_shape, T.int32), # type: ignore
group_padded_offsets: T.Tensor(group_sizes_shape, "int32"), # type: ignore group_padded_offsets: T.Tensor(group_sizes_shape, T.int32), # type: ignore
group_idx_for_bx: T.Tensor((M,), "int32"), # type: ignore group_idx_for_bx: T.Tensor((M,), T.int32), # type: ignore
up_logits: T.Tensor(intermediate_shape, dtype), # type: ignore up_logits: T.Tensor(intermediate_shape, dtype), # type: ignore
output: T.Tensor(input_shape, dtype), # type: ignore output: T.Tensor(input_shape, dtype), # type: ignore
): ):
# Step 1: Compute gate and up logits # Step 1: Compute gate and up logits
with T.Kernel(M, T.ceildiv(dexpert, block_dexpert), threads=threads) as (bx, by): with T.Kernel(M, T.ceildiv(dexpert, block_dexpert), threads=threads) as (bx, by):
...@@ -158,8 +155,8 @@ def moe_forward_tilelang_routed(d_hidden, ...@@ -158,8 +155,8 @@ def moe_forward_tilelang_routed(d_hidden,
gate_logits_local = T.alloc_fragment((block_token, block_dexpert), dtype=accum_dtype) gate_logits_local = T.alloc_fragment((block_token, block_dexpert), dtype=accum_dtype)
up_logits_local = T.alloc_fragment((block_token, block_dexpert), dtype=accum_dtype) up_logits_local = T.alloc_fragment((block_token, block_dexpert), dtype=accum_dtype)
cur_group_idx = T.alloc_local([1], "int32") cur_group_idx = T.alloc_local([1], T.int32)
cur_group_size = T.alloc_local([1], "int32") cur_group_size = T.alloc_local([1], T.int32)
T.use_swizzle(10, enable=True) T.use_swizzle(10, enable=True)
...@@ -168,48 +165,37 @@ def moe_forward_tilelang_routed(d_hidden, ...@@ -168,48 +165,37 @@ def moe_forward_tilelang_routed(d_hidden,
cur_group_idx[0] = group_idx_for_bx[bx] cur_group_idx[0] = group_idx_for_bx[bx]
cur_group_size[0] = group_sizes[cur_group_idx[0]] cur_group_size[0] = group_sizes[cur_group_idx[0]]
m_start = m_start_padded - group_padded_offsets[cur_group_idx[0]] + group_offsets[ m_start = m_start_padded - group_padded_offsets[cur_group_idx[0]] + group_offsets[cur_group_idx[0]]
cur_group_idx[0]] actual_rows = T.max(0, T.min(block_token, cur_group_size[0] - (m_start_padded - group_padded_offsets[cur_group_idx[0]])))
actual_rows = T.max(
0,
T.min(block_token, cur_group_size[0] -
(m_start_padded - group_padded_offsets[cur_group_idx[0]])))
T.clear(gate_logits_local) T.clear(gate_logits_local)
T.clear(up_logits_local) T.clear(up_logits_local)
for k in T.Pipelined(T.ceildiv(dhidden, block_dhidden), num_stages=num_stages): for k in T.Pipelined(T.ceildiv(dhidden, block_dhidden), num_stages=num_stages):
T.copy( T.copy(
input[m_start:m_start + block_token, k * block_dhidden:(k + 1) * block_dhidden], input[m_start : m_start + block_token, k * block_dhidden : (k + 1) * block_dhidden],
input_shared, input_shared,
coalesced_width=coalesced_width) coalesced_width=coalesced_width,
)
T.copy( T.copy(
routed_expert_gate[cur_group_idx[0], routed_expert_gate[
by * block_dexpert:(by + 1) * block_dexpert, cur_group_idx[0], by * block_dexpert : (by + 1) * block_dexpert, k * block_dhidden : (k + 1) * block_dhidden
k * block_dhidden:(k + 1) * block_dhidden], ],
routed_expert_gate_shared,
coalesced_width=coalesced_width)
T.gemm(
input_shared,
routed_expert_gate_shared, routed_expert_gate_shared,
gate_logits_local, coalesced_width=coalesced_width,
k_pack=k_pack, )
transpose_B=True) T.gemm(input_shared, routed_expert_gate_shared, gate_logits_local, k_pack=k_pack, transpose_B=True)
T.copy( T.copy(
routed_expert_up[cur_group_idx[0], by * block_dexpert:(by + 1) * block_dexpert, routed_expert_up[
k * block_dhidden:(k + 1) * block_dhidden], cur_group_idx[0], by * block_dexpert : (by + 1) * block_dexpert, k * block_dhidden : (k + 1) * block_dhidden
],
routed_expert_up_shared, routed_expert_up_shared,
coalesced_width=coalesced_width) coalesced_width=coalesced_width,
T.gemm( )
input_shared, T.gemm(input_shared, routed_expert_up_shared, up_logits_local, k_pack=k_pack, transpose_B=True)
routed_expert_up_shared,
up_logits_local,
k_pack=k_pack,
transpose_B=True)
for i, j in T.Parallel(block_token, block_dexpert): for i, j in T.Parallel(block_token, block_dexpert):
gate_logits_local[i, j] = gate_logits_local[i, j] * ( gate_logits_local[i, j] = gate_logits_local[i, j] * (1.0 / (1.0 + T.exp2(-gate_logits_local[i, j] * scale)))
1.0 / (1.0 + T.exp2(-gate_logits_local[i, j] * scale)))
up_logits_local[i, j] = up_logits_local[i, j] * gate_logits_local[i, j] up_logits_local[i, j] = up_logits_local[i, j] * gate_logits_local[i, j]
for i, j in T.Parallel(block_token, block_dexpert): for i, j in T.Parallel(block_token, block_dexpert):
...@@ -222,8 +208,8 @@ def moe_forward_tilelang_routed(d_hidden, ...@@ -222,8 +208,8 @@ def moe_forward_tilelang_routed(d_hidden,
routed_expert_down_shared = T.alloc_shared((block_dhidden, block_dexpert), dtype=dtype) routed_expert_down_shared = T.alloc_shared((block_dhidden, block_dexpert), dtype=dtype)
output_local = T.alloc_fragment((block_token, block_dhidden), dtype=accum_dtype) output_local = T.alloc_fragment((block_token, block_dhidden), dtype=accum_dtype)
cur_group_idx = T.alloc_local([1], "int32") cur_group_idx = T.alloc_local([1], T.int32)
cur_group_size = T.alloc_local([1], "int32") cur_group_size = T.alloc_local([1], T.int32)
T.use_swizzle(10, enable=True) T.use_swizzle(10, enable=True)
...@@ -232,50 +218,35 @@ def moe_forward_tilelang_routed(d_hidden, ...@@ -232,50 +218,35 @@ def moe_forward_tilelang_routed(d_hidden,
cur_group_idx[0] = group_idx_for_bx[bx] cur_group_idx[0] = group_idx_for_bx[bx]
cur_group_size[0] = group_sizes[cur_group_idx[0]] cur_group_size[0] = group_sizes[cur_group_idx[0]]
m_start = m_start_padded - group_padded_offsets[cur_group_idx[0]] + group_offsets[ m_start = m_start_padded - group_padded_offsets[cur_group_idx[0]] + group_offsets[cur_group_idx[0]]
cur_group_idx[0]] actual_rows = T.max(0, T.min(block_token, cur_group_size[0] - (m_start_padded - group_padded_offsets[cur_group_idx[0]])))
actual_rows = T.max(
0,
T.min(block_token, cur_group_size[0] -
(m_start_padded - group_padded_offsets[cur_group_idx[0]])))
T.clear(output_local) T.clear(output_local)
for k in T.Pipelined(T.ceildiv(dexpert, block_dexpert), num_stages=num_stages): for k in T.Pipelined(T.ceildiv(dexpert, block_dexpert), num_stages=num_stages):
T.copy( T.copy(
up_logits[m_start:m_start + block_token, up_logits[m_start : m_start + block_token, k * block_dexpert : (k + 1) * block_dexpert],
k * block_dexpert:(k + 1) * block_dexpert],
up_logits_shared, up_logits_shared,
coalesced_width=coalesced_width) coalesced_width=coalesced_width,
)
T.copy( T.copy(
routed_expert_down[cur_group_idx[0], routed_expert_down[
by * block_dhidden:(by + 1) * block_dhidden, cur_group_idx[0], by * block_dhidden : (by + 1) * block_dhidden, k * block_dexpert : (k + 1) * block_dexpert
k * block_dexpert:(k + 1) * block_dexpert], ],
routed_expert_down_shared,
coalesced_width=coalesced_width)
T.gemm(
up_logits_shared,
routed_expert_down_shared, routed_expert_down_shared,
output_local, coalesced_width=coalesced_width,
k_pack=k_pack, )
transpose_B=True) T.gemm(up_logits_shared, routed_expert_down_shared, output_local, k_pack=k_pack, transpose_B=True)
for i, j in T.Parallel(block_token, block_dhidden): for i, j in T.Parallel(block_token, block_dhidden):
if i < actual_rows: if i < actual_rows:
output[m_start + i, by * block_dhidden + output[m_start + i, by * block_dhidden + j] = output_local[i, j] * routed_expert_weights[m_start + i]
j] = output_local[i, j] * routed_expert_weights[m_start + i]
return kernel return kernel
class Expert(nn.Module): class Expert(nn.Module):
def __init__(self, config: Dict, gate: torch.Tensor, up: torch.Tensor, down: torch.Tensor, d_expert: Optional[int] = None):
def __init__(self,
config: Dict,
gate: torch.Tensor,
up: torch.Tensor,
down: torch.Tensor,
d_expert: Optional[int] = None):
super().__init__() super().__init__()
self.config = config self.config = config
self.act_fn = nn.SiLU() self.act_fn = nn.SiLU()
...@@ -294,14 +265,13 @@ class Expert(nn.Module): ...@@ -294,14 +265,13 @@ class Expert(nn.Module):
class MoEGate(nn.Module): class MoEGate(nn.Module):
def __init__(self, config: Dict, weights: Dict): def __init__(self, config: Dict, weights: Dict):
super().__init__() super().__init__()
self.top_k: int = config["n_experts_per_token"] self.top_k: int = config["n_experts_per_token"]
self.num_experts: int = config["n_routed_experts"] self.num_experts: int = config["n_routed_experts"]
self.d_hidden: int = config["d_hidden"] self.d_hidden: int = config["d_hidden"]
self.W_g_weight = weights['router.weight'].t() self.W_g_weight = weights["router.weight"].t()
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
logits = x @ self.W_g_weight logits = x @ self.W_g_weight
...@@ -312,76 +282,69 @@ class MoEGate(nn.Module): ...@@ -312,76 +282,69 @@ class MoEGate(nn.Module):
class MoE(nn.Module): class MoE(nn.Module):
def __init__(
def __init__(self, self, config: Dict, shared_kernel: tilelang.JITKernel, routed_kernel: tilelang.JITKernel, weights: Dict, padding_M: int = 128
config: Dict, ):
shared_kernel: tilelang.JITKernel,
routed_kernel: tilelang.JITKernel,
weights: Dict,
padding_M: int = 128):
super().__init__() super().__init__()
self.config = config self.config = config
self.shared_kernel = shared_kernel self.shared_kernel = shared_kernel
self.routed_kernel = routed_kernel self.routed_kernel = routed_kernel
self.padding_M = padding_M self.padding_M = padding_M
self.experts = nn.ModuleList([ self.experts = nn.ModuleList(
Expert( [
config, Expert(
gate=weights[f'experts.{i}.0.weight'], config,
up=weights[f'experts.{i}.1.weight'], gate=weights[f"experts.{i}.0.weight"],
down=weights[f'experts.{i}.2.weight']) for i in range(config["n_routed_experts"]) up=weights[f"experts.{i}.1.weight"],
]) down=weights[f"experts.{i}.2.weight"],
)
for i in range(config["n_routed_experts"])
]
)
self.device = torch.device("cuda") self.device = torch.device("cuda")
self.gating_network = MoEGate(config, weights).to(self.device) self.gating_network = MoEGate(config, weights).to(self.device)
shared_expert_dim = config["d_expert"] * config["n_shared_experts"] shared_expert_dim = config["d_expert"] * config["n_shared_experts"]
self.shared_expert = Expert( self.shared_expert = Expert(
config=config, config=config,
gate=weights['shared_experts.0.weight'], gate=weights["shared_experts.0.weight"],
up=weights['shared_experts.1.weight'], up=weights["shared_experts.1.weight"],
down=weights['shared_experts.2.weight'], down=weights["shared_experts.2.weight"],
d_expert=shared_expert_dim).to(self.device) d_expert=shared_expert_dim,
).to(self.device)
self.expert_cache = torch.zeros( self.expert_cache = torch.zeros(
(config["batch_size"] * config["seq_len"], config["d_hidden"]), (config["batch_size"] * config["seq_len"], config["d_hidden"]), dtype=torch.float16, device=self.device
dtype=torch.float16, )
device=self.device) self.stacked_expert_w_gate = torch.stack([expert.W_gate_weight for expert in self.experts], dim=0)
self.stacked_expert_w_gate = torch.stack([expert.W_gate_weight for expert in self.experts], self.stacked_expert_w_up = torch.stack([expert.W_up_weight for expert in self.experts], dim=0)
dim=0) self.stacked_expert_w_down = torch.stack([expert.W_down_weight for expert in self.experts], dim=0)
self.stacked_expert_w_up = torch.stack([expert.W_up_weight for expert in self.experts],
dim=0)
self.stacked_expert_w_down = torch.stack([expert.W_down_weight for expert in self.experts],
dim=0)
self.stacked_expert_tokens = torch.empty( self.stacked_expert_tokens = torch.empty(
(config["batch_size"] * config["seq_len"] * config["n_experts_per_token"], (config["batch_size"] * config["seq_len"] * config["n_experts_per_token"], self.config["d_hidden"]),
self.config["d_hidden"]),
dtype=torch.float16, dtype=torch.float16,
device=self.device) device=self.device,
)
self.stacked_expert_weights = torch.empty( self.stacked_expert_weights = torch.empty(
(config["batch_size"] * config["seq_len"] * config["n_experts_per_token"]), (config["batch_size"] * config["seq_len"] * config["n_experts_per_token"]), dtype=torch.float16, device=self.device
dtype=torch.float16, )
device=self.device)
self.stacked_expert_tokens_idxs = torch.empty( self.stacked_expert_tokens_idxs = torch.empty(
(config["batch_size"] * config["seq_len"] * config["n_experts_per_token"]), (config["batch_size"] * config["seq_len"] * config["n_experts_per_token"]), dtype=torch.int64, device=self.device
dtype=torch.int64, )
device=self.device)
self.up_logits_shared = torch.empty( self.up_logits_shared = torch.empty(
(config["batch_size"] * config["seq_len"], self.config["d_expert"]), (config["batch_size"] * config["seq_len"], self.config["d_expert"]), dtype=torch.float16, device=self.device
dtype=torch.float16, )
device=self.device)
self.expert_output_shared = torch.empty( self.expert_output_shared = torch.empty(
(config["batch_size"] * config["seq_len"], self.config["d_hidden"]), (config["batch_size"] * config["seq_len"], self.config["d_hidden"]), dtype=torch.float16, device=self.device
dtype=torch.float16, )
device=self.device)
self.up_logits_routed = torch.empty( self.up_logits_routed = torch.empty(
(config["batch_size"] * config["seq_len"] * config["n_experts_per_token"], (config["batch_size"] * config["seq_len"] * config["n_experts_per_token"], self.config["d_expert"]),
self.config["d_expert"]),
dtype=torch.float16, dtype=torch.float16,
device=self.device) device=self.device,
)
self.expert_output_routed = torch.empty( self.expert_output_routed = torch.empty(
(config["batch_size"] * config["seq_len"] * config["n_experts_per_token"], (config["batch_size"] * config["seq_len"] * config["n_experts_per_token"], self.config["d_hidden"]),
self.config["d_hidden"]),
dtype=torch.float16, dtype=torch.float16,
device=self.device) device=self.device,
)
@torch.no_grad() @torch.no_grad()
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
...@@ -413,22 +376,20 @@ class MoE(nn.Module): ...@@ -413,22 +376,20 @@ class MoE(nn.Module):
self.stacked_expert_tokens[start_idx:end_idx] = expert_tokens self.stacked_expert_tokens[start_idx:end_idx] = expert_tokens
self.stacked_expert_tokens_idxs[start_idx:end_idx] = exp_token_idxs self.stacked_expert_tokens_idxs[start_idx:end_idx] = exp_token_idxs
self.stacked_expert_weights[start_idx:end_idx] = flat_expert_weights[ self.stacked_expert_weights[start_idx:end_idx] = flat_expert_weights[idxs[start_idx:end_idx]]
idxs[start_idx:end_idx]]
group_sizes = torch.tensor(counts, dtype=torch.int32, device=self.device) group_sizes = torch.tensor(counts, dtype=torch.int32, device=self.device)
group_offset = torch.tensor( group_offset = torch.tensor(tokens_per_expert - counts, dtype=torch.int32, device=self.device)
tokens_per_expert - counts, dtype=torch.int32, device=self.device)
group_padded_offsets = [0 for _ in range(len(group_sizes))] group_padded_offsets = [0 for _ in range(len(group_sizes))]
for i in range(1, len(group_sizes)): for i in range(1, len(group_sizes)):
group_padded_offsets[i] = group_padded_offsets[i - 1] + math.ceil( group_padded_offsets[i] = group_padded_offsets[i - 1] + math.ceil((counts[i - 1] + 1) / self.padding_M) * self.padding_M
(counts[i - 1] + 1) / self.padding_M) * self.padding_M
block_token = 128 block_token = 128
M = math.ceil( M = (
self.config["batch_size"] * self.config["seq_len"] * math.ceil(self.config["batch_size"] * self.config["seq_len"] * self.config["n_experts_per_token"] / block_token)
self.config["n_experts_per_token"] / block_token) + self.config["n_routed_experts"] + self.config["n_routed_experts"]
)
group_idx_for_bx = [0 for _ in range(M)] group_idx_for_bx = [0 for _ in range(M)]
for bx in range(M): for bx in range(M):
...@@ -437,8 +398,7 @@ class MoE(nn.Module): ...@@ -437,8 +398,7 @@ class MoE(nn.Module):
if m_start_padded >= group_padded_offsets[i]: if m_start_padded >= group_padded_offsets[i]:
group_idx_for_bx[bx] = i group_idx_for_bx[bx] = i
group_padded_offsets = torch.tensor( group_padded_offsets = torch.tensor(group_padded_offsets, dtype=torch.int32, device=self.device)
group_padded_offsets, dtype=torch.int32, device=self.device)
group_idx_for_bx = torch.tensor(group_idx_for_bx, dtype=torch.int32, device=self.device) group_idx_for_bx = torch.tensor(group_idx_for_bx, dtype=torch.int32, device=self.device)
# Multi-stream execution # Multi-stream execution
...@@ -448,11 +408,19 @@ class MoE(nn.Module): ...@@ -448,11 +408,19 @@ class MoE(nn.Module):
with torch.cuda.stream(routed_stream): with torch.cuda.stream(routed_stream):
# Tilelang version: Grouped GEMM # Tilelang version: Grouped GEMM
self.routed_kernel(self.stacked_expert_tokens, self.stacked_expert_w_gate, self.routed_kernel(
self.stacked_expert_w_up, self.stacked_expert_w_down, self.stacked_expert_tokens,
self.stacked_expert_weights, group_sizes, group_offset, self.stacked_expert_w_gate,
group_padded_offsets, group_idx_for_bx, self.up_logits_routed, self.stacked_expert_w_up,
self.expert_output_routed) self.stacked_expert_w_down,
self.stacked_expert_weights,
group_sizes,
group_offset,
group_padded_offsets,
group_idx_for_bx,
self.up_logits_routed,
self.expert_output_routed,
)
# Scatter reduce # Scatter reduce
self.expert_cache = torch.scatter_reduce( self.expert_cache = torch.scatter_reduce(
...@@ -460,14 +428,19 @@ class MoE(nn.Module): ...@@ -460,14 +428,19 @@ class MoE(nn.Module):
0, 0,
self.stacked_expert_tokens_idxs.view(-1, 1).repeat(1, x_flat.shape[-1]), self.stacked_expert_tokens_idxs.view(-1, 1).repeat(1, x_flat.shape[-1]),
self.expert_output_routed, self.expert_output_routed,
reduce='sum') reduce="sum",
)
routed_output = self.expert_cache.view(*orig_shape) routed_output = self.expert_cache.view(*orig_shape)
with torch.cuda.stream(shared_stream): with torch.cuda.stream(shared_stream):
self.shared_kernel(
self.shared_kernel(x_flat, self.shared_expert.W_gate_weight, x_flat,
self.shared_expert.W_up_weight, self.shared_expert.W_down_weight, self.shared_expert.W_gate_weight,
self.up_logits_shared, self.expert_output_shared) self.shared_expert.W_up_weight,
self.shared_expert.W_down_weight,
self.up_logits_shared,
self.expert_output_shared,
)
shared_output = self.expert_output_shared.view(*orig_shape) shared_output = self.expert_output_shared.view(*orig_shape)
torch.cuda.synchronize() torch.cuda.synchronize()
...@@ -491,14 +464,15 @@ def custom_kernel(data: Tuple[torch.Tensor, Dict, Dict]) -> torch.Tensor: ...@@ -491,14 +464,15 @@ def custom_kernel(data: Tuple[torch.Tensor, Dict, Dict]) -> torch.Tensor:
""" """
input_tensor, weights, config = data input_tensor, weights, config = data
dtype_str = "float16" dtype_str = T.float16
shared_kernel = moe_forward_tilelang_shared( shared_kernel = moe_forward_tilelang_shared(
config["d_hidden"], config["d_hidden"],
config["d_expert"], config["d_expert"],
config["n_shared_experts"], config["n_shared_experts"],
dtype=dtype_str, dtype=dtype_str,
num_tokens=config["batch_size"] * config["seq_len"]) num_tokens=config["batch_size"] * config["seq_len"],
)
routed_kernel = moe_forward_tilelang_routed( routed_kernel = moe_forward_tilelang_routed(
config["d_hidden"], config["d_hidden"],
config["d_expert"], config["d_expert"],
...@@ -512,7 +486,8 @@ def custom_kernel(data: Tuple[torch.Tensor, Dict, Dict]) -> torch.Tensor: ...@@ -512,7 +486,8 @@ def custom_kernel(data: Tuple[torch.Tensor, Dict, Dict]) -> torch.Tensor:
threads=256, threads=256,
num_stages=1, num_stages=1,
k_pack=1, k_pack=1,
coalesced_width=2) coalesced_width=2,
)
moe = MoE(config, shared_kernel, routed_kernel, weights, padding_M=128) moe = MoE(config, shared_kernel, routed_kernel, weights, padding_M=128)
...@@ -521,13 +496,7 @@ def custom_kernel(data: Tuple[torch.Tensor, Dict, Dict]) -> torch.Tensor: ...@@ -521,13 +496,7 @@ def custom_kernel(data: Tuple[torch.Tensor, Dict, Dict]) -> torch.Tensor:
return output return output
def main(d_hidden=7168, def main(d_hidden=7168, d_expert=2048, n_routed_experts=8, n_shared_experts=1, n_experts_per_token=4, batch_size=1, seq_len=8192):
d_expert=2048,
n_routed_experts=8,
n_shared_experts=1,
n_experts_per_token=4,
batch_size=1,
seq_len=8192):
config = { config = {
"dhidden": d_hidden, "dhidden": d_hidden,
"dexpert": d_expert, "dexpert": d_expert,
...@@ -536,7 +505,7 @@ def main(d_hidden=7168, ...@@ -536,7 +505,7 @@ def main(d_hidden=7168,
"nexpertspertoken": n_experts_per_token, "nexpertspertoken": n_experts_per_token,
"bs": batch_size, "bs": batch_size,
"seqlen": seq_len, "seqlen": seq_len,
"seed": 81394 "seed": 81394,
} }
data = generate_input(**config) data = generate_input(**config)
......
...@@ -6,7 +6,6 @@ from typing import Dict, Tuple, Optional ...@@ -6,7 +6,6 @@ from typing import Dict, Tuple, Optional
# Reference code in PyTorch # Reference code in PyTorch
class ExpertTorch(nn.Module): class ExpertTorch(nn.Module):
def __init__(self, config: Dict, d_expert: Optional[int] = None): def __init__(self, config: Dict, d_expert: Optional[int] = None):
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -25,7 +24,6 @@ class ExpertTorch(nn.Module): ...@@ -25,7 +24,6 @@ class ExpertTorch(nn.Module):
class MoEGateTorch(nn.Module): class MoEGateTorch(nn.Module):
def __init__(self, config: Dict): def __init__(self, config: Dict):
super().__init__() super().__init__()
self.top_k: int = config["n_experts_per_token"] self.top_k: int = config["n_experts_per_token"]
...@@ -43,12 +41,10 @@ class MoEGateTorch(nn.Module): ...@@ -43,12 +41,10 @@ class MoEGateTorch(nn.Module):
class MoETorch(nn.Module): class MoETorch(nn.Module):
def __init__(self, config: Dict): def __init__(self, config: Dict):
super().__init__() super().__init__()
self.config = config self.config = config
self.experts = nn.ModuleList( self.experts = nn.ModuleList([ExpertTorch(config) for _ in range(config["n_routed_experts"])])
[ExpertTorch(config) for _ in range(config["n_routed_experts"])])
self.gating_network = MoEGateTorch(config) self.gating_network = MoEGateTorch(config)
shared_expert_dim = config["d_expert"] * config["n_shared_experts"] shared_expert_dim = config["d_expert"] * config["n_shared_experts"]
self.shared_expert = ExpertTorch(config=config, d_expert=shared_expert_dim) self.shared_expert = ExpertTorch(config=config, d_expert=shared_expert_dim)
...@@ -67,8 +63,7 @@ class MoETorch(nn.Module): ...@@ -67,8 +63,7 @@ class MoETorch(nn.Module):
return routed_output + shared_output return routed_output + shared_output
@torch.no_grad() @torch.no_grad()
def moe_infer(self, x: torch.Tensor, flat_expert_indices: torch.Tensor, def moe_infer(self, x: torch.Tensor, flat_expert_indices: torch.Tensor, flat_expert_weights: torch.Tensor) -> torch.Tensor:
flat_expert_weights: torch.Tensor) -> torch.Tensor:
expert_cache = torch.zeros_like(x) expert_cache = torch.zeros_like(x)
# test_expert_cache = torch.zeros((x.shape[0] * self.config["n_experts_per_token"], self.config["d_hidden"])) # test_expert_cache = torch.zeros((x.shape[0] * self.config["n_experts_per_token"], self.config["d_hidden"]))
# test_expert_tokens = torch.zeros((x.shape[0] * self.config["n_experts_per_token"], self.config["d_hidden"])) # test_expert_tokens = torch.zeros((x.shape[0] * self.config["n_experts_per_token"], self.config["d_hidden"]))
...@@ -91,8 +86,7 @@ class MoETorch(nn.Module): ...@@ -91,8 +86,7 @@ class MoETorch(nn.Module):
expert_out = expert(expert_tokens) expert_out = expert(expert_tokens)
expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]]) expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
expert_cache.scatter_reduce_( expert_cache.scatter_reduce_(0, exp_token_idxs.view(-1, 1).repeat(1, x.shape[-1]), expert_out, reduce="sum")
0, exp_token_idxs.view(-1, 1).repeat(1, x.shape[-1]), expert_out, reduce='sum')
return expert_cache return expert_cache
...@@ -116,21 +110,21 @@ def ref_kernel(data: Tuple[torch.Tensor, Dict, Dict]) -> torch.Tensor: ...@@ -116,21 +110,21 @@ def ref_kernel(data: Tuple[torch.Tensor, Dict, Dict]) -> torch.Tensor:
moe = MoETorch(config) moe = MoETorch(config)
# Fill in the given weights of the model # Fill in the given weights of the model
moe.gating_network.W_g.weight = nn.Parameter(weights['router.weight']) moe.gating_network.W_g.weight = nn.Parameter(weights["router.weight"])
for i in range(num_experts): for i in range(num_experts):
gate_proj_weight = weights[f'experts.{i}.0.weight'] gate_proj_weight = weights[f"experts.{i}.0.weight"]
up_proj_weight = weights[f'experts.{i}.1.weight'] up_proj_weight = weights[f"experts.{i}.1.weight"]
down_proj_weight = weights[f'experts.{i}.2.weight'] down_proj_weight = weights[f"experts.{i}.2.weight"]
# Transpose weights to match expected shape for nn.Linear # Transpose weights to match expected shape for nn.Linear
moe.experts[i].W_gate.weight = nn.Parameter(gate_proj_weight.t()) moe.experts[i].W_gate.weight = nn.Parameter(gate_proj_weight.t())
moe.experts[i].W_up.weight = nn.Parameter(up_proj_weight.t()) moe.experts[i].W_up.weight = nn.Parameter(up_proj_weight.t())
moe.experts[i].W_down.weight = nn.Parameter(down_proj_weight.t()) moe.experts[i].W_down.weight = nn.Parameter(down_proj_weight.t())
moe.shared_expert.W_gate.weight = nn.Parameter(weights['shared_experts.0.weight'].t()) moe.shared_expert.W_gate.weight = nn.Parameter(weights["shared_experts.0.weight"].t())
moe.shared_expert.W_up.weight = nn.Parameter(weights['shared_experts.1.weight'].t()) moe.shared_expert.W_up.weight = nn.Parameter(weights["shared_experts.1.weight"].t())
moe.shared_expert.W_down.weight = nn.Parameter(weights['shared_experts.2.weight'].t()) moe.shared_expert.W_down.weight = nn.Parameter(weights["shared_experts.2.weight"].t())
output = moe(input_tensor) output = moe(input_tensor)
...@@ -140,10 +134,9 @@ def ref_kernel(data: Tuple[torch.Tensor, Dict, Dict]) -> torch.Tensor: ...@@ -140,10 +134,9 @@ def ref_kernel(data: Tuple[torch.Tensor, Dict, Dict]) -> torch.Tensor:
# Input generation for the reference code # Input generation for the reference code
def generate_input(dhidden: int, dexpert: int, nroutedexperts: int, nsharedexperts: int, def generate_input(
nexpertspertoken: int, bs: int, seqlen: int, dhidden: int, dexpert: int, nroutedexperts: int, nsharedexperts: int, nexpertspertoken: int, bs: int, seqlen: int, seed: int
seed: int) -> Tuple[torch.Tensor, Dict, Dict]: ) -> Tuple[torch.Tensor, Dict, Dict]:
# Really dumb but for now _ isn't parsing correctly. # Really dumb but for now _ isn't parsing correctly.
d_hidden = dhidden d_hidden = dhidden
d_expert = dexpert d_expert = dexpert
...@@ -163,50 +156,40 @@ def generate_input(dhidden: int, dexpert: int, nroutedexperts: int, nsharedexper ...@@ -163,50 +156,40 @@ def generate_input(dhidden: int, dexpert: int, nroutedexperts: int, nsharedexper
"seq_len": seq_len, "seq_len": seq_len,
} }
gen = torch.Generator(device='cuda') gen = torch.Generator(device="cuda")
gen.manual_seed(seed) gen.manual_seed(seed)
num_experts = n_routed_experts num_experts = n_routed_experts
expert_dim = d_expert expert_dim = d_expert
weights = {} weights = {}
input_tensor = torch.randn((batch_size, seq_len, d_hidden), input_tensor = torch.randn((batch_size, seq_len, d_hidden), device="cuda", dtype=torch.float16, generator=gen).contiguous()
device='cuda',
dtype=torch.float16,
generator=gen).contiguous()
# Initialize router weights # Initialize router weights
weights['router.weight'] = torch.randn( weights["router.weight"] = torch.randn((num_experts, d_hidden), device="cuda", dtype=torch.float16, generator=gen) / math.sqrt(d_hidden)
(num_experts, d_hidden), device="cuda", dtype=torch.float16,
generator=gen) / math.sqrt(d_hidden)
for i in range(num_experts): for i in range(num_experts):
weights[f'experts.{i}.0.weight'] = torch.randn( weights[f"experts.{i}.0.weight"] = torch.randn(
(d_hidden, expert_dim), device='cuda', dtype=torch.float16, (d_hidden, expert_dim), device="cuda", dtype=torch.float16, generator=gen
generator=gen) / math.sqrt(expert_dim) ) / math.sqrt(expert_dim)
weights[f'experts.{i}.1.weight'] = torch.randn( weights[f"experts.{i}.1.weight"] = torch.randn(
(d_hidden, expert_dim), device='cuda', dtype=torch.float16, (d_hidden, expert_dim), device="cuda", dtype=torch.float16, generator=gen
generator=gen) / math.sqrt(expert_dim) ) / math.sqrt(expert_dim)
weights[f'experts.{i}.2.weight'] = torch.randn( weights[f"experts.{i}.2.weight"] = torch.randn(
(expert_dim, d_hidden), device='cuda', dtype=torch.float16, (expert_dim, d_hidden), device="cuda", dtype=torch.float16, generator=gen
generator=gen) / math.sqrt(d_hidden) ) / math.sqrt(d_hidden)
weights['shared_experts.0.weight'] = torch.randn( weights["shared_experts.0.weight"] = torch.randn(
(d_hidden, expert_dim * n_shared_experts), (d_hidden, expert_dim * n_shared_experts), device="cuda", dtype=torch.float16, generator=gen
device='cuda', ) / math.sqrt(expert_dim * n_shared_experts)
dtype=torch.float16, weights["shared_experts.1.weight"] = torch.randn(
generator=gen) / math.sqrt(expert_dim * n_shared_experts) (d_hidden, expert_dim * n_shared_experts), device="cuda", dtype=torch.float16, generator=gen
weights['shared_experts.1.weight'] = torch.randn( ) / math.sqrt(expert_dim * n_shared_experts)
(d_hidden, expert_dim * n_shared_experts), weights["shared_experts.2.weight"] = torch.randn(
device='cuda', (expert_dim * n_shared_experts, d_hidden), device="cuda", dtype=torch.float16, generator=gen
dtype=torch.float16, ) / math.sqrt(d_hidden)
generator=gen) / math.sqrt(expert_dim * n_shared_experts)
weights['shared_experts.2.weight'] = torch.randn((expert_dim * n_shared_experts, d_hidden),
device='cuda',
dtype=torch.float16,
generator=gen) / math.sqrt(d_hidden)
return (input_tensor, weights, config) return (input_tensor, weights, config)
......
...@@ -4,13 +4,8 @@ import example_fusedmoe_tilelang ...@@ -4,13 +4,8 @@ import example_fusedmoe_tilelang
def test_example_fusedmoe_tilelang(): def test_example_fusedmoe_tilelang():
example_fusedmoe_tilelang.main( example_fusedmoe_tilelang.main(
d_hidden=1024, d_hidden=1024, d_expert=256, n_routed_experts=8, n_shared_experts=1, n_experts_per_token=4, batch_size=1, seq_len=1024
d_expert=256, )
n_routed_experts=8,
n_shared_experts=1,
n_experts_per_token=4,
batch_size=1,
seq_len=1024)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -12,6 +12,7 @@ print(tilelang.__file__, flush=True) ...@@ -12,6 +12,7 @@ print(tilelang.__file__, flush=True)
# sys.path.insert(0, "/home/tzj/flash-linear-attention") # sys.path.insert(0, "/home/tzj/flash-linear-attention")
try: try:
import fla import fla
print(fla.__file__, flush=True) print(fla.__file__, flush=True)
from fla.ops.common.chunk_delta_h import chunk_gated_delta_rule_bwd_dhu from fla.ops.common.chunk_delta_h import chunk_gated_delta_rule_bwd_dhu
except ImportError: except ImportError:
...@@ -24,7 +25,7 @@ import torch.nn.functional as F ...@@ -24,7 +25,7 @@ import torch.nn.functional as F
torch.random.manual_seed(0) torch.random.manual_seed(0)
# torch.set_printoptions(profile="full") # torch.set_printoptions(profile="full")
from utils import * from test_utils import assert_similar
def prepare_input( def prepare_input(
...@@ -49,6 +50,7 @@ def prepare_input( ...@@ -49,6 +50,7 @@ def prepare_input(
G = F.logsigmoid(G) G = F.logsigmoid(G)
try: try:
from fla.ops.utils.cumsum import chunk_local_cumsum from fla.ops.utils.cumsum import chunk_local_cumsum
G = chunk_local_cumsum(G, chunk_size) G = chunk_local_cumsum(G, chunk_size)
except ImportError: except ImportError:
print("fla not found, skip cumsum") print("fla not found, skip cumsum")
...@@ -125,8 +127,11 @@ def torch_chunk_gated_delta_rule_bwd_dhu( ...@@ -125,8 +127,11 @@ def torch_chunk_gated_delta_rule_bwd_dhu(
DV = dv.shape[-1] DV = dv.shape[-1]
block_S = 64 block_S = 64
BS = S // block_S BS = S // block_S
dh, dh0, dv2 = torch.empty((B, BS, H, DK, DV), dtype=output_dtype), torch.empty( dh, dh0, dv2 = (
(B, H, DK, DV), dtype=state_dtype), torch.empty((B, S, H, DV), dtype=output_dtype) torch.empty((B, BS, H, DK, DV), dtype=output_dtype),
torch.empty((B, H, DK, DV), dtype=state_dtype),
torch.empty((B, S, H, DV), dtype=output_dtype),
)
dh_tmp = torch.empty((B, H, DK, DV), dtype=accum_dtype) dh_tmp = torch.empty((B, H, DK, DV), dtype=accum_dtype)
dv_tmp = torch.empty((B, S, H, DV), dtype=accum_dtype) dv_tmp = torch.empty((B, S, H, DV), dtype=accum_dtype)
Q_tmp = torch.empty((B, S, H, DK), dtype=accum_dtype) Q_tmp = torch.empty((B, S, H, DK), dtype=accum_dtype)
...@@ -138,34 +143,30 @@ def torch_chunk_gated_delta_rule_bwd_dhu( ...@@ -138,34 +143,30 @@ def torch_chunk_gated_delta_rule_bwd_dhu(
for i_s in range(BS - 1, -1, -1): for i_s in range(BS - 1, -1, -1):
dh[:, i_s, :, :, :] = dh_tmp dh[:, i_s, :, :, :] = dh_tmp
dv_tmp = torch.matmul(K[:, i_s * block_S:(i_s + 1) * block_S, :, :].permute(0, 2, 1, 3), dv_tmp = torch.matmul(K[:, i_s * block_S : (i_s + 1) * block_S, :, :].permute(0, 2, 1, 3), dh_tmp.to(K.dtype)).permute(0, 2, 1, 3)
dh_tmp.to(K.dtype)).permute(0, 2, 1, 3)
if use_g: if use_g:
for i_bh in range(B * H): for i_bh in range(B * H):
i_b, i_h = i_bh // H, i_bh % H i_b, i_h = i_bh // H, i_bh % H
for i_s2 in range(block_S): for i_s2 in range(block_S):
if G[i_b, i_s * block_S + block_S - 1, i_h] - G[i_b, i_s * block_S + i_s2, if G[i_b, i_s * block_S + block_S - 1, i_h] - G[i_b, i_s * block_S + i_s2, i_h] <= 0:
i_h] <= 0: dv_tmp[i_b, i_s2, i_h, :] *= torch.exp(G[i_b, i_s * block_S + block_S - 1, i_h] - G[i_b, i_s * block_S + i_s2, i_h])
dv_tmp[i_b, i_s2,
i_h, :] *= torch.exp(G[i_b, i_s * block_S + block_S - 1, i_h] -
G[i_b, i_s * block_S + i_s2, i_h])
else: else:
dv_tmp[i_b, i_s2, i_h, :] = 0 dv_tmp[i_b, i_s2, i_h, :] = 0
dv_tmp += dv[:, i_s * block_S:(i_s + 1) * block_S, :, :] dv_tmp += dv[:, i_s * block_S : (i_s + 1) * block_S, :, :]
dv2[:, i_s * block_S:(i_s + 1) * block_S, :, :] = dv_tmp dv2[:, i_s * block_S : (i_s + 1) * block_S, :, :] = dv_tmp
if use_g: if use_g:
G_last = G[:, i_s * block_S + block_S - 1, :] G_last = G[:, i_s * block_S + block_S - 1, :]
for i_bh in range(B * H): for i_bh in range(B * H):
i_b, i_h = i_bh // H, i_bh % H i_b, i_h = i_bh // H, i_bh % H
dh_tmp[i_b, i_h, :, :] *= torch.exp(G_last[i_b, i_h]) dh_tmp[i_b, i_h, :, :] *= torch.exp(G_last[i_b, i_h])
Q_tmp = Q[:, i_s * block_S:(i_s + 1) * block_S, :, :] Q_tmp = Q[:, i_s * block_S : (i_s + 1) * block_S, :, :]
for i_s2 in range(block_S): for i_s2 in range(block_S):
for i_k in range(DK): for i_k in range(DK):
Q_tmp[:, i_s2, :, i_k] *= torch.exp(G[:, i_s * block_S + i_s2, :]) Q_tmp[:, i_s2, :, i_k] *= torch.exp(G[:, i_s * block_S + i_s2, :])
Q_tmp *= scale Q_tmp *= scale
W_tmp = W[:, i_s * block_S:(i_s + 1) * block_S, :, :] W_tmp = W[:, i_s * block_S : (i_s + 1) * block_S, :, :]
dO_tmp = dO[:, i_s * block_S:(i_s + 1) * block_S, :, :] dO_tmp = dO[:, i_s * block_S : (i_s + 1) * block_S, :, :]
torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True
dh_tmp += torch.matmul(Q_tmp.permute(0, 2, 3, 1), dO_tmp.permute(0, 2, 1, 3)) dh_tmp += torch.matmul(Q_tmp.permute(0, 2, 3, 1), dO_tmp.permute(0, 2, 1, 3))
...@@ -223,19 +224,19 @@ def tilelang_chunk_gated_delta_rule_bwd_dhu( ...@@ -223,19 +224,19 @@ def tilelang_chunk_gated_delta_rule_bwd_dhu(
@T.prim_func @T.prim_func
def kernel( def kernel(
# Input # Input
Q: T.Tensor(Q_shape, dtype=input_dtype), Q: T.Tensor(Q_shape, dtype=input_dtype),
K: T.Tensor(K_shape, dtype=input_dtype), K: T.Tensor(K_shape, dtype=input_dtype),
W: T.Tensor(W_shape, dtype=input_dtype), W: T.Tensor(W_shape, dtype=input_dtype),
G: T.Tensor(G_shape, dtype=gate_dtype), G: T.Tensor(G_shape, dtype=gate_dtype),
h0: T.Tensor(h0_shape, dtype=input_dtype), h0: T.Tensor(h0_shape, dtype=input_dtype),
dht: T.Tensor(dht_shape, dtype=input_dtype), dht: T.Tensor(dht_shape, dtype=input_dtype),
dO: T.Tensor(dO_shape, dtype=input_dtype), dO: T.Tensor(dO_shape, dtype=input_dtype),
dv: T.Tensor(dv_shape, dtype=input_dtype), dv: T.Tensor(dv_shape, dtype=input_dtype),
# Output # Output
dh: T.Tensor(dh_shape, dtype=output_dtype), dh: T.Tensor(dh_shape, dtype=output_dtype),
dh0: T.Tensor(dh0_shape, dtype=state_dtype), dh0: T.Tensor(dh0_shape, dtype=state_dtype),
dv2: T.Tensor(dv2_shape, dtype=output_dtype), dv2: T.Tensor(dv2_shape, dtype=output_dtype),
): ):
with T.Kernel(T.ceildiv(DV, block_DV), B * H, threads=threads) as (bv, bbh): with T.Kernel(T.ceildiv(DV, block_DV), B * H, threads=threads) as (bv, bbh):
bb, bh = bbh // H, bbh % H bb, bh = bbh // H, bbh % H
...@@ -249,13 +250,13 @@ def tilelang_chunk_gated_delta_rule_bwd_dhu( ...@@ -249,13 +250,13 @@ def tilelang_chunk_gated_delta_rule_bwd_dhu(
dv_fragment = T.alloc_fragment((block_S, block_DV), dtype=accum_dtype) dv_fragment = T.alloc_fragment((block_S, block_DV), dtype=accum_dtype)
dv_fragment_2 = T.alloc_fragment((block_S, block_DV), dtype=accum_dtype) dv_fragment_2 = T.alloc_fragment((block_S, block_DV), dtype=accum_dtype)
dO_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype) dO_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype)
dO_shared_t = T.alloc_shared((block_DV, block_S), dtype="float32") dO_shared_t = T.alloc_shared((block_DV, block_S), dtype=T.float32)
dO_fragment = T.alloc_fragment((block_S, block_DV), dtype="float32") dO_fragment = T.alloc_fragment((block_S, block_DV), dtype=T.float32)
dO_fragment_t = T.alloc_fragment((block_DV, block_S), dtype="float32") dO_fragment_t = T.alloc_fragment((block_DV, block_S), dtype=T.float32)
K_shared = T.alloc_shared((block_S, DK), dtype=input_dtype) K_shared = T.alloc_shared((block_S, DK), dtype=input_dtype)
Q_shared = T.alloc_shared((block_S, DK), dtype=input_dtype) Q_shared = T.alloc_shared((block_S, DK), dtype=input_dtype)
Q_shared_fp32 = T.alloc_shared((block_S, DK), dtype="float32") Q_shared_fp32 = T.alloc_shared((block_S, DK), dtype=T.float32)
W_shared = T.alloc_shared((block_S, DK), dtype=input_dtype) W_shared = T.alloc_shared((block_S, DK), dtype=input_dtype)
G_last_local = T.alloc_local((1), dtype=gate_dtype) G_last_local = T.alloc_local((1), dtype=gate_dtype)
...@@ -269,20 +270,22 @@ def tilelang_chunk_gated_delta_rule_bwd_dhu( ...@@ -269,20 +270,22 @@ def tilelang_chunk_gated_delta_rule_bwd_dhu(
T.use_swizzle(10) T.use_swizzle(10)
T.annotate_layout({ T.annotate_layout(
b_dh_shared: tilelang.layout.make_swizzled_layout(b_dh_shared), {
b_dh_shared_fp32: tilelang.layout.make_swizzled_layout(b_dh_shared_fp32), b_dh_shared: tilelang.layout.make_swizzled_layout(b_dh_shared),
dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), b_dh_shared_fp32: tilelang.layout.make_swizzled_layout(b_dh_shared_fp32),
dO_shared: tilelang.layout.make_swizzled_layout(dO_shared), dv_shared: tilelang.layout.make_swizzled_layout(dv_shared),
dO_shared_t: tilelang.layout.make_swizzled_layout(dO_shared_t), dO_shared: tilelang.layout.make_swizzled_layout(dO_shared),
K_shared: tilelang.layout.make_swizzled_layout(K_shared), dO_shared_t: tilelang.layout.make_swizzled_layout(dO_shared_t),
Q_shared: tilelang.layout.make_swizzled_layout(Q_shared), K_shared: tilelang.layout.make_swizzled_layout(K_shared),
Q_shared_fp32: tilelang.layout.make_swizzled_layout(Q_shared_fp32), Q_shared: tilelang.layout.make_swizzled_layout(Q_shared),
W_shared: tilelang.layout.make_swizzled_layout(W_shared), Q_shared_fp32: tilelang.layout.make_swizzled_layout(Q_shared_fp32),
}) W_shared: tilelang.layout.make_swizzled_layout(W_shared),
}
)
if use_final_state_gradient: if use_final_state_gradient:
T.copy(dht[bb, bh, 0:DK, bv * block_DV:(bv + 1) * block_DV], b_dh_shared) T.copy(dht[bb, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV], b_dh_shared)
T.copy(b_dh_shared, b_dh_fragment) T.copy(b_dh_shared, b_dh_fragment)
else: else:
T.clear(b_dh_fragment) T.clear(b_dh_fragment)
...@@ -293,17 +296,14 @@ def tilelang_chunk_gated_delta_rule_bwd_dhu( ...@@ -293,17 +296,14 @@ def tilelang_chunk_gated_delta_rule_bwd_dhu(
# Store the updated dh # Store the updated dh
T.copy(b_dh_fragment, b_dh_shared) T.copy(b_dh_fragment, b_dh_shared)
T.copy(b_dh_shared, dh[bb, i_s_inv, bh, 0:DK, bv * block_DV:(bv + 1) * block_DV]) T.copy(b_dh_shared, dh[bb, i_s_inv, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV])
# Update dv # Update dv
T.copy(K[bb, i_s_inv * block_S:(i_s_inv + 1) * block_S, bh, 0:DK], K_shared) T.copy(K[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh, 0:DK], K_shared)
T.gemm(K_shared, b_dh_shared, dv_fragment, clear_accum=True) T.gemm(K_shared, b_dh_shared, dv_fragment, clear_accum=True)
if use_g: if use_g:
T.copy( T.copy(G[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh], G_shared, disable_tma=True)
G[bb, i_s_inv * block_S:(i_s_inv + 1) * block_S, bh],
G_shared,
disable_tma=True)
T.copy(G_shared, G_fragment) T.copy(G_shared, G_fragment)
G_last_local[0] = G_shared[block_S - 1] G_last_local[0] = G_shared[block_S - 1]
G_last_local_exp[0] = T.exp(G_last_local[0]) G_last_local_exp[0] = T.exp(G_last_local[0])
...@@ -313,27 +313,22 @@ def tilelang_chunk_gated_delta_rule_bwd_dhu( ...@@ -313,27 +313,22 @@ def tilelang_chunk_gated_delta_rule_bwd_dhu(
# with T.If(G_last_local[0] - G_shared[i_s2] <= 0): # with T.If(G_last_local[0] - G_shared[i_s2] <= 0):
with T.If(G_last_local[0] - G_fragment[i_s2] <= 0): with T.If(G_last_local[0] - G_fragment[i_s2] <= 0):
with T.Then(): with T.Then():
dv_fragment[i_s2, dv_fragment[i_s2, i_v] = dv_fragment[i_s2, i_v] * G_fragment_post[i_s2]
i_v] = dv_fragment[i_s2, i_v] * G_fragment_post[i_s2]
with T.Else(): with T.Else():
dv_fragment[i_s2, i_v] = 0 dv_fragment[i_s2, i_v] = 0
T.copy( T.copy(dv[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh, bv * block_DV : (bv + 1) * block_DV], dv_shared)
dv[bb, i_s_inv * block_S:(i_s_inv + 1) * block_S, bh,
bv * block_DV:(bv + 1) * block_DV], dv_shared)
T.copy(dv_shared, dv_fragment_2) T.copy(dv_shared, dv_fragment_2)
for i_s2, i_v in T.Parallel(block_S, block_DV): for i_s2, i_v in T.Parallel(block_S, block_DV):
dv_fragment[i_s2, i_v] = dv_fragment[i_s2, i_v] + dv_fragment_2[i_s2, i_v] dv_fragment[i_s2, i_v] = dv_fragment[i_s2, i_v] + dv_fragment_2[i_s2, i_v]
# Store the updated dv # Store the updated dv
T.copy(dv_fragment, dv_shared) T.copy(dv_fragment, dv_shared)
T.copy( T.copy(dv_shared, dv2[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh, bv * block_DV : (bv + 1) * block_DV])
dv_shared, dv2[bb, i_s_inv * block_S:(i_s_inv + 1) * block_S, bh,
bv * block_DV:(bv + 1) * block_DV])
# Update dh # Update dh
T.copy(Q[bb, i_s_inv * block_S:(i_s_inv + 1) * block_S, bh, 0:DK], Q_shared) T.copy(Q[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh, 0:DK], Q_shared)
T.copy(W[bb, i_s_inv * block_S:(i_s_inv + 1) * block_S, bh, 0:DK], W_shared) T.copy(W[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh, 0:DK], W_shared)
T.clear(Q_fragment) T.clear(Q_fragment)
if use_g: if use_g:
...@@ -353,9 +348,7 @@ def tilelang_chunk_gated_delta_rule_bwd_dhu( ...@@ -353,9 +348,7 @@ def tilelang_chunk_gated_delta_rule_bwd_dhu(
for i_s2, i_k in T.Parallel(block_S, DK): for i_s2, i_k in T.Parallel(block_S, DK):
Q_fragment_t[i_k, i_s2] = Q_fragment[i_s2, i_k] Q_fragment_t[i_k, i_s2] = Q_fragment[i_s2, i_k]
T.copy( T.copy(dO[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh, bv * block_DV : (bv + 1) * block_DV], dO_shared)
dO[bb, i_s_inv * block_S:(i_s_inv + 1) * block_S, bh,
bv * block_DV:(bv + 1) * block_DV], dO_shared)
T.copy(dO_shared, dO_fragment) T.copy(dO_shared, dO_fragment)
for i_s2, i_v in T.Parallel(block_S, block_DV): for i_s2, i_v in T.Parallel(block_S, block_DV):
dO_fragment_t[i_v, i_s2] = dO_fragment[i_s2, i_v] dO_fragment_t[i_v, i_s2] = dO_fragment[i_s2, i_v]
...@@ -369,7 +362,7 @@ def tilelang_chunk_gated_delta_rule_bwd_dhu( ...@@ -369,7 +362,7 @@ def tilelang_chunk_gated_delta_rule_bwd_dhu(
b_dh_fragment[i_k, i_v] += b_dh_fragment_1[i_k, i_v] - b_dh_fragment_2[i_k, i_v] b_dh_fragment[i_k, i_v] += b_dh_fragment_1[i_k, i_v] - b_dh_fragment_2[i_k, i_v]
if use_initial_state: if use_initial_state:
T.copy(b_dh_fragment, dh0[bb, bh, 0:DK, bv * block_DV:(bv + 1) * block_DV]) T.copy(b_dh_fragment, dh0[bb, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV])
return kernel return kernel
...@@ -444,44 +437,61 @@ def run_test( ...@@ -444,44 +437,61 @@ def run_test(
num_stages=0, num_stages=0,
use_torch=False, use_torch=False,
): ):
Q, K, W, G, h0, dht, dO, dv = prepare_input(B, S, H, DK, DV, chunk_size, Q, K, W, G, h0, dht, dO, dv = prepare_input(
getattr(torch, input_dtype), B,
getattr(torch, output_dtype), S,
getattr(torch, accum_dtype), H,
getattr(torch, gate_dtype), DK,
getattr(torch, state_dtype)) DV,
dh_ref, dh0_ref, dv2_ref = prepare_output(B, S, H, DK, DV, chunk_size, chunk_size,
getattr(torch, output_dtype), getattr(torch, input_dtype),
getattr(torch, gate_dtype), getattr(torch, output_dtype),
getattr(torch, state_dtype)) getattr(torch, accum_dtype),
dh_tilelang, dh0_tilelang, dv2_tilelang = prepare_output(B, S, H, DK, DV, chunk_size, getattr(torch, gate_dtype),
getattr(torch, output_dtype), getattr(torch, state_dtype),
getattr(torch, gate_dtype), )
getattr(torch, state_dtype)) dh_ref, dh0_ref, dv2_ref = prepare_output(
B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype), getattr(torch, state_dtype)
)
dh_tilelang, dh0_tilelang, dv2_tilelang = prepare_output(
B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype), getattr(torch, state_dtype)
)
# fla ref # fla ref
print("fla running...", flush=True) print("fla running...", flush=True)
if use_g: if use_g:
dh_ref, dh0_ref, dv2_ref = chunk_gated_delta_rule_bwd_dhu(Q, K, W, G, h0, dht, dO, dv, dh_ref, dh0_ref, dv2_ref = chunk_gated_delta_rule_bwd_dhu(Q, K, W, G, h0, dht, dO, dv, scale)
scale)
else: else:
G = G.fill_(0) G = G.fill_(0)
dh_ref, dh0_ref, dv2_ref = chunk_gated_delta_rule_bwd_dhu(Q, K, W, G, h0, dht, dO, dv, dh_ref, dh0_ref, dv2_ref = chunk_gated_delta_rule_bwd_dhu(Q, K, W, G, h0, dht, dO, dv, scale)
scale)
# tilelang # tilelang
print("tilelang running...", flush=True) print("tilelang running...", flush=True)
kernel = tilelang_chunk_gated_delta_rule_bwd_dhu(B, S, H, DK, DV, input_dtype, output_dtype, kernel = tilelang_chunk_gated_delta_rule_bwd_dhu(
accum_dtype, gate_dtype, state_dtype, B,
chunk_size, scale, use_g, use_initial_state, S,
use_final_state_gradient, block_DV, threads, H,
num_stages) DK,
DV,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
state_dtype,
chunk_size,
scale,
use_g,
use_initial_state,
use_final_state_gradient,
block_DV,
threads,
num_stages,
)
# kernel = tilelang.compile(program) # kernel = tilelang.compile(program)
print(kernel.get_kernel_source()) print(kernel.get_kernel_source())
dh_tilelang, dh0_tilelang, dv2_tilelang = kernel(Q, K, W, G, h0, dht, dO, dv) dh_tilelang, dh0_tilelang, dv2_tilelang = kernel(Q, K, W, G, h0, dht, dO, dv)
fla_time = do_bench( fla_time = do_bench(chunk_gated_delta_rule_bwd_dhu, Q, K, W, G, h0, dht, dO, dv, scale, chunk_size=chunk_size)
chunk_gated_delta_rule_bwd_dhu, Q, K, W, G, h0, dht, dO, dv, scale, chunk_size=chunk_size)
tilelang_time = do_bench(kernel, Q, K, W, G, h0, dht, dO, dv) tilelang_time = do_bench(kernel, Q, K, W, G, h0, dht, dO, dv)
print(f"fla time: {fla_time} ms") print(f"fla time: {fla_time} ms")
...@@ -496,19 +506,47 @@ def run_test( ...@@ -496,19 +506,47 @@ def run_test(
print("torch running...", flush=True) print("torch running...", flush=True)
if use_g: if use_g:
dh_ref_torch, dh0_ref_torch, dv2_ref_torch = torch_chunk_gated_delta_rule_bwd_dhu( dh_ref_torch, dh0_ref_torch, dv2_ref_torch = torch_chunk_gated_delta_rule_bwd_dhu(
Q, K, W, G, h0, dht, dO, dv, scale, use_g, use_initial_state, Q,
use_final_state_gradient, getattr(torch, input_dtype), getattr(torch, output_dtype), K,
getattr(torch, accum_dtype), getattr(torch, W,
gate_dtype), getattr(torch, state_dtype)) G,
h0,
dht,
dO,
dv,
scale,
use_g,
use_initial_state,
use_final_state_gradient,
getattr(torch, input_dtype),
getattr(torch, output_dtype),
getattr(torch, accum_dtype),
getattr(torch, gate_dtype),
getattr(torch, state_dtype),
)
dh_ref_torch = dh_ref_torch.cuda() dh_ref_torch = dh_ref_torch.cuda()
dh0_ref_torch = dh0_ref_torch.cuda() dh0_ref_torch = dh0_ref_torch.cuda()
dv2_ref_torch = dv2_ref_torch.cuda() dv2_ref_torch = dv2_ref_torch.cuda()
else: else:
dh_ref_torch, dh0_ref_torch, dv2_ref_torch = torch_chunk_gated_delta_rule_bwd_dhu( dh_ref_torch, dh0_ref_torch, dv2_ref_torch = torch_chunk_gated_delta_rule_bwd_dhu(
Q, K, W, None, h0, dht, dO, dv, scale, use_g, use_initial_state, Q,
use_final_state_gradient, getattr(torch, input_dtype), getattr(torch, output_dtype), K,
getattr(torch, accum_dtype), getattr(torch, W,
gate_dtype), getattr(torch, state_dtype)) None,
h0,
dht,
dO,
dv,
scale,
use_g,
use_initial_state,
use_final_state_gradient,
getattr(torch, input_dtype),
getattr(torch, output_dtype),
getattr(torch, accum_dtype),
getattr(torch, gate_dtype),
getattr(torch, state_dtype),
)
dh_ref_torch = dh_ref_torch.cuda() dh_ref_torch = dh_ref_torch.cuda()
dh0_ref_torch = dh0_ref_torch.cuda() dh0_ref_torch = dh0_ref_torch.cuda()
dv2_ref_torch = dv2_ref_torch.cuda() dv2_ref_torch = dv2_ref_torch.cuda()
...@@ -554,11 +592,11 @@ def main(): ...@@ -554,11 +592,11 @@ def main():
H=8, H=8,
DK=DK, DK=DK,
DV=128, DV=128,
input_dtype="bfloat16", input_dtype=T.bfloat16,
output_dtype="bfloat16", output_dtype=T.bfloat16,
accum_dtype="float32", accum_dtype=T.float32,
gate_dtype="float32", gate_dtype=T.float32,
state_dtype="float32", state_dtype=T.float32,
chunk_size=64, chunk_size=64,
scale=DK**-0.5, scale=DK**-0.5,
use_g=True, use_g=True,
......
...@@ -3,12 +3,14 @@ ...@@ -3,12 +3,14 @@
import sys # noqa: F401 import sys # noqa: F401
import tilelang import tilelang
import tilelang.language as T import tilelang.language as T
from tilelang.autotuner import autotune
# Add your fla repository path to sys.path # Add your fla repository path to sys.path
# Currently we use the fla repository from the flash-linear-attention project at commit id f03cb3ae # Currently we use the fla repository from the flash-linear-attention project at commit id f03cb3ae
# sys.path.insert(0, "/home/tzj/flash-linear-attention") # sys.path.insert(0, "/home/tzj/flash-linear-attention")
try: try:
import fla import fla
print(fla.__file__) print(fla.__file__)
from fla.ops.common.chunk_delta_h import chunk_gated_delta_rule_fwd_h from fla.ops.common.chunk_delta_h import chunk_gated_delta_rule_fwd_h
except ImportError: except ImportError:
...@@ -19,7 +21,7 @@ import torch ...@@ -19,7 +21,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from tilelang.engine.callback import register_cuda_postproc_callback # noqa: F401 from tilelang.engine.callback import register_cuda_postproc_callback # noqa: F401
from utils import * from test_utils import assert_similar
# (zhengju) We can slightly modify the generated cuda code from tilelang lowering # (zhengju) We can slightly modify the generated cuda code from tilelang lowering
# in the debug folder to make the performance better. To enable this callback, # in the debug folder to make the performance better. To enable this callback,
...@@ -55,6 +57,7 @@ def prepare_input( ...@@ -55,6 +57,7 @@ def prepare_input(
G = F.logsigmoid(G) G = F.logsigmoid(G)
try: try:
from fla.ops.utils.cumsum import chunk_local_cumsum from fla.ops.utils.cumsum import chunk_local_cumsum
G = chunk_local_cumsum(G, chunk_size) G = chunk_local_cumsum(G, chunk_size)
except ImportError: except ImportError:
print("fla not found, skip cumsum") print("fla not found, skip cumsum")
...@@ -80,7 +83,21 @@ def prepare_output( ...@@ -80,7 +83,21 @@ def prepare_output(
return h, final_state, V_new return h, final_state, V_new
@tilelang.jit(out_idx=[-3, -2, -1]) def get_configs():
import itertools
block_DK = [32, 64, 128]
block_DV = [32, 64, 128]
threads = [128, 256]
num_stages = [1, 2, 3]
_configs = list(itertools.product(block_DK, block_DV, threads, num_stages))
configs = [{"block_DK": c[0], "block_DV": c[1], "threads": c[2], "num_stages": c[3]} for c in _configs]
return configs
@autotune(configs=get_configs(), warmup=3, rep=5)
@tilelang.jit(out_idx=[-3, -2, -1], pass_configs={tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True})
def tilelang_chunk_gated_delta_rule_fwd_h( def tilelang_chunk_gated_delta_rule_fwd_h(
# task config # task config
B, B,
...@@ -94,15 +111,15 @@ def tilelang_chunk_gated_delta_rule_fwd_h( ...@@ -94,15 +111,15 @@ def tilelang_chunk_gated_delta_rule_fwd_h(
gate_dtype, gate_dtype,
state_dtype, state_dtype,
chunk_size, chunk_size,
use_g=True, use_g,
use_initial_state=True, use_initial_state,
store_final_state=True, store_final_state,
save_new_value=True, save_new_value,
# kernel config # kernel config
block_DK=64, block_DK=64,
block_DV=64, block_DV=32,
threads=256, threads=128,
num_stages=0, num_stages=1,
): ):
block_S = chunk_size block_S = chunk_size
BS = S // block_S BS = S // block_S
...@@ -118,14 +135,14 @@ def tilelang_chunk_gated_delta_rule_fwd_h( ...@@ -118,14 +135,14 @@ def tilelang_chunk_gated_delta_rule_fwd_h(
@T.prim_func @T.prim_func
def kernel( def kernel(
K: T.Tensor(K_shape, dtype=input_dtype), K: T.Tensor(K_shape, dtype=input_dtype),
W: T.Tensor(W_shape, dtype=input_dtype), W: T.Tensor(W_shape, dtype=input_dtype),
U: T.Tensor(U_shape, dtype=input_dtype), U: T.Tensor(U_shape, dtype=input_dtype),
G: T.Tensor(G_shape, dtype=gate_dtype), G: T.Tensor(G_shape, dtype=gate_dtype),
initial_state: T.Tensor(initial_state_shape, dtype=input_dtype), initial_state: T.Tensor(initial_state_shape, dtype=input_dtype),
h: T.Tensor(h_shape, dtype=output_dtype), h: T.Tensor(h_shape, dtype=output_dtype),
final_state: T.Tensor(final_state_shape, dtype=state_dtype), final_state: T.Tensor(final_state_shape, dtype=state_dtype),
V_new: T.Tensor(V_shape, dtype=output_dtype), V_new: T.Tensor(V_shape, dtype=output_dtype),
): ):
with T.Kernel(T.ceildiv(DV, block_DV), B * H, threads=threads) as (bv, bbh): with T.Kernel(T.ceildiv(DV, block_DV), B * H, threads=threads) as (bv, bbh):
bb, bh = bbh // H, bbh % H bb, bh = bbh // H, bbh % H
...@@ -143,35 +160,35 @@ def tilelang_chunk_gated_delta_rule_fwd_h( ...@@ -143,35 +160,35 @@ def tilelang_chunk_gated_delta_rule_fwd_h(
G_shared = T.alloc_shared((block_S, block_DV), dtype=gate_dtype) G_shared = T.alloc_shared((block_S, block_DV), dtype=gate_dtype)
G_fragment = T.alloc_fragment((block_S, block_DV), dtype=gate_dtype) G_fragment = T.alloc_fragment((block_S, block_DV), dtype=gate_dtype)
T.annotate_layout({ T.annotate_layout(
b_h_shared: tilelang.layout.make_swizzled_layout(b_h_shared), {
U_shared: tilelang.layout.make_swizzled_layout(U_shared), b_h_shared: tilelang.layout.make_swizzled_layout(b_h_shared),
W_shared: tilelang.layout.make_swizzled_layout(W_shared), U_shared: tilelang.layout.make_swizzled_layout(U_shared),
V_new_shared: tilelang.layout.make_swizzled_layout(V_new_shared), W_shared: tilelang.layout.make_swizzled_layout(W_shared),
K_shared: tilelang.layout.make_swizzled_layout(K_shared), V_new_shared: tilelang.layout.make_swizzled_layout(V_new_shared),
G_shared: tilelang.layout.make_swizzled_layout(G_shared), K_shared: tilelang.layout.make_swizzled_layout(K_shared),
}) G_shared: tilelang.layout.make_swizzled_layout(G_shared),
}
)
T.use_swizzle(10) T.use_swizzle(10)
if use_initial_state: if use_initial_state:
T.copy(initial_state[bb, bh, 0:DK, bv * block_DV:(bv + 1) * block_DV], b_h_shared) T.copy(initial_state[bb, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV], b_h_shared)
T.copy(b_h_shared, b_h_fragment) T.copy(b_h_shared, b_h_fragment)
else: else:
T.clear(b_h_fragment) T.clear(b_h_fragment)
for i_s in T.Pipelined(T.ceildiv(S, block_S), num_stages=num_stages): for i_s in T.Pipelined(T.ceildiv(S, block_S), num_stages=num_stages):
# Store previous result to the hidden tensor, like the epilogue # Store previous result to the hidden tensor, like the epilogue
T.copy(b_h_shared, h[bb, i_s, bh, 0:DK, bv * block_DV:(bv + 1) * block_DV]) T.copy(b_h_shared, h[bb, i_s, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV])
# Recurrence # Recurrence
T.copy(W[bb, i_s * block_S:(i_s + 1) * block_S, bh, 0:DK], W_shared) T.copy(W[bb, i_s * block_S : (i_s + 1) * block_S, bh, 0:DK], W_shared)
T.gemm(W_shared, b_h_shared, V_new_fragment, clear_accum=True) T.gemm(W_shared, b_h_shared, V_new_fragment, clear_accum=True)
# U - W * S # U - W * S
T.copy( T.copy(U[bb, i_s * block_S : (i_s + 1) * block_S, bh, bv * block_DV : (bv + 1) * block_DV], U_shared)
U[bb, i_s * block_S:(i_s + 1) * block_S, bh, bv * block_DV:(bv + 1) * block_DV],
U_shared)
T.copy(U_shared, U_fragment) T.copy(U_shared, U_fragment)
for i_s2, i_v in T.Parallel(block_S, block_DV): for i_s2, i_v in T.Parallel(block_S, block_DV):
V_new_fragment[i_s2, i_v] = -V_new_fragment[i_s2, i_v] + U_fragment[i_s2, i_v] V_new_fragment[i_s2, i_v] = -V_new_fragment[i_s2, i_v] + U_fragment[i_s2, i_v]
...@@ -179,11 +196,9 @@ def tilelang_chunk_gated_delta_rule_fwd_h( ...@@ -179,11 +196,9 @@ def tilelang_chunk_gated_delta_rule_fwd_h(
# Save V_new # Save V_new
if save_new_value: if save_new_value:
T.copy(V_new_fragment, dst=V_new_shared) T.copy(V_new_fragment, dst=V_new_shared)
T.copy( T.copy(V_new_shared, V_new[bb, i_s * block_S : (i_s + 1) * block_S, bh, bv * block_DV : (bv + 1) * block_DV])
V_new_shared, V_new[bb, i_s * block_S:(i_s + 1) * block_S, bh,
bv * block_DV:(bv + 1) * block_DV])
T.copy(K[bb, i_s * block_S:(i_s + 1) * block_S, bh, 0:DK], K_shared) T.copy(K[bb, i_s * block_S : (i_s + 1) * block_S, bh, 0:DK], K_shared)
# use_g # use_g
if use_g: if use_g:
G_last_local[0] = G[bb, (i_s + 1) * block_S - 1, bh] G_last_local[0] = G[bb, (i_s + 1) * block_S - 1, bh]
...@@ -193,11 +208,12 @@ def tilelang_chunk_gated_delta_rule_fwd_h( ...@@ -193,11 +208,12 @@ def tilelang_chunk_gated_delta_rule_fwd_h(
for i_s2, i_v in T.Parallel(block_S, block_DV): for i_s2, i_v in T.Parallel(block_S, block_DV):
with T.If(G_last_local[0] - G_fragment[i_s2, i_v] <= 0): with T.If(G_last_local[0] - G_fragment[i_s2, i_v] <= 0):
with T.Then(): with T.Then():
V_new_fragment[i_s2, i_v] = V_new_fragment[i_s2, i_v] * T.exp( V_new_fragment[i_s2, i_v] = V_new_fragment[i_s2, i_v] * T.exp2(
G_last_local[0] - G_fragment[i_s2, i_v]) (G_last_local[0] - G_fragment[i_s2, i_v]) * 1.442695
)
with T.Else(): with T.Else():
V_new_fragment[i_s2, i_v] = 0 V_new_fragment[i_s2, i_v] = 0
G_last_local[0] = T.exp(G_last_local[0]) G_last_local[0] = T.exp2(G_last_local[0] * 1.442695)
for i_k, i_v in T.Parallel(DK, block_DV): for i_k, i_v in T.Parallel(DK, block_DV):
b_h_fragment[i_k, i_v] *= G_last_local[0] b_h_fragment[i_k, i_v] *= G_last_local[0]
...@@ -209,7 +225,7 @@ def tilelang_chunk_gated_delta_rule_fwd_h( ...@@ -209,7 +225,7 @@ def tilelang_chunk_gated_delta_rule_fwd_h(
# Save final state # Save final state
if store_final_state: if store_final_state:
T.copy(b_h_fragment, final_state[bb, bh, 0:DK, bv * block_DV:(bv + 1) * block_DV]) T.copy(b_h_fragment, final_state[bb, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV])
return kernel return kernel
...@@ -260,47 +276,77 @@ def run_test( ...@@ -260,47 +276,77 @@ def run_test(
threads=128, threads=128,
num_stages=0, num_stages=0,
): ):
K, W, U, G, initial_state = prepare_input(B, S, H, DK, DV, chunk_size, K, W, U, G, initial_state = prepare_input(
getattr(torch, input_dtype), B,
getattr(torch, output_dtype), S,
getattr(torch, accum_dtype), H,
getattr(torch, gate_dtype)) DK,
h_ref, final_state_ref, V_new_ref = prepare_output(B, S, H, DK, DV, chunk_size, DV,
getattr(torch, output_dtype), chunk_size,
getattr(torch, state_dtype)) getattr(torch, input_dtype),
h_tilelang, final_state_tilelang, V_new_tilelang = prepare_output(B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype),
getattr(torch, output_dtype), getattr(torch, accum_dtype),
getattr(torch, state_dtype)) getattr(torch, gate_dtype),
)
h_ref, final_state_ref, V_new_ref = prepare_output(
B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, state_dtype)
)
h_tilelang, final_state_tilelang, V_new_tilelang = prepare_output(
B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, state_dtype)
)
# fla ref # fla ref
h_ref, V_new_ref, final_state_ref = chunk_gated_delta_rule_fwd_h(K, W, U, G, initial_state, h_ref, V_new_ref, final_state_ref = chunk_gated_delta_rule_fwd_h(
store_final_state, chunk_size, k=K,
save_new_value) w=W,
u=U,
g=G,
initial_state=initial_state,
output_final_state=store_final_state,
chunk_size=chunk_size,
save_new_value=save_new_value,
)
# tilelang # tilelang
kernel = tilelang_chunk_gated_delta_rule_fwd_h(B, S, H, DK, DV, input_dtype, output_dtype, kernel = tilelang_chunk_gated_delta_rule_fwd_h(
accum_dtype, gate_dtype, state_dtype, chunk_size, B,
use_g, use_initial_state, store_final_state, S,
save_new_value, block_DK, block_DV, threads, H,
num_stages) DK,
DV,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
state_dtype,
chunk_size,
use_g,
use_initial_state,
store_final_state,
save_new_value,
)
h_tilelang, final_state_tilelang, V_new_tilelang = kernel(K, W, U, G, initial_state) h_tilelang, final_state_tilelang, V_new_tilelang = kernel(K, W, U, G, initial_state)
# (zhengju) If you want to print the generated cuda code, you can uncomment the following line # (zhengju) If you want to print the generated cuda code, you can uncomment the following line
# print("CUDA Code:\n", kernel.get_kernel_source()) # print("CUDA Code:\n", kernel.get_kernel_source())
fla_time = do_bench(chunk_gated_delta_rule_fwd_h, K, W, U, G, initial_state, store_final_state, fla_time = do_bench(
chunk_size, save_new_value) chunk_gated_delta_rule_fwd_h,
k=K,
w=W,
u=U,
g=G,
initial_state=initial_state,
output_final_state=store_final_state,
chunk_size=chunk_size,
save_new_value=save_new_value,
)
tilelang_time = do_bench(kernel, K, W, U, G, initial_state) tilelang_time = do_bench(kernel, K, W, U, G, initial_state)
# check correctness # check correctness
try: try:
h_ref_fp32 = h_ref.to(torch.float32) h_ref_fp32 = h_ref.to(torch.float32)
h_tilelang_fp32 = h_tilelang.to(torch.float32) h_tilelang_fp32 = h_tilelang.to(torch.float32)
assert_similar( assert_similar(h_ref_fp32, h_tilelang_fp32, eps=1e-5, name="tilelang chunk gated delta rule fwd h", raise_assert=False)
h_ref_fp32,
h_tilelang_fp32,
eps=1e-5,
name="tilelang chunk gated delta rule fwd h",
raise_assert=False)
print("tilelang chunk gated delta rule fwd h passed √") print("tilelang chunk gated delta rule fwd h passed √")
except Exception as e: except Exception as e:
print("tilelang chunk gated delta rule fwd h failed ✗") print("tilelang chunk gated delta rule fwd h failed ✗")
...@@ -314,7 +360,8 @@ def run_test( ...@@ -314,7 +360,8 @@ def run_test(
final_state_tilelang_fp32, final_state_tilelang_fp32,
eps=1e-5, eps=1e-5,
name="tilelang chunk gated delta rule fwd final_state", name="tilelang chunk gated delta rule fwd final_state",
raise_assert=False) raise_assert=False,
)
print("tilelang chunk gated delta rule fwd final_state passed √") print("tilelang chunk gated delta rule fwd final_state passed √")
except Exception as e: except Exception as e:
print("tilelang chunk gated delta rule fwd final_state failed ✗") print("tilelang chunk gated delta rule fwd final_state failed ✗")
...@@ -323,12 +370,7 @@ def run_test( ...@@ -323,12 +370,7 @@ def run_test(
try: try:
V_new_ref_fp32 = V_new_ref.to(torch.float32) V_new_ref_fp32 = V_new_ref.to(torch.float32)
V_new_tilelang_fp32 = V_new_tilelang.to(torch.float32) V_new_tilelang_fp32 = V_new_tilelang.to(torch.float32)
assert_similar( assert_similar(V_new_ref_fp32, V_new_tilelang_fp32, eps=1e-5, name="tilelang chunk gated delta rule fwd V_new", raise_assert=False)
V_new_ref_fp32,
V_new_tilelang_fp32,
eps=1e-5,
name="tilelang chunk gated delta rule fwd V_new",
raise_assert=False)
print("tilelang chunk gated delta rule fwd V_new passed √") print("tilelang chunk gated delta rule fwd V_new passed √")
except Exception as e: except Exception as e:
print("tilelang chunk gated delta rule fwd V_new failed ✗") print("tilelang chunk gated delta rule fwd V_new failed ✗")
...@@ -345,20 +387,20 @@ def main(): ...@@ -345,20 +387,20 @@ def main():
H=32, H=32,
DK=128, DK=128,
DV=128, DV=128,
input_dtype="bfloat16", input_dtype=T.bfloat16,
output_dtype="bfloat16", output_dtype=T.bfloat16,
accum_dtype="float32", accum_dtype=T.float32,
gate_dtype="float32", gate_dtype=T.float32,
state_dtype="float32", state_dtype=T.float32,
chunk_size=64, chunk_size=64,
use_g=True, use_g=True,
use_initial_state=True, use_initial_state=False,
store_final_state=True, store_final_state=True,
save_new_value=True, save_new_value=True,
block_DK=64, block_DK=32,
block_DV=32, block_DV=32,
threads=128, threads=128,
num_stages=1, num_stages=2,
) )
......
...@@ -9,6 +9,7 @@ import sys # noqa: F401 ...@@ -9,6 +9,7 @@ import sys # noqa: F401
# sys.path.insert(0, "/home/tzj/flash-linear-attention") # sys.path.insert(0, "/home/tzj/flash-linear-attention")
try: try:
import fla import fla
print(fla.__file__) print(fla.__file__)
from fla.ops.common.chunk_o import chunk_fwd_o from fla.ops.common.chunk_o import chunk_fwd_o
except ImportError: except ImportError:
...@@ -87,16 +88,14 @@ def tilelang_chunk_fwd_o( ...@@ -87,16 +88,14 @@ def tilelang_chunk_fwd_o(
@T.prim_func @T.prim_func
def kernel( def kernel(
Q: T.Tensor(Q_shape, dtype=input_dtype), Q: T.Tensor(Q_shape, dtype=input_dtype),
K: T.Tensor(K_shape, dtype=input_dtype), K: T.Tensor(K_shape, dtype=input_dtype),
V: T.Tensor(V_shape, dtype=input_dtype), V: T.Tensor(V_shape, dtype=input_dtype),
HIDDEN: T.Tensor(H_shape, dtype=input_dtype), HIDDEN: T.Tensor(H_shape, dtype=input_dtype),
G: T.Tensor(G_shape, dtype=gate_dtype), G: T.Tensor(G_shape, dtype=gate_dtype),
O: T.Tensor(O_shape, dtype=output_dtype), O: T.Tensor(O_shape, dtype=output_dtype),
): ):
with T.Kernel( with T.Kernel(T.ceildiv(DV, block_DV), T.ceildiv(S, block_S), B * H, threads=threads) as (bv, bs, bbh):
T.ceildiv(DV, block_DV), T.ceildiv(S, block_S), B * H,
threads=threads) as (bv, bs, bbh):
bb, bh = bbh // H, bbh % H bb, bh = bbh // H, bbh % H
Q_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype) Q_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype)
K_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype) K_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype)
...@@ -109,28 +108,24 @@ def tilelang_chunk_fwd_o( ...@@ -109,28 +108,24 @@ def tilelang_chunk_fwd_o(
G_shared = T.alloc_shared((block_S,), dtype=gate_dtype, scope="shared") G_shared = T.alloc_shared((block_S,), dtype=gate_dtype, scope="shared")
G_diff_local = T.alloc_fragment((block_S, block_S), dtype=gate_dtype) G_diff_local = T.alloc_fragment((block_S, block_S), dtype=gate_dtype)
T.annotate_layout({ T.annotate_layout(
Q_shared: tilelang.layout.make_swizzled_layout(Q_shared), {
K_shared: tilelang.layout.make_swizzled_layout(K_shared), Q_shared: tilelang.layout.make_swizzled_layout(Q_shared),
V_shared: tilelang.layout.make_swizzled_layout(V_shared), K_shared: tilelang.layout.make_swizzled_layout(K_shared),
H_shared: tilelang.layout.make_swizzled_layout(H_shared), V_shared: tilelang.layout.make_swizzled_layout(V_shared),
A_shared: tilelang.layout.make_swizzled_layout(A_shared), H_shared: tilelang.layout.make_swizzled_layout(H_shared),
O_shared: tilelang.layout.make_swizzled_layout(O_shared), A_shared: tilelang.layout.make_swizzled_layout(A_shared),
}) O_shared: tilelang.layout.make_swizzled_layout(O_shared),
}
)
T.clear(A_fragment) T.clear(A_fragment)
T.clear(O_fragment) T.clear(O_fragment)
T.disable_warp_group_reg_alloc() T.disable_warp_group_reg_alloc()
for i_k in T.Pipelined(T.ceildiv(DK, block_DK), num_stages=num_stages): for i_k in T.Pipelined(T.ceildiv(DK, block_DK), num_stages=num_stages):
T.copy( T.copy(Q[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK], Q_shared)
Q[bb, bs * block_S:(bs + 1) * block_S, bh, i_k * block_DK:(i_k + 1) * block_DK], T.copy(K[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK], K_shared)
Q_shared) T.copy(HIDDEN[bb, bs, bh, i_k * block_DK : (i_k + 1) * block_DK, bv * block_DV : (bv + 1) * block_DV], H_shared)
T.copy(
K[bb, bs * block_S:(bs + 1) * block_S, bh, i_k * block_DK:(i_k + 1) * block_DK],
K_shared)
T.copy(
HIDDEN[bb, bs, bh, i_k * block_DK:(i_k + 1) * block_DK,
bv * block_DV:(bv + 1) * block_DV], H_shared)
T.gemm(Q_shared, H_shared, O_fragment) T.gemm(Q_shared, H_shared, O_fragment)
T.gemm(Q_shared, K_shared, A_fragment, transpose_B=True) T.gemm(Q_shared, K_shared, A_fragment, transpose_B=True)
...@@ -145,8 +140,7 @@ def tilelang_chunk_fwd_o( ...@@ -145,8 +140,7 @@ def tilelang_chunk_fwd_o(
for i_s1, i_s2 in T.Parallel(block_S, block_S): for i_s1, i_s2 in T.Parallel(block_S, block_S):
with T.If(G_diff_local[i_s1, i_s2] <= 0): with T.If(G_diff_local[i_s1, i_s2] <= 0):
with T.Then(): with T.Then():
A_fragment[i_s1, i_s2] = A_fragment[i_s1, i_s2] * T.exp( A_fragment[i_s1, i_s2] = A_fragment[i_s1, i_s2] * T.exp(G_diff_local[i_s1, i_s2])
G_diff_local[i_s1, i_s2])
with T.Else(): with T.Else():
A_fragment[i_s1, i_s2] = 0 A_fragment[i_s1, i_s2] = 0
...@@ -155,8 +149,7 @@ def tilelang_chunk_fwd_o( ...@@ -155,8 +149,7 @@ def tilelang_chunk_fwd_o(
with T.Then(): with T.Then():
A_fragment[i_s1, i_s2] = 0 A_fragment[i_s1, i_s2] = 0
T.copy(V[bb, bs * block_S:(bs + 1) * block_S, bh, bv * block_DV:(bv + 1) * block_DV], T.copy(V[bb, bs * block_S : (bs + 1) * block_S, bh, bv * block_DV : (bv + 1) * block_DV], V_shared)
V_shared)
T.copy(A_fragment, A_shared) T.copy(A_fragment, A_shared)
T.gemm(A_shared, V_shared, O_fragment) T.gemm(A_shared, V_shared, O_fragment)
...@@ -164,8 +157,7 @@ def tilelang_chunk_fwd_o( ...@@ -164,8 +157,7 @@ def tilelang_chunk_fwd_o(
O_fragment[i_s, i_v] = O_fragment[i_s, i_v] * scale O_fragment[i_s, i_v] = O_fragment[i_s, i_v] * scale
T.copy(O_fragment, O_shared) T.copy(O_fragment, O_shared)
T.copy(O_shared, O[bb, bs * block_S:(bs + 1) * block_S, bh, T.copy(O_shared, O[bb, bs * block_S : (bs + 1) * block_S, bh, bv * block_DV : (bv + 1) * block_DV])
bv * block_DV:(bv + 1) * block_DV])
return kernel return kernel
...@@ -191,8 +183,9 @@ def run_test( ...@@ -191,8 +183,9 @@ def run_test(
output_dtype_torch = getattr(torch, output_dtype) output_dtype_torch = getattr(torch, output_dtype)
accum_dtype_torch = getattr(torch, accum_dtype) accum_dtype_torch = getattr(torch, accum_dtype)
gate_dtype_torch = getattr(torch, gate_dtype) gate_dtype_torch = getattr(torch, gate_dtype)
Q, K, V, HIDDEN, G = prepare_input(B, S, H, DK, DV, chunk_size, input_dtype_torch, Q, K, V, HIDDEN, G = prepare_input(
output_dtype_torch, accum_dtype_torch, gate_dtype_torch) B, S, H, DK, DV, chunk_size, input_dtype_torch, output_dtype_torch, accum_dtype_torch, gate_dtype_torch
)
scale = 1.0 / DK**0.5 scale = 1.0 / DK**0.5
O_ref = prepare_output(B, S, H, DK, DV, chunk_size, output_dtype_torch) O_ref = prepare_output(B, S, H, DK, DV, chunk_size, output_dtype_torch)
...@@ -200,9 +193,25 @@ def run_test( ...@@ -200,9 +193,25 @@ def run_test(
block_S = chunk_size block_S = chunk_size
O_tilelang = prepare_output(B, S, H, DK, DV, chunk_size, output_dtype_torch) O_tilelang = prepare_output(B, S, H, DK, DV, chunk_size, output_dtype_torch)
kernel = tilelang_chunk_fwd_o(B, S, H, DK, DV, input_dtype, output_dtype, accum_dtype, kernel = tilelang_chunk_fwd_o(
gate_dtype, chunk_size, scale, use_g, block_S, block_DK, block_DV, B,
threads, num_stages) S,
H,
DK,
DV,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
chunk_size,
scale,
use_g,
block_S,
block_DK,
block_DV,
threads,
num_stages,
)
O_tilelang = kernel(Q, K, V, HIDDEN, G) O_tilelang = kernel(Q, K, V, HIDDEN, G)
try: try:
...@@ -221,10 +230,10 @@ def main(): ...@@ -221,10 +230,10 @@ def main():
DK=128, DK=128,
DV=128, DV=128,
chunk_size=64, chunk_size=64,
input_dtype="bfloat16", input_dtype=T.bfloat16,
output_dtype="bfloat16", output_dtype=T.bfloat16,
accum_dtype="float32", accum_dtype=T.float32,
gate_dtype="float32", gate_dtype=T.float32,
use_g=True, use_g=True,
block_DK=128, block_DK=128,
block_DV=128, block_DV=128,
......
...@@ -12,6 +12,7 @@ from tilelang.engine.callback import register_cuda_postproc_callback # noqa: F4 ...@@ -12,6 +12,7 @@ from tilelang.engine.callback import register_cuda_postproc_callback # noqa: F4
# sys.path.insert(0, "/home/tzj/flash-linear-attention") # sys.path.insert(0, "/home/tzj/flash-linear-attention")
try: try:
import fla import fla
print(fla.__file__) print(fla.__file__)
from fla.ops.common.chunk_o import chunk_bwd_dqkwg from fla.ops.common.chunk_o import chunk_bwd_dqkwg
except ImportError: except ImportError:
...@@ -19,7 +20,7 @@ except ImportError: ...@@ -19,7 +20,7 @@ except ImportError:
fla = None fla = None
import torch import torch
from utils import * from test_utils import assert_similar
torch.random.manual_seed(0) torch.random.manual_seed(0)
# torch.set_printoptions(profile="full") # torch.set_printoptions(profile="full")
...@@ -108,10 +109,8 @@ def prepare_output( ...@@ -108,10 +109,8 @@ def prepare_output(
@tilelang.jit( @tilelang.jit(
out_idx=[-4, -3, -2, -1], out_idx=[-4, -3, -2, -1],
pass_configs={ pass_configs={tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True},
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, )
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True
})
def tilelang_chunk_o_bwd_dqkwg( def tilelang_chunk_o_bwd_dqkwg(
# task config # task config
B, B,
...@@ -155,25 +154,23 @@ def tilelang_chunk_o_bwd_dqkwg( ...@@ -155,25 +154,23 @@ def tilelang_chunk_o_bwd_dqkwg(
@T.prim_func @T.prim_func
def kernel( def kernel(
# input # input
Q: T.Tensor(Q_shape, dtype=input_dtype), Q: T.Tensor(Q_shape, dtype=input_dtype),
K: T.Tensor(K_shape, dtype=input_dtype), K: T.Tensor(K_shape, dtype=input_dtype),
V: T.Tensor(V_shape, dtype=input_dtype), V: T.Tensor(V_shape, dtype=input_dtype),
h: T.Tensor(h_shape, dtype=input_dtype), h: T.Tensor(h_shape, dtype=input_dtype),
G: T.Tensor(G_shape, dtype=gate_dtype), G: T.Tensor(G_shape, dtype=gate_dtype),
dO: T.Tensor(dO_shape, dtype=input_dtype), dO: T.Tensor(dO_shape, dtype=input_dtype),
dh: T.Tensor(dh_shape, dtype=input_dtype), dh: T.Tensor(dh_shape, dtype=input_dtype),
dv: T.Tensor(dv_shape, dtype=input_dtype), dv: T.Tensor(dv_shape, dtype=input_dtype),
W: T.Tensor(W_shape, dtype=input_dtype), W: T.Tensor(W_shape, dtype=input_dtype),
# output # output
dq: T.Tensor(dq_shape, dtype=output_dtype), dq: T.Tensor(dq_shape, dtype=output_dtype),
dk: T.Tensor(dk_shape, dtype=output_dtype), dk: T.Tensor(dk_shape, dtype=output_dtype),
dw: T.Tensor(dw_shape, dtype=output_dtype), dw: T.Tensor(dw_shape, dtype=output_dtype),
dg: T.Tensor(dg_shape, dtype=gate_dtype), dg: T.Tensor(dg_shape, dtype=gate_dtype),
): ):
with T.Kernel( with T.Kernel(T.ceildiv(DK, block_DK), T.ceildiv(S, block_S), B * H, threads=threads) as (bk, bs, bbh):
T.ceildiv(DK, block_DK), T.ceildiv(S, block_S), B * H,
threads=threads) as (bk, bs, bbh):
bb, bh = bbh // H, bbh % H bb, bh = bbh // H, bbh % H
V_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype) V_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype)
...@@ -212,15 +209,17 @@ def tilelang_chunk_o_bwd_dqkwg( ...@@ -212,15 +209,17 @@ def tilelang_chunk_o_bwd_dqkwg(
T.use_swizzle(10) T.use_swizzle(10)
T.annotate_layout({ T.annotate_layout(
V_shared: tilelang.layout.make_swizzled_layout(V_shared), {
dO_shared: tilelang.layout.make_swizzled_layout(dO_shared), V_shared: tilelang.layout.make_swizzled_layout(V_shared),
h_shared: tilelang.layout.make_swizzled_layout(h_shared), dO_shared: tilelang.layout.make_swizzled_layout(dO_shared),
dh_shared: tilelang.layout.make_swizzled_layout(dh_shared), h_shared: tilelang.layout.make_swizzled_layout(h_shared),
dv_shared: tilelang.layout.make_swizzled_layout(dv_shared), dh_shared: tilelang.layout.make_swizzled_layout(dh_shared),
q_shared: tilelang.layout.make_swizzled_layout(q_shared), dv_shared: tilelang.layout.make_swizzled_layout(dv_shared),
k_shared: tilelang.layout.make_swizzled_layout(k_shared), q_shared: tilelang.layout.make_swizzled_layout(q_shared),
}) k_shared: tilelang.layout.make_swizzled_layout(k_shared),
}
)
T.clear(dg_last_local) T.clear(dg_last_local)
T.clear(G_last_local) T.clear(G_last_local)
...@@ -235,18 +234,10 @@ def tilelang_chunk_o_bwd_dqkwg( ...@@ -235,18 +234,10 @@ def tilelang_chunk_o_bwd_dqkwg(
T.clear(dw_fragment) T.clear(dw_fragment)
for i_v in T.Pipelined(T.ceildiv(DV, block_DV), num_stages=num_stages): for i_v in T.Pipelined(T.ceildiv(DV, block_DV), num_stages=num_stages):
T.copy( T.copy(V[bb, bs * block_S : (bs + 1) * block_S, bh, i_v * block_DV : (i_v + 1) * block_DV], V_shared)
V[bb, bs * block_S:(bs + 1) * block_S, bh, i_v * block_DV:(i_v + 1) * block_DV], T.copy(dO[bb, bs * block_S : (bs + 1) * block_S, bh, i_v * block_DV : (i_v + 1) * block_DV], dO_shared)
V_shared) T.copy(h[bb, bs, bh, bk * block_DK : (bk + 1) * block_DK, i_v * block_DV : (i_v + 1) * block_DV], h_shared)
T.copy( T.copy(dh[bb, bs, bh, bk * block_DK : (bk + 1) * block_DK, i_v * block_DV : (i_v + 1) * block_DV], dh_shared)
dO[bb, bs * block_S:(bs + 1) * block_S, bh,
i_v * block_DV:(i_v + 1) * block_DV], dO_shared)
T.copy(
h[bb, bs, bh, bk * block_DK:(bk + 1) * block_DK,
i_v * block_DV:(i_v + 1) * block_DV], h_shared)
T.copy(
dh[bb, bs, bh, bk * block_DK:(bk + 1) * block_DK,
i_v * block_DV:(i_v + 1) * block_DV], dh_shared)
if use_g: if use_g:
T.clear(dg_last_fragment_scalar) T.clear(dg_last_fragment_scalar)
...@@ -254,9 +245,7 @@ def tilelang_chunk_o_bwd_dqkwg( ...@@ -254,9 +245,7 @@ def tilelang_chunk_o_bwd_dqkwg(
# for i_kv in T.Parallel(block_DK * block_DV): # for i_kv in T.Parallel(block_DK * block_DV):
# dg_last_fragment[i_kv] = h_shared[i_kv // block_DV, i_kv % block_DV] * dh_shared[i_kv // block_DV, i_kv % block_DV] # dg_last_fragment[i_kv] = h_shared[i_kv // block_DV, i_kv % block_DV] * dh_shared[i_kv // block_DV, i_kv % block_DV]
for i_kv in T.Parallel(block_DK * block_DV): for i_kv in T.Parallel(block_DK * block_DV):
dg_last_fragment[i_kv] = h_shared[i_kv // block_DV, i_kv % dg_last_fragment[i_kv] = h_shared[i_kv // block_DV, i_kv % block_DV] * dh_shared[i_kv // block_DV, i_kv % block_DV]
block_DV] * dh_shared[i_kv // block_DV,
i_kv % block_DV]
T.reduce_sum(dg_last_fragment, dg_last_fragment_scalar, dim=-1, clear=False) T.reduce_sum(dg_last_fragment, dg_last_fragment_scalar, dim=-1, clear=False)
dg_last_local[0] += dg_last_fragment_scalar[0] dg_last_local[0] += dg_last_fragment_scalar[0]
...@@ -265,22 +254,16 @@ def tilelang_chunk_o_bwd_dqkwg( ...@@ -265,22 +254,16 @@ def tilelang_chunk_o_bwd_dqkwg(
T.gemm(V_shared, dh_shared, dk_fragment, transpose_B=True) T.gemm(V_shared, dh_shared, dk_fragment, transpose_B=True)
if use_dw: if use_dw:
T.copy( T.copy(dv[bb, bs * block_S : (bs + 1) * block_S, bh, i_v * block_DV : (i_v + 1) * block_DV], dv_shared)
dv[bb, bs * block_S:(bs + 1) * block_S, bh,
i_v * block_DV:(i_v + 1) * block_DV], dv_shared)
T.gemm(dv_shared, h_shared, dw_fragment, transpose_B=True) T.gemm(dv_shared, h_shared, dw_fragment, transpose_B=True)
if use_dw: if use_dw:
for i_s, i_k in T.Parallel(block_S, block_DK): for i_s, i_k in T.Parallel(block_S, block_DK):
dw_fragment[i_s, i_k] = -dw_fragment[i_s, i_k] dw_fragment[i_s, i_k] = -dw_fragment[i_s, i_k]
T.copy( T.copy(dw_fragment, dw[bb, bs * block_S : (bs + 1) * block_S, bh, bk * block_DK : (bk + 1) * block_DK])
dw_fragment, dw[bb, bs * block_S:(bs + 1) * block_S, bh,
bk * block_DK:(bk + 1) * block_DK]) T.copy(Q[bb, bs * block_S : (bs + 1) * block_S, bh, bk * block_DK : (bk + 1) * block_DK], q_shared)
T.copy(K[bb, bs * block_S : (bs + 1) * block_S, bh, bk * block_DK : (bk + 1) * block_DK], k_shared)
T.copy(Q[bb, bs * block_S:(bs + 1) * block_S, bh, bk * block_DK:(bk + 1) * block_DK],
q_shared)
T.copy(K[bb, bs * block_S:(bs + 1) * block_S, bh, bk * block_DK:(bk + 1) * block_DK],
k_shared)
T.copy(q_shared, q_fragment) T.copy(q_shared, q_fragment)
T.copy(k_shared, k_fragment) T.copy(k_shared, k_fragment)
...@@ -294,8 +277,7 @@ def tilelang_chunk_o_bwd_dqkwg( ...@@ -294,8 +277,7 @@ def tilelang_chunk_o_bwd_dqkwg(
dg_last_local[0] = dg_last_local[0] * T.exp(G[bb, bs * block_S + block_S - 1, bh]) dg_last_local[0] = dg_last_local[0] * T.exp(G[bb, bs * block_S + block_S - 1, bh])
for i_s, i_k in T.Parallel(block_S, block_DK): for i_s, i_k in T.Parallel(block_S, block_DK):
dq_fragment[i_s, i_k] = dq_fragment[i_s, i_k] * T.exp(G[bb, bs * block_S + i_s, dq_fragment[i_s, i_k] = dq_fragment[i_s, i_k] * T.exp(G[bb, bs * block_S + i_s, bh]) * scale
bh]) * scale
T.clear(dg_fragment_reduce_tmp) T.clear(dg_fragment_reduce_tmp)
for i_s, i_k in T.Parallel(block_S, block_DK): for i_s, i_k in T.Parallel(block_S, block_DK):
dg_fragment_reduce_tmp[i_s, i_k] = dq_fragment[i_s, i_k] * q_shared[i_s, i_k] dg_fragment_reduce_tmp[i_s, i_k] = dq_fragment[i_s, i_k] * q_shared[i_s, i_k]
...@@ -305,8 +287,7 @@ def tilelang_chunk_o_bwd_dqkwg( ...@@ -305,8 +287,7 @@ def tilelang_chunk_o_bwd_dqkwg(
for i_s, i_k in T.Parallel(block_S, block_DK): for i_s, i_k in T.Parallel(block_S, block_DK):
with T.If(G_last_local[0] - G[bb, bs * block_S + i_s, bh] <= 0): with T.If(G_last_local[0] - G[bb, bs * block_S + i_s, bh] <= 0):
with T.Then(): with T.Then():
dk_fragment[i_s, i_k] = dk_fragment[i_s, i_k] * T.exp( dk_fragment[i_s, i_k] = dk_fragment[i_s, i_k] * T.exp(G_last_local[0] - G[bb, bs * block_S + i_s, bh])
G_last_local[0] - G[bb, bs * block_S + i_s, bh])
with T.Else(): with T.Else():
dk_fragment[i_s, i_k] = 0 dk_fragment[i_s, i_k] = 0
T.clear(dg_fragment_reduce_tmp) T.clear(dg_fragment_reduce_tmp)
...@@ -325,12 +306,11 @@ def tilelang_chunk_o_bwd_dqkwg( ...@@ -325,12 +306,11 @@ def tilelang_chunk_o_bwd_dqkwg(
dg_last_local[1] = dg_last_fragment_scalar_2[0] dg_last_local[1] = dg_last_fragment_scalar_2[0]
for i_s1, i_s2 in T.Parallel(block_S, block_S): for i_s1, i_s2 in T.Parallel(block_S, block_S):
with T.If(i_s1 >= i_s2 and with T.If(i_s1 >= i_s2 and G[bb, bs * block_S + i_s1, bh] - G[bb, bs * block_S + i_s2, bh] <= 0):
G[bb, bs * block_S + i_s1, bh] - G[bb, bs * block_S + i_s2, bh] <= 0):
with T.Then(): with T.Then():
ds_fragment[i_s1, i_s2] = ds_fragment[ ds_fragment[i_s1, i_s2] = (
i_s1, i_s2] * T.exp(G[bb, bs * block_S + i_s1, bh] - ds_fragment[i_s1, i_s2] * T.exp(G[bb, bs * block_S + i_s1, bh] - G[bb, bs * block_S + i_s2, bh]) * scale
G[bb, bs * block_S + i_s2, bh]) * scale )
with T.Else(): with T.Else():
ds_fragment[i_s1, i_s2] = 0 ds_fragment[i_s1, i_s2] = 0
...@@ -338,8 +318,7 @@ def tilelang_chunk_o_bwd_dqkwg( ...@@ -338,8 +318,7 @@ def tilelang_chunk_o_bwd_dqkwg(
T.clear(ds_fragment_positive_transpose) T.clear(ds_fragment_positive_transpose)
T.gemm(q_shared, k_shared, ds_fragment_positive, transpose_B=True) T.gemm(q_shared, k_shared, ds_fragment_positive, transpose_B=True)
for i_s1, i_s2 in T.Parallel(block_S, block_S): for i_s1, i_s2 in T.Parallel(block_S, block_S):
ds_fragment_positive[ ds_fragment_positive[i_s1, i_s2] = ds_fragment[i_s1, i_s2] * ds_fragment_positive[i_s1, i_s2]
i_s1, i_s2] = ds_fragment[i_s1, i_s2] * ds_fragment_positive[i_s1, i_s2]
# FIXME: The reduce_sum statement with clear=True will cause an error of warp specialized pass # FIXME: The reduce_sum statement with clear=True will cause an error of warp specialized pass
T.reduce_sum(ds_fragment_positive, dg_fragment, dim=1, clear=False) T.reduce_sum(ds_fragment_positive, dg_fragment, dim=1, clear=False)
...@@ -363,15 +342,10 @@ def tilelang_chunk_o_bwd_dqkwg( ...@@ -363,15 +342,10 @@ def tilelang_chunk_o_bwd_dqkwg(
for i_s in T.Parallel(block_S): for i_s in T.Parallel(block_S):
with T.If(i_s >= block_S - 1): # noqa: SIM117 with T.If(i_s >= block_S - 1): # noqa: SIM117
with T.Then(): with T.Then():
dg_fragment_final[ dg_fragment_final[i_s] = dg_fragment_final[i_s] + dg_last_local[0] + dg_last_local[1]
i_s] = dg_fragment_final[i_s] + dg_last_local[0] + dg_last_local[1]
T.copy(dq_fragment, dq[bb, bs * block_S : (bs + 1) * block_S, bh, bk * block_DK : (bk + 1) * block_DK])
T.copy( T.copy(dk_fragment, dk[bb, bs * block_S : (bs + 1) * block_S, bh, bk * block_DK : (bk + 1) * block_DK])
dq_fragment, dq[bb, bs * block_S:(bs + 1) * block_S, bh,
bk * block_DK:(bk + 1) * block_DK])
T.copy(
dk_fragment, dk[bb, bs * block_S:(bs + 1) * block_S, bh,
bk * block_DK:(bk + 1) * block_DK])
for i_s in T.Parallel(block_S): for i_s in T.Parallel(block_S):
dg[bk, bb, bs * block_S + i_s, bh] = dg_fragment_final[i_s] dg[bk, bb, bs * block_S + i_s, bh] = dg_fragment_final[i_s]
...@@ -387,12 +361,8 @@ def tilelang_chunk_o_bwd_dqkwg( ...@@ -387,12 +361,8 @@ def tilelang_chunk_o_bwd_dqkwg(
for i_s, i_k in T.Parallel(block_S, block_DK): for i_s, i_k in T.Parallel(block_S, block_DK):
dq_fragment[i_s, i_k] = dq_fragment[i_s, i_k] * scale dq_fragment[i_s, i_k] = dq_fragment[i_s, i_k] * scale
dk_fragment[i_s, i_k] = dk_fragment[i_s, i_k] + dk_fragment_2[i_s, i_k] * scale dk_fragment[i_s, i_k] = dk_fragment[i_s, i_k] + dk_fragment_2[i_s, i_k] * scale
T.copy( T.copy(dq_fragment, dq[bb, bs * block_S : (bs + 1) * block_S, bh, bk * block_DK : (bk + 1) * block_DK])
dq_fragment, dq[bb, bs * block_S:(bs + 1) * block_S, bh, T.copy(dk_fragment, dk[bb, bs * block_S : (bs + 1) * block_S, bh, bk * block_DK : (bk + 1) * block_DK])
bk * block_DK:(bk + 1) * block_DK])
T.copy(
dk_fragment, dk[bb, bs * block_S:(bs + 1) * block_S, bh,
bk * block_DK:(bk + 1) * block_DK])
return kernel return kernel
...@@ -442,33 +412,53 @@ def run_test( ...@@ -442,33 +412,53 @@ def run_test(
threads=256, threads=256,
num_stages=0, num_stages=0,
): ):
Q, K, V, h, G, dO, dh, dv, W = prepare_input(B, S, H, DK, DV, chunk_size, Q, K, V, h, G, dO, dh, dv, W = prepare_input(
getattr(torch, input_dtype), B,
getattr(torch, output_dtype), S,
getattr(torch, accum_dtype), H,
getattr(torch, gate_dtype), DK,
getattr(torch, state_dtype)) DV,
dq_ref, dk_ref, dw_ref, dg_ref = prepare_output(B, S, H, DK, DV, chunk_size, chunk_size,
getattr(torch, output_dtype), getattr(torch, input_dtype),
getattr(torch, gate_dtype), getattr(torch, output_dtype),
getattr(torch, state_dtype), block_DK) getattr(torch, accum_dtype),
getattr(torch, gate_dtype),
getattr(torch, state_dtype),
)
dq_ref, dk_ref, dw_ref, dg_ref = prepare_output(
B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype), getattr(torch, state_dtype), block_DK
)
dq_tilelang, dk_tilelang, dw_tilelang, dg_tilelang = prepare_output( dq_tilelang, dk_tilelang, dw_tilelang, dg_tilelang = prepare_output(
B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype), B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype), getattr(torch, state_dtype), block_DK
getattr(torch, state_dtype), block_DK) )
# ref # ref
if use_g: if use_g:
dq_ref, dk_ref, dw_ref, dg_ref = chunk_bwd_dqkwg( dq_ref, dk_ref, dw_ref, dg_ref = chunk_bwd_dqkwg(Q, K, V, G, dO, h, dh, dv, W, chunk_size=chunk_size, scale=scale)
Q, K, V, G, dO, h, dh, dv, W, chunk_size=chunk_size, scale=scale)
else: else:
dq_ref, dk_ref, dw_ref, dg_ref = chunk_bwd_dqkwg( dq_ref, dk_ref, dw_ref, dg_ref = chunk_bwd_dqkwg(Q, K, V, None, dO, h, dh, dv, W, chunk_size=chunk_size, scale=scale)
Q, K, V, None, dO, h, dh, dv, W, chunk_size=chunk_size, scale=scale)
# tilelang # tilelang
kernel = tilelang_chunk_o_bwd_dqkwg(B, S, H, DK, DV, input_dtype, output_dtype, accum_dtype, kernel = tilelang_chunk_o_bwd_dqkwg(
gate_dtype, state_dtype, chunk_size, scale, use_g, use_dw, B,
block_DK, block_DV, threads, num_stages) S,
print(kernel.get_kernel_source()) H,
DK,
DV,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
state_dtype,
chunk_size,
scale,
use_g,
use_dw,
block_DK,
block_DV,
threads,
num_stages,
)
dq_tilelang, dk_tilelang, dw_tilelang, dg_tilelang = kernel(Q, K, V, h, G, dO, dh, dv, W) dq_tilelang, dk_tilelang, dw_tilelang, dg_tilelang = kernel(Q, K, V, h, G, dO, dh, dv, W)
if use_g: if use_g:
...@@ -515,11 +505,11 @@ def main(): ...@@ -515,11 +505,11 @@ def main():
H=8, H=8,
DK=DK, DK=DK,
DV=DV, DV=DV,
input_dtype="bfloat16", input_dtype=T.bfloat16,
output_dtype="bfloat16", output_dtype=T.bfloat16,
accum_dtype="float32", accum_dtype=T.float32,
gate_dtype="float32", gate_dtype=T.float32,
state_dtype="float32", state_dtype=T.float32,
chunk_size=64, chunk_size=64,
scale=DK**-0.5, scale=DK**-0.5,
# scale=1, # scale=1,
......
...@@ -9,6 +9,7 @@ import sys # noqa: F401 ...@@ -9,6 +9,7 @@ import sys # noqa: F401
# sys.path.insert(0, "/home/tzj/flash-linear-attention") # sys.path.insert(0, "/home/tzj/flash-linear-attention")
try: try:
import fla import fla
print(fla.__file__) print(fla.__file__)
from fla.ops.common.chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd from fla.ops.common.chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd
except ImportError: except ImportError:
...@@ -56,9 +57,9 @@ def tilelang_chunk_scaled_dot_kkt_fwd( ...@@ -56,9 +57,9 @@ def tilelang_chunk_scaled_dot_kkt_fwd(
H, H,
DK, DK,
chunk_size=64, chunk_size=64,
input_dtype="bfloat16", input_dtype=T.bfloat16,
output_dtype="bfloat16", output_dtype=T.bfloat16,
accum_dtype="float32", accum_dtype=T.float32,
use_g=True, use_g=True,
# kernel config # kernel config
block_S=64, block_S=64,
...@@ -75,10 +76,10 @@ def tilelang_chunk_scaled_dot_kkt_fwd( ...@@ -75,10 +76,10 @@ def tilelang_chunk_scaled_dot_kkt_fwd(
@T.prim_func @T.prim_func
def kernel( def kernel(
K: T.Tensor(K_shape, dtype=input_dtype), K: T.Tensor(K_shape, dtype=input_dtype),
Beta: T.Tensor(Beta_shape, dtype=input_dtype), Beta: T.Tensor(Beta_shape, dtype=input_dtype),
G: T.Tensor(G_shape, dtype=accum_dtype), G: T.Tensor(G_shape, dtype=accum_dtype),
A: T.Tensor(output_shape, dtype=output_dtype), A: T.Tensor(output_shape, dtype=output_dtype),
): ):
with T.Kernel(T.ceildiv(S, block_S), B * H, threads=threads) as (bs, bbh): with T.Kernel(T.ceildiv(S, block_S), B * H, threads=threads) as (bs, bbh):
bb, bh = bbh // H, bbh % H bb, bh = bbh // H, bbh % H
...@@ -93,10 +94,12 @@ def tilelang_chunk_scaled_dot_kkt_fwd( ...@@ -93,10 +94,12 @@ def tilelang_chunk_scaled_dot_kkt_fwd(
G_shared = T.alloc_shared((block_S,), dtype=accum_dtype, scope="shared") G_shared = T.alloc_shared((block_S,), dtype=accum_dtype, scope="shared")
G_diff_local = T.alloc_fragment((block_S, block_S), dtype=accum_dtype) G_diff_local = T.alloc_fragment((block_S, block_S), dtype=accum_dtype)
T.annotate_layout({ T.annotate_layout(
K_shared: tilelang.layout.make_swizzled_layout(K_shared), {
A_shared: tilelang.layout.make_swizzled_layout(A_shared), K_shared: tilelang.layout.make_swizzled_layout(K_shared),
}) A_shared: tilelang.layout.make_swizzled_layout(A_shared),
}
)
T.fill(A_fragment, 0) T.fill(A_fragment, 0)
T.disable_warp_group_reg_alloc() T.disable_warp_group_reg_alloc()
...@@ -104,9 +107,7 @@ def tilelang_chunk_scaled_dot_kkt_fwd( ...@@ -104,9 +107,7 @@ def tilelang_chunk_scaled_dot_kkt_fwd(
Beta_shared[i_s] = Beta[bb, bs * block_S + i_s, bh] Beta_shared[i_s] = Beta[bb, bs * block_S + i_s, bh]
for i_k in T.Pipelined(T.ceildiv(DK, block_DK), num_stages=num_stages): for i_k in T.Pipelined(T.ceildiv(DK, block_DK), num_stages=num_stages):
T.copy( T.copy(K[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK], K_shared)
K[bb, bs * block_S:(bs + 1) * block_S, bh, i_k * block_DK:(i_k + 1) * block_DK],
K_shared)
for i_s, i_k2 in T.Parallel(block_S, block_DK): for i_s, i_k2 in T.Parallel(block_S, block_DK):
Beta_K_fragment[i_s, i_k2] = K_shared[i_s, i_k2] * Beta_shared[i_s] Beta_K_fragment[i_s, i_k2] = K_shared[i_s, i_k2] * Beta_shared[i_s]
T.gemm(Beta_K_fragment, K_shared, A_fragment, transpose_B=True) T.gemm(Beta_K_fragment, K_shared, A_fragment, transpose_B=True)
...@@ -119,8 +120,7 @@ def tilelang_chunk_scaled_dot_kkt_fwd( ...@@ -119,8 +120,7 @@ def tilelang_chunk_scaled_dot_kkt_fwd(
for i_s1, i_s2 in T.Parallel(block_S, block_S): for i_s1, i_s2 in T.Parallel(block_S, block_S):
with T.If(G_diff_local[i_s1, i_s2] <= 0 and i_s1 > i_s2): with T.If(G_diff_local[i_s1, i_s2] <= 0 and i_s1 > i_s2):
with T.Then(): with T.Then():
A_fragment[i_s1, i_s2] = A_fragment[i_s1, i_s2] * T.exp( A_fragment[i_s1, i_s2] = A_fragment[i_s1, i_s2] * T.exp(G_diff_local[i_s1, i_s2])
G_diff_local[i_s1, i_s2])
with T.Else(): with T.Else():
A_fragment[i_s1, i_s2] = 0 A_fragment[i_s1, i_s2] = 0
else: else:
...@@ -130,7 +130,7 @@ def tilelang_chunk_scaled_dot_kkt_fwd( ...@@ -130,7 +130,7 @@ def tilelang_chunk_scaled_dot_kkt_fwd(
A_fragment[i_s1, i_s2] = 0 A_fragment[i_s1, i_s2] = 0
T.copy(A_fragment, A_shared) T.copy(A_fragment, A_shared)
T.copy(A_shared, A[bb, bs * block_S:(bs + 1) * block_S, bh, :]) T.copy(A_shared, A[bb, bs * block_S : (bs + 1) * block_S, bh, :])
return kernel return kernel
...@@ -149,24 +149,21 @@ def run_test( ...@@ -149,24 +149,21 @@ def run_test(
threads, threads,
num_stages, num_stages,
): ):
K, Beta, G = prepare_input(B, S, H, DK, getattr(torch, input_dtype), K, Beta, G = prepare_input(B, S, H, DK, getattr(torch, input_dtype), getattr(torch, output_dtype), getattr(torch, accum_dtype))
getattr(torch, output_dtype), getattr(torch, accum_dtype))
A_ref = prepare_output(B, S, H, chunk_size, getattr(torch, output_dtype)) A_ref = prepare_output(B, S, H, chunk_size, getattr(torch, output_dtype))
A_tilelang = prepare_output(B, S, H, chunk_size, getattr(torch, output_dtype)) A_tilelang = prepare_output(B, S, H, chunk_size, getattr(torch, output_dtype))
# reference # reference
if use_g: if use_g:
A_ref = chunk_scaled_dot_kkt_fwd( A_ref = chunk_scaled_dot_kkt_fwd(K, Beta, G, chunk_size=chunk_size, output_dtype=getattr(torch, output_dtype))
K, Beta, G, chunk_size=chunk_size, output_dtype=getattr(torch, output_dtype))
else: else:
A_ref = chunk_scaled_dot_kkt_fwd( A_ref = chunk_scaled_dot_kkt_fwd(K, Beta, None, chunk_size=chunk_size, output_dtype=getattr(torch, output_dtype))
K, Beta, None, chunk_size=chunk_size, output_dtype=getattr(torch, output_dtype))
# tilelang # tilelang
block_S = chunk_size block_S = chunk_size
kernel = tilelang_chunk_scaled_dot_kkt_fwd(B, S, H, DK, chunk_size, input_dtype, output_dtype, kernel = tilelang_chunk_scaled_dot_kkt_fwd(
accum_dtype, use_g, block_S, block_DK, threads, B, S, H, DK, chunk_size, input_dtype, output_dtype, accum_dtype, use_g, block_S, block_DK, threads, num_stages
num_stages) )
A_tilelang = kernel(K, Beta, G) A_tilelang = kernel(K, Beta, G)
try: try:
...@@ -186,13 +183,14 @@ def main(): ...@@ -186,13 +183,14 @@ def main():
H=32, H=32,
DK=128, DK=128,
chunk_size=64, chunk_size=64,
input_dtype="bfloat16", input_dtype=T.bfloat16,
output_dtype="bfloat16", output_dtype=T.bfloat16,
accum_dtype="float32", accum_dtype=T.float32,
use_g=True, use_g=True,
block_DK=64, block_DK=64,
threads=128, threads=128,
num_stages=2) num_stages=2,
)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -10,6 +10,7 @@ import sys # noqa: F401 ...@@ -10,6 +10,7 @@ import sys # noqa: F401
# sys.path.insert(0, "/home/tzj/flash-linear-attention") # sys.path.insert(0, "/home/tzj/flash-linear-attention")
try: try:
import fla import fla
print(fla.__file__) print(fla.__file__)
from fla.ops.utils.cumsum import chunk_local_cumsum_scalar from fla.ops.utils.cumsum import chunk_local_cumsum_scalar
except ImportError: except ImportError:
...@@ -20,11 +21,8 @@ import torch ...@@ -20,11 +21,8 @@ import torch
@tilelang.jit( @tilelang.jit(
out_idx=[-1], out_idx=[-1], pass_configs={tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}
pass_configs={ )
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True
})
def tilelang_chunk_local_cumsum_scalar( def tilelang_chunk_local_cumsum_scalar(
# task config # task config
B, B,
...@@ -34,43 +32,43 @@ def tilelang_chunk_local_cumsum_scalar( ...@@ -34,43 +32,43 @@ def tilelang_chunk_local_cumsum_scalar(
is_varlen=False, is_varlen=False,
head_first=False, head_first=False,
reverse=False, reverse=False,
input_dtype="float16", input_dtype=T.float16,
output_dtype="float32", output_dtype=T.float32,
# kernel config # kernel config
block_S=64, block_S=64,
threads=256, threads=256,
use_fragment=False, use_fragment=False,
): ):
G_shape = (B, H, S) if head_first else (B, S, H) G_shape = (B, H, S) if head_first else (B, S, H)
assert chunk_size == 2**(chunk_size.bit_length() - 1), "chunk_size must be a power of 2" assert chunk_size == 2 ** (chunk_size.bit_length() - 1), "chunk_size must be a power of 2"
assert chunk_size == block_S, "chunk_size must be equal to block_S" assert chunk_size == block_S, "chunk_size must be equal to block_S"
@T.prim_func @T.prim_func
def kernel( def kernel(
G: T.Tensor(G_shape, dtype=input_dtype), G: T.Tensor(G_shape, dtype=input_dtype),
G_new: T.Tensor(G_shape, dtype=output_dtype), G_new: T.Tensor(G_shape, dtype=output_dtype),
): ):
with T.Kernel(T.ceildiv(S, block_S), B * H, threads=threads) as (bs, bbh): with T.Kernel(T.ceildiv(S, block_S), B * H, threads=threads) as (bs, bbh):
bb, bh = bbh // H, bbh % H bb, bh = bbh // H, bbh % H
G_shared = T.alloc_shared((1, block_S), dtype=output_dtype, scope="shared") G_shared = T.alloc_shared((1, block_S), dtype=output_dtype, scope="shared")
if head_first: if head_first:
T.copy(G[bb, bh, bs * block_S:(bs + 1) * block_S], G_shared) T.copy(G[bb, bh, bs * block_S : (bs + 1) * block_S], G_shared)
else: else:
T.copy(G[bb, bs * block_S:(bs + 1) * block_S, bh], G_shared) T.copy(G[bb, bs * block_S : (bs + 1) * block_S, bh], G_shared)
if use_fragment: if use_fragment:
G_fragment = T.alloc_fragment((1, block_S), dtype=output_dtype, scope="shared") G_fragment = T.alloc_fragment((1, block_S), dtype=output_dtype, scope="shared")
T.copy(G_shared, G_fragment) T.copy(G_shared, G_fragment)
T.cumsum(G_fragment, dim=1, reverse=reverse) T.cumsum(G_fragment, dim=1, reverse=reverse)
if head_first: if head_first:
T.copy(G_fragment, G_new[bb, bh, bs * block_S:(bs + 1) * block_S]) T.copy(G_fragment, G_new[bb, bh, bs * block_S : (bs + 1) * block_S])
else: else:
T.copy(G_fragment, G_new[bb, bs * block_S:(bs + 1) * block_S, bh]) T.copy(G_fragment, G_new[bb, bs * block_S : (bs + 1) * block_S, bh])
else: else:
T.cumsum(G_shared, dim=1, reverse=reverse) T.cumsum(G_shared, dim=1, reverse=reverse)
if head_first: if head_first:
T.copy(G_shared, G_new[bb, bh, bs * block_S:(bs + 1) * block_S]) T.copy(G_shared, G_new[bb, bh, bs * block_S : (bs + 1) * block_S])
else: else:
T.copy(G_shared, G_new[bb, bs * block_S:(bs + 1) * block_S, bh]) T.copy(G_shared, G_new[bb, bs * block_S : (bs + 1) * block_S, bh])
return kernel return kernel
...@@ -113,11 +111,8 @@ def run_test( ...@@ -113,11 +111,8 @@ def run_test(
# reference cumsum # reference cumsum
G_new_ref = chunk_local_cumsum_scalar( G_new_ref = chunk_local_cumsum_scalar(
g=G, g=G, chunk_size=chunk_size, reverse=reverse, head_first=head_first, output_dtype=getattr(torch, output_dtype)
chunk_size=chunk_size, )
reverse=reverse,
head_first=head_first,
output_dtype=getattr(torch, output_dtype))
# tilelang cumsum # tilelang cumsum
block_S = chunk_size block_S = chunk_size
...@@ -159,10 +154,11 @@ def main(): ...@@ -159,10 +154,11 @@ def main():
chunk_size=64, chunk_size=64,
reverse=True, reverse=True,
head_first=False, head_first=False,
input_dtype="float32", input_dtype=T.float32,
output_dtype="float32", output_dtype=T.float32,
threads=256, threads=256,
use_fragment=False) use_fragment=False,
)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -9,6 +9,7 @@ import sys # noqa: F401 ...@@ -9,6 +9,7 @@ import sys # noqa: F401
# sys.path.insert(0, "/home/tzj/flash-linear-attention") # sys.path.insert(0, "/home/tzj/flash-linear-attention")
try: try:
import fla import fla
print(fla.__file__) print(fla.__file__)
from fla.ops.gated_delta_rule.wy_fast import recompute_w_u_fwd from fla.ops.gated_delta_rule.wy_fast import recompute_w_u_fwd
except ImportError: except ImportError:
...@@ -73,13 +74,13 @@ def tilelang_recompute_w_u_fwd( ...@@ -73,13 +74,13 @@ def tilelang_recompute_w_u_fwd(
@T.prim_func @T.prim_func
def kernel( def kernel(
K: T.Tensor(K_shape, dtype=input_dtype), K: T.Tensor(K_shape, dtype=input_dtype),
V: T.Tensor(V_shape, dtype=input_dtype), V: T.Tensor(V_shape, dtype=input_dtype),
Beta: T.Tensor(Beta_shape, dtype=input_dtype), Beta: T.Tensor(Beta_shape, dtype=input_dtype),
G: T.Tensor(G_shape, dtype=gate_dtype), G: T.Tensor(G_shape, dtype=gate_dtype),
A: T.Tensor(A_shape, dtype=output_dtype), A: T.Tensor(A_shape, dtype=output_dtype),
W: T.Tensor(K_shape, dtype=output_dtype), W: T.Tensor(K_shape, dtype=output_dtype),
U: T.Tensor(V_shape, dtype=output_dtype), U: T.Tensor(V_shape, dtype=output_dtype),
): ):
with T.Kernel(T.ceildiv(S, block_S), B * H, threads=threads) as (bs, bbh): with T.Kernel(T.ceildiv(S, block_S), B * H, threads=threads) as (bs, bbh):
bb, bh = bbh // H, bbh % H bb, bh = bbh // H, bbh % H
...@@ -95,49 +96,42 @@ def tilelang_recompute_w_u_fwd( ...@@ -95,49 +96,42 @@ def tilelang_recompute_w_u_fwd(
W_Beta_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype) W_Beta_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype)
U_Beta_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype) U_Beta_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype)
T.annotate_layout({ T.annotate_layout(
K_shared: tilelang.layout.make_swizzled_layout(K_shared), {
V_shared: tilelang.layout.make_swizzled_layout(V_shared), K_shared: tilelang.layout.make_swizzled_layout(K_shared),
A_shared: tilelang.layout.make_swizzled_layout(A_shared), V_shared: tilelang.layout.make_swizzled_layout(V_shared),
W_shared: tilelang.layout.make_swizzled_layout(W_shared), A_shared: tilelang.layout.make_swizzled_layout(A_shared),
U_shared: tilelang.layout.make_swizzled_layout(U_shared), W_shared: tilelang.layout.make_swizzled_layout(W_shared),
W_Beta_shared: tilelang.layout.make_swizzled_layout(W_Beta_shared), U_shared: tilelang.layout.make_swizzled_layout(U_shared),
U_Beta_shared: tilelang.layout.make_swizzled_layout(U_Beta_shared), W_Beta_shared: tilelang.layout.make_swizzled_layout(W_Beta_shared),
}) U_Beta_shared: tilelang.layout.make_swizzled_layout(U_Beta_shared),
}
)
T.disable_warp_group_reg_alloc() T.disable_warp_group_reg_alloc()
for i_s in T.Parallel(block_S): for i_s in T.Parallel(block_S):
Beta_shared[i_s] = Beta[bb, bs * block_S + i_s, bh] Beta_shared[i_s] = Beta[bb, bs * block_S + i_s, bh]
G_shared[i_s] = T.exp(G[bb, bs * block_S + i_s, bh]) G_shared[i_s] = T.exp(G[bb, bs * block_S + i_s, bh])
T.copy(A[bb, bs * block_S:(bs + 1) * block_S, bh, :], A_shared) T.copy(A[bb, bs * block_S : (bs + 1) * block_S, bh, :], A_shared)
for i_v in T.Pipelined(T.ceildiv(DV, block_DV), num_stages=num_stages): for i_v in T.Pipelined(T.ceildiv(DV, block_DV), num_stages=num_stages):
T.copy( T.copy(V[bb, bs * block_S : (bs + 1) * block_S, bh, i_v * block_DV : (i_v + 1) * block_DV], V_shared)
V[bb, bs * block_S:(bs + 1) * block_S, bh, i_v * block_DV:(i_v + 1) * block_DV],
V_shared)
for i_s, i_v2 in T.Parallel(block_S, block_DV): for i_s, i_v2 in T.Parallel(block_S, block_DV):
U_Beta_shared[i_s, i_v2] = V_shared[i_s, i_v2] * Beta_shared[i_s] U_Beta_shared[i_s, i_v2] = V_shared[i_s, i_v2] * Beta_shared[i_s]
T.gemm(A_shared, U_Beta_shared, U_fragment, clear_accum=True) T.gemm(A_shared, U_Beta_shared, U_fragment, clear_accum=True)
# First copy to smem, then copy to gmem to reduce U2RU instructions # First copy to smem, then copy to gmem to reduce U2RU instructions
T.copy(U_fragment, U_shared) T.copy(U_fragment, U_shared)
T.copy( T.copy(U_shared, U[bb, bs * block_S : (bs + 1) * block_S, bh, i_v * block_DV : (i_v + 1) * block_DV])
U_shared, U[bb, bs * block_S:(bs + 1) * block_S, bh,
i_v * block_DV:(i_v + 1) * block_DV])
for i_k in T.Pipelined(T.ceildiv(DK, block_DK), num_stages=num_stages): for i_k in T.Pipelined(T.ceildiv(DK, block_DK), num_stages=num_stages):
T.copy( T.copy(K[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK], K_shared)
K[bb, bs * block_S:(bs + 1) * block_S, bh, i_k * block_DK:(i_k + 1) * block_DK],
K_shared)
for i_s, i_k2 in T.Parallel(block_S, block_DK): for i_s, i_k2 in T.Parallel(block_S, block_DK):
W_Beta_shared[i_s, W_Beta_shared[i_s, i_k2] = K_shared[i_s, i_k2] * Beta_shared[i_s] * G_shared[i_s]
i_k2] = K_shared[i_s, i_k2] * Beta_shared[i_s] * G_shared[i_s]
T.gemm(A_shared, W_Beta_shared, W_fragment, clear_accum=True) T.gemm(A_shared, W_Beta_shared, W_fragment, clear_accum=True)
# First copy to smem, then copy to gmem to reduce U2RU instructions # First copy to smem, then copy to gmem to reduce U2RU instructions
T.copy(W_fragment, W_shared) T.copy(W_fragment, W_shared)
T.copy( T.copy(W_shared, W[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK])
W_shared, W[bb, bs * block_S:(bs + 1) * block_S, bh,
i_k * block_DK:(i_k + 1) * block_DK])
return kernel return kernel
...@@ -159,15 +153,8 @@ def run_test( ...@@ -159,15 +153,8 @@ def run_test(
num_stages, num_stages,
): ):
K, V, Beta, G, A = prepare_input( K, V, Beta, G, A = prepare_input(
B, B, S, H, DK, DV, chunk_size, getattr(torch, input_dtype), getattr(torch, output_dtype), gate_dtype=getattr(torch, gate_dtype)
S, )
H,
DK,
DV,
chunk_size,
getattr(torch, input_dtype),
getattr(torch, output_dtype),
gate_dtype=getattr(torch, gate_dtype))
W_ref, U_ref = prepare_output(B, S, H, DK, DV, getattr(torch, output_dtype)) W_ref, U_ref = prepare_output(B, S, H, DK, DV, getattr(torch, output_dtype))
W_tilelang, U_tilelang = prepare_output(B, S, H, DK, DV, getattr(torch, output_dtype)) W_tilelang, U_tilelang = prepare_output(B, S, H, DK, DV, getattr(torch, output_dtype))
...@@ -191,7 +178,8 @@ def run_test( ...@@ -191,7 +178,8 @@ def run_test(
block_DK=block_DK, block_DK=block_DK,
block_DV=block_DV, block_DV=block_DV,
threads=threads, threads=threads,
num_stages=num_stages) num_stages=num_stages,
)
print(kernel.get_kernel_source()) print(kernel.get_kernel_source())
W_tilelang, U_tilelang = kernel(K, V, Beta, G, A) W_tilelang, U_tilelang = kernel(K, V, Beta, G, A)
...@@ -217,14 +205,15 @@ def main(): ...@@ -217,14 +205,15 @@ def main():
DK=128, DK=128,
DV=128, DV=128,
chunk_size=64, chunk_size=64,
input_dtype="bfloat16", input_dtype=T.bfloat16,
output_dtype="bfloat16", output_dtype=T.bfloat16,
gate_dtype="float32", gate_dtype=T.float32,
accum_dtype="float32", accum_dtype=T.float32,
block_DK=64, block_DK=64,
block_DV=32, block_DV=32,
threads=128, threads=128,
num_stages=3) num_stages=3,
)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -10,6 +10,7 @@ import tilelang.language as T ...@@ -10,6 +10,7 @@ import tilelang.language as T
# sys.path.insert(0, "/home/tzj/flash-linear-attention") # sys.path.insert(0, "/home/tzj/flash-linear-attention")
try: try:
import fla import fla
print(fla.__file__) print(fla.__file__)
from fla.ops.gated_delta_rule.wy_fast import bwd_prepare_wy_repr from fla.ops.gated_delta_rule.wy_fast import bwd_prepare_wy_repr
except ImportError: except ImportError:
...@@ -93,10 +94,8 @@ def prepare_output( ...@@ -93,10 +94,8 @@ def prepare_output(
@tilelang.jit( @tilelang.jit(
out_idx=[-5, -4, -3, -2, -1], out_idx=[-5, -4, -3, -2, -1],
pass_configs={ pass_configs={tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True},
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, )
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True
})
def tilelang_wy_fast_bwd( def tilelang_wy_fast_bwd(
# task config # task config
B, B,
...@@ -135,20 +134,20 @@ def tilelang_wy_fast_bwd( ...@@ -135,20 +134,20 @@ def tilelang_wy_fast_bwd(
@T.prim_func @T.prim_func
def kernel( def kernel(
# input # input
K: T.Tensor(K_shape, dtype=input_dtype), K: T.Tensor(K_shape, dtype=input_dtype),
V: T.Tensor(V_shape, dtype=input_dtype), V: T.Tensor(V_shape, dtype=input_dtype),
Beta: T.Tensor(Beta_shape, dtype=input_dtype), Beta: T.Tensor(Beta_shape, dtype=input_dtype),
G: T.Tensor(G_shape, dtype=gate_dtype), G: T.Tensor(G_shape, dtype=gate_dtype),
A: T.Tensor(A_shape, dtype=input_dtype), A: T.Tensor(A_shape, dtype=input_dtype),
dw: T.Tensor(dw_shape, dtype=input_dtype), dw: T.Tensor(dw_shape, dtype=input_dtype),
du: T.Tensor(du_shape, dtype=input_dtype), du: T.Tensor(du_shape, dtype=input_dtype),
# output # output
dA: T.Tensor(dA_shape, dtype=input_dtype), dA: T.Tensor(dA_shape, dtype=input_dtype),
dk: T.Tensor(dk_shape, dtype=output_dtype), dk: T.Tensor(dk_shape, dtype=output_dtype),
dv: T.Tensor(dv_shape, dtype=output_dtype), dv: T.Tensor(dv_shape, dtype=output_dtype),
dbeta: T.Tensor(dbeta_shape, dtype=output_dtype), dbeta: T.Tensor(dbeta_shape, dtype=output_dtype),
dg: T.Tensor(dg_shape, dtype=gate_dtype), dg: T.Tensor(dg_shape, dtype=gate_dtype),
): ):
with T.Kernel(T.ceildiv(S, block_S), B * H, threads=threads) as (bs, bbh): with T.Kernel(T.ceildiv(S, block_S), B * H, threads=threads) as (bs, bbh):
bb, bh = bbh // H, bbh % H bb, bh = bbh // H, bbh % H
...@@ -187,7 +186,7 @@ def tilelang_wy_fast_bwd( ...@@ -187,7 +186,7 @@ def tilelang_wy_fast_bwd(
T.clear(dbeta_fragment_v) T.clear(dbeta_fragment_v)
T.clear(dg_fragment) T.clear(dg_fragment)
T.copy(A[bb, bs * block_S:(bs + 1) * block_S, bh, :], A_shared) T.copy(A[bb, bs * block_S : (bs + 1) * block_S, bh, :], A_shared)
for i_s in T.Parallel(block_S): for i_s in T.Parallel(block_S):
Beta_shared[i_s] = Beta[bb, bs * block_S + i_s, bh] Beta_shared[i_s] = Beta[bb, bs * block_S + i_s, bh]
G_shared[i_s] = G[bb, bs * block_S + i_s, bh] G_shared[i_s] = G[bb, bs * block_S + i_s, bh]
...@@ -195,51 +194,37 @@ def tilelang_wy_fast_bwd( ...@@ -195,51 +194,37 @@ def tilelang_wy_fast_bwd(
# Update dk # Update dk
for i_k in T.Pipelined(T.ceildiv(DK, block_DK), num_stages=num_stages): for i_k in T.Pipelined(T.ceildiv(DK, block_DK), num_stages=num_stages):
T.copy( T.copy(K[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK], K_shared)
K[bb, bs * block_S:(bs + 1) * block_S, bh, i_k * block_DK:(i_k + 1) * block_DK],
K_shared)
for i_s, i_k2 in T.Parallel(block_S, block_DK): for i_s, i_k2 in T.Parallel(block_S, block_DK):
K_shared_beta_g[i_s, K_shared_beta_g[i_s, i_k2] = K_shared[i_s, i_k2] * Beta_shared[i_s] * G_shared_exp[i_s]
i_k2] = K_shared[i_s, T.copy(dw[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK], dw_shared)
i_k2] * Beta_shared[i_s] * G_shared_exp[i_s]
T.copy(
dw[bb, bs * block_S:(bs + 1) * block_S, bh,
i_k * block_DK:(i_k + 1) * block_DK], dw_shared)
T.gemm(dw_shared, K_shared_beta_g, dA_fragment, transpose_B=True) T.gemm(dw_shared, K_shared_beta_g, dA_fragment, transpose_B=True)
T.gemm(A_shared, dw_shared, dk_fragment_beta_g, clear_accum=True, transpose_A=True) T.gemm(A_shared, dw_shared, dk_fragment_beta_g, clear_accum=True, transpose_A=True)
for i_s, i_k2 in T.Parallel(block_S, block_DK): for i_s, i_k2 in T.Parallel(block_S, block_DK):
dk_fragment[ dk_fragment[i_s, i_k2] = dk_fragment_beta_g[i_s, i_k2] * Beta_shared[i_s] * G_shared_exp[i_s]
i_s,
i_k2] = dk_fragment_beta_g[i_s, i_k2] * Beta_shared[i_s] * G_shared_exp[i_s]
# for i_s, i_k2 in T.Parallel(block_S, block_DK): # for i_s, i_k2 in T.Parallel(block_S, block_DK):
# dbeta_fragment[i_s] = dbeta_fragment[i_s] + dk_fragment_beta_g[i_s, i_k2] * K_shared[i_s, i_k2] * G_shared_exp[i_s] # dbeta_fragment[i_s] = dbeta_fragment[i_s] + dk_fragment_beta_g[i_s, i_k2] * K_shared[i_s, i_k2] * G_shared_exp[i_s]
for i_s, i_k2 in T.Parallel(block_S, block_DK): for i_s, i_k2 in T.Parallel(block_S, block_DK):
dbeta_fragment_reduce_tmpk[i_s, i_k2] = dk_fragment_beta_g[ dbeta_fragment_reduce_tmpk[i_s, i_k2] = dk_fragment_beta_g[i_s, i_k2] * K_shared[i_s, i_k2] * G_shared_exp[i_s]
i_s, i_k2] * K_shared[i_s, i_k2] * G_shared_exp[i_s]
T.reduce_sum(dbeta_fragment_reduce_tmpk, dbeta_fragment_k, dim=1, clear=False) T.reduce_sum(dbeta_fragment_reduce_tmpk, dbeta_fragment_k, dim=1, clear=False)
# for i_s, i_k2 in T.Parallel(block_S, block_DK): # for i_s, i_k2 in T.Parallel(block_S, block_DK):
# dg_fragment[i_s] = dg_fragment[i_s] + dk_fragment_beta_g[i_s, i_k2] * K_shared[i_s, i_k2] * G_shared_exp[i_s] * Beta_shared[i_s] # dg_fragment[i_s] = dg_fragment[i_s] + dk_fragment_beta_g[i_s, i_k2] * K_shared[i_s, i_k2] * G_shared_exp[i_s] * Beta_shared[i_s]
for i_s, i_k2 in T.Parallel(block_S, block_DK): for i_s, i_k2 in T.Parallel(block_S, block_DK):
dg_fragment_reduce_tmp[i_s, i_k2] = dk_fragment_beta_g[i_s, i_k2] * K_shared[ dg_fragment_reduce_tmp[i_s, i_k2] = (
i_s, i_k2] * G_shared_exp[i_s] * Beta_shared[i_s] dk_fragment_beta_g[i_s, i_k2] * K_shared[i_s, i_k2] * G_shared_exp[i_s] * Beta_shared[i_s]
)
T.reduce_sum(dg_fragment_reduce_tmp, dg_fragment, dim=1, clear=False) T.reduce_sum(dg_fragment_reduce_tmp, dg_fragment, dim=1, clear=False)
# correct dk # correct dk
T.copy( T.copy(dk_fragment, dk[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK])
dk_fragment, dk[bb, bs * block_S:(bs + 1) * block_S, bh,
i_k * block_DK:(i_k + 1) * block_DK])
# Update dv # Update dv
for i_v in T.Pipelined(T.ceildiv(DV, block_DV), num_stages=num_stages): for i_v in T.Pipelined(T.ceildiv(DV, block_DV), num_stages=num_stages):
T.copy( T.copy(V[bb, bs * block_S : (bs + 1) * block_S, bh, i_v * block_DV : (i_v + 1) * block_DV], V_shared)
V[bb, bs * block_S:(bs + 1) * block_S, bh, i_v * block_DV:(i_v + 1) * block_DV],
V_shared)
for i_s, i_v2 in T.Parallel(block_S, block_DV): for i_s, i_v2 in T.Parallel(block_S, block_DV):
V_shared_beta[i_s, i_v2] = V_shared[i_s, i_v2] * Beta_shared[i_s] V_shared_beta[i_s, i_v2] = V_shared[i_s, i_v2] * Beta_shared[i_s]
T.copy( T.copy(du[bb, bs * block_S : (bs + 1) * block_S, bh, i_v * block_DV : (i_v + 1) * block_DV], du_shared)
du[bb, bs * block_S:(bs + 1) * block_S, bh,
i_v * block_DV:(i_v + 1) * block_DV], du_shared)
T.gemm(du_shared, V_shared_beta, dA_fragment, transpose_B=True) T.gemm(du_shared, V_shared_beta, dA_fragment, transpose_B=True)
T.gemm(A_shared, du_shared, dv_fragment_beta, clear_accum=True, transpose_A=True) T.gemm(A_shared, du_shared, dv_fragment_beta, clear_accum=True, transpose_A=True)
for i_s, i_v2 in T.Parallel(block_S, block_DV): for i_s, i_v2 in T.Parallel(block_S, block_DV):
...@@ -247,30 +232,22 @@ def tilelang_wy_fast_bwd( ...@@ -247,30 +232,22 @@ def tilelang_wy_fast_bwd(
# for i_s, i_v2 in T.Parallel(block_S, block_DV): # for i_s, i_v2 in T.Parallel(block_S, block_DV):
# dbeta_fragment[i_s] = dbeta_fragment[i_s] + dv_fragment_beta[i_s, i_v2] * V_shared[i_s, i_v2] # dbeta_fragment[i_s] = dbeta_fragment[i_s] + dv_fragment_beta[i_s, i_v2] * V_shared[i_s, i_v2]
for i_s, i_v2 in T.Parallel(block_S, block_DV): for i_s, i_v2 in T.Parallel(block_S, block_DV):
dbeta_fragment_reduce_tmpv[i_s, dbeta_fragment_reduce_tmpv[i_s, i_v2] = dv_fragment_beta[i_s, i_v2] * V_shared[i_s, i_v2]
i_v2] = dv_fragment_beta[i_s, i_v2] * V_shared[i_s,
i_v2]
T.reduce_sum(dbeta_fragment_reduce_tmpv, dbeta_fragment_v, dim=1, clear=False) T.reduce_sum(dbeta_fragment_reduce_tmpv, dbeta_fragment_v, dim=1, clear=False)
T.copy( T.copy(dv_fragment, dv[bb, bs * block_S : (bs + 1) * block_S, bh, i_v * block_DV : (i_v + 1) * block_DV])
dv_fragment, dv[bb, bs * block_S:(bs + 1) * block_S, bh,
i_v * block_DV:(i_v + 1) * block_DV])
# Temporary store dbeta, dg and dA # Temporary store dbeta, dg and dA
for i_s in T.Parallel(block_S): for i_s in T.Parallel(block_S):
dbeta[bb, bs * block_S + i_s, bh] = dbeta_fragment_k[i_s] + dbeta_fragment_v[i_s] dbeta[bb, bs * block_S + i_s, bh] = dbeta_fragment_k[i_s] + dbeta_fragment_v[i_s]
dg[bb, bs * block_S + i_s, bh] = dg_fragment[i_s] dg[bb, bs * block_S + i_s, bh] = dg_fragment[i_s]
# correct dA # correct dA
T.copy(dA_fragment, dA[bb, bs * block_S:(bs + 1) * block_S, bh, :]) T.copy(dA_fragment, dA[bb, bs * block_S : (bs + 1) * block_S, bh, :])
return kernel return kernel
@tilelang.jit( @tilelang.jit(pass_configs={tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True})
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True
})
def tilelang_wy_fast_bwd_split( def tilelang_wy_fast_bwd_split(
# task config # task config
B, B,
...@@ -308,20 +285,20 @@ def tilelang_wy_fast_bwd_split( ...@@ -308,20 +285,20 @@ def tilelang_wy_fast_bwd_split(
@T.prim_func @T.prim_func
def kernel( def kernel(
# input # input
K: T.Tensor(K_shape, dtype=input_dtype), K: T.Tensor(K_shape, dtype=input_dtype),
V: T.Tensor(V_shape, dtype=input_dtype), V: T.Tensor(V_shape, dtype=input_dtype),
Beta: T.Tensor(Beta_shape, dtype=input_dtype), Beta: T.Tensor(Beta_shape, dtype=input_dtype),
G: T.Tensor(G_shape, dtype=gate_dtype), G: T.Tensor(G_shape, dtype=gate_dtype),
A: T.Tensor(A_shape, dtype=input_dtype), A: T.Tensor(A_shape, dtype=input_dtype),
dw: T.Tensor(dw_shape, dtype=input_dtype), dw: T.Tensor(dw_shape, dtype=input_dtype),
du: T.Tensor(du_shape, dtype=input_dtype), du: T.Tensor(du_shape, dtype=input_dtype),
dA: T.Tensor(dA_shape, dtype=input_dtype), dA: T.Tensor(dA_shape, dtype=input_dtype),
dk: T.Tensor(dk_shape, dtype=output_dtype), dk: T.Tensor(dk_shape, dtype=output_dtype),
dv: T.Tensor(dv_shape, dtype=output_dtype), dv: T.Tensor(dv_shape, dtype=output_dtype),
dbeta_k: T.Tensor(dbeta_shape, dtype=output_dtype), dbeta_k: T.Tensor(dbeta_shape, dtype=output_dtype),
dg_A_positive: T.Tensor(dA_shape, dtype=gate_dtype), dg_A_positive: T.Tensor(dA_shape, dtype=gate_dtype),
dg_A_negative: T.Tensor(dA_shape, dtype=gate_dtype), dg_A_negative: T.Tensor(dA_shape, dtype=gate_dtype),
): ):
with T.Kernel(T.ceildiv(S, block_S), B * H, threads=threads) as (bs, bbh): with T.Kernel(T.ceildiv(S, block_S), B * H, threads=threads) as (bs, bbh):
bb, bh = bbh // H, bbh % H bb, bh = bbh // H, bbh % H
...@@ -350,7 +327,7 @@ def tilelang_wy_fast_bwd_split( ...@@ -350,7 +327,7 @@ def tilelang_wy_fast_bwd_split(
T.clear(dA_A_fragment_1) T.clear(dA_A_fragment_1)
T.clear(dA_A_fragment_2) T.clear(dA_A_fragment_2)
T.copy(A[bb, bs * block_S:(bs + 1) * block_S, bh, :], A_shared) T.copy(A[bb, bs * block_S : (bs + 1) * block_S, bh, :], A_shared)
for i_s in T.Parallel(block_S): for i_s in T.Parallel(block_S):
Beta_shared[i_s] = Beta[bb, bs * block_S + i_s, bh] Beta_shared[i_s] = Beta[bb, bs * block_S + i_s, bh]
G_shared[i_s] = G[bb, bs * block_S + i_s, bh] G_shared[i_s] = G[bb, bs * block_S + i_s, bh]
...@@ -361,7 +338,7 @@ def tilelang_wy_fast_bwd_split( ...@@ -361,7 +338,7 @@ def tilelang_wy_fast_bwd_split(
# for i_s in T.Parallel(block_S): # for i_s in T.Parallel(block_S):
# dbeta_fragment[i_s] = dbeta[bb, bs * block_S + i_s, bh] # dbeta_fragment[i_s] = dbeta[bb, bs * block_S + i_s, bh]
# dg_fragment[i_s] = dg[bb, bs * block_S + i_s, bh] # dg_fragment[i_s] = dg[bb, bs * block_S + i_s, bh]
T.copy(dA[bb, bs * block_S:(bs + 1) * block_S, bh, :], dA_shared) T.copy(dA[bb, bs * block_S : (bs + 1) * block_S, bh, :], dA_shared)
# T.copy(dA_shared, dA[bb, bs * block_S:(bs + 1) * block_S, bh, :]) # T.copy(dA_shared, dA[bb, bs * block_S:(bs + 1) * block_S, bh, :])
# Update dA # Update dA
...@@ -385,8 +362,7 @@ def tilelang_wy_fast_bwd_split( ...@@ -385,8 +362,7 @@ def tilelang_wy_fast_bwd_split(
for i_s1, i_s2 in T.Parallel(block_S, block_S): for i_s1, i_s2 in T.Parallel(block_S, block_S):
with T.If(G[bb, bs * block_S + i_s1, bh] - G[bb, bs * block_S + i_s2, bh] <= 0): with T.If(G[bb, bs * block_S + i_s1, bh] - G[bb, bs * block_S + i_s2, bh] <= 0):
with T.Then(): with T.Then():
dA_fragment[i_s1, i_s2] *= T.exp(G[bb, bs * block_S + i_s1, bh] - dA_fragment[i_s1, i_s2] *= T.exp(G[bb, bs * block_S + i_s1, bh] - G[bb, bs * block_S + i_s2, bh])
G[bb, bs * block_S + i_s2, bh])
with T.Else(): with T.Else():
dA_fragment[i_s1, i_s2] = 0 dA_fragment[i_s1, i_s2] = 0
T.copy(dA_fragment, dA_shared) T.copy(dA_fragment, dA_shared)
...@@ -397,12 +373,8 @@ def tilelang_wy_fast_bwd_split( ...@@ -397,12 +373,8 @@ def tilelang_wy_fast_bwd_split(
# Update dk using previous dk # Update dk using previous dk
T.clear(A_fragment) T.clear(A_fragment)
for i_k in T.Pipelined(T.ceildiv(DK, block_DK), num_stages=num_stages): for i_k in T.Pipelined(T.ceildiv(DK, block_DK), num_stages=num_stages):
T.copy( T.copy(K[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK], K_shared)
K[bb, bs * block_S:(bs + 1) * block_S, bh, i_k * block_DK:(i_k + 1) * block_DK], T.copy(dk[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK], dk_shared)
K_shared)
T.copy(
dk[bb, bs * block_S:(bs + 1) * block_S, bh,
i_k * block_DK:(i_k + 1) * block_DK], dk_shared)
T.copy(dk_shared, dk_fragment) T.copy(dk_shared, dk_fragment)
for i_s, i_k2 in T.Parallel(block_S, block_DK): for i_s, i_k2 in T.Parallel(block_S, block_DK):
K_shared_beta[i_s, i_k2] = K_shared[i_s, i_k2] * Beta_shared[i_s] K_shared_beta[i_s, i_k2] = K_shared[i_s, i_k2] * Beta_shared[i_s]
...@@ -411,18 +383,14 @@ def tilelang_wy_fast_bwd_split( ...@@ -411,18 +383,14 @@ def tilelang_wy_fast_bwd_split(
# for i_s, i_k2 in T.Parallel(block_S, block_DK): # for i_s, i_k2 in T.Parallel(block_S, block_DK):
# dbeta_fragment[i_s] = dbeta_fragment[i_s] + dk_fragment_beta[i_s, i_k2] * K_shared[i_s, i_k2] # dbeta_fragment[i_s] = dbeta_fragment[i_s] + dk_fragment_beta[i_s, i_k2] * K_shared[i_s, i_k2]
for i_s, i_k2 in T.Parallel(block_S, block_DK): for i_s, i_k2 in T.Parallel(block_S, block_DK):
dbeta_fragment_reduce_tmpk[i_s, dbeta_fragment_reduce_tmpk[i_s, i_k2] = dk_fragment_beta[i_s, i_k2] * K_shared[i_s, i_k2]
i_k2] = dk_fragment_beta[i_s, i_k2] * K_shared[i_s,
i_k2]
T.reduce_sum(dbeta_fragment_reduce_tmpk, dbeta_fragment_k, dim=1, clear=False) T.reduce_sum(dbeta_fragment_reduce_tmpk, dbeta_fragment_k, dim=1, clear=False)
T.gemm(dA_shared, K_shared_beta, dk_fragment, transpose_A=True) T.gemm(dA_shared, K_shared_beta, dk_fragment, transpose_A=True)
for i_s, i_k2 in T.Parallel(block_S, block_DK): for i_s, i_k2 in T.Parallel(block_S, block_DK):
dk_shared_beta[i_s, i_k2] = dk_fragment_beta[i_s, i_k2] * Beta_shared[i_s] dk_shared_beta[i_s, i_k2] = dk_fragment_beta[i_s, i_k2] * Beta_shared[i_s]
for i_s, i_k2 in T.Parallel(block_S, block_DK): for i_s, i_k2 in T.Parallel(block_S, block_DK):
dk_fragment[i_s, i_k2] = dk_fragment[i_s, i_k2] + dk_shared_beta[i_s, i_k2] dk_fragment[i_s, i_k2] = dk_fragment[i_s, i_k2] + dk_shared_beta[i_s, i_k2]
T.copy( T.copy(dk_fragment, dk[bb, bs * block_S : (bs + 1) * block_S, bh, i_k * block_DK : (i_k + 1) * block_DK])
dk_fragment, dk[bb, bs * block_S:(bs + 1) * block_S, bh,
i_k * block_DK:(i_k + 1) * block_DK])
# Update dg and dbeta # Update dg and dbeta
T.copy(A_fragment, A_shared) T.copy(A_fragment, A_shared)
...@@ -460,19 +428,25 @@ def run_test( ...@@ -460,19 +428,25 @@ def run_test(
threads=128, threads=128,
num_stages=0, num_stages=0,
): ):
K, V, Beta, G, A, dw, du = prepare_input(B, S, H, DK, DV, chunk_size, K, V, Beta, G, A, dw, du = prepare_input(
getattr(torch, input_dtype), B,
getattr(torch, output_dtype), S,
getattr(torch, H,
accum_dtype), getattr(torch, gate_dtype), DK,
getattr(torch, state_dtype)) DV,
dk_ref, dv_ref, dbeta_ref, dg_ref = prepare_output(B, S, H, DK, DV, chunk_size, chunk_size,
getattr(torch, output_dtype), getattr(torch, input_dtype),
getattr(torch, gate_dtype), getattr(torch, output_dtype),
getattr(torch, state_dtype)) getattr(torch, accum_dtype),
getattr(torch, gate_dtype),
getattr(torch, state_dtype),
)
dk_ref, dv_ref, dbeta_ref, dg_ref = prepare_output(
B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype), getattr(torch, state_dtype)
)
dk_tilelang, dv_tilelang, dbeta_tilelang, dg_tilelang = prepare_output( dk_tilelang, dv_tilelang, dbeta_tilelang, dg_tilelang = prepare_output(
B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype), B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype), getattr(torch, state_dtype)
getattr(torch, state_dtype)) )
BS = chunk_size BS = chunk_size
dA_tilelang = torch.empty(B, S, H, BS, dtype=getattr(torch, input_dtype)).cuda() dA_tilelang = torch.empty(B, S, H, BS, dtype=getattr(torch, input_dtype)).cuda()
dbeta_tilelang_k = torch.empty(B, S, H, dtype=getattr(torch, output_dtype)).cuda() dbeta_tilelang_k = torch.empty(B, S, H, dtype=getattr(torch, output_dtype)).cuda()
...@@ -480,28 +454,55 @@ def run_test( ...@@ -480,28 +454,55 @@ def run_test(
dg_tilelang_A_negative = torch.empty(B, S, H, BS, dtype=getattr(torch, gate_dtype)).cuda() dg_tilelang_A_negative = torch.empty(B, S, H, BS, dtype=getattr(torch, gate_dtype)).cuda()
# ref # ref
dk_ref, dv_ref, dbeta_ref, dg_ref = bwd_prepare_wy_repr( dk_ref, dv_ref, dbeta_ref, dg_ref = bwd_prepare_wy_repr(K, V, G, Beta, A, dw, du, cu_seqlens=None)
K, V, G, Beta, A, dw, du, cu_seqlens=None)
# tilelang # tilelang
kernel = tilelang_wy_fast_bwd(B, S, H, DK, DV, input_dtype, output_dtype, accum_dtype, kernel = tilelang_wy_fast_bwd(
gate_dtype, state_dtype, chunk_size, block_DK, block_DV, threads, B,
num_stages) S,
dA_tilelang, dk_tilelang, dv_tilelang, dbeta_tilelang, dg_tilelang = kernel( H,
K, V, Beta, G, A, dw, du) DK,
DV,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
state_dtype,
chunk_size,
block_DK,
block_DV,
threads,
num_stages,
)
dA_tilelang, dk_tilelang, dv_tilelang, dbeta_tilelang, dg_tilelang = kernel(K, V, Beta, G, A, dw, du)
torch.cuda.synchronize() torch.cuda.synchronize()
kernel_split = tilelang_wy_fast_bwd_split(B, S, H, DK, DV, input_dtype, output_dtype, kernel_split = tilelang_wy_fast_bwd_split(
accum_dtype, gate_dtype, state_dtype, chunk_size, B,
block_DK, block_DV, threads, num_stages) S,
kernel_split(K, V, Beta, G, A, dw, du, dA_tilelang, dk_tilelang, dv_tilelang, dbeta_tilelang_k, H,
dg_tilelang_A_positive, dg_tilelang_A_negative) DK,
DV,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
state_dtype,
chunk_size,
block_DK,
block_DV,
threads,
num_stages,
)
kernel_split(
K, V, Beta, G, A, dw, du, dA_tilelang, dk_tilelang, dv_tilelang, dbeta_tilelang_k, dg_tilelang_A_positive, dg_tilelang_A_negative
)
torch.cuda.synchronize() torch.cuda.synchronize()
dbeta_tilelang = dbeta_tilelang_k + dbeta_tilelang dbeta_tilelang = dbeta_tilelang_k + dbeta_tilelang
dg_tilelang = dg_tilelang + dg_tilelang_A_positive.sum(dim=-1) - dg_tilelang_A_negative.sum( dg_tilelang = dg_tilelang + dg_tilelang_A_positive.sum(dim=-1) - dg_tilelang_A_negative.sum(dim=-1)
dim=-1)
from test_utils import assert_similar
from utils import assert_similar
assert_similar(dk_ref, dk_tilelang, eps=1e-5, name="dk", raise_assert=False) assert_similar(dk_ref, dk_tilelang, eps=1e-5, name="dk", raise_assert=False)
assert_similar(dv_ref, dv_tilelang, eps=1e-5, name="dv", raise_assert=False) assert_similar(dv_ref, dv_tilelang, eps=1e-5, name="dv", raise_assert=False)
assert_similar(dbeta_ref, dbeta_tilelang, eps=1e-5, name="dbeta", raise_assert=False) assert_similar(dbeta_ref, dbeta_tilelang, eps=1e-5, name="dbeta", raise_assert=False)
...@@ -517,11 +518,11 @@ def main(): ...@@ -517,11 +518,11 @@ def main():
H=8, H=8,
DK=DK, DK=DK,
DV=DV, DV=DV,
input_dtype="bfloat16", input_dtype=T.bfloat16,
output_dtype="bfloat16", output_dtype=T.bfloat16,
accum_dtype="float32", accum_dtype=T.float32,
gate_dtype="float32", gate_dtype=T.float32,
state_dtype="float32", state_dtype=T.float32,
chunk_size=64, chunk_size=64,
block_DK=32, block_DK=32,
block_DV=32, block_DV=32,
......
import tilelang.testing
import torch import torch
import tilelang.testing
from tilelang import language as T
B = 1 B = 1
S = 1024 # small but for test only. S = 1024 # small but for test only.
H = 32 H = 32
DK = 128 DK = 128
DV = 128 DV = 128
input_dtype = "bfloat16" input_dtype = T.bfloat16
output_dtype = "bfloat16" output_dtype = T.bfloat16
accum_dtype = "float32" accum_dtype = T.float32
gate_dtype = "float32" gate_dtype = T.float32
state_dtype = "float32" state_dtype = T.float32
chunk_size = 64 chunk_size = 64
use_g = True use_g = True
use_initial_state = True use_initial_state = True
...@@ -25,16 +26,10 @@ num_stages = 1 ...@@ -25,16 +26,10 @@ num_stages = 1
def test_example_wy_fast_compilation(): def test_example_wy_fast_compilation():
from example_wy_fast import tilelang_recompute_w_u_fwd, prepare_input from example_wy_fast import tilelang_recompute_w_u_fwd, prepare_input
K, V, Beta, G, A = prepare_input( K, V, Beta, G, A = prepare_input(
B, B, S, H, DK, DV, chunk_size, getattr(torch, input_dtype), getattr(torch, output_dtype), gate_dtype=getattr(torch, gate_dtype)
S, )
H,
DK,
DV,
chunk_size,
getattr(torch, input_dtype),
getattr(torch, output_dtype),
gate_dtype=getattr(torch, gate_dtype))
# tilelang # tilelang
block_S = chunk_size block_S = chunk_size
kernel = tilelang_recompute_w_u_fwd( kernel = tilelang_recompute_w_u_fwd(
...@@ -52,22 +47,31 @@ def test_example_wy_fast_compilation(): ...@@ -52,22 +47,31 @@ def test_example_wy_fast_compilation():
block_DK=block_DK, block_DK=block_DK,
block_DV=block_DV, block_DV=block_DV,
threads=threads, threads=threads,
num_stages=num_stages) num_stages=num_stages,
)
print(kernel.get_kernel_source()) print(kernel.get_kernel_source())
W_tilelang, U_tilelang = kernel(K, V, Beta, G, A) W_tilelang, U_tilelang = kernel(K, V, Beta, G, A)
def test_example_wy_fast_bwd_split_compilation(): def test_example_wy_fast_bwd_split_compilation():
from example_wy_fast_bwd_split import tilelang_wy_fast_bwd, tilelang_wy_fast_bwd_split, prepare_input, prepare_output from example_wy_fast_bwd_split import tilelang_wy_fast_bwd, tilelang_wy_fast_bwd_split, prepare_input, prepare_output
K, V, Beta, G, A, dw, du = prepare_input(B, S, H, DK, DV, chunk_size,
getattr(torch, input_dtype), K, V, Beta, G, A, dw, du = prepare_input(
getattr(torch, output_dtype), B,
getattr(torch, S,
accum_dtype), getattr(torch, gate_dtype), H,
getattr(torch, state_dtype)) DK,
DV,
chunk_size,
getattr(torch, input_dtype),
getattr(torch, output_dtype),
getattr(torch, accum_dtype),
getattr(torch, gate_dtype),
getattr(torch, state_dtype),
)
dk_tilelang, dv_tilelang, dbeta_tilelang, dg_tilelang = prepare_output( dk_tilelang, dv_tilelang, dbeta_tilelang, dg_tilelang = prepare_output(
B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype), B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype), getattr(torch, state_dtype)
getattr(torch, state_dtype)) )
BS = chunk_size BS = chunk_size
dA_tilelang = torch.empty(B, S, H, BS, dtype=getattr(torch, input_dtype)).cuda() dA_tilelang = torch.empty(B, S, H, BS, dtype=getattr(torch, input_dtype)).cuda()
dbeta_tilelang_k = torch.empty(B, S, H, dtype=getattr(torch, output_dtype)).cuda() dbeta_tilelang_k = torch.empty(B, S, H, dtype=getattr(torch, output_dtype)).cuda()
...@@ -75,67 +79,146 @@ def test_example_wy_fast_bwd_split_compilation(): ...@@ -75,67 +79,146 @@ def test_example_wy_fast_bwd_split_compilation():
dg_tilelang_A_negative = torch.empty(B, S, H, BS, dtype=getattr(torch, gate_dtype)).cuda() dg_tilelang_A_negative = torch.empty(B, S, H, BS, dtype=getattr(torch, gate_dtype)).cuda()
# tilelang # tilelang
kernel = tilelang_wy_fast_bwd(B, S, H, DK, DV, input_dtype, output_dtype, accum_dtype, kernel = tilelang_wy_fast_bwd(
gate_dtype, state_dtype, chunk_size, block_DK, block_DV, threads, B,
num_stages) S,
dA_tilelang, dk_tilelang, dv_tilelang, dbeta_tilelang, dg_tilelang = kernel( H,
K, V, Beta, G, A, dw, du) DK,
DV,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
state_dtype,
chunk_size,
block_DK,
block_DV,
threads,
num_stages,
)
dA_tilelang, dk_tilelang, dv_tilelang, dbeta_tilelang, dg_tilelang = kernel(K, V, Beta, G, A, dw, du)
torch.cuda.synchronize() torch.cuda.synchronize()
kernel_split = tilelang_wy_fast_bwd_split(B, S, H, DK, DV, input_dtype, output_dtype, kernel_split = tilelang_wy_fast_bwd_split(
accum_dtype, gate_dtype, state_dtype, chunk_size, B,
block_DK, block_DV, threads, num_stages) S,
kernel_split(K, V, Beta, G, A, dw, du, dA_tilelang, dk_tilelang, dv_tilelang, dbeta_tilelang_k, H,
dg_tilelang_A_positive, dg_tilelang_A_negative) DK,
DV,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
state_dtype,
chunk_size,
block_DK,
block_DV,
threads,
num_stages,
)
kernel_split(
K, V, Beta, G, A, dw, du, dA_tilelang, dk_tilelang, dv_tilelang, dbeta_tilelang_k, dg_tilelang_A_positive, dg_tilelang_A_negative
)
torch.cuda.synchronize() torch.cuda.synchronize()
dbeta_tilelang = dbeta_tilelang_k + dbeta_tilelang dbeta_tilelang = dbeta_tilelang_k + dbeta_tilelang
dg_tilelang = dg_tilelang + dg_tilelang_A_positive.sum(dim=-1) - dg_tilelang_A_negative.sum( dg_tilelang = dg_tilelang + dg_tilelang_A_positive.sum(dim=-1) - dg_tilelang_A_negative.sum(dim=-1)
dim=-1)
def test_example_chunk_o_compilation(): def test_example_chunk_o_compilation():
from example_chunk_o import tilelang_chunk_fwd_o, prepare_input from example_chunk_o import tilelang_chunk_fwd_o, prepare_input
Q, K, V, HIDDEN, G = prepare_input(B, S, H, DK, DV, chunk_size, getattr(torch, input_dtype),
getattr(torch, output_dtype), getattr(torch, accum_dtype), Q, K, V, HIDDEN, G = prepare_input(
getattr(torch, gate_dtype)) B,
S,
H,
DK,
DV,
chunk_size,
getattr(torch, input_dtype),
getattr(torch, output_dtype),
getattr(torch, accum_dtype),
getattr(torch, gate_dtype),
)
scale = 1.0 / DK**0.5 scale = 1.0 / DK**0.5
block_S = chunk_size block_S = chunk_size
kernel = tilelang_chunk_fwd_o(B, S, H, DK, DV, input_dtype, output_dtype, accum_dtype, kernel = tilelang_chunk_fwd_o(
gate_dtype, chunk_size, scale, use_g, block_S, block_DK, block_DV, B,
threads, num_stages) S,
H,
DK,
DV,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
chunk_size,
scale,
use_g,
block_S,
block_DK,
block_DV,
threads,
num_stages,
)
O_tilelang = kernel(Q, K, V, HIDDEN, G) # noqa: F841 O_tilelang = kernel(Q, K, V, HIDDEN, G) # noqa: F841
def test_example_chunk_o_bwd_compilation(): def test_example_chunk_o_bwd_compilation():
from example_chunk_o_bwd import tilelang_chunk_o_bwd_dqkwg, prepare_input from example_chunk_o_bwd import tilelang_chunk_o_bwd_dqkwg, prepare_input
Q, K, V, h, G, dO, dh, dv, W = prepare_input(B, S, H, DK, DV, chunk_size,
getattr(torch, input_dtype), Q, K, V, h, G, dO, dh, dv, W = prepare_input(
getattr(torch, output_dtype), B,
getattr(torch, accum_dtype), S,
getattr(torch, gate_dtype), H,
getattr(torch, state_dtype)) DK,
kernel = tilelang_chunk_o_bwd_dqkwg(B, S, H, DK, DV, input_dtype, output_dtype, accum_dtype, DV,
gate_dtype, state_dtype, chunk_size, 1.0, use_g, True, chunk_size,
block_DK, block_DV, threads, num_stages) getattr(torch, input_dtype),
dq_tilelang, dk_tilelang, dw_tilelang, dg_tilelang = kernel(Q, K, V, h, G, dO, dh, dv, getattr(torch, output_dtype),
W) # noqa: F841 getattr(torch, accum_dtype),
getattr(torch, gate_dtype),
getattr(torch, state_dtype),
)
kernel = tilelang_chunk_o_bwd_dqkwg(
B,
S,
H,
DK,
DV,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
state_dtype,
chunk_size,
1.0,
use_g,
True,
block_DK,
block_DV,
threads,
num_stages,
)
dq_tilelang, dk_tilelang, dw_tilelang, dg_tilelang = kernel(Q, K, V, h, G, dO, dh, dv, W) # noqa: F841
if use_g: if use_g:
dg_tilelang = dg_tilelang.sum(dim=0) dg_tilelang = dg_tilelang.sum(dim=0)
def test_example_chunk_scaled_dot_kkt_compilation(): def test_example_chunk_scaled_dot_kkt_compilation():
from example_chunk_scaled_dot_kkt import tilelang_chunk_scaled_dot_kkt_fwd, prepare_input from example_chunk_scaled_dot_kkt import tilelang_chunk_scaled_dot_kkt_fwd, prepare_input
K, Beta, G = prepare_input(B, S, H, DK, getattr(torch, input_dtype),
getattr(torch, output_dtype), getattr(torch, accum_dtype)) K, Beta, G = prepare_input(B, S, H, DK, getattr(torch, input_dtype), getattr(torch, output_dtype), getattr(torch, accum_dtype))
block_S = chunk_size block_S = chunk_size
kernel = tilelang_chunk_scaled_dot_kkt_fwd(B, S, H, DK, chunk_size, input_dtype, output_dtype, kernel = tilelang_chunk_scaled_dot_kkt_fwd(
accum_dtype, use_g, block_S, block_DK, threads, B, S, H, DK, chunk_size, input_dtype, output_dtype, accum_dtype, use_g, block_S, block_DK, threads, num_stages
num_stages) )
A_tilelang = kernel(K, Beta, G) # noqa: F841 A_tilelang = kernel(K, Beta, G) # noqa: F841
def test_example_cumsum_compilation(): def test_example_cumsum_compilation():
from example_cumsum import tilelang_chunk_local_cumsum_scalar, prepare_cumsum_input, prepare_cumsum_output from example_cumsum import tilelang_chunk_local_cumsum_scalar, prepare_cumsum_input, prepare_cumsum_output
G = prepare_cumsum_input(B, S, H, getattr(torch, gate_dtype)) G = prepare_cumsum_input(B, S, H, getattr(torch, gate_dtype))
G_new_tilelang = prepare_cumsum_output(B, S, H, getattr(torch, gate_dtype)) G_new_tilelang = prepare_cumsum_output(B, S, H, getattr(torch, gate_dtype))
block_S = chunk_size block_S = chunk_size
...@@ -157,33 +240,79 @@ def test_example_cumsum_compilation(): ...@@ -157,33 +240,79 @@ def test_example_cumsum_compilation():
def test_example_chunk_delta_h_compilation(): def test_example_chunk_delta_h_compilation():
from example_chunk_delta_h import tilelang_chunk_gated_delta_rule_fwd_h, prepare_input from example_chunk_delta_h import tilelang_chunk_gated_delta_rule_fwd_h, prepare_input
K, W, U, G, initial_state = prepare_input(B, S, H, DK, DV, chunk_size,
getattr(torch, input_dtype), K, W, U, G, initial_state = prepare_input(
getattr(torch, output_dtype), B,
getattr(torch, accum_dtype), S,
getattr(torch, gate_dtype)) H,
kernel = tilelang_chunk_gated_delta_rule_fwd_h(B, S, H, DK, DV, input_dtype, output_dtype, DK,
accum_dtype, gate_dtype, state_dtype, chunk_size, DV,
use_g, use_initial_state, store_final_state, chunk_size,
save_new_value, block_DK, block_DV, threads, getattr(torch, input_dtype),
num_stages) getattr(torch, output_dtype),
h_tilelang, final_state_tilelang, V_new_tilelang = kernel(K, W, U, G, getattr(torch, accum_dtype),
initial_state) # noqa: F841 getattr(torch, gate_dtype),
)
kernel = tilelang_chunk_gated_delta_rule_fwd_h(
B,
S,
H,
DK,
DV,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
state_dtype,
chunk_size,
use_g,
use_initial_state,
store_final_state,
save_new_value,
block_DK,
block_DV,
threads,
num_stages,
)
h_tilelang, final_state_tilelang, V_new_tilelang = kernel(K, W, U, G, initial_state) # noqa: F841
def test_example_chunk_delta_bwd_compilation(): def test_example_chunk_delta_bwd_compilation():
from example_chunk_delta_bwd import tilelang_chunk_gated_delta_rule_bwd_dhu, prepare_input from example_chunk_delta_bwd import tilelang_chunk_gated_delta_rule_bwd_dhu, prepare_input
Q, K, W, G, h0, dht, dO, dv = prepare_input(B, S, H, DK, DV, chunk_size,
getattr(torch, input_dtype), Q, K, W, G, h0, dht, dO, dv = prepare_input(
getattr(torch, output_dtype), B,
getattr(torch, accum_dtype), S,
getattr(torch, gate_dtype), H,
getattr(torch, state_dtype)) DK,
kernel = tilelang_chunk_gated_delta_rule_bwd_dhu(B, S, H, DK, DV, input_dtype, output_dtype, DV,
accum_dtype, gate_dtype, state_dtype, chunk_size,
chunk_size, 1.0, use_g, use_initial_state, getattr(torch, input_dtype),
use_final_state_gradient, block_DV, threads, getattr(torch, output_dtype),
num_stages) getattr(torch, accum_dtype),
getattr(torch, gate_dtype),
getattr(torch, state_dtype),
)
kernel = tilelang_chunk_gated_delta_rule_bwd_dhu(
B,
S,
H,
DK,
DV,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
state_dtype,
chunk_size,
1.0,
use_g,
use_initial_state,
use_final_state_gradient,
block_DV,
threads,
num_stages,
)
dh_tilelang, dh0_tilelang, dv2_tilelang = kernel(Q, K, W, G, h0, dht, dO, dv) # noqa: F841 dh_tilelang, dh0_tilelang, dv2_tilelang = kernel(Q, K, W, G, h0, dht, dO, dv) # noqa: F841
......
...@@ -9,7 +9,7 @@ def calc_sim(x, y, name="tensor"): ...@@ -9,7 +9,7 @@ def calc_sim(x, y, name="tensor"):
x, y = x.data.double(), y.data.double() x, y = x.data.double(), y.data.double()
denominator = (x * x + y * y).sum() denominator = (x * x + y * y).sum()
if denominator == 0: if denominator == 0:
print_red_warning(f'{name} all zero') print_red_warning(f"{name} all zero")
return 1 return 1
sim = 2 * (x * y).sum() / denominator sim = 2 * (x * y).sum() / denominator
return sim return sim
...@@ -19,21 +19,19 @@ def assert_similar(x, y, eps=1e-8, name="tensor", data="", raise_assert=True): ...@@ -19,21 +19,19 @@ def assert_similar(x, y, eps=1e-8, name="tensor", data="", raise_assert=True):
x_mask = torch.isfinite(x) x_mask = torch.isfinite(x)
y_mask = torch.isfinite(y) y_mask = torch.isfinite(y)
if not torch.all(x_mask == y_mask): if not torch.all(x_mask == y_mask):
print_red_warning(f'{name} Error: isfinite mask mismatch') print_red_warning(f"{name} Error: isfinite mask mismatch")
if raise_assert: if raise_assert:
raise AssertionError raise AssertionError
if not torch.isclose( if not torch.isclose(x.masked_fill(x_mask, 0), y.masked_fill(y_mask, 0), rtol=0, atol=0, equal_nan=True).all():
x.masked_fill(x_mask, 0), y.masked_fill(y_mask, 0), rtol=0, atol=0, print_red_warning(f"{name} Error: nonfinite value mismatch")
equal_nan=True).all():
print_red_warning(f'{name} Error: nonfinite value mismatch')
if raise_assert: if raise_assert:
raise AssertionError raise AssertionError
x = x.masked_fill(~x_mask, 0) x = x.masked_fill(~x_mask, 0)
y = y.masked_fill(~y_mask, 0) y = y.masked_fill(~y_mask, 0)
sim = calc_sim(x, y, name) sim = calc_sim(x, y, name)
diff = 1. - sim diff = 1.0 - sim
if not (0 <= diff <= eps): if not (0 <= diff <= eps):
print_red_warning(f'{name} Error: {diff}') print_red_warning(f"{name} Error: {diff}")
if raise_assert: if raise_assert:
raise AssertionError raise AssertionError
else: else:
......
...@@ -53,7 +53,7 @@ import tilelang ...@@ -53,7 +53,7 @@ import tilelang
from tilelang import Profiler from tilelang import Profiler
import tilelang.language as T import tilelang.language as T
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float):
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, K), dtype), A: T.Tensor((M, K), dtype),
...@@ -176,7 +176,7 @@ import tilelang.language as T ...@@ -176,7 +176,7 @@ import tilelang.language as T
# that helps align data for MMA (Matrix Multiply-Accumulate) operations. # that helps align data for MMA (Matrix Multiply-Accumulate) operations.
from tilelang.intrinsics import make_mma_swizzle_layout as make_swizzle_layout from tilelang.intrinsics import make_mma_swizzle_layout as make_swizzle_layout
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float):
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, K), dtype), A: T.Tensor((M, K), dtype),
...@@ -265,18 +265,18 @@ def tl_matmul( ...@@ -265,18 +265,18 @@ def tl_matmul(
accum_dtype, accum_dtype,
): ):
assert in_dtype in [ assert in_dtype in [
"float16", T.float16,
"int8", T.int8,
], "Currently only float16 and int8 are supported" ], "Currently only float16 and int8 are supported"
assert out_dtype in [ assert out_dtype in [
"float16", T.float16,
"float32", T.float32,
"int32", T.int32,
], "Currently only float16, float32 and int32 are supported" ], "Currently only float16, float32 and int32 are supported"
micro_size_x = micro_size_y = micro_size_k = 16 micro_size_x = micro_size_y = micro_size_k = 16
if out_dtype == "int32": if out_dtype == T.int32:
micro_size_k = 32 micro_size_k = 32
# This is a debug config # This is a debug config
......
...@@ -3,13 +3,12 @@ import tilelang.language as T ...@@ -3,13 +3,12 @@ import tilelang.language as T
@tilelang.jit(out_idx=[-1]) @tilelang.jit(out_idx=[-1])
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32):
@T.prim_func @T.prim_func
def gemm( def gemm(
A: T.Tensor((M, K), dtype), A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype), B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype), C: T.Tensor((M, N), dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype) A_shared = T.alloc_shared((block_M, block_K), dtype)
......
...@@ -51,9 +51,9 @@ def get_configs(M, N, K, with_roller=False, topk=20): ...@@ -51,9 +51,9 @@ def get_configs(M, N, K, with_roller=False, topk=20):
M=M, M=M,
N=N, N=N,
K=K, K=K,
in_dtype="float16", in_dtype=T.float16,
out_dtype="float16", out_dtype=T.float16,
accum_dtype="float", accum_dtype=T.float32,
).with_arch(arch) ).with_arch(arch)
func = carve_template.equivalent_function() func = carve_template.equivalent_function()
...@@ -90,7 +90,8 @@ def get_configs(M, N, K, with_roller=False, topk=20): ...@@ -90,7 +90,8 @@ def get_configs(M, N, K, with_roller=False, topk=20):
num_stages, num_stages,
thread_num, thread_num,
enable_rasterization, enable_rasterization,
)) )
)
configs = [ configs = [
{ {
...@@ -100,13 +101,13 @@ def get_configs(M, N, K, with_roller=False, topk=20): ...@@ -100,13 +101,13 @@ def get_configs(M, N, K, with_roller=False, topk=20):
"num_stages": c[3], "num_stages": c[3],
"thread_num": c[4], "thread_num": c[4],
"enable_rasteration": c[5], # keep param name for backward-compat "enable_rasteration": c[5], # keep param name for backward-compat
} for c in _configs }
for c in _configs
] ]
return configs return configs
def get_best_config(M, N, K, with_roller=False): def get_best_config(M, N, K, with_roller=False):
def kernel( def kernel(
block_M=None, block_M=None,
block_N=None, block_N=None,
...@@ -115,17 +116,16 @@ def get_best_config(M, N, K, with_roller=False): ...@@ -115,17 +116,16 @@ def get_best_config(M, N, K, with_roller=False):
thread_num=None, thread_num=None,
enable_rasteration=None, enable_rasteration=None,
): ):
dtype = "bfloat16" dtype = T.bfloat16
accum_dtype = "float" accum_dtype = T.float32
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, K), dtype), A: T.Tensor((M, K), dtype),
B: T.Tensor((N, K), dtype), B: T.Tensor((N, K), dtype),
C: T.Tensor((M, N), dtype), C: T.Tensor((M, N), dtype),
): ):
with T.Kernel( with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype) A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_N, block_K), dtype) B_shared = T.alloc_shared((block_N, block_K), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype) C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
...@@ -146,15 +146,18 @@ def get_best_config(M, N, K, with_roller=False): ...@@ -146,15 +146,18 @@ def get_best_config(M, N, K, with_roller=False):
return main return main
autotuner = AutoTuner.from_kernel( autotuner = (
kernel=kernel, configs=get_configs(M, N, K, with_roller)).set_compile_args( AutoTuner.from_kernel(kernel=kernel, configs=get_configs(M, N, K, with_roller))
.set_compile_args(
out_idx=[-1], out_idx=[-1],
target="auto", target="auto",
).set_profile_args( )
.set_profile_args(
supply_type=tl.TensorSupplyType.Integer, supply_type=tl.TensorSupplyType.Integer,
ref_prog=ref_program, ref_prog=ref_program,
skip_check=False, skip_check=False,
) )
)
return autotuner.run(warmup=3, rep=20) return autotuner.run(warmup=3, rep=20)
...@@ -167,52 +170,20 @@ def get_heuristic_config() -> dict: ...@@ -167,52 +170,20 @@ def get_heuristic_config() -> dict:
sm_version = sm_major * 10 + sm_minor sm_version = sm_major * 10 + sm_minor
print(f"CUDA device capability: {sm_version}") print(f"CUDA device capability: {sm_version}")
if sm_version in {80}: if sm_version in {80}:
return { return {"block_M": 128, "block_N": 256, "block_K": 32, "num_stages": 2, "thread_num": 128, "enable_rasteration": True}
"block_M": 128,
"block_N": 256,
"block_K": 32,
"num_stages": 2,
"thread_num": 128,
"enable_rasteration": True
}
elif sm_version in {90}: elif sm_version in {90}:
return { return {"block_M": 128, "block_N": 256, "block_K": 64, "num_stages": 3, "thread_num": 256, "enable_rasteration": True}
"block_M": 128,
"block_N": 256,
"block_K": 64,
"num_stages": 3,
"thread_num": 256,
"enable_rasteration": True
}
else: else:
return { return {"block_M": 128, "block_N": 256, "block_K": 32, "num_stages": 0, "thread_num": 128, "enable_rasteration": True}
"block_M": 128,
"block_N": 256,
"block_K": 32,
"num_stages": 0,
"thread_num": 128,
"enable_rasteration": True
}
@tl.jit(out_idx=[-1]) @tl.jit(out_idx=[-1])
def matmul(M, def matmul(M, N, K, block_M, block_N, block_K, num_stages, thread_num, enable_rasteration, dtype=T.float16, accum_dtype=T.float32):
N,
K,
block_M,
block_N,
block_K,
num_stages,
thread_num,
enable_rasteration,
dtype="float16",
accum_dtype="float"):
@T.prim_func @T.prim_func
def gemm_autotune( def gemm_autotune(
A: T.Tensor((M, K), dtype), A: T.Tensor((M, K), dtype),
B: T.Tensor((N, K), dtype), B: T.Tensor((N, K), dtype),
C: T.Tensor((M, N), dtype), C: T.Tensor((M, N), dtype),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype) A_shared = T.alloc_shared((block_M, block_K), dtype)
...@@ -236,11 +207,7 @@ def matmul(M, ...@@ -236,11 +207,7 @@ def matmul(M,
return gemm_autotune return gemm_autotune
def main(M: int = 4096, def main(M: int = 4096, N: int = 4096, K: int = 4096, use_autotune: bool = False, with_roller: bool = False):
N: int = 4096,
K: int = 4096,
use_autotune: bool = False,
with_roller: bool = False):
use_autotune = True use_autotune = True
if use_autotune: if use_autotune:
result = get_best_config(M, N, K, with_roller) result = get_best_config(M, N, K, with_roller)
...@@ -266,15 +233,7 @@ if __name__ == "__main__": ...@@ -266,15 +233,7 @@ if __name__ == "__main__":
parser.add_argument("--m", type=int, default=4096, help="Matrix dimension M") parser.add_argument("--m", type=int, default=4096, help="Matrix dimension M")
parser.add_argument("--n", type=int, default=4096, help="Matrix dimension N") parser.add_argument("--n", type=int, default=4096, help="Matrix dimension N")
parser.add_argument("--k", type=int, default=4096, help="Matrix dimension K") parser.add_argument("--k", type=int, default=4096, help="Matrix dimension K")
parser.add_argument( parser.add_argument("--use_autotune", action="store_true", default=False, help="Whether to use autotune for matmul configs")
"--use_autotune", parser.add_argument("--with_roller", action="store_true", default=False, help="Whether to enable BitBLAS roller for search space")
action="store_true",
default=False,
help="Whether to use autotune for matmul configs")
parser.add_argument(
"--with_roller",
action="store_true",
default=False,
help="Whether to enable BitBLAS roller for search space")
args = parser.parse_args() args = parser.parse_args()
main(args.m, args.n, args.k, args.use_autotune, args.with_roller) main(args.m, args.n, args.k, args.use_autotune, args.with_roller)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment