Commit 539aa992 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.6.2' into v0.6.2-dev

parents 93872128 7193774b
...@@ -120,9 +120,8 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -120,9 +120,8 @@ class Fp8LinearMethod(LinearMethodBase):
# For GPUs that lack FP8 hardware support, we can leverage the Marlin # For GPUs that lack FP8 hardware support, we can leverage the Marlin
# kernel for fast weight-only FP8 quantization # kernel for fast weight-only FP8 quantization
capability = current_platform.get_device_capability() self.use_marlin = (not current_platform.has_device_capability(89)
capability = capability[0] * 10 + capability[1] or envs.VLLM_TEST_FORCE_FP8_MARLIN)
self.use_marlin = capability < 89 or envs.VLLM_TEST_FORCE_FP8_MARLIN
# Disable marlin for rocm # Disable marlin for rocm
if is_hip(): if is_hip():
self.use_marlin = False self.use_marlin = False
......
...@@ -55,7 +55,10 @@ class GGUFConfig(QuantizationConfig): ...@@ -55,7 +55,10 @@ class GGUFConfig(QuantizationConfig):
def _fuse_mul_mat(x: torch.Tensor, qweight: torch.Tensor, def _fuse_mul_mat(x: torch.Tensor, qweight: torch.Tensor,
qweight_type: int) -> torch.Tensor: qweight_type: int) -> torch.Tensor:
# use dequantize mulmat for IQmatrix, mmq for k-quants # use dequantize mulmat for IQmatrix, mmq for k-quants
if qweight_type >= 16: if x.shape[0] == 1:
# enable mmvq in contiguous batching
y = ops.ggml_mul_mat_vec_a8(qweight, x, qweight_type, qweight.shape[0])
elif qweight_type >= 16:
block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type] block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type]
shape = (qweight.shape[0], qweight.shape[1] // type_size * block_size) shape = (qweight.shape[0], qweight.shape[1] // type_size * block_size)
weight = ops.ggml_dequantize(qweight, qweight_type, *shape) weight = ops.ggml_dequantize(qweight, qweight_type, *shape)
......
...@@ -217,6 +217,7 @@ class GPTQLinearMethod(LinearMethodBase): ...@@ -217,6 +217,7 @@ class GPTQLinearMethod(LinearMethodBase):
layer.qzeros = Parameter(layer.qzeros.data, requires_grad=False) layer.qzeros = Parameter(layer.qzeros.data, requires_grad=False)
layer.qweight = Parameter(layer.qweight.data, requires_grad=False) layer.qweight = Parameter(layer.qweight.data, requires_grad=False)
layer.g_idx = Parameter(layer.g_idx.data, requires_grad=False) layer.g_idx = Parameter(layer.g_idx.data, requires_grad=False)
layer.scales = Parameter(layer.scales.data, requires_grad=False)
# exllama needs to shuffle the weight after the weight is loaded # exllama needs to shuffle the weight after the weight is loaded
# here we do the shuffle on first forward pass # here we do the shuffle on first forward pass
......
from typing import Any, Callable, Dict, List, Optional, Union from typing import Any, Callable, Dict, List, Optional, Set, Union
import torch import torch
from torch.nn import Parameter
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -11,12 +10,12 @@ from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, ...@@ -11,12 +10,12 @@ from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
set_weight_attrs) set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.layers.quantization.kernels import (
MPLinearLayerConfig, choose_mp_linear_kernel)
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.marlin_utils import ( from vllm.model_executor.layers.quantization.utils.marlin_utils import (
apply_gptq_marlin_linear, check_marlin_supported, marlin_is_k_full, check_marlin_supported, marlin_moe_permute_scales,
marlin_make_empty_g_idx, marlin_make_workspace, marlin_moe_permute_scales, marlin_repeat_scales_on_all_ranks, verify_marlin_supported)
marlin_permute_scales, marlin_repeat_scales_on_all_ranks,
marlin_sort_g_idx, replace_tensor, verify_marlin_supported,
verify_marlin_supports_shape)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.parameter import (ChannelQuantScaleParameter, from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
GroupQuantScaleParameter, GroupQuantScaleParameter,
...@@ -132,10 +131,10 @@ class GPTQMarlinConfig(QuantizationConfig): ...@@ -132,10 +131,10 @@ class GPTQMarlinConfig(QuantizationConfig):
def is_gptq_marlin_compatible(cls, quant_config: Dict[str, Any]): def is_gptq_marlin_compatible(cls, quant_config: Dict[str, Any]):
# Extract data from quant config. # Extract data from quant config.
quant_method = quant_config.get("quant_method", "").lower() quant_method = quant_config.get("quant_method", "").lower()
num_bits = quant_config.get("bits", None) num_bits = quant_config.get("bits")
group_size = quant_config.get("group_size", None) group_size = quant_config.get("group_size")
sym = quant_config.get("sym", None) sym = quant_config.get("sym")
desc_act = quant_config.get("desc_act", None) desc_act = quant_config.get("desc_act")
if quant_method != "gptq": if quant_method != "gptq":
return False return False
...@@ -159,6 +158,8 @@ class GPTQMarlinLinearMethod(LinearMethodBase): ...@@ -159,6 +158,8 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
quant_config: The GPTQ Marlin quantization config. quant_config: The GPTQ Marlin quantization config.
""" """
_kernel_backends_being_used: Set[str] = set()
def __init__(self, quant_config: GPTQMarlinConfig) -> None: def __init__(self, quant_config: GPTQMarlinConfig) -> None:
self.quant_config = quant_config self.quant_config = quant_config
...@@ -176,25 +177,34 @@ class GPTQMarlinLinearMethod(LinearMethodBase): ...@@ -176,25 +177,34 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
params_dtype: torch.dtype, params_dtype: torch.dtype,
**extra_weight_attrs, **extra_weight_attrs,
) -> None: ) -> None:
del output_size
output_size_per_partition = sum(output_partition_sizes) output_size_per_partition = sum(output_partition_sizes)
is_row_parallel = input_size != input_size_per_partition is_row_parallel = input_size != input_size_per_partition
weight_loader = extra_weight_attrs.get("weight_loader") weight_loader = extra_weight_attrs.get("weight_loader")
mp_linear_kernel_config = MPLinearLayerConfig(
full_weight_shape=(input_size, output_size),
partition_weight_shape=\
(input_size_per_partition, output_size_per_partition),
weight_type=self.quant_config.quant_type,
act_type=params_dtype,
group_size=self.quant_config.group_size,
zero_points=False,
has_g_idx=self.quant_config.desc_act
)
kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config)
if kernel_type.__name__ not in self._kernel_backends_being_used:
logger.info("Using %s for GPTQMarlinLinearMethod",
kernel_type.__name__)
self._kernel_backends_being_used.add(kernel_type.__name__)
# Normalize group_size # Normalize group_size
if self.quant_config.group_size != -1: if self.quant_config.group_size != -1:
group_size = self.quant_config.group_size group_size = self.quant_config.group_size
else: else:
group_size = input_size group_size = input_size
verify_marlin_supports_shape(
output_size_per_partition=output_size_per_partition,
input_size_per_partition=input_size_per_partition,
input_size=input_size,
group_size=group_size,
)
# Determine sharding # Determine sharding
if marlin_repeat_scales_on_all_ranks(self.quant_config.desc_act, if marlin_repeat_scales_on_all_ranks(self.quant_config.desc_act,
self.quant_config.group_size, self.quant_config.group_size,
...@@ -275,57 +285,15 @@ class GPTQMarlinLinearMethod(LinearMethodBase): ...@@ -275,57 +285,15 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
layer.register_parameter("g_idx", g_idx) layer.register_parameter("g_idx", g_idx)
layer.register_parameter("scales", scales) layer.register_parameter("scales", scales)
layer.register_parameter("qzeros", qzeros) layer.register_parameter("qzeros", qzeros)
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
layer.input_size = input_size
layer.is_k_full = marlin_is_k_full(self.quant_config.desc_act,
is_row_parallel)
# Checkpoints are serialized in AutoGPTQ format, which is different from the
# marlin format. This function is called after the weights are loaded.
# Here, we handle the repacking, including the activation reordering case.
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
device = layer.qweight.device
# required by torch.compile self.kernel = kernel_type(mp_linear_kernel_config,
layer.qweight = Parameter(layer.qweight.data, requires_grad=False) w_q_param_name="qweight",
layer.scales = Parameter(layer.scales.data, requires_grad=False) w_s_param_name="scales",
w_zp_param_name="qzeros",
w_gidx_param_name="g_idx")
# Allocate marlin workspace def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.workspace = marlin_make_workspace( self.kernel.process_weights_after_loading(layer)
layer.output_size_per_partition, device)
# Handle sorting for activation reordering if needed.
if self.quant_config.desc_act:
g_idx, g_idx_sort_indices = marlin_sort_g_idx(layer.g_idx)
layer.g_idx_sort_indices = g_idx_sort_indices
replace_tensor(layer, "g_idx", g_idx)
else:
layer.g_idx = marlin_make_empty_g_idx(device)
layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
# No zero-point
layer.zp = marlin_make_empty_g_idx(device)
# Repack weights from autogptq format to marlin format.
marlin_qweight = ops.gptq_marlin_repack(
layer.qweight,
perm=layer.g_idx_sort_indices,
size_k=layer.input_size_per_partition,
size_n=layer.output_size_per_partition,
num_bits=self.quant_config.quant_type.size_bits,
)
replace_tensor(layer, "qweight", marlin_qweight)
# Permute scales from autogptq format to marlin format.
marlin_scales = marlin_permute_scales(
layer.scales,
size_k=(layer.input_size if self.quant_config.desc_act else
layer.input_size_per_partition),
size_n=layer.output_size_per_partition,
group_size=self.quant_config.group_size,
)
replace_tensor(layer, "scales", marlin_scales)
def apply( def apply(
self, self,
...@@ -333,20 +301,7 @@ class GPTQMarlinLinearMethod(LinearMethodBase): ...@@ -333,20 +301,7 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
x: torch.Tensor, x: torch.Tensor,
bias: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
return apply_gptq_marlin_linear( return self.kernel.apply_weights(layer, x, bias)
input=x,
weight=layer.qweight,
weight_scale=layer.scales,
weight_zp=layer.zp,
g_idx=layer.g_idx,
g_idx_sort_indices=layer.g_idx_sort_indices,
workspace=layer.workspace,
wtype=self.quant_config.quant_type,
output_size_per_partition=layer.output_size_per_partition,
input_size_per_partition=layer.input_size_per_partition,
is_k_full=layer.is_k_full,
bias=bias,
)
class GPTQMarlinMoEMethod(FusedMoEMethodBase): class GPTQMarlinMoEMethod(FusedMoEMethodBase):
...@@ -506,12 +461,12 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): ...@@ -506,12 +461,12 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
w13_g_idx_sort_indices[e]] w13_g_idx_sort_indices[e]]
w2_sorted_g_idx[e] = layer.w2_g_idx[e][ w2_sorted_g_idx[e] = layer.w2_g_idx[e][
w2_g_idx_sort_indices[e]] w2_g_idx_sort_indices[e]]
replace_tensor(layer, "w13_g_idx", w13_sorted_g_idx) replace_parameter(layer, "w13_g_idx", w13_sorted_g_idx)
replace_tensor(layer, "w2_g_idx", w2_sorted_g_idx) replace_parameter(layer, "w2_g_idx", w2_sorted_g_idx)
replace_tensor(layer, "w13_g_idx_sort_indices", replace_parameter(layer, "w13_g_idx_sort_indices",
w13_g_idx_sort_indices) w13_g_idx_sort_indices)
replace_tensor(layer, "w2_g_idx_sort_indices", replace_parameter(layer, "w2_g_idx_sort_indices",
w2_g_idx_sort_indices) w2_g_idx_sort_indices)
else: else:
# Reset g_idx related tensors # Reset g_idx related tensors
num_experts = layer.w13_g_idx.shape[0] num_experts = layer.w13_g_idx.shape[0]
...@@ -544,7 +499,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): ...@@ -544,7 +499,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
layer.w13_qweight.shape[2], layer.w13_qweight.shape[2],
self.quant_config.quant_type.size_bits, self.quant_config.quant_type.size_bits,
) )
replace_tensor(layer, "w13_qweight", marlin_w13_qweight) replace_parameter(layer, "w13_qweight", marlin_w13_qweight)
marlin_w2_qweight = ops.gptq_marlin_moe_repack( marlin_w2_qweight = ops.gptq_marlin_moe_repack(
layer.w2_qweight, layer.w2_qweight,
layer.w2_g_idx_sort_indices, layer.w2_g_idx_sort_indices,
...@@ -552,7 +507,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): ...@@ -552,7 +507,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
layer.w2_qweight.shape[2], layer.w2_qweight.shape[2],
self.quant_config.quant_type.size_bits, self.quant_config.quant_type.size_bits,
) )
replace_tensor(layer, "w2_qweight", marlin_w2_qweight) replace_parameter(layer, "w2_qweight", marlin_w2_qweight)
# Repack scales # Repack scales
marlin_w13_scales = marlin_moe_permute_scales( marlin_w13_scales = marlin_moe_permute_scales(
s=layer.w13_scales, s=layer.w13_scales,
...@@ -560,14 +515,14 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): ...@@ -560,14 +515,14 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
size_n=layer.w13_scales.shape[2], size_n=layer.w13_scales.shape[2],
group_size=self.quant_config.group_size, group_size=self.quant_config.group_size,
) )
replace_tensor(layer, "w13_scales", marlin_w13_scales) replace_parameter(layer, "w13_scales", marlin_w13_scales)
marlin_w2_scales = marlin_moe_permute_scales( marlin_w2_scales = marlin_moe_permute_scales(
s=layer.w2_scales, s=layer.w2_scales,
size_k=layer.w2_scales.shape[1] * self.quant_config.pack_factor, size_k=layer.w2_scales.shape[1] * self.quant_config.pack_factor,
size_n=layer.w2_scales.shape[2], size_n=layer.w2_scales.shape[2],
group_size=self.quant_config.group_size, group_size=self.quant_config.group_size,
) )
replace_tensor(layer, "w2_scales", marlin_w2_scales) replace_parameter(layer, "w2_scales", marlin_w2_scales)
def apply( def apply(
self, self,
...@@ -611,4 +566,5 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): ...@@ -611,4 +566,5 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
topk_ids, topk_ids,
w1_scale=layer.w13_scales, w1_scale=layer.w13_scales,
w2_scale=layer.w2_scales, w2_scale=layer.w2_scales,
num_bits=self.quant_config.quant_type.size_bits,
).to(orig_dtype) ).to(orig_dtype)
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Callable, Optional, Tuple
import torch
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.scalar_type import ScalarType
@dataclass
class MPLinearLayerConfig:
full_weight_shape: Tuple[int, int] # [in, out]
partition_weight_shape: Tuple[int, int]
weight_type: ScalarType
act_type: torch.dtype
group_size: int
zero_points: bool
has_g_idx: bool
class MPLinearKernel(ABC):
@classmethod
@abstractmethod
def get_min_capability(cls) -> int:
raise NotImplementedError
@classmethod
@abstractmethod
def can_implement(cls,
c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]:
raise NotImplementedError
def __init__(self,
c: MPLinearLayerConfig,
w_q_param_name: str,
w_s_param_name: str,
w_zp_param_name: Optional[str] = None,
w_gidx_param_name: Optional[str] = None) -> None:
assert self.can_implement(c)
self.config = c
self.w_q_name = w_q_param_name
self.w_s_name = w_s_param_name
self.w_zp_name = w_zp_param_name
self.w_gidx_name = w_gidx_param_name
@abstractmethod
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
raise NotImplementedError
@abstractmethod
def apply_weights(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
raise NotImplementedError
def _transform_param(self, layer: torch.nn.Module, name: Optional[str],
fn: Callable) -> None:
if name is not None and getattr(layer, name, None) is not None:
old_param = getattr(layer, name)
new_param = fn(old_param)
# replace the parameter with torch.nn.Parameter for TorchDynamo
# compatibility
replace_parameter(
layer, name,
torch.nn.Parameter(new_param.data, requires_grad=False))
def _get_weight_params(
self, layer: torch.nn.Module
) -> Tuple[torch.Tensor, # w_q
torch.Tensor, # w_s
Optional[torch.Tensor], # w_zp,
Optional[torch.Tensor] # w_gidx
]:
return (
getattr(layer, self.w_q_name),
getattr(layer, self.w_s_name),
getattr(layer, self.w_zp_name or "", None),
getattr(layer, self.w_gidx_name or "", None),
)
import os
from typing import List, Optional, Type
from vllm.model_executor.layers.quantization.kernels.machete import (
MacheteLinearKernel)
from vllm.model_executor.layers.quantization.kernels.marlin import (
MarlinLinearKernel)
from vllm.model_executor.layers.quantization.kernels.MPLinearKernel import (
MPLinearKernel, MPLinearLayerConfig)
from vllm.platforms import current_platform
# in priority/performance order (when available)
_POSSIBLE_KERNELS: List[Type[MPLinearKernel]] = [
MacheteLinearKernel,
MarlinLinearKernel,
]
def choose_mp_linear_kernel(
config: MPLinearLayerConfig,
compute_capability: Optional[int] = None) -> Type[MPLinearKernel]:
"""
Choose an MPLinearKernel that can implement the given config for the given
compute capability. Attempts to choose the best kernel in terms of
performance.
Args:
config (MPLinearLayerConfig): Description of the linear layer to be
implemented.
compute_capability (Optional[int], optional): The compute capability of
the target device, if None uses `current_platform` to get the compute
capability. Defaults to None.
Raises:
ValueError: If no kernel can implement the given config.
Returns:
Type[MPLinearKernel]: Chosen kernel.
"""
if compute_capability is None:
if current_platform is None:
raise ValueError("Cannot determine compute capability")
_cc = current_platform.get_device_capability()
compute_capability = _cc[0] * 10 + _cc[1]
failure_reasons = []
for kernel in _POSSIBLE_KERNELS:
if kernel.__name__ in os.environ.get("VLLM_DISABLED_KERNELS", "")\
.split(","):
failure_reasons.append(
f' {kernel.__name__} disabled by environment variable')
continue
if kernel.get_min_capability() > compute_capability:
failure_reasons.append(
f"{kernel.__name__} requires capability "
f"{kernel.get_min_capability()}, current compute capability "
f"is {compute_capability}")
continue
can_implement, failure_reason = kernel.can_implement(config)
if can_implement:
return kernel
else:
failure_reasons.append(
f' {kernel.__name__} cannot implement due to: {failure_reason}'
)
raise ValueError(
"Failed to find a kernel that can implement the "\
"WNA16 linear layer. Reasons: \n"
+ '\n'.join(failure_reasons))
from functools import partial
from typing import Optional, Tuple
import torch
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.machete_utils import (
MACHETE_SUPPORTED_GROUP_SIZES, check_machete_supports_shape,
query_machete_supported_quant_types)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
pack_weights_into_int32, unpack_weights_into_int32)
from vllm.model_executor.parameter import (BasevLLMParameter,
permute_param_layout_)
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
class MacheteLinearKernel(MPLinearKernel):
@classmethod
def get_min_capability(cls) -> int:
return 90
@classmethod
def can_implement(cls,
c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]:
if c.has_g_idx and\
c.partition_weight_shape[0] != c.full_weight_shape[0]:
return False, "Act reordering currently not supported by Machete, "\
"when the input features are partitioned across "\
"devices"
if c.zero_points:
return False, "Zero points currently not supported by "\
" Compressed Tensors + Machete. (Kernel supports it"\
" but CompressedTensorsWNA16 does not so support has"\
" not been added to MacheteWNA16Kernel yet"
if c.weight_type not in query_machete_supported_quant_types(
c.zero_points):
return False, f"Quant type ({c.weight_type}) not supported by "\
"Machete, supported types are: "\
f"{query_machete_supported_quant_types(c.zero_points)}"
if c.group_size not in MACHETE_SUPPORTED_GROUP_SIZES:
return False, f"Group size ({c.group_size}) not supported by "\
"Machete, supported group sizes are: "\
f"{MACHETE_SUPPORTED_GROUP_SIZES}"
return check_machete_supports_shape(c.partition_weight_shape[0],
c.partition_weight_shape[1])
# note assumes that
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
# `weight_scale` is: {input_dim = 0, output_dim = 1}
def process_weights_after_loading(self, layer: torch.nn.Module):
c = self.config
if c.has_g_idx:
assert self.w_gidx_name is not None
perm = torch.argsort(getattr(layer, self.w_gidx_name))\
.to(torch.int)
self.act_perm = lambda x: x[:, perm]
# use `ops.permute_cols` if possible
if c.act_type in [torch.float16, torch.bfloat16] \
and c.partition_weight_shape[0] % 8 == 0:
self.act_perm = partial(ops.permute_cols, perm=perm)
def transform_w_q(x):
assert isinstance(x, BasevLLMParameter)
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
if c.has_g_idx:
x_unpacked = unpack_weights_into_int32(x.data,
c.weight_type,
packed_dim=0)
x_perm = x_unpacked[perm, :]
x.data = pack_weights_into_int32(x_perm,
c.weight_type,
packed_dim=0)
x.data = ops.machete_prepack_B(x.data.t().contiguous().t(),
self.config.weight_type)
return x
def transform_w_s(x):
assert isinstance(x, BasevLLMParameter)
permute_param_layout_(x, input_dim=0, output_dim=1)
x.data = x.data.contiguous()
return x
# Repack weights and scales for Machete
self._transform_param(layer, self.w_q_name, transform_w_q)
self._transform_param(layer, self.w_s_name, transform_w_s)
def apply_weights(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
c = self.config
w_q, w_s, _, _ = self._get_weight_params(layer)
x_2d = x.reshape(-1, x.shape[-1])
out_shape = x.shape[:-1] + (c.partition_weight_shape[1], )
if c.has_g_idx:
x_2d = self.act_perm(x_2d)
output = ops.machete_gemm(a=x_2d,
b_q=w_q,
b_type=c.weight_type,
b_zeros=None,
b_scales=w_s,
b_group_size=c.group_size)
if bias is not None:
output.add_(bias) # In-place add
return output.reshape(out_shape)
from typing import Optional, Tuple
import torch
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
MARLIN_SUPPORTED_GROUP_SIZES, apply_gptq_marlin_linear,
check_marlin_supports_shape, marlin_is_k_full, marlin_make_empty_g_idx,
marlin_make_workspace, marlin_permute_scales, marlin_sort_g_idx,
query_marlin_supported_quant_types)
from vllm.model_executor.parameter import (BasevLLMParameter,
permute_param_layout_)
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
class MarlinLinearKernel(MPLinearKernel):
@classmethod
def get_min_capability(cls) -> int:
return 80
@classmethod
def can_implement(cls,
c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]:
if c.zero_points:
return False, "Zero points currently not supported by "\
" MarlinLinearKernel. Will be added when AWQMarlin "\
"is migrated over to using MPLinearKernel backend"
quant_types = query_marlin_supported_quant_types(c.zero_points)
if c.weight_type not in quant_types:
return False, f"Quant type ({c.weight_type}) not supported by"\
f" Marlin, supported types are: {quant_types}"
if c.group_size not in MARLIN_SUPPORTED_GROUP_SIZES:
return False, f"Group size ({c.group_size}) not supported by "\
"Marlin, supported group sizes are: "\
f"{MARLIN_SUPPORTED_GROUP_SIZES}"
return check_marlin_supports_shape(c.partition_weight_shape[0],
c.partition_weight_shape[1],
c.full_weight_shape[1],
c.group_size)
# note assumes that
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
# `weight_scale` is: {input_dim = 0, output_dim = 1}
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
device = getattr(layer, self.w_q_name).device
c = self.config
row_parallel = (c.partition_weight_shape[0] != c.full_weight_shape[0])
self.is_k_full = marlin_is_k_full(c.has_g_idx, row_parallel)
# Allocate marlin workspace.
self.workspace = marlin_make_workspace(c.partition_weight_shape[1],
device)
# Default names since marlin requires empty parameters for these,
# TODO: remove this requirement from marlin (allow optional tensors)
if self.w_gidx_name is None:
self.w_gidx_name = "g_idx"
if self.w_zp_name is None:
self.w_zp_name = "w_zp"
if c.has_g_idx:
g_idx, g_idx_sort_indices = marlin_sort_g_idx(
getattr(layer, self.w_gidx_name))
self._transform_param(layer, self.w_gidx_name, lambda _: g_idx)
layer.g_idx_sort_indices = g_idx_sort_indices
else:
setattr(layer, self.w_gidx_name, marlin_make_empty_g_idx(device))
layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
if c.zero_points:
pass
# TODO (lucas): add the following when AWQMarlin is migrated over to
# using MPLinearKernel backend
# self._transform_param(layer, self.w_zp_name, lambda x: \
# marlin_zero_points(
# x,
# size_k=c.partition_weight_shape[0],
# size_n=c.partition_weight_shape[1],
# num_bits=c.weight_type.size_bits))
else:
setattr(layer, self.w_zp_name, marlin_make_empty_g_idx(device))
def transform_w_q(x):
assert isinstance(x, BasevLLMParameter)
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
x.data = ops.gptq_marlin_repack(x.data.contiguous(),
perm=layer.g_idx_sort_indices,
size_k=c.partition_weight_shape[0],
size_n=c.partition_weight_shape[1],
num_bits=c.weight_type.size_bits)
return x
def transform_w_s(x):
assert isinstance(x, BasevLLMParameter)
permute_param_layout_(x, input_dim=0, output_dim=1)
x.data = marlin_permute_scales(x.data.contiguous(),
size_k=c.partition_weight_shape[0],
size_n=c.partition_weight_shape[1],
group_size=c.group_size)
return x
self._transform_param(layer, self.w_q_name, transform_w_q)
self._transform_param(layer, self.w_s_name, transform_w_s)
def apply_weights(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
c = self.config
w_q, w_s, w_zp, w_gidx = self._get_weight_params(layer)
# `process_weights_after_loading` will ensure w_zp and w_gidx are not
# None for marlin
return apply_gptq_marlin_linear(
input=x,
weight=w_q,
weight_scale=w_s,
weight_zp=w_zp, # type: ignore
g_idx=w_gidx, # type: ignore
g_idx_sort_indices=layer.g_idx_sort_indices,
workspace=self.workspace,
wtype=c.weight_type,
input_size_per_partition=c.partition_weight_shape[0],
output_size_per_partition=c.partition_weight_shape[1],
is_k_full=self.is_k_full,
bias=bias)
...@@ -260,7 +260,7 @@ class QQQLinearMethod(LinearMethodBase): ...@@ -260,7 +260,7 @@ class QQQLinearMethod(LinearMethodBase):
size_k = x_2d.shape[1] size_k = x_2d.shape[1]
size_n = s_ch.shape[1] size_n = s_ch.shape[1]
x_int8, s_tok = ops.scaled_int8_quant(x_2d) x_int8, s_tok, _ = ops.scaled_int8_quant(x_2d)
output_2d = ops.marlin_qqq_gemm(x_int8, qweight, s_tok, s_ch, s_group, output_2d = ops.marlin_qqq_gemm(x_int8, qweight, s_tok, s_ch, s_group,
workspace, size_m, size_n, size_k) workspace, size_m, size_n, size_k)
......
from .layer_utils import replace_parameter, update_tensor_inplace
__all__ = ['update_tensor_inplace', 'replace_parameter']
from typing import Union
import torch
def update_tensor_inplace(dst: torch.Tensor, src: torch.Tensor):
assert dst.dtype == src.dtype, "Tensors must have the same dtype"
# update tensor shape and stride
dst.as_strided_(src.shape, src.stride())
# If not the same underlying storage move tensor data
if dst.data_ptr() != src.data_ptr():
dst.copy_(src)
del src
# Newly generated tensors need to replace existing tensors that are
# already registered as parameters by vLLM (and won't be freed)
def replace_parameter(mod: torch.nn.Module, name: str,
new: Union[torch.Tensor, torch.nn.Parameter]) -> None:
old = getattr(mod, name)
if type(old) is type(new) and old.dtype == new.dtype and \
old.untyped_storage().nbytes() == new.untyped_storage().nbytes():
# If we can just update in-place to avoid re-registering
# can be faster if the underlying storage is the same
update_tensor_inplace(old, new)
else:
# Fallback re-register parameter, convert to Parameter if necessary
# this not only ensures we don't register a tensor as a parameter, but
# also ensures that all parameter subclasses get re-registered as
# parameters for `torch.compile` compatibility
if not isinstance(new, torch.nn.Parameter):
new = torch.nn.Parameter(new, requires_grad=False)
mod.register_parameter(name,
torch.nn.Parameter(new, requires_grad=False))
from typing import List, Optional, Tuple
import torch
from vllm.scalar_type import ScalarType, scalar_types
MACHETE_SUPPORTED_GROUP_SIZES = [-1, 128]
MACHETE_PREPACKED_BLOCK_SHAPE = [64, 128]
def query_machete_supported_quant_types(zero_points: bool) -> List[ScalarType]:
if zero_points:
return [scalar_types.uint4, scalar_types.uint8]
else:
return [scalar_types.uint4b8, scalar_types.uint8b128]
def query_machete_supported_act_types(zero_points: bool) -> List[ScalarType]:
return [torch.float16, torch.bfloat16]
def check_machete_supports_shape(in_features: int, out_featrues: int) \
-> Tuple[bool, Optional[str]]:
if in_features % MACHETE_PREPACKED_BLOCK_SHAPE[0] != 0:
return False, "Input features size must be divisible by "\
f"{MACHETE_PREPACKED_BLOCK_SHAPE[0]}"
if out_featrues % MACHETE_PREPACKED_BLOCK_SHAPE[1] != 0:
return False, "Output features size must be divisible by "\
f"{MACHETE_PREPACKED_BLOCK_SHAPE[1]}"
return True, None
...@@ -29,8 +29,9 @@ def query_marlin_supported_quant_types(has_zp: bool, ...@@ -29,8 +29,9 @@ def query_marlin_supported_quant_types(has_zp: bool,
device_capability: Optional[int] = None device_capability: Optional[int] = None
): ):
if device_capability is None: if device_capability is None:
major, minor = current_platform.get_device_capability() capability_tuple = current_platform.get_device_capability()
device_capability = major * 10 + minor device_capability = (-1 if capability_tuple is None else
capability_tuple.to_int())
if device_capability < 80: if device_capability < 80:
return [] return []
...@@ -52,8 +53,9 @@ def _check_marlin_supported( ...@@ -52,8 +53,9 @@ def _check_marlin_supported(
device_capability: Optional[int] = None) -> Tuple[bool, Optional[str]]: device_capability: Optional[int] = None) -> Tuple[bool, Optional[str]]:
if device_capability is None: if device_capability is None:
major, minor = current_platform.get_device_capability() capability_tuple = current_platform.get_device_capability()
device_capability = major * 10 + minor device_capability = (-1 if capability_tuple is None else
capability_tuple.to_int())
supported_types = query_marlin_supported_quant_types( supported_types = query_marlin_supported_quant_types(
has_zp, device_capability) has_zp, device_capability)
...@@ -118,6 +120,19 @@ def verify_marlin_supports_shape(output_size_per_partition: int, ...@@ -118,6 +120,19 @@ def verify_marlin_supports_shape(output_size_per_partition: int,
"with --quantization gptq.") "with --quantization gptq.")
def check_marlin_supports_shape(output_size_per_partition: int,
input_size_per_partition: int,
input_size: int, group_size: int) \
-> Tuple[bool, Optional[str]]:
try:
verify_marlin_supports_shape(output_size_per_partition,
input_size_per_partition, input_size,
group_size)
except ValueError as e:
return False, e.__str__()
return True, None
def marlin_make_workspace(output_size_per_partition: int, def marlin_make_workspace(output_size_per_partition: int,
device: torch.device) -> torch.Tensor: device: torch.device) -> torch.Tensor:
max_workspace_size = (output_size_per_partition // max_workspace_size = (output_size_per_partition //
...@@ -146,6 +161,11 @@ def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor: ...@@ -146,6 +161,11 @@ def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor:
requires_grad=False) requires_grad=False)
def marlin_make_empty_zp(device: torch.device) -> torch.Tensor:
return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device),
requires_grad=False)
def marlin_sort_g_idx( def marlin_sort_g_idx(
g_idx: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: g_idx: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
g_idx_sort_indices = torch.argsort(g_idx).to(torch.int) g_idx_sort_indices = torch.argsort(g_idx).to(torch.int)
...@@ -238,17 +258,6 @@ def awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int, ...@@ -238,17 +258,6 @@ def awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int,
return marlin_zp return marlin_zp
# Newly generated tensors need to replace existing tensors that are
# already registered as parameters by vLLM (and won't be freed)
def replace_tensor(layer: torch.nn.Module, name: str,
new_t: torch.Tensor) -> None:
# It is important to use resize_() here since it ensures
# the same buffer is reused
getattr(layer, name).resize_(new_t.shape)
getattr(layer, name).copy_(new_t)
del new_t
def apply_gptq_marlin_linear( def apply_gptq_marlin_linear(
input: torch.Tensor, input: torch.Tensor,
weight: torch.Tensor, weight: torch.Tensor,
......
...@@ -10,8 +10,7 @@ from .marlin_utils import marlin_make_workspace, marlin_permute_scales ...@@ -10,8 +10,7 @@ from .marlin_utils import marlin_make_workspace, marlin_permute_scales
def is_fp8_marlin_supported(): def is_fp8_marlin_supported():
capability = current_platform.get_device_capability() return current_platform.has_device_capability(80)
return capability[0] >= 8
def apply_fp8_marlin_linear( def apply_fp8_marlin_linear(
......
...@@ -20,6 +20,49 @@ FUSED_LAYER_NAME_MAPPING = { ...@@ -20,6 +20,49 @@ FUSED_LAYER_NAME_MAPPING = {
} }
def pack_weights_into_int32(w_q: torch.Tensor,
wtype: ScalarType,
packed_dim: int = 0):
# move dim to pack to the end
perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim)
inv_perm = tuple(perm.index(i) for i in range(len(perm)))
w_q_perm = w_q.permute(perm)
pack_factor = 32 // wtype.size_bits
mask = (1 << wtype.size_bits) - 1
new_shape_perm = list(w_q_perm.shape)
assert w_q_perm.shape[-1] % pack_factor == 0
new_shape_perm[-1] //= pack_factor
res = torch.zeros(new_shape_perm, dtype=torch.int32, device=w_q.device)
for i in range(pack_factor):
res |= (w_q_perm[..., i::pack_factor] & mask) << wtype.size_bits * i
return res.permute(inv_perm)
def unpack_weights_into_int32(w_q: torch.Tensor,
wtype: ScalarType,
packed_dim: int = 0):
# move dim to pack to the end
perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim)
inv_perm = tuple(perm.index(i) for i in range(len(perm)))
w_q_perm = w_q.permute(perm)
pack_factor = 32 // wtype.size_bits
mask = (1 << wtype.size_bits) - 1
new_shape_perm = list(w_q_perm.shape)
new_shape_perm[-1] *= pack_factor
res = torch.zeros(new_shape_perm, dtype=torch.int32, device=w_q.device)
for i in range(pack_factor):
res[..., i::pack_factor] = (w_q_perm >> wtype.size_bits * i) & mask
return res.permute(inv_perm)
def is_layer_skipped(prefix: str, ignored_layers: List[str]) -> bool: def is_layer_skipped(prefix: str, ignored_layers: List[str]) -> bool:
# prefix: model.layers.0.self_attn.q_proj # prefix: model.layers.0.self_attn.q_proj
# proj_name: q_proj # proj_name: q_proj
......
...@@ -6,19 +6,18 @@ from vllm import _custom_ops as ops ...@@ -6,19 +6,18 @@ from vllm import _custom_ops as ops
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import is_hip from vllm.utils import is_hip
# scaled_mm in pytorch on rocm has a bug that requires always # Input scaling factors are no longer optional in _scaled_mm starting
# providing scaling factor for result. This value is created # from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
# as global value to avoid multiple tensor allocations, and TORCH_DEVICE_IDENTITY = torch.ones(1).cuda() if is_hip() else None
# can be removed once pytorch fixes the bug.
TORCH_SCALED_MM_SCALE_RESULT = torch.ones(1).cuda() if is_hip() else None
def cutlass_fp8_supported() -> bool: def cutlass_fp8_supported() -> bool:
# cutlass is not supported on Rocm # cutlass is not supported on Rocm
if is_hip(): if is_hip():
return False return False
capability = current_platform.get_device_capability()
capability = capability[0] * 10 + capability[1] capability_tuple = current_platform.get_device_capability()
capability = -1 if capability_tuple is None else capability_tuple.to_int()
return ops.cutlass_scaled_mm_supports_fp8(capability) return ops.cutlass_scaled_mm_supports_fp8(capability)
...@@ -130,19 +129,17 @@ def apply_fp8_linear( ...@@ -130,19 +129,17 @@ def apply_fp8_linear(
if per_tensor_weights and per_tensor_activations: if per_tensor_weights and per_tensor_activations:
# Fused GEMM_DQ # Fused GEMM_DQ
output = torch._scaled_mm( output = torch._scaled_mm(qinput,
qinput, weight,
weight, out_dtype=input.dtype,
out_dtype=input.dtype, scale_a=x_scale,
scale_a=x_scale, scale_b=weight_scale,
scale_b=weight_scale, bias=bias)
scale_result=TORCH_SCALED_MM_SCALE_RESULT, # A fix for discrepancy in scaled_mm which returns tuple
bias=bias) # for torch < 2.5 and a single value in torch >= 2.5
# Since in torch 2.5, scaled_mm only returns single value if type(output) is tuple and len(output) == 2:
# This should be removed when vllm-nvidia also moves to 2.5 return torch.narrow(output[0], 0, 0, input.shape[0])
if is_hip(): return torch.narrow(output, 0, 0, input.shape[0])
return torch.narrow(output, 0, 0, input.shape[0])
return torch.narrow(output[0], 0, 0, input.shape[0])
else: else:
# Fallback for channelwise case, where we use unfused DQ # Fallback for channelwise case, where we use unfused DQ
...@@ -160,12 +157,23 @@ def apply_fp8_linear( ...@@ -160,12 +157,23 @@ def apply_fp8_linear(
# For the scaled_mm fallback case, we break this down, since it # For the scaled_mm fallback case, we break this down, since it
# does not support s_w being a vector. # does not support s_w being a vector.
# Making sure the dummy tensor is on the same device as the weight
global TORCH_DEVICE_IDENTITY
if TORCH_DEVICE_IDENTITY.device != weight.device:
TORCH_DEVICE_IDENTITY = TORCH_DEVICE_IDENTITY.to(weight.device)
# GEMM # GEMM
# This computes C = (X * W). # This computes C = (X * W).
# Output in fp32 to allow subsequent ops to happen in-place # Output in fp32 to allow subsequent ops to happen in-place
output, _ = torch._scaled_mm(qinput, output = torch._scaled_mm(qinput,
weight, weight,
out_dtype=torch.float32) scale_a=TORCH_DEVICE_IDENTITY,
scale_b=TORCH_DEVICE_IDENTITY,
out_dtype=torch.float32)
# A fix for discrepancy in scaled_mm which returns tuple
# for torch < 2.5 and a single value in torch >= 2.5
if type(output) is tuple and len(output) == 2:
output = output[0]
# Unpad (undo num_token_padding) # Unpad (undo num_token_padding)
output = torch.narrow(output, 0, 0, input.shape[0]) output = torch.narrow(output, 0, 0, input.shape[0])
x_scale = torch.narrow(x_scale, 0, 0, input.shape[0]) x_scale = torch.narrow(x_scale, 0, 0, input.shape[0])
...@@ -188,7 +196,7 @@ def apply_int8_linear( ...@@ -188,7 +196,7 @@ def apply_int8_linear(
# ops.scaled_int8_quant supports both dynamic and static quant. # ops.scaled_int8_quant supports both dynamic and static quant.
# * dynamic, layer.input_scale is None and x_scale computed from x. # * dynamic, layer.input_scale is None and x_scale computed from x.
# * static, layer.input_scale is scalar and x_scale is input_scale. # * static, layer.input_scale is scalar and x_scale is input_scale.
x_q, x_scale = ops.scaled_int8_quant(input, input_scale) x_q, x_scale, _ = ops.scaled_int8_quant(input, input_scale)
return ops.cutlass_scaled_mm(x_q, return ops.cutlass_scaled_mm(x_q,
weight, weight,
......
...@@ -31,15 +31,11 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler): ...@@ -31,15 +31,11 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
""" """
def __init__(self, def __init__(self,
disable_bonus_tokens: bool = True,
strict_mode: bool = False, strict_mode: bool = False,
use_flashinfer: Optional[bool] = None): use_flashinfer: Optional[bool] = None):
"""Create a rejection sampler. """Create a rejection sampler.
Args: Args:
disable_bonus_tokens: Whether or not to disable the bonus token.
Require when bonus tokens will cause corrupt KV cache for
proposal methods that require KV cache.
strict_mode: Whether or not to perform shape/device/dtype checks strict_mode: Whether or not to perform shape/device/dtype checks
during sampling. This catches correctness issues but adds during sampling. This catches correctness issues but adds
nontrivial latency. nontrivial latency.
...@@ -48,8 +44,7 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler): ...@@ -48,8 +44,7 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
None, we will use the default value from the environment variable. None, we will use the default value from the environment variable.
This parameter is only used for testing purposes. This parameter is only used for testing purposes.
""" """
super().__init__(disable_bonus_tokens=disable_bonus_tokens, super().__init__(strict_mode=strict_mode)
strict_mode=strict_mode)
if use_flashinfer is None: if use_flashinfer is None:
self.use_flashinfer = envs.VLLM_USE_FLASHINFER_SAMPLER and ( self.use_flashinfer = envs.VLLM_USE_FLASHINFER_SAMPLER and (
chain_speculative_sampling is not None) chain_speculative_sampling is not None)
...@@ -57,8 +52,6 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler): ...@@ -57,8 +52,6 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
self.use_flashinfer = use_flashinfer self.use_flashinfer = use_flashinfer
if self.use_flashinfer: if self.use_flashinfer:
assert not disable_bonus_tokens, \
"flashinfer will enable bonus token by default"
logger.info("Use flashinfer for rejection sampling.") logger.info("Use flashinfer for rejection sampling.")
else: else:
logger.info("Use pytorch for rejection sampling.") logger.info("Use pytorch for rejection sampling.")
......
...@@ -10,19 +10,15 @@ import msgspec ...@@ -10,19 +10,15 @@ import msgspec
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
from vllm.triton_utils import HAS_TRITON
if HAS_TRITON:
from vllm.model_executor.layers.ops.sample import sample as sample_triton
import vllm.envs as envs import vllm.envs as envs
from vllm.model_executor.sampling_metadata import (SamplingMetadata, from vllm.model_executor.sampling_metadata import (SamplingMetadata,
SamplingTensors, SamplingTensors,
SequenceGroupToSample) SequenceGroupToSample)
from vllm.sampling_params import SamplingType from vllm.sampling_params import SamplingType
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, from vllm.sequence import (VLLM_INVALID_TOKEN_ID,
CompletionSequenceGroupOutput, Logprob,
PromptLogprobs, SampleLogprobs, SequenceOutput) PromptLogprobs, SampleLogprobs, SequenceOutput)
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"): if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"):
import flashinfer.sampling import flashinfer.sampling
...@@ -438,12 +434,9 @@ def _apply_top_k_top_p( ...@@ -438,12 +434,9 @@ def _apply_top_k_top_p(
logits_sort.masked_fill_(top_p_mask, -float("inf")) logits_sort.masked_fill_(top_p_mask, -float("inf"))
# Re-sort the probabilities. # Re-sort the probabilities.
src = torch.arange(logits_idx.shape[-1], logits = torch.empty_like(logits_sort).scatter_(dim=-1,
device=logits_idx.device).expand_as(logits_idx) index=logits_idx,
logits_idx_inv = torch.empty_like(logits_idx).scatter_(dim=-1, src=logits_sort)
index=logits_idx,
src=src)
logits = torch.gather(logits_sort, dim=-1, index=logits_idx_inv)
return logits return logits
...@@ -740,7 +733,7 @@ def _sample_with_torch( ...@@ -740,7 +733,7 @@ def _sample_with_torch(
) -> SampleReturnType: ) -> SampleReturnType:
'''Torch-oriented _sample() implementation. '''Torch-oriented _sample() implementation.
Single-step scheduling: Single-step scheduling:
* Perform GPU-side sampling computation * Perform GPU-side sampling computation
* Immediately Pythonize sampling result * Immediately Pythonize sampling result
...@@ -767,17 +760,17 @@ def _sample_with_torch( ...@@ -767,17 +760,17 @@ def _sample_with_torch(
# Create output tensor for sampled token ids. # Create output tensor for sampled token ids.
if include_gpu_probs_tensor: if include_gpu_probs_tensor:
sampled_token_ids_tensor = torch.empty(logprobs.shape[0], sampled_token_ids_tensor = torch.full((logprobs.shape[0], 1),
1, VLLM_INVALID_TOKEN_ID,
dtype=torch.long, dtype=torch.long,
device=logprobs.device) device=logprobs.device)
else: else:
sampled_token_ids_tensor = None sampled_token_ids_tensor = None
# Counterintiutively, having two loops here is actually faster. # Counterintiutively, having two loops here is actually faster.
# The first loop can run without waiting on GPU<->CPU sync. # The first loop can run without waiting on GPU<->CPU sync.
for sampling_type in SamplingType: for sampling_type in SamplingType:
sample_indices = categorized_sample_indices[sampling_type][:, 0] sample_indices = categorized_sample_indices[sampling_type]
num_tokens = len(sample_indices) num_tokens = len(sample_indices)
if num_tokens == 0: if num_tokens == 0:
continue continue
...@@ -863,88 +856,6 @@ def _sample_with_torch( ...@@ -863,88 +856,6 @@ def _sample_with_torch(
) )
def _sample_with_triton_kernel(
probs: torch.Tensor,
logprobs: torch.Tensor,
sampling_metadata: SamplingMetadata,
sampling_tensors: SamplingTensors,
) -> SampleResultType:
categorized_seq_group_ids: Dict[SamplingType,
List[int]] = {t: []
for t in SamplingType}
categorized_sample_indices = sampling_metadata.categorized_sample_indices
for i, seq_group in enumerate(sampling_metadata.seq_groups):
sampling_params = seq_group.sampling_params
sampling_type = sampling_params.sampling_type
categorized_seq_group_ids[sampling_type].append(i)
sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
sample_metadata: Dict[SamplingType,
Tuple[List[int], List[SequenceGroupToSample],
torch.Tensor, torch.Tensor]] = {}
max_best_of_in_batch = 1
# Counterintiutively, having two loops here is actually faster.
# The first loop can run without waiting on GPU<->CPU sync.
for sampling_type in SamplingType:
sample_indices = categorized_sample_indices[sampling_type][:, 0]
sampled_token_indices = categorized_sample_indices[sampling_type][:, 1]
num_tokens = len(sample_indices)
if num_tokens == 0:
continue
seq_group_id = categorized_seq_group_ids[sampling_type]
seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_id]
sample_metadata[sampling_type] = (seq_group_id, seq_groups,
sample_indices,
sampled_token_indices)
if sampling_type in (SamplingType.GREEDY, SamplingType.RANDOM,
SamplingType.RANDOM_SEED):
for seq_group in seq_groups:
if seq_group.is_prompt:
sampling_params = seq_group.sampling_params
max_best_of_in_batch = max(max_best_of_in_batch,
sampling_params.best_of)
elif sampling_type == SamplingType.BEAM:
beam_search_logprobs = logprobs[sample_indices]
else:
raise ValueError(f"Unsupported sampling type: {sampling_type}")
sampled_tokens, _, _ = sample_triton(
probs=probs,
seeds=sampling_tensors.sampling_seeds,
max_best_of=max_best_of_in_batch,
sample_indices=sampling_tensors.sample_indices,
logprobs=logprobs,
# don't save logprobs because we have logic for that below
# TODO: use this instead of the CPU-based logic below
save_logprobs=False,
)
# GPU<->CPU sync happens in the loop below.
for sampling_type in SamplingType:
if sampling_type not in sample_metadata:
continue
(seq_group_id, seq_groups, sample_indices,
sampled_token_indices) = sample_metadata[sampling_type]
if sampling_type == SamplingType.GREEDY:
sample_results = _greedy_sample(
seq_groups, sampled_tokens[sampled_token_indices][:, 0])
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
sample_results = _random_sample(
seq_groups, sampled_tokens[sampled_token_indices])
elif sampling_type == SamplingType.BEAM:
sample_results = _beam_search_sample(seq_groups,
beam_search_logprobs)
sample_results_dict.update(zip(seq_group_id, sample_results))
sample_results = [
sample_results_dict.get(i, ([], []))
for i in range(len(sampling_metadata.seq_groups))
]
return sample_results
def _sample( def _sample(
probs: torch.Tensor, probs: torch.Tensor,
logprobs: torch.Tensor, logprobs: torch.Tensor,
...@@ -974,10 +885,6 @@ def _sample( ...@@ -974,10 +885,6 @@ def _sample(
modify_greedy_probs=modify_greedy_probs, modify_greedy_probs=modify_greedy_probs,
) )
# TODO: Enable once Triton kernel & associated code is faster.
# return _sample_with_triton_kernel(probs, logprobs, sampling_metadata,
# sampling_tensors)
def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
""" """
......
...@@ -11,20 +11,14 @@ class SpecDecodeBaseSampler(nn.Module): ...@@ -11,20 +11,14 @@ class SpecDecodeBaseSampler(nn.Module):
step. step.
""" """
def __init__(self, def __init__(self, strict_mode: bool = False):
disable_bonus_tokens: bool = True,
strict_mode: bool = False):
"""Base class constructor. """Base class constructor.
Args: Args:
disable_bonus_tokens: Whether or not to disable the bonus token.
Require when bonus tokens will cause corrupt KV cache for
proposal methods that require KV cache.
strict_mode: Whether or not to perform shape/device/dtype checks strict_mode: Whether or not to perform shape/device/dtype checks
during sampling. This catches correctness issues but adds during sampling. This catches correctness issues but adds
nontrivial latency. nontrivial latency.
""" """
super().__init__() super().__init__()
self._disable_bonus_tokens = disable_bonus_tokens
self._strict_mode = strict_mode self._strict_mode = strict_mode
# NOTE: A "bonus token" is accepted iff all proposal tokens are # NOTE: A "bonus token" is accepted iff all proposal tokens are
...@@ -111,13 +105,6 @@ class SpecDecodeBaseSampler(nn.Module): ...@@ -111,13 +105,6 @@ class SpecDecodeBaseSampler(nn.Module):
output_with_bonus_tokens[:, -1] = torch.where(output[:, -1] != -1, output_with_bonus_tokens[:, -1] = torch.where(output[:, -1] != -1,
bonus_token_ids, -1) bonus_token_ids, -1)
# We disable bonus tokens because it causes corrupt KV cache for
# proposal methods that require KV cache. We can fix it by "prefilling"
# the bonus token in the proposer. The following issue tracks the fix.
# https://github.com/vllm-project/vllm/issues/4212
if self._disable_bonus_tokens:
output_with_bonus_tokens[:, -1] = -1
# Fill the recovered token ids. # Fill the recovered token ids.
output.mul_(~after_false_mask).add_( output.mul_(~after_false_mask).add_(
substitute_token_ids.mul(after_false_mask)) substitute_token_ids.mul(after_false_mask))
......
...@@ -16,15 +16,11 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler): ...@@ -16,15 +16,11 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
self, self,
posterior_threshold: float, posterior_threshold: float,
posterior_alpha: float, posterior_alpha: float,
disable_bonus_tokens: bool = False,
strict_mode: bool = False, strict_mode: bool = False,
): ):
"""Create a Typical Acceptance Sampler. """Create a Typical Acceptance Sampler.
Args: Args:
disable_bonus_tokens: Whether or not to disable the bonus token.
Require when bonus tokens will cause corrupt KV cache for
proposal methods that require KV cache.
strict_mode: Whether or not to perform shape/device/dtype checks strict_mode: Whether or not to perform shape/device/dtype checks
during sampling. This catches correctness issues but adds during sampling. This catches correctness issues but adds
nontrivial latency. nontrivial latency.
...@@ -36,8 +32,7 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler): ...@@ -36,8 +32,7 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
""" """
self._posterior_threshold = posterior_threshold self._posterior_threshold = posterior_threshold
self._posterior_alpha = posterior_alpha self._posterior_alpha = posterior_alpha
super().__init__(disable_bonus_tokens=disable_bonus_tokens, super().__init__(strict_mode=strict_mode)
strict_mode=strict_mode)
def forward( def forward(
self, self,
...@@ -54,7 +49,7 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler): ...@@ -54,7 +49,7 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
one token will be emitted. one token will be emitted.
In the case where all draft tokens are accepted, the bonus token will be In the case where all draft tokens are accepted, the bonus token will be
accepted conditioned on self._disable_bonus_tokens being false. accepted.
Args: Args:
target_probs: The probability distribution over token ids given target_probs: The probability distribution over token ids given
...@@ -85,7 +80,7 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler): ...@@ -85,7 +80,7 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
target_probs = target_with_bonus_probs[:, :-1] target_probs = target_with_bonus_probs[:, :-1]
accepted = self._evaluate_accepted_tokens(target_probs, accepted = self._evaluate_accepted_tokens(target_probs,
draft_token_ids) draft_token_ids)
recovered_token_ids = self._replacement_token_ids(target_probs) recovered_token_ids = self._get_recovered_token_ids(target_probs)
output_token_ids = self._create_output(accepted, recovered_token_ids, output_token_ids = self._create_output(accepted, recovered_token_ids,
draft_token_ids, draft_token_ids,
bonus_token_ids) bonus_token_ids)
...@@ -153,16 +148,10 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler): ...@@ -153,16 +148,10 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
accepted_mask = candidates_prob > threshold accepted_mask = candidates_prob > threshold
return accepted_mask return accepted_mask
def _replacement_token_ids(self, target_probs): def _get_recovered_token_ids(self, target_probs):
""" """
Generate one replacement token ID for each sequence based on target The recovered token ids will fill the first unmatched token
probabilities. The replacement token is used as the fallback option by the target token.
if typical acceptance sampling does not accept any draft tokens for
that particular sequence.
This method computes the token IDs to be replaced by selecting the
token with the highest probability for each sequence in the first
position. The rest of the output is filled with -1.
Parameters Parameters
---------- ----------
...@@ -173,13 +162,9 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler): ...@@ -173,13 +162,9 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
Returns Returns
------- -------
torch.Tensor torch.Tensor
A tensor of shape (batch_size, k) with the replacement A tensor of shape (batch_size, k) with the recovered token
token IDs. Only the first column is set, and the rest of the ids which are selected from target probs.
columns are filled with -1.
""" """
max_indices = torch.argmax(target_probs[:, 0, :], dim=1) max_indices = torch.argmax(target_probs, dim=-1)
output = -torch.ones((target_probs.shape[0], target_probs.shape[1]),
dtype=self.token_id_dtype, return max_indices
device=target_probs.device)
output[:, 0] = max_indices
return output
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment