granitemoeshared.py 11.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
"""Inference-only GraniteMoeShared model.

The architecture is the same as granitemoe but with the addition of shared
experts.
"""
8

9
from collections.abc import Iterable
10
from itertools import islice
11
12
13
14
15
16
17
18
19
20

import torch
from torch import nn
from transformers.models.granitemoeshared import GraniteMoeSharedConfig

from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
21
22
23
24
from vllm.model_executor.layers.linear import (
    MergedColumnParallelLinear,
    RowParallelLinear,
)
25
from vllm.model_executor.layers.logits_processor import LogitsProcessor
26
from vllm.model_executor.layers.quantization import QuantizationConfig
27
from vllm.model_executor.layers.vocab_parallel_embedding import (
28
29
30
    ParallelLMHead,
    VocabParallelEmbedding,
)
31
32
from vllm.sequence import IntermediateTensors

33
from .granitemoe import GraniteMoeAttention, GraniteMoeModel, GraniteMoeMoE
34
from .interfaces import SupportsLoRA, SupportsPP
35
from .utils import AutoWeightsLoader, make_layers, maybe_prefix
36
37
38
39
40
41


class GraniteMoeSharedMLP(nn.Module):
    def __init__(
        self,
        config: GraniteMoeSharedConfig,
42
        quant_config: QuantizationConfig | None = None,
43
44
45
46
47
48
49
50
51
52
53
        prefix: str = "",
    ):
        super().__init__()

        self.input_size = config.hidden_size
        self.hidden_size = config.shared_intermediate_size
        self.input_linear = MergedColumnParallelLinear(
            input_size=self.input_size,
            output_sizes=[self.hidden_size] * 2,
            bias=False,
            quant_config=quant_config,
54
55
            prefix=f"{prefix}.input_linear",
        )
56
57
58
59
60
        self.output_linear = RowParallelLinear(
            self.hidden_size,
            self.input_size,
            bias=False,
            quant_config=quant_config,
61
62
            prefix=f"{prefix}.output_linear",
        )
63
        if config.hidden_act != "silu":
64
65
66
67
            raise ValueError(
                f"Unsupported activation: {config.hidden_act}. "
                "Only silu is supported for now."
            )
68
69
70
71
72
73
74
75
76
77
78
79
80
        self.act_fn = SiluAndMul()

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.input_linear(hidden_states)
        hidden_states = self.act_fn(hidden_states)
        hidden_states, _ = self.output_linear(hidden_states)
        return hidden_states


class GraniteMoeSharedDecoderLayer(nn.Module):
    def __init__(
        self,
        config: GraniteMoeSharedConfig,
81
82
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
83
84
85
86
87
88
89
90
91
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
        self.self_attn = GraniteMoeAttention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            max_position=config.max_position_embeddings,
            num_kv_heads=config.num_key_value_heads,
92
            rope_parameters=config.rope_parameters,
93
94
95
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.self_attn",
96
97
            attention_multiplier=config.attention_multiplier,
        )
98
99
100
101
102
103
        self.block_sparse_moe = GraniteMoeMoE(
            num_experts=config.num_local_experts,
            top_k=config.num_experts_per_tok,
            hidden_size=config.hidden_size,
            intermediate_size=config.intermediate_size,
            quant_config=quant_config,
104
105
106
107
108
            prefix=f"{prefix}.block_sparse_moe",
        )
        self.shared_mlp = (
            None
            if getattr(config, "shared_intermediate_size", 0) == 0
109
            else GraniteMoeSharedMLP(
110
                config, quant_config=quant_config, prefix=f"{prefix}.shared_mlp"
111
            )
112
        )
113

114
115
116
117
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157

        self.residual_multiplier = config.residual_multiplier

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        # Self Attention
        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states)
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
        )
        hidden_states = residual + hidden_states * self.residual_multiplier
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        if self.shared_mlp is None:
            hidden_states = self.block_sparse_moe(hidden_states)
        else:
            # create a copy since block_sparse_moe modifies in-place
            moe_hidden_states = hidden_states.clone()
            moe_hidden_states = self.block_sparse_moe(moe_hidden_states)
            hidden_states = moe_hidden_states + self.shared_mlp(hidden_states)
            del moe_hidden_states
        hidden_states = residual + hidden_states * self.residual_multiplier

        return hidden_states


@support_torch_compile
class GraniteMoeSharedModel(nn.Module):
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

158
159
        self.config = config
        self.quant_config = quant_config  # Required by MixtralModel
160
161

        self.vocab_size = config.vocab_size
162
163
164
165
166
167
168
169
170
171
172
173
174

        self.embed_tokens = VocabParallelEmbedding(
            self.vocab_size,
            config.hidden_size,
            quant_config=quant_config,
        )
        self.embedding_multiplier = config.embedding_multiplier

        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
            lambda prefix: GraniteMoeSharedDecoderLayer(
                config, cache_config, quant_config=quant_config, prefix=prefix
            ),
175
176
            prefix=f"{prefix}.layers",
        )
177
178
179

        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

180
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
181
182
183
184
        return self.embed_tokens(input_ids)

    def forward(
        self,
zhuwenwen's avatar
zhuwenwen committed
185
        input_ids: torch.Tensor | None,
186
        positions: torch.Tensor,
187
188
        intermediate_tensors: IntermediateTensors | None,
        inputs_embeds: torch.Tensor | None = None,
189
190
191
192
193
    ) -> torch.Tensor:
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
194
                hidden_states = self.embed_input_ids(input_ids)
195
196
197
198
            hidden_states *= self.embedding_multiplier
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
199
        for layer in islice(self.layers, self.start_layer, self.end_layer):
200
201
            hidden_states = layer(positions, hidden_states)
        if not get_pp_group().is_last_rank:
202
203
204
205
206
            return IntermediateTensors(
                {
                    "hidden_states": hidden_states,
                }
            )
207
208
209
        hidden_states = self.norm(hidden_states)
        return hidden_states

210
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
211
212
        new_weights = {}
        for n, p in weights:
213
            if n.endswith(".block_sparse_moe.input_linear.weight"):
214
215
                for e in range(p.size(0)):
                    w1_name = n.replace(
216
217
218
                        ".block_sparse_moe.input_linear.weight",
                        f".block_sparse_moe.experts.{e}.w1.weight",
                    )
219
                    w3_name = n.replace(
220
221
222
                        ".block_sparse_moe.input_linear.weight",
                        f".block_sparse_moe.experts.{e}.w3.weight",
                    )
223
224
225
226
227
                    w1_param, w3_param = p[e].chunk(2, dim=0)
                    assert w1_name not in new_weights
                    assert w3_name not in new_weights
                    new_weights[w1_name] = w1_param
                    new_weights[w3_name] = w3_param
228
            elif n.endswith(".block_sparse_moe.output_linear.weight"):
229
230
                for e in range(p.size(0)):
                    w2_name = n.replace(
231
232
233
                        ".block_sparse_moe.output_linear.weight",
                        f".block_sparse_moe.experts.{e}.w2.weight",
                    )
234
235
236
                    w2_param = p[e]
                    assert w2_name not in new_weights
                    new_weights[w2_name] = w2_param
237
238
239
240
241
            elif n.endswith(".block_sparse_moe.router.layer.weight"):
                gate_name = n.replace(
                    ".block_sparse_moe.router.layer.weight",
                    ".block_sparse_moe.gate.weight",
                )
242
243
244
245
                assert gate_name not in new_weights
                new_weights[gate_name] = p
            else:
                new_weights[n] = p
246
        return GraniteMoeModel._load_weights(self, new_weights.items())
247

248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272

class GraniteMoeSharedForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
    fall_back_to_pt_during_load = False

    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
    }

    # LoRA specific attributes
    embedding_modules = {
        "embed_tokens": "input_embeddings",
        "lm_head": "output_embeddings",
    }

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config

        self.config = config

273
274
275
        self.model = GraniteMoeSharedModel(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
276

277
        self.lm_head = ParallelLMHead(
278
            config.vocab_size,
279
280
            config.hidden_size,
            quant_config=quant_config,
281
282
            prefix=maybe_prefix(prefix, "lm_head"),
        )
283
284
285
        if config.tie_word_embeddings:
            self.lm_head.weight = self.model.embed_tokens.weight

286
        self.logits_processor = LogitsProcessor(
287
            config.vocab_size,
288
289
290
            config.vocab_size,
            scale=1 / self.config.logits_scaling,
        )
291

292
293
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.embed_input_ids(input_ids)
294
295
296

    def forward(
        self,
zhuwenwen's avatar
zhuwenwen committed
297
        input_ids: torch.Tensor | None,
298
        positions: torch.Tensor,
299
300
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
301
    ) -> torch.Tensor:
302
303
304
        hidden_states = self.model(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
305
306
        return hidden_states

307
    def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor | None:
308
        logits = self.logits_processor(self.lm_head, hidden_states)
309
310
311
        return logits

    def make_empty_intermediate_tensors(
312
313
314
315
316
317
318
319
320
321
322
        self, batch_size: int, dtype: torch.dtype, device: torch.device
    ) -> IntermediateTensors:
        return IntermediateTensors(
            {
                "hidden_states": torch.zeros(
                    (batch_size, self.config.hidden_size), dtype=dtype, device=device
                ),
            }
        )

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
323
324
        loader = AutoWeightsLoader(
            self,
325
            skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
326
        )
zhuwenwen's avatar
zhuwenwen committed
327
        return loader.load_weights(weights)