baichuan.py 15.2 KB
Newer Older
codethazine's avatar
codethazine committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
20
"""Inference-only BaiChuan model compatible with HuggingFace weights."""
21
import math
22
from typing import Iterable, List, Optional, Tuple
codethazine's avatar
codethazine committed
23
24

import torch
25
from torch import nn
26
from transformers import PretrainedConfig
codethazine's avatar
codethazine committed
27

28
from vllm.attention import Attention, AttentionMetadata
29
from vllm.config import LoRAConfig
30
31
from vllm.distributed import (get_tensor_model_parallel_rank,
                              get_tensor_model_parallel_world_size)
32
33
34
35
36
37
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase,
                                               MergedColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
38
from vllm.model_executor.layers.logits_processor import LogitsProcessor
39
from vllm.model_executor.layers.rotary_embedding import get_rope
40
41
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
42
    ParallelLMHead, VocabParallelEmbedding)
43
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
44
45
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
codethazine's avatar
codethazine committed
46

47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151

def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
    closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
    base = torch.tensor(
        2**(-(2**-(math.log2(closest_power_of_2) - 3))),
        dtype=torch.float32,
    )
    powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
    slopes = torch.pow(base, powers)

    if closest_power_of_2 != total_num_heads:
        extra_base = torch.tensor(
            2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
            dtype=torch.float32,
        )
        num_remaining_heads = min(closest_power_of_2,
                                  total_num_heads - closest_power_of_2)
        extra_powers = torch.arange(start=1,
                                    end=1 + 2 * num_remaining_heads,
                                    step=2,
                                    dtype=torch.int32)
        slopes = torch.cat(
            [slopes, torch.pow(extra_base, extra_powers)], dim=0)
    return slopes


class BaiChuanMLP(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
        linear_method: Optional[LinearMethodBase] = None,
    ):
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
            hidden_size, [intermediate_size] * 2,
            bias=False,
            linear_method=linear_method)
        self.down_proj = RowParallelLinear(intermediate_size,
                                           hidden_size,
                                           bias=False,
                                           linear_method=linear_method)
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


class BaiChuanAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        position_embedding: str,
        rope_theta: float = 10000,
        max_position_embeddings: int = 8192,
        linear_method: Optional[LinearMethodBase] = None,
    ):
        super().__init__()
        self.hidden_size = hidden_size
        tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
        )
        self.total_num_heads = num_heads
        assert self.total_num_heads % tensor_model_parallel_world_size == 0
        self.num_heads = (self.total_num_heads //
                          tensor_model_parallel_world_size)
        self.head_dim = hidden_size // self.total_num_heads
        self.postion_embedding = position_embedding
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings

        # pylint: disable=invalid-name
        self.W_pack = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_heads,
            bias=False,
            linear_method=linear_method,
        )
        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
            linear_method=linear_method,
        )
        # Create the alibi slopes and slice them.
        if self.postion_embedding == "ALIBI":
            tp_rank = get_tensor_model_parallel_rank()
            head_start = tp_rank * self.num_heads
            head_end = (tp_rank + 1) * self.num_heads
            alibi_slopes = _get_alibi_slopes(self.total_num_heads)
            alibi_slopes = alibi_slopes[head_start:head_end].tolist()

            scaling = self.head_dim**-0.5
152
153
154
155
            self.attn = Attention(self.num_heads,
                                  self.head_dim,
                                  scaling,
                                  alibi_slopes=alibi_slopes)
156
157
158
159
160
161
162
163
        else:
            self.rotary_emb = get_rope(
                self.head_dim,
                rotary_dim=self.head_dim,
                max_position=self.max_position_embeddings,
                base=self.rope_theta,
            )
            self.scaling = self.head_dim**-0.5
164
            self.attn = Attention(self.num_heads, self.head_dim, self.scaling)
165
166
167
168
169

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
170
171
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
172
173
174
175
176
    ) -> torch.Tensor:
        qkv, _ = self.W_pack(hidden_states)
        q, k, v = qkv.chunk(chunks=3, dim=-1)
        if self.postion_embedding != "ALIBI":
            q, k = self.rotary_emb(positions, q, k)
177
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
178
179
180
181
182
183
184
        output, _ = self.o_proj(attn_output)
        return output


class BaiChuanDecoderLayer(nn.Module):

    def __init__(self,
185
                 config: PretrainedConfig,
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
                 position_embedding: str,
                 linear_method: Optional[LinearMethodBase] = None):
        super().__init__()
        self.hidden_size = config.hidden_size
        rope_theta = getattr(config, "rope_theta", 10000)
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
        self.self_attn = BaiChuanAttention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            position_embedding=position_embedding,
            rope_theta=rope_theta,
            max_position_embeddings=max_position_embeddings,
            linear_method=linear_method,
        )
        self.mlp = BaiChuanMLP(
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
            linear_method=linear_method,
        )
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps=config.rms_norm_eps)

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
216
217
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
218
219
220
221
222
223
224
225
226
227
228
229
230
        residual: Optional[torch.Tensor],
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(
                hidden_states, residual)
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
231
            attn_metadata=attn_metadata,
232
233
234
235
236
237
238
239
240
241
242
243
        )

        # Fully Connected
        hidden_states, residual = self.post_attention_layernorm(
            hidden_states, residual)
        hidden_states = self.mlp(hidden_states)
        return hidden_states, residual


class BaiChuanModel(nn.Module):

    def __init__(self,
244
                 config: PretrainedConfig,
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
                 position_embedding: str,
                 linear_method: Optional[LinearMethodBase] = None):
        super().__init__()
        self.config = config
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
        )
        self.layers = nn.ModuleList([
            BaiChuanDecoderLayer(config, position_embedding, linear_method)
            for _ in range(config.num_hidden_layers)
        ])
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
266
267
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
268
269
270
271
272
273
274
275
276
    ) -> torch.Tensor:
        hidden_states = self.embed_tokens(input_ids)
        residual = None
        for i in range(len(self.layers)):
            layer = self.layers[i]
            hidden_states, residual = layer(
                positions,
                hidden_states,
                kv_caches[i],
277
                attn_metadata,
278
279
280
281
282
283
284
                residual,
            )
        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states


class BaiChuanBaseForCausalLM(nn.Module):
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
    packed_modules_mapping = {
        "W_pack": ["W_pack"],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }
    # LoRA specific attributes
    supported_lora_modules = [
        "W_pack",
        "o_proj",
        "gate_up_proj",
        "down_proj",
    ]
    embedding_modules = {}
    embedding_padding_modules = []
301

302
303
304
305
306
307
308
    def __init__(
        self,
        config,
        position_embedding: str,
        linear_method: Optional[LinearMethodBase] = None,
        lora_config: Optional[LoRAConfig] = None,
    ):
309
310
311
312
313
        super().__init__()
        self.config = config
        self.linear_method = linear_method
        self.model = BaiChuanModel(config, position_embedding, linear_method)
        self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
314
315
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.sampler = Sampler()
316
317
318
319
320

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
321
322
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
323
324
    ) -> torch.Tensor:
        hidden_states = self.model(input_ids, positions, kv_caches,
325
                                   attn_metadata)
326
327
        return hidden_states

328
329
330
331
332
333
    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head.weight, hidden_states,
                                       sampling_metadata)
        return logits

334
335
    def sample(
        self,
336
        logits: torch.Tensor,
337
338
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
339
        next_tokens = self.sampler(logits, sampling_metadata)
340
        return next_tokens
codethazine's avatar
codethazine committed
341

342
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
343
344
345
346
347
348
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
349
        for name, loaded_weight in weights:
codethazine's avatar
codethazine committed
350
351
            if "rotary_emb.inv_freq" in name:
                continue
352
            if name == "lm_head.weight":
353
354
                # Unlike Baichuan, Baichuan2 normalizes the head weights.
                # Refer to:
355
356
357
358
359
360
361
362
363
                # https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat/blob/84603cde5ebffb6084e476cfaeceaf0b8b91fe54/modeling_baichuan.py#L508
                # Distinguish between Baichuan and Baichuan2 by checking the
                # vocab size. This is suggested by
                # https://github.com/vllm-project/vllm/pull/1022#discussion_r1325652704
                is_baichuan2 = self.config.vocab_size == 125696
                if is_baichuan2:
                    loaded_weight = torch.nn.functional.normalize(
                        loaded_weight)

364
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
codethazine's avatar
codethazine committed
365
366
                if weight_name not in name:
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
367
368
369
370
371
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                param = params_dict[name]
372
373
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
codethazine's avatar
codethazine committed
374
                break
375
            else:
CHU Tianxiang's avatar
CHU Tianxiang committed
376
377
378
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
379
380
381
382
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
383
384


385
386
class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
    """Baichuan 13B and Baichuan2 7B/13B."""
387

388
389
390
391
392
393
    def __init__(
        self,
        config,
        linear_method: Optional[LinearMethodBase] = None,
        lora_config: Optional[LoRAConfig] = None,
    ):
394
        if config.hidden_size == 4096:  # baichuan2 7b
395
            super().__init__(config, "ROPE", linear_method, lora_config)
396
        else:  # baichuan 13b, baichuan2 13b
397
            super().__init__(config, "ALIBI", linear_method, lora_config)
398
399


400
401
class BaiChuanForCausalLM(BaiChuanBaseForCausalLM):
    """Baichuan 7B."""
402

403
404
405
406
407
408
409
    def __init__(
        self,
        config,
        linear_method: Optional[LinearMethodBase] = None,
        lora_config: Optional[LoRAConfig] = None,
    ):
        super().__init__(config, "ROPE", linear_method, lora_config)