import torch import torch.nn as nn import torch.nn.functional as F import triton import triton.language as tl import numpy as np import random import time @triton.jit def matmul_kernel( x_ptr, w_ptr, out_ptr, M, K, N, stride_xm, stride_xk, stride_wk, stride_wn, # w is [K, N] (已经转置好) stride_om, stride_on, BLOCK_M: tl.constexpr = 64, BLOCK_N: tl.constexpr = 64, BLOCK_K: tl.constexpr = 32, ): pid_m = tl.program_id(0) pid_n = tl.program_id(1) offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) offs_k = tl.arange(0, BLOCK_K) x_ptrs = x_ptr + offs_m[:, None] * stride_xm + offs_k[None, :] * stride_xk w_ptrs = w_ptr + offs_k[:, None] * stride_wk + offs_n[None, :] * stride_wn acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) for k in range(0, K, BLOCK_K): x_mask = (offs_m[:, None] < M) & (offs_k[None, :] < K - k) w_mask = (offs_k[:, None] < K - k) & (offs_n[None, :] < N) x = tl.load(x_ptrs, mask=x_mask, other=0.0) w = tl.load(w_ptrs, mask=w_mask, other=0.0) acc += tl.dot(x, w) x_ptrs += BLOCK_K * stride_xk w_ptrs += BLOCK_K * stride_wk # 转换为bfloat16输出 out = acc.to(tl.bfloat16) out_ptrs = out_ptr + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on tl.store(out_ptrs, out, mask=(offs_m[:, None] < M) & (offs_n[None, :] < N)) def triton_matmul(x, weight): """ Compute y = x @ weight.T using Triton. x: [M, K], dtype=bfloat16 weight: [N, K], dtype=bfloat16 (PyTorch Linear weight, 形状是[out_features, in_features]) Returns: y: [M, N], dtype=bfloat16 """ assert x.dtype == torch.bfloat16 assert weight.dtype == torch.bfloat16 assert x.device == weight.device assert x.is_contiguous() M, K = x.shape N, K2 = weight.shape assert K == K2, f"K mismatch: {K} != {K2}" # 提前转置权重到[K, N]格式,这样triton kernel可以直接使用 # weight是[N, K],我们需要weight.T = [K, N] w_t = weight.t().contiguous() # [K, N] out = torch.empty(M, N, dtype=torch.bfloat16, device=x.device) grid = lambda META: ( triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N']) ) # 注意:这里传递的是转置后的权重w_t,形状是[K, N] matmul_kernel[grid]( x, w_t, out, M, K, N, x.stride(0), x.stride(1), w_t.stride(0), w_t.stride(1), out.stride(0), out.stride(1), BLOCK_M=64, BLOCK_N=64, BLOCK_K=32, ) return out class ParallelGatedMLP(nn.Module): def __init__(self): super().__init__() self.act = F.silu self.act_type = "silu" self.l1 = nn.Linear( in_features=4096, out_features=11264, bias=False, ) self.l2 = nn.Linear( in_features=4096, out_features=11264, bias=False, ) self.l3 = nn.Linear( in_features=11264, out_features=4096, bias=False, ) def forward_org(self, z): shape = z.shape z_flat = z.view(-1, shape[-1]) # [M, K] # bfloat16 数据精度 # self.l1 = nn.Linear( # in_features=4096, # out_features=11264, # bias=False, # ) # z_flat.shape 1,4096 y = F.linear(z_flat, self.l1.weight, bias=None) # [M, N] return y def forward_org_triton(self, z): shape = z.shape z_flat = z.view(-1, shape[-1]) # [M, K] y = triton_matmul(z_flat, self.l1.weight) # [M, N] return y if __name__ == "__main__": seed = 1111 torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) np.random.seed(seed) random.seed(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False model = ParallelGatedMLP() model = model.to(dtype=torch.bfloat16, device="cuda:0") device = "cuda:0" x = torch.randn(1, 1, 4096, dtype=torch.bfloat16, device=device) # 测试正确性 with torch.no_grad(): result_org = model.forward_org(x) result_opt = model.forward_org_triton(x) print(f"ORG shape: {result_org.shape}") print(f"OPT shape: {result_opt.shape}") # 打印前20个元素比较 print(f"\nORG first 20: {result_org[0, :20]}") print(f"OPT first 20: {result_opt[0, :20]}") # 计算差异 diff = torch.abs(result_org - result_opt) print(f"\nMax diff: {diff.max().item()}") print(f"Mean diff: {diff.mean().item()}") # 相对误差检查 rel_error = diff / (torch.abs(result_org) + 1e-8) print(f"Max relative error: {rel_error.max().item()}") # 验证是否在合理误差范围内(由于浮点计算差异) if torch.allclose(result_org, result_opt, rtol=1e-2, atol=1e-3): print("\n✓ Results match within tolerance!") else: print("\n✗ Results do not match!")