Unverified Commit 5ecac15a authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[test] FSDP: add the failing test for #421 (#453)



* [test] FSDP: add the failing test for #421

* skip on 1.5

* better skipping

* Update tests/nn/data_parallel/test_fsdp_grad_scaler.py
Co-authored-by: default avatarSam Shleifer <sshleifer@gmail.com>
Co-authored-by: default avatarSam Shleifer <sshleifer@gmail.com>
parent 5c5866b3
import os
from unittest import mock
import pytest
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairscale.nn import FullyShardedDataParallel
from fairscale.optim.grad_scaler import ShardedGradScaler
from fairscale.utils.testing import skip_if_no_cuda
try:
from torch.cuda.amp import autocast
except ImportError:
# Older version doesn't support autocast. Skip this file.
pytestmark = pytest.mark.skip
@mock.patch.dict(os.environ, {"MASTER_ADDR": "localhost", "MASTER_PORT": "1337"}, clear=True)
@skip_if_no_cuda
def test_scaler_cpu_offload_breaks():
device = torch.device("cuda")
torch.cuda.set_device(0)
torch.distributed.init_process_group(backend="nccl", rank=0, world_size=1)
scaler = ShardedGradScaler()
model = FullyShardedDataParallel(nn.Linear(5, 5), cpu_offload=True, mixed_precision=True)
optim = torch.optim.SGD(model.parameters(), lr=1e-3)
input = torch.rand((1, 5), dtype=torch.float).to(device)
optim.zero_grad()
with autocast():
output = model(input)
loss = F.mse_loss(input, output)
scaler.scale(loss).backward()
# TODO (Min): Need to fix. Details in issue #421.
with pytest.raises(RuntimeError):
scaler.step(optim)
scaler.update()
torch.distributed.destroy_process_group()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment