mpt.py 12.5 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
# Adapted from https://huggingface.co/mosaicml/mpt-7b/tree/main
import math
5
from typing import Iterable, Optional, Set, Tuple, Union
6
7
8
9

import torch
import torch.nn as nn

10
from vllm.attention import Attention
11
from vllm.compilation.decorators import support_torch_compile
12
from vllm.config import CacheConfig, VllmConfig
13
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
14
                              get_tensor_model_parallel_world_size)
15
from vllm.model_executor.layers.activation import get_act_fn
16
17
18
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
19
from vllm.model_executor.layers.logits_processor import LogitsProcessor
20
from vllm.model_executor.layers.quantization import QuantizationConfig
Joe Runde's avatar
Joe Runde committed
21
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
22
23
from vllm.model_executor.layers.vocab_parallel_embedding import (
    VocabParallelEmbedding)
24
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
25
from vllm.model_executor.sampling_metadata import SamplingMetadata
26
from vllm.sequence import IntermediateTensors
27
from vllm.transformers_utils.configs.mpt import MPTConfig
28

29
30
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
31
32
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
33

34
35
36
37
38
39
40
41
42
43
44
45
46
47

def _get_alibi_slopes(
    total_num_heads: int,
    alibi_bias_max: int,
) -> torch.Tensor:
    next_power_of_2 = 2**math.ceil(math.log2(total_num_heads))
    m = torch.arange(1, next_power_of_2 + 1, dtype=torch.float32)
    m = m.mul(alibi_bias_max / next_power_of_2)
    slopes = 1.0 / torch.pow(2, m)
    if next_power_of_2 != total_num_heads:
        slopes = torch.concat([slopes[1::2], slopes[::2]])[:total_num_heads]
    return slopes


48
class MPTAttention(nn.Module):
49

50
51
    def __init__(
        self,
52
        config: MPTConfig,
53
        cache_config: Optional[CacheConfig] = None,
54
        quant_config: Optional[QuantizationConfig] = None,
55
        prefix: str = "",
56
    ):
57
58
59
        super().__init__()
        self.d_model = config.d_model
        self.total_num_heads = config.n_heads
Megha Agarwal's avatar
Megha Agarwal committed
60
        self.head_dim = self.d_model // self.total_num_heads
61
62
63
        self.clip_qkv = config.attn_config["clip_qkv"]
        self.qk_ln = config.attn_config["qk_ln"]
        self.alibi_bias_max = config.attn_config["alibi_bias_max"]
Megha Agarwal's avatar
Megha Agarwal committed
64
65
66
67
        if "kv_n_heads" in config.attn_config:
            self.total_num_kv_heads = config.attn_config['kv_n_heads']
        else:
            self.total_num_kv_heads = self.total_num_heads
68
69
        assert not config.attn_config["prefix_lm"]
        assert config.attn_config["alibi"]
70

71
72
        # pylint: disable=invalid-name
        self.Wqkv = QKVParallelLinear(
73
            self.d_model,
74
75
            self.d_model // self.total_num_heads,
            self.total_num_heads,
Megha Agarwal's avatar
Megha Agarwal committed
76
            self.total_num_kv_heads,
77
            bias=not config.no_bias,
78
            quant_config=quant_config,
79
80
81
82
83
84
85
86
        )
        if self.qk_ln:
            self.q_ln = nn.LayerNorm(self.d_model)
            self.k_ln = nn.LayerNorm(self.d_model)
        self.out_proj = RowParallelLinear(
            self.d_model,
            self.d_model,
            bias=not config.no_bias,
87
            quant_config=quant_config,
88
89
90
91
92
93
        )

        tp_world_size = get_tensor_model_parallel_world_size()
        assert self.total_num_heads % tp_world_size == 0
        self.num_heads = self.total_num_heads // tp_world_size

Megha Agarwal's avatar
Megha Agarwal committed
94
95
96
97
98
99
100
101
102
103
104
        if self.total_num_kv_heads >= tp_world_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_world_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_world_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size)
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
105
106
107
108
109
110
111
112
113
114
        # Create the alibi slopes and slice them.
        tp_rank = get_tensor_model_parallel_rank()
        head_start = tp_rank * self.num_heads
        head_end = (tp_rank + 1) * self.num_heads
        alibi_slopes = _get_alibi_slopes(self.total_num_heads,
                                         self.alibi_bias_max)
        alibi_slopes = alibi_slopes[head_start:head_end].tolist()

        self.head_dim = self.d_model // self.total_num_heads
        scaling = self.head_dim**-0.5
115
116
117
118
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              scaling,
                              alibi_slopes=alibi_slopes,
119
                              num_kv_heads=self.num_kv_heads,
120
                              cache_config=cache_config,
121
122
                              quant_config=quant_config,
                              prefix=f"{prefix}.attn")
123
124
125
126
127
128
129

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        del position_ids  # unused.
130
        qkv, _ = self.Wqkv(hidden_states)
131
132
        if self.clip_qkv is not None:
            qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
Megha Agarwal's avatar
Megha Agarwal committed
133
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
134
135
136
        if self.qk_ln:
            q = self.q_ln(q)
            k = self.k_ln(k)
137
        attn_output = self.attn(q, k, v)
138
139
140
141
        output, _ = self.out_proj(attn_output)
        return output


142
class MPTMLP(nn.Module):
143

144
145
    def __init__(
        self,
146
        config: MPTConfig,
147
        quant_config: Optional[QuantizationConfig] = None,
148
    ):
149
150
151
152
        super().__init__()
        hidden_size = config.d_model
        expansion_ratio = config.expansion_ratio
        intermediate_size = expansion_ratio * hidden_size
153
154
155
156
        self.up_proj = ColumnParallelLinear(
            hidden_size,
            intermediate_size,
            bias=not config.no_bias,
157
            quant_config=quant_config,
158
        )
159
        self.act = get_act_fn("gelu")
160
161
162
163
        self.down_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=not config.no_bias,
164
            quant_config=quant_config,
165
        )
166
167
168
169
170
171
172
173

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x, _ = self.up_proj(x)
        x = self.act(x)
        x, _ = self.down_proj(x)
        return x


174
class MPTBlock(nn.Module):
175

176
177
    def __init__(
        self,
178
        config: MPTConfig,
179
        cache_config: Optional[CacheConfig] = None,
180
        quant_config: Optional[QuantizationConfig] = None,
181
        prefix: str = "",
182
    ):
183
184
185
        super().__init__()
        hidden_size = config.d_model
        self.norm_1 = nn.LayerNorm(hidden_size)
186
187
188
189
        self.attn = MPTAttention(config,
                                 cache_config,
                                 quant_config,
                                 prefix=f"{prefix}.attn")
190
        self.norm_2 = nn.LayerNorm(hidden_size)
191
        self.ffn = MPTMLP(config, quant_config)
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        x = self.norm_1(hidden_states)
        x = self.attn(
            position_ids=position_ids,
            hidden_states=x,
        )
        hidden_states = hidden_states + x
        x = self.norm_2(hidden_states)
        x = self.ffn(x)
        hidden_states = hidden_states + x
        return hidden_states


210
@support_torch_compile
211
class MPTModel(nn.Module):
212

213
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
214
        super().__init__()
215
216
217
218
219

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

220
221
222
        assert config.embedding_fraction == 1.0
        assert config.norm_type == "low_precision_layernorm"

223
224
225
226
        self.wte = VocabParallelEmbedding(
            config.vocab_size,
            config.d_model,
        )
227
228
        self.start_layer, self.end_layer, self.blocks = make_layers(
            config.n_layers,
229
230
            lambda prefix: MPTBlock(
                config, cache_config, quant_config, prefix=prefix),
231
            prefix=f"{prefix}.blocks")
232
233
234
        self.norm_f = nn.LayerNorm(config.d_model)
        if config.no_bias:
            for module in self.modules():
235
236
237
238
                if hasattr(module, "bias") and isinstance(
                        module.bias, nn.Parameter):
                    # Remove the bias term in Linear and LayerNorm.
                    module.register_parameter("bias", None)
239
240
241
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(["hidden_states"],
                                                    config.d_model))
242

243
244
245
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.wte(input_ids)

246
247
248
249
    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
250
        intermediate_tensors: Optional[IntermediateTensors],
251
        inputs_embeds: Optional[torch.Tensor] = None,
252
253
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
254
255
256
257
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
258
259
260
261
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]

262
263
        for block in self.blocks[self.start_layer:self.end_layer]:
            hidden_states = block(position_ids, hidden_states)
264
265
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})
266
267
268
269
        hidden_states = self.norm_f(hidden_states)
        return hidden_states


270
class MPTForCausalLM(nn.Module, SupportsPP):
271

272
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
273
        super().__init__()
274
275
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
276
277
        self.config = config
        assert config.tie_word_embeddings
278
        self.quant_config = quant_config
279

280
281
        self.transformer = MPTModel(vllm_config=vllm_config,
                                    prefix=maybe_prefix(prefix, "transformer"))
282
        self.lm_head = self.transformer.wte
283
        self.logits_processor = LogitsProcessor(config.vocab_size)
Joe Runde's avatar
Joe Runde committed
284
        self.sampler = get_sampler()
285
286
        self.make_empty_intermediate_tensors = (
            self.transformer.make_empty_intermediate_tensors)
287

288
289
290
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.transformer.get_input_embeddings(input_ids)

291
292
293
294
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
295
        intermediate_tensors: Optional[IntermediateTensors] = None,
296
        inputs_embeds: Optional[torch.Tensor] = None,
297
    ) -> Union[torch.Tensor, IntermediateTensors]:
298
299
        hidden_states = self.transformer(input_ids, positions,
                                         intermediate_tensors, inputs_embeds)
300
301
        return hidden_states

302
303
304
305
306
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
307
        logits = self.logits_processor(self.lm_head, hidden_states,
308
309
310
                                       sampling_metadata)
        return logits

311
312
    def sample(
        self,
313
        logits: torch.Tensor,
314
        sampling_metadata: SamplingMetadata,
315
    ) -> Optional[SamplerOutput]:
316
        next_tokens = self.sampler(logits, sampling_metadata)
317
318
        return next_tokens

319
320
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
321
        params_dict = dict(self.named_parameters(remove_duplicate=False))
322
        loaded_params: Set[str] = set()
323
        for name, loaded_weight in weights:
CHU Tianxiang's avatar
CHU Tianxiang committed
324
325
326
            # Skip loading extra bias for GPTQ models.
            if name.endswith(".bias") and name not in params_dict:
                continue
327
328
            if is_pp_missing_parameter(name, self):
                continue
329
330
331
332
            param = params_dict[name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)
333
334
            loaded_params.add(name)
        return loaded_params