olmo.py 14.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

Isotr0py's avatar
Isotr0py committed
4
# Adapted from
5
6
7
# https://github.com/huggingface/transformers/blob/v4.40.1/src/transformers/models/olmo/modeling_olmo.py
# Copyright 2024 The vLLM team.
# Copyright 2024 EleutherAI and the HuggingFace Inc. team. All rights reserved.
Isotr0py's avatar
Isotr0py committed
8
#
9
10
11
12
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
Isotr0py's avatar
Isotr0py committed
13
#
14
15
16
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Isotr0py's avatar
Isotr0py committed
17
#
18
#     http://www.apache.org/licenses/LICENSE-2.0
Isotr0py's avatar
Isotr0py committed
19
#
20
21
22
23
24
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Isotr0py's avatar
Isotr0py committed
25
"""Inference-only OLMo model compatible with HuggingFace weights."""
26

27
from collections.abc import Iterable
28
from itertools import islice
Isotr0py's avatar
Isotr0py committed
29
30
31

import torch
from torch import nn
32
from transformers import OlmoConfig
Isotr0py's avatar
Isotr0py committed
33

34
from vllm.attention import Attention
35
from vllm.compilation.decorators import support_torch_compile
36
from vllm.config import CacheConfig, VllmConfig
37
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
38
from vllm.model_executor.layers.activation import SiluAndMul
39
40
41
42
43
from vllm.model_executor.layers.linear import (
    MergedColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
44
from vllm.model_executor.layers.logits_processor import LogitsProcessor
45
from vllm.model_executor.layers.quantization import QuantizationConfig
46
from vllm.model_executor.layers.rotary_embedding import get_rope
47
from vllm.model_executor.layers.vocab_parallel_embedding import (
48
49
50
    ParallelLMHead,
    VocabParallelEmbedding,
)
51
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
52
from vllm.sequence import IntermediateTensors
53

54
from .interfaces import SupportsLoRA, SupportsPP
55
56
57
58
59
60
61
from .utils import (
    AutoWeightsLoader,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
62

Isotr0py's avatar
Isotr0py committed
63
64
65

class OlmoAttention(nn.Module):
    """
66
    This is the attention block where the output is computed as
67
    `Attention(LN(x))` in `MLP(LN(x + Attention(LN(x))))`
Isotr0py's avatar
Isotr0py committed
68
69
70
71
72
    (plus another skip connection).
    """

    def __init__(
        self,
73
        config: OlmoConfig,
74
75
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
76
        prefix: str = "",
Isotr0py's avatar
Isotr0py committed
77
78
79
    ):
        super().__init__()
        self.config = config
80
        self.hidden_size = config.hidden_size
81
        tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
82
83
84
        self.total_num_heads = config.num_attention_heads

        assert self.hidden_size % self.total_num_heads == 0
Isotr0py's avatar
Isotr0py committed
85
        assert self.total_num_heads % tensor_model_parallel_world_size == 0
86

87
        self.num_heads = self.total_num_heads // tensor_model_parallel_world_size
Isotr0py's avatar
Isotr0py committed
88
        self.head_dim = self.hidden_size // self.total_num_heads
89
90
91
        self.max_position_embeddings = config.max_position_embeddings
        self.rope_theta = config.rope_theta
        self.clip_qkv = config.clip_qkv
Isotr0py's avatar
Isotr0py committed
92
93

        # Attention input projection. Projects x -> (q, k, v)
94
95
        self.qkv_proj = QKVParallelLinear(
            self.hidden_size,
Isotr0py's avatar
Isotr0py committed
96
97
            self.head_dim,
            self.total_num_heads,
98
            bias=config.attention_bias,
99
            quant_config=quant_config,
100
            prefix=f"{prefix}.qkv_proj",
Isotr0py's avatar
Isotr0py committed
101
102
103
        )

        # Rotary embeddings.
104
105
106
107
108
109
        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=self.max_position_embeddings,
            base=self.rope_theta,
        )
Isotr0py's avatar
Isotr0py committed
110
        self.scaling = self.head_dim**-0.5
111
112
113
114
115
116
117
118
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            scale=self.scaling,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
        )
Isotr0py's avatar
Isotr0py committed
119
120

        # Attention output projection.
121
122
123
124
        self.o_proj = RowParallelLinear(
            self.hidden_size,
            self.hidden_size,
            bias=config.attention_bias,
125
            quant_config=quant_config,
126
            prefix=f"{prefix}.o_proj",
Isotr0py's avatar
Isotr0py committed
127
128
129
130
131
132
133
        )

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
134
135
136
        qkv, _ = self.qkv_proj(hidden_states)
        if self.clip_qkv is not None:
            qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
Isotr0py's avatar
Isotr0py committed
137
        q, k, v = qkv.chunk(chunks=3, dim=-1)
138
        q, k = self.rotary_emb(positions, q, k)
139
        attn_output = self.attn(q, k, v)
140
        output, _ = self.o_proj(attn_output)
Isotr0py's avatar
Isotr0py committed
141
142
143
144
145
        return output


class OlmoMLP(nn.Module):
    """
146
    This is the MLP block where the output is computed as
147
    `MLP(LN(x))` in `MLP(LN(x + Attention(LN(x))))`
Isotr0py's avatar
Isotr0py committed
148
149
150
151
152
    (plus another skip connection).
    """

    def __init__(
        self,
153
        config: OlmoConfig,
154
        quant_config: QuantizationConfig | None = None,
155
        prefix: str = "",
Isotr0py's avatar
Isotr0py committed
156
157
158
    ):
        super().__init__()
        self.config = config
159
160
        self.hidden_size = config.hidden_size
        self.intermediate_size = config.intermediate_size
Isotr0py's avatar
Isotr0py committed
161
162

        # Feed-forward input projection.
163
164
165
166
        self.gate_up_proj = MergedColumnParallelLinear(
            self.hidden_size,
            [self.intermediate_size] * 2,
            bias=False,
167
            quant_config=quant_config,
168
            prefix=f"{prefix}.gate_up_proj",
Isotr0py's avatar
Isotr0py committed
169
170
171
        )

        # Activation function.
172
        self.act_fn = SiluAndMul()
Isotr0py's avatar
Isotr0py committed
173
174

        # Feed-forward output projection.
175
176
177
178
        self.down_proj = RowParallelLinear(
            self.intermediate_size,
            self.hidden_size,
            bias=False,
179
            quant_config=quant_config,
180
            prefix=f"{prefix}.down_proj",
Isotr0py's avatar
Isotr0py committed
181
182
183
184
185
186
        )

    def forward(
        self,
        x: torch.Tensor,
    ) -> torch.Tensor:
187
188
189
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
Isotr0py's avatar
Isotr0py committed
190
191
192
        return x


193
class OlmoDecoderLayer(nn.Module):
Isotr0py's avatar
Isotr0py committed
194
    """
195
    This is a typical transformer block where the output is
196
    computed as `MLP(LN(x + Attention(LN(x))))`
Isotr0py's avatar
Isotr0py committed
197
198
199
    (plus another skip connection).
    """

200
201
202
    def __init__(
        self,
        config: OlmoConfig,
203
204
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
205
206
        prefix: str = "",
    ):
Isotr0py's avatar
Isotr0py committed
207
208
        super().__init__()
        # Attention block.
209
210
211
        self.self_attn = OlmoAttention(
            config, cache_config, quant_config, prefix=f"{prefix}.self_attn"
        )
Isotr0py's avatar
Isotr0py committed
212
213

        # MLP block.
214
        self.mlp = OlmoMLP(config, quant_config, prefix=f"{prefix}.mlp")
Isotr0py's avatar
Isotr0py committed
215

216
        # LayerNorm
217
218
219
220
221
222
        self.input_layernorm = nn.LayerNorm(
            config.hidden_size, elementwise_affine=False, bias=False
        )
        self.post_attention_layernorm = nn.LayerNorm(
            config.hidden_size, elementwise_affine=False, bias=False
        )
223

Isotr0py's avatar
Isotr0py committed
224
225
226
227
    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
228
    ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor] | None]:
Isotr0py's avatar
Isotr0py committed
229
        # Attention block.
230
231
        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states)
232
        hidden_states = self.self_attn(positions, hidden_states)
233
        hidden_states = hidden_states + residual
Isotr0py's avatar
Isotr0py committed
234
235

        # MLP block.
236
237
238
239
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states
Isotr0py's avatar
Isotr0py committed
240
241
242
        return hidden_states


243
@support_torch_compile
Isotr0py's avatar
Isotr0py committed
244
class OlmoModel(nn.Module):
245
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
Isotr0py's avatar
Isotr0py committed
246
        super().__init__()
247
248
249
250
251

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

Isotr0py's avatar
Isotr0py committed
252
253
        self.config = config

254
255
256
        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size, config.hidden_size
        )
257
258
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
259
            lambda prefix: OlmoDecoderLayer(
260
261
262
263
264
265
266
267
268
269
                config, cache_config, quant_config, prefix=prefix
            ),
            prefix=f"{prefix}.layers",
        )
        self.norm = nn.LayerNorm(
            config.hidden_size, elementwise_affine=False, bias=False
        )
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states"], config.hidden_size
        )
Isotr0py's avatar
Isotr0py committed
270

271
272
273
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

Isotr0py's avatar
Isotr0py committed
274
275
276
277
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
278
279
280
        intermediate_tensors: IntermediateTensors | None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
Isotr0py's avatar
Isotr0py committed
281
282
283
        """
        :param input_ids: A tensor of shape `(batch_size, seq_len)`.
        """
284
        if get_pp_group().is_first_rank:
285
286
287
288
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
289
290
291
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
Isotr0py's avatar
Isotr0py committed
292
293

        # Apply blocks one-by-one.
294
        for layer in islice(self.layers, self.start_layer, self.end_layer):
Isotr0py's avatar
Isotr0py committed
295
            # shape: (batch_size, seq_len, d_model)
296
            hidden_states = layer(positions, hidden_states)
Isotr0py's avatar
Isotr0py committed
297

298
299
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})
Isotr0py's avatar
Isotr0py committed
300
301
        # Apply final layer norm.
        # shape: (batch_size, seq_len or 1, d_model)
302
303
        hidden_states = self.norm(hidden_states)
        return hidden_states
Isotr0py's avatar
Isotr0py committed
304

305
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
306
307
308
309
310
311
312
313
314
315
316
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]
        params_dict = dict(self.named_parameters(remove_duplicate=False))
        loaded_params: set[str] = set()
        for name, loaded_weight in weights:
317
            for param_name, weight_name, shard_id in stacked_params_mapping:
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
337
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
338
339
340
341
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

Isotr0py's avatar
Isotr0py committed
342

343
class OlmoForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
Isotr0py's avatar
Isotr0py committed
344
345
346
    """
    Extremely barebones HF model wrapper.
    """
347

348
349
350
351
352
353
354
355
356
357
358
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }
Isotr0py's avatar
Isotr0py committed
359

360
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
Isotr0py's avatar
Isotr0py committed
361
        super().__init__()
362
363
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
Isotr0py's avatar
Isotr0py committed
364
        self.config = config
365
366
367
        self.model = OlmoModel(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
368
        if config.tie_word_embeddings:
369
            self.lm_head = self.model.embed_tokens
370
371
372
373
374
375
        else:
            self.unpadded_vocab_size = config.vocab_size
            self.lm_head = ParallelLMHead(
                self.unpadded_vocab_size,
                config.hidden_size,
                org_num_embeddings=config.vocab_size,
376
                quant_config=quant_config,
377
                prefix=maybe_prefix(prefix, "lm_head"),
378
            )
379
        self.logits_processor = LogitsProcessor(config.vocab_size)
380
        self.make_empty_intermediate_tensors = (
381
382
            self.model.make_empty_intermediate_tensors
        )
Isotr0py's avatar
Isotr0py committed
383

384
385
386
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

Isotr0py's avatar
Isotr0py committed
387
388
389
390
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
391
392
393
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
Isotr0py's avatar
Isotr0py committed
394
395
396
        hidden_states = self.model(
            input_ids=input_ids,
            positions=positions,
397
            intermediate_tensors=intermediate_tensors,
398
            inputs_embeds=inputs_embeds,
Isotr0py's avatar
Isotr0py committed
399
400
401
        )
        return hidden_states

402
403
404
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
405
    ) -> torch.Tensor | None:
406
        logits = self.logits_processor(self.lm_head, hidden_states)
407
408
        return logits

409
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
410
411
        loader = AutoWeightsLoader(
            self,
412
413
414
            skip_prefixes=(
                ["lm_head.weight"] if self.config.tie_word_embeddings else None
            ),
415
416
        )
        return loader.load_weights(weights)