oss.rst 2.61 KB
Newer Older
1
2
3
4
Optimizer state sharding
========================

Using torch.nn.parallel.DistributedDataParallel leads to some wasted communications, but it is possible and makes OSS a drop in solution in your existing torch distributed code.
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
5
Let's suppose that your trainer looks like
6

Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
7
.. code-block:: python
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25


    import torch

    def train(
        rank: int,
        world_size: int,
        epochs: int):

        # DDP
        dist_init(rank, world_size)

        # Problem statement
        model = myAwesomeModel()
        dataloader = mySuperFastDataloader()
        loss = myVeryRelevantLoss()

        base_optimizer_arguments = {} # any optimizer specific arguments, LR, momentum, etc...
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
26
27
28
        optimizer = torch.optim.SGD(
            params=model.parameters(),
            **base_optimizer_arguments)
29
30
31
32
33
34
35
36
37
38
39

        # Any relevant training loop, nothing specific to OSS. For example:
        model.train()
        for e in range(epochs):
            for batch in dataloader:
                # Train
                model.zero_grad()
                outputs = model(batch["inputs"])
                loss = loss_fn(outputs, batch["label"])
                loss /= world_size
                loss.backward()
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
40
                torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM)
41
42
43
44
45
                optimizer.step()


Then sharding the optimizer state is merely a matter of wrapping your optimizer in fairscale.optim.OSS, as follows

Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
46
.. code-block:: python
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64


    import torch
    from fairscale.optim.oss import OSS

    def train(
        rank: int,
        world_size: int,
        epochs: int):

        # DDP
        dist_init(rank, world_size)

        # Problem statement
        model = myAwesomeModel()
        dataloader = mySuperFastDataloader()
        loss = myVeryRelevantLoss()

Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
65
66
67
        base_optimizer_arguments = {}  # any optimizer specific arguments, LR, momentum, etc...

        # ** NEW ** Wrap a base optimizer into OSS
68
        base_optimizer = torch.optim.SGD  # any pytorch compliant optimizer
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
69
70
71
72
        optimizer = OSS(
            params=model.parameters(),
            optim=base_optimizer,
            **base_optimizer_arguments)
73
74
75
76
77
78
79
80
81
82
83

        # Any relevant training loop, nothing specific to OSS. For example:
        model.train()
        for e in range(epochs):
            for batch in dataloader:
                # Train
                model.zero_grad()
                outputs = model(batch["inputs"])
                loss = loss_fn(outputs, batch["label"])
                loss /= world_size
                loss.backward()
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
84
                torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM)
85
                optimizer.step()