yi.py 12.1 KB
Newer Older
Roy's avatar
Roy committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Yi model (https://01.ai) compatible with HuggingFace weights.

The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input.
"""
from typing import Any, Dict, List, Optional, Tuple

import torch
from torch import nn
from vllm.transformers_utils.configs.yi import YiConfig

from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
37
38
39
40
41
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase,
                                               MergedColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
Roy's avatar
Roy committed
42
from vllm.model_executor.layers.sampler import Sampler
43
44
from vllm.model_executor.layers.vocab_parallel_embedding import (
    VocabParallelEmbedding, ParallelLMHead)
Roy's avatar
Roy committed
45
from vllm.model_executor.parallel_utils.parallel_state import (
46
47
48
    get_tensor_model_parallel_world_size)
from vllm.model_executor.weight_utils import (default_weight_loader,
                                              hf_model_weights_iterator)
Roy's avatar
Roy committed
49
50
51
52
53
54
55
56
57
58
59
60
from vllm.sequence import SamplerOutput

KVCache = Tuple[torch.Tensor, torch.Tensor]


class YiMLP(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
61
        linear_method: Optional[LinearMethodBase] = None,
Roy's avatar
Roy committed
62
63
    ) -> None:
        super().__init__()
64
65
66
67
68
69
70
71
        self.gate_up_proj = MergedColumnParallelLinear(
            hidden_size, [intermediate_size] * 2,
            bias=False,
            linear_method=linear_method)
        self.down_proj = RowParallelLinear(intermediate_size,
                                           hidden_size,
                                           bias=False,
                                           linear_method=linear_method)
Roy's avatar
Roy committed
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


class YiAttention(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        rope_theta: float = 10000,
        rope_scaling: Optional[Dict[str, Any]] = None,
        max_position_embeddings: int = 8192,
94
        linear_method: Optional[LinearMethodBase] = None,
Roy's avatar
Roy committed
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings

119
        self.qkv_proj = QKVParallelLinear(
Roy's avatar
Roy committed
120
121
            hidden_size,
            self.head_dim,
122
123
            self.total_num_heads,
            self.total_num_kv_heads,
Roy's avatar
Roy committed
124
            bias=False,
125
            linear_method=linear_method,
Roy's avatar
Roy committed
126
        )
127
        self.o_proj = RowParallelLinear(
Roy's avatar
Roy committed
128
129
130
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
131
            linear_method=linear_method,
Roy's avatar
Roy committed
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
        )
        self.attn = PagedAttentionWithRoPE(
            self.num_heads,
            self.head_dim,
            self.scaling,
            base=self.rope_theta,
            max_position=self.max_position_embeddings,
            rotary_dim=self.head_dim,
            num_kv_heads=self.num_kv_heads,
            rope_scaling=rope_scaling)

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        kv_cache: KVCache,
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        k_cache, v_cache = kv_cache
        attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
                                input_metadata, cache_event)
        output, _ = self.o_proj(attn_output)
        return output


class YiDecoderLayer(nn.Module):

    def __init__(
        self,
        config: YiConfig,
165
        linear_method: Optional[LinearMethodBase] = None,
Roy's avatar
Roy committed
166
167
168
169
170
171
172
173
174
175
176
177
178
179
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
        self.self_attn = YiAttention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            num_kv_heads=config.num_key_value_heads,
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
            max_position_embeddings=max_position_embeddings,
180
            linear_method=linear_method,
Roy's avatar
Roy committed
181
182
183
184
185
        )
        self.mlp = YiMLP(
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
186
            linear_method=linear_method,
Roy's avatar
Roy committed
187
188
189
190
191
192
193
194
195
196
197
        )
        self.ln1 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.ln2 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        kv_cache: KVCache,
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
198
199
        residual: Optional[torch.Tensor],
    ) -> Tuple[torch.Tensor, torch.Tensor]:
Roy's avatar
Roy committed
200
        # Self Attention
201
202
203
204
205
        if residual is None:
            residual = hidden_states
            hidden_states = self.ln1(hidden_states)
        else:
            hidden_states, residual = self.ln1(hidden_states, residual)
Roy's avatar
Roy committed
206
207
208
209
210
211
212
213
214
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
            input_metadata=input_metadata,
            cache_event=cache_event,
        )

        # Fully Connected
215
        hidden_states, residual = self.ln2(hidden_states, residual)
Roy's avatar
Roy committed
216
        hidden_states = self.mlp(hidden_states)
217
        return hidden_states, residual
Roy's avatar
Roy committed
218
219
220
221
222
223
224


class YiModel(nn.Module):

    def __init__(
        self,
        config: YiConfig,
225
        linear_method: Optional[LinearMethodBase] = None,
Roy's avatar
Roy committed
226
227
228
229
230
231
    ) -> None:
        super().__init__()
        self.config = config
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size
        self.embed_tokens = VocabParallelEmbedding(
232
            config.vocab_size,
Roy's avatar
Roy committed
233
234
235
            config.hidden_size,
        )
        self.layers = nn.ModuleList([
236
            YiDecoderLayer(config, linear_method)
Roy's avatar
Roy committed
237
238
239
240
241
242
243
244
245
246
247
248
249
            for _ in range(config.num_hidden_layers)
        ])
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
        cache_events: Optional[List[torch.cuda.Event]],
    ) -> torch.Tensor:
        hidden_states = self.embed_tokens(input_ids)
250
        residual = None
Roy's avatar
Roy committed
251
252
253
254
255
256
        for i in range(len(self.layers)):
            if cache_events is None:
                cache_event = None
            else:
                cache_event = cache_events[i]
            layer = self.layers[i]
257
            hidden_states, residual = layer(
Roy's avatar
Roy committed
258
259
260
261
262
                positions,
                hidden_states,
                kv_caches[i],
                input_metadata,
                cache_event,
263
                residual,
Roy's avatar
Roy committed
264
            )
265
        hidden_states, _ = self.norm(hidden_states, residual)
Roy's avatar
Roy committed
266
267
268
269
270
271
272
273
        return hidden_states


class YiForCausalLM(nn.Module):

    def __init__(
        self,
        config: YiConfig,
274
        linear_method: Optional[LinearMethodBase] = None,
Roy's avatar
Roy committed
275
276
277
    ) -> None:
        super().__init__()
        self.config = config
278
279
280
        self.linear_method = linear_method
        self.model = YiModel(config, linear_method)
        self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
Roy's avatar
Roy committed
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
        self.sampler = Sampler(config.vocab_size)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
        cache_events: Optional[List[torch.cuda.Event]],
    ) -> SamplerOutput:
        hidden_states = self.model(input_ids, positions, kv_caches,
                                   input_metadata, cache_events)
        next_tokens = self.sampler(self.lm_head.weight, hidden_states,
                                   input_metadata)
        return next_tokens

    def load_weights(self,
                     model_name_or_path: str,
                     cache_dir: Optional[str] = None,
                     load_format: str = "auto",
                     revision: Optional[str] = None):
302
303
304
305
306
307
308
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
Roy's avatar
Roy committed
309
        ]
310
        params_dict = dict(self.named_parameters())
Roy's avatar
Roy committed
311
312
313
314
        for name, loaded_weight in hf_model_weights_iterator(
                model_name_or_path, cache_dir, load_format, revision):
            if "rotary_emb.inv_freq" in name:
                continue
315
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
Roy's avatar
Roy committed
316
317
                if weight_name not in name:
                    continue
318
319
320
                param = params_dict[name.replace(weight_name, param_name)]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
Roy's avatar
Roy committed
321
                break
322
323
324
325
326
            else:
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)