mpt.py 12 KB
Newer Older
1
2
# Adapted from https://huggingface.co/mosaicml/mpt-7b/tree/main
import math
3
from typing import Iterable, List, Optional, Tuple, Union
4
5
6
7

import torch
import torch.nn as nn

8
from vllm.attention import Attention, AttentionMetadata
9
from vllm.compilation.decorators import support_torch_compile
10
from vllm.config import CacheConfig
11
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
12
                              get_tensor_model_parallel_world_size)
13
from vllm.model_executor.layers.activation import get_act_fn
14
15
16
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
17
from vllm.model_executor.layers.logits_processor import LogitsProcessor
18
from vllm.model_executor.layers.quantization import QuantizationConfig
19
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
20
21
from vllm.model_executor.layers.vocab_parallel_embedding import (
    VocabParallelEmbedding)
22
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
23
from vllm.model_executor.sampling_metadata import SamplingMetadata
24
from vllm.sequence import IntermediateTensors
25
from vllm.transformers_utils.configs.mpt import MPTConfig
26

27
28
29
30
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
                    make_empty_intermediate_tensors_factory, make_layers)

31
32
33
34
35
36
37
38
39
40
41
42
43
44

def _get_alibi_slopes(
    total_num_heads: int,
    alibi_bias_max: int,
) -> torch.Tensor:
    next_power_of_2 = 2**math.ceil(math.log2(total_num_heads))
    m = torch.arange(1, next_power_of_2 + 1, dtype=torch.float32)
    m = m.mul(alibi_bias_max / next_power_of_2)
    slopes = 1.0 / torch.pow(2, m)
    if next_power_of_2 != total_num_heads:
        slopes = torch.concat([slopes[1::2], slopes[::2]])[:total_num_heads]
    return slopes


45
class MPTAttention(nn.Module):
46

47
48
    def __init__(
        self,
49
        config: MPTConfig,
50
        cache_config: Optional[CacheConfig] = None,
51
        quant_config: Optional[QuantizationConfig] = None,
52
    ):
53
54
55
        super().__init__()
        self.d_model = config.d_model
        self.total_num_heads = config.n_heads
Megha Agarwal's avatar
Megha Agarwal committed
56
        self.head_dim = self.d_model // self.total_num_heads
57
58
59
        self.clip_qkv = config.attn_config["clip_qkv"]
        self.qk_ln = config.attn_config["qk_ln"]
        self.alibi_bias_max = config.attn_config["alibi_bias_max"]
Megha Agarwal's avatar
Megha Agarwal committed
60
61
62
63
        if "kv_n_heads" in config.attn_config:
            self.total_num_kv_heads = config.attn_config['kv_n_heads']
        else:
            self.total_num_kv_heads = self.total_num_heads
64
65
        assert not config.attn_config["prefix_lm"]
        assert config.attn_config["alibi"]
66

67
68
        # pylint: disable=invalid-name
        self.Wqkv = QKVParallelLinear(
69
            self.d_model,
70
71
            self.d_model // self.total_num_heads,
            self.total_num_heads,
Megha Agarwal's avatar
Megha Agarwal committed
72
            self.total_num_kv_heads,
73
            bias=not config.no_bias,
74
            quant_config=quant_config,
75
76
77
78
79
80
81
82
        )
        if self.qk_ln:
            self.q_ln = nn.LayerNorm(self.d_model)
            self.k_ln = nn.LayerNorm(self.d_model)
        self.out_proj = RowParallelLinear(
            self.d_model,
            self.d_model,
            bias=not config.no_bias,
83
            quant_config=quant_config,
84
85
86
87
88
89
        )

        tp_world_size = get_tensor_model_parallel_world_size()
        assert self.total_num_heads % tp_world_size == 0
        self.num_heads = self.total_num_heads // tp_world_size

Megha Agarwal's avatar
Megha Agarwal committed
90
91
92
93
94
95
96
97
98
99
100
        if self.total_num_kv_heads >= tp_world_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_world_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_world_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size)
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
101
102
103
104
105
106
107
108
109
110
        # Create the alibi slopes and slice them.
        tp_rank = get_tensor_model_parallel_rank()
        head_start = tp_rank * self.num_heads
        head_end = (tp_rank + 1) * self.num_heads
        alibi_slopes = _get_alibi_slopes(self.total_num_heads,
                                         self.alibi_bias_max)
        alibi_slopes = alibi_slopes[head_start:head_end].tolist()

        self.head_dim = self.d_model // self.total_num_heads
        scaling = self.head_dim**-0.5
111
112
113
114
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              scaling,
                              alibi_slopes=alibi_slopes,
115
                              num_kv_heads=self.num_kv_heads,
116
117
                              cache_config=cache_config,
                              quant_config=quant_config)
118
119
120
121
122

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
123
124
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
125
126
    ) -> torch.Tensor:
        del position_ids  # unused.
127
        qkv, _ = self.Wqkv(hidden_states)
128
129
        if self.clip_qkv is not None:
            qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
Megha Agarwal's avatar
Megha Agarwal committed
130
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
131
132
133
        if self.qk_ln:
            q = self.q_ln(q)
            k = self.k_ln(k)
134
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
135
136
137
138
        output, _ = self.out_proj(attn_output)
        return output


139
class MPTMLP(nn.Module):
140

141
142
    def __init__(
        self,
143
        config: MPTConfig,
144
        quant_config: Optional[QuantizationConfig] = None,
145
    ):
146
147
148
149
        super().__init__()
        hidden_size = config.d_model
        expansion_ratio = config.expansion_ratio
        intermediate_size = expansion_ratio * hidden_size
150
151
152
153
        self.up_proj = ColumnParallelLinear(
            hidden_size,
            intermediate_size,
            bias=not config.no_bias,
154
            quant_config=quant_config,
155
        )
156
        self.act = get_act_fn("gelu")
157
158
159
160
        self.down_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=not config.no_bias,
161
            quant_config=quant_config,
162
        )
163
164
165
166
167
168
169
170

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x, _ = self.up_proj(x)
        x = self.act(x)
        x, _ = self.down_proj(x)
        return x


171
class MPTBlock(nn.Module):
172

173
174
    def __init__(
        self,
175
        config: MPTConfig,
176
        cache_config: Optional[CacheConfig] = None,
177
        quant_config: Optional[QuantizationConfig] = None,
178
    ):
179
180
181
        super().__init__()
        hidden_size = config.d_model
        self.norm_1 = nn.LayerNorm(hidden_size)
182
        self.attn = MPTAttention(config, cache_config, quant_config)
183
        self.norm_2 = nn.LayerNorm(hidden_size)
184
        self.ffn = MPTMLP(config, quant_config)
185
186
187
188
189

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
190
191
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
192
193
194
195
196
197
    ) -> torch.Tensor:
        x = self.norm_1(hidden_states)
        x = self.attn(
            position_ids=position_ids,
            hidden_states=x,
            kv_cache=kv_cache,
198
            attn_metadata=attn_metadata,
199
200
201
202
203
204
205
206
        )
        hidden_states = hidden_states + x
        x = self.norm_2(hidden_states)
        x = self.ffn(x)
        hidden_states = hidden_states + x
        return hidden_states


207
@support_torch_compile
208
class MPTModel(nn.Module):
209

210
211
    def __init__(
        self,
212
        config: MPTConfig,
213
        cache_config: Optional[CacheConfig] = None,
214
        quant_config: Optional[QuantizationConfig] = None,
215
        prefix: str = "",
216
    ):
217
218
219
220
        super().__init__()
        assert config.embedding_fraction == 1.0
        assert config.norm_type == "low_precision_layernorm"

221
222
223
224
        self.wte = VocabParallelEmbedding(
            config.vocab_size,
            config.d_model,
        )
225
226
227
228
        self.start_layer, self.end_layer, self.blocks = make_layers(
            config.n_layers,
            lambda prefix: MPTBlock(config, cache_config, quant_config),
            prefix=f"{prefix}.blocks")
229
230
231
        self.norm_f = nn.LayerNorm(config.d_model)
        if config.no_bias:
            for module in self.modules():
232
233
234
235
                if hasattr(module, "bias") and isinstance(
                        module.bias, nn.Parameter):
                    # Remove the bias term in Linear and LayerNorm.
                    module.register_parameter("bias", None)
236
237
238
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(["hidden_states"],
                                                    config.d_model))
239
240
241
242
243

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
244
245
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
246
247
248
249
250
251
252
253
254
        intermediate_tensors: Optional[IntermediateTensors],
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
            hidden_states = self.wte(input_ids)
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]

        for i in range(self.start_layer, self.end_layer):
255
256
257
258
            block = self.blocks[i]
            hidden_states = block(
                position_ids,
                hidden_states,
259
                kv_caches[i - self.start_layer],
260
                attn_metadata,
261
            )
262
263
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})
264
265
266
267
        hidden_states = self.norm_f(hidden_states)
        return hidden_states


268
class MPTForCausalLM(nn.Module, SupportsPP):
269

270
271
    def __init__(
        self,
272
        config: MPTConfig,
273
        cache_config: Optional[CacheConfig] = None,
274
        quant_config: Optional[QuantizationConfig] = None,
275
    ):
276
277
278
        super().__init__()
        self.config = config
        assert config.tie_word_embeddings
279
        self.quant_config = quant_config
280

281
        self.transformer = MPTModel(config, cache_config, quant_config)
282
        self.lm_head = self.transformer.wte
283
284
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.sampler = Sampler()
285
286
        self.make_empty_intermediate_tensors = (
            self.transformer.make_empty_intermediate_tensors)
287
288
289
290
291

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
292
293
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
294
        intermediate_tensors: Optional[IntermediateTensors] = None,
295
    ) -> Union[torch.Tensor, IntermediateTensors]:
296
        hidden_states = self.transformer(input_ids, positions, kv_caches,
297
                                         attn_metadata, intermediate_tensors)
298
299
        return hidden_states

300
301
302
303
304
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
305
        logits = self.logits_processor(self.lm_head, hidden_states,
306
307
308
                                       sampling_metadata)
        return logits

309
310
    def sample(
        self,
311
        logits: torch.Tensor,
312
        sampling_metadata: SamplingMetadata,
313
    ) -> Optional[SamplerOutput]:
314
        next_tokens = self.sampler(logits, sampling_metadata)
315
316
        return next_tokens

317
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
318
        params_dict = dict(self.named_parameters(remove_duplicate=False))
319
        for name, loaded_weight in weights:
CHU Tianxiang's avatar
CHU Tianxiang committed
320
321
322
            # Skip loading extra bias for GPTQ models.
            if name.endswith(".bias") and name not in params_dict:
                continue
323
324
            if is_pp_missing_parameter(name, self):
                continue
325
326
327
328
            param = params_dict[name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)