olmo2.py 16.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
# Adapted from
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/olmo2/modeling_olmo2.py
# Copyright 2024 The vLLM team.
# Copyright 2024 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only OLMo2 model compatible with HuggingFace weights."""

27
from collections.abc import Iterable
28
from functools import partial
29
from itertools import islice
30
31
32

import torch
from torch import nn
33
from transformers import Olmo2Config, Olmo3Config
34

35
from vllm.compilation.decorators import support_torch_compile
36
37
38
39
40
41
from vllm.config import VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.distributed.communication_op import tensor_model_parallel_all_gather
from vllm.distributed.parallel_state import get_tensor_model_parallel_rank
from vllm.distributed.utils import split_tensor_along_last_dim
from vllm.model_executor.layers.activation import SiluAndMul
42
from vllm.model_executor.layers.attention import Attention
43
from vllm.model_executor.layers.layernorm import RMSNorm
44
45
46
47
48
from vllm.model_executor.layers.linear import (
    MergedColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
49
50
51
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
52
53
54
    ParallelLMHead,
    VocabParallelEmbedding,
)
55
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
56
from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP
57
from vllm.model_executor.models.utils import (
58
59
60
61
62
63
64
    AutoWeightsLoader,
    extract_layer_index,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
65
66
67
68
69
70
from vllm.sequence import IntermediateTensors


class Olmo2Attention(nn.Module):
    """
    This is the attention block where the output is computed as
71
    `Attention(LN(x))` in `MLP(LN(x + Attention(LN(x))))`
72
73
74
75
76
77
    (plus another skip connection).
    """

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        self.config = vllm_config.model_config.hf_config
78
        assert isinstance(self.config, (Olmo2Config, Olmo3Config))
79
80
81
82
83
84
85
86
87

        hidden_size = self.config.hidden_size
        self.tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = self.config.num_attention_heads

        assert hidden_size % self.total_num_heads == 0
        assert self.total_num_heads % self.tp_size == 0

        self.num_heads = self.total_num_heads // self.tp_size
88
89
90
        self.total_num_kv_heads = (
            self.config.num_key_value_heads or self.total_num_heads
        )
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
        if self.total_num_kv_heads >= self.tp_size:
            assert self.total_num_kv_heads % self.tp_size == 0
        else:
            assert self.tp_size % self.total_num_kv_heads == 0

        self.num_kv_heads = max(1, self.total_num_kv_heads // self.tp_size)
        self.head_dim = hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.max_position_embeddings = self.config.max_position_embeddings

        # Attention input projection. Projects x -> (q, k, v)
        self.qkv_proj = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=False,
            quant_config=vllm_config.quant_config,
            prefix=f"{prefix}.qkv_proj",
        )

        self.tp_rank = get_tensor_model_parallel_rank()
        self.k_norm = RMSNorm(
            self.total_num_kv_heads * self.head_dim,
            eps=self.config.rms_norm_eps,
        )
118
        self.q_norm = RMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
119
120

        self.scaling = self.head_dim**-0.5
121
122
123

        layer_idx = extract_layer_index(prefix)
        sliding_window = None
124
125
126
        if (
            layer_types := getattr(self.config, "layer_types", None)
        ) is not None and layer_types[layer_idx] == "sliding_attention":
127
128
            sliding_window = self.config.sliding_window

129
130
131
132
133
134
135
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=vllm_config.cache_config,
            quant_config=vllm_config.quant_config,
136
137
138
139
            per_layer_sliding_window=sliding_window,
            prefix=f"{prefix}.attn",
        )

140
141
142
143
144
145
        # Rotary embeddings. Rope scaling is only applied on full attention layers.
        if sliding_window is None:
            rope_parameters = self.config.rope_parameters
        else:
            rope_theta = self.config.rope_parameters["rope_theta"]
            rope_parameters = {"rope_type": "default", "rope_theta": rope_theta}
146
147
148
        self.rotary_emb = get_rope(
            self.head_dim,
            max_position=self.max_position_embeddings,
149
            rope_parameters=rope_parameters,
150
151
152
153
154
155
156
157
158
159
160
        )

        # Attention output projection.
        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
            quant_config=vllm_config.quant_config,
            prefix=f"{prefix}.o_proj",
        )

161
162
163
    def _apply_qk_norm(
        self, q: torch.Tensor, k: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor]:
164
165
166
        if self.tp_size > 1:
            q = tensor_model_parallel_all_gather(q.contiguous())
            k = tensor_model_parallel_all_gather(k.contiguous())
167
168
        q = self.q_norm(q)
        k = self.k_norm(k)
169
        if self.tp_size > 1:
170
            splitter = partial(split_tensor_along_last_dim, num_partitions=self.tp_size)
171
172
173
174
175
176
177
178
179
180
            q = splitter(q)[self.tp_rank]
            k = splitter(k)[self.tp_rank]
        return q, k

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
181
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
182
183
        q, k = self._apply_qk_norm(q, k)
        q, k = self.rotary_emb(positions, q, k)
184
        attn_output = self.attn(q, k, v)
185
186
187
188
189
190
191
        output, _ = self.o_proj(attn_output)
        return output


class Olmo2MLP(nn.Module):
    """
    This is the MLP block where the output is computed as
192
    `MLP(x)` in `LN(MLP(x + LN(Attention(x))))`
193
194
195
196
197
198
    (plus another skip connection).
    """

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
199
        assert isinstance(config, (Olmo2Config, Olmo3Config))
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
        hidden_size = config.hidden_size
        intermediate_size = config.intermediate_size

        # Feed-forward input projection.
        self.gate_up_proj = MergedColumnParallelLinear(
            hidden_size,
            [intermediate_size] * 2,
            bias=False,
            quant_config=vllm_config.quant_config,
            prefix=f"{prefix}.gate_up_proj",
        )

        # Activation function.
        self.act_fn = SiluAndMul()

        # Feed-forward output projection.
        self.down_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=False,
            quant_config=vllm_config.quant_config,
            prefix=f"{prefix}.down_proj",
        )

    def forward(
        self,
        x: torch.Tensor,
    ) -> torch.Tensor:
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


class Olmo2DecoderLayer(nn.Module):
    """
    This is a typical transformer block where the output is
237
    computed as `MLP(LN(x + Attention(LN(x))))`
238
239
240
241
242
243
    (plus another skip connection).
    """

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
244
        assert isinstance(config, (Olmo2Config, Olmo3Config))
245
        # Attention block.
246
247
248
        self.self_attn = Olmo2Attention(
            vllm_config=vllm_config, prefix=f"{prefix}.self_attn"
        )
249
250
251
252
253

        # MLP block.
        self.mlp = Olmo2MLP(vllm_config=vllm_config, prefix=f"{prefix}.mlp")

        # LayerNorm
254
255
256
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )
257

258
259
260
        self.post_feedforward_layernorm = RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )
261
262
263
264
265
266
267
268

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        # Attention block.
        residual = hidden_states
269
        hidden_states = self.self_attn(positions, hidden_states)
270
271
272
273
274
275
276
277
278
279
280
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = hidden_states + residual

        # MLP block.
        residual = hidden_states
        hidden_states = self.mlp(hidden_states)
        hidden_states = self.post_feedforward_layernorm(hidden_states)
        hidden_states = residual + hidden_states
        return hidden_states


281
@support_torch_compile
282
283
284
285
class Olmo2Model(nn.Module):
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        self.config = vllm_config.model_config.hf_config
286
        assert isinstance(self.config, (Olmo2Config, Olmo3Config))
287
288
289
290
291
292
293
294

        self.embed_tokens = VocabParallelEmbedding(
            self.config.vocab_size,
            self.config.hidden_size,
            prefix=f"{prefix}.embed_tokens",
        )
        self.start_layer, self.end_layer, self.layers = make_layers(
            self.config.num_hidden_layers,
295
            lambda prefix: Olmo2DecoderLayer(vllm_config=vllm_config, prefix=prefix),
296
297
298
299
300
301
            prefix=f"{prefix}.layers",
        )
        self.norm = RMSNorm(
            self.config.hidden_size,
            eps=self.config.rms_norm_eps,
        )
302
303
304
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states"], self.config.hidden_size
        )
305

306
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
307
308
        return self.embed_tokens(input_ids)

309
310
    def forward(
        self,
311
        input_ids: torch.Tensor | None,
312
        positions: torch.Tensor,
313
314
315
        intermediate_tensors: IntermediateTensors | None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
316
317
318
319
        """
        :param input_ids: A tensor of shape `(batch_size, seq_len)`.
        """
        if get_pp_group().is_first_rank:
320
321
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
322
323
            # Get embeddings of input.
            # shape: (batch_size, seq_len, d_model)
324
325
            else:
                hidden_states = self.embed_tokens(input_ids)
326
327
328
329
330
331
332

        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            assert isinstance(hidden_states, torch.Tensor)

        # Apply blocks one-by-one.
333
        for layer in islice(self.layers, self.start_layer, self.end_layer):
334
            # shape: (batch_size, seq_len, d_model)
335
            hidden_states = layer(positions, hidden_states)
336
337
338
339
340
341
342
343
344

        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})

        # Apply final layer norm.
        # shape: (batch_size, seq_len or 1, d_model)
        hidden_states = self.norm(hidden_states)
        return hidden_states

345
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
346
347
348
349
350
351
352
353
354
355
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]

        params_dict = dict(self.named_parameters(remove_duplicate=False))
356
        loaded_params: set[str] = set()
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
        for name, loaded_weight in weights:
            if is_pp_missing_parameter(name, self):
                continue
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader  # type: ignore
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                param = params_dict[name]
376
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
377
                weight_loader(param, loaded_weight)
378
379
            loaded_params.add(name)
        return loaded_params
380

381

382
class Olmo2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
383
384
385
    """
    Extremely barebones HF model wrapper.
    """
386

387
388
389
390
391
392
393
394
395
396
397
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }
398
399
400
401

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
402
        assert isinstance(config, (Olmo2Config, Olmo3Config))
403
        self.config = config
404
405
406
        self.model = Olmo2Model(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
407
408
409
410
411
412
413
414
415
416
417
        if config.tie_word_embeddings:
            self.lm_head = self.model.embed_tokens
        else:
            self.lm_head = ParallelLMHead(
                config.vocab_size,
                config.hidden_size,
                quant_config=vllm_config.quant_config,
                prefix=maybe_prefix(prefix, "lm_head"),
            )
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.make_empty_intermediate_tensors = (
418
419
            self.model.make_empty_intermediate_tensors
        )
420

421
422
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.embed_input_ids(input_ids)
423

424
425
    def forward(
        self,
426
        input_ids: torch.Tensor | None,
427
        positions: torch.Tensor,
428
429
430
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
431
432
433
434
        hidden_states = self.model(
            input_ids=input_ids,
            positions=positions,
            intermediate_tensors=intermediate_tensors,
435
            inputs_embeds=inputs_embeds,
436
437
438
439
440
441
        )
        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
442
    ) -> torch.Tensor | None:
443
        logits = self.logits_processor(self.lm_head, hidden_states)
444
445
        return logits

446
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
447
448
        loader = AutoWeightsLoader(
            self,
449
450
451
            skip_prefixes=(
                ["lm_head.weight"] if self.config.tie_word_embeddings else None
            ),
452
453
        )
        return loader.load_weights(weights)