gpt.py 10.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import math
from typing import Callable

import torch
from colossalai import nn as col_nn
from colossalai.nn.layer.utils import CheckpointModule
from colossalai.registry import LAYERS, MODELS, LOSSES
from colossalai.utils import get_current_device
from torch import dtype, nn

__all__ = ['GPT', 'GPTLMLoss', 'gpt2_small', 'gpt2_medium', 'gpt2_large', 'gpt2_xl', 'gpt3']


@LAYERS.register_module
class GPTEmbedding(nn.Module):
    def __init__(self,
                 embedding_dim: int,
                 vocab_size: int,
                 max_position_embeddings: int,
                 num_tokentypes: int = 0,
                 padding_idx: int = 0,
                 dropout: float = 0.,
                 dtype: dtype = None) -> None:
        super().__init__()
        self.word_embeddings = col_nn.Embedding(vocab_size, embedding_dim, padding_idx=padding_idx, dtype=dtype)
        self.position_embeddings = col_nn.Embedding(max_position_embeddings, embedding_dim, dtype=dtype)
        if num_tokentypes > 0:
            self.tokentype_embeddings = col_nn.Embedding(num_tokentypes, embedding_dim, dtype=dtype)
        else:
            self.tokentype_embeddings = None
        self.dropout = col_nn.Dropout(dropout)

    @property
    def word_embedding_weight(self):
        return self.word_embeddings.weight

    def forward(self, input_ids, position_ids=None, tokentype_ids=None):
        seq_length = input_ids.size(1)
        if position_ids is None:
            position_ids = torch.arange(seq_length, dtype=torch.long, device=get_current_device()).unsqueeze(0)
        x = self.word_embeddings(input_ids) + self.position_embeddings(position_ids)
        if self.tokentype_embeddings is not None and tokentype_ids is not None:
            x = x + self.tokentype_embeddings(tokentype_ids)
        x = self.dropout(x)
        return x


@LAYERS.register_module
class GPTSelfAttention(nn.Module):
    def __init__(self,
                 dim: int,
                 num_heads: int,
                 attention_dropout: float,
                 dropout: float,
                 bias: bool = True,
                 dtype: dtype = None) -> None:
        super().__init__()

        self.attention_head_size = dim // num_heads
        self.query_key_value = col_nn.Linear(dim, 3 * dim, dtype=dtype, bias=bias)
        self.attention_dropout = col_nn.Dropout(attention_dropout)
        self.dense = col_nn.Linear(dim, dim, dtype=dtype, bias=True)
        self.dropout = col_nn.Dropout(dropout)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x, attention_mask=None):
        qkv = self.query_key_value(x)
        all_head_size = qkv.shape[-1] // 3
        num_attention_heads = all_head_size // self.attention_head_size
        new_qkv_shape = qkv.shape[:-1] + \
            (num_attention_heads, 3 * self.attention_head_size)
        qkv = qkv.view(new_qkv_shape)
        qkv = qkv.permute((0, 2, 1, 3))
        q, k, v = torch.chunk(qkv, 3, dim=-1)

        x = torch.matmul(q, k.transpose(-1, -2))
        x = x / math.sqrt(self.attention_head_size)

        # causal mask
        q_len, k_len = q.size(-2), k.size(-2)
        causal_mask = torch.tril(torch.ones((q_len, k_len), dtype=torch.uint8,
                                            device=get_current_device())).view(1, 1, q_len, k_len).bool()
        x = torch.where(causal_mask, x, torch.tensor(-1e4, dtype=x.dtype, device=get_current_device()))

        if attention_mask is not None:
            x = x + attention_mask
        x = self.softmax(x)
        x = self.attention_dropout(x)

        x = torch.matmul(x, v)
        x = x.transpose(1, 2)
        new_context_layer_shape = x.size()[:-2] + (all_head_size, )
        x = x.reshape(new_context_layer_shape)

        x = self.dense(x)
        x = self.dropout(x)

        return x


@LAYERS.register_module
class GPTMLP(nn.Module):
    def __init__(self,
                 dim: int,
                 mlp_ratio: int,
                 activation: Callable,
                 dropout: float,
                 dtype: dtype = None,
                 bias: bool = True):
        super().__init__()
        self.dense_1 = col_nn.Linear(dim, mlp_ratio * dim, dtype=dtype, bias=bias)
        self.activation = activation
        self.dense_2 = col_nn.Linear(mlp_ratio * dim, dim, dtype=dtype, bias=bias)
        self.dropout = col_nn.Dropout(dropout)

    def forward(self, x):
        x = self.dense_1(x)
        x = self.activation(x)
        x = self.dense_2(x)
        x = self.dropout(x)
        return x


@LAYERS.register_module
class GPTBlock(CheckpointModule):
    def __init__(self,
                 dim: int,
                 num_heads: int,
                 mlp_ratio: int,
                 activation: Callable,
                 attention_dropout: float = 0.,
                 dropout: float = 0.,
                 dtype: dtype = None,
                 bias: bool = True,
                 checkpoint: bool = False):
136
        super().__init__(checkpoint=checkpoint)
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
        self.norm1 = col_nn.LayerNorm(normalized_shape=dim, eps=1e-6, dtype=dtype)
        self.attn = GPTSelfAttention(dim=dim,
                                     num_heads=num_heads,
                                     attention_dropout=attention_dropout,
                                     dropout=dropout,
                                     bias=bias,
                                     dtype=dtype)
        self.norm2 = col_nn.LayerNorm(normalized_shape=dim, eps=1e-6, dtype=dtype)
        self.mlp = GPTMLP(dim=dim, mlp_ratio=mlp_ratio, activation=activation, dropout=dropout, dtype=dtype, bias=bias)

    def _forward(self, x, attention_mask=None):
        x = x + self.attn(self.norm1(x), attention_mask)
        x = x + self.mlp(self.norm2(x))
        return x, attention_mask


@LAYERS.register_module
class GPTLMHead(nn.Module):
    def __init__(self,
                 dim: int,
                 vocab_size: int,
                 word_embeeding_weight: nn.Parameter = None,
                 bias: bool = False,
                 dtype: dtype = None) -> None:
        super().__init__()
        self.dense = col_nn.Classifier(dim, vocab_size, word_embeeding_weight, bias=bias, dtype=dtype)

    def forward(self, x):
        x = self.dense(x)
        return x


@LOSSES.register_module
class GPTLMLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.loss = col_nn.CrossEntropyLoss()

    def forward(self, logits, labels):
        shift_logits = logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()
        # Flatten the tokens
        return self.loss(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))


@MODELS.register_module
class GPT(nn.Module):
    def __init__(self,
                 vocab_size: int = 50304,
                 max_position_embeddings: int = 1024,
                 dim: int = 768,
                 num_heads: int = 12,
                 depth: int = 12,
                 mlp_ratio: int = 4,
                 dropout: float = 0.1,
                 embedding_dropout: float = 0.1,
                 attention_dropout: float = 0.1,
                 layernorm_epsilon: float = 1e-5,
                 activation: Callable = nn.functional.gelu,
                 checkpoint: bool = False,
                 dtype: dtype = None,
                 bias: bool = True,
                 padding_idx: int = 0) -> None:
        super().__init__()
        self.dtype = dtype
        self.embed = GPTEmbedding(embedding_dim=dim,
                                  vocab_size=vocab_size,
                                  max_position_embeddings=max_position_embeddings,
                                  padding_idx=padding_idx,
                                  dropout=embedding_dropout,
                                  dtype=dtype)
        self.blocks = nn.ModuleList([
            GPTBlock(
                dim=dim,
                num_heads=num_heads,
                mlp_ratio=mlp_ratio,
                activation=activation,
                attention_dropout=attention_dropout,
                dropout=dropout,
                dtype=dtype,
                bias=bias,
                checkpoint=checkpoint,
            ) for _ in range(depth)
        ])

        self.norm = col_nn.LayerNorm(normalized_shape=dim, eps=layernorm_epsilon, dtype=dtype)

        self.head = GPTLMHead(dim=dim,
                              vocab_size=vocab_size,
                              word_embeeding_weight=self.embed.word_embedding_weight,
                              bias=bias,
                              dtype=dtype)

    def forward(self, input_ids, attention_mask=None):
        # We create a 3D attention mask from a 2D tensor mask.
        # Sizes are [batch_size, 1, 1, to_seq_length]
        # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
        # Adapted from huggingface
        if attention_mask is not None:
            batch_size = input_ids.shape[0]
            attention_mask = attention_mask.view(batch_size, -1)
            attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
            attention_mask = attention_mask.to(dtype=self.dtype)  # fp16 compatibility
            attention_mask = (1.0 - attention_mask) * -10000.0

        x = self.embed(input_ids)

        for block in self.blocks:
            x, attention_mask = block(x, attention_mask)

        x = self.head(self.norm(x))

        return x


def _create_gpt_model(**model_kwargs):
    model = GPT(**model_kwargs)
    return model


@MODELS.register_module
def gpt2_small(**kwargs):
    model_kwargs = dict(dim=768, depth=12, num_heads=12, **kwargs)
    return _create_gpt_model(**model_kwargs)


@MODELS.register_module
def gpt2_medium(**kwargs):
    model_kwargs = dict(dim=1024, depth=24, num_heads=16, **kwargs)
    return _create_gpt_model(**model_kwargs)


@MODELS.register_module
def gpt2_large(**kwargs):
    model_kwargs = dict(dim=1280, depth=36, num_heads=20, **kwargs)
    return _create_gpt_model(**model_kwargs)


@MODELS.register_module
def gpt2_xl(**kwargs):
    model_kwargs = dict(dim=1600, depth=48, num_heads=25, **kwargs)
    return _create_gpt_model(**model_kwargs)


@MODELS.register_module
def gpt3(**kwargs):
    model_kwargs = dict(dim=12288, max_position_embeddings=2048, depth=96, num_heads=96, **kwargs)
    return _create_gpt_model(**model_kwargs)