Unverified Commit a44b9c68 authored by Ziyao Li's avatar Ziyao Li Committed by GitHub
Browse files

pass arg to switch ckp on/off (#22)

parent ee13d84f
......@@ -296,6 +296,7 @@ def eval_bool(x, default=False):
def checkpoint_sequential(
functions,
input,
enabled=True,
):
def wrap_tuple(a):
return (a,) if type(a) is not tuple else a
......@@ -313,7 +314,7 @@ def checkpoint_sequential(
is_grad_enabled = torch.is_grad_enabled()
if is_grad_enabled:
if enabled and is_grad_enabled:
for func in functions:
input = torch.utils.checkpoint.checkpoint(get_wrap_exec(func), *input)
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment