Unverified Commit a44b9c68 authored by Ziyao Li's avatar Ziyao Li Committed by GitHub
Browse files

pass arg to switch ckp on/off (#22)

parent ee13d84f
...@@ -296,6 +296,7 @@ def eval_bool(x, default=False): ...@@ -296,6 +296,7 @@ def eval_bool(x, default=False):
def checkpoint_sequential( def checkpoint_sequential(
functions, functions,
input, input,
enabled=True,
): ):
def wrap_tuple(a): def wrap_tuple(a):
return (a,) if type(a) is not tuple else a return (a,) if type(a) is not tuple else a
...@@ -313,7 +314,7 @@ def checkpoint_sequential( ...@@ -313,7 +314,7 @@ def checkpoint_sequential(
is_grad_enabled = torch.is_grad_enabled() is_grad_enabled = torch.is_grad_enabled()
if is_grad_enabled: if enabled and is_grad_enabled:
for func in functions: for func in functions:
input = torch.utils.checkpoint.checkpoint(get_wrap_exec(func), *input) input = torch.utils.checkpoint.checkpoint(get_wrap_exec(func), *input)
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment