Unverified Commit 0f6d7a9a authored by Murali Andoorveedu's avatar Murali Andoorveedu Committed by GitHub
Browse files

[Models] Add remaining model PP support (#7168)


Signed-off-by: default avatarMuralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
Signed-off-by: default avatarMurali Andoorveedu <muralidhar.andoorveedu@centml.ai>
Co-authored-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 303d4479
...@@ -28,7 +28,7 @@ from transformers import FalconConfig as HF_FalconConfig ...@@ -28,7 +28,7 @@ from transformers import FalconConfig as HF_FalconConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig from vllm.config import CacheConfig
from vllm.distributed import (get_tensor_model_parallel_rank, from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce) tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
...@@ -36,8 +36,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, ...@@ -36,8 +36,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization import QuantizationConfig
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
...@@ -47,6 +46,10 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata ...@@ -47,6 +46,10 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs import RWConfig from vllm.transformers_utils.configs import RWConfig
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
FalconConfig = Union[HF_FalconConfig, RWConfig] FalconConfig = Union[HF_FalconConfig, RWConfig]
...@@ -333,6 +336,7 @@ class FalconModel(nn.Module): ...@@ -333,6 +336,7 @@ class FalconModel(nn.Module):
config: FalconConfig, config: FalconConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -347,35 +351,45 @@ class FalconModel(nn.Module): ...@@ -347,35 +351,45 @@ class FalconModel(nn.Module):
) )
# Transformer blocks # Transformer blocks
self.h = nn.ModuleList([ self.start_layer, self.end_layer, self.h = make_layers(
FalconDecoderLayer(config, cache_config, quant_config) config.num_hidden_layers,
for _ in range(config.num_hidden_layers) lambda prefix: FalconDecoderLayer(config, cache_config,
]) quant_config),
prefix=f"{prefix}.h")
# Final Layer Norm # Final Layer Norm
self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(["hidden_states"],
config.hidden_size))
def forward( def forward(
self, self,
input_ids: torch.LongTensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: intermediate_tensors: Optional[IntermediateTensors],
hidden_states = self.word_embeddings(input_ids) ) -> Union[torch.Tensor, IntermediateTensors]:
for i in range(len(self.h)): if get_pp_group().is_first_rank:
hidden_states = self.word_embeddings(input_ids)
else:
hidden_states = intermediate_tensors["hidden_states"]
for i in range(self.start_layer, self.end_layer):
layer = self.h[i] layer = self.h[i]
hidden_states = layer( hidden_states = layer(
positions, positions,
hidden_states, hidden_states,
kv_caches[i], kv_caches[i - self.start_layer],
attn_metadata, attn_metadata,
) )
if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})
hidden_states = self.ln_f(hidden_states) hidden_states = self.ln_f(hidden_states)
return hidden_states return hidden_states
class FalconForCausalLM(nn.Module): class FalconForCausalLM(nn.Module, SupportsPP):
def __init__( def __init__(
self, self,
...@@ -403,6 +417,8 @@ class FalconForCausalLM(nn.Module): ...@@ -403,6 +417,8 @@ class FalconForCausalLM(nn.Module):
) )
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
self.transformer.make_empty_intermediate_tensors)
def forward( def forward(
self, self,
...@@ -412,12 +428,8 @@ class FalconForCausalLM(nn.Module): ...@@ -412,12 +428,8 @@ class FalconForCausalLM(nn.Module):
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.transformer( hidden_states = self.transformer(input_ids, positions, kv_caches,
input_ids, attn_metadata, intermediate_tensors)
positions,
kv_caches,
attn_metadata,
)
return hidden_states return hidden_states
def compute_logits( def compute_logits(
...@@ -454,6 +466,8 @@ class FalconForCausalLM(nn.Module): ...@@ -454,6 +466,8 @@ class FalconForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name] param = params_dict[name]
if "query_key_value" in name: if "query_key_value" in name:
output_dim = getattr(param, "output_dim", None) output_dim = getattr(param, "output_dim", None)
......
...@@ -41,8 +41,9 @@ from vllm.multimodal.utils import cached_get_tokenizer ...@@ -41,8 +41,9 @@ from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
SequenceData) SequenceData)
from .interfaces import SupportsMultiModal from .interfaces import SupportsMultiModal, SupportsPP
from .utils import flatten_bn, merge_multimodal_embeddings from .utils import (flatten_bn, group_weights_with_prefix,
merge_multimodal_embeddings)
# Cannot find the following 2 numbers from hf config. # Cannot find the following 2 numbers from hf config.
_IMAGE_TOKEN_ID = 71011 _IMAGE_TOKEN_ID = 71011
...@@ -217,7 +218,7 @@ def input_mapper_for_fuyu(ctx: InputContext, data: object): ...@@ -217,7 +218,7 @@ def input_mapper_for_fuyu(ctx: InputContext, data: object):
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_fuyu_image_tokens) @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_fuyu_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_fuyu) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_fuyu)
@INPUT_REGISTRY.register_input_processor(input_processor_for_fuyu) @INPUT_REGISTRY.register_input_processor(input_processor_for_fuyu)
class FuyuForCausalLM(nn.Module, SupportsMultiModal): class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
def __init__(self, def __init__(self,
config: FuyuConfig, config: FuyuConfig,
...@@ -242,6 +243,12 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal): ...@@ -242,6 +243,12 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal):
self.language_model = PersimmonForCausalLM(config.text_config, self.language_model = PersimmonForCausalLM(config.text_config,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config)
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
@property
def sampler(self):
return self.language_model.sampler
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
...@@ -297,23 +304,29 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal): ...@@ -297,23 +304,29 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal):
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs: object, **kwargs: object,
): ):
image_input = self._parse_and_validate_image_input(**kwargs) if intermediate_tensors is not None:
input_ids = None
inputs_embeds = None
else:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is not None: if image_input is not None:
vision_embeddings = self._process_image_input(image_input) vision_embeddings = self._process_image_input(image_input)
inputs_embeds = self.language_model.model.embed_tokens(input_ids) inputs_embeds = self.language_model.model.embed_tokens(
inputs_embeds = merge_multimodal_embeddings( input_ids)
input_ids, inputs_embeds, vision_embeddings, inputs_embeds = merge_multimodal_embeddings(
self.image_token_id) input_ids, inputs_embeds, vision_embeddings,
self.image_token_id)
else: else:
inputs_embeds = None inputs_embeds = None
hidden_states = self.language_model( hidden_states = self.language_model(
input_ids=input_ids, input_ids=input_ids,
positions=positions, positions=positions,
kv_caches=kv_caches, kv_caches=kv_caches,
attn_metadata=attn_metadata, attn_metadata=attn_metadata,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
) )
return hidden_states return hidden_states
...@@ -336,34 +349,16 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal): ...@@ -336,34 +349,16 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal):
return next_tokens return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters(remove_duplicate=False)) # prepare weight iterators for components
for name, loaded_weight in weights: weights_group = group_weights_with_prefix(weights)
if "rotary_emb.inv_freq" in name:
continue
if ("rotary_emb.cos_cached" in name
or "rotary_emb.sin_cached" in name):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
param = params_dict[name]
if "query_key_value" in name:
# copy from vllm/model_executor/models/bloom.py
# NOTE: Fuyu's fused QKV's output_dim has the shape of
# (num_heads * 3 * head_size), while the
# required shape is (3 * num_heads * head_size).
# Thus, we need weight conversion.
output_dim = getattr(param, "output_dim", None)
num_heads = self.config.num_attention_heads
if output_dim is not None:
loaded_weight_shape = loaded_weight.shape
loaded_weight = loaded_weight.view(
loaded_weight_shape[:output_dim] + (num_heads, 3, -1) +
loaded_weight_shape[output_dim + 1:])
loaded_weight = loaded_weight.transpose(
output_dim, output_dim + 1)
loaded_weight = loaded_weight.reshape(loaded_weight_shape)
# load vision embeddings
vision_params_dict = dict(self.vision_embed_tokens.named_parameters())
for name, loaded_weight in weights_group["vision_embed_tokens"]:
param = vision_params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
# load llm backbone
self.language_model.load_weights(weights_group["language_model"])
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
# limitations under the License. # limitations under the License.
"""Inference-only Gemma model compatible with HuggingFace weights.""" """Inference-only Gemma model compatible with HuggingFace weights."""
from functools import lru_cache from functools import lru_cache
from typing import Iterable, List, Optional, Set, Tuple from typing import Iterable, List, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
...@@ -23,7 +23,7 @@ from transformers import GemmaConfig ...@@ -23,7 +23,7 @@ from transformers import GemmaConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.activation import GeluAndMul from vllm.model_executor.layers.activation import GeluAndMul
from vllm.model_executor.layers.layernorm import GemmaRMSNorm from vllm.model_executor.layers.layernorm import GemmaRMSNorm
...@@ -31,8 +31,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, ...@@ -31,8 +31,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization import QuantizationConfig
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
...@@ -41,7 +40,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader ...@@ -41,7 +40,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA from .interfaces import SupportsLoRA, SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -245,6 +246,7 @@ class GemmaModel(nn.Module): ...@@ -245,6 +246,7 @@ class GemmaModel(nn.Module):
config: GemmaConfig, config: GemmaConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -253,10 +255,11 @@ class GemmaModel(nn.Module): ...@@ -253,10 +255,11 @@ class GemmaModel(nn.Module):
config.vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
) )
self.layers = nn.ModuleList([ self.start_layer, self.end_layer, self.layers = make_layers(
GemmaDecoderLayer(config, cache_config, quant_config) config.num_hidden_layers,
for _ in range(config.num_hidden_layers) lambda prefix: GemmaDecoderLayer(config, cache_config, quant_config
]) ),
prefix=f"{prefix}.layers")
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# Normalize the embedding by sqrt(hidden_size) # Normalize the embedding by sqrt(hidden_size)
...@@ -265,6 +268,9 @@ class GemmaModel(nn.Module): ...@@ -265,6 +268,9 @@ class GemmaModel(nn.Module):
# See https://github.com/huggingface/transformers/pull/29402 # See https://github.com/huggingface/transformers/pull/29402
normalizer = self.config.hidden_size**0.5 normalizer = self.config.hidden_size**0.5
self.register_buffer("normalizer", torch.tensor(normalizer)) self.register_buffer("normalizer", torch.tensor(normalizer))
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids)
...@@ -275,29 +281,38 @@ class GemmaModel(nn.Module): ...@@ -275,29 +281,38 @@ class GemmaModel(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> Union[torch.Tensor, IntermediateTensors]:
if inputs_embeds is not None: if get_pp_group().is_first_rank:
hidden_states = inputs_embeds if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
hidden_states *= self.normalizer
residual = None
else: else:
hidden_states = self.get_input_embeddings(input_ids) hidden_states = intermediate_tensors["hidden_states"]
hidden_states *= self.normalizer residual = intermediate_tensors["residual"]
residual = None for i in range(self.start_layer, self.end_layer):
for i in range(len(self.layers)):
layer = self.layers[i] layer = self.layers[i]
hidden_states, residual = layer( hidden_states, residual = layer(
positions, positions,
hidden_states, hidden_states,
kv_caches[i], kv_caches[i - self.start_layer],
attn_metadata, attn_metadata,
residual, residual,
) )
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states return hidden_states
class GemmaForCausalLM(nn.Module, SupportsLoRA): class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": [ "qkv_proj": [
"q_proj", "q_proj",
...@@ -339,6 +354,8 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA): ...@@ -339,6 +354,8 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA):
self.model = GemmaModel(config, cache_config, quant_config) self.model = GemmaModel(config, cache_config, quant_config)
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def forward( def forward(
self, self,
...@@ -347,9 +364,9 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA): ...@@ -347,9 +364,9 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA):
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata) attn_metadata, intermediate_tensors)
return hidden_states return hidden_states
def compute_logits( def compute_logits(
...@@ -388,6 +405,8 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA): ...@@ -388,6 +405,8 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA):
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name] param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
...@@ -400,6 +419,8 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA): ...@@ -400,6 +419,8 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA):
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Iterable, List, Optional, Set, Tuple from typing import Iterable, List, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
...@@ -22,7 +22,7 @@ from transformers import Gemma2Config ...@@ -22,7 +22,7 @@ from transformers import Gemma2Config
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.activation import GeluAndMul from vllm.model_executor.layers.activation import GeluAndMul
from vllm.model_executor.layers.layernorm import GemmaRMSNorm from vllm.model_executor.layers.layernorm import GemmaRMSNorm
...@@ -30,8 +30,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, ...@@ -30,8 +30,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization import QuantizationConfig
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
...@@ -40,7 +39,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader ...@@ -40,7 +39,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA from .interfaces import SupportsLoRA, SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -244,6 +245,7 @@ class Gemma2Model(nn.Module): ...@@ -244,6 +245,7 @@ class Gemma2Model(nn.Module):
config: Gemma2Config, config: Gemma2Config,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -252,10 +254,11 @@ class Gemma2Model(nn.Module): ...@@ -252,10 +254,11 @@ class Gemma2Model(nn.Module):
config.vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
) )
self.layers = nn.ModuleList([ self.start_layer, self.end_layer, self.layers = make_layers(
Gemma2DecoderLayer(layer_idx, config, cache_config, quant_config) config.num_hidden_layers,
for layer_idx in range(config.num_hidden_layers) lambda prefix: Gemma2DecoderLayer(int(prefix.split(".")[
]) -1]), config, cache_config, quant_config),
prefix=f"{prefix}.layers")
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# Normalize the embedding by sqrt(hidden_size) # Normalize the embedding by sqrt(hidden_size)
...@@ -264,6 +267,9 @@ class Gemma2Model(nn.Module): ...@@ -264,6 +267,9 @@ class Gemma2Model(nn.Module):
# See https://github.com/huggingface/transformers/pull/29402 # See https://github.com/huggingface/transformers/pull/29402
normalizer = self.config.hidden_size**0.5 normalizer = self.config.hidden_size**0.5
self.register_buffer("normalizer", torch.tensor(normalizer)) self.register_buffer("normalizer", torch.tensor(normalizer))
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
def forward( def forward(
self, self,
...@@ -271,25 +277,36 @@ class Gemma2Model(nn.Module): ...@@ -271,25 +277,36 @@ class Gemma2Model(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: intermediate_tensors: Optional[IntermediateTensors],
hidden_states = self.embed_tokens(input_ids) ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states *= self.normalizer if get_pp_group().is_first_rank:
hidden_states = self.embed_tokens(input_ids)
hidden_states *= self.normalizer
residual = None residual = None
for i in range(len(self.layers)): else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i] layer = self.layers[i]
hidden_states, residual = layer( hidden_states, residual = layer(
positions, positions,
hidden_states, hidden_states,
kv_caches[i], kv_caches[i - self.start_layer],
attn_metadata, attn_metadata,
residual, residual,
) )
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states return hidden_states
class Gemma2ForCausalLM(nn.Module, SupportsLoRA): class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": [ "qkv_proj": [
"q_proj", "q_proj",
...@@ -338,6 +355,8 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA): ...@@ -338,6 +355,8 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA):
self.logits_processor = LogitsProcessor( self.logits_processor = LogitsProcessor(
config.vocab_size, soft_cap=config.final_logit_softcapping) config.vocab_size, soft_cap=config.final_logit_softcapping)
self.sampler = Sampler() self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def forward( def forward(
self, self,
...@@ -346,9 +365,9 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA): ...@@ -346,9 +365,9 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA):
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata) attn_metadata, intermediate_tensors)
return hidden_states return hidden_states
def compute_logits( def compute_logits(
...@@ -387,6 +406,8 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA): ...@@ -387,6 +406,8 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA):
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name] param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
...@@ -399,6 +420,8 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA): ...@@ -399,6 +420,8 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA):
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
......
...@@ -32,8 +32,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, ...@@ -32,8 +32,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization import QuantizationConfig
QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
...@@ -41,7 +40,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader ...@@ -41,7 +40,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .utils import is_pp_missing_parameter, make_layers from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
class GPT2Attention(nn.Module): class GPT2Attention(nn.Module):
...@@ -204,6 +205,9 @@ class GPT2Model(nn.Module): ...@@ -204,6 +205,9 @@ class GPT2Model(nn.Module):
config, cache_config, quant_config, prefix=prefix), config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.h") prefix=f"{prefix}.h")
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(["hidden_states"],
config.n_embd))
def forward( def forward(
self, self,
...@@ -234,7 +238,7 @@ class GPT2Model(nn.Module): ...@@ -234,7 +238,7 @@ class GPT2Model(nn.Module):
return hidden_states return hidden_states
class GPT2LMHeadModel(nn.Module): class GPT2LMHeadModel(nn.Module, SupportsPP):
def __init__( def __init__(
self, self,
...@@ -256,6 +260,8 @@ class GPT2LMHeadModel(nn.Module): ...@@ -256,6 +260,8 @@ class GPT2LMHeadModel(nn.Module):
self.config.hidden_size) self.config.hidden_size)
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
self.transformer.make_empty_intermediate_tensors)
def forward( def forward(
self, self,
...@@ -264,7 +270,7 @@ class GPT2LMHeadModel(nn.Module): ...@@ -264,7 +270,7 @@ class GPT2LMHeadModel(nn.Module):
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors) attn_metadata, intermediate_tensors)
return hidden_states return hidden_states
...@@ -286,16 +292,6 @@ class GPT2LMHeadModel(nn.Module): ...@@ -286,16 +292,6 @@ class GPT2LMHeadModel(nn.Module):
next_tokens = self.sampler(logits, sampling_metadata) next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens return next_tokens
def make_empty_intermediate_tensors(
self, batch_size: int, dtype: torch.dtype,
device: torch.device) -> IntermediateTensors:
return IntermediateTensors({
"hidden_states":
torch.zeros((batch_size, self.config.hidden_size),
dtype=dtype,
device=device),
})
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters(remove_duplicate=False)) params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in weights: for name, loaded_weight in weights:
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only GPTBigCode model compatible with HuggingFace weights.""" """Inference-only GPTBigCode model compatible with HuggingFace weights."""
from typing import Iterable, List, Optional, Tuple from typing import Iterable, List, Optional, Tuple, Union
import torch import torch
from torch import nn from torch import nn
...@@ -26,14 +26,13 @@ from transformers import GPTBigCodeConfig ...@@ -26,14 +26,13 @@ from transformers import GPTBigCodeConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization import QuantizationConfig
QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
...@@ -41,7 +40,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader ...@@ -41,7 +40,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA from .interfaces import SupportsLoRA, SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
class GPTBigCodeAttention(nn.Module): class GPTBigCodeAttention(nn.Module):
...@@ -194,6 +195,7 @@ class GPTBigCodeModel(nn.Module): ...@@ -194,6 +195,7 @@ class GPTBigCodeModel(nn.Module):
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None, lora_config: Optional[LoRAConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -207,11 +209,15 @@ class GPTBigCodeModel(nn.Module): ...@@ -207,11 +209,15 @@ class GPTBigCodeModel(nn.Module):
self.embed_dim, self.embed_dim,
org_num_embeddings=config.vocab_size) org_num_embeddings=config.vocab_size)
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
self.h = nn.ModuleList([ self.start_layer, self.end_layer, self.h = make_layers(
GPTBigCodeBlock(config, cache_config, quant_config) config.num_hidden_layers,
for _ in range(config.num_hidden_layers) lambda prefix: GPTBigCodeBlock(config, cache_config, quant_config),
]) prefix=f"{prefix}.h",
)
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(["hidden_states"],
config.n_embd))
def forward( def forward(
self, self,
...@@ -219,20 +225,28 @@ class GPTBigCodeModel(nn.Module): ...@@ -219,20 +225,28 @@ class GPTBigCodeModel(nn.Module):
position_ids: torch.Tensor, position_ids: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds = self.wte(input_ids) ) -> Union[torch.Tensor, IntermediateTensors]:
position_embeds = self.wpe(position_ids) if get_pp_group().is_first_rank:
hidden_states = inputs_embeds + position_embeds inputs_embeds = self.wte(input_ids)
position_embeds = self.wpe(position_ids)
hidden_states = inputs_embeds + position_embeds
else:
hidden_states = intermediate_tensors["hidden_states"]
for i in range(len(self.h)): for i in range(self.start_layer, self.end_layer):
layer = self.h[i] layer = self.h[i]
hidden_states = layer(hidden_states, kv_caches[i], attn_metadata) hidden_states = layer(hidden_states,
kv_caches[i - self.start_layer],
attn_metadata)
if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})
hidden_states = self.ln_f(hidden_states) hidden_states = self.ln_f(hidden_states)
return hidden_states return hidden_states
class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA): class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = {"c_attn": ["c_attn"]} packed_modules_mapping = {"c_attn": ["c_attn"]}
supported_lora_modules = ["c_fc", "c_proj", "wte", "c_attn"] supported_lora_modules = ["c_fc", "c_proj", "wte", "c_attn"]
...@@ -272,6 +286,8 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA): ...@@ -272,6 +286,8 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA):
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size) config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
self.transformer.make_empty_intermediate_tensors)
def forward( def forward(
self, self,
...@@ -280,9 +296,9 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA): ...@@ -280,9 +296,9 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA):
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions, kv_caches,
attn_metadata) attn_metadata, intermediate_tensors)
return hidden_states return hidden_states
def compute_logits( def compute_logits(
...@@ -311,6 +327,8 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA): ...@@ -311,6 +327,8 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA):
# Skip attention mask. # Skip attention mask.
# NOTE: "c_attn.bias" should not be skipped. # NOTE: "c_attn.bias" should not be skipped.
continue continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only GPT-J model compatible with HuggingFace weights.""" """Inference-only GPT-J model compatible with HuggingFace weights."""
from typing import Iterable, List, Optional, Tuple from typing import Iterable, List, Optional, Tuple, Union
import torch import torch
from torch import nn from torch import nn
...@@ -24,14 +24,13 @@ from transformers import GPTJConfig ...@@ -24,14 +24,13 @@ from transformers import GPTJConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig from vllm.config import CacheConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization import QuantizationConfig
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
...@@ -40,6 +39,10 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader ...@@ -40,6 +39,10 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
class GPTJAttention(nn.Module): class GPTJAttention(nn.Module):
...@@ -178,6 +181,7 @@ class GPTJModel(nn.Module): ...@@ -178,6 +181,7 @@ class GPTJModel(nn.Module):
config: GPTJConfig, config: GPTJConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -186,11 +190,15 @@ class GPTJModel(nn.Module): ...@@ -186,11 +190,15 @@ class GPTJModel(nn.Module):
config.vocab_size, config.vocab_size,
self.embed_dim, self.embed_dim,
) )
self.h = nn.ModuleList([ self.start_layer, self.end_layer, self.h = make_layers(
GPTJBlock(config, cache_config, quant_config) config.n_layer,
for _ in range(config.n_layer) lambda prefix: GPTJBlock(config, cache_config, quant_config),
]) prefix=f"{prefix}.h",
)
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(["hidden_states"],
config.n_embd))
def forward( def forward(
self, self,
...@@ -198,21 +206,27 @@ class GPTJModel(nn.Module): ...@@ -198,21 +206,27 @@ class GPTJModel(nn.Module):
position_ids: torch.Tensor, position_ids: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: intermediate_tensors: Optional[IntermediateTensors],
hidden_states = self.wte(input_ids) ) -> Union[torch.Tensor, IntermediateTensors]:
for i in range(len(self.h)): if get_pp_group().is_first_rank:
hidden_states = self.wte(input_ids)
else:
hidden_states = intermediate_tensors["hidden_states"]
for i in range(self.start_layer, self.end_layer):
layer = self.h[i] layer = self.h[i]
hidden_states = layer( hidden_states = layer(
position_ids, position_ids,
hidden_states, hidden_states,
kv_caches[i], kv_caches[i - self.start_layer],
attn_metadata, attn_metadata,
) )
if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})
hidden_states = self.ln_f(hidden_states) hidden_states = self.ln_f(hidden_states)
return hidden_states return hidden_states
class GPTJForCausalLM(nn.Module): class GPTJForCausalLM(nn.Module, SupportsPP):
def __init__( def __init__(
self, self,
...@@ -233,6 +247,8 @@ class GPTJForCausalLM(nn.Module): ...@@ -233,6 +247,8 @@ class GPTJForCausalLM(nn.Module):
) )
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
self.transformer.make_empty_intermediate_tensors)
def forward( def forward(
self, self,
...@@ -241,9 +257,9 @@ class GPTJForCausalLM(nn.Module): ...@@ -241,9 +257,9 @@ class GPTJForCausalLM(nn.Module):
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions, kv_caches,
attn_metadata) attn_metadata, intermediate_tensors)
return hidden_states return hidden_states
def compute_logits( def compute_logits(
...@@ -283,6 +299,8 @@ class GPTJForCausalLM(nn.Module): ...@@ -283,6 +299,8 @@ class GPTJForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name] param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
...@@ -291,6 +309,8 @@ class GPTJForCausalLM(nn.Module): ...@@ -291,6 +309,8 @@ class GPTJForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only GPT-NeoX model compatible with HuggingFace weights.""" """Inference-only GPT-NeoX model compatible with HuggingFace weights."""
from typing import Iterable, List, Optional, Tuple from typing import Iterable, List, Optional, Tuple, Union
import torch import torch
from torch import nn from torch import nn
...@@ -24,14 +24,13 @@ from transformers import GPTNeoXConfig ...@@ -24,14 +24,13 @@ from transformers import GPTNeoXConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig from vllm.config import CacheConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization import QuantizationConfig
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
...@@ -40,6 +39,10 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader ...@@ -40,6 +39,10 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
class GPTNeoXAttention(nn.Module): class GPTNeoXAttention(nn.Module):
...@@ -191,6 +194,7 @@ class GPTNeoXModel(nn.Module): ...@@ -191,6 +194,7 @@ class GPTNeoXModel(nn.Module):
config: GPTNeoXConfig, config: GPTNeoXConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -199,12 +203,16 @@ class GPTNeoXModel(nn.Module): ...@@ -199,12 +203,16 @@ class GPTNeoXModel(nn.Module):
config.vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
) )
self.layers = nn.ModuleList([ self.start_layer, self.end_layer, self.layers = make_layers(
GPTNeoXLayer(config, cache_config, quant_config) config.num_hidden_layers,
for _ in range(config.num_hidden_layers) lambda prefix: GPTNeoXLayer(config, cache_config, quant_config),
]) prefix=f"{prefix}.layers",
)
self.final_layer_norm = nn.LayerNorm(config.hidden_size, self.final_layer_norm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps) eps=config.layer_norm_eps)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(["hidden_states"],
config.hidden_size))
def forward( def forward(
self, self,
...@@ -212,21 +220,27 @@ class GPTNeoXModel(nn.Module): ...@@ -212,21 +220,27 @@ class GPTNeoXModel(nn.Module):
position_ids: torch.Tensor, position_ids: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: intermediate_tensors: Optional[IntermediateTensors],
hidden_states = self.embed_in(input_ids) ) -> Union[torch.Tensor, IntermediateTensors]:
for i in range(len(self.layers)): if get_pp_group().is_first_rank:
hidden_states = self.embed_in(input_ids)
else:
hidden_states = intermediate_tensors["hidden_states"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i] layer = self.layers[i]
hidden_states = layer( hidden_states = layer(
position_ids, position_ids,
hidden_states, hidden_states,
kv_caches[i], kv_caches[i - self.start_layer],
attn_metadata, attn_metadata,
) )
if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})
hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.final_layer_norm(hidden_states)
return hidden_states return hidden_states
class GPTNeoXForCausalLM(nn.Module): class GPTNeoXForCausalLM(nn.Module, SupportsPP):
def __init__( def __init__(
self, self,
...@@ -247,6 +261,8 @@ class GPTNeoXForCausalLM(nn.Module): ...@@ -247,6 +261,8 @@ class GPTNeoXForCausalLM(nn.Module):
self.embed_out.weight = self.gpt_neox.embed_in.weight self.embed_out.weight = self.gpt_neox.embed_in.weight
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
self.gpt_neox.make_empty_intermediate_tensors)
def forward( def forward(
self, self,
...@@ -255,9 +271,9 @@ class GPTNeoXForCausalLM(nn.Module): ...@@ -255,9 +271,9 @@ class GPTNeoXForCausalLM(nn.Module):
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.gpt_neox(input_ids, positions, kv_caches, hidden_states = self.gpt_neox(input_ids, positions, kv_caches,
attn_metadata) attn_metadata, intermediate_tensors)
return hidden_states return hidden_states
def compute_logits( def compute_logits(
...@@ -288,6 +304,8 @@ class GPTNeoXForCausalLM(nn.Module): ...@@ -288,6 +304,8 @@ class GPTNeoXForCausalLM(nn.Module):
# Models trained using OpenRLHF may include # Models trained using OpenRLHF may include
# these tensors in the checkpoint. Skip them. # these tensors in the checkpoint. Skip them.
continue continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name] param = params_dict[name]
if "query_key_value" in name: if "query_key_value" in name:
......
...@@ -51,7 +51,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata ...@@ -51,7 +51,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import is_hip from vllm.utils import is_hip
from .interfaces import SupportsLoRA from .interfaces import SupportsLoRA, SupportsPP
from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers
...@@ -311,13 +311,13 @@ class GraniteModel(nn.Module): ...@@ -311,13 +311,13 @@ class GraniteModel(nn.Module):
else: else:
hidden_states = self.get_input_embeddings(input_ids) hidden_states = self.get_input_embeddings(input_ids)
residual = None residual = None
hidden_states *= self.config.embedding_multiplier
else: else:
assert intermediate_tensors is not None assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
hidden_states *= self.config.embedding_multiplier
for i in range(self.start_layer, self.end_layer): for i in range(self.start_layer, self.end_layer):
layer = self.layers[i] layer = self.layers[i]
hidden_states = layer( hidden_states = layer(
...@@ -337,7 +337,7 @@ class GraniteModel(nn.Module): ...@@ -337,7 +337,7 @@ class GraniteModel(nn.Module):
return hidden_states return hidden_states
class GraniteForCausalLM(nn.Module, SupportsLoRA): class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": [ "qkv_proj": [
"q_proj", "q_proj",
......
...@@ -46,7 +46,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata ...@@ -46,7 +46,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from . import mixtral from . import mixtral
from .interfaces import SupportsLoRA from .interfaces import SupportsLoRA, SupportsPP
from .utils import make_layers from .utils import make_layers
...@@ -307,7 +307,7 @@ class GraniteMoeModel(nn.Module): ...@@ -307,7 +307,7 @@ class GraniteMoeModel(nn.Module):
return hidden_states return hidden_states
class GraniteMoeForCausalLM(nn.Module, SupportsLoRA): class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
fall_back_to_pt_during_load = False fall_back_to_pt_during_load = False
packed_modules_mapping = { packed_modules_mapping = {
......
from typing import (ClassVar, Dict, List, Literal, Optional, Protocol, Type, import inspect
Union, overload, runtime_checkable) from typing import (TYPE_CHECKING, ClassVar, Dict, List, Literal, Optional,
Protocol, Type, Union, overload, runtime_checkable)
import torch
from typing_extensions import TypeIs from typing_extensions import TypeIs
from vllm.config import LoRAConfig, MultiModalConfig, SchedulerConfig
from vllm.logger import init_logger from vllm.logger import init_logger
if TYPE_CHECKING:
from vllm.attention import AttentionMetadata
from vllm.config import LoRAConfig, MultiModalConfig, SchedulerConfig
from vllm.sequence import IntermediateTensors
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -22,7 +28,7 @@ class SupportsMultiModal(Protocol): ...@@ -22,7 +28,7 @@ class SupportsMultiModal(Protocol):
MRO of your model class. MRO of your model class.
""" """
def __init__(self, *, multimodal_config: MultiModalConfig) -> None: def __init__(self, *, multimodal_config: "MultiModalConfig") -> None:
... ...
...@@ -32,7 +38,7 @@ class SupportsMultiModal(Protocol): ...@@ -32,7 +38,7 @@ class SupportsMultiModal(Protocol):
class _SupportsMultiModalType(Protocol): class _SupportsMultiModalType(Protocol):
supports_multimodal: Literal[True] supports_multimodal: Literal[True]
def __call__(self, *, multimodal_config: MultiModalConfig) -> None: def __call__(self, *, multimodal_config: "MultiModalConfig") -> None:
... ...
...@@ -75,7 +81,7 @@ class SupportsLoRA(Protocol): ...@@ -75,7 +81,7 @@ class SupportsLoRA(Protocol):
embedding_padding_modules: ClassVar[List[str]] embedding_padding_modules: ClassVar[List[str]]
# lora_config is None when LoRA is not enabled # lora_config is None when LoRA is not enabled
def __init__(self, *, lora_config: Optional[LoRAConfig] = None) -> None: def __init__(self, *, lora_config: Optional["LoRAConfig"] = None) -> None:
... ...
...@@ -90,7 +96,7 @@ class _SupportsLoRAType(Protocol): ...@@ -90,7 +96,7 @@ class _SupportsLoRAType(Protocol):
embedding_modules: Dict[str, str] embedding_modules: Dict[str, str]
embedding_padding_modules: List[str] embedding_padding_modules: List[str]
def __call__(self, *, lora_config: Optional[LoRAConfig] = None) -> None: def __call__(self, *, lora_config: Optional["LoRAConfig"] = None) -> None:
... ...
...@@ -145,6 +151,132 @@ def _supports_lora( ...@@ -145,6 +151,132 @@ def _supports_lora(
return isinstance(model, SupportsLoRA) return isinstance(model, SupportsLoRA)
@runtime_checkable
class SupportsPP(Protocol):
"""The interface required for all models that support pipeline parallel."""
supports_pp: ClassVar[Literal[True]] = True
"""
A flag that indicates this model supports pipeline parallel.
Note:
There is no need to redefine this flag if this class is in the
MRO of your model class.
"""
def make_empty_intermediate_tensors(
self,
batch_size: int,
dtype: torch.dtype,
device: torch.device,
) -> "IntermediateTensors":
"""Called when PP rank > 0 for profiling purposes."""
...
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: "AttentionMetadata",
intermediate_tensors: Optional["IntermediateTensors"],
) -> Union[torch.Tensor, "IntermediateTensors"]:
"""
Accept :class:`IntermediateTensors` when PP rank > 0.
Return :class:`IntermediateTensors` only for the last PP rank.
"""
...
# We can't use runtime_checkable with ClassVar for issubclass checks
# so we need to treat the class as an instance and use isinstance instead
@runtime_checkable
class _SupportsPPType(Protocol):
supports_pp: Literal[True]
def make_empty_intermediate_tensors(
self,
batch_size: int,
dtype: torch.dtype,
device: torch.device,
) -> "IntermediateTensors":
...
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: "AttentionMetadata",
intermediate_tensors: Optional["IntermediateTensors"],
) -> Union[torch.Tensor, "IntermediateTensors"]:
...
@overload
def supports_pp(model: Type[object]) -> TypeIs[Type[SupportsPP]]:
...
@overload
def supports_pp(model: object) -> TypeIs[SupportsPP]:
...
def supports_pp(
model: Union[Type[object], object],
) -> Union[bool, TypeIs[Type[SupportsPP]], TypeIs[SupportsPP]]:
supports_attributes = _supports_pp_attributes(model)
supports_inspect = _supports_pp_inspect(model)
if supports_attributes and not supports_inspect:
logger.warning(
"The model (%s) sets `supports_pp=True`, but does not accept "
"`intermediate_tensors` in its `forward` method", model)
if not supports_attributes:
pp_attrs = ("make_empty_intermediate_tensors", )
missing_attrs = tuple(attr for attr in pp_attrs
if not hasattr(model, attr))
if getattr(model, "supports_pp", False):
if missing_attrs:
logger.warning(
"The model (%s) sets `supports_pp=True`, "
"but is missing PP-specific attributes: %s",
model,
missing_attrs,
)
else:
if not missing_attrs:
logger.warning(
"The model (%s) contains all PP-specific attributes, "
"but does not set `supports_pp=True`.", model)
return supports_attributes and supports_inspect
def _supports_pp_attributes(
model: Union[Type[object], object],
) -> Union[bool, TypeIs[Type[SupportsPP]], TypeIs[SupportsPP]]:
if isinstance(model, type):
return isinstance(model, _SupportsPPType)
return isinstance(model, SupportsPP)
def _supports_pp_inspect(
model: Union[Type[object], object],
) -> Union[bool, TypeIs[Type[SupportsPP]], TypeIs[SupportsPP]]:
model_forward = getattr(model, "forward", None)
if not callable(model_forward):
return False
forward_params = inspect.signature(model_forward).parameters
return "intermediate_tensors" in forward_params
@runtime_checkable @runtime_checkable
class HasInnerState(Protocol): class HasInnerState(Protocol):
"""The interface required for all models that has inner state.""" """The interface required for all models that has inner state."""
...@@ -158,7 +290,7 @@ class HasInnerState(Protocol): ...@@ -158,7 +290,7 @@ class HasInnerState(Protocol):
def __init__(self, def __init__(self,
*, *,
scheduler_config: Optional[SchedulerConfig] = None) -> None: scheduler_config: Optional["SchedulerConfig"] = None) -> None:
... ...
...@@ -168,7 +300,7 @@ class _HasInnerStateType(Protocol): ...@@ -168,7 +300,7 @@ class _HasInnerStateType(Protocol):
def __init__(self, def __init__(self,
*, *,
scheduler_config: Optional[SchedulerConfig] = None) -> None: scheduler_config: Optional["SchedulerConfig"] = None) -> None:
... ...
......
...@@ -18,8 +18,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, ...@@ -18,8 +18,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization import QuantizationConfig
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
...@@ -28,6 +27,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader ...@@ -28,6 +27,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter, from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers) make_empty_intermediate_tensors_factory, make_layers)
...@@ -266,7 +266,7 @@ class InternLM2Model(nn.Module): ...@@ -266,7 +266,7 @@ class InternLM2Model(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: IntermediateTensors = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank: if get_pp_group().is_first_rank:
...@@ -297,7 +297,7 @@ class InternLM2Model(nn.Module): ...@@ -297,7 +297,7 @@ class InternLM2Model(nn.Module):
return hidden_states return hidden_states
class InternLM2ForCausalLM(nn.Module): class InternLM2ForCausalLM(nn.Module, SupportsPP):
def __init__( def __init__(
self, self,
...@@ -325,7 +325,7 @@ class InternLM2ForCausalLM(nn.Module): ...@@ -325,7 +325,7 @@ class InternLM2ForCausalLM(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: IntermediateTensors, intermediate_tensors: Optional[IntermediateTensors],
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors) attn_metadata, intermediate_tensors)
......
...@@ -5,9 +5,9 @@ ...@@ -5,9 +5,9 @@
# Licensed under The MIT License [see LICENSE for details] # Licensed under The MIT License [see LICENSE for details]
# -------------------------------------------------------- # --------------------------------------------------------
import re import re
from functools import partial from functools import cached_property, partial
from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
Tuple, TypedDict, Union) TypedDict, Union)
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -17,7 +17,6 @@ from transformers import PretrainedConfig ...@@ -17,7 +17,6 @@ from transformers import PretrainedConfig
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import CacheConfig, MultiModalConfig
from vllm.distributed import get_pp_group
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
...@@ -32,7 +31,7 @@ from vllm.utils import is_list_of ...@@ -32,7 +31,7 @@ from vllm.utils import is_list_of
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip, from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
get_clip_num_patches) get_clip_num_patches)
from .interfaces import SupportsMultiModal from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (flatten_bn, group_weights_with_prefix, from .utils import (flatten_bn, group_weights_with_prefix,
init_vllm_registered_model, merge_multimodal_embeddings) init_vllm_registered_model, merge_multimodal_embeddings)
...@@ -123,7 +122,7 @@ def calculate_num_blocks(orig_width: int, orig_height: int, min_num: int, ...@@ -123,7 +122,7 @@ def calculate_num_blocks(orig_width: int, orig_height: int, min_num: int,
return blocks, target_width, target_height return blocks, target_width, target_height
def calculate_num_blocks_wrapper(hf_config: Dict[str, Any], def calculate_num_blocks_wrapper(hf_config: PretrainedConfig,
max_dynamic_patch: Optional[int] = None): max_dynamic_patch: Optional[int] = None):
if max_dynamic_patch is None: if max_dynamic_patch is None:
max_dynamic_patch = hf_config.max_dynamic_patch max_dynamic_patch = hf_config.max_dynamic_patch
...@@ -183,7 +182,7 @@ def image_to_pixel_values(image: Image.Image, input_size: int, min_num: int, ...@@ -183,7 +182,7 @@ def image_to_pixel_values(image: Image.Image, input_size: int, min_num: int,
return pixel_values return pixel_values
def image_to_pixel_values_wrapper(hf_config: Dict[str, Any], def image_to_pixel_values_wrapper(hf_config: PretrainedConfig,
max_dynamic_patch: Optional[int] = None): max_dynamic_patch: Optional[int] = None):
image_size = hf_config.vision_config.image_size image_size = hf_config.vision_config.image_size
min_num = hf_config.min_dynamic_patch min_num = hf_config.min_dynamic_patch
...@@ -197,7 +196,7 @@ def image_to_pixel_values_wrapper(hf_config: Dict[str, Any], ...@@ -197,7 +196,7 @@ def image_to_pixel_values_wrapper(hf_config: Dict[str, Any],
use_thumbnail=use_thumbnail) use_thumbnail=use_thumbnail)
def get_internvl_num_patches(hf_config: Dict[str, Any]): def get_internvl_num_patches(hf_config: PretrainedConfig):
vision_config = hf_config.vision_config vision_config = hf_config.vision_config
downsample_ratio = hf_config.downsample_ratio downsample_ratio = hf_config.downsample_ratio
image_size = vision_config.image_size image_size = vision_config.image_size
...@@ -362,7 +361,7 @@ def dummy_data_for_internvl(ctx: InputContext, ...@@ -362,7 +361,7 @@ def dummy_data_for_internvl(ctx: InputContext,
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_internvl_image_tokens) @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_internvl_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_internvl) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_internvl)
@INPUT_REGISTRY.register_input_processor(input_processor_for_internvl) @INPUT_REGISTRY.register_input_processor(input_processor_for_internvl)
class InternVLChatModel(nn.Module, SupportsMultiModal): class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
def __init__(self, def __init__(self,
config: PretrainedConfig, config: PretrainedConfig,
...@@ -408,10 +407,12 @@ class InternVLChatModel(nn.Module, SupportsMultiModal): ...@@ -408,10 +407,12 @@ class InternVLChatModel(nn.Module, SupportsMultiModal):
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors) self.language_model.make_empty_intermediate_tensors)
@cached_property
def sampler(self):
if hasattr(self.language_model, "sampler"): if hasattr(self.language_model, "sampler"):
self.sampler = self.language_model.sampler return self.language_model.sampler
else:
self.sampler = Sampler() return Sampler()
def pixel_shuffle(self, x, scale_factor=0.5): def pixel_shuffle(self, x, scale_factor=0.5):
n, w, h, c = x.size() n, w, h, c = x.size()
...@@ -515,18 +516,22 @@ class InternVLChatModel(nn.Module, SupportsMultiModal): ...@@ -515,18 +516,22 @@ class InternVLChatModel(nn.Module, SupportsMultiModal):
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs: object, **kwargs: object,
) -> SamplerOutput: ) -> Union[SamplerOutput, IntermediateTensors]:
image_input = self._parse_and_validate_image_input(**kwargs) if intermediate_tensors is not None:
if image_input is not None and get_pp_group().is_first_rank:
inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids)
vision_embeddings = self._process_image_input(image_input)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.img_context_token_id)
input_ids = None input_ids = None
else:
inputs_embeds = None inputs_embeds = None
else:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is not None:
inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids)
vision_embeddings = self._process_image_input(image_input)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.img_context_token_id)
input_ids = None
else:
inputs_embeds = None
hidden_states = self.language_model.model(input_ids, hidden_states = self.language_model.model(input_ids,
positions, positions,
......
...@@ -33,8 +33,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, ...@@ -33,8 +33,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization import QuantizationConfig
QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
...@@ -43,7 +42,9 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata ...@@ -43,7 +42,9 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs import JAISConfig from vllm.transformers_utils.configs import JAISConfig
from .utils import is_pp_missing_parameter, make_layers from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
class SwiGLUActivation(nn.Module): class SwiGLUActivation(nn.Module):
...@@ -244,6 +245,9 @@ class JAISModel(nn.Module): ...@@ -244,6 +245,9 @@ class JAISModel(nn.Module):
) )
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(["hidden_states"],
config.n_embd))
def forward( def forward(
self, self,
...@@ -279,7 +283,7 @@ class JAISModel(nn.Module): ...@@ -279,7 +283,7 @@ class JAISModel(nn.Module):
return hidden_states return hidden_states
class JAISLMHeadModel(nn.Module): class JAISLMHeadModel(nn.Module, SupportsPP):
def __init__( def __init__(
self, self,
...@@ -304,6 +308,8 @@ class JAISLMHeadModel(nn.Module): ...@@ -304,6 +308,8 @@ class JAISLMHeadModel(nn.Module):
self.logits_processor = LogitsProcessor(vocab_size=config.vocab_size, self.logits_processor = LogitsProcessor(vocab_size=config.vocab_size,
scale=self.output_logits_scale) scale=self.output_logits_scale)
self.sampler = Sampler() self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
self.transformer.make_empty_intermediate_tensors)
def forward( def forward(
self, self,
...@@ -326,16 +332,6 @@ class JAISLMHeadModel(nn.Module): ...@@ -326,16 +332,6 @@ class JAISLMHeadModel(nn.Module):
sampling_metadata) sampling_metadata)
return logits return logits
def make_empty_intermediate_tensors(
self, batch_size: int, dtype: torch.dtype,
device: torch.device) -> IntermediateTensors:
return IntermediateTensors({
"hidden_states":
torch.zeros((batch_size, self.config.hidden_size),
dtype=dtype,
device=device),
})
def sample( def sample(
self, self,
logits: torch.Tensor, logits: torch.Tensor,
......
...@@ -37,8 +37,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, ...@@ -37,8 +37,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization import QuantizationConfig
QuantizationConfig)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
get_compressed_tensors_cache_scale) get_compressed_tensors_cache_scale)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
...@@ -51,8 +50,9 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata ...@@ -51,8 +50,9 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import is_hip from vllm.utils import is_hip
from .interfaces import SupportsLoRA from .interfaces import SupportsLoRA, SupportsPP
from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers from .utils import (PPMissingLayer, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
class LlamaMLP(nn.Module): class LlamaMLP(nn.Module):
...@@ -72,12 +72,15 @@ class LlamaMLP(nn.Module): ...@@ -72,12 +72,15 @@ class LlamaMLP(nn.Module):
output_sizes=[intermediate_size] * 2, output_sizes=[intermediate_size] * 2,
bias=bias, bias=bias,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj") prefix=f"{prefix}.gate_up_proj",
self.down_proj = RowParallelLinear(input_size=intermediate_size, )
output_size=hidden_size, self.down_proj = RowParallelLinear(
bias=bias, input_size=intermediate_size,
quant_config=quant_config, output_size=hidden_size,
prefix=f"{prefix}.down_proj") bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
)
if hidden_act != "silu": if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. " raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.") "Only silu is supported for now.")
...@@ -161,12 +164,14 @@ class LlamaAttention(nn.Module): ...@@ -161,12 +164,14 @@ class LlamaAttention(nn.Module):
rope_scaling=rope_scaling, rope_scaling=rope_scaling,
is_neox_style=is_neox_style, is_neox_style=is_neox_style,
) )
self.attn = Attention(self.num_heads, self.attn = Attention(
self.head_dim, self.num_heads,
self.scaling, self.head_dim,
num_kv_heads=self.num_kv_heads, self.scaling,
cache_config=cache_config, num_kv_heads=self.num_kv_heads,
quant_config=quant_config) cache_config=cache_config,
quant_config=quant_config,
)
def forward( def forward(
self, self,
...@@ -248,12 +253,10 @@ class LlamaDecoderLayer(nn.Module): ...@@ -248,12 +253,10 @@ class LlamaDecoderLayer(nn.Module):
else: else:
hidden_states, residual = self.input_layernorm( hidden_states, residual = self.input_layernorm(
hidden_states, residual) hidden_states, residual)
hidden_states = self.self_attn( hidden_states = self.self_attn(positions=positions,
positions=positions, hidden_states=hidden_states,
hidden_states=hidden_states, kv_cache=kv_cache,
kv_cache=kv_cache, attn_metadata=attn_metadata)
attn_metadata=attn_metadata,
)
# Fully Connected # Fully Connected
hidden_states, residual = self.post_attention_layernorm( hidden_states, residual = self.post_attention_layernorm(
...@@ -295,12 +298,17 @@ class LlamaModel(nn.Module): ...@@ -295,12 +298,17 @@ class LlamaModel(nn.Module):
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
prefix=prefix), prefix=prefix),
prefix=f"{prefix}.layers") prefix=f"{prefix}.layers",
)
if get_pp_group().is_last_rank: if get_pp_group().is_last_rank:
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
else: else:
self.norm = PPMissingLayer() self.norm = PPMissingLayer()
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids)
...@@ -326,13 +334,9 @@ class LlamaModel(nn.Module): ...@@ -326,13 +334,9 @@ class LlamaModel(nn.Module):
for i in range(self.start_layer, self.end_layer): for i in range(self.start_layer, self.end_layer):
layer = self.layers[i] layer = self.layers[i]
hidden_states, residual = layer( hidden_states, residual = layer(positions, hidden_states,
positions, kv_caches[i - self.start_layer],
hidden_states, attn_metadata, residual)
kv_caches[i - self.start_layer],
attn_metadata,
residual,
)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
return IntermediateTensors({ return IntermediateTensors({
...@@ -344,17 +348,10 @@ class LlamaModel(nn.Module): ...@@ -344,17 +348,10 @@ class LlamaModel(nn.Module):
return hidden_states return hidden_states
class LlamaForCausalLM(nn.Module, SupportsLoRA): class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": [ "qkv_proj": ["q_proj", "k_proj", "v_proj"],
"q_proj", "gate_up_proj": ["gate_proj", "up_proj"]
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
} }
# LoRA specific attributes # LoRA specific attributes
...@@ -364,7 +361,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA): ...@@ -364,7 +361,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
] ]
embedding_modules = { embedding_modules = {
"embed_tokens": "input_embeddings", "embed_tokens": "input_embeddings",
"lm_head": "output_embeddings", "lm_head": "output_embeddings"
} }
embedding_padding_modules = ["lm_head"] embedding_padding_modules = ["lm_head"]
bitsandbytes_stacked_params_mapping = { bitsandbytes_stacked_params_mapping = {
...@@ -420,10 +417,12 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA): ...@@ -420,10 +417,12 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
self.unpadded_vocab_size, self.unpadded_vocab_size,
config.hidden_size, config.hidden_size,
org_num_embeddings=config.vocab_size, org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE padding_size=(
# We need bigger padding if using lora for kernel DEFAULT_VOCAB_PADDING_SIZE
# compatibility # We need bigger padding if using lora for kernel
if not lora_config else lora_config.lora_vocab_padding_size, # compatibility
if not lora_config else
lora_config.lora_vocab_padding_size),
quant_config=quant_config, quant_config=quant_config,
) )
if config.tie_word_embeddings: if config.tie_word_embeddings:
...@@ -436,6 +435,8 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA): ...@@ -436,6 +435,8 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
self.sampler = Sampler() self.sampler = Sampler()
else: else:
self.lm_head = PPMissingLayer() self.lm_head = PPMissingLayer()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def forward( def forward(
self, self,
...@@ -458,28 +459,11 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA): ...@@ -458,28 +459,11 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
sampling_metadata) sampling_metadata)
return logits return logits
def sample( def sample(self, logits: torch.Tensor,
self, sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]:
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata) next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens return next_tokens
def make_empty_intermediate_tensors(
self, batch_size: int, dtype: torch.dtype,
device: torch.device) -> IntermediateTensors:
return IntermediateTensors({
"hidden_states":
torch.zeros((batch_size, self.config.hidden_size),
dtype=dtype,
device=device),
"residual":
torch.zeros((batch_size, self.config.hidden_size),
dtype=dtype,
device=device),
})
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
...@@ -513,7 +497,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA): ...@@ -513,7 +497,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
loaded_weight = loaded_weight[0] loaded_weight = loaded_weight[0]
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
continue continue
for (param_name, weight_name, shard_id) in stacked_params_mapping: for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
continue continue
name = name.replace(weight_name, param_name) name = name.replace(weight_name, param_name)
......
from typing import Iterable, List, Optional, Tuple from typing import Iterable, List, Optional, Tuple, Union
import torch import torch
from torch import nn from torch import nn
...@@ -8,10 +8,13 @@ from vllm.model_executor.layers.pooler import Pooler, PoolingType ...@@ -8,10 +8,13 @@ from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.llama import LlamaModel from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import PoolerOutput from vllm.sequence import IntermediateTensors, PoolerOutput
from .interfaces import SupportsPP
from .utils import is_pp_missing_parameter
class LlamaEmbeddingModel(nn.Module):
class LlamaEmbeddingModel(nn.Module, SupportsPP):
"""A model that uses Llama with additional embedding functionalities. """A model that uses Llama with additional embedding functionalities.
This class encapsulates the LlamaModel and provides an interface for This class encapsulates the LlamaModel and provides an interface for
...@@ -29,6 +32,8 @@ class LlamaEmbeddingModel(nn.Module): ...@@ -29,6 +32,8 @@ class LlamaEmbeddingModel(nn.Module):
super().__init__() super().__init__()
self.model = LlamaModel(**kwargs) self.model = LlamaModel(**kwargs)
self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def forward( def forward(
self, self,
...@@ -36,10 +41,12 @@ class LlamaEmbeddingModel(nn.Module): ...@@ -36,10 +41,12 @@ class LlamaEmbeddingModel(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> Union[torch.Tensor, IntermediateTensors]:
return self.model.forward(input_ids, positions, kv_caches, return self.model.forward(input_ids, positions, kv_caches,
attn_metadata, inputs_embeds) attn_metadata, intermediate_tensors,
inputs_embeds)
def pooler( def pooler(
self, self,
...@@ -73,6 +80,8 @@ class LlamaEmbeddingModel(nn.Module): ...@@ -73,6 +80,8 @@ class LlamaEmbeddingModel(nn.Module):
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name] param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
...@@ -81,6 +90,8 @@ class LlamaEmbeddingModel(nn.Module): ...@@ -81,6 +90,8 @@ class LlamaEmbeddingModel(nn.Module):
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
......
from functools import cached_property
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union) TypedDict, Union)
...@@ -11,7 +12,7 @@ from vllm.config import CacheConfig, MultiModalConfig ...@@ -11,7 +12,7 @@ from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
...@@ -21,7 +22,7 @@ from vllm.utils import is_list_of ...@@ -21,7 +22,7 @@ from vllm.utils import is_list_of
from .clip import (CLIPVisionModel, dummy_image_for_clip, from .clip import (CLIPVisionModel, dummy_image_for_clip,
dummy_seq_data_for_clip, get_max_clip_image_tokens, dummy_seq_data_for_clip, get_max_clip_image_tokens,
input_processor_for_clip) input_processor_for_clip)
from .interfaces import SupportsMultiModal from .interfaces import SupportsMultiModal, SupportsPP
from .siglip import (SiglipVisionModel, dummy_image_for_siglip, from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip, get_max_siglip_image_tokens, dummy_seq_data_for_siglip, get_max_siglip_image_tokens,
input_processor_for_siglip) input_processor_for_siglip)
...@@ -198,7 +199,7 @@ def _init_vision_tower(hf_config: LlavaConfig): ...@@ -198,7 +199,7 @@ def _init_vision_tower(hf_config: LlavaConfig):
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens) @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava)
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava) @INPUT_REGISTRY.register_input_processor(input_processor_for_llava)
class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal): class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
def __init__(self, def __init__(self,
config: LlavaConfig, config: LlavaConfig,
...@@ -220,6 +221,16 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -220,6 +221,16 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal):
self.language_model = init_vllm_registered_model( self.language_model = init_vllm_registered_model(
config.text_config, cache_config, quant_config) config.text_config, cache_config, quant_config)
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
@cached_property
def sampler(self):
if hasattr(self.language_model, "sampler"):
return self.language_model.sampler
return Sampler()
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
h = w = self.config.vision_config.image_size h = w = self.config.vision_config.image_size
expected_dims = (3, h, w) expected_dims = (3, h, w)
...@@ -315,7 +326,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -315,7 +326,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal):
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs: object, **kwargs: object,
) -> SamplerOutput: ) -> Union[torch.Tensor, IntermediateTensors]:
"""Run forward pass for LLaVA-1.5. """Run forward pass for LLaVA-1.5.
One key thing to understand is the `input_ids` already accounts for the One key thing to understand is the `input_ids` already accounts for the
...@@ -351,26 +362,30 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -351,26 +362,30 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal):
See also: See also:
:class:`LlavaImageInputs` :class:`LlavaImageInputs`
""" """
image_input = self._parse_and_validate_image_input(**kwargs) if intermediate_tensors is not None:
input_ids = None
inputs_embeds = None
else:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is not None: if image_input is not None:
vision_embeddings = self._process_image_input(image_input) vision_embeddings = self._process_image_input(image_input)
inputs_embeds = self.language_model.model.get_input_embeddings( inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids) input_ids)
inputs_embeds = merge_multimodal_embeddings( inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings, input_ids, inputs_embeds, vision_embeddings,
self.config.image_token_index) self.config.image_token_index)
input_ids = None input_ids = None
else: else:
inputs_embeds = None inputs_embeds = None
hidden_states = self.language_model.model(input_ids, hidden_states = self.language_model.model(input_ids,
positions, positions,
kv_caches, kv_caches,
attn_metadata, attn_metadata,
None, intermediate_tensors,
inputs_embeds=inputs_embeds) inputs_embeds=inputs_embeds)
return hidden_states return hidden_states
......
from functools import cached_property
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union) TypedDict, Union)
...@@ -13,7 +14,7 @@ from vllm.attention import AttentionMetadata ...@@ -13,7 +14,7 @@ from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
...@@ -23,7 +24,7 @@ from vllm.utils import is_list_of ...@@ -23,7 +24,7 @@ from vllm.utils import is_list_of
from .clip import (CLIPVisionModel, dummy_image_for_clip, from .clip import (CLIPVisionModel, dummy_image_for_clip,
dummy_seq_data_for_clip, get_clip_image_feature_size, dummy_seq_data_for_clip, get_clip_image_feature_size,
get_clip_patch_grid_length, input_processor_for_clip) get_clip_patch_grid_length, input_processor_for_clip)
from .interfaces import SupportsMultiModal from .interfaces import SupportsMultiModal, SupportsPP
from .llava import LlavaMultiModalProjector from .llava import LlavaMultiModalProjector
from .siglip import (SiglipVisionModel, dummy_image_for_siglip, from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip, get_siglip_image_feature_size, dummy_seq_data_for_siglip, get_siglip_image_feature_size,
...@@ -286,7 +287,8 @@ def _init_vision_tower(hf_config: LlavaNextConfig): ...@@ -286,7 +287,8 @@ def _init_vision_tower(hf_config: LlavaNextConfig):
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_next_image_tokens) @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_next_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next)
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava_next) @INPUT_REGISTRY.register_input_processor(input_processor_for_llava_next)
class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal): class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP):
def __init__(self, def __init__(self,
config: LlavaNextConfig, config: LlavaNextConfig,
...@@ -300,6 +302,8 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -300,6 +302,8 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal):
# TODO: Optionally initializes this for supporting embeddings. # TODO: Optionally initializes this for supporting embeddings.
self.vision_tower = _init_vision_tower(config) self.vision_tower = _init_vision_tower(config)
self.image_newline = nn.Parameter(
torch.empty(config.text_config.hidden_size))
self.multi_modal_projector = LlavaMultiModalProjector( self.multi_modal_projector = LlavaMultiModalProjector(
vision_hidden_size=config.vision_config.hidden_size, vision_hidden_size=config.vision_config.hidden_size,
text_hidden_size=config.text_config.hidden_size, text_hidden_size=config.text_config.hidden_size,
...@@ -308,8 +312,15 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -308,8 +312,15 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal):
self.language_model = init_vllm_registered_model( self.language_model = init_vllm_registered_model(
config.text_config, cache_config, quant_config) config.text_config, cache_config, quant_config)
self.image_newline = nn.Parameter( self.make_empty_intermediate_tensors = (
torch.empty(config.text_config.hidden_size)) self.language_model.make_empty_intermediate_tensors)
@cached_property
def sampler(self):
if hasattr(self.language_model, "sampler"):
return self.language_model.sampler
return Sampler()
def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor: def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor:
expected_dims = (2, ) expected_dims = (2, )
...@@ -542,7 +553,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -542,7 +553,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal):
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs: object, **kwargs: object,
) -> SamplerOutput: ) -> Union[torch.Tensor, IntermediateTensors]:
"""Run forward pass for LlaVA-NeXT. """Run forward pass for LlaVA-NeXT.
One key thing to understand is the `input_ids` already accounts for the One key thing to understand is the `input_ids` already accounts for the
...@@ -587,26 +598,30 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -587,26 +598,30 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal):
See also: See also:
:class:`LlavaNextImageInputs` :class:`LlavaNextImageInputs`
""" """
image_input = self._parse_and_validate_image_input(**kwargs) if intermediate_tensors is not None:
input_ids = None
inputs_embeds = None
else:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is not None: if image_input is not None:
vision_embeddings = self._process_image_input(image_input) vision_embeddings = self._process_image_input(image_input)
inputs_embeds = self.language_model.model.get_input_embeddings( inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids) input_ids)
inputs_embeds = merge_multimodal_embeddings( inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings, input_ids, inputs_embeds, vision_embeddings,
self.config.image_token_index) self.config.image_token_index)
input_ids = None input_ids = None
else: else:
inputs_embeds = None inputs_embeds = None
hidden_states = self.language_model.model(input_ids, hidden_states = self.language_model.model(input_ids,
positions, positions,
kv_caches, kv_caches,
attn_metadata, attn_metadata,
None, intermediate_tensors,
inputs_embeds=inputs_embeds) inputs_embeds=inputs_embeds)
return hidden_states return hidden_states
......
import math import math
from functools import cached_property
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union) TypedDict, Union)
...@@ -12,9 +13,8 @@ from vllm.attention import AttentionMetadata ...@@ -12,9 +13,8 @@ from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization import QuantizationConfig
QuantizationConfig) from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.clip import CLIPVisionModel from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
...@@ -25,7 +25,7 @@ from vllm.sequence import IntermediateTensors ...@@ -25,7 +25,7 @@ from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of from vllm.utils import is_list_of
from .clip import dummy_image_for_clip, dummy_seq_data_for_clip from .clip import dummy_image_for_clip, dummy_seq_data_for_clip
from .interfaces import SupportsMultiModal from .interfaces import SupportsMultiModal, SupportsPP
from .siglip import (SiglipVisionModel, dummy_image_for_siglip, from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip) dummy_seq_data_for_siglip)
from .utils import (group_weights_with_prefix, init_vllm_registered_model, from .utils import (group_weights_with_prefix, init_vllm_registered_model,
...@@ -267,7 +267,8 @@ class LlavaNextMultiModalProjector(nn.Module): ...@@ -267,7 +267,8 @@ class LlavaNextMultiModalProjector(nn.Module):
"video", get_max_llava_next_video_tokens) "video", get_max_llava_next_video_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next_video) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next_video)
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava_next_video) @INPUT_REGISTRY.register_input_processor(input_processor_for_llava_next_video)
class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal): class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP):
def __init__(self, def __init__(self,
config: LlavaNextVideoConfig, config: LlavaNextVideoConfig,
...@@ -281,13 +282,23 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -281,13 +282,23 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal):
# Initialize the vision tower only up to the required feature layer # Initialize the vision tower only up to the required feature layer
self.vision_tower = _init_vision_tower(config) self.vision_tower = _init_vision_tower(config)
self.vision_resampler = LlavaNextVideoPooler(config)
self.multi_modal_projector = LlavaNextMultiModalProjector( self.multi_modal_projector = LlavaNextMultiModalProjector(
vision_hidden_size=config.vision_config.hidden_size, vision_hidden_size=config.vision_config.hidden_size,
text_hidden_size=config.text_config.hidden_size, text_hidden_size=config.text_config.hidden_size,
projector_hidden_act=config.projector_hidden_act) projector_hidden_act=config.projector_hidden_act)
self.language_model = init_vllm_registered_model( self.language_model = init_vllm_registered_model(
config.text_config, cache_config, quant_config) config.text_config, cache_config, quant_config)
self.vision_resampler = LlavaNextVideoPooler(config)
self.make_empty_intermediate_tensors = (
self.language_model.model.make_empty_intermediate_tensors)
@cached_property
def sampler(self):
if hasattr(self.language_model, "sampler"):
return self.language_model.sampler
return Sampler()
def _validate_video_pixel_values( def _validate_video_pixel_values(
self, data: Union[torch.Tensor, List[torch.Tensor]] self, data: Union[torch.Tensor, List[torch.Tensor]]
...@@ -397,34 +408,36 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -397,34 +408,36 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal):
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs: object, **kwargs: object,
) -> SamplerOutput: ) -> Union[torch.Tensor, IntermediateTensors]:
"""Run forward pass for LlaVA-NeXT-Video. """Run forward pass for LlaVA-NeXT-Video.
Args: Args:
input_ids: Flattened (concatenated) input_ids corresponding to a input_ids: Flattened (concatenated) input_ids corresponding to a
batch. batch.
pixel_values_videos: Pixels in each frames for each input videos. pixel_values_videos: Pixels in each frames for each input videos.
""" """
video_input = self._parse_and_validate_video_input(**kwargs) if intermediate_tensors is not None:
# merge video embeddings into input embeddings
if video_input is not None:
video_embeddings = self._process_video_pixels(video_input)
inputs_embeds = self.language_model \
.model.get_input_embeddings(input_ids)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, video_embeddings,
self.config.video_token_index)
input_ids = None input_ids = None
else:
inputs_embeds = None inputs_embeds = None
else:
video_input = self._parse_and_validate_video_input(**kwargs)
if video_input is not None:
video_embeddings = self._process_video_pixels(video_input)
inputs_embeds = self.language_model \
.model.get_input_embeddings(input_ids)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, video_embeddings,
self.config.video_token_index)
input_ids = None
else:
inputs_embeds = None
hidden_states = self.language_model.model(input_ids, hidden_states = self.language_model.model(input_ids,
positions, positions,
kv_caches, kv_caches,
attn_metadata, attn_metadata,
None, intermediate_tensors,
inputs_embeds=inputs_embeds) inputs_embeds=inputs_embeds)
return hidden_states return hidden_states
......
import math import math
from functools import cached_property
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union) TypedDict, Union)
...@@ -17,9 +18,8 @@ from vllm.config import CacheConfig, MultiModalConfig ...@@ -17,9 +18,8 @@ from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization import QuantizationConfig
QuantizationConfig) from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
...@@ -31,7 +31,7 @@ from vllm.utils import is_list_of ...@@ -31,7 +31,7 @@ from vllm.utils import is_list_of
from .clip import (CLIPVisionModel, dummy_seq_data_for_clip, from .clip import (CLIPVisionModel, dummy_seq_data_for_clip,
dummy_video_for_clip, get_clip_image_feature_size, dummy_video_for_clip, get_clip_image_feature_size,
get_clip_patch_grid_length, input_processor_for_clip) get_clip_patch_grid_length, input_processor_for_clip)
from .interfaces import SupportsMultiModal from .interfaces import SupportsMultiModal, SupportsPP
from .siglip import (SiglipVisionModel, dummy_seq_data_for_siglip, from .siglip import (SiglipVisionModel, dummy_seq_data_for_siglip,
dummy_video_for_siglip, get_siglip_image_feature_size, dummy_video_for_siglip, get_siglip_image_feature_size,
get_siglip_patch_grid_length, input_processor_for_siglip) get_siglip_patch_grid_length, input_processor_for_siglip)
...@@ -414,7 +414,8 @@ class LlavaOnevisionMultiModalProjector(nn.Module): ...@@ -414,7 +414,8 @@ class LlavaOnevisionMultiModalProjector(nn.Module):
"video", get_max_llava_onevision_video_tokens) "video", get_max_llava_onevision_video_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_onevision) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_onevision)
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava_onevision) @INPUT_REGISTRY.register_input_processor(input_processor_for_llava_onevision)
class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal): class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP):
def __init__(self, def __init__(self,
config: LlavaOnevisionConfig, config: LlavaOnevisionConfig,
...@@ -434,6 +435,16 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -434,6 +435,16 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal):
self.image_newline = nn.Parameter( self.image_newline = nn.Parameter(
torch.empty(config.text_config.hidden_size)) torch.empty(config.text_config.hidden_size))
self.make_empty_intermediate_tensors = (
self.language_model.model.make_empty_intermediate_tensors)
@cached_property
def sampler(self):
if hasattr(self.language_model, "sampler"):
return self.language_model.sampler
return Sampler()
def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor: def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor:
expected_dims = (2, ) expected_dims = (2, )
...@@ -805,39 +816,42 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -805,39 +816,42 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal):
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs: object, **kwargs: object,
) -> SamplerOutput: ) -> Union[torch.Tensor, IntermediateTensors]:
"""Run forward pass for LlaVA-Onevision. """Run forward pass for LlaVA-Onevision.
Args: Args:
input_ids: Flattened (concatenated) input_ids corresponding to a input_ids: Flattened (concatenated) input_ids corresponding to a
batch. batch.
pixel_values_videos: Pixels in each frames for each input videos. pixel_values_videos: Pixels in each frames for each input videos.
""" """
modalities = self._parse_and_validate_multimodal_inputs(**kwargs) if intermediate_tensors is not None:
# merge video embeddings into input embeddings
if modalities:
inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids)
if "images" in modalities:
image_input = modalities["images"]
vision_embeddings = self._process_image_input(image_input)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.config.image_token_index)
if "videos" in modalities:
video_input = modalities["videos"]
video_embeddings = self._process_video_pixels(video_input)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, video_embeddings,
self.config.video_token_index)
input_ids = None input_ids = None
else:
inputs_embeds = None inputs_embeds = None
else:
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
if modalities:
inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids)
if "images" in modalities:
image_input = modalities["images"]
vision_embeddings = self._process_image_input(image_input)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.config.image_token_index)
if "videos" in modalities:
video_input = modalities["videos"]
video_embeddings = self._process_video_pixels(video_input)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, video_embeddings,
self.config.video_token_index)
input_ids = None
else:
inputs_embeds = None
hidden_states = self.language_model.model(input_ids, hidden_states = self.language_model.model(input_ids,
positions, positions,
kv_caches, kv_caches,
attn_metadata, attn_metadata,
None, intermediate_tensors,
inputs_embeds=inputs_embeds) inputs_embeds=inputs_embeds)
return hidden_states return hidden_states
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment