Unverified Commit 543d5693 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[fix] OSS tests - remove concurrent dist inits (#177)

parent cc766aa5
...@@ -7,7 +7,9 @@ ...@@ -7,7 +7,9 @@
# pylint: disable=missing-class-docstring # pylint: disable=missing-class-docstring
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
import os
import tempfile
import unittest
import numpy as np import numpy as np
import pytest import pytest
...@@ -23,163 +25,156 @@ BACKEND = dist.Backend.NCCL if torch.cuda.is_available() else dist.Backend.GLOO ...@@ -23,163 +25,156 @@ BACKEND = dist.Backend.NCCL if torch.cuda.is_available() else dist.Backend.GLOO
DEVICE = "cuda" if torch.cuda.is_available() else torch.device("cpu") DEVICE = "cuda" if torch.cuda.is_available() else torch.device("cpu")
def setup_module(module): def dist_init(rank, world_size, tempfile_name, backend=BACKEND):
os.environ["MASTER_ADDR"] = "localhost" url = "file://" + tempfile_name
os.environ["MASTER_PORT"] = "29500" dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size)
dist.init_process_group(backend=BACKEND, rank=0, world_size=1)
def teardown_module(module):
torch.distributed.destroy_process_group()
def dist_init(rank, world_size):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29501"
dist.init_process_group(backend=BACKEND, rank=rank, world_size=world_size)
def test_create(): class TestSingleRank(unittest.TestCase):
params = [torch.rand(1)] """
o = optim.OSS(params, lr=0.01) All the following tests do not check for inter-process communication
"""
def setUp(self):
dist_init(0, 1, tempfile.mkstemp()[1])
def test_state_dict(): def tearDown(self):
x = torch.tensor([1.0], device=DEVICE, requires_grad=True) torch.distributed.destroy_process_group()
o = optim.OSS([x], lr=0.1, momentum=0.9)
x.backward()
o.step()
assert x == torch.tensor([0.9], device=DEVICE)
assert o.optim.state[x]["momentum_buffer"] == torch.tensor([1.0], device=DEVICE)
o.zero_grad()
o.consolidate_state_dict() # Sync state dict in between replicas - even if there are none
state_dict = o.state_dict()
# Check that the state dict is pytorch-compliant key wise
assert "param_groups" in state_dict.keys()
assert "state" in state_dict.keys()
# Check that the pulled state is what we expect, and that we have all the expected keys
assert state_dict["param_groups"][0]["lr"] == 0.1
assert state_dict["param_groups"][0]["momentum"] == 0.9
assert not state_dict["param_groups"][0]["nesterov"]
assert state_dict["param_groups"][0]["weight_decay"] == 0.0
assert state_dict["param_groups"][0]["dampening"] == 0.0
# Check that the pulled state and the .param_groups attribute are in sync
for k in state_dict["param_groups"][0].keys():
if k != "params":
assert state_dict["param_groups"][0][k] == o.param_groups[0][k]
# Check that it's correctly loaded
o = optim.OSS([x], lr=0.01)
o.load_state_dict(state_dict)
# Check that state is correct and on proper device
assert o.optim.state[x]["momentum_buffer"] == torch.tensor([1.0], device=DEVICE)
# We should now be using a lr of 0.1, both within the optimizer
# and as exposed by the .param_groups attribute
assert o.param_groups[0]["lr"] == 0.1
x.backward()
o.step()
assert x == torch.tensor([0.71], device=DEVICE)
assert o.optim.state[x]["momentum_buffer"] == torch.tensor([1.9], device=DEVICE)
# Check that the exposed param_groups are on the proper device
assert o.param_groups[0]["params"][0].device == x.device
def test_create(self):
params = [torch.rand(1)]
o = optim.OSS(params, lr=0.01)
def test_lr_scheduler(): def test_state_dict(self):
x = torch.tensor([1.0], device=DEVICE, requires_grad=True) x = torch.tensor([1.0], device=DEVICE, requires_grad=True)
x2 = torch.tensor([1.0], device=DEVICE, requires_grad=True) o = optim.OSS([x], lr=0.1, momentum=0.9)
o = optim.OSS([x], lr=0.01)
o2 = torch.optim.SGD([x2], lr=0.01)
s = torch.optim.lr_scheduler.StepLR(o, 1)
s2 = torch.optim.lr_scheduler.StepLR(o2, 1)
for _ in range(5):
x.backward() x.backward()
o.step()
assert x == torch.tensor([0.9], device=DEVICE)
assert o.optim.state[x]["momentum_buffer"] == torch.tensor([1.0], device=DEVICE)
o.zero_grad() o.zero_grad()
o.consolidate_state_dict() # Sync state dict in between replicas - even if there are none
state_dict = o.state_dict()
# Check that the state dict is pytorch-compliant key wise
assert "param_groups" in state_dict.keys()
assert "state" in state_dict.keys()
# Check that the pulled state is what we expect, and that we have all the expected keys
assert state_dict["param_groups"][0]["lr"] == 0.1
assert state_dict["param_groups"][0]["momentum"] == 0.9
assert not state_dict["param_groups"][0]["nesterov"]
assert state_dict["param_groups"][0]["weight_decay"] == 0.0
assert state_dict["param_groups"][0]["dampening"] == 0.0
# Check that the pulled state and the .param_groups attribute are in sync
for k in state_dict["param_groups"][0].keys():
if k != "params":
assert state_dict["param_groups"][0][k] == o.param_groups[0][k]
# Check that it's correctly loaded
o = optim.OSS([x], lr=0.01)
o.load_state_dict(state_dict)
# Check that state is correct and on proper device
assert o.optim.state[x]["momentum_buffer"] == torch.tensor([1.0], device=DEVICE)
# We should now be using a lr of 0.1, both within the optimizer
# and as exposed by the .param_groups attribute
assert o.param_groups[0]["lr"] == 0.1
x.backward()
o.step() o.step()
s.step() assert x == torch.tensor([0.71], device=DEVICE)
x2.backward() assert o.optim.state[x]["momentum_buffer"] == torch.tensor([1.9], device=DEVICE)
o2.zero_grad()
o2.step() # Check that the exposed param_groups are on the proper device
s2.step() assert o.param_groups[0]["params"][0].device == x.device
assert x == x2
def test_lr_scheduler(self):
x = torch.tensor([1.0], device=DEVICE, requires_grad=True)
def test_step_with_kwargs(): x2 = torch.tensor([1.0], device=DEVICE, requires_grad=True)
class SGDWithStepKWArg(torch.optim.SGD): o = optim.OSS([x], lr=0.01)
def step(self, closure=None, kwarg=[]): o2 = torch.optim.SGD([x2], lr=0.01)
super().step() s = torch.optim.lr_scheduler.StepLR(o, 1)
kwarg.append(5) s2 = torch.optim.lr_scheduler.StepLR(o2, 1)
for _ in range(5):
kwarg = [] x.backward()
x = torch.tensor([1.0], device=DEVICE, requires_grad=True) o.zero_grad()
o = optim.OSS([x], SGDWithStepKWArg, lr=0.1) o.step()
x.backward() s.step()
o.step(0, kwarg=kwarg) x2.backward()
assert kwarg == [5] o2.zero_grad()
assert x == torch.tensor([0.9], device=DEVICE) o2.step()
s2.step()
assert x == x2
def test_step_with_extra_inner_key():
class SGDWithNewKey(torch.optim.SGD): def test_step_with_kwargs(self):
# Dummy optimizer which adds a new key to the param groups class SGDWithStepKWArg(torch.optim.SGD):
def step(self, closure=None): def step(self, closure=None, kwarg=[]):
super().step() super().step()
self.param_groups[0]["new_key"] = 0.1 kwarg.append(5)
x = torch.tensor([1.0], device=DEVICE, requires_grad=True) kwarg = []
o = optim.OSS([x], SGDWithNewKey, lr=0.1) x = torch.tensor([1.0], device=DEVICE, requires_grad=True)
x.backward() o = optim.OSS([x], SGDWithStepKWArg, lr=0.1)
o.step() x.backward()
assert o.param_groups[0]["new_key"] == 0.1 o.step(0, kwarg=kwarg)
assert x == torch.tensor([0.9], device=DEVICE) assert kwarg == [5]
assert x == torch.tensor([0.9], device=DEVICE)
def test_step_with_extra_inner_key(self):
class SGDWithNewKey(torch.optim.SGD):
# Dummy optimizer which adds a new key to the param groups
def step(self, closure=None):
super().step()
self.param_groups[0]["new_key"] = 0.1
x = torch.tensor([1.0], device=DEVICE, requires_grad=True)
o = optim.OSS([x], SGDWithNewKey, lr=0.1)
x.backward()
o.step()
assert o.param_groups[0]["new_key"] == 0.1
assert x == torch.tensor([0.9], device=DEVICE)
def test_step_without_closure(): def test_step_without_closure(self):
class SGDWithoutClosure(torch.optim.SGD): class SGDWithoutClosure(torch.optim.SGD):
def step(self): def step(self):
return super().step() return super().step()
x = torch.tensor([1.0], device=DEVICE, requires_grad=True) x = torch.tensor([1.0], device=DEVICE, requires_grad=True)
o = optim.OSS([x], SGDWithoutClosure, lr=0.1) o = optim.OSS([x], SGDWithoutClosure, lr=0.1)
x.backward() x.backward()
o.step() o.step()
assert x == torch.tensor([0.9], device=DEVICE) assert x == torch.tensor([0.9], device=DEVICE)
def test_local_state_dict(self):
def test_local_state_dict(): x = torch.tensor([1.0], device=DEVICE, requires_grad=True)
x = torch.tensor([1.0], device=DEVICE, requires_grad=True) o = optim.OSS([x], lr=0.1)
o = optim.OSS([x], lr=0.1) local_state_dict = o.local_state_dict()
local_state_dict = o.local_state_dict() o = optim.OSS([x], lr=0.01)
o = optim.OSS([x], lr=0.01) o.load_local_state_dict(local_state_dict)
o.load_local_state_dict(local_state_dict) # We should now be using a lr of 0.1.
# We should now be using a lr of 0.1. assert o.optim.param_groups[0]["lr"] == 0.1
assert o.optim.param_groups[0]["lr"] == 0.1 assert o.param_groups[0]["lr"] == 0.1
assert o.param_groups[0]["lr"] == 0.1 x.backward()
x.backward() o.step()
o.step() assert x == torch.tensor([0.9], device=DEVICE)
assert x == torch.tensor([0.9], device=DEVICE)
def test_implicit_local_state_dict(self):
x = torch.tensor([1.0], device=DEVICE, requires_grad=True)
def test_implicit_local_state_dict(): o = optim.OSS([x], lr=0.1)
x = torch.tensor([1.0], device=DEVICE, requires_grad=True) local_state_dict = o.state_dict()
o = optim.OSS([x], lr=0.1) o = optim.OSS([x], lr=0.01)
local_state_dict = o.state_dict() o.load_state_dict(local_state_dict)
o = optim.OSS([x], lr=0.01) # We should now be using a lr of 0.1.
o.load_state_dict(local_state_dict) assert o.optim.param_groups[0]["lr"] == 0.1
# We should now be using a lr of 0.1. assert o.param_groups[0]["lr"] == 0.1
assert o.optim.param_groups[0]["lr"] == 0.1 x.backward()
assert o.param_groups[0]["lr"] == 0.1 o.step()
x.backward() assert x == torch.tensor([0.9], device=DEVICE)
o.step()
assert x == torch.tensor([0.9], device=DEVICE)
def run_test_add_param_group(rank, world_size): def run_test_add_param_group(rank, world_size, tempfile_name):
dist_init(rank, world_size) dist_init(rank, world_size, tempfile_name)
params = [] params = []
for size in [4, 5, 2, 6, 4]: for size in [4, 5, 2, 6, 4]:
params.append(torch.rand(size, 1)) params.append(torch.rand(size, 1))
...@@ -191,14 +186,17 @@ def run_test_add_param_group(rank, world_size): ...@@ -191,14 +186,17 @@ def run_test_add_param_group(rank, world_size):
assert sum([x.numel() for g in o.optim.param_groups for x in g["params"]]) == 8 assert sum([x.numel() for g in o.optim.param_groups for x in g["params"]]) == 8
assert len(o.optim.param_groups) == 2 assert len(o.optim.param_groups) == 2
dist.destroy_process_group()
def test_add_param_group(): def test_add_param_group():
world_size = 3 world_size = 3
mp.spawn(run_test_add_param_group, args=(world_size,), nprocs=world_size, join=True) temp_file_name = tempfile.mkstemp()[1]
mp.spawn(run_test_add_param_group, args=(world_size, temp_file_name), nprocs=world_size, join=True)
def run_test_zero_grad(rank, world_size): def run_test_zero_grad(rank, world_size, tempfile_name):
dist_init(rank, world_size) dist_init(rank, world_size, tempfile_name)
x = torch.rand(1) x = torch.rand(1)
m = torch.nn.Linear(1, 1) m = torch.nn.Linear(1, 1)
o = optim.OSS(m.parameters(), lr=0.1) o = optim.OSS(m.parameters(), lr=0.1)
...@@ -210,14 +208,17 @@ def run_test_zero_grad(rank, world_size): ...@@ -210,14 +208,17 @@ def run_test_zero_grad(rank, world_size):
assert not m.weight.grad assert not m.weight.grad
assert not m.bias.grad assert not m.bias.grad
dist.destroy_process_group()
def test_zero_grad(): def test_zero_grad():
world_size = 2 world_size = 2
mp.spawn(run_test_zero_grad, args=(world_size,), nprocs=world_size, join=True) temp_file_name = tempfile.mkstemp()[1]
mp.spawn(run_test_zero_grad, args=(world_size, temp_file_name), nprocs=world_size, join=True)
def run_test_step(rank, world_size): def run_test_step(rank, world_size, tempfile_name):
dist_init(rank, world_size) dist_init(rank, world_size, tempfile_name, backend="gloo")
x = torch.tensor([float(rank + 1)], device=rank) x = torch.tensor([float(rank + 1)], device=rank)
m = torch.nn.Linear(1, 1) m = torch.nn.Linear(1, 1)
m.weight.data = torch.tensor([[1.0]]) m.weight.data = torch.tensor([[1.0]])
...@@ -233,15 +234,19 @@ def run_test_step(rank, world_size): ...@@ -233,15 +234,19 @@ def run_test_step(rank, world_size):
assert m.weight == torch.tensor([[0.75]], device=rank) assert m.weight == torch.tensor([[0.75]], device=rank)
assert m.bias == torch.tensor([1.85], device=rank) assert m.bias == torch.tensor([1.85], device=rank)
dist.destroy_process_group()
@skip_if_no_cuda @skip_if_no_cuda
def test_step(): def test_step():
world_size = min(2, torch.cuda.device_count()) world_size = min(2, torch.cuda.device_count())
mp.spawn(run_test_step, args=(world_size,), nprocs=world_size, join=True) temp_file_name = tempfile.mkstemp()[1]
mp.spawn(run_test_step, args=(world_size, temp_file_name), nprocs=world_size, join=True)
def run_test_step_with_closure(rank, world_size, optimizer=None): def run_test_step_with_closure(rank, world_size, tempfile_name, optimizer=None):
dist_init(rank, world_size) dist_init(rank, world_size, tempfile_name)
x_val = rank + 1 x_val = rank + 1
weight = 1.0 weight = 1.0
...@@ -277,33 +282,41 @@ def run_test_step_with_closure(rank, world_size, optimizer=None): ...@@ -277,33 +282,41 @@ def run_test_step_with_closure(rank, world_size, optimizer=None):
assert m.weight == torch.tensor([[1.1]], device=rank) assert m.weight == torch.tensor([[1.1]], device=rank)
assert m.bias == torch.tensor([2.1], device=rank) assert m.bias == torch.tensor([2.1], device=rank)
dist.destroy_process_group()
@skip_if_no_cuda @skip_if_no_cuda
def test_step_with_closure(): def test_step_with_closure():
world_size = min(2, torch.cuda.device_count()) world_size = min(2, torch.cuda.device_count())
mp.spawn(run_test_step_with_closure, args=(world_size,), nprocs=world_size, join=True) temp_file_name = tempfile.mkstemp()[1]
mp.spawn(run_test_step_with_closure, args=(world_size, temp_file_name), nprocs=world_size, join=True)
def run_test_sharding(rank, world_size): def run_test_sharding(rank, world_size, tempfile_name):
dist_init(rank, world_size) dist_init(rank, world_size, tempfile_name)
params = [] params = []
for size in [5, 4, 2, 6, 4, 3]: for size in [5, 4, 2, 6, 4, 3]:
params.append(torch.rand(size, 1)) params.append(torch.rand(size, 1))
o = optim.OSS(params, lr=0.1) o = optim.OSS(params, lr=0.1)
assert sum([x.numel() for x in o.optim.param_groups[0]["params"]]) == 8 assert sum([x.numel() for x in o.optim.param_groups[0]["params"]]) == 8
dist.destroy_process_group()
def test_sharding(): def test_sharding():
world_size = 3 world_size = 3
mp.spawn(run_test_sharding, args=(world_size,), nprocs=world_size, join=True) temp_file_name = tempfile.mkstemp()[1]
mp.spawn(run_test_sharding, args=(world_size, temp_file_name), nprocs=world_size, join=True)
def run_test_collect_shards(rank, world_size, reference_rank): def run_test_collect_shards(rank, world_size, reference_rank, tempfile_name):
dist_init(rank, world_size) dist_init(rank, world_size, tempfile_name)
device = torch.device(rank) if torch.cuda.device_count() > 1 else DEVICE device = torch.device(rank) if torch.cuda.device_count() > 1 else DEVICE
# Run a dummy step so that the optimizer state dict exists # Run a dummy step so that the optimizer state dict exists
batch, input_width, hidden, target_width = 3, 20, 10, 5 batch, input_width, hidden, target_width = 3, 3, 3, 5
target = torch.rand((batch, target_width), device=device) target = torch.rand((batch, target_width), device=device)
inputs = torch.rand((batch, input_width), device=device) inputs = torch.rand((batch, input_width), device=device)
...@@ -343,24 +356,25 @@ def run_test_collect_shards(rank, world_size, reference_rank): ...@@ -343,24 +356,25 @@ def run_test_collect_shards(rank, world_size, reference_rank):
# Load the optimizer state dict # Load the optimizer state dict
optimizer.load_state_dict(optimizer_state_dict) optimizer.load_state_dict(optimizer_state_dict)
dist.destroy_process_group()
def test_collect_shards(): def test_collect_shards():
world_size = 3 world_size = 3
temp_file_name = tempfile.mkstemp()[1]
if torch.cuda.is_available(): if torch.cuda.is_available():
world_size = min(world_size, torch.cuda.device_count()) world_size = min(world_size, torch.cuda.device_count())
reference_rank = 0 reference_rank = 0
mp.spawn( mp.spawn(
run_test_collect_shards, args=(world_size, reference_rank), nprocs=world_size, join=True, run_test_collect_shards, args=(world_size, reference_rank, temp_file_name), nprocs=world_size, join=True,
) )
def run_test_multiple_groups(rank, world_size): def run_test_multiple_groups(rank, world_size, tempfile_name):
# Only work with the even ranks, to check that the global_rank indexing is properly used # Only work with the even ranks, to check that the global_rank indexing is properly used
os.environ["MASTER_ADDR"] = "localhost" dist_init(rank=rank, world_size=world_size, tempfile_name=tempfile_name, backend="gloo")
os.environ["MASTER_PORT"] = "29501"
dist.init_process_group(backend="gloo", rank=rank, world_size=world_size)
sub_group_ranks = [0, 2, 4] sub_group_ranks = [0, 2, 4]
process_group = torch.distributed.new_group(ranks=sub_group_ranks, backend="gloo") process_group = torch.distributed.new_group(ranks=sub_group_ranks, backend="gloo")
...@@ -422,10 +436,14 @@ def run_test_multiple_groups(rank, world_size): ...@@ -422,10 +436,14 @@ def run_test_multiple_groups(rank, world_size):
optimizer = optim.OSS(model.parameters(), lr=0.1, momentum=0.99, group=process_group, broadcast_buffer_size=0) optimizer = optim.OSS(model.parameters(), lr=0.1, momentum=0.99, group=process_group, broadcast_buffer_size=0)
check(optimizer) check(optimizer)
dist.destroy_process_group(process_group)
dist.destroy_process_group()
def test_multiple_groups(): def test_multiple_groups():
world_size = 6 world_size = 6
temp_file_name = tempfile.mkstemp()[1]
mp.spawn( mp.spawn(
run_test_multiple_groups, args=(world_size,), nprocs=world_size, join=True, run_test_multiple_groups, args=(world_size, temp_file_name), nprocs=world_size, join=True,
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment