Unverified Commit a44b9c68 authored by Ziyao Li's avatar Ziyao Li Committed by GitHub
Browse files

pass arg to switch ckp on/off (#22)

parent ee13d84f
...@@ -296,6 +296,7 @@ def eval_bool(x, default=False): ...@@ -296,6 +296,7 @@ def eval_bool(x, default=False):
def checkpoint_sequential( def checkpoint_sequential(
functions, functions,
input, input,
enabled=True,
): ):
def wrap_tuple(a): def wrap_tuple(a):
return (a,) if type(a) is not tuple else a return (a,) if type(a) is not tuple else a
...@@ -313,7 +314,7 @@ def checkpoint_sequential( ...@@ -313,7 +314,7 @@ def checkpoint_sequential(
is_grad_enabled = torch.is_grad_enabled() is_grad_enabled = torch.is_grad_enabled()
if is_grad_enabled: if enabled and is_grad_enabled:
for func in functions: for func in functions:
input = torch.utils.checkpoint.checkpoint(get_wrap_exec(func), *input) input = torch.utils.checkpoint.checkpoint(get_wrap_exec(func), *input)
else: else:
...@@ -433,4 +434,4 @@ def validate_with_ema(trainer, ema=False): ...@@ -433,4 +434,4 @@ def validate_with_ema(trainer, ema=False):
yield yield
finally: finally:
trainer._wrapped_model = _wrapped_model trainer._wrapped_model = _wrapped_model
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment