olmo2.py 16.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
# Adapted from
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/olmo2/modeling_olmo2.py
# Copyright 2024 The vLLM team.
# Copyright 2024 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only OLMo2 model compatible with HuggingFace weights."""

27
from collections.abc import Iterable
28
from functools import partial
29
from itertools import islice
30
from typing import Optional, Union
31
32
33

import torch
from torch import nn
34
from transformers import Olmo2Config
35

36
from vllm.attention import Attention
37
from vllm.compilation.decorators import support_torch_compile
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
from vllm.config import VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.distributed.communication_op import tensor_model_parallel_all_gather
from vllm.distributed.parallel_state import get_tensor_model_parallel_rank
from vllm.distributed.utils import split_tensor_along_last_dim
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
    ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
53
from vllm.model_executor.models.interfaces import SupportsLoRA, SupportsPP
54
from vllm.model_executor.models.utils import (
55
    AutoWeightsLoader, extract_layer_index, is_pp_missing_parameter,
56
    make_empty_intermediate_tensors_factory, make_layers, maybe_prefix)
57
from vllm.sequence import IntermediateTensors
58
from vllm.transformers_utils.configs import Olmo3Config
59
60
61
62
63
64
65
66
67
68
69
70


class Olmo2Attention(nn.Module):
    """
    This is the attention block where the output is computed as
    ``Attention(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))``
    (plus another skip connection).
    """

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        self.config = vllm_config.model_config.hf_config
71
        assert isinstance(self.config, (Olmo2Config, Olmo3Config))
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114

        hidden_size = self.config.hidden_size
        self.tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = self.config.num_attention_heads

        assert hidden_size % self.total_num_heads == 0
        assert self.total_num_heads % self.tp_size == 0

        self.num_heads = self.total_num_heads // self.tp_size
        self.total_num_kv_heads = (self.config.num_key_value_heads
                                   or self.total_num_heads)
        if self.total_num_kv_heads >= self.tp_size:
            assert self.total_num_kv_heads % self.tp_size == 0
        else:
            assert self.tp_size % self.total_num_kv_heads == 0

        self.num_kv_heads = max(1, self.total_num_kv_heads // self.tp_size)
        self.head_dim = hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.max_position_embeddings = self.config.max_position_embeddings
        self.rope_theta = self.config.rope_theta

        # Attention input projection. Projects x -> (q, k, v)
        self.qkv_proj = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=False,
            quant_config=vllm_config.quant_config,
            prefix=f"{prefix}.qkv_proj",
        )

        self.tp_rank = get_tensor_model_parallel_rank()
        self.k_norm = RMSNorm(
            self.total_num_kv_heads * self.head_dim,
            eps=self.config.rms_norm_eps,
        )
        self.q_norm = RMSNorm(self.config.hidden_size,
                              eps=self.config.rms_norm_eps)

        self.scaling = self.head_dim**-0.5
115
116
117
118
119
120
121

        layer_idx = extract_layer_index(prefix)
        sliding_window = None
        if ((layer_types := getattr(self.config, "layer_types", None))
                is not None and layer_types[layer_idx] == "sliding_attention"):
            sliding_window = self.config.sliding_window

122
123
124
125
126
127
128
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=vllm_config.cache_config,
            quant_config=vllm_config.quant_config,
129
130
131
132
133
134
135
136
137
138
139
140
141
142
            per_layer_sliding_window=sliding_window,
            prefix=f"{prefix}.attn",
        )

        # Rotary embeddings. Rope scaling is only applied on full attention
        # layers.
        self.rope_scaling = (self.config.rope_scaling
                             if sliding_window is None else None)
        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=self.max_position_embeddings,
            base=self.rope_theta,  # type: ignore
            rope_scaling=self.rope_scaling,
143
144
145
146
147
148
149
150
151
152
153
154
        )

        # Attention output projection.
        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
            quant_config=vllm_config.quant_config,
            prefix=f"{prefix}.o_proj",
        )

    def _apply_qk_norm(self, q: torch.Tensor,
155
                       k: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
156
157
158
        if self.tp_size > 1:
            q = tensor_model_parallel_all_gather(q.contiguous())
            k = tensor_model_parallel_all_gather(k.contiguous())
159
160
        q = self.q_norm(q)
        k = self.k_norm(k)
161
162
163
164
165
166
167
168
169
170
171
172
173
        if self.tp_size > 1:
            splitter = partial(split_tensor_along_last_dim,
                               num_partitions=self.tp_size)
            q = splitter(q)[self.tp_rank]
            k = splitter(k)[self.tp_rank]
        return q, k

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
174
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
175
176
        q, k = self._apply_qk_norm(q, k)
        q, k = self.rotary_emb(positions, q, k)
177
        attn_output = self.attn(q, k, v)
178
179
180
181
182
183
184
185
186
187
188
189
190
191
        output, _ = self.o_proj(attn_output)
        return output


class Olmo2MLP(nn.Module):
    """
    This is the MLP block where the output is computed as
    ``MLP(x)`` in ``LN(MLP(x + LN(Attention(x))))``
    (plus another skip connection).
    """

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
192
        assert isinstance(config, (Olmo2Config, Olmo3Config))
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
        hidden_size = config.hidden_size
        intermediate_size = config.intermediate_size

        # Feed-forward input projection.
        self.gate_up_proj = MergedColumnParallelLinear(
            hidden_size,
            [intermediate_size] * 2,
            bias=False,
            quant_config=vllm_config.quant_config,
            prefix=f"{prefix}.gate_up_proj",
        )

        # Activation function.
        self.act_fn = SiluAndMul()

        # Feed-forward output projection.
        self.down_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=False,
            quant_config=vllm_config.quant_config,
            prefix=f"{prefix}.down_proj",
        )

    def forward(
        self,
        x: torch.Tensor,
    ) -> torch.Tensor:
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


class Olmo2DecoderLayer(nn.Module):
    """
    This is a typical transformer block where the output is
    computed as ``MLP(LN(x + Attention(LN(x))))``
    (plus another skip connection).
    """

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
237
        assert isinstance(config, (Olmo2Config, Olmo3Config))
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
        # Attention block.
        self.self_attn = Olmo2Attention(vllm_config=vllm_config,
                                        prefix=f"{prefix}.self_attn")

        # MLP block.
        self.mlp = Olmo2MLP(vllm_config=vllm_config, prefix=f"{prefix}.mlp")

        # LayerNorm
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps=config.rms_norm_eps)

        self.post_feedforward_layernorm = RMSNorm(config.hidden_size,
                                                  eps=config.rms_norm_eps)

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        # Attention block.
        residual = hidden_states
259
        hidden_states = self.self_attn(positions, hidden_states)
260
261
262
263
264
265
266
267
268
269
270
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = hidden_states + residual

        # MLP block.
        residual = hidden_states
        hidden_states = self.mlp(hidden_states)
        hidden_states = self.post_feedforward_layernorm(hidden_states)
        hidden_states = residual + hidden_states
        return hidden_states


271
@support_torch_compile
272
273
274
275
276
class Olmo2Model(nn.Module):

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        self.config = vllm_config.model_config.hf_config
277
        assert isinstance(self.config, (Olmo2Config, Olmo3Config))
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302

        self.embed_tokens = VocabParallelEmbedding(
            self.config.vocab_size,
            self.config.hidden_size,
            prefix=f"{prefix}.embed_tokens",
        )
        self.start_layer, self.end_layer, self.layers = make_layers(
            self.config.num_hidden_layers,
            lambda prefix: Olmo2DecoderLayer(vllm_config=vllm_config,
                                             prefix=prefix),
            prefix=f"{prefix}.layers",
        )
        self.norm = RMSNorm(
            self.config.hidden_size,
            eps=self.config.rms_norm_eps,
        )
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(["hidden_states"],
                                                    self.config.hidden_size))

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors],
303
        inputs_embeds: Optional[torch.Tensor] = None,
304
305
306
307
308
    ) -> Union[torch.Tensor, IntermediateTensors]:
        """
        :param input_ids: A tensor of shape `(batch_size, seq_len)`.
        """
        if get_pp_group().is_first_rank:
309
310
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
311
312
            # Get embeddings of input.
            # shape: (batch_size, seq_len, d_model)
313
314
            else:
                hidden_states = self.embed_tokens(input_ids)
315
316
317
318
319
320
321

        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            assert isinstance(hidden_states, torch.Tensor)

        # Apply blocks one-by-one.
322
        for layer in islice(self.layers, self.start_layer, self.end_layer):
323
            # shape: (batch_size, seq_len, d_model)
324
            hidden_states = layer(positions, hidden_states)
325
326
327
328
329
330
331
332
333

        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})

        # Apply final layer norm.
        # shape: (batch_size, seq_len or 1, d_model)
        hidden_states = self.norm(hidden_states)
        return hidden_states

334
335
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
336
337
338
339
340
341
342
343
344
345
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]

        params_dict = dict(self.named_parameters(remove_duplicate=False))
346
        loaded_params: set[str] = set()
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
        for name, loaded_weight in weights:
            if is_pp_missing_parameter(name, self):
                continue
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader  # type: ignore
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
369
370
            loaded_params.add(name)
        return loaded_params
371

372

373
class Olmo2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
374
375
376
    """
    Extremely barebones HF model wrapper.
    """
377
378
379
380
381
382
383
384
385
386
387
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }
388
389
390
391

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
392
        assert isinstance(config, (Olmo2Config, Olmo3Config))
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
        self.config = config
        self.model = Olmo2Model(vllm_config=vllm_config,
                                prefix=maybe_prefix(prefix, "model"))
        if config.tie_word_embeddings:
            self.lm_head = self.model.embed_tokens
        else:
            self.unpadded_vocab_size = config.vocab_size
            self.lm_head = ParallelLMHead(
                config.vocab_size,
                config.hidden_size,
                org_num_embeddings=config.vocab_size,
                quant_config=vllm_config.quant_config,
                prefix=maybe_prefix(prefix, "lm_head"),
            )
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
416
        inputs_embeds: Optional[torch.Tensor] = None,
417
418
419
420
421
    ) -> Union[torch.Tensor, IntermediateTensors]:
        hidden_states = self.model(
            input_ids=input_ids,
            positions=positions,
            intermediate_tensors=intermediate_tensors,
422
            inputs_embeds=inputs_embeds,
423
424
425
426
427
428
429
        )
        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> Optional[torch.Tensor]:
430
        logits = self.logits_processor(self.lm_head, hidden_states)
431
432
        return logits

433
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
434
435
        loader = AutoWeightsLoader(
            self,
436
437
            skip_prefixes=(["lm_head.weight"]
                           if self.config.tie_word_embeddings else None),
438
439
        )
        return loader.load_weights(weights)