Unverified Commit 8ba649e1 authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[minor] make backward assert a bit better (#919)



* [minor] better assert in backward

* mypy
Co-authored-by: default avatarMin Xu <min.xu.public@gmail.com>
parent 5d8a505c
...@@ -1716,7 +1716,8 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1716,7 +1716,8 @@ class FullyShardedDataParallel(nn.Module):
@torch.no_grad() @torch.no_grad()
def _wait_for_post_backward(self) -> None: def _wait_for_post_backward(self) -> None:
"""Wait for post-backward to finish. Only called on root instance.""" """Wait for post-backward to finish. Only called on root instance."""
assert self._is_root # None, backward runtime swallow the assert error, so we use p_assert() here.
p_assert(self._is_root, "WFPB not called on root")
# Check if the root module has params and if any of them has # Check if the root module has params and if any of them has
# the `requires_grad` field set. If `requires_grad=False` for # the `requires_grad` field set. If `requires_grad=False` for
# all the params, the post_backward hook will not fire and the # all the params, the post_backward hook will not fire and the
...@@ -1729,7 +1730,8 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1729,7 +1730,8 @@ class FullyShardedDataParallel(nn.Module):
if self._require_backward_grad_sync: if self._require_backward_grad_sync:
# Flush any unreduced buckets in the post_backward stream. # Flush any unreduced buckets in the post_backward stream.
with torch.cuda.stream(self._streams["post_backward"]): with torch.cuda.stream(self._streams["post_backward"]):
assert self._reducer is not None p_assert(self._reducer is not None, "WFPB: reducer is None")
assert self._reducer is not None # make mypy happy
self._reducer.flush() self._reducer.flush()
torch.cuda.current_stream().wait_stream(self._streams["post_backward"]) torch.cuda.current_stream().wait_stream(self._streams["post_backward"])
if self.move_grads_to_cpu: if self.move_grads_to_cpu:
...@@ -1748,7 +1750,7 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1748,7 +1750,7 @@ class FullyShardedDataParallel(nn.Module):
if not p.requires_grad: if not p.requires_grad:
continue continue
if hasattr(p, "_shard_bwd_hook"): if hasattr(p, "_shard_bwd_hook"):
assert len(p._shard_bwd_hook) == 2, len(p._shard_bwd_hook) p_assert(len(p._shard_bwd_hook) == 2, f"WFPB: incorrect hook num: {len(p._shard_bwd_hook)}")
p._shard_bwd_hook[1].remove() p._shard_bwd_hook[1].remove()
delattr(p, "_shard_bwd_hook") delattr(p, "_shard_bwd_hook")
...@@ -1761,10 +1763,13 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1761,10 +1763,13 @@ class FullyShardedDataParallel(nn.Module):
# Parameter and gradient devices must match. # Parameter and gradient devices must match.
if hasattr(p, "_cpu_grad"): if hasattr(p, "_cpu_grad"):
assert p.device == torch.device("cpu") p_assert(p.device == torch.device("cpu"), f"WFPB: incorrect cpu_grad device {p.device}")
p.grad = p._cpu_grad p.grad = p._cpu_grad
elif hasattr(p, "_saved_grad_shard"): elif hasattr(p, "_saved_grad_shard"):
assert p.device == p._saved_grad_shard.device p_assert(
p.device == p._saved_grad_shard.device,
f"WFPB: incorrect saved_grad_shard device {p.device} vs {p._saved_grad_shard.device}",
)
p.grad = p._saved_grad_shard p.grad = p._saved_grad_shard
if hasattr(p, "_saved_grad_shard"): if hasattr(p, "_saved_grad_shard"):
...@@ -1799,7 +1804,11 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1799,7 +1804,11 @@ class FullyShardedDataParallel(nn.Module):
# reset this flag for cases like "one forward pass + multiple backward passes" # reset this flag for cases like "one forward pass + multiple backward passes"
self._post_backward_callback_queued = False self._post_backward_callback_queued = False
# clear this list for next iteration # clear this list for next iteration
assert self._output_pre_backward_hook_registered is not None p_assert(
self._output_pre_backward_hook_registered is not None,
"WFPB: self._output_pre_backward_hook_registered should not be None",
)
assert self._output_pre_backward_hook_registered is not None # make mypy happy
self._output_pre_backward_hook_registered.clear() self._output_pre_backward_hook_registered.clear()
@torch.no_grad() @torch.no_grad()
...@@ -2355,6 +2364,13 @@ class FullyShardedDataParallel(nn.Module): ...@@ -2355,6 +2364,13 @@ class FullyShardedDataParallel(nn.Module):
return self.move_params_to_cpu return self.move_params_to_cpu
def p_assert(cond: Any, s: Any) -> None:
"""Used in backward context to make sure error is printed."""
if not cond:
print(s)
raise AssertionError
def _get_default_cuda_device(module: nn.Module) -> torch.device: def _get_default_cuda_device(module: nn.Module) -> torch.device:
"""Try to infer CUDA device from module parameters.""" """Try to infer CUDA device from module parameters."""
try: try:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment