Commit 705f6a35 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.5.2' into v0.5.2-dtk24.04.1

parents af837396 4cf256ae
from typing import Callable, List, Optional
import torch
from torch.nn import Parameter
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
apply_marlin_linear, marlin_make_empty_g_idx, marlin_make_workspace,
marlin_permute_scales, replace_tensor, verify_marlin_supported,
verify_marlin_supports_shape)
from vllm.model_executor.utils import set_weight_attrs
__all__ = ["CompressedTensorsWNA16"]
WNA16_SUPPORTED_BITS = [4, 8]
class CompressedTensorsWNA16(CompressedTensorsScheme):
def __init__(self,
strategy: str,
num_bits: int,
group_size: Optional[int] = None):
self.num_bits = num_bits
self.pack_factor = 32 // self.num_bits
self.strategy = strategy
self.group_size: int
if group_size is None:
if self.strategy != "channel":
raise ValueError(
"Marlin kernels require group quantization or "
"channelwise quantization, but found no group "
"size and strategy is not channelwise.")
self.group_size = -1
else:
self.group_size = group_size
# Verify supported on platform.
verify_marlin_supported(num_bits=self.num_bits,
group_size=self.group_size,
is_sym=True)
def create_weights(self, layer: torch.nn.Module, input_size: int,
output_partition_sizes: List[int],
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):
output_size_per_partition = sum(output_partition_sizes)
# If group_size is -1, we are in channelwise case.
group_size = input_size if self.group_size == -1 else self.group_size
verify_marlin_supports_shape(
output_size_per_partition=output_size_per_partition,
input_size_per_partition=input_size_per_partition,
input_size=input_size,
group_size=group_size)
weight_scale_dim = None
scales_and_zp_size = input_size // group_size
if (input_size != input_size_per_partition
and self.group_size is not None):
weight_scale_dim = 1
scales_and_zp_size = input_size_per_partition // group_size
weight = Parameter(
torch.empty(
output_size_per_partition,
input_size_per_partition // self.pack_factor,
dtype=torch.int32,
),
requires_grad=False,
)
set_weight_attrs(
weight, {
"input_dim": 1,
"output_dim": 0,
"packed_dim": 1,
"pack_factor": self.pack_factor,
"weight_loader": weight_loader
})
layer.register_parameter("weight_packed", weight)
weight_scale = Parameter(
torch.empty(
output_size_per_partition,
scales_and_zp_size,
dtype=params_dtype,
),
requires_grad=False,
)
set_weight_attrs(
weight_scale, {
"weight_loader": weight_loader,
"input_dim": weight_scale_dim,
"output_dim": 0
})
layer.register_parameter("weight_scale", weight_scale)
# A 2D array defining the original shape of the weights
# before packing
weight_shape = Parameter(torch.empty(2, dtype=torch.int64),
requires_grad=False)
layer.register_parameter("weight_shape", weight_shape)
set_weight_attrs(weight_shape, {
"weight_loader": weight_loader,
"ignore_warning": True,
})
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
layer.input_size = input_size
layer.group_size = group_size
# Checkpoints are serialized in compressed-tensors format, which is
# different from marlin format. Handle repacking here.
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
device = layer.weight_packed.device
# Allocate marlin workspace.
layer.workspace = marlin_make_workspace(
layer.output_size_per_partition, device)
# Act-order not supported in compressed-tensors yet, so set to empty.
layer.g_idx = marlin_make_empty_g_idx(device)
layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
# Repack weights from compressed-tensors format to marlin format.
marlin_qweight = ops.gptq_marlin_repack(
layer.weight_packed.t().contiguous(),
perm=layer.g_idx_sort_indices,
size_k=layer.input_size_per_partition,
size_n=layer.output_size_per_partition,
num_bits=self.num_bits)
replace_tensor(layer, "weight_packed", marlin_qweight)
# Permute scales from compressed-tensors format to marlin format.
marlin_scales = marlin_permute_scales(
layer.weight_scale.squeeze().t().contiguous(),
size_k=layer.input_size_per_partition,
size_n=layer.output_size_per_partition,
group_size=layer.group_size)
replace_tensor(layer, "weight_scale", marlin_scales)
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
return apply_marlin_linear(
input=x,
weight=layer.weight_packed,
weight_scale=layer.weight_scale,
g_idx=layer.g_idx,
g_idx_sort_indices=layer.g_idx_sort_indices,
workspace=layer.workspace,
num_bits=self.num_bits,
output_size_per_partition=layer.output_size_per_partition,
input_size_per_partition=layer.input_size_per_partition,
is_k_full=True,
bias=bias)
...@@ -6,6 +6,15 @@ from pydantic import BaseModel, Field ...@@ -6,6 +6,15 @@ from pydantic import BaseModel, Field
from torch.nn import Module from torch.nn import Module
class CompressionFormat(Enum):
dense = "dense"
sparse_bitmask = "sparse-bitmask"
float_quantized = "float-quantized"
int_quantized = "int-quantized"
pack_quantized = "pack-quantized"
marlin_24 = "marlin-24"
class QuantizationType(str, Enum): class QuantizationType(str, Enum):
""" """
Enum storing quantization type options Enum storing quantization type options
......
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional
import torch import torch
from torch.nn import Module from torch.nn import Module
...@@ -6,10 +6,18 @@ from torch.nn.parameter import Parameter ...@@ -6,10 +6,18 @@ from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
fused_moe)
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase) QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
all_close_1d, apply_fp8_linear, create_per_tensor_scale_param,
cutlass_fp8_supported, per_tensor_dequantize, requantize_with_max_scale)
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils import print_warning_once from vllm.utils import print_warning_once
ACTIVATION_SCHEMES = ["static", "dynamic"] ACTIVATION_SCHEMES = ["static", "dynamic"]
...@@ -17,24 +25,6 @@ ACTIVATION_SCHEMES = ["static", "dynamic"] ...@@ -17,24 +25,6 @@ ACTIVATION_SCHEMES = ["static", "dynamic"]
logger = init_logger(__name__) logger = init_logger(__name__)
def cutlass_fp8_supported() -> bool:
capability = torch.cuda.get_device_capability()
capability = capability[0] * 10 + capability[1]
major, minor = torch.version.cuda.split(".")
version = int(major) * 10 + int(minor)
# CUTLASS FP8 kernels need at least
# CUDA 12.0 on SM90 systems (Hopper)
# CUDA 12.4 on SM89 systems (Lovelace)
gpu_is_supported = False
if capability >= 90:
gpu_is_supported = version > 120
elif capability >= 89:
gpu_is_supported = version > 124
return gpu_is_supported
class Fp8Config(QuantizationConfig): class Fp8Config(QuantizationConfig):
"""Config class for FP8.""" """Config class for FP8."""
...@@ -62,7 +52,7 @@ class Fp8Config(QuantizationConfig): ...@@ -62,7 +52,7 @@ class Fp8Config(QuantizationConfig):
@classmethod @classmethod
def get_min_capability(cls) -> int: def get_min_capability(cls) -> int:
return 89 return 80
@classmethod @classmethod
def get_config_filenames(cls) -> List[str]: def get_config_filenames(cls) -> List[str]:
...@@ -82,7 +72,9 @@ class Fp8Config(QuantizationConfig): ...@@ -82,7 +72,9 @@ class Fp8Config(QuantizationConfig):
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):
return Fp8LinearMethod(self) return Fp8LinearMethod(self)
if isinstance(layer, Attention): elif isinstance(layer, FusedMoE):
return Fp8MoEMethod(self)
elif isinstance(layer, Attention):
return Fp8KVCacheMethod(self) return Fp8KVCacheMethod(self)
return None return None
...@@ -112,23 +104,11 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -112,23 +104,11 @@ class Fp8LinearMethod(LinearMethodBase):
self.quant_config = quant_config self.quant_config = quant_config
self.cutlass_fp8_supported = cutlass_fp8_supported() self.cutlass_fp8_supported = cutlass_fp8_supported()
def _create_scale_param( # For GPUs that lack FP8 hardware support, we can leverage the Marlin
self, # kernel for fast weight-only FP8 quantization
scale_name: str, capability = current_platform.get_device_capability()
layer: torch.nn.Module, capability = capability[0] * 10 + capability[1]
output_partition_sizes: List[int], self.use_marlin = capability < 89
**extra_weight_attrs,
) -> None:
scale = Parameter(torch.empty(len(output_partition_sizes),
dtype=torch.float32),
requires_grad=False)
layer.register_parameter(scale_name, scale)
set_weight_attrs(
scale, {
**extra_weight_attrs,
"fp8_scales_shard_indexer":
self.scales_shard_indexer,
})
def create_weights( def create_weights(
self, self,
...@@ -143,9 +123,12 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -143,9 +123,12 @@ class Fp8LinearMethod(LinearMethodBase):
del input_size, output_size del input_size, output_size
output_size_per_partition = sum(output_partition_sizes) output_size_per_partition = sum(output_partition_sizes)
layer.process_after_load = True
layer.logical_widths = output_partition_sizes layer.logical_widths = output_partition_sizes
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
layer.orig_dtype = params_dtype
# WEIGHT # WEIGHT
weight_dtype = (torch.float8_e4m3fn weight_dtype = (torch.float8_e4m3fn
if self.quant_config.is_checkpoint_fp8_serialized else if self.quant_config.is_checkpoint_fp8_serialized else
...@@ -165,129 +148,255 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -165,129 +148,255 @@ class Fp8LinearMethod(LinearMethodBase):
# Otherwise, wait until process_weights_after_loading. # Otherwise, wait until process_weights_after_loading.
if self.quant_config.is_checkpoint_fp8_serialized: if self.quant_config.is_checkpoint_fp8_serialized:
# WEIGHT SCALE # WEIGHT SCALE
self._create_scale_param( scale = create_per_tensor_scale_param(output_partition_sizes,
scale_name="weight_scale", **extra_weight_attrs)
layer=layer, layer.register_parameter("weight_scale", scale)
output_partition_sizes=output_partition_sizes,
**extra_weight_attrs)
# INPUT ACTIVATION SCALE # INPUT ACTIVATION SCALE
if self.quant_config.activation_scheme == "static": if self.quant_config.activation_scheme == "static":
self._create_scale_param( scale = create_per_tensor_scale_param(output_partition_sizes,
scale_name="input_scale", **extra_weight_attrs)
layer=layer, layer.register_parameter("input_scale", scale)
output_partition_sizes=output_partition_sizes,
**extra_weight_attrs)
def scales_shard_indexer(
self, param: torch.Tensor, loaded_weight: torch.Tensor,
shard_id: Union[str, int]) -> Tuple[torch.Tensor, torch.Tensor]:
qkv_idxs = {"q": 0, "k": 1, "v": 2}
if isinstance(shard_id, int):
pass
elif isinstance(shard_id, str):
if shard_id not in qkv_idxs:
raise ValueError(f"Unknown shard_id: {shard_id}")
shard_id = qkv_idxs[shard_id]
else:
ValueError(f"Shard id must be int or str but got {type(shard_id)}")
return param[shard_id], loaded_weight
def process_weights_after_loading(self, layer: Module) -> None: def process_weights_after_loading(self, layer: Module) -> None:
if (not hasattr(layer, "process_after_load") # If checkpoint not serialized fp8, quantize the weights.
or not layer.process_after_load):
return
# If checkpoint is fp/bf16 (not serialized fp8), quantize the weights.
if not self.quant_config.is_checkpoint_fp8_serialized: if not self.quant_config.is_checkpoint_fp8_serialized:
qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, qweight, weight_scale = ops.scaled_fp8_quant(layer.weight,
scale=None) scale=None)
# Update the layer with the new values.
layer.weight = Parameter(qweight.t(), requires_grad=False) layer.weight = Parameter(qweight.t(), requires_grad=False)
layer.weight_scale = Parameter(weight_scale, requires_grad=False) layer.weight_scale = Parameter(weight_scale, requires_grad=False)
layer.logical_widths = None
layer.input_scale = None layer.input_scale = None
return
# If checkpoint is fp8, requantize the separately quantized logical # If checkpoint is fp8, requantize the separately quantized logical
# weights into a single fp8 weight with a single weight scale. # weights into a single fp8 weight with a single weight scale.
else: else:
# WEIGHT_SCALE / WEIGHT # Dequant -> Quant with max scale.
# Loop over logical weights, requantizing with single scale. max_w_scale, weight = requantize_with_max_scale(
max_w_scale = layer.weight_scale.max() weight=layer.weight,
start = 0 weight_scale=layer.weight_scale,
for idx, logical_width in enumerate(layer.logical_widths): logical_widths=layer.logical_widths,
end = start + logical_width )
weight_dq = per_tensor_dequantize(layer.weight[start:end, :],
layer.weight_scale[idx])
layer.weight[start:end, :] = per_tensor_quantize(
weight_dq, layer.weight_scale.max())
start = end
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
# WEIGHT # Update layer with new values.
# Transpose weight for passing to torch._scaled_mm
weight = layer.weight
layer.weight = Parameter(weight.t(), requires_grad=False) layer.weight = Parameter(weight.t(), requires_grad=False)
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
# INPUT ACTIVATION SCALE if self.quant_config.activation_scheme == "static":
# Dynamic: set to None (required input to ops.scaled_fp8_quant).
# Static: set to max of the input_scales (since they are equal).
if self.quant_config.activation_scheme == "dynamic":
layer.input_scale = None
elif self.quant_config.activation_scheme == "static":
if not all_close_1d(layer.input_scale):
raise ValueError(
"All the input_scales for the logical weights of a "
f"layer must be equal. But got {layer.input_scale}")
layer.input_scale = Parameter(layer.input_scale.max(), layer.input_scale = Parameter(layer.input_scale.max(),
requires_grad=False) requires_grad=False)
else: else:
raise ValueError( layer.input_scale = None
f"Unknown scheme {self.quant_config.activation_scheme}")
if self.use_marlin:
prepare_fp8_layer_for_marlin(layer)
# Activations not quantized for marlin.
del layer.input_scale
def apply(self, def apply(self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor: bias: Optional[torch.Tensor] = None) -> torch.Tensor:
# ops.scaled_fp8_quant supports both dynamic and static quant. if self.use_marlin:
# If dynamic, layer.input_scale is None and x_scale computed from x. return apply_fp8_marlin_linear(
# If static, layer.input_scale is scalar and x_scale is input_scale. input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
workspace=layer.workspace,
size_n=layer.output_size_per_partition,
size_k=layer.input_size_per_partition,
bias=bias)
return apply_fp8_linear(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
input_scale=layer.input_scale,
bias=bias,
cutlass_fp8_supported=self.cutlass_fp8_supported)
class Fp8MoEMethod(FusedMoEMethodBase):
"""MoE method for FP8.
Supports loading FP8 checkpoints with static weight scale and
dynamic/static activation scale.
Also supports loading quantized FP16/BF16 model checkpoints with dynamic
activation scaling. The weight scaling factor will be initialized after
the model weights are loaded.
Args:
quant_config: The quantization config.
"""
if bias is None and self.cutlass_fp8_supported: def __init__(self, quant_config: Fp8Config):
qinput, x_scale = ops.scaled_fp8_quant(x, layer.input_scale) self.quant_config = quant_config
# Fused GEMM_DQ def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
output = ops.cutlass_scaled_mm_dq( intermediate_size: int, params_dtype: torch.dtype,
qinput, **extra_weight_attrs):
layer.weight,
out_dtype=x.dtype,
scale_a=x_scale,
scale_b=layer.weight_scale,
)
if self.quant_config.is_checkpoint_fp8_serialized:
params_dtype = torch.float8_e4m3fn
# WEIGHTS
w13_weight = torch.nn.Parameter(torch.empty(num_experts,
2 * intermediate_size,
hidden_size,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter(torch.empty(num_experts,
hidden_size,
intermediate_size,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
# WEIGHT_SCALES
# Allocate 2 scales for w1 and w3 respectively.
# They will be combined to a single scale after weight loading.
w13_scale = torch.nn.Parameter(torch.ones(num_experts,
2,
dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w13_scale", w13_scale)
w2_scale = torch.nn.Parameter(torch.ones(num_experts,
dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w2_scale", w2_scale)
# If loading fp8 checkpoint, pass the weight loaders.
# If loading an fp16 checkpoint, do not (we will quantize in
# process_weights_after_loading()
if self.quant_config.is_checkpoint_fp8_serialized:
set_weight_attrs(w13_scale, extra_weight_attrs)
set_weight_attrs(w2_scale, extra_weight_attrs)
# INPUT_SCALES
if self.quant_config.activation_scheme == "static":
if not self.quant_config.is_checkpoint_fp8_serialized:
raise ValueError(
"Found static activation scheme for checkpoint that "
"was not serialized fp8.")
a13_scale = torch.nn.Parameter(torch.ones(num_experts,
dtype=torch.float32),
requires_grad=False)
layer.register_parameter("a13_scale", a13_scale)
set_weight_attrs(a13_scale, extra_weight_attrs)
a2_scale = torch.nn.Parameter(torch.ones(num_experts,
dtype=torch.float32),
requires_grad=False)
layer.register_parameter("a2_scale", a2_scale)
set_weight_attrs(a2_scale, extra_weight_attrs)
else: else:
qinput, x_scale = ops.scaled_fp8_quant(x, layer.a13_scale = None
layer.input_scale, layer.a2_scale = None
batch_dim_padding=17)
def process_weights_after_loading(self, layer: Module) -> None:
# Fused GEMM_DQ -- note we padded the input above because
# torch._scaled_mm is more performant for matrices with # If checkpoint is fp16, quantize in place.
# batch dimension > 16. Note that this could change if not self.quant_config.is_checkpoint_fp8_serialized:
# in the future. w13_weight = torch.empty_like(layer.w13_weight.data,
output, _ = torch._scaled_mm( dtype=torch.float8_e4m3fn)
qinput, w2_weight = torch.empty_like(layer.w2_weight.data,
layer.weight, dtype=torch.float8_e4m3fn)
out_dtype=x.dtype,
scale_a=x_scale, # Re-initialize w13_scale because we directly quantize
scale_b=layer.weight_scale, # merged w13 weights and generate a single scaling factor.
bias=bias, layer.w13_scale = torch.nn.Parameter(torch.ones(
) layer.num_experts,
dtype=torch.float32,
device=w13_weight.device),
requires_grad=False)
for expert in range(layer.num_experts):
w13_weight[expert, :, :], layer.w13_scale[
expert] = ops.scaled_fp8_quant(
layer.w13_weight.data[expert, :, :])
w2_weight[expert, :, :], layer.w2_scale[
expert] = ops.scaled_fp8_quant(
layer.w2_weight.data[expert, :, :])
layer.w13_weight = torch.nn.Parameter(w13_weight,
requires_grad=False)
layer.w2_weight = torch.nn.Parameter(w2_weight,
requires_grad=False)
return
# If checkpoint is fp8, we need to handle that the
# MoE kernels require single activation scale and single weight
# scale for w13 per expert.
else:
# Fp8 moe kernels require a single activation scale.
# We take the max of all the scales in case they differ.
if self.quant_config.activation_scheme == "static":
if layer.a13_scale is None or layer.a2_scale is None:
raise ValueError(
"QuantConfig has static quantization, but found "
"activation scales are None.")
if (not all_close_1d(layer.a13_scale)
or not all_close_1d(layer.a2_scale)):
print_warning_once(
"Found input_scales that are not equal for "
"fp8 MoE layer. Using the maximum across experts "
"for each layer. ")
layer.a13_scale = torch.nn.Parameter(layer.a13_scale.max(),
requires_grad=False)
layer.a2_scale = torch.nn.Parameter(layer.a2_scale.max(),
requires_grad=False)
# Fp8 moe kernel needs single weight scale for w13 per expert.
# We take the max then dequant and requant each expert.
assert layer.w13_scale is not None
shard_size = layer.intermediate_size_per_partition
max_w13_scales = layer.w13_scale.max(dim=1).values
for expert_id in range(layer.num_experts):
start = 0
for shard_id in range(2):
dq_weight = per_tensor_dequantize(
layer.w13_weight[expert_id][start:start +
shard_size, :],
layer.w13_scale[expert_id][shard_id])
layer.w13_weight[expert_id][
start:start + shard_size, :], _ = ops.scaled_fp8_quant(
dq_weight, max_w13_scales[expert_id])
start += shard_size
layer.w13_scale = torch.nn.Parameter(max_w13_scales,
requires_grad=False)
return
return torch.narrow(output, 0, 0, x.shape[0]) def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool = True,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None) -> torch.Tensor:
return fused_moe(x,
layer.w13_weight,
layer.w2_weight,
router_logits,
top_k,
renormalize=renormalize,
inplace=True,
use_fp8=True,
w1_scale=layer.w13_scale,
w2_scale=layer.w2_scale,
a1_scale=layer.a13_scale,
a2_scale=layer.a2_scale,
use_grouped_topk=use_grouped_topk,
num_expert_group=num_expert_group,
topk_group=topk_group)
class Fp8KVCacheMethod(QuantizeMethodBase): class Fp8KVCacheMethod(QuantizeMethodBase):
...@@ -326,23 +435,3 @@ class Fp8KVCacheMethod(QuantizeMethodBase): ...@@ -326,23 +435,3 @@ class Fp8KVCacheMethod(QuantizeMethodBase):
"cause accuracy issues. Please make sure kv-cache scaling " "cause accuracy issues. Please make sure kv-cache scaling "
"factor is available in the fp8 checkpoint.") "factor is available in the fp8 checkpoint.")
del layer.kv_scale del layer.kv_scale
def all_close_1d(x: torch.Tensor) -> bool:
assert len(x.shape) == 1
return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0]))
def per_tensor_quantize(tensor: torch.Tensor,
inv_scale: Union[float, torch.Tensor]) -> torch.Tensor:
finfo = torch.finfo(torch.float8_e4m3fn)
qweight = (tensor / inv_scale).clamp(min=finfo.min, max=finfo.max)
return qweight.to(torch.float8_e4m3fn)
def per_tensor_dequantize(
tensor: torch.Tensor, inv_scale: Union[float,
torch.Tensor]) -> torch.Tensor:
fake_qweight = tensor.to(torch.float16)
dq_weight = fake_qweight * inv_scale
return dq_weight
...@@ -10,6 +10,7 @@ from vllm import _custom_ops as ops ...@@ -10,6 +10,7 @@ from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
...@@ -24,10 +25,12 @@ class GPTQConfig(QuantizationConfig): ...@@ -24,10 +25,12 @@ class GPTQConfig(QuantizationConfig):
weight_bits: int, weight_bits: int,
group_size: int, group_size: int,
desc_act: bool, desc_act: bool,
lm_head_quantized: bool,
) -> None: ) -> None:
self.weight_bits = weight_bits self.weight_bits = weight_bits
self.group_size = group_size self.group_size = group_size
self.desc_act = desc_act self.desc_act = desc_act
self.lm_head_quantized = lm_head_quantized
self.pack_factor = Fraction(32, self.weight_bits) self.pack_factor = Fraction(32, self.weight_bits)
if self.weight_bits not in [2, 3, 4, 8]: if self.weight_bits not in [2, 3, 4, 8]:
raise ValueError( raise ValueError(
...@@ -37,7 +40,8 @@ class GPTQConfig(QuantizationConfig): ...@@ -37,7 +40,8 @@ class GPTQConfig(QuantizationConfig):
def __repr__(self) -> str: def __repr__(self) -> str:
return (f"GPTQConfig(weight_bits={self.weight_bits}, " return (f"GPTQConfig(weight_bits={self.weight_bits}, "
f"group_size={self.group_size}, " f"group_size={self.group_size}, "
f"desc_act={self.desc_act})") f"desc_act={self.desc_act}),"
f"lm_head_quantized={self.lm_head_quantized}")
@classmethod @classmethod
def get_name(cls) -> str: def get_name(cls) -> str:
...@@ -61,11 +65,14 @@ class GPTQConfig(QuantizationConfig): ...@@ -61,11 +65,14 @@ class GPTQConfig(QuantizationConfig):
weight_bits = cls.get_from_keys(config, ["bits"]) weight_bits = cls.get_from_keys(config, ["bits"])
group_size = cls.get_from_keys(config, ["group_size"]) group_size = cls.get_from_keys(config, ["group_size"])
desc_act = cls.get_from_keys(config, ["desc_act"]) desc_act = cls.get_from_keys(config, ["desc_act"])
return cls(weight_bits, group_size, desc_act) lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
default=False)
return cls(weight_bits, group_size, desc_act, lm_head_quantized)
def get_quant_method( def get_quant_method(
self, layer: torch.nn.Module) -> Optional["GPTQLinearMethod"]: self, layer: torch.nn.Module) -> Optional["GPTQLinearMethod"]:
if isinstance(layer, LinearBase): if (isinstance(layer, LinearBase) or
(isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
return GPTQLinearMethod(self) return GPTQLinearMethod(self)
return None return None
......
import enum
from enum import Enum
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
import torch import torch
...@@ -11,90 +9,43 @@ from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, ...@@ -11,90 +9,43 @@ from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
set_weight_attrs) set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
apply_marlin_linear, check_marlin_supported, marlin_is_k_full,
marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales,
marlin_repeat_scales_on_all_ranks, marlin_sort_g_idx, replace_tensor,
verify_marlin_supported, verify_marlin_supports_shape)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
logger = init_logger(__name__) logger = init_logger(__name__)
GPTQ_MARLIN_TILE = 16
GPTQ_MARLIN_MIN_THREAD_N = 64
GPTQ_MARLIN_MIN_THREAD_K = 128
GPTQ_MARLIN_MAX_PARALLEL = 16
GPTQ_MARLIN_SUPPORTED_NUM_BITS = [4, 8]
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
GPTQ_MARLIN_SUPPORTED_SYM = [True]
# Permutations for Marlin scale shuffling
def get_scale_perms(num_bits):
scale_perm = []
for i in range(8):
scale_perm.extend([i + 8 * j for j in range(8)])
scale_perm_single = []
for i in range(4):
scale_perm_single.extend(
[2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
return scale_perm, scale_perm_single
def get_pack_factor(num_bits):
assert (num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS
), f"Unsupported num_bits = {num_bits}"
return 32 // num_bits
def marlin_permute_scales(s, size_k, size_n, group_size, num_bits):
scale_perm, scale_perm_single = get_scale_perms(num_bits)
if group_size < size_k and group_size != -1:
s = s.reshape((-1, len(scale_perm)))[:, scale_perm]
else:
s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
s = s.reshape((-1, size_n)).contiguous()
return s
class GPTQMarlinConfig(QuantizationConfig): class GPTQMarlinConfig(QuantizationConfig):
"""Config class for GPTQ Marlin""" """Config class for GPTQ Marlin"""
def __init__(self, weight_bits: int, group_size: int, desc_act: bool, def __init__(self, weight_bits: int, group_size: int, desc_act: bool,
is_sym: bool) -> None: is_sym: bool, lm_head_quantized: bool) -> None:
if desc_act and group_size == -1: if desc_act and group_size == -1:
# In this case, act_order == True is the same as act_order == False # In this case, act_order == True is the same as act_order == False
# (since we have only one group per output channel) # (since we have only one group per output channel)
desc_act = False desc_act = False
self.weight_bits = weight_bits self.weight_bits = weight_bits
self.pack_factor = 32 // self.weight_bits # packed into int32
self.group_size = group_size self.group_size = group_size
self.desc_act = desc_act self.desc_act = desc_act
self.is_sym = is_sym self.is_sym = is_sym
self.lm_head_quantized = lm_head_quantized
# Verify # Verify supported on platform.
if self.weight_bits not in GPTQ_MARLIN_SUPPORTED_NUM_BITS: verify_marlin_supported(num_bits=self.weight_bits,
raise ValueError( group_size=self.group_size,
f"Marlin does not support weight_bits = {self.weight_bits}. " is_sym=self.is_sym)
f"Only weight_bits = {GPTQ_MARLIN_SUPPORTED_NUM_BITS} "
"are supported.")
if self.group_size not in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES:
raise ValueError(
f"Marlin does not support group_size = {self.group_size}. "
f"Only group_sizes = {GPTQ_MARLIN_SUPPORTED_GROUP_SIZES} "
"are supported.")
if self.is_sym not in GPTQ_MARLIN_SUPPORTED_SYM:
raise ValueError(
f"Marlin does not support is_sym = {self.is_sym}. "
f"Only sym = {GPTQ_MARLIN_SUPPORTED_SYM} are supported.")
# Init
self.pack_factor = get_pack_factor(weight_bits)
self.tile_size = GPTQ_MARLIN_TILE
self.min_thread_n = GPTQ_MARLIN_MIN_THREAD_N
self.min_thread_k = GPTQ_MARLIN_MIN_THREAD_K
self.max_parallel = GPTQ_MARLIN_MAX_PARALLEL
def __repr__(self) -> str: def __repr__(self) -> str:
return (f"GPTQMarlinConfig(weight_bits={self.weight_bits}, " return (f"GPTQMarlinConfig(weight_bits={self.weight_bits}, "
f"group_size={self.group_size}, " f"group_size={self.group_size}, "
f"desc_act={self.desc_act})") f"desc_act={self.desc_act}, "
f"lm_head_quantized={self.lm_head_quantized})")
@classmethod @classmethod
def get_name(cls) -> str: def get_name(cls) -> str:
...@@ -118,7 +69,10 @@ class GPTQMarlinConfig(QuantizationConfig): ...@@ -118,7 +69,10 @@ class GPTQMarlinConfig(QuantizationConfig):
group_size = cls.get_from_keys(config, ["group_size"]) group_size = cls.get_from_keys(config, ["group_size"])
desc_act = cls.get_from_keys(config, ["desc_act"]) desc_act = cls.get_from_keys(config, ["desc_act"])
is_sym = cls.get_from_keys(config, ["sym"]) is_sym = cls.get_from_keys(config, ["sym"])
return cls(weight_bits, group_size, desc_act, is_sym) lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
default=False)
return cls(weight_bits, group_size, desc_act, is_sym,
lm_head_quantized)
@classmethod @classmethod
def override_quantization_method(cls, hf_quant_cfg, def override_quantization_method(cls, hf_quant_cfg,
...@@ -143,7 +97,8 @@ class GPTQMarlinConfig(QuantizationConfig): ...@@ -143,7 +97,8 @@ class GPTQMarlinConfig(QuantizationConfig):
def get_quant_method( def get_quant_method(
self, self,
layer: torch.nn.Module) -> Optional["GPTQMarlinLinearMethod"]: layer: torch.nn.Module) -> Optional["GPTQMarlinLinearMethod"]:
if isinstance(layer, LinearBase): if (isinstance(layer, LinearBase) or
(isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
return GPTQMarlinLinearMethod(self) return GPTQMarlinLinearMethod(self)
return None return None
...@@ -163,21 +118,10 @@ class GPTQMarlinConfig(QuantizationConfig): ...@@ -163,21 +118,10 @@ class GPTQMarlinConfig(QuantizationConfig):
or desc_act is None): or desc_act is None):
return False return False
# If the capability of the device is too low, cannot convert. return check_marlin_supported(num_bits=num_bits,
major, minor = torch.cuda.get_device_capability() group_size=group_size,
device_capability = major * 10 + minor is_sym=sym,
if device_capability < cls.get_min_capability(): min_capability=cls.get_min_capability())
return False
# Otherwise, can convert if model satisfies marlin constraints.
return (num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS
and group_size in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
and sym in GPTQ_MARLIN_SUPPORTED_SYM)
class GPTQMarlinState(Enum):
REPACK = enum.auto()
READY = enum.auto()
class GPTQMarlinLinearMethod(LinearMethodBase): class GPTQMarlinLinearMethod(LinearMethodBase):
...@@ -201,6 +145,8 @@ class GPTQMarlinLinearMethod(LinearMethodBase): ...@@ -201,6 +145,8 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
**extra_weight_attrs, **extra_weight_attrs,
) -> None: ) -> None:
del output_size del output_size
output_size_per_partition = sum(output_partition_sizes)
is_row_parallel = input_size != input_size_per_partition
# Normalize group_size # Normalize group_size
if self.quant_config.group_size != -1: if self.quant_config.group_size != -1:
...@@ -208,58 +154,25 @@ class GPTQMarlinLinearMethod(LinearMethodBase): ...@@ -208,58 +154,25 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
else: else:
group_size = input_size group_size = input_size
# Validate dtype verify_marlin_supports_shape(
if params_dtype not in [torch.float16, torch.bfloat16]: output_size_per_partition=output_size_per_partition,
raise ValueError(f"The params dtype must be float16 " input_size_per_partition=input_size_per_partition,
f"or bfloat16, but got {params_dtype}") input_size=input_size,
group_size=group_size)
# Validate output_size_per_partition
output_size_per_partition = sum(output_partition_sizes) # Determine sharding
if output_size_per_partition % self.quant_config.min_thread_n != 0: if marlin_repeat_scales_on_all_ranks(self.quant_config.desc_act,
raise ValueError( self.quant_config.group_size,
f"Weight output_size_per_partition = " is_row_parallel):
f"{output_size_per_partition} is not divisible by " # By setting scale_dim == None, weight_loader will
f" min_thread_n = {self.quant_config.min_thread_n}.") # repeat the scales on each GPU in TP>1 case.
scales_and_zp_input_dim = None
# Validate input_size_per_partition scales_and_zp_size = input_size // group_size
if input_size_per_partition % self.quant_config.min_thread_k != 0:
raise ValueError(
f"Weight input_size_per_partition = "
f"{input_size_per_partition} is not divisible "
f"by min_thread_k = {self.quant_config.min_thread_k}.")
if (group_size < input_size
and input_size_per_partition % group_size != 0):
raise ValueError(
f"Weight input_size_per_partition = {input_size_per_partition}"
f" is not divisible by group_size = {group_size}.")
# Detect sharding of scales/zp
# By default, no sharding over "input dim"
scales_and_zp_size = input_size // group_size
scales_and_zp_input_dim = None
if self.quant_config.desc_act:
# Act-order case
assert self.quant_config.group_size != -1
is_k_full = input_size_per_partition == input_size
else: else:
# No act-order case # By setting scale_dim == 0, weight_loader will
# shard the scales in TP>1 case.
# K is always full due to full alignment with scales_and_zp_input_dim = 0
# group-size and shard of scales/zp scales_and_zp_size = input_size_per_partition // group_size
is_k_full = True
# If this is a row-parallel case, then shard scales/zp
if (input_size != input_size_per_partition
and self.quant_config.group_size != -1):
scales_and_zp_size = input_size_per_partition // group_size
scales_and_zp_input_dim = 0
# Init buffers
# Quantized weights # Quantized weights
qweight = Parameter( qweight = Parameter(
...@@ -298,11 +211,6 @@ class GPTQMarlinLinearMethod(LinearMethodBase): ...@@ -298,11 +211,6 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
}, },
) )
g_idx_sort_indices = torch.empty(
g_idx.shape,
dtype=torch.int32,
)
# Scales # Scales
scales = Parameter( scales = Parameter(
torch.empty( torch.empty(
...@@ -342,25 +250,52 @@ class GPTQMarlinLinearMethod(LinearMethodBase): ...@@ -342,25 +250,52 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
}, },
) )
# Allocate marlin workspace
max_workspace_size = (
output_size_per_partition //
self.quant_config.min_thread_n) * self.quant_config.max_parallel
workspace = torch.zeros(max_workspace_size,
dtype=torch.int,
requires_grad=False)
layer.register_parameter("qweight", qweight) layer.register_parameter("qweight", qweight)
layer.register_parameter("g_idx", g_idx) layer.register_parameter("g_idx", g_idx)
layer.register_parameter("scales", scales) layer.register_parameter("scales", scales)
layer.register_parameter("qzeros", qzeros) layer.register_parameter("qzeros", qzeros)
layer.g_idx_sort_indices = g_idx_sort_indices
layer.workspace = workspace
layer.input_size_per_partition = input_size_per_partition layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition layer.output_size_per_partition = output_size_per_partition
layer.input_size = input_size layer.input_size = input_size
layer.is_k_full = is_k_full layer.is_k_full = marlin_is_k_full(self.quant_config.desc_act,
layer.marlin_state = GPTQMarlinState.REPACK is_row_parallel)
# Checkpoints are serialized in AutoGPTQ format, which is different from the
# marlin format. This function is called after the weights are loaded.
# Here, we handle the repacking, including the activation reordering case.
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
device = layer.qweight.device
# Allocate marlin workspace
layer.workspace = marlin_make_workspace(
layer.output_size_per_partition, device)
# Handle sorting for activation reordering if needed.
if self.quant_config.desc_act:
g_idx, g_idx_sort_indices = marlin_sort_g_idx(layer.g_idx)
layer.g_idx_sort_indices = g_idx_sort_indices
replace_tensor(layer, "g_idx", g_idx)
else:
layer.g_idx = marlin_make_empty_g_idx(device)
layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
# Repack weights from autogptq format to marlin format.
marlin_qweight = ops.gptq_marlin_repack(
layer.qweight,
perm=layer.g_idx_sort_indices,
size_k=layer.input_size_per_partition,
size_n=layer.output_size_per_partition,
num_bits=self.quant_config.weight_bits)
replace_tensor(layer, "qweight", marlin_qweight)
# Permute scales from autogptq format to marlin format.
marlin_scales = marlin_permute_scales(
layer.scales,
size_k=(layer.input_size if self.quant_config.desc_act else
layer.input_size_per_partition),
size_n=layer.output_size_per_partition,
group_size=self.quant_config.group_size)
replace_tensor(layer, "scales", marlin_scales)
def apply( def apply(
self, self,
...@@ -368,90 +303,15 @@ class GPTQMarlinLinearMethod(LinearMethodBase): ...@@ -368,90 +303,15 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
x: torch.Tensor, x: torch.Tensor,
bias: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
reshaped_x = x.reshape(-1, x.shape[-1]) return apply_marlin_linear(
input=x,
size_m = reshaped_x.shape[0] weight=layer.qweight,
part_size_n = layer.output_size_per_partition weight_scale=layer.scales,
part_size_k = layer.input_size_per_partition g_idx=layer.g_idx,
full_size_k = layer.input_size g_idx_sort_indices=layer.g_idx_sort_indices,
workspace=layer.workspace,
out_shape = x.shape[:-1] + (part_size_n, ) num_bits=self.quant_config.weight_bits,
output_size_per_partition=layer.output_size_per_partition,
if layer.marlin_state == GPTQMarlinState.REPACK: input_size_per_partition=layer.input_size_per_partition,
layer.marlin_state = GPTQMarlinState.READY is_k_full=layer.is_k_full,
bias=bias)
# Newly generated tensors need to replace existing tensors that are
# already registered as parameters by vLLM (and won't be freed)
def replace_tensor(name, new_t):
# It is important to use resize_() here since it ensures
# the same buffer is reused
getattr(layer, name).resize_(new_t.shape)
getattr(layer, name).copy_(new_t)
del new_t
cur_device = layer.qweight.device
# Process act_order
if self.quant_config.desc_act:
# Get sorting based on g_idx
g_idx_sort_indices = torch.argsort(layer.g_idx).to(torch.int)
sorted_g_idx = layer.g_idx[g_idx_sort_indices]
replace_tensor("g_idx", sorted_g_idx)
replace_tensor("g_idx_sort_indices", g_idx_sort_indices)
else:
# Reset g_idx related tensors
layer.g_idx = Parameter(
torch.empty(0, dtype=torch.int, device=cur_device),
requires_grad=False,
)
layer.g_idx_sort_indices = Parameter(
torch.empty(0, dtype=torch.int, device=cur_device),
requires_grad=False,
)
# Repack weights
marlin_qweight = ops.gptq_marlin_repack(
layer.qweight,
layer.g_idx_sort_indices,
part_size_k,
part_size_n,
self.quant_config.weight_bits,
)
replace_tensor("qweight", marlin_qweight)
# Permute scales
scales_size_k = part_size_k
scales_size_n = part_size_n
if self.quant_config.desc_act:
scales_size_k = full_size_k
marlin_scales = marlin_permute_scales(
layer.scales,
scales_size_k,
scales_size_n,
self.quant_config.group_size,
self.quant_config.weight_bits,
)
replace_tensor("scales", marlin_scales)
output = ops.gptq_marlin_gemm(
reshaped_x,
layer.qweight,
layer.scales,
layer.g_idx,
layer.g_idx_sort_indices,
layer.workspace,
self.quant_config.weight_bits,
size_m,
part_size_n,
part_size_k,
layer.is_k_full,
)
if bias is not None:
output.add_(bias) # In-place add
return output.reshape(out_shape)
...@@ -8,6 +8,7 @@ from vllm.logger import init_logger ...@@ -8,6 +8,7 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -22,9 +23,11 @@ class MarlinConfig(QuantizationConfig): ...@@ -22,9 +23,11 @@ class MarlinConfig(QuantizationConfig):
def __init__( def __init__(
self, self,
group_size: int, group_size: int,
lm_head_quantized: bool,
) -> None: ) -> None:
# Group size for the quantization. # Group size for the quantization.
self.group_size = group_size self.group_size = group_size
self.lm_head_quantized = lm_head_quantized
if self.group_size != 128 and self.group_size != -1: if self.group_size != 128 and self.group_size != -1:
raise ValueError( raise ValueError(
"Currently, only group size 128 and -1 (channelwise) " "Currently, only group size 128 and -1 (channelwise) "
...@@ -51,7 +54,8 @@ class MarlinConfig(QuantizationConfig): ...@@ -51,7 +54,8 @@ class MarlinConfig(QuantizationConfig):
self.perm_len = 1024 self.perm_len = 1024
def __repr__(self) -> str: def __repr__(self) -> str:
return f"MarlinConfig(group_size={self.group_size})" return (f"MarlinConfig(group_size={self.group_size}, "
f"lm_head_quantized={self.lm_head_quantized})")
@classmethod @classmethod
def get_name(cls) -> str: def get_name(cls) -> str:
...@@ -73,7 +77,9 @@ class MarlinConfig(QuantizationConfig): ...@@ -73,7 +77,9 @@ class MarlinConfig(QuantizationConfig):
@classmethod @classmethod
def from_config(cls, config: Dict[str, Any]) -> "MarlinConfig": def from_config(cls, config: Dict[str, Any]) -> "MarlinConfig":
group_size = cls.get_from_keys(config, ["group_size"]) group_size = cls.get_from_keys(config, ["group_size"])
return cls(group_size) lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
default=False)
return cls(group_size, lm_head_quantized)
@classmethod @classmethod
def override_quantization_method(cls, hf_quant_cfg, def override_quantization_method(cls, hf_quant_cfg,
...@@ -96,7 +102,8 @@ class MarlinConfig(QuantizationConfig): ...@@ -96,7 +102,8 @@ class MarlinConfig(QuantizationConfig):
def get_quant_method( def get_quant_method(
self, layer: torch.nn.Module) -> Optional["MarlinLinearMethod"]: self, layer: torch.nn.Module) -> Optional["MarlinLinearMethod"]:
if isinstance(layer, LinearBase): if (isinstance(layer, LinearBase) or
(isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
return MarlinLinearMethod(self) return MarlinLinearMethod(self)
return None return None
......
...@@ -39,7 +39,8 @@ class SqueezeLLMConfig(QuantizationConfig): ...@@ -39,7 +39,8 @@ class SqueezeLLMConfig(QuantizationConfig):
def get_supported_act_dtypes(self) -> List[torch.dtype]: def get_supported_act_dtypes(self) -> List[torch.dtype]:
return [torch.half] return [torch.half]
def get_min_capability(self) -> int: @classmethod
def get_min_capability(cls) -> int:
return 70 return 70
@staticmethod @staticmethod
......
"""This file is used for /tests and /benchmarks"""
import numpy
import torch
# Precompute permutations for Marlin24 weight and scale shuffling # noqa: E501
#
# Marlin works on [16*2,64] tiles. The goal of the permutations is to reorder the weight data so that it is compatible noqa: # noqa: E501
# with the tensor-core format that is described here:
# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type # noqa: E501
#
# As a result of this reordering, the vector loads inside the kernel will get the data as it is needed for tensor-core # noqa: E501
# (without the need to use ldmatrix instructions) # noqa: E501
def get_perms_24(num_bits):
perm_list = []
for i in range(32):
perm1 = []
col = i // 4
col_o = col // 2
for block in [0, 1]:
for row in [
2 * (i % 4),
2 * (i % 4) + 1,
2 * (i % 4 + 4),
2 * (i % 4 + 4) + 1,
]:
perm1.append(16 * row + col_o * 256 + 8 * (col % 2) +
4 * block)
for j in range(4):
perm_list.extend([p + 1 * j for p in perm1])
perm = numpy.array(perm_list)
if num_bits == 4:
interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
elif num_bits == 8:
interleave = numpy.array([0, 2, 1, 3])
else:
raise ValueError("num_bits must be 4 or 8, got {}".format(num_bits))
perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()
perm = torch.from_numpy(perm)
scale_perm = []
for i in range(8):
scale_perm.extend([i * 8 + j for j in [0, 4, 1, 5, 2, 6, 3, 7]])
scale_perm_single = []
for i in range(8):
scale_perm_single.extend([8 * i + j for j in [0, 1, 2, 3, 4, 5, 6, 7]])
return perm, scale_perm, scale_perm_single
marlin_24_perm = {}
marlin_24_scale_perm = {}
marlin_24_scale_perm_single = {}
for num_bits in [4, 8]:
perm_24, scale_perm_24, scale_perm_single_24 = get_perms_24(num_bits)
marlin_24_perm[num_bits] = perm_24
marlin_24_scale_perm[num_bits] = scale_perm_24
marlin_24_scale_perm_single[num_bits] = scale_perm_single_24
"""This file is used for /tests and /benchmarks"""
import numpy
import torch
# Precompute permutations for Marlin weight and scale shuffling # noqa: E501
#
# Marlin works on [16,64] tiles. The goal of the permutations is to reorder the weight data so that it is compatible noqa: # noqa: E501
# with the tensor-core format that is described here:
# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type # noqa: E501
#
# As a result of this reordering, the vector loads inside the kernel will get the data as it is needed for tensor-core # noqa: E501
# (without the need to use ldmatrix instructions) # noqa: E501
def get_perms(num_bits):
perm_list = []
for i in range(32):
perm1 = []
col = i // 4
for block in [0, 1]:
for row in [
2 * (i % 4),
2 * (i % 4) + 1,
2 * (i % 4 + 4),
2 * (i % 4 + 4) + 1,
]:
perm1.append(16 * row + col + 8 * block)
for j in range(4):
perm_list.extend([p + 256 * j for p in perm1])
perm = numpy.array(perm_list)
if num_bits == 4:
interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
elif num_bits == 8:
interleave = numpy.array([0, 2, 1, 3])
else:
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()
perm = torch.from_numpy(perm)
scale_perm = []
for i in range(8):
scale_perm.extend([i + 8 * j for j in range(8)])
scale_perm_single = []
for i in range(4):
scale_perm_single.extend(
[2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
return perm, scale_perm, scale_perm_single
marlin_perm = {}
marlin_scale_perm = {}
marlin_scale_perm_single = {}
for num_bits in [4, 8]:
perm, scale_perm, scale_perm_single = get_perms(num_bits)
marlin_perm[num_bits] = perm
marlin_scale_perm[num_bits] = scale_perm
marlin_scale_perm_single[num_bits] = scale_perm_single
"""This file is used for /tests and /benchmarks""" from typing import List, Optional, Tuple
import random
import numpy
import torch import torch
from vllm.model_executor.layers.quantization.utils.format_24 import ( from vllm import _custom_ops as ops
mask_creator, sparse_semi_structured_from_dense_cutlass) from vllm.platforms import current_platform
from vllm.model_executor.layers.quantization.utils.marlin_24_perms import (
marlin_24_perm, marlin_24_scale_perm, marlin_24_scale_perm_single) GPTQ_MARLIN_TILE = 16
from vllm.model_executor.layers.quantization.utils.marlin_perms import ( GPTQ_MARLIN_MIN_THREAD_N = 64
marlin_perm, marlin_scale_perm, marlin_scale_perm_single) GPTQ_MARLIN_MIN_THREAD_K = 128
from vllm.model_executor.layers.quantization.utils.quant_utils import ( GPTQ_MARLIN_MAX_PARALLEL = 16
get_pack_factor, quantize_weights, sort_weights)
GPTQ_MARLIN_SUPPORTED_NUM_BITS = [4, 8]
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
GPTQ_MARLIN_SUPPORTED_SYM = [True]
GTPQ_MARLIN_UNSUPPORTED_GROUP_SIZE_ACT_ORDER = [-1]
__cuda_arch = torch.cuda.get_device_capability()
MARLIN_TILE = 16 def check_marlin_supported(num_bits: int, group_size: int, is_sym: bool,
min_capability: int) -> bool:
def is_marlin_supported(): # If the capability of the device is too low, cannot convert.
return __cuda_arch[0] >= 8 major, minor = current_platform.get_device_capability()
device_capability = major * 10 + minor
if device_capability < min_capability:
def marlin_permute_weights(q_w, size_k, size_n, perm, tile=MARLIN_TILE): return False
assert q_w.shape == (size_k, size_n)
assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}" return (device_capability >= min_capability
assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}" and num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS
and group_size in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
# Permute weights to 16x64 marlin tiles and is_sym in GPTQ_MARLIN_SUPPORTED_SYM)
q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile))
q_w = q_w.permute((0, 2, 1, 3))
q_w = q_w.reshape((size_k // tile, size_n * tile)) def verify_marlin_supported(num_bits: int, group_size: Optional[int],
is_sym: bool) -> None:
q_w = q_w.reshape((-1, perm.numel()))[:, perm].reshape(q_w.shape)
if num_bits not in GPTQ_MARLIN_SUPPORTED_NUM_BITS:
return q_w raise ValueError(
f"Marlin does not support weight_bits = {num_bits}. "
f"Only weight_bits = {GPTQ_MARLIN_SUPPORTED_NUM_BITS} "
def marlin_weights(q_w, size_k, size_n, num_bits, perm): "are supported.")
# Permute if (group_size is None
q_w = marlin_permute_weights(q_w, size_k, size_n, perm) or group_size not in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES):
raise ValueError(
# Pack f"Marlin does not support group_size = {group_size}. "
pack_factor = get_pack_factor(num_bits) f"Only group_sizes = {GPTQ_MARLIN_SUPPORTED_GROUP_SIZES} "
orig_device = q_w.device "are supported.")
if is_sym not in GPTQ_MARLIN_SUPPORTED_SYM:
q_w = q_w.cpu().numpy().astype(numpy.uint32) raise ValueError(
f"Marlin does not support is_sym = is_sym. "
q_packed = numpy.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), f"Only sym = {GPTQ_MARLIN_SUPPORTED_SYM} are supported.")
dtype=numpy.uint32)
for i in range(pack_factor):
q_packed |= q_w[:, i::pack_factor] << num_bits * i def verify_marlin_supports_shape(output_size_per_partition: int,
input_size_per_partition: int,
q_packed = torch.from_numpy(q_packed.astype(numpy.int32)).to(orig_device) input_size: int, group_size: int) -> None:
return q_packed # Validate output_size_per_partition
if output_size_per_partition % GPTQ_MARLIN_MIN_THREAD_N != 0:
raise ValueError(f"Weight output_size_per_partition = "
def marlin_permute_scales(s, size_k, size_n, group_size, scale_perm, f"{output_size_per_partition} is not divisible by "
scale_perm_single): f" min_thread_n = {GPTQ_MARLIN_MIN_THREAD_N}. "
"Consider reducing tensor_parallel_size or running "
"with --quantization gptq.")
# Validate input_size_per_partition
if input_size_per_partition % GPTQ_MARLIN_MIN_THREAD_K != 0:
raise ValueError(f"Weight input_size_per_partition = "
f"{input_size_per_partition} is not divisible "
f"by min_thread_k = {GPTQ_MARLIN_MIN_THREAD_K}. "
"Consider reducing tensor_parallel_size or running "
"with --quantization gptq.")
if (group_size < input_size
and input_size_per_partition % group_size != 0):
raise ValueError(
f"Weight input_size_per_partition = {input_size_per_partition}"
f" is not divisible by group_size = {group_size}."
"Consider reducing tensor_parallel_size or running "
"with --quantization gptq.")
def marlin_make_workspace(output_size_per_partition: int,
device: torch.device) -> torch.Tensor:
max_workspace_size = (output_size_per_partition //
GPTQ_MARLIN_MIN_THREAD_N) * GPTQ_MARLIN_MAX_PARALLEL
return torch.zeros(max_workspace_size,
dtype=torch.int,
device=device,
requires_grad=False)
def marlin_is_k_full(act_order: bool, is_row_parallel: bool) -> bool:
return (not act_order) or (act_order and not is_row_parallel)
def marlin_repeat_scales_on_all_ranks(act_order: bool, group_size: int,
is_row_parallel: bool) -> bool:
# Need to repeat scales on every rank if act_ordering or
# channelwise and RowParallelLinear
is_channelwise = group_size == -1
return act_order or (is_channelwise and is_row_parallel)
def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor:
return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device),
requires_grad=False)
def marlin_sort_g_idx(
g_idx: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
g_idx_sort_indices = torch.argsort(g_idx).to(torch.int)
return g_idx[g_idx_sort_indices], g_idx_sort_indices
def get_scale_perms():
scale_perm: List[int] = []
for i in range(8):
scale_perm.extend([i + 8 * j for j in range(8)])
scale_perm_single: List[int] = []
for i in range(4):
scale_perm_single.extend(
[2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
return scale_perm, scale_perm_single
def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int,
group_size: int) -> torch.Tensor:
scale_perm, scale_perm_single = get_scale_perms()
if group_size < size_k and group_size != -1: if group_size < size_k and group_size != -1:
s = s.reshape((-1, len(scale_perm)))[:, scale_perm] s = s.reshape((-1, len(scale_perm)))[:, scale_perm]
else: else:
...@@ -68,157 +138,44 @@ def marlin_permute_scales(s, size_k, size_n, group_size, scale_perm, ...@@ -68,157 +138,44 @@ def marlin_permute_scales(s, size_k, size_n, group_size, scale_perm,
return s return s
def marlin_quantize( # Newly generated tensors need to replace existing tensors that are
w: torch.Tensor, # already registered as parameters by vLLM (and won't be freed)
num_bits: int, def replace_tensor(layer: torch.nn.Module, name: str,
group_size: int, new_t: torch.Tensor) -> None:
act_order: bool, # It is important to use resize_() here since it ensures
): # the same buffer is reused
size_k, size_n = w.shape getattr(layer, name).resize_(new_t.shape)
getattr(layer, name).copy_(new_t)
# Normalize group_size del new_t
if group_size == -1:
group_size = size_k
assert group_size <= size_k def apply_marlin_linear(input: torch.Tensor,
weight: torch.Tensor,
# Quantize (and apply act_order if provided) weight_scale: torch.Tensor,
w_ref, q_w, s, g_idx, rand_perm = quantize_weights(w, num_bits, group_size, g_idx: torch.Tensor,
act_order) g_idx_sort_indices: torch.Tensor,
workspace: torch.Tensor,
# For act_order, sort the "weights" and "g_idx" so that group ids are num_bits: int,
# increasing output_size_per_partition: int,
sort_indices = torch.empty(0, dtype=torch.int, device=w.device) input_size_per_partition: int,
if act_order: is_k_full: bool,
q_w, g_idx, sort_indices = sort_weights(q_w, g_idx) bias: Optional[torch.Tensor] = None) -> torch.Tensor:
reshaped_x = input.reshape(-1, input.shape[-1])
# Reformat to marlin out_shape = input.shape[:-1] + (output_size_per_partition, )
marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits,
marlin_perm[num_bits]) output = ops.gptq_marlin_gemm(reshaped_x,
marlin_s = marlin_permute_scales(s, size_k, size_n, group_size, weight,
marlin_scale_perm[num_bits], weight_scale,
marlin_scale_perm_single[num_bits]) g_idx,
g_idx_sort_indices,
# Create result workspace,
res_list = [w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm] num_bits,
for i in range(len(res_list)): size_m=reshaped_x.shape[0],
res_list[i] = res_list[i].to(w.device) size_n=output_size_per_partition,
size_k=input_size_per_partition,
return res_list is_k_full=is_k_full)
if bias is not None:
def inject_24(w, size_k, size_n): output.add_(bias) # In-place add
assert w.shape == (size_k, size_n)
return output.reshape(out_shape)
mask = mask_creator(w.t()).t().cuda().bool()
return (mask * w).contiguous(), mask.contiguous()
def check_24(w, num_rows_to_sample=50, _verbose=False):
BLOCK_SIZE = 4
MAX_NON_ZEROS = 2
w = w.t().contiguous()
print("check_24: w.shape = {}".format(w.shape))
num_rows, num_cols = w.shape
sampled_row_idxs = random.choices(range(num_rows), k=num_rows_to_sample)
if _verbose:
print(f"Sampled row idxs = {sampled_row_idxs}")
total_segments = 0
non_24_segments = 0
for i in sampled_row_idxs:
for j in range(0, num_cols - BLOCK_SIZE, BLOCK_SIZE):
total_segments += 1
block = w[i, j:j + BLOCK_SIZE]
num_nonzero = torch.count_nonzero(block)
if num_nonzero > MAX_NON_ZEROS:
print("i = {} j = {} block = {}".format(i, j, block))
non_24_segments += 1
print(f"{non_24_segments} / {total_segments} do not have 2:4 structure.")
def compress_quantized_24_weight(q_24, size_k, size_n, num_bits):
assert q_24.shape == (size_k, size_n)
# Remove zp to normalize over 0
max_q_val = (1 << num_bits) - 1
zp = (max_q_val + 1) // 2
q_24_no_zp = q_24 - zp
# Compress
q_24_no_zp = q_24_no_zp.t().contiguous()
q_24_no_zp_comp, meta = sparse_semi_structured_from_dense_cutlass(
q_24_no_zp)
q_24_no_zp_comp = q_24_no_zp_comp.t().contiguous()
# Restore zp
q_24_comp = q_24_no_zp_comp + zp
# Resize meta to its actual shape (without moving any data)
meta = meta.resize_(meta.shape[1] // 2, meta.shape[0] * 2)
return q_24_comp, meta
def marlin_24_quantize(
w: torch.Tensor,
num_bits: int,
group_size: int,
):
size_k, size_n = w.shape
# Normalize group_size
if group_size == -1:
group_size = size_k
assert group_size <= size_k
# Inject 2:4 sparsity
w_24, mask_24 = inject_24(w, size_k, size_n)
# Quantize
w_24_ref, q_w_24, s, g_idx, rand_perm = quantize_weights(w_24,
num_bits,
group_size,
act_order=False)
# Compress quantized weight
q_w_24_comp, meta = compress_quantized_24_weight(q_w_24, size_k, size_n,
num_bits)
size_k_comp = size_k // 2
# Reformat to marlin
marlin_24_q_w_comp = marlin_weights(q_w_24_comp, size_k_comp, size_n,
num_bits, marlin_24_perm[num_bits])
marlin_24_s = marlin_permute_scales(s, size_k, size_n, group_size,
marlin_24_scale_perm[num_bits],
marlin_24_scale_perm_single[num_bits])
# Create result
res_list = [w_24_ref, marlin_24_q_w_comp, meta, marlin_24_s]
for i in range(len(res_list)):
res_list[i] = res_list[i].to(w.device)
return res_list
def compute_max_diff(output, output_ref):
return torch.mean(torch.abs(output - output_ref)) / torch.mean(
torch.abs(output_ref))
class MarlinWorkspace:
def __init__(self, out_features, min_thread_n, max_parallel):
assert (out_features % min_thread_n == 0), (
"out_features = {} is undivisible by min_thread_n = {}".format(
out_features, min_thread_n))
max_workspace_size = ((out_features // min_thread_n) * max_parallel)
self.scratch = torch.zeros(max_workspace_size,
dtype=torch.int,
device="cuda")
from typing import Optional
import torch
import vllm._custom_ops as ops
from vllm.platforms import current_platform
from vllm.utils import print_warning_once
from .marlin_utils import marlin_make_workspace, marlin_permute_scales
def is_fp8_marlin_supported():
capability = current_platform.get_device_capability()
return capability[0] >= 8
def apply_fp8_marlin_linear(
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
workspace: torch.Tensor,
size_n: int,
size_k: int,
bias: Optional[torch.Tensor],
) -> torch.Tensor:
# For GPUs that lack FP8 hardware support, we can leverage the
# Marlin kernel for fast weight-only FP8 quantization
reshaped_x = input.reshape(-1, input.shape[-1])
out_shape = input.shape[:-1] + (size_n, )
output = ops.fp8_marlin_gemm(
a=reshaped_x,
b_q_weight=weight,
b_scales=weight_scale,
workspace=workspace,
num_bits=8,
size_m=reshaped_x.shape[0],
size_n=size_n,
size_k=size_k,
)
if bias is not None:
output.add_(bias) # In-place add
return output.reshape(out_shape)
def prepare_fp8_layer_for_marlin(layer: torch.nn.Module) -> None:
print_warning_once(
"Your GPU does not have native support for FP8 computation but "
"FP8 quantization is being used. Weight-only FP8 compression will "
"be used leveraging the Marlin kernel. This may degrade "
"performance for compute-heavy workloads.")
part_size_n = layer.output_size_per_partition
part_size_k = layer.input_size_per_partition
device = layer.weight.device
# WORKSPACE
layer.workspace = marlin_make_workspace(part_size_n, device)
# WEIGHT
# Repack weights to marlin format
marlin_qweight = ops.gptq_marlin_repack(b_q_weight=pack_fp8_to_int32(
layer.weight),
perm=torch.empty(0,
dtype=torch.int,
device=device),
size_k=part_size_k,
size_n=part_size_n,
num_bits=8)
layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False)
# WEIGHT SCALES
# Currently Marlin doesn't support per-tensor scales, so we
# expand it to channelwise
scales = layer.weight_scale.repeat(1, part_size_n).to(
layer.orig_dtype).to(device)
# Permute scales
marlin_scales = marlin_permute_scales(s=scales,
size_k=part_size_k,
size_n=part_size_n,
group_size=-1)
layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False)
def pack_fp8_to_int32(fp8_tensor: torch.Tensor) -> torch.Tensor:
"""
Repack FP8 weights to gptq format (packed int32 elements)
"""
assert fp8_tensor.dtype == torch.float8_e4m3fn
assert fp8_tensor.shape[0] % 4 == 0
# Reshape to prepare for packing
reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:])
# Convert fp8 to uint8 (byte) representation
byte_tensor = reshaped.view(torch.uint8)
# Pack 4 uint8 values into one int32
packed = (byte_tensor[:, 0].to(torch.int32) |
(byte_tensor[:, 1].to(torch.int32) << 8) |
(byte_tensor[:, 2].to(torch.int32) << 16) |
(byte_tensor[:, 3].to(torch.int32) << 24))
return packed.view(fp8_tensor.shape[0] // 4,
*fp8_tensor.shape[1:]).contiguous()
"""Utility functions used for tests and benchmarks"""
from typing import List
import numpy
import torch
from .marlin_utils import GPTQ_MARLIN_TILE, marlin_permute_scales
from .quant_utils import get_pack_factor, quantize_weights, sort_weights
class MarlinWorkspace:
def __init__(self, out_features, min_thread_n, max_parallel):
assert (out_features % min_thread_n == 0), (
"out_features = {} is undivisible by min_thread_n = {}".format(
out_features, min_thread_n))
max_workspace_size = ((out_features // min_thread_n) * max_parallel)
self.scratch = torch.zeros(max_workspace_size,
dtype=torch.int,
device="cuda")
def marlin_permute_weights(q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE):
assert q_w.shape == (size_k, size_n)
assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}"
assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}"
# Permute weights to 16x64 marlin tiles
q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile))
q_w = q_w.permute((0, 2, 1, 3))
q_w = q_w.reshape((size_k // tile, size_n * tile))
q_w = q_w.reshape((-1, perm.numel()))[:, perm].reshape(q_w.shape)
return q_w
def marlin_weights(q_w, size_k, size_n, num_bits, perm):
# Permute
q_w = marlin_permute_weights(q_w, size_k, size_n, perm)
# Pack
pack_factor = get_pack_factor(num_bits)
orig_device = q_w.device
q_w = q_w.cpu().numpy().astype(numpy.uint32)
q_packed = numpy.zeros((q_w.shape[0], q_w.shape[1] // pack_factor),
dtype=numpy.uint32)
for i in range(pack_factor):
q_packed |= q_w[:, i::pack_factor] << num_bits * i
q_packed = torch.from_numpy(q_packed.astype(numpy.int32)).to(orig_device)
return q_packed
def get_weight_perm(num_bits: int):
perm_list: List[int] = []
for i in range(32):
perm1: List[int] = []
col = i // 4
for block in [0, 1]:
for row in [
2 * (i % 4),
2 * (i % 4) + 1,
2 * (i % 4 + 4),
2 * (i % 4 + 4) + 1,
]:
perm1.append(16 * row + col + 8 * block)
for j in range(4):
perm_list.extend([p + 256 * j for p in perm1])
perm = numpy.array(perm_list)
if num_bits == 4:
interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
elif num_bits == 8:
interleave = numpy.array([0, 2, 1, 3])
else:
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()
perm = torch.from_numpy(perm)
return perm
def marlin_quantize(w: torch.Tensor, num_bits: int, group_size: int,
act_order: bool):
size_k, size_n = w.shape
# Normalize group_size
if group_size == -1:
group_size = size_k
assert group_size <= size_k
# Quantize (and apply act_order if provided)
w_ref, q_w, s, g_idx, rand_perm = quantize_weights(w, num_bits, group_size,
act_order)
# For act_order, sort the "weights" and "g_idx" so that group ids are
# increasing
sort_indices = torch.empty(0, dtype=torch.int, device=w.device)
if act_order:
q_w, g_idx, sort_indices = sort_weights(q_w, g_idx)
# Reformat to marlin
weight_perm = get_weight_perm(num_bits)
marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits, weight_perm)
marlin_s = marlin_permute_scales(s, size_k, size_n, group_size)
# Create result
res_list = [w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm]
for i in range(len(res_list)):
res_list[i] = res_list[i].to(w.device)
return res_list
# """Utility functions used for tests and benchmarks"""
# Modified by Roberto Lopez Castro (roberto.lopez.castro@udc.es).
#
import random
from typing import List
import numpy
import torch import torch
from .marlin_utils_test import marlin_weights
from .quant_utils import quantize_weights
# This is PyTorch implementation of main part of reorder_meta() # This is PyTorch implementation of main part of reorder_meta()
# function, from tools/util/include/cutlass/util/host_reorder.h file # function, from tools/util/include/cutlass/util/host_reorder.h file
...@@ -306,3 +311,155 @@ def mask_creator(tensor): ...@@ -306,3 +311,155 @@ def mask_creator(tensor):
mask = w_b.scatter_(dim=1, index=index, value=0).reshape(tensor.shape) mask = w_b.scatter_(dim=1, index=index, value=0).reshape(tensor.shape)
return mask return mask
def inject_24(w, size_k, size_n):
assert w.shape == (size_k, size_n)
mask = mask_creator(w.t()).t().cuda().bool()
return (mask * w).contiguous(), mask.contiguous()
def check_24(w, num_rows_to_sample=50, _verbose=False):
BLOCK_SIZE = 4
MAX_NON_ZEROS = 2
w = w.t().contiguous()
print("check_24: w.shape = {}".format(w.shape))
num_rows, num_cols = w.shape
sampled_row_idxs = random.choices(range(num_rows), k=num_rows_to_sample)
if _verbose:
print(f"Sampled row idxs = {sampled_row_idxs}")
total_segments = 0
non_24_segments = 0
for i in sampled_row_idxs:
for j in range(0, num_cols - BLOCK_SIZE, BLOCK_SIZE):
total_segments += 1
block = w[i, j:j + BLOCK_SIZE]
num_nonzero = torch.count_nonzero(block)
if num_nonzero > MAX_NON_ZEROS:
print("i = {} j = {} block = {}".format(i, j, block))
non_24_segments += 1
print(f"{non_24_segments} / {total_segments} do not have 2:4 structure.")
def compress_quantized_24_weight(q_24, size_k, size_n, num_bits):
assert q_24.shape == (size_k, size_n)
# Remove zp to normalize over 0
max_q_val = (1 << num_bits) - 1
zp = (max_q_val + 1) // 2
q_24_no_zp = q_24 - zp
# Compress
q_24_no_zp = q_24_no_zp.t().contiguous()
q_24_no_zp_comp, meta = sparse_semi_structured_from_dense_cutlass(
q_24_no_zp)
q_24_no_zp_comp = q_24_no_zp_comp.t().contiguous()
# Restore zp
q_24_comp = q_24_no_zp_comp + zp
# Resize meta to its actual shape (without moving any data)
meta = meta.resize_(meta.shape[1] // 2, meta.shape[0] * 2)
return q_24_comp, meta
def get_scale_perms_24():
scale_perm: List[int] = []
for i in range(8):
scale_perm.extend([i * 8 + j for j in [0, 4, 1, 5, 2, 6, 3, 7]])
scale_perm_single: List[int] = []
for i in range(8):
scale_perm_single.extend([8 * i + j for j in [0, 1, 2, 3, 4, 5, 6, 7]])
return scale_perm, scale_perm_single
def get_weight_perm_24(num_bits: int):
perm_list: List[int] = []
for i in range(32):
perm1: List[int] = []
col = i // 4
col_o = col // 2
for block in [0, 1]:
for row in [
2 * (i % 4),
2 * (i % 4) + 1,
2 * (i % 4 + 4),
2 * (i % 4 + 4) + 1,
]:
perm1.append(16 * row + col_o * 256 + 8 * (col % 2) +
4 * block)
for j in range(4):
perm_list.extend([p + 1 * j for p in perm1])
perm = numpy.array(perm_list)
if num_bits == 4:
interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
elif num_bits == 8:
interleave = numpy.array([0, 2, 1, 3])
else:
raise ValueError("num_bits must be 4 or 8, got {}".format(num_bits))
perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()
perm = torch.from_numpy(perm)
return perm
def marlin_permute_scales_24(s: torch.Tensor, size_k: int, size_n: int,
group_size: int) -> torch.Tensor:
scale_perm, scale_perm_single = get_scale_perms_24()
if group_size < size_k and group_size != -1:
s = s.reshape((-1, len(scale_perm)))[:, scale_perm]
else:
s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
s = s.reshape((-1, size_n)).contiguous()
return s
def marlin_24_quantize(
w: torch.Tensor,
num_bits: int,
group_size: int,
):
size_k, size_n = w.shape
# Normalize group_size
if group_size == -1:
group_size = size_k
assert group_size <= size_k
# Inject 2:4 sparsity
w_24, mask_24 = inject_24(w, size_k, size_n)
# Quantize
w_24_ref, q_w_24, s, g_idx, rand_perm = quantize_weights(w_24,
num_bits,
group_size,
act_order=False)
# Compress quantized weight
q_w_24_comp, meta = compress_quantized_24_weight(q_w_24, size_k, size_n,
num_bits)
size_k_comp = size_k // 2
# Reformat to marlin
weight_perm = get_weight_perm_24(num_bits)
marlin_24_q_w_comp = marlin_weights(q_w_24_comp, size_k_comp, size_n,
num_bits, weight_perm)
marlin_24_s = marlin_permute_scales_24(s, size_k, size_n, group_size)
# Create result
res_list = [w_24_ref, marlin_24_q_w_comp, meta, marlin_24_s]
for i in range(len(res_list)):
res_list[i] = res_list[i].to(w.device)
return res_list
from typing import List, Optional, Tuple, Union
import torch
from torch.nn import Parameter
from vllm import _custom_ops as ops
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
def cutlass_fp8_supported() -> bool:
capability = current_platform.get_device_capability()
capability = capability[0] * 10 + capability[1]
return ops.cutlass_scaled_mm_supports_fp8(capability)
def per_tensor_dequantize(
tensor: torch.Tensor, inv_scale: Union[float,
torch.Tensor]) -> torch.Tensor:
fake_qweight = tensor.to(torch.float16)
dq_weight = fake_qweight * inv_scale
return dq_weight
def all_close_1d(x: torch.Tensor) -> bool:
assert len(x.shape) == 1
return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0]))
def create_per_tensor_scale_param(
output_partition_sizes: List[int],
**extra_weight_attrs,
) -> Parameter:
scale = Parameter(torch.empty(len(output_partition_sizes),
dtype=torch.float32),
requires_grad=False)
scale[:] = torch.finfo(torch.float32).min
set_weight_attrs(scale, {
"needs_scalar_to_array": True,
**extra_weight_attrs
})
return scale
def create_per_channel_scale_param(output_partition_sizes: List[int],
**extra_weight_attrs) -> Parameter:
scale = Parameter(torch.empty((sum(output_partition_sizes), 1),
dtype=torch.float32),
requires_grad=False)
scale[:] = torch.finfo(torch.float32).min
set_weight_attrs(scale, {"output_dim": 0, **extra_weight_attrs})
return scale
def convert_to_channelwise(
weight_scale: torch.Tensor,
logical_widths: List[int]) -> Tuple[torch.Tensor, torch.Tensor]:
# Create channelwise buffer
weight_scale_channel = torch.empty((sum(logical_widths), 1),
dtype=torch.float32,
device=weight_scale.device)
# Expand each scale to match the size of each logical matrix.
start = 0
for idx, logical_width in enumerate(logical_widths):
end = start + logical_width
weight_scale_channel[start:end, :] = weight_scale[idx]
start = end
return weight_scale_channel
def requantize_with_max_scale(
weight: torch.Tensor, weight_scale: torch.Tensor,
logical_widths: List[int]) -> Tuple[torch.Tensor, torch.Tensor]:
# Max scale to be used for requanitzation.
max_w_scale = weight_scale.max()
# QKV / MLP is fused in the on disk checkpoint if any of the
# weight scales are still set to the default since we initialize
# N weight scales for N shards but we only load 1 weight scale
# from disk in this case. Skip requantization in this case (since)
# we already are quantized with the single scale.
# * Sample Model: nm-testing/Phi-3-mini-128k-instruct-FP8
unfused_module_in_checkpoint = (weight_scale[-1] > torch.finfo(
torch.float8_e4m3fn).min)
# If unfused checkpoint, need requanize with the single scale.
if unfused_module_in_checkpoint:
start = 0
for idx, logical_width in enumerate(logical_widths):
end = start + logical_width
weight_dq = per_tensor_dequantize(weight[start:end, :],
weight_scale[idx])
weight[start:end, :], _ = ops.scaled_fp8_quant(
weight_dq, max_w_scale)
start = end
return max_w_scale, weight
def apply_fp8_linear(
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
input_scale: torch.Tensor,
bias: Optional[torch.Tensor] = None,
cutlass_fp8_supported: bool = True,
) -> torch.Tensor:
# ops.scaled_fp8_quant supports both dynamic and static quant.
# If dynamic, layer.input_scale is None and x_scale computed from x.
# If static, layer.input_scale is scalar and x_scale is input_scale.
if cutlass_fp8_supported:
qinput, x_scale = ops.scaled_fp8_quant(input, input_scale)
# Fused GEMM_DQ
output = ops.cutlass_scaled_mm(qinput,
weight,
out_dtype=input.dtype,
scale_a=x_scale,
scale_b=weight_scale,
bias=bias)
else:
qinput, x_scale = ops.scaled_fp8_quant(input,
input_scale,
batch_dim_padding=17)
# Fused GEMM_DQ -- note we padded the input above because
# torch._scaled_mm is more performant for matrices with
# batch dimension > 16. Note that this could change
# in the future.
output, _ = torch._scaled_mm(qinput,
weight,
out_dtype=input.dtype,
scale_a=x_scale,
scale_b=weight_scale,
bias=bias)
return torch.narrow(output, 0, 0, input.shape[0])
def apply_int8_linear(
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
input_scale: torch.Tensor,
bias: Optional[torch.Tensor] = None,
):
# ops.scaled_int8_quant supports both dynamic and static quant.
# * dynamic, layer.input_scale is None and x_scale computed from x.
# * static, layer.input_scale is scalar and x_scale is input_scale.
x_q, x_scale = ops.scaled_int8_quant(input, input_scale)
return ops.cutlass_scaled_mm(x_q,
weight,
scale_a=x_scale,
scale_b=weight_scale,
out_dtype=input.dtype,
bias=bias)
from functools import cached_property from functools import cached_property
from typing import Optional, Tuple from typing import Tuple
import torch import torch
import torch.jit import torch.jit
import torch.nn as nn
from vllm.model_executor.layers.spec_decode_base_sampler import (
SpecDecodeBaseSampler)
class RejectionSampler(nn.Module):
class RejectionSampler(SpecDecodeBaseSampler):
"""Apply modified rejection sampling as described in "Accelerating Large """Apply modified rejection sampling as described in "Accelerating Large
Language Model Decoding with Speculative Sampling" Language Model Decoding with Speculative Sampling"
https://arxiv.org/pdf/2302.01318.pdf. https://arxiv.org/pdf/2302.01318.pdf.
...@@ -22,39 +24,11 @@ class RejectionSampler(nn.Module): ...@@ -22,39 +24,11 @@ class RejectionSampler(nn.Module):
Require when bonus tokens will cause corrupt KV cache for Require when bonus tokens will cause corrupt KV cache for
proposal methods that require KV cache. proposal methods that require KV cache.
strict_mode: Whether or not to perform shape/device/dtype checks strict_mode: Whether or not to perform shape/device/dtype checks
during sampling. This catches correctness issues but adds during sampling. This catches correctness issues but adds
nontrivial latency. nontrivial latency.
""" """
super().__init__() super().__init__(disable_bonus_tokens=disable_bonus_tokens,
self._disable_bonus_tokens = disable_bonus_tokens strict_mode=strict_mode)
self._strict_mode = strict_mode
# NOTE: A "bonus token" is accepted iff all proposal tokens are
# accepted. There is always only one possible bonus token. We store this
# value in a variable for readability.
self._num_bonus_tokens = 1
self.num_accepted_tokens: Optional[torch.Tensor] = None
self.num_emitted_tokens: Optional[torch.Tensor] = None
self.num_draft_tokens: int = 0
def init_gpu_tensors(self, rank: int) -> None:
assert self.num_accepted_tokens is None
device = f"cuda:{rank}"
self.num_accepted_tokens = torch.tensor(0,
dtype=torch.long,
device=device)
self.num_emitted_tokens = torch.tensor(0,
dtype=torch.long,
device=device)
@property
def probs_dtype(self):
return torch.float32
@property
def token_id_dtype(self):
return torch.int64
def forward( def forward(
self, self,
...@@ -100,21 +74,15 @@ class RejectionSampler(nn.Module): ...@@ -100,21 +74,15 @@ class RejectionSampler(nn.Module):
# Only perform shape/dtype/device checking in strict mode, as it adds # Only perform shape/dtype/device checking in strict mode, as it adds
# overhead. # overhead.
if self._strict_mode: if self._strict_mode:
self._raise_if_incorrect_shape(target_probs, bonus_token_ids, self._raise_if_incorrect_input(target_probs, bonus_token_ids,
draft_probs, draft_token_ids)
self._raise_if_incorrect_dtype(target_probs, bonus_token_ids,
draft_probs, draft_token_ids) draft_probs, draft_token_ids)
self._raise_if_inconsistent_device(target_probs, bonus_token_ids,
draft_probs, draft_token_ids) accepted, recovered_token_ids = (
self._raise_if_out_of_bounds_vocab(target_probs.shape[-1], self._batch_modified_rejection_sampling(
bonus_token_ids, target_probs,
draft_token_ids) draft_probs,
draft_token_ids,
accepted, recovered_token_ids = self._batch_modified_rejection_sampling( ))
target_probs,
draft_probs,
draft_token_ids,
)
output_token_ids = self._create_output( output_token_ids = self._create_output(
accepted, accepted,
...@@ -272,128 +240,6 @@ class RejectionSampler(nn.Module): ...@@ -272,128 +240,6 @@ class RejectionSampler(nn.Module):
""" """
return torch.finfo(self.probs_dtype).tiny return torch.finfo(self.probs_dtype).tiny
def _create_output(
self,
accepted: torch.Tensor, # [batch_size, k]
recovered_token_ids: torch.Tensor, # [batch_size, k]
draft_token_ids: torch.Tensor, # [batch_size, k]
bonus_token_ids: torch.Tensor, # [batch_size]
) -> torch.Tensor:
"""Format output. Returns a matrix of token ids. When
a token is rejected via rejection sampling, all subsequent
token ids are set to -1 for the sequence.
shape = [batch_size, k + num_bonus_tokens]
"""
bonus_token_ids = bonus_token_ids.squeeze()
batch_size, k = recovered_token_ids.shape
# Determine the index of the first False value for each row.
limits = (accepted == 0).max(1).indices
limits[~(accepted == 0).any(1)] = k
# Create masks using the indices.
indices = torch.arange(k, device=accepted.device).unsqueeze(0)
accepted_mask = indices < limits.unsqueeze(1)
after_false_mask = indices == limits.unsqueeze(1)
# Create an extended output tensor
output_with_bonus_tokens = -torch.ones(
(batch_size, k + self._num_bonus_tokens),
dtype=self.token_id_dtype,
device=accepted.device)
output = output_with_bonus_tokens[:, :k]
# Fill in the first k columns of the output tensor using masks and data
# tensors.
torch.where(accepted_mask,
draft_token_ids,
-torch.ones_like(draft_token_ids),
out=output)
# Fill the last column.
# We check output directly as accepted may have True values inconsistent
# with causal acceptance.
output_with_bonus_tokens[:, -1] = torch.where(output[:, -1] != -1,
bonus_token_ids, -1)
# We disable bonus tokens because it causes corrupt KV cache for
# proposal methods that require KV cache. We can fix it by "prefilling"
# the bonus token in the proposer. The following issue tracks the fix.
# https://github.com/vllm-project/vllm/issues/4212
if self._disable_bonus_tokens:
output_with_bonus_tokens[:, -1] = -1
# Fill the recovered token ids.
output.mul_(~after_false_mask).add_(
recovered_token_ids.mul(after_false_mask))
self.num_accepted_tokens += accepted.sum()
self.num_emitted_tokens += (output_with_bonus_tokens != -1).sum()
self.num_draft_tokens += batch_size * k
return output_with_bonus_tokens
def _raise_if_incorrect_shape(
self,
target_probs: torch.Tensor,
bonus_token_ids: torch.Tensor,
draft_probs: torch.Tensor,
draft_token_ids: torch.Tensor,
) -> None:
(target_batch_size, num_target_probs,
target_vocab_size) = target_probs.shape
bonus_batch_size, num_bonus_tokens = bonus_token_ids.shape
draft_batch_size, num_draft_probs, draft_vocab_size = draft_probs.shape
draft_token_ids_batch_size, num_draft_token_ids = draft_token_ids.shape
assert draft_batch_size == target_batch_size
assert num_draft_probs == num_target_probs
assert (draft_vocab_size == target_vocab_size
), f"{draft_vocab_size=} {target_vocab_size=}"
assert draft_token_ids_batch_size == draft_batch_size
assert num_draft_token_ids == num_draft_probs
assert bonus_batch_size == target_batch_size
assert num_bonus_tokens == self._num_bonus_tokens
def _raise_if_incorrect_dtype(
self,
target_probs: torch.Tensor,
bonus_token_ids: torch.Tensor,
draft_probs: torch.Tensor,
draft_token_ids: torch.Tensor,
) -> None:
assert all(probs.dtype == self.probs_dtype
for probs in [target_probs, draft_probs])
assert all(token_ids.dtype == self.token_id_dtype
for token_ids in [bonus_token_ids, draft_token_ids])
def _raise_if_inconsistent_device(
self,
target_probs: torch.Tensor,
bonus_token_ids: torch.Tensor,
draft_probs: torch.Tensor,
draft_token_ids: torch.Tensor,
) -> None:
devices = [
t.device for t in
[target_probs, bonus_token_ids, draft_probs, draft_token_ids]
]
assert all([devices[0] == device for device in devices])
def _raise_if_out_of_bounds_vocab(
self,
vocab_size: int,
bonus_token_ids: torch.Tensor,
draft_token_ids: torch.Tensor,
) -> None:
assert torch.all(bonus_token_ids < vocab_size)
assert torch.all(bonus_token_ids >= 0)
assert torch.all(draft_token_ids < vocab_size)
assert torch.all(draft_token_ids >= 0)
# torch.multinomial forces a GPU<->CPU sync. # torch.multinomial forces a GPU<->CPU sync.
# Therefore, we use an optimized implementation instead that skips the sync. # Therefore, we use an optimized implementation instead that skips the sync.
......
...@@ -28,6 +28,7 @@ import torch ...@@ -28,6 +28,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
from vllm.utils import is_tpu
def _rotate_neox(x: torch.Tensor) -> torch.Tensor: def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
...@@ -43,6 +44,19 @@ def _rotate_gptj(x: torch.Tensor) -> torch.Tensor: ...@@ -43,6 +44,19 @@ def _rotate_gptj(x: torch.Tensor) -> torch.Tensor:
return x.flatten(-2) return x.flatten(-2)
def _apply_rotary_emb(
x: torch.Tensor,
freqs_cis: torch.Tensor,
) -> torch.Tensor:
x_ = torch.view_as_complex(
torch.stack(torch.chunk(x.transpose(1, 2).float(), 2, dim=-1), dim=-1))
x_out = torch.view_as_real(x_ * freqs_cis).type_as(x)
x_out = torch.cat(torch.chunk(x_out, 2, dim=-1), dim=-2)
x_out = x_out.reshape(x_out.shape[0], x_out.shape[1], x_out.shape[2],
-1).transpose(1, 2)
return x_out
class RotaryEmbedding(CustomOp): class RotaryEmbedding(CustomOp):
"""Original rotary positional embedding.""" """Original rotary positional embedding."""
...@@ -64,8 +78,14 @@ class RotaryEmbedding(CustomOp): ...@@ -64,8 +78,14 @@ class RotaryEmbedding(CustomOp):
self.dtype = dtype self.dtype = dtype
cache = self._compute_cos_sin_cache() cache = self._compute_cos_sin_cache()
cache = cache.to(dtype) self.use_native2 = is_tpu() and is_neox_style
self.register_buffer("cos_sin_cache", cache, persistent=False) if not self.use_native2:
cache = cache.to(dtype)
self.register_buffer("cos_sin_cache", cache, persistent=False)
else:
cos, sin = cache.chunk(2, dim=-1)
freqs_cis = cos + 1j * sin
self.register_buffer("freqs_cis", freqs_cis, persistent=False)
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
"""Compute the inverse frequency.""" """Compute the inverse frequency."""
...@@ -100,7 +120,11 @@ class RotaryEmbedding(CustomOp): ...@@ -100,7 +120,11 @@ class RotaryEmbedding(CustomOp):
key: torch.Tensor, key: torch.Tensor,
offsets: Optional[torch.Tensor] = None, offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
"""PyTorch-native implementation equivalent to forward().""" """A PyTorch-native implementation equivalent to forward().
This method mimics the implementation of the custom CUDA kernel
used in `forward_cuda()`.
"""
query = query.view(*query.shape[:-1], -1, self.head_size) query = query.view(*query.shape[:-1], -1, self.head_size)
key = key.view(*key.shape[:-1], -1, self.head_size) key = key.view(*key.shape[:-1], -1, self.head_size)
...@@ -138,6 +162,42 @@ class RotaryEmbedding(CustomOp): ...@@ -138,6 +162,42 @@ class RotaryEmbedding(CustomOp):
key = key.flatten(-2) key = key.flatten(-2)
return query, key return query, key
def forward_native2(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Another PyTorch-native implementation of forward().
This method might perform better than `forward_native()` when compiled.
"""
if positions.dim() == 1:
batch_size = 1
seq_len = positions.shape[0]
else:
batch_size, seq_len = positions.shape
if offsets is not None:
positions = positions + offsets
freqs_cis = self.freqs_cis.index_select(0, positions.flatten())
freqs_cis = freqs_cis.view(batch_size, 1, seq_len, -1)
query_shape = query.shape
query = query.view(batch_size, seq_len, -1, self.head_size)
query_rot = query[..., :self.rotary_dim]
query_pass = query[..., self.rotary_dim:]
query_rot = _apply_rotary_emb(query_rot, freqs_cis)
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
key_shape = key.shape
key = key.view(batch_size, seq_len, -1, self.head_size)
key_rot = key[..., :self.rotary_dim]
key_pass = key[..., self.rotary_dim:]
key_rot = _apply_rotary_emb(key_rot, freqs_cis)
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return query, key
def forward_cuda( def forward_cuda(
self, self,
positions: torch.Tensor, positions: torch.Tensor,
...@@ -161,6 +221,40 @@ class RotaryEmbedding(CustomOp): ...@@ -161,6 +221,40 @@ class RotaryEmbedding(CustomOp):
self.cos_sin_cache, self.is_neox_style) self.cos_sin_cache, self.is_neox_style)
return query, key return query, key
def forward_xpu(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
from vllm._ipex_ops import ipex_ops as ops
self.cos_sin_cache = self.cos_sin_cache.to(positions.device,
dtype=query.dtype)
# ops.rotary_embedding()/batched_rotary_embedding()
# are in-place operations that update the query and key tensors.
if offsets is not None:
ops.batched_rotary_embedding(positions, query, key, self.head_size,
self.cos_sin_cache,
self.is_neox_style, self.rotary_dim,
offsets)
else:
ops.rotary_embedding(positions, query, key, self.head_size,
self.cos_sin_cache, self.is_neox_style)
return query, key
def forward_tpu(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
forward_fn = (self.forward_native2
if self.use_native2 else self.forward_native)
return forward_fn(positions, query, key, offsets)
def extra_repr(self) -> str: def extra_repr(self) -> str:
s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}" s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}"
s += f", max_position_embeddings={self.max_position_embeddings}" s += f", max_position_embeddings={self.max_position_embeddings}"
...@@ -396,7 +490,7 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding): ...@@ -396,7 +490,7 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
return cache return cache
class Phi3SuScaledRotaryEmbedding(nn.Module): class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
"""Phi3 family of models scaled rotary embedding. """Phi3 family of models scaled rotary embedding.
Based on the original RotaryEmbedding implementation. Based on the original RotaryEmbedding implementation.
...@@ -413,18 +507,19 @@ class Phi3SuScaledRotaryEmbedding(nn.Module): ...@@ -413,18 +507,19 @@ class Phi3SuScaledRotaryEmbedding(nn.Module):
dtype: torch.dtype, dtype: torch.dtype,
short_factor: List[float], short_factor: List[float],
long_factor: List[float], long_factor: List[float],
short_mscale: float = 1.1, short_mscale: float = 1.0,
long_mscale: float = 1.225, long_mscale: float = 1.0,
): ):
super().__init__() super().__init__()
if rotary_dim != head_size: if rotary_dim != head_size:
raise ValueError( raise ValueError(
f"`Phi3SuScaledRotaryEmbedding` does not support rotary_dim != \ f"`Phi3LongRoPEScaledRotaryEmbedding` does not support \
head_size ({rotary_dim}!={head_size}).") rotary_dim != head_size ({rotary_dim}!={head_size}).")
if is_neox_style is False: if is_neox_style is False:
raise ValueError( raise ValueError(
"`Phi3SuScaledRotaryEmbedding` only supports neox_style.") "`Phi3LongRoPEScaledRotaryEmbedding` only supports neox_style."
)
self.head_size = head_size self.head_size = head_size
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
...@@ -435,6 +530,16 @@ class Phi3SuScaledRotaryEmbedding(nn.Module): ...@@ -435,6 +530,16 @@ class Phi3SuScaledRotaryEmbedding(nn.Module):
self.short_mscale = short_mscale self.short_mscale = short_mscale
self.long_mscale = long_mscale self.long_mscale = long_mscale
scale = (self.max_position_embeddings /
self.original_max_position_embeddings)
if scale <= 1.0:
self.scaling_factor = 1.0
else:
self.scaling_factor = math.sqrt(
1 + math.log(scale) /
math.log(self.original_max_position_embeddings))
short_cache = self._compute_cos_sin_cache( short_cache = self._compute_cos_sin_cache(
original_max_position_embeddings, short_factor, short_mscale) original_max_position_embeddings, short_factor, short_mscale)
short_cache = short_cache.to(dtype) short_cache = short_cache.to(dtype)
...@@ -470,8 +575,8 @@ class Phi3SuScaledRotaryEmbedding(nn.Module): ...@@ -470,8 +575,8 @@ class Phi3SuScaledRotaryEmbedding(nn.Module):
inv_freq = self._compute_inv_freq(rescale_factors) inv_freq = self._compute_inv_freq(rescale_factors)
t = torch.arange(max_position_embeddings, dtype=torch.float) t = torch.arange(max_position_embeddings, dtype=torch.float)
freqs = torch.einsum("i,j -> ij", t, inv_freq) freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos() * mscale cos = freqs.cos() * mscale * self.scaling_factor
sin = freqs.sin() * mscale sin = freqs.sin() * mscale * self.scaling_factor
cache = torch.cat((cos, sin), dim=-1) cache = torch.cat((cos, sin), dim=-1)
return cache return cache
...@@ -660,7 +765,9 @@ def get_rope( ...@@ -660,7 +765,9 @@ def get_rope(
is_neox_style, dtype) is_neox_style, dtype)
else: else:
scaling_type = rope_scaling["type"] scaling_type = rope_scaling["type"]
if scaling_type != "su": # The correct one should be "longrope" but keep "su" here
# for backward compatible
if scaling_type != "su" and scaling_type != "longrope":
scaling_factor = rope_scaling["factor"] scaling_factor = rope_scaling["factor"]
if scaling_type == "linear": if scaling_type == "linear":
rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim, rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim,
...@@ -710,7 +817,7 @@ def get_rope( ...@@ -710,7 +817,7 @@ def get_rope(
for k, v in rope_scaling.items() for k, v in rope_scaling.items()
if k in ("short_mscale", "long_mscale") if k in ("short_mscale", "long_mscale")
} }
rotary_emb = Phi3SuScaledRotaryEmbedding( rotary_emb = Phi3LongRoPEScaledRotaryEmbedding(
head_size, rotary_dim, max_position, original_max_position, head_size, rotary_dim, max_position, original_max_position,
base, is_neox_style, dtype, short_factor, long_factor, base, is_neox_style, dtype, short_factor, long_factor,
**extra_kwargs) **extra_kwargs)
......
...@@ -174,7 +174,7 @@ def _apply_min_tokens_penalty( ...@@ -174,7 +174,7 @@ def _apply_min_tokens_penalty(
min_tokens = sampling_params.min_tokens min_tokens = sampling_params.min_tokens
token_ids_to_penalize = sampling_params.all_stop_token_ids token_ids_to_penalize = sampling_params.all_stop_token_ids
if min_tokens > 0 and token_ids_to_penalize: if min_tokens > 0 and token_ids_to_penalize:
seqs_to_penalize = [] seqs_to_penalize: List[int] = []
for j, seq_id in enumerate(seq_ids): for j, seq_id in enumerate(seq_ids):
seq_data = seq_group.seq_data[seq_id] seq_data = seq_group.seq_data[seq_id]
if len(seq_data.output_token_ids) < min_tokens: if len(seq_data.output_token_ids) < min_tokens:
...@@ -285,7 +285,7 @@ def _greedy_sample( ...@@ -285,7 +285,7 @@ def _greedy_sample(
same as the length of selected_seq_groups. If the corresponding same as the length of selected_seq_groups. If the corresponding
seq_group has do_sample=False, tuple contains ([], []) seq_group has do_sample=False, tuple contains ([], [])
""" """
samples = samples.tolist() samples_lst = samples.tolist()
sample_idx = 0 sample_idx = 0
results: SampleResultType = [] results: SampleResultType = []
for seq_group in selected_seq_groups: for seq_group in selected_seq_groups:
...@@ -298,7 +298,7 @@ def _greedy_sample( ...@@ -298,7 +298,7 @@ def _greedy_sample(
assert num_parent_seqs == 1, ( assert num_parent_seqs == 1, (
"Greedy sampling should have only one seq.") "Greedy sampling should have only one seq.")
parent_ids = list(range(num_parent_seqs)) parent_ids = list(range(num_parent_seqs))
next_token_ids = [samples[sample_idx]] next_token_ids = [samples_lst[sample_idx]]
results.append((next_token_ids, parent_ids)) results.append((next_token_ids, parent_ids))
sample_idx += num_parent_seqs sample_idx += num_parent_seqs
return results return results
...@@ -394,7 +394,7 @@ def _beam_search_sample( ...@@ -394,7 +394,7 @@ def _beam_search_sample(
next_token_ids = next_token_ids.tolist() next_token_ids = next_token_ids.tolist()
else: else:
# Generation phase. # Generation phase.
cumulative_logprobs: List[int] = [ cumulative_logprobs: List[float] = [
seq_group.seq_data[seq_id].cumulative_logprob seq_group.seq_data[seq_id].cumulative_logprob
for seq_id in seq_ids for seq_id in seq_ids
] ]
...@@ -466,8 +466,9 @@ def _sample_with_torch( ...@@ -466,8 +466,9 @@ def _sample_with_torch(
categorized_seq_group_ids[sampling_type].append(i) categorized_seq_group_ids[sampling_type].append(i)
sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {} sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
sample_metadata = {} sample_metadata: Dict[SamplingType,
multinomial_samples = {} Tuple[List[int], List[SequenceGroupToSample]]] = {}
multinomial_samples: Dict[SamplingType, torch.Tensor] = {}
# Create output tensor for sampled token ids. # Create output tensor for sampled token ids.
if include_gpu_probs_tensor: if include_gpu_probs_tensor:
...@@ -494,7 +495,7 @@ def _sample_with_torch( ...@@ -494,7 +495,7 @@ def _sample_with_torch(
greedy_samples = torch.argmax(logprobs[long_sample_indices], greedy_samples = torch.argmax(logprobs[long_sample_indices],
dim=-1) dim=-1)
if include_gpu_probs_tensor: if sampled_token_ids_tensor is not None:
# Store sampled tokens in output tensor. # Store sampled tokens in output tensor.
sampled_token_ids_tensor[ sampled_token_ids_tensor[
long_sample_indices] = greedy_samples.unsqueeze(-1) long_sample_indices] = greedy_samples.unsqueeze(-1)
...@@ -522,7 +523,7 @@ def _sample_with_torch( ...@@ -522,7 +523,7 @@ def _sample_with_torch(
probs[long_sample_indices], max_best_of_in_batch, probs[long_sample_indices], max_best_of_in_batch,
**seeded_args) **seeded_args)
if include_gpu_probs_tensor: if sampled_token_ids_tensor is not None:
# Store sampled tokens in output tensor. # Store sampled tokens in output tensor.
sampled_token_ids_tensor[ sampled_token_ids_tensor[
long_sample_indices] = multinomial_samples[sampling_type] long_sample_indices] = multinomial_samples[sampling_type]
...@@ -571,7 +572,9 @@ def _sample_with_triton_kernel( ...@@ -571,7 +572,9 @@ def _sample_with_triton_kernel(
categorized_seq_group_ids[sampling_type].append(i) categorized_seq_group_ids[sampling_type].append(i)
sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {} sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
sample_metadata = {} sample_metadata: Dict[SamplingType,
Tuple[List[int], List[SequenceGroupToSample],
torch.Tensor, torch.Tensor]] = {}
max_best_of_in_batch = 1 max_best_of_in_batch = 1
# Counterintiutively, having two loops here is actually faster. # Counterintiutively, having two loops here is actually faster.
...@@ -676,7 +679,7 @@ def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: ...@@ -676,7 +679,7 @@ def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
Returns: Returns:
torch.Tensor: 1D tensor of shape (N,) where N is the no. of tokens. torch.Tensor: 1D tensor of shape (N,) where N is the no. of tokens.
Each element in the returned tensor represents the rank Each element in the returned tensor represents the rank
of the chosen token in the input logprob tensor. of the chosen token in the input logprob tensor.
""" """
vals = x[torch.arange(0, len(x), device=x.device, dtype=indices.dtype), vals = x[torch.arange(0, len(x), device=x.device, dtype=indices.dtype),
...@@ -962,7 +965,7 @@ def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor, ...@@ -962,7 +965,7 @@ def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor,
distribution. distribution.
- Greedy sampling performs `argmax` to obtain the token with the - Greedy sampling performs `argmax` to obtain the token with the
highest likelihood. highest likelihood.
Ignoring greedy sampling for a moment, we find that the computed probability Ignoring greedy sampling for a moment, we find that the computed probability
distribution has the following property: we can sample from it independently distribution has the following property: we can sample from it independently
and find that the token sampled by the Sampler has a frequency corresponding and find that the token sampled by the Sampler has a frequency corresponding
...@@ -1008,14 +1011,14 @@ def _build_sampler_output( ...@@ -1008,14 +1011,14 @@ def _build_sampler_output(
speculative decoding rejection sampling. speculative decoding rejection sampling.
""" """
sampler_output = [] sampler_output: List[CompletionSequenceGroupOutput] = []
for (seq_group, sample_result, group_prompt_logprobs, for (seq_group, sample_result, group_prompt_logprobs,
group_sample_logprobs) in zip(sampling_metadata.seq_groups, group_sample_logprobs) in zip(sampling_metadata.seq_groups,
sample_results, prompt_logprobs, sample_results, prompt_logprobs,
sample_logprobs): sample_logprobs):
seq_ids = seq_group.seq_ids seq_ids = seq_group.seq_ids
next_token_ids, parent_ids = sample_result next_token_ids, parent_ids = sample_result
seq_outputs = [] seq_outputs: List[SequenceOutput] = []
for parent_id, next_token_id, logprobs in zip(parent_ids, for parent_id, next_token_id, logprobs in zip(parent_ids,
next_token_ids, next_token_ids,
group_sample_logprobs): group_sample_logprobs):
......
from abc import abstractmethod
from typing import Optional
import torch
import torch.jit
import torch.nn as nn
class SpecDecodeBaseSampler(nn.Module):
"""Base class for samplers used for Speculative Decoding verification
step.
"""
def __init__(self,
disable_bonus_tokens: bool = True,
strict_mode: bool = False):
"""Base class constructor.
Args:
disable_bonus_tokens: Whether or not to disable the bonus token.
Require when bonus tokens will cause corrupt KV cache for
proposal methods that require KV cache.
strict_mode: Whether or not to perform shape/device/dtype checks
during sampling. This catches correctness issues but adds
nontrivial latency.
"""
super().__init__()
self._disable_bonus_tokens = disable_bonus_tokens
self._strict_mode = strict_mode
# NOTE: A "bonus token" is accepted iff all proposal tokens are
# accepted. There is always only one possible bonus token. We store this
# value in a variable for readability.
self._num_bonus_tokens = 1
self.num_accepted_tokens: Optional[torch.Tensor] = None
self.num_emitted_tokens: Optional[torch.Tensor] = None
self.num_draft_tokens: int = 0
def init_gpu_tensors(self, rank: int) -> None:
assert self.num_accepted_tokens is None
device = f"cuda:{rank}"
self.num_accepted_tokens = torch.tensor(0,
dtype=torch.long,
device=device)
self.num_emitted_tokens = torch.tensor(0,
dtype=torch.long,
device=device)
@property
def probs_dtype(self):
return torch.float32
@property
def token_id_dtype(self):
return torch.int64
@abstractmethod
def forward(
self,
target_probs: torch.Tensor,
bonus_token_ids: torch.Tensor,
draft_probs: torch.Tensor,
draft_token_ids: torch.Tensor,
) -> torch.Tensor:
raise NotImplementedError
def _create_output(
self,
accepted: torch.Tensor, # [batch_size, k]
substitute_token_ids: torch.Tensor, # [batch_size, k]
draft_token_ids: torch.Tensor, # [batch_size, k]
bonus_token_ids: torch.Tensor, # [batch_size]
) -> torch.Tensor:
"""Format output. Returns a matrix of token ids. When
a token is rejected via sampling, all subsequent token ids are
set to -1 for the sequence.
Args:
accepted: A boolean tensor indicating if the corresponding
draft token in draft_token_ids should be accepted or not.
substitute_token_ids: A tensor of token_ids that can be used
as substitutes for the draft token ids if the proposed token
is rejected.
draft_token_ids: A tensor of token ids speculated by the
draft model.
bonus_token_ids: Token ids to use as the bonus token if
all the draft tokens are accepted.
Returns:
A tensor containing the accepted token ids. The shape of the
tensor is [batch_size, k + num_bonus_tokens]
"""
batch_size, k = substitute_token_ids.shape
bonus_token_ids = bonus_token_ids.squeeze()
# Determine the index of the first False value for each row.
limits = (accepted == 0).max(1).indices
limits[~(accepted == 0).any(1)] = k
# Create masks using the indices.
indices = torch.arange(k, device=accepted.device).unsqueeze(0)
accepted_mask = indices < limits.unsqueeze(1)
after_false_mask = indices == limits.unsqueeze(1)
# Create an extended output tensor
output_with_bonus_tokens = -torch.ones(
(batch_size, k + self._num_bonus_tokens),
dtype=self.token_id_dtype,
device=accepted.device)
output = output_with_bonus_tokens[:, :k]
# Fill in the first k columns of the output tensor using masks and data
# tensors.
output[:, :k] = torch.where(accepted_mask, draft_token_ids,
-torch.ones_like(draft_token_ids))
# Fill the last column.
# We check output directly as accepted may have True values inconsistent
# with causal acceptance.
output_with_bonus_tokens[:, -1] = torch.where(output[:, -1] != -1,
bonus_token_ids, -1)
# We disable bonus tokens because it causes corrupt KV cache for
# proposal methods that require KV cache. We can fix it by "prefilling"
# the bonus token in the proposer. The following issue tracks the fix.
# https://github.com/vllm-project/vllm/issues/4212
if self._disable_bonus_tokens:
output_with_bonus_tokens[:, -1] = -1
# Fill the recovered token ids.
output.mul_(~after_false_mask).add_(
substitute_token_ids.mul(after_false_mask))
self.num_accepted_tokens += accepted.sum()
self.num_emitted_tokens += (output_with_bonus_tokens != -1).sum()
self.num_draft_tokens += batch_size * k
return output_with_bonus_tokens
def _raise_if_incorrect_input(
self,
target_probs: torch.Tensor,
draft_token_ids: torch.Tensor,
bonus_token_ids: torch.Tensor,
draft_probs: Optional[torch.Tensor] = None,
) -> None:
self._raise_if_incorrect_shape(target_probs, draft_token_ids,
bonus_token_ids, draft_probs)
self._raise_if_incorrect_dtype(target_probs, draft_token_ids,
bonus_token_ids, draft_probs)
self._raise_if_inconsistent_device(target_probs, draft_token_ids,
bonus_token_ids, draft_probs)
self._raise_if_out_of_bounds_vocab(target_probs.shape[-1],
draft_token_ids, bonus_token_ids)
def _raise_if_incorrect_shape(
self,
target_probs: torch.Tensor,
draft_token_ids: torch.Tensor,
bonus_token_ids: torch.Tensor,
draft_probs: Optional[torch.Tensor] = None,
) -> None:
(target_batch_size, num_target_probs,
target_vocab_size) = target_probs.shape
# validate the shape of draft token ids.
draft_token_ids_batch_size, num_draft_token_ids = draft_token_ids.shape
assert draft_token_ids_batch_size == target_batch_size
assert num_draft_token_ids == num_target_probs
# validate the shape of bonus token ids
bonus_batch_size, num_bonus_tokens = bonus_token_ids.shape
assert bonus_batch_size == target_batch_size
assert num_bonus_tokens == self._num_bonus_tokens
# validate the shape of draft probs if it is set
if draft_probs is not None:
(draft_batch_size, num_draft_probs,
draft_vocab_size) = draft_probs.shape
assert draft_batch_size == target_batch_size
assert num_draft_probs == num_target_probs
assert (draft_vocab_size == target_vocab_size
), f"{draft_vocab_size=} {target_vocab_size=}"
def _raise_if_incorrect_dtype(
self,
target_probs: torch.Tensor,
draft_token_ids: torch.Tensor,
bonus_token_ids: torch.Tensor,
draft_probs: Optional[torch.Tensor] = None,
) -> None:
assert target_probs.dtype == self.probs_dtype
assert draft_token_ids.dtype == self.token_id_dtype
assert bonus_token_ids.dtype == self.token_id_dtype
if draft_probs is not None:
assert draft_probs.dtype == self.probs_dtype
def _raise_if_inconsistent_device(
self,
target_probs: torch.Tensor,
draft_token_ids: torch.Tensor,
bonus_token_ids: torch.Tensor,
draft_probs: Optional[torch.Tensor] = None,
) -> None:
devices = [
t.device for t in
[target_probs, bonus_token_ids, draft_probs, draft_token_ids]
if t is not None
]
assert all([devices[0] == device for device in devices])
def _raise_if_out_of_bounds_vocab(
self,
vocab_size: int,
draft_token_ids: torch.Tensor,
bonus_token_ids: torch.Tensor,
) -> None:
assert torch.all(bonus_token_ids < vocab_size)
assert torch.all(bonus_token_ids >= 0)
assert torch.all(draft_token_ids < vocab_size)
assert torch.all(draft_token_ids >= 0)
import torch
import torch.jit
from vllm.model_executor.layers.spec_decode_base_sampler import (
SpecDecodeBaseSampler)
class TypicalAcceptanceSampler(SpecDecodeBaseSampler):
"""Apply typical acceptance sampling as described in section 3.3.1 in
"MEDUSA: Simple LLM Inference Acceleration Framework with
Multiple Decoding Heads"
https://arxiv.org/pdf/2401.10774
"""
def __init__(
self,
posterior_threshold: float,
posterior_alpha: float,
disable_bonus_tokens: bool = False,
strict_mode: bool = False,
):
"""Create a Typical Acceptance Sampler.
Args:
disable_bonus_tokens: Whether or not to disable the bonus token.
Require when bonus tokens will cause corrupt KV cache for
proposal methods that require KV cache.
strict_mode: Whether or not to perform shape/device/dtype checks
during sampling. This catches correctness issues but adds
nontrivial latency.
posterior_threshold : A threshold value that sets a lower bound
on the posterior probability of a token in target model for it
to be accepted.
posterior_alpha : A scaling factor for the entropy-based
threshold in typical acceptance sampling.
"""
self._posterior_threshold = posterior_threshold
self._posterior_alpha = posterior_alpha
super().__init__(disable_bonus_tokens=disable_bonus_tokens,
strict_mode=strict_mode)
def forward(
self,
target_probs: torch.Tensor,
bonus_token_ids: torch.Tensor,
draft_probs: torch.Tensor,
draft_token_ids: torch.Tensor,
) -> torch.Tensor:
"""Sample token ids using typical acceptance sampling. This accepts
or rejects tokens proposed by the draft model using the probability
of each token according to the draft and target models.
In the worst case where all draft tokens are rejected, it is guaranteed
one token will be emitted.
In the case where all draft tokens are accepted, the bonus token will be
accepted conditioned on self._disable_bonus_tokens being false.
Args:
target_probs: The probability distribution over token ids given
context according to the target model.
shape = [batch_size, num_speculative_tokens, vocab_size]
bonus_token_ids: The "bonus" token ids that are accepted iff all
speculative tokens in a sequence are accepted.
shape = [batch_size, num_bonus_tokens]
draft_probs: This parameter is unused by the acceptance sampler.
draft_token_ids: The token ids that were sampled from the draft
probabilities.
shape = [batch_size, num_speculative_tokens]
Returns:
output_token_ids: The token ids sampled via rejection sampling,
or -1 if unable to sample a token because the previous token
was rejected.
shape = [batch_size, num_speculative_tokens + num_bonus_tokens]
"""
# Only perform shape/dtype/device checking in strict mode, as it adds
# overhead.
if self._strict_mode:
self._raise_if_incorrect_input(target_probs, draft_token_ids,
bonus_token_ids)
accepted = self._evaluate_accepted_tokens(target_probs,
draft_token_ids)
recovered_token_ids = self._replacement_token_ids(target_probs)
output_token_ids = self._create_output(accepted, recovered_token_ids,
draft_token_ids,
bonus_token_ids)
return output_token_ids
def _evaluate_accepted_tokens(self, target_probs, draft_token_ids):
r"""
Evaluates and returns a mask of accepted tokens based on the
posterior probabilities.
Parameters:
----------
target_probs : torch.Tensor
A tensor of shape (batch_size, k, vocab_size) representing
the probabilities of each token in the vocabulary for each
position in the proposed sequence. This is the distribution
generated by the target model.
draft_token_ids : torch.Tensor
A tensor of shape (batch_size, k) representing the proposed
token ids.
A draft token_id x_{n+k} is accepted if it satisfies the
following condition
.. math::
p_{\text{original}}(x_{n+k} | x_1, x_2, \dots, x_{n+k-1}) >
\min \left( \epsilon, \delta * \exp \left(
-H(p_{\text{original}}(
\cdot | x_1, x_2, \ldots, x_{n+k-1})) \right) \right)
where :math:`p_{\text{original}}` corresponds to target_probs
and :math:`\epsilon` and :math:`\delta` correspond to hyperparameters
specified using self._posterior_threshold and self._posterior_alpha
This method computes the posterior probabilities for the given
draft token ids based on the provided target probabilities. It
calculates the entropy of the posterior distribution and determines
a dynamic threshold for each token position using the provided
posterior_threshold and posterior_alpha values. The method then
returns a boolean mask indicating which tokens can be accepted.
Returns:
-------
torch.Tensor
A boolean tensor of shape (batch_size, k) where each element
indicates whether the corresponding draft token has been accepted
or rejected. True indicates acceptance and false indicates
rejection.
"""
device = target_probs.device
candidates_prob = torch.gather(
target_probs, dim=-1,
index=draft_token_ids.unsqueeze(-1)).squeeze(-1)
# A small constant added to prevent computing the logarithm of zero,
# which can lead to undefined values.
epsilon = 1e-5
posterior_entropy = -torch.sum(
target_probs * torch.log(target_probs + epsilon), dim=-1)
threshold = torch.minimum(
torch.ones_like(posterior_entropy, device=device) *
self._posterior_threshold,
torch.exp(-posterior_entropy) * self._posterior_alpha,
)
accepted_mask = candidates_prob > threshold
return accepted_mask
def _replacement_token_ids(self, target_probs):
"""
Generate one replacement token ID for each sequence based on target
probabilities. The replacement token is used as the fallback option
if typical acceptance sampling does not accept any draft tokens for
that particular sequence.
This method computes the token IDs to be replaced by selecting the
token with the highest probability for each sequence in the first
position. The rest of the output is filled with -1.
Parameters
----------
target_probs : torch.Tensor
A tensor of shape (batch_size, k, vocab_size) containing
the target probability distribution
Returns
-------
torch.Tensor
A tensor of shape (batch_size, k) with the replacement
token IDs. Only the first column is set, and the rest of the
columns are filled with -1.
"""
max_indices = torch.argmax(target_probs[:, 0, :], dim=1)
output = -torch.ones((target_probs.shape[0], target_probs.shape[1]),
dtype=self.token_id_dtype,
device=target_probs.device)
output[:, 0] = max_indices
return output
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment