test_fsdp_grad_scaler.py 2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

# pylint: disable=missing-module-docstring
# pylint: disable=missing-class-docstring
# pylint: disable=missing-function-docstring

""" Test FSDP with grad scaler. """

12
import os
13
import random
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30

import pytest
import torch
import torch.nn as nn
import torch.nn.functional as F

from fairscale.nn import FullyShardedDataParallel
from fairscale.optim.grad_scaler import ShardedGradScaler
from fairscale.utils.testing import skip_if_no_cuda

try:
    from torch.cuda.amp import autocast
except ImportError:
    # Older version doesn't support autocast. Skip this file.
    pytestmark = pytest.mark.skip


31
# Mixed precision needs cuda.
32
33
34
35
36
@skip_if_no_cuda
def test_scaler_cpu_offload_breaks():
    device = torch.device("cuda")
    torch.cuda.set_device(0)

37
38
39
    # Random port in case the next test run quickly, same port would cause conflict.
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = str(random.randint(2000, 3000))
40
41
    torch.distributed.init_process_group(backend="nccl", rank=0, world_size=1)

42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
    try:
        scaler = ShardedGradScaler()
        model = FullyShardedDataParallel(nn.Linear(5, 5), cpu_offload=True, mixed_precision=True)
        optim = torch.optim.SGD(model.parameters(), lr=1e-3)

        input = torch.rand((1, 5), dtype=torch.float).to(device)
        optim.zero_grad()
        with autocast():
            output = model(input)
            loss = F.mse_loss(input, output)

        scaler.scale(loss).backward()
        # TODO (Min): Need to fix. Details in issue #421.
        with pytest.raises(RuntimeError):
            scaler.step(optim)
            scaler.update()

    finally:
        # Clean-up is important or the next test in this file may fail to init the PG.
        torch.distributed.destroy_process_group()
        del os.environ["MASTER_ADDR"]
        del os.environ["MASTER_PORT"]