Unverified Commit 8c8a625a authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[test][minor] Improving SDP test coverage (#639)

* Improving test coverage on SDP
* using pytest exception catcher
parent 21cba91b
...@@ -40,6 +40,6 @@ repos: ...@@ -40,6 +40,6 @@ repos:
additional_dependencies: [toml] additional_dependencies: [toml]
- repo: https://github.com/pre-commit/mirrors-mypy - repo: https://github.com/pre-commit/mirrors-mypy
rev: 'v0.770' rev: 'v0.790'
hooks: hooks:
- id: mypy - id: mypy
...@@ -256,9 +256,11 @@ class ShardedDataParallel(nn.Module): ...@@ -256,9 +256,11 @@ class ShardedDataParallel(nn.Module):
Module: self. Module: self.
""" """
assert device in self._buckets.keys(), "Changing devices is not supported, because this would break OSSs state"
assert ( assert (
len(self._buckets.keys()) == 1 len(self._buckets.keys()) == 0 or device in self._buckets.keys()
), "Changing devices is not supported, because this would break OSSs state"
assert (
len(self._buckets.keys()) < 2
), "Several devices specified to begin with, incompatible with setting a single device here" ), "Several devices specified to begin with, incompatible with setting a single device here"
for _device in self._buckets.keys(): for _device in self._buckets.keys():
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
# This source code is licensed under the BSD license found in the # This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import collections from collections import abc
import io import io
from math import inf from math import inf
from typing import Any, Callable, Dict, List, Optional from typing import Any, Callable, Dict, List, Optional
...@@ -46,7 +46,7 @@ def recursive_copy_to_device(value: Any, non_blocking: bool, device: torch.devic ...@@ -46,7 +46,7 @@ def recursive_copy_to_device(value: Any, non_blocking: bool, device: torch.devic
return values if isinstance(value, list) else tuple(values) return values if isinstance(value, list) else tuple(values)
if isinstance(value, collections.abc.Mapping): if isinstance(value, abc.Mapping):
device_val: Dict[str, Any] = {} device_val: Dict[str, Any] = {}
for key, val in value.items(): for key, val in value.items():
device_val[key] = recursive_copy_to_device(val, non_blocking=non_blocking, device=device) device_val[key] = recursive_copy_to_device(val, non_blocking=non_blocking, device=device)
......
...@@ -72,6 +72,7 @@ def run_one_step( ...@@ -72,6 +72,7 @@ def run_one_step(
grad_accumulation, grad_accumulation,
reduce_buffer_size, reduce_buffer_size,
optimizer_type, optimizer_type,
reduce_fp16=False,
): ):
dist.init_process_group(init_method="file://" + temp_file_name, backend=backend, rank=rank, world_size=world_size) dist.init_process_group(init_method="file://" + temp_file_name, backend=backend, rank=rank, world_size=world_size)
if device == torch.device("cuda"): if device == torch.device("cuda"):
...@@ -93,7 +94,11 @@ def run_one_step( ...@@ -93,7 +94,11 @@ def run_one_step(
optimizer = OSS(params=model.parameters(), optim=optimizer_type, **optimizer_settings) optimizer = OSS(params=model.parameters(), optim=optimizer_type, **optimizer_settings)
ddp_model = ShardedDataParallel( ddp_model = ShardedDataParallel(
model, optimizer, broadcast_buffers=broadcast_buffers, reduce_buffer_size=reduce_buffer_size model,
optimizer,
broadcast_buffers=broadcast_buffers,
reduce_buffer_size=reduce_buffer_size,
reduce_fp16=reduce_fp16,
) )
# The model should be synchronized in between the ranks at ShardedDataParallel construction time, check that # The model should be synchronized in between the ranks at ShardedDataParallel construction time, check that
...@@ -144,6 +149,7 @@ def run_test(backend, device, world_size, broadcast_buffers, grad_accumulation, ...@@ -144,6 +149,7 @@ def run_test(backend, device, world_size, broadcast_buffers, grad_accumulation,
@pytest.mark.parametrize("grad_accumulation", [True, False]) @pytest.mark.parametrize("grad_accumulation", [True, False])
@pytest.mark.parametrize("reduce_buffer_size", [0, 2 ** 20]) @pytest.mark.parametrize("reduce_buffer_size", [0, 2 ** 20])
@pytest.mark.parametrize("optimizer_type", [torch.optim.SGD, SGDWithPausingCompute]) @pytest.mark.parametrize("optimizer_type", [torch.optim.SGD, SGDWithPausingCompute])
@pytest.mark.parametrize("reduce_fp16", [False, True])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"setup", "setup",
[ [
...@@ -152,7 +158,7 @@ def run_test(backend, device, world_size, broadcast_buffers, grad_accumulation, ...@@ -152,7 +158,7 @@ def run_test(backend, device, world_size, broadcast_buffers, grad_accumulation,
[dist.Backend.GLOO, torch.device("cuda")], [dist.Backend.GLOO, torch.device("cuda")],
], ],
) )
def test_step(broadcast_buffers, grad_accumulation, reduce_buffer_size, optimizer_type, setup): def test_step(broadcast_buffers, grad_accumulation, reduce_buffer_size, optimizer_type, reduce_fp16, setup):
world_size = 2 world_size = 2
temp_file_name = tempfile.mkstemp()[1] temp_file_name = tempfile.mkstemp()[1]
...@@ -167,6 +173,7 @@ def test_step(broadcast_buffers, grad_accumulation, reduce_buffer_size, optimize ...@@ -167,6 +173,7 @@ def test_step(broadcast_buffers, grad_accumulation, reduce_buffer_size, optimize
grad_accumulation, grad_accumulation,
reduce_buffer_size, reduce_buffer_size,
optimizer_type, optimizer_type,
reduce_fp16,
), ),
nprocs=world_size, nprocs=world_size,
join=True, join=True,
...@@ -248,6 +255,26 @@ def test_random_attributes(): ...@@ -248,6 +255,26 @@ def test_random_attributes():
dist.destroy_process_group() dist.destroy_process_group()
def test_catch_grad_grad():
# Check that ShardedDDP exposes the original module's attributes
dist.init_process_group(init_method="file://" + tempfile.mkstemp()[1], backend="gloo", rank=0, world_size=1)
model = Sequential(Linear(2, 3), Linear(3, 3))
model.train()
chained_grad = torch.zeros_like(next(model.parameters()))
chained_grad.requires_grad = True
next(model.parameters()).grad = chained_grad
optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
ddp_model = ShardedDataParallel(model, optimizer)
inputs = torch.rand(100, 2)
with pytest.raises(RuntimeError):
_ = ddp_model(inputs)
dist.destroy_process_group()
def test_mixed_types(): def test_mixed_types():
# Check that ShardedDDP exposes the original module's attributes # Check that ShardedDDP exposes the original module's attributes
dist.init_process_group(init_method="file://" + tempfile.mkstemp()[1], backend="gloo", rank=0, world_size=1) dist.init_process_group(init_method="file://" + tempfile.mkstemp()[1], backend="gloo", rank=0, world_size=1)
...@@ -312,6 +339,9 @@ def run_test_device_change(rank, world_size, backend, device, temp_file_name, re ...@@ -312,6 +339,9 @@ def run_test_device_change(rank, world_size, backend, device, temp_file_name, re
except AssertionError: except AssertionError:
pass pass
# Check that we can change the data type
ddp_model.to(device=torch.device("cpu"), dtype=torch.float16)
dist.destroy_process_group() dist.destroy_process_group()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment