orion.py 13.7 KB
Newer Older
张大成's avatar
张大成 committed
1
2
3
4
5
6
# coding=utf-8
# Adapted from
# https://huggingface.co/OrionStarAI/Orion-14B-Base/blob/main/modeling_orion.py
# Copyright (c) OrionStar Inc.
# LICENSE: https://huggingface.co/OrionStarAI/Orion-14B-Base/blob/main/LICENSE
"""Inference-only Orion-14B model compatible with HuggingFace weights."""
7
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
张大成's avatar
张大成 committed
8
9
10
11
12

import torch
from torch import nn
from transformers import PretrainedConfig

13
from vllm.attention import Attention, AttentionMetadata
14
from vllm.config import CacheConfig
15
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
张大成's avatar
张大成 committed
16
from vllm.model_executor.layers.activation import SiluAndMul
17
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
张大成's avatar
张大成 committed
18
19
                                               QKVParallelLinear,
                                               RowParallelLinear)
20
from vllm.model_executor.layers.logits_processor import LogitsProcessor
21
from vllm.model_executor.layers.quantization import QuantizationConfig
22
from vllm.model_executor.layers.rotary_embedding import get_rope
23
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
张大成's avatar
张大成 committed
24
from vllm.model_executor.layers.vocab_parallel_embedding import (
25
    ParallelLMHead, VocabParallelEmbedding)
26
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
张大成's avatar
张大成 committed
27
from vllm.model_executor.sampling_metadata import SamplingMetadata
28
from vllm.sequence import IntermediateTensors
张大成's avatar
张大成 committed
29

30
31
32
33
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
                    make_empty_intermediate_tensors_factory, make_layers)

张大成's avatar
张大成 committed
34
35
36
37
38
39
40
41

class OrionMLP(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
42
        quant_config: Optional[QuantizationConfig] = None,
张大成's avatar
张大成 committed
43
44
45
46
47
    ) -> None:
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
            hidden_size, [intermediate_size] * 2,
            bias=False,
48
            quant_config=quant_config)
张大成's avatar
张大成 committed
49
50
51
        self.down_proj = RowParallelLinear(intermediate_size,
                                           hidden_size,
                                           bias=False,
52
                                           quant_config=quant_config)
张大成's avatar
张大成 committed
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


class OrionAttention(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        rope_theta: float = 10000,
        rope_scaling: Optional[Dict[str, Any]] = None,
        max_position_embeddings: int = 8192,
75
        cache_config: Optional[CacheConfig] = None,
76
        quant_config: Optional[QuantizationConfig] = None,
张大成's avatar
张大成 committed
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings

        self.qkv_proj = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=False,
107
            quant_config=quant_config,
张大成's avatar
张大成 committed
108
109
110
111
112
        )
        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
113
            quant_config=quant_config,
张大成's avatar
张大成 committed
114
115
116
117
118
119
120
121
122
        )

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
        )
123
124
125
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              self.scaling,
126
                              num_kv_heads=self.num_kv_heads,
127
128
                              cache_config=cache_config,
                              quant_config=quant_config)
张大成's avatar
张大成 committed
129
130
131
132
133

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
134
135
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
张大成's avatar
张大成 committed
136
137
138
139
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
140
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
张大成's avatar
张大成 committed
141
142
143
144
145
146
147
148
149
        output, _ = self.o_proj(attn_output)
        return output


class OrionDecoderLayer(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
150
        cache_config: Optional[CacheConfig] = None,
151
        quant_config: Optional[QuantizationConfig] = None,
张大成's avatar
张大成 committed
152
153
154
155
156
157
158
159
160
161
162
163
164
165
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
        self.self_attn = OrionAttention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            num_kv_heads=config.num_key_value_heads,
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
            max_position_embeddings=max_position_embeddings,
166
            cache_config=cache_config,
167
            quant_config=quant_config,
张大成's avatar
张大成 committed
168
169
170
171
172
        )
        self.mlp = OrionMLP(
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
173
            quant_config=quant_config,
张大成's avatar
张大成 committed
174
175
176
177
178
179
180
181
182
183
184
        )

        self.input_layernorm = nn.LayerNorm(config.hidden_size,
                                            eps=config.rms_norm_eps)
        self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
                                                     eps=config.rms_norm_eps)

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
185
186
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
张大成's avatar
张大成 committed
187
188
189
190
191
192
193
194
195
        residual: Optional[torch.Tensor],
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # Self Attention
        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states)
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
196
            attn_metadata=attn_metadata,
张大成's avatar
张大成 committed
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
        )

        hidden_states = residual + hidden_states

        # Fully Connected
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states
        return hidden_states, None


class OrionModel(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
214
        cache_config: Optional[CacheConfig] = None,
215
        quant_config: Optional[QuantizationConfig] = None,
216
        prefix: str = "",
张大成's avatar
张大成 committed
217
218
219
220
221
222
223
224
225
    ) -> None:
        super().__init__()
        self.config = config
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size
        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
        )
226
227
228
229
230
231
232
233
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
            lambda prefix: OrionDecoderLayer(
                config,
                cache_config,
                quant_config,
            ),
            prefix=f"{prefix}.layers")
张大成's avatar
张大成 committed
234
        self.norm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps)
235
236
237
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(
                ["hidden_states", "residual"], config.hidden_size))
张大成's avatar
张大成 committed
238
239
240
241
242

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
243
244
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
245
246
247
248
249
250
251
252
253
254
        intermediate_tensors: Optional[IntermediateTensors],
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
            hidden_states = self.embed_tokens(input_ids)
            residual = None
        else:
            assert intermediate_tensors
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
        for i in range(self.start_layer, self.end_layer):
张大成's avatar
张大成 committed
255
256
257
258
            layer = self.layers[i]
            hidden_states, residual = layer(
                positions,
                hidden_states,
259
                kv_caches[i - self.start_layer],
260
                attn_metadata,
张大成's avatar
张大成 committed
261
262
                residual,
            )
263
264
265
266
267
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({
                "hidden_states": hidden_states,
                "residual": residual
            })
张大成's avatar
张大成 committed
268
269
270
271
        hidden_states = self.norm(hidden_states)
        return hidden_states


272
class OrionForCausalLM(nn.Module, SupportsPP):
张大成's avatar
张大成 committed
273
274
275
276

    def __init__(
        self,
        config: PretrainedConfig,
277
        cache_config: Optional[CacheConfig] = None,
278
        quant_config: Optional[QuantizationConfig] = None,
张大成's avatar
张大成 committed
279
280
281
    ) -> None:
        super().__init__()
        self.config = config
282
        self.quant_config = quant_config
283
        self.model = OrionModel(config, cache_config, quant_config)
284
285
286
        self.lm_head = ParallelLMHead(config.vocab_size,
                                      config.hidden_size,
                                      quant_config=quant_config)
287
288
        if self.config.tie_word_embeddings:
            self.lm_head.weight = self.model.embed_tokens.weight
289
290
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.sampler = Sampler()
291
292
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)
张大成's avatar
张大成 committed
293
294
295
296
297

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
298
299
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
300
        intermediate_tensors: Optional[IntermediateTensors] = None,
301
    ) -> Union[torch.Tensor, IntermediateTensors]:
张大成's avatar
张大成 committed
302
        hidden_states = self.model(input_ids, positions, kv_caches,
303
                                   attn_metadata, intermediate_tensors)
张大成's avatar
张大成 committed
304
305
        return hidden_states

306
307
308
309
310
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
311
        logits = self.logits_processor(self.lm_head, hidden_states,
312
313
314
                                       sampling_metadata)
        return logits

张大成's avatar
张大成 committed
315
316
    def sample(
        self,
317
        logits: torch.Tensor,
张大成's avatar
张大成 committed
318
319
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
320
        next_tokens = self.sampler(logits, sampling_metadata)
张大成's avatar
张大成 committed
321
322
        return next_tokens

323
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
张大成's avatar
张大成 committed
324
325
326
327
328
329
330
331
332
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
333
        for name, loaded_weight in weights:
张大成's avatar
张大成 committed
334
335
336
337
338
339
340
341
342
343
344
345
346
347
            if "rotary_emb.inv_freq" in name:
                continue
            if ("rotary_emb.cos_cached" in name
                    or "rotary_emb.sin_cached" in name):
                # Models trained using ColossalAI may include these tensors in
                # the checkpoint. Skip them.
                continue
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
348
349
                if is_pp_missing_parameter(name, self):
                    continue
张大成's avatar
张大成 committed
350
351
352
353
354
355
356
357
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
358
359
                if is_pp_missing_parameter(name, self):
                    continue
张大成's avatar
张大成 committed
360
361
362
363
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)