Unverified Commit 0fa14907 authored by Siyuan Liu's avatar Siyuan Liu Committed by GitHub
Browse files

[TPU] Add Load-time W8A16 quantization for TPU Backend (#7005)

parent 5923532e
...@@ -244,6 +244,7 @@ class ModelConfig: ...@@ -244,6 +244,7 @@ class ModelConfig:
"fp8", "marlin", "gptq_marlin_24", "gptq_marlin", "awq_marlin", "fp8", "marlin", "gptq_marlin_24", "gptq_marlin", "awq_marlin",
"fbgemm_fp8", "compressed_tensors", "compressed-tensors" "fbgemm_fp8", "compressed_tensors", "compressed-tensors"
] ]
tpu_supported_quantization = ["tpu_int8"]
if self.quantization is not None: if self.quantization is not None:
self.quantization = self.quantization.lower() self.quantization = self.quantization.lower()
...@@ -282,6 +283,11 @@ class ModelConfig: ...@@ -282,6 +283,11 @@ class ModelConfig:
raise ValueError( raise ValueError(
f"{self.quantization} quantization is currently not " f"{self.quantization} quantization is currently not "
f"supported in ROCm.") f"supported in ROCm.")
if is_tpu(
) and self.quantization not in tpu_supported_quantization:
raise ValueError(
f"{self.quantization} quantization is currently not "
f"supported in TPU Backend.")
if self.quantization not in optimized_quantization_methods: if self.quantization not in optimized_quantization_methods:
logger.warning( logger.warning(
"%s quantization is not fully " "%s quantization is not fully "
......
...@@ -22,11 +22,13 @@ from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( ...@@ -22,11 +22,13 @@ from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
from vllm.model_executor.layers.quantization.marlin import MarlinConfig from vllm.model_executor.layers.quantization.marlin import MarlinConfig
from vllm.model_executor.layers.quantization.qqq import QQQConfig from vllm.model_executor.layers.quantization.qqq import QQQConfig
from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig
from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig
QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = { QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
"aqlm": AQLMConfig, "aqlm": AQLMConfig,
"awq": AWQConfig, "awq": AWQConfig,
"deepspeedfp": DeepSpeedFPConfig, "deepspeedfp": DeepSpeedFPConfig,
"tpu_int8": Int8TpuConfig,
"fp8": Fp8Config, "fp8": Fp8Config,
"fbgemm_fp8": FBGEMMFp8Config, "fbgemm_fp8": FBGEMMFp8Config,
# The order of gptq methods is important for config.py iteration over # The order of gptq methods is important for config.py iteration over
......
from typing import Any, Dict, List, Optional, Tuple
import torch
from torch.nn import Module
from torch.nn.parameter import Parameter
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.utils import set_weight_attrs
ACTIVATION_SCHEMES = ["none"]
class Int8TpuConfig(QuantizationConfig):
"""Int8 Quantization Config class for TPU Backend."""
def __init__(
self,
activation_scheme: str = "none",
) -> None:
if activation_scheme not in ACTIVATION_SCHEMES:
raise ValueError(
f"Unsupported activation scheme {activation_scheme}")
self.activation_scheme = activation_scheme
def get_name(self) -> str:
return "tpu_int8"
def get_supported_act_dtypes(self) -> List[torch.dtype]:
return [torch.float16, torch.bfloat16]
@classmethod
def get_min_capability(cls) -> int:
raise NotImplementedError(
"This function should not be called with TPU Backend")
@staticmethod
def get_config_filenames() -> List[str]:
return []
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "Int8TpuConfig":
activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
return cls(activation_scheme=activation_scheme)
def get_quant_method(self, layer: Module,
prefix: str) -> Optional["TPUInt8LinearMethod"]:
if isinstance(layer, LinearBase):
return TPUInt8LinearMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
return []
class TPUInt8LinearMethod(LinearMethodBase):
"""Int8 Linear method for TPU Quant. """
def __init__(self, quant_config: Int8TpuConfig):
self.quant_config = quant_config
def create_weights(self, layer: Module, input_size_per_partition: int,
output_partition_sizes: List[int], input_size: int,
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
weight = Parameter(torch.empty(sum(output_partition_sizes),
input_size_per_partition,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("weight", weight)
set_weight_attrs(weight, {
**extra_weight_attrs,
"input_dim": 1,
"output_dim": 0,
})
def _quantize_weight(
self, weight: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
weight_dtype = weight.dtype
weight = weight.cpu().to(torch.float32)
n_bit = 8
eps = 1e-5
max_int = 2**(n_bit - 1) - 1
min_int = -(2**(n_bit - 1))
max_val = weight.abs().amax(dim=-1, keepdim=True)
max_val = max_val.clamp(min=eps)
qscale = max_val / max_int
qweight = torch.clamp(torch.round(weight * (1.0 / qscale)), min_int,
max_int).to(torch.int8)
qscale = qscale.squeeze().to(weight_dtype)
return qweight, qscale
def process_weights_after_loading(self, layer: Module) -> None:
device = layer.weight.device
qweight, qscale = self._quantize_weight(layer.weight)
qweight = qweight.to(device)
qscale = qscale.to(device)
layer.weight = Parameter(qweight, requires_grad=False)
layer.scale = Parameter(qscale, requires_grad=False)
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
try:
import torch_xla.experimental.xla_quantized_matmul # noqa: F401
except ImportError as err:
raise ImportError(
"Please install torch_xla by following the instructions at "
"https://docs.vllm.ai/en/latest/getting_started/tpu-installation.html " # noqa: E501
"to run vLLM on TPU.") from err
weight = layer.weight
scale = layer.scale
out = torch.ops.xla.quantized_matmul(x, weight, scale)
if bias is not None:
out = out + bias
return out
...@@ -94,14 +94,15 @@ def _get_quantization_config( ...@@ -94,14 +94,15 @@ def _get_quantization_config(
"""Get the quantization config.""" """Get the quantization config."""
if model_config.quantization is not None: if model_config.quantization is not None:
quant_config = get_quant_config(model_config, load_config) quant_config = get_quant_config(model_config, load_config)
capability = current_platform.get_device_capability() if not is_tpu():
capability = capability[0] * 10 + capability[1] capability = current_platform.get_device_capability()
if capability < quant_config.get_min_capability(): capability = capability[0] * 10 + capability[1]
raise ValueError( if capability < quant_config.get_min_capability():
f"The quantization method {model_config.quantization} is not " raise ValueError(
"supported for the current GPU. " f"The quantization method {model_config.quantization} "
f"Minimum capability: {quant_config.get_min_capability()}. " "is not supported for the current GPU. "
f"Current capability: {capability}.") f"Minimum capability: {quant_config.get_min_capability()}. "
f"Current capability: {capability}.")
supported_dtypes = quant_config.get_supported_act_dtypes() supported_dtypes = quant_config.get_supported_act_dtypes()
if model_config.dtype not in supported_dtypes: if model_config.dtype not in supported_dtypes:
raise ValueError( raise ValueError(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment