Unverified Commit 0f6d7a9a authored by Murali Andoorveedu's avatar Murali Andoorveedu Committed by GitHub
Browse files

[Models] Add remaining model PP support (#7168)


Signed-off-by: default avatarMuralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
Signed-off-by: default avatarMurali Andoorveedu <muralidhar.andoorveedu@centml.ai>
Co-authored-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 303d4479
......@@ -22,7 +22,7 @@
# limitations under the License.
"""Inference-only MiniCPM model compatible with HuggingFace weights."""
import math
from typing import Any, Dict, Iterable, List, Optional, Tuple
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
import torch
from torch import nn
......@@ -30,7 +30,7 @@ from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.activation import SiluAndMul
......@@ -41,8 +41,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import (
......@@ -52,7 +51,9 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
class MiniCPMMoE(nn.Module):
......@@ -264,7 +265,7 @@ class MiniCPMDecoderLayer(nn.Module):
def __init__(
self,
config,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
......@@ -346,10 +347,11 @@ class MiniCPMModel(nn.Module):
def __init__(
self,
config,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
......@@ -365,15 +367,24 @@ class MiniCPMModel(nn.Module):
config.hidden_size,
org_num_embeddings=config.vocab_size,
)
self._init_layers()
self._init_layers(prefix, config, cache_config, quant_config)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], self.config.hidden_size))
def _init_layers(self):
self.layers = nn.ModuleList([
MiniCPMDecoderLayer(self.config, self.cache_config,
self.quant_config)
for _ in range(self.config.num_hidden_layers)
])
def _init_layers(
self,
prefix: str,
config: PretrainedConfig,
cache_config: Optional[CacheConfig],
quant_config: Optional[QuantizationConfig],
):
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: MiniCPMDecoderLayer(config, cache_config,
quant_config),
prefix=f"{prefix}.layers")
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
embedding = self.embed_tokens(input_ids)
......@@ -387,27 +398,36 @@ class MiniCPMModel(nn.Module):
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if inputs_embeds is not None:
hidden_states = inputs_embeds
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(len(self.layers)):
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(
positions,
hidden_states,
kv_caches[i],
kv_caches[i - self.start_layer],
attn_metadata,
residual,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states = self.norm(hidden_states)
return hidden_states
class MiniCPMForCausalLM(nn.Module, SupportsLoRA):
class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
......@@ -470,6 +490,8 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA):
self.logits_processor = LogitsProcessor(unpadded_vocab_size,
config.vocab_size)
self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def _init_model(self):
self.model = MiniCPMModel(config=self.config,
......@@ -484,7 +506,7 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA):
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor:
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors)
return hidden_states
......@@ -548,6 +570,8 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA):
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
......@@ -557,6 +581,8 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA):
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param,
......@@ -568,6 +594,8 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA):
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
......
......@@ -26,6 +26,7 @@ from typing import Any, Dict, Optional
import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
......@@ -34,19 +35,20 @@ from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.models.minicpm import (MiniCPMDecoderLayer,
MiniCPMForCausalLM,
MiniCPMModel)
from .utils import make_layers
class MiniCPM3Attention(nn.Module):
def __init__(
self,
config,
config: PretrainedConfig,
hidden_size: int,
num_heads: int,
qk_nope_head_dim: int,
......@@ -199,12 +201,18 @@ class MiniCPM3DecoderLayer(MiniCPMDecoderLayer):
class MiniCPM3Model(MiniCPMModel):
def _init_layers(self):
self.layers = nn.ModuleList([
MiniCPM3DecoderLayer(self.config, self.cache_config,
self.quant_config)
for _ in range(self.config.num_hidden_layers)
])
def _init_layers(
self,
prefix: str,
config: PretrainedConfig,
cache_config: Optional[CacheConfig],
quant_config: Optional[QuantizationConfig],
):
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: MiniCPM3DecoderLayer(config, cache_config,
quant_config),
prefix=f"{prefix}.layers")
class MiniCPM3ForCausalLM(MiniCPMForCausalLM):
......
......@@ -45,7 +45,6 @@ from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import SupportsMultiModal
from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.models.minicpm import MiniCPMModel
from vllm.model_executor.models.module_mapping import MultiModelKeys
......@@ -59,7 +58,8 @@ from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import IntermediateTensors, SequenceData
from .idefics2_vision_model import Idefics2VisionTransformer
from .interfaces import SupportsLoRA
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
from .utils import is_pp_missing_parameter
_KEYS_TO_MODIFY_MAPPING = {
"llm.lm_head": "lm_head",
......@@ -337,7 +337,7 @@ def input_mapper_for_minicpmv(ctx: InputContext, data: object):
return MultiModalInputs(batch_data)
class MiniCPMVBaseModel(nn.Module, SupportsMultiModal):
class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
"""
The abstract class of MiniCPMV can only be inherited, but cannot be
instantiated.
......@@ -374,6 +374,9 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal):
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
self.llm.make_empty_intermediate_tensors)
def get_embedding(
self,
input_ids: torch.Tensor,
......@@ -498,9 +501,12 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal):
intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs: Any,
) -> torch.Tensor:
image_inputs = self._parse_and_validate_inputs(input_ids, **kwargs)
if intermediate_tensors is not None:
vlm_embeddings = None
else:
image_inputs = self._parse_and_validate_inputs(input_ids, **kwargs)
vlm_embeddings, _ = self.get_embedding(input_ids, image_inputs)
vlm_embeddings, _ = self.get_embedding(input_ids, image_inputs)
output = self.llm(
input_ids=None,
......@@ -557,6 +563,9 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal):
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
if is_pp_missing_parameter(
name.replace(weight_name, param_name), self):
continue
param = params_dict[name.replace(weight_name, param_name)]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
......@@ -564,6 +573,8 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal):
else:
use_default_weight_loading = True
if use_default_weight_loading:
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment