gpt2.py 2.63 KB
Newer Older
1
2
3
import torch
import torch.nn as nn
from transformers import GPT2Config, GPT2LMHeadModel
4

5
6
from colossalai.utils.cuda import get_current_device

7
8
9
from .registry import non_distributed_component_funcs
from .utils.dummy_data_generator import DummyDataGenerator

10
11

class DummyDataLoader(DummyDataGenerator):
Frank Lee's avatar
Frank Lee committed
12
    vocab_size = 128
13
    batch_size = 4
Frank Lee's avatar
Frank Lee committed
14
    seq_len = 64
15
16

    def generate(self):
17
18
19
20
21
22
        input_ids = torch.randint(
            0,
            DummyDataLoader.vocab_size,
            (DummyDataLoader.batch_size, DummyDataLoader.seq_len),
            device=get_current_device(),
        )
23
        return input_ids, input_ids
24
25
26


class GPTLMModel(nn.Module):
27
28
29
30
31
32
33
34
35
    def __init__(
        self,
        hidden_size=768,
        num_layers=12,
        num_attention_heads=12,
        max_seq_len=1024,
        vocab_size=50304,
        checkpoint=False,
    ):
36
37
38
        super().__init__()
        self.checkpoint = checkpoint
        self.model = GPT2LMHeadModel(
39
40
41
42
43
44
45
46
47
48
49
50
            GPT2Config(
                n_embd=hidden_size,
                n_layer=num_layers,
                n_head=num_attention_heads,
                n_positions=max_seq_len,
                n_ctx=max_seq_len,
                vocab_size=vocab_size,
                resid_pdrop=0.0,
                embd_pdrop=0.0,
                attn_pdrop=0.0,
            )
        )
51
52
53
        if checkpoint:
            self.model.gradient_checkpointing_enable()

54
    def forward(self, input_ids):
55
        # Only return lm_logits
56
        attention_mask = torch.ones_like(input_ids)
57
58
        return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=not self.checkpoint)[0]

59

Frank Lee's avatar
Frank Lee committed
60
def gpt2_micro(checkpoint=True):
61
62
63
    return GPTLMModel(
        checkpoint=checkpoint, hidden_size=32, num_layers=2, num_attention_heads=4, max_seq_len=64, vocab_size=128
    )
64

65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85

def gpt2_s(checkpoint=True):
    return GPTLMModel(checkpoint=checkpoint)


def gpt2_m(checkpoint=True):
    return GPTLMModel(hidden_size=1024, num_layers=24, num_attention_heads=16, checkpoint=checkpoint)


class GPTLMLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.loss_fn = nn.CrossEntropyLoss()

    def forward(self, logits, labels):
        shift_logits = logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()
        # Flatten the tokens
        return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))


86
@non_distributed_component_funcs.register(name="gpt2")
87
88
89
90
91
def get_training_components():
    trainloader = DummyDataLoader()
    testloader = DummyDataLoader()

    criterion = GPTLMLoss()
Frank Lee's avatar
Frank Lee committed
92
    return gpt2_micro, trainloader, testloader, torch.optim.Adam, criterion