test_mxfp8_quant.py 2.28 KB
Newer Older
1
2
3
import unittest
import torch
from lightx2v_kernel.gemm import cutlass_scaled_mxfp8_mm
4
from lightx2v_kernel.gemm import scaled_mxfp8_quant
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
from torch.nn.functional import linear
from lightx2v_kernel.utils import error, benchmark


class TestQuantBF162MXFP8(unittest.TestCase):
    def setUp(self):
        self.tokens = [257, 512, 1024, 13325, 32130, 32760]  # , 75348
        self.channels = [1536, 5120, 8960]  # , 13824
        self.hiddenDims = [1536, 3072, 5120, 8960, 12800]  # , 13824

        self.device = "cuda"
        self.dtype = torch.bfloat16

    def test_accuracy(self):
        """Test the accuracy of quantization from BF16 to MXFP8."""
        for m in self.tokens:
            for k in self.hiddenDims:
                for n in self.channels:
                    with self.subTest(shape=[m, k, n]):
                        activation = torch.randn(m, k, dtype=self.dtype, device=self.device)
25
                        activation_quant_pred, activation_scale_pred = scaled_mxfp8_quant(activation)
26
27

                        weight = torch.randn(n, k, dtype=self.dtype, device=self.device)
28
                        weight_quant_pred, weight_scale_pred = scaled_mxfp8_quant(weight)
29

Xtra's avatar
Xtra committed
30
31
                        bias = torch.rand(1, n, dtype=self.dtype, device=self.device) * 10

32
                        alpha = torch.tensor(1.0, device=self.device, dtype=torch.float32)
Xtra's avatar
Xtra committed
33
                        mm_pred = cutlass_scaled_mxfp8_mm(activation_quant_pred, weight_quant_pred, activation_scale_pred, weight_scale_pred, alpha=alpha, bias=bias)
34

Xtra's avatar
Xtra committed
35
                        mm_real = linear(activation, weight, bias=bias).to(torch.bfloat16)
36
37
38
39
40
41
42
43
44
45
46

                        self.assertTrue(error(mm_pred, mm_real) < 1e-2, f"Accuracy test failed for shape {m, k, n}: Error {error(mm_pred, mm_real)} exceeds threshold.")

    def test_performance(self):
        """Benchmark the performance of Activation quantization from BF16 to MXFP8."""
        for m in self.tokens:
            for k in self.hiddenDims:
                with self.subTest(shape=[m, k]):
                    input = torch.randn(m, k, dtype=self.dtype, device=self.device)
                    shape = [m, k]
                    tflops = 2 * (m * k / 1024**4)
47
                    benchmark(scaled_mxfp8_quant, shape, tflops, 100, input)
48
49
50
51


if __name__ == "__main__":
    unittest.main()