Commit fcfc474d authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.8.3' into v0.8.3-dev

parents bb94d2e5 296c6572
......@@ -592,9 +592,14 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
) -> torch.Tensor:
assert activation == "silu", "Only SiLU activation is supported."
if apply_router_weight_on_input is not None:
raise NotImplementedError(
"Apply router weight on input is not supported for"
"fused Marlin MoE method.")
# The input must currently be float16
orig_dtype = x.dtype
......
......@@ -3,6 +3,8 @@
import os
from typing import Dict, List, Optional, Type
from vllm.model_executor.layers.quantization.kernels.scaled_mm.aiter import (
AiterScaledMMLinearKernel)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass import (
CutlassScaledMMLinearKernel)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
......@@ -17,7 +19,7 @@ from vllm.platforms import PlatformEnum, current_platform
_POSSIBLE_KERNELS: Dict[PlatformEnum, List[Type[ScaledMMLinearKernel]]] = {
PlatformEnum.CPU: [CutlassScaledMMLinearKernel],
PlatformEnum.CUDA: [CutlassScaledMMLinearKernel],
PlatformEnum.ROCM: [TritonScaledMMLinearKernel],
PlatformEnum.ROCM: [AiterScaledMMLinearKernel, TritonScaledMMLinearKernel],
PlatformEnum.TPU: [XLAScaledMMLinearKernel],
}
......
# SPDX-License-Identifier: Apache-2.0
from typing import Optional, Tuple
import torch
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.platforms import current_platform
from .cutlass import CutlassScaledMMLinearKernel
from .ScaledMMLinearKernel import ScaledMMLinearLayerConfig
class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):
@classmethod
def get_min_capability(cls) -> int:
return 90
@classmethod
def can_implement(
cls, c: ScaledMMLinearLayerConfig) -> Tuple[bool, Optional[str]]:
if not current_platform.is_rocm():
return (
False,
"AiterScaledMMLinearKernel requires `aiter` which is not " +
"currently supported on non-ROCm platform.")
try:
import aiter # noqa: F401 # deliberately attempt to import aiter
except Exception:
return (
False,
"AiterScaledMMLinearKernel requires `aiter` which is not " +
"installed on ROCm.")
# Check if rocm_aiter_gemm_w8a8_scaled_mm is enabled
if not (
envs.VLLM_ROCM_USE_AITER_LINEAR \
and envs.VLLM_ROCM_USE_AITER
):
return (False, "AiterScaledMMLinearKernel is disabled. " +
"Enable by setting `VLLM_ROCM_USE_AITER=1` " +
"and `VLLM_ROCM_USE_AITER_LINEAR=1`. " +
"`VLLM_ROCM_USE_AITER_LINEAR` default is True.")
if not c.input_symmetric:
return (False,
"AiterScaledMMLinearKernel only supports symmetric " +
"quantization.")
return True, None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
super().process_weights_after_loading(layer)
def apply_weights(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
`AiterScaledMMLinearKernel` implements a fused version of
`output = torch.mm((scale_a * a), (scale_b * b)).to(out_dtype)`
where scale_a * a and scale_b * b are implemented using numpy-style
broadcasting.
Currently only support per-tensor-per-tensor GEMM
and per-token-per-channel GEMM through AITER
w8a8 scaled gemm. `AiterScaledMMLinearKernel` also does not support
ATIER block scaled GEMM and mix-precision GEMM.
"""
w_q, w_s, i_s, i_zp, azp_adj = self._get_weight_params(layer)
# ops.scaled_int8_quant supports both dynamic and static quant:
# * dynamic, i_s is None and x_s computed from x.
# * static, i_s is scalar and x_s is i_s.
symmetric = azp_adj is None
assert symmetric, ("AiterScaledMMLinearKernel only supports"
" symmetric quantization.")
x_q, x_s, x_zp = ops.scaled_int8_quant(x,
i_s,
i_zp,
symmetric=symmetric)
assert x_zp is None, ("AiterScaledMMLinearKernel only supports"
" symmetric quantization.")
out_dtype = x.dtype
assert (w_q.shape[0] % 16 == 0 and w_q.shape[1] % 16 == 0)
assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
assert bias is None or bias.shape[0] == w_q.shape[
1] and bias.dtype == out_dtype
m = x_q.shape[0] # a
n = w_q.shape[1] # b
per_tensor_scale_a = (x_s.numel() == 1)
per_tensor_scale_b = (w_s.numel() == 1)
per_token_scale_a = (x_s.numel() == m)
per_channel_scale_b = (w_s.numel() == n)
# @TODO:
# Maybe broadcast the per-tensor-scale into per-channel-scale
# if one of the scale is a per-channel-scale.
# For now, it only supports:
# - per-tensor-per-tensor a8w8 scaled GEMM, and
# - per-token-per-channel a8w8 scaled GEMM
assert ((per_tensor_scale_a and per_tensor_scale_b)
or (per_token_scale_a and per_channel_scale_b)), (
"Currently only support per-tensor-per-tensor GEMM " +
" and per-token-per-channel GEMM through AITER"
" w8a8 scaled gemm. `AiterScaledMMLinearKernel` " +
"does not support AITER block scaled GEMM.")
from aiter import gemm_a8w8_CK
# gemm_a8w8_CK(a, b, scale_a, scale_b, bias) expects
# a to be [M, K]
# b to be [N, K]
# CutlassScaledMMLinearKernel prepare weight `w_q` in [K, N] format
return gemm_a8w8_CK(x_q, w_q.t(), x_s, w_s, bias).to(out_dtype)
......@@ -97,7 +97,8 @@ class XLAScaledMMLinearKernel(ScaledMMLinearKernel):
block_size=-1,
int4_weight=False,
quantize_activation=True)
# `quantized_matmul` output is fp32, cast it down to bf16 for perf
out = out.to(x.dtype)
# Explicitly capture control flow to make dynamo happy.
# https://pytorch.org/docs/main/generated/exportdb/index.html#cond-branch-class-method # noqa: E501
return cond(bias is None, self.no_add_bias, self.add_bias, [out, bias])
......@@ -293,6 +293,7 @@ class MoeWNA16Method(FusedMoEMethodBase):
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
use_nn_moe: Optional[bool] = False,
moe_ep_size: Optional[int] = None,
......@@ -316,23 +317,24 @@ class MoeWNA16Method(FusedMoEMethodBase):
weight_bits = self.quant_config.weight_bits
has_zp = self.quant_config.has_zp
return fused_experts(x,
layer.w13_qweight,
layer.w2_qweight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
use_int4_w4a16=weight_bits == 4,
use_int8_w8a16=weight_bits == 8,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=layer.w13_scales,
w2_scale=layer.w2_scales,
w1_zp=layer.w13_qzeros if has_zp else None,
w2_zp=layer.w2_qzeros if has_zp else None,
block_shape=[0, layer.group_size],
use_nn_moe=False,
)
return fused_experts(
x,
layer.w13_qweight,
layer.w2_qweight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
use_int4_w4a16=weight_bits == 4,
use_int8_w8a16=weight_bits == 8,
global_num_experts=global_num_experts,
apply_router_weight_on_input=apply_router_weight_on_input,
expert_map=expert_map,
w1_scale=layer.w13_scales,
w2_scale=layer.w2_scales,
w1_zp=layer.w13_qzeros if has_zp else None,
w2_zp=layer.w2_qzeros if has_zp else None,
block_shape=[0, layer.group_size],
use_nn_moe=False)
@staticmethod
def get_weight_loader(layer, weight_loader):
......
......@@ -202,6 +202,8 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts
......@@ -217,16 +219,18 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
return fused_experts(x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
use_fp8_w8a8=True,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale)
return fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
use_fp8_w8a8=True,
global_num_experts=global_num_experts,
apply_router_weight_on_input=apply_router_weight_on_input,
expert_map=expert_map,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale)
......@@ -22,6 +22,7 @@ class QuarkW8A8Fp8(QuarkScheme):
self.qscheme = qscheme
self.is_static_input_scheme = is_static_input_scheme
self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=True)
self.out_dtype = torch.get_default_dtype()
@classmethod
def get_min_capability(cls) -> int:
......@@ -134,5 +135,6 @@ class QuarkW8A8Fp8(QuarkScheme):
return self.fp8_linear.apply(input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
out_dtype=self.out_dtype,
input_scale=layer.input_scale,
bias=bias)
......@@ -51,6 +51,16 @@ def cutlass_block_fp8_supported() -> bool:
return ops.cutlass_scaled_mm_supports_block_fp8(capability)
def cutlass_group_gemm_supported() -> bool:
if not current_platform.is_cuda():
return False
capability_tuple = current_platform.get_device_capability()
capability = -1 if capability_tuple is None else capability_tuple.to_int()
return ops.cutlass_group_gemm_supported(capability)
CUTLASS_FP8_SUPPORTED = cutlass_fp8_supported()
CUTLASS_BLOCK_FP8_SUPPORTED = cutlass_block_fp8_supported()
......@@ -154,6 +164,7 @@ class Fp8LinearOp:
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
out_dtype: Optional[torch.dtype] = None,
input_scale: Optional[torch.Tensor] = None,
input_scale_ub: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
......@@ -173,8 +184,13 @@ class Fp8LinearOp:
if use_per_token_if_dynamic is None:
use_per_token_if_dynamic = self.use_per_token_if_dynamic
if out_dtype is None:
out_dtype = input.dtype
# cutlass_scaled_mm supports per tensor/channel W and per tensor/token A
if self.cutlass_fp8_supported:
assert input.dtype != current_platform.fp8_dtype(
), "FP8 input to cutlass is not currently implemented"
qinput, x_scale = ops.scaled_fp8_quant(
input_2d,
input_scale,
......@@ -184,7 +200,7 @@ class Fp8LinearOp:
# Fused GEMM_DQ
output = ops.cutlass_scaled_mm(qinput,
weight,
out_dtype=input.dtype,
out_dtype=out_dtype,
scale_a=x_scale,
scale_b=weight_scale,
bias=bias)
......@@ -193,12 +209,15 @@ class Fp8LinearOp:
# torch.scaled_mm supports per tensor weights + activations only
# so fallback to naive if per channel or per token
else:
# Maybe apply padding to output, see comment in __init__
qinput, x_scale = ops.scaled_fp8_quant(
input_2d,
input_scale,
num_token_padding=self.output_padding,
use_per_token_if_dynamic=use_per_token_if_dynamic)
if input.dtype != current_platform.fp8_dtype():
# Maybe apply padding to output, see comment in __init__
qinput, x_scale = ops.scaled_fp8_quant(
input_2d,
input_scale,
num_token_padding=self.output_padding,
use_per_token_if_dynamic=use_per_token_if_dynamic)
else:
qinput, x_scale = input_2d, input_scale
per_tensor_weights = (weight_scale.numel() == 1)
per_tensor_activations = (x_scale.numel() == 1)
......@@ -207,7 +226,7 @@ class Fp8LinearOp:
# Fused GEMM_DQ
output = torch._scaled_mm(qinput,
weight,
out_dtype=input.dtype,
out_dtype=out_dtype,
scale_a=x_scale,
scale_b=weight_scale,
bias=bias)
......@@ -231,7 +250,7 @@ class Fp8LinearOp:
# Fused GEMM_DQ Rowwise GEMM
output = torch._scaled_mm(qinput,
weight,
out_dtype=input.dtype,
out_dtype=out_dtype,
scale_a=x_scale,
scale_b=weight_scale.t(),
bias=bias)
......
......@@ -851,6 +851,70 @@ class Llama3RotaryEmbedding(RotaryEmbedding):
return new_freqs
class Llama4VisionRotaryEmbedding(RotaryEmbedding):
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
dtype: torch.dtype,
):
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
is_neox_style, dtype)
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
inv_freqs = super()._compute_inv_freq(base)
inv_freqs = inv_freqs[:(self.rotary_dim // 2)]
return inv_freqs
def _compute_cos_sin_cache(self) -> torch.Tensor:
inv_freq = self._compute_inv_freq(self.base)
# self.max_position_embeddings here is number of image patches
# i.e. (image_size // patch_size) ** 2
num_patches = self.max_position_embeddings
img_idx = torch.arange(num_patches,
dtype=torch.int32) \
.reshape(num_patches, 1)
img_idx = torch.cat([img_idx, img_idx[:1]], dim=0)
img_idx[-1, -1] = -2 # set to ID_CLS_TOKEN
num_patches_single_dim = int(math.sqrt(num_patches))
frequencies_x = img_idx % num_patches_single_dim
frequencies_y = img_idx // num_patches_single_dim
freqs_x = ((frequencies_x + 1)[..., None] *
inv_freq[None, None, :]).repeat_interleave(2, dim=-1)
freqs_y = ((frequencies_y + 1)[..., None] *
inv_freq[None, None, :]).repeat_interleave(2, dim=-1)
freqs = torch.cat([freqs_x, freqs_y],
dim=-1).float().contiguous()[..., ::2]
freqs = freqs.masked_fill(img_idx.reshape(-1, 1, 1) < 0, 0)
cache = torch.view_as_complex(
torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1))
return cache
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(query.device)
query_ = torch.view_as_complex(query.float().reshape(
*query.shape[:-1], -1, 2))
key_ = torch.view_as_complex(key.float().reshape(
*key.shape[:-1], -1, 2))
broadcast_shape = [
d if i == 1 or i == (query_.ndim - 1) else 1
for i, d in enumerate(query_.shape)
]
freqs_ci = self.cos_sin_cache.view(*broadcast_shape)
query_out = torch.view_as_real(query_ * freqs_ci).flatten(3)
key_out = torch.view_as_real(key_ * freqs_ci).flatten(3)
return query_out.type_as(query), key_out.type_as(key)
class MRotaryEmbedding(RotaryEmbedding):
"""Rotary Embedding with Multimodal Sections."""
......@@ -1130,6 +1194,10 @@ def get_rope(
scaling_factor, low_freq_factor,
high_freq_factor,
original_max_position)
elif scaling_type == "mllama4":
rotary_emb = Llama4VisionRotaryEmbedding(head_size, rotary_dim,
max_position, base,
is_neox_style, dtype)
elif scaling_type == "default":
if "mrope_section" in rope_scaling:
rotary_emb = MRotaryEmbedding(
......
......@@ -250,7 +250,7 @@ class VocabParallelEmbedding(torch.nn.Module):
# If we are making an embedding layer, then our quantization linear
# method must implement the embedding operation. If we are another
# layer type like ParallelLMHead, this is not important.
is_embedding_layer = type(self.__class__) is VocabParallelEmbedding
is_embedding_layer = type(self) is VocabParallelEmbedding
quant_method_implements_embedding = method_has_implemented_embedding(
type(quant_method))
if is_embedding_layer and not quant_method_implements_embedding:
......
......@@ -1261,6 +1261,8 @@ class BitsAndBytesModelLoader(BaseModelLoader):
pack_ratio)
offsets = np.concatenate(([0], np.cumsum(num_elements)))
# Make torch infer_schema happy
offsets = torch.tensor(offsets).cpu()
set_weight_attrs(param, {"bnb_shard_offsets": offsets})
if load_8bit:
......
......@@ -37,16 +37,13 @@ def is_transformers_impl_compatible(
mod = module or getattr(transformers, arch, None)
if mod is None:
return False
if hasattr(mod, "supports_backend"):
return mod.is_backend_compatible()
else:
return mod._supports_flex_attn
return mod.is_backend_compatible()
def resolve_transformers_fallback(model_config: ModelConfig,
architectures: list[str]):
def resolve_transformers_arch(model_config: ModelConfig,
architectures: list[str]):
for i, arch in enumerate(architectures):
if arch == "TransformersModel":
if arch == "TransformersForCausalLM":
continue
auto_map: dict[str, str] = getattr(model_config.hf_config, "auto_map",
None) or dict()
......@@ -70,7 +67,7 @@ def resolve_transformers_fallback(model_config: ModelConfig,
raise ValueError(
f"The Transformers implementation of {arch} is not "
"compatible with vLLM.")
architectures[i] = "TransformersModel"
architectures[i] = "TransformersForCausalLM"
if model_config.model_impl == ModelImpl.AUTO:
if not is_transformers_impl_compatible(arch, custom_model_module):
raise ValueError(
......@@ -81,7 +78,7 @@ def resolve_transformers_fallback(model_config: ModelConfig,
"%s has no vLLM implementation, falling back to Transformers "
"implementation. Some features may not be supported and "
"performance may not be optimal.", arch)
architectures[i] = "TransformersModel"
architectures[i] = "TransformersForCausalLM"
return architectures
......@@ -140,8 +137,7 @@ def get_model_architecture(
for arch in architectures)
if (not is_vllm_supported
or model_config.model_impl == ModelImpl.TRANSFORMERS):
architectures = resolve_transformers_fallback(model_config,
architectures)
architectures = resolve_transformers_arch(model_config, architectures)
model_cls, arch = ModelRegistry.resolve_model_cls(architectures)
if model_config.task == "embed":
......
......@@ -99,16 +99,17 @@ def _create_pooling_model_cls(
mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
weights = mapper.apply(weights)
self.model.load_weights(weights)
return
loaded_params = self.model.load_weights(weights)
loaded_params = {f"model.{name}" for name in loaded_params}
return loaded_params
# For most other models
if hasattr(orig_cls, "load_weights"):
orig_cls.load_weights(self, weights) # type: ignore
return orig_cls.load_weights(self, weights) # type: ignore
# Fallback
else:
loader = AutoWeightsLoader(self)
loader.load_weights(weights)
return loader.load_weights(weights)
return ModelForPooling # type: ignore
......
# SPDX-License-Identifier: Apache-2.0 Adapted from
# https://github.com/huggingface/transformers/tree/main/src/transformers/models/aya_vision
from functools import cached_property
from typing import (Iterable, Literal, Mapping, Optional, Sequence, Set, Tuple,
TypedDict, Union, cast)
import torch
from torch import nn
from transformers import BatchFeature, GotOcr2ImageProcessor
from transformers.activations import ACT2FN
from transformers.image_processing_utils import get_size_dict
from transformers.models.aya_vision import AyaVisionConfig
from transformers.models.aya_vision.processing_aya_vision import (
AyaVisionProcessor)
from transformers.models.got_ocr2.image_processing_got_ocr2 import (
get_optimal_tiled_canvas)
from vllm.config import VllmConfig
from vllm.jsontree import json_map_leaves
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalKwargs
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo,
MultiModalFieldConfig,
PromptReplacement, PromptUpdate,
encode_tokens)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
from .vision import scatter_patch_features, select_patch_features
class AyaVisionImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
pixel_values: torch.Tensor
"""
Shape: `(num_patches_total, num_channels, height, width)`
`num_patches_total` is the total number of patches over each image over each
prompt in the batch.
"""
num_patches: torch.Tensor
"""Shape: `(batch_size * num_images)`"""
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which image embeddings correspond to patch tokens.
Shape: `(batch_size * num_images, num_embeds)`
"""
class AyaVisionMultiModalProjector(nn.Module):
def __init__(self, config: AyaVisionConfig):
super().__init__()
self.config = config
self.downsample_factor = config.downsample_factor
self.alignment_intermediate_size = getattr(
config, "alignment_intermediate_size",
config.text_config.hidden_size)
self.layernorm = nn.LayerNorm(config.vision_config.hidden_size *
(config.downsample_factor**2),
eps=config.adapter_layer_norm_eps)
self.linear_1 = nn.Linear(
config.vision_config.hidden_size * (config.downsample_factor**2),
self.alignment_intermediate_size,
bias=True,
)
self.act = ACT2FN["silu"] # SwiGLU uses SiLU activation
# For SwiGLU, project down to half size since we split intermediate dim
self.linear_2 = nn.Linear(self.alignment_intermediate_size // 2,
config.text_config.hidden_size,
bias=True)
def forward(self, image_features: torch.Tensor) -> torch.Tensor:
image_features = self.pixel_shuffle(image_features)
image_features = self.layernorm(image_features)
hidden_states = self.linear_1(image_features)
# Split along last dimension and apply SwiGLU
x, gate = hidden_states.chunk(2, dim=-1)
hidden_states = self.act(gate) * x
hidden_states = self.linear_2(hidden_states)
return hidden_states
def pixel_shuffle(self,
image_features: torch.Tensor) -> torch.Tensor: # B, S, D
batch_size, seq_length, _ = image_features.shape
height = width = int(seq_length**0.5)
image_features = image_features.reshape(image_features.shape[0], width,
height, -1)
channels = image_features.shape[-1]
image_features = image_features.reshape(
batch_size, width, int(height / self.downsample_factor),
int(channels * self.downsample_factor))
image_features = image_features.permute(0, 2, 1, 3)
image_features = image_features.reshape(
batch_size, int(height / self.downsample_factor),
int(width / self.downsample_factor), -1)
image_features = image_features.permute(0, 2, 1, 3)
return image_features
class AyaVisionProcessingInfo(BaseProcessingInfo):
def get_hf_config(self) -> AyaVisionConfig:
return self.ctx.get_hf_config(AyaVisionConfig)
def get_hf_processor(self, **kwargs: object) -> AyaVisionProcessor:
return self.ctx.get_hf_processor(AyaVisionProcessor, **kwargs)
def get_image_processor(self) -> GotOcr2ImageProcessor:
return self.get_hf_processor().image_processor
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {"image": self.get_max_image_tokens()}
def get_max_image_tokens(self) -> int:
hf_processor = self.get_hf_processor()
image_processor = hf_processor.image_processor
image_size = self.get_image_size_with_most_features()
tokenizer = hf_processor.tokenizer
num_patches = self.get_num_patches(
image_width=image_size.width,
image_height=image_size.height,
size=image_processor.size,
min_patches=image_processor.min_patches,
max_patches=image_processor.max_patches)
image_string = hf_processor._prompt_split_image(num_patches)
x = encode_tokens(
tokenizer,
image_string,
add_special_tokens=False,
)
return len(x)
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
def get_image_size_with_most_features(self) -> ImageSize:
image_processor = self.get_image_processor()
height = image_processor.size['height']
width = image_processor.size['width']
max_patches = image_processor.max_patches
return ImageSize(height=height * max_patches,
width=width * max_patches)
def get_num_patches(self, *, image_width: int, image_height: int,
size: dict, min_patches: int, max_patches: int) -> int:
"""
Calculate the number of patches needed for a given image based on size
constraints. This method replicates and adjusts the logic from:
transformers/models/got_ocr2/image_processing_got_ocr2
"""
size = get_size_dict(size, default_to_square=False)
num_columns, num_rows = get_optimal_tiled_canvas(
(image_height, image_width), (size["height"], size["width"]),
min_patches, max_patches)
num_blocks = num_columns * num_rows
return num_blocks if num_blocks == 1 else num_blocks + 1
class AyaVisionDummyInputsBuilder(
BaseDummyInputsBuilder[AyaVisionProcessingInfo]):
def get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
processor = self.info.get_hf_processor()
image_token = processor.image_token
num_images = mm_counts.get("image", 0)
image_size = \
self.info.get_image_size_with_most_features()
mm_data = {
"image":
self._get_dummy_images(width=image_size.width,
height=image_size.height,
num_images=num_images)
}
return ProcessorInputs(
prompt_text=image_token * num_images,
mm_data=mm_data,
)
class AyaVisionMultiModalProcessor(
BaseMultiModalProcessor[AyaVisionProcessingInfo]):
def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
) -> BatchFeature:
processed_outputs = super()._call_hf_processor(
prompt,
mm_data,
mm_kwargs,
)
hf_processor = self.info.get_hf_processor(**mm_kwargs)
image_processor = hf_processor.image_processor
hf_config = self.info.get_hf_config()
# HF processor pops the `num_patches` kwarg, which is needed by vLLM
if (images :=
mm_data.get("images")) is not None and '<image>' in prompt:
assert isinstance(images, list)
parsed_images = (self._get_data_parser().parse_mm_data({
"image":
images
}).get_items("image", ImageProcessorItems))
image_sizes = [
parsed_images.get_image_size(i)
for i in range(len(parsed_images))
]
num_patches = [
self.info.get_num_patches(
image_width=image_size.width,
image_height=image_size.height,
size=image_processor.size,
min_patches=image_processor.min_patches,
max_patches=image_processor.max_patches)
for image_size in image_sizes
]
image_tokens_list = [
hf_processor._prompt_split_image(num_patch)
for num_patch in num_patches
]
tokenizer = self.info.get_tokenizer()
image_token_ids = [
tokenizer.encode(image_tokens, add_special_tokens=False)
for image_tokens in image_tokens_list
]
embed_is_patch = [
torch.tensor(image_repl_tokens) == hf_config.image_token_index
for image_repl_tokens in image_token_ids
]
processed_outputs["embed_is_patch"] = embed_is_patch
processed_outputs["num_patches"] = torch.tensor(num_patches)
return processed_outputs
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
num_patches = hf_inputs.get("num_patches", torch.empty(0))
return dict(
pixel_values=MultiModalFieldConfig.flat_from_sizes(
"image", num_patches),
num_patches=MultiModalFieldConfig.batched("image"),
embed_is_patch=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"),
)
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> Sequence[PromptUpdate]:
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
image_token = hf_processor.image_token
image_processor = hf_processor.image_processor
def get_replacement(item_idx: int):
images: ImageProcessorItems = mm_items.get("image",
ImageProcessorItems)
image_size: ImageSize = images.get_image_size(item_idx)
num_patches = self.info.get_num_patches(
image_width=image_size.width,
image_height=image_size.height,
size=image_processor.size,
min_patches=image_processor.min_patches,
max_patches=image_processor.max_patches)
return hf_processor._prompt_split_image(num_patches=num_patches)
return [
PromptReplacement(
modality="image",
target=image_token,
replacement=get_replacement,
)
]
def _get_num_hidden_layers(hf_config: AyaVisionConfig) -> int:
feature_layers = hf_config.vision_feature_layer
num_hidden_layers = hf_config.vision_config.num_hidden_layers
# If we have one feature layer, initialize up to that layer
if isinstance(feature_layers, int):
return _get_layer_index(feature_layers, num_hidden_layers)
# If we have multiple feature layers, initialize up to the deepest m
elif isinstance(feature_layers, (list, tuple)):
return max(
_get_layer_index(idx, num_hidden_layers) for idx in feature_layers)
raise TypeError(f"vision_layer_feature type: {type(feature_layers)}"
" is not supported")
def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int:
if feature_layer_index < 0:
return num_hidden_layers + feature_layer_index + 1
return feature_layer_index
@MULTIMODAL_REGISTRY.register_processor(
AyaVisionMultiModalProcessor,
info=AyaVisionProcessingInfo,
dummy_inputs=AyaVisionDummyInputsBuilder)
class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config: AyaVisionConfig = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
num_hidden_layers = _get_num_hidden_layers(config)
self.config = config
self.quant_config = quant_config
self.multimodal_config = multimodal_config
self.vision_tower = SiglipVisionModel(
config.vision_config,
quant_config,
num_hidden_layers_override=num_hidden_layers,
prefix=maybe_prefix(prefix, "vision_model"))
self.vocab_size = config.text_config.vocab_size
self.multi_modal_projector = AyaVisionMultiModalProjector(config)
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "model"),
# Cohere2ForCausalLM and CohereForCausalLM are the same on vllm
architectures=["Cohere2ForCausalLM"])
@property
def dtype(self):
return next(self.parameters()).dtype
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
def _image_pixels_to_features(self, vision_tower: SiglipVisionModel,
pixel_values: torch.Tensor,
**kwargs) -> torch.Tensor:
target_dtype = vision_tower.get_input_embeddings().weight.dtype
image_features = vision_tower(pixel_values.to(dtype=target_dtype),
**kwargs)
def select_features(leaf: torch.Tensor):
return self._select_image_features(
leaf,
strategy=self.config.vision_feature_select_strategy,
)
return cast(
Union[torch.Tensor, tuple[torch.Tensor, ...]],
json_map_leaves(select_features, image_features),
)
def _select_image_features(self, image_features: torch.Tensor, *,
strategy: str) -> torch.Tensor:
if strategy == "default":
return image_features[:, 1:]
elif strategy == "full":
return image_features
raise ValueError(f"Unexpected select feature strategy: {strategy}")
def _process_image_input(self, image_input: AyaVisionImagePixelInputs,
**kwargs) -> list[torch.Tensor]:
assert self.vision_tower is not None
pixel_values = image_input["pixel_values"]
num_patches = image_input["num_patches"]
image_features = self._image_pixels_to_features(
self.vision_tower, pixel_values=pixel_values)
image_embeds = self.multi_modal_projector(image_features)
return [
e.flatten(0, 2) for e in image_embeds.split(num_patches.tolist())
]
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
h = w = self.config.vision_config.image_size
expected_dims = (3, h, w)
def _validate_shape(d: torch.Tensor):
if d.shape != expected_dims:
raise ValueError(
"The expected shape of pixel values per image per batch "
f"is {expected_dims}. You supplied {tuple(d.shape)}.")
for d in data:
_validate_shape(d)
return data
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[AyaVisionImagePixelInputs]:
pixel_values = kwargs.pop("pixel_values", None)
num_patches = kwargs.pop("num_patches", None)
embed_is_patch = kwargs.pop("embed_is_patch", None)
image_embeds = kwargs.pop("image_embeds", None)
assert image_embeds is None, "Aya Vision does not support image_embeds."
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
if num_patches is not None and not isinstance(num_patches,
(torch.Tensor, list)):
raise ValueError("Incorrect type of num_patches. "
f"Got type: {type(num_patches)}")
if not isinstance(embed_is_patch, (torch.Tensor, list)):
raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}")
pixel_values = flatten_bn(pixel_values, concat=True)
num_patches = flatten_bn(num_patches, concat=True)
embed_is_patch = flatten_bn(embed_is_patch)
return AyaVisionImagePixelInputs(
type="pixel_values",
pixel_values=self._validate_pixel_values(pixel_values),
num_patches=num_patches,
embed_is_patch=embed_is_patch,
)
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
image_features = self._process_image_input(image_input, **kwargs)
return scatter_patch_features(
image_features,
image_input["embed_is_patch"],
)
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
inputs_embeds = merge_multimodal_embeddings(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
multimodal_embeddings=select_patch_features(
multimodal_embeddings),
placeholder_token_id=self.config.image_token_index)
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object,
) -> Union[torch.Tensor, IntermediateTensors]:
if intermediate_tensors is not None:
inputs_embeds = None
# NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility.
elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
input_ids = None
hidden_states = self.language_model.model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
return self.language_model.compute_logits(hidden_states,
sampling_metadata)
@cached_property
def sampler(self):
if hasattr(self.language_model, "sampler"):
return self.language_model.sampler
return get_sampler()
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
return self.language_model.sample(logits, sampling_metadata)
......@@ -51,7 +51,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP, SupportsQuant
from .utils import (is_pp_missing_parameter,
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
from vllm import _custom_ops as ops
......@@ -301,6 +301,16 @@ class BaiChuanModel(nn.Module):
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
self.quant_method = None
if quant_config is not None:
self.quant_method=quant_config.get_name()
self.quant_config=quant_config
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1'
self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
......@@ -336,86 +346,6 @@ class BaiChuanModel(nn.Module):
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
SupportsQuant):
packed_modules_mapping = {
"W_pack": ["W_pack"],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
def __init__(
self,
*,
vllm_config: VllmConfig,
prefix: str = "",
position_embedding: str = "ROPE",
):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
self.lora_config = lora_config
self.quant_config = quant_config
self.model = BaiChuanModel(vllm_config=vllm_config,
prefix=prefix,
position_embedding=position_embedding)
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)
if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
self.quant_method = None
if quant_config is not None:
self.quant_method=quant_config.get_name()
self.quant_config=quant_config
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1'
self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
......@@ -428,17 +358,6 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if name == "lm_head.weight":
# Unlike Baichuan, Baichuan2 normalizes the head weights.
# Refer to:
# https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat/blob/84603cde5ebffb6084e476cfaeceaf0b8b91fe54/modeling_baichuan.py#L508
# Distinguish between Baichuan and Baichuan2 by checking the
# vocab size. This is suggested by
# https://github.com/vllm-project/vllm/pull/1022#discussion_r1325652704
is_baichuan2 = self.config.vocab_size == 125696
if is_baichuan2:
loaded_weight = torch.nn.functional.normalize(
loaded_weight)
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
......@@ -464,7 +383,7 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
if self.use_llama_nn and self.quant_method is None :
lay_key_words = [
"self_attn.W_pack.weight",
......@@ -540,11 +459,101 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
zeros_and_scalse_pad= torch.zeros(dim_n,pad_group,dtype=torch.int32).cuda()
zeros_and_scalse.data=torch.cat((zeros_and_scalse.data,zeros_and_scalse_pad),dim=1).contiguous()
qweight_pad= torch.zeros(dim_n,int(group_size//4),dtype=torch.int32).cuda()
qweight.data=torch.cat((qweight.data,qweight_pad),dim=1).contiguous()
qweight.data=torch.cat((qweight.data,qweight_pad),dim=1).contiguous()
return loaded_params
class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
SupportsQuant):
packed_modules_mapping = {
"W_pack": ["W_pack"],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
def __init__(
self,
*,
vllm_config: VllmConfig,
prefix: str = "",
position_embedding: str = "ROPE",
):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
self.lora_config = lora_config
self.quant_config = quant_config
self.model = BaiChuanModel(vllm_config=vllm_config,
prefix=prefix,
position_embedding=position_embedding)
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)
self.lm_head.weight.weight_loader = self.lm_head_weight_loader
if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
def lm_head_weight_loader(self, param: nn.Parameter,
loaded_weight: torch.Tensor):
# Unlike Baichuan, Baichuan2 normalizes the head weights.
# Refer to:
# https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat/blob/84603cde5ebffb6084e476cfaeceaf0b8b91fe54/modeling_baichuan.py#L508
# Distinguish between Baichuan and Baichuan2 by checking the
# vocab size. This is suggested by
# https://github.com/vllm-project/vllm/pull/1022#discussion_r1325652704
is_baichuan2 = self.config.vocab_size == 125696
if is_baichuan2:
loaded_weight = torch.nn.functional.normalize(loaded_weight)
default_weight_loader(param, loaded_weight)
class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
"""Baichuan 13B and Baichuan2 7B/13B.
NOTE: the class name has a lower case 'c'.
......
......@@ -34,7 +34,7 @@ from vllm.utils import LayerBlockType
from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP,
SupportsQuant, SupportsV0Only)
from .utils import (is_pp_missing_parameter,
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
......@@ -363,6 +363,58 @@ class BambaModel(nn.Module):
hidden_states, _ = self.final_layernorm(hidden_states, residual)
return hidden_states
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if "A_log" in name:
name = name.replace("A_log", "A")
if ".self_attn." in name:
name = name.replace(".self_attn", "")
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
IsHybrid, SupportsV0Only, SupportsQuant):
......@@ -502,52 +554,5 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if "A_log" in name:
name = name.replace("A_log", "A")
if ".self_attn." in name:
name = name.replace(".self_attn", "")
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
\ No newline at end of file
......@@ -26,7 +26,7 @@ from vllm.sequence import IntermediateTensors, PoolerOutput
from vllm.transformers_utils.config import (
get_cross_encoder_activation_function)
from .interfaces import SupportsCrossEncoding, SupportsV0Only
from .interfaces import SupportsCrossEncoding, SupportsQuant, SupportsV0Only
from .utils import WeightsMapper, maybe_prefix
......@@ -313,7 +313,8 @@ class BertOutput(nn.Module):
return hidden_states
class BertModel(nn.Module):
class BertModel(nn.Module, SupportsQuant):
packed_modules_mapping = {"qkv_proj": ["query", "key", "value"]}
def __init__(self,
*,
......@@ -385,7 +386,7 @@ class BertModel(nn.Module):
return loaded_params
class BertEmbeddingModel(nn.Module, SupportsV0Only):
class BertEmbeddingModel(nn.Module, SupportsV0Only, SupportsQuant):
"""A model that uses Bert to provide embedding functionalities.
This class encapsulates the BertModel and provides an interface for
......@@ -443,7 +444,8 @@ class BertEmbeddingModel(nn.Module, SupportsV0Only):
softmax=False)
class BertForSequenceClassification(nn.Module, SupportsCrossEncoding):
class BertForSequenceClassification(nn.Module, SupportsCrossEncoding,
SupportsQuant):
"""A model that uses Bert to provide embedding functionalities.
This class encapsulates the BertModel and provides an interface for
......
......@@ -16,6 +16,8 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from .interfaces import SupportsQuant
def get_blip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
assert image_size % patch_size == 0
......@@ -243,9 +245,10 @@ class BlipEncoder(nn.Module):
return hidden_states
class BlipVisionModel(nn.Module):
class BlipVisionModel(nn.Module, SupportsQuant):
config_class = BlipVisionConfig
main_input_name = "pixel_values"
packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]}
def __init__(
self,
......
......@@ -24,7 +24,8 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from .blip import BlipVisionModel
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, SupportsPP,
SupportsQuant)
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
......@@ -498,7 +499,8 @@ class Blip2MultiModalProcessor(BaseMultiModalProcessor[Blip2ProcessingInfo]):
@MULTIMODAL_REGISTRY.register_processor(Blip2MultiModalProcessor,
info=Blip2ProcessingInfo,
dummy_inputs=Blip2DummyInputsBuilder)
class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
SupportsQuant):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
......
......@@ -46,7 +46,7 @@ from vllm.sequence import IntermediateTensors
from vllm import _custom_ops as ops
from vllm.model_executor.utils import pad_weight, gemm_bank_conf
from .interfaces import SupportsPP, SupportsV0Only
from .interfaces import SupportsPP, SupportsQuant, SupportsV0Only
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
......@@ -290,7 +290,7 @@ class BloomModel(nn.Module):
return hidden_states
class BloomForCausalLM(nn.Module, SupportsPP, SupportsV0Only):
class BloomForCausalLM(nn.Module, SupportsPP, SupportsV0Only, SupportsQuant):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment