baichuan.py 18.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

codethazine's avatar
codethazine committed
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
22
"""Inference-only BaiChuan model compatible with HuggingFace weights."""
23
import math
24
from collections.abc import Iterable
25
from itertools import islice
26
from typing import Optional, Union
codethazine's avatar
codethazine committed
27
28

import torch
29
from torch import nn
30
from transformers import PretrainedConfig
codethazine's avatar
codethazine committed
31

32
from vllm.attention import Attention
33
from vllm.compilation.decorators import support_torch_compile
34
from vllm.config import CacheConfig, VllmConfig
35
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
36
                              get_tensor_model_parallel_world_size)
37
38
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
39
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
40
41
                                               QKVParallelLinear,
                                               RowParallelLinear)
42
from vllm.model_executor.layers.logits_processor import LogitsProcessor
43
from vllm.model_executor.layers.quantization import QuantizationConfig
44
from vllm.model_executor.layers.rotary_embedding import get_rope
45
from vllm.model_executor.layers.vocab_parallel_embedding import (
46
    ParallelLMHead, VocabParallelEmbedding)
47
48
from vllm.model_executor.model_loader.weight_utils import (
    default_weight_loader, row_parallel_weight_loader)
49
from vllm.model_executor.sampling_metadata import SamplingMetadata
50
from vllm.sequence import IntermediateTensors
codethazine's avatar
codethazine committed
51

52
from .interfaces import SupportsLoRA, SupportsPP, SupportsQuant
53
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
54
55
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
56

57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89

def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
    closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
    base = torch.tensor(
        2**(-(2**-(math.log2(closest_power_of_2) - 3))),
        dtype=torch.float32,
    )
    powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
    slopes = torch.pow(base, powers)

    if closest_power_of_2 != total_num_heads:
        extra_base = torch.tensor(
            2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
            dtype=torch.float32,
        )
        num_remaining_heads = min(closest_power_of_2,
                                  total_num_heads - closest_power_of_2)
        extra_powers = torch.arange(start=1,
                                    end=1 + 2 * num_remaining_heads,
                                    step=2,
                                    dtype=torch.int32)
        slopes = torch.cat(
            [slopes, torch.pow(extra_base, extra_powers)], dim=0)
    return slopes


class BaiChuanMLP(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
90
        quant_config: Optional[QuantizationConfig] = None,
91
92
93
94
95
    ):
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
            hidden_size, [intermediate_size] * 2,
            bias=False,
96
            quant_config=quant_config)
97
98
99
        self.down_proj = RowParallelLinear(intermediate_size,
                                           hidden_size,
                                           bias=False,
100
                                           quant_config=quant_config)
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


class BaiChuanAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        position_embedding: str,
        rope_theta: float = 10000,
        max_position_embeddings: int = 8192,
123
        cache_config: Optional[CacheConfig] = None,
124
        quant_config: Optional[QuantizationConfig] = None,
125
        prefix: str = "",
126
127
128
129
130
131
132
133
134
135
    ):
        super().__init__()
        self.hidden_size = hidden_size
        tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
        )
        self.total_num_heads = num_heads
        assert self.total_num_heads % tensor_model_parallel_world_size == 0
        self.num_heads = (self.total_num_heads //
                          tensor_model_parallel_world_size)
        self.head_dim = hidden_size // self.total_num_heads
136
        self.position_embedding = position_embedding
137
138
139
140
141
142
143
144
145
146
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings

        # pylint: disable=invalid-name
        self.W_pack = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_heads,
            bias=False,
147
            quant_config=quant_config,
148
149
150
151
152
        )
        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
153
            quant_config=quant_config,
154
155
        )
        # Create the alibi slopes and slice them.
156
        if self.position_embedding == "ALIBI":
157
158
159
160
161
162
163
            tp_rank = get_tensor_model_parallel_rank()
            head_start = tp_rank * self.num_heads
            head_end = (tp_rank + 1) * self.num_heads
            alibi_slopes = _get_alibi_slopes(self.total_num_heads)
            alibi_slopes = alibi_slopes[head_start:head_end].tolist()

            scaling = self.head_dim**-0.5
164
165
166
            self.attn = Attention(self.num_heads,
                                  self.head_dim,
                                  scaling,
167
                                  alibi_slopes=alibi_slopes,
168
169
                                  quant_config=quant_config,
                                  prefix=f"{prefix}.attn")
170
171
172
173
174
175
176
177
        else:
            self.rotary_emb = get_rope(
                self.head_dim,
                rotary_dim=self.head_dim,
                max_position=self.max_position_embeddings,
                base=self.rope_theta,
            )
            self.scaling = self.head_dim**-0.5
178
179
180
            self.attn = Attention(self.num_heads,
                                  self.head_dim,
                                  self.scaling,
181
                                  cache_config=cache_config,
182
183
                                  quant_config=quant_config,
                                  prefix=f"{prefix}.attn")
184
185
186
187
188
189
190
191

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.W_pack(hidden_states)
        q, k, v = qkv.chunk(chunks=3, dim=-1)
192
        if self.position_embedding != "ALIBI":
193
            q, k = self.rotary_emb(positions, q, k)
194
        attn_output = self.attn(q, k, v)
195
196
197
198
199
200
201
        output, _ = self.o_proj(attn_output)
        return output


class BaiChuanDecoderLayer(nn.Module):

    def __init__(self,
202
                 config: PretrainedConfig,
203
                 position_embedding: str,
204
                 cache_config: Optional[CacheConfig] = None,
205
206
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
207
208
209
210
211
212
213
214
215
216
217
        super().__init__()
        self.hidden_size = config.hidden_size
        rope_theta = getattr(config, "rope_theta", 10000)
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
        self.self_attn = BaiChuanAttention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            position_embedding=position_embedding,
            rope_theta=rope_theta,
            max_position_embeddings=max_position_embeddings,
218
            cache_config=cache_config,
219
            quant_config=quant_config,
220
            prefix=f"{prefix}.self_attn",
221
222
223
224
225
        )
        self.mlp = BaiChuanMLP(
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
226
            quant_config=quant_config,
227
228
229
230
231
232
233
234
235
236
237
        )
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps=config.rms_norm_eps)

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        residual: Optional[torch.Tensor],
238
    ) -> tuple[torch.Tensor, torch.Tensor]:
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(
                hidden_states, residual)
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
        )

        # Fully Connected
        hidden_states, residual = self.post_attention_layernorm(
            hidden_states, residual)
        hidden_states = self.mlp(hidden_states)
        return hidden_states, residual


258
@support_torch_compile
259
260
class BaiChuanModel(nn.Module):

261
262
263
264
265
266
    def __init__(
        self,
        vllm_config: VllmConfig,
        prefix: str = "",
        position_embedding: str = "ROPE",
    ) -> None:
267
        super().__init__()
268
269
270
271
272

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

273
274
275
276
277
278
279
        self.config = config
        self.vocab_size = config.vocab_size

        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
        )
280
281
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
282
283
284
285
286
            lambda prefix: BaiChuanDecoderLayer(config,
                                                position_embedding,
                                                cache_config,
                                                quant_config,
                                                prefix=prefix),
287
288
            prefix=f"{prefix}.layers",
        )
289
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
290
291
292
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(
                ["hidden_states", "residual"], config.hidden_size))
293

294
295
296
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

297
298
299
300
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
301
        intermediate_tensors: Optional[IntermediateTensors],
302
        inputs_embeds: Optional[torch.Tensor] = None,
303
304
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
305
306
307
308
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
309
310
311
312
313
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
314
        for layer in islice(self.layers, self.start_layer, self.end_layer):
315
316
317
318
319
            hidden_states, residual = layer(
                positions,
                hidden_states,
                residual,
            )
320
321
322
323
324
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({
                "hidden_states": hidden_states,
                "residual": residual,
            })
325
326
327
        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states

328
329
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
330
331
332
333
334
335
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
336
        loaded_params: set[str] = set()
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
        for name, loaded_weight in weights:
            if "rotary_emb.inv_freq" in name:
                continue

            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

367

368
369
class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
                              SupportsQuant):
370
371
372
373
374
375
376
    packed_modules_mapping = {
        "W_pack": ["W_pack"],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }
377

378
379
    def __init__(
        self,
380
        *,
381
382
383
        vllm_config: VllmConfig,
        prefix: str = "",
        position_embedding: str = "ROPE",
384
    ):
385
        super().__init__()
386
387
388
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config
389
        self.config = config
390
        self.lora_config = lora_config
391
        self.tp_size = get_tensor_model_parallel_world_size()
392
        self.quant_config = quant_config
393
394
395
        self.model = BaiChuanModel(vllm_config=vllm_config,
                                   prefix=prefix,
                                   position_embedding=position_embedding)
396
397
        self.lm_head = ParallelLMHead(config.vocab_size,
                                      config.hidden_size,
398
399
                                      quant_config=quant_config,
                                      prefix=maybe_prefix(prefix, "lm_head"))
400
        self.lm_head.weight.weight_loader = self.lm_head_weight_loader
401
402
        if self.config.tie_word_embeddings:
            self.lm_head.weight = self.model.embed_tokens.weight
403
        self.logits_processor = LogitsProcessor(config.vocab_size)
404
405
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)
406

407
408
409
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

410
411
412
413
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
414
        intermediate_tensors: Optional[IntermediateTensors] = None,
415
        inputs_embeds: Optional[torch.Tensor] = None,
416
    ) -> Union[torch.Tensor, IntermediateTensors]:
417
        hidden_states = self.model(input_ids, positions, intermediate_tensors,
418
                                   inputs_embeds)
419
420
        return hidden_states

421
422
423
424
425
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
426
        logits = self.logits_processor(self.lm_head, hidden_states,
427
428
429
                                       sampling_metadata)
        return logits

430
431
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
432
433
434
435
436
437
438
439
440
441
442
443
444
445
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights)

    def lm_head_weight_loader(self, param: nn.Parameter,
                              loaded_weight: torch.Tensor):
        # Unlike Baichuan, Baichuan2 normalizes the head weights.
        # Refer to:
        # https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat/blob/84603cde5ebffb6084e476cfaeceaf0b8b91fe54/modeling_baichuan.py#L508
        # Distinguish between Baichuan and Baichuan2 by checking the
        # vocab size. This is suggested by
        # https://github.com/vllm-project/vllm/pull/1022#discussion_r1325652704
        is_baichuan2 = self.config.vocab_size == 125696
        if is_baichuan2:
            loaded_weight = torch.nn.functional.normalize(loaded_weight)
446
447
448
449
        if self.tp_size > 1:
            row_parallel_weight_loader(param, loaded_weight)
        else:
            default_weight_loader(param, loaded_weight)
450
451


452
class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
453
454
455
    """Baichuan 13B and Baichuan2 7B/13B.
    NOTE: the class name has a lower case 'c'.
    """
456

457
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
458
        config = vllm_config.model_config.hf_config
459
        if config.hidden_size == 4096:  # baichuan2 7b
460
461
462
            super().__init__(vllm_config=vllm_config,
                             prefix=prefix,
                             position_embedding="ROPE")
463
        else:  # baichuan 13b, baichuan2 13b
464
465
466
            super().__init__(vllm_config=vllm_config,
                             prefix=prefix,
                             position_embedding="ALIBI")
467
468


469
class BaiChuanForCausalLM(BaiChuanBaseForCausalLM):
470
471
472
    """Baichuan 7B.
    NOTE: the class name has an upper case 'C'.
    """
473

474
475
476
477
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__(vllm_config=vllm_config,
                         prefix=prefix,
                         position_embedding="ROPE")