"tools/vscode:/vscode.git/clone" did not exist on "66b809cc68259a061c3184012324b7cfb5cf776f"
llama_eagle.py 6.63 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from collections.abc import Iterable
5
from typing import Optional
6
7
8
9
10

import torch
import torch.nn as nn
from transformers import LlamaConfig

11
12
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
13
from vllm.distributed.parallel_state import get_pp_group
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from vllm.logger import init_logger
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import (
    VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.llama import (LlamaDecoderLayer,
                                              LlamaForCausalLM)

from .utils import AutoWeightsLoader, maybe_prefix

logger = init_logger(__name__)


class LlamaDecoderLayer(LlamaDecoderLayer):

    def __init__(
        self,
31
        vllm_config: VllmConfig,
32
33
        disable_input_layernorm: bool,
        prefix: str = "",
34
        config: Optional[LlamaConfig] = None,
35
    ) -> None:
36
        super().__init__(vllm_config, prefix=prefix, config=config)
37
38
39
40
41
42
43
44

        # Skip the input_layernorm
        # https://github.com/SafeAILab/EAGLE/blob/35c78f6cdc19a73e05cf5c330b4c358dad970c6a/eagle/model/cnets.py#L427
        if disable_input_layernorm:
            del self.input_layernorm
            self.input_layernorm = nn.Identity()


45
@support_torch_compile
46
47
48
49
50
class LlamaModel(nn.Module):

    def __init__(
        self,
        *,
51
        vllm_config: VllmConfig,
52
        prefix: str = "",
53
        start_layer_id: int = 0,
54
55
    ) -> None:
        super().__init__()
56
57
        self.config = vllm_config. \
            speculative_config.draft_model_config.hf_config
58
        self.vocab_size = self.config.vocab_size
59

60
61
62
63
64
        self.embed_tokens = VocabParallelEmbedding(
            self.config.vocab_size,
            self.config.hidden_size,
            prefix=maybe_prefix(prefix, "embed_tokens"),
        )
65

66
67
        self.layers = nn.ModuleList([
            LlamaDecoderLayer(
68
                vllm_config,
69
70
                i == 0,
                prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"),
71
                config=self.config,
72
73
74
75
76
77
            ) for i in range(self.config.num_hidden_layers)
        ])
        self.fc = torch.nn.Linear(self.config.hidden_size * 2,
                                  self.config.hidden_size,
                                  bias=False)

78
79
80
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

81
82
83
84
85
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
86
    ) -> tuple[torch.Tensor, torch.Tensor]:
87
88
89
90
        input_embeds = self.embed_tokens(input_ids)
        hidden_states = self.fc(
            torch.cat((input_embeds, hidden_states), dim=-1))
        residual = None
91
        for layer in self.layers:
92
93
94
95
96
            hidden_states, residual = layer(
                positions,
                hidden_states,
                residual,
            )
97
98
        hidden_states = hidden_states + residual
        return hidden_states, hidden_states
99

100
101
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
102
103
104
105
106
107
108
109
110
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
            (".gate_up_proj", ".gate_proj", 0),
            (".gate_up_proj", ".up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
111
        loaded_params: set[str] = set()
112
113
114
115
116
117
118
119
120
121
        for name, loaded_weight in weights:
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
122
123
124
125
126
127

                # if PP disabled then draft will share embed with target
                if get_pp_group().world_size == 1 and \
                    "embed_tokens." in name:
                    continue

128
129
130
131
132
133
134
135
136
137
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params


class EagleLlamaForCausalLM(LlamaForCausalLM):

138
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
139
        nn.Module.__init__(self)
140
141
        self.config = vllm_config. \
            speculative_config.draft_model_config.hf_config
142
143
144
145
146
        # Ensure draft_vocab_size is set
        # default to the base vocab size when absent
        if getattr(self.config, "draft_vocab_size", None) is None:
            base_vocab_size = getattr(self.config, "vocab_size", None)
            self.config.draft_vocab_size = base_vocab_size
147
148
        target_layer_num = vllm_config.model_config.get_num_layers(
            vllm_config.parallel_config)
149
150
        self.model = LlamaModel(vllm_config=vllm_config,
                                prefix="model",
151
                                start_layer_id=target_layer_num)
152
153
154
155
156

        logit_scale = getattr(self.config, "logit_scale", 1.0)
        self.logits_processor = LogitsProcessor(self.config.vocab_size,
                                                scale=logit_scale)

157
158
159
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

160
161
162
163
164
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
165
        inputs_embeds: Optional[torch.Tensor] = None,
166
    ) -> tuple[torch.Tensor, torch.Tensor]:
167
168
169
170
        if inputs_embeds is not None:
            raise NotImplementedError(
                f"{type(self).__name__} does not support multimodal inputs yet."
            )
171
172
        return self.model(input_ids, positions, hidden_states)

173
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
174
175
176
177
178
179
180

        def transform(inputs):
            name, loaded_weight = inputs
            if "lm_head" not in name:
                name = "model." + name
            return name, loaded_weight

181
182
        loader = AutoWeightsLoader(
            self,
183
            skip_prefixes=None,
184
        )
185
        loader.load_weights(map(transform, weights))