mpt.py 10.6 KB
Newer Older
1
# coding=utf-8
2
3
# Adapted from https://huggingface.co/mosaicml/mpt-7b/tree/main
import math
4
from typing import List, Optional, Tuple
5
6
7

import torch
import torch.nn as nn
Woosuk Kwon's avatar
Woosuk Kwon committed
8
from transformers import MptConfig
9
10
11
12
13

from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import PagedAttentionWithALiBi
from vllm.model_executor.layers.sampler import Sampler
14
15
from vllm.model_executor.weight_utils import (convert_pyslice_to_tensor,
                                              hf_model_weights_iterator,
16
17
18
                                              load_tensor_parallel_weights)
from vllm.model_executor.parallel_utils.parallel_state import (
    get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
19
20
21
from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding,
                                                       ColumnParallelLinear,
                                                       RowParallelLinear)
22
from vllm.sequence import SamplerOutput
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39

KVCache = Tuple[torch.Tensor, torch.Tensor]


def _get_alibi_slopes(
    total_num_heads: int,
    alibi_bias_max: int,
) -> torch.Tensor:
    next_power_of_2 = 2**math.ceil(math.log2(total_num_heads))
    m = torch.arange(1, next_power_of_2 + 1, dtype=torch.float32)
    m = m.mul(alibi_bias_max / next_power_of_2)
    slopes = 1.0 / torch.pow(2, m)
    if next_power_of_2 != total_num_heads:
        slopes = torch.concat([slopes[1::2], slopes[::2]])[:total_num_heads]
    return slopes


Woosuk Kwon's avatar
Woosuk Kwon committed
40
class MptAttention(nn.Module):
41

Woosuk Kwon's avatar
Woosuk Kwon committed
42
    def __init__(self, config: MptConfig):
43
44
45
        super().__init__()
        self.d_model = config.d_model
        self.total_num_heads = config.n_heads
Woosuk Kwon's avatar
Woosuk Kwon committed
46
47
48
49
50
        self.clip_qkv = config.attn_config.clip_qkv
        self.qk_ln = config.attn_config.qk_ln
        self.alibi_bias_max = config.attn_config.alibi_bias_max
        assert not config.attn_config.prefix_lm
        assert config.attn_config.alibi
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107

        self.qkv_proj = ColumnParallelLinear(
            self.d_model,
            3 * self.d_model,
            bias=not config.no_bias,
            gather_output=False,
        )
        if self.qk_ln:
            self.q_ln = nn.LayerNorm(self.d_model)
            self.k_ln = nn.LayerNorm(self.d_model)
        self.out_proj = RowParallelLinear(
            self.d_model,
            self.d_model,
            bias=not config.no_bias,
            input_is_parallel=True,
        )

        tp_world_size = get_tensor_model_parallel_world_size()
        assert self.total_num_heads % tp_world_size == 0
        self.num_heads = self.total_num_heads // tp_world_size

        # Create the alibi slopes and slice them.
        tp_rank = get_tensor_model_parallel_rank()
        head_start = tp_rank * self.num_heads
        head_end = (tp_rank + 1) * self.num_heads
        alibi_slopes = _get_alibi_slopes(self.total_num_heads,
                                         self.alibi_bias_max)
        alibi_slopes = alibi_slopes[head_start:head_end].tolist()

        self.head_dim = self.d_model // self.total_num_heads
        scaling = self.head_dim**-0.5
        self.attn = PagedAttentionWithALiBi(self.num_heads, self.head_dim,
                                            scaling, alibi_slopes)

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
        kv_cache: KVCache,
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
    ) -> torch.Tensor:
        del position_ids  # unused.
        qkv, _ = self.qkv_proj(hidden_states)
        if self.clip_qkv is not None:
            qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
        q, k, v = qkv.chunk(chunks=3, dim=-1)
        if self.qk_ln:
            q = self.q_ln(q)
            k = self.k_ln(k)
        k_cache, v_cache = kv_cache
        attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
                                cache_event)
        output, _ = self.out_proj(attn_output)
        return output


Woosuk Kwon's avatar
Woosuk Kwon committed
108
class MptMLP(nn.Module):
109

Woosuk Kwon's avatar
Woosuk Kwon committed
110
    def __init__(self, config: MptConfig):
111
112
113
114
        super().__init__()
        hidden_size = config.d_model
        expansion_ratio = config.expansion_ratio
        intermediate_size = expansion_ratio * hidden_size
115
116
117
118
119
120
        self.up_proj = ColumnParallelLinear(
            hidden_size,
            intermediate_size,
            bias=not config.no_bias,
            gather_output=False,
        )
121
        self.act = get_act_fn("gelu")
122
123
124
125
126
127
        self.down_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=not config.no_bias,
            input_is_parallel=True,
        )
128
129
130
131
132
133
134
135

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x, _ = self.up_proj(x)
        x = self.act(x)
        x, _ = self.down_proj(x)
        return x


Woosuk Kwon's avatar
Woosuk Kwon committed
136
class MptBlock(nn.Module):
137

Woosuk Kwon's avatar
Woosuk Kwon committed
138
    def __init__(self, config: MptConfig):
139
140
141
        super().__init__()
        hidden_size = config.d_model
        self.norm_1 = nn.LayerNorm(hidden_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
142
        self.attn = MptAttention(config)
143
        self.norm_2 = nn.LayerNorm(hidden_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
144
        self.ffn = MptMLP(config)
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
        kv_cache: KVCache,
        input_metadata: InputMetadata,
        cache_event: Optional[torch.cuda.Event],
    ) -> torch.Tensor:
        x = self.norm_1(hidden_states)
        x = self.attn(
            position_ids=position_ids,
            hidden_states=x,
            kv_cache=kv_cache,
            input_metadata=input_metadata,
            cache_event=cache_event,
        )
        hidden_states = hidden_states + x
        x = self.norm_2(hidden_states)
        x = self.ffn(x)
        hidden_states = hidden_states + x
        return hidden_states


Woosuk Kwon's avatar
Woosuk Kwon committed
169
class MptModel(nn.Module):
170

Woosuk Kwon's avatar
Woosuk Kwon committed
171
    def __init__(self, config: MptConfig):
172
173
174
175
        super().__init__()
        assert config.embedding_fraction == 1.0
        assert config.norm_type == "low_precision_layernorm"

176
177
178
179
        self.wte = VocabParallelEmbedding(
            config.vocab_size,
            config.d_model,
        )
180
        self.blocks = nn.ModuleList(
Woosuk Kwon's avatar
Woosuk Kwon committed
181
            [MptBlock(config) for _ in range(config.n_layers)])
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
        self.norm_f = nn.LayerNorm(config.d_model)
        if config.no_bias:
            for module in self.modules():
                if hasattr(module, "bias"):
                    if isinstance(module.bias, nn.Parameter):
                        # Remove the bias term in Linear and LayerNorm.
                        module.register_parameter("bias", None)

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
        cache_events: Optional[List[torch.cuda.Event]],
    ) -> torch.Tensor:
        hidden_states = self.wte(input_ids)
        for i in range(len(self.blocks)):
            if cache_events is None:
                cache_event = None
            else:
                cache_event = cache_events[i]
            block = self.blocks[i]
            hidden_states = block(
                position_ids,
                hidden_states,
                kv_caches[i],
                input_metadata,
                cache_event,
            )
        hidden_states = self.norm_f(hidden_states)
        return hidden_states


Woosuk Kwon's avatar
Woosuk Kwon committed
216
class MptForCausalLM(nn.Module):
217

Woosuk Kwon's avatar
Woosuk Kwon committed
218
    def __init__(self, config: MptConfig):
219
220
221
222
        super().__init__()
        self.config = config
        assert config.tie_word_embeddings

Woosuk Kwon's avatar
Woosuk Kwon committed
223
        self.transformer = MptModel(config)
224
225
226
227
228
229
230
231
232
233
234
235
        # TODO(zhuohan): create a new weight after implementing pipeline
        #                parallelism
        self.lm_head_weight = self.transformer.wte.weight
        self.sampler = Sampler(config.vocab_size)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[KVCache],
        input_metadata: InputMetadata,
        cache_events: Optional[List[torch.cuda.Event]],
236
    ) -> SamplerOutput:
237
238
239
240
241
242
243
244
245
246
247
248
        hidden_states = self.transformer(input_ids, positions, kv_caches,
                                         input_metadata, cache_events)
        next_tokens = self.sampler(self.lm_head_weight, hidden_states,
                                   input_metadata)
        return next_tokens

    _column_parallel_weights = ["wte.weight", "up_proj.weight", "up_proj.bias"]
    _row_parallel_weights = ["out_proj.weight", "down_proj.weight"]

    def load_weights(self,
                     model_name_or_path: str,
                     cache_dir: Optional[str] = None,
Jasmond L's avatar
Jasmond L committed
249
250
                     load_format: str = "auto",
                     revision: Optional[str] = None):
251
252
253
254
        tp_world_size = get_tensor_model_parallel_world_size()
        tp_rank = get_tensor_model_parallel_rank()
        state_dict = self.state_dict()
        for name, loaded_weight in hf_model_weights_iterator(
Jasmond L's avatar
Jasmond L committed
255
                model_name_or_path, cache_dir, load_format, revision):
256
257
258
259
260
261
262
263
264
265
266
            if "Wqkv" in name:
                # NOTE(woosuk): MPT's fused QKV has the shape of
                # [3 * num_heads * head_size, hidden_size].
                # When tensor model parallelism is used, we need to shard
                # the weight along the hidden dimension.
                total_num_heads = self.config.num_attention_heads
                hidden_size = self.config.hidden_size
                head_size = hidden_size // total_num_heads
                num_heads = total_num_heads // tp_world_size
                head_start = tp_rank * num_heads
                head_end = (tp_rank + 1) * num_heads
267
                loaded_weight = convert_pyslice_to_tensor(loaded_weight)
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
                if name.endswith(".weight"):
                    loaded_weight = loaded_weight.view(3, total_num_heads,
                                                       head_size, hidden_size)
                    loaded_weight = loaded_weight[:, head_start:head_end, :, :]
                    loaded_weight = loaded_weight.reshape(-1, hidden_size)
                elif name.endswith(".bias"):
                    loaded_weight = loaded_weight.view(3, total_num_heads,
                                                       head_size)
                    loaded_weight = loaded_weight[:, head_start:head_end, :]
                    loaded_weight = loaded_weight.reshape(-1)
                else:
                    raise ValueError(f"Unexpected parameter name {name}")
                name = name.replace("Wqkv", "qkv_proj")
            param = state_dict[name]
            load_tensor_parallel_weights(param, loaded_weight, name,
                                         self._column_parallel_weights,
                                         self._row_parallel_weights, tp_rank)