phi.py 11.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
# coding=utf-8
# Adapted from
# https://huggingface.co/microsoft/phi-1_5/blob/main/modeling_phi.py
# Copyright 2023 The vLLM team.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
#
# BSD 3-Clause License
#
# Copyright (c) 2022, Tri Dao, trid@cs.stanford.edu.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
#   list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
#   this list of conditions and the following disclaimer in the documentation
#   and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
#   contributors may be used to endorse or promote products derived from
#   this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
Woosuk Kwon's avatar
Woosuk Kwon committed
37
"""Inference-only Phi-1.5 model compatible with HuggingFace weights."""
38
from typing import Iterable, List, Optional, Tuple
39
40
41
42
43

import torch
from torch import nn
from transformers import PretrainedConfig

44
from vllm.attention import Attention, AttentionMetadata
45
from vllm.distributed import get_tensor_model_parallel_world_size
46
47
48
49
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
50
from vllm.model_executor.layers.logits_processor import LogitsProcessor
51
52
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
53
from vllm.model_executor.layers.rotary_embedding import get_rope
54
55
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
56
    ParallelLMHead, VocabParallelEmbedding)
57
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
58
from vllm.model_executor.sampling_metadata import SamplingMetadata
59
60
61
62
63
64
65
from vllm.sequence import SamplerOutput


class PhiAttention(nn.Module):

    def __init__(self,
                 config: PretrainedConfig,
66
                 quant_config: Optional[QuantizationConfig] = None):
67
68
69
70
71
72
73
74
75
76
77
78
79
        super().__init__()
        self.total_num_heads = config.num_attention_heads
        self.hidden_size = config.hidden_size
        self.head_size = self.hidden_size // self.total_num_heads

        tensor_model_parallel_world_size = (
            get_tensor_model_parallel_world_size())
        assert self.total_num_heads % tensor_model_parallel_world_size == 0
        self.num_heads = (self.total_num_heads //
                          tensor_model_parallel_world_size)

        # pylint: disable=C0103
        self.qkv_proj = QKVParallelLinear(
80
            self.hidden_size,
81
82
            self.head_size,
            self.total_num_heads,
83
            bias=True,
84
            quant_config=quant_config,
85
        )
86
        self.dense = RowParallelLinear(
87
88
            self.hidden_size,
            self.hidden_size,
89
            quant_config=quant_config,
90
91
92
        )

        scaling = self.head_size**-0.5
93
94
        rotary_dim = int(config.partial_rotary_factor *
                         (config.hidden_size // config.num_attention_heads))
95
96
97
98
99
100
101
        assert rotary_dim % 2 == 0

        # pylint: disable=C0301
        # Refer to:
        # https://huggingface.co/microsoft/phi-1_5/blob/d212a789620c380ff32ca1d1ee9943a777360987/modeling_phi.py#L518
        rope_theta = 10000
        max_position_embeddings = getattr(config, "n_positions", 2048)
Woosuk Kwon's avatar
Woosuk Kwon committed
102
        self.rotary_emb = get_rope(
103
            self.head_size,
Woosuk Kwon's avatar
Woosuk Kwon committed
104
105
            rotary_dim=rotary_dim,
            max_position=max_position_embeddings,
106
            base=rope_theta,
Woosuk Kwon's avatar
Woosuk Kwon committed
107
        )
108
        self.attn = Attention(self.num_heads, self.head_size, scaling)
109
110
111
112
113

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
114
115
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
116
    ) -> torch.Tensor:
117
        qkv, _ = self.qkv_proj(hidden_states)
118
        q, k, v = qkv.chunk(chunks=3, dim=-1)
Woosuk Kwon's avatar
Woosuk Kwon committed
119
        q, k = self.rotary_emb(position_ids, q, k)
120
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
121
        output, _ = self.dense(attn_output)
122
123
124
125
126
127
128
        return output


class PhiMLP(nn.Module):

    def __init__(self,
                 config: PretrainedConfig,
129
                 quant_config: Optional[QuantizationConfig] = None):
130
131
132
133
134
135
136
137
        super().__init__()

        n_inner = getattr(config, "n_inner", None)
        n_inner = n_inner if n_inner is not None else 4 * config.hidden_size

        self.fc1 = ColumnParallelLinear(
            config.hidden_size,
            n_inner,
138
            quant_config=quant_config,
139
140
141
142
        )
        self.fc2 = RowParallelLinear(
            n_inner,
            config.hidden_size,
143
            quant_config=quant_config,
144
        )
145
        quant_config = getattr(quant_config, "quant_config", None)
146
        self.act = get_act_fn(config.hidden_act, quant_config, n_inner)
147
148
149
150
151
152
153
154
155
156
157
158

    def forward(self, hidden_states):
        hidden_states, _ = self.fc1(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.fc2(hidden_states)
        return hidden_states


class PhiLayer(nn.Module):

    def __init__(self,
                 config: PretrainedConfig,
159
                 quant_config: Optional[QuantizationConfig] = None):
160
        super().__init__()
161
162
        self.input_layernorm = nn.LayerNorm(config.hidden_size,
                                            eps=config.layer_norm_eps)
163
164
        self.self_attn = PhiAttention(config, quant_config)
        self.mlp = PhiMLP(config, quant_config)
165
166
167
168
169

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
170
171
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
172
173
    ) -> torch.Tensor:
        residual = hidden_states
174
175
        hidden_states = self.input_layernorm(hidden_states)
        attn_outputs = self.self_attn(
176
177
178
            position_ids=position_ids,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
179
            attn_metadata=attn_metadata,
180
181
182
183
184
185
186
187
188
189
        )
        feed_forward_hidden_states = self.mlp(hidden_states)
        hidden_states = attn_outputs + feed_forward_hidden_states + residual
        return hidden_states


class PhiModel(nn.Module):

    def __init__(self,
                 config: PretrainedConfig,
190
                 quant_config: Optional[QuantizationConfig] = None):
191
192
        super().__init__()
        self.config = config
193
        self.quant_config = quant_config
194
195
196
        self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
                                                   config.hidden_size)
        self.layers = nn.ModuleList([
197
            PhiLayer(config, quant_config)
198
199
            for _ in range(config.num_hidden_layers)
        ])
200
201
        self.final_layernorm = nn.LayerNorm(config.hidden_size,
                                            eps=config.layer_norm_eps)
202
203
204
205
206

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
207
208
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
209
    ) -> torch.Tensor:
210
        hidden_states = self.embed_tokens(input_ids)
211
        for i in range(self.config.num_hidden_layers):
212
            layer = self.layers[i]
213
214
215
216
            hidden_states = layer(
                positions,
                hidden_states,
                kv_caches[i],
217
                attn_metadata,
218
219
            )

220
        hidden_states = self.final_layernorm(hidden_states)
221

222
        return hidden_states
223
224


225
226
227
228
class PhiForCausalLM(nn.Module):

    def __init__(self,
                 config: PretrainedConfig,
229
                 quant_config: Optional[QuantizationConfig] = None):
230
231
        super().__init__()
        self.config = config
232
        self.quant_config = quant_config
233

234
        self.model = PhiModel(config, quant_config)
235
236
237
238

        self.lm_head = ParallelLMHead(config.vocab_size,
                                      config.hidden_size,
                                      bias=True)
239
240
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.sampler = Sampler()
241
242
243
244
245

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
246
247
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
248
    ) -> torch.Tensor:
249
        hidden_states = self.model(input_ids, positions, kv_caches,
250
                                   attn_metadata)
251

252
253
        return hidden_states

254
255
256
257
258
259
    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head.weight, hidden_states,
                                       sampling_metadata, self.lm_head.bias)
        return logits

260
261
    def sample(
        self,
262
        logits: torch.Tensor,
263
        sampling_metadata: SamplingMetadata,
264
    ) -> Optional[SamplerOutput]:
265
        next_tokens = self.sampler(logits, sampling_metadata)
266
        return next_tokens
267

268
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
269
270
271
272
273
274
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v")
        ]
275
        params_dict = dict(self.named_parameters())
276

277
        for name, loaded_weight in weights:
278
279
280
            if "rotary_emb.inv_freq" in name:
                continue

281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                # pylint: disable=E1136

                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)