Unverified Commit 4f7f0853 authored by Tim Brooks's avatar Tim Brooks Committed by GitHub
Browse files

Add method for disabling gradient checkpointing (#772)

See https://github.com/facebookresearch/fairscale/issues/771
parent 3ecf76f4
......@@ -25,11 +25,23 @@ class ThreadLocal(threading.local):
def __init__(self) -> None:
self.is_checkpointing = False
self.is_recomputing = False
self.is_checkpointing_disabled = False
thread_local = ThreadLocal()
@contextmanager
def disable_checkpointing() -> Generator[None, None, None]:
"""Makes :func:`is_checkpointing_disabled` return :data:`True` within a context."""
orig = thread_local.is_checkpointing_disabled
thread_local.is_checkpointing_disabled = True
try:
yield
finally:
thread_local.is_checkpointing_disabled = orig
@contextmanager
def enable_checkpointing() -> Generator[None, None, None]:
"""Makes :func:`is_checkpointing` return :data:`True` within a context."""
......@@ -164,7 +176,7 @@ def _checkpointed_forward(
# which would be an issue during eval since there wouldn't be a corresponding backward pass
# to decrement the fwd counter.
# See https://github.com/facebookresearch/fairscale/pull/709.
if not torch.is_grad_enabled():
if not torch.is_grad_enabled() or thread_local.is_checkpointing_disabled:
return original_forward(module, *args, **kwargs)
# Autograd Functions in PyTorch work best with positional args, since
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment