Commit af7f4372 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.5.5' into v0.5.5-dtk24.04.1

parents 5e19cdef 09c77926
......@@ -345,6 +345,8 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA):
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)
if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
......@@ -369,8 +371,11 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA):
attn_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
......
# Derived from BART implementation posted on HuggingFace; license below:
#
# coding=utf-8
# Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch BART model."""
import math
from typing import Iterable, List, Optional, Tuple
import torch
from torch import nn
from transformers import BartConfig
from transformers.utils import logging
from vllm.attention import Attention, AttentionMetadata, AttentionType
from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, SamplerOutput
logger = logging.get_logger(__name__)
def get_bsz_seq_len(input_ids):
shp = input_ids.shape
ndim = len(shp)
if ndim == 1:
return 1, input_ids.numel()
else:
return shp[:2]
class BartLearnedPositionalEmbedding(VocabParallelEmbedding):
"""
This module learns positional embeddings up to a fixed maximum size.
"""
def __init__(self, num_embeddings: int, embedding_dim: int):
# Bart is set up so that if padding_idx is
# specified then offset the embedding ids by 2
# and adjust num_embeddings appropriately.
# Other models don't have this hack
self.offset = 2
super().__init__(num_embeddings + self.offset, embedding_dim)
def forward(
self,
positions: torch.Tensor,
attn_type: AttentionType,
) -> torch.Tensor:
"""`input_ids' shape is expected to be [bsz x seqlen]."""
assert attn_type != AttentionType.ENCODER_DECODER
return super().forward(positions + self.offset)
class BartScaledWordEmbedding(VocabParallelEmbedding):
"""
This module overrides VocabParallelEmbedding's
forward by multiplying with embeddings scale.
"""
def __init__(self,
num_embeddings: int,
embedding_dim: int,
embed_scale: float = 1.0):
super().__init__(num_embeddings, embedding_dim)
self.embed_scale = embed_scale
def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
return super().forward(input_ids) * self.embed_scale
class BartParallelLMHead(ParallelLMHead):
"""
This module overrides ParallelLMHead's
forward by dividing by embeddings scale,
yielding effectively the inverse of
BartScaledWordEmbedding
"""
def __init__(self,
num_embeddings: int,
embedding_dim: int,
embed_scale: float = 1.0):
super().__init__(num_embeddings, embedding_dim)
self.embed_scale = embed_scale
def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
return super().forward(input_ids) / self.embed_scale
class BartEncoderAttention(nn.Module):
def __init__(
self,
embed_dim: int,
num_heads: int,
bias: bool = True,
config: Optional[BartConfig] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.d_model = config.d_model
self.embed_dim = embed_dim
self.total_num_heads = num_heads
self.total_num_kv_heads = self.total_num_heads
self.head_dim = embed_dim // num_heads
self.config = config
if (self.head_dim * num_heads) != self.embed_dim:
raise ValueError(f"embed_dim must be divisible by num_heads "
f"(got `embed_dim`: {self.embed_dim}"
f" and `num_heads`: {num_heads}).")
self.scaling = self.head_dim**-0.5
self.qkv_proj = QKVParallelLinear(
self.d_model,
self.d_model // self.total_num_heads,
self.total_num_heads,
self.total_num_kv_heads,
bias=bias,
quant_config=quant_config,
)
self.out_proj = RowParallelLinear(
embed_dim,
embed_dim,
bias=bias,
quant_config=quant_config,
)
tp_world_size = get_tensor_model_parallel_world_size()
assert self.total_num_heads % tp_world_size == 0
self.num_heads = self.total_num_heads // tp_world_size
if self.total_num_kv_heads >= tp_world_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_world_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_world_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size)
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config)
def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata) -> torch.Tensor:
"""Input shape: Batch x Time x Channel"""
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
attn_output = self.attn(q,
k,
v,
kv_cache,
attn_metadata,
attn_type=AttentionType.ENCODER)
output, _ = self.out_proj(attn_output)
return output
class BartDecoderSelfAttention(nn.Module):
def __init__(
self,
embed_dim: int,
num_heads: int,
bias: bool = True,
config: Optional[BartConfig] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.d_model = config.d_model
self.embed_dim = embed_dim
self.total_num_heads = num_heads
self.total_num_kv_heads = self.total_num_heads
self.head_dim = embed_dim // num_heads
self.config = config
if (self.head_dim * num_heads) != self.embed_dim:
raise ValueError(f"embed_dim must be divisible by num_heads "
f"(got `embed_dim`: {self.embed_dim}"
f" and `num_heads`: {num_heads}).")
self.scaling = self.head_dim**-0.5
self.qkv_proj = QKVParallelLinear(
self.d_model,
self.d_model // self.total_num_heads,
self.total_num_heads,
self.total_num_kv_heads,
bias=bias,
quant_config=quant_config,
)
self.out_proj = RowParallelLinear(
embed_dim,
embed_dim,
bias=bias,
quant_config=quant_config,
)
tp_world_size = get_tensor_model_parallel_world_size()
assert self.total_num_heads % tp_world_size == 0
self.num_heads = self.total_num_heads // tp_world_size
if self.total_num_kv_heads >= tp_world_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_world_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_world_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size)
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config)
def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata) -> torch.Tensor:
"""Input shape: Batch x Time x Channel"""
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
attn_output = self.attn(q,
k,
v,
kv_cache,
attn_metadata,
attn_type=AttentionType.DECODER)
output, _ = self.out_proj(attn_output)
return output
class BartCrossAttention(nn.Module):
def __init__(
self,
embed_dim: int,
num_heads: int,
bias: bool = True,
config: Optional[BartConfig] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.d_model = config.d_model
self.embed_dim = embed_dim
self.total_num_heads = num_heads
self.total_num_kv_heads = self.total_num_heads
self.head_dim = embed_dim // num_heads
self.config = config
if (self.head_dim * num_heads) != self.embed_dim:
raise ValueError(f"embed_dim must be divisible by num_heads "
f"(got `embed_dim`: {self.embed_dim}"
f" and `num_heads`: {num_heads}).")
self.scaling = self.head_dim**-0.5
self.qkv_proj = QKVParallelLinear(
self.d_model,
self.d_model // self.total_num_heads,
self.total_num_heads,
self.total_num_kv_heads,
bias=bias,
quant_config=quant_config,
)
self.out_proj = RowParallelLinear(
embed_dim,
embed_dim,
bias=bias,
quant_config=quant_config,
)
tp_world_size = get_tensor_model_parallel_world_size()
assert self.total_num_heads % tp_world_size == 0
self.num_heads = self.total_num_heads // tp_world_size
if self.total_num_kv_heads >= tp_world_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_world_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_world_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size)
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config)
def forward(
self,
decoder_hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
encoder_hidden_states: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Input shape: Batch x Time x Channel"""
# (afeldman-nm 2024/07/22) TODO:
# Need a more efficient solution for q/k/v
qkv_dec, _ = self.qkv_proj(decoder_hidden_states)
q, _, _ = qkv_dec.split([self.q_size, self.kv_size, self.kv_size],
dim=-1)
if encoder_hidden_states is None:
k = None
v = None
else:
qkv_enc, _ = self.qkv_proj(encoder_hidden_states)
_, k, v = qkv_enc.split([self.q_size, self.kv_size, self.kv_size],
dim=-1)
attn_output = self.attn(q,
k,
v,
kv_cache,
attn_metadata,
attn_type=AttentionType.ENCODER_DECODER)
output, _ = self.out_proj(attn_output)
return output
class BartEncoderLayer(nn.Module):
def __init__(
self,
config: BartConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.embed_dim = config.d_model
self.self_attn = BartEncoderAttention(
embed_dim=self.embed_dim,
num_heads=config.encoder_attention_heads,
config=config,
cache_config=cache_config,
quant_config=quant_config)
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.activation_fn = get_act_fn(config.activation_function,
quant_config)
ffn_hidden_size = self.embed_dim
ffn_intermediate_size = config.encoder_ffn_dim
ffn_has_bias = True
self.fc1 = ColumnParallelLinear(
ffn_hidden_size,
ffn_intermediate_size,
bias=ffn_has_bias,
quant_config=quant_config,
)
self.act = get_act_fn("gelu", quant_config, ffn_intermediate_size)
self.fc2 = RowParallelLinear(
ffn_intermediate_size,
ffn_hidden_size,
bias=ffn_has_bias,
quant_config=quant_config,
)
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata) -> torch.Tensor:
r"""
Args:
hidden_states
torch.Tensor of *encoder* input embeddings.
kv_cache:
Layer-wise list of KV cache tensors
attn_metadata:
vLLM Attention metadata structure
Returns:
Encoder layer output torch.Tensor
"""
residual = hidden_states
hidden_states = self.self_attn(hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata)
hidden_states = residual + hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
residual = hidden_states
fc1_out, _ = self.fc1(hidden_states)
hidden_states = self.activation_fn(fc1_out)
hidden_states, _ = self.fc2(hidden_states)
hidden_states = residual + hidden_states
hidden_states = self.final_layer_norm(hidden_states)
if hidden_states.dtype == torch.float16 and (
torch.isinf(hidden_states).any()
or torch.isnan(hidden_states).any()):
clamp_value = torch.finfo(hidden_states.dtype).max - 1000
hidden_states = torch.clamp(hidden_states,
min=-clamp_value,
max=clamp_value)
return hidden_states
class BartDecoderLayer(nn.Module):
def __init__(
self,
config: BartConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.embed_dim = config.d_model
self.self_attn = BartDecoderSelfAttention(
embed_dim=self.embed_dim,
num_heads=config.decoder_attention_heads,
config=config,
cache_config=cache_config,
quant_config=quant_config)
self.activation_fn = get_act_fn(config.activation_function,
quant_config)
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
'''
afeldman-nm: personally I would call this "cross-attention",
however I left the name as "encoder_attn" to maintain consistency
with the name of the pretrained weights.
'''
self.encoder_attn = BartCrossAttention(
self.embed_dim,
config.decoder_attention_heads,
config=config,
)
self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
ffn_hidden_size = self.embed_dim
ffn_intermediate_size = config.encoder_ffn_dim
ffn_has_bias = True
self.fc1 = ColumnParallelLinear(
ffn_hidden_size,
ffn_intermediate_size,
bias=ffn_has_bias,
quant_config=quant_config,
)
self.fc2 = RowParallelLinear(
ffn_intermediate_size,
ffn_hidden_size,
bias=ffn_has_bias,
quant_config=quant_config,
)
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
def forward(
self,
decoder_hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
encoder_hidden_states: Optional[torch.Tensor] = None,
) -> torch.Tensor:
r"""
Args:
decoder_hidden_states
torch.Tensor of *decoder* input embeddings.
kv_cache:
KV cache tensor
attn_metadata:
vLLM Attention metadata structure
encoder_hidden_states
torch.Tensor of *encoder* input embeddings.
Returns:
Decoder layer output torch.Tensor
"""
residual = decoder_hidden_states
# Self Attention
hidden_states = self.self_attn(hidden_states=decoder_hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata)
hidden_states = residual + hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
# Cross-Attention Block
residual = hidden_states
hidden_states = self.encoder_attn(
decoder_hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
encoder_hidden_states=encoder_hidden_states,
)
hidden_states = residual + hidden_states
hidden_states = self.encoder_attn_layer_norm(hidden_states)
# Fully Connected
residual = hidden_states
fc1_out, _ = self.fc1(hidden_states)
hidden_states = self.activation_fn(fc1_out)
hidden_states, _ = self.fc2(hidden_states)
hidden_states = residual + hidden_states
hidden_states = self.final_layer_norm(hidden_states)
return hidden_states
class BartEncoder(nn.Module):
"""
Transformer encoder consisting of *config.encoder_layers*
self attention layers. Each layer is a [`BartEncoderLayer`].
Args:
config: BartConfig
embed_tokens (nn.Embedding): output embedding
"""
def __init__(self,
config: BartConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
embed_tokens: Optional[nn.Embedding] = None):
super().__init__()
self.cache_config = cache_config
self.quant_config = quant_config
self.lora_config = lora_config
embed_dim = config.d_model
self.max_source_positions = config.max_position_embeddings
embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
self.embed_tokens = BartScaledWordEmbedding(config.vocab_size,
embed_dim,
embed_scale=embed_scale)
if embed_tokens is not None:
self.embed_tokens.weight = embed_tokens.weight
self.embed_positions = BartLearnedPositionalEmbedding(
config.max_position_embeddings,
embed_dim,
)
self.layers = nn.ModuleList(
[BartEncoderLayer(config,cache_config,quant_config) \
for _ in range(config.encoder_layers)])
self.layernorm_embedding = nn.LayerNorm(embed_dim)
def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata) -> torch.Tensor:
r"""
Args:
input_ids
Indices of *encoder* input sequence tokens in the vocabulary.
Padding will be ignored by default should you
provide it.
positions
Positions of *encoder* input sequence tokens.
kv_caches:
Layer-wise list of KV cache tensors
attn_metadata:
vLLM Attention metadata structure
Returns:
Decoder output torch.Tensor
"""
# retrieve input_ids and inputs_embeds
input_ids = input_ids.view(-1, input_ids.shape[-1])
inputs_embeds = self.embed_tokens(input_ids)
embed_pos = self.embed_positions(
positions,
AttentionType.ENCODER,
)
embed_pos = embed_pos.to(inputs_embeds.device)
hidden_states = inputs_embeds + embed_pos
hidden_states = self.layernorm_embedding(hidden_states)
for idx, encoder_layer in enumerate(self.layers):
hidden_states = encoder_layer(
hidden_states=hidden_states,
kv_cache=kv_caches[idx],
attn_metadata=attn_metadata,
)
return hidden_states
class BartDecoder(nn.Module):
"""
Transformer decoder consisting of *config.decoder_layers* layers.
Each layer is a [`BartDecoderLayer`]
Args:
config: BartConfig
embed_tokens (nn.Embedding): output embedding
"""
def __init__(
self,
config: BartConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
embed_tokens: Optional[nn.Embedding] = None,
):
super().__init__()
self.cache_config = cache_config
self.quant_config = quant_config
self.lora_config = lora_config
self.max_target_positions = config.max_position_embeddings
embed_scale = math.sqrt(
config.d_model) if config.scale_embedding else 1.0
self.embed_tokens = BartScaledWordEmbedding(config.vocab_size,
config.d_model,
embed_scale=embed_scale)
if embed_tokens is not None:
self.embed_tokens.weight = embed_tokens.weight
self.embed_positions = BartLearnedPositionalEmbedding(
config.max_position_embeddings,
config.d_model,
)
self.layers = nn.ModuleList(
[BartDecoderLayer(config,cache_config,quant_config) \
for _ in range(config.decoder_layers)])
self.layernorm_embedding = nn.LayerNorm(config.d_model)
def forward(self, decoder_input_ids: torch.Tensor,
decoder_positions: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor],
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata) -> torch.Tensor:
r"""
Args:
decoder_input_ids
Indices of *decoder* input sequence tokens in the vocabulary.
Padding will be ignored by default should you
provide it.
decoder_positions
Positions of *decoder* input sequence tokens.
encoder_hidden_states:
Tensor of encoder output embeddings
kv_caches:
Layer-wise list of KV cache tensors
attn_metadata:
vLLM Attention metadata structure
Returns:
Decoder output torch.Tensor
"""
inputs_embeds = self.embed_tokens(decoder_input_ids)
# embed positions
embed_pos = self.embed_positions(
decoder_positions,
AttentionType.DECODER,
)
embed_pos = embed_pos.to(inputs_embeds.device)
hidden_states = inputs_embeds + embed_pos
hidden_states = self.layernorm_embedding(hidden_states)
# decoder layers
for idx, decoder_layer in enumerate(self.layers):
hidden_states = decoder_layer(
decoder_hidden_states=hidden_states,
kv_cache=kv_caches[idx],
attn_metadata=attn_metadata,
encoder_hidden_states=encoder_hidden_states,
)
return hidden_states
class BartModel(nn.Module):
_tied_weights_keys = [
"encoder.embed_tokens.weight", "decoder.embed_tokens.weight"
]
def __init__(self,
config: BartConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None):
super().__init__()
self.config = config
self.padding_idx = config.pad_token_id
lora_vocab = (lora_config.lora_extra_vocab_size *
(lora_config.max_loras or 1)) if lora_config else 0
self.vocab_size = config.vocab_size + lora_vocab
self.org_vocab_size = config.vocab_size
self.encoder = BartEncoder(config,
cache_config,
quant_config=quant_config)
self.decoder = BartDecoder(config,
cache_config,
quant_config=quant_config)
def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
encoder_input_ids: torch.Tensor,
encoder_positions: torch.Tensor, kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata) -> torch.Tensor:
r"""
Args:
input_ids
Indices of *decoder* input sequence tokens in the vocabulary.
Padding will be ignored by default should you
provide it.
positions
Positions of *decoder* input sequence tokens.
encoder_input_ids
Indices of *encoder* input sequence tokens in the vocabulary.
encoder_positions:
Positions of *encoder* input sequence tokens.
kv_caches:
Layer-wise list of KV cache tensors
attn_metadata:
vLLM Attention metadata structure
Returns:
Model output torch.Tensor
"""
encoder_hidden_states = None
if encoder_input_ids.numel() > 0:
# Run encoder attention if a non-zero number of encoder tokens
# are provided as input
encoder_hidden_states = self.encoder(input_ids=encoder_input_ids,
positions=encoder_positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata)
# decoder outputs consists of
# (dec_features, past_key_value, dec_hidden, dec_attn)
decoder_outputs = self.decoder(
decoder_input_ids=input_ids,
decoder_positions=positions,
encoder_hidden_states=encoder_hidden_states,
kv_caches=kv_caches,
attn_metadata=attn_metadata)
return decoder_outputs
class BartForConditionalGeneration(nn.Module):
base_model_prefix = "model"
def __init__(self,
config: BartConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None):
super().__init__()
# currently all existing BART models have `tie_word_embeddings` enabled
assert config.tie_word_embeddings
self.config = config
self.model = BartModel(config,
cache_config,
quant_config,
lora_config=lora_config)
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
embed_scale = math.sqrt(
config.d_model) if config.scale_embedding else 1.0
self.lm_head = BartParallelLMHead(config.vocab_size,
config.d_model,
embed_scale=embed_scale)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size)
self.sampler = Sampler()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
encoder_input_ids: torch.Tensor,
encoder_positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor:
r"""
Args:
input_ids
torch.Tensor of *decoder* input token ids.
positions
torch.Tensor of *decoder* position indices.
encoder_input_ids
torch.Tensor of *encoder* input token ids.
encoder_positions
torch.Tensor of *encoder* position indices
kv_caches:
Layer-wise list of KV cache tensors
attn_metadata:
vLLM Attention metadata structure
Returns:
Output torch.Tensor
"""
return self.model(input_ids, positions, encoder_input_ids,
encoder_positions, kv_caches, attn_metadata)
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def sample(
self,
logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
stacked_params_mapping = {
"q_proj": {
"param_name": "qkv_proj",
"shard_id": "q",
},
"k_proj": {
"param_name": "qkv_proj",
"shard_id": "k",
},
"v_proj": {
"param_name": "qkv_proj",
"shard_id": "v",
},
}
params_mapping = {
"beta": "bias",
"gamma": "weight",
"LayerNorm": "layernorm",
}
def _rename_key(self, key: str):
prefix = f"{self.base_model_prefix}."
key = key[len(prefix):] if key.startswith(prefix) else key
for src, dst in self.params_mapping.items():
key = key.replace(src, dst)
return key
def _rename_stacked_param(
self,
name: str,
) -> Tuple[str, Optional[str]]:
for key, mapping in self.stacked_params_mapping.items():
if key in name:
name = name.replace(key, mapping["param_name"])
return name, mapping["shard_id"]
return name, None
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
model_params_dict = dict(self.model.named_parameters())
top_params_dict = dict(self.named_parameters())
weights_tuple_list = list(weights)
shared_embedding_weight = None
shared_embedding_shard_id = None
for name, loaded_weight in weights_tuple_list:
name = self._rename_key(name)
name, shard_id = self._rename_stacked_param(name)
if ('shared.weight' in name
or 'encoder.embed_tokens.weight' in name
or 'decoder.embed_tokens.weight' in name
or 'lm_head.weight' in name):
assert shared_embedding_weight is None, (
"Conflicting embedding weights.")
shared_embedding_weight = loaded_weight
shared_embedding_shard_id = shard_id
else:
# Skip the specific downstream task weight.
if name.startswith('cls.'):
continue
# use Pooler instead.
if name.startswith('pooler.'):
continue
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in model_params_dict:
continue
param = model_params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
if shard_id:
weight_loader(param, loaded_weight, shard_id)
else:
weight_loader(param, loaded_weight)
# Assign shared weight values
encoder_in_param = model_params_dict['encoder.embed_tokens.weight']
encoder_in_weight_loader = getattr(encoder_in_param, "weight_loader",
default_weight_loader)
decoder_in_param = model_params_dict['decoder.embed_tokens.weight']
decoder_in_weight_loader = getattr(decoder_in_param, "weight_loader",
default_weight_loader)
lm_head_in_param = top_params_dict['lm_head.weight']
lm_head_in_weight_loader = getattr(lm_head_in_param, "weight_loader",
default_weight_loader)
assert shared_embedding_weight is not None
if shared_embedding_shard_id:
encoder_in_weight_loader(encoder_in_param, shared_embedding_weight,
shared_embedding_shard_id)
decoder_in_weight_loader(decoder_in_param, shared_embedding_weight,
shared_embedding_shard_id)
lm_head_in_weight_loader(lm_head_in_param, shared_embedding_weight,
shared_embedding_shard_id)
else:
encoder_in_weight_loader(encoder_in_param, shared_embedding_weight)
decoder_in_weight_loader(decoder_in_param, shared_embedding_weight)
lm_head_in_weight_loader(lm_head_in_param, shared_embedding_weight)
"""Minimal implementation of BlipVisionModel intended to be only used
within a vision language model."""
from array import array
from typing import Optional, Union
import torch
......@@ -14,9 +15,9 @@ from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.multimodal.image import (cached_get_tokenizer,
repeat_and_pad_image_tokens)
from vllm.sequence import SequenceData
from vllm.multimodal.utils import (cached_get_tokenizer,
repeat_and_pad_placeholder_tokens)
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData
def get_blip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
......@@ -31,13 +32,13 @@ def get_blip_num_patches(*, image_size: int, patch_size: int) -> int:
def get_blip_image_feature_size(
hf_config: Union[BlipVisionConfig, Blip2VisionConfig], ) -> int:
hf_config: Union[BlipVisionConfig, Blip2VisionConfig]) -> int:
return get_blip_num_patches(image_size=hf_config.image_size,
patch_size=hf_config.patch_size)
def get_max_blip_image_tokens(
hf_config: Union[BlipVisionConfig, Blip2VisionConfig], ) -> int:
hf_config: Union[BlipVisionConfig, Blip2VisionConfig]) -> int:
return get_blip_image_feature_size(hf_config)
......@@ -53,13 +54,16 @@ def dummy_seq_data_for_blip(
else:
image_feature_size = image_feature_size_override
token_ids = [image_token_id] * image_feature_size
token_ids += [0] * (seq_len - image_feature_size)
token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
[image_token_id]) * image_feature_size
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
[0]) * (seq_len - image_feature_size)
return SequenceData(token_ids)
def dummy_image_for_blip(
hf_config: Union[BlipVisionConfig, Blip2VisionConfig],
num_images: int,
*,
image_width_override: Optional[int] = None,
image_height_override: Optional[int] = None,
......@@ -71,7 +75,7 @@ def dummy_image_for_blip(
height = image_height_override
image = Image.new("RGB", (width, height), color=0)
return {"image": image}
return {"image": image if num_images == 1 else [image] * num_images}
def input_processor_for_blip(
......@@ -93,11 +97,11 @@ def input_processor_for_blip(
else:
image_feature_size = image_feature_size_override
new_prompt, new_token_ids = repeat_and_pad_image_tokens(
new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
tokenizer,
llm_inputs.get("prompt"),
llm_inputs["prompt_token_ids"],
image_token_id=image_token_id,
placeholder_token_id=image_token_id,
repeat_count=image_feature_size,
)
......
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict
from array import array
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union)
import torch
import torch.nn as nn
......@@ -16,18 +18,42 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.opt import OPTModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import IntermediateTensors, SamplerOutput, SequenceData
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
SamplerOutput, SequenceData)
from .blip import (BlipVisionModel, dummy_image_for_blip,
get_max_blip_image_tokens)
from .interfaces import SupportsVision
from .utils import merge_vision_embeddings
from .interfaces import SupportsMultiModal
from .utils import merge_multimodal_embeddings
_KEYS_TO_MODIFY_MAPPING = {
"language_model.lm_head": "lm_head",
"language_model.model": "language_model",
}
# We use this internally as placeholders since there is no image token
# defined on the HuggingFace repo
BLIP2_IMAGE_TOKEN = "<image>"
BLIP2_IMAGE_TOKEN_ID = 50265
class Blip2ImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: torch.Tensor
"""Shape: (batch_size, num_channels, height, width)"""
class Blip2ImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
data: torch.Tensor
"""Shape: `(batch_size, image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
"""
Blip2ImageInputs = Union[Blip2ImagePixelInputs, Blip2ImageEmbeddingInputs]
class Blip2QFormerMultiHeadAttention(nn.Module):
......@@ -375,20 +401,6 @@ class Blip2QFormerModel(nn.Module):
return sequence_output
class Blip2ImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: torch.Tensor
"""Shape: (batch_size, num_channels, height, width)"""
Blip2ImageInputs = Blip2ImagePixelInputs
# We use this internally as placeholders since there is no image token
# defined on the HuggingFace repo
BLIP2_IMAGE_TOKEN = "<image>"
BLIP2_IMAGE_TOKEN_ID = 50265
def get_blip2_image_feature_size(hf_config: Blip2Config) -> int:
return hf_config.num_query_tokens
......@@ -404,17 +416,41 @@ def get_max_blip2_image_tokens(ctx: InputContext):
raise NotImplementedError(msg)
def dummy_data_for_blip2(ctx: InputContext, seq_len: int):
def dummy_seq_data_for_blip2(
hf_config: Blip2Config,
seq_len: int,
num_images: int,
*,
image_token_id: int,
image_feature_size_override: Optional[int] = None,
):
if image_feature_size_override is None:
image_feature_size = get_blip2_image_feature_size(hf_config)
else:
image_feature_size = image_feature_size_override
token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
[image_token_id]) * image_feature_size * num_images
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
[0]) * (seq_len - image_feature_size * num_images)
return SequenceData(token_ids)
def dummy_data_for_blip2(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]):
hf_config = ctx.get_hf_config(Blip2Config)
vision_config = hf_config.vision_config
num_images = mm_counts["image"]
image_feature_size = get_blip2_image_feature_size(hf_config)
token_ids = [BLIP2_IMAGE_TOKEN_ID] * image_feature_size
token_ids += [0] * (seq_len - image_feature_size)
seq_data = SequenceData(token_ids)
seq_data = dummy_seq_data_for_blip2(
hf_config,
seq_len,
num_images,
image_token_id=BLIP2_IMAGE_TOKEN_ID,
)
if isinstance(vision_config, Blip2VisionConfig):
mm_data = dummy_image_for_blip(vision_config)
mm_data = dummy_image_for_blip(vision_config, num_images)
return seq_data, mm_data
......@@ -448,7 +484,7 @@ def input_processor_for_blip2(ctx: InputContext, llm_inputs: LLMInputs):
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_blip2_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_blip2)
@INPUT_REGISTRY.register_input_processor(input_processor_for_blip2)
class Blip2ForConditionalGeneration(nn.Module, SupportsVision):
class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal):
def __init__(self,
config: Blip2Config,
......@@ -458,6 +494,9 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsVision):
super().__init__()
# currently all existing BLIP-2 models have `tie_word_embeddings`
# enabled
assert config.tie_word_embeddings
self.config = config
self.multimodal_config = multimodal_config
......@@ -506,18 +545,31 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsVision):
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[Blip2ImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
image_embeds = kwargs.pop("image_embeds", None)
if pixel_values is None:
if pixel_values is None and image_embeds is None:
return None
if not isinstance(pixel_values, torch.Tensor):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
if pixel_values is not None:
if not isinstance(pixel_values, torch.Tensor):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
return Blip2ImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(pixel_values),
)
return Blip2ImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(pixel_values),
)
if image_embeds is not None:
if not isinstance(image_embeds, torch.Tensor):
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")
return Blip2ImageEmbeddingInputs(
type="image_embeds",
data=image_embeds,
)
raise AssertionError("This line should be unreachable.")
def _image_pixels_to_features(self, vision_model: BlipVisionModel,
pixel_values: torch.Tensor) -> torch.Tensor:
......@@ -538,6 +590,10 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsVision):
def _process_image_input(self,
image_input: Blip2ImageInputs) -> torch.Tensor:
if image_input["type"] == "image_embeds":
return image_input["data"]
assert self.vision_model is not None
image_features = self._process_image_pixels(image_input)
......@@ -595,9 +651,9 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsVision):
vision_embeddings = self._process_image_input(image_input)
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
inputs_embeds = merge_vision_embeddings(input_ids, inputs_embeds,
vision_embeddings,
BLIP2_IMAGE_TOKEN_ID)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings,
BLIP2_IMAGE_TOKEN_ID)
input_ids = None
else:
......@@ -611,8 +667,11 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsVision):
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.get_lm_head(), hidden_states,
sampling_metadata)
return logits
......
......@@ -36,7 +36,7 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, SamplerOutput
......@@ -276,7 +276,12 @@ class BloomForCausalLM(nn.Module):
self.config = config
self.quant_config = quant_config
self.transformer = BloomModel(config, cache_config, quant_config)
self.lm_head = self.transformer.word_embeddings
if self.config.tie_word_embeddings:
self.lm_head = self.transformer.word_embeddings
else:
self.lm_head = ParallelLMHead(self.config.vocab_size,
self.config.hidden_size)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
......@@ -292,8 +297,11 @@ class BloomForCausalLM(nn.Module):
attn_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
......
from array import array
from functools import cached_property
from typing import (Any, Dict, Iterable, List, Literal, Optional, Tuple,
TypedDict)
from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional,
Tuple, TypedDict)
import torch
import torch.nn.functional as F
......@@ -19,21 +20,23 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, row_parallel_weight_loader)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import (cached_get_tokenizer,
repeat_and_pad_image_tokens)
from vllm.sequence import IntermediateTensors, SamplerOutput, SequenceData
from vllm.multimodal.utils import (cached_get_tokenizer,
repeat_and_pad_placeholder_tokens)
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
SamplerOutput, SequenceData)
from vllm.utils import print_warning_once
from .interfaces import SupportsVision
from .interfaces import SupportsMultiModal
logger = init_logger(__name__)
......@@ -59,6 +62,7 @@ def get_max_chameleon_image_tokens(ctx: InputContext):
def dummy_seq_data_for_chameleon(
seq_len: int,
num_images: int,
*,
image_token_id: int,
image_feature_size_override: Optional[int] = None,
......@@ -68,12 +72,16 @@ def dummy_seq_data_for_chameleon(
else:
image_feature_size = image_feature_size_override
token_ids = [image_token_id] * image_feature_size
token_ids += [0] * (seq_len - image_feature_size)
token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
[image_token_id]) * image_feature_size * num_images
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
[0]) * (seq_len - image_feature_size * num_images)
return SequenceData(token_ids)
def dummy_image_for_chameleon(
num_images: int,
*,
image_width_override: Optional[int] = None,
image_height_override: Optional[int] = None,
):
......@@ -85,17 +93,20 @@ def dummy_image_for_chameleon(
height = image_height_override
image = Image.new("RGB", (width, height), color=0)
return {"image": image}
return {"image": image if num_images == 1 else [image] * num_images}
def dummy_data_for_chameleon(ctx: InputContext, seq_len: int):
def dummy_data_for_chameleon(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]):
num_images = mm_counts["image"]
seq_data = dummy_seq_data_for_chameleon(
seq_len,
num_images,
image_token_id=CHAMELEON_IMAGE_TOKEN_ID,
)
mm_data = dummy_image_for_chameleon()
mm_data = dummy_image_for_chameleon(num_images)
return seq_data, mm_data
......@@ -113,11 +124,11 @@ def input_processor_for_chameleon(ctx: InputContext, llm_inputs: LLMInputs):
model_config = ctx.model_config
tokenizer = cached_get_tokenizer(model_config.tokenizer)
new_prompt, new_token_ids = repeat_and_pad_image_tokens(
new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
tokenizer,
llm_inputs.get("prompt"),
llm_inputs["prompt_token_ids"],
image_token_id=CHAMELEON_IMAGE_TOKEN_ID,
placeholder_token_id=CHAMELEON_IMAGE_TOKEN_ID,
repeat_count=CHAMELEON_IMAGE_SEQ_LENGTH,
pad_token_left=CHAMELEON_IMAGE_START_TOKEN_ID,
pad_token_right=CHAMELEON_IMAGE_END_TOKEN_ID,
......@@ -141,6 +152,11 @@ class ChameleonLayerNorm(nn.LayerNorm):
super().__init__(hidden_size, *args, **kwargs)
self.normalized_shape = (hidden_size[-1], )
set_weight_attrs(self.weight,
{"weight_loader": row_parallel_weight_loader})
set_weight_attrs(self.bias,
{"weight_loader": row_parallel_weight_loader})
def forward(self, hidden_states):
hidden_states = F.layer_norm(hidden_states,
self.normalized_shape,
......@@ -697,6 +713,8 @@ class ChameleonVQVAEEncoder(nn.Module):
)
def forward(self, pixel_values: torch.Tensor):
pixel_values = pixel_values.to(self.conv_in.weight.dtype)
# downsampling
hidden_states = [self.conv_in(pixel_values)]
for i_level in range(self.num_resolutions):
......@@ -877,7 +895,7 @@ class ChameleonModel(nn.Module):
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_chameleon_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_chameleon)
@INPUT_REGISTRY.register_input_processor(input_processor_for_chameleon)
class ChameleonForConditionalGeneration(nn.Module, SupportsVision):
class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal):
def __init__(
self,
......@@ -959,15 +977,19 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsVision):
attn_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
# Disallow image tokens which does not include special
# begin-image and end-image tokens
image_tokens = self.model.vocabulary_mapping.image_tokens
logits[:, image_tokens] = torch.finfo(logits.dtype).min
if logits is not None:
image_tokens = self.model.vocabulary_mapping.image_tokens
logits[:, image_tokens] = torch.finfo(logits.dtype).min
return logits
......
......@@ -368,6 +368,9 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA):
self.max_position_embeddings = getattr(config, "max_sequence_length",
8192)
self.transformer = ChatGLMModel(config, cache_config, quant_config)
if self.config.tie_word_embeddings:
self.transformer.output_layer.weight = (
self.transformer.embedding.weight)
self.lm_head = self.transformer.output_layer
self.logits_processor = LogitsProcessor(config.padded_vocab_size)
self.sampler = Sampler()
......@@ -393,8 +396,11 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA):
attn_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
......
"""Minimal implementation of CLIPVisionModel intended to be only used
within a vision language model."""
from typing import Optional
from array import array
from typing import Iterable, Optional, Tuple
import torch
import torch.nn as nn
......@@ -14,9 +15,10 @@ from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.multimodal.image import (cached_get_tokenizer,
repeat_and_pad_image_tokens)
from vllm.sequence import SequenceData
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal.utils import (cached_get_tokenizer,
repeat_and_pad_placeholder_tokens)
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData
def get_clip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
......@@ -32,7 +34,7 @@ def get_clip_num_patches(*, image_size: int, patch_size: int) -> int:
def get_clip_image_feature_size(hf_config: CLIPVisionConfig) -> int:
return get_clip_num_patches(image_size=hf_config.image_size,
patch_size=hf_config.patch_size)
patch_size=hf_config.patch_size) + 1
def get_max_clip_image_tokens(hf_config: CLIPVisionConfig) -> int:
......@@ -42,6 +44,7 @@ def get_max_clip_image_tokens(hf_config: CLIPVisionConfig) -> int:
def dummy_seq_data_for_clip(
hf_config: CLIPVisionConfig,
seq_len: int,
num_images: int,
*,
image_token_id: int,
image_feature_size_override: Optional[int] = None,
......@@ -51,13 +54,16 @@ def dummy_seq_data_for_clip(
else:
image_feature_size = image_feature_size_override
token_ids = [image_token_id] * image_feature_size
token_ids += [0] * (seq_len - image_feature_size)
token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
[image_token_id]) * image_feature_size * num_images
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
[0]) * (seq_len - image_feature_size * num_images)
return SequenceData(token_ids)
def dummy_image_for_clip(
hf_config: CLIPVisionConfig,
num_images: int,
*,
image_width_override: Optional[int] = None,
image_height_override: Optional[int] = None,
......@@ -69,7 +75,7 @@ def dummy_image_for_clip(
height = image_height_override
image = Image.new("RGB", (width, height), color=0)
return {"image": image}
return {"image": image if num_images == 1 else [image] * num_images}
def input_processor_for_clip(
......@@ -87,15 +93,21 @@ def input_processor_for_clip(
tokenizer = cached_get_tokenizer(model_config.tokenizer)
if image_feature_size_override is None:
image_feature_size = get_clip_image_feature_size(hf_config)
image_data = multi_modal_data["image"]
if isinstance(image_data, Image.Image):
image_feature_size = get_clip_image_feature_size(hf_config)
elif isinstance(image_data, torch.Tensor):
image_feature_size = image_data.shape[0]
else:
raise TypeError(f"Invalid image type: {type(image_data)}")
else:
image_feature_size = image_feature_size_override
new_prompt, new_token_ids = repeat_and_pad_image_tokens(
new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
tokenizer,
llm_inputs.get("prompt"),
llm_inputs["prompt_token_ids"],
image_token_id=image_token_id,
placeholder_token_id=image_token_id,
repeat_count=image_feature_size,
)
......@@ -291,3 +303,22 @@ class CLIPVisionModel(nn.Module):
@property
def device(self):
return next(self.parameters()).device
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters())
layer_count = len(self.vision_model.encoder.layers)
for name, loaded_weight in weights:
# post_layernorm is not needed in CLIPVisionModel
if "vision_model.post_layernorm" in name:
continue
# omit layers when num_hidden_layers_override is set
if "vision_model.encoder.layers." in name:
layer_idx = int(name.split(".")[3])
if layer_idx >= layer_count:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
......@@ -25,13 +25,11 @@ from typing import Iterable, List, Optional, Set, Tuple
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn.parameter import Parameter
from transformers import CohereConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
......@@ -43,7 +41,8 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, row_parallel_weight_loader)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs
from vllm.sequence import IntermediateTensors, SamplerOutput
......@@ -67,25 +66,14 @@ class LayerNorm(nn.Module):
super().__init__()
self.weight = nn.Parameter(torch.ones(param_shape))
self.variance_epsilon = eps
set_weight_attrs(self.weight, {"weight_loader": self.weight_loader})
set_weight_attrs(self.weight,
{"weight_loader": row_parallel_weight_loader})
def forward(self, hidden_states, residuals=None):
hidden_states = layer_norm_func(hidden_states, self.weight,
self.variance_epsilon)
return hidden_states, residuals
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
tp_rank = get_tensor_model_parallel_rank()
shard_dim = 0 if param.dim() != 1 else None
param_data = param.data
if shard_dim is not None:
shard_size = param_data.shape[shard_dim]
start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(shard_dim, start_idx,
shard_size)
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
# Copied from transformers.models.llama.modeling_llama.LlamaMLP Llama->Cohere
class CohereMLP(nn.Module):
......@@ -333,6 +321,9 @@ class CohereForCausalLM(nn.Module):
) -> None:
super().__init__()
self.config = config
# currently all existing command R models have `tie_word_embeddings`
# enabled
assert config.tie_word_embeddings
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
......@@ -359,8 +350,11 @@ class CohereForCausalLM(nn.Module):
attn_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
is_not_lora = hasattr(self.model.embed_tokens, 'weight')
if is_not_lora:
logits = self.logits_processor(self.model.embed_tokens,
......
......@@ -362,6 +362,9 @@ class DbrxForCausalLM(nn.Module):
):
super().__init__()
self.config = config
if config.tie_word_embeddings:
raise ValueError(
"tie_word_embeddings is not supported for Dbrx models.")
self.quant_config = quant_config
self.unpadded_vocab_size = config.vocab_size
self.transformer = DbrxModel(config, cache_config, quant_config)
......@@ -388,8 +391,11 @@ class DbrxForCausalLM(nn.Module):
attn_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
......
......@@ -380,6 +380,8 @@ class DeepseekForCausalLM(nn.Module):
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)
if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
......@@ -395,8 +397,11 @@ class DeepseekForCausalLM(nn.Module):
attn_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
......
......@@ -505,8 +505,11 @@ class DeepseekV2ForCausalLM(nn.Module):
attn_metadata, intermediate_tensors)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
......@@ -590,7 +593,7 @@ class DeepseekV2ForCausalLM(nn.Module):
weight_loader = param.weight_loader
weight_loader(param,
loaded_weight,
weight_name,
name,
shard_id=shard_id,
expert_id=expert_id)
break
......
from typing import Iterable, List, Optional, Tuple
import torch
import torch.nn as nn
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, SamplerOutput
from vllm.transformers_utils.configs.eagle import EAGLEConfig
class EAGLE(nn.Module):
"""This class implements the EAGLE draft model from the paper: https://arxiv.org/pdf/2401.15077
Reference implementation: https://github.com/SafeAILab/EAGLE
Differences from reference implementation:
1. In reference, LlamaDecoderLayer implementation doesn't have
input_layernorm for 1st decoder layer (https://github.com/SafeAILab/EAGLE/blob/7d065d084443fbfd386f88839efd7193c12be869/eagle/model/cnets.py#L427)
but we do as HF implementation also does.
2. We allow any decoder layer to be used in EAGLE whereas in reference
decoder layer is fixed to be LlamaDecoderLayer.
3. We have an optional token_map which reduces draft vocab to most
frequently used tokens to give some additional speed-up by reducing
sampling overhead. This is disabled unless the checkpoint file has
explicit token_map tensor and config has an optional attribute
truncated_vocab_size < vocab_size. To use this technique, one has to find
the top-k most frequent tokens in target dataset and add that as a tensor
in the draft checkpoint (using key token_map). Also, the draft config
needs to have truncated_vocab_size (=k) as an attribute."""
def __init__(self, config: EAGLEConfig, *args, **kwargs) -> None:
super().__init__()
self.config = config
architectures = getattr(self.config.model, "architectures", [])
model_cls, _ = ModelRegistry.resolve_model_cls(architectures)
self.model = model_cls(self.config.model, *args, **kwargs)
self.fc = nn.Linear(config.model.hidden_size * 2,
config.model.hidden_size,
bias=False)
self.orig_vocab_size = config.vocab_size
self.truncated_vocab_size = config.truncated_vocab_size
self.unpadded_vocab_size = self.truncated_vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=self.truncated_vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
)
logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
self.truncated_vocab_size,
logit_scale)
# Token map is a idx to token mapping to reduce the vocab size for
# the draft model. Using smaller vocab size for draft, containing
# only most frequent tokens reduces the speculation overhead. This
# doesn't affect the acceptance rate much and thus gives more speed
# -up. By default, this is disabled and is only used if the EAGLE
# checkpoint file has token_map tensor.
self.token_map = None
@property
def sampler(self):
return self.model.sampler
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
previous_hidden_states: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor:
tok_embeds = self.model.model.embed_tokens(input_ids)
inputs_embeds = self.fc(
torch.cat([tok_embeds, previous_hidden_states], dim=-1))
inputs_embeds[positions == 0] = 0 # masking inputs at position=0
hidden_states = self.model.model(
input_ids=None,
inputs_embeds=inputs_embeds,
positions=positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
intermediate_tensors=intermediate_tensors)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
if self.token_map is not None:
_logits = logits
logits = -torch.inf * torch.ones(
size=(*_logits.shape[:-1], self.orig_vocab_size),
device=_logits.device,
dtype=_logits.dtype)
logits[..., self.token_map] = _logits
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# This implementation is incompitable with https://huggingface.co/yuhuili/EAGLE-LLaMA3-Instruct-8B
# due to missing lm_head weights and its config being that of a
# Llama model. Here's a compatible version with the same weights:
# https://huggingface.co/abhigoyal/EAGLE-LLaMA3-Instruct-8B-vllm
# Also, here's an example script for converting trained EAGLE
# checkpoint to vLLM compatible version: https://gist.github.com/abhigoyal1997/1e7a4109ccb7704fbc67f625e86b2d6d
model_weights = {}
for name, loaded_weight in weights:
if name == "token_map":
if self.config.truncated_vocab_size < self.config.vocab_size:
self.token_map = nn.Parameter(loaded_weight,
requires_grad=False)
elif name.startswith("fc."):
weight_loader = getattr(self.fc.weight, "weight_loader",
default_weight_loader)
weight_loader(self.fc.weight, loaded_weight)
elif name.startswith("model.lm_head.") or name.startswith(
"model.model."):
model_weights[name.split("model.", 1)[-1]] = loaded_weight
elif name.startswith("lm_head.") or name.startswith("model."):
model_weights[name] = loaded_weight
else:
model_weights[f"model.{name}"] = loaded_weight
lm_head_weight = model_weights.pop("lm_head.weight")
if self.token_map is not None and\
lm_head_weight.shape[0] > self.token_map.shape[0]:
lm_head_weight = lm_head_weight[self.token_map]
weight_loader = getattr(self.lm_head.weight, "weight_loader",
default_weight_loader)
weight_loader(self.lm_head.weight, lm_head_weight)
self.model.load_weights(model_weights.items())
......@@ -420,8 +420,11 @@ class FalconForCausalLM(nn.Module):
)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
......
......@@ -16,7 +16,8 @@
# limitations under the License.
""" PyTorch Fuyu model."""
import math
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict
from array import array
from typing import Iterable, List, Literal, Mapping, Optional, Tuple, TypedDict
import torch
import torch.nn as nn
......@@ -29,19 +30,19 @@ from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.persimmon import PersimmonForCausalLM
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.base import MultiModalInputs
from vllm.multimodal.image import (cached_get_image_processor,
cached_get_tokenizer)
from vllm.sequence import IntermediateTensors, SamplerOutput, SequenceData
from vllm.multimodal.image import cached_get_image_processor
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
SamplerOutput, SequenceData)
from .interfaces import SupportsVision
from .utils import merge_vision_embeddings
from .interfaces import SupportsMultiModal
from .utils import merge_multimodal_embeddings
logger = init_logger(__name__)
......@@ -94,27 +95,36 @@ def get_max_fuyu_image_tokens(ctx: InputContext):
return (ncol + 1) * nrow
def dummy_seq_data_for_fuyu(ctx: InputContext, seq_len: int):
def dummy_seq_data_for_fuyu(ctx: InputContext, seq_len: int, num_images: int):
ncol, nrow = get_max_fuyu_image_feature_size()
image_feature_size = get_max_fuyu_image_tokens(ctx)
token_ids = ([_IMAGE_TOKEN_ID] * ncol + [_NEWLINE_TOKEN_ID]) * nrow
token_ids += [0] * (seq_len - image_feature_size)
image_token_ids = (
array(VLLM_TOKEN_ID_ARRAY_TYPE, [_IMAGE_TOKEN_ID]) * ncol +
array(VLLM_TOKEN_ID_ARRAY_TYPE, [_NEWLINE_TOKEN_ID])) * nrow
token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, image_token_ids) * num_images
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
[0]) * (seq_len - image_feature_size * num_images)
return SequenceData(token_ids)
def dummy_image_for_fuyu(
num_images: int,
*,
image_width: int,
image_height: int,
):
image = Image.new("RGB", (image_width, image_height), color=0)
return {"image": image}
return {"image": image if num_images == 1 else [image] * num_images}
def dummy_data_for_fuyu(ctx: InputContext, seq_len: int):
seq_data = dummy_seq_data_for_fuyu(ctx, seq_len)
mm_data = dummy_image_for_fuyu(MAX_IMAGE_FEATURE_SIZE_WIDTH,
MAX_IMAGE_FEATURE_SIZE_HEIGHT)
def dummy_data_for_fuyu(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]):
num_images = mm_counts["image"]
seq_data = dummy_seq_data_for_fuyu(ctx, seq_len, num_images)
mm_data = dummy_image_for_fuyu(num_images,
image_width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
image_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT)
return seq_data, mm_data
......@@ -209,7 +219,7 @@ def input_mapper_for_fuyu(ctx: InputContext, data: object):
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_fuyu_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_fuyu)
@INPUT_REGISTRY.register_input_processor(input_processor_for_fuyu)
class FuyuForCausalLM(nn.Module, SupportsVision):
class FuyuForCausalLM(nn.Module, SupportsMultiModal):
def __init__(self,
config: FuyuConfig,
......@@ -234,7 +244,8 @@ class FuyuForCausalLM(nn.Module, SupportsVision):
cache_config=cache_config,
quant_config=quant_config)
def _parse_and_validate_image_input(self, **kwargs: object):
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[FuyuImagePixelInputs]:
image_patches = kwargs.pop("image_patches", None)
if isinstance(image_patches, torch.Tensor):
......@@ -249,6 +260,13 @@ class FuyuForCausalLM(nn.Module, SupportsVision):
data=image_patches)
return None
def _process_image_input(
self, image_input: FuyuImagePixelInputs) -> torch.Tensor:
assert self.vision_embed_tokens is not None
vision_embeddings, _ = self.vision_embed_tokens(image_input["data"])
return vision_embeddings
def forward(
self,
input_ids: torch.Tensor,
......@@ -261,12 +279,11 @@ class FuyuForCausalLM(nn.Module, SupportsVision):
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is not None:
vision_embeddings, _ = self.vision_embed_tokens(
image_input["data"])
vision_embeddings = self._process_image_input(image_input)
inputs_embeds = self.language_model.model.embed_tokens(input_ids)
inputs_embeds = merge_vision_embeddings(input_ids, inputs_embeds,
vision_embeddings,
self.image_token_id)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.image_token_id)
else:
inputs_embeds = None
......@@ -280,8 +297,11 @@ class FuyuForCausalLM(nn.Module, SupportsVision):
)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.language_model.logits_processor(
self.language_model.lm_head, hidden_states, sampling_metadata)
return logits
......
......@@ -33,7 +33,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import GemmaRotaryEmbedding
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
......@@ -148,14 +148,12 @@ class GemmaAttention(nn.Module):
quant_config=quant_config,
)
# TODO(woosuk): Use the `get_rope` interface.
self.rotary_emb = GemmaRotaryEmbedding(
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position_embeddings=max_position_embeddings,
max_position=max_position_embeddings,
base=self.rope_theta,
is_neox_style=True,
dtype=torch.get_default_dtype(),
)
self.attn = Attention(self.num_heads,
self.head_dim,
......@@ -333,6 +331,8 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA):
super().__init__()
self.config = config
# currently all existing Gemma models have `tie_word_embeddings` enabled
assert config.tie_word_embeddings
self.lora_config = lora_config
self.quant_config = quant_config
......@@ -352,8 +352,11 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA):
attn_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.model.embed_tokens, hidden_states,
sampling_metadata)
return logits
......
......@@ -32,7 +32,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import GemmaRotaryEmbedding
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
......@@ -130,14 +130,12 @@ class Gemma2Attention(nn.Module):
bias=config.attention_bias,
quant_config=quant_config,
)
# TODO(woosuk): Use the `get_rope` interface.
self.rotary_emb = GemmaRotaryEmbedding(
self.rotary_emb = get_rope(
self.head_dim,
self.head_dim,
max_position_embeddings,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
base=self.rope_theta,
is_neox_style=True,
dtype=torch.get_default_dtype(),
)
# FIXME(woosuk): While Gemma 2 uses sliding window attention for every
......@@ -325,6 +323,8 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA):
del lora_config # Unused.
super().__init__()
self.config = config
# currently all existing Gemma models have `tie_word_embeddings` enabled
assert config.tie_word_embeddings
self.quant_config = quant_config
self.model = Gemma2Model(config, cache_config, quant_config)
self.logits_processor = LogitsProcessor(
......@@ -343,8 +343,11 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA):
attn_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.model.embed_tokens, hidden_states,
sampling_metadata)
return logits
......
......@@ -36,7 +36,7 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, SamplerOutput
......@@ -249,7 +249,11 @@ class GPT2LMHeadModel(nn.Module):
cache_config,
quant_config,
prefix="transformer")
self.lm_head = self.transformer.wte
if self.config.tie_word_embeddings:
self.lm_head = self.transformer.wte
else:
self.lm_head = ParallelLMHead(self.config.vocab_size,
self.config.hidden_size)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
......@@ -265,8 +269,11 @@ class GPT2LMHeadModel(nn.Module):
attn_metadata, intermediate_tensors)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
......
......@@ -36,7 +36,7 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, SamplerOutput
......@@ -259,7 +259,13 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA):
self.quant_config = quant_config
self.transformer = GPTBigCodeModel(config, cache_config, quant_config,
lora_config)
self.lm_head = self.transformer.wte
if self.config.tie_word_embeddings:
self.lm_head = self.transformer.wte
else:
self.lm_head = ParallelLMHead(
self.transformer.vocab_size,
self.transformer.embed_dim,
org_num_embeddings=self.config.vocab_size)
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
......@@ -279,8 +285,11 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA):
attn_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
......
......@@ -246,8 +246,11 @@ class GPTJForCausalLM(nn.Module):
attn_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata, self.lm_head.bias)
return logits
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment