Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
c7bb59cd
Commit
c7bb59cd
authored
Jun 23, 2025
by
helloyongyang
Browse files
fix ci
parent
01caaf29
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
90 additions
and
135 deletions
+90
-135
lightx2v_kernel/csrc/gemm/nvfp4_scaled_mm_kernels_sm120.cu
lightx2v_kernel/csrc/gemm/nvfp4_scaled_mm_kernels_sm120.cu
+1
-1
lightx2v_kernel/python/lightx2v_kernel/__init__.py
lightx2v_kernel/python/lightx2v_kernel/__init__.py
+1
-3
lightx2v_kernel/python/lightx2v_kernel/gemm.py
lightx2v_kernel/python/lightx2v_kernel/gemm.py
+4
-13
lightx2v_kernel/test/fake_quant.py
lightx2v_kernel/test/fake_quant.py
+1
-0
lightx2v_kernel/test/test_bench1.py
lightx2v_kernel/test/test_bench1.py
+11
-25
lightx2v_kernel/test/test_bench2.py
lightx2v_kernel/test/test_bench2.py
+16
-24
lightx2v_kernel/test/test_bench3_bias.py
lightx2v_kernel/test/test_bench3_bias.py
+14
-21
lightx2v_kernel/test/test_mm_tflops.py
lightx2v_kernel/test/test_mm_tflops.py
+21
-23
lightx2v_kernel/test/test_quant_mem_utils.py
lightx2v_kernel/test/test_quant_mem_utils.py
+21
-25
No files found.
lightx2v_kernel/csrc/gemm/nvfp4_scaled_mm_kernels_sm120.cu
View file @
c7bb59cd
...
...
@@ -61,7 +61,7 @@ struct Fp4GemmSm120 {
using
ClusterShape
=
Shape
<
_1
,
_1
,
_1
>
;
// Shape of the threadblocks in a cluster
using
CollectiveEpilogue
=
typename
cutlass
::
epilogue
::
collective
::
CollectiveBuilder
<
ArchTag
,
OperatorClass
,
ArchTag
,
OperatorClass
,
ThreadBlockShape
,
ClusterShape
,
cutlass
::
epilogue
::
collective
::
EpilogueTileAuto
,
ElementAccumulator
,
ElementAccumulator
,
...
...
lightx2v_kernel/python/lightx2v_kernel/__init__.py
View file @
c7bb59cd
...
...
@@ -14,6 +14,4 @@ from lightx2v_kernel import common_ops
from
lightx2v_kernel.gemm
import
cutlass_scaled_fp4_mm
,
scaled_fp4_quant
from
lightx2v_kernel.version
import
__version__
build_tree_kernel
=
(
None
)
build_tree_kernel
=
None
lightx2v_kernel/python/lightx2v_kernel/gemm.py
View file @
c7bb59cd
...
...
@@ -12,15 +12,11 @@ def cutlass_scaled_fp4_mm(mat_a, mat_b, scales_a, scales_b, alpha, bias=None):
"""
m
,
n
=
mat_a
.
shape
[
0
],
mat_b
.
shape
[
0
]
out
=
torch
.
empty
((
m
,
n
),
dtype
=
torch
.
bfloat16
,
device
=
mat_a
.
device
)
torch
.
ops
.
lightx2v_kernel
.
cutlass_scaled_fp4_mm_sm120
.
default
(
out
,
mat_a
,
mat_b
,
scales_a
,
scales_b
,
alpha
,
bias
)
torch
.
ops
.
lightx2v_kernel
.
cutlass_scaled_fp4_mm_sm120
.
default
(
out
,
mat_a
,
mat_b
,
scales_a
,
scales_b
,
alpha
,
bias
)
return
out
def
scaled_fp4_quant
(
input
:
torch
.
Tensor
,
input_global_scale
:
torch
.
Tensor
):
def
scaled_fp4_quant
(
input
:
torch
.
Tensor
,
input_global_scale
:
torch
.
Tensor
):
"""
Quantize input tensor to FP4 and return quantized tensor and scale.
...
...
@@ -60,13 +56,8 @@ def scaled_fp4_quant(
# rounded_m = ((m + 128 - 1) // 128) * 128
# scale_n = n // block_size
# rounded_n = ((scale_n + 4 - 1) // 4) * 4
output_scale
=
torch
.
empty
(
(((
m
+
128
-
1
)
//
128
)
*
128
,
(
n
//
block_size
+
4
-
1
)
//
4
),
device
=
device
,
dtype
=
torch
.
int32
)
output_scale
=
torch
.
empty
((((
m
+
128
-
1
)
//
128
)
*
128
,
(
n
//
block_size
+
4
-
1
)
//
4
),
device
=
device
,
dtype
=
torch
.
int32
)
torch
.
ops
.
lightx2v_kernel
.
scaled_fp4_quant_sm120
.
default
(
output
,
input
,
output_scale
,
input_global_scale
)
torch
.
ops
.
lightx2v_kernel
.
scaled_fp4_quant_sm120
.
default
(
output
,
input
,
output_scale
,
input_global_scale
)
output_scale
=
output_scale
.
view
(
torch
.
float8_e4m3fn
)
return
output
,
output_scale
lightx2v_kernel/test/fake_quant.py
View file @
c7bb59cd
...
...
@@ -6,6 +6,7 @@ BLOCK_SIZE = 16
FLOAT4_E2M1_MAX
=
6.0
FLOAT8_E4M3_MAX
=
torch
.
finfo
(
torch
.
float8_e4m3fn
).
max
def
cast_to_fp4
(
x
):
sign
=
torch
.
sign
(
x
)
x
=
torch
.
abs
(
x
)
...
...
lightx2v_kernel/test/test_bench1.py
View file @
c7bb59cd
...
...
@@ -53,9 +53,7 @@ def break_fp4_bytes(a, dtype):
return
out
def
dequantize_to_dtype
(
tensor_fp4
,
tensor_sf
,
global_scale
,
dtype
,
device
,
block_size
=
16
):
def
dequantize_to_dtype
(
tensor_fp4
,
tensor_sf
,
global_scale
,
dtype
,
device
,
block_size
=
16
):
"""Dequantize the fp4 tensor back to high precision."""
# Two fp4 values are packed into one uint8.
assert
tensor_fp4
.
dtype
==
torch
.
uint8
...
...
@@ -88,12 +86,8 @@ def get_ref_results(
_
,
m_k
=
a_fp4
.
shape
_
,
n_k
=
b_fp4
.
shape
assert
m_k
==
n_k
a_in_dtype
=
dequantize_to_dtype
(
a_fp4
,
a_sf
,
a_global_scale
,
dtype
=
dtype
,
device
=
device
,
block_size
=
block_size
)
b_in_dtype
=
dequantize_to_dtype
(
b_fp4
,
b_sf
,
b_global_scale
,
dtype
=
dtype
,
device
=
device
,
block_size
=
block_size
)
a_in_dtype
=
dequantize_to_dtype
(
a_fp4
,
a_sf
,
a_global_scale
,
dtype
=
dtype
,
device
=
device
,
block_size
=
block_size
)
b_in_dtype
=
dequantize_to_dtype
(
b_fp4
,
b_sf
,
b_global_scale
,
dtype
=
dtype
,
device
=
device
,
block_size
=
block_size
)
return
torch
.
matmul
(
a_in_dtype
,
b_in_dtype
.
t
())
...
...
@@ -109,16 +103,12 @@ def test_nvfp4_gemm(
b_dtype
=
torch
.
randn
((
n
,
k
),
dtype
=
dtype
,
device
=
"cuda"
)
bias
=
torch
.
randn
((
1
,
n
),
dtype
=
dtype
,
device
=
"cuda"
)
a_global_scale
=
(
(
FLOAT8_E4M3_MAX
*
FLOAT4_E2M1_MAX
)
/
torch
.
amax
(
a_dtype
.
flatten
(),
dim
=-
1
)
).
to
(
torch
.
float32
)
b_global_scale
=
(
(
FLOAT8_E4M3_MAX
*
FLOAT4_E2M1_MAX
)
/
torch
.
amax
(
b_dtype
.
flatten
(),
dim
=-
1
)
).
to
(
torch
.
float32
)
a_global_scale
=
((
FLOAT8_E4M3_MAX
*
FLOAT4_E2M1_MAX
)
/
torch
.
amax
(
a_dtype
.
flatten
(),
dim
=-
1
)).
to
(
torch
.
float32
)
b_global_scale
=
((
FLOAT8_E4M3_MAX
*
FLOAT4_E2M1_MAX
)
/
torch
.
amax
(
b_dtype
.
flatten
(),
dim
=-
1
)).
to
(
torch
.
float32
)
print
(
f
"a_global_scale :
{
a_global_scale
}
,
{
a_global_scale
.
shape
}
"
)
print
(
f
"b_global_scale :
{
b_global_scale
}
,
{
b_global_scale
.
shape
}
"
)
alpha
=
1.0
/
(
a_global_scale
*
b_global_scale
)
a_fp4
,
a_scale_interleaved
=
scaled_fp4_quant
(
a_dtype
,
a_global_scale
)
b_fp4
,
b_scale_interleaved
=
scaled_fp4_quant
(
b_dtype
,
b_global_scale
)
...
...
@@ -137,15 +127,11 @@ def test_nvfp4_gemm(
"cuda"
,
)
expected_out
=
expected_out
+
bias
print
(
f
"alpha
{
alpha
}
,
{
alpha
.
shape
}
,
{
alpha
.
dtype
}
"
)
out
=
cutlass_scaled_fp4_mm
(
a_fp4
,
b_fp4
,
a_scale_interleaved
,
b_scale_interleaved
,
alpha
,
bias
)
out
=
cutlass_scaled_fp4_mm
(
a_fp4
,
b_fp4
,
a_scale_interleaved
,
b_scale_interleaved
,
alpha
,
bias
)
print
(
f
"out :
{
out
}
,
{
out
.
shape
}
,
{
out
.
dtype
}
"
)
print
(
f
"expected_out :
{
expected_out
}
,
{
expected_out
.
shape
}
,
{
expected_out
.
dtype
}
"
)
...
...
lightx2v_kernel/test/test_bench2.py
View file @
c7bb59cd
...
...
@@ -7,7 +7,7 @@ class MMWeightFp4:
def
__init__
(
self
,
weight
,
bias
):
self
.
load_fp4_weight
(
weight
,
bias
)
self
.
act_quant_func
=
self
.
act_quant_fp4
# calibrate x_max
self
.
calibrate_x_absmax
()
...
...
@@ -24,7 +24,7 @@ class MMWeightFp4:
self
.
bias
=
bias
def
calibrate_x_absmax
(
self
):
self
.
x_absmax
=
torch
.
tensor
(
5.0
,
dtype
=
torch
.
float32
,
device
=
self
.
weight
.
device
)
# need to be calibrated
self
.
x_absmax
=
torch
.
tensor
(
5.0
,
dtype
=
torch
.
float32
,
device
=
self
.
weight
.
device
)
# need to be calibrated
self
.
input_global_scale
=
(
2688.0
/
self
.
x_absmax
).
to
(
torch
.
float32
)
self
.
alpha
=
1.0
/
(
self
.
input_global_scale
*
self
.
weight_global_scale
)
...
...
@@ -33,7 +33,6 @@ class MMWeightFp4:
return
scaled_fp4_quant
(
x
,
self
.
input_global_scale
)
def
test_speed
(
m
,
k
,
n
):
with
torch
.
no_grad
():
input_tensor
=
torch
.
randn
(
m
,
k
,
dtype
=
torch
.
bfloat16
).
cuda
()
...
...
@@ -42,26 +41,24 @@ def test_speed(m, k, n):
bias
=
None
mm
=
MMWeightFp4
(
weight
,
bias
)
# warmup
output_tensor
=
mm
.
apply
(
input_tensor
)
torch
.
cuda
.
synchronize
()
start_time
=
time
.
time
()
for
i
in
range
(
100
):
output_tensor
=
mm
.
apply
(
input_tensor
)
torch
.
cuda
.
synchronize
()
end_time
=
time
.
time
()
lightx2v_kernel_time
=
(
end_time
-
start_time
)
/
100
print
(
f
"lightx2v-kernel time:
{
lightx2v_kernel_time
}
"
)
input_tensor
=
torch
.
randn
(
m
,
n
,
dtype
=
torch
.
bfloat16
).
cuda
()
weight
=
torch
.
randn
(
k
,
n
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
bias
=
torch
.
randn
(
1
,
k
,
dtype
=
torch
.
bfloat16
).
cuda
()
linear
=
torch
.
nn
.
Linear
(
k
,
n
,
bias
=
False
).
cuda
()
linear
.
weight
.
data
=
weight
# linear.bias.data = bias
...
...
@@ -72,13 +69,13 @@ def test_speed(m, k, n):
torch
.
cuda
.
synchronize
()
start_time
=
time
.
time
()
for
i
in
range
(
100
):
ref_output_tensor
=
linear
(
input_tensor
)
ref_output_tensor
=
linear
(
input_tensor
)
torch
.
cuda
.
synchronize
()
end_time
=
time
.
time
()
ref_time
=
(
end_time
-
start_time
)
/
100
print
(
f
"ref time:
{
ref_time
}
"
)
print
(
f
"speedup:
{
ref_time
/
lightx2v_kernel_time
:.
3
f
}
"
)
...
...
@@ -88,47 +85,42 @@ def test_accuracy(m, k, n):
weight
=
torch
.
randn
(
n
,
k
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
# bias = torch.randn(1, n, dtype=torch.bfloat16).cuda()
bias
=
None
linear
=
torch
.
nn
.
Linear
(
k
,
n
,
bias
=
False
).
cuda
()
linear
.
weight
.
data
=
weight
# linear.bias.data = bias
ref_output_tensor
=
linear
(
input_tensor
)
mm
=
MMWeightFp4
(
weight
,
bias
)
output_tensor
=
mm
.
apply
(
input_tensor
)
# print(f"ref_output_tensor: {ref_output_tensor}")
# print(f"output_tensor: {output_tensor}")
# cosine
cos
=
torch
.
nn
.
functional
.
cosine_similarity
(
ref_output_tensor
.
flatten
(),
output_tensor
.
flatten
(),
dim
=
0
)
print
(
f
"cos :
{
cos
}
"
)
if
__name__
==
"__main__"
:
test_sizes
=
[
(
32130
,
5120
,
5120
),
(
512
,
5120
,
5120
),
(
257
,
5120
,
5120
),
(
32130
,
5120
,
13824
),
(
32130
,
13824
,
5120
),
(
75348
,
5120
,
5120
),
(
75348
,
13824
,
5120
),
(
32760
,
1536
,
1536
),
(
512
,
1536
,
1536
),
(
32760
,
1536
,
8960
),
(
32760
,
8960
,
1536
),
]
for
i
,
(
m
,
k
,
n
)
in
enumerate
(
test_sizes
):
print
(
"-"
*
30
)
print
(
f
"测试
{
i
+
1
}
: 张量大小 (
{
m
}
,
{
k
}
,
{
n
}
)"
)
print
(
f
"测试
{
i
+
1
}
: 张量大小 (
{
m
}
,
{
k
}
,
{
n
}
)"
)
test_accuracy
(
m
,
k
,
n
)
test_speed
(
m
,
k
,
n
)
lightx2v_kernel/test/test_bench3_bias.py
View file @
c7bb59cd
...
...
@@ -11,26 +11,24 @@ def test_speed(m, k, n):
bias
=
torch
.
randn
(
1
,
n
,
dtype
=
torch
.
bfloat16
).
cuda
()
mm
=
MMWeightFp4
(
weight
,
bias
)
# warmup
output_tensor
=
mm
.
apply
(
input_tensor
)
torch
.
cuda
.
synchronize
()
start_time
=
time
.
time
()
for
i
in
range
(
100
):
output_tensor
=
mm
.
apply
(
input_tensor
)
torch
.
cuda
.
synchronize
()
end_time
=
time
.
time
()
lightx2v_kernel_time
=
(
end_time
-
start_time
)
/
100
print
(
f
"lightx2v-kernel time:
{
lightx2v_kernel_time
}
"
)
input_tensor
=
torch
.
randn
(
m
,
n
,
dtype
=
torch
.
bfloat16
).
cuda
()
weight
=
torch
.
randn
(
k
,
n
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
bias
=
torch
.
randn
(
1
,
k
,
dtype
=
torch
.
bfloat16
).
cuda
()
linear
=
torch
.
nn
.
Linear
(
k
,
n
,
bias
=
True
).
cuda
()
linear
.
weight
.
data
=
weight
linear
.
bias
.
data
=
bias
...
...
@@ -41,13 +39,13 @@ def test_speed(m, k, n):
torch
.
cuda
.
synchronize
()
start_time
=
time
.
time
()
for
i
in
range
(
100
):
ref_output_tensor
=
linear
(
input_tensor
)
ref_output_tensor
=
linear
(
input_tensor
)
torch
.
cuda
.
synchronize
()
end_time
=
time
.
time
()
ref_time
=
(
end_time
-
start_time
)
/
100
print
(
f
"ref time:
{
ref_time
}
"
)
print
(
f
"speedup:
{
ref_time
/
lightx2v_kernel_time
:.
3
f
}
"
)
...
...
@@ -56,47 +54,42 @@ def test_accuracy(m, k, n):
input_tensor
=
torch
.
randn
(
m
,
k
,
dtype
=
torch
.
bfloat16
).
cuda
()
weight
=
torch
.
randn
(
n
,
k
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
bias
=
torch
.
randn
(
1
,
n
,
dtype
=
torch
.
bfloat16
).
cuda
()
linear
=
torch
.
nn
.
Linear
(
k
,
n
,
bias
=
True
).
cuda
()
linear
.
weight
.
data
=
weight
linear
.
bias
.
data
=
bias
ref_output_tensor
=
linear
(
input_tensor
)
mm
=
MMWeightFp4
(
weight
,
bias
)
output_tensor
=
mm
.
apply
(
input_tensor
)
# print(f"ref_output_tensor: {ref_output_tensor}")
# print(f"output_tensor: {output_tensor}")
# cosine
cos
=
torch
.
nn
.
functional
.
cosine_similarity
(
ref_output_tensor
.
flatten
(),
output_tensor
.
flatten
(),
dim
=
0
)
print
(
f
"cos :
{
cos
}
"
)
if
__name__
==
"__main__"
:
test_sizes
=
[
(
32130
,
5120
,
5120
),
(
512
,
5120
,
5120
),
(
257
,
5120
,
5120
),
(
32130
,
5120
,
13824
),
(
32130
,
13824
,
5120
),
(
75348
,
5120
,
5120
),
(
75348
,
13824
,
5120
),
(
32760
,
1536
,
1536
),
(
512
,
1536
,
1536
),
(
32760
,
1536
,
8960
),
(
32760
,
8960
,
1536
),
]
for
i
,
(
m
,
k
,
n
)
in
enumerate
(
test_sizes
):
print
(
"-"
*
30
)
print
(
f
"测试
{
i
+
1
}
: 张量大小 (
{
m
}
,
{
k
}
,
{
n
}
)"
)
print
(
f
"测试
{
i
+
1
}
: 张量大小 (
{
m
}
,
{
k
}
,
{
n
}
)"
)
test_accuracy
(
m
,
k
,
n
)
test_speed
(
m
,
k
,
n
)
lightx2v_kernel/test/test_mm_tflops.py
View file @
c7bb59cd
...
...
@@ -14,6 +14,7 @@ alpha = torch.tensor(0.0002765655517578125, device="cuda").to(torch.float32)
bias = None
"""
def
test_mm
(
input_tensor_quant
,
weight
,
input_tensor_scale
,
weight_scale
,
alpha
,
bias
):
output_tensor
=
cutlass_scaled_fp4_mm
(
input_tensor_quant
,
weight
,
input_tensor_scale
,
weight_scale
,
alpha
=
alpha
,
bias
=
bias
)
return
output_tensor
...
...
@@ -23,64 +24,63 @@ def test_tflops(input_shape, weight_shape, num_warmup=10, num_runs=100):
"""
测试test_mm函数的TFLOPS性能
"""
# 创建输入数据
input_tensor_quant
=
(
torch
.
rand
((
input_shape
[
0
],
input_shape
[
1
]
//
2
),
device
=
"cuda"
)
*
10
).
to
(
torch
.
uint8
)
weight
=
(
torch
.
rand
((
weight_shape
[
0
],
weight_shape
[
1
]
//
2
),
device
=
"cuda"
)
*
10
).
to
(
torch
.
uint8
)
input_tensor_scale
=
torch
.
rand
(((
input_shape
[
0
]
+
128
-
1
)
//
128
)
*
128
,
(
input_shape
[
1
]
//
16
+
4
-
1
)
//
4
*
4
,
device
=
"cuda"
).
to
(
torch
.
float8_e4m3fn
)
weight_scale
=
torch
.
rand
(
weight_shape
[
0
],
weight_shape
[
1
]
//
16
,
device
=
"cuda"
).
to
(
torch
.
float8_e4m3fn
)
alpha
=
torch
.
tensor
(
0.0002765655517578125
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
bias
=
None
# 预热GPU
for
_
in
range
(
num_warmup
):
test_mm
(
input_tensor_quant
,
weight
,
input_tensor_scale
,
weight_scale
,
alpha
,
bias
)
# 同步GPU
torch
.
cuda
.
synchronize
()
# 创建GPU事件用于精确计时
start_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
end_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
# 测量时间
start_event
.
record
()
for
_
in
range
(
num_runs
):
result
=
test_mm
(
input_tensor_quant
,
weight
,
input_tensor_scale
,
weight_scale
,
alpha
,
bias
)
end_event
.
record
()
# 同步并计算时间
torch
.
cuda
.
synchronize
()
elapsed_time_ms
=
start_event
.
elapsed_time
(
end_event
)
elapsed_time_s
=
elapsed_time_ms
/
1000.0
# 计算FLOPS
# 矩阵乘法 A(M x K) @ B(K x N) = C(M x N)
# M = batch_size, K = input_dim, N = output_dim
M
=
input_shape
[
0
]
K
=
input_shape
[
1
]
N
=
weight_shape
[
0
]
# 每次矩阵乘法的FLOPS = 2 * M * N * K (每个输出元素需要K次乘法和K次加法)
flops_per_run
=
2
*
M
*
N
*
K
total_flops
=
flops_per_run
*
num_runs
# 计算TFLOPS (万亿次浮点运算每秒)
tflops
=
total_flops
/
(
elapsed_time_s
*
1e12
)
print
(
f
"测试结果:"
)
print
(
f
" 输入形状:
{
input_shape
}
(M=
{
M
}
, K=
{
K
}
)"
)
print
(
f
" 权重形状:
{
weight_shape
}
(N=
{
N
}
, K=
{
K
}
)"
)
print
(
f
" 输出形状: (
{
M
}
,
{
N
}
)"
)
print
(
f
" 运行次数:
{
num_runs
}
"
)
print
(
f
" 总执行时间:
{
elapsed_time_ms
:.
2
f
}
ms"
)
print
(
f
" 平均每次执行时间:
{
elapsed_time_ms
/
num_runs
:.
4
f
}
ms"
)
print
(
f
" 每次运行FLOPS:
{
flops_per_run
/
1e9
:.
2
f
}
GFLOPS"
)
print
(
f
" 总FLOPS:
{
total_flops
/
1e12
:.
2
f
}
TFLOPS"
)
print
(
f
" 平均每次执行时间:
{
elapsed_time_ms
/
num_runs
:.
4
f
}
ms"
)
print
(
f
" 每次运行FLOPS:
{
flops_per_run
/
1e9
:.
2
f
}
GFLOPS"
)
print
(
f
" 总FLOPS:
{
total_flops
/
1e12
:.
2
f
}
TFLOPS"
)
print
(
f
" 计算性能:
{
tflops
:.
2
f
}
TFLOPS"
)
return
tflops
...
...
@@ -93,24 +93,22 @@ if __name__ == "__main__":
((
257
,
5120
),
(
5120
,
5120
)),
((
32130
,
5120
),
(
13824
,
5120
)),
((
32130
,
13824
),
(
5120
,
13824
)),
((
75348
,
5120
),
(
5120
,
5120
)),
((
75348
,
5120
),
(
13824
,
5120
)),
((
75348
,
13824
),
(
5120
,
13824
)),
((
32760
,
1536
),
(
1536
,
1536
)),
((
512
,
1536
),
(
1536
,
1536
)),
((
32760
,
1536
),
(
8960
,
1536
)),
((
32760
,
8960
),
(
1536
,
8960
)),
]
print
(
"=== test_mm TFLOPS性能测试 ===
\n
"
)
for
i
,
(
input_shape
,
weight_shape
)
in
enumerate
(
test_cases
):
print
(
f
"测试
{
i
+
1
}
: 输入形状
{
input_shape
}
, 权重形状
{
weight_shape
}
"
)
print
(
f
"测试
{
i
+
1
}
: 输入形状
{
input_shape
}
, 权重形状
{
weight_shape
}
"
)
print
(
"-"
*
60
)
tflops
=
test_tflops
(
input_shape
,
weight_shape
)
print
(
f
"✓ 成功完成测试,性能:
{
tflops
:.
2
f
}
TFLOPS
\n
"
)
print
(
"=== 测试完成 ==="
)
lightx2v_kernel/test/test_quant_mem_utils.py
View file @
c7bb59cd
...
...
@@ -4,7 +4,8 @@ from lightx2v_kernel.gemm import scaled_fp4_quant
input_global_scale
=
torch
.
tensor
(
808.0
,
dtype
=
torch
.
float32
).
cuda
()
def
quantize_fp4
(
x
):
def
quantize_fp4
(
x
):
return
scaled_fp4_quant
(
x
,
input_global_scale
)
...
...
@@ -15,51 +16,50 @@ def test_memory_bandwidth(func, x, num_warmup=10, num_runs=100):
# 预热GPU
for
_
in
range
(
num_warmup
):
func
(
x
)
# 同步GPU
torch
.
cuda
.
synchronize
()
# 创建GPU事件用于精确计时
start_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
end_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
# 测量时间
start_event
.
record
()
for
_
in
range
(
num_runs
):
result
=
func
(
x
)
end_event
.
record
()
# 同步并计算时间
torch
.
cuda
.
synchronize
()
elapsed_time_ms
=
start_event
.
elapsed_time
(
end_event
)
elapsed_time_s
=
elapsed_time_ms
/
1000.0
# 计算数据量
input_bytes
=
x
.
numel
()
*
x
.
element_size
()
# 输入数据字节数
# FP4量化后,每个元素占用0.5字节
output_bytes
=
x
.
numel
()
*
0.5
# FP4输出数据字节数
scale_bytes
=
x
.
numel
()
/
16
# group_size = 16
scale_bytes
=
x
.
numel
()
/
16
# group_size = 16
# 总数据传输量(读取输入 + 写入输出 + scale)
total_bytes
=
(
input_bytes
+
output_bytes
+
scale_bytes
)
*
num_runs
# 计算带宽
bandwidth_gbps
=
(
total_bytes
/
elapsed_time_s
)
/
(
1024
**
3
)
# GB/s
print
(
f
"测试结果:"
)
print
(
f
" 输入张量形状:
{
x
.
shape
}
"
)
print
(
f
" 输入数据类型:
{
x
.
dtype
}
"
)
print
(
f
" 运行次数:
{
num_runs
}
"
)
print
(
f
" 总执行时间:
{
elapsed_time_ms
:.
2
f
}
ms"
)
print
(
f
" 平均每次执行时间:
{
elapsed_time_ms
/
num_runs
:.
4
f
}
ms"
)
print
(
f
" 平均每次执行时间:
{
elapsed_time_ms
/
num_runs
:.
4
f
}
ms"
)
print
(
f
" 输入数据大小:
{
input_bytes
/
(
1024
**
2
):.
2
f
}
MB"
)
print
(
f
" 输出数据大小:
{
output_bytes
/
(
1024
**
2
):.
2
f
}
MB"
)
print
(
f
" 输出数据大小:
{
output_bytes
/
(
1024
**
2
):.
2
f
}
MB"
)
print
(
f
" 总数据传输量:
{
total_bytes
/
(
1024
**
3
):.
2
f
}
GB"
)
print
(
f
" 显存带宽:
{
bandwidth_gbps
:.
2
f
}
GB/s"
)
return
bandwidth_gbps
...
...
@@ -132,33 +132,29 @@ if __name__ == "__main__":
# (32768, 8192),
# (32768, 16384),
# (32768, 32768),
(
32130
,
5120
),
(
512
,
5120
),
(
257
,
5120
),
(
32130
,
13824
),
(
75348
,
5120
),
(
75348
,
13824
),
(
32760
,
1536
),
(
512
,
1536
),
(
32760
,
8960
),
]
print
(
"=== quantize_fp4 显存带宽测试 ===
\n
"
)
for
i
,
(
h
,
w
)
in
enumerate
(
test_sizes
):
print
(
f
"测试
{
i
+
1
}
: 张量大小 (
{
h
}
,
{
w
}
)"
)
print
(
f
"测试
{
i
+
1
}
: 张量大小 (
{
h
}
,
{
w
}
)"
)
print
(
"-"
*
50
)
x
=
torch
.
randn
(
h
,
w
,
dtype
=
torch
.
bfloat16
).
cuda
()
try
:
bandwidth
=
test_memory_bandwidth
(
quantize_fp4
,
x
)
print
(
f
"✓ 成功完成测试,带宽:
{
bandwidth
:.
2
f
}
GB/s
\n
"
)
except
Exception
as
e
:
print
(
f
"✗ 测试失败:
{
e
}
\n
"
)
print
(
"=== 测试完成 ==="
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment