granitemoe.py 20.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only GraniteMoe model."""
26

27
from collections.abc import Iterable
28
from itertools import islice
29
from typing import Any
30
31
32
33

import torch
from torch import nn

34
from vllm.attention.layer import Attention
35
from vllm.compilation.decorators import support_torch_compile
36
from vllm.config import CacheConfig, VllmConfig
37
38
39
40
41
from vllm.distributed import (
    get_pp_group,
    get_tensor_model_parallel_world_size,
    tensor_model_parallel_all_gather,
)
42
43
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
44
45
46
47
48
from vllm.model_executor.layers.linear import (
    QKVParallelLinear,
    ReplicatedLinear,
    RowParallelLinear,
)
49
from vllm.model_executor.layers.logits_processor import LogitsProcessor
50
from vllm.model_executor.layers.quantization import QuantizationConfig
51
52
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
53
54
55
    ParallelLMHead,
    VocabParallelEmbedding,
)
56
from vllm.model_executor.model_loader.weight_utils import (
57
58
59
    default_weight_loader,
    maybe_remap_kv_scale_name,
)
60
from vllm.model_executor.models.utils import sequence_parallel_chunk
61
62
from vllm.sequence import IntermediateTensors

63
from .interfaces import SupportsLoRA, SupportsPP
64
from .utils import AutoWeightsLoader, is_pp_missing_parameter, make_layers, maybe_prefix
65
66
67
68
69
70
71
72
73
74


class GraniteMoeMoE(nn.Module):
    """A tensor-parallel MoE implementation for GraniteMoe that shards each
    expert across all ranks.
    Each expert's weights are sharded across all ranks and a fused MoE
    kernel is used for the forward pass, and finally we reduce the outputs
    across ranks.
    """

75
76
77
78
79
80
    def __init__(
        self,
        num_experts: int,
        top_k: int,
        hidden_size: int,
        intermediate_size: int,
81
82
83
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
        tp_size: int | None = None,
84
85
86
        is_sequence_parallel=False,
        prefix: str = "",
    ):
87
88
        super().__init__()
        self.hidden_size = hidden_size
89
        self.is_sequence_parallel = is_sequence_parallel
90
91

        # Gate always runs at half / full precision for now.
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
        self.gate = ReplicatedLinear(
            hidden_size,
            num_experts,
            bias=False,
            params_dtype=params_dtype,
            quant_config=None,
            prefix=f"{prefix}.gate",
        )

        self.experts = FusedMoE(
            num_experts=num_experts,
            top_k=top_k,
            hidden_size=hidden_size,
            intermediate_size=intermediate_size,
            params_dtype=params_dtype,
            reduce_results=True,
            renormalize=True,
            quant_config=quant_config,
            tp_size=tp_size,
            prefix=f"{prefix}.experts",
            is_sequence_parallel=self.is_sequence_parallel,
        )
114
115
116
117
118

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        # NOTE: hidden_states can have either 1D or 2D shape.
        orig_shape = hidden_states.shape
        hidden_states = hidden_states.view(-1, self.hidden_size)
119
120
121
122

        if self.is_sequence_parallel:
            hidden_states = sequence_parallel_chunk(hidden_states)

123
124
125
        # router_logits: (num_tokens, n_experts)
        router_logits, _ = self.gate(hidden_states)
        final_hidden_states = self.experts(hidden_states, router_logits)
126
127
128

        if self.is_sequence_parallel:
            final_hidden_states = tensor_model_parallel_all_gather(
129
130
                final_hidden_states, 0
            )
131
132
133
            num_tokens = orig_shape[0]
            final_hidden_states = final_hidden_states[:num_tokens]

134
135
136
137
138
139
140
141
142
143
        return final_hidden_states.view(orig_shape)


class GraniteMoeAttention(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        max_position: int = 4096 * 32,
144
        rope_parameters: dict[str, Any] | None = None,
145
146
147
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
        attention_multiplier: float | None = None,
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
169
170
171
172
173
        self.scaling = (
            attention_multiplier
            if attention_multiplier is not None
            else self.head_dim**-1
        )
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193

        self.qkv_proj = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj",
        )
        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.o_proj",
        )
        self.rotary_emb = get_rope(
            self.head_dim,
            max_position=max_position,
194
            rope_parameters=rope_parameters,
195
196
            is_neox_style=True,
        )
197
198
199
200
201
202
203
204
205
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
        )
206
207
208
209
210
211
212
213
214

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
215
        attn_output = self.attn(q, k, v)
216
217
218
219
220
221
222
        output, _ = self.o_proj(attn_output)
        return output


class GraniteMoeDecoderLayer(nn.Module):
    def __init__(
        self,
223
        vllm_config: VllmConfig,
224
225
226
        prefix: str = "",
    ) -> None:
        super().__init__()
227
228
229
230
231
232

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
        parallel_config = vllm_config.parallel_config

233
234
235
236
237
238
        self.hidden_size = config.hidden_size
        self.self_attn = GraniteMoeAttention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            max_position=config.max_position_embeddings,
            num_kv_heads=config.num_key_value_heads,
239
            rope_parameters=config.rope_parameters,
240
241
242
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.self_attn",
243
244
            attention_multiplier=config.attention_multiplier,
        )
245
246
247
248
249
250
        self.block_sparse_moe = GraniteMoeMoE(
            num_experts=config.num_local_experts,
            top_k=config.num_experts_per_tok,
            hidden_size=config.hidden_size,
            intermediate_size=config.intermediate_size,
            quant_config=quant_config,
251
            is_sequence_parallel=parallel_config.use_sequence_parallel_moe,
252
253
            prefix=f"{prefix}.block_sparse_moe",
        )
254

255
256
257
258
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282

        self.residual_multiplier = config.residual_multiplier

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        # Self Attention
        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states)
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
        )
        hidden_states = residual + hidden_states * self.residual_multiplier
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.block_sparse_moe(hidden_states)
        hidden_states = residual + hidden_states * self.residual_multiplier

        return hidden_states


283
@support_torch_compile
284
class GraniteMoeModel(nn.Module):
285
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
286
        super().__init__()
287
288
289
290

        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config

291
292
        self.config = config
        self.quant_config = quant_config  # Required by MixtralModel
293
294

        self.vocab_size = config.vocab_size
295
296
297
298
299
300
301
302
303

        self.embed_tokens = VocabParallelEmbedding(
            self.vocab_size,
            config.hidden_size,
        )
        self.embedding_multiplier = config.embedding_multiplier

        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
304
            lambda prefix: GraniteMoeDecoderLayer(vllm_config, prefix=prefix),
305
306
            prefix=f"{prefix}.layers",
        )
307
308
309

        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

310
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
311
312
        return self.embed_tokens(input_ids)

313
314
    def forward(
        self,
315
        input_ids: torch.Tensor | None,
316
        positions: torch.Tensor,
317
318
        intermediate_tensors: IntermediateTensors | None,
        inputs_embeds: torch.Tensor | None = None,
319
320
    ) -> torch.Tensor:
        if get_pp_group().is_first_rank:
321
322
323
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
324
                hidden_states = self.embed_input_ids(input_ids)
325
326
327
328
            hidden_states *= self.embedding_multiplier
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
329
        for layer in islice(self.layers, self.start_layer, self.end_layer):
330
            hidden_states = layer(positions, hidden_states)
331
        if not get_pp_group().is_last_rank:
332
333
334
335
336
            return IntermediateTensors(
                {
                    "hidden_states": hidden_states,
                }
            )
337
338
339
        hidden_states = self.norm(hidden_states)
        return hidden_states

340
    def _load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
341
        """
342
343
        This function is copied from `MixtralModel.load_weights`, mainly to
        decouple from mixtral, avoiding impact on support like BNB
344
345
346
347
348
349
350
351
352
353
354
355
        quantization.
        """
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]

        # Params for weights, fp8 weight scales, fp8 activation scales
        # (param_name, weight_name, expert_id, shard_id)
        expert_params_mapping = FusedMoE.make_expert_params_mapping(
356
            self,
357
358
359
            ckpt_gate_proj_name="w1",
            ckpt_down_proj_name="w2",
            ckpt_up_proj_name="w3",
360
361
            num_experts=self.config.num_local_experts,
        )
362
363
364
365

        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()
        for name, loaded_weight in weights:
366
367
368
            if self.quant_config is not None and (
                scale_name := self.quant_config.get_cache_scale(name)
            ):
369
370
                # Loading kv cache quantization scales
                param = params_dict[scale_name]
371
372
373
374
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                loaded_weight = (
                    loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
                )
375
376
377
378
                weight_loader(param, loaded_weight)
                loaded_params.add(scale_name)
                continue

379
            for param_name, weight_name, shard_id in stacked_params_mapping:
380
381
382
383
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
384
385
386
                if (
                    name.endswith(".bias") or name.endswith("_bias")
                ) and name not in params_dict:
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
                    continue
                # Skip layers on other devices.
                if is_pp_missing_parameter(name, self):
                    continue
                if name.endswith("scale"):
                    # Remapping the name of FP8 kv-scale.
                    name = maybe_remap_kv_scale_name(name, params_dict)
                    if name is None:
                        continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                for mapping in expert_params_mapping:
                    param_name, weight_name, expert_id, shard_id = mapping
                    if weight_name not in name:
                        continue
                    name = name.replace(weight_name, param_name)
                    # Skip layers on other devices.
                    if is_pp_missing_parameter(name, self):
                        continue
409
410
411
                    if (
                        name.endswith(".bias") or name.endswith("_bias")
                    ) and name not in params_dict:
412
413
414
                        continue
                    param = params_dict[name]
                    weight_loader = param.weight_loader
415
416
417
418
419
420
421
                    weight_loader(
                        param,
                        loaded_weight,
                        name,
                        shard_id=shard_id,
                        expert_id=expert_id,
                    )
422
423
424
                    break
                else:
                    # Skip loading extra bias for GPTQ models.
425
426
427
                    if (
                        name.endswith(".bias") or name.endswith("_bias")
                    ) and name not in params_dict:
428
429
430
431
432
433
434
435
436
437
                        continue
                    # Skip layers on other devices.
                    if is_pp_missing_parameter(name, self):
                        continue
                    # Remapping the name of FP8 kv-scale.
                    name = maybe_remap_kv_scale_name(name, params_dict)
                    if name is None:
                        continue

                    param = params_dict[name]
438
439
440
                    weight_loader = getattr(
                        param, "weight_loader", default_weight_loader
                    )
441
442
443
444
                    weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

445
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
446
447
        new_weights = {}
        for n, p in weights:
448
            if n.endswith(".block_sparse_moe.input_linear.weight"):
449
450
                for e in range(p.size(0)):
                    w1_name = n.replace(
451
452
453
                        ".block_sparse_moe.input_linear.weight",
                        f".block_sparse_moe.experts.{e}.w1.weight",
                    )
454
                    w3_name = n.replace(
455
456
457
                        ".block_sparse_moe.input_linear.weight",
                        f".block_sparse_moe.experts.{e}.w3.weight",
                    )
458
459
460
461
462
                    w1_param, w3_param = p[e].chunk(2, dim=0)
                    assert w1_name not in new_weights
                    assert w3_name not in new_weights
                    new_weights[w1_name] = w1_param
                    new_weights[w3_name] = w3_param
463
            elif n.endswith(".block_sparse_moe.output_linear.weight"):
464
465
                for e in range(p.size(0)):
                    w2_name = n.replace(
466
467
468
                        ".block_sparse_moe.output_linear.weight",
                        f".block_sparse_moe.experts.{e}.w2.weight",
                    )
469
470
471
                    w2_param = p[e]
                    assert w2_name not in new_weights
                    new_weights[w2_name] = w2_param
472
473
474
475
476
            elif n.endswith(".block_sparse_moe.router.layer.weight"):
                gate_name = n.replace(
                    ".block_sparse_moe.router.layer.weight",
                    ".block_sparse_moe.gate.weight",
                )
477
478
479
480
                assert gate_name not in new_weights
                new_weights[gate_name] = p
            else:
                new_weights[n] = p
481
        return self._load_weights(new_weights.items())
482

483

484
class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
    fall_back_to_pt_during_load = False

    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
    }

    # LoRA specific attributes
    embedding_modules = {
        "embed_tokens": "input_embeddings",
        "lm_head": "output_embeddings",
    }

501
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
502
        super().__init__()
503
504
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
505
506
507

        self.config = config

508
509
510
        self.model = GraniteMoeModel(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
511

512
        self.lm_head = ParallelLMHead(
513
            config.vocab_size,
514
515
            config.hidden_size,
            quant_config=quant_config,
516
            prefix=maybe_prefix(prefix, "lm_head"),
517
518
519
520
        )
        if config.tie_word_embeddings:
            self.lm_head.weight = self.model.embed_tokens.weight

521
522
523
524
        self.logits_processor = LogitsProcessor(
            config.vocab_size,
            scale=1 / self.config.logits_scaling,
        )
525

526
527
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.embed_input_ids(input_ids)
528

529
530
    def forward(
        self,
531
        input_ids: torch.Tensor | None,
532
        positions: torch.Tensor,
533
534
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
535
    ) -> torch.Tensor:
536
537
538
        hidden_states = self.model(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
539
540
        return hidden_states

541
    def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor | None:
542
        logits = self.logits_processor(self.lm_head, hidden_states)
543
544
545
        return logits

    def make_empty_intermediate_tensors(
546
547
548
549
550
551
552
553
554
555
556
        self, batch_size: int, dtype: torch.dtype, device: torch.device
    ) -> IntermediateTensors:
        return IntermediateTensors(
            {
                "hidden_states": torch.zeros(
                    (batch_size, self.config.hidden_size), dtype=dtype, device=device
                ),
            }
        )

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
557
558
        loader = AutoWeightsLoader(
            self,
559
            skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
560
561
        )
        return loader.load_weights(weights)