baichuan.py 18.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

codethazine's avatar
codethazine committed
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
22
"""Inference-only BaiChuan model compatible with HuggingFace weights."""
23
import math
24
from collections.abc import Iterable
25
from itertools import islice
26
from typing import Optional, Union
codethazine's avatar
codethazine committed
27
28

import torch
29
from torch import nn
30
from transformers import PretrainedConfig
codethazine's avatar
codethazine committed
31

32
from vllm.attention import Attention
33
from vllm.compilation.decorators import support_torch_compile
34
from vllm.config import CacheConfig, VllmConfig
35
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
36
                              get_tensor_model_parallel_world_size)
37
38
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
39
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
40
41
                                               QKVParallelLinear,
                                               RowParallelLinear)
42
from vllm.model_executor.layers.logits_processor import LogitsProcessor
43
from vllm.model_executor.layers.quantization import QuantizationConfig
44
from vllm.model_executor.layers.rotary_embedding import get_rope
45
from vllm.model_executor.layers.vocab_parallel_embedding import (
46
    ParallelLMHead, VocabParallelEmbedding)
47
48
from vllm.model_executor.model_loader.weight_utils import (
    default_weight_loader, row_parallel_weight_loader)
49
from vllm.sequence import IntermediateTensors
codethazine's avatar
codethazine committed
50

51
from .interfaces import SupportsLoRA, SupportsPP, SupportsQuant
52
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
53
54
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
55

56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88

def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
    closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
    base = torch.tensor(
        2**(-(2**-(math.log2(closest_power_of_2) - 3))),
        dtype=torch.float32,
    )
    powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
    slopes = torch.pow(base, powers)

    if closest_power_of_2 != total_num_heads:
        extra_base = torch.tensor(
            2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
            dtype=torch.float32,
        )
        num_remaining_heads = min(closest_power_of_2,
                                  total_num_heads - closest_power_of_2)
        extra_powers = torch.arange(start=1,
                                    end=1 + 2 * num_remaining_heads,
                                    step=2,
                                    dtype=torch.int32)
        slopes = torch.cat(
            [slopes, torch.pow(extra_base, extra_powers)], dim=0)
    return slopes


class BaiChuanMLP(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
89
        quant_config: Optional[QuantizationConfig] = None,
90
91
92
93
94
    ):
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
            hidden_size, [intermediate_size] * 2,
            bias=False,
95
            quant_config=quant_config)
96
97
98
        self.down_proj = RowParallelLinear(intermediate_size,
                                           hidden_size,
                                           bias=False,
99
                                           quant_config=quant_config)
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


class BaiChuanAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        position_embedding: str,
        rope_theta: float = 10000,
        max_position_embeddings: int = 8192,
122
        cache_config: Optional[CacheConfig] = None,
123
        quant_config: Optional[QuantizationConfig] = None,
124
        prefix: str = "",
125
126
127
128
129
130
131
132
133
134
    ):
        super().__init__()
        self.hidden_size = hidden_size
        tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
        )
        self.total_num_heads = num_heads
        assert self.total_num_heads % tensor_model_parallel_world_size == 0
        self.num_heads = (self.total_num_heads //
                          tensor_model_parallel_world_size)
        self.head_dim = hidden_size // self.total_num_heads
135
        self.position_embedding = position_embedding
136
137
138
139
140
141
142
143
144
145
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings

        # pylint: disable=invalid-name
        self.W_pack = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_heads,
            bias=False,
146
            quant_config=quant_config,
147
148
149
150
151
        )
        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
152
            quant_config=quant_config,
153
154
        )
        # Create the alibi slopes and slice them.
155
        if self.position_embedding == "ALIBI":
156
157
158
159
160
161
162
            tp_rank = get_tensor_model_parallel_rank()
            head_start = tp_rank * self.num_heads
            head_end = (tp_rank + 1) * self.num_heads
            alibi_slopes = _get_alibi_slopes(self.total_num_heads)
            alibi_slopes = alibi_slopes[head_start:head_end].tolist()

            scaling = self.head_dim**-0.5
163
164
165
            self.attn = Attention(self.num_heads,
                                  self.head_dim,
                                  scaling,
166
                                  alibi_slopes=alibi_slopes,
167
168
                                  quant_config=quant_config,
                                  prefix=f"{prefix}.attn")
169
170
171
172
173
174
175
176
        else:
            self.rotary_emb = get_rope(
                self.head_dim,
                rotary_dim=self.head_dim,
                max_position=self.max_position_embeddings,
                base=self.rope_theta,
            )
            self.scaling = self.head_dim**-0.5
177
178
179
            self.attn = Attention(self.num_heads,
                                  self.head_dim,
                                  self.scaling,
180
                                  cache_config=cache_config,
181
182
                                  quant_config=quant_config,
                                  prefix=f"{prefix}.attn")
183
184
185
186
187
188
189
190

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.W_pack(hidden_states)
        q, k, v = qkv.chunk(chunks=3, dim=-1)
191
        if self.position_embedding != "ALIBI":
192
            q, k = self.rotary_emb(positions, q, k)
193
        attn_output = self.attn(q, k, v)
194
195
196
197
198
199
200
        output, _ = self.o_proj(attn_output)
        return output


class BaiChuanDecoderLayer(nn.Module):

    def __init__(self,
201
                 config: PretrainedConfig,
202
                 position_embedding: str,
203
                 cache_config: Optional[CacheConfig] = None,
204
205
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
206
207
208
209
210
211
212
213
214
215
216
        super().__init__()
        self.hidden_size = config.hidden_size
        rope_theta = getattr(config, "rope_theta", 10000)
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
        self.self_attn = BaiChuanAttention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            position_embedding=position_embedding,
            rope_theta=rope_theta,
            max_position_embeddings=max_position_embeddings,
217
            cache_config=cache_config,
218
            quant_config=quant_config,
219
            prefix=f"{prefix}.self_attn",
220
221
222
223
224
        )
        self.mlp = BaiChuanMLP(
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
225
            quant_config=quant_config,
226
227
228
229
230
231
232
233
234
235
236
        )
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps=config.rms_norm_eps)

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        residual: Optional[torch.Tensor],
237
    ) -> tuple[torch.Tensor, torch.Tensor]:
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(
                hidden_states, residual)
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
        )

        # Fully Connected
        hidden_states, residual = self.post_attention_layernorm(
            hidden_states, residual)
        hidden_states = self.mlp(hidden_states)
        return hidden_states, residual


257
@support_torch_compile
258
259
class BaiChuanModel(nn.Module):

260
261
262
263
264
265
    def __init__(
        self,
        vllm_config: VllmConfig,
        prefix: str = "",
        position_embedding: str = "ROPE",
    ) -> None:
266
        super().__init__()
267
268
269
270
271

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

272
273
274
275
276
277
278
        self.config = config
        self.vocab_size = config.vocab_size

        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
        )
279
280
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
281
282
283
284
285
            lambda prefix: BaiChuanDecoderLayer(config,
                                                position_embedding,
                                                cache_config,
                                                quant_config,
                                                prefix=prefix),
286
287
            prefix=f"{prefix}.layers",
        )
288
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
289
290
291
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(
                ["hidden_states", "residual"], config.hidden_size))
292

293
294
295
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

296
297
298
299
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
300
        intermediate_tensors: Optional[IntermediateTensors],
301
        inputs_embeds: Optional[torch.Tensor] = None,
302
303
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
304
305
306
307
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
308
309
310
311
312
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
313
        for layer in islice(self.layers, self.start_layer, self.end_layer):
314
315
316
317
318
            hidden_states, residual = layer(
                positions,
                hidden_states,
                residual,
            )
319
320
321
322
323
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({
                "hidden_states": hidden_states,
                "residual": residual,
            })
324
325
326
        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states

327
328
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
329
330
331
332
333
334
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
335
        loaded_params: set[str] = set()
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
        for name, loaded_weight in weights:
            if "rotary_emb.inv_freq" in name:
                continue

            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

366

367
368
class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
                              SupportsQuant):
369
370
371
372
373
374
375
    packed_modules_mapping = {
        "W_pack": ["W_pack"],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }
376

377
378
    def __init__(
        self,
379
        *,
380
381
382
        vllm_config: VllmConfig,
        prefix: str = "",
        position_embedding: str = "ROPE",
383
    ):
384
        super().__init__()
385
386
387
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config
388
        self.config = config
389
        self.lora_config = lora_config
390
        self.tp_size = get_tensor_model_parallel_world_size()
391
        self.quant_config = quant_config
392
393
394
        self.model = BaiChuanModel(vllm_config=vllm_config,
                                   prefix=prefix,
                                   position_embedding=position_embedding)
395
396
        self.lm_head = ParallelLMHead(config.vocab_size,
                                      config.hidden_size,
397
398
                                      quant_config=quant_config,
                                      prefix=maybe_prefix(prefix, "lm_head"))
399
        self.lm_head.weight.weight_loader = self.lm_head_weight_loader
400
401
        if self.config.tie_word_embeddings:
            self.lm_head.weight = self.model.embed_tokens.weight
402
        self.logits_processor = LogitsProcessor(config.vocab_size)
403
404
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)
405

406
407
408
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

409
410
411
412
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
413
        intermediate_tensors: Optional[IntermediateTensors] = None,
414
        inputs_embeds: Optional[torch.Tensor] = None,
415
    ) -> Union[torch.Tensor, IntermediateTensors]:
416
        hidden_states = self.model(input_ids, positions, intermediate_tensors,
417
                                   inputs_embeds)
418
419
        return hidden_states

420
421
422
423
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> Optional[torch.Tensor]:
424
        logits = self.logits_processor(self.lm_head, hidden_states)
425
426
        return logits

427
428
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
429
430
431
432
433
434
435
436
437
438
439
440
441
442
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights)

    def lm_head_weight_loader(self, param: nn.Parameter,
                              loaded_weight: torch.Tensor):
        # Unlike Baichuan, Baichuan2 normalizes the head weights.
        # Refer to:
        # https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat/blob/84603cde5ebffb6084e476cfaeceaf0b8b91fe54/modeling_baichuan.py#L508
        # Distinguish between Baichuan and Baichuan2 by checking the
        # vocab size. This is suggested by
        # https://github.com/vllm-project/vllm/pull/1022#discussion_r1325652704
        is_baichuan2 = self.config.vocab_size == 125696
        if is_baichuan2:
            loaded_weight = torch.nn.functional.normalize(loaded_weight)
443
444
445
446
        if self.tp_size > 1:
            row_parallel_weight_loader(param, loaded_weight)
        else:
            default_weight_loader(param, loaded_weight)
447
448


449
class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
450
451
452
    """Baichuan 13B and Baichuan2 7B/13B.
    NOTE: the class name has a lower case 'c'.
    """
453

454
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
455
        config = vllm_config.model_config.hf_config
456
        if config.hidden_size == 4096:  # baichuan2 7b
457
458
459
            super().__init__(vllm_config=vllm_config,
                             prefix=prefix,
                             position_embedding="ROPE")
460
        else:  # baichuan 13b, baichuan2 13b
461
462
463
            super().__init__(vllm_config=vllm_config,
                             prefix=prefix,
                             position_embedding="ALIBI")
464
465


466
class BaiChuanForCausalLM(BaiChuanBaseForCausalLM):
467
468
469
    """Baichuan 7B.
    NOTE: the class name has an upper case 'C'.
    """
470

471
472
473
474
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__(vllm_config=vllm_config,
                         prefix=prefix,
                         position_embedding="ROPE")