test_bench2.py 4.1 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
2
3
4
5
6
7
8
9
import torch
from lightx2v_kernel.gemm import scaled_fp4_quant, cutlass_scaled_fp4_mm
import time


class MMWeightFp4:
    def __init__(self, weight, bias):
        self.load_fp4_weight(weight, bias)
        self.act_quant_func = self.act_quant_fp4
helloyongyang's avatar
fix ci  
helloyongyang committed
10

helloyongyang's avatar
helloyongyang committed
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
        # calibrate x_max
        self.calibrate_x_absmax()

    @torch.no_grad()
    def apply(self, input_tensor):
        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
        output_tensor = cutlass_scaled_fp4_mm(input_tensor_quant, self.weight, input_tensor_scale, self.weight_scale, alpha=self.alpha, bias=self.bias)
        return output_tensor

    @torch.no_grad()
    def load_fp4_weight(self, weight, bias):
        self.weight_global_scale = (2688.0 / torch.max(torch.abs(weight))).to(torch.float32)
        self.weight, self.weight_scale = scaled_fp4_quant(weight, self.weight_global_scale)
        self.bias = bias

    def calibrate_x_absmax(self):
helloyongyang's avatar
fix ci  
helloyongyang committed
27
        self.x_absmax = torch.tensor(5.0, dtype=torch.float32, device=self.weight.device)  # need to be calibrated
helloyongyang's avatar
helloyongyang committed
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
        self.input_global_scale = (2688.0 / self.x_absmax).to(torch.float32)
        self.alpha = 1.0 / (self.input_global_scale * self.weight_global_scale)

    @torch.no_grad()
    def act_quant_fp4(self, x):
        return scaled_fp4_quant(x, self.input_global_scale)


def test_speed(m, k, n):
    with torch.no_grad():
        input_tensor = torch.randn(m, k, dtype=torch.bfloat16).cuda()
        weight = torch.randn(n, k, dtype=torch.bfloat16, device="cuda")
        # bias = torch.randn(1, n, dtype=torch.bfloat16).cuda()
        bias = None

        mm = MMWeightFp4(weight, bias)
helloyongyang's avatar
fix ci  
helloyongyang committed
44

helloyongyang's avatar
helloyongyang committed
45
46
        # warmup
        output_tensor = mm.apply(input_tensor)
helloyongyang's avatar
fix ci  
helloyongyang committed
47

helloyongyang's avatar
helloyongyang committed
48
49
50
51
52
53
        torch.cuda.synchronize()
        start_time = time.time()
        for i in range(100):
            output_tensor = mm.apply(input_tensor)
        torch.cuda.synchronize()
        end_time = time.time()
helloyongyang's avatar
fix ci  
helloyongyang committed
54

helloyongyang's avatar
helloyongyang committed
55
56
57
58
59
60
        lightx2v_kernel_time = (end_time - start_time) / 100
        print(f"lightx2v-kernel time: {lightx2v_kernel_time}")

        input_tensor = torch.randn(m, n, dtype=torch.bfloat16).cuda()
        weight = torch.randn(k, n, dtype=torch.bfloat16, device="cuda")
        bias = torch.randn(1, k, dtype=torch.bfloat16).cuda()
helloyongyang's avatar
fix ci  
helloyongyang committed
61

helloyongyang's avatar
helloyongyang committed
62
63
64
65
66
67
68
69
70
71
        linear = torch.nn.Linear(k, n, bias=False).cuda()
        linear.weight.data = weight
        # linear.bias.data = bias

        # warmup
        ref_output_tensor = linear(input_tensor)

        torch.cuda.synchronize()
        start_time = time.time()
        for i in range(100):
helloyongyang's avatar
fix ci  
helloyongyang committed
72
            ref_output_tensor = linear(input_tensor)
helloyongyang's avatar
helloyongyang committed
73
74
        torch.cuda.synchronize()
        end_time = time.time()
helloyongyang's avatar
fix ci  
helloyongyang committed
75

helloyongyang's avatar
helloyongyang committed
76
77
        ref_time = (end_time - start_time) / 100
        print(f"ref time: {ref_time}")
helloyongyang's avatar
fix ci  
helloyongyang committed
78

helloyongyang's avatar
helloyongyang committed
79
80
81
82
83
84
85
86
87
        print(f"speedup: {ref_time / lightx2v_kernel_time:.3f}")


def test_accuracy(m, k, n):
    with torch.no_grad():
        input_tensor = torch.randn(m, k, dtype=torch.bfloat16).cuda()
        weight = torch.randn(n, k, dtype=torch.bfloat16, device="cuda")
        # bias = torch.randn(1, n, dtype=torch.bfloat16).cuda()
        bias = None
helloyongyang's avatar
fix ci  
helloyongyang committed
88

helloyongyang's avatar
helloyongyang committed
89
90
91
        linear = torch.nn.Linear(k, n, bias=False).cuda()
        linear.weight.data = weight
        # linear.bias.data = bias
helloyongyang's avatar
fix ci  
helloyongyang committed
92

helloyongyang's avatar
helloyongyang committed
93
94
95
        ref_output_tensor = linear(input_tensor)

        mm = MMWeightFp4(weight, bias)
helloyongyang's avatar
fix ci  
helloyongyang committed
96

helloyongyang's avatar
helloyongyang committed
97
        output_tensor = mm.apply(input_tensor)
helloyongyang's avatar
fix ci  
helloyongyang committed
98

helloyongyang's avatar
helloyongyang committed
99
100
        # print(f"ref_output_tensor: {ref_output_tensor}")
        # print(f"output_tensor: {output_tensor}")
helloyongyang's avatar
fix ci  
helloyongyang committed
101

helloyongyang's avatar
helloyongyang committed
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
        # cosine
        cos = torch.nn.functional.cosine_similarity(ref_output_tensor.flatten(), output_tensor.flatten(), dim=0)
        print(f"cos : {cos}")


if __name__ == "__main__":
    test_sizes = [
        (32130, 5120, 5120),
        (512, 5120, 5120),
        (257, 5120, 5120),
        (32130, 5120, 13824),
        (32130, 13824, 5120),
        (75348, 5120, 5120),
        (75348, 13824, 5120),
        (32760, 1536, 1536),
        (512, 1536, 1536),
        (32760, 1536, 8960),
        (32760, 8960, 1536),
    ]
helloyongyang's avatar
fix ci  
helloyongyang committed
121

helloyongyang's avatar
helloyongyang committed
122
123
    for i, (m, k, n) in enumerate(test_sizes):
        print("-" * 30)
helloyongyang's avatar
fix ci  
helloyongyang committed
124
        print(f"测试 {i + 1}: 张量大小 ({m}, {k}, {n})")
helloyongyang's avatar
helloyongyang committed
125
126
        test_accuracy(m, k, n)
        test_speed(m, k, n)