Commit c7bb59cd authored by helloyongyang's avatar helloyongyang
Browse files

fix ci

parent 01caaf29
...@@ -14,6 +14,4 @@ from lightx2v_kernel import common_ops ...@@ -14,6 +14,4 @@ from lightx2v_kernel import common_ops
from lightx2v_kernel.gemm import cutlass_scaled_fp4_mm, scaled_fp4_quant from lightx2v_kernel.gemm import cutlass_scaled_fp4_mm, scaled_fp4_quant
from lightx2v_kernel.version import __version__ from lightx2v_kernel.version import __version__
build_tree_kernel = ( build_tree_kernel = None
None
)
...@@ -12,15 +12,11 @@ def cutlass_scaled_fp4_mm(mat_a, mat_b, scales_a, scales_b, alpha, bias=None): ...@@ -12,15 +12,11 @@ def cutlass_scaled_fp4_mm(mat_a, mat_b, scales_a, scales_b, alpha, bias=None):
""" """
m, n = mat_a.shape[0], mat_b.shape[0] m, n = mat_a.shape[0], mat_b.shape[0]
out = torch.empty((m, n), dtype=torch.bfloat16, device=mat_a.device) out = torch.empty((m, n), dtype=torch.bfloat16, device=mat_a.device)
torch.ops.lightx2v_kernel.cutlass_scaled_fp4_mm_sm120.default( torch.ops.lightx2v_kernel.cutlass_scaled_fp4_mm_sm120.default(out, mat_a, mat_b, scales_a, scales_b, alpha, bias)
out, mat_a, mat_b, scales_a, scales_b, alpha, bias
)
return out return out
def scaled_fp4_quant( def scaled_fp4_quant(input: torch.Tensor, input_global_scale: torch.Tensor):
input: torch.Tensor, input_global_scale: torch.Tensor
):
""" """
Quantize input tensor to FP4 and return quantized tensor and scale. Quantize input tensor to FP4 and return quantized tensor and scale.
...@@ -60,13 +56,8 @@ def scaled_fp4_quant( ...@@ -60,13 +56,8 @@ def scaled_fp4_quant(
# rounded_m = ((m + 128 - 1) // 128) * 128 # rounded_m = ((m + 128 - 1) // 128) * 128
# scale_n = n // block_size # scale_n = n // block_size
# rounded_n = ((scale_n + 4 - 1) // 4) * 4 # rounded_n = ((scale_n + 4 - 1) // 4) * 4
output_scale = torch.empty( output_scale = torch.empty((((m + 128 - 1) // 128) * 128, (n // block_size + 4 - 1) // 4), device=device, dtype=torch.int32)
(((m + 128 - 1) // 128) * 128, (n // block_size + 4 - 1) // 4), device=device, dtype=torch.int32
)
torch.ops.lightx2v_kernel.scaled_fp4_quant_sm120.default( torch.ops.lightx2v_kernel.scaled_fp4_quant_sm120.default(output, input, output_scale, input_global_scale)
output, input, output_scale, input_global_scale
)
output_scale = output_scale.view(torch.float8_e4m3fn) output_scale = output_scale.view(torch.float8_e4m3fn)
return output, output_scale return output, output_scale
...@@ -6,6 +6,7 @@ BLOCK_SIZE = 16 ...@@ -6,6 +6,7 @@ BLOCK_SIZE = 16
FLOAT4_E2M1_MAX = 6.0 FLOAT4_E2M1_MAX = 6.0
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
def cast_to_fp4(x): def cast_to_fp4(x):
sign = torch.sign(x) sign = torch.sign(x)
x = torch.abs(x) x = torch.abs(x)
......
...@@ -53,9 +53,7 @@ def break_fp4_bytes(a, dtype): ...@@ -53,9 +53,7 @@ def break_fp4_bytes(a, dtype):
return out return out
def dequantize_to_dtype( def dequantize_to_dtype(tensor_fp4, tensor_sf, global_scale, dtype, device, block_size=16):
tensor_fp4, tensor_sf, global_scale, dtype, device, block_size=16
):
"""Dequantize the fp4 tensor back to high precision.""" """Dequantize the fp4 tensor back to high precision."""
# Two fp4 values are packed into one uint8. # Two fp4 values are packed into one uint8.
assert tensor_fp4.dtype == torch.uint8 assert tensor_fp4.dtype == torch.uint8
...@@ -88,12 +86,8 @@ def get_ref_results( ...@@ -88,12 +86,8 @@ def get_ref_results(
_, m_k = a_fp4.shape _, m_k = a_fp4.shape
_, n_k = b_fp4.shape _, n_k = b_fp4.shape
assert m_k == n_k assert m_k == n_k
a_in_dtype = dequantize_to_dtype( a_in_dtype = dequantize_to_dtype(a_fp4, a_sf, a_global_scale, dtype=dtype, device=device, block_size=block_size)
a_fp4, a_sf, a_global_scale, dtype=dtype, device=device, block_size=block_size b_in_dtype = dequantize_to_dtype(b_fp4, b_sf, b_global_scale, dtype=dtype, device=device, block_size=block_size)
)
b_in_dtype = dequantize_to_dtype(
b_fp4, b_sf, b_global_scale, dtype=dtype, device=device, block_size=block_size
)
return torch.matmul(a_in_dtype, b_in_dtype.t()) return torch.matmul(a_in_dtype, b_in_dtype.t())
...@@ -109,12 +103,8 @@ def test_nvfp4_gemm( ...@@ -109,12 +103,8 @@ def test_nvfp4_gemm(
b_dtype = torch.randn((n, k), dtype=dtype, device="cuda") b_dtype = torch.randn((n, k), dtype=dtype, device="cuda")
bias = torch.randn((1, n), dtype=dtype, device="cuda") bias = torch.randn((1, n), dtype=dtype, device="cuda")
a_global_scale = ( a_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(a_dtype.flatten(), dim=-1)).to(torch.float32)
(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(a_dtype.flatten(), dim=-1) b_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(b_dtype.flatten(), dim=-1)).to(torch.float32)
).to(torch.float32)
b_global_scale = (
(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(b_dtype.flatten(), dim=-1)
).to(torch.float32)
print(f"a_global_scale : {a_global_scale}, {a_global_scale.shape}") print(f"a_global_scale : {a_global_scale}, {a_global_scale.shape}")
print(f"b_global_scale : {b_global_scale}, {b_global_scale.shape}") print(f"b_global_scale : {b_global_scale}, {b_global_scale.shape}")
...@@ -138,13 +128,9 @@ def test_nvfp4_gemm( ...@@ -138,13 +128,9 @@ def test_nvfp4_gemm(
) )
expected_out = expected_out + bias expected_out = expected_out + bias
print(f"alpha {alpha}, {alpha.shape}, {alpha.dtype}") print(f"alpha {alpha}, {alpha.shape}, {alpha.dtype}")
out = cutlass_scaled_fp4_mm(a_fp4, b_fp4, a_scale_interleaved, b_scale_interleaved, alpha, bias)
out = cutlass_scaled_fp4_mm(
a_fp4, b_fp4, a_scale_interleaved, b_scale_interleaved, alpha, bias
)
print(f"out : {out}, {out.shape}, {out.dtype}") print(f"out : {out}, {out.shape}, {out.dtype}")
print(f"expected_out : {expected_out}, {expected_out.shape}, {expected_out.dtype}") print(f"expected_out : {expected_out}, {expected_out.shape}, {expected_out.dtype}")
......
...@@ -33,7 +33,6 @@ class MMWeightFp4: ...@@ -33,7 +33,6 @@ class MMWeightFp4:
return scaled_fp4_quant(x, self.input_global_scale) return scaled_fp4_quant(x, self.input_global_scale)
def test_speed(m, k, n): def test_speed(m, k, n):
with torch.no_grad(): with torch.no_grad():
input_tensor = torch.randn(m, k, dtype=torch.bfloat16).cuda() input_tensor = torch.randn(m, k, dtype=torch.bfloat16).cuda()
...@@ -56,8 +55,6 @@ def test_speed(m, k, n): ...@@ -56,8 +55,6 @@ def test_speed(m, k, n):
lightx2v_kernel_time = (end_time - start_time) / 100 lightx2v_kernel_time = (end_time - start_time) / 100
print(f"lightx2v-kernel time: {lightx2v_kernel_time}") print(f"lightx2v-kernel time: {lightx2v_kernel_time}")
input_tensor = torch.randn(m, n, dtype=torch.bfloat16).cuda() input_tensor = torch.randn(m, n, dtype=torch.bfloat16).cuda()
weight = torch.randn(k, n, dtype=torch.bfloat16, device="cuda") weight = torch.randn(k, n, dtype=torch.bfloat16, device="cuda")
bias = torch.randn(1, k, dtype=torch.bfloat16).cuda() bias = torch.randn(1, k, dtype=torch.bfloat16).cuda()
...@@ -107,19 +104,15 @@ def test_accuracy(m, k, n): ...@@ -107,19 +104,15 @@ def test_accuracy(m, k, n):
print(f"cos : {cos}") print(f"cos : {cos}")
if __name__ == "__main__": if __name__ == "__main__":
test_sizes = [ test_sizes = [
(32130, 5120, 5120), (32130, 5120, 5120),
(512, 5120, 5120), (512, 5120, 5120),
(257, 5120, 5120), (257, 5120, 5120),
(32130, 5120, 13824), (32130, 5120, 13824),
(32130, 13824, 5120), (32130, 13824, 5120),
(75348, 5120, 5120), (75348, 5120, 5120),
(75348, 13824, 5120), (75348, 13824, 5120),
(32760, 1536, 1536), (32760, 1536, 1536),
(512, 1536, 1536), (512, 1536, 1536),
(32760, 1536, 8960), (32760, 1536, 8960),
...@@ -128,7 +121,6 @@ if __name__ == "__main__": ...@@ -128,7 +121,6 @@ if __name__ == "__main__":
for i, (m, k, n) in enumerate(test_sizes): for i, (m, k, n) in enumerate(test_sizes):
print("-" * 30) print("-" * 30)
print(f"测试 {i+1}: 张量大小 ({m}, {k}, {n})") print(f"测试 {i + 1}: 张量大小 ({m}, {k}, {n})")
test_accuracy(m, k, n) test_accuracy(m, k, n)
test_speed(m, k, n) test_speed(m, k, n)
...@@ -25,8 +25,6 @@ def test_speed(m, k, n): ...@@ -25,8 +25,6 @@ def test_speed(m, k, n):
lightx2v_kernel_time = (end_time - start_time) / 100 lightx2v_kernel_time = (end_time - start_time) / 100
print(f"lightx2v-kernel time: {lightx2v_kernel_time}") print(f"lightx2v-kernel time: {lightx2v_kernel_time}")
input_tensor = torch.randn(m, n, dtype=torch.bfloat16).cuda() input_tensor = torch.randn(m, n, dtype=torch.bfloat16).cuda()
weight = torch.randn(k, n, dtype=torch.bfloat16, device="cuda") weight = torch.randn(k, n, dtype=torch.bfloat16, device="cuda")
bias = torch.randn(1, k, dtype=torch.bfloat16).cuda() bias = torch.randn(1, k, dtype=torch.bfloat16).cuda()
...@@ -75,19 +73,15 @@ def test_accuracy(m, k, n): ...@@ -75,19 +73,15 @@ def test_accuracy(m, k, n):
print(f"cos : {cos}") print(f"cos : {cos}")
if __name__ == "__main__": if __name__ == "__main__":
test_sizes = [ test_sizes = [
(32130, 5120, 5120), (32130, 5120, 5120),
(512, 5120, 5120), (512, 5120, 5120),
(257, 5120, 5120), (257, 5120, 5120),
(32130, 5120, 13824), (32130, 5120, 13824),
(32130, 13824, 5120), (32130, 13824, 5120),
(75348, 5120, 5120), (75348, 5120, 5120),
(75348, 13824, 5120), (75348, 13824, 5120),
(32760, 1536, 1536), (32760, 1536, 1536),
(512, 1536, 1536), (512, 1536, 1536),
(32760, 1536, 8960), (32760, 1536, 8960),
...@@ -96,7 +90,6 @@ if __name__ == "__main__": ...@@ -96,7 +90,6 @@ if __name__ == "__main__":
for i, (m, k, n) in enumerate(test_sizes): for i, (m, k, n) in enumerate(test_sizes):
print("-" * 30) print("-" * 30)
print(f"测试 {i+1}: 张量大小 ({m}, {k}, {n})") print(f"测试 {i + 1}: 张量大小 ({m}, {k}, {n})")
test_accuracy(m, k, n) test_accuracy(m, k, n)
test_speed(m, k, n) test_speed(m, k, n)
...@@ -14,6 +14,7 @@ alpha = torch.tensor(0.0002765655517578125, device="cuda").to(torch.float32) ...@@ -14,6 +14,7 @@ alpha = torch.tensor(0.0002765655517578125, device="cuda").to(torch.float32)
bias = None bias = None
""" """
def test_mm(input_tensor_quant, weight, input_tensor_scale, weight_scale, alpha, bias): def test_mm(input_tensor_quant, weight, input_tensor_scale, weight_scale, alpha, bias):
output_tensor = cutlass_scaled_fp4_mm(input_tensor_quant, weight, input_tensor_scale, weight_scale, alpha=alpha, bias=bias) output_tensor = cutlass_scaled_fp4_mm(input_tensor_quant, weight, input_tensor_scale, weight_scale, alpha=alpha, bias=bias)
return output_tensor return output_tensor
...@@ -28,7 +29,6 @@ def test_tflops(input_shape, weight_shape, num_warmup=10, num_runs=100): ...@@ -28,7 +29,6 @@ def test_tflops(input_shape, weight_shape, num_warmup=10, num_runs=100):
input_tensor_quant = (torch.rand((input_shape[0], input_shape[1] // 2), device="cuda") * 10).to(torch.uint8) input_tensor_quant = (torch.rand((input_shape[0], input_shape[1] // 2), device="cuda") * 10).to(torch.uint8)
weight = (torch.rand((weight_shape[0], weight_shape[1] // 2), device="cuda") * 10).to(torch.uint8) weight = (torch.rand((weight_shape[0], weight_shape[1] // 2), device="cuda") * 10).to(torch.uint8)
input_tensor_scale = torch.rand(((input_shape[0] + 128 - 1) // 128) * 128, (input_shape[1] // 16 + 4 - 1) // 4 * 4, device="cuda").to(torch.float8_e4m3fn) input_tensor_scale = torch.rand(((input_shape[0] + 128 - 1) // 128) * 128, (input_shape[1] // 16 + 4 - 1) // 4 * 4, device="cuda").to(torch.float8_e4m3fn)
weight_scale = torch.rand(weight_shape[0], weight_shape[1] // 16, device="cuda").to(torch.float8_e4m3fn) weight_scale = torch.rand(weight_shape[0], weight_shape[1] // 16, device="cuda").to(torch.float8_e4m3fn)
alpha = torch.tensor(0.0002765655517578125, device="cuda", dtype=torch.float32) alpha = torch.tensor(0.0002765655517578125, device="cuda", dtype=torch.float32)
...@@ -76,9 +76,9 @@ def test_tflops(input_shape, weight_shape, num_warmup=10, num_runs=100): ...@@ -76,9 +76,9 @@ def test_tflops(input_shape, weight_shape, num_warmup=10, num_runs=100):
print(f" 输出形状: ({M}, {N})") print(f" 输出形状: ({M}, {N})")
print(f" 运行次数: {num_runs}") print(f" 运行次数: {num_runs}")
print(f" 总执行时间: {elapsed_time_ms:.2f} ms") print(f" 总执行时间: {elapsed_time_ms:.2f} ms")
print(f" 平均每次执行时间: {elapsed_time_ms/num_runs:.4f} ms") print(f" 平均每次执行时间: {elapsed_time_ms / num_runs:.4f} ms")
print(f" 每次运行FLOPS: {flops_per_run/1e9:.2f} GFLOPS") print(f" 每次运行FLOPS: {flops_per_run / 1e9:.2f} GFLOPS")
print(f" 总FLOPS: {total_flops/1e12:.2f} TFLOPS") print(f" 总FLOPS: {total_flops / 1e12:.2f} TFLOPS")
print(f" 计算性能: {tflops:.2f} TFLOPS") print(f" 计算性能: {tflops:.2f} TFLOPS")
return tflops return tflops
...@@ -93,11 +93,9 @@ if __name__ == "__main__": ...@@ -93,11 +93,9 @@ if __name__ == "__main__":
((257, 5120), (5120, 5120)), ((257, 5120), (5120, 5120)),
((32130, 5120), (13824, 5120)), ((32130, 5120), (13824, 5120)),
((32130, 13824), (5120, 13824)), ((32130, 13824), (5120, 13824)),
((75348, 5120), (5120, 5120)), ((75348, 5120), (5120, 5120)),
((75348, 5120), (13824, 5120)), ((75348, 5120), (13824, 5120)),
((75348, 13824), (5120, 13824)), ((75348, 13824), (5120, 13824)),
((32760, 1536), (1536, 1536)), ((32760, 1536), (1536, 1536)),
((512, 1536), (1536, 1536)), ((512, 1536), (1536, 1536)),
((32760, 1536), (8960, 1536)), ((32760, 1536), (8960, 1536)),
...@@ -107,7 +105,7 @@ if __name__ == "__main__": ...@@ -107,7 +105,7 @@ if __name__ == "__main__":
print("=== test_mm TFLOPS性能测试 ===\n") print("=== test_mm TFLOPS性能测试 ===\n")
for i, (input_shape, weight_shape) in enumerate(test_cases): for i, (input_shape, weight_shape) in enumerate(test_cases):
print(f"测试 {i+1}: 输入形状 {input_shape}, 权重形状 {weight_shape}") print(f"测试 {i + 1}: 输入形状 {input_shape}, 权重形状 {weight_shape}")
print("-" * 60) print("-" * 60)
tflops = test_tflops(input_shape, weight_shape) tflops = test_tflops(input_shape, weight_shape)
......
...@@ -4,6 +4,7 @@ from lightx2v_kernel.gemm import scaled_fp4_quant ...@@ -4,6 +4,7 @@ from lightx2v_kernel.gemm import scaled_fp4_quant
input_global_scale = torch.tensor(808.0, dtype=torch.float32).cuda() input_global_scale = torch.tensor(808.0, dtype=torch.float32).cuda()
def quantize_fp4(x): def quantize_fp4(x):
return scaled_fp4_quant(x, input_global_scale) return scaled_fp4_quant(x, input_global_scale)
...@@ -40,7 +41,6 @@ def test_memory_bandwidth(func, x, num_warmup=10, num_runs=100): ...@@ -40,7 +41,6 @@ def test_memory_bandwidth(func, x, num_warmup=10, num_runs=100):
# FP4量化后,每个元素占用0.5字节 # FP4量化后,每个元素占用0.5字节
output_bytes = x.numel() * 0.5 # FP4输出数据字节数 output_bytes = x.numel() * 0.5 # FP4输出数据字节数
scale_bytes = x.numel() / 16 # group_size = 16 scale_bytes = x.numel() / 16 # group_size = 16
# 总数据传输量(读取输入 + 写入输出 + scale) # 总数据传输量(读取输入 + 写入输出 + scale)
...@@ -54,7 +54,7 @@ def test_memory_bandwidth(func, x, num_warmup=10, num_runs=100): ...@@ -54,7 +54,7 @@ def test_memory_bandwidth(func, x, num_warmup=10, num_runs=100):
print(f" 输入数据类型: {x.dtype}") print(f" 输入数据类型: {x.dtype}")
print(f" 运行次数: {num_runs}") print(f" 运行次数: {num_runs}")
print(f" 总执行时间: {elapsed_time_ms:.2f} ms") print(f" 总执行时间: {elapsed_time_ms:.2f} ms")
print(f" 平均每次执行时间: {elapsed_time_ms/num_runs:.4f} ms") print(f" 平均每次执行时间: {elapsed_time_ms / num_runs:.4f} ms")
print(f" 输入数据大小: {input_bytes / (1024**2):.2f} MB") print(f" 输入数据大小: {input_bytes / (1024**2):.2f} MB")
print(f" 输出数据大小: {output_bytes / (1024**2):.2f} MB") print(f" 输出数据大小: {output_bytes / (1024**2):.2f} MB")
print(f" 总数据传输量: {total_bytes / (1024**3):.2f} GB") print(f" 总数据传输量: {total_bytes / (1024**3):.2f} GB")
...@@ -132,16 +132,12 @@ if __name__ == "__main__": ...@@ -132,16 +132,12 @@ if __name__ == "__main__":
# (32768, 8192), # (32768, 8192),
# (32768, 16384), # (32768, 16384),
# (32768, 32768), # (32768, 32768),
(32130, 5120), (32130, 5120),
(512, 5120), (512, 5120),
(257, 5120), (257, 5120),
(32130, 13824), (32130, 13824),
(75348, 5120), (75348, 5120),
(75348, 13824), (75348, 13824),
(32760, 1536), (32760, 1536),
(512, 1536), (512, 1536),
(32760, 8960), (32760, 8960),
...@@ -150,7 +146,7 @@ if __name__ == "__main__": ...@@ -150,7 +146,7 @@ if __name__ == "__main__":
print("=== quantize_fp4 显存带宽测试 ===\n") print("=== quantize_fp4 显存带宽测试 ===\n")
for i, (h, w) in enumerate(test_sizes): for i, (h, w) in enumerate(test_sizes):
print(f"测试 {i+1}: 张量大小 ({h}, {w})") print(f"测试 {i + 1}: 张量大小 ({h}, {w})")
print("-" * 50) print("-" * 50)
x = torch.randn(h, w, dtype=torch.bfloat16).cuda() x = torch.randn(h, w, dtype=torch.bfloat16).cuda()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment