Commit 1591c68f authored by zhuwenwen's avatar zhuwenwen
Browse files

merge v0.4.2

parents 09bcf00b c7f2cf2b
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
from torch.nn import Module
from torch.nn.parameter import Parameter
from vllm.model_executor.layers.linear import (LinearMethodBase,
set_weight_attrs)
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.utils import set_weight_attrs
ACTIVATION_SCHEMES = ["static", "dynamic"]
class FP8Config(QuantizationConfig):
logger = init_logger(__name__)
class Fp8Config(QuantizationConfig):
"""Config class for FP8."""
def __init__(
self,
is_checkpoint_fp8_serialized: bool = False,
activation_scheme: str = "dynamic",
) -> None:
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
if is_checkpoint_fp8_serialized:
logger.warning("Detected fp8 checkpoint. Please note that the "
"format is experimental and subject to change.")
if activation_scheme not in ACTIVATION_SCHEMES:
raise ValueError(
f"Unsupported activation scheme {activation_scheme}")
self.activation_scheme = activation_scheme
@classmethod
def get_name(cls) -> str:
return "fp8"
......@@ -23,21 +43,25 @@ class FP8Config(QuantizationConfig):
@classmethod
def get_min_capability(cls) -> int:
# TODO: PyTorch 2.3.0+ is required to run FP8 on
# SM 89 (e.g. Ada) GPUs. Specifically, this PR has to
# be included: https://github.com/pytorch/pytorch/pull/118881
return 90
return 89
@classmethod
def get_config_filenames(cls) -> List[str]:
return []
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "FP8Config":
return cls()
def from_config(cls, config: Dict[str, Any]) -> "Fp8Config":
quant_method = cls.get_from_keys(config, ["quant_method"])
is_checkpoint_fp8_serialized = ("fp8" in quant_method)
activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
activation_scheme=activation_scheme)
def get_linear_method(self) -> "Fp8LinearMethod":
return Fp8LinearMethod(self)
def get_quant_method(
self, layer: torch.nn.Module) -> Optional["Fp8LinearMethod"]:
if isinstance(layer, LinearBase):
return Fp8LinearMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
return []
......@@ -45,8 +69,12 @@ class FP8Config(QuantizationConfig):
class Fp8LinearMethod(LinearMethodBase):
"""Linear method for FP8.
We now support common FP16/BF16 model checkpoints ONLY. The weight
scaling factor will be initialized after the model weights are loaded.
Supports loading FP8 checkpoints with static weight scale and
dynamic/static activation scale.
Also supports loading quantized FP16/BF16 model checkpoints with dynamic
activation scaling. The weight scaling factor will be initialized after
the model weights are loaded.
Limitations:
1. Only support per-tensor quantization due to torch._scaled_mm support.
......@@ -57,9 +85,27 @@ class Fp8LinearMethod(LinearMethodBase):
quant_config: The quantization config.
"""
def __init__(self, quant_config: FP8Config):
def __init__(self, quant_config: Fp8Config):
self.quant_config = quant_config
def _create_scale_param(
self,
scale_name: str,
layer: torch.nn.Module,
output_partition_sizes: List[int],
**extra_weight_attrs,
) -> None:
scale = Parameter(torch.empty(len(output_partition_sizes),
dtype=torch.float32),
requires_grad=False)
layer.register_parameter(scale_name, scale)
set_weight_attrs(
scale, {
**extra_weight_attrs,
"fp8_scales_shard_indexer":
self.scales_shard_indexer,
})
def create_weights(
self,
layer: torch.nn.Module,
......@@ -70,70 +116,150 @@ class Fp8LinearMethod(LinearMethodBase):
params_dtype: torch.dtype,
**extra_weight_attrs,
):
del input_size, output_size
output_size_per_partition = sum(output_partition_sizes)
layer.process_after_load = True
layer.logical_widths = output_partition_sizes
# WEIGHT
weight_dtype = (torch.float8_e4m3fn
if self.quant_config.is_checkpoint_fp8_serialized else
params_dtype)
weight = Parameter(torch.empty(output_size_per_partition,
input_size_per_partition,
dtype=params_dtype),
dtype=weight_dtype),
requires_grad=False)
layer.register_parameter("weight", weight)
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
set_weight_attrs(weight, extra_weight_attrs)
set_weight_attrs(weight, {
**extra_weight_attrs,
"input_dim": 1,
"output_dim": 0,
})
w_scale = Parameter(
torch.empty(1, dtype=torch.float32),
requires_grad=False,
)
layer.register_parameter("weight_scaling_factor", w_scale)
# If checkpoint is serialized fp8, load them.
# Otherwise, wait until process_weights_after_loading.
if self.quant_config.is_checkpoint_fp8_serialized:
# WEIGHT SCALE
self._create_scale_param(
scale_name="weight_scale",
layer=layer,
output_partition_sizes=output_partition_sizes,
**extra_weight_attrs)
# ACTIVATION SCALE
if self.quant_config.activation_scheme == "static":
self._create_scale_param(
scale_name="act_scale",
layer=layer,
output_partition_sizes=output_partition_sizes,
**extra_weight_attrs)
def scales_shard_indexer(
self, param: torch.Tensor, loaded_weight: torch.Tensor,
shard_id: Union[str, int]) -> Tuple[torch.Tensor, torch.Tensor]:
qkv_idxs = {"q": 0, "k": 1, "v": 2}
if isinstance(shard_id, int):
pass
elif isinstance(shard_id, str):
if shard_id not in qkv_idxs:
raise ValueError(f"Unknown shard_id: {shard_id}")
shard_id = qkv_idxs[shard_id]
else:
ValueError(f"Shard id must be int or str but got {type(shard_id)}")
return param[shard_id], loaded_weight
def process_weights_after_loading(self, layer: Module) -> None:
# Although the linear_method is propagated to all layers,
# only linear layers invoke "create_weights". So we check
# whether "weight_scaling_facor" is registered to determine
# whether the layer is a linear layer that requires quantization.
if not hasattr(layer, "weight_scaling_factor"):
if (not hasattr(layer, "process_after_load")
or not layer.process_after_load):
return
# If checkpoint is fp/bf16 (not serialized fp8), quantize the weights.
if not self.quant_config.is_checkpoint_fp8_serialized:
qweight, weight_scale = ops.scaled_fp8_quant(layer.weight,
scale=None)
layer.weight = Parameter(qweight.t(), requires_grad=False)
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
layer.logical_widths = None
layer.act_scale = None
return
qweight, weight_scale = per_tensor_quantize(layer.weight)
# torch._scaled_mm requires column-major in the second
# input (weight), so we transpose the quantized weight.
layer.weight = Parameter(qweight.t(), requires_grad=False)
layer.weight_scaling_factor.data.copy_(weight_scale)
def apply_weights(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
qinput, x_scale = per_tensor_quantize(x)
# If checkpoint is fp8, requantize the separately quantized logical
# weights into a single fp8 weight with a single weight scale.
else:
# WEIGHT_SCALE / WEIGHT
# Loop over logical weights, requantizing with single scale.
max_w_scale = layer.weight_scale.max()
start = 0
for idx, logical_width in enumerate(layer.logical_widths):
end = start + logical_width
weight_dq = per_tensor_dequantize(layer.weight[start:end, :],
layer.weight_scale[idx])
layer.weight[start:end, :] = per_tensor_quantize(
weight_dq, layer.weight_scale.max())
start = end
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
# WEIGHT
# Transpose weight for passing to torch._scaled_mm
weight = layer.weight
layer.weight = Parameter(weight.t(), requires_grad=False)
# ACT_SCALE
# Dynamic: set to None (required input to ops.scaled_fp8_quant).
# Static: set to max of the act_scales (since they are equal).
if self.quant_config.activation_scheme == "dynamic":
layer.act_scale = None
elif self.quant_config.activation_scheme == "static":
if not all_close_1d(layer.act_scale):
raise ValueError(
"All the act_scales for the logical weights of a layer "
f"must be equal. But got {layer.act_scale}")
layer.act_scale = Parameter(layer.act_scale.max(),
requires_grad=False)
else:
raise ValueError(
f"Unknown scheme {self.quant_config.activation_scheme}")
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
# ops.scaled_fp8_quant supports both dynamic and static quant.
# If dynamic, layer.act_scale is None and x_scale computed from x.
# If static, layer.act_scale is scalar and x_scale set to act_scale.
qinput, x_scale = ops.scaled_fp8_quant(x, layer.act_scale)
# Fused GEMM_DQ
output, _ = torch._scaled_mm(
qinput,
layer.weight,
out_dtype=x.dtype,
scale_a=x_scale,
scale_b=layer.weight_scaling_factor,
scale_b=layer.weight_scale,
bias=bias,
)
return output
def per_tensor_quantize(tensor: torch.Tensor) -> Tuple[torch.Tensor, float]:
"""Quantize a tensor using per-tensor static scaling factor.
def all_close_1d(x: torch.Tensor) -> bool:
assert len(x.shape) == 1
return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0]))
Args:
tensor: The input tensor.
"""
def per_tensor_quantize(tensor: torch.Tensor,
inv_scale: float) -> torch.Tensor:
finfo = torch.finfo(torch.float8_e4m3fn)
# Calculate the scale as dtype max divided by absmax.
# Since .abs() creates a new tensor, we use aminmax to get
# the min and max first and then calculate the absmax.
min_val, max_val = tensor.aminmax()
amax = min_val.abs().max(max_val.abs())
scale = finfo.max / amax.clamp(min=1e-12)
# scale and clamp the tensor to bring it to
# the representative range of float8 data type
# (as default cast is unsaturated)
qweight = (tensor * scale).clamp(min=finfo.min, max=finfo.max)
# Return both float8 data and the inverse scale (as float),
# as both required as inputs to torch._scaled_mm
qweight = qweight.to(torch.float8_e4m3fn)
scale = scale.float().reciprocal()
return qweight, scale
qweight = (tensor / inv_scale).clamp(min=finfo.min, max=finfo.max)
return qweight.to(torch.float8_e4m3fn)
def per_tensor_dequantize(tensor: torch.Tensor,
inv_scale: float) -> torch.Tensor:
fake_qweight = tensor.to(torch.float16)
dq_weight = fake_qweight * inv_scale
return dq_weight
......@@ -7,10 +7,10 @@ import torch
from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import (LinearMethodBase,
set_weight_attrs)
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.utils import set_weight_attrs
class GPTQConfig(QuantizationConfig):
......@@ -63,8 +63,11 @@ class GPTQConfig(QuantizationConfig):
desc_act = cls.get_from_keys(config, ["desc_act"])
return cls(weight_bits, group_size, desc_act)
def get_linear_method(self) -> "GPTQLinearMethod":
return GPTQLinearMethod(self)
def get_quant_method(
self, layer: torch.nn.Module) -> Optional["GPTQLinearMethod"]:
if isinstance(layer, LinearBase):
return GPTQLinearMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
return []
......@@ -194,10 +197,10 @@ class GPTQLinearMethod(LinearMethodBase):
layer.exllama_state = exllama_state
def apply_weights(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
qweight = layer.qweight
out_shape = x.shape[:-1] + (qweight.shape[-1], )
reshaped_x = x.reshape(-1, x.shape[-1])
......
import enum
from enum import Enum
from typing import Any, Dict, List, Optional
import torch
from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
GPTQ_MARLIN_TILE = 16
GPTQ_MARLIN_MIN_THREAD_N = 64
GPTQ_MARLIN_MIN_THREAD_K = 128
GPTQ_MARLIN_MAX_PARALLEL = 16
GPTQ_MARLIN_SUPPORTED_NUM_BITS = [4, 8]
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
GPTQ_MARLIN_SUPPORTED_SYM = [True]
# Permutations for Marlin scale shuffling
def get_scale_perms(num_bits):
scale_perm = []
for i in range(8):
scale_perm.extend([i + 8 * j for j in range(8)])
scale_perm_single = []
for i in range(4):
scale_perm_single.extend(
[2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
return scale_perm, scale_perm_single
def get_pack_factor(num_bits):
assert (num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS
), f"Unsupported num_bits = {num_bits}"
return 32 // num_bits
def marlin_permute_scales(s, size_k, size_n, group_size, num_bits):
scale_perm, scale_perm_single = get_scale_perms(num_bits)
if group_size < size_k and group_size != -1:
s = s.reshape((-1, len(scale_perm)))[:, scale_perm]
else:
s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
s = s.reshape((-1, size_n)).contiguous()
return s
class GPTQMarlinConfig(QuantizationConfig):
"""Config class for GPTQ Marlin"""
def __init__(self, weight_bits: int, group_size: int, desc_act: bool,
is_sym: bool) -> None:
if desc_act and group_size == -1:
# In this case, act_order == True is the same as act_order == False
# (since we have only one group per output channel)
desc_act = False
self.weight_bits = weight_bits
self.group_size = group_size
self.desc_act = desc_act
self.is_sym = is_sym
# Verify
if self.weight_bits not in GPTQ_MARLIN_SUPPORTED_NUM_BITS:
raise ValueError(
f"Marlin does not support weight_bits = {self.weight_bits}. "
f"Only weight_bits = {GPTQ_MARLIN_SUPPORTED_NUM_BITS} "
"are supported.")
if self.group_size not in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES:
raise ValueError(
f"Marlin does not support group_size = {self.group_size}. "
f"Only group_sizes = {GPTQ_MARLIN_SUPPORTED_GROUP_SIZES} "
"are supported.")
if self.is_sym not in GPTQ_MARLIN_SUPPORTED_SYM:
raise ValueError(
f"Marlin does not support is_sym = {self.is_sym}. "
f"Only sym = {GPTQ_MARLIN_SUPPORTED_SYM} are supported.")
# Init
self.pack_factor = get_pack_factor(weight_bits)
self.tile_size = GPTQ_MARLIN_TILE
self.min_thread_n = GPTQ_MARLIN_MIN_THREAD_N
self.min_thread_k = GPTQ_MARLIN_MIN_THREAD_K
self.max_parallel = GPTQ_MARLIN_MAX_PARALLEL
def __repr__(self) -> str:
return (f"GPTQMarlinConfig(weight_bits={self.weight_bits}, "
f"group_size={self.group_size}, "
f"desc_act={self.desc_act})")
@classmethod
def get_name(cls) -> str:
return "gptq_marlin"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.half]
@classmethod
def get_min_capability(cls) -> int:
return 80
@classmethod
def get_config_filenames(cls) -> List[str]:
return ["quantize_config.json"]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "GPTQMarlinConfig":
weight_bits = cls.get_from_keys(config, ["bits"])
group_size = cls.get_from_keys(config, ["group_size"])
desc_act = cls.get_from_keys(config, ["desc_act"])
is_sym = cls.get_from_keys(config, ["sym"])
return cls(weight_bits, group_size, desc_act, is_sym)
def get_quant_method(
self,
layer: torch.nn.Module) -> Optional["GPTQMarlinLinearMethod"]:
if isinstance(layer, LinearBase):
return GPTQMarlinLinearMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
return []
@classmethod
def is_marlin_compatible(cls, quant_config: Dict[str, Any]):
# Extract data from quant config.
num_bits = quant_config.get("bits", None)
group_size = quant_config.get("group_size", None)
sym = quant_config.get("sym", None)
desc_act = quant_config.get("desc_act", None)
# If we cannot find the info needed in the config, cannot convert.
if (num_bits is None or group_size is None or sym is None
or desc_act is None):
return False
# If the capability of the device is too low, cannot convert.
major, minor = torch.cuda.get_device_capability()
device_capability = major * 10 + minor
if device_capability < cls.get_min_capability():
return False
# Otherwise, can convert if model satisfies marlin constraints.
return (num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS
and group_size in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
and sym in GPTQ_MARLIN_SUPPORTED_SYM)
class GPTQMarlinState(Enum):
REPACK = enum.auto()
READY = enum.auto()
class GPTQMarlinLinearMethod(LinearMethodBase):
"""Linear method for GPTQ Marlin.
Args:
quant_config: The GPTQ Marlin quantization config.
"""
def __init__(self, quant_config: GPTQMarlinConfig) -> None:
self.quant_config = quant_config
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
) -> None:
del output_size
# Normalize group_size
if self.quant_config.group_size != -1:
group_size = self.quant_config.group_size
else:
group_size = input_size
# Validate dtype
if params_dtype != torch.float16:
raise ValueError(
f"The params dtype must be float16, but got {params_dtype}")
# Validate output_size_per_partition
output_size_per_partition = sum(output_partition_sizes)
if output_size_per_partition % self.quant_config.min_thread_n != 0:
raise ValueError(
f"Weight output_size_per_partition = "
f"{output_size_per_partition} is not divisible by "
f" min_thread_n = {self.quant_config.min_thread_n}.")
# Validate input_size_per_partition
if input_size_per_partition % self.quant_config.min_thread_k != 0:
raise ValueError(
f"Weight input_size_per_partition = "
f"{input_size_per_partition} is not divisible "
f"by min_thread_k = {self.quant_config.min_thread_k}.")
if (group_size < input_size
and input_size_per_partition % group_size != 0):
raise ValueError(
f"Weight input_size_per_partition = {input_size_per_partition}"
f" is not divisible by group_size = {group_size}.")
# Detect sharding of scales/zp
# By default, no sharding over "input dim"
scales_and_zp_size = input_size // group_size
scales_and_zp_input_dim = None
if self.quant_config.desc_act:
# Act-order case
assert self.quant_config.group_size != -1
is_k_full = input_size_per_partition == input_size
else:
# No act-order case
# K is always full due to full alignment with
# group-size and shard of scales/zp
is_k_full = True
# If this is a row-parallel case, then shard scales/zp
if (input_size != input_size_per_partition
and self.quant_config.group_size != -1):
scales_and_zp_size = input_size_per_partition // group_size
scales_and_zp_input_dim = 0
# Init buffers
# Quantized weights
qweight = Parameter(
torch.empty(
input_size_per_partition // self.quant_config.pack_factor,
output_size_per_partition,
dtype=torch.int32,
),
requires_grad=False,
)
set_weight_attrs(
qweight,
{
**extra_weight_attrs,
"input_dim": 0,
"output_dim": 1,
"packed_dim": 0,
"pack_factor": self.quant_config.pack_factor,
},
)
# Activation order
g_idx = Parameter(
torch.empty(
input_size_per_partition,
dtype=torch.int32,
),
requires_grad=False,
)
# Ignore warning from fused linear layers such as QKVParallelLinear.
set_weight_attrs(
g_idx,
{
**extra_weight_attrs, "input_dim": 0,
"ignore_warning": True
},
)
g_idx_sort_indices = Parameter(
torch.empty(
g_idx.shape,
dtype=torch.int32,
),
requires_grad=False,
)
set_weight_attrs(g_idx_sort_indices, extra_weight_attrs)
# Scales
scales = Parameter(
torch.empty(
scales_and_zp_size,
output_size_per_partition,
dtype=params_dtype,
),
requires_grad=False,
)
set_weight_attrs(
scales,
{
**extra_weight_attrs,
"input_dim": scales_and_zp_input_dim,
"output_dim": 1,
},
)
# Quantized zero-points
qzeros = Parameter(
torch.empty(
scales_and_zp_size,
output_size_per_partition // self.quant_config.pack_factor,
dtype=torch.int32,
device="meta",
),
requires_grad=False,
)
set_weight_attrs(
qzeros,
{
**extra_weight_attrs,
"input_dim": scales_and_zp_input_dim,
"output_dim": 1,
"packed_dim": 1,
"pack_factor": self.quant_config.pack_factor,
},
)
# Allocate marlin workspace
max_workspace_size = (
output_size_per_partition //
self.quant_config.min_thread_n) * self.quant_config.max_parallel
workspace = torch.zeros(max_workspace_size,
dtype=torch.int,
requires_grad=False)
layer.register_parameter("qweight", qweight)
layer.register_parameter("g_idx", g_idx)
layer.register_parameter("g_idx_sort_indices", g_idx_sort_indices)
layer.register_parameter("scales", scales)
layer.register_parameter("qzeros", qzeros)
layer.workspace = workspace
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
layer.input_size = input_size
layer.is_k_full = is_k_full
layer.marlin_state = GPTQMarlinState.REPACK
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
reshaped_x = x.reshape(-1, x.shape[-1])
size_m = reshaped_x.shape[0]
part_size_n = layer.output_size_per_partition
part_size_k = layer.input_size_per_partition
full_size_k = layer.input_size
out_shape = x.shape[:-1] + (part_size_n, )
if layer.marlin_state == GPTQMarlinState.REPACK:
layer.marlin_state = GPTQMarlinState.READY
# Newly generated tensors need to replace existing tensors that are
# already registered as parameters by vLLM (and won't be freed)
def replace_tensor(name, new_t):
# It is important to use resize_() here since it ensures
# the same buffer is reused
getattr(layer, name).resize_(new_t.shape)
getattr(layer, name).copy_(new_t)
del new_t
cur_device = layer.qweight.device
# Process act_order
if self.quant_config.desc_act:
# Get sorting based on g_idx
g_idx_sort_indices = torch.argsort(layer.g_idx).to(torch.int)
sorted_g_idx = layer.g_idx[g_idx_sort_indices]
replace_tensor("g_idx", sorted_g_idx)
replace_tensor("g_idx_sort_indices", g_idx_sort_indices)
else:
# Reset g_idx related tensors
layer.g_idx = Parameter(
torch.empty(0, dtype=torch.int, device=cur_device),
requires_grad=False,
)
layer.g_idx_sort_indices = Parameter(
torch.empty(0, dtype=torch.int, device=cur_device),
requires_grad=False,
)
# Repack weights
marlin_qweight = ops.gptq_marlin_repack(
layer.qweight,
layer.g_idx_sort_indices,
part_size_k,
part_size_n,
self.quant_config.weight_bits,
)
replace_tensor("qweight", marlin_qweight)
# Permute scales
scales_size_k = part_size_k
scales_size_n = part_size_n
if self.quant_config.desc_act:
scales_size_k = full_size_k
marlin_scales = marlin_permute_scales(
layer.scales,
scales_size_k,
scales_size_n,
self.quant_config.group_size,
self.quant_config.weight_bits,
)
replace_tensor("scales", marlin_scales)
output = ops.gptq_marlin_gemm(
reshaped_x,
layer.qweight,
layer.scales,
layer.g_idx,
layer.g_idx_sort_indices,
layer.workspace,
self.quant_config.weight_bits,
size_m,
part_size_n,
part_size_k,
layer.is_k_full,
)
if bias is not None:
output.add_(bias) # In-place add
return output.reshape(out_shape)
......@@ -4,10 +4,10 @@ import torch
from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import (LinearMethodBase,
set_weight_attrs)
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.utils import set_weight_attrs
class MarlinConfig(QuantizationConfig):
......@@ -72,8 +72,11 @@ class MarlinConfig(QuantizationConfig):
group_size = cls.get_from_keys(config, ["group_size"])
return cls(group_size)
def get_linear_method(self) -> "MarlinLinearMethod":
return MarlinLinearMethod(self)
def get_quant_method(
self, layer: torch.nn.Module) -> Optional["MarlinLinearMethod"]:
if isinstance(layer, LinearBase):
return MarlinLinearMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
return []
......@@ -197,7 +200,7 @@ class MarlinLinearMethod(LinearMethodBase):
layer.register_parameter("workspace", workspace)
set_weight_attrs(workspace, extra_weight_attrs)
def apply_weights(
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
......
......@@ -4,10 +4,10 @@ import torch
from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import (LinearMethodBase,
set_weight_attrs)
from vllm.model_executor.layers.linear import LinearBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.utils import set_weight_attrs
from vllm.utils import is_hip
......@@ -51,14 +51,17 @@ class SqueezeLLMConfig(QuantizationConfig):
weight_bits = cls.get_from_keys(config, ["wbits"])
return cls(weight_bits)
def get_linear_method(self) -> "SqueezeLLMLinearMethod":
return SqueezeLLMLinearMethod(self)
def get_quant_method(
self, layer: torch.nn.Module) -> Optional[QuantizeMethodBase]:
if isinstance(layer, LinearBase):
return SqueezeLLMLinearMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
return []
class SqueezeLLMLinearMethod(LinearMethodBase):
class SqueezeLLMLinearMethod(QuantizeMethodBase):
"""Linear method for SqueezeLLM.
Args:
......@@ -112,10 +115,10 @@ class SqueezeLLMLinearMethod(LinearMethodBase):
layer.register_parameter("lookup_table", lookup_table)
set_weight_attrs(lookup_table, extra_weight_attrs)
def apply_weights(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
qweight = layer.qweight
lookup_table = layer.lookup_table
out_shape = x.shape[:-1] + (qweight.shape[-1], )
......
......@@ -156,6 +156,12 @@ class RotaryEmbedding(nn.Module):
self.cos_sin_cache, self.is_neox_style)
return query, key
def extra_repr(self) -> str:
s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}"
s += f", max_position_embeddings={self.max_position_embeddings}"
s += f", base={self.base}, is_neox_style={self.is_neox_style}"
return s
class LinearScalingRotaryEmbedding(RotaryEmbedding):
"""RotaryEmbedding extended with linear scaling.
......@@ -338,6 +344,114 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
return cache
class Phi3SuScaledRotaryEmbedding(nn.Module):
"""Phi3 family of models scaled rotary embedding.
Based on the original RotaryEmbedding implementation.
"""
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
original_max_position_embeddings: int,
base: int,
is_neox_style: bool,
short_factor: List[float],
long_factor: List[float],
short_mscale: float = 1.1,
long_mscale: float = 1.225,
):
super().__init__()
if rotary_dim != head_size:
raise ValueError(
f"`Phi3SuScaledRotaryEmbedding` does not support rotary_dim != \
head_size ({rotary_dim}!={head_size}).")
if is_neox_style is False:
raise ValueError(
"`Phi3SuScaledRotaryEmbedding` only supports neox_style.")
self.head_size = head_size
self.max_position_embeddings = max_position_embeddings
self.original_max_position_embeddings = original_max_position_embeddings
self.base = base
self.short_factor = short_factor
self.long_factor = long_factor
self.short_mscale = short_mscale
self.long_mscale = long_mscale
short_cache = self._compute_cos_sin_cache(
original_max_position_embeddings, short_factor, short_mscale)
short_cache = short_cache.to(torch.get_default_dtype())
self.register_buffer("short_cos_sin_cache",
short_cache,
persistent=False)
long_cache = self._compute_cos_sin_cache(max_position_embeddings,
long_factor, long_mscale)
long_cache = long_cache.to(torch.get_default_dtype())
self.register_buffer("long_cos_sin_cache",
long_cache,
persistent=False)
long_short_cache = torch.cat(
[self.short_cos_sin_cache, self.long_cos_sin_cache], dim=0)
self.register_buffer("long_short_cos_sin_cache",
long_short_cache,
persistent=False)
def _compute_inv_freq(self, rescale_factors: List[float]) -> torch.Tensor:
rescale_factors = torch.tensor(rescale_factors, dtype=torch.float32)
inv_freq = 1.0 / (rescale_factors * (self.base**(torch.arange(
0, self.head_size, 2, dtype=torch.float) / self.head_size)))
return inv_freq
def _compute_cos_sin_cache(
self,
max_position_embeddings: int,
rescale_factors: List[float],
mscale: float,
) -> torch.Tensor:
inv_freq = self._compute_inv_freq(rescale_factors)
t = torch.arange(max_position_embeddings, dtype=torch.float)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos() * mscale
sin = freqs.sin() * mscale
cache = torch.cat((cos, sin), dim=-1)
return cache
def forward(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
query = query.view(*query.shape[:-1], -1, self.head_size)
key = key.view(*key.shape[:-1], -1, self.head_size)
k = self.original_max_position_embeddings
long_prompt_offset = (torch.any(positions > k).float() *
torch.full_like(positions, k)).long()
idx = (torch.add(positions, long_prompt_offset)
if long_prompt_offset is not None else positions)
self.long_short_cos_sin_cache: torch.Tensor = (
self.long_short_cos_sin_cache.to(idx.device))
idx = torch.add(idx, offsets) if offsets is not None else idx
cos_sin = torch.index_select(self.long_short_cos_sin_cache, 0, idx)
cos, sin = cos_sin.chunk(2, dim=-1)
cos = cos.repeat(1, 2).unsqueeze(-2)
sin = sin.repeat(1, 2).unsqueeze(-2)
query = query * cos + _rotate_neox(query) * sin
key = key * cos + _rotate_neox(key) * sin
return query.flatten(-2), key.flatten(-2)
_ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {}
......@@ -349,17 +463,26 @@ def get_rope(
is_neox_style: bool = True,
rope_scaling: Optional[Dict[str, Any]] = None,
) -> RotaryEmbedding:
if rope_scaling is not None:
# Transforms every value that is a list into a tuple for caching calls
rope_scaling_tuple = {
k: tuple(v) if isinstance(v, list) else v
for k, v in rope_scaling.items()
}
rope_scaling_args = tuple(rope_scaling_tuple.items())
else:
rope_scaling_args = None
key = (head_size, rotary_dim, max_position, base, is_neox_style,
tuple(rope_scaling.items()) if rope_scaling is not None else None)
rope_scaling_args)
if key in _ROPE_DICT:
return _ROPE_DICT[key]
if rope_scaling is None:
rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base,
is_neox_style)
else:
scaling_type = rope_scaling["type"]
scaling_factor = rope_scaling["factor"]
if scaling_type != "su":
scaling_factor = rope_scaling["factor"]
if scaling_type == "linear":
rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim,
max_position, base,
......@@ -383,6 +506,19 @@ def get_rope(
base, is_neox_style,
scaling_factor,
**extra_kwargs)
elif scaling_type == "su":
short_factor = rope_scaling["short_factor"]
long_factor = rope_scaling["long_factor"]
original_max_position = rope_scaling[
"original_max_position_embeddings"]
extra_kwargs = {
k: v
for k, v in rope_scaling.items()
if k in ("short_mscale", "long_mscale")
}
rotary_emb = Phi3SuScaledRotaryEmbedding(
head_size, rotary_dim, max_position, original_max_position,
base, is_neox_style, short_factor, long_factor, **extra_kwargs)
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
_ROPE_DICT[key] = rotary_emb
......
......@@ -7,11 +7,14 @@ import torch.nn as nn
from vllm.model_executor.layers.ops.sample import sample as sample_triton
from vllm.model_executor.sampling_metadata import (SamplingMetadata,
SamplingTensors)
from vllm.sampling_params import SamplingParams, SamplingType
SamplingTensors,
SequenceGroupToSample)
from vllm.sampling_params import SamplingType
from vllm.sequence import (Logprob, PromptLogprobs, SampleLogprobs,
SamplerOutput, SequenceData, SequenceGroupOutput,
SequenceOutput)
SamplerOutput, SequenceGroupOutput, SequenceOutput)
# (num_token_ids, num_parent_ids) per sequence group.
SampleResultType = List[Tuple[List[int], List[int]]]
class Sampler(nn.Module):
......@@ -48,11 +51,14 @@ class Sampler(nn.Module):
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
"""
Args:
logits: (num_tokens, vocab_size).
sampling_metadata: Metadata for sampling.
"""
assert logits is not None
_, vocab_size = logits.shape
# Apply min_tokens penalty which sets stop tokens to -inf if min_tokens
# have not been generated yet
logits = _apply_min_tokens_penalty(logits, sampling_metadata)
# Prepare sampling tensors with pinned memory to avoid blocking.
......@@ -83,7 +89,6 @@ class Sampler(nn.Module):
# Compute the probabilities.
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
# Compute the log probabilities.
# Use log_softmax to ensure numerical stability.
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
# Sample the next tokens.
......@@ -98,8 +103,7 @@ class Sampler(nn.Module):
if self.include_gpu_probs_tensor:
assert maybe_sampled_tokens_tensor is not None
sampled_tokens_tensor = maybe_sampled_tokens_tensor
on_device_tensors = (probs, sampled_tokens_tensor)
on_device_tensors = (probs, logprobs, maybe_sampled_tokens_tensor)
else:
on_device_tensors = None
......@@ -149,46 +153,46 @@ def _apply_min_tokens_penalty(
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
"""Apply min_tokens penalty which sets stop tokens to -inf if min_tokens
have not been generated yet
"""
# list of indices in logits that will be set to -inf
logits_to_penalize = []
start_idx = 0
for i, seq_group in enumerate(sampling_metadata.seq_groups):
seq_ids, sampling_params = seq_group
# handle prompt_logprobs by skipping rows in logits added for the prompt
# tokens (prompt logprobs are not penalized)
if (i < sampling_metadata.num_prompts
and sampling_params.prompt_logprobs is not None):
assert len(seq_ids) == 1
start_idx += sampling_metadata.prompt_lens[i] - 1
logits_to_penalize: List[Tuple[int, int]] = []
logits_applied = 0
for seq_group in sampling_metadata.seq_groups:
seq_ids = seq_group.seq_ids
sampling_params = seq_group.sampling_params
sample_indices = seq_group.sample_indices
logits_applied += len(sample_indices) + len(
seq_group.prompt_logprob_indices)
if not seq_group.do_sample:
continue
start_idx = sample_indices[0]
min_tokens = sampling_params.min_tokens
if min_tokens > 0:
token_ids_to_penalize = sampling_params.all_stop_token_ids
if min_tokens > 0 and token_ids_to_penalize:
seqs_to_penalize = []
for i, seq_id in enumerate(seq_ids):
seq_data = sampling_metadata.seq_data[seq_id]
for j, seq_id in enumerate(seq_ids):
seq_data = seq_group.seq_data[seq_id]
if len(seq_data.output_token_ids) < min_tokens:
seqs_to_penalize.append(i)
seqs_to_penalize.append(j)
if seqs_to_penalize:
# convert to the index into logits
seqs_to_penalize = [start_idx + i for i in seqs_to_penalize]
# use set() to remove any duplicates
token_ids_to_penalize = set(sampling_params.stop_token_ids +
[sampling_params.eos_token_id])
seqs_to_penalize = [start_idx + j for j in seqs_to_penalize]
# itertools.product pairs each seq index with every token id
logits_to_penalize.extend(
itertools.product(seqs_to_penalize, token_ids_to_penalize))
start_idx += len(seq_ids)
if logits_to_penalize:
# use zip and * to group indices along each dimension
# eg. [ (1,2), (1,3), (5,6) ] -> ( (1,1,5), (2,3,6) )
logits[tuple(zip(*logits_to_penalize))] = -float("inf")
# verifies that no rows in logits were missed unexpectedly
assert start_idx == logits.shape[0]
assert logits_applied == logits.shape[0]
return logits
......@@ -265,14 +269,30 @@ def _apply_min_p(
def _greedy_sample(
selected_seq_groups: List[Tuple[List[int], SamplingParams]],
selected_seq_groups: List[SequenceGroupToSample],
samples: torch.Tensor,
) -> List[Tuple[List[int], List[int]]]:
) -> SampleResultType:
"""Run greedy sampling on a given samples.
Args:
selected_seq_groups: A list of sequence groups batched.
samples: (num_selected_samples,) A tensor of samples. The length of
samples could be smaller than selected_seq_groups if
seq_group.do_sample is False.
Returns:
Tuple of (next_token_ids, parent_ids). The length of returned list is
same as the length of selected_seq_groups. If the corresponding
seq_group has do_sample=False, tuple contains ([], [])
"""
samples = samples.tolist()
sample_idx = 0
results = []
results: SampleResultType = []
for seq_group in selected_seq_groups:
seq_ids, _ = seq_group
if not seq_group.do_sample:
results.append(([], []))
continue
seq_ids = seq_group.seq_ids
num_parent_seqs = len(seq_ids)
assert num_parent_seqs == 1, (
"Greedy sampling should have only one seq.")
......@@ -284,16 +304,33 @@ def _greedy_sample(
def _random_sample(
selected_seq_groups: List[Tuple[List[int], SamplingParams]],
is_prompts: List[bool],
selected_seq_groups: List[SequenceGroupToSample],
random_samples: torch.Tensor,
) -> List[Tuple[List[int], List[int]]]:
) -> SampleResultType:
"""Run random sampling on a given samples.
Args:
selected_seq_groups: A list of sequence groups batched.
random_samples: (num_selected_samples,) A tensor of samples. The
length of samples could be smaller than selected_seq_groups if
seq_group.do_sample is False.
Returns:
Tuple of (next_token_ids, parent_ids). The length of returned list is
same as the length of selected_seq_groups. If the corresponding
seq_group has do_sample=False, tuple contains ([], [])
"""
# Find the maximum best_of value of the prompt phase requests.
random_samples = random_samples.cpu()
sample_idx = 0
results = []
for seq_group, is_prompt in zip(selected_seq_groups, is_prompts):
seq_ids, sampling_params = seq_group
results: SampleResultType = []
for seq_group in selected_seq_groups:
if not seq_group.do_sample:
results.append(([], []))
continue
seq_ids = seq_group.seq_ids
sampling_params = seq_group.sampling_params
is_prompt = seq_group.is_prompt
num_parent_seqs = len(seq_ids)
if is_prompt:
# Prompt phase.
......@@ -311,11 +348,20 @@ def _random_sample(
def _beam_search_sample(
selected_seq_groups: List[Tuple[List[int], SamplingParams]],
is_prompts: List[bool],
seq_data: Dict[int, SequenceData],
selected_seq_groups: List[SequenceGroupToSample],
logprobs: torch.Tensor,
) -> List[Tuple[List[int], List[int]]]:
) -> SampleResultType:
"""Run beam sampling on a given samples.
Args:
selected_seq_groups: A list of sequence groups batched.
logprobs: (num_selected_samples, vocab_size,) A tensor of logprob
on selected sample indices.
Returns:
Tuple of (next_token_ids, parent_ids). The length of returned list is
same as the length of selected_seq_groups. If the corresponding
seq_group has do_sample=False, tuple contains ([], [])
"""
# We sample 2 * beam_width candidates to make sure that with high
# probability we can get `beam_width` candidates in addition to
# the finished sequences for the next iteration. See
......@@ -326,9 +372,14 @@ def _beam_search_sample(
# NOTE: Beam search is not vectorized, so its speed can be slower than
# other sampling methods.
sample_idx = 0
results = []
for seq_group, is_prompt in zip(selected_seq_groups, is_prompts):
seq_ids, sampling_params = seq_group
results: SampleResultType = []
for seq_group in selected_seq_groups:
if not seq_group.do_sample:
results.append(([], []))
continue
is_prompt = seq_group.is_prompt
seq_ids, sampling_params = seq_group.seq_ids, seq_group.sampling_params
num_parent_seqs = len(seq_ids)
beam_width = sampling_params.best_of
seq_group_logprobs = logprobs[sample_idx:sample_idx + num_parent_seqs]
......@@ -342,15 +393,16 @@ def _beam_search_sample(
next_token_ids = next_token_ids.tolist()
else:
# Generation phase.
cumulative_logprobs = [
seq_data[seq_id].cumulative_logprob for seq_id in seq_ids
cumulative_logprobs: List[int] = [
seq_group.seq_data[seq_id].cumulative_logprob
for seq_id in seq_ids
]
cumulative_logprobs = torch.tensor(
cumulative_logprobs_tensor = torch.tensor(
cumulative_logprobs,
dtype=torch.float,
device=seq_group_logprobs.device)
seq_group_logprobs = (seq_group_logprobs +
cumulative_logprobs.unsqueeze(dim=1))
cumulative_logprobs_tensor.unsqueeze(dim=1))
_, topk_ids = torch.topk(seq_group_logprobs.flatten(),
2 * beam_width)
topk_ids = topk_ids.tolist()
......@@ -371,8 +423,7 @@ def _beam_search_sample(
def _multinomial(
probs: torch.Tensor,
num_samples: int,
seq_groups: Optional[List[Tuple[List[int], SamplingParams]]] = None,
generators: Optional[List[torch.Generator]] = None,
seq_groups: Optional[List[SequenceGroupToSample]] = None,
) -> torch.Tensor:
if num_samples > 1:
# This is equivalent to torch.repeat_interleaved (which also
......@@ -388,9 +439,11 @@ def _multinomial(
q.exponential_()
else:
sample_idx = 0
for (seq_ids, _), generator in zip(seq_groups, generators):
for seq_group in seq_groups:
seq_ids = seq_group.seq_ids
next_sample_idx = sample_idx + len(seq_ids) * num_samples
q[sample_idx:next_sample_idx].exponential_(generator=generator)
q[sample_idx:next_sample_idx].exponential_(
generator=seq_group.generator)
sample_idx = next_sample_idx
return probs.div_(q).argmax(dim=1).view(-1, num_samples)
......@@ -401,11 +454,13 @@ def _sample_with_torch(
sampling_metadata: SamplingMetadata,
include_gpu_probs_tensor: bool,
modify_greedy_probs: bool,
) -> Tuple[List[Tuple[List[int], List[int]]], Optional[torch.Tensor]]:
categorized_seq_group_ids = {t: [] for t in SamplingType}
) -> Tuple[SampleResultType, Optional[torch.Tensor]]:
categorized_seq_group_ids: Dict[SamplingType,
List[int]] = {t: []
for t in SamplingType}
categorized_sample_indices = sampling_metadata.categorized_sample_indices
for i, seq_group in enumerate(sampling_metadata.seq_groups):
_, sampling_params = seq_group
sampling_params = seq_group.sampling_params
sampling_type = sampling_params.sampling_type
categorized_seq_group_ids[sampling_type].append(i)
......@@ -429,13 +484,11 @@ def _sample_with_torch(
num_tokens = len(sample_indices)
if num_tokens == 0:
continue
seq_group_ids = categorized_seq_group_ids[sampling_type]
seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_ids]
is_prompts = [i < sampling_metadata.num_prompts for i in seq_group_ids]
sample_metadata[sampling_type] = (seq_group_ids, seq_groups,
is_prompts, sample_indices)
long_sample_indices = sample_indices.long()
seq_group_id = categorized_seq_group_ids[sampling_type]
seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_id]
sample_metadata[sampling_type] = (seq_group_id, seq_groups)
long_sample_indices = sample_indices.long()
if sampling_type == SamplingType.GREEDY:
greedy_samples = torch.argmax(logprobs[long_sample_indices],
dim=-1)
......@@ -455,14 +508,13 @@ def _sample_with_torch(
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
max_best_of_in_batch = 1
for seq_group, is_prompt in zip(seq_groups, is_prompts):
if is_prompt:
_, sampling_params = seq_group
for seq_group in seq_groups:
if seq_group.is_prompt:
sampling_params = seq_group.sampling_params
max_best_of_in_batch = max(max_best_of_in_batch,
sampling_params.best_of)
seeded_args = {} if sampling_type == SamplingType.RANDOM else {
"seq_groups": seq_groups,
"generators": sampling_metadata.generators,
}
multinomial_samples[sampling_type] = _multinomial(
......@@ -481,25 +533,22 @@ def _sample_with_torch(
# GPU<->CPU sync happens in the loop below.
# This also converts the sample output to Python objects.
for sampling_type in SamplingType:
if sampling_type not in sample_metadata:
continue
seq_group_ids, seq_groups, is_prompts, sample_indices = sample_metadata[
sampling_type]
(seq_group_id, seq_groups) = sample_metadata[sampling_type]
if sampling_type == SamplingType.GREEDY:
sample_results = _greedy_sample(seq_groups, greedy_samples)
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
sample_results = _random_sample(seq_groups, is_prompts,
sample_results = _random_sample(seq_groups,
multinomial_samples[sampling_type])
elif sampling_type == SamplingType.BEAM:
sample_results = _beam_search_sample(seq_groups, is_prompts,
sampling_metadata.seq_data,
sample_results = _beam_search_sample(seq_groups,
beam_search_logprobs)
sample_results_dict.update(zip(seq_group_ids, sample_results))
sample_results_dict.update(zip(seq_group_id, sample_results))
sample_results = [
sample_results_dict[i]
sample_results_dict.get(i, ([], []))
for i in range(len(sampling_metadata.seq_groups))
]
return sample_results, sampled_token_ids_tensor
......@@ -510,11 +559,13 @@ def _sample_with_triton_kernel(
logprobs: torch.Tensor,
sampling_metadata: SamplingMetadata,
sampling_tensors: SamplingTensors,
) -> List[Tuple[List[int], List[int]]]:
categorized_seq_group_ids = {t: [] for t in SamplingType}
) -> SampleResultType:
categorized_seq_group_ids: Dict[SamplingType,
List[int]] = {t: []
for t in SamplingType}
categorized_sample_indices = sampling_metadata.categorized_sample_indices
for i, seq_group in enumerate(sampling_metadata.seq_groups):
_, sampling_params = seq_group
sampling_params = seq_group.sampling_params
sampling_type = sampling_params.sampling_type
categorized_seq_group_ids[sampling_type].append(i)
......@@ -530,17 +581,16 @@ def _sample_with_triton_kernel(
num_tokens = len(sample_indices)
if num_tokens == 0:
continue
seq_group_ids = categorized_seq_group_ids[sampling_type]
seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_ids]
is_prompts = [i < sampling_metadata.num_prompts for i in seq_group_ids]
sample_metadata[sampling_type] = (seq_group_ids, seq_groups,
is_prompts, sample_indices,
seq_group_id = categorized_seq_group_ids[sampling_type]
seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_id]
sample_metadata[sampling_type] = (seq_group_id, seq_groups,
sample_indices,
sampled_token_indices)
if sampling_type in (SamplingType.GREEDY, SamplingType.RANDOM,
SamplingType.RANDOM_SEED):
for seq_group, is_prompt in zip(seq_groups, is_prompts):
if is_prompt:
_, sampling_params = seq_group
for seq_group in seq_groups:
if seq_group.is_prompt:
sampling_params = seq_group.sampling_params
max_best_of_in_batch = max(max_best_of_in_batch,
sampling_params.best_of)
elif sampling_type == SamplingType.BEAM:
......@@ -564,22 +614,21 @@ def _sample_with_triton_kernel(
for sampling_type in SamplingType:
if sampling_type not in sample_metadata:
continue
(seq_group_ids, seq_groups, is_prompts, sample_indices,
(seq_group_id, seq_groups, sample_indices,
sampled_token_indices) = sample_metadata[sampling_type]
if sampling_type == SamplingType.GREEDY:
sample_results = _greedy_sample(
seq_groups, sampled_tokens[sampled_token_indices][:, 0])
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
sample_results = _random_sample(
seq_groups, is_prompts, sampled_tokens[sampled_token_indices])
seq_groups, sampled_tokens[sampled_token_indices])
elif sampling_type == SamplingType.BEAM:
sample_results = _beam_search_sample(seq_groups, is_prompts,
sampling_metadata.seq_data,
sample_results = _beam_search_sample(seq_groups,
beam_search_logprobs)
sample_results_dict.update(zip(seq_group_ids, sample_results))
sample_results_dict.update(zip(seq_group_id, sample_results))
sample_results = [
sample_results_dict[i]
sample_results_dict.get(i, ([], []))
for i in range(len(sampling_metadata.seq_groups))
]
return sample_results
......@@ -589,7 +638,19 @@ def _sample(
probs: torch.Tensor, logprobs: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_tensors: SamplingTensors,
include_gpu_probs_tensor: bool, modify_greedy_probs: bool
) -> Tuple[List[Tuple[List[int], List[int]]], Optional[torch.Tensor]]:
) -> Tuple[SampleResultType, Optional[torch.Tensor]]:
"""
Args:
probs: (num_query_tokens_in_batch, num_vocab)
logprobs: (num_query_tokens_in_batch, num_vocab)
sampling_metadata: The metadata for a batch for sampling.
sampling_tensors: Tensors that include sampling related metadata.
Returns:
(next_token_ids, parent_seq_ids) for each seq group in a batch.
If sampling is skipped, it returns ([], [])
sampled_token_ids_tensor: A tensor of sampled token ids.
"""
return _sample_with_torch(
probs,
logprobs,
......@@ -625,57 +686,98 @@ def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
def _get_logprobs(
logprobs: torch.Tensor,
sampling_metadata: SamplingMetadata,
sample_results: List[Tuple[List[int], List[int]]],
) -> Tuple[List[Optional[List[Optional[Dict[int, float]]]]], List[List[Dict[
int, float]]]]:
# Prepare query indices
batched_logprobs_query_seq_indices: List[int] = []
batched_logprobs_query_token_indices: List[int] = []
# at least get one logprob for each token
sample_results: SampleResultType,
) -> Tuple[List[Optional[PromptLogprobs]], List[SampleLogprobs]]:
"""Return sample lobprobs and prompt logprobs.
The logic consists of 3 parts.
- Select indices to compute logprob from, ranks of token ids, and
the top k token ids from logprobs.
- Compute prompt logprobs if required.
- Compute sample logprobs if required.
Args:
logprobs: (num_query_tokens_across_batch, num_vocab). Each query token's
logprob per vocab. Sequence groups' query tokens are batched in a
single flattened tensor. For example, assuming there are N
seq groups, it is sorted by prefill tokens for seq_group_1 (if
prompt logprob is enabled), decode tokens for seq_group_1 (if
sampling is required), prefill tokens for seq_group_2, ...
sampling_metadata: The sampling metadata.
sample_results: (num_seq_groups) The tuple of (next_token_ids,
parent_ids) for each sequence group. When beam search is enabled,
sample_results can contain different number of seq_ids from
sampling_metadata.seq_groups. It is because beam search creates
2 * BEAM_WIDTH number of samples (whereas there are only up to
BEAM_WIDTH number of seq_ids).
Returns:
A tuple of prompt and sample logprobs per sequence group in a batch.
"""
# The index of query token to calculate logprobs. It includes both
# prompt and sample logprob indices.
query_indices: List[int] = []
# The next token ids to get the logprob value from.
next_token_ids: List[int] = []
# The largest requested number of logprobs. We find logprobs as many as the
# largest num logprobs in this API.
largest_num_logprobs = 1
sample_idx = 0
for i, (seq_group, sample_result) in enumerate(
zip(sampling_metadata.seq_groups, sample_results)):
seq_ids, sampling_params = seq_group
next_token_ids, parent_ids = sample_result
num_parent_seqs = len(seq_ids)
if (i < sampling_metadata.num_prompts
# Select indices to compute logprob from, ranks of token ids, and the top
# k token ids from logprobs.
for (seq_group, sample_result) in zip(sampling_metadata.seq_groups,
sample_results):
sampling_params = seq_group.sampling_params
# Update indices and tokens for prompt logprobs.
if (seq_group.is_prompt
and sampling_params.prompt_logprobs is not None):
largest_num_logprobs = max(largest_num_logprobs,
sampling_params.prompt_logprobs)
prompt_len = sampling_metadata.prompt_lens[i]
prompt_tokens = sampling_metadata.seq_data[
seq_ids[0]].prompt_token_ids
batched_logprobs_query_seq_indices.extend(
sample_idx + j for j in range(prompt_len - 1))
batched_logprobs_query_token_indices.extend(
token_id for token_id in prompt_tokens[1:])
sample_idx += prompt_len - 1
batched_logprobs_query_seq_indices.extend(
[sample_idx + parent_id for parent_id in parent_ids])
batched_logprobs_query_token_indices.extend(next_token_ids)
if sampling_params.logprobs is not None:
largest_num_logprobs = max(largest_num_logprobs,
sampling_params.logprobs)
sample_idx += num_parent_seqs
assert sample_idx == logprobs.size(0)
batched_logprobs_query_seq_indices_gpu = torch.tensor(
batched_logprobs_query_seq_indices, device=logprobs.device)
batched_logprobs_query_token_indices_gpu = torch.tensor(
batched_logprobs_query_token_indices, device=logprobs.device)
# Batched query for logprobs of selected token
batched_logprobs_query_result = logprobs[[
batched_logprobs_query_seq_indices_gpu,
batched_logprobs_query_token_indices_gpu
next_prompt_tokens = _get_next_prompt_tokens(seq_group)
query_indices.extend(seq_group.prompt_logprob_indices)
next_token_ids.extend(next_prompt_tokens)
# Update indices and next tokenes for sample logprob.
if seq_group.do_sample:
token_ids, parent_seq_ids = sample_result
# NOTE: We cannot directly use sample_indices because
# sample_indices only contain parent seq_ids of a previous step.
# The current step may have different number of seq_ids, and
# we can obtain it from `sample_result[1]`.
query_idx = seq_group.sample_indices[0]
query_indices.extend(
[query_idx + parent_id for parent_id in parent_seq_ids])
next_token_ids.extend(token_ids)
if sampling_params.logprobs is not None:
largest_num_logprobs = max(largest_num_logprobs,
sampling_params.logprobs)
assert len(next_token_ids) == len(query_indices)
if len(query_indices) == 0:
empty_sampled_logprob: SampleLogprobs = []
empty_prompt_logprob: Optional[PromptLogprobs] = None
return [empty_prompt_logprob], [empty_sampled_logprob]
query_indices_gpu = torch.tensor(query_indices, device=logprobs.device)
next_token_ids_gpu = torch.tensor(next_token_ids, device=logprobs.device)
# (num_selected_query_tokens, num_logprobs). Note that query_indices can
# contain duplicates if beam search is enabled.
selected_logprobs = logprobs[[
query_indices_gpu,
next_token_ids_gpu,
]]
ranks = _get_ranks(
logprobs[query_indices_gpu],
next_token_ids_gpu,
)
assert selected_logprobs.shape[0] == ranks.shape[0]
batched_ranks_query_result = _get_ranks(
logprobs[batched_logprobs_query_seq_indices_gpu],
batched_logprobs_query_token_indices_gpu)
# Batched query for logprobs of topk tokens
# Logprobs of topk tokens for a batch of sequence groups.
# (num_query_tokens_across_batch).
if largest_num_logprobs > 0:
top_logprobs, top_token_ids = torch.topk(logprobs,
largest_num_logprobs,
......@@ -685,79 +787,136 @@ def _get_logprobs(
else:
top_logprobs, top_token_ids = None, None
batched_logprobs_query_result = batched_logprobs_query_result.cpu()
batched_ranks_query_result = batched_ranks_query_result.cpu()
# Gather results
result_prompt_logprobs: List[Optional[PromptLogprobs]] = []
result_sample_logprobs: List[SampleLogprobs] = []
sample_idx = 0
query_result_idx = 0
for i, (seq_group, sample_result) in enumerate(
zip(sampling_metadata.seq_groups, sample_results)):
seq_ids, sampling_params = seq_group
next_token_ids, parent_ids = sample_result
selected_logprobs = selected_logprobs.cpu()
ranks = ranks.cpu()
# Find prompt/sample logprobs.
prompt_logprobs_per_seq_group: List[Optional[PromptLogprobs]] = []
sample_logprobs_per_seq_group: List[SampleLogprobs] = []
top_logprob_idx = 0
selected_logprobs_idx = 0
for seq_group, sample_result in zip(sampling_metadata.seq_groups,
sample_results):
(prompt_logprobs, top_logprob_idx,
selected_logprobs_idx) = _get_prompt_logprob_if_needed(
seq_group, selected_logprobs, ranks, top_token_ids, top_logprobs,
selected_logprobs_idx, top_logprob_idx)
prompt_logprobs_per_seq_group.append(prompt_logprobs)
(sampled_logprobs, top_logprob_idx,
selected_logprobs_idx) = _get_sampled_logprob_if_needed(
seq_group, sample_result, selected_logprobs, ranks, top_token_ids,
top_logprobs, selected_logprobs_idx, top_logprob_idx)
sample_logprobs_per_seq_group.append(sampled_logprobs)
return prompt_logprobs_per_seq_group, sample_logprobs_per_seq_group
def _get_prompt_logprob_if_needed(
seq_group: SequenceGroupToSample,
selected_logprobs: torch.Tensor,
ranks: torch.Tensor,
top_token_ids: torch.Tensor,
top_logprobs: torch.Tensor,
selected_logprobs_idx: int,
top_logprob_idx: int,
):
"""Compute the prompt logprob from a sequence group if needed."""
sampling_params = seq_group.sampling_params
is_prompt = seq_group.is_prompt
# Find prompt logprobs
prompt_logprobs: Optional[PromptLogprobs] = None
if (is_prompt and sampling_params.prompt_logprobs is not None):
prompt_logprobs = []
num_logprobs = sampling_params.prompt_logprobs
next_prompt_tokens = _get_next_prompt_tokens(seq_group)
for token_id in next_prompt_tokens:
# Calculate the prompt logprob of the real prompt tokens.
# Use tuple here for performance (to use to_list()).
# {token_id: (logprob, rank_from_vocab)}
prompt_logprobs_dict: Dict[int, Tuple[float, int]] = {
token_id: (selected_logprobs[selected_logprobs_idx].item(),
ranks[selected_logprobs_idx].item())
}
# Prompt logprobs
if (i < sampling_metadata.num_prompts
and sampling_params.prompt_logprobs is not None):
num_logprobs = sampling_params.prompt_logprobs
prompt_tokens = sampling_metadata.seq_data[
seq_ids[0]].prompt_token_ids
group_prompt_logprobs: PromptLogprobs = [None]
for token_id in prompt_tokens[1:]:
prompt_logprobs_dict = {
token_id:
(batched_logprobs_query_result[query_result_idx].item(),
batched_ranks_query_result[query_result_idx].item())
}
if num_logprobs > 0:
prompt_logprobs_dict.update(
# Add top K prompt logprobs along with its rank.
if num_logprobs > 0:
prompt_logprobs_dict.update(
zip(
top_token_ids[top_logprob_idx, :num_logprobs].tolist(),
zip(
top_token_ids[sample_idx, :num_logprobs].tolist(),
zip(
top_logprobs[
sample_idx, :num_logprobs].tolist(),
range(1, num_logprobs + 1))))
group_prompt_logprobs.append({
token_id: Logprob(*logprob_rank)
for token_id, logprob_rank in prompt_logprobs_dict.items()
})
sample_idx += 1
query_result_idx += 1
result_prompt_logprobs.append(group_prompt_logprobs)
else:
result_prompt_logprobs.append(None)
# Sample logprobs
num_logprobs = sampling_params.logprobs
if num_logprobs is None:
num_logprobs = 0
group_sample_logprobs: SampleLogprobs = []
for next_token_id, parent_id in zip(next_token_ids, parent_ids):
sample_logprobs_dict = {
top_logprobs[
top_logprob_idx, :num_logprobs].tolist(),
# This is ranks. Since top_logprob is sorted,
# we can just use a range here.
range(1, num_logprobs + 1))))
prompt_logprobs.append({
token_id: Logprob(*logprob_and_rank)
for token_id, logprob_and_rank in prompt_logprobs_dict.items()
})
# + 1 to go to the next prompt token.
top_logprob_idx += 1
selected_logprobs_idx += 1
return prompt_logprobs, top_logprob_idx, selected_logprobs_idx
def _get_sampled_logprob_if_needed(
seq_group: SequenceGroupToSample,
sample_result: Tuple[List[int], List[int]],
selected_logprobs: torch.Tensor,
ranks: torch.Tensor,
top_token_ids: torch.Tensor,
top_logprobs: torch.Tensor,
selected_logprobs_idx: int,
top_logprob_idx: int,
):
"""Compute the sample logprob if needed."""
seq_ids = seq_group.seq_ids
num_logprobs = seq_group.sampling_params.logprobs
if num_logprobs is None:
num_logprobs = 0
sampled_logprobs: SampleLogprobs = []
next_token_ids, parent_seq_ids = sample_result
if seq_group.do_sample:
assert len(next_token_ids) > 0
for (next_token_id, parent_id) in zip(next_token_ids, parent_seq_ids):
# Calculate the sample logprob of the real sampled tokens.
# Use tuple here for performance (to use to_list()).
# token_id: (logprob, rank_from_vocab)
sampled_logprobs_dict: Dict[int, Tuple[float, int]] = {
next_token_id:
(batched_logprobs_query_result[query_result_idx].item(),
batched_ranks_query_result[query_result_idx].item())
(selected_logprobs[selected_logprobs_idx].item(),
ranks[selected_logprobs_idx].item())
}
query_result_idx += 1
# +1 to go to the next sampled token. Note that
# selected_logprobs can contain duplicates unlike top_logprobs
# when beam search is enabled.
selected_logprobs_idx += 1
# Second, add top K logprobs along with its rank.
if num_logprobs >= 0:
sample_logprobs_dict.update(
sampled_logprobs_dict.update(
zip(
top_token_ids[sample_idx +
top_token_ids[top_logprob_idx +
parent_id, :num_logprobs].tolist(),
zip(
top_logprobs[sample_idx +
top_logprobs[top_logprob_idx +
parent_id, :num_logprobs].tolist(),
# This is rank. Since top_logprob is sorted, we
# can just use a range here.
range(1, num_logprobs + 1))))
group_sample_logprobs.append({
token_id: Logprob(*logprob_rank)
for token_id, logprob_rank in sample_logprobs_dict.items()
sampled_logprobs.append({
token_id: Logprob(*logprob_and_rank)
for token_id, logprob_and_rank in
sampled_logprobs_dict.items()
})
result_sample_logprobs.append(group_sample_logprobs)
sample_idx += len(seq_ids)
return result_prompt_logprobs, result_sample_logprobs
# There are len(seq_ids) number of sampled tokens for the current
# sequence group in top_logprobs. Jump to the next seq_group.
top_logprob_idx += len(seq_ids)
return sampled_logprobs, top_logprob_idx, selected_logprobs_idx
def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor,
......@@ -805,18 +964,18 @@ def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor,
has implications on the overall design of the sampler, e.g. how to record
accurate logprobs for the user, so this improvement is deferred to later.
"""
logprobs[sample_indices, :] = -float('inf')
logprobs[sample_indices, greedy_samples] = 0.0
# NOTE: logprobs are not modified so they can be returned to the user.
probs[sample_indices, :] = 0
probs[sample_indices, greedy_samples] = 1.0
def _build_sampler_output(
sample_results: List[Tuple[List[int], List[int]]],
sample_results: SampleResultType,
sampling_metadata: SamplingMetadata,
prompt_logprobs: List[Optional[PromptLogprobs]],
sample_logprobs: List[SampleLogprobs],
on_device_tensors: Optional[Tuple[torch.Tensor, torch.Tensor]],
on_device_tensors: Optional[Tuple[torch.Tensor, torch.Tensor,
torch.Tensor]],
) -> SamplerOutput:
"""Construct Python objects with the output of sampling.
......@@ -832,7 +991,7 @@ def _build_sampler_output(
group_sample_logprobs) in zip(sampling_metadata.seq_groups,
sample_results, prompt_logprobs,
sample_logprobs):
seq_ids, _ = seq_group
seq_ids = seq_group.seq_ids
next_token_ids, parent_ids = sample_result
seq_outputs = []
for parent_id, next_token_id, logprobs in zip(parent_ids,
......@@ -845,12 +1004,48 @@ def _build_sampler_output(
# If not specified, store None values in SamplerOutput.
if on_device_tensors is not None:
sampled_token_probs, sampled_token_ids = on_device_tensors
(sampled_token_probs, logprobs_tensor,
sampled_token_ids) = on_device_tensors
else:
sampled_token_probs, sampled_token_ids = (None, None)
sampled_token_probs, logprobs_tensor, sampled_token_ids = (None, None,
None)
return SamplerOutput(
outputs=sampler_output,
sampled_token_probs=sampled_token_probs,
sampled_token_ids=sampled_token_ids,
logprobs=logprobs_tensor,
)
def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[int]:
"""Get a list of next prompt tokens to compute logprob from a
given sequence group.
It is used to compute prompt logprob. Imagine you have logprob for each
query token. Query token needs to know the next prompt token id to compute
prompt logprob. This is a helper to obtain next prompt token ids.
This API has to be used only when the caller knows seq_group is in prefill
stage.
Returns:
A list of next prompt tokens to compute logprob.
"""
assert seq_group.is_prompt, (
"Caller should ensure the sequence group is in a prefill stage.")
seq_ids = seq_group.seq_ids
query_len = seq_group.query_len
assert query_len is not None
# prompt has only 1 seq id.
assert len(seq_ids) == 1
seq_data = seq_group.seq_data[seq_ids[0]]
computed_len = seq_data.get_num_computed_tokens()
prompt_tokens = seq_data.prompt_token_ids
# +1 because we are looking for a next prompt token.
next_token_index_start = computed_len + 1
next_token_index_end = min(computed_len + query_len + 1,
len(prompt_tokens))
next_prompt_tokens = prompt_tokens[
next_token_index_start:next_token_index_end]
return next_prompt_tokens
......@@ -105,6 +105,14 @@ class VocabParallelEmbedding(torch.nn.Module):
output = tensor_model_parallel_all_reduce(output_parallel)
return output
def extra_repr(self) -> str:
s = f"num_embeddings={self.num_embeddings_per_partition}"
s += f", embedding_dim={self.embedding_dim}"
s += f", org_vocab_size={self.org_vocab_size}"
s += f', num_embeddings_padded={self.num_embeddings_padded}'
s += f', tp_size={self.tp_size}'
return s
class ParallelLMHead(VocabParallelEmbedding):
"""Parallelized LM head.
......
......@@ -3,16 +3,19 @@ import copy
import glob
import os
from abc import ABC, abstractmethod
from typing import (TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple,
Type)
from typing import Any, Dict, Generator, List, Optional, Tuple, Type
import huggingface_hub
import torch
from torch import nn
from vllm.config import (VLLM_USE_MODELSCOPE, DeviceConfig, LoadConfig,
LoadFormat, LoRAConfig, ModelConfig, ParallelConfig,
SchedulerConfig, VisionLanguageConfig)
from vllm.config import (DeviceConfig, LoadConfig, LoadFormat, LoRAConfig,
ModelConfig, ParallelConfig, SchedulerConfig,
VisionLanguageConfig)
from vllm.envs import VLLM_USE_MODELSCOPE
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.model_loader.tensorizer import (
TensorizerConfig, is_vllm_serialized_tensorizer, load_with_tensorizer,
tensorizer_weights_iterator)
......@@ -24,9 +27,6 @@ from vllm.model_executor.model_loader.weight_utils import (
pt_weights_iterator, safetensors_weights_iterator)
from vllm.model_executor.models.llava import LlavaForConditionalGeneration
if TYPE_CHECKING:
from vllm.model_executor.layers.linear import LinearMethodBase
_VISION_MODEL_CLASSES = [
LlavaForConditionalGeneration,
]
......@@ -34,11 +34,10 @@ _VISION_MODEL_CLASSES = [
logger = init_logger(__name__)
def _get_linear_method(
def _get_quantization_config(
model_config: ModelConfig,
load_config: LoadConfig) -> Optional["LinearMethodBase"]:
"""Get the (maybe quantized) linear method."""
linear_method = None
load_config: LoadConfig) -> Optional[QuantizationConfig]:
"""Get the quantization config."""
if model_config.quantization is not None:
quant_config = get_quant_config(model_config, load_config)
capability = torch.cuda.get_device_capability()
......@@ -55,6 +54,7 @@ def _get_linear_method(
f"{model_config.dtype} is not supported for quantization "
f"method {model_config.quantization}. Supported dtypes: "
f"{supported_dtypes}")
<<<<<<< HEAD
linear_method = quant_config.get_linear_method()
......@@ -62,6 +62,10 @@ def _get_linear_method(
os.environ['LLAMA_NN'] = '0'
return linear_method
=======
return quant_config
return None
>>>>>>> v0.4.2
def _get_model_initialization_kwargs(
......@@ -89,10 +93,10 @@ def _initialize_model(
vision_language_config: Optional[VisionLanguageConfig]) -> nn.Module:
"""Initialize a model with the given configurations."""
model_class = get_model_architecture(model_config)[0]
linear_method = _get_linear_method(model_config, load_config)
quant_config = _get_quantization_config(model_config, load_config)
return model_class(config=model_config.hf_config,
linear_method=linear_method,
quant_config=quant_config,
**_get_model_initialization_kwargs(
model_class, lora_config, vision_language_config))
......@@ -139,7 +143,9 @@ class DefaultModelLoader(BaseModelLoader):
model_path = snapshot_download(
model_id=model,
cache_dir=self.load_config.download_dir,
revision=revision)
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
revision=revision,
)
else:
model_path = model
return model_path
......@@ -233,9 +239,11 @@ class DefaultModelLoader(BaseModelLoader):
"fall_back_to_pt_during_load",
True)), )
for _, module in model.named_modules():
linear_method = getattr(module, "linear_method", None)
if linear_method is not None:
linear_method.process_weights_after_loading(module)
quant_method = getattr(module, "quant_method", None)
if quant_method is not None:
quant_method.process_weights_after_loading(module)
# FIXME: Remove this after Mixtral is updated
# to use quant_method.
if hasattr(module, "process_weights_after_loading"):
module.process_weights_after_loading()
return model.eval()
......@@ -318,11 +326,11 @@ class TensorizerLoader(BaseModelLoader):
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model_class = get_model_architecture(model_config)[0]
linear_method = _get_linear_method(model_config,
self.load_config)
quant_config = _get_quantization_config(
model_config, self.load_config)
extra_kwargs = _get_model_initialization_kwargs(
model_class, lora_config, vision_language_config)
extra_kwargs["linear_method"] = linear_method
extra_kwargs["quant_config"] = quant_config
tensorizer_config = copy.copy(self.tensorizer_config)
tensorizer_config.model_class = model_class
......
......@@ -11,9 +11,11 @@ import torch
from torch import nn
from transformers import PretrainedConfig
import vllm.envs as envs
from vllm.config import ModelConfig, ParallelConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
......@@ -43,7 +45,7 @@ class TensorizerConfig:
str, bytes, os.PathLike, int]
vllm_tensorized: bool
verify_hash: Optional[bool] = False
num_readers: Optional[int] = 1
num_readers: Optional[int] = None
encryption_keyfile: Optional[str] = None
s3_access_key_id: Optional[str] = None
s3_secret_access_key: Optional[str] = None
......@@ -63,7 +65,7 @@ class TensorizerConfig:
"s3_secret_access_key": self.s3_secret_access_key,
"s3_endpoint": self.s3_endpoint,
}
return TensorizerArgs(**tensorizer_args)
return TensorizerArgs(**tensorizer_args) # type: ignore
def verify_with_parallel_config(
self,
......@@ -103,7 +105,7 @@ class TensorizerArgs:
str, bytes, os.PathLike, int]
vllm_tensorized: bool
verify_hash: Optional[bool] = False
num_readers: Optional[int] = 1
num_readers: Optional[int] = None
encryption_keyfile: Optional[str] = None
s3_access_key_id: Optional[str] = None
s3_secret_access_key: Optional[str] = None
......@@ -124,8 +126,9 @@ class TensorizerArgs:
the hashes stored in the metadata. A `HashMismatchError` will be
raised if any of the hashes do not match.
num_readers: Controls how many threads are allowed to read concurrently
from the source file. Default is 1. This greatly increases
performance.
from the source file. Default is `None`, which will dynamically set
the number of readers based on the number of available
resources and model size. This greatly increases performance.
encryption_keyfile: File path to a binary file containing a
binary key to use for decryption. `None` (the default) means
no decryption. See the example script in
......@@ -140,13 +143,10 @@ class TensorizerArgs:
def __post_init__(self):
self.file_obj = self.tensorizer_uri
self.s3_access_key_id = (self.s3_access_key_id
or os.environ.get("S3_ACCESS_KEY_ID")) or None
self.s3_secret_access_key = (
self.s3_secret_access_key
or os.environ.get("S3_SECRET_ACCESS_KEY")) or None
self.s3_endpoint = (self.s3_endpoint
or os.environ.get("S3_ENDPOINT_URL")) or None
self.s3_access_key_id = self.s3_access_key_id or envs.S3_ACCESS_KEY_ID
self.s3_secret_access_key = (self.s3_secret_access_key
or envs.S3_SECRET_ACCESS_KEY)
self.s3_endpoint = self.s3_endpoint or envs.S3_ENDPOINT_URL
self.stream_params = {
"s3_access_key_id": self.s3_access_key_id,
"s3_secret_access_key": self.s3_secret_access_key,
......@@ -198,10 +198,12 @@ class TensorizerArgs:
"use for decryption. Can be a file path or S3 network URI.")
group.add_argument(
"--num-readers",
default=1,
default=None,
type=int,
help="Controls how many threads are allowed to read concurrently "
"from the source file.")
"from the source file. Default is `None`, which will dynamically "
"set the number of readers based on the available resources "
"and model size. This greatly increases performance.")
group.add_argument(
"--s3-access-key-id",
default=None,
......@@ -251,7 +253,7 @@ class TensorizerAgent:
"""
def __init__(self, tensorizer_config: TensorizerConfig,
linear_method: LinearMethodBase, **extra_kwargs):
quant_config: QuantizationConfig, **extra_kwargs):
if tensorizer_load_fail is not None:
raise ImportError(
"Tensorizer is not installed. Please install tensorizer "
......@@ -262,19 +264,21 @@ class TensorizerAgent:
self.tensorizer_args = (
self.tensorizer_config._construct_tensorizer_args())
self.extra_kwargs = extra_kwargs
if extra_kwargs.get("linear_method", None) is not None:
self.linear_method = extra_kwargs["linear_method"]
if extra_kwargs.get("quant_config", None) is not None:
self.quant_config = extra_kwargs["quant_config"]
else:
self.linear_method = linear_method
self.quant_config = quant_config
self.model = self._init_model()
def _init_model(self):
assert self.tensorizer_config.hf_config is not None
model_args = self.tensorizer_config.hf_config
model_args.torch_dtype = self.tensorizer_config.dtype
assert self.tensorizer_config.model_class is not None
with no_init_or_tensor():
return self.tensorizer_config.model_class(
config=model_args,
linear_method=self.linear_method,
quant_config=self.quant_config,
**self.extra_kwargs)
def _resize_lora_embeddings(self):
......@@ -334,10 +338,10 @@ class TensorizerAgent:
per_second = convert_bytes(deserializer.total_tensor_bytes / duration)
after_mem = get_mem_usage()
deserializer.close()
logger.info(f"Deserialized {total_bytes_str} in "
f"{end - start:0.2f}s, {per_second}/s")
logger.info(f"Memory usage before: {before_mem}")
logger.info(f"Memory usage after: {after_mem}")
logger.info("Deserialized %s in %0.2fs, %s/s", total_bytes_str,
end - start, per_second)
logger.info("Memory usage before: %s", before_mem)
logger.info("Memory usage after: %s", after_mem)
self._check_tensors_on_meta_device()
self._resize_lora_embeddings()
......
......@@ -127,11 +127,14 @@ def get_quant_config(model_config: ModelConfig,
if not is_local:
# Download the config files.
with get_lock(model_name_or_path, load_config.download_dir):
hf_folder = snapshot_download(model_name_or_path,
revision=model_config.revision,
allow_patterns="*.json",
cache_dir=load_config.download_dir,
tqdm_class=DisabledTqdm)
hf_folder = snapshot_download(
model_name_or_path,
revision=model_config.revision,
allow_patterns="*.json",
cache_dir=load_config.download_dir,
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
tqdm_class=DisabledTqdm,
)
else:
hf_folder = model_name_or_path
......@@ -161,12 +164,14 @@ def get_quant_config(model_config: ModelConfig,
return quant_cls.from_config(config)
def download_weights_from_hf(model_name_or_path: str,
cache_dir: Optional[str],
allow_patterns: List[str],
revision: Optional[str] = None) -> str:
def download_weights_from_hf(
model_name_or_path: str,
cache_dir: Optional[str],
allow_patterns: List[str],
revision: Optional[str] = None,
) -> str:
"""Download model weights from Hugging Face Hub.
Args:
model_name_or_path (str): The model name or path.
cache_dir (Optional[str]): The cache directory to store the model
......@@ -179,26 +184,30 @@ def download_weights_from_hf(model_name_or_path: str,
Returns:
str: The path to the downloaded model weights.
"""
# Before we download we look at that is available:
fs = HfFileSystem()
file_list = fs.ls(model_name_or_path, detail=False, revision=revision)
# depending on what is available we download different things
for pattern in allow_patterns:
matching = fnmatch.filter(file_list, pattern)
if len(matching) > 0:
allow_patterns = [pattern]
break
logger.info(f"Using model weights format {allow_patterns}")
if not huggingface_hub.constants.HF_HUB_OFFLINE:
# Before we download we look at that is available:
fs = HfFileSystem()
file_list = fs.ls(model_name_or_path, detail=False, revision=revision)
# depending on what is available we download different things
for pattern in allow_patterns:
matching = fnmatch.filter(file_list, pattern)
if len(matching) > 0:
allow_patterns = [pattern]
break
logger.info("Using model weights format %s", allow_patterns)
# Use file lock to prevent multiple processes from
# downloading the same model weights at the same time.
with get_lock(model_name_or_path, cache_dir):
hf_folder = snapshot_download(model_name_or_path,
allow_patterns=allow_patterns,
cache_dir=cache_dir,
tqdm_class=DisabledTqdm,
revision=revision)
hf_folder = snapshot_download(
model_name_or_path,
allow_patterns=allow_patterns,
cache_dir=cache_dir,
tqdm_class=DisabledTqdm,
revision=revision,
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
)
return hf_folder
......@@ -310,17 +319,17 @@ def kv_cache_scales_loader(
return layer_scales_map.items()
except FileNotFoundError:
logger.error(f"File or directory '{filename}' not found.")
logger.error("File or directory '%s' not found.", filename)
except json.JSONDecodeError:
logger.error(f"Error decoding JSON in file '{filename}'.")
logger.error("Error decoding JSON in file '%s'.", filename)
except Exception as e:
logger.error(f"An error occurred while reading '{filename}': {e}")
logger.error("An error occurred while reading '%s': %s", filename, e)
# This section is reached if and only if any of the excepts are hit
# Return an empty iterable (list) => no KV cache scales are loaded
# which ultimately defaults to 1.0 scales
logger.warning("Defaulting to KV cache scaling factors = 1.0 "
f"for all layers in TP rank {tp_rank} "
"as an error occurred during loading.")
logger.warning(
"Defaulting to KV cache scaling factors = 1.0 for all "
"layers in TP rank %d as an error occurred during loading.", tp_rank)
return []
......
......@@ -42,10 +42,11 @@ _MODELS = {
"MptForCausalLM": ("mpt", "MPTForCausalLM"),
"MPTForCausalLM": ("mpt", "MPTForCausalLM"),
"MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
"OLMoForCausalLM": ("olmo", "OLMoForCausalLM"),
"OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
"OPTForCausalLM": ("opt", "OPTForCausalLM"),
"OrionForCausalLM": ("orion", "OrionForCausalLM"),
"PhiForCausalLM": ("phi", "PhiForCausalLM"),
"Phi3ForCausalLM": ("llama", "LlamaForCausalLM"),
"QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
"Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
......@@ -90,8 +91,8 @@ class ModelRegistry:
"ROCm for now.")
if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
logger.warning(
f"Model architecture {model_arch} is partially supported "
"by ROCm: " + _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch])
"Model architecture %s is partially supported by ROCm: %s",
model_arch, _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch])
module_name, model_cls_name = _MODELS[model_arch]
module = importlib.import_module(
......@@ -106,9 +107,9 @@ class ModelRegistry:
def register_model(model_arch: str, model_cls: Type[nn.Module]):
if model_arch in _MODELS:
logger.warning(
f"Model architecture {model_arch} is already registered, "
"and will be overwritten by the new model "
f"class {model_cls.__name__}.")
"Model architecture %s is already registered, and will be "
"overwritten by the new model class %s.", model_arch,
model_cls.__name__)
global _OOT_MODELS
_OOT_MODELS[model_arch] = model_cls
......
......@@ -31,11 +31,12 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear,
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
......@@ -77,17 +78,17 @@ class BaiChuanMLP(nn.Module):
hidden_size: int,
intermediate_size: int,
hidden_act: str,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
bias=False,
linear_method=linear_method)
quant_config=quant_config)
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
linear_method=linear_method)
quant_config=quant_config)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
......@@ -110,7 +111,7 @@ class BaiChuanAttention(nn.Module):
position_embedding: str,
rope_theta: float = 10000,
max_position_embeddings: int = 8192,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.hidden_size = hidden_size
......@@ -132,13 +133,13 @@ class BaiChuanAttention(nn.Module):
self.total_num_heads,
self.total_num_heads,
bias=False,
linear_method=linear_method,
quant_config=quant_config,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
linear_method=linear_method,
quant_config=quant_config,
)
# Create the alibi slopes and slice them.
if self.postion_embedding == "ALIBI":
......@@ -184,7 +185,7 @@ class BaiChuanDecoderLayer(nn.Module):
def __init__(self,
config: PretrainedConfig,
position_embedding: str,
linear_method: Optional[LinearMethodBase] = None):
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
......@@ -196,13 +197,13 @@ class BaiChuanDecoderLayer(nn.Module):
position_embedding=position_embedding,
rope_theta=rope_theta,
max_position_embeddings=max_position_embeddings,
linear_method=linear_method,
quant_config=quant_config,
)
self.mlp = BaiChuanMLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
linear_method=linear_method,
quant_config=quant_config,
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
......@@ -243,7 +244,7 @@ class BaiChuanModel(nn.Module):
def __init__(self,
config: PretrainedConfig,
position_embedding: str,
linear_method: Optional[LinearMethodBase] = None):
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.config = config
self.padding_idx = config.pad_token_id
......@@ -254,7 +255,7 @@ class BaiChuanModel(nn.Module):
config.hidden_size,
)
self.layers = nn.ModuleList([
BaiChuanDecoderLayer(config, position_embedding, linear_method)
BaiChuanDecoderLayer(config, position_embedding, quant_config)
for _ in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
......@@ -303,13 +304,13 @@ class BaiChuanBaseForCausalLM(nn.Module):
self,
config,
position_embedding: str,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
):
super().__init__()
self.config = config
self.linear_method = linear_method
self.model = BaiChuanModel(config, position_embedding, linear_method)
self.quant_config = quant_config
self.model = BaiChuanModel(config, position_embedding, quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
......@@ -388,13 +389,13 @@ class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
def __init__(
self,
config,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
):
if config.hidden_size == 4096: # baichuan2 7b
super().__init__(config, "ROPE", linear_method, lora_config)
super().__init__(config, "ROPE", quant_config, lora_config)
else: # baichuan 13b, baichuan2 13b
super().__init__(config, "ALIBI", linear_method, lora_config)
super().__init__(config, "ALIBI", quant_config, lora_config)
class BaiChuanForCausalLM(BaiChuanBaseForCausalLM):
......@@ -403,7 +404,7 @@ class BaiChuanForCausalLM(BaiChuanBaseForCausalLM):
def __init__(
self,
config,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
):
super().__init__(config, "ROPE", linear_method, lora_config)
super().__init__(config, "ROPE", quant_config, lora_config)
......@@ -28,10 +28,11 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
......@@ -70,7 +71,7 @@ class BloomAttention(nn.Module):
def __init__(
self,
config: BloomConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.hidden_size = config.hidden_size
......@@ -87,13 +88,13 @@ class BloomAttention(nn.Module):
self.head_dim,
self.total_num_heads,
bias=True,
linear_method=linear_method,
quant_config=quant_config,
)
self.dense = RowParallelLinear(
self.hidden_size,
self.hidden_size,
bias=True,
linear_method=linear_method,
quant_config=quant_config,
)
# Create the alibi slopes and slice them.
......@@ -129,21 +130,20 @@ class BloomMLP(nn.Module):
def __init__(
self,
config: BloomConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
hidden_size = config.hidden_size
self.dense_h_to_4h = ColumnParallelLinear(
hidden_size,
4 * hidden_size,
linear_method=linear_method,
quant_config=quant_config,
)
quant_config = getattr(linear_method, "quant_config", None)
self.gelu_impl = get_act_fn("gelu", quant_config, 4 * hidden_size)
self.dense_4h_to_h = RowParallelLinear(
4 * hidden_size,
hidden_size,
linear_method=linear_method,
quant_config=quant_config,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
......@@ -158,17 +158,17 @@ class BloomBlock(nn.Module):
def __init__(
self,
config: BloomConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
hidden_size = config.hidden_size
self.input_layernorm = nn.LayerNorm(hidden_size,
eps=config.layer_norm_epsilon)
self.self_attention = BloomAttention(config, linear_method)
self.self_attention = BloomAttention(config, quant_config)
self.post_attention_layernorm = nn.LayerNorm(
hidden_size, eps=config.layer_norm_epsilon)
self.mlp = BloomMLP(config, linear_method)
self.mlp = BloomMLP(config, quant_config)
self.apply_residual_connection_post_layernorm = (
config.apply_residual_connection_post_layernorm)
......@@ -214,7 +214,7 @@ class BloomModel(nn.Module):
def __init__(
self,
config: BloomConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.embed_dim = config.hidden_size
......@@ -229,7 +229,7 @@ class BloomModel(nn.Module):
# Transformer blocks
self.h = nn.ModuleList([
BloomBlock(config, linear_method)
BloomBlock(config, quant_config)
for _ in range(config.num_hidden_layers)
])
......@@ -262,12 +262,12 @@ class BloomForCausalLM(nn.Module):
def __init__(
self,
config: BloomConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
self.linear_method = linear_method
self.transformer = BloomModel(config, linear_method)
self.quant_config = quant_config
self.transformer = BloomModel(config, quant_config)
self.lm_head_weight = self.transformer.word_embeddings.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
......
......@@ -13,11 +13,12 @@ from vllm.config import LoRAConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear,
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
......@@ -33,7 +34,7 @@ class GLMAttention(nn.Module):
def __init__(
self,
config,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.hidden_size = config.hidden_size
......@@ -65,13 +66,13 @@ class GLMAttention(nn.Module):
self.total_num_heads,
self.total_num_kv_heads,
bias=config.add_bias_linear or config.add_qkv_bias,
linear_method=linear_method,
quant_config=quant_config,
)
self.dense = RowParallelLinear(
self.total_num_heads * self.head_dim,
config.hidden_size,
bias=config.add_bias_linear,
linear_method=linear_method,
quant_config=quant_config,
)
# https://huggingface.co/THUDM/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141
......@@ -123,7 +124,7 @@ class GLMMLP(nn.Module):
def __init__(
self,
config,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
......@@ -134,7 +135,7 @@ class GLMMLP(nn.Module):
config.hidden_size,
[config.ffn_hidden_size] * 2,
bias=config.add_bias_linear,
linear_method=linear_method,
quant_config=quant_config,
)
self.activation_func = SiluAndMul()
......@@ -144,7 +145,7 @@ class GLMMLP(nn.Module):
config.ffn_hidden_size,
config.hidden_size,
bias=config.add_bias_linear,
linear_method=linear_method,
quant_config=quant_config,
)
def forward(self, hidden_states):
......@@ -166,7 +167,7 @@ class GLMBlock(nn.Module):
def __init__(
self,
config,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.apply_residual_connection_post_layernorm = (
......@@ -180,7 +181,7 @@ class GLMBlock(nn.Module):
eps=config.layernorm_epsilon)
# Self attention.
self.self_attention = GLMAttention(config, linear_method)
self.self_attention = GLMAttention(config, quant_config)
self.hidden_dropout = config.hidden_dropout
# Layernorm on the attention output
......@@ -188,7 +189,7 @@ class GLMBlock(nn.Module):
config.hidden_size, eps=config.layernorm_epsilon)
# MLP
self.mlp = GLMMLP(config, linear_method)
self.mlp = GLMMLP(config, quant_config)
def forward(
self,
......@@ -236,7 +237,7 @@ class GLMTransformer(nn.Module):
def __init__(
self,
config,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.post_layer_norm = config.post_layer_norm
......@@ -246,7 +247,7 @@ class GLMTransformer(nn.Module):
# Transformer layers.
self.layers = nn.ModuleList(
[GLMBlock(config, linear_method) for i in range(self.num_layers)])
[GLMBlock(config, quant_config) for i in range(self.num_layers)])
if self.post_layer_norm:
layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
......@@ -281,7 +282,7 @@ class ChatGLMModel(nn.Module):
def __init__(
self,
config,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
......@@ -291,7 +292,7 @@ class ChatGLMModel(nn.Module):
self.num_layers = config.num_layers
self.multi_query_group_num = config.multi_query_group_num
self.kv_channels = config.kv_channels
self.encoder = GLMTransformer(config, linear_method)
self.encoder = GLMTransformer(config, quant_config)
self.output_layer = ParallelLMHead(config.padded_vocab_size,
config.hidden_size)
......@@ -333,13 +334,13 @@ class ChatGLMForCausalLM(nn.Module):
def __init__(
self,
config: ChatGLMConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
):
super().__init__()
self.config: ChatGLMConfig = config
self.linear_method = linear_method
self.transformer = ChatGLMModel(config, linear_method)
self.quant_config = quant_config
self.transformer = ChatGLMModel(config, quant_config)
self.lm_head_weight = self.transformer.output_layer.weight
self.logits_processor = LogitsProcessor(config.padded_vocab_size)
self.sampler = Sampler()
......
......@@ -32,11 +32,12 @@ from vllm.attention import Attention, AttentionMetadata
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear,
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
......@@ -91,7 +92,7 @@ class CohereMLP(nn.Module):
def __init__(
self,
config,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
......@@ -101,13 +102,13 @@ class CohereMLP(nn.Module):
self.hidden_size,
[self.intermediate_size] * 2,
bias=False,
linear_method=linear_method,
quant_config=quant_config,
)
self.down_proj = RowParallelLinear(
self.intermediate_size,
self.hidden_size,
bias=False,
linear_method=linear_method,
quant_config=quant_config,
)
self.act_fn = SiluAndMul()
......@@ -123,7 +124,7 @@ class CohereAttention(nn.Module):
def __init__(
self,
config: CohereConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
tp_size = get_tensor_model_parallel_world_size()
......@@ -158,13 +159,13 @@ class CohereAttention(nn.Module):
self.total_num_heads,
self.total_num_kv_heads,
bias=False,
linear_method=linear_method,
quant_config=quant_config,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
self.hidden_size,
bias=False,
linear_method=linear_method,
quant_config=quant_config,
)
self.rotary_emb = get_rope(
self.head_dim,
......@@ -218,13 +219,13 @@ class CohereDecoderLayer(nn.Module):
def __init__(self,
config: CohereConfig,
linear_method: Optional[LinearMethodBase] = None):
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = CohereAttention(config, linear_method=linear_method)
self.self_attn = CohereAttention(config, quant_config=quant_config)
self.mlp = CohereMLP(config, linear_method=linear_method)
self.mlp = CohereMLP(config, quant_config=quant_config)
self.input_layernorm = LayerNorm(param_shape=(config.hidden_size),
eps=config.layer_norm_eps)
......@@ -257,7 +258,7 @@ class CohereModel(nn.Module):
def __init__(
self,
config: CohereConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
......@@ -265,7 +266,7 @@ class CohereModel(nn.Module):
self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
config.hidden_size)
self.layers = nn.ModuleList([
CohereDecoderLayer(config, linear_method=linear_method)
CohereDecoderLayer(config, quant_config=quant_config)
for _ in range(config.num_hidden_layers)
])
self.norm = LayerNorm(param_shape=(config.hidden_size),
......@@ -298,14 +299,14 @@ class CohereForCausalLM(nn.Module):
def __init__(
self,
config: CohereConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.config = config
self.linear_method = linear_method
self.quant_config = quant_config
self.logits_processor = LogitsProcessor(config.vocab_size,
scale=config.logit_scale)
self.model = CohereModel(config, linear_method)
self.model = CohereModel(config, quant_config)
self.sampler = Sampler()
@torch.no_grad()
......
......@@ -9,11 +9,12 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.linear import (LinearMethodBase,
QKVParallelLinear,
from vllm.model_executor.layers.linear import (QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
......@@ -44,7 +45,7 @@ class DbrxRouter(nn.Module):
self.num_total_experts,
bias=False,
params_dtype=params_dtype,
linear_method=None,
quant_config=None,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
......@@ -63,7 +64,7 @@ class DbrxExperts(nn.Module):
def __init__(
self,
config: DbrxConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
params_dtype: Optional[torch.dtype] = None,
):
super().__init__()
......@@ -165,7 +166,7 @@ class DbrxAttention(nn.Module):
def __init__(
self,
config: DbrxConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.d_model = config.d_model
......@@ -183,13 +184,13 @@ class DbrxAttention(nn.Module):
self.total_num_heads,
self.total_num_kv_heads,
bias=False,
linear_method=linear_method,
quant_config=quant_config,
)
self.out_proj = RowParallelLinear(
self.d_model,
self.d_model,
bias=False,
linear_method=linear_method,
quant_config=quant_config,
)
self.rotary_emb = get_rope(
self.head_dim,
......@@ -244,11 +245,11 @@ class DbrxFusedNormAttention(nn.Module):
def __init__(
self,
config: DbrxConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.d_model = config.d_model
self.attn = DbrxAttention(config, linear_method)
self.attn = DbrxAttention(config, quant_config)
self.norm_1 = nn.LayerNorm(self.d_model)
self.norm_2 = nn.LayerNorm(self.d_model)
......@@ -278,11 +279,11 @@ class DbrxBlock(nn.Module):
def __init__(
self,
config: DbrxConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.norm_attn_norm = DbrxFusedNormAttention(config, linear_method)
self.ffn = DbrxExperts(config, linear_method)
self.norm_attn_norm = DbrxFusedNormAttention(config, quant_config)
self.ffn = DbrxExperts(config, quant_config)
def forward(
self,
......@@ -307,7 +308,7 @@ class DbrxModel(nn.Module):
def __init__(
self,
config: DbrxConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.wte = VocabParallelEmbedding(
......@@ -315,7 +316,7 @@ class DbrxModel(nn.Module):
config.d_model,
)
self.blocks = nn.ModuleList(
[DbrxBlock(config, linear_method) for _ in range(config.n_layers)])
[DbrxBlock(config, quant_config) for _ in range(config.n_layers)])
self.norm_f = nn.LayerNorm(config.d_model, eps=1e-5)
for module in self.modules():
if hasattr(module, "bias") and isinstance(module.bias,
......@@ -348,13 +349,13 @@ class DbrxForCausalLM(nn.Module):
def __init__(
self,
config: DbrxConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
self.linear_method = linear_method
self.quant_config = quant_config
self.unpadded_vocab_size = config.vocab_size
self.transformer = DbrxModel(config, linear_method)
self.transformer = DbrxModel(config, quant_config)
self.lm_head = ParallelLMHead(
config.vocab_size,
config.d_model,
......
......@@ -29,7 +29,8 @@ import torch
from transformers import PretrainedConfig
from vllm.config import LoRAConfig
from vllm.model_executor.layers.linear import LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.llama import LlamaForCausalLM
......@@ -55,13 +56,13 @@ class DeciLMForCausalLM(LlamaForCausalLM):
def __init__(
self,
config: Optional[PretrainedConfig] = None,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None:
config.num_key_value_heads = max(config.num_key_value_heads_per_layer)
delattr(config, "num_key_value_heads_per_layer")
super().__init__(config=config,
linear_method=linear_method,
quant_config=quant_config,
lora_config=lora_config)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
......
......@@ -34,12 +34,13 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear,
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
......@@ -56,18 +57,18 @@ class DeepseekMLP(nn.Module):
hidden_size: int,
intermediate_size: int,
hidden_act: str,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
reduce_results: bool = True,
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
bias=False,
linear_method=linear_method)
quant_config=quant_config)
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
linear_method=linear_method,
quant_config=quant_config,
reduce_results=reduce_results)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
......@@ -86,7 +87,7 @@ class DeepseekMoE(nn.Module):
def __init__(
self,
config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
......@@ -103,7 +104,7 @@ class DeepseekMoE(nn.Module):
DeepseekMLP(hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
hidden_act=config.hidden_act,
linear_method=linear_method,
quant_config=quant_config,
reduce_results=False)
for idx in range(self.n_routed_experts)
])
......@@ -112,7 +113,7 @@ class DeepseekMoE(nn.Module):
self.gate = ReplicatedLinear(config.hidden_size,
self.n_routed_experts,
bias=False,
linear_method=None)
quant_config=None)
if config.n_shared_experts is not None:
intermediate_size = (config.moe_intermediate_size *
......@@ -121,7 +122,7 @@ class DeepseekMoE(nn.Module):
hidden_size=config.hidden_size,
intermediate_size=intermediate_size,
hidden_act=config.hidden_act,
linear_method=linear_method,
quant_config=quant_config,
reduce_results=False,
)
......@@ -177,7 +178,7 @@ class DeepseekAttention(nn.Module):
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.hidden_size = hidden_size
......@@ -208,14 +209,14 @@ class DeepseekAttention(nn.Module):
self.total_num_heads,
self.total_num_kv_heads,
bias=False,
linear_method=linear_method,
quant_config=quant_config,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
linear_method=linear_method,
quant_config=quant_config,
)
self.rotary_emb = get_rope(
......@@ -251,7 +252,7 @@ class DeepseekDecoderLayer(nn.Module):
self,
config: PretrainedConfig,
layer_idx: int,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
......@@ -266,18 +267,18 @@ class DeepseekDecoderLayer(nn.Module):
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
linear_method=linear_method,
quant_config=quant_config,
)
if (config.n_routed_experts is not None
and layer_idx >= config.first_k_dense_replace
and layer_idx % config.moe_layer_freq == 0):
self.mlp = DeepseekMoE(config=config, linear_method=linear_method)
self.mlp = DeepseekMoE(config=config, quant_config=quant_config)
else:
self.mlp = DeepseekMLP(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
linear_method=linear_method,
quant_config=quant_config,
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
......@@ -320,7 +321,7 @@ class DeepseekModel(nn.Module):
def __init__(
self,
config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.padding_idx = config.pad_token_id
......@@ -331,9 +332,7 @@ class DeepseekModel(nn.Module):
config.hidden_size,
)
self.layers = nn.ModuleList([
DeepseekDecoderLayer(config,
layer_idx,
linear_method=linear_method)
DeepseekDecoderLayer(config, layer_idx, quant_config=quant_config)
for layer_idx in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
......@@ -361,12 +360,12 @@ class DeepseekForCausalLM(nn.Module):
def __init__(
self,
config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.config = config
self.linear_method = linear_method
self.model = DeepseekModel(config, linear_method)
self.quant_config = quant_config
self.model = DeepseekModel(config, quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
......
......@@ -32,10 +32,11 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
......@@ -76,7 +77,7 @@ class FalconAttention(nn.Module):
def __init__(
self,
config: FalconConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
......@@ -115,7 +116,7 @@ class FalconAttention(nn.Module):
self.total_num_kv_heads,
bias=config.bias,
skip_bias_add=True,
linear_method=linear_method,
quant_config=quant_config,
)
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
......@@ -129,7 +130,7 @@ class FalconAttention(nn.Module):
self.hidden_size,
bias=config.bias,
skip_bias_add=True,
linear_method=linear_method,
quant_config=quant_config,
reduce_results=self.reduce_row_parallel_results)
self.use_rotary = config.rotary
......@@ -192,7 +193,7 @@ class FalconMLP(nn.Module):
def __init__(
self,
config: FalconConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
hidden_size = config.hidden_size
......@@ -201,8 +202,7 @@ class FalconMLP(nn.Module):
4 * hidden_size,
bias=config.bias,
skip_bias_add=True,
linear_method=linear_method)
quant_config = getattr(linear_method, "quant_config", None)
quant_config=quant_config)
self.act = get_act_fn("gelu", quant_config, 4 * hidden_size)
self.reduce_row_parallel_results = not (config.new_decoder_architecture
or config.parallel_attn)
......@@ -212,7 +212,7 @@ class FalconMLP(nn.Module):
bias=config.bias,
skip_bias_add=True,
reduce_results=self.reduce_row_parallel_results,
linear_method=linear_method)
quant_config=quant_config)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# NOTE(zhuohan): Following huggingface, we do not fuse bias add here.
......@@ -229,13 +229,13 @@ class FalconDecoderLayer(nn.Module):
def __init__(
self,
config: FalconConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.self_attention = FalconAttention(config, linear_method)
self.mlp = FalconMLP(config, linear_method)
self.self_attention = FalconAttention(config, quant_config)
self.mlp = FalconMLP(config, quant_config)
self.config = config
if config.new_decoder_architecture:
......@@ -311,7 +311,7 @@ class FalconModel(nn.Module):
def __init__(
self,
config: FalconConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
......@@ -327,7 +327,7 @@ class FalconModel(nn.Module):
# Transformer blocks
self.h = nn.ModuleList([
FalconDecoderLayer(config, linear_method)
FalconDecoderLayer(config, quant_config)
for _ in range(config.num_hidden_layers)
])
......@@ -359,12 +359,12 @@ class FalconForCausalLM(nn.Module):
def __init__(
self,
config: FalconConfig,
linear_method: Optional[LinearMethodBase] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
self.linear_method = linear_method
self.transformer = FalconModel(config, linear_method)
self.quant_config = quant_config
self.transformer = FalconModel(config, quant_config)
self.lm_head_weight = self.transformer.word_embeddings.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment