Unverified Commit 3d02f052 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[refactor][OSS] Adding a pytorch parity unit test (#298)

* adding a parity unit test
* code review, better testing, use torch defaults and check for the loss, log world size
parent 3399e97c
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from typing import Any, List, Dict, Iterable, Union, Callable, Optional
from .. import Tensor
......@@ -8,7 +8,7 @@ _params_t = Union[Iterable[Tensor], Iterable[Dict]]
class Optimizer(object):
param_groups: List[Dict]
state: Dict
def __init__(self, params: _params_t, defaults: Dict) -> None: ...
def __init__(self, params: _params_t, defaults: Optional[Dict]=None, lr: Optional[float]=None) -> None: ...
def state_dict(self) -> Dict: ...
def load_state_dict(self, state_dict: Dict) -> None: ...
def zero_grad(self) -> None: ...
......
......@@ -11,6 +11,7 @@
import copy
from math import inf
import tempfile
from typing import Type, cast
import unittest
import numpy as np
......@@ -676,3 +677,80 @@ def test_state_dict_distributed():
mp.spawn(
run_state_dict_distributed, args=(world_size, temp_file_name), nprocs=world_size, join=True,
)
def run_ddp_parity(rank, world_size, backend, temp_file_name):
url = "file://" + temp_file_name
dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size)
device = torch.device("cuda")
torch.cuda.set_device(rank)
torch.manual_seed(rank)
np.random.seed(rank)
def check_optimizer_equivalence(optimizer: Type[torch.optim.Optimizer]):
# Any model works. Add one different buffer per rank
model = torch.nn.Sequential(torch.nn.Linear(2, 3), torch.nn.Linear(3, 3), torch.nn.Linear(3, 3),)
model.register_buffer("test_buffer", torch.ones((1)) * rank)
model.to(device)
sharded_optimizer = optim.OSS(params=model.parameters(), optim=optimizer, lr=1e-3)
sharded_ddp_model = DDP(module=model, device_ids=[rank], broadcast_buffers=True)
ddp_model_single = copy.deepcopy(model)
ddp_optimizer = optimizer(ddp_model_single.parameters(), lr=1e-3)
ddp_model = DDP(ddp_model_single, device_ids=[rank], broadcast_buffers=True)
def check_same_model_params():
for pg, ddp_pg in zip(sharded_optimizer.param_groups, ddp_optimizer.param_groups):
for p, ddp_p in zip(pg["params"], ddp_pg["params"]):
assert torch.allclose(
p, ddp_p, atol=1e-3
), f"Model parameters differ in between Pytorch optim and OSS \n{p} {ddp_p}\nworld size {world_size}"
for b, ddp_b in zip(sharded_ddp_model.buffers(), ddp_model.buffers()):
assert torch.allclose(
b, ddp_b
), f"Model buffers differ in between Pytorch optim and OSS\nworld size {world_size}"
# The model should be synchronized in between the ranks at construction time, check that
check_same_model_params()
# The models should stay the same in between the ranks
for i in range(20):
input_tensor = torch.rand((64, 2)).to(device)
def closure_ddp(input_tensor=input_tensor):
ddp_optimizer.zero_grad()
ddp_loss = ddp_model(input_tensor).abs().sum()
ddp_loss.backward()
return ddp_loss
def closure_sharded(input_tensor=input_tensor):
sharded_optimizer.zero_grad()
sharded_loss = sharded_ddp_model(input_tensor).abs().sum()
sharded_loss.backward()
return sharded_loss
loss_ddp = cast(torch.Tensor, ddp_optimizer.step(closure=closure_ddp))
loss_sharded_optim = cast(torch.Tensor, sharded_optimizer.step(closure=closure_sharded))
assert torch.allclose(
loss_ddp, loss_sharded_optim
), f"Losses differ in between Pytorch optim and OSS\nworld size {world_size}"
check_same_model_params()
for opt in [torch.optim.SGD, torch.optim.Adam]:
check_optimizer_equivalence(opt)
dist.destroy_process_group()
@skip_if_no_cuda
@skip_if_single_gpu
def test_ddp_parity():
temp_file_name = tempfile.mkstemp()[1]
world_size = torch.cuda.device_count()
backend = dist.Backend.NCCL
mp.spawn(run_ddp_parity, args=(world_size, backend, temp_file_name), nprocs=world_size, join=True)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment