"googlemock/include/gmock/gmock-matchers.h" did not exist on "6414d806cd7d0954cce81348552fdd1e5bd31515"
demo.py 4.7 KB
Newer Older
liuys's avatar
liuys committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import triton
import triton.language as tl

# @triton.jit
# def gated_mlp_kernel(
#     # 输入
#     x_ptr,          # [M, K]
#     w1_ptr,         # [N, K]  -> 注意:w1 是 out_features x in_features
#     w2_ptr,
#     w3_ptr,         # [K_out, N] = [hidden, inner]
#     y_ptr,          # output [M, K_out]

#     # 形状
#     M,              # batch * seq_len
#     K,              # hidden_size (e.g., 4096)
#     N,              # inner_size (e.g., 11264)
#     K_out: tl.constexpr,

#     # 分块
#     BLOCK_M: tl.constexpr = 64,
#     BLOCK_N: tl.constexpr = 128,
#     BLOCK_K: tl.constexpr = 64,
# ):
#     pid_m = tl.program_id(0)
#     pid_n = tl.program_id(1)

#     # 计算当前 block 覆盖的输出区域: [pid_m*BLOCK_M : ..., pid_n*BLOCK_N : ...]
#     offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
#     offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
#     offs_k = tl.arange(0, BLOCK_K)

#     # 加载 x 的一行(或几行)
#     x_ptrs = x_ptr + offs_m[:, None] * K + offs_k[None, :]
#     w1_ptrs = w1_ptr + offs_n[:, None] * K + offs_k[None, :]
#     w2_ptrs = w2_ptr + offs_n[:, None] * K + offs_k[None, :]

#     # 初始化累加器
#     acc1 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
#     acc2 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)

#     for k in range(0, K, BLOCK_K):
#         # 边界处理
#         k_mask = (offs_k[None, :] < K - k)
#         x = tl.load(x_ptrs, mask=k_mask, other=0.0)
#         w1 = tl.load(w1_ptrs, mask=k_mask, other=0.0)
#         w2 = tl.load(w2_ptrs, mask=k_mask, other=0.0)

#         acc1 += tl.dot(x, w1.T)
#         acc2 += tl.dot(x, w2.T)

#         x_ptrs += BLOCK_K
#         w1_ptrs += BLOCK_K
#         w2_ptrs += BLOCK_K
#         offs_k += BLOCK_K

#     # 应用 SiLU: x * sigmoid(x)
#     z1 = acc1.to(tl.bfloat16)
#     z2 = acc2.to(tl.bfloat16)
#     sig = tl.sigmoid(z1)
#     gated = z1 * sig * z2  # SiLU(z1) * z2

#     # 第二阶段:gated @ w3.T → [M, N] @ [K_out, N].T = [M, K_out]
#     # 注意:w3 是 [K_out, N],我们要做 gated (M,N) × w3.T (N, K_out)
#     offs_k2 = tl.arange(0, BLOCK_K)
#     w3_ptrs = w3_ptr + offs_n[:, None] + offs_k2[None, :] * N  # w3[k_out, n] → 列主序?
#     # 更安全的方式:假设 w3 是 [K_out, N],按行存储,则 w3[k, n] = w3_ptr[k*N + n]
#     # 所以要加载 w3 的第 n 列 → 需要转置视角

#     # 我们改用:对每个输出列 k_out,累加 gated[:, n] * w3[k_out, n]
#     # 所以启动 grid 时,pid_n 对应 k_out,需要调整逻辑

#     # ⚠️ 上面的设计有问题!更好的方式是分两个 kernel:
#     #   1. 计算 gated = SiLU(x@W1) * (x@W2)  → [M, N]
#     #   2. gated @ W3.T → [M, K_out]
#     # 因为 N=11264 很大,直接三重融合会导致寄存器溢出

#     # 因此,我们只融合前两步 + activation,第三步用 cuBLAS(torch.matmul)


@triton.jit
def gated_proj_kernel(
    x_ptr, w1_ptr, w2_ptr, out_ptr,
    M, K, N,
    stride_xm, stride_xk,
    stride_wk, stride_wn,  # w is [N, K], so stride_wn = K
    stride_om, stride_on,
    ACTIVATION: tl.constexpr,
    BLOCK_M: tl.constexpr = 64,
    BLOCK_N: tl.constexpr = 64,
    BLOCK_K: tl.constexpr = 32,
):
    pid_m = tl.program_id(0)
    pid_n = tl.program_id(1)

    offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    offs_k = tl.arange(0, BLOCK_K)

    x_ptrs = x_ptr + offs_m[:, None] * stride_xm + offs_k[None, :] * stride_xk
    w1_ptrs = w1_ptr + offs_n[:, None] * stride_wn + offs_k[None, :] * stride_wk
    w2_ptrs = w2_ptr + offs_n[:, None] * stride_wn + offs_k[None, :] * stride_wk

    acc1 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
    acc2 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)

    for k in range(0, K, BLOCK_K):
        k_mask = offs_k[None, :] < K - k
        x = tl.load(x_ptrs, mask=(offs_m[:, None] < M) & k_mask, other=0.0)
        w1 = tl.load(w1_ptrs, mask=(offs_n[:, None] < N) & k_mask, other=0.0)
        w2 = tl.load(w2_ptrs, mask=(offs_n[:, None] < N) & k_mask, other=0.0)

        acc1 += tl.dot(x, w1.T)
        acc2 += tl.dot(x, w2.T)

        x_ptrs += BLOCK_K * stride_xk
        w1_ptrs += BLOCK_K * stride_wk
        w2_ptrs += BLOCK_K * stride_wk
        offs_k += BLOCK_K

    z1 = acc1.to(tl.bfloat16)
    z2 = acc2.to(tl.bfloat16)

    if ACTIVATION == "silu":
        sig = tl.sigmoid(z1)
        out = z1 * sig * z2
    elif ACTIVATION == "gelu":
        # Triton 没有 gelu,可近似或回退
        out = z1 * 0.5 * (1 + tl.tanh(0.79788456 * (z1 + 0.044715 * z1 * z1 * z1))) * z2
    else:
        out = z1 * z2

    out_ptrs = out_ptr + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on
    tl.store(out_ptrs, out, mask=(offs_m[:, None] < M) & (offs_n[None, :] < N))