bench_sum_scale.py 7.81 KB
Newer Older
1
2
import os

3
4
5
import torch
import triton
import triton.language as tl
6
from sgl_kernel import moe_sum_reduce as moe_sum_reduce_cuda
7
8
from triton.testing import do_bench

9
10
11
12
13
14
# CI environment detection
IS_CI = (
    os.getenv("CI", "false").lower() == "true"
    or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
)

15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39

@triton.jit
def _moe_sum_reduce_kernel(
    input_ptr,
    input_stride_0,
    input_stride_1,
    input_stride_2,
    output_ptr,
    output_stride_0,
    output_stride_1,
    token_num: int,
    topk_num: int,
    hidden_dim: int,
    routed_scaling_factor: tl.constexpr,
    BLOCK_M: tl.constexpr,
    BLOCK_DIM: tl.constexpr,
    NUM_STAGE: tl.constexpr,
):
    input_stride_0 = tl.cast(input_stride_0, dtype=tl.int64)
    input_stride_1 = tl.cast(input_stride_1, dtype=tl.int64)
    output_stride_0 = tl.cast(output_stride_0, dtype=tl.int64)

    token_block_id = tl.program_id(0)
    dim_block_id = tl.program_id(1)

40
41
42
43
44
45
46
47
48
49
50
51
52
53
    offs_token = token_block_id * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_dim = dim_block_id * BLOCK_DIM + tl.arange(0, BLOCK_DIM)

    mask_token = offs_token < token_num
    mask_dim = offs_dim < hidden_dim

    base_ptrs = input_ptr + offs_token[:, None] * input_stride_0 + offs_dim[None, :]

    accumulator = tl.zeros((BLOCK_M, BLOCK_DIM), dtype=tl.float32)
    for i in tl.range(0, topk_num, num_stages=NUM_STAGE):
        tile = tl.load(
            base_ptrs + i * input_stride_1,
            mask=mask_token[:, None] & mask_dim[None, :],
            other=0.0,
54
        )
55
56
57
58
59
60
61
62
63
64
        accumulator += tile.to(tl.float32)
    accumulator *= routed_scaling_factor

    # -------- Write back --------
    store_ptrs = output_ptr + offs_token[:, None] * output_stride_0 + offs_dim[None, :]
    tl.store(
        store_ptrs,
        accumulator.to(input_ptr.dtype.element_ty),
        mask=mask_token[:, None] & mask_dim[None, :],
    )
65
66


67
# _moe_sum_reduce_kernel kernel modified from https://github.com/ModelTC/lightllm/blob/main/lightllm/common/fused_moe/moe_sum_reduce.py
68
def moe_sum_reduce_triton(
69
70
71
72
73
74
75
76
77
78
79
    input: torch.Tensor, output: torch.Tensor, routed_scaling_factor: float
):
    assert input.is_contiguous()
    assert output.is_contiguous()

    token_num, topk_num, hidden_dim = input.shape
    assert output.shape[0] == token_num and output.shape[1] == hidden_dim

    BLOCK_M = 1
    BLOCK_DIM = 2048
    NUM_STAGE = 1
80
    num_warps = 16
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119

    grid = (
        triton.cdiv(token_num, BLOCK_M),
        triton.cdiv(hidden_dim, BLOCK_DIM),
    )

    _moe_sum_reduce_kernel[grid](
        input,
        *input.stride(),
        output,
        *output.stride(),
        token_num=token_num,
        topk_num=topk_num,
        hidden_dim=hidden_dim,
        routed_scaling_factor=routed_scaling_factor,
        BLOCK_M=BLOCK_M,
        BLOCK_DIM=BLOCK_DIM,
        NUM_STAGE=NUM_STAGE,
        num_warps=num_warps,
    )
    return


def compute_sum_scaled_baseline(
    x: torch.Tensor, out: torch.Tensor, routed_scaling_factor: float
) -> torch.Tensor:
    torch.sum(x, dim=1, out=out)
    out.mul_(routed_scaling_factor)
    return out


@torch.compile
def compute_sum_scaled_compiled(
    x: torch.Tensor, out: torch.Tensor, routed_scaling_factor: float
) -> torch.Tensor:
    torch.sum(x * routed_scaling_factor, dim=1, out=out)
    return out


120
def get_benchmark(dtype=torch.bfloat16):
121
122
123
124
125
126
127
    num_tokens_range = [2**i for i in range(0, 13)]

    @triton.testing.perf_report(
        triton.testing.Benchmark(
            x_names=["num_tokens"],
            x_vals=num_tokens_range,
            line_arg="version",
128
129
130
            line_vals=["baseline", "compiled", "triton", "cuda"],
            line_names=["Original", "TorchCompile", "TritonKernel", "CudaKernel"],
            styles=[("blue", "-"), ("green", "-"), ("red", "-"), ("yellow", "-")],
131
            ylabel="us",
132
            plot_name=f"sum_scaled_performance_{str(dtype).split('.')[-1]}",
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
            args={},
        )
    )
    def benchmark(num_tokens, version):
        topk = 9
        hidden_size = 4096
        dtype = torch.bfloat16
        scaling_factor = 0.3

        x = torch.randn(num_tokens, topk, hidden_size, dtype=dtype, device="cuda")
        out = torch.empty(num_tokens, hidden_size, dtype=dtype, device="cuda")

        # Warmup
        for _ in range(3):
            if version == "baseline":
                compute_sum_scaled_baseline(x, out, scaling_factor)
            elif version == "compiled":
                compute_sum_scaled_compiled(x, out, scaling_factor)
151
152
            elif version == "triton":
                moe_sum_reduce_triton(x, out, scaling_factor)
153
            else:
154
                moe_sum_reduce_cuda(x, out, scaling_factor)
155
156
157
158
159
160
161
162
163
164
165
166
167

        # Benchmark
        quantiles = [0.5, 0.2, 0.8]
        if version == "baseline":
            ms, min_ms, max_ms = do_bench(
                lambda: compute_sum_scaled_baseline(x, out, scaling_factor),
                quantiles=quantiles,
            )
        elif version == "compiled":
            ms, min_ms, max_ms = do_bench(
                lambda: compute_sum_scaled_compiled(x, out, scaling_factor),
                quantiles=quantiles,
            )
168
169
170
171
172
        elif version == "triton":
            ms, min_ms, max_ms = do_bench(
                lambda: moe_sum_reduce_triton(x, out, scaling_factor),
                quantiles=quantiles,
            )
173
174
        else:
            ms, min_ms, max_ms = do_bench(
175
176
                lambda: moe_sum_reduce_cuda(x, out, scaling_factor),
                quantiles=quantiles,
177
178
179
180
181
182
183
            )

        return 1000 * ms, 1000 * max_ms, 1000 * min_ms

    return benchmark


184
185
def verify_correctness(num_tokens=1024, dtype=torch.bfloat16):
    x = torch.randn(num_tokens, 9, 4096, device="cuda", dtype=dtype)
186
187
188
189
190
191
192
193
    scaling_factor = 0.3

    out_baseline = torch.empty_like(x[:, 0])
    compute_sum_scaled_baseline(x, out_baseline, scaling_factor)

    out_compiled = torch.empty_like(out_baseline)
    compute_sum_scaled_compiled(x, out_compiled, scaling_factor)

194
195
    out_cuda = torch.empty_like(out_baseline)
    moe_sum_reduce_cuda(x, out_cuda, scaling_factor)
196

197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
    triton_skipped = dtype == torch.float64
    if not triton_skipped:
        out_triton = torch.empty_like(out_baseline)
        moe_sum_reduce_triton(x, out_triton, scaling_factor)

    if dtype == torch.float64:
        atol, rtol = 1e-12, 1e-12
    elif dtype == torch.float32:
        atol, rtol = 1e-6, 1e-6
    else:  # bfloat16 / float16
        atol, rtol = 1e-2, 1e-2

    ok_compiled = torch.allclose(out_baseline, out_compiled, atol=atol, rtol=rtol)
    ok_cuda = torch.allclose(out_baseline, out_cuda, atol=atol, rtol=rtol)
    ok_triton = (
        True
        if triton_skipped
        else torch.allclose(out_baseline, out_triton, atol=atol, rtol=rtol)
    )

    if ok_compiled and ok_triton and ok_cuda:
        msg = "✅ All implementations match"
        if triton_skipped:
            msg += " (Triton skipped for float64)"
        print(msg)
222
223
224
225
226
    else:
        print("❌ Implementations differ")
        print(
            f"Baseline vs Compiled: {(out_baseline - out_compiled).abs().max().item()}"
        )
227
228
229
230
        if not triton_skipped:
            print(
                f"Baseline vs Triton: {(out_baseline - out_triton).abs().max().item()}"
            )
231
        print(f"Baseline vs Cuda: {(out_baseline - out_cuda).abs().max().item()}")
232
233
234


if __name__ == "__main__":
235
236
237
238
239
240
241
242
243
244
    print("Running correctness verification for bfloat16...")
    verify_correctness(dtype=torch.bfloat16)

    # CI environment uses simplified parameters
    if not IS_CI:
        print("Running correctness verification for float64...")
        verify_correctness(dtype=torch.float64)

    print("Running correctness verification for float64...")
    verify_correctness(dtype=torch.float64)
245

246
247
    print("\nRunning performance benchmark for bfloat16...")
    benchmark = get_benchmark(dtype=torch.bfloat16)
248
249
250
251
    benchmark.run(
        print_data=True,
        # save_path="./configs/benchmark_ops/sum_scaled/"
    )