"vscode:/vscode.git/clone" did not exist on "37c5899fc2100de1c9afd51a7b1977b2f8185a28"
grad_scaler.rst 1.92 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
Sharded Grad Scaler
========================
Enabling PyTorch's automatic mixed precision usually means using a `GradScaler` to detect underflows.
This grad scaler is not aware of the state sharding when Fairscale OSS is involved, and will lead to deadlocks.
Make sure that you use `ShardedGradScaler` in that case, which is a shard-aware wrapper of PyTorch's implementation.

.. code-block:: python

    import torch
    from fairscale.optim.oss import OSS
    from fairscale.optim.grad_scaler import ShardedGradScaler
    from torch.nn.parallel import DistributedDataParallel as DDP

    def train(
        rank: int,
        world_size: int,
        epochs: int):

        # DDP
        dist_init(rank, world_size)

        # Problem statement
        model = myAwesomeModel().to(rank)
        model = DDP(model, device_ids=[rank])
        dataloader = mySuperFastDataloader()
        loss_ln = myVeryRelevantLoss()

        # optimizer specific arguments e.g. LR, momentum, etc...
        base_optimizer_arguments = { "lr": 1e-4}

        # ** NEW ** Wrap a base optimizer into OSS
        base_optimizer = torch.optim.SGD  # any pytorch compliant optimizer
        optimizer = OSS(
            params=model.parameters(),
            optim=base_optimizer,
            **base_optimizer_arguments)

        scaler = ShardedGradScaler()

        # Any relevant training loop, nothing specific to OSS. For example:
        model.train()
        for e in range(epochs):
            for (data, target) in dataloader:
                data, target = data.to(rank), target.to(rank)

                # Automatically computes the FW pass in half precision
                with torch.cuda.amp.autocast():
                    model.zero_grad()
                    outputs = model(data)
                    loss = loss_fn(outputs, target)

                # Automatically handle scaled gradients
                scaler.scale(loss).backward()
                optimizer.step()