Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
01caaf29
Commit
01caaf29
authored
Jun 23, 2025
by
helloyongyang
Browse files
Add lightx2v_kernel for nvfp4
parent
ea618db2
Changes
23
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
382 additions
and
0 deletions
+382
-0
lightx2v_kernel/test/test_bench3_bias.py
lightx2v_kernel/test/test_bench3_bias.py
+102
-0
lightx2v_kernel/test/test_mm_tflops.py
lightx2v_kernel/test/test_mm_tflops.py
+116
-0
lightx2v_kernel/test/test_quant_mem_utils.py
lightx2v_kernel/test/test_quant_mem_utils.py
+164
-0
No files found.
lightx2v_kernel/test/test_bench3_bias.py
0 → 100644
View file @
01caaf29
import
torch
from
lightx2v_kernel.gemm
import
scaled_fp4_quant
,
cutlass_scaled_fp4_mm
import
time
from
test_bench2
import
MMWeightFp4
def
test_speed
(
m
,
k
,
n
):
with
torch
.
no_grad
():
input_tensor
=
torch
.
randn
(
m
,
k
,
dtype
=
torch
.
bfloat16
).
cuda
()
weight
=
torch
.
randn
(
n
,
k
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
bias
=
torch
.
randn
(
1
,
n
,
dtype
=
torch
.
bfloat16
).
cuda
()
mm
=
MMWeightFp4
(
weight
,
bias
)
# warmup
output_tensor
=
mm
.
apply
(
input_tensor
)
torch
.
cuda
.
synchronize
()
start_time
=
time
.
time
()
for
i
in
range
(
100
):
output_tensor
=
mm
.
apply
(
input_tensor
)
torch
.
cuda
.
synchronize
()
end_time
=
time
.
time
()
lightx2v_kernel_time
=
(
end_time
-
start_time
)
/
100
print
(
f
"lightx2v-kernel time:
{
lightx2v_kernel_time
}
"
)
input_tensor
=
torch
.
randn
(
m
,
n
,
dtype
=
torch
.
bfloat16
).
cuda
()
weight
=
torch
.
randn
(
k
,
n
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
bias
=
torch
.
randn
(
1
,
k
,
dtype
=
torch
.
bfloat16
).
cuda
()
linear
=
torch
.
nn
.
Linear
(
k
,
n
,
bias
=
True
).
cuda
()
linear
.
weight
.
data
=
weight
linear
.
bias
.
data
=
bias
# warmup
ref_output_tensor
=
linear
(
input_tensor
)
torch
.
cuda
.
synchronize
()
start_time
=
time
.
time
()
for
i
in
range
(
100
):
ref_output_tensor
=
linear
(
input_tensor
)
torch
.
cuda
.
synchronize
()
end_time
=
time
.
time
()
ref_time
=
(
end_time
-
start_time
)
/
100
print
(
f
"ref time:
{
ref_time
}
"
)
print
(
f
"speedup:
{
ref_time
/
lightx2v_kernel_time
:.
3
f
}
"
)
def
test_accuracy
(
m
,
k
,
n
):
with
torch
.
no_grad
():
input_tensor
=
torch
.
randn
(
m
,
k
,
dtype
=
torch
.
bfloat16
).
cuda
()
weight
=
torch
.
randn
(
n
,
k
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
bias
=
torch
.
randn
(
1
,
n
,
dtype
=
torch
.
bfloat16
).
cuda
()
linear
=
torch
.
nn
.
Linear
(
k
,
n
,
bias
=
True
).
cuda
()
linear
.
weight
.
data
=
weight
linear
.
bias
.
data
=
bias
ref_output_tensor
=
linear
(
input_tensor
)
mm
=
MMWeightFp4
(
weight
,
bias
)
output_tensor
=
mm
.
apply
(
input_tensor
)
# print(f"ref_output_tensor: {ref_output_tensor}")
# print(f"output_tensor: {output_tensor}")
# cosine
cos
=
torch
.
nn
.
functional
.
cosine_similarity
(
ref_output_tensor
.
flatten
(),
output_tensor
.
flatten
(),
dim
=
0
)
print
(
f
"cos :
{
cos
}
"
)
if
__name__
==
"__main__"
:
test_sizes
=
[
(
32130
,
5120
,
5120
),
(
512
,
5120
,
5120
),
(
257
,
5120
,
5120
),
(
32130
,
5120
,
13824
),
(
32130
,
13824
,
5120
),
(
75348
,
5120
,
5120
),
(
75348
,
13824
,
5120
),
(
32760
,
1536
,
1536
),
(
512
,
1536
,
1536
),
(
32760
,
1536
,
8960
),
(
32760
,
8960
,
1536
),
]
for
i
,
(
m
,
k
,
n
)
in
enumerate
(
test_sizes
):
print
(
"-"
*
30
)
print
(
f
"测试
{
i
+
1
}
: 张量大小 (
{
m
}
,
{
k
}
,
{
n
}
)"
)
test_accuracy
(
m
,
k
,
n
)
test_speed
(
m
,
k
,
n
)
lightx2v_kernel/test/test_mm_tflops.py
0 → 100644
View file @
01caaf29
import
torch
from
lightx2v_kernel.gemm
import
cutlass_scaled_fp4_mm
"""
input_shape = (1024, 2048)
weight_shape = (4096, 2048)
input_tensor_quant = (torch.rand((1024, 1024), device="cuda") * 10).to(torch.uint8)
weight = (torch.rand((4096, 1024), device="cuda") * 10).to(torch.uint8)
input_tensor_scale = torch.rand(1024, 128, device="cuda").to(torch.float8_e4m3fn)
weight_scale = torch.rand(4096, 128, device="cuda").to(torch.float8_e4m3fn)
alpha = torch.tensor(0.0002765655517578125, device="cuda").to(torch.float32)
bias = None
"""
def
test_mm
(
input_tensor_quant
,
weight
,
input_tensor_scale
,
weight_scale
,
alpha
,
bias
):
output_tensor
=
cutlass_scaled_fp4_mm
(
input_tensor_quant
,
weight
,
input_tensor_scale
,
weight_scale
,
alpha
=
alpha
,
bias
=
bias
)
return
output_tensor
def
test_tflops
(
input_shape
,
weight_shape
,
num_warmup
=
10
,
num_runs
=
100
):
"""
测试test_mm函数的TFLOPS性能
"""
# 创建输入数据
input_tensor_quant
=
(
torch
.
rand
((
input_shape
[
0
],
input_shape
[
1
]
//
2
),
device
=
"cuda"
)
*
10
).
to
(
torch
.
uint8
)
weight
=
(
torch
.
rand
((
weight_shape
[
0
],
weight_shape
[
1
]
//
2
),
device
=
"cuda"
)
*
10
).
to
(
torch
.
uint8
)
input_tensor_scale
=
torch
.
rand
(((
input_shape
[
0
]
+
128
-
1
)
//
128
)
*
128
,
(
input_shape
[
1
]
//
16
+
4
-
1
)
//
4
*
4
,
device
=
"cuda"
).
to
(
torch
.
float8_e4m3fn
)
weight_scale
=
torch
.
rand
(
weight_shape
[
0
],
weight_shape
[
1
]
//
16
,
device
=
"cuda"
).
to
(
torch
.
float8_e4m3fn
)
alpha
=
torch
.
tensor
(
0.0002765655517578125
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
bias
=
None
# 预热GPU
for
_
in
range
(
num_warmup
):
test_mm
(
input_tensor_quant
,
weight
,
input_tensor_scale
,
weight_scale
,
alpha
,
bias
)
# 同步GPU
torch
.
cuda
.
synchronize
()
# 创建GPU事件用于精确计时
start_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
end_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
# 测量时间
start_event
.
record
()
for
_
in
range
(
num_runs
):
result
=
test_mm
(
input_tensor_quant
,
weight
,
input_tensor_scale
,
weight_scale
,
alpha
,
bias
)
end_event
.
record
()
# 同步并计算时间
torch
.
cuda
.
synchronize
()
elapsed_time_ms
=
start_event
.
elapsed_time
(
end_event
)
elapsed_time_s
=
elapsed_time_ms
/
1000.0
# 计算FLOPS
# 矩阵乘法 A(M x K) @ B(K x N) = C(M x N)
# M = batch_size, K = input_dim, N = output_dim
M
=
input_shape
[
0
]
K
=
input_shape
[
1
]
N
=
weight_shape
[
0
]
# 每次矩阵乘法的FLOPS = 2 * M * N * K (每个输出元素需要K次乘法和K次加法)
flops_per_run
=
2
*
M
*
N
*
K
total_flops
=
flops_per_run
*
num_runs
# 计算TFLOPS (万亿次浮点运算每秒)
tflops
=
total_flops
/
(
elapsed_time_s
*
1e12
)
print
(
f
"测试结果:"
)
print
(
f
" 输入形状:
{
input_shape
}
(M=
{
M
}
, K=
{
K
}
)"
)
print
(
f
" 权重形状:
{
weight_shape
}
(N=
{
N
}
, K=
{
K
}
)"
)
print
(
f
" 输出形状: (
{
M
}
,
{
N
}
)"
)
print
(
f
" 运行次数:
{
num_runs
}
"
)
print
(
f
" 总执行时间:
{
elapsed_time_ms
:.
2
f
}
ms"
)
print
(
f
" 平均每次执行时间:
{
elapsed_time_ms
/
num_runs
:.
4
f
}
ms"
)
print
(
f
" 每次运行FLOPS:
{
flops_per_run
/
1e9
:.
2
f
}
GFLOPS"
)
print
(
f
" 总FLOPS:
{
total_flops
/
1e12
:.
2
f
}
TFLOPS"
)
print
(
f
" 计算性能:
{
tflops
:.
2
f
}
TFLOPS"
)
return
tflops
if
__name__
==
"__main__"
:
# 测试不同大小的矩阵乘法
# (m,k) (n,k)
test_cases
=
[
((
32130
,
5120
),
(
5120
,
5120
)),
((
512
,
5120
),
(
5120
,
5120
)),
((
257
,
5120
),
(
5120
,
5120
)),
((
32130
,
5120
),
(
13824
,
5120
)),
((
32130
,
13824
),
(
5120
,
13824
)),
((
75348
,
5120
),
(
5120
,
5120
)),
((
75348
,
5120
),
(
13824
,
5120
)),
((
75348
,
13824
),
(
5120
,
13824
)),
((
32760
,
1536
),
(
1536
,
1536
)),
((
512
,
1536
),
(
1536
,
1536
)),
((
32760
,
1536
),
(
8960
,
1536
)),
((
32760
,
8960
),
(
1536
,
8960
)),
]
print
(
"=== test_mm TFLOPS性能测试 ===
\n
"
)
for
i
,
(
input_shape
,
weight_shape
)
in
enumerate
(
test_cases
):
print
(
f
"测试
{
i
+
1
}
: 输入形状
{
input_shape
}
, 权重形状
{
weight_shape
}
"
)
print
(
"-"
*
60
)
tflops
=
test_tflops
(
input_shape
,
weight_shape
)
print
(
f
"✓ 成功完成测试,性能:
{
tflops
:.
2
f
}
TFLOPS
\n
"
)
print
(
"=== 测试完成 ==="
)
lightx2v_kernel/test/test_quant_mem_utils.py
0 → 100644
View file @
01caaf29
import
torch
from
lightx2v_kernel.gemm
import
scaled_fp4_quant
input_global_scale
=
torch
.
tensor
(
808.0
,
dtype
=
torch
.
float32
).
cuda
()
def
quantize_fp4
(
x
):
return
scaled_fp4_quant
(
x
,
input_global_scale
)
def
test_memory_bandwidth
(
func
,
x
,
num_warmup
=
10
,
num_runs
=
100
):
"""
测试函数的显存带宽
"""
# 预热GPU
for
_
in
range
(
num_warmup
):
func
(
x
)
# 同步GPU
torch
.
cuda
.
synchronize
()
# 创建GPU事件用于精确计时
start_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
end_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
# 测量时间
start_event
.
record
()
for
_
in
range
(
num_runs
):
result
=
func
(
x
)
end_event
.
record
()
# 同步并计算时间
torch
.
cuda
.
synchronize
()
elapsed_time_ms
=
start_event
.
elapsed_time
(
end_event
)
elapsed_time_s
=
elapsed_time_ms
/
1000.0
# 计算数据量
input_bytes
=
x
.
numel
()
*
x
.
element_size
()
# 输入数据字节数
# FP4量化后,每个元素占用0.5字节
output_bytes
=
x
.
numel
()
*
0.5
# FP4输出数据字节数
scale_bytes
=
x
.
numel
()
/
16
# group_size = 16
# 总数据传输量(读取输入 + 写入输出 + scale)
total_bytes
=
(
input_bytes
+
output_bytes
+
scale_bytes
)
*
num_runs
# 计算带宽
bandwidth_gbps
=
(
total_bytes
/
elapsed_time_s
)
/
(
1024
**
3
)
# GB/s
print
(
f
"测试结果:"
)
print
(
f
" 输入张量形状:
{
x
.
shape
}
"
)
print
(
f
" 输入数据类型:
{
x
.
dtype
}
"
)
print
(
f
" 运行次数:
{
num_runs
}
"
)
print
(
f
" 总执行时间:
{
elapsed_time_ms
:.
2
f
}
ms"
)
print
(
f
" 平均每次执行时间:
{
elapsed_time_ms
/
num_runs
:.
4
f
}
ms"
)
print
(
f
" 输入数据大小:
{
input_bytes
/
(
1024
**
2
):.
2
f
}
MB"
)
print
(
f
" 输出数据大小:
{
output_bytes
/
(
1024
**
2
):.
2
f
}
MB"
)
print
(
f
" 总数据传输量:
{
total_bytes
/
(
1024
**
3
):.
2
f
}
GB"
)
print
(
f
" 显存带宽:
{
bandwidth_gbps
:.
2
f
}
GB/s"
)
return
bandwidth_gbps
if
__name__
==
"__main__"
:
# 测试不同大小的张量
test_sizes
=
[
# (1, 1024),
# (1, 2048),
# (1, 4096),
# (1, 8192),
# (1, 16384),
# (1, 32768),
# (2, 1024),
# (2, 2048),
# (2, 4096),
# (2, 8192),
# (2, 16384),
# (2, 32768),
# (4, 1024),
# (4, 2048),
# (4, 4096),
# (4, 8192),
# (4, 16384),
# (4, 32768),
# (128, 1024),
# (128, 2048),
# (128, 4096),
# (128, 8192),
# (128, 16384),
# (128, 32768),
# (512, 1024),
# (512, 2048),
# (512, 4096),
# (512, 8192),
# (512, 16384),
# (512, 32768),
# (1024, 1024),
# (1024, 2048),
# (1024, 4096),
# (1024, 8192),
# (1024, 16384),
# (1024, 32768),
# (2048, 1024),
# (2048, 2048),
# (2048, 4096),
# (2048, 8192),
# (2048, 16384),
# (2048, 32768),
# (4096, 1024),
# (4096, 2048),
# (4096, 4096),
# (4096, 8192),
# (4096, 16384),
# (4096, 32768),
# (8192, 1024),
# (8192, 2048),
# (8192, 4096),
# (8192, 8192),
# (8192, 16384),
# (8192, 32768),
# (16384, 1024),
# (16384, 2048),
# (16384, 4096),
# (16384, 8192),
# (16384, 16384),
# (16384, 32768),
# (32768, 1024),
# (32768, 2048),
# (32768, 4096),
# (32768, 8192),
# (32768, 16384),
# (32768, 32768),
(
32130
,
5120
),
(
512
,
5120
),
(
257
,
5120
),
(
32130
,
13824
),
(
75348
,
5120
),
(
75348
,
13824
),
(
32760
,
1536
),
(
512
,
1536
),
(
32760
,
8960
),
]
print
(
"=== quantize_fp4 显存带宽测试 ===
\n
"
)
for
i
,
(
h
,
w
)
in
enumerate
(
test_sizes
):
print
(
f
"测试
{
i
+
1
}
: 张量大小 (
{
h
}
,
{
w
}
)"
)
print
(
"-"
*
50
)
x
=
torch
.
randn
(
h
,
w
,
dtype
=
torch
.
bfloat16
).
cuda
()
try
:
bandwidth
=
test_memory_bandwidth
(
quantize_fp4
,
x
)
print
(
f
"✓ 成功完成测试,带宽:
{
bandwidth
:.
2
f
}
GB/s
\n
"
)
except
Exception
as
e
:
print
(
f
"✗ 测试失败:
{
e
}
\n
"
)
print
(
"=== 测试完成 ==="
)
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment