"docs/git@developer.sourcefind.cn:OpenDAS/fairscale.git" did not exist on "efffd7edcf832bd52e78bb0ae016d409fb363133"
Unverified Commit c19cc897 authored by anj-s's avatar anj-s Committed by GitHub
Browse files

[offload] Add support for record_function when using OffloadModel (#564)

* add record_function support

* add more record_function cutpoints

* add more record_function cutpoints

* lint errors

* make string ids more specific
parent 2d3d5a7b
...@@ -214,7 +214,8 @@ def train(model_config, model, benchmark_config, model_specs, args): ...@@ -214,7 +214,8 @@ def train(model_config, model, benchmark_config, model_specs, args):
total_tokens_per_log_interval = 0 total_tokens_per_log_interval = 0
total_loss = 0 total_loss = 0
start_time = time.time() start_time = time.time()
prof.export_chrome_trace("/tmp/offload_prof") if args.use_profiler:
prof.export_chrome_trace("/tmp/offload_prof")
if epoch_start_time != 0: if epoch_start_time != 0:
wps = total_tokens / (time.time() - epoch_start_time) wps = total_tokens / (time.time() - epoch_start_time)
......
...@@ -165,38 +165,41 @@ class ActivationCheckpointing(torch.autograd.Function): ...@@ -165,38 +165,41 @@ class ActivationCheckpointing(torch.autograd.Function):
model_instance._activations = [inputs] model_instance._activations = [inputs]
# Enumerate through layer shards and apply activations from the previous shard. # Enumerate through layer shards and apply activations from the previous shard.
for index, layer_shard in enumerate(model_instance.model_slices): for index, layer_shard in enumerate(model_instance.model_slices):
# Bring in the current activations onto the device. with torch.autograd.profiler.record_function("fairscale.experimental.nn.offload:forward_load"):
model_instance._activations[index] = tuple([a.cuda() for a in list(model_instance._activations[index])]) # Bring in the current activations onto the device.
# Bring in the current layer shard onto the device. model_instance._activations[index] = tuple([a.cuda() for a in list(model_instance._activations[index])])
layer_shard.forward_load() # Bring in the current layer shard onto the device.
layer_shard.forward_load()
# Apply the FP and store the activations on the CPU. # Apply the FP and store the activations on the CPU.
inputs = model_instance._activations[index] inputs = model_instance._activations[index]
with torch.autograd.profiler.record_function("fairscale.experimental.nn.offload:no_grad_forward_pass"):
with torch.no_grad(): with torch.no_grad():
output_list: List[Any] = [] output_list: List[Any] = []
for given_input in inputs: for given_input in inputs:
given_input_list = torch.chunk(given_input, model_instance._num_microbatches) given_input_list = torch.chunk(given_input, model_instance._num_microbatches)
given_output_list = [] given_output_list = []
for inputs in given_input_list: for inputs in given_input_list:
output = layer_shard(inputs) output = layer_shard(inputs)
given_output_list.append(output) given_output_list.append(output)
given_output = torch.cat(given_output_list).squeeze(-1) given_output = torch.cat(given_output_list).squeeze(-1)
output_list.append(given_output) output_list.append(given_output)
output = tuple(output_list) output = tuple(output_list)
output = output if isinstance(output, tuple) else (output,) output = output if isinstance(output, tuple) else (output,)
# The last instance will lose the gradient function if we move it to the CPU. with torch.autograd.profiler.record_function("fairscale.experimental.nn.offload:forward_drop"):
# This is because all grad function are present on the device that ran the FW pass. # The last instance will lose the gradient function if we move it to the CPU.
if index == len(model_instance.model_slices) - 1: # This is because all grad function are present on the device that ran the FW pass.
model_instance._activations.append(output) if index == len(model_instance.model_slices) - 1:
else: model_instance._activations.append(output)
model_instance._activations.append(tuple([a.cpu() for a in list(output)])) else:
# Move the layer shard back to the CPU. model_instance._activations.append(tuple([a.cpu() for a in list(output)]))
layer_shard.forward_drop() # Move the layer shard back to the CPU.
layer_shard.forward_drop()
# TODO(anj-s): Check device of the result to make sure the outputs and targets match device. # TODO(anj-s): Check device of the result to make sure the outputs and targets match device.
result = model_instance._activations[-1] result = model_instance._activations[-1]
result = [r.cuda() for r in result]
for r in result: for r in result:
r.requires_grad = True r.requires_grad = True
return result[0] if len(result) == 1 else result return result[0] if len(result) == 1 else result
...@@ -217,8 +220,10 @@ class ActivationCheckpointing(torch.autograd.Function): ...@@ -217,8 +220,10 @@ class ActivationCheckpointing(torch.autograd.Function):
for model_shard, activation in zip( for model_shard, activation in zip(
reversed(model_instance.model_slices), reversed(model_instance._activations[:-1]) reversed(model_instance.model_slices), reversed(model_instance._activations[:-1])
): ):
# Move the model shard to the device. with torch.autograd.profiler.record_function("fairscale.experimental.nn.offload:backward_load"):
model_shard.backward_load() # Move the model shard to the device.
model_shard.backward_load()
# Store the BW pass state. # Store the BW pass state.
bwd_rng_state = torch.get_rng_state() bwd_rng_state = torch.get_rng_state()
...@@ -253,13 +258,17 @@ class ActivationCheckpointing(torch.autograd.Function): ...@@ -253,13 +258,17 @@ class ActivationCheckpointing(torch.autograd.Function):
a.requires_grad = True a.requires_grad = True
a.retain_grad() a.retain_grad()
with torch.enable_grad(): with torch.autograd.profiler.record_function(
# calculate the output of the last shard wrt to the stored activation at the slice boundary. "fairscale.experimental.nn.offload:forward_pass_with_enable_grad"
outputs = model_shard(*chunked_activation) ):
with torch.enable_grad():
# calculate the output of the last shard wrt to the stored activation at the slice boundary.
outputs = model_shard(*chunked_activation)
# Set the states back to what it was at the start of this function. # Set the states back to what it was at the start of this function.
torch.set_rng_state(bwd_rng_state) torch.set_rng_state(bwd_rng_state)
torch.autograd.backward(outputs, chunked_grad) with torch.autograd.profiler.record_function("fairscale.experimental.nn.offload:backward_pass"):
torch.autograd.backward(outputs, chunked_grad)
intermediate_grads = [] intermediate_grads = []
for a in chunked_activation: for a in chunked_activation:
if a.grad is not None: if a.grad is not None:
...@@ -270,8 +279,9 @@ class ActivationCheckpointing(torch.autograd.Function): ...@@ -270,8 +279,9 @@ class ActivationCheckpointing(torch.autograd.Function):
# Append the list of grads to the all_grads list and this should be on the CPU. # Append the list of grads to the all_grads list and this should be on the CPU.
all_grads.append(torch.cat(chunked_grad_list).squeeze(-1)) # type: ignore all_grads.append(torch.cat(chunked_grad_list).squeeze(-1)) # type: ignore
# TODO(anj-s): Why does moving activations to CPU cause the .grad property to be None? # TODO(anj-s): Why does moving activations to CPU cause the .grad property to be None?
# Move the shard back to the CPU. with torch.autograd.profiler.record_function("fairscale.experimental.nn.offload:backward_drop"):
model_shard.backward_drop() # Move the shard back to the CPU.
model_shard.backward_drop()
detached_inputs = model_instance._activations[0] detached_inputs = model_instance._activations[0]
grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp for inp in detached_inputs) grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp for inp in detached_inputs)
return (None, None) + grads return (None, None) + grads
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment