Unverified Commit ba5785f7 authored by Anupam Bhatnagar's avatar Anupam Bhatnagar Committed by GitHub
Browse files

Allow sharded grad scaler to cpu offload with FSDP (#831)

* first commit

* sharded scaler hitting nan assertions

* adding test for sharded grad scaler without cpu offload

* ddp grad scaler and fsdp sharded grad scaler test failing

* removing test_output

* fix no cpu offload test

* changing optimizer from OSS to SGD

* all tests passing, code cleanup pending

* code cleanup

* fix pyproject.toml

* removing .isort.cfg

* running isort linter

* resolving isort issues

* resolving black linter issue

* resolving mypy issues

* fix import statement

* fix mypy error

* modifying import statement

* adding pytorch version requirement

* fixing pytest skip test decorator

* apply version guard for ShardedGradScaler

* removing test_fsdp_grad_scaler

* increasing num_epochs for ShardedGradScaler so that updates are not skipped

* adding support for torch 1.8

* minor edit

* [skip ci] more torch 1.8 changes

* parametrizing the tests

* cleanup code with linters

* [skip ci] update doc string

* [skip ci] addressing some more comments
parent 7d7edf6d
......@@ -21,8 +21,8 @@ Device = Union[torch.device, int, str]
def check_pytorch_version() -> None:
if torch_version() < (1, 9, 0):
raise Exception("DistributedPipeline requires PyTorch version 1.9 or higher")
if torch_version() < (1, 8, 0):
raise Exception("DistributedPipeline requires PyTorch version 1.8 or higher")
MOVING_DENIED = TypeError(
......
This diff is collapsed.
......@@ -3,7 +3,6 @@ tests/nn/data_parallel/test_fsdp_multiple_wrapping.py
tests/nn/data_parallel/test_fsdp_freezing_weights.py
tests/nn/data_parallel/test_fsdp_regnet.py
tests/nn/data_parallel/test_fsdp_uneven.py
tests/nn/data_parallel/test_fsdp_grad_scaler.py
tests/nn/data_parallel/test_fsdp_grad_acc.py
tests/nn/data_parallel/test_fsdp_summon_full_params.py
tests/nn/data_parallel/test_fsdp_input.py
......
......@@ -13,6 +13,7 @@ import unittest
from unittest import mock
from parameterized import parameterized
import pytest
import torch
from torch import nn
import torch.distributed
......@@ -29,6 +30,9 @@ from fairscale.utils.testing import (
spawn_for_all_world_sizes,
)
if torch_version() >= (1, 8, 0):
from fairscale.optim.grad_scaler import ShardedGradScaler
# How to use remote-pdb: https://gist.github.com/sshleifer/9d43351957179c13606e015b072927d4
# All helper functions called by spawn must be either @classmethod, @staticmethod
......@@ -49,7 +53,9 @@ class DistributedTest(unittest.TestCase):
model_device = next(model.parameters()).device
# use SGD with momentum instead of Adam, since Adam is scale invariant
# and this makes it bad for tests
optim = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
optim = torch.optim.SGD(params=model.parameters(), lr=lr, momentum=0.9)
scaler = ShardedGradScaler()
for _ in range(num_steps):
optim.zero_grad()
with torch.cuda.amp.autocast(enabled=autocast):
......@@ -57,6 +63,7 @@ class DistributedTest(unittest.TestCase):
input = model.module.get_input(torch.device("cuda"))
output = model(*input)
loss = model.module.get_loss(input, output).to(model_device)
loss = scaler.scale(loss)
assert loss.dtype == torch.float32
model.module.run_backward(loss)
if norm_type is not None:
......@@ -65,10 +72,10 @@ class DistributedTest(unittest.TestCase):
model.clip_grad_norm_(clip_norm, norm_type)
else:
torch.nn.utils.clip_grad_norm_(model.parameters(), clip_norm, norm_type)
params = [p for p in model.parameters()]
print(f"params.device {params[0].device} param.grad.device {params[0].grad.device}")
optim.step()
scaler.step(optim)
scaler.update()
if hasattr(model, "assert_idle"):
model.assert_idle()
if isinstance(model, FullyShardedDataParallel):
model.assert_state(TrainingState.IDLE)
return loss.detach()
......@@ -308,21 +315,21 @@ class TestComparisonToPyTorchDDP(DistributedTest):
# Test every combination of these options:
spawn_and_init(functools.partial(self._test_identical_outputs, TransformerWithSharedParams, config))
def test_cpu_offload_and_cpu_grads(self):
# We don't test the False condition because that requires the optimizer to internally do
# the device transfer and PyTorch optimizers don't support this.
config = {"mixed_precision": True, "cpu_offload": True, "move_grads_to_cpu": True}
# testing moving params to cpu while using full and mixed precision
@parameterized.expand([(True,), (False,)], name_func=rename_test)
def test_cpu_offload_and_cpu_grads(self, mixed_precision):
config = {"mixed_precision": mixed_precision, "cpu_offload": True}
test_fn = functools.partial(
self._test_identical_outputs, TransformerWithSharedParams, config, use_cuda=False, lr=0.01
)
spawn_and_init(test_fn)
def test_cpu_offload_and_cpu_grads_no_mixed_precision(self):
# We don't test the False condition because that requires the optimizer to internally do
# the device transfer and PyTorch optimizers don't support this.
config = {"mixed_precision": False, "cpu_offload": True, "move_grads_to_cpu": True}
# testing full and mixed precision on the gpu
@parameterized.expand([(True,), (False,)], name_func=rename_test)
def test_no_cpu_offload_with_sharded_grad_scaler(self, mixed_precision):
config = {"mixed_precision": mixed_precision, "move_params_to_cpu": False}
test_fn = functools.partial(
self._test_identical_outputs, TransformerWithSharedParams, config, use_cuda=False, lr=0.01
self._test_identical_outputs, TransformerWithSharedParams, config, use_cuda=True, lr=0.01
)
spawn_and_init(test_fn)
......@@ -485,10 +492,10 @@ class TestSerialization(DistributedTest):
optim.step()
@pytest.mark.skipif(torch_version() < (1, 8, 0), reason="pytorch version >= 1.8.0 required")
class TestHooks(DistributedTest):
# Feel free to modify these tests as the implementation changes.
# They aspire to make sure that backward hooks are registered and used
@parameterized.expand([[True], [False]])
def test_output_backward_hooks(self, cuda_first):
fn = functools.partial(self._test_output_backward_hooks, cuda_first=cuda_first)
......@@ -541,6 +548,7 @@ class TestHooks(DistributedTest):
assert model._register_pre_backward_hooks.called
@pytest.mark.skipif(torch_version() < (1, 8, 0), reason="pytorch version >= 1.8.0 required")
class TestNoGrad(DistributedTest):
@parameterized.expand(CONFIG_OPTIONS, name_func=rename_test)
def test_transformer_parameterized(self, config):
......@@ -568,6 +576,7 @@ class TestNoGrad(DistributedTest):
assert objects_are_equal(ref_output, no_grad_output, raise_exception=True)
@pytest.mark.skipif(torch_version() < (1, 8, 0), reason="pytorch version >= 1.8.0 required")
class TestModuleProperties(DistributedTest):
@parameterized.expand([[{"flatten_parameters": False}], [{"flatten_parameters": True}]], name_func=rename_test)
def test_named_parameters(self, config):
......
......@@ -7,8 +7,11 @@ import functools
import unittest
from parameterized import parameterized
import pytest
import torch.nn as nn
from fairscale.utils import torch_version
from .test_fsdp import (
CONFIG_OPTIONS,
DistributedTest,
......@@ -19,6 +22,7 @@ from .test_fsdp import (
)
@pytest.mark.skipif(torch_version() < (1, 8, 0), reason="pytorch version >= 1.8.0 required")
class TestApply(DistributedTest):
@parameterized.expand(CONFIG_OPTIONS, name_func=rename_test)
def test_transformer_weight_init(self, config):
......
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# pylint: disable=missing-module-docstring
# pylint: disable=missing-class-docstring
# pylint: disable=missing-function-docstring
""" Test FSDP with grad scaler. """
import os
import random
import pytest
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairscale.nn import FullyShardedDataParallel
from fairscale.optim.grad_scaler import ShardedGradScaler
from fairscale.utils.testing import skip_if_no_cuda
try:
from torch.cuda.amp import autocast
except ImportError:
# Older version doesn't support autocast. Skip this file.
pytestmark = pytest.mark.skip
# Mixed precision needs cuda.
@skip_if_no_cuda
def test_scaler_cpu_offload_breaks():
device = torch.device("cuda")
torch.cuda.set_device(0)
# Random port in case the next test run quickly, same port would cause conflict.
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(random.randint(2000, 3000))
torch.distributed.init_process_group(backend="nccl", rank=0, world_size=1)
try:
scaler = ShardedGradScaler()
model = FullyShardedDataParallel(nn.Linear(5, 5), cpu_offload=True, mixed_precision=True)
optim = torch.optim.SGD(model.parameters(), lr=1e-3)
input = torch.rand((1, 5), dtype=torch.float).to(device)
optim.zero_grad()
with autocast():
output = model(input)
loss = F.mse_loss(input, output)
scaler.scale(loss).backward()
# TODO (Min): Need to fix. Details in issue #421.
with pytest.raises(RuntimeError):
scaler.step(optim)
scaler.update()
finally:
# Clean-up is important or the next test in this file may fail to init the PG.
torch.distributed.destroy_process_group()
del os.environ["MASTER_ADDR"]
del os.environ["MASTER_PORT"]
......@@ -35,7 +35,6 @@ from torch.optim import SGD
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.nn.data_parallel import TrainingState, auto_wrap_bn
from fairscale.optim.grad_scaler import ShardedGradScaler
from fairscale.utils import torch_version
from fairscale.utils.testing import (
dist_init,
......@@ -47,6 +46,9 @@ from fairscale.utils.testing import (
torch_cuda_version,
)
if torch_version() >= (1, 8, 0):
from fairscale.optim.grad_scaler import ShardedGradScaler
# Const test params.
# Reduce iterations to 1 for debugging.
# Change world_size to 8 on beefy machines for better test coverage.
......@@ -352,8 +354,8 @@ def _distributed_worker(
@pytest.mark.parametrize("flatten", ["flatten", "no_flatten"])
@pytest.mark.parametrize("sync_bn", ["none", "pytorch"])
def test_regnet(temp_files, ddp_ref, precision, flatten, sync_bn):
if torch_version() < (1, 6, 0):
pytest.skip("older pytorch doesn't support reduce_scatter")
if torch_version() < (1, 8, 0):
pytest.skip("pytorch version >= 1.8.0 required")
state_before, inputs, conv_bias, linear_bias, state_after = ddp_ref
......
......@@ -7,10 +7,12 @@ import functools
import unittest
from parameterized import parameterized
import pytest
import torch
from torch import nn
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.utils import torch_version
from fairscale.utils.testing import dist_init, objects_are_equal, skip_if_cuda, teardown, temp_files_ctx
from .test_fsdp import (
......@@ -23,6 +25,7 @@ from .test_fsdp import (
)
@pytest.mark.skipif(torch_version() < (1, 8, 0), reason="pytorch version >= 1.8.0 required")
class TestLocalStateDict(DistributedTest):
@parameterized.expand([[True, True], [False, False]], name_func=rename_test)
def test_load_local_state_dict(self, flatten_params, mixed_precision):
......@@ -50,7 +53,9 @@ class TestLocalStateDict(DistributedTest):
state_1_module_weight = model.module.state_dict()[weight_key]
torch.testing.assert_allclose(state_1_weight, state_1_module_weight)
torch.testing.assert_allclose(state_1_weight, model.module.embed_tokens.weight)
self._train_for_several_steps(model, 1, model.mixed_precision)
# increasing number of epochs from 1 to 6 for ShardedGradScaler to work properly.
# test fails for num_epochs < 6 since the updates are skipped due to gradient being inf.
self._train_for_several_steps(model, 6, model.mixed_precision)
state_2 = model.local_state_dict()
state_after_training = {k: v.cpu().clone() for k, v in state_2.items()}
......@@ -69,6 +74,7 @@ class TestLocalStateDict(DistributedTest):
raise AssertionError(f"params {unchanged} not changed after training")
@pytest.mark.skipif(torch_version() < (1, 8, 0), reason="pytorch version >= 1.8.0 required")
class TestSaveLoadStateDict(DistributedTest):
@parameterized.expand([[False], [True]], name_func=rename_test)
def test_calling_state_dict_twice_mixed_precision(self, mixed_precision):
......@@ -178,6 +184,7 @@ class TestSaveLoadStateDict(DistributedTest):
), f"{key}, {ref_state_dict[key]} != {state_dict[key]}"
@pytest.mark.skipif(torch_version() < (1, 8, 0), reason="pytorch version >= 1.8.0 required")
class TestStateDictDeviceDtype(DistributedTest):
@parameterized.expand([[False, False], [True, False], [True, True]], name_func=rename_test)
def test_state_dict_device(self, mixed_precision, cpu_offload):
......
......@@ -8,8 +8,11 @@ import gc
import unittest
from parameterized import parameterized
import pytest
import torch
from fairscale.utils.version import torch_version
from .test_fsdp import CONFIG_OPTIONS, DistributedTest, rename_test, spawn_and_init
......@@ -19,6 +22,7 @@ def get_cuda_mem():
return torch.cuda.memory_allocated()
@pytest.mark.skipif(torch_version() < (1, 8, 0), reason="pytorch version >= 1.8.0 required")
class TestMemory(DistributedTest):
@parameterized.expand(CONFIG_OPTIONS, name_func=rename_test)
def test_memory(self, config):
......
......@@ -21,10 +21,12 @@ from torch.nn.parallel import DistributedDataParallel as DDP
from fairscale.nn.data_parallel import ShardedDataParallel
from fairscale.optim import OSS
from fairscale.optim.grad_scaler import ShardedGradScaler
from fairscale.utils import torch_version
from fairscale.utils.testing import check_same_model_params, skip_if_no_cuda, skip_if_single_gpu, temp_files_ctx
if torch_version() >= (1, 8, 0):
from fairscale.optim.grad_scaler import ShardedGradScaler
"""
Check that ShardedDDP gets the same results as DDP in a variety of scenarii
"""
......@@ -249,6 +251,8 @@ def test_ddp_parity(
manual_reduction,
multiple_fw,
):
if torch_version() < (1, 8, 0):
pytest.skip("pytorch version >= 1.8.0 required")
if manual_reduction and change_train_graph:
pytest.skip("Skipping changing model and grad accumulation combination, makes little sense")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment