orion.py 12.3 KB
Newer Older
张大成's avatar
张大成 committed
1
2
3
4
5
6
# coding=utf-8
# Adapted from
# https://huggingface.co/OrionStarAI/Orion-14B-Base/blob/main/modeling_orion.py
# Copyright (c) OrionStar Inc.
# LICENSE: https://huggingface.co/OrionStarAI/Orion-14B-Base/blob/main/LICENSE
"""Inference-only Orion-14B model compatible with HuggingFace weights."""
7
from typing import Any, Dict, Iterable, List, Optional, Tuple
张大成's avatar
张大成 committed
8
9
10
11
12

import torch
from torch import nn
from transformers import PretrainedConfig

13
from vllm.attention import Attention, AttentionMetadata
14
from vllm.config import CacheConfig
15
from vllm.distributed import get_tensor_model_parallel_world_size
张大成's avatar
张大成 committed
16
from vllm.model_executor.layers.activation import SiluAndMul
17
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
张大成's avatar
张大成 committed
18
19
                                               QKVParallelLinear,
                                               RowParallelLinear)
20
from vllm.model_executor.layers.logits_processor import LogitsProcessor
21
22
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
23
from vllm.model_executor.layers.rotary_embedding import get_rope
张大成's avatar
张大成 committed
24
25
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
26
    ParallelLMHead, VocabParallelEmbedding)
27
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
张大成's avatar
张大成 committed
28
from vllm.model_executor.sampling_metadata import SamplingMetadata
29
from vllm.sequence import IntermediateTensors, SamplerOutput
张大成's avatar
张大成 committed
30
31
32
33
34
35
36
37
38


class OrionMLP(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
39
        quant_config: Optional[QuantizationConfig] = None,
张大成's avatar
张大成 committed
40
41
42
43
44
    ) -> None:
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
            hidden_size, [intermediate_size] * 2,
            bias=False,
45
            quant_config=quant_config)
张大成's avatar
张大成 committed
46
47
48
        self.down_proj = RowParallelLinear(intermediate_size,
                                           hidden_size,
                                           bias=False,
49
                                           quant_config=quant_config)
张大成's avatar
张大成 committed
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


class OrionAttention(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        rope_theta: float = 10000,
        rope_scaling: Optional[Dict[str, Any]] = None,
        max_position_embeddings: int = 8192,
72
        cache_config: Optional[CacheConfig] = None,
73
        quant_config: Optional[QuantizationConfig] = None,
张大成's avatar
张大成 committed
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings

        self.qkv_proj = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=False,
104
            quant_config=quant_config,
张大成's avatar
张大成 committed
105
106
107
108
109
        )
        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
110
            quant_config=quant_config,
张大成's avatar
张大成 committed
111
112
113
114
115
116
117
118
119
        )

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
        )
120
121
122
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              self.scaling,
123
                              num_kv_heads=self.num_kv_heads,
124
125
                              cache_config=cache_config,
                              quant_config=quant_config)
张大成's avatar
张大成 committed
126
127
128
129
130

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
131
132
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
张大成's avatar
张大成 committed
133
134
135
136
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
137
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
张大成's avatar
张大成 committed
138
139
140
141
142
143
144
145
146
        output, _ = self.o_proj(attn_output)
        return output


class OrionDecoderLayer(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
147
        cache_config: Optional[CacheConfig] = None,
148
        quant_config: Optional[QuantizationConfig] = None,
张大成's avatar
张大成 committed
149
150
151
152
153
154
155
156
157
158
159
160
161
162
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
        self.self_attn = OrionAttention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            num_kv_heads=config.num_key_value_heads,
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
            max_position_embeddings=max_position_embeddings,
163
            cache_config=cache_config,
164
            quant_config=quant_config,
张大成's avatar
张大成 committed
165
166
167
168
169
        )
        self.mlp = OrionMLP(
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
170
            quant_config=quant_config,
张大成's avatar
张大成 committed
171
172
173
174
175
176
177
178
179
180
181
        )

        self.input_layernorm = nn.LayerNorm(config.hidden_size,
                                            eps=config.rms_norm_eps)
        self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
                                                     eps=config.rms_norm_eps)

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
182
183
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
张大成's avatar
张大成 committed
184
185
186
187
188
189
190
191
192
        residual: Optional[torch.Tensor],
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # Self Attention
        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states)
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
193
            attn_metadata=attn_metadata,
张大成's avatar
张大成 committed
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
        )

        hidden_states = residual + hidden_states

        # Fully Connected
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states
        return hidden_states, None


class OrionModel(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
211
        cache_config: Optional[CacheConfig] = None,
212
        quant_config: Optional[QuantizationConfig] = None,
张大成's avatar
张大成 committed
213
214
215
216
217
218
219
220
221
222
    ) -> None:
        super().__init__()
        self.config = config
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size
        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
        )
        self.layers = nn.ModuleList([
223
            OrionDecoderLayer(config, cache_config, quant_config)
张大成's avatar
张大成 committed
224
225
226
227
228
229
230
231
            for _ in range(config.num_hidden_layers)
        ])
        self.norm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
232
233
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
张大成's avatar
张大成 committed
234
235
236
237
238
239
240
241
242
    ) -> torch.Tensor:
        hidden_states = self.embed_tokens(input_ids)
        residual = None
        for i in range(len(self.layers)):
            layer = self.layers[i]
            hidden_states, residual = layer(
                positions,
                hidden_states,
                kv_caches[i],
243
                attn_metadata,
张大成's avatar
张大成 committed
244
245
246
247
248
249
250
251
252
253
254
                residual,
            )
        hidden_states = self.norm(hidden_states)
        return hidden_states


class OrionForCausalLM(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
255
        cache_config: Optional[CacheConfig] = None,
256
        quant_config: Optional[QuantizationConfig] = None,
张大成's avatar
张大成 committed
257
258
259
    ) -> None:
        super().__init__()
        self.config = config
260
        self.quant_config = quant_config
261
        self.model = OrionModel(config, cache_config, quant_config)
262
263
264
        self.lm_head = ParallelLMHead(config.vocab_size,
                                      config.hidden_size,
                                      quant_config=quant_config)
265
266
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.sampler = Sampler()
张大成's avatar
张大成 committed
267
268
269
270
271

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
272
273
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
274
        intermediate_tensors: Optional[IntermediateTensors] = None,
张大成's avatar
张大成 committed
275
276
    ) -> torch.Tensor:
        hidden_states = self.model(input_ids, positions, kv_caches,
277
                                   attn_metadata)
张大成's avatar
张大成 committed
278
279
        return hidden_states

280
281
282
283
284
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
285
        logits = self.logits_processor(self.lm_head, hidden_states,
286
287
288
                                       sampling_metadata)
        return logits

张大成's avatar
张大成 committed
289
290
    def sample(
        self,
291
        logits: torch.Tensor,
张大成's avatar
张大成 committed
292
293
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
294
        next_tokens = self.sampler(logits, sampling_metadata)
张大成's avatar
张大成 committed
295
296
        return next_tokens

297
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
张大成's avatar
张大成 committed
298
299
300
301
302
303
304
305
306
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
307
        for name, loaded_weight in weights:
张大成's avatar
张大成 committed
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
            if "rotary_emb.inv_freq" in name:
                continue
            if ("rotary_emb.cos_cached" in name
                    or "rotary_emb.sin_cached" in name):
                # Models trained using ColossalAI may include these tensors in
                # the checkpoint. Skip them.
                continue
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)