Unverified Commit 0f6d7a9a authored by Murali Andoorveedu's avatar Murali Andoorveedu Committed by GitHub
Browse files

[Models] Add remaining model PP support (#7168)


Signed-off-by: default avatarMuralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
Signed-off-by: default avatarMurali Andoorveedu <muralidhar.andoorveedu@centml.ai>
Co-authored-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 303d4479
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only Qwen2MoE model compatible with HuggingFace weights.""" """Inference-only Qwen2MoE model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, List, Optional, Tuple from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -42,8 +42,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, ...@@ -42,8 +42,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
ReplicatedLinear, ReplicatedLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization import QuantizationConfig
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
...@@ -53,7 +52,9 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata ...@@ -53,7 +52,9 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import print_warning_once from vllm.utils import print_warning_once
from .utils import is_pp_missing_parameter, make_layers from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
class Qwen2MoeMLP(nn.Module): class Qwen2MoeMLP(nn.Module):
...@@ -338,6 +339,9 @@ class Qwen2MoeModel(nn.Module): ...@@ -338,6 +339,9 @@ class Qwen2MoeModel(nn.Module):
prefix=f"{prefix}.layers", prefix=f"{prefix}.layers",
) )
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
def forward( def forward(
self, self,
...@@ -346,7 +350,7 @@ class Qwen2MoeModel(nn.Module): ...@@ -346,7 +350,7 @@ class Qwen2MoeModel(nn.Module):
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank: if get_pp_group().is_first_rank:
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
residual = None residual = None
...@@ -368,7 +372,7 @@ class Qwen2MoeModel(nn.Module): ...@@ -368,7 +372,7 @@ class Qwen2MoeModel(nn.Module):
return hidden_states return hidden_states
class Qwen2MoeForCausalLM(nn.Module): class Qwen2MoeForCausalLM(nn.Module, SupportsPP):
fall_back_to_pt_during_load = False fall_back_to_pt_during_load = False
...@@ -389,6 +393,8 @@ class Qwen2MoeForCausalLM(nn.Module): ...@@ -389,6 +393,8 @@ class Qwen2MoeForCausalLM(nn.Module):
self.lm_head.weight = self.model.embed_tokens.weight self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def forward( def forward(
self, self,
...@@ -397,7 +403,7 @@ class Qwen2MoeForCausalLM(nn.Module): ...@@ -397,7 +403,7 @@ class Qwen2MoeForCausalLM(nn.Module):
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors) attn_metadata, intermediate_tensors)
return hidden_states return hidden_states
...@@ -411,20 +417,6 @@ class Qwen2MoeForCausalLM(nn.Module): ...@@ -411,20 +417,6 @@ class Qwen2MoeForCausalLM(nn.Module):
sampling_metadata) sampling_metadata)
return logits return logits
def make_empty_intermediate_tensors(
self, batch_size: int, dtype: torch.dtype,
device: torch.device) -> IntermediateTensors:
return IntermediateTensors({
"hidden_states":
torch.zeros((batch_size, self.config.hidden_size),
dtype=dtype,
device=device),
"residual":
torch.zeros((batch_size, self.config.hidden_size),
dtype=dtype,
device=device),
})
def sample( def sample(
self, self,
logits: Optional[torch.Tensor], logits: Optional[torch.Tensor],
......
...@@ -55,7 +55,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig ...@@ -55,7 +55,6 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import SupportsMultiModal
from vllm.model_executor.models.qwen2 import Qwen2Model from vllm.model_executor.models.qwen2 import Qwen2Model
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict, from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict,
MultiModalInputs) MultiModalInputs)
...@@ -68,6 +67,7 @@ from vllm.transformers_utils.configs.qwen2vl import (Qwen2VLConfig, ...@@ -68,6 +67,7 @@ from vllm.transformers_utils.configs.qwen2vl import (Qwen2VLConfig,
from vllm.transformers_utils.processor import get_processor from vllm.transformers_utils.processor import get_processor
from vllm.utils import is_cpu from vllm.utils import is_cpu
from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (PPMissingLayer, is_pp_missing_parameter, from .utils import (PPMissingLayer, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory) make_empty_intermediate_tensors_factory)
...@@ -883,7 +883,8 @@ def input_processor_for_qwen2_vl(ctx: InputContext, ...@@ -883,7 +883,8 @@ def input_processor_for_qwen2_vl(ctx: InputContext,
"video", get_max_qwen2_vl_video_tokens) "video", get_max_qwen2_vl_video_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_qwen2_vl) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_qwen2_vl)
@INPUT_REGISTRY.register_input_processor(input_processor_for_qwen2_vl) @INPUT_REGISTRY.register_input_processor(input_processor_for_qwen2_vl)
class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal): class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP):
def __init__(self, def __init__(self,
config: Qwen2VLConfig, config: Qwen2VLConfig,
...@@ -1027,7 +1028,7 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -1027,7 +1028,7 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal):
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs: object, **kwargs: object,
) -> SamplerOutput: ) -> Union[torch.Tensor, IntermediateTensors]:
"""Run forward pass for Qwen2-VL. """Run forward pass for Qwen2-VL.
Args: Args:
...@@ -1047,41 +1048,43 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal): ...@@ -1047,41 +1048,43 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal):
video_grid_thw: Tensor `(n_videos, 3)` of video 3D grid in LLM. video_grid_thw: Tensor `(n_videos, 3)` of video 3D grid in LLM.
`None` if no videos are passed. `None` if no videos are passed.
""" """
if intermediate_tensors is not None:
image_input = self._parse_and_validate_image_input(**kwargs) input_ids = None
video_input = self._parse_and_validate_video_input(**kwargs)
if (image_input is None
and video_input is None) or not get_pp_group().is_first_rank:
inputs_embeds = None inputs_embeds = None
else: else:
if getattr(self.config, "rope_scaling", {}).get("type", image_input = self._parse_and_validate_image_input(**kwargs)
None) == "mrope": video_input = self._parse_and_validate_video_input(**kwargs)
assert positions.ndim == 2 and positions.size(0) == 3, (
"multimodal section rotary embedding requires "
f"(3, seq_len) positions, but got {positions.size()}")
inputs_embeds = self.model.embed_tokens(input_ids)
if image_input is not None:
image_embeds = self._process_image_input(image_input)
inputs_embeds = self._merge_multimodal_embeddings(
input_ids,
inputs_embeds,
image_embeds,
placeholder_token_id=self.config.image_token_id,
)
if video_input is not None: if image_input is None and video_input is None:
video_embeds = self._process_video_input(video_input) inputs_embeds = None
inputs_embeds = self._merge_multimodal_embeddings( else:
input_ids, rope_scaling = getattr(self.config, "rope_scaling", {})
inputs_embeds, if rope_scaling.get("type", None) == "mrope":
video_embeds, assert positions.ndim == 2 and positions.size(0) == 3, (
placeholder_token_id=self.config.video_token_id, "multimodal section rotary embedding requires "
) f"(3, seq_len) positions, but got {positions.size()}")
input_ids = None inputs_embeds = self.model.embed_tokens(input_ids)
if image_input is not None:
image_embeds = self._process_image_input(image_input)
inputs_embeds = self._merge_multimodal_embeddings(
input_ids,
inputs_embeds,
image_embeds,
placeholder_token_id=self.config.image_token_id,
)
if video_input is not None:
video_embeds = self._process_video_input(video_input)
inputs_embeds = self._merge_multimodal_embeddings(
input_ids,
inputs_embeds,
video_embeds,
placeholder_token_id=self.config.video_token_id,
)
input_ids = None
hidden_states = self.model( hidden_states = self.model(
input_ids=input_ids, input_ids=input_ids,
......
...@@ -246,7 +246,7 @@ class SiglipParallelAttention(nn.Module): ...@@ -246,7 +246,7 @@ class SiglipParallelAttention(nn.Module):
def __init__( def __init__(
self, self,
config, config: SiglipVisionConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
...@@ -312,7 +312,7 @@ class SiglipMLP(nn.Module): ...@@ -312,7 +312,7 @@ class SiglipMLP(nn.Module):
def __init__( def __init__(
self, self,
config, config: SiglipVisionConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
): ):
super().__init__() super().__init__()
......
...@@ -26,6 +26,7 @@ from typing import Any, Dict, Iterable, List, Optional, Tuple, Union ...@@ -26,6 +26,7 @@ from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
import torch import torch
from torch import nn from torch import nn
from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig from vllm.config import CacheConfig, LoRAConfig
...@@ -37,8 +38,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, ...@@ -37,8 +38,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization import QuantizationConfig
QuantizationConfig)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
get_compressed_tensors_cache_scale) get_compressed_tensors_cache_scale)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
...@@ -47,14 +47,14 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -47,14 +47,14 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name) default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.models.interfaces import SupportsLoRA
from vllm.model_executor.models.utils import (PPMissingLayer,
is_pp_missing_parameter,
make_layers)
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import is_hip from vllm.utils import is_hip
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (PPMissingLayer, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
class SolarMLP(nn.Module): class SolarMLP(nn.Module):
...@@ -98,7 +98,7 @@ class SolarAttention(nn.Module): ...@@ -98,7 +98,7 @@ class SolarAttention(nn.Module):
def __init__( def __init__(
self, self,
config, config: PretrainedConfig,
hidden_size: int, hidden_size: int,
num_heads: int, num_heads: int,
num_kv_heads: int, num_kv_heads: int,
...@@ -187,7 +187,7 @@ class SolarDecoderLayer(nn.Module): ...@@ -187,7 +187,7 @@ class SolarDecoderLayer(nn.Module):
def __init__( def __init__(
self, self,
config, config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
...@@ -267,7 +267,7 @@ class SolarModel(nn.Module): ...@@ -267,7 +267,7 @@ class SolarModel(nn.Module):
def __init__( def __init__(
self, self,
config, config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None, lora_config: Optional[LoRAConfig] = None,
...@@ -304,6 +304,10 @@ class SolarModel(nn.Module): ...@@ -304,6 +304,10 @@ class SolarModel(nn.Module):
else: else:
self.norm = PPMissingLayer() self.norm = PPMissingLayer()
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids)
...@@ -368,7 +372,7 @@ class SolarModel(nn.Module): ...@@ -368,7 +372,7 @@ class SolarModel(nn.Module):
return hidden_states return hidden_states
class SolarForCausalLM(nn.Module, SupportsLoRA): class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": [ "qkv_proj": [
"q_proj", "q_proj",
...@@ -406,7 +410,7 @@ class SolarForCausalLM(nn.Module, SupportsLoRA): ...@@ -406,7 +410,7 @@ class SolarForCausalLM(nn.Module, SupportsLoRA):
def __init__( def __init__(
self, self,
config, config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None, lora_config: Optional[LoRAConfig] = None,
...@@ -448,6 +452,9 @@ class SolarForCausalLM(nn.Module, SupportsLoRA): ...@@ -448,6 +452,9 @@ class SolarForCausalLM(nn.Module, SupportsLoRA):
else: else:
self.lm_head = PPMissingLayer() self.lm_head = PPMissingLayer()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -474,24 +481,6 @@ class SolarForCausalLM(nn.Module, SupportsLoRA): ...@@ -474,24 +481,6 @@ class SolarForCausalLM(nn.Module, SupportsLoRA):
next_tokens = self.sampler(logits, sampling_metadata) next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens return next_tokens
def make_empty_intermediate_tensors(
self, batch_size: int, dtype: torch.dtype,
device: torch.device) -> IntermediateTensors:
return IntermediateTensors({
"hidden_states":
torch.zeros(
(batch_size, self.config.hidden_size),
dtype=dtype,
device=device,
),
"residual":
torch.zeros(
(batch_size, self.config.hidden_size),
dtype=dtype,
device=device,
),
})
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
......
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
# https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/config.json # https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/config.json
"""Inference-only StabeLM (https://github.com/Stability-AI/StableLM) """Inference-only StabeLM (https://github.com/Stability-AI/StableLM)
model compatible with HuggingFace weights.""" model compatible with HuggingFace weights."""
from typing import Iterable, List, Optional, Tuple from typing import Iterable, List, Optional, Tuple, Union
import torch import torch
from torch import nn from torch import nn
...@@ -27,14 +27,13 @@ from transformers import PretrainedConfig ...@@ -27,14 +27,13 @@ from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig from vllm.config import CacheConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization import QuantizationConfig
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
...@@ -43,6 +42,10 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader ...@@ -43,6 +42,10 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
class StablelmMLP(nn.Module): class StablelmMLP(nn.Module):
...@@ -194,19 +197,25 @@ class StableLMEpochModel(nn.Module): ...@@ -194,19 +197,25 @@ class StableLMEpochModel(nn.Module):
def __init__(self, def __init__(self,
config: PretrainedConfig, config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None) -> None: quant_config: Optional[QuantizationConfig] = None,
prefix: str = '') -> None:
super().__init__() super().__init__()
self.embed_tokens = VocabParallelEmbedding( self.embed_tokens = VocabParallelEmbedding(
config.vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
) )
self.layers = nn.ModuleList([ self.start_layer, self.end_layer, self.layers = make_layers(
StablelmDecoderLayer(config, cache_config, quant_config) config.num_hidden_layers,
for _ in range(config.num_hidden_layers) lambda prefix: StablelmDecoderLayer(config, cache_config,
]) quant_config),
prefix=f"{prefix}.layers",
)
norm_eps = getattr(config, "norm_eps", norm_eps = getattr(config, "norm_eps",
getattr(config, "layer_norm_eps", 1e-05)) getattr(config, "layer_norm_eps", 1e-05))
self.norm = nn.LayerNorm(config.hidden_size, eps=norm_eps) self.norm = nn.LayerNorm(config.hidden_size, eps=norm_eps)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(["hidden_states"],
config.hidden_size))
def forward( def forward(
self, self,
...@@ -214,21 +223,28 @@ class StableLMEpochModel(nn.Module): ...@@ -214,21 +223,28 @@ class StableLMEpochModel(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: intermediate_tensors: Optional[IntermediateTensors],
hidden_states = self.embed_tokens(input_ids) ) -> Union[torch.Tensor, IntermediateTensors]:
for i in range(len(self.layers)): if get_pp_group().is_first_rank:
hidden_states = self.embed_tokens(input_ids)
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i] layer = self.layers[i]
hidden_states, residual = layer( hidden_states, residual = layer(
positions, positions,
hidden_states, hidden_states,
kv_caches[i], kv_caches[i - self.start_layer],
attn_metadata, attn_metadata,
) )
if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})
hidden_states = self.norm(hidden_states) hidden_states = self.norm(hidden_states)
return hidden_states return hidden_states
class StablelmForCausalLM(nn.Module): class StablelmForCausalLM(nn.Module, SupportsPP):
def __init__( def __init__(
self, self,
...@@ -247,6 +263,8 @@ class StablelmForCausalLM(nn.Module): ...@@ -247,6 +263,8 @@ class StablelmForCausalLM(nn.Module):
self.lm_head.weight = self.model.embed_tokens.weight self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def forward( def forward(
self, self,
...@@ -255,9 +273,9 @@ class StablelmForCausalLM(nn.Module): ...@@ -255,9 +273,9 @@ class StablelmForCausalLM(nn.Module):
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata) attn_metadata, intermediate_tensors)
return hidden_states return hidden_states
def compute_logits( def compute_logits(
...@@ -302,6 +320,8 @@ class StablelmForCausalLM(nn.Module): ...@@ -302,6 +320,8 @@ class StablelmForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name] param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
...@@ -310,6 +330,8 @@ class StablelmForCausalLM(nn.Module): ...@@ -310,6 +330,8 @@ class StablelmForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" PyTorch Starcoder2 model.""" """ PyTorch Starcoder2 model."""
from typing import Iterable, List, Optional, Tuple from typing import Iterable, List, Optional, Tuple, Union
import torch import torch
from torch import nn from torch import nn
...@@ -26,14 +26,13 @@ from transformers import Starcoder2Config ...@@ -26,14 +26,13 @@ from transformers import Starcoder2Config
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig from vllm.config import CacheConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization import QuantizationConfig
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
...@@ -42,6 +41,10 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader ...@@ -42,6 +41,10 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
class Starcoder2Attention(nn.Module): class Starcoder2Attention(nn.Module):
...@@ -195,7 +198,8 @@ class Starcoder2Model(nn.Module): ...@@ -195,7 +198,8 @@ class Starcoder2Model(nn.Module):
def __init__(self, def __init__(self,
config: Starcoder2Config, config: Starcoder2Config,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None): quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__() super().__init__()
self.config = config self.config = config
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
...@@ -204,13 +208,16 @@ class Starcoder2Model(nn.Module): ...@@ -204,13 +208,16 @@ class Starcoder2Model(nn.Module):
# TODO: consider padding_idx (currently removed) # TODO: consider padding_idx (currently removed)
self.embed_tokens = VocabParallelEmbedding(config.vocab_size, self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
config.hidden_size) config.hidden_size)
self.layers = nn.ModuleList([ self.start_layer, self.end_layer, self.layers = make_layers(
Starcoder2DecoderLayer(config, config.num_hidden_layers,
cache_config, lambda prefix: Starcoder2DecoderLayer(
quant_config=quant_config) config, cache_config, quant_config=quant_config),
for _ in range(config.num_hidden_layers) prefix=f"{prefix}.layers",
]) )
self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon) self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(["hidden_states"],
config.hidden_size))
def forward( def forward(
self, self,
...@@ -218,17 +225,25 @@ class Starcoder2Model(nn.Module): ...@@ -218,17 +225,25 @@ class Starcoder2Model(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: intermediate_tensors: Optional[IntermediateTensors],
hidden_states = self.embed_tokens(input_ids) ) -> Union[torch.Tensor, IntermediateTensors]:
for i in range(len(self.layers)): if get_pp_group().is_first_rank:
hidden_states = self.embed_tokens(input_ids)
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i] layer = self.layers[i]
hidden_states = layer(positions, hidden_states, kv_caches[i], hidden_states = layer(positions, hidden_states,
kv_caches[i - self.start_layer],
attn_metadata) attn_metadata)
if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})
hidden_states = self.norm(hidden_states) hidden_states = self.norm(hidden_states)
return hidden_states return hidden_states
class Starcoder2ForCausalLM(nn.Module): class Starcoder2ForCausalLM(nn.Module, SupportsPP):
def __init__(self, def __init__(self,
config: Starcoder2Config, config: Starcoder2Config,
...@@ -255,6 +270,8 @@ class Starcoder2ForCausalLM(nn.Module): ...@@ -255,6 +270,8 @@ class Starcoder2ForCausalLM(nn.Module):
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size) config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def forward( def forward(
self, self,
...@@ -263,9 +280,9 @@ class Starcoder2ForCausalLM(nn.Module): ...@@ -263,9 +280,9 @@ class Starcoder2ForCausalLM(nn.Module):
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata) attn_metadata, intermediate_tensors)
return hidden_states return hidden_states
def compute_logits( def compute_logits(
...@@ -302,6 +319,8 @@ class Starcoder2ForCausalLM(nn.Module): ...@@ -302,6 +319,8 @@ class Starcoder2ForCausalLM(nn.Module):
if weight_name not in name: if weight_name not in name:
continue continue
name = name.replace(weight_name, param_name) name = name.replace(weight_name, param_name)
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name] param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
...@@ -309,6 +328,8 @@ class Starcoder2ForCausalLM(nn.Module): ...@@ -309,6 +328,8 @@ class Starcoder2ForCausalLM(nn.Module):
else: else:
if self.config.tie_word_embeddings and "lm_head.weight" in name: if self.config.tie_word_embeddings and "lm_head.weight" in name:
continue continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
import math import math
from array import array from array import array
from functools import lru_cache from functools import cached_property, lru_cache
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union, cast) TypedDict, Union, cast)
...@@ -22,12 +22,10 @@ from vllm.inputs.data import LLMInputs ...@@ -22,12 +22,10 @@ from vllm.inputs.data import LLMInputs
from vllm.inputs.registry import InputContext from vllm.inputs.registry import InputContext
from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization import QuantizationConfig
QuantizationConfig) from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader.loader import DefaultModelLoader from vllm.model_executor.model_loader.loader import DefaultModelLoader
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import SupportsMultiModal
from vllm.model_executor.models.utils import (flatten_bn, from vllm.model_executor.models.utils import (flatten_bn,
group_weights_with_prefix, group_weights_with_prefix,
init_vllm_registered_model, init_vllm_registered_model,
...@@ -37,9 +35,12 @@ from vllm.multimodal import MULTIMODAL_REGISTRY ...@@ -37,9 +35,12 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.base import MultiModalInputs, NestedTensors from vllm.multimodal.base import MultiModalInputs, NestedTensors
from vllm.multimodal.utils import (cached_get_tokenizer, from vllm.multimodal.utils import (cached_get_tokenizer,
repeat_and_pad_placeholder_tokens) repeat_and_pad_placeholder_tokens)
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
SequenceData)
from vllm.transformers_utils.configs.ultravox import UltravoxConfig from vllm.transformers_utils.configs.ultravox import UltravoxConfig
from .interfaces import SupportsMultiModal, SupportsPP
_AUDIO_PLACEHOLDER_TOKEN = 128002 _AUDIO_PLACEHOLDER_TOKEN = 128002
_AUDIO_TOKENS_PER_SECOND = 6.25 _AUDIO_TOKENS_PER_SECOND = 6.25
...@@ -323,7 +324,7 @@ class ModifiedWhisperEncoder(WhisperEncoder): ...@@ -323,7 +324,7 @@ class ModifiedWhisperEncoder(WhisperEncoder):
"audio", get_ultravox_max_audio_tokens) "audio", get_ultravox_max_audio_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_ultravox) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_ultravox)
@INPUT_REGISTRY.register_input_processor(input_processor_for_ultravox) @INPUT_REGISTRY.register_input_processor(input_processor_for_ultravox)
class UltravoxModel(nn.Module, SupportsMultiModal): class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP):
def __init__(self, def __init__(self,
config: UltravoxConfig, config: UltravoxConfig,
...@@ -353,6 +354,16 @@ class UltravoxModel(nn.Module, SupportsMultiModal): ...@@ -353,6 +354,16 @@ class UltravoxModel(nn.Module, SupportsMultiModal):
revision=None, revision=None,
prefix="language_model.")) prefix="language_model."))
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
@cached_property
def sampler(self):
if hasattr(self.language_model, "sampler"):
return self.language_model.sampler
return Sampler()
def _audio_features_to_embeddings( def _audio_features_to_embeddings(
self, input_features: torch.Tensor) -> torch.Tensor: self, input_features: torch.Tensor) -> torch.Tensor:
audio_input = input_features.to(self.audio_tower.dtype) audio_input = input_features.to(self.audio_tower.dtype)
...@@ -425,7 +436,7 @@ class UltravoxModel(nn.Module, SupportsMultiModal): ...@@ -425,7 +436,7 @@ class UltravoxModel(nn.Module, SupportsMultiModal):
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[torch.Tensor], intermediate_tensors: Optional[torch.Tensor],
**kwargs) -> SamplerOutput: **kwargs) -> Union[torch.Tensor, IntermediateTensors]:
"""Run forward pass for Ultravox """Run forward pass for Ultravox
One key thing to understand is the `input_ids` already accounts for the One key thing to understand is the `input_ids` already accounts for the
...@@ -438,18 +449,22 @@ class UltravoxModel(nn.Module, SupportsMultiModal): ...@@ -438,18 +449,22 @@ class UltravoxModel(nn.Module, SupportsMultiModal):
Args: Args:
audio_features: A batch of audio inputs [B, N, 80, M]. audio_features: A batch of audio inputs [B, N, 80, M].
""" """
audio_input = self._parse_and_validate_audio_input(**kwargs) if intermediate_tensors is not None:
if audio_input is not None:
audio_embeddings = self._process_audio_input(audio_input)
inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, audio_embeddings,
_AUDIO_PLACEHOLDER_TOKEN)
input_ids = None input_ids = None
else:
inputs_embeds = None inputs_embeds = None
else:
audio_input = self._parse_and_validate_audio_input(**kwargs)
if audio_input is not None:
audio_embeddings = self._process_audio_input(audio_input)
inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, audio_embeddings,
_AUDIO_PLACEHOLDER_TOKEN)
input_ids = None
else:
inputs_embeds = None
hidden_states = self.language_model.model( hidden_states = self.language_model.model(
input_ids=input_ids, input_ids=input_ids,
......
...@@ -24,7 +24,7 @@ class WeightsGroup(UserDict): ...@@ -24,7 +24,7 @@ class WeightsGroup(UserDict):
when attempting to access a weight component that does not exist. when attempting to access a weight component that does not exist.
""" """
def __getitem__(self, key: str) -> int: def __getitem__(self, key: str) -> Iterable[Tuple[str, torch.Tensor]]:
try: try:
return super().__getitem__(key) return super().__getitem__(key)
except KeyError as exc: except KeyError as exc:
...@@ -49,8 +49,7 @@ def filter_weights(weights: Iterable[Tuple[str, torch.Tensor]], ...@@ -49,8 +49,7 @@ def filter_weights(weights: Iterable[Tuple[str, torch.Tensor]],
def group_weights_with_prefix( def group_weights_with_prefix(
weights: Iterable[Tuple[str, torch.Tensor]] weights: Iterable[Tuple[str, torch.Tensor]], ) -> WeightsGroup:
) -> Dict[str, Iterable[Tuple[str, torch.Tensor]]]:
""" """
Helper function to group weights with prefix Helper function to group weights with prefix
""" """
...@@ -183,10 +182,7 @@ def merge_multimodal_embeddings(input_ids: torch.Tensor, ...@@ -183,10 +182,7 @@ def merge_multimodal_embeddings(input_ids: torch.Tensor,
class LayerFn(Protocol): class LayerFn(Protocol):
def __call__( def __call__(self, prefix: str) -> torch.nn.Module:
self,
prefix="",
) -> torch.nn.Module:
... ...
...@@ -319,8 +315,10 @@ def is_pp_missing_parameter(name: str, model: torch.nn.Module) -> bool: ...@@ -319,8 +315,10 @@ def is_pp_missing_parameter(name: str, model: torch.nn.Module) -> bool:
def make_empty_intermediate_tensors_factory(keys: List[str], hidden_size: int): def make_empty_intermediate_tensors_factory(keys: List[str], hidden_size: int):
def make_empty_intermediate_tensors( def make_empty_intermediate_tensors(
batch_size: int, dtype: torch.dtype, batch_size: int,
device: torch.device) -> IntermediateTensors: dtype: torch.dtype,
device: torch.device,
) -> IntermediateTensors:
return IntermediateTensors({ return IntermediateTensors({
key: torch.zeros((batch_size, hidden_size), key: torch.zeros((batch_size, hidden_size),
dtype=dtype, dtype=dtype,
...@@ -342,8 +340,14 @@ class LLMWrapper(nn.Module): ...@@ -342,8 +340,14 @@ class LLMWrapper(nn.Module):
self.model_name = name self.model_name = name
setattr(self, name, llm) setattr(self, name, llm)
def forward(self, *args, **kwargs) -> Any: def __getattr__(self, key: str):
return getattr(self, self.model_name)(*args, **kwargs) llm = super().__getattr__(self.model_name)
if key == self.model_name:
return llm
def embed_tokens(self, *args, **kwargs) -> Any: return getattr(llm, key)
return getattr(self, self.model_name).embed_tokens(*args, **kwargs)
# We need to explicitly override this
def __call__(self, *args: Any, **kwargs: Any) -> Any:
llm = super().__getattr__(self.model_name)
return llm(*args, **kwargs)
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only Xverse model compatible with HuggingFace weights.""" """Inference-only Xverse model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, List, Optional, Tuple from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
import torch import torch
from torch import nn from torch import nn
...@@ -28,15 +28,14 @@ from transformers import PretrainedConfig ...@@ -28,15 +28,14 @@ from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization import QuantizationConfig
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
...@@ -45,7 +44,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader ...@@ -45,7 +44,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA from .interfaces import SupportsLoRA, SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers)
class XverseMLP(nn.Module): class XverseMLP(nn.Module):
...@@ -227,6 +228,7 @@ class XverseModel(nn.Module): ...@@ -227,6 +228,7 @@ class XverseModel(nn.Module):
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None, lora_config: Optional[LoRAConfig] = None,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -240,11 +242,16 @@ class XverseModel(nn.Module): ...@@ -240,11 +242,16 @@ class XverseModel(nn.Module):
config.hidden_size, config.hidden_size,
org_num_embeddings=config.vocab_size, org_num_embeddings=config.vocab_size,
) )
self.layers = nn.ModuleList([ self.start_layer, self.end_layer, self.layers = make_layers(
XverseDecoderLayer(config, cache_config, quant_config) config.num_hidden_layers,
for _ in range(config.num_hidden_layers) lambda prefix: XverseDecoderLayer(config, cache_config,
]) quant_config),
prefix=f"{prefix}.layers",
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
def forward( def forward(
self, self,
...@@ -252,23 +259,32 @@ class XverseModel(nn.Module): ...@@ -252,23 +259,32 @@ class XverseModel(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: intermediate_tensors: Optional[IntermediateTensors],
hidden_states = self.embed_tokens(input_ids) ) -> Union[torch.Tensor, IntermediateTensors]:
residual = None if get_pp_group().is_first_rank:
for i in range(len(self.layers)): hidden_states = self.embed_tokens(input_ids)
residual = None
else:
hidden_states = intermediate_tensors["hidden_states"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i] layer = self.layers[i]
hidden_states, residual = layer( hidden_states, residual = layer(
positions, positions,
hidden_states, hidden_states,
kv_caches[i], kv_caches[i - self.start_layer],
attn_metadata, attn_metadata,
residual, residual,
) )
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states return hidden_states
class XverseForCausalLM(nn.Module, SupportsLoRA): class XverseForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": [ "qkv_proj": [
"q_proj", "q_proj",
...@@ -317,6 +333,8 @@ class XverseForCausalLM(nn.Module, SupportsLoRA): ...@@ -317,6 +333,8 @@ class XverseForCausalLM(nn.Module, SupportsLoRA):
self.lm_head.weight = self.model.embed_tokens.weight self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
def forward( def forward(
self, self,
...@@ -325,9 +343,9 @@ class XverseForCausalLM(nn.Module, SupportsLoRA): ...@@ -325,9 +343,9 @@ class XverseForCausalLM(nn.Module, SupportsLoRA):
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata) attn_metadata, intermediate_tensors)
return hidden_states return hidden_states
def compute_logits( def compute_logits(
...@@ -368,6 +386,8 @@ class XverseForCausalLM(nn.Module, SupportsLoRA): ...@@ -368,6 +386,8 @@ class XverseForCausalLM(nn.Module, SupportsLoRA):
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name] param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
...@@ -376,6 +396,8 @@ class XverseForCausalLM(nn.Module, SupportsLoRA): ...@@ -376,6 +396,8 @@ class XverseForCausalLM(nn.Module, SupportsLoRA):
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment