Unverified Commit 64f23c29 authored by Song's avatar Song Committed by GitHub
Browse files

fix baichuan for different position embedding for 7b and 13b models (#643)

parent d4c7755c
...@@ -11,7 +11,8 @@ from vllm.model_executor.weight_utils import initialize_dummy_weights ...@@ -11,7 +11,8 @@ from vllm.model_executor.weight_utils import initialize_dummy_weights
# TODO(woosuk): Lazy-load the model classes. # TODO(woosuk): Lazy-load the model classes.
_MODEL_REGISTRY = { _MODEL_REGISTRY = {
"BaiChuanForCausalLM": BaiChuanForCausalLM, "BaiChuanForCausalLM": BaiChuanForCausalLM, # baichuan-7b
"BaichuanForCausalLM": BaichuanForCausalLM, # baichuan-13b
"BloomForCausalLM": BloomForCausalLM, "BloomForCausalLM": BloomForCausalLM,
"GPT2LMHeadModel": GPT2LMHeadModel, "GPT2LMHeadModel": GPT2LMHeadModel,
"GPTBigCodeForCausalLM": GPTBigCodeForCausalLM, "GPTBigCodeForCausalLM": GPTBigCodeForCausalLM,
......
from vllm.model_executor.models.baichuan import BaiChuanForCausalLM from vllm.model_executor.models.baichuan import BaiChuanForCausalLM, BaichuanForCausalLM
from vllm.model_executor.models.bloom import BloomForCausalLM from vllm.model_executor.models.bloom import BloomForCausalLM
from vllm.model_executor.models.gpt2 import GPT2LMHeadModel from vllm.model_executor.models.gpt2 import GPT2LMHeadModel
from vllm.model_executor.models.gpt_bigcode import GPTBigCodeForCausalLM from vllm.model_executor.models.gpt_bigcode import GPTBigCodeForCausalLM
...@@ -10,6 +10,7 @@ from vllm.model_executor.models.opt import OPTForCausalLM ...@@ -10,6 +10,7 @@ from vllm.model_executor.models.opt import OPTForCausalLM
__all__ = [ __all__ = [
"BaiChuanForCausalLM", "BaiChuanForCausalLM",
"BaichuanForCausalLM",
"BloomForCausalLM", "BloomForCausalLM",
"GPT2LMHeadModel", "GPT2LMHeadModel",
"GPTBigCodeForCausalLM", "GPTBigCodeForCausalLM",
......
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
The input of the model is flattened to a 1D tensor of tokens. The model uses The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input. InputMetadata to extract the original 2D shape of the input.
""" """
import math
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
import torch import torch
...@@ -31,7 +32,7 @@ from vllm.sequence import SequenceOutputs ...@@ -31,7 +32,7 @@ from vllm.sequence import SequenceOutputs
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE from vllm.model_executor.layers.attention import PagedAttentionWithRoPE, PagedAttentionWithALiBi
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.weight_utils import (hf_model_weights_iterator, from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
load_tensor_parallel_weights) load_tensor_parallel_weights)
...@@ -44,6 +45,31 @@ from vllm.transformers_utils.configs.baichuan import BaiChuanConfig ...@@ -44,6 +45,31 @@ from vllm.transformers_utils.configs.baichuan import BaiChuanConfig
KVCache = Tuple[torch.Tensor, torch.Tensor] KVCache = Tuple[torch.Tensor, torch.Tensor]
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
base = torch.tensor(
2**(-(2**-(math.log2(closest_power_of_2) - 3))),
dtype=torch.float32,
)
powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
slopes = torch.pow(base, powers)
if closest_power_of_2 != total_num_heads:
extra_base = torch.tensor(
2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
dtype=torch.float32,
)
num_remaining_heads = min(closest_power_of_2,
total_num_heads - closest_power_of_2)
extra_powers = torch.arange(start=1,
end=1 + 2 * num_remaining_heads,
step=2,
dtype=torch.int32)
slopes = torch.cat(
[slopes, torch.pow(extra_base, extra_powers)], dim=0)
return slopes
class BaiChuanMLP(nn.Module): class BaiChuanMLP(nn.Module):
def __init__( def __init__(
...@@ -82,6 +108,7 @@ class BaiChuanAttention(nn.Module): ...@@ -82,6 +108,7 @@ class BaiChuanAttention(nn.Module):
self, self,
hidden_size: int, hidden_size: int,
num_heads: int, num_heads: int,
position_embedding: str,
): ):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
...@@ -92,7 +119,7 @@ class BaiChuanAttention(nn.Module): ...@@ -92,7 +119,7 @@ class BaiChuanAttention(nn.Module):
self.num_heads = (self.total_num_heads // self.num_heads = (self.total_num_heads //
tensor_model_parallel_world_size) tensor_model_parallel_world_size)
self.head_dim = hidden_size // self.total_num_heads self.head_dim = hidden_size // self.total_num_heads
self.scaling = self.head_dim**-0.5 self.postion_embedding = position_embedding
# pylint: disable=invalid-name # pylint: disable=invalid-name
self.W_pack = ColumnParallelLinear( self.W_pack = ColumnParallelLinear(
...@@ -109,11 +136,23 @@ class BaiChuanAttention(nn.Module): ...@@ -109,11 +136,23 @@ class BaiChuanAttention(nn.Module):
input_is_parallel=True, input_is_parallel=True,
perform_initialization=False, perform_initialization=False,
) )
# Create the alibi slopes and slice them.
self.attn = PagedAttentionWithRoPE(self.num_heads, if self.postion_embedding == "ALIBI":
self.head_dim, tp_rank = get_tensor_model_parallel_rank()
self.scaling, head_start = tp_rank * self.num_heads
rotary_dim=self.head_dim) head_end = (tp_rank + 1) * self.num_heads
alibi_slopes = _get_alibi_slopes(self.total_num_heads)
alibi_slopes = alibi_slopes[head_start:head_end].tolist()
scaling = self.head_dim**-0.5
self.attn = PagedAttentionWithALiBi(self.num_heads, self.head_dim,
scaling, alibi_slopes)
else:
self.scaling = self.head_dim**-0.5
self.attn = PagedAttentionWithRoPE(self.num_heads,
self.head_dim,
self.scaling,
rotary_dim=self.head_dim)
def forward( def forward(
self, self,
...@@ -126,20 +165,26 @@ class BaiChuanAttention(nn.Module): ...@@ -126,20 +165,26 @@ class BaiChuanAttention(nn.Module):
qkv, _ = self.W_pack(hidden_states) qkv, _ = self.W_pack(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1) q, k, v = qkv.chunk(chunks=3, dim=-1)
k_cache, v_cache = kv_cache k_cache, v_cache = kv_cache
attn_output = self.attn(positions, q, k, v, k_cache, v_cache, if self.postion_embedding == "ALIBI":
input_metadata, cache_event) attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
cache_event)
else:
attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
input_metadata, cache_event)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
class BaiChuanDecoderLayer(nn.Module): class BaiChuanDecoderLayer(nn.Module):
def __init__(self, config: BaiChuanConfig): def __init__(self, config: BaiChuanConfig, position_embedding: str):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.self_attn = BaiChuanAttention( self.self_attn = BaiChuanAttention(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
num_heads=config.num_attention_heads, num_heads=config.num_attention_heads,
position_embedding=position_embedding,
) )
self.mlp = BaiChuanMLP( self.mlp = BaiChuanMLP(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
...@@ -181,7 +226,7 @@ class BaiChuanDecoderLayer(nn.Module): ...@@ -181,7 +226,7 @@ class BaiChuanDecoderLayer(nn.Module):
class BaiChuanModel(nn.Module): class BaiChuanModel(nn.Module):
def __init__(self, config: BaiChuanConfig): def __init__(self, config: BaiChuanConfig, position_embedding: str):
super().__init__() super().__init__()
self.config = config self.config = config
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
...@@ -192,7 +237,7 @@ class BaiChuanModel(nn.Module): ...@@ -192,7 +237,7 @@ class BaiChuanModel(nn.Module):
config.hidden_size, config.hidden_size,
perform_initialization=False) perform_initialization=False)
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
BaiChuanDecoderLayer(config) BaiChuanDecoderLayer(config, position_embedding)
for _ in range(config.num_hidden_layers) for _ in range(config.num_hidden_layers)
]) ])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
...@@ -223,12 +268,12 @@ class BaiChuanModel(nn.Module): ...@@ -223,12 +268,12 @@ class BaiChuanModel(nn.Module):
return hidden_states return hidden_states
class BaiChuanForCausalLM(nn.Module): class BaiChuanBaseForCausalLM(nn.Module):
def __init__(self, config): def __init__(self, config, position_embedding: str):
super().__init__() super().__init__()
self.config = config self.config = config
self.model = BaiChuanModel(config) self.model = BaiChuanModel(config, position_embedding)
self.lm_head = ColumnParallelLinear(config.hidden_size, self.lm_head = ColumnParallelLinear(config.hidden_size,
config.vocab_size, config.vocab_size,
bias=False, bias=False,
...@@ -318,3 +363,15 @@ class BaiChuanForCausalLM(nn.Module): ...@@ -318,3 +363,15 @@ class BaiChuanForCausalLM(nn.Module):
self._row_parallel_weights, self._row_parallel_weights,
tp_rank, tp_rank,
) )
class BaichuanForCausalLM(BaiChuanBaseForCausalLM): # baichuan 13b
def __init__(self, config):
super().__init__(config, "ALIBI")
class BaiChuanForCausalLM(BaiChuanBaseForCausalLM): # baichuan 7b
def __init__(self, config):
super().__init__(config, "ROPE")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment