"vscode:/vscode.git/clone" did not exist on "9a6a66f3b837bd3565471dc09ce3e23831e0e3f7"
mpt.py 11.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
# Adapted from https://huggingface.co/mosaicml/mpt-7b/tree/main
import math
6
from collections.abc import Iterable
7
from itertools import islice
8
9
10

import torch
import torch.nn as nn
11
from transformers import MptConfig
12

13
from vllm.attention.layer import Attention
14
from vllm.compilation.decorators import support_torch_compile
15
from vllm.config import CacheConfig, VllmConfig
16
17
18
19
20
from vllm.distributed import (
    get_pp_group,
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
)
21
from vllm.model_executor.layers.activation import get_act_fn
22
23
24
25
26
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
27
from vllm.model_executor.layers.logits_processor import LogitsProcessor
28
from vllm.model_executor.layers.quantization import QuantizationConfig
29
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
30
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
31
from vllm.sequence import IntermediateTensors
32

33
from .interfaces import SupportsPP
34
35
36
37
38
39
40
from .utils import (
    AutoWeightsLoader,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
41

42
43
44
45
46

def _get_alibi_slopes(
    total_num_heads: int,
    alibi_bias_max: int,
) -> torch.Tensor:
47
    next_power_of_2 = 2 ** math.ceil(math.log2(total_num_heads))
48
49
50
51
52
53
54
55
    m = torch.arange(1, next_power_of_2 + 1, dtype=torch.float32)
    m = m.mul(alibi_bias_max / next_power_of_2)
    slopes = 1.0 / torch.pow(2, m)
    if next_power_of_2 != total_num_heads:
        slopes = torch.concat([slopes[1::2], slopes[::2]])[:total_num_heads]
    return slopes


56
class MPTAttention(nn.Module):
57
58
    def __init__(
        self,
59
        config: MptConfig,
60
61
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
62
        prefix: str = "",
63
    ):
64
65
66
        super().__init__()
        self.d_model = config.d_model
        self.total_num_heads = config.n_heads
Megha Agarwal's avatar
Megha Agarwal committed
67
        self.head_dim = self.d_model // self.total_num_heads
68
69
70
        self.clip_qkv = config.attn_config.clip_qkv
        self.qk_ln = config.attn_config.qk_ln
        self.alibi_bias_max = config.attn_config.alibi_bias_max
Megha Agarwal's avatar
Megha Agarwal committed
71
        if "kv_n_heads" in config.attn_config:
72
            self.total_num_kv_heads = config.attn_config.kv_n_heads
Megha Agarwal's avatar
Megha Agarwal committed
73
74
        else:
            self.total_num_kv_heads = self.total_num_heads
75
76
        assert not config.attn_config.prefix_lm
        assert config.attn_config.alibi
77

78
79
        # pylint: disable=invalid-name
        self.Wqkv = QKVParallelLinear(
80
            self.d_model,
81
82
            self.d_model // self.total_num_heads,
            self.total_num_heads,
Megha Agarwal's avatar
Megha Agarwal committed
83
            self.total_num_kv_heads,
84
            bias=not config.no_bias,
85
            quant_config=quant_config,
86
            prefix=f"{prefix}.Wqkv",
87
88
89
90
91
92
93
94
        )
        if self.qk_ln:
            self.q_ln = nn.LayerNorm(self.d_model)
            self.k_ln = nn.LayerNorm(self.d_model)
        self.out_proj = RowParallelLinear(
            self.d_model,
            self.d_model,
            bias=not config.no_bias,
95
            quant_config=quant_config,
96
            prefix=f"{prefix}.out_proj",
97
98
99
100
101
102
        )

        tp_world_size = get_tensor_model_parallel_world_size()
        assert self.total_num_heads % tp_world_size == 0
        self.num_heads = self.total_num_heads // tp_world_size

Megha Agarwal's avatar
Megha Agarwal committed
103
104
105
106
107
108
109
110
111
112
113
        if self.total_num_kv_heads >= tp_world_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_world_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_world_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size)
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
114
115
116
117
        # Create the alibi slopes and slice them.
        tp_rank = get_tensor_model_parallel_rank()
        head_start = tp_rank * self.num_heads
        head_end = (tp_rank + 1) * self.num_heads
118
        alibi_slopes = _get_alibi_slopes(self.total_num_heads, self.alibi_bias_max)
119
120
121
122
        alibi_slopes = alibi_slopes[head_start:head_end].tolist()

        self.head_dim = self.d_model // self.total_num_heads
        scaling = self.head_dim**-0.5
123
124
125
126
127
128
129
130
131
132
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            scaling,
            alibi_slopes=alibi_slopes,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
        )
133
134
135
136
137
138
139

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        del position_ids  # unused.
140
        qkv, _ = self.Wqkv(hidden_states)
141
142
        if self.clip_qkv is not None:
            qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
Megha Agarwal's avatar
Megha Agarwal committed
143
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
144
145
146
        if self.qk_ln:
            q = self.q_ln(q)
            k = self.k_ln(k)
147
        attn_output = self.attn(q, k, v)
148
149
150
151
        output, _ = self.out_proj(attn_output)
        return output


152
class MPTMLP(nn.Module):
153
154
    def __init__(
        self,
155
        config: MptConfig,
156
        quant_config: QuantizationConfig | None = None,
157
        prefix: str = "",
158
    ):
159
160
161
162
        super().__init__()
        hidden_size = config.d_model
        expansion_ratio = config.expansion_ratio
        intermediate_size = expansion_ratio * hidden_size
163
164
165
166
        self.up_proj = ColumnParallelLinear(
            hidden_size,
            intermediate_size,
            bias=not config.no_bias,
167
            quant_config=quant_config,
168
            prefix=f"{prefix}.up_proj",
169
        )
170
        self.act = get_act_fn("gelu")
171
172
173
174
        self.down_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=not config.no_bias,
175
            quant_config=quant_config,
176
            prefix=f"{prefix}.down_proj",
177
        )
178
179
180
181
182
183
184
185

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x, _ = self.up_proj(x)
        x = self.act(x)
        x, _ = self.down_proj(x)
        return x


186
class MPTBlock(nn.Module):
187
188
    def __init__(
        self,
189
        config: MptConfig,
190
191
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
192
        prefix: str = "",
193
    ):
194
195
196
        super().__init__()
        hidden_size = config.d_model
        self.norm_1 = nn.LayerNorm(hidden_size)
197
198
199
        self.attn = MPTAttention(
            config, cache_config, quant_config, prefix=f"{prefix}.attn"
        )
200
        self.norm_2 = nn.LayerNorm(hidden_size)
201
        self.ffn = MPTMLP(config, quant_config, prefix=f"{prefix}.ffn")
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        x = self.norm_1(hidden_states)
        x = self.attn(
            position_ids=position_ids,
            hidden_states=x,
        )
        hidden_states = hidden_states + x
        x = self.norm_2(hidden_states)
        x = self.ffn(x)
        hidden_states = hidden_states + x
        return hidden_states


220
@support_torch_compile
221
class MPTModel(nn.Module):
222
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
223
        super().__init__()
224
225
226
227
228

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

229
230
231
        assert config.embedding_fraction == 1.0
        assert config.norm_type == "low_precision_layernorm"

232
233
234
235
        self.wte = VocabParallelEmbedding(
            config.vocab_size,
            config.d_model,
        )
236
237
        self.start_layer, self.end_layer, self.blocks = make_layers(
            config.n_layers,
238
239
240
            lambda prefix: MPTBlock(config, cache_config, quant_config, prefix=prefix),
            prefix=f"{prefix}.blocks",
        )
241
242
243
        self.norm_f = nn.LayerNorm(config.d_model)
        if config.no_bias:
            for module in self.modules():
244
                if hasattr(module, "bias") and isinstance(module.bias, nn.Parameter):
245
246
                    # Remove the bias term in Linear and LayerNorm.
                    module.register_parameter("bias", None)
247
248
249
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states"], config.d_model
        )
250

251
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
252
253
        return self.wte(input_ids)

254
255
    def forward(
        self,
zhuwenwen's avatar
zhuwenwen committed
256
        input_ids: torch.Tensor,
257
        position_ids: torch.Tensor,
258
259
260
        intermediate_tensors: IntermediateTensors | None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
261
        if get_pp_group().is_first_rank:
262
263
264
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
265
                hidden_states = self.embed_input_ids(input_ids)
266
267
268
269
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]

270
        for block in islice(self.blocks, self.start_layer, self.end_layer):
271
            hidden_states = block(position_ids, hidden_states)
272
273
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})
274
275
276
        hidden_states = self.norm_f(hidden_states)
        return hidden_states

277
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
278
        params_dict = dict(self.named_parameters(remove_duplicate=False))
279
        loaded_params: set[str] = set()
280
281
282
283
284
285
286
        for name, loaded_weight in weights:
            # Skip loading extra bias for GPTQ models.
            if name.endswith(".bias") and name not in params_dict:
                continue
            if is_pp_missing_parameter(name, self):
                continue
            param = params_dict[name]
287
            weight_loader = getattr(param, "weight_loader", default_weight_loader)
288
289
290
291
            weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

292

293
class MPTForCausalLM(nn.Module, SupportsPP):
294
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
295
        super().__init__()
296
297
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
298
299
        self.config = config
        assert config.tie_word_embeddings
300
        self.quant_config = quant_config
301

302
303
304
        self.transformer = MPTModel(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "transformer")
        )
305
        self.lm_head = self.transformer.wte
306
        self.logits_processor = LogitsProcessor(config.vocab_size)
307
        self.make_empty_intermediate_tensors = (
308
309
            self.transformer.make_empty_intermediate_tensors
        )
310

311
312
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.transformer.embed_input_ids(input_ids)
313

314
315
    def forward(
        self,
zhuwenwen's avatar
zhuwenwen committed
316
        input_ids: torch.Tensor,
317
        positions: torch.Tensor,
318
319
320
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
321
322
323
        hidden_states = self.transformer(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
324
325
        return hidden_states

326
327
328
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
329
    ) -> torch.Tensor | None:
330
        logits = self.logits_processor(self.lm_head, hidden_states)
331
332
        return logits

333
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
334
        loader = AutoWeightsLoader(self)
zhuwenwen's avatar
zhuwenwen committed
335
        return loader.load_weights(weights)