Unverified Commit 8ba649e1 authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[minor] make backward assert a bit better (#919)



* [minor] better assert in backward

* mypy
Co-authored-by: default avatarMin Xu <min.xu.public@gmail.com>
parent 5d8a505c
......@@ -1716,7 +1716,8 @@ class FullyShardedDataParallel(nn.Module):
@torch.no_grad()
def _wait_for_post_backward(self) -> None:
"""Wait for post-backward to finish. Only called on root instance."""
assert self._is_root
# None, backward runtime swallow the assert error, so we use p_assert() here.
p_assert(self._is_root, "WFPB not called on root")
# Check if the root module has params and if any of them has
# the `requires_grad` field set. If `requires_grad=False` for
# all the params, the post_backward hook will not fire and the
......@@ -1729,7 +1730,8 @@ class FullyShardedDataParallel(nn.Module):
if self._require_backward_grad_sync:
# Flush any unreduced buckets in the post_backward stream.
with torch.cuda.stream(self._streams["post_backward"]):
assert self._reducer is not None
p_assert(self._reducer is not None, "WFPB: reducer is None")
assert self._reducer is not None # make mypy happy
self._reducer.flush()
torch.cuda.current_stream().wait_stream(self._streams["post_backward"])
if self.move_grads_to_cpu:
......@@ -1748,7 +1750,7 @@ class FullyShardedDataParallel(nn.Module):
if not p.requires_grad:
continue
if hasattr(p, "_shard_bwd_hook"):
assert len(p._shard_bwd_hook) == 2, len(p._shard_bwd_hook)
p_assert(len(p._shard_bwd_hook) == 2, f"WFPB: incorrect hook num: {len(p._shard_bwd_hook)}")
p._shard_bwd_hook[1].remove()
delattr(p, "_shard_bwd_hook")
......@@ -1761,10 +1763,13 @@ class FullyShardedDataParallel(nn.Module):
# Parameter and gradient devices must match.
if hasattr(p, "_cpu_grad"):
assert p.device == torch.device("cpu")
p_assert(p.device == torch.device("cpu"), f"WFPB: incorrect cpu_grad device {p.device}")
p.grad = p._cpu_grad
elif hasattr(p, "_saved_grad_shard"):
assert p.device == p._saved_grad_shard.device
p_assert(
p.device == p._saved_grad_shard.device,
f"WFPB: incorrect saved_grad_shard device {p.device} vs {p._saved_grad_shard.device}",
)
p.grad = p._saved_grad_shard
if hasattr(p, "_saved_grad_shard"):
......@@ -1799,7 +1804,11 @@ class FullyShardedDataParallel(nn.Module):
# reset this flag for cases like "one forward pass + multiple backward passes"
self._post_backward_callback_queued = False
# clear this list for next iteration
assert self._output_pre_backward_hook_registered is not None
p_assert(
self._output_pre_backward_hook_registered is not None,
"WFPB: self._output_pre_backward_hook_registered should not be None",
)
assert self._output_pre_backward_hook_registered is not None # make mypy happy
self._output_pre_backward_hook_registered.clear()
@torch.no_grad()
......@@ -2355,6 +2364,13 @@ class FullyShardedDataParallel(nn.Module):
return self.move_params_to_cpu
def p_assert(cond: Any, s: Any) -> None:
"""Used in backward context to make sure error is printed."""
if not cond:
print(s)
raise AssertionError
def _get_default_cuda_device(module: nn.Module) -> torch.device:
"""Try to infer CUDA device from module parameters."""
try:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment