llama_eagle.py 5.88 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from collections.abc import Iterable
5
6
7
8
9

import torch
import torch.nn as nn
from transformers import LlamaConfig

10
11
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
12
from vllm.distributed.parallel_state import get_pp_group
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
from vllm.logger import init_logger
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import (
    VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.llama import (LlamaDecoderLayer,
                                              LlamaForCausalLM)

from .utils import AutoWeightsLoader, maybe_prefix

logger = init_logger(__name__)


class LlamaDecoderLayer(LlamaDecoderLayer):

    def __init__(
        self,
        config: LlamaConfig,
        disable_input_layernorm: bool,
        prefix: str = "",
    ) -> None:
        super().__init__(config, prefix=prefix)

        # Skip the input_layernorm
        # https://github.com/SafeAILab/EAGLE/blob/35c78f6cdc19a73e05cf5c330b4c358dad970c6a/eagle/model/cnets.py#L427
        if disable_input_layernorm:
            del self.input_layernorm
            self.input_layernorm = nn.Identity()


43
@support_torch_compile
44
45
46
47
48
class LlamaModel(nn.Module):

    def __init__(
        self,
        *,
49
        vllm_config: VllmConfig,
50
        prefix: str = "",
51
        start_layer_id: int = 0,
52
53
    ) -> None:
        super().__init__()
54
55
        self.config = vllm_config. \
            speculative_config.draft_model_config.hf_config
56
        self.vocab_size = self.config.vocab_size
57
58
59
60
61
62
63
64
65

        # if PP disabled then draft will share embed with target
        if get_pp_group().world_size > 1:
            self.embed_tokens = VocabParallelEmbedding(
                self.config.vocab_size,
                self.config.hidden_size,
                prefix=maybe_prefix(prefix, "embed_tokens"),
            )

66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
        self.layers = nn.ModuleList([
            LlamaDecoderLayer(
                self.config,
                i == 0,
                prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"),
            ) for i in range(self.config.num_hidden_layers)
        ])
        self.fc = torch.nn.Linear(self.config.hidden_size * 2,
                                  self.config.hidden_size,
                                  bias=False)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
82
    ) -> tuple[torch.Tensor, torch.Tensor]:
83
84
85
86
        input_embeds = self.embed_tokens(input_ids)
        hidden_states = self.fc(
            torch.cat((input_embeds, hidden_states), dim=-1))
        residual = None
87
        for layer in self.layers:
88
89
90
91
92
            hidden_states, residual = layer(
                positions,
                hidden_states,
                residual,
            )
93
94
        hidden_states = hidden_states + residual
        return hidden_states, hidden_states
95

96
97
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
98
99
100
101
102
103
104
105
106
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
            (".gate_up_proj", ".gate_proj", 0),
            (".gate_up_proj", ".up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
107
        loaded_params: set[str] = set()
108
109
110
111
112
113
114
115
116
117
        for name, loaded_weight in weights:
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
118
119
120
121
122
123

                # if PP disabled then draft will share embed with target
                if get_pp_group().world_size == 1 and \
                    "embed_tokens." in name:
                    continue

124
125
126
127
128
129
130
131
132
133
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params


class EagleLlamaForCausalLM(LlamaForCausalLM):

134
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
135
        nn.Module.__init__(self)
136
137
        self.config = vllm_config. \
            speculative_config.draft_model_config.hf_config
138
139
        target_layer_num = vllm_config.model_config.get_num_layers(
            vllm_config.parallel_config)
140
141
        self.model = LlamaModel(vllm_config=vllm_config,
                                prefix="model",
142
                                start_layer_id=target_layer_num)
143
144
145
146
147
148
149
150
151
152

        logit_scale = getattr(self.config, "logit_scale", 1.0)
        self.logits_processor = LogitsProcessor(self.config.vocab_size,
                                                scale=logit_scale)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
153
    ) -> tuple[torch.Tensor, torch.Tensor]:
154
155
        return self.model(input_ids, positions, hidden_states)

156
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
157
158
        loader = AutoWeightsLoader(
            self,
159
            skip_prefixes=None,
160
161
162
163
164
165
166
        )

        model_weights = {}
        for name, loaded_weight in weights:
            if "lm_head" not in name:
                name = "model." + name
            model_weights[name] = loaded_weight
167
        return loader.load_weights(model_weights.items())