Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
5103aef7
Commit
5103aef7
authored
Jul 16, 2025
by
Xtra
Committed by
GitHub
Jul 16, 2025
Browse files
add mxfp6 quant kernel and some tests (#126)
parent
514ea716
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
994 additions
and
2 deletions
+994
-2
lightx2v_kernel/CMakeLists.txt
lightx2v_kernel/CMakeLists.txt
+1
-0
lightx2v_kernel/csrc/common_extension.cc
lightx2v_kernel/csrc/common_extension.cc
+4
-0
lightx2v_kernel/csrc/gemm/mxfp6_quant_kernels_sm120.cu
lightx2v_kernel/csrc/gemm/mxfp6_quant_kernels_sm120.cu
+348
-0
lightx2v_kernel/csrc/gemm/mxfp8_quant_kernels_sm120.cu
lightx2v_kernel/csrc/gemm/mxfp8_quant_kernels_sm120.cu
+1
-1
lightx2v_kernel/include/lightx2v_kernel_ops.h
lightx2v_kernel/include/lightx2v_kernel_ops.h
+2
-0
lightx2v_kernel/python/lightx2v_kernel/gemm.py
lightx2v_kernel/python/lightx2v_kernel/gemm.py
+14
-1
lightx2v_kernel/test/mxfp6_mxfp8/test_bench.py
lightx2v_kernel/test/mxfp6_mxfp8/test_bench.py
+121
-0
lightx2v_kernel/test/mxfp6_mxfp8/test_fake_quant.py
lightx2v_kernel/test/mxfp6_mxfp8/test_fake_quant.py
+181
-0
lightx2v_kernel/test/mxfp6_mxfp8/test_mm_tflops.py
lightx2v_kernel/test/mxfp6_mxfp8/test_mm_tflops.py
+115
-0
lightx2v_kernel/test/mxfp6_mxfp8/test_mxfp6_quant.py
lightx2v_kernel/test/mxfp6_mxfp8/test_mxfp6_quant.py
+49
-0
lightx2v_kernel/test/mxfp6_mxfp8/test_quant_mem_utils.py
lightx2v_kernel/test/mxfp6_mxfp8/test_quant_mem_utils.py
+158
-0
No files found.
lightx2v_kernel/CMakeLists.txt
View file @
5103aef7
...
...
@@ -94,6 +94,7 @@ set(SOURCES
"csrc/gemm/nvfp4_scaled_mm_kernels_sm120.cu"
"csrc/gemm/nvfp4_quant_kernels_sm120.cu"
"csrc/gemm/mxfp8_quant_kernels_sm120.cu"
"csrc/gemm/mxfp6_quant_kernels_sm120.cu"
"csrc/gemm/mxfp6_mxfp8_scaled_mm_kernels_sm120.cu"
"csrc/gemm/mxfp8_scaled_mm_kernels_sm120.cu"
"csrc/common_extension.cc"
...
...
lightx2v_kernel/csrc/common_extension.cc
View file @
5103aef7
...
...
@@ -20,6 +20,10 @@ TORCH_LIBRARY_FRAGMENT(lightx2v_kernel, m) {
"scaled_fp8_quant_sm120(Tensor! output, Tensor! input,"
" Tensor! output_scale) -> ()"
);
m
.
impl
(
"scaled_fp8_quant_sm120"
,
torch
::
kCUDA
,
&
scaled_fp8_quant_sm120
);
m
.
def
(
"scaled_fp6_quant_sm120(Tensor! output, Tensor! input,"
" Tensor! output_scale) -> ()"
);
m
.
impl
(
"scaled_fp6_quant_sm120"
,
torch
::
kCUDA
,
&
scaled_fp6_quant_sm120
);
m
.
def
(
"cutlass_scaled_mxfp6_mxfp8_mm_sm120(Tensor! out, Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, Tensor "
...
...
lightx2v_kernel/csrc/gemm/mxfp6_quant_kernels_sm120.cu
0 → 100644
View file @
5103aef7
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_fp8.h>
#include <cuda_fp6.h>
#include <cuda_runtime.h>
#include <cuda_runtime_api.h>
#include <torch/all.h>
#include "utils.h"
// Get type2 from type or vice versa (applied to half and bfloat16)
template
<
typename
T
>
struct
TypeConverter
{
using
Type
=
half2
;
};
// keep for generality
template
<
>
struct
TypeConverter
<
half2
>
{
using
Type
=
half
;
};
template
<
>
struct
TypeConverter
<
half
>
{
using
Type
=
half2
;
};
template
<
>
struct
TypeConverter
<
__nv_bfloat162
>
{
using
Type
=
__nv_bfloat16
;
};
template
<
>
struct
TypeConverter
<
__nv_bfloat16
>
{
using
Type
=
__nv_bfloat162
;
};
#define ELTS_PER_THREAD 8
constexpr
int
CVT_FP6_ELTS_PER_THREAD
=
8
;
constexpr
int
CVT_FP6_SF_VEC_SIZE
=
32
;
struct
uint8x6_t
{
uint8_t
elts
[
6
];
};
// Convert 4 float2 values into 8 e3m2 values (represented as one uint8x6_t).
inline
__device__
uint8x6_t
fp32_vec_to_e3m2
(
float2
(
&
array
)[
4
])
{
uint64_t
val
;
asm
volatile
(
"{
\n
"
".reg .b16 pack0;
\n
"
".reg .b16 pack1;
\n
"
".reg .b16 pack2;
\n
"
".reg .b16 pack3;
\n
"
"cvt.rn.satfinite.e3m2x2.f32 pack0, %2, %1;
\n
"
"cvt.rn.satfinite.e3m2x2.f32 pack1, %4, %3;
\n
"
"cvt.rn.satfinite.e3m2x2.f32 pack2, %6, %5;
\n
"
"cvt.rn.satfinite.e3m2x2.f32 pack3, %8, %7;
\n
"
"mov.b64 %0, {pack0, pack1, pack2, pack3};
\n
"
"}"
:
"=l"
(
val
)
:
"f"
(
array
[
0
].
x
),
"f"
(
array
[
0
].
y
),
"f"
(
array
[
1
].
x
),
"f"
(
array
[
1
].
y
),
"f"
(
array
[
2
].
x
),
"f"
(
array
[
2
].
y
),
"f"
(
array
[
3
].
x
),
"f"
(
array
[
3
].
y
));
uint8x6_t
result
;
// pack 8 uint8_t into 6 uint8_t
// here is how to pack:
// 4个fp6 a b c d. a:[a5 a4 a3 a2 a1 a0], b..., c..., d...
// 3个unint8 pack0 pack1 pack2
// packed0: [b1 b0][a5 a4 a3 a2 a1 a0]
// packed1: [c3 c2 c1 c0][b5 b4 b3 b2]
// packed2: [d5 d4 d3 d2 d1 d0][c5 c4]
// lower 4 uint8_t
uint8_t
l_val_0
=
val
&
0xFF
;
uint8_t
l_val_1
=
(
val
>>
8
)
&
0xFF
;
uint8_t
l_val_2
=
(
val
>>
16
)
&
0xFF
;
uint8_t
l_val_3
=
(
val
>>
24
)
&
0xFF
;
// higher 4 uint8_t
uint8_t
h_val_0
=
(
val
>>
32
)
&
0xFF
;
uint8_t
h_val_1
=
(
val
>>
40
)
&
0xFF
;
uint8_t
h_val_2
=
(
val
>>
48
)
&
0xFF
;
uint8_t
h_val_3
=
(
val
>>
56
)
&
0xFF
;
// pack result
result
.
elts
[
0
]
=
(
l_val_1
<<
6
)
|
l_val_0
;
result
.
elts
[
1
]
=
(
l_val_2
<<
4
)
|
(
l_val_1
>>
2
);
result
.
elts
[
2
]
=
(
l_val_3
<<
2
)
|
(
l_val_2
>>
4
);
result
.
elts
[
3
]
=
(
h_val_1
<<
6
)
|
h_val_0
;
result
.
elts
[
4
]
=
(
h_val_2
<<
4
)
|
(
h_val_1
>>
2
);
result
.
elts
[
5
]
=
(
h_val_3
<<
2
)
|
(
h_val_2
>>
4
);
return
result
;
}
// Fast reciprocal.
inline
__device__
float
reciprocal_approximate_ftz
(
float
a
)
{
float
b
;
asm
volatile
(
"rcp.approx.ftz.f32 %0, %1;
\n
"
:
"=f"
(
b
)
:
"f"
(
a
));
return
b
;
}
template
<
class
SFType
,
int
CVT_FP6_NUM_THREADS_PER_SF
>
__device__
uint8_t
*
get_sf_out_address
(
int
rowIdx
,
int
colIdx
,
int
numCols
,
SFType
*
SFout
)
{
// #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
static_assert
(
CVT_FP6_NUM_THREADS_PER_SF
==
4
);
// one of 4 threads write one SF to global memory.
// TODO: stage through smem for packed STG.32
// is it better than STG.8 from 4 threads ?
if
(
threadIdx
.
x
%
CVT_FP6_NUM_THREADS_PER_SF
==
0
)
{
// SF vector index (32 elements share one SF in the K dimension).
int32_t
kIdx
=
colIdx
/
CVT_FP6_NUM_THREADS_PER_SF
;
int32_t
mIdx
=
rowIdx
;
// SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)]
// --> index [mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx]
int32_t
mTileIdx
=
mIdx
/
(
32
*
4
);
// SF vector size 32.
int
factor
=
CVT_FP6_SF_VEC_SIZE
*
4
;
int32_t
numKTiles
=
(
numCols
+
factor
-
1
)
/
factor
;
int64_t
mTileStride
=
numKTiles
*
32
*
4
*
4
;
int32_t
kTileIdx
=
(
kIdx
/
4
);
int64_t
kTileStride
=
32
*
4
*
4
;
// M tile layout [32, 4] is column-major.
int32_t
outerMIdx
=
(
mIdx
%
32
);
// same as (mIdx % 128) % 32
int64_t
outerMStride
=
4
*
4
;
int32_t
innerMIdx
=
(
mIdx
%
(
32
*
4
))
/
32
;
int64_t
innerMStride
=
4
;
int32_t
innerKIdx
=
(
kIdx
%
4
);
int64_t
innerKStride
=
1
;
// Compute the global offset.
int64_t
SFOffset
=
mTileIdx
*
mTileStride
+
kTileIdx
*
kTileStride
+
outerMIdx
*
outerMStride
+
innerMIdx
*
innerMStride
+
innerKIdx
*
innerKStride
;
return
reinterpret_cast
<
uint8_t
*>
(
SFout
)
+
SFOffset
;
}
else
{
// Other threads do not write to SFout.
return
nullptr
;
}
}
// Define a 16 bytes packed data type.
template
<
class
Type
>
struct
PackedVec
{
typename
TypeConverter
<
Type
>::
Type
elts
[
4
];
};
// template <>
// struct PackedVec<__nv_fp8_e4m3> {
// __nv_fp8x2_e4m3 elts[8];
// };
template
<
class
Type
>
// Type can be half or bfloat16
__device__
uint8x6_t
cvt_warp_fp16_to_fp6
(
PackedVec
<
Type
>&
vec
,
uint8_t
*
SFout
)
{
// #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
// Get absolute maximum values among the local 8 values.
auto
localMax
=
__habs2
(
vec
.
elts
[
0
]);
// Local maximum value.
#pragma unroll
for
(
int
i
=
1
;
i
<
CVT_FP6_ELTS_PER_THREAD
/
2
;
i
++
)
{
localMax
=
__hmax2
(
localMax
,
__habs2
(
vec
.
elts
[
i
]));
}
// Get the absolute maximum among all 32 values (four threads).
localMax
=
__hmax2
(
__shfl_xor_sync
(
uint32_t
(
-
1
),
localMax
,
1
),
localMax
);
localMax
=
__hmax2
(
__shfl_xor_sync
(
uint32_t
(
-
1
),
localMax
,
2
),
localMax
);
// Get the final absolute maximum values.
float
vecMax
=
float
(
__hmax
(
localMax
.
x
,
localMax
.
y
));
// Get the SF (max value of the vector / max value of e3m2).
// maximum value of e3m2 = 28.0.
// TODO: use half as compute data type.
float
SFValue
=
(
vecMax
/
28.0
f
);
// 8 bits representation of the SF.
uint8_t
fp8SFVal
;
// Write the SF to global memory (STG.8).
__nv_fp8_e8m0
tmp
;
tmp
.
__x
=
__nv_cvt_float_to_e8m0
(
SFValue
,
__NV_SATFINITE
,
cudaRoundPosInf
);
SFValue
=
static_cast
<
float
>
(
tmp
);
fp8SFVal
=
tmp
.
__x
;
float
outputScale
=
SFValue
!=
0
?
reciprocal_approximate_ftz
(
SFValue
)
:
0.0
f
;
if
(
SFout
)
{
// Write the SF to global memory (STG.8).
*
SFout
=
fp8SFVal
;
}
// Convert the input to float.
float2
fp2Vals
[
CVT_FP6_ELTS_PER_THREAD
/
2
];
#pragma unroll
for
(
int
i
=
0
;
i
<
CVT_FP6_ELTS_PER_THREAD
/
2
;
i
++
)
{
if
constexpr
(
std
::
is_same_v
<
Type
,
half
>
)
{
fp2Vals
[
i
]
=
__half22float2
(
vec
.
elts
[
i
]);
}
else
{
fp2Vals
[
i
]
=
__bfloat1622float2
(
vec
.
elts
[
i
]);
}
fp2Vals
[
i
].
x
*=
outputScale
;
fp2Vals
[
i
].
y
*=
outputScale
;
}
// Convert to e3m2 values.
uint8x6_t
e3m2Vec
=
fp32_vec_to_e3m2
(
fp2Vals
);
return
e3m2Vec
;
}
template
<
class
Type
>
// Type can be half or bfloat16
__global__
void
// #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
__launch_bounds__
(
256
,
6
)
cvt_fp16_to_fp6
(
// #else
// cvt_fp16_to_fp6(
// #endif
int32_t
numRows
,
int32_t
numCols
,
Type
const
*
in
,
uint8x6_t
*
out
,
uint32_t
*
SFout
)
{
// #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
using
PackedVec
=
PackedVec
<
Type
>
;
static
constexpr
int
CVT_FP6_NUM_THREADS_PER_SF
=
(
CVT_FP6_SF_VEC_SIZE
/
CVT_FP6_ELTS_PER_THREAD
);
static_assert
(
sizeof
(
PackedVec
)
==
sizeof
(
Type
)
*
CVT_FP6_ELTS_PER_THREAD
,
"Vec size is not matched."
);
// Input tensor row/col loops.
for
(
int
rowIdx
=
blockIdx
.
x
;
rowIdx
<
numRows
;
rowIdx
+=
gridDim
.
x
)
{
for
(
int
colIdx
=
threadIdx
.
x
;
colIdx
<
numCols
/
CVT_FP6_ELTS_PER_THREAD
;
colIdx
+=
blockDim
.
x
)
{
int64_t
inOffset
=
rowIdx
*
(
numCols
/
CVT_FP6_ELTS_PER_THREAD
)
+
colIdx
;
PackedVec
in_vec
=
reinterpret_cast
<
PackedVec
const
*>
(
in
)[
inOffset
];
// Get the output tensor offset.
// Same as inOffset because 8 elements(E3M2) are packed into one uint8x6_t.
int64_t
outOffset
=
inOffset
;
auto
&
out_pos
=
out
[
outOffset
];
auto
sf_out
=
get_sf_out_address
<
uint32_t
,
CVT_FP6_NUM_THREADS_PER_SF
>
(
rowIdx
,
colIdx
,
numCols
,
SFout
);
out_pos
=
cvt_warp_fp16_to_fp6
<
Type
>
(
in_vec
,
sf_out
);
}
}
// #endif
}
template
<
typename
T
>
void
invokeFP6Quantization
(
int
m
,
int
n
,
T
const
*
input
,
int64_t
*
output
,
int32_t
*
SFOuput
,
int
multiProcessorCount
,
cudaStream_t
stream
)
{
// Grid, Block size.
// Each thread converts 8 values.
dim3
block
(
std
::
min
(
int
(
n
/
ELTS_PER_THREAD
),
256
));
// Get number of blocks per SM (assume we can fully utilize the SM).
int
const
numBlocksPerSM
=
1536
/
block
.
x
;
dim3
grid
(
std
::
min
(
int
(
m
),
multiProcessorCount
*
numBlocksPerSM
));
// Launch the cvt kernel.
cvt_fp16_to_fp6
<
T
>
<<<
grid
,
block
,
0
,
stream
>>>
(
m
,
n
,
input
,
reinterpret_cast
<
uint8x6_t
*>
(
output
),
reinterpret_cast
<
uint32_t
*>
(
SFOuput
));
}
// Instantiate the function.
template
void
invokeFP6Quantization
(
int
m
,
int
n
,
half
const
*
input
,
int64_t
*
output
,
int32_t
*
SFOuput
,
int
multiProcessorCount
,
cudaStream_t
stream
);
template
void
invokeFP6Quantization
(
int
m
,
int
n
,
__nv_bfloat16
const
*
input
,
int64_t
*
output
,
int32_t
*
SFOuput
,
int
multiProcessorCount
,
cudaStream_t
stream
);
inline
int
getMultiProcessorCount
()
{
static
int
multi_processor_count
=
[]()
{
int
device_id
=
0
;
int
count
=
0
;
// Get the current CUDA device ID
CHECK_CUDA_SUCCESS
(
cudaGetDevice
(
&
device_id
));
// Get the number of multiprocessors for the current device
CHECK_CUDA_SUCCESS
(
cudaDeviceGetAttribute
(
&
count
,
cudaDevAttrMultiProcessorCount
,
device_id
));
return
count
;
// Initialize the static variable
}();
return
multi_processor_count
;
// Return the cached value on subsequent calls
}
void
scaled_fp6_quant_sm120
(
torch
::
Tensor
&
output
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
&
output_sf
)
{
int32_t
m
=
input
.
size
(
0
);
int32_t
n
=
input
.
size
(
1
);
TORCH_CHECK
(
n
%
32
==
0
,
"The N dimension must be multiple of 32."
);
int
multiProcessorCount
=
getMultiProcessorCount
();
auto
sf_out
=
static_cast
<
int32_t
*>
(
output_sf
.
data_ptr
());
auto
output_ptr
=
static_cast
<
int64_t
*>
(
output
.
data_ptr
());
at
::
cuda
::
CUDAGuard
device_guard
{(
char
)
input
.
get_device
()};
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
input
.
get_device
());
switch
(
input
.
scalar_type
())
{
case
torch
::
kHalf
:
{
auto
input_ptr
=
reinterpret_cast
<
half
const
*>
(
input
.
data_ptr
());
invokeFP6Quantization
(
m
,
n
,
input_ptr
,
output_ptr
,
sf_out
,
multiProcessorCount
,
stream
);
break
;
}
case
torch
::
kBFloat16
:
{
auto
input_ptr
=
reinterpret_cast
<
__nv_bfloat16
const
*>
(
input
.
data_ptr
());
invokeFP6Quantization
(
m
,
n
,
input_ptr
,
output_ptr
,
sf_out
,
multiProcessorCount
,
stream
);
break
;
}
default:
{
std
::
cerr
<<
"Observing: "
<<
input
.
scalar_type
()
<<
" for the input datatype which is invalid"
;
throw
std
::
runtime_error
(
"Unsupported input data type for quantize_to_fp6."
);
}
}
}
lightx2v_kernel/csrc/gemm/mxfp8_quant_kernels_sm120.cu
View file @
5103aef7
...
...
@@ -287,7 +287,7 @@ void scaled_fp8_quant_sm120(
int32_t
m
=
input
.
size
(
0
);
int32_t
n
=
input
.
size
(
1
);
TORCH_CHECK
(
n
%
32
==
0
,
"The N dimension must be multiple of
16
."
);
TORCH_CHECK
(
n
%
32
==
0
,
"The N dimension must be multiple of
32
."
);
int
multiProcessorCount
=
getMultiProcessorCount
();
...
...
lightx2v_kernel/include/lightx2v_kernel_ops.h
View file @
5103aef7
...
...
@@ -60,6 +60,8 @@ void scaled_fp4_quant_sm120(
void
scaled_fp8_quant_sm120
(
torch
::
Tensor
&
output
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
&
output_sf
);
void
scaled_fp6_quant_sm120
(
torch
::
Tensor
&
output
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
&
output_sf
);
void
cutlass_scaled_mxfp6_mxfp8_mm_sm120
(
torch
::
Tensor
&
D
,
...
...
lightx2v_kernel/python/lightx2v_kernel/gemm.py
View file @
5103aef7
...
...
@@ -48,13 +48,26 @@ def scaled_fp4_quant(input: torch.Tensor, input_global_scale: torch.Tensor):
# rounded_m = ((m + 128 - 1) // 128) * 128
# scale_n = n // block_size
# rounded_n = ((scale_n + 4 - 1) // 4) * 4
output_scale
=
torch
.
empty
((((
m
+
128
-
1
)
//
128
)
*
128
,
(
n
//
block_size
+
4
-
1
)
//
4
),
device
=
device
,
dtype
=
torch
.
int32
)
output_scale
=
torch
.
zeros
((((
m
+
128
-
1
)
//
128
)
*
128
,
(
n
//
block_size
+
4
-
1
)
//
4
),
device
=
device
,
dtype
=
torch
.
int32
)
torch
.
ops
.
lightx2v_kernel
.
scaled_fp4_quant_sm120
.
default
(
output
,
input
,
output_scale
,
input_global_scale
)
output_scale
=
output_scale
.
view
(
torch
.
float8_e4m3fn
)
return
output
,
output_scale
def
scaled_fp6_quant
(
input
:
torch
.
Tensor
):
m
,
n
=
input
.
shape
block_size
=
32
device
=
input
.
device
output
=
torch
.
empty
((
m
,
3
*
n
//
4
),
device
=
device
,
dtype
=
torch
.
uint8
)
output_scale
=
torch
.
zeros
(((
m
+
128
-
1
)
//
128
*
128
,
(
n
//
block_size
+
4
-
1
)
//
4
),
device
=
device
,
dtype
=
torch
.
int32
)
torch
.
ops
.
lightx2v_kernel
.
scaled_fp6_quant_sm120
.
default
(
output
,
input
,
output_scale
)
output_scale
=
output_scale
.
view
(
torch
.
float8_e8m0fnu
)
return
output
,
output_scale
def
scaled_fp8_quant
(
input
:
torch
.
Tensor
):
m
,
n
=
input
.
shape
block_size
=
32
...
...
lightx2v_kernel/test/mxfp6_mxfp8/test_bench.py
0 → 100644
View file @
5103aef7
import
torch
from
lightx2v_kernel.gemm
import
scaled_fp8_quant
,
scaled_fp6_quant
,
cutlass_scaled_mxfp6_mxfp8_mm
import
time
class
MMWeightMxfp8ActMxfp6
:
def
__init__
(
self
,
weight
,
bias
):
self
.
load_fp6_weight
(
weight
,
bias
)
self
.
act_quant_func
=
self
.
act_quant_fp8
self
.
set_alpha
()
@
torch
.
no_grad
()
def
apply
(
self
,
input_tensor
):
input_tensor_quant
,
input_tensor_scale
=
self
.
act_quant_func
(
input_tensor
)
output_tensor
=
cutlass_scaled_mxfp6_mxfp8_mm
(
input_tensor_quant
,
self
.
weight
,
input_tensor_scale
,
self
.
weight_scale
,
alpha
=
self
.
alpha
,
bias
=
self
.
bias
)
return
output_tensor
@
torch
.
no_grad
()
def
load_fp6_weight
(
self
,
weight
,
bias
):
self
.
weight
,
self
.
weight_scale
=
scaled_fp6_quant
(
weight
)
self
.
bias
=
bias
def
set_alpha
(
self
):
self
.
alpha
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
,
device
=
self
.
weight
.
device
)
@
torch
.
no_grad
()
def
act_quant_fp8
(
self
,
x
):
return
scaled_fp8_quant
(
x
)
def
test_speed
(
m
,
k
,
n
):
with
torch
.
no_grad
():
input_tensor
=
torch
.
randn
(
m
,
k
,
dtype
=
torch
.
bfloat16
).
cuda
()
weight
=
torch
.
randn
(
n
,
k
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
# bias = torch.randn(1, n, dtype=torch.bfloat16).cuda()
bias
=
None
mm
=
MMWeightMxfp8ActMxfp6
(
weight
,
bias
)
# warmup
output_tensor
=
mm
.
apply
(
input_tensor
)
torch
.
cuda
.
synchronize
()
start_time
=
time
.
time
()
for
i
in
range
(
100
):
output_tensor
=
mm
.
apply
(
input_tensor
)
torch
.
cuda
.
synchronize
()
end_time
=
time
.
time
()
lightx2v_kernel_time
=
(
end_time
-
start_time
)
/
100
print
(
f
"lightx2v-kernel time:
{
lightx2v_kernel_time
}
"
)
input_tensor
=
torch
.
randn
(
m
,
n
,
dtype
=
torch
.
bfloat16
).
cuda
()
weight
=
torch
.
randn
(
k
,
n
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
bias
=
torch
.
randn
(
1
,
k
,
dtype
=
torch
.
bfloat16
).
cuda
()
linear
=
torch
.
nn
.
Linear
(
k
,
n
,
bias
=
False
).
cuda
()
linear
.
weight
.
data
=
weight
# linear.bias.data = bias
# warmup
ref_output_tensor
=
linear
(
input_tensor
)
torch
.
cuda
.
synchronize
()
start_time
=
time
.
time
()
for
i
in
range
(
100
):
ref_output_tensor
=
linear
(
input_tensor
)
torch
.
cuda
.
synchronize
()
end_time
=
time
.
time
()
ref_time
=
(
end_time
-
start_time
)
/
100
print
(
f
"ref time:
{
ref_time
}
"
)
print
(
f
"speedup:
{
ref_time
/
lightx2v_kernel_time
:.
3
f
}
"
)
def
test_accuracy
(
m
,
k
,
n
):
with
torch
.
no_grad
():
input_tensor
=
torch
.
randn
(
m
,
k
,
dtype
=
torch
.
bfloat16
).
cuda
()
weight
=
torch
.
randn
(
n
,
k
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
# bias = torch.randn(1, n, dtype=torch.bfloat16).cuda()
bias
=
None
linear
=
torch
.
nn
.
Linear
(
k
,
n
,
bias
=
False
).
cuda
()
linear
.
weight
.
data
=
weight
# linear.bias.data = bias
ref_output_tensor
=
linear
(
input_tensor
)
mm
=
MMWeightMxfp8ActMxfp6
(
weight
,
bias
)
output_tensor
=
mm
.
apply
(
input_tensor
)
# print(f"ref_output_tensor: {ref_output_tensor}")
# print(f"output_tensor: {output_tensor}")
# cosine
cos
=
torch
.
nn
.
functional
.
cosine_similarity
(
ref_output_tensor
.
flatten
(),
output_tensor
.
flatten
(),
dim
=
0
)
print
(
f
"cos :
{
cos
}
"
)
if
__name__
==
"__main__"
:
test_sizes
=
[
(
32130
,
5120
,
5120
),
(
512
,
5120
,
5120
),
(
257
,
5120
,
5120
),
(
32130
,
5120
,
13824
),
(
32130
,
13824
,
5120
),
(
75348
,
5120
,
5120
),
(
75348
,
13824
,
5120
),
(
32760
,
1536
,
1536
),
(
512
,
1536
,
1536
),
(
32760
,
1536
,
8960
),
(
32760
,
8960
,
1536
),
]
for
i
,
(
m
,
k
,
n
)
in
enumerate
(
test_sizes
):
print
(
"-"
*
30
)
print
(
f
"测试
{
i
+
1
}
: 张量大小 (
{
m
}
,
{
k
}
,
{
n
}
)"
)
test_accuracy
(
m
,
k
,
n
)
test_speed
(
m
,
k
,
n
)
lightx2v_kernel/test/mxfp6_mxfp8/test_fake_quant.py
0 → 100644
View file @
5103aef7
import
torch
from
torchao.prototype.mx_formats.constants
import
DTYPE_FP6_E3M2
from
torchao.prototype.mx_formats.mx_tensor
import
to_mx
,
pack_uint6
def
quant2mxfp8
(
x
:
torch
.
Tensor
):
block_size
=
32
m
,
_
=
x
.
shape
scale
,
output
=
to_mx
(
x
,
torch
.
float8_e4m3fn
,
block_size
=
block_size
)
return
scale
.
reshape
(
m
,
-
1
),
output
def
quant2mxfp6
(
x
:
torch
.
Tensor
):
block_size
=
32
m
,
_
=
x
.
shape
scale
,
output
=
to_mx
(
x
,
DTYPE_FP6_E3M2
,
block_size
=
block_size
,
pack_fp6
=
False
)
return
scale
.
reshape
(
m
,
-
1
),
output
def
scale_pad_and_swizzle
(
scale
:
torch
.
Tensor
):
m
,
s
=
scale
.
shape
# pad the m up to 128, s up to 4
padded_m
=
(
m
+
127
)
//
128
*
128
padded_s
=
(
s
+
3
)
//
4
*
4
padded_scale
=
torch
.
empty
(
padded_m
,
padded_s
,
device
=
scale
.
device
,
dtype
=
scale
.
dtype
)
padded_scale
[:
m
,
:
s
]
=
scale
# swizzle the padded scale
swizzled_scale
=
padded_scale
.
reshape
(
padded_m
//
128
,
128
,
padded_s
//
4
,
4
).
reshape
(
padded_m
//
128
,
4
,
32
,
padded_s
//
4
,
4
).
permute
(
0
,
3
,
2
,
1
,
4
)
return
swizzled_scale
.
reshape
(
padded_m
,
padded_s
)
###############################################################
# Packing kernel and func
###############################################################
import
triton
# noqa: E402
import
triton.language
as
tl
# noqa: E402
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({
"BLOCK_SIZE_IN"
:
2
},
num_warps
=
1
),
triton
.
Config
({
"BLOCK_SIZE_IN"
:
4
},
num_warps
=
1
),
triton
.
Config
({
"BLOCK_SIZE_IN"
:
8
},
num_warps
=
1
),
triton
.
Config
({
"BLOCK_SIZE_IN"
:
16
},
num_warps
=
1
),
],
key
=
[
"n_mx_blocks"
],
)
@
triton
.
jit
def
triton_pack_uint6_kernel
(
input_ptr
,
output_ptr
,
n_mx_blocks
,
MX_BLOCK_SIZE
:
tl
.
constexpr
,
PACKED_MX_BLOCK_SIZE
:
tl
.
constexpr
,
BLOCK_SIZE_IN
:
tl
.
constexpr
,
):
pid
=
tl
.
program_id
(
axis
=
0
)
block_start
=
pid
*
BLOCK_SIZE_IN
# input_ptr is shape [n_mx_blocks, MX_BLOCK_SIZE]
# Load BLOCK_SIZE rows of input_ptr
offsets_rows
=
block_start
+
tl
.
arange
(
0
,
BLOCK_SIZE_IN
)
offsets_cols
=
tl
.
arange
(
0
,
MX_BLOCK_SIZE
//
4
)
offsets
=
offsets_rows
[:,
None
]
*
MX_BLOCK_SIZE
+
(
4
*
offsets_cols
[
None
,
:])
mask
=
(
offsets_rows
[:,
None
]
<
n_mx_blocks
)
&
(
offsets_cols
[
None
,
:]
<
MX_BLOCK_SIZE
//
4
)
# x is shape [BLOCK_SIZE, MX_BLOCK_SIZE]
x_0
=
tl
.
load
(
input_ptr
+
offsets
,
mask
=
mask
)
x_1
=
tl
.
load
(
input_ptr
+
offsets
+
1
,
mask
=
mask
)
x_2
=
tl
.
load
(
input_ptr
+
offsets
+
2
,
mask
=
mask
)
x_3
=
tl
.
load
(
input_ptr
+
offsets
+
3
,
mask
=
mask
)
# 4个fp6 a b c d. a:[a5 a4 a3 a2 a1 a0], b..., c..., d...
# 3个unint8 pack0 pack1 pack2
# cutlass需要的:
# packed0: [b1 b0][a5 a4 a3 a2 a1 a0]
# packed1: [c3 c2 c1 c0][b5 b4 b3 b2]
# packed2: [d5 d4 d3 d2 d1 d0][c5 c4]
bits_packed0
=
(
x_1
<<
6
)
|
x_0
bits_packed1
=
(
x_2
<<
4
)
|
(
x_1
>>
2
)
bits_packed2
=
(
x_3
<<
2
)
|
(
x_2
>>
4
)
# Store values in a uint8 tensor of length `3 * MX_BLOCK_SIZE / 4`
offsets_out_4_a
=
offsets_rows
[:,
None
]
*
PACKED_MX_BLOCK_SIZE
+
3
*
offsets_cols
[
None
,
:]
offsets_out_4_b
=
offsets_rows
[:,
None
]
*
PACKED_MX_BLOCK_SIZE
+
3
*
offsets_cols
[
None
,
:]
+
1
offsets_out_2
=
offsets_rows
[:,
None
]
*
PACKED_MX_BLOCK_SIZE
+
3
*
offsets_cols
[
None
,
:]
+
2
# Store into output tensor
tl
.
store
(
output_ptr
+
offsets_out_4_a
,
bits_packed0
,
mask
=
mask
,
)
tl
.
store
(
output_ptr
+
offsets_out_4_b
,
bits_packed1
,
mask
=
mask
,
)
tl
.
store
(
output_ptr
+
offsets_out_2
,
bits_packed2
,
mask
=
mask
,
)
def
pack_uint6
(
uint8_data
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# ensure input data is contiguous before passing to kernel
assert
uint8_data
.
is_contiguous
()
# tensor should already be of shape [..., mx_block_size]
mx_block_size
=
uint8_data
.
shape
[
-
1
]
assert
mx_block_size
%
4
==
0
# effective mx block size since we're packing 2 fp4 into 1 uint8
packed_mx_block_size
=
3
*
mx_block_size
//
4
packed_shape
=
[
uint8_data
.
shape
[
0
],
packed_mx_block_size
]
n_mx_blocks
=
uint8_data
.
numel
()
//
mx_block_size
grid
=
lambda
meta
:
(
triton
.
cdiv
(
n_mx_blocks
,
meta
[
"BLOCK_SIZE_IN"
]),)
# noqa: E731
# contiguous uint8 container in which we can store the unpacked tensor
packed_uint8_data
=
torch
.
empty
(
packed_shape
,
dtype
=
torch
.
uint8
,
device
=
uint8_data
.
device
)
triton_pack_uint6_kernel
[
grid
](
uint8_data
,
packed_uint8_data
,
n_mx_blocks
,
MX_BLOCK_SIZE
=
mx_block_size
,
PACKED_MX_BLOCK_SIZE
=
packed_mx_block_size
,
)
return
packed_uint8_data
M
=
[
257
,
512
,
1024
,
13325
,
32130
,
32760
]
# , 75348
N
=
[
1536
,
5120
,
8960
]
# , 13824
K
=
[
128
,
256
,
512
,
1024
,
2048
,
4096
]
# , 13824
for
m
in
M
:
for
n
in
N
:
for
k
in
K
:
x
=
torch
.
randn
(
m
,
k
,
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
w
=
torch
.
randn
(
n
,
k
,
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
# excute quant
x_scale
,
x_quant
=
quant2mxfp8
(
x
)
w_scale
,
w_quant
=
quant2mxfp6
(
w
)
# pack fp6 for cutlass
w_quant_packed
=
pack_uint6
(
w_quant
.
reshape
(
-
1
,
32
))
# pad and swizzle scale
padded_and_swizzled_x_scale
=
scale_pad_and_swizzle
(
x_scale
)
padded_and_swizzled_w_scale
=
scale_pad_and_swizzle
(
w_scale
)
# ref mm result
ref_mm
=
torch
.
nn
.
functional
.
linear
(
x
,
w
).
to
(
torch
.
bfloat16
)
# custom scaled mm
from
lightx2v_kernel.gemm
import
cutlass_scaled_mxfp6_mxfp8_mm
alpha
=
torch
.
tensor
(
1.0
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
bias
=
None
x_quant
=
x_quant
.
reshape
(
m
,
k
).
view
(
torch
.
uint8
)
w_quant_packed
=
w_quant_packed
.
reshape
(
n
,
3
*
k
//
4
)
custom_mm
=
cutlass_scaled_mxfp6_mxfp8_mm
(
x_quant
,
w_quant_packed
,
padded_and_swizzled_x_scale
,
padded_and_swizzled_w_scale
,
alpha
,
bias
)
# cal snr
from
lightx2v_kernel.utils
import
error
print
(
f
"m:
{
m
}
, n:
{
n
}
, k:
{
k
}
, error:
{
error
(
ref_mm
,
custom_mm
)
}
"
)
# cal cos
cos_sim
=
torch
.
nn
.
functional
.
cosine_similarity
(
ref_mm
.
flatten
(),
custom_mm
.
flatten
(),
dim
=
0
)
print
(
f
"m:
{
m
}
, n:
{
n
}
, k:
{
k
}
, cos_sim:
{
cos_sim
}
"
)
lightx2v_kernel/test/mxfp6_mxfp8/test_mm_tflops.py
0 → 100644
View file @
5103aef7
import
torch
from
lightx2v_kernel.gemm
import
cutlass_scaled_mxfp6_mxfp8_mm
"""
input_shape = (1024, 2048)
weight_shape = (4096, 2048)
input_tensor_quant = (torch.rand((1024, 1024), device="cuda") * 10).to(torch.uint8)
weight = (torch.rand((4096, 1024), device="cuda") * 10).to(torch.uint8)
input_tensor_scale = torch.rand(1024, 128, device="cuda").to(torch.float8_e8m0fnu)
weight_scale = torch.rand(4096, 128, device="cuda").to(torch.float8_e8m0fnu)
alpha = torch.tensor(1.0, device="cuda").to(torch.float32)
bias = None
"""
def
test_mm
(
input_tensor_quant
,
weight
,
input_tensor_scale
,
weight_scale
,
alpha
,
bias
):
output_tensor
=
cutlass_scaled_mxfp6_mxfp8_mm
(
input_tensor_quant
,
weight
,
input_tensor_scale
,
weight_scale
,
alpha
=
alpha
,
bias
=
bias
)
return
output_tensor
def
test_tflops
(
input_shape
,
weight_shape
,
num_warmup
=
10
,
num_runs
=
100
):
"""
测试test_mm函数的TFLOPS性能
"""
# 创建输入数据
input_tensor_quant
=
(
torch
.
rand
((
input_shape
[
0
],
input_shape
[
1
]),
device
=
"cuda"
)
*
10
).
to
(
torch
.
uint8
)
weight
=
(
torch
.
rand
((
weight_shape
[
0
],
3
*
weight_shape
[
1
]
//
4
),
device
=
"cuda"
)
*
10
).
to
(
torch
.
uint8
)
input_tensor_scale
=
torch
.
rand
(((
input_shape
[
0
]
+
128
-
1
)
//
128
)
*
128
,
(
input_shape
[
1
]
//
32
+
4
-
1
)
//
4
*
4
,
device
=
"cuda"
).
to
(
torch
.
float8_e8m0fnu
)
weight_scale
=
torch
.
rand
(
weight_shape
[
0
],
weight_shape
[
1
]
//
32
,
device
=
"cuda"
).
to
(
torch
.
float8_e8m0fnu
)
alpha
=
torch
.
tensor
(
1.0
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
bias
=
None
# 预热GPU
for
_
in
range
(
num_warmup
):
test_mm
(
input_tensor_quant
,
weight
,
input_tensor_scale
,
weight_scale
,
alpha
,
bias
)
# 同步GPU
torch
.
cuda
.
synchronize
()
# 创建GPU事件用于精确计时
start_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
end_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
# 测量时间
start_event
.
record
()
for
_
in
range
(
num_runs
):
result
=
test_mm
(
input_tensor_quant
,
weight
,
input_tensor_scale
,
weight_scale
,
alpha
,
bias
)
end_event
.
record
()
# 同步并计算时间
torch
.
cuda
.
synchronize
()
elapsed_time_ms
=
start_event
.
elapsed_time
(
end_event
)
elapsed_time_s
=
elapsed_time_ms
/
1000.0
# 计算FLOPS
# 矩阵乘法 A(M x K) @ B(K x N) = C(M x N)
# M = batch_size, K = input_dim, N = output_dim
M
=
input_shape
[
0
]
K
=
input_shape
[
1
]
N
=
weight_shape
[
0
]
# 每次矩阵乘法的FLOPS = 2 * M * N * K (每个输出元素需要K次乘法和K次加法)
flops_per_run
=
2
*
M
*
N
*
K
total_flops
=
flops_per_run
*
num_runs
# 计算TFLOPS (万亿次浮点运算每秒)
tflops
=
total_flops
/
(
elapsed_time_s
*
1e12
)
print
(
f
"测试结果:"
)
print
(
f
" 输入形状:
{
input_shape
}
(M=
{
M
}
, K=
{
K
}
)"
)
print
(
f
" 权重形状:
{
weight_shape
}
(N=
{
N
}
, K=
{
K
}
)"
)
print
(
f
" 输出形状: (
{
M
}
,
{
N
}
)"
)
print
(
f
" 运行次数:
{
num_runs
}
"
)
print
(
f
" 总执行时间:
{
elapsed_time_ms
:.
2
f
}
ms"
)
print
(
f
" 平均每次执行时间:
{
elapsed_time_ms
/
num_runs
:.
4
f
}
ms"
)
print
(
f
" 每次运行FLOPS:
{
flops_per_run
/
1e9
:.
2
f
}
GFLOPS"
)
print
(
f
" 总FLOPS:
{
total_flops
/
1e12
:.
2
f
}
TFLOPS"
)
print
(
f
" 计算性能:
{
tflops
:.
2
f
}
TFLOPS"
)
return
tflops
if
__name__
==
"__main__"
:
# 测试不同大小的矩阵乘法
# (m,k) (n,k)
test_cases
=
[
((
32130
,
5120
),
(
5120
,
5120
)),
((
512
,
1536
),
(
1536
,
1536
)),
((
512
,
5120
),
(
5120
,
5120
)),
((
257
,
5120
),
(
5120
,
5120
)),
((
32130
,
5120
),
(
13824
,
5120
)),
((
32130
,
13824
),
(
5120
,
13824
)),
((
75348
,
5120
),
(
5120
,
5120
)),
((
75348
,
5120
),
(
13824
,
5120
)),
((
75348
,
13824
),
(
5120
,
13824
)),
((
32760
,
1536
),
(
1536
,
1536
)),
((
512
,
1536
),
(
1536
,
1536
)),
((
32760
,
1536
),
(
8960
,
1536
)),
((
32760
,
8960
),
(
1536
,
8960
)),
]
print
(
"=== test_mm TFLOPS性能测试 ===
\n
"
)
for
i
,
(
input_shape
,
weight_shape
)
in
enumerate
(
test_cases
):
print
(
f
"测试
{
i
+
1
}
: 输入形状
{
input_shape
}
, 权重形状
{
weight_shape
}
"
)
print
(
"-"
*
60
)
tflops
=
test_tflops
(
input_shape
,
weight_shape
)
print
(
f
"✓ 成功完成测试,性能:
{
tflops
:.
2
f
}
TFLOPS
\n
"
)
print
(
"=== 测试完成 ==="
)
lightx2v_kernel/test/mxfp6_mxfp8/test_mxfp6_quant.py
0 → 100644
View file @
5103aef7
import
unittest
import
torch
from
lightx2v_kernel.gemm
import
cutlass_scaled_mxfp6_mxfp8_mm
from
lightx2v_kernel.gemm
import
scaled_fp6_quant
,
scaled_fp8_quant
from
torch.nn.functional
import
linear
from
lightx2v_kernel.utils
import
error
,
benchmark
class
TestQuantBF162MXFP6
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
tokens
=
[
128
,
257
,
512
,
1024
,
13325
,
32130
,
32760
]
# , 75348
self
.
channels
=
[
128
,
1536
,
5120
,
8960
]
# , 13824
self
.
hiddenDims
=
[
128
,
1536
,
3072
,
5120
,
8960
,
12800
]
# , 13824
self
.
device
=
"cuda"
self
.
dtype
=
torch
.
bfloat16
def
test_accuracy
(
self
):
"""Test the accuracy of quantization from BF16 to MXFP6."""
for
m
in
self
.
tokens
:
for
k
in
self
.
hiddenDims
:
for
n
in
self
.
channels
:
with
self
.
subTest
(
shape
=
[
m
,
k
,
n
]):
activation
=
torch
.
randn
(
m
,
k
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
activation_quant_pred
,
activation_scale_pred
=
scaled_fp8_quant
(
activation
)
weight
=
torch
.
randn
(
n
,
k
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
weight_quant_pred
,
weight_scale_pred
=
scaled_fp6_quant
(
weight
)
alpha
=
torch
.
tensor
(
1.0
,
device
=
self
.
device
,
dtype
=
torch
.
float32
)
mm_pred
=
cutlass_scaled_mxfp6_mxfp8_mm
(
activation_quant_pred
,
weight_quant_pred
,
activation_scale_pred
,
weight_scale_pred
,
alpha
=
alpha
)
mm_real
=
linear
(
activation
,
weight
,
bias
=
None
).
to
(
torch
.
bfloat16
)
self
.
assertTrue
(
error
(
mm_pred
,
mm_real
)
<
1e-2
,
f
"Accuracy test failed for shape
{
m
,
k
,
n
}
: Error
{
error
(
mm_pred
,
mm_real
)
}
exceeds threshold."
)
def
test_performance
(
self
):
"""Benchmark the performance of Activation quantization from BF16 to MXFP6."""
for
m
in
self
.
tokens
:
for
k
in
self
.
hiddenDims
:
with
self
.
subTest
(
shape
=
[
m
,
k
]):
input
=
torch
.
randn
(
m
,
k
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
shape
=
[
m
,
k
]
tflops
=
2
*
(
m
*
k
/
1024
**
4
)
benchmark
(
scaled_fp6_quant
,
shape
,
tflops
,
100
,
input
)
if
__name__
==
"__main__"
:
unittest
.
main
()
lightx2v_kernel/test/mxfp6_mxfp8/test_quant_mem_utils.py
0 → 100644
View file @
5103aef7
import
torch
from
lightx2v_kernel.gemm
import
scaled_fp6_quant
def
quantize_fp6
(
x
):
return
scaled_fp6_quant
(
x
)
def
test_memory_bandwidth
(
func
,
x
,
num_warmup
=
10
,
num_runs
=
100
):
"""
测试函数的显存带宽
"""
# 预热GPU
for
_
in
range
(
num_warmup
):
func
(
x
)
# 同步GPU
torch
.
cuda
.
synchronize
()
# 创建GPU事件用于精确计时
start_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
end_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
# 测量时间
start_event
.
record
()
for
_
in
range
(
num_runs
):
result
=
func
(
x
)
end_event
.
record
()
# 同步并计算时间
torch
.
cuda
.
synchronize
()
elapsed_time_ms
=
start_event
.
elapsed_time
(
end_event
)
elapsed_time_s
=
elapsed_time_ms
/
1000.0
# 计算数据量
input_bytes
=
x
.
numel
()
*
x
.
element_size
()
# 输入数据字节数
# FP6量化后,每个元素占用 3/ 4字节
output_bytes
=
x
.
numel
()
*
(
3
/
4
)
# FP6输出数据字节数
scale_bytes
=
x
.
numel
()
/
32
# group_size = 32
# 总数据传输量(读取输入 + 写入输出 + scale)
total_bytes
=
(
input_bytes
+
output_bytes
+
scale_bytes
)
*
num_runs
# 计算带宽
bandwidth_gbps
=
(
total_bytes
/
elapsed_time_s
)
/
(
1024
**
3
)
# GB/s
print
(
f
"测试结果:"
)
print
(
f
" 输入张量形状:
{
x
.
shape
}
"
)
print
(
f
" 输入数据类型:
{
x
.
dtype
}
"
)
print
(
f
" 运行次数:
{
num_runs
}
"
)
print
(
f
" 总执行时间:
{
elapsed_time_ms
:.
2
f
}
ms"
)
print
(
f
" 平均每次执行时间:
{
elapsed_time_ms
/
num_runs
:.
4
f
}
ms"
)
print
(
f
" 输入数据大小:
{
input_bytes
/
(
1024
**
2
):.
2
f
}
MB"
)
print
(
f
" 输出数据大小:
{
output_bytes
/
(
1024
**
2
):.
2
f
}
MB"
)
print
(
f
" 总数据传输量:
{
total_bytes
/
(
1024
**
3
):.
2
f
}
GB"
)
print
(
f
" 显存带宽:
{
bandwidth_gbps
:.
2
f
}
GB/s"
)
return
bandwidth_gbps
if
__name__
==
"__main__"
:
# 测试不同大小的张量
test_sizes
=
[
# (1, 1024),
# (1, 2048),
# (1, 4096),
# (1, 8192),
# (1, 16384),
# (1, 32768),
# (2, 1024),
# (2, 2048),
# (2, 4096),
# (2, 8192),
# (2, 16384),
# (2, 32768),
# (4, 1024),
# (4, 2048),
# (4, 4096),
# (4, 8192),
# (4, 16384),
# (4, 32768),
# (128, 1024),
# (128, 2048),
# (128, 4096),
# (128, 8192),
# (128, 16384),
# (128, 32768),
# (512, 1024),
# (512, 2048),
# (512, 4096),
# (512, 8192),
# (512, 16384),
# (512, 32768),
# (1024, 1024),
# (1024, 2048),
# (1024, 4096),
# (1024, 8192),
# (1024, 16384),
# (1024, 32768),
# (2048, 1024),
# (2048, 2048),
# (2048, 4096),
# (2048, 8192),
# (2048, 16384),
# (2048, 32768),
# (4096, 1024),
# (4096, 2048),
# (4096, 4096),
# (4096, 8192),
# (4096, 16384),
# (4096, 32768),
# (8192, 1024),
# (8192, 2048),
# (8192, 4096),
# (8192, 8192),
# (8192, 16384),
# (8192, 32768),
# (16384, 1024),
# (16384, 2048),
# (16384, 4096),
# (16384, 8192),
# (16384, 16384),
# (16384, 32768),
# (32768, 1024),
# (32768, 2048),
# (32768, 4096),
# (32768, 8192),
# (32768, 16384),
# (32768, 32768),
(
32130
,
5120
),
(
512
,
5120
),
(
257
,
5120
),
(
32130
,
13824
),
(
75348
,
5120
),
(
75348
,
13824
),
(
32760
,
1536
),
(
512
,
3072
),
(
512
,
1536
),
(
32760
,
8960
),
]
print
(
"=== quantize_fp8 显存带宽测试 ===
\n
"
)
for
i
,
(
h
,
w
)
in
enumerate
(
test_sizes
):
print
(
f
"测试
{
i
+
1
}
: 张量大小 (
{
h
}
,
{
w
}
)"
)
print
(
"-"
*
50
)
x
=
torch
.
randn
(
h
,
w
,
dtype
=
torch
.
bfloat16
).
cuda
()
try
:
bandwidth
=
test_memory_bandwidth
(
quantize_fp6
,
x
)
print
(
f
"✓ 成功完成测试,带宽:
{
bandwidth
:.
2
f
}
GB/s
\n
"
)
except
Exception
as
e
:
print
(
f
"✗ 测试失败:
{
e
}
\n
"
)
print
(
"=== 测试完成 ==="
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment