mpt.py 12 KB
Newer Older
1
# coding=utf-8
2
3
# Adapted from https://huggingface.co/mosaicml/mpt-7b/tree/main
import math
4
from typing import Iterable, List, Optional, Tuple, Union
5
6
7
8

import torch
import torch.nn as nn

9
from vllm.attention import Attention, AttentionMetadata
10
from vllm.compilation.decorators import support_torch_compile
11
from vllm.config import CacheConfig
12
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
13
                              get_tensor_model_parallel_world_size)
14
from vllm.model_executor.layers.activation import get_act_fn
15
16
17
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
18
from vllm.model_executor.layers.logits_processor import LogitsProcessor
19
from vllm.model_executor.layers.quantization import QuantizationConfig
20
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
21
22
from vllm.model_executor.layers.vocab_parallel_embedding import (
    VocabParallelEmbedding)
23
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
24
from vllm.model_executor.sampling_metadata import SamplingMetadata
25
from vllm.sequence import IntermediateTensors
26
from vllm.transformers_utils.configs.mpt import MPTConfig
27

28
29
30
31
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
                    make_empty_intermediate_tensors_factory, make_layers)

32
33
34
35
36
37
38
39
40
41
42
43
44
45

def _get_alibi_slopes(
    total_num_heads: int,
    alibi_bias_max: int,
) -> torch.Tensor:
    next_power_of_2 = 2**math.ceil(math.log2(total_num_heads))
    m = torch.arange(1, next_power_of_2 + 1, dtype=torch.float32)
    m = m.mul(alibi_bias_max / next_power_of_2)
    slopes = 1.0 / torch.pow(2, m)
    if next_power_of_2 != total_num_heads:
        slopes = torch.concat([slopes[1::2], slopes[::2]])[:total_num_heads]
    return slopes


46
class MPTAttention(nn.Module):
47

48
49
    def __init__(
        self,
50
        config: MPTConfig,
51
        cache_config: Optional[CacheConfig] = None,
52
        quant_config: Optional[QuantizationConfig] = None,
53
    ):
54
55
56
        super().__init__()
        self.d_model = config.d_model
        self.total_num_heads = config.n_heads
Megha Agarwal's avatar
Megha Agarwal committed
57
        self.head_dim = self.d_model // self.total_num_heads
58
59
60
        self.clip_qkv = config.attn_config["clip_qkv"]
        self.qk_ln = config.attn_config["qk_ln"]
        self.alibi_bias_max = config.attn_config["alibi_bias_max"]
Megha Agarwal's avatar
Megha Agarwal committed
61
62
63
64
        if "kv_n_heads" in config.attn_config:
            self.total_num_kv_heads = config.attn_config['kv_n_heads']
        else:
            self.total_num_kv_heads = self.total_num_heads
65
66
        assert not config.attn_config["prefix_lm"]
        assert config.attn_config["alibi"]
67

68
69
        # pylint: disable=invalid-name
        self.Wqkv = QKVParallelLinear(
70
            self.d_model,
71
72
            self.d_model // self.total_num_heads,
            self.total_num_heads,
Megha Agarwal's avatar
Megha Agarwal committed
73
            self.total_num_kv_heads,
74
            bias=not config.no_bias,
75
            quant_config=quant_config,
76
77
78
79
80
81
82
83
        )
        if self.qk_ln:
            self.q_ln = nn.LayerNorm(self.d_model)
            self.k_ln = nn.LayerNorm(self.d_model)
        self.out_proj = RowParallelLinear(
            self.d_model,
            self.d_model,
            bias=not config.no_bias,
84
            quant_config=quant_config,
85
86
87
88
89
90
        )

        tp_world_size = get_tensor_model_parallel_world_size()
        assert self.total_num_heads % tp_world_size == 0
        self.num_heads = self.total_num_heads // tp_world_size

Megha Agarwal's avatar
Megha Agarwal committed
91
92
93
94
95
96
97
98
99
100
101
        if self.total_num_kv_heads >= tp_world_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_world_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_world_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size)
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
102
103
104
105
106
107
108
109
110
111
        # Create the alibi slopes and slice them.
        tp_rank = get_tensor_model_parallel_rank()
        head_start = tp_rank * self.num_heads
        head_end = (tp_rank + 1) * self.num_heads
        alibi_slopes = _get_alibi_slopes(self.total_num_heads,
                                         self.alibi_bias_max)
        alibi_slopes = alibi_slopes[head_start:head_end].tolist()

        self.head_dim = self.d_model // self.total_num_heads
        scaling = self.head_dim**-0.5
112
113
114
115
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              scaling,
                              alibi_slopes=alibi_slopes,
116
                              num_kv_heads=self.num_kv_heads,
117
118
                              cache_config=cache_config,
                              quant_config=quant_config)
119
120
121
122
123

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
124
125
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
126
127
    ) -> torch.Tensor:
        del position_ids  # unused.
128
        qkv, _ = self.Wqkv(hidden_states)
129
130
        if self.clip_qkv is not None:
            qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
Megha Agarwal's avatar
Megha Agarwal committed
131
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
132
133
134
        if self.qk_ln:
            q = self.q_ln(q)
            k = self.k_ln(k)
135
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
136
137
138
139
        output, _ = self.out_proj(attn_output)
        return output


140
class MPTMLP(nn.Module):
141

142
143
    def __init__(
        self,
144
        config: MPTConfig,
145
        quant_config: Optional[QuantizationConfig] = None,
146
    ):
147
148
149
150
        super().__init__()
        hidden_size = config.d_model
        expansion_ratio = config.expansion_ratio
        intermediate_size = expansion_ratio * hidden_size
151
152
153
154
        self.up_proj = ColumnParallelLinear(
            hidden_size,
            intermediate_size,
            bias=not config.no_bias,
155
            quant_config=quant_config,
156
        )
157
        self.act = get_act_fn("gelu", quant_config, intermediate_size)
158
159
160
161
        self.down_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=not config.no_bias,
162
            quant_config=quant_config,
163
        )
164
165
166
167
168
169
170
171

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x, _ = self.up_proj(x)
        x = self.act(x)
        x, _ = self.down_proj(x)
        return x


172
class MPTBlock(nn.Module):
173

174
175
    def __init__(
        self,
176
        config: MPTConfig,
177
        cache_config: Optional[CacheConfig] = None,
178
        quant_config: Optional[QuantizationConfig] = None,
179
    ):
180
181
182
        super().__init__()
        hidden_size = config.d_model
        self.norm_1 = nn.LayerNorm(hidden_size)
183
        self.attn = MPTAttention(config, cache_config, quant_config)
184
        self.norm_2 = nn.LayerNorm(hidden_size)
185
        self.ffn = MPTMLP(config, quant_config)
186
187
188
189
190

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
191
192
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
193
194
195
196
197
198
    ) -> torch.Tensor:
        x = self.norm_1(hidden_states)
        x = self.attn(
            position_ids=position_ids,
            hidden_states=x,
            kv_cache=kv_cache,
199
            attn_metadata=attn_metadata,
200
201
202
203
204
205
206
207
        )
        hidden_states = hidden_states + x
        x = self.norm_2(hidden_states)
        x = self.ffn(x)
        hidden_states = hidden_states + x
        return hidden_states


208
@support_torch_compile
209
class MPTModel(nn.Module):
210

211
212
    def __init__(
        self,
213
        config: MPTConfig,
214
        cache_config: Optional[CacheConfig] = None,
215
        quant_config: Optional[QuantizationConfig] = None,
216
        prefix: str = "",
217
    ):
218
219
220
221
        super().__init__()
        assert config.embedding_fraction == 1.0
        assert config.norm_type == "low_precision_layernorm"

222
223
224
225
        self.wte = VocabParallelEmbedding(
            config.vocab_size,
            config.d_model,
        )
226
227
228
229
        self.start_layer, self.end_layer, self.blocks = make_layers(
            config.n_layers,
            lambda prefix: MPTBlock(config, cache_config, quant_config),
            prefix=f"{prefix}.blocks")
230
231
232
        self.norm_f = nn.LayerNorm(config.d_model)
        if config.no_bias:
            for module in self.modules():
233
234
235
236
                if hasattr(module, "bias") and isinstance(
                        module.bias, nn.Parameter):
                    # Remove the bias term in Linear and LayerNorm.
                    module.register_parameter("bias", None)
237
238
239
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(["hidden_states"],
                                                    config.d_model))
240
241
242
243
244

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
245
246
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
247
248
249
250
251
252
253
254
255
        intermediate_tensors: Optional[IntermediateTensors],
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
            hidden_states = self.wte(input_ids)
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]

        for i in range(self.start_layer, self.end_layer):
256
257
258
259
            block = self.blocks[i]
            hidden_states = block(
                position_ids,
                hidden_states,
260
                kv_caches[i - self.start_layer],
261
                attn_metadata,
262
            )
263
264
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})
265
266
267
268
        hidden_states = self.norm_f(hidden_states)
        return hidden_states


269
class MPTForCausalLM(nn.Module, SupportsPP):
270

271
272
    def __init__(
        self,
273
        config: MPTConfig,
274
        cache_config: Optional[CacheConfig] = None,
275
        quant_config: Optional[QuantizationConfig] = None,
276
    ):
277
278
279
        super().__init__()
        self.config = config
        assert config.tie_word_embeddings
280
        self.quant_config = quant_config
281

282
        self.transformer = MPTModel(config, cache_config, quant_config)
283
        self.lm_head = self.transformer.wte
284
285
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.sampler = Sampler()
286
287
        self.make_empty_intermediate_tensors = (
            self.transformer.make_empty_intermediate_tensors)
288
289
290
291
292

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
293
294
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
295
        intermediate_tensors: Optional[IntermediateTensors] = None,
296
    ) -> Union[torch.Tensor, IntermediateTensors]:
297
        hidden_states = self.transformer(input_ids, positions, kv_caches,
298
                                         attn_metadata, intermediate_tensors)
299
300
        return hidden_states

301
302
303
304
305
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
306
        logits = self.logits_processor(self.lm_head, hidden_states,
307
308
309
                                       sampling_metadata)
        return logits

310
311
    def sample(
        self,
312
        logits: torch.Tensor,
313
        sampling_metadata: SamplingMetadata,
314
    ) -> Optional[SamplerOutput]:
315
        next_tokens = self.sampler(logits, sampling_metadata)
316
317
        return next_tokens

318
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
319
        params_dict = dict(self.named_parameters(remove_duplicate=False))
320
        for name, loaded_weight in weights:
CHU Tianxiang's avatar
CHU Tianxiang committed
321
322
323
            # Skip loading extra bias for GPTQ models.
            if name.endswith(".bias") and name not in params_dict:
                continue
324
325
            if is_pp_missing_parameter(name, self):
                continue
326
327
328
329
            param = params_dict[name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)