"docs/features/reasoning_outputs.md" did not exist on "b6a6e7a529d72e50cbe0c0b5360cf890124260e6"
persimmon.py 13.9 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# adapted from https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/persimmon/modeling_persimmon.py
# Copyright 2023 The vLLM team.
# Copyright 2023 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only persimmon model compatible with HuggingFace weights."""
24
25
from collections.abc import Iterable
from typing import Optional, Union
26
27
28
29
30

import torch
from torch import nn
from transformers import PersimmonConfig

31
from vllm.attention import Attention
32
from vllm.compilation.decorators import support_torch_compile
33
from vllm.config import CacheConfig, VllmConfig
34
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
35
from vllm.model_executor.layers.activation import get_act_fn
36
37
38
39
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
40
from vllm.model_executor.layers.quantization import QuantizationConfig
41
42
43
44
45
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
    ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
46
from vllm.sequence import IntermediateTensors
47

48
from .interfaces import SupportsPP
49
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
50
51
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
52

53
54
55
56
57
58
59
60
61
62
63
64
65

class PersimmonMLP(nn.Module):

    def __init__(self,
                 config: PersimmonConfig,
                 quant_config: Optional[QuantizationConfig] = None):
        super().__init__()
        self.dense_h_to_4h = ColumnParallelLinear(config.hidden_size,
                                                  config.intermediate_size,
                                                  quant_config=quant_config)
        self.dense_4h_to_h = RowParallelLinear(config.intermediate_size,
                                               config.hidden_size,
                                               quant_config=quant_config)
66
        self.act = get_act_fn(config.hidden_act)
67
68
69
70
71
72
73
74
75
76
77
78
79

    def forward(self, hidden_states) -> torch.Tensor:
        hidden_states, _ = self.dense_h_to_4h(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.dense_4h_to_h(hidden_states)
        return hidden_states


class PersimmonAttention(nn.Module):

    def __init__(self,
                 config: PersimmonConfig,
                 cache_config: Optional[CacheConfig] = None,
80
81
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
        super().__init__()
        self.config = config
        tensor_parallel_world_size = get_tensor_model_parallel_world_size()

        self.hidden_size = config.hidden_size
        self.total_num_heads = config.num_attention_heads
        self.num_heads = self.total_num_heads // tensor_parallel_world_size
        self.head_dim = self.hidden_size // self.total_num_heads
        self.max_position_embeddings = config.max_position_embeddings
        self.rope_theta = config.rope_theta
        self.partial_rotary_factor = config.partial_rotary_factor
        self.is_causal = True

        assert (self.head_dim * self.total_num_heads) == self.hidden_size
        assert self.total_num_heads % tensor_parallel_world_size == 0

        self.query_key_value = QKVParallelLinear(
            self.hidden_size,
            self.head_dim,
            self.total_num_heads,
            bias=True,
            quant_config=quant_config,
        )
        self.dense = RowParallelLinear(
106
            self.total_num_heads * self.head_dim,
107
108
109
110
111
112
113
114
115
116
117
118
            self.hidden_size,
            bias=True,
            quant_config=quant_config,
        )
        self.is_qk_layernorm = config.qk_layernorm

        if self.is_qk_layernorm:
            self.q_layernorm = nn.LayerNorm(self.head_dim)
            self.k_layernorm = nn.LayerNorm(self.head_dim)

        self.rotary_emb = get_rope(
            self.head_dim,
119
            rotary_dim=self.head_dim,
120
121
            max_position=self.max_position_embeddings,
            base=self.rope_theta,
122
            partial_rotary_factor=self.partial_rotary_factor,
123
124
125
126
127
128
        )
        self.scaling = self.head_dim**-0.5
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              scale=self.scaling,
                              cache_config=cache_config,
129
130
                              quant_config=quant_config,
                              prefix=f"{prefix}.attn")
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162

    def _split_heads(self, x: torch.Tensor) -> torch.Tensor:
        # [seq_length, hidden_size] -> [seq_length, num_heads, head_dim]
        seq_length = x.shape[0]
        return x.view(seq_length, self.num_heads, self.head_dim)

    def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
        # [seq_length, num_heads, head_dim] -> [seq_length, hidden_size]
        seq_length = x.shape[0]
        return x.view(seq_length, self.num_heads * self.head_dim)

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        # [seq_length, 3 x hidden_size]
        qkv, _ = self.query_key_value(hidden_states)
        q, k, v = qkv.chunk(chunks=3, dim=-1)

        if self.is_qk_layernorm:
            # [seq_length, num_heads, head_dim]
            q = self._split_heads(q)
            k = self._split_heads(k)

            q = self.q_layernorm(q)
            k = self.k_layernorm(k)

            q = self._merge_heads(q)
            k = self._merge_heads(k)

        q, k = self.rotary_emb(position_ids, q, k)
163
        attn_output = self.attn(q, k, v)
164
165
166
167
168
169
170
171
172
        output, _ = self.dense(attn_output)
        return output


class PersimmonDecoderLayer(nn.Module):

    def __init__(self,
                 config: PersimmonConfig,
                 cache_config: Optional[CacheConfig] = None,
173
174
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
175
176
177
178
        super().__init__()
        self.hidden_size = config.hidden_size
        self.self_attn = PersimmonAttention(config=config,
                                            cache_config=cache_config,
179
180
                                            quant_config=quant_config,
                                            prefix=f"{prefix}.self_attn")
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
        self.mlp = PersimmonMLP(config, quant_config=quant_config)
        self.input_layernorm = nn.LayerNorm(config.hidden_size,
                                            eps=config.layer_norm_eps)
        self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
                                                     eps=config.layer_norm_eps)

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        residual = hidden_states

        hidden_states = self.input_layernorm(hidden_states)

        # Self Attention
        hidden_states = self.self_attn(
            position_ids=position_ids,
            hidden_states=hidden_states,
        )
        hidden_states = residual + hidden_states

        # Fully Connected
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)

        hidden_states = hidden_states + residual

        outputs = hidden_states
        return outputs


214
@support_torch_compile
215
216
class PersimmonModel(nn.Module):

217
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
218
        super().__init__()
219
220
221
222
223

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

224
        self.vocab_size = config.vocab_size
225
        self.config = config
226
227
        self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
                                                   config.hidden_size)
228
229
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
230
231
            lambda prefix: PersimmonDecoderLayer(
                config, cache_config, quant_config, prefix=prefix),
232
            prefix=f"{prefix}.layers")
233
234
        self.final_layernorm = nn.LayerNorm(config.hidden_size,
                                            eps=config.layer_norm_eps)
235
236
237
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(["hidden_states"],
                                                    config.hidden_size))
238

239
240
241
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

242
243
244
245
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
246
        intermediate_tensors: Optional[IntermediateTensors],
247
        inputs_embeds: Optional[torch.Tensor] = None,
248
249
250
251
252
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
253
                hidden_states = self.get_input_embeddings(input_ids)
254
        else:
255
256
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
257
258
        for layer in self.layers[self.start_layer:self.end_layer]:
            hidden_states = layer(positions, hidden_states)
259
260
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})
261
262
263
        hidden_states = self.final_layernorm(hidden_states)
        return hidden_states

264
265
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
266
        params_dict = dict(self.named_parameters(remove_duplicate=False))
267
        loaded_params: set[str] = set()
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
        for name, loaded_weight in weights:
            if is_pp_missing_parameter(name, self):
                continue
            param = params_dict[name]

            if "query_key_value" in name:
                # copy from vllm/model_executor/models/bloom.py
                # NOTE: Persimmon's fused QKV's output_dim has the shape of
                # (num_heads * 3 * head_size), while the
                # required shape is (3 * num_heads * head_size).
                # Thus, we need weight conversion.
                output_dim = getattr(param, "output_dim", None)
                num_heads = self.config.num_attention_heads
                if output_dim is not None:
                    loaded_weight_shape = loaded_weight.shape
                    loaded_weight = loaded_weight.view(
                        loaded_weight_shape[:output_dim] + (num_heads, 3, -1) +
                        loaded_weight_shape[output_dim + 1:])
                    loaded_weight = loaded_weight.transpose(
                        output_dim, output_dim + 1)
                    loaded_weight = loaded_weight.reshape(loaded_weight_shape)

            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

296

297
class PersimmonForCausalLM(nn.Module, SupportsPP):
298

299
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
300
        super().__init__()
301
        config = vllm_config.model_config.hf_config
302
        self.config = config
303
        self.vocab_size = config.vocab_size
304
305
        self.model = PersimmonModel(vllm_config=vllm_config,
                                    prefix=maybe_prefix(prefix, "model"))
306
        self.lm_head = ParallelLMHead(config.vocab_size,
307
308
                                      config.hidden_size,
                                      bias=False)
309
        self.logits_processor = LogitsProcessor(config.vocab_size)
310
311
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)
312

313
314
315
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

316
317
318
319
320
321
322
323
324
325
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
    ):
        hidden_states = self.model(
            input_ids=input_ids,
            positions=positions,
326
            intermediate_tensors=intermediate_tensors,
327
328
329
330
            inputs_embeds=inputs_embeds,
        )
        return hidden_states

331
332
333
334
335
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
336
337
338
339
        logits = self.logits_processor(self.lm_head, hidden_states,
                                       sampling_metadata)
        return logits

340
341
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
342
343
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights)