baichuan.py 18.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

codethazine's avatar
codethazine committed
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
22
"""Inference-only BaiChuan model compatible with HuggingFace weights."""
23
import math
24
25
from collections.abc import Iterable
from typing import Optional, Union
codethazine's avatar
codethazine committed
26
27

import torch
28
from torch import nn
29
from transformers import PretrainedConfig
codethazine's avatar
codethazine committed
30

31
from vllm.attention import Attention
32
from vllm.compilation.decorators import support_torch_compile
33
from vllm.config import CacheConfig, VllmConfig
34
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
35
                              get_tensor_model_parallel_world_size)
36
37
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
38
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
39
40
                                               QKVParallelLinear,
                                               RowParallelLinear)
41
from vllm.model_executor.layers.logits_processor import LogitsProcessor
42
from vllm.model_executor.layers.quantization import QuantizationConfig
43
from vllm.model_executor.layers.rotary_embedding import get_rope
44
from vllm.model_executor.layers.vocab_parallel_embedding import (
45
    ParallelLMHead, VocabParallelEmbedding)
46
47
from vllm.model_executor.model_loader.weight_utils import (
    default_weight_loader, row_parallel_weight_loader)
48
from vllm.model_executor.sampling_metadata import SamplingMetadata
49
from vllm.sequence import IntermediateTensors
codethazine's avatar
codethazine committed
50

51
from .interfaces import SupportsLoRA, SupportsPP, SupportsQuant
52
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
53
                    make_empty_intermediate_tensors_factory, make_layers)
54

55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87

def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
    closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
    base = torch.tensor(
        2**(-(2**-(math.log2(closest_power_of_2) - 3))),
        dtype=torch.float32,
    )
    powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
    slopes = torch.pow(base, powers)

    if closest_power_of_2 != total_num_heads:
        extra_base = torch.tensor(
            2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
            dtype=torch.float32,
        )
        num_remaining_heads = min(closest_power_of_2,
                                  total_num_heads - closest_power_of_2)
        extra_powers = torch.arange(start=1,
                                    end=1 + 2 * num_remaining_heads,
                                    step=2,
                                    dtype=torch.int32)
        slopes = torch.cat(
            [slopes, torch.pow(extra_base, extra_powers)], dim=0)
    return slopes


class BaiChuanMLP(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
88
        quant_config: Optional[QuantizationConfig] = None,
89
90
91
92
93
    ):
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
            hidden_size, [intermediate_size] * 2,
            bias=False,
94
            quant_config=quant_config)
95
96
97
        self.down_proj = RowParallelLinear(intermediate_size,
                                           hidden_size,
                                           bias=False,
98
                                           quant_config=quant_config)
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


class BaiChuanAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        position_embedding: str,
        rope_theta: float = 10000,
        max_position_embeddings: int = 8192,
121
        cache_config: Optional[CacheConfig] = None,
122
        quant_config: Optional[QuantizationConfig] = None,
123
        prefix: str = "",
124
125
126
127
128
129
130
131
132
133
    ):
        super().__init__()
        self.hidden_size = hidden_size
        tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
        )
        self.total_num_heads = num_heads
        assert self.total_num_heads % tensor_model_parallel_world_size == 0
        self.num_heads = (self.total_num_heads //
                          tensor_model_parallel_world_size)
        self.head_dim = hidden_size // self.total_num_heads
134
        self.position_embedding = position_embedding
135
136
137
138
139
140
141
142
143
144
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings

        # pylint: disable=invalid-name
        self.W_pack = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_heads,
            bias=False,
145
            quant_config=quant_config,
146
147
148
149
150
        )
        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
151
            quant_config=quant_config,
152
153
        )
        # Create the alibi slopes and slice them.
154
        if self.position_embedding == "ALIBI":
155
156
157
158
159
160
161
            tp_rank = get_tensor_model_parallel_rank()
            head_start = tp_rank * self.num_heads
            head_end = (tp_rank + 1) * self.num_heads
            alibi_slopes = _get_alibi_slopes(self.total_num_heads)
            alibi_slopes = alibi_slopes[head_start:head_end].tolist()

            scaling = self.head_dim**-0.5
162
163
164
            self.attn = Attention(self.num_heads,
                                  self.head_dim,
                                  scaling,
165
                                  alibi_slopes=alibi_slopes,
166
167
                                  quant_config=quant_config,
                                  prefix=f"{prefix}.attn")
168
169
170
171
172
173
174
175
        else:
            self.rotary_emb = get_rope(
                self.head_dim,
                rotary_dim=self.head_dim,
                max_position=self.max_position_embeddings,
                base=self.rope_theta,
            )
            self.scaling = self.head_dim**-0.5
176
177
178
            self.attn = Attention(self.num_heads,
                                  self.head_dim,
                                  self.scaling,
179
                                  cache_config=cache_config,
180
181
                                  quant_config=quant_config,
                                  prefix=f"{prefix}.attn")
182
183
184
185
186
187
188
189

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.W_pack(hidden_states)
        q, k, v = qkv.chunk(chunks=3, dim=-1)
190
        if self.position_embedding != "ALIBI":
191
            q, k = self.rotary_emb(positions, q, k)
192
        attn_output = self.attn(q, k, v)
193
194
195
196
197
198
199
        output, _ = self.o_proj(attn_output)
        return output


class BaiChuanDecoderLayer(nn.Module):

    def __init__(self,
200
                 config: PretrainedConfig,
201
                 position_embedding: str,
202
                 cache_config: Optional[CacheConfig] = None,
203
204
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
205
206
207
208
209
210
211
212
213
214
215
        super().__init__()
        self.hidden_size = config.hidden_size
        rope_theta = getattr(config, "rope_theta", 10000)
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
        self.self_attn = BaiChuanAttention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            position_embedding=position_embedding,
            rope_theta=rope_theta,
            max_position_embeddings=max_position_embeddings,
216
            cache_config=cache_config,
217
            quant_config=quant_config,
218
            prefix=f"{prefix}.self_attn",
219
220
221
222
223
        )
        self.mlp = BaiChuanMLP(
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
224
            quant_config=quant_config,
225
226
227
228
229
230
231
232
233
234
235
        )
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps=config.rms_norm_eps)

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        residual: Optional[torch.Tensor],
236
    ) -> tuple[torch.Tensor, torch.Tensor]:
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(
                hidden_states, residual)
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
        )

        # Fully Connected
        hidden_states, residual = self.post_attention_layernorm(
            hidden_states, residual)
        hidden_states = self.mlp(hidden_states)
        return hidden_states, residual


256
@support_torch_compile
257
258
class BaiChuanModel(nn.Module):

259
260
261
262
263
264
    def __init__(
        self,
        vllm_config: VllmConfig,
        prefix: str = "",
        position_embedding: str = "ROPE",
    ) -> None:
265
        super().__init__()
266
267
268
269
270

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

271
272
273
274
275
276
277
        self.config = config
        self.vocab_size = config.vocab_size

        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
        )
278
279
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
280
281
282
283
284
            lambda prefix: BaiChuanDecoderLayer(config,
                                                position_embedding,
                                                cache_config,
                                                quant_config,
                                                prefix=prefix),
285
286
            prefix=f"{prefix}.layers",
        )
287
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
288
289
290
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(
                ["hidden_states", "residual"], config.hidden_size))
291

292
293
294
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

295
296
297
298
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
299
        intermediate_tensors: Optional[IntermediateTensors],
300
        inputs_embeds: Optional[torch.Tensor] = None,
301
302
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
303
304
305
306
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
307
308
309
310
311
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
312
        for layer in self.layers[self.start_layer:self.end_layer]:
313
314
315
316
317
            hidden_states, residual = layer(
                positions,
                hidden_states,
                residual,
            )
318
319
320
321
322
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({
                "hidden_states": hidden_states,
                "residual": residual,
            })
323
324
325
        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states

326
327
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
328
329
330
331
332
333
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
334
        loaded_params: set[str] = set()
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
        for name, loaded_weight in weights:
            if "rotary_emb.inv_freq" in name:
                continue

            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

365

366
367
class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
                              SupportsQuant):
368
369
370
371
372
373
374
    packed_modules_mapping = {
        "W_pack": ["W_pack"],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }
375

376
377
    def __init__(
        self,
378
        *,
379
380
381
        vllm_config: VllmConfig,
        prefix: str = "",
        position_embedding: str = "ROPE",
382
    ):
383
        super().__init__()
384
385
386
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config
387
        self.config = config
388
        self.lora_config = lora_config
389
        self.tp_size = get_tensor_model_parallel_world_size()
390
        self.quant_config = quant_config
391
392
393
        self.model = BaiChuanModel(vllm_config=vllm_config,
                                   prefix=prefix,
                                   position_embedding=position_embedding)
394
395
396
        self.lm_head = ParallelLMHead(config.vocab_size,
                                      config.hidden_size,
                                      quant_config=quant_config)
397
        self.lm_head.weight.weight_loader = self.lm_head_weight_loader
398
399
        if self.config.tie_word_embeddings:
            self.lm_head.weight = self.model.embed_tokens.weight
400
        self.logits_processor = LogitsProcessor(config.vocab_size)
401
402
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)
403

404
405
406
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

407
408
409
410
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
411
        intermediate_tensors: Optional[IntermediateTensors] = None,
412
        inputs_embeds: Optional[torch.Tensor] = None,
413
    ) -> Union[torch.Tensor, IntermediateTensors]:
414
        hidden_states = self.model(input_ids, positions, intermediate_tensors,
415
                                   inputs_embeds)
416
417
        return hidden_states

418
419
420
421
422
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
423
        logits = self.logits_processor(self.lm_head, hidden_states,
424
425
426
                                       sampling_metadata)
        return logits

427
428
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
429
430
431
432
433
434
435
436
437
438
439
440
441
442
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights)

    def lm_head_weight_loader(self, param: nn.Parameter,
                              loaded_weight: torch.Tensor):
        # Unlike Baichuan, Baichuan2 normalizes the head weights.
        # Refer to:
        # https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat/blob/84603cde5ebffb6084e476cfaeceaf0b8b91fe54/modeling_baichuan.py#L508
        # Distinguish between Baichuan and Baichuan2 by checking the
        # vocab size. This is suggested by
        # https://github.com/vllm-project/vllm/pull/1022#discussion_r1325652704
        is_baichuan2 = self.config.vocab_size == 125696
        if is_baichuan2:
            loaded_weight = torch.nn.functional.normalize(loaded_weight)
443
444
445
446
        if self.tp_size > 1:
            row_parallel_weight_loader(param, loaded_weight)
        else:
            default_weight_loader(param, loaded_weight)
447
448


449
class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
450
451
452
    """Baichuan 13B and Baichuan2 7B/13B.
    NOTE: the class name has a lower case 'c'.
    """
453

454
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
455
        config = vllm_config.model_config.hf_config
456
        if config.hidden_size == 4096:  # baichuan2 7b
457
458
459
            super().__init__(vllm_config=vllm_config,
                             prefix=prefix,
                             position_embedding="ROPE")
460
        else:  # baichuan 13b, baichuan2 13b
461
462
463
            super().__init__(vllm_config=vllm_config,
                             prefix=prefix,
                             position_embedding="ALIBI")
464
465


466
class BaiChuanForCausalLM(BaiChuanBaseForCausalLM):
467
468
469
    """Baichuan 7B.
    NOTE: the class name has an upper case 'C'.
    """
470

471
472
473
474
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__(vllm_config=vllm_config,
                         prefix=prefix,
                         position_embedding="ROPE")