Commit 5103aef7 authored by Xtra's avatar Xtra Committed by GitHub
Browse files

add mxfp6 quant kernel and some tests (#126)

parent 514ea716
...@@ -94,6 +94,7 @@ set(SOURCES ...@@ -94,6 +94,7 @@ set(SOURCES
"csrc/gemm/nvfp4_scaled_mm_kernels_sm120.cu" "csrc/gemm/nvfp4_scaled_mm_kernels_sm120.cu"
"csrc/gemm/nvfp4_quant_kernels_sm120.cu" "csrc/gemm/nvfp4_quant_kernels_sm120.cu"
"csrc/gemm/mxfp8_quant_kernels_sm120.cu" "csrc/gemm/mxfp8_quant_kernels_sm120.cu"
"csrc/gemm/mxfp6_quant_kernels_sm120.cu"
"csrc/gemm/mxfp6_mxfp8_scaled_mm_kernels_sm120.cu" "csrc/gemm/mxfp6_mxfp8_scaled_mm_kernels_sm120.cu"
"csrc/gemm/mxfp8_scaled_mm_kernels_sm120.cu" "csrc/gemm/mxfp8_scaled_mm_kernels_sm120.cu"
"csrc/common_extension.cc" "csrc/common_extension.cc"
......
...@@ -20,6 +20,10 @@ TORCH_LIBRARY_FRAGMENT(lightx2v_kernel, m) { ...@@ -20,6 +20,10 @@ TORCH_LIBRARY_FRAGMENT(lightx2v_kernel, m) {
"scaled_fp8_quant_sm120(Tensor! output, Tensor! input," "scaled_fp8_quant_sm120(Tensor! output, Tensor! input,"
" Tensor! output_scale) -> ()"); " Tensor! output_scale) -> ()");
m.impl("scaled_fp8_quant_sm120", torch::kCUDA, &scaled_fp8_quant_sm120); m.impl("scaled_fp8_quant_sm120", torch::kCUDA, &scaled_fp8_quant_sm120);
m.def(
"scaled_fp6_quant_sm120(Tensor! output, Tensor! input,"
" Tensor! output_scale) -> ()");
m.impl("scaled_fp6_quant_sm120", torch::kCUDA, &scaled_fp6_quant_sm120);
m.def( m.def(
"cutlass_scaled_mxfp6_mxfp8_mm_sm120(Tensor! out, Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, Tensor " "cutlass_scaled_mxfp6_mxfp8_mm_sm120(Tensor! out, Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, Tensor "
......
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_fp8.h>
#include <cuda_fp6.h>
#include <cuda_runtime.h>
#include <cuda_runtime_api.h>
#include <torch/all.h>
#include "utils.h"
// Get type2 from type or vice versa (applied to half and bfloat16)
template <typename T>
struct TypeConverter {
using Type = half2;
}; // keep for generality
template <>
struct TypeConverter<half2> {
using Type = half;
};
template <>
struct TypeConverter<half> {
using Type = half2;
};
template <>
struct TypeConverter<__nv_bfloat162> {
using Type = __nv_bfloat16;
};
template <>
struct TypeConverter<__nv_bfloat16> {
using Type = __nv_bfloat162;
};
#define ELTS_PER_THREAD 8
constexpr int CVT_FP6_ELTS_PER_THREAD = 8;
constexpr int CVT_FP6_SF_VEC_SIZE = 32;
struct uint8x6_t {
uint8_t elts[6];
};
// Convert 4 float2 values into 8 e3m2 values (represented as one uint8x6_t).
inline __device__ uint8x6_t fp32_vec_to_e3m2(float2 (&array)[4]) {
uint64_t val;
asm volatile(
"{\n"
".reg .b16 pack0;\n"
".reg .b16 pack1;\n"
".reg .b16 pack2;\n"
".reg .b16 pack3;\n"
"cvt.rn.satfinite.e3m2x2.f32 pack0, %2, %1;\n"
"cvt.rn.satfinite.e3m2x2.f32 pack1, %4, %3;\n"
"cvt.rn.satfinite.e3m2x2.f32 pack2, %6, %5;\n"
"cvt.rn.satfinite.e3m2x2.f32 pack3, %8, %7;\n"
"mov.b64 %0, {pack0, pack1, pack2, pack3};\n"
"}"
: "=l"(val)
: "f"(array[0].x),
"f"(array[0].y),
"f"(array[1].x),
"f"(array[1].y),
"f"(array[2].x),
"f"(array[2].y),
"f"(array[3].x),
"f"(array[3].y));
uint8x6_t result;
// pack 8 uint8_t into 6 uint8_t
// here is how to pack:
// 4个fp6 a b c d. a:[a5 a4 a3 a2 a1 a0], b..., c..., d...
// 3个unint8 pack0 pack1 pack2
// packed0: [b1 b0][a5 a4 a3 a2 a1 a0]
// packed1: [c3 c2 c1 c0][b5 b4 b3 b2]
// packed2: [d5 d4 d3 d2 d1 d0][c5 c4]
// lower 4 uint8_t
uint8_t l_val_0 = val & 0xFF;
uint8_t l_val_1 = (val >> 8) & 0xFF;
uint8_t l_val_2 = (val >> 16) & 0xFF;
uint8_t l_val_3 = (val >> 24) & 0xFF;
// higher 4 uint8_t
uint8_t h_val_0 = (val >> 32) & 0xFF;
uint8_t h_val_1 = (val >> 40) & 0xFF;
uint8_t h_val_2 = (val >> 48) & 0xFF;
uint8_t h_val_3 = (val >> 56) & 0xFF;
// pack result
result.elts[0] = (l_val_1 << 6) | l_val_0;
result.elts[1] = (l_val_2 << 4) | (l_val_1 >> 2);
result.elts[2] = (l_val_3 << 2) | (l_val_2 >> 4);
result.elts[3] = (h_val_1 << 6) | h_val_0;
result.elts[4] = (h_val_2 << 4) | (h_val_1 >> 2);
result.elts[5] = (h_val_3 << 2) | (h_val_2 >> 4);
return result;
}
// Fast reciprocal.
inline __device__ float reciprocal_approximate_ftz(float a) {
float b;
asm volatile("rcp.approx.ftz.f32 %0, %1;\n" : "=f"(b) : "f"(a));
return b;
}
template <class SFType, int CVT_FP6_NUM_THREADS_PER_SF>
__device__ uint8_t* get_sf_out_address(int rowIdx, int colIdx, int numCols, SFType* SFout) {
// #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
static_assert(CVT_FP6_NUM_THREADS_PER_SF == 4);
// one of 4 threads write one SF to global memory.
// TODO: stage through smem for packed STG.32
// is it better than STG.8 from 4 threads ?
if (threadIdx.x % CVT_FP6_NUM_THREADS_PER_SF == 0) {
// SF vector index (32 elements share one SF in the K dimension).
int32_t kIdx = colIdx / CVT_FP6_NUM_THREADS_PER_SF;
int32_t mIdx = rowIdx;
// SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)]
// --> index [mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx]
int32_t mTileIdx = mIdx / (32 * 4);
// SF vector size 32.
int factor = CVT_FP6_SF_VEC_SIZE * 4;
int32_t numKTiles = (numCols + factor - 1) / factor;
int64_t mTileStride = numKTiles * 32 * 4 * 4;
int32_t kTileIdx = (kIdx / 4);
int64_t kTileStride = 32 * 4 * 4;
// M tile layout [32, 4] is column-major.
int32_t outerMIdx = (mIdx % 32); // same as (mIdx % 128) % 32
int64_t outerMStride = 4 * 4;
int32_t innerMIdx = (mIdx % (32 * 4)) / 32;
int64_t innerMStride = 4;
int32_t innerKIdx = (kIdx % 4);
int64_t innerKStride = 1;
// Compute the global offset.
int64_t SFOffset = mTileIdx * mTileStride + kTileIdx * kTileStride + outerMIdx * outerMStride +
innerMIdx * innerMStride + innerKIdx * innerKStride;
return reinterpret_cast<uint8_t*>(SFout) + SFOffset;
} else {
// Other threads do not write to SFout.
return nullptr;
}
}
// Define a 16 bytes packed data type.
template <class Type>
struct PackedVec {
typename TypeConverter<Type>::Type elts[4];
};
// template <>
// struct PackedVec<__nv_fp8_e4m3> {
// __nv_fp8x2_e4m3 elts[8];
// };
template <class Type> // Type can be half or bfloat16
__device__ uint8x6_t cvt_warp_fp16_to_fp6(PackedVec<Type>& vec, uint8_t* SFout) {
// #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
// Get absolute maximum values among the local 8 values.
auto localMax = __habs2(vec.elts[0]);
// Local maximum value.
#pragma unroll
for (int i = 1; i < CVT_FP6_ELTS_PER_THREAD / 2; i++) {
localMax = __hmax2(localMax, __habs2(vec.elts[i]));
}
// Get the absolute maximum among all 32 values (four threads).
localMax = __hmax2(__shfl_xor_sync(uint32_t(-1), localMax, 1), localMax);
localMax = __hmax2(__shfl_xor_sync(uint32_t(-1), localMax, 2), localMax);
// Get the final absolute maximum values.
float vecMax = float(__hmax(localMax.x, localMax.y));
// Get the SF (max value of the vector / max value of e3m2).
// maximum value of e3m2 = 28.0.
// TODO: use half as compute data type.
float SFValue = (vecMax / 28.0f);
// 8 bits representation of the SF.
uint8_t fp8SFVal;
// Write the SF to global memory (STG.8).
__nv_fp8_e8m0 tmp;
tmp.__x = __nv_cvt_float_to_e8m0(SFValue, __NV_SATFINITE, cudaRoundPosInf);
SFValue = static_cast<float>(tmp);
fp8SFVal = tmp.__x;
float outputScale =
SFValue != 0 ? reciprocal_approximate_ftz(SFValue) : 0.0f;
if (SFout) {
// Write the SF to global memory (STG.8).
*SFout = fp8SFVal;
}
// Convert the input to float.
float2 fp2Vals[CVT_FP6_ELTS_PER_THREAD / 2];
#pragma unroll
for (int i = 0; i < CVT_FP6_ELTS_PER_THREAD / 2; i++) {
if constexpr (std::is_same_v<Type, half>) {
fp2Vals[i] = __half22float2(vec.elts[i]);
} else {
fp2Vals[i] = __bfloat1622float2(vec.elts[i]);
}
fp2Vals[i].x *= outputScale;
fp2Vals[i].y *= outputScale;
}
// Convert to e3m2 values.
uint8x6_t e3m2Vec = fp32_vec_to_e3m2(fp2Vals);
return e3m2Vec;
}
template <class Type> // Type can be half or bfloat16
__global__ void
// #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
__launch_bounds__(256, 6) cvt_fp16_to_fp6(
// #else
// cvt_fp16_to_fp6(
// #endif
int32_t numRows, int32_t numCols, Type const* in, uint8x6_t* out, uint32_t* SFout) {
// #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
using PackedVec = PackedVec<Type>;
static constexpr int CVT_FP6_NUM_THREADS_PER_SF = (CVT_FP6_SF_VEC_SIZE / CVT_FP6_ELTS_PER_THREAD);
static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP6_ELTS_PER_THREAD, "Vec size is not matched.");
// Input tensor row/col loops.
for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x) {
for (int colIdx = threadIdx.x; colIdx < numCols / CVT_FP6_ELTS_PER_THREAD; colIdx += blockDim.x) {
int64_t inOffset = rowIdx * (numCols / CVT_FP6_ELTS_PER_THREAD) + colIdx;
PackedVec in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset];
// Get the output tensor offset.
// Same as inOffset because 8 elements(E3M2) are packed into one uint8x6_t.
int64_t outOffset = inOffset;
auto& out_pos = out[outOffset];
auto sf_out =
get_sf_out_address<uint32_t, CVT_FP6_NUM_THREADS_PER_SF>(rowIdx, colIdx, numCols, SFout);
out_pos = cvt_warp_fp16_to_fp6<Type>(in_vec, sf_out);
}
}
// #endif
}
template <typename T>
void invokeFP6Quantization(
int m,
int n,
T const* input,
int64_t* output,
int32_t* SFOuput,
int multiProcessorCount,
cudaStream_t stream) {
// Grid, Block size.
// Each thread converts 8 values.
dim3 block(std::min(int(n / ELTS_PER_THREAD), 256));
// Get number of blocks per SM (assume we can fully utilize the SM).
int const numBlocksPerSM = 1536 / block.x;
dim3 grid(std::min(int(m), multiProcessorCount * numBlocksPerSM));
// Launch the cvt kernel.
cvt_fp16_to_fp6<T>
<<<grid, block, 0, stream>>>(
m, n, input, reinterpret_cast<uint8x6_t*>(output), reinterpret_cast<uint32_t*>(SFOuput));
}
// Instantiate the function.
template void invokeFP6Quantization(
int m,
int n,
half const* input,
int64_t* output,
int32_t* SFOuput,
int multiProcessorCount,
cudaStream_t stream);
template void invokeFP6Quantization(
int m,
int n,
__nv_bfloat16 const* input,
int64_t* output,
int32_t* SFOuput,
int multiProcessorCount,
cudaStream_t stream);
inline int getMultiProcessorCount() {
static int multi_processor_count = []() {
int device_id = 0;
int count = 0;
// Get the current CUDA device ID
CHECK_CUDA_SUCCESS(cudaGetDevice(&device_id));
// Get the number of multiprocessors for the current device
CHECK_CUDA_SUCCESS(cudaDeviceGetAttribute(&count, cudaDevAttrMultiProcessorCount, device_id));
return count; // Initialize the static variable
}();
return multi_processor_count; // Return the cached value on subsequent calls
}
void scaled_fp6_quant_sm120(
torch::Tensor& output, torch::Tensor const& input, torch::Tensor& output_sf) {
int32_t m = input.size(0);
int32_t n = input.size(1);
TORCH_CHECK(n % 32 == 0, "The N dimension must be multiple of 32.");
int multiProcessorCount = getMultiProcessorCount();
auto sf_out = static_cast<int32_t*>(output_sf.data_ptr());
auto output_ptr = static_cast<int64_t*>(output.data_ptr());
at::cuda::CUDAGuard device_guard{(char)input.get_device()};
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(input.get_device());
switch (input.scalar_type()) {
case torch::kHalf: {
auto input_ptr = reinterpret_cast<half const*>(input.data_ptr());
invokeFP6Quantization(m, n, input_ptr, output_ptr, sf_out, multiProcessorCount, stream);
break;
}
case torch::kBFloat16: {
auto input_ptr = reinterpret_cast<__nv_bfloat16 const*>(input.data_ptr());
invokeFP6Quantization(m, n, input_ptr, output_ptr, sf_out, multiProcessorCount, stream);
break;
}
default: {
std::cerr << "Observing: " << input.scalar_type() << " for the input datatype which is invalid";
throw std::runtime_error("Unsupported input data type for quantize_to_fp6.");
}
}
}
...@@ -287,7 +287,7 @@ void scaled_fp8_quant_sm120( ...@@ -287,7 +287,7 @@ void scaled_fp8_quant_sm120(
int32_t m = input.size(0); int32_t m = input.size(0);
int32_t n = input.size(1); int32_t n = input.size(1);
TORCH_CHECK(n % 32 == 0, "The N dimension must be multiple of 16."); TORCH_CHECK(n % 32 == 0, "The N dimension must be multiple of 32.");
int multiProcessorCount = getMultiProcessorCount(); int multiProcessorCount = getMultiProcessorCount();
......
...@@ -60,6 +60,8 @@ void scaled_fp4_quant_sm120( ...@@ -60,6 +60,8 @@ void scaled_fp4_quant_sm120(
void scaled_fp8_quant_sm120( void scaled_fp8_quant_sm120(
torch::Tensor& output, torch::Tensor const& input, torch::Tensor& output_sf); torch::Tensor& output, torch::Tensor const& input, torch::Tensor& output_sf);
void scaled_fp6_quant_sm120(
torch::Tensor& output, torch::Tensor const& input, torch::Tensor& output_sf);
void cutlass_scaled_mxfp6_mxfp8_mm_sm120( void cutlass_scaled_mxfp6_mxfp8_mm_sm120(
torch::Tensor& D, torch::Tensor& D,
......
...@@ -48,13 +48,26 @@ def scaled_fp4_quant(input: torch.Tensor, input_global_scale: torch.Tensor): ...@@ -48,13 +48,26 @@ def scaled_fp4_quant(input: torch.Tensor, input_global_scale: torch.Tensor):
# rounded_m = ((m + 128 - 1) // 128) * 128 # rounded_m = ((m + 128 - 1) // 128) * 128
# scale_n = n // block_size # scale_n = n // block_size
# rounded_n = ((scale_n + 4 - 1) // 4) * 4 # rounded_n = ((scale_n + 4 - 1) // 4) * 4
output_scale = torch.empty((((m + 128 - 1) // 128) * 128, (n // block_size + 4 - 1) // 4), device=device, dtype=torch.int32) output_scale = torch.zeros((((m + 128 - 1) // 128) * 128, (n // block_size + 4 - 1) // 4), device=device, dtype=torch.int32)
torch.ops.lightx2v_kernel.scaled_fp4_quant_sm120.default(output, input, output_scale, input_global_scale) torch.ops.lightx2v_kernel.scaled_fp4_quant_sm120.default(output, input, output_scale, input_global_scale)
output_scale = output_scale.view(torch.float8_e4m3fn) output_scale = output_scale.view(torch.float8_e4m3fn)
return output, output_scale return output, output_scale
def scaled_fp6_quant(input: torch.Tensor):
m, n = input.shape
block_size = 32
device = input.device
output = torch.empty((m, 3 * n // 4), device=device, dtype=torch.uint8)
output_scale = torch.zeros(((m + 128 - 1) // 128 * 128, (n // block_size + 4 - 1) // 4), device=device, dtype=torch.int32)
torch.ops.lightx2v_kernel.scaled_fp6_quant_sm120.default(output, input, output_scale)
output_scale = output_scale.view(torch.float8_e8m0fnu)
return output, output_scale
def scaled_fp8_quant(input: torch.Tensor): def scaled_fp8_quant(input: torch.Tensor):
m, n = input.shape m, n = input.shape
block_size = 32 block_size = 32
......
import torch
from lightx2v_kernel.gemm import scaled_fp8_quant, scaled_fp6_quant, cutlass_scaled_mxfp6_mxfp8_mm
import time
class MMWeightMxfp8ActMxfp6:
def __init__(self, weight, bias):
self.load_fp6_weight(weight, bias)
self.act_quant_func = self.act_quant_fp8
self.set_alpha()
@torch.no_grad()
def apply(self, input_tensor):
input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
output_tensor = cutlass_scaled_mxfp6_mxfp8_mm(input_tensor_quant, self.weight, input_tensor_scale, self.weight_scale, alpha=self.alpha, bias=self.bias)
return output_tensor
@torch.no_grad()
def load_fp6_weight(self, weight, bias):
self.weight, self.weight_scale = scaled_fp6_quant(weight)
self.bias = bias
def set_alpha(self):
self.alpha = torch.tensor(1.0, dtype=torch.float32, device=self.weight.device)
@torch.no_grad()
def act_quant_fp8(self, x):
return scaled_fp8_quant(x)
def test_speed(m, k, n):
with torch.no_grad():
input_tensor = torch.randn(m, k, dtype=torch.bfloat16).cuda()
weight = torch.randn(n, k, dtype=torch.bfloat16, device="cuda")
# bias = torch.randn(1, n, dtype=torch.bfloat16).cuda()
bias = None
mm = MMWeightMxfp8ActMxfp6(weight, bias)
# warmup
output_tensor = mm.apply(input_tensor)
torch.cuda.synchronize()
start_time = time.time()
for i in range(100):
output_tensor = mm.apply(input_tensor)
torch.cuda.synchronize()
end_time = time.time()
lightx2v_kernel_time = (end_time - start_time) / 100
print(f"lightx2v-kernel time: {lightx2v_kernel_time}")
input_tensor = torch.randn(m, n, dtype=torch.bfloat16).cuda()
weight = torch.randn(k, n, dtype=torch.bfloat16, device="cuda")
bias = torch.randn(1, k, dtype=torch.bfloat16).cuda()
linear = torch.nn.Linear(k, n, bias=False).cuda()
linear.weight.data = weight
# linear.bias.data = bias
# warmup
ref_output_tensor = linear(input_tensor)
torch.cuda.synchronize()
start_time = time.time()
for i in range(100):
ref_output_tensor = linear(input_tensor)
torch.cuda.synchronize()
end_time = time.time()
ref_time = (end_time - start_time) / 100
print(f"ref time: {ref_time}")
print(f"speedup: {ref_time / lightx2v_kernel_time:.3f}")
def test_accuracy(m, k, n):
with torch.no_grad():
input_tensor = torch.randn(m, k, dtype=torch.bfloat16).cuda()
weight = torch.randn(n, k, dtype=torch.bfloat16, device="cuda")
# bias = torch.randn(1, n, dtype=torch.bfloat16).cuda()
bias = None
linear = torch.nn.Linear(k, n, bias=False).cuda()
linear.weight.data = weight
# linear.bias.data = bias
ref_output_tensor = linear(input_tensor)
mm = MMWeightMxfp8ActMxfp6(weight, bias)
output_tensor = mm.apply(input_tensor)
# print(f"ref_output_tensor: {ref_output_tensor}")
# print(f"output_tensor: {output_tensor}")
# cosine
cos = torch.nn.functional.cosine_similarity(ref_output_tensor.flatten(), output_tensor.flatten(), dim=0)
print(f"cos : {cos}")
if __name__ == "__main__":
test_sizes = [
(32130, 5120, 5120),
(512, 5120, 5120),
(257, 5120, 5120),
(32130, 5120, 13824),
(32130, 13824, 5120),
(75348, 5120, 5120),
(75348, 13824, 5120),
(32760, 1536, 1536),
(512, 1536, 1536),
(32760, 1536, 8960),
(32760, 8960, 1536),
]
for i, (m, k, n) in enumerate(test_sizes):
print("-" * 30)
print(f"测试 {i + 1}: 张量大小 ({m}, {k}, {n})")
test_accuracy(m, k, n)
test_speed(m, k, n)
import torch
from torchao.prototype.mx_formats.constants import DTYPE_FP6_E3M2
from torchao.prototype.mx_formats.mx_tensor import to_mx, pack_uint6
def quant2mxfp8(x: torch.Tensor):
block_size = 32
m, _ = x.shape
scale, output = to_mx(x, torch.float8_e4m3fn, block_size=block_size)
return scale.reshape(m, -1), output
def quant2mxfp6(x: torch.Tensor):
block_size = 32
m, _ = x.shape
scale, output = to_mx(x, DTYPE_FP6_E3M2, block_size=block_size, pack_fp6=False)
return scale.reshape(m, -1), output
def scale_pad_and_swizzle(scale: torch.Tensor):
m, s = scale.shape
# pad the m up to 128, s up to 4
padded_m = (m + 127) // 128 * 128
padded_s = (s + 3) // 4 * 4
padded_scale = torch.empty(padded_m, padded_s, device=scale.device, dtype=scale.dtype)
padded_scale[:m, :s] = scale
# swizzle the padded scale
swizzled_scale = padded_scale.reshape(padded_m // 128, 128, padded_s // 4, 4).reshape(padded_m // 128, 4, 32, padded_s // 4, 4).permute(0, 3, 2, 1, 4)
return swizzled_scale.reshape(padded_m, padded_s)
###############################################################
# Packing kernel and func
###############################################################
import triton # noqa: E402
import triton.language as tl # noqa: E402
@triton.autotune(
configs=[
triton.Config({"BLOCK_SIZE_IN": 2}, num_warps=1),
triton.Config({"BLOCK_SIZE_IN": 4}, num_warps=1),
triton.Config({"BLOCK_SIZE_IN": 8}, num_warps=1),
triton.Config({"BLOCK_SIZE_IN": 16}, num_warps=1),
],
key=["n_mx_blocks"],
)
@triton.jit
def triton_pack_uint6_kernel(
input_ptr,
output_ptr,
n_mx_blocks,
MX_BLOCK_SIZE: tl.constexpr,
PACKED_MX_BLOCK_SIZE: tl.constexpr,
BLOCK_SIZE_IN: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE_IN
# input_ptr is shape [n_mx_blocks, MX_BLOCK_SIZE]
# Load BLOCK_SIZE rows of input_ptr
offsets_rows = block_start + tl.arange(0, BLOCK_SIZE_IN)
offsets_cols = tl.arange(0, MX_BLOCK_SIZE // 4)
offsets = offsets_rows[:, None] * MX_BLOCK_SIZE + (4 * offsets_cols[None, :])
mask = (offsets_rows[:, None] < n_mx_blocks) & (offsets_cols[None, :] < MX_BLOCK_SIZE // 4)
# x is shape [BLOCK_SIZE, MX_BLOCK_SIZE]
x_0 = tl.load(input_ptr + offsets, mask=mask)
x_1 = tl.load(input_ptr + offsets + 1, mask=mask)
x_2 = tl.load(input_ptr + offsets + 2, mask=mask)
x_3 = tl.load(input_ptr + offsets + 3, mask=mask)
# 4个fp6 a b c d. a:[a5 a4 a3 a2 a1 a0], b..., c..., d...
# 3个unint8 pack0 pack1 pack2
# cutlass需要的:
# packed0: [b1 b0][a5 a4 a3 a2 a1 a0]
# packed1: [c3 c2 c1 c0][b5 b4 b3 b2]
# packed2: [d5 d4 d3 d2 d1 d0][c5 c4]
bits_packed0 = (x_1 << 6) | x_0
bits_packed1 = (x_2 << 4) | (x_1 >> 2)
bits_packed2 = (x_3 << 2) | (x_2 >> 4)
# Store values in a uint8 tensor of length `3 * MX_BLOCK_SIZE / 4`
offsets_out_4_a = offsets_rows[:, None] * PACKED_MX_BLOCK_SIZE + 3 * offsets_cols[None, :]
offsets_out_4_b = offsets_rows[:, None] * PACKED_MX_BLOCK_SIZE + 3 * offsets_cols[None, :] + 1
offsets_out_2 = offsets_rows[:, None] * PACKED_MX_BLOCK_SIZE + 3 * offsets_cols[None, :] + 2
# Store into output tensor
tl.store(
output_ptr + offsets_out_4_a,
bits_packed0,
mask=mask,
)
tl.store(
output_ptr + offsets_out_4_b,
bits_packed1,
mask=mask,
)
tl.store(
output_ptr + offsets_out_2,
bits_packed2,
mask=mask,
)
def pack_uint6(uint8_data: torch.Tensor) -> torch.Tensor:
# ensure input data is contiguous before passing to kernel
assert uint8_data.is_contiguous()
# tensor should already be of shape [..., mx_block_size]
mx_block_size = uint8_data.shape[-1]
assert mx_block_size % 4 == 0
# effective mx block size since we're packing 2 fp4 into 1 uint8
packed_mx_block_size = 3 * mx_block_size // 4
packed_shape = [uint8_data.shape[0], packed_mx_block_size]
n_mx_blocks = uint8_data.numel() // mx_block_size
grid = lambda meta: (triton.cdiv(n_mx_blocks, meta["BLOCK_SIZE_IN"]),) # noqa: E731
# contiguous uint8 container in which we can store the unpacked tensor
packed_uint8_data = torch.empty(packed_shape, dtype=torch.uint8, device=uint8_data.device)
triton_pack_uint6_kernel[grid](
uint8_data,
packed_uint8_data,
n_mx_blocks,
MX_BLOCK_SIZE=mx_block_size,
PACKED_MX_BLOCK_SIZE=packed_mx_block_size,
)
return packed_uint8_data
M = [257, 512, 1024, 13325, 32130, 32760] # , 75348
N = [1536, 5120, 8960] # , 13824
K = [128, 256, 512, 1024, 2048, 4096] # , 13824
for m in M:
for n in N:
for k in K:
x = torch.randn(m, k, device="cuda", dtype=torch.bfloat16)
w = torch.randn(n, k, device="cuda", dtype=torch.bfloat16)
# excute quant
x_scale, x_quant = quant2mxfp8(x)
w_scale, w_quant = quant2mxfp6(w)
# pack fp6 for cutlass
w_quant_packed = pack_uint6(w_quant.reshape(-1, 32))
# pad and swizzle scale
padded_and_swizzled_x_scale = scale_pad_and_swizzle(x_scale)
padded_and_swizzled_w_scale = scale_pad_and_swizzle(w_scale)
# ref mm result
ref_mm = torch.nn.functional.linear(x, w).to(torch.bfloat16)
# custom scaled mm
from lightx2v_kernel.gemm import cutlass_scaled_mxfp6_mxfp8_mm
alpha = torch.tensor(1.0, device="cuda", dtype=torch.float32)
bias = None
x_quant = x_quant.reshape(m, k).view(torch.uint8)
w_quant_packed = w_quant_packed.reshape(n, 3 * k // 4)
custom_mm = cutlass_scaled_mxfp6_mxfp8_mm(x_quant, w_quant_packed, padded_and_swizzled_x_scale, padded_and_swizzled_w_scale, alpha, bias)
# cal snr
from lightx2v_kernel.utils import error
print(f"m: {m}, n: {n}, k: {k}, error: {error(ref_mm, custom_mm)}")
# cal cos
cos_sim = torch.nn.functional.cosine_similarity(ref_mm.flatten(), custom_mm.flatten(), dim=0)
print(f"m: {m}, n: {n}, k: {k}, cos_sim: {cos_sim}")
import torch
from lightx2v_kernel.gemm import cutlass_scaled_mxfp6_mxfp8_mm
"""
input_shape = (1024, 2048)
weight_shape = (4096, 2048)
input_tensor_quant = (torch.rand((1024, 1024), device="cuda") * 10).to(torch.uint8)
weight = (torch.rand((4096, 1024), device="cuda") * 10).to(torch.uint8)
input_tensor_scale = torch.rand(1024, 128, device="cuda").to(torch.float8_e8m0fnu)
weight_scale = torch.rand(4096, 128, device="cuda").to(torch.float8_e8m0fnu)
alpha = torch.tensor(1.0, device="cuda").to(torch.float32)
bias = None
"""
def test_mm(input_tensor_quant, weight, input_tensor_scale, weight_scale, alpha, bias):
output_tensor = cutlass_scaled_mxfp6_mxfp8_mm(input_tensor_quant, weight, input_tensor_scale, weight_scale, alpha=alpha, bias=bias)
return output_tensor
def test_tflops(input_shape, weight_shape, num_warmup=10, num_runs=100):
"""
测试test_mm函数的TFLOPS性能
"""
# 创建输入数据
input_tensor_quant = (torch.rand((input_shape[0], input_shape[1]), device="cuda") * 10).to(torch.uint8)
weight = (torch.rand((weight_shape[0], 3 * weight_shape[1] // 4), device="cuda") * 10).to(torch.uint8)
input_tensor_scale = torch.rand(((input_shape[0] + 128 - 1) // 128) * 128, (input_shape[1] // 32 + 4 - 1) // 4 * 4, device="cuda").to(torch.float8_e8m0fnu)
weight_scale = torch.rand(weight_shape[0], weight_shape[1] // 32, device="cuda").to(torch.float8_e8m0fnu)
alpha = torch.tensor(1.0, device="cuda", dtype=torch.float32)
bias = None
# 预热GPU
for _ in range(num_warmup):
test_mm(input_tensor_quant, weight, input_tensor_scale, weight_scale, alpha, bias)
# 同步GPU
torch.cuda.synchronize()
# 创建GPU事件用于精确计时
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
# 测量时间
start_event.record()
for _ in range(num_runs):
result = test_mm(input_tensor_quant, weight, input_tensor_scale, weight_scale, alpha, bias)
end_event.record()
# 同步并计算时间
torch.cuda.synchronize()
elapsed_time_ms = start_event.elapsed_time(end_event)
elapsed_time_s = elapsed_time_ms / 1000.0
# 计算FLOPS
# 矩阵乘法 A(M x K) @ B(K x N) = C(M x N)
# M = batch_size, K = input_dim, N = output_dim
M = input_shape[0]
K = input_shape[1]
N = weight_shape[0]
# 每次矩阵乘法的FLOPS = 2 * M * N * K (每个输出元素需要K次乘法和K次加法)
flops_per_run = 2 * M * N * K
total_flops = flops_per_run * num_runs
# 计算TFLOPS (万亿次浮点运算每秒)
tflops = total_flops / (elapsed_time_s * 1e12)
print(f"测试结果:")
print(f" 输入形状: {input_shape} (M={M}, K={K})")
print(f" 权重形状: {weight_shape} (N={N}, K={K})")
print(f" 输出形状: ({M}, {N})")
print(f" 运行次数: {num_runs}")
print(f" 总执行时间: {elapsed_time_ms:.2f} ms")
print(f" 平均每次执行时间: {elapsed_time_ms / num_runs:.4f} ms")
print(f" 每次运行FLOPS: {flops_per_run / 1e9:.2f} GFLOPS")
print(f" 总FLOPS: {total_flops / 1e12:.2f} TFLOPS")
print(f" 计算性能: {tflops:.2f} TFLOPS")
return tflops
if __name__ == "__main__":
# 测试不同大小的矩阵乘法
# (m,k) (n,k)
test_cases = [
((32130, 5120), (5120, 5120)),
((512, 1536), (1536, 1536)),
((512, 5120), (5120, 5120)),
((257, 5120), (5120, 5120)),
((32130, 5120), (13824, 5120)),
((32130, 13824), (5120, 13824)),
((75348, 5120), (5120, 5120)),
((75348, 5120), (13824, 5120)),
((75348, 13824), (5120, 13824)),
((32760, 1536), (1536, 1536)),
((512, 1536), (1536, 1536)),
((32760, 1536), (8960, 1536)),
((32760, 8960), (1536, 8960)),
]
print("=== test_mm TFLOPS性能测试 ===\n")
for i, (input_shape, weight_shape) in enumerate(test_cases):
print(f"测试 {i + 1}: 输入形状 {input_shape}, 权重形状 {weight_shape}")
print("-" * 60)
tflops = test_tflops(input_shape, weight_shape)
print(f"✓ 成功完成测试,性能: {tflops:.2f} TFLOPS\n")
print("=== 测试完成 ===")
import unittest
import torch
from lightx2v_kernel.gemm import cutlass_scaled_mxfp6_mxfp8_mm
from lightx2v_kernel.gemm import scaled_fp6_quant, scaled_fp8_quant
from torch.nn.functional import linear
from lightx2v_kernel.utils import error, benchmark
class TestQuantBF162MXFP6(unittest.TestCase):
def setUp(self):
self.tokens = [128, 257, 512, 1024, 13325, 32130, 32760] # , 75348
self.channels = [128, 1536, 5120, 8960] # , 13824
self.hiddenDims = [128, 1536, 3072, 5120, 8960, 12800] # , 13824
self.device = "cuda"
self.dtype = torch.bfloat16
def test_accuracy(self):
"""Test the accuracy of quantization from BF16 to MXFP6."""
for m in self.tokens:
for k in self.hiddenDims:
for n in self.channels:
with self.subTest(shape=[m, k, n]):
activation = torch.randn(m, k, dtype=self.dtype, device=self.device)
activation_quant_pred, activation_scale_pred = scaled_fp8_quant(activation)
weight = torch.randn(n, k, dtype=self.dtype, device=self.device)
weight_quant_pred, weight_scale_pred = scaled_fp6_quant(weight)
alpha = torch.tensor(1.0, device=self.device, dtype=torch.float32)
mm_pred = cutlass_scaled_mxfp6_mxfp8_mm(activation_quant_pred, weight_quant_pred, activation_scale_pred, weight_scale_pred, alpha=alpha)
mm_real = linear(activation, weight, bias=None).to(torch.bfloat16)
self.assertTrue(error(mm_pred, mm_real) < 1e-2, f"Accuracy test failed for shape {m, k, n}: Error {error(mm_pred, mm_real)} exceeds threshold.")
def test_performance(self):
"""Benchmark the performance of Activation quantization from BF16 to MXFP6."""
for m in self.tokens:
for k in self.hiddenDims:
with self.subTest(shape=[m, k]):
input = torch.randn(m, k, dtype=self.dtype, device=self.device)
shape = [m, k]
tflops = 2 * (m * k / 1024**4)
benchmark(scaled_fp6_quant, shape, tflops, 100, input)
if __name__ == "__main__":
unittest.main()
import torch
from lightx2v_kernel.gemm import scaled_fp6_quant
def quantize_fp6(x):
return scaled_fp6_quant(x)
def test_memory_bandwidth(func, x, num_warmup=10, num_runs=100):
"""
测试函数的显存带宽
"""
# 预热GPU
for _ in range(num_warmup):
func(x)
# 同步GPU
torch.cuda.synchronize()
# 创建GPU事件用于精确计时
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
# 测量时间
start_event.record()
for _ in range(num_runs):
result = func(x)
end_event.record()
# 同步并计算时间
torch.cuda.synchronize()
elapsed_time_ms = start_event.elapsed_time(end_event)
elapsed_time_s = elapsed_time_ms / 1000.0
# 计算数据量
input_bytes = x.numel() * x.element_size() # 输入数据字节数
# FP6量化后,每个元素占用 3/ 4字节
output_bytes = x.numel() * (3 / 4) # FP6输出数据字节数
scale_bytes = x.numel() / 32 # group_size = 32
# 总数据传输量(读取输入 + 写入输出 + scale)
total_bytes = (input_bytes + output_bytes + scale_bytes) * num_runs
# 计算带宽
bandwidth_gbps = (total_bytes / elapsed_time_s) / (1024**3) # GB/s
print(f"测试结果:")
print(f" 输入张量形状: {x.shape}")
print(f" 输入数据类型: {x.dtype}")
print(f" 运行次数: {num_runs}")
print(f" 总执行时间: {elapsed_time_ms:.2f} ms")
print(f" 平均每次执行时间: {elapsed_time_ms / num_runs:.4f} ms")
print(f" 输入数据大小: {input_bytes / (1024**2):.2f} MB")
print(f" 输出数据大小: {output_bytes / (1024**2):.2f} MB")
print(f" 总数据传输量: {total_bytes / (1024**3):.2f} GB")
print(f" 显存带宽: {bandwidth_gbps:.2f} GB/s")
return bandwidth_gbps
if __name__ == "__main__":
# 测试不同大小的张量
test_sizes = [
# (1, 1024),
# (1, 2048),
# (1, 4096),
# (1, 8192),
# (1, 16384),
# (1, 32768),
# (2, 1024),
# (2, 2048),
# (2, 4096),
# (2, 8192),
# (2, 16384),
# (2, 32768),
# (4, 1024),
# (4, 2048),
# (4, 4096),
# (4, 8192),
# (4, 16384),
# (4, 32768),
# (128, 1024),
# (128, 2048),
# (128, 4096),
# (128, 8192),
# (128, 16384),
# (128, 32768),
# (512, 1024),
# (512, 2048),
# (512, 4096),
# (512, 8192),
# (512, 16384),
# (512, 32768),
# (1024, 1024),
# (1024, 2048),
# (1024, 4096),
# (1024, 8192),
# (1024, 16384),
# (1024, 32768),
# (2048, 1024),
# (2048, 2048),
# (2048, 4096),
# (2048, 8192),
# (2048, 16384),
# (2048, 32768),
# (4096, 1024),
# (4096, 2048),
# (4096, 4096),
# (4096, 8192),
# (4096, 16384),
# (4096, 32768),
# (8192, 1024),
# (8192, 2048),
# (8192, 4096),
# (8192, 8192),
# (8192, 16384),
# (8192, 32768),
# (16384, 1024),
# (16384, 2048),
# (16384, 4096),
# (16384, 8192),
# (16384, 16384),
# (16384, 32768),
# (32768, 1024),
# (32768, 2048),
# (32768, 4096),
# (32768, 8192),
# (32768, 16384),
# (32768, 32768),
(32130, 5120),
(512, 5120),
(257, 5120),
(32130, 13824),
(75348, 5120),
(75348, 13824),
(32760, 1536),
(512, 3072),
(512, 1536),
(32760, 8960),
]
print("=== quantize_fp8 显存带宽测试 ===\n")
for i, (h, w) in enumerate(test_sizes):
print(f"测试 {i + 1}: 张量大小 ({h}, {w})")
print("-" * 50)
x = torch.randn(h, w, dtype=torch.bfloat16).cuda()
try:
bandwidth = test_memory_bandwidth(quantize_fp6, x)
print(f"✓ 成功完成测试,带宽: {bandwidth:.2f} GB/s\n")
except Exception as e:
print(f"✗ 测试失败: {e}\n")
print("=== 测试完成 ===")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment