Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
99d12b98
Commit
99d12b98
authored
Jul 18, 2025
by
Xtra
Committed by
GitHub
Jul 18, 2025
Browse files
fix bias epilogue (#141)
parent
a5138ed3
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
243 additions
and
18 deletions
+243
-18
lightx2v_kernel/csrc/gemm/mxfp6_mxfp8_scaled_mm_kernels_sm120.cu
...v_kernel/csrc/gemm/mxfp6_mxfp8_scaled_mm_kernels_sm120.cu
+15
-4
lightx2v_kernel/csrc/gemm/mxfp8_scaled_mm_kernels_sm120.cu
lightx2v_kernel/csrc/gemm/mxfp8_scaled_mm_kernels_sm120.cu
+15
-4
lightx2v_kernel/csrc/gemm/nvfp4_scaled_mm_kernels_sm120.cu
lightx2v_kernel/csrc/gemm/nvfp4_scaled_mm_kernels_sm120.cu
+15
-4
lightx2v_kernel/test/mxfp6_mxfp8/test_bench3_bias.py
lightx2v_kernel/test/mxfp6_mxfp8/test_bench3_bias.py
+94
-0
lightx2v_kernel/test/mxfp6_mxfp8/test_mxfp6_quant.py
lightx2v_kernel/test/mxfp6_mxfp8/test_mxfp6_quant.py
+4
-2
lightx2v_kernel/test/mxfp8_mxfp8/test_bench3_bias.py
lightx2v_kernel/test/mxfp8_mxfp8/test_bench3_bias.py
+94
-0
lightx2v_kernel/test/mxfp8_mxfp8/test_mxfp8_quant.py
lightx2v_kernel/test/mxfp8_mxfp8/test_mxfp8_quant.py
+4
-2
lightx2v_kernel/test/nvfp4_nvfp4/test_bench3_bias.py
lightx2v_kernel/test/nvfp4_nvfp4/test_bench3_bias.py
+2
-2
No files found.
lightx2v_kernel/csrc/gemm/mxfp6_mxfp8_scaled_mm_kernels_sm120.cu
View file @
99d12b98
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
// clang-format off
// clang-format off
#include "cutlass/cutlass.h"
#include "cutlass/cutlass.h"
#include "cutlass/epilogue/fusion/operations.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
...
@@ -60,6 +61,9 @@ struct Mxfp6Mxfp8GemmSm120 {
...
@@ -60,6 +61,9 @@ struct Mxfp6Mxfp8GemmSm120 {
using
ThreadBlockShape
=
Shape
<
_128
,
_128
,
_128
>
;
// Threadblock's tile size
using
ThreadBlockShape
=
Shape
<
_128
,
_128
,
_128
>
;
// Threadblock's tile size
using
ClusterShape
=
Shape
<
_1
,
_1
,
_1
>
;
// Shape of the threadblocks in a cluster
using
ClusterShape
=
Shape
<
_1
,
_1
,
_1
>
;
// Shape of the threadblocks in a cluster
// use per-column bias, i.e. every column has different bias
using
EVTOp
=
cutlass
::
epilogue
::
fusion
::
LinCombPerColBias
<
ElementD
,
ElementAccumulator
>
;
using
CollectiveEpilogue
=
typename
cutlass
::
epilogue
::
collective
::
CollectiveBuilder
<
using
CollectiveEpilogue
=
typename
cutlass
::
epilogue
::
collective
::
CollectiveBuilder
<
ArchTag
,
OperatorClass
,
ArchTag
,
OperatorClass
,
ThreadBlockShape
,
ClusterShape
,
ThreadBlockShape
,
ClusterShape
,
...
@@ -67,7 +71,8 @@ struct Mxfp6Mxfp8GemmSm120 {
...
@@ -67,7 +71,8 @@ struct Mxfp6Mxfp8GemmSm120 {
ElementAccumulator
,
ElementAccumulator
,
ElementAccumulator
,
ElementAccumulator
,
ElementC
,
LayoutCTag
,
AlignmentC
,
ElementC
,
LayoutCTag
,
AlignmentC
,
ElementD
,
LayoutDTag
,
AlignmentD
,
ElementD
,
LayoutDTag
,
AlignmentD
,
cutlass
::
epilogue
::
collective
::
EpilogueScheduleAuto
// Epilogue schedule policy
cutlass
::
epilogue
::
collective
::
EpilogueScheduleAuto
,
// Epilogue schedule policy
EVTOp
>::
CollectiveOp
;
>::
CollectiveOp
;
using
CollectiveMainloop
=
typename
cutlass
::
gemm
::
collective
::
CollectiveBuilder
<
using
CollectiveMainloop
=
typename
cutlass
::
gemm
::
collective
::
CollectiveBuilder
<
...
@@ -127,7 +132,7 @@ typename Mxfp6Mxfp8GemmSm120::Gemm::Arguments args_from_options_mxfp6_mxfp8(
...
@@ -127,7 +132,7 @@ typename Mxfp6Mxfp8GemmSm120::Gemm::Arguments args_from_options_mxfp6_mxfp8(
auto
layout_SFB
=
Sm1xxBlkScaledConfig
::
tile_atom_to_shape_SFB
(
cute
::
make_shape
(
m
,
n
,
k
,
1
));
auto
layout_SFB
=
Sm1xxBlkScaledConfig
::
tile_atom_to_shape_SFB
(
cute
::
make_shape
(
m
,
n
,
k
,
1
));
if
(
bias
){
if
(
bias
){
auto
s
tride
_b
ias
=
cutlass
::
make_cute_packed_stride
(
Mxfp6Mxfp8GemmSm120
::
StrideC
{},
{})
;
using
S
tride
B
ias
=
Stride
<
cutlass
::
_0
,
cutlass
::
_1
,
int64_t
>
;
typename
Mxfp6Mxfp8GemmSm120
::
Gemm
::
Arguments
arguments
{
typename
Mxfp6Mxfp8GemmSm120
::
Gemm
::
Arguments
arguments
{
cutlass
::
gemm
::
GemmUniversalMode
::
kGemm
,
cutlass
::
gemm
::
GemmUniversalMode
::
kGemm
,
...
@@ -143,12 +148,16 @@ typename Mxfp6Mxfp8GemmSm120::Gemm::Arguments args_from_options_mxfp6_mxfp8(
...
@@ -143,12 +148,16 @@ typename Mxfp6Mxfp8GemmSm120::Gemm::Arguments args_from_options_mxfp6_mxfp8(
layout_SFB
},
layout_SFB
},
{
// Epilogue arguments
{
// Epilogue arguments
{},
// epilogue.thread
{},
// epilogue.thread
static_cast
<
Mxfp6Mxfp8GemmSm120
::
Gemm
::
ElementC
const
*>
(
bias
->
data_ptr
()),
static_cast
<
Mxfp6Mxfp8GemmSm120
::
Gemm
::
ElementC
const
*>
(
D
.
data_ptr
()),
stride_
bias
,
stride_
D
,
static_cast
<
Mxfp6Mxfp8GemmSm120
::
Gemm
::
ElementD
*>
(
D
.
data_ptr
()),
static_cast
<
Mxfp6Mxfp8GemmSm120
::
Gemm
::
ElementD
*>
(
D
.
data_ptr
()),
stride_D
}};
stride_D
}};
auto
&
fusion_args
=
arguments
.
epilogue
.
thread
;
auto
&
fusion_args
=
arguments
.
epilogue
.
thread
;
fusion_args
.
alpha_ptr
=
static_cast
<
float
const
*>
(
alpha
.
data_ptr
());
fusion_args
.
alpha_ptr
=
static_cast
<
float
const
*>
(
alpha
.
data_ptr
());
static
const
float
beta_zero
=
0.0
f
;
fusion_args
.
beta_ptr
=
&
beta_zero
;
fusion_args
.
bias_ptr
=
static_cast
<
Mxfp6Mxfp8GemmSm120
::
Gemm
::
ElementC
const
*>
(
bias
->
data_ptr
());
fusion_args
.
dBias
=
StrideBias
{};
return
arguments
;
return
arguments
;
}
else
{
}
else
{
typename
Mxfp6Mxfp8GemmSm120
::
Gemm
::
Arguments
arguments
{
typename
Mxfp6Mxfp8GemmSm120
::
Gemm
::
Arguments
arguments
{
...
@@ -171,6 +180,8 @@ typename Mxfp6Mxfp8GemmSm120::Gemm::Arguments args_from_options_mxfp6_mxfp8(
...
@@ -171,6 +180,8 @@ typename Mxfp6Mxfp8GemmSm120::Gemm::Arguments args_from_options_mxfp6_mxfp8(
stride_D
}};
stride_D
}};
auto
&
fusion_args
=
arguments
.
epilogue
.
thread
;
auto
&
fusion_args
=
arguments
.
epilogue
.
thread
;
fusion_args
.
alpha_ptr
=
static_cast
<
float
const
*>
(
alpha
.
data_ptr
());
fusion_args
.
alpha_ptr
=
static_cast
<
float
const
*>
(
alpha
.
data_ptr
());
static
const
float
beta_zero
=
0.0
f
;
fusion_args
.
beta_ptr
=
&
beta_zero
;
return
arguments
;
return
arguments
;
}
}
}
}
...
...
lightx2v_kernel/csrc/gemm/mxfp8_scaled_mm_kernels_sm120.cu
View file @
99d12b98
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
// clang-format off
// clang-format off
#include "cutlass/cutlass.h"
#include "cutlass/cutlass.h"
#include "cutlass/epilogue/fusion/operations.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
...
@@ -60,6 +61,9 @@ struct Mxfp8GemmSm120 {
...
@@ -60,6 +61,9 @@ struct Mxfp8GemmSm120 {
using
ThreadBlockShape
=
Shape
<
_128
,
_128
,
_128
>
;
// Threadblock's tile size
using
ThreadBlockShape
=
Shape
<
_128
,
_128
,
_128
>
;
// Threadblock's tile size
using
ClusterShape
=
Shape
<
_1
,
_1
,
_1
>
;
// Shape of the threadblocks in a cluster
using
ClusterShape
=
Shape
<
_1
,
_1
,
_1
>
;
// Shape of the threadblocks in a cluster
// use per-column bias, i.e. every column has different bias
using
EVTOp
=
cutlass
::
epilogue
::
fusion
::
LinCombPerColBias
<
ElementD
,
ElementAccumulator
>
;
using
CollectiveEpilogue
=
typename
cutlass
::
epilogue
::
collective
::
CollectiveBuilder
<
using
CollectiveEpilogue
=
typename
cutlass
::
epilogue
::
collective
::
CollectiveBuilder
<
ArchTag
,
OperatorClass
,
ArchTag
,
OperatorClass
,
ThreadBlockShape
,
ClusterShape
,
ThreadBlockShape
,
ClusterShape
,
...
@@ -67,7 +71,8 @@ struct Mxfp8GemmSm120 {
...
@@ -67,7 +71,8 @@ struct Mxfp8GemmSm120 {
ElementAccumulator
,
ElementAccumulator
,
ElementAccumulator
,
ElementAccumulator
,
ElementC
,
LayoutCTag
,
AlignmentC
,
ElementC
,
LayoutCTag
,
AlignmentC
,
ElementD
,
LayoutDTag
,
AlignmentD
,
ElementD
,
LayoutDTag
,
AlignmentD
,
cutlass
::
epilogue
::
collective
::
EpilogueScheduleAuto
// Epilogue schedule policy
cutlass
::
epilogue
::
collective
::
EpilogueScheduleAuto
,
// Epilogue schedule policy
EVTOp
>::
CollectiveOp
;
>::
CollectiveOp
;
using
CollectiveMainloop
=
typename
cutlass
::
gemm
::
collective
::
CollectiveBuilder
<
using
CollectiveMainloop
=
typename
cutlass
::
gemm
::
collective
::
CollectiveBuilder
<
...
@@ -127,7 +132,7 @@ typename Mxfp8GemmSm120::Gemm::Arguments args_from_options_mxfp8(
...
@@ -127,7 +132,7 @@ typename Mxfp8GemmSm120::Gemm::Arguments args_from_options_mxfp8(
auto
layout_SFB
=
Sm1xxBlkScaledConfig
::
tile_atom_to_shape_SFB
(
cute
::
make_shape
(
m
,
n
,
k
,
1
));
auto
layout_SFB
=
Sm1xxBlkScaledConfig
::
tile_atom_to_shape_SFB
(
cute
::
make_shape
(
m
,
n
,
k
,
1
));
if
(
bias
){
if
(
bias
){
auto
s
tride
_b
ias
=
cutlass
::
make_cute_packed_stride
(
Mxfp8GemmSm120
::
StrideC
{},
{})
;
using
S
tride
B
ias
=
Stride
<
cutlass
::
_0
,
cutlass
::
_1
,
int64_t
>
;
typename
Mxfp8GemmSm120
::
Gemm
::
Arguments
arguments
{
typename
Mxfp8GemmSm120
::
Gemm
::
Arguments
arguments
{
cutlass
::
gemm
::
GemmUniversalMode
::
kGemm
,
cutlass
::
gemm
::
GemmUniversalMode
::
kGemm
,
...
@@ -143,12 +148,16 @@ typename Mxfp8GemmSm120::Gemm::Arguments args_from_options_mxfp8(
...
@@ -143,12 +148,16 @@ typename Mxfp8GemmSm120::Gemm::Arguments args_from_options_mxfp8(
layout_SFB
},
layout_SFB
},
{
// Epilogue arguments
{
// Epilogue arguments
{},
// epilogue.thread
{},
// epilogue.thread
static_cast
<
Mxfp8GemmSm120
::
Gemm
::
ElementC
const
*>
(
bias
->
data_ptr
()),
static_cast
<
Mxfp8GemmSm120
::
Gemm
::
ElementC
const
*>
(
D
.
data_ptr
()),
stride_
bias
,
stride_
D
,
static_cast
<
Mxfp8GemmSm120
::
Gemm
::
ElementD
*>
(
D
.
data_ptr
()),
static_cast
<
Mxfp8GemmSm120
::
Gemm
::
ElementD
*>
(
D
.
data_ptr
()),
stride_D
}};
stride_D
}};
auto
&
fusion_args
=
arguments
.
epilogue
.
thread
;
auto
&
fusion_args
=
arguments
.
epilogue
.
thread
;
fusion_args
.
alpha_ptr
=
static_cast
<
float
const
*>
(
alpha
.
data_ptr
());
fusion_args
.
alpha_ptr
=
static_cast
<
float
const
*>
(
alpha
.
data_ptr
());
static
const
float
beta_zero
=
0.0
f
;
fusion_args
.
beta_ptr
=
&
beta_zero
;
fusion_args
.
bias_ptr
=
static_cast
<
Mxfp8GemmSm120
::
Gemm
::
ElementC
const
*>
(
bias
->
data_ptr
());
fusion_args
.
dBias
=
StrideBias
{};
return
arguments
;
return
arguments
;
}
else
{
}
else
{
typename
Mxfp8GemmSm120
::
Gemm
::
Arguments
arguments
{
typename
Mxfp8GemmSm120
::
Gemm
::
Arguments
arguments
{
...
@@ -171,6 +180,8 @@ typename Mxfp8GemmSm120::Gemm::Arguments args_from_options_mxfp8(
...
@@ -171,6 +180,8 @@ typename Mxfp8GemmSm120::Gemm::Arguments args_from_options_mxfp8(
stride_D
}};
stride_D
}};
auto
&
fusion_args
=
arguments
.
epilogue
.
thread
;
auto
&
fusion_args
=
arguments
.
epilogue
.
thread
;
fusion_args
.
alpha_ptr
=
static_cast
<
float
const
*>
(
alpha
.
data_ptr
());
fusion_args
.
alpha_ptr
=
static_cast
<
float
const
*>
(
alpha
.
data_ptr
());
static
const
float
beta_zero
=
0.0
f
;
fusion_args
.
beta_ptr
=
&
beta_zero
;
return
arguments
;
return
arguments
;
}
}
}
}
...
...
lightx2v_kernel/csrc/gemm/nvfp4_scaled_mm_kernels_sm120.cu
View file @
99d12b98
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
// clang-format off
// clang-format off
#include "cutlass/cutlass.h"
#include "cutlass/cutlass.h"
#include "cutlass/epilogue/fusion/operations.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
...
@@ -60,6 +61,9 @@ struct Fp4GemmSm120 {
...
@@ -60,6 +61,9 @@ struct Fp4GemmSm120 {
using
ThreadBlockShape
=
Shape
<
_128
,
_128
,
_128
>
;
// Threadblock's tile size
using
ThreadBlockShape
=
Shape
<
_128
,
_128
,
_128
>
;
// Threadblock's tile size
using
ClusterShape
=
Shape
<
_1
,
_1
,
_1
>
;
// Shape of the threadblocks in a cluster
using
ClusterShape
=
Shape
<
_1
,
_1
,
_1
>
;
// Shape of the threadblocks in a cluster
// use per-column bias, i.e. every column has different bias
using
EVTOp
=
cutlass
::
epilogue
::
fusion
::
LinCombPerColBias
<
ElementD
,
ElementAccumulator
>
;
using
CollectiveEpilogue
=
typename
cutlass
::
epilogue
::
collective
::
CollectiveBuilder
<
using
CollectiveEpilogue
=
typename
cutlass
::
epilogue
::
collective
::
CollectiveBuilder
<
ArchTag
,
OperatorClass
,
ArchTag
,
OperatorClass
,
ThreadBlockShape
,
ClusterShape
,
ThreadBlockShape
,
ClusterShape
,
...
@@ -67,7 +71,8 @@ struct Fp4GemmSm120 {
...
@@ -67,7 +71,8 @@ struct Fp4GemmSm120 {
ElementAccumulator
,
ElementAccumulator
,
ElementAccumulator
,
ElementAccumulator
,
ElementC
,
LayoutCTag
,
AlignmentC
,
ElementC
,
LayoutCTag
,
AlignmentC
,
ElementD
,
LayoutDTag
,
AlignmentD
,
ElementD
,
LayoutDTag
,
AlignmentD
,
cutlass
::
epilogue
::
collective
::
EpilogueScheduleAuto
// Epilogue schedule policy
cutlass
::
epilogue
::
collective
::
EpilogueScheduleAuto
,
// Epilogue schedule policy
EVTOp
>::
CollectiveOp
;
>::
CollectiveOp
;
using
CollectiveMainloop
=
typename
cutlass
::
gemm
::
collective
::
CollectiveBuilder
<
using
CollectiveMainloop
=
typename
cutlass
::
gemm
::
collective
::
CollectiveBuilder
<
...
@@ -127,7 +132,7 @@ typename Fp4GemmSm120::Gemm::Arguments args_from_options_nvfp4_nvfp4(
...
@@ -127,7 +132,7 @@ typename Fp4GemmSm120::Gemm::Arguments args_from_options_nvfp4_nvfp4(
auto
layout_SFB
=
Sm1xxBlkScaledConfig
::
tile_atom_to_shape_SFB
(
cute
::
make_shape
(
m
,
n
,
k
,
1
));
auto
layout_SFB
=
Sm1xxBlkScaledConfig
::
tile_atom_to_shape_SFB
(
cute
::
make_shape
(
m
,
n
,
k
,
1
));
if
(
bias
){
if
(
bias
){
auto
s
tride
_b
ias
=
cutlass
::
make_cute_packed_stride
(
Fp4GemmSm120
::
StrideC
{},
{})
;
using
S
tride
B
ias
=
Stride
<
cutlass
::
_0
,
cutlass
::
_1
,
int64_t
>
;
typename
Fp4GemmSm120
::
Gemm
::
Arguments
arguments
{
typename
Fp4GemmSm120
::
Gemm
::
Arguments
arguments
{
cutlass
::
gemm
::
GemmUniversalMode
::
kGemm
,
cutlass
::
gemm
::
GemmUniversalMode
::
kGemm
,
...
@@ -143,12 +148,16 @@ typename Fp4GemmSm120::Gemm::Arguments args_from_options_nvfp4_nvfp4(
...
@@ -143,12 +148,16 @@ typename Fp4GemmSm120::Gemm::Arguments args_from_options_nvfp4_nvfp4(
layout_SFB
},
layout_SFB
},
{
// Epilogue arguments
{
// Epilogue arguments
{},
// epilogue.thread
{},
// epilogue.thread
static_cast
<
Fp4GemmSm120
::
Gemm
::
ElementC
const
*>
(
bias
->
data_ptr
()),
static_cast
<
Fp4GemmSm120
::
Gemm
::
ElementC
const
*>
(
D
.
data_ptr
()),
stride_
bias
,
stride_
D
,
static_cast
<
Fp4GemmSm120
::
Gemm
::
ElementD
*>
(
D
.
data_ptr
()),
static_cast
<
Fp4GemmSm120
::
Gemm
::
ElementD
*>
(
D
.
data_ptr
()),
stride_D
}};
stride_D
}};
auto
&
fusion_args
=
arguments
.
epilogue
.
thread
;
auto
&
fusion_args
=
arguments
.
epilogue
.
thread
;
fusion_args
.
alpha_ptr
=
static_cast
<
float
const
*>
(
alpha
.
data_ptr
());
fusion_args
.
alpha_ptr
=
static_cast
<
float
const
*>
(
alpha
.
data_ptr
());
static
const
float
beta_zero
=
0.0
f
;
fusion_args
.
beta_ptr
=
&
beta_zero
;
fusion_args
.
bias_ptr
=
static_cast
<
Fp4GemmSm120
::
Gemm
::
ElementC
const
*>
(
bias
->
data_ptr
());
fusion_args
.
dBias
=
StrideBias
{};
return
arguments
;
return
arguments
;
}
else
{
}
else
{
typename
Fp4GemmSm120
::
Gemm
::
Arguments
arguments
{
typename
Fp4GemmSm120
::
Gemm
::
Arguments
arguments
{
...
@@ -171,6 +180,8 @@ typename Fp4GemmSm120::Gemm::Arguments args_from_options_nvfp4_nvfp4(
...
@@ -171,6 +180,8 @@ typename Fp4GemmSm120::Gemm::Arguments args_from_options_nvfp4_nvfp4(
stride_D
}};
stride_D
}};
auto
&
fusion_args
=
arguments
.
epilogue
.
thread
;
auto
&
fusion_args
=
arguments
.
epilogue
.
thread
;
fusion_args
.
alpha_ptr
=
static_cast
<
float
const
*>
(
alpha
.
data_ptr
());
fusion_args
.
alpha_ptr
=
static_cast
<
float
const
*>
(
alpha
.
data_ptr
());
static
const
float
beta_zero
=
0.0
f
;
fusion_args
.
beta_ptr
=
&
beta_zero
;
return
arguments
;
return
arguments
;
}
}
}
}
...
...
lightx2v_kernel/test/mxfp6_mxfp8/test_bench3_bias.py
0 → 100644
View file @
99d12b98
import
torch
import
time
from
test_bench
import
MMWeightMxfp8ActMxfp6
def
test_speed
(
m
,
k
,
n
):
with
torch
.
no_grad
():
input_tensor
=
torch
.
randn
(
m
,
k
,
dtype
=
torch
.
bfloat16
).
cuda
()
weight
=
torch
.
randn
(
n
,
k
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
bias
=
torch
.
randn
(
1
,
n
,
dtype
=
torch
.
bfloat16
).
cuda
()
mm
=
MMWeightMxfp8ActMxfp6
(
weight
,
bias
)
# warmup
output_tensor
=
mm
.
apply
(
input_tensor
)
torch
.
cuda
.
synchronize
()
start_time
=
time
.
time
()
for
i
in
range
(
100
):
output_tensor
=
mm
.
apply
(
input_tensor
)
torch
.
cuda
.
synchronize
()
end_time
=
time
.
time
()
lightx2v_kernel_time
=
(
end_time
-
start_time
)
/
100
print
(
f
"lightx2v-kernel time:
{
lightx2v_kernel_time
}
"
)
input_tensor
=
torch
.
randn
(
m
,
n
,
dtype
=
torch
.
bfloat16
).
cuda
()
weight
=
torch
.
randn
(
k
,
n
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
bias
=
torch
.
randn
(
1
,
k
,
dtype
=
torch
.
bfloat16
).
cuda
()
linear
=
torch
.
nn
.
Linear
(
k
,
n
,
bias
=
True
).
cuda
()
linear
.
weight
.
data
=
weight
linear
.
bias
.
data
=
bias
# warmup
ref_output_tensor
=
linear
(
input_tensor
)
torch
.
cuda
.
synchronize
()
start_time
=
time
.
time
()
for
i
in
range
(
100
):
ref_output_tensor
=
linear
(
input_tensor
)
torch
.
cuda
.
synchronize
()
end_time
=
time
.
time
()
ref_time
=
(
end_time
-
start_time
)
/
100
print
(
f
"ref time:
{
ref_time
}
"
)
print
(
f
"speedup:
{
ref_time
/
lightx2v_kernel_time
:.
3
f
}
"
)
def
test_accuracy
(
m
,
k
,
n
):
with
torch
.
no_grad
():
input_tensor
=
torch
.
randn
(
m
,
k
,
dtype
=
torch
.
bfloat16
).
cuda
()
weight
=
torch
.
randn
(
n
,
k
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
bias
=
torch
.
ones
(
1
,
n
,
dtype
=
torch
.
bfloat16
).
cuda
()
*
50
linear
=
torch
.
nn
.
Linear
(
k
,
n
,
bias
=
True
).
cuda
()
linear
.
weight
.
data
=
weight
linear
.
bias
.
data
=
bias
ref_output_tensor
=
linear
(
input_tensor
)
mm
=
MMWeightMxfp8ActMxfp6
(
weight
,
bias
)
output_tensor
=
mm
.
apply
(
input_tensor
)
# print(f"ref_output_tensor: {ref_output_tensor}")
# print(f"output_tensor: {output_tensor}")
# cosine
cos
=
torch
.
nn
.
functional
.
cosine_similarity
(
ref_output_tensor
.
flatten
(),
output_tensor
.
flatten
(),
dim
=
0
)
print
(
f
"cos :
{
cos
}
"
)
if
__name__
==
"__main__"
:
test_sizes
=
[
(
32130
,
5120
,
5120
),
(
512
,
5120
,
5120
),
(
257
,
5120
,
5120
),
(
32130
,
5120
,
13824
),
(
32130
,
13824
,
5120
),
(
75348
,
5120
,
5120
),
(
75348
,
13824
,
5120
),
(
32760
,
1536
,
1536
),
(
512
,
1536
,
1536
),
(
32760
,
1536
,
8960
),
(
32760
,
8960
,
1536
),
]
for
i
,
(
m
,
k
,
n
)
in
enumerate
(
test_sizes
):
print
(
"-"
*
30
)
print
(
f
"测试
{
i
+
1
}
: 张量大小 (
{
m
}
,
{
k
}
,
{
n
}
)"
)
test_accuracy
(
m
,
k
,
n
)
test_speed
(
m
,
k
,
n
)
lightx2v_kernel/test/mxfp6_mxfp8/test_mxfp6_quant.py
View file @
99d12b98
...
@@ -27,10 +27,12 @@ class TestQuantBF162MXFP6(unittest.TestCase):
...
@@ -27,10 +27,12 @@ class TestQuantBF162MXFP6(unittest.TestCase):
weight
=
torch
.
randn
(
n
,
k
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
weight
=
torch
.
randn
(
n
,
k
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
weight_quant_pred
,
weight_scale_pred
=
scaled_fp6_quant
(
weight
)
weight_quant_pred
,
weight_scale_pred
=
scaled_fp6_quant
(
weight
)
bias
=
torch
.
rand
(
1
,
n
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
*
10
alpha
=
torch
.
tensor
(
1.0
,
device
=
self
.
device
,
dtype
=
torch
.
float32
)
alpha
=
torch
.
tensor
(
1.0
,
device
=
self
.
device
,
dtype
=
torch
.
float32
)
mm_pred
=
cutlass_scaled_mxfp6_mxfp8_mm
(
activation_quant_pred
,
weight_quant_pred
,
activation_scale_pred
,
weight_scale_pred
,
alpha
=
alpha
)
mm_pred
=
cutlass_scaled_mxfp6_mxfp8_mm
(
activation_quant_pred
,
weight_quant_pred
,
activation_scale_pred
,
weight_scale_pred
,
alpha
=
alpha
,
bias
=
bias
)
mm_real
=
linear
(
activation
,
weight
,
bias
=
None
).
to
(
torch
.
bfloat16
)
mm_real
=
linear
(
activation
,
weight
,
bias
=
bias
).
to
(
torch
.
bfloat16
)
self
.
assertTrue
(
error
(
mm_pred
,
mm_real
)
<
1e-2
,
f
"Accuracy test failed for shape
{
m
,
k
,
n
}
: Error
{
error
(
mm_pred
,
mm_real
)
}
exceeds threshold."
)
self
.
assertTrue
(
error
(
mm_pred
,
mm_real
)
<
1e-2
,
f
"Accuracy test failed for shape
{
m
,
k
,
n
}
: Error
{
error
(
mm_pred
,
mm_real
)
}
exceeds threshold."
)
...
...
lightx2v_kernel/test/mxfp8_mxfp8/test_bench3_bias.py
0 → 100644
View file @
99d12b98
import
torch
import
time
from
test_bench
import
MMWeightMxfp8
def
test_speed
(
m
,
k
,
n
):
with
torch
.
no_grad
():
input_tensor
=
torch
.
randn
(
m
,
k
,
dtype
=
torch
.
bfloat16
).
cuda
()
weight
=
torch
.
randn
(
n
,
k
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
bias
=
torch
.
randn
(
1
,
n
,
dtype
=
torch
.
bfloat16
).
cuda
()
mm
=
MMWeightMxfp8
(
weight
,
bias
)
# warmup
output_tensor
=
mm
.
apply
(
input_tensor
)
torch
.
cuda
.
synchronize
()
start_time
=
time
.
time
()
for
i
in
range
(
100
):
output_tensor
=
mm
.
apply
(
input_tensor
)
torch
.
cuda
.
synchronize
()
end_time
=
time
.
time
()
lightx2v_kernel_time
=
(
end_time
-
start_time
)
/
100
print
(
f
"lightx2v-kernel time:
{
lightx2v_kernel_time
}
"
)
input_tensor
=
torch
.
randn
(
m
,
n
,
dtype
=
torch
.
bfloat16
).
cuda
()
weight
=
torch
.
randn
(
k
,
n
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
bias
=
torch
.
randn
(
1
,
k
,
dtype
=
torch
.
bfloat16
).
cuda
()
linear
=
torch
.
nn
.
Linear
(
k
,
n
,
bias
=
True
).
cuda
()
linear
.
weight
.
data
=
weight
linear
.
bias
.
data
=
bias
# warmup
ref_output_tensor
=
linear
(
input_tensor
)
torch
.
cuda
.
synchronize
()
start_time
=
time
.
time
()
for
i
in
range
(
100
):
ref_output_tensor
=
linear
(
input_tensor
)
torch
.
cuda
.
synchronize
()
end_time
=
time
.
time
()
ref_time
=
(
end_time
-
start_time
)
/
100
print
(
f
"ref time:
{
ref_time
}
"
)
print
(
f
"speedup:
{
ref_time
/
lightx2v_kernel_time
:.
3
f
}
"
)
def
test_accuracy
(
m
,
k
,
n
):
with
torch
.
no_grad
():
input_tensor
=
torch
.
randn
(
m
,
k
,
dtype
=
torch
.
bfloat16
).
cuda
()
weight
=
torch
.
randn
(
n
,
k
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
bias
=
torch
.
randn
(
1
,
n
,
dtype
=
torch
.
bfloat16
).
cuda
()
linear
=
torch
.
nn
.
Linear
(
k
,
n
,
bias
=
True
).
cuda
()
linear
.
weight
.
data
=
weight
linear
.
bias
.
data
=
bias
ref_output_tensor
=
linear
(
input_tensor
)
mm
=
MMWeightMxfp8
(
weight
,
bias
)
output_tensor
=
mm
.
apply
(
input_tensor
)
# print(f"ref_output_tensor: {ref_output_tensor}")
# print(f"output_tensor: {output_tensor}")
# cosine
cos
=
torch
.
nn
.
functional
.
cosine_similarity
(
ref_output_tensor
.
flatten
(),
output_tensor
.
flatten
(),
dim
=
0
)
print
(
f
"cos :
{
cos
}
"
)
if
__name__
==
"__main__"
:
test_sizes
=
[
(
32130
,
5120
,
5120
),
(
512
,
5120
,
5120
),
(
257
,
5120
,
5120
),
(
32130
,
5120
,
13824
),
(
32130
,
13824
,
5120
),
(
75348
,
5120
,
5120
),
(
75348
,
13824
,
5120
),
(
32760
,
1536
,
1536
),
(
512
,
1536
,
1536
),
(
32760
,
1536
,
8960
),
(
32760
,
8960
,
1536
),
]
for
i
,
(
m
,
k
,
n
)
in
enumerate
(
test_sizes
):
print
(
"-"
*
30
)
print
(
f
"测试
{
i
+
1
}
: 张量大小 (
{
m
}
,
{
k
}
,
{
n
}
)"
)
test_accuracy
(
m
,
k
,
n
)
test_speed
(
m
,
k
,
n
)
lightx2v_kernel/test/mxfp8_mxfp8/test_mxfp8_quant.py
View file @
99d12b98
...
@@ -27,10 +27,12 @@ class TestQuantBF162MXFP8(unittest.TestCase):
...
@@ -27,10 +27,12 @@ class TestQuantBF162MXFP8(unittest.TestCase):
weight
=
torch
.
randn
(
n
,
k
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
weight
=
torch
.
randn
(
n
,
k
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
weight_quant_pred
,
weight_scale_pred
=
scaled_fp8_quant
(
weight
)
weight_quant_pred
,
weight_scale_pred
=
scaled_fp8_quant
(
weight
)
bias
=
torch
.
rand
(
1
,
n
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
*
10
alpha
=
torch
.
tensor
(
1.0
,
device
=
self
.
device
,
dtype
=
torch
.
float32
)
alpha
=
torch
.
tensor
(
1.0
,
device
=
self
.
device
,
dtype
=
torch
.
float32
)
mm_pred
=
cutlass_scaled_mxfp8_mm
(
activation_quant_pred
,
weight_quant_pred
,
activation_scale_pred
,
weight_scale_pred
,
alpha
=
alpha
)
mm_pred
=
cutlass_scaled_mxfp8_mm
(
activation_quant_pred
,
weight_quant_pred
,
activation_scale_pred
,
weight_scale_pred
,
alpha
=
alpha
,
bias
=
bias
)
mm_real
=
linear
(
activation
,
weight
,
bias
=
None
).
to
(
torch
.
bfloat16
)
mm_real
=
linear
(
activation
,
weight
,
bias
=
bias
).
to
(
torch
.
bfloat16
)
self
.
assertTrue
(
error
(
mm_pred
,
mm_real
)
<
1e-2
,
f
"Accuracy test failed for shape
{
m
,
k
,
n
}
: Error
{
error
(
mm_pred
,
mm_real
)
}
exceeds threshold."
)
self
.
assertTrue
(
error
(
mm_pred
,
mm_real
)
<
1e-2
,
f
"Accuracy test failed for shape
{
m
,
k
,
n
}
: Error
{
error
(
mm_pred
,
mm_real
)
}
exceeds threshold."
)
...
...
lightx2v_kernel/test/nvfp4_nvfp4/test_bench3_bias.py
View file @
99d12b98
...
@@ -8,7 +8,7 @@ def test_speed(m, k, n):
...
@@ -8,7 +8,7 @@ def test_speed(m, k, n):
with
torch
.
no_grad
():
with
torch
.
no_grad
():
input_tensor
=
torch
.
randn
(
m
,
k
,
dtype
=
torch
.
bfloat16
).
cuda
()
input_tensor
=
torch
.
randn
(
m
,
k
,
dtype
=
torch
.
bfloat16
).
cuda
()
weight
=
torch
.
randn
(
n
,
k
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
weight
=
torch
.
randn
(
n
,
k
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
bias
=
torch
.
randn
(
1
,
n
,
dtype
=
torch
.
bfloat16
).
cuda
()
bias
=
torch
.
ones
(
1
,
n
,
dtype
=
torch
.
bfloat16
).
cuda
()
*
50
mm
=
MMWeightFp4
(
weight
,
bias
)
mm
=
MMWeightFp4
(
weight
,
bias
)
...
@@ -53,7 +53,7 @@ def test_accuracy(m, k, n):
...
@@ -53,7 +53,7 @@ def test_accuracy(m, k, n):
with
torch
.
no_grad
():
with
torch
.
no_grad
():
input_tensor
=
torch
.
randn
(
m
,
k
,
dtype
=
torch
.
bfloat16
).
cuda
()
input_tensor
=
torch
.
randn
(
m
,
k
,
dtype
=
torch
.
bfloat16
).
cuda
()
weight
=
torch
.
randn
(
n
,
k
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
weight
=
torch
.
randn
(
n
,
k
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
bias
=
torch
.
randn
(
1
,
n
,
dtype
=
torch
.
bfloat16
).
cuda
()
bias
=
torch
.
ones
(
1
,
n
,
dtype
=
torch
.
bfloat16
).
cuda
()
*
50
linear
=
torch
.
nn
.
Linear
(
k
,
n
,
bias
=
True
).
cuda
()
linear
=
torch
.
nn
.
Linear
(
k
,
n
,
bias
=
True
).
cuda
()
linear
.
weight
.
data
=
weight
linear
.
weight
.
data
=
weight
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment