Unverified Commit 4f7f0853 authored by Tim Brooks's avatar Tim Brooks Committed by GitHub
Browse files

Add method for disabling gradient checkpointing (#772)

See https://github.com/facebookresearch/fairscale/issues/771
parent 3ecf76f4
...@@ -25,11 +25,23 @@ class ThreadLocal(threading.local): ...@@ -25,11 +25,23 @@ class ThreadLocal(threading.local):
def __init__(self) -> None: def __init__(self) -> None:
self.is_checkpointing = False self.is_checkpointing = False
self.is_recomputing = False self.is_recomputing = False
self.is_checkpointing_disabled = False
thread_local = ThreadLocal() thread_local = ThreadLocal()
@contextmanager
def disable_checkpointing() -> Generator[None, None, None]:
"""Makes :func:`is_checkpointing_disabled` return :data:`True` within a context."""
orig = thread_local.is_checkpointing_disabled
thread_local.is_checkpointing_disabled = True
try:
yield
finally:
thread_local.is_checkpointing_disabled = orig
@contextmanager @contextmanager
def enable_checkpointing() -> Generator[None, None, None]: def enable_checkpointing() -> Generator[None, None, None]:
"""Makes :func:`is_checkpointing` return :data:`True` within a context.""" """Makes :func:`is_checkpointing` return :data:`True` within a context."""
...@@ -164,7 +176,7 @@ def _checkpointed_forward( ...@@ -164,7 +176,7 @@ def _checkpointed_forward(
# which would be an issue during eval since there wouldn't be a corresponding backward pass # which would be an issue during eval since there wouldn't be a corresponding backward pass
# to decrement the fwd counter. # to decrement the fwd counter.
# See https://github.com/facebookresearch/fairscale/pull/709. # See https://github.com/facebookresearch/fairscale/pull/709.
if not torch.is_grad_enabled(): if not torch.is_grad_enabled() or thread_local.is_checkpointing_disabled:
return original_forward(module, *args, **kwargs) return original_forward(module, *args, **kwargs)
# Autograd Functions in PyTorch work best with positional args, since # Autograd Functions in PyTorch work best with positional args, since
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment