"vscode:/vscode.git/clone" did not exist on "fbefc8a78d22b20eac042c586805c7dcbfc66b1c"
Unverified Commit 34cd32fe authored by Michael Goin's avatar Michael Goin Committed by GitHub
Browse files

[Perf][Kernel] Fused SiLU+Mul+Quant kernel for NVFP4 cutlass_moe (#31832)


Signed-off-by: default avatarmgoin <mgoin64@gmail.com>
Signed-off-by: default avatarMichael Goin <mgoin64@gmail.com>
parent 8e27663b
......@@ -301,6 +301,12 @@ void scaled_fp4_experts_quant(
torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts);
void silu_and_mul_scaled_fp4_experts_quant(
torch::Tensor& output, torch::Tensor& output_scale,
torch::Tensor const& input, torch::Tensor const& input_global_scale,
torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts);
void per_token_group_quant_fp8(const torch::Tensor& input,
torch::Tensor& output_q, torch::Tensor& output_s,
int64_t group_size, double eps, double fp8_min,
......
......@@ -31,37 +31,6 @@
namespace vllm {
// silu in float32
__device__ __forceinline__ float silu(float x) {
return __fdividef(x, (1.f + __expf(-x)));
}
__device__ __forceinline__ float2 silu2(float2 x) {
return make_float2(silu(x.x), silu(x.y));
}
template <class Type>
__inline__ __device__ PackedVec<Type> compute_silu_mul(PackedVec<Type>& vec,
PackedVec<Type>& vec2) {
PackedVec<Type> result;
using packed_type = typename TypeConverter<Type>::Type;
#pragma unroll
for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; ++i) {
// silu_mul in float32
if constexpr (std::is_same_v<Type, half>) {
float2 silu_vec = silu2(__half22float2(vec.elts[i]));
result.elts[i] =
__float22half2_rn(__fmul2_rn(silu_vec, __half22float2(vec2.elts[i])));
} else {
float2 silu_vec = silu2(__bfloat1622float2(vec.elts[i]));
result.elts[i] = __float22bfloat162_rn(
__fmul2_rn(silu_vec, __bfloat1622float2(vec2.elts[i])));
}
}
return result;
}
// Use UE4M3 by default.
template <class Type, bool UE8M0_SF = false>
__global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024))
......
......@@ -31,8 +31,12 @@
namespace vllm {
// NVFP4 quantization kernel for experts (low-latency path).
// When FUSE_SILU_MUL=true, expects input with gate||up layout and fuses
// SiLU(gate)*up before quantization.
// Use UE4M3 by default.
template <class Type, bool UE8M0_SF = false, bool SMALL_NUM_EXPERTS = false>
template <class Type, bool FUSE_SILU_MUL = false, bool UE8M0_SF = false,
bool SMALL_NUM_EXPERTS = false>
__global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, Type const* in,
float const* SFScale, uint32_t* out, uint32_t* SFout,
......@@ -50,6 +54,8 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int colsPerRow = numCols / CVT_FP4_ELTS_PER_THREAD;
// When fusing SiLU+Mul, input has gate || up layout (doubled width)
int inColsPerRow = FUSE_SILU_MUL ? colsPerRow * 2 : colsPerRow;
// Each global thread processes one element
for (int globalIdx = tid; globalIdx < numRows * colsPerRow;
......@@ -58,13 +64,6 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
int rowIdx = globalIdx / colsPerRow;
int colIdx = globalIdx % colsPerRow;
int64_t inOffset = rowIdx * colsPerRow + colIdx;
PackedVec in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset];
// Get the output tensor offset.
// Same as inOffset because 8 elements are packed into one uint32_t.
int64_t outOffset = inOffset;
auto& out_pos = out[outOffset];
// Find index within the experts using different strategies based on expert
// count
int rowIdx_in_expert = 0;
......@@ -111,6 +110,23 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
}
}
// Load input and optionally apply fused SiLU+Mul
int64_t inOffset = rowIdx * inColsPerRow + colIdx;
PackedVec in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset];
PackedVec quant_input;
if constexpr (FUSE_SILU_MUL) {
PackedVec in_vec_up =
reinterpret_cast<PackedVec const*>(in)[inOffset + colsPerRow];
quant_input = compute_silu_mul(in_vec, in_vec_up);
} else {
quant_input = in_vec;
}
// Get the output tensor offset.
// Same as inOffset because 8 elements are packed into one uint32_t.
int64_t outOffset = rowIdx * colsPerRow + colIdx;
auto& out_pos = out[outOffset];
// Get the global scaling factor, which will be applied to the SF.
// Note SFScale is the same as next GEMM's alpha, which is
// (448.f / (Alpha_A / 6.f)).
......@@ -124,12 +140,16 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
CVT_FP4_NUM_THREADS_PER_SF>(
rowIdx_in_expert, colIdx, numKTiles, SFout_in_expert);
out_pos = cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(in_vec, SFScaleVal, sf_out);
out_pos =
cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(quant_input, SFScaleVal, sf_out);
}
}
// Kernel for LARGE_M_TOPK = true (large m_topk optimized version)
template <class Type, bool UE8M0_SF = false, bool SMALL_NUM_EXPERTS = false>
// NVFP4 quantization kernel for LARGE_M_TOPK = true (large m_topk optimized
// version). When FUSE_SILU_MUL=true, expects input with gate||up layout and
// fuses SiLU(gate)*up before quantization.
template <class Type, bool FUSE_SILU_MUL = false, bool UE8M0_SF = false,
bool SMALL_NUM_EXPERTS = false>
__global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024))
cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, Type const* in,
float const* SFScale, uint32_t* out, uint32_t* SFout,
......@@ -167,6 +187,8 @@ __global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024))
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int colsPerRow = numCols / CVT_FP4_ELTS_PER_THREAD;
// When fusing SiLU+Mul, input has gate || up layout (doubled width)
int inColsPerRow = FUSE_SILU_MUL ? colsPerRow * 2 : colsPerRow;
// Each global thread processes one element
for (int globalIdx = tid; globalIdx < numRows * colsPerRow;
......@@ -175,11 +197,6 @@ __global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024))
int rowIdx = globalIdx / colsPerRow;
int colIdx = globalIdx % colsPerRow;
int64_t inOffset = rowIdx * colsPerRow + colIdx;
PackedVec in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset];
int64_t outOffset = inOffset;
auto& out_pos = out[outOffset];
// Find expert using binary search for better performance with large m_topk
int rowIdx_in_expert = 0;
int expert_idx = 0;
......@@ -204,6 +221,21 @@ __global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024))
}
}
// Load input and optionally apply fused SiLU+Mul
int64_t inOffset = rowIdx * inColsPerRow + colIdx;
PackedVec in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset];
PackedVec quant_input;
if constexpr (FUSE_SILU_MUL) {
PackedVec in_vec_up =
reinterpret_cast<PackedVec const*>(in)[inOffset + colsPerRow];
quant_input = compute_silu_mul(in_vec, in_vec_up);
} else {
quant_input = in_vec;
}
int64_t outOffset = rowIdx * colsPerRow + colIdx;
auto& out_pos = out[outOffset];
float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[expert_idx];
uint32_t* SFout_in_expert =
......@@ -214,11 +246,12 @@ __global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024))
CVT_FP4_NUM_THREADS_PER_SF>(
rowIdx_in_expert, colIdx, numKTiles, SFout_in_expert);
out_pos = cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(in_vec, SFScaleVal, sf_out);
out_pos =
cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(quant_input, SFScaleVal, sf_out);
}
}
template <typename T>
template <typename T, bool FUSE_SILU_MUL = false>
void quant_impl(void* output, void* output_scale, void* input,
void* input_global_scale, void* input_offset_by_experts,
void* output_scale_offset_by_experts, int m_topk, int k,
......@@ -246,7 +279,7 @@ void quant_impl(void* output, void* output_scale, void* input,
if (blockRepeat > 1) {
size_t shared_mem_size = (n_experts + 1) * sizeof(uint32_t);
if (n_experts >= 4) {
cvt_fp16_to_fp4<T, false, false>
cvt_fp16_to_fp4<T, FUSE_SILU_MUL, false, false>
<<<grid, block, shared_mem_size, stream>>>(
m_topk, k, reinterpret_cast<T*>(input),
reinterpret_cast<float*>(input_global_scale),
......@@ -256,34 +289,37 @@ void quant_impl(void* output, void* output_scale, void* input,
reinterpret_cast<uint32_t*>(output_scale_offset_by_experts),
n_experts);
} else {
cvt_fp16_to_fp4<T, false, true><<<grid, block, shared_mem_size, stream>>>(
m_topk, k, reinterpret_cast<T*>(input),
reinterpret_cast<float*>(input_global_scale),
reinterpret_cast<uint32_t*>(output),
reinterpret_cast<uint32_t*>(output_scale),
reinterpret_cast<uint32_t*>(input_offset_by_experts),
reinterpret_cast<uint32_t*>(output_scale_offset_by_experts),
n_experts);
cvt_fp16_to_fp4<T, FUSE_SILU_MUL, false, true>
<<<grid, block, shared_mem_size, stream>>>(
m_topk, k, reinterpret_cast<T*>(input),
reinterpret_cast<float*>(input_global_scale),
reinterpret_cast<uint32_t*>(output),
reinterpret_cast<uint32_t*>(output_scale),
reinterpret_cast<uint32_t*>(input_offset_by_experts),
reinterpret_cast<uint32_t*>(output_scale_offset_by_experts),
n_experts);
}
} else {
if (n_experts >= 16) {
cvt_fp16_to_fp4<T, false, false><<<grid, block, 0, stream>>>(
m_topk, k, reinterpret_cast<T*>(input),
reinterpret_cast<float*>(input_global_scale),
reinterpret_cast<uint32_t*>(output),
reinterpret_cast<uint32_t*>(output_scale),
reinterpret_cast<uint32_t*>(input_offset_by_experts),
reinterpret_cast<uint32_t*>(output_scale_offset_by_experts),
n_experts, /* bool low_latency */ true);
cvt_fp16_to_fp4<T, FUSE_SILU_MUL, false, false>
<<<grid, block, 0, stream>>>(
m_topk, k, reinterpret_cast<T*>(input),
reinterpret_cast<float*>(input_global_scale),
reinterpret_cast<uint32_t*>(output),
reinterpret_cast<uint32_t*>(output_scale),
reinterpret_cast<uint32_t*>(input_offset_by_experts),
reinterpret_cast<uint32_t*>(output_scale_offset_by_experts),
n_experts, /* bool low_latency */ true);
} else {
cvt_fp16_to_fp4<T, false, true><<<grid, block, 0, stream>>>(
m_topk, k, reinterpret_cast<T*>(input),
reinterpret_cast<float*>(input_global_scale),
reinterpret_cast<uint32_t*>(output),
reinterpret_cast<uint32_t*>(output_scale),
reinterpret_cast<uint32_t*>(input_offset_by_experts),
reinterpret_cast<uint32_t*>(output_scale_offset_by_experts),
n_experts, /* bool low_latency */ true);
cvt_fp16_to_fp4<T, FUSE_SILU_MUL, false, true>
<<<grid, block, 0, stream>>>(
m_topk, k, reinterpret_cast<T*>(input),
reinterpret_cast<float*>(input_global_scale),
reinterpret_cast<uint32_t*>(output),
reinterpret_cast<uint32_t*>(output_scale),
reinterpret_cast<uint32_t*>(input_offset_by_experts),
reinterpret_cast<uint32_t*>(output_scale_offset_by_experts),
n_experts, /* bool low_latency */ true);
}
}
}
......@@ -304,19 +340,19 @@ constexpr auto FLOAT = at::ScalarType::Float;
constexpr auto INT = at::ScalarType::Int;
constexpr auto UINT8 = at::ScalarType::Byte;
void scaled_fp4_experts_quant_sm1xxa(
torch::Tensor& output, torch::Tensor& output_scale,
// Common validation for fp4 experts quantization entry points.
static void validate_fp4_experts_quant_inputs(
torch::Tensor const& output, torch::Tensor const& output_scale,
torch::Tensor const& input, torch::Tensor const& input_global_scale,
torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts) {
CHECK_INPUT(output, "output must be a CUDA tensor");
CHECK_INPUT(output_scale, "output_scale must be a CUDA tensor");
CHECK_INPUT(input, "input must be a CUDA tensor");
CHECK_INPUT(input_global_scale, "input_global_scale must be a CUDA tensor");
CHECK_INPUT(input_offset_by_experts,
"input_offset_by_experts must be a CUDA tensor");
CHECK_INPUT(output_scale_offset_by_experts,
"output_scale_offset_by_experts must be a CUDA tensor");
torch::Tensor const& output_scale_offset_by_experts, int64_t m_topk,
int64_t k) {
CHECK_INPUT(output, "output");
CHECK_INPUT(output_scale, "output_scale");
CHECK_INPUT(input, "input");
CHECK_INPUT(input_global_scale, "input_global_scale");
CHECK_INPUT(input_offset_by_experts, "input_offset_by_experts");
CHECK_INPUT(output_scale_offset_by_experts, "output_scale_offset_by_experts");
TORCH_CHECK(output.dim() == 2);
TORCH_CHECK(output_scale.dim() == 2);
......@@ -335,8 +371,6 @@ void scaled_fp4_experts_quant_sm1xxa(
TORCH_CHECK(output_scale.scalar_type() == INT);
const int BLOCK_SIZE = 16;
auto m_topk = input.size(0);
auto k = input.size(1);
TORCH_CHECK(k % BLOCK_SIZE == 0, "k must be a multiple of 16");
auto n_experts = input_global_scale.size(0);
TORCH_CHECK(input_offset_by_experts.size(0) == n_experts + 1);
......@@ -348,7 +382,21 @@ void scaled_fp4_experts_quant_sm1xxa(
int padded_k = (scales_k + (4 - 1)) / 4 * 4;
// 4 means 4 fp8 values are packed into one int32
TORCH_CHECK(output_scale.size(1) * 4 == padded_k);
}
void scaled_fp4_experts_quant_sm1xxa(
torch::Tensor& output, torch::Tensor& output_scale,
torch::Tensor const& input, torch::Tensor const& input_global_scale,
torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts) {
auto m_topk = input.size(0);
auto k = input.size(1);
validate_fp4_experts_quant_inputs(output, output_scale, input,
input_global_scale, input_offset_by_experts,
output_scale_offset_by_experts, m_topk, k);
auto n_experts = input_global_scale.size(0);
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream =
at::cuda::getCurrentCUDAStream(input.get_device());
......@@ -356,7 +404,38 @@ void scaled_fp4_experts_quant_sm1xxa(
VLLM_DISPATCH_HALF_TYPES(
input.scalar_type(), "nvfp4_experts_quant_kernel", [&] {
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
vllm::quant_impl<cuda_type>(
vllm::quant_impl<cuda_type, /*FUSE_SILU_MUL=*/false>(
output.data_ptr(), output_scale.data_ptr(), input.data_ptr(),
input_global_scale.data_ptr(), input_offset_by_experts.data_ptr(),
output_scale_offset_by_experts.data_ptr(), m_topk, k, n_experts,
stream);
});
}
void silu_and_mul_scaled_fp4_experts_quant_sm1xxa(
torch::Tensor& output, torch::Tensor& output_scale,
torch::Tensor const& input, torch::Tensor const& input_global_scale,
torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts) {
auto m_topk = input.size(0);
// Input has gate || up layout, so k = input.size(1) / 2
auto k_times_2 = input.size(1);
TORCH_CHECK(k_times_2 % 2 == 0, "input width must be even (gate || up)");
auto k = k_times_2 / 2;
validate_fp4_experts_quant_inputs(output, output_scale, input,
input_global_scale, input_offset_by_experts,
output_scale_offset_by_experts, m_topk, k);
auto n_experts = input_global_scale.size(0);
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream =
at::cuda::getCurrentCUDAStream(input.get_device());
VLLM_DISPATCH_HALF_TYPES(
input.scalar_type(), "silu_mul_nvfp4_experts_quant_kernel", [&] {
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
vllm::quant_impl<cuda_type, /*FUSE_SILU_MUL=*/true>(
output.data_ptr(), output_scale.data_ptr(), input.data_ptr(),
input_global_scale.data_ptr(), input_offset_by_experts.data_ptr(),
output_scale_offset_by_experts.data_ptr(), m_topk, k, n_experts,
......
......@@ -41,6 +41,15 @@ void silu_and_mul_nvfp4_quant_sm1xxa(torch::Tensor& output,
torch::Tensor& input_sf);
#endif
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
void silu_and_mul_scaled_fp4_experts_quant_sm1xxa(
torch::Tensor& output, torch::Tensor& output_scale,
torch::Tensor const& input, torch::Tensor const& input_global_scale,
torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts);
#endif
void scaled_fp4_quant(torch::Tensor& output, torch::Tensor const& input,
torch::Tensor& output_sf, torch::Tensor const& input_sf) {
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
......@@ -74,3 +83,18 @@ void silu_and_mul_nvfp4_quant(torch::Tensor& output, torch::Tensor& output_sf,
TORCH_CHECK_NOT_IMPLEMENTED(
false, "No compiled silu_and_mul nvfp4 quantization kernel");
}
void silu_and_mul_scaled_fp4_experts_quant(
torch::Tensor& output, torch::Tensor& output_scale,
torch::Tensor const& input, torch::Tensor const& input_global_scale,
torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts) {
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
return silu_and_mul_scaled_fp4_experts_quant_sm1xxa(
output, output_scale, input, input_global_scale, input_offset_by_experts,
output_scale_offset_by_experts);
#endif
TORCH_CHECK_NOT_IMPLEMENTED(
false, "No compiled silu_and_mul nvfp4 experts quantization kernel");
}
......@@ -239,4 +239,34 @@ __device__ uint32_t cvt_warp_fp16_to_fp4(PackedVec<Type>& vec, float SFScaleVal,
return e2m1Vec;
}
// silu in float32
__device__ __forceinline__ float silu(float x) {
return __fdividef(x, (1.f + __expf(-x)));
}
__device__ __forceinline__ float2 silu2(float2 x) {
return make_float2(silu(x.x), silu(x.y));
}
template <class Type>
__inline__ __device__ PackedVec<Type> compute_silu_mul(
const PackedVec<Type>& x_vec, const PackedVec<Type>& y_vec) {
PackedVec<Type> result;
#pragma unroll
for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; ++i) {
// silu_mul in float32
if constexpr (std::is_same_v<Type, half>) {
float2 silu_vec = silu2(__half22float2(x_vec.elts[i]));
result.elts[i] = __float22half2_rn(
__fmul2_rn(silu_vec, __half22float2(y_vec.elts[i])));
} else {
float2 silu_vec = silu2(__bfloat1622float2(x_vec.elts[i]));
result.elts[i] = __float22bfloat162_rn(
__fmul2_rn(silu_vec, __bfloat1622float2(y_vec.elts[i])));
}
}
return result;
}
} // namespace vllm
......@@ -558,6 +558,15 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"Tensor output_scale_offset_by_experts) -> ()");
ops.impl("scaled_fp4_experts_quant", torch::kCUDA, &scaled_fp4_experts_quant);
// Fused SiLU+Mul+NVFP4 experts quantization.
ops.def(
"silu_and_mul_scaled_fp4_experts_quant(Tensor! output, Tensor! "
"output_scale,"
"Tensor input, Tensor input_global_scale, Tensor input_offset_by_experts,"
"Tensor output_scale_offset_by_experts) -> ()");
ops.impl("silu_and_mul_scaled_fp4_experts_quant", torch::kCUDA,
&silu_and_mul_scaled_fp4_experts_quant);
// Check if cutlass_scaled_mm_fp4 is supported for CUDA devices
// of the given capability
ops.def("cutlass_scaled_mm_supports_fp4(int cuda_device_capability) -> bool");
......
......@@ -1606,15 +1606,15 @@ def scaled_fp4_experts_quant(
topk: int,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Quantize input tensor to FP4 and return quantized tensor and scale, for
Quantize input tensor to NVFP4 and return quantized tensor and scale, for
packed MoE Inputs.
Args:
input_tensor: The input tensor to be quantized to FP4
input_tensor: The input tensor to be quantized to NVFP4
input_global_scale: A scalar scaling factor for the entire tensor.
expert_offsets: The expert offsets tensor
blockscale_offsets: The blockscale offsets tensor
Outputs:
output: The quantized tensor in FP4
output: The quantized tensor in NVFP4
output_scales: The blockscale tensor in FP8-E4M3
"""
assert not current_platform.is_rocm()
......@@ -1660,6 +1660,71 @@ def scaled_fp4_experts_quant(
return output, output_scales
def silu_and_mul_scaled_fp4_experts_quant(
input_tensor: torch.Tensor,
input_global_scale: torch.Tensor,
expert_offsets: torch.Tensor,
blockscale_offsets: torch.Tensor,
topk: int,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Fused SiLU+Mul+NVFP4 quantization for MoE intermediate activations.
Args:
input_tensor: The input tensor with gate || up layout [m_topk, k*2]
input_global_scale: A per-expert scaling factor [n_experts]
expert_offsets: The expert offsets tensor [n_experts+1]
blockscale_offsets: The blockscale offsets tensor [n_experts+1]
topk: Number of top-k experts selected
Outputs:
output: The quantized tensor in NVFP4 [m_topk, k/2]
output_scales: The blockscale tensor in FP8-E4M3
"""
assert not current_platform.is_rocm()
assert input_tensor.ndim == 2, (
f"input.ndim needs to be == 2, but got {input_tensor.ndim}."
)
# Control the maximum number of tokens per expert supported by the
# NVFP4 MoE Expert Quantization. This is used to prevent the kernel
# from running out of memory. This value can also be increased to support
# larger models.
MAX_TOKENS_PER_EXPERT = envs.VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE
m_numtopk, k_times_2 = input_tensor.shape
assert k_times_2 % 2 == 0, "input width must be even (gate || up layout)"
k = k_times_2 // 2
assert m_numtopk <= MAX_TOKENS_PER_EXPERT * topk, (
f"m_numtopk must be less than MAX_TOKENS_PER_EXPERT("
f"{MAX_TOKENS_PER_EXPERT})"
f" for cutlass_moe_fp4, observed m_numtopk = {m_numtopk}. Use"
f" VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE to set this value."
)
scales_k = k // 16
padded_k = (scales_k + (4 - 1)) // 4
# output is uint8 and packed fp4 values
output = torch.empty(
m_numtopk, k // 2, device=input_tensor.device, dtype=torch.uint8
)
output_scales = torch.empty(
MAX_TOKENS_PER_EXPERT * topk,
padded_k,
dtype=torch.int32,
device=input_tensor.device,
)
torch.ops._C.silu_and_mul_scaled_fp4_experts_quant(
output,
output_scales,
input_tensor,
input_global_scale,
expert_offsets,
blockscale_offsets,
)
output_scales = output_scales.view(torch.float8_e4m3fn)
return output, output_scales
# fp8
def scaled_fp8_quant(
input: torch.Tensor,
......
......@@ -549,7 +549,8 @@ def run_cutlass_moe_fp4(
num_topk,
)
c1 = _resize_cache(workspace13, (m * topk, n * 2))
c2 = _resize_cache(workspace2, (m * topk, n))
# Note: c2 workspace is no longer needed since SiLU is fused with quantization.
# c3 reuses workspace13 after c1 is consumed.
c3 = _resize_cache(workspace13, (m * topk, k))
ops.cutlass_fp4_moe_mm(
c1,
......@@ -563,9 +564,9 @@ def run_cutlass_moe_fp4(
blockscale_offsets[:-1],
)
del rep_a_fp4, rep_a_blockscale
torch.ops._C.silu_and_mul(c2, c1)
int_fp4, int_blockscale = ops.scaled_fp4_experts_quant(
c2, a2_gscale, expert_offsets, blockscale_offsets, num_topk
# Fused SiLU+Mul+NVFP4 quantization
int_fp4, int_blockscale = ops.silu_and_mul_scaled_fp4_experts_quant(
c1, a2_gscale, expert_offsets, blockscale_offsets, num_topk
)
ops.cutlass_fp4_moe_mm(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment