baichuan.py 17.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

codethazine's avatar
codethazine committed
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
22
"""Inference-only BaiChuan model compatible with HuggingFace weights."""
23

24
import math
25
from collections.abc import Iterable
26
from itertools import islice
codethazine's avatar
codethazine committed
27
28

import torch
29
from torch import nn
30
from transformers import PretrainedConfig
codethazine's avatar
codethazine committed
31

32
from vllm.attention.layer import Attention
33
from vllm.compilation.decorators import support_torch_compile
34
from vllm.config import CacheConfig, VllmConfig
35
36
37
38
39
from vllm.distributed import (
    get_pp_group,
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
)
40
41
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
42
43
44
45
46
from vllm.model_executor.layers.linear import (
    MergedColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
47
from vllm.model_executor.layers.logits_processor import LogitsProcessor
48
from vllm.model_executor.layers.quantization import QuantizationConfig
49
from vllm.model_executor.layers.rotary_embedding import get_rope
50
from vllm.model_executor.layers.vocab_parallel_embedding import (
51
52
53
    ParallelLMHead,
    VocabParallelEmbedding,
)
54
from vllm.model_executor.model_loader.weight_utils import (
55
56
57
    default_weight_loader,
    row_parallel_weight_loader,
)
58
from vllm.sequence import IntermediateTensors
codethazine's avatar
codethazine committed
59

60
from .interfaces import SupportsLoRA, SupportsPP, SupportsQuant
61
62
63
64
65
66
67
from .utils import (
    AutoWeightsLoader,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
68

69
70

def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
71
    closest_power_of_2 = 2 ** math.floor(math.log2(total_num_heads))
72
    base = torch.tensor(
73
        2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))),
74
75
76
77
78
79
80
        dtype=torch.float32,
    )
    powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
    slopes = torch.pow(base, powers)

    if closest_power_of_2 != total_num_heads:
        extra_base = torch.tensor(
81
            2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))),
82
83
            dtype=torch.float32,
        )
84
85
86
87
88
89
90
        num_remaining_heads = min(
            closest_power_of_2, total_num_heads - closest_power_of_2
        )
        extra_powers = torch.arange(
            start=1, end=1 + 2 * num_remaining_heads, step=2, dtype=torch.int32
        )
        slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
91
92
93
94
95
96
97
98
99
    return slopes


class BaiChuanMLP(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
100
        quant_config: QuantizationConfig | None = None,
101
        prefix: str = "",
102
103
104
    ):
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
105
106
107
108
109
            hidden_size,
            [intermediate_size] * 2,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.gate_up_proj",
110
111
        )
        self.down_proj = RowParallelLinear(
112
113
114
115
116
            intermediate_size,
            hidden_size,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.down_proj",
117
        )
118
        if hidden_act != "silu":
119
120
121
            raise ValueError(
                f"Unsupported activation: {hidden_act}. Only silu is supported for now."
            )
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


class BaiChuanAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        position_embedding: str,
139
        rope_parameters: dict,
140
        max_position_embeddings: int = 8192,
141
142
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
143
        prefix: str = "",
144
145
146
    ):
        super().__init__()
        self.hidden_size = hidden_size
147
        tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
148
149
        self.total_num_heads = num_heads
        assert self.total_num_heads % tensor_model_parallel_world_size == 0
150
        self.num_heads = self.total_num_heads // tensor_model_parallel_world_size
151
        self.head_dim = hidden_size // self.total_num_heads
152
        self.position_embedding = position_embedding
153
154
155
156
157
158
159
160
161
        self.max_position_embeddings = max_position_embeddings

        # pylint: disable=invalid-name
        self.W_pack = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_heads,
            bias=False,
162
            quant_config=quant_config,
163
            prefix=f"{prefix}.W_pack",
164
165
166
167
168
        )
        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
169
            quant_config=quant_config,
170
            prefix=f"{prefix}.o_proj",
171
172
        )
        # Create the alibi slopes and slice them.
173
        if self.position_embedding == "ALIBI":
174
175
176
177
178
179
180
            tp_rank = get_tensor_model_parallel_rank()
            head_start = tp_rank * self.num_heads
            head_end = (tp_rank + 1) * self.num_heads
            alibi_slopes = _get_alibi_slopes(self.total_num_heads)
            alibi_slopes = alibi_slopes[head_start:head_end].tolist()

            scaling = self.head_dim**-0.5
181
182
183
184
185
186
187
188
            self.attn = Attention(
                self.num_heads,
                self.head_dim,
                scaling,
                alibi_slopes=alibi_slopes,
                quant_config=quant_config,
                prefix=f"{prefix}.attn",
            )
189
190
191
192
        else:
            self.rotary_emb = get_rope(
                self.head_dim,
                max_position=self.max_position_embeddings,
193
                rope_parameters=rope_parameters,
194
195
            )
            self.scaling = self.head_dim**-0.5
196
197
198
199
200
201
202
203
            self.attn = Attention(
                self.num_heads,
                self.head_dim,
                self.scaling,
                cache_config=cache_config,
                quant_config=quant_config,
                prefix=f"{prefix}.attn",
            )
204
205
206
207
208
209
210
211

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.W_pack(hidden_states)
        q, k, v = qkv.chunk(chunks=3, dim=-1)
212
        if self.position_embedding != "ALIBI":
213
            q, k = self.rotary_emb(positions, q, k)
214
        attn_output = self.attn(q, k, v)
215
216
217
218
219
        output, _ = self.o_proj(attn_output)
        return output


class BaiChuanDecoderLayer(nn.Module):
220
221
222
223
    def __init__(
        self,
        config: PretrainedConfig,
        position_embedding: str,
224
225
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
226
227
        prefix: str = "",
    ):
228
229
        super().__init__()
        self.hidden_size = config.hidden_size
230
        max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
231
232
233
234
        self.self_attn = BaiChuanAttention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            position_embedding=position_embedding,
235
            rope_parameters=getattr(config, "rope_parameters", None),
236
            max_position_embeddings=max_position_embeddings,
237
            cache_config=cache_config,
238
            quant_config=quant_config,
239
            prefix=f"{prefix}.self_attn",
240
241
242
243
244
        )
        self.mlp = BaiChuanMLP(
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
245
            quant_config=quant_config,
246
            prefix=f"{prefix}.mlp",
247
        )
248
249
250
251
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )
252
253
254
255
256

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
257
        residual: torch.Tensor | None,
258
    ) -> tuple[torch.Tensor, torch.Tensor]:
259
260
261
262
263
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
264
            hidden_states, residual = self.input_layernorm(hidden_states, residual)
265
266
267
268
269
270
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
        )

        # Fully Connected
271
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
272
273
274
275
        hidden_states = self.mlp(hidden_states)
        return hidden_states, residual


276
@support_torch_compile
277
class BaiChuanModel(nn.Module):
278
279
280
281
282
283
    def __init__(
        self,
        vllm_config: VllmConfig,
        prefix: str = "",
        position_embedding: str = "ROPE",
    ) -> None:
284
        super().__init__()
285
286
287
288
289

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

290
291
292
293
294
295
296
        self.config = config
        self.vocab_size = config.vocab_size

        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
        )
297
298
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
299
300
301
            lambda prefix: BaiChuanDecoderLayer(
                config, position_embedding, cache_config, quant_config, prefix=prefix
            ),
302
303
            prefix=f"{prefix}.layers",
        )
304
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
305
306
307
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], config.hidden_size
        )
308

309
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
310
311
        return self.embed_tokens(input_ids)

312
313
314
315
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
316
317
318
        intermediate_tensors: IntermediateTensors | None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
319
        if get_pp_group().is_first_rank:
320
321
322
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
323
                hidden_states = self.embed_input_ids(input_ids)
324
325
326
327
328
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
329
        for layer in islice(self.layers, self.start_layer, self.end_layer):
330
331
332
333
334
            hidden_states, residual = layer(
                positions,
                hidden_states,
                residual,
            )
335
        if not get_pp_group().is_last_rank:
336
337
338
339
340
341
            return IntermediateTensors(
                {
                    "hidden_states": hidden_states,
                    "residual": residual,
                }
            )
342
343
344
        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states

345
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
346
347
348
349
350
351
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
352
        loaded_params: set[str] = set()
353
354
355
356
        for name, loaded_weight in weights:
            if "rotary_emb.inv_freq" in name:
                continue

357
            for param_name, weight_name, shard_id in stacked_params_mapping:
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
377
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
378
379
380
381
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

382

383
class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsQuant):
384
385
386
387
388
389
390
    packed_modules_mapping = {
        "W_pack": ["W_pack"],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }
391

392
393
    def __init__(
        self,
394
        *,
395
396
397
        vllm_config: VllmConfig,
        prefix: str = "",
        position_embedding: str = "ROPE",
398
    ):
399
        super().__init__()
400
401
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
402

403
        self.config = config
404

405
        self.tp_size = get_tensor_model_parallel_world_size()
406
        self.quant_config = quant_config
407
408
409
410
411
412
413
414
415
416
417
        self.model = BaiChuanModel(
            vllm_config=vllm_config,
            prefix=prefix,
            position_embedding=position_embedding,
        )
        self.lm_head = ParallelLMHead(
            config.vocab_size,
            config.hidden_size,
            quant_config=quant_config,
            prefix=maybe_prefix(prefix, "lm_head"),
        )
418
        self.lm_head.weight.weight_loader = self.lm_head_weight_loader
419
420
        if self.config.tie_word_embeddings:
            self.lm_head.weight = self.model.embed_tokens.weight
421
        self.logits_processor = LogitsProcessor(config.vocab_size)
422
        self.make_empty_intermediate_tensors = (
423
424
            self.model.make_empty_intermediate_tensors
        )
425

426
427
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.embed_input_ids(input_ids)
428

429
430
431
432
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
433
434
435
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
436
437
438
        hidden_states = self.model(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
439
440
        return hidden_states

441
442
443
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
444
    ) -> torch.Tensor | None:
445
        logits = self.logits_processor(self.lm_head, hidden_states)
446
447
        return logits

448
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
449
450
451
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights)

452
    def lm_head_weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
453
454
455
456
457
458
459
460
461
        # Unlike Baichuan, Baichuan2 normalizes the head weights.
        # Refer to:
        # https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat/blob/84603cde5ebffb6084e476cfaeceaf0b8b91fe54/modeling_baichuan.py#L508
        # Distinguish between Baichuan and Baichuan2 by checking the
        # vocab size. This is suggested by
        # https://github.com/vllm-project/vllm/pull/1022#discussion_r1325652704
        is_baichuan2 = self.config.vocab_size == 125696
        if is_baichuan2:
            loaded_weight = torch.nn.functional.normalize(loaded_weight)
462
463
464
465
        if self.tp_size > 1:
            row_parallel_weight_loader(param, loaded_weight)
        else:
            default_weight_loader(param, loaded_weight)
466
467


468
class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
469
470
471
    """Baichuan 13B and Baichuan2 7B/13B.
    NOTE: the class name has a lower case 'c'.
    """
472

473
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
474
        config = vllm_config.model_config.hf_config
475
        if config.hidden_size == 4096:  # baichuan2 7b
476
477
478
            super().__init__(
                vllm_config=vllm_config, prefix=prefix, position_embedding="ROPE"
            )
479
        else:  # baichuan 13b, baichuan2 13b
480
481
482
            super().__init__(
                vllm_config=vllm_config, prefix=prefix, position_embedding="ALIBI"
            )
483
484


485
class BaiChuanForCausalLM(BaiChuanBaseForCausalLM):
486
487
488
    """Baichuan 7B.
    NOTE: the class name has an upper case 'C'.
    """
489

490
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
491
492
493
        super().__init__(
            vllm_config=vllm_config, prefix=prefix, position_embedding="ROPE"
        )