Commit af7f4372 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.5.5' into v0.5.5-dtk24.04.1

parents 5e19cdef 09c77926
......@@ -4,6 +4,7 @@ import torch
from torch.nn import Module
from torch.nn.parameter import Parameter
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase
......@@ -18,11 +19,13 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
is_layer_skipped)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
all_close_1d, apply_fp8_linear, convert_to_channelwise,
create_per_tensor_scale_param, cutlass_fp8_supported,
per_tensor_dequantize, requantize_with_max_scale)
cutlass_fp8_supported, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize,
requantize_with_max_scale)
from vllm.model_executor.parameter import (ModelWeightParameter,
PerTensorScaleParameter)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils import print_warning_once
from vllm.utils import is_hip, print_warning_once
ACTIVATION_SCHEMES = ["static", "dynamic"]
......@@ -118,7 +121,10 @@ class Fp8LinearMethod(LinearMethodBase):
# kernel for fast weight-only FP8 quantization
capability = current_platform.get_device_capability()
capability = capability[0] * 10 + capability[1]
self.use_marlin = capability < 89
self.use_marlin = capability < 89 or envs.VLLM_TEST_FORCE_FP8_MARLIN
# Disable marlin for rocm
if is_hip():
self.use_marlin = False
def create_weights(
self,
......@@ -132,6 +138,7 @@ class Fp8LinearMethod(LinearMethodBase):
):
del input_size, output_size
output_size_per_partition = sum(output_partition_sizes)
weight_loader = extra_weight_attrs.get("weight_loader")
layer.logical_widths = output_partition_sizes
......@@ -143,37 +150,54 @@ class Fp8LinearMethod(LinearMethodBase):
weight_dtype = (torch.float8_e4m3fn
if self.quant_config.is_checkpoint_fp8_serialized else
params_dtype)
weight = Parameter(torch.empty(output_size_per_partition,
weight = ModelWeightParameter(data=torch.empty(
output_size_per_partition,
input_size_per_partition,
dtype=weight_dtype),
requires_grad=False)
input_dim=1,
output_dim=0,
weight_loader=weight_loader)
layer.register_parameter("weight", weight)
set_weight_attrs(weight, {
**extra_weight_attrs,
"input_dim": 1,
"output_dim": 0,
})
# If checkpoint is serialized fp8, load them.
# Otherwise, wait until process_weights_after_loading.
if self.quant_config.is_checkpoint_fp8_serialized:
# WEIGHT SCALE
scale = create_per_tensor_scale_param(output_partition_sizes,
**extra_weight_attrs)
scale = PerTensorScaleParameter(data=torch.empty(
len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader)
scale[:] = torch.finfo(torch.float32).min
layer.register_parameter("weight_scale", scale)
# INPUT ACTIVATION SCALE
if self.quant_config.activation_scheme == "static":
scale = create_per_tensor_scale_param(output_partition_sizes,
**extra_weight_attrs)
scale = PerTensorScaleParameter(data=torch.empty(
len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader)
scale[:] = torch.finfo(torch.float32).min
layer.register_parameter("input_scale", scale)
else:
layer.register_parameter("input_scale", None)
def process_weights_after_loading(self, layer: Module) -> None:
layer.weight = torch.nn.Parameter(layer.weight.data,
requires_grad=False)
# If checkpoint not serialized fp8, quantize the weights.
if not self.quant_config.is_checkpoint_fp8_serialized:
qweight, weight_scale = ops.scaled_fp8_quant(layer.weight,
scale=None)
# If using marlin (w8a16), kernel uses channelwise weights,
# so extend the weight scales to be channelwise.
if self.use_marlin:
assert weight_scale.numel() == 1
weight_scale = convert_to_channelwise(
weight_scale.expand(len(layer.logical_widths)),
layer.logical_widths)
# Update the layer with the new values.
layer.weight = Parameter(qweight.t(), requires_grad=False)
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
......@@ -182,6 +206,11 @@ class Fp8LinearMethod(LinearMethodBase):
# If checkpoint is fp8, handle that there are N scales for N
# shards in a fused module
else:
layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data,
requires_grad=False)
if self.quant_config.activation_scheme == "static":
layer.input_scale = torch.nn.Parameter(layer.input_scale.data,
requires_grad=False)
# If using marlin (w8a16), kernel uses channelwise weights,
# so extend the weight scales to be channelwise.
if self.use_marlin:
......@@ -193,9 +222,23 @@ class Fp8LinearMethod(LinearMethodBase):
# requantize the logical shards as a single weight.
else:
# Dequant -> Quant with max scale so we can run per tensor.
weight = layer.weight
weight_scale = layer.weight_scale
# If rocm, use float8_e4m3fnuz.
if is_hip():
weight, weight_scale, input_scale = \
normalize_e4m3fn_to_e4m3fnuz(
weight=weight,
weight_scale=weight_scale,
input_scale=layer.input_scale)
if input_scale is not None:
layer.input_scale = Parameter(input_scale,
requires_grad=False)
weight_scale, weight = requantize_with_max_scale(
weight=layer.weight,
weight_scale=layer.weight_scale,
weight=weight,
weight_scale=weight_scale,
logical_widths=layer.logical_widths,
)
......@@ -205,8 +248,6 @@ class Fp8LinearMethod(LinearMethodBase):
if self.quant_config.activation_scheme == "static":
layer.input_scale = Parameter(layer.input_scale.max(),
requires_grad=False)
else:
layer.input_scale = None
if self.use_marlin:
prepare_fp8_layer_for_marlin(layer)
......@@ -281,23 +322,29 @@ class Fp8MoEMethod(FusedMoEMethodBase):
# WEIGHT_SCALES
# Allocate 2 scales for w1 and w3 respectively.
# They will be combined to a single scale after weight loading.
w13_scale = torch.nn.Parameter(torch.ones(num_experts,
w13_weight_scale = torch.nn.Parameter(torch.ones(num_experts,
2,
dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w13_scale", w13_scale)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
w2_scale = torch.nn.Parameter(torch.ones(num_experts,
w2_weight_scale = torch.nn.Parameter(torch.ones(num_experts,
dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w2_scale", w2_scale)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
# If loading fp8 checkpoint, pass the weight loaders.
# If loading an fp16 checkpoint, do not (we will quantize in
# process_weights_after_loading()
if self.quant_config.is_checkpoint_fp8_serialized:
set_weight_attrs(w13_scale, extra_weight_attrs)
set_weight_attrs(w2_scale, extra_weight_attrs)
set_weight_attrs(w13_weight_scale, {
"is_fp8_scale": True,
**extra_weight_attrs
})
set_weight_attrs(w2_weight_scale, {
"is_fp8_scale": True,
**extra_weight_attrs
})
# INPUT_SCALES
if self.quant_config.activation_scheme == "static":
......@@ -306,42 +353,50 @@ class Fp8MoEMethod(FusedMoEMethodBase):
"Found static activation scheme for checkpoint that "
"was not serialized fp8.")
a13_scale = torch.nn.Parameter(torch.ones(num_experts,
dtype=torch.float32),
w13_input_scale = torch.nn.Parameter(torch.ones(
num_experts, dtype=torch.float32),
requires_grad=False)
layer.register_parameter("a13_scale", a13_scale)
set_weight_attrs(a13_scale, extra_weight_attrs)
layer.register_parameter("w13_input_scale", w13_input_scale)
set_weight_attrs(w13_input_scale, {
"is_fp8_scale": True,
**extra_weight_attrs
})
a2_scale = torch.nn.Parameter(torch.ones(num_experts,
dtype=torch.float32),
w2_input_scale = torch.nn.Parameter(torch.ones(
num_experts, dtype=torch.float32),
requires_grad=False)
layer.register_parameter("a2_scale", a2_scale)
set_weight_attrs(a2_scale, extra_weight_attrs)
layer.register_parameter("w2_input_scale", w2_input_scale)
set_weight_attrs(w2_input_scale, {
"is_fp8_scale": True,
**extra_weight_attrs
})
else:
layer.a13_scale = None
layer.a2_scale = None
layer.w13_input_scale = None
layer.w2_input_scale = None
def process_weights_after_loading(self, layer: Module) -> None:
# If checkpoint is fp16, quantize in place.
if not self.quant_config.is_checkpoint_fp8_serialized:
# If rocm, use float8_e4m3fnuz as dtype
fp8_dtype = torch.float8_e4m3fnuz \
if is_hip() else torch.float8_e4m3fn
w13_weight = torch.empty_like(layer.w13_weight.data,
dtype=torch.float8_e4m3fn)
w2_weight = torch.empty_like(layer.w2_weight.data,
dtype=torch.float8_e4m3fn)
dtype=fp8_dtype)
w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
# Re-initialize w13_scale because we directly quantize
# merged w13 weights and generate a single scaling factor.
layer.w13_scale = torch.nn.Parameter(torch.ones(
layer.w13_weight_scale = torch.nn.Parameter(torch.ones(
layer.num_experts,
dtype=torch.float32,
device=w13_weight.device),
requires_grad=False)
for expert in range(layer.num_experts):
w13_weight[expert, :, :], layer.w13_scale[
w13_weight[expert, :, :], layer.w13_weight_scale[
expert] = ops.scaled_fp8_quant(
layer.w13_weight.data[expert, :, :])
w2_weight[expert, :, :], layer.w2_scale[
w2_weight[expert, :, :], layer.w2_weight_scale[
expert] = ops.scaled_fp8_quant(
layer.w2_weight.data[expert, :, :])
layer.w13_weight = torch.nn.Parameter(w13_weight,
......@@ -357,39 +412,66 @@ class Fp8MoEMethod(FusedMoEMethodBase):
# Fp8 moe kernels require a single activation scale.
# We take the max of all the scales in case they differ.
if self.quant_config.activation_scheme == "static":
if layer.a13_scale is None or layer.a2_scale is None:
if (layer.w13_input_scale is None
or layer.w2_input_scale is None):
raise ValueError(
"QuantConfig has static quantization, but found "
"activation scales are None.")
if (not all_close_1d(layer.a13_scale)
or not all_close_1d(layer.a2_scale)):
if (not all_close_1d(layer.w13_input_scale)
or not all_close_1d(layer.w2_input_scale)):
print_warning_once(
"Found input_scales that are not equal for "
"fp8 MoE layer. Using the maximum across experts "
"for each layer. ")
layer.a13_scale = torch.nn.Parameter(layer.a13_scale.max(),
layer.w13_input_scale = torch.nn.Parameter(
layer.w13_input_scale.max(), requires_grad=False)
layer.w2_input_scale = torch.nn.Parameter(
layer.w2_input_scale.max(), requires_grad=False)
# If rocm, normalize the weights and scales to e4m3fnuz
if is_hip():
# Normalize the weights and scales
w13_weight, w13_weight_scale, w13_input_scale = \
normalize_e4m3fn_to_e4m3fnuz(
layer.w13_weight, layer.w13_weight_scale,
layer.w13_input_scale)
w2_weight, w2_weight_scale, w2_input_scale = \
normalize_e4m3fn_to_e4m3fnuz(
layer.w2_weight, layer.w2_weight_scale,
layer.w2_input_scale)
# Reset the parameter
layer.w13_weight = torch.nn.Parameter(w13_weight,
requires_grad=False)
layer.w13_weight_scale = torch.nn.Parameter(
w13_weight_scale, requires_grad=False)
if w13_input_scale is not None:
layer.w13_input_scale = torch.nn.Parameter(
w13_input_scale, requires_grad=False)
layer.w2_weight = torch.nn.Parameter(w2_weight,
requires_grad=False)
layer.a2_scale = torch.nn.Parameter(layer.a2_scale.max(),
layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale,
requires_grad=False)
if w2_input_scale is not None:
layer.w2_input_scale = torch.nn.Parameter(
w2_input_scale, requires_grad=False)
# Fp8 moe kernel needs single weight scale for w13 per expert.
# We take the max then dequant and requant each expert.
assert layer.w13_scale is not None
assert layer.w13_weight_scale is not None
shard_size = layer.intermediate_size_per_partition
max_w13_scales = layer.w13_scale.max(dim=1).values
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
for expert_id in range(layer.num_experts):
start = 0
for shard_id in range(2):
dq_weight = per_tensor_dequantize(
layer.w13_weight[expert_id][start:start +
shard_size, :],
layer.w13_scale[expert_id][shard_id])
layer.w13_weight_scale[expert_id][shard_id])
layer.w13_weight[expert_id][
start:start + shard_size, :], _ = ops.scaled_fp8_quant(
dq_weight, max_w13_scales[expert_id])
start += shard_size
layer.w13_scale = torch.nn.Parameter(max_w13_scales,
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
requires_grad=False)
return
......@@ -398,27 +480,33 @@ class Fp8MoEMethod(FusedMoEMethodBase):
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool = True,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None) -> torch.Tensor:
renormalize: bool,
use_grouped_topk: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_moe
return fused_moe(x,
from vllm.model_executor.layers.fused_moe import fused_experts
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group)
return fused_experts(x,
layer.w13_weight,
layer.w2_weight,
router_logits,
top_k,
renormalize=renormalize,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
use_fp8=True,
w1_scale=layer.w13_scale,
w2_scale=layer.w2_scale,
a1_scale=layer.a13_scale,
a2_scale=layer.a2_scale,
use_grouped_topk=use_grouped_topk,
num_expert_group=num_expert_group,
topk_group=topk_group)
use_fp8_w8a8=True,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale)
class Fp8KVCacheMethod(BaseKVCacheMethod):
......
from typing import Any, Dict, List, Optional
import gguf
import torch
from torch.nn.parameter import Parameter, UninitializedParameter
from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.utils import set_weight_attrs
class GGUFConfig(QuantizationConfig):
"""Config class for GGUF."""
def __init__(self, ) -> None:
pass
def __repr__(self) -> str:
return ("GGUFConfig()")
def get_name(self) -> str:
return "gguf"
def get_supported_act_dtypes(self) -> List[torch.dtype]:
return [torch.half, torch.bfloat16]
@classmethod
def get_min_capability(cls) -> int:
return 60
@classmethod
def get_config_filenames(cls) -> List[str]:
return [] # no extra configs.
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "GGUFConfig":
return cls()
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["QuantizeMethodBase"]:
if isinstance(layer, LinearBase):
return GGUFLinearMethod(self)
elif isinstance(layer, VocabParallelEmbedding):
return GGUFEmbeddingMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
return []
def _fuse_mul_mat(x: torch.Tensor, qweight: torch.Tensor,
qweight_type: int) -> torch.Tensor:
# use dequantize mulmat for IQmatrix, mmq for k-quants
if qweight_type >= 16:
block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type]
shape = (qweight.shape[0], qweight.shape[1] // type_size * block_size)
weight = ops.ggml_dequantize(qweight, qweight_type, *shape)
y = x @ weight.T
else:
y = ops.ggml_mul_mat_a8(qweight, x, qweight_type, qweight.shape[0])
return y
class GGUFLinearMethod(LinearMethodBase):
"""Linear method for GGUF.
Args:
quant_config: The GGUF quantization config.
"""
def __init__(self, quant_config: GGUFConfig):
self.quant_config = quant_config
def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int], input_size: int,
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
output_size_per_partition = sum(output_partition_sizes)
tensor_shape = (output_size_per_partition, input_size_per_partition)
qweight = UninitializedParameter(requires_grad=False)
set_weight_attrs(
qweight, {
"input_dim": 1,
"output_dim": 0,
"tensor_shape": tensor_shape,
"is_gguf_weight": True,
"shard_size": {},
"shard_id": [],
})
set_weight_attrs(qweight, extra_weight_attrs)
layer.register_parameter("qweight", qweight)
qweight_type = Parameter(torch.empty(len(output_partition_sizes),
dtype=torch.uint8),
requires_grad=False)
set_weight_attrs(
qweight_type, {
"is_gguf_weight_type": True,
"weight_type": 0,
"shard_weight_type": {},
"ignore_warning": True
})
set_weight_attrs(qweight_type, extra_weight_attrs)
layer.register_parameter("qweight_type", qweight_type)
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
shard_size = getattr(layer.qweight, "shard_size", None)
shard_id = getattr(layer.qweight, "shard_id", None)
if shard_id and shard_size:
result = []
offset = 0
# dequantize shard weights respectively
shard_id = ["q", "k", "v"] if "q" in shard_id else shard_id
for id in shard_id:
shard_weight = layer.qweight[
offset:offset +
shard_size[id][0], :shard_size[id][1]].contiguous()
qweight_type = layer.qweight_type.shard_weight_type[id]
result.append(_fuse_mul_mat(x, shard_weight, qweight_type))
offset += shard_size[id][0]
out = torch.cat(result, axis=1)
else:
qweight = layer.qweight
qweight_type = layer.qweight_type.weight_type
out = _fuse_mul_mat(x, qweight, qweight_type)
if bias is not None:
out.add_(bias)
return out
class GGUFEmbeddingMethod(GGUFLinearMethod):
"""Embedding method for GGUF.
Args:
quant_config: The GGUF quantization config.
"""
def embedding(self, layer: torch.nn.Module,
x: torch.Tensor) -> torch.Tensor:
qweight = layer.qweight
qweight_type = layer.qweight_type.weight_type
block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type]
hidden_size = qweight.shape[1] // type_size * block_size
if qweight_type < 2:
return torch.embedding(qweight, x)
x_flat = x.flatten()
quant = torch.index_select(qweight, dim=0, index=x_flat)
dequant = ops.ggml_dequantize(quant, qweight_type, hidden_size,
x_flat.shape[0])
return dequant.view(*x.shape, hidden_size)
......@@ -204,13 +204,7 @@ class GPTQLinearMethod(LinearMethodBase):
layer.exllama_state = exllama_state
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
qweight = layer.qweight
out_shape = x.shape[:-1] + (qweight.shape[-1], )
reshaped_x = x.reshape(-1, x.shape[-1])
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# exllama needs to shuffle the weight after the weight is loaded
# here we do the shuffle on first forward pass
if layer.exllama_state == ExllamaState.UNINITIALIZED:
......@@ -218,10 +212,19 @@ class GPTQLinearMethod(LinearMethodBase):
layer.g_idx.data = torch.argsort(layer.g_idx).to(torch.int)
else:
layer.g_idx.data = torch.empty((0, ),
dtype=torch.int,
device=layer.g_idx.device)
layer.exllama_state = ExllamaState.READY
ops.gptq_shuffle(layer.qweight, layer.g_idx,
self.quant_config.weight_bits)
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
out_shape = x.shape[:-1] + (layer.qweight.shape[-1], )
reshaped_x = x.reshape(-1, x.shape[-1])
output = ops.gptq_gemm(reshaped_x, layer.qweight, layer.qzeros,
layer.scales, layer.g_idx,
layer.exllama_state == ExllamaState.READY,
......
from typing import Any, Dict, List, Optional
import torch
from torch.nn.parameter import Parameter
from torch.nn import Parameter
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
set_weight_attrs)
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
......@@ -15,6 +14,11 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
marlin_repeat_scales_on_all_ranks, marlin_sort_g_idx, replace_tensor,
verify_marlin_supported, verify_marlin_supports_shape)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
GroupQuantScaleParameter,
PackedColumnParameter,
PackedvLLMParameter,
RowvLLMParameter)
from vllm.scalar_type import scalar_types
logger = init_logger(__name__)
......@@ -136,8 +140,7 @@ class GPTQMarlinConfig(QuantizationConfig):
return False
return check_marlin_supported(quant_type=cls.TYPE_MAP[(num_bits, sym)],
group_size=group_size,
min_capability=cls.get_min_capability())
group_size=group_size)
class GPTQMarlinLinearMethod(LinearMethodBase):
......@@ -160,9 +163,11 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
params_dtype: torch.dtype,
**extra_weight_attrs,
) -> None:
del output_size
output_size_per_partition = sum(output_partition_sizes)
is_row_parallel = input_size != input_size_per_partition
weight_loader = extra_weight_attrs.get("weight_loader")
# Normalize group_size
if self.quant_config.group_size != -1:
......@@ -191,80 +196,66 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
scales_and_zp_size = input_size_per_partition // group_size
# Quantized weights
qweight = Parameter(
torch.empty(
qweight = PackedvLLMParameter(
data=torch.empty(
input_size_per_partition // self.quant_config.pack_factor,
output_size_per_partition,
dtype=torch.int32,
),
requires_grad=False,
)
set_weight_attrs(
qweight,
{
**extra_weight_attrs,
"input_dim": 0,
"output_dim": 1,
"packed_dim": 0,
"pack_factor": self.quant_config.pack_factor,
},
)
input_dim=0,
output_dim=1,
packed_dim=0,
packed_factor=self.quant_config.pack_factor,
weight_loader=weight_loader)
# Activation order
g_idx = Parameter(
torch.empty(
g_idx = RowvLLMParameter(data=torch.empty(
input_size_per_partition,
dtype=torch.int32,
),
requires_grad=False,
)
# Ignore warning from fused linear layers such as QKVParallelLinear.
set_weight_attrs(
g_idx,
{
**extra_weight_attrs, "input_dim": 0,
"ignore_warning": True
},
)
# Scales
scales = Parameter(
torch.empty(
scales_and_zp_size,
output_size_per_partition,
dtype=params_dtype,
),
requires_grad=False,
)
set_weight_attrs(
scales,
{
**extra_weight_attrs,
"input_dim": scales_and_zp_input_dim,
"output_dim": 1,
},
)
input_dim=0,
weight_loader=weight_loader)
# Quantized zero-points
qzeros = Parameter(
qzeros_args = {
"data":
torch.empty(
scales_and_zp_size,
output_size_per_partition // self.quant_config.pack_factor,
dtype=torch.int32,
device="meta",
),
requires_grad=False,
)
set_weight_attrs(
qzeros,
{
**extra_weight_attrs,
"input_dim": scales_and_zp_input_dim,
"output_dim": 1,
"packed_dim": 1,
"pack_factor": self.quant_config.pack_factor,
},
)
"weight_loader":
weight_loader
}
weight_scale_args = {
"data":
torch.empty(
scales_and_zp_size,
output_size_per_partition,
dtype=params_dtype,
),
"weight_loader":
weight_loader
}
if scales_and_zp_input_dim is None:
scales = ChannelQuantScaleParameter(output_dim=1,
**weight_scale_args)
qzeros = PackedColumnParameter(
output_dim=1,
packed_dim=1,
packed_factor=self.quant_config.pack_factor,
**qzeros_args)
else:
scales = GroupQuantScaleParameter(output_dim=1,
input_dim=0,
**weight_scale_args)
qzeros = PackedvLLMParameter(
input_dim=0,
output_dim=1,
packed_dim=1,
packed_factor=self.quant_config.pack_factor,
**qzeros_args)
layer.register_parameter("qweight", qweight)
layer.register_parameter("g_idx", g_idx)
......@@ -282,6 +273,10 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
device = layer.qweight.device
# required by torch.compile
layer.qweight = Parameter(layer.qweight.data, requires_grad=False)
layer.scales = Parameter(layer.scales.data, requires_grad=False)
# Allocate marlin workspace
layer.workspace = marlin_make_workspace(
layer.output_size_per_partition, device)
......
......@@ -9,7 +9,10 @@ from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.utils import set_weight_attrs
from vllm.model_executor.parameter import (BasevLLMParameter,
ChannelQuantScaleParameter,
GroupQuantScaleParameter,
PackedvLLMParameter)
logger = init_logger(__name__)
......@@ -132,6 +135,7 @@ class MarlinLinearMethod(LinearMethodBase):
**extra_weight_attrs,
):
del output_size # Unused.
weight_loader = extra_weight_attrs["weight_loader"]
if params_dtype != torch.float16:
raise ValueError(
......@@ -170,64 +174,64 @@ class MarlinLinearMethod(LinearMethodBase):
"Each permutation group must reside on the same gpu")
# Quantized 4Bit weights packed into Int32.
qweight = Parameter(
torch.empty(
qweight = PackedvLLMParameter(
data=torch.empty(
input_size_per_partition // self.quant_config.tile_size,
output_size_per_partition * self.quant_config.tile_size //
self.quant_config.pack_factor,
device="cuda",
dtype=torch.int32,
),
requires_grad=False,
)
set_weight_attrs(
qweight,
{
"input_dim": 0,
"output_dim": 1,
"packed_dim": 1,
"pack_factor": self.quant_config.pack_factor,
"marlin_tile_size": self.quant_config.tile_size,
},
)
input_dim=0,
output_dim=1,
packed_dim=1,
packed_factor=self.quant_config.pack_factor,
marlin_tile_size=self.quant_config.tile_size,
weight_loader=weight_loader)
# Determine if channelwise or not
input_groups = (1 if self.quant_config.group_size == -1 else
input_size_per_partition //
self.quant_config.group_size)
scales = Parameter(
weight_scale_args = {
"data":
torch.empty(
input_groups,
output_size_per_partition,
device="cuda",
dtype=params_dtype,
),
requires_grad=False,
)
set_weight_attrs(
scales,
{
"input_dim": None if input_groups == 1 else 0,
"output_dim": 1,
},
)
"weight_loader":
weight_loader
}
if input_groups == 1:
scales = ChannelQuantScaleParameter(output_dim=1,
**weight_scale_args)
else:
scales = GroupQuantScaleParameter(output_dim=1,
input_dim=0,
**weight_scale_args)
# Allocate workspace (Used for internal locking mechanism)
max_workspace_size = (
output_size_per_partition //
self.quant_config.min_n_threads) * self.quant_config.max_parallel
workspace = Parameter(torch.zeros(max_workspace_size,
workspace = BasevLLMParameter(data=torch.zeros(max_workspace_size,
device="cuda",
dtype=torch.int),
requires_grad=False)
weight_loader=weight_loader)
layer.register_parameter("B", qweight)
set_weight_attrs(qweight, extra_weight_attrs)
layer.register_parameter("s", scales)
set_weight_attrs(scales, extra_weight_attrs)
layer.register_parameter("workspace", workspace)
set_weight_attrs(workspace, extra_weight_attrs)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# required by torch.compile
layer.B = Parameter(layer.B.data, requires_grad=False)
layer.s = Parameter(layer.s.data, requires_grad=False)
layer.workspace = Parameter(layer.workspace.data, requires_grad=False)
def apply(
self,
......
from typing import Any, Dict, List, Optional, Tuple
import torch
from torch.nn import Module
from torch.nn.parameter import Parameter
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.utils import set_weight_attrs
ACTIVATION_SCHEMES = ["none"]
class Int8TpuConfig(QuantizationConfig):
"""Int8 Quantization Config class for TPU Backend."""
def __init__(
self,
activation_scheme: str = "none",
) -> None:
if activation_scheme not in ACTIVATION_SCHEMES:
raise ValueError(
f"Unsupported activation scheme {activation_scheme}")
self.activation_scheme = activation_scheme
def get_name(self) -> str:
return "tpu_int8"
def get_supported_act_dtypes(self) -> List[torch.dtype]:
return [torch.float16, torch.bfloat16]
@classmethod
def get_min_capability(cls) -> int:
raise NotImplementedError(
"This function should not be called with TPU Backend")
@staticmethod
def get_config_filenames() -> List[str]:
return []
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "Int8TpuConfig":
activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
return cls(activation_scheme=activation_scheme)
def get_quant_method(self, layer: Module,
prefix: str) -> Optional["TPUInt8LinearMethod"]:
if isinstance(layer, LinearBase):
return TPUInt8LinearMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
return []
class TPUInt8LinearMethod(LinearMethodBase):
"""Int8 Linear method for TPU Quant. """
def __init__(self, quant_config: Int8TpuConfig):
self.quant_config = quant_config
def create_weights(self, layer: Module, input_size_per_partition: int,
output_partition_sizes: List[int], input_size: int,
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
weight = Parameter(torch.empty(sum(output_partition_sizes),
input_size_per_partition,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("weight", weight)
set_weight_attrs(weight, {
**extra_weight_attrs,
"input_dim": 1,
"output_dim": 0,
})
def _quantize_weight(
self, weight: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
weight_dtype = weight.dtype
weight = weight.cpu().to(torch.float32)
n_bit = 8
eps = 1e-5
max_int = 2**(n_bit - 1) - 1
min_int = -(2**(n_bit - 1))
max_val = weight.abs().amax(dim=-1, keepdim=True)
max_val = max_val.clamp(min=eps)
qscale = max_val / max_int
qweight = torch.clamp(torch.round(weight * (1.0 / qscale)), min_int,
max_int).to(torch.int8)
qscale = qscale.squeeze().to(weight_dtype)
return qweight, qscale
def process_weights_after_loading(self, layer: Module) -> None:
device = layer.weight.device
qweight, qscale = self._quantize_weight(layer.weight)
qweight = qweight.to(device)
qscale = qscale.to(device)
layer.weight = Parameter(qweight, requires_grad=False)
layer.scale = Parameter(qscale, requires_grad=False)
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
try:
import torch_xla.experimental.xla_quantized_matmul # noqa: F401
except ImportError as err:
raise ImportError(
"Please install torch_xla by following the instructions at "
"https://docs.vllm.ai/en/latest/getting_started/tpu-installation.html " # noqa: E501
"to run vLLM on TPU.") from err
weight = layer.weight
scale = layer.scale
out = torch.ops.xla.quantized_matmul(x, weight, scale)
if bias is not None:
out = out + bias
return out
......@@ -26,12 +26,13 @@ USE_FP32_REDUCE_DEFAULT = True
# without runtime zero-point. We support common cases, i.e. AWQ and GPTQ.
# TODO: we may want to move this into the C++ so its closer to the actual impl
def query_marlin_supported_quant_types(has_zp: bool,
min_capability: Optional[int] = None):
if min_capability is None:
device_capability: Optional[int] = None
):
if device_capability is None:
major, minor = current_platform.get_device_capability()
min_capability = major * 10 + minor
device_capability = major * 10 + minor
if min_capability < 80:
if device_capability < 80:
return []
if has_zp:
......@@ -48,20 +49,20 @@ def _check_marlin_supported(
quant_type: ScalarType,
group_size: Optional[int],
has_zp: bool,
min_capability: Optional[int] = None) -> Tuple[bool, Optional[str]]:
device_capability: Optional[int] = None) -> Tuple[bool, Optional[str]]:
if min_capability is None:
if device_capability is None:
major, minor = current_platform.get_device_capability()
min_capability = major * 10 + minor
device_capability = major * 10 + minor
supported_types = query_marlin_supported_quant_types(
has_zp, min_capability)
has_zp, device_capability)
if quant_type not in supported_types:
return (False, f"Marlin does not support weight_bits = {quant_type}. "
f"Only types = {supported_types} "
f"are supported (for group_size = {group_size}, "
f"min_capability = {min_capability}, zp = {has_zp}).")
f"device_capability = {device_capability}, zp = {has_zp}).")
if (group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES):
return (False, f"Marlin does not support group_size = {group_size}. "
f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} "
......@@ -73,9 +74,9 @@ def _check_marlin_supported(
def check_marlin_supported(quant_type: ScalarType,
group_size: int,
has_zp: bool = False,
min_capability: Optional[int] = None) -> bool:
device_capability: Optional[int] = None) -> bool:
cond, _ = _check_marlin_supported(quant_type, group_size, has_zp,
min_capability)
device_capability)
return cond
......
......@@ -81,7 +81,8 @@ def permute_rows(q_w: torch.Tensor, w_ref: torch.Tensor, group_size: int):
def quantize_weights(w: torch.Tensor,
quant_type: ScalarType,
group_size: int,
zero_points: bool = False):
zero_points: bool = False,
ref_zero_points_after_scales: bool = False):
assert quant_type.is_integer(), \
"Floating point quantization may work but has not been tested"
......@@ -126,6 +127,12 @@ def quantize_weights(w: torch.Tensor,
w_q = torch.clamp(w_q, min_q_val, max_q_val)
# Compute ref (dequantized)
# For some kernels (namely Machete) the zero-points are applied after the
# scales are applied, for this case computing the reference in similar way
# allows us to use tighter error tolerances in our unit tests.
if ref_zero_points_after_scales and zero_points:
w_ref = w_q.to(orig_type) * w_s - maybe_w_zp.to(orig_type) * w_s
else:
w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s
if quant_type.has_bias():
......
......@@ -6,9 +6,19 @@ from torch.nn import Parameter
from vllm import _custom_ops as ops
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils import is_hip
# scaled_mm in pytorch on rocm has a bug that requires always
# providing scaling factor for result. This value is created
# as global value to avoid multiple tensor allocations, and
# can be removed once pytorch fixes the bug.
TORCH_SCALED_MM_SCALE_RESULT = torch.ones(1).cuda() if is_hip() else None
def cutlass_fp8_supported() -> bool:
# cutlass is not supported on Rocm
if is_hip():
return False
capability = current_platform.get_device_capability()
capability = capability[0] * 10 + capability[1]
......@@ -147,13 +157,19 @@ def apply_fp8_linear(
if per_tensor_weights and per_tensor_activations:
# Fused GEMM_DQ
output, _ = torch._scaled_mm(qinput,
output = torch._scaled_mm(
qinput,
weight,
out_dtype=input.dtype,
scale_a=x_scale,
scale_b=weight_scale,
scale_result=TORCH_SCALED_MM_SCALE_RESULT,
bias=bias)
# Since in torch 2.5, scaled_mm only returns single value
# This should be removed when vllm-nvidia also moves to 2.5
if is_hip():
return torch.narrow(output, 0, 0, input.shape[0])
return torch.narrow(output[0], 0, 0, input.shape[0])
else:
# Fallback for channelwise case, where we use unfused DQ
......@@ -207,3 +223,27 @@ def apply_int8_linear(
scale_b=weight_scale,
out_dtype=input.dtype,
bias=bias)
def normalize_e4m3fn_to_e4m3fnuz(
weight: torch.Tensor,
weight_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
assert weight.dtype == torch.float8_e4m3fn
# The bits pattern 10000000(-128) represents zero in e4m3fn
# but NaN in e4m3fnuz. So here we set it to 0.
# https://onnx.ai/onnx/technical/float8.html
weight_as_int8 = weight.view(torch.int8)
ROCM_FP8_NAN_AS_INT = -128
weight_as_int8[weight_as_int8 == ROCM_FP8_NAN_AS_INT] = 0
weight = weight_as_int8.view(torch.float8_e4m3fnuz)
# For the same bits representation, e4m3fnuz value is half of
# the e4m3fn value, so we should double the scaling factor to
# get the same dequantized value.
# https://onnx.ai/onnx/technical/float8.html
weight_scale = weight_scale * 2.0
if input_scale is not None:
input_scale = input_scale * 2.0
return weight, weight_scale, input_scale
......@@ -78,8 +78,8 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
# Only perform shape/dtype/device checking in strict mode, as it adds
# overhead.
if self._strict_mode:
self._raise_if_incorrect_input(target_probs, bonus_token_ids,
draft_probs, draft_token_ids)
self._raise_if_incorrect_input(target_probs, draft_token_ids,
bonus_token_ids, draft_probs)
accepted, recovered_token_ids = (
self._batch_modified_rejection_sampling(
......
......@@ -28,7 +28,7 @@ import torch
import torch.nn as nn
from vllm.model_executor.custom_op import CustomOp
from vllm.utils import is_tpu
from vllm.platforms import current_platform
def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
......@@ -46,15 +46,23 @@ def _rotate_gptj(x: torch.Tensor) -> torch.Tensor:
def _apply_rotary_emb(
x: torch.Tensor,
freqs_cis: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> torch.Tensor:
x_ = torch.view_as_complex(
torch.stack(torch.chunk(x.transpose(1, 2).float(), 2, dim=-1), dim=-1))
x_out = torch.view_as_real(x_ * freqs_cis).type_as(x)
x_out = torch.cat(torch.chunk(x_out, 2, dim=-1), dim=-2)
x_out = x_out.reshape(x_out.shape[0], x_out.shape[1], x_out.shape[2],
-1).transpose(1, 2)
return x_out
"""
Args:
x: [num_tokens, num_heads, head_size]
cos: [num_tokens, head_size // 2]
sin: [num_tokens, head_size // 2]
"""
orig_dtype = x.dtype
x = x.float()
x1, x2 = torch.chunk(x, 2, dim=-1)
cos = cos.unsqueeze(-2)
sin = sin.unsqueeze(-2)
o1 = x1 * cos - x2 * sin
o2 = x2 * cos + x1 * sin
return torch.cat((o1, o2), dim=-1).to(orig_dtype)
class RotaryEmbedding(CustomOp):
......@@ -78,22 +86,13 @@ class RotaryEmbedding(CustomOp):
self.dtype = dtype
cache = self._compute_cos_sin_cache()
self.use_native2 = is_tpu() and is_neox_style
if not self.use_native2:
cache = cache.to(dtype)
self.register_buffer("cos_sin_cache", cache, persistent=False)
else:
cos, sin = cache.chunk(2, dim=-1)
freqs_cis = cos + 1j * sin
self.register_buffer("freqs_cis", freqs_cis, persistent=False)
self.use_native2 = current_platform.is_tpu() and is_neox_style
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
"""Compute the inverse frequency."""
# NOTE(woosuk): The HF implementation uses `torch.arange(...).float()`.
# However, we use `torch.arange(..., dtype=torch.float)` instead to
# avoid numerical issues with large base values (e.g., 10000000).
# This may cause a slight numerical difference between the HF
# implementation and ours.
# NOTE(woosuk): To exactly match the HF implementation, we need to
# use CPU to compute the cache and then move it to GPU. However, we
# create the cache on GPU for faster initialization. This may cause
......@@ -173,28 +172,25 @@ class RotaryEmbedding(CustomOp):
This method might perform better than `forward_native()` when compiled.
"""
if positions.dim() == 1:
batch_size = 1
seq_len = positions.shape[0]
else:
batch_size, seq_len = positions.shape
if offsets is not None:
positions = positions + offsets
freqs_cis = self.freqs_cis.index_select(0, positions.flatten())
freqs_cis = freqs_cis.view(batch_size, 1, seq_len, -1)
positions = positions.flatten()
num_tokens = positions.shape[0]
cos_sin = self.cos_sin_cache.index_select(0, positions)
cos, sin = cos_sin.chunk(2, dim=-1)
query_shape = query.shape
query = query.view(batch_size, seq_len, -1, self.head_size)
query = query.view(num_tokens, -1, self.head_size)
query_rot = query[..., :self.rotary_dim]
query_pass = query[..., self.rotary_dim:]
query_rot = _apply_rotary_emb(query_rot, freqs_cis)
query_rot = _apply_rotary_emb(query_rot, cos, sin)
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
key_shape = key.shape
key = key.view(batch_size, seq_len, -1, self.head_size)
key = key.view(num_tokens, -1, self.head_size)
key_rot = key[..., :self.rotary_dim]
key_pass = key[..., self.rotary_dim:]
key_rot = _apply_rotary_emb(key_rot, freqs_cis)
key_rot = _apply_rotary_emb(key_rot, cos, sin)
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return query, key
......@@ -723,44 +719,50 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
return query, key
class GemmaRotaryEmbedding(RotaryEmbedding):
class Llama3RotaryEmbedding(RotaryEmbedding):
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
# https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/gemma/modeling_gemma.py#L107
inv_freq = 1.0 / (base**(
torch.arange(0, self.rotary_dim, 2, dtype=torch.int64).float() /
self.rotary_dim))
return inv_freq
class ExtendedRotaryEmbedding(RotaryEmbedding):
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
dtype: torch.dtype,
scaling_factor: float,
low_freq_factor: float,
high_freq_factor: float,
orig_max_position: int,
) -> None:
self.scaling_factor = scaling_factor
self.low_freq_factor = low_freq_factor
self.high_freq_factor = high_freq_factor
self.orig_max_position = orig_max_position
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
is_neox_style, dtype)
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
inv_freqs = super()._compute_inv_freq(base)
return self.apply_scaling(inv_freqs)
def apply_scaling(self, freqs: torch.Tensor):
scale_factor = 8
low_freq_factor = 1
high_freq_factor = 4
old_context_len = 8192
low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor
new_freqs = []
for freq in freqs:
wavelen = 2 * math.pi / freq
if wavelen < high_freq_wavelen:
new_freqs.append(freq)
elif wavelen > low_freq_wavelen:
new_freqs.append(freq / scale_factor)
low_freq_wavelen = self.orig_max_position / self.low_freq_factor
high_freq_wavelen = self.orig_max_position / self.high_freq_factor
wave_len = 2 * math.pi / inv_freqs
if self.low_freq_factor != self.high_freq_factor:
smooth = (self.orig_max_position / wave_len - self.low_freq_factor
) / (self.high_freq_factor - self.low_freq_factor)
else:
assert low_freq_wavelen != high_freq_wavelen
smooth = (old_context_len / wavelen - low_freq_factor) / (
high_freq_factor - low_freq_factor)
new_freqs.append((1 - smooth) * freq / scale_factor +
smooth * freq)
return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)
smooth = 0
new_freqs = torch.where(
wave_len < high_freq_wavelen,
inv_freqs,
torch.where(
wave_len > low_freq_wavelen,
inv_freqs / self.scaling_factor,
(1 - smooth) * inv_freqs / self.scaling_factor +
smooth * inv_freqs,
),
)
return new_freqs
_ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {}
......@@ -774,7 +776,7 @@ def get_rope(
is_neox_style: bool = True,
rope_scaling: Optional[Dict[str, Any]] = None,
dtype: Optional[torch.dtype] = None,
rotary_percent: float = 1.0,
partial_rotary_factor: float = 1.0,
) -> RotaryEmbedding:
if dtype is None:
dtype = torch.get_default_dtype()
......@@ -787,12 +789,13 @@ def get_rope(
rope_scaling_args = tuple(rope_scaling_tuple.items())
else:
rope_scaling_args = None
if rotary_percent < 1.0:
rotary_dim = int(rotary_dim * rotary_percent)
if partial_rotary_factor < 1.0:
rotary_dim = int(rotary_dim * partial_rotary_factor)
key = (head_size, rotary_dim, max_position, base, is_neox_style,
rope_scaling_args, dtype)
if key in _ROPE_DICT:
return _ROPE_DICT[key]
if rope_scaling is None:
rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base,
is_neox_style, dtype)
......@@ -801,12 +804,19 @@ def get_rope(
"type"] if "type" in rope_scaling else rope_scaling["rope_type"]
# The correct one should be "longrope" but keep "su" here
# for backward compatible
if scaling_type not in {"su", "longrope", "llama3"}:
if scaling_type not in {"su", "longrope"}:
scaling_factor = rope_scaling["factor"]
if scaling_type == "llama3":
rotary_emb = ExtendedRotaryEmbedding(head_size, rotary_dim,
low_freq_factor = rope_scaling["low_freq_factor"]
high_freq_factor = rope_scaling["high_freq_factor"]
original_max_position = rope_scaling[
"original_max_position_embeddings"]
rotary_emb = Llama3RotaryEmbedding(head_size, rotary_dim,
max_position, base,
is_neox_style, dtype)
is_neox_style, dtype,
scaling_factor, low_freq_factor,
high_freq_factor,
original_max_position)
elif scaling_type == "linear":
rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim,
max_position, base,
......
"""A layer that samples the next tokens from the model's outputs."""
import itertools
import warnings
from importlib.util import find_spec
from math import inf
from typing import Dict, List, Optional, Tuple
......@@ -11,6 +13,7 @@ from vllm.triton_utils import HAS_TRITON
if HAS_TRITON:
from vllm.model_executor.layers.ops.sample import sample as sample_triton
import vllm.envs as envs
from vllm.model_executor.sampling_metadata import (SamplingMetadata,
SamplingTensors,
SequenceGroupToSample)
......@@ -19,6 +22,16 @@ from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
PromptLogprobs, SampleLogprobs, SamplerOutput,
SequenceOutput)
if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"):
import flashinfer.sampling
# yapf: disable
from flashinfer.sampling import (
top_k_top_p_sampling_from_probs as flashinfer_top_k_top_p_sampling)
# yapf: enable
else:
flashinfer_top_k_top_p_sampling = None
# (num_token_ids, num_parent_ids) per sequence group.
SampleResultType = List[Tuple[List[int], List[int]]]
......@@ -51,6 +64,7 @@ class Sampler(nn.Module):
# containing the sampled token ids and probabilities. This is used by
# speculative decoding.
self.include_gpu_probs_tensor = False
self.should_modify_greedy_probs_inplace = False
def _init_sampling_tensors(
self,
......@@ -117,11 +131,12 @@ class Sampler(nn.Module):
sampling_tensors.frequency_penalties,
sampling_tensors.repetition_penalties)
# Apply temperature scaling.
# Use float32 to apply temperature scaling.
# Use in-place division to avoid creating a new tensor.
logits = logits.to(torch.float)
logits.div_(sampling_tensors.temperatures.unsqueeze(dim=1))
if do_top_p_top_k:
if do_top_p_top_k and flashinfer_top_k_top_p_sampling is None:
logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps,
sampling_tensors.top_ks)
......@@ -177,8 +192,7 @@ class Sampler(nn.Module):
This is used by speculative decoding, which requires that the sampling
method be encoded into the probability distribution.
"""
# Modify greedy probs if include_gpu_probs_tensor is set.
return self.include_gpu_probs_tensor
return self.should_modify_greedy_probs_inplace
def _get_bin_counts_and_mask(
......@@ -475,14 +489,7 @@ def _multinomial(
seq_groups: Optional[List[SequenceGroupToSample]] = None,
) -> torch.Tensor:
if num_samples > 1:
# This is equivalent to torch.repeat_interleaved (which also
# forces a GPU<->CPU sync).
# This allows us to do sampling with replacement by creating
# num_samples copies of each row in the tensor, and then
# batch sampling the resulting tensor.
probs = probs[:, None, :].expand(probs.shape[0], num_samples,
probs.shape[1]).contiguous().view(
-1, probs.shape[1])
probs = probs.repeat_interleave(num_samples, dim=0)
q = torch.empty_like(probs)
if seq_groups is None:
q.exponential_()
......@@ -490,17 +497,57 @@ def _multinomial(
sample_idx = 0
for seq_group in seq_groups:
seq_ids = seq_group.seq_ids
next_sample_idx = sample_idx + len(seq_ids) * num_samples
q[sample_idx:next_sample_idx].exponential_(
generator=seq_group.generator)
sample_idx = next_sample_idx
stride = len(seq_ids) * num_samples
assert seq_group.generator is not None
q[sample_idx:sample_idx +
stride].exponential_(generator=seq_group.generator)
sample_idx += stride
return probs.div_(q).argmax(dim=1).view(-1, num_samples)
def _top_k_top_p_multinomial_with_flashinfer(
probs: torch.Tensor, top_ks: torch.Tensor, top_ps: torch.Tensor,
num_samples: int, seq_groups: Optional[List[SequenceGroupToSample]]):
max_top_k_round = 32
if num_samples > 1:
probs = probs.repeat_interleave(num_samples, dim=0)
top_ks = top_ks.repeat_interleave(num_samples)
top_ps = top_ps.repeat_interleave(num_samples)
batch_size = probs.shape[0]
uniform_samples = torch.empty((max_top_k_round, batch_size),
device=probs.device)
if seq_groups is None:
uniform_samples.uniform_()
else:
sample_idx = 0
for seq_group in seq_groups:
seq_ids = seq_group.seq_ids
stride = len(seq_ids) * num_samples
assert seq_group.generator is not None
uniform_samples[:, sample_idx:sample_idx +
stride].uniform_(generator=seq_group.generator)
sample_idx += stride
batch_next_token_ids, success = flashinfer_top_k_top_p_sampling(
probs,
uniform_samples,
top_ks,
top_ps,
)
if not success.all():
warnings.warn("FlashInfer rejection sampling failed, fallback.",
stacklevel=1)
probs = flashinfer.sampling.top_k_renorm_prob(probs, top_ks)
probs = flashinfer.sampling.top_p_renorm_prob(probs, top_ps)
batch_next_token_ids = flashinfer.sampling.sampling_from_probs(
probs, uniform_samples[0])
return batch_next_token_ids.view(-1, num_samples)
def _sample_with_torch(
probs: torch.Tensor,
logprobs: torch.Tensor,
sampling_metadata: SamplingMetadata,
sampling_tensors: SamplingTensors,
include_gpu_probs_tensor: bool,
modify_greedy_probs: bool,
) -> Tuple[SampleResultType, Optional[torch.Tensor]]:
......@@ -563,18 +610,28 @@ def _sample_with_torch(
sampling_params = seq_group.sampling_params
max_best_of_in_batch = max(max_best_of_in_batch,
sampling_params.best_of)
seeded_args = {} if sampling_type == SamplingType.RANDOM else {
"seq_groups": seq_groups,
}
seq_groups_arg = (None if sampling_type == SamplingType.RANDOM else
seq_groups)
if flashinfer_top_k_top_p_sampling is not None:
multinomial_samples[
sampling_type] = _top_k_top_p_multinomial_with_flashinfer(
probs[long_sample_indices],
sampling_tensors.top_ks[long_sample_indices],
sampling_tensors.top_ps[long_sample_indices],
max_best_of_in_batch,
seq_groups_arg,
)
else:
multinomial_samples[sampling_type] = _multinomial(
probs[long_sample_indices], max_best_of_in_batch,
**seeded_args)
probs[long_sample_indices],
max_best_of_in_batch,
seq_groups=seq_groups_arg)
if sampled_token_ids_tensor is not None:
# Store sampled tokens in output tensor.
sampled_token_ids_tensor[
long_sample_indices] = multinomial_samples[sampling_type]
sampled_token_ids_tensor[long_sample_indices] = \
multinomial_samples[sampling_type].to(torch.long)
elif sampling_type == SamplingType.BEAM:
beam_search_logprobs = logprobs[sample_indices]
......@@ -692,9 +749,12 @@ def _sample_with_triton_kernel(
def _sample(
probs: torch.Tensor, logprobs: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_tensors: SamplingTensors,
include_gpu_probs_tensor: bool, modify_greedy_probs: bool
probs: torch.Tensor,
logprobs: torch.Tensor,
sampling_metadata: SamplingMetadata,
sampling_tensors: SamplingTensors,
include_gpu_probs_tensor: bool,
modify_greedy_probs: bool,
) -> Tuple[SampleResultType, Optional[torch.Tensor]]:
"""
Args:
......@@ -712,6 +772,7 @@ def _sample(
probs,
logprobs,
sampling_metadata,
sampling_tensors,
include_gpu_probs_tensor=include_gpu_probs_tensor,
modify_greedy_probs=modify_greedy_probs,
)
......
from abc import abstractmethod
from typing import Dict, Optional
from typing import Dict, Optional, Union
import torch
import torch.jit
......@@ -36,9 +36,12 @@ class SpecDecodeBaseSampler(nn.Module):
self.num_emitted_tokens: Optional[torch.Tensor] = None
self.num_draft_tokens: int = 0
def init_gpu_tensors(self, rank: int) -> None:
def init_gpu_tensors(self, device: Union[int, str]) -> None:
assert self.num_accepted_tokens is None
device = f"cuda:{rank}"
if isinstance(device, int):
device = f"cuda:{device}"
elif not isinstance(device, str):
raise ValueError(f"Device must be int or str, get {type(device)}")
self.num_accepted_tokens = torch.tensor(0,
dtype=torch.long,
device=device)
......
......@@ -3,19 +3,46 @@ from typing import List, Optional, Sequence, Tuple
import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from torch.nn.parameter import Parameter, UninitializedParameter
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
QuantizationConfig, QuantizeMethodBase, method_has_implemented_embedding)
from vllm.model_executor.utils import set_weight_attrs
DEFAULT_VOCAB_PADDING_SIZE = 64
class UnquantizedEmbeddingMethod(QuantizeMethodBase):
"""Unquantized method for embeddings."""
def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int], input_size: int,
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
"""Create weights for embedding layer."""
weight = Parameter(torch.empty(sum(output_partition_sizes),
input_size_per_partition,
dtype=params_dtype),
requires_grad=False)
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
layer.register_parameter("weight", weight)
set_weight_attrs(weight, extra_weight_attrs)
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return F.linear(x, layer.weight, bias)
def embedding(self, layer: torch.nn.Module,
input_: torch.Tensor) -> torch.Tensor:
return F.embedding(input_, layer.weight)
def pad_vocab_size(vocab_size: int,
pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> int:
"""Pad the vocab size to the given value."""
......@@ -199,7 +226,19 @@ class VocabParallelEmbedding(torch.nn.Module):
if quant_config is not None:
linear_method = quant_config.get_quant_method(self, prefix=prefix)
if linear_method is None:
linear_method = UnquantizedLinearMethod()
linear_method = UnquantizedEmbeddingMethod()
# If we are making an embedding layer, then our quantization linear
# method must implement the embedding operation. If we are another
# layer type like ParallelLMHead, this is not important.
is_embedding_layer = type(self.__class__) is VocabParallelEmbedding
linear_method_implements_embedding = method_has_implemented_embedding(
type(linear_method))
if is_embedding_layer and not linear_method_implements_embedding:
raise NotImplementedError(
f"The class {type(linear_method).__name__} must implement "
"the 'embedding' method, see UnquantizedEmbeddingMethod.")
self.linear_method: QuantizeMethodBase = linear_method
if params_dtype is None:
......@@ -306,6 +345,14 @@ class VocabParallelEmbedding(torch.nn.Module):
output_dim = getattr(param, "output_dim", None)
packed_dim = getattr(param, "packed_dim", None)
# If the parameter is a gguf weight, then load it directly.
if getattr(param, "is_gguf_weight_type", None):
param.data.copy_(loaded_weight)
param.weight_type = loaded_weight.item()
return
elif isinstance(param, UninitializedParameter):
param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)
# If parameter does not have output dim, then it should
# be copied onto all gpus (e.g. g_idx for act_order gptq).
if output_dim is None:
......@@ -344,7 +391,8 @@ class VocabParallelEmbedding(torch.nn.Module):
else:
masked_input = input_
# Get the embeddings.
output_parallel = F.embedding(masked_input.long(), self.weight)
output_parallel = self.linear_method.embedding(self,
masked_input.long())
# Mask the output embedding.
if self.tp_size > 1:
output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0)
......@@ -389,6 +437,7 @@ class ParallelLMHead(VocabParallelEmbedding):
super().__init__(num_embeddings, embedding_dim, params_dtype,
org_num_embeddings, padding_size, quant_config,
prefix)
if bias:
self.bias = Parameter(
torch.empty(self.num_embeddings_per_partition,
......
......@@ -3,8 +3,7 @@ from typing import Optional
from torch import nn
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ParallelConfig,
SchedulerConfig)
ModelConfig, ParallelConfig, SchedulerConfig)
from vllm.model_executor.model_loader.loader import (BaseModelLoader,
get_model_loader)
from vllm.model_executor.model_loader.utils import (
......@@ -15,13 +14,11 @@ def get_model(*, model_config: ModelConfig, load_config: LoadConfig,
device_config: DeviceConfig, parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
lora_config: Optional[LoRAConfig],
multimodal_config: Optional[MultiModalConfig],
cache_config: CacheConfig) -> nn.Module:
loader = get_model_loader(load_config)
return loader.load_model(model_config=model_config,
device_config=device_config,
lora_config=lora_config,
multimodal_config=multimodal_config,
parallel_config=parallel_config,
scheduler_config=scheduler_config,
cache_config=cache_config)
......
......@@ -10,11 +10,13 @@ from abc import ABC, abstractmethod
from contextlib import contextmanager
from typing import Any, Dict, Generator, List, Optional, Tuple, Type
import gguf
import huggingface_hub
import numpy as np
import torch
from huggingface_hub import HfApi, hf_hub_download
from torch import nn
from transformers import AutoModelForCausalLM, PretrainedConfig
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoadFormat,
LoRAConfig, ModelConfig, MultiModalConfig,
......@@ -31,14 +33,15 @@ from vllm.model_executor.model_loader.utils import (get_model_architecture,
from vllm.model_executor.model_loader.weight_utils import (
download_safetensors_index_file_from_hf, download_weights_from_hf,
filter_duplicate_safetensors_files, filter_files_not_needed_for_inference,
get_quant_config, initialize_dummy_weights, np_cache_weights_iterator,
pt_weights_iterator, safetensors_weights_iterator)
get_gguf_extra_tensor_names, get_quant_config, gguf_quant_weights_iterator,
initialize_dummy_weights, np_cache_weights_iterator, pt_weights_iterator,
safetensors_weights_iterator)
from vllm.model_executor.models.interfaces import (has_inner_state,
supports_lora,
supports_vision)
supports_multimodal)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils import is_pin_memory_available, is_tpu
from vllm.utils import is_pin_memory_available
@contextmanager
......@@ -91,12 +94,13 @@ def _get_quantization_config(
"""Get the quantization config."""
if model_config.quantization is not None:
quant_config = get_quant_config(model_config, load_config)
if not current_platform.is_tpu():
capability = current_platform.get_device_capability()
capability = capability[0] * 10 + capability[1]
if capability < quant_config.get_min_capability():
raise ValueError(
f"The quantization method {model_config.quantization} is not "
"supported for the current GPU. "
f"The quantization method {model_config.quantization} "
"is not supported for the current GPU. "
f"Minimum capability: {quant_config.get_min_capability()}. "
f"Current capability: {capability}.")
supported_dtypes = quant_config.get_supported_act_dtypes()
......@@ -130,10 +134,8 @@ def _get_model_initialization_kwargs(
"be added in the future. If this is important to you, "
"please open an issue on github.")
if supports_vision(model_class):
if multimodal_config is None:
raise ValueError("Provide vision related configurations "
"through LLM entrypoint or engine arguments.")
if supports_multimodal(model_class):
assert multimodal_config is not None
extra_kwargs["multimodal_config"] = multimodal_config
......@@ -143,23 +145,40 @@ def _get_model_initialization_kwargs(
return extra_kwargs
def build_model(model_class: Type[nn.Module], hf_config: PretrainedConfig,
cache_config: Optional[CacheConfig],
quant_config: Optional[QuantizationConfig], *,
lora_config: Optional[LoRAConfig],
multimodal_config: Optional[MultiModalConfig],
scheduler_config: Optional[SchedulerConfig]) -> nn.Module:
extra_kwargs = _get_model_initialization_kwargs(model_class, lora_config,
multimodal_config,
scheduler_config)
return model_class(config=hf_config,
cache_config=cache_config,
quant_config=quant_config,
**extra_kwargs)
def _initialize_model(
model_config: ModelConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
multimodal_config: Optional[MultiModalConfig],
cache_config: CacheConfig,
scheduler_config: Optional[SchedulerConfig] = None) -> nn.Module:
"""Initialize a model with the given configurations."""
model_class = get_model_architecture(model_config)[0]
quant_config = _get_quantization_config(model_config, load_config)
model_class, _ = get_model_architecture(model_config)
return model_class(config=model_config.hf_config,
return build_model(
model_class,
model_config.hf_config,
cache_config=cache_config,
quant_config=quant_config,
**_get_model_initialization_kwargs(
model_class, lora_config, multimodal_config,
scheduler_config))
quant_config=_get_quantization_config(model_config, load_config),
lora_config=lora_config,
multimodal_config=model_config.multimodal_config,
scheduler_config=scheduler_config,
)
class BaseModelLoader(ABC):
......@@ -172,7 +191,6 @@ class BaseModelLoader(ABC):
def load_model(self, *, model_config: ModelConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
multimodal_config: Optional[MultiModalConfig],
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
cache_config: CacheConfig) -> nn.Module:
......@@ -301,7 +319,7 @@ class DefaultModelLoader(BaseModelLoader):
else:
weights_iterator = pt_weights_iterator(hf_weights_files)
if is_tpu():
if current_platform.is_tpu():
# In PyTorch XLA, we should call `xm.mark_step` frequently so that
# not too many ops are accumulated in the XLA program.
import torch_xla.core.xla_model as xm
......@@ -317,7 +335,6 @@ class DefaultModelLoader(BaseModelLoader):
def load_model(self, *, model_config: ModelConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
multimodal_config: Optional[MultiModalConfig],
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
cache_config: CacheConfig) -> nn.Module:
......@@ -325,8 +342,8 @@ class DefaultModelLoader(BaseModelLoader):
with set_default_torch_dtype(model_config.dtype):
with target_device:
model = _initialize_model(model_config, self.load_config,
lora_config, multimodal_config,
cache_config, scheduler_config)
lora_config, cache_config,
scheduler_config)
model.load_weights(
self._get_weights_iterator(model_config.model,
model_config.revision,
......@@ -360,15 +377,14 @@ class DummyModelLoader(BaseModelLoader):
def load_model(self, *, model_config: ModelConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
multimodal_config: Optional[MultiModalConfig],
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
cache_config: CacheConfig) -> nn.Module:
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model = _initialize_model(model_config, self.load_config,
lora_config, multimodal_config,
cache_config, scheduler_config)
lora_config, cache_config,
scheduler_config)
# NOTE(woosuk): For accurate performance evaluation, we assign
# random values to the weights.
initialize_dummy_weights(model)
......@@ -401,7 +417,6 @@ class TensorizerLoader(BaseModelLoader):
model_config: ModelConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
multimodal_config: Optional[MultiModalConfig],
cache_config: CacheConfig,
) -> nn.Module:
"""Load a serialized model with tensorizer to the CPU.
......@@ -414,8 +429,7 @@ class TensorizerLoader(BaseModelLoader):
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model = _initialize_model(model_config, self.load_config,
lora_config, multimodal_config,
cache_config)
lora_config, cache_config)
model.load_weights(self._get_weights_iterator())
return model.eval()
......@@ -425,7 +439,6 @@ class TensorizerLoader(BaseModelLoader):
model_config: ModelConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
multimodal_config: Optional[MultiModalConfig],
cache_config: CacheConfig,
) -> nn.Module:
"""Load a serialized model with tensorizer.
......@@ -439,7 +452,7 @@ class TensorizerLoader(BaseModelLoader):
quant_config = _get_quantization_config(
model_config, self.load_config)
extra_kwargs = _get_model_initialization_kwargs(
model_class, lora_config, multimodal_config)
model_class, lora_config, model_config.multimodal_config)
extra_kwargs["quant_config"] = quant_config
extra_kwargs["cache_config"] = cache_config
......@@ -454,7 +467,6 @@ class TensorizerLoader(BaseModelLoader):
def load_model(self, *, model_config: ModelConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
multimodal_config: Optional[MultiModalConfig],
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
cache_config: CacheConfig) -> nn.Module:
......@@ -468,11 +480,9 @@ class TensorizerLoader(BaseModelLoader):
if is_vllm_tensorized(self.tensorizer_config):
return self._load_model_serialized(model_config, device_config,
lora_config, multimodal_config,
cache_config)
lora_config, cache_config)
return self._load_model_serialized_cpu(model_config, device_config,
lora_config, multimodal_config,
cache_config)
lora_config, cache_config)
@staticmethod
def save_model(
......@@ -558,7 +568,6 @@ class ShardedStateLoader(BaseModelLoader):
def load_model(self, *, model_config: ModelConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
multimodal_config: Optional[MultiModalConfig],
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
cache_config: CacheConfig) -> nn.Module:
......@@ -572,8 +581,11 @@ class ShardedStateLoader(BaseModelLoader):
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model = _initialize_model(model_config, self.load_config,
lora_config, multimodal_config,
cache_config)
lora_config, cache_config)
for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None)
if quant_method is not None:
quant_method.process_weights_after_loading(module)
rank = get_tensor_model_parallel_rank()
pattern = os.path.join(
local_model_path,
......@@ -864,11 +876,11 @@ class BitsAndBytesModelLoader(BaseModelLoader):
if not hasattr(model, 'load_weights'):
raise AttributeError(
"The required method 'load_weights' is not defined in class"
f" {type(self).__name__}.")
f" {type(model).__name__}.")
if not hasattr(model, 'bitsandbytes_stacked_params_mapping'):
raise AttributeError(
f"Model {type(self).__name__} does not support BitsAndBytes "
f"Model {type(model).__name__} does not support BitsAndBytes "
"quantization yet.")
logger.info("Loading weights with BitsAndBytes quantization. "
......@@ -936,21 +948,101 @@ class BitsAndBytesModelLoader(BaseModelLoader):
def load_model(self, *, model_config: ModelConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
multimodal_config: Optional[MultiModalConfig],
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
cache_config: CacheConfig) -> nn.Module:
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model = _initialize_model(model_config, self.load_config,
lora_config, multimodal_config,
cache_config)
lora_config, cache_config)
self._load_weights(model_config, model)
return model.eval()
class GGUFModelLoader(BaseModelLoader):
"""
Model loader that can load GGUF files. This is useful for loading models
that are quantized with GGUF and saved in the GGUF format. This loader
supports loading both full models and sharded models.
"""
def __init__(self, load_config: LoadConfig):
super().__init__(load_config)
if load_config.model_loader_extra_config:
raise ValueError(f"Model loader extra config is not supported for "
f"load format {load_config.load_format}")
def _prepare_weights(self, model_name_or_path: str):
if os.path.isfile(model_name_or_path):
return model_name_or_path
else:
raise ValueError(f"{model_name_or_path} is not a file.")
def _get_gguf_weights_map(self, model_config: ModelConfig):
"""
GGUF uses this naming convention for their tensors from HF checkpoint:
`blk.N.BB.weight` and `blk.N.BB.bias`
where N signifies the block number of a layer, and BB signifies the
attention/mlp layer components.
See "Standardized tensor names" in
https://github.com/ggerganov/ggml/blob/master/docs/gguf.md for details.
"""
config = model_config.hf_config
model_type = config.model_type
# hack: ggufs have a different name than transformers
if model_type == "cohere":
model_type = "command-r"
arch = None
for key, value in gguf.MODEL_ARCH_NAMES.items():
if value == model_type:
arch = key
break
if arch is None:
raise RuntimeError(f"Unknown gguf model_type: {model_type}")
num_layers = config.num_hidden_layers
name_map = gguf.get_tensor_name_map(arch, num_layers)
with torch.device("meta"):
dummy_model = AutoModelForCausalLM.from_config(config)
state_dict = dummy_model.state_dict()
gguf_to_hf_name_map = {}
for hf_name in state_dict:
name, suffix = hf_name.rsplit(".", 1)
gguf_name = name_map.get_name(name)
gguf_to_hf_name_map[f"{gguf_name}.{suffix}"] = hf_name
return gguf_to_hf_name_map
def _get_weights_iterator(
self, model_name_or_path: str, gguf_to_hf_name_map: Dict[str, str]
) -> Generator[Tuple[str, torch.Tensor], None, None]:
return gguf_quant_weights_iterator(model_name_or_path,
gguf_to_hf_name_map)
def load_model(self, *, model_config: ModelConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
cache_config: CacheConfig) -> nn.Module:
local_model_path = self._prepare_weights(model_config.model)
gguf_weights_map = self._get_gguf_weights_map(model_config)
# we can only know if tie word embeddings after mapping weights
if "lm_head.weight" in get_gguf_extra_tensor_names(
local_model_path, gguf_weights_map):
model_config.hf_config.update({"tie_word_embeddings": True})
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model = _initialize_model(model_config, self.load_config,
lora_config, cache_config)
model.load_weights(
self._get_weights_iterator(local_model_path, gguf_weights_map))
return model
def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
"""Get a model loader based on the load format."""
......@@ -969,4 +1061,7 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
if load_config.load_format == LoadFormat.BITSANDBYTES:
return BitsAndBytesModelLoader(load_config)
if load_config.load_format == LoadFormat.GGUF:
return GGUFModelLoader(load_config)
return DefaultModelLoader(load_config)
......@@ -47,13 +47,7 @@ def get_model_architecture(
and "MixtralForCausalLM" in architectures):
architectures = ["QuantMixtralForCausalLM"]
for arch in architectures:
model_cls = ModelRegistry.load_model_cls(arch)
if model_cls is not None:
return (model_cls, arch)
raise ValueError(
f"Model architectures {architectures} are not supported for now. "
f"Supported architectures: {ModelRegistry.get_supported_archs()}")
return ModelRegistry.resolve_model_cls(architectures)
def get_architecture_class_name(model_config: ModelConfig) -> str:
......
......@@ -6,9 +6,10 @@ import json
import os
import tempfile
from collections import defaultdict
from typing import Any, Generator, Iterable, List, Optional, Tuple, Union
from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, Union
import filelock
import gguf
import huggingface_hub.constants
import numpy as np
import torch
......@@ -18,6 +19,7 @@ from tqdm.auto import tqdm
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
from vllm.config import LoadConfig, ModelConfig
from vllm.distributed import get_tensor_model_parallel_rank
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import (QuantizationConfig,
get_quantization_config)
......@@ -121,9 +123,18 @@ def get_quant_config(model_config: ModelConfig,
load_config: LoadConfig) -> QuantizationConfig:
quant_cls = get_quantization_config(model_config.quantization)
# GGUF doesn't have config file
if model_config.quantization == "gguf":
return quant_cls.from_config({})
# Read the quantization config from the HF model config, if available.
hf_quant_config = getattr(model_config.hf_config, "quantization_config",
None)
# some vision model may keep quantization_config in their text_config
hf_text_config = getattr(model_config.hf_config, "text_config", None)
if hf_quant_config is None and hf_text_config is not None:
hf_quant_config = getattr(hf_text_config, "quantization_config", None)
if hf_quant_config is None:
# compressed-tensors uses a compressions_config
hf_quant_config = getattr(model_config.hf_config, "compression_config",
......@@ -409,6 +420,47 @@ def pt_weights_iterator(
torch.cuda.empty_cache()
def get_gguf_extra_tensor_names(
gguf_file: str, gguf_to_hf_name_map: Dict[str, str]) -> List[str]:
reader = gguf.GGUFReader(gguf_file)
expected_gguf_keys = set(gguf_to_hf_name_map.keys())
exact_gguf_keys = set([tensor.name for tensor in reader.tensors])
extra_keys = expected_gguf_keys - exact_gguf_keys
return [gguf_to_hf_name_map[key] for key in extra_keys]
def gguf_quant_weights_iterator(
gguf_file: str, gguf_to_hf_name_map: Dict[str, str]
) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""
Iterate over the quant weights in the model gguf files and convert
them to torch tensors
"""
reader = gguf.GGUFReader(gguf_file)
for tensor in reader.tensors:
if tensor.name in gguf_to_hf_name_map:
weight_type = tensor.tensor_type
name = gguf_to_hf_name_map[tensor.name]
if weight_type.name != "F32":
weight_type_name = name.replace("weight", "qweight_type")
weight_type = torch.tensor(weight_type)
yield weight_type_name, weight_type
for tensor in reader.tensors:
if tensor.name in gguf_to_hf_name_map:
weight = tensor.data
weight_type = tensor.tensor_type
name = gguf_to_hf_name_map[tensor.name]
if weight_type.name != "F32":
name = name.replace("weight", "qweight")
param = torch.tensor(weight)
yield name, param
def kv_cache_scales_loader(
filename: str, tp_rank: int, tp_size: int, num_hidden_layers: int,
model_type: Optional[str]) -> Iterable[Tuple[int, float]]:
......@@ -467,8 +519,36 @@ def convert_pyslice_to_tensor(x: Any) -> torch.Tensor:
def default_weight_loader(param: torch.Tensor,
loaded_weight: torch.Tensor) -> None:
"""Default weight loader."""
assert param.size() == loaded_weight.size()
try:
if param.numel() == 1 and loaded_weight.numel() == 1:
# Sometimes scalar values aren't considered tensors with shapes
# so if both param and loaded_weight are a scalar,
# "broadcast" instead of copy
param.data.fill_(loaded_weight.item())
else:
assert param.size() == loaded_weight.size(), (
f"Attempted to load weight ({loaded_weight.size()}) "
f"into parameter ({param.size()})")
param.data.copy_(loaded_weight)
except Exception:
# NOTE: This exception is added for the purpose of setting breakpoint to
# debug weight loading issues.
raise
def row_parallel_weight_loader(param: torch.Tensor,
loaded_weight: torch.Tensor) -> None:
"""Load weights that are row-parallelized."""
tp_rank = get_tensor_model_parallel_rank()
shard_dim = 0 if param.dim() != 1 else None
if shard_dim is not None:
shard_size = param.data.shape[shard_dim]
start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(shard_dim, start_idx, shard_size)
return default_weight_loader(param, loaded_weight)
def initialize_dummy_weights(
......
import functools
import importlib
from typing import Dict, List, Optional, Type
from typing import Dict, List, Optional, Tuple, Type
import torch.nn as nn
......@@ -9,17 +9,12 @@ from vllm.utils import is_hip
logger = init_logger(__name__)
# Architecture -> (module, class).
_GENERATION_MODELS = {
"AquilaModel": ("llama", "LlamaForCausalLM"),
"AquilaForCausalLM": ("llama", "LlamaForCausalLM"), # AquilaChat2
"BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"), # baichuan-7b
"BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"), # baichuan-13b
"BloomForCausalLM": ("bloom", "BloomForCausalLM"),
"Blip2ForConditionalGeneration":
("blip2", "Blip2ForConditionalGeneration"),
"ChameleonForConditionalGeneration":
("chameleon", "ChameleonForConditionalGeneration"),
"ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
"ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
"CohereForCausalLM": ("commandr", "CohereForCausalLM"),
......@@ -28,7 +23,6 @@ _GENERATION_MODELS = {
"DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
"DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
"FalconForCausalLM": ("falcon", "FalconForCausalLM"),
"FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
"GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
"Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
"GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
......@@ -37,13 +31,8 @@ _GENERATION_MODELS = {
"GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
"InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
"InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
"InternVLChatModel": ("internvl", "InternVLChatModel"),
"JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
"LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
"LlavaForConditionalGeneration":
("llava", "LlavaForConditionalGeneration"),
"LlavaNextForConditionalGeneration":
("llava_next", "LlavaNextForConditionalGeneration"),
# For decapoda-research/llama-*
"LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
"MistralForCausalLM": ("llama", "LlamaForCausalLM"),
......@@ -53,17 +42,13 @@ _GENERATION_MODELS = {
"MptForCausalLM": ("mpt", "MPTForCausalLM"),
"MPTForCausalLM": ("mpt", "MPTForCausalLM"),
"MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
"MiniCPMV": ("minicpmv", "MiniCPMV"),
"NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
"OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
"OPTForCausalLM": ("opt", "OPTForCausalLM"),
"OrionForCausalLM": ("orion", "OrionForCausalLM"),
"PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
"PaliGemmaForConditionalGeneration": ("paligemma",
"PaliGemmaForConditionalGeneration"),
"PhiForCausalLM": ("phi", "PhiForCausalLM"),
"Phi3ForCausalLM": ("llama", "LlamaForCausalLM"),
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
"QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
"Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
......@@ -75,15 +60,43 @@ _GENERATION_MODELS = {
"XverseForCausalLM": ("xverse", "XverseForCausalLM"),
"Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"),
"MedusaModel": ("medusa", "Medusa"),
"EAGLEModel": ("eagle", "EAGLE"),
"MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
"JambaForCausalLM": ("jamba", "JambaForCausalLM")
"JambaForCausalLM": ("jamba", "JambaForCausalLM"),
}
_EMBEDDING_MODELS = {
"MistralModel": ("llama_embedding", "LlamaEmbeddingModel"),
}
_MODELS = {**_GENERATION_MODELS, **_EMBEDDING_MODELS}
_MULTIMODAL_MODELS = {
"Blip2ForConditionalGeneration":
("blip2", "Blip2ForConditionalGeneration"),
"ChameleonForConditionalGeneration":
("chameleon", "ChameleonForConditionalGeneration"),
"FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
"InternVLChatModel": ("internvl", "InternVLChatModel"),
"LlavaForConditionalGeneration":
("llava", "LlavaForConditionalGeneration"),
"LlavaNextForConditionalGeneration":
("llava_next", "LlavaNextForConditionalGeneration"),
"MiniCPMV": ("minicpmv", "MiniCPMV"),
"PaliGemmaForConditionalGeneration": ("paligemma",
"PaliGemmaForConditionalGeneration"),
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
"UltravoxModel": ("ultravox", "UltravoxModel"),
}
_CONDITIONAL_GENERATION_MODELS = {
"BartModel": ("bart", "BartForConditionalGeneration"),
"BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"),
}
_MODELS = {
**_GENERATION_MODELS,
**_EMBEDDING_MODELS,
**_MULTIMODAL_MODELS,
**_CONDITIONAL_GENERATION_MODELS,
}
# Architecture -> type.
# out of tree models
......@@ -126,7 +139,7 @@ class ModelRegistry:
return getattr(module, model_cls_name, None)
@staticmethod
def load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]:
def _try_load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]:
if model_arch in _OOT_MODELS:
return _OOT_MODELS[model_arch]
if model_arch not in _MODELS:
......@@ -143,9 +156,21 @@ class ModelRegistry:
return ModelRegistry._get_model(model_arch)
@staticmethod
def resolve_model_cls(
architectures: List[str]) -> Tuple[Type[nn.Module], str]:
for arch in architectures:
model_cls = ModelRegistry._try_load_model_cls(arch)
if model_cls is not None:
return (model_cls, arch)
raise ValueError(
f"Model architectures {architectures} are not supported for now. "
f"Supported architectures: {ModelRegistry.get_supported_archs()}")
@staticmethod
def get_supported_archs() -> List[str]:
return list(_MODELS.keys())
return list(_MODELS.keys()) + list(_OOT_MODELS.keys())
@staticmethod
def register_model(model_arch: str, model_cls: Type[nn.Module]):
......@@ -161,6 +186,15 @@ class ModelRegistry:
def is_embedding_model(model_arch: str) -> bool:
return model_arch in _EMBEDDING_MODELS
@staticmethod
def is_multimodal_model(model_arch: str) -> bool:
# TODO: find a way to avoid initializing CUDA prematurely to
# use `supports_multimodal` to determine if a model is multimodal
# model_cls = ModelRegistry._try_load_model_cls(model_arch)
# from vllm.model_executor.models.interfaces import supports_multimodal
return model_arch in _MULTIMODAL_MODELS
__all__ = [
"ModelRegistry",
......
......@@ -414,6 +414,8 @@ class ArcticForCausalLM(nn.Module):
config.hidden_size,
quant_config=quant_config,
)
if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
self.num_experts = config.num_local_experts
self.num_experts_per_tok = config.num_experts_per_tok
self.unpadded_vocab_size = config.vocab_size
......@@ -433,8 +435,11 @@ class ArcticForCausalLM(nn.Module):
attn_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment