olmoe.py 18.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

Niklas Muennighoff's avatar
Niklas Muennighoff committed
4
5
6
7
8
9
10
11
12
13
14
15
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only OLMoE model compatible with HuggingFace weights."""
16

17
from collections.abc import Iterable
18
from functools import partial
19
from itertools import islice
Niklas Muennighoff's avatar
Niklas Muennighoff committed
20
21
22
23

import torch
from torch import nn

24
from vllm.compilation.decorators import support_torch_compile
25
from vllm.config import VllmConfig
26
27
28
29
30
31
from vllm.distributed import (
    get_pp_group,
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
    tensor_model_parallel_all_gather,
)
32
from vllm.distributed.utils import split_tensor_along_last_dim
33
from vllm.logger import init_logger
34
from vllm.model_executor.layers.attention import Attention
35
36
37
38
from vllm.model_executor.layers.fused_moe import (
    FusedMoE,
    fused_moe_make_expert_params_mapping,
)
Niklas Muennighoff's avatar
Niklas Muennighoff committed
39
from vllm.model_executor.layers.layernorm import RMSNorm
40
41
42
43
44
from vllm.model_executor.layers.linear import (
    QKVParallelLinear,
    ReplicatedLinear,
    RowParallelLinear,
)
Niklas Muennighoff's avatar
Niklas Muennighoff committed
45
from vllm.model_executor.layers.logits_processor import LogitsProcessor
46
from vllm.model_executor.layers.quantization import QuantizationConfig
Niklas Muennighoff's avatar
Niklas Muennighoff committed
47
48
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
49
50
51
    ParallelLMHead,
    VocabParallelEmbedding,
)
Niklas Muennighoff's avatar
Niklas Muennighoff committed
52
53
54
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.sequence import IntermediateTensors

55
from .interfaces import SupportsLoRA, SupportsPP
56
57
58
59
60
61
62
from .utils import (
    AutoWeightsLoader,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
63

64
65
logger = init_logger(__name__)

Niklas Muennighoff's avatar
Niklas Muennighoff committed
66
67
68
69
70
71
72
73
74
75

class OlmoeMoE(nn.Module):
    """A tensor-parallel MoE implementation for Olmoe that shards each expert
    across all ranks.

    Each expert's weights are sharded across all ranks and a fused MoE
    kernel is used for the forward pass, and finally we reduce the outputs
    across ranks.
    """

76
77
78
79
80
81
    def __init__(
        self,
        num_experts: int,
        top_k: int,
        hidden_size: int,
        intermediate_size: int,
82
83
84
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
        tp_size: int | None = None,
85
86
        prefix: str = "",
    ):
Niklas Muennighoff's avatar
Niklas Muennighoff committed
87
88
89
90
        super().__init__()
        self.hidden_size = hidden_size

        # Gate always runs at half / full precision for now.
91
        self.gate = ReplicatedLinear(
92
93
94
95
96
            hidden_size,
            num_experts,
            bias=False,
            quant_config=None,
            prefix=f"{prefix}.gate",
97
98
99
100
101
102
103
104
105
106
107
108
        )

        self.experts = FusedMoE(
            num_experts=num_experts,
            top_k=top_k,
            hidden_size=hidden_size,
            intermediate_size=intermediate_size,
            renormalize=False,
            quant_config=quant_config,
            tp_size=tp_size,
            prefix=f"{prefix}.experts",
        )
Niklas Muennighoff's avatar
Niklas Muennighoff committed
109
110
111
112
113
114
115
116

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        # NOTE: hidden_states can have either 1D or 2D shape.
        orig_shape = hidden_states.shape
        hidden_dim = hidden_states.shape[-1]
        hidden_states = hidden_states.view(-1, hidden_dim)
        # router_logits: (num_tokens, n_experts)
        router_logits, _ = self.gate(hidden_states)
117
118
119
        final_hidden_states = self.experts(
            hidden_states=hidden_states, router_logits=router_logits
        )
Niklas Muennighoff's avatar
Niklas Muennighoff committed
120
121
122
123
        return final_hidden_states.view(orig_shape)


class OlmoeAttention(nn.Module):
124
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
Niklas Muennighoff's avatar
Niklas Muennighoff committed
125
        super().__init__()
126
127
128
129
130
131
132
133
134
135
136

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

        self.hidden_size = config.hidden_size
        max_position_embeddings = getattr(config, "max_position_embeddings", 4096)

        num_heads = config.num_attention_heads
        num_kv_heads = config.num_key_value_heads

Niklas Muennighoff's avatar
Niklas Muennighoff committed
137
138
139
140
141
142
143
144
145
146
147
148
149
150
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
151
        self.head_dim = self.hidden_size // self.total_num_heads
Niklas Muennighoff's avatar
Niklas Muennighoff committed
152
153
154
155
156
157
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.max_position_embeddings = max_position_embeddings

        self.qkv_proj = QKVParallelLinear(
158
            self.hidden_size,
Niklas Muennighoff's avatar
Niklas Muennighoff committed
159
160
161
162
163
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=False,
            quant_config=quant_config,
164
            prefix=f"{prefix}.qkv_proj",
Niklas Muennighoff's avatar
Niklas Muennighoff committed
165
        )
166
167
168
        self.tp_size = tp_size
        self.tp_rank = get_tensor_model_parallel_rank()
        self.q_norm = RMSNorm(self.total_num_heads * self.head_dim, eps=1e-5)
169
        self.k_norm = RMSNorm(self.total_num_kv_heads * self.head_dim, eps=1e-5)
Niklas Muennighoff's avatar
Niklas Muennighoff committed
170
171
        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
172
            self.hidden_size,
Niklas Muennighoff's avatar
Niklas Muennighoff committed
173
174
            bias=False,
            quant_config=quant_config,
175
            prefix=f"{prefix}.o_proj",
Niklas Muennighoff's avatar
Niklas Muennighoff committed
176
177
178
179
180
        )

        self.rotary_emb = get_rope(
            self.head_dim,
            max_position=max_position_embeddings,
181
            rope_parameters=config.rope_parameters,
Niklas Muennighoff's avatar
Niklas Muennighoff committed
182
183
            is_neox_style=True,
        )
184
185
186
187
188
189
190
191
192
193
194
195
196
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
        )

    def _apply_qk_norm(
        self, q: torch.Tensor, k: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor]:
197
198
199
200
201
202
        if self.tp_size > 1:
            q = tensor_model_parallel_all_gather(q.contiguous())
            k = tensor_model_parallel_all_gather(k.contiguous())
        q = self.q_norm(q)
        k = self.k_norm(k)
        if self.tp_size > 1:
203
            splitter = partial(split_tensor_along_last_dim, num_partitions=self.tp_size)
204
205
206
207
            q = splitter(q)[self.tp_rank]
            k = splitter(k)[self.tp_rank]
        return q, k

Niklas Muennighoff's avatar
Niklas Muennighoff committed
208
209
210
211
212
213
214
    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
215
        q, k = self._apply_qk_norm(q, k)
Niklas Muennighoff's avatar
Niklas Muennighoff committed
216
        q, k = self.rotary_emb(positions, q, k)
217
        attn_output = self.attn(q, k, v)
Niklas Muennighoff's avatar
Niklas Muennighoff committed
218
219
220
221
222
        output, _ = self.o_proj(attn_output)
        return output


class OlmoeDecoderLayer(nn.Module):
223
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
Niklas Muennighoff's avatar
Niklas Muennighoff committed
224
        super().__init__()
225
226
227
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config

Niklas Muennighoff's avatar
Niklas Muennighoff committed
228
229
230
        self.hidden_size = config.hidden_size

        self.self_attn = OlmoeAttention(
231
            vllm_config=vllm_config,
232
            prefix=f"{prefix}.self_attn",
Niklas Muennighoff's avatar
Niklas Muennighoff committed
233
234
235
236
237
238
239
240
        )

        self.mlp = OlmoeMoE(
            num_experts=config.num_experts,
            top_k=config.num_experts_per_tok,
            hidden_size=config.hidden_size,
            intermediate_size=config.intermediate_size,
            quant_config=quant_config,
241
            prefix=f"{prefix}.mlp",
Niklas Muennighoff's avatar
Niklas Muennighoff committed
242
243
244
245
246
247
248
249
        )
        self.input_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
        self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=1e-5)

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
250
        residual: torch.Tensor | None,
Niklas Muennighoff's avatar
Niklas Muennighoff committed
251
252
253
254
255
256
    ) -> torch.Tensor:
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
257
            hidden_states, residual = self.input_layernorm(hidden_states, residual)
Niklas Muennighoff's avatar
Niklas Muennighoff committed
258
259
260
261
262
263
264

        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
        )

        # Fully Connected
265
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
Niklas Muennighoff's avatar
Niklas Muennighoff committed
266
267
268
269
        hidden_states = self.mlp(hidden_states)
        return hidden_states, residual


270
@support_torch_compile
Niklas Muennighoff's avatar
Niklas Muennighoff committed
271
class OlmoeModel(nn.Module):
272
273
274
275
276
277
278
    def __init__(
        self,
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
        layer_type: type[nn.Module] = OlmoeDecoderLayer,
    ):
Niklas Muennighoff's avatar
Niklas Muennighoff committed
279
        super().__init__()
280
281
282

        config = vllm_config.model_config.hf_config

Niklas Muennighoff's avatar
Niklas Muennighoff committed
283
        self.vocab_size = config.vocab_size
284
        self.config = config
Niklas Muennighoff's avatar
Niklas Muennighoff committed
285
286
287
288
        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
        )
289
290
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
291
            lambda prefix: layer_type(vllm_config=vllm_config, prefix=prefix),
292
293
            prefix=f"{prefix}.layers",
        )
Niklas Muennighoff's avatar
Niklas Muennighoff committed
294
295
        self.norm = RMSNorm(config.hidden_size, eps=1e-5)

296
297
298
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], config.hidden_size
        )
299

300
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
301
302
        return self.embed_tokens(input_ids)

Niklas Muennighoff's avatar
Niklas Muennighoff committed
303
304
    def forward(
        self,
305
        input_ids: torch.Tensor | None,
Niklas Muennighoff's avatar
Niklas Muennighoff committed
306
        positions: torch.Tensor,
307
308
309
        intermediate_tensors: IntermediateTensors | None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
310
        if get_pp_group().is_first_rank:
311
312
313
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
314
                hidden_states = self.embed_input_ids(input_ids)
315
316
317
318
319
320
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]

321
        for layer in islice(self.layers, self.start_layer, self.end_layer):
322
323
324
325
326
327
328
            hidden_states, residual = layer(
                positions,
                hidden_states,
                residual,
            )

        if not get_pp_group().is_last_rank:
329
330
331
            return IntermediateTensors(
                {"hidden_states": hidden_states, "residual": residual}
            )
332

333
334
335
336
        if residual is not None:
            hidden_states, _ = self.norm(hidden_states, residual)
        else:
            hidden_states = self.norm(hidden_states)
Niklas Muennighoff's avatar
Niklas Muennighoff committed
337
338
        return hidden_states

339
340
341
    def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
        # Params for weights, fp8 weight scales, fp8 activation scales
        # (param_name, weight_name, expert_id, shard_id)
342
        return fused_moe_make_expert_params_mapping(
343
            self,
344
345
346
            ckpt_gate_proj_name="gate_proj",
            ckpt_down_proj_name="down_proj",
            ckpt_up_proj_name="up_proj",
347
348
            num_experts=self.config.num_experts,
        )
349

350
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
Niklas Muennighoff's avatar
Niklas Muennighoff committed
351
352
353
354
355
356
357
358
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]

        params_dict = dict(self.named_parameters())
359
        loaded_params: set[str] = set()
360
        expert_params_mapping = self.get_expert_mapping()
Niklas Muennighoff's avatar
Niklas Muennighoff committed
361
        for name, loaded_weight in weights:
362
            for param_name, weight_name, shard_id in stacked_params_mapping:
Niklas Muennighoff's avatar
Niklas Muennighoff committed
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
                # Skip non-stacked layers and experts (experts handled below).
                if weight_name not in name:
                    continue
                # We have mlp.experts[0].gate_proj in the checkpoint.
                # Since we handle the experts below in expert_params_mapping,
                # we need to skip here BEFORE we update the name, otherwise
                # name will be updated to mlp.experts[0].gate_up_proj, which
                # will then be updated below in expert_params_mapping
                # for mlp.experts[0].gate_gate_up_proj, which breaks load.
                if "mlp.experts" in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
378
379
380
                # Skip layers on other devices.
                if is_pp_missing_parameter(name, self):
                    continue
Niklas Muennighoff's avatar
Niklas Muennighoff committed
381
382
383
384
385
386
387
388
                if name not in params_dict:
                    continue

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
389
                for mapping in expert_params_mapping:
Niklas Muennighoff's avatar
Niklas Muennighoff committed
390
391
392
393
                    param_name, weight_name, expert_id, shard_id = mapping
                    if weight_name not in name:
                        continue
                    name = name.replace(weight_name, param_name)
394
395
396
                    # Skip layers on other devices.
                    if is_pp_missing_parameter(name, self):
                        continue
Niklas Muennighoff's avatar
Niklas Muennighoff committed
397
398
                    param = params_dict[name]
                    weight_loader = param.weight_loader
399
400
401
402
403
404
405
                    weight_loader(
                        param,
                        loaded_weight,
                        name,
                        shard_id=shard_id,
                        expert_id=expert_id,
                    )
Niklas Muennighoff's avatar
Niklas Muennighoff committed
406
407
408
409
410
                    break
                else:
                    # Skip loading extra bias for GPTQ models.
                    if name.endswith(".bias") and name not in params_dict:
                        continue
411
412
413
                    # Skip layers on other devices.
                    if is_pp_missing_parameter(name, self):
                        continue
Niklas Muennighoff's avatar
Niklas Muennighoff committed
414
415
416
                    # Remapping the name of FP8 kv-scale.
                    if name.endswith("kv_scale"):
                        remapped_kv_scale_name = name.replace(
417
418
                            ".kv_scale", ".attn.kv_scale"
                        )
Niklas Muennighoff's avatar
Niklas Muennighoff committed
419
                        if remapped_kv_scale_name not in params_dict:
420
                            logger.warning_once(
421
422
423
424
                                "Found kv scale in the checkpoint (e.g. %s), but not found the expected name in the model (e.g. %s). kv-scale is not loaded.",  # noqa: E501
                                name,
                                remapped_kv_scale_name,
                            )
Niklas Muennighoff's avatar
Niklas Muennighoff committed
425
426
427
428
429
                            continue
                        else:
                            name = remapped_kv_scale_name

                    param = params_dict[name]
430
431
432
                    weight_loader = getattr(
                        param, "weight_loader", default_weight_loader
                    )
Niklas Muennighoff's avatar
Niklas Muennighoff committed
433
                    weight_loader(param, loaded_weight)
434
435
            loaded_params.add(name)
        return loaded_params
436
437


438
class OlmoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
439
440
441
442
443
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
444
        ]
445
    }
446

447
448
449
450
451
452
453
    def __init__(
        self,
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
        layer_type: type[nn.Module] = OlmoeDecoderLayer,
    ):
454
455
456
457
458
        super().__init__()
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        self.config = config
        self.quant_config = quant_config
459
        self.model = OlmoeModel(
460
461
462
            vllm_config=vllm_config,
            prefix=maybe_prefix(prefix, "model"),
            layer_type=layer_type,
463
464
465
466
467
468
469
        )
        self.lm_head = ParallelLMHead(
            config.vocab_size,
            config.hidden_size,
            quant_config=quant_config,
            prefix=maybe_prefix(prefix, "lm_head"),
        )
470
471
472
        self.logits_processor = LogitsProcessor(config.vocab_size)

        self.make_empty_intermediate_tensors = (
473
474
            self.model.make_empty_intermediate_tensors
        )
475

476
477
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.embed_input_ids(input_ids)
478
479
480

    def forward(
        self,
481
        input_ids: torch.Tensor | None,
482
        positions: torch.Tensor,
483
484
485
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
486
487
488
        hidden_states = self.model(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
489
490
        return hidden_states

491
492
    def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head, hidden_states)
493
494
        return logits

495
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
496
        loader = AutoWeightsLoader(self)
497
        return loader.load_weights(weights)
498
499
500

    def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
        return self.model.get_expert_mapping()