case-0.py 5.63 KB
Newer Older
liuys's avatar
liuys committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
import torch
import torch.nn as nn
import torch.nn.functional as F
import triton
import triton.language as tl

import time


import torch
import numpy as np
import random





@triton.jit
def gated_proj_kernel(
    x_ptr, w1_ptr, w2_ptr, out_ptr,
    M, K, N,
    stride_xm, stride_xk,
    stride_wk, stride_wn,  # w is [N, K], so stride_wn = K
    stride_om, stride_on,
    ACTIVATION: tl.constexpr,
    BLOCK_M: tl.constexpr = 64,
    BLOCK_N: tl.constexpr = 64,
    BLOCK_K: tl.constexpr = 32,
):
    pid_m = tl.program_id(0)
    pid_n = tl.program_id(1)

    offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    offs_k = tl.arange(0, BLOCK_K)

    x_ptrs = x_ptr + offs_m[:, None] * stride_xm + offs_k[None, :] * stride_xk
    w1_ptrs = w1_ptr + offs_n[:, None] * stride_wn + offs_k[None, :] * stride_wk
    w2_ptrs = w2_ptr + offs_n[:, None] * stride_wn + offs_k[None, :] * stride_wk

    acc1 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
    acc2 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)

    for k in range(0, K, BLOCK_K):
        k_mask = offs_k[None, :] < K - k
        x = tl.load(x_ptrs, mask=(offs_m[:, None] < M) & k_mask, other=0.0)
        w1 = tl.load(w1_ptrs, mask=(offs_n[:, None] < N) & k_mask, other=0.0)
        w2 = tl.load(w2_ptrs, mask=(offs_n[:, None] < N) & k_mask, other=0.0)

        acc1 += tl.dot(x, w1.T)
        acc2 += tl.dot(x, w2.T)

        x_ptrs += BLOCK_K * stride_xk
        w1_ptrs += BLOCK_K * stride_wk
        w2_ptrs += BLOCK_K * stride_wk
        offs_k += BLOCK_K

    z1 = acc1.to(tl.float32)
    z2 = acc2.to(tl.float32)

    if ACTIVATION == "silu":
        sig = tl.sigmoid(z1)
        out = z1 * sig * z2
    elif ACTIVATION == "gelu":
        # Triton 没有 gelu,可近似或回退
        # out = z1 * 0.5 * (1 + tl.tanh(0.79788456 * (z1 + 0.044715 * z1 * z1 * z1))) * z2
        sig = tl.sigmoid(z1)
        out = z1 * sig * z2
    else:
        out = z1 * z2

    out_ptrs = out_ptr + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on
    tl.store(out_ptrs, out.to(tl.bfloat16), mask=(offs_m[:, None] < M) & (offs_n[None, :] < N))


def fused_gated_proj(x, w1, w2, activation="silu"):
    assert x.dtype == torch.bfloat16
    assert w1.dtype == torch.bfloat16 and w2.dtype == torch.bfloat16
    M, K = x.shape  # 1, 4096
    N, _ = w1.shape # 4096, 11264
    assert w2.shape == (N, K)

    out = torch.empty(M, N, dtype=torch.bfloat16, device=x.device)

    grid = lambda META: (
        triton.cdiv(M, META['BLOCK_M']),
        triton.cdiv(N, META['BLOCK_N'])
    )

    gated_proj_kernel[grid](
        x, w1, w2, out,
        M, K, N,
        x.stride(0), x.stride(1),
        w1.stride(1), w1.stride(0),
        out.stride(0), out.stride(1),
        ACTIVATION=activation,
        BLOCK_M=64,
        BLOCK_N=64,
        BLOCK_K=32,
    )
    return out


class ParallelGatedMLP(nn.Module):
    def __init__(self):
        super().__init__()

        self.act = F.silu
        self.act_type = "silu"

        self.l1 = nn.Linear(
            in_features=4096,
            out_features=11264,
            bias=False,
        )

        self.l2 = nn.Linear(
            in_features=4096,
            out_features=11264,
            bias=False,
        )
        self.l3 = nn.Linear(
            in_features=11264,
            out_features=4096,
            bias=False,
        )

        # 确保权重是 contiguous(通常 Linear 默认就是,但保险起见)
        self.l1.weight = torch.nn.Parameter(self.l1.weight.contiguous())
        self.l2.weight = torch.nn.Parameter(self.l2.weight.contiguous())
        self.l3.weight = torch.nn.Parameter(self.l3.weight.contiguous())

    def forward(self, z):
        # z: [B, S, D] → flatten to [M, D]
        shape = z.shape
        z_flat = z.view(-1, int(shape[-1]))  # [M, D]

        # Triton 路径
        gated = fused_gated_proj(
            z_flat,
            self.l1.weight,  # [inner, hidden]
            self.l2.weight,
            activation=self.act_type
        )

        # y_flat = self.l3(gated)  # [M, D]
        # y = y_flat.view(*shape)
        return gated


    def forward_org(self, z):
        shape = z.shape
        z_flat = z.view(-1, shape[-1])
        # GELU 或调试时走原生路径
        z1, z2 = self.l1(z_flat), self.l2(z_flat)
        gated = self.act(z1) * z2
        return gated

    def forward_opt(self, z):
        # z: [B, S, D] → flatten to [M, D]
        shape = z.shape
        z_flat = z.view(-1, int(shape[-1]))  # [M, D]

        # Triton 路径
        gated = fused_gated_proj(
            z_flat,
            self.l1.weight,  # [inner, hidden]
            self.l2.weight,
            activation=self.act_type
        )

        return gated

if __name__ == "__main__":

    seed = 1111
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if using multi-GPU
    np.random.seed(seed)
    random.seed(seed)

    # 可选:牺牲性能以换取可复现性(因为某些 CUDA 操作是非确定性的)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


    # 创建模型实例
    model = ParallelGatedMLP()

    # 将模型转换为 bfloat16
    model = model.to(dtype=torch.bfloat16, device="cuda:0")

    # 创建输入张量(batch=1, seq_len=1, hidden=4096)
    device = "cuda:0"  # 或 "cuda" 如果你有支持 bf16 的 GPU(如 A100、H100)
    x = torch.randn(1, 1, 4096, dtype=torch.bfloat16, device=device)

    with torch.no_grad():
        result_org = model.forward_org(x)
        print(f"ORG: {result_org[0, :20]}")
        result_opt = model.forward_opt(x)
        print(f"OPT: {result_opt[0, :20]}")