persimmon.py 13.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# adapted from https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/persimmon/modeling_persimmon.py
# Copyright 2023 The vLLM team.
# Copyright 2023 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only persimmon model compatible with HuggingFace weights."""
25

26
from collections.abc import Iterable
27
from itertools import islice
28
29
30
31
32

import torch
from torch import nn
from transformers import PersimmonConfig

33
from vllm.attention import Attention
34
from vllm.compilation.decorators import support_torch_compile
35
from vllm.config import CacheConfig, VllmConfig
36
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
37
from vllm.model_executor.layers.activation import get_act_fn
38
39
40
41
42
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
43
from vllm.model_executor.layers.logits_processor import LogitsProcessor
44
from vllm.model_executor.layers.quantization import QuantizationConfig
45
46
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
47
48
49
    ParallelLMHead,
    VocabParallelEmbedding,
)
50
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
51
from vllm.sequence import IntermediateTensors
52

53
from .interfaces import SupportsPP
54
55
56
57
58
59
60
from .utils import (
    AutoWeightsLoader,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
61

62
63

class PersimmonMLP(nn.Module):
64
    def __init__(
65
66
67
68
        self,
        config: PersimmonConfig,
        quant_config: QuantizationConfig | None = None,
        prefix: str = "",
69
    ):
70
        super().__init__()
71
        self.dense_h_to_4h = ColumnParallelLinear(
72
73
74
75
            config.hidden_size,
            config.intermediate_size,
            quant_config=quant_config,
            prefix=f"{prefix}.dense_h_to_4h",
76
77
        )
        self.dense_4h_to_h = RowParallelLinear(
78
79
80
81
            config.intermediate_size,
            config.hidden_size,
            quant_config=quant_config,
            prefix=f"{prefix}.dense_4h_to_h",
82
        )
83
        self.act = get_act_fn(config.hidden_act)
84
85
86
87
88
89
90
91
92

    def forward(self, hidden_states) -> torch.Tensor:
        hidden_states, _ = self.dense_h_to_4h(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.dense_4h_to_h(hidden_states)
        return hidden_states


class PersimmonAttention(nn.Module):
93
94
95
    def __init__(
        self,
        config: PersimmonConfig,
96
97
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
98
99
        prefix: str = "",
    ):
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
        super().__init__()
        self.config = config
        tensor_parallel_world_size = get_tensor_model_parallel_world_size()

        self.hidden_size = config.hidden_size
        self.total_num_heads = config.num_attention_heads
        self.num_heads = self.total_num_heads // tensor_parallel_world_size
        self.head_dim = self.hidden_size // self.total_num_heads
        self.max_position_embeddings = config.max_position_embeddings
        self.rope_theta = config.rope_theta
        self.partial_rotary_factor = config.partial_rotary_factor
        self.is_causal = True

        assert (self.head_dim * self.total_num_heads) == self.hidden_size
        assert self.total_num_heads % tensor_parallel_world_size == 0

        self.query_key_value = QKVParallelLinear(
            self.hidden_size,
            self.head_dim,
            self.total_num_heads,
            bias=True,
            quant_config=quant_config,
122
            prefix=f"{prefix}.query_key_value",
123
124
        )
        self.dense = RowParallelLinear(
125
            self.total_num_heads * self.head_dim,
126
127
128
            self.hidden_size,
            bias=True,
            quant_config=quant_config,
129
            prefix=f"{prefix}.dense",
130
131
132
133
134
135
136
137
138
        )
        self.is_qk_layernorm = config.qk_layernorm

        if self.is_qk_layernorm:
            self.q_layernorm = nn.LayerNorm(self.head_dim)
            self.k_layernorm = nn.LayerNorm(self.head_dim)

        self.rotary_emb = get_rope(
            self.head_dim,
139
            rotary_dim=self.head_dim,
140
141
            max_position=self.max_position_embeddings,
            base=self.rope_theta,
142
            partial_rotary_factor=self.partial_rotary_factor,
143
144
        )
        self.scaling = self.head_dim**-0.5
145
146
147
148
149
150
151
152
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            scale=self.scaling,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
        )
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184

    def _split_heads(self, x: torch.Tensor) -> torch.Tensor:
        # [seq_length, hidden_size] -> [seq_length, num_heads, head_dim]
        seq_length = x.shape[0]
        return x.view(seq_length, self.num_heads, self.head_dim)

    def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
        # [seq_length, num_heads, head_dim] -> [seq_length, hidden_size]
        seq_length = x.shape[0]
        return x.view(seq_length, self.num_heads * self.head_dim)

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        # [seq_length, 3 x hidden_size]
        qkv, _ = self.query_key_value(hidden_states)
        q, k, v = qkv.chunk(chunks=3, dim=-1)

        if self.is_qk_layernorm:
            # [seq_length, num_heads, head_dim]
            q = self._split_heads(q)
            k = self._split_heads(k)

            q = self.q_layernorm(q)
            k = self.k_layernorm(k)

            q = self._merge_heads(q)
            k = self._merge_heads(k)

        q, k = self.rotary_emb(position_ids, q, k)
185
        attn_output = self.attn(q, k, v)
186
187
188
189
190
        output, _ = self.dense(attn_output)
        return output


class PersimmonDecoderLayer(nn.Module):
191
192
193
    def __init__(
        self,
        config: PersimmonConfig,
194
195
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
196
197
        prefix: str = "",
    ):
198
199
        super().__init__()
        self.hidden_size = config.hidden_size
200
201
202
203
204
205
        self.self_attn = PersimmonAttention(
            config=config,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.self_attn",
        )
206
207
208
209
210
        self.mlp = PersimmonMLP(
            config,
            quant_config=quant_config,
            prefix=f"{prefix}.mlp",
        )
211
212
213
214
215
216
        self.input_layernorm = nn.LayerNorm(
            config.hidden_size, eps=config.layer_norm_eps
        )
        self.post_attention_layernorm = nn.LayerNorm(
            config.hidden_size, eps=config.layer_norm_eps
        )
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        residual = hidden_states

        hidden_states = self.input_layernorm(hidden_states)

        # Self Attention
        hidden_states = self.self_attn(
            position_ids=position_ids,
            hidden_states=hidden_states,
        )
        hidden_states = residual + hidden_states

        # Fully Connected
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)

        hidden_states = hidden_states + residual

        outputs = hidden_states
        return outputs


245
@support_torch_compile
246
class PersimmonModel(nn.Module):
247
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
248
        super().__init__()
249
250
251
252
253

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

254
        self.vocab_size = config.vocab_size
255
        self.config = config
256
257
258
        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size, config.hidden_size
        )
259
260
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
261
            lambda prefix: PersimmonDecoderLayer(
262
263
264
265
266
267
268
269
270
271
                config, cache_config, quant_config, prefix=prefix
            ),
            prefix=f"{prefix}.layers",
        )
        self.final_layernorm = nn.LayerNorm(
            config.hidden_size, eps=config.layer_norm_eps
        )
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states"], config.hidden_size
        )
272

273
274
275
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

276
277
278
279
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
280
281
282
        intermediate_tensors: IntermediateTensors | None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
283
284
285
286
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
287
                hidden_states = self.get_input_embeddings(input_ids)
288
        else:
289
290
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
291
        for layer in islice(self.layers, self.start_layer, self.end_layer):
292
            hidden_states = layer(positions, hidden_states)
293
294
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})
295
296
297
        hidden_states = self.final_layernorm(hidden_states)
        return hidden_states

298
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
299
        params_dict = dict(self.named_parameters(remove_duplicate=False))
300
        loaded_params: set[str] = set()
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
        for name, loaded_weight in weights:
            if is_pp_missing_parameter(name, self):
                continue
            param = params_dict[name]

            if "query_key_value" in name:
                # copy from vllm/model_executor/models/bloom.py
                # NOTE: Persimmon's fused QKV's output_dim has the shape of
                # (num_heads * 3 * head_size), while the
                # required shape is (3 * num_heads * head_size).
                # Thus, we need weight conversion.
                output_dim = getattr(param, "output_dim", None)
                num_heads = self.config.num_attention_heads
                if output_dim is not None:
                    loaded_weight_shape = loaded_weight.shape
                    loaded_weight = loaded_weight.view(
317
318
319
320
321
                        loaded_weight_shape[:output_dim]
                        + (num_heads, 3, -1)
                        + loaded_weight_shape[output_dim + 1 :]
                    )
                    loaded_weight = loaded_weight.transpose(output_dim, output_dim + 1)
322
323
                    loaded_weight = loaded_weight.reshape(loaded_weight_shape)

324
            weight_loader = getattr(param, "weight_loader", default_weight_loader)
325
326
327
328
            weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

329

330
class PersimmonForCausalLM(nn.Module, SupportsPP):
331
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
332
        super().__init__()
333
        config = vllm_config.model_config.hf_config
334
        self.config = config
335
        self.vocab_size = config.vocab_size
336
337
338
339
340
341
342
343
344
        self.model = PersimmonModel(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
        self.lm_head = ParallelLMHead(
            config.vocab_size,
            config.hidden_size,
            bias=False,
            prefix=maybe_prefix(prefix, "lm_head"),
        )
345
        self.logits_processor = LogitsProcessor(config.vocab_size)
346
        self.make_empty_intermediate_tensors = (
347
348
            self.model.make_empty_intermediate_tensors
        )
349

350
351
352
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

353
354
355
356
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
357
358
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
359
360
361
362
    ):
        hidden_states = self.model(
            input_ids=input_ids,
            positions=positions,
363
            intermediate_tensors=intermediate_tensors,
364
365
366
367
            inputs_embeds=inputs_embeds,
        )
        return hidden_states

368
369
370
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
371
    ) -> torch.Tensor | None:
372
        logits = self.logits_processor(self.lm_head, hidden_states)
373
374
        return logits

375
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
376
377
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights)