Commit 1591c68f authored by zhuwenwen's avatar zhuwenwen
Browse files

merge v0.4.2

parents 09bcf00b c7f2cf2b
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple, Union
import torch import torch
from torch.nn import Module from torch.nn import Module
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm import _custom_ops as ops
set_weight_attrs) from vllm.logger import init_logger
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.utils import set_weight_attrs
ACTIVATION_SCHEMES = ["static", "dynamic"]
class FP8Config(QuantizationConfig): logger = init_logger(__name__)
class Fp8Config(QuantizationConfig):
"""Config class for FP8.""" """Config class for FP8."""
def __init__(
self,
is_checkpoint_fp8_serialized: bool = False,
activation_scheme: str = "dynamic",
) -> None:
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
if is_checkpoint_fp8_serialized:
logger.warning("Detected fp8 checkpoint. Please note that the "
"format is experimental and subject to change.")
if activation_scheme not in ACTIVATION_SCHEMES:
raise ValueError(
f"Unsupported activation scheme {activation_scheme}")
self.activation_scheme = activation_scheme
@classmethod @classmethod
def get_name(cls) -> str: def get_name(cls) -> str:
return "fp8" return "fp8"
...@@ -23,21 +43,25 @@ class FP8Config(QuantizationConfig): ...@@ -23,21 +43,25 @@ class FP8Config(QuantizationConfig):
@classmethod @classmethod
def get_min_capability(cls) -> int: def get_min_capability(cls) -> int:
# TODO: PyTorch 2.3.0+ is required to run FP8 on return 89
# SM 89 (e.g. Ada) GPUs. Specifically, this PR has to
# be included: https://github.com/pytorch/pytorch/pull/118881
return 90
@classmethod @classmethod
def get_config_filenames(cls) -> List[str]: def get_config_filenames(cls) -> List[str]:
return [] return []
@classmethod @classmethod
def from_config(cls, config: Dict[str, Any]) -> "FP8Config": def from_config(cls, config: Dict[str, Any]) -> "Fp8Config":
return cls() quant_method = cls.get_from_keys(config, ["quant_method"])
is_checkpoint_fp8_serialized = ("fp8" in quant_method)
activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
activation_scheme=activation_scheme)
def get_linear_method(self) -> "Fp8LinearMethod": def get_quant_method(
return Fp8LinearMethod(self) self, layer: torch.nn.Module) -> Optional["Fp8LinearMethod"]:
if isinstance(layer, LinearBase):
return Fp8LinearMethod(self)
return None
def get_scaled_act_names(self) -> List[str]: def get_scaled_act_names(self) -> List[str]:
return [] return []
...@@ -45,8 +69,12 @@ class FP8Config(QuantizationConfig): ...@@ -45,8 +69,12 @@ class FP8Config(QuantizationConfig):
class Fp8LinearMethod(LinearMethodBase): class Fp8LinearMethod(LinearMethodBase):
"""Linear method for FP8. """Linear method for FP8.
We now support common FP16/BF16 model checkpoints ONLY. The weight Supports loading FP8 checkpoints with static weight scale and
scaling factor will be initialized after the model weights are loaded. dynamic/static activation scale.
Also supports loading quantized FP16/BF16 model checkpoints with dynamic
activation scaling. The weight scaling factor will be initialized after
the model weights are loaded.
Limitations: Limitations:
1. Only support per-tensor quantization due to torch._scaled_mm support. 1. Only support per-tensor quantization due to torch._scaled_mm support.
...@@ -57,9 +85,27 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -57,9 +85,27 @@ class Fp8LinearMethod(LinearMethodBase):
quant_config: The quantization config. quant_config: The quantization config.
""" """
def __init__(self, quant_config: FP8Config): def __init__(self, quant_config: Fp8Config):
self.quant_config = quant_config self.quant_config = quant_config
def _create_scale_param(
self,
scale_name: str,
layer: torch.nn.Module,
output_partition_sizes: List[int],
**extra_weight_attrs,
) -> None:
scale = Parameter(torch.empty(len(output_partition_sizes),
dtype=torch.float32),
requires_grad=False)
layer.register_parameter(scale_name, scale)
set_weight_attrs(
scale, {
**extra_weight_attrs,
"fp8_scales_shard_indexer":
self.scales_shard_indexer,
})
def create_weights( def create_weights(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
...@@ -70,70 +116,150 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -70,70 +116,150 @@ class Fp8LinearMethod(LinearMethodBase):
params_dtype: torch.dtype, params_dtype: torch.dtype,
**extra_weight_attrs, **extra_weight_attrs,
): ):
del input_size, output_size
output_size_per_partition = sum(output_partition_sizes) output_size_per_partition = sum(output_partition_sizes)
layer.process_after_load = True
layer.logical_widths = output_partition_sizes
# WEIGHT
weight_dtype = (torch.float8_e4m3fn
if self.quant_config.is_checkpoint_fp8_serialized else
params_dtype)
weight = Parameter(torch.empty(output_size_per_partition, weight = Parameter(torch.empty(output_size_per_partition,
input_size_per_partition, input_size_per_partition,
dtype=params_dtype), dtype=weight_dtype),
requires_grad=False) requires_grad=False)
layer.register_parameter("weight", weight) layer.register_parameter("weight", weight)
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) set_weight_attrs(weight, {
set_weight_attrs(weight, extra_weight_attrs) **extra_weight_attrs,
"input_dim": 1,
"output_dim": 0,
})
w_scale = Parameter( # If checkpoint is serialized fp8, load them.
torch.empty(1, dtype=torch.float32), # Otherwise, wait until process_weights_after_loading.
requires_grad=False, if self.quant_config.is_checkpoint_fp8_serialized:
) # WEIGHT SCALE
layer.register_parameter("weight_scaling_factor", w_scale) self._create_scale_param(
scale_name="weight_scale",
layer=layer,
output_partition_sizes=output_partition_sizes,
**extra_weight_attrs)
# ACTIVATION SCALE
if self.quant_config.activation_scheme == "static":
self._create_scale_param(
scale_name="act_scale",
layer=layer,
output_partition_sizes=output_partition_sizes,
**extra_weight_attrs)
def scales_shard_indexer(
self, param: torch.Tensor, loaded_weight: torch.Tensor,
shard_id: Union[str, int]) -> Tuple[torch.Tensor, torch.Tensor]:
qkv_idxs = {"q": 0, "k": 1, "v": 2}
if isinstance(shard_id, int):
pass
elif isinstance(shard_id, str):
if shard_id not in qkv_idxs:
raise ValueError(f"Unknown shard_id: {shard_id}")
shard_id = qkv_idxs[shard_id]
else:
ValueError(f"Shard id must be int or str but got {type(shard_id)}")
return param[shard_id], loaded_weight
def process_weights_after_loading(self, layer: Module) -> None: def process_weights_after_loading(self, layer: Module) -> None:
# Although the linear_method is propagated to all layers, if (not hasattr(layer, "process_after_load")
# only linear layers invoke "create_weights". So we check or not layer.process_after_load):
# whether "weight_scaling_facor" is registered to determine return
# whether the layer is a linear layer that requires quantization.
if not hasattr(layer, "weight_scaling_factor"): # If checkpoint is fp/bf16 (not serialized fp8), quantize the weights.
if not self.quant_config.is_checkpoint_fp8_serialized:
qweight, weight_scale = ops.scaled_fp8_quant(layer.weight,
scale=None)
layer.weight = Parameter(qweight.t(), requires_grad=False)
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
layer.logical_widths = None
layer.act_scale = None
return return
qweight, weight_scale = per_tensor_quantize(layer.weight) # If checkpoint is fp8, requantize the separately quantized logical
# torch._scaled_mm requires column-major in the second # weights into a single fp8 weight with a single weight scale.
# input (weight), so we transpose the quantized weight. else:
layer.weight = Parameter(qweight.t(), requires_grad=False) # WEIGHT_SCALE / WEIGHT
layer.weight_scaling_factor.data.copy_(weight_scale) # Loop over logical weights, requantizing with single scale.
max_w_scale = layer.weight_scale.max()
def apply_weights(self, start = 0
layer: torch.nn.Module, for idx, logical_width in enumerate(layer.logical_widths):
x: torch.Tensor, end = start + logical_width
bias: Optional[torch.Tensor] = None) -> torch.Tensor: weight_dq = per_tensor_dequantize(layer.weight[start:end, :],
qinput, x_scale = per_tensor_quantize(x) layer.weight_scale[idx])
layer.weight[start:end, :] = per_tensor_quantize(
weight_dq, layer.weight_scale.max())
start = end
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
# WEIGHT
# Transpose weight for passing to torch._scaled_mm
weight = layer.weight
layer.weight = Parameter(weight.t(), requires_grad=False)
# ACT_SCALE
# Dynamic: set to None (required input to ops.scaled_fp8_quant).
# Static: set to max of the act_scales (since they are equal).
if self.quant_config.activation_scheme == "dynamic":
layer.act_scale = None
elif self.quant_config.activation_scheme == "static":
if not all_close_1d(layer.act_scale):
raise ValueError(
"All the act_scales for the logical weights of a layer "
f"must be equal. But got {layer.act_scale}")
layer.act_scale = Parameter(layer.act_scale.max(),
requires_grad=False)
else:
raise ValueError(
f"Unknown scheme {self.quant_config.activation_scheme}")
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
# ops.scaled_fp8_quant supports both dynamic and static quant.
# If dynamic, layer.act_scale is None and x_scale computed from x.
# If static, layer.act_scale is scalar and x_scale set to act_scale.
qinput, x_scale = ops.scaled_fp8_quant(x, layer.act_scale)
# Fused GEMM_DQ
output, _ = torch._scaled_mm( output, _ = torch._scaled_mm(
qinput, qinput,
layer.weight, layer.weight,
out_dtype=x.dtype, out_dtype=x.dtype,
scale_a=x_scale, scale_a=x_scale,
scale_b=layer.weight_scaling_factor, scale_b=layer.weight_scale,
bias=bias, bias=bias,
) )
return output return output
def per_tensor_quantize(tensor: torch.Tensor) -> Tuple[torch.Tensor, float]: def all_close_1d(x: torch.Tensor) -> bool:
"""Quantize a tensor using per-tensor static scaling factor. assert len(x.shape) == 1
return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0]))
Args:
tensor: The input tensor. def per_tensor_quantize(tensor: torch.Tensor,
""" inv_scale: float) -> torch.Tensor:
finfo = torch.finfo(torch.float8_e4m3fn) finfo = torch.finfo(torch.float8_e4m3fn)
# Calculate the scale as dtype max divided by absmax. qweight = (tensor / inv_scale).clamp(min=finfo.min, max=finfo.max)
# Since .abs() creates a new tensor, we use aminmax to get return qweight.to(torch.float8_e4m3fn)
# the min and max first and then calculate the absmax.
min_val, max_val = tensor.aminmax()
amax = min_val.abs().max(max_val.abs()) def per_tensor_dequantize(tensor: torch.Tensor,
scale = finfo.max / amax.clamp(min=1e-12) inv_scale: float) -> torch.Tensor:
# scale and clamp the tensor to bring it to fake_qweight = tensor.to(torch.float16)
# the representative range of float8 data type dq_weight = fake_qweight * inv_scale
# (as default cast is unsaturated) return dq_weight
qweight = (tensor * scale).clamp(min=finfo.min, max=finfo.max)
# Return both float8 data and the inverse scale (as float),
# as both required as inputs to torch._scaled_mm
qweight = qweight.to(torch.float8_e4m3fn)
scale = scale.float().reciprocal()
return qweight, scale
...@@ -7,10 +7,10 @@ import torch ...@@ -7,10 +7,10 @@ import torch
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.utils import set_weight_attrs
class GPTQConfig(QuantizationConfig): class GPTQConfig(QuantizationConfig):
...@@ -63,8 +63,11 @@ class GPTQConfig(QuantizationConfig): ...@@ -63,8 +63,11 @@ class GPTQConfig(QuantizationConfig):
desc_act = cls.get_from_keys(config, ["desc_act"]) desc_act = cls.get_from_keys(config, ["desc_act"])
return cls(weight_bits, group_size, desc_act) return cls(weight_bits, group_size, desc_act)
def get_linear_method(self) -> "GPTQLinearMethod": def get_quant_method(
return GPTQLinearMethod(self) self, layer: torch.nn.Module) -> Optional["GPTQLinearMethod"]:
if isinstance(layer, LinearBase):
return GPTQLinearMethod(self)
return None
def get_scaled_act_names(self) -> List[str]: def get_scaled_act_names(self) -> List[str]:
return [] return []
...@@ -194,10 +197,10 @@ class GPTQLinearMethod(LinearMethodBase): ...@@ -194,10 +197,10 @@ class GPTQLinearMethod(LinearMethodBase):
layer.exllama_state = exllama_state layer.exllama_state = exllama_state
def apply_weights(self, def apply(self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor: bias: Optional[torch.Tensor] = None) -> torch.Tensor:
qweight = layer.qweight qweight = layer.qweight
out_shape = x.shape[:-1] + (qweight.shape[-1], ) out_shape = x.shape[:-1] + (qweight.shape[-1], )
reshaped_x = x.reshape(-1, x.shape[-1]) reshaped_x = x.reshape(-1, x.shape[-1])
......
import enum
from enum import Enum
from typing import Any, Dict, List, Optional
import torch
from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
GPTQ_MARLIN_TILE = 16
GPTQ_MARLIN_MIN_THREAD_N = 64
GPTQ_MARLIN_MIN_THREAD_K = 128
GPTQ_MARLIN_MAX_PARALLEL = 16
GPTQ_MARLIN_SUPPORTED_NUM_BITS = [4, 8]
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
GPTQ_MARLIN_SUPPORTED_SYM = [True]
# Permutations for Marlin scale shuffling
def get_scale_perms(num_bits):
scale_perm = []
for i in range(8):
scale_perm.extend([i + 8 * j for j in range(8)])
scale_perm_single = []
for i in range(4):
scale_perm_single.extend(
[2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
return scale_perm, scale_perm_single
def get_pack_factor(num_bits):
assert (num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS
), f"Unsupported num_bits = {num_bits}"
return 32 // num_bits
def marlin_permute_scales(s, size_k, size_n, group_size, num_bits):
scale_perm, scale_perm_single = get_scale_perms(num_bits)
if group_size < size_k and group_size != -1:
s = s.reshape((-1, len(scale_perm)))[:, scale_perm]
else:
s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
s = s.reshape((-1, size_n)).contiguous()
return s
class GPTQMarlinConfig(QuantizationConfig):
"""Config class for GPTQ Marlin"""
def __init__(self, weight_bits: int, group_size: int, desc_act: bool,
is_sym: bool) -> None:
if desc_act and group_size == -1:
# In this case, act_order == True is the same as act_order == False
# (since we have only one group per output channel)
desc_act = False
self.weight_bits = weight_bits
self.group_size = group_size
self.desc_act = desc_act
self.is_sym = is_sym
# Verify
if self.weight_bits not in GPTQ_MARLIN_SUPPORTED_NUM_BITS:
raise ValueError(
f"Marlin does not support weight_bits = {self.weight_bits}. "
f"Only weight_bits = {GPTQ_MARLIN_SUPPORTED_NUM_BITS} "
"are supported.")
if self.group_size not in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES:
raise ValueError(
f"Marlin does not support group_size = {self.group_size}. "
f"Only group_sizes = {GPTQ_MARLIN_SUPPORTED_GROUP_SIZES} "
"are supported.")
if self.is_sym not in GPTQ_MARLIN_SUPPORTED_SYM:
raise ValueError(
f"Marlin does not support is_sym = {self.is_sym}. "
f"Only sym = {GPTQ_MARLIN_SUPPORTED_SYM} are supported.")
# Init
self.pack_factor = get_pack_factor(weight_bits)
self.tile_size = GPTQ_MARLIN_TILE
self.min_thread_n = GPTQ_MARLIN_MIN_THREAD_N
self.min_thread_k = GPTQ_MARLIN_MIN_THREAD_K
self.max_parallel = GPTQ_MARLIN_MAX_PARALLEL
def __repr__(self) -> str:
return (f"GPTQMarlinConfig(weight_bits={self.weight_bits}, "
f"group_size={self.group_size}, "
f"desc_act={self.desc_act})")
@classmethod
def get_name(cls) -> str:
return "gptq_marlin"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.half]
@classmethod
def get_min_capability(cls) -> int:
return 80
@classmethod
def get_config_filenames(cls) -> List[str]:
return ["quantize_config.json"]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "GPTQMarlinConfig":
weight_bits = cls.get_from_keys(config, ["bits"])
group_size = cls.get_from_keys(config, ["group_size"])
desc_act = cls.get_from_keys(config, ["desc_act"])
is_sym = cls.get_from_keys(config, ["sym"])
return cls(weight_bits, group_size, desc_act, is_sym)
def get_quant_method(
self,
layer: torch.nn.Module) -> Optional["GPTQMarlinLinearMethod"]:
if isinstance(layer, LinearBase):
return GPTQMarlinLinearMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
return []
@classmethod
def is_marlin_compatible(cls, quant_config: Dict[str, Any]):
# Extract data from quant config.
num_bits = quant_config.get("bits", None)
group_size = quant_config.get("group_size", None)
sym = quant_config.get("sym", None)
desc_act = quant_config.get("desc_act", None)
# If we cannot find the info needed in the config, cannot convert.
if (num_bits is None or group_size is None or sym is None
or desc_act is None):
return False
# If the capability of the device is too low, cannot convert.
major, minor = torch.cuda.get_device_capability()
device_capability = major * 10 + minor
if device_capability < cls.get_min_capability():
return False
# Otherwise, can convert if model satisfies marlin constraints.
return (num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS
and group_size in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
and sym in GPTQ_MARLIN_SUPPORTED_SYM)
class GPTQMarlinState(Enum):
REPACK = enum.auto()
READY = enum.auto()
class GPTQMarlinLinearMethod(LinearMethodBase):
"""Linear method for GPTQ Marlin.
Args:
quant_config: The GPTQ Marlin quantization config.
"""
def __init__(self, quant_config: GPTQMarlinConfig) -> None:
self.quant_config = quant_config
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
) -> None:
del output_size
# Normalize group_size
if self.quant_config.group_size != -1:
group_size = self.quant_config.group_size
else:
group_size = input_size
# Validate dtype
if params_dtype != torch.float16:
raise ValueError(
f"The params dtype must be float16, but got {params_dtype}")
# Validate output_size_per_partition
output_size_per_partition = sum(output_partition_sizes)
if output_size_per_partition % self.quant_config.min_thread_n != 0:
raise ValueError(
f"Weight output_size_per_partition = "
f"{output_size_per_partition} is not divisible by "
f" min_thread_n = {self.quant_config.min_thread_n}.")
# Validate input_size_per_partition
if input_size_per_partition % self.quant_config.min_thread_k != 0:
raise ValueError(
f"Weight input_size_per_partition = "
f"{input_size_per_partition} is not divisible "
f"by min_thread_k = {self.quant_config.min_thread_k}.")
if (group_size < input_size
and input_size_per_partition % group_size != 0):
raise ValueError(
f"Weight input_size_per_partition = {input_size_per_partition}"
f" is not divisible by group_size = {group_size}.")
# Detect sharding of scales/zp
# By default, no sharding over "input dim"
scales_and_zp_size = input_size // group_size
scales_and_zp_input_dim = None
if self.quant_config.desc_act:
# Act-order case
assert self.quant_config.group_size != -1
is_k_full = input_size_per_partition == input_size
else:
# No act-order case
# K is always full due to full alignment with
# group-size and shard of scales/zp
is_k_full = True
# If this is a row-parallel case, then shard scales/zp
if (input_size != input_size_per_partition
and self.quant_config.group_size != -1):
scales_and_zp_size = input_size_per_partition // group_size
scales_and_zp_input_dim = 0
# Init buffers
# Quantized weights
qweight = Parameter(
torch.empty(
input_size_per_partition // self.quant_config.pack_factor,
output_size_per_partition,
dtype=torch.int32,
),
requires_grad=False,
)
set_weight_attrs(
qweight,
{
**extra_weight_attrs,
"input_dim": 0,
"output_dim": 1,
"packed_dim": 0,
"pack_factor": self.quant_config.pack_factor,
},
)
# Activation order
g_idx = Parameter(
torch.empty(
input_size_per_partition,
dtype=torch.int32,
),
requires_grad=False,
)
# Ignore warning from fused linear layers such as QKVParallelLinear.
set_weight_attrs(
g_idx,
{
**extra_weight_attrs, "input_dim": 0,
"ignore_warning": True
},
)
g_idx_sort_indices = Parameter(
torch.empty(
g_idx.shape,
dtype=torch.int32,
),
requires_grad=False,
)
set_weight_attrs(g_idx_sort_indices, extra_weight_attrs)
# Scales
scales = Parameter(
torch.empty(
scales_and_zp_size,
output_size_per_partition,
dtype=params_dtype,
),
requires_grad=False,
)
set_weight_attrs(
scales,
{
**extra_weight_attrs,
"input_dim": scales_and_zp_input_dim,
"output_dim": 1,
},
)
# Quantized zero-points
qzeros = Parameter(
torch.empty(
scales_and_zp_size,
output_size_per_partition // self.quant_config.pack_factor,
dtype=torch.int32,
device="meta",
),
requires_grad=False,
)
set_weight_attrs(
qzeros,
{
**extra_weight_attrs,
"input_dim": scales_and_zp_input_dim,
"output_dim": 1,
"packed_dim": 1,
"pack_factor": self.quant_config.pack_factor,
},
)
# Allocate marlin workspace
max_workspace_size = (
output_size_per_partition //
self.quant_config.min_thread_n) * self.quant_config.max_parallel
workspace = torch.zeros(max_workspace_size,
dtype=torch.int,
requires_grad=False)
layer.register_parameter("qweight", qweight)
layer.register_parameter("g_idx", g_idx)
layer.register_parameter("g_idx_sort_indices", g_idx_sort_indices)
layer.register_parameter("scales", scales)
layer.register_parameter("qzeros", qzeros)
layer.workspace = workspace
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
layer.input_size = input_size
layer.is_k_full = is_k_full
layer.marlin_state = GPTQMarlinState.REPACK
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
reshaped_x = x.reshape(-1, x.shape[-1])
size_m = reshaped_x.shape[0]
part_size_n = layer.output_size_per_partition
part_size_k = layer.input_size_per_partition
full_size_k = layer.input_size
out_shape = x.shape[:-1] + (part_size_n, )
if layer.marlin_state == GPTQMarlinState.REPACK:
layer.marlin_state = GPTQMarlinState.READY
# Newly generated tensors need to replace existing tensors that are
# already registered as parameters by vLLM (and won't be freed)
def replace_tensor(name, new_t):
# It is important to use resize_() here since it ensures
# the same buffer is reused
getattr(layer, name).resize_(new_t.shape)
getattr(layer, name).copy_(new_t)
del new_t
cur_device = layer.qweight.device
# Process act_order
if self.quant_config.desc_act:
# Get sorting based on g_idx
g_idx_sort_indices = torch.argsort(layer.g_idx).to(torch.int)
sorted_g_idx = layer.g_idx[g_idx_sort_indices]
replace_tensor("g_idx", sorted_g_idx)
replace_tensor("g_idx_sort_indices", g_idx_sort_indices)
else:
# Reset g_idx related tensors
layer.g_idx = Parameter(
torch.empty(0, dtype=torch.int, device=cur_device),
requires_grad=False,
)
layer.g_idx_sort_indices = Parameter(
torch.empty(0, dtype=torch.int, device=cur_device),
requires_grad=False,
)
# Repack weights
marlin_qweight = ops.gptq_marlin_repack(
layer.qweight,
layer.g_idx_sort_indices,
part_size_k,
part_size_n,
self.quant_config.weight_bits,
)
replace_tensor("qweight", marlin_qweight)
# Permute scales
scales_size_k = part_size_k
scales_size_n = part_size_n
if self.quant_config.desc_act:
scales_size_k = full_size_k
marlin_scales = marlin_permute_scales(
layer.scales,
scales_size_k,
scales_size_n,
self.quant_config.group_size,
self.quant_config.weight_bits,
)
replace_tensor("scales", marlin_scales)
output = ops.gptq_marlin_gemm(
reshaped_x,
layer.qweight,
layer.scales,
layer.g_idx,
layer.g_idx_sort_indices,
layer.workspace,
self.quant_config.weight_bits,
size_m,
part_size_n,
part_size_k,
layer.is_k_full,
)
if bias is not None:
output.add_(bias) # In-place add
return output.reshape(out_shape)
...@@ -4,10 +4,10 @@ import torch ...@@ -4,10 +4,10 @@ import torch
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.utils import set_weight_attrs
class MarlinConfig(QuantizationConfig): class MarlinConfig(QuantizationConfig):
...@@ -72,8 +72,11 @@ class MarlinConfig(QuantizationConfig): ...@@ -72,8 +72,11 @@ class MarlinConfig(QuantizationConfig):
group_size = cls.get_from_keys(config, ["group_size"]) group_size = cls.get_from_keys(config, ["group_size"])
return cls(group_size) return cls(group_size)
def get_linear_method(self) -> "MarlinLinearMethod": def get_quant_method(
return MarlinLinearMethod(self) self, layer: torch.nn.Module) -> Optional["MarlinLinearMethod"]:
if isinstance(layer, LinearBase):
return MarlinLinearMethod(self)
return None
def get_scaled_act_names(self) -> List[str]: def get_scaled_act_names(self) -> List[str]:
return [] return []
...@@ -197,7 +200,7 @@ class MarlinLinearMethod(LinearMethodBase): ...@@ -197,7 +200,7 @@ class MarlinLinearMethod(LinearMethodBase):
layer.register_parameter("workspace", workspace) layer.register_parameter("workspace", workspace)
set_weight_attrs(workspace, extra_weight_attrs) set_weight_attrs(workspace, extra_weight_attrs)
def apply_weights( def apply(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
......
...@@ -4,10 +4,10 @@ import torch ...@@ -4,10 +4,10 @@ import torch
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import LinearBase
set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.utils import set_weight_attrs
from vllm.utils import is_hip from vllm.utils import is_hip
...@@ -51,14 +51,17 @@ class SqueezeLLMConfig(QuantizationConfig): ...@@ -51,14 +51,17 @@ class SqueezeLLMConfig(QuantizationConfig):
weight_bits = cls.get_from_keys(config, ["wbits"]) weight_bits = cls.get_from_keys(config, ["wbits"])
return cls(weight_bits) return cls(weight_bits)
def get_linear_method(self) -> "SqueezeLLMLinearMethod": def get_quant_method(
return SqueezeLLMLinearMethod(self) self, layer: torch.nn.Module) -> Optional[QuantizeMethodBase]:
if isinstance(layer, LinearBase):
return SqueezeLLMLinearMethod(self)
return None
def get_scaled_act_names(self) -> List[str]: def get_scaled_act_names(self) -> List[str]:
return [] return []
class SqueezeLLMLinearMethod(LinearMethodBase): class SqueezeLLMLinearMethod(QuantizeMethodBase):
"""Linear method for SqueezeLLM. """Linear method for SqueezeLLM.
Args: Args:
...@@ -112,10 +115,10 @@ class SqueezeLLMLinearMethod(LinearMethodBase): ...@@ -112,10 +115,10 @@ class SqueezeLLMLinearMethod(LinearMethodBase):
layer.register_parameter("lookup_table", lookup_table) layer.register_parameter("lookup_table", lookup_table)
set_weight_attrs(lookup_table, extra_weight_attrs) set_weight_attrs(lookup_table, extra_weight_attrs)
def apply_weights(self, def apply(self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor: bias: Optional[torch.Tensor] = None) -> torch.Tensor:
qweight = layer.qweight qweight = layer.qweight
lookup_table = layer.lookup_table lookup_table = layer.lookup_table
out_shape = x.shape[:-1] + (qweight.shape[-1], ) out_shape = x.shape[:-1] + (qweight.shape[-1], )
......
...@@ -156,6 +156,12 @@ class RotaryEmbedding(nn.Module): ...@@ -156,6 +156,12 @@ class RotaryEmbedding(nn.Module):
self.cos_sin_cache, self.is_neox_style) self.cos_sin_cache, self.is_neox_style)
return query, key return query, key
def extra_repr(self) -> str:
s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}"
s += f", max_position_embeddings={self.max_position_embeddings}"
s += f", base={self.base}, is_neox_style={self.is_neox_style}"
return s
class LinearScalingRotaryEmbedding(RotaryEmbedding): class LinearScalingRotaryEmbedding(RotaryEmbedding):
"""RotaryEmbedding extended with linear scaling. """RotaryEmbedding extended with linear scaling.
...@@ -338,6 +344,114 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding): ...@@ -338,6 +344,114 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
return cache return cache
class Phi3SuScaledRotaryEmbedding(nn.Module):
"""Phi3 family of models scaled rotary embedding.
Based on the original RotaryEmbedding implementation.
"""
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
original_max_position_embeddings: int,
base: int,
is_neox_style: bool,
short_factor: List[float],
long_factor: List[float],
short_mscale: float = 1.1,
long_mscale: float = 1.225,
):
super().__init__()
if rotary_dim != head_size:
raise ValueError(
f"`Phi3SuScaledRotaryEmbedding` does not support rotary_dim != \
head_size ({rotary_dim}!={head_size}).")
if is_neox_style is False:
raise ValueError(
"`Phi3SuScaledRotaryEmbedding` only supports neox_style.")
self.head_size = head_size
self.max_position_embeddings = max_position_embeddings
self.original_max_position_embeddings = original_max_position_embeddings
self.base = base
self.short_factor = short_factor
self.long_factor = long_factor
self.short_mscale = short_mscale
self.long_mscale = long_mscale
short_cache = self._compute_cos_sin_cache(
original_max_position_embeddings, short_factor, short_mscale)
short_cache = short_cache.to(torch.get_default_dtype())
self.register_buffer("short_cos_sin_cache",
short_cache,
persistent=False)
long_cache = self._compute_cos_sin_cache(max_position_embeddings,
long_factor, long_mscale)
long_cache = long_cache.to(torch.get_default_dtype())
self.register_buffer("long_cos_sin_cache",
long_cache,
persistent=False)
long_short_cache = torch.cat(
[self.short_cos_sin_cache, self.long_cos_sin_cache], dim=0)
self.register_buffer("long_short_cos_sin_cache",
long_short_cache,
persistent=False)
def _compute_inv_freq(self, rescale_factors: List[float]) -> torch.Tensor:
rescale_factors = torch.tensor(rescale_factors, dtype=torch.float32)
inv_freq = 1.0 / (rescale_factors * (self.base**(torch.arange(
0, self.head_size, 2, dtype=torch.float) / self.head_size)))
return inv_freq
def _compute_cos_sin_cache(
self,
max_position_embeddings: int,
rescale_factors: List[float],
mscale: float,
) -> torch.Tensor:
inv_freq = self._compute_inv_freq(rescale_factors)
t = torch.arange(max_position_embeddings, dtype=torch.float)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos() * mscale
sin = freqs.sin() * mscale
cache = torch.cat((cos, sin), dim=-1)
return cache
def forward(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
query = query.view(*query.shape[:-1], -1, self.head_size)
key = key.view(*key.shape[:-1], -1, self.head_size)
k = self.original_max_position_embeddings
long_prompt_offset = (torch.any(positions > k).float() *
torch.full_like(positions, k)).long()
idx = (torch.add(positions, long_prompt_offset)
if long_prompt_offset is not None else positions)
self.long_short_cos_sin_cache: torch.Tensor = (
self.long_short_cos_sin_cache.to(idx.device))
idx = torch.add(idx, offsets) if offsets is not None else idx
cos_sin = torch.index_select(self.long_short_cos_sin_cache, 0, idx)
cos, sin = cos_sin.chunk(2, dim=-1)
cos = cos.repeat(1, 2).unsqueeze(-2)
sin = sin.repeat(1, 2).unsqueeze(-2)
query = query * cos + _rotate_neox(query) * sin
key = key * cos + _rotate_neox(key) * sin
return query.flatten(-2), key.flatten(-2)
_ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {} _ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {}
...@@ -349,17 +463,26 @@ def get_rope( ...@@ -349,17 +463,26 @@ def get_rope(
is_neox_style: bool = True, is_neox_style: bool = True,
rope_scaling: Optional[Dict[str, Any]] = None, rope_scaling: Optional[Dict[str, Any]] = None,
) -> RotaryEmbedding: ) -> RotaryEmbedding:
if rope_scaling is not None:
# Transforms every value that is a list into a tuple for caching calls
rope_scaling_tuple = {
k: tuple(v) if isinstance(v, list) else v
for k, v in rope_scaling.items()
}
rope_scaling_args = tuple(rope_scaling_tuple.items())
else:
rope_scaling_args = None
key = (head_size, rotary_dim, max_position, base, is_neox_style, key = (head_size, rotary_dim, max_position, base, is_neox_style,
tuple(rope_scaling.items()) if rope_scaling is not None else None) rope_scaling_args)
if key in _ROPE_DICT: if key in _ROPE_DICT:
return _ROPE_DICT[key] return _ROPE_DICT[key]
if rope_scaling is None: if rope_scaling is None:
rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base, rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base,
is_neox_style) is_neox_style)
else: else:
scaling_type = rope_scaling["type"] scaling_type = rope_scaling["type"]
scaling_factor = rope_scaling["factor"] if scaling_type != "su":
scaling_factor = rope_scaling["factor"]
if scaling_type == "linear": if scaling_type == "linear":
rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim, rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim,
max_position, base, max_position, base,
...@@ -383,6 +506,19 @@ def get_rope( ...@@ -383,6 +506,19 @@ def get_rope(
base, is_neox_style, base, is_neox_style,
scaling_factor, scaling_factor,
**extra_kwargs) **extra_kwargs)
elif scaling_type == "su":
short_factor = rope_scaling["short_factor"]
long_factor = rope_scaling["long_factor"]
original_max_position = rope_scaling[
"original_max_position_embeddings"]
extra_kwargs = {
k: v
for k, v in rope_scaling.items()
if k in ("short_mscale", "long_mscale")
}
rotary_emb = Phi3SuScaledRotaryEmbedding(
head_size, rotary_dim, max_position, original_max_position,
base, is_neox_style, short_factor, long_factor, **extra_kwargs)
else: else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}") raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
_ROPE_DICT[key] = rotary_emb _ROPE_DICT[key] = rotary_emb
......
...@@ -7,11 +7,14 @@ import torch.nn as nn ...@@ -7,11 +7,14 @@ import torch.nn as nn
from vllm.model_executor.layers.ops.sample import sample as sample_triton from vllm.model_executor.layers.ops.sample import sample as sample_triton
from vllm.model_executor.sampling_metadata import (SamplingMetadata, from vllm.model_executor.sampling_metadata import (SamplingMetadata,
SamplingTensors) SamplingTensors,
from vllm.sampling_params import SamplingParams, SamplingType SequenceGroupToSample)
from vllm.sampling_params import SamplingType
from vllm.sequence import (Logprob, PromptLogprobs, SampleLogprobs, from vllm.sequence import (Logprob, PromptLogprobs, SampleLogprobs,
SamplerOutput, SequenceData, SequenceGroupOutput, SamplerOutput, SequenceGroupOutput, SequenceOutput)
SequenceOutput)
# (num_token_ids, num_parent_ids) per sequence group.
SampleResultType = List[Tuple[List[int], List[int]]]
class Sampler(nn.Module): class Sampler(nn.Module):
...@@ -48,11 +51,14 @@ class Sampler(nn.Module): ...@@ -48,11 +51,14 @@ class Sampler(nn.Module):
logits: torch.Tensor, logits: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]: ) -> Optional[SamplerOutput]:
"""
Args:
logits: (num_tokens, vocab_size).
sampling_metadata: Metadata for sampling.
"""
assert logits is not None assert logits is not None
_, vocab_size = logits.shape _, vocab_size = logits.shape
# Apply min_tokens penalty which sets stop tokens to -inf if min_tokens
# have not been generated yet
logits = _apply_min_tokens_penalty(logits, sampling_metadata) logits = _apply_min_tokens_penalty(logits, sampling_metadata)
# Prepare sampling tensors with pinned memory to avoid blocking. # Prepare sampling tensors with pinned memory to avoid blocking.
...@@ -83,7 +89,6 @@ class Sampler(nn.Module): ...@@ -83,7 +89,6 @@ class Sampler(nn.Module):
# Compute the probabilities. # Compute the probabilities.
probs = torch.softmax(logits, dim=-1, dtype=torch.float) probs = torch.softmax(logits, dim=-1, dtype=torch.float)
# Compute the log probabilities. # Compute the log probabilities.
# Use log_softmax to ensure numerical stability.
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float) logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
# Sample the next tokens. # Sample the next tokens.
...@@ -98,8 +103,7 @@ class Sampler(nn.Module): ...@@ -98,8 +103,7 @@ class Sampler(nn.Module):
if self.include_gpu_probs_tensor: if self.include_gpu_probs_tensor:
assert maybe_sampled_tokens_tensor is not None assert maybe_sampled_tokens_tensor is not None
sampled_tokens_tensor = maybe_sampled_tokens_tensor on_device_tensors = (probs, logprobs, maybe_sampled_tokens_tensor)
on_device_tensors = (probs, sampled_tokens_tensor)
else: else:
on_device_tensors = None on_device_tensors = None
...@@ -149,46 +153,46 @@ def _apply_min_tokens_penalty( ...@@ -149,46 +153,46 @@ def _apply_min_tokens_penalty(
logits: torch.Tensor, logits: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
"""Apply min_tokens penalty which sets stop tokens to -inf if min_tokens
have not been generated yet
"""
# list of indices in logits that will be set to -inf # list of indices in logits that will be set to -inf
logits_to_penalize = [] logits_to_penalize: List[Tuple[int, int]] = []
start_idx = 0 logits_applied = 0
for i, seq_group in enumerate(sampling_metadata.seq_groups): for seq_group in sampling_metadata.seq_groups:
seq_ids, sampling_params = seq_group seq_ids = seq_group.seq_ids
sampling_params = seq_group.sampling_params
# handle prompt_logprobs by skipping rows in logits added for the prompt
# tokens (prompt logprobs are not penalized) sample_indices = seq_group.sample_indices
if (i < sampling_metadata.num_prompts logits_applied += len(sample_indices) + len(
and sampling_params.prompt_logprobs is not None): seq_group.prompt_logprob_indices)
assert len(seq_ids) == 1 if not seq_group.do_sample:
start_idx += sampling_metadata.prompt_lens[i] - 1 continue
start_idx = sample_indices[0]
min_tokens = sampling_params.min_tokens min_tokens = sampling_params.min_tokens
if min_tokens > 0: token_ids_to_penalize = sampling_params.all_stop_token_ids
if min_tokens > 0 and token_ids_to_penalize:
seqs_to_penalize = [] seqs_to_penalize = []
for i, seq_id in enumerate(seq_ids): for j, seq_id in enumerate(seq_ids):
seq_data = sampling_metadata.seq_data[seq_id] seq_data = seq_group.seq_data[seq_id]
if len(seq_data.output_token_ids) < min_tokens: if len(seq_data.output_token_ids) < min_tokens:
seqs_to_penalize.append(i) seqs_to_penalize.append(j)
if seqs_to_penalize: if seqs_to_penalize:
# convert to the index into logits # convert to the index into logits
seqs_to_penalize = [start_idx + i for i in seqs_to_penalize] seqs_to_penalize = [start_idx + j for j in seqs_to_penalize]
# use set() to remove any duplicates
token_ids_to_penalize = set(sampling_params.stop_token_ids +
[sampling_params.eos_token_id])
# itertools.product pairs each seq index with every token id # itertools.product pairs each seq index with every token id
logits_to_penalize.extend( logits_to_penalize.extend(
itertools.product(seqs_to_penalize, token_ids_to_penalize)) itertools.product(seqs_to_penalize, token_ids_to_penalize))
start_idx += len(seq_ids)
if logits_to_penalize: if logits_to_penalize:
# use zip and * to group indices along each dimension # use zip and * to group indices along each dimension
# eg. [ (1,2), (1,3), (5,6) ] -> ( (1,1,5), (2,3,6) ) # eg. [ (1,2), (1,3), (5,6) ] -> ( (1,1,5), (2,3,6) )
logits[tuple(zip(*logits_to_penalize))] = -float("inf") logits[tuple(zip(*logits_to_penalize))] = -float("inf")
# verifies that no rows in logits were missed unexpectedly # verifies that no rows in logits were missed unexpectedly
assert start_idx == logits.shape[0] assert logits_applied == logits.shape[0]
return logits return logits
...@@ -265,14 +269,30 @@ def _apply_min_p( ...@@ -265,14 +269,30 @@ def _apply_min_p(
def _greedy_sample( def _greedy_sample(
selected_seq_groups: List[Tuple[List[int], SamplingParams]], selected_seq_groups: List[SequenceGroupToSample],
samples: torch.Tensor, samples: torch.Tensor,
) -> List[Tuple[List[int], List[int]]]: ) -> SampleResultType:
"""Run greedy sampling on a given samples.
Args:
selected_seq_groups: A list of sequence groups batched.
samples: (num_selected_samples,) A tensor of samples. The length of
samples could be smaller than selected_seq_groups if
seq_group.do_sample is False.
Returns:
Tuple of (next_token_ids, parent_ids). The length of returned list is
same as the length of selected_seq_groups. If the corresponding
seq_group has do_sample=False, tuple contains ([], [])
"""
samples = samples.tolist() samples = samples.tolist()
sample_idx = 0 sample_idx = 0
results = [] results: SampleResultType = []
for seq_group in selected_seq_groups: for seq_group in selected_seq_groups:
seq_ids, _ = seq_group if not seq_group.do_sample:
results.append(([], []))
continue
seq_ids = seq_group.seq_ids
num_parent_seqs = len(seq_ids) num_parent_seqs = len(seq_ids)
assert num_parent_seqs == 1, ( assert num_parent_seqs == 1, (
"Greedy sampling should have only one seq.") "Greedy sampling should have only one seq.")
...@@ -284,16 +304,33 @@ def _greedy_sample( ...@@ -284,16 +304,33 @@ def _greedy_sample(
def _random_sample( def _random_sample(
selected_seq_groups: List[Tuple[List[int], SamplingParams]], selected_seq_groups: List[SequenceGroupToSample],
is_prompts: List[bool],
random_samples: torch.Tensor, random_samples: torch.Tensor,
) -> List[Tuple[List[int], List[int]]]: ) -> SampleResultType:
"""Run random sampling on a given samples.
Args:
selected_seq_groups: A list of sequence groups batched.
random_samples: (num_selected_samples,) A tensor of samples. The
length of samples could be smaller than selected_seq_groups if
seq_group.do_sample is False.
Returns:
Tuple of (next_token_ids, parent_ids). The length of returned list is
same as the length of selected_seq_groups. If the corresponding
seq_group has do_sample=False, tuple contains ([], [])
"""
# Find the maximum best_of value of the prompt phase requests. # Find the maximum best_of value of the prompt phase requests.
random_samples = random_samples.cpu() random_samples = random_samples.cpu()
sample_idx = 0 sample_idx = 0
results = [] results: SampleResultType = []
for seq_group, is_prompt in zip(selected_seq_groups, is_prompts): for seq_group in selected_seq_groups:
seq_ids, sampling_params = seq_group if not seq_group.do_sample:
results.append(([], []))
continue
seq_ids = seq_group.seq_ids
sampling_params = seq_group.sampling_params
is_prompt = seq_group.is_prompt
num_parent_seqs = len(seq_ids) num_parent_seqs = len(seq_ids)
if is_prompt: if is_prompt:
# Prompt phase. # Prompt phase.
...@@ -311,11 +348,20 @@ def _random_sample( ...@@ -311,11 +348,20 @@ def _random_sample(
def _beam_search_sample( def _beam_search_sample(
selected_seq_groups: List[Tuple[List[int], SamplingParams]], selected_seq_groups: List[SequenceGroupToSample],
is_prompts: List[bool],
seq_data: Dict[int, SequenceData],
logprobs: torch.Tensor, logprobs: torch.Tensor,
) -> List[Tuple[List[int], List[int]]]: ) -> SampleResultType:
"""Run beam sampling on a given samples.
Args:
selected_seq_groups: A list of sequence groups batched.
logprobs: (num_selected_samples, vocab_size,) A tensor of logprob
on selected sample indices.
Returns:
Tuple of (next_token_ids, parent_ids). The length of returned list is
same as the length of selected_seq_groups. If the corresponding
seq_group has do_sample=False, tuple contains ([], [])
"""
# We sample 2 * beam_width candidates to make sure that with high # We sample 2 * beam_width candidates to make sure that with high
# probability we can get `beam_width` candidates in addition to # probability we can get `beam_width` candidates in addition to
# the finished sequences for the next iteration. See # the finished sequences for the next iteration. See
...@@ -326,9 +372,14 @@ def _beam_search_sample( ...@@ -326,9 +372,14 @@ def _beam_search_sample(
# NOTE: Beam search is not vectorized, so its speed can be slower than # NOTE: Beam search is not vectorized, so its speed can be slower than
# other sampling methods. # other sampling methods.
sample_idx = 0 sample_idx = 0
results = [] results: SampleResultType = []
for seq_group, is_prompt in zip(selected_seq_groups, is_prompts): for seq_group in selected_seq_groups:
seq_ids, sampling_params = seq_group if not seq_group.do_sample:
results.append(([], []))
continue
is_prompt = seq_group.is_prompt
seq_ids, sampling_params = seq_group.seq_ids, seq_group.sampling_params
num_parent_seqs = len(seq_ids) num_parent_seqs = len(seq_ids)
beam_width = sampling_params.best_of beam_width = sampling_params.best_of
seq_group_logprobs = logprobs[sample_idx:sample_idx + num_parent_seqs] seq_group_logprobs = logprobs[sample_idx:sample_idx + num_parent_seqs]
...@@ -342,15 +393,16 @@ def _beam_search_sample( ...@@ -342,15 +393,16 @@ def _beam_search_sample(
next_token_ids = next_token_ids.tolist() next_token_ids = next_token_ids.tolist()
else: else:
# Generation phase. # Generation phase.
cumulative_logprobs = [ cumulative_logprobs: List[int] = [
seq_data[seq_id].cumulative_logprob for seq_id in seq_ids seq_group.seq_data[seq_id].cumulative_logprob
for seq_id in seq_ids
] ]
cumulative_logprobs = torch.tensor( cumulative_logprobs_tensor = torch.tensor(
cumulative_logprobs, cumulative_logprobs,
dtype=torch.float, dtype=torch.float,
device=seq_group_logprobs.device) device=seq_group_logprobs.device)
seq_group_logprobs = (seq_group_logprobs + seq_group_logprobs = (seq_group_logprobs +
cumulative_logprobs.unsqueeze(dim=1)) cumulative_logprobs_tensor.unsqueeze(dim=1))
_, topk_ids = torch.topk(seq_group_logprobs.flatten(), _, topk_ids = torch.topk(seq_group_logprobs.flatten(),
2 * beam_width) 2 * beam_width)
topk_ids = topk_ids.tolist() topk_ids = topk_ids.tolist()
...@@ -371,8 +423,7 @@ def _beam_search_sample( ...@@ -371,8 +423,7 @@ def _beam_search_sample(
def _multinomial( def _multinomial(
probs: torch.Tensor, probs: torch.Tensor,
num_samples: int, num_samples: int,
seq_groups: Optional[List[Tuple[List[int], SamplingParams]]] = None, seq_groups: Optional[List[SequenceGroupToSample]] = None,
generators: Optional[List[torch.Generator]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if num_samples > 1: if num_samples > 1:
# This is equivalent to torch.repeat_interleaved (which also # This is equivalent to torch.repeat_interleaved (which also
...@@ -388,9 +439,11 @@ def _multinomial( ...@@ -388,9 +439,11 @@ def _multinomial(
q.exponential_() q.exponential_()
else: else:
sample_idx = 0 sample_idx = 0
for (seq_ids, _), generator in zip(seq_groups, generators): for seq_group in seq_groups:
seq_ids = seq_group.seq_ids
next_sample_idx = sample_idx + len(seq_ids) * num_samples next_sample_idx = sample_idx + len(seq_ids) * num_samples
q[sample_idx:next_sample_idx].exponential_(generator=generator) q[sample_idx:next_sample_idx].exponential_(
generator=seq_group.generator)
sample_idx = next_sample_idx sample_idx = next_sample_idx
return probs.div_(q).argmax(dim=1).view(-1, num_samples) return probs.div_(q).argmax(dim=1).view(-1, num_samples)
...@@ -401,11 +454,13 @@ def _sample_with_torch( ...@@ -401,11 +454,13 @@ def _sample_with_torch(
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
include_gpu_probs_tensor: bool, include_gpu_probs_tensor: bool,
modify_greedy_probs: bool, modify_greedy_probs: bool,
) -> Tuple[List[Tuple[List[int], List[int]]], Optional[torch.Tensor]]: ) -> Tuple[SampleResultType, Optional[torch.Tensor]]:
categorized_seq_group_ids = {t: [] for t in SamplingType} categorized_seq_group_ids: Dict[SamplingType,
List[int]] = {t: []
for t in SamplingType}
categorized_sample_indices = sampling_metadata.categorized_sample_indices categorized_sample_indices = sampling_metadata.categorized_sample_indices
for i, seq_group in enumerate(sampling_metadata.seq_groups): for i, seq_group in enumerate(sampling_metadata.seq_groups):
_, sampling_params = seq_group sampling_params = seq_group.sampling_params
sampling_type = sampling_params.sampling_type sampling_type = sampling_params.sampling_type
categorized_seq_group_ids[sampling_type].append(i) categorized_seq_group_ids[sampling_type].append(i)
...@@ -429,13 +484,11 @@ def _sample_with_torch( ...@@ -429,13 +484,11 @@ def _sample_with_torch(
num_tokens = len(sample_indices) num_tokens = len(sample_indices)
if num_tokens == 0: if num_tokens == 0:
continue continue
seq_group_ids = categorized_seq_group_ids[sampling_type]
seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_ids]
is_prompts = [i < sampling_metadata.num_prompts for i in seq_group_ids]
sample_metadata[sampling_type] = (seq_group_ids, seq_groups,
is_prompts, sample_indices)
long_sample_indices = sample_indices.long()
seq_group_id = categorized_seq_group_ids[sampling_type]
seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_id]
sample_metadata[sampling_type] = (seq_group_id, seq_groups)
long_sample_indices = sample_indices.long()
if sampling_type == SamplingType.GREEDY: if sampling_type == SamplingType.GREEDY:
greedy_samples = torch.argmax(logprobs[long_sample_indices], greedy_samples = torch.argmax(logprobs[long_sample_indices],
dim=-1) dim=-1)
...@@ -455,14 +508,13 @@ def _sample_with_torch( ...@@ -455,14 +508,13 @@ def _sample_with_torch(
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED): elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
max_best_of_in_batch = 1 max_best_of_in_batch = 1
for seq_group, is_prompt in zip(seq_groups, is_prompts): for seq_group in seq_groups:
if is_prompt: if seq_group.is_prompt:
_, sampling_params = seq_group sampling_params = seq_group.sampling_params
max_best_of_in_batch = max(max_best_of_in_batch, max_best_of_in_batch = max(max_best_of_in_batch,
sampling_params.best_of) sampling_params.best_of)
seeded_args = {} if sampling_type == SamplingType.RANDOM else { seeded_args = {} if sampling_type == SamplingType.RANDOM else {
"seq_groups": seq_groups, "seq_groups": seq_groups,
"generators": sampling_metadata.generators,
} }
multinomial_samples[sampling_type] = _multinomial( multinomial_samples[sampling_type] = _multinomial(
...@@ -481,25 +533,22 @@ def _sample_with_torch( ...@@ -481,25 +533,22 @@ def _sample_with_torch(
# GPU<->CPU sync happens in the loop below. # GPU<->CPU sync happens in the loop below.
# This also converts the sample output to Python objects. # This also converts the sample output to Python objects.
for sampling_type in SamplingType: for sampling_type in SamplingType:
if sampling_type not in sample_metadata: if sampling_type not in sample_metadata:
continue continue
seq_group_ids, seq_groups, is_prompts, sample_indices = sample_metadata[ (seq_group_id, seq_groups) = sample_metadata[sampling_type]
sampling_type]
if sampling_type == SamplingType.GREEDY: if sampling_type == SamplingType.GREEDY:
sample_results = _greedy_sample(seq_groups, greedy_samples) sample_results = _greedy_sample(seq_groups, greedy_samples)
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED): elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
sample_results = _random_sample(seq_groups, is_prompts, sample_results = _random_sample(seq_groups,
multinomial_samples[sampling_type]) multinomial_samples[sampling_type])
elif sampling_type == SamplingType.BEAM: elif sampling_type == SamplingType.BEAM:
sample_results = _beam_search_sample(seq_groups, is_prompts, sample_results = _beam_search_sample(seq_groups,
sampling_metadata.seq_data,
beam_search_logprobs) beam_search_logprobs)
sample_results_dict.update(zip(seq_group_ids, sample_results)) sample_results_dict.update(zip(seq_group_id, sample_results))
sample_results = [ sample_results = [
sample_results_dict[i] sample_results_dict.get(i, ([], []))
for i in range(len(sampling_metadata.seq_groups)) for i in range(len(sampling_metadata.seq_groups))
] ]
return sample_results, sampled_token_ids_tensor return sample_results, sampled_token_ids_tensor
...@@ -510,11 +559,13 @@ def _sample_with_triton_kernel( ...@@ -510,11 +559,13 @@ def _sample_with_triton_kernel(
logprobs: torch.Tensor, logprobs: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
sampling_tensors: SamplingTensors, sampling_tensors: SamplingTensors,
) -> List[Tuple[List[int], List[int]]]: ) -> SampleResultType:
categorized_seq_group_ids = {t: [] for t in SamplingType} categorized_seq_group_ids: Dict[SamplingType,
List[int]] = {t: []
for t in SamplingType}
categorized_sample_indices = sampling_metadata.categorized_sample_indices categorized_sample_indices = sampling_metadata.categorized_sample_indices
for i, seq_group in enumerate(sampling_metadata.seq_groups): for i, seq_group in enumerate(sampling_metadata.seq_groups):
_, sampling_params = seq_group sampling_params = seq_group.sampling_params
sampling_type = sampling_params.sampling_type sampling_type = sampling_params.sampling_type
categorized_seq_group_ids[sampling_type].append(i) categorized_seq_group_ids[sampling_type].append(i)
...@@ -530,17 +581,16 @@ def _sample_with_triton_kernel( ...@@ -530,17 +581,16 @@ def _sample_with_triton_kernel(
num_tokens = len(sample_indices) num_tokens = len(sample_indices)
if num_tokens == 0: if num_tokens == 0:
continue continue
seq_group_ids = categorized_seq_group_ids[sampling_type] seq_group_id = categorized_seq_group_ids[sampling_type]
seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_ids] seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_id]
is_prompts = [i < sampling_metadata.num_prompts for i in seq_group_ids] sample_metadata[sampling_type] = (seq_group_id, seq_groups,
sample_metadata[sampling_type] = (seq_group_ids, seq_groups, sample_indices,
is_prompts, sample_indices,
sampled_token_indices) sampled_token_indices)
if sampling_type in (SamplingType.GREEDY, SamplingType.RANDOM, if sampling_type in (SamplingType.GREEDY, SamplingType.RANDOM,
SamplingType.RANDOM_SEED): SamplingType.RANDOM_SEED):
for seq_group, is_prompt in zip(seq_groups, is_prompts): for seq_group in seq_groups:
if is_prompt: if seq_group.is_prompt:
_, sampling_params = seq_group sampling_params = seq_group.sampling_params
max_best_of_in_batch = max(max_best_of_in_batch, max_best_of_in_batch = max(max_best_of_in_batch,
sampling_params.best_of) sampling_params.best_of)
elif sampling_type == SamplingType.BEAM: elif sampling_type == SamplingType.BEAM:
...@@ -564,22 +614,21 @@ def _sample_with_triton_kernel( ...@@ -564,22 +614,21 @@ def _sample_with_triton_kernel(
for sampling_type in SamplingType: for sampling_type in SamplingType:
if sampling_type not in sample_metadata: if sampling_type not in sample_metadata:
continue continue
(seq_group_ids, seq_groups, is_prompts, sample_indices, (seq_group_id, seq_groups, sample_indices,
sampled_token_indices) = sample_metadata[sampling_type] sampled_token_indices) = sample_metadata[sampling_type]
if sampling_type == SamplingType.GREEDY: if sampling_type == SamplingType.GREEDY:
sample_results = _greedy_sample( sample_results = _greedy_sample(
seq_groups, sampled_tokens[sampled_token_indices][:, 0]) seq_groups, sampled_tokens[sampled_token_indices][:, 0])
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED): elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
sample_results = _random_sample( sample_results = _random_sample(
seq_groups, is_prompts, sampled_tokens[sampled_token_indices]) seq_groups, sampled_tokens[sampled_token_indices])
elif sampling_type == SamplingType.BEAM: elif sampling_type == SamplingType.BEAM:
sample_results = _beam_search_sample(seq_groups, is_prompts, sample_results = _beam_search_sample(seq_groups,
sampling_metadata.seq_data,
beam_search_logprobs) beam_search_logprobs)
sample_results_dict.update(zip(seq_group_ids, sample_results)) sample_results_dict.update(zip(seq_group_id, sample_results))
sample_results = [ sample_results = [
sample_results_dict[i] sample_results_dict.get(i, ([], []))
for i in range(len(sampling_metadata.seq_groups)) for i in range(len(sampling_metadata.seq_groups))
] ]
return sample_results return sample_results
...@@ -589,7 +638,19 @@ def _sample( ...@@ -589,7 +638,19 @@ def _sample(
probs: torch.Tensor, logprobs: torch.Tensor, probs: torch.Tensor, logprobs: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_tensors: SamplingTensors, sampling_metadata: SamplingMetadata, sampling_tensors: SamplingTensors,
include_gpu_probs_tensor: bool, modify_greedy_probs: bool include_gpu_probs_tensor: bool, modify_greedy_probs: bool
) -> Tuple[List[Tuple[List[int], List[int]]], Optional[torch.Tensor]]: ) -> Tuple[SampleResultType, Optional[torch.Tensor]]:
"""
Args:
probs: (num_query_tokens_in_batch, num_vocab)
logprobs: (num_query_tokens_in_batch, num_vocab)
sampling_metadata: The metadata for a batch for sampling.
sampling_tensors: Tensors that include sampling related metadata.
Returns:
(next_token_ids, parent_seq_ids) for each seq group in a batch.
If sampling is skipped, it returns ([], [])
sampled_token_ids_tensor: A tensor of sampled token ids.
"""
return _sample_with_torch( return _sample_with_torch(
probs, probs,
logprobs, logprobs,
...@@ -625,57 +686,98 @@ def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: ...@@ -625,57 +686,98 @@ def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
def _get_logprobs( def _get_logprobs(
logprobs: torch.Tensor, logprobs: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
sample_results: List[Tuple[List[int], List[int]]], sample_results: SampleResultType,
) -> Tuple[List[Optional[List[Optional[Dict[int, float]]]]], List[List[Dict[ ) -> Tuple[List[Optional[PromptLogprobs]], List[SampleLogprobs]]:
int, float]]]]: """Return sample lobprobs and prompt logprobs.
# Prepare query indices
batched_logprobs_query_seq_indices: List[int] = [] The logic consists of 3 parts.
batched_logprobs_query_token_indices: List[int] = [] - Select indices to compute logprob from, ranks of token ids, and
# at least get one logprob for each token the top k token ids from logprobs.
- Compute prompt logprobs if required.
- Compute sample logprobs if required.
Args:
logprobs: (num_query_tokens_across_batch, num_vocab). Each query token's
logprob per vocab. Sequence groups' query tokens are batched in a
single flattened tensor. For example, assuming there are N
seq groups, it is sorted by prefill tokens for seq_group_1 (if
prompt logprob is enabled), decode tokens for seq_group_1 (if
sampling is required), prefill tokens for seq_group_2, ...
sampling_metadata: The sampling metadata.
sample_results: (num_seq_groups) The tuple of (next_token_ids,
parent_ids) for each sequence group. When beam search is enabled,
sample_results can contain different number of seq_ids from
sampling_metadata.seq_groups. It is because beam search creates
2 * BEAM_WIDTH number of samples (whereas there are only up to
BEAM_WIDTH number of seq_ids).
Returns:
A tuple of prompt and sample logprobs per sequence group in a batch.
"""
# The index of query token to calculate logprobs. It includes both
# prompt and sample logprob indices.
query_indices: List[int] = []
# The next token ids to get the logprob value from.
next_token_ids: List[int] = []
# The largest requested number of logprobs. We find logprobs as many as the
# largest num logprobs in this API.
largest_num_logprobs = 1 largest_num_logprobs = 1
sample_idx = 0
for i, (seq_group, sample_result) in enumerate( # Select indices to compute logprob from, ranks of token ids, and the top
zip(sampling_metadata.seq_groups, sample_results)): # k token ids from logprobs.
seq_ids, sampling_params = seq_group for (seq_group, sample_result) in zip(sampling_metadata.seq_groups,
next_token_ids, parent_ids = sample_result sample_results):
num_parent_seqs = len(seq_ids) sampling_params = seq_group.sampling_params
if (i < sampling_metadata.num_prompts
# Update indices and tokens for prompt logprobs.
if (seq_group.is_prompt
and sampling_params.prompt_logprobs is not None): and sampling_params.prompt_logprobs is not None):
largest_num_logprobs = max(largest_num_logprobs, largest_num_logprobs = max(largest_num_logprobs,
sampling_params.prompt_logprobs) sampling_params.prompt_logprobs)
prompt_len = sampling_metadata.prompt_lens[i] next_prompt_tokens = _get_next_prompt_tokens(seq_group)
prompt_tokens = sampling_metadata.seq_data[ query_indices.extend(seq_group.prompt_logprob_indices)
seq_ids[0]].prompt_token_ids next_token_ids.extend(next_prompt_tokens)
batched_logprobs_query_seq_indices.extend(
sample_idx + j for j in range(prompt_len - 1)) # Update indices and next tokenes for sample logprob.
batched_logprobs_query_token_indices.extend( if seq_group.do_sample:
token_id for token_id in prompt_tokens[1:]) token_ids, parent_seq_ids = sample_result
sample_idx += prompt_len - 1 # NOTE: We cannot directly use sample_indices because
batched_logprobs_query_seq_indices.extend( # sample_indices only contain parent seq_ids of a previous step.
[sample_idx + parent_id for parent_id in parent_ids]) # The current step may have different number of seq_ids, and
batched_logprobs_query_token_indices.extend(next_token_ids) # we can obtain it from `sample_result[1]`.
if sampling_params.logprobs is not None: query_idx = seq_group.sample_indices[0]
largest_num_logprobs = max(largest_num_logprobs, query_indices.extend(
sampling_params.logprobs) [query_idx + parent_id for parent_id in parent_seq_ids])
sample_idx += num_parent_seqs next_token_ids.extend(token_ids)
assert sample_idx == logprobs.size(0)
if sampling_params.logprobs is not None:
batched_logprobs_query_seq_indices_gpu = torch.tensor( largest_num_logprobs = max(largest_num_logprobs,
batched_logprobs_query_seq_indices, device=logprobs.device) sampling_params.logprobs)
batched_logprobs_query_token_indices_gpu = torch.tensor(
batched_logprobs_query_token_indices, device=logprobs.device) assert len(next_token_ids) == len(query_indices)
# Batched query for logprobs of selected token if len(query_indices) == 0:
batched_logprobs_query_result = logprobs[[ empty_sampled_logprob: SampleLogprobs = []
batched_logprobs_query_seq_indices_gpu, empty_prompt_logprob: Optional[PromptLogprobs] = None
batched_logprobs_query_token_indices_gpu return [empty_prompt_logprob], [empty_sampled_logprob]
query_indices_gpu = torch.tensor(query_indices, device=logprobs.device)
next_token_ids_gpu = torch.tensor(next_token_ids, device=logprobs.device)
# (num_selected_query_tokens, num_logprobs). Note that query_indices can
# contain duplicates if beam search is enabled.
selected_logprobs = logprobs[[
query_indices_gpu,
next_token_ids_gpu,
]] ]]
ranks = _get_ranks(
logprobs[query_indices_gpu],
next_token_ids_gpu,
)
assert selected_logprobs.shape[0] == ranks.shape[0]
batched_ranks_query_result = _get_ranks( # Logprobs of topk tokens for a batch of sequence groups.
logprobs[batched_logprobs_query_seq_indices_gpu], # (num_query_tokens_across_batch).
batched_logprobs_query_token_indices_gpu)
# Batched query for logprobs of topk tokens
if largest_num_logprobs > 0: if largest_num_logprobs > 0:
top_logprobs, top_token_ids = torch.topk(logprobs, top_logprobs, top_token_ids = torch.topk(logprobs,
largest_num_logprobs, largest_num_logprobs,
...@@ -685,79 +787,136 @@ def _get_logprobs( ...@@ -685,79 +787,136 @@ def _get_logprobs(
else: else:
top_logprobs, top_token_ids = None, None top_logprobs, top_token_ids = None, None
batched_logprobs_query_result = batched_logprobs_query_result.cpu() selected_logprobs = selected_logprobs.cpu()
batched_ranks_query_result = batched_ranks_query_result.cpu() ranks = ranks.cpu()
# Gather results # Find prompt/sample logprobs.
result_prompt_logprobs: List[Optional[PromptLogprobs]] = [] prompt_logprobs_per_seq_group: List[Optional[PromptLogprobs]] = []
result_sample_logprobs: List[SampleLogprobs] = [] sample_logprobs_per_seq_group: List[SampleLogprobs] = []
sample_idx = 0 top_logprob_idx = 0
query_result_idx = 0 selected_logprobs_idx = 0
for i, (seq_group, sample_result) in enumerate(
zip(sampling_metadata.seq_groups, sample_results)): for seq_group, sample_result in zip(sampling_metadata.seq_groups,
seq_ids, sampling_params = seq_group sample_results):
next_token_ids, parent_ids = sample_result (prompt_logprobs, top_logprob_idx,
selected_logprobs_idx) = _get_prompt_logprob_if_needed(
seq_group, selected_logprobs, ranks, top_token_ids, top_logprobs,
selected_logprobs_idx, top_logprob_idx)
prompt_logprobs_per_seq_group.append(prompt_logprobs)
(sampled_logprobs, top_logprob_idx,
selected_logprobs_idx) = _get_sampled_logprob_if_needed(
seq_group, sample_result, selected_logprobs, ranks, top_token_ids,
top_logprobs, selected_logprobs_idx, top_logprob_idx)
sample_logprobs_per_seq_group.append(sampled_logprobs)
return prompt_logprobs_per_seq_group, sample_logprobs_per_seq_group
def _get_prompt_logprob_if_needed(
seq_group: SequenceGroupToSample,
selected_logprobs: torch.Tensor,
ranks: torch.Tensor,
top_token_ids: torch.Tensor,
top_logprobs: torch.Tensor,
selected_logprobs_idx: int,
top_logprob_idx: int,
):
"""Compute the prompt logprob from a sequence group if needed."""
sampling_params = seq_group.sampling_params
is_prompt = seq_group.is_prompt
# Find prompt logprobs
prompt_logprobs: Optional[PromptLogprobs] = None
if (is_prompt and sampling_params.prompt_logprobs is not None):
prompt_logprobs = []
num_logprobs = sampling_params.prompt_logprobs
next_prompt_tokens = _get_next_prompt_tokens(seq_group)
for token_id in next_prompt_tokens:
# Calculate the prompt logprob of the real prompt tokens.
# Use tuple here for performance (to use to_list()).
# {token_id: (logprob, rank_from_vocab)}
prompt_logprobs_dict: Dict[int, Tuple[float, int]] = {
token_id: (selected_logprobs[selected_logprobs_idx].item(),
ranks[selected_logprobs_idx].item())
}
# Prompt logprobs # Add top K prompt logprobs along with its rank.
if (i < sampling_metadata.num_prompts if num_logprobs > 0:
and sampling_params.prompt_logprobs is not None): prompt_logprobs_dict.update(
num_logprobs = sampling_params.prompt_logprobs zip(
prompt_tokens = sampling_metadata.seq_data[ top_token_ids[top_logprob_idx, :num_logprobs].tolist(),
seq_ids[0]].prompt_token_ids
group_prompt_logprobs: PromptLogprobs = [None]
for token_id in prompt_tokens[1:]:
prompt_logprobs_dict = {
token_id:
(batched_logprobs_query_result[query_result_idx].item(),
batched_ranks_query_result[query_result_idx].item())
}
if num_logprobs > 0:
prompt_logprobs_dict.update(
zip( zip(
top_token_ids[sample_idx, :num_logprobs].tolist(), top_logprobs[
zip( top_logprob_idx, :num_logprobs].tolist(),
top_logprobs[ # This is ranks. Since top_logprob is sorted,
sample_idx, :num_logprobs].tolist(), # we can just use a range here.
range(1, num_logprobs + 1)))) range(1, num_logprobs + 1))))
group_prompt_logprobs.append({ prompt_logprobs.append({
token_id: Logprob(*logprob_rank) token_id: Logprob(*logprob_and_rank)
for token_id, logprob_rank in prompt_logprobs_dict.items() for token_id, logprob_and_rank in prompt_logprobs_dict.items()
}) })
sample_idx += 1 # + 1 to go to the next prompt token.
query_result_idx += 1 top_logprob_idx += 1
result_prompt_logprobs.append(group_prompt_logprobs) selected_logprobs_idx += 1
else: return prompt_logprobs, top_logprob_idx, selected_logprobs_idx
result_prompt_logprobs.append(None)
# Sample logprobs def _get_sampled_logprob_if_needed(
num_logprobs = sampling_params.logprobs seq_group: SequenceGroupToSample,
if num_logprobs is None: sample_result: Tuple[List[int], List[int]],
num_logprobs = 0 selected_logprobs: torch.Tensor,
group_sample_logprobs: SampleLogprobs = [] ranks: torch.Tensor,
for next_token_id, parent_id in zip(next_token_ids, parent_ids): top_token_ids: torch.Tensor,
sample_logprobs_dict = { top_logprobs: torch.Tensor,
selected_logprobs_idx: int,
top_logprob_idx: int,
):
"""Compute the sample logprob if needed."""
seq_ids = seq_group.seq_ids
num_logprobs = seq_group.sampling_params.logprobs
if num_logprobs is None:
num_logprobs = 0
sampled_logprobs: SampleLogprobs = []
next_token_ids, parent_seq_ids = sample_result
if seq_group.do_sample:
assert len(next_token_ids) > 0
for (next_token_id, parent_id) in zip(next_token_ids, parent_seq_ids):
# Calculate the sample logprob of the real sampled tokens.
# Use tuple here for performance (to use to_list()).
# token_id: (logprob, rank_from_vocab)
sampled_logprobs_dict: Dict[int, Tuple[float, int]] = {
next_token_id: next_token_id:
(batched_logprobs_query_result[query_result_idx].item(), (selected_logprobs[selected_logprobs_idx].item(),
batched_ranks_query_result[query_result_idx].item()) ranks[selected_logprobs_idx].item())
} }
query_result_idx += 1 # +1 to go to the next sampled token. Note that
# selected_logprobs can contain duplicates unlike top_logprobs
# when beam search is enabled.
selected_logprobs_idx += 1
# Second, add top K logprobs along with its rank.
if num_logprobs >= 0: if num_logprobs >= 0:
sample_logprobs_dict.update( sampled_logprobs_dict.update(
zip( zip(
top_token_ids[sample_idx + top_token_ids[top_logprob_idx +
parent_id, :num_logprobs].tolist(), parent_id, :num_logprobs].tolist(),
zip( zip(
top_logprobs[sample_idx + top_logprobs[top_logprob_idx +
parent_id, :num_logprobs].tolist(), parent_id, :num_logprobs].tolist(),
# This is rank. Since top_logprob is sorted, we
# can just use a range here.
range(1, num_logprobs + 1)))) range(1, num_logprobs + 1))))
group_sample_logprobs.append({ sampled_logprobs.append({
token_id: Logprob(*logprob_rank) token_id: Logprob(*logprob_and_rank)
for token_id, logprob_rank in sample_logprobs_dict.items() for token_id, logprob_and_rank in
sampled_logprobs_dict.items()
}) })
result_sample_logprobs.append(group_sample_logprobs) # There are len(seq_ids) number of sampled tokens for the current
sample_idx += len(seq_ids) # sequence group in top_logprobs. Jump to the next seq_group.
top_logprob_idx += len(seq_ids)
return result_prompt_logprobs, result_sample_logprobs return sampled_logprobs, top_logprob_idx, selected_logprobs_idx
def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor, def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor,
...@@ -805,18 +964,18 @@ def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor, ...@@ -805,18 +964,18 @@ def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor,
has implications on the overall design of the sampler, e.g. how to record has implications on the overall design of the sampler, e.g. how to record
accurate logprobs for the user, so this improvement is deferred to later. accurate logprobs for the user, so this improvement is deferred to later.
""" """
logprobs[sample_indices, :] = -float('inf') # NOTE: logprobs are not modified so they can be returned to the user.
logprobs[sample_indices, greedy_samples] = 0.0
probs[sample_indices, :] = 0 probs[sample_indices, :] = 0
probs[sample_indices, greedy_samples] = 1.0 probs[sample_indices, greedy_samples] = 1.0
def _build_sampler_output( def _build_sampler_output(
sample_results: List[Tuple[List[int], List[int]]], sample_results: SampleResultType,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
prompt_logprobs: List[Optional[PromptLogprobs]], prompt_logprobs: List[Optional[PromptLogprobs]],
sample_logprobs: List[SampleLogprobs], sample_logprobs: List[SampleLogprobs],
on_device_tensors: Optional[Tuple[torch.Tensor, torch.Tensor]], on_device_tensors: Optional[Tuple[torch.Tensor, torch.Tensor,
torch.Tensor]],
) -> SamplerOutput: ) -> SamplerOutput:
"""Construct Python objects with the output of sampling. """Construct Python objects with the output of sampling.
...@@ -832,7 +991,7 @@ def _build_sampler_output( ...@@ -832,7 +991,7 @@ def _build_sampler_output(
group_sample_logprobs) in zip(sampling_metadata.seq_groups, group_sample_logprobs) in zip(sampling_metadata.seq_groups,
sample_results, prompt_logprobs, sample_results, prompt_logprobs,
sample_logprobs): sample_logprobs):
seq_ids, _ = seq_group seq_ids = seq_group.seq_ids
next_token_ids, parent_ids = sample_result next_token_ids, parent_ids = sample_result
seq_outputs = [] seq_outputs = []
for parent_id, next_token_id, logprobs in zip(parent_ids, for parent_id, next_token_id, logprobs in zip(parent_ids,
...@@ -845,12 +1004,48 @@ def _build_sampler_output( ...@@ -845,12 +1004,48 @@ def _build_sampler_output(
# If not specified, store None values in SamplerOutput. # If not specified, store None values in SamplerOutput.
if on_device_tensors is not None: if on_device_tensors is not None:
sampled_token_probs, sampled_token_ids = on_device_tensors (sampled_token_probs, logprobs_tensor,
sampled_token_ids) = on_device_tensors
else: else:
sampled_token_probs, sampled_token_ids = (None, None) sampled_token_probs, logprobs_tensor, sampled_token_ids = (None, None,
None)
return SamplerOutput( return SamplerOutput(
outputs=sampler_output, outputs=sampler_output,
sampled_token_probs=sampled_token_probs, sampled_token_probs=sampled_token_probs,
sampled_token_ids=sampled_token_ids, sampled_token_ids=sampled_token_ids,
logprobs=logprobs_tensor,
) )
def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[int]:
"""Get a list of next prompt tokens to compute logprob from a
given sequence group.
It is used to compute prompt logprob. Imagine you have logprob for each
query token. Query token needs to know the next prompt token id to compute
prompt logprob. This is a helper to obtain next prompt token ids.
This API has to be used only when the caller knows seq_group is in prefill
stage.
Returns:
A list of next prompt tokens to compute logprob.
"""
assert seq_group.is_prompt, (
"Caller should ensure the sequence group is in a prefill stage.")
seq_ids = seq_group.seq_ids
query_len = seq_group.query_len
assert query_len is not None
# prompt has only 1 seq id.
assert len(seq_ids) == 1
seq_data = seq_group.seq_data[seq_ids[0]]
computed_len = seq_data.get_num_computed_tokens()
prompt_tokens = seq_data.prompt_token_ids
# +1 because we are looking for a next prompt token.
next_token_index_start = computed_len + 1
next_token_index_end = min(computed_len + query_len + 1,
len(prompt_tokens))
next_prompt_tokens = prompt_tokens[
next_token_index_start:next_token_index_end]
return next_prompt_tokens
...@@ -105,6 +105,14 @@ class VocabParallelEmbedding(torch.nn.Module): ...@@ -105,6 +105,14 @@ class VocabParallelEmbedding(torch.nn.Module):
output = tensor_model_parallel_all_reduce(output_parallel) output = tensor_model_parallel_all_reduce(output_parallel)
return output return output
def extra_repr(self) -> str:
s = f"num_embeddings={self.num_embeddings_per_partition}"
s += f", embedding_dim={self.embedding_dim}"
s += f", org_vocab_size={self.org_vocab_size}"
s += f', num_embeddings_padded={self.num_embeddings_padded}'
s += f', tp_size={self.tp_size}'
return s
class ParallelLMHead(VocabParallelEmbedding): class ParallelLMHead(VocabParallelEmbedding):
"""Parallelized LM head. """Parallelized LM head.
......
...@@ -3,16 +3,19 @@ import copy ...@@ -3,16 +3,19 @@ import copy
import glob import glob
import os import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import (TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple, from typing import Any, Dict, Generator, List, Optional, Tuple, Type
Type)
import huggingface_hub
import torch import torch
from torch import nn from torch import nn
from vllm.config import (VLLM_USE_MODELSCOPE, DeviceConfig, LoadConfig, from vllm.config import (DeviceConfig, LoadConfig, LoadFormat, LoRAConfig,
LoadFormat, LoRAConfig, ModelConfig, ParallelConfig, ModelConfig, ParallelConfig, SchedulerConfig,
SchedulerConfig, VisionLanguageConfig) VisionLanguageConfig)
from vllm.envs import VLLM_USE_MODELSCOPE
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.model_loader.tensorizer import ( from vllm.model_executor.model_loader.tensorizer import (
TensorizerConfig, is_vllm_serialized_tensorizer, load_with_tensorizer, TensorizerConfig, is_vllm_serialized_tensorizer, load_with_tensorizer,
tensorizer_weights_iterator) tensorizer_weights_iterator)
...@@ -24,9 +27,6 @@ from vllm.model_executor.model_loader.weight_utils import ( ...@@ -24,9 +27,6 @@ from vllm.model_executor.model_loader.weight_utils import (
pt_weights_iterator, safetensors_weights_iterator) pt_weights_iterator, safetensors_weights_iterator)
from vllm.model_executor.models.llava import LlavaForConditionalGeneration from vllm.model_executor.models.llava import LlavaForConditionalGeneration
if TYPE_CHECKING:
from vllm.model_executor.layers.linear import LinearMethodBase
_VISION_MODEL_CLASSES = [ _VISION_MODEL_CLASSES = [
LlavaForConditionalGeneration, LlavaForConditionalGeneration,
] ]
...@@ -34,11 +34,10 @@ _VISION_MODEL_CLASSES = [ ...@@ -34,11 +34,10 @@ _VISION_MODEL_CLASSES = [
logger = init_logger(__name__) logger = init_logger(__name__)
def _get_linear_method( def _get_quantization_config(
model_config: ModelConfig, model_config: ModelConfig,
load_config: LoadConfig) -> Optional["LinearMethodBase"]: load_config: LoadConfig) -> Optional[QuantizationConfig]:
"""Get the (maybe quantized) linear method.""" """Get the quantization config."""
linear_method = None
if model_config.quantization is not None: if model_config.quantization is not None:
quant_config = get_quant_config(model_config, load_config) quant_config = get_quant_config(model_config, load_config)
capability = torch.cuda.get_device_capability() capability = torch.cuda.get_device_capability()
...@@ -55,6 +54,7 @@ def _get_linear_method( ...@@ -55,6 +54,7 @@ def _get_linear_method(
f"{model_config.dtype} is not supported for quantization " f"{model_config.dtype} is not supported for quantization "
f"method {model_config.quantization}. Supported dtypes: " f"method {model_config.quantization}. Supported dtypes: "
f"{supported_dtypes}") f"{supported_dtypes}")
<<<<<<< HEAD
linear_method = quant_config.get_linear_method() linear_method = quant_config.get_linear_method()
...@@ -62,6 +62,10 @@ def _get_linear_method( ...@@ -62,6 +62,10 @@ def _get_linear_method(
os.environ['LLAMA_NN'] = '0' os.environ['LLAMA_NN'] = '0'
return linear_method return linear_method
=======
return quant_config
return None
>>>>>>> v0.4.2
def _get_model_initialization_kwargs( def _get_model_initialization_kwargs(
...@@ -89,10 +93,10 @@ def _initialize_model( ...@@ -89,10 +93,10 @@ def _initialize_model(
vision_language_config: Optional[VisionLanguageConfig]) -> nn.Module: vision_language_config: Optional[VisionLanguageConfig]) -> nn.Module:
"""Initialize a model with the given configurations.""" """Initialize a model with the given configurations."""
model_class = get_model_architecture(model_config)[0] model_class = get_model_architecture(model_config)[0]
linear_method = _get_linear_method(model_config, load_config) quant_config = _get_quantization_config(model_config, load_config)
return model_class(config=model_config.hf_config, return model_class(config=model_config.hf_config,
linear_method=linear_method, quant_config=quant_config,
**_get_model_initialization_kwargs( **_get_model_initialization_kwargs(
model_class, lora_config, vision_language_config)) model_class, lora_config, vision_language_config))
...@@ -139,7 +143,9 @@ class DefaultModelLoader(BaseModelLoader): ...@@ -139,7 +143,9 @@ class DefaultModelLoader(BaseModelLoader):
model_path = snapshot_download( model_path = snapshot_download(
model_id=model, model_id=model,
cache_dir=self.load_config.download_dir, cache_dir=self.load_config.download_dir,
revision=revision) local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
revision=revision,
)
else: else:
model_path = model model_path = model
return model_path return model_path
...@@ -233,9 +239,11 @@ class DefaultModelLoader(BaseModelLoader): ...@@ -233,9 +239,11 @@ class DefaultModelLoader(BaseModelLoader):
"fall_back_to_pt_during_load", "fall_back_to_pt_during_load",
True)), ) True)), )
for _, module in model.named_modules(): for _, module in model.named_modules():
linear_method = getattr(module, "linear_method", None) quant_method = getattr(module, "quant_method", None)
if linear_method is not None: if quant_method is not None:
linear_method.process_weights_after_loading(module) quant_method.process_weights_after_loading(module)
# FIXME: Remove this after Mixtral is updated
# to use quant_method.
if hasattr(module, "process_weights_after_loading"): if hasattr(module, "process_weights_after_loading"):
module.process_weights_after_loading() module.process_weights_after_loading()
return model.eval() return model.eval()
...@@ -318,11 +326,11 @@ class TensorizerLoader(BaseModelLoader): ...@@ -318,11 +326,11 @@ class TensorizerLoader(BaseModelLoader):
with set_default_torch_dtype(model_config.dtype): with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device): with torch.device(device_config.device):
model_class = get_model_architecture(model_config)[0] model_class = get_model_architecture(model_config)[0]
linear_method = _get_linear_method(model_config, quant_config = _get_quantization_config(
self.load_config) model_config, self.load_config)
extra_kwargs = _get_model_initialization_kwargs( extra_kwargs = _get_model_initialization_kwargs(
model_class, lora_config, vision_language_config) model_class, lora_config, vision_language_config)
extra_kwargs["linear_method"] = linear_method extra_kwargs["quant_config"] = quant_config
tensorizer_config = copy.copy(self.tensorizer_config) tensorizer_config = copy.copy(self.tensorizer_config)
tensorizer_config.model_class = model_class tensorizer_config.model_class = model_class
......
...@@ -11,9 +11,11 @@ import torch ...@@ -11,9 +11,11 @@ import torch
from torch import nn from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
import vllm.envs as envs
from vllm.config import ModelConfig, ParallelConfig from vllm.config import ModelConfig, ParallelConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.linear import LinearMethodBase from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
...@@ -43,7 +45,7 @@ class TensorizerConfig: ...@@ -43,7 +45,7 @@ class TensorizerConfig:
str, bytes, os.PathLike, int] str, bytes, os.PathLike, int]
vllm_tensorized: bool vllm_tensorized: bool
verify_hash: Optional[bool] = False verify_hash: Optional[bool] = False
num_readers: Optional[int] = 1 num_readers: Optional[int] = None
encryption_keyfile: Optional[str] = None encryption_keyfile: Optional[str] = None
s3_access_key_id: Optional[str] = None s3_access_key_id: Optional[str] = None
s3_secret_access_key: Optional[str] = None s3_secret_access_key: Optional[str] = None
...@@ -63,7 +65,7 @@ class TensorizerConfig: ...@@ -63,7 +65,7 @@ class TensorizerConfig:
"s3_secret_access_key": self.s3_secret_access_key, "s3_secret_access_key": self.s3_secret_access_key,
"s3_endpoint": self.s3_endpoint, "s3_endpoint": self.s3_endpoint,
} }
return TensorizerArgs(**tensorizer_args) return TensorizerArgs(**tensorizer_args) # type: ignore
def verify_with_parallel_config( def verify_with_parallel_config(
self, self,
...@@ -103,7 +105,7 @@ class TensorizerArgs: ...@@ -103,7 +105,7 @@ class TensorizerArgs:
str, bytes, os.PathLike, int] str, bytes, os.PathLike, int]
vllm_tensorized: bool vllm_tensorized: bool
verify_hash: Optional[bool] = False verify_hash: Optional[bool] = False
num_readers: Optional[int] = 1 num_readers: Optional[int] = None
encryption_keyfile: Optional[str] = None encryption_keyfile: Optional[str] = None
s3_access_key_id: Optional[str] = None s3_access_key_id: Optional[str] = None
s3_secret_access_key: Optional[str] = None s3_secret_access_key: Optional[str] = None
...@@ -124,8 +126,9 @@ class TensorizerArgs: ...@@ -124,8 +126,9 @@ class TensorizerArgs:
the hashes stored in the metadata. A `HashMismatchError` will be the hashes stored in the metadata. A `HashMismatchError` will be
raised if any of the hashes do not match. raised if any of the hashes do not match.
num_readers: Controls how many threads are allowed to read concurrently num_readers: Controls how many threads are allowed to read concurrently
from the source file. Default is 1. This greatly increases from the source file. Default is `None`, which will dynamically set
performance. the number of readers based on the number of available
resources and model size. This greatly increases performance.
encryption_keyfile: File path to a binary file containing a encryption_keyfile: File path to a binary file containing a
binary key to use for decryption. `None` (the default) means binary key to use for decryption. `None` (the default) means
no decryption. See the example script in no decryption. See the example script in
...@@ -140,13 +143,10 @@ class TensorizerArgs: ...@@ -140,13 +143,10 @@ class TensorizerArgs:
def __post_init__(self): def __post_init__(self):
self.file_obj = self.tensorizer_uri self.file_obj = self.tensorizer_uri
self.s3_access_key_id = (self.s3_access_key_id self.s3_access_key_id = self.s3_access_key_id or envs.S3_ACCESS_KEY_ID
or os.environ.get("S3_ACCESS_KEY_ID")) or None self.s3_secret_access_key = (self.s3_secret_access_key
self.s3_secret_access_key = ( or envs.S3_SECRET_ACCESS_KEY)
self.s3_secret_access_key self.s3_endpoint = self.s3_endpoint or envs.S3_ENDPOINT_URL
or os.environ.get("S3_SECRET_ACCESS_KEY")) or None
self.s3_endpoint = (self.s3_endpoint
or os.environ.get("S3_ENDPOINT_URL")) or None
self.stream_params = { self.stream_params = {
"s3_access_key_id": self.s3_access_key_id, "s3_access_key_id": self.s3_access_key_id,
"s3_secret_access_key": self.s3_secret_access_key, "s3_secret_access_key": self.s3_secret_access_key,
...@@ -198,10 +198,12 @@ class TensorizerArgs: ...@@ -198,10 +198,12 @@ class TensorizerArgs:
"use for decryption. Can be a file path or S3 network URI.") "use for decryption. Can be a file path or S3 network URI.")
group.add_argument( group.add_argument(
"--num-readers", "--num-readers",
default=1, default=None,
type=int, type=int,
help="Controls how many threads are allowed to read concurrently " help="Controls how many threads are allowed to read concurrently "
"from the source file.") "from the source file. Default is `None`, which will dynamically "
"set the number of readers based on the available resources "
"and model size. This greatly increases performance.")
group.add_argument( group.add_argument(
"--s3-access-key-id", "--s3-access-key-id",
default=None, default=None,
...@@ -251,7 +253,7 @@ class TensorizerAgent: ...@@ -251,7 +253,7 @@ class TensorizerAgent:
""" """
def __init__(self, tensorizer_config: TensorizerConfig, def __init__(self, tensorizer_config: TensorizerConfig,
linear_method: LinearMethodBase, **extra_kwargs): quant_config: QuantizationConfig, **extra_kwargs):
if tensorizer_load_fail is not None: if tensorizer_load_fail is not None:
raise ImportError( raise ImportError(
"Tensorizer is not installed. Please install tensorizer " "Tensorizer is not installed. Please install tensorizer "
...@@ -262,19 +264,21 @@ class TensorizerAgent: ...@@ -262,19 +264,21 @@ class TensorizerAgent:
self.tensorizer_args = ( self.tensorizer_args = (
self.tensorizer_config._construct_tensorizer_args()) self.tensorizer_config._construct_tensorizer_args())
self.extra_kwargs = extra_kwargs self.extra_kwargs = extra_kwargs
if extra_kwargs.get("linear_method", None) is not None: if extra_kwargs.get("quant_config", None) is not None:
self.linear_method = extra_kwargs["linear_method"] self.quant_config = extra_kwargs["quant_config"]
else: else:
self.linear_method = linear_method self.quant_config = quant_config
self.model = self._init_model() self.model = self._init_model()
def _init_model(self): def _init_model(self):
assert self.tensorizer_config.hf_config is not None
model_args = self.tensorizer_config.hf_config model_args = self.tensorizer_config.hf_config
model_args.torch_dtype = self.tensorizer_config.dtype model_args.torch_dtype = self.tensorizer_config.dtype
assert self.tensorizer_config.model_class is not None
with no_init_or_tensor(): with no_init_or_tensor():
return self.tensorizer_config.model_class( return self.tensorizer_config.model_class(
config=model_args, config=model_args,
linear_method=self.linear_method, quant_config=self.quant_config,
**self.extra_kwargs) **self.extra_kwargs)
def _resize_lora_embeddings(self): def _resize_lora_embeddings(self):
...@@ -334,10 +338,10 @@ class TensorizerAgent: ...@@ -334,10 +338,10 @@ class TensorizerAgent:
per_second = convert_bytes(deserializer.total_tensor_bytes / duration) per_second = convert_bytes(deserializer.total_tensor_bytes / duration)
after_mem = get_mem_usage() after_mem = get_mem_usage()
deserializer.close() deserializer.close()
logger.info(f"Deserialized {total_bytes_str} in " logger.info("Deserialized %s in %0.2fs, %s/s", total_bytes_str,
f"{end - start:0.2f}s, {per_second}/s") end - start, per_second)
logger.info(f"Memory usage before: {before_mem}") logger.info("Memory usage before: %s", before_mem)
logger.info(f"Memory usage after: {after_mem}") logger.info("Memory usage after: %s", after_mem)
self._check_tensors_on_meta_device() self._check_tensors_on_meta_device()
self._resize_lora_embeddings() self._resize_lora_embeddings()
......
...@@ -127,11 +127,14 @@ def get_quant_config(model_config: ModelConfig, ...@@ -127,11 +127,14 @@ def get_quant_config(model_config: ModelConfig,
if not is_local: if not is_local:
# Download the config files. # Download the config files.
with get_lock(model_name_or_path, load_config.download_dir): with get_lock(model_name_or_path, load_config.download_dir):
hf_folder = snapshot_download(model_name_or_path, hf_folder = snapshot_download(
revision=model_config.revision, model_name_or_path,
allow_patterns="*.json", revision=model_config.revision,
cache_dir=load_config.download_dir, allow_patterns="*.json",
tqdm_class=DisabledTqdm) cache_dir=load_config.download_dir,
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
tqdm_class=DisabledTqdm,
)
else: else:
hf_folder = model_name_or_path hf_folder = model_name_or_path
...@@ -161,12 +164,14 @@ def get_quant_config(model_config: ModelConfig, ...@@ -161,12 +164,14 @@ def get_quant_config(model_config: ModelConfig,
return quant_cls.from_config(config) return quant_cls.from_config(config)
def download_weights_from_hf(model_name_or_path: str, def download_weights_from_hf(
cache_dir: Optional[str], model_name_or_path: str,
allow_patterns: List[str], cache_dir: Optional[str],
revision: Optional[str] = None) -> str: allow_patterns: List[str],
revision: Optional[str] = None,
) -> str:
"""Download model weights from Hugging Face Hub. """Download model weights from Hugging Face Hub.
Args: Args:
model_name_or_path (str): The model name or path. model_name_or_path (str): The model name or path.
cache_dir (Optional[str]): The cache directory to store the model cache_dir (Optional[str]): The cache directory to store the model
...@@ -179,26 +184,30 @@ def download_weights_from_hf(model_name_or_path: str, ...@@ -179,26 +184,30 @@ def download_weights_from_hf(model_name_or_path: str,
Returns: Returns:
str: The path to the downloaded model weights. str: The path to the downloaded model weights.
""" """
# Before we download we look at that is available: if not huggingface_hub.constants.HF_HUB_OFFLINE:
fs = HfFileSystem() # Before we download we look at that is available:
file_list = fs.ls(model_name_or_path, detail=False, revision=revision) fs = HfFileSystem()
file_list = fs.ls(model_name_or_path, detail=False, revision=revision)
# depending on what is available we download different things
for pattern in allow_patterns: # depending on what is available we download different things
matching = fnmatch.filter(file_list, pattern) for pattern in allow_patterns:
if len(matching) > 0: matching = fnmatch.filter(file_list, pattern)
allow_patterns = [pattern] if len(matching) > 0:
break allow_patterns = [pattern]
break
logger.info(f"Using model weights format {allow_patterns}")
logger.info("Using model weights format %s", allow_patterns)
# Use file lock to prevent multiple processes from # Use file lock to prevent multiple processes from
# downloading the same model weights at the same time. # downloading the same model weights at the same time.
with get_lock(model_name_or_path, cache_dir): with get_lock(model_name_or_path, cache_dir):
hf_folder = snapshot_download(model_name_or_path, hf_folder = snapshot_download(
allow_patterns=allow_patterns, model_name_or_path,
cache_dir=cache_dir, allow_patterns=allow_patterns,
tqdm_class=DisabledTqdm, cache_dir=cache_dir,
revision=revision) tqdm_class=DisabledTqdm,
revision=revision,
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
)
return hf_folder return hf_folder
...@@ -310,17 +319,17 @@ def kv_cache_scales_loader( ...@@ -310,17 +319,17 @@ def kv_cache_scales_loader(
return layer_scales_map.items() return layer_scales_map.items()
except FileNotFoundError: except FileNotFoundError:
logger.error(f"File or directory '{filename}' not found.") logger.error("File or directory '%s' not found.", filename)
except json.JSONDecodeError: except json.JSONDecodeError:
logger.error(f"Error decoding JSON in file '{filename}'.") logger.error("Error decoding JSON in file '%s'.", filename)
except Exception as e: except Exception as e:
logger.error(f"An error occurred while reading '{filename}': {e}") logger.error("An error occurred while reading '%s': %s", filename, e)
# This section is reached if and only if any of the excepts are hit # This section is reached if and only if any of the excepts are hit
# Return an empty iterable (list) => no KV cache scales are loaded # Return an empty iterable (list) => no KV cache scales are loaded
# which ultimately defaults to 1.0 scales # which ultimately defaults to 1.0 scales
logger.warning("Defaulting to KV cache scaling factors = 1.0 " logger.warning(
f"for all layers in TP rank {tp_rank} " "Defaulting to KV cache scaling factors = 1.0 for all "
"as an error occurred during loading.") "layers in TP rank %d as an error occurred during loading.", tp_rank)
return [] return []
......
...@@ -42,10 +42,11 @@ _MODELS = { ...@@ -42,10 +42,11 @@ _MODELS = {
"MptForCausalLM": ("mpt", "MPTForCausalLM"), "MptForCausalLM": ("mpt", "MPTForCausalLM"),
"MPTForCausalLM": ("mpt", "MPTForCausalLM"), "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
"MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"), "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
"OLMoForCausalLM": ("olmo", "OLMoForCausalLM"), "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
"OPTForCausalLM": ("opt", "OPTForCausalLM"), "OPTForCausalLM": ("opt", "OPTForCausalLM"),
"OrionForCausalLM": ("orion", "OrionForCausalLM"), "OrionForCausalLM": ("orion", "OrionForCausalLM"),
"PhiForCausalLM": ("phi", "PhiForCausalLM"), "PhiForCausalLM": ("phi", "PhiForCausalLM"),
"Phi3ForCausalLM": ("llama", "LlamaForCausalLM"),
"QWenLMHeadModel": ("qwen", "QWenLMHeadModel"), "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"), "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
"Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"), "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
...@@ -90,8 +91,8 @@ class ModelRegistry: ...@@ -90,8 +91,8 @@ class ModelRegistry:
"ROCm for now.") "ROCm for now.")
if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS: if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
logger.warning( logger.warning(
f"Model architecture {model_arch} is partially supported " "Model architecture %s is partially supported by ROCm: %s",
"by ROCm: " + _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch]) model_arch, _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch])
module_name, model_cls_name = _MODELS[model_arch] module_name, model_cls_name = _MODELS[model_arch]
module = importlib.import_module( module = importlib.import_module(
...@@ -106,9 +107,9 @@ class ModelRegistry: ...@@ -106,9 +107,9 @@ class ModelRegistry:
def register_model(model_arch: str, model_cls: Type[nn.Module]): def register_model(model_arch: str, model_cls: Type[nn.Module]):
if model_arch in _MODELS: if model_arch in _MODELS:
logger.warning( logger.warning(
f"Model architecture {model_arch} is already registered, " "Model architecture %s is already registered, and will be "
"and will be overwritten by the new model " "overwritten by the new model class %s.", model_arch,
f"class {model_cls.__name__}.") model_cls.__name__)
global _OOT_MODELS global _OOT_MODELS
_OOT_MODELS[model_arch] = model_cls _OOT_MODELS[model_arch] = model_cls
......
...@@ -31,11 +31,12 @@ from vllm.distributed import (get_tensor_model_parallel_rank, ...@@ -31,11 +31,12 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
...@@ -77,17 +78,17 @@ class BaiChuanMLP(nn.Module): ...@@ -77,17 +78,17 @@ class BaiChuanMLP(nn.Module):
hidden_size: int, hidden_size: int,
intermediate_size: int, intermediate_size: int,
hidden_act: str, hidden_act: str,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
self.gate_up_proj = MergedColumnParallelLinear( self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2, hidden_size, [intermediate_size] * 2,
bias=False, bias=False,
linear_method=linear_method) quant_config=quant_config)
self.down_proj = RowParallelLinear(intermediate_size, self.down_proj = RowParallelLinear(intermediate_size,
hidden_size, hidden_size,
bias=False, bias=False,
linear_method=linear_method) quant_config=quant_config)
if hidden_act != "silu": if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. " raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.") "Only silu is supported for now.")
...@@ -110,7 +111,7 @@ class BaiChuanAttention(nn.Module): ...@@ -110,7 +111,7 @@ class BaiChuanAttention(nn.Module):
position_embedding: str, position_embedding: str,
rope_theta: float = 10000, rope_theta: float = 10000,
max_position_embeddings: int = 8192, max_position_embeddings: int = 8192,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
...@@ -132,13 +133,13 @@ class BaiChuanAttention(nn.Module): ...@@ -132,13 +133,13 @@ class BaiChuanAttention(nn.Module):
self.total_num_heads, self.total_num_heads,
self.total_num_heads, self.total_num_heads,
bias=False, bias=False,
linear_method=linear_method, quant_config=quant_config,
) )
self.o_proj = RowParallelLinear( self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim, self.total_num_heads * self.head_dim,
hidden_size, hidden_size,
bias=False, bias=False,
linear_method=linear_method, quant_config=quant_config,
) )
# Create the alibi slopes and slice them. # Create the alibi slopes and slice them.
if self.postion_embedding == "ALIBI": if self.postion_embedding == "ALIBI":
...@@ -184,7 +185,7 @@ class BaiChuanDecoderLayer(nn.Module): ...@@ -184,7 +185,7 @@ class BaiChuanDecoderLayer(nn.Module):
def __init__(self, def __init__(self,
config: PretrainedConfig, config: PretrainedConfig,
position_embedding: str, position_embedding: str,
linear_method: Optional[LinearMethodBase] = None): quant_config: Optional[QuantizationConfig] = None):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000) rope_theta = getattr(config, "rope_theta", 10000)
...@@ -196,13 +197,13 @@ class BaiChuanDecoderLayer(nn.Module): ...@@ -196,13 +197,13 @@ class BaiChuanDecoderLayer(nn.Module):
position_embedding=position_embedding, position_embedding=position_embedding,
rope_theta=rope_theta, rope_theta=rope_theta,
max_position_embeddings=max_position_embeddings, max_position_embeddings=max_position_embeddings,
linear_method=linear_method, quant_config=quant_config,
) )
self.mlp = BaiChuanMLP( self.mlp = BaiChuanMLP(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act, hidden_act=config.hidden_act,
linear_method=linear_method, quant_config=quant_config,
) )
self.input_layernorm = RMSNorm(config.hidden_size, self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps) eps=config.rms_norm_eps)
...@@ -243,7 +244,7 @@ class BaiChuanModel(nn.Module): ...@@ -243,7 +244,7 @@ class BaiChuanModel(nn.Module):
def __init__(self, def __init__(self,
config: PretrainedConfig, config: PretrainedConfig,
position_embedding: str, position_embedding: str,
linear_method: Optional[LinearMethodBase] = None): quant_config: Optional[QuantizationConfig] = None):
super().__init__() super().__init__()
self.config = config self.config = config
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
...@@ -254,7 +255,7 @@ class BaiChuanModel(nn.Module): ...@@ -254,7 +255,7 @@ class BaiChuanModel(nn.Module):
config.hidden_size, config.hidden_size,
) )
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
BaiChuanDecoderLayer(config, position_embedding, linear_method) BaiChuanDecoderLayer(config, position_embedding, quant_config)
for _ in range(config.num_hidden_layers) for _ in range(config.num_hidden_layers)
]) ])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
...@@ -303,13 +304,13 @@ class BaiChuanBaseForCausalLM(nn.Module): ...@@ -303,13 +304,13 @@ class BaiChuanBaseForCausalLM(nn.Module):
self, self,
config, config,
position_embedding: str, position_embedding: str,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None, lora_config: Optional[LoRAConfig] = None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
self.linear_method = linear_method self.quant_config = quant_config
self.model = BaiChuanModel(config, position_embedding, linear_method) self.model = BaiChuanModel(config, position_embedding, quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
...@@ -388,13 +389,13 @@ class BaichuanForCausalLM(BaiChuanBaseForCausalLM): ...@@ -388,13 +389,13 @@ class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
def __init__( def __init__(
self, self,
config, config,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None, lora_config: Optional[LoRAConfig] = None,
): ):
if config.hidden_size == 4096: # baichuan2 7b if config.hidden_size == 4096: # baichuan2 7b
super().__init__(config, "ROPE", linear_method, lora_config) super().__init__(config, "ROPE", quant_config, lora_config)
else: # baichuan 13b, baichuan2 13b else: # baichuan 13b, baichuan2 13b
super().__init__(config, "ALIBI", linear_method, lora_config) super().__init__(config, "ALIBI", quant_config, lora_config)
class BaiChuanForCausalLM(BaiChuanBaseForCausalLM): class BaiChuanForCausalLM(BaiChuanBaseForCausalLM):
...@@ -403,7 +404,7 @@ class BaiChuanForCausalLM(BaiChuanBaseForCausalLM): ...@@ -403,7 +404,7 @@ class BaiChuanForCausalLM(BaiChuanBaseForCausalLM):
def __init__( def __init__(
self, self,
config, config,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None, lora_config: Optional[LoRAConfig] = None,
): ):
super().__init__(config, "ROPE", linear_method, lora_config) super().__init__(config, "ROPE", quant_config, lora_config)
...@@ -28,10 +28,11 @@ from vllm.distributed import (get_tensor_model_parallel_rank, ...@@ -28,10 +28,11 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
...@@ -70,7 +71,7 @@ class BloomAttention(nn.Module): ...@@ -70,7 +71,7 @@ class BloomAttention(nn.Module):
def __init__( def __init__(
self, self,
config: BloomConfig, config: BloomConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
...@@ -87,13 +88,13 @@ class BloomAttention(nn.Module): ...@@ -87,13 +88,13 @@ class BloomAttention(nn.Module):
self.head_dim, self.head_dim,
self.total_num_heads, self.total_num_heads,
bias=True, bias=True,
linear_method=linear_method, quant_config=quant_config,
) )
self.dense = RowParallelLinear( self.dense = RowParallelLinear(
self.hidden_size, self.hidden_size,
self.hidden_size, self.hidden_size,
bias=True, bias=True,
linear_method=linear_method, quant_config=quant_config,
) )
# Create the alibi slopes and slice them. # Create the alibi slopes and slice them.
...@@ -129,21 +130,20 @@ class BloomMLP(nn.Module): ...@@ -129,21 +130,20 @@ class BloomMLP(nn.Module):
def __init__( def __init__(
self, self,
config: BloomConfig, config: BloomConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
hidden_size = config.hidden_size hidden_size = config.hidden_size
self.dense_h_to_4h = ColumnParallelLinear( self.dense_h_to_4h = ColumnParallelLinear(
hidden_size, hidden_size,
4 * hidden_size, 4 * hidden_size,
linear_method=linear_method, quant_config=quant_config,
) )
quant_config = getattr(linear_method, "quant_config", None)
self.gelu_impl = get_act_fn("gelu", quant_config, 4 * hidden_size) self.gelu_impl = get_act_fn("gelu", quant_config, 4 * hidden_size)
self.dense_4h_to_h = RowParallelLinear( self.dense_4h_to_h = RowParallelLinear(
4 * hidden_size, 4 * hidden_size,
hidden_size, hidden_size,
linear_method=linear_method, quant_config=quant_config,
) )
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
...@@ -158,17 +158,17 @@ class BloomBlock(nn.Module): ...@@ -158,17 +158,17 @@ class BloomBlock(nn.Module):
def __init__( def __init__(
self, self,
config: BloomConfig, config: BloomConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
hidden_size = config.hidden_size hidden_size = config.hidden_size
self.input_layernorm = nn.LayerNorm(hidden_size, self.input_layernorm = nn.LayerNorm(hidden_size,
eps=config.layer_norm_epsilon) eps=config.layer_norm_epsilon)
self.self_attention = BloomAttention(config, linear_method) self.self_attention = BloomAttention(config, quant_config)
self.post_attention_layernorm = nn.LayerNorm( self.post_attention_layernorm = nn.LayerNorm(
hidden_size, eps=config.layer_norm_epsilon) hidden_size, eps=config.layer_norm_epsilon)
self.mlp = BloomMLP(config, linear_method) self.mlp = BloomMLP(config, quant_config)
self.apply_residual_connection_post_layernorm = ( self.apply_residual_connection_post_layernorm = (
config.apply_residual_connection_post_layernorm) config.apply_residual_connection_post_layernorm)
...@@ -214,7 +214,7 @@ class BloomModel(nn.Module): ...@@ -214,7 +214,7 @@ class BloomModel(nn.Module):
def __init__( def __init__(
self, self,
config: BloomConfig, config: BloomConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
...@@ -229,7 +229,7 @@ class BloomModel(nn.Module): ...@@ -229,7 +229,7 @@ class BloomModel(nn.Module):
# Transformer blocks # Transformer blocks
self.h = nn.ModuleList([ self.h = nn.ModuleList([
BloomBlock(config, linear_method) BloomBlock(config, quant_config)
for _ in range(config.num_hidden_layers) for _ in range(config.num_hidden_layers)
]) ])
...@@ -262,12 +262,12 @@ class BloomForCausalLM(nn.Module): ...@@ -262,12 +262,12 @@ class BloomForCausalLM(nn.Module):
def __init__( def __init__(
self, self,
config: BloomConfig, config: BloomConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
self.linear_method = linear_method self.quant_config = quant_config
self.transformer = BloomModel(config, linear_method) self.transformer = BloomModel(config, quant_config)
self.lm_head_weight = self.transformer.word_embeddings.weight self.lm_head_weight = self.transformer.word_embeddings.weight
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
......
...@@ -13,11 +13,12 @@ from vllm.config import LoRAConfig ...@@ -13,11 +13,12 @@ from vllm.config import LoRAConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
...@@ -33,7 +34,7 @@ class GLMAttention(nn.Module): ...@@ -33,7 +34,7 @@ class GLMAttention(nn.Module):
def __init__( def __init__(
self, self,
config, config,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
...@@ -65,13 +66,13 @@ class GLMAttention(nn.Module): ...@@ -65,13 +66,13 @@ class GLMAttention(nn.Module):
self.total_num_heads, self.total_num_heads,
self.total_num_kv_heads, self.total_num_kv_heads,
bias=config.add_bias_linear or config.add_qkv_bias, bias=config.add_bias_linear or config.add_qkv_bias,
linear_method=linear_method, quant_config=quant_config,
) )
self.dense = RowParallelLinear( self.dense = RowParallelLinear(
self.total_num_heads * self.head_dim, self.total_num_heads * self.head_dim,
config.hidden_size, config.hidden_size,
bias=config.add_bias_linear, bias=config.add_bias_linear,
linear_method=linear_method, quant_config=quant_config,
) )
# https://huggingface.co/THUDM/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141 # https://huggingface.co/THUDM/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141
...@@ -123,7 +124,7 @@ class GLMMLP(nn.Module): ...@@ -123,7 +124,7 @@ class GLMMLP(nn.Module):
def __init__( def __init__(
self, self,
config, config,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
...@@ -134,7 +135,7 @@ class GLMMLP(nn.Module): ...@@ -134,7 +135,7 @@ class GLMMLP(nn.Module):
config.hidden_size, config.hidden_size,
[config.ffn_hidden_size] * 2, [config.ffn_hidden_size] * 2,
bias=config.add_bias_linear, bias=config.add_bias_linear,
linear_method=linear_method, quant_config=quant_config,
) )
self.activation_func = SiluAndMul() self.activation_func = SiluAndMul()
...@@ -144,7 +145,7 @@ class GLMMLP(nn.Module): ...@@ -144,7 +145,7 @@ class GLMMLP(nn.Module):
config.ffn_hidden_size, config.ffn_hidden_size,
config.hidden_size, config.hidden_size,
bias=config.add_bias_linear, bias=config.add_bias_linear,
linear_method=linear_method, quant_config=quant_config,
) )
def forward(self, hidden_states): def forward(self, hidden_states):
...@@ -166,7 +167,7 @@ class GLMBlock(nn.Module): ...@@ -166,7 +167,7 @@ class GLMBlock(nn.Module):
def __init__( def __init__(
self, self,
config, config,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
self.apply_residual_connection_post_layernorm = ( self.apply_residual_connection_post_layernorm = (
...@@ -180,7 +181,7 @@ class GLMBlock(nn.Module): ...@@ -180,7 +181,7 @@ class GLMBlock(nn.Module):
eps=config.layernorm_epsilon) eps=config.layernorm_epsilon)
# Self attention. # Self attention.
self.self_attention = GLMAttention(config, linear_method) self.self_attention = GLMAttention(config, quant_config)
self.hidden_dropout = config.hidden_dropout self.hidden_dropout = config.hidden_dropout
# Layernorm on the attention output # Layernorm on the attention output
...@@ -188,7 +189,7 @@ class GLMBlock(nn.Module): ...@@ -188,7 +189,7 @@ class GLMBlock(nn.Module):
config.hidden_size, eps=config.layernorm_epsilon) config.hidden_size, eps=config.layernorm_epsilon)
# MLP # MLP
self.mlp = GLMMLP(config, linear_method) self.mlp = GLMMLP(config, quant_config)
def forward( def forward(
self, self,
...@@ -236,7 +237,7 @@ class GLMTransformer(nn.Module): ...@@ -236,7 +237,7 @@ class GLMTransformer(nn.Module):
def __init__( def __init__(
self, self,
config, config,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
self.post_layer_norm = config.post_layer_norm self.post_layer_norm = config.post_layer_norm
...@@ -246,7 +247,7 @@ class GLMTransformer(nn.Module): ...@@ -246,7 +247,7 @@ class GLMTransformer(nn.Module):
# Transformer layers. # Transformer layers.
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[GLMBlock(config, linear_method) for i in range(self.num_layers)]) [GLMBlock(config, quant_config) for i in range(self.num_layers)])
if self.post_layer_norm: if self.post_layer_norm:
layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
...@@ -281,7 +282,7 @@ class ChatGLMModel(nn.Module): ...@@ -281,7 +282,7 @@ class ChatGLMModel(nn.Module):
def __init__( def __init__(
self, self,
config, config,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
...@@ -291,7 +292,7 @@ class ChatGLMModel(nn.Module): ...@@ -291,7 +292,7 @@ class ChatGLMModel(nn.Module):
self.num_layers = config.num_layers self.num_layers = config.num_layers
self.multi_query_group_num = config.multi_query_group_num self.multi_query_group_num = config.multi_query_group_num
self.kv_channels = config.kv_channels self.kv_channels = config.kv_channels
self.encoder = GLMTransformer(config, linear_method) self.encoder = GLMTransformer(config, quant_config)
self.output_layer = ParallelLMHead(config.padded_vocab_size, self.output_layer = ParallelLMHead(config.padded_vocab_size,
config.hidden_size) config.hidden_size)
...@@ -333,13 +334,13 @@ class ChatGLMForCausalLM(nn.Module): ...@@ -333,13 +334,13 @@ class ChatGLMForCausalLM(nn.Module):
def __init__( def __init__(
self, self,
config: ChatGLMConfig, config: ChatGLMConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None, lora_config: Optional[LoRAConfig] = None,
): ):
super().__init__() super().__init__()
self.config: ChatGLMConfig = config self.config: ChatGLMConfig = config
self.linear_method = linear_method self.quant_config = quant_config
self.transformer = ChatGLMModel(config, linear_method) self.transformer = ChatGLMModel(config, quant_config)
self.lm_head_weight = self.transformer.output_layer.weight self.lm_head_weight = self.transformer.output_layer.weight
self.logits_processor = LogitsProcessor(config.padded_vocab_size) self.logits_processor = LogitsProcessor(config.padded_vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
......
...@@ -32,11 +32,12 @@ from vllm.attention import Attention, AttentionMetadata ...@@ -32,11 +32,12 @@ from vllm.attention import Attention, AttentionMetadata
from vllm.distributed import (get_tensor_model_parallel_rank, from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
...@@ -91,7 +92,7 @@ class CohereMLP(nn.Module): ...@@ -91,7 +92,7 @@ class CohereMLP(nn.Module):
def __init__( def __init__(
self, self,
config, config,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -101,13 +102,13 @@ class CohereMLP(nn.Module): ...@@ -101,13 +102,13 @@ class CohereMLP(nn.Module):
self.hidden_size, self.hidden_size,
[self.intermediate_size] * 2, [self.intermediate_size] * 2,
bias=False, bias=False,
linear_method=linear_method, quant_config=quant_config,
) )
self.down_proj = RowParallelLinear( self.down_proj = RowParallelLinear(
self.intermediate_size, self.intermediate_size,
self.hidden_size, self.hidden_size,
bias=False, bias=False,
linear_method=linear_method, quant_config=quant_config,
) )
self.act_fn = SiluAndMul() self.act_fn = SiluAndMul()
...@@ -123,7 +124,7 @@ class CohereAttention(nn.Module): ...@@ -123,7 +124,7 @@ class CohereAttention(nn.Module):
def __init__( def __init__(
self, self,
config: CohereConfig, config: CohereConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
...@@ -158,13 +159,13 @@ class CohereAttention(nn.Module): ...@@ -158,13 +159,13 @@ class CohereAttention(nn.Module):
self.total_num_heads, self.total_num_heads,
self.total_num_kv_heads, self.total_num_kv_heads,
bias=False, bias=False,
linear_method=linear_method, quant_config=quant_config,
) )
self.o_proj = RowParallelLinear( self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim, self.total_num_heads * self.head_dim,
self.hidden_size, self.hidden_size,
bias=False, bias=False,
linear_method=linear_method, quant_config=quant_config,
) )
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
...@@ -218,13 +219,13 @@ class CohereDecoderLayer(nn.Module): ...@@ -218,13 +219,13 @@ class CohereDecoderLayer(nn.Module):
def __init__(self, def __init__(self,
config: CohereConfig, config: CohereConfig,
linear_method: Optional[LinearMethodBase] = None): quant_config: Optional[QuantizationConfig] = None):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.self_attn = CohereAttention(config, linear_method=linear_method) self.self_attn = CohereAttention(config, quant_config=quant_config)
self.mlp = CohereMLP(config, linear_method=linear_method) self.mlp = CohereMLP(config, quant_config=quant_config)
self.input_layernorm = LayerNorm(param_shape=(config.hidden_size), self.input_layernorm = LayerNorm(param_shape=(config.hidden_size),
eps=config.layer_norm_eps) eps=config.layer_norm_eps)
...@@ -257,7 +258,7 @@ class CohereModel(nn.Module): ...@@ -257,7 +258,7 @@ class CohereModel(nn.Module):
def __init__( def __init__(
self, self,
config: CohereConfig, config: CohereConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -265,7 +266,7 @@ class CohereModel(nn.Module): ...@@ -265,7 +266,7 @@ class CohereModel(nn.Module):
self.embed_tokens = VocabParallelEmbedding(config.vocab_size, self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
config.hidden_size) config.hidden_size)
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
CohereDecoderLayer(config, linear_method=linear_method) CohereDecoderLayer(config, quant_config=quant_config)
for _ in range(config.num_hidden_layers) for _ in range(config.num_hidden_layers)
]) ])
self.norm = LayerNorm(param_shape=(config.hidden_size), self.norm = LayerNorm(param_shape=(config.hidden_size),
...@@ -298,14 +299,14 @@ class CohereForCausalLM(nn.Module): ...@@ -298,14 +299,14 @@ class CohereForCausalLM(nn.Module):
def __init__( def __init__(
self, self,
config: CohereConfig, config: CohereConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.linear_method = linear_method self.quant_config = quant_config
self.logits_processor = LogitsProcessor(config.vocab_size, self.logits_processor = LogitsProcessor(config.vocab_size,
scale=config.logit_scale) scale=config.logit_scale)
self.model = CohereModel(config, linear_method) self.model = CohereModel(config, quant_config)
self.sampler = Sampler() self.sampler = Sampler()
@torch.no_grad() @torch.no_grad()
......
...@@ -9,11 +9,12 @@ from vllm.distributed import (get_tensor_model_parallel_rank, ...@@ -9,11 +9,12 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce) tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import (QKVParallelLinear,
QKVParallelLinear,
ReplicatedLinear, ReplicatedLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
...@@ -44,7 +45,7 @@ class DbrxRouter(nn.Module): ...@@ -44,7 +45,7 @@ class DbrxRouter(nn.Module):
self.num_total_experts, self.num_total_experts,
bias=False, bias=False,
params_dtype=params_dtype, params_dtype=params_dtype,
linear_method=None, quant_config=None,
) )
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
...@@ -63,7 +64,7 @@ class DbrxExperts(nn.Module): ...@@ -63,7 +64,7 @@ class DbrxExperts(nn.Module):
def __init__( def __init__(
self, self,
config: DbrxConfig, config: DbrxConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
): ):
super().__init__() super().__init__()
...@@ -165,7 +166,7 @@ class DbrxAttention(nn.Module): ...@@ -165,7 +166,7 @@ class DbrxAttention(nn.Module):
def __init__( def __init__(
self, self,
config: DbrxConfig, config: DbrxConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
self.d_model = config.d_model self.d_model = config.d_model
...@@ -183,13 +184,13 @@ class DbrxAttention(nn.Module): ...@@ -183,13 +184,13 @@ class DbrxAttention(nn.Module):
self.total_num_heads, self.total_num_heads,
self.total_num_kv_heads, self.total_num_kv_heads,
bias=False, bias=False,
linear_method=linear_method, quant_config=quant_config,
) )
self.out_proj = RowParallelLinear( self.out_proj = RowParallelLinear(
self.d_model, self.d_model,
self.d_model, self.d_model,
bias=False, bias=False,
linear_method=linear_method, quant_config=quant_config,
) )
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
self.head_dim, self.head_dim,
...@@ -244,11 +245,11 @@ class DbrxFusedNormAttention(nn.Module): ...@@ -244,11 +245,11 @@ class DbrxFusedNormAttention(nn.Module):
def __init__( def __init__(
self, self,
config: DbrxConfig, config: DbrxConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
self.d_model = config.d_model self.d_model = config.d_model
self.attn = DbrxAttention(config, linear_method) self.attn = DbrxAttention(config, quant_config)
self.norm_1 = nn.LayerNorm(self.d_model) self.norm_1 = nn.LayerNorm(self.d_model)
self.norm_2 = nn.LayerNorm(self.d_model) self.norm_2 = nn.LayerNorm(self.d_model)
...@@ -278,11 +279,11 @@ class DbrxBlock(nn.Module): ...@@ -278,11 +279,11 @@ class DbrxBlock(nn.Module):
def __init__( def __init__(
self, self,
config: DbrxConfig, config: DbrxConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
self.norm_attn_norm = DbrxFusedNormAttention(config, linear_method) self.norm_attn_norm = DbrxFusedNormAttention(config, quant_config)
self.ffn = DbrxExperts(config, linear_method) self.ffn = DbrxExperts(config, quant_config)
def forward( def forward(
self, self,
...@@ -307,7 +308,7 @@ class DbrxModel(nn.Module): ...@@ -307,7 +308,7 @@ class DbrxModel(nn.Module):
def __init__( def __init__(
self, self,
config: DbrxConfig, config: DbrxConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
self.wte = VocabParallelEmbedding( self.wte = VocabParallelEmbedding(
...@@ -315,7 +316,7 @@ class DbrxModel(nn.Module): ...@@ -315,7 +316,7 @@ class DbrxModel(nn.Module):
config.d_model, config.d_model,
) )
self.blocks = nn.ModuleList( self.blocks = nn.ModuleList(
[DbrxBlock(config, linear_method) for _ in range(config.n_layers)]) [DbrxBlock(config, quant_config) for _ in range(config.n_layers)])
self.norm_f = nn.LayerNorm(config.d_model, eps=1e-5) self.norm_f = nn.LayerNorm(config.d_model, eps=1e-5)
for module in self.modules(): for module in self.modules():
if hasattr(module, "bias") and isinstance(module.bias, if hasattr(module, "bias") and isinstance(module.bias,
...@@ -348,13 +349,13 @@ class DbrxForCausalLM(nn.Module): ...@@ -348,13 +349,13 @@ class DbrxForCausalLM(nn.Module):
def __init__( def __init__(
self, self,
config: DbrxConfig, config: DbrxConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
self.linear_method = linear_method self.quant_config = quant_config
self.unpadded_vocab_size = config.vocab_size self.unpadded_vocab_size = config.vocab_size
self.transformer = DbrxModel(config, linear_method) self.transformer = DbrxModel(config, quant_config)
self.lm_head = ParallelLMHead( self.lm_head = ParallelLMHead(
config.vocab_size, config.vocab_size,
config.d_model, config.d_model,
......
...@@ -29,7 +29,8 @@ import torch ...@@ -29,7 +29,8 @@ import torch
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.config import LoRAConfig from vllm.config import LoRAConfig
from vllm.model_executor.layers.linear import LinearMethodBase from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.model_executor.models.llama import LlamaForCausalLM
...@@ -55,13 +56,13 @@ class DeciLMForCausalLM(LlamaForCausalLM): ...@@ -55,13 +56,13 @@ class DeciLMForCausalLM(LlamaForCausalLM):
def __init__( def __init__(
self, self,
config: Optional[PretrainedConfig] = None, config: Optional[PretrainedConfig] = None,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None, lora_config: Optional[LoRAConfig] = None,
) -> None: ) -> None:
config.num_key_value_heads = max(config.num_key_value_heads_per_layer) config.num_key_value_heads = max(config.num_key_value_heads_per_layer)
delattr(config, "num_key_value_heads_per_layer") delattr(config, "num_key_value_heads_per_layer")
super().__init__(config=config, super().__init__(config=config,
linear_method=linear_method, quant_config=quant_config,
lora_config=lora_config) lora_config=lora_config)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
......
...@@ -34,12 +34,13 @@ from vllm.distributed import (get_tensor_model_parallel_rank, ...@@ -34,12 +34,13 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
ReplicatedLinear, ReplicatedLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
...@@ -56,18 +57,18 @@ class DeepseekMLP(nn.Module): ...@@ -56,18 +57,18 @@ class DeepseekMLP(nn.Module):
hidden_size: int, hidden_size: int,
intermediate_size: int, intermediate_size: int,
hidden_act: str, hidden_act: str,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
reduce_results: bool = True, reduce_results: bool = True,
) -> None: ) -> None:
super().__init__() super().__init__()
self.gate_up_proj = MergedColumnParallelLinear( self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2, hidden_size, [intermediate_size] * 2,
bias=False, bias=False,
linear_method=linear_method) quant_config=quant_config)
self.down_proj = RowParallelLinear(intermediate_size, self.down_proj = RowParallelLinear(intermediate_size,
hidden_size, hidden_size,
bias=False, bias=False,
linear_method=linear_method, quant_config=quant_config,
reduce_results=reduce_results) reduce_results=reduce_results)
if hidden_act != "silu": if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. " raise ValueError(f"Unsupported activation: {hidden_act}. "
...@@ -86,7 +87,7 @@ class DeepseekMoE(nn.Module): ...@@ -86,7 +87,7 @@ class DeepseekMoE(nn.Module):
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -103,7 +104,7 @@ class DeepseekMoE(nn.Module): ...@@ -103,7 +104,7 @@ class DeepseekMoE(nn.Module):
DeepseekMLP(hidden_size=config.hidden_size, DeepseekMLP(hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size, intermediate_size=config.moe_intermediate_size,
hidden_act=config.hidden_act, hidden_act=config.hidden_act,
linear_method=linear_method, quant_config=quant_config,
reduce_results=False) reduce_results=False)
for idx in range(self.n_routed_experts) for idx in range(self.n_routed_experts)
]) ])
...@@ -112,7 +113,7 @@ class DeepseekMoE(nn.Module): ...@@ -112,7 +113,7 @@ class DeepseekMoE(nn.Module):
self.gate = ReplicatedLinear(config.hidden_size, self.gate = ReplicatedLinear(config.hidden_size,
self.n_routed_experts, self.n_routed_experts,
bias=False, bias=False,
linear_method=None) quant_config=None)
if config.n_shared_experts is not None: if config.n_shared_experts is not None:
intermediate_size = (config.moe_intermediate_size * intermediate_size = (config.moe_intermediate_size *
...@@ -121,7 +122,7 @@ class DeepseekMoE(nn.Module): ...@@ -121,7 +122,7 @@ class DeepseekMoE(nn.Module):
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
intermediate_size=intermediate_size, intermediate_size=intermediate_size,
hidden_act=config.hidden_act, hidden_act=config.hidden_act,
linear_method=linear_method, quant_config=quant_config,
reduce_results=False, reduce_results=False,
) )
...@@ -177,7 +178,7 @@ class DeepseekAttention(nn.Module): ...@@ -177,7 +178,7 @@ class DeepseekAttention(nn.Module):
rope_theta: float = 10000, rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None, rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192, max_position_embeddings: int = 8192,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
...@@ -208,14 +209,14 @@ class DeepseekAttention(nn.Module): ...@@ -208,14 +209,14 @@ class DeepseekAttention(nn.Module):
self.total_num_heads, self.total_num_heads,
self.total_num_kv_heads, self.total_num_kv_heads,
bias=False, bias=False,
linear_method=linear_method, quant_config=quant_config,
) )
self.o_proj = RowParallelLinear( self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim, self.total_num_heads * self.head_dim,
hidden_size, hidden_size,
bias=False, bias=False,
linear_method=linear_method, quant_config=quant_config,
) )
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
...@@ -251,7 +252,7 @@ class DeepseekDecoderLayer(nn.Module): ...@@ -251,7 +252,7 @@ class DeepseekDecoderLayer(nn.Module):
self, self,
config: PretrainedConfig, config: PretrainedConfig,
layer_idx: int, layer_idx: int,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
...@@ -266,18 +267,18 @@ class DeepseekDecoderLayer(nn.Module): ...@@ -266,18 +267,18 @@ class DeepseekDecoderLayer(nn.Module):
rope_theta=rope_theta, rope_theta=rope_theta,
rope_scaling=rope_scaling, rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings, max_position_embeddings=max_position_embeddings,
linear_method=linear_method, quant_config=quant_config,
) )
if (config.n_routed_experts is not None if (config.n_routed_experts is not None
and layer_idx >= config.first_k_dense_replace and layer_idx >= config.first_k_dense_replace
and layer_idx % config.moe_layer_freq == 0): and layer_idx % config.moe_layer_freq == 0):
self.mlp = DeepseekMoE(config=config, linear_method=linear_method) self.mlp = DeepseekMoE(config=config, quant_config=quant_config)
else: else:
self.mlp = DeepseekMLP( self.mlp = DeepseekMLP(
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act, hidden_act=config.hidden_act,
linear_method=linear_method, quant_config=quant_config,
) )
self.input_layernorm = RMSNorm(config.hidden_size, self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps) eps=config.rms_norm_eps)
...@@ -320,7 +321,7 @@ class DeepseekModel(nn.Module): ...@@ -320,7 +321,7 @@ class DeepseekModel(nn.Module):
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
...@@ -331,9 +332,7 @@ class DeepseekModel(nn.Module): ...@@ -331,9 +332,7 @@ class DeepseekModel(nn.Module):
config.hidden_size, config.hidden_size,
) )
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
DeepseekDecoderLayer(config, DeepseekDecoderLayer(config, layer_idx, quant_config=quant_config)
layer_idx,
linear_method=linear_method)
for layer_idx in range(config.num_hidden_layers) for layer_idx in range(config.num_hidden_layers)
]) ])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
...@@ -361,12 +360,12 @@ class DeepseekForCausalLM(nn.Module): ...@@ -361,12 +360,12 @@ class DeepseekForCausalLM(nn.Module):
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.linear_method = linear_method self.quant_config = quant_config
self.model = DeepseekModel(config, linear_method) self.model = DeepseekModel(config, quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
......
...@@ -32,10 +32,11 @@ from vllm.distributed import (get_tensor_model_parallel_rank, ...@@ -32,10 +32,11 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
tensor_model_parallel_all_reduce) tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
...@@ -76,7 +77,7 @@ class FalconAttention(nn.Module): ...@@ -76,7 +77,7 @@ class FalconAttention(nn.Module):
def __init__( def __init__(
self, self,
config: FalconConfig, config: FalconConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
...@@ -115,7 +116,7 @@ class FalconAttention(nn.Module): ...@@ -115,7 +116,7 @@ class FalconAttention(nn.Module):
self.total_num_kv_heads, self.total_num_kv_heads,
bias=config.bias, bias=config.bias,
skip_bias_add=True, skip_bias_add=True,
linear_method=linear_method, quant_config=quant_config,
) )
self.q_size = self.num_heads * self.head_dim self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim
...@@ -129,7 +130,7 @@ class FalconAttention(nn.Module): ...@@ -129,7 +130,7 @@ class FalconAttention(nn.Module):
self.hidden_size, self.hidden_size,
bias=config.bias, bias=config.bias,
skip_bias_add=True, skip_bias_add=True,
linear_method=linear_method, quant_config=quant_config,
reduce_results=self.reduce_row_parallel_results) reduce_results=self.reduce_row_parallel_results)
self.use_rotary = config.rotary self.use_rotary = config.rotary
...@@ -192,7 +193,7 @@ class FalconMLP(nn.Module): ...@@ -192,7 +193,7 @@ class FalconMLP(nn.Module):
def __init__( def __init__(
self, self,
config: FalconConfig, config: FalconConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
hidden_size = config.hidden_size hidden_size = config.hidden_size
...@@ -201,8 +202,7 @@ class FalconMLP(nn.Module): ...@@ -201,8 +202,7 @@ class FalconMLP(nn.Module):
4 * hidden_size, 4 * hidden_size,
bias=config.bias, bias=config.bias,
skip_bias_add=True, skip_bias_add=True,
linear_method=linear_method) quant_config=quant_config)
quant_config = getattr(linear_method, "quant_config", None)
self.act = get_act_fn("gelu", quant_config, 4 * hidden_size) self.act = get_act_fn("gelu", quant_config, 4 * hidden_size)
self.reduce_row_parallel_results = not (config.new_decoder_architecture self.reduce_row_parallel_results = not (config.new_decoder_architecture
or config.parallel_attn) or config.parallel_attn)
...@@ -212,7 +212,7 @@ class FalconMLP(nn.Module): ...@@ -212,7 +212,7 @@ class FalconMLP(nn.Module):
bias=config.bias, bias=config.bias,
skip_bias_add=True, skip_bias_add=True,
reduce_results=self.reduce_row_parallel_results, reduce_results=self.reduce_row_parallel_results,
linear_method=linear_method) quant_config=quant_config)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
# NOTE(zhuohan): Following huggingface, we do not fuse bias add here. # NOTE(zhuohan): Following huggingface, we do not fuse bias add here.
...@@ -229,13 +229,13 @@ class FalconDecoderLayer(nn.Module): ...@@ -229,13 +229,13 @@ class FalconDecoderLayer(nn.Module):
def __init__( def __init__(
self, self,
config: FalconConfig, config: FalconConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
hidden_size = config.hidden_size hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads self.num_heads = config.num_attention_heads
self.self_attention = FalconAttention(config, linear_method) self.self_attention = FalconAttention(config, quant_config)
self.mlp = FalconMLP(config, linear_method) self.mlp = FalconMLP(config, quant_config)
self.config = config self.config = config
if config.new_decoder_architecture: if config.new_decoder_architecture:
...@@ -311,7 +311,7 @@ class FalconModel(nn.Module): ...@@ -311,7 +311,7 @@ class FalconModel(nn.Module):
def __init__( def __init__(
self, self,
config: FalconConfig, config: FalconConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -327,7 +327,7 @@ class FalconModel(nn.Module): ...@@ -327,7 +327,7 @@ class FalconModel(nn.Module):
# Transformer blocks # Transformer blocks
self.h = nn.ModuleList([ self.h = nn.ModuleList([
FalconDecoderLayer(config, linear_method) FalconDecoderLayer(config, quant_config)
for _ in range(config.num_hidden_layers) for _ in range(config.num_hidden_layers)
]) ])
...@@ -359,12 +359,12 @@ class FalconForCausalLM(nn.Module): ...@@ -359,12 +359,12 @@ class FalconForCausalLM(nn.Module):
def __init__( def __init__(
self, self,
config: FalconConfig, config: FalconConfig,
linear_method: Optional[LinearMethodBase] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
self.linear_method = linear_method self.quant_config = quant_config
self.transformer = FalconModel(config, linear_method) self.transformer = FalconModel(config, quant_config)
self.lm_head_weight = self.transformer.word_embeddings.weight self.lm_head_weight = self.transformer.word_embeddings.weight
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment