Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
29a90944
Commit
29a90944
authored
Jul 21, 2025
by
Xtra
Committed by
GitHub
Jul 21, 2025
Browse files
add mxfp4 kernels and rename some func for clarity (#148)
parent
505c5a47
Changes
25
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
12 additions
and
13 deletions
+12
-13
lightx2v_kernel/test/nvfp4_nvfp4/test_bench1.py
lightx2v_kernel/test/nvfp4_nvfp4/test_bench1.py
+4
-4
lightx2v_kernel/test/nvfp4_nvfp4/test_bench2.py
lightx2v_kernel/test/nvfp4_nvfp4/test_bench2.py
+4
-4
lightx2v_kernel/test/nvfp4_nvfp4/test_bench3_bias.py
lightx2v_kernel/test/nvfp4_nvfp4/test_bench3_bias.py
+0
-1
lightx2v_kernel/test/nvfp4_nvfp4/test_mm_tflops.py
lightx2v_kernel/test/nvfp4_nvfp4/test_mm_tflops.py
+2
-2
lightx2v_kernel/test/nvfp4_nvfp4/test_quant_mem_utils.py
lightx2v_kernel/test/nvfp4_nvfp4/test_quant_mem_utils.py
+2
-2
No files found.
lightx2v_kernel/test/nvfp4_nvfp4/test_bench1.py
View file @
29a90944
import
torch
from
lightx2v_kernel.gemm
import
scaled_fp4_quant
,
cutlass_scaled_fp4_mm
from
lightx2v_kernel.gemm
import
scaled_
nv
fp4_quant
,
cutlass_scaled_
nv
fp4_mm
FLOAT4_E2M1_MAX
=
6.0
...
...
@@ -110,8 +110,8 @@ def test_nvfp4_gemm(
print
(
f
"b_global_scale :
{
b_global_scale
}
,
{
b_global_scale
.
shape
}
"
)
alpha
=
1.0
/
(
a_global_scale
*
b_global_scale
)
a_fp4
,
a_scale_interleaved
=
scaled_fp4_quant
(
a_dtype
,
a_global_scale
)
b_fp4
,
b_scale_interleaved
=
scaled_fp4_quant
(
b_dtype
,
b_global_scale
)
a_fp4
,
a_scale_interleaved
=
scaled_
nv
fp4_quant
(
a_dtype
,
a_global_scale
)
b_fp4
,
b_scale_interleaved
=
scaled_
nv
fp4_quant
(
b_dtype
,
b_global_scale
)
expected_out
=
get_ref_results
(
a_fp4
,
...
...
@@ -130,7 +130,7 @@ def test_nvfp4_gemm(
print
(
f
"alpha
{
alpha
}
,
{
alpha
.
shape
}
,
{
alpha
.
dtype
}
"
)
out
=
cutlass_scaled_fp4_mm
(
a_fp4
,
b_fp4
,
a_scale_interleaved
,
b_scale_interleaved
,
alpha
,
bias
)
out
=
cutlass_scaled_
nv
fp4_mm
(
a_fp4
,
b_fp4
,
a_scale_interleaved
,
b_scale_interleaved
,
alpha
,
bias
)
print
(
f
"out :
{
out
}
,
{
out
.
shape
}
,
{
out
.
dtype
}
"
)
print
(
f
"expected_out :
{
expected_out
}
,
{
expected_out
.
shape
}
,
{
expected_out
.
dtype
}
"
)
...
...
lightx2v_kernel/test/nvfp4_nvfp4/test_bench2.py
View file @
29a90944
import
torch
from
lightx2v_kernel.gemm
import
scaled_fp4_quant
,
cutlass_scaled_fp4_mm
from
lightx2v_kernel.gemm
import
scaled_
nv
fp4_quant
,
cutlass_scaled_
nv
fp4_mm
import
time
...
...
@@ -14,13 +14,13 @@ class MMWeightFp4:
@
torch
.
no_grad
()
def
apply
(
self
,
input_tensor
):
input_tensor_quant
,
input_tensor_scale
=
self
.
act_quant_func
(
input_tensor
)
output_tensor
=
cutlass_scaled_fp4_mm
(
input_tensor_quant
,
self
.
weight
,
input_tensor_scale
,
self
.
weight_scale
,
alpha
=
self
.
alpha
,
bias
=
self
.
bias
)
output_tensor
=
cutlass_scaled_
nv
fp4_mm
(
input_tensor_quant
,
self
.
weight
,
input_tensor_scale
,
self
.
weight_scale
,
alpha
=
self
.
alpha
,
bias
=
self
.
bias
)
return
output_tensor
@
torch
.
no_grad
()
def
load_fp4_weight
(
self
,
weight
,
bias
):
self
.
weight_global_scale
=
(
2688.0
/
torch
.
max
(
torch
.
abs
(
weight
))).
to
(
torch
.
float32
)
self
.
weight
,
self
.
weight_scale
=
scaled_fp4_quant
(
weight
,
self
.
weight_global_scale
)
self
.
weight
,
self
.
weight_scale
=
scaled_
nv
fp4_quant
(
weight
,
self
.
weight_global_scale
)
self
.
bias
=
bias
def
calibrate_x_absmax
(
self
):
...
...
@@ -30,7 +30,7 @@ class MMWeightFp4:
@
torch
.
no_grad
()
def
act_quant_fp4
(
self
,
x
):
return
scaled_fp4_quant
(
x
,
self
.
input_global_scale
)
return
scaled_
nv
fp4_quant
(
x
,
self
.
input_global_scale
)
def
test_speed
(
m
,
k
,
n
):
...
...
lightx2v_kernel/test/nvfp4_nvfp4/test_bench3_bias.py
View file @
29a90944
import
torch
from
lightx2v_kernel.gemm
import
scaled_fp4_quant
,
cutlass_scaled_fp4_mm
import
time
from
test_bench2
import
MMWeightFp4
...
...
lightx2v_kernel/test/nvfp4_nvfp4/test_mm_tflops.py
View file @
29a90944
import
torch
from
lightx2v_kernel.gemm
import
cutlass_scaled_fp4_mm
from
lightx2v_kernel.gemm
import
cutlass_scaled_
nv
fp4_mm
"""
...
...
@@ -16,7 +16,7 @@ bias = None
def
test_mm
(
input_tensor_quant
,
weight
,
input_tensor_scale
,
weight_scale
,
alpha
,
bias
):
output_tensor
=
cutlass_scaled_fp4_mm
(
input_tensor_quant
,
weight
,
input_tensor_scale
,
weight_scale
,
alpha
=
alpha
,
bias
=
bias
)
output_tensor
=
cutlass_scaled_
nv
fp4_mm
(
input_tensor_quant
,
weight
,
input_tensor_scale
,
weight_scale
,
alpha
=
alpha
,
bias
=
bias
)
return
output_tensor
...
...
lightx2v_kernel/test/nvfp4_nvfp4/test_quant_mem_utils.py
View file @
29a90944
import
torch
from
lightx2v_kernel.gemm
import
scaled_fp4_quant
from
lightx2v_kernel.gemm
import
scaled_
nv
fp4_quant
input_global_scale
=
torch
.
tensor
(
808.0
,
dtype
=
torch
.
float32
).
cuda
()
def
quantize_fp4
(
x
):
return
scaled_fp4_quant
(
x
,
input_global_scale
)
return
scaled_
nv
fp4_quant
(
x
,
input_global_scale
)
def
test_memory_bandwidth
(
func
,
x
,
num_warmup
=
10
,
num_runs
=
100
):
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment