Commit 3e4fe79b authored by GoatWu's avatar GoatWu
Browse files

Merge branch 'main' of github.com:ModelTC/lightx2v into main

parents 8ddd33a5 d013cac7
...@@ -83,9 +83,6 @@ class EulerSchedulerTimestepFix(BaseScheduler): ...@@ -83,9 +83,6 @@ class EulerSchedulerTimestepFix(BaseScheduler):
self.sigmas = self.sigmas.to("cpu") self.sigmas = self.sigmas.to("cpu")
def prepare(self, image_encoder_output=None): def prepare(self, image_encoder_output=None):
self.generator = torch.Generator(device=self.device)
self.generator.manual_seed(self.config.seed)
self.prepare_latents(self.config.target_shape, dtype=torch.float32) self.prepare_latents(self.config.target_shape, dtype=torch.float32)
if self.config.task in ["t2v"]: if self.config.task in ["t2v"]:
...@@ -113,6 +110,7 @@ class EulerSchedulerTimestepFix(BaseScheduler): ...@@ -113,6 +110,7 @@ class EulerSchedulerTimestepFix(BaseScheduler):
self.set_timesteps(self.infer_steps, device=self.device, shift=self.sample_shift) self.set_timesteps(self.infer_steps, device=self.device, shift=self.sample_shift)
def prepare_latents(self, target_shape, dtype=torch.float32): def prepare_latents(self, target_shape, dtype=torch.float32):
self.generator = torch.Generator(device=self.device).manual_seed(self.config.seed)
self.latents = ( self.latents = (
torch.randn( torch.randn(
target_shape[0], target_shape[0],
......
...@@ -5,21 +5,30 @@ from lightx2v.models.schedulers.wan.scheduler import WanScheduler ...@@ -5,21 +5,30 @@ from lightx2v.models.schedulers.wan.scheduler import WanScheduler
class WanScheduler4ChangingResolution(WanScheduler): class WanScheduler4ChangingResolution(WanScheduler):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.resolution_rate = config.get("resolution_rate", 0.75) if "resolution_rate" not in config:
self.changing_resolution_steps = config.get("changing_resolution_steps", config.infer_steps // 2) config["resolution_rate"] = [0.75]
if "changing_resolution_steps" not in config:
config["changing_resolution_steps"] = [config.infer_steps // 2]
assert len(config["resolution_rate"]) == len(config["changing_resolution_steps"])
def prepare_latents(self, target_shape, dtype=torch.float32): def prepare_latents(self, target_shape, dtype=torch.float32):
self.latents = torch.randn( self.latents_list = []
for i in range(len(self.config["resolution_rate"])):
self.latents_list.append(
torch.randn(
target_shape[0], target_shape[0],
target_shape[1], target_shape[1],
int(target_shape[2] * self.resolution_rate) // 2 * 2, int(target_shape[2] * self.config["resolution_rate"][i]) // 2 * 2,
int(target_shape[3] * self.resolution_rate) // 2 * 2, int(target_shape[3] * self.config["resolution_rate"][i]) // 2 * 2,
dtype=dtype, dtype=dtype,
device=self.device, device=self.device,
generator=self.generator, generator=self.generator,
) )
)
self.noise_original_resolution = torch.randn( # add original resolution latents
self.latents_list.append(
torch.randn(
target_shape[0], target_shape[0],
target_shape[1], target_shape[1],
target_shape[2], target_shape[2],
...@@ -28,10 +37,16 @@ class WanScheduler4ChangingResolution(WanScheduler): ...@@ -28,10 +37,16 @@ class WanScheduler4ChangingResolution(WanScheduler):
device=self.device, device=self.device,
generator=self.generator, generator=self.generator,
) )
)
# set initial latents
self.latents = self.latents_list[0]
self.changing_resolution_index = 0
def step_post(self): def step_post(self):
if self.step_index == self.changing_resolution_steps: if self.step_index + 1 in self.config["changing_resolution_steps"]:
self.step_post_upsample() self.step_post_upsample()
self.changing_resolution_index += 1
else: else:
super().step_post() super().step_post()
...@@ -45,19 +60,21 @@ class WanScheduler4ChangingResolution(WanScheduler): ...@@ -45,19 +60,21 @@ class WanScheduler4ChangingResolution(WanScheduler):
# 2. upsample clean noise to target shape # 2. upsample clean noise to target shape
denoised_sample_5d = denoised_sample.unsqueeze(0) # (C,T,H,W) -> (1,C,T,H,W) denoised_sample_5d = denoised_sample.unsqueeze(0) # (C,T,H,W) -> (1,C,T,H,W)
clean_noise = torch.nn.functional.interpolate(denoised_sample_5d, size=(self.config.target_shape[1], self.config.target_shape[2], self.config.target_shape[3]), mode="trilinear")
shape_to_upsampled = self.latents_list[self.changing_resolution_index + 1].shape[1:]
clean_noise = torch.nn.functional.interpolate(denoised_sample_5d, size=shape_to_upsampled, mode="trilinear")
clean_noise = clean_noise.squeeze(0) # (1,C,T,H,W) -> (C,T,H,W) clean_noise = clean_noise.squeeze(0) # (1,C,T,H,W) -> (C,T,H,W)
# 3. add noise to clean noise # 3. add noise to clean noise
noisy_sample = self.add_noise(clean_noise, self.noise_original_resolution, self.timesteps[self.step_index + 1]) noisy_sample = self.add_noise(clean_noise, self.latents_list[self.changing_resolution_index + 1], self.timesteps[self.step_index + 1])
# 4. update latents # 4. update latents
self.latents = noisy_sample self.latents = noisy_sample
# self.disable_corrector = [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37] # maybe not needed # self.disable_corrector = [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37] # maybe not needed
# 5. update timesteps using shift + 2 更激进的去噪 # 5. update timesteps using shift + self.changing_resolution_index + 1 更激进的去噪
self.set_timesteps(self.infer_steps, device=self.device, shift=self.sample_shift + 2) self.set_timesteps(self.infer_steps, device=self.device, shift=self.sample_shift + self.changing_resolution_index + 1)
def add_noise(self, original_samples, noise, timesteps): def add_noise(self, original_samples, noise, timesteps):
sigma = self.sigmas[self.step_index] sigma = self.sigmas[self.step_index]
......
...@@ -27,11 +27,6 @@ class WanScheduler(BaseScheduler): ...@@ -27,11 +27,6 @@ class WanScheduler(BaseScheduler):
self.prepare_latents(self.config.target_shape, dtype=torch.float32) self.prepare_latents(self.config.target_shape, dtype=torch.float32)
if self.config.task in ["t2v"]:
self.seq_len = math.ceil((self.config.target_shape[2] * self.config.target_shape[3]) / (self.config.patch_size[1] * self.config.patch_size[2]) * self.config.target_shape[1])
elif self.config.task in ["i2v"]:
self.seq_len = ((self.config.target_video_length - 1) // self.config.vae_stride[0] + 1) * self.config.lat_h * self.config.lat_w // (self.config.patch_size[1] * self.config.patch_size[2])
alphas = np.linspace(1, 1 / self.num_train_timesteps, self.num_train_timesteps)[::-1].copy() alphas = np.linspace(1, 1 / self.num_train_timesteps, self.num_train_timesteps)[::-1].copy()
sigmas = 1.0 - alphas sigmas = 1.0 - alphas
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32) sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32)
......
...@@ -93,8 +93,10 @@ list(APPEND LIGHTX2V_KERNEL_CUDA_FLAGS ...@@ -93,8 +93,10 @@ list(APPEND LIGHTX2V_KERNEL_CUDA_FLAGS
set(SOURCES set(SOURCES
"csrc/gemm/nvfp4_scaled_mm_kernels_sm120.cu" "csrc/gemm/nvfp4_scaled_mm_kernels_sm120.cu"
"csrc/gemm/nvfp4_quant_kernels_sm120.cu" "csrc/gemm/nvfp4_quant_kernels_sm120.cu"
"csrc/gemm/mxfp4_quant_kernels_sm120.cu"
"csrc/gemm/mxfp8_quant_kernels_sm120.cu" "csrc/gemm/mxfp8_quant_kernels_sm120.cu"
"csrc/gemm/mxfp6_quant_kernels_sm120.cu" "csrc/gemm/mxfp6_quant_kernels_sm120.cu"
"csrc/gemm/mxfp4_scaled_mm_kernels_sm120.cu"
"csrc/gemm/mxfp6_mxfp8_scaled_mm_kernels_sm120.cu" "csrc/gemm/mxfp6_mxfp8_scaled_mm_kernels_sm120.cu"
"csrc/gemm/mxfp8_scaled_mm_kernels_sm120.cu" "csrc/gemm/mxfp8_scaled_mm_kernels_sm120.cu"
"csrc/common_extension.cc" "csrc/common_extension.cc"
......
...@@ -7,23 +7,34 @@ ...@@ -7,23 +7,34 @@
TORCH_LIBRARY_FRAGMENT(lightx2v_kernel, m) { TORCH_LIBRARY_FRAGMENT(lightx2v_kernel, m) {
m.def( m.def(
"cutlass_scaled_fp4_mm_sm120(Tensor! out, Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, Tensor " "cutlass_scaled_nvfp4_mm_sm120(Tensor! out, Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, Tensor "
"alpha, Tensor? bias) -> ()"); "alpha, Tensor? bias) -> ()");
m.impl("cutlass_scaled_fp4_mm_sm120", torch::kCUDA, &cutlass_scaled_fp4_mm_sm120); m.impl("cutlass_scaled_nvfp4_mm_sm120", torch::kCUDA, &cutlass_scaled_nvfp4_mm_sm120);
m.def( m.def(
"scaled_fp4_quant_sm120(Tensor! output, Tensor! input," "scaled_nvfp4_quant_sm120(Tensor! output, Tensor! input,"
" Tensor! output_scale, Tensor! input_scale) -> ()"); " Tensor! output_scale, Tensor! input_scale) -> ()");
m.impl("scaled_fp4_quant_sm120", torch::kCUDA, &scaled_fp4_quant_sm120); m.impl("scaled_nvfp4_quant_sm120", torch::kCUDA, &scaled_nvfp4_quant_sm120);
m.def( m.def(
"scaled_fp8_quant_sm120(Tensor! output, Tensor! input," "scaled_mxfp4_quant_sm120(Tensor! output, Tensor! input,"
" Tensor! output_scale) -> ()"); " Tensor! output_scale) -> ()");
m.impl("scaled_fp8_quant_sm120", torch::kCUDA, &scaled_fp8_quant_sm120); m.impl("scaled_mxfp4_quant_sm120", torch::kCUDA, &scaled_mxfp4_quant_sm120);
m.def(
"scaled_mxfp8_quant_sm120(Tensor! output, Tensor! input,"
" Tensor! output_scale) -> ()");
m.impl("scaled_mxfp8_quant_sm120", torch::kCUDA, &scaled_mxfp8_quant_sm120);
m.def( m.def(
"scaled_fp6_quant_sm120(Tensor! output, Tensor! input," "scaled_mxfp6_quant_sm120(Tensor! output, Tensor! input,"
" Tensor! output_scale) -> ()"); " Tensor! output_scale) -> ()");
m.impl("scaled_fp6_quant_sm120", torch::kCUDA, &scaled_fp6_quant_sm120); m.impl("scaled_mxfp6_quant_sm120", torch::kCUDA, &scaled_mxfp6_quant_sm120);
m.def(
"cutlass_scaled_mxfp4_mm_sm120(Tensor! out, Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, Tensor "
"alpha, Tensor? bias) -> ()");
m.impl("cutlass_scaled_mxfp4_mm_sm120", torch::kCUDA, &cutlass_scaled_mxfp4_mm_sm120);
m.def( m.def(
"cutlass_scaled_mxfp6_mxfp8_mm_sm120(Tensor! out, Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, Tensor " "cutlass_scaled_mxfp6_mxfp8_mm_sm120(Tensor! out, Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, Tensor "
......
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_fp8.h>
#include <cuda_runtime.h>
#include <cuda_runtime_api.h>
#include <torch/all.h>
#include "utils.h"
// Get type2 from type or vice versa (applied to half and bfloat16)
template <typename T>
struct TypeConverter {
using Type = half2;
}; // keep for generality
template <>
struct TypeConverter<half2> {
using Type = half;
};
template <>
struct TypeConverter<half> {
using Type = half2;
};
template <>
struct TypeConverter<__nv_bfloat162> {
using Type = __nv_bfloat16;
};
template <>
struct TypeConverter<__nv_bfloat16> {
using Type = __nv_bfloat162;
};
#define ELTS_PER_THREAD 8
constexpr int CVT_FP4_ELTS_PER_THREAD = 8;
constexpr int CVT_FP4_SF_VEC_SIZE = 32;
// Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t).
inline __device__ uint32_t fp32_vec_to_e2m1(float2 (&array)[4]) {
// PTX instructions used here requires sm100a.
// #if CUDA_VERSION >= 12080
// #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) && __CUDA_ARCH_HAS_FEATURE__(SM100_ALL)
uint32_t val;
asm volatile(
"{\n"
".reg .b8 byte0;\n"
".reg .b8 byte1;\n"
".reg .b8 byte2;\n"
".reg .b8 byte3;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n"
"mov.b32 %0, {byte0, byte1, byte2, byte3};\n"
"}"
: "=r"(val)
: "f"(array[0].x),
"f"(array[0].y),
"f"(array[1].x),
"f"(array[1].y),
"f"(array[2].x),
"f"(array[2].y),
"f"(array[3].x),
"f"(array[3].y));
return val;
// #else
// return 0;
// #endif
// #endif
}
// Fast reciprocal.
inline __device__ float reciprocal_approximate_ftz(float a) {
float b;
asm volatile("rcp.approx.ftz.f32 %0, %1;\n" : "=f"(b) : "f"(a));
return b;
}
template <class SFType, int CVT_FP4_NUM_THREADS_PER_SF>
__device__ uint8_t* get_sf_out_address(int rowIdx, int colIdx, int numCols, SFType* SFout) {
// #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
static_assert(CVT_FP4_NUM_THREADS_PER_SF == 4);
// One pair of threads write one SF to global memory.
// TODO: stage through smem for packed STG.32
// is it better than STG.8 from 4 threads ?
if (threadIdx.x % CVT_FP4_NUM_THREADS_PER_SF == 0) {
// SF vector index (16 elements share one SF in the K dimension).
int32_t kIdx = colIdx / CVT_FP4_NUM_THREADS_PER_SF;
int32_t mIdx = rowIdx;
// SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)]
// --> index [mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx]
int32_t mTileIdx = mIdx / (32 * 4);
// SF vector size 16.
int factor = CVT_FP4_SF_VEC_SIZE * 4;
int32_t numKTiles = (numCols + factor - 1) / factor;
int64_t mTileStride = numKTiles * 32 * 4 * 4;
int32_t kTileIdx = (kIdx / 4);
int64_t kTileStride = 32 * 4 * 4;
// M tile layout [32, 4] is column-major.
int32_t outerMIdx = (mIdx % 32);
int64_t outerMStride = 4 * 4;
int32_t innerMIdx = (mIdx % (32 * 4)) / 32;
int64_t innerMStride = 4;
int32_t innerKIdx = (kIdx % 4);
int64_t innerKStride = 1;
// Compute the global offset.
int64_t SFOffset = mTileIdx * mTileStride + kTileIdx * kTileStride + outerMIdx * outerMStride +
innerMIdx * innerMStride + innerKIdx * innerKStride;
return reinterpret_cast<uint8_t*>(SFout) + SFOffset;
}
// #endif
return nullptr;
}
// Define a 16 bytes packed data type.
template <class Type>
struct PackedVec {
typename TypeConverter<Type>::Type elts[4];
};
template <>
struct PackedVec<__nv_fp8_e4m3> {
__nv_fp8x2_e4m3 elts[8];
};
// Quantizes the provided PackedVec into the uint32_t output
template <class Type>
__device__ uint32_t cvt_warp_fp16_to_fp4(PackedVec<Type>& vec, uint8_t* SFout) {
// #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
// Get absolute maximum values among the local 8 values.
auto localMax = __habs2(vec.elts[0]);
// Local maximum value.
#pragma unroll
for (int i = 1; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) {
localMax = __hmax2(localMax, __habs2(vec.elts[i]));
}
// Get the absolute maximum among all 32 values (four threads).
localMax = __hmax2(__shfl_xor_sync(uint32_t(-1), localMax, 1), localMax);
localMax = __hmax2(__shfl_xor_sync(uint32_t(-1), localMax, 2), localMax);
// Get the final absolute maximum values.
float vecMax = float(__hmax(localMax.x, localMax.y));
// Get the SF (max value of the vector / max value of e2m1).
// maximum value of e2m1 = 6.0.
// TODO: use half as compute data type.
float SFValue = vecMax * 0.16666666666666666f;
// 8 bits representation of the SF.
uint8_t fp8SFVal;
// Write the SF to global memory (STG.8).
__nv_fp8_e8m0 tmp;
tmp.__x = __nv_cvt_float_to_e8m0(SFValue, __NV_SATFINITE, cudaRoundPosInf);
SFValue = static_cast<float>(tmp);
fp8SFVal = tmp.__x;
// Get the output scale.
float outputScale =
SFValue != 0 ? reciprocal_approximate_ftz(SFValue) : 0.0f;
if (SFout) {
// Write the SF to global memory (STG.8).
*SFout = fp8SFVal;
}
// Convert the input to float.
float2 fp2Vals[CVT_FP4_ELTS_PER_THREAD / 2];
#pragma unroll
for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) {
if constexpr (std::is_same_v<Type, half>) {
fp2Vals[i] = __half22float2(vec.elts[i]);
} else {
fp2Vals[i] = __bfloat1622float2(vec.elts[i]);
}
fp2Vals[i].x *= outputScale;
fp2Vals[i].y *= outputScale;
}
// Convert to e2m1 values.
uint32_t e2m1Vec = fp32_vec_to_e2m1(fp2Vals);
// Write the e2m1 values to global memory.
return e2m1Vec;
// #else
// return 0;
// #endif
}
// Use UE4M3 by default.
template <class Type>
__global__ void
// #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
__launch_bounds__(256, 6) cvt_fp16_to_fp4(
// #else
// cvt_fp16_to_fp4(
// #endif
int32_t numRows, int32_t numCols, Type const* in, uint32_t* out, uint32_t* SFout) {
// #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
using PackedVec = PackedVec<Type>;
static constexpr int CVT_FP4_NUM_THREADS_PER_SF = (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD);
static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD, "Vec size is not matched.");
// Input tensor row/col loops.
for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x) {
for (int colIdx = threadIdx.x; colIdx < numCols / CVT_FP4_ELTS_PER_THREAD; colIdx += blockDim.x) {
int64_t inOffset = rowIdx * (numCols / CVT_FP4_ELTS_PER_THREAD) + colIdx;
PackedVec in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset];
// Get the output tensor offset.
// Same as inOffset because 8 elements are packed into one uint32_t.
int64_t outOffset = inOffset;
auto& out_pos = out[outOffset];
auto sf_out =
get_sf_out_address<uint32_t, CVT_FP4_NUM_THREADS_PER_SF>(rowIdx, colIdx, numCols, SFout);
out_pos = cvt_warp_fp16_to_fp4<Type>(in_vec, sf_out);
}
}
// #endif
}
template <typename T>
void invokeFP4Quantization(
int m,
int n,
T const* input,
int64_t* output,
int32_t* SFOuput,
int multiProcessorCount,
cudaStream_t stream) {
// Grid, Block size.
// Each thread converts 8 values.
dim3 block(std::min(int(n / ELTS_PER_THREAD), 256));
// Get number of blocks per SM (assume we can fully utilize the SM).
int const numBlocksPerSM = 1536 / block.x;
dim3 grid(std::min(int(m), multiProcessorCount * numBlocksPerSM));
// Launch the cvt kernel.
cvt_fp16_to_fp4<T><<<grid, block, 0, stream>>>(
m, n, input, reinterpret_cast<uint32_t*>(output), reinterpret_cast<uint32_t*>(SFOuput));
}
// Instantiate the function.
template void invokeFP4Quantization(
int m,
int n,
half const* input,
int64_t* output,
int32_t* SFOuput,
int multiProcessorCount,
cudaStream_t stream);
template void invokeFP4Quantization(
int m,
int n,
__nv_bfloat16 const* input,
int64_t* output,
int32_t* SFOuput,
int multiProcessorCount,
cudaStream_t stream);
inline int getMultiProcessorCount() {
static int multi_processor_count = []() {
int device_id = 0;
int count = 0;
// Get the current CUDA device ID
CHECK_CUDA_SUCCESS(cudaGetDevice(&device_id));
// Get the number of multiprocessors for the current device
CHECK_CUDA_SUCCESS(cudaDeviceGetAttribute(&count, cudaDevAttrMultiProcessorCount, device_id));
return count; // Initialize the static variable
}();
return multi_processor_count; // Return the cached value on subsequent calls
}
void scaled_mxfp4_quant_sm120(
torch::Tensor& output, torch::Tensor const& input, torch::Tensor& output_sf) {
int32_t m = input.size(0);
int32_t n = input.size(1);
TORCH_CHECK(n % 16 == 0, "The N dimension must be multiple of 16.");
int multiProcessorCount = getMultiProcessorCount();
auto sf_out = static_cast<int32_t*>(output_sf.data_ptr());
auto output_ptr = static_cast<int64_t*>(output.data_ptr());
at::cuda::CUDAGuard device_guard{(char)input.get_device()};
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(input.get_device());
switch (input.scalar_type()) {
case torch::kHalf: {
auto input_ptr = reinterpret_cast<half const*>(input.data_ptr());
invokeFP4Quantization(m, n, input_ptr, output_ptr, sf_out, multiProcessorCount, stream);
break;
}
case torch::kBFloat16: {
auto input_ptr = reinterpret_cast<__nv_bfloat16 const*>(input.data_ptr());
invokeFP4Quantization(m, n, input_ptr, output_ptr, sf_out, multiProcessorCount, stream);
break;
}
default: {
std::cerr << "Observing: " << input.scalar_type() << " for the input datatype which is invalid";
throw std::runtime_error("Unsupported input data type for quantize_to_fp4.");
}
}
}
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
// clang-format off
#include "cutlass/cutlass.h"
#include "cutlass/epilogue/fusion/operations.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/util/packed_stride.hpp"
// clang-format on
#define CUTLASS_CHECK(status) \
{ \
cutlass::Status error = status; \
TORCH_CHECK(error == cutlass::Status::kSuccess, cutlassGetStatusString(error)); \
}
#define CHECK_TYPE(x, st, m) TORCH_CHECK(x.scalar_type() == st, "Inconsistency of Tensor type:", m)
#define CHECK_TH_CUDA(x, m) TORCH_CHECK(x.is_cuda(), m, "must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x, m) TORCH_CHECK(x.is_contiguous(), m, "must be contiguous")
#define CHECK_INPUT(x, st, m) \
CHECK_TH_CUDA(x, m); \
CHECK_CONTIGUOUS(x, m); \
CHECK_TYPE(x, st, m)
using namespace cute;
struct Mxfp4GemmSm120 {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM kernel configurations
/////////////////////////////////////////////////////////////////////////////////////////////////
// A matrix configuration
using ElementA = cutlass::mx_float4_t<cutlass::float_e2m1_t>; // Element type for A matrix operand
using LayoutATag = cutlass::layout::RowMajor; // Layout type for A matrix operand
static constexpr int AlignmentA = 128; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
// B matrix configuration
using ElementB = cutlass::mx_float4_t<cutlass::float_e2m1_t>; // Element type for B matrix operand
using LayoutBTag = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
static constexpr int AlignmentB = 128; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
// C/D matrix configuration
using ElementD = cutlass::bfloat16_t; // Element type for D matrix operand
using ElementC = cutlass::bfloat16_t; // Element type for C matrix operand
using LayoutCTag = cutlass::layout::RowMajor; // Layout type for C matrix operand
using LayoutDTag = cutlass::layout::RowMajor; // Layout type for D matrix operand
static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
// Kernel functional config
using ElementAccumulator = float; // Element type for internal accumulation
using ArchTag = cutlass::arch::Sm120; // Tag indicating the minimum SM that supports the intended feature
using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; // Operator class tag
// Kernel Perf config
using ThreadBlockShape = Shape<_128,_128,_128>; // Threadblock's tile size
using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster
// use per-column bias, i.e. every column has different bias
using EVTOp = cutlass::epilogue::fusion::LinCombPerColBias<ElementD, ElementAccumulator>;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, OperatorClass,
ThreadBlockShape, ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementAccumulator,
ElementC, LayoutCTag, AlignmentC,
ElementD, LayoutDTag, AlignmentD,
cutlass::epilogue::collective::EpilogueScheduleAuto, // Epilogue schedule policy
EVTOp
>::CollectiveOp;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass,
ElementA, LayoutATag, AlignmentA,
ElementB, LayoutBTag, AlignmentB,
ElementAccumulator,
ThreadBlockShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto defaults to cooperative kernel schedule
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>, // Indicates ProblemShape
CollectiveMainloop,
CollectiveEpilogue,
void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
// Reference device GEMM implementation type
using StrideA = typename Gemm::GemmKernel::StrideA;
using LayoutA = decltype(cute::make_layout(make_shape(0,0,0), StrideA{}));
using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA; // Scale Factor tensors have an interleaved layout. Bring Layout instead of stride.
using StrideB = typename Gemm::GemmKernel::StrideB;
using LayoutB = decltype(cute::make_layout(make_shape(0,0,0), StrideB{}));
using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB; // Scale Factor tensors have an interleaved layout. Bring Layout instead of stride.
using StrideC = typename Gemm::GemmKernel::StrideC;
using LayoutC = decltype(cute::make_layout(make_shape(0,0,0), StrideC{}));
using StrideD = typename Gemm::GemmKernel::StrideD;
using LayoutD = decltype(cute::make_layout(make_shape(0,0,0), StrideD{}));
};
// Populates a Gemm::Arguments structure from the given commandline options
typename Mxfp4GemmSm120::Gemm::Arguments args_from_options_mxp4_mxfp4(
at::Tensor& D,
at::Tensor const& A,
at::Tensor const& B,
at::Tensor const& A_sf,
at::Tensor const& B_sf,
at::Tensor const& alpha,
c10::optional<torch::Tensor> const& bias,
int64_t M,
int64_t N,
int64_t K) {
using Sm1xxBlkScaledConfig = typename Mxfp4GemmSm120::Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
int m = static_cast<int>(M);
int n = static_cast<int>(N);
int k = static_cast<int>(K);
auto stride_A = cutlass::make_cute_packed_stride(Mxfp4GemmSm120::StrideA{}, {m, k, 1});
auto stride_B = cutlass::make_cute_packed_stride(Mxfp4GemmSm120::StrideB{}, {n, k, 1});
auto stride_D = cutlass::make_cute_packed_stride(Mxfp4GemmSm120::StrideD{}, {m, n, 1});
auto layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(m, n, k, 1));
auto layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(m, n, k, 1));
if (bias){
using StrideBias = Stride<cutlass::_0, cutlass::_1, int64_t>;
typename Mxfp4GemmSm120::Gemm::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
{m, n, k, 1},
{// Mainloop arguments
static_cast<Mxfp4GemmSm120::Gemm::ElementA const*>(A.data_ptr()),
stride_A,
static_cast<Mxfp4GemmSm120::Gemm::ElementB const*>(B.data_ptr()),
stride_B,
static_cast<cutlass::float_ue8m0_t const*>(A_sf.data_ptr()),
layout_SFA,
static_cast<cutlass::float_ue8m0_t const*>(B_sf.data_ptr()),
layout_SFB},
{ // Epilogue arguments
{}, // epilogue.thread
static_cast<Mxfp4GemmSm120::Gemm::ElementC const*>(D.data_ptr()),
stride_D,
static_cast<Mxfp4GemmSm120::Gemm::ElementD*>(D.data_ptr()),
stride_D}};
auto& fusion_args = arguments.epilogue.thread;
fusion_args.alpha_ptr = static_cast<float const*>(alpha.data_ptr());
static const float beta_zero = 0.0f;
fusion_args.beta_ptr = &beta_zero;
fusion_args.bias_ptr = static_cast<Mxfp4GemmSm120::Gemm::ElementC const*>(bias->data_ptr());
fusion_args.dBias = StrideBias{};
return arguments;
} else {
typename Mxfp4GemmSm120::Gemm::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
{m, n, k, 1},
{// Mainloop arguments
static_cast<Mxfp4GemmSm120::Gemm::ElementA const*>(A.data_ptr()),
stride_A,
static_cast<Mxfp4GemmSm120::Gemm::ElementB const*>(B.data_ptr()),
stride_B,
static_cast<cutlass::float_ue8m0_t const*>(A_sf.data_ptr()),
layout_SFA,
static_cast<cutlass::float_ue8m0_t const*>(B_sf.data_ptr()),
layout_SFB},
{ // Epilogue arguments
{}, // epilogue.thread
static_cast<Mxfp4GemmSm120::Gemm::ElementC const*>(D.data_ptr()),
stride_D,
static_cast<Mxfp4GemmSm120::Gemm::ElementD*>(D.data_ptr()),
stride_D}};
auto& fusion_args = arguments.epilogue.thread;
fusion_args.alpha_ptr = static_cast<float const*>(alpha.data_ptr());
static const float beta_zero = 0.0f;
fusion_args.beta_ptr = &beta_zero;
return arguments;
}
}
void runGemmMxfp4Sm120(
at::Tensor& D,
at::Tensor const& A,
at::Tensor const& B,
at::Tensor const& A_sf,
at::Tensor const& B_sf,
at::Tensor const& alpha,
c10::optional<torch::Tensor> const& bias,
int64_t m,
int64_t n,
int64_t k,
cudaStream_t stream) {
typename Mxfp4GemmSm120::Gemm gemm;
auto arguments = args_from_options_mxp4_mxfp4(D, A, B, A_sf, B_sf, alpha, bias, m, n, k);
size_t workspace_size = Mxfp4GemmSm120::Gemm::get_workspace_size(arguments);
auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(A.device());
auto workspace = torch::empty(workspace_size, workspace_options);
CUTLASS_CHECK(gemm.can_implement(arguments));
CUTLASS_CHECK(gemm.initialize(arguments, workspace.data_ptr(), stream));
CUTLASS_CHECK(gemm.run(arguments, workspace.data_ptr(), stream));
}
constexpr auto FLOAT4_E2M1X2 = at::ScalarType::Byte;
constexpr auto SF_DTYPE = at::ScalarType::Float8_e8m0fnu;
void cutlass_scaled_mxfp4_mm_sm120(
torch::Tensor& D,
torch::Tensor const& A,
torch::Tensor const& B,
torch::Tensor const& A_sf,
torch::Tensor const& B_sf,
torch::Tensor const& alpha,
c10::optional<torch::Tensor> const& bias) {
CHECK_INPUT(A, FLOAT4_E2M1X2, "a");
CHECK_INPUT(B, FLOAT4_E2M1X2, "b");
CHECK_INPUT(A_sf, SF_DTYPE, "scale_a");
CHECK_INPUT(B_sf, SF_DTYPE, "scale_b");
CHECK_INPUT(alpha, at::ScalarType::Float, "alpha");
TORCH_CHECK(A.dim() == 2, "a must be a matrix");
TORCH_CHECK(B.dim() == 2, "b must be a matrix");
TORCH_CHECK(
A.sizes()[1] == B.sizes()[1],
"a and b shapes cannot be multiplied (",
A.sizes()[0],
"x",
A.sizes()[1],
" and ",
B.sizes()[0],
"x",
B.sizes()[1],
")");
auto const m = A.sizes()[0];
auto const n = B.sizes()[0];
auto const k = A.sizes()[1] * 2;
constexpr int alignment = 128;
TORCH_CHECK(
k % alignment == 0,
"Expected k to be divisible by ",
alignment,
", but got a shape: (",
A.sizes()[0],
"x",
A.sizes()[1],
"), k: ",
k,
".");
TORCH_CHECK(
n % alignment == 0,
"Expected n to be divisible by ",
alignment,
", but got b shape: (",
B.sizes()[0],
"x",
B.sizes()[1],
").");
auto round_up = [](int x, int y) { return (x + y - 1) / y * y; };
int rounded_m = round_up(m, 128);
int rounded_n = round_up(n, 128);
// Since k is divisible by 128 (alignment), k / 32 is guaranteed to be an
// integer.
int rounded_k = round_up(k / 32, 4);
TORCH_CHECK(A_sf.dim() == 2, "scale_a must be a matrix");
TORCH_CHECK(B_sf.dim() == 2, "scale_b must be a matrix");
TORCH_CHECK(
A_sf.sizes()[1] == B_sf.sizes()[1],
"scale_a and scale_b shapes cannot be multiplied (",
A_sf.sizes()[0],
"x",
A_sf.sizes()[1],
" and ",
B_sf.sizes()[0],
"x",
B_sf.sizes()[1],
")");
TORCH_CHECK(
A_sf.sizes()[0] == rounded_m && A_sf.sizes()[1] == rounded_k,
"scale_a must be padded and swizzled to a shape (",
rounded_m,
"x",
rounded_k,
"), but got a shape (",
A_sf.sizes()[0],
"x",
A_sf.sizes()[1],
")");
TORCH_CHECK(
B_sf.sizes()[0] == rounded_n && B_sf.sizes()[1] == rounded_k,
"scale_b must be padded and swizzled to a shape (",
rounded_n,
"x",
rounded_k,
"), but got a shape (",
B_sf.sizes()[0],
"x",
B_sf.sizes()[1],
")");
auto out_dtype = D.dtype();
at::cuda::CUDAGuard device_guard{(char)A.get_device()};
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(A.get_device());
runGemmMxfp4Sm120(D, A, B, A_sf, B_sf, alpha, bias, m, n, k, stream);
}
...@@ -315,7 +315,7 @@ inline int getMultiProcessorCount() { ...@@ -315,7 +315,7 @@ inline int getMultiProcessorCount() {
return multi_processor_count; // Return the cached value on subsequent calls return multi_processor_count; // Return the cached value on subsequent calls
} }
void scaled_fp6_quant_sm120( void scaled_mxfp6_quant_sm120(
torch::Tensor& output, torch::Tensor const& input, torch::Tensor& output_sf) { torch::Tensor& output, torch::Tensor const& input, torch::Tensor& output_sf) {
int32_t m = input.size(0); int32_t m = input.size(0);
int32_t n = input.size(1); int32_t n = input.size(1);
......
...@@ -282,7 +282,7 @@ inline int getMultiProcessorCount() { ...@@ -282,7 +282,7 @@ inline int getMultiProcessorCount() {
return multi_processor_count; // Return the cached value on subsequent calls return multi_processor_count; // Return the cached value on subsequent calls
} }
void scaled_fp8_quant_sm120( void scaled_mxfp8_quant_sm120(
torch::Tensor& output, torch::Tensor const& input, torch::Tensor& output_sf) { torch::Tensor& output, torch::Tensor const& input, torch::Tensor& output_sf) {
int32_t m = input.size(0); int32_t m = input.size(0);
int32_t n = input.size(1); int32_t n = input.size(1);
......
...@@ -350,7 +350,7 @@ inline int getMultiProcessorCount() { ...@@ -350,7 +350,7 @@ inline int getMultiProcessorCount() {
return multi_processor_count; // Return the cached value on subsequent calls return multi_processor_count; // Return the cached value on subsequent calls
} }
void scaled_fp4_quant_sm120( void scaled_nvfp4_quant_sm120(
torch::Tensor& output, torch::Tensor const& input, torch::Tensor& output_sf, torch::Tensor const& input_sf) { torch::Tensor& output, torch::Tensor const& input, torch::Tensor& output_sf, torch::Tensor const& input_sf) {
int32_t m = input.size(0); int32_t m = input.size(0);
int32_t n = input.size(1); int32_t n = input.size(1);
......
...@@ -215,7 +215,7 @@ void runGemmNvfp4Sm120( ...@@ -215,7 +215,7 @@ void runGemmNvfp4Sm120(
constexpr auto FLOAT4_E2M1X2 = at::ScalarType::Byte; constexpr auto FLOAT4_E2M1X2 = at::ScalarType::Byte;
constexpr auto SF_DTYPE = at::ScalarType::Float8_e4m3fn; constexpr auto SF_DTYPE = at::ScalarType::Float8_e4m3fn;
void cutlass_scaled_fp4_mm_sm120( void cutlass_scaled_nvfp4_mm_sm120(
torch::Tensor& D, torch::Tensor& D,
torch::Tensor const& A, torch::Tensor const& A,
torch::Tensor const& B, torch::Tensor const& B,
......
# MX-Formats Quantization Basics
**Note: The following focuses on sharing the differences between MX-Formats quantization and Per-Row/Per-Column quantization, as well as the layout requirements for compatibility with Cutlass Block Scaled GEMMs.**
### Data Formats and Quantization Factors
Target data format reference: [MX-Formats](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf). Note that we do not need to pack raw data and scale factors together here.
Source data format: fp16/bf16
Target data format: mxfp4/6/8
Quantization factor data format: E8M0, Per-Row/Per-Column quantization typically stores quantization factors in fp32, whereas E8M0 has the same numerical range as fp32. After rounding, the quantization factors can be stored directly, though the loss of mantissa bits may affect precision.
Quantization granularity: \[1X32\]
Quantization dimension: Following Cutlass GEMM conventions, where M, N, K represent the three dimensions of matrix multiplication, we should quantize along K dimension.
### Rounding and Clamp
Unlike software emulation, CUDA can efficiently handle complex rounding and clamping operations using PTX or built-in functions.
For example, `cvt.rn.satfinite.e2m1x2.f32` can convert two fp32 inputs into two fp4 outputs.
Rounding mode: `rn` (round-to-nearest-even)
Clamp mode: `satfinite` (clamped to the maximum finite value within the target range, excluding infinities and NaN)
For more data types and modes, refer to: [PTX cvt Instructions](https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cvt)
### Data Layout and Quantization Factor Layout
**Data Layout**
- mxfp4 requires packing two values into a uint8.
- mxfp6 requires packing every four values into three uint8s. For the format, refer to: [mxfp6 cutlass mm format packing](https://github.com/ModelTC/LightX2V/blob/main/lightx2v_kernel/csrc/gemm/mxfp6_quant_kernels_sm120.cu#L74).
**Quantization Factor Layout**
Cutlass Block Scaled GEMMs impose special swizzle requirements on quantization factor layouts to optimize matrix operations.
Reference: [Scale Factor Layouts](https://github.com/NVIDIA/cutlass/blob/main/media/docs/cpp/blackwell_functionality.md#scale-factor-layouts)
### Quantization Method
After understanding the above, the calculation of the target data and quantization factor values can refer to [nvfp4 Quantization Basics](https://github.com/theNiemand/lightx2v/blob/main/lightx2v_kernel/docs/zh_CN/nvfp4%E9%87%8F%E5%8C%96%E5%9F%BA%E7%A1%80.md). Note that MX-Formats do not require quantizing the scale itself.
# MX-Formats量化基础
**注:下文关注于分享MX-Formats量化相对于Per-Row/Per-Column量化的区别,以及与Cutlass Block Scaled GEMMs配合使用需要满足的一些布局要求。**
### 数据格式与量化因子
目标数据格式参考:[MX-Formats](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf),需要注意的是,我们这里不需要将raw data和scale factor打包在一起
源数据格式:fp16/bf16
目标数据格式:mxfp4/6/8
量化因子数据格式:E8M0, Per-Row/Per-Column量化的量化因子一般以fp32进行存储,而E8M0与fp32数值范围一致,经过rounding后可直接存储量化因子,缺点是尾数的丢失会影响精度。
量化粒度:\[1X32\]
量化维度:以Cutlass GEMM的规范,M N K表示矩阵乘的三个维度,需要沿着K维度量化
### Rounding与Clamp
不同于软件模拟,CUDA可以通过PTX或者内置函数高性能地便捷地来完成繁琐的Rouding和Clamp操作。
例如,`cvt.rn.satfinite.e2m1x2.f32` 可以将两个fp32类型的输入,转换为​两个fp4类型的输出
Rounding模式为:`rn`,​round-to-nearest-even​
Clamp模式为:`satfinite`,钳制到目标范围内的最大有限值,​排除无穷和 NaN
更多数据类型和模式参考:[PTX cvt指令](https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cvt)
### 数据布局与量化因子布局
数据布局
- mxfp4需要两两打包为uint8
- mxfp6需要每4个打包为3个uint8,格式参考:[mxfp6 cutlass mm 格式打包](https://github.com/ModelTC/LightX2V/blob/main/lightx2v_kernel/csrc/gemm/mxfp6_quant_kernels_sm120.cu#L74)
量化因子布局
Cutlass Block Scaled GEMMs为了满足矩阵运算加速,对量化因子布局有特殊的swizzle要求
参考:[Scale Factor Layouts](https://github.com/NVIDIA/cutlass/blob/main/media/docs/cpp/blackwell_functionality.md#scale-factor-layouts)
### 量化方法
了解完上述后,目标数据和量化因子两者自身数值的求解,可参考[nvfp4量化基础](https://github.com/theNiemand/lightx2v/blob/main/lightx2v_kernel/docs/zh_CN/nvfp4%E9%87%8F%E5%8C%96%E5%9F%BA%E7%A1%80.md),注意MX-Formats无需量化scale本身
...@@ -42,8 +42,19 @@ limitations under the License. ...@@ -42,8 +42,19 @@ limitations under the License.
/* /*
* From csrc/gemm * From csrc/gemm
*/ */
void scaled_nvfp4_quant_sm120(
torch::Tensor& output, torch::Tensor const& input, torch::Tensor& output_sf, torch::Tensor const& input_sf);
void scaled_mxfp4_quant_sm120(
torch::Tensor& output, torch::Tensor const& input, torch::Tensor& output_sf);
void cutlass_scaled_fp4_mm_sm120( void scaled_mxfp6_quant_sm120(
torch::Tensor& output, torch::Tensor const& input, torch::Tensor& output_sf);
void scaled_mxfp8_quant_sm120(
torch::Tensor& output, torch::Tensor const& input, torch::Tensor& output_sf);
void cutlass_scaled_nvfp4_mm_sm120(
torch::Tensor& D, torch::Tensor& D,
torch::Tensor const& A, torch::Tensor const& A,
torch::Tensor const& B, torch::Tensor const& B,
...@@ -52,16 +63,14 @@ void cutlass_scaled_fp4_mm_sm120( ...@@ -52,16 +63,14 @@ void cutlass_scaled_fp4_mm_sm120(
torch::Tensor const& alpha, torch::Tensor const& alpha,
c10::optional<torch::Tensor> const& bias); c10::optional<torch::Tensor> const& bias);
void cutlass_scaled_mxfp4_mm_sm120(
void scaled_fp4_quant_sm120( torch::Tensor& D,
torch::Tensor& output, torch::Tensor const& input, torch::Tensor& output_sf, torch::Tensor const& input_sf); torch::Tensor const& A,
torch::Tensor const& B,
torch::Tensor const& A_sf,
void scaled_fp8_quant_sm120( torch::Tensor const& B_sf,
torch::Tensor& output, torch::Tensor const& input, torch::Tensor& output_sf); torch::Tensor const& alpha,
c10::optional<torch::Tensor> const& bias);
void scaled_fp6_quant_sm120(
torch::Tensor& output, torch::Tensor const& input, torch::Tensor& output_sf);
void cutlass_scaled_mxfp6_mxfp8_mm_sm120( void cutlass_scaled_mxfp6_mxfp8_mm_sm120(
torch::Tensor& D, torch::Tensor& D,
......
import torch import torch
def cutlass_scaled_fp4_mm(mat_a, mat_b, scales_a, scales_b, alpha, bias=None): def cutlass_scaled_nvfp4_mm(mat_a, mat_b, scales_a, scales_b, alpha, bias=None):
m, n = mat_a.shape[0], mat_b.shape[0] m, n = mat_a.shape[0], mat_b.shape[0]
out = torch.empty((m, n), dtype=torch.bfloat16, device=mat_a.device) out = torch.empty((m, n), dtype=torch.bfloat16, device=mat_a.device)
torch.ops.lightx2v_kernel.cutlass_scaled_fp4_mm_sm120.default(out, mat_a, mat_b, scales_a, scales_b, alpha, bias) torch.ops.lightx2v_kernel.cutlass_scaled_nvfp4_mm_sm120.default(out, mat_a, mat_b, scales_a, scales_b, alpha, bias)
return out return out
def scaled_fp4_quant(input: torch.Tensor, input_global_scale: torch.Tensor): def scaled_nvfp4_quant(input: torch.Tensor, input_global_scale: torch.Tensor):
""" """
Quantize input tensor to FP4 and return quantized tensor and scale. Quantize input tensor to FP4 and return quantized tensor and scale.
...@@ -50,12 +50,25 @@ def scaled_fp4_quant(input: torch.Tensor, input_global_scale: torch.Tensor): ...@@ -50,12 +50,25 @@ def scaled_fp4_quant(input: torch.Tensor, input_global_scale: torch.Tensor):
# rounded_n = ((scale_n + 4 - 1) // 4) * 4 # rounded_n = ((scale_n + 4 - 1) // 4) * 4
output_scale = torch.zeros((((m + 128 - 1) // 128) * 128, (n // block_size + 4 - 1) // 4), device=device, dtype=torch.int32) output_scale = torch.zeros((((m + 128 - 1) // 128) * 128, (n // block_size + 4 - 1) // 4), device=device, dtype=torch.int32)
torch.ops.lightx2v_kernel.scaled_fp4_quant_sm120.default(output, input, output_scale, input_global_scale) torch.ops.lightx2v_kernel.scaled_nvfp4_quant_sm120.default(output, input, output_scale, input_global_scale)
output_scale = output_scale.view(torch.float8_e4m3fn) output_scale = output_scale.view(torch.float8_e4m3fn)
return output, output_scale return output, output_scale
def scaled_fp6_quant(input: torch.Tensor): def scaled_mxfp4_quant(input: torch.Tensor):
m, n = input.shape
block_size = 32
device = input.device
output = torch.empty((m, n // 2), device=device, dtype=torch.uint8)
output_scale = torch.zeros(((m + 128 - 1) // 128 * 128, (n // block_size + 4 - 1) // 4), device=device, dtype=torch.int32)
torch.ops.lightx2v_kernel.scaled_mxfp4_quant_sm120.default(output, input, output_scale)
output_scale = output_scale.view(torch.float8_e8m0fnu)
return output, output_scale
def scaled_mxfp6_quant(input: torch.Tensor):
m, n = input.shape m, n = input.shape
block_size = 32 block_size = 32
device = input.device device = input.device
...@@ -63,12 +76,12 @@ def scaled_fp6_quant(input: torch.Tensor): ...@@ -63,12 +76,12 @@ def scaled_fp6_quant(input: torch.Tensor):
output = torch.empty((m, 3 * n // 4), device=device, dtype=torch.uint8) output = torch.empty((m, 3 * n // 4), device=device, dtype=torch.uint8)
output_scale = torch.zeros(((m + 128 - 1) // 128 * 128, (n // block_size + 4 - 1) // 4), device=device, dtype=torch.int32) output_scale = torch.zeros(((m + 128 - 1) // 128 * 128, (n // block_size + 4 - 1) // 4), device=device, dtype=torch.int32)
torch.ops.lightx2v_kernel.scaled_fp6_quant_sm120.default(output, input, output_scale) torch.ops.lightx2v_kernel.scaled_mxfp6_quant_sm120.default(output, input, output_scale)
output_scale = output_scale.view(torch.float8_e8m0fnu) output_scale = output_scale.view(torch.float8_e8m0fnu)
return output, output_scale return output, output_scale
def scaled_fp8_quant(input: torch.Tensor): def scaled_mxfp8_quant(input: torch.Tensor):
m, n = input.shape m, n = input.shape
block_size = 32 block_size = 32
device = input.device device = input.device
...@@ -76,11 +89,18 @@ def scaled_fp8_quant(input: torch.Tensor): ...@@ -76,11 +89,18 @@ def scaled_fp8_quant(input: torch.Tensor):
output = torch.empty((m, n), device=device, dtype=torch.uint8) output = torch.empty((m, n), device=device, dtype=torch.uint8)
output_scale = torch.empty(((m + 128 - 1) // 128 * 128, (n // block_size + 4 - 1) // 4), device=device, dtype=torch.int32) output_scale = torch.empty(((m + 128 - 1) // 128 * 128, (n // block_size + 4 - 1) // 4), device=device, dtype=torch.int32)
torch.ops.lightx2v_kernel.scaled_fp8_quant_sm120.default(output, input, output_scale) torch.ops.lightx2v_kernel.scaled_mxfp8_quant_sm120.default(output, input, output_scale)
output_scale = output_scale.view(torch.float8_e8m0fnu) output_scale = output_scale.view(torch.float8_e8m0fnu)
return output, output_scale return output, output_scale
def cutlass_scaled_mxfp4_mm(mat_a, mat_b, scales_a, scales_b, alpha, bias=None):
m, n = mat_a.shape[0], mat_b.shape[0]
out = torch.empty((m, n), dtype=torch.bfloat16, device=mat_a.device)
torch.ops.lightx2v_kernel.cutlass_scaled_mxfp4_mm_sm120.default(out, mat_a, mat_b, scales_a, scales_b, alpha, bias)
return out
def cutlass_scaled_mxfp6_mxfp8_mm(mat_a, mat_b, scales_a, scales_b, alpha, bias=None): def cutlass_scaled_mxfp6_mxfp8_mm(mat_a, mat_b, scales_a, scales_b, alpha, bias=None):
m, n = mat_a.shape[0], mat_b.shape[0] m, n = mat_a.shape[0], mat_b.shape[0]
out = torch.empty((m, n), dtype=torch.bfloat16, device=mat_a.device) out = torch.empty((m, n), dtype=torch.bfloat16, device=mat_a.device)
......
import torch
from lightx2v_kernel.gemm import scaled_mxfp4_quant, cutlass_scaled_mxfp4_mm
import time
class MMWeightMxfp4ActMxfp4:
def __init__(self, weight, bias):
self.load_fp4_weight(weight, bias)
self.act_quant_func = self.act_quant_fp4
self.set_alpha()
@torch.no_grad()
def apply(self, input_tensor):
input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
output_tensor = cutlass_scaled_mxfp4_mm(input_tensor_quant, self.weight, input_tensor_scale, self.weight_scale, alpha=self.alpha, bias=self.bias)
return output_tensor
@torch.no_grad()
def load_fp4_weight(self, weight, bias):
self.weight, self.weight_scale = scaled_mxfp4_quant(weight)
self.bias = bias
def set_alpha(self):
self.alpha = torch.tensor(1.0, dtype=torch.float32, device=self.weight.device)
@torch.no_grad()
def act_quant_fp4(self, x):
return scaled_mxfp4_quant(x)
def test_speed(m, k, n):
with torch.no_grad():
input_tensor = torch.randn(m, k, dtype=torch.bfloat16).cuda()
weight = torch.randn(n, k, dtype=torch.bfloat16, device="cuda")
# bias = torch.randn(1, n, dtype=torch.bfloat16).cuda()
bias = None
mm = MMWeightMxfp4ActMxfp4(weight, bias)
# warmup
output_tensor = mm.apply(input_tensor)
torch.cuda.synchronize()
start_time = time.time()
for i in range(100):
output_tensor = mm.apply(input_tensor)
torch.cuda.synchronize()
end_time = time.time()
lightx2v_kernel_time = (end_time - start_time) / 100
print(f"lightx2v-kernel time: {lightx2v_kernel_time}")
input_tensor = torch.randn(m, n, dtype=torch.bfloat16).cuda()
weight = torch.randn(k, n, dtype=torch.bfloat16, device="cuda")
bias = torch.randn(1, k, dtype=torch.bfloat16).cuda()
linear = torch.nn.Linear(k, n, bias=False).cuda()
linear.weight.data = weight
# linear.bias.data = bias
# warmup
ref_output_tensor = linear(input_tensor)
torch.cuda.synchronize()
start_time = time.time()
for i in range(100):
ref_output_tensor = linear(input_tensor)
torch.cuda.synchronize()
end_time = time.time()
ref_time = (end_time - start_time) / 100
print(f"ref time: {ref_time}")
print(f"speedup: {ref_time / lightx2v_kernel_time:.3f}")
def test_accuracy(m, k, n):
with torch.no_grad():
input_tensor = torch.randn(m, k, dtype=torch.bfloat16).cuda()
weight = torch.randn(n, k, dtype=torch.bfloat16, device="cuda")
# bias = torch.randn(1, n, dtype=torch.bfloat16).cuda()
bias = None
linear = torch.nn.Linear(k, n, bias=False).cuda()
linear.weight.data = weight
# linear.bias.data = bias
ref_output_tensor = linear(input_tensor)
mm = MMWeightMxfp4ActMxfp4(weight, bias)
output_tensor = mm.apply(input_tensor)
# print(f"ref_output_tensor: {ref_output_tensor}")
# print(f"output_tensor: {output_tensor}")
# cosine
cos = torch.nn.functional.cosine_similarity(ref_output_tensor.flatten(), output_tensor.flatten(), dim=0)
print(f"cos : {cos}")
if __name__ == "__main__":
test_sizes = [
(32130, 5120, 5120),
(512, 5120, 5120),
(257, 5120, 5120),
(32130, 5120, 13824),
(32130, 13824, 5120),
(75348, 5120, 5120),
(75348, 13824, 5120),
(32760, 1536, 1536),
(512, 1536, 1536),
(32760, 1536, 8960),
(32760, 8960, 1536),
]
for i, (m, k, n) in enumerate(test_sizes):
print("-" * 30)
print(f"测试 {i + 1}: 张量大小 ({m}, {k}, {n})")
test_accuracy(m, k, n)
test_speed(m, k, n)
import torch
import time
from test_bench import MMWeightMxfp4ActMxfp4
def test_speed(m, k, n):
with torch.no_grad():
input_tensor = torch.randn(m, k, dtype=torch.bfloat16).cuda()
weight = torch.randn(n, k, dtype=torch.bfloat16, device="cuda")
bias = torch.ones(1, n, dtype=torch.bfloat16).cuda() * 50
mm = MMWeightMxfp4ActMxfp4(weight, bias)
# warmup
output_tensor = mm.apply(input_tensor)
torch.cuda.synchronize()
start_time = time.time()
for i in range(100):
output_tensor = mm.apply(input_tensor)
torch.cuda.synchronize()
end_time = time.time()
lightx2v_kernel_time = (end_time - start_time) / 100
print(f"lightx2v-kernel time: {lightx2v_kernel_time}")
input_tensor = torch.randn(m, n, dtype=torch.bfloat16).cuda()
weight = torch.randn(k, n, dtype=torch.bfloat16, device="cuda")
bias = torch.randn(1, k, dtype=torch.bfloat16).cuda()
linear = torch.nn.Linear(k, n, bias=True).cuda()
linear.weight.data = weight
linear.bias.data = bias
# warmup
ref_output_tensor = linear(input_tensor)
torch.cuda.synchronize()
start_time = time.time()
for i in range(100):
ref_output_tensor = linear(input_tensor)
torch.cuda.synchronize()
end_time = time.time()
ref_time = (end_time - start_time) / 100
print(f"ref time: {ref_time}")
print(f"speedup: {ref_time / lightx2v_kernel_time:.3f}")
def test_accuracy(m, k, n):
with torch.no_grad():
input_tensor = torch.randn(m, k, dtype=torch.bfloat16).cuda()
weight = torch.randn(n, k, dtype=torch.bfloat16, device="cuda")
bias = torch.ones(1, n, dtype=torch.bfloat16).cuda() * 50
linear = torch.nn.Linear(k, n, bias=True).cuda()
linear.weight.data = weight
linear.bias.data = bias
ref_output_tensor = linear(input_tensor)
mm = MMWeightMxfp4ActMxfp4(weight, bias)
output_tensor = mm.apply(input_tensor)
# print(f"ref_output_tensor: {ref_output_tensor}")
# print(f"output_tensor: {output_tensor}")
# cosine
cos = torch.nn.functional.cosine_similarity(ref_output_tensor.flatten(), output_tensor.flatten(), dim=0)
print(f"cos : {cos}")
if __name__ == "__main__":
test_sizes = [
(32130, 5120, 5120),
(512, 5120, 5120),
(257, 5120, 5120),
(32130, 5120, 13824),
(32130, 13824, 5120),
(75348, 5120, 5120),
(75348, 13824, 5120),
(32760, 1536, 1536),
(512, 1536, 1536),
(32760, 1536, 8960),
(32760, 8960, 1536),
]
for i, (m, k, n) in enumerate(test_sizes):
print("-" * 30)
print(f"测试 {i + 1}: 张量大小 ({m}, {k}, {n})")
test_accuracy(m, k, n)
test_speed(m, k, n)
import unittest
import torch
from lightx2v_kernel.gemm import cutlass_scaled_mxfp4_mm
from lightx2v_kernel.gemm import scaled_mxfp4_quant
from torch.nn.functional import linear
from lightx2v_kernel.utils import error, benchmark
class TestQuantBF162MXFP4(unittest.TestCase):
def setUp(self):
self.tokens = [128, 257, 512, 1024, 13325, 32130, 32760] # , 75348
self.channels = [128, 1536, 5120, 8960] # , 13824
self.hiddenDims = [128, 1536, 3072, 5120, 8960, 12800] # , 13824
self.device = "cuda"
self.dtype = torch.bfloat16
def test_accuracy(self):
"""Test the accuracy of quantization from BF16 to MXFP4."""
for m in self.tokens:
for k in self.hiddenDims:
for n in self.channels:
with self.subTest(shape=[m, k, n]):
activation = torch.randn(m, k, dtype=self.dtype, device=self.device)
activation_quant_pred, activation_scale_pred = scaled_mxfp4_quant(activation)
weight = torch.randn(n, k, dtype=self.dtype, device=self.device)
weight_quant_pred, weight_scale_pred = scaled_mxfp4_quant(weight)
bias = torch.rand(1, n, dtype=self.dtype, device=self.device) * 10
alpha = torch.tensor(1.0, device=self.device, dtype=torch.float32)
mm_pred = cutlass_scaled_mxfp4_mm(activation_quant_pred, weight_quant_pred, activation_scale_pred, weight_scale_pred, alpha=alpha, bias=bias)
mm_real = linear(activation, weight, bias=bias).to(torch.bfloat16)
# mxfp4_mxfp4 mm have very low accuracy, so we set the threshold to 3e-2.
self.assertTrue(error(mm_pred, mm_real) < 3e-2, f"Accuracy test failed for shape {m, k, n}: Error {error(mm_pred, mm_real)} exceeds threshold.")
def test_performance(self):
"""Benchmark the performance of Activation quantization from BF16 to MXFP4."""
for m in self.tokens:
for k in self.hiddenDims:
with self.subTest(shape=[m, k]):
input = torch.randn(m, k, dtype=self.dtype, device=self.device)
shape = [m, k]
tflops = 2 * (m * k / 1024**4)
benchmark(scaled_mxfp4_quant, shape, tflops, 100, input)
if __name__ == "__main__":
unittest.main()
import torch import torch
from lightx2v_kernel.gemm import scaled_fp8_quant, scaled_fp6_quant, cutlass_scaled_mxfp6_mxfp8_mm from lightx2v_kernel.gemm import scaled_mxfp8_quant, scaled_mxfp6_quant, cutlass_scaled_mxfp6_mxfp8_mm
import time import time
class MMWeightMxfp8ActMxfp6: class MMWeightMxfp6ActMxfp8:
def __init__(self, weight, bias): def __init__(self, weight, bias):
self.load_fp6_weight(weight, bias) self.load_fp6_weight(weight, bias)
self.act_quant_func = self.act_quant_fp8 self.act_quant_func = self.act_quant_fp8
...@@ -17,7 +17,7 @@ class MMWeightMxfp8ActMxfp6: ...@@ -17,7 +17,7 @@ class MMWeightMxfp8ActMxfp6:
@torch.no_grad() @torch.no_grad()
def load_fp6_weight(self, weight, bias): def load_fp6_weight(self, weight, bias):
self.weight, self.weight_scale = scaled_fp6_quant(weight) self.weight, self.weight_scale = scaled_mxfp6_quant(weight)
self.bias = bias self.bias = bias
def set_alpha(self): def set_alpha(self):
...@@ -25,7 +25,7 @@ class MMWeightMxfp8ActMxfp6: ...@@ -25,7 +25,7 @@ class MMWeightMxfp8ActMxfp6:
@torch.no_grad() @torch.no_grad()
def act_quant_fp8(self, x): def act_quant_fp8(self, x):
return scaled_fp8_quant(x) return scaled_mxfp8_quant(x)
def test_speed(m, k, n): def test_speed(m, k, n):
...@@ -35,7 +35,7 @@ def test_speed(m, k, n): ...@@ -35,7 +35,7 @@ def test_speed(m, k, n):
# bias = torch.randn(1, n, dtype=torch.bfloat16).cuda() # bias = torch.randn(1, n, dtype=torch.bfloat16).cuda()
bias = None bias = None
mm = MMWeightMxfp8ActMxfp6(weight, bias) mm = MMWeightMxfp6ActMxfp8(weight, bias)
# warmup # warmup
output_tensor = mm.apply(input_tensor) output_tensor = mm.apply(input_tensor)
...@@ -87,7 +87,7 @@ def test_accuracy(m, k, n): ...@@ -87,7 +87,7 @@ def test_accuracy(m, k, n):
ref_output_tensor = linear(input_tensor) ref_output_tensor = linear(input_tensor)
mm = MMWeightMxfp8ActMxfp6(weight, bias) mm = MMWeightMxfp6ActMxfp8(weight, bias)
output_tensor = mm.apply(input_tensor) output_tensor = mm.apply(input_tensor)
......
import torch import torch
import time import time
from test_bench import MMWeightMxfp8ActMxfp6 from test_bench import MMWeightMxfp6ActMxfp8
def test_speed(m, k, n): def test_speed(m, k, n):
...@@ -9,7 +9,7 @@ def test_speed(m, k, n): ...@@ -9,7 +9,7 @@ def test_speed(m, k, n):
weight = torch.randn(n, k, dtype=torch.bfloat16, device="cuda") weight = torch.randn(n, k, dtype=torch.bfloat16, device="cuda")
bias = torch.randn(1, n, dtype=torch.bfloat16).cuda() bias = torch.randn(1, n, dtype=torch.bfloat16).cuda()
mm = MMWeightMxfp8ActMxfp6(weight, bias) mm = MMWeightMxfp6ActMxfp8(weight, bias)
# warmup # warmup
output_tensor = mm.apply(input_tensor) output_tensor = mm.apply(input_tensor)
...@@ -60,7 +60,7 @@ def test_accuracy(m, k, n): ...@@ -60,7 +60,7 @@ def test_accuracy(m, k, n):
ref_output_tensor = linear(input_tensor) ref_output_tensor = linear(input_tensor)
mm = MMWeightMxfp8ActMxfp6(weight, bias) mm = MMWeightMxfp6ActMxfp8(weight, bias)
output_tensor = mm.apply(input_tensor) output_tensor = mm.apply(input_tensor)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment