Unverified Commit 34cd32fe authored by Michael Goin's avatar Michael Goin Committed by GitHub
Browse files

[Perf][Kernel] Fused SiLU+Mul+Quant kernel for NVFP4 cutlass_moe (#31832)


Signed-off-by: default avatarmgoin <mgoin64@gmail.com>
Signed-off-by: default avatarMichael Goin <mgoin64@gmail.com>
parent 8e27663b
...@@ -301,6 +301,12 @@ void scaled_fp4_experts_quant( ...@@ -301,6 +301,12 @@ void scaled_fp4_experts_quant(
torch::Tensor const& input_offset_by_experts, torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts); torch::Tensor const& output_scale_offset_by_experts);
void silu_and_mul_scaled_fp4_experts_quant(
torch::Tensor& output, torch::Tensor& output_scale,
torch::Tensor const& input, torch::Tensor const& input_global_scale,
torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts);
void per_token_group_quant_fp8(const torch::Tensor& input, void per_token_group_quant_fp8(const torch::Tensor& input,
torch::Tensor& output_q, torch::Tensor& output_s, torch::Tensor& output_q, torch::Tensor& output_s,
int64_t group_size, double eps, double fp8_min, int64_t group_size, double eps, double fp8_min,
......
...@@ -31,37 +31,6 @@ ...@@ -31,37 +31,6 @@
namespace vllm { namespace vllm {
// silu in float32
__device__ __forceinline__ float silu(float x) {
return __fdividef(x, (1.f + __expf(-x)));
}
__device__ __forceinline__ float2 silu2(float2 x) {
return make_float2(silu(x.x), silu(x.y));
}
template <class Type>
__inline__ __device__ PackedVec<Type> compute_silu_mul(PackedVec<Type>& vec,
PackedVec<Type>& vec2) {
PackedVec<Type> result;
using packed_type = typename TypeConverter<Type>::Type;
#pragma unroll
for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; ++i) {
// silu_mul in float32
if constexpr (std::is_same_v<Type, half>) {
float2 silu_vec = silu2(__half22float2(vec.elts[i]));
result.elts[i] =
__float22half2_rn(__fmul2_rn(silu_vec, __half22float2(vec2.elts[i])));
} else {
float2 silu_vec = silu2(__bfloat1622float2(vec.elts[i]));
result.elts[i] = __float22bfloat162_rn(
__fmul2_rn(silu_vec, __bfloat1622float2(vec2.elts[i])));
}
}
return result;
}
// Use UE4M3 by default. // Use UE4M3 by default.
template <class Type, bool UE8M0_SF = false> template <class Type, bool UE8M0_SF = false>
__global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024)) __global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024))
......
...@@ -31,8 +31,12 @@ ...@@ -31,8 +31,12 @@
namespace vllm { namespace vllm {
// NVFP4 quantization kernel for experts (low-latency path).
// When FUSE_SILU_MUL=true, expects input with gate||up layout and fuses
// SiLU(gate)*up before quantization.
// Use UE4M3 by default. // Use UE4M3 by default.
template <class Type, bool UE8M0_SF = false, bool SMALL_NUM_EXPERTS = false> template <class Type, bool FUSE_SILU_MUL = false, bool UE8M0_SF = false,
bool SMALL_NUM_EXPERTS = false>
__global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, Type const* in, cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, Type const* in,
float const* SFScale, uint32_t* out, uint32_t* SFout, float const* SFScale, uint32_t* out, uint32_t* SFout,
...@@ -50,6 +54,8 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) ...@@ -50,6 +54,8 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
int tid = blockIdx.x * blockDim.x + threadIdx.x; int tid = blockIdx.x * blockDim.x + threadIdx.x;
int colsPerRow = numCols / CVT_FP4_ELTS_PER_THREAD; int colsPerRow = numCols / CVT_FP4_ELTS_PER_THREAD;
// When fusing SiLU+Mul, input has gate || up layout (doubled width)
int inColsPerRow = FUSE_SILU_MUL ? colsPerRow * 2 : colsPerRow;
// Each global thread processes one element // Each global thread processes one element
for (int globalIdx = tid; globalIdx < numRows * colsPerRow; for (int globalIdx = tid; globalIdx < numRows * colsPerRow;
...@@ -58,13 +64,6 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) ...@@ -58,13 +64,6 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
int rowIdx = globalIdx / colsPerRow; int rowIdx = globalIdx / colsPerRow;
int colIdx = globalIdx % colsPerRow; int colIdx = globalIdx % colsPerRow;
int64_t inOffset = rowIdx * colsPerRow + colIdx;
PackedVec in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset];
// Get the output tensor offset.
// Same as inOffset because 8 elements are packed into one uint32_t.
int64_t outOffset = inOffset;
auto& out_pos = out[outOffset];
// Find index within the experts using different strategies based on expert // Find index within the experts using different strategies based on expert
// count // count
int rowIdx_in_expert = 0; int rowIdx_in_expert = 0;
...@@ -111,6 +110,23 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) ...@@ -111,6 +110,23 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
} }
} }
// Load input and optionally apply fused SiLU+Mul
int64_t inOffset = rowIdx * inColsPerRow + colIdx;
PackedVec in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset];
PackedVec quant_input;
if constexpr (FUSE_SILU_MUL) {
PackedVec in_vec_up =
reinterpret_cast<PackedVec const*>(in)[inOffset + colsPerRow];
quant_input = compute_silu_mul(in_vec, in_vec_up);
} else {
quant_input = in_vec;
}
// Get the output tensor offset.
// Same as inOffset because 8 elements are packed into one uint32_t.
int64_t outOffset = rowIdx * colsPerRow + colIdx;
auto& out_pos = out[outOffset];
// Get the global scaling factor, which will be applied to the SF. // Get the global scaling factor, which will be applied to the SF.
// Note SFScale is the same as next GEMM's alpha, which is // Note SFScale is the same as next GEMM's alpha, which is
// (448.f / (Alpha_A / 6.f)). // (448.f / (Alpha_A / 6.f)).
...@@ -124,12 +140,16 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) ...@@ -124,12 +140,16 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
CVT_FP4_NUM_THREADS_PER_SF>( CVT_FP4_NUM_THREADS_PER_SF>(
rowIdx_in_expert, colIdx, numKTiles, SFout_in_expert); rowIdx_in_expert, colIdx, numKTiles, SFout_in_expert);
out_pos = cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(in_vec, SFScaleVal, sf_out); out_pos =
cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(quant_input, SFScaleVal, sf_out);
} }
} }
// Kernel for LARGE_M_TOPK = true (large m_topk optimized version) // NVFP4 quantization kernel for LARGE_M_TOPK = true (large m_topk optimized
template <class Type, bool UE8M0_SF = false, bool SMALL_NUM_EXPERTS = false> // version). When FUSE_SILU_MUL=true, expects input with gate||up layout and
// fuses SiLU(gate)*up before quantization.
template <class Type, bool FUSE_SILU_MUL = false, bool UE8M0_SF = false,
bool SMALL_NUM_EXPERTS = false>
__global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024)) __global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024))
cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, Type const* in, cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, Type const* in,
float const* SFScale, uint32_t* out, uint32_t* SFout, float const* SFScale, uint32_t* out, uint32_t* SFout,
...@@ -167,6 +187,8 @@ __global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024)) ...@@ -167,6 +187,8 @@ __global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024))
int tid = blockIdx.x * blockDim.x + threadIdx.x; int tid = blockIdx.x * blockDim.x + threadIdx.x;
int colsPerRow = numCols / CVT_FP4_ELTS_PER_THREAD; int colsPerRow = numCols / CVT_FP4_ELTS_PER_THREAD;
// When fusing SiLU+Mul, input has gate || up layout (doubled width)
int inColsPerRow = FUSE_SILU_MUL ? colsPerRow * 2 : colsPerRow;
// Each global thread processes one element // Each global thread processes one element
for (int globalIdx = tid; globalIdx < numRows * colsPerRow; for (int globalIdx = tid; globalIdx < numRows * colsPerRow;
...@@ -175,11 +197,6 @@ __global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024)) ...@@ -175,11 +197,6 @@ __global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024))
int rowIdx = globalIdx / colsPerRow; int rowIdx = globalIdx / colsPerRow;
int colIdx = globalIdx % colsPerRow; int colIdx = globalIdx % colsPerRow;
int64_t inOffset = rowIdx * colsPerRow + colIdx;
PackedVec in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset];
int64_t outOffset = inOffset;
auto& out_pos = out[outOffset];
// Find expert using binary search for better performance with large m_topk // Find expert using binary search for better performance with large m_topk
int rowIdx_in_expert = 0; int rowIdx_in_expert = 0;
int expert_idx = 0; int expert_idx = 0;
...@@ -204,6 +221,21 @@ __global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024)) ...@@ -204,6 +221,21 @@ __global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024))
} }
} }
// Load input and optionally apply fused SiLU+Mul
int64_t inOffset = rowIdx * inColsPerRow + colIdx;
PackedVec in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset];
PackedVec quant_input;
if constexpr (FUSE_SILU_MUL) {
PackedVec in_vec_up =
reinterpret_cast<PackedVec const*>(in)[inOffset + colsPerRow];
quant_input = compute_silu_mul(in_vec, in_vec_up);
} else {
quant_input = in_vec;
}
int64_t outOffset = rowIdx * colsPerRow + colIdx;
auto& out_pos = out[outOffset];
float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[expert_idx]; float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[expert_idx];
uint32_t* SFout_in_expert = uint32_t* SFout_in_expert =
...@@ -214,11 +246,12 @@ __global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024)) ...@@ -214,11 +246,12 @@ __global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024))
CVT_FP4_NUM_THREADS_PER_SF>( CVT_FP4_NUM_THREADS_PER_SF>(
rowIdx_in_expert, colIdx, numKTiles, SFout_in_expert); rowIdx_in_expert, colIdx, numKTiles, SFout_in_expert);
out_pos = cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(in_vec, SFScaleVal, sf_out); out_pos =
cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(quant_input, SFScaleVal, sf_out);
} }
} }
template <typename T> template <typename T, bool FUSE_SILU_MUL = false>
void quant_impl(void* output, void* output_scale, void* input, void quant_impl(void* output, void* output_scale, void* input,
void* input_global_scale, void* input_offset_by_experts, void* input_global_scale, void* input_offset_by_experts,
void* output_scale_offset_by_experts, int m_topk, int k, void* output_scale_offset_by_experts, int m_topk, int k,
...@@ -246,7 +279,7 @@ void quant_impl(void* output, void* output_scale, void* input, ...@@ -246,7 +279,7 @@ void quant_impl(void* output, void* output_scale, void* input,
if (blockRepeat > 1) { if (blockRepeat > 1) {
size_t shared_mem_size = (n_experts + 1) * sizeof(uint32_t); size_t shared_mem_size = (n_experts + 1) * sizeof(uint32_t);
if (n_experts >= 4) { if (n_experts >= 4) {
cvt_fp16_to_fp4<T, false, false> cvt_fp16_to_fp4<T, FUSE_SILU_MUL, false, false>
<<<grid, block, shared_mem_size, stream>>>( <<<grid, block, shared_mem_size, stream>>>(
m_topk, k, reinterpret_cast<T*>(input), m_topk, k, reinterpret_cast<T*>(input),
reinterpret_cast<float*>(input_global_scale), reinterpret_cast<float*>(input_global_scale),
...@@ -256,7 +289,8 @@ void quant_impl(void* output, void* output_scale, void* input, ...@@ -256,7 +289,8 @@ void quant_impl(void* output, void* output_scale, void* input,
reinterpret_cast<uint32_t*>(output_scale_offset_by_experts), reinterpret_cast<uint32_t*>(output_scale_offset_by_experts),
n_experts); n_experts);
} else { } else {
cvt_fp16_to_fp4<T, false, true><<<grid, block, shared_mem_size, stream>>>( cvt_fp16_to_fp4<T, FUSE_SILU_MUL, false, true>
<<<grid, block, shared_mem_size, stream>>>(
m_topk, k, reinterpret_cast<T*>(input), m_topk, k, reinterpret_cast<T*>(input),
reinterpret_cast<float*>(input_global_scale), reinterpret_cast<float*>(input_global_scale),
reinterpret_cast<uint32_t*>(output), reinterpret_cast<uint32_t*>(output),
...@@ -267,7 +301,8 @@ void quant_impl(void* output, void* output_scale, void* input, ...@@ -267,7 +301,8 @@ void quant_impl(void* output, void* output_scale, void* input,
} }
} else { } else {
if (n_experts >= 16) { if (n_experts >= 16) {
cvt_fp16_to_fp4<T, false, false><<<grid, block, 0, stream>>>( cvt_fp16_to_fp4<T, FUSE_SILU_MUL, false, false>
<<<grid, block, 0, stream>>>(
m_topk, k, reinterpret_cast<T*>(input), m_topk, k, reinterpret_cast<T*>(input),
reinterpret_cast<float*>(input_global_scale), reinterpret_cast<float*>(input_global_scale),
reinterpret_cast<uint32_t*>(output), reinterpret_cast<uint32_t*>(output),
...@@ -276,7 +311,8 @@ void quant_impl(void* output, void* output_scale, void* input, ...@@ -276,7 +311,8 @@ void quant_impl(void* output, void* output_scale, void* input,
reinterpret_cast<uint32_t*>(output_scale_offset_by_experts), reinterpret_cast<uint32_t*>(output_scale_offset_by_experts),
n_experts, /* bool low_latency */ true); n_experts, /* bool low_latency */ true);
} else { } else {
cvt_fp16_to_fp4<T, false, true><<<grid, block, 0, stream>>>( cvt_fp16_to_fp4<T, FUSE_SILU_MUL, false, true>
<<<grid, block, 0, stream>>>(
m_topk, k, reinterpret_cast<T*>(input), m_topk, k, reinterpret_cast<T*>(input),
reinterpret_cast<float*>(input_global_scale), reinterpret_cast<float*>(input_global_scale),
reinterpret_cast<uint32_t*>(output), reinterpret_cast<uint32_t*>(output),
...@@ -304,19 +340,19 @@ constexpr auto FLOAT = at::ScalarType::Float; ...@@ -304,19 +340,19 @@ constexpr auto FLOAT = at::ScalarType::Float;
constexpr auto INT = at::ScalarType::Int; constexpr auto INT = at::ScalarType::Int;
constexpr auto UINT8 = at::ScalarType::Byte; constexpr auto UINT8 = at::ScalarType::Byte;
void scaled_fp4_experts_quant_sm1xxa( // Common validation for fp4 experts quantization entry points.
torch::Tensor& output, torch::Tensor& output_scale, static void validate_fp4_experts_quant_inputs(
torch::Tensor const& output, torch::Tensor const& output_scale,
torch::Tensor const& input, torch::Tensor const& input_global_scale, torch::Tensor const& input, torch::Tensor const& input_global_scale,
torch::Tensor const& input_offset_by_experts, torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts) { torch::Tensor const& output_scale_offset_by_experts, int64_t m_topk,
CHECK_INPUT(output, "output must be a CUDA tensor"); int64_t k) {
CHECK_INPUT(output_scale, "output_scale must be a CUDA tensor"); CHECK_INPUT(output, "output");
CHECK_INPUT(input, "input must be a CUDA tensor"); CHECK_INPUT(output_scale, "output_scale");
CHECK_INPUT(input_global_scale, "input_global_scale must be a CUDA tensor"); CHECK_INPUT(input, "input");
CHECK_INPUT(input_offset_by_experts, CHECK_INPUT(input_global_scale, "input_global_scale");
"input_offset_by_experts must be a CUDA tensor"); CHECK_INPUT(input_offset_by_experts, "input_offset_by_experts");
CHECK_INPUT(output_scale_offset_by_experts, CHECK_INPUT(output_scale_offset_by_experts, "output_scale_offset_by_experts");
"output_scale_offset_by_experts must be a CUDA tensor");
TORCH_CHECK(output.dim() == 2); TORCH_CHECK(output.dim() == 2);
TORCH_CHECK(output_scale.dim() == 2); TORCH_CHECK(output_scale.dim() == 2);
...@@ -335,8 +371,6 @@ void scaled_fp4_experts_quant_sm1xxa( ...@@ -335,8 +371,6 @@ void scaled_fp4_experts_quant_sm1xxa(
TORCH_CHECK(output_scale.scalar_type() == INT); TORCH_CHECK(output_scale.scalar_type() == INT);
const int BLOCK_SIZE = 16; const int BLOCK_SIZE = 16;
auto m_topk = input.size(0);
auto k = input.size(1);
TORCH_CHECK(k % BLOCK_SIZE == 0, "k must be a multiple of 16"); TORCH_CHECK(k % BLOCK_SIZE == 0, "k must be a multiple of 16");
auto n_experts = input_global_scale.size(0); auto n_experts = input_global_scale.size(0);
TORCH_CHECK(input_offset_by_experts.size(0) == n_experts + 1); TORCH_CHECK(input_offset_by_experts.size(0) == n_experts + 1);
...@@ -348,7 +382,21 @@ void scaled_fp4_experts_quant_sm1xxa( ...@@ -348,7 +382,21 @@ void scaled_fp4_experts_quant_sm1xxa(
int padded_k = (scales_k + (4 - 1)) / 4 * 4; int padded_k = (scales_k + (4 - 1)) / 4 * 4;
// 4 means 4 fp8 values are packed into one int32 // 4 means 4 fp8 values are packed into one int32
TORCH_CHECK(output_scale.size(1) * 4 == padded_k); TORCH_CHECK(output_scale.size(1) * 4 == padded_k);
}
void scaled_fp4_experts_quant_sm1xxa(
torch::Tensor& output, torch::Tensor& output_scale,
torch::Tensor const& input, torch::Tensor const& input_global_scale,
torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts) {
auto m_topk = input.size(0);
auto k = input.size(1);
validate_fp4_experts_quant_inputs(output, output_scale, input,
input_global_scale, input_offset_by_experts,
output_scale_offset_by_experts, m_topk, k);
auto n_experts = input_global_scale.size(0);
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = const cudaStream_t stream =
at::cuda::getCurrentCUDAStream(input.get_device()); at::cuda::getCurrentCUDAStream(input.get_device());
...@@ -356,7 +404,38 @@ void scaled_fp4_experts_quant_sm1xxa( ...@@ -356,7 +404,38 @@ void scaled_fp4_experts_quant_sm1xxa(
VLLM_DISPATCH_HALF_TYPES( VLLM_DISPATCH_HALF_TYPES(
input.scalar_type(), "nvfp4_experts_quant_kernel", [&] { input.scalar_type(), "nvfp4_experts_quant_kernel", [&] {
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type; using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
vllm::quant_impl<cuda_type>( vllm::quant_impl<cuda_type, /*FUSE_SILU_MUL=*/false>(
output.data_ptr(), output_scale.data_ptr(), input.data_ptr(),
input_global_scale.data_ptr(), input_offset_by_experts.data_ptr(),
output_scale_offset_by_experts.data_ptr(), m_topk, k, n_experts,
stream);
});
}
void silu_and_mul_scaled_fp4_experts_quant_sm1xxa(
torch::Tensor& output, torch::Tensor& output_scale,
torch::Tensor const& input, torch::Tensor const& input_global_scale,
torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts) {
auto m_topk = input.size(0);
// Input has gate || up layout, so k = input.size(1) / 2
auto k_times_2 = input.size(1);
TORCH_CHECK(k_times_2 % 2 == 0, "input width must be even (gate || up)");
auto k = k_times_2 / 2;
validate_fp4_experts_quant_inputs(output, output_scale, input,
input_global_scale, input_offset_by_experts,
output_scale_offset_by_experts, m_topk, k);
auto n_experts = input_global_scale.size(0);
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream =
at::cuda::getCurrentCUDAStream(input.get_device());
VLLM_DISPATCH_HALF_TYPES(
input.scalar_type(), "silu_mul_nvfp4_experts_quant_kernel", [&] {
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
vllm::quant_impl<cuda_type, /*FUSE_SILU_MUL=*/true>(
output.data_ptr(), output_scale.data_ptr(), input.data_ptr(), output.data_ptr(), output_scale.data_ptr(), input.data_ptr(),
input_global_scale.data_ptr(), input_offset_by_experts.data_ptr(), input_global_scale.data_ptr(), input_offset_by_experts.data_ptr(),
output_scale_offset_by_experts.data_ptr(), m_topk, k, n_experts, output_scale_offset_by_experts.data_ptr(), m_topk, k, n_experts,
......
...@@ -41,6 +41,15 @@ void silu_and_mul_nvfp4_quant_sm1xxa(torch::Tensor& output, ...@@ -41,6 +41,15 @@ void silu_and_mul_nvfp4_quant_sm1xxa(torch::Tensor& output,
torch::Tensor& input_sf); torch::Tensor& input_sf);
#endif #endif
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
void silu_and_mul_scaled_fp4_experts_quant_sm1xxa(
torch::Tensor& output, torch::Tensor& output_scale,
torch::Tensor const& input, torch::Tensor const& input_global_scale,
torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts);
#endif
void scaled_fp4_quant(torch::Tensor& output, torch::Tensor const& input, void scaled_fp4_quant(torch::Tensor& output, torch::Tensor const& input,
torch::Tensor& output_sf, torch::Tensor const& input_sf) { torch::Tensor& output_sf, torch::Tensor const& input_sf) {
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \ #if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
...@@ -74,3 +83,18 @@ void silu_and_mul_nvfp4_quant(torch::Tensor& output, torch::Tensor& output_sf, ...@@ -74,3 +83,18 @@ void silu_and_mul_nvfp4_quant(torch::Tensor& output, torch::Tensor& output_sf,
TORCH_CHECK_NOT_IMPLEMENTED( TORCH_CHECK_NOT_IMPLEMENTED(
false, "No compiled silu_and_mul nvfp4 quantization kernel"); false, "No compiled silu_and_mul nvfp4 quantization kernel");
} }
void silu_and_mul_scaled_fp4_experts_quant(
torch::Tensor& output, torch::Tensor& output_scale,
torch::Tensor const& input, torch::Tensor const& input_global_scale,
torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts) {
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
return silu_and_mul_scaled_fp4_experts_quant_sm1xxa(
output, output_scale, input, input_global_scale, input_offset_by_experts,
output_scale_offset_by_experts);
#endif
TORCH_CHECK_NOT_IMPLEMENTED(
false, "No compiled silu_and_mul nvfp4 experts quantization kernel");
}
...@@ -239,4 +239,34 @@ __device__ uint32_t cvt_warp_fp16_to_fp4(PackedVec<Type>& vec, float SFScaleVal, ...@@ -239,4 +239,34 @@ __device__ uint32_t cvt_warp_fp16_to_fp4(PackedVec<Type>& vec, float SFScaleVal,
return e2m1Vec; return e2m1Vec;
} }
// silu in float32
__device__ __forceinline__ float silu(float x) {
return __fdividef(x, (1.f + __expf(-x)));
}
__device__ __forceinline__ float2 silu2(float2 x) {
return make_float2(silu(x.x), silu(x.y));
}
template <class Type>
__inline__ __device__ PackedVec<Type> compute_silu_mul(
const PackedVec<Type>& x_vec, const PackedVec<Type>& y_vec) {
PackedVec<Type> result;
#pragma unroll
for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; ++i) {
// silu_mul in float32
if constexpr (std::is_same_v<Type, half>) {
float2 silu_vec = silu2(__half22float2(x_vec.elts[i]));
result.elts[i] = __float22half2_rn(
__fmul2_rn(silu_vec, __half22float2(y_vec.elts[i])));
} else {
float2 silu_vec = silu2(__bfloat1622float2(x_vec.elts[i]));
result.elts[i] = __float22bfloat162_rn(
__fmul2_rn(silu_vec, __bfloat1622float2(y_vec.elts[i])));
}
}
return result;
}
} // namespace vllm } // namespace vllm
...@@ -558,6 +558,15 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -558,6 +558,15 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"Tensor output_scale_offset_by_experts) -> ()"); "Tensor output_scale_offset_by_experts) -> ()");
ops.impl("scaled_fp4_experts_quant", torch::kCUDA, &scaled_fp4_experts_quant); ops.impl("scaled_fp4_experts_quant", torch::kCUDA, &scaled_fp4_experts_quant);
// Fused SiLU+Mul+NVFP4 experts quantization.
ops.def(
"silu_and_mul_scaled_fp4_experts_quant(Tensor! output, Tensor! "
"output_scale,"
"Tensor input, Tensor input_global_scale, Tensor input_offset_by_experts,"
"Tensor output_scale_offset_by_experts) -> ()");
ops.impl("silu_and_mul_scaled_fp4_experts_quant", torch::kCUDA,
&silu_and_mul_scaled_fp4_experts_quant);
// Check if cutlass_scaled_mm_fp4 is supported for CUDA devices // Check if cutlass_scaled_mm_fp4 is supported for CUDA devices
// of the given capability // of the given capability
ops.def("cutlass_scaled_mm_supports_fp4(int cuda_device_capability) -> bool"); ops.def("cutlass_scaled_mm_supports_fp4(int cuda_device_capability) -> bool");
......
...@@ -1606,15 +1606,15 @@ def scaled_fp4_experts_quant( ...@@ -1606,15 +1606,15 @@ def scaled_fp4_experts_quant(
topk: int, topk: int,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
""" """
Quantize input tensor to FP4 and return quantized tensor and scale, for Quantize input tensor to NVFP4 and return quantized tensor and scale, for
packed MoE Inputs. packed MoE Inputs.
Args: Args:
input_tensor: The input tensor to be quantized to FP4 input_tensor: The input tensor to be quantized to NVFP4
input_global_scale: A scalar scaling factor for the entire tensor. input_global_scale: A scalar scaling factor for the entire tensor.
expert_offsets: The expert offsets tensor expert_offsets: The expert offsets tensor
blockscale_offsets: The blockscale offsets tensor blockscale_offsets: The blockscale offsets tensor
Outputs: Outputs:
output: The quantized tensor in FP4 output: The quantized tensor in NVFP4
output_scales: The blockscale tensor in FP8-E4M3 output_scales: The blockscale tensor in FP8-E4M3
""" """
assert not current_platform.is_rocm() assert not current_platform.is_rocm()
...@@ -1660,6 +1660,71 @@ def scaled_fp4_experts_quant( ...@@ -1660,6 +1660,71 @@ def scaled_fp4_experts_quant(
return output, output_scales return output, output_scales
def silu_and_mul_scaled_fp4_experts_quant(
input_tensor: torch.Tensor,
input_global_scale: torch.Tensor,
expert_offsets: torch.Tensor,
blockscale_offsets: torch.Tensor,
topk: int,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Fused SiLU+Mul+NVFP4 quantization for MoE intermediate activations.
Args:
input_tensor: The input tensor with gate || up layout [m_topk, k*2]
input_global_scale: A per-expert scaling factor [n_experts]
expert_offsets: The expert offsets tensor [n_experts+1]
blockscale_offsets: The blockscale offsets tensor [n_experts+1]
topk: Number of top-k experts selected
Outputs:
output: The quantized tensor in NVFP4 [m_topk, k/2]
output_scales: The blockscale tensor in FP8-E4M3
"""
assert not current_platform.is_rocm()
assert input_tensor.ndim == 2, (
f"input.ndim needs to be == 2, but got {input_tensor.ndim}."
)
# Control the maximum number of tokens per expert supported by the
# NVFP4 MoE Expert Quantization. This is used to prevent the kernel
# from running out of memory. This value can also be increased to support
# larger models.
MAX_TOKENS_PER_EXPERT = envs.VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE
m_numtopk, k_times_2 = input_tensor.shape
assert k_times_2 % 2 == 0, "input width must be even (gate || up layout)"
k = k_times_2 // 2
assert m_numtopk <= MAX_TOKENS_PER_EXPERT * topk, (
f"m_numtopk must be less than MAX_TOKENS_PER_EXPERT("
f"{MAX_TOKENS_PER_EXPERT})"
f" for cutlass_moe_fp4, observed m_numtopk = {m_numtopk}. Use"
f" VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE to set this value."
)
scales_k = k // 16
padded_k = (scales_k + (4 - 1)) // 4
# output is uint8 and packed fp4 values
output = torch.empty(
m_numtopk, k // 2, device=input_tensor.device, dtype=torch.uint8
)
output_scales = torch.empty(
MAX_TOKENS_PER_EXPERT * topk,
padded_k,
dtype=torch.int32,
device=input_tensor.device,
)
torch.ops._C.silu_and_mul_scaled_fp4_experts_quant(
output,
output_scales,
input_tensor,
input_global_scale,
expert_offsets,
blockscale_offsets,
)
output_scales = output_scales.view(torch.float8_e4m3fn)
return output, output_scales
# fp8 # fp8
def scaled_fp8_quant( def scaled_fp8_quant(
input: torch.Tensor, input: torch.Tensor,
......
...@@ -549,7 +549,8 @@ def run_cutlass_moe_fp4( ...@@ -549,7 +549,8 @@ def run_cutlass_moe_fp4(
num_topk, num_topk,
) )
c1 = _resize_cache(workspace13, (m * topk, n * 2)) c1 = _resize_cache(workspace13, (m * topk, n * 2))
c2 = _resize_cache(workspace2, (m * topk, n)) # Note: c2 workspace is no longer needed since SiLU is fused with quantization.
# c3 reuses workspace13 after c1 is consumed.
c3 = _resize_cache(workspace13, (m * topk, k)) c3 = _resize_cache(workspace13, (m * topk, k))
ops.cutlass_fp4_moe_mm( ops.cutlass_fp4_moe_mm(
c1, c1,
...@@ -563,9 +564,9 @@ def run_cutlass_moe_fp4( ...@@ -563,9 +564,9 @@ def run_cutlass_moe_fp4(
blockscale_offsets[:-1], blockscale_offsets[:-1],
) )
del rep_a_fp4, rep_a_blockscale del rep_a_fp4, rep_a_blockscale
torch.ops._C.silu_and_mul(c2, c1) # Fused SiLU+Mul+NVFP4 quantization
int_fp4, int_blockscale = ops.scaled_fp4_experts_quant( int_fp4, int_blockscale = ops.silu_and_mul_scaled_fp4_experts_quant(
c2, a2_gscale, expert_offsets, blockscale_offsets, num_topk c1, a2_gscale, expert_offsets, blockscale_offsets, num_topk
) )
ops.cutlass_fp4_moe_mm( ops.cutlass_fp4_moe_mm(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment