persimmon.py 14.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# adapted from https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/persimmon/modeling_persimmon.py
# Copyright 2023 The vLLM team.
# Copyright 2023 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only persimmon model compatible with HuggingFace weights."""
22
from typing import Iterable, List, Optional, Set, Tuple, Union
23
24
25
26
27
28

import torch
from torch import nn
from transformers import PersimmonConfig

from vllm.attention import Attention, AttentionMetadata
29
from vllm.compilation.decorators import support_torch_compile
30
from vllm.config import CacheConfig, VllmConfig
31
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
32
from vllm.model_executor.layers.activation import get_act_fn
33
34
35
36
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
37
from vllm.model_executor.layers.quantization import QuantizationConfig
38
from vllm.model_executor.layers.rotary_embedding import get_rope
Joe Runde's avatar
Joe Runde committed
39
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
40
41
42
43
from vllm.model_executor.layers.vocab_parallel_embedding import (
    ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
44
from vllm.sequence import IntermediateTensors
45

46
47
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
48
49
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
50

51
52
53
54
55
56
57
58
59
60
61
62
63

class PersimmonMLP(nn.Module):

    def __init__(self,
                 config: PersimmonConfig,
                 quant_config: Optional[QuantizationConfig] = None):
        super().__init__()
        self.dense_h_to_4h = ColumnParallelLinear(config.hidden_size,
                                                  config.intermediate_size,
                                                  quant_config=quant_config)
        self.dense_4h_to_h = RowParallelLinear(config.intermediate_size,
                                               config.hidden_size,
                                               quant_config=quant_config)
64
        self.act = get_act_fn(config.hidden_act)
65
66
67
68
69
70
71
72
73
74
75
76
77

    def forward(self, hidden_states) -> torch.Tensor:
        hidden_states, _ = self.dense_h_to_4h(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.dense_4h_to_h(hidden_states)
        return hidden_states


class PersimmonAttention(nn.Module):

    def __init__(self,
                 config: PersimmonConfig,
                 cache_config: Optional[CacheConfig] = None,
78
79
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
        super().__init__()
        self.config = config
        tensor_parallel_world_size = get_tensor_model_parallel_world_size()

        self.hidden_size = config.hidden_size
        self.total_num_heads = config.num_attention_heads
        self.num_heads = self.total_num_heads // tensor_parallel_world_size
        self.head_dim = self.hidden_size // self.total_num_heads
        self.max_position_embeddings = config.max_position_embeddings
        self.rope_theta = config.rope_theta
        self.partial_rotary_factor = config.partial_rotary_factor
        self.is_causal = True

        assert (self.head_dim * self.total_num_heads) == self.hidden_size
        assert self.total_num_heads % tensor_parallel_world_size == 0

        self.query_key_value = QKVParallelLinear(
            self.hidden_size,
            self.head_dim,
            self.total_num_heads,
            bias=True,
            quant_config=quant_config,
        )
        self.dense = RowParallelLinear(
104
            self.total_num_heads * self.head_dim,
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
            self.hidden_size,
            bias=True,
            quant_config=quant_config,
        )
        self.is_qk_layernorm = config.qk_layernorm

        if self.is_qk_layernorm:
            self.q_layernorm = nn.LayerNorm(self.head_dim)
            self.k_layernorm = nn.LayerNorm(self.head_dim)

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=int(self.partial_rotary_factor * self.head_dim),
            max_position=self.max_position_embeddings,
            base=self.rope_theta,
        )
        self.scaling = self.head_dim**-0.5
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              scale=self.scaling,
                              cache_config=cache_config,
126
127
                              quant_config=quant_config,
                              prefix=f"{prefix}.attn")
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171

    def _split_heads(self, x: torch.Tensor) -> torch.Tensor:
        # [seq_length, hidden_size] -> [seq_length, num_heads, head_dim]
        seq_length = x.shape[0]
        return x.view(seq_length, self.num_heads, self.head_dim)

    def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
        # [seq_length, num_heads, head_dim] -> [seq_length, hidden_size]
        seq_length = x.shape[0]
        return x.view(seq_length, self.num_heads * self.head_dim)

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
    ) -> torch.Tensor:
        # [seq_length, 3 x hidden_size]
        qkv, _ = self.query_key_value(hidden_states)
        q, k, v = qkv.chunk(chunks=3, dim=-1)

        if self.is_qk_layernorm:
            # [seq_length, num_heads, head_dim]
            q = self._split_heads(q)
            k = self._split_heads(k)

            q = self.q_layernorm(q)
            k = self.k_layernorm(k)

            q = self._merge_heads(q)
            k = self._merge_heads(k)

        q, k = self.rotary_emb(position_ids, q, k)
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
        output, _ = self.dense(attn_output)
        return output


class PersimmonDecoderLayer(nn.Module):

    def __init__(self,
                 config: PersimmonConfig,
                 cache_config: Optional[CacheConfig] = None,
172
173
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
174
175
176
177
        super().__init__()
        self.hidden_size = config.hidden_size
        self.self_attn = PersimmonAttention(config=config,
                                            cache_config=cache_config,
178
179
                                            quant_config=quant_config,
                                            prefix=f"{prefix}.self_attn")
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
        self.mlp = PersimmonMLP(config, quant_config=quant_config)
        self.input_layernorm = nn.LayerNorm(config.hidden_size,
                                            eps=config.layer_norm_eps)
        self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
                                                     eps=config.layer_norm_eps)

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
    ) -> torch.Tensor:
        residual = hidden_states

        hidden_states = self.input_layernorm(hidden_states)

        # Self Attention
        hidden_states = self.self_attn(
            position_ids=position_ids,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
            attn_metadata=attn_metadata,
        )
        hidden_states = residual + hidden_states

        # Fully Connected
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)

        hidden_states = hidden_states + residual

        outputs = hidden_states
        return outputs


217
@support_torch_compile
218
219
class PersimmonModel(nn.Module):

220
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
221
        super().__init__()
222
223
224
225
226

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

227
        self.vocab_size = config.vocab_size
228

229
230
        self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
                                                   config.hidden_size)
231
232
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
233
234
            lambda prefix: PersimmonDecoderLayer(
                config, cache_config, quant_config, prefix=prefix),
235
            prefix=f"{prefix}.layers")
236
237
        self.final_layernorm = nn.LayerNorm(config.hidden_size,
                                            eps=config.layer_norm_eps)
238
239
240
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(["hidden_states"],
                                                    config.hidden_size))
241

242
243
244
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

245
246
247
248
249
250
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
251
        intermediate_tensors: Optional[IntermediateTensors],
252
        inputs_embeds: Optional[torch.Tensor] = None,
253
254
255
256
257
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
258
                hidden_states = self.get_input_embeddings(input_ids)
259
        else:
260
261
262
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
        for i in range(self.start_layer, self.end_layer):
263
264
265
            hidden_states = self.layers[i](
                positions,
                hidden_states,
266
                kv_caches[i - self.start_layer],
267
268
                attn_metadata,
            )
269
270
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})
271
272
273
274
        hidden_states = self.final_layernorm(hidden_states)
        return hidden_states


275
class PersimmonForCausalLM(nn.Module, SupportsPP):
276

277
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
278
        super().__init__()
279
        config = vllm_config.model_config.hf_config
280
        self.config = config
281
        self.vocab_size = config.vocab_size
282
283
        self.model = PersimmonModel(vllm_config=vllm_config,
                                    prefix=maybe_prefix(prefix, "model"))
284
        self.lm_head = ParallelLMHead(config.vocab_size,
285
286
                                      config.hidden_size,
                                      bias=False)
287
        self.logits_processor = LogitsProcessor(config.vocab_size)
Joe Runde's avatar
Joe Runde committed
288
        self.sampler = get_sampler()
289
290
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)
291

292
293
294
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

295
296
297
298
299
300
301
302
303
304
305
306
307
308
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
    ):
        hidden_states = self.model(
            input_ids=input_ids,
            positions=positions,
            kv_caches=kv_caches,
            attn_metadata=attn_metadata,
309
            intermediate_tensors=intermediate_tensors,
310
311
312
313
            inputs_embeds=inputs_embeds,
        )
        return hidden_states

314
315
316
317
318
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
319
320
321
322
323
324
325
326
327
328
329
330
        logits = self.logits_processor(self.lm_head, hidden_states,
                                       sampling_metadata)
        return logits

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
        next_tokens = self.sampler(logits, sampling_metadata)
        return next_tokens

331
332
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
333
        params_dict = dict(self.named_parameters(remove_duplicate=False))
334
        loaded_params: Set[str] = set()
335
336
337
338
339
340
341
342
        for name, loaded_weight in weights:
            if "rotary_emb.inv_freq" in name:
                continue
            if ("rotary_emb.cos_cached" in name
                    or "rotary_emb.sin_cached" in name):
                # Models trained using ColossalAI may include these tensors in
                # the checkpoint. Skip them.
                continue
343
344
            if is_pp_missing_parameter(name, self):
                continue
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
            param = params_dict[name]

            if "query_key_value" in name:
                # copy from vllm/model_executor/models/bloom.py
                # NOTE: Persimmon's fused QKV's output_dim has the shape of
                # (num_heads * 3 * head_size), while the
                # required shape is (3 * num_heads * head_size).
                # Thus, we need weight conversion.
                output_dim = getattr(param, "output_dim", None)
                num_heads = self.config.num_attention_heads
                if output_dim is not None:
                    loaded_weight_shape = loaded_weight.shape
                    loaded_weight = loaded_weight.view(
                        loaded_weight_shape[:output_dim] + (num_heads, 3, -1) +
                        loaded_weight_shape[output_dim + 1:])
                    loaded_weight = loaded_weight.transpose(
                        output_dim, output_dim + 1)
                    loaded_weight = loaded_weight.reshape(loaded_weight_shape)

            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)
367
368
            loaded_params.add(name)
        return loaded_params