Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
29a90944
"test/verify/test_unbatched_gemm_1.cpp" did not exist on "985f58b009280b531e80fd7f95b5135ef3d8ecd1"
Commit
29a90944
authored
Jul 21, 2025
by
Xtra
Committed by
GitHub
Jul 21, 2025
Browse files
add mxfp4 kernels and rename some func for clarity (#148)
parent
505c5a47
Changes
25
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
12 additions
and
13 deletions
+12
-13
lightx2v_kernel/test/nvfp4_nvfp4/test_bench1.py
lightx2v_kernel/test/nvfp4_nvfp4/test_bench1.py
+4
-4
lightx2v_kernel/test/nvfp4_nvfp4/test_bench2.py
lightx2v_kernel/test/nvfp4_nvfp4/test_bench2.py
+4
-4
lightx2v_kernel/test/nvfp4_nvfp4/test_bench3_bias.py
lightx2v_kernel/test/nvfp4_nvfp4/test_bench3_bias.py
+0
-1
lightx2v_kernel/test/nvfp4_nvfp4/test_mm_tflops.py
lightx2v_kernel/test/nvfp4_nvfp4/test_mm_tflops.py
+2
-2
lightx2v_kernel/test/nvfp4_nvfp4/test_quant_mem_utils.py
lightx2v_kernel/test/nvfp4_nvfp4/test_quant_mem_utils.py
+2
-2
No files found.
lightx2v_kernel/test/nvfp4_nvfp4/test_bench1.py
View file @
29a90944
import
torch
from
lightx2v_kernel.gemm
import
scaled_fp4_quant
,
cutlass_scaled_fp4_mm
from
lightx2v_kernel.gemm
import
scaled_
nv
fp4_quant
,
cutlass_scaled_
nv
fp4_mm
FLOAT4_E2M1_MAX
=
6.0
...
...
@@ -110,8 +110,8 @@ def test_nvfp4_gemm(
print
(
f
"b_global_scale :
{
b_global_scale
}
,
{
b_global_scale
.
shape
}
"
)
alpha
=
1.0
/
(
a_global_scale
*
b_global_scale
)
a_fp4
,
a_scale_interleaved
=
scaled_fp4_quant
(
a_dtype
,
a_global_scale
)
b_fp4
,
b_scale_interleaved
=
scaled_fp4_quant
(
b_dtype
,
b_global_scale
)
a_fp4
,
a_scale_interleaved
=
scaled_
nv
fp4_quant
(
a_dtype
,
a_global_scale
)
b_fp4
,
b_scale_interleaved
=
scaled_
nv
fp4_quant
(
b_dtype
,
b_global_scale
)
expected_out
=
get_ref_results
(
a_fp4
,
...
...
@@ -130,7 +130,7 @@ def test_nvfp4_gemm(
print
(
f
"alpha
{
alpha
}
,
{
alpha
.
shape
}
,
{
alpha
.
dtype
}
"
)
out
=
cutlass_scaled_fp4_mm
(
a_fp4
,
b_fp4
,
a_scale_interleaved
,
b_scale_interleaved
,
alpha
,
bias
)
out
=
cutlass_scaled_
nv
fp4_mm
(
a_fp4
,
b_fp4
,
a_scale_interleaved
,
b_scale_interleaved
,
alpha
,
bias
)
print
(
f
"out :
{
out
}
,
{
out
.
shape
}
,
{
out
.
dtype
}
"
)
print
(
f
"expected_out :
{
expected_out
}
,
{
expected_out
.
shape
}
,
{
expected_out
.
dtype
}
"
)
...
...
lightx2v_kernel/test/nvfp4_nvfp4/test_bench2.py
View file @
29a90944
import
torch
from
lightx2v_kernel.gemm
import
scaled_fp4_quant
,
cutlass_scaled_fp4_mm
from
lightx2v_kernel.gemm
import
scaled_
nv
fp4_quant
,
cutlass_scaled_
nv
fp4_mm
import
time
...
...
@@ -14,13 +14,13 @@ class MMWeightFp4:
@
torch
.
no_grad
()
def
apply
(
self
,
input_tensor
):
input_tensor_quant
,
input_tensor_scale
=
self
.
act_quant_func
(
input_tensor
)
output_tensor
=
cutlass_scaled_fp4_mm
(
input_tensor_quant
,
self
.
weight
,
input_tensor_scale
,
self
.
weight_scale
,
alpha
=
self
.
alpha
,
bias
=
self
.
bias
)
output_tensor
=
cutlass_scaled_
nv
fp4_mm
(
input_tensor_quant
,
self
.
weight
,
input_tensor_scale
,
self
.
weight_scale
,
alpha
=
self
.
alpha
,
bias
=
self
.
bias
)
return
output_tensor
@
torch
.
no_grad
()
def
load_fp4_weight
(
self
,
weight
,
bias
):
self
.
weight_global_scale
=
(
2688.0
/
torch
.
max
(
torch
.
abs
(
weight
))).
to
(
torch
.
float32
)
self
.
weight
,
self
.
weight_scale
=
scaled_fp4_quant
(
weight
,
self
.
weight_global_scale
)
self
.
weight
,
self
.
weight_scale
=
scaled_
nv
fp4_quant
(
weight
,
self
.
weight_global_scale
)
self
.
bias
=
bias
def
calibrate_x_absmax
(
self
):
...
...
@@ -30,7 +30,7 @@ class MMWeightFp4:
@
torch
.
no_grad
()
def
act_quant_fp4
(
self
,
x
):
return
scaled_fp4_quant
(
x
,
self
.
input_global_scale
)
return
scaled_
nv
fp4_quant
(
x
,
self
.
input_global_scale
)
def
test_speed
(
m
,
k
,
n
):
...
...
lightx2v_kernel/test/nvfp4_nvfp4/test_bench3_bias.py
View file @
29a90944
import
torch
from
lightx2v_kernel.gemm
import
scaled_fp4_quant
,
cutlass_scaled_fp4_mm
import
time
from
test_bench2
import
MMWeightFp4
...
...
lightx2v_kernel/test/nvfp4_nvfp4/test_mm_tflops.py
View file @
29a90944
import
torch
from
lightx2v_kernel.gemm
import
cutlass_scaled_fp4_mm
from
lightx2v_kernel.gemm
import
cutlass_scaled_
nv
fp4_mm
"""
...
...
@@ -16,7 +16,7 @@ bias = None
def
test_mm
(
input_tensor_quant
,
weight
,
input_tensor_scale
,
weight_scale
,
alpha
,
bias
):
output_tensor
=
cutlass_scaled_fp4_mm
(
input_tensor_quant
,
weight
,
input_tensor_scale
,
weight_scale
,
alpha
=
alpha
,
bias
=
bias
)
output_tensor
=
cutlass_scaled_
nv
fp4_mm
(
input_tensor_quant
,
weight
,
input_tensor_scale
,
weight_scale
,
alpha
=
alpha
,
bias
=
bias
)
return
output_tensor
...
...
lightx2v_kernel/test/nvfp4_nvfp4/test_quant_mem_utils.py
View file @
29a90944
import
torch
from
lightx2v_kernel.gemm
import
scaled_fp4_quant
from
lightx2v_kernel.gemm
import
scaled_
nv
fp4_quant
input_global_scale
=
torch
.
tensor
(
808.0
,
dtype
=
torch
.
float32
).
cuda
()
def
quantize_fp4
(
x
):
return
scaled_fp4_quant
(
x
,
input_global_scale
)
return
scaled_
nv
fp4_quant
(
x
,
input_global_scale
)
def
test_memory_bandwidth
(
func
,
x
,
num_warmup
=
10
,
num_runs
=
100
):
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment