oss.rst 3.15 KB
Newer Older
1
2
3
4
Optimizer state sharding
========================

Using torch.nn.parallel.DistributedDataParallel leads to some wasted communications, but it is possible and makes OSS a drop in solution in your existing torch distributed code.
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
5
Let's suppose that your trainer looks like
6

Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
7
.. code-block:: python
8
9
10
11
12
13
14
15
16
17
18
19
20


    import torch

    def train(
        rank: int,
        world_size: int,
        epochs: int):

        # DDP
        dist_init(rank, world_size)

        # Problem statement
Vittorio Caggiano's avatar
Vittorio Caggiano committed
21
        model = myAwesomeModel().to(rank)
22
        dataloader = mySuperFastDataloader()
Vittorio Caggiano's avatar
Vittorio Caggiano committed
23
        loss_ln = myVeryRelevantLoss()
24

Vittorio Caggiano's avatar
Vittorio Caggiano committed
25
26
        # optimizer specific arguments e.g. LR, momentum, etc...
        base_optimizer_arguments = { "lr": 1e-4} 
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
27
28
29
        optimizer = torch.optim.SGD(
            params=model.parameters(),
            **base_optimizer_arguments)
30
31
32
33

        # Any relevant training loop, nothing specific to OSS. For example:
        model.train()
        for e in range(epochs):
Vittorio Caggiano's avatar
Vittorio Caggiano committed
34
35
            for (data, target) in dataloader:
                data, target = data.to(rank), target.to(rank)
36
37
                # Train
                model.zero_grad()
Vittorio Caggiano's avatar
Vittorio Caggiano committed
38
39
                outputs = model(data)
                loss = loss_fn(outputs, target)
40
41
                loss /= world_size
                loss.backward()
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
42
                torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM)
43
44
45
46
47
                optimizer.step()


Then sharding the optimizer state is merely a matter of wrapping your optimizer in fairscale.optim.OSS, as follows

Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
48
.. code-block:: python
49
50
51
52
53
54
55
56
57
58
59
60
61
62


    import torch
    from fairscale.optim.oss import OSS

    def train(
        rank: int,
        world_size: int,
        epochs: int):

        # DDP
        dist_init(rank, world_size)

        # Problem statement
Vittorio Caggiano's avatar
Vittorio Caggiano committed
63
        model = myAwesomeModel().to(rank)
64
        dataloader = mySuperFastDataloader()
Vittorio Caggiano's avatar
Vittorio Caggiano committed
65
        loss_ln = myVeryRelevantLoss()
66

Vittorio Caggiano's avatar
Vittorio Caggiano committed
67
68
        # optimizer specific arguments e.g. LR, momentum, etc...
        base_optimizer_arguments = { "lr": 1e-4} 
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
69
70

        # ** NEW ** Wrap a base optimizer into OSS
71
        base_optimizer = torch.optim.SGD  # any pytorch compliant optimizer
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
72
73
74
75
        optimizer = OSS(
            params=model.parameters(),
            optim=base_optimizer,
            **base_optimizer_arguments)
76
77
78
79

        # Any relevant training loop, nothing specific to OSS. For example:
        model.train()
        for e in range(epochs):
Vittorio Caggiano's avatar
Vittorio Caggiano committed
80
81
            for (data, target) in dataloader:
                data, target = data.to(rank), target.to(rank)
82
83
                # Train
                model.zero_grad()
Vittorio Caggiano's avatar
Vittorio Caggiano committed
84
85
                outputs = model(data)
                loss = loss_fn(outputs, target)
86
87
                loss /= world_size
                loss.backward()
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
88
                torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM)
89
                optimizer.step()
Vittorio Caggiano's avatar
Vittorio Caggiano committed
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104


The above `train` function will then need to be run via a `multiprocessing.spawn` function.

.. code-block:: python


    mp.spawn(
            train,
            args=(WORLD_SIZE, EPOCHS),
            nprocs=WORLD_SIZE,
            join=True
        )
    
to see it in action, you can test it with the following script _`tutorial_oss.py <../../../examples/tutorial_oss.py>`_