mpt.py 9.54 KB
Newer Older
1
# coding=utf-8
2
3
# Adapted from https://huggingface.co/mosaicml/mpt-7b/tree/main
import math
4
from typing import List, Optional, Tuple
5
6
7
8
9
10
11

import torch
import torch.nn as nn

from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import PagedAttentionWithALiBi
12
13
14
15
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               LinearMethodBase,
                                               QKVParallelLinear,
                                               RowParallelLinear)
16
from vllm.model_executor.layers.sampler import Sampler
17
18
from vllm.model_executor.layers.vocab_parallel_embedding import (
    VocabParallelEmbedding)
19
20
from vllm.model_executor.parallel_utils.parallel_state import (
    get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
21
22
from vllm.model_executor.weight_utils import (default_weight_loader,
                                              hf_model_weights_iterator)
23
from vllm.sequence import SamplerOutput
24
from vllm.transformers_utils.configs.mpt import MPTConfig
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41

KVCache = Tuple[torch.Tensor, torch.Tensor]


def _get_alibi_slopes(
    total_num_heads: int,
    alibi_bias_max: int,
) -> torch.Tensor:
    next_power_of_2 = 2**math.ceil(math.log2(total_num_heads))
    m = torch.arange(1, next_power_of_2 + 1, dtype=torch.float32)
    m = m.mul(alibi_bias_max / next_power_of_2)
    slopes = 1.0 / torch.pow(2, m)
    if next_power_of_2 != total_num_heads:
        slopes = torch.concat([slopes[1::2], slopes[::2]])[:total_num_heads]
    return slopes


42
class MPTAttention(nn.Module):
43

44
45
    def __init__(
        self,
46
        config: MPTConfig,
47
48
        linear_method: Optional[LinearMethodBase] = None,
    ):
49
50
51
        super().__init__()
        self.d_model = config.d_model
        self.total_num_heads = config.n_heads
52
53
54
55
56
        self.clip_qkv = config.attn_config["clip_qkv"]
        self.qk_ln = config.attn_config["qk_ln"]
        self.alibi_bias_max = config.attn_config["alibi_bias_max"]
        assert not config.attn_config["prefix_lm"]
        assert config.attn_config["alibi"]
57

58
59
        # pylint: disable=invalid-name
        self.Wqkv = QKVParallelLinear(
60
            self.d_model,
61
62
            self.d_model // self.total_num_heads,
            self.total_num_heads,
63
            bias=not config.no_bias,
64
            linear_method=linear_method,
65
66
67
68
69
70
71
72
        )
        if self.qk_ln:
            self.q_ln = nn.LayerNorm(self.d_model)
            self.k_ln = nn.LayerNorm(self.d_model)
        self.out_proj = RowParallelLinear(
            self.d_model,
            self.d_model,
            bias=not config.no_bias,
73
            linear_method=linear_method,
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
        )

        tp_world_size = get_tensor_model_parallel_world_size()
        assert self.total_num_heads % tp_world_size == 0
        self.num_heads = self.total_num_heads // tp_world_size

        # Create the alibi slopes and slice them.
        tp_rank = get_tensor_model_parallel_rank()
        head_start = tp_rank * self.num_heads
        head_end = (tp_rank + 1) * self.num_heads
        alibi_slopes = _get_alibi_slopes(self.total_num_heads,
                                         self.alibi_bias_max)
        alibi_slopes = alibi_slopes[head_start:head_end].tolist()

        self.head_dim = self.d_model // self.total_num_heads
        scaling = self.head_dim**-0.5
        self.attn = PagedAttentionWithALiBi(self.num_heads, self.head_dim,
                                            scaling, alibi_slopes)

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
        kv_cache: KVCache,
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
    ) -> torch.Tensor:
        del position_ids  # unused.
102
        qkv, _ = self.Wqkv(hidden_states)
103
104
105
106
107
108
109
110
111
112
113
114
115
        if self.clip_qkv is not None:
            qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
        q, k, v = qkv.chunk(chunks=3, dim=-1)
        if self.qk_ln:
            q = self.q_ln(q)
            k = self.k_ln(k)
        k_cache, v_cache = kv_cache
        attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
                                cache_event)
        output, _ = self.out_proj(attn_output)
        return output


116
class MPTMLP(nn.Module):
117

118
119
    def __init__(
        self,
120
        config: MPTConfig,
121
122
        linear_method: Optional[LinearMethodBase] = None,
    ):
123
124
125
126
        super().__init__()
        hidden_size = config.d_model
        expansion_ratio = config.expansion_ratio
        intermediate_size = expansion_ratio * hidden_size
127
128
129
130
        self.up_proj = ColumnParallelLinear(
            hidden_size,
            intermediate_size,
            bias=not config.no_bias,
131
            linear_method=linear_method,
132
        )
133
134
        quant_config = getattr(linear_method, "quant_config", None)
        self.act = get_act_fn("gelu", quant_config, intermediate_size)
135
136
137
138
        self.down_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=not config.no_bias,
139
            linear_method=linear_method,
140
        )
141
142
143
144
145
146
147
148

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x, _ = self.up_proj(x)
        x = self.act(x)
        x, _ = self.down_proj(x)
        return x


149
class MPTBlock(nn.Module):
150

151
152
    def __init__(
        self,
153
        config: MPTConfig,
154
155
        linear_method: Optional[LinearMethodBase] = None,
    ):
156
157
158
        super().__init__()
        hidden_size = config.d_model
        self.norm_1 = nn.LayerNorm(hidden_size)
159
        self.attn = MPTAttention(config, linear_method)
160
        self.norm_2 = nn.LayerNorm(hidden_size)
161
        self.ffn = MPTMLP(config, linear_method)
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
        kv_cache: KVCache,
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
    ) -> torch.Tensor:
        x = self.norm_1(hidden_states)
        x = self.attn(
            position_ids=position_ids,
            hidden_states=x,
            kv_cache=kv_cache,
            input_metadata=input_metadata,
            cache_event=cache_event,
        )
        hidden_states = hidden_states + x
        x = self.norm_2(hidden_states)
        x = self.ffn(x)
        hidden_states = hidden_states + x
        return hidden_states


186
class MPTModel(nn.Module):
187

188
189
    def __init__(
        self,
190
        config: MPTConfig,
191
192
        linear_method: Optional[LinearMethodBase] = None,
    ):
193
194
195
196
        super().__init__()
        assert config.embedding_fraction == 1.0
        assert config.norm_type == "low_precision_layernorm"

197
198
199
200
        self.wte = VocabParallelEmbedding(
            config.vocab_size,
            config.d_model,
        )
201
        self.blocks = nn.ModuleList(
202
            [MPTBlock(config, linear_method) for _ in range(config.n_layers)])
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
        self.norm_f = nn.LayerNorm(config.d_model)
        if config.no_bias:
            for module in self.modules():
                if hasattr(module, "bias"):
                    if isinstance(module.bias, nn.Parameter):
                        # Remove the bias term in Linear and LayerNorm.
                        module.register_parameter("bias", None)

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
        cache_events: Optional[List[torch.cuda.Event]],
    ) -> torch.Tensor:
        hidden_states = self.wte(input_ids)
        for i in range(len(self.blocks)):
            if cache_events is None:
                cache_event = None
            else:
                cache_event = cache_events[i]
            block = self.blocks[i]
            hidden_states = block(
                position_ids,
                hidden_states,
                kv_caches[i],
                input_metadata,
                cache_event,
            )
        hidden_states = self.norm_f(hidden_states)
        return hidden_states


237
class MPTForCausalLM(nn.Module):
238

239
240
    def __init__(
        self,
241
        config: MPTConfig,
242
243
        linear_method: Optional[LinearMethodBase] = None,
    ):
244
245
246
        super().__init__()
        self.config = config
        assert config.tie_word_embeddings
247
        self.linear_method = linear_method
248

249
        self.transformer = MPTModel(config, linear_method)
250
251
252
253
254
255
256
257
258
259
        self.lm_head_weight = self.transformer.wte.weight
        self.sampler = Sampler(config.vocab_size)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
        cache_events: Optional[List[torch.cuda.Event]],
260
    ) -> SamplerOutput:
261
262
263
264
265
266
267
268
269
        hidden_states = self.transformer(input_ids, positions, kv_caches,
                                         input_metadata, cache_events)
        next_tokens = self.sampler(self.lm_head_weight, hidden_states,
                                   input_metadata)
        return next_tokens

    def load_weights(self,
                     model_name_or_path: str,
                     cache_dir: Optional[str] = None,
Jasmond L's avatar
Jasmond L committed
270
271
                     load_format: str = "auto",
                     revision: Optional[str] = None):
272
        params_dict = dict(self.named_parameters(remove_duplicate=False))
273
        for name, loaded_weight in hf_model_weights_iterator(
Jasmond L's avatar
Jasmond L committed
274
                model_name_or_path, cache_dir, load_format, revision):
275
276
277
278
            param = params_dict[name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)