""" Triton FP8支持测试Demo 测试环境要求: - Triton >= 2.1.0 - CUDA >= 11.8 - GPU计算能力 >= 8.9 (H100, L40S, etc.) """ import torch import triton import triton.language as tl import numpy as np from typing import Tuple # 检查Triton版本 print(f"Triton version: {triton.__version__}") print(f"PyTorch version: {torch.__version__}") print(f"CUDA available: {torch.cuda.is_available()}") if torch.cuda.is_available(): print(f"GPU: {torch.cuda.get_device_name(0)}") print(f"Compute Capability: {torch.cuda.get_device_capability(0)}") # 检查计算能力是否支持FP8 cc = torch.cuda.get_device_capability(0) fp8_supported = cc[0] * 10 + cc[1] >= 89 # 8.9+ print(f"FP8 hardware support: {fp8_supported}") @triton.jit def add_kernel_fp8( x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr, ): """使用FP8的向量加法kernel""" pid = tl.program_id(axis=0) block_start = pid * BLOCK_SIZE offsets = block_start + tl.arange(0, BLOCK_SIZE) mask = offsets < n_elements # 加载FP8数据并转换为FP32进行计算 x = tl.load(x_ptr + offsets, mask=mask).to(tl.float32) y = tl.load(y_ptr + offsets, mask=mask).to(tl.float32) # 执行计算 output = x + y # 转换回FP8并存储 # 注意:需要根据实际使用的FP8格式选择合适的缩放因子 output_fp8 = output.to(tl.float8e5) # 或 tl.float8e4m3 tl.store(output_ptr + offsets, output_fp8, mask=mask) @triton.jit def matmul_kernel_fp8( a_ptr, b_ptr, c_ptr, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, ): """使用FP8的矩阵乘法kernel""" pid_m = tl.program_id(0) pid_n = tl.program_id(1) # 计算当前block的位置 offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) offs_k = tl.arange(0, BLOCK_K) # 创建mask m_mask = offs_m[:, None] < M n_mask = offs_n[None, :] < N # 初始化累加器 accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) for k in range(0, K, BLOCK_K): # 加载FP8矩阵块并转换为FP32 a_ptrs = a_ptr + (offs_m[:, None] * stride_am + (offs_k[None, :] + k) * stride_ak) b_ptrs = b_ptr + ((offs_k[:, None] + k) * stride_bk + offs_n[None, :] * stride_bn) a = tl.load(a_ptrs, mask=m_mask & (offs_k[None, :] + k < K)[None, :]) b = tl.load(b_ptrs, mask=(offs_k[:, None] + k < K)[:, None] & n_mask) # 转换为FP32计算 a_fp32 = a.to(tl.float32) b_fp32 = b.to(tl.float32) # 矩阵乘法 accumulator += tl.dot(a_fp32, b_fp32) # 将结果转换为FP8并存储 c = accumulator.to(tl.float8e5) # 存储结果 offs_m_full = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) offs_n_full = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) c_ptrs = c_ptr + (offs_m_full[:, None] * stride_cm + offs_n_full[None, :] * stride_cn) tl.store(c_ptrs, c, mask=m_mask & n_mask) def create_fp8_tensor(data: torch.Tensor, fp8_type: str = 'e5m2') -> torch.Tensor: """创建FP8张量""" if fp8_type == 'e5m2': # 转换为float8_e5m2格式 # 需要先缩放到合适范围 max_val = data.abs().max() scale = 448.0 / max_val if max_val > 0 else 1.0 # e5m2最大值为448 scaled_data = data * scale return scaled_data.to(torch.float8_e5m2) elif fp8_type == 'e4m3': max_val = data.abs().max() scale = 240.0 / max_val if max_val > 0 else 1.0 # e4m3最大值为240 scaled_data = data * scale return scaled_data.to(torch.float8_e4m3fn) else: raise ValueError(f"Unsupported FP8 type: {fp8_type}") def create_fp8_tensor_with_scaling(data: torch.Tensor, fp8_type: str = 'e5m2'): """创建带缩放因子的FP8张量""" if fp8_type == 'e5m2': fp8_max = 57344.0 data_max = data.abs().max() scale = fp8_max / data_max if data_max > 0 else 1.0 scaled_data = data * scale # 确保值在FP8范围内 scaled_data = torch.clamp(scaled_data, -fp8_max, fp8_max) fp8_data = scaled_data.to(torch.float8_e5m2) return fp8_data, scale elif fp8_type == 'e4m3': fp8_max = 448.0 data_max = data.abs().max() scale = fp8_max / data_max if data_max > 0 else 1.0 scaled_data = data * scale scaled_data = torch.clamp(scaled_data, -fp8_max, fp8_max) fp8_data = scaled_data.to(torch.float8_e4m3fn) return fp8_data, scale else: raise ValueError(f"Unsupported FP8 type: {fp8_type}") @triton.jit def add_kernel_fp8_workaround( x_ptr, y_ptr, output_ptr, scale_x, scale_y, scale_out, n_elements, BLOCK_SIZE: tl.constexpr, ): """绕过类型问题的FP8加法""" pid = tl.program_id(axis=0) block_start = pid * BLOCK_SIZE offsets = block_start + tl.arange(0, BLOCK_SIZE) mask = offsets < n_elements # 直接加载并转换,一行完成 # 1. 先加载数据,并立即转换为 float32 类型 x_loaded = tl.load(x_ptr + offsets, mask=mask).to(tl.float32) # 2. 再对已经明确为 float32 类型的变量进行除法运算 x = x_loaded / scale_x # 1. 先加载数据,并立即转换为 float32 类型 y_loaded = tl.load(y_ptr + offsets, mask=mask).to(tl.float32) # 2. 再对已经明确为 float32 类型的变量进行除法运算 y = y_loaded / scale_y # 计算 output = x + y # 存储 output_scaled = output * scale_out tl.store(output_ptr + offsets, output_scaled.to(tl.float8e5), mask=mask) def test_fp8_vector_addition_fixed(): """修复后的FP8向量加法测试""" print("\n" + "="*50) print("测试: 带缩放因子的FP8向量加法") print("="*50) n_elements = 1024 # 准备数据 torch.manual_seed(42) x = torch.randn(n_elements, device='cuda', dtype=torch.float32) * 2 y = torch.randn(n_elements, device='cuda', dtype=torch.float32) * 2 # 创建带缩放的FP8张量 x_fp8, scale_x = create_fp8_tensor_with_scaling(x, 'e5m2') y_fp8, scale_y = create_fp8_tensor_with_scaling(y, 'e5m2') output_fp8 = torch.empty_like(x_fp8) # 计算输出的缩放因子 expected_max = (x.abs().max() + y.abs().max()).item() * 1.5 scale_out = 57344.0 / expected_max if expected_max > 0 else 1.0 print(f"缩放因子: scale_x={scale_x:.4f}, scale_y={scale_y:.4f}, scale_out={scale_out:.4f}") print(f"数据范围: x=[{x.min():.3f}, {x.max():.3f}], y=[{y.min():.3f}, {y.max():.3f}]") # 配置kernel BLOCK_SIZE = 256 grid = (triton.cdiv(n_elements, BLOCK_SIZE),) # 执行kernel - 注意缩放因子作为标量传递 add_kernel_fp8_workaround[grid]( x_fp8, y_fp8, output_fp8, scale_x, scale_y, scale_out, n_elements, BLOCK_SIZE ) # 恢复数据 output = output_fp8.float() / scale_out expected = x + y # 计算误差 abs_error = (output - expected).abs() rel_error = abs_error / (expected.abs() + 1e-8) print(f"\n前10个结果对比:") print(f"预期: {expected[:10].cpu()}") print(f"实际: {output[:10].cpu()}") print(f"差异: {(output - expected)[:10].cpu()}") print(f"\n误差统计:") print(f"平均绝对误差: {abs_error.mean():.6f}") print(f"最大绝对误差: {abs_error.max():.6f}") print(f"平均相对误差: {rel_error.mean():.6f}") print(f"最大相对误差: {rel_error.max():.6f}") # 检查是否合理 if rel_error.mean() < 0.1: print("✓ FP8精度在可接受范围内") return True else: print("⚠️ FP8精度损失较大") return False def test_fp8_matrix_multiplication(): """测试FP8矩阵乘法""" print("\n" + "="*50) print("测试2: FP8矩阵乘法") print("="*50) try: # 配置矩阵维度 M, N, K = 512, 512, 256 # 创建输入矩阵 a = torch.randn((M, K), device='cuda', dtype=torch.float32) b = torch.randn((K, N), device='cuda', dtype=torch.float32) # 转换为FP8 a_fp8 = create_fp8_tensor(a, 'e5m2') b_fp8 = create_fp8_tensor(b, 'e5m2') c_fp8 = torch.empty((M, N), device='cuda', dtype=torch.float8_e5m2) # Kernel配置 BLOCK_M, BLOCK_N, BLOCK_K = 64, 64, 64 grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N)) # 执行矩阵乘法 matmul_kernel_fp8[grid]( a_fp8, b_fp8, c_fp8, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c_fp8.stride(0), c_fp8.stride(1), BLOCK_M, BLOCK_N, BLOCK_K ) # 转换回FP32 c_output = c_fp8.float() # 验证结果(使用PyTorch的矩阵乘法) expected = torch.mm(a, b) # 归一化误差计算 norm_expected = torch.norm(expected) rel_error = torch.norm(c_output - expected) / norm_expected print(f"✓ FP8矩阵乘法成功执行") print(f" 矩阵维度: {M}x{K} * {K}x{N} = {M}x{N}") print(f" 相对误差: {rel_error:.6f}") # 显示部分结果 print(f" 预期结果[0:5,0:5]: {expected[0,0]:.4f} {expected[0,1]:.4f} ...") print(f" 实际结果[0:5,0:5]: {c_output[0,0]:.4f} {c_output[0,1]:.4f} ...") return True except Exception as e: print(f"✗ FP8矩阵乘法失败: {e}") return False def test_fp8_dtype_support(): """测试FP8数据类型支持""" print("\n" + "="*50) print("测试3: FP8数据类型支持") print("="*50) # 检查torch的FP8支持 fp8_types = { 'float8_e5m2': hasattr(torch, 'float8_e5m2'), 'float8_e4m3fn': hasattr(torch, 'float8_e4m3fn'), } for dtype_name, supported in fp8_types.items(): status = "✓" if supported else "✗" print(f"{status} torch.{dtype_name}: {supported}") # 检查triton的FP8支持 triton_fp8_types = { 'float8e5': hasattr(tl, 'float8e5'), 'float8e4': hasattr(tl, 'float8e4'), } for dtype_name, supported in triton_fp8_types.items(): status = "✓" if supported else "✗" print(f"{status} tl.{dtype_name}: {supported}") return any(fp8_types.values()) and any(triton_fp8_types.values()) def performance_comparison(): """性能对比:FP32 vs FP8""" print("\n" + "="*50) print("测试4: 性能对比 (FP32 vs FP8)") print("="*50) try: import time M, N, K = 1024, 1024, 512 # FP32矩阵乘法 a_fp32 = torch.randn((M, K), device='cuda', dtype=torch.float32) b_fp32 = torch.randn((K, N), device='cuda', dtype=torch.float32) # 预热 for _ in range(10): torch.mm(a_fp32, b_fp32) torch.cuda.synchronize() start = time.time() for _ in range(100): torch.mm(a_fp32, b_fp32) torch.cuda.synchronize() fp32_time = (time.time() - start) / 100 # FP8矩阵乘法 a_fp8 = create_fp8_tensor(a_fp32, 'e5m2') b_fp8 = create_fp8_tensor(b_fp32, 'e5m2') c_fp8 = torch.empty((M, N), device='cuda', dtype=torch.float8_e5m2) BLOCK_M, BLOCK_N, BLOCK_K = 64, 64, 64 grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N)) # 预热 for _ in range(10): matmul_kernel_fp8[grid]( a_fp8, b_fp8, c_fp8, M, N, K, a_fp8.stride(0), a_fp8.stride(1), b_fp8.stride(0), b_fp8.stride(1), c_fp8.stride(0), c_fp8.stride(1), BLOCK_M, BLOCK_N, BLOCK_K ) torch.cuda.synchronize() start = time.time() for _ in range(100): matmul_kernel_fp8[grid]( a_fp8, b_fp8, c_fp8, M, N, K, a_fp8.stride(0), a_fp8.stride(1), b_fp8.stride(0), b_fp8.stride(1), c_fp8.stride(0), c_fp8.stride(1), BLOCK_M, BLOCK_N, BLOCK_K ) torch.cuda.synchronize() fp8_time = (time.time() - start) / 100 speedup = fp32_time / fp8_time print(f"FP32 平均时间: {fp32_time*1000:.3f} ms") print(f"FP8 平均时间: {fp8_time*1000:.3f} ms") print(f"加速比: {speedup:.2f}x") return speedup > 1.0 except Exception as e: print(f"性能测试失败: {e}") return False def main(): """主测试函数""" print("\n" + "🚀 Triton FP8 支持测试套件") print("="*50) # 检查硬件支持 if not torch.cuda.is_available(): print("❌ CUDA不可用,无法测试FP8") return cc = torch.cuda.get_device_capability(0) if cc[0] * 10 + cc[1] < 89: print(f"⚠️ 警告: GPU计算能力{cc[0]}.{cc[1]} < 8.9") print(" FP8需要H100、L40S或更新的GPU") print(" 继续测试但可能会失败...") # 运行测试 results = {} # 测试1: 数据类型支持 results['dtype_support'] = test_fp8_dtype_support() # 测试2: 向量加法 if results['dtype_support']: results['vector_add'] = test_fp8_vector_addition_fixed() # 测试3: 矩阵乘法 results['matmul'] = test_fp8_matrix_multiplication() # 测试4: 性能对比 if results['matmul']: results['performance'] = performance_comparison() else: print("\n❌ FP8数据类型不支持,跳过后续测试") # 汇总结果 print("\n" + "="*50) print("📊 测试结果汇总") print("="*50) for test_name, passed in results.items(): status = "✅ PASS" if passed else "❌ FAIL" print(f"{status}: {test_name}") # 最终结论 print("\n" + "="*50) if results.get('dtype_support', False): print("🎉 你的Triton支持FP8!") if results.get('performance', False): print("⚡ FP8性能有提升,可以利用FP8加速") else: print("⚠️ FP8功能正常但性能提升不明显") else: print("❌ 当前环境不支持FP8") print("\n建议:") print("1. 升级Triton: pip install --upgrade triton") print("2. 升级PyTorch: pip install --upgrade torch") print("3. 确保使用支持的GPU (H100, L40S等)") print("4. 检查CUDA版本: nvcc --version") if __name__ == "__main__": main()