train_simclr.py 3.02 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import colossalai
from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
from colossalai.trainer import Trainer, hooks
from colossalai.utils import get_dataloader, MultiTimer
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR

from torchvision.datasets import CIFAR10
from NT_Xentloss import NT_Xentloss
from myhooks import TotalBatchsizeHook
from models.simclr import SimCLR
from augmentation import SimCLRTransform

def build_dataset_train():
    augment = SimCLRTransform()
    train_dataset = CIFAR10(root=gpc.config.dataset.root, 
                                    transform=augment,
                                    train=True,
                                    download=True)
                         
    return get_dataloader(
        dataset=train_dataset,
        shuffle=True, 
        num_workers = 1,
        batch_size=gpc.config.BATCH_SIZE,
        pin_memory=True,
    )

def build_dataset_test():
    augment = SimCLRTransform()
    val_dataset = CIFAR10(root=gpc.config.dataset.root, 
                                    transform=augment,
                                    train=False)
    
    return get_dataloader(
        dataset=val_dataset,
        add_sampler=False,
        num_workers = 1,
        batch_size=gpc.config.BATCH_SIZE,
        pin_memory=True,
    )

def main():
アマデウス's avatar
アマデウス committed
44
    colossalai.launch_from_torch(config='./config.py')
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
    
    # get logger
    logger = get_dist_logger()

    ## build model
    model = SimCLR(model='resnet18')

    # build dataloader
    train_dataloader = build_dataset_train()
    test_dataloader = build_dataset_test()

    # build loss
    criterion = NT_Xentloss()

    # build optimizer
    optimizer = colossalai.nn.FusedSGD(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY, momentum=gpc.config.MOMENTUM)

    # lr_scheduelr
    lr_scheduler = CosineAnnealingWarmupLR(optimizer, warmup_steps=10, total_steps=gpc.config.NUM_EPOCHS)

    engine, train_dataloader, test_dataloader, _ = colossalai.initialize(
        model, optimizer, criterion, train_dataloader, test_dataloader
    )
    logger.info("initialized colossalai components", ranks=[0])

    # build a timer to measure time
    timer = MultiTimer()

    # build trainer
    trainer = Trainer(engine=engine, logger=logger, timer=timer)

    # build hooks
    hook_list = [
        hooks.LossHook(),
        hooks.LogMetricByEpochHook(logger),
        hooks.LRSchedulerHook(lr_scheduler, by_epoch=True),
        TotalBatchsizeHook(),

        # comment if you do not need to use the hooks below
        hooks.SaveCheckpointHook(interval=50, checkpoint_dir=f'./ckpt/{gpc.config.LOG_NAME}'),
        hooks.TensorboardHook(log_dir=f'./tb_logs/{gpc.config.LOG_NAME}', ranks=[0]),
    ]

    # start training
    trainer.fit(
        train_dataloader=train_dataloader,
        test_dataloader=test_dataloader,
        epochs=gpc.config.NUM_EPOCHS,
        hooks=hook_list,
        display_progress=True,
        test_interval=1
    )


if __name__ == '__main__':
    main()