Unverified Commit 543d5693 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[fix] OSS tests - remove concurrent dist inits (#177)

parent cc766aa5
...@@ -7,7 +7,9 @@ ...@@ -7,7 +7,9 @@
# pylint: disable=missing-class-docstring # pylint: disable=missing-class-docstring
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
import os
import tempfile
import unittest
import numpy as np import numpy as np
import pytest import pytest
...@@ -23,28 +25,27 @@ BACKEND = dist.Backend.NCCL if torch.cuda.is_available() else dist.Backend.GLOO ...@@ -23,28 +25,27 @@ BACKEND = dist.Backend.NCCL if torch.cuda.is_available() else dist.Backend.GLOO
DEVICE = "cuda" if torch.cuda.is_available() else torch.device("cpu") DEVICE = "cuda" if torch.cuda.is_available() else torch.device("cpu")
def setup_module(module): def dist_init(rank, world_size, tempfile_name, backend=BACKEND):
os.environ["MASTER_ADDR"] = "localhost" url = "file://" + tempfile_name
os.environ["MASTER_PORT"] = "29500" dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size)
dist.init_process_group(backend=BACKEND, rank=0, world_size=1)
def teardown_module(module):
torch.distributed.destroy_process_group()
class TestSingleRank(unittest.TestCase):
"""
All the following tests do not check for inter-process communication
"""
def dist_init(rank, world_size): def setUp(self):
os.environ["MASTER_ADDR"] = "localhost" dist_init(0, 1, tempfile.mkstemp()[1])
os.environ["MASTER_PORT"] = "29501"
dist.init_process_group(backend=BACKEND, rank=rank, world_size=world_size)
def tearDown(self):
torch.distributed.destroy_process_group()
def test_create(): def test_create(self):
params = [torch.rand(1)] params = [torch.rand(1)]
o = optim.OSS(params, lr=0.01) o = optim.OSS(params, lr=0.01)
def test_state_dict(self):
def test_state_dict():
x = torch.tensor([1.0], device=DEVICE, requires_grad=True) x = torch.tensor([1.0], device=DEVICE, requires_grad=True)
o = optim.OSS([x], lr=0.1, momentum=0.9) o = optim.OSS([x], lr=0.1, momentum=0.9)
x.backward() x.backward()
...@@ -88,8 +89,7 @@ def test_state_dict(): ...@@ -88,8 +89,7 @@ def test_state_dict():
# Check that the exposed param_groups are on the proper device # Check that the exposed param_groups are on the proper device
assert o.param_groups[0]["params"][0].device == x.device assert o.param_groups[0]["params"][0].device == x.device
def test_lr_scheduler(self):
def test_lr_scheduler():
x = torch.tensor([1.0], device=DEVICE, requires_grad=True) x = torch.tensor([1.0], device=DEVICE, requires_grad=True)
x2 = torch.tensor([1.0], device=DEVICE, requires_grad=True) x2 = torch.tensor([1.0], device=DEVICE, requires_grad=True)
o = optim.OSS([x], lr=0.01) o = optim.OSS([x], lr=0.01)
...@@ -107,8 +107,7 @@ def test_lr_scheduler(): ...@@ -107,8 +107,7 @@ def test_lr_scheduler():
s2.step() s2.step()
assert x == x2 assert x == x2
def test_step_with_kwargs(self):
def test_step_with_kwargs():
class SGDWithStepKWArg(torch.optim.SGD): class SGDWithStepKWArg(torch.optim.SGD):
def step(self, closure=None, kwarg=[]): def step(self, closure=None, kwarg=[]):
super().step() super().step()
...@@ -122,8 +121,7 @@ def test_step_with_kwargs(): ...@@ -122,8 +121,7 @@ def test_step_with_kwargs():
assert kwarg == [5] assert kwarg == [5]
assert x == torch.tensor([0.9], device=DEVICE) assert x == torch.tensor([0.9], device=DEVICE)
def test_step_with_extra_inner_key(self):
def test_step_with_extra_inner_key():
class SGDWithNewKey(torch.optim.SGD): class SGDWithNewKey(torch.optim.SGD):
# Dummy optimizer which adds a new key to the param groups # Dummy optimizer which adds a new key to the param groups
def step(self, closure=None): def step(self, closure=None):
...@@ -137,8 +135,7 @@ def test_step_with_extra_inner_key(): ...@@ -137,8 +135,7 @@ def test_step_with_extra_inner_key():
assert o.param_groups[0]["new_key"] == 0.1 assert o.param_groups[0]["new_key"] == 0.1
assert x == torch.tensor([0.9], device=DEVICE) assert x == torch.tensor([0.9], device=DEVICE)
def test_step_without_closure(self):
def test_step_without_closure():
class SGDWithoutClosure(torch.optim.SGD): class SGDWithoutClosure(torch.optim.SGD):
def step(self): def step(self):
return super().step() return super().step()
...@@ -149,8 +146,7 @@ def test_step_without_closure(): ...@@ -149,8 +146,7 @@ def test_step_without_closure():
o.step() o.step()
assert x == torch.tensor([0.9], device=DEVICE) assert x == torch.tensor([0.9], device=DEVICE)
def test_local_state_dict(self):
def test_local_state_dict():
x = torch.tensor([1.0], device=DEVICE, requires_grad=True) x = torch.tensor([1.0], device=DEVICE, requires_grad=True)
o = optim.OSS([x], lr=0.1) o = optim.OSS([x], lr=0.1)
local_state_dict = o.local_state_dict() local_state_dict = o.local_state_dict()
...@@ -163,8 +159,7 @@ def test_local_state_dict(): ...@@ -163,8 +159,7 @@ def test_local_state_dict():
o.step() o.step()
assert x == torch.tensor([0.9], device=DEVICE) assert x == torch.tensor([0.9], device=DEVICE)
def test_implicit_local_state_dict(self):
def test_implicit_local_state_dict():
x = torch.tensor([1.0], device=DEVICE, requires_grad=True) x = torch.tensor([1.0], device=DEVICE, requires_grad=True)
o = optim.OSS([x], lr=0.1) o = optim.OSS([x], lr=0.1)
local_state_dict = o.state_dict() local_state_dict = o.state_dict()
...@@ -178,8 +173,8 @@ def test_implicit_local_state_dict(): ...@@ -178,8 +173,8 @@ def test_implicit_local_state_dict():
assert x == torch.tensor([0.9], device=DEVICE) assert x == torch.tensor([0.9], device=DEVICE)
def run_test_add_param_group(rank, world_size): def run_test_add_param_group(rank, world_size, tempfile_name):
dist_init(rank, world_size) dist_init(rank, world_size, tempfile_name)
params = [] params = []
for size in [4, 5, 2, 6, 4]: for size in [4, 5, 2, 6, 4]:
params.append(torch.rand(size, 1)) params.append(torch.rand(size, 1))
...@@ -191,14 +186,17 @@ def run_test_add_param_group(rank, world_size): ...@@ -191,14 +186,17 @@ def run_test_add_param_group(rank, world_size):
assert sum([x.numel() for g in o.optim.param_groups for x in g["params"]]) == 8 assert sum([x.numel() for g in o.optim.param_groups for x in g["params"]]) == 8
assert len(o.optim.param_groups) == 2 assert len(o.optim.param_groups) == 2
dist.destroy_process_group()
def test_add_param_group(): def test_add_param_group():
world_size = 3 world_size = 3
mp.spawn(run_test_add_param_group, args=(world_size,), nprocs=world_size, join=True) temp_file_name = tempfile.mkstemp()[1]
mp.spawn(run_test_add_param_group, args=(world_size, temp_file_name), nprocs=world_size, join=True)
def run_test_zero_grad(rank, world_size): def run_test_zero_grad(rank, world_size, tempfile_name):
dist_init(rank, world_size) dist_init(rank, world_size, tempfile_name)
x = torch.rand(1) x = torch.rand(1)
m = torch.nn.Linear(1, 1) m = torch.nn.Linear(1, 1)
o = optim.OSS(m.parameters(), lr=0.1) o = optim.OSS(m.parameters(), lr=0.1)
...@@ -210,14 +208,17 @@ def run_test_zero_grad(rank, world_size): ...@@ -210,14 +208,17 @@ def run_test_zero_grad(rank, world_size):
assert not m.weight.grad assert not m.weight.grad
assert not m.bias.grad assert not m.bias.grad
dist.destroy_process_group()
def test_zero_grad(): def test_zero_grad():
world_size = 2 world_size = 2
mp.spawn(run_test_zero_grad, args=(world_size,), nprocs=world_size, join=True) temp_file_name = tempfile.mkstemp()[1]
mp.spawn(run_test_zero_grad, args=(world_size, temp_file_name), nprocs=world_size, join=True)
def run_test_step(rank, world_size): def run_test_step(rank, world_size, tempfile_name):
dist_init(rank, world_size) dist_init(rank, world_size, tempfile_name, backend="gloo")
x = torch.tensor([float(rank + 1)], device=rank) x = torch.tensor([float(rank + 1)], device=rank)
m = torch.nn.Linear(1, 1) m = torch.nn.Linear(1, 1)
m.weight.data = torch.tensor([[1.0]]) m.weight.data = torch.tensor([[1.0]])
...@@ -233,15 +234,19 @@ def run_test_step(rank, world_size): ...@@ -233,15 +234,19 @@ def run_test_step(rank, world_size):
assert m.weight == torch.tensor([[0.75]], device=rank) assert m.weight == torch.tensor([[0.75]], device=rank)
assert m.bias == torch.tensor([1.85], device=rank) assert m.bias == torch.tensor([1.85], device=rank)
dist.destroy_process_group()
@skip_if_no_cuda @skip_if_no_cuda
def test_step(): def test_step():
world_size = min(2, torch.cuda.device_count()) world_size = min(2, torch.cuda.device_count())
mp.spawn(run_test_step, args=(world_size,), nprocs=world_size, join=True) temp_file_name = tempfile.mkstemp()[1]
mp.spawn(run_test_step, args=(world_size, temp_file_name), nprocs=world_size, join=True)
def run_test_step_with_closure(rank, world_size, optimizer=None):
dist_init(rank, world_size) def run_test_step_with_closure(rank, world_size, tempfile_name, optimizer=None):
dist_init(rank, world_size, tempfile_name)
x_val = rank + 1 x_val = rank + 1
weight = 1.0 weight = 1.0
...@@ -277,33 +282,41 @@ def run_test_step_with_closure(rank, world_size, optimizer=None): ...@@ -277,33 +282,41 @@ def run_test_step_with_closure(rank, world_size, optimizer=None):
assert m.weight == torch.tensor([[1.1]], device=rank) assert m.weight == torch.tensor([[1.1]], device=rank)
assert m.bias == torch.tensor([2.1], device=rank) assert m.bias == torch.tensor([2.1], device=rank)
dist.destroy_process_group()
@skip_if_no_cuda @skip_if_no_cuda
def test_step_with_closure(): def test_step_with_closure():
world_size = min(2, torch.cuda.device_count()) world_size = min(2, torch.cuda.device_count())
mp.spawn(run_test_step_with_closure, args=(world_size,), nprocs=world_size, join=True) temp_file_name = tempfile.mkstemp()[1]
mp.spawn(run_test_step_with_closure, args=(world_size, temp_file_name), nprocs=world_size, join=True)
def run_test_sharding(rank, world_size):
dist_init(rank, world_size) def run_test_sharding(rank, world_size, tempfile_name):
dist_init(rank, world_size, tempfile_name)
params = [] params = []
for size in [5, 4, 2, 6, 4, 3]: for size in [5, 4, 2, 6, 4, 3]:
params.append(torch.rand(size, 1)) params.append(torch.rand(size, 1))
o = optim.OSS(params, lr=0.1) o = optim.OSS(params, lr=0.1)
assert sum([x.numel() for x in o.optim.param_groups[0]["params"]]) == 8 assert sum([x.numel() for x in o.optim.param_groups[0]["params"]]) == 8
dist.destroy_process_group()
def test_sharding(): def test_sharding():
world_size = 3 world_size = 3
mp.spawn(run_test_sharding, args=(world_size,), nprocs=world_size, join=True) temp_file_name = tempfile.mkstemp()[1]
mp.spawn(run_test_sharding, args=(world_size, temp_file_name), nprocs=world_size, join=True)
def run_test_collect_shards(rank, world_size, reference_rank): def run_test_collect_shards(rank, world_size, reference_rank, tempfile_name):
dist_init(rank, world_size) dist_init(rank, world_size, tempfile_name)
device = torch.device(rank) if torch.cuda.device_count() > 1 else DEVICE device = torch.device(rank) if torch.cuda.device_count() > 1 else DEVICE
# Run a dummy step so that the optimizer state dict exists # Run a dummy step so that the optimizer state dict exists
batch, input_width, hidden, target_width = 3, 20, 10, 5 batch, input_width, hidden, target_width = 3, 3, 3, 5
target = torch.rand((batch, target_width), device=device) target = torch.rand((batch, target_width), device=device)
inputs = torch.rand((batch, input_width), device=device) inputs = torch.rand((batch, input_width), device=device)
...@@ -343,24 +356,25 @@ def run_test_collect_shards(rank, world_size, reference_rank): ...@@ -343,24 +356,25 @@ def run_test_collect_shards(rank, world_size, reference_rank):
# Load the optimizer state dict # Load the optimizer state dict
optimizer.load_state_dict(optimizer_state_dict) optimizer.load_state_dict(optimizer_state_dict)
dist.destroy_process_group()
def test_collect_shards(): def test_collect_shards():
world_size = 3 world_size = 3
temp_file_name = tempfile.mkstemp()[1]
if torch.cuda.is_available(): if torch.cuda.is_available():
world_size = min(world_size, torch.cuda.device_count()) world_size = min(world_size, torch.cuda.device_count())
reference_rank = 0 reference_rank = 0
mp.spawn( mp.spawn(
run_test_collect_shards, args=(world_size, reference_rank), nprocs=world_size, join=True, run_test_collect_shards, args=(world_size, reference_rank, temp_file_name), nprocs=world_size, join=True,
) )
def run_test_multiple_groups(rank, world_size): def run_test_multiple_groups(rank, world_size, tempfile_name):
# Only work with the even ranks, to check that the global_rank indexing is properly used # Only work with the even ranks, to check that the global_rank indexing is properly used
os.environ["MASTER_ADDR"] = "localhost" dist_init(rank=rank, world_size=world_size, tempfile_name=tempfile_name, backend="gloo")
os.environ["MASTER_PORT"] = "29501"
dist.init_process_group(backend="gloo", rank=rank, world_size=world_size)
sub_group_ranks = [0, 2, 4] sub_group_ranks = [0, 2, 4]
process_group = torch.distributed.new_group(ranks=sub_group_ranks, backend="gloo") process_group = torch.distributed.new_group(ranks=sub_group_ranks, backend="gloo")
...@@ -422,10 +436,14 @@ def run_test_multiple_groups(rank, world_size): ...@@ -422,10 +436,14 @@ def run_test_multiple_groups(rank, world_size):
optimizer = optim.OSS(model.parameters(), lr=0.1, momentum=0.99, group=process_group, broadcast_buffer_size=0) optimizer = optim.OSS(model.parameters(), lr=0.1, momentum=0.99, group=process_group, broadcast_buffer_size=0)
check(optimizer) check(optimizer)
dist.destroy_process_group(process_group)
dist.destroy_process_group()
def test_multiple_groups(): def test_multiple_groups():
world_size = 6 world_size = 6
temp_file_name = tempfile.mkstemp()[1]
mp.spawn( mp.spawn(
run_test_multiple_groups, args=(world_size,), nprocs=world_size, join=True, run_test_multiple_groups, args=(world_size, temp_file_name), nprocs=world_size, join=True,
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment