Unverified Commit 0f6d7a9a authored by Murali Andoorveedu's avatar Murali Andoorveedu Committed by GitHub
Browse files

[Models] Add remaining model PP support (#7168)


Signed-off-by: default avatarMuralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
Signed-off-by: default avatarMurali Andoorveedu <muralidhar.andoorveedu@centml.ai>
Co-authored-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 303d4479
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
# limitations under the License. # limitations under the License.
"""Inference-only MiniCPM model compatible with HuggingFace weights.""" """Inference-only MiniCPM model compatible with HuggingFace weights."""
import math import math
from typing import Any, Dict, Iterable, List, Optional, Tuple from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
import torch import torch
from torch import nn from torch import nn
...@@ -30,7 +30,7 @@ from transformers import PretrainedConfig ...@@ -30,7 +30,7 @@ from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import (get_tensor_model_parallel_rank, from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce) tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
...@@ -41,8 +41,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, ...@@ -41,8 +41,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
ReplicatedLinear, ReplicatedLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization import QuantizationConfig
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
...@@ -52,7 +51,9 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata ...@@ -52,7 +51,9 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA from .interfaces import SupportsLoRA, SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
class MiniCPMMoE(nn.Module): class MiniCPMMoE(nn.Module):
...@@ -264,7 +265,7 @@ class MiniCPMDecoderLayer(nn.Module): ...@@ -264,7 +265,7 @@ class MiniCPMDecoderLayer(nn.Module):
def __init__( def __init__(
self, self,
config, config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
...@@ -346,10 +347,11 @@ class MiniCPMModel(nn.Module): ...@@ -346,10 +347,11 @@ class MiniCPMModel(nn.Module):
def __init__( def __init__(
self, self,
config, config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None, lora_config: Optional[LoRAConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -365,15 +367,24 @@ class MiniCPMModel(nn.Module): ...@@ -365,15 +367,24 @@ class MiniCPMModel(nn.Module):
config.hidden_size, config.hidden_size,
org_num_embeddings=config.vocab_size, org_num_embeddings=config.vocab_size,
) )
self._init_layers() self._init_layers(prefix, config, cache_config, quant_config)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], self.config.hidden_size))
def _init_layers(self): def _init_layers(
self.layers = nn.ModuleList([ self,
MiniCPMDecoderLayer(self.config, self.cache_config, prefix: str,
self.quant_config) config: PretrainedConfig,
for _ in range(self.config.num_hidden_layers) cache_config: Optional[CacheConfig],
]) quant_config: Optional[QuantizationConfig],
):
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: MiniCPMDecoderLayer(config, cache_config,
quant_config),
prefix=f"{prefix}.layers")
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
embedding = self.embed_tokens(input_ids) embedding = self.embed_tokens(input_ids)
...@@ -387,27 +398,36 @@ class MiniCPMModel(nn.Module): ...@@ -387,27 +398,36 @@ class MiniCPMModel(nn.Module):
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> Union[torch.Tensor, IntermediateTensors]:
if inputs_embeds is not None: if get_pp_group().is_first_rank:
hidden_states = inputs_embeds if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None
else: else:
hidden_states = self.get_input_embeddings(input_ids) hidden_states = intermediate_tensors["hidden_states"]
residual = None residual = intermediate_tensors["residual"]
for i in range(len(self.layers)): for i in range(self.start_layer, self.end_layer):
layer = self.layers[i] layer = self.layers[i]
hidden_states, residual = layer( hidden_states, residual = layer(
positions, positions,
hidden_states, hidden_states,
kv_caches[i], kv_caches[i - self.start_layer],
attn_metadata, attn_metadata,
residual, residual,
) )
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states = self.norm(hidden_states) hidden_states = self.norm(hidden_states)
return hidden_states return hidden_states
class MiniCPMForCausalLM(nn.Module, SupportsLoRA): class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": [ "qkv_proj": [
"q_proj", "q_proj",
...@@ -470,6 +490,8 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA): ...@@ -470,6 +490,8 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA):
self.logits_processor = LogitsProcessor(unpadded_vocab_size, self.logits_processor = LogitsProcessor(unpadded_vocab_size,
config.vocab_size) config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def _init_model(self): def _init_model(self):
self.model = MiniCPMModel(config=self.config, self.model = MiniCPMModel(config=self.config,
...@@ -484,7 +506,7 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA): ...@@ -484,7 +506,7 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA):
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors) attn_metadata, intermediate_tensors)
return hidden_states return hidden_states
...@@ -548,6 +570,8 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA): ...@@ -548,6 +570,8 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA):
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name] param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
...@@ -557,6 +581,8 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA): ...@@ -557,6 +581,8 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA):
if weight_name not in name: if weight_name not in name:
continue continue
name = name.replace(weight_name, param_name) name = name.replace(weight_name, param_name)
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name] param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, weight_loader(param,
...@@ -568,6 +594,8 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA): ...@@ -568,6 +594,8 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA):
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
......
...@@ -26,6 +26,7 @@ from typing import Any, Dict, Optional ...@@ -26,6 +26,7 @@ from typing import Any, Dict, Optional
import torch import torch
from torch import nn from torch import nn
from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig from vllm.config import CacheConfig
...@@ -34,19 +35,20 @@ from vllm.model_executor.layers.layernorm import RMSNorm ...@@ -34,19 +35,20 @@ from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
ReplicatedLinear, ReplicatedLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization import QuantizationConfig
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.models.minicpm import (MiniCPMDecoderLayer, from vllm.model_executor.models.minicpm import (MiniCPMDecoderLayer,
MiniCPMForCausalLM, MiniCPMForCausalLM,
MiniCPMModel) MiniCPMModel)
from .utils import make_layers
class MiniCPM3Attention(nn.Module): class MiniCPM3Attention(nn.Module):
def __init__( def __init__(
self, self,
config, config: PretrainedConfig,
hidden_size: int, hidden_size: int,
num_heads: int, num_heads: int,
qk_nope_head_dim: int, qk_nope_head_dim: int,
...@@ -199,12 +201,18 @@ class MiniCPM3DecoderLayer(MiniCPMDecoderLayer): ...@@ -199,12 +201,18 @@ class MiniCPM3DecoderLayer(MiniCPMDecoderLayer):
class MiniCPM3Model(MiniCPMModel): class MiniCPM3Model(MiniCPMModel):
def _init_layers(self): def _init_layers(
self.layers = nn.ModuleList([ self,
MiniCPM3DecoderLayer(self.config, self.cache_config, prefix: str,
self.quant_config) config: PretrainedConfig,
for _ in range(self.config.num_hidden_layers) cache_config: Optional[CacheConfig],
]) quant_config: Optional[QuantizationConfig],
):
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: MiniCPM3DecoderLayer(config, cache_config,
quant_config),
prefix=f"{prefix}.layers")
class MiniCPM3ForCausalLM(MiniCPMForCausalLM): class MiniCPM3ForCausalLM(MiniCPMForCausalLM):
......
...@@ -45,7 +45,6 @@ from vllm.model_executor.layers.sampler import Sampler, SamplerOutput ...@@ -45,7 +45,6 @@ from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.utils import set_default_torch_dtype from vllm.model_executor.model_loader.utils import set_default_torch_dtype
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import SupportsMultiModal
from vllm.model_executor.models.llama import LlamaModel from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.models.minicpm import MiniCPMModel from vllm.model_executor.models.minicpm import MiniCPMModel
from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.module_mapping import MultiModelKeys
...@@ -59,7 +58,8 @@ from vllm.multimodal.utils import cached_get_tokenizer ...@@ -59,7 +58,8 @@ from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import IntermediateTensors, SequenceData from vllm.sequence import IntermediateTensors, SequenceData
from .idefics2_vision_model import Idefics2VisionTransformer from .idefics2_vision_model import Idefics2VisionTransformer
from .interfaces import SupportsLoRA from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
from .utils import is_pp_missing_parameter
_KEYS_TO_MODIFY_MAPPING = { _KEYS_TO_MODIFY_MAPPING = {
"llm.lm_head": "lm_head", "llm.lm_head": "lm_head",
...@@ -337,7 +337,7 @@ def input_mapper_for_minicpmv(ctx: InputContext, data: object): ...@@ -337,7 +337,7 @@ def input_mapper_for_minicpmv(ctx: InputContext, data: object):
return MultiModalInputs(batch_data) return MultiModalInputs(batch_data)
class MiniCPMVBaseModel(nn.Module, SupportsMultiModal): class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP):
""" """
The abstract class of MiniCPMV can only be inherited, but cannot be The abstract class of MiniCPMV can only be inherited, but cannot be
instantiated. instantiated.
...@@ -374,6 +374,9 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal): ...@@ -374,6 +374,9 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal):
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
self.llm.make_empty_intermediate_tensors)
def get_embedding( def get_embedding(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -498,9 +501,12 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal): ...@@ -498,9 +501,12 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal):
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs: Any, **kwargs: Any,
) -> torch.Tensor: ) -> torch.Tensor:
image_inputs = self._parse_and_validate_inputs(input_ids, **kwargs) if intermediate_tensors is not None:
vlm_embeddings = None
else:
image_inputs = self._parse_and_validate_inputs(input_ids, **kwargs)
vlm_embeddings, _ = self.get_embedding(input_ids, image_inputs) vlm_embeddings, _ = self.get_embedding(input_ids, image_inputs)
output = self.llm( output = self.llm(
input_ids=None, input_ids=None,
...@@ -557,6 +563,9 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal): ...@@ -557,6 +563,9 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal):
for param_name, weight_name, shard_id in stacked_params_mapping: for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
continue continue
if is_pp_missing_parameter(
name.replace(weight_name, param_name), self):
continue
param = params_dict[name.replace(weight_name, param_name)] param = params_dict[name.replace(weight_name, param_name)]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
...@@ -564,6 +573,8 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal): ...@@ -564,6 +573,8 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal):
else: else:
use_default_weight_loading = True use_default_weight_loading = True
if use_default_weight_loading: if use_default_weight_loading:
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
......
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only Mixtral model.""" """Inference-only Mixtral model."""
from typing import Iterable, List, Optional, Tuple from typing import Iterable, List, Optional, Tuple, Union
import torch import torch
from torch import nn from torch import nn
...@@ -36,8 +36,7 @@ from vllm.model_executor.layers.linear import (QKVParallelLinear, ...@@ -36,8 +36,7 @@ from vllm.model_executor.layers.linear import (QKVParallelLinear,
ReplicatedLinear, ReplicatedLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization import QuantizationConfig
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
...@@ -47,8 +46,9 @@ from vllm.model_executor.model_loader.weight_utils import ( ...@@ -47,8 +46,9 @@ from vllm.model_executor.model_loader.weight_utils import (
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA from .interfaces import SupportsLoRA, SupportsPP
from .utils import is_pp_missing_parameter, make_layers from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
class MixtralMoE(nn.Module): class MixtralMoE(nn.Module):
...@@ -276,6 +276,9 @@ class MixtralModel(nn.Module): ...@@ -276,6 +276,9 @@ class MixtralModel(nn.Module):
prefix=f"{prefix}.layers") prefix=f"{prefix}.layers")
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
def forward( def forward(
self, self,
...@@ -284,7 +287,7 @@ class MixtralModel(nn.Module): ...@@ -284,7 +287,7 @@ class MixtralModel(nn.Module):
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors], intermediate_tensors: Optional[IntermediateTensors],
) -> torch.Tensor: ) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank: if get_pp_group().is_first_rank:
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
residual = None residual = None
...@@ -306,7 +309,7 @@ class MixtralModel(nn.Module): ...@@ -306,7 +309,7 @@ class MixtralModel(nn.Module):
return hidden_states return hidden_states
class MixtralForCausalLM(nn.Module, SupportsLoRA): class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
fall_back_to_pt_during_load = False fall_back_to_pt_during_load = False
packed_modules_mapping = { packed_modules_mapping = {
...@@ -365,6 +368,8 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA): ...@@ -365,6 +368,8 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size) config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def forward( def forward(
self, self,
...@@ -373,7 +378,7 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA): ...@@ -373,7 +378,7 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors) attn_metadata, intermediate_tensors)
return hidden_states return hidden_states
...@@ -387,20 +392,6 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA): ...@@ -387,20 +392,6 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
sampling_metadata) sampling_metadata)
return logits return logits
def make_empty_intermediate_tensors(
self, batch_size: int, dtype: torch.dtype,
device: torch.device) -> IntermediateTensors:
return IntermediateTensors({
"hidden_states":
torch.zeros((batch_size, self.config.hidden_size),
dtype=dtype,
device=device),
"residual":
torch.zeros((batch_size, self.config.hidden_size),
dtype=dtype,
device=device),
})
def sample( def sample(
self, self,
logits: Optional[torch.Tensor], logits: Optional[torch.Tensor],
......
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only Mixtral model.""" """Inference-only Mixtral model."""
from typing import Iterable, List, Optional, Tuple from typing import Iterable, List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
...@@ -31,7 +31,7 @@ from transformers import MixtralConfig ...@@ -31,7 +31,7 @@ from transformers import MixtralConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig from vllm.config import CacheConfig
from vllm.distributed import (get_tensor_model_parallel_rank, from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce) tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
...@@ -39,8 +39,7 @@ from vllm.model_executor.layers.linear import (QKVParallelLinear, ...@@ -39,8 +39,7 @@ from vllm.model_executor.layers.linear import (QKVParallelLinear,
ReplicatedLinear, ReplicatedLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization import QuantizationConfig
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
...@@ -49,6 +48,10 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader ...@@ -49,6 +48,10 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
class MixtralMLP(nn.Module): class MixtralMLP(nn.Module):
...@@ -296,6 +299,7 @@ class MixtralModel(nn.Module): ...@@ -296,6 +299,7 @@ class MixtralModel(nn.Module):
config: MixtralConfig, config: MixtralConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
...@@ -305,13 +309,15 @@ class MixtralModel(nn.Module): ...@@ -305,13 +309,15 @@ class MixtralModel(nn.Module):
config.vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
) )
self.layers = nn.ModuleList([ self.start_layer, self.end_layer, self.layers = make_layers(
MixtralDecoderLayer(config, config.num_hidden_layers,
cache_config, lambda prefix: MixtralDecoderLayer(
quant_config=quant_config) config, cache_config, quant_config=quant_config),
for _ in range(config.num_hidden_layers) prefix=f"{prefix}.layers")
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
def forward( def forward(
self, self,
...@@ -319,19 +325,30 @@ class MixtralModel(nn.Module): ...@@ -319,19 +325,30 @@ class MixtralModel(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: intermediate_tensors: Optional[IntermediateTensors],
hidden_states = self.embed_tokens(input_ids) ) -> Union[torch.Tensor, IntermediateTensors]:
residual = None if get_pp_group().is_first_rank:
for i in range(len(self.layers)): hidden_states = self.embed_tokens(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i] layer = self.layers[i]
hidden_states, residual = layer(positions, hidden_states, hidden_states, residual = layer(positions, hidden_states,
kv_caches[i], attn_metadata, kv_caches[i - self.start_layer],
residual) attn_metadata, residual)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states return hidden_states
class MixtralForCausalLM(nn.Module): class MixtralForCausalLM(nn.Module, SupportsPP):
fall_back_to_pt_during_load = False fall_back_to_pt_during_load = False
def __init__( def __init__(
...@@ -351,6 +368,8 @@ class MixtralForCausalLM(nn.Module): ...@@ -351,6 +368,8 @@ class MixtralForCausalLM(nn.Module):
self.lm_head.weight = self.model.embed_tokens.weight self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def forward( def forward(
self, self,
...@@ -359,9 +378,9 @@ class MixtralForCausalLM(nn.Module): ...@@ -359,9 +378,9 @@ class MixtralForCausalLM(nn.Module):
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata) attn_metadata, intermediate_tensors)
return hidden_states return hidden_states
def compute_logits( def compute_logits(
...@@ -400,6 +419,8 @@ class MixtralForCausalLM(nn.Module): ...@@ -400,6 +419,8 @@ class MixtralForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name] param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
...@@ -412,6 +433,8 @@ class MixtralForCausalLM(nn.Module): ...@@ -412,6 +433,8 @@ class MixtralForCausalLM(nn.Module):
if ("block_sparse_moe.experts." in name if ("block_sparse_moe.experts." in name
and name not in params_dict): and name not in params_dict):
continue continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
......
# coding=utf-8 # coding=utf-8
# Adapted from https://huggingface.co/mosaicml/mpt-7b/tree/main # Adapted from https://huggingface.co/mosaicml/mpt-7b/tree/main
import math import math
from typing import Iterable, List, Optional, Tuple from typing import Iterable, List, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig from vllm.config import CacheConfig
from vllm.distributed import (get_tensor_model_parallel_rank, from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization import QuantizationConfig
QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
...@@ -25,6 +24,10 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata ...@@ -25,6 +24,10 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.mpt import MPTConfig from vllm.transformers_utils.configs.mpt import MPTConfig
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
def _get_alibi_slopes( def _get_alibi_slopes(
total_num_heads: int, total_num_heads: int,
...@@ -208,6 +211,7 @@ class MPTModel(nn.Module): ...@@ -208,6 +211,7 @@ class MPTModel(nn.Module):
config: MPTConfig, config: MPTConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
assert config.embedding_fraction == 1.0 assert config.embedding_fraction == 1.0
...@@ -217,10 +221,10 @@ class MPTModel(nn.Module): ...@@ -217,10 +221,10 @@ class MPTModel(nn.Module):
config.vocab_size, config.vocab_size,
config.d_model, config.d_model,
) )
self.blocks = nn.ModuleList([ self.start_layer, self.end_layer, self.blocks = make_layers(
MPTBlock(config, cache_config, quant_config) config.n_layers,
for _ in range(config.n_layers) lambda prefix: MPTBlock(config, cache_config, quant_config),
]) prefix=f"{prefix}.blocks")
self.norm_f = nn.LayerNorm(config.d_model) self.norm_f = nn.LayerNorm(config.d_model)
if config.no_bias: if config.no_bias:
for module in self.modules(): for module in self.modules():
...@@ -228,6 +232,9 @@ class MPTModel(nn.Module): ...@@ -228,6 +232,9 @@ class MPTModel(nn.Module):
module.bias, nn.Parameter): module.bias, nn.Parameter):
# Remove the bias term in Linear and LayerNorm. # Remove the bias term in Linear and LayerNorm.
module.register_parameter("bias", None) module.register_parameter("bias", None)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(["hidden_states"],
config.d_model))
def forward( def forward(
self, self,
...@@ -235,21 +242,29 @@ class MPTModel(nn.Module): ...@@ -235,21 +242,29 @@ class MPTModel(nn.Module):
position_ids: torch.Tensor, position_ids: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: intermediate_tensors: Optional[IntermediateTensors],
hidden_states = self.wte(input_ids) ) -> Union[torch.Tensor, IntermediateTensors]:
for i in range(len(self.blocks)): if get_pp_group().is_first_rank:
hidden_states = self.wte(input_ids)
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
for i in range(self.start_layer, self.end_layer):
block = self.blocks[i] block = self.blocks[i]
hidden_states = block( hidden_states = block(
position_ids, position_ids,
hidden_states, hidden_states,
kv_caches[i], kv_caches[i - self.start_layer],
attn_metadata, attn_metadata,
) )
if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})
hidden_states = self.norm_f(hidden_states) hidden_states = self.norm_f(hidden_states)
return hidden_states return hidden_states
class MPTForCausalLM(nn.Module): class MPTForCausalLM(nn.Module, SupportsPP):
def __init__( def __init__(
self, self,
...@@ -266,6 +281,8 @@ class MPTForCausalLM(nn.Module): ...@@ -266,6 +281,8 @@ class MPTForCausalLM(nn.Module):
self.lm_head = self.transformer.wte self.lm_head = self.transformer.wte
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
self.transformer.make_empty_intermediate_tensors)
def forward( def forward(
self, self,
...@@ -274,9 +291,9 @@ class MPTForCausalLM(nn.Module): ...@@ -274,9 +291,9 @@ class MPTForCausalLM(nn.Module):
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions, kv_caches,
attn_metadata) attn_metadata, intermediate_tensors)
return hidden_states return hidden_states
def compute_logits( def compute_logits(
...@@ -302,6 +319,8 @@ class MPTForCausalLM(nn.Module): ...@@ -302,6 +319,8 @@ class MPTForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
......
...@@ -34,8 +34,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, ...@@ -34,8 +34,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization import QuantizationConfig
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
...@@ -46,8 +45,9 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata ...@@ -46,8 +45,9 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs import NemotronConfig from vllm.transformers_utils.configs import NemotronConfig
from .interfaces import SupportsLoRA from .interfaces import SupportsLoRA, SupportsPP
from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers from .utils import (PPMissingLayer, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
# The architecture is pretty similar to Llama, with these changes: # The architecture is pretty similar to Llama, with these changes:
# - There is no gate_proj, just up_proj # - There is no gate_proj, just up_proj
...@@ -328,6 +328,9 @@ class NemotronModel(nn.Module): ...@@ -328,6 +328,9 @@ class NemotronModel(nn.Module):
eps=config.norm_eps) eps=config.norm_eps)
else: else:
self.norm = PPMissingLayer() self.norm = PPMissingLayer()
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids)
...@@ -372,7 +375,7 @@ class NemotronModel(nn.Module): ...@@ -372,7 +375,7 @@ class NemotronModel(nn.Module):
return hidden_states return hidden_states
class NemotronForCausalLM(nn.Module, SupportsLoRA): class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": [ "qkv_proj": [
"q_proj", "q_proj",
...@@ -440,6 +443,8 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA): ...@@ -440,6 +443,8 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA):
self.sampler = Sampler() self.sampler = Sampler()
else: else:
self.lm_head = PPMissingLayer() self.lm_head = PPMissingLayer()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def forward( def forward(
self, self,
...@@ -470,20 +475,6 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA): ...@@ -470,20 +475,6 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA):
next_tokens = self.sampler(logits, sampling_metadata) next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens return next_tokens
def make_empty_intermediate_tensors(
self, batch_size: int, dtype: torch.dtype,
device: torch.device) -> IntermediateTensors:
return IntermediateTensors({
"hidden_states":
torch.zeros((batch_size, self.config.hidden_size),
dtype=dtype,
device=device),
"residual":
torch.zeros((batch_size, self.config.hidden_size),
dtype=dtype,
device=device),
})
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
......
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only OLMo model compatible with HuggingFace weights.""" """Inference-only OLMo model compatible with HuggingFace weights."""
from typing import Iterable, List, Optional, Tuple from typing import Iterable, List, Optional, Tuple, Union
import torch import torch
from torch import nn from torch import nn
...@@ -29,14 +29,13 @@ from transformers import OlmoConfig ...@@ -29,14 +29,13 @@ from transformers import OlmoConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig from vllm.config import CacheConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization import QuantizationConfig
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
...@@ -45,6 +44,10 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader ...@@ -45,6 +44,10 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
class OlmoAttention(nn.Module): class OlmoAttention(nn.Module):
""" """
...@@ -223,19 +226,24 @@ class OlmoModel(nn.Module): ...@@ -223,19 +226,24 @@ class OlmoModel(nn.Module):
def __init__(self, def __init__(self,
config: OlmoConfig, config: OlmoConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None): quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__() super().__init__()
self.config = config self.config = config
self.embed_tokens = VocabParallelEmbedding(config.vocab_size, self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
config.hidden_size) config.hidden_size)
self.layers = nn.ModuleList([ self.start_layer, self.end_layer, self.layers = make_layers(
OlmoDecoderLayer(config, cache_config, quant_config) config.num_hidden_layers,
for layer_idx in range(config.num_hidden_layers) lambda prefix: OlmoDecoderLayer(config, cache_config, quant_config
]) ),
prefix=f"{prefix}.layers")
self.norm = nn.LayerNorm(config.hidden_size, self.norm = nn.LayerNorm(config.hidden_size,
elementwise_affine=False, elementwise_affine=False,
bias=False) bias=False)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(["hidden_states"],
config.hidden_size))
def forward( def forward(
self, self,
...@@ -243,34 +251,41 @@ class OlmoModel(nn.Module): ...@@ -243,34 +251,41 @@ class OlmoModel(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: intermediate_tensors: Optional[IntermediateTensors],
) -> Union[torch.Tensor, IntermediateTensors]:
""" """
:param input_ids: A tensor of shape `(batch_size, seq_len)`. :param input_ids: A tensor of shape `(batch_size, seq_len)`.
""" """
# Get embeddings of input. if get_pp_group().is_first_rank:
# shape: (batch_size, seq_len, d_model) # Get embeddings of input.
inputs_embeds = self.embed_tokens(input_ids) # shape: (batch_size, seq_len, d_model)
inputs_embeds = self.embed_tokens(input_ids)
# embed positions # embed positions
hidden_states = inputs_embeds hidden_states = inputs_embeds
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
# Apply blocks one-by-one. # Apply blocks one-by-one.
for layer_idx, decoder_layer in enumerate(self.layers): for i in range(self.start_layer, self.end_layer):
# shape: (batch_size, seq_len, d_model) # shape: (batch_size, seq_len, d_model)
hidden_states = decoder_layer( hidden_states = self.layers[i](
positions, positions,
hidden_states, hidden_states,
kv_caches[layer_idx], kv_caches[i - self.start_layer],
attn_metadata, attn_metadata,
) )
if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})
# Apply final layer norm. # Apply final layer norm.
# shape: (batch_size, seq_len or 1, d_model) # shape: (batch_size, seq_len or 1, d_model)
hidden_states = self.norm(hidden_states) hidden_states = self.norm(hidden_states)
return hidden_states return hidden_states
class OlmoForCausalLM(nn.Module): class OlmoForCausalLM(nn.Module, SupportsPP):
""" """
Extremely barebones HF model wrapper. Extremely barebones HF model wrapper.
""" """
...@@ -294,6 +309,8 @@ class OlmoForCausalLM(nn.Module): ...@@ -294,6 +309,8 @@ class OlmoForCausalLM(nn.Module):
) )
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def forward( def forward(
self, self,
...@@ -302,12 +319,13 @@ class OlmoForCausalLM(nn.Module): ...@@ -302,12 +319,13 @@ class OlmoForCausalLM(nn.Module):
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model( hidden_states = self.model(
input_ids=input_ids, input_ids=input_ids,
positions=positions, positions=positions,
kv_caches=kv_caches, kv_caches=kv_caches,
attn_metadata=attn_metadata, attn_metadata=attn_metadata,
intermediate_tensors=intermediate_tensors,
) )
return hidden_states return hidden_states
...@@ -358,6 +376,8 @@ class OlmoForCausalLM(nn.Module): ...@@ -358,6 +376,8 @@ class OlmoForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name] param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
...@@ -366,6 +386,8 @@ class OlmoForCausalLM(nn.Module): ...@@ -366,6 +386,8 @@ class OlmoForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
......
...@@ -10,7 +10,7 @@ ...@@ -10,7 +10,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only OLMoE model compatible with HuggingFace weights.""" """Inference-only OLMoE model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, List, Optional, Tuple from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
import torch import torch
from torch import nn from torch import nn
...@@ -18,15 +18,14 @@ from transformers import PretrainedConfig ...@@ -18,15 +18,14 @@ from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig from vllm.config import CacheConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (QKVParallelLinear, from vllm.model_executor.layers.linear import (QKVParallelLinear,
ReplicatedLinear, ReplicatedLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization import QuantizationConfig
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
...@@ -36,6 +35,10 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata ...@@ -36,6 +35,10 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import print_warning_once from vllm.utils import print_warning_once
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
class OlmoeMoE(nn.Module): class OlmoeMoE(nn.Module):
"""A tensor-parallel MoE implementation for Olmoe that shards each expert """A tensor-parallel MoE implementation for Olmoe that shards each expert
...@@ -243,6 +246,7 @@ class OlmoeModel(nn.Module): ...@@ -243,6 +246,7 @@ class OlmoeModel(nn.Module):
config: PretrainedConfig, config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
...@@ -252,34 +256,54 @@ class OlmoeModel(nn.Module): ...@@ -252,34 +256,54 @@ class OlmoeModel(nn.Module):
config.vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
) )
self.layers = nn.ModuleList([ self.start_layer, self.end_layer, self.layers = make_layers(
OlmoeDecoderLayer(config, config.num_hidden_layers,
layer_idx, lambda prefix: OlmoeDecoderLayer(config, int(
cache_config, prefix.split(".")[-1]), cache_config, quant_config),
quant_config=quant_config) prefix=f"{prefix}.layers")
for layer_idx in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=1e-5) self.norm = RMSNorm(config.hidden_size, eps=1e-5)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: intermediate_tensors: Optional[IntermediateTensors],
hidden_states = self.embed_tokens(input_ids) ) -> Union[torch.Tensor, IntermediateTensors]:
residual = None if get_pp_group().is_first_rank:
for i in range(len(self.layers)): hidden_states = self.embed_tokens(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i] layer = self.layers[i]
hidden_states, residual = layer(positions, hidden_states, hidden_states, residual = layer(
kv_caches[i], attn_metadata, positions,
residual) hidden_states,
kv_caches[i - self.start_layer],
attn_metadata,
residual,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states return hidden_states
class OlmoeForCausalLM(nn.Module): class OlmoeForCausalLM(nn.Module, SupportsPP):
fall_back_to_pt_during_load = False fall_back_to_pt_during_load = False
...@@ -299,6 +323,9 @@ class OlmoeForCausalLM(nn.Module): ...@@ -299,6 +323,9 @@ class OlmoeForCausalLM(nn.Module):
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -306,9 +333,9 @@ class OlmoeForCausalLM(nn.Module): ...@@ -306,9 +333,9 @@ class OlmoeForCausalLM(nn.Module):
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata) attn_metadata, intermediate_tensors)
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(self, hidden_states: torch.Tensor,
...@@ -363,6 +390,9 @@ class OlmoeForCausalLM(nn.Module): ...@@ -363,6 +390,9 @@ class OlmoeForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
if name not in params_dict: if name not in params_dict:
continue continue
...@@ -376,6 +406,9 @@ class OlmoeForCausalLM(nn.Module): ...@@ -376,6 +406,9 @@ class OlmoeForCausalLM(nn.Module):
if weight_name not in name: if weight_name not in name:
continue continue
name = name.replace(weight_name, param_name) name = name.replace(weight_name, param_name)
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name] param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, weight_loader(param,
...@@ -388,6 +421,9 @@ class OlmoeForCausalLM(nn.Module): ...@@ -388,6 +421,9 @@ class OlmoeForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
# Remapping the name of FP8 kv-scale. # Remapping the name of FP8 kv-scale.
if name.endswith("kv_scale"): if name.endswith("kv_scale"):
remapped_kv_scale_name = name.replace( remapped_kv_scale_name = name.replace(
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only OPT model compatible with HuggingFace weights.""" """Inference-only OPT model compatible with HuggingFace weights."""
from typing import Iterable, List, Optional, Tuple from typing import Iterable, List, Optional, Tuple, Union
import torch import torch
from torch import nn from torch import nn
...@@ -25,15 +25,14 @@ from transformers import OPTConfig ...@@ -25,15 +25,14 @@ from transformers import OPTConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig from vllm.config import CacheConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
ReplicatedLinear, ReplicatedLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization import QuantizationConfig
QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
...@@ -41,6 +40,10 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader ...@@ -41,6 +40,10 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
class OPTLearnedPositionalEmbedding(nn.Embedding): class OPTLearnedPositionalEmbedding(nn.Embedding):
...@@ -189,6 +192,7 @@ class OPTDecoder(nn.Module): ...@@ -189,6 +192,7 @@ class OPTDecoder(nn.Module):
config: OPTConfig, config: OPTConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -232,10 +236,10 @@ class OPTDecoder(nn.Module): ...@@ -232,10 +236,10 @@ class OPTDecoder(nn.Module):
else: else:
self.final_layer_norm = None self.final_layer_norm = None
self.layers = nn.ModuleList([ self.start_layer, self.end_layer, self.layers = make_layers(
OPTDecoderLayer(config, cache_config, quant_config) config.num_hidden_layers,
for _ in range(config.num_hidden_layers) lambda prefix: OPTDecoderLayer(config, cache_config, quant_config),
]) prefix=f"{prefix}.layers")
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids)
...@@ -246,19 +250,28 @@ class OPTDecoder(nn.Module): ...@@ -246,19 +250,28 @@ class OPTDecoder(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> Union[torch.Tensor, IntermediateTensors]:
if inputs_embeds is None: if get_pp_group().is_first_rank:
inputs_embeds = self.get_input_embeddings(input_ids) if inputs_embeds is None:
pos_embeds = self.embed_positions(positions) inputs_embeds = self.get_input_embeddings(input_ids)
if self.project_in is not None: pos_embeds = self.embed_positions(positions)
inputs_embeds, _ = self.project_in(inputs_embeds) if self.project_in is not None:
hidden_states = inputs_embeds + pos_embeds inputs_embeds, _ = self.project_in(inputs_embeds)
hidden_states = inputs_embeds + pos_embeds
for i in range(len(self.layers)): else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i] layer = self.layers[i]
hidden_states = layer(hidden_states, kv_caches[i], attn_metadata) hidden_states = layer(hidden_states,
kv_caches[i - self.start_layer],
attn_metadata)
if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})
if self.final_layer_norm is not None: if self.final_layer_norm is not None:
hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.final_layer_norm(hidden_states)
if self.project_out is not None: if self.project_out is not None:
...@@ -276,6 +289,9 @@ class OPTModel(nn.Module): ...@@ -276,6 +289,9 @@ class OPTModel(nn.Module):
): ):
super().__init__() super().__init__()
self.decoder = OPTDecoder(config, cache_config, quant_config) self.decoder = OPTDecoder(config, cache_config, quant_config)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(["hidden_states"],
config.hidden_size))
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.decoder.get_input_embeddings(input_ids) return self.decoder.get_input_embeddings(input_ids)
...@@ -286,20 +302,22 @@ class OPTModel(nn.Module): ...@@ -286,20 +302,22 @@ class OPTModel(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> Union[torch.Tensor, IntermediateTensors]:
return self.decoder(input_ids, return self.decoder(input_ids,
positions, positions,
kv_caches, kv_caches,
attn_metadata, attn_metadata,
intermediate_tensors,
inputs_embeds=inputs_embeds) inputs_embeds=inputs_embeds)
class OPTForCausalLM(nn.Module): class OPTForCausalLM(nn.Module, SupportsPP):
def __init__( def __init__(
self, self,
config, config: OPTConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
...@@ -314,6 +332,8 @@ class OPTForCausalLM(nn.Module): ...@@ -314,6 +332,8 @@ class OPTForCausalLM(nn.Module):
config.word_embed_proj_dim) config.word_embed_proj_dim)
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def forward( def forward(
self, self,
...@@ -322,9 +342,9 @@ class OPTForCausalLM(nn.Module): ...@@ -322,9 +342,9 @@ class OPTForCausalLM(nn.Module):
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata) attn_metadata, intermediate_tensors)
return hidden_states return hidden_states
def compute_logits( def compute_logits(
...@@ -365,6 +385,8 @@ class OPTForCausalLM(nn.Module): ...@@ -365,6 +385,8 @@ class OPTForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name] param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
...@@ -373,6 +395,8 @@ class OPTForCausalLM(nn.Module): ...@@ -373,6 +395,8 @@ class OPTForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
# Copyright (c) OrionStar Inc. # Copyright (c) OrionStar Inc.
# LICENSE: https://huggingface.co/OrionStarAI/Orion-14B-Base/blob/main/LICENSE # LICENSE: https://huggingface.co/OrionStarAI/Orion-14B-Base/blob/main/LICENSE
"""Inference-only Orion-14B model compatible with HuggingFace weights.""" """Inference-only Orion-14B model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, List, Optional, Tuple from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
import torch import torch
from torch import nn from torch import nn
...@@ -12,14 +12,13 @@ from transformers import PretrainedConfig ...@@ -12,14 +12,13 @@ from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig from vllm.config import CacheConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization import QuantizationConfig
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
...@@ -28,6 +27,10 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader ...@@ -28,6 +27,10 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
class OrionMLP(nn.Module): class OrionMLP(nn.Module):
...@@ -210,6 +213,7 @@ class OrionModel(nn.Module): ...@@ -210,6 +213,7 @@ class OrionModel(nn.Module):
config: PretrainedConfig, config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -219,11 +223,18 @@ class OrionModel(nn.Module): ...@@ -219,11 +223,18 @@ class OrionModel(nn.Module):
config.vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
) )
self.layers = nn.ModuleList([ self.start_layer, self.end_layer, self.layers = make_layers(
OrionDecoderLayer(config, cache_config, quant_config) config.num_hidden_layers,
for _ in range(config.num_hidden_layers) lambda prefix: OrionDecoderLayer(
]) config,
cache_config,
quant_config,
),
prefix=f"{prefix}.layers")
self.norm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
def forward( def forward(
self, self,
...@@ -231,23 +242,34 @@ class OrionModel(nn.Module): ...@@ -231,23 +242,34 @@ class OrionModel(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: intermediate_tensors: Optional[IntermediateTensors],
hidden_states = self.embed_tokens(input_ids) ) -> Union[torch.Tensor, IntermediateTensors]:
residual = None if get_pp_group().is_first_rank:
for i in range(len(self.layers)): hidden_states = self.embed_tokens(input_ids)
residual = None
else:
assert intermediate_tensors
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i] layer = self.layers[i]
hidden_states, residual = layer( hidden_states, residual = layer(
positions, positions,
hidden_states, hidden_states,
kv_caches[i], kv_caches[i - self.start_layer],
attn_metadata, attn_metadata,
residual, residual,
) )
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states = self.norm(hidden_states) hidden_states = self.norm(hidden_states)
return hidden_states return hidden_states
class OrionForCausalLM(nn.Module): class OrionForCausalLM(nn.Module, SupportsPP):
def __init__( def __init__(
self, self,
...@@ -266,6 +288,8 @@ class OrionForCausalLM(nn.Module): ...@@ -266,6 +288,8 @@ class OrionForCausalLM(nn.Module):
self.lm_head.weight = self.model.embed_tokens.weight self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def forward( def forward(
self, self,
...@@ -274,9 +298,9 @@ class OrionForCausalLM(nn.Module): ...@@ -274,9 +298,9 @@ class OrionForCausalLM(nn.Module):
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata) attn_metadata, intermediate_tensors)
return hidden_states return hidden_states
def compute_logits( def compute_logits(
...@@ -321,6 +345,8 @@ class OrionForCausalLM(nn.Module): ...@@ -321,6 +345,8 @@ class OrionForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name] param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
...@@ -329,6 +355,8 @@ class OrionForCausalLM(nn.Module): ...@@ -329,6 +355,8 @@ class OrionForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
......
...@@ -9,9 +9,8 @@ from vllm.attention import AttentionMetadata ...@@ -9,9 +9,8 @@ from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.gemma import GemmaForCausalLM from vllm.model_executor.models.gemma import GemmaForCausalLM
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
...@@ -19,7 +18,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY ...@@ -19,7 +18,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.utils import cached_get_tokenizer from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsMultiModal from .interfaces import SupportsMultiModal, SupportsPP
from .siglip import (SiglipVisionModel, dummy_image_for_siglip, from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip, get_max_siglip_image_tokens) dummy_seq_data_for_siglip, get_max_siglip_image_tokens)
from .utils import group_weights_with_prefix, merge_multimodal_embeddings from .utils import group_weights_with_prefix, merge_multimodal_embeddings
...@@ -129,7 +128,8 @@ class PaliGemmaMultiModalProjector(nn.Module): ...@@ -129,7 +128,8 @@ class PaliGemmaMultiModalProjector(nn.Module):
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_paligemma_image_tokens) @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_paligemma_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_paligemma) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_paligemma)
@INPUT_REGISTRY.register_input_processor(input_processor_for_paligemma) @INPUT_REGISTRY.register_input_processor(input_processor_for_paligemma)
class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal): class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP):
def __init__(self, def __init__(self,
config: PaliGemmaConfig, config: PaliGemmaConfig,
...@@ -149,12 +149,15 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -149,12 +149,15 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal):
self.quant_config = quant_config self.quant_config = quant_config
self.language_model = GemmaForCausalLM(config.text_config, self.language_model = GemmaForCausalLM(config.text_config,
cache_config, quant_config) cache_config, quant_config)
self.unpadded_vocab_size = config.text_config.vocab_size
logit_scale = getattr(config, "logit_scale", 1.0) logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, self.language_model.logits_processor.scale *= logit_scale
config.text_config.vocab_size,
logit_scale) self.make_empty_intermediate_tensors = (
self.sampler = Sampler() self.language_model.make_empty_intermediate_tensors)
@property
def sampler(self):
return self.language_model.sampler
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
h = w = self.config.vision_config.image_size h = w = self.config.vision_config.image_size
...@@ -239,32 +242,36 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -239,32 +242,36 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal):
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs: object) -> SamplerOutput: **kwargs: object) -> Union[SamplerOutput, IntermediateTensors]:
if intermediate_tensors is not None:
parsed_image_input = self._parse_and_validate_image_input(**kwargs) input_ids = None
inputs_embeds = None
else:
parsed_image_input = self._parse_and_validate_image_input(**kwargs)
if parsed_image_input is not None: if parsed_image_input is not None:
vision_embeddings = self._process_image_input(parsed_image_input) vision_embeddings = self._process_image_input(
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/paligemma/modeling_paligemma.py#L294 # noqa parsed_image_input)
vision_embeddings = vision_embeddings * (self.config.hidden_size** # https://github.com/huggingface/transformers/blob/main/src/transformers/models/paligemma/modeling_paligemma.py#L294 # noqa
-0.5) vision_embeddings = vision_embeddings * (
self.config.hidden_size**-0.5)
inputs_embeds = self.language_model.model.get_input_embeddings( inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids) input_ids)
inputs_embeds = merge_multimodal_embeddings( inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings, input_ids, inputs_embeds, vision_embeddings,
self.config.image_token_index) self.config.image_token_index)
input_ids = None input_ids = None
else: else:
inputs_embeds = None inputs_embeds = None
hidden_states = self.language_model.model(input_ids, hidden_states = self.language_model.model(input_ids,
positions, positions,
kv_caches, kv_caches,
attn_metadata, attn_metadata,
None, intermediate_tensors,
inputs_embeds=inputs_embeds) inputs_embeds=inputs_embeds)
return hidden_states return hidden_states
......
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only persimmon model compatible with HuggingFace weights.""" """Inference-only persimmon model compatible with HuggingFace weights."""
from typing import Iterable, List, Optional, Tuple from typing import Iterable, List, Optional, Tuple, Union
import torch import torch
from torch import nn from torch import nn
...@@ -28,14 +28,13 @@ from transformers import PersimmonConfig ...@@ -28,14 +28,13 @@ from transformers import PersimmonConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig from vllm.config import CacheConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization import QuantizationConfig
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
...@@ -44,6 +43,10 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader ...@@ -44,6 +43,10 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
class PersimmonMLP(nn.Module): class PersimmonMLP(nn.Module):
...@@ -211,20 +214,23 @@ class PersimmonModel(nn.Module): ...@@ -211,20 +214,23 @@ class PersimmonModel(nn.Module):
def __init__(self, def __init__(self,
config: PersimmonConfig, config: PersimmonConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None): quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__() super().__init__()
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(config.vocab_size, self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
config.hidden_size) config.hidden_size)
self.layers = nn.ModuleList([ self.start_layer, self.end_layer, self.layers = make_layers(
PersimmonDecoderLayer(config, config.num_hidden_layers,
cache_config=cache_config, lambda prefix: PersimmonDecoderLayer(config, cache_config,
quant_config=quant_config) quant_config),
for _ in range(config.num_hidden_layers) prefix=f"{prefix}.layers")
])
self.final_layernorm = nn.LayerNorm(config.hidden_size, self.final_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps) eps=config.layer_norm_eps)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(["hidden_states"],
config.hidden_size))
def forward( def forward(
self, self,
...@@ -232,24 +238,31 @@ class PersimmonModel(nn.Module): ...@@ -232,24 +238,31 @@ class PersimmonModel(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> Union[torch.Tensor, IntermediateTensors]:
if inputs_embeds is not None: if get_pp_group().is_first_rank:
hidden_states = inputs_embeds if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.embed_tokens(input_ids)
else: else:
hidden_states = self.embed_tokens(input_ids) assert intermediate_tensors is not None
for i in range(len(self.layers)): hidden_states = intermediate_tensors["hidden_states"]
for i in range(self.start_layer, self.end_layer):
hidden_states = self.layers[i]( hidden_states = self.layers[i](
positions, positions,
hidden_states, hidden_states,
kv_caches[i], kv_caches[i - self.start_layer],
attn_metadata, attn_metadata,
) )
if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})
hidden_states = self.final_layernorm(hidden_states) hidden_states = self.final_layernorm(hidden_states)
return hidden_states return hidden_states
class PersimmonForCausalLM(nn.Module): class PersimmonForCausalLM(nn.Module, SupportsPP):
def __init__(self, def __init__(self,
config: PersimmonConfig, config: PersimmonConfig,
...@@ -266,6 +279,8 @@ class PersimmonForCausalLM(nn.Module): ...@@ -266,6 +279,8 @@ class PersimmonForCausalLM(nn.Module):
bias=False) bias=False)
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def forward( def forward(
self, self,
...@@ -281,6 +296,7 @@ class PersimmonForCausalLM(nn.Module): ...@@ -281,6 +296,7 @@ class PersimmonForCausalLM(nn.Module):
positions=positions, positions=positions,
kv_caches=kv_caches, kv_caches=kv_caches,
attn_metadata=attn_metadata, attn_metadata=attn_metadata,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
) )
return hidden_states return hidden_states
...@@ -312,6 +328,8 @@ class PersimmonForCausalLM(nn.Module): ...@@ -312,6 +328,8 @@ class PersimmonForCausalLM(nn.Module):
# Models trained using ColossalAI may include these tensors in # Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them. # the checkpoint. Skip them.
continue continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name] param = params_dict[name]
if "query_key_value" in name: if "query_key_value" in name:
......
...@@ -35,7 +35,7 @@ ...@@ -35,7 +35,7 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Inference-only Phi-1.5 model compatible with HuggingFace weights.""" """Inference-only Phi-1.5 model compatible with HuggingFace weights."""
from typing import Iterable, List, Optional, Tuple from typing import Iterable, List, Optional, Tuple, Union
import torch import torch
from torch import nn from torch import nn
...@@ -43,14 +43,13 @@ from transformers import PhiConfig ...@@ -43,14 +43,13 @@ from transformers import PhiConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization import QuantizationConfig
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
...@@ -59,7 +58,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader ...@@ -59,7 +58,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA from .interfaces import SupportsLoRA, SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
class PhiAttention(nn.Module): class PhiAttention(nn.Module):
...@@ -196,18 +197,22 @@ class PhiModel(nn.Module): ...@@ -196,18 +197,22 @@ class PhiModel(nn.Module):
def __init__(self, def __init__(self,
config: PhiConfig, config: PhiConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None): quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__() super().__init__()
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.embed_tokens = VocabParallelEmbedding(config.vocab_size, self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
config.hidden_size) config.hidden_size)
self.layers = nn.ModuleList([ self.start_layer, self.end_layer, self.layers = make_layers(
PhiLayer(config, cache_config, quant_config) config.num_hidden_layers,
for _ in range(config.num_hidden_layers) lambda prefix: PhiLayer(config, cache_config, quant_config),
]) prefix=f"{prefix}.layers")
self.final_layernorm = nn.LayerNorm(config.hidden_size, self.final_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps) eps=config.layer_norm_eps)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(["hidden_states"],
config.hidden_size))
def forward( def forward(
self, self,
...@@ -215,23 +220,31 @@ class PhiModel(nn.Module): ...@@ -215,23 +220,31 @@ class PhiModel(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: intermediate_tensors: Optional[IntermediateTensors],
hidden_states = self.embed_tokens(input_ids) ) -> Union[torch.Tensor, IntermediateTensors]:
for i in range(self.config.num_hidden_layers): if get_pp_group().is_first_rank:
hidden_states = self.embed_tokens(input_ids)
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i] layer = self.layers[i]
hidden_states = layer( hidden_states = layer(
positions, positions,
hidden_states, hidden_states,
kv_caches[i], kv_caches[i - self.start_layer],
attn_metadata, attn_metadata,
) )
if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})
hidden_states = self.final_layernorm(hidden_states) hidden_states = self.final_layernorm(hidden_states)
return hidden_states return hidden_states
class PhiForCausalLM(nn.Module, SupportsLoRA): class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": [ "qkv_proj": [
"q_proj", "q_proj",
...@@ -274,6 +287,8 @@ class PhiForCausalLM(nn.Module, SupportsLoRA): ...@@ -274,6 +287,8 @@ class PhiForCausalLM(nn.Module, SupportsLoRA):
quant_config=quant_config) quant_config=quant_config)
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def forward( def forward(
self, self,
...@@ -282,9 +297,9 @@ class PhiForCausalLM(nn.Module, SupportsLoRA): ...@@ -282,9 +297,9 @@ class PhiForCausalLM(nn.Module, SupportsLoRA):
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata) attn_metadata, intermediate_tensors)
return hidden_states return hidden_states
...@@ -325,6 +340,8 @@ class PhiForCausalLM(nn.Module, SupportsLoRA): ...@@ -325,6 +340,8 @@ class PhiForCausalLM(nn.Module, SupportsLoRA):
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name] param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
...@@ -335,6 +352,8 @@ class PhiForCausalLM(nn.Module, SupportsLoRA): ...@@ -335,6 +352,8 @@ class PhiForCausalLM(nn.Module, SupportsLoRA):
continue continue
# pylint: disable=E1136 # pylint: disable=E1136
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
......
import math import math
from typing import Iterable, List, Optional, Tuple from typing import Iterable, List, Optional, Tuple, Union
import torch import torch
from torch import nn from torch import nn
...@@ -7,14 +7,13 @@ from transformers.configuration_utils import PretrainedConfig ...@@ -7,14 +7,13 @@ from transformers.configuration_utils import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import (get_tensor_model_parallel_rank, from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization import QuantizationConfig
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
...@@ -23,6 +22,10 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader ...@@ -23,6 +22,10 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
def load_column_parallel_weight(param: torch.nn.Parameter, def load_column_parallel_weight(param: torch.nn.Parameter,
loaded_weight: torch.Tensor): loaded_weight: torch.Tensor):
...@@ -301,20 +304,25 @@ class Phi3SmallModel(nn.Module): ...@@ -301,20 +304,25 @@ class Phi3SmallModel(nn.Module):
config: PretrainedConfig, config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.config = config self.config = config
self.embed_tokens = VocabParallelEmbedding(config.vocab_size, self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
config.hidden_size) config.hidden_size)
self.mup_embedding_multiplier = config.mup_embedding_multiplier self.mup_embedding_multiplier = config.mup_embedding_multiplier
self.layers = nn.ModuleList([ self.start_layer, self.end_layer, self.layers = make_layers(
Phi3SmallDecoderLayer(config, layer_idx, cache_config, config.num_hidden_layers,
quant_config) lambda prefix: Phi3SmallDecoderLayer(config,
for layer_idx in range(config.num_hidden_layers) int(prefix.split('.')[-1]),
]) cache_config, quant_config),
prefix=f"{prefix}.layers")
self.final_layernorm = nn.LayerNorm(config.hidden_size, self.final_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_epsilon) eps=config.layer_norm_epsilon)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(["hidden_states"],
config.hidden_size))
def get_input_embeddings(self): def get_input_embeddings(self):
return self.embed_tokens return self.embed_tokens
...@@ -327,30 +335,37 @@ class Phi3SmallModel(nn.Module): ...@@ -327,30 +335,37 @@ class Phi3SmallModel(nn.Module):
input_ids: torch.LongTensor, input_ids: torch.LongTensor,
positions: Optional[torch.LongTensor], positions: Optional[torch.LongTensor],
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata = None, attn_metadata: AttentionMetadata,
): intermediate_tensors: Optional[IntermediateTensors],
hidden_states = self.embed_tokens(input_ids) ) -> Union[torch.Tensor, IntermediateTensors]:
if (self.mup_embedding_multiplier is not None if get_pp_group().is_first_rank:
and self.mup_embedding_multiplier > 0.0): hidden_states = self.embed_tokens(input_ids)
hidden_states = hidden_states * self.mup_embedding_multiplier if (self.mup_embedding_multiplier is not None
for i in range(len(self.layers)): and self.mup_embedding_multiplier > 0.0):
hidden_states = hidden_states * self.mup_embedding_multiplier
else:
assert intermediate_tensors
hidden_states = intermediate_tensors["hidden_states"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i] layer = self.layers[i]
hidden_states = layer( hidden_states = layer(
positions, positions,
hidden_states, hidden_states,
kv_caches[i], kv_caches[i - self.start_layer],
attn_metadata, attn_metadata,
) )
if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})
hidden_states = self.final_layernorm(hidden_states) hidden_states = self.final_layernorm(hidden_states)
return hidden_states return hidden_states
class Phi3SmallForCausalLM(nn.Module): class Phi3SmallForCausalLM(nn.Module, SupportsPP):
_tied_weights_keys = ["lm_head.weight"] _tied_weights_keys = ["lm_head.weight"]
def __init__( def __init__(
self, self,
config, config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None, lora_config: Optional[LoRAConfig] = None,
...@@ -372,6 +387,8 @@ class Phi3SmallForCausalLM(nn.Module): ...@@ -372,6 +387,8 @@ class Phi3SmallForCausalLM(nn.Module):
self.lm_head.weight = self.model.embed_tokens.weight self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
# tokens in tiktoken but not used # tokens in tiktoken but not used
if hasattr(config, 'dummy_token_indices'): if hasattr(config, 'dummy_token_indices'):
...@@ -419,12 +436,13 @@ class Phi3SmallForCausalLM(nn.Module): ...@@ -419,12 +436,13 @@ class Phi3SmallForCausalLM(nn.Module):
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> Union[torch.Tensor, IntermediateTensors]:
output_hidden_states = self.model( output_hidden_states = self.model(
input_ids=input_ids, input_ids=input_ids,
positions=positions, positions=positions,
kv_caches=kv_caches, kv_caches=kv_caches,
attn_metadata=attn_metadata, attn_metadata=attn_metadata,
intermediate_tensors=intermediate_tensors,
) )
output_hidden_states = output_hidden_states output_hidden_states = output_hidden_states
return output_hidden_states return output_hidden_states
...@@ -447,6 +465,8 @@ class Phi3SmallForCausalLM(nn.Module): ...@@ -447,6 +465,8 @@ class Phi3SmallForCausalLM(nn.Module):
continue continue
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
# limitations under the License. # limitations under the License.
import itertools import itertools
import re import re
from functools import lru_cache from functools import cached_property, lru_cache
from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional,
Tuple, TypedDict, Union) Tuple, TypedDict, Union)
...@@ -29,13 +29,11 @@ from vllm.attention import AttentionMetadata ...@@ -29,13 +29,11 @@ from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, ModelConfig, MultiModalConfig from vllm.config import CacheConfig, ModelConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.clip import CLIPVisionModel from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.models.llama import LlamaModel from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.utils import cached_get_tokenizer, repeat_and_pad_token from vllm.multimodal.utils import cached_get_tokenizer, repeat_and_pad_token
...@@ -43,8 +41,9 @@ from vllm.sequence import IntermediateTensors ...@@ -43,8 +41,9 @@ from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of from vllm.utils import is_list_of
from .clip import dummy_image_for_clip, dummy_seq_data_for_clip from .clip import dummy_image_for_clip, dummy_seq_data_for_clip
from .interfaces import SupportsMultiModal from .interfaces import SupportsMultiModal, SupportsPP
from .utils import flatten_bn, merge_multimodal_embeddings from .utils import (flatten_bn, group_weights_with_prefix,
merge_multimodal_embeddings)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -295,6 +294,37 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase): ...@@ -295,6 +294,37 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
dim=2).reshape(num_images, -1, hid_dim) dim=2).reshape(num_images, -1, hid_dim)
return image_features_hd_newline return image_features_hd_newline
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# prepare weight iterators for components
weights_group = group_weights_with_prefix(weights)
# load vision encoder
self.img_processor.load_weights(weights_group["img_processor"])
# load glb_GN
for name, loaded_weight in weights_group["glb_GN"]:
assert name == ""
param = self.glb_GN
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# load sub_GN
for name, loaded_weight in weights_group["sub_GN"]:
assert name == ""
param = self.sub_GN
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# load mlp projector
mlp_params_dict = dict(self.img_projection.named_parameters())
for name, loaded_weight in weights_group["img_projection"]:
param = mlp_params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L57 # Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L57
def _calc_padded_size(*, width: int, height: int, padding_unit: int = 336): def _calc_padded_size(*, width: int, height: int, padding_unit: int = 336):
...@@ -508,7 +538,7 @@ def input_processor_for_phi3v(ctx: InputContext, ...@@ -508,7 +538,7 @@ def input_processor_for_phi3v(ctx: InputContext,
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_phi3v_image_tokens) @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_phi3v_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_phi3v) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_phi3v)
@INPUT_REGISTRY.register_input_processor(input_processor_for_phi3v) @INPUT_REGISTRY.register_input_processor(input_processor_for_phi3v)
class Phi3VForCausalLM(nn.Module, SupportsMultiModal): class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
def __init__(self, def __init__(self,
config: PretrainedConfig, config: PretrainedConfig,
...@@ -521,17 +551,21 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal): ...@@ -521,17 +551,21 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal):
self.multimodal_config = multimodal_config self.multimodal_config = multimodal_config
self.image_token_id = _IMAGE_TOKEN_ID self.image_token_id = _IMAGE_TOKEN_ID
self.model = LlamaModel(config, cache_config, quant_config)
# TODO: Optionally initializes this for supporting embeddings. # TODO: Optionally initializes this for supporting embeddings.
self.vision_embed_tokens = Phi3HDImageEmbedding(config) self.vision_embed_tokens = Phi3HDImageEmbedding(config)
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size, self.language_model = LlamaForCausalLM(config, cache_config,
quant_config=quant_config) quant_config)
if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight self.make_empty_intermediate_tensors = (
self.logits_processor = LogitsProcessor(config.vocab_size) self.language_model.make_empty_intermediate_tensors)
self.sampler = Sampler()
@cached_property
def sampler(self):
if hasattr(self.language_model, "sampler"):
return self.language_model.sampler
return Sampler()
def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor: def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor:
expected_dims = (2, ) expected_dims = (2, )
...@@ -631,24 +665,29 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal): ...@@ -631,24 +665,29 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal):
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs: object): **kwargs: object):
image_input = self._parse_and_validate_image_input(**kwargs) if intermediate_tensors is not None:
if image_input is not None:
vision_embeddings = self._process_image_input(image_input)
inputs_embeds = self.model.get_input_embeddings(input_ids)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.image_token_id)
input_ids = None input_ids = None
else:
inputs_embeds = None inputs_embeds = None
else:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is not None:
vision_embeddings = self._process_image_input(image_input)
inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.image_token_id)
input_ids = None
else:
inputs_embeds = None
hidden_states = self.model(input_ids, hidden_states = self.language_model.model(input_ids,
positions, positions,
kv_caches, kv_caches,
attn_metadata, attn_metadata,
intermediate_tensors, intermediate_tensors,
inputs_embeds=inputs_embeds) inputs_embeds=inputs_embeds)
return hidden_states return hidden_states
...@@ -657,66 +696,38 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal): ...@@ -657,66 +696,38 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states, return self.language_model.compute_logits(hidden_states,
sampling_metadata) sampling_metadata)
return logits
def sample( def sample(
self, self,
logits: torch.Tensor, logits: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]: ) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata) return self.language_model.sample(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [ hf_to_vllm_mapping = {
# (param_name, shard_name, shard_id) "model.vision_embed_tokens.": "vision_embed_tokens.",
(".qkv_proj", ".q_proj", "q"), "lm_head.": "language_model.lm_head.",
(".qkv_proj", ".k_proj", "k"), "model.": "language_model.model.",
(".qkv_proj", ".v_proj", "v"), }
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
# TODO(ChristopherCho): This is a temporary fix to load def hf_to_vllm_name(key: str) -> str:
# the vision weights with CLIPVisionModel.load_weights() for hf_name, vllm_name in hf_to_vllm_mapping.items():
vision_weights = [] if key.startswith(hf_name):
params_dict = dict(self.named_parameters()) return key.replace(hf_name, vllm_name, 1)
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name: return key
continue
# Skip loading the img_processor weights since they are vllm_weights = {hf_to_vllm_name(k): v for k, v in weights}
# loaded separately.
if "vision_embed_tokens.img_processor" in name: # prepare weight iterators for components
vision_weights.append((name, loaded_weight)) weights_group = group_weights_with_prefix(vllm_weights.items())
continue
# load vision embeddings and encoder
for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items(): self.vision_embed_tokens.load_weights(
if key_to_modify in name: weights_group["vision_embed_tokens"])
name = name.replace(key_to_modify, new_key)
for (param_name, weight_name, shard_id) in stacked_params_mapping: # load llm backbone
if weight_name not in name: self.language_model.load_weights(weights_group["language_model"])
continue
param = params_dict[name.replace(weight_name, param_name)]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if name in params_dict:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# We use regex to extract the sub-module name
# from "model.vision_embed_tokens.img_processor.*"
vision_weights = [
(re.search(r"vision_embed_tokens\.img_processor\.(.*)",
n).group(1), w) for n, w in vision_weights
]
self.vision_embed_tokens.img_processor.load_weights(vision_weights)
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only PhiMoE model.""" """Inference-only PhiMoE model."""
from typing import Iterable, List, Optional, Tuple from typing import Iterable, List, Optional, Tuple, Union
import torch import torch
from torch import nn from torch import nn
...@@ -29,7 +29,7 @@ from transformers.configuration_utils import PretrainedConfig ...@@ -29,7 +29,7 @@ from transformers.configuration_utils import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import (QKVParallelLinear, from vllm.model_executor.layers.linear import (QKVParallelLinear,
ReplicatedLinear, ReplicatedLinear,
...@@ -46,7 +46,9 @@ from vllm.model_executor.model_loader.weight_utils import ( ...@@ -46,7 +46,9 @@ from vllm.model_executor.model_loader.weight_utils import (
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA from .interfaces import SupportsLoRA, SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
class PhiMoEConfig(PretrainedConfig): class PhiMoEConfig(PretrainedConfig):
...@@ -435,6 +437,7 @@ class PhiMoEModel(nn.Module): ...@@ -435,6 +437,7 @@ class PhiMoEModel(nn.Module):
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None, lora_config: Optional[LoRAConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
...@@ -448,33 +451,56 @@ class PhiMoEModel(nn.Module): ...@@ -448,33 +451,56 @@ class PhiMoEModel(nn.Module):
config.hidden_size, config.hidden_size,
org_num_embeddings=config.vocab_size, org_num_embeddings=config.vocab_size,
) )
self.layers = nn.ModuleList([ self.start_layer, self.end_layer, self.layers = make_layers(
PhiMoEDecoderLayer(config, cache_config, quant_config=quant_config) config.num_hidden_layers,
for _ in range(config.num_hidden_layers) lambda prefix: PhiMoEDecoderLayer(config, cache_config,
]) quant_config),
prefix=f"{prefix}.layers")
self.norm = nn.LayerNorm(config.hidden_size, self.norm = nn.LayerNorm(config.hidden_size,
eps=config.rms_norm_eps, eps=config.rms_norm_eps,
elementwise_affine=True) elementwise_affine=True)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: intermediate_tensors: Optional[IntermediateTensors],
hidden_states = self.embed_tokens(input_ids) ) -> Union[torch.Tensor, IntermediateTensors]:
residual = None if get_pp_group().is_first_rank:
for i in range(len(self.layers)): hidden_states = self.embed_tokens(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i] layer = self.layers[i]
hidden_states, residual = layer(positions, hidden_states, hidden_states, residual = layer(
kv_caches[i], attn_metadata, positions,
residual) hidden_states,
kv_caches[i - self.start_layer],
attn_metadata,
residual,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states = self.norm(hidden_states) hidden_states = self.norm(hidden_states)
return hidden_states return hidden_states
class PhiMoEForCausalLM(nn.Module, SupportsLoRA): class PhiMoEForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
fall_back_to_pt_during_load = False fall_back_to_pt_during_load = False
packed_modules_mapping = { packed_modules_mapping = {
...@@ -537,6 +563,9 @@ class PhiMoEForCausalLM(nn.Module, SupportsLoRA): ...@@ -537,6 +563,9 @@ class PhiMoEForCausalLM(nn.Module, SupportsLoRA):
config.vocab_size) config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -544,9 +573,9 @@ class PhiMoEForCausalLM(nn.Module, SupportsLoRA): ...@@ -544,9 +573,9 @@ class PhiMoEForCausalLM(nn.Module, SupportsLoRA):
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata) attn_metadata, intermediate_tensors)
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(self, hidden_states: torch.Tensor,
...@@ -589,6 +618,9 @@ class PhiMoEForCausalLM(nn.Module, SupportsLoRA): ...@@ -589,6 +618,9 @@ class PhiMoEForCausalLM(nn.Module, SupportsLoRA):
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name] param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
...@@ -599,6 +631,9 @@ class PhiMoEForCausalLM(nn.Module, SupportsLoRA): ...@@ -599,6 +631,9 @@ class PhiMoEForCausalLM(nn.Module, SupportsLoRA):
if weight_name not in name: if weight_name not in name:
continue continue
name = name.replace(weight_name, param_name) name = name.replace(weight_name, param_name)
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name] param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader( weight_loader(
...@@ -613,6 +648,9 @@ class PhiMoEForCausalLM(nn.Module, SupportsLoRA): ...@@ -613,6 +648,9 @@ class PhiMoEForCausalLM(nn.Module, SupportsLoRA):
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
# Remapping the name of FP8 kv-scale. # Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict) name = maybe_remap_kv_scale_name(name, params_dict)
if name is None: if name is None:
......
from dataclasses import dataclass, fields from dataclasses import dataclass, fields
from functools import cached_property
from itertools import tee from itertools import tee
from typing import Iterable, List, Mapping, Optional, Tuple, Union from typing import Iterable, List, Mapping, Optional, Tuple, Union
...@@ -16,7 +17,7 @@ from vllm.config import CacheConfig, MultiModalConfig ...@@ -16,7 +17,7 @@ from vllm.config import CacheConfig, MultiModalConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.utils import merge_multimodal_embeddings from vllm.model_executor.models.utils import merge_multimodal_embeddings
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
...@@ -25,7 +26,7 @@ from vllm.multimodal.base import MultiModalInputs ...@@ -25,7 +26,7 @@ from vllm.multimodal.base import MultiModalInputs
from vllm.multimodal.utils import cached_get_tokenizer from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import IntermediateTensors, SequenceData from vllm.sequence import IntermediateTensors, SequenceData
from .interfaces import SupportsMultiModal from .interfaces import SupportsMultiModal, SupportsPP
from .utils import init_vllm_registered_model from .utils import init_vllm_registered_model
...@@ -126,7 +127,8 @@ def input_processor_for_pixtral(ctx: InputContext, llm_inputs: LLMInputs): ...@@ -126,7 +127,8 @@ def input_processor_for_pixtral(ctx: InputContext, llm_inputs: LLMInputs):
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_pixtral_image_tokens) @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_pixtral_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_pixtral) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_pixtral)
@INPUT_REGISTRY.register_input_processor(input_processor_for_pixtral) @INPUT_REGISTRY.register_input_processor(input_processor_for_pixtral)
class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal): class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP):
def __init__(self, def __init__(self,
config: PretrainedConfig, config: PretrainedConfig,
...@@ -155,6 +157,16 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -155,6 +157,16 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal):
self.vision_language_adapter = VisionLanguageAdapter( self.vision_language_adapter = VisionLanguageAdapter(
self.vision_args, dim=config.text_config.hidden_size) self.vision_args, dim=config.text_config.hidden_size)
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
@cached_property
def sampler(self):
if hasattr(self.language_model, "sampler"):
return self.language_model.sampler
return Sampler()
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -163,32 +175,36 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -163,32 +175,36 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal):
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs: object, **kwargs: object,
) -> SamplerOutput: ) -> Union[torch.Tensor, IntermediateTensors]:
"""Run forward pass for pixtral. """Run forward pass for pixtral.
TODO TODO
""" """
image_input = self._parse_and_validate_image_input(**kwargs) if intermediate_tensors is not None:
input_ids = None
inputs_embeds = None
else:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is not None: if image_input is not None:
vision_embeddings = self._process_image_input(image_input) vision_embeddings = self._process_image_input(image_input)
inputs_embeds = self.language_model.model.get_input_embeddings( inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids) input_ids)
inputs_embeds = merge_multimodal_embeddings( inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings, input_ids, inputs_embeds, vision_embeddings,
self.vision_args.image_token_id) self.vision_args.image_token_id)
input_ids = None input_ids = None
else: else:
inputs_embeds = None inputs_embeds = None
hidden_states = self.language_model.model(input_ids, hidden_states = self.language_model.model(input_ids,
positions, positions,
kv_caches, kv_caches,
attn_metadata, attn_metadata,
None, intermediate_tensors,
inputs_embeds=inputs_embeds) inputs_embeds=inputs_embeds)
return hidden_states return hidden_states
......
...@@ -31,15 +31,13 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, ...@@ -31,15 +31,13 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization import QuantizationConfig
QuantizationConfig)
from vllm.model_executor.layers.resampler import Resampler2, get_abs_pos from vllm.model_executor.layers.resampler import Resampler2, get_abs_pos
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import SupportsMultiModal
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.base import MultiModalInputs from vllm.multimodal.base import MultiModalInputs
...@@ -47,7 +45,9 @@ from vllm.multimodal.utils import cached_get_tokenizer ...@@ -47,7 +45,9 @@ from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import IntermediateTensors, SequenceData from vllm.sequence import IntermediateTensors, SequenceData
from vllm.utils import is_list_of from vllm.utils import is_list_of
from .utils import flatten_bn, is_pp_missing_parameter, make_layers from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (flatten_bn, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -568,6 +568,9 @@ class QWenModel(nn.Module): ...@@ -568,6 +568,9 @@ class QWenModel(nn.Module):
lambda prefix: QWenBlock(config, cache_config, quant_config), lambda prefix: QWenBlock(config, cache_config, quant_config),
prefix=f"{prefix}.h") prefix=f"{prefix}.h")
self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
self.visual = VisionTransformer(**config.visual, self.visual = VisionTransformer(**config.visual,
quant_config=quant_config) if hasattr( quant_config=quant_config) if hasattr(
config, "visual") else None config, "visual") else None
...@@ -580,7 +583,7 @@ class QWenModel(nn.Module): ...@@ -580,7 +583,7 @@ class QWenModel(nn.Module):
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors], intermediate_tensors: Optional[IntermediateTensors],
pixel_values: Optional[QwenImageInputs], pixel_values: Optional[QwenImageInputs],
) -> torch.Tensor: ) -> Union[torch.Tensor, IntermediateTensors]:
img_pos = None img_pos = None
# If pixel / visual embeddings are provided, this is a visual model # If pixel / visual embeddings are provided, this is a visual model
if pixel_values is not None and self.visual is not None: if pixel_values is not None and self.visual is not None:
...@@ -860,7 +863,7 @@ def dummy_data_for_qwen( ...@@ -860,7 +863,7 @@ def dummy_data_for_qwen(
@MULTIMODAL_REGISTRY.register_max_image_tokens(MAX_QWEN_IMG_TOKENS) @MULTIMODAL_REGISTRY.register_max_image_tokens(MAX_QWEN_IMG_TOKENS)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_qwen) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_qwen)
@INPUT_REGISTRY.register_input_processor(input_processor_for_qwen) @INPUT_REGISTRY.register_input_processor(input_processor_for_qwen)
class QWenLMHeadModel(nn.Module, SupportsMultiModal): class QWenLMHeadModel(nn.Module, SupportsMultiModal, SupportsPP):
def __init__( def __init__(
self, self,
...@@ -881,6 +884,8 @@ class QWenLMHeadModel(nn.Module, SupportsMultiModal): ...@@ -881,6 +884,8 @@ class QWenLMHeadModel(nn.Module, SupportsMultiModal):
self.lm_head.weight = self.transformer.wte.weight self.lm_head.weight = self.transformer.wte.weight
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
self.transformer.make_empty_intermediate_tensors)
def _get_image_input_type( def _get_image_input_type(
self, self,
...@@ -912,33 +917,26 @@ class QWenLMHeadModel(nn.Module, SupportsMultiModal): ...@@ -912,33 +917,26 @@ class QWenLMHeadModel(nn.Module, SupportsMultiModal):
) )
return None return None
def forward(self, def forward(
input_ids: torch.Tensor, self,
positions: torch.Tensor, input_ids: torch.Tensor,
kv_caches: List[torch.Tensor], positions: torch.Tensor,
attn_metadata: AttentionMetadata, kv_caches: List[torch.Tensor],
intermediate_tensors: Optional[IntermediateTensors] = None, attn_metadata: AttentionMetadata,
pixel_values: Optional[torch.Tensor] = None) -> torch.Tensor: intermediate_tensors: Optional[IntermediateTensors] = None,
pixel_values = self._get_image_input_type(pixel_values) pixel_values: Optional[torch.Tensor] = None
) -> Union[torch.Tensor, IntermediateTensors]:
if intermediate_tensors is not None:
input_ids = None
pixel_values = None
else:
pixel_values = self._get_image_input_type(pixel_values)
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors, attn_metadata, intermediate_tensors,
pixel_values) pixel_values)
return hidden_states return hidden_states
def make_empty_intermediate_tensors(
self, batch_size: int, dtype: torch.dtype,
device: torch.device) -> IntermediateTensors:
return IntermediateTensors({
"hidden_states":
torch.zeros((batch_size, self.config.hidden_size),
dtype=dtype,
device=device),
"residual":
torch.zeros((batch_size, self.config.hidden_size),
dtype=dtype,
device=device),
})
def compute_logits( def compute_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
......
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only Qwen2 model compatible with HuggingFace weights.""" """Inference-only Qwen2 model compatible with HuggingFace weights."""
from typing import Iterable, List, Optional, Tuple from typing import Iterable, List, Optional, Tuple, Union
import torch import torch
from torch import nn from torch import nn
...@@ -37,8 +37,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, ...@@ -37,8 +37,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization import QuantizationConfig
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
...@@ -48,8 +47,9 @@ from vllm.model_executor.model_loader.weight_utils import ( ...@@ -48,8 +47,9 @@ from vllm.model_executor.model_loader.weight_utils import (
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA from .interfaces import SupportsLoRA, SupportsPP
from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers from .utils import (PPMissingLayer, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
class Qwen2MLP(nn.Module): class Qwen2MLP(nn.Module):
...@@ -253,6 +253,9 @@ class Qwen2Model(nn.Module): ...@@ -253,6 +253,9 @@ class Qwen2Model(nn.Module):
prefix=f"{prefix}.layers", prefix=f"{prefix}.layers",
) )
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
if get_pp_group().is_last_rank: if get_pp_group().is_last_rank:
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
else: else:
...@@ -269,7 +272,7 @@ class Qwen2Model(nn.Module): ...@@ -269,7 +272,7 @@ class Qwen2Model(nn.Module):
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank: if get_pp_group().is_first_rank:
if inputs_embeds is not None: if inputs_embeds is not None:
hidden_states = inputs_embeds hidden_states = inputs_embeds
...@@ -298,7 +301,7 @@ class Qwen2Model(nn.Module): ...@@ -298,7 +301,7 @@ class Qwen2Model(nn.Module):
return hidden_states return hidden_states
class Qwen2ForCausalLM(nn.Module, SupportsLoRA): class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": [ "qkv_proj": [
"q_proj", "q_proj",
...@@ -357,6 +360,8 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA): ...@@ -357,6 +360,8 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def forward( def forward(
self, self,
...@@ -365,7 +370,7 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA): ...@@ -365,7 +370,7 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors) attn_metadata, intermediate_tensors)
return hidden_states return hidden_states
...@@ -379,20 +384,6 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA): ...@@ -379,20 +384,6 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
sampling_metadata) sampling_metadata)
return logits return logits
def make_empty_intermediate_tensors(
self, batch_size: int, dtype: torch.dtype,
device: torch.device) -> IntermediateTensors:
return IntermediateTensors({
"hidden_states":
torch.zeros((batch_size, self.config.hidden_size),
dtype=dtype,
device=device),
"residual":
torch.zeros((batch_size, self.config.hidden_size),
dtype=dtype,
device=device),
})
def sample( def sample(
self, self,
logits: torch.Tensor, logits: torch.Tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment