Unverified Commit 2e9a14e7 authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[fix]: handle inputs with containers in mixed precision (#486)

* [fix]: handle inputs with containers

- this is an issue surfaces by vissl as well
- fix seems to be super simple
- also cleaned up two tests with respect to multiple such tests
  running back to back (they don't do that presently)

* cleanup

* fix

* lint
parent 1204c7cf
...@@ -20,13 +20,7 @@ import torch.nn.functional as F ...@@ -20,13 +20,7 @@ import torch.nn.functional as F
from fairscale.nn.misc import FlattenParamsWrapper from fairscale.nn.misc import FlattenParamsWrapper
from fairscale.optim.utils import calc_grad_norm from fairscale.optim.utils import calc_grad_norm
from fairscale.utils.containers import ( from fairscale.utils.containers import apply_to_tensors
apply_to_tensors,
pack_kwargs,
split_non_tensors,
unpack_kwargs,
unpack_non_tensors,
)
from fairscale.utils.parallel import chunk_and_pad, validate_process_group from fairscale.utils.parallel import chunk_and_pad, validate_process_group
from fairscale.utils.reduce_scatter_bucketer import ReduceScatterBucketer from fairscale.utils.reduce_scatter_bucketer import ReduceScatterBucketer
from fairscale.utils.state_dict import replace_by_prefix_ from fairscale.utils.state_dict import replace_by_prefix_
...@@ -1189,15 +1183,14 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1189,15 +1183,14 @@ class FullyShardedDataParallel(nn.Module):
def cast_inputs_to_fp16(*args: Any, **kwargs: Any) -> Tuple[Any, Any]: def cast_inputs_to_fp16(*args: Any, **kwargs: Any) -> Tuple[Any, Any]:
""" """
Cast any Tensors in *args or **kwargs to FP16. Cast any Tensors in *args or **kwargs to FP16.
Doesn't currently support Tensors nested inside containers (e.g., dict).
""" """
kwarg_keys, flat_args = pack_kwargs(*args, **kwargs)
tensor_inputs, packed_non_tensor_inputs = split_non_tensors(flat_args) def fn(x: torch.Tensor) -> torch.Tensor:
tensor_inputs = tuple(t.half() if torch.is_floating_point(t) else t for t in tensor_inputs) if x.dtype is torch.float32:
flat_args = unpack_non_tensors(tensor_inputs, packed_non_tensor_inputs) return x.half()
args, kwargs = unpack_kwargs(kwarg_keys, flat_args) return x
return args, kwargs
return apply_to_tensors(fn, args), apply_to_tensors(fn, kwargs)
def cast_buffers_( def cast_buffers_(
......
...@@ -2,6 +2,7 @@ tests/nn/data_parallel/test_fsdp_uneven.py ...@@ -2,6 +2,7 @@ tests/nn/data_parallel/test_fsdp_uneven.py
tests/nn/data_parallel/test_fsdp_grad_scaler.py tests/nn/data_parallel/test_fsdp_grad_scaler.py
tests/nn/data_parallel/test_fsdp_no_sync.py tests/nn/data_parallel/test_fsdp_no_sync.py
tests/nn/data_parallel/test_fsdp_summon_full_params.py tests/nn/data_parallel/test_fsdp_summon_full_params.py
tests/nn/data_parallel/test_fsdp_input.py
tests/nn/data_parallel/test_sharded_ddp_features.py tests/nn/data_parallel/test_sharded_ddp_features.py
tests/nn/data_parallel/test_sharded_ddp_pytorch_parity.py tests/nn/data_parallel/test_sharded_ddp_pytorch_parity.py
tests/nn/pipe/skip/test_gpipe.py tests/nn/pipe/skip/test_gpipe.py
......
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# pylint: disable=missing-module-docstring
# pylint: disable=missing-class-docstring
# pylint: disable=missing-function-docstring
""" Test FSDP with grad scaler. """
import os import os
from unittest import mock import random
import pytest import pytest
import torch import torch
...@@ -17,28 +28,36 @@ except ImportError: ...@@ -17,28 +28,36 @@ except ImportError:
pytestmark = pytest.mark.skip pytestmark = pytest.mark.skip
@mock.patch.dict(os.environ, {"MASTER_ADDR": "localhost", "MASTER_PORT": "1337"}, clear=True) # Mixed precision needs cuda.
@skip_if_no_cuda @skip_if_no_cuda
def test_scaler_cpu_offload_breaks(): def test_scaler_cpu_offload_breaks():
device = torch.device("cuda") device = torch.device("cuda")
torch.cuda.set_device(0) torch.cuda.set_device(0)
# Random port in case the next test run quickly, same port would cause conflict.
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(random.randint(2000, 3000))
torch.distributed.init_process_group(backend="nccl", rank=0, world_size=1) torch.distributed.init_process_group(backend="nccl", rank=0, world_size=1)
scaler = ShardedGradScaler() try:
model = FullyShardedDataParallel(nn.Linear(5, 5), cpu_offload=True, mixed_precision=True) scaler = ShardedGradScaler()
optim = torch.optim.SGD(model.parameters(), lr=1e-3) model = FullyShardedDataParallel(nn.Linear(5, 5), cpu_offload=True, mixed_precision=True)
optim = torch.optim.SGD(model.parameters(), lr=1e-3)
input = torch.rand((1, 5), dtype=torch.float).to(device)
optim.zero_grad() input = torch.rand((1, 5), dtype=torch.float).to(device)
with autocast(): optim.zero_grad()
output = model(input) with autocast():
loss = F.mse_loss(input, output) output = model(input)
loss = F.mse_loss(input, output)
scaler.scale(loss).backward()
# TODO (Min): Need to fix. Details in issue #421. scaler.scale(loss).backward()
with pytest.raises(RuntimeError): # TODO (Min): Need to fix. Details in issue #421.
scaler.step(optim) with pytest.raises(RuntimeError):
scaler.update() scaler.step(optim)
torch.distributed.destroy_process_group() scaler.update()
finally:
# Clean-up is important or the next test in this file may fail to init the PG.
torch.distributed.destroy_process_group()
del os.environ["MASTER_ADDR"]
del os.environ["MASTER_PORT"]
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# pylint: disable=missing-module-docstring
# pylint: disable=missing-class-docstring
# pylint: disable=missing-function-docstring
""" Test FSDP with different input types. """
import os
import random
import pytest
import torch
from torch.nn import Linear, Module
from torch.optim import SGD
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.nn.data_parallel import TrainingState
from fairscale.utils.testing import skip_if_no_cuda, torch_version
# We only test on GPU since mix-precision only really works on GPU.
@skip_if_no_cuda
@pytest.mark.parametrize(
"fsdp_config", [{}, {"mixed_precision": True}],
)
@pytest.mark.parametrize("input_cls", [dict, list])
def test_it(fsdp_config, input_cls):
"""Test FSDP with input being a list or a dict, only single GPU."""
if torch_version() < (1, 6, 0):
pytest.skip("older pytorch doesn't support reduce_scatter")
# Random port in case the next test run quickly, same port would cause conflict.
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(random.randint(2000, 3000))
torch.distributed.init_process_group(backend="nccl", rank=0, world_size=1)
try:
assert isinstance(fsdp_config, dict), str(fsdp_config)
class Model(Module):
def __init__(self):
super().__init__()
self.layer = Linear(4, 4)
def forward(self, input):
if isinstance(input, list):
input = input[0]
else:
assert isinstance(input, dict), input
input = input["in"]
return self.layer(input)
model = FSDP(Model(), **fsdp_config).cuda()
optim = SGD(model.parameters(), lr=0.1)
for _ in range(5):
in_data = torch.rand(64, 4).cuda()
in_data.requires_grad = True
if input_cls is list:
in_data = [in_data]
else:
assert input_cls is dict
in_data = {"in": in_data}
out = model(in_data)
out.sum().backward()
optim.step()
optim.zero_grad()
model.assert_state(TrainingState.IDLE)
finally:
# Clean-up is important or the next test in this file may fail to init the PG.
torch.distributed.destroy_process_group()
del os.environ["MASTER_ADDR"]
del os.environ["MASTER_PORT"]
...@@ -5,8 +5,8 @@ ...@@ -5,8 +5,8 @@
import functools import functools
import os import os
import random
import unittest import unittest
from unittest import mock
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -16,13 +16,14 @@ from fairscale.nn import FullyShardedDataParallel as FSDP ...@@ -16,13 +16,14 @@ from fairscale.nn import FullyShardedDataParallel as FSDP
from fairscale.nn import auto_wrap, default_auto_wrap_policy, enable_wrap, wrap from fairscale.nn import auto_wrap, default_auto_wrap_policy, enable_wrap, wrap
from fairscale.utils.testing import DummyProcessGroup from fairscale.utils.testing import DummyProcessGroup
try:
from torch.cuda.amp import autocast
except ImportError:
autocast = None # type: ignore
class TestAutoWrap(unittest.TestCase): class TestAutoWrap(unittest.TestCase):
def setUp(self) -> None: def setUp(self) -> None:
version = torch.__version__.split(".")[:2]
major, minor = int(version[0]), int(version[1])
if major < 1 or (major == 1 and minor < 6):
raise unittest.SkipTest("Need pytorch version >= 1.6 due to autocast")
self.process_group = DummyProcessGroup(rank=0, size=1) self.process_group = DummyProcessGroup(rank=0, size=1)
def test_wrap(self): def test_wrap(self):
...@@ -129,26 +130,32 @@ class TestAutoWrap(unittest.TestCase): ...@@ -129,26 +130,32 @@ class TestAutoWrap(unittest.TestCase):
""" """
self._auto_wrap_smoke_test(enable_mixed_precision=True) self._auto_wrap_smoke_test(enable_mixed_precision=True)
@mock.patch.dict(os.environ, {"MASTER_ADDR": "localhost", "MASTER_PORT": "12345"}, clear=True)
@unittest.skipIf(not torch.cuda.is_available(), "Test Requires CUDA") @unittest.skipIf(not torch.cuda.is_available(), "Test Requires CUDA")
@unittest.skipIf(autocast is None, "Test Requires autocast")
def _auto_wrap_smoke_test(self, enable_mixed_precision): def _auto_wrap_smoke_test(self, enable_mixed_precision):
from torch.cuda.amp import autocast
device = torch.device("cuda") device = torch.device("cuda")
torch.cuda.set_device(0) torch.cuda.set_device(0)
# Random port in case the next test run quickly, same port would cause conflict.
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(random.randint(2000, 3000))
torch.distributed.init_process_group(backend="nccl", rank=0, world_size=1) torch.distributed.init_process_group(backend="nccl", rank=0, world_size=1)
with enable_wrap(wrapper_cls=FSDP, mixed_precision=enable_mixed_precision): try:
sequential = nn.Sequential( with enable_wrap(wrapper_cls=FSDP, mixed_precision=enable_mixed_precision):
nn.Linear(5, 5), nn.Linear(5, 5), nn.Sequential(nn.Linear(5, 5), nn.Linear(5, 5)) sequential = nn.Sequential(
) nn.Linear(5, 5), nn.Linear(5, 5), nn.Sequential(nn.Linear(5, 5), nn.Linear(5, 5))
my_auto_wrap_policy = functools.partial(default_auto_wrap_policy, min_num_params=40) )
model = auto_wrap(sequential, auto_wrap_policy=my_auto_wrap_policy) my_auto_wrap_policy = functools.partial(default_auto_wrap_policy, min_num_params=40)
model.to(device) model = auto_wrap(sequential, auto_wrap_policy=my_auto_wrap_policy)
input = torch.rand((1, 5), dtype=torch.float).to(device) model.to(device)
input = torch.rand((1, 5), dtype=torch.float).to(device)
with autocast(enabled=enable_mixed_precision):
output = model(input) with autocast(enabled=enable_mixed_precision):
loss = F.mse_loss(input, output) output = model(input)
loss.backward() loss = F.mse_loss(input, output)
torch.distributed.destroy_process_group() loss.backward()
finally:
torch.distributed.destroy_process_group()
del os.environ["MASTER_ADDR"]
del os.environ["MASTER_PORT"]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment