arctic.py 23.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
"""Inference-only Snowflake Arctic model."""
3
from typing import Iterable, List, Optional, Set, Tuple, Union
4
5
6
7

import torch
from torch import nn

8
from vllm.attention import Attention
9
from vllm.compilation.decorators import support_torch_compile
10
from vllm.config import CacheConfig, VllmConfig
11
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
12
13
14
15
16
17
18
19
20
21
22
                              get_tensor_model_parallel_world_size,
                              tensor_model_parallel_all_reduce)
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
                                               QKVParallelLinear,
                                               ReplicatedLinear,
                                               RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
23
from vllm.model_executor.layers.quantization import QuantizationConfig
24
25
26
27
28
29
30
31
from vllm.model_executor.layers.quantization.deepspeedfp import (
    DeepSpeedFPConfig, DeepSpeedFPParameter)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
    ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs
32
from vllm.platforms import current_platform
33
from vllm.sequence import IntermediateTensors
34
35
from vllm.transformers_utils.configs.arctic import ArcticConfig

36
from .interfaces import SupportsPP, SupportsQuant
37
from .utils import (extract_layer_index, is_pp_missing_parameter,
38
39
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
40

41
42
43
44
45
46
47
48
49
50
logger = init_logger(__name__)


class ArcticMLP(nn.Module):

    def __init__(self,
                 config: ArcticConfig,
                 expert_id: int = -1,
                 is_residual_mlp: bool = False,
                 quant_config: Optional[QuantizationConfig] = None,
51
52
                 reduce_results: bool = True,
                 prefix: str = ""):
53
        super().__init__()
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
        self.hidden_size = config.hidden_size
        self.expert_id = expert_id

        self.ffn_dim = config.intermediate_size if not is_residual_mlp \
            else self.hidden_size

        self.w13 = MergedColumnParallelLinear(self.hidden_size,
                                              [self.ffn_dim] * 2,
                                              bias=False,
                                              quant_config=quant_config)
        self.w2 = RowParallelLinear(self.ffn_dim,
                                    self.hidden_size,
                                    bias=False,
                                    reduce_results=reduce_results,
                                    quant_config=quant_config)
        if config.hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {config.hidden_act}. "
                             "Only silu is supported for now.")
        self.act_fn = SiluAndMul()

    def forward(self, hidden_states):
        gate_up, _ = self.w13(hidden_states)
        hidden_states = self.act_fn(gate_up)
        hidden_states, _ = self.w2(hidden_states)
        return hidden_states


class ArcticMoE(nn.Module):
    """
    Model-parallel implementation of Arctic MoE Layer.
    """

    def __init__(self,
                 config: ArcticConfig,
                 tp_size: Optional[int] = None,
                 params_dtype: Optional[torch.dtype] = None,
                 quant_config: Optional[QuantizationConfig] = None,
91
92
                 reduce_results: bool = True,
                 prefix: str = ""):
93
        super().__init__()
94

95
        layer_id = extract_layer_index(prefix)
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
        self.tp_size = tp_size or get_tensor_model_parallel_world_size()
        self.hidden_size = config.hidden_size
        self.num_experts = config.num_local_experts
        self.layer_id = layer_id
        self.top_k = config.num_experts_per_tok
        self.intermediate_size = config.intermediate_size // self.tp_size

        self.is_moe_layer = (layer_id + 1) % config.moe_layer_frequency == 0
        self.is_quant = isinstance(quant_config, DeepSpeedFPConfig)
        self.reduce_results = reduce_results
        # Some other parameters
        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
        self.params_dtype = params_dtype

        if not self.is_moe_layer:
            self.mlp = ArcticMLP(config,
                                 quant_config=quant_config,
114
115
                                 reduce_results=reduce_results,
                                 prefix=f"{prefix}.mlp")
116
117
118
119
120
        else:
            self.gate = ReplicatedLinear(self.hidden_size,
                                         self.num_experts,
                                         bias=False,
                                         params_dtype=self.params_dtype,
121
122
                                         quant_config=quant_config,
                                         prefix=f"{prefix}.gate")
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
            if self.is_quant:
                self.ws = DeepSpeedFPParameter(
                    torch.Size((self.num_experts, 2 * self.intermediate_size,
                                self.hidden_size)),
                    params_dtype=params_dtype,
                    quant_config=quant_config,
                )
                self.w2s = DeepSpeedFPParameter(
                    torch.Size((self.num_experts, self.hidden_size,
                                self.intermediate_size)),
                    params_dtype=params_dtype,
                    quant_config=quant_config,
                )
            else:
                self.ws = nn.Parameter(
                    torch.empty(self.num_experts,
                                2 * self.intermediate_size,
                                self.hidden_size,
141
                                device=current_platform.device_type,
142
143
144
145
146
                                dtype=self.params_dtype))
                self.w2s = nn.Parameter(
                    torch.empty(self.num_experts,
                                self.hidden_size,
                                self.intermediate_size,
147
                                device=current_platform.device_type,
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
                                dtype=self.params_dtype))
            set_weight_attrs(self.ws, {
                "weight_loader": self.weight_loader,
            })
            set_weight_attrs(self.w2s, {
                "weight_loader": self.weight_loader,
            })

    def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
                      weight_name: str, expert_id: int):
        tp_rank = get_tensor_model_parallel_rank()
        param_data = param.ds_dequantize() if self.is_quant else param.data
        shard_size = self.intermediate_size
        shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
        if weight_name.endswith("w1.weight"):
            param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
        if weight_name.endswith("w3.weight"):
            param_data[expert_id,
                       shard_size:2 * shard_size, :] = loaded_weight[shard, :]
        if weight_name.endswith("w2.weight"):
            param_data[expert_id, :, :] = loaded_weight[:, shard]
        if self.is_quant:
            param.ds_quantize_(param_data)

    def local_moe_fused(self, hidden_states: torch.Tensor) -> torch.Tensor:
        num_tokens, hidden_size = hidden_states.shape
        hidden_states = hidden_states.view(-1, self.hidden_size)
        # router_logits: (num_tokens, n_experts)
        router_logits, _ = self.gate(hidden_states)
        do_normalize = self.top_k > 1
178
179
        topk_weights, topk_ids, token_expert_indices = fused_topk(
            hidden_states, router_logits, self.top_k, renormalize=do_normalize)
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
        # topk_ids: (num_tokens, k)
        if self.is_quant:
            if 2 * num_tokens <= self.num_experts:
                # If much fewer tokens than experts, use selective dequantize.
                ws_dequantized = self.ws.ds_selective_dequantize(
                    topk_ids.flatten())
                w2s_dequantized = self.w2s.ds_selective_dequantize(
                    topk_ids.flatten())
                # We gathered the experts to the tokens so update the mapping.
                topk_ids = torch.arange(
                    0,
                    topk_ids.numel(),
                    device=topk_ids.device,
                ).reshape(topk_ids.shape)
            else:
                ws_dequantized = self.ws.ds_dequantize()
                w2s_dequantized = self.w2s.ds_dequantize()

        final_hidden_states = fused_experts(
            hidden_states,
            ws_dequantized if self.is_quant else self.ws,
            w2s_dequantized if self.is_quant else self.w2s,
            topk_weights,
            topk_ids,
            inplace=True)
        if self.reduce_results and self.tp_size > 1:
            final_hidden_states = tensor_model_parallel_all_reduce(
                final_hidden_states)
        return final_hidden_states.view(num_tokens, hidden_size)

    def forward(self, hidden_states: torch.Tensor):
        if self.is_moe_layer:
            final_hidden_states = self.local_moe_fused(hidden_states)
        else:
            final_hidden_states = self.mlp(hidden_states)
        return final_hidden_states


class ArcticAttention(nn.Module):

    def __init__(
        self,
        config: ArcticConfig,
223
        cache_config: Optional[CacheConfig] = None,
224
        quant_config: Optional[QuantizationConfig] = None,
225
        prefix: str = "",
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
    ):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size

        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = config.num_attention_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = config.num_key_value_heads
        if self.total_num_kv_heads >= tp_size:
            assert self.total_num_kv_heads % tp_size == 0
        else:
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = self.hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim

        self.max_position_embeddings = config.max_position_embeddings
        self.rope_theta = config.rope_theta
        self.scaling = self.head_dim**-0.5

        self.qkv_proj = QKVParallelLinear(self.hidden_size,
                                          self.head_dim,
                                          self.total_num_heads,
                                          self.total_num_kv_heads,
                                          bias=False,
                                          quant_config=quant_config)
        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            self.hidden_size,
            bias=False,
            reduce_results=True,
            quant_config=quant_config,
        )

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=self.max_position_embeddings,
            base=int(self.rope_theta),
            is_neox_style=True,
        )

        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              self.scaling,
274
                              num_kv_heads=self.num_kv_heads,
275
                              cache_config=cache_config,
276
277
                              quant_config=quant_config,
                              prefix=f"{prefix}.attn")
278
279
280
281
282
283
284
285
286

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
287
        attn_output = self.attn(q, k, v)
288
289
290
291
292
293
294
295
296
        output, _ = self.o_proj(attn_output)
        return output


class ArcticDecoderLayer(nn.Module):

    def __init__(
        self,
        config: ArcticConfig,
297
        cache_config: Optional[CacheConfig] = None,
298
        quant_config: Optional[QuantizationConfig] = None,
299
        prefix: str = "",
300
301
302
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
303
        layer_idx = extract_layer_index(prefix)
304
305
306
        is_moe_layer = (layer_idx + 1) % config.moe_layer_frequency == 0
        self.use_residual = config.use_residual and is_moe_layer
        self.self_attn = ArcticAttention(config,
307
                                         cache_config,
308
309
                                         quant_config=quant_config,
                                         prefix=f"{prefix}.self_attn")
310
311
312
        self.block_sparse_moe = ArcticMoE(
            config,
            quant_config=quant_config,
313
314
315
            reduce_results=(not self.use_residual),
            prefix=f"{prefix}.block_sparse_moe",
        )
316
317
318
319
320
321
322
323
324
325
326

        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps=config.rms_norm_eps)

        if self.use_residual:
            self.residual_layernorm = RMSNorm(config.hidden_size,
                                              eps=config.rms_norm_eps)
            self.residual_mlp = ArcticMLP(config,
                                          is_residual_mlp=True,
327
328
                                          reduce_results=False,
                                          prefix=f"{prefix}.residual_mlp")
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        residual_input = hidden_states
        hidden_states = self.input_layernorm(hidden_states)
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
        )
        hidden_states = residual_input + hidden_states

        residual_attn = hidden_states
        if self.use_residual:
            hidden_states = self.residual_layernorm(hidden_states)
            hidden_states = self.residual_mlp(hidden_states)
            residual_mlp = hidden_states
            hidden_states = self.post_attention_layernorm(residual_input)
            hidden_states = self.block_sparse_moe(hidden_states)
            hidden_states = residual_mlp + hidden_states
            hidden_states = tensor_model_parallel_all_reduce(hidden_states)
            hidden_states = residual_attn + hidden_states
        else:
            hidden_states = self.post_attention_layernorm(hidden_states)
            hidden_states = self.block_sparse_moe(hidden_states)
            hidden_states = residual_attn + hidden_states
        return hidden_states


360
@support_torch_compile
361
362
class ArcticModel(nn.Module):

363
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
364
        super().__init__()
365
366
367
368
369

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

370
371
372
373
374
        self.vocab_size = config.vocab_size
        self.embed_tokens = VocabParallelEmbedding(
            self.vocab_size,
            config.hidden_size,
            org_num_embeddings=self.vocab_size)
375
376
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
377
378
            lambda prefix: ArcticDecoderLayer(
                config, cache_config, quant_config, prefix=prefix),
379
            prefix=f"{prefix}.layers")
380
381
        self._attn_implementation = config._attn_implementation
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
382
383
384
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(["hidden_states"],
                                                    config.hidden_size))
385

386
387
388
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

389
390
391
392
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
393
        intermediate_tensors: Optional[IntermediateTensors],
394
        inputs_embeds: Optional[torch.Tensor] = None,
395
396
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
397
398
399
400
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
401
402
403
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
404
405
        for layer in self.layers[self.start_layer:self.end_layer]:
            hidden_states = layer(positions, hidden_states)
406
407
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})
408
409
410
411
        hidden_states = self.norm(hidden_states)
        return hidden_states


412
413
class ArcticForCausalLM(nn.Module, SupportsPP, SupportsQuant):
    packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]}
414

415
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
416
        super().__init__()
417
418
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
419
        self.config = config
420
421
        self.model = ArcticModel(vllm_config=vllm_config,
                                 prefix=maybe_prefix(prefix, "model"))
422
423
424
425
        self.vocab_size = config.vocab_size
        self.lm_head = ParallelLMHead(
            self.vocab_size,
            config.hidden_size,
426
            quant_config=quant_config,
427
        )
428
429
        if self.config.tie_word_embeddings:
            self.lm_head.weight = self.model.embed_tokens.weight
430
431
432
433
434
        self.num_experts = config.num_local_experts
        self.num_experts_per_tok = config.num_experts_per_tok
        self.unpadded_vocab_size = config.vocab_size
        self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                config.vocab_size)
435
436
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)
437

438
439
440
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

441
442
443
444
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
445
        intermediate_tensors: Optional[IntermediateTensors] = None,
446
        inputs_embeds: Optional[torch.Tensor] = None,
447
    ) -> Union[torch.Tensor, IntermediateTensors]:
448
        hidden_states = self.model(input_ids, positions, intermediate_tensors,
449
                                   inputs_embeds)
450
451
        return hidden_states

452
453
454
455
456
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
457
        logits = self.logits_processor(self.lm_head, hidden_states,
458
459
460
                                       sampling_metadata)
        return logits

461
462
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
463
464
465
466
467
468
469
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]

470
471
        mlp_params_mapping: List[Tuple[str, str, int]] = []
        expert_params_mapping: List[Tuple[str, str, int]] = []
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
        num_layers = self.config.num_hidden_layers

        for layer in range(num_layers):
            mlp_params_mapping.append(
                (f"layers.{layer}.residual_mlp.w13.weight",
                 f"layers.{layer}.residual_mlp.w1.weight", 0))
            mlp_params_mapping.append(
                (f"layers.{layer}.residual_mlp.w13.weight",
                 f"layers.{layer}.residual_mlp.w3.weight", 1))
            if layer % 2 == 0:
                # MLP layers
                mlp_params_mapping.append(
                    (f"layers.{layer}.block_sparse_moe.mlp.w13.weight",
                     f"layers.{layer}.block_sparse_moe.mlp.w1.weight", 0))
                mlp_params_mapping.append(
                    (f"layers.{layer}.block_sparse_moe.mlp.w13.weight",
                     f"layers.{layer}.block_sparse_moe.mlp.w3.weight", 1))
            else:
                # MoE layers
                for expert_id in range(self.config.num_local_experts):
                    expert_params_mapping.append(
                        ("ws", f"experts.{expert_id}.w1.weight", expert_id))
                    expert_params_mapping.append(
                        ("w2s", f"experts.{expert_id}.w2.weight", expert_id))
                    expert_params_mapping.append(
                        ("ws", f"experts.{expert_id}.w3.weight", expert_id))

        params_dict = dict(self.named_parameters())
500
        loaded_params: Set[str] = set()
501
502
503
504
505
506
507
508
509
510
511
512
513

        logger.info(
            "It will take ~10 minutes loading from the 16-bit weights. "
            "Alternatively, use the prequantized 8-bit weights of arctic "
            "and set load-format to `sharded_state` will accelerate loading.")
        for name, loaded_weight in weights:
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
514
515
                if is_pp_missing_parameter(name, self):
                    continue
516
517
518
519
520
521
522
523
524
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                for param_name, weight_name, shard_id in mlp_params_mapping:
                    if weight_name not in name:
                        continue
                    name = name.replace(weight_name, param_name)
525
526
                    if is_pp_missing_parameter(name, self):
                        continue
527
528
529
530
531
532
533
534
535
536
                    param = params_dict[name]
                    weight_loader = param.weight_loader
                    weight_loader(param, loaded_weight, shard_id)
                    break
                else:
                    for param_name, weight_name, shard_id \
                            in expert_params_mapping:
                        if weight_name not in name:
                            continue
                        name = name.replace(weight_name, param_name)
537
538
                        if is_pp_missing_parameter(name, self):
                            continue
539
540
541
542
543
544
545
546
547
548
                        param = params_dict[name]
                        weight_loader = param.weight_loader
                        weight_loader(param,
                                      loaded_weight,
                                      weight_name,
                                      expert_id=shard_id)
                        break
                    else:
                        if name.endswith(".bias") and name not in params_dict:
                            continue
549
550
                        if is_pp_missing_parameter(name, self):
                            continue
551
552
553
554
555
                        param = params_dict[name]

                        weight_loader = getattr(param, "weight_loader",
                                                default_weight_loader)
                        weight_loader(param, loaded_weight)
556
557
            loaded_params.add(name)
        return loaded_params