olmo.py 15.7 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

Isotr0py's avatar
Isotr0py committed
3
# Adapted from
4
5
6
# https://github.com/huggingface/transformers/blob/v4.40.1/src/transformers/models/olmo/modeling_olmo.py
# Copyright 2024 The vLLM team.
# Copyright 2024 EleutherAI and the HuggingFace Inc. team. All rights reserved.
Isotr0py's avatar
Isotr0py committed
7
#
8
9
10
11
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
Isotr0py's avatar
Isotr0py committed
12
#
13
14
15
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Isotr0py's avatar
Isotr0py committed
16
#
17
#     http://www.apache.org/licenses/LICENSE-2.0
Isotr0py's avatar
Isotr0py committed
18
#
19
20
21
22
23
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Isotr0py's avatar
Isotr0py committed
24
"""Inference-only OLMo model compatible with HuggingFace weights."""
25
from typing import Iterable, List, Optional, Set, Tuple, Union
Isotr0py's avatar
Isotr0py committed
26
27
28

import torch
from torch import nn
29
from transformers import OlmoConfig
Isotr0py's avatar
Isotr0py committed
30

31
from vllm.attention import Attention, AttentionMetadata
32
from vllm.compilation.decorators import support_torch_compile
33
from vllm.config import CacheConfig, VllmConfig
34
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
35
from vllm.model_executor.layers.activation import SiluAndMul
36
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
37
38
                                               QKVParallelLinear,
                                               RowParallelLinear)
39
from vllm.model_executor.layers.logits_processor import LogitsProcessor
40
from vllm.model_executor.layers.quantization import QuantizationConfig
41
from vllm.model_executor.layers.rotary_embedding import get_rope
Joe Runde's avatar
Joe Runde committed
42
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
43
from vllm.model_executor.layers.vocab_parallel_embedding import (
44
    ParallelLMHead, VocabParallelEmbedding)
45
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
Isotr0py's avatar
Isotr0py committed
46
from vllm.model_executor.sampling_metadata import SamplingMetadata
47
from vllm.sequence import IntermediateTensors
48

49
50
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
51
52
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
53

Isotr0py's avatar
Isotr0py committed
54
55
56

class OlmoAttention(nn.Module):
    """
57
58
    This is the attention block where the output is computed as
    ``Attention(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))``
Isotr0py's avatar
Isotr0py committed
59
60
61
62
63
    (plus another skip connection).
    """

    def __init__(
        self,
64
        config: OlmoConfig,
65
        cache_config: Optional[CacheConfig] = None,
66
        quant_config: Optional[QuantizationConfig] = None,
67
        prefix: str = "",
Isotr0py's avatar
Isotr0py committed
68
69
70
    ):
        super().__init__()
        self.config = config
71
        self.hidden_size = config.hidden_size
72
73
        tensor_model_parallel_world_size = (
            get_tensor_model_parallel_world_size())
74
75
76
        self.total_num_heads = config.num_attention_heads

        assert self.hidden_size % self.total_num_heads == 0
Isotr0py's avatar
Isotr0py committed
77
        assert self.total_num_heads % tensor_model_parallel_world_size == 0
78

79
80
        self.num_heads = (self.total_num_heads //
                          tensor_model_parallel_world_size)
Isotr0py's avatar
Isotr0py committed
81
        self.head_dim = self.hidden_size // self.total_num_heads
82
83
84
        self.max_position_embeddings = config.max_position_embeddings
        self.rope_theta = config.rope_theta
        self.clip_qkv = config.clip_qkv
Isotr0py's avatar
Isotr0py committed
85
86

        # Attention input projection. Projects x -> (q, k, v)
87
88
        self.qkv_proj = QKVParallelLinear(
            self.hidden_size,
Isotr0py's avatar
Isotr0py committed
89
90
            self.head_dim,
            self.total_num_heads,
91
            bias=config.attention_bias,
92
            quant_config=quant_config,
Isotr0py's avatar
Isotr0py committed
93
94
95
        )

        # Rotary embeddings.
96
97
98
99
100
101
        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=self.max_position_embeddings,
            base=self.rope_theta,
        )
Isotr0py's avatar
Isotr0py committed
102
        self.scaling = self.head_dim**-0.5
103
104
        self.attn = Attention(self.num_heads,
                              self.head_dim,
105
                              scale=self.scaling,
106
                              cache_config=cache_config,
107
108
                              quant_config=quant_config,
                              prefix=f"{prefix}.attn")
Isotr0py's avatar
Isotr0py committed
109
110

        # Attention output projection.
111
112
113
114
        self.o_proj = RowParallelLinear(
            self.hidden_size,
            self.hidden_size,
            bias=config.attention_bias,
115
            quant_config=quant_config,
Isotr0py's avatar
Isotr0py committed
116
117
118
119
120
121
        )

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
122
123
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
Isotr0py's avatar
Isotr0py committed
124
    ) -> torch.Tensor:
125
126
127
        qkv, _ = self.qkv_proj(hidden_states)
        if self.clip_qkv is not None:
            qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
Isotr0py's avatar
Isotr0py committed
128
        q, k, v = qkv.chunk(chunks=3, dim=-1)
129
        q, k = self.rotary_emb(positions, q, k)
130
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
131
        output, _ = self.o_proj(attn_output)
Isotr0py's avatar
Isotr0py committed
132
133
134
135
136
        return output


class OlmoMLP(nn.Module):
    """
137
138
    This is the MLP block where the output is computed as
    ``MLP(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))``
Isotr0py's avatar
Isotr0py committed
139
140
141
142
143
    (plus another skip connection).
    """

    def __init__(
        self,
144
        config: OlmoConfig,
145
        quant_config: Optional[QuantizationConfig] = None,
Isotr0py's avatar
Isotr0py committed
146
147
148
    ):
        super().__init__()
        self.config = config
149
150
        self.hidden_size = config.hidden_size
        self.intermediate_size = config.intermediate_size
Isotr0py's avatar
Isotr0py committed
151
152

        # Feed-forward input projection.
153
154
155
156
        self.gate_up_proj = MergedColumnParallelLinear(
            self.hidden_size,
            [self.intermediate_size] * 2,
            bias=False,
157
            quant_config=quant_config,
Isotr0py's avatar
Isotr0py committed
158
159
160
        )

        # Activation function.
161
        self.act_fn = SiluAndMul()
Isotr0py's avatar
Isotr0py committed
162
163

        # Feed-forward output projection.
164
165
166
167
        self.down_proj = RowParallelLinear(
            self.intermediate_size,
            self.hidden_size,
            bias=False,
168
            quant_config=quant_config,
Isotr0py's avatar
Isotr0py committed
169
170
171
172
173
174
        )

    def forward(
        self,
        x: torch.Tensor,
    ) -> torch.Tensor:
175
176
177
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
Isotr0py's avatar
Isotr0py committed
178
179
180
        return x


181
class OlmoDecoderLayer(nn.Module):
Isotr0py's avatar
Isotr0py committed
182
    """
183
184
    This is a typical transformer block where the output is
    computed as ``MLP(LN(x + Attention(LN(x))))``
Isotr0py's avatar
Isotr0py committed
185
186
187
188
    (plus another skip connection).
    """

    def __init__(self,
189
                 config: OlmoConfig,
190
                 cache_config: Optional[CacheConfig] = None,
191
192
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
Isotr0py's avatar
Isotr0py committed
193
194
        super().__init__()
        # Attention block.
195
196
197
198
        self.self_attn = OlmoAttention(config,
                                       cache_config,
                                       quant_config,
                                       prefix=f"{prefix}.self_attn")
Isotr0py's avatar
Isotr0py committed
199
200

        # MLP block.
201
        self.mlp = OlmoMLP(config, quant_config)
Isotr0py's avatar
Isotr0py committed
202

203
204
205
206
207
208
209
210
        # LayerNorm
        self.input_layernorm = nn.LayerNorm(config.hidden_size,
                                            elementwise_affine=False,
                                            bias=False)
        self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
                                                     elementwise_affine=False,
                                                     bias=False)

Isotr0py's avatar
Isotr0py committed
211
212
213
214
    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
215
216
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
Isotr0py's avatar
Isotr0py committed
217
218
    ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
        # Attention block.
219
220
221
222
223
        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states)
        hidden_states = self.self_attn(positions, hidden_states, kv_cache,
                                       attn_metadata)
        hidden_states = hidden_states + residual
Isotr0py's avatar
Isotr0py committed
224
225

        # MLP block.
226
227
228
229
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states
Isotr0py's avatar
Isotr0py committed
230
231
232
        return hidden_states


233
@support_torch_compile
Isotr0py's avatar
Isotr0py committed
234
235
class OlmoModel(nn.Module):

236
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
Isotr0py's avatar
Isotr0py committed
237
        super().__init__()
238
239
240
241
242

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

Isotr0py's avatar
Isotr0py committed
243
244
        self.config = config

245
246
        self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
                                                   config.hidden_size)
247
248
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
249
250
            lambda prefix: OlmoDecoderLayer(
                config, cache_config, quant_config, prefix=prefix),
251
            prefix=f"{prefix}.layers")
252
253
254
        self.norm = nn.LayerNorm(config.hidden_size,
                                 elementwise_affine=False,
                                 bias=False)
255
256
257
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(["hidden_states"],
                                                    config.hidden_size))
Isotr0py's avatar
Isotr0py committed
258

259
260
261
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

Isotr0py's avatar
Isotr0py committed
262
263
264
265
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
266
267
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
268
        intermediate_tensors: Optional[IntermediateTensors],
269
        inputs_embeds: Optional[torch.Tensor] = None,
270
    ) -> Union[torch.Tensor, IntermediateTensors]:
Isotr0py's avatar
Isotr0py committed
271
272
273
        """
        :param input_ids: A tensor of shape `(batch_size, seq_len)`.
        """
274
        if get_pp_group().is_first_rank:
275
276
277
278
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
279
280
281
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
Isotr0py's avatar
Isotr0py committed
282
283

        # Apply blocks one-by-one.
284
        for i in range(self.start_layer, self.end_layer):
Isotr0py's avatar
Isotr0py committed
285
            # shape: (batch_size, seq_len, d_model)
286
            hidden_states = self.layers[i](
Isotr0py's avatar
Isotr0py committed
287
                positions,
288
                hidden_states,
289
                kv_caches[i - self.start_layer],
290
                attn_metadata,
Isotr0py's avatar
Isotr0py committed
291
292
            )

293
294
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})
Isotr0py's avatar
Isotr0py committed
295
296
        # Apply final layer norm.
        # shape: (batch_size, seq_len or 1, d_model)
297
298
        hidden_states = self.norm(hidden_states)
        return hidden_states
Isotr0py's avatar
Isotr0py committed
299
300


301
class OlmoForCausalLM(nn.Module, SupportsPP):
Isotr0py's avatar
Isotr0py committed
302
303
304
305
    """
    Extremely barebones HF model wrapper.
    """

306
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
Isotr0py's avatar
Isotr0py committed
307
        super().__init__()
308
309
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
Isotr0py's avatar
Isotr0py committed
310
        self.config = config
311
312
        self.model = OlmoModel(vllm_config=vllm_config,
                               prefix=maybe_prefix(prefix, "model"))
313
        if config.tie_word_embeddings:
314
            self.lm_head = self.model.embed_tokens
315
316
317
318
319
320
        else:
            self.unpadded_vocab_size = config.vocab_size
            self.lm_head = ParallelLMHead(
                self.unpadded_vocab_size,
                config.hidden_size,
                org_num_embeddings=config.vocab_size,
321
                quant_config=quant_config,
322
            )
323
        self.logits_processor = LogitsProcessor(config.vocab_size)
Joe Runde's avatar
Joe Runde committed
324
        self.sampler = get_sampler()
325
326
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)
Isotr0py's avatar
Isotr0py committed
327

328
329
330
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

Isotr0py's avatar
Isotr0py committed
331
332
333
334
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
335
336
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
337
        intermediate_tensors: Optional[IntermediateTensors] = None,
338
        inputs_embeds: Optional[torch.Tensor] = None,
339
    ) -> Union[torch.Tensor, IntermediateTensors]:
Isotr0py's avatar
Isotr0py committed
340
341
342
343
        hidden_states = self.model(
            input_ids=input_ids,
            positions=positions,
            kv_caches=kv_caches,
344
            attn_metadata=attn_metadata,
345
            intermediate_tensors=intermediate_tensors,
346
            inputs_embeds=inputs_embeds,
Isotr0py's avatar
Isotr0py committed
347
348
349
        )
        return hidden_states

350
351
352
353
354
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
355
        logits = self.logits_processor(self.lm_head, hidden_states,
356
357
358
                                       sampling_metadata)
        return logits

Isotr0py's avatar
Isotr0py committed
359
360
    def sample(
        self,
361
        logits: torch.Tensor,
Isotr0py's avatar
Isotr0py committed
362
363
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
364
        next_tokens = self.sampler(logits, sampling_metadata)
Isotr0py's avatar
Isotr0py committed
365
366
        return next_tokens

367
368
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
369
370
371
372
373
374
375
376
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]
Isotr0py's avatar
Isotr0py committed
377
        params_dict = dict(self.named_parameters(remove_duplicate=False))
378
        loaded_params: Set[str] = set()
379
        for name, loaded_weight in weights:
380
381
382
383
384
385
386
            if "rotary_emb.inv_freq" in name:
                continue
            if ("rotary_emb.cos_cached" in name
                    or "rotary_emb.sin_cached" in name):
                # Models trained using ColossalAI may include these tensors in
                # the checkpoint. Skip them.
                continue
387
388
389
390
391
            # With tie_word_embeddings, we can skip lm_head.weight
            # The weight might appear unnecessarily in the files if the model is
            # processed with quantization, LoRA, fine-tuning, etc.
            if self.config.tie_word_embeddings and "lm_head.weight" in name:
                continue
392
393
394
395
396
397
398
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
399
400
                if is_pp_missing_parameter(name, self):
                    continue
401
402
403
404
405
406
407
408
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
409
410
                if is_pp_missing_parameter(name, self):
                    continue
411
412
413
414
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
415
416
            loaded_params.add(name)
        return loaded_params