test.py 1.1 KB
Newer Older
xuwx1's avatar
xuwx1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import torch
from lightx2v_kernel.gemm import cutlass_scaled_mxfp6_mxfp8_mm


def test_cutlass_scaled_mxfp6_mxfp8_mm_sm120():
    m, k, n = 1024, 2048, 4096

    input_shape = (m, k)
    weight_shape = (n, k)

    input_tensor_quant = (torch.rand((input_shape[0], input_shape[1]), device="cuda") * 10).to(torch.uint8)
    weight = (torch.rand((weight_shape[0], weight_shape[1] * 3 // 4), device="cuda") * 10).to(torch.uint8)

    print(f"shape: {input_tensor_quant.shape}, {weight.shape}")

    input_tensor_scale = torch.rand((input_shape[0], input_shape[1] // 32), device="cuda").to(torch.float8_e8m0fnu)
    weight_scale = torch.rand(weight_shape[0], weight_shape[1] // 32, device="cuda").to(torch.float8_e8m0fnu)

    print(f"shape: {input_tensor_scale.shape}, {weight_scale.shape}")

    alpha = torch.tensor(0.0002765655517578125, device="cuda", dtype=torch.float32)
    bias = None

    out = cutlass_scaled_mxfp6_mxfp8_mm(input_tensor_quant, weight, input_tensor_scale, weight_scale, alpha, bias)
    print(f"out: {out}, shape: {out.shape}")


if __name__ == "__main__":
    test_cutlass_scaled_mxfp6_mxfp8_mm_sm120()