grok1.py 21 KB
Newer Older
Michael Goin's avatar
Michael Goin committed
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
Michael Goin's avatar
Michael Goin committed
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# Adapted from
# https://github.com/ROCm/vllm/blob/cea7419f151cc50293a05b7fac8547f8f887c9f6/vllm/model_executor/models/grok1.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Grok1 model."""
25

26
from collections.abc import Iterable
27
from itertools import islice
Michael Goin's avatar
Michael Goin committed
28
29
30
31
32

import torch
import torch.nn.functional as F
from torch import nn

33
from vllm.attention import Attention
Michael Goin's avatar
Michael Goin committed
34
35
36
37
38
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
39
40
41
42
43
from vllm.model_executor.layers.linear import (
    QKVParallelLinear,
    ReplicatedLinear,
    RowParallelLinear,
)
Michael Goin's avatar
Michael Goin committed
44
45
46
47
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
48
49
50
51
    DEFAULT_VOCAB_PADDING_SIZE,
    ParallelLMHead,
    VocabParallelEmbedding,
)
Michael Goin's avatar
Michael Goin committed
52
from vllm.model_executor.model_loader.weight_utils import (
53
54
55
    default_weight_loader,
    maybe_remap_kv_scale_name,
)
Michael Goin's avatar
Michael Goin committed
56
57
58
from vllm.sequence import IntermediateTensors

from .interfaces import SupportsLoRA, SupportsPP
59
60
61
62
63
64
65
from .utils import (
    AutoWeightsLoader,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
Michael Goin's avatar
Michael Goin committed
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81

# Default Grok1-specific constants, overridden by config values if present
DEFAULT_ATTN_OUTPUT_MULTIPLIER = 0.08838834764831845
DEFAULT_OUTPUT_MULTIPLIER_SCALE = 0.5773502691896257
DEFAULT_EMBEDDING_MULTIPLIER_SCALE = 78.38367176906169


class Grok1MoE(nn.Module):
    """A tensor-parallel MoE implementation for Grok1 that shards each expert
    across all ranks.

    Each expert's weights are sharded across all ranks and a fused MoE
    kernel is used for the forward pass, and finally we reduce the outputs
    across ranks.
    """

82
83
84
85
86
87
    def __init__(
        self,
        num_experts: int,
        top_k: int,
        hidden_size: int,
        intermediate_size: int,
88
89
90
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
        tp_size: int | None = None,
91
92
        prefix: str = "",
    ):
Michael Goin's avatar
Michael Goin committed
93
94
95
96
        super().__init__()
        self.hidden_size = hidden_size

        # Gate always runs at half / full precision for now.
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
        self.gate = ReplicatedLinear(
            hidden_size,
            num_experts,
            bias=False,
            params_dtype=params_dtype,
            quant_config=None,
            prefix=f"{prefix}.gate",
        )

        self.experts = FusedMoE(
            num_experts=num_experts,
            top_k=top_k,
            hidden_size=hidden_size,
            intermediate_size=intermediate_size,
            params_dtype=params_dtype,
            reduce_results=True,
            renormalize=True,
            quant_config=quant_config,
            tp_size=tp_size,
            activation="gelu",
            prefix=f"{prefix}.experts",
        )
Michael Goin's avatar
Michael Goin committed
119
120
121
122
123
124
125
126
127
128
129
130
131
132

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        # NOTE: hidden_states can have either 1D or 2D shape.
        orig_shape = hidden_states.shape
        hidden_states = hidden_states.view(-1, self.hidden_size)
        # router_logits: (num_tokens, n_experts)
        router_logits, _ = self.gate(hidden_states)
        router_logits = 30.0 * F.tanh(router_logits / 30.0)
        final_hidden_states = self.experts(hidden_states, router_logits)
        return final_hidden_states.view(orig_shape)


class Grok1Attention(nn.Module):
    def __init__(
133
134
135
136
137
138
        self,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        max_position: int = 4096 * 32,
        rope_theta: float = 10000,
139
140
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
141
142
        prefix: str = "",
        config=None,  # Added config parameter
Michael Goin's avatar
Michael Goin committed
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        self.config = config  # Store config reference
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.rope_theta = rope_theta

        self.qkv_proj = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj",
        )
        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.o_proj",
        )
        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position,
            base=int(self.rope_theta),
            is_neox_style=True,
        )

191
        attn_logits_soft_cap = max(getattr(config, "attn_logit_softcapping", 30.0), 0.0)
Michael Goin's avatar
Michael Goin committed
192

193
194
195
196
197
198
199
200
201
202
203
204
205
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            logits_soft_cap=attn_logits_soft_cap,
            prefix=f"{prefix}.attn",
        )
        self.attn_multiplier = (
            getattr(self.config, "attn_output_multiplier", 1.0) if self.config else 1.0
        )
Michael Goin's avatar
Michael Goin committed
206
207
208
209
210
211
212
213
214

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
215
        attn_output = self.attn(q, k, v)
Michael Goin's avatar
Michael Goin committed
216
        output, _ = self.o_proj(attn_output)
217
        output *= self.attn_multiplier
Michael Goin's avatar
Michael Goin committed
218
219
220
221
222
223
224
        return output


class Grok1DecoderLayer(nn.Module):
    def __init__(
        self,
        config,
225
226
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
Michael Goin's avatar
Michael Goin committed
227
228
229
230
231
232
233
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
        # Check for fp8 quantization
        self.use_fp8 = False
        if quant_config is not None:
234
            self.use_fp8 = getattr(quant_config, "is_fp8_w8a8", lambda: False)()
Michael Goin's avatar
Michael Goin committed
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
            if not self.use_fp8 and hasattr(quant_config, "is_fp8"):
                self.use_fp8 = quant_config.is_fp8

        # Requires transformers > 4.32.0
        # Default rope_theta value if not in config
        rope_theta = 10000
        self.attn = Grok1Attention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            max_position=config.max_position_embeddings,
            num_kv_heads=config.num_key_value_heads,
            rope_theta=rope_theta,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
250
251
            config=config,
        )  # Pass config to Grok1Attention
Michael Goin's avatar
Michael Goin committed
252
253
254
255
256

        # Grok1 uses "num_experts" in its config
        num_experts = getattr(config, "num_experts", 8)
        num_experts_per_tok = getattr(config, "num_experts_per_tok", 2)

257
258
259
260
261
262
263
264
265
266
267
268
269
        self.moe_block = Grok1MoE(
            num_experts=num_experts,
            top_k=num_experts_per_tok,
            hidden_size=config.hidden_size,
            intermediate_size=config.intermediate_size,
            quant_config=quant_config,
            prefix=f"{prefix}.moe_block",
        )

        self.pre_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.pre_moe_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_moe_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
Michael Goin's avatar
Michael Goin committed
270
271
272
273
274

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
275
        residual: torch.Tensor | None,
276
    ) -> tuple[torch.Tensor, torch.Tensor]:
Michael Goin's avatar
Michael Goin committed
277
278
279
280
281
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.pre_attn_norm(hidden_states)
        else:
282
            hidden_states, residual = self.pre_attn_norm(hidden_states, residual)
Michael Goin's avatar
Michael Goin committed
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309

        hidden_states = self.attn(
            positions=positions,
            hidden_states=hidden_states,
        )

        # Post attention normalization
        hidden_states = self.post_attn_norm(hidden_states)

        # MoE block with normalization
        hidden_states, residual = self.pre_moe_norm(hidden_states, residual)
        hidden_states = self.moe_block(hidden_states)
        hidden_states = self.post_moe_norm(hidden_states)

        return hidden_states, residual


@support_torch_compile
class Grok1Model(nn.Module):
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config

310
311
        self.config = config
        self.quant_config = quant_config
Michael Goin's avatar
Michael Goin committed
312
        self.padding_idx = config.pad_token_id
313
314
315
316
317
        lora_vocab = (
            (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1))
            if lora_config
            else 0
        )
Michael Goin's avatar
Michael Goin committed
318
319
320
        self.vocab_size = config.vocab_size + lora_vocab
        self.org_vocab_size = config.vocab_size
        self.embedding_multiplier_scale = getattr(
321
322
            config, "embedding_multiplier_scale", DEFAULT_EMBEDDING_MULTIPLIER_SCALE
        )
Michael Goin's avatar
Michael Goin committed
323
324
325
326
327
328
329
330
331
332
333
334
335

        self.embed_tokens = VocabParallelEmbedding(
            self.vocab_size,
            config.hidden_size,
            org_num_embeddings=config.vocab_size,
            quant_config=quant_config,
        )

        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
            lambda prefix: Grok1DecoderLayer(
                config, cache_config, quant_config=quant_config, prefix=prefix
            ),
336
337
            prefix=f"{prefix}.layers",
        )
Michael Goin's avatar
Michael Goin committed
338
339

        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
340
341
342
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], config.hidden_size
        )
Michael Goin's avatar
Michael Goin committed
343
344
345
346
347
348
349
350
351
352

    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        hidden_states = self.embed_tokens(input_ids)
        hidden_states = hidden_states * self.embedding_multiplier_scale
        return hidden_states

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
353
354
355
        intermediate_tensors: IntermediateTensors | None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
Michael Goin's avatar
Michael Goin committed
356
357
358
359
360
361
362
363
364
365
366
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]

367
        for layer in islice(self.layers, self.start_layer, self.end_layer):
368
            hidden_states, residual = layer(positions, hidden_states, residual)
Michael Goin's avatar
Michael Goin committed
369
370

        if not get_pp_group().is_last_rank:
371
372
373
            return IntermediateTensors(
                {"hidden_states": hidden_states, "residual": residual}
            )
Michael Goin's avatar
Michael Goin committed
374
375
376
377

        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states

378
379
380
381
382
383
384
385
    def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
        # Map Grok1's unique expert parameter names to standard names
        # Grok1 uses "num_experts" in its config
        num_experts = getattr(self.config, "num_experts", 8)
        return FusedMoE.make_expert_params_mapping(
            ckpt_gate_proj_name="linear",  # Grok1 specific
            ckpt_down_proj_name="linear_1",  # Grok1 specific
            ckpt_up_proj_name="linear_v",  # Grok1 specific
386
387
            num_experts=num_experts,
        )
388

389
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
Michael Goin's avatar
Michael Goin committed
390
391
392
393
394
395
396
397
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]

        params_dict = dict(self.named_parameters())
398
        loaded_params: set[str] = set()
399
        expert_params_mapping = self.get_expert_mapping()
Michael Goin's avatar
Michael Goin committed
400
        for name, loaded_weight in weights:
401
402
403
            if self.quant_config is not None and (
                scale_name := self.quant_config.get_cache_scale(name)
            ):
Michael Goin's avatar
Michael Goin committed
404
405
                # Loading kv cache quantization scales
                param = params_dict[scale_name]
406
407
408
409
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                loaded_weight = (
                    loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
                )
Michael Goin's avatar
Michael Goin committed
410
411
412
413
                weight_loader(param, loaded_weight)
                loaded_params.add(scale_name)
                continue

414
            for param_name, weight_name, shard_id in stacked_params_mapping:
Michael Goin's avatar
Michael Goin committed
415
416
417
418
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
419
420
421
                if (
                    name.endswith(".bias") or name.endswith("_bias")
                ) and name not in params_dict:
Michael Goin's avatar
Michael Goin committed
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
                    continue
                # Skip layers on other devices.
                if is_pp_missing_parameter(name, self):
                    continue
                if name.endswith("scale"):
                    # Remapping the name of FP8 kv-scale.
                    name = maybe_remap_kv_scale_name(name, params_dict)
                    if name is None:
                        continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                for mapping in expert_params_mapping:
                    param_name, weight_name, expert_id, shard_id = mapping
                    if weight_name not in name:
                        continue
                    name = name.replace(weight_name, param_name)
                    # Skip layers on other devices.
                    if is_pp_missing_parameter(name, self):
                        continue
444
445
446
                    if (
                        name.endswith(".bias") or name.endswith("_bias")
                    ) and name not in params_dict:
Michael Goin's avatar
Michael Goin committed
447
448
449
                        continue
                    param = params_dict[name]
                    weight_loader = param.weight_loader
450
451
452
453
454
455
456
                    weight_loader(
                        param,
                        loaded_weight,
                        name,
                        shard_id=shard_id,
                        expert_id=expert_id,
                    )
Michael Goin's avatar
Michael Goin committed
457
458
459
                    break
                else:
                    # Skip loading extra bias for GPTQ models.
460
461
462
                    if (
                        name.endswith(".bias") or name.endswith("_bias")
                    ) and name not in params_dict:
Michael Goin's avatar
Michael Goin committed
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
                        continue
                    # Skip layers on other devices.
                    if is_pp_missing_parameter(name, self):
                        continue

                    # Remapping the name of FP8 kv-scale.
                    name = maybe_remap_kv_scale_name(name, params_dict)
                    if name is None:
                        continue

                    # Handle Grok1-specific norm.scale naming
                    if "norm.scale" in name:
                        name = name.replace("scale", "weight")

                    param = params_dict[name]
478
479
480
                    weight_loader = getattr(
                        param, "weight_loader", default_weight_loader
                    )
Michael Goin's avatar
Michael Goin committed
481
482
483
                    weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507


class Grok1ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
    fall_back_to_pt_during_load = False

    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
    }

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()

        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        lora_config = vllm_config.lora_config

        self.config = config
        self.lora_config = lora_config
        self.quant_config = quant_config

508
509
510
        self.model = Grok1Model(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
511
512
513
514
515
516
517
518
519

        self.unpadded_vocab_size = config.vocab_size
        if lora_config:
            self.unpadded_vocab_size += lora_config.lora_extra_vocab_size

        self.lm_head = ParallelLMHead(
            self.unpadded_vocab_size,
            config.hidden_size,
            org_num_embeddings=config.vocab_size,
520
            padding_size=DEFAULT_VOCAB_PADDING_SIZE,
521
522
523
524
525
526
527
528
            quant_config=quant_config,
            prefix=maybe_prefix(prefix, "lm_head"),
        )

        if self.config.tie_word_embeddings:
            self.lm_head.weight = self.model.embed_tokens.weight

        self.output_multiplier_scale = getattr(
529
530
531
532
533
            config, "output_multiplier_scale", DEFAULT_OUTPUT_MULTIPLIER_SCALE
        )
        self.logits_processor = LogitsProcessor(
            self.unpadded_vocab_size, config.vocab_size, self.output_multiplier_scale
        )
534
535

        self.make_empty_intermediate_tensors = (
536
537
            self.model.make_empty_intermediate_tensors
        )
538
539
540
541
542
543
544
545

    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
546
547
548
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
549
550
551
        hidden_states = self.model(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
552
553
554
555
556
        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
557
    ) -> torch.Tensor | None:
558
        logits = self.logits_processor(self.lm_head, hidden_states)
559
560
        return logits

561
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
562
        # Skip lm_head when tie_word_embeddings is True
563
        skip_prefixes = ["lm_head"] if self.config.tie_word_embeddings else None
564

565
566
567
568
        loader = AutoWeightsLoader(
            self,
            skip_prefixes=skip_prefixes,
        )
569
        return loader.load_weights(weights)
570
571
572

    def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
        return self.model.get_expert_mapping()