test_mxfp8_quant.py 2.28 KB
Newer Older
xuwx1's avatar
xuwx1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import unittest
import torch
from lightx2v_kernel.gemm import cutlass_scaled_mxfp8_mm
from lightx2v_kernel.gemm import scaled_mxfp8_quant
from torch.nn.functional import linear
from lightx2v_kernel.utils import error, benchmark


class TestQuantBF162MXFP8(unittest.TestCase):
    def setUp(self):
        self.tokens = [257, 512, 1024, 13325, 32130, 32760]  # , 75348
        self.channels = [1536, 5120, 8960]  # , 13824
        self.hiddenDims = [1536, 3072, 5120, 8960, 12800]  # , 13824

        self.device = "cuda"
        self.dtype = torch.bfloat16

    def test_accuracy(self):
        """Test the accuracy of quantization from BF16 to MXFP8."""
        for m in self.tokens:
            for k in self.hiddenDims:
                for n in self.channels:
                    with self.subTest(shape=[m, k, n]):
                        activation = torch.randn(m, k, dtype=self.dtype, device=self.device)
                        activation_quant_pred, activation_scale_pred = scaled_mxfp8_quant(activation)

                        weight = torch.randn(n, k, dtype=self.dtype, device=self.device)
                        weight_quant_pred, weight_scale_pred = scaled_mxfp8_quant(weight)

                        bias = torch.rand(1, n, dtype=self.dtype, device=self.device) * 10

                        alpha = torch.tensor(1.0, device=self.device, dtype=torch.float32)
                        mm_pred = cutlass_scaled_mxfp8_mm(activation_quant_pred, weight_quant_pred, activation_scale_pred, weight_scale_pred, alpha=alpha, bias=bias)

                        mm_real = linear(activation, weight, bias=bias).to(torch.bfloat16)

                        self.assertTrue(error(mm_pred, mm_real) < 1e-2, f"Accuracy test failed for shape {m, k, n}: Error {error(mm_pred, mm_real)} exceeds threshold.")

    def test_performance(self):
        """Benchmark the performance of Activation quantization from BF16 to MXFP8."""
        for m in self.tokens:
            for k in self.hiddenDims:
                with self.subTest(shape=[m, k]):
                    input = torch.randn(m, k, dtype=self.dtype, device=self.device)
                    shape = [m, k]
                    tflops = 2 * (m * k / 1024**4)
                    benchmark(scaled_mxfp8_quant, shape, tflops, 100, input)


if __name__ == "__main__":
    unittest.main()