a.py 14.7 KB
Newer Older
wangkaixiong's avatar
init  
wangkaixiong committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
"""
Triton FP8支持测试Demo
测试环境要求:
- Triton >= 2.1.0
- CUDA >= 11.8
- GPU计算能力 >= 8.9 (H100, L40S, etc.)
"""

import torch
import triton
import triton.language as tl
import numpy as np
from typing import Tuple


# 检查Triton版本
print(f"Triton version: {triton.__version__}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Compute Capability: {torch.cuda.get_device_capability(0)}")
    
    # 检查计算能力是否支持FP8
    cc = torch.cuda.get_device_capability(0)
    fp8_supported = cc[0] * 10 + cc[1] >= 89  # 8.9+
    print(f"FP8 hardware support: {fp8_supported}")


@triton.jit
def add_kernel_fp8(
    x_ptr, y_ptr, output_ptr,
    n_elements,
    BLOCK_SIZE: tl.constexpr,
):
    """使用FP8的向量加法kernel"""
    pid = tl.program_id(axis=0)
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n_elements
    
    # 加载FP8数据并转换为FP32进行计算
    x = tl.load(x_ptr + offsets, mask=mask).to(tl.float32)
    y = tl.load(y_ptr + offsets, mask=mask).to(tl.float32)
    
    # 执行计算
    output = x + y
    
    # 转换回FP8并存储
    # 注意:需要根据实际使用的FP8格式选择合适的缩放因子
    output_fp8 = output.to(tl.float8e5)  # 或 tl.float8e4m3
    tl.store(output_ptr + offsets, output_fp8, mask=mask)


@triton.jit
def matmul_kernel_fp8(
    a_ptr, b_ptr, c_ptr,
    M, N, K,
    stride_am, stride_ak,
    stride_bk, stride_bn,
    stride_cm, stride_cn,
    BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
    """使用FP8的矩阵乘法kernel"""
    pid_m = tl.program_id(0)
    pid_n = tl.program_id(1)
    
    # 计算当前block的位置
    offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    offs_k = tl.arange(0, BLOCK_K)
    
    # 创建mask
    m_mask = offs_m[:, None] < M
    n_mask = offs_n[None, :] < N
    
    # 初始化累加器
    accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
    
    for k in range(0, K, BLOCK_K):
        # 加载FP8矩阵块并转换为FP32
        a_ptrs = a_ptr + (offs_m[:, None] * stride_am + (offs_k[None, :] + k) * stride_ak)
        b_ptrs = b_ptr + ((offs_k[:, None] + k) * stride_bk + offs_n[None, :] * stride_bn)
        
        a = tl.load(a_ptrs, mask=m_mask & (offs_k[None, :] + k < K)[None, :])
        b = tl.load(b_ptrs, mask=(offs_k[:, None] + k < K)[:, None] & n_mask)
        
        # 转换为FP32计算
        a_fp32 = a.to(tl.float32)
        b_fp32 = b.to(tl.float32)
        
        # 矩阵乘法
        accumulator += tl.dot(a_fp32, b_fp32)
    
    # 将结果转换为FP8并存储
    c = accumulator.to(tl.float8e5)
    
    # 存储结果
    offs_m_full = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_n_full = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    c_ptrs = c_ptr + (offs_m_full[:, None] * stride_cm + offs_n_full[None, :] * stride_cn)
    tl.store(c_ptrs, c, mask=m_mask & n_mask)


def create_fp8_tensor(data: torch.Tensor, fp8_type: str = 'e5m2') -> torch.Tensor:
    """创建FP8张量"""
    if fp8_type == 'e5m2':
        # 转换为float8_e5m2格式
        # 需要先缩放到合适范围
        max_val = data.abs().max()
        scale = 448.0 / max_val if max_val > 0 else 1.0  # e5m2最大值为448
        scaled_data = data * scale
        return scaled_data.to(torch.float8_e5m2)
    elif fp8_type == 'e4m3':
        max_val = data.abs().max()
        scale = 240.0 / max_val if max_val > 0 else 1.0  # e4m3最大值为240
        scaled_data = data * scale
        return scaled_data.to(torch.float8_e4m3fn)
    else:
        raise ValueError(f"Unsupported FP8 type: {fp8_type}")


def create_fp8_tensor_with_scaling(data: torch.Tensor, fp8_type: str = 'e5m2'):
    """创建带缩放因子的FP8张量"""
    if fp8_type == 'e5m2':
        fp8_max = 57344.0
        data_max = data.abs().max()
        scale = fp8_max / data_max if data_max > 0 else 1.0
        scaled_data = data * scale
        # 确保值在FP8范围内
        scaled_data = torch.clamp(scaled_data, -fp8_max, fp8_max)
        fp8_data = scaled_data.to(torch.float8_e5m2)
        return fp8_data, scale
    elif fp8_type == 'e4m3':
        fp8_max = 448.0
        data_max = data.abs().max()
        scale = fp8_max / data_max if data_max > 0 else 1.0
        scaled_data = data * scale
        scaled_data = torch.clamp(scaled_data, -fp8_max, fp8_max)
        fp8_data = scaled_data.to(torch.float8_e4m3fn)
        return fp8_data, scale
    else:
        raise ValueError(f"Unsupported FP8 type: {fp8_type}")

@triton.jit
def add_kernel_fp8_workaround(
    x_ptr, y_ptr, output_ptr,
    scale_x, scale_y, scale_out,
    n_elements,
    BLOCK_SIZE: tl.constexpr,
):
    """绕过类型问题的FP8加法"""
    pid = tl.program_id(axis=0)
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n_elements
    
    # 直接加载并转换,一行完成
    # 1. 先加载数据,并立即转换为 float32 类型
    x_loaded = tl.load(x_ptr + offsets, mask=mask).to(tl.float32)
    # 2. 再对已经明确为 float32 类型的变量进行除法运算
    x = x_loaded / scale_x

    # 1. 先加载数据,并立即转换为 float32 类型
    y_loaded = tl.load(y_ptr + offsets, mask=mask).to(tl.float32)
    # 2. 再对已经明确为 float32 类型的变量进行除法运算
    y = y_loaded / scale_y
    
    # 计算
    output = x + y
    
    # 存储
    output_scaled = output * scale_out
    tl.store(output_ptr + offsets, output_scaled.to(tl.float8e5), mask=mask)

def test_fp8_vector_addition_fixed():
    """修复后的FP8向量加法测试"""
    print("\n" + "="*50)
    print("测试: 带缩放因子的FP8向量加法")
    print("="*50)
    
    n_elements = 1024
    
    # 准备数据
    torch.manual_seed(42)
    x = torch.randn(n_elements, device='cuda', dtype=torch.float32) * 2
    y = torch.randn(n_elements, device='cuda', dtype=torch.float32) * 2
    
    # 创建带缩放的FP8张量
    x_fp8, scale_x = create_fp8_tensor_with_scaling(x, 'e5m2')
    y_fp8, scale_y = create_fp8_tensor_with_scaling(y, 'e5m2')
    output_fp8 = torch.empty_like(x_fp8)
    
    # 计算输出的缩放因子
    expected_max = (x.abs().max() + y.abs().max()).item() * 1.5
    scale_out = 57344.0 / expected_max if expected_max > 0 else 1.0
    
    print(f"缩放因子: scale_x={scale_x:.4f}, scale_y={scale_y:.4f}, scale_out={scale_out:.4f}")
    print(f"数据范围: x=[{x.min():.3f}, {x.max():.3f}], y=[{y.min():.3f}, {y.max():.3f}]")
    
    # 配置kernel
    BLOCK_SIZE = 256
    grid = (triton.cdiv(n_elements, BLOCK_SIZE),)
    
    # 执行kernel - 注意缩放因子作为标量传递
    add_kernel_fp8_workaround[grid](
        x_fp8, y_fp8, output_fp8,
        scale_x, scale_y, scale_out,
        n_elements, BLOCK_SIZE
    )
    
    # 恢复数据
    output = output_fp8.float() / scale_out
    expected = x + y
    
    # 计算误差
    abs_error = (output - expected).abs()
    rel_error = abs_error / (expected.abs() + 1e-8)
    
    print(f"\n前10个结果对比:")
    print(f"预期: {expected[:10].cpu()}")
    print(f"实际: {output[:10].cpu()}")
    print(f"差异: {(output - expected)[:10].cpu()}")
    
    print(f"\n误差统计:")
    print(f"平均绝对误差: {abs_error.mean():.6f}")
    print(f"最大绝对误差: {abs_error.max():.6f}")
    print(f"平均相对误差: {rel_error.mean():.6f}")
    print(f"最大相对误差: {rel_error.max():.6f}")
    
    # 检查是否合理
    if rel_error.mean() < 0.1:
        print("✓ FP8精度在可接受范围内")
        return True
    else:
        print("⚠️ FP8精度损失较大")
        return False


def test_fp8_matrix_multiplication():
    """测试FP8矩阵乘法"""
    print("\n" + "="*50)
    print("测试2: FP8矩阵乘法")
    print("="*50)
    
    try:
        # 配置矩阵维度
        M, N, K = 512, 512, 256
        
        # 创建输入矩阵
        a = torch.randn((M, K), device='cuda', dtype=torch.float32)
        b = torch.randn((K, N), device='cuda', dtype=torch.float32)
        
        # 转换为FP8
        a_fp8 = create_fp8_tensor(a, 'e5m2')
        b_fp8 = create_fp8_tensor(b, 'e5m2')
        c_fp8 = torch.empty((M, N), device='cuda', dtype=torch.float8_e5m2)
        
        # Kernel配置
        BLOCK_M, BLOCK_N, BLOCK_K = 64, 64, 64
        grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N))
        
        # 执行矩阵乘法
        matmul_kernel_fp8[grid](
            a_fp8, b_fp8, c_fp8,
            M, N, K,
            a.stride(0), a.stride(1),
            b.stride(0), b.stride(1),
            c_fp8.stride(0), c_fp8.stride(1),
            BLOCK_M, BLOCK_N, BLOCK_K
        )
        
        # 转换回FP32
        c_output = c_fp8.float()
        
        # 验证结果(使用PyTorch的矩阵乘法)
        expected = torch.mm(a, b)
        
        # 归一化误差计算
        norm_expected = torch.norm(expected)
        rel_error = torch.norm(c_output - expected) / norm_expected
        
        print(f"✓ FP8矩阵乘法成功执行")
        print(f"  矩阵维度: {M}x{K} * {K}x{N} = {M}x{N}")
        print(f"  相对误差: {rel_error:.6f}")
        
        # 显示部分结果
        print(f"  预期结果[0:5,0:5]: {expected[0,0]:.4f} {expected[0,1]:.4f} ...")
        print(f"  实际结果[0:5,0:5]: {c_output[0,0]:.4f} {c_output[0,1]:.4f} ...")
        
        return True
        
    except Exception as e:
        print(f"✗ FP8矩阵乘法失败: {e}")
        return False


def test_fp8_dtype_support():
    """测试FP8数据类型支持"""
    print("\n" + "="*50)
    print("测试3: FP8数据类型支持")
    print("="*50)
    
    # 检查torch的FP8支持
    fp8_types = {
        'float8_e5m2': hasattr(torch, 'float8_e5m2'),
        'float8_e4m3fn': hasattr(torch, 'float8_e4m3fn'),
    }
    
    for dtype_name, supported in fp8_types.items():
        status = "✓" if supported else "✗"
        print(f"{status} torch.{dtype_name}: {supported}")
    
    # 检查triton的FP8支持
    triton_fp8_types = {
        'float8e5': hasattr(tl, 'float8e5'),
        'float8e4': hasattr(tl, 'float8e4'),
    }
    
    for dtype_name, supported in triton_fp8_types.items():
        status = "✓" if supported else "✗"
        print(f"{status} tl.{dtype_name}: {supported}")
    
    return any(fp8_types.values()) and any(triton_fp8_types.values())


def performance_comparison():
    """性能对比:FP32 vs FP8"""
    print("\n" + "="*50)
    print("测试4: 性能对比 (FP32 vs FP8)")
    print("="*50)
    
    try:
        import time
        
        M, N, K = 1024, 1024, 512
        
        # FP32矩阵乘法
        a_fp32 = torch.randn((M, K), device='cuda', dtype=torch.float32)
        b_fp32 = torch.randn((K, N), device='cuda', dtype=torch.float32)
        
        # 预热
        for _ in range(10):
            torch.mm(a_fp32, b_fp32)
        
        torch.cuda.synchronize()
        start = time.time()
        for _ in range(100):
            torch.mm(a_fp32, b_fp32)
        torch.cuda.synchronize()
        fp32_time = (time.time() - start) / 100
        
        # FP8矩阵乘法
        a_fp8 = create_fp8_tensor(a_fp32, 'e5m2')
        b_fp8 = create_fp8_tensor(b_fp32, 'e5m2')
        c_fp8 = torch.empty((M, N), device='cuda', dtype=torch.float8_e5m2)
        
        BLOCK_M, BLOCK_N, BLOCK_K = 64, 64, 64
        grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N))
        
        # 预热
        for _ in range(10):
            matmul_kernel_fp8[grid](
                a_fp8, b_fp8, c_fp8,
                M, N, K,
                a_fp8.stride(0), a_fp8.stride(1),
                b_fp8.stride(0), b_fp8.stride(1),
                c_fp8.stride(0), c_fp8.stride(1),
                BLOCK_M, BLOCK_N, BLOCK_K
            )
        
        torch.cuda.synchronize()
        start = time.time()
        for _ in range(100):
            matmul_kernel_fp8[grid](
                a_fp8, b_fp8, c_fp8,
                M, N, K,
                a_fp8.stride(0), a_fp8.stride(1),
                b_fp8.stride(0), b_fp8.stride(1),
                c_fp8.stride(0), c_fp8.stride(1),
                BLOCK_M, BLOCK_N, BLOCK_K
            )
        torch.cuda.synchronize()
        fp8_time = (time.time() - start) / 100
        
        speedup = fp32_time / fp8_time
        
        print(f"FP32 平均时间: {fp32_time*1000:.3f} ms")
        print(f"FP8  平均时间: {fp8_time*1000:.3f} ms")
        print(f"加速比: {speedup:.2f}x")
        
        return speedup > 1.0
        
    except Exception as e:
        print(f"性能测试失败: {e}")
        return False


def main():
    """主测试函数"""
    print("\n" + "🚀 Triton FP8 支持测试套件")
    print("="*50)
    
    # 检查硬件支持
    if not torch.cuda.is_available():
        print("❌ CUDA不可用,无法测试FP8")
        return
    
    cc = torch.cuda.get_device_capability(0)
    if cc[0] * 10 + cc[1] < 89:
        print(f"⚠️  警告: GPU计算能力{cc[0]}.{cc[1]} < 8.9")
        print("   FP8需要H100、L40S或更新的GPU")
        print("   继续测试但可能会失败...")
    
    # 运行测试
    results = {}
    
    # 测试1: 数据类型支持
    results['dtype_support'] = test_fp8_dtype_support()
    
    # 测试2: 向量加法
    if results['dtype_support']:
        results['vector_add'] = test_fp8_vector_addition_fixed()
        
        # 测试3: 矩阵乘法
        results['matmul'] = test_fp8_matrix_multiplication()
        
        # 测试4: 性能对比
        if results['matmul']:
            results['performance'] = performance_comparison()
    else:
        print("\n❌ FP8数据类型不支持,跳过后续测试")
    
    # 汇总结果
    print("\n" + "="*50)
    print("📊 测试结果汇总")
    print("="*50)
    
    for test_name, passed in results.items():
        status = "✅ PASS" if passed else "❌ FAIL"
        print(f"{status}: {test_name}")
    
    # 最终结论
    print("\n" + "="*50)
    if results.get('dtype_support', False):
        print("🎉 你的Triton支持FP8!")
        if results.get('performance', False):
            print("⚡ FP8性能有提升,可以利用FP8加速")
        else:
            print("⚠️  FP8功能正常但性能提升不明显")
    else:
        print("❌ 当前环境不支持FP8")
        print("\n建议:")
        print("1. 升级Triton: pip install --upgrade triton")
        print("2. 升级PyTorch: pip install --upgrade torch")
        print("3. 确保使用支持的GPU (H100, L40S等)")
        print("4. 检查CUDA版本: nvcc --version")


if __name__ == "__main__":
    main()