Commit 29a90944 authored by Xtra's avatar Xtra Committed by GitHub
Browse files

add mxfp4 kernels and rename some func for clarity (#148)

parent 505c5a47
......@@ -93,8 +93,10 @@ list(APPEND LIGHTX2V_KERNEL_CUDA_FLAGS
set(SOURCES
"csrc/gemm/nvfp4_scaled_mm_kernels_sm120.cu"
"csrc/gemm/nvfp4_quant_kernels_sm120.cu"
"csrc/gemm/mxfp4_quant_kernels_sm120.cu"
"csrc/gemm/mxfp8_quant_kernels_sm120.cu"
"csrc/gemm/mxfp6_quant_kernels_sm120.cu"
"csrc/gemm/mxfp4_scaled_mm_kernels_sm120.cu"
"csrc/gemm/mxfp6_mxfp8_scaled_mm_kernels_sm120.cu"
"csrc/gemm/mxfp8_scaled_mm_kernels_sm120.cu"
"csrc/common_extension.cc"
......
......@@ -7,23 +7,34 @@
TORCH_LIBRARY_FRAGMENT(lightx2v_kernel, m) {
m.def(
"cutlass_scaled_fp4_mm_sm120(Tensor! out, Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, Tensor "
"cutlass_scaled_nvfp4_mm_sm120(Tensor! out, Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, Tensor "
"alpha, Tensor? bias) -> ()");
m.impl("cutlass_scaled_fp4_mm_sm120", torch::kCUDA, &cutlass_scaled_fp4_mm_sm120);
m.impl("cutlass_scaled_nvfp4_mm_sm120", torch::kCUDA, &cutlass_scaled_nvfp4_mm_sm120);
m.def(
"scaled_fp4_quant_sm120(Tensor! output, Tensor! input,"
"scaled_nvfp4_quant_sm120(Tensor! output, Tensor! input,"
" Tensor! output_scale, Tensor! input_scale) -> ()");
m.impl("scaled_fp4_quant_sm120", torch::kCUDA, &scaled_fp4_quant_sm120);
m.impl("scaled_nvfp4_quant_sm120", torch::kCUDA, &scaled_nvfp4_quant_sm120);
m.def(
"scaled_fp8_quant_sm120(Tensor! output, Tensor! input,"
"scaled_mxfp4_quant_sm120(Tensor! output, Tensor! input,"
" Tensor! output_scale) -> ()");
m.impl("scaled_mxfp4_quant_sm120", torch::kCUDA, &scaled_mxfp4_quant_sm120);
m.def(
"scaled_mxfp8_quant_sm120(Tensor! output, Tensor! input,"
" Tensor! output_scale) -> ()");
m.impl("scaled_fp8_quant_sm120", torch::kCUDA, &scaled_fp8_quant_sm120);
m.impl("scaled_mxfp8_quant_sm120", torch::kCUDA, &scaled_mxfp8_quant_sm120);
m.def(
"scaled_fp6_quant_sm120(Tensor! output, Tensor! input,"
"scaled_mxfp6_quant_sm120(Tensor! output, Tensor! input,"
" Tensor! output_scale) -> ()");
m.impl("scaled_fp6_quant_sm120", torch::kCUDA, &scaled_fp6_quant_sm120);
m.impl("scaled_mxfp6_quant_sm120", torch::kCUDA, &scaled_mxfp6_quant_sm120);
m.def(
"cutlass_scaled_mxfp4_mm_sm120(Tensor! out, Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, Tensor "
"alpha, Tensor? bias) -> ()");
m.impl("cutlass_scaled_mxfp4_mm_sm120", torch::kCUDA, &cutlass_scaled_mxfp4_mm_sm120);
m.def(
"cutlass_scaled_mxfp6_mxfp8_mm_sm120(Tensor! out, Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, Tensor "
......
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_fp8.h>
#include <cuda_runtime.h>
#include <cuda_runtime_api.h>
#include <torch/all.h>
#include "utils.h"
// Get type2 from type or vice versa (applied to half and bfloat16)
template <typename T>
struct TypeConverter {
using Type = half2;
}; // keep for generality
template <>
struct TypeConverter<half2> {
using Type = half;
};
template <>
struct TypeConverter<half> {
using Type = half2;
};
template <>
struct TypeConverter<__nv_bfloat162> {
using Type = __nv_bfloat16;
};
template <>
struct TypeConverter<__nv_bfloat16> {
using Type = __nv_bfloat162;
};
#define ELTS_PER_THREAD 8
constexpr int CVT_FP4_ELTS_PER_THREAD = 8;
constexpr int CVT_FP4_SF_VEC_SIZE = 32;
// Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t).
inline __device__ uint32_t fp32_vec_to_e2m1(float2 (&array)[4]) {
// PTX instructions used here requires sm100a.
// #if CUDA_VERSION >= 12080
// #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) && __CUDA_ARCH_HAS_FEATURE__(SM100_ALL)
uint32_t val;
asm volatile(
"{\n"
".reg .b8 byte0;\n"
".reg .b8 byte1;\n"
".reg .b8 byte2;\n"
".reg .b8 byte3;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n"
"mov.b32 %0, {byte0, byte1, byte2, byte3};\n"
"}"
: "=r"(val)
: "f"(array[0].x),
"f"(array[0].y),
"f"(array[1].x),
"f"(array[1].y),
"f"(array[2].x),
"f"(array[2].y),
"f"(array[3].x),
"f"(array[3].y));
return val;
// #else
// return 0;
// #endif
// #endif
}
// Fast reciprocal.
inline __device__ float reciprocal_approximate_ftz(float a) {
float b;
asm volatile("rcp.approx.ftz.f32 %0, %1;\n" : "=f"(b) : "f"(a));
return b;
}
template <class SFType, int CVT_FP4_NUM_THREADS_PER_SF>
__device__ uint8_t* get_sf_out_address(int rowIdx, int colIdx, int numCols, SFType* SFout) {
// #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
static_assert(CVT_FP4_NUM_THREADS_PER_SF == 4);
// One pair of threads write one SF to global memory.
// TODO: stage through smem for packed STG.32
// is it better than STG.8 from 4 threads ?
if (threadIdx.x % CVT_FP4_NUM_THREADS_PER_SF == 0) {
// SF vector index (16 elements share one SF in the K dimension).
int32_t kIdx = colIdx / CVT_FP4_NUM_THREADS_PER_SF;
int32_t mIdx = rowIdx;
// SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)]
// --> index [mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx]
int32_t mTileIdx = mIdx / (32 * 4);
// SF vector size 16.
int factor = CVT_FP4_SF_VEC_SIZE * 4;
int32_t numKTiles = (numCols + factor - 1) / factor;
int64_t mTileStride = numKTiles * 32 * 4 * 4;
int32_t kTileIdx = (kIdx / 4);
int64_t kTileStride = 32 * 4 * 4;
// M tile layout [32, 4] is column-major.
int32_t outerMIdx = (mIdx % 32);
int64_t outerMStride = 4 * 4;
int32_t innerMIdx = (mIdx % (32 * 4)) / 32;
int64_t innerMStride = 4;
int32_t innerKIdx = (kIdx % 4);
int64_t innerKStride = 1;
// Compute the global offset.
int64_t SFOffset = mTileIdx * mTileStride + kTileIdx * kTileStride + outerMIdx * outerMStride +
innerMIdx * innerMStride + innerKIdx * innerKStride;
return reinterpret_cast<uint8_t*>(SFout) + SFOffset;
}
// #endif
return nullptr;
}
// Define a 16 bytes packed data type.
template <class Type>
struct PackedVec {
typename TypeConverter<Type>::Type elts[4];
};
template <>
struct PackedVec<__nv_fp8_e4m3> {
__nv_fp8x2_e4m3 elts[8];
};
// Quantizes the provided PackedVec into the uint32_t output
template <class Type>
__device__ uint32_t cvt_warp_fp16_to_fp4(PackedVec<Type>& vec, uint8_t* SFout) {
// #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
// Get absolute maximum values among the local 8 values.
auto localMax = __habs2(vec.elts[0]);
// Local maximum value.
#pragma unroll
for (int i = 1; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) {
localMax = __hmax2(localMax, __habs2(vec.elts[i]));
}
// Get the absolute maximum among all 32 values (four threads).
localMax = __hmax2(__shfl_xor_sync(uint32_t(-1), localMax, 1), localMax);
localMax = __hmax2(__shfl_xor_sync(uint32_t(-1), localMax, 2), localMax);
// Get the final absolute maximum values.
float vecMax = float(__hmax(localMax.x, localMax.y));
// Get the SF (max value of the vector / max value of e2m1).
// maximum value of e2m1 = 6.0.
// TODO: use half as compute data type.
float SFValue = vecMax * 0.16666666666666666f;
// 8 bits representation of the SF.
uint8_t fp8SFVal;
// Write the SF to global memory (STG.8).
__nv_fp8_e8m0 tmp;
tmp.__x = __nv_cvt_float_to_e8m0(SFValue, __NV_SATFINITE, cudaRoundPosInf);
SFValue = static_cast<float>(tmp);
fp8SFVal = tmp.__x;
// Get the output scale.
float outputScale =
SFValue != 0 ? reciprocal_approximate_ftz(SFValue) : 0.0f;
if (SFout) {
// Write the SF to global memory (STG.8).
*SFout = fp8SFVal;
}
// Convert the input to float.
float2 fp2Vals[CVT_FP4_ELTS_PER_THREAD / 2];
#pragma unroll
for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) {
if constexpr (std::is_same_v<Type, half>) {
fp2Vals[i] = __half22float2(vec.elts[i]);
} else {
fp2Vals[i] = __bfloat1622float2(vec.elts[i]);
}
fp2Vals[i].x *= outputScale;
fp2Vals[i].y *= outputScale;
}
// Convert to e2m1 values.
uint32_t e2m1Vec = fp32_vec_to_e2m1(fp2Vals);
// Write the e2m1 values to global memory.
return e2m1Vec;
// #else
// return 0;
// #endif
}
// Use UE4M3 by default.
template <class Type>
__global__ void
// #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
__launch_bounds__(256, 6) cvt_fp16_to_fp4(
// #else
// cvt_fp16_to_fp4(
// #endif
int32_t numRows, int32_t numCols, Type const* in, uint32_t* out, uint32_t* SFout) {
// #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
using PackedVec = PackedVec<Type>;
static constexpr int CVT_FP4_NUM_THREADS_PER_SF = (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD);
static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD, "Vec size is not matched.");
// Input tensor row/col loops.
for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x) {
for (int colIdx = threadIdx.x; colIdx < numCols / CVT_FP4_ELTS_PER_THREAD; colIdx += blockDim.x) {
int64_t inOffset = rowIdx * (numCols / CVT_FP4_ELTS_PER_THREAD) + colIdx;
PackedVec in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset];
// Get the output tensor offset.
// Same as inOffset because 8 elements are packed into one uint32_t.
int64_t outOffset = inOffset;
auto& out_pos = out[outOffset];
auto sf_out =
get_sf_out_address<uint32_t, CVT_FP4_NUM_THREADS_PER_SF>(rowIdx, colIdx, numCols, SFout);
out_pos = cvt_warp_fp16_to_fp4<Type>(in_vec, sf_out);
}
}
// #endif
}
template <typename T>
void invokeFP4Quantization(
int m,
int n,
T const* input,
int64_t* output,
int32_t* SFOuput,
int multiProcessorCount,
cudaStream_t stream) {
// Grid, Block size.
// Each thread converts 8 values.
dim3 block(std::min(int(n / ELTS_PER_THREAD), 256));
// Get number of blocks per SM (assume we can fully utilize the SM).
int const numBlocksPerSM = 1536 / block.x;
dim3 grid(std::min(int(m), multiProcessorCount * numBlocksPerSM));
// Launch the cvt kernel.
cvt_fp16_to_fp4<T><<<grid, block, 0, stream>>>(
m, n, input, reinterpret_cast<uint32_t*>(output), reinterpret_cast<uint32_t*>(SFOuput));
}
// Instantiate the function.
template void invokeFP4Quantization(
int m,
int n,
half const* input,
int64_t* output,
int32_t* SFOuput,
int multiProcessorCount,
cudaStream_t stream);
template void invokeFP4Quantization(
int m,
int n,
__nv_bfloat16 const* input,
int64_t* output,
int32_t* SFOuput,
int multiProcessorCount,
cudaStream_t stream);
inline int getMultiProcessorCount() {
static int multi_processor_count = []() {
int device_id = 0;
int count = 0;
// Get the current CUDA device ID
CHECK_CUDA_SUCCESS(cudaGetDevice(&device_id));
// Get the number of multiprocessors for the current device
CHECK_CUDA_SUCCESS(cudaDeviceGetAttribute(&count, cudaDevAttrMultiProcessorCount, device_id));
return count; // Initialize the static variable
}();
return multi_processor_count; // Return the cached value on subsequent calls
}
void scaled_mxfp4_quant_sm120(
torch::Tensor& output, torch::Tensor const& input, torch::Tensor& output_sf) {
int32_t m = input.size(0);
int32_t n = input.size(1);
TORCH_CHECK(n % 16 == 0, "The N dimension must be multiple of 16.");
int multiProcessorCount = getMultiProcessorCount();
auto sf_out = static_cast<int32_t*>(output_sf.data_ptr());
auto output_ptr = static_cast<int64_t*>(output.data_ptr());
at::cuda::CUDAGuard device_guard{(char)input.get_device()};
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(input.get_device());
switch (input.scalar_type()) {
case torch::kHalf: {
auto input_ptr = reinterpret_cast<half const*>(input.data_ptr());
invokeFP4Quantization(m, n, input_ptr, output_ptr, sf_out, multiProcessorCount, stream);
break;
}
case torch::kBFloat16: {
auto input_ptr = reinterpret_cast<__nv_bfloat16 const*>(input.data_ptr());
invokeFP4Quantization(m, n, input_ptr, output_ptr, sf_out, multiProcessorCount, stream);
break;
}
default: {
std::cerr << "Observing: " << input.scalar_type() << " for the input datatype which is invalid";
throw std::runtime_error("Unsupported input data type for quantize_to_fp4.");
}
}
}
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
// clang-format off
#include "cutlass/cutlass.h"
#include "cutlass/epilogue/fusion/operations.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/util/packed_stride.hpp"
// clang-format on
#define CUTLASS_CHECK(status) \
{ \
cutlass::Status error = status; \
TORCH_CHECK(error == cutlass::Status::kSuccess, cutlassGetStatusString(error)); \
}
#define CHECK_TYPE(x, st, m) TORCH_CHECK(x.scalar_type() == st, "Inconsistency of Tensor type:", m)
#define CHECK_TH_CUDA(x, m) TORCH_CHECK(x.is_cuda(), m, "must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x, m) TORCH_CHECK(x.is_contiguous(), m, "must be contiguous")
#define CHECK_INPUT(x, st, m) \
CHECK_TH_CUDA(x, m); \
CHECK_CONTIGUOUS(x, m); \
CHECK_TYPE(x, st, m)
using namespace cute;
struct Mxfp4GemmSm120 {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM kernel configurations
/////////////////////////////////////////////////////////////////////////////////////////////////
// A matrix configuration
using ElementA = cutlass::mx_float4_t<cutlass::float_e2m1_t>; // Element type for A matrix operand
using LayoutATag = cutlass::layout::RowMajor; // Layout type for A matrix operand
static constexpr int AlignmentA = 128; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
// B matrix configuration
using ElementB = cutlass::mx_float4_t<cutlass::float_e2m1_t>; // Element type for B matrix operand
using LayoutBTag = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
static constexpr int AlignmentB = 128; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
// C/D matrix configuration
using ElementD = cutlass::bfloat16_t; // Element type for D matrix operand
using ElementC = cutlass::bfloat16_t; // Element type for C matrix operand
using LayoutCTag = cutlass::layout::RowMajor; // Layout type for C matrix operand
using LayoutDTag = cutlass::layout::RowMajor; // Layout type for D matrix operand
static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
// Kernel functional config
using ElementAccumulator = float; // Element type for internal accumulation
using ArchTag = cutlass::arch::Sm120; // Tag indicating the minimum SM that supports the intended feature
using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; // Operator class tag
// Kernel Perf config
using ThreadBlockShape = Shape<_128,_128,_128>; // Threadblock's tile size
using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster
// use per-column bias, i.e. every column has different bias
using EVTOp = cutlass::epilogue::fusion::LinCombPerColBias<ElementD, ElementAccumulator>;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, OperatorClass,
ThreadBlockShape, ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementAccumulator,
ElementC, LayoutCTag, AlignmentC,
ElementD, LayoutDTag, AlignmentD,
cutlass::epilogue::collective::EpilogueScheduleAuto, // Epilogue schedule policy
EVTOp
>::CollectiveOp;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass,
ElementA, LayoutATag, AlignmentA,
ElementB, LayoutBTag, AlignmentB,
ElementAccumulator,
ThreadBlockShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto defaults to cooperative kernel schedule
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>, // Indicates ProblemShape
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
// Reference device GEMM implementation type
using StrideA = typename Gemm::GemmKernel::StrideA;
using LayoutA = decltype(cute::make_layout(make_shape(0,0,0), StrideA{}));
using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA; // Scale Factor tensors have an interleaved layout. Bring Layout instead of stride.
using StrideB = typename Gemm::GemmKernel::StrideB;
using LayoutB = decltype(cute::make_layout(make_shape(0,0,0), StrideB{}));
using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB; // Scale Factor tensors have an interleaved layout. Bring Layout instead of stride.
using StrideC = typename Gemm::GemmKernel::StrideC;
using LayoutC = decltype(cute::make_layout(make_shape(0,0,0), StrideC{}));
using StrideD = typename Gemm::GemmKernel::StrideD;
using LayoutD = decltype(cute::make_layout(make_shape(0,0,0), StrideD{}));
};
// Populates a Gemm::Arguments structure from the given commandline options
typename Mxfp4GemmSm120::Gemm::Arguments args_from_options_mxp4_mxfp4(
at::Tensor& D,
at::Tensor const& A,
at::Tensor const& B,
at::Tensor const& A_sf,
at::Tensor const& B_sf,
at::Tensor const& alpha,
c10::optional<torch::Tensor> const& bias,
int64_t M,
int64_t N,
int64_t K) {
using Sm1xxBlkScaledConfig = typename Mxfp4GemmSm120::Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
int m = static_cast<int>(M);
int n = static_cast<int>(N);
int k = static_cast<int>(K);
auto stride_A = cutlass::make_cute_packed_stride(Mxfp4GemmSm120::StrideA{}, {m, k, 1});
auto stride_B = cutlass::make_cute_packed_stride(Mxfp4GemmSm120::StrideB{}, {n, k, 1});
auto stride_D = cutlass::make_cute_packed_stride(Mxfp4GemmSm120::StrideD{}, {m, n, 1});
auto layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(m, n, k, 1));
auto layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(m, n, k, 1));
if (bias){
using StrideBias = Stride<cutlass::_0, cutlass::_1, int64_t>;
typename Mxfp4GemmSm120::Gemm::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
{m, n, k, 1},
{// Mainloop arguments
static_cast<Mxfp4GemmSm120::Gemm::ElementA const*>(A.data_ptr()),
stride_A,
static_cast<Mxfp4GemmSm120::Gemm::ElementB const*>(B.data_ptr()),
stride_B,
static_cast<cutlass::float_ue8m0_t const*>(A_sf.data_ptr()),
layout_SFA,
static_cast<cutlass::float_ue8m0_t const*>(B_sf.data_ptr()),
layout_SFB},
{ // Epilogue arguments
{}, // epilogue.thread
static_cast<Mxfp4GemmSm120::Gemm::ElementC const*>(D.data_ptr()),
stride_D,
static_cast<Mxfp4GemmSm120::Gemm::ElementD*>(D.data_ptr()),
stride_D}};
auto& fusion_args = arguments.epilogue.thread;
fusion_args.alpha_ptr = static_cast<float const*>(alpha.data_ptr());
static const float beta_zero = 0.0f;
fusion_args.beta_ptr = &beta_zero;
fusion_args.bias_ptr = static_cast<Mxfp4GemmSm120::Gemm::ElementC const*>(bias->data_ptr());
fusion_args.dBias = StrideBias{};
return arguments;
} else {
typename Mxfp4GemmSm120::Gemm::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
{m, n, k, 1},
{// Mainloop arguments
static_cast<Mxfp4GemmSm120::Gemm::ElementA const*>(A.data_ptr()),
stride_A,
static_cast<Mxfp4GemmSm120::Gemm::ElementB const*>(B.data_ptr()),
stride_B,
static_cast<cutlass::float_ue8m0_t const*>(A_sf.data_ptr()),
layout_SFA,
static_cast<cutlass::float_ue8m0_t const*>(B_sf.data_ptr()),
layout_SFB},
{ // Epilogue arguments
{}, // epilogue.thread
static_cast<Mxfp4GemmSm120::Gemm::ElementC const*>(D.data_ptr()),
stride_D,
static_cast<Mxfp4GemmSm120::Gemm::ElementD*>(D.data_ptr()),
stride_D}};
auto& fusion_args = arguments.epilogue.thread;
fusion_args.alpha_ptr = static_cast<float const*>(alpha.data_ptr());
static const float beta_zero = 0.0f;
fusion_args.beta_ptr = &beta_zero;
return arguments;
}
}
void runGemmMxfp4Sm120(
at::Tensor& D,
at::Tensor const& A,
at::Tensor const& B,
at::Tensor const& A_sf,
at::Tensor const& B_sf,
at::Tensor const& alpha,
c10::optional<torch::Tensor> const& bias,
int64_t m,
int64_t n,
int64_t k,
cudaStream_t stream) {
typename Mxfp4GemmSm120::Gemm gemm;
auto arguments = args_from_options_mxp4_mxfp4(D, A, B, A_sf, B_sf, alpha, bias, m, n, k);
size_t workspace_size = Mxfp4GemmSm120::Gemm::get_workspace_size(arguments);
auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(A.device());
auto workspace = torch::empty(workspace_size, workspace_options);
CUTLASS_CHECK(gemm.can_implement(arguments));
CUTLASS_CHECK(gemm.initialize(arguments, workspace.data_ptr(), stream));
CUTLASS_CHECK(gemm.run(arguments, workspace.data_ptr(), stream));
}
constexpr auto FLOAT4_E2M1X2 = at::ScalarType::Byte;
constexpr auto SF_DTYPE = at::ScalarType::Float8_e8m0fnu;
void cutlass_scaled_mxfp4_mm_sm120(
torch::Tensor& D,
torch::Tensor const& A,
torch::Tensor const& B,
torch::Tensor const& A_sf,
torch::Tensor const& B_sf,
torch::Tensor const& alpha,
c10::optional<torch::Tensor> const& bias) {
CHECK_INPUT(A, FLOAT4_E2M1X2, "a");
CHECK_INPUT(B, FLOAT4_E2M1X2, "b");
CHECK_INPUT(A_sf, SF_DTYPE, "scale_a");
CHECK_INPUT(B_sf, SF_DTYPE, "scale_b");
CHECK_INPUT(alpha, at::ScalarType::Float, "alpha");
TORCH_CHECK(A.dim() == 2, "a must be a matrix");
TORCH_CHECK(B.dim() == 2, "b must be a matrix");
TORCH_CHECK(
A.sizes()[1] == B.sizes()[1],
"a and b shapes cannot be multiplied (",
A.sizes()[0],
"x",
A.sizes()[1],
" and ",
B.sizes()[0],
"x",
B.sizes()[1],
")");
auto const m = A.sizes()[0];
auto const n = B.sizes()[0];
auto const k = A.sizes()[1] * 2;
constexpr int alignment = 128;
TORCH_CHECK(
k % alignment == 0,
"Expected k to be divisible by ",
alignment,
", but got a shape: (",
A.sizes()[0],
"x",
A.sizes()[1],
"), k: ",
k,
".");
TORCH_CHECK(
n % alignment == 0,
"Expected n to be divisible by ",
alignment,
", but got b shape: (",
B.sizes()[0],
"x",
B.sizes()[1],
").");
auto round_up = [](int x, int y) { return (x + y - 1) / y * y; };
int rounded_m = round_up(m, 128);
int rounded_n = round_up(n, 128);
// Since k is divisible by 128 (alignment), k / 32 is guaranteed to be an
// integer.
int rounded_k = round_up(k / 32, 4);
TORCH_CHECK(A_sf.dim() == 2, "scale_a must be a matrix");
TORCH_CHECK(B_sf.dim() == 2, "scale_b must be a matrix");
TORCH_CHECK(
A_sf.sizes()[1] == B_sf.sizes()[1],
"scale_a and scale_b shapes cannot be multiplied (",
A_sf.sizes()[0],
"x",
A_sf.sizes()[1],
" and ",
B_sf.sizes()[0],
"x",
B_sf.sizes()[1],
")");
TORCH_CHECK(
A_sf.sizes()[0] == rounded_m && A_sf.sizes()[1] == rounded_k,
"scale_a must be padded and swizzled to a shape (",
rounded_m,
"x",
rounded_k,
"), but got a shape (",
A_sf.sizes()[0],
"x",
A_sf.sizes()[1],
")");
TORCH_CHECK(
B_sf.sizes()[0] == rounded_n && B_sf.sizes()[1] == rounded_k,
"scale_b must be padded and swizzled to a shape (",
rounded_n,
"x",
rounded_k,
"), but got a shape (",
B_sf.sizes()[0],
"x",
B_sf.sizes()[1],
")");
auto out_dtype = D.dtype();
at::cuda::CUDAGuard device_guard{(char)A.get_device()};
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(A.get_device());
runGemmMxfp4Sm120(D, A, B, A_sf, B_sf, alpha, bias, m, n, k, stream);
}
......@@ -315,7 +315,7 @@ inline int getMultiProcessorCount() {
return multi_processor_count; // Return the cached value on subsequent calls
}
void scaled_fp6_quant_sm120(
void scaled_mxfp6_quant_sm120(
torch::Tensor& output, torch::Tensor const& input, torch::Tensor& output_sf) {
int32_t m = input.size(0);
int32_t n = input.size(1);
......
......@@ -282,7 +282,7 @@ inline int getMultiProcessorCount() {
return multi_processor_count; // Return the cached value on subsequent calls
}
void scaled_fp8_quant_sm120(
void scaled_mxfp8_quant_sm120(
torch::Tensor& output, torch::Tensor const& input, torch::Tensor& output_sf) {
int32_t m = input.size(0);
int32_t n = input.size(1);
......
......@@ -350,7 +350,7 @@ inline int getMultiProcessorCount() {
return multi_processor_count; // Return the cached value on subsequent calls
}
void scaled_fp4_quant_sm120(
void scaled_nvfp4_quant_sm120(
torch::Tensor& output, torch::Tensor const& input, torch::Tensor& output_sf, torch::Tensor const& input_sf) {
int32_t m = input.size(0);
int32_t n = input.size(1);
......
......@@ -215,7 +215,7 @@ void runGemmNvfp4Sm120(
constexpr auto FLOAT4_E2M1X2 = at::ScalarType::Byte;
constexpr auto SF_DTYPE = at::ScalarType::Float8_e4m3fn;
void cutlass_scaled_fp4_mm_sm120(
void cutlass_scaled_nvfp4_mm_sm120(
torch::Tensor& D,
torch::Tensor const& A,
torch::Tensor const& B,
......
......@@ -42,8 +42,19 @@ limitations under the License.
/*
* From csrc/gemm
*/
void scaled_nvfp4_quant_sm120(
torch::Tensor& output, torch::Tensor const& input, torch::Tensor& output_sf, torch::Tensor const& input_sf);
void scaled_mxfp4_quant_sm120(
torch::Tensor& output, torch::Tensor const& input, torch::Tensor& output_sf);
void cutlass_scaled_fp4_mm_sm120(
void scaled_mxfp6_quant_sm120(
torch::Tensor& output, torch::Tensor const& input, torch::Tensor& output_sf);
void scaled_mxfp8_quant_sm120(
torch::Tensor& output, torch::Tensor const& input, torch::Tensor& output_sf);
void cutlass_scaled_nvfp4_mm_sm120(
torch::Tensor& D,
torch::Tensor const& A,
torch::Tensor const& B,
......@@ -52,16 +63,14 @@ void cutlass_scaled_fp4_mm_sm120(
torch::Tensor const& alpha,
c10::optional<torch::Tensor> const& bias);
void scaled_fp4_quant_sm120(
torch::Tensor& output, torch::Tensor const& input, torch::Tensor& output_sf, torch::Tensor const& input_sf);
void scaled_fp8_quant_sm120(
torch::Tensor& output, torch::Tensor const& input, torch::Tensor& output_sf);
void scaled_fp6_quant_sm120(
torch::Tensor& output, torch::Tensor const& input, torch::Tensor& output_sf);
void cutlass_scaled_mxfp4_mm_sm120(
torch::Tensor& D,
torch::Tensor const& A,
torch::Tensor const& B,
torch::Tensor const& A_sf,
torch::Tensor const& B_sf,
torch::Tensor const& alpha,
c10::optional<torch::Tensor> const& bias);
void cutlass_scaled_mxfp6_mxfp8_mm_sm120(
torch::Tensor& D,
......
import torch
def cutlass_scaled_fp4_mm(mat_a, mat_b, scales_a, scales_b, alpha, bias=None):
def cutlass_scaled_nvfp4_mm(mat_a, mat_b, scales_a, scales_b, alpha, bias=None):
m, n = mat_a.shape[0], mat_b.shape[0]
out = torch.empty((m, n), dtype=torch.bfloat16, device=mat_a.device)
torch.ops.lightx2v_kernel.cutlass_scaled_fp4_mm_sm120.default(out, mat_a, mat_b, scales_a, scales_b, alpha, bias)
torch.ops.lightx2v_kernel.cutlass_scaled_nvfp4_mm_sm120.default(out, mat_a, mat_b, scales_a, scales_b, alpha, bias)
return out
def scaled_fp4_quant(input: torch.Tensor, input_global_scale: torch.Tensor):
def scaled_nvfp4_quant(input: torch.Tensor, input_global_scale: torch.Tensor):
"""
Quantize input tensor to FP4 and return quantized tensor and scale.
......@@ -50,12 +50,25 @@ def scaled_fp4_quant(input: torch.Tensor, input_global_scale: torch.Tensor):
# rounded_n = ((scale_n + 4 - 1) // 4) * 4
output_scale = torch.zeros((((m + 128 - 1) // 128) * 128, (n // block_size + 4 - 1) // 4), device=device, dtype=torch.int32)
torch.ops.lightx2v_kernel.scaled_fp4_quant_sm120.default(output, input, output_scale, input_global_scale)
torch.ops.lightx2v_kernel.scaled_nvfp4_quant_sm120.default(output, input, output_scale, input_global_scale)
output_scale = output_scale.view(torch.float8_e4m3fn)
return output, output_scale
def scaled_fp6_quant(input: torch.Tensor):
def scaled_mxfp4_quant(input: torch.Tensor):
m, n = input.shape
block_size = 32
device = input.device
output = torch.empty((m, n // 2), device=device, dtype=torch.uint8)
output_scale = torch.zeros(((m + 128 - 1) // 128 * 128, (n // block_size + 4 - 1) // 4), device=device, dtype=torch.int32)
torch.ops.lightx2v_kernel.scaled_mxfp4_quant_sm120.default(output, input, output_scale)
output_scale = output_scale.view(torch.float8_e8m0fnu)
return output, output_scale
def scaled_mxfp6_quant(input: torch.Tensor):
m, n = input.shape
block_size = 32
device = input.device
......@@ -63,12 +76,12 @@ def scaled_fp6_quant(input: torch.Tensor):
output = torch.empty((m, 3 * n // 4), device=device, dtype=torch.uint8)
output_scale = torch.zeros(((m + 128 - 1) // 128 * 128, (n // block_size + 4 - 1) // 4), device=device, dtype=torch.int32)
torch.ops.lightx2v_kernel.scaled_fp6_quant_sm120.default(output, input, output_scale)
torch.ops.lightx2v_kernel.scaled_mxfp6_quant_sm120.default(output, input, output_scale)
output_scale = output_scale.view(torch.float8_e8m0fnu)
return output, output_scale
def scaled_fp8_quant(input: torch.Tensor):
def scaled_mxfp8_quant(input: torch.Tensor):
m, n = input.shape
block_size = 32
device = input.device
......@@ -76,11 +89,18 @@ def scaled_fp8_quant(input: torch.Tensor):
output = torch.empty((m, n), device=device, dtype=torch.uint8)
output_scale = torch.empty(((m + 128 - 1) // 128 * 128, (n // block_size + 4 - 1) // 4), device=device, dtype=torch.int32)
torch.ops.lightx2v_kernel.scaled_fp8_quant_sm120.default(output, input, output_scale)
torch.ops.lightx2v_kernel.scaled_mxfp8_quant_sm120.default(output, input, output_scale)
output_scale = output_scale.view(torch.float8_e8m0fnu)
return output, output_scale
def cutlass_scaled_mxfp4_mm(mat_a, mat_b, scales_a, scales_b, alpha, bias=None):
m, n = mat_a.shape[0], mat_b.shape[0]
out = torch.empty((m, n), dtype=torch.bfloat16, device=mat_a.device)
torch.ops.lightx2v_kernel.cutlass_scaled_mxfp4_mm_sm120.default(out, mat_a, mat_b, scales_a, scales_b, alpha, bias)
return out
def cutlass_scaled_mxfp6_mxfp8_mm(mat_a, mat_b, scales_a, scales_b, alpha, bias=None):
m, n = mat_a.shape[0], mat_b.shape[0]
out = torch.empty((m, n), dtype=torch.bfloat16, device=mat_a.device)
......
import torch
from lightx2v_kernel.gemm import scaled_mxfp4_quant, cutlass_scaled_mxfp4_mm
import time
class MMWeightMxfp4ActMxfp4:
def __init__(self, weight, bias):
self.load_fp4_weight(weight, bias)
self.act_quant_func = self.act_quant_fp4
self.set_alpha()
@torch.no_grad()
def apply(self, input_tensor):
input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
output_tensor = cutlass_scaled_mxfp4_mm(input_tensor_quant, self.weight, input_tensor_scale, self.weight_scale, alpha=self.alpha, bias=self.bias)
return output_tensor
@torch.no_grad()
def load_fp4_weight(self, weight, bias):
self.weight, self.weight_scale = scaled_mxfp4_quant(weight)
self.bias = bias
def set_alpha(self):
self.alpha = torch.tensor(1.0, dtype=torch.float32, device=self.weight.device)
@torch.no_grad()
def act_quant_fp4(self, x):
return scaled_mxfp4_quant(x)
def test_speed(m, k, n):
with torch.no_grad():
input_tensor = torch.randn(m, k, dtype=torch.bfloat16).cuda()
weight = torch.randn(n, k, dtype=torch.bfloat16, device="cuda")
# bias = torch.randn(1, n, dtype=torch.bfloat16).cuda()
bias = None
mm = MMWeightMxfp4ActMxfp4(weight, bias)
# warmup
output_tensor = mm.apply(input_tensor)
torch.cuda.synchronize()
start_time = time.time()
for i in range(100):
output_tensor = mm.apply(input_tensor)
torch.cuda.synchronize()
end_time = time.time()
lightx2v_kernel_time = (end_time - start_time) / 100
print(f"lightx2v-kernel time: {lightx2v_kernel_time}")
input_tensor = torch.randn(m, n, dtype=torch.bfloat16).cuda()
weight = torch.randn(k, n, dtype=torch.bfloat16, device="cuda")
bias = torch.randn(1, k, dtype=torch.bfloat16).cuda()
linear = torch.nn.Linear(k, n, bias=False).cuda()
linear.weight.data = weight
# linear.bias.data = bias
# warmup
ref_output_tensor = linear(input_tensor)
torch.cuda.synchronize()
start_time = time.time()
for i in range(100):
ref_output_tensor = linear(input_tensor)
torch.cuda.synchronize()
end_time = time.time()
ref_time = (end_time - start_time) / 100
print(f"ref time: {ref_time}")
print(f"speedup: {ref_time / lightx2v_kernel_time:.3f}")
def test_accuracy(m, k, n):
with torch.no_grad():
input_tensor = torch.randn(m, k, dtype=torch.bfloat16).cuda()
weight = torch.randn(n, k, dtype=torch.bfloat16, device="cuda")
# bias = torch.randn(1, n, dtype=torch.bfloat16).cuda()
bias = None
linear = torch.nn.Linear(k, n, bias=False).cuda()
linear.weight.data = weight
# linear.bias.data = bias
ref_output_tensor = linear(input_tensor)
mm = MMWeightMxfp4ActMxfp4(weight, bias)
output_tensor = mm.apply(input_tensor)
# print(f"ref_output_tensor: {ref_output_tensor}")
# print(f"output_tensor: {output_tensor}")
# cosine
cos = torch.nn.functional.cosine_similarity(ref_output_tensor.flatten(), output_tensor.flatten(), dim=0)
print(f"cos : {cos}")
if __name__ == "__main__":
test_sizes = [
(32130, 5120, 5120),
(512, 5120, 5120),
(257, 5120, 5120),
(32130, 5120, 13824),
(32130, 13824, 5120),
(75348, 5120, 5120),
(75348, 13824, 5120),
(32760, 1536, 1536),
(512, 1536, 1536),
(32760, 1536, 8960),
(32760, 8960, 1536),
]
for i, (m, k, n) in enumerate(test_sizes):
print("-" * 30)
print(f"测试 {i + 1}: 张量大小 ({m}, {k}, {n})")
test_accuracy(m, k, n)
test_speed(m, k, n)
import torch
import time
from test_bench import MMWeightMxfp4ActMxfp4
def test_speed(m, k, n):
with torch.no_grad():
input_tensor = torch.randn(m, k, dtype=torch.bfloat16).cuda()
weight = torch.randn(n, k, dtype=torch.bfloat16, device="cuda")
bias = torch.ones(1, n, dtype=torch.bfloat16).cuda() * 50
mm = MMWeightMxfp4ActMxfp4(weight, bias)
# warmup
output_tensor = mm.apply(input_tensor)
torch.cuda.synchronize()
start_time = time.time()
for i in range(100):
output_tensor = mm.apply(input_tensor)
torch.cuda.synchronize()
end_time = time.time()
lightx2v_kernel_time = (end_time - start_time) / 100
print(f"lightx2v-kernel time: {lightx2v_kernel_time}")
input_tensor = torch.randn(m, n, dtype=torch.bfloat16).cuda()
weight = torch.randn(k, n, dtype=torch.bfloat16, device="cuda")
bias = torch.randn(1, k, dtype=torch.bfloat16).cuda()
linear = torch.nn.Linear(k, n, bias=True).cuda()
linear.weight.data = weight
linear.bias.data = bias
# warmup
ref_output_tensor = linear(input_tensor)
torch.cuda.synchronize()
start_time = time.time()
for i in range(100):
ref_output_tensor = linear(input_tensor)
torch.cuda.synchronize()
end_time = time.time()
ref_time = (end_time - start_time) / 100
print(f"ref time: {ref_time}")
print(f"speedup: {ref_time / lightx2v_kernel_time:.3f}")
def test_accuracy(m, k, n):
with torch.no_grad():
input_tensor = torch.randn(m, k, dtype=torch.bfloat16).cuda()
weight = torch.randn(n, k, dtype=torch.bfloat16, device="cuda")
bias = torch.ones(1, n, dtype=torch.bfloat16).cuda() * 50
linear = torch.nn.Linear(k, n, bias=True).cuda()
linear.weight.data = weight
linear.bias.data = bias
ref_output_tensor = linear(input_tensor)
mm = MMWeightMxfp4ActMxfp4(weight, bias)
output_tensor = mm.apply(input_tensor)
# print(f"ref_output_tensor: {ref_output_tensor}")
# print(f"output_tensor: {output_tensor}")
# cosine
cos = torch.nn.functional.cosine_similarity(ref_output_tensor.flatten(), output_tensor.flatten(), dim=0)
print(f"cos : {cos}")
if __name__ == "__main__":
test_sizes = [
(32130, 5120, 5120),
(512, 5120, 5120),
(257, 5120, 5120),
(32130, 5120, 13824),
(32130, 13824, 5120),
(75348, 5120, 5120),
(75348, 13824, 5120),
(32760, 1536, 1536),
(512, 1536, 1536),
(32760, 1536, 8960),
(32760, 8960, 1536),
]
for i, (m, k, n) in enumerate(test_sizes):
print("-" * 30)
print(f"测试 {i + 1}: 张量大小 ({m}, {k}, {n})")
test_accuracy(m, k, n)
test_speed(m, k, n)
import unittest
import torch
from lightx2v_kernel.gemm import cutlass_scaled_mxfp4_mm
from lightx2v_kernel.gemm import scaled_mxfp4_quant
from torch.nn.functional import linear
from lightx2v_kernel.utils import error, benchmark
class TestQuantBF162MXFP4(unittest.TestCase):
def setUp(self):
self.tokens = [128, 257, 512, 1024, 13325, 32130, 32760] # , 75348
self.channels = [128, 1536, 5120, 8960] # , 13824
self.hiddenDims = [128, 1536, 3072, 5120, 8960, 12800] # , 13824
self.device = "cuda"
self.dtype = torch.bfloat16
def test_accuracy(self):
"""Test the accuracy of quantization from BF16 to MXFP4."""
for m in self.tokens:
for k in self.hiddenDims:
for n in self.channels:
with self.subTest(shape=[m, k, n]):
activation = torch.randn(m, k, dtype=self.dtype, device=self.device)
activation_quant_pred, activation_scale_pred = scaled_mxfp4_quant(activation)
weight = torch.randn(n, k, dtype=self.dtype, device=self.device)
weight_quant_pred, weight_scale_pred = scaled_mxfp4_quant(weight)
bias = torch.rand(1, n, dtype=self.dtype, device=self.device) * 10
alpha = torch.tensor(1.0, device=self.device, dtype=torch.float32)
mm_pred = cutlass_scaled_mxfp4_mm(activation_quant_pred, weight_quant_pred, activation_scale_pred, weight_scale_pred, alpha=alpha, bias=bias)
mm_real = linear(activation, weight, bias=bias).to(torch.bfloat16)
# mxfp4_mxfp4 mm have very low accuracy, so we set the threshold to 3e-2.
self.assertTrue(error(mm_pred, mm_real) < 3e-2, f"Accuracy test failed for shape {m, k, n}: Error {error(mm_pred, mm_real)} exceeds threshold.")
def test_performance(self):
"""Benchmark the performance of Activation quantization from BF16 to MXFP4."""
for m in self.tokens:
for k in self.hiddenDims:
with self.subTest(shape=[m, k]):
input = torch.randn(m, k, dtype=self.dtype, device=self.device)
shape = [m, k]
tflops = 2 * (m * k / 1024**4)
benchmark(scaled_mxfp4_quant, shape, tflops, 100, input)
if __name__ == "__main__":
unittest.main()
import torch
from lightx2v_kernel.gemm import scaled_fp8_quant, scaled_fp6_quant, cutlass_scaled_mxfp6_mxfp8_mm
from lightx2v_kernel.gemm import scaled_mxfp8_quant, scaled_mxfp6_quant, cutlass_scaled_mxfp6_mxfp8_mm
import time
class MMWeightMxfp8ActMxfp6:
class MMWeightMxfp6ActMxfp8:
def __init__(self, weight, bias):
self.load_fp6_weight(weight, bias)
self.act_quant_func = self.act_quant_fp8
......@@ -17,7 +17,7 @@ class MMWeightMxfp8ActMxfp6:
@torch.no_grad()
def load_fp6_weight(self, weight, bias):
self.weight, self.weight_scale = scaled_fp6_quant(weight)
self.weight, self.weight_scale = scaled_mxfp6_quant(weight)
self.bias = bias
def set_alpha(self):
......@@ -25,7 +25,7 @@ class MMWeightMxfp8ActMxfp6:
@torch.no_grad()
def act_quant_fp8(self, x):
return scaled_fp8_quant(x)
return scaled_mxfp8_quant(x)
def test_speed(m, k, n):
......@@ -35,7 +35,7 @@ def test_speed(m, k, n):
# bias = torch.randn(1, n, dtype=torch.bfloat16).cuda()
bias = None
mm = MMWeightMxfp8ActMxfp6(weight, bias)
mm = MMWeightMxfp6ActMxfp8(weight, bias)
# warmup
output_tensor = mm.apply(input_tensor)
......@@ -87,7 +87,7 @@ def test_accuracy(m, k, n):
ref_output_tensor = linear(input_tensor)
mm = MMWeightMxfp8ActMxfp6(weight, bias)
mm = MMWeightMxfp6ActMxfp8(weight, bias)
output_tensor = mm.apply(input_tensor)
......
import torch
import time
from test_bench import MMWeightMxfp8ActMxfp6
from test_bench import MMWeightMxfp6ActMxfp8
def test_speed(m, k, n):
......@@ -9,7 +9,7 @@ def test_speed(m, k, n):
weight = torch.randn(n, k, dtype=torch.bfloat16, device="cuda")
bias = torch.randn(1, n, dtype=torch.bfloat16).cuda()
mm = MMWeightMxfp8ActMxfp6(weight, bias)
mm = MMWeightMxfp6ActMxfp8(weight, bias)
# warmup
output_tensor = mm.apply(input_tensor)
......@@ -60,7 +60,7 @@ def test_accuracy(m, k, n):
ref_output_tensor = linear(input_tensor)
mm = MMWeightMxfp8ActMxfp6(weight, bias)
mm = MMWeightMxfp6ActMxfp8(weight, bias)
output_tensor = mm.apply(input_tensor)
......
import unittest
import torch
from lightx2v_kernel.gemm import cutlass_scaled_mxfp6_mxfp8_mm
from lightx2v_kernel.gemm import scaled_fp6_quant, scaled_fp8_quant
from lightx2v_kernel.gemm import scaled_mxfp6_quant, scaled_mxfp8_quant
from torch.nn.functional import linear
from lightx2v_kernel.utils import error, benchmark
......@@ -22,10 +22,10 @@ class TestQuantBF162MXFP6(unittest.TestCase):
for n in self.channels:
with self.subTest(shape=[m, k, n]):
activation = torch.randn(m, k, dtype=self.dtype, device=self.device)
activation_quant_pred, activation_scale_pred = scaled_fp8_quant(activation)
activation_quant_pred, activation_scale_pred = scaled_mxfp8_quant(activation)
weight = torch.randn(n, k, dtype=self.dtype, device=self.device)
weight_quant_pred, weight_scale_pred = scaled_fp6_quant(weight)
weight_quant_pred, weight_scale_pred = scaled_mxfp6_quant(weight)
bias = torch.rand(1, n, dtype=self.dtype, device=self.device) * 10
......@@ -44,7 +44,7 @@ class TestQuantBF162MXFP6(unittest.TestCase):
input = torch.randn(m, k, dtype=self.dtype, device=self.device)
shape = [m, k]
tflops = 2 * (m * k / 1024**4)
benchmark(scaled_fp6_quant, shape, tflops, 100, input)
benchmark(scaled_mxfp6_quant, shape, tflops, 100, input)
if __name__ == "__main__":
......
import torch
from lightx2v_kernel.gemm import scaled_fp6_quant
from lightx2v_kernel.gemm import scaled_mxfp6_quant
def quantize_fp6(x):
return scaled_fp6_quant(x)
return scaled_mxfp6_quant(x)
def test_memory_bandwidth(func, x, num_warmup=10, num_runs=100):
......
import torch
from lightx2v_kernel.gemm import scaled_fp8_quant, cutlass_scaled_mxfp8_mm
from lightx2v_kernel.gemm import scaled_mxfp8_quant, cutlass_scaled_mxfp8_mm
import time
......@@ -17,7 +17,7 @@ class MMWeightMxfp8:
@torch.no_grad()
def load_fp8_weight(self, weight, bias):
self.weight, self.weight_scale = scaled_fp8_quant(weight)
self.weight, self.weight_scale = scaled_mxfp8_quant(weight)
self.bias = bias
def set_alpha(self):
......@@ -25,7 +25,7 @@ class MMWeightMxfp8:
@torch.no_grad()
def act_quant_fp8(self, x):
return scaled_fp8_quant(x)
return scaled_mxfp8_quant(x)
def test_speed(m, k, n):
......
import unittest
import torch
from lightx2v_kernel.gemm import cutlass_scaled_mxfp8_mm
from lightx2v_kernel.gemm import scaled_fp8_quant
from lightx2v_kernel.gemm import scaled_mxfp8_quant
from torch.nn.functional import linear
from lightx2v_kernel.utils import error, benchmark
......@@ -22,10 +22,10 @@ class TestQuantBF162MXFP8(unittest.TestCase):
for n in self.channels:
with self.subTest(shape=[m, k, n]):
activation = torch.randn(m, k, dtype=self.dtype, device=self.device)
activation_quant_pred, activation_scale_pred = scaled_fp8_quant(activation)
activation_quant_pred, activation_scale_pred = scaled_mxfp8_quant(activation)
weight = torch.randn(n, k, dtype=self.dtype, device=self.device)
weight_quant_pred, weight_scale_pred = scaled_fp8_quant(weight)
weight_quant_pred, weight_scale_pred = scaled_mxfp8_quant(weight)
bias = torch.rand(1, n, dtype=self.dtype, device=self.device) * 10
......@@ -44,7 +44,7 @@ class TestQuantBF162MXFP8(unittest.TestCase):
input = torch.randn(m, k, dtype=self.dtype, device=self.device)
shape = [m, k]
tflops = 2 * (m * k / 1024**4)
benchmark(scaled_fp8_quant, shape, tflops, 100, input)
benchmark(scaled_mxfp8_quant, shape, tflops, 100, input)
if __name__ == "__main__":
......
import torch
from lightx2v_kernel.gemm import scaled_fp8_quant
from lightx2v_kernel.gemm import scaled_mxfp8_quant
def quantize_fp8(x):
return scaled_fp8_quant(x)
return scaled_mxfp8_quant(x)
def test_memory_bandwidth(func, x, num_warmup=10, num_runs=100):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment