baichuan.py 14 KB
Newer Older
codethazine's avatar
codethazine committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
20
"""Inference-only BaiChuan model compatible with HuggingFace weights."""
21
import math
22
from typing import List, Optional, Tuple
codethazine's avatar
codethazine committed
23
24
25
26
27
28

import torch
from torch import nn

from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul
Woosuk Kwon's avatar
Woosuk Kwon committed
29
from vllm.model_executor.layers.attention import PagedAttention
30
31
32
33
34
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase,
                                               MergedColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
Woosuk Kwon's avatar
Woosuk Kwon committed
35
from vllm.model_executor.layers.rotary_embedding import get_rope
codethazine's avatar
codethazine committed
36
from vllm.model_executor.layers.sampler import Sampler
37
38
from vllm.model_executor.layers.vocab_parallel_embedding import (
    VocabParallelEmbedding, ParallelLMHead)
codethazine's avatar
codethazine committed
39
40
from vllm.model_executor.parallel_utils.parallel_state import (
    get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
41
from vllm.model_executor.sampling_metadata import SamplingMetadata
42
43
from vllm.model_executor.weight_utils import (default_weight_loader,
                                              hf_model_weights_iterator)
44
from vllm.sequence import SamplerOutput
codethazine's avatar
codethazine committed
45
46
47
48
49
from vllm.transformers_utils.configs.baichuan import BaiChuanConfig

KVCache = Tuple[torch.Tensor, torch.Tensor]


50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
    closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
    base = torch.tensor(
        2**(-(2**-(math.log2(closest_power_of_2) - 3))),
        dtype=torch.float32,
    )
    powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
    slopes = torch.pow(base, powers)

    if closest_power_of_2 != total_num_heads:
        extra_base = torch.tensor(
            2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
            dtype=torch.float32,
        )
        num_remaining_heads = min(closest_power_of_2,
                                  total_num_heads - closest_power_of_2)
        extra_powers = torch.arange(start=1,
                                    end=1 + 2 * num_remaining_heads,
                                    step=2,
                                    dtype=torch.int32)
        slopes = torch.cat(
            [slopes, torch.pow(extra_base, extra_powers)], dim=0)
    return slopes


codethazine's avatar
codethazine committed
75
76
77
78
79
80
81
class BaiChuanMLP(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
82
        linear_method: Optional[LinearMethodBase] = None,
codethazine's avatar
codethazine committed
83
84
    ):
        super().__init__()
85
86
        self.gate_up_proj = MergedColumnParallelLinear(
            hidden_size, [intermediate_size] * 2,
87
            bias=False,
88
89
90
91
92
            linear_method=linear_method)
        self.down_proj = RowParallelLinear(intermediate_size,
                                           hidden_size,
                                           bias=False,
                                           linear_method=linear_method)
codethazine's avatar
codethazine committed
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


class BaiChuanAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
112
        position_embedding: str,
113
114
        rope_theta: float = 10000,
        max_position_embeddings: int = 8192,
115
        linear_method: Optional[LinearMethodBase] = None,
codethazine's avatar
codethazine committed
116
117
118
119
120
121
122
123
124
125
    ):
        super().__init__()
        self.hidden_size = hidden_size
        tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
        )
        self.total_num_heads = num_heads
        assert self.total_num_heads % tensor_model_parallel_world_size == 0
        self.num_heads = (self.total_num_heads //
                          tensor_model_parallel_world_size)
        self.head_dim = hidden_size // self.total_num_heads
126
        self.postion_embedding = position_embedding
127
128
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings
codethazine's avatar
codethazine committed
129
130

        # pylint: disable=invalid-name
131
        self.W_pack = QKVParallelLinear(
codethazine's avatar
codethazine committed
132
            hidden_size,
133
134
135
            self.head_dim,
            self.total_num_heads,
            self.total_num_heads,
codethazine's avatar
codethazine committed
136
            bias=False,
137
            linear_method=linear_method,
codethazine's avatar
codethazine committed
138
139
140
141
142
        )
        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
143
            linear_method=linear_method,
codethazine's avatar
codethazine committed
144
        )
145
146
147
148
149
150
151
152
153
        # Create the alibi slopes and slice them.
        if self.postion_embedding == "ALIBI":
            tp_rank = get_tensor_model_parallel_rank()
            head_start = tp_rank * self.num_heads
            head_end = (tp_rank + 1) * self.num_heads
            alibi_slopes = _get_alibi_slopes(self.total_num_heads)
            alibi_slopes = alibi_slopes[head_start:head_end].tolist()

            scaling = self.head_dim**-0.5
Woosuk Kwon's avatar
Woosuk Kwon committed
154
155
156
157
            self.attn = PagedAttention(self.num_heads,
                                       self.head_dim,
                                       scaling,
                                       alibi_slopes=alibi_slopes)
158
        else:
Woosuk Kwon's avatar
Woosuk Kwon committed
159
            self.rotary_emb = get_rope(
160
161
                self.head_dim,
                rotary_dim=self.head_dim,
Woosuk Kwon's avatar
Woosuk Kwon committed
162
                max_position=self.max_position_embeddings,
163
                base=self.rope_theta,
Woosuk Kwon's avatar
Woosuk Kwon committed
164
165
166
167
            )
            self.scaling = self.head_dim**-0.5
            self.attn = PagedAttention(self.num_heads, self.head_dim,
                                       self.scaling)
codethazine's avatar
codethazine committed
168
169
170
171
172
173
174
175
176
177
178

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        kv_cache: KVCache,
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
    ) -> torch.Tensor:
        qkv, _ = self.W_pack(hidden_states)
        q, k, v = qkv.chunk(chunks=3, dim=-1)
Woosuk Kwon's avatar
Woosuk Kwon committed
179
180
        if self.postion_embedding != "ALIBI":
            q, k = self.rotary_emb(positions, q, k)
codethazine's avatar
codethazine committed
181
        k_cache, v_cache = kv_cache
Woosuk Kwon's avatar
Woosuk Kwon committed
182
183
        attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
                                cache_event)
codethazine's avatar
codethazine committed
184
185
186
187
188
189
        output, _ = self.o_proj(attn_output)
        return output


class BaiChuanDecoderLayer(nn.Module):

190
191
192
193
    def __init__(self,
                 config: BaiChuanConfig,
                 position_embedding: str,
                 linear_method: Optional[LinearMethodBase] = None):
codethazine's avatar
codethazine committed
194
195
        super().__init__()
        self.hidden_size = config.hidden_size
196
197
198
        rope_theta = getattr(config, "rope_theta", 10000)
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
codethazine's avatar
codethazine committed
199
200
201
        self.self_attn = BaiChuanAttention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
202
            position_embedding=position_embedding,
203
204
            rope_theta=rope_theta,
            max_position_embeddings=max_position_embeddings,
205
            linear_method=linear_method,
codethazine's avatar
codethazine committed
206
207
208
209
210
        )
        self.mlp = BaiChuanMLP(
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
211
            linear_method=linear_method,
codethazine's avatar
codethazine committed
212
213
214
215
216
217
218
219
220
221
222
223
224
        )
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps=config.rms_norm_eps)

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        kv_cache: KVCache,
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
225
226
        residual: Optional[torch.Tensor],
    ) -> Tuple[torch.Tensor, torch.Tensor]:
codethazine's avatar
codethazine committed
227
        # Self Attention
228
229
230
231
232
233
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(
                hidden_states, residual)
codethazine's avatar
codethazine committed
234
235
236
237
238
239
240
241
242
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
            input_metadata=input_metadata,
            cache_event=cache_event,
        )

        # Fully Connected
243
244
        hidden_states, residual = self.post_attention_layernorm(
            hidden_states, residual)
codethazine's avatar
codethazine committed
245
        hidden_states = self.mlp(hidden_states)
246
        return hidden_states, residual
codethazine's avatar
codethazine committed
247
248
249
250


class BaiChuanModel(nn.Module):

251
252
253
254
    def __init__(self,
                 config: BaiChuanConfig,
                 position_embedding: str,
                 linear_method: Optional[LinearMethodBase] = None):
codethazine's avatar
codethazine committed
255
256
257
258
259
260
261
262
        super().__init__()
        self.config = config
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
263
        )
codethazine's avatar
codethazine committed
264
        self.layers = nn.ModuleList([
265
            BaiChuanDecoderLayer(config, position_embedding, linear_method)
codethazine's avatar
codethazine committed
266
267
268
269
270
271
272
273
274
275
276
277
278
            for _ in range(config.num_hidden_layers)
        ])
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
        cache_events: Optional[List[torch.cuda.Event]],
    ) -> torch.Tensor:
        hidden_states = self.embed_tokens(input_ids)
279
        residual = None
codethazine's avatar
codethazine committed
280
        for i in range(len(self.layers)):
281
            cache_event = None if cache_events is None else cache_events[i]
codethazine's avatar
codethazine committed
282
            layer = self.layers[i]
283
            hidden_states, residual = layer(
codethazine's avatar
codethazine committed
284
285
286
287
288
                positions,
                hidden_states,
                kv_caches[i],
                input_metadata,
                cache_event,
289
                residual,
codethazine's avatar
codethazine committed
290
            )
291
        hidden_states, _ = self.norm(hidden_states, residual)
codethazine's avatar
codethazine committed
292
293
294
        return hidden_states


295
class BaiChuanBaseForCausalLM(nn.Module):
codethazine's avatar
codethazine committed
296

297
298
299
300
    def __init__(self,
                 config,
                 position_embedding: str,
                 linear_method: Optional[LinearMethodBase] = None):
codethazine's avatar
codethazine committed
301
302
        super().__init__()
        self.config = config
303
304
305
        self.linear_method = linear_method
        self.model = BaiChuanModel(config, position_embedding, linear_method)
        self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
codethazine's avatar
codethazine committed
306
307
308
309
310
311
312
313
314
        self.sampler = Sampler(config.vocab_size)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
        cache_events: Optional[List[torch.cuda.Event]],
315
    ) -> torch.Tensor:
codethazine's avatar
codethazine committed
316
317
        hidden_states = self.model(input_ids, positions, kv_caches,
                                   input_metadata, cache_events)
318
319
320
321
322
323
324
        return hidden_states

    def sample(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> SamplerOutput:
codethazine's avatar
codethazine committed
325
        next_tokens = self.sampler(self.lm_head.weight, hidden_states,
326
                                   sampling_metadata)
codethazine's avatar
codethazine committed
327
328
329
330
331
        return next_tokens

    def load_weights(self,
                     model_name_or_path: str,
                     cache_dir: Optional[str] = None,
Jasmond L's avatar
Jasmond L committed
332
333
                     load_format: str = "auto",
                     revision: Optional[str] = None):
334
335
336
337
338
339
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
codethazine's avatar
codethazine committed
340
        for name, loaded_weight in hf_model_weights_iterator(
Jasmond L's avatar
Jasmond L committed
341
                model_name_or_path, cache_dir, load_format, revision):
codethazine's avatar
codethazine committed
342
343
            if "rotary_emb.inv_freq" in name:
                continue
344
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
codethazine's avatar
codethazine committed
345
346
                if weight_name not in name:
                    continue
347
348
349
                param = params_dict[name.replace(weight_name, param_name)]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
codethazine's avatar
codethazine committed
350
                break
351
352
353
354
355
            else:
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
356
357
358
359


class BaichuanForCausalLM(BaiChuanBaseForCausalLM):  # baichuan 13b

360
361
362
363
    def __init__(self,
                 config,
                 linear_method: Optional[LinearMethodBase] = None):
        super().__init__(config, "ALIBI", linear_method)
364
365
366
367


class BaiChuanForCausalLM(BaiChuanBaseForCausalLM):  # baichuan 7b

368
369
370
371
    def __init__(self,
                 config,
                 linear_method: Optional[LinearMethodBase] = None):
        super().__init__(config, "ROPE", linear_method)