Unverified Commit 31fe887e authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

support gradient checkpointing

parent 89ac1b7b
...@@ -290,6 +290,35 @@ def eval_bool(x, default=False): ...@@ -290,6 +290,35 @@ def eval_bool(x, default=False):
return default return default
def checkpoint_sequential(
functions,
input,
):
def wrap_tuple(a):
return (a,) if type(a) is not tuple else a
def exec(func, a):
return wrap_tuple(func(*a))
def get_wrap_exec(func):
def wrap_exec(*a):
return exec(func, a)
return wrap_exec
input = wrap_tuple(input)
is_grad_enabled = torch.is_grad_enabled()
if is_grad_enabled:
for func in functions:
input = torch.utils.checkpoint.checkpoint(get_wrap_exec(func), *input)
else:
for func in functions:
input = exec(func, input)
return input
def permute_final_dims(tensor: torch.Tensor, inds: List[int]): def permute_final_dims(tensor: torch.Tensor, inds: List[int]):
zero_index = -1 * len(inds) zero_index = -1 * len(inds)
first_inds = list(range(len(tensor.shape[:zero_index]))) first_inds = list(range(len(tensor.shape[:zero_index])))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment