Commit 1e422663 authored by Xtra's avatar Xtra Committed by GitHub
Browse files

add mxfp8 quant kernel and some tests (#97)



* add mxfp8 quant kernel and some tests

* Update .gitignore

---------
Co-authored-by: default avatarYang Yong(雍洋) <yongyang1030@163.com>
parent 48b57707
...@@ -93,7 +93,9 @@ list(APPEND LIGHTX2V_KERNEL_CUDA_FLAGS ...@@ -93,7 +93,9 @@ list(APPEND LIGHTX2V_KERNEL_CUDA_FLAGS
set(SOURCES set(SOURCES
"csrc/gemm/nvfp4_scaled_mm_kernels_sm120.cu" "csrc/gemm/nvfp4_scaled_mm_kernels_sm120.cu"
"csrc/gemm/nvfp4_quant_kernels_sm120.cu" "csrc/gemm/nvfp4_quant_kernels_sm120.cu"
"csrc/gemm/mxfp8_quant_kernels_sm120.cu"
"csrc/gemm/mxfp6_mxfp8_scaled_mm_kernels_sm120.cu" "csrc/gemm/mxfp6_mxfp8_scaled_mm_kernels_sm120.cu"
"csrc/gemm/mxfp8_scaled_mm_kernels_sm120.cu"
"csrc/common_extension.cc" "csrc/common_extension.cc"
) )
......
...@@ -16,11 +16,21 @@ TORCH_LIBRARY_FRAGMENT(lightx2v_kernel, m) { ...@@ -16,11 +16,21 @@ TORCH_LIBRARY_FRAGMENT(lightx2v_kernel, m) {
" Tensor! output_scale, Tensor! input_scale) -> ()"); " Tensor! output_scale, Tensor! input_scale) -> ()");
m.impl("scaled_fp4_quant_sm120", torch::kCUDA, &scaled_fp4_quant_sm120); m.impl("scaled_fp4_quant_sm120", torch::kCUDA, &scaled_fp4_quant_sm120);
m.def(
"scaled_fp8_quant_sm120(Tensor! output, Tensor! input,"
" Tensor! output_scale) -> ()");
m.impl("scaled_fp8_quant_sm120", torch::kCUDA, &scaled_fp8_quant_sm120);
m.def( m.def(
"cutlass_scaled_mxfp6_mxfp8_mm_sm120(Tensor! out, Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, Tensor " "cutlass_scaled_mxfp6_mxfp8_mm_sm120(Tensor! out, Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, Tensor "
"alpha, Tensor? bias) -> ()"); "alpha, Tensor? bias) -> ()");
m.impl("cutlass_scaled_mxfp6_mxfp8_mm_sm120", torch::kCUDA, &cutlass_scaled_mxfp6_mxfp8_mm_sm120); m.impl("cutlass_scaled_mxfp6_mxfp8_mm_sm120", torch::kCUDA, &cutlass_scaled_mxfp6_mxfp8_mm_sm120);
m.def(
"cutlass_scaled_mxfp8_mm_sm120(Tensor! out, Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, Tensor "
"alpha, Tensor? bias) -> ()");
m.impl("cutlass_scaled_mxfp8_mm_sm120", torch::kCUDA, &cutlass_scaled_mxfp8_mm_sm120);
} }
REGISTER_EXTENSION(common_ops) REGISTER_EXTENSION(common_ops)
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_fp8.h>
#include <cuda_runtime.h>
#include <cuda_runtime_api.h>
#include <torch/all.h>
#include "utils.h"
// Get type2 from type or vice versa (applied to half and bfloat16)
template <typename T>
struct TypeConverter {
using Type = half2;
}; // keep for generality
template <>
struct TypeConverter<half2> {
using Type = half;
};
template <>
struct TypeConverter<half> {
using Type = half2;
};
template <>
struct TypeConverter<__nv_bfloat162> {
using Type = __nv_bfloat16;
};
template <>
struct TypeConverter<__nv_bfloat16> {
using Type = __nv_bfloat162;
};
#define ELTS_PER_THREAD 8
constexpr int CVT_FP8_ELTS_PER_THREAD = 8;
constexpr int CVT_FP8_SF_VEC_SIZE = 32;
// Convert 4 float2 values into 8 e4m3 values (represented as one uint64_t).
inline __device__ uint64_t fp32_vec_to_e4m3(float2 (&array)[4]) {
uint64_t val;
asm volatile(
"{\n"
".reg .b16 pack0;\n"
".reg .b16 pack1;\n"
".reg .b16 pack2;\n"
".reg .b16 pack3;\n"
"cvt.rn.satfinite.e4m3x2.f32 pack0, %2, %1;\n"
"cvt.rn.satfinite.e4m3x2.f32 pack1, %4, %3;\n"
"cvt.rn.satfinite.e4m3x2.f32 pack2, %6, %5;\n"
"cvt.rn.satfinite.e4m3x2.f32 pack3, %8, %7;\n"
"mov.b64 %0, {pack0, pack1, pack2, pack3};\n"
"}"
: "=l"(val)
: "f"(array[0].x),
"f"(array[0].y),
"f"(array[1].x),
"f"(array[1].y),
"f"(array[2].x),
"f"(array[2].y),
"f"(array[3].x),
"f"(array[3].y));
return val;
}
// Fast reciprocal.
inline __device__ float reciprocal_approximate_ftz(float a) {
float b;
asm volatile("rcp.approx.ftz.f32 %0, %1;\n" : "=f"(b) : "f"(a));
return b;
}
template <class SFType, int CVT_FP8_NUM_THREADS_PER_SF>
__device__ uint8_t* get_sf_out_address(int rowIdx, int colIdx, int numCols, SFType* SFout) {
// #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
static_assert(CVT_FP8_NUM_THREADS_PER_SF == 4);
// one of 4 threads write one SF to global memory.
// TODO: stage through smem for packed STG.32
// is it better than STG.8 from 4 threads ?
if (threadIdx.x % CVT_FP8_NUM_THREADS_PER_SF == 0) {
// SF vector index (16 elements share one SF in the K dimension).
int32_t kIdx = colIdx / CVT_FP8_NUM_THREADS_PER_SF;
int32_t mIdx = rowIdx;
// SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)]
// --> index [mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx]
int32_t mTileIdx = mIdx / (32 * 4);
// SF vector size 32.
int factor = CVT_FP8_SF_VEC_SIZE * 4;
int32_t numKTiles = (numCols + factor - 1) / factor;
int64_t mTileStride = numKTiles * 32 * 4 * 4;
int32_t kTileIdx = (kIdx / 4);
int64_t kTileStride = 32 * 4 * 4;
// M tile layout [32, 4] is column-major.
int32_t outerMIdx = (mIdx % 32); // same as (mIdx % 128) % 32
int64_t outerMStride = 4 * 4;
int32_t innerMIdx = (mIdx % (32 * 4)) / 32;
int64_t innerMStride = 4;
int32_t innerKIdx = (kIdx % 4);
int64_t innerKStride = 1;
// Compute the global offset.
int64_t SFOffset = mTileIdx * mTileStride + kTileIdx * kTileStride + outerMIdx * outerMStride +
innerMIdx * innerMStride + innerKIdx * innerKStride;
return reinterpret_cast<uint8_t*>(SFout) + SFOffset;
} else {
// Other threads do not write to SFout.
return nullptr;
}
}
// Define a 16 bytes packed data type.
template <class Type>
struct PackedVec {
typename TypeConverter<Type>::Type elts[4];
};
template <>
struct PackedVec<__nv_fp8_e4m3> {
__nv_fp8x2_e4m3 elts[8];
};
// Quantizes the provided PackedVec into the uint64_t output
template <class Type> // Type can be half or bfloat16
__device__ uint64_t cvt_warp_fp16_to_fp8(PackedVec<Type>& vec, uint8_t* SFout) {
// #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
// Get absolute maximum values among the local 8 values.
auto localMax = __habs2(vec.elts[0]);
// Local maximum value.
#pragma unroll
for (int i = 1; i < CVT_FP8_ELTS_PER_THREAD / 2; i++) {
localMax = __hmax2(localMax, __habs2(vec.elts[i]));
}
// Get the absolute maximum among all 32 values (four threads).
localMax = __hmax2(__shfl_xor_sync(uint32_t(-1), localMax, 1), localMax);
localMax = __hmax2(__shfl_xor_sync(uint32_t(-1), localMax, 2), localMax);
// Get the final absolute maximum values.
float vecMax = float(__hmax(localMax.x, localMax.y));
// Get the SF (max value of the vector / max value of e4m3).
// maximum value of e4m3 = 448.0.
// TODO: use half as compute data type.
float SFValue = (vecMax / 448.0f);
// 8 bits representation of the SF.
uint8_t fp8SFVal;
// Write the SF to global memory (STG.8).
__nv_fp8_e8m0 tmp;
tmp.__x = __nv_cvt_float_to_e8m0(SFValue, __NV_SATFINITE, cudaRoundPosInf);
SFValue = static_cast<float>(tmp);
fp8SFVal = tmp.__x;
float outputScale =
SFValue != 0 ? reciprocal_approximate_ftz(SFValue) : 0.0f;
if (SFout) {
// Write the SF to global memory (STG.8).
*SFout = fp8SFVal;
}
// Convert the input to float.
float2 fp2Vals[CVT_FP8_ELTS_PER_THREAD / 2];
#pragma unroll
for (int i = 0; i < CVT_FP8_ELTS_PER_THREAD / 2; i++) {
if constexpr (std::is_same_v<Type, half>) {
fp2Vals[i] = __half22float2(vec.elts[i]);
} else {
fp2Vals[i] = __bfloat1622float2(vec.elts[i]);
}
fp2Vals[i].x *= outputScale;
fp2Vals[i].y *= outputScale;
}
// Convert to e4m3 values.
uint64_t e4m3Vec = fp32_vec_to_e4m3(fp2Vals);
return e4m3Vec;
}
template <class Type> // Type can be half or bfloat16
__global__ void
// #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
__launch_bounds__(256, 6) cvt_fp16_to_fp8(
// #else
// cvt_fp16_to_fp8(
// #endif
int32_t numRows, int32_t numCols, Type const* in, uint64_t* out, uint32_t* SFout) {
// #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
using PackedVec = PackedVec<Type>;
static constexpr int CVT_FP8_NUM_THREADS_PER_SF = (CVT_FP8_SF_VEC_SIZE / CVT_FP8_ELTS_PER_THREAD);
static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP8_ELTS_PER_THREAD, "Vec size is not matched.");
// Input tensor row/col loops.
for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x) {
for (int colIdx = threadIdx.x; colIdx < numCols / CVT_FP8_ELTS_PER_THREAD; colIdx += blockDim.x) {
int64_t inOffset = rowIdx * (numCols / CVT_FP8_ELTS_PER_THREAD) + colIdx;
PackedVec in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset];
// Get the output tensor offset.
// Same as inOffset because 8 elements(E4M3) are packed into one uint64_t.
int64_t outOffset = inOffset;
auto& out_pos = out[outOffset];
auto sf_out =
get_sf_out_address<uint32_t, CVT_FP8_NUM_THREADS_PER_SF>(rowIdx, colIdx, numCols, SFout);
out_pos = cvt_warp_fp16_to_fp8<Type>(in_vec, sf_out);
}
}
// #endif
}
template <typename T>
void invokeFP8Quantization(
int m,
int n,
T const* input,
int64_t* output,
int32_t* SFOuput,
int multiProcessorCount,
cudaStream_t stream) {
// Grid, Block size.
// Each thread converts 8 values.
dim3 block(std::min(int(n / ELTS_PER_THREAD), 256));
// Get number of blocks per SM (assume we can fully utilize the SM).
int const numBlocksPerSM = 1536 / block.x;
dim3 grid(std::min(int(m), multiProcessorCount * numBlocksPerSM));
// Launch the cvt kernel.
cvt_fp16_to_fp8<T>
<<<grid, block, 0, stream>>>(
m, n, input, reinterpret_cast<uint64_t*>(output), reinterpret_cast<uint32_t*>(SFOuput));
}
// Instantiate the function.
template void invokeFP8Quantization(
int m,
int n,
half const* input,
int64_t* output,
int32_t* SFOuput,
int multiProcessorCount,
cudaStream_t stream);
template void invokeFP8Quantization(
int m,
int n,
__nv_bfloat16 const* input,
int64_t* output,
int32_t* SFOuput,
int multiProcessorCount,
cudaStream_t stream);
inline int getMultiProcessorCount() {
static int multi_processor_count = []() {
int device_id = 0;
int count = 0;
// Get the current CUDA device ID
CHECK_CUDA_SUCCESS(cudaGetDevice(&device_id));
// Get the number of multiprocessors for the current device
CHECK_CUDA_SUCCESS(cudaDeviceGetAttribute(&count, cudaDevAttrMultiProcessorCount, device_id));
return count; // Initialize the static variable
}();
return multi_processor_count; // Return the cached value on subsequent calls
}
void scaled_fp8_quant_sm120(
torch::Tensor& output, torch::Tensor const& input, torch::Tensor& output_sf) {
int32_t m = input.size(0);
int32_t n = input.size(1);
TORCH_CHECK(n % 32 == 0, "The N dimension must be multiple of 16.");
int multiProcessorCount = getMultiProcessorCount();
auto sf_out = static_cast<int32_t*>(output_sf.data_ptr());
auto output_ptr = static_cast<int64_t*>(output.data_ptr());
at::cuda::CUDAGuard device_guard{(char)input.get_device()};
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(input.get_device());
switch (input.scalar_type()) {
case torch::kHalf: {
auto input_ptr = reinterpret_cast<half const*>(input.data_ptr());
invokeFP8Quantization(m, n, input_ptr, output_ptr, sf_out, multiProcessorCount, stream);
break;
}
case torch::kBFloat16: {
auto input_ptr = reinterpret_cast<__nv_bfloat16 const*>(input.data_ptr());
invokeFP8Quantization(m, n, input_ptr, output_ptr, sf_out, multiProcessorCount, stream);
break;
}
default: {
std::cerr << "Observing: " << input.scalar_type() << " for the input datatype which is invalid";
throw std::runtime_error("Unsupported input data type for quantize_to_fp8.");
}
}
}
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
// clang-format off
#include "cutlass/cutlass.h"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/util/packed_stride.hpp"
// clang-format on
#define CUTLASS_CHECK(status) \
{ \
cutlass::Status error = status; \
TORCH_CHECK(error == cutlass::Status::kSuccess, cutlassGetStatusString(error)); \
}
#define CHECK_TYPE(x, st, m) TORCH_CHECK(x.scalar_type() == st, "Inconsistency of Tensor type:", m)
#define CHECK_TH_CUDA(x, m) TORCH_CHECK(x.is_cuda(), m, "must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x, m) TORCH_CHECK(x.is_contiguous(), m, "must be contiguous")
#define CHECK_INPUT(x, st, m) \
CHECK_TH_CUDA(x, m); \
CHECK_CONTIGUOUS(x, m); \
CHECK_TYPE(x, st, m)
using namespace cute;
struct Mxfp8GemmSm120 {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM kernel configurations
/////////////////////////////////////////////////////////////////////////////////////////////////
// A matrix configuration
using ElementA = cutlass::mx_float8_t<cutlass::float_e4m3_t>; // Element type for A matrix operand
using LayoutATag = cutlass::layout::RowMajor; // Layout type for A matrix operand
static constexpr int AlignmentA = 16; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
// B matrix configuration
using ElementB = cutlass::mx_float8_t<cutlass::float_e4m3_t>; // Element type for B matrix operand
using LayoutBTag = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
static constexpr int AlignmentB = 128; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
// C/D matrix configuration
using ElementD = cutlass::bfloat16_t; // Element type for D matrix operand
using ElementC = cutlass::bfloat16_t; // Element type for C matrix operand
using LayoutCTag = cutlass::layout::RowMajor; // Layout type for C matrix operand
using LayoutDTag = cutlass::layout::RowMajor; // Layout type for D matrix operand
static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
// Kernel functional config
using ElementAccumulator = float; // Element type for internal accumulation
using ArchTag = cutlass::arch::Sm120; // Tag indicating the minimum SM that supports the intended feature
using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; // Operator class tag
// Kernel Perf config
using ThreadBlockShape = Shape<_128,_128,_128>; // Threadblock's tile size
using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, OperatorClass,
ThreadBlockShape, ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementAccumulator,
ElementC, LayoutCTag, AlignmentC,
ElementD, LayoutDTag, AlignmentD,
cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy
>::CollectiveOp;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass,
ElementA, LayoutATag, AlignmentA,
ElementB, LayoutBTag, AlignmentB,
ElementAccumulator,
ThreadBlockShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto defaults to cooperative kernel schedule
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>, // Indicates ProblemShape
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
// Reference device GEMM implementation type
using StrideA = typename Gemm::GemmKernel::StrideA;
using LayoutA = decltype(cute::make_layout(make_shape(0,0,0), StrideA{}));
using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA; // Scale Factor tensors have an interleaved layout. Bring Layout instead of stride.
using StrideB = typename Gemm::GemmKernel::StrideB;
using LayoutB = decltype(cute::make_layout(make_shape(0,0,0), StrideB{}));
using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB; // Scale Factor tensors have an interleaved layout. Bring Layout instead of stride.
using StrideC = typename Gemm::GemmKernel::StrideC;
using LayoutC = decltype(cute::make_layout(make_shape(0,0,0), StrideC{}));
using StrideD = typename Gemm::GemmKernel::StrideD;
using LayoutD = decltype(cute::make_layout(make_shape(0,0,0), StrideD{}));
};
// Populates a Gemm::Arguments structure from the given commandline options
typename Mxfp8GemmSm120::Gemm::Arguments args_from_options_mxfp8(
at::Tensor& D,
at::Tensor const& A,
at::Tensor const& B,
at::Tensor const& A_sf,
at::Tensor const& B_sf,
at::Tensor const& alpha,
c10::optional<torch::Tensor> const& bias,
int64_t M,
int64_t N,
int64_t K) {
using Sm1xxBlkScaledConfig = typename Mxfp8GemmSm120::Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
int m = static_cast<int>(M);
int n = static_cast<int>(N);
int k = static_cast<int>(K);
auto stride_A = cutlass::make_cute_packed_stride(Mxfp8GemmSm120::StrideA{}, {m, k, 1});
auto stride_B = cutlass::make_cute_packed_stride(Mxfp8GemmSm120::StrideB{}, {n, k, 1});
auto stride_D = cutlass::make_cute_packed_stride(Mxfp8GemmSm120::StrideD{}, {m, n, 1});
auto layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(m, n, k, 1));
auto layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(m, n, k, 1));
if (bias){
auto stride_bias = cutlass::make_cute_packed_stride(Mxfp8GemmSm120::StrideC{}, {});
typename Mxfp8GemmSm120::Gemm::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
{m, n, k, 1},
{// Mainloop arguments
static_cast<Mxfp8GemmSm120::Gemm::ElementA const*>(A.data_ptr()),
stride_A,
static_cast<Mxfp8GemmSm120::Gemm::ElementB const*>(B.data_ptr()),
stride_B,
static_cast<cutlass::float_ue8m0_t const*>(A_sf.data_ptr()),
layout_SFA,
static_cast<cutlass::float_ue8m0_t const*>(B_sf.data_ptr()),
layout_SFB},
{ // Epilogue arguments
{}, // epilogue.thread
static_cast<Mxfp8GemmSm120::Gemm::ElementC const*>(bias->data_ptr()),
stride_bias,
static_cast<Mxfp8GemmSm120::Gemm::ElementD*>(D.data_ptr()),
stride_D}};
auto& fusion_args = arguments.epilogue.thread;
fusion_args.alpha_ptr = static_cast<float const*>(alpha.data_ptr());
return arguments;
} else {
typename Mxfp8GemmSm120::Gemm::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
{m, n, k, 1},
{// Mainloop arguments
static_cast<Mxfp8GemmSm120::Gemm::ElementA const*>(A.data_ptr()),
stride_A,
static_cast<Mxfp8GemmSm120::Gemm::ElementB const*>(B.data_ptr()),
stride_B,
static_cast<cutlass::float_ue8m0_t const*>(A_sf.data_ptr()),
layout_SFA,
static_cast<cutlass::float_ue8m0_t const*>(B_sf.data_ptr()),
layout_SFB},
{ // Epilogue arguments
{}, // epilogue.thread
static_cast<Mxfp8GemmSm120::Gemm::ElementC const*>(D.data_ptr()),
stride_D,
static_cast<Mxfp8GemmSm120::Gemm::ElementD*>(D.data_ptr()),
stride_D}};
auto& fusion_args = arguments.epilogue.thread;
fusion_args.alpha_ptr = static_cast<float const*>(alpha.data_ptr());
return arguments;
}
}
void runGemmMxfp8Sm120(
at::Tensor& D,
at::Tensor const& A,
at::Tensor const& B,
at::Tensor const& A_sf,
at::Tensor const& B_sf,
at::Tensor const& alpha,
c10::optional<torch::Tensor> const& bias,
int64_t m,
int64_t n,
int64_t k,
cudaStream_t stream) {
typename Mxfp8GemmSm120::Gemm gemm;
auto arguments = args_from_options_mxfp8(D, A, B, A_sf, B_sf, alpha, bias, m, n, k);
size_t workspace_size = Mxfp8GemmSm120::Gemm::get_workspace_size(arguments);
auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(A.device());
auto workspace = torch::empty(workspace_size, workspace_options);
CUTLASS_CHECK(gemm.can_implement(arguments));
CUTLASS_CHECK(gemm.initialize(arguments, workspace.data_ptr(), stream));
CUTLASS_CHECK(gemm.run(arguments, workspace.data_ptr(), stream));
}
constexpr auto FP6_FP8_TYPE = at::ScalarType::Byte;
constexpr auto SF_DTYPE = at::ScalarType::Float8_e8m0fnu;
void cutlass_scaled_mxfp8_mm_sm120(
torch::Tensor& D,
torch::Tensor const& A,
torch::Tensor const& B,
torch::Tensor const& A_sf,
torch::Tensor const& B_sf,
torch::Tensor const& alpha,
c10::optional<torch::Tensor> const& bias) {
CHECK_INPUT(A, FP6_FP8_TYPE, "a");
CHECK_INPUT(B, FP6_FP8_TYPE, "b");
CHECK_INPUT(A_sf, SF_DTYPE, "scale_a");
CHECK_INPUT(B_sf, SF_DTYPE, "scale_b");
CHECK_INPUT(alpha, at::ScalarType::Float, "alpha");
TORCH_CHECK(A.dim() == 2, "a must be a matrix");
TORCH_CHECK(B.dim() == 2, "b must be a matrix");
TORCH_CHECK(
A.sizes()[1] == B.sizes()[1],
"a and b shapes cannot be multiplied (",
A.sizes()[0],
"x",
A.sizes()[1],
" and ",
B.sizes()[0],
"x",
B.sizes()[1],
")");
auto const m = A.sizes()[0];
auto const n = B.sizes()[0];
auto const k = A.sizes()[1];
constexpr int alignment_a = 16;
constexpr int alignment_b = 128;
TORCH_CHECK(
k % alignment_a == 0,
"Expected k to be divisible by ",
alignment_a,
", but got a shape: (",
A.sizes()[0],
"x",
A.sizes()[1],
"), k: ",
k,
".");
TORCH_CHECK(
n % alignment_b == 0,
"Expected n to be divisible by ",
alignment_b,
", but got b shape: (",
B.sizes()[0],
"x",
B.sizes()[1],
").");
auto round_up = [](int x, int y) { return (x + y - 1) / y * y; };
int rounded_m = round_up(m, 128);
int rounded_n = round_up(n, 128);
// Since k is divisible by 32 (alignment), k / 32 is guaranteed to be an
// integer.
int rounded_k = round_up(k / 32, 4);
TORCH_CHECK(A_sf.dim() == 2, "scale_a must be a matrix");
TORCH_CHECK(B_sf.dim() == 2, "scale_b must be a matrix");
TORCH_CHECK(
A_sf.sizes()[1] == B_sf.sizes()[1],
"scale_a and scale_b shapes cannot be multiplied (",
A_sf.sizes()[0],
"x",
A_sf.sizes()[1],
" and ",
B_sf.sizes()[0],
"x",
B_sf.sizes()[1],
")");
TORCH_CHECK(
A_sf.sizes()[0] == rounded_m && A_sf.sizes()[1] == rounded_k,
"scale_a must be padded and swizzled to a shape (",
rounded_m,
"x",
rounded_k,
"), but got a shape (",
A_sf.sizes()[0],
"x",
A_sf.sizes()[1],
")");
TORCH_CHECK(
B_sf.sizes()[0] == rounded_n && B_sf.sizes()[1] == rounded_k,
"scale_b must be padded and swizzled to a shape (",
rounded_n,
"x",
rounded_k,
"), but got a shape (",
B_sf.sizes()[0],
"x",
B_sf.sizes()[1],
")");
auto out_dtype = D.dtype();
at::cuda::CUDAGuard device_guard{(char)A.get_device()};
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(A.get_device());
runGemmMxfp8Sm120(D, A, B, A_sf, B_sf, alpha, bias, m, n, k, stream);
}
...@@ -57,6 +57,10 @@ void scaled_fp4_quant_sm120( ...@@ -57,6 +57,10 @@ void scaled_fp4_quant_sm120(
torch::Tensor& output, torch::Tensor const& input, torch::Tensor& output_sf, torch::Tensor const& input_sf); torch::Tensor& output, torch::Tensor const& input, torch::Tensor& output_sf, torch::Tensor const& input_sf);
void scaled_fp8_quant_sm120(
torch::Tensor& output, torch::Tensor const& input, torch::Tensor& output_sf);
void cutlass_scaled_mxfp6_mxfp8_mm_sm120( void cutlass_scaled_mxfp6_mxfp8_mm_sm120(
torch::Tensor& D, torch::Tensor& D,
torch::Tensor const& A, torch::Tensor const& A,
...@@ -65,3 +69,13 @@ void cutlass_scaled_mxfp6_mxfp8_mm_sm120( ...@@ -65,3 +69,13 @@ void cutlass_scaled_mxfp6_mxfp8_mm_sm120(
torch::Tensor const& B_sf, torch::Tensor const& B_sf,
torch::Tensor const& alpha, torch::Tensor const& alpha,
c10::optional<torch::Tensor> const& bias); c10::optional<torch::Tensor> const& bias);
void cutlass_scaled_mxfp8_mm_sm120(
torch::Tensor& D,
torch::Tensor const& A,
torch::Tensor const& B,
torch::Tensor const& A_sf,
torch::Tensor const& B_sf,
torch::Tensor const& alpha,
c10::optional<torch::Tensor> const& bias);
...@@ -55,8 +55,28 @@ def scaled_fp4_quant(input: torch.Tensor, input_global_scale: torch.Tensor): ...@@ -55,8 +55,28 @@ def scaled_fp4_quant(input: torch.Tensor, input_global_scale: torch.Tensor):
return output, output_scale return output, output_scale
def scaled_fp8_quant(input: torch.Tensor):
m, n = input.shape
block_size = 32
device = input.device
output = torch.empty((m, n), device=device, dtype=torch.uint8)
output_scale = torch.empty(((m + 128 - 1) // 128 * 128, (n // block_size + 4 - 1) // 4), device=device, dtype=torch.int32)
torch.ops.lightx2v_kernel.scaled_fp8_quant_sm120.default(output, input, output_scale)
output_scale = output_scale.view(torch.float8_e8m0fnu)
return output, output_scale
def cutlass_scaled_mxfp6_mxfp8_mm(mat_a, mat_b, scales_a, scales_b, alpha, bias=None): def cutlass_scaled_mxfp6_mxfp8_mm(mat_a, mat_b, scales_a, scales_b, alpha, bias=None):
m, n = mat_a.shape[0], mat_b.shape[0] m, n = mat_a.shape[0], mat_b.shape[0]
out = torch.empty((m, n), dtype=torch.bfloat16, device=mat_a.device) out = torch.empty((m, n), dtype=torch.bfloat16, device=mat_a.device)
torch.ops.lightx2v_kernel.cutlass_scaled_mxfp6_mxfp8_mm_sm120.default(out, mat_a, mat_b, scales_a, scales_b, alpha, bias) torch.ops.lightx2v_kernel.cutlass_scaled_mxfp6_mxfp8_mm_sm120.default(out, mat_a, mat_b, scales_a, scales_b, alpha, bias)
return out return out
def cutlass_scaled_mxfp8_mm(mat_a, mat_b, scales_a, scales_b, alpha, bias=None):
m, n = mat_a.shape[0], mat_b.shape[0]
out = torch.empty((m, n), dtype=torch.bfloat16, device=mat_a.device)
torch.ops.lightx2v_kernel.cutlass_scaled_mxfp8_mm_sm120.default(out, mat_a, mat_b, scales_a, scales_b, alpha, bias)
return out
import functools import functools
from typing import Dict, Tuple from typing import Dict, Tuple, Callable, List
import torch import torch
...@@ -33,3 +33,125 @@ def is_hopper_arch() -> bool: ...@@ -33,3 +33,125 @@ def is_hopper_arch() -> bool:
device = torch.cuda.current_device() device = torch.cuda.current_device()
major, minor = torch.cuda.get_device_capability(device) major, minor = torch.cuda.get_device_capability(device)
return major == 9 return major == 9
def error(y_pred: torch.Tensor, y_real: torch.Tensor) -> torch.Tensor:
"""
Compute SNR between y_pred(tensor) and y_real(tensor)
SNR can be calcualted as following equation:
SNR(pred, real) = (pred - real) ^ 2 / (real) ^ 2
if x and y are matrixs, SNR error over matrix should be the mean value of SNR error over all elements.
SNR(pred, real) = mean((pred - real) ^ 2 / (real) ^ 2)
Args:
y_pred (torch.Tensor): _description_
y_real (torch.Tensor): _description_
reduction (str, optional): _description_. Defaults to 'mean'.
Raises:
ValueError: _description_
ValueError: _description_
Returns:
torch.Tensor: _description_
"""
y_pred = torch.flatten(y_pred).float()
y_real = torch.flatten(y_real).float()
if y_pred.shape != y_real.shape:
raise ValueError(f"Can not compute snr loss for tensors with different shape. ({y_pred.shape} and {y_real.shape})")
noise_power = torch.pow(y_pred - y_real, 2).sum(dim=-1)
signal_power = torch.pow(y_real, 2).sum(dim=-1)
snr = (noise_power) / (signal_power + 1e-7)
return snr.item()
def benchmark(func: Callable, shape: List[int], tflops: float, steps: int, *args, **kwargs):
"""
A decorator function to assist in performance testing of CUDA operations.
This function will:
1. Automatically determine whether any parameters in the argument list,
or the output of the `func`, are of type `torch.Tensor`.
2. If so, calculate the memory usage of the input and output tensors
on the GPU (based on their data type and `torch.numel()`).
3. Establish a CUDA graph and attempt to execute `func` repeatedly for `steps` iterations.
4. Record the execution time during these iterations.
5. Use the information above to compute the compute performance (TFLOPS) and memory throughput.
Args:
func (function): The function to benchmark.
shape (list of int): The problem shape.
tflops (float): The computational workload (in TFLOPS) per call of `func`.
steps (int): The number of times the function is executed during benchmarking.
*args: Positional arguments to be passed to the `func`.
**kwargs: Keyword arguments to be passed to the `func`.
Returns:
function result
"""
# Ensure CUDA is available
if not torch.cuda.is_available():
raise RuntimeError("CUDA is required for benchmarking.")
# Check for torch.Tensor in inputs and outputs
input_tensors = [arg for arg in args if isinstance(arg, torch.Tensor)]
input_tensors += [value for value in kwargs.values() if isinstance(value, torch.Tensor)]
def calculate_memory(tensor: torch.Tensor):
"""Calculate memory usage in bytes for a tensor."""
return tensor.numel() * tensor.element_size()
input_memory = sum(calculate_memory(t) for t in input_tensors)
# Execute the function to inspect outputs
with torch.no_grad():
output = func(*args, **kwargs)
output_memory = 0
if isinstance(output, torch.Tensor):
output_memory = calculate_memory(output)
elif isinstance(output, (list, tuple)):
output_memory = sum(calculate_memory(o) for o in output if isinstance(o, torch.Tensor))
total_memory = input_memory + output_memory
# Warm-up and CUDA graph creation
for _ in range(10): # Warm-up
func(*args, **kwargs)
torch.cuda.synchronize() # Ensure no pending operations
# Benchmark the function
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
for _ in range(steps):
func(*args, **kwargs)
end_event.record()
torch.cuda.synchronize() # Ensure all operations are finished
elapsed_time_ms = start_event.elapsed_time(end_event) # Time in milliseconds
# Calculate performance metrics
elapsed_time_s = elapsed_time_ms / 1000 # Convert to seconds
avg_time_per_step = elapsed_time_s / steps
compute_performance = tflops / avg_time_per_step # TFLOPS
memory_throughput = (total_memory * steps / (1024**3)) / elapsed_time_s # GB/s
# Print performance metrics
print(f"Function: {func.__name__}{shape}")
# print(f"Function: {func.__ne__}{shape}")
print(f"Elapsed Time (total): {elapsed_time_s:.4f} seconds")
print(f"Average Time Per Step: {avg_time_per_step * 1000:.3f} ms")
print(f"Compute Performance: {compute_performance:.2f} TFLOPS")
print(f"Memory Throughput: {memory_throughput:.2f} GB/s")
print("") # print a blank line.
import torch
from lightx2v_kernel.gemm import scaled_fp8_quant, cutlass_scaled_mxfp8_mm
import time
class MMWeightMxfp8:
def __init__(self, weight, bias):
self.load_fp8_weight(weight, bias)
self.act_quant_func = self.act_quant_fp8
self.set_alpha()
@torch.no_grad()
def apply(self, input_tensor):
input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
output_tensor = cutlass_scaled_mxfp8_mm(input_tensor_quant, self.weight, input_tensor_scale, self.weight_scale, alpha=self.alpha, bias=self.bias)
return output_tensor
@torch.no_grad()
def load_fp8_weight(self, weight, bias):
self.weight, self.weight_scale = scaled_fp8_quant(weight)
self.bias = bias
def set_alpha(self):
self.alpha = torch.tensor(1.0, dtype=torch.float32, device=self.weight.device)
@torch.no_grad()
def act_quant_fp8(self, x):
return scaled_fp8_quant(x)
def test_speed(m, k, n):
with torch.no_grad():
input_tensor = torch.randn(m, k, dtype=torch.bfloat16).cuda()
weight = torch.randn(n, k, dtype=torch.bfloat16, device="cuda")
# bias = torch.randn(1, n, dtype=torch.bfloat16).cuda()
bias = None
mm = MMWeightMxfp8(weight, bias)
# warmup
output_tensor = mm.apply(input_tensor)
torch.cuda.synchronize()
start_time = time.time()
for i in range(100):
output_tensor = mm.apply(input_tensor)
torch.cuda.synchronize()
end_time = time.time()
lightx2v_kernel_time = (end_time - start_time) / 100
print(f"lightx2v-kernel time: {lightx2v_kernel_time}")
input_tensor = torch.randn(m, n, dtype=torch.bfloat16).cuda()
weight = torch.randn(k, n, dtype=torch.bfloat16, device="cuda")
bias = torch.randn(1, k, dtype=torch.bfloat16).cuda()
linear = torch.nn.Linear(k, n, bias=False).cuda()
linear.weight.data = weight
# linear.bias.data = bias
# warmup
ref_output_tensor = linear(input_tensor)
torch.cuda.synchronize()
start_time = time.time()
for i in range(100):
ref_output_tensor = linear(input_tensor)
torch.cuda.synchronize()
end_time = time.time()
ref_time = (end_time - start_time) / 100
print(f"ref time: {ref_time}")
print(f"speedup: {ref_time / lightx2v_kernel_time:.3f}")
def test_accuracy(m, k, n):
with torch.no_grad():
input_tensor = torch.randn(m, k, dtype=torch.bfloat16).cuda()
weight = torch.randn(n, k, dtype=torch.bfloat16, device="cuda")
# bias = torch.randn(1, n, dtype=torch.bfloat16).cuda()
bias = None
linear = torch.nn.Linear(k, n, bias=False).cuda()
linear.weight.data = weight
# linear.bias.data = bias
ref_output_tensor = linear(input_tensor)
mm = MMWeightMxfp8(weight, bias)
output_tensor = mm.apply(input_tensor)
# print(f"ref_output_tensor: {ref_output_tensor}")
# print(f"output_tensor: {output_tensor}")
# cosine
cos = torch.nn.functional.cosine_similarity(ref_output_tensor.flatten(), output_tensor.flatten(), dim=0)
print(f"cos : {cos}")
if __name__ == "__main__":
test_sizes = [
(32130, 5120, 5120),
(512, 5120, 5120),
(257, 5120, 5120),
(32130, 5120, 13824),
(32130, 13824, 5120),
(75348, 5120, 5120),
(75348, 13824, 5120),
(32760, 1536, 1536),
(512, 1536, 1536),
(32760, 1536, 8960),
(32760, 8960, 1536),
]
for i, (m, k, n) in enumerate(test_sizes):
print("-" * 30)
print(f"测试 {i + 1}: 张量大小 ({m}, {k}, {n})")
test_accuracy(m, k, n)
test_speed(m, k, n)
import torch
from lightx2v_kernel.gemm import cutlass_scaled_mxfp8_mm
"""
input_shape = (1024, 2048)
weight_shape = (4096, 2048)
input_tensor_quant = (torch.rand((1024, 1024), device="cuda") * 10).to(torch.uint8)
weight = (torch.rand((4096, 1024), device="cuda") * 10).to(torch.uint8)
input_tensor_scale = torch.rand(1024, 128, device="cuda").to(torch.float8_e8m0fnu)
weight_scale = torch.rand(4096, 128, device="cuda").to(torch.float8_e8m0fnu)
alpha = torch.tensor(1.0, device="cuda").to(torch.float32)
bias = None
"""
def test_mm(input_tensor_quant, weight, input_tensor_scale, weight_scale, alpha, bias):
output_tensor = cutlass_scaled_mxfp8_mm(input_tensor_quant, weight, input_tensor_scale, weight_scale, alpha=alpha, bias=bias)
return output_tensor
def test_tflops(input_shape, weight_shape, num_warmup=10, num_runs=100):
"""
测试test_mm函数的TFLOPS性能
"""
# 创建输入数据
input_tensor_quant = (torch.rand((input_shape[0], input_shape[1]), device="cuda") * 10).to(torch.uint8)
weight = (torch.rand((weight_shape[0], weight_shape[1]), device="cuda") * 10).to(torch.uint8)
input_tensor_scale = torch.rand(((input_shape[0] + 128 - 1) // 128) * 128, (input_shape[1] // 32 + 4 - 1) // 4 * 4, device="cuda").to(torch.float8_e8m0fnu)
weight_scale = torch.rand(weight_shape[0], weight_shape[1] // 32, device="cuda").to(torch.float8_e8m0fnu)
alpha = torch.tensor(1.0, device="cuda", dtype=torch.float32)
bias = None
# 预热GPU
for _ in range(num_warmup):
test_mm(input_tensor_quant, weight, input_tensor_scale, weight_scale, alpha, bias)
# 同步GPU
torch.cuda.synchronize()
# 创建GPU事件用于精确计时
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
# 测量时间
start_event.record()
for _ in range(num_runs):
result = test_mm(input_tensor_quant, weight, input_tensor_scale, weight_scale, alpha, bias)
end_event.record()
# 同步并计算时间
torch.cuda.synchronize()
elapsed_time_ms = start_event.elapsed_time(end_event)
elapsed_time_s = elapsed_time_ms / 1000.0
# 计算FLOPS
# 矩阵乘法 A(M x K) @ B(K x N) = C(M x N)
# M = batch_size, K = input_dim, N = output_dim
M = input_shape[0]
K = input_shape[1]
N = weight_shape[0]
# 每次矩阵乘法的FLOPS = 2 * M * N * K (每个输出元素需要K次乘法和K次加法)
flops_per_run = 2 * M * N * K
total_flops = flops_per_run * num_runs
# 计算TFLOPS (万亿次浮点运算每秒)
tflops = total_flops / (elapsed_time_s * 1e12)
print(f"测试结果:")
print(f" 输入形状: {input_shape} (M={M}, K={K})")
print(f" 权重形状: {weight_shape} (N={N}, K={K})")
print(f" 输出形状: ({M}, {N})")
print(f" 运行次数: {num_runs}")
print(f" 总执行时间: {elapsed_time_ms:.2f} ms")
print(f" 平均每次执行时间: {elapsed_time_ms / num_runs:.4f} ms")
print(f" 每次运行FLOPS: {flops_per_run / 1e9:.2f} GFLOPS")
print(f" 总FLOPS: {total_flops / 1e12:.2f} TFLOPS")
print(f" 计算性能: {tflops:.2f} TFLOPS")
return tflops
if __name__ == "__main__":
# 测试不同大小的矩阵乘法
# (m,k) (n,k)
test_cases = [
((32130, 5120), (5120, 5120)),
((512, 1536), (1536, 1536)),
((512, 5120), (5120, 5120)),
((257, 5120), (5120, 5120)),
((32130, 5120), (13824, 5120)),
((32130, 13824), (5120, 13824)),
((75348, 5120), (5120, 5120)),
((75348, 5120), (13824, 5120)),
((75348, 13824), (5120, 13824)),
((32760, 1536), (1536, 1536)),
((512, 1536), (1536, 1536)),
((32760, 1536), (8960, 1536)),
((32760, 8960), (1536, 8960)),
]
print("=== test_mm TFLOPS性能测试 ===\n")
for i, (input_shape, weight_shape) in enumerate(test_cases):
print(f"测试 {i + 1}: 输入形状 {input_shape}, 权重形状 {weight_shape}")
print("-" * 60)
tflops = test_tflops(input_shape, weight_shape)
print(f"✓ 成功完成测试,性能: {tflops:.2f} TFLOPS\n")
print("=== 测试完成 ===")
import unittest
import torch
from lightx2v_kernel.gemm import cutlass_scaled_mxfp8_mm
from lightx2v_kernel.gemm import scaled_fp8_quant
from torch.nn.functional import linear
from lightx2v_kernel.utils import error, benchmark
class TestQuantBF162MXFP8(unittest.TestCase):
def setUp(self):
self.tokens = [257, 512, 1024, 13325, 32130, 32760] # , 75348
self.channels = [1536, 5120, 8960] # , 13824
self.hiddenDims = [1536, 3072, 5120, 8960, 12800] # , 13824
self.device = "cuda"
self.dtype = torch.bfloat16
def test_accuracy(self):
"""Test the accuracy of quantization from BF16 to MXFP8."""
for m in self.tokens:
for k in self.hiddenDims:
for n in self.channels:
with self.subTest(shape=[m, k, n]):
activation = torch.randn(m, k, dtype=self.dtype, device=self.device)
activation_quant_pred, activation_scale_pred = scaled_fp8_quant(activation)
weight = torch.randn(n, k, dtype=self.dtype, device=self.device)
weight_quant_pred, weight_scale_pred = scaled_fp8_quant(weight)
alpha = torch.tensor(1.0, device=self.device, dtype=torch.float32)
mm_pred = cutlass_scaled_mxfp8_mm(activation_quant_pred, weight_quant_pred, activation_scale_pred, weight_scale_pred, alpha=alpha)
mm_real = linear(activation, weight, bias=None).to(torch.bfloat16)
self.assertTrue(error(mm_pred, mm_real) < 1e-2, f"Accuracy test failed for shape {m, k, n}: Error {error(mm_pred, mm_real)} exceeds threshold.")
def test_performance(self):
"""Benchmark the performance of Activation quantization from BF16 to MXFP8."""
for m in self.tokens:
for k in self.hiddenDims:
with self.subTest(shape=[m, k]):
input = torch.randn(m, k, dtype=self.dtype, device=self.device)
shape = [m, k]
tflops = 2 * (m * k / 1024**4)
benchmark(scaled_fp8_quant, shape, tflops, 100, input)
if __name__ == "__main__":
unittest.main()
import torch
from lightx2v_kernel.gemm import scaled_fp8_quant
def quantize_fp8(x):
return scaled_fp8_quant(x)
def test_memory_bandwidth(func, x, num_warmup=10, num_runs=100):
"""
测试函数的显存带宽
"""
# 预热GPU
for _ in range(num_warmup):
func(x)
# 同步GPU
torch.cuda.synchronize()
# 创建GPU事件用于精确计时
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
# 测量时间
start_event.record()
for _ in range(num_runs):
result = func(x)
end_event.record()
# 同步并计算时间
torch.cuda.synchronize()
elapsed_time_ms = start_event.elapsed_time(end_event)
elapsed_time_s = elapsed_time_ms / 1000.0
# 计算数据量
input_bytes = x.numel() * x.element_size() # 输入数据字节数
# FP8量化后,每个元素占用1字节
output_bytes = x.numel() * 1 # FP8输出数据字节数
scale_bytes = x.numel() / 32 # group_size = 32
# 总数据传输量(读取输入 + 写入输出 + scale)
total_bytes = (input_bytes + output_bytes + scale_bytes) * num_runs
# 计算带宽
bandwidth_gbps = (total_bytes / elapsed_time_s) / (1024**3) # GB/s
print(f"测试结果:")
print(f" 输入张量形状: {x.shape}")
print(f" 输入数据类型: {x.dtype}")
print(f" 运行次数: {num_runs}")
print(f" 总执行时间: {elapsed_time_ms:.2f} ms")
print(f" 平均每次执行时间: {elapsed_time_ms / num_runs:.4f} ms")
print(f" 输入数据大小: {input_bytes / (1024**2):.2f} MB")
print(f" 输出数据大小: {output_bytes / (1024**2):.2f} MB")
print(f" 总数据传输量: {total_bytes / (1024**3):.2f} GB")
print(f" 显存带宽: {bandwidth_gbps:.2f} GB/s")
return bandwidth_gbps
if __name__ == "__main__":
# 测试不同大小的张量
test_sizes = [
# (1, 1024),
# (1, 2048),
# (1, 4096),
# (1, 8192),
# (1, 16384),
# (1, 32768),
# (2, 1024),
# (2, 2048),
# (2, 4096),
# (2, 8192),
# (2, 16384),
# (2, 32768),
# (4, 1024),
# (4, 2048),
# (4, 4096),
# (4, 8192),
# (4, 16384),
# (4, 32768),
# (128, 1024),
# (128, 2048),
# (128, 4096),
# (128, 8192),
# (128, 16384),
# (128, 32768),
# (512, 1024),
# (512, 2048),
# (512, 4096),
# (512, 8192),
# (512, 16384),
# (512, 32768),
# (1024, 1024),
# (1024, 2048),
# (1024, 4096),
# (1024, 8192),
# (1024, 16384),
# (1024, 32768),
# (2048, 1024),
# (2048, 2048),
# (2048, 4096),
# (2048, 8192),
# (2048, 16384),
# (2048, 32768),
# (4096, 1024),
# (4096, 2048),
# (4096, 4096),
# (4096, 8192),
# (4096, 16384),
# (4096, 32768),
# (8192, 1024),
# (8192, 2048),
# (8192, 4096),
# (8192, 8192),
# (8192, 16384),
# (8192, 32768),
# (16384, 1024),
# (16384, 2048),
# (16384, 4096),
# (16384, 8192),
# (16384, 16384),
# (16384, 32768),
# (32768, 1024),
# (32768, 2048),
# (32768, 4096),
# (32768, 8192),
# (32768, 16384),
# (32768, 32768),
(32130, 5120),
(512, 5120),
(257, 5120),
(32130, 13824),
(75348, 5120),
(75348, 13824),
(32760, 1536),
(512, 3072),
(512, 1536),
(32760, 8960),
]
print("=== quantize_fp8 显存带宽测试 ===\n")
for i, (h, w) in enumerate(test_sizes):
print(f"测试 {i + 1}: 张量大小 ({h}, {w})")
print("-" * 50)
x = torch.randn(h, w, dtype=torch.bfloat16).cuda()
try:
bandwidth = test_memory_bandwidth(quantize_fp8, x)
print(f"✓ 成功完成测试,带宽: {bandwidth:.2f} GB/s\n")
except Exception as e:
print(f"✗ 测试失败: {e}\n")
print("=== 测试完成 ===")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment