decilm.py 5.43 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 DeciAI Research Team. All rights reserved.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on MistralAI GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only DeciLM model compatible with HuggingFace weights."""

25
from typing import Iterable, Optional, Tuple
26
27

import torch
28
from transformers import LlamaConfig
29

30
from vllm.config import CacheConfig, LoRAConfig
31
from vllm.model_executor.layers.quantization import QuantizationConfig
32
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
33
34
from vllm.model_executor.models.llama import LlamaForCausalLM

35
36
from .utils import is_pp_missing_parameter

37
38
39
40
41
42
43

class DeciLMForCausalLM(LlamaForCausalLM):
    """
    Implementation for https://huggingface.co/Deci/DeciLM-7b-instruct.
    Based on the llama executor.

    The main difference is that DeciLM uses Variable Grouped Query Attention.
44
    The constant number of GQA heads in the decoder is overridden with a value
45
46
47
48
49
50
51
52
53
54
55
56
57
    per layer.

    Usually, in the HuggingFace implementation, instead of
    "config.num_key_value_heads", we use
    "config.num_key_value_heads_per_layer[i]" which varies.

    Currently, PagedAttention does not work well with variable GQA, so we
    normalize the weights upon loading, and use uniform GQA with the max value
    instead.
    """

    def __init__(
        self,
58
        config: LlamaConfig,
59
        cache_config: Optional[CacheConfig] = None,
60
        quant_config: Optional[QuantizationConfig] = None,
Philipp Moritz's avatar
Philipp Moritz committed
61
        lora_config: Optional[LoRAConfig] = None,
62
63
64
    ) -> None:
        config.num_key_value_heads = max(config.num_key_value_heads_per_layer)
        delattr(config, "num_key_value_heads_per_layer")
Philipp Moritz's avatar
Philipp Moritz committed
65
        super().__init__(config=config,
66
                         cache_config=cache_config,
67
                         quant_config=quant_config,
Philipp Moritz's avatar
Philipp Moritz committed
68
                         lora_config=lora_config)
69

70
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
71
72
73
74
75
76
77
78
79
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
80
        for name, loaded_weight in weights:
81
82
83
84
85
86
87
88
89
90
91
92
93
            if "rotary_emb.inv_freq" in name:
                continue

            if "k_proj" in name or "v_proj" in name:
                loaded_weight = self._degroup_weight(loaded_weight)

            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
94
95
                if is_pp_missing_parameter(name, self):
                    continue
96
97
98
99
100
101
102
103
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
104
105
                if is_pp_missing_parameter(name, self):
                    continue
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)

    def _degroup_weight(self, loaded_weight: torch.Tensor) -> torch.Tensor:
        hidden_size = self.config.hidden_size
        head_size = self.config.hidden_size // self.config.num_attention_heads
        target_num_kv_heads = self.config.num_key_value_heads
        num_kv_heads = loaded_weight.shape[0] // head_size
        n_repeats = target_num_kv_heads / num_kv_heads
        assert n_repeats == int(n_repeats)

        n_repeats = int(n_repeats)
        loaded_weight = loaded_weight.view(num_kv_heads, head_size,
                                           hidden_size)
        loaded_weight = torch.repeat_interleave(loaded_weight,
                                                repeats=n_repeats,
                                                dim=0)
        loaded_weight = loaded_weight.reshape(target_num_kv_heads * head_size,
                                              hidden_size)

        return loaded_weight