"git@developer.sourcefind.cn:OpenDAS/lmdeploy.git" did not exist on "5ea40abf613e47bb56a0c06f695644d55671f585"
Unverified Commit c19cc897 authored by anj-s's avatar anj-s Committed by GitHub
Browse files

[offload] Add support for record_function when using OffloadModel (#564)

* add record_function support

* add more record_function cutpoints

* add more record_function cutpoints

* lint errors

* make string ids more specific
parent 2d3d5a7b
......@@ -214,7 +214,8 @@ def train(model_config, model, benchmark_config, model_specs, args):
total_tokens_per_log_interval = 0
total_loss = 0
start_time = time.time()
prof.export_chrome_trace("/tmp/offload_prof")
if args.use_profiler:
prof.export_chrome_trace("/tmp/offload_prof")
if epoch_start_time != 0:
wps = total_tokens / (time.time() - epoch_start_time)
......
......@@ -165,38 +165,41 @@ class ActivationCheckpointing(torch.autograd.Function):
model_instance._activations = [inputs]
# Enumerate through layer shards and apply activations from the previous shard.
for index, layer_shard in enumerate(model_instance.model_slices):
# Bring in the current activations onto the device.
model_instance._activations[index] = tuple([a.cuda() for a in list(model_instance._activations[index])])
# Bring in the current layer shard onto the device.
layer_shard.forward_load()
with torch.autograd.profiler.record_function("fairscale.experimental.nn.offload:forward_load"):
# Bring in the current activations onto the device.
model_instance._activations[index] = tuple([a.cuda() for a in list(model_instance._activations[index])])
# Bring in the current layer shard onto the device.
layer_shard.forward_load()
# Apply the FP and store the activations on the CPU.
inputs = model_instance._activations[index]
with torch.no_grad():
output_list: List[Any] = []
for given_input in inputs:
given_input_list = torch.chunk(given_input, model_instance._num_microbatches)
given_output_list = []
for inputs in given_input_list:
output = layer_shard(inputs)
given_output_list.append(output)
given_output = torch.cat(given_output_list).squeeze(-1)
output_list.append(given_output)
output = tuple(output_list)
with torch.autograd.profiler.record_function("fairscale.experimental.nn.offload:no_grad_forward_pass"):
with torch.no_grad():
output_list: List[Any] = []
for given_input in inputs:
given_input_list = torch.chunk(given_input, model_instance._num_microbatches)
given_output_list = []
for inputs in given_input_list:
output = layer_shard(inputs)
given_output_list.append(output)
given_output = torch.cat(given_output_list).squeeze(-1)
output_list.append(given_output)
output = tuple(output_list)
output = output if isinstance(output, tuple) else (output,)
# The last instance will lose the gradient function if we move it to the CPU.
# This is because all grad function are present on the device that ran the FW pass.
if index == len(model_instance.model_slices) - 1:
model_instance._activations.append(output)
else:
model_instance._activations.append(tuple([a.cpu() for a in list(output)]))
# Move the layer shard back to the CPU.
layer_shard.forward_drop()
with torch.autograd.profiler.record_function("fairscale.experimental.nn.offload:forward_drop"):
# The last instance will lose the gradient function if we move it to the CPU.
# This is because all grad function are present on the device that ran the FW pass.
if index == len(model_instance.model_slices) - 1:
model_instance._activations.append(output)
else:
model_instance._activations.append(tuple([a.cpu() for a in list(output)]))
# Move the layer shard back to the CPU.
layer_shard.forward_drop()
# TODO(anj-s): Check device of the result to make sure the outputs and targets match device.
result = model_instance._activations[-1]
result = [r.cuda() for r in result]
for r in result:
r.requires_grad = True
return result[0] if len(result) == 1 else result
......@@ -217,8 +220,10 @@ class ActivationCheckpointing(torch.autograd.Function):
for model_shard, activation in zip(
reversed(model_instance.model_slices), reversed(model_instance._activations[:-1])
):
# Move the model shard to the device.
model_shard.backward_load()
with torch.autograd.profiler.record_function("fairscale.experimental.nn.offload:backward_load"):
# Move the model shard to the device.
model_shard.backward_load()
# Store the BW pass state.
bwd_rng_state = torch.get_rng_state()
......@@ -253,13 +258,17 @@ class ActivationCheckpointing(torch.autograd.Function):
a.requires_grad = True
a.retain_grad()
with torch.enable_grad():
# calculate the output of the last shard wrt to the stored activation at the slice boundary.
outputs = model_shard(*chunked_activation)
with torch.autograd.profiler.record_function(
"fairscale.experimental.nn.offload:forward_pass_with_enable_grad"
):
with torch.enable_grad():
# calculate the output of the last shard wrt to the stored activation at the slice boundary.
outputs = model_shard(*chunked_activation)
# Set the states back to what it was at the start of this function.
torch.set_rng_state(bwd_rng_state)
torch.autograd.backward(outputs, chunked_grad)
with torch.autograd.profiler.record_function("fairscale.experimental.nn.offload:backward_pass"):
torch.autograd.backward(outputs, chunked_grad)
intermediate_grads = []
for a in chunked_activation:
if a.grad is not None:
......@@ -270,8 +279,9 @@ class ActivationCheckpointing(torch.autograd.Function):
# Append the list of grads to the all_grads list and this should be on the CPU.
all_grads.append(torch.cat(chunked_grad_list).squeeze(-1)) # type: ignore
# TODO(anj-s): Why does moving activations to CPU cause the .grad property to be None?
# Move the shard back to the CPU.
model_shard.backward_drop()
with torch.autograd.profiler.record_function("fairscale.experimental.nn.offload:backward_drop"):
# Move the shard back to the CPU.
model_shard.backward_drop()
detached_inputs = model_instance._activations[0]
grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp for inp in detached_inputs)
return (None, None) + grads
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment