import triton import triton.language as tl # @triton.jit # def gated_mlp_kernel( # # 输入 # x_ptr, # [M, K] # w1_ptr, # [N, K] -> 注意:w1 是 out_features x in_features # w2_ptr, # w3_ptr, # [K_out, N] = [hidden, inner] # y_ptr, # output [M, K_out] # # 形状 # M, # batch * seq_len # K, # hidden_size (e.g., 4096) # N, # inner_size (e.g., 11264) # K_out: tl.constexpr, # # 分块 # BLOCK_M: tl.constexpr = 64, # BLOCK_N: tl.constexpr = 128, # BLOCK_K: tl.constexpr = 64, # ): # pid_m = tl.program_id(0) # pid_n = tl.program_id(1) # # 计算当前 block 覆盖的输出区域: [pid_m*BLOCK_M : ..., pid_n*BLOCK_N : ...] # offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) # offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) # offs_k = tl.arange(0, BLOCK_K) # # 加载 x 的一行(或几行) # x_ptrs = x_ptr + offs_m[:, None] * K + offs_k[None, :] # w1_ptrs = w1_ptr + offs_n[:, None] * K + offs_k[None, :] # w2_ptrs = w2_ptr + offs_n[:, None] * K + offs_k[None, :] # # 初始化累加器 # acc1 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) # acc2 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) # for k in range(0, K, BLOCK_K): # # 边界处理 # k_mask = (offs_k[None, :] < K - k) # x = tl.load(x_ptrs, mask=k_mask, other=0.0) # w1 = tl.load(w1_ptrs, mask=k_mask, other=0.0) # w2 = tl.load(w2_ptrs, mask=k_mask, other=0.0) # acc1 += tl.dot(x, w1.T) # acc2 += tl.dot(x, w2.T) # x_ptrs += BLOCK_K # w1_ptrs += BLOCK_K # w2_ptrs += BLOCK_K # offs_k += BLOCK_K # # 应用 SiLU: x * sigmoid(x) # z1 = acc1.to(tl.bfloat16) # z2 = acc2.to(tl.bfloat16) # sig = tl.sigmoid(z1) # gated = z1 * sig * z2 # SiLU(z1) * z2 # # 第二阶段:gated @ w3.T → [M, N] @ [K_out, N].T = [M, K_out] # # 注意:w3 是 [K_out, N],我们要做 gated (M,N) × w3.T (N, K_out) # offs_k2 = tl.arange(0, BLOCK_K) # w3_ptrs = w3_ptr + offs_n[:, None] + offs_k2[None, :] * N # w3[k_out, n] → 列主序? # # 更安全的方式:假设 w3 是 [K_out, N],按行存储,则 w3[k, n] = w3_ptr[k*N + n] # # 所以要加载 w3 的第 n 列 → 需要转置视角 # # 我们改用:对每个输出列 k_out,累加 gated[:, n] * w3[k_out, n] # # 所以启动 grid 时,pid_n 对应 k_out,需要调整逻辑 # # ⚠️ 上面的设计有问题!更好的方式是分两个 kernel: # # 1. 计算 gated = SiLU(x@W1) * (x@W2) → [M, N] # # 2. gated @ W3.T → [M, K_out] # # 因为 N=11264 很大,直接三重融合会导致寄存器溢出 # # 因此,我们只融合前两步 + activation,第三步用 cuBLAS(torch.matmul) @triton.jit def gated_proj_kernel( x_ptr, w1_ptr, w2_ptr, out_ptr, M, K, N, stride_xm, stride_xk, stride_wk, stride_wn, # w is [N, K], so stride_wn = K stride_om, stride_on, ACTIVATION: tl.constexpr, BLOCK_M: tl.constexpr = 64, BLOCK_N: tl.constexpr = 64, BLOCK_K: tl.constexpr = 32, ): pid_m = tl.program_id(0) pid_n = tl.program_id(1) offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) offs_k = tl.arange(0, BLOCK_K) x_ptrs = x_ptr + offs_m[:, None] * stride_xm + offs_k[None, :] * stride_xk w1_ptrs = w1_ptr + offs_n[:, None] * stride_wn + offs_k[None, :] * stride_wk w2_ptrs = w2_ptr + offs_n[:, None] * stride_wn + offs_k[None, :] * stride_wk acc1 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) acc2 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) for k in range(0, K, BLOCK_K): k_mask = offs_k[None, :] < K - k x = tl.load(x_ptrs, mask=(offs_m[:, None] < M) & k_mask, other=0.0) w1 = tl.load(w1_ptrs, mask=(offs_n[:, None] < N) & k_mask, other=0.0) w2 = tl.load(w2_ptrs, mask=(offs_n[:, None] < N) & k_mask, other=0.0) acc1 += tl.dot(x, w1.T) acc2 += tl.dot(x, w2.T) x_ptrs += BLOCK_K * stride_xk w1_ptrs += BLOCK_K * stride_wk w2_ptrs += BLOCK_K * stride_wk offs_k += BLOCK_K z1 = acc1.to(tl.bfloat16) z2 = acc2.to(tl.bfloat16) if ACTIVATION == "silu": sig = tl.sigmoid(z1) out = z1 * sig * z2 elif ACTIVATION == "gelu": # Triton 没有 gelu,可近似或回退 out = z1 * 0.5 * (1 + tl.tanh(0.79788456 * (z1 + 0.044715 * z1 * z1 * z1))) * z2 else: out = z1 * z2 out_ptrs = out_ptr + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on tl.store(out_ptrs, out, mask=(offs_m[:, None] < M) & (offs_n[None, :] < N))