olmo.py 15.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

Isotr0py's avatar
Isotr0py committed
4
# Adapted from
5
6
7
# https://github.com/huggingface/transformers/blob/v4.40.1/src/transformers/models/olmo/modeling_olmo.py
# Copyright 2024 The vLLM team.
# Copyright 2024 EleutherAI and the HuggingFace Inc. team. All rights reserved.
Isotr0py's avatar
Isotr0py committed
8
#
9
10
11
12
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
Isotr0py's avatar
Isotr0py committed
13
#
14
15
16
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Isotr0py's avatar
Isotr0py committed
17
#
18
#     http://www.apache.org/licenses/LICENSE-2.0
Isotr0py's avatar
Isotr0py committed
19
#
20
21
22
23
24
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Isotr0py's avatar
Isotr0py committed
25
"""Inference-only OLMo model compatible with HuggingFace weights."""
26
from collections.abc import Iterable
27
from itertools import islice
28
from typing import Optional, Union
Isotr0py's avatar
Isotr0py committed
29
30
31

import torch
from torch import nn
32
from transformers import OlmoConfig
Isotr0py's avatar
Isotr0py committed
33

34
from vllm.attention import Attention
35
from vllm.compilation.decorators import support_torch_compile
36
from vllm.config import CacheConfig, VllmConfig
37
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
38
from vllm.model_executor.layers.activation import SiluAndMul
39
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
40
41
                                               QKVParallelLinear,
                                               RowParallelLinear)
42
from vllm.model_executor.layers.logits_processor import LogitsProcessor
43
from vllm.model_executor.layers.quantization import QuantizationConfig
44
from vllm.model_executor.layers.rotary_embedding import get_rope
45
from vllm.model_executor.layers.vocab_parallel_embedding import (
46
    ParallelLMHead, VocabParallelEmbedding)
47
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
Isotr0py's avatar
Isotr0py committed
48
from vllm.model_executor.sampling_metadata import SamplingMetadata
49
from vllm.sequence import IntermediateTensors
50

51
from .interfaces import SupportsLoRA, SupportsPP
52
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
53
54
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
55

Isotr0py's avatar
Isotr0py committed
56
57
58

class OlmoAttention(nn.Module):
    """
59
60
    This is the attention block where the output is computed as
    ``Attention(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))``
Isotr0py's avatar
Isotr0py committed
61
62
63
64
65
    (plus another skip connection).
    """

    def __init__(
        self,
66
        config: OlmoConfig,
67
        cache_config: Optional[CacheConfig] = None,
68
        quant_config: Optional[QuantizationConfig] = None,
69
        prefix: str = "",
Isotr0py's avatar
Isotr0py committed
70
71
72
    ):
        super().__init__()
        self.config = config
73
        self.hidden_size = config.hidden_size
74
75
        tensor_model_parallel_world_size = (
            get_tensor_model_parallel_world_size())
76
77
78
        self.total_num_heads = config.num_attention_heads

        assert self.hidden_size % self.total_num_heads == 0
Isotr0py's avatar
Isotr0py committed
79
        assert self.total_num_heads % tensor_model_parallel_world_size == 0
80

81
82
        self.num_heads = (self.total_num_heads //
                          tensor_model_parallel_world_size)
Isotr0py's avatar
Isotr0py committed
83
        self.head_dim = self.hidden_size // self.total_num_heads
84
85
86
        self.max_position_embeddings = config.max_position_embeddings
        self.rope_theta = config.rope_theta
        self.clip_qkv = config.clip_qkv
Isotr0py's avatar
Isotr0py committed
87
88

        # Attention input projection. Projects x -> (q, k, v)
89
90
        self.qkv_proj = QKVParallelLinear(
            self.hidden_size,
Isotr0py's avatar
Isotr0py committed
91
92
            self.head_dim,
            self.total_num_heads,
93
            bias=config.attention_bias,
94
            quant_config=quant_config,
95
            prefix=f"{prefix}.qkv_proj",
Isotr0py's avatar
Isotr0py committed
96
97
98
        )

        # Rotary embeddings.
99
100
101
102
103
104
        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=self.max_position_embeddings,
            base=self.rope_theta,
        )
Isotr0py's avatar
Isotr0py committed
105
        self.scaling = self.head_dim**-0.5
106
107
        self.attn = Attention(self.num_heads,
                              self.head_dim,
108
                              scale=self.scaling,
109
                              cache_config=cache_config,
110
111
                              quant_config=quant_config,
                              prefix=f"{prefix}.attn")
Isotr0py's avatar
Isotr0py committed
112
113

        # Attention output projection.
114
115
116
117
        self.o_proj = RowParallelLinear(
            self.hidden_size,
            self.hidden_size,
            bias=config.attention_bias,
118
            quant_config=quant_config,
119
            prefix=f"{prefix}.o_proj",
Isotr0py's avatar
Isotr0py committed
120
121
122
123
124
125
126
        )

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
127
128
129
        qkv, _ = self.qkv_proj(hidden_states)
        if self.clip_qkv is not None:
            qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
Isotr0py's avatar
Isotr0py committed
130
        q, k, v = qkv.chunk(chunks=3, dim=-1)
131
        q, k = self.rotary_emb(positions, q, k)
132
        attn_output = self.attn(q, k, v)
133
        output, _ = self.o_proj(attn_output)
Isotr0py's avatar
Isotr0py committed
134
135
136
137
138
        return output


class OlmoMLP(nn.Module):
    """
139
140
    This is the MLP block where the output is computed as
    ``MLP(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))``
Isotr0py's avatar
Isotr0py committed
141
142
143
144
145
    (plus another skip connection).
    """

    def __init__(
        self,
146
        config: OlmoConfig,
147
        quant_config: Optional[QuantizationConfig] = None,
148
        prefix: str = "",
Isotr0py's avatar
Isotr0py committed
149
150
151
    ):
        super().__init__()
        self.config = config
152
153
        self.hidden_size = config.hidden_size
        self.intermediate_size = config.intermediate_size
Isotr0py's avatar
Isotr0py committed
154
155

        # Feed-forward input projection.
156
157
158
159
        self.gate_up_proj = MergedColumnParallelLinear(
            self.hidden_size,
            [self.intermediate_size] * 2,
            bias=False,
160
            quant_config=quant_config,
161
            prefix=f"{prefix}.gate_up_proj",
Isotr0py's avatar
Isotr0py committed
162
163
164
        )

        # Activation function.
165
        self.act_fn = SiluAndMul()
Isotr0py's avatar
Isotr0py committed
166
167

        # Feed-forward output projection.
168
169
170
171
        self.down_proj = RowParallelLinear(
            self.intermediate_size,
            self.hidden_size,
            bias=False,
172
            quant_config=quant_config,
173
            prefix=f"{prefix}.down_proj",
Isotr0py's avatar
Isotr0py committed
174
175
176
177
178
179
        )

    def forward(
        self,
        x: torch.Tensor,
    ) -> torch.Tensor:
180
181
182
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
Isotr0py's avatar
Isotr0py committed
183
184
185
        return x


186
class OlmoDecoderLayer(nn.Module):
Isotr0py's avatar
Isotr0py committed
187
    """
188
189
    This is a typical transformer block where the output is
    computed as ``MLP(LN(x + Attention(LN(x))))``
Isotr0py's avatar
Isotr0py committed
190
191
192
193
    (plus another skip connection).
    """

    def __init__(self,
194
                 config: OlmoConfig,
195
                 cache_config: Optional[CacheConfig] = None,
196
197
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
Isotr0py's avatar
Isotr0py committed
198
199
        super().__init__()
        # Attention block.
200
201
202
203
        self.self_attn = OlmoAttention(config,
                                       cache_config,
                                       quant_config,
                                       prefix=f"{prefix}.self_attn")
Isotr0py's avatar
Isotr0py committed
204
205

        # MLP block.
206
        self.mlp = OlmoMLP(config, quant_config, prefix=f"{prefix}.mlp")
Isotr0py's avatar
Isotr0py committed
207

208
209
210
211
212
213
214
215
        # LayerNorm
        self.input_layernorm = nn.LayerNorm(config.hidden_size,
                                            elementwise_affine=False,
                                            bias=False)
        self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
                                                     elementwise_affine=False,
                                                     bias=False)

Isotr0py's avatar
Isotr0py committed
216
217
218
219
    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
220
    ) -> tuple[torch.Tensor, Optional[tuple[torch.Tensor, torch.Tensor]]]:
Isotr0py's avatar
Isotr0py committed
221
        # Attention block.
222
223
        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states)
224
        hidden_states = self.self_attn(positions, hidden_states)
225
        hidden_states = hidden_states + residual
Isotr0py's avatar
Isotr0py committed
226
227

        # MLP block.
228
229
230
231
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states
Isotr0py's avatar
Isotr0py committed
232
233
234
        return hidden_states


235
@support_torch_compile
Isotr0py's avatar
Isotr0py committed
236
237
class OlmoModel(nn.Module):

238
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
Isotr0py's avatar
Isotr0py committed
239
        super().__init__()
240
241
242
243
244

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

Isotr0py's avatar
Isotr0py committed
245
246
        self.config = config

247
248
        self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
                                                   config.hidden_size)
249
250
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
251
252
            lambda prefix: OlmoDecoderLayer(
                config, cache_config, quant_config, prefix=prefix),
253
            prefix=f"{prefix}.layers")
254
255
256
        self.norm = nn.LayerNorm(config.hidden_size,
                                 elementwise_affine=False,
                                 bias=False)
257
258
259
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(["hidden_states"],
                                                    config.hidden_size))
Isotr0py's avatar
Isotr0py committed
260

261
262
263
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

Isotr0py's avatar
Isotr0py committed
264
265
266
267
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
268
        intermediate_tensors: Optional[IntermediateTensors],
269
        inputs_embeds: Optional[torch.Tensor] = None,
270
    ) -> Union[torch.Tensor, IntermediateTensors]:
Isotr0py's avatar
Isotr0py committed
271
272
273
        """
        :param input_ids: A tensor of shape `(batch_size, seq_len)`.
        """
274
        if get_pp_group().is_first_rank:
275
276
277
278
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
279
280
281
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
Isotr0py's avatar
Isotr0py committed
282
283

        # Apply blocks one-by-one.
284
        for layer in islice(self.layers, self.start_layer, self.end_layer):
Isotr0py's avatar
Isotr0py committed
285
            # shape: (batch_size, seq_len, d_model)
286
            hidden_states = layer(positions, hidden_states)
Isotr0py's avatar
Isotr0py committed
287

288
289
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})
Isotr0py's avatar
Isotr0py committed
290
291
        # Apply final layer norm.
        # shape: (batch_size, seq_len or 1, d_model)
292
293
        hidden_states = self.norm(hidden_states)
        return hidden_states
Isotr0py's avatar
Isotr0py committed
294

295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]
        params_dict = dict(self.named_parameters(remove_duplicate=False))
        loaded_params: set[str] = set()
        for name, loaded_weight in weights:
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

Isotr0py's avatar
Isotr0py committed
334

335
class OlmoForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
Isotr0py's avatar
Isotr0py committed
336
337
338
    """
    Extremely barebones HF model wrapper.
    """
339
340
341
342
343
344
345
346
347
348
349
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }
Isotr0py's avatar
Isotr0py committed
350

351
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
Isotr0py's avatar
Isotr0py committed
352
        super().__init__()
353
354
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
Isotr0py's avatar
Isotr0py committed
355
        self.config = config
356
357
        self.model = OlmoModel(vllm_config=vllm_config,
                               prefix=maybe_prefix(prefix, "model"))
358
        if config.tie_word_embeddings:
359
            self.lm_head = self.model.embed_tokens
360
361
362
363
364
365
        else:
            self.unpadded_vocab_size = config.vocab_size
            self.lm_head = ParallelLMHead(
                self.unpadded_vocab_size,
                config.hidden_size,
                org_num_embeddings=config.vocab_size,
366
                quant_config=quant_config,
367
                prefix=maybe_prefix(prefix, "lm_head"),
368
            )
369
        self.logits_processor = LogitsProcessor(config.vocab_size)
370
371
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)
Isotr0py's avatar
Isotr0py committed
372

373
374
375
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

Isotr0py's avatar
Isotr0py committed
376
377
378
379
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
380
        intermediate_tensors: Optional[IntermediateTensors] = None,
381
        inputs_embeds: Optional[torch.Tensor] = None,
382
    ) -> Union[torch.Tensor, IntermediateTensors]:
Isotr0py's avatar
Isotr0py committed
383
384
385
        hidden_states = self.model(
            input_ids=input_ids,
            positions=positions,
386
            intermediate_tensors=intermediate_tensors,
387
            inputs_embeds=inputs_embeds,
Isotr0py's avatar
Isotr0py committed
388
389
390
        )
        return hidden_states

391
392
393
394
395
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
396
        logits = self.logits_processor(self.lm_head, hidden_states,
397
398
399
                                       sampling_metadata)
        return logits

400
401
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
402
403
        loader = AutoWeightsLoader(
            self,
404
405
            skip_prefixes=(["lm_head.weight"]
                           if self.config.tie_word_embeddings else None),
406
407
        )
        return loader.load_weights(weights)