Unverified Commit 63f289f2 authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

checkpoint_activations: use non blocking cpu transfer (#719)

parent 308f1057
......@@ -248,7 +248,7 @@ class CheckpointFunction(torch.autograd.Function):
if parent_ctx_dict["offload"]:
ctx.fwd_device = tuple(x.device for x in tensor_inputs)
ctx.grad_requirements = tuple(x.requires_grad for x in tensor_inputs)
tensor_inputs = tuple(x.cpu() for x in tensor_inputs)
tensor_inputs = tuple(x.to("cpu", non_blocking=True) for x in tensor_inputs)
else:
ctx.fwd_device, ctx.grad_requirements = None, None
......@@ -277,7 +277,7 @@ class CheckpointFunction(torch.autograd.Function):
tensor_inputs: Tuple = ctx.saved_tensors
tensor_inputs = torch_checkpoint.detach_variable(tensor_inputs)
if ctx.fwd_device is not None:
tensor_inputs = tuple(t.to(ctx.fwd_device[i]) for i, t in enumerate(tensor_inputs))
tensor_inputs = tuple(t.to(ctx.fwd_device[i], non_blocking=True) for i, t in enumerate(tensor_inputs))
for i, need_grad in enumerate(ctx.grad_requirements):
tensor_inputs[i].requires_grad = need_grad
inputs = unpack_non_tensors(tensor_inputs, ctx.packed_non_tensor_inputs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment