Commit c7bb59cd authored by helloyongyang's avatar helloyongyang
Browse files

fix ci

parent 01caaf29
...@@ -61,7 +61,7 @@ struct Fp4GemmSm120 { ...@@ -61,7 +61,7 @@ struct Fp4GemmSm120 {
using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, OperatorClass, ArchTag, OperatorClass,
ThreadBlockShape, ClusterShape, ThreadBlockShape, ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto, cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementAccumulator, ElementAccumulator, ElementAccumulator,
......
...@@ -14,6 +14,4 @@ from lightx2v_kernel import common_ops ...@@ -14,6 +14,4 @@ from lightx2v_kernel import common_ops
from lightx2v_kernel.gemm import cutlass_scaled_fp4_mm, scaled_fp4_quant from lightx2v_kernel.gemm import cutlass_scaled_fp4_mm, scaled_fp4_quant
from lightx2v_kernel.version import __version__ from lightx2v_kernel.version import __version__
build_tree_kernel = ( build_tree_kernel = None
None
)
...@@ -12,15 +12,11 @@ def cutlass_scaled_fp4_mm(mat_a, mat_b, scales_a, scales_b, alpha, bias=None): ...@@ -12,15 +12,11 @@ def cutlass_scaled_fp4_mm(mat_a, mat_b, scales_a, scales_b, alpha, bias=None):
""" """
m, n = mat_a.shape[0], mat_b.shape[0] m, n = mat_a.shape[0], mat_b.shape[0]
out = torch.empty((m, n), dtype=torch.bfloat16, device=mat_a.device) out = torch.empty((m, n), dtype=torch.bfloat16, device=mat_a.device)
torch.ops.lightx2v_kernel.cutlass_scaled_fp4_mm_sm120.default( torch.ops.lightx2v_kernel.cutlass_scaled_fp4_mm_sm120.default(out, mat_a, mat_b, scales_a, scales_b, alpha, bias)
out, mat_a, mat_b, scales_a, scales_b, alpha, bias
)
return out return out
def scaled_fp4_quant( def scaled_fp4_quant(input: torch.Tensor, input_global_scale: torch.Tensor):
input: torch.Tensor, input_global_scale: torch.Tensor
):
""" """
Quantize input tensor to FP4 and return quantized tensor and scale. Quantize input tensor to FP4 and return quantized tensor and scale.
...@@ -60,13 +56,8 @@ def scaled_fp4_quant( ...@@ -60,13 +56,8 @@ def scaled_fp4_quant(
# rounded_m = ((m + 128 - 1) // 128) * 128 # rounded_m = ((m + 128 - 1) // 128) * 128
# scale_n = n // block_size # scale_n = n // block_size
# rounded_n = ((scale_n + 4 - 1) // 4) * 4 # rounded_n = ((scale_n + 4 - 1) // 4) * 4
output_scale = torch.empty( output_scale = torch.empty((((m + 128 - 1) // 128) * 128, (n // block_size + 4 - 1) // 4), device=device, dtype=torch.int32)
(((m + 128 - 1) // 128) * 128, (n // block_size + 4 - 1) // 4), device=device, dtype=torch.int32
)
torch.ops.lightx2v_kernel.scaled_fp4_quant_sm120.default( torch.ops.lightx2v_kernel.scaled_fp4_quant_sm120.default(output, input, output_scale, input_global_scale)
output, input, output_scale, input_global_scale
)
output_scale = output_scale.view(torch.float8_e4m3fn) output_scale = output_scale.view(torch.float8_e4m3fn)
return output, output_scale return output, output_scale
...@@ -6,6 +6,7 @@ BLOCK_SIZE = 16 ...@@ -6,6 +6,7 @@ BLOCK_SIZE = 16
FLOAT4_E2M1_MAX = 6.0 FLOAT4_E2M1_MAX = 6.0
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
def cast_to_fp4(x): def cast_to_fp4(x):
sign = torch.sign(x) sign = torch.sign(x)
x = torch.abs(x) x = torch.abs(x)
......
...@@ -53,9 +53,7 @@ def break_fp4_bytes(a, dtype): ...@@ -53,9 +53,7 @@ def break_fp4_bytes(a, dtype):
return out return out
def dequantize_to_dtype( def dequantize_to_dtype(tensor_fp4, tensor_sf, global_scale, dtype, device, block_size=16):
tensor_fp4, tensor_sf, global_scale, dtype, device, block_size=16
):
"""Dequantize the fp4 tensor back to high precision.""" """Dequantize the fp4 tensor back to high precision."""
# Two fp4 values are packed into one uint8. # Two fp4 values are packed into one uint8.
assert tensor_fp4.dtype == torch.uint8 assert tensor_fp4.dtype == torch.uint8
...@@ -88,12 +86,8 @@ def get_ref_results( ...@@ -88,12 +86,8 @@ def get_ref_results(
_, m_k = a_fp4.shape _, m_k = a_fp4.shape
_, n_k = b_fp4.shape _, n_k = b_fp4.shape
assert m_k == n_k assert m_k == n_k
a_in_dtype = dequantize_to_dtype( a_in_dtype = dequantize_to_dtype(a_fp4, a_sf, a_global_scale, dtype=dtype, device=device, block_size=block_size)
a_fp4, a_sf, a_global_scale, dtype=dtype, device=device, block_size=block_size b_in_dtype = dequantize_to_dtype(b_fp4, b_sf, b_global_scale, dtype=dtype, device=device, block_size=block_size)
)
b_in_dtype = dequantize_to_dtype(
b_fp4, b_sf, b_global_scale, dtype=dtype, device=device, block_size=block_size
)
return torch.matmul(a_in_dtype, b_in_dtype.t()) return torch.matmul(a_in_dtype, b_in_dtype.t())
...@@ -109,16 +103,12 @@ def test_nvfp4_gemm( ...@@ -109,16 +103,12 @@ def test_nvfp4_gemm(
b_dtype = torch.randn((n, k), dtype=dtype, device="cuda") b_dtype = torch.randn((n, k), dtype=dtype, device="cuda")
bias = torch.randn((1, n), dtype=dtype, device="cuda") bias = torch.randn((1, n), dtype=dtype, device="cuda")
a_global_scale = ( a_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(a_dtype.flatten(), dim=-1)).to(torch.float32)
(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(a_dtype.flatten(), dim=-1) b_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(b_dtype.flatten(), dim=-1)).to(torch.float32)
).to(torch.float32)
b_global_scale = (
(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(b_dtype.flatten(), dim=-1)
).to(torch.float32)
print(f"a_global_scale : {a_global_scale}, {a_global_scale.shape}") print(f"a_global_scale : {a_global_scale}, {a_global_scale.shape}")
print(f"b_global_scale : {b_global_scale}, {b_global_scale.shape}") print(f"b_global_scale : {b_global_scale}, {b_global_scale.shape}")
alpha = 1.0 / (a_global_scale * b_global_scale) alpha = 1.0 / (a_global_scale * b_global_scale)
a_fp4, a_scale_interleaved = scaled_fp4_quant(a_dtype, a_global_scale) a_fp4, a_scale_interleaved = scaled_fp4_quant(a_dtype, a_global_scale)
b_fp4, b_scale_interleaved = scaled_fp4_quant(b_dtype, b_global_scale) b_fp4, b_scale_interleaved = scaled_fp4_quant(b_dtype, b_global_scale)
...@@ -137,15 +127,11 @@ def test_nvfp4_gemm( ...@@ -137,15 +127,11 @@ def test_nvfp4_gemm(
"cuda", "cuda",
) )
expected_out = expected_out + bias expected_out = expected_out + bias
print(f"alpha {alpha}, {alpha.shape}, {alpha.dtype}") print(f"alpha {alpha}, {alpha.shape}, {alpha.dtype}")
out = cutlass_scaled_fp4_mm(a_fp4, b_fp4, a_scale_interleaved, b_scale_interleaved, alpha, bias)
out = cutlass_scaled_fp4_mm(
a_fp4, b_fp4, a_scale_interleaved, b_scale_interleaved, alpha, bias
)
print(f"out : {out}, {out.shape}, {out.dtype}") print(f"out : {out}, {out.shape}, {out.dtype}")
print(f"expected_out : {expected_out}, {expected_out.shape}, {expected_out.dtype}") print(f"expected_out : {expected_out}, {expected_out.shape}, {expected_out.dtype}")
......
...@@ -7,7 +7,7 @@ class MMWeightFp4: ...@@ -7,7 +7,7 @@ class MMWeightFp4:
def __init__(self, weight, bias): def __init__(self, weight, bias):
self.load_fp4_weight(weight, bias) self.load_fp4_weight(weight, bias)
self.act_quant_func = self.act_quant_fp4 self.act_quant_func = self.act_quant_fp4
# calibrate x_max # calibrate x_max
self.calibrate_x_absmax() self.calibrate_x_absmax()
...@@ -24,7 +24,7 @@ class MMWeightFp4: ...@@ -24,7 +24,7 @@ class MMWeightFp4:
self.bias = bias self.bias = bias
def calibrate_x_absmax(self): def calibrate_x_absmax(self):
self.x_absmax = torch.tensor(5.0, dtype=torch.float32, device=self.weight.device) # need to be calibrated self.x_absmax = torch.tensor(5.0, dtype=torch.float32, device=self.weight.device) # need to be calibrated
self.input_global_scale = (2688.0 / self.x_absmax).to(torch.float32) self.input_global_scale = (2688.0 / self.x_absmax).to(torch.float32)
self.alpha = 1.0 / (self.input_global_scale * self.weight_global_scale) self.alpha = 1.0 / (self.input_global_scale * self.weight_global_scale)
...@@ -33,7 +33,6 @@ class MMWeightFp4: ...@@ -33,7 +33,6 @@ class MMWeightFp4:
return scaled_fp4_quant(x, self.input_global_scale) return scaled_fp4_quant(x, self.input_global_scale)
def test_speed(m, k, n): def test_speed(m, k, n):
with torch.no_grad(): with torch.no_grad():
input_tensor = torch.randn(m, k, dtype=torch.bfloat16).cuda() input_tensor = torch.randn(m, k, dtype=torch.bfloat16).cuda()
...@@ -42,26 +41,24 @@ def test_speed(m, k, n): ...@@ -42,26 +41,24 @@ def test_speed(m, k, n):
bias = None bias = None
mm = MMWeightFp4(weight, bias) mm = MMWeightFp4(weight, bias)
# warmup # warmup
output_tensor = mm.apply(input_tensor) output_tensor = mm.apply(input_tensor)
torch.cuda.synchronize() torch.cuda.synchronize()
start_time = time.time() start_time = time.time()
for i in range(100): for i in range(100):
output_tensor = mm.apply(input_tensor) output_tensor = mm.apply(input_tensor)
torch.cuda.synchronize() torch.cuda.synchronize()
end_time = time.time() end_time = time.time()
lightx2v_kernel_time = (end_time - start_time) / 100 lightx2v_kernel_time = (end_time - start_time) / 100
print(f"lightx2v-kernel time: {lightx2v_kernel_time}") print(f"lightx2v-kernel time: {lightx2v_kernel_time}")
input_tensor = torch.randn(m, n, dtype=torch.bfloat16).cuda() input_tensor = torch.randn(m, n, dtype=torch.bfloat16).cuda()
weight = torch.randn(k, n, dtype=torch.bfloat16, device="cuda") weight = torch.randn(k, n, dtype=torch.bfloat16, device="cuda")
bias = torch.randn(1, k, dtype=torch.bfloat16).cuda() bias = torch.randn(1, k, dtype=torch.bfloat16).cuda()
linear = torch.nn.Linear(k, n, bias=False).cuda() linear = torch.nn.Linear(k, n, bias=False).cuda()
linear.weight.data = weight linear.weight.data = weight
# linear.bias.data = bias # linear.bias.data = bias
...@@ -72,13 +69,13 @@ def test_speed(m, k, n): ...@@ -72,13 +69,13 @@ def test_speed(m, k, n):
torch.cuda.synchronize() torch.cuda.synchronize()
start_time = time.time() start_time = time.time()
for i in range(100): for i in range(100):
ref_output_tensor = linear(input_tensor) ref_output_tensor = linear(input_tensor)
torch.cuda.synchronize() torch.cuda.synchronize()
end_time = time.time() end_time = time.time()
ref_time = (end_time - start_time) / 100 ref_time = (end_time - start_time) / 100
print(f"ref time: {ref_time}") print(f"ref time: {ref_time}")
print(f"speedup: {ref_time / lightx2v_kernel_time:.3f}") print(f"speedup: {ref_time / lightx2v_kernel_time:.3f}")
...@@ -88,47 +85,42 @@ def test_accuracy(m, k, n): ...@@ -88,47 +85,42 @@ def test_accuracy(m, k, n):
weight = torch.randn(n, k, dtype=torch.bfloat16, device="cuda") weight = torch.randn(n, k, dtype=torch.bfloat16, device="cuda")
# bias = torch.randn(1, n, dtype=torch.bfloat16).cuda() # bias = torch.randn(1, n, dtype=torch.bfloat16).cuda()
bias = None bias = None
linear = torch.nn.Linear(k, n, bias=False).cuda() linear = torch.nn.Linear(k, n, bias=False).cuda()
linear.weight.data = weight linear.weight.data = weight
# linear.bias.data = bias # linear.bias.data = bias
ref_output_tensor = linear(input_tensor) ref_output_tensor = linear(input_tensor)
mm = MMWeightFp4(weight, bias) mm = MMWeightFp4(weight, bias)
output_tensor = mm.apply(input_tensor) output_tensor = mm.apply(input_tensor)
# print(f"ref_output_tensor: {ref_output_tensor}") # print(f"ref_output_tensor: {ref_output_tensor}")
# print(f"output_tensor: {output_tensor}") # print(f"output_tensor: {output_tensor}")
# cosine # cosine
cos = torch.nn.functional.cosine_similarity(ref_output_tensor.flatten(), output_tensor.flatten(), dim=0) cos = torch.nn.functional.cosine_similarity(ref_output_tensor.flatten(), output_tensor.flatten(), dim=0)
print(f"cos : {cos}") print(f"cos : {cos}")
if __name__ == "__main__": if __name__ == "__main__":
test_sizes = [ test_sizes = [
(32130, 5120, 5120), (32130, 5120, 5120),
(512, 5120, 5120), (512, 5120, 5120),
(257, 5120, 5120), (257, 5120, 5120),
(32130, 5120, 13824), (32130, 5120, 13824),
(32130, 13824, 5120), (32130, 13824, 5120),
(75348, 5120, 5120), (75348, 5120, 5120),
(75348, 13824, 5120), (75348, 13824, 5120),
(32760, 1536, 1536), (32760, 1536, 1536),
(512, 1536, 1536), (512, 1536, 1536),
(32760, 1536, 8960), (32760, 1536, 8960),
(32760, 8960, 1536), (32760, 8960, 1536),
] ]
for i, (m, k, n) in enumerate(test_sizes): for i, (m, k, n) in enumerate(test_sizes):
print("-" * 30) print("-" * 30)
print(f"测试 {i+1}: 张量大小 ({m}, {k}, {n})") print(f"测试 {i + 1}: 张量大小 ({m}, {k}, {n})")
test_accuracy(m, k, n) test_accuracy(m, k, n)
test_speed(m, k, n) test_speed(m, k, n)
...@@ -11,26 +11,24 @@ def test_speed(m, k, n): ...@@ -11,26 +11,24 @@ def test_speed(m, k, n):
bias = torch.randn(1, n, dtype=torch.bfloat16).cuda() bias = torch.randn(1, n, dtype=torch.bfloat16).cuda()
mm = MMWeightFp4(weight, bias) mm = MMWeightFp4(weight, bias)
# warmup # warmup
output_tensor = mm.apply(input_tensor) output_tensor = mm.apply(input_tensor)
torch.cuda.synchronize() torch.cuda.synchronize()
start_time = time.time() start_time = time.time()
for i in range(100): for i in range(100):
output_tensor = mm.apply(input_tensor) output_tensor = mm.apply(input_tensor)
torch.cuda.synchronize() torch.cuda.synchronize()
end_time = time.time() end_time = time.time()
lightx2v_kernel_time = (end_time - start_time) / 100 lightx2v_kernel_time = (end_time - start_time) / 100
print(f"lightx2v-kernel time: {lightx2v_kernel_time}") print(f"lightx2v-kernel time: {lightx2v_kernel_time}")
input_tensor = torch.randn(m, n, dtype=torch.bfloat16).cuda() input_tensor = torch.randn(m, n, dtype=torch.bfloat16).cuda()
weight = torch.randn(k, n, dtype=torch.bfloat16, device="cuda") weight = torch.randn(k, n, dtype=torch.bfloat16, device="cuda")
bias = torch.randn(1, k, dtype=torch.bfloat16).cuda() bias = torch.randn(1, k, dtype=torch.bfloat16).cuda()
linear = torch.nn.Linear(k, n, bias=True).cuda() linear = torch.nn.Linear(k, n, bias=True).cuda()
linear.weight.data = weight linear.weight.data = weight
linear.bias.data = bias linear.bias.data = bias
...@@ -41,13 +39,13 @@ def test_speed(m, k, n): ...@@ -41,13 +39,13 @@ def test_speed(m, k, n):
torch.cuda.synchronize() torch.cuda.synchronize()
start_time = time.time() start_time = time.time()
for i in range(100): for i in range(100):
ref_output_tensor = linear(input_tensor) ref_output_tensor = linear(input_tensor)
torch.cuda.synchronize() torch.cuda.synchronize()
end_time = time.time() end_time = time.time()
ref_time = (end_time - start_time) / 100 ref_time = (end_time - start_time) / 100
print(f"ref time: {ref_time}") print(f"ref time: {ref_time}")
print(f"speedup: {ref_time / lightx2v_kernel_time:.3f}") print(f"speedup: {ref_time / lightx2v_kernel_time:.3f}")
...@@ -56,47 +54,42 @@ def test_accuracy(m, k, n): ...@@ -56,47 +54,42 @@ def test_accuracy(m, k, n):
input_tensor = torch.randn(m, k, dtype=torch.bfloat16).cuda() input_tensor = torch.randn(m, k, dtype=torch.bfloat16).cuda()
weight = torch.randn(n, k, dtype=torch.bfloat16, device="cuda") weight = torch.randn(n, k, dtype=torch.bfloat16, device="cuda")
bias = torch.randn(1, n, dtype=torch.bfloat16).cuda() bias = torch.randn(1, n, dtype=torch.bfloat16).cuda()
linear = torch.nn.Linear(k, n, bias=True).cuda() linear = torch.nn.Linear(k, n, bias=True).cuda()
linear.weight.data = weight linear.weight.data = weight
linear.bias.data = bias linear.bias.data = bias
ref_output_tensor = linear(input_tensor) ref_output_tensor = linear(input_tensor)
mm = MMWeightFp4(weight, bias) mm = MMWeightFp4(weight, bias)
output_tensor = mm.apply(input_tensor) output_tensor = mm.apply(input_tensor)
# print(f"ref_output_tensor: {ref_output_tensor}") # print(f"ref_output_tensor: {ref_output_tensor}")
# print(f"output_tensor: {output_tensor}") # print(f"output_tensor: {output_tensor}")
# cosine # cosine
cos = torch.nn.functional.cosine_similarity(ref_output_tensor.flatten(), output_tensor.flatten(), dim=0) cos = torch.nn.functional.cosine_similarity(ref_output_tensor.flatten(), output_tensor.flatten(), dim=0)
print(f"cos : {cos}") print(f"cos : {cos}")
if __name__ == "__main__": if __name__ == "__main__":
test_sizes = [ test_sizes = [
(32130, 5120, 5120), (32130, 5120, 5120),
(512, 5120, 5120), (512, 5120, 5120),
(257, 5120, 5120), (257, 5120, 5120),
(32130, 5120, 13824), (32130, 5120, 13824),
(32130, 13824, 5120), (32130, 13824, 5120),
(75348, 5120, 5120), (75348, 5120, 5120),
(75348, 13824, 5120), (75348, 13824, 5120),
(32760, 1536, 1536), (32760, 1536, 1536),
(512, 1536, 1536), (512, 1536, 1536),
(32760, 1536, 8960), (32760, 1536, 8960),
(32760, 8960, 1536), (32760, 8960, 1536),
] ]
for i, (m, k, n) in enumerate(test_sizes): for i, (m, k, n) in enumerate(test_sizes):
print("-" * 30) print("-" * 30)
print(f"测试 {i+1}: 张量大小 ({m}, {k}, {n})") print(f"测试 {i + 1}: 张量大小 ({m}, {k}, {n})")
test_accuracy(m, k, n) test_accuracy(m, k, n)
test_speed(m, k, n) test_speed(m, k, n)
...@@ -14,6 +14,7 @@ alpha = torch.tensor(0.0002765655517578125, device="cuda").to(torch.float32) ...@@ -14,6 +14,7 @@ alpha = torch.tensor(0.0002765655517578125, device="cuda").to(torch.float32)
bias = None bias = None
""" """
def test_mm(input_tensor_quant, weight, input_tensor_scale, weight_scale, alpha, bias): def test_mm(input_tensor_quant, weight, input_tensor_scale, weight_scale, alpha, bias):
output_tensor = cutlass_scaled_fp4_mm(input_tensor_quant, weight, input_tensor_scale, weight_scale, alpha=alpha, bias=bias) output_tensor = cutlass_scaled_fp4_mm(input_tensor_quant, weight, input_tensor_scale, weight_scale, alpha=alpha, bias=bias)
return output_tensor return output_tensor
...@@ -23,64 +24,63 @@ def test_tflops(input_shape, weight_shape, num_warmup=10, num_runs=100): ...@@ -23,64 +24,63 @@ def test_tflops(input_shape, weight_shape, num_warmup=10, num_runs=100):
""" """
测试test_mm函数的TFLOPS性能 测试test_mm函数的TFLOPS性能
""" """
# 创建输入数据 # 创建输入数据
input_tensor_quant = (torch.rand((input_shape[0], input_shape[1] // 2), device="cuda") * 10).to(torch.uint8) input_tensor_quant = (torch.rand((input_shape[0], input_shape[1] // 2), device="cuda") * 10).to(torch.uint8)
weight = (torch.rand((weight_shape[0], weight_shape[1] // 2), device="cuda") * 10).to(torch.uint8) weight = (torch.rand((weight_shape[0], weight_shape[1] // 2), device="cuda") * 10).to(torch.uint8)
input_tensor_scale = torch.rand(((input_shape[0] + 128 - 1) // 128) * 128, (input_shape[1] // 16 + 4 - 1) // 4 * 4, device="cuda").to(torch.float8_e4m3fn) input_tensor_scale = torch.rand(((input_shape[0] + 128 - 1) // 128) * 128, (input_shape[1] // 16 + 4 - 1) // 4 * 4, device="cuda").to(torch.float8_e4m3fn)
weight_scale = torch.rand(weight_shape[0], weight_shape[1] // 16, device="cuda").to(torch.float8_e4m3fn) weight_scale = torch.rand(weight_shape[0], weight_shape[1] // 16, device="cuda").to(torch.float8_e4m3fn)
alpha = torch.tensor(0.0002765655517578125, device="cuda", dtype=torch.float32) alpha = torch.tensor(0.0002765655517578125, device="cuda", dtype=torch.float32)
bias = None bias = None
# 预热GPU # 预热GPU
for _ in range(num_warmup): for _ in range(num_warmup):
test_mm(input_tensor_quant, weight, input_tensor_scale, weight_scale, alpha, bias) test_mm(input_tensor_quant, weight, input_tensor_scale, weight_scale, alpha, bias)
# 同步GPU # 同步GPU
torch.cuda.synchronize() torch.cuda.synchronize()
# 创建GPU事件用于精确计时 # 创建GPU事件用于精确计时
start_event = torch.cuda.Event(enable_timing=True) start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True)
# 测量时间 # 测量时间
start_event.record() start_event.record()
for _ in range(num_runs): for _ in range(num_runs):
result = test_mm(input_tensor_quant, weight, input_tensor_scale, weight_scale, alpha, bias) result = test_mm(input_tensor_quant, weight, input_tensor_scale, weight_scale, alpha, bias)
end_event.record() end_event.record()
# 同步并计算时间 # 同步并计算时间
torch.cuda.synchronize() torch.cuda.synchronize()
elapsed_time_ms = start_event.elapsed_time(end_event) elapsed_time_ms = start_event.elapsed_time(end_event)
elapsed_time_s = elapsed_time_ms / 1000.0 elapsed_time_s = elapsed_time_ms / 1000.0
# 计算FLOPS # 计算FLOPS
# 矩阵乘法 A(M x K) @ B(K x N) = C(M x N) # 矩阵乘法 A(M x K) @ B(K x N) = C(M x N)
# M = batch_size, K = input_dim, N = output_dim # M = batch_size, K = input_dim, N = output_dim
M = input_shape[0] M = input_shape[0]
K = input_shape[1] K = input_shape[1]
N = weight_shape[0] N = weight_shape[0]
# 每次矩阵乘法的FLOPS = 2 * M * N * K (每个输出元素需要K次乘法和K次加法) # 每次矩阵乘法的FLOPS = 2 * M * N * K (每个输出元素需要K次乘法和K次加法)
flops_per_run = 2 * M * N * K flops_per_run = 2 * M * N * K
total_flops = flops_per_run * num_runs total_flops = flops_per_run * num_runs
# 计算TFLOPS (万亿次浮点运算每秒) # 计算TFLOPS (万亿次浮点运算每秒)
tflops = total_flops / (elapsed_time_s * 1e12) tflops = total_flops / (elapsed_time_s * 1e12)
print(f"测试结果:") print(f"测试结果:")
print(f" 输入形状: {input_shape} (M={M}, K={K})") print(f" 输入形状: {input_shape} (M={M}, K={K})")
print(f" 权重形状: {weight_shape} (N={N}, K={K})") print(f" 权重形状: {weight_shape} (N={N}, K={K})")
print(f" 输出形状: ({M}, {N})") print(f" 输出形状: ({M}, {N})")
print(f" 运行次数: {num_runs}") print(f" 运行次数: {num_runs}")
print(f" 总执行时间: {elapsed_time_ms:.2f} ms") print(f" 总执行时间: {elapsed_time_ms:.2f} ms")
print(f" 平均每次执行时间: {elapsed_time_ms/num_runs:.4f} ms") print(f" 平均每次执行时间: {elapsed_time_ms / num_runs:.4f} ms")
print(f" 每次运行FLOPS: {flops_per_run/1e9:.2f} GFLOPS") print(f" 每次运行FLOPS: {flops_per_run / 1e9:.2f} GFLOPS")
print(f" 总FLOPS: {total_flops/1e12:.2f} TFLOPS") print(f" 总FLOPS: {total_flops / 1e12:.2f} TFLOPS")
print(f" 计算性能: {tflops:.2f} TFLOPS") print(f" 计算性能: {tflops:.2f} TFLOPS")
return tflops return tflops
...@@ -93,24 +93,22 @@ if __name__ == "__main__": ...@@ -93,24 +93,22 @@ if __name__ == "__main__":
((257, 5120), (5120, 5120)), ((257, 5120), (5120, 5120)),
((32130, 5120), (13824, 5120)), ((32130, 5120), (13824, 5120)),
((32130, 13824), (5120, 13824)), ((32130, 13824), (5120, 13824)),
((75348, 5120), (5120, 5120)), ((75348, 5120), (5120, 5120)),
((75348, 5120), (13824, 5120)), ((75348, 5120), (13824, 5120)),
((75348, 13824), (5120, 13824)), ((75348, 13824), (5120, 13824)),
((32760, 1536), (1536, 1536)), ((32760, 1536), (1536, 1536)),
((512, 1536), (1536, 1536)), ((512, 1536), (1536, 1536)),
((32760, 1536), (8960, 1536)), ((32760, 1536), (8960, 1536)),
((32760, 8960), (1536, 8960)), ((32760, 8960), (1536, 8960)),
] ]
print("=== test_mm TFLOPS性能测试 ===\n") print("=== test_mm TFLOPS性能测试 ===\n")
for i, (input_shape, weight_shape) in enumerate(test_cases): for i, (input_shape, weight_shape) in enumerate(test_cases):
print(f"测试 {i+1}: 输入形状 {input_shape}, 权重形状 {weight_shape}") print(f"测试 {i + 1}: 输入形状 {input_shape}, 权重形状 {weight_shape}")
print("-" * 60) print("-" * 60)
tflops = test_tflops(input_shape, weight_shape) tflops = test_tflops(input_shape, weight_shape)
print(f"✓ 成功完成测试,性能: {tflops:.2f} TFLOPS\n") print(f"✓ 成功完成测试,性能: {tflops:.2f} TFLOPS\n")
print("=== 测试完成 ===") print("=== 测试完成 ===")
...@@ -4,7 +4,8 @@ from lightx2v_kernel.gemm import scaled_fp4_quant ...@@ -4,7 +4,8 @@ from lightx2v_kernel.gemm import scaled_fp4_quant
input_global_scale = torch.tensor(808.0, dtype=torch.float32).cuda() input_global_scale = torch.tensor(808.0, dtype=torch.float32).cuda()
def quantize_fp4(x):
def quantize_fp4(x):
return scaled_fp4_quant(x, input_global_scale) return scaled_fp4_quant(x, input_global_scale)
...@@ -15,51 +16,50 @@ def test_memory_bandwidth(func, x, num_warmup=10, num_runs=100): ...@@ -15,51 +16,50 @@ def test_memory_bandwidth(func, x, num_warmup=10, num_runs=100):
# 预热GPU # 预热GPU
for _ in range(num_warmup): for _ in range(num_warmup):
func(x) func(x)
# 同步GPU # 同步GPU
torch.cuda.synchronize() torch.cuda.synchronize()
# 创建GPU事件用于精确计时 # 创建GPU事件用于精确计时
start_event = torch.cuda.Event(enable_timing=True) start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True)
# 测量时间 # 测量时间
start_event.record() start_event.record()
for _ in range(num_runs): for _ in range(num_runs):
result = func(x) result = func(x)
end_event.record() end_event.record()
# 同步并计算时间 # 同步并计算时间
torch.cuda.synchronize() torch.cuda.synchronize()
elapsed_time_ms = start_event.elapsed_time(end_event) elapsed_time_ms = start_event.elapsed_time(end_event)
elapsed_time_s = elapsed_time_ms / 1000.0 elapsed_time_s = elapsed_time_ms / 1000.0
# 计算数据量 # 计算数据量
input_bytes = x.numel() * x.element_size() # 输入数据字节数 input_bytes = x.numel() * x.element_size() # 输入数据字节数
# FP4量化后,每个元素占用0.5字节 # FP4量化后,每个元素占用0.5字节
output_bytes = x.numel() * 0.5 # FP4输出数据字节数 output_bytes = x.numel() * 0.5 # FP4输出数据字节数
scale_bytes = x.numel() / 16 # group_size = 16 scale_bytes = x.numel() / 16 # group_size = 16
# 总数据传输量(读取输入 + 写入输出 + scale) # 总数据传输量(读取输入 + 写入输出 + scale)
total_bytes = (input_bytes + output_bytes + scale_bytes) * num_runs total_bytes = (input_bytes + output_bytes + scale_bytes) * num_runs
# 计算带宽 # 计算带宽
bandwidth_gbps = (total_bytes / elapsed_time_s) / (1024**3) # GB/s bandwidth_gbps = (total_bytes / elapsed_time_s) / (1024**3) # GB/s
print(f"测试结果:") print(f"测试结果:")
print(f" 输入张量形状: {x.shape}") print(f" 输入张量形状: {x.shape}")
print(f" 输入数据类型: {x.dtype}") print(f" 输入数据类型: {x.dtype}")
print(f" 运行次数: {num_runs}") print(f" 运行次数: {num_runs}")
print(f" 总执行时间: {elapsed_time_ms:.2f} ms") print(f" 总执行时间: {elapsed_time_ms:.2f} ms")
print(f" 平均每次执行时间: {elapsed_time_ms/num_runs:.4f} ms") print(f" 平均每次执行时间: {elapsed_time_ms / num_runs:.4f} ms")
print(f" 输入数据大小: {input_bytes / (1024**2):.2f} MB") print(f" 输入数据大小: {input_bytes / (1024**2):.2f} MB")
print(f" 输出数据大小: {output_bytes / (1024**2):.2f} MB") print(f" 输出数据大小: {output_bytes / (1024**2):.2f} MB")
print(f" 总数据传输量: {total_bytes / (1024**3):.2f} GB") print(f" 总数据传输量: {total_bytes / (1024**3):.2f} GB")
print(f" 显存带宽: {bandwidth_gbps:.2f} GB/s") print(f" 显存带宽: {bandwidth_gbps:.2f} GB/s")
return bandwidth_gbps return bandwidth_gbps
...@@ -132,33 +132,29 @@ if __name__ == "__main__": ...@@ -132,33 +132,29 @@ if __name__ == "__main__":
# (32768, 8192), # (32768, 8192),
# (32768, 16384), # (32768, 16384),
# (32768, 32768), # (32768, 32768),
(32130, 5120), (32130, 5120),
(512, 5120), (512, 5120),
(257, 5120), (257, 5120),
(32130, 13824), (32130, 13824),
(75348, 5120), (75348, 5120),
(75348, 13824), (75348, 13824),
(32760, 1536), (32760, 1536),
(512, 1536), (512, 1536),
(32760, 8960), (32760, 8960),
] ]
print("=== quantize_fp4 显存带宽测试 ===\n") print("=== quantize_fp4 显存带宽测试 ===\n")
for i, (h, w) in enumerate(test_sizes): for i, (h, w) in enumerate(test_sizes):
print(f"测试 {i+1}: 张量大小 ({h}, {w})") print(f"测试 {i + 1}: 张量大小 ({h}, {w})")
print("-" * 50) print("-" * 50)
x = torch.randn(h, w, dtype=torch.bfloat16).cuda() x = torch.randn(h, w, dtype=torch.bfloat16).cuda()
try: try:
bandwidth = test_memory_bandwidth(quantize_fp4, x) bandwidth = test_memory_bandwidth(quantize_fp4, x)
print(f"✓ 成功完成测试,带宽: {bandwidth:.2f} GB/s\n") print(f"✓ 成功完成测试,带宽: {bandwidth:.2f} GB/s\n")
except Exception as e: except Exception as e:
print(f"✗ 测试失败: {e}\n") print(f"✗ 测试失败: {e}\n")
print("=== 测试完成 ===") print("=== 测试完成 ===")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment