llama_eagle.py 7.98 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from collections.abc import Iterable
5
6
7
8
9

import torch
import torch.nn as nn
from transformers import LlamaConfig

10
11
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
12
from vllm.distributed.parallel_state import get_pp_group
13
from vllm.logger import init_logger
14
from vllm.model_executor.layers.linear import ReplicatedLinear
15
from vllm.model_executor.layers.logits_processor import LogitsProcessor
16
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
17
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
18
19
20
21
from vllm.model_executor.model_loader.weight_utils import (
    default_weight_loader,
    maybe_remap_kv_scale_name,
)
22
from vllm.model_executor.models.llama import LlamaDecoderLayer, LlamaForCausalLM
23

24
25
26
27
28
29
from .utils import (
    AutoWeightsLoader,
    get_draft_quant_config,
    maybe_prefix,
    process_eagle_weight,
)
30
31
32
33
34
35
36

logger = init_logger(__name__)


class LlamaDecoderLayer(LlamaDecoderLayer):
    def __init__(
        self,
37
        vllm_config: VllmConfig,
38
39
        disable_input_layernorm: bool,
        prefix: str = "",
40
        config: LlamaConfig | None = None,
41
    ) -> None:
42
        super().__init__(vllm_config, prefix=prefix, config=config)
43
44
45
46
47
48
49

        # Skip the input_layernorm
        # https://github.com/SafeAILab/EAGLE/blob/35c78f6cdc19a73e05cf5c330b4c358dad970c6a/eagle/model/cnets.py#L427
        if disable_input_layernorm:
            del self.input_layernorm
            self.input_layernorm = nn.Identity()

50
51
    def get_quant_config(self, vllm_config: VllmConfig) -> QuantizationConfig | None:
        """Use drafter's quantization config instead of verifier's."""
52
        return get_draft_quant_config(vllm_config)
53

54

55
@support_torch_compile
56
57
58
59
class LlamaModel(nn.Module):
    def __init__(
        self,
        *,
60
        vllm_config: VllmConfig,
61
        prefix: str = "",
62
        start_layer_id: int = 0,
63
64
    ) -> None:
        super().__init__()
65
        self.config = vllm_config.speculative_config.draft_model_config.hf_config
66
        self.vocab_size = self.config.vocab_size
67

68
69
70
        # Get drafter's quantization config
        self.quant_config = get_draft_quant_config(vllm_config)

71
72
73
74
75
        self.embed_tokens = VocabParallelEmbedding(
            self.config.vocab_size,
            self.config.hidden_size,
            prefix=maybe_prefix(prefix, "embed_tokens"),
        )
76

77
78
79
80
81
82
83
84
85
86
87
        self.layers = nn.ModuleList(
            [
                LlamaDecoderLayer(
                    vllm_config,
                    i == 0,
                    prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"),
                    config=self.config,
                )
                for i in range(self.config.num_hidden_layers)
            ]
        )
88
89
90
91
92
93
94
95
        self.fc = ReplicatedLinear(
            input_size=self.config.hidden_size * 2,
            output_size=self.config.hidden_size,
            bias=False,
            params_dtype=vllm_config.model_config.dtype,
            quant_config=self.quant_config,
            prefix=maybe_prefix(prefix, "fc"),
            return_bias=False,
96
        )
97

98
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
99
100
        return self.embed_tokens(input_ids)

101
102
103
104
105
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
106
    ) -> tuple[torch.Tensor, torch.Tensor]:
107
        input_embeds = self.embed_tokens(input_ids)
108
        hidden_states = self.fc(torch.cat((input_embeds, hidden_states), dim=-1))
109
        residual = None
110
        for layer in self.layers:
111
112
113
114
115
            hidden_states, residual = layer(
                positions,
                hidden_states,
                residual,
            )
116
117
        hidden_states = hidden_states + residual
        return hidden_states, hidden_states
118

119
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
120
121
122
123
124
125
126
127
128
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
            (".gate_up_proj", ".gate_proj", 0),
            (".gate_up_proj", ".up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
129
        loaded_params: set[str] = set()
130
        for name, loaded_weight in weights:
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
            # Handle kv cache quantization scales
            if self.quant_config is not None and (
                scale_name := self.quant_config.get_cache_scale(name)
            ):
                # Loading kv cache quantization scales
                param = params_dict[scale_name]
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                loaded_weight = (
                    loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
                )
                weight_loader(param, loaded_weight)
                loaded_params.add(scale_name)
                continue
            # Remapping the name FP8 kv-scale
            if "scale" in name:
                name = maybe_remap_kv_scale_name(name, params_dict)
                if name is None:
                    continue
149
150
151
152
153
154
155
156
157
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
158
                # if PP disabled then draft will share embed with target
159
                if get_pp_group().world_size == 1 and "embed_tokens." in name:
160
161
                    continue

162
                param = params_dict[name]
163
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
164
165
166
167
168
169
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params


class EagleLlamaForCausalLM(LlamaForCausalLM):
170
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
171
        nn.Module.__init__(self)
172
        self.config = vllm_config.speculative_config.draft_model_config.hf_config
173
174
175
176
177
        # Ensure draft_vocab_size is set
        # default to the base vocab size when absent
        if getattr(self.config, "draft_vocab_size", None) is None:
            base_vocab_size = getattr(self.config, "vocab_size", None)
            self.config.draft_vocab_size = base_vocab_size
178
        target_layer_num = vllm_config.model_config.get_num_layers(
179
180
181
182
183
            vllm_config.parallel_config
        )
        self.model = LlamaModel(
            vllm_config=vllm_config, prefix="model", start_layer_id=target_layer_num
        )
184
185

        logit_scale = getattr(self.config, "logit_scale", 1.0)
186
187
188
        self.logits_processor = LogitsProcessor(
            self.config.vocab_size, scale=logit_scale
        )
189

190
191
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.embed_input_ids(input_ids)
192

193
194
195
196
197
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
198
        inputs_embeds: torch.Tensor | None = None,
199
    ) -> tuple[torch.Tensor, torch.Tensor]:
200
201
202
203
        if inputs_embeds is not None:
            raise NotImplementedError(
                f"{type(self).__name__} does not support multimodal inputs yet."
            )
204
205
        return self.model(input_ids, positions, hidden_states)

206
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
207
208
209
210
        def transform(inputs):
            name, loaded_weight = inputs
            if "lm_head" not in name:
                name = "model." + name
211
            process_eagle_weight(self, name)
212
213
            return name, loaded_weight

214
215
        loader = AutoWeightsLoader(
            self,
216
            skip_prefixes=None,
217
        )
218
        loader.load_weights(map(transform, weights))