arctic.py 23.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Inference-only Snowflake Arctic model."""
4

5
from collections.abc import Iterable
6
from itertools import islice
7
8
9
10

import torch
from torch import nn

11
from vllm.attention.layer import Attention
12
from vllm.compilation.decorators import support_torch_compile
13
from vllm.config import CacheConfig, VllmConfig
14
15
16
17
18
19
from vllm.distributed import (
    get_pp_group,
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
    tensor_model_parallel_all_reduce,
)
20
21
22
23
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk
from vllm.model_executor.layers.layernorm import RMSNorm
24
25
26
27
28
29
from vllm.model_executor.layers.linear import (
    MergedColumnParallelLinear,
    QKVParallelLinear,
    ReplicatedLinear,
    RowParallelLinear,
)
30
from vllm.model_executor.layers.logits_processor import LogitsProcessor
31
from vllm.model_executor.layers.quantization import QuantizationConfig
32
from vllm.model_executor.layers.quantization.deepspeedfp import (
33
34
35
    DeepSpeedFPConfig,
    DeepSpeedFPParameter,
)
36
37
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
38
39
40
    ParallelLMHead,
    VocabParallelEmbedding,
)
41
42
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.utils import set_weight_attrs
43
from vllm.platforms import current_platform
44
from vllm.sequence import IntermediateTensors
45
46
from vllm.transformers_utils.configs.arctic import ArcticConfig

47
from .interfaces import SupportsPP, SupportsQuant
48
49
50
51
52
53
54
from .utils import (
    extract_layer_index,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
55

56
57
58
59
logger = init_logger(__name__)


class ArcticMLP(nn.Module):
60
61
62
63
64
    def __init__(
        self,
        config: ArcticConfig,
        expert_id: int = -1,
        is_residual_mlp: bool = False,
65
        quant_config: QuantizationConfig | None = None,
66
67
68
        reduce_results: bool = True,
        prefix: str = "",
    ):
69
        super().__init__()
70
71
72
        self.hidden_size = config.hidden_size
        self.expert_id = expert_id

73
74
75
76
77
        self.ffn_dim = (
            config.intermediate_size if not is_residual_mlp else self.hidden_size
        )

        self.w13 = MergedColumnParallelLinear(
78
79
80
81
82
            self.hidden_size,
            [self.ffn_dim] * 2,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.w13",
83
84
85
86
87
88
89
        )
        self.w2 = RowParallelLinear(
            self.ffn_dim,
            self.hidden_size,
            bias=False,
            reduce_results=reduce_results,
            quant_config=quant_config,
90
            prefix=f"{prefix}.w2",
91
        )
92
        if config.hidden_act != "silu":
93
94
95
96
            raise ValueError(
                f"Unsupported activation: {config.hidden_act}. "
                "Only silu is supported for now."
            )
97
98
99
100
101
102
103
104
105
106
107
108
109
110
        self.act_fn = SiluAndMul()

    def forward(self, hidden_states):
        gate_up, _ = self.w13(hidden_states)
        hidden_states = self.act_fn(gate_up)
        hidden_states, _ = self.w2(hidden_states)
        return hidden_states


class ArcticMoE(nn.Module):
    """
    Model-parallel implementation of Arctic MoE Layer.
    """

111
112
113
    def __init__(
        self,
        config: ArcticConfig,
114
115
116
        tp_size: int | None = None,
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
117
118
119
        reduce_results: bool = True,
        prefix: str = "",
    ):
120
        super().__init__()
121

122
        layer_id = extract_layer_index(prefix)
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
        self.tp_size = tp_size or get_tensor_model_parallel_world_size()
        self.hidden_size = config.hidden_size
        self.num_experts = config.num_local_experts
        self.layer_id = layer_id
        self.top_k = config.num_experts_per_tok
        self.intermediate_size = config.intermediate_size // self.tp_size

        self.is_moe_layer = (layer_id + 1) % config.moe_layer_frequency == 0
        self.is_quant = isinstance(quant_config, DeepSpeedFPConfig)
        self.reduce_results = reduce_results
        # Some other parameters
        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
        self.params_dtype = params_dtype

        if not self.is_moe_layer:
139
140
141
142
143
144
            self.mlp = ArcticMLP(
                config,
                quant_config=quant_config,
                reduce_results=reduce_results,
                prefix=f"{prefix}.mlp",
            )
145
        else:
146
147
148
149
150
151
152
153
            self.gate = ReplicatedLinear(
                self.hidden_size,
                self.num_experts,
                bias=False,
                params_dtype=self.params_dtype,
                quant_config=quant_config,
                prefix=f"{prefix}.gate",
            )
154
155
            if self.is_quant:
                self.ws = DeepSpeedFPParameter(
156
157
158
                    torch.Size(
                        (self.num_experts, 2 * self.intermediate_size, self.hidden_size)
                    ),
159
160
161
162
                    params_dtype=params_dtype,
                    quant_config=quant_config,
                )
                self.w2s = DeepSpeedFPParameter(
163
164
165
                    torch.Size(
                        (self.num_experts, self.hidden_size, self.intermediate_size)
                    ),
166
167
168
169
170
                    params_dtype=params_dtype,
                    quant_config=quant_config,
                )
            else:
                self.ws = nn.Parameter(
171
172
173
174
175
176
177
178
                    torch.empty(
                        self.num_experts,
                        2 * self.intermediate_size,
                        self.hidden_size,
                        device=current_platform.device_type,
                        dtype=self.params_dtype,
                    )
                )
179
                self.w2s = nn.Parameter(
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
                    torch.empty(
                        self.num_experts,
                        self.hidden_size,
                        self.intermediate_size,
                        device=current_platform.device_type,
                        dtype=self.params_dtype,
                    )
                )
            set_weight_attrs(
                self.ws,
                {
                    "weight_loader": self.weight_loader,
                },
            )
            set_weight_attrs(
                self.w2s,
                {
                    "weight_loader": self.weight_loader,
                },
            )

    def weight_loader(
        self,
        param: nn.Parameter,
        loaded_weight: torch.Tensor,
        weight_name: str,
        expert_id: int,
    ):
208
209
210
211
212
213
214
        tp_rank = get_tensor_model_parallel_rank()
        param_data = param.ds_dequantize() if self.is_quant else param.data
        shard_size = self.intermediate_size
        shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
        if weight_name.endswith("w1.weight"):
            param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
        if weight_name.endswith("w3.weight"):
215
216
217
            param_data[expert_id, shard_size : 2 * shard_size, :] = loaded_weight[
                shard, :
            ]
218
219
220
221
222
223
224
225
226
227
228
        if weight_name.endswith("w2.weight"):
            param_data[expert_id, :, :] = loaded_weight[:, shard]
        if self.is_quant:
            param.ds_quantize_(param_data)

    def local_moe_fused(self, hidden_states: torch.Tensor) -> torch.Tensor:
        num_tokens, hidden_size = hidden_states.shape
        hidden_states = hidden_states.view(-1, self.hidden_size)
        # router_logits: (num_tokens, n_experts)
        router_logits, _ = self.gate(hidden_states)
        do_normalize = self.top_k > 1
229
        topk_weights, topk_ids, token_expert_indices = fused_topk(
230
231
            hidden_states, router_logits, self.top_k, renormalize=do_normalize
        )
232
233
234
235
        # topk_ids: (num_tokens, k)
        if self.is_quant:
            if 2 * num_tokens <= self.num_experts:
                # If much fewer tokens than experts, use selective dequantize.
236
237
                ws_dequantized = self.ws.ds_selective_dequantize(topk_ids.flatten())
                w2s_dequantized = self.w2s.ds_selective_dequantize(topk_ids.flatten())
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
                # We gathered the experts to the tokens so update the mapping.
                topk_ids = torch.arange(
                    0,
                    topk_ids.numel(),
                    device=topk_ids.device,
                ).reshape(topk_ids.shape)
            else:
                ws_dequantized = self.ws.ds_dequantize()
                w2s_dequantized = self.w2s.ds_dequantize()

        final_hidden_states = fused_experts(
            hidden_states,
            ws_dequantized if self.is_quant else self.ws,
            w2s_dequantized if self.is_quant else self.w2s,
            topk_weights,
            topk_ids,
254
255
            inplace=True,
        )
256
        if self.reduce_results and self.tp_size > 1:
257
            final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
258
259
260
261
262
263
264
265
266
267
268
269
270
271
        return final_hidden_states.view(num_tokens, hidden_size)

    def forward(self, hidden_states: torch.Tensor):
        if self.is_moe_layer:
            final_hidden_states = self.local_moe_fused(hidden_states)
        else:
            final_hidden_states = self.mlp(hidden_states)
        return final_hidden_states


class ArcticAttention(nn.Module):
    def __init__(
        self,
        config: ArcticConfig,
272
273
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
274
        prefix: str = "",
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
    ):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size

        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = config.num_attention_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = config.num_key_value_heads
        if self.total_num_kv_heads >= tp_size:
            assert self.total_num_kv_heads % tp_size == 0
        else:
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = self.hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim

        self.max_position_embeddings = config.max_position_embeddings
        self.scaling = self.head_dim**-0.5

297
298
299
300
301
302
303
        self.qkv_proj = QKVParallelLinear(
            self.hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=False,
            quant_config=quant_config,
304
            prefix=f"{prefix}.qkv_proj",
305
        )
306
307
308
309
310
311
        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            self.hidden_size,
            bias=False,
            reduce_results=True,
            quant_config=quant_config,
312
            prefix=f"{prefix}.o_proj",
313
314
315
316
317
318
        )

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=self.max_position_embeddings,
319
            rope_parameters=config.rope_parameters,
320
321
322
            is_neox_style=True,
        )

323
324
325
326
327
328
329
330
331
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
        )
332
333
334
335
336
337
338
339
340

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
341
        attn_output = self.attn(q, k, v)
342
343
344
345
346
347
348
349
        output, _ = self.o_proj(attn_output)
        return output


class ArcticDecoderLayer(nn.Module):
    def __init__(
        self,
        config: ArcticConfig,
350
351
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
352
        prefix: str = "",
353
354
355
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
356
        layer_idx = extract_layer_index(prefix)
357
358
        is_moe_layer = (layer_idx + 1) % config.moe_layer_frequency == 0
        self.use_residual = config.use_residual and is_moe_layer
359
360
361
362
363
364
        self.self_attn = ArcticAttention(
            config,
            cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.self_attn",
        )
365
366
367
        self.block_sparse_moe = ArcticMoE(
            config,
            quant_config=quant_config,
368
369
370
            reduce_results=(not self.use_residual),
            prefix=f"{prefix}.block_sparse_moe",
        )
371

372
373
374
375
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )
376
377

        if self.use_residual:
378
379
380
381
382
383
384
385
386
            self.residual_layernorm = RMSNorm(
                config.hidden_size, eps=config.rms_norm_eps
            )
            self.residual_mlp = ArcticMLP(
                config,
                is_residual_mlp=True,
                reduce_results=False,
                prefix=f"{prefix}.residual_mlp",
            )
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        residual_input = hidden_states
        hidden_states = self.input_layernorm(hidden_states)
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
        )
        hidden_states = residual_input + hidden_states

        residual_attn = hidden_states
        if self.use_residual:
            hidden_states = self.residual_layernorm(hidden_states)
            hidden_states = self.residual_mlp(hidden_states)
            residual_mlp = hidden_states
            hidden_states = self.post_attention_layernorm(residual_input)
            hidden_states = self.block_sparse_moe(hidden_states)
            hidden_states = residual_mlp + hidden_states
            hidden_states = tensor_model_parallel_all_reduce(hidden_states)
            hidden_states = residual_attn + hidden_states
        else:
            hidden_states = self.post_attention_layernorm(hidden_states)
            hidden_states = self.block_sparse_moe(hidden_states)
            hidden_states = residual_attn + hidden_states
        return hidden_states


418
@support_torch_compile
419
class ArcticModel(nn.Module):
420
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
421
        super().__init__()
422
423
424
425
426

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

427
428
        self.vocab_size = config.vocab_size
        self.embed_tokens = VocabParallelEmbedding(
429
430
            self.vocab_size, config.hidden_size, org_num_embeddings=self.vocab_size
        )
431
432
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
433
            lambda prefix: ArcticDecoderLayer(
434
435
436
437
                config, cache_config, quant_config, prefix=prefix
            ),
            prefix=f"{prefix}.layers",
        )
438
439
        self._attn_implementation = config._attn_implementation
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
440
441
442
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states"], config.hidden_size
        )
443

444
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
445
446
        return self.embed_tokens(input_ids)

447
448
449
450
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
451
452
453
        intermediate_tensors: IntermediateTensors | None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
454
        if get_pp_group().is_first_rank:
455
456
457
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
458
                hidden_states = self.embed_input_ids(input_ids)
459
460
461
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
462
        for layer in islice(self.layers, self.start_layer, self.end_layer):
463
            hidden_states = layer(positions, hidden_states)
464
465
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})
466
467
468
469
        hidden_states = self.norm(hidden_states)
        return hidden_states


470
471
class ArcticForCausalLM(nn.Module, SupportsPP, SupportsQuant):
    packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]}
472

473
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
474
        super().__init__()
475
476
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
477
        self.config = config
478
479
480
        self.model = ArcticModel(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
481
482
483
484
        self.vocab_size = config.vocab_size
        self.lm_head = ParallelLMHead(
            self.vocab_size,
            config.hidden_size,
485
            quant_config=quant_config,
486
            prefix=maybe_prefix(prefix, "lm_head"),
487
        )
488
489
        if self.config.tie_word_embeddings:
            self.lm_head.weight = self.model.embed_tokens.weight
490
491
        self.num_experts = config.num_local_experts
        self.num_experts_per_tok = config.num_experts_per_tok
492
493

        self.logits_processor = LogitsProcessor(config.vocab_size)
494
        self.make_empty_intermediate_tensors = (
495
496
            self.model.make_empty_intermediate_tensors
        )
497

498
499
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.embed_input_ids(input_ids)
500

501
502
503
504
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
505
506
507
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
508
509
510
        hidden_states = self.model(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
511
512
        return hidden_states

513
514
515
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
516
    ) -> torch.Tensor | None:
517
        logits = self.logits_processor(self.lm_head, hidden_states)
518
519
        return logits

520
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
521
522
523
524
525
526
527
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]

528
529
        mlp_params_mapping: list[tuple[str, str, int]] = []
        expert_params_mapping: list[tuple[str, str, int]] = []
530
531
532
533
        num_layers = self.config.num_hidden_layers

        for layer in range(num_layers):
            mlp_params_mapping.append(
534
535
536
537
538
539
                (
                    f"layers.{layer}.residual_mlp.w13.weight",
                    f"layers.{layer}.residual_mlp.w1.weight",
                    0,
                )
            )
540
            mlp_params_mapping.append(
541
542
543
544
545
546
                (
                    f"layers.{layer}.residual_mlp.w13.weight",
                    f"layers.{layer}.residual_mlp.w3.weight",
                    1,
                )
            )
547
548
549
            if layer % 2 == 0:
                # MLP layers
                mlp_params_mapping.append(
550
551
552
553
554
555
                    (
                        f"layers.{layer}.block_sparse_moe.mlp.w13.weight",
                        f"layers.{layer}.block_sparse_moe.mlp.w1.weight",
                        0,
                    )
                )
556
                mlp_params_mapping.append(
557
558
559
560
561
562
                    (
                        f"layers.{layer}.block_sparse_moe.mlp.w13.weight",
                        f"layers.{layer}.block_sparse_moe.mlp.w3.weight",
                        1,
                    )
                )
563
564
565
566
            else:
                # MoE layers
                for expert_id in range(self.config.num_local_experts):
                    expert_params_mapping.append(
567
568
                        ("ws", f"experts.{expert_id}.w1.weight", expert_id)
                    )
569
                    expert_params_mapping.append(
570
571
                        ("w2s", f"experts.{expert_id}.w2.weight", expert_id)
                    )
572
                    expert_params_mapping.append(
573
574
                        ("ws", f"experts.{expert_id}.w3.weight", expert_id)
                    )
575
576

        params_dict = dict(self.named_parameters())
577
        loaded_params: set[str] = set()
578
579
580
581

        logger.info(
            "It will take ~10 minutes loading from the 16-bit weights. "
            "Alternatively, use the prequantized 8-bit weights of arctic "
582
583
            "and set load-format to `sharded_state` will accelerate loading."
        )
584
        for name, loaded_weight in weights:
585
            for param_name, weight_name, shard_id in stacked_params_mapping:
586
587
588
589
590
591
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
592
593
                if is_pp_missing_parameter(name, self):
                    continue
594
595
596
597
598
599
600
601
602
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                for param_name, weight_name, shard_id in mlp_params_mapping:
                    if weight_name not in name:
                        continue
                    name = name.replace(weight_name, param_name)
603
604
                    if is_pp_missing_parameter(name, self):
                        continue
605
606
607
608
609
                    param = params_dict[name]
                    weight_loader = param.weight_loader
                    weight_loader(param, loaded_weight, shard_id)
                    break
                else:
610
                    for param_name, weight_name, shard_id in expert_params_mapping:
611
612
613
                        if weight_name not in name:
                            continue
                        name = name.replace(weight_name, param_name)
614
615
                        if is_pp_missing_parameter(name, self):
                            continue
616
617
                        param = params_dict[name]
                        weight_loader = param.weight_loader
618
619
620
                        weight_loader(
                            param, loaded_weight, weight_name, expert_id=shard_id
                        )
621
622
623
624
                        break
                    else:
                        if name.endswith(".bias") and name not in params_dict:
                            continue
625
626
                        if is_pp_missing_parameter(name, self):
                            continue
627
628
                        param = params_dict[name]

629
630
631
                        weight_loader = getattr(
                            param, "weight_loader", default_weight_loader
                        )
632
                        weight_loader(param, loaded_weight)
633
634
            loaded_params.add(name)
        return loaded_params