baichuan.py 16 KB
Newer Older
codethazine's avatar
codethazine committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
20
"""Inference-only BaiChuan model compatible with HuggingFace weights."""
21
import math
22
from typing import Iterable, List, Optional, Tuple
codethazine's avatar
codethazine committed
23
24

import torch
25
from torch import nn
26
from transformers import PretrainedConfig
codethazine's avatar
codethazine committed
27

28
from vllm.attention import Attention, AttentionMetadata
29
from vllm.config import CacheConfig, LoRAConfig
30
31
from vllm.distributed import (get_tensor_model_parallel_rank,
                              get_tensor_model_parallel_world_size)
32
33
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
34
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
35
36
                                               QKVParallelLinear,
                                               RowParallelLinear)
37
from vllm.model_executor.layers.logits_processor import LogitsProcessor
38
39
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
40
from vllm.model_executor.layers.rotary_embedding import get_rope
41
42
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
43
    ParallelLMHead, VocabParallelEmbedding)
44
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
45
46
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
codethazine's avatar
codethazine committed
47

48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80

def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
    closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
    base = torch.tensor(
        2**(-(2**-(math.log2(closest_power_of_2) - 3))),
        dtype=torch.float32,
    )
    powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
    slopes = torch.pow(base, powers)

    if closest_power_of_2 != total_num_heads:
        extra_base = torch.tensor(
            2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
            dtype=torch.float32,
        )
        num_remaining_heads = min(closest_power_of_2,
                                  total_num_heads - closest_power_of_2)
        extra_powers = torch.arange(start=1,
                                    end=1 + 2 * num_remaining_heads,
                                    step=2,
                                    dtype=torch.int32)
        slopes = torch.cat(
            [slopes, torch.pow(extra_base, extra_powers)], dim=0)
    return slopes


class BaiChuanMLP(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
81
        quant_config: Optional[QuantizationConfig] = None,
82
83
84
85
86
    ):
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
            hidden_size, [intermediate_size] * 2,
            bias=False,
87
            quant_config=quant_config)
88
89
90
        self.down_proj = RowParallelLinear(intermediate_size,
                                           hidden_size,
                                           bias=False,
91
                                           quant_config=quant_config)
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


class BaiChuanAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        position_embedding: str,
        rope_theta: float = 10000,
        max_position_embeddings: int = 8192,
114
        cache_config: Optional[CacheConfig] = None,
115
        quant_config: Optional[QuantizationConfig] = None,
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
    ):
        super().__init__()
        self.hidden_size = hidden_size
        tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
        )
        self.total_num_heads = num_heads
        assert self.total_num_heads % tensor_model_parallel_world_size == 0
        self.num_heads = (self.total_num_heads //
                          tensor_model_parallel_world_size)
        self.head_dim = hidden_size // self.total_num_heads
        self.postion_embedding = position_embedding
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings

        # pylint: disable=invalid-name
        self.W_pack = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_heads,
            bias=False,
137
            quant_config=quant_config,
138
139
140
141
142
        )
        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
143
            quant_config=quant_config,
144
145
146
147
148
149
150
151
152
153
        )
        # Create the alibi slopes and slice them.
        if self.postion_embedding == "ALIBI":
            tp_rank = get_tensor_model_parallel_rank()
            head_start = tp_rank * self.num_heads
            head_end = (tp_rank + 1) * self.num_heads
            alibi_slopes = _get_alibi_slopes(self.total_num_heads)
            alibi_slopes = alibi_slopes[head_start:head_end].tolist()

            scaling = self.head_dim**-0.5
154
155
156
            self.attn = Attention(self.num_heads,
                                  self.head_dim,
                                  scaling,
157
158
                                  alibi_slopes=alibi_slopes,
                                  quant_config=quant_config)
159
160
161
162
163
164
165
166
        else:
            self.rotary_emb = get_rope(
                self.head_dim,
                rotary_dim=self.head_dim,
                max_position=self.max_position_embeddings,
                base=self.rope_theta,
            )
            self.scaling = self.head_dim**-0.5
167
168
169
            self.attn = Attention(self.num_heads,
                                  self.head_dim,
                                  self.scaling,
170
171
                                  cache_config=cache_config,
                                  quant_config=quant_config)
172
173
174
175
176

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
177
178
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
179
180
181
182
183
    ) -> torch.Tensor:
        qkv, _ = self.W_pack(hidden_states)
        q, k, v = qkv.chunk(chunks=3, dim=-1)
        if self.postion_embedding != "ALIBI":
            q, k = self.rotary_emb(positions, q, k)
184
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
185
186
187
188
189
190
191
        output, _ = self.o_proj(attn_output)
        return output


class BaiChuanDecoderLayer(nn.Module):

    def __init__(self,
192
                 config: PretrainedConfig,
193
                 position_embedding: str,
194
                 cache_config: Optional[CacheConfig] = None,
195
                 quant_config: Optional[QuantizationConfig] = None):
196
197
198
199
200
201
202
203
204
205
206
        super().__init__()
        self.hidden_size = config.hidden_size
        rope_theta = getattr(config, "rope_theta", 10000)
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
        self.self_attn = BaiChuanAttention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            position_embedding=position_embedding,
            rope_theta=rope_theta,
            max_position_embeddings=max_position_embeddings,
207
            cache_config=cache_config,
208
            quant_config=quant_config,
209
210
211
212
213
        )
        self.mlp = BaiChuanMLP(
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
214
            quant_config=quant_config,
215
216
217
218
219
220
221
222
223
224
        )
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps=config.rms_norm_eps)

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
225
226
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
227
228
229
230
231
232
233
234
235
236
237
238
239
        residual: Optional[torch.Tensor],
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(
                hidden_states, residual)
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
240
            attn_metadata=attn_metadata,
241
242
243
244
245
246
247
248
249
250
251
252
        )

        # Fully Connected
        hidden_states, residual = self.post_attention_layernorm(
            hidden_states, residual)
        hidden_states = self.mlp(hidden_states)
        return hidden_states, residual


class BaiChuanModel(nn.Module):

    def __init__(self,
253
                 config: PretrainedConfig,
254
                 position_embedding: str,
255
                 cache_config: Optional[CacheConfig] = None,
256
                 quant_config: Optional[QuantizationConfig] = None):
257
258
259
260
261
262
263
264
265
266
        super().__init__()
        self.config = config
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
        )
        self.layers = nn.ModuleList([
267
268
            BaiChuanDecoderLayer(config, position_embedding, cache_config,
                                 quant_config)
269
270
271
272
273
274
275
276
            for _ in range(config.num_hidden_layers)
        ])
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
277
278
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
279
280
281
282
283
284
285
286
287
    ) -> torch.Tensor:
        hidden_states = self.embed_tokens(input_ids)
        residual = None
        for i in range(len(self.layers)):
            layer = self.layers[i]
            hidden_states, residual = layer(
                positions,
                hidden_states,
                kv_caches[i],
288
                attn_metadata,
289
290
291
292
293
294
295
                residual,
            )
        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states


class BaiChuanBaseForCausalLM(nn.Module):
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
    packed_modules_mapping = {
        "W_pack": ["W_pack"],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }
    # LoRA specific attributes
    supported_lora_modules = [
        "W_pack",
        "o_proj",
        "gate_up_proj",
        "down_proj",
    ]
    embedding_modules = {}
    embedding_padding_modules = []
312

313
314
315
316
    def __init__(
        self,
        config,
        position_embedding: str,
317
        cache_config: Optional[CacheConfig] = None,
318
        quant_config: Optional[QuantizationConfig] = None,
319
320
        lora_config: Optional[LoRAConfig] = None,
    ):
321
322
        super().__init__()
        self.config = config
323
        self.quant_config = quant_config
324
325
        self.model = BaiChuanModel(config, position_embedding, cache_config,
                                   quant_config)
326
        self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
327
328
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.sampler = Sampler()
329
330
331
332
333

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
334
335
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
336
337
    ) -> torch.Tensor:
        hidden_states = self.model(input_ids, positions, kv_caches,
338
                                   attn_metadata)
339
340
        return hidden_states

341
342
343
344
345
346
    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head.weight, hidden_states,
                                       sampling_metadata)
        return logits

347
348
    def sample(
        self,
349
        logits: torch.Tensor,
350
351
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
352
        next_tokens = self.sampler(logits, sampling_metadata)
353
        return next_tokens
codethazine's avatar
codethazine committed
354

355
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
356
357
358
359
360
361
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
362
        for name, loaded_weight in weights:
codethazine's avatar
codethazine committed
363
364
            if "rotary_emb.inv_freq" in name:
                continue
365
            if name == "lm_head.weight":
366
367
                # Unlike Baichuan, Baichuan2 normalizes the head weights.
                # Refer to:
368
369
370
371
372
373
374
375
376
                # https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat/blob/84603cde5ebffb6084e476cfaeceaf0b8b91fe54/modeling_baichuan.py#L508
                # Distinguish between Baichuan and Baichuan2 by checking the
                # vocab size. This is suggested by
                # https://github.com/vllm-project/vllm/pull/1022#discussion_r1325652704
                is_baichuan2 = self.config.vocab_size == 125696
                if is_baichuan2:
                    loaded_weight = torch.nn.functional.normalize(
                        loaded_weight)

377
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
codethazine's avatar
codethazine committed
378
379
                if weight_name not in name:
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
380
381
382
383
384
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                param = params_dict[name]
385
386
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
codethazine's avatar
codethazine committed
387
                break
388
            else:
CHU Tianxiang's avatar
CHU Tianxiang committed
389
390
391
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
392
393
394
395
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
396
397


398
399
class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
    """Baichuan 13B and Baichuan2 7B/13B."""
400

401
402
403
    def __init__(
        self,
        config,
404
        cache_config: Optional[CacheConfig] = None,
405
        quant_config: Optional[QuantizationConfig] = None,
406
407
        lora_config: Optional[LoRAConfig] = None,
    ):
408
        if config.hidden_size == 4096:  # baichuan2 7b
409
410
            super().__init__(config, "ROPE", cache_config, quant_config,
                             lora_config)
411
        else:  # baichuan 13b, baichuan2 13b
412
413
            super().__init__(config, "ALIBI", cache_config, quant_config,
                             lora_config)
414
415


416
417
class BaiChuanForCausalLM(BaiChuanBaseForCausalLM):
    """Baichuan 7B."""
418

419
420
421
    def __init__(
        self,
        config,
422
        cache_config: Optional[CacheConfig] = None,
423
        quant_config: Optional[QuantizationConfig] = None,
424
425
        lora_config: Optional[LoRAConfig] = None,
    ):
426
427
        super().__init__(config, "ROPE", cache_config, quant_config,
                         lora_config)