Unverified Commit fcb9df99 authored by Roberto L. Castro's avatar Roberto L. Castro Committed by GitHub
Browse files

[Perf][Kernel] Optimize FP4 quantization kernels (SM100F) (#32520)


Signed-off-by: default avatarLopezCastroRoberto <rocastro@redhat.com>
parent 1ebdff41
...@@ -20,8 +20,12 @@ FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max() ...@@ -20,8 +20,12 @@ FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max()
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
PROVIDER_CFGS = { PROVIDER_CFGS = {
"vllm": dict(backend="vllm", enabled=True), "vllm": dict(backend="vllm", is_sf_swizzled_layout=False, enabled=True),
"flashinfer": dict(backend="flashinfer", enabled=True), "vllm-swizzle": dict(backend="vllm", is_sf_swizzled_layout=True, enabled=True),
"flashinfer": dict(backend="flashinfer", is_sf_swizzled_layout=False, enabled=True),
"flashinfer-swizzle": dict(
backend="flashinfer", is_sf_swizzled_layout=True, enabled=True
),
} }
_enabled = [k for k, v in PROVIDER_CFGS.items() if v["enabled"]] _enabled = [k for k, v in PROVIDER_CFGS.items() if v["enabled"]]
...@@ -36,7 +40,7 @@ def compute_global_scale(tensor: torch.Tensor) -> torch.Tensor: ...@@ -36,7 +40,7 @@ def compute_global_scale(tensor: torch.Tensor) -> torch.Tensor:
@triton.testing.perf_report( @triton.testing.perf_report(
triton.testing.Benchmark( triton.testing.Benchmark(
x_names=["batch_size"], x_names=["batch_size"],
x_vals=[1, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096], x_vals=[1, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192],
x_log=False, x_log=False,
line_arg="provider", line_arg="provider",
line_vals=_enabled, line_vals=_enabled,
...@@ -63,19 +67,36 @@ def benchmark(batch_size, provider, N, K): ...@@ -63,19 +67,36 @@ def benchmark(batch_size, provider, N, K):
if cfg["backend"] == "vllm": if cfg["backend"] == "vllm":
# vLLM's FP4 quantization # vLLM's FP4 quantization
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( if cfg["is_sf_swizzled_layout"]:
lambda: ops.scaled_fp4_quant(a, a_global_scale), ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
quantiles=quantiles, lambda: ops.scaled_fp4_quant(
) a, a_global_scale, is_sf_swizzled_layout=True
),
quantiles=quantiles,
)
else:
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: ops.scaled_fp4_quant(
a, a_global_scale, is_sf_swizzled_layout=False
),
quantiles=quantiles,
)
elif cfg["backend"] == "flashinfer": elif cfg["backend"] == "flashinfer":
# FlashInfer's FP4 quantization # FlashInfer's FP4 quantization
# Use is_sf_swizzled_layout=True to match vLLM's output format if cfg["is_sf_swizzled_layout"]:
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: flashinfer_fp4_quantize( lambda: flashinfer_fp4_quantize(
a, a_global_scale, is_sf_swizzled_layout=True a, a_global_scale, is_sf_swizzled_layout=True
), ),
quantiles=quantiles, quantiles=quantiles,
) )
else:
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: flashinfer_fp4_quantize(
a, a_global_scale, is_sf_swizzled_layout=False
),
quantiles=quantiles,
)
# Convert ms to us for better readability at small batch sizes # Convert ms to us for better readability at small batch sizes
to_us = lambda t_ms: t_ms * 1000 to_us = lambda t_ms: t_ms * 1000
...@@ -92,7 +113,9 @@ def prepare_shapes(args): ...@@ -92,7 +113,9 @@ def prepare_shapes(args):
return out return out
def _test_accuracy_once(M: int, K: int, dtype: torch.dtype, device: str): def _test_accuracy_once(
M: int, K: int, dtype: torch.dtype, device: str, is_sf_swizzled_layout: bool
):
"""Test accuracy between vLLM and FlashInfer FP4 quantization.""" """Test accuracy between vLLM and FlashInfer FP4 quantization."""
# Create input tensor # Create input tensor
a = torch.randn((M, K), device=device, dtype=dtype) a = torch.randn((M, K), device=device, dtype=dtype)
...@@ -101,11 +124,13 @@ def _test_accuracy_once(M: int, K: int, dtype: torch.dtype, device: str): ...@@ -101,11 +124,13 @@ def _test_accuracy_once(M: int, K: int, dtype: torch.dtype, device: str):
a_global_scale = compute_global_scale(a) a_global_scale = compute_global_scale(a)
# vLLM quantization # vLLM quantization
vllm_fp4, vllm_scale = ops.scaled_fp4_quant(a, a_global_scale) vllm_fp4, vllm_scale = ops.scaled_fp4_quant(
a, a_global_scale, is_sf_swizzled_layout=is_sf_swizzled_layout
)
# FlashInfer quantization (with swizzled layout to match vLLM's output) # FlashInfer quantization (with swizzled layout to match vLLM's output)
flashinfer_fp4, flashinfer_scale = flashinfer_fp4_quantize( flashinfer_fp4, flashinfer_scale = flashinfer_fp4_quantize(
a, a_global_scale, is_sf_swizzled_layout=True a, a_global_scale, is_sf_swizzled_layout=is_sf_swizzled_layout
) )
flashinfer_scale = flashinfer_scale.view(torch.float8_e4m3fn) flashinfer_scale = flashinfer_scale.view(torch.float8_e4m3fn)
...@@ -114,7 +139,14 @@ def _test_accuracy_once(M: int, K: int, dtype: torch.dtype, device: str): ...@@ -114,7 +139,14 @@ def _test_accuracy_once(M: int, K: int, dtype: torch.dtype, device: str):
vllm_fp4, vllm_fp4,
flashinfer_fp4, flashinfer_fp4,
) )
print(f"M={M}, K={K}, dtype={dtype}: PASSED") # Compare scales
torch.testing.assert_close(
vllm_scale,
flashinfer_scale,
)
print(
f"M={M}, K={K}, dtype={dtype}, is_sf_swizzled_layout={is_sf_swizzled_layout}: PASSED" # noqa: E501
)
def test_accuracy(): def test_accuracy():
...@@ -130,9 +162,10 @@ def test_accuracy(): ...@@ -130,9 +162,10 @@ def test_accuracy():
Ms = [1, 1024] Ms = [1, 1024]
Ks = [4096] Ks = [4096]
for M in Ms: for is_sf_swizzled_layout in [True, False]:
for K in Ks: for M in Ms:
_test_accuracy_once(M, K, dtype, device) for K in Ks:
_test_accuracy_once(M, K, dtype, device, is_sf_swizzled_layout)
print("\nAll accuracy tests passed!") print("\nAll accuracy tests passed!")
...@@ -145,7 +178,7 @@ if __name__ == "__main__": ...@@ -145,7 +178,7 @@ if __name__ == "__main__":
"--models", "--models",
nargs="+", nargs="+",
type=str, type=str,
default=["meta-llama/Llama-3.1-8B-Instruct"], default=["meta-llama/Llama-3.3-70B-Instruct"],
choices=list(WEIGHT_SHAPES.keys()), choices=list(WEIGHT_SHAPES.keys()),
) )
parser.add_argument("--tp-sizes", nargs="+", type=int, default=[1]) parser.add_argument("--tp-sizes", nargs="+", type=int, default=[1])
......
...@@ -293,7 +293,8 @@ std::vector<torch::Tensor> cutlass_sparse_compress(torch::Tensor const& a); ...@@ -293,7 +293,8 @@ std::vector<torch::Tensor> cutlass_sparse_compress(torch::Tensor const& a);
void scaled_fp4_quant(torch::Tensor& output, torch::Tensor const& input, void scaled_fp4_quant(torch::Tensor& output, torch::Tensor const& input,
torch::Tensor& output_scale, torch::Tensor& output_scale,
torch::Tensor const& input_scale); torch::Tensor const& input_scale,
bool is_sf_swizzled_layout);
void scaled_fp4_experts_quant( void scaled_fp4_experts_quant(
torch::Tensor& output, torch::Tensor& output_scale, torch::Tensor& output, torch::Tensor& output_scale,
......
...@@ -27,17 +27,24 @@ ...@@ -27,17 +27,24 @@
#include "cuda_utils.h" #include "cuda_utils.h"
#include "launch_bounds_utils.h" #include "launch_bounds_utils.h"
// Define before including nvfp4_utils.cuh so the header
// can use this macro during compilation.
#define NVFP4_ENABLE_ELTS16 1
#include "nvfp4_utils.cuh" #include "nvfp4_utils.cuh"
namespace vllm { namespace vllm {
// Use UE4M3 by default. // Use UE4M3 by default.
template <class Type, bool UE8M0_SF = false> template <class Type, bool UE8M0_SF = false>
__global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024)) __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
silu_mul_cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, Type const* in, silu_mul_cvt_fp16_to_fp4(int32_t numRows, int32_t numCols,
float const* SFScale, uint32_t* out, int32_t num_padded_cols,
uint32_t* SFout) { Type const* __restrict__ in,
using PackedVec = PackedVec<Type>; float const* __restrict__ SFScale,
uint32_t* __restrict__ out,
uint32_t* __restrict__ SFout) {
using PackedVec = vllm::PackedVec<Type>;
static constexpr int CVT_FP4_NUM_THREADS_PER_SF = static constexpr int CVT_FP4_NUM_THREADS_PER_SF =
(CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD); (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD);
static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD, static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD,
...@@ -49,34 +56,60 @@ __global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024)) ...@@ -49,34 +56,60 @@ __global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024))
// Get the global scaling factor, which will be applied to the SF. // Get the global scaling factor, which will be applied to the SF.
// Note SFScale is the same as next GEMM's alpha, which is // Note SFScale is the same as next GEMM's alpha, which is
// (448.f / (Alpha_A / 6.f)). // (448.f / (Alpha_A / 6.f)).
float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[0]; float const SFScaleVal = (SFScale == nullptr) ? 1.0f : SFScale[0];
int32_t const colIdx = blockDim.x * blockIdx.y + threadIdx.x;
int elem_idx = colIdx * CVT_FP4_ELTS_PER_THREAD;
// Input tensor row/col loops. // Input tensor row/col loops.
for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x) { for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x) {
for (int colIdx = threadIdx.x; colIdx < numCols / CVT_FP4_ELTS_PER_THREAD; if (colIdx < num_padded_cols) {
colIdx += blockDim.x) { PackedVec in_vec;
PackedVec in_vec2;
int64_t inOffset = int64_t inOffset =
rowIdx * (numCols * 2 / CVT_FP4_ELTS_PER_THREAD) + colIdx; rowIdx * (numCols * 2 / CVT_FP4_ELTS_PER_THREAD) + colIdx;
int64_t inOffset2 = rowIdx * (numCols * 2 / CVT_FP4_ELTS_PER_THREAD) + int64_t inOffset2 = rowIdx * (numCols * 2 / CVT_FP4_ELTS_PER_THREAD) +
numCols / CVT_FP4_ELTS_PER_THREAD + colIdx; numCols / CVT_FP4_ELTS_PER_THREAD + colIdx;
PackedVec in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset];
PackedVec in_vec2 = reinterpret_cast<PackedVec const*>(in)[inOffset2];
// Get the output tensor offset. bool valid = (rowIdx < numRows) && (elem_idx < numCols);
// Same as inOffset because 8 elements are packed into one uint32_t. if constexpr (CVT_FP4_PACK16) {
int64_t outOffset = rowIdx * (numCols / CVT_FP4_ELTS_PER_THREAD) + colIdx; ld256_or_zero_cg_u32<Type>(
auto& out_pos = out[outOffset]; in_vec, &reinterpret_cast<const uint32_t*>(in)[inOffset * 8],
valid);
ld256_or_zero_cg_u32<Type>(
in_vec2, &reinterpret_cast<const uint32_t*>(in)[inOffset2 * 8],
valid);
} else {
ld128_or_zero_cg_u32<Type>(
in_vec, &reinterpret_cast<const uint32_t*>(in)[inOffset * 4],
valid);
ld128_or_zero_cg_u32<Type>(
in_vec2, &reinterpret_cast<const uint32_t*>(in)[inOffset2 * 4],
valid);
}
// Compute silu and mul // Compute silu and mul
PackedVec out_silu_mul = compute_silu_mul(in_vec, in_vec2); PackedVec out_silu_mul = compute_silu_mul<Type>(in_vec, in_vec2);
auto sf_out = auto sf_out =
cvt_quant_to_fp4_get_sf_out_offset<uint32_t, cvt_quant_to_fp4_get_sf_out_offset<uint32_t,
CVT_FP4_NUM_THREADS_PER_SF>( CVT_FP4_NUM_THREADS_PER_SF>(
rowIdx, colIdx, numKTiles, SFout); rowIdx, colIdx, numKTiles, SFout);
out_pos = cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(out_silu_mul, SFScaleVal, auto out_val =
sf_out); cvt_warp_fp16_to_fp4<Type, CVT_FP4_NUM_THREADS_PER_SF, UE8M0_SF>(
out_silu_mul, SFScaleVal, sf_out);
if (valid) {
if constexpr (CVT_FP4_PACK16) {
int64_t outOffset = rowIdx * (numCols / 8) + colIdx * 2;
uint64_t packed64 =
(uint64_t(out_val.hi) << 32) | uint64_t(out_val.lo);
reinterpret_cast<uint64_t*>(out)[outOffset >> 1] = packed64;
} else {
out[inOffset] = out_val;
}
}
} }
} }
} }
...@@ -103,17 +136,23 @@ void silu_and_mul_nvfp4_quant_sm1xxa(torch::Tensor& output, // [..., d] ...@@ -103,17 +136,23 @@ void silu_and_mul_nvfp4_quant_sm1xxa(torch::Tensor& output, // [..., d]
auto output_ptr = static_cast<int64_t*>(output.data_ptr()); auto output_ptr = static_cast<int64_t*>(output.data_ptr());
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
auto stream = at::cuda::getCurrentCUDAStream(input.get_device()); auto stream = at::cuda::getCurrentCUDAStream(input.get_device());
dim3 block(std::min(int(n / ELTS_PER_THREAD), 1024)); dim3 block(std::min(int(n / ELTS_PER_THREAD), 512));
int const numBlocksPerSM = int const numBlocksPerSM =
vllm_runtime_blocks_per_sm(static_cast<int>(block.x)); vllm_runtime_blocks_per_sm(static_cast<int>(block.x));
dim3 grid(std::min(int(m), multiProcessorCount * numBlocksPerSM));
int sf_n_unpadded = int(n / CVT_FP4_SF_VEC_SIZE);
int grid_y = vllm::div_round_up(sf_n_unpadded, static_cast<int>(block.x));
int grid_x = std::min(
int(m), std::max(1, (multiProcessorCount * numBlocksPerSM) / grid_y));
dim3 grid(grid_x, grid_y);
VLLM_DISPATCH_HALF_TYPES( VLLM_DISPATCH_HALF_TYPES(
input.scalar_type(), "silu_and_mul_nvfp4_quant_kernel", [&] { input.scalar_type(), "silu_and_mul_nvfp4_quant_kernel", [&] {
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type; using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
auto input_ptr = static_cast<cuda_type const*>(input.data_ptr()); auto input_ptr = static_cast<cuda_type const*>(input.data_ptr());
vllm::silu_mul_cvt_fp16_to_fp4<cuda_type><<<grid, block, 0, stream>>>( vllm::silu_mul_cvt_fp16_to_fp4<cuda_type><<<grid, block, 0, stream>>>(
m, n, input_ptr, input_sf_ptr, m, n, sf_n_unpadded, input_ptr, input_sf_ptr,
reinterpret_cast<uint32_t*>(output_ptr), reinterpret_cast<uint32_t*>(output_ptr),
reinterpret_cast<uint32_t*>(sf_out)); reinterpret_cast<uint32_t*>(sf_out));
}); });
......
...@@ -140,8 +140,8 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) ...@@ -140,8 +140,8 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
CVT_FP4_NUM_THREADS_PER_SF>( CVT_FP4_NUM_THREADS_PER_SF>(
rowIdx_in_expert, colIdx, numKTiles, SFout_in_expert); rowIdx_in_expert, colIdx, numKTiles, SFout_in_expert);
out_pos = out_pos = cvt_warp_fp16_to_fp4<Type, CVT_FP4_NUM_THREADS_PER_SF, UE8M0_SF>(
cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(quant_input, SFScaleVal, sf_out); quant_input, SFScaleVal, sf_out);
} }
} }
...@@ -246,8 +246,8 @@ __global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024)) ...@@ -246,8 +246,8 @@ __global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024))
CVT_FP4_NUM_THREADS_PER_SF>( CVT_FP4_NUM_THREADS_PER_SF>(
rowIdx_in_expert, colIdx, numKTiles, SFout_in_expert); rowIdx_in_expert, colIdx, numKTiles, SFout_in_expert);
out_pos = out_pos = cvt_warp_fp16_to_fp4<Type, CVT_FP4_NUM_THREADS_PER_SF, UE8M0_SF>(
cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(quant_input, SFScaleVal, sf_out); quant_input, SFScaleVal, sf_out);
} }
} }
......
...@@ -21,7 +21,8 @@ ...@@ -21,7 +21,8 @@
void scaled_fp4_quant_sm1xxa(torch::Tensor const& output, void scaled_fp4_quant_sm1xxa(torch::Tensor const& output,
torch::Tensor const& input, torch::Tensor const& input,
torch::Tensor const& output_sf, torch::Tensor const& output_sf,
torch::Tensor const& input_sf); torch::Tensor const& input_sf,
bool is_sf_swizzled_layout);
#endif #endif
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \ #if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
...@@ -51,10 +52,12 @@ void silu_and_mul_scaled_fp4_experts_quant_sm1xxa( ...@@ -51,10 +52,12 @@ void silu_and_mul_scaled_fp4_experts_quant_sm1xxa(
#endif #endif
void scaled_fp4_quant(torch::Tensor& output, torch::Tensor const& input, void scaled_fp4_quant(torch::Tensor& output, torch::Tensor const& input,
torch::Tensor& output_sf, torch::Tensor const& input_sf) { torch::Tensor& output_sf, torch::Tensor const& input_sf,
bool is_sf_swizzled_layout) {
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \ #if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120) (defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)
return scaled_fp4_quant_sm1xxa(output, input, output_sf, input_sf); return scaled_fp4_quant_sm1xxa(output, input, output_sf, input_sf,
is_sf_swizzled_layout);
#endif #endif
TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 quantization kernel"); TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 quantization kernel");
} }
......
...@@ -27,29 +27,23 @@ ...@@ -27,29 +27,23 @@
#include "cuda_utils.h" #include "cuda_utils.h"
#include "launch_bounds_utils.h" #include "launch_bounds_utils.h"
// Define before including nvfp4_utils.cuh so the header
// can use this macro during compilation.
#define NVFP4_ENABLE_ELTS16 1
#include "nvfp4_utils.cuh" #include "nvfp4_utils.cuh"
namespace vllm { namespace vllm {
template <typename Int>
__host__ __device__ inline Int round_up(Int x, Int y) {
static_assert(std::is_integral_v<Int>,
"round_up argument must be integral type");
return ((x + y - 1) / y) * y;
}
// Compute effective rows for grid configuration with swizzled SF layouts.
inline int computeEffectiveRows(int m) {
constexpr int ROW_TILE = 128;
return round_up(m, ROW_TILE);
}
// Use UE4M3 by default. // Use UE4M3 by default.
template <class Type, bool UE8M0_SF = false> template <class Type, bool UE8M0_SF = false>
__global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, Type const* in, cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, int32_t num_padded_cols,
float const* SFScale, uint32_t* out, uint32_t* SFout) { Type const* __restrict__ in,
using PackedVec = PackedVec<Type>; float const* __restrict__ SFScale,
uint32_t* __restrict__ out, uint32_t* __restrict__ SFout) {
using PackedVec = vllm::PackedVec<Type>;
static constexpr int CVT_FP4_NUM_THREADS_PER_SF = static constexpr int CVT_FP4_NUM_THREADS_PER_SF =
(CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD); (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD);
static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD, static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD,
...@@ -59,33 +53,31 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) ...@@ -59,33 +53,31 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
int32_t const numKTiles = (numCols + 63) / 64; int32_t const numKTiles = (numCols + 63) / 64;
int sf_m = round_up<int>(numRows, 128); int sf_m = round_up<int>(numRows, 128);
int sf_n_unpadded = numCols / CVT_FP4_SF_VEC_SIZE; int32_t const colIdx = blockDim.x * blockIdx.y + threadIdx.x;
int sf_n_int = round_up<int>(sf_n_unpadded, 4) / 4; int elem_idx = colIdx * CVT_FP4_ELTS_PER_THREAD;
int num_padded_cols = sf_n_int * 4 * CVT_FP4_SF_VEC_SIZE;
// Get the global scaling factor, which will be applied to the SF. // Get the global scaling factor, which will be applied to the SF.
// Note SFScale is the same as next GEMM's alpha, which is // Note SFScale is the same as next GEMM's alpha, which is
// (448.f / (Alpha_A / 6.f)). // (448.f / (Alpha_A / 6.f)).
float const global_scale = SFScale == nullptr ? 1.0f : SFScale[0]; float const global_scale = (SFScale == nullptr) ? 1.0f : SFScale[0];
// Iterate over all rows and cols including padded ones - // Iterate over all rows and cols including padded ones -
// ensures we visit every single scale factor address to initialize it. // ensures we visit every single scale factor address to initialize it.
for (int rowIdx = blockIdx.x; rowIdx < sf_m; rowIdx += gridDim.x) { for (int rowIdx = blockIdx.x; rowIdx < sf_m; rowIdx += gridDim.x) {
for (int colIdx = threadIdx.x; if (colIdx < num_padded_cols) {
colIdx < num_padded_cols / CVT_FP4_ELTS_PER_THREAD;
colIdx += blockDim.x) {
int elem_idx = colIdx * CVT_FP4_ELTS_PER_THREAD;
PackedVec in_vec; PackedVec in_vec;
int64_t inOffset = rowIdx * (numCols / CVT_FP4_ELTS_PER_THREAD) + colIdx; int64_t inOffset = rowIdx * (numCols / CVT_FP4_ELTS_PER_THREAD) + colIdx;
// If we are outside valid rows OR outside valid columns -> Use Zeros // If we are outside valid rows OR outside valid columns -> Use Zeros
if (rowIdx >= numRows || elem_idx >= numCols) { bool valid = (rowIdx < numRows) && (elem_idx < numCols);
memset(&in_vec, 0, sizeof(PackedVec)); if constexpr (CVT_FP4_PACK16) {
ld256_or_zero_cg_u32<Type>(
in_vec, &reinterpret_cast<const uint32_t*>(in)[inOffset * 8],
valid);
} else { } else {
// Valid Region: Load actual data ld128_or_zero_cg_u32<Type>(
in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset]; in_vec, &reinterpret_cast<const uint32_t*>(in)[inOffset * 4],
valid);
} }
auto sf_out = auto sf_out =
...@@ -94,13 +86,85 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) ...@@ -94,13 +86,85 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
rowIdx, colIdx, numKTiles, SFout); rowIdx, colIdx, numKTiles, SFout);
auto out_val = auto out_val =
cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(in_vec, global_scale, sf_out); cvt_warp_fp16_to_fp4<Type, CVT_FP4_NUM_THREADS_PER_SF, UE8M0_SF>(
in_vec, global_scale, sf_out);
// We do NOT write output for padding because the 'out' tensor is not
// padded.
if (valid) {
if constexpr (CVT_FP4_PACK16) {
int64_t outOffset = rowIdx * (numCols / 8) + colIdx * 2;
uint64_t packed64 =
(uint64_t(out_val.hi) << 32) | uint64_t(out_val.lo);
reinterpret_cast<uint64_t*>(out)[outOffset >> 1] = packed64;
} else {
out[inOffset] = out_val;
}
}
}
}
}
// Use UE4M3 by default.
template <class Type, bool UE8M0_SF = false>
__global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
cvt_fp16_to_fp4_sf_major(int32_t numRows, int32_t numCols,
int32_t sf_n_unpadded, Type const* __restrict__ in,
float const* __restrict__ SFScale,
uint32_t* __restrict__ out,
uint32_t* __restrict__ SFout) {
using PackedVec = PackedVec<Type>;
static constexpr int CVT_FP4_NUM_THREADS_PER_SF =
(CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD);
static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD,
"Vec size is not matched.");
int32_t const colIdx = blockDim.x * blockIdx.y + threadIdx.x;
int elem_idx = colIdx * CVT_FP4_ELTS_PER_THREAD;
// Get the global scaling factor, which will be applied to the SF.
// Note SFScale is the same as next GEMM's alpha, which is
// (448.f / (Alpha_A / 6.f)).
float const global_scale = (SFScale == nullptr) ? 1.0f : SFScale[0];
// Iterate over all rows and cols including padded ones -
// ensures we visit every single scale factor address to initialize it.
for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x) {
if (colIdx < sf_n_unpadded) {
PackedVec in_vec;
int64_t inOffset = rowIdx * (numCols / CVT_FP4_ELTS_PER_THREAD) + colIdx;
// If we are outside valid rows OR outside valid columns -> Use Zeros
bool valid = (rowIdx < numRows) && (elem_idx < numCols);
if constexpr (CVT_FP4_PACK16) {
ld256_or_zero_cg_u32<Type>(
in_vec, &reinterpret_cast<const uint32_t*>(in)[inOffset * 8],
valid);
} else {
ld128_or_zero_cg_u32<Type>(
in_vec, &reinterpret_cast<const uint32_t*>(in)[inOffset * 4],
valid);
}
auto sf_out =
sf_out_rowmajor_u8<uint32_t>(rowIdx, colIdx, sf_n_unpadded, SFout);
auto out_val =
cvt_warp_fp16_to_fp4<Type, CVT_FP4_NUM_THREADS_PER_SF, UE8M0_SF>(
in_vec, global_scale, sf_out);
// We do NOT write output for padding because the 'out' tensor is not // We do NOT write output for padding because the 'out' tensor is not
// padded. // padded.
if (rowIdx < numRows && elem_idx < numCols) { if (valid) {
// Same as inOffset because 8 elements are packed into one uint32_t. if constexpr (CVT_FP4_PACK16) {
out[inOffset] = out_val; int64_t outOffset = rowIdx * (numCols / 8) + colIdx * 2;
uint64_t packed64 =
(uint64_t(out_val.hi) << 32) | uint64_t(out_val.lo);
reinterpret_cast<uint64_t*>(out)[outOffset >> 1] = packed64;
} else {
out[inOffset] = out_val;
}
} }
} }
} }
...@@ -111,7 +175,8 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) ...@@ -111,7 +175,8 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
void scaled_fp4_quant_sm1xxa(torch::Tensor const& output, void scaled_fp4_quant_sm1xxa(torch::Tensor const& output,
torch::Tensor const& input, torch::Tensor const& input,
torch::Tensor const& output_sf, torch::Tensor const& output_sf,
torch::Tensor const& input_sf) { torch::Tensor const& input_sf,
bool is_sf_swizzled_layout) {
int32_t m = input.size(0); int32_t m = input.size(0);
int32_t n = input.size(1); int32_t n = input.size(1);
...@@ -129,19 +194,48 @@ void scaled_fp4_quant_sm1xxa(torch::Tensor const& output, ...@@ -129,19 +194,48 @@ void scaled_fp4_quant_sm1xxa(torch::Tensor const& output,
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
auto stream = at::cuda::getCurrentCUDAStream(input.get_device()); auto stream = at::cuda::getCurrentCUDAStream(input.get_device());
int sf_n_unpadded = int(n / CVT_FP4_SF_VEC_SIZE);
// Grid, Block size. Each thread converts 8 values. // Grid, Block size. Each thread converts 8 values.
dim3 block(std::min(int(n / ELTS_PER_THREAD), 512)); dim3 block(std::min(int(n / ELTS_PER_THREAD), 512));
int const numBlocksPerSM = int const numBlocksPerSM =
vllm_runtime_blocks_per_sm(static_cast<int>(block.x)); vllm_runtime_blocks_per_sm(static_cast<int>(block.x));
int effectiveRows = vllm::computeEffectiveRows(m);
dim3 grid(std::min(effectiveRows, multiProcessorCount * numBlocksPerSM)); if (is_sf_swizzled_layout) {
int sf_n_int = int(vllm::round_up(sf_n_unpadded, 4) / 4);
VLLM_DISPATCH_HALF_TYPES(input.scalar_type(), "nvfp4_quant_kernel", [&] { int32_t num_padded_cols =
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type; sf_n_int * 4 * CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD;
auto input_ptr = static_cast<cuda_type const*>(input.data_ptr());
// NOTE: We don't support e8m0 scales at this moment. int grid_y = vllm::div_round_up(num_padded_cols, static_cast<int>(block.x));
vllm::cvt_fp16_to_fp4<cuda_type, false><<<grid, block, 0, stream>>>( int grid_x =
m, n, input_ptr, input_sf_ptr, reinterpret_cast<uint32_t*>(output_ptr), std::min(vllm::computeEffectiveRows(m),
reinterpret_cast<uint32_t*>(sf_out)); std::max(1, (multiProcessorCount * numBlocksPerSM) / grid_y));
}); dim3 grid(grid_x, grid_y);
}
\ No newline at end of file VLLM_DISPATCH_HALF_TYPES(input.scalar_type(), "nvfp4_quant_kernel", [&] {
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
auto input_ptr = static_cast<cuda_type const*>(input.data_ptr());
// NOTE: We don't support e8m0 scales at this moment.
vllm::cvt_fp16_to_fp4<cuda_type, false><<<grid, block, 0, stream>>>(
m, n, num_padded_cols, input_ptr, input_sf_ptr,
reinterpret_cast<uint32_t*>(output_ptr),
reinterpret_cast<uint32_t*>(sf_out));
});
} else {
int grid_y = vllm::div_round_up(sf_n_unpadded, static_cast<int>(block.x));
int grid_x = std::min(
m, std::max(1, (multiProcessorCount * numBlocksPerSM) / grid_y));
dim3 grid(grid_x, grid_y);
VLLM_DISPATCH_HALF_TYPES(input.scalar_type(), "nvfp4_quant_kernel", [&] {
using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
auto input_ptr = static_cast<cuda_type const*>(input.data_ptr());
// NOTE: We don't support e8m0 scales at this moment.
vllm::cvt_fp16_to_fp4_sf_major<cuda_type, false>
<<<grid, block, 0, stream>>>(m, n, sf_n_unpadded, input_ptr,
input_sf_ptr,
reinterpret_cast<uint32_t*>(output_ptr),
reinterpret_cast<uint32_t*>(sf_out));
});
}
}
...@@ -19,9 +19,17 @@ ...@@ -19,9 +19,17 @@
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <cuda_fp8.h> #include <cuda_fp8.h>
#define ELTS_PER_THREAD 8 #if (defined(NVFP4_ENABLE_ELTS16) && (CUDART_VERSION >= 12090) && \
defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100)
#define ELTS_PER_THREAD 16
constexpr int CVT_FP4_ELTS_PER_THREAD = 16;
constexpr bool CVT_FP4_PACK16 = true;
#else
#define ELTS_PER_THREAD 8
constexpr int CVT_FP4_ELTS_PER_THREAD = 8; constexpr int CVT_FP4_ELTS_PER_THREAD = 8;
constexpr bool CVT_FP4_PACK16 = false;
#endif
constexpr int CVT_FP4_SF_VEC_SIZE = 16; constexpr int CVT_FP4_SF_VEC_SIZE = 16;
namespace vllm { namespace vllm {
...@@ -68,19 +76,46 @@ struct TypeConverter<__nv_bfloat16> { ...@@ -68,19 +76,46 @@ struct TypeConverter<__nv_bfloat16> {
using Type = __nv_bfloat162; using Type = __nv_bfloat162;
}; };
#if (defined(NVFP4_ENABLE_ELTS16) && (CUDART_VERSION >= 12090) && \
defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100)
// Define a 32 bytes packed data type.
template <class Type>
struct alignas(32) PackedVec {
typename TypeConverter<Type>::Type elts[8];
};
#else
// Define a 16 bytes packed data type. // Define a 16 bytes packed data type.
template <class Type> template <class Type>
struct PackedVec { struct alignas(16) PackedVec {
typename TypeConverter<Type>::Type elts[4]; typename TypeConverter<Type>::Type elts[4];
}; };
#endif
template <> template <>
struct PackedVec<__nv_fp8_e4m3> { struct PackedVec<__nv_fp8_e4m3> {
__nv_fp8x2_e4m3 elts[8]; __nv_fp8x2_e4m3 elts[8];
}; };
template <typename Int>
__host__ __device__ inline Int round_up(Int x, Int y) {
static_assert(std::is_integral_v<Int>,
"round_up argument must be integral type");
return ((x + y - 1) / y) * y;
}
template <typename Int>
__host__ __device__ __forceinline__ Int div_round_up(Int x, Int y) {
return (x + y - 1) / y;
}
// Compute effective rows for grid configuration with swizzled SF layouts.
inline int computeEffectiveRows(int m) {
constexpr int ROW_TILE = 128;
return round_up(m, ROW_TILE);
}
// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t). // Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t).
inline __device__ uint32_t fp32_vec_to_e2m1(float (&array)[8]) { inline __device__ uint32_t fp32_vec8_to_e2m1(float (&array)[8]) {
uint32_t val; uint32_t val;
asm volatile( asm volatile(
"{\n" "{\n"
...@@ -101,7 +136,7 @@ inline __device__ uint32_t fp32_vec_to_e2m1(float (&array)[8]) { ...@@ -101,7 +136,7 @@ inline __device__ uint32_t fp32_vec_to_e2m1(float (&array)[8]) {
} }
// Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t). // Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t).
inline __device__ uint32_t fp32_vec_to_e2m1(float2 (&array)[4]) { __device__ __forceinline__ uint32_t fp32_vec8_to_e2m1(float2 (&array)[4]) {
uint32_t val; uint32_t val;
asm volatile( asm volatile(
"{\n" "{\n"
...@@ -114,20 +149,115 @@ inline __device__ uint32_t fp32_vec_to_e2m1(float2 (&array)[4]) { ...@@ -114,20 +149,115 @@ inline __device__ uint32_t fp32_vec_to_e2m1(float2 (&array)[4]) {
"cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n" "cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n"
"cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n" "cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n"
"mov.b32 %0, {byte0, byte1, byte2, byte3};\n" "mov.b32 %0, {byte0, byte1, byte2, byte3};\n"
"}" "}\n"
: "=r"(val) : "=r"(val)
: "f"(array[0].x), "f"(array[0].y), "f"(array[1].x), "f"(array[1].y), : "f"(array[0].x), "f"(array[0].y), "f"(array[1].x), "f"(array[1].y),
"f"(array[2].x), "f"(array[2].y), "f"(array[3].x), "f"(array[3].y)); "f"(array[2].x), "f"(array[2].y), "f"(array[3].x), "f"(array[3].y));
return val; return val;
} }
struct u32x2 {
uint32_t lo, hi;
};
using fp4_packed_t = std::conditional_t<CVT_FP4_PACK16, u32x2, uint32_t>;
__device__ __forceinline__ u32x2 fp32_vec16_to_e2m1(float2 (&array)[8]) {
u32x2 out;
asm volatile(
"{\n"
".reg .b8 b0;\n"
".reg .b8 b1;\n"
".reg .b8 b2;\n"
".reg .b8 b3;\n"
".reg .b8 b4;\n"
".reg .b8 b5;\n"
".reg .b8 b6;\n"
".reg .b8 b7;\n"
"cvt.rn.satfinite.e2m1x2.f32 b0, %3, %2;\n"
"cvt.rn.satfinite.e2m1x2.f32 b1, %5, %4;\n"
"cvt.rn.satfinite.e2m1x2.f32 b2, %7, %6;\n"
"cvt.rn.satfinite.e2m1x2.f32 b3, %9, %8;\n"
"cvt.rn.satfinite.e2m1x2.f32 b4, %11, %10;\n"
"cvt.rn.satfinite.e2m1x2.f32 b5, %13, %12;\n"
"cvt.rn.satfinite.e2m1x2.f32 b6, %15, %14;\n"
"cvt.rn.satfinite.e2m1x2.f32 b7, %17, %16;\n"
"mov.b32 %0, {b0, b1, b2, b3};\n"
"mov.b32 %1, {b4, b5, b6, b7};\n"
"}\n"
: "=r"(out.lo), "=r"(out.hi)
: "f"(array[0].x), "f"(array[0].y), "f"(array[1].x), "f"(array[1].y),
"f"(array[2].x), "f"(array[2].y), "f"(array[3].x), "f"(array[3].y),
"f"(array[4].x), "f"(array[4].y), "f"(array[5].x), "f"(array[5].y),
"f"(array[6].x), "f"(array[6].y), "f"(array[7].x), "f"(array[7].y));
return out;
}
__device__ __forceinline__ uint32_t pack_fp4(float2 (&v)[4]) {
return fp32_vec8_to_e2m1(v);
}
__device__ __forceinline__ u32x2 pack_fp4(float2 (&v)[8]) {
return fp32_vec16_to_e2m1(v);
}
// Fast reciprocal. // Fast reciprocal.
inline __device__ float reciprocal_approximate_ftz(float a) { __device__ __forceinline__ float reciprocal_approximate_ftz(float a) {
float b; float b;
asm volatile("rcp.approx.ftz.f32 %0, %1;\n" : "=f"(b) : "f"(a)); asm volatile("rcp.approx.ftz.f32 %0, %1;" : "=f"(b) : "f"(a));
return b; return b;
} }
template <class Type>
__device__ __forceinline__ void ld128_or_zero_cg_u32(PackedVec<Type>& out,
const void* ptr,
bool pred) {
uint32_t r0, r1, r2, r3;
asm volatile(
"{\n"
" .reg .pred pr;\n"
" setp.ne.u32 pr, %4, 0;\n"
" mov.u32 %0, 0;\n"
" mov.u32 %1, 0;\n"
" mov.u32 %2, 0;\n"
" mov.u32 %3, 0;\n"
" @pr ld.global.cg.v4.u32 {%0,%1,%2,%3}, [%5];\n"
"}\n"
: "=r"(r0), "=r"(r1), "=r"(r2), "=r"(r3)
: "r"((int)pred), "l"(ptr));
*reinterpret_cast<uint4*>(&out) = uint4{r0, r1, r2, r3};
}
template <class Type>
__device__ __forceinline__ void ld256_or_zero_cg_u32(PackedVec<Type>& out,
const void* ptr,
bool pred) {
uint32_t r0, r1, r2, r3, r4, r5, r6, r7;
asm volatile(
"{\n"
" .reg .pred pr;\n"
" setp.ne.u32 pr, %8, 0;\n"
" mov.u32 %0, 0;\n"
" mov.u32 %1, 0;\n"
" mov.u32 %2, 0;\n"
" mov.u32 %3, 0;\n"
" mov.u32 %4, 0;\n"
" mov.u32 %5, 0;\n"
" mov.u32 %6, 0;\n"
" mov.u32 %7, 0;\n"
" @pr ld.global.cg.v8.u32 {%0,%1,%2,%3,%4,%5,%6,%7}, [%9];\n"
"}\n"
: "=r"(r0), "=r"(r1), "=r"(r2), "=r"(r3), "=r"(r4), "=r"(r5), "=r"(r6),
"=r"(r7)
: "r"((int)pred), "l"(ptr));
reinterpret_cast<uint4*>(&out)[0] = uint4{r0, r1, r2, r3};
reinterpret_cast<uint4*>(&out)[1] = uint4{r4, r5, r6, r7};
}
// Compute SF output offset for swizzled tensor core layout. // Compute SF output offset for swizzled tensor core layout.
// SF layout: [numMTiles, numKTiles, 32, 4, 4] // SF layout: [numMTiles, numKTiles, 32, 4, 4]
// Caller must precompute: numKTiles = (numCols + 63) / 64 // Caller must precompute: numKTiles = (numCols + 63) / 64
...@@ -166,21 +296,41 @@ __device__ __forceinline__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset( ...@@ -166,21 +296,41 @@ __device__ __forceinline__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset(
return reinterpret_cast<uint8_t*>(SFout) + SFOffset; return reinterpret_cast<uint8_t*>(SFout) + SFOffset;
} }
template <class SFType>
__device__ __forceinline__ uint8_t* sf_out_rowmajor_u8(int row, int pack,
int packs_per_row_sf,
SFType* SFout) {
constexpr int PACK = CVT_FP4_ELTS_PER_THREAD;
constexpr int THREADS_PER_SF =
CVT_FP4_SF_VEC_SIZE / PACK; // 1 if PACK=16, 2 else PACK=8
if (threadIdx.x % THREADS_PER_SF != 0) return nullptr;
int sf_col =
pack / THREADS_PER_SF; // PACK=16 => sf_col=pack; PACK=8 => sf_col=pack/2
int64_t off = (int64_t)row * packs_per_row_sf + sf_col;
return (uint8_t*)SFout + off;
}
// Quantizes the provided PackedVec into the uint32_t output // Quantizes the provided PackedVec into the uint32_t output
template <class Type, bool UE8M0_SF = false> template <class Type, int CVT_FP4_NUM_THREADS_PER_SF, bool UE8M0_SF = false>
__device__ uint32_t cvt_warp_fp16_to_fp4(PackedVec<Type>& vec, float SFScaleVal, __device__ __forceinline__ fp4_packed_t
uint8_t* SFout) { cvt_warp_fp16_to_fp4(PackedVec<Type>& vec, float SFScaleVal, uint8_t* SFout) {
// Get absolute maximum values among the local 8 values. // Get absolute maximum values among the local 8 values.
auto localMax = __habs2(vec.elts[0]); auto localMax = __habs2(vec.elts[0]);
// Local maximum value. // Local maximum value.
#pragma unroll #pragma unroll
for (int i = 1; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) { for (int i = 1; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) {
localMax = __hmax2(localMax, __habs2(vec.elts[i])); localMax = __hmax2(localMax, __habs2(vec.elts[i]));
} }
// Get the absolute maximum among all 16 values (two threads). // Get the absolute maximum among all 16 values (two threads).
localMax = __hmax2(__shfl_xor_sync(uint32_t(-1), localMax, 1), localMax);
if constexpr (CVT_FP4_NUM_THREADS_PER_SF == 2) {
localMax = __hmax2(__shfl_xor_sync(0xffffffffu, localMax, 1), localMax);
}
// Get the final absolute maximum values. // Get the final absolute maximum values.
float vecMax = float(__hmax(localMax.x, localMax.y)); float vecMax = float(__hmax(localMax.x, localMax.y));
...@@ -205,18 +355,17 @@ __device__ uint32_t cvt_warp_fp16_to_fp4(PackedVec<Type>& vec, float SFScaleVal, ...@@ -205,18 +355,17 @@ __device__ uint32_t cvt_warp_fp16_to_fp4(PackedVec<Type>& vec, float SFScaleVal,
// Convert back to fp32. // Convert back to fp32.
SFValue = float(tmp); SFValue = float(tmp);
} }
// Write the SF to global memory (STG.8).
if (SFout) *SFout = fp8SFVal;
// Get the output scale. // Get the output scale.
// Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) * // Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) *
// reciprocal(SFScaleVal)) // reciprocal(SFScaleVal))
float outputScale = float outputScale =
SFValue != 0 ? reciprocal_approximate_ftz( SFValue != 0.0f ? reciprocal_approximate_ftz(
SFValue * reciprocal_approximate_ftz(SFScaleVal)) SFValue * reciprocal_approximate_ftz(SFScaleVal))
: 0.0f; : 0.0f;
if (SFout) {
// Write the SF to global memory (STG.8).
*SFout = fp8SFVal;
}
// Convert the input to float. // Convert the input to float.
float2 fp2Vals[CVT_FP4_ELTS_PER_THREAD / 2]; float2 fp2Vals[CVT_FP4_ELTS_PER_THREAD / 2];
...@@ -233,10 +382,7 @@ __device__ uint32_t cvt_warp_fp16_to_fp4(PackedVec<Type>& vec, float SFScaleVal, ...@@ -233,10 +382,7 @@ __device__ uint32_t cvt_warp_fp16_to_fp4(PackedVec<Type>& vec, float SFScaleVal,
} }
// Convert to e2m1 values. // Convert to e2m1 values.
uint32_t e2m1Vec = fp32_vec_to_e2m1(fp2Vals); return pack_fp4(fp2Vals);
// Write the e2m1 values to global memory.
return e2m1Vec;
} }
// silu in float32 // silu in float32
......
...@@ -546,7 +546,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -546,7 +546,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// Compute NVFP4 block quantized tensor. // Compute NVFP4 block quantized tensor.
ops.def( ops.def(
"scaled_fp4_quant(Tensor! output, Tensor input," "scaled_fp4_quant(Tensor! output, Tensor input,"
" Tensor! output_scale, Tensor input_scale) -> ()"); " Tensor! output_scale, Tensor input_scale, bool "
"is_sf_swizzled_layout) -> ()");
ops.impl("scaled_fp4_quant", torch::kCUDA, &scaled_fp4_quant); ops.impl("scaled_fp4_quant", torch::kCUDA, &scaled_fp4_quant);
// Compute NVFP4 experts quantization. // Compute NVFP4 experts quantization.
......
...@@ -107,10 +107,14 @@ def test_flashinfer_nvfp4_gemm( ...@@ -107,10 +107,14 @@ def test_flashinfer_nvfp4_gemm(
# from checkpoints are in linear scales. # from checkpoints are in linear scales.
# So instead of needing to swizzle for cutlass as in modelopt.py, # So instead of needing to swizzle for cutlass as in modelopt.py,
# we need to unswizzle for trtllm here. # we need to unswizzle for trtllm here.
a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(a_dtype, a_global_scale, backend) a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(
a_dtype, a_global_scale, is_sf_swizzled_layout=True, backend=backend
)
is_sf_128x4_layout = not (backend == "trtllm" and m <= 32) is_sf_128x4_layout = not (backend == "trtllm" and m <= 32)
b_fp4, b_scale_interleaved = ops.scaled_fp4_quant(b_dtype, b_global_scale) b_fp4, b_scale_interleaved = ops.scaled_fp4_quant(
b_dtype, b_global_scale, is_sf_swizzled_layout=True
)
# get_ref_results unswizzles the scales internally. # get_ref_results unswizzles the scales internally.
expected_out = get_ref_results( expected_out = get_ref_results(
......
...@@ -27,6 +27,12 @@ PAD_SHAPES = [ ...@@ -27,6 +27,12 @@ PAD_SHAPES = [
(150, 128), (150, 128),
(150, 48), (150, 48),
(90, 80), (90, 80),
(128, 512),
(128, 1024),
(128, 2048),
(64, 7168),
(64, 7152),
(32, 14336),
] ]
SEEDS = [42] SEEDS = [42]
CUDA_DEVICES = ["cuda:0"] CUDA_DEVICES = ["cuda:0"]
...@@ -173,3 +179,25 @@ def test_quantize_to_fp4_padded(pad_shape: tuple[int, int]) -> None: ...@@ -173,3 +179,25 @@ def test_quantize_to_fp4_padded(pad_shape: tuple[int, int]) -> None:
out_ans = cast_from_fp4(out, m, n) out_ans = cast_from_fp4(out, m, n)
torch.testing.assert_close(out_ans, out_ref) torch.testing.assert_close(out_ans, out_ref)
torch.testing.assert_close(scale_ans, scale_ref) torch.testing.assert_close(scale_ans, scale_ref)
@pytest.mark.parametrize("pad_shape", PAD_SHAPES)
@torch.inference_mode()
def test_quantize_to_fp4_padded_no_sf_swizzled(pad_shape: tuple[int, int]) -> None:
dtype = torch.float16
set_random_seed(42)
torch.set_default_device("cuda:0")
m, n = pad_shape
x = torch.randn((m, n), dtype=dtype)
tensor_amax = torch.abs(x).max().to(torch.float32)
global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / tensor_amax
out_ref, scale_ref = ref_nvfp4_quant(x, global_scale)
out, out_scale = ops.scaled_fp4_quant(x, global_scale, is_sf_swizzled_layout=False)
scale_ans = out_scale.to(torch.float32)
out_ans = cast_from_fp4(out, m, n)
torch.testing.assert_close(out_ans, out_ref)
torch.testing.assert_close(scale_ans, scale_ref)
...@@ -1534,6 +1534,7 @@ def permute_cols(a: torch.Tensor, perm: torch.Tensor) -> torch.Tensor: ...@@ -1534,6 +1534,7 @@ def permute_cols(a: torch.Tensor, perm: torch.Tensor) -> torch.Tensor:
def scaled_fp4_quant( def scaled_fp4_quant(
input: torch.Tensor, input: torch.Tensor,
input_global_scale: torch.Tensor, input_global_scale: torch.Tensor,
is_sf_swizzled_layout: bool = True,
backend: str = "none", backend: str = "none",
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
""" """
...@@ -1577,22 +1578,26 @@ def scaled_fp4_quant( ...@@ -1577,22 +1578,26 @@ def scaled_fp4_quant(
else: else:
# Two fp4 values will be packed into an uint8. # Two fp4 values will be packed into an uint8.
output = torch.empty((m, n // 2), device=device, dtype=torch.uint8) output = torch.empty((m, n // 2), device=device, dtype=torch.uint8)
if is_sf_swizzled_layout:
# We use the rounded values to store the swizzled values. Due to the
# requirement of the Tensor Core, the minimum tile is 128x4 for the scales.
# So, we first pad the scales to multiples of 128 and 4. Then, the scales
# (in float8_e4m3fn) are packed into an int32 for every 4 values. More:
# https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-b-layout-4x
round_up = lambda x, y: (x + y - 1) // y * y
rounded_m = round_up(m, 128)
scale_n = n // block_size
rounded_n = round_up(scale_n, 4)
output_scale = torch.empty(
(rounded_m, rounded_n // 4), device=device, dtype=torch.int32
)
else:
output_scale = torch.empty((m, n // 16), device=device, dtype=torch.uint8)
# We use the rounded values to store the swizzled values. Due to the torch.ops._C.scaled_fp4_quant(
# requirement of the Tensor Core, the minimum tile is 128x4 for the scales. output, input, output_scale, input_global_scale, is_sf_swizzled_layout
# So, we first pad the scales to multiples of 128 and 4. Then, the scales
# (in float8_e4m3fn) are packed into an int32 for every 4 values. More:
# https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-b-layout-4x
round_up = lambda x, y: (x + y - 1) // y * y
rounded_m = round_up(m, 128)
scale_n = n // block_size
rounded_n = round_up(scale_n, 4)
output_scale = torch.empty(
(rounded_m, rounded_n // 4), device=device, dtype=torch.int32
) )
torch.ops._C.scaled_fp4_quant(output, input, output_scale, input_global_scale)
output_scale = output_scale.view(torch.float8_e4m3fn) output_scale = output_scale.view(torch.float8_e4m3fn)
return output, output_scale return output, output_scale
......
...@@ -152,6 +152,7 @@ class SiluMulNvfp4QuantPattern(ActivationQuantPattern): ...@@ -152,6 +152,7 @@ class SiluMulNvfp4QuantPattern(ActivationQuantPattern):
input=result_silu_mul, input=result_silu_mul,
output_scale=output_scale, output_scale=output_scale,
input_scale=scale, input_scale=scale,
is_sf_swizzled_layout=True,
) )
return at[1], at[2] return at[1], at[2]
......
...@@ -946,6 +946,7 @@ class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern): ...@@ -946,6 +946,7 @@ class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern):
input=rms, input=rms,
output_scale=output_scale, output_scale=output_scale,
input_scale=input_global_scale, input_scale=input_global_scale,
is_sf_swizzled_layout=True,
) )
# quant_out, allreduce_output, output_scale # quant_out, allreduce_output, output_scale
...@@ -1043,6 +1044,7 @@ class AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(BasePattern): ...@@ -1043,6 +1044,7 @@ class AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(BasePattern):
input=rms, input=rms,
output_scale=output_scale, output_scale=output_scale,
input_scale=input_global_scale, input_scale=input_global_scale,
is_sf_swizzled_layout=True,
) )
# quant_out, allreduce_output, output_scale # quant_out, allreduce_output, output_scale
......
...@@ -248,6 +248,7 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern): ...@@ -248,6 +248,7 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern):
input=attn_out_view, input=attn_out_view,
output_scale=output_scale, output_scale=output_scale,
input_scale=input_scale, input_scale=input_scale,
is_sf_swizzled_layout=True,
) )
output_scale_view = torch.ops.aten.view.dtype(at2[2], FP8_DTYPE) output_scale_view = torch.ops.aten.view.dtype(at2[2], FP8_DTYPE)
return at2[1], output_scale_view return at2[1], output_scale_view
......
...@@ -24,7 +24,6 @@ from vllm.model_executor.layers.quantization.utils.mxfp8_utils import ( ...@@ -24,7 +24,6 @@ from vllm.model_executor.layers.quantization.utils.mxfp8_utils import (
mxfp8_e4m3_quantize, mxfp8_e4m3_quantize,
) )
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.utils.flashinfer import flashinfer_fp4_quantize
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
from vllm.utils.torch_utils import is_torch_equal_or_newer from vllm.utils.torch_utils import is_torch_equal_or_newer
...@@ -117,9 +116,7 @@ def _nvfp4_quantize( ...@@ -117,9 +116,7 @@ def _nvfp4_quantize(
A_scale: torch.Tensor | None, A_scale: torch.Tensor | None,
is_sf_swizzled_layout: bool, is_sf_swizzled_layout: bool,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
return flashinfer_fp4_quantize( return ops.scaled_fp4_quant(A, A_scale, is_sf_swizzled_layout=is_sf_swizzled_layout)
A, A_scale, is_sf_swizzled_layout=is_sf_swizzled_layout
)
def _fp8_quantize( def _fp8_quantize(
......
...@@ -191,7 +191,10 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme): ...@@ -191,7 +191,10 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
# quantize BF16 or FP16 to (FP4 and interleaved block scale) # quantize BF16 or FP16 to (FP4 and interleaved block scale)
x_fp4, x_blockscale = scaled_fp4_quant( x_fp4, x_blockscale = scaled_fp4_quant(
x, layer.input_global_scale, self.backend x,
layer.input_global_scale,
is_sf_swizzled_layout=True,
backend=self.backend,
) )
mm_args = ( mm_args = (
......
...@@ -1307,7 +1307,9 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase): ...@@ -1307,7 +1307,9 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
output_shape = [x.shape[0], layer.weight.shape[0]] output_shape = [x.shape[0], layer.weight.shape[0]]
# quantize BF16 or FP16 to (FP4 and interleaved block scale) # quantize BF16 or FP16 to (FP4 and interleaved block scale)
x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_scale_inv, self.backend) x_fp4, x_blockscale = scaled_fp4_quant(
x, layer.input_scale_inv, is_sf_swizzled_layout=True, backend=self.backend
)
# validate dtypes of quantized input, input block scale, # validate dtypes of quantized input, input block scale,
# weight and weight_blockscale # weight and weight_blockscale
......
...@@ -8,6 +8,7 @@ import torch ...@@ -8,6 +8,7 @@ import torch
import vllm.envs as envs import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEConfig,
...@@ -341,10 +342,8 @@ def flashinfer_trtllm_fp4_moe( ...@@ -341,10 +342,8 @@ def flashinfer_trtllm_fp4_moe(
hidden_states_fp4, hidden_states_scale_linear_fp4 = x hidden_states_fp4, hidden_states_scale_linear_fp4 = x
else: else:
# hidden_states is the already quantized # hidden_states is the already quantized
(hidden_states_fp4, hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize( (hidden_states_fp4, hidden_states_scale_linear_fp4) = ops.scaled_fp4_quant(
x, x, layer.a1_gscale, is_sf_swizzled_layout=False
layer.a1_gscale,
is_sf_swizzled_layout=False,
) )
# Determine routing method type # Determine routing method type
...@@ -443,10 +442,8 @@ def flashinfer_trtllm_fp4_routed_moe( ...@@ -443,10 +442,8 @@ def flashinfer_trtllm_fp4_routed_moe(
hidden_states_fp4, hidden_states_scale_linear_fp4 = x hidden_states_fp4, hidden_states_scale_linear_fp4 = x
else: else:
# Quantize input to FP4 # Quantize input to FP4
(hidden_states_fp4, hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize( (hidden_states_fp4, hidden_states_scale_linear_fp4) = ops.scaled_fp4_quant(
x, x, layer.a1_gscale, is_sf_swizzled_layout=False
layer.a1_gscale,
is_sf_swizzled_layout=False,
) )
# Call TRT-LLM FP4 block-scale MoE kernel # Call TRT-LLM FP4 block-scale MoE kernel
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment