Unverified Commit c5832d2a authored by Murali Andoorveedu's avatar Murali Andoorveedu Committed by GitHub
Browse files

[Core] Pipeline Parallel Support (#4412)


Signed-off-by: default avatarMuralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
parent 15aba081
...@@ -44,7 +44,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -44,7 +44,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
from vllm.transformers_utils.configs import RWConfig from vllm.transformers_utils.configs import RWConfig
FalconConfig = Union[HF_FalconConfig, RWConfig] FalconConfig = Union[HF_FalconConfig, RWConfig]
...@@ -410,6 +410,7 @@ class FalconForCausalLM(nn.Module): ...@@ -410,6 +410,7 @@ class FalconForCausalLM(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.transformer( hidden_states = self.transformer(
input_ids, input_ids,
......
...@@ -39,7 +39,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -39,7 +39,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
from .interfaces import SupportsLoRA from .interfaces import SupportsLoRA
...@@ -339,6 +339,7 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA): ...@@ -339,6 +339,7 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata) attn_metadata)
......
...@@ -37,7 +37,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -37,7 +37,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
from vllm.utils import print_warning_once from vllm.utils import print_warning_once
from .interfaces import SupportsLoRA from .interfaces import SupportsLoRA
...@@ -338,6 +338,7 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA): ...@@ -338,6 +338,7 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata) attn_metadata)
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only GPT-2 model compatible with HuggingFace weights.""" """Inference-only GPT-2 model compatible with HuggingFace weights."""
from typing import Iterable, List, Optional, Tuple from typing import Iterable, List, Optional, Tuple, Union
import torch import torch
from torch import nn from torch import nn
...@@ -25,7 +25,9 @@ from transformers import GPT2Config ...@@ -25,7 +25,9 @@ from transformers import GPT2Config
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig from vllm.config import CacheConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import (
get_pp_group, get_tensor_model_parallel_world_size)
from vllm.distributed.utils import get_pp_indices
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
...@@ -38,7 +40,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -38,7 +40,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
class GPT2Attention(nn.Module): class GPT2Attention(nn.Module):
...@@ -181,10 +183,18 @@ class GPT2Model(nn.Module): ...@@ -181,10 +183,18 @@ class GPT2Model(nn.Module):
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim) self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim)
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
self.h = nn.ModuleList([ self.start_layer, self.end_layer = get_pp_indices(
GPT2Block(config, cache_config, quant_config) config.num_hidden_layers,
for _ in range(config.num_hidden_layers) get_pp_group().rank_in_group,
]) get_pp_group().world_size)
self.h = nn.ModuleList(
[nn.Identity() for _ in range(self.start_layer)] + [
GPT2Block(config, cache_config, quant_config)
for _ in range(self.start_layer, self.end_layer)
] + [
nn.Identity()
for _ in range(self.end_layer, config.num_hidden_layers)
])
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
def forward( def forward(
...@@ -193,14 +203,24 @@ class GPT2Model(nn.Module): ...@@ -193,14 +203,24 @@ class GPT2Model(nn.Module):
position_ids: torch.Tensor, position_ids: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds = self.wte(input_ids) ) -> Union[torch.Tensor, IntermediateTensors]:
position_embeds = self.wpe(position_ids) if get_pp_group().is_first_rank:
hidden_states = inputs_embeds + position_embeds inputs_embeds = self.wte(input_ids)
position_embeds = self.wpe(position_ids)
hidden_states = inputs_embeds + position_embeds
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
for i in range(len(self.h)): for i in range(self.start_layer, self.end_layer):
layer = self.h[i] layer = self.h[i]
hidden_states = layer(hidden_states, kv_caches[i], attn_metadata) hidden_states = layer(hidden_states,
kv_caches[i - self.start_layer],
attn_metadata)
if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})
hidden_states = self.ln_f(hidden_states) hidden_states = self.ln_f(hidden_states)
return hidden_states return hidden_states
...@@ -228,9 +248,10 @@ class GPT2LMHeadModel(nn.Module): ...@@ -228,9 +248,10 @@ class GPT2LMHeadModel(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions, kv_caches,
attn_metadata) attn_metadata, intermediate_tensors)
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(self, hidden_states: torch.Tensor,
...@@ -247,6 +268,16 @@ class GPT2LMHeadModel(nn.Module): ...@@ -247,6 +268,16 @@ class GPT2LMHeadModel(nn.Module):
next_tokens = self.sampler(logits, sampling_metadata) next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens return next_tokens
def make_empty_intermediate_tensors(
self, batch_size: int, dtype: torch.dtype,
device: torch.device) -> IntermediateTensors:
return IntermediateTensors({
"hidden_states":
torch.zeros((batch_size, self.config.hidden_size),
dtype=dtype,
device=device),
})
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters(remove_duplicate=False)) params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in weights: for name, loaded_weight in weights:
...@@ -260,16 +291,19 @@ class GPT2LMHeadModel(nn.Module): ...@@ -260,16 +291,19 @@ class GPT2LMHeadModel(nn.Module):
continue continue
if not name.startswith("transformer."): if not name.startswith("transformer."):
name = "transformer." + name name = "transformer." + name
param = params_dict[name] try:
# The HF's GPT-2 implementation uses Conv1D instead of Linear. param = params_dict[name]
# Because of this, we need to transpose the weights. # The HF's GPT-2 implementation uses Conv1D instead of Linear.
# Note(zhuohan): the logic below might break quantized models. # Because of this, we need to transpose the weights.
for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]: # Note(zhuohan): the logic below might break quantized models.
if conv1d_weight_name not in name: for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
continue if conv1d_weight_name not in name:
if not name.endswith(".weight"): continue
continue if not name.endswith(".weight"):
loaded_weight = loaded_weight.t() continue
weight_loader = getattr(param, "weight_loader", loaded_weight = loaded_weight.t()
default_weight_loader) weight_loader = getattr(param, "weight_loader",
weight_loader(param, loaded_weight) default_weight_loader)
weight_loader(param, loaded_weight)
except KeyError:
continue
...@@ -39,7 +39,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -39,7 +39,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
from .interfaces import SupportsLoRA from .interfaces import SupportsLoRA
...@@ -273,6 +273,7 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA): ...@@ -273,6 +273,7 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions, kv_caches,
attn_metadata) attn_metadata)
......
...@@ -38,7 +38,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -38,7 +38,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
class GPTJAttention(nn.Module): class GPTJAttention(nn.Module):
...@@ -239,6 +239,7 @@ class GPTJForCausalLM(nn.Module): ...@@ -239,6 +239,7 @@ class GPTJForCausalLM(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions, kv_caches,
attn_metadata) attn_metadata)
......
...@@ -38,7 +38,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -38,7 +38,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
class GPTNeoXAttention(nn.Module): class GPTNeoXAttention(nn.Module):
...@@ -251,6 +251,7 @@ class GPTNeoXForCausalLM(nn.Module): ...@@ -251,6 +251,7 @@ class GPTNeoXForCausalLM(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.gpt_neox(input_ids, positions, kv_caches, hidden_states = self.gpt_neox(input_ids, positions, kv_caches,
attn_metadata) attn_metadata)
......
...@@ -22,7 +22,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -22,7 +22,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
class InternLM2MLP(nn.Module): class InternLM2MLP(nn.Module):
...@@ -263,6 +263,7 @@ class InternLM2ForCausalLM(nn.Module): ...@@ -263,6 +263,7 @@ class InternLM2ForCausalLM(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: IntermediateTensors,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata) attn_metadata)
......
...@@ -40,7 +40,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -40,7 +40,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
from vllm.transformers_utils.configs import JAISConfig from vllm.transformers_utils.configs import JAISConfig
...@@ -289,6 +289,7 @@ class JAISLMHeadModel(nn.Module): ...@@ -289,6 +289,7 @@ class JAISLMHeadModel(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions, kv_caches,
attn_metadata) attn_metadata)
......
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only LLaMA model compatible with HuggingFace weights.""" """Inference-only LLaMA model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, List, Optional, Tuple from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
import torch import torch
from torch import nn from torch import nn
...@@ -29,7 +29,8 @@ from transformers import LlamaConfig ...@@ -29,7 +29,8 @@ from transformers import LlamaConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import (get_tensor_model_parallel_rank, from vllm.distributed import (get_pp_group, get_pp_indices,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
...@@ -46,7 +47,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -46,7 +47,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, kv_cache_scales_loader) default_weight_loader, kv_cache_scales_loader)
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
from vllm.utils import is_hip, print_warning_once from vllm.utils import is_hip, print_warning_once
from .interfaces import SupportsLoRA from .interfaces import SupportsLoRA
...@@ -261,12 +262,20 @@ class LlamaModel(nn.Module): ...@@ -261,12 +262,20 @@ class LlamaModel(nn.Module):
config.hidden_size, config.hidden_size,
org_num_embeddings=config.vocab_size, org_num_embeddings=config.vocab_size,
) )
self.layers = nn.ModuleList([ self.start_layer, self.end_layer = get_pp_indices(
LlamaDecoderLayer(config=config, config.num_hidden_layers,
cache_config=cache_config, get_pp_group().rank_in_group,
quant_config=quant_config) get_pp_group().world_size)
for idx in range(config.num_hidden_layers) self.layers = nn.ModuleList(
]) [nn.Identity() for _ in range(self.start_layer)] + [
LlamaDecoderLayer(config=config,
cache_config=cache_config,
quant_config=quant_config)
for _ in range(self.start_layer, self.end_layer)
] + [
nn.Identity()
for _ in range(self.end_layer, config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
...@@ -278,22 +287,36 @@ class LlamaModel(nn.Module): ...@@ -278,22 +287,36 @@ class LlamaModel(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> Union[torch.Tensor, IntermediateTensors]:
if inputs_embeds is not None: if get_pp_group().is_first_rank:
hidden_states = inputs_embeds if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None
else: else:
hidden_states = self.get_input_embeddings(input_ids) assert intermediate_tensors is not None
residual = None hidden_states = intermediate_tensors["hidden_states"]
for i in range(len(self.layers)): residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i] layer = self.layers[i]
hidden_states, residual = layer( hidden_states, residual = layer(
positions, positions,
hidden_states, hidden_states,
kv_caches[i], kv_caches[i - self.start_layer],
attn_metadata, attn_metadata,
residual, residual,
) )
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states return hidden_states
...@@ -372,10 +395,11 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA): ...@@ -372,10 +395,11 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: intermediate_tensors: Optional[IntermediateTensors] = None,
hidden_states = self.model(input_ids, positions, kv_caches, ) -> Union[torch.Tensor, IntermediateTensors]:
attn_metadata) model_output = self.model(input_ids, positions, kv_caches,
return hidden_states attn_metadata, intermediate_tensors)
return model_output
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor: sampling_metadata: SamplingMetadata) -> torch.Tensor:
...@@ -391,6 +415,20 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA): ...@@ -391,6 +415,20 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
next_tokens = self.sampler(logits, sampling_metadata) next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens return next_tokens
def make_empty_intermediate_tensors(
self, batch_size: int, dtype: torch.dtype,
device: torch.device) -> IntermediateTensors:
return IntermediateTensors({
"hidden_states":
torch.zeros((batch_size, self.config.hidden_size),
dtype=dtype,
device=device),
"residual":
torch.zeros((batch_size, self.config.hidden_size),
dtype=dtype,
device=device),
})
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
...@@ -416,9 +454,12 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA): ...@@ -416,9 +454,12 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
param = params_dict[name] try:
weight_loader = param.weight_loader param = params_dict[name]
weight_loader(param, loaded_weight, shard_id) weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
except KeyError:
pass
break break
else: else:
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
...@@ -437,10 +478,13 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA): ...@@ -437,10 +478,13 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
continue continue
else: else:
name = remapped_kv_scale_name name = remapped_kv_scale_name
param = params_dict[name] try:
weight_loader = getattr(param, "weight_loader", param = params_dict[name]
default_weight_loader) weight_loader = getattr(param, "weight_loader",
weight_loader(param, loaded_weight) default_weight_loader)
weight_loader(param, loaded_weight)
except KeyError:
pass
# If this function is called, it should always initialize KV cache scale # If this function is called, it should always initialize KV cache scale
# factors (or else raise an exception). Thus, handled exceptions should # factors (or else raise an exception). Thus, handled exceptions should
...@@ -452,7 +496,8 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA): ...@@ -452,7 +496,8 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
quantization_param_path, tp_rank, tp_size, quantization_param_path, tp_rank, tp_size,
self.config.num_hidden_layers, self.config.num_hidden_layers,
self.config.__class__.model_type): self.config.__class__.model_type):
layer_self_attn = self.model.layers[layer_idx].self_attn if not isinstance(self.model.layers[layer_idx], nn.Identity):
layer_self_attn = self.model.layers[layer_idx].self_attn
if is_hip(): if is_hip():
# The scaling factor convention we are assuming is # The scaling factor convention we are assuming is
......
...@@ -18,7 +18,7 @@ from vllm.model_executor.models.clip import CLIPVisionModel ...@@ -18,7 +18,7 @@ from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.models.llama import LlamaModel from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
from .clip import dummy_image_for_clip, dummy_seq_data_for_clip from .clip import dummy_image_for_clip, dummy_seq_data_for_clip
from .interfaces import SupportsVision from .interfaces import SupportsVision
...@@ -202,6 +202,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision): ...@@ -202,6 +202,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs: object, **kwargs: object,
) -> SamplerOutput: ) -> SamplerOutput:
"""Run forward pass for LLaVA-1.5. """Run forward pass for LLaVA-1.5.
...@@ -247,6 +248,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision): ...@@ -247,6 +248,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
positions, positions,
kv_caches, kv_caches,
attn_metadata, attn_metadata,
None,
inputs_embeds=inputs_embeds) inputs_embeds=inputs_embeds)
return hidden_states return hidden_states
......
...@@ -22,7 +22,7 @@ from vllm.model_executor.models.clip import CLIPVisionModel ...@@ -22,7 +22,7 @@ from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.models.llama import LlamaModel from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip, from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
get_clip_patch_grid_length) get_clip_patch_grid_length)
...@@ -376,6 +376,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision): ...@@ -376,6 +376,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs: object, **kwargs: object,
) -> SamplerOutput: ) -> SamplerOutput:
"""Run forward pass for LlaVA-NeXT. """Run forward pass for LlaVA-NeXT.
...@@ -430,6 +431,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision): ...@@ -430,6 +431,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
positions, positions,
kv_caches, kv_caches,
attn_metadata, attn_metadata,
None,
inputs_embeds=inputs_embeds) inputs_embeds=inputs_embeds)
return hidden_states return hidden_states
......
...@@ -50,7 +50,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -50,7 +50,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.sequence import SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
from .interfaces import SupportsLoRA from .interfaces import SupportsLoRA
...@@ -462,6 +462,7 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA): ...@@ -462,6 +462,7 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata) attn_metadata)
......
...@@ -51,7 +51,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -51,7 +51,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.sequence import SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
from vllm.utils import print_warning_once from vllm.utils import print_warning_once
from .interfaces import SupportsLoRA from .interfaces import SupportsLoRA
...@@ -536,6 +536,7 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA): ...@@ -536,6 +536,7 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata) attn_metadata)
......
...@@ -47,7 +47,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -47,7 +47,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
class MixtralMLP(nn.Module): class MixtralMLP(nn.Module):
...@@ -354,6 +354,7 @@ class MixtralForCausalLM(nn.Module): ...@@ -354,6 +354,7 @@ class MixtralForCausalLM(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata) attn_metadata)
......
...@@ -22,7 +22,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -22,7 +22,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
from vllm.transformers_utils.configs.mpt import MPTConfig from vllm.transformers_utils.configs.mpt import MPTConfig
...@@ -273,6 +273,7 @@ class MPTForCausalLM(nn.Module): ...@@ -273,6 +273,7 @@ class MPTForCausalLM(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions, kv_caches,
attn_metadata) attn_metadata)
......
...@@ -43,7 +43,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -43,7 +43,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
class OlmoAttention(nn.Module): class OlmoAttention(nn.Module):
...@@ -301,6 +301,7 @@ class OlmoForCausalLM(nn.Module): ...@@ -301,6 +301,7 @@ class OlmoForCausalLM(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model( hidden_states = self.model(
input_ids=input_ids, input_ids=input_ids,
......
...@@ -39,7 +39,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -39,7 +39,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
class OPTLearnedPositionalEmbedding(nn.Embedding): class OPTLearnedPositionalEmbedding(nn.Embedding):
...@@ -304,6 +304,7 @@ class OPTForCausalLM(nn.Module): ...@@ -304,6 +304,7 @@ class OPTForCausalLM(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata) attn_metadata)
......
...@@ -26,7 +26,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -26,7 +26,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
class OrionMLP(nn.Module): class OrionMLP(nn.Module):
...@@ -269,6 +269,7 @@ class OrionForCausalLM(nn.Module): ...@@ -269,6 +269,7 @@ class OrionForCausalLM(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata) attn_metadata)
......
...@@ -57,7 +57,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -57,7 +57,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
from .interfaces import SupportsLoRA from .interfaces import SupportsLoRA
...@@ -278,6 +278,7 @@ class PhiForCausalLM(nn.Module, SupportsLoRA): ...@@ -278,6 +278,7 @@ class PhiForCausalLM(nn.Module, SupportsLoRA):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata) attn_metadata)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment