mpt.py 10.8 KB
Newer Older
1
# coding=utf-8
2
3
# Adapted from https://huggingface.co/mosaicml/mpt-7b/tree/main
import math
4
from typing import List, Optional, Tuple
5
6
7
8
9
10

import torch
import torch.nn as nn

from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import get_act_fn
11
from vllm.model_executor.layers.attention import Attention
12
13
14
15
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               LinearMethodBase,
                                               QKVParallelLinear,
                                               RowParallelLinear)
16
from vllm.model_executor.layers.logits_processor import LogitsProcessor
17
from vllm.model_executor.layers.sampler import Sampler
18
19
from vllm.model_executor.layers.vocab_parallel_embedding import (
    VocabParallelEmbedding)
20
21
from vllm.model_executor.parallel_utils.parallel_state import (
    get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
22
from vllm.model_executor.sampling_metadata import SamplingMetadata
23
24
from vllm.model_executor.weight_utils import (default_weight_loader,
                                              hf_model_weights_iterator)
25
from vllm.sequence import SamplerOutput
26
from vllm.transformers_utils.configs.mpt import MPTConfig
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43

KVCache = Tuple[torch.Tensor, torch.Tensor]


def _get_alibi_slopes(
    total_num_heads: int,
    alibi_bias_max: int,
) -> torch.Tensor:
    next_power_of_2 = 2**math.ceil(math.log2(total_num_heads))
    m = torch.arange(1, next_power_of_2 + 1, dtype=torch.float32)
    m = m.mul(alibi_bias_max / next_power_of_2)
    slopes = 1.0 / torch.pow(2, m)
    if next_power_of_2 != total_num_heads:
        slopes = torch.concat([slopes[1::2], slopes[::2]])[:total_num_heads]
    return slopes


44
class MPTAttention(nn.Module):
45

46
47
    def __init__(
        self,
48
        config: MPTConfig,
49
50
        linear_method: Optional[LinearMethodBase] = None,
    ):
51
52
53
        super().__init__()
        self.d_model = config.d_model
        self.total_num_heads = config.n_heads
Megha Agarwal's avatar
Megha Agarwal committed
54
        self.head_dim = self.d_model // self.total_num_heads
55
56
57
        self.clip_qkv = config.attn_config["clip_qkv"]
        self.qk_ln = config.attn_config["qk_ln"]
        self.alibi_bias_max = config.attn_config["alibi_bias_max"]
Megha Agarwal's avatar
Megha Agarwal committed
58
59
60
61
        if "kv_n_heads" in config.attn_config:
            self.total_num_kv_heads = config.attn_config['kv_n_heads']
        else:
            self.total_num_kv_heads = self.total_num_heads
62
63
        assert not config.attn_config["prefix_lm"]
        assert config.attn_config["alibi"]
64

65
66
        # pylint: disable=invalid-name
        self.Wqkv = QKVParallelLinear(
67
            self.d_model,
68
69
            self.d_model // self.total_num_heads,
            self.total_num_heads,
Megha Agarwal's avatar
Megha Agarwal committed
70
            self.total_num_kv_heads,
71
            bias=not config.no_bias,
72
            linear_method=linear_method,
73
74
75
76
77
78
79
80
        )
        if self.qk_ln:
            self.q_ln = nn.LayerNorm(self.d_model)
            self.k_ln = nn.LayerNorm(self.d_model)
        self.out_proj = RowParallelLinear(
            self.d_model,
            self.d_model,
            bias=not config.no_bias,
81
            linear_method=linear_method,
82
83
84
85
86
87
        )

        tp_world_size = get_tensor_model_parallel_world_size()
        assert self.total_num_heads % tp_world_size == 0
        self.num_heads = self.total_num_heads // tp_world_size

Megha Agarwal's avatar
Megha Agarwal committed
88
89
90
91
92
93
94
95
96
97
98
        if self.total_num_kv_heads >= tp_world_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_world_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_world_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size)
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
99
100
101
102
103
104
105
106
107
108
        # Create the alibi slopes and slice them.
        tp_rank = get_tensor_model_parallel_rank()
        head_start = tp_rank * self.num_heads
        head_end = (tp_rank + 1) * self.num_heads
        alibi_slopes = _get_alibi_slopes(self.total_num_heads,
                                         self.alibi_bias_max)
        alibi_slopes = alibi_slopes[head_start:head_end].tolist()

        self.head_dim = self.d_model // self.total_num_heads
        scaling = self.head_dim**-0.5
109
110
111
112
113
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              scaling,
                              alibi_slopes=alibi_slopes,
                              num_kv_heads=self.num_kv_heads)
114
115
116
117
118
119
120
121
122

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
        kv_cache: KVCache,
        input_metadata: InputMetadata,
    ) -> torch.Tensor:
        del position_ids  # unused.
123
        qkv, _ = self.Wqkv(hidden_states)
124
125
        if self.clip_qkv is not None:
            qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
Megha Agarwal's avatar
Megha Agarwal committed
126
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
127
128
129
130
        if self.qk_ln:
            q = self.q_ln(q)
            k = self.k_ln(k)
        k_cache, v_cache = kv_cache
131
        attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
132
133
134
135
        output, _ = self.out_proj(attn_output)
        return output


136
class MPTMLP(nn.Module):
137

138
139
    def __init__(
        self,
140
        config: MPTConfig,
141
142
        linear_method: Optional[LinearMethodBase] = None,
    ):
143
144
145
146
        super().__init__()
        hidden_size = config.d_model
        expansion_ratio = config.expansion_ratio
        intermediate_size = expansion_ratio * hidden_size
147
148
149
150
        self.up_proj = ColumnParallelLinear(
            hidden_size,
            intermediate_size,
            bias=not config.no_bias,
151
            linear_method=linear_method,
152
        )
153
154
        quant_config = getattr(linear_method, "quant_config", None)
        self.act = get_act_fn("gelu", quant_config, intermediate_size)
155
156
157
158
        self.down_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=not config.no_bias,
159
            linear_method=linear_method,
160
        )
161
162
163
164
165
166
167
168

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x, _ = self.up_proj(x)
        x = self.act(x)
        x, _ = self.down_proj(x)
        return x


169
class MPTBlock(nn.Module):
170

171
172
    def __init__(
        self,
173
        config: MPTConfig,
174
175
        linear_method: Optional[LinearMethodBase] = None,
    ):
176
177
178
        super().__init__()
        hidden_size = config.d_model
        self.norm_1 = nn.LayerNorm(hidden_size)
179
        self.attn = MPTAttention(config, linear_method)
180
        self.norm_2 = nn.LayerNorm(hidden_size)
181
        self.ffn = MPTMLP(config, linear_method)
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
        kv_cache: KVCache,
        input_metadata: InputMetadata,
    ) -> torch.Tensor:
        x = self.norm_1(hidden_states)
        x = self.attn(
            position_ids=position_ids,
            hidden_states=x,
            kv_cache=kv_cache,
            input_metadata=input_metadata,
        )
        hidden_states = hidden_states + x
        x = self.norm_2(hidden_states)
        x = self.ffn(x)
        hidden_states = hidden_states + x
        return hidden_states


204
class MPTModel(nn.Module):
205

206
207
    def __init__(
        self,
208
        config: MPTConfig,
209
210
        linear_method: Optional[LinearMethodBase] = None,
    ):
211
212
213
214
        super().__init__()
        assert config.embedding_fraction == 1.0
        assert config.norm_type == "low_precision_layernorm"

215
216
217
218
        self.wte = VocabParallelEmbedding(
            config.vocab_size,
            config.d_model,
        )
219
        self.blocks = nn.ModuleList(
220
            [MPTBlock(config, linear_method) for _ in range(config.n_layers)])
221
222
223
        self.norm_f = nn.LayerNorm(config.d_model)
        if config.no_bias:
            for module in self.modules():
224
225
226
227
                if hasattr(module, "bias") and isinstance(
                        module.bias, nn.Parameter):
                    # Remove the bias term in Linear and LayerNorm.
                    module.register_parameter("bias", None)
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
    ) -> torch.Tensor:
        hidden_states = self.wte(input_ids)
        for i in range(len(self.blocks)):
            block = self.blocks[i]
            hidden_states = block(
                position_ids,
                hidden_states,
                kv_caches[i],
                input_metadata,
            )
        hidden_states = self.norm_f(hidden_states)
        return hidden_states


249
class MPTForCausalLM(nn.Module):
250

251
252
    def __init__(
        self,
253
        config: MPTConfig,
254
255
        linear_method: Optional[LinearMethodBase] = None,
    ):
256
257
258
        super().__init__()
        self.config = config
        assert config.tie_word_embeddings
259
        self.linear_method = linear_method
260

261
        self.transformer = MPTModel(config, linear_method)
262
        self.lm_head_weight = self.transformer.wte.weight
263
264
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.sampler = Sampler()
265
266
267
268
269
270
271

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
272
    ) -> torch.Tensor:
273
        hidden_states = self.transformer(input_ids, positions, kv_caches,
274
                                         input_metadata)
275
276
        return hidden_states

277
278
279
280
281
282
    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head_weight, hidden_states,
                                       sampling_metadata)
        return logits

283
284
    def sample(
        self,
285
        logits: torch.Tensor,
286
        sampling_metadata: SamplingMetadata,
287
    ) -> Optional[SamplerOutput]:
288
        next_tokens = self.sampler(logits, sampling_metadata)
289
290
291
292
293
        return next_tokens

    def load_weights(self,
                     model_name_or_path: str,
                     cache_dir: Optional[str] = None,
Jasmond L's avatar
Jasmond L committed
294
295
                     load_format: str = "auto",
                     revision: Optional[str] = None):
296
        params_dict = dict(self.named_parameters(remove_duplicate=False))
297
        for name, loaded_weight in hf_model_weights_iterator(
Jasmond L's avatar
Jasmond L committed
298
                model_name_or_path, cache_dir, load_format, revision):
CHU Tianxiang's avatar
CHU Tianxiang committed
299
300
301
            # Skip loading extra bias for GPTQ models.
            if name.endswith(".bias") and name not in params_dict:
                continue
302
303
304
305
            param = params_dict[name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)