arctic.py 21.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Inference-only Snowflake Arctic model."""
4

5
from collections.abc import Iterable
6
from itertools import islice
7
8
9
10

import torch
from torch import nn

11
from vllm.compilation.decorators import support_torch_compile
12
from vllm.config import CacheConfig, VllmConfig
13
14
15
16
17
18
from vllm.distributed import (
    get_pp_group,
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
    tensor_model_parallel_all_reduce,
)
19
from vllm.model_executor.layers.activation import SiluAndMul
20
from vllm.model_executor.layers.attention import Attention
21
22
23
24
from vllm.model_executor.layers.fused_moe import (
    fused_experts,
    fused_topk,
)
25
from vllm.model_executor.layers.layernorm import RMSNorm
26
27
28
29
30
31
from vllm.model_executor.layers.linear import (
    MergedColumnParallelLinear,
    QKVParallelLinear,
    ReplicatedLinear,
    RowParallelLinear,
)
32
from vllm.model_executor.layers.logits_processor import LogitsProcessor
33
from vllm.model_executor.layers.quantization import QuantizationConfig
34
35
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
36
37
38
    ParallelLMHead,
    VocabParallelEmbedding,
)
39
40
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.utils import set_weight_attrs
41
from vllm.platforms import current_platform
42
from vllm.sequence import IntermediateTensors
43
44
from vllm.transformers_utils.configs.arctic import ArcticConfig

45
from .interfaces import SupportsPP, SupportsQuant
46
from .utils import (
47
    AutoWeightsLoader,
48
49
50
51
52
53
    extract_layer_index,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
54

55
56

class ArcticMLP(nn.Module):
57
58
59
60
61
    def __init__(
        self,
        config: ArcticConfig,
        expert_id: int = -1,
        is_residual_mlp: bool = False,
62
        quant_config: QuantizationConfig | None = None,
63
64
65
        reduce_results: bool = True,
        prefix: str = "",
    ):
66
        super().__init__()
67
68
69
        self.hidden_size = config.hidden_size
        self.expert_id = expert_id

70
71
72
73
74
        self.ffn_dim = (
            config.intermediate_size if not is_residual_mlp else self.hidden_size
        )

        self.w13 = MergedColumnParallelLinear(
75
76
77
78
79
            self.hidden_size,
            [self.ffn_dim] * 2,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.w13",
80
81
82
83
84
85
86
        )
        self.w2 = RowParallelLinear(
            self.ffn_dim,
            self.hidden_size,
            bias=False,
            reduce_results=reduce_results,
            quant_config=quant_config,
87
            prefix=f"{prefix}.w2",
88
        )
89
        if config.hidden_act != "silu":
90
91
92
93
            raise ValueError(
                f"Unsupported activation: {config.hidden_act}. "
                "Only silu is supported for now."
            )
94
95
96
97
98
99
100
101
102
103
104
105
106
107
        self.act_fn = SiluAndMul()

    def forward(self, hidden_states):
        gate_up, _ = self.w13(hidden_states)
        hidden_states = self.act_fn(gate_up)
        hidden_states, _ = self.w2(hidden_states)
        return hidden_states


class ArcticMoE(nn.Module):
    """
    Model-parallel implementation of Arctic MoE Layer.
    """

108
109
110
    def __init__(
        self,
        config: ArcticConfig,
111
112
113
        tp_size: int | None = None,
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
114
115
116
        reduce_results: bool = True,
        prefix: str = "",
    ):
117
        super().__init__()
118

119
        layer_id = extract_layer_index(prefix)
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
        self.tp_size = tp_size or get_tensor_model_parallel_world_size()
        self.hidden_size = config.hidden_size
        self.num_experts = config.num_local_experts
        self.layer_id = layer_id
        self.top_k = config.num_experts_per_tok
        self.intermediate_size = config.intermediate_size // self.tp_size

        self.is_moe_layer = (layer_id + 1) % config.moe_layer_frequency == 0
        self.reduce_results = reduce_results
        # Some other parameters
        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
        self.params_dtype = params_dtype

        if not self.is_moe_layer:
135
136
137
138
139
140
            self.mlp = ArcticMLP(
                config,
                quant_config=quant_config,
                reduce_results=reduce_results,
                prefix=f"{prefix}.mlp",
            )
141
        else:
142
143
144
145
146
147
148
149
            self.gate = ReplicatedLinear(
                self.hidden_size,
                self.num_experts,
                bias=False,
                params_dtype=self.params_dtype,
                quant_config=quant_config,
                prefix=f"{prefix}.gate",
            )
150
151
152
153
154
155
156
            self.ws = nn.Parameter(
                torch.empty(
                    self.num_experts,
                    2 * self.intermediate_size,
                    self.hidden_size,
                    device=current_platform.device_type,
                    dtype=self.params_dtype,
157
                )
158
159
160
161
162
163
164
165
            )
            self.w2s = nn.Parameter(
                torch.empty(
                    self.num_experts,
                    self.hidden_size,
                    self.intermediate_size,
                    device=current_platform.device_type,
                    dtype=self.params_dtype,
166
                )
167
            )
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
            set_weight_attrs(
                self.ws,
                {
                    "weight_loader": self.weight_loader,
                },
            )
            set_weight_attrs(
                self.w2s,
                {
                    "weight_loader": self.weight_loader,
                },
            )

    def weight_loader(
        self,
        param: nn.Parameter,
        loaded_weight: torch.Tensor,
        weight_name: str,
        expert_id: int,
    ):
188
        tp_rank = get_tensor_model_parallel_rank()
189
        param_data = param.data
190
191
192
193
194
        shard_size = self.intermediate_size
        shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
        if weight_name.endswith("w1.weight"):
            param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
        if weight_name.endswith("w3.weight"):
195
196
197
            param_data[expert_id, shard_size : 2 * shard_size, :] = loaded_weight[
                shard, :
            ]
198
199
200
201
202
203
204
205
206
        if weight_name.endswith("w2.weight"):
            param_data[expert_id, :, :] = loaded_weight[:, shard]

    def local_moe_fused(self, hidden_states: torch.Tensor) -> torch.Tensor:
        num_tokens, hidden_size = hidden_states.shape
        hidden_states = hidden_states.view(-1, self.hidden_size)
        # router_logits: (num_tokens, n_experts)
        router_logits, _ = self.gate(hidden_states)
        do_normalize = self.top_k > 1
207
        topk_weights, topk_ids, token_expert_indices = fused_topk(
208
209
            hidden_states, router_logits, self.top_k, renormalize=do_normalize
        )
210
211
        final_hidden_states = fused_experts(
            hidden_states,
212
213
            self.ws,
            self.w2s,
214
215
            topk_weights,
            topk_ids,
216
217
            inplace=True,
        )
218
        if self.reduce_results and self.tp_size > 1:
219
            final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
220
221
222
223
224
225
226
227
228
229
230
231
232
233
        return final_hidden_states.view(num_tokens, hidden_size)

    def forward(self, hidden_states: torch.Tensor):
        if self.is_moe_layer:
            final_hidden_states = self.local_moe_fused(hidden_states)
        else:
            final_hidden_states = self.mlp(hidden_states)
        return final_hidden_states


class ArcticAttention(nn.Module):
    def __init__(
        self,
        config: ArcticConfig,
234
235
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
236
        prefix: str = "",
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
    ):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size

        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = config.num_attention_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = config.num_key_value_heads
        if self.total_num_kv_heads >= tp_size:
            assert self.total_num_kv_heads % tp_size == 0
        else:
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = self.hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim

        self.max_position_embeddings = config.max_position_embeddings
        self.scaling = self.head_dim**-0.5

259
260
261
262
263
264
265
        self.qkv_proj = QKVParallelLinear(
            self.hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=False,
            quant_config=quant_config,
266
            prefix=f"{prefix}.qkv_proj",
267
        )
268
269
270
271
272
273
        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            self.hidden_size,
            bias=False,
            reduce_results=True,
            quant_config=quant_config,
274
            prefix=f"{prefix}.o_proj",
275
276
277
278
279
        )

        self.rotary_emb = get_rope(
            self.head_dim,
            max_position=self.max_position_embeddings,
280
            rope_parameters=config.rope_parameters,
281
282
283
            is_neox_style=True,
        )

284
285
286
287
288
289
290
291
292
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
        )
293
294
295
296
297
298
299
300
301

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
302
        attn_output = self.attn(q, k, v)
303
304
305
306
307
308
309
310
        output, _ = self.o_proj(attn_output)
        return output


class ArcticDecoderLayer(nn.Module):
    def __init__(
        self,
        config: ArcticConfig,
311
312
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
313
        prefix: str = "",
314
315
316
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
317
        layer_idx = extract_layer_index(prefix)
318
319
        is_moe_layer = (layer_idx + 1) % config.moe_layer_frequency == 0
        self.use_residual = config.use_residual and is_moe_layer
320
321
322
323
324
325
        self.self_attn = ArcticAttention(
            config,
            cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.self_attn",
        )
326
327
328
        self.block_sparse_moe = ArcticMoE(
            config,
            quant_config=quant_config,
329
330
331
            reduce_results=(not self.use_residual),
            prefix=f"{prefix}.block_sparse_moe",
        )
332

333
334
335
336
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )
337
338

        if self.use_residual:
339
340
341
342
343
344
345
346
347
            self.residual_layernorm = RMSNorm(
                config.hidden_size, eps=config.rms_norm_eps
            )
            self.residual_mlp = ArcticMLP(
                config,
                is_residual_mlp=True,
                reduce_results=False,
                prefix=f"{prefix}.residual_mlp",
            )
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        residual_input = hidden_states
        hidden_states = self.input_layernorm(hidden_states)
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
        )
        hidden_states = residual_input + hidden_states

        residual_attn = hidden_states
        if self.use_residual:
            hidden_states = self.residual_layernorm(hidden_states)
            hidden_states = self.residual_mlp(hidden_states)
            residual_mlp = hidden_states
            hidden_states = self.post_attention_layernorm(residual_input)
            hidden_states = self.block_sparse_moe(hidden_states)
            hidden_states = residual_mlp + hidden_states
            hidden_states = tensor_model_parallel_all_reduce(hidden_states)
            hidden_states = residual_attn + hidden_states
        else:
            hidden_states = self.post_attention_layernorm(hidden_states)
            hidden_states = self.block_sparse_moe(hidden_states)
            hidden_states = residual_attn + hidden_states
        return hidden_states


379
@support_torch_compile
380
class ArcticModel(nn.Module):
381
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
382
        super().__init__()
383
384
385
386
387

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

388
        self.config = config
389
390
        self.vocab_size = config.vocab_size
        self.embed_tokens = VocabParallelEmbedding(
391
392
            self.vocab_size, config.hidden_size, org_num_embeddings=self.vocab_size
        )
393
394
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
395
            lambda prefix: ArcticDecoderLayer(
396
397
398
399
                config, cache_config, quant_config, prefix=prefix
            ),
            prefix=f"{prefix}.layers",
        )
400
401
        self._attn_implementation = config._attn_implementation
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
402
403
404
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states"], config.hidden_size
        )
405

406
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
407
408
        return self.embed_tokens(input_ids)

409
410
    def forward(
        self,
411
        input_ids: torch.Tensor | None,
412
        positions: torch.Tensor,
413
414
415
        intermediate_tensors: IntermediateTensors | None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
416
        if get_pp_group().is_first_rank:
417
418
419
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
420
                hidden_states = self.embed_input_ids(input_ids)
421
422
423
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
424
        for layer in islice(self.layers, self.start_layer, self.end_layer):
425
            hidden_states = layer(positions, hidden_states)
426
427
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})
428
429
430
        hidden_states = self.norm(hidden_states)
        return hidden_states

431
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
432
433
434
435
436
437
438
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]

439
440
        mlp_params_mapping: list[tuple[str, str, int]] = []
        expert_params_mapping: list[tuple[str, str, int]] = []
441
442
443
444

        for layer in range(self.config.num_hidden_layers):
            is_moe_layer = (layer + 1) % self.config.moe_layer_frequency == 0
            if is_moe_layer and self.config.use_residual:
445
                mlp_params_mapping.append(
446
                    (
447
448
                        f"layers.{layer}.residual_mlp.w13.weight",
                        f"layers.{layer}.residual_mlp.w1.weight",
449
450
451
                        0,
                    )
                )
452
                mlp_params_mapping.append(
453
                    (
454
455
                        f"layers.{layer}.residual_mlp.w13.weight",
                        f"layers.{layer}.residual_mlp.w3.weight",
456
457
458
                        1,
                    )
                )
459
460

            if is_moe_layer:
461
462
                for expert_id in range(self.config.num_local_experts):
                    expert_params_mapping.append(
463
464
                        ("ws", f"experts.{expert_id}.w1.weight", expert_id)
                    )
465
                    expert_params_mapping.append(
466
467
                        ("w2s", f"experts.{expert_id}.w2.weight", expert_id)
                    )
468
                    expert_params_mapping.append(
469
470
                        ("ws", f"experts.{expert_id}.w3.weight", expert_id)
                    )
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
            else:
                mlp_params_mapping.append(
                    (
                        f"layers.{layer}.block_sparse_moe.mlp.w13.weight",
                        f"layers.{layer}.block_sparse_moe.mlp.w1.weight",
                        0,
                    )
                )
                mlp_params_mapping.append(
                    (
                        f"layers.{layer}.block_sparse_moe.mlp.w13.weight",
                        f"layers.{layer}.block_sparse_moe.mlp.w3.weight",
                        1,
                    )
                )
486
487

        params_dict = dict(self.named_parameters())
488
        loaded_params: set[str] = set()
489
490

        for name, loaded_weight in weights:
491
            for param_name, weight_name, shard_id in stacked_params_mapping:
492
493
494
495
496
497
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
498
499
                if is_pp_missing_parameter(name, self):
                    continue
500
501
502
503
504
505
506
507
508
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                for param_name, weight_name, shard_id in mlp_params_mapping:
                    if weight_name not in name:
                        continue
                    name = name.replace(weight_name, param_name)
509
510
                    if is_pp_missing_parameter(name, self):
                        continue
511
512
513
514
515
                    param = params_dict[name]
                    weight_loader = param.weight_loader
                    weight_loader(param, loaded_weight, shard_id)
                    break
                else:
516
                    for param_name, weight_name, shard_id in expert_params_mapping:
517
518
519
                        if weight_name not in name:
                            continue
                        name = name.replace(weight_name, param_name)
520
521
                        if is_pp_missing_parameter(name, self):
                            continue
522
523
                        param = params_dict[name]
                        weight_loader = param.weight_loader
524
525
526
                        weight_loader(
                            param, loaded_weight, weight_name, expert_id=shard_id
                        )
527
528
529
530
                        break
                    else:
                        if name.endswith(".bias") and name not in params_dict:
                            continue
531
532
                        if is_pp_missing_parameter(name, self):
                            continue
533
                        param = params_dict[name]
534
535
536
                        weight_loader = getattr(
                            param, "weight_loader", default_weight_loader
                        )
537
                        weight_loader(param, loaded_weight)
538
539
            loaded_params.add(name)
        return loaded_params
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597


class ArcticForCausalLM(nn.Module, SupportsPP, SupportsQuant):
    packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]}

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        self.config = config
        self.model = ArcticModel(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
        self.vocab_size = config.vocab_size
        self.lm_head = ParallelLMHead(
            self.vocab_size,
            config.hidden_size,
            quant_config=quant_config,
            prefix=maybe_prefix(prefix, "lm_head"),
        )
        if self.config.tie_word_embeddings:
            self.lm_head.weight = self.model.embed_tokens.weight
        self.num_experts = config.num_local_experts
        self.num_experts_per_tok = config.num_experts_per_tok

        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors
        )

    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.embed_input_ids(input_ids)

    def forward(
        self,
        input_ids: torch.Tensor | None,
        positions: torch.Tensor,
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
        hidden_states = self.model(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor | None:
        logits = self.logits_processor(self.lm_head, hidden_states)
        return logits

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
        loader = AutoWeightsLoader(
            self,
            skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
        )
        return loader.load_weights(weights)