Unverified Commit 8c8eb8e8 authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[fix] fix eval for oss_ddp (#55)

- added train(mode) method to be aware of eval mode
parent fb49b515
......@@ -99,6 +99,15 @@ class OssDdp(nn.Module):
attrs = copy.copy(self.__dict__)
return attrs
def train(self, mode: bool = True) -> "OssDdp":
pre_mode = self.module.training
self.module.train(mode)
if self.module.training:
assert not self.need_reduction or pre_mode, "incorrect state transition"
else:
assert not self.need_reduction, "try to enter eval with grads unreduced"
return self
@contextmanager
def no_sync(self) -> Generator:
"""A context manager to disable gradient synchronization."""
......@@ -108,10 +117,11 @@ class OssDdp(nn.Module):
self.accumulate_grads = old_accumulate_grads
def forward(self, *inputs: Any, **kwargs: Any) -> Tensor:
if self.need_reduction:
raise RuntimeError("OssDdp requires explicit reduction, must call OssDdp.reduce")
if not self.accumulate_grads:
self.need_reduction = True
if self.module.training:
if self.need_reduction:
raise RuntimeError("OssDdp requires explicit reduction, must call OssDdp.reduce")
if not self.accumulate_grads:
self.need_reduction = True
return self.module(*inputs, **kwargs)
def reduce(self) -> None:
......@@ -119,6 +129,7 @@ class OssDdp(nn.Module):
This function must be called explicitly after backward to reduce
gradients. There is no automatic hook like c10d.
"""
assert self.module.training, "Cannot call reduce in eval"
def reduce_params(params: List[Parameter], params_rank: int) -> None:
""" Helper to reduce a list of params that should fix in the buffer. """
......
......@@ -54,3 +54,32 @@ def run_one_step(rank, world_size, backend, device, temp_file_name):
def run_test(backend, device, world_size=2):
temp_file_name = tempfile.mkstemp()[1]
mp.spawn(run_one_step, args=(world_size, backend, device, temp_file_name), nprocs=world_size, join=True)
def run_eval_mode(_unused):
""" Testing eval mode make sure this is no asserts. """
dist.init_process_group(
init_method=f"file://{tempfile.mkstemp()[1]}", backend=dist.Backend.GLOO, rank=0, world_size=1
)
model = Sequential(Linear(2, 3), Linear(3, 4))
optimizer = OSS(model.parameters(), lr=0.1, momentum=0.99)
ddp = OssDdp(model, optimizer, 1)
ddp.eval()
for _ in range(5):
input_tensor = torch.rand((64, 2))
output = ddp(input_tensor)
ddp.train()
try:
for _ in range(5):
input_tensor = torch.rand((64, 2))
output = ddp(input_tensor)
except RuntimeError:
pass
else:
assert False, "Multiple forward passes on training mode should not pass"
def test_eval_mode():
mp.spawn(run_eval_mode, args=(), join=True)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment