phi.py 11.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
# coding=utf-8
# Adapted from
# https://huggingface.co/microsoft/phi-1_5/blob/main/modeling_phi.py
# Copyright 2023 The vLLM team.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
#
# BSD 3-Clause License
#
# Copyright (c) 2022, Tri Dao, trid@cs.stanford.edu.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
#   list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
#   this list of conditions and the following disclaimer in the documentation
#   and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
#   contributors may be used to endorse or promote products derived from
#   this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
Woosuk Kwon's avatar
Woosuk Kwon committed
37
"""Inference-only Phi-1.5 model compatible with HuggingFace weights."""
38
from typing import Iterable, List, Optional, Tuple
39
40
41
42
43

import torch
from torch import nn
from transformers import PretrainedConfig

44
from vllm.attention import Attention, AttentionMetadata
45
from vllm.distributed import get_tensor_model_parallel_world_size
46
47
48
49
50
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               LinearMethodBase,
                                               QKVParallelLinear,
                                               RowParallelLinear)
51
from vllm.model_executor.layers.logits_processor import LogitsProcessor
52
from vllm.model_executor.layers.rotary_embedding import get_rope
53
54
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
55
    ParallelLMHead, VocabParallelEmbedding)
56
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
57
from vllm.model_executor.sampling_metadata import SamplingMetadata
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from vllm.sequence import SamplerOutput


class PhiAttention(nn.Module):

    def __init__(self,
                 config: PretrainedConfig,
                 linear_method: Optional[LinearMethodBase] = None):
        super().__init__()
        self.total_num_heads = config.num_attention_heads
        self.hidden_size = config.hidden_size
        self.head_size = self.hidden_size // self.total_num_heads

        tensor_model_parallel_world_size = (
            get_tensor_model_parallel_world_size())
        assert self.total_num_heads % tensor_model_parallel_world_size == 0
        self.num_heads = (self.total_num_heads //
                          tensor_model_parallel_world_size)

        # pylint: disable=C0103
        self.qkv_proj = QKVParallelLinear(
79
            self.hidden_size,
80
81
            self.head_size,
            self.total_num_heads,
82
            bias=True,
83
84
            linear_method=linear_method,
        )
85
        self.dense = RowParallelLinear(
86
87
88
89
90
91
            self.hidden_size,
            self.hidden_size,
            linear_method=linear_method,
        )

        scaling = self.head_size**-0.5
92
93
        rotary_dim = int(config.partial_rotary_factor *
                         (config.hidden_size // config.num_attention_heads))
94
95
96
97
98
99
100
        assert rotary_dim % 2 == 0

        # pylint: disable=C0301
        # Refer to:
        # https://huggingface.co/microsoft/phi-1_5/blob/d212a789620c380ff32ca1d1ee9943a777360987/modeling_phi.py#L518
        rope_theta = 10000
        max_position_embeddings = getattr(config, "n_positions", 2048)
Woosuk Kwon's avatar
Woosuk Kwon committed
101
        self.rotary_emb = get_rope(
102
            self.head_size,
Woosuk Kwon's avatar
Woosuk Kwon committed
103
104
            rotary_dim=rotary_dim,
            max_position=max_position_embeddings,
105
            base=rope_theta,
Woosuk Kwon's avatar
Woosuk Kwon committed
106
        )
107
        self.attn = Attention(self.num_heads, self.head_size, scaling)
108
109
110
111
112

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
113
114
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
115
    ) -> torch.Tensor:
116
        qkv, _ = self.qkv_proj(hidden_states)
117
        q, k, v = qkv.chunk(chunks=3, dim=-1)
Woosuk Kwon's avatar
Woosuk Kwon committed
118
        q, k = self.rotary_emb(position_ids, q, k)
119
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
120
        output, _ = self.dense(attn_output)
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
        return output


class PhiMLP(nn.Module):

    def __init__(self,
                 config: PretrainedConfig,
                 linear_method: Optional[LinearMethodBase] = None):
        super().__init__()

        n_inner = getattr(config, "n_inner", None)
        n_inner = n_inner if n_inner is not None else 4 * config.hidden_size

        self.fc1 = ColumnParallelLinear(
            config.hidden_size,
            n_inner,
            linear_method=linear_method,
        )
        self.fc2 = RowParallelLinear(
            n_inner,
            config.hidden_size,
            linear_method=linear_method,
        )
144
        quant_config = getattr(linear_method, "quant_config", None)
145
        self.act = get_act_fn(config.hidden_act, quant_config, n_inner)
146
147
148
149
150
151
152
153
154
155
156
157
158
159

    def forward(self, hidden_states):
        hidden_states, _ = self.fc1(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.fc2(hidden_states)
        return hidden_states


class PhiLayer(nn.Module):

    def __init__(self,
                 config: PretrainedConfig,
                 linear_method: Optional[LinearMethodBase] = None):
        super().__init__()
160
161
162
        self.input_layernorm = nn.LayerNorm(config.hidden_size,
                                            eps=config.layer_norm_eps)
        self.self_attn = PhiAttention(config, linear_method)
163
164
165
166
167
168
        self.mlp = PhiMLP(config, linear_method)

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
169
170
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
171
172
    ) -> torch.Tensor:
        residual = hidden_states
173
174
        hidden_states = self.input_layernorm(hidden_states)
        attn_outputs = self.self_attn(
175
176
177
            position_ids=position_ids,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
178
            attn_metadata=attn_metadata,
179
180
181
182
183
184
185
186
187
188
189
190
191
192
        )
        feed_forward_hidden_states = self.mlp(hidden_states)
        hidden_states = attn_outputs + feed_forward_hidden_states + residual
        return hidden_states


class PhiModel(nn.Module):

    def __init__(self,
                 config: PretrainedConfig,
                 linear_method: Optional[LinearMethodBase] = None):
        super().__init__()
        self.config = config
        self.linear_method = linear_method
193
194
195
        self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
                                                   config.hidden_size)
        self.layers = nn.ModuleList([
196
197
198
            PhiLayer(config, linear_method)
            for _ in range(config.num_hidden_layers)
        ])
199
200
        self.final_layernorm = nn.LayerNorm(config.hidden_size,
                                            eps=config.layer_norm_eps)
201
202
203
204
205

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
206
207
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
208
    ) -> torch.Tensor:
209
        hidden_states = self.embed_tokens(input_ids)
210
        for i in range(self.config.num_hidden_layers):
211
            layer = self.layers[i]
212
213
214
215
            hidden_states = layer(
                positions,
                hidden_states,
                kv_caches[i],
216
                attn_metadata,
217
218
            )

219
        hidden_states = self.final_layernorm(hidden_states)
220

221
        return hidden_states
222
223


224
225
226
227
228
229
230
231
232
class PhiForCausalLM(nn.Module):

    def __init__(self,
                 config: PretrainedConfig,
                 linear_method: Optional[LinearMethodBase] = None):
        super().__init__()
        self.config = config
        self.linear_method = linear_method

233
234
235
236
237
        self.model = PhiModel(config, linear_method)

        self.lm_head = ParallelLMHead(config.vocab_size,
                                      config.hidden_size,
                                      bias=True)
238
239
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.sampler = Sampler()
240
241
242
243
244

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
245
246
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
247
    ) -> torch.Tensor:
248
        hidden_states = self.model(input_ids, positions, kv_caches,
249
                                   attn_metadata)
250

251
252
        return hidden_states

253
254
255
256
257
258
    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head.weight, hidden_states,
                                       sampling_metadata, self.lm_head.bias)
        return logits

259
260
    def sample(
        self,
261
        logits: torch.Tensor,
262
        sampling_metadata: SamplingMetadata,
263
    ) -> Optional[SamplerOutput]:
264
        next_tokens = self.sampler(logits, sampling_metadata)
265
        return next_tokens
266

267
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
268
269
270
271
272
273
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v")
        ]
274
        params_dict = dict(self.named_parameters())
275

276
        for name, loaded_weight in weights:
277
278
279
            if "rotary_emb.inv_freq" in name:
                continue

280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                # pylint: disable=E1136

                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)