Unverified Commit 85dea5b2 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[chore] SDP - adding the profiler labels (#630)

* adding the labels
* longer labels, following aten::
parent 38ce54b7
......@@ -18,6 +18,7 @@ from typing import Any, Callable, Deque, Dict, Generator, List, Optional, Union
import torch
from torch import nn
from torch.autograd import Variable
import torch.autograd.profiler as profiler
import torch.distributed as dist
from fairscale.nn.misc import GradBucket
......@@ -199,6 +200,7 @@ class ShardedDataParallel(nn.Module):
backward pass for gradient reduction to the proper ranks.
"""
with profiler.record_function("fairscale::sdp::forward"):
# Deferred initialization, or change detection
needs_setup = len(self._grad_hooks) == 0 and self.training
......@@ -274,6 +276,7 @@ class ShardedDataParallel(nn.Module):
"Grads waiting to be reduced. If this is on purpose (grad accumulation), please use a no_sync() context"
)
with profiler.record_function("fairscale::sdp::refresh_trainable"):
self._trainable_params = list(filter(lambda x: x.requires_grad, self._all_params))
self._trainable_params.sort(key=lambda x: x.numel())
......@@ -320,6 +323,7 @@ class ShardedDataParallel(nn.Module):
blocking (bool): wait for the operation to conclude.
"""
with profiler.record_function("fairscale::sdp::sync_buffers"):
work_handles = []
for buffer in self._module.buffers(recurse=True):
......@@ -480,7 +484,7 @@ class ShardedDataParallel(nn.Module):
Attach a reduce function to each grad-requiring parameter.
This makes the gradient reduction automatic whenever there's a backward pass
"""
with profiler.record_function("fairscale::sdp::setup_backward_hooks"):
# Detach possible pre-existing hooks
while len(self._grad_hooks) > 0:
self._grad_hooks.pop().remove()
......@@ -552,6 +556,7 @@ class ShardedDataParallel(nn.Module):
This method can be a slow for big models, but it it not typically called often (not for every forward for instance)
"""
with profiler.record_function("fairscale::sdp::setup_buckets"):
if not self._use_buckets:
return
......@@ -628,6 +633,7 @@ class ShardedDataParallel(nn.Module):
self._consume_work_handles()
def _detect_train_change(self) -> bool:
with profiler.record_function("fairscale::sdp::detect_train_changes"):
# Optionally check whether the trainable parameters have changed
trainable_mask = list(map(_trainable, self._all_params))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment