Unverified Commit 543d5693 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[fix] OSS tests - remove concurrent dist inits (#177)

parent cc766aa5
......@@ -7,7 +7,9 @@
# pylint: disable=missing-class-docstring
# pylint: disable=missing-function-docstring
import os
import tempfile
import unittest
import numpy as np
import pytest
......@@ -23,163 +25,156 @@ BACKEND = dist.Backend.NCCL if torch.cuda.is_available() else dist.Backend.GLOO
DEVICE = "cuda" if torch.cuda.is_available() else torch.device("cpu")
def setup_module(module):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29500"
dist.init_process_group(backend=BACKEND, rank=0, world_size=1)
def teardown_module(module):
torch.distributed.destroy_process_group()
def dist_init(rank, world_size):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29501"
dist.init_process_group(backend=BACKEND, rank=rank, world_size=world_size)
def dist_init(rank, world_size, tempfile_name, backend=BACKEND):
url = "file://" + tempfile_name
dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size)
def test_create():
params = [torch.rand(1)]
o = optim.OSS(params, lr=0.01)
class TestSingleRank(unittest.TestCase):
"""
All the following tests do not check for inter-process communication
"""
def setUp(self):
dist_init(0, 1, tempfile.mkstemp()[1])
def test_state_dict():
x = torch.tensor([1.0], device=DEVICE, requires_grad=True)
o = optim.OSS([x], lr=0.1, momentum=0.9)
x.backward()
o.step()
assert x == torch.tensor([0.9], device=DEVICE)
assert o.optim.state[x]["momentum_buffer"] == torch.tensor([1.0], device=DEVICE)
o.zero_grad()
o.consolidate_state_dict() # Sync state dict in between replicas - even if there are none
state_dict = o.state_dict()
# Check that the state dict is pytorch-compliant key wise
assert "param_groups" in state_dict.keys()
assert "state" in state_dict.keys()
# Check that the pulled state is what we expect, and that we have all the expected keys
assert state_dict["param_groups"][0]["lr"] == 0.1
assert state_dict["param_groups"][0]["momentum"] == 0.9
assert not state_dict["param_groups"][0]["nesterov"]
assert state_dict["param_groups"][0]["weight_decay"] == 0.0
assert state_dict["param_groups"][0]["dampening"] == 0.0
# Check that the pulled state and the .param_groups attribute are in sync
for k in state_dict["param_groups"][0].keys():
if k != "params":
assert state_dict["param_groups"][0][k] == o.param_groups[0][k]
# Check that it's correctly loaded
o = optim.OSS([x], lr=0.01)
o.load_state_dict(state_dict)
# Check that state is correct and on proper device
assert o.optim.state[x]["momentum_buffer"] == torch.tensor([1.0], device=DEVICE)
# We should now be using a lr of 0.1, both within the optimizer
# and as exposed by the .param_groups attribute
assert o.param_groups[0]["lr"] == 0.1
x.backward()
o.step()
assert x == torch.tensor([0.71], device=DEVICE)
assert o.optim.state[x]["momentum_buffer"] == torch.tensor([1.9], device=DEVICE)
# Check that the exposed param_groups are on the proper device
assert o.param_groups[0]["params"][0].device == x.device
def tearDown(self):
torch.distributed.destroy_process_group()
def test_create(self):
params = [torch.rand(1)]
o = optim.OSS(params, lr=0.01)
def test_lr_scheduler():
x = torch.tensor([1.0], device=DEVICE, requires_grad=True)
x2 = torch.tensor([1.0], device=DEVICE, requires_grad=True)
o = optim.OSS([x], lr=0.01)
o2 = torch.optim.SGD([x2], lr=0.01)
s = torch.optim.lr_scheduler.StepLR(o, 1)
s2 = torch.optim.lr_scheduler.StepLR(o2, 1)
for _ in range(5):
def test_state_dict(self):
x = torch.tensor([1.0], device=DEVICE, requires_grad=True)
o = optim.OSS([x], lr=0.1, momentum=0.9)
x.backward()
o.step()
assert x == torch.tensor([0.9], device=DEVICE)
assert o.optim.state[x]["momentum_buffer"] == torch.tensor([1.0], device=DEVICE)
o.zero_grad()
o.consolidate_state_dict() # Sync state dict in between replicas - even if there are none
state_dict = o.state_dict()
# Check that the state dict is pytorch-compliant key wise
assert "param_groups" in state_dict.keys()
assert "state" in state_dict.keys()
# Check that the pulled state is what we expect, and that we have all the expected keys
assert state_dict["param_groups"][0]["lr"] == 0.1
assert state_dict["param_groups"][0]["momentum"] == 0.9
assert not state_dict["param_groups"][0]["nesterov"]
assert state_dict["param_groups"][0]["weight_decay"] == 0.0
assert state_dict["param_groups"][0]["dampening"] == 0.0
# Check that the pulled state and the .param_groups attribute are in sync
for k in state_dict["param_groups"][0].keys():
if k != "params":
assert state_dict["param_groups"][0][k] == o.param_groups[0][k]
# Check that it's correctly loaded
o = optim.OSS([x], lr=0.01)
o.load_state_dict(state_dict)
# Check that state is correct and on proper device
assert o.optim.state[x]["momentum_buffer"] == torch.tensor([1.0], device=DEVICE)
# We should now be using a lr of 0.1, both within the optimizer
# and as exposed by the .param_groups attribute
assert o.param_groups[0]["lr"] == 0.1
x.backward()
o.step()
s.step()
x2.backward()
o2.zero_grad()
o2.step()
s2.step()
assert x == x2
def test_step_with_kwargs():
class SGDWithStepKWArg(torch.optim.SGD):
def step(self, closure=None, kwarg=[]):
super().step()
kwarg.append(5)
kwarg = []
x = torch.tensor([1.0], device=DEVICE, requires_grad=True)
o = optim.OSS([x], SGDWithStepKWArg, lr=0.1)
x.backward()
o.step(0, kwarg=kwarg)
assert kwarg == [5]
assert x == torch.tensor([0.9], device=DEVICE)
def test_step_with_extra_inner_key():
class SGDWithNewKey(torch.optim.SGD):
# Dummy optimizer which adds a new key to the param groups
def step(self, closure=None):
super().step()
self.param_groups[0]["new_key"] = 0.1
x = torch.tensor([1.0], device=DEVICE, requires_grad=True)
o = optim.OSS([x], SGDWithNewKey, lr=0.1)
x.backward()
o.step()
assert o.param_groups[0]["new_key"] == 0.1
assert x == torch.tensor([0.9], device=DEVICE)
assert x == torch.tensor([0.71], device=DEVICE)
assert o.optim.state[x]["momentum_buffer"] == torch.tensor([1.9], device=DEVICE)
# Check that the exposed param_groups are on the proper device
assert o.param_groups[0]["params"][0].device == x.device
def test_lr_scheduler(self):
x = torch.tensor([1.0], device=DEVICE, requires_grad=True)
x2 = torch.tensor([1.0], device=DEVICE, requires_grad=True)
o = optim.OSS([x], lr=0.01)
o2 = torch.optim.SGD([x2], lr=0.01)
s = torch.optim.lr_scheduler.StepLR(o, 1)
s2 = torch.optim.lr_scheduler.StepLR(o2, 1)
for _ in range(5):
x.backward()
o.zero_grad()
o.step()
s.step()
x2.backward()
o2.zero_grad()
o2.step()
s2.step()
assert x == x2
def test_step_with_kwargs(self):
class SGDWithStepKWArg(torch.optim.SGD):
def step(self, closure=None, kwarg=[]):
super().step()
kwarg.append(5)
kwarg = []
x = torch.tensor([1.0], device=DEVICE, requires_grad=True)
o = optim.OSS([x], SGDWithStepKWArg, lr=0.1)
x.backward()
o.step(0, kwarg=kwarg)
assert kwarg == [5]
assert x == torch.tensor([0.9], device=DEVICE)
def test_step_with_extra_inner_key(self):
class SGDWithNewKey(torch.optim.SGD):
# Dummy optimizer which adds a new key to the param groups
def step(self, closure=None):
super().step()
self.param_groups[0]["new_key"] = 0.1
x = torch.tensor([1.0], device=DEVICE, requires_grad=True)
o = optim.OSS([x], SGDWithNewKey, lr=0.1)
x.backward()
o.step()
assert o.param_groups[0]["new_key"] == 0.1
assert x == torch.tensor([0.9], device=DEVICE)
def test_step_without_closure():
class SGDWithoutClosure(torch.optim.SGD):
def step(self):
return super().step()
def test_step_without_closure(self):
class SGDWithoutClosure(torch.optim.SGD):
def step(self):
return super().step()
x = torch.tensor([1.0], device=DEVICE, requires_grad=True)
o = optim.OSS([x], SGDWithoutClosure, lr=0.1)
x.backward()
o.step()
assert x == torch.tensor([0.9], device=DEVICE)
def test_local_state_dict():
x = torch.tensor([1.0], device=DEVICE, requires_grad=True)
o = optim.OSS([x], lr=0.1)
local_state_dict = o.local_state_dict()
o = optim.OSS([x], lr=0.01)
o.load_local_state_dict(local_state_dict)
# We should now be using a lr of 0.1.
assert o.optim.param_groups[0]["lr"] == 0.1
assert o.param_groups[0]["lr"] == 0.1
x.backward()
o.step()
assert x == torch.tensor([0.9], device=DEVICE)
def test_implicit_local_state_dict():
x = torch.tensor([1.0], device=DEVICE, requires_grad=True)
o = optim.OSS([x], lr=0.1)
local_state_dict = o.state_dict()
o = optim.OSS([x], lr=0.01)
o.load_state_dict(local_state_dict)
# We should now be using a lr of 0.1.
assert o.optim.param_groups[0]["lr"] == 0.1
assert o.param_groups[0]["lr"] == 0.1
x.backward()
o.step()
assert x == torch.tensor([0.9], device=DEVICE)
x = torch.tensor([1.0], device=DEVICE, requires_grad=True)
o = optim.OSS([x], SGDWithoutClosure, lr=0.1)
x.backward()
o.step()
assert x == torch.tensor([0.9], device=DEVICE)
def test_local_state_dict(self):
x = torch.tensor([1.0], device=DEVICE, requires_grad=True)
o = optim.OSS([x], lr=0.1)
local_state_dict = o.local_state_dict()
o = optim.OSS([x], lr=0.01)
o.load_local_state_dict(local_state_dict)
# We should now be using a lr of 0.1.
assert o.optim.param_groups[0]["lr"] == 0.1
assert o.param_groups[0]["lr"] == 0.1
x.backward()
o.step()
assert x == torch.tensor([0.9], device=DEVICE)
def test_implicit_local_state_dict(self):
x = torch.tensor([1.0], device=DEVICE, requires_grad=True)
o = optim.OSS([x], lr=0.1)
local_state_dict = o.state_dict()
o = optim.OSS([x], lr=0.01)
o.load_state_dict(local_state_dict)
# We should now be using a lr of 0.1.
assert o.optim.param_groups[0]["lr"] == 0.1
assert o.param_groups[0]["lr"] == 0.1
x.backward()
o.step()
assert x == torch.tensor([0.9], device=DEVICE)
def run_test_add_param_group(rank, world_size):
dist_init(rank, world_size)
def run_test_add_param_group(rank, world_size, tempfile_name):
dist_init(rank, world_size, tempfile_name)
params = []
for size in [4, 5, 2, 6, 4]:
params.append(torch.rand(size, 1))
......@@ -191,14 +186,17 @@ def run_test_add_param_group(rank, world_size):
assert sum([x.numel() for g in o.optim.param_groups for x in g["params"]]) == 8
assert len(o.optim.param_groups) == 2
dist.destroy_process_group()
def test_add_param_group():
world_size = 3
mp.spawn(run_test_add_param_group, args=(world_size,), nprocs=world_size, join=True)
temp_file_name = tempfile.mkstemp()[1]
mp.spawn(run_test_add_param_group, args=(world_size, temp_file_name), nprocs=world_size, join=True)
def run_test_zero_grad(rank, world_size):
dist_init(rank, world_size)
def run_test_zero_grad(rank, world_size, tempfile_name):
dist_init(rank, world_size, tempfile_name)
x = torch.rand(1)
m = torch.nn.Linear(1, 1)
o = optim.OSS(m.parameters(), lr=0.1)
......@@ -210,14 +208,17 @@ def run_test_zero_grad(rank, world_size):
assert not m.weight.grad
assert not m.bias.grad
dist.destroy_process_group()
def test_zero_grad():
world_size = 2
mp.spawn(run_test_zero_grad, args=(world_size,), nprocs=world_size, join=True)
temp_file_name = tempfile.mkstemp()[1]
mp.spawn(run_test_zero_grad, args=(world_size, temp_file_name), nprocs=world_size, join=True)
def run_test_step(rank, world_size):
dist_init(rank, world_size)
def run_test_step(rank, world_size, tempfile_name):
dist_init(rank, world_size, tempfile_name, backend="gloo")
x = torch.tensor([float(rank + 1)], device=rank)
m = torch.nn.Linear(1, 1)
m.weight.data = torch.tensor([[1.0]])
......@@ -233,15 +234,19 @@ def run_test_step(rank, world_size):
assert m.weight == torch.tensor([[0.75]], device=rank)
assert m.bias == torch.tensor([1.85], device=rank)
dist.destroy_process_group()
@skip_if_no_cuda
def test_step():
world_size = min(2, torch.cuda.device_count())
mp.spawn(run_test_step, args=(world_size,), nprocs=world_size, join=True)
temp_file_name = tempfile.mkstemp()[1]
mp.spawn(run_test_step, args=(world_size, temp_file_name), nprocs=world_size, join=True)
def run_test_step_with_closure(rank, world_size, optimizer=None):
dist_init(rank, world_size)
def run_test_step_with_closure(rank, world_size, tempfile_name, optimizer=None):
dist_init(rank, world_size, tempfile_name)
x_val = rank + 1
weight = 1.0
......@@ -277,33 +282,41 @@ def run_test_step_with_closure(rank, world_size, optimizer=None):
assert m.weight == torch.tensor([[1.1]], device=rank)
assert m.bias == torch.tensor([2.1], device=rank)
dist.destroy_process_group()
@skip_if_no_cuda
def test_step_with_closure():
world_size = min(2, torch.cuda.device_count())
mp.spawn(run_test_step_with_closure, args=(world_size,), nprocs=world_size, join=True)
temp_file_name = tempfile.mkstemp()[1]
mp.spawn(run_test_step_with_closure, args=(world_size, temp_file_name), nprocs=world_size, join=True)
def run_test_sharding(rank, world_size):
dist_init(rank, world_size)
def run_test_sharding(rank, world_size, tempfile_name):
dist_init(rank, world_size, tempfile_name)
params = []
for size in [5, 4, 2, 6, 4, 3]:
params.append(torch.rand(size, 1))
o = optim.OSS(params, lr=0.1)
assert sum([x.numel() for x in o.optim.param_groups[0]["params"]]) == 8
dist.destroy_process_group()
def test_sharding():
world_size = 3
mp.spawn(run_test_sharding, args=(world_size,), nprocs=world_size, join=True)
temp_file_name = tempfile.mkstemp()[1]
mp.spawn(run_test_sharding, args=(world_size, temp_file_name), nprocs=world_size, join=True)
def run_test_collect_shards(rank, world_size, reference_rank):
dist_init(rank, world_size)
def run_test_collect_shards(rank, world_size, reference_rank, tempfile_name):
dist_init(rank, world_size, tempfile_name)
device = torch.device(rank) if torch.cuda.device_count() > 1 else DEVICE
# Run a dummy step so that the optimizer state dict exists
batch, input_width, hidden, target_width = 3, 20, 10, 5
batch, input_width, hidden, target_width = 3, 3, 3, 5
target = torch.rand((batch, target_width), device=device)
inputs = torch.rand((batch, input_width), device=device)
......@@ -343,24 +356,25 @@ def run_test_collect_shards(rank, world_size, reference_rank):
# Load the optimizer state dict
optimizer.load_state_dict(optimizer_state_dict)
dist.destroy_process_group()
def test_collect_shards():
world_size = 3
temp_file_name = tempfile.mkstemp()[1]
if torch.cuda.is_available():
world_size = min(world_size, torch.cuda.device_count())
reference_rank = 0
mp.spawn(
run_test_collect_shards, args=(world_size, reference_rank), nprocs=world_size, join=True,
run_test_collect_shards, args=(world_size, reference_rank, temp_file_name), nprocs=world_size, join=True,
)
def run_test_multiple_groups(rank, world_size):
def run_test_multiple_groups(rank, world_size, tempfile_name):
# Only work with the even ranks, to check that the global_rank indexing is properly used
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29501"
dist.init_process_group(backend="gloo", rank=rank, world_size=world_size)
dist_init(rank=rank, world_size=world_size, tempfile_name=tempfile_name, backend="gloo")
sub_group_ranks = [0, 2, 4]
process_group = torch.distributed.new_group(ranks=sub_group_ranks, backend="gloo")
......@@ -422,10 +436,14 @@ def run_test_multiple_groups(rank, world_size):
optimizer = optim.OSS(model.parameters(), lr=0.1, momentum=0.99, group=process_group, broadcast_buffer_size=0)
check(optimizer)
dist.destroy_process_group(process_group)
dist.destroy_process_group()
def test_multiple_groups():
world_size = 6
temp_file_name = tempfile.mkstemp()[1]
mp.spawn(
run_test_multiple_groups, args=(world_size,), nprocs=world_size, join=True,
run_test_multiple_groups, args=(world_size, temp_file_name), nprocs=world_size, join=True,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment