mlp-sample.py 8.88 KB
Newer Older
liuys's avatar
liuys committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
import torch
import torch.nn as nn
import torch.nn.functional as F
import triton
import triton.language as tl
import numpy as np
import random
import time


@triton.jit
def gated_proj_kernel(
    x_ptr, w1_ptr, w2_ptr, out_ptr,
    M, K, N,
    stride_xm, stride_xk,
    stride_w1k, stride_w1n,  # w1 is [K, N]
    stride_w2k, stride_w2n,  # w2 is [K, N]
    stride_om, stride_on,
    ACTIVATION: tl.constexpr,
    BLOCK_M: tl.constexpr = 64,
    BLOCK_N: tl.constexpr = 64,
    BLOCK_K: tl.constexpr = 32,
):
    pid_m = tl.program_id(0)
    pid_n = tl.program_id(1)

    offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    offs_k = tl.arange(0, BLOCK_K)

    # x: [M, K]
    x_ptrs = x_ptr + offs_m[:, None] * stride_xm + offs_k[None, :] * stride_xk
    
    # w1 and w2: [K, N] (转置后的权重)
    # 注意:w1_ptr 和 w2_ptr 已经指向转置后的权重
    w1_ptrs = w1_ptr + offs_k[:, None] * stride_w1k + offs_n[None, :] * stride_w1n
    w2_ptrs = w2_ptr + offs_k[:, None] * stride_w2k + offs_n[None, :] * stride_w2n

    acc1 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
    acc2 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)

    for k in range(0, K, BLOCK_K):
        # 加载 x
        x_mask = (offs_m[:, None] < M) & (offs_k[None, :] < K - k)
        x = tl.load(x_ptrs, mask=x_mask, other=0.0)
        
        # 加载 w1 和 w2
        w_mask = (offs_k[:, None] < K - k) & (offs_n[None, :] < N)
        w1 = tl.load(w1_ptrs, mask=w_mask, other=0.0)
        w2 = tl.load(w2_ptrs, mask=w_mask, other=0.0)
        
        # 计算点积: x @ w1^T 和 x @ w2^T
        # x: [BLOCK_M, BLOCK_K], w1: [BLOCK_K, BLOCK_N]
        # tl.dot(x, w1) 计算的是 x @ w1,但我们需要 x @ w1^T
        # 由于 w1 是转置后的权重 [K, N],所以 x @ w1 就是我们要的 x @ w1^T
        acc1 += tl.dot(x, w1)
        acc2 += tl.dot(x, w2)

        # 移动指针到下一个block
        x_ptrs += BLOCK_K * stride_xk
        w1_ptrs += BLOCK_K * stride_w1k
        w2_ptrs += BLOCK_K * stride_w2k

    # 应用激活函数
    if ACTIVATION == "silu":
        # SiLU(x) = x * sigmoid(x)
        sig = tl.sigmoid(acc1)
        out = acc1 * sig * acc2  # SiLU(w1*x) * (w2*x)

    # elif ACTIVATION == "gelu":
    #     # GELU 近似
    #     # GELU(x) ≈ 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x^3)))
    #     gelu_approx = 0.5 * acc1 * (1 + tl.tanh(0.79788456 * (acc1 + 0.044715 * acc1 * acc1 * acc1)))
    #     out = gelu_approx * acc2
    # else:
    #     # 无激活函数
    #     out = acc1 * acc2

    # 存储结果
    out_ptrs = out_ptr + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on
    tl.store(out_ptrs, out.to(tl.bfloat16), mask=(offs_m[:, None] < M) & (offs_n[None, :] < N))


def fused_gated_proj(x, w1, w2, activation="silu"):
    """
    x: [M, K] - input
    w1: [N, K] - weight1 (PyTorch Linear weight, 形状为 [out_features, in_features])
    w2: [N, K] - weight2 (PyTorch Linear weight, 形状为 [out_features, in_features])
    返回: [M, N]
    
    计算: activation(w1 @ x^T)^T * (w2 @ x^T)^T
    等价于: SiLU(x @ w1^T) * (x @ w2^T)
    """
    assert x.dtype == torch.bfloat16
    assert w1.dtype == torch.bfloat16 and w2.dtype == torch.bfloat16
    M, K = x.shape  # M=1, K=4096
    N, K2 = w1.shape  # N=4096 K2=11264
    assert K == K2, f"Dimension mismatch: x K={K}, w1 K={K2}"
    assert w2.shape == (N, K), f"w2 shape mismatch: expected {(N, K)}, got {w2.shape}"

    # 提前转置权重到 [K, N] 格式
    w1_t = w1.t().contiguous()  # [K, N]
    w2_t = w2.t().contiguous()  # [K, N]

    out = torch.empty(M, N, dtype=torch.bfloat16, device=x.device)

    grid = lambda META: (
        triton.cdiv(M, META['BLOCK_M']),
        triton.cdiv(N, META['BLOCK_N'])
    )

    gated_proj_kernel[grid](
        x, w1_t, w2_t, out,  # 传入转置后的权重
        M, K, N,
        x.stride(0), x.stride(1),
        w1_t.stride(0), w1_t.stride(1),  # [K, N] 的 stride
        w2_t.stride(0), w2_t.stride(1),
        out.stride(0), out.stride(1),
        ACTIVATION=activation,
        BLOCK_M=64,
        BLOCK_N=64,
        BLOCK_K=32,
    )
    return out


class ParallelGatedMLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.act = F.silu
        self.act_type = "silu"

        self.l1 = nn.Linear(
            in_features=4096,
            out_features=11264,
            bias=False,
        )

        self.l2 = nn.Linear(
            in_features=4096,
            out_features=11264,
            bias=False,
        )
        self.l3 = nn.Linear(
            in_features=11264,
            out_features=4096,
            bias=False,
        )

    def forward_org(self, z):
        """原始实现"""
        shape = z.shape
        z_flat = z.view(-1, shape[-1])  # [M, K]
        
        # PyTorch: F.linear(x, weight) = x @ weight^T
        # z1 = F.linear(z_flat, self.l1.weight)  # [M, N]
        # z2 = F.linear(z_flat, self.l2.weight)  # [M, N]
        
        z1, z2 = self.l1(z_flat), self.l2(z_flat)
        gated = self.act(z1) * z2
        return gated

    def forward_opt(self, z):
        """Triton优化实现"""
        shape = z.shape
        z_flat = z.view(-1, shape[-1])  # [M, K]
        
        # Triton 路径
        gated = fused_gated_proj(
            z_flat,
            self.l1.weight,  # [N, K]
            self.l2.weight,  # [N, K]
            activation=self.act_type
        )
        return gated


if __name__ == "__main__":
    seed = 1111
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # 创建模型实例
    model = ParallelGatedMLP()
    model = model.to(dtype=torch.bfloat16, device="cuda:0")

    # 测试不同的batch size
    for batch_size in [1]:
        print(f"\n{'='*50}")
        print(f"Testing with batch_size={batch_size}")
        print('='*50)
        
        x = torch.randn(batch_size, 1, 4096, dtype=torch.bfloat16, device="cuda:0")
        
        with torch.no_grad():
            # 预热
            for _ in range(3):
                _ = model.forward_org(x)
                _ = model.forward_opt(x)

            
            t0 = time.time()
            # 计算原始版本
            result_org = model.forward_org(x)
            t1 = time.time()
            print(f"Time taken for forward_org: {t1 - t0:.4f} seconds")
            
            # 计算优化版本
            result_opt = model.forward_opt(x)
            print(f"Time taken for forward_org: {time.time() - t1:.4f} seconds")
            
            # 验证结果
            print(f"ORG shape: {result_org.shape}")
            print(f"OPT shape: {result_opt.shape}")
            
            # 计算差异
            diff = torch.abs(result_org - result_opt)
            print(f"Max diff: {diff.max().item():.6f}")
            print(f"Mean diff: {diff.mean().item():.6f}")
            print(f"Min diff: {diff.min().item():.6f}")
            
            # 相对误差
            rel_error = diff / (torch.abs(result_org) + 1e-8)
            print(f"Max relative error: {rel_error.max().item():.6f}")
            print(f"Mean relative error: {rel_error.mean().item():.6f}")
            
            # 验证前几个值
            print("\nFirst 10 values comparison:")
            print(f"ORG: {result_org[0, :10].float().cpu().numpy()}")
            print(f"OPT: {result_opt[0, :10].float().cpu().numpy()}")
            print(f"Diff: {diff[0, :10].float().cpu().numpy()}")
            
            # 检查是否匹配
            if torch.allclose(result_org, result_opt, rtol=1e-2, atol=1e-3):
                print("✓ Results match within tolerance!")
            else:
                print("✗ Results do not match!")
                
    # # 额外的验证:检查数学等价性
    # print(f"\n{'='*50}")
    # print("Additional mathematical verification")
    # print('='*50)
    
    # # 使用小矩阵验证
    # test_x = torch.randn(2, 16, dtype=torch.bfloat16, device="cuda:0")
    # test_w1 = torch.randn(32, 16, dtype=torch.bfloat16, device="cuda:0")
    # test_w2 = torch.randn(32, 16, dtype=torch.bfloat16, device="cuda:0")
    
    # # PyTorch 计算
    # z1_pt = F.linear(test_x, test_w1)  # x @ w1^T
    # z2_pt = F.linear(test_x, test_w2)  # x @ w2^T
    # result_pt = F.silu(z1_pt) * z2_pt
    
    # # Triton 计算
    # result_triton = fused_gated_proj(test_x, test_w1, test_w2, activation="silu")
    
    # diff_test = torch.abs(result_pt - result_triton)
    # print(f"Test max diff: {diff_test.max().item():.6f}")
    # print(f"Test mean diff: {diff_test.mean().item():.6f}")
    
    # if torch.allclose(result_pt, result_triton, rtol=1e-2, atol=1e-3):
    #     print("✓ Test passed: Triton implementation matches PyTorch!")
    # else:
    #     print("✗ Test failed: Triton implementation doesn't match PyTorch!")