baichuan.py 17.1 KB
Newer Older
codethazine's avatar
codethazine committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
20
"""Inference-only BaiChuan model compatible with HuggingFace weights."""
21
import math
22
from typing import Iterable, List, Optional, Tuple
codethazine's avatar
codethazine committed
23
24

import torch
25
from torch import nn
26
from transformers import PretrainedConfig
zhuwenwen's avatar
zhuwenwen committed
27
import os
zhuwenwen's avatar
zhuwenwen committed
28
import re
codethazine's avatar
codethazine committed
29

30
from vllm.attention import Attention, AttentionMetadata
31
from vllm.config import CacheConfig, LoRAConfig
32
33
from vllm.distributed import (get_tensor_model_parallel_rank,
                              get_tensor_model_parallel_world_size)
34
35
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
36
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
37
38
                                               QKVParallelLinear,
                                               RowParallelLinear)
39
from vllm.model_executor.layers.logits_processor import LogitsProcessor
40
41
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
42
from vllm.model_executor.layers.rotary_embedding import get_rope
43
44
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
45
    ParallelLMHead, VocabParallelEmbedding)
46
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
47
48
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
zhuwenwen's avatar
zhuwenwen committed
49
from vllm import _custom_ops as ops
codethazine's avatar
codethazine committed
50

51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83

def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
    closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
    base = torch.tensor(
        2**(-(2**-(math.log2(closest_power_of_2) - 3))),
        dtype=torch.float32,
    )
    powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
    slopes = torch.pow(base, powers)

    if closest_power_of_2 != total_num_heads:
        extra_base = torch.tensor(
            2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
            dtype=torch.float32,
        )
        num_remaining_heads = min(closest_power_of_2,
                                  total_num_heads - closest_power_of_2)
        extra_powers = torch.arange(start=1,
                                    end=1 + 2 * num_remaining_heads,
                                    step=2,
                                    dtype=torch.int32)
        slopes = torch.cat(
            [slopes, torch.pow(extra_base, extra_powers)], dim=0)
    return slopes


class BaiChuanMLP(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
84
        quant_config: Optional[QuantizationConfig] = None,
85
86
87
88
89
    ):
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
            hidden_size, [intermediate_size] * 2,
            bias=False,
90
            quant_config=quant_config)
91
92
93
        self.down_proj = RowParallelLinear(intermediate_size,
                                           hidden_size,
                                           bias=False,
94
                                           quant_config=quant_config)
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


class BaiChuanAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        position_embedding: str,
        rope_theta: float = 10000,
        max_position_embeddings: int = 8192,
117
        cache_config: Optional[CacheConfig] = None,
118
        quant_config: Optional[QuantizationConfig] = None,
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
    ):
        super().__init__()
        self.hidden_size = hidden_size
        tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
        )
        self.total_num_heads = num_heads
        assert self.total_num_heads % tensor_model_parallel_world_size == 0
        self.num_heads = (self.total_num_heads //
                          tensor_model_parallel_world_size)
        self.head_dim = hidden_size // self.total_num_heads
        self.postion_embedding = position_embedding
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings

        # pylint: disable=invalid-name
        self.W_pack = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_heads,
            bias=False,
140
            quant_config=quant_config,
141
142
143
144
145
        )
        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
146
            quant_config=quant_config,
147
148
149
150
151
152
153
154
155
156
        )
        # Create the alibi slopes and slice them.
        if self.postion_embedding == "ALIBI":
            tp_rank = get_tensor_model_parallel_rank()
            head_start = tp_rank * self.num_heads
            head_end = (tp_rank + 1) * self.num_heads
            alibi_slopes = _get_alibi_slopes(self.total_num_heads)
            alibi_slopes = alibi_slopes[head_start:head_end].tolist()

            scaling = self.head_dim**-0.5
157
158
159
            self.attn = Attention(self.num_heads,
                                  self.head_dim,
                                  scaling,
160
161
                                  alibi_slopes=alibi_slopes,
                                  quant_config=quant_config)
162
163
164
165
166
167
168
169
        else:
            self.rotary_emb = get_rope(
                self.head_dim,
                rotary_dim=self.head_dim,
                max_position=self.max_position_embeddings,
                base=self.rope_theta,
            )
            self.scaling = self.head_dim**-0.5
170
171
172
            self.attn = Attention(self.num_heads,
                                  self.head_dim,
                                  self.scaling,
173
174
                                  cache_config=cache_config,
                                  quant_config=quant_config)
175
176
177
178
179

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
180
181
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
182
183
184
185
186
    ) -> torch.Tensor:
        qkv, _ = self.W_pack(hidden_states)
        q, k, v = qkv.chunk(chunks=3, dim=-1)
        if self.postion_embedding != "ALIBI":
            q, k = self.rotary_emb(positions, q, k)
187
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
188
189
190
191
192
193
194
        output, _ = self.o_proj(attn_output)
        return output


class BaiChuanDecoderLayer(nn.Module):

    def __init__(self,
195
                 config: PretrainedConfig,
196
                 position_embedding: str,
197
                 cache_config: Optional[CacheConfig] = None,
198
                 quant_config: Optional[QuantizationConfig] = None):
199
200
201
202
203
204
205
206
207
208
209
        super().__init__()
        self.hidden_size = config.hidden_size
        rope_theta = getattr(config, "rope_theta", 10000)
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
        self.self_attn = BaiChuanAttention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            position_embedding=position_embedding,
            rope_theta=rope_theta,
            max_position_embeddings=max_position_embeddings,
210
            cache_config=cache_config,
211
            quant_config=quant_config,
212
213
214
215
216
        )
        self.mlp = BaiChuanMLP(
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
217
            quant_config=quant_config,
218
219
220
221
222
223
224
225
226
227
        )
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps=config.rms_norm_eps)

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
228
229
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
230
231
232
233
234
235
236
237
238
239
240
241
242
        residual: Optional[torch.Tensor],
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(
                hidden_states, residual)
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
243
            attn_metadata=attn_metadata,
244
245
246
247
248
249
250
251
252
253
254
255
        )

        # Fully Connected
        hidden_states, residual = self.post_attention_layernorm(
            hidden_states, residual)
        hidden_states = self.mlp(hidden_states)
        return hidden_states, residual


class BaiChuanModel(nn.Module):

    def __init__(self,
256
                 config: PretrainedConfig,
257
                 position_embedding: str,
258
                 cache_config: Optional[CacheConfig] = None,
259
                 quant_config: Optional[QuantizationConfig] = None):
260
261
262
263
264
265
266
267
268
269
        super().__init__()
        self.config = config
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
        )
        self.layers = nn.ModuleList([
270
271
            BaiChuanDecoderLayer(config, position_embedding, cache_config,
                                 quant_config)
272
273
274
275
276
277
278
279
            for _ in range(config.num_hidden_layers)
        ])
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
280
281
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
282
283
284
285
286
287
288
289
290
    ) -> torch.Tensor:
        hidden_states = self.embed_tokens(input_ids)
        residual = None
        for i in range(len(self.layers)):
            layer = self.layers[i]
            hidden_states, residual = layer(
                positions,
                hidden_states,
                kv_caches[i],
291
                attn_metadata,
292
293
294
295
296
297
298
                residual,
            )
        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states


class BaiChuanBaseForCausalLM(nn.Module):
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
    packed_modules_mapping = {
        "W_pack": ["W_pack"],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }
    # LoRA specific attributes
    supported_lora_modules = [
        "W_pack",
        "o_proj",
        "gate_up_proj",
        "down_proj",
    ]
    embedding_modules = {}
    embedding_padding_modules = []
315

316
317
318
319
    def __init__(
        self,
        config,
        position_embedding: str,
320
        cache_config: Optional[CacheConfig] = None,
321
        quant_config: Optional[QuantizationConfig] = None,
322
323
        lora_config: Optional[LoRAConfig] = None,
    ):
324
325
        super().__init__()
        self.config = config
326
        self.quant_config = quant_config
327
328
        self.model = BaiChuanModel(config, position_embedding, cache_config,
                                   quant_config)
329
        self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
330
331
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.sampler = Sampler()
zhuwenwen's avatar
zhuwenwen committed
332
        self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
333
334
335
336
337

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
338
339
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
340
341
    ) -> torch.Tensor:
        hidden_states = self.model(input_ids, positions, kv_caches,
342
                                   attn_metadata)
343
344
        return hidden_states

345
346
347
348
349
350
    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head.weight, hidden_states,
                                       sampling_metadata)
        return logits

351
352
    def sample(
        self,
353
        logits: torch.Tensor,
354
355
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
356
        next_tokens = self.sampler(logits, sampling_metadata)
357
        return next_tokens
codethazine's avatar
codethazine committed
358

359
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
360
361
362
363
364
365
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
366
        for name, loaded_weight in weights:
codethazine's avatar
codethazine committed
367
368
            if "rotary_emb.inv_freq" in name:
                continue
369
            if name == "lm_head.weight":
370
371
                # Unlike Baichuan, Baichuan2 normalizes the head weights.
                # Refer to:
372
373
374
375
376
377
378
379
380
                # https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat/blob/84603cde5ebffb6084e476cfaeceaf0b8b91fe54/modeling_baichuan.py#L508
                # Distinguish between Baichuan and Baichuan2 by checking the
                # vocab size. This is suggested by
                # https://github.com/vllm-project/vllm/pull/1022#discussion_r1325652704
                is_baichuan2 = self.config.vocab_size == 125696
                if is_baichuan2:
                    loaded_weight = torch.nn.functional.normalize(
                        loaded_weight)

381
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
codethazine's avatar
codethazine committed
382
383
                if weight_name not in name:
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
384
385
386
387
388
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                param = params_dict[name]
389
390
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
codethazine's avatar
codethazine committed
391
                break
392
            else:
CHU Tianxiang's avatar
CHU Tianxiang committed
393
394
395
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
396
397
398
399
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
zhuwenwen's avatar
zhuwenwen committed
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
                
            if self.use_llama_nn:
                lay_key_words = [
                    "self_attn.W_pack.weight",
                    "self_attn.o_proj.weight",
                    "mlp.gate_up_proj.weight",
                    "mlp.down_proj.weight"
                ]
                combined_words = "|".join(lay_key_words)
                
                for layername, weight in params_dict.items():
                    matches = re.findall(combined_words, layername)
                    if matches:                  
                        _weight = torch.zeros_like(weight.data)
                        ori_shape =_weight.shape
                        
                        ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
                        weight.data.copy_(_weight)
                        
                        weight.data=weight.data.reshape(ori_shape[1], -1)
420
421


422
423
class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
    """Baichuan 13B and Baichuan2 7B/13B."""
424

425
426
427
    def __init__(
        self,
        config,
428
        cache_config: Optional[CacheConfig] = None,
429
        quant_config: Optional[QuantizationConfig] = None,
430
431
        lora_config: Optional[LoRAConfig] = None,
    ):
432
        if config.hidden_size == 4096:  # baichuan2 7b
433
434
            super().__init__(config, "ROPE", cache_config, quant_config,
                             lora_config)
435
        else:  # baichuan 13b, baichuan2 13b
436
437
            super().__init__(config, "ALIBI", cache_config, quant_config,
                             lora_config)
438
439


440
441
class BaiChuanForCausalLM(BaiChuanBaseForCausalLM):
    """Baichuan 7B."""
442

443
444
445
    def __init__(
        self,
        config,
446
        cache_config: Optional[CacheConfig] = None,
447
        quant_config: Optional[QuantizationConfig] = None,
448
449
        lora_config: Optional[LoRAConfig] = None,
    ):
450
451
        super().__init__(config, "ROPE", cache_config, quant_config,
                         lora_config)