baichuan.py 18 KB
Newer Older
codethazine's avatar
codethazine committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
20
"""Inference-only BaiChuan model compatible with HuggingFace weights."""
21
import math
22
from typing import Iterable, List, Optional, Tuple, Union
codethazine's avatar
codethazine committed
23
24

import torch
25
from torch import nn
26
from transformers import PretrainedConfig
codethazine's avatar
codethazine committed
27

28
from vllm.attention import Attention, AttentionMetadata
29
from vllm.compilation.decorators import support_torch_compile
30
from vllm.config import CacheConfig, LoRAConfig
31
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
32
                              get_tensor_model_parallel_world_size)
33
34
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
35
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
36
37
                                               QKVParallelLinear,
                                               RowParallelLinear)
38
from vllm.model_executor.layers.logits_processor import LogitsProcessor
39
from vllm.model_executor.layers.quantization import QuantizationConfig
40
from vllm.model_executor.layers.rotary_embedding import get_rope
41
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
42
from vllm.model_executor.layers.vocab_parallel_embedding import (
43
    ParallelLMHead, VocabParallelEmbedding)
44
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
45
from vllm.model_executor.sampling_metadata import SamplingMetadata
46
from vllm.sequence import IntermediateTensors
codethazine's avatar
codethazine committed
47

48
49
50
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (is_pp_missing_parameter,
                    make_empty_intermediate_tensors_factory, make_layers)
51

52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84

def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
    closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
    base = torch.tensor(
        2**(-(2**-(math.log2(closest_power_of_2) - 3))),
        dtype=torch.float32,
    )
    powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
    slopes = torch.pow(base, powers)

    if closest_power_of_2 != total_num_heads:
        extra_base = torch.tensor(
            2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
            dtype=torch.float32,
        )
        num_remaining_heads = min(closest_power_of_2,
                                  total_num_heads - closest_power_of_2)
        extra_powers = torch.arange(start=1,
                                    end=1 + 2 * num_remaining_heads,
                                    step=2,
                                    dtype=torch.int32)
        slopes = torch.cat(
            [slopes, torch.pow(extra_base, extra_powers)], dim=0)
    return slopes


class BaiChuanMLP(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
85
        quant_config: Optional[QuantizationConfig] = None,
86
87
88
89
90
    ):
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
            hidden_size, [intermediate_size] * 2,
            bias=False,
91
            quant_config=quant_config)
92
93
94
        self.down_proj = RowParallelLinear(intermediate_size,
                                           hidden_size,
                                           bias=False,
95
                                           quant_config=quant_config)
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


class BaiChuanAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        position_embedding: str,
        rope_theta: float = 10000,
        max_position_embeddings: int = 8192,
118
        cache_config: Optional[CacheConfig] = None,
119
        quant_config: Optional[QuantizationConfig] = None,
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
    ):
        super().__init__()
        self.hidden_size = hidden_size
        tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
        )
        self.total_num_heads = num_heads
        assert self.total_num_heads % tensor_model_parallel_world_size == 0
        self.num_heads = (self.total_num_heads //
                          tensor_model_parallel_world_size)
        self.head_dim = hidden_size // self.total_num_heads
        self.postion_embedding = position_embedding
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings

        # pylint: disable=invalid-name
        self.W_pack = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_heads,
            bias=False,
141
            quant_config=quant_config,
142
143
144
145
146
        )
        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
147
            quant_config=quant_config,
148
149
150
151
152
153
154
155
156
157
        )
        # Create the alibi slopes and slice them.
        if self.postion_embedding == "ALIBI":
            tp_rank = get_tensor_model_parallel_rank()
            head_start = tp_rank * self.num_heads
            head_end = (tp_rank + 1) * self.num_heads
            alibi_slopes = _get_alibi_slopes(self.total_num_heads)
            alibi_slopes = alibi_slopes[head_start:head_end].tolist()

            scaling = self.head_dim**-0.5
158
159
160
            self.attn = Attention(self.num_heads,
                                  self.head_dim,
                                  scaling,
161
162
                                  alibi_slopes=alibi_slopes,
                                  quant_config=quant_config)
163
164
165
166
167
168
169
170
        else:
            self.rotary_emb = get_rope(
                self.head_dim,
                rotary_dim=self.head_dim,
                max_position=self.max_position_embeddings,
                base=self.rope_theta,
            )
            self.scaling = self.head_dim**-0.5
171
172
173
            self.attn = Attention(self.num_heads,
                                  self.head_dim,
                                  self.scaling,
174
175
                                  cache_config=cache_config,
                                  quant_config=quant_config)
176
177
178
179
180

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
181
182
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
183
184
185
186
187
    ) -> torch.Tensor:
        qkv, _ = self.W_pack(hidden_states)
        q, k, v = qkv.chunk(chunks=3, dim=-1)
        if self.postion_embedding != "ALIBI":
            q, k = self.rotary_emb(positions, q, k)
188
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
189
190
191
192
193
194
195
        output, _ = self.o_proj(attn_output)
        return output


class BaiChuanDecoderLayer(nn.Module):

    def __init__(self,
196
                 config: PretrainedConfig,
197
                 position_embedding: str,
198
                 cache_config: Optional[CacheConfig] = None,
199
                 quant_config: Optional[QuantizationConfig] = None):
200
201
202
203
204
205
206
207
208
209
210
        super().__init__()
        self.hidden_size = config.hidden_size
        rope_theta = getattr(config, "rope_theta", 10000)
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
        self.self_attn = BaiChuanAttention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            position_embedding=position_embedding,
            rope_theta=rope_theta,
            max_position_embeddings=max_position_embeddings,
211
            cache_config=cache_config,
212
            quant_config=quant_config,
213
214
215
216
217
        )
        self.mlp = BaiChuanMLP(
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
218
            quant_config=quant_config,
219
220
221
222
223
224
225
226
227
228
        )
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps=config.rms_norm_eps)

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
229
230
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
231
232
233
234
235
236
237
238
239
240
241
242
243
        residual: Optional[torch.Tensor],
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(
                hidden_states, residual)
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
244
            attn_metadata=attn_metadata,
245
246
247
248
249
250
251
252
253
        )

        # Fully Connected
        hidden_states, residual = self.post_attention_layernorm(
            hidden_states, residual)
        hidden_states = self.mlp(hidden_states)
        return hidden_states, residual


254
@support_torch_compile
255
256
257
class BaiChuanModel(nn.Module):

    def __init__(self,
258
                 config: PretrainedConfig,
259
                 position_embedding: str,
260
                 cache_config: Optional[CacheConfig] = None,
261
262
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
263
264
265
266
267
268
269
270
271
        super().__init__()
        self.config = config
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
        )
272
273
274
275
276
277
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
            lambda prefix: BaiChuanDecoderLayer(config, position_embedding,
                                                cache_config, quant_config),
            prefix=f"{prefix}.layers",
        )
278
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
279
280
281
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(
                ["hidden_states", "residual"], config.hidden_size))
282
283
284
285
286

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
287
288
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
289
290
291
292
293
294
295
296
297
298
        intermediate_tensors: Optional[IntermediateTensors],
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
            hidden_states = self.embed_tokens(input_ids)
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
        for i in range(self.start_layer, self.end_layer):
299
300
301
302
            layer = self.layers[i]
            hidden_states, residual = layer(
                positions,
                hidden_states,
303
                kv_caches[i - self.start_layer],
304
                attn_metadata,
305
306
                residual,
            )
307
308
309
310
311
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({
                "hidden_states": hidden_states,
                "residual": residual,
            })
312
313
314
315
        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states


316
class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
    packed_modules_mapping = {
        "W_pack": ["W_pack"],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }
    # LoRA specific attributes
    supported_lora_modules = [
        "W_pack",
        "o_proj",
        "gate_up_proj",
        "down_proj",
    ]
    embedding_modules = {}
    embedding_padding_modules = []
333

334
335
    def __init__(
        self,
336
        config: PretrainedConfig,
337
        position_embedding: str,
338
        cache_config: Optional[CacheConfig] = None,
339
        quant_config: Optional[QuantizationConfig] = None,
340
341
        lora_config: Optional[LoRAConfig] = None,
    ):
342
        super().__init__()
343

344
        self.config = config
345
346
        self.lora_config = lora_config

347
        self.quant_config = quant_config
348
349
        self.model = BaiChuanModel(config, position_embedding, cache_config,
                                   quant_config)
350
351
352
        self.lm_head = ParallelLMHead(config.vocab_size,
                                      config.hidden_size,
                                      quant_config=quant_config)
353
354
        if self.config.tie_word_embeddings:
            self.lm_head.weight = self.model.embed_tokens.weight
355
356
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.sampler = Sampler()
357
358
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)
359
360
361
362
363

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
364
365
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
366
        intermediate_tensors: Optional[IntermediateTensors] = None,
367
    ) -> Union[torch.Tensor, IntermediateTensors]:
368
        hidden_states = self.model(input_ids, positions, kv_caches,
369
                                   attn_metadata, intermediate_tensors)
370
371
        return hidden_states

372
373
374
375
376
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
377
        logits = self.logits_processor(self.lm_head, hidden_states,
378
379
380
                                       sampling_metadata)
        return logits

381
382
    def sample(
        self,
383
        logits: torch.Tensor,
384
385
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
386
        next_tokens = self.sampler(logits, sampling_metadata)
387
        return next_tokens
codethazine's avatar
codethazine committed
388

389
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
390
391
392
393
394
395
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
396
        for name, loaded_weight in weights:
codethazine's avatar
codethazine committed
397
398
            if "rotary_emb.inv_freq" in name:
                continue
399
            if name == "lm_head.weight":
400
401
                # Unlike Baichuan, Baichuan2 normalizes the head weights.
                # Refer to:
402
403
404
405
406
407
408
409
410
                # https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat/blob/84603cde5ebffb6084e476cfaeceaf0b8b91fe54/modeling_baichuan.py#L508
                # Distinguish between Baichuan and Baichuan2 by checking the
                # vocab size. This is suggested by
                # https://github.com/vllm-project/vllm/pull/1022#discussion_r1325652704
                is_baichuan2 = self.config.vocab_size == 125696
                if is_baichuan2:
                    loaded_weight = torch.nn.functional.normalize(
                        loaded_weight)

411
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
codethazine's avatar
codethazine committed
412
413
                if weight_name not in name:
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
414
415
416
417
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
418
419
                if is_pp_missing_parameter(name, self):
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
420
                param = params_dict[name]
421
422
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
codethazine's avatar
codethazine committed
423
                break
424
            else:
CHU Tianxiang's avatar
CHU Tianxiang committed
425
426
427
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
428
429
                if is_pp_missing_parameter(name, self):
                    continue
430
431
432
433
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
434
435


436
class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
437
438
439
    """Baichuan 13B and Baichuan2 7B/13B.
    NOTE: the class name has a lower case 'c'.
    """
440

441
442
    def __init__(
        self,
443
        config: PretrainedConfig,
444
        cache_config: Optional[CacheConfig] = None,
445
        quant_config: Optional[QuantizationConfig] = None,
446
447
        lora_config: Optional[LoRAConfig] = None,
    ):
448
        if config.hidden_size == 4096:  # baichuan2 7b
449
450
            super().__init__(config, "ROPE", cache_config, quant_config,
                             lora_config)
451
        else:  # baichuan 13b, baichuan2 13b
452
453
            super().__init__(config, "ALIBI", cache_config, quant_config,
                             lora_config)
454
455


456
class BaiChuanForCausalLM(BaiChuanBaseForCausalLM):
457
458
459
    """Baichuan 7B.
    NOTE: the class name has an upper case 'C'.
    """
460

461
462
    def __init__(
        self,
463
        config: PretrainedConfig,
464
        cache_config: Optional[CacheConfig] = None,
465
        quant_config: Optional[QuantizationConfig] = None,
466
467
        lora_config: Optional[LoRAConfig] = None,
    ):
468
469
        super().__init__(config, "ROPE", cache_config, quant_config,
                         lora_config)