import torch import torch.nn as nn import torch.nn.functional as F import triton import triton.language as tl import time import torch import numpy as np import random @triton.jit def gated_proj_kernel( x_ptr, w1_ptr, w2_ptr, out_ptr, M, K, N, stride_xm, stride_xk, stride_wk, stride_wn, # w is [N, K], so stride_wn = K stride_om, stride_on, ACTIVATION: tl.constexpr, BLOCK_M: tl.constexpr = 64, BLOCK_N: tl.constexpr = 64, BLOCK_K: tl.constexpr = 32, ): pid_m = tl.program_id(0) pid_n = tl.program_id(1) offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) offs_k = tl.arange(0, BLOCK_K) x_ptrs = x_ptr + offs_m[:, None] * stride_xm + offs_k[None, :] * stride_xk w1_ptrs = w1_ptr + offs_n[:, None] * stride_wn + offs_k[None, :] * stride_wk w2_ptrs = w2_ptr + offs_n[:, None] * stride_wn + offs_k[None, :] * stride_wk acc1 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) acc2 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) for k in range(0, K, BLOCK_K): k_mask = offs_k[None, :] < K - k x = tl.load(x_ptrs, mask=(offs_m[:, None] < M) & k_mask, other=0.0) w1 = tl.load(w1_ptrs, mask=(offs_n[:, None] < N) & k_mask, other=0.0) w2 = tl.load(w2_ptrs, mask=(offs_n[:, None] < N) & k_mask, other=0.0) acc1 += tl.dot(x, w1.T) acc2 += tl.dot(x, w2.T) x_ptrs += BLOCK_K * stride_xk w1_ptrs += BLOCK_K * stride_wk w2_ptrs += BLOCK_K * stride_wk offs_k += BLOCK_K z1 = acc1.to(tl.float32) z2 = acc2.to(tl.float32) if ACTIVATION == "silu": sig = tl.sigmoid(z1) out = z1 * sig * z2 elif ACTIVATION == "gelu": # Triton 没有 gelu,可近似或回退 # out = z1 * 0.5 * (1 + tl.tanh(0.79788456 * (z1 + 0.044715 * z1 * z1 * z1))) * z2 sig = tl.sigmoid(z1) out = z1 * sig * z2 else: out = z1 * z2 out_ptrs = out_ptr + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on tl.store(out_ptrs, out.to(tl.bfloat16), mask=(offs_m[:, None] < M) & (offs_n[None, :] < N)) def fused_gated_proj(x, w1, w2, activation="silu"): assert x.dtype == torch.bfloat16 assert w1.dtype == torch.bfloat16 and w2.dtype == torch.bfloat16 M, K = x.shape # 1, 4096 N, _ = w1.shape # 4096, 11264 assert w2.shape == (N, K) out = torch.empty(M, N, dtype=torch.bfloat16, device=x.device) grid = lambda META: ( triton.cdiv(M, META['BLOCK_M']), triton.cdiv(N, META['BLOCK_N']) ) gated_proj_kernel[grid]( x, w1, w2, out, M, K, N, x.stride(0), x.stride(1), w1.stride(1), w1.stride(0), out.stride(0), out.stride(1), ACTIVATION=activation, BLOCK_M=64, BLOCK_N=64, BLOCK_K=32, ) return out class ParallelGatedMLP(nn.Module): def __init__(self): super().__init__() self.act = F.silu self.act_type = "silu" self.l1 = nn.Linear( in_features=4096, out_features=11264, bias=False, ) self.l2 = nn.Linear( in_features=4096, out_features=11264, bias=False, ) self.l3 = nn.Linear( in_features=11264, out_features=4096, bias=False, ) # 确保权重是 contiguous(通常 Linear 默认就是,但保险起见) self.l1.weight = torch.nn.Parameter(self.l1.weight.contiguous()) self.l2.weight = torch.nn.Parameter(self.l2.weight.contiguous()) self.l3.weight = torch.nn.Parameter(self.l3.weight.contiguous()) def forward(self, z): # z: [B, S, D] → flatten to [M, D] shape = z.shape z_flat = z.view(-1, int(shape[-1])) # [M, D] # Triton 路径 gated = fused_gated_proj( z_flat, self.l1.weight, # [inner, hidden] self.l2.weight, activation=self.act_type ) # y_flat = self.l3(gated) # [M, D] # y = y_flat.view(*shape) return gated def forward_org(self, z): shape = z.shape z_flat = z.view(-1, shape[-1]) # GELU 或调试时走原生路径 z1, z2 = self.l1(z_flat), self.l2(z_flat) gated = self.act(z1) * z2 return gated def forward_opt(self, z): # z: [B, S, D] → flatten to [M, D] shape = z.shape z_flat = z.view(-1, int(shape[-1])) # [M, D] # Triton 路径 gated = fused_gated_proj( z_flat, self.l1.weight, # [inner, hidden] self.l2.weight, activation=self.act_type ) return gated if __name__ == "__main__": seed = 1111 torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) # if using multi-GPU np.random.seed(seed) random.seed(seed) # 可选:牺牲性能以换取可复现性(因为某些 CUDA 操作是非确定性的) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False # 创建模型实例 model = ParallelGatedMLP() # 将模型转换为 bfloat16 model = model.to(dtype=torch.bfloat16, device="cuda:0") # 创建输入张量(batch=1, seq_len=1, hidden=4096) device = "cuda:0" # 或 "cuda" 如果你有支持 bf16 的 GPU(如 A100、H100) x = torch.randn(1, 1, 4096, dtype=torch.bfloat16, device=device) with torch.no_grad(): result_org = model.forward_org(x) print(f"ORG: {result_org[0, :20]}") result_opt = model.forward_opt(x) print(f"OPT: {result_opt[0, :20]}")