Unverified Commit 0f6d7a9a authored by Murali Andoorveedu's avatar Murali Andoorveedu Committed by GitHub
Browse files

[Models] Add remaining model PP support (#7168)


Signed-off-by: default avatarMuralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
Signed-off-by: default avatarMurali Andoorveedu <muralidhar.andoorveedu@centml.ai>
Co-authored-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 303d4479
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
# limitations under the License. # limitations under the License.
"""Inference-only MiniCPM model compatible with HuggingFace weights.""" """Inference-only MiniCPM model compatible with HuggingFace weights."""
import math import math
from typing import Any, Dict, Iterable, List, Optional, Tuple from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
import torch import torch
from torch import nn from torch import nn
...@@ -30,7 +30,7 @@ from transformers import PretrainedConfig ...@@ -30,7 +30,7 @@ from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import (get_tensor_model_parallel_rank, from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce) tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
...@@ -41,8 +41,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, ...@@ -41,8 +41,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
ReplicatedLinear, ReplicatedLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization import QuantizationConfig
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
...@@ -52,7 +51,9 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata ...@@ -52,7 +51,9 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA from .interfaces import SupportsLoRA, SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
class MiniCPMMoE(nn.Module): class MiniCPMMoE(nn.Module):
...@@ -264,7 +265,7 @@ class MiniCPMDecoderLayer(nn.Module): ...@@ -264,7 +265,7 @@ class MiniCPMDecoderLayer(nn.Module):
def __init__( def __init__(
self, self,
config, config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
...@@ -346,10 +347,11 @@ class MiniCPMModel(nn.Module): ...@@ -346,10 +347,11 @@ class MiniCPMModel(nn.Module):
def __init__( def __init__(
self, self,
config, config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None, lora_config: Optional[LoRAConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -365,15 +367,24 @@ class MiniCPMModel(nn.Module): ...@@ -365,15 +367,24 @@ class MiniCPMModel(nn.Module):
config.hidden_size, config.hidden_size,
org_num_embeddings=config.vocab_size, org_num_embeddings=config.vocab_size,
) )
self._init_layers() self._init_layers(prefix, config, cache_config, quant_config)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], self.config.hidden_size))
def _init_layers(self): def _init_layers(
self.layers = nn.ModuleList([ self,
MiniCPMDecoderLayer(self.config, self.cache_config, prefix: str,
self.quant_config) config: PretrainedConfig,
for _ in range(self.config.num_hidden_layers) cache_config: Optional[CacheConfig],
]) quant_config: Optional[QuantizationConfig],
):
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: MiniCPMDecoderLayer(config, cache_config,
quant_config),
prefix=f"{prefix}.layers")
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
embedding = self.embed_tokens(input_ids) embedding = self.embed_tokens(input_ids)
...@@ -387,27 +398,36 @@ class MiniCPMModel(nn.Module): ...@@ -387,27 +398,36 @@ class MiniCPMModel(nn.Module):
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None: if inputs_embeds is not None:
hidden_states = inputs_embeds hidden_states = inputs_embeds
else: else:
hidden_states = self.get_input_embeddings(input_ids) hidden_states = self.get_input_embeddings(input_ids)
residual = None residual = None
else:
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(len(self.layers)): for i in range(self.start_layer, self.end_layer):
layer = self.layers[i] layer = self.layers[i]
hidden_states, residual = layer( hidden_states, residual = layer(
positions, positions,
hidden_states, hidden_states,
kv_caches[i], kv_caches[i - self.start_layer],
attn_metadata, attn_metadata,
residual, residual,
) )
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states = self.norm(hidden_states) hidden_states = self.norm(hidden_states)
return hidden_states return hidden_states
class MiniCPMForCausalLM(nn.Module, SupportsLoRA): class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": [ "qkv_proj": [
"q_proj", "q_proj",
...@@ -470,6 +490,8 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA): ...@@ -470,6 +490,8 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA):
self.logits_processor = LogitsProcessor(unpadded_vocab_size, self.logits_processor = LogitsProcessor(unpadded_vocab_size,
config.vocab_size) config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def _init_model(self): def _init_model(self):
self.model = MiniCPMModel(config=self.config, self.model = MiniCPMModel(config=self.config,
...@@ -484,7 +506,7 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA): ...@@ -484,7 +506,7 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA):
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors) attn_metadata, intermediate_tensors)
return hidden_states return hidden_states
...@@ -548,6 +570,8 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA): ...@@ -548,6 +570,8 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA):
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name] param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
...@@ -557,6 +581,8 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA): ...@@ -557,6 +581,8 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA):
if weight_name not in name: if weight_name not in name:
continue continue
name = name.replace(weight_name, param_name) name = name.replace(weight_name, param_name)
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name] param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, weight_loader(param,
...@@ -568,6 +594,8 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA): ...@@ -568,6 +594,8 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA):
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
......
...@@ -26,6 +26,7 @@ from typing import Any, Dict, Optional ...@@ -26,6 +26,7 @@ from typing import Any, Dict, Optional
import torch import torch
from torch import nn from torch import nn
from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig from vllm.config import CacheConfig
...@@ -34,19 +35,20 @@ from vllm.model_executor.layers.layernorm import RMSNorm ...@@ -34,19 +35,20 @@ from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
ReplicatedLinear, ReplicatedLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization import QuantizationConfig
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.models.minicpm import (MiniCPMDecoderLayer, from vllm.model_executor.models.minicpm import (MiniCPMDecoderLayer,
MiniCPMForCausalLM, MiniCPMForCausalLM,
MiniCPMModel) MiniCPMModel)
from .utils import make_layers
class MiniCPM3Attention(nn.Module): class MiniCPM3Attention(nn.Module):
def __init__( def __init__(
self, self,
config, config: PretrainedConfig,
hidden_size: int, hidden_size: int,
num_heads: int, num_heads: int,
qk_nope_head_dim: int, qk_nope_head_dim: int,
...@@ -199,12 +201,18 @@ class MiniCPM3DecoderLayer(MiniCPMDecoderLayer): ...@@ -199,12 +201,18 @@ class MiniCPM3DecoderLayer(MiniCPMDecoderLayer):
class MiniCPM3Model(MiniCPMModel): class MiniCPM3Model(MiniCPMModel):
def _init_layers(self): def _init_layers(
self.layers = nn.ModuleList([ self,
MiniCPM3DecoderLayer(self.config, self.cache_config, prefix: str,
self.quant_config) config: PretrainedConfig,
for _ in range(self.config.num_hidden_layers) cache_config: Optional[CacheConfig],
]) quant_config: Optional[QuantizationConfig],
):
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: MiniCPM3DecoderLayer(config, cache_config,
quant_config),
prefix=f"{prefix}.layers")
class MiniCPM3ForCausalLM(MiniCPMForCausalLM): class MiniCPM3ForCausalLM(MiniCPMForCausalLM):
......
...@@ -45,7 +45,6 @@ from vllm.model_executor.layers.sampler import Sampler, SamplerOutput ...@@ -45,7 +45,6 @@ from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.utils import set_default_torch_dtype from vllm.model_executor.model_loader.utils import set_default_torch_dtype
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import SupportsMultiModal
from vllm.model_executor.models.llama import LlamaModel from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.models.minicpm import MiniCPMModel from vllm.model_executor.models.minicpm import MiniCPMModel
from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.module_mapping import MultiModelKeys
...@@ -59,7 +58,8 @@ from vllm.multimodal.utils import cached_get_tokenizer ...@@ -59,7 +58,8 @@ from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import IntermediateTensors, SequenceData from vllm.sequence import IntermediateTensors, SequenceData
from .idefics2_vision_model import Idefics2VisionTransformer from .idefics2_vision_model import Idefics2VisionTransformer
from .interfaces import SupportsLoRA from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
from .utils import is_pp_missing_parameter
_KEYS_TO_MODIFY_MAPPING = { _KEYS_TO_MODIFY_MAPPING = {
"llm.lm_head": "lm_head", "llm.lm_head": "lm_head",
...@@ -337,7 +337,7 @@ def input_mapper_for_minicpmv(ctx: InputContext, data: object): ...@@ -337,7 +337,7 @@ def input_mapper_for_minicpmv(ctx: InputContext, data: object):
return MultiModalInputs(batch_data) return MultiModalInputs(batch_data)
class MiniCPMVBaseModel(nn.Module, SupportsMultiModal): class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
""" """
The abstract class of MiniCPMV can only be inherited, but cannot be The abstract class of MiniCPMV can only be inherited, but cannot be
instantiated. instantiated.
...@@ -374,6 +374,9 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal): ...@@ -374,6 +374,9 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal):
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
self.llm.make_empty_intermediate_tensors)
def get_embedding( def get_embedding(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -498,6 +501,9 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal): ...@@ -498,6 +501,9 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal):
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs: Any, **kwargs: Any,
) -> torch.Tensor: ) -> torch.Tensor:
if intermediate_tensors is not None:
vlm_embeddings = None
else:
image_inputs = self._parse_and_validate_inputs(input_ids, **kwargs) image_inputs = self._parse_and_validate_inputs(input_ids, **kwargs)
vlm_embeddings, _ = self.get_embedding(input_ids, image_inputs) vlm_embeddings, _ = self.get_embedding(input_ids, image_inputs)
...@@ -557,6 +563,9 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal): ...@@ -557,6 +563,9 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal):
for param_name, weight_name, shard_id in stacked_params_mapping: for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
continue continue
if is_pp_missing_parameter(
name.replace(weight_name, param_name), self):
continue
param = params_dict[name.replace(weight_name, param_name)] param = params_dict[name.replace(weight_name, param_name)]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
...@@ -564,6 +573,8 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal): ...@@ -564,6 +573,8 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal):
else: else:
use_default_weight_loading = True use_default_weight_loading = True
if use_default_weight_loading: if use_default_weight_loading:
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
......
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only Mixtral model.""" """Inference-only Mixtral model."""
from typing import Iterable, List, Optional, Tuple from typing import Iterable, List, Optional, Tuple, Union
import torch import torch
from torch import nn from torch import nn
...@@ -36,8 +36,7 @@ from vllm.model_executor.layers.linear import (QKVParallelLinear, ...@@ -36,8 +36,7 @@ from vllm.model_executor.layers.linear import (QKVParallelLinear,
ReplicatedLinear, ReplicatedLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization import QuantizationConfig
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
...@@ -47,8 +46,9 @@ from vllm.model_executor.model_loader.weight_utils import ( ...@@ -47,8 +46,9 @@ from vllm.model_executor.model_loader.weight_utils import (
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA from .interfaces import SupportsLoRA, SupportsPP
from .utils import is_pp_missing_parameter, make_layers from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
class MixtralMoE(nn.Module): class MixtralMoE(nn.Module):
...@@ -276,6 +276,9 @@ class MixtralModel(nn.Module): ...@@ -276,6 +276,9 @@ class MixtralModel(nn.Module):
prefix=f"{prefix}.layers") prefix=f"{prefix}.layers")
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
def forward( def forward(
self, self,
...@@ -284,7 +287,7 @@ class MixtralModel(nn.Module): ...@@ -284,7 +287,7 @@ class MixtralModel(nn.Module):
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors], intermediate_tensors: Optional[IntermediateTensors],
) -> torch.Tensor: ) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank: if get_pp_group().is_first_rank:
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
residual = None residual = None
...@@ -306,7 +309,7 @@ class MixtralModel(nn.Module): ...@@ -306,7 +309,7 @@ class MixtralModel(nn.Module):
return hidden_states return hidden_states
class MixtralForCausalLM(nn.Module, SupportsLoRA): class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
fall_back_to_pt_during_load = False fall_back_to_pt_during_load = False
packed_modules_mapping = { packed_modules_mapping = {
...@@ -365,6 +368,8 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA): ...@@ -365,6 +368,8 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size) config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def forward( def forward(
self, self,
...@@ -373,7 +378,7 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA): ...@@ -373,7 +378,7 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors) attn_metadata, intermediate_tensors)
return hidden_states return hidden_states
...@@ -387,20 +392,6 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA): ...@@ -387,20 +392,6 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
sampling_metadata) sampling_metadata)
return logits return logits
def make_empty_intermediate_tensors(
self, batch_size: int, dtype: torch.dtype,
device: torch.device) -> IntermediateTensors:
return IntermediateTensors({
"hidden_states":
torch.zeros((batch_size, self.config.hidden_size),
dtype=dtype,
device=device),
"residual":
torch.zeros((batch_size, self.config.hidden_size),
dtype=dtype,
device=device),
})
def sample( def sample(
self, self,
logits: Optional[torch.Tensor], logits: Optional[torch.Tensor],
......
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only Mixtral model.""" """Inference-only Mixtral model."""
from typing import Iterable, List, Optional, Tuple from typing import Iterable, List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
...@@ -31,7 +31,7 @@ from transformers import MixtralConfig ...@@ -31,7 +31,7 @@ from transformers import MixtralConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig from vllm.config import CacheConfig
from vllm.distributed import (get_tensor_model_parallel_rank, from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce) tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
...@@ -39,8 +39,7 @@ from vllm.model_executor.layers.linear import (QKVParallelLinear, ...@@ -39,8 +39,7 @@ from vllm.model_executor.layers.linear import (QKVParallelLinear,
ReplicatedLinear, ReplicatedLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization import QuantizationConfig
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
...@@ -49,6 +48,10 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader ...@@ -49,6 +48,10 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
class MixtralMLP(nn.Module): class MixtralMLP(nn.Module):
...@@ -296,6 +299,7 @@ class MixtralModel(nn.Module): ...@@ -296,6 +299,7 @@ class MixtralModel(nn.Module):
config: MixtralConfig, config: MixtralConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
...@@ -305,13 +309,15 @@ class MixtralModel(nn.Module): ...@@ -305,13 +309,15 @@ class MixtralModel(nn.Module):
config.vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
) )
self.layers = nn.ModuleList([ self.start_layer, self.end_layer, self.layers = make_layers(
MixtralDecoderLayer(config, config.num_hidden_layers,
cache_config, lambda prefix: MixtralDecoderLayer(
quant_config=quant_config) config, cache_config, quant_config=quant_config),
for _ in range(config.num_hidden_layers) prefix=f"{prefix}.layers")
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
def forward( def forward(
self, self,
...@@ -319,19 +325,30 @@ class MixtralModel(nn.Module): ...@@ -319,19 +325,30 @@ class MixtralModel(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: intermediate_tensors: Optional[IntermediateTensors],
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
residual = None residual = None
for i in range(len(self.layers)): else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i] layer = self.layers[i]
hidden_states, residual = layer(positions, hidden_states, hidden_states, residual = layer(positions, hidden_states,
kv_caches[i], attn_metadata, kv_caches[i - self.start_layer],
residual) attn_metadata, residual)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states return hidden_states
class MixtralForCausalLM(nn.Module): class MixtralForCausalLM(nn.Module, SupportsPP):
fall_back_to_pt_during_load = False fall_back_to_pt_during_load = False
def __init__( def __init__(
...@@ -351,6 +368,8 @@ class MixtralForCausalLM(nn.Module): ...@@ -351,6 +368,8 @@ class MixtralForCausalLM(nn.Module):
self.lm_head.weight = self.model.embed_tokens.weight self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def forward( def forward(
self, self,
...@@ -359,9 +378,9 @@ class MixtralForCausalLM(nn.Module): ...@@ -359,9 +378,9 @@ class MixtralForCausalLM(nn.Module):
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata) attn_metadata, intermediate_tensors)
return hidden_states return hidden_states
def compute_logits( def compute_logits(
...@@ -400,6 +419,8 @@ class MixtralForCausalLM(nn.Module): ...@@ -400,6 +419,8 @@ class MixtralForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name] param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
...@@ -412,6 +433,8 @@ class MixtralForCausalLM(nn.Module): ...@@ -412,6 +433,8 @@ class MixtralForCausalLM(nn.Module):
if ("block_sparse_moe.experts." in name if ("block_sparse_moe.experts." in name
and name not in params_dict): and name not in params_dict):
continue continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment