Commit cf6e11c9 authored by qisan's avatar qisan
Browse files

feat: merge dcu branch features

parents 3f27f85a d0436b7b
Pipeline #3369 failed with stages
in 0 seconds
import tilelang.testing
import example_gqa_decode
import example_mha_inference
# TODO(lei): fix the correctness of gqa decode on sm90
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_le(8, 9)
def test_example_example_gqa_decode():
example_gqa_decode.main()
def test_example_example_mha_inference():
example_mha_inference.main(BATCH=1, H=32, Q_CTX=128, KV_CTX=2048, D_HEAD=128, causal=False)
if __name__ == "__main__":
tilelang.testing.main()
import math
import torch
import torch.nn as nn
from typing import Dict, Tuple, Optional
import tilelang
import tilelang.language as T
from tilelang.autotuner import *
from example_fusedmoe_torch import *
@tilelang.jit(pass_configs={"tl.disable_tma_lower": True, "tl.disable_warp_specialized": True})
def moe_forward_tilelang_shared(
d_hidden,
d_expert,
n_shared_experts,
dtype,
num_tokens,
block_token=128,
block_dhidden=128,
block_dexpert=128,
threads=256,
num_stages=1,
):
scale = 1.44269504 # log2(e)
# Parameters
dhidden = d_hidden
dexpert = d_expert * n_shared_experts
# Tensors: Note that input shape is reshape to (num_tokens, dhidden)
input_shape = (num_tokens, dhidden)
shared_W_gate_shape = (dexpert, dhidden)
shared_W_up_shape = (dexpert, dhidden)
shared_W_down_shape = (dhidden, dexpert)
accum_type = T.float32
@T.prim_func
def kernel_shared(
input: T.Tensor(input_shape, dtype), # type: ignore
shared_W_gate: T.Tensor(shared_W_gate_shape, dtype), # type: ignore
shared_W_up: T.Tensor(shared_W_up_shape, dtype), # type: ignore
shared_W_down: T.Tensor(shared_W_down_shape, dtype), # type: ignore
up_logits: T.Tensor((num_tokens, dexpert), dtype), # type: ignore
output: T.Tensor(input_shape, dtype), # type: ignore
):
# Step 1: Compute gate and up logits
with T.Kernel(T.ceildiv(num_tokens, block_token), T.ceildiv(dexpert, block_dexpert), threads=threads) as (bx, by):
# Split the block to shared experts and routed experts
input_shared = T.alloc_fragment((block_token, block_dhidden), dtype=dtype)
W_gate_shared = T.alloc_shared((block_dexpert, block_dhidden), dtype=dtype)
W_up_shared = T.alloc_shared((block_dexpert, block_dhidden), dtype=dtype)
# Shared experts: no need to check expert_indices
gate_logits_local = T.alloc_fragment((block_token, block_dexpert), dtype=accum_type)
up_logits_local = T.alloc_fragment((block_token, block_dexpert), dtype=accum_type)
T.use_swizzle(10)
T.clear(gate_logits_local)
T.clear(up_logits_local)
# Parallel for gate and up matmul
for k in T.Pipelined(T.ceildiv(dhidden, block_dhidden), num_stages=num_stages):
T.copy(input[bx * block_token, k * block_dhidden], input_shared)
T.copy(shared_W_gate[by * block_dexpert, k * block_dhidden], W_gate_shared)
T.copy(shared_W_up[by * block_dexpert, k * block_dhidden], W_up_shared)
T.gemm(input_shared, W_gate_shared, gate_logits_local, transpose_B=True)
T.gemm(input_shared, W_up_shared, up_logits_local, transpose_B=True)
# Fuse with SiLU and element-wise product
for i, j in T.Parallel(block_token, block_dexpert):
gate_logits_local[i, j] = gate_logits_local[i, j] * (1.0 / (1.0 + T.exp2(-gate_logits_local[i, j] * scale)))
up_logits_local[i, j] = up_logits_local[i, j] * gate_logits_local[i, j]
T.copy(up_logits_local, up_logits[bx * block_token, by * block_dexpert])
# Step 2: Compute down logits
with T.Kernel(T.ceildiv(num_tokens, block_token), T.ceildiv(dhidden, block_dhidden), threads=threads) as (bx, by):
up_logits_shared = T.alloc_fragment((block_token, block_dexpert), dtype=dtype)
W_down_shared = T.alloc_shared((block_dhidden, block_dexpert), dtype=dtype)
output_local = T.alloc_fragment((block_token, block_dhidden), dtype=accum_type)
T.use_swizzle(10)
T.clear(output_local)
for k in T.Pipelined(T.ceildiv(dexpert, block_dexpert), num_stages=num_stages):
T.copy(up_logits[bx * block_token, k * block_dexpert], up_logits_shared)
T.copy(shared_W_down[by * block_dhidden, k * block_dexpert], W_down_shared)
T.gemm(up_logits_shared, W_down_shared, output_local, transpose_B=True)
T.copy(output_local, output[bx * block_token, by * block_dhidden])
return kernel_shared
@tilelang.jit(pass_configs={"tl.disable_tma_lower": True, "tl.disable_warp_specialized": True})
def moe_forward_tilelang_routed(
d_hidden,
d_expert,
n_routed_experts,
dtype,
group_sum,
group_count,
block_token=128,
block_dhidden=128,
block_dexpert=128,
threads=256,
num_stages=1,
k_pack=1,
coalesced_width=None,
):
scale = 1.44269504 # log2(e)
# Parameters
dhidden = d_hidden
dexpert = d_expert
n_routed_experts = n_routed_experts
# Group info
# group_sum = sum(group_sizes_list)
# group_count = len(group_sizes_list)
# M = sum([(group_size + block_token - 1) // block_token for group_size in group_sizes_list])
M = math.ceil(group_sum / block_token) + group_count
accum_dtype = T.float32
# Tensors: Note that input shape is reshape to (bs * seq_len * n_experts_per_token, dhidden) for grouped gemm
input_shape = (group_sum, dhidden)
intermediate_shape = (group_sum, dexpert)
routed_expert_gate_shape = (n_routed_experts, dexpert, dhidden)
routed_expert_up_shape = (n_routed_experts, dexpert, dhidden)
routed_expert_down_shape = (n_routed_experts, dhidden, dexpert)
routed_expert_weights_shape = group_sum
group_sizes_shape = n_routed_experts
@T.prim_func
def kernel(
input: T.Tensor(input_shape, dtype), # type: ignore
routed_expert_gate: T.Tensor(routed_expert_gate_shape, dtype), # type: ignore
routed_expert_up: T.Tensor(routed_expert_up_shape, dtype), # type: ignore
routed_expert_down: T.Tensor(routed_expert_down_shape, dtype), # type: ignore
routed_expert_weights: T.Tensor(routed_expert_weights_shape, dtype), # type: ignore
group_sizes: T.Tensor(group_sizes_shape, T.int32), # type: ignore
group_offsets: T.Tensor(group_sizes_shape, T.int32), # type: ignore
group_padded_offsets: T.Tensor(group_sizes_shape, T.int32), # type: ignore
group_idx_for_bx: T.Tensor((M,), T.int32), # type: ignore
up_logits: T.Tensor(intermediate_shape, dtype), # type: ignore
output: T.Tensor(input_shape, dtype), # type: ignore
):
# Step 1: Compute gate and up logits
with T.Kernel(M, T.ceildiv(dexpert, block_dexpert), threads=threads) as (bx, by):
input_shared = T.alloc_fragment((block_token, block_dhidden), dtype=dtype)
routed_expert_gate_shared = T.alloc_shared((block_dexpert, block_dhidden), dtype=dtype)
routed_expert_up_shared = T.alloc_shared((block_dexpert, block_dhidden), dtype=dtype)
gate_logits_local = T.alloc_fragment((block_token, block_dexpert), dtype=accum_dtype)
up_logits_local = T.alloc_fragment((block_token, block_dexpert), dtype=accum_dtype)
cur_group_idx = T.alloc_local([1], T.int32)
cur_group_size = T.alloc_local([1], T.int32)
T.use_swizzle(10, enable=True)
m_start_padded = bx * block_token
cur_group_idx[0] = group_idx_for_bx[bx]
cur_group_size[0] = group_sizes[cur_group_idx[0]]
m_start = m_start_padded - group_padded_offsets[cur_group_idx[0]] + group_offsets[cur_group_idx[0]]
actual_rows = T.max(0, T.min(block_token, cur_group_size[0] - (m_start_padded - group_padded_offsets[cur_group_idx[0]])))
T.clear(gate_logits_local)
T.clear(up_logits_local)
for k in T.Pipelined(T.ceildiv(dhidden, block_dhidden), num_stages=num_stages):
T.copy(
input[m_start : m_start + block_token, k * block_dhidden : (k + 1) * block_dhidden],
input_shared,
coalesced_width=coalesced_width,
)
T.copy(
routed_expert_gate[
cur_group_idx[0], by * block_dexpert : (by + 1) * block_dexpert, k * block_dhidden : (k + 1) * block_dhidden
],
routed_expert_gate_shared,
coalesced_width=coalesced_width,
)
T.gemm(input_shared, routed_expert_gate_shared, gate_logits_local, k_pack=k_pack, transpose_B=True)
T.copy(
routed_expert_up[
cur_group_idx[0], by * block_dexpert : (by + 1) * block_dexpert, k * block_dhidden : (k + 1) * block_dhidden
],
routed_expert_up_shared,
coalesced_width=coalesced_width,
)
T.gemm(input_shared, routed_expert_up_shared, up_logits_local, k_pack=k_pack, transpose_B=True)
for i, j in T.Parallel(block_token, block_dexpert):
gate_logits_local[i, j] = gate_logits_local[i, j] * (1.0 / (1.0 + T.exp2(-gate_logits_local[i, j] * scale)))
up_logits_local[i, j] = up_logits_local[i, j] * gate_logits_local[i, j]
for i, j in T.Parallel(block_token, block_dexpert):
if i < actual_rows:
up_logits[m_start + i, by * block_dexpert + j] = up_logits_local[i, j]
# Step 2: Compute down logits
with T.Kernel(M, T.ceildiv(dhidden, block_dhidden), threads=threads) as (bx, by):
up_logits_shared = T.alloc_fragment((block_token, block_dexpert), dtype=dtype)
routed_expert_down_shared = T.alloc_shared((block_dhidden, block_dexpert), dtype=dtype)
output_local = T.alloc_fragment((block_token, block_dhidden), dtype=accum_dtype)
cur_group_idx = T.alloc_local([1], T.int32)
cur_group_size = T.alloc_local([1], T.int32)
T.use_swizzle(10, enable=True)
m_start_padded = bx * block_token
cur_group_idx[0] = group_idx_for_bx[bx]
cur_group_size[0] = group_sizes[cur_group_idx[0]]
m_start = m_start_padded - group_padded_offsets[cur_group_idx[0]] + group_offsets[cur_group_idx[0]]
actual_rows = T.max(0, T.min(block_token, cur_group_size[0] - (m_start_padded - group_padded_offsets[cur_group_idx[0]])))
T.clear(output_local)
for k in T.Pipelined(T.ceildiv(dexpert, block_dexpert), num_stages=num_stages):
T.copy(
up_logits[m_start : m_start + block_token, k * block_dexpert : (k + 1) * block_dexpert],
up_logits_shared,
coalesced_width=coalesced_width,
)
T.copy(
routed_expert_down[
cur_group_idx[0], by * block_dhidden : (by + 1) * block_dhidden, k * block_dexpert : (k + 1) * block_dexpert
],
routed_expert_down_shared,
coalesced_width=coalesced_width,
)
T.gemm(up_logits_shared, routed_expert_down_shared, output_local, k_pack=k_pack, transpose_B=True)
for i, j in T.Parallel(block_token, block_dhidden):
if i < actual_rows:
output[m_start + i, by * block_dhidden + j] = output_local[i, j] * routed_expert_weights[m_start + i]
return kernel
class Expert(nn.Module):
def __init__(self, config: Dict, gate: torch.Tensor, up: torch.Tensor, down: torch.Tensor, d_expert: Optional[int] = None):
super().__init__()
self.config = config
self.act_fn = nn.SiLU()
self.d_hidden: int = config["d_hidden"]
self.d_expert: int = config["d_expert"] if d_expert is None else d_expert
self.device = torch.device("cuda")
self.W_gate_weight = gate.t().contiguous().to(self.device)
self.W_up_weight = up.t().contiguous().to(self.device)
self.W_down_weight = down.t().contiguous().to(self.device)
def forward(self, x: torch.Tensor) -> torch.Tensor:
gate = self.act_fn(x @ self.W_gate_weight)
out = (gate * (x @ self.W_up_weight)) @ self.W_down_weight
return out
class MoEGate(nn.Module):
def __init__(self, config: Dict, weights: Dict):
super().__init__()
self.top_k: int = config["n_experts_per_token"]
self.num_experts: int = config["n_routed_experts"]
self.d_hidden: int = config["d_hidden"]
self.W_g_weight = weights["router.weight"].t()
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
logits = x @ self.W_g_weight
scores = logits.softmax(dim=-1)
topk_scores, topk_indices = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
return topk_indices, topk_scores
class MoE(nn.Module):
def __init__(
self, config: Dict, shared_kernel: tilelang.JITKernel, routed_kernel: tilelang.JITKernel, weights: Dict, padding_M: int = 128
):
super().__init__()
self.config = config
self.shared_kernel = shared_kernel
self.routed_kernel = routed_kernel
self.padding_M = padding_M
self.experts = nn.ModuleList(
[
Expert(
config,
gate=weights[f"experts.{i}.0.weight"],
up=weights[f"experts.{i}.1.weight"],
down=weights[f"experts.{i}.2.weight"],
)
for i in range(config["n_routed_experts"])
]
)
self.device = torch.device("cuda")
self.gating_network = MoEGate(config, weights).to(self.device)
shared_expert_dim = config["d_expert"] * config["n_shared_experts"]
self.shared_expert = Expert(
config=config,
gate=weights["shared_experts.0.weight"],
up=weights["shared_experts.1.weight"],
down=weights["shared_experts.2.weight"],
d_expert=shared_expert_dim,
).to(self.device)
self.expert_cache = torch.zeros(
(config["batch_size"] * config["seq_len"], config["d_hidden"]), dtype=torch.float16, device=self.device
)
self.stacked_expert_w_gate = torch.stack([expert.W_gate_weight for expert in self.experts], dim=0)
self.stacked_expert_w_up = torch.stack([expert.W_up_weight for expert in self.experts], dim=0)
self.stacked_expert_w_down = torch.stack([expert.W_down_weight for expert in self.experts], dim=0)
self.stacked_expert_tokens = torch.empty(
(config["batch_size"] * config["seq_len"] * config["n_experts_per_token"], self.config["d_hidden"]),
dtype=torch.float16,
device=self.device,
)
self.stacked_expert_weights = torch.empty(
(config["batch_size"] * config["seq_len"] * config["n_experts_per_token"]), dtype=torch.float16, device=self.device
)
self.stacked_expert_tokens_idxs = torch.empty(
(config["batch_size"] * config["seq_len"] * config["n_experts_per_token"]), dtype=torch.int64, device=self.device
)
self.up_logits_shared = torch.empty(
(config["batch_size"] * config["seq_len"], self.config["d_expert"]), dtype=torch.float16, device=self.device
)
self.expert_output_shared = torch.empty(
(config["batch_size"] * config["seq_len"], self.config["d_hidden"]), dtype=torch.float16, device=self.device
)
self.up_logits_routed = torch.empty(
(config["batch_size"] * config["seq_len"] * config["n_experts_per_token"], self.config["d_expert"]),
dtype=torch.float16,
device=self.device,
)
self.expert_output_routed = torch.empty(
(config["batch_size"] * config["seq_len"] * config["n_experts_per_token"], self.config["d_hidden"]),
dtype=torch.float16,
device=self.device,
)
@torch.no_grad()
def forward(self, x: torch.Tensor) -> torch.Tensor:
orig_shape = x.shape
batch_size, seq_len, hidden_dim = x.shape
expert_indices, expert_scores = self.gating_network(x)
flat_expert_indices = expert_indices.view(-1)
flat_expert_weights = expert_scores.view(-1)
x_flat = x.view(-1, hidden_dim)
# Prepare for grouped GEMM
idxs = flat_expert_indices.argsort()
counts = flat_expert_indices.bincount().cpu().numpy()
# counts = flat_expert_indices.bincount()
tokens_per_expert = counts.cumsum()
# tokens_per_expert = torch.cumsum(counts, dim=0)
num_per_tok = self.config["n_experts_per_token"]
token_idxs = idxs // num_per_tok
# Get stacked expert tokens and expert weights
for expert_id, end_idx in enumerate(tokens_per_expert):
start_idx = 0 if expert_id == 0 else tokens_per_expert[expert_id - 1]
if start_idx == end_idx:
continue
exp_token_idxs = token_idxs[start_idx:end_idx]
expert_tokens = x_flat[exp_token_idxs]
self.stacked_expert_tokens[start_idx:end_idx] = expert_tokens
self.stacked_expert_tokens_idxs[start_idx:end_idx] = exp_token_idxs
self.stacked_expert_weights[start_idx:end_idx] = flat_expert_weights[idxs[start_idx:end_idx]]
group_sizes = torch.tensor(counts, dtype=torch.int32, device=self.device)
group_offset = torch.tensor(tokens_per_expert - counts, dtype=torch.int32, device=self.device)
group_padded_offsets = [0 for _ in range(len(group_sizes))]
for i in range(1, len(group_sizes)):
group_padded_offsets[i] = group_padded_offsets[i - 1] + math.ceil((counts[i - 1] + 1) / self.padding_M) * self.padding_M
block_token = 128
M = (
math.ceil(self.config["batch_size"] * self.config["seq_len"] * self.config["n_experts_per_token"] / block_token)
+ self.config["n_routed_experts"]
)
group_idx_for_bx = [0 for _ in range(M)]
for bx in range(M):
m_start_padded = bx * block_token
for i in range(self.config["n_routed_experts"]):
if m_start_padded >= group_padded_offsets[i]:
group_idx_for_bx[bx] = i
group_padded_offsets = torch.tensor(group_padded_offsets, dtype=torch.int32, device=self.device)
group_idx_for_bx = torch.tensor(group_idx_for_bx, dtype=torch.int32, device=self.device)
# Multi-stream execution
shared_stream = torch.cuda.Stream()
routed_stream = torch.cuda.default_stream()
torch.cuda.synchronize()
with torch.cuda.stream(routed_stream):
# Tilelang version: Grouped GEMM
self.routed_kernel(
self.stacked_expert_tokens,
self.stacked_expert_w_gate,
self.stacked_expert_w_up,
self.stacked_expert_w_down,
self.stacked_expert_weights,
group_sizes,
group_offset,
group_padded_offsets,
group_idx_for_bx,
self.up_logits_routed,
self.expert_output_routed,
)
# Scatter reduce
self.expert_cache = torch.scatter_reduce(
self.expert_cache,
0,
self.stacked_expert_tokens_idxs.view(-1, 1).repeat(1, x_flat.shape[-1]),
self.expert_output_routed,
reduce="sum",
)
routed_output = self.expert_cache.view(*orig_shape)
with torch.cuda.stream(shared_stream):
self.shared_kernel(
x_flat,
self.shared_expert.W_gate_weight,
self.shared_expert.W_up_weight,
self.shared_expert.W_down_weight,
self.up_logits_shared,
self.expert_output_shared,
)
shared_output = self.expert_output_shared.view(*orig_shape)
torch.cuda.synchronize()
return shared_output + routed_output
def custom_kernel(data: Tuple[torch.Tensor, Dict, Dict]) -> torch.Tensor:
"""
DeepSeek-style Mixture of Experts using Tilelang.
Args:
data: Tuple of (input: torch.Tensor, weights: Dict[str, torch.Tensor], config: Dict)
- input: Input tensor of shape [batch_size, seq_len, hidden_size]
- weights: Dictionary containing model weights
- config: Dictionary containing model configuration parameters
Returns:
Tuple containing:
- output: Processed tensor [batch_size, seq_len, d_model]
"""
input_tensor, weights, config = data
dtype_str = T.float16
shared_kernel = moe_forward_tilelang_shared(
config["d_hidden"],
config["d_expert"],
config["n_shared_experts"],
dtype=dtype_str,
num_tokens=config["batch_size"] * config["seq_len"],
)
routed_kernel = moe_forward_tilelang_routed(
config["d_hidden"],
config["d_expert"],
config["n_routed_experts"],
dtype=dtype_str,
group_sum=config["batch_size"] * config["seq_len"] * config["n_experts_per_token"],
group_count=config["n_routed_experts"],
block_token=128,
block_dhidden=128,
block_dexpert=128,
threads=256,
num_stages=1,
k_pack=1,
coalesced_width=2,
)
moe = MoE(config, shared_kernel, routed_kernel, weights, padding_M=128)
output = moe(input_tensor)
return output
def main(d_hidden=7168, d_expert=2048, n_routed_experts=8, n_shared_experts=1, n_experts_per_token=4, batch_size=1, seq_len=8192):
config = {
"dhidden": d_hidden,
"dexpert": d_expert,
"nroutedexperts": n_routed_experts,
"nsharedexperts": n_shared_experts,
"nexpertspertoken": n_experts_per_token,
"bs": batch_size,
"seqlen": seq_len,
"seed": 81394,
}
data = generate_input(**config)
torch.cuda.synchronize()
ref_output = ref_kernel(clone_data(data)).to(torch.float32)
torch.cuda.synchronize()
tilelang_output = custom_kernel(clone_data(data)).to(torch.float32)
torch.cuda.synchronize()
torch.testing.assert_close(ref_output, tilelang_output, atol=1e-2, rtol=1e-2)
print("✅ Tilelang and Torch match")
if __name__ == "__main__":
main()
import math
import torch
import torch.nn as nn
from typing import Dict, Tuple, Optional
# Reference code in PyTorch
class ExpertTorch(nn.Module):
def __init__(self, config: Dict, d_expert: Optional[int] = None):
super().__init__()
self.config = config
self.act_fn = nn.SiLU()
self.d_hidden: int = config["d_hidden"]
self.d_expert: int = config["d_expert"] if d_expert is None else d_expert
self.W_gate = nn.Linear(self.d_hidden, self.d_expert, bias=False)
self.W_up = nn.Linear(self.d_hidden, self.d_expert, bias=False)
self.W_down = nn.Linear(self.d_expert, self.d_hidden, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
gate = self.act_fn(self.W_gate(x))
out = self.W_down(gate * self.W_up(x))
return out
class MoEGateTorch(nn.Module):
def __init__(self, config: Dict):
super().__init__()
self.top_k: int = config["n_experts_per_token"]
self.num_experts: int = config["n_routed_experts"]
self.d_hidden: int = config["d_hidden"]
self.W_g = nn.Linear(self.d_hidden, self.num_experts, bias=False)
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
logits = self.W_g(x)
scores = logits.softmax(dim=-1)
topk_scores, topk_indices = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
return topk_indices, topk_scores
class MoETorch(nn.Module):
def __init__(self, config: Dict):
super().__init__()
self.config = config
self.experts = nn.ModuleList([ExpertTorch(config) for _ in range(config["n_routed_experts"])])
self.gating_network = MoEGateTorch(config)
shared_expert_dim = config["d_expert"] * config["n_shared_experts"]
self.shared_expert = ExpertTorch(config=config, d_expert=shared_expert_dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
shared_output = self.shared_expert(x)
expert_indices, expert_scores = self.gating_network(x)
batch_size, seq_len, hidden_dim = x.shape
orig_shape = x.shape
x_flat = x.view(-1, hidden_dim)
flat_expert_indices = expert_indices.view(-1)
flat_expert_weights = expert_scores.view(-1, 1)
routed_output_flat = self.moe_infer(x_flat, flat_expert_indices, flat_expert_weights)
routed_output = routed_output_flat.view(*orig_shape)
return routed_output + shared_output
@torch.no_grad()
def moe_infer(self, x: torch.Tensor, flat_expert_indices: torch.Tensor, flat_expert_weights: torch.Tensor) -> torch.Tensor:
expert_cache = torch.zeros_like(x)
# test_expert_cache = torch.zeros((x.shape[0] * self.config["n_experts_per_token"], self.config["d_hidden"]))
# test_expert_tokens = torch.zeros((x.shape[0] * self.config["n_experts_per_token"], self.config["d_hidden"]))
# test_expert_ups = torch.zeros((self.config["n_routed_experts"], self.config["d_hidden"], self.config["d_expert"]))
# test_expert_tokens_num = torch.zeros((self.config["n_routed_experts"]))
idxs = flat_expert_indices.argsort()
counts = flat_expert_indices.bincount().cpu().numpy()
tokens_per_expert = counts.cumsum()
num_per_tok = self.config["n_experts_per_token"]
token_idxs = idxs // num_per_tok
for expert_id, end_idx in enumerate(tokens_per_expert):
start_idx = 0 if expert_id == 0 else tokens_per_expert[expert_id - 1]
if start_idx == end_idx:
continue
expert = self.experts[expert_id]
exp_token_idxs = token_idxs[start_idx:end_idx]
expert_tokens = x[exp_token_idxs]
expert_out = expert(expert_tokens)
expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
expert_cache.scatter_reduce_(0, exp_token_idxs.view(-1, 1).repeat(1, x.shape[-1]), expert_out, reduce="sum")
return expert_cache
def ref_kernel(data: Tuple[torch.Tensor, Dict, Dict]) -> torch.Tensor:
"""
Reference implementation of DeepSeek-style Mixture of Experts using PyTorch.
Args:
data: Tuple of (input: torch.Tensor, weights: Dict[str, torch.Tensor], config: Dict)
- input: Input tensor of shape [batch_size, seq_len, hidden_dim]
- weights: Dictionary containing model weights
- config: Dictionary containing model configuration parameters
Returns:
Tuple containing:
- output: Processed tensor [batch_size, seq_len, d_model]
"""
input_tensor, weights, config = data
num_experts = config["n_routed_experts"]
moe = MoETorch(config)
# Fill in the given weights of the model
moe.gating_network.W_g.weight = nn.Parameter(weights["router.weight"])
for i in range(num_experts):
gate_proj_weight = weights[f"experts.{i}.0.weight"]
up_proj_weight = weights[f"experts.{i}.1.weight"]
down_proj_weight = weights[f"experts.{i}.2.weight"]
# Transpose weights to match expected shape for nn.Linear
moe.experts[i].W_gate.weight = nn.Parameter(gate_proj_weight.t())
moe.experts[i].W_up.weight = nn.Parameter(up_proj_weight.t())
moe.experts[i].W_down.weight = nn.Parameter(down_proj_weight.t())
moe.shared_expert.W_gate.weight = nn.Parameter(weights["shared_experts.0.weight"].t())
moe.shared_expert.W_up.weight = nn.Parameter(weights["shared_experts.1.weight"].t())
moe.shared_expert.W_down.weight = nn.Parameter(weights["shared_experts.2.weight"].t())
output = moe(input_tensor)
return output
# Input generation for the reference code
def generate_input(
dhidden: int, dexpert: int, nroutedexperts: int, nsharedexperts: int, nexpertspertoken: int, bs: int, seqlen: int, seed: int
) -> Tuple[torch.Tensor, Dict, Dict]:
# Really dumb but for now _ isn't parsing correctly.
d_hidden = dhidden
d_expert = dexpert
n_routed_experts = nroutedexperts
n_shared_experts = nsharedexperts
n_experts_per_token = nexpertspertoken
batch_size = bs
seq_len = seqlen
config = {
"d_hidden": d_hidden,
"d_expert": d_expert,
"n_routed_experts": n_routed_experts,
"n_shared_experts": n_shared_experts,
"n_experts_per_token": n_experts_per_token,
"batch_size": batch_size,
"seq_len": seq_len,
}
gen = torch.Generator(device="cuda")
gen.manual_seed(seed)
num_experts = n_routed_experts
expert_dim = d_expert
weights = {}
input_tensor = torch.randn((batch_size, seq_len, d_hidden), device="cuda", dtype=torch.float16, generator=gen).contiguous()
# Initialize router weights
weights["router.weight"] = torch.randn((num_experts, d_hidden), device="cuda", dtype=torch.float16, generator=gen) / math.sqrt(d_hidden)
for i in range(num_experts):
weights[f"experts.{i}.0.weight"] = torch.randn(
(d_hidden, expert_dim), device="cuda", dtype=torch.float16, generator=gen
) / math.sqrt(expert_dim)
weights[f"experts.{i}.1.weight"] = torch.randn(
(d_hidden, expert_dim), device="cuda", dtype=torch.float16, generator=gen
) / math.sqrt(expert_dim)
weights[f"experts.{i}.2.weight"] = torch.randn(
(expert_dim, d_hidden), device="cuda", dtype=torch.float16, generator=gen
) / math.sqrt(d_hidden)
weights["shared_experts.0.weight"] = torch.randn(
(d_hidden, expert_dim * n_shared_experts), device="cuda", dtype=torch.float16, generator=gen
) / math.sqrt(expert_dim * n_shared_experts)
weights["shared_experts.1.weight"] = torch.randn(
(d_hidden, expert_dim * n_shared_experts), device="cuda", dtype=torch.float16, generator=gen
) / math.sqrt(expert_dim * n_shared_experts)
weights["shared_experts.2.weight"] = torch.randn(
(expert_dim * n_shared_experts, d_hidden), device="cuda", dtype=torch.float16, generator=gen
) / math.sqrt(d_hidden)
return (input_tensor, weights, config)
def clone_data(data):
"""
Recursively goes through data and clones all tensors.
"""
if isinstance(data, tuple):
return tuple(clone_data(x) for x in data)
elif isinstance(data, list):
return [clone_data(x) for x in data]
elif isinstance(data, dict):
return {k: clone_data(v) for k, v in data.items()}
elif isinstance(data, torch.Tensor):
return data.clone()
else:
return data
import tilelang.testing
import example_fusedmoe_tilelang
def test_example_fusedmoe_tilelang():
example_fusedmoe_tilelang.main(
d_hidden=1024, d_expert=256, n_routed_experts=8, n_shared_experts=1, n_experts_per_token=4, batch_size=1, seq_len=1024
)
if __name__ == "__main__":
tilelang.testing.main()
# Gated Delta Net (GDN) kernel implementation with TileLang
## Requirement
- TileLang: `0.1.5+17fafc1b3026d910a83eb8052fdf811ba56be0b1`
- Triton: `3.3.0` (used for comparison)
- FLA: commit `f03cb3ae` (used for comparison)
## Get started
The [chunk_delta_h](common/chunk_delta_h.py) implements the most critical forward kernel of GDN. It's a good start to understand the GDN logic and the TileLang optimization.
## Acknowledgments
This kernel was developed by Yu Cheng and Zhengju Tang following in-depth discussions with Xiaomi's LLM-Core Team (MiMo).
# Reference: fla/ops/common/chunk_delta_h.py
import sys # noqa: F401
import tilelang
import tilelang.language as T
print(tilelang.__file__, flush=True)
# Add your fla repository path to sys.path
# Currently we use the fla repository from the flash-linear-attention project at commit id f03cb3ae
# sys.path.insert(0, "/home/tzj/flash-linear-attention")
try:
import fla
print(fla.__file__, flush=True)
from fla.ops.common.chunk_delta_h import chunk_gated_delta_rule_bwd_dhu
except ImportError:
print("fla not found, using tilelang implementation")
fla = None
import torch
import torch.nn.functional as F
torch.random.manual_seed(0)
# torch.set_printoptions(profile="full")
from test_utils import assert_similar
def prepare_input(
B,
S,
H,
DK,
DV,
chunk_size,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
state_dtype,
):
Q = torch.randn(B, S, H, DK, dtype=input_dtype).cuda()
K = torch.randn(B, S, H, DK, dtype=input_dtype).cuda()
K = F.normalize(K, dim=-1, p=2)
W = torch.randn(B, S, H, DK, dtype=input_dtype).cuda()
# Note: G should be in logspace and do chunkwise cumsum
G = torch.randn(B, S, H, dtype=gate_dtype).cuda()
G = F.logsigmoid(G)
try:
from fla.ops.utils.cumsum import chunk_local_cumsum
G = chunk_local_cumsum(G, chunk_size)
except ImportError:
print("fla not found, skip cumsum")
h0 = torch.randn(B, H, DK, DV, dtype=input_dtype).cuda()
dht = torch.randn(B, H, DK, DV, dtype=input_dtype).cuda()
dO = torch.randn(B, S, H, DV, dtype=input_dtype).cuda()
dv = torch.randn(B, S, H, DV, dtype=input_dtype).cuda()
return Q, K, W, G, h0, dht, dO, dv
def prepare_input_fake(
B,
S,
H,
DK,
DV,
chunk_size,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
state_dtype,
):
Q = torch.ones(B, S, H, DK, dtype=input_dtype).cuda()
K = torch.ones(B, S, H, DK, dtype=input_dtype).cuda()
W = torch.ones(B, S, H, DK, dtype=input_dtype).cuda()
G = torch.ones(B, S, H, dtype=gate_dtype).cuda()
h0 = torch.ones(B, H, DK, DV, dtype=input_dtype).cuda()
dht = torch.ones(B, H, DK, DV, dtype=input_dtype).cuda()
dO = torch.ones(B, S, H, DV, dtype=input_dtype).cuda()
dv = torch.ones(B, S, H, DV, dtype=input_dtype).cuda()
return Q, K, W, G, h0, dht, dO, dv
def prepare_output(
B,
S,
H,
DK,
DV,
chunk_size,
output_dtype,
gate_dtype,
state_dtype,
):
BS = S // chunk_size
dh = torch.empty(B, BS, H, DK, DV, dtype=output_dtype).cuda()
dh0 = torch.empty(B, H, DK, DV, dtype=state_dtype).cuda()
dv2 = torch.empty(B, S, H, DV, dtype=output_dtype).cuda()
return dh, dh0, dv2
def torch_chunk_gated_delta_rule_bwd_dhu(
Q: torch.Tensor,
K: torch.Tensor,
W: torch.Tensor,
G: torch.Tensor,
h0: torch.Tensor,
dht: torch.Tensor,
dO: torch.Tensor,
dv: torch.Tensor,
scale: float,
use_g: bool,
use_initial_state: bool,
use_final_state_gradient: bool,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
state_dtype,
):
B, S, H, DK = Q.shape
DV = dv.shape[-1]
block_S = 64
BS = S // block_S
dh, dh0, dv2 = (
torch.empty((B, BS, H, DK, DV), dtype=output_dtype),
torch.empty((B, H, DK, DV), dtype=state_dtype),
torch.empty((B, S, H, DV), dtype=output_dtype),
)
dh_tmp = torch.empty((B, H, DK, DV), dtype=accum_dtype)
dv_tmp = torch.empty((B, S, H, DV), dtype=accum_dtype)
Q_tmp = torch.empty((B, S, H, DK), dtype=accum_dtype)
if use_final_state_gradient:
dh_tmp = dht.clone().to(accum_dtype)
else:
dh_tmp = torch.zeros_like(dht).to(accum_dtype)
for i_s in range(BS - 1, -1, -1):
dh[:, i_s, :, :, :] = dh_tmp
dv_tmp = torch.matmul(K[:, i_s * block_S : (i_s + 1) * block_S, :, :].permute(0, 2, 1, 3), dh_tmp.to(K.dtype)).permute(0, 2, 1, 3)
if use_g:
for i_bh in range(B * H):
i_b, i_h = i_bh // H, i_bh % H
for i_s2 in range(block_S):
if G[i_b, i_s * block_S + block_S - 1, i_h] - G[i_b, i_s * block_S + i_s2, i_h] <= 0:
dv_tmp[i_b, i_s2, i_h, :] *= torch.exp(G[i_b, i_s * block_S + block_S - 1, i_h] - G[i_b, i_s * block_S + i_s2, i_h])
else:
dv_tmp[i_b, i_s2, i_h, :] = 0
dv_tmp += dv[:, i_s * block_S : (i_s + 1) * block_S, :, :]
dv2[:, i_s * block_S : (i_s + 1) * block_S, :, :] = dv_tmp
if use_g:
G_last = G[:, i_s * block_S + block_S - 1, :]
for i_bh in range(B * H):
i_b, i_h = i_bh // H, i_bh % H
dh_tmp[i_b, i_h, :, :] *= torch.exp(G_last[i_b, i_h])
Q_tmp = Q[:, i_s * block_S : (i_s + 1) * block_S, :, :]
for i_s2 in range(block_S):
for i_k in range(DK):
Q_tmp[:, i_s2, :, i_k] *= torch.exp(G[:, i_s * block_S + i_s2, :])
Q_tmp *= scale
W_tmp = W[:, i_s * block_S : (i_s + 1) * block_S, :, :]
dO_tmp = dO[:, i_s * block_S : (i_s + 1) * block_S, :, :]
torch.backends.cuda.matmul.allow_tf32 = True
dh_tmp += torch.matmul(Q_tmp.permute(0, 2, 3, 1), dO_tmp.permute(0, 2, 1, 3))
dh_tmp -= torch.matmul(W_tmp.permute(0, 2, 3, 1), dv_tmp.permute(0, 2, 1, 3))
torch.backends.cuda.matmul.allow_tf32 = False
if use_initial_state:
dh0 = dh_tmp[:, :, :, :]
else:
dh0 = torch.zeros_like(dh_tmp[:, :, :, :])
print(dh0.dtype)
return dh, dh0, dv2
@tilelang.jit(out_idx=[-3, -2, -1])
def tilelang_chunk_gated_delta_rule_bwd_dhu(
# task config
B,
S,
H,
DK,
DV,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
state_dtype,
chunk_size,
scale,
use_g=True,
use_initial_state=True,
use_final_state_gradient=True,
# kernel config
block_DV=64,
threads=256,
num_stages=0,
):
block_S = chunk_size
# Should support cu_seqlen
BS = S // block_S
Q_shape = (B, S, H, DK)
K_shape = (B, S, H, DK)
W_shape = (B, S, H, DK)
G_shape = (B, S, H)
h0_shape = (B, H, DK, DV)
dht_shape = (B, H, DK, DV)
dO_shape = (B, S, H, DV)
dv_shape = (B, S, H, DV)
dh_shape = (B, BS, H, DK, DV)
dh0_shape = (B, H, DK, DV)
dv2_shape = (B, S, H, DV)
@T.prim_func
def kernel(
# Input
Q: T.Tensor(Q_shape, dtype=input_dtype),
K: T.Tensor(K_shape, dtype=input_dtype),
W: T.Tensor(W_shape, dtype=input_dtype),
G: T.Tensor(G_shape, dtype=gate_dtype),
h0: T.Tensor(h0_shape, dtype=input_dtype),
dht: T.Tensor(dht_shape, dtype=input_dtype),
dO: T.Tensor(dO_shape, dtype=input_dtype),
dv: T.Tensor(dv_shape, dtype=input_dtype),
# Output
dh: T.Tensor(dh_shape, dtype=output_dtype),
dh0: T.Tensor(dh0_shape, dtype=state_dtype),
dv2: T.Tensor(dv2_shape, dtype=output_dtype),
):
with T.Kernel(T.ceildiv(DV, block_DV), B * H, threads=threads) as (bv, bbh):
bb, bh = bbh // H, bbh % H
b_dh_shared = T.alloc_shared((DK, block_DV), dtype=output_dtype)
b_dh_shared_fp32 = T.alloc_shared((DK, block_DV), dtype=state_dtype)
b_dh_fragment = T.alloc_fragment((DK, block_DV), dtype=accum_dtype)
b_dh_fragment_1 = T.alloc_fragment((DK, block_DV), dtype=accum_dtype)
b_dh_fragment_2 = T.alloc_fragment((DK, block_DV), dtype=accum_dtype)
dv_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype)
dv_fragment = T.alloc_fragment((block_S, block_DV), dtype=accum_dtype)
dv_fragment_2 = T.alloc_fragment((block_S, block_DV), dtype=accum_dtype)
dO_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype)
dO_shared_t = T.alloc_shared((block_DV, block_S), dtype=T.float32)
dO_fragment = T.alloc_fragment((block_S, block_DV), dtype=T.float32)
dO_fragment_t = T.alloc_fragment((block_DV, block_S), dtype=T.float32)
K_shared = T.alloc_shared((block_S, DK), dtype=input_dtype)
Q_shared = T.alloc_shared((block_S, DK), dtype=input_dtype)
Q_shared_fp32 = T.alloc_shared((block_S, DK), dtype=T.float32)
W_shared = T.alloc_shared((block_S, DK), dtype=input_dtype)
G_last_local = T.alloc_local((1), dtype=gate_dtype)
G_last_local_exp = T.alloc_local((1), dtype=gate_dtype)
G_shared = T.alloc_shared((block_S), dtype=gate_dtype, scope="shared")
G_fragment = T.alloc_fragment((block_S), dtype=gate_dtype)
G_fragment_post = T.alloc_fragment((block_S), dtype=gate_dtype)
G_fragment_exp = T.alloc_fragment((block_S), dtype=gate_dtype)
Q_fragment = T.alloc_fragment((block_S, DK), dtype=accum_dtype)
Q_fragment_t = T.alloc_fragment((DK, block_S), dtype=accum_dtype)
T.use_swizzle(10)
T.annotate_layout(
{
b_dh_shared: tilelang.layout.make_swizzled_layout(b_dh_shared),
b_dh_shared_fp32: tilelang.layout.make_swizzled_layout(b_dh_shared_fp32),
dv_shared: tilelang.layout.make_swizzled_layout(dv_shared),
dO_shared: tilelang.layout.make_swizzled_layout(dO_shared),
dO_shared_t: tilelang.layout.make_swizzled_layout(dO_shared_t),
K_shared: tilelang.layout.make_swizzled_layout(K_shared),
Q_shared: tilelang.layout.make_swizzled_layout(Q_shared),
Q_shared_fp32: tilelang.layout.make_swizzled_layout(Q_shared_fp32),
W_shared: tilelang.layout.make_swizzled_layout(W_shared),
}
)
if use_final_state_gradient:
T.copy(dht[bb, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV], b_dh_shared)
T.copy(b_dh_shared, b_dh_fragment)
else:
T.clear(b_dh_fragment)
for i_s in T.Pipelined(T.ceildiv(S, block_S), num_stages=num_stages):
# The gradient should be stored in the reverse order
i_s_inv = T.ceildiv(S, block_S) - i_s - 1
# Store the updated dh
T.copy(b_dh_fragment, b_dh_shared)
T.copy(b_dh_shared, dh[bb, i_s_inv, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV])
# Update dv
T.copy(K[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh, 0:DK], K_shared)
T.gemm(K_shared, b_dh_shared, dv_fragment, clear_accum=True)
if use_g:
T.copy(G[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh], G_shared, disable_tma=True)
T.copy(G_shared, G_fragment)
G_last_local[0] = G_shared[block_S - 1]
G_last_local_exp[0] = T.exp(G_last_local[0])
for i_s2 in T.Parallel(block_S):
G_fragment_post[i_s2] = T.exp(G_last_local[0] - G_fragment[i_s2])
for i_s2, i_v in T.Parallel(block_S, block_DV):
# with T.If(G_last_local[0] - G_shared[i_s2] <= 0):
with T.If(G_last_local[0] - G_fragment[i_s2] <= 0):
with T.Then():
dv_fragment[i_s2, i_v] = dv_fragment[i_s2, i_v] * G_fragment_post[i_s2]
with T.Else():
dv_fragment[i_s2, i_v] = 0
T.copy(dv[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh, bv * block_DV : (bv + 1) * block_DV], dv_shared)
T.copy(dv_shared, dv_fragment_2)
for i_s2, i_v in T.Parallel(block_S, block_DV):
dv_fragment[i_s2, i_v] = dv_fragment[i_s2, i_v] + dv_fragment_2[i_s2, i_v]
# Store the updated dv
T.copy(dv_fragment, dv_shared)
T.copy(dv_shared, dv2[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh, bv * block_DV : (bv + 1) * block_DV])
# Update dh
T.copy(Q[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh, 0:DK], Q_shared)
T.copy(W[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh, 0:DK], W_shared)
T.clear(Q_fragment)
if use_g:
for i_k, i_v in T.Parallel(DK, block_DV):
b_dh_fragment[i_k, i_v] *= G_last_local_exp[0]
T.copy(Q_shared, Q_fragment)
for i_s2 in T.Parallel(block_S):
G_fragment_exp[i_s2] = T.exp(G_shared[i_s2])
for i_s2, i_k in T.Parallel(block_S, DK):
# Q_fragment[i_s2, i_k] = Q_fragment[i_s2, i_k] * T.exp(G_shared[i_s2]) * scale
Q_fragment[i_s2, i_k] = Q_fragment[i_s2, i_k] * G_fragment_exp[i_s2] * scale
else:
T.copy(Q_shared, Q_fragment)
for i_s2, i_k in T.Parallel(block_S, DK):
Q_fragment[i_s2, i_k] = Q_fragment[i_s2, i_k] * scale
# Get transpose of Q_fragment to meet tf32 gemm requirement
for i_s2, i_k in T.Parallel(block_S, DK):
Q_fragment_t[i_k, i_s2] = Q_fragment[i_s2, i_k]
T.copy(dO[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh, bv * block_DV : (bv + 1) * block_DV], dO_shared)
T.copy(dO_shared, dO_fragment)
for i_s2, i_v in T.Parallel(block_S, block_DV):
dO_fragment_t[i_v, i_s2] = dO_fragment[i_s2, i_v]
T.copy(dO_fragment_t, dO_shared_t)
T.clear(b_dh_fragment_1)
T.gemm(Q_fragment_t, dO_shared_t, b_dh_fragment_1, transpose_B=True)
T.clear(b_dh_fragment_2)
T.gemm(W_shared, dv_shared, b_dh_fragment_2, transpose_A=True)
for i_k, i_v in T.Parallel(DK, block_DV):
b_dh_fragment[i_k, i_v] += b_dh_fragment_1[i_k, i_v] - b_dh_fragment_2[i_k, i_v]
if use_initial_state:
T.copy(b_dh_fragment, dh0[bb, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV])
return kernel
def test_result(dh_0, dh0_0, dv2_0, dh_1, dh0_1, dv2_1, name):
try:
torch.testing.assert_close(dh_0, dh_1, rtol=1e-2, atol=1e-2, equal_nan=True)
print(f"{name} dh_0 and dh_1 passed for {name}")
except Exception as e:
print(f"{name} dh_0 and dh_1 are not close for {name}")
print(e, end="\n\n")
try:
torch.testing.assert_close(dh0_0, dh0_1, rtol=1e-2, atol=1e-2, equal_nan=True)
print(f"{name} dh0_0 and dh0_1 passed for {name}")
except Exception as e:
print(f"{name} dh0_0 and dh0_1 are not close for {name}")
print(e, end="\n\n")
try:
torch.testing.assert_close(dv2_0, dv2_1, rtol=1e-2, atol=1e-2, equal_nan=True)
print(f"{name} dv2_0 and dv2_1 passed for {name}")
except Exception as e:
print(f"{name} dv2_0 and dv2_1 are not close for {name}")
print(e, end="\n\n")
close = torch.isclose(dh_0, dh_1, rtol=1e-2, atol=1e-2)
mismatch_indices = torch.nonzero(~close, as_tuple=True)
error_num = 0
for indices in zip(*mismatch_indices):
if error_num < 100:
print(
f"{name} dh_0[{[idx.item() for idx in indices]}] = {dh_0[indices[0].item(), indices[1].item(), indices[2].item(), indices[3].item(), indices[4].item()]}, dh_1[{[idx.item() for idx in indices]}] = {dh_1[indices[0].item(), indices[1].item(), indices[2].item(), indices[3].item(), indices[4].item()]}"
)
error_num += 1
close = torch.isclose(dh0_0, dh0_1, rtol=1e-2, atol=1e-2)
mismatch_indices = torch.nonzero(~close, as_tuple=True)
error_num = 0
for indices in zip(*mismatch_indices):
if error_num < 100:
print(
f"{name} dh0_0[{[idx.item() for idx in indices]}] = {dh0_0[indices[0].item(), indices[1].item(), indices[2].item(), indices[3].item()]}, dh0_1[{[idx.item() for idx in indices]}] = {dh0_1[indices[0].item(), indices[1].item(), indices[2].item(), indices[3].item()]}"
)
error_num += 1
close = torch.isclose(dv2_0, dv2_1, rtol=1e-2, atol=1e-2)
mismatch_indices = torch.nonzero(~close, as_tuple=True)
error_num = 0
for indices in zip(*mismatch_indices):
if error_num < 100:
print(
f"{name} dv2_0[{[idx.item() for idx in indices]}] = {dv2_0[indices[0].item(), indices[1].item(), indices[2].item(), indices[3].item()]}, dv2_1[{[idx.item() for idx in indices]}] = {dv2_1[indices[0].item(), indices[1].item(), indices[2].item(), indices[3].item()]}"
)
error_num += 1
def run_test(
B,
S,
H,
DK,
DV,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
state_dtype,
chunk_size,
scale,
use_g=True,
use_initial_state=True,
use_final_state_gradient=True,
block_DV=64,
threads=256,
num_stages=0,
use_torch=False,
):
Q, K, W, G, h0, dht, dO, dv = prepare_input(
B,
S,
H,
DK,
DV,
chunk_size,
getattr(torch, input_dtype),
getattr(torch, output_dtype),
getattr(torch, accum_dtype),
getattr(torch, gate_dtype),
getattr(torch, state_dtype),
)
dh_ref, dh0_ref, dv2_ref = prepare_output(
B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype), getattr(torch, state_dtype)
)
dh_tilelang, dh0_tilelang, dv2_tilelang = prepare_output(
B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype), getattr(torch, state_dtype)
)
# fla ref
print("fla running...", flush=True)
if use_g:
dh_ref, dh0_ref, dv2_ref = chunk_gated_delta_rule_bwd_dhu(Q, K, W, G, h0, dht, dO, dv, scale)
else:
G = G.fill_(0)
dh_ref, dh0_ref, dv2_ref = chunk_gated_delta_rule_bwd_dhu(Q, K, W, G, h0, dht, dO, dv, scale)
# tilelang
print("tilelang running...", flush=True)
kernel = tilelang_chunk_gated_delta_rule_bwd_dhu(
B,
S,
H,
DK,
DV,
input_dtype,
output_dtype,
accum_dtype,
gate_dtype,
state_dtype,
chunk_size,
scale,
use_g,
use_initial_state,
use_final_state_gradient,
block_DV,
threads,
num_stages,
)
# kernel = tilelang.compile(program)
print(kernel.get_kernel_source())
dh_tilelang, dh0_tilelang, dv2_tilelang = kernel(Q, K, W, G, h0, dht, dO, dv)
fla_time = do_bench(chunk_gated_delta_rule_bwd_dhu, Q, K, W, G, h0, dht, dO, dv, scale, chunk_size=chunk_size)
tilelang_time = do_bench(kernel, Q, K, W, G, h0, dht, dO, dv)
print(f"fla time: {fla_time} ms")
print(f"tilelang time: {tilelang_time} ms")
assert_similar(dh_tilelang, dh_ref, 1e-5, "fla-tilelang", data="dh")
assert_similar(dh0_tilelang, dh0_ref, 1e-5, "fla-tilelang", data="dh0")
assert_similar(dv2_tilelang, dv2_ref, 1e-5, "fla-tilelang", data="dv2")
# torch ref
if use_torch:
print("torch running...", flush=True)
if use_g:
dh_ref_torch, dh0_ref_torch, dv2_ref_torch = torch_chunk_gated_delta_rule_bwd_dhu(
Q,
K,
W,
G,
h0,
dht,
dO,
dv,
scale,
use_g,
use_initial_state,
use_final_state_gradient,
getattr(torch, input_dtype),
getattr(torch, output_dtype),
getattr(torch, accum_dtype),
getattr(torch, gate_dtype),
getattr(torch, state_dtype),
)
dh_ref_torch = dh_ref_torch.cuda()
dh0_ref_torch = dh0_ref_torch.cuda()
dv2_ref_torch = dv2_ref_torch.cuda()
else:
dh_ref_torch, dh0_ref_torch, dv2_ref_torch = torch_chunk_gated_delta_rule_bwd_dhu(
Q,
K,
W,
None,
h0,
dht,
dO,
dv,
scale,
use_g,
use_initial_state,
use_final_state_gradient,
getattr(torch, input_dtype),
getattr(torch, output_dtype),
getattr(torch, accum_dtype),
getattr(torch, gate_dtype),
getattr(torch, state_dtype),
)
dh_ref_torch = dh_ref_torch.cuda()
dh0_ref_torch = dh0_ref_torch.cuda()
dv2_ref_torch = dv2_ref_torch.cuda()
assert_similar(dh_ref_torch, dh_ref, 1e-5, "torch-fla", data="dh")
assert_similar(dh0_ref_torch, dh0_ref, 1e-5, "torch-fla", data="dh0")
assert_similar(dv2_ref_torch, dv2_ref, 1e-5, "torch-fla", data="dv2")
assert_similar(dh_ref_torch, dh_tilelang, 1e-5, "torch-tilelang", data="dh")
assert_similar(dh0_ref_torch, dh0_tilelang, 1e-5, "torch-tilelang", data="dh0")
assert_similar(dv2_ref_torch, dv2_tilelang, 1e-5, "torch-tilelang", data="dv2")
def do_bench(fn, *args, warmup=10, rep=10, **kwargs):
"""
Do benchmark for a function.
"""
start_event = [torch.cuda.Event(enable_timing=True) for i in range(rep)]
end_event = [torch.cuda.Event(enable_timing=True) for i in range(rep)]
for _ in range(warmup):
fn(*args, **kwargs)
torch.cuda.synchronize()
for i in range(rep):
start_event[i].record()
fn(*args, **kwargs)
end_event[i].record()
torch.cuda.synchronize()
# Record clocks
times = torch.tensor(
[s.elapsed_time(e) for s, e in zip(start_event, end_event)],
dtype=torch.float,
)
return times.mean().item()
def main():
DK = 128
run_test(
B=1,
S=32768,
H=8,
DK=DK,
DV=128,
input_dtype=T.bfloat16,
output_dtype=T.bfloat16,
accum_dtype=T.float32,
gate_dtype=T.float32,
state_dtype=T.float32,
chunk_size=64,
scale=DK**-0.5,
use_g=True,
use_initial_state=True,
use_final_state_gradient=True,
block_DV=32,
threads=128,
num_stages=1,
use_torch=False,
)
if __name__ == "__main__":
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment