oss.rst 4.72 KB
Newer Older
1
2
3
Optimizer state sharding
========================

4
Using torch.nn.parallel.DistributedDataParallel leads to some wasted communications in the case of OSS, but it is possible and makes OSS a drop in solution in your existing torch distributed code.
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
5
Let's suppose that your trainer looks like
6

Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
7
.. code-block:: python
8
9
10


    import torch
11
12
    from torch.nn.parallel import DistributedDataParallel as DDP

13
14
15
16
17
18
19
20
21
22

    def train(
        rank: int,
        world_size: int,
        epochs: int):

        # DDP
        dist_init(rank, world_size)

        # Problem statement
Vittorio Caggiano's avatar
Vittorio Caggiano committed
23
        model = myAwesomeModel().to(rank)
24
        model = DDP(model, device_ids=[rank])
25
        dataloader = mySuperFastDataloader()
Vittorio Caggiano's avatar
Vittorio Caggiano committed
26
        loss_ln = myVeryRelevantLoss()
27

Vittorio Caggiano's avatar
Vittorio Caggiano committed
28
        # optimizer specific arguments e.g. LR, momentum, etc...
29
        base_optimizer_arguments = { "lr": 1e-4}
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
30
31
32
        optimizer = torch.optim.SGD(
            params=model.parameters(),
            **base_optimizer_arguments)
33
34
35
36

        # Any relevant training loop, nothing specific to OSS. For example:
        model.train()
        for e in range(epochs):
Vittorio Caggiano's avatar
Vittorio Caggiano committed
37
38
            for (data, target) in dataloader:
                data, target = data.to(rank), target.to(rank)
39
40
                # Train
                model.zero_grad()
Vittorio Caggiano's avatar
Vittorio Caggiano committed
41
42
                outputs = model(data)
                loss = loss_fn(outputs, target)
43
44
45
46
47
48
                loss.backward()
                optimizer.step()


Then sharding the optimizer state is merely a matter of wrapping your optimizer in fairscale.optim.OSS, as follows

Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
49
.. code-block:: python
50
51
52
53


    import torch
    from fairscale.optim.oss import OSS
54
    from torch.nn.parallel import DistributedDataParallel as DDP
55
56
57
58
59
60
61
62
63
64

    def train(
        rank: int,
        world_size: int,
        epochs: int):

        # DDP
        dist_init(rank, world_size)

        # Problem statement
Vittorio Caggiano's avatar
Vittorio Caggiano committed
65
        model = myAwesomeModel().to(rank)
66
        model = DDP(model, device_ids=[rank])
67
        dataloader = mySuperFastDataloader()
Vittorio Caggiano's avatar
Vittorio Caggiano committed
68
        loss_ln = myVeryRelevantLoss()
69

Vittorio Caggiano's avatar
Vittorio Caggiano committed
70
        # optimizer specific arguments e.g. LR, momentum, etc...
71
        base_optimizer_arguments = { "lr": 1e-4}
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
72
73

        # ** NEW ** Wrap a base optimizer into OSS
74
        base_optimizer = torch.optim.SGD  # any pytorch compliant optimizer
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
75
76
77
78
        optimizer = OSS(
            params=model.parameters(),
            optim=base_optimizer,
            **base_optimizer_arguments)
79
80
81
82

        # Any relevant training loop, nothing specific to OSS. For example:
        model.train()
        for e in range(epochs):
Vittorio Caggiano's avatar
Vittorio Caggiano committed
83
84
            for (data, target) in dataloader:
                data, target = data.to(rank), target.to(rank)
85
86
                # Train
                model.zero_grad()
Vittorio Caggiano's avatar
Vittorio Caggiano committed
87
88
                outputs = model(data)
                loss = loss_fn(outputs, target)
89
90
                loss.backward()
                optimizer.step()
Vittorio Caggiano's avatar
Vittorio Caggiano committed
91
92
93
94
95
96
97
98
99
100
101
102
103


The above `train` function will then need to be run via a `multiprocessing.spawn` function.

.. code-block:: python


    mp.spawn(
            train,
            args=(WORLD_SIZE, EPOCHS),
            nprocs=WORLD_SIZE,
            join=True
        )
104

105

106
to see it in action, you can test it with the following script `here <../../../examples/tutorial_oss.py>`_.
107

108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
Using PyTorch Automatic Mixed Precision is possible, but it requires a shard-aware GradScaler, which is available in
`fairscale.optim.grad_scaler`. Autocast can be used as is, and the loss will be scaled and handled in the same way.
See [the original documentation] (https://pytorch.org/docs/stable/notes/amp_examples.html?highlight=automatic%20mixed%20precision)
for more information.

.. code-block:: python



    from fairscale.optim.grad_scaler import ShardedGradScaler


    # Creates model and optimizer in default precision
    model = Net().cuda()
    optimizer = optim.SGD(model.parameters(), ...)

    # Creates a ShardedGradScaler once at the beginning of training.
    scaler = ShardedGradScaler()

    for epoch in epochs:
        for input, target in data:
            optimizer.zero_grad()

            # Runs the forward pass with autocasting.
            with autocast():
                output = model(input)
                loss = loss_fn(output, target)

            # Scales loss.  Calls backward() on scaled loss to create scaled gradients.
            # Backward passes under autocast are not recommended.
            # Backward ops run in the same dtype autocast chose for corresponding forward ops.
            scaler.scale(loss).backward()

            # scaler.step() first unscales the gradients of the optimizer's assigned params.
            # If these gradients do not contain infs or NaNs, optimizer.step() is then called,
            # otherwise, optimizer.step() is skipped.
            scaler.step(optimizer)

            # Updates the scale for next iteration.
            scaler.update()