test_grouped_gemm.py 10.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
import os
import sys

import pytest
import torch

sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))

try:
    from fbgemm_grouped_gemm import grouped_gemm as fbgemm_grouped_gemm
    from fbgemm_grouped_gemm import (
        grouped_gemm_fp8_rowwise as fbgemm_grouped_gemm_fp8_rowwise,
    )

    FBGEMM_AVAILABLE = True
    print("✓ Successfully imported FBGEMM grouped GEMM")
except ImportError as e:
    print(f"✗ Failed to import FBGEMM grouped GEMM: {e}")
    FBGEMM_AVAILABLE = False

try:
    from sglang.srt.layers.moe.ep_moe.kernels import (
        grouped_gemm_triton as sglang_grouped_gemm,
    )

    SGLANG_AVAILABLE = True
    print("✓ Successfully imported SGLang grouped GEMM")
except ImportError as e:
    print(f"✗ Failed to import SGLang grouped GEMM: {e}")
    SGLANG_AVAILABLE = False


def create_uniform_groups(batch_size, num_groups, device):
    tokens_per_group = batch_size // num_groups
    return torch.full((num_groups,), tokens_per_group, dtype=torch.int64, device=device)


def create_non_uniform_groups(batch_size, num_groups, device):
    remaining = batch_size
    m_sizes = []

    for i in range(num_groups - 1):
        if remaining <= 1:
            size = 1
        else:
            max_size = remaining - (num_groups - i - 1) + 1
            size = torch.randint(1, max_size, (1,)).item()
        m_sizes.append(size)
        remaining -= size

    m_sizes.append(remaining)
    return torch.tensor(m_sizes, dtype=torch.int64, device=device)


def create_sglang_inputs(x, w, m_sizes, num_groups, intermediate_size, device):
    batch_size = x.shape[0]

    c_sglang = torch.empty(
        batch_size, intermediate_size, dtype=torch.bfloat16, device=device
    )

    seg_indptr = torch.zeros(num_groups + 1, dtype=torch.int64, device=device)
    current_pos = 0
    for i, size in enumerate(m_sizes):
        current_pos += size
        seg_indptr[i + 1] = current_pos

    weight_indices = torch.arange(num_groups, dtype=torch.int64, device=device)
    w_sglang = w.view(num_groups, intermediate_size, -1)

    return c_sglang, seg_indptr, weight_indices, w_sglang


def create_fp8_data(batch_size, num_groups, hidden_size, intermediate_size, device):
    torch.manual_seed(42)

    x_fp16 = torch.randn(batch_size, hidden_size, dtype=torch.float16, device=device)
    w_fp16 = torch.randn(
        num_groups * intermediate_size, hidden_size, dtype=torch.float16, device=device
    )

    x_fp8 = x_fp16.to(torch.float8_e4m3fn)
    w_fp8 = w_fp16.to(torch.float8_e4m3fn)

    x_scale = torch.randn(batch_size, dtype=torch.float32, device=device).abs() + 1e-4
    w_scale = torch.randn(num_groups, dtype=torch.float32, device=device).abs() + 1e-4

    return x_fp8, w_fp8, x_scale, w_scale


@pytest.fixture
def device():
    if not torch.cuda.is_available():
        pytest.skip("CUDA not available")
    return torch.device("cuda")


@pytest.mark.skipif(not FBGEMM_AVAILABLE, reason="FBGEMM not available")
@pytest.mark.skipif(not SGLANG_AVAILABLE, reason="SGLang not available")
@pytest.mark.parametrize("batch_size", [32])
@pytest.mark.parametrize("num_groups", [2, 4, 8])
@pytest.mark.parametrize("hidden_size", [512, 1024])
@pytest.mark.parametrize("intermediate_size", [1024, 2048])
def test_uniform_groups(batch_size, num_groups, hidden_size, intermediate_size, device):
    if batch_size % num_groups != 0:
        pytest.skip(f"batch_size {batch_size} not divisible by num_groups {num_groups}")

    torch.manual_seed(42)

    m_sizes = create_uniform_groups(batch_size, num_groups, device)

    x = torch.randn(batch_size, hidden_size, dtype=torch.bfloat16, device=device)
    w = torch.randn(
        num_groups * intermediate_size, hidden_size, dtype=torch.bfloat16, device=device
    )

    result_fbgemm = fbgemm_grouped_gemm(x, w, m_sizes, use_fast_accum=True)

    c_sglang, seg_indptr, weight_indices, w_sglang = create_sglang_inputs(
        x, w, m_sizes, num_groups, intermediate_size, device
    )

    result_sglang = sglang_grouped_gemm(
        x,
        w_sglang,
        c_sglang,
        num_groups,
        weight_column_major=True,
        seg_indptr=seg_indptr,
        weight_indices=weight_indices,
        c_dtype=c_sglang.dtype,
    )

    assert torch.allclose(result_fbgemm, result_sglang, rtol=1e-3, atol=1e-3)


@pytest.mark.skipif(not FBGEMM_AVAILABLE, reason="FBGEMM not available")
@pytest.mark.skipif(not SGLANG_AVAILABLE, reason="SGLang not available")
@pytest.mark.parametrize("batch_size", [63, 100, 127])
@pytest.mark.parametrize("num_groups", [3, 5, 7])
@pytest.mark.parametrize("hidden_size", [512, 1024])
@pytest.mark.parametrize("intermediate_size", [1024, 2048])
def test_non_uniform_groups(
    batch_size, num_groups, hidden_size, intermediate_size, device
):
    torch.manual_seed(42)

    m_sizes = create_non_uniform_groups(batch_size, num_groups, device)

    x = torch.randn(batch_size, hidden_size, dtype=torch.bfloat16, device=device)
    w = torch.randn(
        num_groups * intermediate_size, hidden_size, dtype=torch.bfloat16, device=device
    )

    result_fbgemm = fbgemm_grouped_gemm(x, w, m_sizes, use_fast_accum=True)

    c_sglang, seg_indptr, weight_indices, w_sglang = create_sglang_inputs(
        x, w, m_sizes, num_groups, intermediate_size, device
    )

    result_sglang = sglang_grouped_gemm(
        x,
        w_sglang,
        c_sglang,
        num_groups,
        weight_column_major=True,
        seg_indptr=seg_indptr,
        weight_indices=weight_indices,
        c_dtype=c_sglang.dtype,
    )

    assert torch.allclose(result_fbgemm, result_sglang, rtol=1e-3, atol=1e-3)


@pytest.mark.skipif(not FBGEMM_AVAILABLE, reason="FBGEMM not available")
@pytest.mark.skipif(not SGLANG_AVAILABLE, reason="SGLang not available")
@pytest.mark.parametrize("batch_size,num_groups", [(64, 4), (128, 8), (256, 16)])
@pytest.mark.parametrize("hidden_size", [768, 2048, 4096])
@pytest.mark.parametrize("intermediate_size", [2048, 4096, 8192])
def test_large_dimensions(
    batch_size, num_groups, hidden_size, intermediate_size, device
):
    torch.manual_seed(42)

    m_sizes = create_uniform_groups(batch_size, num_groups, device)

    x = torch.randn(batch_size, hidden_size, dtype=torch.bfloat16, device=device)
    w = torch.randn(
        num_groups * intermediate_size, hidden_size, dtype=torch.bfloat16, device=device
    )

    result_fbgemm = fbgemm_grouped_gemm(x, w, m_sizes, use_fast_accum=True)

    c_sglang, seg_indptr, weight_indices, w_sglang = create_sglang_inputs(
        x, w, m_sizes, num_groups, intermediate_size, device
    )

    result_sglang = sglang_grouped_gemm(
        x,
        w_sglang,
        c_sglang,
        num_groups,
        weight_column_major=True,
        seg_indptr=seg_indptr,
        weight_indices=weight_indices,
        c_dtype=c_sglang.dtype,
    )

    assert torch.allclose(result_fbgemm, result_sglang, rtol=1e-3, atol=1e-3)


@pytest.mark.skipif(not FBGEMM_AVAILABLE, reason="FBGEMM not available")
@pytest.mark.parametrize("batch_size", [32, 64])
@pytest.mark.parametrize("num_groups", [2, 4])
@pytest.mark.parametrize("hidden_size", [512, 1024])
@pytest.mark.parametrize("intermediate_size", [1024, 2048])
def test_fp8_uniform_groups(
    batch_size, num_groups, hidden_size, intermediate_size, device
):
    if batch_size % num_groups != 0:
        pytest.skip(f"batch_size {batch_size} not divisible by num_groups {num_groups}")

    torch.manual_seed(42)

    m_sizes = create_uniform_groups(batch_size, num_groups, device)
    x_fp8, w_fp8, x_scale, w_scale = create_fp8_data(
        batch_size, num_groups, hidden_size, intermediate_size, device
    )

    try:
        result_fp8 = fbgemm_grouped_gemm_fp8_rowwise(
            x_fp8, w_fp8, m_sizes, x_scale, w_scale, use_fast_accum=True
        )
        assert result_fp8.shape == (batch_size, intermediate_size)
        assert result_fp8.dtype == torch.bfloat16
    except Exception as e:
        pytest.skip(f"FP8 test failed (possibly unsupported): {e}")


@pytest.mark.skipif(not FBGEMM_AVAILABLE, reason="FBGEMM not available")
@pytest.mark.parametrize("batch_size", [63, 100])
@pytest.mark.parametrize("num_groups", [3, 5])
@pytest.mark.parametrize("hidden_size", [512, 1024])
@pytest.mark.parametrize("intermediate_size", [1024, 2048])
def test_fp8_non_uniform_groups(
    batch_size, num_groups, hidden_size, intermediate_size, device
):
    torch.manual_seed(42)

    m_sizes = create_non_uniform_groups(batch_size, num_groups, device)
    x_fp8, w_fp8, x_scale, w_scale = create_fp8_data(
        batch_size, num_groups, hidden_size, intermediate_size, device
    )

    try:
        result_fp8 = fbgemm_grouped_gemm_fp8_rowwise(
            x_fp8, w_fp8, m_sizes, x_scale, w_scale, use_fast_accum=True
        )
        assert result_fp8.shape == (batch_size, intermediate_size)
        assert result_fp8.dtype == torch.bfloat16
    except Exception as e:
        pytest.skip(f"FP8 test failed (possibly unsupported): {e}")


@pytest.mark.skipif(not FBGEMM_AVAILABLE, reason="FBGEMM not available")
def test_fbgemm_only_uniform(device):
    torch.manual_seed(42)

    batch_size, num_groups = 64, 4
    hidden_size, intermediate_size = 512, 1024

    m_sizes = create_uniform_groups(batch_size, num_groups, device)
    x = torch.randn(batch_size, hidden_size, dtype=torch.bfloat16, device=device)
    w = torch.randn(
        num_groups * intermediate_size, hidden_size, dtype=torch.bfloat16, device=device
    )

    result = fbgemm_grouped_gemm(x, w, m_sizes, use_fast_accum=True)

    assert result.shape == (batch_size, intermediate_size)
    assert result.dtype == torch.bfloat16


@pytest.mark.skipif(not SGLANG_AVAILABLE, reason="SGLang not available")
def test_sglang_only_uniform(device):
    torch.manual_seed(42)

    batch_size, num_groups = 64, 4
    hidden_size, intermediate_size = 512, 1024

    m_sizes = create_uniform_groups(batch_size, num_groups, device)
    x = torch.randn(batch_size, hidden_size, dtype=torch.bfloat16, device=device)
    w = torch.randn(
        num_groups * intermediate_size, hidden_size, dtype=torch.bfloat16, device=device
    )

    c_sglang, seg_indptr, weight_indices, w_sglang = create_sglang_inputs(
        x, w, m_sizes, num_groups, intermediate_size, device
    )

    result = sglang_grouped_gemm(
        x,
        w_sglang,
        c_sglang,
        num_groups,
        weight_column_major=True,
        seg_indptr=seg_indptr,
        weight_indices=weight_indices,
        c_dtype=c_sglang.dtype,
    )

    assert result.shape == (batch_size, intermediate_size)
    assert result.dtype == torch.bfloat16


def test_imports():
    assert (
        FBGEMM_AVAILABLE or SGLANG_AVAILABLE
    ), "Neither FBGEMM nor SGLang is available"


if __name__ == "__main__":
    pytest.main([__file__, "-v"])