Unverified Commit 2e9a14e7 authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[fix]: handle inputs with containers in mixed precision (#486)

* [fix]: handle inputs with containers

- this is an issue surfaces by vissl as well
- fix seems to be super simple
- also cleaned up two tests with respect to multiple such tests
  running back to back (they don't do that presently)

* cleanup

* fix

* lint
parent 1204c7cf
......@@ -20,13 +20,7 @@ import torch.nn.functional as F
from fairscale.nn.misc import FlattenParamsWrapper
from fairscale.optim.utils import calc_grad_norm
from fairscale.utils.containers import (
apply_to_tensors,
pack_kwargs,
split_non_tensors,
unpack_kwargs,
unpack_non_tensors,
)
from fairscale.utils.containers import apply_to_tensors
from fairscale.utils.parallel import chunk_and_pad, validate_process_group
from fairscale.utils.reduce_scatter_bucketer import ReduceScatterBucketer
from fairscale.utils.state_dict import replace_by_prefix_
......@@ -1189,15 +1183,14 @@ class FullyShardedDataParallel(nn.Module):
def cast_inputs_to_fp16(*args: Any, **kwargs: Any) -> Tuple[Any, Any]:
"""
Cast any Tensors in *args or **kwargs to FP16.
Doesn't currently support Tensors nested inside containers (e.g., dict).
"""
kwarg_keys, flat_args = pack_kwargs(*args, **kwargs)
tensor_inputs, packed_non_tensor_inputs = split_non_tensors(flat_args)
tensor_inputs = tuple(t.half() if torch.is_floating_point(t) else t for t in tensor_inputs)
flat_args = unpack_non_tensors(tensor_inputs, packed_non_tensor_inputs)
args, kwargs = unpack_kwargs(kwarg_keys, flat_args)
return args, kwargs
def fn(x: torch.Tensor) -> torch.Tensor:
if x.dtype is torch.float32:
return x.half()
return x
return apply_to_tensors(fn, args), apply_to_tensors(fn, kwargs)
def cast_buffers_(
......
......@@ -2,6 +2,7 @@ tests/nn/data_parallel/test_fsdp_uneven.py
tests/nn/data_parallel/test_fsdp_grad_scaler.py
tests/nn/data_parallel/test_fsdp_no_sync.py
tests/nn/data_parallel/test_fsdp_summon_full_params.py
tests/nn/data_parallel/test_fsdp_input.py
tests/nn/data_parallel/test_sharded_ddp_features.py
tests/nn/data_parallel/test_sharded_ddp_pytorch_parity.py
tests/nn/pipe/skip/test_gpipe.py
......
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# pylint: disable=missing-module-docstring
# pylint: disable=missing-class-docstring
# pylint: disable=missing-function-docstring
""" Test FSDP with grad scaler. """
import os
from unittest import mock
import random
import pytest
import torch
......@@ -17,28 +28,36 @@ except ImportError:
pytestmark = pytest.mark.skip
@mock.patch.dict(os.environ, {"MASTER_ADDR": "localhost", "MASTER_PORT": "1337"}, clear=True)
# Mixed precision needs cuda.
@skip_if_no_cuda
def test_scaler_cpu_offload_breaks():
device = torch.device("cuda")
torch.cuda.set_device(0)
# Random port in case the next test run quickly, same port would cause conflict.
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(random.randint(2000, 3000))
torch.distributed.init_process_group(backend="nccl", rank=0, world_size=1)
scaler = ShardedGradScaler()
model = FullyShardedDataParallel(nn.Linear(5, 5), cpu_offload=True, mixed_precision=True)
optim = torch.optim.SGD(model.parameters(), lr=1e-3)
input = torch.rand((1, 5), dtype=torch.float).to(device)
optim.zero_grad()
with autocast():
output = model(input)
loss = F.mse_loss(input, output)
scaler.scale(loss).backward()
# TODO (Min): Need to fix. Details in issue #421.
with pytest.raises(RuntimeError):
scaler.step(optim)
scaler.update()
torch.distributed.destroy_process_group()
try:
scaler = ShardedGradScaler()
model = FullyShardedDataParallel(nn.Linear(5, 5), cpu_offload=True, mixed_precision=True)
optim = torch.optim.SGD(model.parameters(), lr=1e-3)
input = torch.rand((1, 5), dtype=torch.float).to(device)
optim.zero_grad()
with autocast():
output = model(input)
loss = F.mse_loss(input, output)
scaler.scale(loss).backward()
# TODO (Min): Need to fix. Details in issue #421.
with pytest.raises(RuntimeError):
scaler.step(optim)
scaler.update()
finally:
# Clean-up is important or the next test in this file may fail to init the PG.
torch.distributed.destroy_process_group()
del os.environ["MASTER_ADDR"]
del os.environ["MASTER_PORT"]
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# pylint: disable=missing-module-docstring
# pylint: disable=missing-class-docstring
# pylint: disable=missing-function-docstring
""" Test FSDP with different input types. """
import os
import random
import pytest
import torch
from torch.nn import Linear, Module
from torch.optim import SGD
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.nn.data_parallel import TrainingState
from fairscale.utils.testing import skip_if_no_cuda, torch_version
# We only test on GPU since mix-precision only really works on GPU.
@skip_if_no_cuda
@pytest.mark.parametrize(
"fsdp_config", [{}, {"mixed_precision": True}],
)
@pytest.mark.parametrize("input_cls", [dict, list])
def test_it(fsdp_config, input_cls):
"""Test FSDP with input being a list or a dict, only single GPU."""
if torch_version() < (1, 6, 0):
pytest.skip("older pytorch doesn't support reduce_scatter")
# Random port in case the next test run quickly, same port would cause conflict.
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(random.randint(2000, 3000))
torch.distributed.init_process_group(backend="nccl", rank=0, world_size=1)
try:
assert isinstance(fsdp_config, dict), str(fsdp_config)
class Model(Module):
def __init__(self):
super().__init__()
self.layer = Linear(4, 4)
def forward(self, input):
if isinstance(input, list):
input = input[0]
else:
assert isinstance(input, dict), input
input = input["in"]
return self.layer(input)
model = FSDP(Model(), **fsdp_config).cuda()
optim = SGD(model.parameters(), lr=0.1)
for _ in range(5):
in_data = torch.rand(64, 4).cuda()
in_data.requires_grad = True
if input_cls is list:
in_data = [in_data]
else:
assert input_cls is dict
in_data = {"in": in_data}
out = model(in_data)
out.sum().backward()
optim.step()
optim.zero_grad()
model.assert_state(TrainingState.IDLE)
finally:
# Clean-up is important or the next test in this file may fail to init the PG.
torch.distributed.destroy_process_group()
del os.environ["MASTER_ADDR"]
del os.environ["MASTER_PORT"]
......@@ -5,8 +5,8 @@
import functools
import os
import random
import unittest
from unittest import mock
import torch
import torch.nn as nn
......@@ -16,13 +16,14 @@ from fairscale.nn import FullyShardedDataParallel as FSDP
from fairscale.nn import auto_wrap, default_auto_wrap_policy, enable_wrap, wrap
from fairscale.utils.testing import DummyProcessGroup
try:
from torch.cuda.amp import autocast
except ImportError:
autocast = None # type: ignore
class TestAutoWrap(unittest.TestCase):
def setUp(self) -> None:
version = torch.__version__.split(".")[:2]
major, minor = int(version[0]), int(version[1])
if major < 1 or (major == 1 and minor < 6):
raise unittest.SkipTest("Need pytorch version >= 1.6 due to autocast")
self.process_group = DummyProcessGroup(rank=0, size=1)
def test_wrap(self):
......@@ -129,26 +130,32 @@ class TestAutoWrap(unittest.TestCase):
"""
self._auto_wrap_smoke_test(enable_mixed_precision=True)
@mock.patch.dict(os.environ, {"MASTER_ADDR": "localhost", "MASTER_PORT": "12345"}, clear=True)
@unittest.skipIf(not torch.cuda.is_available(), "Test Requires CUDA")
@unittest.skipIf(autocast is None, "Test Requires autocast")
def _auto_wrap_smoke_test(self, enable_mixed_precision):
from torch.cuda.amp import autocast
device = torch.device("cuda")
torch.cuda.set_device(0)
# Random port in case the next test run quickly, same port would cause conflict.
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(random.randint(2000, 3000))
torch.distributed.init_process_group(backend="nccl", rank=0, world_size=1)
with enable_wrap(wrapper_cls=FSDP, mixed_precision=enable_mixed_precision):
sequential = nn.Sequential(
nn.Linear(5, 5), nn.Linear(5, 5), nn.Sequential(nn.Linear(5, 5), nn.Linear(5, 5))
)
my_auto_wrap_policy = functools.partial(default_auto_wrap_policy, min_num_params=40)
model = auto_wrap(sequential, auto_wrap_policy=my_auto_wrap_policy)
model.to(device)
input = torch.rand((1, 5), dtype=torch.float).to(device)
with autocast(enabled=enable_mixed_precision):
output = model(input)
loss = F.mse_loss(input, output)
loss.backward()
torch.distributed.destroy_process_group()
try:
with enable_wrap(wrapper_cls=FSDP, mixed_precision=enable_mixed_precision):
sequential = nn.Sequential(
nn.Linear(5, 5), nn.Linear(5, 5), nn.Sequential(nn.Linear(5, 5), nn.Linear(5, 5))
)
my_auto_wrap_policy = functools.partial(default_auto_wrap_policy, min_num_params=40)
model = auto_wrap(sequential, auto_wrap_policy=my_auto_wrap_policy)
model.to(device)
input = torch.rand((1, 5), dtype=torch.float).to(device)
with autocast(enabled=enable_mixed_precision):
output = model(input)
loss = F.mse_loss(input, output)
loss.backward()
finally:
torch.distributed.destroy_process_group()
del os.environ["MASTER_ADDR"]
del os.environ["MASTER_PORT"]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment