olmo.py 15 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

Isotr0py's avatar
Isotr0py committed
4
# Adapted from
5
6
7
# https://github.com/huggingface/transformers/blob/v4.40.1/src/transformers/models/olmo/modeling_olmo.py
# Copyright 2024 The vLLM team.
# Copyright 2024 EleutherAI and the HuggingFace Inc. team. All rights reserved.
Isotr0py's avatar
Isotr0py committed
8
#
9
10
11
12
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
Isotr0py's avatar
Isotr0py committed
13
#
14
15
16
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Isotr0py's avatar
Isotr0py committed
17
#
18
#     http://www.apache.org/licenses/LICENSE-2.0
Isotr0py's avatar
Isotr0py committed
19
#
20
21
22
23
24
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Isotr0py's avatar
Isotr0py committed
25
"""Inference-only OLMo model compatible with HuggingFace weights."""
26
from collections.abc import Iterable
27
from itertools import islice
28
from typing import Optional, Union
Isotr0py's avatar
Isotr0py committed
29
30
31

import torch
from torch import nn
32
from transformers import OlmoConfig
Isotr0py's avatar
Isotr0py committed
33

34
from vllm.attention import Attention
35
from vllm.compilation.decorators import support_torch_compile
36
from vllm.config import CacheConfig, VllmConfig
37
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
38
from vllm.model_executor.layers.activation import SiluAndMul
39
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
40
41
                                               QKVParallelLinear,
                                               RowParallelLinear)
42
from vllm.model_executor.layers.logits_processor import LogitsProcessor
43
from vllm.model_executor.layers.quantization import QuantizationConfig
44
from vllm.model_executor.layers.rotary_embedding import get_rope
45
from vllm.model_executor.layers.vocab_parallel_embedding import (
46
    ParallelLMHead, VocabParallelEmbedding)
47
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
48
from vllm.sequence import IntermediateTensors
49

50
from .interfaces import SupportsLoRA, SupportsPP
51
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
52
53
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
54

Isotr0py's avatar
Isotr0py committed
55
56
57

class OlmoAttention(nn.Module):
    """
58
59
    This is the attention block where the output is computed as
    ``Attention(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))``
Isotr0py's avatar
Isotr0py committed
60
61
62
63
64
    (plus another skip connection).
    """

    def __init__(
        self,
65
        config: OlmoConfig,
66
        cache_config: Optional[CacheConfig] = None,
67
        quant_config: Optional[QuantizationConfig] = None,
68
        prefix: str = "",
Isotr0py's avatar
Isotr0py committed
69
70
71
    ):
        super().__init__()
        self.config = config
72
        self.hidden_size = config.hidden_size
73
74
        tensor_model_parallel_world_size = (
            get_tensor_model_parallel_world_size())
75
76
77
        self.total_num_heads = config.num_attention_heads

        assert self.hidden_size % self.total_num_heads == 0
Isotr0py's avatar
Isotr0py committed
78
        assert self.total_num_heads % tensor_model_parallel_world_size == 0
79

80
81
        self.num_heads = (self.total_num_heads //
                          tensor_model_parallel_world_size)
Isotr0py's avatar
Isotr0py committed
82
        self.head_dim = self.hidden_size // self.total_num_heads
83
84
85
        self.max_position_embeddings = config.max_position_embeddings
        self.rope_theta = config.rope_theta
        self.clip_qkv = config.clip_qkv
Isotr0py's avatar
Isotr0py committed
86
87

        # Attention input projection. Projects x -> (q, k, v)
88
89
        self.qkv_proj = QKVParallelLinear(
            self.hidden_size,
Isotr0py's avatar
Isotr0py committed
90
91
            self.head_dim,
            self.total_num_heads,
92
            bias=config.attention_bias,
93
            quant_config=quant_config,
94
            prefix=f"{prefix}.qkv_proj",
Isotr0py's avatar
Isotr0py committed
95
96
97
        )

        # Rotary embeddings.
98
99
100
101
102
103
        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=self.max_position_embeddings,
            base=self.rope_theta,
        )
Isotr0py's avatar
Isotr0py committed
104
        self.scaling = self.head_dim**-0.5
105
106
        self.attn = Attention(self.num_heads,
                              self.head_dim,
107
                              scale=self.scaling,
108
                              cache_config=cache_config,
109
110
                              quant_config=quant_config,
                              prefix=f"{prefix}.attn")
Isotr0py's avatar
Isotr0py committed
111
112

        # Attention output projection.
113
114
115
116
        self.o_proj = RowParallelLinear(
            self.hidden_size,
            self.hidden_size,
            bias=config.attention_bias,
117
            quant_config=quant_config,
118
            prefix=f"{prefix}.o_proj",
Isotr0py's avatar
Isotr0py committed
119
120
121
122
123
124
125
        )

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
126
127
128
        qkv, _ = self.qkv_proj(hidden_states)
        if self.clip_qkv is not None:
            qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
Isotr0py's avatar
Isotr0py committed
129
        q, k, v = qkv.chunk(chunks=3, dim=-1)
130
        q, k = self.rotary_emb(positions, q, k)
131
        attn_output = self.attn(q, k, v)
132
        output, _ = self.o_proj(attn_output)
Isotr0py's avatar
Isotr0py committed
133
134
135
136
137
        return output


class OlmoMLP(nn.Module):
    """
138
139
    This is the MLP block where the output is computed as
    ``MLP(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))``
Isotr0py's avatar
Isotr0py committed
140
141
142
143
144
    (plus another skip connection).
    """

    def __init__(
        self,
145
        config: OlmoConfig,
146
        quant_config: Optional[QuantizationConfig] = None,
147
        prefix: str = "",
Isotr0py's avatar
Isotr0py committed
148
149
150
    ):
        super().__init__()
        self.config = config
151
152
        self.hidden_size = config.hidden_size
        self.intermediate_size = config.intermediate_size
Isotr0py's avatar
Isotr0py committed
153
154

        # Feed-forward input projection.
155
156
157
158
        self.gate_up_proj = MergedColumnParallelLinear(
            self.hidden_size,
            [self.intermediate_size] * 2,
            bias=False,
159
            quant_config=quant_config,
160
            prefix=f"{prefix}.gate_up_proj",
Isotr0py's avatar
Isotr0py committed
161
162
163
        )

        # Activation function.
164
        self.act_fn = SiluAndMul()
Isotr0py's avatar
Isotr0py committed
165
166

        # Feed-forward output projection.
167
168
169
170
        self.down_proj = RowParallelLinear(
            self.intermediate_size,
            self.hidden_size,
            bias=False,
171
            quant_config=quant_config,
172
            prefix=f"{prefix}.down_proj",
Isotr0py's avatar
Isotr0py committed
173
174
175
176
177
178
        )

    def forward(
        self,
        x: torch.Tensor,
    ) -> torch.Tensor:
179
180
181
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
Isotr0py's avatar
Isotr0py committed
182
183
184
        return x


185
class OlmoDecoderLayer(nn.Module):
Isotr0py's avatar
Isotr0py committed
186
    """
187
188
    This is a typical transformer block where the output is
    computed as ``MLP(LN(x + Attention(LN(x))))``
Isotr0py's avatar
Isotr0py committed
189
190
191
192
    (plus another skip connection).
    """

    def __init__(self,
193
                 config: OlmoConfig,
194
                 cache_config: Optional[CacheConfig] = None,
195
196
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
Isotr0py's avatar
Isotr0py committed
197
198
        super().__init__()
        # Attention block.
199
200
201
202
        self.self_attn = OlmoAttention(config,
                                       cache_config,
                                       quant_config,
                                       prefix=f"{prefix}.self_attn")
Isotr0py's avatar
Isotr0py committed
203
204

        # MLP block.
205
        self.mlp = OlmoMLP(config, quant_config, prefix=f"{prefix}.mlp")
Isotr0py's avatar
Isotr0py committed
206

207
208
209
210
211
212
213
214
        # LayerNorm
        self.input_layernorm = nn.LayerNorm(config.hidden_size,
                                            elementwise_affine=False,
                                            bias=False)
        self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
                                                     elementwise_affine=False,
                                                     bias=False)

Isotr0py's avatar
Isotr0py committed
215
216
217
218
    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
219
    ) -> tuple[torch.Tensor, Optional[tuple[torch.Tensor, torch.Tensor]]]:
Isotr0py's avatar
Isotr0py committed
220
        # Attention block.
221
222
        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states)
223
        hidden_states = self.self_attn(positions, hidden_states)
224
        hidden_states = hidden_states + residual
Isotr0py's avatar
Isotr0py committed
225
226

        # MLP block.
227
228
229
230
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states
Isotr0py's avatar
Isotr0py committed
231
232
233
        return hidden_states


234
@support_torch_compile
Isotr0py's avatar
Isotr0py committed
235
236
class OlmoModel(nn.Module):

237
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
Isotr0py's avatar
Isotr0py committed
238
        super().__init__()
239
240
241
242
243

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

Isotr0py's avatar
Isotr0py committed
244
245
        self.config = config

246
247
        self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
                                                   config.hidden_size)
248
249
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
250
251
            lambda prefix: OlmoDecoderLayer(
                config, cache_config, quant_config, prefix=prefix),
252
            prefix=f"{prefix}.layers")
253
254
255
        self.norm = nn.LayerNorm(config.hidden_size,
                                 elementwise_affine=False,
                                 bias=False)
256
257
258
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(["hidden_states"],
                                                    config.hidden_size))
Isotr0py's avatar
Isotr0py committed
259

260
261
262
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

Isotr0py's avatar
Isotr0py committed
263
264
265
266
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
267
        intermediate_tensors: Optional[IntermediateTensors],
268
        inputs_embeds: Optional[torch.Tensor] = None,
269
    ) -> Union[torch.Tensor, IntermediateTensors]:
Isotr0py's avatar
Isotr0py committed
270
271
272
        """
        :param input_ids: A tensor of shape `(batch_size, seq_len)`.
        """
273
        if get_pp_group().is_first_rank:
274
275
276
277
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
278
279
280
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
Isotr0py's avatar
Isotr0py committed
281
282

        # Apply blocks one-by-one.
283
        for layer in islice(self.layers, self.start_layer, self.end_layer):
Isotr0py's avatar
Isotr0py committed
284
            # shape: (batch_size, seq_len, d_model)
285
            hidden_states = layer(positions, hidden_states)
Isotr0py's avatar
Isotr0py committed
286

287
288
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})
Isotr0py's avatar
Isotr0py committed
289
290
        # Apply final layer norm.
        # shape: (batch_size, seq_len or 1, d_model)
291
292
        hidden_states = self.norm(hidden_states)
        return hidden_states
Isotr0py's avatar
Isotr0py committed
293

294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]
        params_dict = dict(self.named_parameters(remove_duplicate=False))
        loaded_params: set[str] = set()
        for name, loaded_weight in weights:
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

Isotr0py's avatar
Isotr0py committed
333

334
class OlmoForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
Isotr0py's avatar
Isotr0py committed
335
336
337
    """
    Extremely barebones HF model wrapper.
    """
338
339
340
341
342
343
344
345
346
347
348
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }
Isotr0py's avatar
Isotr0py committed
349

350
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
Isotr0py's avatar
Isotr0py committed
351
        super().__init__()
352
353
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
Isotr0py's avatar
Isotr0py committed
354
        self.config = config
355
356
        self.model = OlmoModel(vllm_config=vllm_config,
                               prefix=maybe_prefix(prefix, "model"))
357
        if config.tie_word_embeddings:
358
            self.lm_head = self.model.embed_tokens
359
360
361
362
363
364
        else:
            self.unpadded_vocab_size = config.vocab_size
            self.lm_head = ParallelLMHead(
                self.unpadded_vocab_size,
                config.hidden_size,
                org_num_embeddings=config.vocab_size,
365
                quant_config=quant_config,
366
                prefix=maybe_prefix(prefix, "lm_head"),
367
            )
368
        self.logits_processor = LogitsProcessor(config.vocab_size)
369
370
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)
Isotr0py's avatar
Isotr0py committed
371

372
373
374
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

Isotr0py's avatar
Isotr0py committed
375
376
377
378
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
379
        intermediate_tensors: Optional[IntermediateTensors] = None,
380
        inputs_embeds: Optional[torch.Tensor] = None,
381
    ) -> Union[torch.Tensor, IntermediateTensors]:
Isotr0py's avatar
Isotr0py committed
382
383
384
        hidden_states = self.model(
            input_ids=input_ids,
            positions=positions,
385
            intermediate_tensors=intermediate_tensors,
386
            inputs_embeds=inputs_embeds,
Isotr0py's avatar
Isotr0py committed
387
388
389
        )
        return hidden_states

390
391
392
393
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> Optional[torch.Tensor]:
394
        logits = self.logits_processor(self.lm_head, hidden_states)
395
396
        return logits

397
398
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
399
400
        loader = AutoWeightsLoader(
            self,
401
402
            skip_prefixes=(["lm_head.weight"]
                           if self.config.tie_word_embeddings else None),
403
404
        )
        return loader.load_weights(weights)