Unverified Commit e73f76ee authored by Besher Alkurdi's avatar Besher Alkurdi Committed by GitHub
Browse files

[Model] Pipeline parallel support for JAIS (#7603)

parent d95cc0a5
...@@ -36,6 +36,7 @@ _PP_SUPPORTED_MODELS = [ ...@@ -36,6 +36,7 @@ _PP_SUPPORTED_MODELS = [
"AquilaForCausalLM", "AquilaForCausalLM",
"DeepseekV2ForCausalLM", "DeepseekV2ForCausalLM",
"InternLMForCausalLM", "InternLMForCausalLM",
"JAISLMHeadModel",
"LlamaForCausalLM", "LlamaForCausalLM",
"LLaMAForCausalLM", "LLaMAForCausalLM",
"MistralForCausalLM", "MistralForCausalLM",
......
...@@ -20,14 +20,14 @@ ...@@ -20,14 +20,14 @@
"""Inference-only Jais model compatible with HuggingFace weights.""" """Inference-only Jais model compatible with HuggingFace weights."""
import math import math
from typing import Iterable, List, Optional, Tuple from typing import Iterable, List, Optional, Tuple, Union
import torch import torch
from torch import nn from torch import nn
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig from vllm.config import CacheConfig
from vllm.distributed import (get_tensor_model_parallel_rank, from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
...@@ -43,6 +43,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata ...@@ -43,6 +43,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
from vllm.transformers_utils.configs import JAISConfig from vllm.transformers_utils.configs import JAISConfig
from .utils import is_pp_missing_parameter, make_layers
class SwiGLUActivation(nn.Module): class SwiGLUActivation(nn.Module):
...@@ -216,6 +218,7 @@ class JAISModel(nn.Module): ...@@ -216,6 +218,7 @@ class JAISModel(nn.Module):
config: JAISConfig, config: JAISConfig,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.config = config self.config = config
...@@ -231,10 +234,15 @@ class JAISModel(nn.Module): ...@@ -231,10 +234,15 @@ class JAISModel(nn.Module):
self.embeddings_scale = config.embeddings_scale self.embeddings_scale = config.embeddings_scale
else: else:
self.embeddings_scale = config.mup_embeddings_scale self.embeddings_scale = config.mup_embeddings_scale
self.h = nn.ModuleList([
JAISBlock(config, cache_config, quant_config) self.start_layer, self.end_layer, self.h = make_layers(
for _ in range(config.num_hidden_layers) config.num_hidden_layers,
]) lambda prefix: JAISBlock(config=config,
cache_config=cache_config,
quant_config=quant_config),
prefix=f"{prefix}.h",
)
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
def forward( def forward(
...@@ -243,7 +251,9 @@ class JAISModel(nn.Module): ...@@ -243,7 +251,9 @@ class JAISModel(nn.Module):
position_ids: torch.Tensor, position_ids: torch.Tensor,
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: intermediate_tensors: Optional[IntermediateTensors] = None,
) -> Union[IntermediateTensors, torch.Tensor]:
if get_pp_group().is_first_rank:
inputs_embeds = self.wte(input_ids) inputs_embeds = self.wte(input_ids)
if self.wpe is not None: if self.wpe is not None:
position_embeds = self.wpe(position_ids) position_embeds = self.wpe(position_ids)
...@@ -252,10 +262,18 @@ class JAISModel(nn.Module): ...@@ -252,10 +262,18 @@ class JAISModel(nn.Module):
hidden_states = inputs_embeds hidden_states = inputs_embeds
hidden_states *= torch.tensor(float(self.embeddings_scale), hidden_states *= torch.tensor(float(self.embeddings_scale),
dtype=hidden_states.dtype) dtype=hidden_states.dtype)
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
for i in range(len(self.h)): for i in range(self.start_layer, self.end_layer):
layer = self.h[i] layer = self.h[i]
hidden_states = layer(hidden_states, kv_caches[i], attn_metadata) hidden_states = layer(hidden_states,
kv_caches[i - self.start_layer],
attn_metadata)
if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})
hidden_states = self.ln_f(hidden_states) hidden_states = self.ln_f(hidden_states)
return hidden_states return hidden_states
...@@ -290,9 +308,9 @@ class JAISLMHeadModel(nn.Module): ...@@ -290,9 +308,9 @@ class JAISLMHeadModel(nn.Module):
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor: ) -> Union[IntermediateTensors, torch.Tensor]:
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions, kv_caches,
attn_metadata) attn_metadata, intermediate_tensors)
return hidden_states return hidden_states
def compute_logits( def compute_logits(
...@@ -304,6 +322,16 @@ class JAISLMHeadModel(nn.Module): ...@@ -304,6 +322,16 @@ class JAISLMHeadModel(nn.Module):
sampling_metadata) sampling_metadata)
return logits return logits
def make_empty_intermediate_tensors(
self, batch_size: int, dtype: torch.dtype,
device: torch.device) -> IntermediateTensors:
return IntermediateTensors({
"hidden_states":
torch.zeros((batch_size, self.config.hidden_size),
dtype=dtype,
device=device),
})
def sample( def sample(
self, self,
logits: torch.Tensor, logits: torch.Tensor,
...@@ -327,6 +355,10 @@ class JAISLMHeadModel(nn.Module): ...@@ -327,6 +355,10 @@ class JAISLMHeadModel(nn.Module):
continue continue
if not name.startswith("transformer."): if not name.startswith("transformer."):
name = "transformer." + name name = "transformer." + name
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name] param = params_dict[name]
# The HF's GPT-2 implementation uses Conv1D instead of Linear. # The HF's GPT-2 implementation uses Conv1D instead of Linear.
# Because of this, we need to transpose the weights. # Because of this, we need to transpose the weights.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment