Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
e08c4f90
"official/projects/yolo/README.md" did not exist on "14c7d6095e4b4b380ca443a94e02370bf399cbc7"
Commit
e08c4f90
authored
Jul 17, 2025
by
sandy
Committed by
GitHub
Jul 17, 2025
Browse files
Merge branch 'main' into audio_r2v
parents
12bfd120
6d07a72e
Changes
191
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1022 additions
and
209 deletions
+1022
-209
lightx2v_kernel/csrc/common_extension.cc
lightx2v_kernel/csrc/common_extension.cc
+4
-0
lightx2v_kernel/csrc/gemm/mxfp6_quant_kernels_sm120.cu
lightx2v_kernel/csrc/gemm/mxfp6_quant_kernels_sm120.cu
+348
-0
lightx2v_kernel/csrc/gemm/mxfp8_quant_kernels_sm120.cu
lightx2v_kernel/csrc/gemm/mxfp8_quant_kernels_sm120.cu
+1
-1
lightx2v_kernel/include/lightx2v_kernel_ops.h
lightx2v_kernel/include/lightx2v_kernel_ops.h
+2
-0
lightx2v_kernel/python/lightx2v_kernel/gemm.py
lightx2v_kernel/python/lightx2v_kernel/gemm.py
+14
-1
lightx2v_kernel/test/mxfp6_mxfp8/test_bench.py
lightx2v_kernel/test/mxfp6_mxfp8/test_bench.py
+121
-0
lightx2v_kernel/test/mxfp6_mxfp8/test_fake_quant.py
lightx2v_kernel/test/mxfp6_mxfp8/test_fake_quant.py
+181
-0
lightx2v_kernel/test/mxfp6_mxfp8/test_mm_tflops.py
lightx2v_kernel/test/mxfp6_mxfp8/test_mm_tflops.py
+115
-0
lightx2v_kernel/test/mxfp6_mxfp8/test_mxfp6_quant.py
lightx2v_kernel/test/mxfp6_mxfp8/test_mxfp6_quant.py
+49
-0
lightx2v_kernel/test/mxfp6_mxfp8/test_quant_mem_utils.py
lightx2v_kernel/test/mxfp6_mxfp8/test_quant_mem_utils.py
+158
-0
requirements.txt
requirements.txt
+3
-0
requirements_win.txt
requirements_win.txt
+19
-0
scripts/cache/readme.md
scripts/cache/readme.md
+6
-52
scripts/cache/run_wan_i2v_taylor.sh
scripts/cache/run_wan_i2v_taylor.sh
+0
-39
scripts/cache/run_wan_t2v_custom.sh
scripts/cache/run_wan_t2v_custom.sh
+0
-38
scripts/cache/run_wan_t2v_tea.sh
scripts/cache/run_wan_t2v_tea.sh
+1
-1
scripts/cogvideox/readme.md
scripts/cogvideox/readme.md
+0
-1
scripts/cogvideox/run_cogvideox_t2v.sh
scripts/cogvideox/run_cogvideox_t2v.sh
+0
-39
scripts/deploy/readme.md
scripts/deploy/readme.md
+0
-1
scripts/deploy/start_dit_server.sh
scripts/deploy/start_dit_server.sh
+0
-36
No files found.
lightx2v_kernel/csrc/common_extension.cc
View file @
e08c4f90
...
...
@@ -20,6 +20,10 @@ TORCH_LIBRARY_FRAGMENT(lightx2v_kernel, m) {
"scaled_fp8_quant_sm120(Tensor! output, Tensor! input,"
" Tensor! output_scale) -> ()"
);
m
.
impl
(
"scaled_fp8_quant_sm120"
,
torch
::
kCUDA
,
&
scaled_fp8_quant_sm120
);
m
.
def
(
"scaled_fp6_quant_sm120(Tensor! output, Tensor! input,"
" Tensor! output_scale) -> ()"
);
m
.
impl
(
"scaled_fp6_quant_sm120"
,
torch
::
kCUDA
,
&
scaled_fp6_quant_sm120
);
m
.
def
(
"cutlass_scaled_mxfp6_mxfp8_mm_sm120(Tensor! out, Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, Tensor "
...
...
lightx2v_kernel/csrc/gemm/mxfp6_quant_kernels_sm120.cu
0 → 100644
View file @
e08c4f90
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_fp8.h>
#include <cuda_fp6.h>
#include <cuda_runtime.h>
#include <cuda_runtime_api.h>
#include <torch/all.h>
#include "utils.h"
// Get type2 from type or vice versa (applied to half and bfloat16)
template
<
typename
T
>
struct
TypeConverter
{
using
Type
=
half2
;
};
// keep for generality
template
<
>
struct
TypeConverter
<
half2
>
{
using
Type
=
half
;
};
template
<
>
struct
TypeConverter
<
half
>
{
using
Type
=
half2
;
};
template
<
>
struct
TypeConverter
<
__nv_bfloat162
>
{
using
Type
=
__nv_bfloat16
;
};
template
<
>
struct
TypeConverter
<
__nv_bfloat16
>
{
using
Type
=
__nv_bfloat162
;
};
#define ELTS_PER_THREAD 8
constexpr
int
CVT_FP6_ELTS_PER_THREAD
=
8
;
constexpr
int
CVT_FP6_SF_VEC_SIZE
=
32
;
struct
uint8x6_t
{
uint8_t
elts
[
6
];
};
// Convert 4 float2 values into 8 e3m2 values (represented as one uint8x6_t).
inline
__device__
uint8x6_t
fp32_vec_to_e3m2
(
float2
(
&
array
)[
4
])
{
uint64_t
val
;
asm
volatile
(
"{
\n
"
".reg .b16 pack0;
\n
"
".reg .b16 pack1;
\n
"
".reg .b16 pack2;
\n
"
".reg .b16 pack3;
\n
"
"cvt.rn.satfinite.e3m2x2.f32 pack0, %2, %1;
\n
"
"cvt.rn.satfinite.e3m2x2.f32 pack1, %4, %3;
\n
"
"cvt.rn.satfinite.e3m2x2.f32 pack2, %6, %5;
\n
"
"cvt.rn.satfinite.e3m2x2.f32 pack3, %8, %7;
\n
"
"mov.b64 %0, {pack0, pack1, pack2, pack3};
\n
"
"}"
:
"=l"
(
val
)
:
"f"
(
array
[
0
].
x
),
"f"
(
array
[
0
].
y
),
"f"
(
array
[
1
].
x
),
"f"
(
array
[
1
].
y
),
"f"
(
array
[
2
].
x
),
"f"
(
array
[
2
].
y
),
"f"
(
array
[
3
].
x
),
"f"
(
array
[
3
].
y
));
uint8x6_t
result
;
// pack 8 uint8_t into 6 uint8_t
// here is how to pack:
// 4个fp6 a b c d. a:[a5 a4 a3 a2 a1 a0], b..., c..., d...
// 3个unint8 pack0 pack1 pack2
// packed0: [b1 b0][a5 a4 a3 a2 a1 a0]
// packed1: [c3 c2 c1 c0][b5 b4 b3 b2]
// packed2: [d5 d4 d3 d2 d1 d0][c5 c4]
// lower 4 uint8_t
uint8_t
l_val_0
=
val
&
0xFF
;
uint8_t
l_val_1
=
(
val
>>
8
)
&
0xFF
;
uint8_t
l_val_2
=
(
val
>>
16
)
&
0xFF
;
uint8_t
l_val_3
=
(
val
>>
24
)
&
0xFF
;
// higher 4 uint8_t
uint8_t
h_val_0
=
(
val
>>
32
)
&
0xFF
;
uint8_t
h_val_1
=
(
val
>>
40
)
&
0xFF
;
uint8_t
h_val_2
=
(
val
>>
48
)
&
0xFF
;
uint8_t
h_val_3
=
(
val
>>
56
)
&
0xFF
;
// pack result
result
.
elts
[
0
]
=
(
l_val_1
<<
6
)
|
l_val_0
;
result
.
elts
[
1
]
=
(
l_val_2
<<
4
)
|
(
l_val_1
>>
2
);
result
.
elts
[
2
]
=
(
l_val_3
<<
2
)
|
(
l_val_2
>>
4
);
result
.
elts
[
3
]
=
(
h_val_1
<<
6
)
|
h_val_0
;
result
.
elts
[
4
]
=
(
h_val_2
<<
4
)
|
(
h_val_1
>>
2
);
result
.
elts
[
5
]
=
(
h_val_3
<<
2
)
|
(
h_val_2
>>
4
);
return
result
;
}
// Fast reciprocal.
inline
__device__
float
reciprocal_approximate_ftz
(
float
a
)
{
float
b
;
asm
volatile
(
"rcp.approx.ftz.f32 %0, %1;
\n
"
:
"=f"
(
b
)
:
"f"
(
a
));
return
b
;
}
template
<
class
SFType
,
int
CVT_FP6_NUM_THREADS_PER_SF
>
__device__
uint8_t
*
get_sf_out_address
(
int
rowIdx
,
int
colIdx
,
int
numCols
,
SFType
*
SFout
)
{
// #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
static_assert
(
CVT_FP6_NUM_THREADS_PER_SF
==
4
);
// one of 4 threads write one SF to global memory.
// TODO: stage through smem for packed STG.32
// is it better than STG.8 from 4 threads ?
if
(
threadIdx
.
x
%
CVT_FP6_NUM_THREADS_PER_SF
==
0
)
{
// SF vector index (32 elements share one SF in the K dimension).
int32_t
kIdx
=
colIdx
/
CVT_FP6_NUM_THREADS_PER_SF
;
int32_t
mIdx
=
rowIdx
;
// SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)]
// --> index [mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx]
int32_t
mTileIdx
=
mIdx
/
(
32
*
4
);
// SF vector size 32.
int
factor
=
CVT_FP6_SF_VEC_SIZE
*
4
;
int32_t
numKTiles
=
(
numCols
+
factor
-
1
)
/
factor
;
int64_t
mTileStride
=
numKTiles
*
32
*
4
*
4
;
int32_t
kTileIdx
=
(
kIdx
/
4
);
int64_t
kTileStride
=
32
*
4
*
4
;
// M tile layout [32, 4] is column-major.
int32_t
outerMIdx
=
(
mIdx
%
32
);
// same as (mIdx % 128) % 32
int64_t
outerMStride
=
4
*
4
;
int32_t
innerMIdx
=
(
mIdx
%
(
32
*
4
))
/
32
;
int64_t
innerMStride
=
4
;
int32_t
innerKIdx
=
(
kIdx
%
4
);
int64_t
innerKStride
=
1
;
// Compute the global offset.
int64_t
SFOffset
=
mTileIdx
*
mTileStride
+
kTileIdx
*
kTileStride
+
outerMIdx
*
outerMStride
+
innerMIdx
*
innerMStride
+
innerKIdx
*
innerKStride
;
return
reinterpret_cast
<
uint8_t
*>
(
SFout
)
+
SFOffset
;
}
else
{
// Other threads do not write to SFout.
return
nullptr
;
}
}
// Define a 16 bytes packed data type.
template
<
class
Type
>
struct
PackedVec
{
typename
TypeConverter
<
Type
>::
Type
elts
[
4
];
};
// template <>
// struct PackedVec<__nv_fp8_e4m3> {
// __nv_fp8x2_e4m3 elts[8];
// };
template
<
class
Type
>
// Type can be half or bfloat16
__device__
uint8x6_t
cvt_warp_fp16_to_fp6
(
PackedVec
<
Type
>&
vec
,
uint8_t
*
SFout
)
{
// #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
// Get absolute maximum values among the local 8 values.
auto
localMax
=
__habs2
(
vec
.
elts
[
0
]);
// Local maximum value.
#pragma unroll
for
(
int
i
=
1
;
i
<
CVT_FP6_ELTS_PER_THREAD
/
2
;
i
++
)
{
localMax
=
__hmax2
(
localMax
,
__habs2
(
vec
.
elts
[
i
]));
}
// Get the absolute maximum among all 32 values (four threads).
localMax
=
__hmax2
(
__shfl_xor_sync
(
uint32_t
(
-
1
),
localMax
,
1
),
localMax
);
localMax
=
__hmax2
(
__shfl_xor_sync
(
uint32_t
(
-
1
),
localMax
,
2
),
localMax
);
// Get the final absolute maximum values.
float
vecMax
=
float
(
__hmax
(
localMax
.
x
,
localMax
.
y
));
// Get the SF (max value of the vector / max value of e3m2).
// maximum value of e3m2 = 28.0.
// TODO: use half as compute data type.
float
SFValue
=
(
vecMax
/
28.0
f
);
// 8 bits representation of the SF.
uint8_t
fp8SFVal
;
// Write the SF to global memory (STG.8).
__nv_fp8_e8m0
tmp
;
tmp
.
__x
=
__nv_cvt_float_to_e8m0
(
SFValue
,
__NV_SATFINITE
,
cudaRoundPosInf
);
SFValue
=
static_cast
<
float
>
(
tmp
);
fp8SFVal
=
tmp
.
__x
;
float
outputScale
=
SFValue
!=
0
?
reciprocal_approximate_ftz
(
SFValue
)
:
0.0
f
;
if
(
SFout
)
{
// Write the SF to global memory (STG.8).
*
SFout
=
fp8SFVal
;
}
// Convert the input to float.
float2
fp2Vals
[
CVT_FP6_ELTS_PER_THREAD
/
2
];
#pragma unroll
for
(
int
i
=
0
;
i
<
CVT_FP6_ELTS_PER_THREAD
/
2
;
i
++
)
{
if
constexpr
(
std
::
is_same_v
<
Type
,
half
>
)
{
fp2Vals
[
i
]
=
__half22float2
(
vec
.
elts
[
i
]);
}
else
{
fp2Vals
[
i
]
=
__bfloat1622float2
(
vec
.
elts
[
i
]);
}
fp2Vals
[
i
].
x
*=
outputScale
;
fp2Vals
[
i
].
y
*=
outputScale
;
}
// Convert to e3m2 values.
uint8x6_t
e3m2Vec
=
fp32_vec_to_e3m2
(
fp2Vals
);
return
e3m2Vec
;
}
template
<
class
Type
>
// Type can be half or bfloat16
__global__
void
// #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
__launch_bounds__
(
256
,
6
)
cvt_fp16_to_fp6
(
// #else
// cvt_fp16_to_fp6(
// #endif
int32_t
numRows
,
int32_t
numCols
,
Type
const
*
in
,
uint8x6_t
*
out
,
uint32_t
*
SFout
)
{
// #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
using
PackedVec
=
PackedVec
<
Type
>
;
static
constexpr
int
CVT_FP6_NUM_THREADS_PER_SF
=
(
CVT_FP6_SF_VEC_SIZE
/
CVT_FP6_ELTS_PER_THREAD
);
static_assert
(
sizeof
(
PackedVec
)
==
sizeof
(
Type
)
*
CVT_FP6_ELTS_PER_THREAD
,
"Vec size is not matched."
);
// Input tensor row/col loops.
for
(
int
rowIdx
=
blockIdx
.
x
;
rowIdx
<
numRows
;
rowIdx
+=
gridDim
.
x
)
{
for
(
int
colIdx
=
threadIdx
.
x
;
colIdx
<
numCols
/
CVT_FP6_ELTS_PER_THREAD
;
colIdx
+=
blockDim
.
x
)
{
int64_t
inOffset
=
rowIdx
*
(
numCols
/
CVT_FP6_ELTS_PER_THREAD
)
+
colIdx
;
PackedVec
in_vec
=
reinterpret_cast
<
PackedVec
const
*>
(
in
)[
inOffset
];
// Get the output tensor offset.
// Same as inOffset because 8 elements(E3M2) are packed into one uint8x6_t.
int64_t
outOffset
=
inOffset
;
auto
&
out_pos
=
out
[
outOffset
];
auto
sf_out
=
get_sf_out_address
<
uint32_t
,
CVT_FP6_NUM_THREADS_PER_SF
>
(
rowIdx
,
colIdx
,
numCols
,
SFout
);
out_pos
=
cvt_warp_fp16_to_fp6
<
Type
>
(
in_vec
,
sf_out
);
}
}
// #endif
}
template
<
typename
T
>
void
invokeFP6Quantization
(
int
m
,
int
n
,
T
const
*
input
,
int64_t
*
output
,
int32_t
*
SFOuput
,
int
multiProcessorCount
,
cudaStream_t
stream
)
{
// Grid, Block size.
// Each thread converts 8 values.
dim3
block
(
std
::
min
(
int
(
n
/
ELTS_PER_THREAD
),
256
));
// Get number of blocks per SM (assume we can fully utilize the SM).
int
const
numBlocksPerSM
=
1536
/
block
.
x
;
dim3
grid
(
std
::
min
(
int
(
m
),
multiProcessorCount
*
numBlocksPerSM
));
// Launch the cvt kernel.
cvt_fp16_to_fp6
<
T
>
<<<
grid
,
block
,
0
,
stream
>>>
(
m
,
n
,
input
,
reinterpret_cast
<
uint8x6_t
*>
(
output
),
reinterpret_cast
<
uint32_t
*>
(
SFOuput
));
}
// Instantiate the function.
template
void
invokeFP6Quantization
(
int
m
,
int
n
,
half
const
*
input
,
int64_t
*
output
,
int32_t
*
SFOuput
,
int
multiProcessorCount
,
cudaStream_t
stream
);
template
void
invokeFP6Quantization
(
int
m
,
int
n
,
__nv_bfloat16
const
*
input
,
int64_t
*
output
,
int32_t
*
SFOuput
,
int
multiProcessorCount
,
cudaStream_t
stream
);
inline
int
getMultiProcessorCount
()
{
static
int
multi_processor_count
=
[]()
{
int
device_id
=
0
;
int
count
=
0
;
// Get the current CUDA device ID
CHECK_CUDA_SUCCESS
(
cudaGetDevice
(
&
device_id
));
// Get the number of multiprocessors for the current device
CHECK_CUDA_SUCCESS
(
cudaDeviceGetAttribute
(
&
count
,
cudaDevAttrMultiProcessorCount
,
device_id
));
return
count
;
// Initialize the static variable
}();
return
multi_processor_count
;
// Return the cached value on subsequent calls
}
void
scaled_fp6_quant_sm120
(
torch
::
Tensor
&
output
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
&
output_sf
)
{
int32_t
m
=
input
.
size
(
0
);
int32_t
n
=
input
.
size
(
1
);
TORCH_CHECK
(
n
%
32
==
0
,
"The N dimension must be multiple of 32."
);
int
multiProcessorCount
=
getMultiProcessorCount
();
auto
sf_out
=
static_cast
<
int32_t
*>
(
output_sf
.
data_ptr
());
auto
output_ptr
=
static_cast
<
int64_t
*>
(
output
.
data_ptr
());
at
::
cuda
::
CUDAGuard
device_guard
{(
char
)
input
.
get_device
()};
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
input
.
get_device
());
switch
(
input
.
scalar_type
())
{
case
torch
::
kHalf
:
{
auto
input_ptr
=
reinterpret_cast
<
half
const
*>
(
input
.
data_ptr
());
invokeFP6Quantization
(
m
,
n
,
input_ptr
,
output_ptr
,
sf_out
,
multiProcessorCount
,
stream
);
break
;
}
case
torch
::
kBFloat16
:
{
auto
input_ptr
=
reinterpret_cast
<
__nv_bfloat16
const
*>
(
input
.
data_ptr
());
invokeFP6Quantization
(
m
,
n
,
input_ptr
,
output_ptr
,
sf_out
,
multiProcessorCount
,
stream
);
break
;
}
default:
{
std
::
cerr
<<
"Observing: "
<<
input
.
scalar_type
()
<<
" for the input datatype which is invalid"
;
throw
std
::
runtime_error
(
"Unsupported input data type for quantize_to_fp6."
);
}
}
}
lightx2v_kernel/csrc/gemm/mxfp8_quant_kernels_sm120.cu
View file @
e08c4f90
...
...
@@ -287,7 +287,7 @@ void scaled_fp8_quant_sm120(
int32_t
m
=
input
.
size
(
0
);
int32_t
n
=
input
.
size
(
1
);
TORCH_CHECK
(
n
%
32
==
0
,
"The N dimension must be multiple of
16
."
);
TORCH_CHECK
(
n
%
32
==
0
,
"The N dimension must be multiple of
32
."
);
int
multiProcessorCount
=
getMultiProcessorCount
();
...
...
lightx2v_kernel/include/lightx2v_kernel_ops.h
View file @
e08c4f90
...
...
@@ -60,6 +60,8 @@ void scaled_fp4_quant_sm120(
void
scaled_fp8_quant_sm120
(
torch
::
Tensor
&
output
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
&
output_sf
);
void
scaled_fp6_quant_sm120
(
torch
::
Tensor
&
output
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
&
output_sf
);
void
cutlass_scaled_mxfp6_mxfp8_mm_sm120
(
torch
::
Tensor
&
D
,
...
...
lightx2v_kernel/python/lightx2v_kernel/gemm.py
View file @
e08c4f90
...
...
@@ -48,13 +48,26 @@ def scaled_fp4_quant(input: torch.Tensor, input_global_scale: torch.Tensor):
# rounded_m = ((m + 128 - 1) // 128) * 128
# scale_n = n // block_size
# rounded_n = ((scale_n + 4 - 1) // 4) * 4
output_scale
=
torch
.
empty
((((
m
+
128
-
1
)
//
128
)
*
128
,
(
n
//
block_size
+
4
-
1
)
//
4
),
device
=
device
,
dtype
=
torch
.
int32
)
output_scale
=
torch
.
zeros
((((
m
+
128
-
1
)
//
128
)
*
128
,
(
n
//
block_size
+
4
-
1
)
//
4
),
device
=
device
,
dtype
=
torch
.
int32
)
torch
.
ops
.
lightx2v_kernel
.
scaled_fp4_quant_sm120
.
default
(
output
,
input
,
output_scale
,
input_global_scale
)
output_scale
=
output_scale
.
view
(
torch
.
float8_e4m3fn
)
return
output
,
output_scale
def
scaled_fp6_quant
(
input
:
torch
.
Tensor
):
m
,
n
=
input
.
shape
block_size
=
32
device
=
input
.
device
output
=
torch
.
empty
((
m
,
3
*
n
//
4
),
device
=
device
,
dtype
=
torch
.
uint8
)
output_scale
=
torch
.
zeros
(((
m
+
128
-
1
)
//
128
*
128
,
(
n
//
block_size
+
4
-
1
)
//
4
),
device
=
device
,
dtype
=
torch
.
int32
)
torch
.
ops
.
lightx2v_kernel
.
scaled_fp6_quant_sm120
.
default
(
output
,
input
,
output_scale
)
output_scale
=
output_scale
.
view
(
torch
.
float8_e8m0fnu
)
return
output
,
output_scale
def
scaled_fp8_quant
(
input
:
torch
.
Tensor
):
m
,
n
=
input
.
shape
block_size
=
32
...
...
lightx2v_kernel/test/mxfp6_mxfp8/test_bench.py
0 → 100644
View file @
e08c4f90
import
torch
from
lightx2v_kernel.gemm
import
scaled_fp8_quant
,
scaled_fp6_quant
,
cutlass_scaled_mxfp6_mxfp8_mm
import
time
class
MMWeightMxfp8ActMxfp6
:
def
__init__
(
self
,
weight
,
bias
):
self
.
load_fp6_weight
(
weight
,
bias
)
self
.
act_quant_func
=
self
.
act_quant_fp8
self
.
set_alpha
()
@
torch
.
no_grad
()
def
apply
(
self
,
input_tensor
):
input_tensor_quant
,
input_tensor_scale
=
self
.
act_quant_func
(
input_tensor
)
output_tensor
=
cutlass_scaled_mxfp6_mxfp8_mm
(
input_tensor_quant
,
self
.
weight
,
input_tensor_scale
,
self
.
weight_scale
,
alpha
=
self
.
alpha
,
bias
=
self
.
bias
)
return
output_tensor
@
torch
.
no_grad
()
def
load_fp6_weight
(
self
,
weight
,
bias
):
self
.
weight
,
self
.
weight_scale
=
scaled_fp6_quant
(
weight
)
self
.
bias
=
bias
def
set_alpha
(
self
):
self
.
alpha
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
,
device
=
self
.
weight
.
device
)
@
torch
.
no_grad
()
def
act_quant_fp8
(
self
,
x
):
return
scaled_fp8_quant
(
x
)
def
test_speed
(
m
,
k
,
n
):
with
torch
.
no_grad
():
input_tensor
=
torch
.
randn
(
m
,
k
,
dtype
=
torch
.
bfloat16
).
cuda
()
weight
=
torch
.
randn
(
n
,
k
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
# bias = torch.randn(1, n, dtype=torch.bfloat16).cuda()
bias
=
None
mm
=
MMWeightMxfp8ActMxfp6
(
weight
,
bias
)
# warmup
output_tensor
=
mm
.
apply
(
input_tensor
)
torch
.
cuda
.
synchronize
()
start_time
=
time
.
time
()
for
i
in
range
(
100
):
output_tensor
=
mm
.
apply
(
input_tensor
)
torch
.
cuda
.
synchronize
()
end_time
=
time
.
time
()
lightx2v_kernel_time
=
(
end_time
-
start_time
)
/
100
print
(
f
"lightx2v-kernel time:
{
lightx2v_kernel_time
}
"
)
input_tensor
=
torch
.
randn
(
m
,
n
,
dtype
=
torch
.
bfloat16
).
cuda
()
weight
=
torch
.
randn
(
k
,
n
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
bias
=
torch
.
randn
(
1
,
k
,
dtype
=
torch
.
bfloat16
).
cuda
()
linear
=
torch
.
nn
.
Linear
(
k
,
n
,
bias
=
False
).
cuda
()
linear
.
weight
.
data
=
weight
# linear.bias.data = bias
# warmup
ref_output_tensor
=
linear
(
input_tensor
)
torch
.
cuda
.
synchronize
()
start_time
=
time
.
time
()
for
i
in
range
(
100
):
ref_output_tensor
=
linear
(
input_tensor
)
torch
.
cuda
.
synchronize
()
end_time
=
time
.
time
()
ref_time
=
(
end_time
-
start_time
)
/
100
print
(
f
"ref time:
{
ref_time
}
"
)
print
(
f
"speedup:
{
ref_time
/
lightx2v_kernel_time
:.
3
f
}
"
)
def
test_accuracy
(
m
,
k
,
n
):
with
torch
.
no_grad
():
input_tensor
=
torch
.
randn
(
m
,
k
,
dtype
=
torch
.
bfloat16
).
cuda
()
weight
=
torch
.
randn
(
n
,
k
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
# bias = torch.randn(1, n, dtype=torch.bfloat16).cuda()
bias
=
None
linear
=
torch
.
nn
.
Linear
(
k
,
n
,
bias
=
False
).
cuda
()
linear
.
weight
.
data
=
weight
# linear.bias.data = bias
ref_output_tensor
=
linear
(
input_tensor
)
mm
=
MMWeightMxfp8ActMxfp6
(
weight
,
bias
)
output_tensor
=
mm
.
apply
(
input_tensor
)
# print(f"ref_output_tensor: {ref_output_tensor}")
# print(f"output_tensor: {output_tensor}")
# cosine
cos
=
torch
.
nn
.
functional
.
cosine_similarity
(
ref_output_tensor
.
flatten
(),
output_tensor
.
flatten
(),
dim
=
0
)
print
(
f
"cos :
{
cos
}
"
)
if
__name__
==
"__main__"
:
test_sizes
=
[
(
32130
,
5120
,
5120
),
(
512
,
5120
,
5120
),
(
257
,
5120
,
5120
),
(
32130
,
5120
,
13824
),
(
32130
,
13824
,
5120
),
(
75348
,
5120
,
5120
),
(
75348
,
13824
,
5120
),
(
32760
,
1536
,
1536
),
(
512
,
1536
,
1536
),
(
32760
,
1536
,
8960
),
(
32760
,
8960
,
1536
),
]
for
i
,
(
m
,
k
,
n
)
in
enumerate
(
test_sizes
):
print
(
"-"
*
30
)
print
(
f
"测试
{
i
+
1
}
: 张量大小 (
{
m
}
,
{
k
}
,
{
n
}
)"
)
test_accuracy
(
m
,
k
,
n
)
test_speed
(
m
,
k
,
n
)
lightx2v_kernel/test/mxfp6_mxfp8/test_fake_quant.py
0 → 100644
View file @
e08c4f90
import
torch
from
torchao.prototype.mx_formats.constants
import
DTYPE_FP6_E3M2
from
torchao.prototype.mx_formats.mx_tensor
import
to_mx
,
pack_uint6
def
quant2mxfp8
(
x
:
torch
.
Tensor
):
block_size
=
32
m
,
_
=
x
.
shape
scale
,
output
=
to_mx
(
x
,
torch
.
float8_e4m3fn
,
block_size
=
block_size
)
return
scale
.
reshape
(
m
,
-
1
),
output
def
quant2mxfp6
(
x
:
torch
.
Tensor
):
block_size
=
32
m
,
_
=
x
.
shape
scale
,
output
=
to_mx
(
x
,
DTYPE_FP6_E3M2
,
block_size
=
block_size
,
pack_fp6
=
False
)
return
scale
.
reshape
(
m
,
-
1
),
output
def
scale_pad_and_swizzle
(
scale
:
torch
.
Tensor
):
m
,
s
=
scale
.
shape
# pad the m up to 128, s up to 4
padded_m
=
(
m
+
127
)
//
128
*
128
padded_s
=
(
s
+
3
)
//
4
*
4
padded_scale
=
torch
.
empty
(
padded_m
,
padded_s
,
device
=
scale
.
device
,
dtype
=
scale
.
dtype
)
padded_scale
[:
m
,
:
s
]
=
scale
# swizzle the padded scale
swizzled_scale
=
padded_scale
.
reshape
(
padded_m
//
128
,
128
,
padded_s
//
4
,
4
).
reshape
(
padded_m
//
128
,
4
,
32
,
padded_s
//
4
,
4
).
permute
(
0
,
3
,
2
,
1
,
4
)
return
swizzled_scale
.
reshape
(
padded_m
,
padded_s
)
###############################################################
# Packing kernel and func
###############################################################
import
triton
# noqa: E402
import
triton.language
as
tl
# noqa: E402
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({
"BLOCK_SIZE_IN"
:
2
},
num_warps
=
1
),
triton
.
Config
({
"BLOCK_SIZE_IN"
:
4
},
num_warps
=
1
),
triton
.
Config
({
"BLOCK_SIZE_IN"
:
8
},
num_warps
=
1
),
triton
.
Config
({
"BLOCK_SIZE_IN"
:
16
},
num_warps
=
1
),
],
key
=
[
"n_mx_blocks"
],
)
@
triton
.
jit
def
triton_pack_uint6_kernel
(
input_ptr
,
output_ptr
,
n_mx_blocks
,
MX_BLOCK_SIZE
:
tl
.
constexpr
,
PACKED_MX_BLOCK_SIZE
:
tl
.
constexpr
,
BLOCK_SIZE_IN
:
tl
.
constexpr
,
):
pid
=
tl
.
program_id
(
axis
=
0
)
block_start
=
pid
*
BLOCK_SIZE_IN
# input_ptr is shape [n_mx_blocks, MX_BLOCK_SIZE]
# Load BLOCK_SIZE rows of input_ptr
offsets_rows
=
block_start
+
tl
.
arange
(
0
,
BLOCK_SIZE_IN
)
offsets_cols
=
tl
.
arange
(
0
,
MX_BLOCK_SIZE
//
4
)
offsets
=
offsets_rows
[:,
None
]
*
MX_BLOCK_SIZE
+
(
4
*
offsets_cols
[
None
,
:])
mask
=
(
offsets_rows
[:,
None
]
<
n_mx_blocks
)
&
(
offsets_cols
[
None
,
:]
<
MX_BLOCK_SIZE
//
4
)
# x is shape [BLOCK_SIZE, MX_BLOCK_SIZE]
x_0
=
tl
.
load
(
input_ptr
+
offsets
,
mask
=
mask
)
x_1
=
tl
.
load
(
input_ptr
+
offsets
+
1
,
mask
=
mask
)
x_2
=
tl
.
load
(
input_ptr
+
offsets
+
2
,
mask
=
mask
)
x_3
=
tl
.
load
(
input_ptr
+
offsets
+
3
,
mask
=
mask
)
# 4个fp6 a b c d. a:[a5 a4 a3 a2 a1 a0], b..., c..., d...
# 3个unint8 pack0 pack1 pack2
# cutlass需要的:
# packed0: [b1 b0][a5 a4 a3 a2 a1 a0]
# packed1: [c3 c2 c1 c0][b5 b4 b3 b2]
# packed2: [d5 d4 d3 d2 d1 d0][c5 c4]
bits_packed0
=
(
x_1
<<
6
)
|
x_0
bits_packed1
=
(
x_2
<<
4
)
|
(
x_1
>>
2
)
bits_packed2
=
(
x_3
<<
2
)
|
(
x_2
>>
4
)
# Store values in a uint8 tensor of length `3 * MX_BLOCK_SIZE / 4`
offsets_out_4_a
=
offsets_rows
[:,
None
]
*
PACKED_MX_BLOCK_SIZE
+
3
*
offsets_cols
[
None
,
:]
offsets_out_4_b
=
offsets_rows
[:,
None
]
*
PACKED_MX_BLOCK_SIZE
+
3
*
offsets_cols
[
None
,
:]
+
1
offsets_out_2
=
offsets_rows
[:,
None
]
*
PACKED_MX_BLOCK_SIZE
+
3
*
offsets_cols
[
None
,
:]
+
2
# Store into output tensor
tl
.
store
(
output_ptr
+
offsets_out_4_a
,
bits_packed0
,
mask
=
mask
,
)
tl
.
store
(
output_ptr
+
offsets_out_4_b
,
bits_packed1
,
mask
=
mask
,
)
tl
.
store
(
output_ptr
+
offsets_out_2
,
bits_packed2
,
mask
=
mask
,
)
def
pack_uint6
(
uint8_data
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# ensure input data is contiguous before passing to kernel
assert
uint8_data
.
is_contiguous
()
# tensor should already be of shape [..., mx_block_size]
mx_block_size
=
uint8_data
.
shape
[
-
1
]
assert
mx_block_size
%
4
==
0
# effective mx block size since we're packing 2 fp4 into 1 uint8
packed_mx_block_size
=
3
*
mx_block_size
//
4
packed_shape
=
[
uint8_data
.
shape
[
0
],
packed_mx_block_size
]
n_mx_blocks
=
uint8_data
.
numel
()
//
mx_block_size
grid
=
lambda
meta
:
(
triton
.
cdiv
(
n_mx_blocks
,
meta
[
"BLOCK_SIZE_IN"
]),)
# noqa: E731
# contiguous uint8 container in which we can store the unpacked tensor
packed_uint8_data
=
torch
.
empty
(
packed_shape
,
dtype
=
torch
.
uint8
,
device
=
uint8_data
.
device
)
triton_pack_uint6_kernel
[
grid
](
uint8_data
,
packed_uint8_data
,
n_mx_blocks
,
MX_BLOCK_SIZE
=
mx_block_size
,
PACKED_MX_BLOCK_SIZE
=
packed_mx_block_size
,
)
return
packed_uint8_data
M
=
[
257
,
512
,
1024
,
13325
,
32130
,
32760
]
# , 75348
N
=
[
1536
,
5120
,
8960
]
# , 13824
K
=
[
128
,
256
,
512
,
1024
,
2048
,
4096
]
# , 13824
for
m
in
M
:
for
n
in
N
:
for
k
in
K
:
x
=
torch
.
randn
(
m
,
k
,
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
w
=
torch
.
randn
(
n
,
k
,
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
# excute quant
x_scale
,
x_quant
=
quant2mxfp8
(
x
)
w_scale
,
w_quant
=
quant2mxfp6
(
w
)
# pack fp6 for cutlass
w_quant_packed
=
pack_uint6
(
w_quant
.
reshape
(
-
1
,
32
))
# pad and swizzle scale
padded_and_swizzled_x_scale
=
scale_pad_and_swizzle
(
x_scale
)
padded_and_swizzled_w_scale
=
scale_pad_and_swizzle
(
w_scale
)
# ref mm result
ref_mm
=
torch
.
nn
.
functional
.
linear
(
x
,
w
).
to
(
torch
.
bfloat16
)
# custom scaled mm
from
lightx2v_kernel.gemm
import
cutlass_scaled_mxfp6_mxfp8_mm
alpha
=
torch
.
tensor
(
1.0
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
bias
=
None
x_quant
=
x_quant
.
reshape
(
m
,
k
).
view
(
torch
.
uint8
)
w_quant_packed
=
w_quant_packed
.
reshape
(
n
,
3
*
k
//
4
)
custom_mm
=
cutlass_scaled_mxfp6_mxfp8_mm
(
x_quant
,
w_quant_packed
,
padded_and_swizzled_x_scale
,
padded_and_swizzled_w_scale
,
alpha
,
bias
)
# cal snr
from
lightx2v_kernel.utils
import
error
print
(
f
"m:
{
m
}
, n:
{
n
}
, k:
{
k
}
, error:
{
error
(
ref_mm
,
custom_mm
)
}
"
)
# cal cos
cos_sim
=
torch
.
nn
.
functional
.
cosine_similarity
(
ref_mm
.
flatten
(),
custom_mm
.
flatten
(),
dim
=
0
)
print
(
f
"m:
{
m
}
, n:
{
n
}
, k:
{
k
}
, cos_sim:
{
cos_sim
}
"
)
lightx2v_kernel/test/mxfp6_mxfp8/test_mm_tflops.py
0 → 100644
View file @
e08c4f90
import
torch
from
lightx2v_kernel.gemm
import
cutlass_scaled_mxfp6_mxfp8_mm
"""
input_shape = (1024, 2048)
weight_shape = (4096, 2048)
input_tensor_quant = (torch.rand((1024, 1024), device="cuda") * 10).to(torch.uint8)
weight = (torch.rand((4096, 1024), device="cuda") * 10).to(torch.uint8)
input_tensor_scale = torch.rand(1024, 128, device="cuda").to(torch.float8_e8m0fnu)
weight_scale = torch.rand(4096, 128, device="cuda").to(torch.float8_e8m0fnu)
alpha = torch.tensor(1.0, device="cuda").to(torch.float32)
bias = None
"""
def
test_mm
(
input_tensor_quant
,
weight
,
input_tensor_scale
,
weight_scale
,
alpha
,
bias
):
output_tensor
=
cutlass_scaled_mxfp6_mxfp8_mm
(
input_tensor_quant
,
weight
,
input_tensor_scale
,
weight_scale
,
alpha
=
alpha
,
bias
=
bias
)
return
output_tensor
def
test_tflops
(
input_shape
,
weight_shape
,
num_warmup
=
10
,
num_runs
=
100
):
"""
测试test_mm函数的TFLOPS性能
"""
# 创建输入数据
input_tensor_quant
=
(
torch
.
rand
((
input_shape
[
0
],
input_shape
[
1
]),
device
=
"cuda"
)
*
10
).
to
(
torch
.
uint8
)
weight
=
(
torch
.
rand
((
weight_shape
[
0
],
3
*
weight_shape
[
1
]
//
4
),
device
=
"cuda"
)
*
10
).
to
(
torch
.
uint8
)
input_tensor_scale
=
torch
.
rand
(((
input_shape
[
0
]
+
128
-
1
)
//
128
)
*
128
,
(
input_shape
[
1
]
//
32
+
4
-
1
)
//
4
*
4
,
device
=
"cuda"
).
to
(
torch
.
float8_e8m0fnu
)
weight_scale
=
torch
.
rand
(
weight_shape
[
0
],
weight_shape
[
1
]
//
32
,
device
=
"cuda"
).
to
(
torch
.
float8_e8m0fnu
)
alpha
=
torch
.
tensor
(
1.0
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
bias
=
None
# 预热GPU
for
_
in
range
(
num_warmup
):
test_mm
(
input_tensor_quant
,
weight
,
input_tensor_scale
,
weight_scale
,
alpha
,
bias
)
# 同步GPU
torch
.
cuda
.
synchronize
()
# 创建GPU事件用于精确计时
start_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
end_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
# 测量时间
start_event
.
record
()
for
_
in
range
(
num_runs
):
result
=
test_mm
(
input_tensor_quant
,
weight
,
input_tensor_scale
,
weight_scale
,
alpha
,
bias
)
end_event
.
record
()
# 同步并计算时间
torch
.
cuda
.
synchronize
()
elapsed_time_ms
=
start_event
.
elapsed_time
(
end_event
)
elapsed_time_s
=
elapsed_time_ms
/
1000.0
# 计算FLOPS
# 矩阵乘法 A(M x K) @ B(K x N) = C(M x N)
# M = batch_size, K = input_dim, N = output_dim
M
=
input_shape
[
0
]
K
=
input_shape
[
1
]
N
=
weight_shape
[
0
]
# 每次矩阵乘法的FLOPS = 2 * M * N * K (每个输出元素需要K次乘法和K次加法)
flops_per_run
=
2
*
M
*
N
*
K
total_flops
=
flops_per_run
*
num_runs
# 计算TFLOPS (万亿次浮点运算每秒)
tflops
=
total_flops
/
(
elapsed_time_s
*
1e12
)
print
(
f
"测试结果:"
)
print
(
f
" 输入形状:
{
input_shape
}
(M=
{
M
}
, K=
{
K
}
)"
)
print
(
f
" 权重形状:
{
weight_shape
}
(N=
{
N
}
, K=
{
K
}
)"
)
print
(
f
" 输出形状: (
{
M
}
,
{
N
}
)"
)
print
(
f
" 运行次数:
{
num_runs
}
"
)
print
(
f
" 总执行时间:
{
elapsed_time_ms
:.
2
f
}
ms"
)
print
(
f
" 平均每次执行时间:
{
elapsed_time_ms
/
num_runs
:.
4
f
}
ms"
)
print
(
f
" 每次运行FLOPS:
{
flops_per_run
/
1e9
:.
2
f
}
GFLOPS"
)
print
(
f
" 总FLOPS:
{
total_flops
/
1e12
:.
2
f
}
TFLOPS"
)
print
(
f
" 计算性能:
{
tflops
:.
2
f
}
TFLOPS"
)
return
tflops
if
__name__
==
"__main__"
:
# 测试不同大小的矩阵乘法
# (m,k) (n,k)
test_cases
=
[
((
32130
,
5120
),
(
5120
,
5120
)),
((
512
,
1536
),
(
1536
,
1536
)),
((
512
,
5120
),
(
5120
,
5120
)),
((
257
,
5120
),
(
5120
,
5120
)),
((
32130
,
5120
),
(
13824
,
5120
)),
((
32130
,
13824
),
(
5120
,
13824
)),
((
75348
,
5120
),
(
5120
,
5120
)),
((
75348
,
5120
),
(
13824
,
5120
)),
((
75348
,
13824
),
(
5120
,
13824
)),
((
32760
,
1536
),
(
1536
,
1536
)),
((
512
,
1536
),
(
1536
,
1536
)),
((
32760
,
1536
),
(
8960
,
1536
)),
((
32760
,
8960
),
(
1536
,
8960
)),
]
print
(
"=== test_mm TFLOPS性能测试 ===
\n
"
)
for
i
,
(
input_shape
,
weight_shape
)
in
enumerate
(
test_cases
):
print
(
f
"测试
{
i
+
1
}
: 输入形状
{
input_shape
}
, 权重形状
{
weight_shape
}
"
)
print
(
"-"
*
60
)
tflops
=
test_tflops
(
input_shape
,
weight_shape
)
print
(
f
"✓ 成功完成测试,性能:
{
tflops
:.
2
f
}
TFLOPS
\n
"
)
print
(
"=== 测试完成 ==="
)
lightx2v_kernel/test/mxfp6_mxfp8/test_mxfp6_quant.py
0 → 100644
View file @
e08c4f90
import
unittest
import
torch
from
lightx2v_kernel.gemm
import
cutlass_scaled_mxfp6_mxfp8_mm
from
lightx2v_kernel.gemm
import
scaled_fp6_quant
,
scaled_fp8_quant
from
torch.nn.functional
import
linear
from
lightx2v_kernel.utils
import
error
,
benchmark
class
TestQuantBF162MXFP6
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
tokens
=
[
128
,
257
,
512
,
1024
,
13325
,
32130
,
32760
]
# , 75348
self
.
channels
=
[
128
,
1536
,
5120
,
8960
]
# , 13824
self
.
hiddenDims
=
[
128
,
1536
,
3072
,
5120
,
8960
,
12800
]
# , 13824
self
.
device
=
"cuda"
self
.
dtype
=
torch
.
bfloat16
def
test_accuracy
(
self
):
"""Test the accuracy of quantization from BF16 to MXFP6."""
for
m
in
self
.
tokens
:
for
k
in
self
.
hiddenDims
:
for
n
in
self
.
channels
:
with
self
.
subTest
(
shape
=
[
m
,
k
,
n
]):
activation
=
torch
.
randn
(
m
,
k
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
activation_quant_pred
,
activation_scale_pred
=
scaled_fp8_quant
(
activation
)
weight
=
torch
.
randn
(
n
,
k
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
weight_quant_pred
,
weight_scale_pred
=
scaled_fp6_quant
(
weight
)
alpha
=
torch
.
tensor
(
1.0
,
device
=
self
.
device
,
dtype
=
torch
.
float32
)
mm_pred
=
cutlass_scaled_mxfp6_mxfp8_mm
(
activation_quant_pred
,
weight_quant_pred
,
activation_scale_pred
,
weight_scale_pred
,
alpha
=
alpha
)
mm_real
=
linear
(
activation
,
weight
,
bias
=
None
).
to
(
torch
.
bfloat16
)
self
.
assertTrue
(
error
(
mm_pred
,
mm_real
)
<
1e-2
,
f
"Accuracy test failed for shape
{
m
,
k
,
n
}
: Error
{
error
(
mm_pred
,
mm_real
)
}
exceeds threshold."
)
def
test_performance
(
self
):
"""Benchmark the performance of Activation quantization from BF16 to MXFP6."""
for
m
in
self
.
tokens
:
for
k
in
self
.
hiddenDims
:
with
self
.
subTest
(
shape
=
[
m
,
k
]):
input
=
torch
.
randn
(
m
,
k
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
shape
=
[
m
,
k
]
tflops
=
2
*
(
m
*
k
/
1024
**
4
)
benchmark
(
scaled_fp6_quant
,
shape
,
tflops
,
100
,
input
)
if
__name__
==
"__main__"
:
unittest
.
main
()
lightx2v_kernel/test/mxfp6_mxfp8/test_quant_mem_utils.py
0 → 100644
View file @
e08c4f90
import
torch
from
lightx2v_kernel.gemm
import
scaled_fp6_quant
def
quantize_fp6
(
x
):
return
scaled_fp6_quant
(
x
)
def
test_memory_bandwidth
(
func
,
x
,
num_warmup
=
10
,
num_runs
=
100
):
"""
测试函数的显存带宽
"""
# 预热GPU
for
_
in
range
(
num_warmup
):
func
(
x
)
# 同步GPU
torch
.
cuda
.
synchronize
()
# 创建GPU事件用于精确计时
start_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
end_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
# 测量时间
start_event
.
record
()
for
_
in
range
(
num_runs
):
result
=
func
(
x
)
end_event
.
record
()
# 同步并计算时间
torch
.
cuda
.
synchronize
()
elapsed_time_ms
=
start_event
.
elapsed_time
(
end_event
)
elapsed_time_s
=
elapsed_time_ms
/
1000.0
# 计算数据量
input_bytes
=
x
.
numel
()
*
x
.
element_size
()
# 输入数据字节数
# FP6量化后,每个元素占用 3/ 4字节
output_bytes
=
x
.
numel
()
*
(
3
/
4
)
# FP6输出数据字节数
scale_bytes
=
x
.
numel
()
/
32
# group_size = 32
# 总数据传输量(读取输入 + 写入输出 + scale)
total_bytes
=
(
input_bytes
+
output_bytes
+
scale_bytes
)
*
num_runs
# 计算带宽
bandwidth_gbps
=
(
total_bytes
/
elapsed_time_s
)
/
(
1024
**
3
)
# GB/s
print
(
f
"测试结果:"
)
print
(
f
" 输入张量形状:
{
x
.
shape
}
"
)
print
(
f
" 输入数据类型:
{
x
.
dtype
}
"
)
print
(
f
" 运行次数:
{
num_runs
}
"
)
print
(
f
" 总执行时间:
{
elapsed_time_ms
:.
2
f
}
ms"
)
print
(
f
" 平均每次执行时间:
{
elapsed_time_ms
/
num_runs
:.
4
f
}
ms"
)
print
(
f
" 输入数据大小:
{
input_bytes
/
(
1024
**
2
):.
2
f
}
MB"
)
print
(
f
" 输出数据大小:
{
output_bytes
/
(
1024
**
2
):.
2
f
}
MB"
)
print
(
f
" 总数据传输量:
{
total_bytes
/
(
1024
**
3
):.
2
f
}
GB"
)
print
(
f
" 显存带宽:
{
bandwidth_gbps
:.
2
f
}
GB/s"
)
return
bandwidth_gbps
if
__name__
==
"__main__"
:
# 测试不同大小的张量
test_sizes
=
[
# (1, 1024),
# (1, 2048),
# (1, 4096),
# (1, 8192),
# (1, 16384),
# (1, 32768),
# (2, 1024),
# (2, 2048),
# (2, 4096),
# (2, 8192),
# (2, 16384),
# (2, 32768),
# (4, 1024),
# (4, 2048),
# (4, 4096),
# (4, 8192),
# (4, 16384),
# (4, 32768),
# (128, 1024),
# (128, 2048),
# (128, 4096),
# (128, 8192),
# (128, 16384),
# (128, 32768),
# (512, 1024),
# (512, 2048),
# (512, 4096),
# (512, 8192),
# (512, 16384),
# (512, 32768),
# (1024, 1024),
# (1024, 2048),
# (1024, 4096),
# (1024, 8192),
# (1024, 16384),
# (1024, 32768),
# (2048, 1024),
# (2048, 2048),
# (2048, 4096),
# (2048, 8192),
# (2048, 16384),
# (2048, 32768),
# (4096, 1024),
# (4096, 2048),
# (4096, 4096),
# (4096, 8192),
# (4096, 16384),
# (4096, 32768),
# (8192, 1024),
# (8192, 2048),
# (8192, 4096),
# (8192, 8192),
# (8192, 16384),
# (8192, 32768),
# (16384, 1024),
# (16384, 2048),
# (16384, 4096),
# (16384, 8192),
# (16384, 16384),
# (16384, 32768),
# (32768, 1024),
# (32768, 2048),
# (32768, 4096),
# (32768, 8192),
# (32768, 16384),
# (32768, 32768),
(
32130
,
5120
),
(
512
,
5120
),
(
257
,
5120
),
(
32130
,
13824
),
(
75348
,
5120
),
(
75348
,
13824
),
(
32760
,
1536
),
(
512
,
3072
),
(
512
,
1536
),
(
32760
,
8960
),
]
print
(
"=== quantize_fp8 显存带宽测试 ===
\n
"
)
for
i
,
(
h
,
w
)
in
enumerate
(
test_sizes
):
print
(
f
"测试
{
i
+
1
}
: 张量大小 (
{
h
}
,
{
w
}
)"
)
print
(
"-"
*
50
)
x
=
torch
.
randn
(
h
,
w
,
dtype
=
torch
.
bfloat16
).
cuda
()
try
:
bandwidth
=
test_memory_bandwidth
(
quantize_fp6
,
x
)
print
(
f
"✓ 成功完成测试,带宽:
{
bandwidth
:.
2
f
}
GB/s
\n
"
)
except
Exception
as
e
:
print
(
f
"✗ 测试失败:
{
e
}
\n
"
)
print
(
"=== 测试完成 ==="
)
requirements.txt
View file @
e08c4f90
...
...
@@ -18,3 +18,6 @@ sgl-kernel
qtorch
ftfy
easydict
gradio
aiohttp
pydantic
requirements_win.txt
0 → 100644
View file @
e08c4f90
packaging
ninja
diffusers
transformers
tokenizers
accelerate
safetensors
opencv-python
numpy
imageio
imageio-ffmpeg
einops
loguru
qtorch
ftfy
easydict
gradio
aiohttp
pydantic
scripts/cache/readme.md
View file @
e08c4f90
# Cache
## 缓存加速算法
-
在扩散模型的推理过程中,缓存复用是一种重要的加速算法。
-
其核心思想是在部分时间步跳过冗余计算,通过复用历史缓存结果提升推理效率。
-
算法的关键在于如何决策在哪些时间步进行缓存复用,通常基于模型状态变化或误差阈值动态判断。
-
在推理过程中,需要缓存如中间特征、残差、注意力输出等关键内容。当进入可复用时间步时,直接利用已缓存的内容,通过泰勒展开等近似方法重构当前输出,从而减少重复计算,实现高效推理。
# Feature Caching
## TeaCache
`TeaCache`
的核心思想是通过对相邻时间步输入的
**相对L1**
距离进行累加,当累计距离达到设定阈值时,判定当前时间步可以进行缓存复用。
-
具体来说,算法在每一步推理时计算当前输入与上一步输入的相对L1距离,并将其累加。
-
当累计距离超过阈值,说明模型状态发生了足够的变化,则直接复用最近一次缓存的内容,跳过部分冗余计算。这样可以显著减少模型的前向计算次数,提高推理速度。
The config files for feature caching are available
[
here
](
https://github.com/ModelTC/lightx2v/tree/main/configs/caching
)
实际效果上,TeaCache 在保证生成质量的前提下,实现了明显的加速。加速前后的视频对比如下:
By specifying --config_json to the specific config file, you can test different cache algorithms.
| 加速前 | 加速后 |
|:------:|:------:|
| 单卡H200推理耗时:58s | 单卡H200推理耗时:17.9s |
| !
[
加速前效果
](
../../assets/gifs/1.gif
)
| !
[
加速后效果
](
../../assets/gifs/2.gif
)
|
-
加速比为:
**3.24**
-
参考论文:
[
https://arxiv.org/abs/2411.19108
](
https://arxiv.org/abs/2411.19108
)
Please refer our feature caching doc:
## TaylorSeer Cache
`TaylorSeer Cache`
的核心在于利用泰勒公式对缓存内容进行再次计算,作为缓存复用时间步的残差补偿。具体做法是在缓存复用的时间步,不仅简单地复用历史缓存,还通过泰勒展开对当前输出进行近似重构。这样可以在减少计算量的同时,进一步提升输出的准确性。泰勒展开能够有效捕捉模型状态的微小变化,使得缓存复用带来的误差得到补偿,从而在加速的同时保证生成质量。
`TaylorSeer Cache`
适用于对输出精度要求较高的场景,能够在缓存复用的基础上进一步提升模型推理的表现。
[
English doc: Feature Caching
](
https://lightx2v-en.readthedocs.io/en/latest/method_tutorials/cache.html
)
| 加速前 | 加速后 |
|:------:|:------:|
| 单卡H200推理耗时:57.7s | 单卡H200推理耗时:41.3s |
| !
[
加速前效果
](
../../assets/gifs/3.gif
)
| !
[
加速后效果
](
../../assets/gifs/4.gif
)
|
-
加速比为:
**1.39**
-
参考论文:
[
https://arxiv.org/abs/2503.06923
](
https://arxiv.org/abs/2503.06923
)
## AdaCache
`AdaCache`
的核心思想是根据指定block块中的部分缓存内容,动态调整缓存复用的步长。
-
算法会分析相邻两个时间步在特定 block 内的特征差异,根据差异大小自适应地决定下一个缓存复用的时间步间隔。
-
当模型状态变化较小时,步长自动加大,减少缓存更新频率;当状态变化较大时,步长缩小,保证输出质量。
这样可以根据实际推理过程中的动态变化,灵活调整缓存策略,实现更高效的加速和更优的生成效果。AdaCache 适合对推理速度和生成质量都有较高要求的应用场景。
| 加速前 | 加速后 |
|:------:|:------:|
| 单卡H200推理耗时:227s | 单卡H200推理耗时:83s |
| !
[
加速前效果
](
../../assets/gifs/5.gif
)
| !
[
加速后效果
](
../../assets/gifs/6.gif
)
|
-
加速比为:
**2.73**
-
参考论文:
[
https://arxiv.org/abs/2411.02397
](
https://arxiv.org/abs/2411.02397
)
## CustomCache
`CustomCache`
综合了
`TeaCache`
和
`TaylorSeer Cache`
的优势。
-
它结合了
`TeaCache`
在缓存决策上的实时性和合理性,通过动态阈值判断何时进行缓存复用.
-
同时利用
`TaylorSeer`
的泰勒展开方法对已缓存内容进行利用。
这样不仅能够高效地决定缓存复用的时机,还能最大程度地利用缓存内容,提升输出的准确性和生成质量。实际测试表明,
`CustomCache`
在多个内容生成任务上,生成的视频质量优于单独使用
`TeaCache、TaylorSeer Cache`
或
`AdaCache`
的方案,是目前综合性能最优的缓存加速算法之一。
| 加速前 | 加速后 |
|:------:|:------:|
| 单卡H200推理耗时:57.9s | 单卡H200推理耗时:16.6s |
| !
[
加速前效果
](
../../assets/gifs/7.gif
)
| !
[
加速后效果
](
../../assets/gifs/8.gif
)
|
-
加速比为:
**3.49**
[
中文文档: 特征缓存
](
https://lightx2v-zhcn.readthedocs.io/zh-cn/latest/method_tutorials/cache.html
)
scripts/cache/run_wan_i2v_taylor.sh
deleted
100644 → 0
View file @
12bfd120
#!/bin/bash
# set path and first
lightx2v_path
=
model_path
=
# check section
if
[
-z
"
${
CUDA_VISIBLE_DEVICES
}
"
]
;
then
cuda_devices
=
0
echo
"Warn: CUDA_VISIBLE_DEVICES is not set, using default value:
${
cuda_devices
}
, change at shell script or set env variable."
export
CUDA_VISIBLE_DEVICES
=
${
cuda_devices
}
fi
if
[
-z
"
${
lightx2v_path
}
"
]
;
then
echo
"Error: lightx2v_path is not set. Please set this variable first."
exit
1
fi
if
[
-z
"
${
model_path
}
"
]
;
then
echo
"Error: model_path is not set. Please set this variable first."
exit
1
fi
export
TOKENIZERS_PARALLELISM
=
false
export
PYTHONPATH
=
${
lightx2v_path
}
:
$PYTHONPATH
export
DTYPE
=
BF16
export
ENABLE_PROFILING_DEBUG
=
true
export
ENABLE_GRAPH_MODE
=
false
python
-m
lightx2v.infer
\
--model_cls
wan2.1
\
--task
t2v
\
--model_path
$model_path
\
--config_json
${
lightx2v_path
}
/configs/caching/taylorseer/wan_i2v_tea_480p.json
\
--prompt
"Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage."
\
--negative_prompt
"镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
\
--image_path
${
lightx2v_path
}
/assets/inputs/imgs/img_0.jpg
\
--save_video_path
${
lightx2v_path
}
/save_results/output_lightx2v_wan_i2v_taylor.mp4
scripts/cache/run_wan_t2v_custom.sh
deleted
100644 → 0
View file @
12bfd120
#!/bin/bash
# set path and first
lightx2v_path
=
model_path
=
# check section
if
[
-z
"
${
CUDA_VISIBLE_DEVICES
}
"
]
;
then
cuda_devices
=
0
echo
"Warn: CUDA_VISIBLE_DEVICES is not set, using default value:
${
cuda_devices
}
, change at shell script or set env variable."
export
CUDA_VISIBLE_DEVICES
=
${
cuda_devices
}
fi
if
[
-z
"
${
lightx2v_path
}
"
]
;
then
echo
"Error: lightx2v_path is not set. Please set this variable first."
exit
1
fi
if
[
-z
"
${
model_path
}
"
]
;
then
echo
"Error: model_path is not set. Please set this variable first."
exit
1
fi
export
TOKENIZERS_PARALLELISM
=
false
export
PYTHONPATH
=
${
lightx2v_path
}
:
$PYTHONPATH
export
DTYPE
=
BF16
export
ENABLE_PROFILING_DEBUG
=
true
export
ENABLE_GRAPH_MODE
=
false
python
-m
lightx2v.infer
\
--model_cls
wan2.1
\
--task
t2v
\
--model_path
$model_path
\
--config_json
${
lightx2v_path
}
/configs/caching/custom/wan_t2v_custom_1_3b.json
\
--prompt
"Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage."
\
--negative_prompt
"镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
\
--save_video_path
${
lightx2v_path
}
/save_results/output_lightx2v_wan_t2v_custom.mp4
scripts/cache/run_wan_t2v_tea.sh
View file @
e08c4f90
...
...
@@ -32,7 +32,7 @@ python -m lightx2v.infer \
--model_cls
wan2.1
\
--task
t2v
\
--model_path
$model_path
\
--config_json
${
lightx2v_path
}
/configs/caching/teacache/wan_t2v_1_3b.json
\
--config_json
${
lightx2v_path
}
/configs/caching/teacache/wan_t2v_1_3b
_tea_480p
.json
\
--prompt
"Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage."
\
--negative_prompt
"镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
\
--save_video_path
${
lightx2v_path
}
/save_results/output_lightx2v_wan_t2v_tea.mp4
scripts/cogvideox/readme.md
deleted
100644 → 0
View file @
12bfd120
## todo
scripts/cogvideox/run_cogvideox_t2v.sh
deleted
100755 → 0
View file @
12bfd120
#!/bin/bash
# set path and first
lightx2v_path
=
model_path
=
# check section
if
[
-z
"
${
CUDA_VISIBLE_DEVICES
}
"
]
;
then
cuda_devices
=
0
echo
"Warn: CUDA_VISIBLE_DEVICES is not set, using defalt value:
${
cuda_devices
}
, change at shell script or set env variable."
export
CUDA_VISIBLE_DEVICES
=
${
cuda_devices
}
fi
if
[
-z
"
${
lightx2v_path
}
"
]
;
then
echo
"Error: lightx2v_path is not set. Please set this variable first."
exit
1
fi
if
[
-z
"
${
model_path
}
"
]
;
then
echo
"Error: model_path is not set. Please set this variable first."
exit
1
fi
export
TOKENIZERS_PARALLELISM
=
false
export
PYTHONPATH
=
${
lightx2v_path
}
:
$PYTHONPATH
export
ENABLE_PROFILING_DEBUG
=
true
export
ENABLE_GRAPH_MODE
=
false
export
DTYPE
=
BF16
export
PYTHONPATH
=
/mtc/wushuo/VideoGen/diffusers:
$PYTHONPATH
python
-m
lightx2v.infer
\
--model_cls
cogvideox
\
--task
t2v
\
--model_path
$model_path
\
--config_json
${
lightx2v_path
}
/configs/cogvideox/cogvideox_t2v.json
\
--prompt
"A little girl smile."
\
--save_video_path
${
lightx2v_path
}
/save_results/output_lightx2v_cogvideox_t2v.mp4
scripts/deploy/readme.md
deleted
100644 → 0
View file @
12bfd120
## todo
scripts/deploy/start_dit_server.sh
deleted
100755 → 0
View file @
12bfd120
#!/bin/bash
# set path and first
lightx2v_path
=
model_path
=
# check section
if
[
-z
"
${
CUDA_VISIBLE_DEVICES
}
"
]
;
then
cuda_devices
=
0
echo
"Warn: CUDA_VISIBLE_DEVICES is not set, using default value:
${
cuda_devices
}
, change at shell script or set env variable."
export
CUDA_VISIBLE_DEVICES
=
${
cuda_devices
}
fi
if
[
-z
"
${
lightx2v_path
}
"
]
;
then
echo
"Error: lightx2v_path is not set. Please set this variable first."
exit
1
fi
if
[
-z
"
${
model_path
}
"
]
;
then
echo
"Error: model_path is not set. Please set this variable first."
exit
1
fi
export
TOKENIZERS_PARALLELISM
=
false
export
PYTHONPATH
=
${
lightx2v_path
}
:
$PYTHONPATH
export
ENABLE_PROFILING_DEBUG
=
true
export
ENABLE_GRAPH_MODE
=
false
python
-m
lightx2v.common.apis.dit
\
--model_cls
wan2.1
\
--task
i2v
\
--model_path
$model_path
\
--config_json
${
lightx2v_path
}
/configs/deploy/wan_i2v.json
\
--port
9000
Prev
1
…
4
5
6
7
8
9
10
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment