baichuan.py 14.3 KB
Newer Older
codethazine's avatar
codethazine committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only BaiChuan model compatible with HuggingFace weights.

The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input.
"""
25
import math
26
from typing import List, Optional, Tuple
codethazine's avatar
codethazine committed
27
28
29
30
31
32
33

import torch
from torch import nn

from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
JFDuan's avatar
JFDuan committed
34
35
from vllm.model_executor.layers.attention import (PagedAttentionWithRoPE,
                                                  PagedAttentionWithALiBi)
codethazine's avatar
codethazine committed
36
from vllm.model_executor.layers.sampler import Sampler
JFDuan's avatar
JFDuan committed
37
from vllm.model_executor.weight_utils import (
38
39
    convert_pyslice_to_tensor, hf_model_weights_iterator,
    load_padded_tensor_parallel_vocab, load_tensor_parallel_weights)
codethazine's avatar
codethazine committed
40
41
from vllm.model_executor.parallel_utils.parallel_state import (
    get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
42
43
44
from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding,
                                                       ColumnParallelLinear,
                                                       RowParallelLinear)
45
from vllm.sequence import SamplerOutput
codethazine's avatar
codethazine committed
46
47
48
49
50
from vllm.transformers_utils.configs.baichuan import BaiChuanConfig

KVCache = Tuple[torch.Tensor, torch.Tensor]


51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
    closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
    base = torch.tensor(
        2**(-(2**-(math.log2(closest_power_of_2) - 3))),
        dtype=torch.float32,
    )
    powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
    slopes = torch.pow(base, powers)

    if closest_power_of_2 != total_num_heads:
        extra_base = torch.tensor(
            2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
            dtype=torch.float32,
        )
        num_remaining_heads = min(closest_power_of_2,
                                  total_num_heads - closest_power_of_2)
        extra_powers = torch.arange(start=1,
                                    end=1 + 2 * num_remaining_heads,
                                    step=2,
                                    dtype=torch.int32)
        slopes = torch.cat(
            [slopes, torch.pow(extra_base, extra_powers)], dim=0)
    return slopes


codethazine's avatar
codethazine committed
76
77
78
79
80
81
82
83
84
class BaiChuanMLP(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
    ):
        super().__init__()
85
86
87
88
89
90
91
92
93
94
95
96
        self.gate_up_proj = ColumnParallelLinear(
            hidden_size,
            2 * intermediate_size,
            bias=False,
            gather_output=False,
        )
        self.down_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=False,
            input_is_parallel=True,
        )
codethazine's avatar
codethazine committed
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


class BaiChuanAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
116
        position_embedding: str,
117
118
        rope_theta: float = 10000,
        max_position_embeddings: int = 8192,
codethazine's avatar
codethazine committed
119
120
121
122
123
124
125
126
127
128
    ):
        super().__init__()
        self.hidden_size = hidden_size
        tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
        )
        self.total_num_heads = num_heads
        assert self.total_num_heads % tensor_model_parallel_world_size == 0
        self.num_heads = (self.total_num_heads //
                          tensor_model_parallel_world_size)
        self.head_dim = hidden_size // self.total_num_heads
129
        self.postion_embedding = position_embedding
130
131
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings
codethazine's avatar
codethazine committed
132
133
134
135
136
137
138
139
140
141
142
143
144
145

        # pylint: disable=invalid-name
        self.W_pack = ColumnParallelLinear(
            hidden_size,
            3 * hidden_size,
            bias=False,
            gather_output=False,
        )
        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
            input_is_parallel=True,
        )
146
147
148
149
150
151
152
153
154
155
156
157
158
        # Create the alibi slopes and slice them.
        if self.postion_embedding == "ALIBI":
            tp_rank = get_tensor_model_parallel_rank()
            head_start = tp_rank * self.num_heads
            head_end = (tp_rank + 1) * self.num_heads
            alibi_slopes = _get_alibi_slopes(self.total_num_heads)
            alibi_slopes = alibi_slopes[head_start:head_end].tolist()

            scaling = self.head_dim**-0.5
            self.attn = PagedAttentionWithALiBi(self.num_heads, self.head_dim,
                                                scaling, alibi_slopes)
        else:
            self.scaling = self.head_dim**-0.5
159
160
161
162
163
164
165
            self.attn = PagedAttentionWithRoPE(
                self.num_heads,
                self.head_dim,
                self.scaling,
                rotary_dim=self.head_dim,
                base=self.rope_theta,
                max_position=self.max_position_embeddings)
codethazine's avatar
codethazine committed
166
167
168
169
170
171
172
173
174
175
176
177

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        kv_cache: KVCache,
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
    ) -> torch.Tensor:
        qkv, _ = self.W_pack(hidden_states)
        q, k, v = qkv.chunk(chunks=3, dim=-1)
        k_cache, v_cache = kv_cache
178
179
180
181
182
183
184
        if self.postion_embedding == "ALIBI":
            attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
                                    cache_event)
        else:
            attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
                                    input_metadata, cache_event)

codethazine's avatar
codethazine committed
185
186
187
188
189
190
        output, _ = self.o_proj(attn_output)
        return output


class BaiChuanDecoderLayer(nn.Module):

191
    def __init__(self, config: BaiChuanConfig, position_embedding: str):
codethazine's avatar
codethazine committed
192
193
        super().__init__()
        self.hidden_size = config.hidden_size
194
195
196
        rope_theta = getattr(config, "rope_theta", 10000)
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
codethazine's avatar
codethazine committed
197
198
199
        self.self_attn = BaiChuanAttention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
200
            position_embedding=position_embedding,
201
202
            rope_theta=rope_theta,
            max_position_embeddings=max_position_embeddings,
codethazine's avatar
codethazine committed
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
        )
        self.mlp = BaiChuanMLP(
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
        )
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps=config.rms_norm_eps)

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        kv_cache: KVCache,
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
    ) -> torch.Tensor:
        # Self Attention
        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states)
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
            input_metadata=input_metadata,
            cache_event=cache_event,
        )
        hidden_states = residual + hidden_states

        # Fully Connected
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states
        return hidden_states


class BaiChuanModel(nn.Module):

244
    def __init__(self, config: BaiChuanConfig, position_embedding: str):
codethazine's avatar
codethazine committed
245
246
247
248
249
250
251
252
        super().__init__()
        self.config = config
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
253
        )
codethazine's avatar
codethazine committed
254
        self.layers = nn.ModuleList([
255
            BaiChuanDecoderLayer(config, position_embedding)
codethazine's avatar
codethazine committed
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
            for _ in range(config.num_hidden_layers)
        ])
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
        cache_events: Optional[List[torch.cuda.Event]],
    ) -> torch.Tensor:
        hidden_states = self.embed_tokens(input_ids)
        for i in range(len(self.layers)):
            if cache_events is None:
                cache_event = None
            else:
                cache_event = cache_events[i]
            layer = self.layers[i]
            hidden_states = layer(
                positions,
                hidden_states,
                kv_caches[i],
                input_metadata,
                cache_event,
            )
        hidden_states = self.norm(hidden_states)
        return hidden_states


286
class BaiChuanBaseForCausalLM(nn.Module):
codethazine's avatar
codethazine committed
287

288
    def __init__(self, config, position_embedding: str):
codethazine's avatar
codethazine committed
289
290
        super().__init__()
        self.config = config
291
        self.model = BaiChuanModel(config, position_embedding)
292
293
294
295
296
297
        self.lm_head = ColumnParallelLinear(
            config.hidden_size,
            config.vocab_size,
            bias=False,
            gather_output=False,
        )
codethazine's avatar
codethazine committed
298
299
300
301
302
303
304
305
306
        self.sampler = Sampler(config.vocab_size)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
        cache_events: Optional[List[torch.cuda.Event]],
307
    ) -> SamplerOutput:
codethazine's avatar
codethazine committed
308
309
310
311
312
313
        hidden_states = self.model(input_ids, positions, kv_caches,
                                   input_metadata, cache_events)
        next_tokens = self.sampler(self.lm_head.weight, hidden_states,
                                   input_metadata)
        return next_tokens

JFDuan's avatar
JFDuan committed
314
    _column_parallel_weights = []
codethazine's avatar
codethazine committed
315
316
317
318
319
    _row_parallel_weights = ["o_proj.weight", "down_proj.weight"]

    def load_weights(self,
                     model_name_or_path: str,
                     cache_dir: Optional[str] = None,
Jasmond L's avatar
Jasmond L committed
320
321
                     load_format: str = "auto",
                     revision: Optional[str] = None):
Qing's avatar
Qing committed
322
323
        tp_world_size = get_tensor_model_parallel_world_size()
        tp_rank = get_tensor_model_parallel_rank()
codethazine's avatar
codethazine committed
324
325
326
        state_dict = self.state_dict()

        for name, loaded_weight in hf_model_weights_iterator(
Jasmond L's avatar
Jasmond L committed
327
                model_name_or_path, cache_dir, load_format, revision):
codethazine's avatar
codethazine committed
328
329
330
            if "rotary_emb.inv_freq" in name:
                continue

331
332
            loaded_weight = convert_pyslice_to_tensor(loaded_weight)

Qing's avatar
Qing committed
333
334
335
336
337
338
339
340
341
342
343
344
345
            if "W_pack" in name:
                total_num_heads = self.config.num_attention_heads
                hidden_size = self.config.hidden_size
                head_size = hidden_size // total_num_heads
                num_heads = total_num_heads // tp_world_size
                head_start = tp_rank * num_heads
                head_end = (tp_rank + 1) * num_heads

                loaded_weight = loaded_weight.view(3, total_num_heads,
                                                   head_size, hidden_size)
                loaded_weight = loaded_weight[:, head_start:head_end, :, :]
                loaded_weight = loaded_weight.reshape(-1, hidden_size)

codethazine's avatar
codethazine committed
346
347
348
349
350
351
            is_gate_up_weight = False
            for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]):
                if weight_name not in name:
                    continue
                param = state_dict[name.replace(weight_name, "gate_up_proj")]
                shard_size = param.shape[0] // 2
Qing's avatar
Qing committed
352
353
                loaded_weight = loaded_weight[shard_size * tp_rank:shard_size *
                                              (tp_rank + 1)]
codethazine's avatar
codethazine committed
354
355
356
357
358
359
360
361
362
363
                param_slice = param.data[shard_size * stride_id:shard_size *
                                         (stride_id + 1)]
                assert param_slice.shape == loaded_weight.shape
                param_slice.copy_(loaded_weight)
                is_gate_up_weight = True
                break
            if is_gate_up_weight:
                continue

            param = state_dict[name]
JFDuan's avatar
JFDuan committed
364
365
366
367
368
369

            if "embed_tokens" in name or "lm_head" in name:
                load_padded_tensor_parallel_vocab(param, loaded_weight,
                                                  tp_rank)
                continue

Qing's avatar
Qing committed
370
371
372
373
374
375
376
377
            load_tensor_parallel_weights(
                param,
                loaded_weight,
                name,
                self._column_parallel_weights,
                self._row_parallel_weights,
                tp_rank,
            )
378
379
380
381
382
383
384
385
386
387
388
389


class BaichuanForCausalLM(BaiChuanBaseForCausalLM):  # baichuan 13b

    def __init__(self, config):
        super().__init__(config, "ALIBI")


class BaiChuanForCausalLM(BaiChuanBaseForCausalLM):  # baichuan 7b

    def __init__(self, config):
        super().__init__(config, "ROPE")