"template/alfred.json" did not exist on "23ebbaa46ead40c44c20b707b0e53d954ea51dc5"
test_fsdp_grad_scaler.py 1.29 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import os
from unittest import mock

import pytest
import torch
import torch.nn as nn
import torch.nn.functional as F

from fairscale.nn import FullyShardedDataParallel
from fairscale.optim.grad_scaler import ShardedGradScaler
from fairscale.utils.testing import skip_if_no_cuda

try:
    from torch.cuda.amp import autocast
except ImportError:
    # Older version doesn't support autocast. Skip this file.
    pytestmark = pytest.mark.skip


@mock.patch.dict(os.environ, {"MASTER_ADDR": "localhost", "MASTER_PORT": "1337"}, clear=True)
@skip_if_no_cuda
def test_scaler_cpu_offload_breaks():

    device = torch.device("cuda")
    torch.cuda.set_device(0)

    torch.distributed.init_process_group(backend="nccl", rank=0, world_size=1)

    scaler = ShardedGradScaler()
    model = FullyShardedDataParallel(nn.Linear(5, 5), cpu_offload=True, mixed_precision=True)
    optim = torch.optim.SGD(model.parameters(), lr=1e-3)

    input = torch.rand((1, 5), dtype=torch.float).to(device)
    optim.zero_grad()
    with autocast():
        output = model(input)
        loss = F.mse_loss(input, output)

    scaler.scale(loss).backward()
    # TODO (Min): Need to fix. Details in issue #421.
    with pytest.raises(RuntimeError):
        scaler.step(optim)
        scaler.update()
    torch.distributed.destroy_process_group()