baichuan.py 13.9 KB
Newer Older
codethazine's avatar
codethazine committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only BaiChuan model compatible with HuggingFace weights.

The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input.
"""
25
import math
26
from typing import List, Optional, Tuple
codethazine's avatar
codethazine committed
27
28
29
30
31
32

import torch
from torch import nn

from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul
JFDuan's avatar
JFDuan committed
33
34
from vllm.model_executor.layers.attention import (PagedAttentionWithRoPE,
                                                  PagedAttentionWithALiBi)
35
36
37
38
39
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase,
                                               MergedColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
codethazine's avatar
codethazine committed
40
from vllm.model_executor.layers.sampler import Sampler
41
42
from vllm.model_executor.layers.vocab_parallel_embedding import (
    VocabParallelEmbedding, ParallelLMHead)
codethazine's avatar
codethazine committed
43
44
from vllm.model_executor.parallel_utils.parallel_state import (
    get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
45
46
from vllm.model_executor.weight_utils import (default_weight_loader,
                                              hf_model_weights_iterator)
47
from vllm.sequence import SamplerOutput
codethazine's avatar
codethazine committed
48
49
50
51
52
from vllm.transformers_utils.configs.baichuan import BaiChuanConfig

KVCache = Tuple[torch.Tensor, torch.Tensor]


53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
    closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
    base = torch.tensor(
        2**(-(2**-(math.log2(closest_power_of_2) - 3))),
        dtype=torch.float32,
    )
    powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
    slopes = torch.pow(base, powers)

    if closest_power_of_2 != total_num_heads:
        extra_base = torch.tensor(
            2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
            dtype=torch.float32,
        )
        num_remaining_heads = min(closest_power_of_2,
                                  total_num_heads - closest_power_of_2)
        extra_powers = torch.arange(start=1,
                                    end=1 + 2 * num_remaining_heads,
                                    step=2,
                                    dtype=torch.int32)
        slopes = torch.cat(
            [slopes, torch.pow(extra_base, extra_powers)], dim=0)
    return slopes


codethazine's avatar
codethazine committed
78
79
80
81
82
83
84
class BaiChuanMLP(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
85
        linear_method: Optional[LinearMethodBase] = None,
codethazine's avatar
codethazine committed
86
87
    ):
        super().__init__()
88
89
        self.gate_up_proj = MergedColumnParallelLinear(
            hidden_size, [intermediate_size] * 2,
90
            bias=False,
91
92
93
94
95
            linear_method=linear_method)
        self.down_proj = RowParallelLinear(intermediate_size,
                                           hidden_size,
                                           bias=False,
                                           linear_method=linear_method)
codethazine's avatar
codethazine committed
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


class BaiChuanAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
115
        position_embedding: str,
116
117
        rope_theta: float = 10000,
        max_position_embeddings: int = 8192,
118
        linear_method: Optional[LinearMethodBase] = None,
codethazine's avatar
codethazine committed
119
120
121
122
123
124
125
126
127
128
    ):
        super().__init__()
        self.hidden_size = hidden_size
        tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
        )
        self.total_num_heads = num_heads
        assert self.total_num_heads % tensor_model_parallel_world_size == 0
        self.num_heads = (self.total_num_heads //
                          tensor_model_parallel_world_size)
        self.head_dim = hidden_size // self.total_num_heads
129
        self.postion_embedding = position_embedding
130
131
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings
codethazine's avatar
codethazine committed
132
133

        # pylint: disable=invalid-name
134
        self.W_pack = QKVParallelLinear(
codethazine's avatar
codethazine committed
135
            hidden_size,
136
137
138
            self.head_dim,
            self.total_num_heads,
            self.total_num_heads,
codethazine's avatar
codethazine committed
139
            bias=False,
140
            linear_method=linear_method,
codethazine's avatar
codethazine committed
141
142
143
144
145
        )
        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
146
            linear_method=linear_method,
codethazine's avatar
codethazine committed
147
        )
148
149
150
151
152
153
154
155
156
157
158
159
160
        # Create the alibi slopes and slice them.
        if self.postion_embedding == "ALIBI":
            tp_rank = get_tensor_model_parallel_rank()
            head_start = tp_rank * self.num_heads
            head_end = (tp_rank + 1) * self.num_heads
            alibi_slopes = _get_alibi_slopes(self.total_num_heads)
            alibi_slopes = alibi_slopes[head_start:head_end].tolist()

            scaling = self.head_dim**-0.5
            self.attn = PagedAttentionWithALiBi(self.num_heads, self.head_dim,
                                                scaling, alibi_slopes)
        else:
            self.scaling = self.head_dim**-0.5
161
162
163
164
165
166
167
            self.attn = PagedAttentionWithRoPE(
                self.num_heads,
                self.head_dim,
                self.scaling,
                rotary_dim=self.head_dim,
                base=self.rope_theta,
                max_position=self.max_position_embeddings)
codethazine's avatar
codethazine committed
168
169
170
171
172
173
174
175
176
177
178
179

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        kv_cache: KVCache,
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
    ) -> torch.Tensor:
        qkv, _ = self.W_pack(hidden_states)
        q, k, v = qkv.chunk(chunks=3, dim=-1)
        k_cache, v_cache = kv_cache
180
181
182
183
184
185
186
        if self.postion_embedding == "ALIBI":
            attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
                                    cache_event)
        else:
            attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
                                    input_metadata, cache_event)

codethazine's avatar
codethazine committed
187
188
189
190
191
192
        output, _ = self.o_proj(attn_output)
        return output


class BaiChuanDecoderLayer(nn.Module):

193
194
195
196
    def __init__(self,
                 config: BaiChuanConfig,
                 position_embedding: str,
                 linear_method: Optional[LinearMethodBase] = None):
codethazine's avatar
codethazine committed
197
198
        super().__init__()
        self.hidden_size = config.hidden_size
199
200
201
        rope_theta = getattr(config, "rope_theta", 10000)
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
codethazine's avatar
codethazine committed
202
203
204
        self.self_attn = BaiChuanAttention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
205
            position_embedding=position_embedding,
206
207
            rope_theta=rope_theta,
            max_position_embeddings=max_position_embeddings,
208
            linear_method=linear_method,
codethazine's avatar
codethazine committed
209
210
211
212
213
        )
        self.mlp = BaiChuanMLP(
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
214
            linear_method=linear_method,
codethazine's avatar
codethazine committed
215
216
217
218
219
220
221
222
223
224
225
226
227
        )
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps=config.rms_norm_eps)

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        kv_cache: KVCache,
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
228
229
        residual: Optional[torch.Tensor],
    ) -> Tuple[torch.Tensor, torch.Tensor]:
codethazine's avatar
codethazine committed
230
        # Self Attention
231
232
233
234
235
236
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(
                hidden_states, residual)
codethazine's avatar
codethazine committed
237
238
239
240
241
242
243
244
245
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
            input_metadata=input_metadata,
            cache_event=cache_event,
        )

        # Fully Connected
246
247
        hidden_states, residual = self.post_attention_layernorm(
            hidden_states, residual)
codethazine's avatar
codethazine committed
248
        hidden_states = self.mlp(hidden_states)
249
        return hidden_states, residual
codethazine's avatar
codethazine committed
250
251
252
253


class BaiChuanModel(nn.Module):

254
255
256
257
    def __init__(self,
                 config: BaiChuanConfig,
                 position_embedding: str,
                 linear_method: Optional[LinearMethodBase] = None):
codethazine's avatar
codethazine committed
258
259
260
261
262
263
264
265
        super().__init__()
        self.config = config
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
266
        )
codethazine's avatar
codethazine committed
267
        self.layers = nn.ModuleList([
268
            BaiChuanDecoderLayer(config, position_embedding, linear_method)
codethazine's avatar
codethazine committed
269
270
271
272
273
274
275
276
277
278
279
280
281
            for _ in range(config.num_hidden_layers)
        ])
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
        cache_events: Optional[List[torch.cuda.Event]],
    ) -> torch.Tensor:
        hidden_states = self.embed_tokens(input_ids)
282
        residual = None
codethazine's avatar
codethazine committed
283
        for i in range(len(self.layers)):
284
            cache_event = None if cache_events is None else cache_events[i]
codethazine's avatar
codethazine committed
285
            layer = self.layers[i]
286
            hidden_states, residual = layer(
codethazine's avatar
codethazine committed
287
288
289
290
291
                positions,
                hidden_states,
                kv_caches[i],
                input_metadata,
                cache_event,
292
                residual,
codethazine's avatar
codethazine committed
293
            )
294
        hidden_states, _ = self.norm(hidden_states, residual)
codethazine's avatar
codethazine committed
295
296
297
        return hidden_states


298
class BaiChuanBaseForCausalLM(nn.Module):
codethazine's avatar
codethazine committed
299

300
301
302
303
    def __init__(self,
                 config,
                 position_embedding: str,
                 linear_method: Optional[LinearMethodBase] = None):
codethazine's avatar
codethazine committed
304
305
        super().__init__()
        self.config = config
306
307
308
        self.linear_method = linear_method
        self.model = BaiChuanModel(config, position_embedding, linear_method)
        self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
codethazine's avatar
codethazine committed
309
310
311
312
313
314
315
316
317
        self.sampler = Sampler(config.vocab_size)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
        cache_events: Optional[List[torch.cuda.Event]],
318
    ) -> SamplerOutput:
codethazine's avatar
codethazine committed
319
320
321
322
323
324
325
326
327
        hidden_states = self.model(input_ids, positions, kv_caches,
                                   input_metadata, cache_events)
        next_tokens = self.sampler(self.lm_head.weight, hidden_states,
                                   input_metadata)
        return next_tokens

    def load_weights(self,
                     model_name_or_path: str,
                     cache_dir: Optional[str] = None,
Jasmond L's avatar
Jasmond L committed
328
329
                     load_format: str = "auto",
                     revision: Optional[str] = None):
330
331
332
333
334
335
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
codethazine's avatar
codethazine committed
336
        for name, loaded_weight in hf_model_weights_iterator(
Jasmond L's avatar
Jasmond L committed
337
                model_name_or_path, cache_dir, load_format, revision):
codethazine's avatar
codethazine committed
338
339
            if "rotary_emb.inv_freq" in name:
                continue
340
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
codethazine's avatar
codethazine committed
341
342
                if weight_name not in name:
                    continue
343
344
345
                param = params_dict[name.replace(weight_name, param_name)]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
codethazine's avatar
codethazine committed
346
                break
347
348
349
350
351
            else:
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
352
353
354
355


class BaichuanForCausalLM(BaiChuanBaseForCausalLM):  # baichuan 13b

356
357
358
359
    def __init__(self,
                 config,
                 linear_method: Optional[LinearMethodBase] = None):
        super().__init__(config, "ALIBI", linear_method)
360
361
362
363


class BaiChuanForCausalLM(BaiChuanBaseForCausalLM):  # baichuan 7b

364
365
366
367
    def __init__(self,
                 config,
                 linear_method: Optional[LinearMethodBase] = None):
        super().__init__(config, "ROPE", linear_method)