baichuan.py 17.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

codethazine's avatar
codethazine committed
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
22
"""Inference-only BaiChuan model compatible with HuggingFace weights."""
23

24
import math
25
from collections.abc import Iterable
26
from itertools import islice
27
from typing import Optional, Union
codethazine's avatar
codethazine committed
28
29

import torch
30
from torch import nn
31
from transformers import PretrainedConfig
codethazine's avatar
codethazine committed
32

33
from vllm.attention import Attention
34
from vllm.compilation.decorators import support_torch_compile
35
from vllm.config import CacheConfig, VllmConfig
36
37
38
39
40
from vllm.distributed import (
    get_pp_group,
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
)
41
42
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
43
44
45
46
47
from vllm.model_executor.layers.linear import (
    MergedColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
48
from vllm.model_executor.layers.logits_processor import LogitsProcessor
49
from vllm.model_executor.layers.quantization import QuantizationConfig
50
from vllm.model_executor.layers.rotary_embedding import get_rope
51
from vllm.model_executor.layers.vocab_parallel_embedding import (
52
53
54
    ParallelLMHead,
    VocabParallelEmbedding,
)
55
from vllm.model_executor.model_loader.weight_utils import (
56
57
58
    default_weight_loader,
    row_parallel_weight_loader,
)
59
from vllm.sequence import IntermediateTensors
codethazine's avatar
codethazine committed
60

61
from .interfaces import SupportsLoRA, SupportsPP, SupportsQuant
62
63
64
65
66
67
68
from .utils import (
    AutoWeightsLoader,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
69

70
71

def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
72
    closest_power_of_2 = 2 ** math.floor(math.log2(total_num_heads))
73
    base = torch.tensor(
74
        2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))),
75
76
77
78
79
80
81
        dtype=torch.float32,
    )
    powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
    slopes = torch.pow(base, powers)

    if closest_power_of_2 != total_num_heads:
        extra_base = torch.tensor(
82
            2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))),
83
84
            dtype=torch.float32,
        )
85
86
87
88
89
90
91
        num_remaining_heads = min(
            closest_power_of_2, total_num_heads - closest_power_of_2
        )
        extra_powers = torch.arange(
            start=1, end=1 + 2 * num_remaining_heads, step=2, dtype=torch.int32
        )
        slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
92
93
94
95
96
97
98
99
100
    return slopes


class BaiChuanMLP(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
101
        quant_config: Optional[QuantizationConfig] = None,
102
103
104
    ):
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
105
106
107
108
109
            hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config
        )
        self.down_proj = RowParallelLinear(
            intermediate_size, hidden_size, bias=False, quant_config=quant_config
        )
110
        if hidden_act != "silu":
111
112
113
            raise ValueError(
                f"Unsupported activation: {hidden_act}. Only silu is supported for now."
            )
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


class BaiChuanAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        position_embedding: str,
        rope_theta: float = 10000,
        max_position_embeddings: int = 8192,
133
        cache_config: Optional[CacheConfig] = None,
134
        quant_config: Optional[QuantizationConfig] = None,
135
        prefix: str = "",
136
137
138
    ):
        super().__init__()
        self.hidden_size = hidden_size
139
        tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
140
141
        self.total_num_heads = num_heads
        assert self.total_num_heads % tensor_model_parallel_world_size == 0
142
        self.num_heads = self.total_num_heads // tensor_model_parallel_world_size
143
        self.head_dim = hidden_size // self.total_num_heads
144
        self.position_embedding = position_embedding
145
146
147
148
149
150
151
152
153
154
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings

        # pylint: disable=invalid-name
        self.W_pack = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_heads,
            bias=False,
155
            quant_config=quant_config,
156
157
158
159
160
        )
        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
161
            quant_config=quant_config,
162
163
        )
        # Create the alibi slopes and slice them.
164
        if self.position_embedding == "ALIBI":
165
166
167
168
169
170
171
            tp_rank = get_tensor_model_parallel_rank()
            head_start = tp_rank * self.num_heads
            head_end = (tp_rank + 1) * self.num_heads
            alibi_slopes = _get_alibi_slopes(self.total_num_heads)
            alibi_slopes = alibi_slopes[head_start:head_end].tolist()

            scaling = self.head_dim**-0.5
172
173
174
175
176
177
178
179
            self.attn = Attention(
                self.num_heads,
                self.head_dim,
                scaling,
                alibi_slopes=alibi_slopes,
                quant_config=quant_config,
                prefix=f"{prefix}.attn",
            )
180
181
182
183
184
185
186
187
        else:
            self.rotary_emb = get_rope(
                self.head_dim,
                rotary_dim=self.head_dim,
                max_position=self.max_position_embeddings,
                base=self.rope_theta,
            )
            self.scaling = self.head_dim**-0.5
188
189
190
191
192
193
194
195
            self.attn = Attention(
                self.num_heads,
                self.head_dim,
                self.scaling,
                cache_config=cache_config,
                quant_config=quant_config,
                prefix=f"{prefix}.attn",
            )
196
197
198
199
200
201
202
203

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.W_pack(hidden_states)
        q, k, v = qkv.chunk(chunks=3, dim=-1)
204
        if self.position_embedding != "ALIBI":
205
            q, k = self.rotary_emb(positions, q, k)
206
        attn_output = self.attn(q, k, v)
207
208
209
210
211
        output, _ = self.o_proj(attn_output)
        return output


class BaiChuanDecoderLayer(nn.Module):
212
213
214
215
216
217
218
219
    def __init__(
        self,
        config: PretrainedConfig,
        position_embedding: str,
        cache_config: Optional[CacheConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ):
220
221
222
        super().__init__()
        self.hidden_size = config.hidden_size
        rope_theta = getattr(config, "rope_theta", 10000)
223
        max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
224
225
226
227
228
229
        self.self_attn = BaiChuanAttention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            position_embedding=position_embedding,
            rope_theta=rope_theta,
            max_position_embeddings=max_position_embeddings,
230
            cache_config=cache_config,
231
            quant_config=quant_config,
232
            prefix=f"{prefix}.self_attn",
233
234
235
236
237
        )
        self.mlp = BaiChuanMLP(
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
238
            quant_config=quant_config,
239
        )
240
241
242
243
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )
244
245
246
247
248
249

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        residual: Optional[torch.Tensor],
250
    ) -> tuple[torch.Tensor, torch.Tensor]:
251
252
253
254
255
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
256
            hidden_states, residual = self.input_layernorm(hidden_states, residual)
257
258
259
260
261
262
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
        )

        # Fully Connected
263
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
264
265
266
267
        hidden_states = self.mlp(hidden_states)
        return hidden_states, residual


268
@support_torch_compile
269
class BaiChuanModel(nn.Module):
270
271
272
273
274
275
    def __init__(
        self,
        vllm_config: VllmConfig,
        prefix: str = "",
        position_embedding: str = "ROPE",
    ) -> None:
276
        super().__init__()
277
278
279
280
281

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

282
283
284
285
286
287
288
        self.config = config
        self.vocab_size = config.vocab_size

        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
        )
289
290
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
291
292
293
            lambda prefix: BaiChuanDecoderLayer(
                config, position_embedding, cache_config, quant_config, prefix=prefix
            ),
294
295
            prefix=f"{prefix}.layers",
        )
296
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
297
298
299
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], config.hidden_size
        )
300

301
302
303
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

304
305
306
307
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
308
        intermediate_tensors: Optional[IntermediateTensors],
309
        inputs_embeds: Optional[torch.Tensor] = None,
310
311
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
312
313
314
315
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
316
317
318
319
320
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
321
        for layer in islice(self.layers, self.start_layer, self.end_layer):
322
323
324
325
326
            hidden_states, residual = layer(
                positions,
                hidden_states,
                residual,
            )
327
        if not get_pp_group().is_last_rank:
328
329
330
331
332
333
            return IntermediateTensors(
                {
                    "hidden_states": hidden_states,
                    "residual": residual,
                }
            )
334
335
336
        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states

337
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
338
339
340
341
342
343
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
344
        loaded_params: set[str] = set()
345
346
347
348
        for name, loaded_weight in weights:
            if "rotary_emb.inv_freq" in name:
                continue

349
            for param_name, weight_name, shard_id in stacked_params_mapping:
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
369
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
370
371
372
373
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

374

375
class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsQuant):
376
377
378
379
380
381
382
    packed_modules_mapping = {
        "W_pack": ["W_pack"],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }
383

384
385
    def __init__(
        self,
386
        *,
387
388
389
        vllm_config: VllmConfig,
        prefix: str = "",
        position_embedding: str = "ROPE",
390
    ):
391
        super().__init__()
392
393
394
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config
395
        self.config = config
396
        self.lora_config = lora_config
397
        self.tp_size = get_tensor_model_parallel_world_size()
398
        self.quant_config = quant_config
399
400
401
402
403
404
405
406
407
408
409
        self.model = BaiChuanModel(
            vllm_config=vllm_config,
            prefix=prefix,
            position_embedding=position_embedding,
        )
        self.lm_head = ParallelLMHead(
            config.vocab_size,
            config.hidden_size,
            quant_config=quant_config,
            prefix=maybe_prefix(prefix, "lm_head"),
        )
410
        self.lm_head.weight.weight_loader = self.lm_head_weight_loader
411
412
        if self.config.tie_word_embeddings:
            self.lm_head.weight = self.model.embed_tokens.weight
413
        self.logits_processor = LogitsProcessor(config.vocab_size)
414
        self.make_empty_intermediate_tensors = (
415
416
            self.model.make_empty_intermediate_tensors
        )
417

418
419
420
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

421
422
423
424
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
425
        intermediate_tensors: Optional[IntermediateTensors] = None,
426
        inputs_embeds: Optional[torch.Tensor] = None,
427
    ) -> Union[torch.Tensor, IntermediateTensors]:
428
429
430
        hidden_states = self.model(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
431
432
        return hidden_states

433
434
435
436
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> Optional[torch.Tensor]:
437
        logits = self.logits_processor(self.lm_head, hidden_states)
438
439
        return logits

440
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
441
442
443
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights)

444
    def lm_head_weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
445
446
447
448
449
450
451
452
453
        # Unlike Baichuan, Baichuan2 normalizes the head weights.
        # Refer to:
        # https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat/blob/84603cde5ebffb6084e476cfaeceaf0b8b91fe54/modeling_baichuan.py#L508
        # Distinguish between Baichuan and Baichuan2 by checking the
        # vocab size. This is suggested by
        # https://github.com/vllm-project/vllm/pull/1022#discussion_r1325652704
        is_baichuan2 = self.config.vocab_size == 125696
        if is_baichuan2:
            loaded_weight = torch.nn.functional.normalize(loaded_weight)
454
455
456
457
        if self.tp_size > 1:
            row_parallel_weight_loader(param, loaded_weight)
        else:
            default_weight_loader(param, loaded_weight)
458
459


460
class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
461
462
463
    """Baichuan 13B and Baichuan2 7B/13B.
    NOTE: the class name has a lower case 'c'.
    """
464

465
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
466
        config = vllm_config.model_config.hf_config
467
        if config.hidden_size == 4096:  # baichuan2 7b
468
469
470
            super().__init__(
                vllm_config=vllm_config, prefix=prefix, position_embedding="ROPE"
            )
471
        else:  # baichuan 13b, baichuan2 13b
472
473
474
            super().__init__(
                vllm_config=vllm_config, prefix=prefix, position_embedding="ALIBI"
            )
475
476


477
class BaiChuanForCausalLM(BaiChuanBaseForCausalLM):
478
479
480
    """Baichuan 7B.
    NOTE: the class name has an upper case 'C'.
    """
481

482
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
483
484
485
        super().__init__(
            vllm_config=vllm_config, prefix=prefix, position_embedding="ROPE"
        )