test_quant_mem_utils.py 4.2 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
2
3
4
5
6
import torch
from lightx2v_kernel.gemm import scaled_fp4_quant


input_global_scale = torch.tensor(808.0, dtype=torch.float32).cuda()

helloyongyang's avatar
fix ci  
helloyongyang committed
7
8

def quantize_fp4(x):
helloyongyang's avatar
helloyongyang committed
9
10
11
12
13
14
15
16
17
18
    return scaled_fp4_quant(x, input_global_scale)


def test_memory_bandwidth(func, x, num_warmup=10, num_runs=100):
    """
    测试函数的显存带宽
    """
    # 预热GPU
    for _ in range(num_warmup):
        func(x)
helloyongyang's avatar
fix ci  
helloyongyang committed
19

helloyongyang's avatar
helloyongyang committed
20
21
    # 同步GPU
    torch.cuda.synchronize()
helloyongyang's avatar
fix ci  
helloyongyang committed
22

helloyongyang's avatar
helloyongyang committed
23
24
25
    # 创建GPU事件用于精确计时
    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)
helloyongyang's avatar
fix ci  
helloyongyang committed
26

helloyongyang's avatar
helloyongyang committed
27
28
29
30
31
    # 测量时间
    start_event.record()
    for _ in range(num_runs):
        result = func(x)
    end_event.record()
helloyongyang's avatar
fix ci  
helloyongyang committed
32

helloyongyang's avatar
helloyongyang committed
33
34
35
36
    # 同步并计算时间
    torch.cuda.synchronize()
    elapsed_time_ms = start_event.elapsed_time(end_event)
    elapsed_time_s = elapsed_time_ms / 1000.0
helloyongyang's avatar
fix ci  
helloyongyang committed
37

helloyongyang's avatar
helloyongyang committed
38
39
    # 计算数据量
    input_bytes = x.numel() * x.element_size()  # 输入数据字节数
helloyongyang's avatar
fix ci  
helloyongyang committed
40

helloyongyang's avatar
helloyongyang committed
41
42
43
    # FP4量化后,每个元素占用0.5字节
    output_bytes = x.numel() * 0.5  # FP4输出数据字节数

helloyongyang's avatar
fix ci  
helloyongyang committed
44
45
    scale_bytes = x.numel() / 16  # group_size = 16

helloyongyang's avatar
helloyongyang committed
46
47
    # 总数据传输量(读取输入 + 写入输出 + scale)
    total_bytes = (input_bytes + output_bytes + scale_bytes) * num_runs
helloyongyang's avatar
fix ci  
helloyongyang committed
48

helloyongyang's avatar
helloyongyang committed
49
50
    # 计算带宽
    bandwidth_gbps = (total_bytes / elapsed_time_s) / (1024**3)  # GB/s
helloyongyang's avatar
fix ci  
helloyongyang committed
51

helloyongyang's avatar
helloyongyang committed
52
53
54
55
56
    print(f"测试结果:")
    print(f"  输入张量形状: {x.shape}")
    print(f"  输入数据类型: {x.dtype}")
    print(f"  运行次数: {num_runs}")
    print(f"  总执行时间: {elapsed_time_ms:.2f} ms")
helloyongyang's avatar
fix ci  
helloyongyang committed
57
    print(f"  平均每次执行时间: {elapsed_time_ms / num_runs:.4f} ms")
helloyongyang's avatar
helloyongyang committed
58
    print(f"  输入数据大小: {input_bytes / (1024**2):.2f} MB")
helloyongyang's avatar
fix ci  
helloyongyang committed
59
    print(f"  输出数据大小: {output_bytes / (1024**2):.2f} MB")
helloyongyang's avatar
helloyongyang committed
60
61
    print(f"  总数据传输量: {total_bytes / (1024**3):.2f} GB")
    print(f"  显存带宽: {bandwidth_gbps:.2f} GB/s")
helloyongyang's avatar
fix ci  
helloyongyang committed
62

helloyongyang's avatar
helloyongyang committed
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
    return bandwidth_gbps


if __name__ == "__main__":
    # 测试不同大小的张量
    test_sizes = [
        # (1, 1024),
        # (1, 2048),
        # (1, 4096),
        # (1, 8192),
        # (1, 16384),
        # (1, 32768),
        # (2, 1024),
        # (2, 2048),
        # (2, 4096),
        # (2, 8192),
        # (2, 16384),
        # (2, 32768),
        # (4, 1024),
        # (4, 2048),
        # (4, 4096),
        # (4, 8192),
        # (4, 16384),
        # (4, 32768),
        # (128, 1024),
        # (128, 2048),
        # (128, 4096),
        # (128, 8192),
        # (128, 16384),
        # (128, 32768),
        # (512, 1024),
        # (512, 2048),
        # (512, 4096),
        # (512, 8192),
        # (512, 16384),
        # (512, 32768),
        # (1024, 1024),
        # (1024, 2048),
        # (1024, 4096),
        # (1024, 8192),
        # (1024, 16384),
        # (1024, 32768),
        # (2048, 1024),
        # (2048, 2048),
        # (2048, 4096),
        # (2048, 8192),
        # (2048, 16384),
        # (2048, 32768),
        # (4096, 1024),
        # (4096, 2048),
        # (4096, 4096),
        # (4096, 8192),
        # (4096, 16384),
        # (4096, 32768),
        # (8192, 1024),
        # (8192, 2048),
        # (8192, 4096),
        # (8192, 8192),
        # (8192, 16384),
        # (8192, 32768),
        # (16384, 1024),
        # (16384, 2048),
        # (16384, 4096),
        # (16384, 8192),
        # (16384, 16384),
        # (16384, 32768),
        # (32768, 1024),
        # (32768, 2048),
        # (32768, 4096),
        # (32768, 8192),
        # (32768, 16384),
        # (32768, 32768),
        (32130, 5120),
        (512, 5120),
        (257, 5120),
        (32130, 13824),
        (75348, 5120),
        (75348, 13824),
        (32760, 1536),
        (512, 1536),
        (32760, 8960),
    ]
helloyongyang's avatar
fix ci  
helloyongyang committed
145

helloyongyang's avatar
helloyongyang committed
146
    print("=== quantize_fp4 显存带宽测试 ===\n")
helloyongyang's avatar
fix ci  
helloyongyang committed
147

helloyongyang's avatar
helloyongyang committed
148
    for i, (h, w) in enumerate(test_sizes):
helloyongyang's avatar
fix ci  
helloyongyang committed
149
        print(f"测试 {i + 1}: 张量大小 ({h}, {w})")
helloyongyang's avatar
helloyongyang committed
150
        print("-" * 50)
helloyongyang's avatar
fix ci  
helloyongyang committed
151

helloyongyang's avatar
helloyongyang committed
152
        x = torch.randn(h, w, dtype=torch.bfloat16).cuda()
helloyongyang's avatar
fix ci  
helloyongyang committed
153

helloyongyang's avatar
helloyongyang committed
154
155
156
157
158
        try:
            bandwidth = test_memory_bandwidth(quantize_fp4, x)
            print(f"✓ 成功完成测试,带宽: {bandwidth:.2f} GB/s\n")
        except Exception as e:
            print(f"✗ 测试失败: {e}\n")
helloyongyang's avatar
fix ci  
helloyongyang committed
159

helloyongyang's avatar
helloyongyang committed
160
    print("=== 测试完成 ===")