granitemoeshared.py 12.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
"""Inference-only GraniteMoeShared model.

The architecture is the same as granitemoe but with the addition of shared
experts.
"""
8

9
from collections.abc import Iterable
10
from itertools import islice
11
from typing import Optional
12
13
14
15
16
17
18
19
20
21

import torch
from torch import nn
from transformers.models.granitemoeshared import GraniteMoeSharedConfig

from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
22
23
24
25
from vllm.model_executor.layers.linear import (
    MergedColumnParallelLinear,
    RowParallelLinear,
)
26
from vllm.model_executor.layers.logits_processor import LogitsProcessor
27
from vllm.model_executor.layers.quantization import QuantizationConfig
28
from vllm.model_executor.layers.vocab_parallel_embedding import (
29
30
31
32
    DEFAULT_VOCAB_PADDING_SIZE,
    ParallelLMHead,
    VocabParallelEmbedding,
)
33
34
from vllm.sequence import IntermediateTensors

35
from .granitemoe import GraniteMoeAttention, GraniteMoeModel, GraniteMoeMoE
36
from .interfaces import SupportsLoRA, SupportsPP
37
from .utils import AutoWeightsLoader, make_layers, maybe_prefix
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55


class GraniteMoeSharedMLP(nn.Module):
    def __init__(
        self,
        config: GraniteMoeSharedConfig,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ):
        super().__init__()

        self.input_size = config.hidden_size
        self.hidden_size = config.shared_intermediate_size
        self.input_linear = MergedColumnParallelLinear(
            input_size=self.input_size,
            output_sizes=[self.hidden_size] * 2,
            bias=False,
            quant_config=quant_config,
56
57
            prefix=f"{prefix}.input_linear",
        )
58
59
60
61
62
        self.output_linear = RowParallelLinear(
            self.hidden_size,
            self.input_size,
            bias=False,
            quant_config=quant_config,
63
64
            prefix=f"{prefix}.output_linear",
        )
65
        if config.hidden_act != "silu":
66
67
68
69
            raise ValueError(
                f"Unsupported activation: {config.hidden_act}. "
                "Only silu is supported for now."
            )
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
        self.act_fn = SiluAndMul()

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.input_linear(hidden_states)
        hidden_states = self.act_fn(hidden_states)
        hidden_states, _ = self.output_linear(hidden_states)
        return hidden_states


class GraniteMoeSharedDecoderLayer(nn.Module):
    def __init__(
        self,
        config: GraniteMoeSharedConfig,
        cache_config: Optional[CacheConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
        # Requires transformers > 4.32.0
        rope_theta = getattr(config, "rope_theta", 10000)
91
        rope_scaling = getattr(config, "rope_scaling", None)
92
93
94
95
96
97
        self.self_attn = GraniteMoeAttention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            max_position=config.max_position_embeddings,
            num_kv_heads=config.num_key_value_heads,
            rope_theta=rope_theta,
98
            rope_scaling=rope_scaling,
99
100
101
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.self_attn",
102
103
            attention_multiplier=config.attention_multiplier,
        )
104
105
106
107
108
109
        self.block_sparse_moe = GraniteMoeMoE(
            num_experts=config.num_local_experts,
            top_k=config.num_experts_per_tok,
            hidden_size=config.hidden_size,
            intermediate_size=config.intermediate_size,
            quant_config=quant_config,
110
111
112
113
114
            prefix=f"{prefix}.block_sparse_moe",
        )
        self.shared_mlp = (
            None
            if getattr(config, "shared_intermediate_size", 0) == 0
115
            else GraniteMoeSharedMLP(
116
                config, quant_config=quant_config, prefix=f"{prefix}.shared_mlp"
117
            )
118
        )
119

120
121
122
123
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164

        self.residual_multiplier = config.residual_multiplier

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        # Self Attention
        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states)
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
        )
        hidden_states = residual + hidden_states * self.residual_multiplier
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        if self.shared_mlp is None:
            hidden_states = self.block_sparse_moe(hidden_states)
        else:
            # create a copy since block_sparse_moe modifies in-place
            moe_hidden_states = hidden_states.clone()
            moe_hidden_states = self.block_sparse_moe(moe_hidden_states)
            hidden_states = moe_hidden_states + self.shared_mlp(hidden_states)
            del moe_hidden_states
        hidden_states = residual + hidden_states * self.residual_multiplier

        return hidden_states


@support_torch_compile
class GraniteMoeSharedModel(nn.Module):
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config

165
166
        self.config = config
        self.quant_config = quant_config  # Required by MixtralModel
167
        self.padding_idx = config.pad_token_id
168
169
170
171
172
        lora_vocab = (
            (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1))
            if lora_config
            else 0
        )
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
        self.vocab_size = config.vocab_size + lora_vocab
        self.org_vocab_size = config.vocab_size

        self.embed_tokens = VocabParallelEmbedding(
            self.vocab_size,
            config.hidden_size,
            org_num_embeddings=config.vocab_size,
            quant_config=quant_config,
        )
        self.embedding_multiplier = config.embedding_multiplier

        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
            lambda prefix: GraniteMoeSharedDecoderLayer(
                config, cache_config, quant_config=quant_config, prefix=prefix
            ),
189
190
            prefix=f"{prefix}.layers",
        )
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212

        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors],
        inputs_embeds: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
            hidden_states *= self.embedding_multiplier
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
213
        for layer in islice(self.layers, self.start_layer, self.end_layer):
214
215
            hidden_states = layer(positions, hidden_states)
        if not get_pp_group().is_last_rank:
216
217
218
219
220
            return IntermediateTensors(
                {
                    "hidden_states": hidden_states,
                }
            )
221
222
223
        hidden_states = self.norm(hidden_states)
        return hidden_states

224
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
225
226
        new_weights = {}
        for n, p in weights:
227
            if n.endswith(".block_sparse_moe.input_linear.weight"):
228
229
                for e in range(p.size(0)):
                    w1_name = n.replace(
230
231
232
                        ".block_sparse_moe.input_linear.weight",
                        f".block_sparse_moe.experts.{e}.w1.weight",
                    )
233
                    w3_name = n.replace(
234
235
236
                        ".block_sparse_moe.input_linear.weight",
                        f".block_sparse_moe.experts.{e}.w3.weight",
                    )
237
238
239
240
241
                    w1_param, w3_param = p[e].chunk(2, dim=0)
                    assert w1_name not in new_weights
                    assert w3_name not in new_weights
                    new_weights[w1_name] = w1_param
                    new_weights[w3_name] = w3_param
242
            elif n.endswith(".block_sparse_moe.output_linear.weight"):
243
244
                for e in range(p.size(0)):
                    w2_name = n.replace(
245
246
247
                        ".block_sparse_moe.output_linear.weight",
                        f".block_sparse_moe.experts.{e}.w2.weight",
                    )
248
249
250
                    w2_param = p[e]
                    assert w2_name not in new_weights
                    new_weights[w2_name] = w2_param
251
252
253
254
255
            elif n.endswith(".block_sparse_moe.router.layer.weight"):
                gate_name = n.replace(
                    ".block_sparse_moe.router.layer.weight",
                    ".block_sparse_moe.gate.weight",
                )
256
257
258
259
                assert gate_name not in new_weights
                new_weights[gate_name] = p
            else:
                new_weights[n] = p
260
        return GraniteMoeModel._load_weights(self, new_weights.items())
261

262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289

class GraniteMoeSharedForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
    fall_back_to_pt_during_load = False

    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
    }

    # LoRA specific attributes
    embedding_modules = {
        "embed_tokens": "input_embeddings",
        "lm_head": "output_embeddings",
    }
    embedding_padding_modules = ["lm_head"]

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config

        self.config = config
        self.lora_config = lora_config

290
291
292
        self.model = GraniteMoeSharedModel(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
293
294
295
296
297
298
299
300
301
302
        self.unpadded_vocab_size = config.vocab_size
        if lora_config:
            self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
        self.lm_head = ParallelLMHead(
            self.unpadded_vocab_size,
            config.hidden_size,
            org_num_embeddings=config.vocab_size,
            padding_size=DEFAULT_VOCAB_PADDING_SIZE
            # We need bigger padding if using lora for kernel
            # compatibility
303
304
            if not lora_config
            else lora_config.lora_vocab_padding_size,
305
            quant_config=quant_config,
306
307
            prefix=maybe_prefix(prefix, "lm_head"),
        )
308
309
310
        if config.tie_word_embeddings:
            self.lm_head.weight = self.model.embed_tokens.weight

311
312
313
314
315
        self.logits_processor = LogitsProcessor(
            self.unpadded_vocab_size,
            config.vocab_size,
            scale=1 / self.config.logits_scaling,
        )
316
317
318
319
320
321
322
323
324
325
326

    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
327
328
329
        hidden_states = self.model(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
330
331
        return hidden_states

332
    def compute_logits(self, hidden_states: torch.Tensor) -> Optional[torch.Tensor]:
333
        logits = self.logits_processor(self.lm_head, hidden_states)
334
335
336
        return logits

    def make_empty_intermediate_tensors(
337
338
339
340
341
342
343
344
345
346
347
        self, batch_size: int, dtype: torch.dtype, device: torch.device
    ) -> IntermediateTensors:
        return IntermediateTensors(
            {
                "hidden_states": torch.zeros(
                    (batch_size, self.config.hidden_size), dtype=dtype, device=device
                ),
            }
        )

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
348
349
        loader = AutoWeightsLoader(
            self,
350
            skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
351
352
        )
        return loader.load_weights(weights)