mpt.py 11.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
# Adapted from https://huggingface.co/mosaicml/mpt-7b/tree/main
import math
6
from collections.abc import Iterable
7
from itertools import islice
8
from typing import Optional, Union
9
10
11

import torch
import torch.nn as nn
12
from transformers import MptConfig
13

14
from vllm.attention import Attention
15
from vllm.compilation.decorators import support_torch_compile
16
from vllm.config import CacheConfig, VllmConfig
17
18
19
20
21
from vllm.distributed import (
    get_pp_group,
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
)
22
from vllm.model_executor.layers.activation import get_act_fn
23
24
25
26
27
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
28
from vllm.model_executor.layers.logits_processor import LogitsProcessor
29
from vllm.model_executor.layers.quantization import QuantizationConfig
30
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
31
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
32
from vllm.sequence import IntermediateTensors
33

34
from .interfaces import SupportsPP
35
36
37
38
39
40
41
from .utils import (
    AutoWeightsLoader,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
42

43
44
45
46
47

def _get_alibi_slopes(
    total_num_heads: int,
    alibi_bias_max: int,
) -> torch.Tensor:
48
    next_power_of_2 = 2 ** math.ceil(math.log2(total_num_heads))
49
50
51
52
53
54
55
56
    m = torch.arange(1, next_power_of_2 + 1, dtype=torch.float32)
    m = m.mul(alibi_bias_max / next_power_of_2)
    slopes = 1.0 / torch.pow(2, m)
    if next_power_of_2 != total_num_heads:
        slopes = torch.concat([slopes[1::2], slopes[::2]])[:total_num_heads]
    return slopes


57
class MPTAttention(nn.Module):
58
59
    def __init__(
        self,
60
        config: MptConfig,
61
        cache_config: Optional[CacheConfig] = None,
62
        quant_config: Optional[QuantizationConfig] = None,
63
        prefix: str = "",
64
    ):
65
66
67
        super().__init__()
        self.d_model = config.d_model
        self.total_num_heads = config.n_heads
Megha Agarwal's avatar
Megha Agarwal committed
68
        self.head_dim = self.d_model // self.total_num_heads
69
70
71
        self.clip_qkv = config.attn_config.clip_qkv
        self.qk_ln = config.attn_config.qk_ln
        self.alibi_bias_max = config.attn_config.alibi_bias_max
Megha Agarwal's avatar
Megha Agarwal committed
72
        if "kv_n_heads" in config.attn_config:
73
            self.total_num_kv_heads = config.attn_config.kv_n_heads
Megha Agarwal's avatar
Megha Agarwal committed
74
75
        else:
            self.total_num_kv_heads = self.total_num_heads
76
77
        assert not config.attn_config.prefix_lm
        assert config.attn_config.alibi
78

79
80
        # pylint: disable=invalid-name
        self.Wqkv = QKVParallelLinear(
81
            self.d_model,
82
83
            self.d_model // self.total_num_heads,
            self.total_num_heads,
Megha Agarwal's avatar
Megha Agarwal committed
84
            self.total_num_kv_heads,
85
            bias=not config.no_bias,
86
            quant_config=quant_config,
87
88
89
90
91
92
93
94
        )
        if self.qk_ln:
            self.q_ln = nn.LayerNorm(self.d_model)
            self.k_ln = nn.LayerNorm(self.d_model)
        self.out_proj = RowParallelLinear(
            self.d_model,
            self.d_model,
            bias=not config.no_bias,
95
            quant_config=quant_config,
96
97
98
99
100
101
        )

        tp_world_size = get_tensor_model_parallel_world_size()
        assert self.total_num_heads % tp_world_size == 0
        self.num_heads = self.total_num_heads // tp_world_size

Megha Agarwal's avatar
Megha Agarwal committed
102
103
104
105
106
107
108
109
110
111
112
        if self.total_num_kv_heads >= tp_world_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_world_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_world_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size)
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
113
114
115
116
        # Create the alibi slopes and slice them.
        tp_rank = get_tensor_model_parallel_rank()
        head_start = tp_rank * self.num_heads
        head_end = (tp_rank + 1) * self.num_heads
117
        alibi_slopes = _get_alibi_slopes(self.total_num_heads, self.alibi_bias_max)
118
119
120
121
        alibi_slopes = alibi_slopes[head_start:head_end].tolist()

        self.head_dim = self.d_model // self.total_num_heads
        scaling = self.head_dim**-0.5
122
123
124
125
126
127
128
129
130
131
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            scaling,
            alibi_slopes=alibi_slopes,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
        )
132
133
134
135
136
137
138

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        del position_ids  # unused.
139
        qkv, _ = self.Wqkv(hidden_states)
140
141
        if self.clip_qkv is not None:
            qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
Megha Agarwal's avatar
Megha Agarwal committed
142
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
143
144
145
        if self.qk_ln:
            q = self.q_ln(q)
            k = self.k_ln(k)
146
        attn_output = self.attn(q, k, v)
147
148
149
150
        output, _ = self.out_proj(attn_output)
        return output


151
class MPTMLP(nn.Module):
152
153
    def __init__(
        self,
154
        config: MptConfig,
155
        quant_config: Optional[QuantizationConfig] = None,
156
    ):
157
158
159
160
        super().__init__()
        hidden_size = config.d_model
        expansion_ratio = config.expansion_ratio
        intermediate_size = expansion_ratio * hidden_size
161
162
163
164
        self.up_proj = ColumnParallelLinear(
            hidden_size,
            intermediate_size,
            bias=not config.no_bias,
165
            quant_config=quant_config,
166
        )
167
        self.act = get_act_fn("gelu")
168
169
170
171
        self.down_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=not config.no_bias,
172
            quant_config=quant_config,
173
        )
174
175
176
177
178
179
180
181

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x, _ = self.up_proj(x)
        x = self.act(x)
        x, _ = self.down_proj(x)
        return x


182
class MPTBlock(nn.Module):
183
184
    def __init__(
        self,
185
        config: MptConfig,
186
        cache_config: Optional[CacheConfig] = None,
187
        quant_config: Optional[QuantizationConfig] = None,
188
        prefix: str = "",
189
    ):
190
191
192
        super().__init__()
        hidden_size = config.d_model
        self.norm_1 = nn.LayerNorm(hidden_size)
193
194
195
        self.attn = MPTAttention(
            config, cache_config, quant_config, prefix=f"{prefix}.attn"
        )
196
        self.norm_2 = nn.LayerNorm(hidden_size)
197
        self.ffn = MPTMLP(config, quant_config)
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        x = self.norm_1(hidden_states)
        x = self.attn(
            position_ids=position_ids,
            hidden_states=x,
        )
        hidden_states = hidden_states + x
        x = self.norm_2(hidden_states)
        x = self.ffn(x)
        hidden_states = hidden_states + x
        return hidden_states


216
@support_torch_compile
217
class MPTModel(nn.Module):
218
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
219
        super().__init__()
220
221
222
223
224

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

225
226
227
        assert config.embedding_fraction == 1.0
        assert config.norm_type == "low_precision_layernorm"

228
229
230
231
        self.wte = VocabParallelEmbedding(
            config.vocab_size,
            config.d_model,
        )
232
233
        self.start_layer, self.end_layer, self.blocks = make_layers(
            config.n_layers,
234
235
236
            lambda prefix: MPTBlock(config, cache_config, quant_config, prefix=prefix),
            prefix=f"{prefix}.blocks",
        )
237
238
239
        self.norm_f = nn.LayerNorm(config.d_model)
        if config.no_bias:
            for module in self.modules():
240
                if hasattr(module, "bias") and isinstance(module.bias, nn.Parameter):
241
242
                    # Remove the bias term in Linear and LayerNorm.
                    module.register_parameter("bias", None)
243
244
245
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states"], config.d_model
        )
246

247
248
249
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.wte(input_ids)

250
251
252
253
    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
254
        intermediate_tensors: Optional[IntermediateTensors],
255
        inputs_embeds: Optional[torch.Tensor] = None,
256
257
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
258
259
260
261
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
262
263
264
265
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]

266
        for block in islice(self.blocks, self.start_layer, self.end_layer):
267
            hidden_states = block(position_ids, hidden_states)
268
269
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})
270
271
272
        hidden_states = self.norm_f(hidden_states)
        return hidden_states

273
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
274
        params_dict = dict(self.named_parameters(remove_duplicate=False))
275
        loaded_params: set[str] = set()
276
277
278
279
280
281
282
        for name, loaded_weight in weights:
            # Skip loading extra bias for GPTQ models.
            if name.endswith(".bias") and name not in params_dict:
                continue
            if is_pp_missing_parameter(name, self):
                continue
            param = params_dict[name]
283
            weight_loader = getattr(param, "weight_loader", default_weight_loader)
284
285
286
287
            weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

288

289
class MPTForCausalLM(nn.Module, SupportsPP):
290
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
291
        super().__init__()
292
293
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
294
295
        self.config = config
        assert config.tie_word_embeddings
296
        self.quant_config = quant_config
297

298
299
300
        self.transformer = MPTModel(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "transformer")
        )
301
        self.lm_head = self.transformer.wte
302
        self.logits_processor = LogitsProcessor(config.vocab_size)
303
        self.make_empty_intermediate_tensors = (
304
305
            self.transformer.make_empty_intermediate_tensors
        )
306

307
308
309
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.transformer.get_input_embeddings(input_ids)

310
311
312
313
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
314
        intermediate_tensors: Optional[IntermediateTensors] = None,
315
        inputs_embeds: Optional[torch.Tensor] = None,
316
    ) -> Union[torch.Tensor, IntermediateTensors]:
317
318
319
        hidden_states = self.transformer(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
320
321
        return hidden_states

322
323
324
325
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> Optional[torch.Tensor]:
326
        logits = self.logits_processor(self.lm_head, hidden_states)
327
328
        return logits

329
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
330
331
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights)