Commit bc2d5632 authored by root's avatar root
Browse files

init

parents
Pipeline #3222 failed with stages
in 0 seconds
import tilelang.testing
import example_gqa_decode
import example_mha_inference
# TODO(lei): fix the correctness of gqa decode on sm90
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_le(8, 9)
def test_example_example_gqa_decode():
example_gqa_decode.main()
def test_example_example_mha_inference():
example_mha_inference.main()
if __name__ == "__main__":
tilelang.testing.main()
import math
import torch
import torch.nn as nn
from typing import Dict, Tuple, Optional
import tilelang
import tilelang.language as T
from tilelang.autotuner import *
from example_fusedmoe_torch import *
@tilelang.jit(pass_configs={"tl.disable_tma_lower": True, "tl.disable_warp_specialized": True})
def moe_forward_tilelang_shared(d_hidden,
d_expert,
n_shared_experts,
dtype,
num_tokens,
block_token=128,
block_dhidden=128,
block_dexpert=128,
threads=256,
num_stages=1):
scale = 1.44269504 # log2(e)
# Parameters
dhidden = d_hidden
dexpert = d_expert * n_shared_experts
# Tensors: Note that input shape is reshape to (num_tokens, dhidden)
input_shape = (num_tokens, dhidden)
shared_W_gate_shape = (dexpert, dhidden)
shared_W_up_shape = (dexpert, dhidden)
shared_W_down_shape = (dhidden, dexpert)
accum_type = "float32"
@T.prim_func
def kernel_shared(
input: T.Tensor(input_shape, dtype), # type: ignore
shared_W_gate: T.Tensor(shared_W_gate_shape, dtype), # type: ignore
shared_W_up: T.Tensor(shared_W_up_shape, dtype), # type: ignore
shared_W_down: T.Tensor(shared_W_down_shape, dtype), # type: ignore
up_logits: T.Tensor((num_tokens, dexpert), dtype), # type: ignore
output: T.Tensor(input_shape, dtype), # type: ignore
):
# Step 1: Compute gate and up logits
with T.Kernel(
T.ceildiv(num_tokens, block_token), T.ceildiv(dexpert, block_dexpert),
threads=threads) as (bx, by):
# Split the block to shared experts and routed experts
input_shared = T.alloc_fragment((block_token, block_dhidden), dtype=dtype)
W_gate_shared = T.alloc_shared((block_dexpert, block_dhidden), dtype=dtype)
W_up_shared = T.alloc_shared((block_dexpert, block_dhidden), dtype=dtype)
# Shared experts: no need to check expert_indices
gate_logits_local = T.alloc_fragment((block_token, block_dexpert), dtype=accum_type)
up_logits_local = T.alloc_fragment((block_token, block_dexpert), dtype=accum_type)
T.use_swizzle(10)
T.clear(gate_logits_local)
T.clear(up_logits_local)
# Parallel for gate and up matmul
for k in T.Pipelined(T.ceildiv(dhidden, block_dhidden), num_stages=num_stages):
T.copy(input[bx * block_token, k * block_dhidden], input_shared)
T.copy(shared_W_gate[by * block_dexpert, k * block_dhidden], W_gate_shared)
T.copy(shared_W_up[by * block_dexpert, k * block_dhidden], W_up_shared)
T.gemm(input_shared, W_gate_shared, gate_logits_local, transpose_B=True)
T.gemm(input_shared, W_up_shared, up_logits_local, transpose_B=True)
# Fuse with SiLU and element-wise product
for i, j in T.Parallel(block_token, block_dexpert):
gate_logits_local[i, j] = gate_logits_local[i, j] * (
1.0 / (1.0 + T.exp2(-gate_logits_local[i, j] * scale)))
up_logits_local[i, j] = up_logits_local[i, j] * gate_logits_local[i, j]
T.copy(up_logits_local, up_logits[bx * block_token, by * block_dexpert])
# Step 2: Compute down logits
with T.Kernel(
T.ceildiv(num_tokens, block_token), T.ceildiv(dhidden, block_dhidden),
threads=threads) as (bx, by):
up_logits_shared = T.alloc_fragment((block_token, block_dexpert), dtype=dtype)
W_down_shared = T.alloc_shared((block_dhidden, block_dexpert), dtype=dtype)
output_local = T.alloc_fragment((block_token, block_dhidden), dtype=accum_type)
T.use_swizzle(10)
T.clear(output_local)
for k in T.Pipelined(T.ceildiv(dexpert, block_dexpert), num_stages=num_stages):
T.copy(up_logits[bx * block_token, k * block_dexpert], up_logits_shared)
T.copy(shared_W_down[by * block_dhidden, k * block_dexpert], W_down_shared)
T.gemm(up_logits_shared, W_down_shared, output_local, transpose_B=True)
T.copy(output_local, output[bx * block_token, by * block_dhidden])
return kernel_shared
@tilelang.jit(pass_configs={"tl.disable_tma_lower": True, "tl.disable_warp_specialized": True})
def moe_forward_tilelang_routed(d_hidden,
d_expert,
n_routed_experts,
dtype,
group_sum,
group_count,
block_token=128,
block_dhidden=128,
block_dexpert=128,
threads=256,
num_stages=1,
k_pack=1,
coalesced_width=None):
scale = 1.44269504 # log2(e)
# Parameters
dhidden = d_hidden
dexpert = d_expert
n_routed_experts = n_routed_experts
# Group info
# group_sum = sum(group_sizes_list)
# group_count = len(group_sizes_list)
# M = sum([(group_size + block_token - 1) // block_token for group_size in group_sizes_list])
M = math.ceil(group_sum / block_token) + group_count
accum_dtype = "float32"
# Tensors: Note that input shape is reshape to (bs * seq_len * n_experts_per_token, dhidden) for grouped gemm
input_shape = (group_sum, dhidden)
intermediate_shape = (group_sum, dexpert)
routed_expert_gate_shape = (n_routed_experts, dexpert, dhidden)
routed_expert_up_shape = (n_routed_experts, dexpert, dhidden)
routed_expert_down_shape = (n_routed_experts, dhidden, dexpert)
routed_expert_weights_shape = (group_sum)
group_sizes_shape = (n_routed_experts)
@T.prim_func
def kernel(
input: T.Tensor(input_shape, dtype), # type: ignore
routed_expert_gate: T.Tensor(routed_expert_gate_shape, dtype), # type: ignore
routed_expert_up: T.Tensor(routed_expert_up_shape, dtype), # type: ignore
routed_expert_down: T.Tensor(routed_expert_down_shape, dtype), # type: ignore
routed_expert_weights: T.Tensor(routed_expert_weights_shape, dtype), # type: ignore
group_sizes: T.Tensor(group_sizes_shape, "int32"), # type: ignore
group_offsets: T.Tensor(group_sizes_shape, "int32"), # type: ignore
group_padded_offsets: T.Tensor(group_sizes_shape, "int32"), # type: ignore
group_idx_for_bx: T.Tensor((M,), "int32"), # type: ignore
up_logits: T.Tensor(intermediate_shape, dtype), # type: ignore
output: T.Tensor(input_shape, dtype), # type: ignore
):
# Step 1: Compute gate and up logits
with T.Kernel(M, T.ceildiv(dexpert, block_dexpert), threads=threads) as (bx, by):
input_shared = T.alloc_fragment((block_token, block_dhidden), dtype=dtype)
routed_expert_gate_shared = T.alloc_shared((block_dexpert, block_dhidden), dtype=dtype)
routed_expert_up_shared = T.alloc_shared((block_dexpert, block_dhidden), dtype=dtype)
gate_logits_local = T.alloc_fragment((block_token, block_dexpert), dtype=accum_dtype)
up_logits_local = T.alloc_fragment((block_token, block_dexpert), dtype=accum_dtype)
cur_group_idx = T.alloc_local([1], "int32")
cur_group_size = T.alloc_local([1], "int32")
T.use_swizzle(10, enable=True)
m_start_padded = bx * block_token
cur_group_idx[0] = group_idx_for_bx[bx]
cur_group_size[0] = group_sizes[cur_group_idx[0]]
m_start = m_start_padded - group_padded_offsets[cur_group_idx[0]] + group_offsets[
cur_group_idx[0]]
actual_rows = T.max(
0,
T.min(block_token, cur_group_size[0] -
(m_start_padded - group_padded_offsets[cur_group_idx[0]])))
T.clear(gate_logits_local)
T.clear(up_logits_local)
for k in T.Pipelined(T.ceildiv(dhidden, block_dhidden), num_stages=num_stages):
T.copy(
input[m_start:m_start + block_token, k * block_dhidden:(k + 1) * block_dhidden],
input_shared,
coalesced_width=coalesced_width)
T.copy(
routed_expert_gate[cur_group_idx[0],
by * block_dexpert:(by + 1) * block_dexpert,
k * block_dhidden:(k + 1) * block_dhidden],
routed_expert_gate_shared,
coalesced_width=coalesced_width)
T.gemm(
input_shared,
routed_expert_gate_shared,
gate_logits_local,
k_pack=k_pack,
transpose_B=True)
T.copy(
routed_expert_up[cur_group_idx[0], by * block_dexpert:(by + 1) * block_dexpert,
k * block_dhidden:(k + 1) * block_dhidden],
routed_expert_up_shared,
coalesced_width=coalesced_width)
T.gemm(
input_shared,
routed_expert_up_shared,
up_logits_local,
k_pack=k_pack,
transpose_B=True)
for i, j in T.Parallel(block_token, block_dexpert):
gate_logits_local[i, j] = gate_logits_local[i, j] * (
1.0 / (1.0 + T.exp2(-gate_logits_local[i, j] * scale)))
up_logits_local[i, j] = up_logits_local[i, j] * gate_logits_local[i, j]
for i, j in T.Parallel(block_token, block_dexpert):
if i < actual_rows:
up_logits[m_start + i, by * block_dexpert + j] = up_logits_local[i, j]
# Step 2: Compute down logits
with T.Kernel(M, T.ceildiv(dhidden, block_dhidden), threads=threads) as (bx, by):
up_logits_shared = T.alloc_fragment((block_token, block_dexpert), dtype=dtype)
routed_expert_down_shared = T.alloc_shared((block_dhidden, block_dexpert), dtype=dtype)
output_local = T.alloc_fragment((block_token, block_dhidden), dtype=accum_dtype)
cur_group_idx = T.alloc_local([1], "int32")
cur_group_size = T.alloc_local([1], "int32")
T.use_swizzle(10, enable=True)
m_start_padded = bx * block_token
cur_group_idx[0] = group_idx_for_bx[bx]
cur_group_size[0] = group_sizes[cur_group_idx[0]]
m_start = m_start_padded - group_padded_offsets[cur_group_idx[0]] + group_offsets[
cur_group_idx[0]]
actual_rows = T.max(
0,
T.min(block_token, cur_group_size[0] -
(m_start_padded - group_padded_offsets[cur_group_idx[0]])))
T.clear(output_local)
for k in T.Pipelined(T.ceildiv(dexpert, block_dexpert), num_stages=num_stages):
T.copy(
up_logits[m_start:m_start + block_token,
k * block_dexpert:(k + 1) * block_dexpert],
up_logits_shared,
coalesced_width=coalesced_width)
T.copy(
routed_expert_down[cur_group_idx[0],
by * block_dhidden:(by + 1) * block_dhidden,
k * block_dexpert:(k + 1) * block_dexpert],
routed_expert_down_shared,
coalesced_width=coalesced_width)
T.gemm(
up_logits_shared,
routed_expert_down_shared,
output_local,
k_pack=k_pack,
transpose_B=True)
for i, j in T.Parallel(block_token, block_dhidden):
if i < actual_rows:
output[m_start + i, by * block_dhidden +
j] = output_local[i, j] * routed_expert_weights[m_start + i]
return kernel
class Expert(nn.Module):
def __init__(self,
config: Dict,
gate: torch.Tensor,
up: torch.Tensor,
down: torch.Tensor,
d_expert: Optional[int] = None):
super().__init__()
self.config = config
self.act_fn = nn.SiLU()
self.d_hidden: int = config["d_hidden"]
self.d_expert: int = config["d_expert"] if d_expert is None else d_expert
self.device = torch.device("cuda")
self.W_gate_weight = gate.t().contiguous().to(self.device)
self.W_up_weight = up.t().contiguous().to(self.device)
self.W_down_weight = down.t().contiguous().to(self.device)
def forward(self, x: torch.Tensor) -> torch.Tensor:
gate = self.act_fn(x @ self.W_gate_weight)
out = (gate * (x @ self.W_up_weight)) @ self.W_down_weight
return out
class MoEGate(nn.Module):
def __init__(self, config: Dict, weights: Dict):
super().__init__()
self.top_k: int = config["n_experts_per_token"]
self.num_experts: int = config["n_routed_experts"]
self.d_hidden: int = config["d_hidden"]
self.W_g_weight = weights['router.weight'].t()
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
logits = x @ self.W_g_weight
scores = logits.softmax(dim=-1)
topk_scores, topk_indices = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
return topk_indices, topk_scores
class MoE(nn.Module):
def __init__(self,
config: Dict,
shared_kernel: tilelang.JITKernel,
routed_kernel: tilelang.JITKernel,
weights: Dict,
padding_M: int = 128):
super().__init__()
self.config = config
self.shared_kernel = shared_kernel
self.routed_kernel = routed_kernel
self.padding_M = padding_M
self.experts = nn.ModuleList([
Expert(
config,
gate=weights[f'experts.{i}.0.weight'],
up=weights[f'experts.{i}.1.weight'],
down=weights[f'experts.{i}.2.weight']) for i in range(config["n_routed_experts"])
])
self.device = torch.device("cuda")
self.gating_network = MoEGate(config, weights).to(self.device)
shared_expert_dim = config["d_expert"] * config["n_shared_experts"]
self.shared_expert = Expert(
config=config,
gate=weights['shared_experts.0.weight'],
up=weights['shared_experts.1.weight'],
down=weights['shared_experts.2.weight'],
d_expert=shared_expert_dim).to(self.device)
self.expert_cache = torch.zeros(
(config["batch_size"] * config["seq_len"], config["d_hidden"]),
dtype=torch.float16,
device=self.device)
self.stacked_expert_w_gate = torch.stack([expert.W_gate_weight for expert in self.experts],
dim=0)
self.stacked_expert_w_up = torch.stack([expert.W_up_weight for expert in self.experts],
dim=0)
self.stacked_expert_w_down = torch.stack([expert.W_down_weight for expert in self.experts],
dim=0)
self.stacked_expert_tokens = torch.empty(
(config["batch_size"] * config["seq_len"] * config["n_experts_per_token"],
self.config["d_hidden"]),
dtype=torch.float16,
device=self.device)
self.stacked_expert_weights = torch.empty(
(config["batch_size"] * config["seq_len"] * config["n_experts_per_token"]),
dtype=torch.float16,
device=self.device)
self.stacked_expert_tokens_idxs = torch.empty(
(config["batch_size"] * config["seq_len"] * config["n_experts_per_token"]),
dtype=torch.int64,
device=self.device)
self.up_logits_shared = torch.empty(
(config["batch_size"] * config["seq_len"], self.config["d_expert"]),
dtype=torch.float16,
device=self.device)
self.expert_output_shared = torch.empty(
(config["batch_size"] * config["seq_len"], self.config["d_hidden"]),
dtype=torch.float16,
device=self.device)
self.up_logits_routed = torch.empty(
(config["batch_size"] * config["seq_len"] * config["n_experts_per_token"],
self.config["d_expert"]),
dtype=torch.float16,
device=self.device)
self.expert_output_routed = torch.empty(
(config["batch_size"] * config["seq_len"] * config["n_experts_per_token"],
self.config["d_hidden"]),
dtype=torch.float16,
device=self.device)
@torch.no_grad()
def forward(self, x: torch.Tensor) -> torch.Tensor:
orig_shape = x.shape
batch_size, seq_len, hidden_dim = x.shape
expert_indices, expert_scores = self.gating_network(x)
flat_expert_indices = expert_indices.view(-1)
flat_expert_weights = expert_scores.view(-1)
x_flat = x.view(-1, hidden_dim)
# Prepare for grouped GEMM
idxs = flat_expert_indices.argsort()
counts = flat_expert_indices.bincount().cpu().numpy()
# counts = flat_expert_indices.bincount()
tokens_per_expert = counts.cumsum()
# tokens_per_expert = torch.cumsum(counts, dim=0)
num_per_tok = self.config["n_experts_per_token"]
token_idxs = idxs // num_per_tok
# Get stacked expert tokens and expert weights
for expert_id, end_idx in enumerate(tokens_per_expert):
start_idx = 0 if expert_id == 0 else tokens_per_expert[expert_id - 1]
if start_idx == end_idx:
continue
exp_token_idxs = token_idxs[start_idx:end_idx]
expert_tokens = x_flat[exp_token_idxs]
self.stacked_expert_tokens[start_idx:end_idx] = expert_tokens
self.stacked_expert_tokens_idxs[start_idx:end_idx] = exp_token_idxs
self.stacked_expert_weights[start_idx:end_idx] = flat_expert_weights[
idxs[start_idx:end_idx]]
group_sizes = torch.tensor(counts, dtype=torch.int32, device=self.device)
group_offset = torch.tensor(
tokens_per_expert - counts, dtype=torch.int32, device=self.device)
group_padded_offsets = [0 for _ in range(len(group_sizes))]
for i in range(1, len(group_sizes)):
group_padded_offsets[i] = group_padded_offsets[i - 1] + math.ceil(
(counts[i - 1] + 1) / self.padding_M) * self.padding_M
block_token = 128
M = math.ceil(
self.config["batch_size"] * self.config["seq_len"] *
self.config["n_experts_per_token"] / block_token) + self.config["n_routed_experts"]
group_idx_for_bx = [0 for _ in range(M)]
for bx in range(M):
m_start_padded = bx * block_token
for i in range(self.config["n_routed_experts"]):
if m_start_padded >= group_padded_offsets[i]:
group_idx_for_bx[bx] = i
group_padded_offsets = torch.tensor(
group_padded_offsets, dtype=torch.int32, device=self.device)
group_idx_for_bx = torch.tensor(group_idx_for_bx, dtype=torch.int32, device=self.device)
# Multi-stream execution
shared_stream = torch.cuda.Stream()
routed_stream = torch.cuda.default_stream()
torch.cuda.synchronize()
with torch.cuda.stream(routed_stream):
# Tilelang version: Grouped GEMM
self.routed_kernel(self.stacked_expert_tokens, self.stacked_expert_w_gate,
self.stacked_expert_w_up, self.stacked_expert_w_down,
self.stacked_expert_weights, group_sizes, group_offset,
group_padded_offsets, group_idx_for_bx, self.up_logits_routed,
self.expert_output_routed)
# Scatter reduce
self.expert_cache = torch.scatter_reduce(
self.expert_cache,
0,
self.stacked_expert_tokens_idxs.view(-1, 1).repeat(1, x_flat.shape[-1]),
self.expert_output_routed,
reduce='sum')
routed_output = self.expert_cache.view(*orig_shape)
with torch.cuda.stream(shared_stream):
self.shared_kernel(x_flat, self.shared_expert.W_gate_weight,
self.shared_expert.W_up_weight, self.shared_expert.W_down_weight,
self.up_logits_shared, self.expert_output_shared)
shared_output = self.expert_output_shared.view(*orig_shape)
torch.cuda.synchronize()
return shared_output + routed_output
def custom_kernel(data: Tuple[torch.Tensor, Dict, Dict]) -> torch.Tensor:
"""
DeepSeek-style Mixture of Experts using Tilelang.
Args:
data: Tuple of (input: torch.Tensor, weights: Dict[str, torch.Tensor], config: Dict)
- input: Input tensor of shape [batch_size, seq_len, hidden_size]
- weights: Dictionary containing model weights
- config: Dictionary containing model configuration parameters
Returns:
Tuple containing:
- output: Processed tensor [batch_size, seq_len, d_model]
"""
input_tensor, weights, config = data
dtype_str = "float16"
shared_kernel = moe_forward_tilelang_shared(
config["d_hidden"],
config["d_expert"],
config["n_shared_experts"],
dtype=dtype_str,
num_tokens=config["batch_size"] * config["seq_len"])
routed_kernel = moe_forward_tilelang_routed(
config["d_hidden"],
config["d_expert"],
config["n_routed_experts"],
dtype=dtype_str,
group_sum=config["batch_size"] * config["seq_len"] * config["n_experts_per_token"],
group_count=config["n_routed_experts"],
block_token=128,
block_dhidden=128,
block_dexpert=128,
threads=256,
num_stages=1,
k_pack=1,
coalesced_width=2)
moe = MoE(config, shared_kernel, routed_kernel, weights, padding_M=128)
output = moe(input_tensor)
return output
def main(d_hidden=7168,
d_expert=2048,
n_routed_experts=8,
n_shared_experts=1,
n_experts_per_token=4,
batch_size=1,
seq_len=8192):
config = {
"dhidden": d_hidden,
"dexpert": d_expert,
"nroutedexperts": n_routed_experts,
"nsharedexperts": n_shared_experts,
"nexpertspertoken": n_experts_per_token,
"bs": batch_size,
"seqlen": seq_len,
"seed": 81394
}
data = generate_input(**config)
torch.cuda.synchronize()
ref_output = ref_kernel(clone_data(data)).to(torch.float32)
torch.cuda.synchronize()
tilelang_output = custom_kernel(clone_data(data)).to(torch.float32)
torch.cuda.synchronize()
torch.testing.assert_close(ref_output, tilelang_output, atol=1e-2, rtol=1e-2)
print("✅ Tilelang and Torch match")
if __name__ == "__main__":
main()
import math
import torch
import torch.nn as nn
from typing import Dict, Tuple, Optional
# Reference code in PyTorch
class ExpertTorch(nn.Module):
def __init__(self, config: Dict, d_expert: Optional[int] = None):
super().__init__()
self.config = config
self.act_fn = nn.SiLU()
self.d_hidden: int = config["d_hidden"]
self.d_expert: int = config["d_expert"] if d_expert is None else d_expert
self.W_gate = nn.Linear(self.d_hidden, self.d_expert, bias=False)
self.W_up = nn.Linear(self.d_hidden, self.d_expert, bias=False)
self.W_down = nn.Linear(self.d_expert, self.d_hidden, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
gate = self.act_fn(self.W_gate(x))
out = self.W_down(gate * self.W_up(x))
return out
class MoEGateTorch(nn.Module):
def __init__(self, config: Dict):
super().__init__()
self.top_k: int = config["n_experts_per_token"]
self.num_experts: int = config["n_routed_experts"]
self.d_hidden: int = config["d_hidden"]
self.W_g = nn.Linear(self.d_hidden, self.num_experts, bias=False)
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
logits = self.W_g(x)
scores = logits.softmax(dim=-1)
topk_scores, topk_indices = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
return topk_indices, topk_scores
class MoETorch(nn.Module):
def __init__(self, config: Dict):
super().__init__()
self.config = config
self.experts = nn.ModuleList(
[ExpertTorch(config) for _ in range(config["n_routed_experts"])])
self.gating_network = MoEGateTorch(config)
shared_expert_dim = config["d_expert"] * config["n_shared_experts"]
self.shared_expert = ExpertTorch(config=config, d_expert=shared_expert_dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
shared_output = self.shared_expert(x)
expert_indices, expert_scores = self.gating_network(x)
batch_size, seq_len, hidden_dim = x.shape
orig_shape = x.shape
x_flat = x.view(-1, hidden_dim)
flat_expert_indices = expert_indices.view(-1)
flat_expert_weights = expert_scores.view(-1, 1)
routed_output_flat = self.moe_infer(x_flat, flat_expert_indices, flat_expert_weights)
routed_output = routed_output_flat.view(*orig_shape)
return routed_output + shared_output
@torch.no_grad()
def moe_infer(self, x: torch.Tensor, flat_expert_indices: torch.Tensor,
flat_expert_weights: torch.Tensor) -> torch.Tensor:
expert_cache = torch.zeros_like(x)
# test_expert_cache = torch.zeros((x.shape[0] * self.config["n_experts_per_token"], self.config["d_hidden"]))
# test_expert_tokens = torch.zeros((x.shape[0] * self.config["n_experts_per_token"], self.config["d_hidden"]))
# test_expert_ups = torch.zeros((self.config["n_routed_experts"], self.config["d_hidden"], self.config["d_expert"]))
# test_expert_tokens_num = torch.zeros((self.config["n_routed_experts"]))
idxs = flat_expert_indices.argsort()
counts = flat_expert_indices.bincount().cpu().numpy()
tokens_per_expert = counts.cumsum()
num_per_tok = self.config["n_experts_per_token"]
token_idxs = idxs // num_per_tok
for expert_id, end_idx in enumerate(tokens_per_expert):
start_idx = 0 if expert_id == 0 else tokens_per_expert[expert_id - 1]
if start_idx == end_idx:
continue
expert = self.experts[expert_id]
exp_token_idxs = token_idxs[start_idx:end_idx]
expert_tokens = x[exp_token_idxs]
expert_out = expert(expert_tokens)
expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
expert_cache.scatter_reduce_(
0, exp_token_idxs.view(-1, 1).repeat(1, x.shape[-1]), expert_out, reduce='sum')
return expert_cache
def ref_kernel(data: Tuple[torch.Tensor, Dict, Dict]) -> torch.Tensor:
"""
Reference implementation of DeepSeek-style Mixture of Experts using PyTorch.
Args:
data: Tuple of (input: torch.Tensor, weights: Dict[str, torch.Tensor], config: Dict)
- input: Input tensor of shape [batch_size, seq_len, hidden_dim]
- weights: Dictionary containing model weights
- config: Dictionary containing model configuration parameters
Returns:
Tuple containing:
- output: Processed tensor [batch_size, seq_len, d_model]
"""
input_tensor, weights, config = data
num_experts = config["n_routed_experts"]
moe = MoETorch(config)
# Fill in the given weights of the model
moe.gating_network.W_g.weight = nn.Parameter(weights['router.weight'])
for i in range(num_experts):
gate_proj_weight = weights[f'experts.{i}.0.weight']
up_proj_weight = weights[f'experts.{i}.1.weight']
down_proj_weight = weights[f'experts.{i}.2.weight']
# Transpose weights to match expected shape for nn.Linear
moe.experts[i].W_gate.weight = nn.Parameter(gate_proj_weight.t())
moe.experts[i].W_up.weight = nn.Parameter(up_proj_weight.t())
moe.experts[i].W_down.weight = nn.Parameter(down_proj_weight.t())
moe.shared_expert.W_gate.weight = nn.Parameter(weights['shared_experts.0.weight'].t())
moe.shared_expert.W_up.weight = nn.Parameter(weights['shared_experts.1.weight'].t())
moe.shared_expert.W_down.weight = nn.Parameter(weights['shared_experts.2.weight'].t())
output = moe(input_tensor)
return output
# Input generation for the reference code
def generate_input(dhidden: int, dexpert: int, nroutedexperts: int, nsharedexperts: int,
nexpertspertoken: int, bs: int, seqlen: int,
seed: int) -> Tuple[torch.Tensor, Dict, Dict]:
# Really dumb but for now _ isn't parsing correctly.
d_hidden = dhidden
d_expert = dexpert
n_routed_experts = nroutedexperts
n_shared_experts = nsharedexperts
n_experts_per_token = nexpertspertoken
batch_size = bs
seq_len = seqlen
config = {
"d_hidden": d_hidden,
"d_expert": d_expert,
"n_routed_experts": n_routed_experts,
"n_shared_experts": n_shared_experts,
"n_experts_per_token": n_experts_per_token,
"batch_size": batch_size,
"seq_len": seq_len,
}
gen = torch.Generator(device='cuda')
gen.manual_seed(seed)
num_experts = n_routed_experts
expert_dim = d_expert
weights = {}
input_tensor = torch.randn((batch_size, seq_len, d_hidden),
device='cuda',
dtype=torch.float16,
generator=gen).contiguous()
# Initialize router weights
weights['router.weight'] = torch.randn(
(num_experts, d_hidden), device="cuda", dtype=torch.float16,
generator=gen) / math.sqrt(d_hidden)
for i in range(num_experts):
weights[f'experts.{i}.0.weight'] = torch.randn(
(d_hidden, expert_dim), device='cuda', dtype=torch.float16,
generator=gen) / math.sqrt(expert_dim)
weights[f'experts.{i}.1.weight'] = torch.randn(
(d_hidden, expert_dim), device='cuda', dtype=torch.float16,
generator=gen) / math.sqrt(expert_dim)
weights[f'experts.{i}.2.weight'] = torch.randn(
(expert_dim, d_hidden), device='cuda', dtype=torch.float16,
generator=gen) / math.sqrt(d_hidden)
weights['shared_experts.0.weight'] = torch.randn(
(d_hidden, expert_dim * n_shared_experts),
device='cuda',
dtype=torch.float16,
generator=gen) / math.sqrt(expert_dim * n_shared_experts)
weights['shared_experts.1.weight'] = torch.randn(
(d_hidden, expert_dim * n_shared_experts),
device='cuda',
dtype=torch.float16,
generator=gen) / math.sqrt(expert_dim * n_shared_experts)
weights['shared_experts.2.weight'] = torch.randn((expert_dim * n_shared_experts, d_hidden),
device='cuda',
dtype=torch.float16,
generator=gen) / math.sqrt(d_hidden)
return (input_tensor, weights, config)
def clone_data(data):
"""
Recursively goes through data and clones all tensors.
"""
if isinstance(data, tuple):
return tuple(clone_data(x) for x in data)
elif isinstance(data, list):
return [clone_data(x) for x in data]
elif isinstance(data, dict):
return {k: clone_data(v) for k, v in data.items()}
elif isinstance(data, torch.Tensor):
return data.clone()
else:
return data
import tilelang.testing
import example_fusedmoe_tilelang
def test_example_fusedmoe_tilelang():
example_fusedmoe_tilelang.main(
d_hidden=1024,
d_expert=256,
n_routed_experts=8,
n_shared_experts=1,
n_experts_per_token=4,
batch_size=1,
seq_len=1024)
if __name__ == "__main__":
tilelang.testing.main()
# Gated Delta Net (GDN) kernel implementation with TileLang
## Requirement
- TileLang: `0.1.5+17fafc1b3026d910a83eb8052fdf811ba56be0b1`
- Triton: `3.3.0` (used for comparison)
- FLA: commit `f03cb3ae` (used for comparison)
## Get started
The [chunk_delta_h](common/chunk_delta_h.py) implements the most critical forward kernel of GDN. It's a good start to understand the GDN logic and the TileLang optimization.
## Acknowledgments
This kernel was developed by Yu Cheng and Zhengju Tang following in-depth discussions with Xiaomi's LLM-Core Team (MiMo).
# Reference: fla/ops/common/chunk_delta_h.py
import sys # noqa: F401
import tilelang
import tilelang.language as T
print(tilelang.__file__, flush=True)
# Add your fla repository path to sys.path
# Currently we use the fla repository from the flash-linear-attention project at commit id f03cb3ae
# sys.path.insert(0, "/home/tzj/flash-linear-attention")
try:
import fla
print(fla.__file__, flush=True)
from fla.ops.common.chunk_delta_h import chunk_gated_delta_rule_bwd_dhu
except ImportError:
print("fla not found, using tilelang implementation")
fla = None
import torch
import torch.nn.functional as F
torch.random.manual_seed(0)
# torch.set_printoptions(profile="full")
from utils import *
def prepare_input(
B,
S,
H,
DK,
DV,
chunk_size,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
state_dtype,
):
Q = torch.randn(B, S, H, DK, dtype=input_dtype).cuda()
K = torch.randn(B, S, H, DK, dtype=input_dtype).cuda()
K = F.normalize(K, dim=-1, p=2)
W = torch.randn(B, S, H, DK, dtype=input_dtype).cuda()
# Note: G should be in logspace and do chunkwise cumsum
G = torch.randn(B, S, H, dtype=gate_dtype).cuda()
G = F.logsigmoid(G)
try:
from fla.ops.utils.cumsum import chunk_local_cumsum
G = chunk_local_cumsum(G, chunk_size)
except ImportError:
print("fla not found, skip cumsum")
h0 = torch.randn(B, H, DK, DV, dtype=input_dtype).cuda()
dht = torch.randn(B, H, DK, DV, dtype=input_dtype).cuda()
dO = torch.randn(B, S, H, DV, dtype=input_dtype).cuda()
dv = torch.randn(B, S, H, DV, dtype=input_dtype).cuda()
return Q, K, W, G, h0, dht, dO, dv
def prepare_input_fake(
B,
S,
H,
DK,
DV,
chunk_size,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
state_dtype,
):
Q = torch.ones(B, S, H, DK, dtype=input_dtype).cuda()
K = torch.ones(B, S, H, DK, dtype=input_dtype).cuda()
W = torch.ones(B, S, H, DK, dtype=input_dtype).cuda()
G = torch.ones(B, S, H, dtype=gate_dtype).cuda()
h0 = torch.ones(B, H, DK, DV, dtype=input_dtype).cuda()
dht = torch.ones(B, H, DK, DV, dtype=input_dtype).cuda()
dO = torch.ones(B, S, H, DV, dtype=input_dtype).cuda()
dv = torch.ones(B, S, H, DV, dtype=input_dtype).cuda()
return Q, K, W, G, h0, dht, dO, dv
def prepare_output(
B,
S,
H,
DK,
DV,
chunk_size,
output_dtype,
gate_dtype,
state_dtype,
):
BS = S // chunk_size
dh = torch.empty(B, BS, H, DK, DV, dtype=output_dtype).cuda()
dh0 = torch.empty(B, H, DK, DV, dtype=state_dtype).cuda()
dv2 = torch.empty(B, S, H, DV, dtype=output_dtype).cuda()
return dh, dh0, dv2
def torch_chunk_gated_delta_rule_bwd_dhu(
Q: torch.Tensor,
K: torch.Tensor,
W: torch.Tensor,
G: torch.Tensor,
h0: torch.Tensor,
dht: torch.Tensor,
dO: torch.Tensor,
dv: torch.Tensor,
scale: float,
use_g: bool,
use_initial_state: bool,
use_final_state_gradient: bool,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
state_dtype,
):
B, S, H, DK = Q.shape
DV = dv.shape[-1]
block_S = 64
BS = S // block_S
dh, dh0, dv2 = torch.empty((B, BS, H, DK, DV), dtype=output_dtype), torch.empty(
(B, H, DK, DV), dtype=state_dtype), torch.empty((B, S, H, DV), dtype=output_dtype)
dh_tmp = torch.empty((B, H, DK, DV), dtype=accum_dtype)
dv_tmp = torch.empty((B, S, H, DV), dtype=accum_dtype)
Q_tmp = torch.empty((B, S, H, DK), dtype=accum_dtype)
if use_final_state_gradient:
dh_tmp = dht.clone().to(accum_dtype)
else:
dh_tmp = torch.zeros_like(dht).to(accum_dtype)
for i_s in range(BS - 1, -1, -1):
dh[:, i_s, :, :, :] = dh_tmp
dv_tmp = torch.matmul(K[:, i_s * block_S:(i_s + 1) * block_S, :, :].permute(0, 2, 1, 3),
dh_tmp.to(K.dtype)).permute(0, 2, 1, 3)
if use_g:
for i_bh in range(B * H):
i_b, i_h = i_bh // H, i_bh % H
for i_s2 in range(block_S):
if G[i_b, i_s * block_S + block_S - 1, i_h] - G[i_b, i_s * block_S + i_s2,
i_h] <= 0:
dv_tmp[i_b, i_s2,
i_h, :] *= torch.exp(G[i_b, i_s * block_S + block_S - 1, i_h] -
G[i_b, i_s * block_S + i_s2, i_h])
else:
dv_tmp[i_b, i_s2, i_h, :] = 0
dv_tmp += dv[:, i_s * block_S:(i_s + 1) * block_S, :, :]
dv2[:, i_s * block_S:(i_s + 1) * block_S, :, :] = dv_tmp
if use_g:
G_last = G[:, i_s * block_S + block_S - 1, :]
for i_bh in range(B * H):
i_b, i_h = i_bh // H, i_bh % H
dh_tmp[i_b, i_h, :, :] *= torch.exp(G_last[i_b, i_h])
Q_tmp = Q[:, i_s * block_S:(i_s + 1) * block_S, :, :]
for i_s2 in range(block_S):
for i_k in range(DK):
Q_tmp[:, i_s2, :, i_k] *= torch.exp(G[:, i_s * block_S + i_s2, :])
Q_tmp *= scale
W_tmp = W[:, i_s * block_S:(i_s + 1) * block_S, :, :]
dO_tmp = dO[:, i_s * block_S:(i_s + 1) * block_S, :, :]
torch.backends.cuda.matmul.allow_tf32 = True
dh_tmp += torch.matmul(Q_tmp.permute(0, 2, 3, 1), dO_tmp.permute(0, 2, 1, 3))
dh_tmp -= torch.matmul(W_tmp.permute(0, 2, 3, 1), dv_tmp.permute(0, 2, 1, 3))
torch.backends.cuda.matmul.allow_tf32 = False
if use_initial_state:
dh0 = dh_tmp[:, :, :, :]
else:
dh0 = torch.zeros_like(dh_tmp[:, :, :, :])
print(dh0.dtype)
return dh, dh0, dv2
@tilelang.jit(out_idx=[-3, -2, -1])
def tilelang_chunk_gated_delta_rule_bwd_dhu(
# task config
B,
S,
H,
DK,
DV,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
state_dtype,
chunk_size,
scale,
use_g=True,
use_initial_state=True,
use_final_state_gradient=True,
# kernel config
block_DV=64,
threads=256,
num_stages=0,
):
block_S = chunk_size
# Should support cu_seqlen
BS = S // block_S
Q_shape = (B, S, H, DK)
K_shape = (B, S, H, DK)
W_shape = (B, S, H, DK)
G_shape = (B, S, H)
h0_shape = (B, H, DK, DV)
dht_shape = (B, H, DK, DV)
dO_shape = (B, S, H, DV)
dv_shape = (B, S, H, DV)
dh_shape = (B, BS, H, DK, DV)
dh0_shape = (B, H, DK, DV)
dv2_shape = (B, S, H, DV)
@T.prim_func
def kernel(
# Input
Q: T.Tensor(Q_shape, dtype=input_dtype),
K: T.Tensor(K_shape, dtype=input_dtype),
W: T.Tensor(W_shape, dtype=input_dtype),
G: T.Tensor(G_shape, dtype=gate_dtype),
h0: T.Tensor(h0_shape, dtype=input_dtype),
dht: T.Tensor(dht_shape, dtype=input_dtype),
dO: T.Tensor(dO_shape, dtype=input_dtype),
dv: T.Tensor(dv_shape, dtype=input_dtype),
# Output
dh: T.Tensor(dh_shape, dtype=output_dtype),
dh0: T.Tensor(dh0_shape, dtype=state_dtype),
dv2: T.Tensor(dv2_shape, dtype=output_dtype),
):
with T.Kernel(T.ceildiv(DV, block_DV), B * H, threads=threads) as (bv, bbh):
bb, bh = bbh // H, bbh % H
b_dh_shared = T.alloc_shared((DK, block_DV), dtype=output_dtype)
b_dh_shared_fp32 = T.alloc_shared((DK, block_DV), dtype=state_dtype)
b_dh_fragment = T.alloc_fragment((DK, block_DV), dtype=accum_dtype)
b_dh_fragment_1 = T.alloc_fragment((DK, block_DV), dtype=accum_dtype)
b_dh_fragment_2 = T.alloc_fragment((DK, block_DV), dtype=accum_dtype)
dv_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype)
dv_fragment = T.alloc_fragment((block_S, block_DV), dtype=accum_dtype)
dv_fragment_2 = T.alloc_fragment((block_S, block_DV), dtype=accum_dtype)
dO_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype)
dO_shared_t = T.alloc_shared((block_DV, block_S), dtype="float32")
dO_fragment = T.alloc_fragment((block_S, block_DV), dtype="float32")
dO_fragment_t = T.alloc_fragment((block_DV, block_S), dtype="float32")
K_shared = T.alloc_shared((block_S, DK), dtype=input_dtype)
Q_shared = T.alloc_shared((block_S, DK), dtype=input_dtype)
Q_shared_fp32 = T.alloc_shared((block_S, DK), dtype="float32")
W_shared = T.alloc_shared((block_S, DK), dtype=input_dtype)
G_last_local = T.alloc_local((1), dtype=gate_dtype)
G_last_local_exp = T.alloc_local((1), dtype=gate_dtype)
G_shared = T.alloc_shared((block_S), dtype=gate_dtype, scope="shared")
G_fragment = T.alloc_fragment((block_S), dtype=gate_dtype)
G_fragment_post = T.alloc_fragment((block_S), dtype=gate_dtype)
G_fragment_exp = T.alloc_fragment((block_S), dtype=gate_dtype)
Q_fragment = T.alloc_fragment((block_S, DK), dtype=accum_dtype)
Q_fragment_t = T.alloc_fragment((DK, block_S), dtype=accum_dtype)
T.use_swizzle(10)
T.annotate_layout({
b_dh_shared: tilelang.layout.make_swizzled_layout(b_dh_shared),
b_dh_shared_fp32: tilelang.layout.make_swizzled_layout(b_dh_shared_fp32),
dv_shared: tilelang.layout.make_swizzled_layout(dv_shared),
dO_shared: tilelang.layout.make_swizzled_layout(dO_shared),
dO_shared_t: tilelang.layout.make_swizzled_layout(dO_shared_t),
K_shared: tilelang.layout.make_swizzled_layout(K_shared),
Q_shared: tilelang.layout.make_swizzled_layout(Q_shared),
Q_shared_fp32: tilelang.layout.make_swizzled_layout(Q_shared_fp32),
W_shared: tilelang.layout.make_swizzled_layout(W_shared),
})
if use_final_state_gradient:
T.copy(dht[bb, bh, 0:DK, bv * block_DV:(bv + 1) * block_DV], b_dh_shared)
T.copy(b_dh_shared, b_dh_fragment)
else:
T.clear(b_dh_fragment)
for i_s in T.Pipelined(T.ceildiv(S, block_S), num_stages=num_stages):
# The gradient should be stored in the reverse order
i_s_inv = T.ceildiv(S, block_S) - i_s - 1
# Store the updated dh
T.copy(b_dh_fragment, b_dh_shared)
T.copy(b_dh_shared, dh[bb, i_s_inv, bh, 0:DK, bv * block_DV:(bv + 1) * block_DV])
# Update dv
T.copy(K[bb, i_s_inv * block_S:(i_s_inv + 1) * block_S, bh, 0:DK], K_shared)
T.gemm(K_shared, b_dh_shared, dv_fragment, clear_accum=True)
if use_g:
T.copy(
G[bb, i_s_inv * block_S:(i_s_inv + 1) * block_S, bh],
G_shared,
disable_tma=True)
T.copy(G_shared, G_fragment)
G_last_local[0] = G_shared[block_S - 1]
G_last_local_exp[0] = T.exp(G_last_local[0])
for i_s2 in T.Parallel(block_S):
G_fragment_post[i_s2] = T.exp(G_last_local[0] - G_fragment[i_s2])
for i_s2, i_v in T.Parallel(block_S, block_DV):
# with T.If(G_last_local[0] - G_shared[i_s2] <= 0):
with T.If(G_last_local[0] - G_fragment[i_s2] <= 0):
with T.Then():
dv_fragment[i_s2,
i_v] = dv_fragment[i_s2, i_v] * G_fragment_post[i_s2]
with T.Else():
dv_fragment[i_s2, i_v] = 0
T.copy(
dv[bb, i_s_inv * block_S:(i_s_inv + 1) * block_S, bh,
bv * block_DV:(bv + 1) * block_DV], dv_shared)
T.copy(dv_shared, dv_fragment_2)
for i_s2, i_v in T.Parallel(block_S, block_DV):
dv_fragment[i_s2, i_v] = dv_fragment[i_s2, i_v] + dv_fragment_2[i_s2, i_v]
# Store the updated dv
T.copy(dv_fragment, dv_shared)
T.copy(
dv_shared, dv2[bb, i_s_inv * block_S:(i_s_inv + 1) * block_S, bh,
bv * block_DV:(bv + 1) * block_DV])
# Update dh
T.copy(Q[bb, i_s_inv * block_S:(i_s_inv + 1) * block_S, bh, 0:DK], Q_shared)
T.copy(W[bb, i_s_inv * block_S:(i_s_inv + 1) * block_S, bh, 0:DK], W_shared)
T.clear(Q_fragment)
if use_g:
for i_k, i_v in T.Parallel(DK, block_DV):
b_dh_fragment[i_k, i_v] *= G_last_local_exp[0]
T.copy(Q_shared, Q_fragment)
for i_s2 in T.Parallel(block_S):
G_fragment_exp[i_s2] = T.exp(G_shared[i_s2])
for i_s2, i_k in T.Parallel(block_S, DK):
# Q_fragment[i_s2, i_k] = Q_fragment[i_s2, i_k] * T.exp(G_shared[i_s2]) * scale
Q_fragment[i_s2, i_k] = Q_fragment[i_s2, i_k] * G_fragment_exp[i_s2] * scale
else:
T.copy(Q_shared, Q_fragment)
for i_s2, i_k in T.Parallel(block_S, DK):
Q_fragment[i_s2, i_k] = Q_fragment[i_s2, i_k] * scale
# Get transpose of Q_fragment to meet tf32 gemm requirement
for i_s2, i_k in T.Parallel(block_S, DK):
Q_fragment_t[i_k, i_s2] = Q_fragment[i_s2, i_k]
T.copy(
dO[bb, i_s_inv * block_S:(i_s_inv + 1) * block_S, bh,
bv * block_DV:(bv + 1) * block_DV], dO_shared)
T.copy(dO_shared, dO_fragment)
for i_s2, i_v in T.Parallel(block_S, block_DV):
dO_fragment_t[i_v, i_s2] = dO_fragment[i_s2, i_v]
T.copy(dO_fragment_t, dO_shared_t)
T.clear(b_dh_fragment_1)
T.gemm(Q_fragment_t, dO_shared_t, b_dh_fragment_1, transpose_B=True)
T.clear(b_dh_fragment_2)
T.gemm(W_shared, dv_shared, b_dh_fragment_2, transpose_A=True)
for i_k, i_v in T.Parallel(DK, block_DV):
b_dh_fragment[i_k, i_v] += b_dh_fragment_1[i_k, i_v] - b_dh_fragment_2[i_k, i_v]
if use_initial_state:
T.copy(b_dh_fragment, dh0[bb, bh, 0:DK, bv * block_DV:(bv + 1) * block_DV])
return kernel
def test_result(dh_0, dh0_0, dv2_0, dh_1, dh0_1, dv2_1, name):
try:
torch.testing.assert_close(dh_0, dh_1, rtol=1e-2, atol=1e-2, equal_nan=True)
print(f"{name} dh_0 and dh_1 passed for {name}")
except Exception as e:
print(f"{name} dh_0 and dh_1 are not close for {name}")
print(e, end="\n\n")
try:
torch.testing.assert_close(dh0_0, dh0_1, rtol=1e-2, atol=1e-2, equal_nan=True)
print(f"{name} dh0_0 and dh0_1 passed for {name}")
except Exception as e:
print(f"{name} dh0_0 and dh0_1 are not close for {name}")
print(e, end="\n\n")
try:
torch.testing.assert_close(dv2_0, dv2_1, rtol=1e-2, atol=1e-2, equal_nan=True)
print(f"{name} dv2_0 and dv2_1 passed for {name}")
except Exception as e:
print(f"{name} dv2_0 and dv2_1 are not close for {name}")
print(e, end="\n\n")
close = torch.isclose(dh_0, dh_1, rtol=1e-2, atol=1e-2)
mismatch_indices = torch.nonzero(~close, as_tuple=True)
error_num = 0
for indices in zip(*mismatch_indices):
if error_num < 100:
print(
f"{name} dh_0[{[idx.item() for idx in indices]}] = {dh_0[indices[0].item(), indices[1].item(), indices[2].item(), indices[3].item(), indices[4].item()]}, dh_1[{[idx.item() for idx in indices]}] = {dh_1[indices[0].item(), indices[1].item(), indices[2].item(), indices[3].item(), indices[4].item()]}"
)
error_num += 1
close = torch.isclose(dh0_0, dh0_1, rtol=1e-2, atol=1e-2)
mismatch_indices = torch.nonzero(~close, as_tuple=True)
error_num = 0
for indices in zip(*mismatch_indices):
if error_num < 100:
print(
f"{name} dh0_0[{[idx.item() for idx in indices]}] = {dh0_0[indices[0].item(), indices[1].item(), indices[2].item(), indices[3].item()]}, dh0_1[{[idx.item() for idx in indices]}] = {dh0_1[indices[0].item(), indices[1].item(), indices[2].item(), indices[3].item()]}"
)
error_num += 1
close = torch.isclose(dv2_0, dv2_1, rtol=1e-2, atol=1e-2)
mismatch_indices = torch.nonzero(~close, as_tuple=True)
error_num = 0
for indices in zip(*mismatch_indices):
if error_num < 100:
print(
f"{name} dv2_0[{[idx.item() for idx in indices]}] = {dv2_0[indices[0].item(), indices[1].item(), indices[2].item(), indices[3].item()]}, dv2_1[{[idx.item() for idx in indices]}] = {dv2_1[indices[0].item(), indices[1].item(), indices[2].item(), indices[3].item()]}"
)
error_num += 1
def run_test(
B,
S,
H,
DK,
DV,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
state_dtype,
chunk_size,
scale,
use_g=True,
use_initial_state=True,
use_final_state_gradient=True,
block_DV=64,
threads=256,
num_stages=0,
use_torch=False,
):
Q, K, W, G, h0, dht, dO, dv = prepare_input(B, S, H, DK, DV, chunk_size,
getattr(torch, input_dtype),
getattr(torch, output_dtype),
getattr(torch, accum_dtype),
getattr(torch, gate_dtype),
getattr(torch, state_dtype))
dh_ref, dh0_ref, dv2_ref = prepare_output(B, S, H, DK, DV, chunk_size,
getattr(torch, output_dtype),
getattr(torch, gate_dtype),
getattr(torch, state_dtype))
dh_tilelang, dh0_tilelang, dv2_tilelang = prepare_output(B, S, H, DK, DV, chunk_size,
getattr(torch, output_dtype),
getattr(torch, gate_dtype),
getattr(torch, state_dtype))
# fla ref
print("fla running...", flush=True)
if use_g:
dh_ref, dh0_ref, dv2_ref = chunk_gated_delta_rule_bwd_dhu(Q, K, W, G, h0, dht, dO, dv,
scale)
else:
G = G.fill_(0)
dh_ref, dh0_ref, dv2_ref = chunk_gated_delta_rule_bwd_dhu(Q, K, W, G, h0, dht, dO, dv,
scale)
# tilelang
print("tilelang running...", flush=True)
kernel = tilelang_chunk_gated_delta_rule_bwd_dhu(B, S, H, DK, DV, input_dtype, output_dtype,
accum_dtype, gate_dtype, state_dtype,
chunk_size, scale, use_g, use_initial_state,
use_final_state_gradient, block_DV, threads,
num_stages)
# kernel = tilelang.compile(program)
print(kernel.get_kernel_source())
dh_tilelang, dh0_tilelang, dv2_tilelang = kernel(Q, K, W, G, h0, dht, dO, dv)
fla_time = do_bench(
chunk_gated_delta_rule_bwd_dhu, Q, K, W, G, h0, dht, dO, dv, scale, chunk_size=chunk_size)
tilelang_time = do_bench(kernel, Q, K, W, G, h0, dht, dO, dv)
print(f"fla time: {fla_time} ms")
print(f"tilelang time: {tilelang_time} ms")
assert_similar(dh_tilelang, dh_ref, 1e-5, "fla-tilelang", data="dh")
assert_similar(dh0_tilelang, dh0_ref, 1e-5, "fla-tilelang", data="dh0")
assert_similar(dv2_tilelang, dv2_ref, 1e-5, "fla-tilelang", data="dv2")
# torch ref
if use_torch:
print("torch running...", flush=True)
if use_g:
dh_ref_torch, dh0_ref_torch, dv2_ref_torch = torch_chunk_gated_delta_rule_bwd_dhu(
Q, K, W, G, h0, dht, dO, dv, scale, use_g, use_initial_state,
use_final_state_gradient, getattr(torch, input_dtype), getattr(torch, output_dtype),
getattr(torch, accum_dtype), getattr(torch,
gate_dtype), getattr(torch, state_dtype))
dh_ref_torch = dh_ref_torch.cuda()
dh0_ref_torch = dh0_ref_torch.cuda()
dv2_ref_torch = dv2_ref_torch.cuda()
else:
dh_ref_torch, dh0_ref_torch, dv2_ref_torch = torch_chunk_gated_delta_rule_bwd_dhu(
Q, K, W, None, h0, dht, dO, dv, scale, use_g, use_initial_state,
use_final_state_gradient, getattr(torch, input_dtype), getattr(torch, output_dtype),
getattr(torch, accum_dtype), getattr(torch,
gate_dtype), getattr(torch, state_dtype))
dh_ref_torch = dh_ref_torch.cuda()
dh0_ref_torch = dh0_ref_torch.cuda()
dv2_ref_torch = dv2_ref_torch.cuda()
assert_similar(dh_ref_torch, dh_ref, 1e-5, "torch-fla", data="dh")
assert_similar(dh0_ref_torch, dh0_ref, 1e-5, "torch-fla", data="dh0")
assert_similar(dv2_ref_torch, dv2_ref, 1e-5, "torch-fla", data="dv2")
assert_similar(dh_ref_torch, dh_tilelang, 1e-5, "torch-tilelang", data="dh")
assert_similar(dh0_ref_torch, dh0_tilelang, 1e-5, "torch-tilelang", data="dh0")
assert_similar(dv2_ref_torch, dv2_tilelang, 1e-5, "torch-tilelang", data="dv2")
def do_bench(fn, *args, warmup=10, rep=10, **kwargs):
"""
Do benchmark for a function.
"""
start_event = [torch.cuda.Event(enable_timing=True) for i in range(rep)]
end_event = [torch.cuda.Event(enable_timing=True) for i in range(rep)]
for _ in range(warmup):
fn(*args, **kwargs)
torch.cuda.synchronize()
for i in range(rep):
start_event[i].record()
fn(*args, **kwargs)
end_event[i].record()
torch.cuda.synchronize()
# Record clocks
times = torch.tensor(
[s.elapsed_time(e) for s, e in zip(start_event, end_event)],
dtype=torch.float,
)
return times.mean().item()
def main():
DK = 128
run_test(
B=1,
S=32768,
H=8,
DK=DK,
DV=128,
input_dtype="bfloat16",
output_dtype="bfloat16",
accum_dtype="float32",
gate_dtype="float32",
state_dtype="float32",
chunk_size=64,
scale=DK**-0.5,
use_g=True,
use_initial_state=True,
use_final_state_gradient=True,
block_DV=32,
threads=128,
num_stages=1,
use_torch=False,
)
if __name__ == "__main__":
main()
# Reference: fla/ops/common/chunk_delta_h.py
import sys # noqa: F401
import tilelang
import tilelang.language as T
# Add your fla repository path to sys.path
# Currently we use the fla repository from the flash-linear-attention project at commit id f03cb3ae
# sys.path.insert(0, "/home/tzj/flash-linear-attention")
try:
import fla
print(fla.__file__)
from fla.ops.common.chunk_delta_h import chunk_gated_delta_rule_fwd_h
except ImportError:
print("fla not found, using tilelang implementation")
fla = None
import torch
import torch.nn.functional as F
from tilelang.engine.callback import register_cuda_postproc_callback # noqa: F401
from utils import *
# (zhengju) We can slightly modify the generated cuda code from tilelang lowering
# in the debug folder to make the performance better. To enable this callback,
# you can comment out the following function.
# @register_cuda_postproc_callback
# def tilelang_callback_cuda_postproc(code, _):
# cuda_code = open("../debug/chunk_delta_h_fuse.cu", "r").read()
# code = cuda_code
# return code
torch.random.manual_seed(0)
def prepare_input(
B,
S,
H,
DK,
DV,
chunk_size,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
):
K = torch.randn(B, S, H, DK, dtype=input_dtype).cuda()
K = F.normalize(K, dim=-1, p=2)
W = torch.randn(B, S, H, DK, dtype=input_dtype).cuda()
W = F.normalize(W, dim=-1, p=2)
U = torch.randn(B, S, H, DV, dtype=input_dtype).cuda()
U = F.normalize(U, dim=-1, p=2)
G = torch.randn(B, S, H, dtype=gate_dtype).cuda()
G = F.logsigmoid(G)
try:
from fla.ops.utils.cumsum import chunk_local_cumsum
G = chunk_local_cumsum(G, chunk_size)
except ImportError:
print("fla not found, skip cumsum")
initial_state = torch.randn(B, H, DK, DV, dtype=input_dtype).cuda()
return K, W, U, G, initial_state
def prepare_output(
B,
S,
H,
DK,
DV,
chunk_size,
output_dtype,
state_dtype,
):
BS = S // chunk_size
h = torch.empty(B, BS, H, DK, DV, dtype=output_dtype).cuda()
final_state = torch.empty(B, H, DK, DV, dtype=state_dtype).cuda()
V_new = torch.empty(B, S, H, DV, dtype=output_dtype).cuda()
return h, final_state, V_new
@tilelang.jit(out_idx=[-3, -2, -1])
def tilelang_chunk_gated_delta_rule_fwd_h(
# task config
B,
S,
H,
DK,
DV,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
state_dtype,
chunk_size,
use_g=True,
use_initial_state=True,
store_final_state=True,
save_new_value=True,
# kernel config
block_DK=64,
block_DV=64,
threads=256,
num_stages=0,
):
block_S = chunk_size
BS = S // block_S
K_shape = (B, S, H, DK)
V_shape = (B, S, H, DV)
W_shape = (B, S, H, DK)
U_shape = (B, S, H, DV)
G_shape = (B, S, H)
h_shape = (B, BS, H, DK, DV)
initial_state_shape = (B, H, DK, DV)
final_state_shape = (B, H, DK, DV)
@T.prim_func
def kernel(
K: T.Tensor(K_shape, dtype=input_dtype),
W: T.Tensor(W_shape, dtype=input_dtype),
U: T.Tensor(U_shape, dtype=input_dtype),
G: T.Tensor(G_shape, dtype=gate_dtype),
initial_state: T.Tensor(initial_state_shape, dtype=input_dtype),
h: T.Tensor(h_shape, dtype=output_dtype),
final_state: T.Tensor(final_state_shape, dtype=state_dtype),
V_new: T.Tensor(V_shape, dtype=output_dtype),
):
with T.Kernel(T.ceildiv(DV, block_DV), B * H, threads=threads) as (bv, bbh):
bb, bh = bbh // H, bbh % H
b_h_shared = T.alloc_shared((DK, block_DV), dtype=input_dtype)
b_h_fragment = T.alloc_fragment((DK, block_DV), dtype=accum_dtype)
U_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype)
U_fragment = T.alloc_fragment((block_S, block_DV), dtype=accum_dtype)
W_shared = T.alloc_shared((block_S, DK), dtype=input_dtype)
V_new_fragment = T.alloc_fragment((block_S, block_DV), dtype=accum_dtype)
V_new_shared = T.alloc_shared((block_S, block_DV), dtype=output_dtype)
K_shared = T.alloc_shared((block_S, DK), dtype=input_dtype)
G_last_local = T.alloc_local((1), dtype=gate_dtype)
G_shared = T.alloc_shared((block_S, block_DV), dtype=gate_dtype)
G_fragment = T.alloc_fragment((block_S, block_DV), dtype=gate_dtype)
T.annotate_layout({
b_h_shared: tilelang.layout.make_swizzled_layout(b_h_shared),
U_shared: tilelang.layout.make_swizzled_layout(U_shared),
W_shared: tilelang.layout.make_swizzled_layout(W_shared),
V_new_shared: tilelang.layout.make_swizzled_layout(V_new_shared),
K_shared: tilelang.layout.make_swizzled_layout(K_shared),
G_shared: tilelang.layout.make_swizzled_layout(G_shared),
})
T.use_swizzle(10)
if use_initial_state:
T.copy(initial_state[bb, bh, 0:DK, bv * block_DV:(bv + 1) * block_DV], b_h_shared)
T.copy(b_h_shared, b_h_fragment)
else:
T.clear(b_h_fragment)
for i_s in T.Pipelined(T.ceildiv(S, block_S), num_stages=num_stages):
# Store previous result to the hidden tensor, like the epilogue
T.copy(b_h_shared, h[bb, i_s, bh, 0:DK, bv * block_DV:(bv + 1) * block_DV])
# Recurrence
T.copy(W[bb, i_s * block_S:(i_s + 1) * block_S, bh, 0:DK], W_shared)
T.gemm(W_shared, b_h_shared, V_new_fragment, clear_accum=True)
# U - W * S
T.copy(
U[bb, i_s * block_S:(i_s + 1) * block_S, bh, bv * block_DV:(bv + 1) * block_DV],
U_shared)
T.copy(U_shared, U_fragment)
for i_s2, i_v in T.Parallel(block_S, block_DV):
V_new_fragment[i_s2, i_v] = -V_new_fragment[i_s2, i_v] + U_fragment[i_s2, i_v]
# Save V_new
if save_new_value:
T.copy(V_new_fragment, dst=V_new_shared)
T.copy(
V_new_shared, V_new[bb, i_s * block_S:(i_s + 1) * block_S, bh,
bv * block_DV:(bv + 1) * block_DV])
T.copy(K[bb, i_s * block_S:(i_s + 1) * block_S, bh, 0:DK], K_shared)
# use_g
if use_g:
G_last_local[0] = G[bb, (i_s + 1) * block_S - 1, bh]
for i_s2, i_v in T.Parallel(block_S, block_DV):
G_shared[i_s2, i_v] = G[bb, i_s * block_S + i_s2, bh]
T.copy(G_shared, G_fragment)
for i_s2, i_v in T.Parallel(block_S, block_DV):
with T.If(G_last_local[0] - G_fragment[i_s2, i_v] <= 0):
with T.Then():
V_new_fragment[i_s2, i_v] = V_new_fragment[i_s2, i_v] * T.exp(
G_last_local[0] - G_fragment[i_s2, i_v])
with T.Else():
V_new_fragment[i_s2, i_v] = 0
G_last_local[0] = T.exp(G_last_local[0])
for i_k, i_v in T.Parallel(DK, block_DV):
b_h_fragment[i_k, i_v] *= G_last_local[0]
# Update intermediate results
T.copy(V_new_fragment, V_new_shared)
T.gemm(K_shared, V_new_shared, b_h_fragment, transpose_A=True)
T.copy(b_h_fragment, b_h_shared)
# Save final state
if store_final_state:
T.copy(b_h_fragment, final_state[bb, bh, 0:DK, bv * block_DV:(bv + 1) * block_DV])
return kernel
def do_bench(fn, *args, warmup=10, rep=10, **kwargs):
"""
Do benchmark for a function.
"""
start_event = [torch.cuda.Event(enable_timing=True) for i in range(rep)]
end_event = [torch.cuda.Event(enable_timing=True) for i in range(rep)]
for _ in range(warmup):
fn(*args, **kwargs)
torch.cuda.synchronize()
for i in range(rep):
start_event[i].record()
fn(*args, **kwargs)
end_event[i].record()
torch.cuda.synchronize()
# Record clocks
times = torch.tensor(
[s.elapsed_time(e) for s, e in zip(start_event, end_event)],
dtype=torch.float,
)
return times.mean().item()
def run_test(
B,
S,
H,
DK,
DV,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
state_dtype,
chunk_size,
use_g=True,
use_initial_state=True,
store_final_state=True,
save_new_value=True,
block_DK=64,
block_DV=32,
threads=128,
num_stages=0,
):
K, W, U, G, initial_state = prepare_input(B, S, H, DK, DV, chunk_size,
getattr(torch, input_dtype),
getattr(torch, output_dtype),
getattr(torch, accum_dtype),
getattr(torch, gate_dtype))
h_ref, final_state_ref, V_new_ref = prepare_output(B, S, H, DK, DV, chunk_size,
getattr(torch, output_dtype),
getattr(torch, state_dtype))
h_tilelang, final_state_tilelang, V_new_tilelang = prepare_output(B, S, H, DK, DV, chunk_size,
getattr(torch, output_dtype),
getattr(torch, state_dtype))
# fla ref
h_ref, V_new_ref, final_state_ref = chunk_gated_delta_rule_fwd_h(K, W, U, G, initial_state,
store_final_state, chunk_size,
save_new_value)
# tilelang
kernel = tilelang_chunk_gated_delta_rule_fwd_h(B, S, H, DK, DV, input_dtype, output_dtype,
accum_dtype, gate_dtype, state_dtype, chunk_size,
use_g, use_initial_state, store_final_state,
save_new_value, block_DK, block_DV, threads,
num_stages)
h_tilelang, final_state_tilelang, V_new_tilelang = kernel(K, W, U, G, initial_state)
# (zhengju) If you want to print the generated cuda code, you can uncomment the following line
# print("CUDA Code:\n", kernel.get_kernel_source())
fla_time = do_bench(chunk_gated_delta_rule_fwd_h, K, W, U, G, initial_state, store_final_state,
chunk_size, save_new_value)
tilelang_time = do_bench(kernel, K, W, U, G, initial_state)
# check correctness
try:
h_ref_fp32 = h_ref.to(torch.float32)
h_tilelang_fp32 = h_tilelang.to(torch.float32)
assert_similar(
h_ref_fp32,
h_tilelang_fp32,
eps=1e-5,
name="tilelang chunk gated delta rule fwd h",
raise_assert=False)
print("tilelang chunk gated delta rule fwd h passed √")
except Exception as e:
print("tilelang chunk gated delta rule fwd h failed ✗")
print(e)
try:
final_state_ref_fp32 = final_state_ref.to(torch.float32)
final_state_tilelang_fp32 = final_state_tilelang.to(torch.float32)
assert_similar(
final_state_ref_fp32,
final_state_tilelang_fp32,
eps=1e-5,
name="tilelang chunk gated delta rule fwd final_state",
raise_assert=False)
print("tilelang chunk gated delta rule fwd final_state passed √")
except Exception as e:
print("tilelang chunk gated delta rule fwd final_state failed ✗")
print(e)
try:
V_new_ref_fp32 = V_new_ref.to(torch.float32)
V_new_tilelang_fp32 = V_new_tilelang.to(torch.float32)
assert_similar(
V_new_ref_fp32,
V_new_tilelang_fp32,
eps=1e-5,
name="tilelang chunk gated delta rule fwd V_new",
raise_assert=False)
print("tilelang chunk gated delta rule fwd V_new passed √")
except Exception as e:
print("tilelang chunk gated delta rule fwd V_new failed ✗")
print(e)
print(f"tilelang time: {tilelang_time} ms")
print(f"fla time: {fla_time} ms")
def main():
run_test(
B=1,
S=32768,
H=32,
DK=128,
DV=128,
input_dtype="bfloat16",
output_dtype="bfloat16",
accum_dtype="float32",
gate_dtype="float32",
state_dtype="float32",
chunk_size=64,
use_g=True,
use_initial_state=True,
store_final_state=True,
save_new_value=True,
block_DK=64,
block_DV=32,
threads=128,
num_stages=1,
)
if __name__ == "__main__":
main()
# Reference: fla/ops/common/chunk_o.py
import tilelang
import tilelang.language as T
import sys # noqa: F401
# Add your fla repository path to sys.path
# Currently we use the fla repository from the flash-linear-attention project at commit id f03cb3ae
# sys.path.insert(0, "/home/tzj/flash-linear-attention")
try:
import fla
print(fla.__file__)
from fla.ops.common.chunk_o import chunk_fwd_o
except ImportError:
print("fla not found, using tilelang implementation")
fla = None
import torch
torch.random.manual_seed(1)
def prepare_input(
B,
S,
H,
DK,
DV,
chunk_size,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
):
BS = chunk_size
Q = torch.randn(B, S, H, DK, dtype=input_dtype).cuda()
K = torch.randn(B, S, H, DK, dtype=input_dtype).cuda()
V = torch.randn(B, S, H, DV, dtype=input_dtype).cuda()
HIDDEN = torch.randn(B, S // BS, H, DK, DV, dtype=input_dtype).cuda()
G = torch.randn(B, S, H, dtype=gate_dtype).cuda()
return Q, K, V, HIDDEN, G
def prepare_output(
B,
S,
H,
DK,
DV,
chunk_size,
output_dtype,
):
O = torch.empty(B, S, H, DV, dtype=output_dtype).cuda()
return O
@tilelang.jit(out_idx=[-1])
def tilelang_chunk_fwd_o(
# task config
B,
S,
H,
DK,
DV,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
chunk_size,
scale,
use_g,
# kernel config
block_S=64,
block_DK=64,
block_DV=64,
threads=256,
num_stages=0,
):
assert chunk_size == block_S, "chunk_size must be equal to block_S"
BS = chunk_size
Q_shape = (B, S, H, DK)
K_shape = (B, S, H, DK)
V_shape = (B, S, H, DV)
H_shape = (B, S // BS, H, DK, DV)
G_shape = (B, S, H)
O_shape = (B, S, H, DV)
@T.prim_func
def kernel(
Q: T.Tensor(Q_shape, dtype=input_dtype),
K: T.Tensor(K_shape, dtype=input_dtype),
V: T.Tensor(V_shape, dtype=input_dtype),
HIDDEN: T.Tensor(H_shape, dtype=input_dtype),
G: T.Tensor(G_shape, dtype=gate_dtype),
O: T.Tensor(O_shape, dtype=output_dtype),
):
with T.Kernel(
T.ceildiv(DV, block_DV), T.ceildiv(S, block_S), B * H,
threads=threads) as (bv, bs, bbh):
bb, bh = bbh // H, bbh % H
Q_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype)
K_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype)
V_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype)
H_shared = T.alloc_shared((block_DK, block_DV), dtype=input_dtype)
A_shared = T.alloc_shared((block_S, block_S), dtype=input_dtype)
O_shared = T.alloc_shared((block_S, block_DV), dtype=output_dtype)
A_fragment = T.alloc_fragment((block_S, block_S), dtype=accum_dtype)
O_fragment = T.alloc_fragment((block_S, block_DV), dtype=accum_dtype)
G_shared = T.alloc_shared((block_S,), dtype=gate_dtype, scope="shared")
G_diff_local = T.alloc_fragment((block_S, block_S), dtype=gate_dtype)
T.annotate_layout({
Q_shared: tilelang.layout.make_swizzled_layout(Q_shared),
K_shared: tilelang.layout.make_swizzled_layout(K_shared),
V_shared: tilelang.layout.make_swizzled_layout(V_shared),
H_shared: tilelang.layout.make_swizzled_layout(H_shared),
A_shared: tilelang.layout.make_swizzled_layout(A_shared),
O_shared: tilelang.layout.make_swizzled_layout(O_shared),
})
T.clear(A_fragment)
T.clear(O_fragment)
T.disable_warp_group_reg_alloc()
for i_k in T.Pipelined(T.ceildiv(DK, block_DK), num_stages=num_stages):
T.copy(
Q[bb, bs * block_S:(bs + 1) * block_S, bh, i_k * block_DK:(i_k + 1) * block_DK],
Q_shared)
T.copy(
K[bb, bs * block_S:(bs + 1) * block_S, bh, i_k * block_DK:(i_k + 1) * block_DK],
K_shared)
T.copy(
HIDDEN[bb, bs, bh, i_k * block_DK:(i_k + 1) * block_DK,
bv * block_DV:(bv + 1) * block_DV], H_shared)
T.gemm(Q_shared, H_shared, O_fragment)
T.gemm(Q_shared, K_shared, A_fragment, transpose_B=True)
if use_g:
for i_s in T.Parallel(block_S):
G_shared[i_s] = G[bb, bs * block_S + i_s, bh]
# T.copy(G[bb, bs * block_S:(bs + 1) * block_S, bh], G_shared)
for i_s, i_v in T.Parallel(block_S, block_DV):
O_fragment[i_s, i_v] = O_fragment[i_s, i_v] * T.exp(G_shared[i_s])
for i_s1, i_s2 in T.Parallel(block_S, block_S):
G_diff_local[i_s1, i_s2] = G_shared[i_s1] - G_shared[i_s2]
for i_s1, i_s2 in T.Parallel(block_S, block_S):
with T.If(G_diff_local[i_s1, i_s2] <= 0):
with T.Then():
A_fragment[i_s1, i_s2] = A_fragment[i_s1, i_s2] * T.exp(
G_diff_local[i_s1, i_s2])
with T.Else():
A_fragment[i_s1, i_s2] = 0
for i_s1, i_s2 in T.Parallel(block_S, block_S):
with T.If(i_s1 < i_s2): # noqa: SIM117
with T.Then():
A_fragment[i_s1, i_s2] = 0
T.copy(V[bb, bs * block_S:(bs + 1) * block_S, bh, bv * block_DV:(bv + 1) * block_DV],
V_shared)
T.copy(A_fragment, A_shared)
T.gemm(A_shared, V_shared, O_fragment)
for i_s, i_v in T.Parallel(block_S, block_DV):
O_fragment[i_s, i_v] = O_fragment[i_s, i_v] * scale
T.copy(O_fragment, O_shared)
T.copy(O_shared, O[bb, bs * block_S:(bs + 1) * block_S, bh,
bv * block_DV:(bv + 1) * block_DV])
return kernel
def run_test(
B,
S,
H,
DK,
DV,
chunk_size,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
use_g,
block_DK,
block_DV,
threads,
num_stages,
):
input_dtype_torch = getattr(torch, input_dtype)
output_dtype_torch = getattr(torch, output_dtype)
accum_dtype_torch = getattr(torch, accum_dtype)
gate_dtype_torch = getattr(torch, gate_dtype)
Q, K, V, HIDDEN, G = prepare_input(B, S, H, DK, DV, chunk_size, input_dtype_torch,
output_dtype_torch, accum_dtype_torch, gate_dtype_torch)
scale = 1.0 / DK**0.5
O_ref = prepare_output(B, S, H, DK, DV, chunk_size, output_dtype_torch)
O_ref = chunk_fwd_o(Q, K, V, HIDDEN, G, scale, chunk_size=chunk_size)
block_S = chunk_size
O_tilelang = prepare_output(B, S, H, DK, DV, chunk_size, output_dtype_torch)
kernel = tilelang_chunk_fwd_o(B, S, H, DK, DV, input_dtype, output_dtype, accum_dtype,
gate_dtype, chunk_size, scale, use_g, block_S, block_DK, block_DV,
threads, num_stages)
O_tilelang = kernel(Q, K, V, HIDDEN, G)
try:
torch.testing.assert_close(O_tilelang, O_ref, rtol=1e-2, atol=1e-2)
print("tilelang chunk fwd o passed √")
except Exception as e:
print("tilelang chunk fwd o failed ✗")
print(e)
def main():
run_test(
B=1,
S=32768,
H=32,
DK=128,
DV=128,
chunk_size=64,
input_dtype="bfloat16",
output_dtype="bfloat16",
accum_dtype="float32",
gate_dtype="float32",
use_g=True,
block_DK=128,
block_DV=128,
threads=128,
num_stages=1,
)
if __name__ == "__main__":
main()
# Reference: fla/ops/common/chunk_o.py
import math
import sys # noqa: F401
import tilelang
import tilelang.language as T
from tilelang.engine.callback import register_cuda_postproc_callback # noqa: F401
print(tilelang.__file__)
# Add your fla repository path to sys.path
# Currently we use the fla repository from the flash-linear-attention project at commit id f03cb3ae
# sys.path.insert(0, "/home/tzj/flash-linear-attention")
try:
import fla
print(fla.__file__)
from fla.ops.common.chunk_o import chunk_bwd_dqkwg
except ImportError:
print("fla not found, using tilelang implementation")
fla = None
import torch
from utils import *
torch.random.manual_seed(0)
# torch.set_printoptions(profile="full")
def prepare_input_fake(
B,
S,
H,
DK,
DV,
chunk_size,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
state_dtype,
):
BS = S // chunk_size
Q = torch.ones(B, S, H, DK, dtype=input_dtype).cuda()
K = torch.ones(B, S, H, DK, dtype=input_dtype).cuda()
V = torch.ones(B, S, H, DV, dtype=input_dtype).cuda()
h = torch.ones(B, BS, H, DK, DV, dtype=input_dtype).cuda()
G = torch.ones(B, S, H, dtype=gate_dtype).cuda()
dO = torch.ones(B, S, H, DV, dtype=input_dtype).cuda()
dh = torch.ones(B, BS, H, DK, DV, dtype=input_dtype).cuda()
dv = torch.ones(B, S, H, DV, dtype=output_dtype).cuda()
W = torch.ones(B, S, H, DK, dtype=input_dtype).cuda()
return Q, K, V, h, G, dO, dh, dv, W
def prepare_input(
B,
S,
H,
DK,
DV,
chunk_size,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
state_dtype,
):
BS = S // chunk_size
Q = torch.randn(B, S, H, DK, dtype=input_dtype).cuda()
K = torch.randn(B, S, H, DK, dtype=input_dtype).cuda()
V = torch.randn(B, S, H, DV, dtype=input_dtype).cuda()
h = torch.randn(B, BS, H, DK, DV, dtype=input_dtype).cuda()
G = torch.randn(B, S, H, dtype=gate_dtype).cuda()
dO = torch.randn(B, S, H, DV, dtype=input_dtype).cuda()
dh = torch.randn(B, BS, H, DK, DV, dtype=input_dtype).cuda()
dv = torch.randn(B, S, H, DV, dtype=output_dtype).cuda()
W = torch.randn(B, S, H, DK, dtype=input_dtype).cuda()
return Q, K, V, h, G, dO, dh, dv, W
def prepare_output(
B,
S,
H,
DK,
DV,
chunk_size,
output_dtype,
gate_dtype,
state_dtype,
block_DK,
):
assert DK == 32 and block_DK == 32 or DK > 32 and block_DK >= 64, "When DK > 32, block_DK must be >= 64"
NK = math.ceil(DK / block_DK)
dq = torch.empty(B, S, H, DK, dtype=output_dtype).cuda()
dk = torch.empty(B, S, H, DK, dtype=output_dtype).cuda()
dw = torch.empty(B, S, H, DK, dtype=output_dtype).cuda()
dg = torch.empty(NK, B, S, H, dtype=gate_dtype).cuda()
return dq, dk, dw, dg
# @register_cuda_postproc_callback
# def tilelang_callback_cuda_postproc(code, _):
# cuda_code = open("../debug/chunk_o_bwd3.log", "r").read()
# code = cuda_code
# return code
@tilelang.jit(
out_idx=[-4, -3, -2, -1],
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True
})
def tilelang_chunk_o_bwd_dqkwg(
# task config
B,
S,
H,
DK,
DV,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
state_dtype,
chunk_size,
scale,
use_g=True,
use_dw=True,
# kernel config
block_DK=64,
block_DV=64,
threads=256,
num_stages=0,
):
block_S = chunk_size
BS = S // block_S
NK = math.ceil(DK / block_DK)
Q_shape = (B, S, H, DK)
K_shape = (B, S, H, DK)
V_shape = (B, S, H, DV)
h_shape = (B, BS, H, DK, DV)
G_shape = (B, S, H)
dO_shape = (B, S, H, DV)
dh_shape = (B, BS, H, DK, DV)
dv_shape = (B, S, H, DV)
W_shape = (B, S, H, DK)
dq_shape = (B, S, H, DK)
dk_shape = (B, S, H, DK)
dw_shape = (B, S, H, DK)
dg_shape = (NK, B, S, H)
@T.prim_func
def kernel(
# input
Q: T.Tensor(Q_shape, dtype=input_dtype),
K: T.Tensor(K_shape, dtype=input_dtype),
V: T.Tensor(V_shape, dtype=input_dtype),
h: T.Tensor(h_shape, dtype=input_dtype),
G: T.Tensor(G_shape, dtype=gate_dtype),
dO: T.Tensor(dO_shape, dtype=input_dtype),
dh: T.Tensor(dh_shape, dtype=input_dtype),
dv: T.Tensor(dv_shape, dtype=input_dtype),
W: T.Tensor(W_shape, dtype=input_dtype),
# output
dq: T.Tensor(dq_shape, dtype=output_dtype),
dk: T.Tensor(dk_shape, dtype=output_dtype),
dw: T.Tensor(dw_shape, dtype=output_dtype),
dg: T.Tensor(dg_shape, dtype=gate_dtype),
):
with T.Kernel(
T.ceildiv(DK, block_DK), T.ceildiv(S, block_S), B * H,
threads=threads) as (bk, bs, bbh):
bb, bh = bbh // H, bbh % H
V_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype)
dO_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype)
h_shared = T.alloc_shared((block_DK, block_DV), dtype=input_dtype)
dh_shared = T.alloc_shared((block_DK, block_DV), dtype=input_dtype)
dv_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype)
q_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype)
k_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype)
ds_shared = T.alloc_shared((block_S, block_S), dtype=output_dtype)
dg_shared_1 = T.alloc_shared((block_S,), dtype=gate_dtype)
dg_shared_2 = T.alloc_shared((block_S,), dtype=gate_dtype)
dk_shared = T.alloc_shared((block_S, block_DK), dtype=accum_dtype)
ds_fragment = T.alloc_fragment((block_S, block_S), dtype=accum_dtype)
ds_fragment_positive = T.alloc_fragment((block_S, block_S), dtype=accum_dtype)
ds_fragment_positive_transpose = T.alloc_fragment((block_S, block_S), dtype=accum_dtype)
dq_fragment = T.alloc_fragment((block_S, block_DK), dtype=accum_dtype)
dk_fragment = T.alloc_fragment((block_S, block_DK), dtype=accum_dtype)
dk_fragment_2 = T.alloc_fragment((block_S, block_DK), dtype=accum_dtype)
dw_fragment = T.alloc_fragment((block_S, block_DK), dtype=accum_dtype)
q_fragment = T.alloc_fragment((block_S, block_DK), dtype=input_dtype)
k_fragment = T.alloc_fragment((block_S, block_DK), dtype=input_dtype)
dg_fragment_reduce_tmp = T.alloc_fragment((block_S, block_DK), dtype=gate_dtype)
dg_fragment = T.alloc_fragment((block_S,), dtype=gate_dtype)
dg_fragment_2 = T.alloc_fragment((block_S,), dtype=gate_dtype)
dg_fragment_final = T.alloc_fragment((block_S,), dtype=gate_dtype)
dg_last_local = T.alloc_local((2,), dtype=gate_dtype)
dg_last_fragment = T.alloc_fragment((block_DV * block_DK), dtype=gate_dtype)
dg_last_fragment_scalar = T.alloc_fragment((1,), dtype=gate_dtype)
dg_last_fragment_2 = T.alloc_fragment((block_S * block_DK), dtype=gate_dtype)
dg_last_fragment_scalar_2 = T.alloc_fragment((1,), dtype=gate_dtype)
G_shared = T.alloc_shared((block_S, block_DK), dtype=gate_dtype, scope="shared")
G_last_local = T.alloc_local((1,), dtype=gate_dtype)
T.use_swizzle(10)
T.annotate_layout({
V_shared: tilelang.layout.make_swizzled_layout(V_shared),
dO_shared: tilelang.layout.make_swizzled_layout(dO_shared),
h_shared: tilelang.layout.make_swizzled_layout(h_shared),
dh_shared: tilelang.layout.make_swizzled_layout(dh_shared),
dv_shared: tilelang.layout.make_swizzled_layout(dv_shared),
q_shared: tilelang.layout.make_swizzled_layout(q_shared),
k_shared: tilelang.layout.make_swizzled_layout(k_shared),
})
T.clear(dg_last_local)
T.clear(G_last_local)
T.clear(G_shared)
T.clear(q_fragment)
T.clear(k_fragment)
T.clear(dg_last_fragment)
T.clear(ds_fragment)
T.clear(dq_fragment)
T.clear(dk_fragment)
T.clear(dw_fragment)
for i_v in T.Pipelined(T.ceildiv(DV, block_DV), num_stages=num_stages):
T.copy(
V[bb, bs * block_S:(bs + 1) * block_S, bh, i_v * block_DV:(i_v + 1) * block_DV],
V_shared)
T.copy(
dO[bb, bs * block_S:(bs + 1) * block_S, bh,
i_v * block_DV:(i_v + 1) * block_DV], dO_shared)
T.copy(
h[bb, bs, bh, bk * block_DK:(bk + 1) * block_DK,
i_v * block_DV:(i_v + 1) * block_DV], h_shared)
T.copy(
dh[bb, bs, bh, bk * block_DK:(bk + 1) * block_DK,
i_v * block_DV:(i_v + 1) * block_DV], dh_shared)
if use_g:
T.clear(dg_last_fragment_scalar)
# FIXME: The reduce operation of a whole buffer to a scalar is not supported and will cause incorrect result
# for i_kv in T.Parallel(block_DK * block_DV):
# dg_last_fragment[i_kv] = h_shared[i_kv // block_DV, i_kv % block_DV] * dh_shared[i_kv // block_DV, i_kv % block_DV]
for i_kv in T.Parallel(block_DK * block_DV):
i_k, i_v = i_kv // block_DV, i_kv % block_DV
dg_last_fragment[i_kv] = h_shared[i_k, i_v] * dh_shared[i_k, i_v]
T.reduce_sum(dg_last_fragment, dg_last_fragment_scalar, dim=-1, clear=False)
dg_last_local[0] += dg_last_fragment_scalar[0]
T.gemm(dO_shared, V_shared, ds_fragment, transpose_B=True)
T.gemm(dO_shared, h_shared, dq_fragment, transpose_B=True)
T.gemm(V_shared, dh_shared, dk_fragment, transpose_B=True)
if use_dw:
T.copy(
dv[bb, bs * block_S:(bs + 1) * block_S, bh,
i_v * block_DV:(i_v + 1) * block_DV], dv_shared)
T.gemm(dv_shared, h_shared, dw_fragment, transpose_B=True)
if use_dw:
for i_s, i_k in T.Parallel(block_S, block_DK):
dw_fragment[i_s, i_k] = -dw_fragment[i_s, i_k]
T.copy(
dw_fragment, dw[bb, bs * block_S:(bs + 1) * block_S, bh,
bk * block_DK:(bk + 1) * block_DK])
T.copy(Q[bb, bs * block_S:(bs + 1) * block_S, bh, bk * block_DK:(bk + 1) * block_DK],
q_shared)
T.copy(K[bb, bs * block_S:(bs + 1) * block_S, bh, bk * block_DK:(bk + 1) * block_DK],
k_shared)
T.copy(q_shared, q_fragment)
T.copy(k_shared, k_fragment)
if use_g:
T.clear(dg_fragment)
T.clear(dg_fragment_2)
for i_s, i_k in T.Parallel(block_S, block_DK):
G_shared[i_s, i_k] = G[bb, bs * block_S + i_s, bh]
G_last_local[0] = G[bb, bs * block_S + block_S - 1, bh]
# Use gmem directly instead of local register
dg_last_local[0] = dg_last_local[0] * T.exp(G[bb, bs * block_S + block_S - 1, bh])
for i_s, i_k in T.Parallel(block_S, block_DK):
dq_fragment[i_s, i_k] = dq_fragment[i_s, i_k] * T.exp(G[bb, bs * block_S + i_s,
bh]) * scale
T.clear(dg_fragment_reduce_tmp)
for i_s, i_k in T.Parallel(block_S, block_DK):
dg_fragment_reduce_tmp[i_s, i_k] = dq_fragment[i_s, i_k] * q_shared[i_s, i_k]
# FIXME: The reduce_sum statement with clear=True will cause an error of warp specialized pass
T.reduce_sum(dg_fragment_reduce_tmp, dg_fragment, dim=-1, clear=False)
for i_s, i_k in T.Parallel(block_S, block_DK):
with T.If(G_last_local[0] - G[bb, bs * block_S + i_s, bh] <= 0):
with T.Then():
dk_fragment[i_s, i_k] = dk_fragment[i_s, i_k] * T.exp(
G_last_local[0] - G[bb, bs * block_S + i_s, bh])
with T.Else():
dk_fragment[i_s, i_k] = 0
T.clear(dg_fragment_reduce_tmp)
for i_s, i_k in T.Parallel(block_S, block_DK):
dg_fragment_reduce_tmp[i_s, i_k] = dk_fragment[i_s, i_k] * (-k_shared[i_s, i_k])
# FIXME: The reduce_sum statement with clear=True will cause an error of warp specialized pass
T.reduce_sum(dg_fragment_reduce_tmp, dg_fragment, dim=-1, clear=False)
# FIXME: The reduce operation of a whole buffer to a scalar is not supported and will cause incorrect result
T.copy(dk_fragment, dk_shared)
T.clear(dg_last_fragment_scalar_2)
for i_sk in T.Parallel(block_S * block_DK):
i_s, i_k = i_sk // block_DK, i_sk % block_DK
dg_last_fragment_2[i_sk] = dk_shared[i_s, i_k] * k_shared[i_s, i_k]
T.reduce_sum(dg_last_fragment_2, dg_last_fragment_scalar_2, dim=-1, clear=False)
dg_last_local[1] = dg_last_fragment_scalar_2[0]
for i_s1, i_s2 in T.Parallel(block_S, block_S):
with T.If(i_s1 >= i_s2 and
G[bb, bs * block_S + i_s1, bh] - G[bb, bs * block_S + i_s2, bh] <= 0):
with T.Then():
ds_fragment[i_s1, i_s2] = ds_fragment[
i_s1, i_s2] * T.exp(G[bb, bs * block_S + i_s1, bh] -
G[bb, bs * block_S + i_s2, bh]) * scale
with T.Else():
ds_fragment[i_s1, i_s2] = 0
T.clear(ds_fragment_positive)
T.clear(ds_fragment_positive_transpose)
T.gemm(q_shared, k_shared, ds_fragment_positive, transpose_B=True)
for i_s1, i_s2 in T.Parallel(block_S, block_S):
ds_fragment_positive[
i_s1, i_s2] = ds_fragment[i_s1, i_s2] * ds_fragment_positive[i_s1, i_s2]
# FIXME: The reduce_sum statement with clear=True will cause an error of warp specialized pass
T.reduce_sum(ds_fragment_positive, dg_fragment, dim=1, clear=False)
T.copy(dg_fragment, dg_shared_1)
# We should transpose the matrix because the reduce_sum statement can only reduce along the last dimension
for i_s1, i_s2 in T.Parallel(block_S, block_S):
ds_fragment_positive_transpose[i_s2, i_s1] = ds_fragment_positive[i_s1, i_s2]
# FIXME: The reduce_sum statement with clear=True will cause an error of warp specialized pass
T.reduce_sum(ds_fragment_positive_transpose, dg_fragment_2, dim=1, clear=False)
T.copy(dg_fragment_2, dg_shared_2)
for i_s in T.Parallel(block_S):
dg_fragment_final[i_s] = dg_shared_1[i_s] - dg_shared_2[i_s]
T.copy(ds_fragment, ds_shared)
T.gemm(ds_shared, k_shared, dq_fragment)
T.gemm(ds_shared, q_shared, dk_fragment, transpose_A=True)
for i_s in T.Parallel(block_S):
with T.If(i_s >= block_S - 1): # noqa: SIM117
with T.Then():
dg_fragment_final[
i_s] = dg_fragment_final[i_s] + dg_last_local[0] + dg_last_local[1]
T.copy(
dq_fragment, dq[bb, bs * block_S:(bs + 1) * block_S, bh,
bk * block_DK:(bk + 1) * block_DK])
T.copy(
dk_fragment, dk[bb, bs * block_S:(bs + 1) * block_S, bh,
bk * block_DK:(bk + 1) * block_DK])
for i_s in T.Parallel(block_S):
dg[bk, bb, bs * block_S + i_s, bh] = dg_fragment_final[i_s]
else:
for i_s1, i_s2 in T.Parallel(block_S, block_S):
with T.If(i_s1 < i_s2): # noqa: SIM117
with T.Then():
ds_fragment[i_s1, i_s2] = 0
T.clear(dk_fragment_2)
T.copy(ds_fragment, ds_shared)
T.gemm(ds_shared, k_shared, dq_fragment)
T.gemm(ds_shared, q_shared, dk_fragment_2, transpose_A=True)
for i_s, i_k in T.Parallel(block_S, block_DK):
dq_fragment[i_s, i_k] = dq_fragment[i_s, i_k] * scale
dk_fragment[i_s, i_k] = dk_fragment[i_s, i_k] + dk_fragment_2[i_s, i_k] * scale
T.copy(
dq_fragment, dq[bb, bs * block_S:(bs + 1) * block_S, bh,
bk * block_DK:(bk + 1) * block_DK])
T.copy(
dk_fragment, dk[bb, bs * block_S:(bs + 1) * block_S, bh,
bk * block_DK:(bk + 1) * block_DK])
return kernel
def do_bench(fn, *args, warmup=10, rep=10, **kwargs):
"""
Do benchmark for a function.
"""
start_event = [torch.cuda.Event(enable_timing=True) for i in range(rep)]
end_event = [torch.cuda.Event(enable_timing=True) for i in range(rep)]
for _ in range(warmup):
fn(*args, **kwargs)
torch.cuda.synchronize()
for i in range(rep):
start_event[i].record()
fn(*args, **kwargs)
end_event[i].record()
torch.cuda.synchronize()
# Record clocks
times = torch.tensor(
[s.elapsed_time(e) for s, e in zip(start_event, end_event)],
dtype=torch.float,
)
return times.mean().item()
def run_test(
B,
S,
H,
DK,
DV,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
state_dtype,
chunk_size,
scale,
use_g=True,
use_dw=True,
block_DK=64,
block_DV=64,
threads=256,
num_stages=0,
):
Q, K, V, h, G, dO, dh, dv, W = prepare_input(B, S, H, DK, DV, chunk_size,
getattr(torch, input_dtype),
getattr(torch, output_dtype),
getattr(torch, accum_dtype),
getattr(torch, gate_dtype),
getattr(torch, state_dtype))
dq_ref, dk_ref, dw_ref, dg_ref = prepare_output(B, S, H, DK, DV, chunk_size,
getattr(torch, output_dtype),
getattr(torch, gate_dtype),
getattr(torch, state_dtype), block_DK)
dq_tilelang, dk_tilelang, dw_tilelang, dg_tilelang = prepare_output(
B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype),
getattr(torch, state_dtype), block_DK)
# ref
if use_g:
dq_ref, dk_ref, dw_ref, dg_ref = chunk_bwd_dqkwg(
Q, K, V, G, dO, h, dh, dv, W, chunk_size=chunk_size, scale=scale)
else:
dq_ref, dk_ref, dw_ref, dg_ref = chunk_bwd_dqkwg(
Q, K, V, None, dO, h, dh, dv, W, chunk_size=chunk_size, scale=scale)
# tilelang
kernel = tilelang_chunk_o_bwd_dqkwg(B, S, H, DK, DV, input_dtype, output_dtype, accum_dtype,
gate_dtype, state_dtype, chunk_size, scale, use_g, use_dw,
block_DK, block_DV, threads, num_stages)
print(kernel.get_kernel_source())
dq_tilelang, dk_tilelang, dw_tilelang, dg_tilelang = kernel(Q, K, V, h, G, dO, dh, dv, W)
if use_g:
dg_tilelang = dg_tilelang.sum(dim=0)
# check
try:
assert_similar(dq_ref, dq_tilelang, 1e-5, "tilelang chunk o bwd dq")
print("tilelang chunk o bwd dq passed √")
except Exception as e:
print("tilelang chunk o bwd dq failed ✗")
print(e)
try:
assert_similar(dk_ref, dk_tilelang, 1e-5, "tilelang chunk o bwd dk")
print("tilelang chunk o bwd dk passed √")
except Exception as e:
print("tilelang chunk o bwd dk failed ✗")
print(e)
if use_g:
try:
assert_similar(dg_ref, dg_tilelang, 1e-5, "tilelang chunk o bwd dg")
print("tilelang chunk o bwd dg passed √")
except Exception as e:
print("tilelang chunk o bwd dg failed ✗")
print(e)
if use_dw:
try:
assert_similar(dw_ref, dw_tilelang, 1e-5, "tilelang chunk o bwd dw")
print("tilelang chunk o bwd dw passed √")
except Exception as e:
print("tilelang chunk o bwd dw failed ✗")
print(e)
def main():
DK = 128
DV = 128
run_test(
B=1,
S=32768,
H=8,
DK=DK,
DV=DV,
input_dtype="bfloat16",
output_dtype="bfloat16",
accum_dtype="float32",
gate_dtype="float32",
state_dtype="float32",
chunk_size=64,
scale=DK**-0.5,
# scale=1,
use_g=True,
use_dw=True,
block_DK=64,
block_DV=64,
threads=128,
num_stages=0,
)
if __name__ == "__main__":
main()
# Reference: fla/ops/common/chunk_scaled_dot_kkt.py
import tilelang
import tilelang.language as T
import sys # noqa: F401
# Add your fla repository path to sys.path
# Currently we use the fla repository from the flash-linear-attention project at commit id f03cb3ae
# sys.path.insert(0, "/home/tzj/flash-linear-attention")
try:
import fla
print(fla.__file__)
from fla.ops.common.chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd
except ImportError:
print("fla not found, using tilelang implementation")
fla = None
import torch
torch.set_printoptions(profile="full")
torch.random.manual_seed(0)
def prepare_input(
B,
S,
H,
DK,
input_dtype,
output_dtype,
accum_dtype,
):
K = torch.randn(B, S, H, DK, dtype=input_dtype).cuda()
Beta = torch.randn(B, S, H, dtype=input_dtype).cuda()
G = torch.randn(B, S, H, dtype=accum_dtype).cuda()
return K, Beta, G
def prepare_output(
B,
S,
H,
chunk_size,
dtype,
):
BS = chunk_size
A = torch.empty(B, S, H, BS, dtype=dtype).cuda()
return A
@tilelang.jit(out_idx=[-1])
def tilelang_chunk_scaled_dot_kkt_fwd(
# task config
B,
S,
H,
DK,
chunk_size=64,
input_dtype="bfloat16",
output_dtype="bfloat16",
accum_dtype="float32",
use_g=True,
# kernel config
block_S=64,
block_DK=64,
threads=256,
num_stages=0,
):
K_shape = (B, S, H, DK)
Beta_shape = (B, S, H)
G_shape = (B, S, H)
assert chunk_size == block_S, "chunk_size must be equal to block_S"
BS = chunk_size
output_shape = (B, S, H, BS)
@T.prim_func
def kernel(
K: T.Tensor(K_shape, dtype=input_dtype),
Beta: T.Tensor(Beta_shape, dtype=input_dtype),
G: T.Tensor(G_shape, dtype=accum_dtype),
A: T.Tensor(output_shape, dtype=output_dtype),
):
with T.Kernel(T.ceildiv(S, block_S), B * H, threads=threads) as (bs, bbh):
bb, bh = bbh // H, bbh % H
# !! Pay attention to the scope of the shared memory: may cause misaligned address when shape is one dimension or the buffer is too small
Beta_shared = T.alloc_shared((block_S,), dtype=input_dtype, scope="shared")
K_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype)
A_shared = T.alloc_shared((block_S, block_S), dtype=output_dtype)
Beta_K_fragment = T.alloc_fragment((block_S, block_DK), dtype=input_dtype)
A_fragment = T.alloc_fragment((block_S, block_S), dtype=accum_dtype)
# Tensor used for gated:
G_shared = T.alloc_shared((block_S,), dtype=accum_dtype, scope="shared")
G_diff_local = T.alloc_fragment((block_S, block_S), dtype=accum_dtype)
T.annotate_layout({
K_shared: tilelang.layout.make_swizzled_layout(K_shared),
A_shared: tilelang.layout.make_swizzled_layout(A_shared),
})
T.fill(A_fragment, 0)
T.disable_warp_group_reg_alloc()
for i_s in T.Parallel(block_S):
Beta_shared[i_s] = Beta[bb, bs * block_S + i_s, bh]
for i_k in T.Pipelined(T.ceildiv(DK, block_DK), num_stages=num_stages):
T.copy(
K[bb, bs * block_S:(bs + 1) * block_S, bh, i_k * block_DK:(i_k + 1) * block_DK],
K_shared)
for i_s, i_k2 in T.Parallel(block_S, block_DK):
Beta_K_fragment[i_s, i_k2] = K_shared[i_s, i_k2] * Beta_shared[i_s]
T.gemm(Beta_K_fragment, K_shared, A_fragment, transpose_B=True)
if use_g:
for i_s in T.Parallel(block_S):
G_shared[i_s] = G[bb, bs * block_S + i_s, bh]
for i_s1, i_s2 in T.Parallel(block_S, block_S):
G_diff_local[i_s1, i_s2] = G_shared[i_s1] - G_shared[i_s2]
for i_s1, i_s2 in T.Parallel(block_S, block_S):
with T.If(G_diff_local[i_s1, i_s2] <= 0 and i_s1 > i_s2):
with T.Then():
A_fragment[i_s1, i_s2] = A_fragment[i_s1, i_s2] * T.exp(
G_diff_local[i_s1, i_s2])
with T.Else():
A_fragment[i_s1, i_s2] = 0
else:
for i_s1, i_s2 in T.Parallel(block_S, block_S):
with T.If(i_s1 <= i_s2): # noqa: SIM117
with T.Then():
A_fragment[i_s1, i_s2] = 0
T.copy(A_fragment, A_shared)
T.copy(A_shared, A[bb, bs * block_S:(bs + 1) * block_S, bh, :])
return kernel
def run_test(
B,
S,
H,
DK,
chunk_size,
input_dtype,
output_dtype,
accum_dtype,
use_g,
block_DK,
threads,
num_stages,
):
K, Beta, G = prepare_input(B, S, H, DK, getattr(torch, input_dtype),
getattr(torch, output_dtype), getattr(torch, accum_dtype))
A_ref = prepare_output(B, S, H, chunk_size, getattr(torch, output_dtype))
A_tilelang = prepare_output(B, S, H, chunk_size, getattr(torch, output_dtype))
# reference
if use_g:
A_ref = chunk_scaled_dot_kkt_fwd(
K, Beta, G, chunk_size=chunk_size, output_dtype=getattr(torch, output_dtype))
else:
A_ref = chunk_scaled_dot_kkt_fwd(
K, Beta, None, chunk_size=chunk_size, output_dtype=getattr(torch, output_dtype))
# tilelang
block_S = chunk_size
kernel = tilelang_chunk_scaled_dot_kkt_fwd(B, S, H, DK, chunk_size, input_dtype, output_dtype,
accum_dtype, use_g, block_S, block_DK, threads,
num_stages)
A_tilelang = kernel(K, Beta, G)
try:
torch.testing.assert_close(A_tilelang, A_ref, rtol=1e-2, atol=1e-2)
print("tilelang chunk scaled dot kkt fwd passed √")
except Exception as e:
print("tilelang chunk scaled dot kkt fwd failed ✗")
print(e)
print("reference cuda kernel:")
print(kernel.get_kernel_source())
def main():
run_test(
B=1,
S=32768,
H=32,
DK=128,
chunk_size=64,
input_dtype="bfloat16",
output_dtype="bfloat16",
accum_dtype="float32",
use_g=True,
block_DK=64,
threads=128,
num_stages=2)
if __name__ == "__main__":
main()
# Util functions for flash linear attention cumsum
# Reference: fla/ops/utils/cumsum.py
import tilelang
import tilelang.language as T
import sys # noqa: F401
# Add your fla repository path to sys.path
# Currently we use the fla repository from the flash-linear-attention project at commit id f03cb3ae
# sys.path.insert(0, "/home/tzj/flash-linear-attention")
try:
import fla
print(fla.__file__)
from fla.ops.utils.cumsum import chunk_local_cumsum_scalar
except ImportError:
print("fla not found, using tilelang implementation")
fla = None
import torch
@tilelang.jit(
out_idx=[-1],
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True
})
def tilelang_chunk_local_cumsum_scalar(
# task config
B,
S,
H,
chunk_size=64,
is_varlen=False,
head_first=False,
reverse=False,
input_dtype="float16",
output_dtype="float32",
# kernel config
block_S=64,
threads=256,
use_fragment=False,
):
G_shape = (B, H, S) if head_first else (B, S, H)
assert chunk_size == 2**(chunk_size.bit_length() - 1), "chunk_size must be a power of 2"
assert chunk_size == block_S, "chunk_size must be equal to block_S"
@T.prim_func
def kernel(
G: T.Tensor(G_shape, dtype=input_dtype),
G_new: T.Tensor(G_shape, dtype=output_dtype),
):
with T.Kernel(T.ceildiv(S, block_S), B * H, threads=threads) as (bs, bbh):
bb, bh = bbh // H, bbh % H
G_shared = T.alloc_shared((1, block_S), dtype=output_dtype, scope="shared")
if head_first:
T.copy(G[bb, bh, bs * block_S:(bs + 1) * block_S], G_shared)
else:
T.copy(G[bb, bs * block_S:(bs + 1) * block_S, bh], G_shared)
if use_fragment:
G_fragment = T.alloc_fragment((1, block_S), dtype=output_dtype, scope="shared")
T.copy(G_shared, G_fragment)
T.cumsum(G_fragment, dim=1, reverse=reverse)
if head_first:
T.copy(G_fragment, G_new[bb, bh, bs * block_S:(bs + 1) * block_S])
else:
T.copy(G_fragment, G_new[bb, bs * block_S:(bs + 1) * block_S, bh])
else:
T.cumsum(G_shared, dim=1, reverse=reverse)
if head_first:
T.copy(G_shared, G_new[bb, bh, bs * block_S:(bs + 1) * block_S])
else:
T.copy(G_shared, G_new[bb, bs * block_S:(bs + 1) * block_S, bh])
return kernel
def prepare_cumsum_input(
B,
S,
H,
dtype,
):
G = torch.randn(B, S, H, dtype=dtype).cuda()
return G
def prepare_cumsum_output(
B,
S,
H,
dtype,
):
G_new = torch.empty(B, S, H, dtype=dtype).cuda()
return G_new
def run_test(
B,
S,
H,
chunk_size,
reverse,
head_first,
input_dtype,
output_dtype,
threads,
use_fragment,
):
G = prepare_cumsum_input(B, S, H, getattr(torch, input_dtype))
G_new_ref = prepare_cumsum_output(B, S, H, getattr(torch, output_dtype))
G_new_tilelang = prepare_cumsum_output(B, S, H, getattr(torch, output_dtype))
# reference cumsum
G_new_ref = chunk_local_cumsum_scalar(
g=G,
chunk_size=chunk_size,
reverse=reverse,
head_first=head_first,
output_dtype=getattr(torch, output_dtype))
# tilelang cumsum
block_S = chunk_size
kernel = tilelang_chunk_local_cumsum_scalar(
B=B,
S=S,
H=H,
chunk_size=chunk_size,
reverse=reverse,
head_first=head_first,
input_dtype=input_dtype,
output_dtype=output_dtype,
block_S=block_S,
threads=threads,
use_fragment=use_fragment,
)
torch.cuda.profiler.start()
G_new_tilelang = kernel(G)
torch.cuda.profiler.stop()
try:
torch.testing.assert_close(G_new_tilelang, G_new_ref, rtol=1e-2, atol=1e-2)
print("tilelang cumsum passed √")
except Exception as e:
print("tilelang cumsum failed ✗")
print(e)
print("G:")
print(G.view(-1))
print("G_new_tilelang:")
print(G_new_tilelang.view(-1))
print("G_new_ref:")
print(G_new_ref.view(-1))
def main():
run_test(
B=1,
S=32768,
H=32,
chunk_size=64,
reverse=True,
head_first=False,
input_dtype="float32",
output_dtype="float32",
threads=256,
use_fragment=False)
if __name__ == "__main__":
main()
# Reference: fla/ops/gated_delta_rule/wy_fast.py
import tilelang
import tilelang.language as T
import sys # noqa: F401
# Add your fla repository path to sys.path
# Currently we use the fla repository from the flash-linear-attention project at commit id f03cb3ae
# sys.path.insert(0, "/home/tzj/flash-linear-attention")
try:
import fla
print(fla.__file__)
from fla.ops.gated_delta_rule.wy_fast import recompute_w_u_fwd
except ImportError:
print("fla not found, using tilelang implementation")
fla = None
import torch
torch.random.manual_seed(1)
def prepare_input(B, S, H, DK, DV, chunk_size, input_dtype, output_dtype, gate_dtype=torch.float32):
BS = chunk_size
K = torch.randn(B, S, H, DK, dtype=input_dtype).cuda()
V = torch.randn(B, S, H, DV, dtype=input_dtype).cuda()
Beta = torch.randn(B, S, H, dtype=input_dtype).cuda()
G = torch.randn(B, S, H, dtype=gate_dtype).cuda()
A = torch.randn(B, S, H, BS, dtype=output_dtype).cuda()
return K, V, Beta, G, A
def prepare_output(
B,
S,
H,
DK,
DV,
output_dtype,
):
W = torch.empty(B, S, H, DK, dtype=output_dtype).cuda()
U = torch.empty(B, S, H, DV, dtype=output_dtype).cuda()
return W, U
@tilelang.jit(out_idx=[-2, -1])
def tilelang_recompute_w_u_fwd(
# task config
B,
S,
H,
DK,
DV,
input_dtype,
output_dtype,
gate_dtype,
accum_dtype,
chunk_size,
# kernel config
block_S=64,
block_DK=64,
block_DV=64,
threads=256,
num_stages=0,
):
K_shape = (B, S, H, DK)
V_shape = (B, S, H, DV)
Beta_shape = (B, S, H)
assert chunk_size == block_S, "chunk_size must be equal to block_S"
BS = chunk_size
G_shape = (B, S, H)
A_shape = (B, S, H, BS)
@T.prim_func
def kernel(
K: T.Tensor(K_shape, dtype=input_dtype),
V: T.Tensor(V_shape, dtype=input_dtype),
Beta: T.Tensor(Beta_shape, dtype=input_dtype),
G: T.Tensor(G_shape, dtype=gate_dtype),
A: T.Tensor(A_shape, dtype=output_dtype),
W: T.Tensor(K_shape, dtype=output_dtype),
U: T.Tensor(V_shape, dtype=output_dtype),
):
with T.Kernel(T.ceildiv(S, block_S), B * H, threads=threads) as (bs, bbh):
bb, bh = bbh // H, bbh % H
Beta_shared = T.alloc_shared((block_S,), dtype=input_dtype, scope="shared")
K_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype)
V_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype)
G_shared = T.alloc_shared((block_S,), dtype=gate_dtype, scope="shared")
A_shared = T.alloc_shared((block_S, block_S), dtype=output_dtype)
W_fragment = T.alloc_fragment((block_S, block_DK), dtype=accum_dtype)
U_fragment = T.alloc_fragment((block_S, block_DV), dtype=accum_dtype)
W_shared = T.alloc_shared((block_S, block_DK), dtype=output_dtype)
U_shared = T.alloc_shared((block_S, block_DV), dtype=output_dtype)
W_Beta_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype)
U_Beta_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype)
T.annotate_layout({
K_shared: tilelang.layout.make_swizzled_layout(K_shared),
V_shared: tilelang.layout.make_swizzled_layout(V_shared),
A_shared: tilelang.layout.make_swizzled_layout(A_shared),
W_shared: tilelang.layout.make_swizzled_layout(W_shared),
U_shared: tilelang.layout.make_swizzled_layout(U_shared),
W_Beta_shared: tilelang.layout.make_swizzled_layout(W_Beta_shared),
U_Beta_shared: tilelang.layout.make_swizzled_layout(U_Beta_shared),
})
T.disable_warp_group_reg_alloc()
for i_s in T.Parallel(block_S):
Beta_shared[i_s] = Beta[bb, bs * block_S + i_s, bh]
G_shared[i_s] = T.exp(G[bb, bs * block_S + i_s, bh])
T.copy(A[bb, bs * block_S:(bs + 1) * block_S, bh, :], A_shared)
for i_v in T.Pipelined(T.ceildiv(DV, block_DV), num_stages=num_stages):
T.copy(
V[bb, bs * block_S:(bs + 1) * block_S, bh, i_v * block_DV:(i_v + 1) * block_DV],
V_shared)
for i_s, i_v2 in T.Parallel(block_S, block_DV):
U_Beta_shared[i_s, i_v2] = V_shared[i_s, i_v2] * Beta_shared[i_s]
T.gemm(A_shared, U_Beta_shared, U_fragment, clear_accum=True)
# First copy to smem, then copy to gmem to reduce U2RU instructions
T.copy(U_fragment, U_shared)
T.copy(
U_shared, U[bb, bs * block_S:(bs + 1) * block_S, bh,
i_v * block_DV:(i_v + 1) * block_DV])
for i_k in T.Pipelined(T.ceildiv(DK, block_DK), num_stages=num_stages):
T.copy(
K[bb, bs * block_S:(bs + 1) * block_S, bh, i_k * block_DK:(i_k + 1) * block_DK],
K_shared)
for i_s, i_k2 in T.Parallel(block_S, block_DK):
W_Beta_shared[i_s,
i_k2] = K_shared[i_s, i_k2] * Beta_shared[i_s] * G_shared[i_s]
T.gemm(A_shared, W_Beta_shared, W_fragment, clear_accum=True)
# First copy to smem, then copy to gmem to reduce U2RU instructions
T.copy(W_fragment, W_shared)
T.copy(
W_shared, W[bb, bs * block_S:(bs + 1) * block_S, bh,
i_k * block_DK:(i_k + 1) * block_DK])
return kernel
def run_test(
B,
S,
H,
DK,
DV,
chunk_size,
input_dtype,
output_dtype,
gate_dtype,
accum_dtype,
block_DK,
block_DV,
threads,
num_stages,
):
K, V, Beta, G, A = prepare_input(
B,
S,
H,
DK,
DV,
chunk_size,
getattr(torch, input_dtype),
getattr(torch, output_dtype),
gate_dtype=getattr(torch, gate_dtype))
W_ref, U_ref = prepare_output(B, S, H, DK, DV, getattr(torch, output_dtype))
W_tilelang, U_tilelang = prepare_output(B, S, H, DK, DV, getattr(torch, output_dtype))
# reference
W_ref, U_ref = recompute_w_u_fwd(K, V, Beta, G, A, None)
# tilelang
block_S = chunk_size
kernel = tilelang_recompute_w_u_fwd(
B,
S,
H,
DK,
DV,
input_dtype,
output_dtype,
gate_dtype,
accum_dtype,
chunk_size,
block_S=block_S,
block_DK=block_DK,
block_DV=block_DV,
threads=threads,
num_stages=num_stages)
print(kernel.get_kernel_source())
W_tilelang, U_tilelang = kernel(K, V, Beta, G, A)
try:
torch.testing.assert_close(W_tilelang, W_ref, rtol=1e-2, atol=1e-2)
print("tilelang recompute w passed √")
except Exception as e:
print("tilelang recompute w failed ✗")
print(e)
try:
torch.testing.assert_close(U_tilelang, U_ref, rtol=1e-2, atol=1e-2)
print("tilelang recompute u passed √")
except Exception as e:
print("tilelang recompute u failed ✗")
print(e)
def main():
run_test(
B=1,
S=32768,
H=32,
DK=128,
DV=128,
chunk_size=64,
input_dtype="bfloat16",
output_dtype="bfloat16",
gate_dtype="float32",
accum_dtype="float32",
block_DK=64,
block_DV=32,
threads=128,
num_stages=3)
if __name__ == "__main__":
main()
# Reference: fla/ops/gated_delta_rule/wy_fast.py
import sys # noqa: F401
import tilelang
import tilelang.language as T
# Add your fla repository path to sys.path
# Currently we use the fla repository from the flash-linear-attention project at commit id 00000000
# sys.path.insert(0, "/home/tzj/flash-linear-attention")
try:
import fla
print(fla.__file__)
from fla.ops.gated_delta_rule.wy_fast import bwd_prepare_wy_repr
except ImportError:
print("fla not found, using tilelang implementation")
fla = None
import torch
import torch.nn.functional as F
torch.random.manual_seed(0)
torch.set_printoptions(profile="full")
def prepare_input_fake(
B,
S,
H,
DK,
DV,
chunk_size,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
state_dtype,
):
BS = chunk_size
K = torch.ones(B, S, H, DK, dtype=input_dtype).cuda()
V = torch.ones(B, S, H, DV, dtype=input_dtype).cuda()
Beta = torch.ones(B, S, H, dtype=input_dtype).cuda()
G = torch.ones(B, S, H, dtype=gate_dtype).cuda()
A = torch.ones(B, S, H, BS, dtype=input_dtype).cuda()
dw = torch.ones(B, S, H, DK, dtype=input_dtype).cuda()
du = torch.ones(B, S, H, DV, dtype=input_dtype).cuda()
return K, V, Beta, G, A, dw, du
def prepare_input(
B,
S,
H,
DK,
DV,
chunk_size,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
state_dtype,
):
BS = chunk_size
K = torch.randn(B, S, H, DK, dtype=input_dtype).cuda()
K = F.normalize(K, dim=-1, p=2)
V = torch.randn(B, S, H, DV, dtype=input_dtype).cuda()
V = F.normalize(V, dim=-1, p=2)
Beta = torch.randn(B, S, H, dtype=input_dtype).cuda()
G = torch.randn(B, S, H, dtype=gate_dtype).cuda()
A = torch.randn(B, S, H, BS, dtype=input_dtype).cuda()
dw = torch.randn(B, S, H, DK, dtype=input_dtype).cuda()
du = torch.randn(B, S, H, DV, dtype=input_dtype).cuda()
return K, V, Beta, G, A, dw, du
def prepare_output(
B,
S,
H,
DK,
DV,
chunk_size,
output_dtype,
gate_dtype,
state_dtype,
):
dk = torch.empty(B, S, H, DK, dtype=output_dtype).cuda()
dv = torch.empty(B, S, H, DV, dtype=output_dtype).cuda()
dbeta = torch.empty(B, S, H, dtype=output_dtype).cuda()
dg = torch.empty(B, S, H, dtype=gate_dtype).cuda()
return dk, dv, dbeta, dg
@tilelang.jit(
out_idx=[-5, -4, -3, -2, -1],
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True
})
def tilelang_wy_fast_bwd(
# task config
B,
S,
H,
DK,
DV,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
state_dtype,
chunk_size,
# kernel config
block_DK=64,
block_DV=64,
threads=128,
num_stages=0,
):
block_S = chunk_size
BS = block_S
K_shape = (B, S, H, DK)
V_shape = (B, S, H, DV)
Beta_shape = (B, S, H)
G_shape = (B, S, H)
A_shape = (B, S, H, BS)
dw_shape = (B, S, H, DK)
du_shape = (B, S, H, DV)
dk_shape = (B, S, H, DK)
dv_shape = (B, S, H, DV)
dbeta_shape = (B, S, H)
dg_shape = (B, S, H)
dA_shape = (B, S, H, BS)
@T.prim_func
def kernel(
# input
K: T.Tensor(K_shape, dtype=input_dtype),
V: T.Tensor(V_shape, dtype=input_dtype),
Beta: T.Tensor(Beta_shape, dtype=input_dtype),
G: T.Tensor(G_shape, dtype=gate_dtype),
A: T.Tensor(A_shape, dtype=input_dtype),
dw: T.Tensor(dw_shape, dtype=input_dtype),
du: T.Tensor(du_shape, dtype=input_dtype),
# output
dA: T.Tensor(dA_shape, dtype=input_dtype),
dk: T.Tensor(dk_shape, dtype=output_dtype),
dv: T.Tensor(dv_shape, dtype=output_dtype),
dbeta: T.Tensor(dbeta_shape, dtype=output_dtype),
dg: T.Tensor(dg_shape, dtype=gate_dtype),
):
with T.Kernel(T.ceildiv(S, block_S), B * H, threads=threads) as (bs, bbh):
bb, bh = bbh // H, bbh % H
A_shared = T.alloc_shared((block_S, block_S), dtype=input_dtype)
K_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype)
K_shared_beta_g = T.alloc_shared((block_S, block_DK), dtype=input_dtype)
V_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype)
V_shared_beta = T.alloc_shared((block_S, block_DV), dtype=input_dtype)
Beta_shared = T.alloc_shared((block_S,), dtype=input_dtype)
G_shared = T.alloc_shared((block_S,), dtype=gate_dtype)
G_shared_exp = T.alloc_shared((block_S,), dtype=gate_dtype)
dw_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype)
du_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype)
dA_fragment = T.alloc_fragment((block_S, block_S), dtype=accum_dtype)
dk_fragment = T.alloc_fragment((block_S, block_DK), dtype=accum_dtype)
dk_fragment_beta_g = T.alloc_fragment((block_S, block_DK), dtype=accum_dtype)
dv_fragment = T.alloc_fragment((block_S, block_DV), dtype=accum_dtype)
dv_fragment_beta = T.alloc_fragment((block_S, block_DV), dtype=accum_dtype)
dbeta_fragment_k = T.alloc_fragment((block_S,), dtype=accum_dtype)
dbeta_fragment_v = T.alloc_fragment((block_S,), dtype=accum_dtype)
dbeta_fragment_reduce_tmpk = T.alloc_fragment((block_S, block_DK), dtype=accum_dtype)
dbeta_fragment_reduce_tmpv = T.alloc_fragment((block_S, block_DV), dtype=accum_dtype)
dg_fragment = T.alloc_fragment((block_S,), dtype=gate_dtype)
dg_fragment_reduce_tmp = T.alloc_fragment((block_S, block_DK), dtype=gate_dtype)
T.use_swizzle(10)
T.clear(dA_fragment)
T.clear(dk_fragment)
T.clear(dk_fragment_beta_g)
T.clear(dv_fragment)
T.clear(dv_fragment_beta)
T.clear(dbeta_fragment_k)
T.clear(dbeta_fragment_v)
T.clear(dg_fragment)
T.copy(A[bb, bs * block_S:(bs + 1) * block_S, bh, :], A_shared)
for i_s in T.Parallel(block_S):
Beta_shared[i_s] = Beta[bb, bs * block_S + i_s, bh]
G_shared[i_s] = G[bb, bs * block_S + i_s, bh]
G_shared_exp[i_s] = T.exp(G[bb, bs * block_S + i_s, bh])
# Update dk
for i_k in T.Pipelined(T.ceildiv(DK, block_DK), num_stages=num_stages):
T.copy(
K[bb, bs * block_S:(bs + 1) * block_S, bh, i_k * block_DK:(i_k + 1) * block_DK],
K_shared)
for i_s, i_k2 in T.Parallel(block_S, block_DK):
K_shared_beta_g[i_s,
i_k2] = K_shared[i_s,
i_k2] * Beta_shared[i_s] * G_shared_exp[i_s]
T.copy(
dw[bb, bs * block_S:(bs + 1) * block_S, bh,
i_k * block_DK:(i_k + 1) * block_DK], dw_shared)
T.gemm(dw_shared, K_shared_beta_g, dA_fragment, transpose_B=True)
T.gemm(A_shared, dw_shared, dk_fragment_beta_g, clear_accum=True, transpose_A=True)
for i_s, i_k2 in T.Parallel(block_S, block_DK):
dk_fragment[
i_s,
i_k2] = dk_fragment_beta_g[i_s, i_k2] * Beta_shared[i_s] * G_shared_exp[i_s]
# for i_s, i_k2 in T.Parallel(block_S, block_DK):
# dbeta_fragment[i_s] = dbeta_fragment[i_s] + dk_fragment_beta_g[i_s, i_k2] * K_shared[i_s, i_k2] * G_shared_exp[i_s]
for i_s, i_k2 in T.Parallel(block_S, block_DK):
dbeta_fragment_reduce_tmpk[i_s, i_k2] = dk_fragment_beta_g[
i_s, i_k2] * K_shared[i_s, i_k2] * G_shared_exp[i_s]
T.reduce_sum(dbeta_fragment_reduce_tmpk, dbeta_fragment_k, dim=1, clear=False)
# for i_s, i_k2 in T.Parallel(block_S, block_DK):
# dg_fragment[i_s] = dg_fragment[i_s] + dk_fragment_beta_g[i_s, i_k2] * K_shared[i_s, i_k2] * G_shared_exp[i_s] * Beta_shared[i_s]
for i_s, i_k2 in T.Parallel(block_S, block_DK):
dg_fragment_reduce_tmp[i_s, i_k2] = dk_fragment_beta_g[i_s, i_k2] * K_shared[
i_s, i_k2] * G_shared_exp[i_s] * Beta_shared[i_s]
T.reduce_sum(dg_fragment_reduce_tmp, dg_fragment, dim=1, clear=False)
# correct dk
T.copy(
dk_fragment, dk[bb, bs * block_S:(bs + 1) * block_S, bh,
i_k * block_DK:(i_k + 1) * block_DK])
# Update dv
for i_v in T.Pipelined(T.ceildiv(DV, block_DV), num_stages=num_stages):
T.copy(
V[bb, bs * block_S:(bs + 1) * block_S, bh, i_v * block_DV:(i_v + 1) * block_DV],
V_shared)
for i_s, i_v2 in T.Parallel(block_S, block_DV):
V_shared_beta[i_s, i_v2] = V_shared[i_s, i_v2] * Beta_shared[i_s]
T.copy(
du[bb, bs * block_S:(bs + 1) * block_S, bh,
i_v * block_DV:(i_v + 1) * block_DV], du_shared)
T.gemm(du_shared, V_shared_beta, dA_fragment, transpose_B=True)
T.gemm(A_shared, du_shared, dv_fragment_beta, clear_accum=True, transpose_A=True)
for i_s, i_v2 in T.Parallel(block_S, block_DV):
dv_fragment[i_s, i_v2] = dv_fragment_beta[i_s, i_v2] * Beta_shared[i_s]
# for i_s, i_v2 in T.Parallel(block_S, block_DV):
# dbeta_fragment[i_s] = dbeta_fragment[i_s] + dv_fragment_beta[i_s, i_v2] * V_shared[i_s, i_v2]
for i_s, i_v2 in T.Parallel(block_S, block_DV):
dbeta_fragment_reduce_tmpv[i_s,
i_v2] = dv_fragment_beta[i_s, i_v2] * V_shared[i_s,
i_v2]
T.reduce_sum(dbeta_fragment_reduce_tmpv, dbeta_fragment_v, dim=1, clear=False)
T.copy(
dv_fragment, dv[bb, bs * block_S:(bs + 1) * block_S, bh,
i_v * block_DV:(i_v + 1) * block_DV])
# Temporary store dbeta, dg and dA
for i_s in T.Parallel(block_S):
dbeta[bb, bs * block_S + i_s, bh] = dbeta_fragment_k[i_s] + dbeta_fragment_v[i_s]
dg[bb, bs * block_S + i_s, bh] = dg_fragment[i_s]
# correct dA
T.copy(dA_fragment, dA[bb, bs * block_S:(bs + 1) * block_S, bh, :])
return kernel
@tilelang.jit(
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True
})
def tilelang_wy_fast_bwd_split(
# task config
B,
S,
H,
DK,
DV,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
state_dtype,
chunk_size,
# kernel config
block_DK=64,
block_DV=64,
threads=128,
num_stages=0,
):
block_S = chunk_size
BS = block_S
K_shape = (B, S, H, DK)
V_shape = (B, S, H, DV)
Beta_shape = (B, S, H)
G_shape = (B, S, H)
A_shape = (B, S, H, BS)
dw_shape = (B, S, H, DK)
du_shape = (B, S, H, DV)
dk_shape = (B, S, H, DK)
dv_shape = (B, S, H, DV)
dbeta_shape = (B, S, H)
dA_shape = (B, S, H, BS)
@T.prim_func
def kernel(
# input
K: T.Tensor(K_shape, dtype=input_dtype),
V: T.Tensor(V_shape, dtype=input_dtype),
Beta: T.Tensor(Beta_shape, dtype=input_dtype),
G: T.Tensor(G_shape, dtype=gate_dtype),
A: T.Tensor(A_shape, dtype=input_dtype),
dw: T.Tensor(dw_shape, dtype=input_dtype),
du: T.Tensor(du_shape, dtype=input_dtype),
dA: T.Tensor(dA_shape, dtype=input_dtype),
dk: T.Tensor(dk_shape, dtype=output_dtype),
dv: T.Tensor(dv_shape, dtype=output_dtype),
dbeta_k: T.Tensor(dbeta_shape, dtype=output_dtype),
dg_A_positive: T.Tensor(dA_shape, dtype=gate_dtype),
dg_A_negative: T.Tensor(dA_shape, dtype=gate_dtype),
):
with T.Kernel(T.ceildiv(S, block_S), B * H, threads=threads) as (bs, bbh):
bb, bh = bbh // H, bbh % H
A_shared = T.alloc_shared((block_S, block_S), dtype=input_dtype)
A_fragment = T.alloc_fragment((block_S, block_S), dtype=accum_dtype)
K_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype)
K_shared_beta = T.alloc_shared((block_S, block_DK), dtype=input_dtype)
dA_shared = T.alloc_shared((block_S, block_S), dtype=input_dtype)
dA_fragment = T.alloc_fragment((block_S, block_S), dtype=accum_dtype)
dA_A_fragment = T.alloc_fragment((block_S, block_S), dtype=accum_dtype)
dA_A_fragment_1 = T.alloc_fragment((block_S,), dtype=accum_dtype)
dA_A_fragment_2 = T.alloc_fragment((block_S,), dtype=accum_dtype)
dk_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype)
dk_shared_beta = T.alloc_shared((block_S, block_DK), dtype=input_dtype)
dk_fragment = T.alloc_fragment((block_S, block_DK), dtype=accum_dtype)
dk_fragment_beta = T.alloc_fragment((block_S, block_DK), dtype=accum_dtype)
Beta_shared = T.alloc_shared((block_S,), dtype=input_dtype)
dbeta_fragment_reduce_tmpk = T.alloc_fragment((block_S, block_DK), dtype=accum_dtype)
dbeta_fragment_k = T.alloc_fragment((block_S,), dtype=accum_dtype)
G_shared = T.alloc_shared((block_S,), dtype=gate_dtype)
G_shared_exp = T.alloc_shared((block_S,), dtype=gate_dtype)
T.clear(dbeta_fragment_reduce_tmpk)
T.clear(dbeta_fragment_k)
T.clear(dA_A_fragment_1)
T.clear(dA_A_fragment_2)
T.copy(A[bb, bs * block_S:(bs + 1) * block_S, bh, :], A_shared)
for i_s in T.Parallel(block_S):
Beta_shared[i_s] = Beta[bb, bs * block_S + i_s, bh]
G_shared[i_s] = G[bb, bs * block_S + i_s, bh]
for i_s in T.Parallel(block_S):
G_shared_exp[i_s] = T.exp(G_shared[i_s])
# Load intermediate results
# for i_s in T.Parallel(block_S):
# dbeta_fragment[i_s] = dbeta[bb, bs * block_S + i_s, bh]
# dg_fragment[i_s] = dg[bb, bs * block_S + i_s, bh]
T.copy(dA[bb, bs * block_S:(bs + 1) * block_S, bh, :], dA_shared)
# T.copy(dA_shared, dA[bb, bs * block_S:(bs + 1) * block_S, bh, :])
# Update dA
T.copy(dA_shared, dA_fragment)
for i_s1, i_s2 in T.Parallel(block_S, block_S):
with T.If(i_s1 <= i_s2): # noqa: SIM117
with T.Then():
dA_fragment[i_s1, i_s2] = 0
T.copy(dA_fragment, dA_shared)
T.gemm(dA_shared, A_shared, dA_fragment, clear_accum=True, transpose_B=True)
T.copy(dA_fragment, dA_shared)
T.gemm(A_shared, dA_shared, dA_fragment, clear_accum=True, transpose_A=True)
for i_s1, i_s2 in T.Parallel(block_S, block_S):
with T.If(i_s1 <= i_s2):
with T.Then():
dA_fragment[i_s1, i_s2] = 0
with T.Else():
dA_fragment[i_s1, i_s2] = -dA_fragment[i_s1, i_s2]
for i_s1, i_s2 in T.Parallel(block_S, block_S):
with T.If(G[bb, bs * block_S + i_s1, bh] - G[bb, bs * block_S + i_s2, bh] <= 0):
with T.Then():
dA_fragment[i_s1, i_s2] *= T.exp(G[bb, bs * block_S + i_s1, bh] -
G[bb, bs * block_S + i_s2, bh])
with T.Else():
dA_fragment[i_s1, i_s2] = 0
T.copy(dA_fragment, dA_shared)
# acceptable dA diff
# T.copy(dA_fragment, dA[bb, bs * block_S:(bs + 1) * block_S, bh, :])
# Update dk using previous dk
T.clear(A_fragment)
for i_k in T.Pipelined(T.ceildiv(DK, block_DK), num_stages=num_stages):
T.copy(
K[bb, bs * block_S:(bs + 1) * block_S, bh, i_k * block_DK:(i_k + 1) * block_DK],
K_shared)
T.copy(
dk[bb, bs * block_S:(bs + 1) * block_S, bh,
i_k * block_DK:(i_k + 1) * block_DK], dk_shared)
T.copy(dk_shared, dk_fragment)
for i_s, i_k2 in T.Parallel(block_S, block_DK):
K_shared_beta[i_s, i_k2] = K_shared[i_s, i_k2] * Beta_shared[i_s]
T.gemm(K_shared_beta, K_shared, A_fragment, transpose_B=True)
T.gemm(dA_shared, K_shared, dk_fragment_beta, clear_accum=True)
# for i_s, i_k2 in T.Parallel(block_S, block_DK):
# dbeta_fragment[i_s] = dbeta_fragment[i_s] + dk_fragment_beta[i_s, i_k2] * K_shared[i_s, i_k2]
for i_s, i_k2 in T.Parallel(block_S, block_DK):
dbeta_fragment_reduce_tmpk[i_s,
i_k2] = dk_fragment_beta[i_s, i_k2] * K_shared[i_s,
i_k2]
T.reduce_sum(dbeta_fragment_reduce_tmpk, dbeta_fragment_k, dim=1, clear=False)
T.gemm(dA_shared, K_shared_beta, dk_fragment, transpose_A=True)
for i_s, i_k2 in T.Parallel(block_S, block_DK):
dk_shared_beta[i_s, i_k2] = dk_fragment_beta[i_s, i_k2] * Beta_shared[i_s]
for i_s, i_k2 in T.Parallel(block_S, block_DK):
dk_fragment[i_s, i_k2] = dk_fragment[i_s, i_k2] + dk_shared_beta[i_s, i_k2]
T.copy(
dk_fragment, dk[bb, bs * block_S:(bs + 1) * block_S, bh,
i_k * block_DK:(i_k + 1) * block_DK])
# Update dg and dbeta
T.copy(A_fragment, A_shared)
for i_s1, i_s2 in T.Parallel(block_S, block_S):
dA_A_fragment[i_s1, i_s2] = dA_fragment[i_s1, i_s2] * A_fragment[i_s1, i_s2]
# Note: Reduce operation now not supported in shared memory
# FIXME: reduce will cause incorrect result when dim != -1
T.reduce_sum(dA_A_fragment, dA_A_fragment_1, dim=1)
T.reduce_sum(dA_A_fragment, dA_A_fragment_2, dim=0)
for i_s1, i_s2 in T.Parallel(block_S, block_S):
dg_A_positive[bb, bs * block_S + i_s1, bh, i_s2] = dA_A_fragment[i_s1, i_s2]
dg_A_negative[bb, bs * block_S + i_s2, bh, i_s1] = dA_A_fragment[i_s1, i_s2]
for i_s in T.Parallel(block_S):
dbeta_k[bb, bs * block_S + i_s, bh] = dbeta_fragment_k[i_s]
return kernel
def run_test(
B,
S,
H,
DK,
DV,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
state_dtype,
chunk_size,
block_DK=64,
block_DV=64,
threads=128,
num_stages=0,
):
K, V, Beta, G, A, dw, du = prepare_input(B, S, H, DK, DV, chunk_size,
getattr(torch, input_dtype),
getattr(torch, output_dtype),
getattr(torch,
accum_dtype), getattr(torch, gate_dtype),
getattr(torch, state_dtype))
dk_ref, dv_ref, dbeta_ref, dg_ref = prepare_output(B, S, H, DK, DV, chunk_size,
getattr(torch, output_dtype),
getattr(torch, gate_dtype),
getattr(torch, state_dtype))
dk_tilelang, dv_tilelang, dbeta_tilelang, dg_tilelang = prepare_output(
B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype),
getattr(torch, state_dtype))
BS = chunk_size
dA_tilelang = torch.empty(B, S, H, BS, dtype=getattr(torch, input_dtype)).cuda()
dbeta_tilelang_k = torch.empty(B, S, H, dtype=getattr(torch, output_dtype)).cuda()
dg_tilelang_A_positive = torch.empty(B, S, H, BS, dtype=getattr(torch, gate_dtype)).cuda()
dg_tilelang_A_negative = torch.empty(B, S, H, BS, dtype=getattr(torch, gate_dtype)).cuda()
# ref
dk_ref, dv_ref, dbeta_ref, dg_ref = bwd_prepare_wy_repr(
K, V, G, Beta, A, dw, du, cu_seqlens=None)
# tilelang
kernel = tilelang_wy_fast_bwd(B, S, H, DK, DV, input_dtype, output_dtype, accum_dtype,
gate_dtype, state_dtype, chunk_size, block_DK, block_DV, threads,
num_stages)
dA_tilelang, dk_tilelang, dv_tilelang, dbeta_tilelang, dg_tilelang = kernel(
K, V, Beta, G, A, dw, du)
torch.cuda.synchronize()
kernel_split = tilelang_wy_fast_bwd_split(B, S, H, DK, DV, input_dtype, output_dtype,
accum_dtype, gate_dtype, state_dtype, chunk_size,
block_DK, block_DV, threads, num_stages)
kernel_split(K, V, Beta, G, A, dw, du, dA_tilelang, dk_tilelang, dv_tilelang, dbeta_tilelang_k,
dg_tilelang_A_positive, dg_tilelang_A_negative)
torch.cuda.synchronize()
dbeta_tilelang = dbeta_tilelang_k + dbeta_tilelang
dg_tilelang = dg_tilelang + dg_tilelang_A_positive.sum(dim=-1) - dg_tilelang_A_negative.sum(
dim=-1)
from utils import assert_similar
assert_similar(dk_ref, dk_tilelang, eps=1e-5, name="dk", raise_assert=False)
assert_similar(dv_ref, dv_tilelang, eps=1e-5, name="dv", raise_assert=False)
assert_similar(dbeta_ref, dbeta_tilelang, eps=1e-5, name="dbeta", raise_assert=False)
assert_similar(dg_ref, dg_tilelang, eps=1e-5, name="dg", raise_assert=False)
def main():
DK = 128
DV = 128
run_test(
B=1,
S=32768,
H=8,
DK=DK,
DV=DV,
input_dtype="bfloat16",
output_dtype="bfloat16",
accum_dtype="float32",
gate_dtype="float32",
state_dtype="float32",
chunk_size=64,
block_DK=32,
block_DV=32,
threads=128,
num_stages=0,
)
if __name__ == "__main__":
main()
import tilelang.testing
import torch
B = 1
S = 1024 # small but for test only.
H = 32
DK = 128
DV = 128
input_dtype = "bfloat16"
output_dtype = "bfloat16"
accum_dtype = "float32"
gate_dtype = "float32"
state_dtype = "float32"
chunk_size = 64
use_g = True
use_initial_state = True
store_final_state = True
use_final_state_gradient = True
save_new_value = True
block_DK = 64
block_DV = 32
threads = 128
num_stages = 1
def test_example_wy_fast_compilation():
from example_wy_fast import tilelang_recompute_w_u_fwd, prepare_input
K, V, Beta, G, A = prepare_input(
B,
S,
H,
DK,
DV,
chunk_size,
getattr(torch, input_dtype),
getattr(torch, output_dtype),
gate_dtype=getattr(torch, gate_dtype))
# tilelang
block_S = chunk_size
kernel = tilelang_recompute_w_u_fwd(
B,
S,
H,
DK,
DV,
input_dtype,
output_dtype,
gate_dtype,
accum_dtype,
chunk_size,
block_S=block_S,
block_DK=block_DK,
block_DV=block_DV,
threads=threads,
num_stages=num_stages)
print(kernel.get_kernel_source())
W_tilelang, U_tilelang = kernel(K, V, Beta, G, A)
def test_example_wy_fast_bwd_split_compilation():
from example_wy_fast_bwd_split import tilelang_wy_fast_bwd, tilelang_wy_fast_bwd_split, prepare_input, prepare_output
K, V, Beta, G, A, dw, du = prepare_input(B, S, H, DK, DV, chunk_size,
getattr(torch, input_dtype),
getattr(torch, output_dtype),
getattr(torch,
accum_dtype), getattr(torch, gate_dtype),
getattr(torch, state_dtype))
dk_tilelang, dv_tilelang, dbeta_tilelang, dg_tilelang = prepare_output(
B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype),
getattr(torch, state_dtype))
BS = chunk_size
dA_tilelang = torch.empty(B, S, H, BS, dtype=getattr(torch, input_dtype)).cuda()
dbeta_tilelang_k = torch.empty(B, S, H, dtype=getattr(torch, output_dtype)).cuda()
dg_tilelang_A_positive = torch.empty(B, S, H, BS, dtype=getattr(torch, gate_dtype)).cuda()
dg_tilelang_A_negative = torch.empty(B, S, H, BS, dtype=getattr(torch, gate_dtype)).cuda()
# tilelang
kernel = tilelang_wy_fast_bwd(B, S, H, DK, DV, input_dtype, output_dtype, accum_dtype,
gate_dtype, state_dtype, chunk_size, block_DK, block_DV, threads,
num_stages)
dA_tilelang, dk_tilelang, dv_tilelang, dbeta_tilelang, dg_tilelang = kernel(
K, V, Beta, G, A, dw, du)
torch.cuda.synchronize()
kernel_split = tilelang_wy_fast_bwd_split(B, S, H, DK, DV, input_dtype, output_dtype,
accum_dtype, gate_dtype, state_dtype, chunk_size,
block_DK, block_DV, threads, num_stages)
kernel_split(K, V, Beta, G, A, dw, du, dA_tilelang, dk_tilelang, dv_tilelang, dbeta_tilelang_k,
dg_tilelang_A_positive, dg_tilelang_A_negative)
torch.cuda.synchronize()
dbeta_tilelang = dbeta_tilelang_k + dbeta_tilelang
dg_tilelang = dg_tilelang + dg_tilelang_A_positive.sum(dim=-1) - dg_tilelang_A_negative.sum(
dim=-1)
def test_example_chunk_o_compilation():
from example_chunk_o import tilelang_chunk_fwd_o, prepare_input
Q, K, V, HIDDEN, G = prepare_input(B, S, H, DK, DV, chunk_size, getattr(torch, input_dtype),
getattr(torch, output_dtype), getattr(torch, accum_dtype),
getattr(torch, gate_dtype))
scale = 1.0 / DK**0.5
block_S = chunk_size
kernel = tilelang_chunk_fwd_o(B, S, H, DK, DV, input_dtype, output_dtype, accum_dtype,
gate_dtype, chunk_size, scale, use_g, block_S, block_DK, block_DV,
threads, num_stages)
O_tilelang = kernel(Q, K, V, HIDDEN, G) # noqa: F841
def test_example_chunk_o_bwd_compilation():
from example_chunk_o_bwd import tilelang_chunk_o_bwd_dqkwg, prepare_input
Q, K, V, h, G, dO, dh, dv, W = prepare_input(B, S, H, DK, DV, chunk_size,
getattr(torch, input_dtype),
getattr(torch, output_dtype),
getattr(torch, accum_dtype),
getattr(torch, gate_dtype),
getattr(torch, state_dtype))
kernel = tilelang_chunk_o_bwd_dqkwg(B, S, H, DK, DV, input_dtype, output_dtype, accum_dtype,
gate_dtype, state_dtype, chunk_size, 1.0, use_g, True,
block_DK, block_DV, threads, num_stages)
dq_tilelang, dk_tilelang, dw_tilelang, dg_tilelang = kernel(Q, K, V, h, G, dO, dh, dv,
W) # noqa: F841
if use_g:
dg_tilelang = dg_tilelang.sum(dim=0)
def test_example_chunk_scaled_dot_kkt_compilation():
from example_chunk_scaled_dot_kkt import tilelang_chunk_scaled_dot_kkt_fwd, prepare_input
K, Beta, G = prepare_input(B, S, H, DK, getattr(torch, input_dtype),
getattr(torch, output_dtype), getattr(torch, accum_dtype))
block_S = chunk_size
kernel = tilelang_chunk_scaled_dot_kkt_fwd(B, S, H, DK, chunk_size, input_dtype, output_dtype,
accum_dtype, use_g, block_S, block_DK, threads,
num_stages)
A_tilelang = kernel(K, Beta, G) # noqa: F841
def test_example_cumsum_compilation():
from example_cumsum import tilelang_chunk_local_cumsum_scalar, prepare_cumsum_input, prepare_cumsum_output
G = prepare_cumsum_input(B, S, H, getattr(torch, gate_dtype))
G_new_tilelang = prepare_cumsum_output(B, S, H, getattr(torch, gate_dtype))
block_S = chunk_size
kernel = tilelang_chunk_local_cumsum_scalar(
B=B,
S=S,
H=H,
chunk_size=chunk_size,
reverse=False,
head_first=False,
input_dtype=gate_dtype,
output_dtype=gate_dtype,
block_S=block_S,
threads=threads,
use_fragment=False,
)
G_new_tilelang = kernel(G) # noqa: F841
def test_example_chunk_delta_h_compilation():
from example_chunk_delta_h import tilelang_chunk_gated_delta_rule_fwd_h, prepare_input
K, W, U, G, initial_state = prepare_input(B, S, H, DK, DV, chunk_size,
getattr(torch, input_dtype),
getattr(torch, output_dtype),
getattr(torch, accum_dtype),
getattr(torch, gate_dtype))
kernel = tilelang_chunk_gated_delta_rule_fwd_h(B, S, H, DK, DV, input_dtype, output_dtype,
accum_dtype, gate_dtype, state_dtype, chunk_size,
use_g, use_initial_state, store_final_state,
save_new_value, block_DK, block_DV, threads,
num_stages)
h_tilelang, final_state_tilelang, V_new_tilelang = kernel(K, W, U, G,
initial_state) # noqa: F841
def test_example_chunk_delta_bwd_compilation():
from example_chunk_delta_bwd import tilelang_chunk_gated_delta_rule_bwd_dhu, prepare_input
Q, K, W, G, h0, dht, dO, dv = prepare_input(B, S, H, DK, DV, chunk_size,
getattr(torch, input_dtype),
getattr(torch, output_dtype),
getattr(torch, accum_dtype),
getattr(torch, gate_dtype),
getattr(torch, state_dtype))
kernel = tilelang_chunk_gated_delta_rule_bwd_dhu(B, S, H, DK, DV, input_dtype, output_dtype,
accum_dtype, gate_dtype, state_dtype,
chunk_size, 1.0, use_g, use_initial_state,
use_final_state_gradient, block_DV, threads,
num_stages)
dh_tilelang, dh0_tilelang, dv2_tilelang = kernel(Q, K, W, G, h0, dht, dO, dv) # noqa: F841
if __name__ == "__main__":
tilelang.testing.main()
import torch
def print_red_warning(message):
print(f"\033[31mWARNING: {message}\033[0m")
def calc_sim(x, y, name="tensor"):
x, y = x.data.double(), y.data.double()
denominator = (x * x + y * y).sum()
if denominator == 0:
print_red_warning(f'{name} all zero')
return 1
sim = 2 * (x * y).sum() / denominator
return sim
def assert_similar(x, y, eps=1e-8, name="tensor", data="", raise_assert=True):
x_mask = torch.isfinite(x)
y_mask = torch.isfinite(y)
if not torch.all(x_mask == y_mask):
print_red_warning(f'{name} Error: isfinite mask mismatch')
if raise_assert:
raise AssertionError
if not torch.isclose(
x.masked_fill(x_mask, 0), y.masked_fill(y_mask, 0), rtol=0, atol=0,
equal_nan=True).all():
print_red_warning(f'{name} Error: nonfinite value mismatch')
if raise_assert:
raise AssertionError
x = x.masked_fill(~x_mask, 0)
y = y.masked_fill(~y_mask, 0)
sim = calc_sim(x, y, name)
diff = 1. - sim
if not (0 <= diff <= eps):
print_red_warning(f'{name} Error: {diff}')
if raise_assert:
raise AssertionError
else:
print(f"{name} {data} passed")
# TileLang GEMM (Matrix Multiplication) Examples
TileLang is a domain-specific language designed to simplify the process of writing GPU kernels. It provides high-level abstractions for memory allocation, scheduling, and tiling, which are critical for achieving maximum performance on modern hardware architectures like NVIDIA GPUs. This README demonstrates how to write and optimize a matrix multiplication (GEMM) kernel using TileLang.
## Table of Contents
1. [Getting Started](#getting-started)
2. [Simple GEMM Example](#simple-gemm-example)
- [Code Walkthrough](#code-walkthrough)
- [Compiling and Profiling](#compiling-and-profiling)
3. [Advanced GEMM Features](#advanced-gemm-features)
- [Custom Memory Layout / Swizzling](#custom-memory-layout--swizzling)
- [Parallel Copy and Auto-Pipelining](#parallel-copy-and-auto-pipelining)
- [Rasterization for L2 Cache Locality](#rasterization-for-l2-cache-locality)
4. [Enhanced GEMM Example with Annotations](#enhanced-gemm-example-with-annotations)
5. [Verifying Correctness](#verifying-correctness)
6. [Fine-grained MMA Computations](#fine-grained-mma-computations)
- [Example Workflow](#example-workflow)
- [Summary](#summary)
7. [References](#references)
---
## Getting Started
### Prerequisites
- **Python 3.8+**
- **NVIDIA GPU** with a recent CUDA toolkit installed
- **PyTorch** (optional, for easy correctness verification)
- **tilelang**
- **bitblas** (optional; used for swizzle layout utilities in the advanced examples)
### Installation
```bash
pip install tilelang bitblas
```
*(Adjust accordingly if you are installing from source or using a different environment.)*
---
## Simple GEMM Example
Below is a basic matrix multiplication (GEMM) example demonstrating how TileLang handles buffer allocation, tiling, and kernel dispatch. For simplicity, we'll multiply two 1024×1024 matrices using 128 threads/block.
```python
import tilelang
from tilelang import Profiler
import tilelang.language as T
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
):
# Define a grid with enough blocks to cover M×N
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
# Allocate shared memory for the current tile of A and B
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
# Allocate a local (register) fragment for partial accumulations
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
# Initialize the local accumulation buffer to zero
T.clear(C_local)
# Loop over the K dimension in block_K chunks, using a 3-stage pipeline
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
# Copy from global memory to shared memory
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[k * block_K, bx * block_N], B_shared)
# Perform a matrix multiply-accumulate on the tile
T.gemm(A_shared, B_shared, C_local)
# Copy the accumulated result from local memory (C_local) to global memory (C)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
```
### Code Walkthrough
1. **Define the Kernel Launch Configuration:**
```python
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
```
This creates a grid of blocks (ceildiv(N, block_N) in x-dimension, ceildiv(M, block_M) in y-dimension), each with 128 threads.
2. **Shared Memory Allocation:**
```python
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
```
Tiles of \(A\) and \(B\) are loaded into these shared memory buffers for faster access.
3. **Local Fragment Accumulation:**
```python
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
```
Partial results are stored in registers (or local memory) to reduce writes to global memory.
4. **Pipelined Loading and GEMM:**
```python
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
T.copy(...)
T.gemm(...)
```
Loads blocks of \(A\) and \(B\) in a pipelined fashion (up to 3 stages). This exploits overlap of data transfer and computation.
5. **Copy Out the Results:**
```python
T.copy(C_local, C[by * block_M, bx * block_N])
```
Writes the final computed tile from registers/shared memory to global memory.
### Compiling and Profiling
```python
func = matmul(1024, 1024, 1024, 128, 128, 32)
print(func) # Prints an IR-like representation of the TileLang kernel
artifact = tilelang.lower(func)
profiler = Profiler(artifact.rt_mod, artifact.params, result_idx=[2])
import torch
a = torch.randn(1024, 1024).cuda().half()
b = torch.randn(1024, 1024).cuda().half()
c = profiler(a, b)
ref_c = a @ b
# Validate results
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
# Get CUDA Kernel Source
print(artifact.kernel_source)
```
---
## Advanced GEMM Features
### Custom Memory Layout / Swizzling
**Swizzling** rearranges data in shared memory or global memory to mitigate bank conflicts, improve cache utilization, and better match the GPU’s warp execution pattern. TileLang provides helper functions like `make_swizzle_layout` to annotate how buffers should be laid out in memory.
### Parallel Copy and Auto-Pipelining
- **Parallel Copy** allows you to distribute the copy of a block tile across all threads in a block, speeding up the transfer from global memory to shared memory.
- **Auto-Pipelining** uses multiple stages to overlap copying with computation, reducing idle cycles.
### Rasterization for L2 Cache Locality
Enabling **swizzle (rasterization)** at the kernel level can improve data reuse and reduce cache thrashing in L2. This is especially important when matrices are large.
---
## Enhanced GEMM Example with Annotations
Below is a more advanced snippet that showcases how to apply memory layouts, enable swizzling, and parallelize the copy operations to maximize performance:
```python
import tilelang.language as T
# `make_mma_swizzle_layout` is a python-defined layout function
# that helps align data for MMA (Matrix Multiply-Accumulate) operations.
from tilelang.intrinsics import make_mma_swizzle_layout as make_swizzle_layout
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
# Allocate shared and local fragments
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
# Annotate memory layout
T.annotate_layout({
A_shared: make_swizzle_layout(A_shared),
B_shared: make_swizzle_layout(B_shared),
})
# Enable swizzle-based rasterization for better L2 locality
T.use_swizzle(panel_size=10, enable=True)
# Clear the local accumulation buffer
T.clear(C_local)
# Pipelined iteration over K dimension
for idx in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
# Copy tile of A
T.copy(A[by * block_M, idx * block_K], A_shared)
# Parallel copy tile of B
for ko, j in T.Parallel(block_K, block_N):
B_shared[ko, j] = B[idx * block_K + ko, bx * block_N + j]
# Perform local GEMM on the shared-memory tiles
T.gemm(A_shared, B_shared, C_local)
# Copy the result tile back
T.copy(C_local, C[by * block_M, bx * block_N])
return main
```
**Key Differences vs. Basic Example**
1. **`T.annotate_layout(...)`**: Annotates how data should be organized in shared memory (swizzling).
2. **`T.use_swizzle(...)`**: Enables swizzle-based rasterization.
3. **Parallel Copy Loop** with `T.Parallel(...)`: Distributes global-to-shared copy across all threads, potentially vectorizing load/store instructions.
---
## Verifying Correctness
Once you compile and load your kernel into a runtime module (`rt_mod`), you can use tools like **PyTorch** to easily create random matrices on the GPU, run your TileLang kernel, and compare the results to a reference implementation (e.g., `torch.matmul` or `@` operator).
```python
import torch
# Suppose your compiled kernel is in rt_mod
profiler = Profiler(rt_mod, params, result_idx=[2])
A = torch.randn(1024, 1024).cuda().half()
B = torch.randn(1024, 1024).cuda().half()
C_tilelang = profiler(A, B)
C_ref = A @ B
torch.testing.assert_close(C_tilelang, C_ref, rtol=1e-2, atol=1e-2)
print("Results match!")
```
---
## Fine-grained MMA Computations
For advanced users who require full control over warp-level matrix multiplication operations, TileLang allows you to specify fine-grained MMA (Matrix Multiply-Accumulate) computations in a manner similar to writing raw CUDA. While higher-level abstractions like `T.gemm(...)` or automatic MMA emitters are sufficient for many use cases, specialized workloads (for example, dequantize gemm may require fine-grained layout transformation on shared to register stage) may benefit from explicitly controlling each MMA instruction, the data layout, and the synchronization points.
### Example Workflow
```python
@simplify_prim_func
def tl_matmul(
M,
N,
K,
in_dtype,
out_dtype,
accum_dtype,
):
assert in_dtype in [
"float16",
"int8",
], "Currently only float16 and int8 are supported"
assert out_dtype in [
"float16",
"float32",
"int32",
], "Currently only float16, float32 and int32 are supported"
micro_size_x = micro_size_y = micro_size_k = 16
if out_dtype == "int32":
micro_size_k = 32
# This is a debug config
block_row_warps = 2
block_col_warps = 2
warp_row_tiles = 32
warp_col_tiles = 32
chunk = 32
shared_scope = "shared.dyn"
# Pipeline Stage
stage = 2
block_M = block_row_warps * warp_row_tiles
block_N = block_col_warps * warp_col_tiles
block_K = chunk
A_shape = (M, K)
B_shape = (N, K)
A_shared_shape = (block_M, block_K)
B_shared_shape = (block_N, block_K)
C_shared_shape = (
block_M // micro_size_x,
block_N // micro_size_y,
micro_size_x,
micro_size_y,
)
warp_size = 32
threads = warp_size * (block_row_warps * block_col_warps)
local_size_a = (micro_size_x * micro_size_k) // warp_size
local_size_b = (micro_size_y * micro_size_k) // warp_size
local_size_c = (micro_size_x * micro_size_y) // warp_size
warp_rows = warp_row_tiles // micro_size_x
warp_cols = warp_col_tiles // micro_size_y
# MMA Wrapper to Auto Generate Code for MMA
mma_emitter = TensorCoreIntrinEmitter(
a_dtype=in_dtype,
b_dtype=in_dtype,
accum_dtype=accum_dtype,
a_transposed=False,
b_transposed=True,
block_row_warps=block_row_warps,
block_col_warps=block_col_warps,
warp_row_tiles=warp_row_tiles,
warp_col_tiles=warp_col_tiles,
chunk=chunk,
)
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope)
B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope)
C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope)
A_local = T.alloc_local((warp_rows * local_size_a), in_dtype)
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype)
T.annotate_layout({
A_shared: make_swizzle_layout(A_shared),
B_shared: make_swizzle_layout(B_shared),
})
# Improve L2 Cache
T.use_swizzle(panel_size=10)
T.clear(C_local)
for ko in T.Pipelined((K // block_K), num_stages=stage):
# Load A into shared memory
for i, k in T.Parallel(block_M, block_K):
A_shared[i, k] = A[by * block_M + i, ko * block_K + k]
# Load B into shared memory
for j, k in T.Parallel(block_N, block_K):
B_shared[j, k] = B[bx * block_N + j, ko * block_K + k]
for ki in T.serial(0, (block_K // micro_size_k)):
# Load A into fragment
mma_emitter.ldmatrix_a(
A_local,
A_shared,
ki
)
# Load B into fragment
mma_emitter.ldmatrix_b(
B_local,
B_shared,
ki
)
# Perform Matrix Multiplication
mma_emitter.mma(A_local, B_local, C_local)
# Perform STMatrix
mma_emitter.stmatrix(
C_local,
C_shared,
)
# Store shared into global
for i, j in T.Parallel(block_M, block_N):
C[by * block_M + i, bx * block_N + j] = C_shared[
i // micro_size_x,
j // micro_size_y,
i % micro_size_x,
j % micro_size_y,
]
```
1. **Set Up Tile Sizes and Thread Bindings**
Just like in CUDA, you will typically start by defining how many warps or threads per block you want and how your matrix is subdivided. In TileLang, this is done via `T.Kernel(...)` and `T.thread_binding(...),` which ensure that the correct number of threads are active, and each thread is bound to a specific role (e.g., warp ID or lane ID).
2. **Allocate Warp-local Fragments**
Instead of using a single shared buffer for partial sums, you allocate local buffers (register fragments) to hold sub-blocks of matrices \(A\) and \(B\). In TileLang, this is done with something like:
```python
A_local = T.alloc_local((warp_rows * local_size_a), in_dtype)
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype)
```
Each of these `local` allocations represents a region of per-thread storage, which collectively forms the warp’s register tiles.
3. **Load Data via `ldmatrix`**
Fine-grained loading instructions allow you to specify exactly how data moves from shared memory to the warp-level fragments. In the example below, `mma_emitter.ldmatrix_a()` and `.ldmatrix_b()` are higher-level wrappers around warp-synchronous intrinsics. You can write your own load logic as well:
```python
for ki in T.serial(0, (block_K // micro_size_k)):
# Warp-synchronous load for A
mma_emitter.ldmatrix_a(A_local, A_shared, ki)
# Warp-synchronous load for B
mma_emitter.ldmatrix_b(B_local, B_shared, ki)
```
Internally, these calls orchestrate how each thread in the warp issues the correct load instructions, performs address calculations, and stores the data into registers.
4. **Perform the MMA Instruction**
After loading sub-tiles (fragments), the warp executes the `mma` instruction. This operation is essentially:
\[
C_{\text{local}} \;+=\; A_{\text{local}} \;\times\; B_{\text{local}}
\]
where each thread in the warp calculates a small portion of the final tile. For instance:
```python
mma_emitter.mma(A_local, B_local, C_local)
```
Under the hood, this translates into Tensor Core instructions (e.g., `wmma.mma.sync` in PTX), which process multiple data elements per warp in parallel.
5. **Store Results via `stmatrix`**
Finally, you write the results from the warp-level fragments back to shared memory or global memory. This step might happen multiple times in a loop or just once at the end. The code snippet:
```python
mma_emitter.stmatrix(C_local, C_shared)
```
orchestrates the warp-synchronous stores, ensuring each thread places the correct fragment element into the correct location of the shared or global buffer.
### Summary
By combining warp-synchronous intrinsics (`ldmatrix`, `mma`, `stmatrix`) with manual thread bindings and memory allocations, you can replicate the control and performance of raw CUDA at the TileLang level. This approach is best suited for expert users who are comfortable with GPU warp-level programming, since it does require a deep understanding of hardware concurrency, memory hierarchies, and scheduling. However, the payoff can be significant for performance-critical paths, where every byte of bandwidth and every cycle of latency must be carefully orchestrated.
---
## References
- [NVIDIA CUTLASS Library](https://github.com/NVIDIA/cutlass): A collection of high-performance CUDA C++ template abstractions for GEMM.
- [NVIDIA CUDA Programming Guide](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html): Official documentation for CUDA.
- [PyTorch Documentation](https://pytorch.org/docs): For verifying correctness via CPU or GPU-based matmul.
import tilelang
import tilelang.language as T
@tilelang.jit(out_idx=[-1])
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
@T.prim_func
def gemm(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local)
T.copy(C_local, C[by * block_M, bx * block_N])
return gemm
def main():
kernel = matmul(1024, 1024, 1024, 128, 128, 32)
import torch
a = torch.randn(1024, 1024).cuda().half()
b = torch.randn(1024, 1024).cuda().half()
c = kernel(a, b)
ref_c = a @ b
print("c:")
print(c)
print("ref_c:")
print(ref_c)
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
print("All check passed.")
# Get CUDA Source
print("CUDA Source:")
print(kernel.get_kernel_source())
# benchmark
profiler = kernel.get_profiler()
latency = profiler.do_bench(backend="cupti")
# latency = profiler.do_bench()
print(f"tilelang Latency: {latency}ms")
if __name__ == "__main__":
main()
import argparse
import itertools
import tilelang as tl
import tilelang.language as T
from tilelang.autotuner import AutoTuner
from tilelang.carver.template import MatmulTemplate
from tilelang.carver.arch import CUDA
from tilelang.carver.arch import CDNA
from tilelang.carver.roller.rasterization import NoRasterization
import torch
def ref_program(A, B):
"""
Compute the matrix product of A and the transpose of B.
A and B are expected to be 2-D tensors where A has shape (M, K) and B has shape (N, K). The result is a tensor with shape (M, N) equal to A @ B.T, using the inputs' dtypes.
"""
return A @ B.T
def get_configs(M, N, K, with_roller=False, topk=20):
"""
Generate a list of kernel tuning configuration dictionaries for a tiled matrix-multiply.
When with_roller is True this queries the MatmulTemplate roller to produce up to `topk` recommended
configurations (device-specific TensorCore-friendly tilings). Each returned dict contains:
- block_M, block_N, block_K: tile sizes
- num_stages: pipeline staging (0 means no explicit staging)
- thread_num: total threads used for the block
- enable_rasteration: whether a rasterization/swizzle layout was recommended (note spelling)
When with_roller is False this returns the Cartesian product of a fixed set of candidate
parameters; the returned dicts use the backward-compatible key name "enable_rasteration" for that flag.
Parameters:
M, N, K (int): GEMM dimensions used to generate valid tile sizes.
with_roller (bool): If True, use MatmulTemplate's roller to generate device-aware hints;
otherwise use a predefined candidate grid.
topk (int): Maximum number of roller hints to request when with_roller is True.
Returns:
List[dict]: A list of configuration dictionaries as described above.
Raises:
ValueError: if with_roller is True but the roller returns no hints.
"""
if with_roller:
arch = CUDA("cuda") if torch.version.hip is None else CDNA("hip")
carve_template = MatmulTemplate(
M=M,
N=N,
K=K,
in_dtype="float16",
out_dtype="float16",
accum_dtype="float",
).with_arch(arch)
func = carve_template.equivalent_function()
assert func is not None, "Function is None"
roller_hints = carve_template.recommend_hints(topk=topk)
if roller_hints is None:
raise ValueError("No Roller Hints Found for TensorCore Scheduling")
configs = []
for hint in roller_hints:
config = {}
block_m, block_n = hint.block
warp_m, warp_n = hint.warp
# block_rows, block_cols represents warp partitioning
block_rows, block_cols = block_m // warp_m, block_n // warp_n
config["block_M"] = block_m
config["block_N"] = block_n
config["block_K"] = hint.rstep[0]
config["num_stages"] = hint.pipeline_stage if hint.pipeline_stage > 1 else 0
config["thread_num"] = block_rows * block_cols * 32
config["enable_rasteration"] = hint.rasterization_plan is not NoRasterization
configs.append(config)
else:
block_M = [64, 128, 256]
block_N = [64, 128, 256]
block_K = [32, 64]
num_stages = [0, 1, 2, 3]
thread_num = [128, 256]
enable_rasterization = [True, False]
_configs = list(
itertools.product(
block_M,
block_N,
block_K,
num_stages,
thread_num,
enable_rasterization,
))
configs = [
{
"block_M": c[0],
"block_N": c[1],
"block_K": c[2],
"num_stages": c[3],
"thread_num": c[4],
"enable_rasteration": c[5], # keep param name for backward-compat
} for c in _configs
]
return configs
def get_best_config(M, N, K, with_roller=False):
def kernel(
block_M=None,
block_N=None,
block_K=None,
num_stages=None,
thread_num=None,
enable_rasteration=None,
):
dtype = "bfloat16"
accum_dtype = "float"
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
B: T.Tensor((N, K), dtype),
C: T.Tensor((M, N), dtype),
):
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_N, block_K), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
C_shared = T.alloc_shared((block_M, block_N), dtype)
T.use_swizzle(panel_size=10, enable=enable_rasteration)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[bx * block_N, k * block_K], B_shared)
T.gemm(
A_shared,
B_shared,
C_local,
transpose_B=True,
)
T.copy(C_local, C_shared)
T.copy(C_shared, C[by * block_M, bx * block_N])
return main
autotuner = AutoTuner.from_kernel(
kernel=kernel, configs=get_configs(M, N, K, with_roller)).set_compile_args(
out_idx=[-1],
target="auto",
).set_profile_args(
supply_type=tl.TensorSupplyType.Integer,
ref_prog=ref_program,
skip_check=False,
)
return autotuner.run(warmup=3, rep=20)
def get_heuristic_config() -> dict:
# Get CUDA device properties
if not torch.cuda.is_available():
raise RuntimeError("CUDA is not available")
device = torch.cuda.current_device()
sm_major, sm_minor = torch.cuda.get_device_capability(device)
sm_version = sm_major * 10 + sm_minor
print(f"CUDA device capability: {sm_version}")
if sm_version in {80}:
return {
"block_M": 128,
"block_N": 256,
"block_K": 32,
"num_stages": 2,
"thread_num": 128,
"enable_rasteration": True
}
elif sm_version in {90}:
return {
"block_M": 128,
"block_N": 256,
"block_K": 64,
"num_stages": 3,
"thread_num": 256,
"enable_rasteration": True
}
else:
return {
"block_M": 128,
"block_N": 256,
"block_K": 32,
"num_stages": 0,
"thread_num": 128,
"enable_rasteration": True
}
@tl.jit(out_idx=[-1])
def matmul(M,
N,
K,
block_M,
block_N,
block_K,
num_stages,
thread_num,
enable_rasteration,
dtype="float16",
accum_dtype="float"):
@T.prim_func
def gemm_autotune(
A: T.Tensor((M, K), dtype),
B: T.Tensor((N, K), dtype),
C: T.Tensor((M, N), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_N, block_K), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
C_shared = T.alloc_shared((block_M, block_N), dtype)
T.use_swizzle(panel_size=10, enable=enable_rasteration)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[bx * block_N, k * block_K], B_shared)
T.gemm(
A_shared,
B_shared,
C_local,
transpose_B=True,
)
T.copy(C_local, C_shared)
T.copy(C_shared, C[by * block_M, bx * block_N])
return gemm_autotune
def main(M: int = 4096,
N: int = 4096,
K: int = 4096,
use_autotune: bool = False,
with_roller: bool = False):
use_autotune = True
if use_autotune:
result = get_best_config(M, N, K, with_roller)
print(result.config)
kernel = result.kernel
else:
config = get_heuristic_config()
kernel = matmul(M, N, K, **config)
# benchmark
profiler = kernel.get_profiler(tensor_supply_type=tl.TensorSupplyType.Auto)
tilelang_latency = profiler.do_bench()
ref_latency = profiler.do_bench(ref_program)
profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
print(f"TileLang latency: {tilelang_latency}")
print(f"Ref latency: {ref_latency}")
print(f"TileLang TFlops: {2 * M * N * K / tilelang_latency * 1e-9}")
print(f"Ref TFlops: {2 * M * N * K / ref_latency * 1e-9}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Autotuned MatMul Benchmark")
parser.add_argument("--m", type=int, default=4096, help="Matrix dimension M")
parser.add_argument("--n", type=int, default=4096, help="Matrix dimension N")
parser.add_argument("--k", type=int, default=4096, help="Matrix dimension K")
parser.add_argument(
"--use_autotune",
action="store_true",
default=False,
help="Whether to use autotune for matmul configs")
parser.add_argument(
"--with_roller",
action="store_true",
default=False,
help="Whether to enable BitBLAS roller for search space")
args = parser.parse_args()
main(args.m, args.n, args.k, args.use_autotune, args.with_roller)
from tilelang import tvm as tvm
from tvm import DataType
import tilelang
import tilelang.language as T
from tilelang.intrinsics import get_swizzle_layout
from tilelang.intrinsics.mma_macro_generator import (
TensorCoreIntrinEmitter,)
from tilelang.transform import simplify_prim_func
def make_swizzle_layout(shared_buf):
dtype = shared_buf.dtype
shape = shared_buf.shape
can_swizzle = shape[-1] * DataType(dtype).bits == 512
if not can_swizzle:
return T.Layout(shape, lambda *args: args)
def transform_func(i, j):
new_warp_i, new_warp_j = get_swizzle_layout(i, j, shape[-1], dtype)
return [new_warp_i, new_warp_j]
return T.Layout(shape, transform_func)
@tilelang.jit(out_idx=[2])
@simplify_prim_func
def tl_matmul(
M,
N,
K,
in_dtype,
out_dtype,
accum_dtype,
):
assert in_dtype in [
"float16",
"int8",
], "Currently only float16 and int8 are supported"
assert out_dtype in [
"float16",
"float32",
"int32",
], "Currently only float16, float32 and int32 are supported"
micro_size_x = micro_size_y = micro_size_k = 16
if out_dtype == "int32":
micro_size_k = 32
# This is a debug config
block_row_warps = 2
block_col_warps = 2
warp_row_tiles = 64
warp_col_tiles = 64
# chunk = 32 if in_dtype == "float16" else 64
chunk = 32
shared_scope = "shared.dyn"
# Pipeline Stage
stage = 2
block_M = block_row_warps * warp_row_tiles
block_N = block_col_warps * warp_col_tiles
block_K = chunk
A_shape = (M, K)
B_shape = (N, K)
A_shared_shape = (block_M, block_K)
B_shared_shape = (block_N, block_K)
C_shared_shape = (
block_M // micro_size_x,
block_N // micro_size_y,
micro_size_x,
micro_size_y,
)
warp_size = 32
threads = warp_size * (block_row_warps * block_col_warps)
local_size_a = (micro_size_x * micro_size_k) // warp_size
local_size_b = (micro_size_y * micro_size_k) // warp_size
local_size_c = (micro_size_x * micro_size_y) // warp_size
warp_rows = warp_row_tiles // micro_size_x
warp_cols = warp_col_tiles // micro_size_y
# MMA Wrapper to Auto Generate Code for MMA
mma_emitter = TensorCoreIntrinEmitter(
a_dtype=in_dtype,
b_dtype=in_dtype,
accum_dtype=accum_dtype,
a_transposed=False,
b_transposed=True,
block_row_warps=block_row_warps,
block_col_warps=block_col_warps,
warp_row_tiles=warp_row_tiles,
warp_col_tiles=warp_col_tiles,
chunk=chunk,
)
@T.prim_func
def gemm_intrinsics(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope)
B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope)
C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope)
A_local = T.alloc_local((warp_rows * local_size_a), in_dtype)
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype)
T.annotate_layout({
A_shared: make_swizzle_layout(A_shared),
B_shared: make_swizzle_layout(B_shared),
})
# Improve L2 Cache
T.use_swizzle(panel_size=10)
T.clear(C_local)
for ko in T.Pipelined((K // block_K), num_stages=stage):
# Load A into shared memory
for i, k in T.Parallel(block_M, block_K):
A_shared[i, k] = A[by * block_M + i, ko * block_K + k]
# Load B into shared memory
for j, k in T.Parallel(block_N, block_K):
B_shared[j, k] = B[bx * block_N + j, ko * block_K + k]
for ki in T.serial(0, (block_K // micro_size_k)):
# Load A into fragment
mma_emitter.ldmatrix_a(A_local, A_shared, ki)
# Load B into fragment
mma_emitter.ldmatrix_b(B_local, B_shared, ki)
# Perform Matrix Multiplication
mma_emitter.mma(A_local, B_local, C_local)
# Perform STMatrix
mma_emitter.stmatrix(C_local, C_shared)
# Store shared into global
for i, j in T.Parallel(block_M, block_N):
C[by * block_M + i, bx * block_N + j] = C_shared[
i // micro_size_x,
j // micro_size_y,
i % micro_size_x,
j % micro_size_y,
]
return gemm_intrinsics
def ref_program(A, B):
return A @ B.T
def main(M=4096, N=4096, K=4096):
in_dtype, out_dtype, accum_dtype = "float16", "float16", "float32"
kernel = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype)
src_code = kernel.get_kernel_source()
# src_code is the generated cuda source
assert src_code is not None
profiler = kernel.get_profiler()
latency = profiler.do_bench(profiler.func, warmup=25)
print(latency)
# Ensure that the latency is not None
assert latency is not None
profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
if __name__ == "__main__":
main(M=4096, N=4096, K=4096)
import tilelang
import tilelang.language as T
from tilelang.carver.arch import driver
import argparse
@tilelang.jit(out_idx=[-1])
def matmul_non_persistent(M,
N,
K,
block_M,
block_N,
block_K,
threads,
num_stages,
dtype="float16",
accum_dtype="float"):
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
):
with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=threads) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
C_shared = T.alloc_shared((block_M, block_N), dtype)
T.use_swizzle(10)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(A[bx * block_M, k * block_K], A_shared)
T.copy(B[k * block_K, by * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local)
T.copy(C_local, C_shared)
T.copy(C_shared, C[bx * block_M, by * block_N])
return main
@tilelang.jit(out_idx=[-1])
def matmul_persistent(M,
N,
K,
block_M,
block_N,
block_K,
threads,
num_stages,
dtype="float16",
accum_dtype="float",
use_persistent_primitive=True):
sm_num = driver.get_num_sms()
m_blocks = T.ceildiv(M, block_M)
n_blocks = T.ceildiv(N, block_N)
waves = T.ceildiv(m_blocks * n_blocks, sm_num)
group_size = 8
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
):
with T.Kernel(sm_num, threads=threads) as (block_id):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
C_shared = T.alloc_shared((block_M, block_N), dtype)
for w in T.serial(waves):
tile_id = sm_num * w + block_id
bx = (tile_id // group_size) % m_blocks
by = (tile_id % group_size) + (tile_id // group_size) // m_blocks * group_size
if bx * block_M < M and by * block_N < N:
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(A[bx * block_M, k * block_K], A_shared)
T.copy(B[k * block_K, by * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local)
T.copy(C_local, C_shared)
T.copy(C_shared, C[bx * block_M, by * block_N])
@T.prim_func
def main_persistent_primitive(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
):
with T.Kernel(sm_num, threads=threads) as (block_id):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
C_shared = T.alloc_shared((block_M, block_N), dtype)
for bx, by in T.Persistent(
[T.ceildiv(M, block_M), T.ceildiv(N, block_N)], sm_num, block_id):
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(A[bx * block_M, k * block_K], A_shared)
T.copy(B[k * block_K, by * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local)
T.copy(C_local, C_shared)
T.copy(C_shared, C[bx * block_M, by * block_N])
return main_persistent_primitive if use_persistent_primitive else main
def ref_program(A, B):
return A @ B
def main(M=4096, N=4096, K=4096):
total_flops = 2 * M * N * K
BLOCK_M = 128
BLOCK_N = 256
BLOCK_K = 64
threads = 256
num_stages = 3
persistent_kernel = matmul_persistent(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, threads, num_stages)
persistent_profiler = persistent_kernel.get_profiler(
tensor_supply_type=tilelang.TensorSupplyType.Randn)
persistent_profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01)
print("Persistent GEMM: All check passed.")
persistent_latency = persistent_profiler.do_bench(warmup=500)
print(f"Persistent GEMM Latency: {persistent_latency} ms")
print(f"Persistent GEMM TFlops: {total_flops / persistent_latency * 1e-9} TFlops")
non_persistent_kernel = matmul_non_persistent(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, threads,
num_stages)
non_persistent_profiler = non_persistent_kernel.get_profiler(
tensor_supply_type=tilelang.TensorSupplyType.Randn)
non_persistent_profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01)
print("Non-Persistent GEMM: All check passed.")
non_persistent_latency = non_persistent_profiler.do_bench(warmup=500)
print(f"Non-Persistent GEMM Latency: {non_persistent_latency} ms")
print(f"Non-Persistent GEMM TFlops: {total_flops / non_persistent_latency * 1e-9} TFlops")
print(f"Persistent GEMM Speedup: {non_persistent_latency / persistent_latency}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--M', type=int, default=8192, help='M dimension')
parser.add_argument('--N', type=int, default=8192, help='N dimension')
parser.add_argument('--K', type=int, default=8192, help='K dimension')
args = parser.parse_args()
M, N, K = args.M, args.N, args.K
main(M, N, K)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment