"...text-generation-inference.git" did not exist on "96a982ad8fc232479384476b1596a880697cc1d0"
Unverified Commit 3d02f052 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[refactor][OSS] Adding a pytorch parity unit test (#298)

* adding a parity unit test
* code review, better testing, use torch defaults and check for the loss, log world size
parent 3399e97c
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from typing import Any, List, Dict, Iterable, Union, Callable, Optional from typing import Any, List, Dict, Iterable, Union, Callable, Optional
from .. import Tensor from .. import Tensor
...@@ -8,7 +8,7 @@ _params_t = Union[Iterable[Tensor], Iterable[Dict]] ...@@ -8,7 +8,7 @@ _params_t = Union[Iterable[Tensor], Iterable[Dict]]
class Optimizer(object): class Optimizer(object):
param_groups: List[Dict] param_groups: List[Dict]
state: Dict state: Dict
def __init__(self, params: _params_t, defaults: Dict) -> None: ... def __init__(self, params: _params_t, defaults: Optional[Dict]=None, lr: Optional[float]=None) -> None: ...
def state_dict(self) -> Dict: ... def state_dict(self) -> Dict: ...
def load_state_dict(self, state_dict: Dict) -> None: ... def load_state_dict(self, state_dict: Dict) -> None: ...
def zero_grad(self) -> None: ... def zero_grad(self) -> None: ...
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
import copy import copy
from math import inf from math import inf
import tempfile import tempfile
from typing import Type, cast
import unittest import unittest
import numpy as np import numpy as np
...@@ -676,3 +677,80 @@ def test_state_dict_distributed(): ...@@ -676,3 +677,80 @@ def test_state_dict_distributed():
mp.spawn( mp.spawn(
run_state_dict_distributed, args=(world_size, temp_file_name), nprocs=world_size, join=True, run_state_dict_distributed, args=(world_size, temp_file_name), nprocs=world_size, join=True,
) )
def run_ddp_parity(rank, world_size, backend, temp_file_name):
url = "file://" + temp_file_name
dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size)
device = torch.device("cuda")
torch.cuda.set_device(rank)
torch.manual_seed(rank)
np.random.seed(rank)
def check_optimizer_equivalence(optimizer: Type[torch.optim.Optimizer]):
# Any model works. Add one different buffer per rank
model = torch.nn.Sequential(torch.nn.Linear(2, 3), torch.nn.Linear(3, 3), torch.nn.Linear(3, 3),)
model.register_buffer("test_buffer", torch.ones((1)) * rank)
model.to(device)
sharded_optimizer = optim.OSS(params=model.parameters(), optim=optimizer, lr=1e-3)
sharded_ddp_model = DDP(module=model, device_ids=[rank], broadcast_buffers=True)
ddp_model_single = copy.deepcopy(model)
ddp_optimizer = optimizer(ddp_model_single.parameters(), lr=1e-3)
ddp_model = DDP(ddp_model_single, device_ids=[rank], broadcast_buffers=True)
def check_same_model_params():
for pg, ddp_pg in zip(sharded_optimizer.param_groups, ddp_optimizer.param_groups):
for p, ddp_p in zip(pg["params"], ddp_pg["params"]):
assert torch.allclose(
p, ddp_p, atol=1e-3
), f"Model parameters differ in between Pytorch optim and OSS \n{p} {ddp_p}\nworld size {world_size}"
for b, ddp_b in zip(sharded_ddp_model.buffers(), ddp_model.buffers()):
assert torch.allclose(
b, ddp_b
), f"Model buffers differ in between Pytorch optim and OSS\nworld size {world_size}"
# The model should be synchronized in between the ranks at construction time, check that
check_same_model_params()
# The models should stay the same in between the ranks
for i in range(20):
input_tensor = torch.rand((64, 2)).to(device)
def closure_ddp(input_tensor=input_tensor):
ddp_optimizer.zero_grad()
ddp_loss = ddp_model(input_tensor).abs().sum()
ddp_loss.backward()
return ddp_loss
def closure_sharded(input_tensor=input_tensor):
sharded_optimizer.zero_grad()
sharded_loss = sharded_ddp_model(input_tensor).abs().sum()
sharded_loss.backward()
return sharded_loss
loss_ddp = cast(torch.Tensor, ddp_optimizer.step(closure=closure_ddp))
loss_sharded_optim = cast(torch.Tensor, sharded_optimizer.step(closure=closure_sharded))
assert torch.allclose(
loss_ddp, loss_sharded_optim
), f"Losses differ in between Pytorch optim and OSS\nworld size {world_size}"
check_same_model_params()
for opt in [torch.optim.SGD, torch.optim.Adam]:
check_optimizer_equivalence(opt)
dist.destroy_process_group()
@skip_if_no_cuda
@skip_if_single_gpu
def test_ddp_parity():
temp_file_name = tempfile.mkstemp()[1]
world_size = torch.cuda.device_count()
backend = dist.Backend.NCCL
mp.spawn(run_ddp_parity, args=(world_size, backend, temp_file_name), nprocs=world_size, join=True)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment