Commit c1819454 authored by zhaosong1's avatar zhaosong1
Browse files

[feature] support online fp8 quant by ptpc_fp8.

parent 0b7cc6cf
......@@ -361,12 +361,12 @@ void static_scaled_fp8_quant(
torch::Tensor& out, torch::Tensor const& input, torch::Tensor const& scale,
std::optional<std::tuple<int64_t, int64_t>> group_shape = std::nullopt);
// void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
// torch::Tensor& scale);
void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
torch::Tensor& scale);
// void dynamic_per_token_scaled_fp8_quant(
// torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scale,
// std::optional<torch::Tensor> const& scale_ub);
void dynamic_per_token_scaled_fp8_quant(
torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scale,
std::optional<torch::Tensor> const& scale_ub);
void selective_scan_fwd(
const torch::Tensor& u, const torch::Tensor& delta, const torch::Tensor& A,
......
......@@ -625,19 +625,19 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.impl("static_scaled_fp8_quant", torch::kCUDA, &static_scaled_fp8_quant);
// Compute dynamic-per-tensor FP8 quantized tensor and scaling factor.
// ops.def(
// "dynamic_scaled_fp8_quant(Tensor! result, Tensor input, Tensor! scale) "
// "-> "
// "()");
// ops.impl("dynamic_scaled_fp8_quant", torch::kCUDA, &dynamic_scaled_fp8_quant);
ops.def(
"dynamic_scaled_fp8_quant(Tensor! result, Tensor input, Tensor! scale) "
"-> "
"()");
ops.impl("dynamic_scaled_fp8_quant", torch::kCUDA, &dynamic_scaled_fp8_quant);
// Compute dynamic-per-token FP8 quantized tensor and scaling factor.
// ops.def(
// "dynamic_per_token_scaled_fp8_quant(Tensor! result, Tensor input, "
// "Tensor! scale, Tensor? scale_ub) -> "
// "()");
// ops.impl("dynamic_per_token_scaled_fp8_quant", torch::kCUDA,
// &dynamic_per_token_scaled_fp8_quant);
ops.def(
"dynamic_per_token_scaled_fp8_quant(Tensor! result, Tensor input, "
"Tensor! scale, Tensor? scale_ub) -> "
"()");
ops.impl("dynamic_per_token_scaled_fp8_quant", torch::kCUDA,
&dynamic_per_token_scaled_fp8_quant);
// Compute int8 quantized tensor for given scaling factor.
ops.def(
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING, Literal, Optional
from typing import TYPE_CHECKING, Literal, Optional, Union
import torch
......@@ -1961,6 +1961,67 @@ def scaled_fp8_quant(
torch.ops._C.static_scaled_fp8_quant(output, input, scale)
return output, scale
def scaled_fp8_quant_weight(
input: torch.Tensor,
scale: Optional[torch.Tensor] = None,
num_token_padding: Optional[int] = None,
scale_ub: Optional[torch.Tensor] = None,
use_per_token_if_dynamic: bool = False,
output: Optional[torch.Tensor] = None,
group_shape: Optional[tuple[int, int]] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Quantize input tensor to FP8 and return quantized tensor and scale.
This function supports both static and dynamic quantization: If you
provide the scale, it will use static scaling and if you omit it,
the scale will be determined dynamically. The function also allows
optional padding of the output tensors for downstream kernels that
will benefit from padding.
Args:
input: The input tensor to be quantized to FP8
scale: Optional scaling factor for the FP8 quantization
scale_ub: Optional upper bound for scaling factor in dynamic
per token case
num_token_padding: If specified, pad the first dimension
of the output to at least this value.
use_per_token_if_dynamic: Whether to do per_tensor or per_token
in the dynamic quantization case.
Returns:
tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
scaling factor.
"""
# This code assumes batch_dim and num_tokens are flattened
assert (input.ndim == 2)
shape: Union[tuple[int, int], torch.Size] = input.shape
# For ROCm on MI300, the output fp8 dtype is torch.float_e3m3fnuz
out_dtype: torch.dtype = current_platform.fp8_dtype()
if num_token_padding:
shape = (max(num_token_padding, input.shape[0]), shape[1])
if output is None:
output = torch.empty(shape, device=input.device, dtype=out_dtype)
else:
assert num_token_padding is None, \
"padding not supported if output passed in"
assert output.dtype == out_dtype
if scale is None:
if use_per_token_if_dynamic:
scale = torch.empty((shape[0], 1),
device=input.device,
dtype=torch.float32)
torch.ops._C.dynamic_per_token_scaled_fp8_quant(
output, input.contiguous(), scale, scale_ub)
# output, scale = per_token_quant_fp8(input.contiguous())
else:
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
else:
assert scale.numel() == 1, f"{scale.shape}"
torch.ops._C.static_scaled_fp8_quant(output, input, scale)
return output, scale
def silu_and_mul_scaled_fp4_experts_quant(
input_tensor: torch.Tensor,
......@@ -2027,76 +2088,6 @@ def silu_and_mul_scaled_fp4_experts_quant(
return output, output_scales
# fp8
# def scaled_fp8_quant(
# input: torch.Tensor,
# scale: torch.Tensor | None = None,
# num_token_padding: int | None = None,
# scale_ub: torch.Tensor | None = None,
# use_per_token_if_dynamic: bool = False,
# output: torch.Tensor | None = None,
# group_shape: tuple[int, int] | None = None,
# ) -> tuple[torch.Tensor, torch.Tensor]:
# """
# Quantize input tensor to FP8 and return quantized tensor and scale.
# This function supports both static and dynamic quantization: If you
# provide the scale, it will use static scaling and if you omit it,
# the scale will be determined dynamically. The function also allows
# optional padding of the output tensors for downstream kernels that
# will benefit from padding.
# Args:
# input: The input tensor to be quantized to FP8 (must be 2D: [M, N])
# scale: Optional scaling factor for the FP8 quantization. Supports:
# - 0D or [1]: per-tensor scaling
# - 1D: requires explicit group_shape to disambiguate per-channel
# vs per-token (use (-1, 1) for per-channel, (1, -1) for per-token)
# - 2D [M/group_m, N/group_n]: group scaling (e.g. [M, N/128] for
# DeepSeek-style (1,128) groups, or [M/128, N/128] for (128,128))
# scale_ub: Optional upper bound for scaling factor in dynamic
# per token case
# num_token_padding: If specified, pad the first dimension
# of the output to at least this value.
# use_per_token_if_dynamic: Whether to do per_tensor or per_token
# in the dynamic quantization case.
# group_shape: Optional tuple (group_m, group_n) specifying the group
# shape for static quantization. Use -1 for "full extent" (e.g.,
# (-1, -1) for per-tensor, (-1, 1) for per-channel, etc.)
# Required for 1D scales; optional for 2D scales.
# Returns:
# tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
# scaling factor.
# """
# # This code assumes batch_dim and num_tokens are flattened
# assert input.ndim == 2
# shape: tuple[int, int] | torch.Size = input.shape
# # For ROCm on MI300, the output fp8 dtype is torch.float_e3m3fnuz
# out_dtype: torch.dtype = current_platform.fp8_dtype()
# if num_token_padding:
# shape = (max(num_token_padding, input.shape[0]), shape[1])
# if output is None:
# output = torch.empty(shape, device=input.device, dtype=out_dtype)
# else:
# assert num_token_padding is None, "padding not supported if output passed in"
# assert output.dtype == out_dtype
# if scale is None:
# if use_per_token_if_dynamic:
# scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32)
# torch.ops._C.dynamic_per_token_scaled_fp8_quant(
# output, input, scale, scale_ub
# )
# else:
# scale = torch.empty(1, device=input.device, dtype=torch.float32)
# torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
# else:
# torch.ops._C.static_scaled_fp8_quant(output, input, scale, group_shape)
# return output, scale
# gptq allspark
def allspark_repack_weight(
qweight: torch.Tensor,
......
......@@ -45,7 +45,7 @@ QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods))
DEPRECATED_QUANTIZATION_METHODS = [
"tpu_int8",
"ptpc_fp8",
# "ptpc_fp8",
"fbgemm_fp8",
"fp_quant",
"bitblas",
......
......@@ -16,6 +16,7 @@ from vllm.model_executor.layers.quantization.fp8 import (
Fp8Config,
Fp8KVCacheMethod,
Fp8LinearMethod,
Fp8OnlineLinearMethod,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
init_fp8_linear_kernel,
......@@ -42,10 +43,10 @@ class PTPCFp8Config(Fp8Config):
if not current_platform.is_rocm():
raise ValueError("ptpc_fp8 quantization is supported only on ROCm.")
if not current_platform.has_device_capability(94):
raise ValueError(
"ptpc_fp8 quantization is supported only on AMD Instinct MI300 GPUs and newer." # noqa: E501
)
# if not current_platform.has_device_capability(94):
# raise ValueError(
# "ptpc_fp8 quantization is supported only on AMD Instinct MI300 GPUs and newer." # noqa: E501
# )
if activation_scheme == "static":
raise ValueError("ptpc_fp8 as of now only support dynamic quantization.")
......@@ -77,7 +78,7 @@ class PTPCFp8Config(Fp8Config):
return None
class PTPCFp8LinearMethod(Fp8LinearMethod):
class PTPCFp8LinearMethod(Fp8OnlineLinearMethod):
"""Linear method for Per-Token and Per-Channel FP8 Quantization.
Only supports loading quantized BF16 model checkpoints with dynamic
activation scaling. To load FP16 model checkpoints, user must specify
......@@ -114,13 +115,13 @@ class PTPCFp8LinearMethod(Fp8LinearMethod):
if layer.weight.data.dtype == torch.bfloat16:
# Quantize the weights.
qweight, weight_scale = ops.scaled_fp8_quant(
qweight, weight_scale = ops.scaled_fp8_quant_weight(
layer.weight, scale=None, use_per_token_if_dynamic=True
)
# Update the layer with the new values.
layer.weight = Parameter(
qweight.t(), requires_grad=False
qweight.contiguous(), requires_grad=False
) # Pretranspose the weight
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
else:
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
try:
from ._version import __version__, __version_tuple__
except Exception as e:
import warnings
warnings.warn(f"Failed to read commit hash:\n{e}", RuntimeWarning, stacklevel=2)
__version__ = "dev"
__version_tuple__ = (0, 0, __version__)
def _prev_minor_version_was(version_str):
"""Check whether a given version matches the previous minor version.
Return True if version_str matches the previous minor version.
For example - return True if the current version if 0.7.4 and the
supplied version_str is '0.6'.
Used for --show-hidden-metrics-for-version.
"""
# Match anything if this is a dev tree
if __version_tuple__[0:2] == (0, 0):
return True
# Note - this won't do the right thing when we release 1.0!
assert __version_tuple__[0] == 0
assert isinstance(__version_tuple__[1], int)
return version_str == f"{__version_tuple__[0]}.{__version_tuple__[1] - 1}"
def _prev_minor_version():
"""For the purpose of testing, return a previous minor version number."""
# In dev tree, this will return "0.-1", but that will work fine"
assert isinstance(__version_tuple__[1], int)
return f"{__version_tuple__[0]}.{__version_tuple__[1] - 1}"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment