adascale.rst 3.27 KB
Newer Older
Min Xu's avatar
Min Xu committed
1
2
AdaScale SGD Tutorial
=====================
3

Min Xu's avatar
Min Xu committed
4
5
6
`AdaScale <https://arxiv.org/pdf/2007.05105.pdf>`_ adaptively scales the learning rate when
using larger batch sizes for data-parallel training. Let's suppose that your trainer looks
like the following.
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33

.. code-block:: python


    import torch
    from torch.nn.parallel import DistributedDataParallel as DDP


    def train(
        rank: int,
        world_size: int,
        epochs: int):

        # DDP
        dist_init(rank, world_size)

        # Problem statement
        model = myAwesomeModel().to(rank)
        model = DDP(model, device_ids=[rank])
        dataloader = mySuperFastDataloader()
        loss_ln = myVeryRelevantLoss()

        # optimizer specific arguments e.g. LR, momentum, etc...
        base_optimizer_arguments = { "lr": 1e-4}
        optimizer = torch.optim.SGD(
            params=model.parameters(),
            **base_optimizer_arguments)
34
35
        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
            lr_lambda = lambda x: 1/10**x)
36
37
38
39
40
41
42
43
44
45
46
47

        # Any relevant training loop. For example:
        model.train()
        for e in range(epochs):
            for (data, target) in dataloader:
                data, target = data.to(rank), target.to(rank)
                # Train
                model.zero_grad()
                outputs = model(data)
                loss = loss_fn(outputs, target)
                loss.backward()
                optimizer.step()
48
            scheduler.step()
49
50


51
52
53
Applying AdaScale is as simple as wrapping your SGD optimizer with
`fairscale.optim.AdaScale`, as follows and uses its gain() to update
the effective step and compute learning rate schedule accordingly.
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81

.. code-block:: python


    import torch
    from fairscale.optim.adascale import AdaScale
    from torch.nn.parallel import DistributedDataParallel as DDP


    def train(
        rank: int,
        world_size: int,
        epochs: int):

        # DDP
        dist_init(rank, world_size)

        # Problem statement
        model = myAwesomeModel().to(rank)
        model = DDP(model, device_ids=[rank])
        dataloader = mySuperFastDataloader()
        loss_ln = myVeryRelevantLoss()

        # optimizer specific arguments e.g. LR, momentum, etc...
        base_optimizer_arguments = { "lr": 1e-4}
        optimizer = torch.optim.SGD(
            params=model.parameters(),
            **base_optimizer_arguments)
82
83
        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
            lr_lambda = lambda x: 1/10**x)
84
85
86
87
88
89

        # Wrap optimizer with AdaScale
        optimizer = AdaScale(optimizer)

        # Any relevant training loop. For example:
        model.train()
90
91
92
93
        last_epoch = 0
        step = 0
        done = False
        while not done:
94
95
96
97
98
99
100
            for (data, target) in dataloader:
                data, target = data.to(rank), target.to(rank)
                # Train
                model.zero_grad()
                outputs = model(data)
                loss = loss_fn(outputs, target)
                loss.backward()
101
                step += optimizer.gain()
102
                optimizer.step()
103
104
105
106
107
108
                epoch = step // len(dataloader)
                if last_epoch != epoch:
                    scheduler.step()
                    last_epoch = epoch
                if epoch >= epochs:
                    done = True