albert.py 2.04 KB
Newer Older
1
2
3
4
5
6
7
import torch
from transformers import AlbertConfig, AlbertForSequenceClassification

from .bert import get_bert_data_loader
from .registry import non_distributed_component_funcs


8
@non_distributed_component_funcs.register(name="albert")
9
10
11
12
13
14
15
16
def get_training_components():
    hidden_dim = 8
    num_head = 4
    sequence_length = 12
    num_layer = 2
    vocab_size = 32

    def bert_model_builder(checkpoint: bool = False):
17
18
19
20
21
22
23
24
25
26
27
28
        config = AlbertConfig(
            vocab_size=vocab_size,
            gradient_checkpointing=checkpoint,
            hidden_size=hidden_dim,
            intermediate_size=hidden_dim * 4,
            num_attention_heads=num_head,
            max_position_embeddings=sequence_length,
            num_hidden_layers=num_layer,
            hidden_dropout_prob=0.0,
            attention_probs_dropout_prob=0.0,
        )
        print("building AlbertForSequenceClassification model")
29

30
        # adapting huggingface BertForSequenceClassification for single unittest calling interface
31
        class ModelAdaptor(AlbertForSequenceClassification):
32
33
34
35
36
37
38
            def forward(self, input_ids, labels):
                """
                inputs: data, label
                outputs: loss
                """
                return super().forward(input_ids=input_ids, labels=labels)[0]

39
        model = ModelAdaptor(config)
40
41
42
43
44
        # if checkpoint and version.parse(transformers.__version__) >= version.parse("4.11.0"):
        #     model.gradient_checkpointing_enable()

        return model

45
    is_distributed = torch.distributed.is_initialized()
46
47
48
49
50
51
52
53
54
55
56
57
58
59
    trainloader = get_bert_data_loader(
        n_class=vocab_size,
        batch_size=2,
        total_samples=10000,
        sequence_length=sequence_length,
        is_distributed=is_distributed,
    )
    testloader = get_bert_data_loader(
        n_class=vocab_size,
        batch_size=2,
        total_samples=10000,
        sequence_length=sequence_length,
        is_distributed=is_distributed,
    )
60
61
62

    criterion = None
    return bert_model_builder, trainloader, testloader, torch.optim.Adam, criterion