Unverified Commit 0f6d7a9a authored by Murali Andoorveedu's avatar Murali Andoorveedu Committed by GitHub
Browse files

[Models] Add remaining model PP support (#7168)


Signed-off-by: default avatarMuralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
Signed-off-by: default avatarMurali Andoorveedu <muralidhar.andoorveedu@centml.ai>
Co-authored-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 303d4479
......@@ -28,7 +28,7 @@ from transformers import FalconConfig as HF_FalconConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.activation import get_act_fn
......@@ -36,8 +36,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import (
......@@ -47,6 +46,10 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs import RWConfig
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
FalconConfig = Union[HF_FalconConfig, RWConfig]
......@@ -333,6 +336,7 @@ class FalconModel(nn.Module):
config: FalconConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.config = config
......@@ -347,35 +351,45 @@ class FalconModel(nn.Module):
)
# Transformer blocks
self.h = nn.ModuleList([
FalconDecoderLayer(config, cache_config, quant_config)
for _ in range(config.num_hidden_layers)
])
self.start_layer, self.end_layer, self.h = make_layers(
config.num_hidden_layers,
lambda prefix: FalconDecoderLayer(config, cache_config,
quant_config),
prefix=f"{prefix}.h")
# Final Layer Norm
self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(["hidden_states"],
config.hidden_size))
def forward(
self,
input_ids: torch.LongTensor,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states = self.word_embeddings(input_ids)
for i in range(len(self.h)):
intermediate_tensors: Optional[IntermediateTensors],
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
hidden_states = self.word_embeddings(input_ids)
else:
hidden_states = intermediate_tensors["hidden_states"]
for i in range(self.start_layer, self.end_layer):
layer = self.h[i]
hidden_states = layer(
positions,
hidden_states,
kv_caches[i],
kv_caches[i - self.start_layer],
attn_metadata,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})
hidden_states = self.ln_f(hidden_states)
return hidden_states
class FalconForCausalLM(nn.Module):
class FalconForCausalLM(nn.Module, SupportsPP):
def __init__(
self,
......@@ -403,6 +417,8 @@ class FalconForCausalLM(nn.Module):
)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
self.transformer.make_empty_intermediate_tensors)
def forward(
self,
......@@ -412,12 +428,8 @@ class FalconForCausalLM(nn.Module):
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor:
hidden_states = self.transformer(
input_ids,
positions,
kv_caches,
attn_metadata,
)
hidden_states = self.transformer(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors)
return hidden_states
def compute_logits(
......@@ -454,6 +466,8 @@ class FalconForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
if "query_key_value" in name:
output_dim = getattr(param, "output_dim", None)
......
......@@ -41,8 +41,9 @@ from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
SequenceData)
from .interfaces import SupportsMultiModal
from .utils import flatten_bn, merge_multimodal_embeddings
from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (flatten_bn, group_weights_with_prefix,
merge_multimodal_embeddings)
# Cannot find the following 2 numbers from hf config.
_IMAGE_TOKEN_ID = 71011
......@@ -217,7 +218,7 @@ def input_mapper_for_fuyu(ctx: InputContext, data: object):
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_fuyu_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_fuyu)
@INPUT_REGISTRY.register_input_processor(input_processor_for_fuyu)
class FuyuForCausalLM(nn.Module, SupportsMultiModal):
class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
def __init__(self,
config: FuyuConfig,
......@@ -242,6 +243,12 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal):
self.language_model = PersimmonForCausalLM(config.text_config,
cache_config=cache_config,
quant_config=quant_config)
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
@property
def sampler(self):
return self.language_model.sampler
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
......@@ -297,23 +304,29 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal):
intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs: object,
):
image_input = self._parse_and_validate_image_input(**kwargs)
if intermediate_tensors is not None:
input_ids = None
inputs_embeds = None
else:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is not None:
vision_embeddings = self._process_image_input(image_input)
inputs_embeds = self.language_model.model.embed_tokens(input_ids)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.image_token_id)
if image_input is not None:
vision_embeddings = self._process_image_input(image_input)
inputs_embeds = self.language_model.model.embed_tokens(
input_ids)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.image_token_id)
else:
inputs_embeds = None
else:
inputs_embeds = None
hidden_states = self.language_model(
input_ids=input_ids,
positions=positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
)
return hidden_states
......@@ -336,34 +349,16 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal):
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if ("rotary_emb.cos_cached" in name
or "rotary_emb.sin_cached" in name):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
param = params_dict[name]
if "query_key_value" in name:
# copy from vllm/model_executor/models/bloom.py
# NOTE: Fuyu's fused QKV's output_dim has the shape of
# (num_heads * 3 * head_size), while the
# required shape is (3 * num_heads * head_size).
# Thus, we need weight conversion.
output_dim = getattr(param, "output_dim", None)
num_heads = self.config.num_attention_heads
if output_dim is not None:
loaded_weight_shape = loaded_weight.shape
loaded_weight = loaded_weight.view(
loaded_weight_shape[:output_dim] + (num_heads, 3, -1) +
loaded_weight_shape[output_dim + 1:])
loaded_weight = loaded_weight.transpose(
output_dim, output_dim + 1)
loaded_weight = loaded_weight.reshape(loaded_weight_shape)
# prepare weight iterators for components
weights_group = group_weights_with_prefix(weights)
# load vision embeddings
vision_params_dict = dict(self.vision_embed_tokens.named_parameters())
for name, loaded_weight in weights_group["vision_embed_tokens"]:
param = vision_params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# load llm backbone
self.language_model.load_weights(weights_group["language_model"])
......@@ -15,7 +15,7 @@
# limitations under the License.
"""Inference-only Gemma model compatible with HuggingFace weights."""
from functools import lru_cache
from typing import Iterable, List, Optional, Set, Tuple
from typing import Iterable, List, Optional, Set, Tuple, Union
import torch
from torch import nn
......@@ -23,7 +23,7 @@ from transformers import GemmaConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import GeluAndMul
from vllm.model_executor.layers.layernorm import GemmaRMSNorm
......@@ -31,8 +31,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import (
......@@ -41,7 +40,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
logger = init_logger(__name__)
......@@ -245,6 +246,7 @@ class GemmaModel(nn.Module):
config: GemmaConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
......@@ -253,10 +255,11 @@ class GemmaModel(nn.Module):
config.vocab_size,
config.hidden_size,
)
self.layers = nn.ModuleList([
GemmaDecoderLayer(config, cache_config, quant_config)
for _ in range(config.num_hidden_layers)
])
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: GemmaDecoderLayer(config, cache_config, quant_config
),
prefix=f"{prefix}.layers")
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# Normalize the embedding by sqrt(hidden_size)
......@@ -265,6 +268,9 @@ class GemmaModel(nn.Module):
# See https://github.com/huggingface/transformers/pull/29402
normalizer = self.config.hidden_size**0.5
self.register_buffer("normalizer", torch.tensor(normalizer))
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
......@@ -275,29 +281,38 @@ class GemmaModel(nn.Module):
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if inputs_embeds is not None:
hidden_states = inputs_embeds
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
hidden_states *= self.normalizer
residual = None
else:
hidden_states = self.get_input_embeddings(input_ids)
hidden_states *= self.normalizer
residual = None
for i in range(len(self.layers)):
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(
positions,
hidden_states,
kv_caches[i],
kv_caches[i - self.start_layer],
attn_metadata,
residual,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class GemmaForCausalLM(nn.Module, SupportsLoRA):
class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
......@@ -339,6 +354,8 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA):
self.model = GemmaModel(config, cache_config, quant_config)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def forward(
self,
......@@ -347,9 +364,9 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA):
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor:
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata)
attn_metadata, intermediate_tensors)
return hidden_states
def compute_logits(
......@@ -388,6 +405,8 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA):
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
......@@ -400,6 +419,8 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA):
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
......
......@@ -14,7 +14,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Iterable, List, Optional, Set, Tuple
from typing import Iterable, List, Optional, Set, Tuple, Union
import torch
from torch import nn
......@@ -22,7 +22,7 @@ from transformers import Gemma2Config
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import GeluAndMul
from vllm.model_executor.layers.layernorm import GemmaRMSNorm
......@@ -30,8 +30,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import (
......@@ -40,7 +39,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
logger = init_logger(__name__)
......@@ -244,6 +245,7 @@ class Gemma2Model(nn.Module):
config: Gemma2Config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
......@@ -252,10 +254,11 @@ class Gemma2Model(nn.Module):
config.vocab_size,
config.hidden_size,
)
self.layers = nn.ModuleList([
Gemma2DecoderLayer(layer_idx, config, cache_config, quant_config)
for layer_idx in range(config.num_hidden_layers)
])
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: Gemma2DecoderLayer(int(prefix.split(".")[
-1]), config, cache_config, quant_config),
prefix=f"{prefix}.layers")
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# Normalize the embedding by sqrt(hidden_size)
......@@ -264,6 +267,9 @@ class Gemma2Model(nn.Module):
# See https://github.com/huggingface/transformers/pull/29402
normalizer = self.config.hidden_size**0.5
self.register_buffer("normalizer", torch.tensor(normalizer))
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
def forward(
self,
......@@ -271,25 +277,36 @@ class Gemma2Model(nn.Module):
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
hidden_states *= self.normalizer
intermediate_tensors: Optional[IntermediateTensors],
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
hidden_states = self.embed_tokens(input_ids)
hidden_states *= self.normalizer
residual = None
for i in range(len(self.layers)):
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(
positions,
hidden_states,
kv_caches[i],
kv_caches[i - self.start_layer],
attn_metadata,
residual,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class Gemma2ForCausalLM(nn.Module, SupportsLoRA):
class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
......@@ -338,6 +355,8 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA):
self.logits_processor = LogitsProcessor(
config.vocab_size, soft_cap=config.final_logit_softcapping)
self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def forward(
self,
......@@ -346,9 +365,9 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA):
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor:
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata)
attn_metadata, intermediate_tensors)
return hidden_states
def compute_logits(
......@@ -387,6 +406,8 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA):
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
......@@ -399,6 +420,8 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA):
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
......
......@@ -32,8 +32,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
......@@ -41,7 +40,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .utils import is_pp_missing_parameter, make_layers
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
class GPT2Attention(nn.Module):
......@@ -204,6 +205,9 @@ class GPT2Model(nn.Module):
config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.h")
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(["hidden_states"],
config.n_embd))
def forward(
self,
......@@ -234,7 +238,7 @@ class GPT2Model(nn.Module):
return hidden_states
class GPT2LMHeadModel(nn.Module):
class GPT2LMHeadModel(nn.Module, SupportsPP):
def __init__(
self,
......@@ -256,6 +260,8 @@ class GPT2LMHeadModel(nn.Module):
self.config.hidden_size)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
self.transformer.make_empty_intermediate_tensors)
def forward(
self,
......@@ -264,7 +270,7 @@ class GPT2LMHeadModel(nn.Module):
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor:
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.transformer(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors)
return hidden_states
......@@ -286,16 +292,6 @@ class GPT2LMHeadModel(nn.Module):
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def make_empty_intermediate_tensors(
self, batch_size: int, dtype: torch.dtype,
device: torch.device) -> IntermediateTensors:
return IntermediateTensors({
"hidden_states":
torch.zeros((batch_size, self.config.hidden_size),
dtype=dtype,
device=device),
})
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in weights:
......
......@@ -18,7 +18,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only GPTBigCode model compatible with HuggingFace weights."""
from typing import Iterable, List, Optional, Tuple
from typing import Iterable, List, Optional, Tuple, Union
import torch
from torch import nn
......@@ -26,14 +26,13 @@ from transformers import GPTBigCodeConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
......@@ -41,7 +40,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
class GPTBigCodeAttention(nn.Module):
......@@ -194,6 +195,7 @@ class GPTBigCodeModel(nn.Module):
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
prefix: str = "",
):
super().__init__()
self.config = config
......@@ -207,11 +209,15 @@ class GPTBigCodeModel(nn.Module):
self.embed_dim,
org_num_embeddings=config.vocab_size)
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
self.h = nn.ModuleList([
GPTBigCodeBlock(config, cache_config, quant_config)
for _ in range(config.num_hidden_layers)
])
self.start_layer, self.end_layer, self.h = make_layers(
config.num_hidden_layers,
lambda prefix: GPTBigCodeBlock(config, cache_config, quant_config),
prefix=f"{prefix}.h",
)
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(["hidden_states"],
config.n_embd))
def forward(
self,
......@@ -219,20 +225,28 @@ class GPTBigCodeModel(nn.Module):
position_ids: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
inputs_embeds = self.wte(input_ids)
position_embeds = self.wpe(position_ids)
hidden_states = inputs_embeds + position_embeds
intermediate_tensors: Optional[IntermediateTensors],
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
inputs_embeds = self.wte(input_ids)
position_embeds = self.wpe(position_ids)
hidden_states = inputs_embeds + position_embeds
else:
hidden_states = intermediate_tensors["hidden_states"]
for i in range(len(self.h)):
for i in range(self.start_layer, self.end_layer):
layer = self.h[i]
hidden_states = layer(hidden_states, kv_caches[i], attn_metadata)
hidden_states = layer(hidden_states,
kv_caches[i - self.start_layer],
attn_metadata)
if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})
hidden_states = self.ln_f(hidden_states)
return hidden_states
class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA):
class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = {"c_attn": ["c_attn"]}
supported_lora_modules = ["c_fc", "c_proj", "wte", "c_attn"]
......@@ -272,6 +286,8 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA):
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size)
self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
self.transformer.make_empty_intermediate_tensors)
def forward(
self,
......@@ -280,9 +296,9 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA):
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor:
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.transformer(input_ids, positions, kv_caches,
attn_metadata)
attn_metadata, intermediate_tensors)
return hidden_states
def compute_logits(
......@@ -311,6 +327,8 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA):
# Skip attention mask.
# NOTE: "c_attn.bias" should not be skipped.
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
......
......@@ -16,7 +16,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only GPT-J model compatible with HuggingFace weights."""
from typing import Iterable, List, Optional, Tuple
from typing import Iterable, List, Optional, Tuple, Union
import torch
from torch import nn
......@@ -24,14 +24,13 @@ from transformers import GPTJConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import (
......@@ -40,6 +39,10 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
class GPTJAttention(nn.Module):
......@@ -178,6 +181,7 @@ class GPTJModel(nn.Module):
config: GPTJConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.config = config
......@@ -186,11 +190,15 @@ class GPTJModel(nn.Module):
config.vocab_size,
self.embed_dim,
)
self.h = nn.ModuleList([
GPTJBlock(config, cache_config, quant_config)
for _ in range(config.n_layer)
])
self.start_layer, self.end_layer, self.h = make_layers(
config.n_layer,
lambda prefix: GPTJBlock(config, cache_config, quant_config),
prefix=f"{prefix}.h",
)
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(["hidden_states"],
config.n_embd))
def forward(
self,
......@@ -198,21 +206,27 @@ class GPTJModel(nn.Module):
position_ids: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states = self.wte(input_ids)
for i in range(len(self.h)):
intermediate_tensors: Optional[IntermediateTensors],
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
hidden_states = self.wte(input_ids)
else:
hidden_states = intermediate_tensors["hidden_states"]
for i in range(self.start_layer, self.end_layer):
layer = self.h[i]
hidden_states = layer(
position_ids,
hidden_states,
kv_caches[i],
kv_caches[i - self.start_layer],
attn_metadata,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})
hidden_states = self.ln_f(hidden_states)
return hidden_states
class GPTJForCausalLM(nn.Module):
class GPTJForCausalLM(nn.Module, SupportsPP):
def __init__(
self,
......@@ -233,6 +247,8 @@ class GPTJForCausalLM(nn.Module):
)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
self.transformer.make_empty_intermediate_tensors)
def forward(
self,
......@@ -241,9 +257,9 @@ class GPTJForCausalLM(nn.Module):
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor:
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.transformer(input_ids, positions, kv_caches,
attn_metadata)
attn_metadata, intermediate_tensors)
return hidden_states
def compute_logits(
......@@ -283,6 +299,8 @@ class GPTJForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
......@@ -291,6 +309,8 @@ class GPTJForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
......
......@@ -16,7 +16,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only GPT-NeoX model compatible with HuggingFace weights."""
from typing import Iterable, List, Optional, Tuple
from typing import Iterable, List, Optional, Tuple, Union
import torch
from torch import nn
......@@ -24,14 +24,13 @@ from transformers import GPTNeoXConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import (
......@@ -40,6 +39,10 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
class GPTNeoXAttention(nn.Module):
......@@ -191,6 +194,7 @@ class GPTNeoXModel(nn.Module):
config: GPTNeoXConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.config = config
......@@ -199,12 +203,16 @@ class GPTNeoXModel(nn.Module):
config.vocab_size,
config.hidden_size,
)
self.layers = nn.ModuleList([
GPTNeoXLayer(config, cache_config, quant_config)
for _ in range(config.num_hidden_layers)
])
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: GPTNeoXLayer(config, cache_config, quant_config),
prefix=f"{prefix}.layers",
)
self.final_layer_norm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(["hidden_states"],
config.hidden_size))
def forward(
self,
......@@ -212,21 +220,27 @@ class GPTNeoXModel(nn.Module):
position_ids: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states = self.embed_in(input_ids)
for i in range(len(self.layers)):
intermediate_tensors: Optional[IntermediateTensors],
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
hidden_states = self.embed_in(input_ids)
else:
hidden_states = intermediate_tensors["hidden_states"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states = layer(
position_ids,
hidden_states,
kv_caches[i],
kv_caches[i - self.start_layer],
attn_metadata,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})
hidden_states = self.final_layer_norm(hidden_states)
return hidden_states
class GPTNeoXForCausalLM(nn.Module):
class GPTNeoXForCausalLM(nn.Module, SupportsPP):
def __init__(
self,
......@@ -247,6 +261,8 @@ class GPTNeoXForCausalLM(nn.Module):
self.embed_out.weight = self.gpt_neox.embed_in.weight
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
self.gpt_neox.make_empty_intermediate_tensors)
def forward(
self,
......@@ -255,9 +271,9 @@ class GPTNeoXForCausalLM(nn.Module):
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor:
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.gpt_neox(input_ids, positions, kv_caches,
attn_metadata)
attn_metadata, intermediate_tensors)
return hidden_states
def compute_logits(
......@@ -288,6 +304,8 @@ class GPTNeoXForCausalLM(nn.Module):
# Models trained using OpenRLHF may include
# these tensors in the checkpoint. Skip them.
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
if "query_key_value" in name:
......
......@@ -51,7 +51,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from vllm.utils import is_hip
from .interfaces import SupportsLoRA
from .interfaces import SupportsLoRA, SupportsPP
from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers
......@@ -311,13 +311,13 @@ class GraniteModel(nn.Module):
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None
hidden_states *= self.config.embedding_multiplier
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
hidden_states *= self.config.embedding_multiplier
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states = layer(
......@@ -337,7 +337,7 @@ class GraniteModel(nn.Module):
return hidden_states
class GraniteForCausalLM(nn.Module, SupportsLoRA):
class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
......
......@@ -46,7 +46,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from . import mixtral
from .interfaces import SupportsLoRA
from .interfaces import SupportsLoRA, SupportsPP
from .utils import make_layers
......@@ -307,7 +307,7 @@ class GraniteMoeModel(nn.Module):
return hidden_states
class GraniteMoeForCausalLM(nn.Module, SupportsLoRA):
class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
fall_back_to_pt_during_load = False
packed_modules_mapping = {
......
from typing import (ClassVar, Dict, List, Literal, Optional, Protocol, Type,
Union, overload, runtime_checkable)
import inspect
from typing import (TYPE_CHECKING, ClassVar, Dict, List, Literal, Optional,
Protocol, Type, Union, overload, runtime_checkable)
import torch
from typing_extensions import TypeIs
from vllm.config import LoRAConfig, MultiModalConfig, SchedulerConfig
from vllm.logger import init_logger
if TYPE_CHECKING:
from vllm.attention import AttentionMetadata
from vllm.config import LoRAConfig, MultiModalConfig, SchedulerConfig
from vllm.sequence import IntermediateTensors
logger = init_logger(__name__)
......@@ -22,7 +28,7 @@ class SupportsMultiModal(Protocol):
MRO of your model class.
"""
def __init__(self, *, multimodal_config: MultiModalConfig) -> None:
def __init__(self, *, multimodal_config: "MultiModalConfig") -> None:
...
......@@ -32,7 +38,7 @@ class SupportsMultiModal(Protocol):
class _SupportsMultiModalType(Protocol):
supports_multimodal: Literal[True]
def __call__(self, *, multimodal_config: MultiModalConfig) -> None:
def __call__(self, *, multimodal_config: "MultiModalConfig") -> None:
...
......@@ -75,7 +81,7 @@ class SupportsLoRA(Protocol):
embedding_padding_modules: ClassVar[List[str]]
# lora_config is None when LoRA is not enabled
def __init__(self, *, lora_config: Optional[LoRAConfig] = None) -> None:
def __init__(self, *, lora_config: Optional["LoRAConfig"] = None) -> None:
...
......@@ -90,7 +96,7 @@ class _SupportsLoRAType(Protocol):
embedding_modules: Dict[str, str]
embedding_padding_modules: List[str]
def __call__(self, *, lora_config: Optional[LoRAConfig] = None) -> None:
def __call__(self, *, lora_config: Optional["LoRAConfig"] = None) -> None:
...
......@@ -145,6 +151,132 @@ def _supports_lora(
return isinstance(model, SupportsLoRA)
@runtime_checkable
class SupportsPP(Protocol):
"""The interface required for all models that support pipeline parallel."""
supports_pp: ClassVar[Literal[True]] = True
"""
A flag that indicates this model supports pipeline parallel.
Note:
There is no need to redefine this flag if this class is in the
MRO of your model class.
"""
def make_empty_intermediate_tensors(
self,
batch_size: int,
dtype: torch.dtype,
device: torch.device,
) -> "IntermediateTensors":
"""Called when PP rank > 0 for profiling purposes."""
...
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: "AttentionMetadata",
intermediate_tensors: Optional["IntermediateTensors"],
) -> Union[torch.Tensor, "IntermediateTensors"]:
"""
Accept :class:`IntermediateTensors` when PP rank > 0.
Return :class:`IntermediateTensors` only for the last PP rank.
"""
...
# We can't use runtime_checkable with ClassVar for issubclass checks
# so we need to treat the class as an instance and use isinstance instead
@runtime_checkable
class _SupportsPPType(Protocol):
supports_pp: Literal[True]
def make_empty_intermediate_tensors(
self,
batch_size: int,
dtype: torch.dtype,
device: torch.device,
) -> "IntermediateTensors":
...
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: "AttentionMetadata",
intermediate_tensors: Optional["IntermediateTensors"],
) -> Union[torch.Tensor, "IntermediateTensors"]:
...
@overload
def supports_pp(model: Type[object]) -> TypeIs[Type[SupportsPP]]:
...
@overload
def supports_pp(model: object) -> TypeIs[SupportsPP]:
...
def supports_pp(
model: Union[Type[object], object],
) -> Union[bool, TypeIs[Type[SupportsPP]], TypeIs[SupportsPP]]:
supports_attributes = _supports_pp_attributes(model)
supports_inspect = _supports_pp_inspect(model)
if supports_attributes and not supports_inspect:
logger.warning(
"The model (%s) sets `supports_pp=True`, but does not accept "
"`intermediate_tensors` in its `forward` method", model)
if not supports_attributes:
pp_attrs = ("make_empty_intermediate_tensors", )
missing_attrs = tuple(attr for attr in pp_attrs
if not hasattr(model, attr))
if getattr(model, "supports_pp", False):
if missing_attrs:
logger.warning(
"The model (%s) sets `supports_pp=True`, "
"but is missing PP-specific attributes: %s",
model,
missing_attrs,
)
else:
if not missing_attrs:
logger.warning(
"The model (%s) contains all PP-specific attributes, "
"but does not set `supports_pp=True`.", model)
return supports_attributes and supports_inspect
def _supports_pp_attributes(
model: Union[Type[object], object],
) -> Union[bool, TypeIs[Type[SupportsPP]], TypeIs[SupportsPP]]:
if isinstance(model, type):
return isinstance(model, _SupportsPPType)
return isinstance(model, SupportsPP)
def _supports_pp_inspect(
model: Union[Type[object], object],
) -> Union[bool, TypeIs[Type[SupportsPP]], TypeIs[SupportsPP]]:
model_forward = getattr(model, "forward", None)
if not callable(model_forward):
return False
forward_params = inspect.signature(model_forward).parameters
return "intermediate_tensors" in forward_params
@runtime_checkable
class HasInnerState(Protocol):
"""The interface required for all models that has inner state."""
......@@ -158,7 +290,7 @@ class HasInnerState(Protocol):
def __init__(self,
*,
scheduler_config: Optional[SchedulerConfig] = None) -> None:
scheduler_config: Optional["SchedulerConfig"] = None) -> None:
...
......@@ -168,7 +300,7 @@ class _HasInnerStateType(Protocol):
def __init__(self,
*,
scheduler_config: Optional[SchedulerConfig] = None) -> None:
scheduler_config: Optional["SchedulerConfig"] = None) -> None:
...
......
......@@ -18,8 +18,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import (
......@@ -28,6 +27,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
......@@ -266,7 +266,7 @@ class InternLM2Model(nn.Module):
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: IntermediateTensors = None,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
......@@ -297,7 +297,7 @@ class InternLM2Model(nn.Module):
return hidden_states
class InternLM2ForCausalLM(nn.Module):
class InternLM2ForCausalLM(nn.Module, SupportsPP):
def __init__(
self,
......@@ -325,7 +325,7 @@ class InternLM2ForCausalLM(nn.Module):
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: IntermediateTensors,
intermediate_tensors: Optional[IntermediateTensors],
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors)
......
......@@ -5,9 +5,9 @@
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
import re
from functools import partial
from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional,
Tuple, TypedDict, Union)
from functools import cached_property, partial
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union)
import torch
import torch.nn as nn
......@@ -17,7 +17,6 @@ from transformers import PretrainedConfig
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig
from vllm.distributed import get_pp_group
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
......@@ -32,7 +31,7 @@ from vllm.utils import is_list_of
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
get_clip_num_patches)
from .interfaces import SupportsMultiModal
from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (flatten_bn, group_weights_with_prefix,
init_vllm_registered_model, merge_multimodal_embeddings)
......@@ -123,7 +122,7 @@ def calculate_num_blocks(orig_width: int, orig_height: int, min_num: int,
return blocks, target_width, target_height
def calculate_num_blocks_wrapper(hf_config: Dict[str, Any],
def calculate_num_blocks_wrapper(hf_config: PretrainedConfig,
max_dynamic_patch: Optional[int] = None):
if max_dynamic_patch is None:
max_dynamic_patch = hf_config.max_dynamic_patch
......@@ -183,7 +182,7 @@ def image_to_pixel_values(image: Image.Image, input_size: int, min_num: int,
return pixel_values
def image_to_pixel_values_wrapper(hf_config: Dict[str, Any],
def image_to_pixel_values_wrapper(hf_config: PretrainedConfig,
max_dynamic_patch: Optional[int] = None):
image_size = hf_config.vision_config.image_size
min_num = hf_config.min_dynamic_patch
......@@ -197,7 +196,7 @@ def image_to_pixel_values_wrapper(hf_config: Dict[str, Any],
use_thumbnail=use_thumbnail)
def get_internvl_num_patches(hf_config: Dict[str, Any]):
def get_internvl_num_patches(hf_config: PretrainedConfig):
vision_config = hf_config.vision_config
downsample_ratio = hf_config.downsample_ratio
image_size = vision_config.image_size
......@@ -362,7 +361,7 @@ def dummy_data_for_internvl(ctx: InputContext,
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_internvl_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_internvl)
@INPUT_REGISTRY.register_input_processor(input_processor_for_internvl)
class InternVLChatModel(nn.Module, SupportsMultiModal):
class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
def __init__(self,
config: PretrainedConfig,
......@@ -408,10 +407,12 @@ class InternVLChatModel(nn.Module, SupportsMultiModal):
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
@cached_property
def sampler(self):
if hasattr(self.language_model, "sampler"):
self.sampler = self.language_model.sampler
else:
self.sampler = Sampler()
return self.language_model.sampler
return Sampler()
def pixel_shuffle(self, x, scale_factor=0.5):
n, w, h, c = x.size()
......@@ -515,18 +516,22 @@ class InternVLChatModel(nn.Module, SupportsMultiModal):
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs: object,
) -> SamplerOutput:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is not None and get_pp_group().is_first_rank:
inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids)
vision_embeddings = self._process_image_input(image_input)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.img_context_token_id)
) -> Union[SamplerOutput, IntermediateTensors]:
if intermediate_tensors is not None:
input_ids = None
else:
inputs_embeds = None
else:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is not None:
inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids)
vision_embeddings = self._process_image_input(image_input)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.img_context_token_id)
input_ids = None
else:
inputs_embeds = None
hidden_states = self.language_model.model(input_ids,
positions,
......
......@@ -33,8 +33,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
......@@ -43,7 +42,9 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs import JAISConfig
from .utils import is_pp_missing_parameter, make_layers
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
class SwiGLUActivation(nn.Module):
......@@ -244,6 +245,9 @@ class JAISModel(nn.Module):
)
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(["hidden_states"],
config.n_embd))
def forward(
self,
......@@ -279,7 +283,7 @@ class JAISModel(nn.Module):
return hidden_states
class JAISLMHeadModel(nn.Module):
class JAISLMHeadModel(nn.Module, SupportsPP):
def __init__(
self,
......@@ -304,6 +308,8 @@ class JAISLMHeadModel(nn.Module):
self.logits_processor = LogitsProcessor(vocab_size=config.vocab_size,
scale=self.output_logits_scale)
self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
self.transformer.make_empty_intermediate_tensors)
def forward(
self,
......@@ -326,16 +332,6 @@ class JAISLMHeadModel(nn.Module):
sampling_metadata)
return logits
def make_empty_intermediate_tensors(
self, batch_size: int, dtype: torch.dtype,
device: torch.device) -> IntermediateTensors:
return IntermediateTensors({
"hidden_states":
torch.zeros((batch_size, self.config.hidden_size),
dtype=dtype,
device=device),
})
def sample(
self,
logits: torch.Tensor,
......
......@@ -37,8 +37,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
get_compressed_tensors_cache_scale)
from vllm.model_executor.layers.rotary_embedding import get_rope
......@@ -51,8 +50,9 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from vllm.utils import is_hip
from .interfaces import SupportsLoRA
from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (PPMissingLayer, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
class LlamaMLP(nn.Module):
......@@ -72,12 +72,15 @@ class LlamaMLP(nn.Module):
output_sizes=[intermediate_size] * 2,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj")
self.down_proj = RowParallelLinear(input_size=intermediate_size,
output_size=hidden_size,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.down_proj")
prefix=f"{prefix}.gate_up_proj",
)
self.down_proj = RowParallelLinear(
input_size=intermediate_size,
output_size=hidden_size,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
......@@ -161,12 +164,14 @@ class LlamaAttention(nn.Module):
rope_scaling=rope_scaling,
is_neox_style=is_neox_style,
)
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config)
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
)
def forward(
self,
......@@ -248,12 +253,10 @@ class LlamaDecoderLayer(nn.Module):
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
hidden_states = self.self_attn(positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
......@@ -295,12 +298,17 @@ class LlamaModel(nn.Module):
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix),
prefix=f"{prefix}.layers")
prefix=f"{prefix}.layers",
)
if get_pp_group().is_last_rank:
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
else:
self.norm = PPMissingLayer()
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
......@@ -326,13 +334,9 @@ class LlamaModel(nn.Module):
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(
positions,
hidden_states,
kv_caches[i - self.start_layer],
attn_metadata,
residual,
)
hidden_states, residual = layer(positions, hidden_states,
kv_caches[i - self.start_layer],
attn_metadata, residual)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
......@@ -344,17 +348,10 @@ class LlamaModel(nn.Module):
return hidden_states
class LlamaForCausalLM(nn.Module, SupportsLoRA):
class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"]
}
# LoRA specific attributes
......@@ -364,7 +361,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
]
embedding_modules = {
"embed_tokens": "input_embeddings",
"lm_head": "output_embeddings",
"lm_head": "output_embeddings"
}
embedding_padding_modules = ["lm_head"]
bitsandbytes_stacked_params_mapping = {
......@@ -420,10 +417,12 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if not lora_config else lora_config.lora_vocab_padding_size,
padding_size=(
DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if not lora_config else
lora_config.lora_vocab_padding_size),
quant_config=quant_config,
)
if config.tie_word_embeddings:
......@@ -436,6 +435,8 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
self.sampler = Sampler()
else:
self.lm_head = PPMissingLayer()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def forward(
self,
......@@ -458,28 +459,11 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
def sample(self, logits: torch.Tensor,
sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def make_empty_intermediate_tensors(
self, batch_size: int, dtype: torch.dtype,
device: torch.device) -> IntermediateTensors:
return IntermediateTensors({
"hidden_states":
torch.zeros((batch_size, self.config.hidden_size),
dtype=dtype,
device=device),
"residual":
torch.zeros((batch_size, self.config.hidden_size),
dtype=dtype,
device=device),
})
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
......@@ -513,7 +497,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
loaded_weight = loaded_weight[0]
weight_loader(param, loaded_weight)
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
......
from typing import Iterable, List, Optional, Tuple
from typing import Iterable, List, Optional, Tuple, Union
import torch
from torch import nn
......@@ -8,10 +8,13 @@ from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import PoolerOutput
from vllm.sequence import IntermediateTensors, PoolerOutput
from .interfaces import SupportsPP
from .utils import is_pp_missing_parameter
class LlamaEmbeddingModel(nn.Module):
class LlamaEmbeddingModel(nn.Module, SupportsPP):
"""A model that uses Llama with additional embedding functionalities.
This class encapsulates the LlamaModel and provides an interface for
......@@ -29,6 +32,8 @@ class LlamaEmbeddingModel(nn.Module):
super().__init__()
self.model = LlamaModel(**kwargs)
self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def forward(
self,
......@@ -36,10 +41,12 @@ class LlamaEmbeddingModel(nn.Module):
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
) -> Union[torch.Tensor, IntermediateTensors]:
return self.model.forward(input_ids, positions, kv_caches,
attn_metadata, inputs_embeds)
attn_metadata, intermediate_tensors,
inputs_embeds)
def pooler(
self,
......@@ -73,6 +80,8 @@ class LlamaEmbeddingModel(nn.Module):
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
......@@ -81,6 +90,8 @@ class LlamaEmbeddingModel(nn.Module):
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
......
from functools import cached_property
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union)
......@@ -11,7 +12,7 @@ from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
......@@ -21,7 +22,7 @@ from vllm.utils import is_list_of
from .clip import (CLIPVisionModel, dummy_image_for_clip,
dummy_seq_data_for_clip, get_max_clip_image_tokens,
input_processor_for_clip)
from .interfaces import SupportsMultiModal
from .interfaces import SupportsMultiModal, SupportsPP
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip, get_max_siglip_image_tokens,
input_processor_for_siglip)
......@@ -198,7 +199,7 @@ def _init_vision_tower(hf_config: LlavaConfig):
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava)
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava)
class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal):
class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
def __init__(self,
config: LlavaConfig,
......@@ -220,6 +221,16 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal):
self.language_model = init_vllm_registered_model(
config.text_config, cache_config, quant_config)
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
@cached_property
def sampler(self):
if hasattr(self.language_model, "sampler"):
return self.language_model.sampler
return Sampler()
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
h = w = self.config.vision_config.image_size
expected_dims = (3, h, w)
......@@ -315,7 +326,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal):
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs: object,
) -> SamplerOutput:
) -> Union[torch.Tensor, IntermediateTensors]:
"""Run forward pass for LLaVA-1.5.
One key thing to understand is the `input_ids` already accounts for the
......@@ -351,26 +362,30 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal):
See also:
:class:`LlavaImageInputs`
"""
image_input = self._parse_and_validate_image_input(**kwargs)
if intermediate_tensors is not None:
input_ids = None
inputs_embeds = None
else:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is not None:
vision_embeddings = self._process_image_input(image_input)
inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids)
if image_input is not None:
vision_embeddings = self._process_image_input(image_input)
inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.config.image_token_index)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.config.image_token_index)
input_ids = None
else:
inputs_embeds = None
input_ids = None
else:
inputs_embeds = None
hidden_states = self.language_model.model(input_ids,
positions,
kv_caches,
attn_metadata,
None,
intermediate_tensors,
inputs_embeds=inputs_embeds)
return hidden_states
......
from functools import cached_property
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union)
......@@ -13,7 +14,7 @@ from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
......@@ -23,7 +24,7 @@ from vllm.utils import is_list_of
from .clip import (CLIPVisionModel, dummy_image_for_clip,
dummy_seq_data_for_clip, get_clip_image_feature_size,
get_clip_patch_grid_length, input_processor_for_clip)
from .interfaces import SupportsMultiModal
from .interfaces import SupportsMultiModal, SupportsPP
from .llava import LlavaMultiModalProjector
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip, get_siglip_image_feature_size,
......@@ -286,7 +287,8 @@ def _init_vision_tower(hf_config: LlavaNextConfig):
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_next_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next)
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava_next)
class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal):
class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP):
def __init__(self,
config: LlavaNextConfig,
......@@ -300,6 +302,8 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal):
# TODO: Optionally initializes this for supporting embeddings.
self.vision_tower = _init_vision_tower(config)
self.image_newline = nn.Parameter(
torch.empty(config.text_config.hidden_size))
self.multi_modal_projector = LlavaMultiModalProjector(
vision_hidden_size=config.vision_config.hidden_size,
text_hidden_size=config.text_config.hidden_size,
......@@ -308,8 +312,15 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal):
self.language_model = init_vllm_registered_model(
config.text_config, cache_config, quant_config)
self.image_newline = nn.Parameter(
torch.empty(config.text_config.hidden_size))
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
@cached_property
def sampler(self):
if hasattr(self.language_model, "sampler"):
return self.language_model.sampler
return Sampler()
def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor:
expected_dims = (2, )
......@@ -542,7 +553,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal):
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs: object,
) -> SamplerOutput:
) -> Union[torch.Tensor, IntermediateTensors]:
"""Run forward pass for LlaVA-NeXT.
One key thing to understand is the `input_ids` already accounts for the
......@@ -587,26 +598,30 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal):
See also:
:class:`LlavaNextImageInputs`
"""
image_input = self._parse_and_validate_image_input(**kwargs)
if intermediate_tensors is not None:
input_ids = None
inputs_embeds = None
else:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is not None:
vision_embeddings = self._process_image_input(image_input)
inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids)
if image_input is not None:
vision_embeddings = self._process_image_input(image_input)
inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.config.image_token_index)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.config.image_token_index)
input_ids = None
else:
inputs_embeds = None
input_ids = None
else:
inputs_embeds = None
hidden_states = self.language_model.model(input_ids,
positions,
kv_caches,
attn_metadata,
None,
intermediate_tensors,
inputs_embeds=inputs_embeds)
return hidden_states
......
import math
from functools import cached_property
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union)
......@@ -12,9 +13,8 @@ from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
......@@ -25,7 +25,7 @@ from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of
from .clip import dummy_image_for_clip, dummy_seq_data_for_clip
from .interfaces import SupportsMultiModal
from .interfaces import SupportsMultiModal, SupportsPP
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip)
from .utils import (group_weights_with_prefix, init_vllm_registered_model,
......@@ -267,7 +267,8 @@ class LlavaNextMultiModalProjector(nn.Module):
"video", get_max_llava_next_video_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next_video)
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava_next_video)
class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal):
class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP):
def __init__(self,
config: LlavaNextVideoConfig,
......@@ -281,13 +282,23 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal):
# Initialize the vision tower only up to the required feature layer
self.vision_tower = _init_vision_tower(config)
self.vision_resampler = LlavaNextVideoPooler(config)
self.multi_modal_projector = LlavaNextMultiModalProjector(
vision_hidden_size=config.vision_config.hidden_size,
text_hidden_size=config.text_config.hidden_size,
projector_hidden_act=config.projector_hidden_act)
self.language_model = init_vllm_registered_model(
config.text_config, cache_config, quant_config)
self.vision_resampler = LlavaNextVideoPooler(config)
self.make_empty_intermediate_tensors = (
self.language_model.model.make_empty_intermediate_tensors)
@cached_property
def sampler(self):
if hasattr(self.language_model, "sampler"):
return self.language_model.sampler
return Sampler()
def _validate_video_pixel_values(
self, data: Union[torch.Tensor, List[torch.Tensor]]
......@@ -397,34 +408,36 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal):
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs: object,
) -> SamplerOutput:
) -> Union[torch.Tensor, IntermediateTensors]:
"""Run forward pass for LlaVA-NeXT-Video.
Args:
input_ids: Flattened (concatenated) input_ids corresponding to a
batch.
pixel_values_videos: Pixels in each frames for each input videos.
"""
video_input = self._parse_and_validate_video_input(**kwargs)
# merge video embeddings into input embeddings
if video_input is not None:
video_embeddings = self._process_video_pixels(video_input)
inputs_embeds = self.language_model \
.model.get_input_embeddings(input_ids)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, video_embeddings,
self.config.video_token_index)
if intermediate_tensors is not None:
input_ids = None
else:
inputs_embeds = None
else:
video_input = self._parse_and_validate_video_input(**kwargs)
if video_input is not None:
video_embeddings = self._process_video_pixels(video_input)
inputs_embeds = self.language_model \
.model.get_input_embeddings(input_ids)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, video_embeddings,
self.config.video_token_index)
input_ids = None
else:
inputs_embeds = None
hidden_states = self.language_model.model(input_ids,
positions,
kv_caches,
attn_metadata,
None,
intermediate_tensors,
inputs_embeds=inputs_embeds)
return hidden_states
......
import math
from functools import cached_property
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union)
......@@ -17,9 +18,8 @@ from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
......@@ -31,7 +31,7 @@ from vllm.utils import is_list_of
from .clip import (CLIPVisionModel, dummy_seq_data_for_clip,
dummy_video_for_clip, get_clip_image_feature_size,
get_clip_patch_grid_length, input_processor_for_clip)
from .interfaces import SupportsMultiModal
from .interfaces import SupportsMultiModal, SupportsPP
from .siglip import (SiglipVisionModel, dummy_seq_data_for_siglip,
dummy_video_for_siglip, get_siglip_image_feature_size,
get_siglip_patch_grid_length, input_processor_for_siglip)
......@@ -414,7 +414,8 @@ class LlavaOnevisionMultiModalProjector(nn.Module):
"video", get_max_llava_onevision_video_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_onevision)
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava_onevision)
class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal):
class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP):
def __init__(self,
config: LlavaOnevisionConfig,
......@@ -434,6 +435,16 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal):
self.image_newline = nn.Parameter(
torch.empty(config.text_config.hidden_size))
self.make_empty_intermediate_tensors = (
self.language_model.model.make_empty_intermediate_tensors)
@cached_property
def sampler(self):
if hasattr(self.language_model, "sampler"):
return self.language_model.sampler
return Sampler()
def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor:
expected_dims = (2, )
......@@ -805,39 +816,42 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal):
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs: object,
) -> SamplerOutput:
) -> Union[torch.Tensor, IntermediateTensors]:
"""Run forward pass for LlaVA-Onevision.
Args:
input_ids: Flattened (concatenated) input_ids corresponding to a
batch.
pixel_values_videos: Pixels in each frames for each input videos.
"""
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
# merge video embeddings into input embeddings
if modalities:
inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids)
if "images" in modalities:
image_input = modalities["images"]
vision_embeddings = self._process_image_input(image_input)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.config.image_token_index)
if "videos" in modalities:
video_input = modalities["videos"]
video_embeddings = self._process_video_pixels(video_input)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, video_embeddings,
self.config.video_token_index)
if intermediate_tensors is not None:
input_ids = None
else:
inputs_embeds = None
else:
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
if modalities:
inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids)
if "images" in modalities:
image_input = modalities["images"]
vision_embeddings = self._process_image_input(image_input)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.config.image_token_index)
if "videos" in modalities:
video_input = modalities["videos"]
video_embeddings = self._process_video_pixels(video_input)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, video_embeddings,
self.config.video_token_index)
input_ids = None
else:
inputs_embeds = None
hidden_states = self.language_model.model(input_ids,
positions,
kv_caches,
attn_metadata,
None,
intermediate_tensors,
inputs_embeds=inputs_embeds)
return hidden_states
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment