Unverified Commit a44b9c68 authored by Ziyao Li's avatar Ziyao Li Committed by GitHub
Browse files

pass arg to switch ckp on/off (#22)

parent ee13d84f
......@@ -296,6 +296,7 @@ def eval_bool(x, default=False):
def checkpoint_sequential(
functions,
input,
enabled=True,
):
def wrap_tuple(a):
return (a,) if type(a) is not tuple else a
......@@ -313,7 +314,7 @@ def checkpoint_sequential(
is_grad_enabled = torch.is_grad_enabled()
if is_grad_enabled:
if enabled and is_grad_enabled:
for func in functions:
input = torch.utils.checkpoint.checkpoint(get_wrap_exec(func), *input)
else:
......@@ -433,4 +434,4 @@ def validate_with_ema(trainer, ema=False):
yield
finally:
trainer._wrapped_model = _wrapped_model
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment