Unverified Commit c19cc897 authored by anj-s's avatar anj-s Committed by GitHub
Browse files

[offload] Add support for record_function when using OffloadModel (#564)

* add record_function support

* add more record_function cutpoints

* add more record_function cutpoints

* lint errors

* make string ids more specific
parent 2d3d5a7b
...@@ -214,6 +214,7 @@ def train(model_config, model, benchmark_config, model_specs, args): ...@@ -214,6 +214,7 @@ def train(model_config, model, benchmark_config, model_specs, args):
total_tokens_per_log_interval = 0 total_tokens_per_log_interval = 0
total_loss = 0 total_loss = 0
start_time = time.time() start_time = time.time()
if args.use_profiler:
prof.export_chrome_trace("/tmp/offload_prof") prof.export_chrome_trace("/tmp/offload_prof")
if epoch_start_time != 0: if epoch_start_time != 0:
......
...@@ -165,6 +165,7 @@ class ActivationCheckpointing(torch.autograd.Function): ...@@ -165,6 +165,7 @@ class ActivationCheckpointing(torch.autograd.Function):
model_instance._activations = [inputs] model_instance._activations = [inputs]
# Enumerate through layer shards and apply activations from the previous shard. # Enumerate through layer shards and apply activations from the previous shard.
for index, layer_shard in enumerate(model_instance.model_slices): for index, layer_shard in enumerate(model_instance.model_slices):
with torch.autograd.profiler.record_function("fairscale.experimental.nn.offload:forward_load"):
# Bring in the current activations onto the device. # Bring in the current activations onto the device.
model_instance._activations[index] = tuple([a.cuda() for a in list(model_instance._activations[index])]) model_instance._activations[index] = tuple([a.cuda() for a in list(model_instance._activations[index])])
# Bring in the current layer shard onto the device. # Bring in the current layer shard onto the device.
...@@ -172,7 +173,7 @@ class ActivationCheckpointing(torch.autograd.Function): ...@@ -172,7 +173,7 @@ class ActivationCheckpointing(torch.autograd.Function):
# Apply the FP and store the activations on the CPU. # Apply the FP and store the activations on the CPU.
inputs = model_instance._activations[index] inputs = model_instance._activations[index]
with torch.autograd.profiler.record_function("fairscale.experimental.nn.offload:no_grad_forward_pass"):
with torch.no_grad(): with torch.no_grad():
output_list: List[Any] = [] output_list: List[Any] = []
for given_input in inputs: for given_input in inputs:
...@@ -186,6 +187,7 @@ class ActivationCheckpointing(torch.autograd.Function): ...@@ -186,6 +187,7 @@ class ActivationCheckpointing(torch.autograd.Function):
output = tuple(output_list) output = tuple(output_list)
output = output if isinstance(output, tuple) else (output,) output = output if isinstance(output, tuple) else (output,)
with torch.autograd.profiler.record_function("fairscale.experimental.nn.offload:forward_drop"):
# The last instance will lose the gradient function if we move it to the CPU. # The last instance will lose the gradient function if we move it to the CPU.
# This is because all grad function are present on the device that ran the FW pass. # This is because all grad function are present on the device that ran the FW pass.
if index == len(model_instance.model_slices) - 1: if index == len(model_instance.model_slices) - 1:
...@@ -197,6 +199,7 @@ class ActivationCheckpointing(torch.autograd.Function): ...@@ -197,6 +199,7 @@ class ActivationCheckpointing(torch.autograd.Function):
# TODO(anj-s): Check device of the result to make sure the outputs and targets match device. # TODO(anj-s): Check device of the result to make sure the outputs and targets match device.
result = model_instance._activations[-1] result = model_instance._activations[-1]
result = [r.cuda() for r in result]
for r in result: for r in result:
r.requires_grad = True r.requires_grad = True
return result[0] if len(result) == 1 else result return result[0] if len(result) == 1 else result
...@@ -217,8 +220,10 @@ class ActivationCheckpointing(torch.autograd.Function): ...@@ -217,8 +220,10 @@ class ActivationCheckpointing(torch.autograd.Function):
for model_shard, activation in zip( for model_shard, activation in zip(
reversed(model_instance.model_slices), reversed(model_instance._activations[:-1]) reversed(model_instance.model_slices), reversed(model_instance._activations[:-1])
): ):
with torch.autograd.profiler.record_function("fairscale.experimental.nn.offload:backward_load"):
# Move the model shard to the device. # Move the model shard to the device.
model_shard.backward_load() model_shard.backward_load()
# Store the BW pass state. # Store the BW pass state.
bwd_rng_state = torch.get_rng_state() bwd_rng_state = torch.get_rng_state()
...@@ -253,12 +258,16 @@ class ActivationCheckpointing(torch.autograd.Function): ...@@ -253,12 +258,16 @@ class ActivationCheckpointing(torch.autograd.Function):
a.requires_grad = True a.requires_grad = True
a.retain_grad() a.retain_grad()
with torch.autograd.profiler.record_function(
"fairscale.experimental.nn.offload:forward_pass_with_enable_grad"
):
with torch.enable_grad(): with torch.enable_grad():
# calculate the output of the last shard wrt to the stored activation at the slice boundary. # calculate the output of the last shard wrt to the stored activation at the slice boundary.
outputs = model_shard(*chunked_activation) outputs = model_shard(*chunked_activation)
# Set the states back to what it was at the start of this function. # Set the states back to what it was at the start of this function.
torch.set_rng_state(bwd_rng_state) torch.set_rng_state(bwd_rng_state)
with torch.autograd.profiler.record_function("fairscale.experimental.nn.offload:backward_pass"):
torch.autograd.backward(outputs, chunked_grad) torch.autograd.backward(outputs, chunked_grad)
intermediate_grads = [] intermediate_grads = []
for a in chunked_activation: for a in chunked_activation:
...@@ -270,6 +279,7 @@ class ActivationCheckpointing(torch.autograd.Function): ...@@ -270,6 +279,7 @@ class ActivationCheckpointing(torch.autograd.Function):
# Append the list of grads to the all_grads list and this should be on the CPU. # Append the list of grads to the all_grads list and this should be on the CPU.
all_grads.append(torch.cat(chunked_grad_list).squeeze(-1)) # type: ignore all_grads.append(torch.cat(chunked_grad_list).squeeze(-1)) # type: ignore
# TODO(anj-s): Why does moving activations to CPU cause the .grad property to be None? # TODO(anj-s): Why does moving activations to CPU cause the .grad property to be None?
with torch.autograd.profiler.record_function("fairscale.experimental.nn.offload:backward_drop"):
# Move the shard back to the CPU. # Move the shard back to the CPU.
model_shard.backward_drop() model_shard.backward_drop()
detached_inputs = model_instance._activations[0] detached_inputs = model_instance._activations[0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment