Commit f48954a4 authored by zhuwenwen's avatar zhuwenwen
Browse files

merge v0.5.0

parents 1dba29d3 8f89d720
......@@ -17,6 +17,24 @@ ACTIVATION_SCHEMES = ["static", "dynamic"]
logger = init_logger(__name__)
def cutlass_fp8_supported() -> bool:
capability = torch.cuda.get_device_capability()
capability = capability[0] * 10 + capability[1]
major, minor = torch.version.cuda.split(".")
version = int(major) * 10 + int(minor)
# CUTLASS FP8 kernels need at least
# CUDA 12.0 on SM90 systems (Hopper)
# CUDA 12.4 on SM89 systems (Lovelace)
gpu_is_supported = False
if capability >= 90:
gpu_is_supported = version > 120
elif capability >= 89:
gpu_is_supported = version > 124
return gpu_is_supported
class Fp8Config(QuantizationConfig):
"""Config class for FP8."""
......@@ -85,13 +103,14 @@ class Fp8LinearMethod(LinearMethodBase):
1. Only support per-tensor quantization due to torch._scaled_mm support.
2. Only support float8_e4m3fn data type due to the limitation of
torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
Args:
quant_config: The quantization config.
"""
def __init__(self, quant_config: Fp8Config):
self.quant_config = quant_config
self.cutlass_fp8_supported = cutlass_fp8_supported()
def _create_scale_param(
self,
......@@ -152,10 +171,10 @@ class Fp8LinearMethod(LinearMethodBase):
output_partition_sizes=output_partition_sizes,
**extra_weight_attrs)
# ACTIVATION SCALE
# INPUT ACTIVATION SCALE
if self.quant_config.activation_scheme == "static":
self._create_scale_param(
scale_name="act_scale",
scale_name="input_scale",
layer=layer,
output_partition_sizes=output_partition_sizes,
**extra_weight_attrs)
......@@ -188,7 +207,7 @@ class Fp8LinearMethod(LinearMethodBase):
layer.weight = Parameter(qweight.t(), requires_grad=False)
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
layer.logical_widths = None
layer.act_scale = None
layer.input_scale = None
return
# If checkpoint is fp8, requantize the separately quantized logical
......@@ -213,18 +232,18 @@ class Fp8LinearMethod(LinearMethodBase):
weight = layer.weight
layer.weight = Parameter(weight.t(), requires_grad=False)
# ACT_SCALE
# INPUT ACTIVATION SCALE
# Dynamic: set to None (required input to ops.scaled_fp8_quant).
# Static: set to max of the act_scales (since they are equal).
# Static: set to max of the input_scales (since they are equal).
if self.quant_config.activation_scheme == "dynamic":
layer.act_scale = None
layer.input_scale = None
elif self.quant_config.activation_scheme == "static":
if not all_close_1d(layer.act_scale):
if not all_close_1d(layer.input_scale):
raise ValueError(
"All the act_scales for the logical weights of a layer "
f"must be equal. But got {layer.act_scale}")
layer.act_scale = Parameter(layer.act_scale.max(),
requires_grad=False)
"All the input_scales for the logical weights of a "
f"layer must be equal. But got {layer.input_scale}")
layer.input_scale = Parameter(layer.input_scale.max(),
requires_grad=False)
else:
raise ValueError(
f"Unknown scheme {self.quant_config.activation_scheme}")
......@@ -233,25 +252,40 @@ class Fp8LinearMethod(LinearMethodBase):
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
# ops.scaled_fp8_quant supports both dynamic and static quant.
# If dynamic, layer.act_scale is None and x_scale computed from x.
# If static, layer.act_scale is scalar and x_scale set to act_scale.
qinput, x_scale = ops.scaled_fp8_quant(x,
layer.act_scale,
batch_dim_padding=17)
# Fused GEMM_DQ -- note we padded the input above because
# torch._scaled_mm is more performant for matrices with
# batch dimension > 16. Note that this could change
# in the future.
output, _ = torch._scaled_mm(
qinput,
layer.weight,
out_dtype=x.dtype,
scale_a=x_scale,
scale_b=layer.weight_scale,
bias=bias,
)
# If dynamic, layer.input_scale is None and x_scale computed from x.
# If static, layer.input_scale is scalar and x_scale is input_scale.
if bias is None and self.cutlass_fp8_supported:
qinput, x_scale = ops.scaled_fp8_quant(x, layer.input_scale)
# Fused GEMM_DQ
output = ops.cutlass_scaled_mm_dq(
qinput,
layer.weight,
out_dtype=x.dtype,
scale_a=x_scale,
scale_b=layer.weight_scale,
)
else:
qinput, x_scale = ops.scaled_fp8_quant(x,
layer.input_scale,
batch_dim_padding=17)
# Fused GEMM_DQ -- note we padded the input above because
# torch._scaled_mm is more performant for matrices with
# batch dimension > 16. Note that this could change
# in the future.
output, _ = torch._scaled_mm(
qinput,
layer.weight,
out_dtype=x.dtype,
scale_a=x_scale,
scale_b=layer.weight_scale,
bias=bias,
)
return torch.narrow(output, 0, 0, x.shape[0])
......@@ -264,8 +298,8 @@ class Fp8KVCacheMethod(QuantizeMethodBase):
self.quant_config = quant_config
def create_weights(self, layer: torch.nn.Module):
"""Create "weight" (aka kv_scale) for an attention layer.
"""Create "weight" (aka kv_scale) for an attention layer.
Args:
layer: The layer that is using the QuantizeMethodBase factory.
"""
......@@ -300,14 +334,15 @@ def all_close_1d(x: torch.Tensor) -> bool:
def per_tensor_quantize(tensor: torch.Tensor,
inv_scale: float) -> torch.Tensor:
inv_scale: Union[float, torch.Tensor]) -> torch.Tensor:
finfo = torch.finfo(torch.float8_e4m3fn)
qweight = (tensor / inv_scale).clamp(min=finfo.min, max=finfo.max)
return qweight.to(torch.float8_e4m3fn)
def per_tensor_dequantize(tensor: torch.Tensor,
inv_scale: float) -> torch.Tensor:
def per_tensor_dequantize(
tensor: torch.Tensor, inv_scale: Union[float,
torch.Tensor]) -> torch.Tensor:
fake_qweight = tensor.to(torch.float16)
dq_weight = fake_qweight * inv_scale
return dq_weight
......@@ -306,8 +306,10 @@ class RejectionSampler(nn.Module):
# Fill in the first k columns of the output tensor using masks and data
# tensors.
output[:, :k] = torch.where(accepted_mask, draft_token_ids,
-torch.ones_like(draft_token_ids))
torch.where(accepted_mask,
draft_token_ids,
-torch.ones_like(draft_token_ids),
out=output)
# Fill the last column.
# We check output directly as accepted may have True values inconsistent
......
......@@ -27,7 +27,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from vllm import _custom_ops as ops
from vllm.model_executor.custom_op import CustomOp
def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
......@@ -43,7 +43,7 @@ def _rotate_gptj(x: torch.Tensor) -> torch.Tensor:
return x.flatten(-2)
class RotaryEmbedding(nn.Module):
class RotaryEmbedding(CustomOp):
"""Original rotary positional embedding."""
def __init__(
......@@ -93,7 +93,7 @@ class RotaryEmbedding(nn.Module):
cache = torch.cat((cos, sin), dim=-1)
return cache
def _forward(
def forward_native(
self,
positions: torch.Tensor,
query: torch.Tensor,
......@@ -138,13 +138,15 @@ class RotaryEmbedding(nn.Module):
key = key.flatten(-2)
return query, key
def forward(
def forward_cuda(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
from vllm import _custom_ops as ops
self.cos_sin_cache = self.cos_sin_cache.to(positions.device,
dtype=query.dtype)
# ops.rotary_embedding()/batched_rotary_embedding()
......
from typing import Optional, Sequence
from dataclasses import dataclass
from typing import List, Optional, Sequence, Tuple
import torch
import torch.nn.functional as F
......@@ -18,18 +19,107 @@ def pad_vocab_size(vocab_size: int,
return ((vocab_size + pad_to - 1) // pad_to) * pad_to
def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size: int,
rank: int) -> Sequence[int]:
def vocab_range_from_per_partition_vocab_size(
per_partition_vocab_size: int,
rank: int,
offset: int = 0) -> Sequence[int]:
index_f = rank * per_partition_vocab_size
index_l = index_f + per_partition_vocab_size
return index_f, index_l
return index_f + offset, index_l + offset
def vocab_range_from_global_vocab_size(global_vocab_size: int, rank: int,
world_size: int) -> Sequence[int]:
def vocab_range_from_global_vocab_size(global_vocab_size: int,
rank: int,
world_size: int,
offset: int = 0) -> Sequence[int]:
per_partition_vocab_size = divide(global_vocab_size, world_size)
return vocab_range_from_per_partition_vocab_size(per_partition_vocab_size,
rank)
rank,
offset=offset)
@dataclass
class VocabParallelEmbeddingShardIndices:
"""Indices for a shard of a vocab parallel embedding."""
padded_org_vocab_start_index: int
padded_org_vocab_end_index: int
padded_added_vocab_start_index: int
padded_added_vocab_end_index: int
org_vocab_start_index: int
org_vocab_end_index: int
added_vocab_start_index: int
added_vocab_end_index: int
@property
def num_org_elements(self) -> int:
return self.org_vocab_end_index - self.org_vocab_start_index
@property
def num_added_elements(self) -> int:
return self.added_vocab_end_index - self.added_vocab_start_index
@property
def num_org_elements_padded(self) -> int:
return (self.padded_org_vocab_end_index -
self.padded_org_vocab_start_index)
@property
def num_added_elements_padded(self) -> int:
return (self.padded_added_vocab_end_index -
self.padded_added_vocab_start_index)
@property
def num_org_vocab_padding(self) -> int:
return self.num_org_elements_padded - self.num_org_elements
@property
def num_added_vocab_padding(self) -> int:
return self.num_added_elements_padded - self.num_added_elements
@property
def num_elements_padded(self) -> int:
return self.num_org_elements_padded + self.num_added_elements_padded
def __post_init__(self):
# sanity checks
assert (self.padded_org_vocab_start_index <=
self.padded_org_vocab_end_index)
assert (self.padded_added_vocab_start_index <=
self.padded_added_vocab_end_index)
assert self.org_vocab_start_index <= self.org_vocab_end_index
assert self.added_vocab_start_index <= self.added_vocab_end_index
assert self.org_vocab_start_index <= self.padded_org_vocab_start_index
assert (self.added_vocab_start_index <=
self.padded_added_vocab_start_index)
assert self.org_vocab_end_index <= self.padded_org_vocab_end_index
assert self.added_vocab_end_index <= self.padded_added_vocab_end_index
assert self.num_org_elements <= self.num_org_elements_padded
assert self.num_added_elements <= self.num_added_elements_padded
@torch.jit.script
def get_masked_input_and_mask(
input_: torch.Tensor, org_vocab_start_index: int,
org_vocab_end_index: int, num_org_vocab_padding: int,
added_vocab_start_index: int,
added_vocab_end_index: int) -> Tuple[torch.Tensor, torch.Tensor]:
# torch.jit.script will fuse all of the pointwise ops below
# into a single kernel, making it very fast
org_vocab_mask = (input_ >= org_vocab_start_index) & (input_ <
org_vocab_end_index)
added_vocab_mask = (input_ >= added_vocab_start_index) & (
input_ < added_vocab_end_index)
added_offset = added_vocab_start_index - (
org_vocab_end_index - org_vocab_start_index) - num_org_vocab_padding
valid_offset = (org_vocab_start_index *
org_vocab_mask) + (added_offset * added_vocab_mask)
vocab_mask = org_vocab_mask | added_vocab_mask
input_ = vocab_mask * (input_ - valid_offset)
return input_, ~vocab_mask
class VocabParallelEmbedding(torch.nn.Module):
......@@ -38,13 +128,36 @@ class VocabParallelEmbedding(torch.nn.Module):
Adapted from torch.nn.Embedding, note that we pad the vocabulary size to
make sure it is divisible by the number of model parallel GPUs.
In order to support various loading methods, we ensure that LoRA-added
embeddings are always at the end of TP-sharded tensors. In other words,
we shard base embeddings and LoRA embeddings separately (both padded),
and place them in the same tensor.
In this example, we will have the original vocab size = 1010,
added vocab size = 16 and padding to 64. Therefore, the total
vocab size with padding will be 1088 (because we first pad 1010 to
1024, add 16, and then pad to 1088).
Therefore, the tensor format looks like the following:
TP1, rank 0 (no sharding):
|< --------BASE-------- >|< -BASE PADDING-- >|< -----LORA------ >|< -LORA PADDING-- >|
corresponding token_id: | 0 | 1 | ... | 1009 | -1 | ... | -1 | 1010 | ... | 1015 | -1 | ... | -1 |
index: | 0 | 1 | ... | 1009 | 1010 | ... | 1023 | 1024 | ... | 1039 | 1040 | ... | 1087 |
TP2, rank 0:
|< --------------------BASE--------------------- >|< -----LORA------ >|< -LORA PADDING- >|
corresponding token_id: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 1000 | ... | 1015 | -1 | ... | -1 |
index: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 512 | ... | 527 | 520 | ... | 543 |
TP2, rank 1:
|< -----------BASE----------- >|< -BASE PADDING- >|< -----------LORA PADDING----------- >|
corresponding token_id: | 512 | 513 | 514 | ... | 1009 | -1 | ... | -1 | -1 | ... | -1 | -1 | ... | -1 |
index: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 512 | ... | 519 | 520 | ... | 543 |
Args:
num_embeddings: vocabulary size.
embedding_dim: size of hidden state.
params_dtype: type of the parameters.
org_num_embeddings: original vocabulary size (without LoRA).
padding_size: padding size for the vocabulary.
"""
""" # noqa: E501
def __init__(self,
num_embeddings: int,
......@@ -55,21 +168,39 @@ class VocabParallelEmbedding(torch.nn.Module):
super().__init__()
# Keep the input dimensions.
tp_rank = get_tensor_model_parallel_rank()
self.tp_size = get_tensor_model_parallel_world_size()
self.num_embeddings = num_embeddings
self.padding_size = padding_size
self.org_vocab_size = org_num_embeddings or num_embeddings
self.num_embeddings_padded = pad_vocab_size(num_embeddings,
padding_size)
num_added_embeddings = num_embeddings - self.org_vocab_size
self.org_vocab_size_padded = pad_vocab_size(self.org_vocab_size,
self.padding_size)
self.num_embeddings_padded = pad_vocab_size(
self.org_vocab_size_padded + num_added_embeddings,
self.padding_size)
assert self.org_vocab_size_padded <= self.num_embeddings_padded
self.shard_indices = self._get_indices(self.num_embeddings_padded,
self.org_vocab_size_padded,
self.num_embeddings,
self.org_vocab_size, tp_rank,
self.tp_size)
self.embedding_dim = embedding_dim
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.tp_size = get_tensor_model_parallel_world_size()
# Divide the weight matrix along the vocaburaly dimension.
self.vocab_start_index, self.vocab_end_index = (
vocab_range_from_global_vocab_size(
self.num_embeddings_padded, get_tensor_model_parallel_rank(),
self.tp_size))
self.num_embeddings_per_partition = (self.vocab_end_index -
self.vocab_start_index)
self.num_added_embeddings = self.num_embeddings - self.org_vocab_size
self.num_embeddings_per_partition = divide(self.num_embeddings_padded,
self.tp_size)
assert (self.shard_indices.num_elements_padded ==
self.num_embeddings_per_partition)
self.num_org_embeddings_per_partition = (
self.shard_indices.org_vocab_end_index -
self.shard_indices.org_vocab_start_index)
self.num_added_embeddings_per_partition = (
self.shard_indices.added_vocab_end_index -
self.shard_indices.added_vocab_start_index)
self.weight = Parameter(
torch.empty(self.num_embeddings_per_partition,
self.embedding_dim,
......@@ -79,28 +210,107 @@ class VocabParallelEmbedding(torch.nn.Module):
"weight_loader": self.weight_loader
})
@classmethod
def _get_indices(cls, vocab_size_padded: int, org_vocab_size_padded: int,
vocab_size: int, org_vocab_size: int, tp_rank: int,
tp_size: int) -> VocabParallelEmbeddingShardIndices:
"""Get start and end indices for vocab parallel embedding, following the
layout outlined in the class docstring, based on the given tp_rank and
tp_size."""
num_added_embeddings_padded = vocab_size_padded - org_vocab_size_padded
padded_org_vocab_start_index, padded_org_vocab_end_index = (
vocab_range_from_global_vocab_size(org_vocab_size_padded, tp_rank,
tp_size))
padded_added_vocab_start_index, padded_added_vocab_end_index = (
vocab_range_from_global_vocab_size(num_added_embeddings_padded,
tp_rank,
tp_size,
offset=org_vocab_size))
# remove padding
org_vocab_start_index = min(padded_org_vocab_start_index,
org_vocab_size)
org_vocab_end_index = min(padded_org_vocab_end_index, org_vocab_size)
added_vocab_start_index = min(padded_added_vocab_start_index,
vocab_size)
added_vocab_end_index = min(padded_added_vocab_end_index, vocab_size)
return VocabParallelEmbeddingShardIndices(
padded_org_vocab_start_index, padded_org_vocab_end_index,
padded_added_vocab_start_index, padded_added_vocab_end_index,
org_vocab_start_index, org_vocab_end_index,
added_vocab_start_index, added_vocab_end_index)
def get_sharded_to_full_mapping(self) -> Optional[List[int]]:
"""Get a mapping that can be used to reindex the gathered
logits for sampling.
During sampling, we gather logits from all ranks. The relationship
of index->token_id will follow the same format as outlined in the class
docstring. However, after the gather, we want to reindex the final
logits tensor to map index->token_id one-to-one (the index is always
equal the token_id it corresponds to). The indices returned by this
method allow us to do that.
"""
if self.tp_size < 2:
return None
base_embeddings: List[int] = []
added_embeddings: List[int] = []
padding: List[int] = []
for tp_rank in range(self.tp_size):
shard_indices = self._get_indices(self.num_embeddings_padded,
self.org_vocab_size_padded,
self.num_embeddings,
self.org_vocab_size, tp_rank,
self.tp_size)
range_start = self.num_embeddings_per_partition * tp_rank
range_end = self.num_embeddings_per_partition * (tp_rank + 1)
base_embeddings.extend(
range(range_start,
range_start + shard_indices.num_org_elements))
padding.extend(
range(range_start + shard_indices.num_org_elements,
range_start + shard_indices.num_org_elements_padded))
added_embeddings.extend(
range(
range_start + shard_indices.num_org_elements_padded,
range_start + shard_indices.num_org_elements_padded +
shard_indices.num_added_elements))
padding.extend(
range(
range_start + shard_indices.num_org_elements_padded +
shard_indices.num_added_elements,
range_start + shard_indices.num_org_elements_padded +
shard_indices.num_added_elements_padded))
assert (range_start + shard_indices.num_org_elements_padded +
shard_indices.num_added_elements_padded == range_end)
ret = base_embeddings + added_embeddings + padding
assert len(ret) == self.num_embeddings_padded
return ret
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
parallel_dim = param.parallel_dim
assert loaded_weight.shape[parallel_dim] == self.org_vocab_size
loaded_weight = loaded_weight[self.vocab_start_index:self.
vocab_end_index]
loaded_weight = loaded_weight[self.shard_indices.org_vocab_start_index:
self.shard_indices.org_vocab_end_index]
param[:loaded_weight.shape[0]].data.copy_(loaded_weight)
param[loaded_weight.shape[0]:].data.fill_(0)
def forward(self, input_):
if self.tp_size > 1:
# Build the mask.
input_mask = ((input_ < self.vocab_start_index) |
(input_ >= self.vocab_end_index))
# Mask the input.
masked_input = input_.clone() - self.vocab_start_index
masked_input[input_mask] = 0
masked_input, input_mask = get_masked_input_and_mask(
input_, self.shard_indices.org_vocab_start_index,
self.shard_indices.org_vocab_end_index,
self.shard_indices.num_org_vocab_padding,
self.shard_indices.added_vocab_start_index,
self.shard_indices.added_vocab_end_index)
else:
masked_input = input_
# Get the embeddings.
output_parallel = F.embedding(masked_input, self.weight)
# Mask the output embedding.
if self.tp_size > 1:
output_parallel[input_mask, :] = 0.0
output_parallel.masked_fill_(input_mask.unsqueeze(1), 0)
# Reduce across all the model parallel GPUs.
output = tensor_model_parallel_all_reduce(output_parallel)
return output
......
# ruff: noqa: SIM117
import collections
import copy
import fnmatch
import glob
import json
import math
import os
from abc import ABC, abstractmethod
from typing import Any, Dict, Generator, List, Optional, Tuple, Type
import huggingface_hub
import numpy as np
import torch
from huggingface_hub import HfApi, hf_hub_download
from torch import nn
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoadFormat,
......@@ -28,6 +33,7 @@ from vllm.model_executor.model_loader.weight_utils import (
get_quant_config, initialize_dummy_weights, np_cache_weights_iterator,
pt_weights_iterator, safetensors_weights_iterator)
from vllm.model_executor.models.vlm_base import VisionLanguageModelBase
from vllm.model_executor.utils import set_weight_attrs
logger = init_logger(__name__)
......@@ -128,7 +134,7 @@ class DefaultModelLoader(BaseModelLoader):
def _maybe_download_from_modelscope(
self, model: str, revision: Optional[str]) -> Optional[str]:
"""Download model from ModelScope hub if VLLM_USE_MODELSCOPE is True.
Returns the path to the downloaded model, or None if the model is not
downloaded from ModelScope."""
if VLLM_USE_MODELSCOPE:
......@@ -250,6 +256,7 @@ class DefaultModelLoader(BaseModelLoader):
model,
"fall_back_to_pt_during_load",
True)), )
for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None)
if quant_method is not None:
......@@ -389,7 +396,7 @@ class ShardedStateLoader(BaseModelLoader):
Model loader that directly loads each worker's model state dict, which
enables a fast load path for large tensor-parallel models where each worker
only needs to read its own shard rather than the entire checkpoint. See
`examples/save_sharded_states.py` for creating a sharded checkpoint.
`examples/save_sharded_state.py` for creating a sharded checkpoint.
"""
DEFAULT_PATTERN = "model-rank-{rank}-part-{part}.safetensors"
......@@ -542,6 +549,241 @@ class ShardedStateLoader(BaseModelLoader):
)
class BitsAndBytesModelLoader(BaseModelLoader):
"""Model loader to load model weights with BitAndBytes quantization."""
default_target_modules = [
"gate_proj", "down_proj", "up_proj", "q_proj", "k_proj", "v_proj",
"o_proj"
]
possible_config_file_names = ["adapter_config.json"]
def __init__(self, load_config: LoadConfig):
super().__init__(load_config)
# we don't need to quantize the whole model, only the target modules
# that are specified in the adapter config file. If the adapter config
# file is not provided, we will quantize the default modules.
if (not load_config.model_loader_extra_config
or "qlora_adapter_name_or_path"
not in load_config.model_loader_extra_config):
self.target_modules = self.default_target_modules
return
qlora_adapter = load_config.model_loader_extra_config[
"qlora_adapter_name_or_path"]
config_file_path = self._get_config_file(qlora_adapter)
with open(config_file_path, "r") as f:
config = json.load(f)
self.target_modules = config["target_modules"]
def _get_config_file(self, qlora_adapter: str) -> str:
is_local = os.path.isdir(qlora_adapter)
config_file_path = None
if is_local:
for file in self.possible_config_file_names:
config_file_path = os.path.join(qlora_adapter, file)
if os.path.exists(config_file_path):
break
else:
hf_api = HfApi()
repo_files = hf_api.list_repo_files(repo_id=qlora_adapter)
for file in self.possible_config_file_names:
if file in repo_files:
config_file_path = hf_hub_download(repo_id=qlora_adapter,
filename=file)
break
if not config_file_path:
raise ValueError(
f"Cannot find adapter config file in {qlora_adapter}")
return config_file_path
def _get_weight_files(
self,
model_name_or_path: str,
allowed_patterns: List[str],
revision: Optional[str] = None) -> Tuple[List[str], str]:
"""Retrieve weight files. Download the files if necessary.
Return the weight files and the file pattern."""
is_local = os.path.isdir(model_name_or_path)
if is_local:
for pattern in allowed_patterns:
weight_files = glob.glob(
os.path.join(model_name_or_path, pattern))
if weight_files:
return weight_files, pattern
else:
hf_api = HfApi()
repo_files = hf_api.list_repo_files(repo_id=model_name_or_path)
for pattern in allowed_patterns:
matching_files = fnmatch.filter(repo_files, pattern)
if matching_files:
hf_folder = download_weights_from_hf(
model_name_or_path, self.load_config.download_dir,
[pattern], revision)
return glob.glob(os.path.join(hf_folder, pattern)), pattern
raise RuntimeError(
f"No model weights found in: `{model_name_or_path}`")
def _prepare_weights(self, model_name_or_path: str,
revision: Optional[str]) -> Tuple[List[str], bool]:
"""Prepare weight files for the model."""
allowed_patterns = ["*.safetensors", "*.bin", "*.pt"]
hf_weights_files, matched_pattern = self._get_weight_files(
model_name_or_path, allowed_patterns, revision)
if matched_pattern != "*.safetensors":
hf_weights_files = filter_files_not_needed_for_inference(
hf_weights_files)
if len(hf_weights_files) == 0:
raise RuntimeError(
f"Cannot find any model weights with `{model_name_or_path}`")
return hf_weights_files, matched_pattern == "*.safetensors"
def _get_quantized_weights_iterator(
self, model_name_or_path: str, revision: Optional[str]
) -> Tuple[Generator[Tuple[str, torch.Tensor], None, None], Dict[str,
Any]]:
"""Get an iterator to the model weights with bitsandbytes quantization,
as well as the quantization state dictionary."""
# only load the bitsandbytes module when needed
try:
import bitsandbytes
if bitsandbytes.__version__ < "0.42.0":
raise ImportError("bitsandbytes version is wrong. Please "
"install bitsandbytes>=0.42.0.")
from bitsandbytes.functional import quantize_4bit
except ImportError as err:
raise ImportError("Please install bitsandbytes>=0.42.0 via "
"`pip install bitsandbytes>=0.42.0` to use "
"bitsandbytes quantizer.") from err
hf_weights_files, use_safetensors = self._prepare_weights(
model_name_or_path, revision)
quant_state_dict = {}
if use_safetensors:
weight_iterator = safetensors_weights_iterator(hf_weights_files)
else:
weight_iterator = pt_weights_iterator(hf_weights_files)
def generator():
for weight_name, weight_tensor in weight_iterator:
if any(target_module in weight_name
for target_module in self.target_modules):
weight_name = weight_name.replace(".weight", ".qweight")
# bitsandbytes requires data in GPU
loaded_weight = weight_tensor.cuda().data
with set_default_torch_dtype(torch.float32):
processed_weight, quant_state = quantize_4bit(
loaded_weight,
compress_statistics=True,
quant_type="nf4")
quant_state_dict[weight_name] = quant_state
else:
processed_weight = weight_tensor
yield weight_name, processed_weight
return generator(), quant_state_dict
def _load_weights(self, model_config: ModelConfig,
model: nn.Module) -> None:
if not hasattr(model, 'load_weights'):
raise AttributeError(
"The required method 'load_weights' is not defined in class"
f" {type(self).__name__}.")
if not hasattr(model, 'bitsandbytes_stacked_params_mapping'):
raise AttributeError(
f"Model {type(self).__name__} does not support BitsAndBytes "
"quantization yet.")
logger.info("Loading weights with BitsAndBytes quantization. "
" May take a while ...")
qweight_iterator, quant_state_dict = (
self._get_quantized_weights_iterator(model_config.model,
model_config.revision))
model.load_weights(qweight_iterator)
param_dict = dict(model.named_parameters())
stacked_quant_state_dict: Dict[str, Dict[int, Any]] = {}
for quant_param_name in quant_state_dict:
non_stacked_param_name = quant_param_name
shard_index = 0
for shard_name, (
weight_name, index
) in model.bitsandbytes_stacked_params_mapping.items():
if shard_name in quant_param_name:
shard_index = index
quant_param_name = quant_param_name.replace(
shard_name, weight_name)
break
if quant_param_name not in param_dict:
raise ValueError(
f"Parameter {quant_param_name} not found in the model.")
if quant_param_name not in stacked_quant_state_dict:
stacked_quant_state_dict[quant_param_name] = {}
stacked_quant_state_dict[quant_param_name][shard_index] = (
quant_state_dict[non_stacked_param_name])
# save quant_states and offsets as the attributes of the parameters
for param_name, param in param_dict.items():
if param_name in stacked_quant_state_dict:
quant_states = stacked_quant_state_dict[param_name]
set_weight_attrs(param, {"bnb_quant_state": quant_states})
pack_ratio = getattr(param, "pack_factor", -1)
if pack_ratio == -1:
raise ValueError(
f"pack_factor not set for parameter {param_name}.")
num_elements = [0] * len(quant_states)
for seq, quant_state in enumerate(quant_states.items()):
num_elements[seq] = math.prod(
quant_state[1].shape) // pack_ratio
offsets = np.concatenate(([0], np.cumsum(num_elements)))
set_weight_attrs(param, {"bnb_shard_offsets": offsets})
def load_model(self, *, model_config: ModelConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
cache_config: CacheConfig) -> nn.Module:
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model = _initialize_model(model_config, self.load_config,
lora_config, vision_language_config,
cache_config)
self._load_weights(model_config, model)
return model.eval()
def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
"""Get a model loader based on the load format."""
......@@ -557,4 +799,7 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
if load_config.load_format == LoadFormat.SHARDED_STATE:
return ShardedStateLoader(load_config)
if load_config.load_format == LoadFormat.BITSANDBYTES:
return BitsAndBytesModelLoader(load_config)
return DefaultModelLoader(load_config)
......@@ -122,15 +122,22 @@ def get_quant_config(model_config: ModelConfig,
hf_quant_config = getattr(model_config.hf_config, "quantization_config",
None)
if hf_quant_config is None:
compression_config = getattr(model_config.hf_config,
"compression_config", None)
if compression_config is not None:
hf_quant_config = compression_config.get("quantization_config",
None)
# compressed-tensors uses a compressions_config
hf_quant_config = getattr(model_config.hf_config, "compression_config",
None)
if hf_quant_config is not None:
return quant_cls.from_config(hf_quant_config)
model_name_or_path = model_config.model
# In case of bitsandbytes/QLoRA, get quant config from the adapter model.
if model_config.quantization == "bitsandbytes":
if (not load_config.model_loader_extra_config
or "qlora_adapter_name_or_path"
not in load_config.model_loader_extra_config):
return quant_cls.from_config({"adapter_name_or_path": ""})
model_name_or_path = load_config.model_loader_extra_config[
"qlora_adapter_name_or_path"]
else:
model_name_or_path = model_config.model
is_local = os.path.isdir(model_name_or_path)
if not is_local:
# Download the config files.
......@@ -169,6 +176,10 @@ def get_quant_config(model_config: ModelConfig,
quant_config_file = quant_config_files[0]
with open(quant_config_file, "r") as f:
config = json.load(f)
if model_config.quantization == "bitsandbytes":
config["adapter_name_or_path"] = model_name_or_path
return quant_cls.from_config(config)
......
......@@ -33,6 +33,8 @@ _GENERATION_MODELS = {
"LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
"LlavaForConditionalGeneration":
("llava", "LlavaForConditionalGeneration"),
"LlavaNextForConditionalGeneration":
("llava_next", "LlavaNextForConditionalGeneration"),
# For decapoda-research/llama-*
"LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
"MistralForCausalLM": ("llama", "LlamaForCausalLM"),
......
......@@ -247,11 +247,12 @@ class DbrxFusedNormAttention(nn.Module):
def __init__(
self,
config: DbrxConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.d_model = config.d_model
self.attn = DbrxAttention(config, quant_config)
self.attn = DbrxAttention(config, cache_config, quant_config)
self.norm_1 = nn.LayerNorm(self.d_model)
self.norm_2 = nn.LayerNorm(self.d_model)
......
......@@ -319,6 +319,14 @@ class LlamaForCausalLM(nn.Module):
"lm_head": "output_embeddings",
}
embedding_padding_modules = ["lm_head"]
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}
def __init__(
self,
......
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union
import torch
from torch import nn
import torch.nn as nn
# TODO(xwjiang): We should port CLIPVisionModel's code over to not depend on
# transformers' impl.
from transformers import CLIPVisionModel, LlavaConfig
......@@ -17,6 +17,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import get_dummy_image_data
from vllm.sequence import SamplerOutput
from .vlm_base import VisionLanguageModelBase
......@@ -49,10 +51,10 @@ class LlavaMultiModalProjector(nn.Module):
return hidden_states
def _merge_vision_embeddings(input_ids: torch.Tensor,
inputs_embeds: torch.Tensor,
vision_embeddings: torch.Tensor,
image_token_id: int) -> torch.Tensor:
def merge_vision_embeddings(input_ids: torch.Tensor,
inputs_embeds: torch.Tensor,
vision_embeddings: torch.Tensor,
image_token_id: int) -> torch.Tensor:
"""In place merges in vision_embeddings with inputs_embeds."""
mask = (input_ids == image_token_id)
......@@ -82,6 +84,9 @@ class LlavaImageFeatureInputs(TypedDict):
LlavaImageInputs = Union[LlavaImagePixelInputs, LlavaImageFeatureInputs]
@MULTIMODAL_REGISTRY.register_image_feature_input()
@MULTIMODAL_REGISTRY.register_image_pixel_input()
@MULTIMODAL_REGISTRY.register_dummy_data(get_dummy_image_data)
class LlavaForConditionalGeneration(VisionLanguageModelBase):
def __init__(self,
......@@ -131,30 +136,43 @@ class LlavaForConditionalGeneration(VisionLanguageModelBase):
return data
def _parse_and_validate_image_input(
self, data: object) -> Optional[LlavaImageInputs]:
self, **kwargs: object) -> Optional[LlavaImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
image_features = kwargs.pop("image_features", None)
expected_input_type = self.vision_language_config.image_input_type
ImageInputType = VisionLanguageConfig.ImageInputType
if data is None:
return None
if expected_input_type == ImageInputType.PIXEL_VALUES:
if not isinstance(data, torch.Tensor):
raise TypeError("Image pixel vector should be a tensor, "
f"but received type: {type(data)}")
if image_features is not None:
raise ValueError(
"Expected pixel values but got image features")
if pixel_values is None:
return None
if not isinstance(pixel_values, torch.Tensor):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
return LlavaImagePixelInputs(
type="pixel_values",
data=self._validate_image_data(data),
data=self._validate_image_data(pixel_values),
)
elif expected_input_type == ImageInputType.IMAGE_FEATURES:
if not isinstance(data, torch.Tensor):
raise TypeError("Image feature vector should be a tensor, "
f"but received type: {type(data)}")
if expected_input_type == ImageInputType.IMAGE_FEATURES:
if pixel_values is not None:
raise ValueError(
"Expected image features but got pixel values")
if image_features is None:
return None
if not isinstance(image_features, torch.Tensor):
raise ValueError("Incorrect type of image features. "
f"Got type: {type(image_features)}")
return LlavaImageFeatureInputs(
type="image_features",
data=self._validate_image_data(data),
data=self._validate_image_data(image_features),
)
return None
......@@ -201,12 +219,14 @@ class LlavaForConditionalGeneration(VisionLanguageModelBase):
return self.multi_modal_projector(image_features)
def forward(self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
image_input: Optional[torch.Tensor] = None) -> SamplerOutput:
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
**kwargs: object,
) -> SamplerOutput:
"""Run forward pass for Llava 1.5.
One key thing to understand is the `input_ids` already accounts for the
......@@ -227,10 +247,10 @@ class LlavaForConditionalGeneration(VisionLanguageModelBase):
This way, the `positions` and `attn_metadata` are consistent
with the `input_ids`.
The model takes two types of image inputs:
The model takes two types of image inputs:
PIXEL_VALUES and IMAGE_FEATURES.
The following shows how each maps to huggingface implementation.
PIXEL_VALUES:
PIXEL_VALUES:
- https://github.com/huggingface/transformers/blob/07bdbeb/src/transformers/models/llava/modeling_llava.py#L353
IMAGE_FEATURES:
- https://github.com/huggingface/transformers/blob/07bdbeb/src/transformers/models/llava/modeling_llava.py#L430
......@@ -239,17 +259,18 @@ class LlavaForConditionalGeneration(VisionLanguageModelBase):
Args:
input_ids: Flattened (concatenated) input_ids corresponding to a
batch.
image_input: A batch of image inputs.
For PIXEL_VALUES, expecting [1, 3, 336, 336].
For IMAGE_FEATURES, expecting [1, 576, 1024].
pixel_values: For PIXEL_VALUES, expects a batch with shape
[1, 3, 336, 336].
image_features: For IMAGE_FEATURES, expects a batch with shape
[1, 576, 1024].
"""
parsed_image_input = self._parse_and_validate_image_input(image_input)
image_input = self._parse_and_validate_image_input(**kwargs)
if parsed_image_input is not None:
vision_embeddings = self._process_image_input(parsed_image_input)
if image_input is not None:
vision_embeddings = self._process_image_input(image_input)
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
inputs_embeds = _merge_vision_embeddings(
inputs_embeds = merge_vision_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.vision_language_config.image_token_id)
......
from typing import (Dict, Iterable, List, Literal, Optional, Tuple, TypedDict,
Union)
import torch
import torch.nn as nn
from PIL import Image
# TODO(xwjiang): We should port CLIPVisionModel's code over to not depend on
# transformers' impl.
from transformers import CLIPVisionModel, LlavaNextConfig
from transformers.models.llava_next.modeling_llava_next import (
get_anyres_image_grid_shape, unpad_image)
from typing_extensions import NotRequired
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, ModelConfig, VisionLanguageConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalData
from vllm.multimodal.image import ImagePixelData, get_dummy_image_data
from vllm.sequence import SamplerOutput, SequenceData
from .llava import LlavaMultiModalProjector, merge_vision_embeddings
from .vlm_base import VisionLanguageModelBase
logger = init_logger(__name__)
_KEYS_TO_MODIFY_MAPPING = {
"language_model.lm_head": "lm_head",
"language_model.model": "language_model",
}
class LlavaNextImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: torch.Tensor
"""Shape: (batch_size, 1 + num_patches, num_channels, height, width)"""
image_sizes: NotRequired[torch.Tensor]
"""Shape: (batch_size, 2)"""
class LlavaNextImageFeatureInputs(TypedDict):
type: Literal["image_features"]
data: torch.Tensor
"""Shape: (batch_size, 1 + num_patches, image_feature_size, hidden_size)"""
image_sizes: NotRequired[torch.Tensor]
"""Shape: (batch_size, 2)"""
LlavaNextImageInputs = Union[LlavaNextImagePixelInputs,
LlavaNextImageFeatureInputs]
def _get_dummy_image_data(
seq_len: int,
model_config: ModelConfig,
vlm_config: VisionLanguageConfig,
) -> Tuple[SequenceData, MultiModalData]:
seq_data, fake_mm_data = get_dummy_image_data(seq_len, model_config,
vlm_config)
config_input_type = vlm_config.image_input_type
ImageInputType = VisionLanguageConfig.ImageInputType
if config_input_type == ImageInputType.PIXEL_VALUES:
_, c, h, w = vlm_config.image_input_shape
mode = {1: "L", 3: "RGB"}[c]
fake_mm_data = ImagePixelData(Image.new(mode, (w, h), color=0))
return seq_data, fake_mm_data
def _image_pixel_processor(
data: ImagePixelData,
model_config: ModelConfig,
vlm_config: VisionLanguageConfig,
) -> Dict[str, torch.Tensor]:
image = data.image
if isinstance(image, torch.Tensor):
pixel_values = image.to(model_config.dtype)
batch_size, _, _, h, w = pixel_values.shape
image_sizes = torch.tensor([(w, h) for _ in range(batch_size)])
return {"pixel_values": pixel_values, "image_sizes": image_sizes}
# Temporary patch before dynamic number of image tokens is supported
_, _, h, w = vlm_config.image_input_shape
if (w, h) != (image.width, image.height):
logger.warning(
"Dynamic image shape is currently not supported. "
"Resizing input image to (%d, %d).", w, h)
data.image = image.resize((w, h))
return MULTIMODAL_REGISTRY._get_plugin_for_data_type(ImagePixelData) \
._default_input_processor(data, model_config, vlm_config)
@MULTIMODAL_REGISTRY.register_image_pixel_input(_image_pixel_processor)
@MULTIMODAL_REGISTRY.register_dummy_data(_get_dummy_image_data)
class LlavaNextForConditionalGeneration(VisionLanguageModelBase):
"""
Args to `forward()`:
input_ids: Flattened (concatenated) input_ids corresponding to a
batch.
pixel_values: For PIXEL_VALUES, expects a batch with shape
[1, num_patches, 3, 336, 336].
image_features: For IMAGE_FEATURES, expects a batch with shape
[1, num_patches, 1176, 1024].
"""
def __init__(self,
config: LlavaNextConfig,
vision_language_config: VisionLanguageConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None) -> None:
super().__init__(vision_language_config)
# Update the type annotation from that of its superclass
self.config = config
if self.vision_language_config.image_input_type == (
VisionLanguageConfig.ImageInputType.PIXEL_VALUES):
self.vision_tower = CLIPVisionModel(config.vision_config)
else:
raise TypeError("Image features are not supported by LLaVA-NeXT")
self.multi_modal_projector = LlavaMultiModalProjector(
vision_hidden_size=config.vision_config.hidden_size,
text_hidden_size=config.text_config.hidden_size,
projector_hidden_act=config.projector_hidden_act)
self.quant_config = quant_config
self.language_model = LlamaModel(config.text_config, cache_config,
quant_config)
self.unpadded_vocab_size = config.text_config.vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.text_config.hidden_size,
org_num_embeddings=self.language_model.org_vocab_size)
logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size, logit_scale)
self.sampler = Sampler()
self.image_newline = nn.Parameter(
torch.empty(config.text_config.hidden_size))
def _validate_image_pixels(self, data: torch.Tensor) -> torch.Tensor:
_, num_channels, _, _ = self.vision_language_config.image_input_shape
# Note that this is different from that of vLLM vision_language_config
# since the image is resized by the HuggingFace preprocessor
height = width = self.config.vision_config.image_size
if list(data.shape[2:]) != [num_channels, height, width]:
raise ValueError(
f"The expected image tensor shape is batch dimension plus "
f"num_patches plus {[num_channels, height, width]}. "
f"You supplied {data.shape}. "
f"If you are using vLLM's entrypoint, make sure your "
f"supplied image input is consistent with "
f"image_input_shape in engine args.")
return data
def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor:
if list(data.shape[1:]) != [2]:
raise ValueError(
f"The expected image sizes shape is batch dimension plus "
f"{[2]}. You supplied {data.shape}.")
return data
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[LlavaNextImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
image_sizes = kwargs.pop("image_sizes", None)
image_features = kwargs.pop("image_features", None)
expected_input_type = self.vision_language_config.image_input_type
ImageInputType = VisionLanguageConfig.ImageInputType
if expected_input_type == ImageInputType.PIXEL_VALUES:
if image_features is not None:
raise ValueError(
"Expected pixel values but got image features")
if pixel_values is None:
return None
if not isinstance(pixel_values, torch.Tensor):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
if not isinstance(image_sizes, torch.Tensor):
raise ValueError("Incorrect type of image sizes. "
f"Got type: {type(image_sizes)}")
return LlavaNextImagePixelInputs(
type="pixel_values",
data=self._validate_image_pixels(pixel_values),
image_sizes=self._validate_image_sizes(image_sizes),
)
assert expected_input_type != ImageInputType.IMAGE_FEATURES, (
"Failed to validate this at initialization time")
return None
def _select_image_features(self, image_features: torch.Tensor, *,
strategy: str) -> torch.Tensor:
# Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421 # noqa
if strategy == "default":
return image_features[:, 1:]
elif strategy == "full":
return image_features
raise ValueError(f"Unexpected select feature strategy: {strategy}")
def _image_pixels_to_features(self, vision_tower: CLIPVisionModel,
pixel_values: torch.Tensor) -> torch.Tensor:
# TODO(xwjiang): Maybe port minimal CLIPVisionModel over.
image_outputs = vision_tower(pixel_values.to(vision_tower.device),
output_hidden_states=True)
image_features = image_outputs.hidden_states[
self.config.vision_feature_layer]
return self._select_image_features(
image_features,
strategy=self.config.vision_feature_select_strategy,
)
def _merge_image_patch_embeddings(self, image_size: torch.Tensor,
patch_embeddings: torch.Tensor, *,
strategy: str) -> torch.Tensor:
# Based on: https://github.com/haotian-liu/LLaVA/blob/main/llava/model/llava_arch.py
if strategy == "flat":
return patch_embeddings.flatten(0, 1)
if strategy.startswith("spatial"):
orig_width, orig_height = image_size
height = width = self.config.vision_config.image_size \
// self.config.vision_config.patch_size
base_patch_embeds = patch_embeddings[0]
if height * width != base_patch_embeds.shape[0]:
raise ValueError(
"The number of patches is not consistent with the "
"image size.")
if patch_embeddings.shape[0] > 1:
other_patch_embeds = patch_embeddings[1:]
# image_aspect_ratio == "anyres"
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
(orig_width, orig_height),
self.config.image_grid_pinpoints,
self.config.vision_config.image_size,
)
other_patch_embeds = other_patch_embeds \
.view(num_patch_width, num_patch_height, height, width, -1)
if "unpad" in strategy:
other_patch_embeds = other_patch_embeds \
.permute(4, 0, 2, 1, 3).contiguous() \
.flatten(1, 2).flatten(2, 3)
other_patch_embeds = unpad_image(other_patch_embeds,
image_size)
other_patch_embeds = torch.cat((
other_patch_embeds,
self.image_newline[:, None, None] \
.expand(*other_patch_embeds.shape[:-1], 1) \
.to(other_patch_embeds.device),
), dim=-1)
other_patch_embeds = other_patch_embeds \
.flatten(1, 2).transpose(0, 1)
else:
other_patch_embeds = other_patch_embeds \
.permute(0, 2, 1, 3, 4).contiguous() \
.flatten(0, 3)
merged_patch_embeddings = torch.cat(
(base_patch_embeds, other_patch_embeds), dim=0)
else:
if "unpad" in strategy:
merged_patch_embeddings = torch.cat(
(base_patch_embeds,
self.image_newline[None] \
.to(base_patch_embeds.device)
), dim=0)
else:
merged_patch_embeddings = base_patch_embeds
return merged_patch_embeddings
raise ValueError(f"Unexpected patch merge strategy: {strategy}")
def _process_image_pixels(
self, inputs: LlavaNextImagePixelInputs) -> torch.Tensor:
assert self.vision_tower is not None
pixel_values = inputs["data"]
b, num_patches, c, h, w = pixel_values.shape
stacked_pixel_values = pixel_values.view(b * num_patches, c, h, w)
stacked_image_features = self._image_pixels_to_features(
self.vision_tower, stacked_pixel_values)
return stacked_image_features.view(b, num_patches,
*stacked_image_features.shape[-2:])
def _process_image_input(
self, image_input: LlavaNextImageInputs) -> torch.Tensor:
if image_input["type"] == "pixel_values":
assert self.vision_tower is not None
image_features = self._process_image_pixels(image_input)
else:
image_features = image_input["data"]
patch_embeddings = self.multi_modal_projector(image_features)
image_sizes = image_input.get("image_sizes")
if image_sizes is None:
batch_size = image_input["data"].shape[0]
vision_config = self.config.vision_config
default_width = default_height = vision_config.image_size
image_sizes = torch.as_tensor([[default_width, default_height]
for _ in range(batch_size)])
merged_patch_embeddings = [
self._merge_image_patch_embeddings(image_sizes[i],
patch_features,
strategy="spatial_unpad")
for i, patch_features in enumerate(patch_embeddings)
]
return torch.stack(merged_patch_embeddings, dim=0)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
**kwargs: object,
) -> SamplerOutput:
"""Run forward pass for Llava 1.5.
One key thing to understand is the `input_ids` already accounts for the
positions of the to-be-inserted image embeddings.
Concretely, consider a text prompt:
"<image>\nUSER: What's the content of the image?\nASSISTANT:".
Tokenizer outputs:
[1, 32000, 29871, 13, 11889, 29901, 1724, 29915, 29879, 278,
2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901].
The to-be-inserted image has a size of 576 (24 * 24) along the context
length dimension.
`input_ids` is thus [1, 32000, ..., 32000, 29871, 13, 11889, 29901,
1724, 29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933,
9047, 13566, 29901].
There will be 576 `32000` in the `input_ids`.
(32000 is the token id for `<image>`.)
This way, the `positions` and `attn_metadata` are consistent
with the `input_ids`.
The model takes two types of image inputs:
PIXEL_VALUES and IMAGE_FEATURES.
The following shows how each maps to huggingface implementation.
PIXEL_VALUES:
- https://github.com/huggingface/transformers/blob/07bdbeb/src/transformers/models/llava/modeling_llava.py#L353
IMAGE_FEATURES:
- https://github.com/huggingface/transformers/blob/07bdbeb/src/transformers/models/llava/modeling_llava.py#L430
before going through the multi modal projector.
Args:
input_ids: Flattened (concatenated) input_ids corresponding to a
batch.
pixel_values: For PIXEL_VALUES, expects a batch with shape
[1, 3, 336, 336].
image_features: For IMAGE_FEATURES, expects a batch with shape
[1, 576, 1024].
"""
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is not None:
vision_embeddings = self._process_image_input(image_input)
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
inputs_embeds = merge_vision_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.vision_language_config.image_token_id)
input_ids = None
else:
inputs_embeds = None
hidden_states = self.language_model(input_ids,
positions,
kv_caches,
attn_metadata,
inputs_embeds=inputs_embeds)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# only doing this for language model part for now.
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
if key_to_modify in name:
name = name.replace(key_to_modify, new_key)
use_default_weight_loading = False
if "vision" in name:
if self.vision_tower is not None:
# We only do sharding for language model and
# not vision model for now.
use_default_weight_loading = True
else:
for (param_name, weight_name,
shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
param = params_dict[name.replace(weight_name, param_name)]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
use_default_weight_loading = True
if use_default_weight_loading:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
......@@ -41,7 +41,9 @@ from vllm.model_executor.layers.linear import (QKVParallelLinear,
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
from vllm.model_executor.layers.quantization.fp8 import (Fp8Config,
per_tensor_dequantize,
per_tensor_quantize)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
......@@ -98,16 +100,16 @@ class MixtralMoE(nn.Module):
if self.use_fp8 and self.quant_config.is_checkpoint_fp8_serialized:
params_dtype = torch.float8_e4m3fn
self.w13_weight = nn.Parameter(
torch.empty(self.num_total_experts,
2 * self.intermediate_size,
self.hidden_size,
dtype=params_dtype))
self.w2_weight = nn.Parameter(
torch.empty(self.num_total_experts,
self.hidden_size,
self.intermediate_size,
dtype=params_dtype))
self.w13_weight = nn.Parameter(torch.empty(self.num_total_experts,
2 * self.intermediate_size,
self.hidden_size,
dtype=params_dtype),
requires_grad=False)
self.w2_weight = nn.Parameter(torch.empty(self.num_total_experts,
self.hidden_size,
self.intermediate_size,
dtype=params_dtype),
requires_grad=False)
set_weight_attrs(self.w13_weight, {
"weight_loader": self.weight_loader,
......@@ -124,7 +126,10 @@ class MixtralMoE(nn.Module):
if self.use_fp8:
# WEIGHT_SCALE (for fp8)
# Allocate 2 scales for w1 and w3 respectively.
# They will be combined to a single scale after weight loading.
self.w13_scale = nn.Parameter(torch.ones(self.num_total_experts,
2,
dtype=torch.float32),
requires_grad=False)
self.w2_scale = nn.Parameter(torch.ones(self.num_total_experts,
......@@ -142,17 +147,17 @@ class MixtralMoE(nn.Module):
"weight_loader": self.weight_loader,
})
# ACT_SCALE (for fp8)
# INPUT_SCALE (for fp8)
if quant_config.activation_scheme == "static":
if not quant_config.is_checkpoint_fp8_serialized:
raise ValueError(
"Found static activation scheme for checkpoint that "
"was not serialized fp8.")
self.a13_scale = nn.Parameter(torch.zeros(
self.a13_scale = nn.Parameter(torch.ones(
self.num_total_experts, dtype=torch.float32),
requires_grad=False)
self.a2_scale = nn.Parameter(torch.zeros(
self.num_total_experts, dtype=torch.float32),
self.a2_scale = nn.Parameter(torch.ones(self.num_total_experts,
dtype=torch.float32),
requires_grad=False)
set_weight_attrs(self.a13_scale, {
......@@ -175,8 +180,22 @@ class MixtralMoE(nn.Module):
shard_size:2 * shard_size, :] = loaded_weight[shard, :]
if weight_name.endswith("w2.weight"):
param_data[expert_id, :, :] = loaded_weight[:, shard]
if "act_scale" in weight_name or "weight_scale" in weight_name:
# Loading scales
if "input_scale" in weight_name or "w2.weight_scale" in weight_name:
if param_data[expert_id] != 1 and (param_data[expert_id] -
loaded_weight).abs() > 1e-5:
raise ValueError(
"input_scales of w1 and w3 of a layer "
f"must be equal. But got {param_data[expert_id]} "
f"vs. {loaded_weight}")
param_data[expert_id] = loaded_weight
elif "weight_scale" in weight_name:
# We have to keep the weight scales of w1 and w3 because
# we need to re-quantize w1/w3 weights after weight loading.
assert "w1" in weight_name or "w3" in weight_name
shard_id = 0 if "w1" in weight_name else 1
param_data[expert_id][shard_id] = loaded_weight
def process_weights_after_loading(self):
# Fp8 is the only case where we need to process after loading.
......@@ -189,6 +208,12 @@ class MixtralMoE(nn.Module):
dtype=torch.float8_e4m3fn)
w2_weight = torch.empty_like(self.w2_weight.data,
dtype=torch.float8_e4m3fn)
# Re-initialize w13_scale because we directly quantize
# merged w13 weights and generate a single scaling factor.
self.w13_scale = nn.Parameter(torch.ones(self.num_total_experts,
dtype=torch.float32),
requires_grad=False)
for expert in range(self.num_total_experts):
w13_weight[expert, :, :], self.w13_scale[
expert] = ops.scaled_fp8_quant(
......@@ -199,25 +224,44 @@ class MixtralMoE(nn.Module):
self.w13_weight = nn.Parameter(w13_weight, requires_grad=False)
self.w2_weight = nn.Parameter(w2_weight, requires_grad=False)
# If checkpoint is fp8 + static, cleanup act_scales.
# Since state_dict has an act_scale per expert but our kernels
# are passed one act_scale shared across all experts.
elif self.quant_config.activation_scheme == "static":
if self.a13_scale is None or self.a2_scale is None:
raise ValueError(
"QuantConfig has static quantization, but found "
"activation scales are None.")
else:
# If checkpoint is fp8 + static, cleanup input_scales.
# Since state_dict has an input_scale per expert but our kernels
# are passed one input_scale shared across all experts.
if self.quant_config.activation_scheme == "static":
if self.a13_scale is None or self.a2_scale is None:
raise ValueError(
"QuantConfig has static quantization, but found "
"activation scales are None.")
if (not all_close_1d(self.a13_scale)
or not all_close_1d(self.a2_scale)):
print_warning_once(
"Found act_scales that are not equal for fp8 MoE layer. "
"Using the maximum across experts for each layer. ")
if (not all_close_1d(self.a13_scale)
or not all_close_1d(self.a2_scale)):
print_warning_once(
"Found input_scales that are not equal for "
"fp8 MoE layer. Using the maximum across experts "
"for each layer. ")
self.a13_scale = nn.Parameter(self.a13_scale.max(),
requires_grad=False)
self.a2_scale = nn.Parameter(self.a2_scale.max(),
requires_grad=False)
self.a13_scale = nn.Parameter(self.a13_scale.max(),
requires_grad=False)
self.a2_scale = nn.Parameter(self.a2_scale.max(),
requires_grad=False)
assert self.w13_scale is not None
shard_size = self.intermediate_size
max_w13_scales = self.w13_scale.max(dim=1).values
for expert_id in range(self.num_total_experts):
start = 0
for shard_id in range(2):
dq_weight = per_tensor_dequantize(
self.w13_weight[expert_id][start:start +
shard_size, :],
self.w13_scale[expert_id][shard_id])
self.w13_weight[expert_id][
start:start + shard_size, :] = per_tensor_quantize(
dq_weight, max_w13_scales[expert_id])
start += shard_size
self.w13_scale = nn.Parameter(max_w13_scales, requires_grad=False)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_size = hidden_states.shape
......@@ -278,15 +322,6 @@ class MixtralAttention(nn.Module):
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
if isinstance(
quant_config,
Fp8Config) and not quant_config.is_checkpoint_fp8_serialized:
print_warning_once(
"For Mixtral FP8 quantization, we currently do not quantize "
"the attention layers until their FP8 performance is improved."
)
quant_config = None
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
......@@ -541,7 +576,7 @@ class MixtralForCausalLM(nn.Module):
# These are the activation scales for the experts
# (param_name, weight_name, expert_id)
("a13_scale" if weight_name in ["w1", "w3"] else "a2_scale",
f"experts.{expert_id}.{weight_name}.act_scale", expert_id)
f"experts.{expert_id}.{weight_name}.input_scale", expert_id)
for expert_id in range(self.config.num_local_experts)
for weight_name in ["w1", "w2", "w3"]
]
......
......@@ -233,7 +233,7 @@ def _prepare_seq_groups(
logits = hidden_states[selected_token_indices]
"""
if sampling_params.prompt_logprobs:
if sampling_params.prompt_logprobs is not None:
selected_token_indices.extend(
range(model_output_idx, model_output_idx + prompt_logprob_len))
model_output_idx += prompt_logprob_len
......@@ -386,16 +386,18 @@ class SamplingTensors:
presence_penalties += [0] * prefill_len
frequency_penalties += [0] * prefill_len
repetition_penalties += [1] * prefill_len
prompt_tokens.extend([] for _ in range(prefill_len))
output_tokens.extend([] for _ in range(prefill_len))
if do_penalties:
prompt_tokens.extend([] for _ in range(prefill_len))
output_tokens.extend([] for _ in range(prefill_len))
if seq_group.do_sample:
sample_lens = len(seq_group.sample_indices)
assert sample_lens == len(seq_ids)
for seq_id in seq_ids:
seq_data = seq_group.seq_data[seq_id]
prompt_tokens.append(seq_data.prompt_token_ids)
output_tokens.append(seq_data.output_token_ids)
if do_penalties:
prompt_tokens.append(seq_data.prompt_token_ids)
output_tokens.append(seq_data.output_token_ids)
temperatures += [temperature] * len(seq_ids)
top_ps += [top_p] * len(seq_ids)
top_ks += [top_k] * len(seq_ids)
......@@ -443,18 +445,22 @@ class SamplingTensors:
# Note that the performance will be very bad without
# pinned memory.
pin_memory = is_pin_memory_available()
prompt_max_len = max([len(tokens) for tokens in prompt_tokens],
default=0)
prompt_padded_tokens = [
tokens + [vocab_size] * (prompt_max_len - len(tokens))
for tokens in prompt_tokens
]
output_max_len = max([len(tokens) for tokens in output_tokens],
default=0)
output_padded_tokens = [
tokens + [vocab_size] * (output_max_len - len(tokens))
for tokens in output_tokens
]
do_penalties = prompt_tokens or output_tokens
if do_penalties:
prompt_max_len = max([len(tokens) for tokens in prompt_tokens],
default=0)
prompt_padded_tokens = [
tokens + [vocab_size] * (prompt_max_len - len(tokens))
for tokens in prompt_tokens
]
output_max_len = max([len(tokens) for tokens in output_tokens],
default=0)
output_padded_tokens = [
tokens + [vocab_size] * (output_max_len - len(tokens))
for tokens in output_tokens
]
temperatures_t = torch.tensor(
temperatures,
......@@ -504,18 +510,22 @@ class SamplingTensors:
dtype=torch.long,
pin_memory=pin_memory,
)
prompt_tensor = torch.tensor(
prompt_padded_tokens,
device="cpu",
dtype=torch.long,
pin_memory=pin_memory,
)
output_tensor = torch.tensor(
output_padded_tokens,
device="cpu",
dtype=torch.long,
pin_memory=pin_memory,
)
if do_penalties:
prompt_tensor = torch.tensor(
prompt_padded_tokens,
device="cpu",
dtype=torch.long,
pin_memory=pin_memory,
)
output_tensor = torch.tensor(
output_padded_tokens,
device="cpu",
dtype=torch.long,
pin_memory=pin_memory,
)
else:
prompt_tensor = None
output_tensor = None
# need to transpose and make contiguous to
# copy the tensor correctly.
# [batch_size, n_seeds] -> [n_seeds, batch_size]
......@@ -538,6 +548,16 @@ class SamplingTensors:
extra_seeds_gpu = None
sampling_seeds_gpu = sampling_seeds_gpu[:num_base_seeds]
if do_penalties:
prompt_tokens_gpu = prompt_tensor.to(device=device,
non_blocking=True)
output_tokens_gpu = output_tensor.to(device=device,
non_blocking=True)
else:
empty_tensor = torch.empty(0, device=device, dtype=torch.long)
prompt_tokens_gpu = empty_tensor
output_tokens_gpu = empty_tensor
return cls(
temperatures=temperatures_t.to(device=device, non_blocking=True),
top_ps=top_ps_t.to(device=device, non_blocking=True),
......@@ -549,8 +569,8 @@ class SamplingTensors:
non_blocking=True),
repetition_penalties=repetition_penalties_t.to(device=device,
non_blocking=True),
prompt_tokens=prompt_tensor.to(device=device, non_blocking=True),
output_tokens=output_tensor.to(device=device, non_blocking=True),
prompt_tokens=prompt_tokens_gpu,
output_tokens=output_tokens_gpu,
sampling_seeds=sampling_seeds_gpu,
sample_indices=sample_indices_t.to(device=device,
non_blocking=True),
......
from .base import MultiModalData, MultiModalPlugin
from .registry import MULTIMODAL_REGISTRY, MultiModalRegistry
__all__ = [
"MultiModalData", "MultiModalPlugin", "MULTIMODAL_REGISTRY",
"MultiModalRegistry"
]
from abc import ABC, abstractmethod
from typing import (TYPE_CHECKING, Callable, Dict, Generic, Optional, Type,
TypeVar)
from vllm.config import ModelConfig, VisionLanguageConfig
from vllm.logger import init_logger
if TYPE_CHECKING:
import torch
from torch import nn
logger = init_logger(__name__)
class MultiModalData:
"""
Base class that contains multi-modal data.
To add a new modality, add a new file under ``multimodal`` directory.
In this new file, subclass :class:`~MultiModalData` and
:class:`~MultiModalPlugin`.
Finally, register the new plugin to
:const:`vllm.multimodal.MULTIMODAL_REGISTRY`.
This enables models to call :meth:`MultiModalRegistry.register_input` for
the new modality.
"""
pass
D = TypeVar("D", bound=MultiModalData)
N = TypeVar("N", bound=Type["nn.Module"])
MultiModalInputProcessor = Callable[[D, ModelConfig, VisionLanguageConfig],
Dict[str, "torch.Tensor"]]
"""Return a dictionary to be passed as keyword arguments to
:meth:`torch.nn.Module.forward`. This is similar in concept to tokenizers
and processors in HuggingFace Transformers."""
class MultiModalPlugin(ABC, Generic[D]):
"""
Base class that defines data processing logic for a specific modality.
In particular, we adopt a registry pattern to dispatch data processing
according to the model being used (considering that different models may
process the same data differently). This registry is in turn used by
:class:`~MultiModalRegistry` which acts at a higher level
(i.e., the modality of the data).
"""
@classmethod
def get_model_cls(cls, model_config: ModelConfig) -> Type["nn.Module"]:
# Avoid circular import
from vllm.model_executor.model_loader import get_model_architecture
return get_model_architecture(model_config)[0]
def __init__(self) -> None:
self._input_processors: Dict[Type["nn.Module"],
MultiModalInputProcessor[D]] = {}
@abstractmethod
def get_data_type(self) -> Type[D]:
"""
Get the modality (subclass of :class:`~MultiModalData`) served by
this plugin.
"""
raise NotImplementedError
@abstractmethod
def _default_input_processor(
self, data: D, model_config: ModelConfig,
vlm_config: VisionLanguageConfig) -> Dict[str, "torch.Tensor"]:
"""Return a dictionary to be passed as keyword arguments to
:meth:`torch.nn.Module.forward`. This is similar in concept to
tokenizers and processors in HuggingFace Transformers.
"""
raise NotImplementedError
def register_input_processor(self,
processor: Optional[
MultiModalInputProcessor[D]] = None):
"""
Register an input processor to a model class.
When the model receives input data that matches the modality served by
this plugin (see :meth:`get_data_type`), the provided input processor is
applied to preprocess the data. If `None` is provided, then the default
input processor is applied instead.
"""
def wrapper(model_cls: N) -> N:
if model_cls in self._input_processors:
logger.warning(
"Model class %s already has an input processor "
"registered to %s. It is overwritten by the new one.",
model_cls, self)
self._input_processors[model_cls] = processor \
or self._default_input_processor
return model_cls
return wrapper
def process_input(
self, data: D, model_config: ModelConfig,
vlm_config: VisionLanguageConfig) -> Dict[str, "torch.Tensor"]:
"""
Apply an input processor to a :class:`~MultiModalData` instance passed
to the model.
The model is identified by ``model_config``. ``vlm_config`` is
for compatibility purposes and may be merged into ``model_config``
in the near future.
"""
model_cls = self.get_model_cls(model_config)
processor = self._input_processors.get(model_cls)
if processor is None:
raise KeyError(f"No input processor in {self} is registered for "
f"model class {model_cls.__name__}.")
return processor(data, model_config, vlm_config)
from typing import Dict, Tuple, Type, Union
import torch
from PIL import Image
from vllm.config import ModelConfig, VisionLanguageConfig
from vllm.logger import init_logger
from vllm.sequence import SequenceData
from vllm.transformers_utils.image_processor import cached_get_image_processor
from .base import MultiModalData, MultiModalPlugin
logger = init_logger(__name__)
def _get_dummy_seq_data(seq_len: int,
vlm_config: VisionLanguageConfig) -> SequenceData:
# NOTE: We assume that <image> token is repeated `image_feature_size` times
# and then concatenated with the text prompt
# TODO: Enable other ways of inserting the image into the prompt
token_ids = [vlm_config.image_token_id] * vlm_config.image_feature_size
token_ids += [0] * (seq_len - vlm_config.image_feature_size)
return SequenceData(token_ids)
def _get_dummy_values(vlm_config: VisionLanguageConfig) -> torch.Tensor:
if vlm_config.image_processor is None:
values_dtype = torch.float16
else:
values_dtype = torch.uint8
return torch.zeros(vlm_config.image_input_shape, dtype=values_dtype)
def get_dummy_image_data(
seq_len: int,
model_config: ModelConfig,
vlm_config: VisionLanguageConfig,
) -> Tuple[SequenceData, MultiModalData]:
"""Standard dummy data factory for image data (to be used in
:meth:`vlm.multimodal.MultiModalRegistry.register_dummy_data`)."""
seq_data = _get_dummy_seq_data(seq_len, vlm_config)
values = _get_dummy_values(vlm_config)
config_input_type = vlm_config.image_input_type
ImageInputType = VisionLanguageConfig.ImageInputType
fake_mm_data: MultiModalData
if config_input_type == ImageInputType.PIXEL_VALUES:
fake_mm_data = ImagePixelData(values)
elif config_input_type == ImageInputType.IMAGE_FEATURES:
fake_mm_data = ImageFeatureData(values)
else:
raise NotImplementedError
return seq_data, fake_mm_data
class ImagePixelData(MultiModalData):
"""
The pixel data of an image. Can be one of:
- :class:``PIL.Image``: An image object. Requires that a HuggingFace
processor is available to the model.
- :class:``torch.Tensor``: The raw pixel data which is passed to the model
without additional pre-processing.
"""
def __init__(self, image: Union[Image.Image, torch.Tensor]) -> None:
if isinstance(image, Image.Image):
# So that this class can be created inside the Image context manager
image.load()
self.image = image
def __repr__(self) -> str:
image = self.image
if isinstance(image, Image.Image):
return f"{type(self).__name__}(image={image})"
return (f"{type(self).__name__}(image=torch.Tensor(shape="
f"{image.shape}, dtype={image.dtype}))")
class ImagePixelPlugin(MultiModalPlugin[ImagePixelData]):
def get_data_type(self) -> Type[ImagePixelData]:
return ImagePixelData
def _get_hf_image_processor(self, model_config: ModelConfig,
vlm_config: VisionLanguageConfig):
if vlm_config is None or vlm_config.image_processor is None:
return None
return cached_get_image_processor(
vlm_config.image_processor,
trust_remote_code=model_config.trust_remote_code,
revision=vlm_config.image_processor_revision,
)
def _default_input_processor(
self, data: ImagePixelData, model_config: ModelConfig,
vlm_config: VisionLanguageConfig) -> Dict[str, torch.Tensor]:
image = data.image
if isinstance(image, Image.Image):
image_processor = self._get_hf_image_processor(
model_config, vlm_config)
if image_processor is None:
raise RuntimeError("No HuggingFace processor is available"
"to process the image object")
try:
return image_processor.preprocess(image, return_tensors="pt") \
.to(model_config.dtype).data
except Exception:
logger.error("Failed to process image (%s)", image)
raise
elif isinstance(image, torch.Tensor):
pixel_values = image.to(model_config.dtype)
return {"pixel_values": pixel_values}
raise TypeError(f"Invalid image type: {type(image)}")
class ImageFeatureData(MultiModalData):
"""
The feature vector of an image, passed directly to the model.
This should be the output of the vision tower.
"""
def __init__(self, image_features: torch.Tensor) -> None:
self.image_features = image_features
def __repr__(self) -> str:
image_features = self.image_features
return (f"{type(self).__name__}(image_features=torch.Tensor(shape="
f"{image_features.shape}, dtype={image_features.dtype}))")
class ImageFeaturePlugin(MultiModalPlugin[ImageFeatureData]):
def get_data_type(self) -> Type[ImageFeatureData]:
return ImageFeatureData
def _default_input_processor(
self, data: ImageFeatureData, model_config: ModelConfig,
vlm_config: VisionLanguageConfig) -> Dict[str, torch.Tensor]:
image_features = data.image_features.to(model_config.dtype)
return {"image_features": image_features}
import functools
from typing import (TYPE_CHECKING, Any, Callable, Dict, Optional, Sequence,
Tuple, Type, TypeVar)
from vllm.config import ModelConfig, VisionLanguageConfig
from vllm.logger import init_logger
from .base import MultiModalData, MultiModalPlugin
from .image import (ImageFeatureData, ImageFeaturePlugin, ImagePixelData,
ImagePixelPlugin)
if TYPE_CHECKING:
import torch
from torch import nn
from vllm.sequence import SequenceData
logger = init_logger(__name__)
D = TypeVar("D", bound=MultiModalData)
N = TypeVar("N", bound=Type["nn.Module"])
MultiModalInputProcessor = Callable[[D, ModelConfig, VisionLanguageConfig],
Dict[str, "torch.Tensor"]]
MultiModalDummyFactory = Callable[[int, ModelConfig, VisionLanguageConfig],
Tuple["SequenceData", MultiModalData]]
class MultiModalRegistry:
"""
This registry is used by model runners to dispatch data processing
according to its modality and the target model.
"""
DEFAULT_PLUGINS = (ImageFeaturePlugin(), ImagePixelPlugin())
def __init__(self,
*,
plugins: Sequence[MultiModalPlugin[Any]] = DEFAULT_PLUGINS
) -> None:
self._plugins_by_data_type = {p.get_data_type(): p for p in plugins}
self._dummy_factories_by_model_type: Dict[Type["nn.Module"],
MultiModalDummyFactory] = {}
def register_plugin(self, plugin: MultiModalPlugin[Any]) -> None:
data_type = plugin.get_data_type()
if data_type in self._plugins_by_data_type:
logger.warning(
"A plugin is already registered for data type %s, "
"and will be overwritten by the new plugin %s.", data_type,
plugin)
self._plugins_by_data_type[data_type] = plugin
def _get_plugin_for_data_type(self, data_type: Type[MultiModalData]):
for typ in data_type.mro():
plugin = self._plugins_by_data_type.get(typ)
if plugin is not None:
return plugin
msg = f"Unknown multi-modal data type: {data_type}"
raise NotImplementedError(msg)
def register_dummy_data(self, factory: MultiModalDummyFactory):
"""
Register a dummy data factory to a model class.
During memory profiling, the provided function is invoked to create
dummy data to be inputted into the model. The modality and shape of
the dummy data should be an upper bound of what the model would receive
at inference time.
"""
def wrapper(model_cls: N) -> N:
if model_cls in self._dummy_factories_by_model_type:
logger.warning(
"Model class %s already has dummy data "
"registered to %s. It is overwritten by the new one.",
model_cls, self)
self._dummy_factories_by_model_type[model_cls] = factory
return model_cls
return wrapper
def dummy_data_for_profiling(self, seq_len: int, model_config: ModelConfig,
vlm_config: VisionLanguageConfig):
"""Create dummy data for memory profiling."""
model_cls = MultiModalPlugin.get_model_cls(model_config)
dummy_factory = self._dummy_factories_by_model_type.get(model_cls)
if dummy_factory is None:
msg = f"No dummy data defined for model class: {model_cls}"
raise NotImplementedError(msg)
return dummy_factory(seq_len, model_config, vlm_config)
def register_input(
self,
data_type: Type[D],
processor: Optional[MultiModalInputProcessor[D]] = None):
"""
Register an input processor for a specific modality to a model class.
See :meth:`MultiModalPlugin.register_input_processor` for more details.
"""
return self._get_plugin_for_data_type(data_type) \
.register_input_processor(processor)
def register_image_pixel_input(
self,
processor: Optional[
MultiModalInputProcessor[ImagePixelData]] = None):
"""
Register an input processor for image pixel data to a model class.
See :meth:`MultiModalPlugin.register_input_processor` for more details.
"""
return self.register_input(ImagePixelData, processor)
def register_image_feature_input(
self,
processor: Optional[
MultiModalInputProcessor[ImageFeatureData]] = None):
"""
Register an input processor for image feature data to a model class.
See :meth:`MultiModalPlugin.register_input_processor` for more details.
"""
return self.register_input(ImageFeatureData, processor)
def process_input(self, data: MultiModalData, model_config: ModelConfig,
vlm_config: VisionLanguageConfig):
"""
Apply an input processor to a :class:`~MultiModalData` instance passed
to the model.
See :meth:`MultiModalPlugin.process_input` for more details.
"""
return self._get_plugin_for_data_type(type(data)) \
.process_input(data, model_config, vlm_config)
def create_input_processor(self, model_config: ModelConfig,
vlm_config: VisionLanguageConfig):
"""
Create an input processor (see :meth:`process_input`) for a
specific model.
"""
return functools.partial(self.process_input,
model_config=model_config,
vlm_config=vlm_config)
MULTIMODAL_REGISTRY = MultiModalRegistry()
"""The global :class:`~MultiModalRegistry` which is used by model runners."""
import base64
from io import BytesIO
from typing import Optional, Union
import aiohttp
from PIL import Image
from vllm.config import ModelConfig
from vllm.envs import VLLM_IMAGE_FETCH_TIMEOUT
from vllm.multimodal.image import ImagePixelData
class ImageFetchAiohttp:
aiohttp_client: Optional[aiohttp.ClientSession] = None
@classmethod
def get_aiohttp_client(cls) -> aiohttp.ClientSession:
if cls.aiohttp_client is None:
timeout = aiohttp.ClientTimeout(total=VLLM_IMAGE_FETCH_TIMEOUT)
connector = aiohttp.TCPConnector()
cls.aiohttp_client = aiohttp.ClientSession(timeout=timeout,
connector=connector)
return cls.aiohttp_client
@classmethod
async def fetch_image(cls, image_url: str) -> Image.Image:
"""Load PIL image from a url or base64 encoded openai GPT4V format"""
if image_url.startswith('http'):
# Avoid circular import
from vllm import __version__ as VLLM_VERSION
client = cls.get_aiohttp_client()
headers = {"User-Agent": f"vLLM/{VLLM_VERSION}"}
async with client.get(url=image_url, headers=headers) as response:
response.raise_for_status()
image_raw = await response.read()
image = Image.open(BytesIO(image_raw))
# Only split once and assume the second part is the base64 encoded image
elif image_url.startswith('data:image'):
image = load_image_from_base64(image_url.split(',', 1)[1])
else:
raise ValueError("Invalid image url: A valid image url must start "
"with either 'data:image' or 'http'.")
return image
async def async_get_and_parse_image(image_url: str) -> ImagePixelData:
with await ImageFetchAiohttp.fetch_image(image_url) as image:
return ImagePixelData(image)
def encode_image_base64(image: Image.Image, format: str = 'JPEG') -> str:
"""encode image to base64 format."""
buffered = BytesIO()
if format == 'JPEG':
image = image.convert('RGB')
image.save(buffered, format)
return base64.b64encode(buffered.getvalue()).decode('utf-8')
def load_image_from_base64(image: Union[bytes, str]) -> Image.Image:
"""Load image from base64 format."""
return Image.open(BytesIO(base64.b64decode(image)))
# TODO(ywang96): move this to a model registry for preprocessing vision
# language prompts based on the model type.
def get_full_image_text_prompt(image_prompt: str, text_prompt: str,
config: ModelConfig) -> str:
"""Combine image and text prompts for vision language model depending on
the model architecture."""
if config.hf_config.model_type in ("llava", "llava_next"):
full_prompt = f"{image_prompt}\n{text_prompt}"
else:
raise ValueError(
f"Unsupported model type: {config.hf_config.model_type}")
return full_prompt
......@@ -5,6 +5,8 @@ from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import torch
from vllm.block import LogicalTokenBlock
from vllm.inputs import LLMInputs
from vllm.lora.request import LoRARequest
......@@ -12,8 +14,7 @@ from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
if TYPE_CHECKING:
import torch
from vllm.multimodal import MultiModalData
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
......@@ -398,25 +399,6 @@ class SequenceGroupState:
generator: Optional = None # type: ignore
class MultiModalData:
"""Multi modal request.
Args:
type: The data type.
data: The actual data.
The required shape and semantic meaning of it depends on the vision
language config of the hosted model.
See `VisionLanguageConfig` in `config.py`.
"""
class Type(enum.Enum):
IMAGE = enum.auto()
def __init__(self, type: Type, data: "torch.Tensor"):
self.type = type
self.data = data
class SequenceGroup:
"""A group of sequences that are generated from the same prompt.
......@@ -473,7 +455,7 @@ class SequenceGroup:
return next(iter(self.seqs_dict.values())).prompt_token_ids
@property
def multi_modal_data(self) -> Optional[MultiModalData]:
def multi_modal_data(self) -> Optional["MultiModalData"]:
# All sequences in the group should have the same multi-modal data.
# We use the multi-modal data of an arbitrary sequence.
return next(iter(self.seqs_dict.values())).multi_modal_data
......@@ -655,7 +637,7 @@ class SequenceGroupMetadata:
lora_request: Optional[LoRARequest] = None,
computed_block_nums: Optional[List[int]] = None,
state: Optional[SequenceGroupState] = None,
multi_modal_data: Optional[MultiModalData] = None,
multi_modal_data: Optional["MultiModalData"] = None,
encoder_seq_data: Optional[SequenceData] = None,
cross_block_table: Optional[List[int]] = None,
) -> None:
......@@ -798,13 +780,13 @@ class SamplerOutput:
outputs: List[CompletionSequenceGroupOutput]
# On-device tensor containing probabilities of each token.
sampled_token_probs: Optional["torch.Tensor"] = None
sampled_token_probs: Optional[torch.Tensor] = None
# On-device tensor containing the logprobs of each token.
logprobs: Optional["torch.Tensor"] = None
# On-device tensor containing the sampled token ids.
sampled_token_ids: Optional["torch.Tensor"] = None
sampled_token_ids: Optional[torch.Tensor] = None
# Spec decode metrics populated by workers.
spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None
......
......@@ -80,7 +80,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
target_sampler_output = self._scorer_worker.execute_model(
execute_model_req=execute_model_req.clone(
seq_group_metadata_list=target_seq_group_metadata_list, ))
seq_group_metadata_list=target_seq_group_metadata_list))
assert len(target_sampler_output) == 1, "expected single-step output"
target_sampler_output = target_sampler_output[0]
......@@ -140,8 +140,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
num_scoring_tokens)
def _contract_batch(
self, contracted_bs: int,
target_sampler_output: List[SamplerOutput],
self, contracted_bs: int, target_sampler_output: SamplerOutput,
proposals: SpeculativeProposals, num_scoring_tokens: int,
non_spec_indices: List[int], spec_indices: List[int],
k: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
......@@ -167,30 +166,16 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
non_spec_expanded_bs, _ = non_spec_target_token_ids.shape
spec_expanded_bs = expanded_batch_size - non_spec_expanded_bs
target_token_ids = target_token_ids.squeeze().reshape(
spec_expanded_bs, k + 1)
target_probs = target_probs.squeeze().reshape(spec_expanded_bs, k + 1,
self._vocab_size)
target_logprobs = target_logprobs.squeeze().reshape(
spec_expanded_bs, k + 1, self._vocab_size)
all_tokens = torch.full(size=(contracted_bs, k + 1),
fill_value=-1,
device=self._device,
dtype=torch.long)
all_probs = torch.zeros(contracted_bs,
k + 1,
self._vocab_size,
device=self._device,
dtype=torch.float32)
all_logprobs = torch.full(size=(
contracted_bs,
k + 1,
self._vocab_size,
),
fill_value=-float("inf"),
device=self._device,
dtype=torch.float32)
target_token_ids = target_token_ids.reshape(spec_expanded_bs, k + 1)
target_probs = target_probs.reshape(*target_token_ids.shape,
self._vocab_size)
target_logprobs = target_logprobs.reshape(target_probs.shape)
all_tokens = target_token_ids.new_full(size=(contracted_bs, k + 1),
fill_value=-1)
all_probs = target_probs.new_zeros(*all_tokens.shape, self._vocab_size)
all_logprobs = target_logprobs.new_full(size=all_probs.shape,
fill_value=-float("inf"))
if non_spec_indices:
all_tokens[non_spec_indices, :1] = non_spec_target_token_ids
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment