triton_ops.py 9.54 KB
Newer Older
litzh's avatar
litzh committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
from typing import Optional

import torch
import triton
import triton.language as tl
from torch import Tensor

from lightx2v.models.networks.wan.infer.triton_ops import fuse_scale_shift_kernel as wan_fuse_scale_shift_kernel


def fuse_scale_shift_kernel(
    x: torch.Tensor,
    scale: torch.Tensor,
    shift: torch.Tensor,
    block_l: int = 128,
    block_c: int = 128,
):
    return wan_fuse_scale_shift_kernel(x, scale.unsqueeze(0), shift.unsqueeze(0), block_l=block_l, block_c=block_c).squeeze(0)


@triton.jit
def _fused_rmsnorm_modulate_kernel(
    X,  # Input [M, N]
    Y,  # Output [M, N]
    Scale,  # Scale tensor (various shapes supported)
    Shift,  # Shift tensor (various shapes supported)
    W,  # Optional weight for RMSNorm [N]
    B,  # Optional bias for RMSNorm [N]
    stride_x_row,
    stride_y_row,
    stride_scale_row,  # For 2D/3D scale
    stride_shift_row,  # For 2D/3D shift
    M,
    N,
    eps,
    HAS_WEIGHT: tl.constexpr,
    HAS_BIAS: tl.constexpr,
    SCALE_IS_4D: tl.constexpr,  # [B, F, 1, C] format
    num_frames: tl.constexpr,  # For 4D scale/shift
    frame_seqlen: tl.constexpr,  # For 4D scale/shift
    BLOCK_N: tl.constexpr,
):
    row = tl.program_id(0)
    X += row * stride_x_row
    Y += row * stride_y_row

    cols = tl.arange(0, BLOCK_N)
    mask = cols < N

    # Load input
    x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32)

    # Step 1: RMSNorm
    xbar = tl.where(mask, x, 0.0)
    var = tl.sum(xbar * xbar, axis=0) / N
    rstd = 1 / tl.sqrt(var + eps)
    x_hat = x * rstd

    # Apply optional weight and bias for RMSNorm
    if HAS_WEIGHT:
        w = tl.load(W + cols, mask=mask, other=1.0).to(tl.float32)
        x_hat = x_hat * w

    if HAS_BIAS:
        b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32)
        x_hat = x_hat + b

    # Step 2: Load scale and shift based on format
    if SCALE_IS_4D:
        # For 4D: [B, F, 1, C] -> need to map row to correct frame
        batch_idx = row // (num_frames * frame_seqlen)
        t_idx = row % (num_frames * frame_seqlen)
        frame_idx = t_idx // frame_seqlen
        scale_row_idx = batch_idx * num_frames + frame_idx

        Scale += scale_row_idx * N
        Shift += scale_row_idx * N
        scale = tl.load(Scale + cols, mask=mask, other=0.0).to(tl.float32)
        shift = tl.load(Shift + cols, mask=mask, other=0.0).to(tl.float32)
    else:
        # For 2D/3D: direct row indexing with stride
        Scale += row * stride_scale_row
        Shift += row * stride_shift_row
        scale = tl.load(Scale + cols, mask=mask, other=0.0).to(tl.float32)
        shift = tl.load(Shift + cols, mask=mask, other=0.0).to(tl.float32)

    # Step 3: Apply modulation: x_hat * (1 + scale) + shift
    y = x_hat * (1.0 + scale) + shift

    tl.store(Y + cols, y, mask=mask)


def fused_rmsnorm_modulate(
    x: Tensor,
    scale: Tensor,
    shift: Tensor,
    weight: Optional[Tensor] = None,
    bias: Optional[Tensor] = None,
    eps: float = 1e-6,
    out: Optional[Tensor] = None,
) -> Tensor:
    """
    融合的 RMSNorm + Modulation 操作

    计算: (RMSNorm(x) * weight + bias) * (1 + scale) + shift

    Args:
        x: 输入张量 [B, L, C] 或 [M, N]
        scale: 调制缩放张量,支持多种形状:
            - [B, F, 1, C]: 4D 格式(帧级调制)
            - [B, L, C] 或 [B, 1, C]: 3D 格式
            - [1, C]: 2D 格式
        shift: 调制偏移张量,形状同 scale
        weight: RMSNorm 权重 [C],可选
        bias: RMSNorm 偏置 [C],可选
        eps: RMSNorm 的 epsilon
        out: 输出张量,可选

    Returns:
        调制后的张量,形状与 x 相同
    """
    # Reshape to 2D for processing
    original_shape = x.shape
    if x.dim() == 3:
        B, L, C = x.shape
        x_2d = x.reshape(B * L, C)
    elif x.dim() == 2:
        x_2d = x
        B, L = 1, x.shape[0]
        C = x.shape[1]
    else:
        raise ValueError(f"Input must be 2D or 3D, got {x.dim()}D")

    M, N = x_2d.shape
    x_2d = x_2d.contiguous()

    # Validate weight and bias
    if weight is not None:
        assert weight.shape == (N,) and weight.stride(-1) == 1
    if bias is not None:
        assert bias.shape == (N,) and bias.stride(-1) == 1

    # Prepare output
    if out is None:
        out_2d = torch.empty_like(x_2d)
    else:
        out_2d = out.reshape(M, N) if out.dim() == 3 else out

    # Determine scale/shift format
    is_4d = scale.dim() == 4

    if is_4d:
        # [B, F, 1, C] format
        assert scale.shape[0] == B or scale.shape[0] == 1
        num_frames = scale.shape[1]
        assert scale.shape[2] == 1
        assert scale.shape[3] == C
        assert L % num_frames == 0, f"seq_len {L} must be divisible by num_frames {num_frames}"
        frame_seqlen = L // num_frames

        # Reshape to [B*F, C] for easier indexing
        scale_2d = scale.squeeze(2).reshape(-1, C).contiguous()
        shift_2d = shift.squeeze(2).reshape(-1, C).contiguous()
        stride_scale_row = 0
        stride_shift_row = 0
    else:
        # Handle 2D/3D formats
        if scale.dim() == 2:
            # [B, C] or [1, C] -> expand to [M, C]
            scale_exp = scale.expand(M, N).contiguous()
            shift_exp = shift.expand(M, N).contiguous()
        elif scale.dim() == 3:
            # [B, L, C] -> reshape to [M, C]
            scale_exp = scale.reshape(M, N).contiguous()
            shift_exp = shift.reshape(M, N).contiguous()
        else:
            raise ValueError(f"scale must be 2D, 3D, or 4D, got {scale.dim()}D")

        scale_2d = scale_exp
        shift_2d = shift_exp
        stride_scale_row = scale_2d.stride(0)
        stride_shift_row = shift_2d.stride(0)
        num_frames = 0
        frame_seqlen = 0

    # Kernel configuration
    MAX_FUSED_SIZE = 65536 // x_2d.element_size()
    BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
    if N > BLOCK_N:
        raise RuntimeError("This kernel doesn't support feature dim >= 64KB.")

    num_warps = min(max(BLOCK_N // 256, 1), 8)

    # Launch kernel
    grid = (M,)
    _fused_rmsnorm_modulate_kernel[grid](
        x_2d,
        out_2d,
        scale_2d,
        shift_2d,
        weight if weight is not None else x_2d,  # dummy when HAS_WEIGHT=False
        bias if bias is not None else x_2d,  # dummy when HAS_BIAS=False
        x_2d.stride(0),
        out_2d.stride(0),
        stride_scale_row,
        stride_shift_row,
        M,
        N,
        eps,
        HAS_WEIGHT=weight is not None,
        HAS_BIAS=bias is not None,
        SCALE_IS_4D=is_4d,
        num_frames=num_frames if is_4d else 0,
        frame_seqlen=frame_seqlen if is_4d else 0,
        BLOCK_N=BLOCK_N,
        num_warps=num_warps,
    )

    # Reshape back to original shape
    return out_2d.reshape(original_shape) if out is None else out


# 测试函数
def test_fused_rmsnorm_modulate():
    """测试融合 kernel 的正确性"""
    B, L, C = 2, 128, 1024
    eps = 1e-6

    # 创建测试数据
    x = torch.randn(B, L, C, device="cuda", dtype=torch.float32)
    scale = torch.randn(B, L, C, device="cuda", dtype=torch.float32)
    shift = torch.randn(B, L, C, device="cuda", dtype=torch.float32)
    weight = torch.randn(C, device="cuda", dtype=torch.float32)
    bias = torch.randn(C, device="cuda", dtype=torch.float32)

    # Torch 参考实现(不带 weight/bias)
    def reference_impl(x, scale, shift, eps):
        return (x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + eps)) * (1 + scale) + shift

    # 测试不带 weight/bias
    out_ref = reference_impl(x, scale, shift, eps)
    out_triton = fused_rmsnorm_modulate(x, scale, shift, eps=eps)

    print("测试不带 weight/bias:")
    print(f"  最大误差: {(out_ref - out_triton).abs().max().item():.6e}")
    print(f"  平均误差: {(out_ref - out_triton).abs().mean().item():.6e}")
    assert torch.allclose(out_ref, out_triton, rtol=1e-4, atol=1e-5), "不带 weight/bias 的测试失败"

    # 测试带 weight/bias
    def reference_impl_with_wb(x, scale, shift, weight, bias, eps):
        normed = x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + eps)
        normed = normed * weight + bias
        return normed * (1 + scale) + shift

    out_ref_wb = reference_impl_with_wb(x, scale, shift, weight, bias, eps)
    out_triton_wb = fused_rmsnorm_modulate(x, scale, shift, weight, bias, eps=eps)

    print("\n测试带 weight/bias:")
    print(f"  最大误差: {(out_ref_wb - out_triton_wb).abs().max().item():.6e}")
    print(f"  平均误差: {(out_ref_wb - out_triton_wb).abs().mean().item():.6e}")
    assert torch.allclose(out_ref_wb, out_triton_wb, rtol=1e-4, atol=1e-5), "带 weight/bias 的测试失败"

    # 测试 4D scale/shift 格式
    num_frames = 8
    scale_4d = torch.randn(B, num_frames, 1, C, device="cuda", dtype=torch.float32)
    shift_4d = torch.randn(B, num_frames, 1, C, device="cuda", dtype=torch.float32)

    # 扩展 scale_4d 到 [B, L, C] 用于参考实现
    frame_len = L // num_frames
    scale_expanded = scale_4d.squeeze(2).repeat_interleave(frame_len, dim=1)
    shift_expanded = shift_4d.squeeze(2).repeat_interleave(frame_len, dim=1)

    out_ref_4d = reference_impl(x, scale_expanded, shift_expanded, eps)
    out_triton_4d = fused_rmsnorm_modulate(x, scale_4d, shift_4d, eps=eps)

    print("\n测试 4D scale/shift 格式:")
    print(f"  最大误差: {(out_ref_4d - out_triton_4d).abs().max().item():.6e}")
    print(f"  平均误差: {(out_ref_4d - out_triton_4d).abs().mean().item():.6e}")
    assert torch.allclose(out_ref_4d, out_triton_4d, rtol=1e-4, atol=1e-5), "4D 格式测试失败"

    print("\n✅ 所有测试通过!")


if __name__ == "__main__":
    test_fused_rmsnorm_modulate()