Commit 01caaf29 authored by helloyongyang's avatar helloyongyang
Browse files

Add lightx2v_kernel for nvfp4

parent ea618db2
import torch
from lightx2v_kernel.gemm import scaled_fp4_quant, cutlass_scaled_fp4_mm
import time
from test_bench2 import MMWeightFp4
def test_speed(m, k, n):
with torch.no_grad():
input_tensor = torch.randn(m, k, dtype=torch.bfloat16).cuda()
weight = torch.randn(n, k, dtype=torch.bfloat16, device="cuda")
bias = torch.randn(1, n, dtype=torch.bfloat16).cuda()
mm = MMWeightFp4(weight, bias)
# warmup
output_tensor = mm.apply(input_tensor)
torch.cuda.synchronize()
start_time = time.time()
for i in range(100):
output_tensor = mm.apply(input_tensor)
torch.cuda.synchronize()
end_time = time.time()
lightx2v_kernel_time = (end_time - start_time) / 100
print(f"lightx2v-kernel time: {lightx2v_kernel_time}")
input_tensor = torch.randn(m, n, dtype=torch.bfloat16).cuda()
weight = torch.randn(k, n, dtype=torch.bfloat16, device="cuda")
bias = torch.randn(1, k, dtype=torch.bfloat16).cuda()
linear = torch.nn.Linear(k, n, bias=True).cuda()
linear.weight.data = weight
linear.bias.data = bias
# warmup
ref_output_tensor = linear(input_tensor)
torch.cuda.synchronize()
start_time = time.time()
for i in range(100):
ref_output_tensor = linear(input_tensor)
torch.cuda.synchronize()
end_time = time.time()
ref_time = (end_time - start_time) / 100
print(f"ref time: {ref_time}")
print(f"speedup: {ref_time / lightx2v_kernel_time:.3f}")
def test_accuracy(m, k, n):
with torch.no_grad():
input_tensor = torch.randn(m, k, dtype=torch.bfloat16).cuda()
weight = torch.randn(n, k, dtype=torch.bfloat16, device="cuda")
bias = torch.randn(1, n, dtype=torch.bfloat16).cuda()
linear = torch.nn.Linear(k, n, bias=True).cuda()
linear.weight.data = weight
linear.bias.data = bias
ref_output_tensor = linear(input_tensor)
mm = MMWeightFp4(weight, bias)
output_tensor = mm.apply(input_tensor)
# print(f"ref_output_tensor: {ref_output_tensor}")
# print(f"output_tensor: {output_tensor}")
# cosine
cos = torch.nn.functional.cosine_similarity(ref_output_tensor.flatten(), output_tensor.flatten(), dim=0)
print(f"cos : {cos}")
if __name__ == "__main__":
test_sizes = [
(32130, 5120, 5120),
(512, 5120, 5120),
(257, 5120, 5120),
(32130, 5120, 13824),
(32130, 13824, 5120),
(75348, 5120, 5120),
(75348, 13824, 5120),
(32760, 1536, 1536),
(512, 1536, 1536),
(32760, 1536, 8960),
(32760, 8960, 1536),
]
for i, (m, k, n) in enumerate(test_sizes):
print("-" * 30)
print(f"测试 {i+1}: 张量大小 ({m}, {k}, {n})")
test_accuracy(m, k, n)
test_speed(m, k, n)
import torch
from lightx2v_kernel.gemm import cutlass_scaled_fp4_mm
"""
input_shape = (1024, 2048)
weight_shape = (4096, 2048)
input_tensor_quant = (torch.rand((1024, 1024), device="cuda") * 10).to(torch.uint8)
weight = (torch.rand((4096, 1024), device="cuda") * 10).to(torch.uint8)
input_tensor_scale = torch.rand(1024, 128, device="cuda").to(torch.float8_e4m3fn)
weight_scale = torch.rand(4096, 128, device="cuda").to(torch.float8_e4m3fn)
alpha = torch.tensor(0.0002765655517578125, device="cuda").to(torch.float32)
bias = None
"""
def test_mm(input_tensor_quant, weight, input_tensor_scale, weight_scale, alpha, bias):
output_tensor = cutlass_scaled_fp4_mm(input_tensor_quant, weight, input_tensor_scale, weight_scale, alpha=alpha, bias=bias)
return output_tensor
def test_tflops(input_shape, weight_shape, num_warmup=10, num_runs=100):
"""
测试test_mm函数的TFLOPS性能
"""
# 创建输入数据
input_tensor_quant = (torch.rand((input_shape[0], input_shape[1] // 2), device="cuda") * 10).to(torch.uint8)
weight = (torch.rand((weight_shape[0], weight_shape[1] // 2), device="cuda") * 10).to(torch.uint8)
input_tensor_scale = torch.rand(((input_shape[0] + 128 - 1) // 128) * 128, (input_shape[1] // 16 + 4 - 1) // 4 * 4, device="cuda").to(torch.float8_e4m3fn)
weight_scale = torch.rand(weight_shape[0], weight_shape[1] // 16, device="cuda").to(torch.float8_e4m3fn)
alpha = torch.tensor(0.0002765655517578125, device="cuda", dtype=torch.float32)
bias = None
# 预热GPU
for _ in range(num_warmup):
test_mm(input_tensor_quant, weight, input_tensor_scale, weight_scale, alpha, bias)
# 同步GPU
torch.cuda.synchronize()
# 创建GPU事件用于精确计时
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
# 测量时间
start_event.record()
for _ in range(num_runs):
result = test_mm(input_tensor_quant, weight, input_tensor_scale, weight_scale, alpha, bias)
end_event.record()
# 同步并计算时间
torch.cuda.synchronize()
elapsed_time_ms = start_event.elapsed_time(end_event)
elapsed_time_s = elapsed_time_ms / 1000.0
# 计算FLOPS
# 矩阵乘法 A(M x K) @ B(K x N) = C(M x N)
# M = batch_size, K = input_dim, N = output_dim
M = input_shape[0]
K = input_shape[1]
N = weight_shape[0]
# 每次矩阵乘法的FLOPS = 2 * M * N * K (每个输出元素需要K次乘法和K次加法)
flops_per_run = 2 * M * N * K
total_flops = flops_per_run * num_runs
# 计算TFLOPS (万亿次浮点运算每秒)
tflops = total_flops / (elapsed_time_s * 1e12)
print(f"测试结果:")
print(f" 输入形状: {input_shape} (M={M}, K={K})")
print(f" 权重形状: {weight_shape} (N={N}, K={K})")
print(f" 输出形状: ({M}, {N})")
print(f" 运行次数: {num_runs}")
print(f" 总执行时间: {elapsed_time_ms:.2f} ms")
print(f" 平均每次执行时间: {elapsed_time_ms/num_runs:.4f} ms")
print(f" 每次运行FLOPS: {flops_per_run/1e9:.2f} GFLOPS")
print(f" 总FLOPS: {total_flops/1e12:.2f} TFLOPS")
print(f" 计算性能: {tflops:.2f} TFLOPS")
return tflops
if __name__ == "__main__":
# 测试不同大小的矩阵乘法
# (m,k) (n,k)
test_cases = [
((32130, 5120), (5120, 5120)),
((512, 5120), (5120, 5120)),
((257, 5120), (5120, 5120)),
((32130, 5120), (13824, 5120)),
((32130, 13824), (5120, 13824)),
((75348, 5120), (5120, 5120)),
((75348, 5120), (13824, 5120)),
((75348, 13824), (5120, 13824)),
((32760, 1536), (1536, 1536)),
((512, 1536), (1536, 1536)),
((32760, 1536), (8960, 1536)),
((32760, 8960), (1536, 8960)),
]
print("=== test_mm TFLOPS性能测试 ===\n")
for i, (input_shape, weight_shape) in enumerate(test_cases):
print(f"测试 {i+1}: 输入形状 {input_shape}, 权重形状 {weight_shape}")
print("-" * 60)
tflops = test_tflops(input_shape, weight_shape)
print(f"✓ 成功完成测试,性能: {tflops:.2f} TFLOPS\n")
print("=== 测试完成 ===")
import torch
from lightx2v_kernel.gemm import scaled_fp4_quant
input_global_scale = torch.tensor(808.0, dtype=torch.float32).cuda()
def quantize_fp4(x):
return scaled_fp4_quant(x, input_global_scale)
def test_memory_bandwidth(func, x, num_warmup=10, num_runs=100):
"""
测试函数的显存带宽
"""
# 预热GPU
for _ in range(num_warmup):
func(x)
# 同步GPU
torch.cuda.synchronize()
# 创建GPU事件用于精确计时
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
# 测量时间
start_event.record()
for _ in range(num_runs):
result = func(x)
end_event.record()
# 同步并计算时间
torch.cuda.synchronize()
elapsed_time_ms = start_event.elapsed_time(end_event)
elapsed_time_s = elapsed_time_ms / 1000.0
# 计算数据量
input_bytes = x.numel() * x.element_size() # 输入数据字节数
# FP4量化后,每个元素占用0.5字节
output_bytes = x.numel() * 0.5 # FP4输出数据字节数
scale_bytes = x.numel() / 16 # group_size = 16
# 总数据传输量(读取输入 + 写入输出 + scale)
total_bytes = (input_bytes + output_bytes + scale_bytes) * num_runs
# 计算带宽
bandwidth_gbps = (total_bytes / elapsed_time_s) / (1024**3) # GB/s
print(f"测试结果:")
print(f" 输入张量形状: {x.shape}")
print(f" 输入数据类型: {x.dtype}")
print(f" 运行次数: {num_runs}")
print(f" 总执行时间: {elapsed_time_ms:.2f} ms")
print(f" 平均每次执行时间: {elapsed_time_ms/num_runs:.4f} ms")
print(f" 输入数据大小: {input_bytes / (1024**2):.2f} MB")
print(f" 输出数据大小: {output_bytes / (1024**2):.2f} MB")
print(f" 总数据传输量: {total_bytes / (1024**3):.2f} GB")
print(f" 显存带宽: {bandwidth_gbps:.2f} GB/s")
return bandwidth_gbps
if __name__ == "__main__":
# 测试不同大小的张量
test_sizes = [
# (1, 1024),
# (1, 2048),
# (1, 4096),
# (1, 8192),
# (1, 16384),
# (1, 32768),
# (2, 1024),
# (2, 2048),
# (2, 4096),
# (2, 8192),
# (2, 16384),
# (2, 32768),
# (4, 1024),
# (4, 2048),
# (4, 4096),
# (4, 8192),
# (4, 16384),
# (4, 32768),
# (128, 1024),
# (128, 2048),
# (128, 4096),
# (128, 8192),
# (128, 16384),
# (128, 32768),
# (512, 1024),
# (512, 2048),
# (512, 4096),
# (512, 8192),
# (512, 16384),
# (512, 32768),
# (1024, 1024),
# (1024, 2048),
# (1024, 4096),
# (1024, 8192),
# (1024, 16384),
# (1024, 32768),
# (2048, 1024),
# (2048, 2048),
# (2048, 4096),
# (2048, 8192),
# (2048, 16384),
# (2048, 32768),
# (4096, 1024),
# (4096, 2048),
# (4096, 4096),
# (4096, 8192),
# (4096, 16384),
# (4096, 32768),
# (8192, 1024),
# (8192, 2048),
# (8192, 4096),
# (8192, 8192),
# (8192, 16384),
# (8192, 32768),
# (16384, 1024),
# (16384, 2048),
# (16384, 4096),
# (16384, 8192),
# (16384, 16384),
# (16384, 32768),
# (32768, 1024),
# (32768, 2048),
# (32768, 4096),
# (32768, 8192),
# (32768, 16384),
# (32768, 32768),
(32130, 5120),
(512, 5120),
(257, 5120),
(32130, 13824),
(75348, 5120),
(75348, 13824),
(32760, 1536),
(512, 1536),
(32760, 8960),
]
print("=== quantize_fp4 显存带宽测试 ===\n")
for i, (h, w) in enumerate(test_sizes):
print(f"测试 {i+1}: 张量大小 ({h}, {w})")
print("-" * 50)
x = torch.randn(h, w, dtype=torch.bfloat16).cuda()
try:
bandwidth = test_memory_bandwidth(quantize_fp4, x)
print(f"✓ 成功完成测试,带宽: {bandwidth:.2f} GB/s\n")
except Exception as e:
print(f"✗ 测试失败: {e}\n")
print("=== 测试完成 ===")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment