Unverified Commit e89a1916 authored by anj-s's avatar anj-s Committed by GitHub
Browse files

[offload] Fix activation offloading to CPU in FW pass. (#588)

* debugging

* debugging activation issue

* fix activation loading

* remove changes used for testing

* remove comment
parent 14abed6e
......@@ -131,7 +131,6 @@ def train_seq(model_config, benchmark_config, model_specs, args):
with _get_profiler_record_context("model_training", args.use_profiler):
with _get_fp16_context(use_fp16=args.use_fp16):
output = model(inputs)
print(f"output grad_fn {output.grad_fn}")
loss = criterion(output, target=batch_outputs)
loss.backward()
optimizer.step()
......
......@@ -188,16 +188,18 @@ class ActivationCheckpointing(torch.autograd.Function):
output = output if isinstance(output, tuple) else (output,)
with torch.autograd.profiler.record_function("fairscale.experimental.nn.offload:forward_drop"):
# The last instance will lose the gradient function if we move it to the CPU.
# This is because all grad function are present on the device that ran the FW pass.
if index == len(model_instance.model_slices) - 1:
model_instance._activations.append(output)
else:
model_instance._activations.append(tuple([a.cpu() for a in list(output)]))
# Move the activation used back for the curent shard back to the CPU.
model_instance._activations[index] = tuple([a.cpu() for a in list(model_instance._activations[index])])
# The newly computed activations remain on the GPU ready for the next shard computation.
model_instance._activations.append(output)
# Move the layer shard back to the CPU.
layer_shard.forward_drop()
# TODO(anj-s): Check device of the result to make sure the outputs and targets match device.
# The last instance will lose the gradient function if we move it to the CPU.
# This is because all grad function are present on the device that ran the FW pass.
# The last activation remains on the GPU and is the return value of this function.
# Note that this assumes that the target is also on the GPU which is required for calculating
# the loss.
result = model_instance._activations[-1]
result = [r.cuda() for r in result]
for r in result:
......@@ -221,7 +223,10 @@ class ActivationCheckpointing(torch.autograd.Function):
reversed(model_instance.model_slices), reversed(model_instance._activations[:-1])
):
with torch.autograd.profiler.record_function("fairscale.experimental.nn.offload:backward_load"):
# Move the model shard to the device.
# Move the activation to the GPU.
activation = tuple([a.cuda() for a in list(activation)])
# Move the model shard to the GPU.
model_shard.backward_load()
# Store the BW pass state.
......@@ -276,11 +281,11 @@ class ActivationCheckpointing(torch.autograd.Function):
if None not in intermediate_grads:
chunked_grad_list += intermediate_grads
if chunked_grad_list:
# Append the list of grads to the all_grads list and this should be on the CPU.
# Append the list of grads to the all_grads list and this should be on the GPU.
all_grads.append(torch.cat(chunked_grad_list).squeeze(-1)) # type: ignore
# TODO(anj-s): Why does moving activations to CPU cause the .grad property to be None?
with torch.autograd.profiler.record_function("fairscale.experimental.nn.offload:backward_drop"):
# Move the shard back to the CPU.
# Move the shard back to the CPU. This should move all the grad tensors to CPU as well.
# We don't need to move activations since we are using a copy of the tensors on the GPU.
model_shard.backward_drop()
detached_inputs = model_instance._activations[0]
grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp for inp in detached_inputs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment