Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
c7bb59cd
Commit
c7bb59cd
authored
Jun 23, 2025
by
helloyongyang
Browse files
fix ci
parent
01caaf29
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
90 additions
and
135 deletions
+90
-135
lightx2v_kernel/csrc/gemm/nvfp4_scaled_mm_kernels_sm120.cu
lightx2v_kernel/csrc/gemm/nvfp4_scaled_mm_kernels_sm120.cu
+1
-1
lightx2v_kernel/python/lightx2v_kernel/__init__.py
lightx2v_kernel/python/lightx2v_kernel/__init__.py
+1
-3
lightx2v_kernel/python/lightx2v_kernel/gemm.py
lightx2v_kernel/python/lightx2v_kernel/gemm.py
+4
-13
lightx2v_kernel/test/fake_quant.py
lightx2v_kernel/test/fake_quant.py
+1
-0
lightx2v_kernel/test/test_bench1.py
lightx2v_kernel/test/test_bench1.py
+11
-25
lightx2v_kernel/test/test_bench2.py
lightx2v_kernel/test/test_bench2.py
+16
-24
lightx2v_kernel/test/test_bench3_bias.py
lightx2v_kernel/test/test_bench3_bias.py
+14
-21
lightx2v_kernel/test/test_mm_tflops.py
lightx2v_kernel/test/test_mm_tflops.py
+21
-23
lightx2v_kernel/test/test_quant_mem_utils.py
lightx2v_kernel/test/test_quant_mem_utils.py
+21
-25
No files found.
lightx2v_kernel/csrc/gemm/nvfp4_scaled_mm_kernels_sm120.cu
View file @
c7bb59cd
lightx2v_kernel/python/lightx2v_kernel/__init__.py
View file @
c7bb59cd
...
@@ -14,6 +14,4 @@ from lightx2v_kernel import common_ops
...
@@ -14,6 +14,4 @@ from lightx2v_kernel import common_ops
from
lightx2v_kernel.gemm
import
cutlass_scaled_fp4_mm
,
scaled_fp4_quant
from
lightx2v_kernel.gemm
import
cutlass_scaled_fp4_mm
,
scaled_fp4_quant
from
lightx2v_kernel.version
import
__version__
from
lightx2v_kernel.version
import
__version__
build_tree_kernel
=
(
build_tree_kernel
=
None
None
)
lightx2v_kernel/python/lightx2v_kernel/gemm.py
View file @
c7bb59cd
...
@@ -12,15 +12,11 @@ def cutlass_scaled_fp4_mm(mat_a, mat_b, scales_a, scales_b, alpha, bias=None):
...
@@ -12,15 +12,11 @@ def cutlass_scaled_fp4_mm(mat_a, mat_b, scales_a, scales_b, alpha, bias=None):
"""
"""
m
,
n
=
mat_a
.
shape
[
0
],
mat_b
.
shape
[
0
]
m
,
n
=
mat_a
.
shape
[
0
],
mat_b
.
shape
[
0
]
out
=
torch
.
empty
((
m
,
n
),
dtype
=
torch
.
bfloat16
,
device
=
mat_a
.
device
)
out
=
torch
.
empty
((
m
,
n
),
dtype
=
torch
.
bfloat16
,
device
=
mat_a
.
device
)
torch
.
ops
.
lightx2v_kernel
.
cutlass_scaled_fp4_mm_sm120
.
default
(
torch
.
ops
.
lightx2v_kernel
.
cutlass_scaled_fp4_mm_sm120
.
default
(
out
,
mat_a
,
mat_b
,
scales_a
,
scales_b
,
alpha
,
bias
)
out
,
mat_a
,
mat_b
,
scales_a
,
scales_b
,
alpha
,
bias
)
return
out
return
out
def
scaled_fp4_quant
(
def
scaled_fp4_quant
(
input
:
torch
.
Tensor
,
input_global_scale
:
torch
.
Tensor
):
input
:
torch
.
Tensor
,
input_global_scale
:
torch
.
Tensor
):
"""
"""
Quantize input tensor to FP4 and return quantized tensor and scale.
Quantize input tensor to FP4 and return quantized tensor and scale.
...
@@ -60,13 +56,8 @@ def scaled_fp4_quant(
...
@@ -60,13 +56,8 @@ def scaled_fp4_quant(
# rounded_m = ((m + 128 - 1) // 128) * 128
# rounded_m = ((m + 128 - 1) // 128) * 128
# scale_n = n // block_size
# scale_n = n // block_size
# rounded_n = ((scale_n + 4 - 1) // 4) * 4
# rounded_n = ((scale_n + 4 - 1) // 4) * 4
output_scale
=
torch
.
empty
(
output_scale
=
torch
.
empty
((((
m
+
128
-
1
)
//
128
)
*
128
,
(
n
//
block_size
+
4
-
1
)
//
4
),
device
=
device
,
dtype
=
torch
.
int32
)
(((
m
+
128
-
1
)
//
128
)
*
128
,
(
n
//
block_size
+
4
-
1
)
//
4
),
device
=
device
,
dtype
=
torch
.
int32
)
torch
.
ops
.
lightx2v_kernel
.
scaled_fp4_quant_sm120
.
default
(
torch
.
ops
.
lightx2v_kernel
.
scaled_fp4_quant_sm120
.
default
(
output
,
input
,
output_scale
,
input_global_scale
)
output
,
input
,
output_scale
,
input_global_scale
)
output_scale
=
output_scale
.
view
(
torch
.
float8_e4m3fn
)
output_scale
=
output_scale
.
view
(
torch
.
float8_e4m3fn
)
return
output
,
output_scale
return
output
,
output_scale
lightx2v_kernel/test/fake_quant.py
View file @
c7bb59cd
...
@@ -6,6 +6,7 @@ BLOCK_SIZE = 16
...
@@ -6,6 +6,7 @@ BLOCK_SIZE = 16
FLOAT4_E2M1_MAX
=
6.0
FLOAT4_E2M1_MAX
=
6.0
FLOAT8_E4M3_MAX
=
torch
.
finfo
(
torch
.
float8_e4m3fn
).
max
FLOAT8_E4M3_MAX
=
torch
.
finfo
(
torch
.
float8_e4m3fn
).
max
def
cast_to_fp4
(
x
):
def
cast_to_fp4
(
x
):
sign
=
torch
.
sign
(
x
)
sign
=
torch
.
sign
(
x
)
x
=
torch
.
abs
(
x
)
x
=
torch
.
abs
(
x
)
...
...
lightx2v_kernel/test/test_bench1.py
View file @
c7bb59cd
...
@@ -53,9 +53,7 @@ def break_fp4_bytes(a, dtype):
...
@@ -53,9 +53,7 @@ def break_fp4_bytes(a, dtype):
return
out
return
out
def
dequantize_to_dtype
(
def
dequantize_to_dtype
(
tensor_fp4
,
tensor_sf
,
global_scale
,
dtype
,
device
,
block_size
=
16
):
tensor_fp4
,
tensor_sf
,
global_scale
,
dtype
,
device
,
block_size
=
16
):
"""Dequantize the fp4 tensor back to high precision."""
"""Dequantize the fp4 tensor back to high precision."""
# Two fp4 values are packed into one uint8.
# Two fp4 values are packed into one uint8.
assert
tensor_fp4
.
dtype
==
torch
.
uint8
assert
tensor_fp4
.
dtype
==
torch
.
uint8
...
@@ -88,12 +86,8 @@ def get_ref_results(
...
@@ -88,12 +86,8 @@ def get_ref_results(
_
,
m_k
=
a_fp4
.
shape
_
,
m_k
=
a_fp4
.
shape
_
,
n_k
=
b_fp4
.
shape
_
,
n_k
=
b_fp4
.
shape
assert
m_k
==
n_k
assert
m_k
==
n_k
a_in_dtype
=
dequantize_to_dtype
(
a_in_dtype
=
dequantize_to_dtype
(
a_fp4
,
a_sf
,
a_global_scale
,
dtype
=
dtype
,
device
=
device
,
block_size
=
block_size
)
a_fp4
,
a_sf
,
a_global_scale
,
dtype
=
dtype
,
device
=
device
,
block_size
=
block_size
b_in_dtype
=
dequantize_to_dtype
(
b_fp4
,
b_sf
,
b_global_scale
,
dtype
=
dtype
,
device
=
device
,
block_size
=
block_size
)
)
b_in_dtype
=
dequantize_to_dtype
(
b_fp4
,
b_sf
,
b_global_scale
,
dtype
=
dtype
,
device
=
device
,
block_size
=
block_size
)
return
torch
.
matmul
(
a_in_dtype
,
b_in_dtype
.
t
())
return
torch
.
matmul
(
a_in_dtype
,
b_in_dtype
.
t
())
...
@@ -109,12 +103,8 @@ def test_nvfp4_gemm(
...
@@ -109,12 +103,8 @@ def test_nvfp4_gemm(
b_dtype
=
torch
.
randn
((
n
,
k
),
dtype
=
dtype
,
device
=
"cuda"
)
b_dtype
=
torch
.
randn
((
n
,
k
),
dtype
=
dtype
,
device
=
"cuda"
)
bias
=
torch
.
randn
((
1
,
n
),
dtype
=
dtype
,
device
=
"cuda"
)
bias
=
torch
.
randn
((
1
,
n
),
dtype
=
dtype
,
device
=
"cuda"
)
a_global_scale
=
(
a_global_scale
=
((
FLOAT8_E4M3_MAX
*
FLOAT4_E2M1_MAX
)
/
torch
.
amax
(
a_dtype
.
flatten
(),
dim
=-
1
)).
to
(
torch
.
float32
)
(
FLOAT8_E4M3_MAX
*
FLOAT4_E2M1_MAX
)
/
torch
.
amax
(
a_dtype
.
flatten
(),
dim
=-
1
)
b_global_scale
=
((
FLOAT8_E4M3_MAX
*
FLOAT4_E2M1_MAX
)
/
torch
.
amax
(
b_dtype
.
flatten
(),
dim
=-
1
)).
to
(
torch
.
float32
)
).
to
(
torch
.
float32
)
b_global_scale
=
(
(
FLOAT8_E4M3_MAX
*
FLOAT4_E2M1_MAX
)
/
torch
.
amax
(
b_dtype
.
flatten
(),
dim
=-
1
)
).
to
(
torch
.
float32
)
print
(
f
"a_global_scale :
{
a_global_scale
}
,
{
a_global_scale
.
shape
}
"
)
print
(
f
"a_global_scale :
{
a_global_scale
}
,
{
a_global_scale
.
shape
}
"
)
print
(
f
"b_global_scale :
{
b_global_scale
}
,
{
b_global_scale
.
shape
}
"
)
print
(
f
"b_global_scale :
{
b_global_scale
}
,
{
b_global_scale
.
shape
}
"
)
...
@@ -138,13 +128,9 @@ def test_nvfp4_gemm(
...
@@ -138,13 +128,9 @@ def test_nvfp4_gemm(
)
)
expected_out
=
expected_out
+
bias
expected_out
=
expected_out
+
bias
print
(
f
"alpha
{
alpha
}
,
{
alpha
.
shape
}
,
{
alpha
.
dtype
}
"
)
print
(
f
"alpha
{
alpha
}
,
{
alpha
.
shape
}
,
{
alpha
.
dtype
}
"
)
out
=
cutlass_scaled_fp4_mm
(
a_fp4
,
b_fp4
,
a_scale_interleaved
,
b_scale_interleaved
,
alpha
,
bias
)
out
=
cutlass_scaled_fp4_mm
(
a_fp4
,
b_fp4
,
a_scale_interleaved
,
b_scale_interleaved
,
alpha
,
bias
)
print
(
f
"out :
{
out
}
,
{
out
.
shape
}
,
{
out
.
dtype
}
"
)
print
(
f
"out :
{
out
}
,
{
out
.
shape
}
,
{
out
.
dtype
}
"
)
print
(
f
"expected_out :
{
expected_out
}
,
{
expected_out
.
shape
}
,
{
expected_out
.
dtype
}
"
)
print
(
f
"expected_out :
{
expected_out
}
,
{
expected_out
.
shape
}
,
{
expected_out
.
dtype
}
"
)
...
...
lightx2v_kernel/test/test_bench2.py
View file @
c7bb59cd
...
@@ -33,7 +33,6 @@ class MMWeightFp4:
...
@@ -33,7 +33,6 @@ class MMWeightFp4:
return
scaled_fp4_quant
(
x
,
self
.
input_global_scale
)
return
scaled_fp4_quant
(
x
,
self
.
input_global_scale
)
def
test_speed
(
m
,
k
,
n
):
def
test_speed
(
m
,
k
,
n
):
with
torch
.
no_grad
():
with
torch
.
no_grad
():
input_tensor
=
torch
.
randn
(
m
,
k
,
dtype
=
torch
.
bfloat16
).
cuda
()
input_tensor
=
torch
.
randn
(
m
,
k
,
dtype
=
torch
.
bfloat16
).
cuda
()
...
@@ -56,8 +55,6 @@ def test_speed(m, k, n):
...
@@ -56,8 +55,6 @@ def test_speed(m, k, n):
lightx2v_kernel_time
=
(
end_time
-
start_time
)
/
100
lightx2v_kernel_time
=
(
end_time
-
start_time
)
/
100
print
(
f
"lightx2v-kernel time:
{
lightx2v_kernel_time
}
"
)
print
(
f
"lightx2v-kernel time:
{
lightx2v_kernel_time
}
"
)
input_tensor
=
torch
.
randn
(
m
,
n
,
dtype
=
torch
.
bfloat16
).
cuda
()
input_tensor
=
torch
.
randn
(
m
,
n
,
dtype
=
torch
.
bfloat16
).
cuda
()
weight
=
torch
.
randn
(
k
,
n
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
weight
=
torch
.
randn
(
k
,
n
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
bias
=
torch
.
randn
(
1
,
k
,
dtype
=
torch
.
bfloat16
).
cuda
()
bias
=
torch
.
randn
(
1
,
k
,
dtype
=
torch
.
bfloat16
).
cuda
()
...
@@ -107,19 +104,15 @@ def test_accuracy(m, k, n):
...
@@ -107,19 +104,15 @@ def test_accuracy(m, k, n):
print
(
f
"cos :
{
cos
}
"
)
print
(
f
"cos :
{
cos
}
"
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_sizes
=
[
test_sizes
=
[
(
32130
,
5120
,
5120
),
(
32130
,
5120
,
5120
),
(
512
,
5120
,
5120
),
(
512
,
5120
,
5120
),
(
257
,
5120
,
5120
),
(
257
,
5120
,
5120
),
(
32130
,
5120
,
13824
),
(
32130
,
5120
,
13824
),
(
32130
,
13824
,
5120
),
(
32130
,
13824
,
5120
),
(
75348
,
5120
,
5120
),
(
75348
,
5120
,
5120
),
(
75348
,
13824
,
5120
),
(
75348
,
13824
,
5120
),
(
32760
,
1536
,
1536
),
(
32760
,
1536
,
1536
),
(
512
,
1536
,
1536
),
(
512
,
1536
,
1536
),
(
32760
,
1536
,
8960
),
(
32760
,
1536
,
8960
),
...
@@ -128,7 +121,6 @@ if __name__ == "__main__":
...
@@ -128,7 +121,6 @@ if __name__ == "__main__":
for
i
,
(
m
,
k
,
n
)
in
enumerate
(
test_sizes
):
for
i
,
(
m
,
k
,
n
)
in
enumerate
(
test_sizes
):
print
(
"-"
*
30
)
print
(
"-"
*
30
)
print
(
f
"测试
{
i
+
1
}
: 张量大小 (
{
m
}
,
{
k
}
,
{
n
}
)"
)
print
(
f
"测试
{
i
+
1
}
: 张量大小 (
{
m
}
,
{
k
}
,
{
n
}
)"
)
test_accuracy
(
m
,
k
,
n
)
test_accuracy
(
m
,
k
,
n
)
test_speed
(
m
,
k
,
n
)
test_speed
(
m
,
k
,
n
)
lightx2v_kernel/test/test_bench3_bias.py
View file @
c7bb59cd
...
@@ -25,8 +25,6 @@ def test_speed(m, k, n):
...
@@ -25,8 +25,6 @@ def test_speed(m, k, n):
lightx2v_kernel_time
=
(
end_time
-
start_time
)
/
100
lightx2v_kernel_time
=
(
end_time
-
start_time
)
/
100
print
(
f
"lightx2v-kernel time:
{
lightx2v_kernel_time
}
"
)
print
(
f
"lightx2v-kernel time:
{
lightx2v_kernel_time
}
"
)
input_tensor
=
torch
.
randn
(
m
,
n
,
dtype
=
torch
.
bfloat16
).
cuda
()
input_tensor
=
torch
.
randn
(
m
,
n
,
dtype
=
torch
.
bfloat16
).
cuda
()
weight
=
torch
.
randn
(
k
,
n
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
weight
=
torch
.
randn
(
k
,
n
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
bias
=
torch
.
randn
(
1
,
k
,
dtype
=
torch
.
bfloat16
).
cuda
()
bias
=
torch
.
randn
(
1
,
k
,
dtype
=
torch
.
bfloat16
).
cuda
()
...
@@ -75,19 +73,15 @@ def test_accuracy(m, k, n):
...
@@ -75,19 +73,15 @@ def test_accuracy(m, k, n):
print
(
f
"cos :
{
cos
}
"
)
print
(
f
"cos :
{
cos
}
"
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_sizes
=
[
test_sizes
=
[
(
32130
,
5120
,
5120
),
(
32130
,
5120
,
5120
),
(
512
,
5120
,
5120
),
(
512
,
5120
,
5120
),
(
257
,
5120
,
5120
),
(
257
,
5120
,
5120
),
(
32130
,
5120
,
13824
),
(
32130
,
5120
,
13824
),
(
32130
,
13824
,
5120
),
(
32130
,
13824
,
5120
),
(
75348
,
5120
,
5120
),
(
75348
,
5120
,
5120
),
(
75348
,
13824
,
5120
),
(
75348
,
13824
,
5120
),
(
32760
,
1536
,
1536
),
(
32760
,
1536
,
1536
),
(
512
,
1536
,
1536
),
(
512
,
1536
,
1536
),
(
32760
,
1536
,
8960
),
(
32760
,
1536
,
8960
),
...
@@ -96,7 +90,6 @@ if __name__ == "__main__":
...
@@ -96,7 +90,6 @@ if __name__ == "__main__":
for
i
,
(
m
,
k
,
n
)
in
enumerate
(
test_sizes
):
for
i
,
(
m
,
k
,
n
)
in
enumerate
(
test_sizes
):
print
(
"-"
*
30
)
print
(
"-"
*
30
)
print
(
f
"测试
{
i
+
1
}
: 张量大小 (
{
m
}
,
{
k
}
,
{
n
}
)"
)
print
(
f
"测试
{
i
+
1
}
: 张量大小 (
{
m
}
,
{
k
}
,
{
n
}
)"
)
test_accuracy
(
m
,
k
,
n
)
test_accuracy
(
m
,
k
,
n
)
test_speed
(
m
,
k
,
n
)
test_speed
(
m
,
k
,
n
)
lightx2v_kernel/test/test_mm_tflops.py
View file @
c7bb59cd
...
@@ -14,6 +14,7 @@ alpha = torch.tensor(0.0002765655517578125, device="cuda").to(torch.float32)
...
@@ -14,6 +14,7 @@ alpha = torch.tensor(0.0002765655517578125, device="cuda").to(torch.float32)
bias = None
bias = None
"""
"""
def
test_mm
(
input_tensor_quant
,
weight
,
input_tensor_scale
,
weight_scale
,
alpha
,
bias
):
def
test_mm
(
input_tensor_quant
,
weight
,
input_tensor_scale
,
weight_scale
,
alpha
,
bias
):
output_tensor
=
cutlass_scaled_fp4_mm
(
input_tensor_quant
,
weight
,
input_tensor_scale
,
weight_scale
,
alpha
=
alpha
,
bias
=
bias
)
output_tensor
=
cutlass_scaled_fp4_mm
(
input_tensor_quant
,
weight
,
input_tensor_scale
,
weight_scale
,
alpha
=
alpha
,
bias
=
bias
)
return
output_tensor
return
output_tensor
...
@@ -28,7 +29,6 @@ def test_tflops(input_shape, weight_shape, num_warmup=10, num_runs=100):
...
@@ -28,7 +29,6 @@ def test_tflops(input_shape, weight_shape, num_warmup=10, num_runs=100):
input_tensor_quant
=
(
torch
.
rand
((
input_shape
[
0
],
input_shape
[
1
]
//
2
),
device
=
"cuda"
)
*
10
).
to
(
torch
.
uint8
)
input_tensor_quant
=
(
torch
.
rand
((
input_shape
[
0
],
input_shape
[
1
]
//
2
),
device
=
"cuda"
)
*
10
).
to
(
torch
.
uint8
)
weight
=
(
torch
.
rand
((
weight_shape
[
0
],
weight_shape
[
1
]
//
2
),
device
=
"cuda"
)
*
10
).
to
(
torch
.
uint8
)
weight
=
(
torch
.
rand
((
weight_shape
[
0
],
weight_shape
[
1
]
//
2
),
device
=
"cuda"
)
*
10
).
to
(
torch
.
uint8
)
input_tensor_scale
=
torch
.
rand
(((
input_shape
[
0
]
+
128
-
1
)
//
128
)
*
128
,
(
input_shape
[
1
]
//
16
+
4
-
1
)
//
4
*
4
,
device
=
"cuda"
).
to
(
torch
.
float8_e4m3fn
)
input_tensor_scale
=
torch
.
rand
(((
input_shape
[
0
]
+
128
-
1
)
//
128
)
*
128
,
(
input_shape
[
1
]
//
16
+
4
-
1
)
//
4
*
4
,
device
=
"cuda"
).
to
(
torch
.
float8_e4m3fn
)
weight_scale
=
torch
.
rand
(
weight_shape
[
0
],
weight_shape
[
1
]
//
16
,
device
=
"cuda"
).
to
(
torch
.
float8_e4m3fn
)
weight_scale
=
torch
.
rand
(
weight_shape
[
0
],
weight_shape
[
1
]
//
16
,
device
=
"cuda"
).
to
(
torch
.
float8_e4m3fn
)
alpha
=
torch
.
tensor
(
0.0002765655517578125
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
alpha
=
torch
.
tensor
(
0.0002765655517578125
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
...
@@ -76,9 +76,9 @@ def test_tflops(input_shape, weight_shape, num_warmup=10, num_runs=100):
...
@@ -76,9 +76,9 @@ def test_tflops(input_shape, weight_shape, num_warmup=10, num_runs=100):
print
(
f
" 输出形状: (
{
M
}
,
{
N
}
)"
)
print
(
f
" 输出形状: (
{
M
}
,
{
N
}
)"
)
print
(
f
" 运行次数:
{
num_runs
}
"
)
print
(
f
" 运行次数:
{
num_runs
}
"
)
print
(
f
" 总执行时间:
{
elapsed_time_ms
:.
2
f
}
ms"
)
print
(
f
" 总执行时间:
{
elapsed_time_ms
:.
2
f
}
ms"
)
print
(
f
" 平均每次执行时间:
{
elapsed_time_ms
/
num_runs
:.
4
f
}
ms"
)
print
(
f
" 平均每次执行时间:
{
elapsed_time_ms
/
num_runs
:.
4
f
}
ms"
)
print
(
f
" 每次运行FLOPS:
{
flops_per_run
/
1e9
:.
2
f
}
GFLOPS"
)
print
(
f
" 每次运行FLOPS:
{
flops_per_run
/
1e9
:.
2
f
}
GFLOPS"
)
print
(
f
" 总FLOPS:
{
total_flops
/
1e12
:.
2
f
}
TFLOPS"
)
print
(
f
" 总FLOPS:
{
total_flops
/
1e12
:.
2
f
}
TFLOPS"
)
print
(
f
" 计算性能:
{
tflops
:.
2
f
}
TFLOPS"
)
print
(
f
" 计算性能:
{
tflops
:.
2
f
}
TFLOPS"
)
return
tflops
return
tflops
...
@@ -93,11 +93,9 @@ if __name__ == "__main__":
...
@@ -93,11 +93,9 @@ if __name__ == "__main__":
((
257
,
5120
),
(
5120
,
5120
)),
((
257
,
5120
),
(
5120
,
5120
)),
((
32130
,
5120
),
(
13824
,
5120
)),
((
32130
,
5120
),
(
13824
,
5120
)),
((
32130
,
13824
),
(
5120
,
13824
)),
((
32130
,
13824
),
(
5120
,
13824
)),
((
75348
,
5120
),
(
5120
,
5120
)),
((
75348
,
5120
),
(
5120
,
5120
)),
((
75348
,
5120
),
(
13824
,
5120
)),
((
75348
,
5120
),
(
13824
,
5120
)),
((
75348
,
13824
),
(
5120
,
13824
)),
((
75348
,
13824
),
(
5120
,
13824
)),
((
32760
,
1536
),
(
1536
,
1536
)),
((
32760
,
1536
),
(
1536
,
1536
)),
((
512
,
1536
),
(
1536
,
1536
)),
((
512
,
1536
),
(
1536
,
1536
)),
((
32760
,
1536
),
(
8960
,
1536
)),
((
32760
,
1536
),
(
8960
,
1536
)),
...
@@ -107,7 +105,7 @@ if __name__ == "__main__":
...
@@ -107,7 +105,7 @@ if __name__ == "__main__":
print
(
"=== test_mm TFLOPS性能测试 ===
\n
"
)
print
(
"=== test_mm TFLOPS性能测试 ===
\n
"
)
for
i
,
(
input_shape
,
weight_shape
)
in
enumerate
(
test_cases
):
for
i
,
(
input_shape
,
weight_shape
)
in
enumerate
(
test_cases
):
print
(
f
"测试
{
i
+
1
}
: 输入形状
{
input_shape
}
, 权重形状
{
weight_shape
}
"
)
print
(
f
"测试
{
i
+
1
}
: 输入形状
{
input_shape
}
, 权重形状
{
weight_shape
}
"
)
print
(
"-"
*
60
)
print
(
"-"
*
60
)
tflops
=
test_tflops
(
input_shape
,
weight_shape
)
tflops
=
test_tflops
(
input_shape
,
weight_shape
)
...
...
lightx2v_kernel/test/test_quant_mem_utils.py
View file @
c7bb59cd
...
@@ -4,6 +4,7 @@ from lightx2v_kernel.gemm import scaled_fp4_quant
...
@@ -4,6 +4,7 @@ from lightx2v_kernel.gemm import scaled_fp4_quant
input_global_scale
=
torch
.
tensor
(
808.0
,
dtype
=
torch
.
float32
).
cuda
()
input_global_scale
=
torch
.
tensor
(
808.0
,
dtype
=
torch
.
float32
).
cuda
()
def
quantize_fp4
(
x
):
def
quantize_fp4
(
x
):
return
scaled_fp4_quant
(
x
,
input_global_scale
)
return
scaled_fp4_quant
(
x
,
input_global_scale
)
...
@@ -40,7 +41,6 @@ def test_memory_bandwidth(func, x, num_warmup=10, num_runs=100):
...
@@ -40,7 +41,6 @@ def test_memory_bandwidth(func, x, num_warmup=10, num_runs=100):
# FP4量化后,每个元素占用0.5字节
# FP4量化后,每个元素占用0.5字节
output_bytes
=
x
.
numel
()
*
0.5
# FP4输出数据字节数
output_bytes
=
x
.
numel
()
*
0.5
# FP4输出数据字节数
scale_bytes
=
x
.
numel
()
/
16
# group_size = 16
scale_bytes
=
x
.
numel
()
/
16
# group_size = 16
# 总数据传输量(读取输入 + 写入输出 + scale)
# 总数据传输量(读取输入 + 写入输出 + scale)
...
@@ -54,7 +54,7 @@ def test_memory_bandwidth(func, x, num_warmup=10, num_runs=100):
...
@@ -54,7 +54,7 @@ def test_memory_bandwidth(func, x, num_warmup=10, num_runs=100):
print
(
f
" 输入数据类型:
{
x
.
dtype
}
"
)
print
(
f
" 输入数据类型:
{
x
.
dtype
}
"
)
print
(
f
" 运行次数:
{
num_runs
}
"
)
print
(
f
" 运行次数:
{
num_runs
}
"
)
print
(
f
" 总执行时间:
{
elapsed_time_ms
:.
2
f
}
ms"
)
print
(
f
" 总执行时间:
{
elapsed_time_ms
:.
2
f
}
ms"
)
print
(
f
" 平均每次执行时间:
{
elapsed_time_ms
/
num_runs
:.
4
f
}
ms"
)
print
(
f
" 平均每次执行时间:
{
elapsed_time_ms
/
num_runs
:.
4
f
}
ms"
)
print
(
f
" 输入数据大小:
{
input_bytes
/
(
1024
**
2
):.
2
f
}
MB"
)
print
(
f
" 输入数据大小:
{
input_bytes
/
(
1024
**
2
):.
2
f
}
MB"
)
print
(
f
" 输出数据大小:
{
output_bytes
/
(
1024
**
2
):.
2
f
}
MB"
)
print
(
f
" 输出数据大小:
{
output_bytes
/
(
1024
**
2
):.
2
f
}
MB"
)
print
(
f
" 总数据传输量:
{
total_bytes
/
(
1024
**
3
):.
2
f
}
GB"
)
print
(
f
" 总数据传输量:
{
total_bytes
/
(
1024
**
3
):.
2
f
}
GB"
)
...
@@ -132,16 +132,12 @@ if __name__ == "__main__":
...
@@ -132,16 +132,12 @@ if __name__ == "__main__":
# (32768, 8192),
# (32768, 8192),
# (32768, 16384),
# (32768, 16384),
# (32768, 32768),
# (32768, 32768),
(
32130
,
5120
),
(
32130
,
5120
),
(
512
,
5120
),
(
512
,
5120
),
(
257
,
5120
),
(
257
,
5120
),
(
32130
,
13824
),
(
32130
,
13824
),
(
75348
,
5120
),
(
75348
,
5120
),
(
75348
,
13824
),
(
75348
,
13824
),
(
32760
,
1536
),
(
32760
,
1536
),
(
512
,
1536
),
(
512
,
1536
),
(
32760
,
8960
),
(
32760
,
8960
),
...
@@ -150,7 +146,7 @@ if __name__ == "__main__":
...
@@ -150,7 +146,7 @@ if __name__ == "__main__":
print
(
"=== quantize_fp4 显存带宽测试 ===
\n
"
)
print
(
"=== quantize_fp4 显存带宽测试 ===
\n
"
)
for
i
,
(
h
,
w
)
in
enumerate
(
test_sizes
):
for
i
,
(
h
,
w
)
in
enumerate
(
test_sizes
):
print
(
f
"测试
{
i
+
1
}
: 张量大小 (
{
h
}
,
{
w
}
)"
)
print
(
f
"测试
{
i
+
1
}
: 张量大小 (
{
h
}
,
{
w
}
)"
)
print
(
"-"
*
50
)
print
(
"-"
*
50
)
x
=
torch
.
randn
(
h
,
w
,
dtype
=
torch
.
bfloat16
).
cuda
()
x
=
torch
.
randn
(
h
,
w
,
dtype
=
torch
.
bfloat16
).
cuda
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment