Unverified Commit 543d5693 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[fix] OSS tests - remove concurrent dist inits (#177)

parent cc766aa5
......@@ -7,7 +7,9 @@
# pylint: disable=missing-class-docstring
# pylint: disable=missing-function-docstring
import os
import tempfile
import unittest
import numpy as np
import pytest
......@@ -23,28 +25,27 @@ BACKEND = dist.Backend.NCCL if torch.cuda.is_available() else dist.Backend.GLOO
DEVICE = "cuda" if torch.cuda.is_available() else torch.device("cpu")
def setup_module(module):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29500"
dist.init_process_group(backend=BACKEND, rank=0, world_size=1)
def dist_init(rank, world_size, tempfile_name, backend=BACKEND):
url = "file://" + tempfile_name
dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size)
def teardown_module(module):
torch.distributed.destroy_process_group()
class TestSingleRank(unittest.TestCase):
"""
All the following tests do not check for inter-process communication
"""
def dist_init(rank, world_size):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29501"
dist.init_process_group(backend=BACKEND, rank=rank, world_size=world_size)
def setUp(self):
dist_init(0, 1, tempfile.mkstemp()[1])
def tearDown(self):
torch.distributed.destroy_process_group()
def test_create():
def test_create(self):
params = [torch.rand(1)]
o = optim.OSS(params, lr=0.01)
def test_state_dict():
def test_state_dict(self):
x = torch.tensor([1.0], device=DEVICE, requires_grad=True)
o = optim.OSS([x], lr=0.1, momentum=0.9)
x.backward()
......@@ -88,8 +89,7 @@ def test_state_dict():
# Check that the exposed param_groups are on the proper device
assert o.param_groups[0]["params"][0].device == x.device
def test_lr_scheduler():
def test_lr_scheduler(self):
x = torch.tensor([1.0], device=DEVICE, requires_grad=True)
x2 = torch.tensor([1.0], device=DEVICE, requires_grad=True)
o = optim.OSS([x], lr=0.01)
......@@ -107,8 +107,7 @@ def test_lr_scheduler():
s2.step()
assert x == x2
def test_step_with_kwargs():
def test_step_with_kwargs(self):
class SGDWithStepKWArg(torch.optim.SGD):
def step(self, closure=None, kwarg=[]):
super().step()
......@@ -122,8 +121,7 @@ def test_step_with_kwargs():
assert kwarg == [5]
assert x == torch.tensor([0.9], device=DEVICE)
def test_step_with_extra_inner_key():
def test_step_with_extra_inner_key(self):
class SGDWithNewKey(torch.optim.SGD):
# Dummy optimizer which adds a new key to the param groups
def step(self, closure=None):
......@@ -137,8 +135,7 @@ def test_step_with_extra_inner_key():
assert o.param_groups[0]["new_key"] == 0.1
assert x == torch.tensor([0.9], device=DEVICE)
def test_step_without_closure():
def test_step_without_closure(self):
class SGDWithoutClosure(torch.optim.SGD):
def step(self):
return super().step()
......@@ -149,8 +146,7 @@ def test_step_without_closure():
o.step()
assert x == torch.tensor([0.9], device=DEVICE)
def test_local_state_dict():
def test_local_state_dict(self):
x = torch.tensor([1.0], device=DEVICE, requires_grad=True)
o = optim.OSS([x], lr=0.1)
local_state_dict = o.local_state_dict()
......@@ -163,8 +159,7 @@ def test_local_state_dict():
o.step()
assert x == torch.tensor([0.9], device=DEVICE)
def test_implicit_local_state_dict():
def test_implicit_local_state_dict(self):
x = torch.tensor([1.0], device=DEVICE, requires_grad=True)
o = optim.OSS([x], lr=0.1)
local_state_dict = o.state_dict()
......@@ -178,8 +173,8 @@ def test_implicit_local_state_dict():
assert x == torch.tensor([0.9], device=DEVICE)
def run_test_add_param_group(rank, world_size):
dist_init(rank, world_size)
def run_test_add_param_group(rank, world_size, tempfile_name):
dist_init(rank, world_size, tempfile_name)
params = []
for size in [4, 5, 2, 6, 4]:
params.append(torch.rand(size, 1))
......@@ -191,14 +186,17 @@ def run_test_add_param_group(rank, world_size):
assert sum([x.numel() for g in o.optim.param_groups for x in g["params"]]) == 8
assert len(o.optim.param_groups) == 2
dist.destroy_process_group()
def test_add_param_group():
world_size = 3
mp.spawn(run_test_add_param_group, args=(world_size,), nprocs=world_size, join=True)
temp_file_name = tempfile.mkstemp()[1]
mp.spawn(run_test_add_param_group, args=(world_size, temp_file_name), nprocs=world_size, join=True)
def run_test_zero_grad(rank, world_size):
dist_init(rank, world_size)
def run_test_zero_grad(rank, world_size, tempfile_name):
dist_init(rank, world_size, tempfile_name)
x = torch.rand(1)
m = torch.nn.Linear(1, 1)
o = optim.OSS(m.parameters(), lr=0.1)
......@@ -210,14 +208,17 @@ def run_test_zero_grad(rank, world_size):
assert not m.weight.grad
assert not m.bias.grad
dist.destroy_process_group()
def test_zero_grad():
world_size = 2
mp.spawn(run_test_zero_grad, args=(world_size,), nprocs=world_size, join=True)
temp_file_name = tempfile.mkstemp()[1]
mp.spawn(run_test_zero_grad, args=(world_size, temp_file_name), nprocs=world_size, join=True)
def run_test_step(rank, world_size):
dist_init(rank, world_size)
def run_test_step(rank, world_size, tempfile_name):
dist_init(rank, world_size, tempfile_name, backend="gloo")
x = torch.tensor([float(rank + 1)], device=rank)
m = torch.nn.Linear(1, 1)
m.weight.data = torch.tensor([[1.0]])
......@@ -233,15 +234,19 @@ def run_test_step(rank, world_size):
assert m.weight == torch.tensor([[0.75]], device=rank)
assert m.bias == torch.tensor([1.85], device=rank)
dist.destroy_process_group()
@skip_if_no_cuda
def test_step():
world_size = min(2, torch.cuda.device_count())
mp.spawn(run_test_step, args=(world_size,), nprocs=world_size, join=True)
temp_file_name = tempfile.mkstemp()[1]
mp.spawn(run_test_step, args=(world_size, temp_file_name), nprocs=world_size, join=True)
def run_test_step_with_closure(rank, world_size, optimizer=None):
dist_init(rank, world_size)
def run_test_step_with_closure(rank, world_size, tempfile_name, optimizer=None):
dist_init(rank, world_size, tempfile_name)
x_val = rank + 1
weight = 1.0
......@@ -277,33 +282,41 @@ def run_test_step_with_closure(rank, world_size, optimizer=None):
assert m.weight == torch.tensor([[1.1]], device=rank)
assert m.bias == torch.tensor([2.1], device=rank)
dist.destroy_process_group()
@skip_if_no_cuda
def test_step_with_closure():
world_size = min(2, torch.cuda.device_count())
mp.spawn(run_test_step_with_closure, args=(world_size,), nprocs=world_size, join=True)
temp_file_name = tempfile.mkstemp()[1]
mp.spawn(run_test_step_with_closure, args=(world_size, temp_file_name), nprocs=world_size, join=True)
def run_test_sharding(rank, world_size):
dist_init(rank, world_size)
def run_test_sharding(rank, world_size, tempfile_name):
dist_init(rank, world_size, tempfile_name)
params = []
for size in [5, 4, 2, 6, 4, 3]:
params.append(torch.rand(size, 1))
o = optim.OSS(params, lr=0.1)
assert sum([x.numel() for x in o.optim.param_groups[0]["params"]]) == 8
dist.destroy_process_group()
def test_sharding():
world_size = 3
mp.spawn(run_test_sharding, args=(world_size,), nprocs=world_size, join=True)
temp_file_name = tempfile.mkstemp()[1]
mp.spawn(run_test_sharding, args=(world_size, temp_file_name), nprocs=world_size, join=True)
def run_test_collect_shards(rank, world_size, reference_rank):
dist_init(rank, world_size)
def run_test_collect_shards(rank, world_size, reference_rank, tempfile_name):
dist_init(rank, world_size, tempfile_name)
device = torch.device(rank) if torch.cuda.device_count() > 1 else DEVICE
# Run a dummy step so that the optimizer state dict exists
batch, input_width, hidden, target_width = 3, 20, 10, 5
batch, input_width, hidden, target_width = 3, 3, 3, 5
target = torch.rand((batch, target_width), device=device)
inputs = torch.rand((batch, input_width), device=device)
......@@ -343,24 +356,25 @@ def run_test_collect_shards(rank, world_size, reference_rank):
# Load the optimizer state dict
optimizer.load_state_dict(optimizer_state_dict)
dist.destroy_process_group()
def test_collect_shards():
world_size = 3
temp_file_name = tempfile.mkstemp()[1]
if torch.cuda.is_available():
world_size = min(world_size, torch.cuda.device_count())
reference_rank = 0
mp.spawn(
run_test_collect_shards, args=(world_size, reference_rank), nprocs=world_size, join=True,
run_test_collect_shards, args=(world_size, reference_rank, temp_file_name), nprocs=world_size, join=True,
)
def run_test_multiple_groups(rank, world_size):
def run_test_multiple_groups(rank, world_size, tempfile_name):
# Only work with the even ranks, to check that the global_rank indexing is properly used
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29501"
dist.init_process_group(backend="gloo", rank=rank, world_size=world_size)
dist_init(rank=rank, world_size=world_size, tempfile_name=tempfile_name, backend="gloo")
sub_group_ranks = [0, 2, 4]
process_group = torch.distributed.new_group(ranks=sub_group_ranks, backend="gloo")
......@@ -422,10 +436,14 @@ def run_test_multiple_groups(rank, world_size):
optimizer = optim.OSS(model.parameters(), lr=0.1, momentum=0.99, group=process_group, broadcast_buffer_size=0)
check(optimizer)
dist.destroy_process_group(process_group)
dist.destroy_process_group()
def test_multiple_groups():
world_size = 6
temp_file_name = tempfile.mkstemp()[1]
mp.spawn(
run_test_multiple_groups, args=(world_size,), nprocs=world_size, join=True,
run_test_multiple_groups, args=(world_size, temp_file_name), nprocs=world_size, join=True,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment