oss.rst 3.17 KB
Newer Older
1
2
3
Optimizer state sharding
========================

4
Using torch.nn.parallel.DistributedDataParallel leads to some wasted communications in the case of OSS, but it is possible and makes OSS a drop in solution in your existing torch distributed code.
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
5
Let's suppose that your trainer looks like
6

Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
7
.. code-block:: python
8
9
10


    import torch
11
12
    from torch.nn.parallel import DistributedDataParallel as DDP

13
14
15
16
17
18
19
20
21
22

    def train(
        rank: int,
        world_size: int,
        epochs: int):

        # DDP
        dist_init(rank, world_size)

        # Problem statement
Vittorio Caggiano's avatar
Vittorio Caggiano committed
23
        model = myAwesomeModel().to(rank)
24
        model = DDP(model, device_ids=[rank])
25
        dataloader = mySuperFastDataloader()
Vittorio Caggiano's avatar
Vittorio Caggiano committed
26
        loss_ln = myVeryRelevantLoss()
27

Vittorio Caggiano's avatar
Vittorio Caggiano committed
28
        # optimizer specific arguments e.g. LR, momentum, etc...
29
        base_optimizer_arguments = { "lr": 1e-4}
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
30
31
32
        optimizer = torch.optim.SGD(
            params=model.parameters(),
            **base_optimizer_arguments)
33
34
35
36

        # Any relevant training loop, nothing specific to OSS. For example:
        model.train()
        for e in range(epochs):
Vittorio Caggiano's avatar
Vittorio Caggiano committed
37
38
            for (data, target) in dataloader:
                data, target = data.to(rank), target.to(rank)
39
40
                # Train
                model.zero_grad()
Vittorio Caggiano's avatar
Vittorio Caggiano committed
41
42
                outputs = model(data)
                loss = loss_fn(outputs, target)
43
44
45
46
47
48
                loss.backward()
                optimizer.step()


Then sharding the optimizer state is merely a matter of wrapping your optimizer in fairscale.optim.OSS, as follows

Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
49
.. code-block:: python
50
51
52
53


    import torch
    from fairscale.optim.oss import OSS
54
    from torch.nn.parallel import DistributedDataParallel as DDP
55
56
57
58
59
60
61
62
63
64

    def train(
        rank: int,
        world_size: int,
        epochs: int):

        # DDP
        dist_init(rank, world_size)

        # Problem statement
Vittorio Caggiano's avatar
Vittorio Caggiano committed
65
        model = myAwesomeModel().to(rank)
66
        model = DDP(model, device_ids=[rank])
67
        dataloader = mySuperFastDataloader()
Vittorio Caggiano's avatar
Vittorio Caggiano committed
68
        loss_ln = myVeryRelevantLoss()
69

Vittorio Caggiano's avatar
Vittorio Caggiano committed
70
        # optimizer specific arguments e.g. LR, momentum, etc...
71
        base_optimizer_arguments = { "lr": 1e-4}
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
72
73

        # ** NEW ** Wrap a base optimizer into OSS
74
        base_optimizer = torch.optim.SGD  # any pytorch compliant optimizer
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
75
76
77
78
        optimizer = OSS(
            params=model.parameters(),
            optim=base_optimizer,
            **base_optimizer_arguments)
79
80
81
82

        # Any relevant training loop, nothing specific to OSS. For example:
        model.train()
        for e in range(epochs):
Vittorio Caggiano's avatar
Vittorio Caggiano committed
83
84
            for (data, target) in dataloader:
                data, target = data.to(rank), target.to(rank)
85
86
                # Train
                model.zero_grad()
Vittorio Caggiano's avatar
Vittorio Caggiano committed
87
88
                outputs = model(data)
                loss = loss_fn(outputs, target)
89
90
91
                loss /= world_size
                loss.backward()
                optimizer.step()
Vittorio Caggiano's avatar
Vittorio Caggiano committed
92
93
94
95
96
97
98
99
100
101
102
103
104


The above `train` function will then need to be run via a `multiprocessing.spawn` function.

.. code-block:: python


    mp.spawn(
            train,
            args=(WORLD_SIZE, EPOCHS),
            nprocs=WORLD_SIZE,
            join=True
        )
105

Vittorio Caggiano's avatar
Vittorio Caggiano committed
106
to see it in action, you can test it with the following script _`tutorial_oss.py <../../../examples/tutorial_oss.py>`_