Unverified Commit c19cc897 authored by anj-s's avatar anj-s Committed by GitHub
Browse files

[offload] Add support for record_function when using OffloadModel (#564)

* add record_function support

* add more record_function cutpoints

* add more record_function cutpoints

* lint errors

* make string ids more specific
parent 2d3d5a7b
......@@ -214,6 +214,7 @@ def train(model_config, model, benchmark_config, model_specs, args):
total_tokens_per_log_interval = 0
total_loss = 0
start_time = time.time()
if args.use_profiler:
prof.export_chrome_trace("/tmp/offload_prof")
if epoch_start_time != 0:
......
......@@ -165,6 +165,7 @@ class ActivationCheckpointing(torch.autograd.Function):
model_instance._activations = [inputs]
# Enumerate through layer shards and apply activations from the previous shard.
for index, layer_shard in enumerate(model_instance.model_slices):
with torch.autograd.profiler.record_function("fairscale.experimental.nn.offload:forward_load"):
# Bring in the current activations onto the device.
model_instance._activations[index] = tuple([a.cuda() for a in list(model_instance._activations[index])])
# Bring in the current layer shard onto the device.
......@@ -172,7 +173,7 @@ class ActivationCheckpointing(torch.autograd.Function):
# Apply the FP and store the activations on the CPU.
inputs = model_instance._activations[index]
with torch.autograd.profiler.record_function("fairscale.experimental.nn.offload:no_grad_forward_pass"):
with torch.no_grad():
output_list: List[Any] = []
for given_input in inputs:
......@@ -186,6 +187,7 @@ class ActivationCheckpointing(torch.autograd.Function):
output = tuple(output_list)
output = output if isinstance(output, tuple) else (output,)
with torch.autograd.profiler.record_function("fairscale.experimental.nn.offload:forward_drop"):
# The last instance will lose the gradient function if we move it to the CPU.
# This is because all grad function are present on the device that ran the FW pass.
if index == len(model_instance.model_slices) - 1:
......@@ -197,6 +199,7 @@ class ActivationCheckpointing(torch.autograd.Function):
# TODO(anj-s): Check device of the result to make sure the outputs and targets match device.
result = model_instance._activations[-1]
result = [r.cuda() for r in result]
for r in result:
r.requires_grad = True
return result[0] if len(result) == 1 else result
......@@ -217,8 +220,10 @@ class ActivationCheckpointing(torch.autograd.Function):
for model_shard, activation in zip(
reversed(model_instance.model_slices), reversed(model_instance._activations[:-1])
):
with torch.autograd.profiler.record_function("fairscale.experimental.nn.offload:backward_load"):
# Move the model shard to the device.
model_shard.backward_load()
# Store the BW pass state.
bwd_rng_state = torch.get_rng_state()
......@@ -253,12 +258,16 @@ class ActivationCheckpointing(torch.autograd.Function):
a.requires_grad = True
a.retain_grad()
with torch.autograd.profiler.record_function(
"fairscale.experimental.nn.offload:forward_pass_with_enable_grad"
):
with torch.enable_grad():
# calculate the output of the last shard wrt to the stored activation at the slice boundary.
outputs = model_shard(*chunked_activation)
# Set the states back to what it was at the start of this function.
torch.set_rng_state(bwd_rng_state)
with torch.autograd.profiler.record_function("fairscale.experimental.nn.offload:backward_pass"):
torch.autograd.backward(outputs, chunked_grad)
intermediate_grads = []
for a in chunked_activation:
......@@ -270,6 +279,7 @@ class ActivationCheckpointing(torch.autograd.Function):
# Append the list of grads to the all_grads list and this should be on the CPU.
all_grads.append(torch.cat(chunked_grad_list).squeeze(-1)) # type: ignore
# TODO(anj-s): Why does moving activations to CPU cause the .grad property to be None?
with torch.autograd.profiler.record_function("fairscale.experimental.nn.offload:backward_drop"):
# Move the shard back to the CPU.
model_shard.backward_drop()
detached_inputs = model_instance._activations[0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment