bert.py 3.19 KB
Newer Older
1
2
3
4
import torch
import transformers
from packaging import version
from torch.utils.data import SequentialSampler
ver217's avatar
ver217 committed
5
6
from transformers import BertConfig, BertForSequenceClassification

7
8
9
10
from .registry import non_distributed_component_funcs


def get_bert_data_loader(
HELSON's avatar
HELSON committed
11
        n_class,
12
13
14
15
        batch_size,
        total_samples,
        sequence_length,
        device=torch.device('cpu:0'),
16
        is_distributed=False,
17
18
19
):
    train_data = torch.randint(
        low=0,
HELSON's avatar
HELSON committed
20
        high=n_class,
21
22
23
24
25
26
        size=(total_samples, sequence_length),
        device=device,
        dtype=torch.long,
    )
    train_label = torch.randint(low=0, high=2, size=(total_samples,), device=device, dtype=torch.long)
    train_dataset = torch.utils.data.TensorDataset(train_data, train_label)
27
    if is_distributed:
28
29
30
31
32
33
34
35
36
37
38
39
40
        sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
    else:
        sampler = SequentialSampler(train_dataset)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, sampler=sampler)
    return train_loader


@non_distributed_component_funcs.register(name='bert')
def get_training_components():
    hidden_dim = 8
    num_head = 4
    sequence_length = 12
    num_layer = 2
HELSON's avatar
HELSON committed
41
    vocab_size = 32
42

43
    def bert_model_builder(checkpoint: bool = False):
Ziyue Jiang's avatar
Ziyue Jiang committed
44
45
        config = BertConfig(vocab_size=vocab_size,
                            gradient_checkpointing=checkpoint,
ver217's avatar
ver217 committed
46
47
48
49
50
51
52
                            hidden_size=hidden_dim,
                            intermediate_size=hidden_dim * 4,
                            num_attention_heads=num_head,
                            max_position_embeddings=sequence_length,
                            num_hidden_layers=num_layer,
                            hidden_dropout_prob=0.,
                            attention_probs_dropout_prob=0.)
53
        print('building BertForSequenceClassification model')
jiaruifang's avatar
jiaruifang committed
54

55
56
        # adapting huggingface BertForSequenceClassification for single unittest calling interface
        class ModelAdaptor(BertForSequenceClassification):
jiaruifang's avatar
jiaruifang committed
57
58
59
60
61
62
63
64

            def forward(self, input_ids, labels):
                """
                inputs: data, label
                outputs: loss
                """
                return super().forward(input_ids=input_ids, labels=labels)[0]

65
        model = ModelAdaptor(config)
66
67
        if checkpoint and version.parse(transformers.__version__) >= version.parse("4.11.0"):
            model.gradient_checkpointing_enable()
jiaruifang's avatar
jiaruifang committed
68

69
70
        return model

71
    is_distributed = torch.distributed.is_initialized()
HELSON's avatar
HELSON committed
72
73
    trainloader = get_bert_data_loader(n_class=vocab_size,
                                       batch_size=2,
74
75
                                       total_samples=10000,
                                       sequence_length=sequence_length,
76
                                       is_distributed=is_distributed)
HELSON's avatar
HELSON committed
77
78
    testloader = get_bert_data_loader(n_class=vocab_size,
                                      batch_size=2,
79
80
                                      total_samples=10000,
                                      sequence_length=sequence_length,
81
                                      is_distributed=is_distributed)
82
83

    criterion = None
84
    return bert_model_builder, trainloader, testloader, torch.optim.Adam, criterion