granitemoeshared.py 12 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
"""Inference-only GraniteMoeShared model.

The architecture is the same as granitemoe but with the addition of shared
experts.
"""
8

9
from collections.abc import Iterable
10
from itertools import islice
11
12
13
14
15
16
17
18
19
20

import torch
from torch import nn
from transformers.models.granitemoeshared import GraniteMoeSharedConfig

from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
21
22
23
24
from vllm.model_executor.layers.linear import (
    MergedColumnParallelLinear,
    RowParallelLinear,
)
25
from vllm.model_executor.layers.logits_processor import LogitsProcessor
26
from vllm.model_executor.layers.quantization import QuantizationConfig
27
from vllm.model_executor.layers.vocab_parallel_embedding import (
28
29
30
    ParallelLMHead,
    VocabParallelEmbedding,
)
31
32
from vllm.sequence import IntermediateTensors

33
from .granitemoe import GraniteMoeAttention, GraniteMoeModel, GraniteMoeMoE
34
from .interfaces import SupportsLoRA, SupportsPP
35
from .utils import AutoWeightsLoader, make_layers, maybe_prefix
36
37
38
39
40
41


class GraniteMoeSharedMLP(nn.Module):
    def __init__(
        self,
        config: GraniteMoeSharedConfig,
42
        quant_config: QuantizationConfig | None = None,
43
44
45
46
47
48
49
50
51
52
53
        prefix: str = "",
    ):
        super().__init__()

        self.input_size = config.hidden_size
        self.hidden_size = config.shared_intermediate_size
        self.input_linear = MergedColumnParallelLinear(
            input_size=self.input_size,
            output_sizes=[self.hidden_size] * 2,
            bias=False,
            quant_config=quant_config,
54
55
            prefix=f"{prefix}.input_linear",
        )
56
57
58
59
60
        self.output_linear = RowParallelLinear(
            self.hidden_size,
            self.input_size,
            bias=False,
            quant_config=quant_config,
61
62
            prefix=f"{prefix}.output_linear",
        )
63
        if config.hidden_act != "silu":
64
65
66
67
            raise ValueError(
                f"Unsupported activation: {config.hidden_act}. "
                "Only silu is supported for now."
            )
68
69
70
71
72
73
74
75
76
77
78
79
80
        self.act_fn = SiluAndMul()

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.input_linear(hidden_states)
        hidden_states = self.act_fn(hidden_states)
        hidden_states, _ = self.output_linear(hidden_states)
        return hidden_states


class GraniteMoeSharedDecoderLayer(nn.Module):
    def __init__(
        self,
        config: GraniteMoeSharedConfig,
81
82
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
83
84
85
86
87
88
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
        # Requires transformers > 4.32.0
        rope_theta = getattr(config, "rope_theta", 10000)
89
        rope_scaling = getattr(config, "rope_scaling", None)
90
91
92
93
94
95
        self.self_attn = GraniteMoeAttention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            max_position=config.max_position_embeddings,
            num_kv_heads=config.num_key_value_heads,
            rope_theta=rope_theta,
96
            rope_scaling=rope_scaling,
97
98
99
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.self_attn",
100
101
            attention_multiplier=config.attention_multiplier,
        )
102
103
104
105
106
107
        self.block_sparse_moe = GraniteMoeMoE(
            num_experts=config.num_local_experts,
            top_k=config.num_experts_per_tok,
            hidden_size=config.hidden_size,
            intermediate_size=config.intermediate_size,
            quant_config=quant_config,
108
109
110
111
112
            prefix=f"{prefix}.block_sparse_moe",
        )
        self.shared_mlp = (
            None
            if getattr(config, "shared_intermediate_size", 0) == 0
113
            else GraniteMoeSharedMLP(
114
                config, quant_config=quant_config, prefix=f"{prefix}.shared_mlp"
115
            )
116
        )
117

118
119
120
121
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161

        self.residual_multiplier = config.residual_multiplier

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        # Self Attention
        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states)
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
        )
        hidden_states = residual + hidden_states * self.residual_multiplier
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        if self.shared_mlp is None:
            hidden_states = self.block_sparse_moe(hidden_states)
        else:
            # create a copy since block_sparse_moe modifies in-place
            moe_hidden_states = hidden_states.clone()
            moe_hidden_states = self.block_sparse_moe(moe_hidden_states)
            hidden_states = moe_hidden_states + self.shared_mlp(hidden_states)
            del moe_hidden_states
        hidden_states = residual + hidden_states * self.residual_multiplier

        return hidden_states


@support_torch_compile
class GraniteMoeSharedModel(nn.Module):
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

162
163
        self.config = config
        self.quant_config = quant_config  # Required by MixtralModel
164
        self.padding_idx = config.pad_token_id
165
166

        self.vocab_size = config.vocab_size
167
168
169
170
171
172
173
174
175
176
177
178
179

        self.embed_tokens = VocabParallelEmbedding(
            self.vocab_size,
            config.hidden_size,
            quant_config=quant_config,
        )
        self.embedding_multiplier = config.embedding_multiplier

        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
            lambda prefix: GraniteMoeSharedDecoderLayer(
                config, cache_config, quant_config=quant_config, prefix=prefix
            ),
180
181
            prefix=f"{prefix}.layers",
        )
182
183
184

        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

185
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
186
187
188
189
190
191
        return self.embed_tokens(input_ids)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
192
193
        intermediate_tensors: IntermediateTensors | None,
        inputs_embeds: torch.Tensor | None = None,
194
195
196
197
198
    ) -> torch.Tensor:
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
199
                hidden_states = self.embed_input_ids(input_ids)
200
201
202
203
            hidden_states *= self.embedding_multiplier
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
204
        for layer in islice(self.layers, self.start_layer, self.end_layer):
205
206
            hidden_states = layer(positions, hidden_states)
        if not get_pp_group().is_last_rank:
207
208
209
210
211
            return IntermediateTensors(
                {
                    "hidden_states": hidden_states,
                }
            )
212
213
214
        hidden_states = self.norm(hidden_states)
        return hidden_states

215
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
216
217
        new_weights = {}
        for n, p in weights:
218
            if n.endswith(".block_sparse_moe.input_linear.weight"):
219
220
                for e in range(p.size(0)):
                    w1_name = n.replace(
221
222
223
                        ".block_sparse_moe.input_linear.weight",
                        f".block_sparse_moe.experts.{e}.w1.weight",
                    )
224
                    w3_name = n.replace(
225
226
227
                        ".block_sparse_moe.input_linear.weight",
                        f".block_sparse_moe.experts.{e}.w3.weight",
                    )
228
229
230
231
232
                    w1_param, w3_param = p[e].chunk(2, dim=0)
                    assert w1_name not in new_weights
                    assert w3_name not in new_weights
                    new_weights[w1_name] = w1_param
                    new_weights[w3_name] = w3_param
233
            elif n.endswith(".block_sparse_moe.output_linear.weight"):
234
235
                for e in range(p.size(0)):
                    w2_name = n.replace(
236
237
238
                        ".block_sparse_moe.output_linear.weight",
                        f".block_sparse_moe.experts.{e}.w2.weight",
                    )
239
240
241
                    w2_param = p[e]
                    assert w2_name not in new_weights
                    new_weights[w2_name] = w2_param
242
243
244
245
246
            elif n.endswith(".block_sparse_moe.router.layer.weight"):
                gate_name = n.replace(
                    ".block_sparse_moe.router.layer.weight",
                    ".block_sparse_moe.gate.weight",
                )
247
248
249
250
                assert gate_name not in new_weights
                new_weights[gate_name] = p
            else:
                new_weights[n] = p
251
        return GraniteMoeModel._load_weights(self, new_weights.items())
252

253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278

class GraniteMoeSharedForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
    fall_back_to_pt_during_load = False

    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
    }

    # LoRA specific attributes
    embedding_modules = {
        "embed_tokens": "input_embeddings",
        "lm_head": "output_embeddings",
    }
    embedding_padding_modules = ["lm_head"]

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config

        self.config = config

279
280
281
        self.model = GraniteMoeSharedModel(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
282

283
        self.lm_head = ParallelLMHead(
284
            config.vocab_size,
285
286
            config.hidden_size,
            quant_config=quant_config,
287
288
            prefix=maybe_prefix(prefix, "lm_head"),
        )
289
290
291
        if config.tie_word_embeddings:
            self.lm_head.weight = self.model.embed_tokens.weight

292
        self.logits_processor = LogitsProcessor(
293
            config.vocab_size,
294
295
296
            config.vocab_size,
            scale=1 / self.config.logits_scaling,
        )
297

298
299
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.embed_input_ids(input_ids)
300
301
302
303
304

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
305
306
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
307
    ) -> torch.Tensor:
308
309
310
        hidden_states = self.model(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
311
312
        return hidden_states

313
    def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor | None:
314
        logits = self.logits_processor(self.lm_head, hidden_states)
315
316
317
        return logits

    def make_empty_intermediate_tensors(
318
319
320
321
322
323
324
325
326
327
328
        self, batch_size: int, dtype: torch.dtype, device: torch.device
    ) -> IntermediateTensors:
        return IntermediateTensors(
            {
                "hidden_states": torch.zeros(
                    (batch_size, self.config.hidden_size), dtype=dtype, device=device
                ),
            }
        )

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
329
330
        loader = AutoWeightsLoader(
            self,
331
            skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
332
333
        )
        return loader.load_weights(weights)