Unverified Commit 8c8eb8e8 authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[fix] fix eval for oss_ddp (#55)

- added train(mode) method to be aware of eval mode
parent fb49b515
...@@ -99,6 +99,15 @@ class OssDdp(nn.Module): ...@@ -99,6 +99,15 @@ class OssDdp(nn.Module):
attrs = copy.copy(self.__dict__) attrs = copy.copy(self.__dict__)
return attrs return attrs
def train(self, mode: bool = True) -> "OssDdp":
pre_mode = self.module.training
self.module.train(mode)
if self.module.training:
assert not self.need_reduction or pre_mode, "incorrect state transition"
else:
assert not self.need_reduction, "try to enter eval with grads unreduced"
return self
@contextmanager @contextmanager
def no_sync(self) -> Generator: def no_sync(self) -> Generator:
"""A context manager to disable gradient synchronization.""" """A context manager to disable gradient synchronization."""
...@@ -108,6 +117,7 @@ class OssDdp(nn.Module): ...@@ -108,6 +117,7 @@ class OssDdp(nn.Module):
self.accumulate_grads = old_accumulate_grads self.accumulate_grads = old_accumulate_grads
def forward(self, *inputs: Any, **kwargs: Any) -> Tensor: def forward(self, *inputs: Any, **kwargs: Any) -> Tensor:
if self.module.training:
if self.need_reduction: if self.need_reduction:
raise RuntimeError("OssDdp requires explicit reduction, must call OssDdp.reduce") raise RuntimeError("OssDdp requires explicit reduction, must call OssDdp.reduce")
if not self.accumulate_grads: if not self.accumulate_grads:
...@@ -119,6 +129,7 @@ class OssDdp(nn.Module): ...@@ -119,6 +129,7 @@ class OssDdp(nn.Module):
This function must be called explicitly after backward to reduce This function must be called explicitly after backward to reduce
gradients. There is no automatic hook like c10d. gradients. There is no automatic hook like c10d.
""" """
assert self.module.training, "Cannot call reduce in eval"
def reduce_params(params: List[Parameter], params_rank: int) -> None: def reduce_params(params: List[Parameter], params_rank: int) -> None:
""" Helper to reduce a list of params that should fix in the buffer. """ """ Helper to reduce a list of params that should fix in the buffer. """
......
...@@ -54,3 +54,32 @@ def run_one_step(rank, world_size, backend, device, temp_file_name): ...@@ -54,3 +54,32 @@ def run_one_step(rank, world_size, backend, device, temp_file_name):
def run_test(backend, device, world_size=2): def run_test(backend, device, world_size=2):
temp_file_name = tempfile.mkstemp()[1] temp_file_name = tempfile.mkstemp()[1]
mp.spawn(run_one_step, args=(world_size, backend, device, temp_file_name), nprocs=world_size, join=True) mp.spawn(run_one_step, args=(world_size, backend, device, temp_file_name), nprocs=world_size, join=True)
def run_eval_mode(_unused):
""" Testing eval mode make sure this is no asserts. """
dist.init_process_group(
init_method=f"file://{tempfile.mkstemp()[1]}", backend=dist.Backend.GLOO, rank=0, world_size=1
)
model = Sequential(Linear(2, 3), Linear(3, 4))
optimizer = OSS(model.parameters(), lr=0.1, momentum=0.99)
ddp = OssDdp(model, optimizer, 1)
ddp.eval()
for _ in range(5):
input_tensor = torch.rand((64, 2))
output = ddp(input_tensor)
ddp.train()
try:
for _ in range(5):
input_tensor = torch.rand((64, 2))
output = ddp(input_tensor)
except RuntimeError:
pass
else:
assert False, "Multiple forward passes on training mode should not pass"
def test_eval_mode():
mp.spawn(run_eval_mode, args=(), join=True)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment