"docs/vscode:/vscode.git/clone" did not exist on "265c09d80e22c7c26063770ea2d11c0095270fc0"
oss.rst 2.58 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
Optimizer state sharding
========================

Using torch.nn.parallel.DistributedDataParallel leads to some wasted communications, but it is possible and makes OSS a drop in solution in your existing torch distributed code.
Let's suppose that your trainer looks likemake html

.. code-block:: default


    import torch

    def train(
        rank: int,
        world_size: int,
        epochs: int):

        # DDP
        dist_init(rank, world_size)

        # Problem statement
        model = myAwesomeModel()
        dataloader = mySuperFastDataloader()
        loss = myVeryRelevantLoss()

        base_optimizer_arguments = {} # any optimizer specific arguments, LR, momentum, etc...
        optimizer = torch.optim.SGD(params=model.parameters(), **base_optimizer_arguments)

        # Any relevant training loop, nothing specific to OSS. For example:
        model.train()
        for e in range(epochs):
            for batch in dataloader:
                # Train
                model.zero_grad()
                outputs = model(batch["inputs"])
                loss = loss_fn(outputs, batch["label"])
                torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM)
                loss /= world_size
                loss.backward()
                optimizer.step()


Then sharding the optimizer state is merely a matter of wrapping your optimizer in fairscale.optim.OSS, as follows

.. code-block:: default


    :emphasize-lines: 49, 65, 66
    import torch
    from fairscale.optim.oss import OSS

    def train(
        rank: int,
        world_size: int,
        epochs: int):

        # DDP
        dist_init(rank, world_size)

        # Problem statement
        model = myAwesomeModel()
        dataloader = mySuperFastDataloader()
        loss = myVeryRelevantLoss()

        base_optimizer_arguments = {}  # pass any optimizer specific arguments here, or directly below when instantiating OSS
        base_optimizer = torch.optim.SGD  # any pytorch compliant optimizer
        optimizer = OSS(params=model.parameters(), optim=base_optimizer, **base_optimizer_arguments)

        # Any relevant training loop, nothing specific to OSS. For example:
        model.train()
        for e in range(epochs):
            for batch in dataloader:
                # Train
                model.zero_grad()
                outputs = model(batch["inputs"])
                loss = loss_fn(outputs, batch["label"])
                torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM)
                loss /= world_size
                loss.backward()
                optimizer.step()