mpt.py 10.7 KB
Newer Older
1
# coding=utf-8
2
3
# Adapted from https://huggingface.co/mosaicml/mpt-7b/tree/main
import math
4
from typing import Iterable, List, Optional, Tuple
5
6
7
8

import torch
import torch.nn as nn

9
from vllm.attention import Attention, AttentionMetadata
10
from vllm.config import CacheConfig
11
12
from vllm.distributed import (get_tensor_model_parallel_rank,
                              get_tensor_model_parallel_world_size)
13
from vllm.model_executor.layers.activation import get_act_fn
14
15
16
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
17
from vllm.model_executor.layers.logits_processor import LogitsProcessor
18
19
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
20
from vllm.model_executor.layers.sampler import Sampler
21
22
from vllm.model_executor.layers.vocab_parallel_embedding import (
    VocabParallelEmbedding)
23
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
24
from vllm.model_executor.sampling_metadata import SamplingMetadata
25
from vllm.sequence import SamplerOutput
26
from vllm.transformers_utils.configs.mpt import MPTConfig
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41


def _get_alibi_slopes(
    total_num_heads: int,
    alibi_bias_max: int,
) -> torch.Tensor:
    next_power_of_2 = 2**math.ceil(math.log2(total_num_heads))
    m = torch.arange(1, next_power_of_2 + 1, dtype=torch.float32)
    m = m.mul(alibi_bias_max / next_power_of_2)
    slopes = 1.0 / torch.pow(2, m)
    if next_power_of_2 != total_num_heads:
        slopes = torch.concat([slopes[1::2], slopes[::2]])[:total_num_heads]
    return slopes


42
class MPTAttention(nn.Module):
43

44
45
    def __init__(
        self,
46
        config: MPTConfig,
47
        cache_config: Optional[CacheConfig] = None,
48
        quant_config: Optional[QuantizationConfig] = None,
49
    ):
50
51
52
        super().__init__()
        self.d_model = config.d_model
        self.total_num_heads = config.n_heads
Megha Agarwal's avatar
Megha Agarwal committed
53
        self.head_dim = self.d_model // self.total_num_heads
54
55
56
        self.clip_qkv = config.attn_config["clip_qkv"]
        self.qk_ln = config.attn_config["qk_ln"]
        self.alibi_bias_max = config.attn_config["alibi_bias_max"]
Megha Agarwal's avatar
Megha Agarwal committed
57
58
59
60
        if "kv_n_heads" in config.attn_config:
            self.total_num_kv_heads = config.attn_config['kv_n_heads']
        else:
            self.total_num_kv_heads = self.total_num_heads
61
62
        assert not config.attn_config["prefix_lm"]
        assert config.attn_config["alibi"]
63

64
65
        # pylint: disable=invalid-name
        self.Wqkv = QKVParallelLinear(
66
            self.d_model,
67
68
            self.d_model // self.total_num_heads,
            self.total_num_heads,
Megha Agarwal's avatar
Megha Agarwal committed
69
            self.total_num_kv_heads,
70
            bias=not config.no_bias,
71
            quant_config=quant_config,
72
73
74
75
76
77
78
79
        )
        if self.qk_ln:
            self.q_ln = nn.LayerNorm(self.d_model)
            self.k_ln = nn.LayerNorm(self.d_model)
        self.out_proj = RowParallelLinear(
            self.d_model,
            self.d_model,
            bias=not config.no_bias,
80
            quant_config=quant_config,
81
82
83
84
85
86
        )

        tp_world_size = get_tensor_model_parallel_world_size()
        assert self.total_num_heads % tp_world_size == 0
        self.num_heads = self.total_num_heads // tp_world_size

Megha Agarwal's avatar
Megha Agarwal committed
87
88
89
90
91
92
93
94
95
96
97
        if self.total_num_kv_heads >= tp_world_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_world_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_world_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size)
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
98
99
100
101
102
103
104
105
106
107
        # Create the alibi slopes and slice them.
        tp_rank = get_tensor_model_parallel_rank()
        head_start = tp_rank * self.num_heads
        head_end = (tp_rank + 1) * self.num_heads
        alibi_slopes = _get_alibi_slopes(self.total_num_heads,
                                         self.alibi_bias_max)
        alibi_slopes = alibi_slopes[head_start:head_end].tolist()

        self.head_dim = self.d_model // self.total_num_heads
        scaling = self.head_dim**-0.5
108
109
110
111
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              scaling,
                              alibi_slopes=alibi_slopes,
112
113
                              num_kv_heads=self.num_kv_heads,
                              cache_config=cache_config)
114
115
116
117
118

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
119
120
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
121
122
    ) -> torch.Tensor:
        del position_ids  # unused.
123
        qkv, _ = self.Wqkv(hidden_states)
124
125
        if self.clip_qkv is not None:
            qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
Megha Agarwal's avatar
Megha Agarwal committed
126
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
127
128
129
        if self.qk_ln:
            q = self.q_ln(q)
            k = self.k_ln(k)
130
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
131
132
133
134
        output, _ = self.out_proj(attn_output)
        return output


135
class MPTMLP(nn.Module):
136

137
138
    def __init__(
        self,
139
        config: MPTConfig,
140
        quant_config: Optional[QuantizationConfig] = None,
141
    ):
142
143
144
145
        super().__init__()
        hidden_size = config.d_model
        expansion_ratio = config.expansion_ratio
        intermediate_size = expansion_ratio * hidden_size
146
147
148
149
        self.up_proj = ColumnParallelLinear(
            hidden_size,
            intermediate_size,
            bias=not config.no_bias,
150
            quant_config=quant_config,
151
        )
152
        self.act = get_act_fn("gelu", quant_config, intermediate_size)
153
154
155
156
        self.down_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=not config.no_bias,
157
            quant_config=quant_config,
158
        )
159
160
161
162
163
164
165
166

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x, _ = self.up_proj(x)
        x = self.act(x)
        x, _ = self.down_proj(x)
        return x


167
class MPTBlock(nn.Module):
168

169
170
    def __init__(
        self,
171
        config: MPTConfig,
172
        cache_config: Optional[CacheConfig] = None,
173
        quant_config: Optional[QuantizationConfig] = None,
174
    ):
175
176
177
        super().__init__()
        hidden_size = config.d_model
        self.norm_1 = nn.LayerNorm(hidden_size)
178
        self.attn = MPTAttention(config, cache_config, quant_config)
179
        self.norm_2 = nn.LayerNorm(hidden_size)
180
        self.ffn = MPTMLP(config, quant_config)
181
182
183
184
185

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
186
187
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
188
189
190
191
192
193
    ) -> torch.Tensor:
        x = self.norm_1(hidden_states)
        x = self.attn(
            position_ids=position_ids,
            hidden_states=x,
            kv_cache=kv_cache,
194
            attn_metadata=attn_metadata,
195
196
197
198
199
200
201
202
        )
        hidden_states = hidden_states + x
        x = self.norm_2(hidden_states)
        x = self.ffn(x)
        hidden_states = hidden_states + x
        return hidden_states


203
class MPTModel(nn.Module):
204

205
206
    def __init__(
        self,
207
        config: MPTConfig,
208
        cache_config: Optional[CacheConfig] = None,
209
        quant_config: Optional[QuantizationConfig] = None,
210
    ):
211
212
213
214
        super().__init__()
        assert config.embedding_fraction == 1.0
        assert config.norm_type == "low_precision_layernorm"

215
216
217
218
        self.wte = VocabParallelEmbedding(
            config.vocab_size,
            config.d_model,
        )
219
220
221
222
        self.blocks = nn.ModuleList([
            MPTBlock(config, cache_config, quant_config)
            for _ in range(config.n_layers)
        ])
223
224
225
        self.norm_f = nn.LayerNorm(config.d_model)
        if config.no_bias:
            for module in self.modules():
226
227
228
229
                if hasattr(module, "bias") and isinstance(
                        module.bias, nn.Parameter):
                    # Remove the bias term in Linear and LayerNorm.
                    module.register_parameter("bias", None)
230
231
232
233
234

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
235
236
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
237
238
239
240
241
242
243
244
    ) -> torch.Tensor:
        hidden_states = self.wte(input_ids)
        for i in range(len(self.blocks)):
            block = self.blocks[i]
            hidden_states = block(
                position_ids,
                hidden_states,
                kv_caches[i],
245
                attn_metadata,
246
247
248
249
250
            )
        hidden_states = self.norm_f(hidden_states)
        return hidden_states


251
class MPTForCausalLM(nn.Module):
252

253
254
    def __init__(
        self,
255
        config: MPTConfig,
256
        cache_config: Optional[CacheConfig] = None,
257
        quant_config: Optional[QuantizationConfig] = None,
258
    ):
259
260
261
        super().__init__()
        self.config = config
        assert config.tie_word_embeddings
262
        self.quant_config = quant_config
263

264
        self.transformer = MPTModel(config, cache_config, quant_config)
265
        self.lm_head_weight = self.transformer.wte.weight
266
267
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.sampler = Sampler()
268
269
270
271
272

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
273
274
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
275
    ) -> torch.Tensor:
276
        hidden_states = self.transformer(input_ids, positions, kv_caches,
277
                                         attn_metadata)
278
279
        return hidden_states

280
281
282
283
284
285
    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head_weight, hidden_states,
                                       sampling_metadata)
        return logits

286
287
    def sample(
        self,
288
        logits: torch.Tensor,
289
        sampling_metadata: SamplingMetadata,
290
    ) -> Optional[SamplerOutput]:
291
        next_tokens = self.sampler(logits, sampling_metadata)
292
293
        return next_tokens

294
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
295
        params_dict = dict(self.named_parameters(remove_duplicate=False))
296
        for name, loaded_weight in weights:
CHU Tianxiang's avatar
CHU Tianxiang committed
297
298
299
            # Skip loading extra bias for GPTQ models.
            if name.endswith(".bias") and name not in params_dict:
                continue
300
301
302
303
            param = params_dict[name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)