Unverified Commit e73f76ee authored by Besher Alkurdi's avatar Besher Alkurdi Committed by GitHub
Browse files

[Model] Pipeline parallel support for JAIS (#7603)

parent d95cc0a5
......@@ -36,6 +36,7 @@ _PP_SUPPORTED_MODELS = [
"AquilaForCausalLM",
"DeepseekV2ForCausalLM",
"InternLMForCausalLM",
"JAISLMHeadModel",
"LlamaForCausalLM",
"LLaMAForCausalLM",
"MistralForCausalLM",
......
......@@ -20,14 +20,14 @@
"""Inference-only Jais model compatible with HuggingFace weights."""
import math
from typing import Iterable, List, Optional, Tuple
from typing import Iterable, List, Optional, Tuple, Union
import torch
from torch import nn
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
......@@ -43,6 +43,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, SamplerOutput
from vllm.transformers_utils.configs import JAISConfig
from .utils import is_pp_missing_parameter, make_layers
class SwiGLUActivation(nn.Module):
......@@ -216,6 +218,7 @@ class JAISModel(nn.Module):
config: JAISConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.config = config
......@@ -231,10 +234,15 @@ class JAISModel(nn.Module):
self.embeddings_scale = config.embeddings_scale
else:
self.embeddings_scale = config.mup_embeddings_scale
self.h = nn.ModuleList([
JAISBlock(config, cache_config, quant_config)
for _ in range(config.num_hidden_layers)
])
self.start_layer, self.end_layer, self.h = make_layers(
config.num_hidden_layers,
lambda prefix: JAISBlock(config=config,
cache_config=cache_config,
quant_config=quant_config),
prefix=f"{prefix}.h",
)
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
def forward(
......@@ -243,7 +251,9 @@ class JAISModel(nn.Module):
position_ids: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> Union[IntermediateTensors, torch.Tensor]:
if get_pp_group().is_first_rank:
inputs_embeds = self.wte(input_ids)
if self.wpe is not None:
position_embeds = self.wpe(position_ids)
......@@ -252,10 +262,18 @@ class JAISModel(nn.Module):
hidden_states = inputs_embeds
hidden_states *= torch.tensor(float(self.embeddings_scale),
dtype=hidden_states.dtype)
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
for i in range(len(self.h)):
for i in range(self.start_layer, self.end_layer):
layer = self.h[i]
hidden_states = layer(hidden_states, kv_caches[i], attn_metadata)
hidden_states = layer(hidden_states,
kv_caches[i - self.start_layer],
attn_metadata)
if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})
hidden_states = self.ln_f(hidden_states)
return hidden_states
......@@ -290,9 +308,9 @@ class JAISLMHeadModel(nn.Module):
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor:
) -> Union[IntermediateTensors, torch.Tensor]:
hidden_states = self.transformer(input_ids, positions, kv_caches,
attn_metadata)
attn_metadata, intermediate_tensors)
return hidden_states
def compute_logits(
......@@ -304,6 +322,16 @@ class JAISLMHeadModel(nn.Module):
sampling_metadata)
return logits
def make_empty_intermediate_tensors(
self, batch_size: int, dtype: torch.dtype,
device: torch.device) -> IntermediateTensors:
return IntermediateTensors({
"hidden_states":
torch.zeros((batch_size, self.config.hidden_size),
dtype=dtype,
device=device),
})
def sample(
self,
logits: torch.Tensor,
......@@ -327,6 +355,10 @@ class JAISLMHeadModel(nn.Module):
continue
if not name.startswith("transformer."):
name = "transformer." + name
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
# The HF's GPT-2 implementation uses Conv1D instead of Linear.
# Because of this, we need to transpose the weights.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment