Unverified Commit 8c8a625a authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[test][minor] Improving SDP test coverage (#639)

* Improving test coverage on SDP
* using pytest exception catcher
parent 21cba91b
......@@ -40,6 +40,6 @@ repos:
additional_dependencies: [toml]
- repo: https://github.com/pre-commit/mirrors-mypy
rev: 'v0.770'
rev: 'v0.790'
hooks:
- id: mypy
......@@ -256,9 +256,11 @@ class ShardedDataParallel(nn.Module):
Module: self.
"""
assert device in self._buckets.keys(), "Changing devices is not supported, because this would break OSSs state"
assert (
len(self._buckets.keys()) == 1
len(self._buckets.keys()) == 0 or device in self._buckets.keys()
), "Changing devices is not supported, because this would break OSSs state"
assert (
len(self._buckets.keys()) < 2
), "Several devices specified to begin with, incompatible with setting a single device here"
for _device in self._buckets.keys():
......
......@@ -3,7 +3,7 @@
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import collections
from collections import abc
import io
from math import inf
from typing import Any, Callable, Dict, List, Optional
......@@ -46,7 +46,7 @@ def recursive_copy_to_device(value: Any, non_blocking: bool, device: torch.devic
return values if isinstance(value, list) else tuple(values)
if isinstance(value, collections.abc.Mapping):
if isinstance(value, abc.Mapping):
device_val: Dict[str, Any] = {}
for key, val in value.items():
device_val[key] = recursive_copy_to_device(val, non_blocking=non_blocking, device=device)
......
......@@ -72,6 +72,7 @@ def run_one_step(
grad_accumulation,
reduce_buffer_size,
optimizer_type,
reduce_fp16=False,
):
dist.init_process_group(init_method="file://" + temp_file_name, backend=backend, rank=rank, world_size=world_size)
if device == torch.device("cuda"):
......@@ -93,7 +94,11 @@ def run_one_step(
optimizer = OSS(params=model.parameters(), optim=optimizer_type, **optimizer_settings)
ddp_model = ShardedDataParallel(
model, optimizer, broadcast_buffers=broadcast_buffers, reduce_buffer_size=reduce_buffer_size
model,
optimizer,
broadcast_buffers=broadcast_buffers,
reduce_buffer_size=reduce_buffer_size,
reduce_fp16=reduce_fp16,
)
# The model should be synchronized in between the ranks at ShardedDataParallel construction time, check that
......@@ -144,6 +149,7 @@ def run_test(backend, device, world_size, broadcast_buffers, grad_accumulation,
@pytest.mark.parametrize("grad_accumulation", [True, False])
@pytest.mark.parametrize("reduce_buffer_size", [0, 2 ** 20])
@pytest.mark.parametrize("optimizer_type", [torch.optim.SGD, SGDWithPausingCompute])
@pytest.mark.parametrize("reduce_fp16", [False, True])
@pytest.mark.parametrize(
"setup",
[
......@@ -152,7 +158,7 @@ def run_test(backend, device, world_size, broadcast_buffers, grad_accumulation,
[dist.Backend.GLOO, torch.device("cuda")],
],
)
def test_step(broadcast_buffers, grad_accumulation, reduce_buffer_size, optimizer_type, setup):
def test_step(broadcast_buffers, grad_accumulation, reduce_buffer_size, optimizer_type, reduce_fp16, setup):
world_size = 2
temp_file_name = tempfile.mkstemp()[1]
......@@ -167,6 +173,7 @@ def test_step(broadcast_buffers, grad_accumulation, reduce_buffer_size, optimize
grad_accumulation,
reduce_buffer_size,
optimizer_type,
reduce_fp16,
),
nprocs=world_size,
join=True,
......@@ -248,6 +255,26 @@ def test_random_attributes():
dist.destroy_process_group()
def test_catch_grad_grad():
# Check that ShardedDDP exposes the original module's attributes
dist.init_process_group(init_method="file://" + tempfile.mkstemp()[1], backend="gloo", rank=0, world_size=1)
model = Sequential(Linear(2, 3), Linear(3, 3))
model.train()
chained_grad = torch.zeros_like(next(model.parameters()))
chained_grad.requires_grad = True
next(model.parameters()).grad = chained_grad
optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
ddp_model = ShardedDataParallel(model, optimizer)
inputs = torch.rand(100, 2)
with pytest.raises(RuntimeError):
_ = ddp_model(inputs)
dist.destroy_process_group()
def test_mixed_types():
# Check that ShardedDDP exposes the original module's attributes
dist.init_process_group(init_method="file://" + tempfile.mkstemp()[1], backend="gloo", rank=0, world_size=1)
......@@ -312,6 +339,9 @@ def run_test_device_change(rank, world_size, backend, device, temp_file_name, re
except AssertionError:
pass
# Check that we can change the data type
ddp_model.to(device=torch.device("cpu"), dtype=torch.float16)
dist.destroy_process_group()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment