mpt.py 9.44 KB
Newer Older
1
# coding=utf-8
2
3
# Adapted from https://huggingface.co/mosaicml/mpt-7b/tree/main
import math
4
from typing import List, Optional, Tuple
5
6
7
8
9
10
11

import torch
import torch.nn as nn

from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import PagedAttentionWithALiBi
12
13
14
15
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               LinearMethodBase,
                                               QKVParallelLinear,
                                               RowParallelLinear)
16
from vllm.model_executor.layers.sampler import Sampler
17
18
from vllm.model_executor.layers.vocab_parallel_embedding import (
    VocabParallelEmbedding)
19
20
from vllm.model_executor.parallel_utils.parallel_state import (
    get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
21
22
from vllm.model_executor.weight_utils import (default_weight_loader,
                                              hf_model_weights_iterator)
23
from vllm.sequence import SamplerOutput
24
from vllm.transformers_utils.configs.mpt import MPTConfig
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41

KVCache = Tuple[torch.Tensor, torch.Tensor]


def _get_alibi_slopes(
    total_num_heads: int,
    alibi_bias_max: int,
) -> torch.Tensor:
    next_power_of_2 = 2**math.ceil(math.log2(total_num_heads))
    m = torch.arange(1, next_power_of_2 + 1, dtype=torch.float32)
    m = m.mul(alibi_bias_max / next_power_of_2)
    slopes = 1.0 / torch.pow(2, m)
    if next_power_of_2 != total_num_heads:
        slopes = torch.concat([slopes[1::2], slopes[::2]])[:total_num_heads]
    return slopes


42
class MPTAttention(nn.Module):
43

44
45
    def __init__(
        self,
46
        config: MPTConfig,
47
48
        linear_method: Optional[LinearMethodBase] = None,
    ):
49
50
51
        super().__init__()
        self.d_model = config.d_model
        self.total_num_heads = config.n_heads
52
53
54
55
56
        self.clip_qkv = config.attn_config["clip_qkv"]
        self.qk_ln = config.attn_config["qk_ln"]
        self.alibi_bias_max = config.attn_config["alibi_bias_max"]
        assert not config.attn_config["prefix_lm"]
        assert config.attn_config["alibi"]
57

58
59
        # pylint: disable=invalid-name
        self.Wqkv = QKVParallelLinear(
60
            self.d_model,
61
62
            self.d_model // self.total_num_heads,
            self.total_num_heads,
63
            bias=not config.no_bias,
64
            linear_method=linear_method,
65
66
67
68
69
70
71
72
        )
        if self.qk_ln:
            self.q_ln = nn.LayerNorm(self.d_model)
            self.k_ln = nn.LayerNorm(self.d_model)
        self.out_proj = RowParallelLinear(
            self.d_model,
            self.d_model,
            bias=not config.no_bias,
73
            linear_method=linear_method,
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
        )

        tp_world_size = get_tensor_model_parallel_world_size()
        assert self.total_num_heads % tp_world_size == 0
        self.num_heads = self.total_num_heads // tp_world_size

        # Create the alibi slopes and slice them.
        tp_rank = get_tensor_model_parallel_rank()
        head_start = tp_rank * self.num_heads
        head_end = (tp_rank + 1) * self.num_heads
        alibi_slopes = _get_alibi_slopes(self.total_num_heads,
                                         self.alibi_bias_max)
        alibi_slopes = alibi_slopes[head_start:head_end].tolist()

        self.head_dim = self.d_model // self.total_num_heads
        scaling = self.head_dim**-0.5
        self.attn = PagedAttentionWithALiBi(self.num_heads, self.head_dim,
                                            scaling, alibi_slopes)

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
        kv_cache: KVCache,
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
    ) -> torch.Tensor:
        del position_ids  # unused.
102
        qkv, _ = self.Wqkv(hidden_states)
103
104
105
106
107
108
109
110
111
112
113
114
115
        if self.clip_qkv is not None:
            qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
        q, k, v = qkv.chunk(chunks=3, dim=-1)
        if self.qk_ln:
            q = self.q_ln(q)
            k = self.k_ln(k)
        k_cache, v_cache = kv_cache
        attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
                                cache_event)
        output, _ = self.out_proj(attn_output)
        return output


116
class MPTMLP(nn.Module):
117

118
119
    def __init__(
        self,
120
        config: MPTConfig,
121
122
        linear_method: Optional[LinearMethodBase] = None,
    ):
123
124
125
126
        super().__init__()
        hidden_size = config.d_model
        expansion_ratio = config.expansion_ratio
        intermediate_size = expansion_ratio * hidden_size
127
128
129
130
        self.up_proj = ColumnParallelLinear(
            hidden_size,
            intermediate_size,
            bias=not config.no_bias,
131
            linear_method=linear_method,
132
        )
133
        self.act = get_act_fn("gelu")
134
135
136
137
        self.down_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=not config.no_bias,
138
            linear_method=linear_method,
139
        )
140
141
142
143
144
145
146
147

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x, _ = self.up_proj(x)
        x = self.act(x)
        x, _ = self.down_proj(x)
        return x


148
class MPTBlock(nn.Module):
149

150
151
    def __init__(
        self,
152
        config: MPTConfig,
153
154
        linear_method: Optional[LinearMethodBase] = None,
    ):
155
156
157
        super().__init__()
        hidden_size = config.d_model
        self.norm_1 = nn.LayerNorm(hidden_size)
158
        self.attn = MPTAttention(config, linear_method)
159
        self.norm_2 = nn.LayerNorm(hidden_size)
160
        self.ffn = MPTMLP(config, linear_method)
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
        kv_cache: KVCache,
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
    ) -> torch.Tensor:
        x = self.norm_1(hidden_states)
        x = self.attn(
            position_ids=position_ids,
            hidden_states=x,
            kv_cache=kv_cache,
            input_metadata=input_metadata,
            cache_event=cache_event,
        )
        hidden_states = hidden_states + x
        x = self.norm_2(hidden_states)
        x = self.ffn(x)
        hidden_states = hidden_states + x
        return hidden_states


185
class MPTModel(nn.Module):
186

187
188
    def __init__(
        self,
189
        config: MPTConfig,
190
191
        linear_method: Optional[LinearMethodBase] = None,
    ):
192
193
194
195
        super().__init__()
        assert config.embedding_fraction == 1.0
        assert config.norm_type == "low_precision_layernorm"

196
197
198
199
        self.wte = VocabParallelEmbedding(
            config.vocab_size,
            config.d_model,
        )
200
        self.blocks = nn.ModuleList(
201
            [MPTBlock(config, linear_method) for _ in range(config.n_layers)])
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
        self.norm_f = nn.LayerNorm(config.d_model)
        if config.no_bias:
            for module in self.modules():
                if hasattr(module, "bias"):
                    if isinstance(module.bias, nn.Parameter):
                        # Remove the bias term in Linear and LayerNorm.
                        module.register_parameter("bias", None)

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
        cache_events: Optional[List[torch.cuda.Event]],
    ) -> torch.Tensor:
        hidden_states = self.wte(input_ids)
        for i in range(len(self.blocks)):
            if cache_events is None:
                cache_event = None
            else:
                cache_event = cache_events[i]
            block = self.blocks[i]
            hidden_states = block(
                position_ids,
                hidden_states,
                kv_caches[i],
                input_metadata,
                cache_event,
            )
        hidden_states = self.norm_f(hidden_states)
        return hidden_states


236
class MPTForCausalLM(nn.Module):
237

238
239
    def __init__(
        self,
240
        config: MPTConfig,
241
242
        linear_method: Optional[LinearMethodBase] = None,
    ):
243
244
245
        super().__init__()
        self.config = config
        assert config.tie_word_embeddings
246
        self.linear_method = linear_method
247

248
        self.transformer = MPTModel(config, linear_method)
249
250
251
252
253
254
255
256
257
258
        self.lm_head_weight = self.transformer.wte.weight
        self.sampler = Sampler(config.vocab_size)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
        cache_events: Optional[List[torch.cuda.Event]],
259
    ) -> SamplerOutput:
260
261
262
263
264
265
266
267
268
        hidden_states = self.transformer(input_ids, positions, kv_caches,
                                         input_metadata, cache_events)
        next_tokens = self.sampler(self.lm_head_weight, hidden_states,
                                   input_metadata)
        return next_tokens

    def load_weights(self,
                     model_name_or_path: str,
                     cache_dir: Optional[str] = None,
Jasmond L's avatar
Jasmond L committed
269
270
                     load_format: str = "auto",
                     revision: Optional[str] = None):
271
        params_dict = dict(self.named_parameters(remove_duplicate=False))
272
        for name, loaded_weight in hf_model_weights_iterator(
Jasmond L's avatar
Jasmond L committed
273
                model_name_or_path, cache_dir, load_format, revision):
274
275
276
277
            param = params_dict[name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)