Unverified Commit 0a0aa077 authored by Lucas Wilkinson's avatar Lucas Wilkinson Committed by GitHub
Browse files

[Quant] Make static quant support all group shapes (#30833)


Signed-off-by: default avatarLucas Wilkinson <lwilkins@redhat.com>
parent f9e2a75a
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#include <optional> #include <optional>
#include <torch/library.h> #include <torch/library.h>
#include <tuple>
#include "core/scalar_type.hpp" #include "core/scalar_type.hpp"
...@@ -346,8 +347,9 @@ torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight, ...@@ -346,8 +347,9 @@ torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit); void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit);
void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input, void static_scaled_fp8_quant(
torch::Tensor const& scale); torch::Tensor& out, torch::Tensor const& input, torch::Tensor const& scale,
std::optional<std::tuple<int64_t, int64_t>> group_shape = std::nullopt);
void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input, void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
torch::Tensor& scale); torch::Tensor& scale);
......
...@@ -4,28 +4,77 @@ ...@@ -4,28 +4,77 @@
#include "quantization/vectorization_utils.cuh" #include "quantization/vectorization_utils.cuh"
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/Exceptions.h> #include <ATen/cuda/Exceptions.h>
#include <tuple>
namespace vllm { namespace vllm {
template <typename scalar_t, typename fp8_type> // STRIDE_I_ZERO: true if scale_stride_i == 0 (per-tensor or per-channel)
__global__ void scaled_fp8_quant_kernel_strided( // STRIDE_J_ZERO: true if scale_stride_j == 0 (per-tensor or per-token)
template <typename scalar_t, typename fp8_type, bool STRIDE_I_ZERO,
bool STRIDE_J_ZERO>
__global__ void scaled_fp8_quant_kernel_strided_group_shape(
fp8_type* __restrict__ out, const scalar_t* __restrict__ input, fp8_type* __restrict__ out, const scalar_t* __restrict__ input,
const float* __restrict__ scale, int hidden_size, int64_t in_row_stride, const float* __restrict__ scale, int hidden_size, int64_t in_row_stride,
int64_t out_row_stride) { int64_t out_row_stride, int group_m, int group_n, int64_t scale_stride_i,
const int64_t token_idx = blockIdx.x; // one token per block int64_t scale_stride_j) {
const int64_t token_idx = blockIdx.x;
const int tid = threadIdx.x; const int tid = threadIdx.x;
const scalar_t* token_in = input + token_idx * in_row_stride; const scalar_t* token_in = input + token_idx * in_row_stride;
fp8_type* token_out = out + token_idx * out_row_stride; fp8_type* token_out = out + token_idx * out_row_stride;
const float inv_scale = 1.0f / (*scale); // Precompute row-level base offset for scale access (compile-time eliminated
// when STRIDE_I_ZERO)
vectorize_with_alignment<16>( const int64_t scale_row_base =
token_in, token_out, hidden_size, tid, blockDim.x, STRIDE_I_ZERO ? 0
: static_cast<int>(token_idx) / group_m * scale_stride_i;
auto get_inv_scale = [&](int gj) {
return 1.0f / scale[scale_row_base + gj * scale_stride_j];
};
int cached_gj = -1;
float cached_inv_scale = 0.0f;
auto get_inv_scale_cached = [&](int gj) {
if (gj != cached_gj) {
cached_inv_scale = 1.0f / scale[scale_row_base + gj * scale_stride_j];
cached_gj = gj;
}
return cached_inv_scale;
};
constexpr int VEC_SIZE = 16; // FP8 so vectorize to 128 bits
auto scaled_fp8_conversion_vectorized = [&](const scalar_t* in, fp8_type* out,
int size, float inv_scale) {
vectorize_with_alignment<VEC_SIZE>(
in, out, size, tid, blockDim.x,
[=] __device__(fp8_type & dst, const scalar_t& src) { [=] __device__(fp8_type & dst, const scalar_t& src) {
dst = scaled_fp8_conversion<true, fp8_type>(static_cast<float>(src), dst = scaled_fp8_conversion<true, fp8_type>(static_cast<float>(src),
inv_scale); inv_scale);
}); });
};
if (STRIDE_J_ZERO && hidden_size % VEC_SIZE == 0) {
// Per-tensor or per-token: single scale per row, vectorize full row
scaled_fp8_conversion_vectorized(token_in, token_out, hidden_size,
get_inv_scale(0));
} else if (group_n % VEC_SIZE == 0) {
// Multiple column groups with vectorization
const int num_groups_n = hidden_size / group_n;
for (int gj = 0; gj < num_groups_n; gj++) {
scaled_fp8_conversion_vectorized(token_in + gj * group_n,
token_out + gj * group_n, group_n,
get_inv_scale(gj));
}
} else {
// Scalar path for small column groups (group_n < VEC_SIZE)
for (int n = tid; n < hidden_size; n += blockDim.x) {
const int gj = n / group_n;
token_out[n] = scaled_fp8_conversion<true, fp8_type>(
static_cast<float>(token_in[n]), get_inv_scale_cached(gj));
}
}
} }
template <typename scalar_t, typename fp8_type> template <typename scalar_t, typename fp8_type>
...@@ -133,17 +182,116 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel_strided( ...@@ -133,17 +182,116 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel_strided(
} // namespace vllm } // namespace vllm
void static_scaled_fp8_quant(torch::Tensor& out, // [..., d] void static_scaled_fp8_quant(
torch::Tensor& out, // [..., d]
torch::Tensor const& input, // [..., d] torch::Tensor const& input, // [..., d]
torch::Tensor const& scale) // [1] torch::Tensor const& scale, // various shapes
std::optional<std::tuple<int64_t, int64_t>>
opt_group_shape) // optional explicit (group_m, group_n)
{ {
TORCH_CHECK(input.stride(-1) == 1, TORCH_CHECK(input.stride(-1) == 1,
"last dimension of input must be contiguous"); "last dimension of input must be contiguous");
TORCH_CHECK(out.stride(-1) == 1, TORCH_CHECK(out.stride(-1) == 1,
"last dimension of output must be contiguous"); "last dimension of output must be contiguous");
const int hidden_size = input.size(-1); const int hidden_size = input.size(-1); // N (columns)
const int num_tokens = input.numel() / hidden_size; const int num_tokens = input.numel() / hidden_size; // M (rows)
// Determine group_m, group_n, and scale strides from scale shape
// Scale indexing: scale[gi * scale_stride_j + gj * scale_stride_i]
// where gi = m / group_m, gj = n / group_n
int group_m, group_n;
int64_t scale_stride_i, scale_stride_j;
if (scale.dim() == 0 || scale.numel() == 1) {
// Per-tensor: one scale for the entire tensor
group_m = num_tokens;
group_n = hidden_size;
scale_stride_i = 0;
scale_stride_j = 0;
} else if (scale.dim() == 1) {
// 1D scale: require explicit group_shape to disambiguate per-channel vs
// per-token (avoids edge case where num_tokens == hidden_size)
TORCH_CHECK(opt_group_shape.has_value(),
"1D scale requires explicit group_shape to disambiguate "
"per-channel vs per-token quantization. "
"Use group_shape=(-1, 1) for per-channel or group_shape=(1, "
"-1) for per-token.");
const auto& [opt_group_m, opt_group_n] = opt_group_shape.value();
group_m = opt_group_m == -1 ? num_tokens : static_cast<int>(opt_group_m);
group_n = opt_group_n == -1 ? hidden_size : static_cast<int>(opt_group_n);
// Validate the explicit group shape matches the 1D scale
const int64_t scale_len = scale.numel();
const int64_t expected_scale_m = num_tokens / group_m;
const int64_t expected_scale_n = hidden_size / group_n;
const int64_t expected_scale_numel = expected_scale_m * expected_scale_n;
TORCH_CHECK(scale_len == expected_scale_numel, "1D scale length (",
scale_len, ") does not match expected size (",
expected_scale_numel, ") for group_shape (", opt_group_m, ", ",
opt_group_n, ") with input shape (", num_tokens, ", ",
hidden_size, ")");
// For 1D scale, determine strides based on which dim is trivial
// Scale indexing: scale[gi * scale_stride_i + gj * scale_stride_j]
// where gi = m / group_m (row group), gj = n / group_n (col group)
if (expected_scale_m == 1) {
// Per-channel style: one scale in M dim, scale varies along N
// gi = 0 always, gj varies, so stride_1 traverses the scale
scale_stride_i = 0;
scale_stride_j = scale.stride(0);
} else if (expected_scale_n == 1) {
// Per-token style: one scale in N dim, scale varies along M
// gj = 0 always, gi varies, so stride_0 traverses the scale
scale_stride_i = scale.stride(0);
scale_stride_j = 0;
} else {
TORCH_CHECK(
false,
"1D scale can only be used when one of the scale dimensions is 1. "
"For 2D group scaling, use a 2D scale tensor.");
}
} else if (scale.dim() == 2) {
// 2D scale: infer group sizes from scale dimensions (or use explicit if
// provided)
const int64_t scale_size_0 = scale.size(0);
const int64_t scale_size_1 = scale.size(1);
TORCH_CHECK(num_tokens % scale_size_0 == 0, "num_tokens (", num_tokens,
") must be divisible by scale.size(0) (", scale_size_0, ")");
TORCH_CHECK(hidden_size % scale_size_1 == 0, "hidden_size (", hidden_size,
") must be divisible by scale.size(1) (", scale_size_1, ")");
// Infer from 2D scale shape
int inferred_group_m = num_tokens / scale_size_0;
int inferred_group_n = hidden_size / scale_size_1;
// Use explicit if provided, otherwise use inferred
if (opt_group_shape.has_value()) {
const auto& [opt_group_m, opt_group_n] = opt_group_shape.value();
group_m = opt_group_m == -1 ? num_tokens : static_cast<int>(opt_group_m);
group_n = opt_group_n == -1 ? hidden_size : static_cast<int>(opt_group_n);
// Validate explicit matches inferred
TORCH_CHECK(group_m == inferred_group_m && group_n == inferred_group_n,
"Explicit group_shape (", opt_group_m, ", ", opt_group_n,
") does not match inferred group shape (", inferred_group_m,
", ", inferred_group_n, ") from 2D scale tensor shape (",
scale_size_0, ", ", scale_size_1, ")");
} else {
group_m = inferred_group_m;
group_n = inferred_group_n;
}
scale_stride_i = scale.stride(0);
scale_stride_j = scale.stride(1);
} else {
TORCH_CHECK(false, "scale must be 0D, 1D, or 2D tensor, but got ",
scale.dim(), "D");
}
const int block_size = 256; const int block_size = 256;
dim3 grid(num_tokens); dim3 grid(num_tokens);
dim3 block(block_size); dim3 block(block_size);
...@@ -153,15 +301,23 @@ void static_scaled_fp8_quant(torch::Tensor& out, // [..., d] ...@@ -153,15 +301,23 @@ void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
// Dispatch to template-specialized kernel based on stride pattern
VLLM_DISPATCH_FLOATING_TYPES( VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "scaled_fp8_quant_kernel_scalar_type", [&] { input.scalar_type(), "scaled_fp8_quant_kernel_scalar_type", [&] {
VLLM_DISPATCH_FP8_TYPES( VLLM_DISPATCH_FP8_TYPES(
out.scalar_type(), "scaled_fp8_quant_kernel_fp8_type", [&] { out.scalar_type(), "scaled_fp8_quant_kernel_fp8_type", [&] {
vllm::scaled_fp8_quant_kernel_strided<scalar_t, fp8_t> VLLM_DISPATCH_BOOL(scale_stride_i == 0, S0_ZERO, [&] {
VLLM_DISPATCH_BOOL(scale_stride_j == 0, S1_ZERO, [&] {
vllm::scaled_fp8_quant_kernel_strided_group_shape<
scalar_t, fp8_t, S0_ZERO, S1_ZERO>
<<<grid, block, 0, stream>>>( <<<grid, block, 0, stream>>>(
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(), out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(),
scale.data_ptr<float>(), hidden_size, in_row_stride, scale.data_ptr<float>(), hidden_size, in_row_stride,
out_row_stride); out_row_stride, group_m, group_n, scale_stride_i,
scale_stride_j);
});
});
}); });
}); });
} }
......
...@@ -599,9 +599,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -599,9 +599,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.impl("gptq_shuffle", torch::kCUDA, &gptq_shuffle); ops.impl("gptq_shuffle", torch::kCUDA, &gptq_shuffle);
// Compute FP8 quantized tensor for given scaling factor. // Compute FP8 quantized tensor for given scaling factor.
// Supports per-tensor, per-channel, per-token, and arbitrary 2D group
// scaling. Optional group_m/group_n specify the group shape explicitly;
// required for 1D scales to disambiguate per-channel vs per-token.
ops.def( ops.def(
"static_scaled_fp8_quant(Tensor! result, Tensor input, Tensor scale) -> " "static_scaled_fp8_quant(Tensor! result, Tensor input, Tensor scale, "
"()"); "(int, int)? group_shape=None) -> ()");
ops.impl("static_scaled_fp8_quant", torch::kCUDA, &static_scaled_fp8_quant); ops.impl("static_scaled_fp8_quant", torch::kCUDA, &static_scaled_fp8_quant);
// Compute dynamic-per-tensor FP8 quantized tensor and scaling factor. // Compute dynamic-per-tensor FP8 quantized tensor and scaling factor.
......
...@@ -11,6 +11,10 @@ from tests.kernels.quant_utils import ( ...@@ -11,6 +11,10 @@ from tests.kernels.quant_utils import (
ref_dynamic_per_token_quant, ref_dynamic_per_token_quant,
) )
from tests.kernels.utils import opcheck from tests.kernels.utils import opcheck
from vllm.model_executor.layers.quantization.utils.quant_utils import (
scaled_quantize,
)
from vllm.platforms import current_platform
from vllm.utils.torch_utils import set_random_seed from vllm.utils.torch_utils import set_random_seed
DTYPES = [torch.bfloat16, torch.float] DTYPES = [torch.bfloat16, torch.float]
...@@ -21,10 +25,18 @@ SEEDS = [0] ...@@ -21,10 +25,18 @@ SEEDS = [0]
def opcheck_fp8_quant( def opcheck_fp8_quant(
output, input, scale=None, scale_ub=None, use_per_token_if_dynamic=False output,
input,
scale=None,
scale_ub=None,
use_per_token_if_dynamic=False,
group_shape=None,
): ):
if scale is not None: if scale is not None:
opcheck(torch.ops._C.static_scaled_fp8_quant, (output, input, scale)) opcheck(
torch.ops._C.static_scaled_fp8_quant,
(output, input, scale, group_shape),
)
elif use_per_token_if_dynamic: elif use_per_token_if_dynamic:
scale = torch.empty( scale = torch.empty(
(input.shape[0], 1), device=input.device, dtype=torch.float32 (input.shape[0], 1), device=input.device, dtype=torch.float32
...@@ -118,3 +130,92 @@ def test_fp8_quant_large(seed: int) -> None: ...@@ -118,3 +130,92 @@ def test_fp8_quant_large(seed: int) -> None:
ops_out = ops_out.to(dtype=dtype) ops_out = ops_out.to(dtype=dtype)
torch.testing.assert_close(ref_out, ops_out) torch.testing.assert_close(ref_out, ops_out)
# Test static FP8 quantization with 2D group scales
GROUP_SHAPES_2D = [
(-1, -1), # Per-tensor
(-1, 1), # Per-channel
(1, -1), # Per-token
(-1, 128), # Per-head quantization
(1, 128), # DeepSeek-style per-token-per-group (group_m=1, group_n=128)
(128, 128), # DeepSeek-style block quantization
(1, 64), # Smaller group size
(1, 16), # Small group (scalar path in kernel)
(4, 256), # Non-trivial both dimensions
]
# Use sizes divisible by all group shapes
NUM_TOKENS_GROUP = [128, 512]
HIDDEN_SIZES_GROUP = [256, 1024, 2048]
@pytest.mark.parametrize("num_tokens", NUM_TOKENS_GROUP)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES_GROUP)
@pytest.mark.parametrize("group_shape", GROUP_SHAPES_2D)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@torch.inference_mode()
def test_static_fp8_quant_group_2d(
num_tokens: int,
hidden_size: int,
group_shape: tuple[int, int],
dtype: torch.dtype,
seed: int,
) -> None:
"""Test static FP8 quantization with 2D group scales using scaled_quantize."""
# Normalize group_shape (-1 means full extent)
norm_group_m = num_tokens if group_shape[0] == -1 else group_shape[0]
norm_group_n = hidden_size if group_shape[1] == -1 else group_shape[1]
# Skip if sizes are not divisible by group shape
if num_tokens % norm_group_m != 0 or hidden_size % norm_group_n != 0:
pytest.skip(
f"Skipping: ({num_tokens}, {hidden_size}) not divisible by "
f"group_shape ({group_shape[0]}, {group_shape[1]})"
)
current_platform.seed_everything(seed)
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda")
ref_out, scale = scaled_quantize(
x, group_shape, FP8_DTYPE, compute_dtype=torch.float32
)
ops_out, ops_scale = ops.scaled_fp8_quant(x, scale=scale, group_shape=group_shape)
torch.testing.assert_close(scale, ops_scale)
torch.testing.assert_close(ref_out.float(), ops_out.float(), rtol=0.12, atol=0.0)
opcheck_fp8_quant(ops_out, x, scale=scale)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS_GROUP)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES_GROUP)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("group_shape", [(1, -1), (-1, 1)]) # per-token, per-channel
@torch.inference_mode()
def test_static_fp8_quant_1d_scale(
num_tokens: int,
hidden_size: int,
dtype: torch.dtype,
seed: int,
group_shape: tuple[int, int],
) -> None:
"""Test static FP8 quantization with 1D scale (per-token or per-channel)."""
current_platform.seed_everything(seed)
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda")
ref_out, scale_2d = scaled_quantize(
x, group_shape, FP8_DTYPE, compute_dtype=torch.float32
)
# Flatten scale to 1D for testing 1D scale path
scale_1d = scale_2d.flatten()
ops_out, ops_scale = ops.scaled_fp8_quant(
x, scale=scale_1d, group_shape=group_shape
)
torch.testing.assert_close(scale_1d, ops_scale)
torch.testing.assert_close(ref_out.float(), ops_out.float(), rtol=0.12, atol=0.0)
opcheck_fp8_quant(ops_out, x, scale=scale_1d, group_shape=group_shape)
...@@ -1752,6 +1752,7 @@ def scaled_fp8_quant( ...@@ -1752,6 +1752,7 @@ def scaled_fp8_quant(
scale_ub: torch.Tensor | None = None, scale_ub: torch.Tensor | None = None,
use_per_token_if_dynamic: bool = False, use_per_token_if_dynamic: bool = False,
output: torch.Tensor | None = None, output: torch.Tensor | None = None,
group_shape: tuple[int, int] | None = None,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
""" """
Quantize input tensor to FP8 and return quantized tensor and scale. Quantize input tensor to FP8 and return quantized tensor and scale.
...@@ -1763,14 +1764,23 @@ def scaled_fp8_quant( ...@@ -1763,14 +1764,23 @@ def scaled_fp8_quant(
will benefit from padding. will benefit from padding.
Args: Args:
input: The input tensor to be quantized to FP8 input: The input tensor to be quantized to FP8 (must be 2D: [M, N])
scale: Optional scaling factor for the FP8 quantization scale: Optional scaling factor for the FP8 quantization. Supports:
- 0D or [1]: per-tensor scaling
- 1D: requires explicit group_shape to disambiguate per-channel
vs per-token (use (-1, 1) for per-channel, (1, -1) for per-token)
- 2D [M/group_m, N/group_n]: group scaling (e.g. [M, N/128] for
DeepSeek-style (1,128) groups, or [M/128, N/128] for (128,128))
scale_ub: Optional upper bound for scaling factor in dynamic scale_ub: Optional upper bound for scaling factor in dynamic
per token case per token case
num_token_padding: If specified, pad the first dimension num_token_padding: If specified, pad the first dimension
of the output to at least this value. of the output to at least this value.
use_per_token_if_dynamic: Whether to do per_tensor or per_token use_per_token_if_dynamic: Whether to do per_tensor or per_token
in the dynamic quantization case. in the dynamic quantization case.
group_shape: Optional tuple (group_m, group_n) specifying the group
shape for static quantization. Use -1 for "full extent" (e.g.,
(-1, -1) for per-tensor, (-1, 1) for per-channel, etc.)
Required for 1D scales; optional for 2D scales.
Returns: Returns:
tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
...@@ -1799,8 +1809,7 @@ def scaled_fp8_quant( ...@@ -1799,8 +1809,7 @@ def scaled_fp8_quant(
scale = torch.empty(1, device=input.device, dtype=torch.float32) scale = torch.empty(1, device=input.device, dtype=torch.float32)
torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale) torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
else: else:
assert scale.numel() == 1, f"{scale.shape}" torch.ops._C.static_scaled_fp8_quant(output, input, scale, group_shape)
torch.ops._C.static_scaled_fp8_quant(output, input, scale)
return output, scale return output, scale
......
...@@ -10,6 +10,7 @@ from vllm.model_executor.custom_op import CustomOp ...@@ -10,6 +10,7 @@ from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape, GroupShape,
get_fp8_min_max, get_fp8_min_max,
group_broadcast,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
...@@ -22,7 +23,7 @@ _FP8_MIN_SCALING_FACTOR = 1.0 / (_FP8_MAX * 512.0) ...@@ -22,7 +23,7 @@ _FP8_MIN_SCALING_FACTOR = 1.0 / (_FP8_MAX * 512.0)
@CustomOp.register("quant_fp8") @CustomOp.register("quant_fp8")
class QuantFP8(CustomOp): class QuantFP8(CustomOp):
""" """
Quantize input tensor to FP8 (per-tensor, per-token, or per-group). Quantize input tensor to FP8 (per-tensor, per-token, per-channel, or per-group).
This CustomOp supports both static and dynamic quantization. This CustomOp supports both static and dynamic quantization.
""" """
...@@ -57,14 +58,14 @@ class QuantFP8(CustomOp): ...@@ -57,14 +58,14 @@ class QuantFP8(CustomOp):
self.is_group_quant = group_shape.is_per_group() self.is_group_quant = group_shape.is_per_group()
if self.is_group_quant: if self.is_group_quant:
assert not static, "Group quantization only supports dynamic mode"
self.group_size = group_shape.col self.group_size = group_shape.col
else: else:
assert group_shape in {GroupShape.PER_TOKEN, GroupShape.PER_TENSOR}
assert not static or group_shape == GroupShape.PER_TENSOR, (
"Only per-tensor scales supported for static quantization."
)
self.use_per_token_if_dynamic = group_shape == GroupShape.PER_TOKEN self.use_per_token_if_dynamic = group_shape == GroupShape.PER_TOKEN
if not static:
assert group_shape in (GroupShape.PER_TOKEN, GroupShape.PER_TENSOR), (
"Only per-token or per-tensor scales are supported for dynamic "
"non-group quantization."
)
def forward_cuda( def forward_cuda(
self, self,
...@@ -72,8 +73,8 @@ class QuantFP8(CustomOp): ...@@ -72,8 +73,8 @@ class QuantFP8(CustomOp):
scale: torch.Tensor | None = None, scale: torch.Tensor | None = None,
scale_ub: torch.Tensor | None = None, scale_ub: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
if self.is_group_quant: if self.is_group_quant and not self.static:
assert scale is None, "Group quantization is always dynamic" assert scale is None, "Dynamic group quantization does not use scale"
from vllm.model_executor.layers.quantization.utils import fp8_utils from vllm.model_executor.layers.quantization.utils import fp8_utils
return fp8_utils.per_token_group_quant_fp8( return fp8_utils.per_token_group_quant_fp8(
...@@ -90,12 +91,14 @@ class QuantFP8(CustomOp): ...@@ -90,12 +91,14 @@ class QuantFP8(CustomOp):
and self.group_shape == GroupShape.PER_TOKEN and self.group_shape == GroupShape.PER_TOKEN
and scale_ub.numel() == 1 and scale_ub.numel() == 1
) )
return ops.scaled_fp8_quant( return ops.scaled_fp8_quant(
x, x,
scale, scale,
num_token_padding=self.num_token_padding, num_token_padding=self.num_token_padding,
scale_ub=scale_ub, scale_ub=scale_ub,
use_per_token_if_dynamic=self.use_per_token_if_dynamic, use_per_token_if_dynamic=self.use_per_token_if_dynamic,
group_shape=self.group_shape if self.static else None,
) )
def forward_hip( def forward_hip(
...@@ -131,8 +134,8 @@ class QuantFP8(CustomOp): ...@@ -131,8 +134,8 @@ class QuantFP8(CustomOp):
scale: torch.Tensor | None = None, scale: torch.Tensor | None = None,
scale_ub: torch.Tensor | None = None, scale_ub: torch.Tensor | None = None,
): ):
if self.is_group_quant: if self.is_group_quant and not self.static:
assert scale is None, "Group quantization is always dynamic" assert scale is None, "Dynamic group quantization does not use scale"
return self._quantize_group_native(x) return self._quantize_group_native(x)
assert (scale is not None) == self.static assert (scale is not None) == self.static
...@@ -155,7 +158,10 @@ class QuantFP8(CustomOp): ...@@ -155,7 +158,10 @@ class QuantFP8(CustomOp):
# Even for dynamic per-token scales, # Even for dynamic per-token scales,
# reciprocal performs slightly better than division # reciprocal performs slightly better than division
out = x.to(torch.float32) * scale.reciprocal() out = (
x.to(torch.float32)
* group_broadcast(scale.to(torch.float32), x.shape[-2:]).reciprocal()
)
out = out.clamp(_FP8_MIN, _FP8_MAX).to(_FP8_DTYPE) out = out.clamp(_FP8_MIN, _FP8_MAX).to(_FP8_DTYPE)
# This currently generates an extra Triton kernel in compilation. # This currently generates an extra Triton kernel in compilation.
......
...@@ -158,11 +158,14 @@ def _normalize_quant_group_shape(x: torch.Tensor, group_shape: GroupShape): ...@@ -158,11 +158,14 @@ def _normalize_quant_group_shape(x: torch.Tensor, group_shape: GroupShape):
# with an extent of 1, since this can be done implicitly by pytorch # with an extent of 1, since this can be done implicitly by pytorch
def group_broadcast(t, shape): def group_broadcast(t, shape):
for i, s in enumerate(shape): for i, s in enumerate(shape):
if t.shape[i] != s and t.shape[i] != 1: # If tensor has fewer dimensions than target shape, treat missing
assert s % t.shape[i] == 0 # dimensions as size 1 (standard PyTorch broadcasting behavior)
t_dim_size = t.shape[i] if i < t.ndim else 1
if t_dim_size != s and t_dim_size != 1:
assert s % t_dim_size == 0
t = ( t = (
t.unsqueeze(i + 1) t.unsqueeze(i + 1)
.expand(*t.shape[: i + 1], s // t.shape[i], *t.shape[i + 1 :]) .expand(*t.shape[: i + 1], s // t_dim_size, *t.shape[i + 1 :])
.flatten(i, i + 1) .flatten(i, i + 1)
) )
return t return t
...@@ -180,7 +183,16 @@ def scaled_quantize( ...@@ -180,7 +183,16 @@ def scaled_quantize(
x: torch.Tensor, x: torch.Tensor,
group_shape: GroupShape, group_shape: GroupShape,
quant_dtype: torch.dtype, quant_dtype: torch.dtype,
compute_dtype: torch.dtype | None = None,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
"""
Args:
x: Input tensor to quantize
group_shape: Shape of quantization groups
quant_dtype: Target quantized dtype (e.g., torch.float8_e4m3fn)
compute_dtype: Optional dtype for intermediate computations.
If None, uses input dtype. Use torch.float32 for higher precision.
"""
group_shape = _normalize_quant_group_shape(x, group_shape) group_shape = _normalize_quant_group_shape(x, group_shape)
assert quant_dtype.is_floating_point, ( assert quant_dtype.is_floating_point, (
"currently `scaled_quantize` only supports floating point dtypes " "currently `scaled_quantize` only supports floating point dtypes "
...@@ -189,11 +201,14 @@ def scaled_quantize( ...@@ -189,11 +201,14 @@ def scaled_quantize(
finfo = torch.finfo(quant_dtype) finfo = torch.finfo(quant_dtype)
# Convert to compute dtype if specified
x_compute = x if compute_dtype is None else x.to(compute_dtype)
# Reshape (M, N) into (BLK_M, BLOCK_SIZE_M, BLK_N, BLOCK_SIZE_N) # Reshape (M, N) into (BLK_M, BLOCK_SIZE_M, BLK_N, BLOCK_SIZE_N)
assert x.ndim == 2 assert x.ndim == 2
assert x.shape[0] % group_shape[0] == 0 and x.shape[1] % group_shape[1] == 0 assert x.shape[0] % group_shape[0] == 0 and x.shape[1] % group_shape[1] == 0
blk_m, blk_n = x.shape[0] // group_shape[0], x.shape[1] // group_shape[1] blk_m, blk_n = x.shape[0] // group_shape[0], x.shape[1] // group_shape[1]
x_blkd = x.reshape(blk_m, group_shape[0], blk_n, group_shape[1]) x_blkd = x_compute.reshape(blk_m, group_shape[0], blk_n, group_shape[1])
# Permute to (BLK_M, BLK_N, BLOCK_SIZE_M, BLOCK_SIZE_N) # Permute to (BLK_M, BLK_N, BLOCK_SIZE_M, BLOCK_SIZE_N)
x_blkd_permd = x_blkd.permute(0, 2, 1, 3) x_blkd_permd = x_blkd.permute(0, 2, 1, 3)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment