import torch import torch.nn as nn import torch.nn.functional as F import triton import triton.language as tl import numpy as np import random import time @triton.jit def gated_proj_kernel( x_ptr, w1_ptr, w2_ptr, out_ptr, M, K, N, stride_xm, stride_xk, stride_w1k, stride_w1n, # w1 is [K, N] stride_w2k, stride_w2n, # w2 is [K, N] stride_om, stride_on, ACTIVATION: tl.constexpr, BLOCK_M: tl.constexpr = 64, BLOCK_N: tl.constexpr = 64, BLOCK_K: tl.constexpr = 32, ): pid_m = tl.program_id(0) pid_n = tl.program_id(1) offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) offs_k = tl.arange(0, BLOCK_K) # x: [M, K] x_ptrs = x_ptr + offs_m[:, None] * stride_xm + offs_k[None, :] * stride_xk # w1 and w2: [K, N] (转置后的权重) # 注意:w1_ptr 和 w2_ptr 已经指向转置后的权重 w1_ptrs = w1_ptr + offs_k[:, None] * stride_w1k + offs_n[None, :] * stride_w1n w2_ptrs = w2_ptr + offs_k[:, None] * stride_w2k + offs_n[None, :] * stride_w2n acc1 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) acc2 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) for k in range(0, K, BLOCK_K): # 加载 x x_mask = (offs_m[:, None] < M) & (offs_k[None, :] < K - k) x = tl.load(x_ptrs, mask=x_mask, other=0.0) # 加载 w1 和 w2 w_mask = (offs_k[:, None] < K - k) & (offs_n[None, :] < N) w1 = tl.load(w1_ptrs, mask=w_mask, other=0.0) w2 = tl.load(w2_ptrs, mask=w_mask, other=0.0) # 计算点积: x @ w1^T 和 x @ w2^T # x: [BLOCK_M, BLOCK_K], w1: [BLOCK_K, BLOCK_N] # tl.dot(x, w1) 计算的是 x @ w1,但我们需要 x @ w1^T # 由于 w1 是转置后的权重 [K, N],所以 x @ w1 就是我们要的 x @ w1^T acc1 += tl.dot(x, w1) acc2 += tl.dot(x, w2) # 移动指针到下一个block x_ptrs += BLOCK_K * stride_xk w1_ptrs += BLOCK_K * stride_w1k w2_ptrs += BLOCK_K * stride_w2k # 应用激活函数 if ACTIVATION == "silu": # SiLU(x) = x * sigmoid(x) sig = tl.sigmoid(acc1) out = acc1 * sig * acc2 # SiLU(w1*x) * (w2*x) # elif ACTIVATION == "gelu": # # GELU 近似 # # GELU(x) ≈ 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x^3))) # gelu_approx = 0.5 * acc1 * (1 + tl.tanh(0.79788456 * (acc1 + 0.044715 * acc1 * acc1 * acc1))) # out = gelu_approx * acc2 # else: # # 无激活函数 # out = acc1 * acc2 # 存储结果 out_ptrs = out_ptr + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on tl.store(out_ptrs, out.to(tl.bfloat16), mask=(offs_m[:, None] < M) & (offs_n[None, :] < N)) def fused_gated_proj(x, w1, w2, activation="silu"): """ x: [M, K] - input w1: [N, K] - weight1 (PyTorch Linear weight, 形状为 [out_features, in_features]) w2: [N, K] - weight2 (PyTorch Linear weight, 形状为 [out_features, in_features]) 返回: [M, N] 计算: activation(w1 @ x^T)^T * (w2 @ x^T)^T 等价于: SiLU(x @ w1^T) * (x @ w2^T) """ assert x.dtype == torch.bfloat16 assert w1.dtype == torch.bfloat16 and w2.dtype == torch.bfloat16 M, K = x.shape # M=1, K=4096 N, K2 = w1.shape # N=4096 K2=11264 assert K == K2, f"Dimension mismatch: x K={K}, w1 K={K2}" assert w2.shape == (N, K), f"w2 shape mismatch: expected {(N, K)}, got {w2.shape}" # 提前转置权重到 [K, N] 格式 w1_t = w1.t().contiguous() # [K, N] w2_t = w2.t().contiguous() # [K, N] out = torch.empty(M, N, dtype=torch.bfloat16, device=x.device) grid = lambda META: ( triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N']) ) gated_proj_kernel[grid]( x, w1_t, w2_t, out, # 传入转置后的权重 M, K, N, x.stride(0), x.stride(1), w1_t.stride(0), w1_t.stride(1), # [K, N] 的 stride w2_t.stride(0), w2_t.stride(1), out.stride(0), out.stride(1), ACTIVATION=activation, BLOCK_M=64, BLOCK_N=64, BLOCK_K=32, ) return out class ParallelGatedMLP(nn.Module): def __init__(self): super().__init__() self.act = F.silu self.act_type = "silu" self.l1 = nn.Linear( in_features=4096, out_features=11264, bias=False, ) self.l2 = nn.Linear( in_features=4096, out_features=11264, bias=False, ) self.l3 = nn.Linear( in_features=11264, out_features=4096, bias=False, ) def forward_org(self, z): """原始实现""" shape = z.shape z_flat = z.view(-1, shape[-1]) # [M, K] # PyTorch: F.linear(x, weight) = x @ weight^T # z1 = F.linear(z_flat, self.l1.weight) # [M, N] # z2 = F.linear(z_flat, self.l2.weight) # [M, N] z1, z2 = self.l1(z_flat), self.l2(z_flat) gated = self.act(z1) * z2 return gated def forward_opt(self, z): """Triton优化实现""" shape = z.shape z_flat = z.view(-1, shape[-1]) # [M, K] # Triton 路径 gated = fused_gated_proj( z_flat, self.l1.weight, # [N, K] self.l2.weight, # [N, K] activation=self.act_type ) return gated if __name__ == "__main__": seed = 1111 torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) np.random.seed(seed) random.seed(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False # 创建模型实例 model = ParallelGatedMLP() model = model.to(dtype=torch.bfloat16, device="cuda:0") # 测试不同的batch size for batch_size in [1]: print(f"\n{'='*50}") print(f"Testing with batch_size={batch_size}") print('='*50) x = torch.randn(batch_size, 1, 4096, dtype=torch.bfloat16, device="cuda:0") with torch.no_grad(): # 预热 for _ in range(3): _ = model.forward_org(x) _ = model.forward_opt(x) t0 = time.time() # 计算原始版本 result_org = model.forward_org(x) t1 = time.time() print(f"Time taken for forward_org: {t1 - t0:.4f} seconds") # 计算优化版本 result_opt = model.forward_opt(x) print(f"Time taken for forward_org: {time.time() - t1:.4f} seconds") # 验证结果 print(f"ORG shape: {result_org.shape}") print(f"OPT shape: {result_opt.shape}") # 计算差异 diff = torch.abs(result_org - result_opt) print(f"Max diff: {diff.max().item():.6f}") print(f"Mean diff: {diff.mean().item():.6f}") print(f"Min diff: {diff.min().item():.6f}") # 相对误差 rel_error = diff / (torch.abs(result_org) + 1e-8) print(f"Max relative error: {rel_error.max().item():.6f}") print(f"Mean relative error: {rel_error.mean().item():.6f}") # 验证前几个值 print("\nFirst 10 values comparison:") print(f"ORG: {result_org[0, :10].float().cpu().numpy()}") print(f"OPT: {result_opt[0, :10].float().cpu().numpy()}") print(f"Diff: {diff[0, :10].float().cpu().numpy()}") # 检查是否匹配 if torch.allclose(result_org, result_opt, rtol=1e-2, atol=1e-3): print("✓ Results match within tolerance!") else: print("✗ Results do not match!") # # 额外的验证:检查数学等价性 # print(f"\n{'='*50}") # print("Additional mathematical verification") # print('='*50) # # 使用小矩阵验证 # test_x = torch.randn(2, 16, dtype=torch.bfloat16, device="cuda:0") # test_w1 = torch.randn(32, 16, dtype=torch.bfloat16, device="cuda:0") # test_w2 = torch.randn(32, 16, dtype=torch.bfloat16, device="cuda:0") # # PyTorch 计算 # z1_pt = F.linear(test_x, test_w1) # x @ w1^T # z2_pt = F.linear(test_x, test_w2) # x @ w2^T # result_pt = F.silu(z1_pt) * z2_pt # # Triton 计算 # result_triton = fused_gated_proj(test_x, test_w1, test_w2, activation="silu") # diff_test = torch.abs(result_pt - result_triton) # print(f"Test max diff: {diff_test.max().item():.6f}") # print(f"Test mean diff: {diff_test.mean().item():.6f}") # if torch.allclose(result_pt, result_triton, rtol=1e-2, atol=1e-3): # print("✓ Test passed: Triton implementation matches PyTorch!") # else: # print("✗ Test failed: Triton implementation doesn't match PyTorch!")