arctic.py 21.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Inference-only Snowflake Arctic model."""
4

5
from collections.abc import Iterable
6
from itertools import islice
7
8
9
10

import torch
from torch import nn

11
from vllm.compilation.decorators import support_torch_compile
12
from vllm.config import CacheConfig, VllmConfig
13
14
15
16
17
18
from vllm.distributed import (
    get_pp_group,
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
    tensor_model_parallel_all_reduce,
)
19
from vllm.model_executor.layers.activation import SiluAndMul
20
from vllm.model_executor.layers.attention import Attention
21
22
from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk
from vllm.model_executor.layers.layernorm import RMSNorm
23
24
25
26
27
28
from vllm.model_executor.layers.linear import (
    MergedColumnParallelLinear,
    QKVParallelLinear,
    ReplicatedLinear,
    RowParallelLinear,
)
29
from vllm.model_executor.layers.logits_processor import LogitsProcessor
30
from vllm.model_executor.layers.quantization import QuantizationConfig
31
32
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
33
34
35
    ParallelLMHead,
    VocabParallelEmbedding,
)
36
37
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.utils import set_weight_attrs
38
from vllm.platforms import current_platform
39
from vllm.sequence import IntermediateTensors
40
41
from vllm.transformers_utils.configs.arctic import ArcticConfig

42
from .interfaces import SupportsPP, SupportsQuant
43
from .utils import (
44
    AutoWeightsLoader,
45
46
47
48
49
50
    extract_layer_index,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
51

52
53

class ArcticMLP(nn.Module):
54
55
56
57
58
    def __init__(
        self,
        config: ArcticConfig,
        expert_id: int = -1,
        is_residual_mlp: bool = False,
59
        quant_config: QuantizationConfig | None = None,
60
61
62
        reduce_results: bool = True,
        prefix: str = "",
    ):
63
        super().__init__()
64
65
66
        self.hidden_size = config.hidden_size
        self.expert_id = expert_id

67
68
69
70
71
        self.ffn_dim = (
            config.intermediate_size if not is_residual_mlp else self.hidden_size
        )

        self.w13 = MergedColumnParallelLinear(
72
73
74
75
76
            self.hidden_size,
            [self.ffn_dim] * 2,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.w13",
77
78
79
80
81
82
83
        )
        self.w2 = RowParallelLinear(
            self.ffn_dim,
            self.hidden_size,
            bias=False,
            reduce_results=reduce_results,
            quant_config=quant_config,
84
            prefix=f"{prefix}.w2",
85
        )
86
        if config.hidden_act != "silu":
87
88
89
90
            raise ValueError(
                f"Unsupported activation: {config.hidden_act}. "
                "Only silu is supported for now."
            )
91
92
93
94
95
96
97
98
99
100
101
102
103
104
        self.act_fn = SiluAndMul()

    def forward(self, hidden_states):
        gate_up, _ = self.w13(hidden_states)
        hidden_states = self.act_fn(gate_up)
        hidden_states, _ = self.w2(hidden_states)
        return hidden_states


class ArcticMoE(nn.Module):
    """
    Model-parallel implementation of Arctic MoE Layer.
    """

105
106
107
    def __init__(
        self,
        config: ArcticConfig,
108
109
110
        tp_size: int | None = None,
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
111
112
113
        reduce_results: bool = True,
        prefix: str = "",
    ):
114
        super().__init__()
115

116
        layer_id = extract_layer_index(prefix)
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
        self.tp_size = tp_size or get_tensor_model_parallel_world_size()
        self.hidden_size = config.hidden_size
        self.num_experts = config.num_local_experts
        self.layer_id = layer_id
        self.top_k = config.num_experts_per_tok
        self.intermediate_size = config.intermediate_size // self.tp_size

        self.is_moe_layer = (layer_id + 1) % config.moe_layer_frequency == 0
        self.reduce_results = reduce_results
        # Some other parameters
        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
        self.params_dtype = params_dtype

        if not self.is_moe_layer:
132
133
134
135
136
137
            self.mlp = ArcticMLP(
                config,
                quant_config=quant_config,
                reduce_results=reduce_results,
                prefix=f"{prefix}.mlp",
            )
138
        else:
139
140
141
142
143
144
145
146
            self.gate = ReplicatedLinear(
                self.hidden_size,
                self.num_experts,
                bias=False,
                params_dtype=self.params_dtype,
                quant_config=quant_config,
                prefix=f"{prefix}.gate",
            )
147
148
149
150
151
152
153
            self.ws = nn.Parameter(
                torch.empty(
                    self.num_experts,
                    2 * self.intermediate_size,
                    self.hidden_size,
                    device=current_platform.device_type,
                    dtype=self.params_dtype,
154
                )
155
156
157
158
159
160
161
162
            )
            self.w2s = nn.Parameter(
                torch.empty(
                    self.num_experts,
                    self.hidden_size,
                    self.intermediate_size,
                    device=current_platform.device_type,
                    dtype=self.params_dtype,
163
                )
164
            )
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
            set_weight_attrs(
                self.ws,
                {
                    "weight_loader": self.weight_loader,
                },
            )
            set_weight_attrs(
                self.w2s,
                {
                    "weight_loader": self.weight_loader,
                },
            )

    def weight_loader(
        self,
        param: nn.Parameter,
        loaded_weight: torch.Tensor,
        weight_name: str,
        expert_id: int,
    ):
185
        tp_rank = get_tensor_model_parallel_rank()
186
        param_data = param.data
187
188
189
190
191
        shard_size = self.intermediate_size
        shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
        if weight_name.endswith("w1.weight"):
            param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
        if weight_name.endswith("w3.weight"):
192
193
194
            param_data[expert_id, shard_size : 2 * shard_size, :] = loaded_weight[
                shard, :
            ]
195
196
197
198
199
200
201
202
203
        if weight_name.endswith("w2.weight"):
            param_data[expert_id, :, :] = loaded_weight[:, shard]

    def local_moe_fused(self, hidden_states: torch.Tensor) -> torch.Tensor:
        num_tokens, hidden_size = hidden_states.shape
        hidden_states = hidden_states.view(-1, self.hidden_size)
        # router_logits: (num_tokens, n_experts)
        router_logits, _ = self.gate(hidden_states)
        do_normalize = self.top_k > 1
204
        topk_weights, topk_ids, token_expert_indices = fused_topk(
205
206
            hidden_states, router_logits, self.top_k, renormalize=do_normalize
        )
207
208
        final_hidden_states = fused_experts(
            hidden_states,
209
210
            self.ws,
            self.w2s,
211
212
            topk_weights,
            topk_ids,
213
214
            inplace=True,
        )
215
        if self.reduce_results and self.tp_size > 1:
216
            final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
217
218
219
220
221
222
223
224
225
226
227
228
229
230
        return final_hidden_states.view(num_tokens, hidden_size)

    def forward(self, hidden_states: torch.Tensor):
        if self.is_moe_layer:
            final_hidden_states = self.local_moe_fused(hidden_states)
        else:
            final_hidden_states = self.mlp(hidden_states)
        return final_hidden_states


class ArcticAttention(nn.Module):
    def __init__(
        self,
        config: ArcticConfig,
231
232
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
233
        prefix: str = "",
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
    ):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size

        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = config.num_attention_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = config.num_key_value_heads
        if self.total_num_kv_heads >= tp_size:
            assert self.total_num_kv_heads % tp_size == 0
        else:
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = self.hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim

        self.max_position_embeddings = config.max_position_embeddings
        self.scaling = self.head_dim**-0.5

256
257
258
259
260
261
262
        self.qkv_proj = QKVParallelLinear(
            self.hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=False,
            quant_config=quant_config,
263
            prefix=f"{prefix}.qkv_proj",
264
        )
265
266
267
268
269
270
        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            self.hidden_size,
            bias=False,
            reduce_results=True,
            quant_config=quant_config,
271
            prefix=f"{prefix}.o_proj",
272
273
274
275
276
        )

        self.rotary_emb = get_rope(
            self.head_dim,
            max_position=self.max_position_embeddings,
277
            rope_parameters=config.rope_parameters,
278
279
280
            is_neox_style=True,
        )

281
282
283
284
285
286
287
288
289
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
        )
290
291
292
293
294
295
296
297
298

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
299
        attn_output = self.attn(q, k, v)
300
301
302
303
304
305
306
307
        output, _ = self.o_proj(attn_output)
        return output


class ArcticDecoderLayer(nn.Module):
    def __init__(
        self,
        config: ArcticConfig,
308
309
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
310
        prefix: str = "",
311
312
313
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
314
        layer_idx = extract_layer_index(prefix)
315
316
        is_moe_layer = (layer_idx + 1) % config.moe_layer_frequency == 0
        self.use_residual = config.use_residual and is_moe_layer
317
318
319
320
321
322
        self.self_attn = ArcticAttention(
            config,
            cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.self_attn",
        )
323
324
325
        self.block_sparse_moe = ArcticMoE(
            config,
            quant_config=quant_config,
326
327
328
            reduce_results=(not self.use_residual),
            prefix=f"{prefix}.block_sparse_moe",
        )
329

330
331
332
333
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )
334
335

        if self.use_residual:
336
337
338
339
340
341
342
343
344
            self.residual_layernorm = RMSNorm(
                config.hidden_size, eps=config.rms_norm_eps
            )
            self.residual_mlp = ArcticMLP(
                config,
                is_residual_mlp=True,
                reduce_results=False,
                prefix=f"{prefix}.residual_mlp",
            )
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        residual_input = hidden_states
        hidden_states = self.input_layernorm(hidden_states)
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
        )
        hidden_states = residual_input + hidden_states

        residual_attn = hidden_states
        if self.use_residual:
            hidden_states = self.residual_layernorm(hidden_states)
            hidden_states = self.residual_mlp(hidden_states)
            residual_mlp = hidden_states
            hidden_states = self.post_attention_layernorm(residual_input)
            hidden_states = self.block_sparse_moe(hidden_states)
            hidden_states = residual_mlp + hidden_states
            hidden_states = tensor_model_parallel_all_reduce(hidden_states)
            hidden_states = residual_attn + hidden_states
        else:
            hidden_states = self.post_attention_layernorm(hidden_states)
            hidden_states = self.block_sparse_moe(hidden_states)
            hidden_states = residual_attn + hidden_states
        return hidden_states


376
@support_torch_compile
377
class ArcticModel(nn.Module):
378
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
379
        super().__init__()
380
381
382
383
384

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

385
        self.config = config
386
387
        self.vocab_size = config.vocab_size
        self.embed_tokens = VocabParallelEmbedding(
388
389
            self.vocab_size, config.hidden_size, org_num_embeddings=self.vocab_size
        )
390
391
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
392
            lambda prefix: ArcticDecoderLayer(
393
394
395
396
                config, cache_config, quant_config, prefix=prefix
            ),
            prefix=f"{prefix}.layers",
        )
397
398
        self._attn_implementation = config._attn_implementation
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
399
400
401
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states"], config.hidden_size
        )
402

403
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
404
405
        return self.embed_tokens(input_ids)

406
407
    def forward(
        self,
408
        input_ids: torch.Tensor | None,
409
        positions: torch.Tensor,
410
411
412
        intermediate_tensors: IntermediateTensors | None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
413
        if get_pp_group().is_first_rank:
414
415
416
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
417
                hidden_states = self.embed_input_ids(input_ids)
418
419
420
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
421
        for layer in islice(self.layers, self.start_layer, self.end_layer):
422
            hidden_states = layer(positions, hidden_states)
423
424
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})
425
426
427
        hidden_states = self.norm(hidden_states)
        return hidden_states

428
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
429
430
431
432
433
434
435
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]

436
437
        mlp_params_mapping: list[tuple[str, str, int]] = []
        expert_params_mapping: list[tuple[str, str, int]] = []
438
439
440
441

        for layer in range(self.config.num_hidden_layers):
            is_moe_layer = (layer + 1) % self.config.moe_layer_frequency == 0
            if is_moe_layer and self.config.use_residual:
442
                mlp_params_mapping.append(
443
                    (
444
445
                        f"layers.{layer}.residual_mlp.w13.weight",
                        f"layers.{layer}.residual_mlp.w1.weight",
446
447
448
                        0,
                    )
                )
449
                mlp_params_mapping.append(
450
                    (
451
452
                        f"layers.{layer}.residual_mlp.w13.weight",
                        f"layers.{layer}.residual_mlp.w3.weight",
453
454
455
                        1,
                    )
                )
456
457

            if is_moe_layer:
458
459
                for expert_id in range(self.config.num_local_experts):
                    expert_params_mapping.append(
460
461
                        ("ws", f"experts.{expert_id}.w1.weight", expert_id)
                    )
462
                    expert_params_mapping.append(
463
464
                        ("w2s", f"experts.{expert_id}.w2.weight", expert_id)
                    )
465
                    expert_params_mapping.append(
466
467
                        ("ws", f"experts.{expert_id}.w3.weight", expert_id)
                    )
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
            else:
                mlp_params_mapping.append(
                    (
                        f"layers.{layer}.block_sparse_moe.mlp.w13.weight",
                        f"layers.{layer}.block_sparse_moe.mlp.w1.weight",
                        0,
                    )
                )
                mlp_params_mapping.append(
                    (
                        f"layers.{layer}.block_sparse_moe.mlp.w13.weight",
                        f"layers.{layer}.block_sparse_moe.mlp.w3.weight",
                        1,
                    )
                )
483
484

        params_dict = dict(self.named_parameters())
485
        loaded_params: set[str] = set()
486
487

        for name, loaded_weight in weights:
488
            for param_name, weight_name, shard_id in stacked_params_mapping:
489
490
491
492
493
494
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
495
496
                if is_pp_missing_parameter(name, self):
                    continue
497
498
499
500
501
502
503
504
505
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                for param_name, weight_name, shard_id in mlp_params_mapping:
                    if weight_name not in name:
                        continue
                    name = name.replace(weight_name, param_name)
506
507
                    if is_pp_missing_parameter(name, self):
                        continue
508
509
510
511
512
                    param = params_dict[name]
                    weight_loader = param.weight_loader
                    weight_loader(param, loaded_weight, shard_id)
                    break
                else:
513
                    for param_name, weight_name, shard_id in expert_params_mapping:
514
515
516
                        if weight_name not in name:
                            continue
                        name = name.replace(weight_name, param_name)
517
518
                        if is_pp_missing_parameter(name, self):
                            continue
519
520
                        param = params_dict[name]
                        weight_loader = param.weight_loader
521
522
523
                        weight_loader(
                            param, loaded_weight, weight_name, expert_id=shard_id
                        )
524
525
526
527
                        break
                    else:
                        if name.endswith(".bias") and name not in params_dict:
                            continue
528
529
                        if is_pp_missing_parameter(name, self):
                            continue
530
                        param = params_dict[name]
531
532
533
                        weight_loader = getattr(
                            param, "weight_loader", default_weight_loader
                        )
534
                        weight_loader(param, loaded_weight)
535
536
            loaded_params.add(name)
        return loaded_params
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594


class ArcticForCausalLM(nn.Module, SupportsPP, SupportsQuant):
    packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]}

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        self.config = config
        self.model = ArcticModel(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
        self.vocab_size = config.vocab_size
        self.lm_head = ParallelLMHead(
            self.vocab_size,
            config.hidden_size,
            quant_config=quant_config,
            prefix=maybe_prefix(prefix, "lm_head"),
        )
        if self.config.tie_word_embeddings:
            self.lm_head.weight = self.model.embed_tokens.weight
        self.num_experts = config.num_local_experts
        self.num_experts_per_tok = config.num_experts_per_tok

        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors
        )

    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.embed_input_ids(input_ids)

    def forward(
        self,
        input_ids: torch.Tensor | None,
        positions: torch.Tensor,
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
        hidden_states = self.model(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor | None:
        logits = self.logits_processor(self.lm_head, hidden_states)
        return logits

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
        loader = AutoWeightsLoader(
            self,
            skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
        )
        return loader.load_weights(weights)