Commit c7bb59cd authored by helloyongyang's avatar helloyongyang
Browse files

fix ci

parent 01caaf29
......@@ -61,7 +61,7 @@ struct Fp4GemmSm120 {
using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, OperatorClass,
ArchTag, OperatorClass,
ThreadBlockShape, ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementAccumulator,
......
......@@ -14,6 +14,4 @@ from lightx2v_kernel import common_ops
from lightx2v_kernel.gemm import cutlass_scaled_fp4_mm, scaled_fp4_quant
from lightx2v_kernel.version import __version__
build_tree_kernel = (
None
)
build_tree_kernel = None
......@@ -12,15 +12,11 @@ def cutlass_scaled_fp4_mm(mat_a, mat_b, scales_a, scales_b, alpha, bias=None):
"""
m, n = mat_a.shape[0], mat_b.shape[0]
out = torch.empty((m, n), dtype=torch.bfloat16, device=mat_a.device)
torch.ops.lightx2v_kernel.cutlass_scaled_fp4_mm_sm120.default(
out, mat_a, mat_b, scales_a, scales_b, alpha, bias
)
torch.ops.lightx2v_kernel.cutlass_scaled_fp4_mm_sm120.default(out, mat_a, mat_b, scales_a, scales_b, alpha, bias)
return out
def scaled_fp4_quant(
input: torch.Tensor, input_global_scale: torch.Tensor
):
def scaled_fp4_quant(input: torch.Tensor, input_global_scale: torch.Tensor):
"""
Quantize input tensor to FP4 and return quantized tensor and scale.
......@@ -60,13 +56,8 @@ def scaled_fp4_quant(
# rounded_m = ((m + 128 - 1) // 128) * 128
# scale_n = n // block_size
# rounded_n = ((scale_n + 4 - 1) // 4) * 4
output_scale = torch.empty(
(((m + 128 - 1) // 128) * 128, (n // block_size + 4 - 1) // 4), device=device, dtype=torch.int32
)
output_scale = torch.empty((((m + 128 - 1) // 128) * 128, (n // block_size + 4 - 1) // 4), device=device, dtype=torch.int32)
torch.ops.lightx2v_kernel.scaled_fp4_quant_sm120.default(
output, input, output_scale, input_global_scale
)
torch.ops.lightx2v_kernel.scaled_fp4_quant_sm120.default(output, input, output_scale, input_global_scale)
output_scale = output_scale.view(torch.float8_e4m3fn)
return output, output_scale
......@@ -6,6 +6,7 @@ BLOCK_SIZE = 16
FLOAT4_E2M1_MAX = 6.0
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
def cast_to_fp4(x):
sign = torch.sign(x)
x = torch.abs(x)
......
......@@ -53,9 +53,7 @@ def break_fp4_bytes(a, dtype):
return out
def dequantize_to_dtype(
tensor_fp4, tensor_sf, global_scale, dtype, device, block_size=16
):
def dequantize_to_dtype(tensor_fp4, tensor_sf, global_scale, dtype, device, block_size=16):
"""Dequantize the fp4 tensor back to high precision."""
# Two fp4 values are packed into one uint8.
assert tensor_fp4.dtype == torch.uint8
......@@ -88,12 +86,8 @@ def get_ref_results(
_, m_k = a_fp4.shape
_, n_k = b_fp4.shape
assert m_k == n_k
a_in_dtype = dequantize_to_dtype(
a_fp4, a_sf, a_global_scale, dtype=dtype, device=device, block_size=block_size
)
b_in_dtype = dequantize_to_dtype(
b_fp4, b_sf, b_global_scale, dtype=dtype, device=device, block_size=block_size
)
a_in_dtype = dequantize_to_dtype(a_fp4, a_sf, a_global_scale, dtype=dtype, device=device, block_size=block_size)
b_in_dtype = dequantize_to_dtype(b_fp4, b_sf, b_global_scale, dtype=dtype, device=device, block_size=block_size)
return torch.matmul(a_in_dtype, b_in_dtype.t())
......@@ -109,16 +103,12 @@ def test_nvfp4_gemm(
b_dtype = torch.randn((n, k), dtype=dtype, device="cuda")
bias = torch.randn((1, n), dtype=dtype, device="cuda")
a_global_scale = (
(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(a_dtype.flatten(), dim=-1)
).to(torch.float32)
b_global_scale = (
(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(b_dtype.flatten(), dim=-1)
).to(torch.float32)
a_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(a_dtype.flatten(), dim=-1)).to(torch.float32)
b_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(b_dtype.flatten(), dim=-1)).to(torch.float32)
print(f"a_global_scale : {a_global_scale}, {a_global_scale.shape}")
print(f"b_global_scale : {b_global_scale}, {b_global_scale.shape}")
alpha = 1.0 / (a_global_scale * b_global_scale)
a_fp4, a_scale_interleaved = scaled_fp4_quant(a_dtype, a_global_scale)
b_fp4, b_scale_interleaved = scaled_fp4_quant(b_dtype, b_global_scale)
......@@ -137,15 +127,11 @@ def test_nvfp4_gemm(
"cuda",
)
expected_out = expected_out + bias
print(f"alpha {alpha}, {alpha.shape}, {alpha.dtype}")
out = cutlass_scaled_fp4_mm(
a_fp4, b_fp4, a_scale_interleaved, b_scale_interleaved, alpha, bias
)
out = cutlass_scaled_fp4_mm(a_fp4, b_fp4, a_scale_interleaved, b_scale_interleaved, alpha, bias)
print(f"out : {out}, {out.shape}, {out.dtype}")
print(f"expected_out : {expected_out}, {expected_out.shape}, {expected_out.dtype}")
......
......@@ -7,7 +7,7 @@ class MMWeightFp4:
def __init__(self, weight, bias):
self.load_fp4_weight(weight, bias)
self.act_quant_func = self.act_quant_fp4
# calibrate x_max
self.calibrate_x_absmax()
......@@ -24,7 +24,7 @@ class MMWeightFp4:
self.bias = bias
def calibrate_x_absmax(self):
self.x_absmax = torch.tensor(5.0, dtype=torch.float32, device=self.weight.device) # need to be calibrated
self.x_absmax = torch.tensor(5.0, dtype=torch.float32, device=self.weight.device) # need to be calibrated
self.input_global_scale = (2688.0 / self.x_absmax).to(torch.float32)
self.alpha = 1.0 / (self.input_global_scale * self.weight_global_scale)
......@@ -33,7 +33,6 @@ class MMWeightFp4:
return scaled_fp4_quant(x, self.input_global_scale)
def test_speed(m, k, n):
with torch.no_grad():
input_tensor = torch.randn(m, k, dtype=torch.bfloat16).cuda()
......@@ -42,26 +41,24 @@ def test_speed(m, k, n):
bias = None
mm = MMWeightFp4(weight, bias)
# warmup
output_tensor = mm.apply(input_tensor)
torch.cuda.synchronize()
start_time = time.time()
for i in range(100):
output_tensor = mm.apply(input_tensor)
torch.cuda.synchronize()
end_time = time.time()
lightx2v_kernel_time = (end_time - start_time) / 100
print(f"lightx2v-kernel time: {lightx2v_kernel_time}")
input_tensor = torch.randn(m, n, dtype=torch.bfloat16).cuda()
weight = torch.randn(k, n, dtype=torch.bfloat16, device="cuda")
bias = torch.randn(1, k, dtype=torch.bfloat16).cuda()
linear = torch.nn.Linear(k, n, bias=False).cuda()
linear.weight.data = weight
# linear.bias.data = bias
......@@ -72,13 +69,13 @@ def test_speed(m, k, n):
torch.cuda.synchronize()
start_time = time.time()
for i in range(100):
ref_output_tensor = linear(input_tensor)
ref_output_tensor = linear(input_tensor)
torch.cuda.synchronize()
end_time = time.time()
ref_time = (end_time - start_time) / 100
print(f"ref time: {ref_time}")
print(f"speedup: {ref_time / lightx2v_kernel_time:.3f}")
......@@ -88,47 +85,42 @@ def test_accuracy(m, k, n):
weight = torch.randn(n, k, dtype=torch.bfloat16, device="cuda")
# bias = torch.randn(1, n, dtype=torch.bfloat16).cuda()
bias = None
linear = torch.nn.Linear(k, n, bias=False).cuda()
linear.weight.data = weight
# linear.bias.data = bias
ref_output_tensor = linear(input_tensor)
mm = MMWeightFp4(weight, bias)
output_tensor = mm.apply(input_tensor)
# print(f"ref_output_tensor: {ref_output_tensor}")
# print(f"output_tensor: {output_tensor}")
# cosine
cos = torch.nn.functional.cosine_similarity(ref_output_tensor.flatten(), output_tensor.flatten(), dim=0)
print(f"cos : {cos}")
if __name__ == "__main__":
test_sizes = [
(32130, 5120, 5120),
(512, 5120, 5120),
(257, 5120, 5120),
(32130, 5120, 13824),
(32130, 13824, 5120),
(75348, 5120, 5120),
(75348, 13824, 5120),
(32760, 1536, 1536),
(512, 1536, 1536),
(32760, 1536, 8960),
(32760, 8960, 1536),
]
for i, (m, k, n) in enumerate(test_sizes):
print("-" * 30)
print(f"测试 {i+1}: 张量大小 ({m}, {k}, {n})")
print(f"测试 {i + 1}: 张量大小 ({m}, {k}, {n})")
test_accuracy(m, k, n)
test_speed(m, k, n)
......@@ -11,26 +11,24 @@ def test_speed(m, k, n):
bias = torch.randn(1, n, dtype=torch.bfloat16).cuda()
mm = MMWeightFp4(weight, bias)
# warmup
output_tensor = mm.apply(input_tensor)
torch.cuda.synchronize()
start_time = time.time()
for i in range(100):
output_tensor = mm.apply(input_tensor)
torch.cuda.synchronize()
end_time = time.time()
lightx2v_kernel_time = (end_time - start_time) / 100
print(f"lightx2v-kernel time: {lightx2v_kernel_time}")
input_tensor = torch.randn(m, n, dtype=torch.bfloat16).cuda()
weight = torch.randn(k, n, dtype=torch.bfloat16, device="cuda")
bias = torch.randn(1, k, dtype=torch.bfloat16).cuda()
linear = torch.nn.Linear(k, n, bias=True).cuda()
linear.weight.data = weight
linear.bias.data = bias
......@@ -41,13 +39,13 @@ def test_speed(m, k, n):
torch.cuda.synchronize()
start_time = time.time()
for i in range(100):
ref_output_tensor = linear(input_tensor)
ref_output_tensor = linear(input_tensor)
torch.cuda.synchronize()
end_time = time.time()
ref_time = (end_time - start_time) / 100
print(f"ref time: {ref_time}")
print(f"speedup: {ref_time / lightx2v_kernel_time:.3f}")
......@@ -56,47 +54,42 @@ def test_accuracy(m, k, n):
input_tensor = torch.randn(m, k, dtype=torch.bfloat16).cuda()
weight = torch.randn(n, k, dtype=torch.bfloat16, device="cuda")
bias = torch.randn(1, n, dtype=torch.bfloat16).cuda()
linear = torch.nn.Linear(k, n, bias=True).cuda()
linear.weight.data = weight
linear.bias.data = bias
ref_output_tensor = linear(input_tensor)
mm = MMWeightFp4(weight, bias)
output_tensor = mm.apply(input_tensor)
# print(f"ref_output_tensor: {ref_output_tensor}")
# print(f"output_tensor: {output_tensor}")
# cosine
cos = torch.nn.functional.cosine_similarity(ref_output_tensor.flatten(), output_tensor.flatten(), dim=0)
print(f"cos : {cos}")
if __name__ == "__main__":
test_sizes = [
(32130, 5120, 5120),
(512, 5120, 5120),
(257, 5120, 5120),
(32130, 5120, 13824),
(32130, 13824, 5120),
(75348, 5120, 5120),
(75348, 13824, 5120),
(32760, 1536, 1536),
(512, 1536, 1536),
(32760, 1536, 8960),
(32760, 8960, 1536),
]
for i, (m, k, n) in enumerate(test_sizes):
print("-" * 30)
print(f"测试 {i+1}: 张量大小 ({m}, {k}, {n})")
print(f"测试 {i + 1}: 张量大小 ({m}, {k}, {n})")
test_accuracy(m, k, n)
test_speed(m, k, n)
......@@ -14,6 +14,7 @@ alpha = torch.tensor(0.0002765655517578125, device="cuda").to(torch.float32)
bias = None
"""
def test_mm(input_tensor_quant, weight, input_tensor_scale, weight_scale, alpha, bias):
output_tensor = cutlass_scaled_fp4_mm(input_tensor_quant, weight, input_tensor_scale, weight_scale, alpha=alpha, bias=bias)
return output_tensor
......@@ -23,64 +24,63 @@ def test_tflops(input_shape, weight_shape, num_warmup=10, num_runs=100):
"""
测试test_mm函数的TFLOPS性能
"""
# 创建输入数据
input_tensor_quant = (torch.rand((input_shape[0], input_shape[1] // 2), device="cuda") * 10).to(torch.uint8)
weight = (torch.rand((weight_shape[0], weight_shape[1] // 2), device="cuda") * 10).to(torch.uint8)
input_tensor_scale = torch.rand(((input_shape[0] + 128 - 1) // 128) * 128, (input_shape[1] // 16 + 4 - 1) // 4 * 4, device="cuda").to(torch.float8_e4m3fn)
weight_scale = torch.rand(weight_shape[0], weight_shape[1] // 16, device="cuda").to(torch.float8_e4m3fn)
alpha = torch.tensor(0.0002765655517578125, device="cuda", dtype=torch.float32)
bias = None
# 预热GPU
for _ in range(num_warmup):
test_mm(input_tensor_quant, weight, input_tensor_scale, weight_scale, alpha, bias)
# 同步GPU
torch.cuda.synchronize()
# 创建GPU事件用于精确计时
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
# 测量时间
start_event.record()
for _ in range(num_runs):
result = test_mm(input_tensor_quant, weight, input_tensor_scale, weight_scale, alpha, bias)
end_event.record()
# 同步并计算时间
torch.cuda.synchronize()
elapsed_time_ms = start_event.elapsed_time(end_event)
elapsed_time_s = elapsed_time_ms / 1000.0
# 计算FLOPS
# 矩阵乘法 A(M x K) @ B(K x N) = C(M x N)
# M = batch_size, K = input_dim, N = output_dim
M = input_shape[0]
K = input_shape[1]
N = weight_shape[0]
# 每次矩阵乘法的FLOPS = 2 * M * N * K (每个输出元素需要K次乘法和K次加法)
flops_per_run = 2 * M * N * K
total_flops = flops_per_run * num_runs
# 计算TFLOPS (万亿次浮点运算每秒)
tflops = total_flops / (elapsed_time_s * 1e12)
print(f"测试结果:")
print(f" 输入形状: {input_shape} (M={M}, K={K})")
print(f" 权重形状: {weight_shape} (N={N}, K={K})")
print(f" 输出形状: ({M}, {N})")
print(f" 运行次数: {num_runs}")
print(f" 总执行时间: {elapsed_time_ms:.2f} ms")
print(f" 平均每次执行时间: {elapsed_time_ms/num_runs:.4f} ms")
print(f" 每次运行FLOPS: {flops_per_run/1e9:.2f} GFLOPS")
print(f" 总FLOPS: {total_flops/1e12:.2f} TFLOPS")
print(f" 平均每次执行时间: {elapsed_time_ms / num_runs:.4f} ms")
print(f" 每次运行FLOPS: {flops_per_run / 1e9:.2f} GFLOPS")
print(f" 总FLOPS: {total_flops / 1e12:.2f} TFLOPS")
print(f" 计算性能: {tflops:.2f} TFLOPS")
return tflops
......@@ -93,24 +93,22 @@ if __name__ == "__main__":
((257, 5120), (5120, 5120)),
((32130, 5120), (13824, 5120)),
((32130, 13824), (5120, 13824)),
((75348, 5120), (5120, 5120)),
((75348, 5120), (13824, 5120)),
((75348, 13824), (5120, 13824)),
((32760, 1536), (1536, 1536)),
((512, 1536), (1536, 1536)),
((32760, 1536), (8960, 1536)),
((32760, 8960), (1536, 8960)),
]
print("=== test_mm TFLOPS性能测试 ===\n")
for i, (input_shape, weight_shape) in enumerate(test_cases):
print(f"测试 {i+1}: 输入形状 {input_shape}, 权重形状 {weight_shape}")
print(f"测试 {i + 1}: 输入形状 {input_shape}, 权重形状 {weight_shape}")
print("-" * 60)
tflops = test_tflops(input_shape, weight_shape)
print(f"✓ 成功完成测试,性能: {tflops:.2f} TFLOPS\n")
print("=== 测试完成 ===")
......@@ -4,7 +4,8 @@ from lightx2v_kernel.gemm import scaled_fp4_quant
input_global_scale = torch.tensor(808.0, dtype=torch.float32).cuda()
def quantize_fp4(x):
def quantize_fp4(x):
return scaled_fp4_quant(x, input_global_scale)
......@@ -15,51 +16,50 @@ def test_memory_bandwidth(func, x, num_warmup=10, num_runs=100):
# 预热GPU
for _ in range(num_warmup):
func(x)
# 同步GPU
torch.cuda.synchronize()
# 创建GPU事件用于精确计时
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
# 测量时间
start_event.record()
for _ in range(num_runs):
result = func(x)
end_event.record()
# 同步并计算时间
torch.cuda.synchronize()
elapsed_time_ms = start_event.elapsed_time(end_event)
elapsed_time_s = elapsed_time_ms / 1000.0
# 计算数据量
input_bytes = x.numel() * x.element_size() # 输入数据字节数
# FP4量化后,每个元素占用0.5字节
output_bytes = x.numel() * 0.5 # FP4输出数据字节数
scale_bytes = x.numel() / 16 # group_size = 16
scale_bytes = x.numel() / 16 # group_size = 16
# 总数据传输量(读取输入 + 写入输出 + scale)
total_bytes = (input_bytes + output_bytes + scale_bytes) * num_runs
# 计算带宽
bandwidth_gbps = (total_bytes / elapsed_time_s) / (1024**3) # GB/s
print(f"测试结果:")
print(f" 输入张量形状: {x.shape}")
print(f" 输入数据类型: {x.dtype}")
print(f" 运行次数: {num_runs}")
print(f" 总执行时间: {elapsed_time_ms:.2f} ms")
print(f" 平均每次执行时间: {elapsed_time_ms/num_runs:.4f} ms")
print(f" 平均每次执行时间: {elapsed_time_ms / num_runs:.4f} ms")
print(f" 输入数据大小: {input_bytes / (1024**2):.2f} MB")
print(f" 输出数据大小: {output_bytes / (1024**2):.2f} MB")
print(f" 输出数据大小: {output_bytes / (1024**2):.2f} MB")
print(f" 总数据传输量: {total_bytes / (1024**3):.2f} GB")
print(f" 显存带宽: {bandwidth_gbps:.2f} GB/s")
return bandwidth_gbps
......@@ -132,33 +132,29 @@ if __name__ == "__main__":
# (32768, 8192),
# (32768, 16384),
# (32768, 32768),
(32130, 5120),
(512, 5120),
(257, 5120),
(32130, 13824),
(75348, 5120),
(75348, 13824),
(32760, 1536),
(512, 1536),
(32760, 8960),
]
print("=== quantize_fp4 显存带宽测试 ===\n")
for i, (h, w) in enumerate(test_sizes):
print(f"测试 {i+1}: 张量大小 ({h}, {w})")
print(f"测试 {i + 1}: 张量大小 ({h}, {w})")
print("-" * 50)
x = torch.randn(h, w, dtype=torch.bfloat16).cuda()
try:
bandwidth = test_memory_bandwidth(quantize_fp4, x)
print(f"✓ 成功完成测试,带宽: {bandwidth:.2f} GB/s\n")
except Exception as e:
print(f"✗ 测试失败: {e}\n")
print("=== 测试完成 ===")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment