gpt.py 17.3 KB
Newer Older
1
2
3
4
5
import math
from typing import Callable

import torch
from colossalai import nn as col_nn
6
7
8
9
10
11
12
from colossalai.builder.pipeline import partition_uniform
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
from colossalai.nn.layer.utils import CheckpointModule, divide
from colossalai.nn.layer.wrapper import PipelineSharedModuleWrapper
from colossalai.registry import LAYERS, LOSSES, MODELS
13
14
15
from colossalai.utils import get_current_device
from torch import dtype, nn

16
17
18
19
__all__ = [
    'GPT', 'GPTLMLoss', 'gpt2_small', 'gpt2_medium', 'gpt2_large', 'gpt2_xl', 'gpt2_8B', 'gpt2_xl_pipeline',
    'gpt2_8B_pipeline', 'gpt3', 'gpt3_pipeline'
]
20
21
22
23
24
25
26
27
28


@LAYERS.register_module
class GPTEmbedding(nn.Module):
    def __init__(self,
                 embedding_dim: int,
                 vocab_size: int,
                 max_position_embeddings: int,
                 num_tokentypes: int = 0,
29
                 padding_idx: int = None,
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
                 dropout: float = 0.,
                 dtype: dtype = None) -> None:
        super().__init__()
        self.word_embeddings = col_nn.Embedding(vocab_size, embedding_dim, padding_idx=padding_idx, dtype=dtype)
        self.position_embeddings = col_nn.Embedding(max_position_embeddings, embedding_dim, dtype=dtype)
        if num_tokentypes > 0:
            self.tokentype_embeddings = col_nn.Embedding(num_tokentypes, embedding_dim, dtype=dtype)
        else:
            self.tokentype_embeddings = None
        self.dropout = col_nn.Dropout(dropout)

    @property
    def word_embedding_weight(self):
        return self.word_embeddings.weight

45
    def forward(self, input_ids, attention_mask=None, position_ids=None, tokentype_ids=None):
46
47
48
49
50
51
52
        seq_length = input_ids.size(1)
        if position_ids is None:
            position_ids = torch.arange(seq_length, dtype=torch.long, device=get_current_device()).unsqueeze(0)
        x = self.word_embeddings(input_ids) + self.position_embeddings(position_ids)
        if self.tokentype_embeddings is not None and tokentype_ids is not None:
            x = x + self.tokentype_embeddings(tokentype_ids)
        x = self.dropout(x)
53
54
55
56
57
58
59
60
61
62
63
64
65
66

        # We create a 3D attention mask from a 2D tensor mask.
        # Sizes are [batch_size, 1, 1, to_seq_length]
        # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
        # Adapted from huggingface
        if attention_mask is not None:
            batch_size = input_ids.shape[0]
            attention_mask = attention_mask.view(batch_size, -1)
            attention_mask = col_nn.partition_batch(attention_mask)
            attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
            attention_mask = attention_mask.to(dtype=x.dtype)  # fp16 compatibility
            attention_mask = (1.0 - attention_mask) * -10000.0

        return x, attention_mask
67
68
69
70
71
72
73
74
75
76


@LAYERS.register_module
class GPTSelfAttention(nn.Module):
    def __init__(self,
                 dim: int,
                 num_heads: int,
                 attention_dropout: float,
                 dropout: float,
                 bias: bool = True,
77
                 fuse_scale_mask_softmax: bool = False,
78
79
                 dtype: dtype = None) -> None:
        super().__init__()
80
81
        self.fuse_scale_mask_softmax = fuse_scale_mask_softmax
        self.attention_head_size = divide(dim, num_heads)
82
        self.query_key_value = col_nn.Linear(dim, 3 * dim, dtype=dtype, bias=bias)
83
84
85
86
87
88
89
90
91
92
93
94
        if fuse_scale_mask_softmax:
            from colossalai.kernel import FusedScaleMaskSoftmax
            from colossalai.kernel.cuda_native.scaled_softmax import AttnMaskType
            self.softmax = FusedScaleMaskSoftmax(input_in_fp16=True,
                                                 input_in_bf16=False,
                                                 attn_mask_type=AttnMaskType.causal,
                                                 scaled_masked_softmax_fusion=True,
                                                 mask_func=None,
                                                 softmax_in_fp32=True,
                                                 scale=math.sqrt(self.attention_head_size))
        else:
            self.softmax = nn.Softmax(dim=-1)
95
96
97
98
99
100
101
        self.attention_dropout = col_nn.Dropout(attention_dropout)
        self.dense = col_nn.Linear(dim, dim, dtype=dtype, bias=True)
        self.dropout = col_nn.Dropout(dropout)

    def forward(self, x, attention_mask=None):
        qkv = self.query_key_value(x)
        all_head_size = qkv.shape[-1] // 3
102
        num_attention_heads = divide(all_head_size, self.attention_head_size)
103
104
105
106
107
108
109
110
        new_qkv_shape = qkv.shape[:-1] + \
            (num_attention_heads, 3 * self.attention_head_size)
        qkv = qkv.view(new_qkv_shape)
        qkv = qkv.permute((0, 2, 1, 3))
        q, k, v = torch.chunk(qkv, 3, dim=-1)

        x = torch.matmul(q, k.transpose(-1, -2))

111
112
113
114
115
116
117
118
119
120
121
122
        if self.fuse_scale_mask_softmax:
            x = self.softmax(x, attention_mask)
        else:
            x = x / math.sqrt(self.attention_head_size)
            # causal mask
            q_len, k_len = q.size(-2), k.size(-2)
            causal_mask = torch.tril(torch.ones((q_len, k_len), dtype=torch.uint8,
                                                device=get_current_device())).view(1, 1, q_len, k_len).bool()
            x = torch.where(causal_mask, x, torch.tensor(-1e4, dtype=x.dtype, device=get_current_device()))
            if attention_mask is not None:
                x = x + attention_mask
            x = self.softmax(x)
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140

        x = self.attention_dropout(x)

        x = torch.matmul(x, v)
        x = x.transpose(1, 2)
        new_context_layer_shape = x.size()[:-2] + (all_head_size, )
        x = x.reshape(new_context_layer_shape)

        x = self.dense(x)
        x = self.dropout(x)

        return x


@LAYERS.register_module
class GPTMLP(nn.Module):
    def __init__(self,
                 dim: int,
141
                 mlp_ratio: float,
142
143
144
145
146
                 activation: Callable,
                 dropout: float,
                 dtype: dtype = None,
                 bias: bool = True):
        super().__init__()
147
148
        intermediate_dim = int(dim * mlp_ratio)
        self.dense_1 = col_nn.Linear(dim, intermediate_dim, dtype=dtype, bias=bias)
149
        self.activation = activation
150
        self.dense_2 = col_nn.Linear(intermediate_dim, dim, dtype=dtype, bias=bias)
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
        self.dropout = col_nn.Dropout(dropout)

    def forward(self, x):
        x = self.dense_1(x)
        x = self.activation(x)
        x = self.dense_2(x)
        x = self.dropout(x)
        return x


@LAYERS.register_module
class GPTBlock(CheckpointModule):
    def __init__(self,
                 dim: int,
                 num_heads: int,
166
                 mlp_ratio: float,
167
168
169
                 activation: Callable,
                 attention_dropout: float = 0.,
                 dropout: float = 0.,
170
                 layernorm_epsilon: float = 1e-5,
171
172
                 dtype: dtype = None,
                 bias: bool = True,
173
174
                 apply_post_layernorm: bool = False,
                 fuse_scale_mask_softmax: bool = False,
175
                 checkpoint: bool = False):
176
177
178
        super().__init__(checkpoint)
        self.apply_post_layernorm = apply_post_layernorm
        self.norm1 = col_nn.LayerNorm(normalized_shape=dim, eps=layernorm_epsilon, dtype=dtype)
179
180
181
182
183
        self.attn = GPTSelfAttention(dim=dim,
                                     num_heads=num_heads,
                                     attention_dropout=attention_dropout,
                                     dropout=dropout,
                                     bias=bias,
184
                                     fuse_scale_mask_softmax=fuse_scale_mask_softmax,
185
                                     dtype=dtype)
186
        self.norm2 = col_nn.LayerNorm(normalized_shape=dim, eps=layernorm_epsilon, dtype=dtype)
187
188
189
        self.mlp = GPTMLP(dim=dim, mlp_ratio=mlp_ratio, activation=activation, dropout=dropout, dtype=dtype, bias=bias)

    def _forward(self, x, attention_mask=None):
190
191
192
193
194
195
196
197
198
199
200
201
202
203
        if not self.apply_post_layernorm:
            residual = x
        x = self.norm1(x)
        if self.apply_post_layernorm:
            residual = x
        x = residual + self.attn(x, attention_mask)

        if not self.apply_post_layernorm:
            residual = x
        x = self.norm2(x)
        if self.apply_post_layernorm:
            residual = x
        x = residual + self.mlp(x)

204
205
206
207
208
209
210
211
212
213
214
215
216
217
        return x, attention_mask


@LAYERS.register_module
class GPTLMHead(nn.Module):
    def __init__(self,
                 dim: int,
                 vocab_size: int,
                 word_embeeding_weight: nn.Parameter = None,
                 bias: bool = False,
                 dtype: dtype = None) -> None:
        super().__init__()
        self.dense = col_nn.Classifier(dim, vocab_size, word_embeeding_weight, bias=bias, dtype=dtype)

218
219
220
221
    @property
    def weight(self):
        return self.dense.weight

222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
    def forward(self, x):
        x = self.dense(x)
        return x


@LOSSES.register_module
class GPTLMLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.loss = col_nn.CrossEntropyLoss()

    def forward(self, logits, labels):
        shift_logits = logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()
        # Flatten the tokens
        return self.loss(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))


@MODELS.register_module
class GPT(nn.Module):
    def __init__(self,
                 vocab_size: int = 50304,
                 max_position_embeddings: int = 1024,
                 dim: int = 768,
                 num_heads: int = 12,
                 depth: int = 12,
248
                 mlp_ratio: float = 4.0,
249
250
251
252
253
                 dropout: float = 0.1,
                 embedding_dropout: float = 0.1,
                 attention_dropout: float = 0.1,
                 layernorm_epsilon: float = 1e-5,
                 activation: Callable = nn.functional.gelu,
254
                 padding_idx: int = None,
255
256
                 dtype: dtype = None,
                 bias: bool = True,
257
258
259
                 apply_post_layernorm: bool = False,
                 fuse_scale_mask_softmax: bool = False,
                 checkpoint: bool = False) -> None:
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
        super().__init__()
        self.embed = GPTEmbedding(embedding_dim=dim,
                                  vocab_size=vocab_size,
                                  max_position_embeddings=max_position_embeddings,
                                  padding_idx=padding_idx,
                                  dropout=embedding_dropout,
                                  dtype=dtype)
        self.blocks = nn.ModuleList([
            GPTBlock(
                dim=dim,
                num_heads=num_heads,
                mlp_ratio=mlp_ratio,
                activation=activation,
                attention_dropout=attention_dropout,
                dropout=dropout,
275
                layernorm_epsilon=layernorm_epsilon,
276
277
                dtype=dtype,
                bias=bias,
278
279
                apply_post_layernorm=apply_post_layernorm,
                fuse_scale_mask_softmax=fuse_scale_mask_softmax,
280
281
282
283
284
285
286
287
288
289
290
291
                checkpoint=checkpoint,
            ) for _ in range(depth)
        ])

        self.norm = col_nn.LayerNorm(normalized_shape=dim, eps=layernorm_epsilon, dtype=dtype)

        self.head = GPTLMHead(dim=dim,
                              vocab_size=vocab_size,
                              word_embeeding_weight=self.embed.word_embedding_weight,
                              dtype=dtype)

    def forward(self, input_ids, attention_mask=None):
292
        x, attention_mask = self.embed(input_ids, attention_mask)
293
294
295
296
297
298
299
300
301

        for block in self.blocks:
            x, attention_mask = block(x, attention_mask)

        x = self.head(self.norm(x))

        return x


302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
class PipelineGPT(nn.Module):
    def __init__(self,
                 vocab_size: int = 50304,
                 max_position_embeddings: int = 1024,
                 dim: int = 768,
                 num_heads: int = 12,
                 depth: int = 12,
                 mlp_ratio: float = 4.0,
                 dropout: float = 0.1,
                 embedding_dropout: float = 0.1,
                 attention_dropout: float = 0.1,
                 layernorm_epsilon: float = 1e-5,
                 activation: Callable = nn.functional.gelu,
                 padding_idx: int = None,
                 dtype: dtype = None,
                 bias: bool = True,
                 apply_post_layernorm: bool = False,
                 fuse_scale_mask_softmax: bool = False,
                 checkpoint: bool = False,
                 first: bool = False,
                 last: bool = False):
        super().__init__()
        self.checkpoint = checkpoint
        self.first = first
        self.last = last
        if first:
            self.embed = GPTEmbedding(embedding_dim=dim,
                                      vocab_size=vocab_size,
                                      max_position_embeddings=max_position_embeddings,
                                      padding_idx=padding_idx,
                                      dropout=embedding_dropout,
                                      dtype=dtype)
        self.blocks = nn.ModuleList([
            GPTBlock(
                dim=dim,
                num_heads=num_heads,
                mlp_ratio=mlp_ratio,
                activation=activation,
                attention_dropout=attention_dropout,
                dropout=dropout,
                layernorm_epsilon=layernorm_epsilon,
                dtype=dtype,
                bias=bias,
                apply_post_layernorm=apply_post_layernorm,
                fuse_scale_mask_softmax=fuse_scale_mask_softmax,
                checkpoint=checkpoint,
            ) for _ in range(depth)
        ])
        if self.last:
            self.norm = col_nn.LayerNorm(normalized_shape=dim, eps=layernorm_epsilon, dtype=dtype)
            self.head = GPTLMHead(dim=dim, vocab_size=vocab_size, dtype=dtype)

    def forward(self, x=None, input_ids=None, attention_mask=None):
        if self.first:
            x, attention_mask = self.embed(input_ids, attention_mask)

        for block in self.blocks:
            x, attention_mask = block(x, attention_mask)

        if self.last:
            x = self.head(self.norm(x))

        return x


367
368
369
370
371
def _create_gpt_model(**model_kwargs):
    model = GPT(**model_kwargs)
    return model


372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
def _create_gpt_pipeline_model(depth=48, num_chunks=1, layer_partitions=None, **model_kwargs):
    logger = get_dist_logger()
    pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
    pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
    rank = gpc.get_global_rank()
    wrapper = PipelineSharedModuleWrapper([0, pipeline_size - 1])
    parts = partition_uniform(depth, pipeline_size,
                              num_chunks)[pipeline_rank] if layer_partitions is None else layer_partitions
    models = []
    for start, end in parts:
        model_kwargs['first'] = start == 0
        model_kwargs['last'] = end == depth
        model_kwargs['depth'] = end - start
        chunk = PipelineGPT(**model_kwargs).to(get_current_device())
        if start == 0:
            wrapper.register_parameter(chunk.embed.word_embedding_weight)
        elif end == depth:
            wrapper.register_parameter(chunk.head.weight)
        models.append(chunk)
        logger.info(f'==> Rank {rank} built layer {start}-{end} / total {depth}')
    if len(models) == 1:
        model = models[0]
    else:
        model = nn.ModuleList(models)
    return model


399
400
401
402
403
404
405
406
@MODELS.register_module
def gpt2_small(**kwargs):
    model_kwargs = dict(dim=768, depth=12, num_heads=12, **kwargs)
    return _create_gpt_model(**model_kwargs)


@MODELS.register_module
def gpt2_medium(**kwargs):
407
    model_kwargs = dict(dim=1024, depth=24, num_heads=8, **kwargs)
408
409
410
411
412
    return _create_gpt_model(**model_kwargs)


@MODELS.register_module
def gpt2_large(**kwargs):
413
    model_kwargs = dict(dim=1536, depth=36, num_heads=12, **kwargs)
414
415
416
417
418
    return _create_gpt_model(**model_kwargs)


@MODELS.register_module
def gpt2_xl(**kwargs):
419
    model_kwargs = dict(dim=1600, depth=48, num_heads=16, **kwargs)
420
421
422
    return _create_gpt_model(**model_kwargs)


423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
@MODELS.register_module
def gpt2_8B(**kwargs):
    model_kwargs = dict(dim=3072, depth=72, num_heads=24, **kwargs)
    return _create_gpt_model(**model_kwargs)


@MODELS.register_module
def gpt2_xl_pipeline(**kwargs):
    model_kwargs = dict(dim=1600, depth=48, num_heads=20, **kwargs)
    return _create_gpt_pipeline_model(**model_kwargs)


@MODELS.register_module
def gpt2_8B_pipeline(**kwargs):
    model_kwargs = dict(dim=3072, depth=72, num_heads=24, **kwargs)
    return _create_gpt_pipeline_model(**model_kwargs)


441
442
@MODELS.register_module
def gpt3(**kwargs):
443
    model_kwargs = dict(dim=12288, depth=96, num_heads=96, **kwargs)
444
    return _create_gpt_model(**model_kwargs)
445
446
447
448
449
450


@MODELS.register_module
def gpt3_pipeline(**kwargs):
    model_kwargs = dict(dim=12288, depth=96, num_heads=96, **kwargs)
    return _create_gpt_pipeline_model(**model_kwargs)