# coding=utf-8 # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. # # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX # and OPT implementations in this library. It has been modified from its # original forms to accommodate minor architectural differences compared # to GPT-NeoX and OPT used by the Meta AI team that trained the model. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Inference-only BaiChuan model compatible with HuggingFace weights. The input of the model is flattened to a 1D tensor of tokens. The model uses InputMetadata to extract the original 2D shape of the input. """ import math from typing import List, Optional, Tuple import torch from torch import nn from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.attention import (PagedAttentionWithRoPE, PagedAttentionWithALiBi) from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.weight_utils import ( hf_model_weights_iterator, load_padded_tensor_parallel_vocab, load_tensor_parallel_weights) from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.parallel_utils.tensor_parallel import ( VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear) from vllm.sequence import SamplerOutput from vllm.transformers_utils.configs.baichuan import BaiChuanConfig KVCache = Tuple[torch.Tensor, torch.Tensor] def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: closest_power_of_2 = 2**math.floor(math.log2(total_num_heads)) base = torch.tensor( 2**(-(2**-(math.log2(closest_power_of_2) - 3))), dtype=torch.float32, ) powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32) slopes = torch.pow(base, powers) if closest_power_of_2 != total_num_heads: extra_base = torch.tensor( 2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))), dtype=torch.float32, ) num_remaining_heads = min(closest_power_of_2, total_num_heads - closest_power_of_2) extra_powers = torch.arange(start=1, end=1 + 2 * num_remaining_heads, step=2, dtype=torch.int32) slopes = torch.cat( [slopes, torch.pow(extra_base, extra_powers)], dim=0) return slopes class BaiChuanMLP(nn.Module): def __init__( self, hidden_size: int, intermediate_size: int, hidden_act: str, ): super().__init__() self.gate_up_proj = ColumnParallelLinear(hidden_size, 2 * intermediate_size, bias=False, gather_output=False, perform_initialization=False) self.down_proj = RowParallelLinear(intermediate_size, hidden_size, bias=False, input_is_parallel=True, perform_initialization=False) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " "Only silu is supported for now.") self.act_fn = SiluAndMul() def forward(self, x): gate_up, _ = self.gate_up_proj(x) x = self.act_fn(gate_up) x, _ = self.down_proj(x) return x class BaiChuanAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__( self, hidden_size: int, num_heads: int, position_embedding: str, ): super().__init__() self.hidden_size = hidden_size tensor_model_parallel_world_size = get_tensor_model_parallel_world_size( ) self.total_num_heads = num_heads assert self.total_num_heads % tensor_model_parallel_world_size == 0 self.num_heads = (self.total_num_heads // tensor_model_parallel_world_size) self.head_dim = hidden_size // self.total_num_heads self.postion_embedding = position_embedding # pylint: disable=invalid-name self.W_pack = ColumnParallelLinear( hidden_size, 3 * hidden_size, bias=False, gather_output=False, perform_initialization=False, ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=False, input_is_parallel=True, perform_initialization=False, ) # Create the alibi slopes and slice them. if self.postion_embedding == "ALIBI": tp_rank = get_tensor_model_parallel_rank() head_start = tp_rank * self.num_heads head_end = (tp_rank + 1) * self.num_heads alibi_slopes = _get_alibi_slopes(self.total_num_heads) alibi_slopes = alibi_slopes[head_start:head_end].tolist() scaling = self.head_dim**-0.5 self.attn = PagedAttentionWithALiBi(self.num_heads, self.head_dim, scaling, alibi_slopes) else: self.scaling = self.head_dim**-0.5 self.attn = PagedAttentionWithRoPE(self.num_heads, self.head_dim, self.scaling, rotary_dim=self.head_dim) def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, kv_cache: KVCache, input_metadata: InputMetadata, cache_event: Optional[torch.cuda.Event], ) -> torch.Tensor: qkv, _ = self.W_pack(hidden_states) q, k, v = qkv.chunk(chunks=3, dim=-1) k_cache, v_cache = kv_cache if self.postion_embedding == "ALIBI": attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata, cache_event) else: attn_output = self.attn(positions, q, k, v, k_cache, v_cache, input_metadata, cache_event) output, _ = self.o_proj(attn_output) return output class BaiChuanDecoderLayer(nn.Module): def __init__(self, config: BaiChuanConfig, position_embedding: str): super().__init__() self.hidden_size = config.hidden_size self.self_attn = BaiChuanAttention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, position_embedding=position_embedding, ) self.mlp = BaiChuanMLP( hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, kv_cache: KVCache, input_metadata: InputMetadata, cache_event: Optional[torch.cuda.Event], ) -> torch.Tensor: # Self Attention residual = hidden_states hidden_states = self.input_layernorm(hidden_states) hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, kv_cache=kv_cache, input_metadata=input_metadata, cache_event=cache_event, ) hidden_states = residual + hidden_states # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states return hidden_states class BaiChuanModel(nn.Module): def __init__(self, config: BaiChuanConfig, position_embedding: str): super().__init__() self.config = config self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size, perform_initialization=False) self.layers = nn.ModuleList([ BaiChuanDecoderLayer(config, position_embedding) for _ in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, kv_caches: List[KVCache], input_metadata: InputMetadata, cache_events: Optional[List[torch.cuda.Event]], ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) for i in range(len(self.layers)): if cache_events is None: cache_event = None else: cache_event = cache_events[i] layer = self.layers[i] hidden_states = layer( positions, hidden_states, kv_caches[i], input_metadata, cache_event, ) hidden_states = self.norm(hidden_states) return hidden_states class BaiChuanBaseForCausalLM(nn.Module): def __init__(self, config, position_embedding: str): super().__init__() self.config = config self.model = BaiChuanModel(config, position_embedding) self.lm_head = ColumnParallelLinear(config.hidden_size, config.vocab_size, bias=False, gather_output=False, perform_initialization=False) self.sampler = Sampler(config.vocab_size) def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, kv_caches: List[KVCache], input_metadata: InputMetadata, cache_events: Optional[List[torch.cuda.Event]], ) -> SamplerOutput: hidden_states = self.model(input_ids, positions, kv_caches, input_metadata, cache_events) next_tokens = self.sampler(self.lm_head.weight, hidden_states, input_metadata) return next_tokens _column_parallel_weights = [] _row_parallel_weights = ["o_proj.weight", "down_proj.weight"] def load_weights(self, model_name_or_path: str, cache_dir: Optional[str] = None, use_np_cache: bool = False): tp_world_size = get_tensor_model_parallel_world_size() tp_rank = get_tensor_model_parallel_rank() state_dict = self.state_dict() for name, loaded_weight in hf_model_weights_iterator( model_name_or_path, cache_dir, use_np_cache): if "rotary_emb.inv_freq" in name: continue if "W_pack" in name: total_num_heads = self.config.num_attention_heads hidden_size = self.config.hidden_size head_size = hidden_size // total_num_heads num_heads = total_num_heads // tp_world_size head_start = tp_rank * num_heads head_end = (tp_rank + 1) * num_heads loaded_weight = loaded_weight.view(3, total_num_heads, head_size, hidden_size) loaded_weight = loaded_weight[:, head_start:head_end, :, :] loaded_weight = loaded_weight.reshape(-1, hidden_size) is_gate_up_weight = False for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]): if weight_name not in name: continue param = state_dict[name.replace(weight_name, "gate_up_proj")] shard_size = param.shape[0] // 2 loaded_weight = loaded_weight[shard_size * tp_rank:shard_size * (tp_rank + 1)] param_slice = param.data[shard_size * stride_id:shard_size * (stride_id + 1)] assert param_slice.shape == loaded_weight.shape param_slice.copy_(loaded_weight) is_gate_up_weight = True break if is_gate_up_weight: continue param = state_dict[name] if "embed_tokens" in name or "lm_head" in name: load_padded_tensor_parallel_vocab(param, loaded_weight, tp_rank) continue load_tensor_parallel_weights( param, loaded_weight, name, self._column_parallel_weights, self._row_parallel_weights, tp_rank, ) class BaichuanForCausalLM(BaiChuanBaseForCausalLM): # baichuan 13b def __init__(self, config): super().__init__(config, "ALIBI") class BaiChuanForCausalLM(BaiChuanBaseForCausalLM): # baichuan 7b def __init__(self, config): super().__init__(config, "ROPE")