Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
e08c4f90
Commit
e08c4f90
authored
Jul 17, 2025
by
sandy
Committed by
GitHub
Jul 17, 2025
Browse files
Merge branch 'main' into audio_r2v
parents
12bfd120
6d07a72e
Changes
191
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1022 additions
and
209 deletions
+1022
-209
lightx2v_kernel/csrc/common_extension.cc
lightx2v_kernel/csrc/common_extension.cc
+4
-0
lightx2v_kernel/csrc/gemm/mxfp6_quant_kernels_sm120.cu
lightx2v_kernel/csrc/gemm/mxfp6_quant_kernels_sm120.cu
+348
-0
lightx2v_kernel/csrc/gemm/mxfp8_quant_kernels_sm120.cu
lightx2v_kernel/csrc/gemm/mxfp8_quant_kernels_sm120.cu
+1
-1
lightx2v_kernel/include/lightx2v_kernel_ops.h
lightx2v_kernel/include/lightx2v_kernel_ops.h
+2
-0
lightx2v_kernel/python/lightx2v_kernel/gemm.py
lightx2v_kernel/python/lightx2v_kernel/gemm.py
+14
-1
lightx2v_kernel/test/mxfp6_mxfp8/test_bench.py
lightx2v_kernel/test/mxfp6_mxfp8/test_bench.py
+121
-0
lightx2v_kernel/test/mxfp6_mxfp8/test_fake_quant.py
lightx2v_kernel/test/mxfp6_mxfp8/test_fake_quant.py
+181
-0
lightx2v_kernel/test/mxfp6_mxfp8/test_mm_tflops.py
lightx2v_kernel/test/mxfp6_mxfp8/test_mm_tflops.py
+115
-0
lightx2v_kernel/test/mxfp6_mxfp8/test_mxfp6_quant.py
lightx2v_kernel/test/mxfp6_mxfp8/test_mxfp6_quant.py
+49
-0
lightx2v_kernel/test/mxfp6_mxfp8/test_quant_mem_utils.py
lightx2v_kernel/test/mxfp6_mxfp8/test_quant_mem_utils.py
+158
-0
requirements.txt
requirements.txt
+3
-0
requirements_win.txt
requirements_win.txt
+19
-0
scripts/cache/readme.md
scripts/cache/readme.md
+6
-52
scripts/cache/run_wan_i2v_taylor.sh
scripts/cache/run_wan_i2v_taylor.sh
+0
-39
scripts/cache/run_wan_t2v_custom.sh
scripts/cache/run_wan_t2v_custom.sh
+0
-38
scripts/cache/run_wan_t2v_tea.sh
scripts/cache/run_wan_t2v_tea.sh
+1
-1
scripts/cogvideox/readme.md
scripts/cogvideox/readme.md
+0
-1
scripts/cogvideox/run_cogvideox_t2v.sh
scripts/cogvideox/run_cogvideox_t2v.sh
+0
-39
scripts/deploy/readme.md
scripts/deploy/readme.md
+0
-1
scripts/deploy/start_dit_server.sh
scripts/deploy/start_dit_server.sh
+0
-36
No files found.
lightx2v_kernel/csrc/common_extension.cc
View file @
e08c4f90
...
@@ -20,6 +20,10 @@ TORCH_LIBRARY_FRAGMENT(lightx2v_kernel, m) {
...
@@ -20,6 +20,10 @@ TORCH_LIBRARY_FRAGMENT(lightx2v_kernel, m) {
"scaled_fp8_quant_sm120(Tensor! output, Tensor! input,"
"scaled_fp8_quant_sm120(Tensor! output, Tensor! input,"
" Tensor! output_scale) -> ()"
);
" Tensor! output_scale) -> ()"
);
m
.
impl
(
"scaled_fp8_quant_sm120"
,
torch
::
kCUDA
,
&
scaled_fp8_quant_sm120
);
m
.
impl
(
"scaled_fp8_quant_sm120"
,
torch
::
kCUDA
,
&
scaled_fp8_quant_sm120
);
m
.
def
(
"scaled_fp6_quant_sm120(Tensor! output, Tensor! input,"
" Tensor! output_scale) -> ()"
);
m
.
impl
(
"scaled_fp6_quant_sm120"
,
torch
::
kCUDA
,
&
scaled_fp6_quant_sm120
);
m
.
def
(
m
.
def
(
"cutlass_scaled_mxfp6_mxfp8_mm_sm120(Tensor! out, Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, Tensor "
"cutlass_scaled_mxfp6_mxfp8_mm_sm120(Tensor! out, Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, Tensor "
...
...
lightx2v_kernel/csrc/gemm/mxfp6_quant_kernels_sm120.cu
0 → 100644
View file @
e08c4f90
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_fp8.h>
#include <cuda_fp6.h>
#include <cuda_runtime.h>
#include <cuda_runtime_api.h>
#include <torch/all.h>
#include "utils.h"
// Get type2 from type or vice versa (applied to half and bfloat16)
template
<
typename
T
>
struct
TypeConverter
{
using
Type
=
half2
;
};
// keep for generality
template
<
>
struct
TypeConverter
<
half2
>
{
using
Type
=
half
;
};
template
<
>
struct
TypeConverter
<
half
>
{
using
Type
=
half2
;
};
template
<
>
struct
TypeConverter
<
__nv_bfloat162
>
{
using
Type
=
__nv_bfloat16
;
};
template
<
>
struct
TypeConverter
<
__nv_bfloat16
>
{
using
Type
=
__nv_bfloat162
;
};
#define ELTS_PER_THREAD 8
constexpr
int
CVT_FP6_ELTS_PER_THREAD
=
8
;
constexpr
int
CVT_FP6_SF_VEC_SIZE
=
32
;
struct
uint8x6_t
{
uint8_t
elts
[
6
];
};
// Convert 4 float2 values into 8 e3m2 values (represented as one uint8x6_t).
inline
__device__
uint8x6_t
fp32_vec_to_e3m2
(
float2
(
&
array
)[
4
])
{
uint64_t
val
;
asm
volatile
(
"{
\n
"
".reg .b16 pack0;
\n
"
".reg .b16 pack1;
\n
"
".reg .b16 pack2;
\n
"
".reg .b16 pack3;
\n
"
"cvt.rn.satfinite.e3m2x2.f32 pack0, %2, %1;
\n
"
"cvt.rn.satfinite.e3m2x2.f32 pack1, %4, %3;
\n
"
"cvt.rn.satfinite.e3m2x2.f32 pack2, %6, %5;
\n
"
"cvt.rn.satfinite.e3m2x2.f32 pack3, %8, %7;
\n
"
"mov.b64 %0, {pack0, pack1, pack2, pack3};
\n
"
"}"
:
"=l"
(
val
)
:
"f"
(
array
[
0
].
x
),
"f"
(
array
[
0
].
y
),
"f"
(
array
[
1
].
x
),
"f"
(
array
[
1
].
y
),
"f"
(
array
[
2
].
x
),
"f"
(
array
[
2
].
y
),
"f"
(
array
[
3
].
x
),
"f"
(
array
[
3
].
y
));
uint8x6_t
result
;
// pack 8 uint8_t into 6 uint8_t
// here is how to pack:
// 4个fp6 a b c d. a:[a5 a4 a3 a2 a1 a0], b..., c..., d...
// 3个unint8 pack0 pack1 pack2
// packed0: [b1 b0][a5 a4 a3 a2 a1 a0]
// packed1: [c3 c2 c1 c0][b5 b4 b3 b2]
// packed2: [d5 d4 d3 d2 d1 d0][c5 c4]
// lower 4 uint8_t
uint8_t
l_val_0
=
val
&
0xFF
;
uint8_t
l_val_1
=
(
val
>>
8
)
&
0xFF
;
uint8_t
l_val_2
=
(
val
>>
16
)
&
0xFF
;
uint8_t
l_val_3
=
(
val
>>
24
)
&
0xFF
;
// higher 4 uint8_t
uint8_t
h_val_0
=
(
val
>>
32
)
&
0xFF
;
uint8_t
h_val_1
=
(
val
>>
40
)
&
0xFF
;
uint8_t
h_val_2
=
(
val
>>
48
)
&
0xFF
;
uint8_t
h_val_3
=
(
val
>>
56
)
&
0xFF
;
// pack result
result
.
elts
[
0
]
=
(
l_val_1
<<
6
)
|
l_val_0
;
result
.
elts
[
1
]
=
(
l_val_2
<<
4
)
|
(
l_val_1
>>
2
);
result
.
elts
[
2
]
=
(
l_val_3
<<
2
)
|
(
l_val_2
>>
4
);
result
.
elts
[
3
]
=
(
h_val_1
<<
6
)
|
h_val_0
;
result
.
elts
[
4
]
=
(
h_val_2
<<
4
)
|
(
h_val_1
>>
2
);
result
.
elts
[
5
]
=
(
h_val_3
<<
2
)
|
(
h_val_2
>>
4
);
return
result
;
}
// Fast reciprocal.
inline
__device__
float
reciprocal_approximate_ftz
(
float
a
)
{
float
b
;
asm
volatile
(
"rcp.approx.ftz.f32 %0, %1;
\n
"
:
"=f"
(
b
)
:
"f"
(
a
));
return
b
;
}
template
<
class
SFType
,
int
CVT_FP6_NUM_THREADS_PER_SF
>
__device__
uint8_t
*
get_sf_out_address
(
int
rowIdx
,
int
colIdx
,
int
numCols
,
SFType
*
SFout
)
{
// #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
static_assert
(
CVT_FP6_NUM_THREADS_PER_SF
==
4
);
// one of 4 threads write one SF to global memory.
// TODO: stage through smem for packed STG.32
// is it better than STG.8 from 4 threads ?
if
(
threadIdx
.
x
%
CVT_FP6_NUM_THREADS_PER_SF
==
0
)
{
// SF vector index (32 elements share one SF in the K dimension).
int32_t
kIdx
=
colIdx
/
CVT_FP6_NUM_THREADS_PER_SF
;
int32_t
mIdx
=
rowIdx
;
// SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)]
// --> index [mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx]
int32_t
mTileIdx
=
mIdx
/
(
32
*
4
);
// SF vector size 32.
int
factor
=
CVT_FP6_SF_VEC_SIZE
*
4
;
int32_t
numKTiles
=
(
numCols
+
factor
-
1
)
/
factor
;
int64_t
mTileStride
=
numKTiles
*
32
*
4
*
4
;
int32_t
kTileIdx
=
(
kIdx
/
4
);
int64_t
kTileStride
=
32
*
4
*
4
;
// M tile layout [32, 4] is column-major.
int32_t
outerMIdx
=
(
mIdx
%
32
);
// same as (mIdx % 128) % 32
int64_t
outerMStride
=
4
*
4
;
int32_t
innerMIdx
=
(
mIdx
%
(
32
*
4
))
/
32
;
int64_t
innerMStride
=
4
;
int32_t
innerKIdx
=
(
kIdx
%
4
);
int64_t
innerKStride
=
1
;
// Compute the global offset.
int64_t
SFOffset
=
mTileIdx
*
mTileStride
+
kTileIdx
*
kTileStride
+
outerMIdx
*
outerMStride
+
innerMIdx
*
innerMStride
+
innerKIdx
*
innerKStride
;
return
reinterpret_cast
<
uint8_t
*>
(
SFout
)
+
SFOffset
;
}
else
{
// Other threads do not write to SFout.
return
nullptr
;
}
}
// Define a 16 bytes packed data type.
template
<
class
Type
>
struct
PackedVec
{
typename
TypeConverter
<
Type
>::
Type
elts
[
4
];
};
// template <>
// struct PackedVec<__nv_fp8_e4m3> {
// __nv_fp8x2_e4m3 elts[8];
// };
template
<
class
Type
>
// Type can be half or bfloat16
__device__
uint8x6_t
cvt_warp_fp16_to_fp6
(
PackedVec
<
Type
>&
vec
,
uint8_t
*
SFout
)
{
// #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
// Get absolute maximum values among the local 8 values.
auto
localMax
=
__habs2
(
vec
.
elts
[
0
]);
// Local maximum value.
#pragma unroll
for
(
int
i
=
1
;
i
<
CVT_FP6_ELTS_PER_THREAD
/
2
;
i
++
)
{
localMax
=
__hmax2
(
localMax
,
__habs2
(
vec
.
elts
[
i
]));
}
// Get the absolute maximum among all 32 values (four threads).
localMax
=
__hmax2
(
__shfl_xor_sync
(
uint32_t
(
-
1
),
localMax
,
1
),
localMax
);
localMax
=
__hmax2
(
__shfl_xor_sync
(
uint32_t
(
-
1
),
localMax
,
2
),
localMax
);
// Get the final absolute maximum values.
float
vecMax
=
float
(
__hmax
(
localMax
.
x
,
localMax
.
y
));
// Get the SF (max value of the vector / max value of e3m2).
// maximum value of e3m2 = 28.0.
// TODO: use half as compute data type.
float
SFValue
=
(
vecMax
/
28.0
f
);
// 8 bits representation of the SF.
uint8_t
fp8SFVal
;
// Write the SF to global memory (STG.8).
__nv_fp8_e8m0
tmp
;
tmp
.
__x
=
__nv_cvt_float_to_e8m0
(
SFValue
,
__NV_SATFINITE
,
cudaRoundPosInf
);
SFValue
=
static_cast
<
float
>
(
tmp
);
fp8SFVal
=
tmp
.
__x
;
float
outputScale
=
SFValue
!=
0
?
reciprocal_approximate_ftz
(
SFValue
)
:
0.0
f
;
if
(
SFout
)
{
// Write the SF to global memory (STG.8).
*
SFout
=
fp8SFVal
;
}
// Convert the input to float.
float2
fp2Vals
[
CVT_FP6_ELTS_PER_THREAD
/
2
];
#pragma unroll
for
(
int
i
=
0
;
i
<
CVT_FP6_ELTS_PER_THREAD
/
2
;
i
++
)
{
if
constexpr
(
std
::
is_same_v
<
Type
,
half
>
)
{
fp2Vals
[
i
]
=
__half22float2
(
vec
.
elts
[
i
]);
}
else
{
fp2Vals
[
i
]
=
__bfloat1622float2
(
vec
.
elts
[
i
]);
}
fp2Vals
[
i
].
x
*=
outputScale
;
fp2Vals
[
i
].
y
*=
outputScale
;
}
// Convert to e3m2 values.
uint8x6_t
e3m2Vec
=
fp32_vec_to_e3m2
(
fp2Vals
);
return
e3m2Vec
;
}
template
<
class
Type
>
// Type can be half or bfloat16
__global__
void
// #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
__launch_bounds__
(
256
,
6
)
cvt_fp16_to_fp6
(
// #else
// cvt_fp16_to_fp6(
// #endif
int32_t
numRows
,
int32_t
numCols
,
Type
const
*
in
,
uint8x6_t
*
out
,
uint32_t
*
SFout
)
{
// #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
using
PackedVec
=
PackedVec
<
Type
>
;
static
constexpr
int
CVT_FP6_NUM_THREADS_PER_SF
=
(
CVT_FP6_SF_VEC_SIZE
/
CVT_FP6_ELTS_PER_THREAD
);
static_assert
(
sizeof
(
PackedVec
)
==
sizeof
(
Type
)
*
CVT_FP6_ELTS_PER_THREAD
,
"Vec size is not matched."
);
// Input tensor row/col loops.
for
(
int
rowIdx
=
blockIdx
.
x
;
rowIdx
<
numRows
;
rowIdx
+=
gridDim
.
x
)
{
for
(
int
colIdx
=
threadIdx
.
x
;
colIdx
<
numCols
/
CVT_FP6_ELTS_PER_THREAD
;
colIdx
+=
blockDim
.
x
)
{
int64_t
inOffset
=
rowIdx
*
(
numCols
/
CVT_FP6_ELTS_PER_THREAD
)
+
colIdx
;
PackedVec
in_vec
=
reinterpret_cast
<
PackedVec
const
*>
(
in
)[
inOffset
];
// Get the output tensor offset.
// Same as inOffset because 8 elements(E3M2) are packed into one uint8x6_t.
int64_t
outOffset
=
inOffset
;
auto
&
out_pos
=
out
[
outOffset
];
auto
sf_out
=
get_sf_out_address
<
uint32_t
,
CVT_FP6_NUM_THREADS_PER_SF
>
(
rowIdx
,
colIdx
,
numCols
,
SFout
);
out_pos
=
cvt_warp_fp16_to_fp6
<
Type
>
(
in_vec
,
sf_out
);
}
}
// #endif
}
template
<
typename
T
>
void
invokeFP6Quantization
(
int
m
,
int
n
,
T
const
*
input
,
int64_t
*
output
,
int32_t
*
SFOuput
,
int
multiProcessorCount
,
cudaStream_t
stream
)
{
// Grid, Block size.
// Each thread converts 8 values.
dim3
block
(
std
::
min
(
int
(
n
/
ELTS_PER_THREAD
),
256
));
// Get number of blocks per SM (assume we can fully utilize the SM).
int
const
numBlocksPerSM
=
1536
/
block
.
x
;
dim3
grid
(
std
::
min
(
int
(
m
),
multiProcessorCount
*
numBlocksPerSM
));
// Launch the cvt kernel.
cvt_fp16_to_fp6
<
T
>
<<<
grid
,
block
,
0
,
stream
>>>
(
m
,
n
,
input
,
reinterpret_cast
<
uint8x6_t
*>
(
output
),
reinterpret_cast
<
uint32_t
*>
(
SFOuput
));
}
// Instantiate the function.
template
void
invokeFP6Quantization
(
int
m
,
int
n
,
half
const
*
input
,
int64_t
*
output
,
int32_t
*
SFOuput
,
int
multiProcessorCount
,
cudaStream_t
stream
);
template
void
invokeFP6Quantization
(
int
m
,
int
n
,
__nv_bfloat16
const
*
input
,
int64_t
*
output
,
int32_t
*
SFOuput
,
int
multiProcessorCount
,
cudaStream_t
stream
);
inline
int
getMultiProcessorCount
()
{
static
int
multi_processor_count
=
[]()
{
int
device_id
=
0
;
int
count
=
0
;
// Get the current CUDA device ID
CHECK_CUDA_SUCCESS
(
cudaGetDevice
(
&
device_id
));
// Get the number of multiprocessors for the current device
CHECK_CUDA_SUCCESS
(
cudaDeviceGetAttribute
(
&
count
,
cudaDevAttrMultiProcessorCount
,
device_id
));
return
count
;
// Initialize the static variable
}();
return
multi_processor_count
;
// Return the cached value on subsequent calls
}
void
scaled_fp6_quant_sm120
(
torch
::
Tensor
&
output
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
&
output_sf
)
{
int32_t
m
=
input
.
size
(
0
);
int32_t
n
=
input
.
size
(
1
);
TORCH_CHECK
(
n
%
32
==
0
,
"The N dimension must be multiple of 32."
);
int
multiProcessorCount
=
getMultiProcessorCount
();
auto
sf_out
=
static_cast
<
int32_t
*>
(
output_sf
.
data_ptr
());
auto
output_ptr
=
static_cast
<
int64_t
*>
(
output
.
data_ptr
());
at
::
cuda
::
CUDAGuard
device_guard
{(
char
)
input
.
get_device
()};
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
input
.
get_device
());
switch
(
input
.
scalar_type
())
{
case
torch
::
kHalf
:
{
auto
input_ptr
=
reinterpret_cast
<
half
const
*>
(
input
.
data_ptr
());
invokeFP6Quantization
(
m
,
n
,
input_ptr
,
output_ptr
,
sf_out
,
multiProcessorCount
,
stream
);
break
;
}
case
torch
::
kBFloat16
:
{
auto
input_ptr
=
reinterpret_cast
<
__nv_bfloat16
const
*>
(
input
.
data_ptr
());
invokeFP6Quantization
(
m
,
n
,
input_ptr
,
output_ptr
,
sf_out
,
multiProcessorCount
,
stream
);
break
;
}
default:
{
std
::
cerr
<<
"Observing: "
<<
input
.
scalar_type
()
<<
" for the input datatype which is invalid"
;
throw
std
::
runtime_error
(
"Unsupported input data type for quantize_to_fp6."
);
}
}
}
lightx2v_kernel/csrc/gemm/mxfp8_quant_kernels_sm120.cu
View file @
e08c4f90
...
@@ -287,7 +287,7 @@ void scaled_fp8_quant_sm120(
...
@@ -287,7 +287,7 @@ void scaled_fp8_quant_sm120(
int32_t
m
=
input
.
size
(
0
);
int32_t
m
=
input
.
size
(
0
);
int32_t
n
=
input
.
size
(
1
);
int32_t
n
=
input
.
size
(
1
);
TORCH_CHECK
(
n
%
32
==
0
,
"The N dimension must be multiple of
16
."
);
TORCH_CHECK
(
n
%
32
==
0
,
"The N dimension must be multiple of
32
."
);
int
multiProcessorCount
=
getMultiProcessorCount
();
int
multiProcessorCount
=
getMultiProcessorCount
();
...
...
lightx2v_kernel/include/lightx2v_kernel_ops.h
View file @
e08c4f90
...
@@ -60,6 +60,8 @@ void scaled_fp4_quant_sm120(
...
@@ -60,6 +60,8 @@ void scaled_fp4_quant_sm120(
void
scaled_fp8_quant_sm120
(
void
scaled_fp8_quant_sm120
(
torch
::
Tensor
&
output
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
&
output_sf
);
torch
::
Tensor
&
output
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
&
output_sf
);
void
scaled_fp6_quant_sm120
(
torch
::
Tensor
&
output
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
&
output_sf
);
void
cutlass_scaled_mxfp6_mxfp8_mm_sm120
(
void
cutlass_scaled_mxfp6_mxfp8_mm_sm120
(
torch
::
Tensor
&
D
,
torch
::
Tensor
&
D
,
...
...
lightx2v_kernel/python/lightx2v_kernel/gemm.py
View file @
e08c4f90
...
@@ -48,13 +48,26 @@ def scaled_fp4_quant(input: torch.Tensor, input_global_scale: torch.Tensor):
...
@@ -48,13 +48,26 @@ def scaled_fp4_quant(input: torch.Tensor, input_global_scale: torch.Tensor):
# rounded_m = ((m + 128 - 1) // 128) * 128
# rounded_m = ((m + 128 - 1) // 128) * 128
# scale_n = n // block_size
# scale_n = n // block_size
# rounded_n = ((scale_n + 4 - 1) // 4) * 4
# rounded_n = ((scale_n + 4 - 1) // 4) * 4
output_scale
=
torch
.
empty
((((
m
+
128
-
1
)
//
128
)
*
128
,
(
n
//
block_size
+
4
-
1
)
//
4
),
device
=
device
,
dtype
=
torch
.
int32
)
output_scale
=
torch
.
zeros
((((
m
+
128
-
1
)
//
128
)
*
128
,
(
n
//
block_size
+
4
-
1
)
//
4
),
device
=
device
,
dtype
=
torch
.
int32
)
torch
.
ops
.
lightx2v_kernel
.
scaled_fp4_quant_sm120
.
default
(
output
,
input
,
output_scale
,
input_global_scale
)
torch
.
ops
.
lightx2v_kernel
.
scaled_fp4_quant_sm120
.
default
(
output
,
input
,
output_scale
,
input_global_scale
)
output_scale
=
output_scale
.
view
(
torch
.
float8_e4m3fn
)
output_scale
=
output_scale
.
view
(
torch
.
float8_e4m3fn
)
return
output
,
output_scale
return
output
,
output_scale
def
scaled_fp6_quant
(
input
:
torch
.
Tensor
):
m
,
n
=
input
.
shape
block_size
=
32
device
=
input
.
device
output
=
torch
.
empty
((
m
,
3
*
n
//
4
),
device
=
device
,
dtype
=
torch
.
uint8
)
output_scale
=
torch
.
zeros
(((
m
+
128
-
1
)
//
128
*
128
,
(
n
//
block_size
+
4
-
1
)
//
4
),
device
=
device
,
dtype
=
torch
.
int32
)
torch
.
ops
.
lightx2v_kernel
.
scaled_fp6_quant_sm120
.
default
(
output
,
input
,
output_scale
)
output_scale
=
output_scale
.
view
(
torch
.
float8_e8m0fnu
)
return
output
,
output_scale
def
scaled_fp8_quant
(
input
:
torch
.
Tensor
):
def
scaled_fp8_quant
(
input
:
torch
.
Tensor
):
m
,
n
=
input
.
shape
m
,
n
=
input
.
shape
block_size
=
32
block_size
=
32
...
...
lightx2v_kernel/test/mxfp6_mxfp8/test_bench.py
0 → 100644
View file @
e08c4f90
import
torch
from
lightx2v_kernel.gemm
import
scaled_fp8_quant
,
scaled_fp6_quant
,
cutlass_scaled_mxfp6_mxfp8_mm
import
time
class
MMWeightMxfp8ActMxfp6
:
def
__init__
(
self
,
weight
,
bias
):
self
.
load_fp6_weight
(
weight
,
bias
)
self
.
act_quant_func
=
self
.
act_quant_fp8
self
.
set_alpha
()
@
torch
.
no_grad
()
def
apply
(
self
,
input_tensor
):
input_tensor_quant
,
input_tensor_scale
=
self
.
act_quant_func
(
input_tensor
)
output_tensor
=
cutlass_scaled_mxfp6_mxfp8_mm
(
input_tensor_quant
,
self
.
weight
,
input_tensor_scale
,
self
.
weight_scale
,
alpha
=
self
.
alpha
,
bias
=
self
.
bias
)
return
output_tensor
@
torch
.
no_grad
()
def
load_fp6_weight
(
self
,
weight
,
bias
):
self
.
weight
,
self
.
weight_scale
=
scaled_fp6_quant
(
weight
)
self
.
bias
=
bias
def
set_alpha
(
self
):
self
.
alpha
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
,
device
=
self
.
weight
.
device
)
@
torch
.
no_grad
()
def
act_quant_fp8
(
self
,
x
):
return
scaled_fp8_quant
(
x
)
def
test_speed
(
m
,
k
,
n
):
with
torch
.
no_grad
():
input_tensor
=
torch
.
randn
(
m
,
k
,
dtype
=
torch
.
bfloat16
).
cuda
()
weight
=
torch
.
randn
(
n
,
k
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
# bias = torch.randn(1, n, dtype=torch.bfloat16).cuda()
bias
=
None
mm
=
MMWeightMxfp8ActMxfp6
(
weight
,
bias
)
# warmup
output_tensor
=
mm
.
apply
(
input_tensor
)
torch
.
cuda
.
synchronize
()
start_time
=
time
.
time
()
for
i
in
range
(
100
):
output_tensor
=
mm
.
apply
(
input_tensor
)
torch
.
cuda
.
synchronize
()
end_time
=
time
.
time
()
lightx2v_kernel_time
=
(
end_time
-
start_time
)
/
100
print
(
f
"lightx2v-kernel time:
{
lightx2v_kernel_time
}
"
)
input_tensor
=
torch
.
randn
(
m
,
n
,
dtype
=
torch
.
bfloat16
).
cuda
()
weight
=
torch
.
randn
(
k
,
n
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
bias
=
torch
.
randn
(
1
,
k
,
dtype
=
torch
.
bfloat16
).
cuda
()
linear
=
torch
.
nn
.
Linear
(
k
,
n
,
bias
=
False
).
cuda
()
linear
.
weight
.
data
=
weight
# linear.bias.data = bias
# warmup
ref_output_tensor
=
linear
(
input_tensor
)
torch
.
cuda
.
synchronize
()
start_time
=
time
.
time
()
for
i
in
range
(
100
):
ref_output_tensor
=
linear
(
input_tensor
)
torch
.
cuda
.
synchronize
()
end_time
=
time
.
time
()
ref_time
=
(
end_time
-
start_time
)
/
100
print
(
f
"ref time:
{
ref_time
}
"
)
print
(
f
"speedup:
{
ref_time
/
lightx2v_kernel_time
:.
3
f
}
"
)
def
test_accuracy
(
m
,
k
,
n
):
with
torch
.
no_grad
():
input_tensor
=
torch
.
randn
(
m
,
k
,
dtype
=
torch
.
bfloat16
).
cuda
()
weight
=
torch
.
randn
(
n
,
k
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
# bias = torch.randn(1, n, dtype=torch.bfloat16).cuda()
bias
=
None
linear
=
torch
.
nn
.
Linear
(
k
,
n
,
bias
=
False
).
cuda
()
linear
.
weight
.
data
=
weight
# linear.bias.data = bias
ref_output_tensor
=
linear
(
input_tensor
)
mm
=
MMWeightMxfp8ActMxfp6
(
weight
,
bias
)
output_tensor
=
mm
.
apply
(
input_tensor
)
# print(f"ref_output_tensor: {ref_output_tensor}")
# print(f"output_tensor: {output_tensor}")
# cosine
cos
=
torch
.
nn
.
functional
.
cosine_similarity
(
ref_output_tensor
.
flatten
(),
output_tensor
.
flatten
(),
dim
=
0
)
print
(
f
"cos :
{
cos
}
"
)
if
__name__
==
"__main__"
:
test_sizes
=
[
(
32130
,
5120
,
5120
),
(
512
,
5120
,
5120
),
(
257
,
5120
,
5120
),
(
32130
,
5120
,
13824
),
(
32130
,
13824
,
5120
),
(
75348
,
5120
,
5120
),
(
75348
,
13824
,
5120
),
(
32760
,
1536
,
1536
),
(
512
,
1536
,
1536
),
(
32760
,
1536
,
8960
),
(
32760
,
8960
,
1536
),
]
for
i
,
(
m
,
k
,
n
)
in
enumerate
(
test_sizes
):
print
(
"-"
*
30
)
print
(
f
"测试
{
i
+
1
}
: 张量大小 (
{
m
}
,
{
k
}
,
{
n
}
)"
)
test_accuracy
(
m
,
k
,
n
)
test_speed
(
m
,
k
,
n
)
lightx2v_kernel/test/mxfp6_mxfp8/test_fake_quant.py
0 → 100644
View file @
e08c4f90
import
torch
from
torchao.prototype.mx_formats.constants
import
DTYPE_FP6_E3M2
from
torchao.prototype.mx_formats.mx_tensor
import
to_mx
,
pack_uint6
def
quant2mxfp8
(
x
:
torch
.
Tensor
):
block_size
=
32
m
,
_
=
x
.
shape
scale
,
output
=
to_mx
(
x
,
torch
.
float8_e4m3fn
,
block_size
=
block_size
)
return
scale
.
reshape
(
m
,
-
1
),
output
def
quant2mxfp6
(
x
:
torch
.
Tensor
):
block_size
=
32
m
,
_
=
x
.
shape
scale
,
output
=
to_mx
(
x
,
DTYPE_FP6_E3M2
,
block_size
=
block_size
,
pack_fp6
=
False
)
return
scale
.
reshape
(
m
,
-
1
),
output
def
scale_pad_and_swizzle
(
scale
:
torch
.
Tensor
):
m
,
s
=
scale
.
shape
# pad the m up to 128, s up to 4
padded_m
=
(
m
+
127
)
//
128
*
128
padded_s
=
(
s
+
3
)
//
4
*
4
padded_scale
=
torch
.
empty
(
padded_m
,
padded_s
,
device
=
scale
.
device
,
dtype
=
scale
.
dtype
)
padded_scale
[:
m
,
:
s
]
=
scale
# swizzle the padded scale
swizzled_scale
=
padded_scale
.
reshape
(
padded_m
//
128
,
128
,
padded_s
//
4
,
4
).
reshape
(
padded_m
//
128
,
4
,
32
,
padded_s
//
4
,
4
).
permute
(
0
,
3
,
2
,
1
,
4
)
return
swizzled_scale
.
reshape
(
padded_m
,
padded_s
)
###############################################################
# Packing kernel and func
###############################################################
import
triton
# noqa: E402
import
triton.language
as
tl
# noqa: E402
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({
"BLOCK_SIZE_IN"
:
2
},
num_warps
=
1
),
triton
.
Config
({
"BLOCK_SIZE_IN"
:
4
},
num_warps
=
1
),
triton
.
Config
({
"BLOCK_SIZE_IN"
:
8
},
num_warps
=
1
),
triton
.
Config
({
"BLOCK_SIZE_IN"
:
16
},
num_warps
=
1
),
],
key
=
[
"n_mx_blocks"
],
)
@
triton
.
jit
def
triton_pack_uint6_kernel
(
input_ptr
,
output_ptr
,
n_mx_blocks
,
MX_BLOCK_SIZE
:
tl
.
constexpr
,
PACKED_MX_BLOCK_SIZE
:
tl
.
constexpr
,
BLOCK_SIZE_IN
:
tl
.
constexpr
,
):
pid
=
tl
.
program_id
(
axis
=
0
)
block_start
=
pid
*
BLOCK_SIZE_IN
# input_ptr is shape [n_mx_blocks, MX_BLOCK_SIZE]
# Load BLOCK_SIZE rows of input_ptr
offsets_rows
=
block_start
+
tl
.
arange
(
0
,
BLOCK_SIZE_IN
)
offsets_cols
=
tl
.
arange
(
0
,
MX_BLOCK_SIZE
//
4
)
offsets
=
offsets_rows
[:,
None
]
*
MX_BLOCK_SIZE
+
(
4
*
offsets_cols
[
None
,
:])
mask
=
(
offsets_rows
[:,
None
]
<
n_mx_blocks
)
&
(
offsets_cols
[
None
,
:]
<
MX_BLOCK_SIZE
//
4
)
# x is shape [BLOCK_SIZE, MX_BLOCK_SIZE]
x_0
=
tl
.
load
(
input_ptr
+
offsets
,
mask
=
mask
)
x_1
=
tl
.
load
(
input_ptr
+
offsets
+
1
,
mask
=
mask
)
x_2
=
tl
.
load
(
input_ptr
+
offsets
+
2
,
mask
=
mask
)
x_3
=
tl
.
load
(
input_ptr
+
offsets
+
3
,
mask
=
mask
)
# 4个fp6 a b c d. a:[a5 a4 a3 a2 a1 a0], b..., c..., d...
# 3个unint8 pack0 pack1 pack2
# cutlass需要的:
# packed0: [b1 b0][a5 a4 a3 a2 a1 a0]
# packed1: [c3 c2 c1 c0][b5 b4 b3 b2]
# packed2: [d5 d4 d3 d2 d1 d0][c5 c4]
bits_packed0
=
(
x_1
<<
6
)
|
x_0
bits_packed1
=
(
x_2
<<
4
)
|
(
x_1
>>
2
)
bits_packed2
=
(
x_3
<<
2
)
|
(
x_2
>>
4
)
# Store values in a uint8 tensor of length `3 * MX_BLOCK_SIZE / 4`
offsets_out_4_a
=
offsets_rows
[:,
None
]
*
PACKED_MX_BLOCK_SIZE
+
3
*
offsets_cols
[
None
,
:]
offsets_out_4_b
=
offsets_rows
[:,
None
]
*
PACKED_MX_BLOCK_SIZE
+
3
*
offsets_cols
[
None
,
:]
+
1
offsets_out_2
=
offsets_rows
[:,
None
]
*
PACKED_MX_BLOCK_SIZE
+
3
*
offsets_cols
[
None
,
:]
+
2
# Store into output tensor
tl
.
store
(
output_ptr
+
offsets_out_4_a
,
bits_packed0
,
mask
=
mask
,
)
tl
.
store
(
output_ptr
+
offsets_out_4_b
,
bits_packed1
,
mask
=
mask
,
)
tl
.
store
(
output_ptr
+
offsets_out_2
,
bits_packed2
,
mask
=
mask
,
)
def
pack_uint6
(
uint8_data
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# ensure input data is contiguous before passing to kernel
assert
uint8_data
.
is_contiguous
()
# tensor should already be of shape [..., mx_block_size]
mx_block_size
=
uint8_data
.
shape
[
-
1
]
assert
mx_block_size
%
4
==
0
# effective mx block size since we're packing 2 fp4 into 1 uint8
packed_mx_block_size
=
3
*
mx_block_size
//
4
packed_shape
=
[
uint8_data
.
shape
[
0
],
packed_mx_block_size
]
n_mx_blocks
=
uint8_data
.
numel
()
//
mx_block_size
grid
=
lambda
meta
:
(
triton
.
cdiv
(
n_mx_blocks
,
meta
[
"BLOCK_SIZE_IN"
]),)
# noqa: E731
# contiguous uint8 container in which we can store the unpacked tensor
packed_uint8_data
=
torch
.
empty
(
packed_shape
,
dtype
=
torch
.
uint8
,
device
=
uint8_data
.
device
)
triton_pack_uint6_kernel
[
grid
](
uint8_data
,
packed_uint8_data
,
n_mx_blocks
,
MX_BLOCK_SIZE
=
mx_block_size
,
PACKED_MX_BLOCK_SIZE
=
packed_mx_block_size
,
)
return
packed_uint8_data
M
=
[
257
,
512
,
1024
,
13325
,
32130
,
32760
]
# , 75348
N
=
[
1536
,
5120
,
8960
]
# , 13824
K
=
[
128
,
256
,
512
,
1024
,
2048
,
4096
]
# , 13824
for
m
in
M
:
for
n
in
N
:
for
k
in
K
:
x
=
torch
.
randn
(
m
,
k
,
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
w
=
torch
.
randn
(
n
,
k
,
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
# excute quant
x_scale
,
x_quant
=
quant2mxfp8
(
x
)
w_scale
,
w_quant
=
quant2mxfp6
(
w
)
# pack fp6 for cutlass
w_quant_packed
=
pack_uint6
(
w_quant
.
reshape
(
-
1
,
32
))
# pad and swizzle scale
padded_and_swizzled_x_scale
=
scale_pad_and_swizzle
(
x_scale
)
padded_and_swizzled_w_scale
=
scale_pad_and_swizzle
(
w_scale
)
# ref mm result
ref_mm
=
torch
.
nn
.
functional
.
linear
(
x
,
w
).
to
(
torch
.
bfloat16
)
# custom scaled mm
from
lightx2v_kernel.gemm
import
cutlass_scaled_mxfp6_mxfp8_mm
alpha
=
torch
.
tensor
(
1.0
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
bias
=
None
x_quant
=
x_quant
.
reshape
(
m
,
k
).
view
(
torch
.
uint8
)
w_quant_packed
=
w_quant_packed
.
reshape
(
n
,
3
*
k
//
4
)
custom_mm
=
cutlass_scaled_mxfp6_mxfp8_mm
(
x_quant
,
w_quant_packed
,
padded_and_swizzled_x_scale
,
padded_and_swizzled_w_scale
,
alpha
,
bias
)
# cal snr
from
lightx2v_kernel.utils
import
error
print
(
f
"m:
{
m
}
, n:
{
n
}
, k:
{
k
}
, error:
{
error
(
ref_mm
,
custom_mm
)
}
"
)
# cal cos
cos_sim
=
torch
.
nn
.
functional
.
cosine_similarity
(
ref_mm
.
flatten
(),
custom_mm
.
flatten
(),
dim
=
0
)
print
(
f
"m:
{
m
}
, n:
{
n
}
, k:
{
k
}
, cos_sim:
{
cos_sim
}
"
)
lightx2v_kernel/test/mxfp6_mxfp8/test_mm_tflops.py
0 → 100644
View file @
e08c4f90
import
torch
from
lightx2v_kernel.gemm
import
cutlass_scaled_mxfp6_mxfp8_mm
"""
input_shape = (1024, 2048)
weight_shape = (4096, 2048)
input_tensor_quant = (torch.rand((1024, 1024), device="cuda") * 10).to(torch.uint8)
weight = (torch.rand((4096, 1024), device="cuda") * 10).to(torch.uint8)
input_tensor_scale = torch.rand(1024, 128, device="cuda").to(torch.float8_e8m0fnu)
weight_scale = torch.rand(4096, 128, device="cuda").to(torch.float8_e8m0fnu)
alpha = torch.tensor(1.0, device="cuda").to(torch.float32)
bias = None
"""
def
test_mm
(
input_tensor_quant
,
weight
,
input_tensor_scale
,
weight_scale
,
alpha
,
bias
):
output_tensor
=
cutlass_scaled_mxfp6_mxfp8_mm
(
input_tensor_quant
,
weight
,
input_tensor_scale
,
weight_scale
,
alpha
=
alpha
,
bias
=
bias
)
return
output_tensor
def
test_tflops
(
input_shape
,
weight_shape
,
num_warmup
=
10
,
num_runs
=
100
):
"""
测试test_mm函数的TFLOPS性能
"""
# 创建输入数据
input_tensor_quant
=
(
torch
.
rand
((
input_shape
[
0
],
input_shape
[
1
]),
device
=
"cuda"
)
*
10
).
to
(
torch
.
uint8
)
weight
=
(
torch
.
rand
((
weight_shape
[
0
],
3
*
weight_shape
[
1
]
//
4
),
device
=
"cuda"
)
*
10
).
to
(
torch
.
uint8
)
input_tensor_scale
=
torch
.
rand
(((
input_shape
[
0
]
+
128
-
1
)
//
128
)
*
128
,
(
input_shape
[
1
]
//
32
+
4
-
1
)
//
4
*
4
,
device
=
"cuda"
).
to
(
torch
.
float8_e8m0fnu
)
weight_scale
=
torch
.
rand
(
weight_shape
[
0
],
weight_shape
[
1
]
//
32
,
device
=
"cuda"
).
to
(
torch
.
float8_e8m0fnu
)
alpha
=
torch
.
tensor
(
1.0
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
bias
=
None
# 预热GPU
for
_
in
range
(
num_warmup
):
test_mm
(
input_tensor_quant
,
weight
,
input_tensor_scale
,
weight_scale
,
alpha
,
bias
)
# 同步GPU
torch
.
cuda
.
synchronize
()
# 创建GPU事件用于精确计时
start_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
end_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
# 测量时间
start_event
.
record
()
for
_
in
range
(
num_runs
):
result
=
test_mm
(
input_tensor_quant
,
weight
,
input_tensor_scale
,
weight_scale
,
alpha
,
bias
)
end_event
.
record
()
# 同步并计算时间
torch
.
cuda
.
synchronize
()
elapsed_time_ms
=
start_event
.
elapsed_time
(
end_event
)
elapsed_time_s
=
elapsed_time_ms
/
1000.0
# 计算FLOPS
# 矩阵乘法 A(M x K) @ B(K x N) = C(M x N)
# M = batch_size, K = input_dim, N = output_dim
M
=
input_shape
[
0
]
K
=
input_shape
[
1
]
N
=
weight_shape
[
0
]
# 每次矩阵乘法的FLOPS = 2 * M * N * K (每个输出元素需要K次乘法和K次加法)
flops_per_run
=
2
*
M
*
N
*
K
total_flops
=
flops_per_run
*
num_runs
# 计算TFLOPS (万亿次浮点运算每秒)
tflops
=
total_flops
/
(
elapsed_time_s
*
1e12
)
print
(
f
"测试结果:"
)
print
(
f
" 输入形状:
{
input_shape
}
(M=
{
M
}
, K=
{
K
}
)"
)
print
(
f
" 权重形状:
{
weight_shape
}
(N=
{
N
}
, K=
{
K
}
)"
)
print
(
f
" 输出形状: (
{
M
}
,
{
N
}
)"
)
print
(
f
" 运行次数:
{
num_runs
}
"
)
print
(
f
" 总执行时间:
{
elapsed_time_ms
:.
2
f
}
ms"
)
print
(
f
" 平均每次执行时间:
{
elapsed_time_ms
/
num_runs
:.
4
f
}
ms"
)
print
(
f
" 每次运行FLOPS:
{
flops_per_run
/
1e9
:.
2
f
}
GFLOPS"
)
print
(
f
" 总FLOPS:
{
total_flops
/
1e12
:.
2
f
}
TFLOPS"
)
print
(
f
" 计算性能:
{
tflops
:.
2
f
}
TFLOPS"
)
return
tflops
if
__name__
==
"__main__"
:
# 测试不同大小的矩阵乘法
# (m,k) (n,k)
test_cases
=
[
((
32130
,
5120
),
(
5120
,
5120
)),
((
512
,
1536
),
(
1536
,
1536
)),
((
512
,
5120
),
(
5120
,
5120
)),
((
257
,
5120
),
(
5120
,
5120
)),
((
32130
,
5120
),
(
13824
,
5120
)),
((
32130
,
13824
),
(
5120
,
13824
)),
((
75348
,
5120
),
(
5120
,
5120
)),
((
75348
,
5120
),
(
13824
,
5120
)),
((
75348
,
13824
),
(
5120
,
13824
)),
((
32760
,
1536
),
(
1536
,
1536
)),
((
512
,
1536
),
(
1536
,
1536
)),
((
32760
,
1536
),
(
8960
,
1536
)),
((
32760
,
8960
),
(
1536
,
8960
)),
]
print
(
"=== test_mm TFLOPS性能测试 ===
\n
"
)
for
i
,
(
input_shape
,
weight_shape
)
in
enumerate
(
test_cases
):
print
(
f
"测试
{
i
+
1
}
: 输入形状
{
input_shape
}
, 权重形状
{
weight_shape
}
"
)
print
(
"-"
*
60
)
tflops
=
test_tflops
(
input_shape
,
weight_shape
)
print
(
f
"✓ 成功完成测试,性能:
{
tflops
:.
2
f
}
TFLOPS
\n
"
)
print
(
"=== 测试完成 ==="
)
lightx2v_kernel/test/mxfp6_mxfp8/test_mxfp6_quant.py
0 → 100644
View file @
e08c4f90
import
unittest
import
torch
from
lightx2v_kernel.gemm
import
cutlass_scaled_mxfp6_mxfp8_mm
from
lightx2v_kernel.gemm
import
scaled_fp6_quant
,
scaled_fp8_quant
from
torch.nn.functional
import
linear
from
lightx2v_kernel.utils
import
error
,
benchmark
class
TestQuantBF162MXFP6
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
tokens
=
[
128
,
257
,
512
,
1024
,
13325
,
32130
,
32760
]
# , 75348
self
.
channels
=
[
128
,
1536
,
5120
,
8960
]
# , 13824
self
.
hiddenDims
=
[
128
,
1536
,
3072
,
5120
,
8960
,
12800
]
# , 13824
self
.
device
=
"cuda"
self
.
dtype
=
torch
.
bfloat16
def
test_accuracy
(
self
):
"""Test the accuracy of quantization from BF16 to MXFP6."""
for
m
in
self
.
tokens
:
for
k
in
self
.
hiddenDims
:
for
n
in
self
.
channels
:
with
self
.
subTest
(
shape
=
[
m
,
k
,
n
]):
activation
=
torch
.
randn
(
m
,
k
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
activation_quant_pred
,
activation_scale_pred
=
scaled_fp8_quant
(
activation
)
weight
=
torch
.
randn
(
n
,
k
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
weight_quant_pred
,
weight_scale_pred
=
scaled_fp6_quant
(
weight
)
alpha
=
torch
.
tensor
(
1.0
,
device
=
self
.
device
,
dtype
=
torch
.
float32
)
mm_pred
=
cutlass_scaled_mxfp6_mxfp8_mm
(
activation_quant_pred
,
weight_quant_pred
,
activation_scale_pred
,
weight_scale_pred
,
alpha
=
alpha
)
mm_real
=
linear
(
activation
,
weight
,
bias
=
None
).
to
(
torch
.
bfloat16
)
self
.
assertTrue
(
error
(
mm_pred
,
mm_real
)
<
1e-2
,
f
"Accuracy test failed for shape
{
m
,
k
,
n
}
: Error
{
error
(
mm_pred
,
mm_real
)
}
exceeds threshold."
)
def
test_performance
(
self
):
"""Benchmark the performance of Activation quantization from BF16 to MXFP6."""
for
m
in
self
.
tokens
:
for
k
in
self
.
hiddenDims
:
with
self
.
subTest
(
shape
=
[
m
,
k
]):
input
=
torch
.
randn
(
m
,
k
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
shape
=
[
m
,
k
]
tflops
=
2
*
(
m
*
k
/
1024
**
4
)
benchmark
(
scaled_fp6_quant
,
shape
,
tflops
,
100
,
input
)
if
__name__
==
"__main__"
:
unittest
.
main
()
lightx2v_kernel/test/mxfp6_mxfp8/test_quant_mem_utils.py
0 → 100644
View file @
e08c4f90
import
torch
from
lightx2v_kernel.gemm
import
scaled_fp6_quant
def
quantize_fp6
(
x
):
return
scaled_fp6_quant
(
x
)
def
test_memory_bandwidth
(
func
,
x
,
num_warmup
=
10
,
num_runs
=
100
):
"""
测试函数的显存带宽
"""
# 预热GPU
for
_
in
range
(
num_warmup
):
func
(
x
)
# 同步GPU
torch
.
cuda
.
synchronize
()
# 创建GPU事件用于精确计时
start_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
end_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
# 测量时间
start_event
.
record
()
for
_
in
range
(
num_runs
):
result
=
func
(
x
)
end_event
.
record
()
# 同步并计算时间
torch
.
cuda
.
synchronize
()
elapsed_time_ms
=
start_event
.
elapsed_time
(
end_event
)
elapsed_time_s
=
elapsed_time_ms
/
1000.0
# 计算数据量
input_bytes
=
x
.
numel
()
*
x
.
element_size
()
# 输入数据字节数
# FP6量化后,每个元素占用 3/ 4字节
output_bytes
=
x
.
numel
()
*
(
3
/
4
)
# FP6输出数据字节数
scale_bytes
=
x
.
numel
()
/
32
# group_size = 32
# 总数据传输量(读取输入 + 写入输出 + scale)
total_bytes
=
(
input_bytes
+
output_bytes
+
scale_bytes
)
*
num_runs
# 计算带宽
bandwidth_gbps
=
(
total_bytes
/
elapsed_time_s
)
/
(
1024
**
3
)
# GB/s
print
(
f
"测试结果:"
)
print
(
f
" 输入张量形状:
{
x
.
shape
}
"
)
print
(
f
" 输入数据类型:
{
x
.
dtype
}
"
)
print
(
f
" 运行次数:
{
num_runs
}
"
)
print
(
f
" 总执行时间:
{
elapsed_time_ms
:.
2
f
}
ms"
)
print
(
f
" 平均每次执行时间:
{
elapsed_time_ms
/
num_runs
:.
4
f
}
ms"
)
print
(
f
" 输入数据大小:
{
input_bytes
/
(
1024
**
2
):.
2
f
}
MB"
)
print
(
f
" 输出数据大小:
{
output_bytes
/
(
1024
**
2
):.
2
f
}
MB"
)
print
(
f
" 总数据传输量:
{
total_bytes
/
(
1024
**
3
):.
2
f
}
GB"
)
print
(
f
" 显存带宽:
{
bandwidth_gbps
:.
2
f
}
GB/s"
)
return
bandwidth_gbps
if
__name__
==
"__main__"
:
# 测试不同大小的张量
test_sizes
=
[
# (1, 1024),
# (1, 2048),
# (1, 4096),
# (1, 8192),
# (1, 16384),
# (1, 32768),
# (2, 1024),
# (2, 2048),
# (2, 4096),
# (2, 8192),
# (2, 16384),
# (2, 32768),
# (4, 1024),
# (4, 2048),
# (4, 4096),
# (4, 8192),
# (4, 16384),
# (4, 32768),
# (128, 1024),
# (128, 2048),
# (128, 4096),
# (128, 8192),
# (128, 16384),
# (128, 32768),
# (512, 1024),
# (512, 2048),
# (512, 4096),
# (512, 8192),
# (512, 16384),
# (512, 32768),
# (1024, 1024),
# (1024, 2048),
# (1024, 4096),
# (1024, 8192),
# (1024, 16384),
# (1024, 32768),
# (2048, 1024),
# (2048, 2048),
# (2048, 4096),
# (2048, 8192),
# (2048, 16384),
# (2048, 32768),
# (4096, 1024),
# (4096, 2048),
# (4096, 4096),
# (4096, 8192),
# (4096, 16384),
# (4096, 32768),
# (8192, 1024),
# (8192, 2048),
# (8192, 4096),
# (8192, 8192),
# (8192, 16384),
# (8192, 32768),
# (16384, 1024),
# (16384, 2048),
# (16384, 4096),
# (16384, 8192),
# (16384, 16384),
# (16384, 32768),
# (32768, 1024),
# (32768, 2048),
# (32768, 4096),
# (32768, 8192),
# (32768, 16384),
# (32768, 32768),
(
32130
,
5120
),
(
512
,
5120
),
(
257
,
5120
),
(
32130
,
13824
),
(
75348
,
5120
),
(
75348
,
13824
),
(
32760
,
1536
),
(
512
,
3072
),
(
512
,
1536
),
(
32760
,
8960
),
]
print
(
"=== quantize_fp8 显存带宽测试 ===
\n
"
)
for
i
,
(
h
,
w
)
in
enumerate
(
test_sizes
):
print
(
f
"测试
{
i
+
1
}
: 张量大小 (
{
h
}
,
{
w
}
)"
)
print
(
"-"
*
50
)
x
=
torch
.
randn
(
h
,
w
,
dtype
=
torch
.
bfloat16
).
cuda
()
try
:
bandwidth
=
test_memory_bandwidth
(
quantize_fp6
,
x
)
print
(
f
"✓ 成功完成测试,带宽:
{
bandwidth
:.
2
f
}
GB/s
\n
"
)
except
Exception
as
e
:
print
(
f
"✗ 测试失败:
{
e
}
\n
"
)
print
(
"=== 测试完成 ==="
)
requirements.txt
View file @
e08c4f90
...
@@ -18,3 +18,6 @@ sgl-kernel
...
@@ -18,3 +18,6 @@ sgl-kernel
qtorch
qtorch
ftfy
ftfy
easydict
easydict
gradio
aiohttp
pydantic
requirements_win.txt
0 → 100644
View file @
e08c4f90
packaging
ninja
diffusers
transformers
tokenizers
accelerate
safetensors
opencv-python
numpy
imageio
imageio-ffmpeg
einops
loguru
qtorch
ftfy
easydict
gradio
aiohttp
pydantic
scripts/cache/readme.md
View file @
e08c4f90
# Cache
# Feature Caching
## 缓存加速算法
-
在扩散模型的推理过程中,缓存复用是一种重要的加速算法。
-
其核心思想是在部分时间步跳过冗余计算,通过复用历史缓存结果提升推理效率。
-
算法的关键在于如何决策在哪些时间步进行缓存复用,通常基于模型状态变化或误差阈值动态判断。
-
在推理过程中,需要缓存如中间特征、残差、注意力输出等关键内容。当进入可复用时间步时,直接利用已缓存的内容,通过泰勒展开等近似方法重构当前输出,从而减少重复计算,实现高效推理。
## TeaCache
The config files for feature caching are available
[
here
](
https://github.com/ModelTC/lightx2v/tree/main/configs/caching
)
`TeaCache`
的核心思想是通过对相邻时间步输入的
**相对L1**
距离进行累加,当累计距离达到设定阈值时,判定当前时间步可以进行缓存复用。
-
具体来说,算法在每一步推理时计算当前输入与上一步输入的相对L1距离,并将其累加。
-
当累计距离超过阈值,说明模型状态发生了足够的变化,则直接复用最近一次缓存的内容,跳过部分冗余计算。这样可以显著减少模型的前向计算次数,提高推理速度。
实际效果上,TeaCache 在保证生成质量的前提下,实现了明显的加速。加速前后的视频对比如下:
By specifying --config_json to the specific config file, you can test different cache algorithms.
| 加速前 | 加速后 |
Please refer our feature caching doc:
|:------:|:------:|
| 单卡H200推理耗时:58s | 单卡H200推理耗时:17.9s |
| !
[
加速前效果
](
../../assets/gifs/1.gif
)
| !
[
加速后效果
](
../../assets/gifs/2.gif
)
|
-
加速比为:
**3.24**
-
参考论文:
[
https://arxiv.org/abs/2411.19108
](
https://arxiv.org/abs/2411.19108
)
## TaylorSeer Cache
[
English doc: Feature Caching
](
https://lightx2v-en.readthedocs.io/en/latest/method_tutorials/cache.html
)
`TaylorSeer Cache`
的核心在于利用泰勒公式对缓存内容进行再次计算,作为缓存复用时间步的残差补偿。具体做法是在缓存复用的时间步,不仅简单地复用历史缓存,还通过泰勒展开对当前输出进行近似重构。这样可以在减少计算量的同时,进一步提升输出的准确性。泰勒展开能够有效捕捉模型状态的微小变化,使得缓存复用带来的误差得到补偿,从而在加速的同时保证生成质量。
`TaylorSeer Cache`
适用于对输出精度要求较高的场景,能够在缓存复用的基础上进一步提升模型推理的表现。
| 加速前 | 加速后 |
[
中文文档: 特征缓存
](
https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/method_tutorials/cache.html
)
|:------:|:------:|
| 单卡H200推理耗时:57.7s | 单卡H200推理耗时:41.3s |
| !
[
加速前效果
](
../../assets/gifs/3.gif
)
| !
[
加速后效果
](
../../assets/gifs/4.gif
)
|
-
加速比为:
**1.39**
-
参考论文:
[
https://arxiv.org/abs/2503.06923
](
https://arxiv.org/abs/2503.06923
)
## AdaCache
`AdaCache`
的核心思想是根据指定block块中的部分缓存内容,动态调整缓存复用的步长。
-
算法会分析相邻两个时间步在特定 block 内的特征差异,根据差异大小自适应地决定下一个缓存复用的时间步间隔。
-
当模型状态变化较小时,步长自动加大,减少缓存更新频率;当状态变化较大时,步长缩小,保证输出质量。
这样可以根据实际推理过程中的动态变化,灵活调整缓存策略,实现更高效的加速和更优的生成效果。AdaCache 适合对推理速度和生成质量都有较高要求的应用场景。
| 加速前 | 加速后 |
|:------:|:------:|
| 单卡H200推理耗时:227s | 单卡H200推理耗时:83s |
| !
[
加速前效果
](
../../assets/gifs/5.gif
)
| !
[
加速后效果
](
../../assets/gifs/6.gif
)
|
-
加速比为:
**2.73**
-
参考论文:
[
https://arxiv.org/abs/2411.02397
](
https://arxiv.org/abs/2411.02397
)
## CustomCache
`CustomCache`
综合了
`TeaCache`
和
`TaylorSeer Cache`
的优势。
-
它结合了
`TeaCache`
在缓存决策上的实时性和合理性,通过动态阈值判断何时进行缓存复用.
-
同时利用
`TaylorSeer`
的泰勒展开方法对已缓存内容进行利用。
这样不仅能够高效地决定缓存复用的时机,还能最大程度地利用缓存内容,提升输出的准确性和生成质量。实际测试表明,
`CustomCache`
在多个内容生成任务上,生成的视频质量优于单独使用
`TeaCache、TaylorSeer Cache`
或
`AdaCache`
的方案,是目前综合性能最优的缓存加速算法之一。
| 加速前 | 加速后 |
|:------:|:------:|
| 单卡H200推理耗时:57.9s | 单卡H200推理耗时:16.6s |
| !
[
加速前效果
](
../../assets/gifs/7.gif
)
| !
[
加速后效果
](
../../assets/gifs/8.gif
)
|
-
加速比为:
**3.49**
scripts/cache/run_wan_i2v_taylor.sh
deleted
100644 → 0
View file @
12bfd120
#!/bin/bash
# set path and first
lightx2v_path
=
model_path
=
# check section
if
[
-z
"
${
CUDA_VISIBLE_DEVICES
}
"
]
;
then
cuda_devices
=
0
echo
"Warn: CUDA_VISIBLE_DEVICES is not set, using default value:
${
cuda_devices
}
, change at shell script or set env variable."
export
CUDA_VISIBLE_DEVICES
=
${
cuda_devices
}
fi
if
[
-z
"
${
lightx2v_path
}
"
]
;
then
echo
"Error: lightx2v_path is not set. Please set this variable first."
exit
1
fi
if
[
-z
"
${
model_path
}
"
]
;
then
echo
"Error: model_path is not set. Please set this variable first."
exit
1
fi
export
TOKENIZERS_PARALLELISM
=
false
export
PYTHONPATH
=
${
lightx2v_path
}
:
$PYTHONPATH
export
DTYPE
=
BF16
export
ENABLE_PROFILING_DEBUG
=
true
export
ENABLE_GRAPH_MODE
=
false
python
-m
lightx2v.infer
\
--model_cls
wan2.1
\
--task
t2v
\
--model_path
$model_path
\
--config_json
${
lightx2v_path
}
/configs/caching/taylorseer/wan_i2v_tea_480p.json
\
--prompt
"Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage."
\
--negative_prompt
"镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
\
--image_path
${
lightx2v_path
}
/assets/inputs/imgs/img_0.jpg
\
--save_video_path
${
lightx2v_path
}
/save_results/output_lightx2v_wan_i2v_taylor.mp4
scripts/cache/run_wan_t2v_custom.sh
deleted
100644 → 0
View file @
12bfd120
#!/bin/bash
# set path and first
lightx2v_path
=
model_path
=
# check section
if
[
-z
"
${
CUDA_VISIBLE_DEVICES
}
"
]
;
then
cuda_devices
=
0
echo
"Warn: CUDA_VISIBLE_DEVICES is not set, using default value:
${
cuda_devices
}
, change at shell script or set env variable."
export
CUDA_VISIBLE_DEVICES
=
${
cuda_devices
}
fi
if
[
-z
"
${
lightx2v_path
}
"
]
;
then
echo
"Error: lightx2v_path is not set. Please set this variable first."
exit
1
fi
if
[
-z
"
${
model_path
}
"
]
;
then
echo
"Error: model_path is not set. Please set this variable first."
exit
1
fi
export
TOKENIZERS_PARALLELISM
=
false
export
PYTHONPATH
=
${
lightx2v_path
}
:
$PYTHONPATH
export
DTYPE
=
BF16
export
ENABLE_PROFILING_DEBUG
=
true
export
ENABLE_GRAPH_MODE
=
false
python
-m
lightx2v.infer
\
--model_cls
wan2.1
\
--task
t2v
\
--model_path
$model_path
\
--config_json
${
lightx2v_path
}
/configs/caching/custom/wan_t2v_custom_1_3b.json
\
--prompt
"Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage."
\
--negative_prompt
"镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
\
--save_video_path
${
lightx2v_path
}
/save_results/output_lightx2v_wan_t2v_custom.mp4
scripts/cache/run_wan_t2v_tea.sh
View file @
e08c4f90
...
@@ -32,7 +32,7 @@ python -m lightx2v.infer \
...
@@ -32,7 +32,7 @@ python -m lightx2v.infer \
--model_cls
wan2.1
\
--model_cls
wan2.1
\
--task
t2v
\
--task
t2v
\
--model_path
$model_path
\
--model_path
$model_path
\
--config_json
${
lightx2v_path
}
/configs/caching/teacache/wan_t2v_1_3b.json
\
--config_json
${
lightx2v_path
}
/configs/caching/teacache/wan_t2v_1_3b
_tea_480p
.json
\
--prompt
"Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage."
\
--prompt
"Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage."
\
--negative_prompt
"镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
\
--negative_prompt
"镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
\
--save_video_path
${
lightx2v_path
}
/save_results/output_lightx2v_wan_t2v_tea.mp4
--save_video_path
${
lightx2v_path
}
/save_results/output_lightx2v_wan_t2v_tea.mp4
scripts/cogvideox/readme.md
deleted
100644 → 0
View file @
12bfd120
## todo
scripts/cogvideox/run_cogvideox_t2v.sh
deleted
100755 → 0
View file @
12bfd120
#!/bin/bash
# set path and first
lightx2v_path
=
model_path
=
# check section
if
[
-z
"
${
CUDA_VISIBLE_DEVICES
}
"
]
;
then
cuda_devices
=
0
echo
"Warn: CUDA_VISIBLE_DEVICES is not set, using defalt value:
${
cuda_devices
}
, change at shell script or set env variable."
export
CUDA_VISIBLE_DEVICES
=
${
cuda_devices
}
fi
if
[
-z
"
${
lightx2v_path
}
"
]
;
then
echo
"Error: lightx2v_path is not set. Please set this variable first."
exit
1
fi
if
[
-z
"
${
model_path
}
"
]
;
then
echo
"Error: model_path is not set. Please set this variable first."
exit
1
fi
export
TOKENIZERS_PARALLELISM
=
false
export
PYTHONPATH
=
${
lightx2v_path
}
:
$PYTHONPATH
export
ENABLE_PROFILING_DEBUG
=
true
export
ENABLE_GRAPH_MODE
=
false
export
DTYPE
=
BF16
export
PYTHONPATH
=
/mtc/wushuo/VideoGen/diffusers:
$PYTHONPATH
python
-m
lightx2v.infer
\
--model_cls
cogvideox
\
--task
t2v
\
--model_path
$model_path
\
--config_json
${
lightx2v_path
}
/configs/cogvideox/cogvideox_t2v.json
\
--prompt
"A little girl smile."
\
--save_video_path
${
lightx2v_path
}
/save_results/output_lightx2v_cogvideox_t2v.mp4
scripts/deploy/readme.md
deleted
100644 → 0
View file @
12bfd120
## todo
scripts/deploy/start_dit_server.sh
deleted
100755 → 0
View file @
12bfd120
#!/bin/bash
# set path and first
lightx2v_path
=
model_path
=
# check section
if
[
-z
"
${
CUDA_VISIBLE_DEVICES
}
"
]
;
then
cuda_devices
=
0
echo
"Warn: CUDA_VISIBLE_DEVICES is not set, using default value:
${
cuda_devices
}
, change at shell script or set env variable."
export
CUDA_VISIBLE_DEVICES
=
${
cuda_devices
}
fi
if
[
-z
"
${
lightx2v_path
}
"
]
;
then
echo
"Error: lightx2v_path is not set. Please set this variable first."
exit
1
fi
if
[
-z
"
${
model_path
}
"
]
;
then
echo
"Error: model_path is not set. Please set this variable first."
exit
1
fi
export
TOKENIZERS_PARALLELISM
=
false
export
PYTHONPATH
=
${
lightx2v_path
}
:
$PYTHONPATH
export
ENABLE_PROFILING_DEBUG
=
true
export
ENABLE_GRAPH_MODE
=
false
python
-m
lightx2v.common.apis.dit
\
--model_cls
wan2.1
\
--task
i2v
\
--model_path
$model_path
\
--config_json
${
lightx2v_path
}
/configs/deploy/wan_i2v.json
\
--port
9000
Prev
1
…
4
5
6
7
8
9
10
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment