Unverified Commit ba5785f7 authored by Anupam Bhatnagar's avatar Anupam Bhatnagar Committed by GitHub
Browse files

Allow sharded grad scaler to cpu offload with FSDP (#831)

* first commit

* sharded scaler hitting nan assertions

* adding test for sharded grad scaler without cpu offload

* ddp grad scaler and fsdp sharded grad scaler test failing

* removing test_output

* fix no cpu offload test

* changing optimizer from OSS to SGD

* all tests passing, code cleanup pending

* code cleanup

* fix pyproject.toml

* removing .isort.cfg

* running isort linter

* resolving isort issues

* resolving black linter issue

* resolving mypy issues

* fix import statement

* fix mypy error

* modifying import statement

* adding pytorch version requirement

* fixing pytest skip test decorator

* apply version guard for ShardedGradScaler

* removing test_fsdp_grad_scaler

* increasing num_epochs for ShardedGradScaler so that updates are not skipped

* adding support for torch 1.8

* minor edit

* [skip ci] more torch 1.8 changes

* parametrizing the tests

* cleanup code with linters

* [skip ci] update doc string

* [skip ci] addressing some more comments
parent 7d7edf6d
...@@ -21,8 +21,8 @@ Device = Union[torch.device, int, str] ...@@ -21,8 +21,8 @@ Device = Union[torch.device, int, str]
def check_pytorch_version() -> None: def check_pytorch_version() -> None:
if torch_version() < (1, 9, 0): if torch_version() < (1, 8, 0):
raise Exception("DistributedPipeline requires PyTorch version 1.9 or higher") raise Exception("DistributedPipeline requires PyTorch version 1.8 or higher")
MOVING_DENIED = TypeError( MOVING_DENIED = TypeError(
......
This diff is collapsed.
...@@ -3,7 +3,6 @@ tests/nn/data_parallel/test_fsdp_multiple_wrapping.py ...@@ -3,7 +3,6 @@ tests/nn/data_parallel/test_fsdp_multiple_wrapping.py
tests/nn/data_parallel/test_fsdp_freezing_weights.py tests/nn/data_parallel/test_fsdp_freezing_weights.py
tests/nn/data_parallel/test_fsdp_regnet.py tests/nn/data_parallel/test_fsdp_regnet.py
tests/nn/data_parallel/test_fsdp_uneven.py tests/nn/data_parallel/test_fsdp_uneven.py
tests/nn/data_parallel/test_fsdp_grad_scaler.py
tests/nn/data_parallel/test_fsdp_grad_acc.py tests/nn/data_parallel/test_fsdp_grad_acc.py
tests/nn/data_parallel/test_fsdp_summon_full_params.py tests/nn/data_parallel/test_fsdp_summon_full_params.py
tests/nn/data_parallel/test_fsdp_input.py tests/nn/data_parallel/test_fsdp_input.py
......
...@@ -13,6 +13,7 @@ import unittest ...@@ -13,6 +13,7 @@ import unittest
from unittest import mock from unittest import mock
from parameterized import parameterized from parameterized import parameterized
import pytest
import torch import torch
from torch import nn from torch import nn
import torch.distributed import torch.distributed
...@@ -29,6 +30,9 @@ from fairscale.utils.testing import ( ...@@ -29,6 +30,9 @@ from fairscale.utils.testing import (
spawn_for_all_world_sizes, spawn_for_all_world_sizes,
) )
if torch_version() >= (1, 8, 0):
from fairscale.optim.grad_scaler import ShardedGradScaler
# How to use remote-pdb: https://gist.github.com/sshleifer/9d43351957179c13606e015b072927d4 # How to use remote-pdb: https://gist.github.com/sshleifer/9d43351957179c13606e015b072927d4
# All helper functions called by spawn must be either @classmethod, @staticmethod # All helper functions called by spawn must be either @classmethod, @staticmethod
...@@ -49,7 +53,9 @@ class DistributedTest(unittest.TestCase): ...@@ -49,7 +53,9 @@ class DistributedTest(unittest.TestCase):
model_device = next(model.parameters()).device model_device = next(model.parameters()).device
# use SGD with momentum instead of Adam, since Adam is scale invariant # use SGD with momentum instead of Adam, since Adam is scale invariant
# and this makes it bad for tests # and this makes it bad for tests
optim = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
optim = torch.optim.SGD(params=model.parameters(), lr=lr, momentum=0.9)
scaler = ShardedGradScaler()
for _ in range(num_steps): for _ in range(num_steps):
optim.zero_grad() optim.zero_grad()
with torch.cuda.amp.autocast(enabled=autocast): with torch.cuda.amp.autocast(enabled=autocast):
...@@ -57,6 +63,7 @@ class DistributedTest(unittest.TestCase): ...@@ -57,6 +63,7 @@ class DistributedTest(unittest.TestCase):
input = model.module.get_input(torch.device("cuda")) input = model.module.get_input(torch.device("cuda"))
output = model(*input) output = model(*input)
loss = model.module.get_loss(input, output).to(model_device) loss = model.module.get_loss(input, output).to(model_device)
loss = scaler.scale(loss)
assert loss.dtype == torch.float32 assert loss.dtype == torch.float32
model.module.run_backward(loss) model.module.run_backward(loss)
if norm_type is not None: if norm_type is not None:
...@@ -65,10 +72,10 @@ class DistributedTest(unittest.TestCase): ...@@ -65,10 +72,10 @@ class DistributedTest(unittest.TestCase):
model.clip_grad_norm_(clip_norm, norm_type) model.clip_grad_norm_(clip_norm, norm_type)
else: else:
torch.nn.utils.clip_grad_norm_(model.parameters(), clip_norm, norm_type) torch.nn.utils.clip_grad_norm_(model.parameters(), clip_norm, norm_type)
params = [p for p in model.parameters()] scaler.step(optim)
print(f"params.device {params[0].device} param.grad.device {params[0].grad.device}") scaler.update()
if hasattr(model, "assert_idle"):
optim.step() model.assert_idle()
if isinstance(model, FullyShardedDataParallel): if isinstance(model, FullyShardedDataParallel):
model.assert_state(TrainingState.IDLE) model.assert_state(TrainingState.IDLE)
return loss.detach() return loss.detach()
...@@ -308,21 +315,21 @@ class TestComparisonToPyTorchDDP(DistributedTest): ...@@ -308,21 +315,21 @@ class TestComparisonToPyTorchDDP(DistributedTest):
# Test every combination of these options: # Test every combination of these options:
spawn_and_init(functools.partial(self._test_identical_outputs, TransformerWithSharedParams, config)) spawn_and_init(functools.partial(self._test_identical_outputs, TransformerWithSharedParams, config))
def test_cpu_offload_and_cpu_grads(self): # testing moving params to cpu while using full and mixed precision
# We don't test the False condition because that requires the optimizer to internally do @parameterized.expand([(True,), (False,)], name_func=rename_test)
# the device transfer and PyTorch optimizers don't support this. def test_cpu_offload_and_cpu_grads(self, mixed_precision):
config = {"mixed_precision": True, "cpu_offload": True, "move_grads_to_cpu": True} config = {"mixed_precision": mixed_precision, "cpu_offload": True}
test_fn = functools.partial( test_fn = functools.partial(
self._test_identical_outputs, TransformerWithSharedParams, config, use_cuda=False, lr=0.01 self._test_identical_outputs, TransformerWithSharedParams, config, use_cuda=False, lr=0.01
) )
spawn_and_init(test_fn) spawn_and_init(test_fn)
def test_cpu_offload_and_cpu_grads_no_mixed_precision(self): # testing full and mixed precision on the gpu
# We don't test the False condition because that requires the optimizer to internally do @parameterized.expand([(True,), (False,)], name_func=rename_test)
# the device transfer and PyTorch optimizers don't support this. def test_no_cpu_offload_with_sharded_grad_scaler(self, mixed_precision):
config = {"mixed_precision": False, "cpu_offload": True, "move_grads_to_cpu": True} config = {"mixed_precision": mixed_precision, "move_params_to_cpu": False}
test_fn = functools.partial( test_fn = functools.partial(
self._test_identical_outputs, TransformerWithSharedParams, config, use_cuda=False, lr=0.01 self._test_identical_outputs, TransformerWithSharedParams, config, use_cuda=True, lr=0.01
) )
spawn_and_init(test_fn) spawn_and_init(test_fn)
...@@ -485,10 +492,10 @@ class TestSerialization(DistributedTest): ...@@ -485,10 +492,10 @@ class TestSerialization(DistributedTest):
optim.step() optim.step()
@pytest.mark.skipif(torch_version() < (1, 8, 0), reason="pytorch version >= 1.8.0 required")
class TestHooks(DistributedTest): class TestHooks(DistributedTest):
# Feel free to modify these tests as the implementation changes. # Feel free to modify these tests as the implementation changes.
# They aspire to make sure that backward hooks are registered and used # They aspire to make sure that backward hooks are registered and used
@parameterized.expand([[True], [False]]) @parameterized.expand([[True], [False]])
def test_output_backward_hooks(self, cuda_first): def test_output_backward_hooks(self, cuda_first):
fn = functools.partial(self._test_output_backward_hooks, cuda_first=cuda_first) fn = functools.partial(self._test_output_backward_hooks, cuda_first=cuda_first)
...@@ -541,6 +548,7 @@ class TestHooks(DistributedTest): ...@@ -541,6 +548,7 @@ class TestHooks(DistributedTest):
assert model._register_pre_backward_hooks.called assert model._register_pre_backward_hooks.called
@pytest.mark.skipif(torch_version() < (1, 8, 0), reason="pytorch version >= 1.8.0 required")
class TestNoGrad(DistributedTest): class TestNoGrad(DistributedTest):
@parameterized.expand(CONFIG_OPTIONS, name_func=rename_test) @parameterized.expand(CONFIG_OPTIONS, name_func=rename_test)
def test_transformer_parameterized(self, config): def test_transformer_parameterized(self, config):
...@@ -568,6 +576,7 @@ class TestNoGrad(DistributedTest): ...@@ -568,6 +576,7 @@ class TestNoGrad(DistributedTest):
assert objects_are_equal(ref_output, no_grad_output, raise_exception=True) assert objects_are_equal(ref_output, no_grad_output, raise_exception=True)
@pytest.mark.skipif(torch_version() < (1, 8, 0), reason="pytorch version >= 1.8.0 required")
class TestModuleProperties(DistributedTest): class TestModuleProperties(DistributedTest):
@parameterized.expand([[{"flatten_parameters": False}], [{"flatten_parameters": True}]], name_func=rename_test) @parameterized.expand([[{"flatten_parameters": False}], [{"flatten_parameters": True}]], name_func=rename_test)
def test_named_parameters(self, config): def test_named_parameters(self, config):
......
...@@ -7,8 +7,11 @@ import functools ...@@ -7,8 +7,11 @@ import functools
import unittest import unittest
from parameterized import parameterized from parameterized import parameterized
import pytest
import torch.nn as nn import torch.nn as nn
from fairscale.utils import torch_version
from .test_fsdp import ( from .test_fsdp import (
CONFIG_OPTIONS, CONFIG_OPTIONS,
DistributedTest, DistributedTest,
...@@ -19,6 +22,7 @@ from .test_fsdp import ( ...@@ -19,6 +22,7 @@ from .test_fsdp import (
) )
@pytest.mark.skipif(torch_version() < (1, 8, 0), reason="pytorch version >= 1.8.0 required")
class TestApply(DistributedTest): class TestApply(DistributedTest):
@parameterized.expand(CONFIG_OPTIONS, name_func=rename_test) @parameterized.expand(CONFIG_OPTIONS, name_func=rename_test)
def test_transformer_weight_init(self, config): def test_transformer_weight_init(self, config):
......
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# pylint: disable=missing-module-docstring
# pylint: disable=missing-class-docstring
# pylint: disable=missing-function-docstring
""" Test FSDP with grad scaler. """
import os
import random
import pytest
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairscale.nn import FullyShardedDataParallel
from fairscale.optim.grad_scaler import ShardedGradScaler
from fairscale.utils.testing import skip_if_no_cuda
try:
from torch.cuda.amp import autocast
except ImportError:
# Older version doesn't support autocast. Skip this file.
pytestmark = pytest.mark.skip
# Mixed precision needs cuda.
@skip_if_no_cuda
def test_scaler_cpu_offload_breaks():
device = torch.device("cuda")
torch.cuda.set_device(0)
# Random port in case the next test run quickly, same port would cause conflict.
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(random.randint(2000, 3000))
torch.distributed.init_process_group(backend="nccl", rank=0, world_size=1)
try:
scaler = ShardedGradScaler()
model = FullyShardedDataParallel(nn.Linear(5, 5), cpu_offload=True, mixed_precision=True)
optim = torch.optim.SGD(model.parameters(), lr=1e-3)
input = torch.rand((1, 5), dtype=torch.float).to(device)
optim.zero_grad()
with autocast():
output = model(input)
loss = F.mse_loss(input, output)
scaler.scale(loss).backward()
# TODO (Min): Need to fix. Details in issue #421.
with pytest.raises(RuntimeError):
scaler.step(optim)
scaler.update()
finally:
# Clean-up is important or the next test in this file may fail to init the PG.
torch.distributed.destroy_process_group()
del os.environ["MASTER_ADDR"]
del os.environ["MASTER_PORT"]
...@@ -35,7 +35,6 @@ from torch.optim import SGD ...@@ -35,7 +35,6 @@ from torch.optim import SGD
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.nn.data_parallel import TrainingState, auto_wrap_bn from fairscale.nn.data_parallel import TrainingState, auto_wrap_bn
from fairscale.optim.grad_scaler import ShardedGradScaler
from fairscale.utils import torch_version from fairscale.utils import torch_version
from fairscale.utils.testing import ( from fairscale.utils.testing import (
dist_init, dist_init,
...@@ -47,6 +46,9 @@ from fairscale.utils.testing import ( ...@@ -47,6 +46,9 @@ from fairscale.utils.testing import (
torch_cuda_version, torch_cuda_version,
) )
if torch_version() >= (1, 8, 0):
from fairscale.optim.grad_scaler import ShardedGradScaler
# Const test params. # Const test params.
# Reduce iterations to 1 for debugging. # Reduce iterations to 1 for debugging.
# Change world_size to 8 on beefy machines for better test coverage. # Change world_size to 8 on beefy machines for better test coverage.
...@@ -352,8 +354,8 @@ def _distributed_worker( ...@@ -352,8 +354,8 @@ def _distributed_worker(
@pytest.mark.parametrize("flatten", ["flatten", "no_flatten"]) @pytest.mark.parametrize("flatten", ["flatten", "no_flatten"])
@pytest.mark.parametrize("sync_bn", ["none", "pytorch"]) @pytest.mark.parametrize("sync_bn", ["none", "pytorch"])
def test_regnet(temp_files, ddp_ref, precision, flatten, sync_bn): def test_regnet(temp_files, ddp_ref, precision, flatten, sync_bn):
if torch_version() < (1, 6, 0): if torch_version() < (1, 8, 0):
pytest.skip("older pytorch doesn't support reduce_scatter") pytest.skip("pytorch version >= 1.8.0 required")
state_before, inputs, conv_bias, linear_bias, state_after = ddp_ref state_before, inputs, conv_bias, linear_bias, state_after = ddp_ref
......
...@@ -7,10 +7,12 @@ import functools ...@@ -7,10 +7,12 @@ import functools
import unittest import unittest
from parameterized import parameterized from parameterized import parameterized
import pytest
import torch import torch
from torch import nn from torch import nn
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.utils import torch_version
from fairscale.utils.testing import dist_init, objects_are_equal, skip_if_cuda, teardown, temp_files_ctx from fairscale.utils.testing import dist_init, objects_are_equal, skip_if_cuda, teardown, temp_files_ctx
from .test_fsdp import ( from .test_fsdp import (
...@@ -23,6 +25,7 @@ from .test_fsdp import ( ...@@ -23,6 +25,7 @@ from .test_fsdp import (
) )
@pytest.mark.skipif(torch_version() < (1, 8, 0), reason="pytorch version >= 1.8.0 required")
class TestLocalStateDict(DistributedTest): class TestLocalStateDict(DistributedTest):
@parameterized.expand([[True, True], [False, False]], name_func=rename_test) @parameterized.expand([[True, True], [False, False]], name_func=rename_test)
def test_load_local_state_dict(self, flatten_params, mixed_precision): def test_load_local_state_dict(self, flatten_params, mixed_precision):
...@@ -50,7 +53,9 @@ class TestLocalStateDict(DistributedTest): ...@@ -50,7 +53,9 @@ class TestLocalStateDict(DistributedTest):
state_1_module_weight = model.module.state_dict()[weight_key] state_1_module_weight = model.module.state_dict()[weight_key]
torch.testing.assert_allclose(state_1_weight, state_1_module_weight) torch.testing.assert_allclose(state_1_weight, state_1_module_weight)
torch.testing.assert_allclose(state_1_weight, model.module.embed_tokens.weight) torch.testing.assert_allclose(state_1_weight, model.module.embed_tokens.weight)
self._train_for_several_steps(model, 1, model.mixed_precision) # increasing number of epochs from 1 to 6 for ShardedGradScaler to work properly.
# test fails for num_epochs < 6 since the updates are skipped due to gradient being inf.
self._train_for_several_steps(model, 6, model.mixed_precision)
state_2 = model.local_state_dict() state_2 = model.local_state_dict()
state_after_training = {k: v.cpu().clone() for k, v in state_2.items()} state_after_training = {k: v.cpu().clone() for k, v in state_2.items()}
...@@ -69,6 +74,7 @@ class TestLocalStateDict(DistributedTest): ...@@ -69,6 +74,7 @@ class TestLocalStateDict(DistributedTest):
raise AssertionError(f"params {unchanged} not changed after training") raise AssertionError(f"params {unchanged} not changed after training")
@pytest.mark.skipif(torch_version() < (1, 8, 0), reason="pytorch version >= 1.8.0 required")
class TestSaveLoadStateDict(DistributedTest): class TestSaveLoadStateDict(DistributedTest):
@parameterized.expand([[False], [True]], name_func=rename_test) @parameterized.expand([[False], [True]], name_func=rename_test)
def test_calling_state_dict_twice_mixed_precision(self, mixed_precision): def test_calling_state_dict_twice_mixed_precision(self, mixed_precision):
...@@ -178,6 +184,7 @@ class TestSaveLoadStateDict(DistributedTest): ...@@ -178,6 +184,7 @@ class TestSaveLoadStateDict(DistributedTest):
), f"{key}, {ref_state_dict[key]} != {state_dict[key]}" ), f"{key}, {ref_state_dict[key]} != {state_dict[key]}"
@pytest.mark.skipif(torch_version() < (1, 8, 0), reason="pytorch version >= 1.8.0 required")
class TestStateDictDeviceDtype(DistributedTest): class TestStateDictDeviceDtype(DistributedTest):
@parameterized.expand([[False, False], [True, False], [True, True]], name_func=rename_test) @parameterized.expand([[False, False], [True, False], [True, True]], name_func=rename_test)
def test_state_dict_device(self, mixed_precision, cpu_offload): def test_state_dict_device(self, mixed_precision, cpu_offload):
......
...@@ -8,8 +8,11 @@ import gc ...@@ -8,8 +8,11 @@ import gc
import unittest import unittest
from parameterized import parameterized from parameterized import parameterized
import pytest
import torch import torch
from fairscale.utils.version import torch_version
from .test_fsdp import CONFIG_OPTIONS, DistributedTest, rename_test, spawn_and_init from .test_fsdp import CONFIG_OPTIONS, DistributedTest, rename_test, spawn_and_init
...@@ -19,6 +22,7 @@ def get_cuda_mem(): ...@@ -19,6 +22,7 @@ def get_cuda_mem():
return torch.cuda.memory_allocated() return torch.cuda.memory_allocated()
@pytest.mark.skipif(torch_version() < (1, 8, 0), reason="pytorch version >= 1.8.0 required")
class TestMemory(DistributedTest): class TestMemory(DistributedTest):
@parameterized.expand(CONFIG_OPTIONS, name_func=rename_test) @parameterized.expand(CONFIG_OPTIONS, name_func=rename_test)
def test_memory(self, config): def test_memory(self, config):
......
...@@ -21,10 +21,12 @@ from torch.nn.parallel import DistributedDataParallel as DDP ...@@ -21,10 +21,12 @@ from torch.nn.parallel import DistributedDataParallel as DDP
from fairscale.nn.data_parallel import ShardedDataParallel from fairscale.nn.data_parallel import ShardedDataParallel
from fairscale.optim import OSS from fairscale.optim import OSS
from fairscale.optim.grad_scaler import ShardedGradScaler
from fairscale.utils import torch_version from fairscale.utils import torch_version
from fairscale.utils.testing import check_same_model_params, skip_if_no_cuda, skip_if_single_gpu, temp_files_ctx from fairscale.utils.testing import check_same_model_params, skip_if_no_cuda, skip_if_single_gpu, temp_files_ctx
if torch_version() >= (1, 8, 0):
from fairscale.optim.grad_scaler import ShardedGradScaler
""" """
Check that ShardedDDP gets the same results as DDP in a variety of scenarii Check that ShardedDDP gets the same results as DDP in a variety of scenarii
""" """
...@@ -249,6 +251,8 @@ def test_ddp_parity( ...@@ -249,6 +251,8 @@ def test_ddp_parity(
manual_reduction, manual_reduction,
multiple_fw, multiple_fw,
): ):
if torch_version() < (1, 8, 0):
pytest.skip("pytorch version >= 1.8.0 required")
if manual_reduction and change_train_graph: if manual_reduction and change_train_graph:
pytest.skip("Skipping changing model and grad accumulation combination, makes little sense") pytest.skip("Skipping changing model and grad accumulation combination, makes little sense")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment