Unverified Commit 85dea5b2 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[chore] SDP - adding the profiler labels (#630)

* adding the labels
* longer labels, following aten::
parent 38ce54b7
...@@ -18,6 +18,7 @@ from typing import Any, Callable, Deque, Dict, Generator, List, Optional, Union ...@@ -18,6 +18,7 @@ from typing import Any, Callable, Deque, Dict, Generator, List, Optional, Union
import torch import torch
from torch import nn from torch import nn
from torch.autograd import Variable from torch.autograd import Variable
import torch.autograd.profiler as profiler
import torch.distributed as dist import torch.distributed as dist
from fairscale.nn.misc import GradBucket from fairscale.nn.misc import GradBucket
...@@ -199,25 +200,26 @@ class ShardedDataParallel(nn.Module): ...@@ -199,25 +200,26 @@ class ShardedDataParallel(nn.Module):
backward pass for gradient reduction to the proper ranks. backward pass for gradient reduction to the proper ranks.
""" """
# Deferred initialization, or change detection with profiler.record_function("fairscale::sdp::forward"):
needs_setup = len(self._grad_hooks) == 0 and self.training # Deferred initialization, or change detection
needs_setup = len(self._grad_hooks) == 0 and self.training
if self._auto_refresh_trainable: if self._auto_refresh_trainable:
needs_setup |= self._detect_train_change() needs_setup |= self._detect_train_change()
if needs_setup: if needs_setup:
self.refresh_trainable() self.refresh_trainable()
if self._enable_broadcast_buffers: if self._enable_broadcast_buffers:
# NCCL communications are on a different stream, needs to be blocking # NCCL communications are on a different stream, needs to be blocking
# for the subsequent FW to be correct # for the subsequent FW to be correct
self.sync_buffers(blocking=True) self.sync_buffers(blocking=True)
# Reset all the grad reduce and bucket state flags # Reset all the grad reduce and bucket state flags
self._clear_counters() self._clear_counters()
# Normal FW on the base model # Normal FW on the base model
return self._module(*inputs, **kwargs) return self._module(*inputs, **kwargs)
def to( # type: ignore def to( # type: ignore
self, self,
...@@ -274,24 +276,25 @@ class ShardedDataParallel(nn.Module): ...@@ -274,24 +276,25 @@ class ShardedDataParallel(nn.Module):
"Grads waiting to be reduced. If this is on purpose (grad accumulation), please use a no_sync() context" "Grads waiting to be reduced. If this is on purpose (grad accumulation), please use a no_sync() context"
) )
self._trainable_params = list(filter(lambda x: x.requires_grad, self._all_params)) with profiler.record_function("fairscale::sdp::refresh_trainable"):
self._trainable_params.sort(key=lambda x: x.numel()) self._trainable_params = list(filter(lambda x: x.requires_grad, self._all_params))
self._trainable_params.sort(key=lambda x: x.numel())
self._trainable_param_to_rank = {} self._trainable_param_to_rank = {}
for optim in self._sharded_optimizers: for optim in self._sharded_optimizers:
# OSS may need to change the communication pattern # OSS may need to change the communication pattern
optim.refresh_trainable() optim.refresh_trainable()
# Update ShardedDDP given the new partitions # Update ShardedDDP given the new partitions
for ( for (
device_per_rank_params device_per_rank_params
) in optim._per_device_params.values(): # all the params on this device (inc all ranks) ) in optim._per_device_params.values(): # all the params on this device (inc all ranks)
for device_params in device_per_rank_params: for device_params in device_per_rank_params:
for param in filter(lambda x: x.requires_grad, device_params): for param in filter(lambda x: x.requires_grad, device_params):
self._trainable_param_to_rank[param] = optim._param_to_rank[param] self._trainable_param_to_rank[param] = optim._param_to_rank[param]
self._setup_bucket_strategy() self._setup_bucket_strategy()
self._setup_backward_hooks() self._setup_backward_hooks()
def reduce(self) -> None: def reduce(self) -> None:
""" """
...@@ -320,18 +323,19 @@ class ShardedDataParallel(nn.Module): ...@@ -320,18 +323,19 @@ class ShardedDataParallel(nn.Module):
blocking (bool): wait for the operation to conclude. blocking (bool): wait for the operation to conclude.
""" """
work_handles = [] with profiler.record_function("fairscale::sdp::sync_buffers"):
work_handles = []
for buffer in self._module.buffers(recurse=True): for buffer in self._module.buffers(recurse=True):
work_handles.append( work_handles.append(
dist.broadcast(buffer.data, self._reference_global_rank, self._process_group, async_op=True) dist.broadcast(buffer.data, self._reference_global_rank, self._process_group, async_op=True)
) )
if blocking and work_handles: if blocking and work_handles:
if self._backend != dist.Backend.NCCL: if self._backend != dist.Backend.NCCL:
_ = list(filter(lambda x: x.wait(), work_handles)) _ = list(filter(lambda x: x.wait(), work_handles))
else: else:
work_handles[-1].wait() work_handles[-1].wait()
def zero_grad(self, set_to_none: bool = False) -> None: def zero_grad(self, set_to_none: bool = False) -> None:
r"""Sets gradients of all model parameters to zero. See similar function r"""Sets gradients of all model parameters to zero. See similar function
...@@ -480,39 +484,39 @@ class ShardedDataParallel(nn.Module): ...@@ -480,39 +484,39 @@ class ShardedDataParallel(nn.Module):
Attach a reduce function to each grad-requiring parameter. Attach a reduce function to each grad-requiring parameter.
This makes the gradient reduction automatic whenever there's a backward pass This makes the gradient reduction automatic whenever there's a backward pass
""" """
with profiler.record_function("fairscale::sdp::setup_backward_hooks"):
# Detach possible pre-existing hooks # Detach possible pre-existing hooks
while len(self._grad_hooks) > 0: while len(self._grad_hooks) > 0:
self._grad_hooks.pop().remove() self._grad_hooks.pop().remove()
# Go through the parameters, attach the hook # Go through the parameters, attach the hook
self._grad_accs = [] self._grad_accs = []
self._manual_reduce = [] self._manual_reduce = []
if not self.training: if not self.training:
return return
for index, param in enumerate(self._trainable_params): for index, param in enumerate(self._trainable_params):
if param.grad is not None and param.grad.requires_grad: if param.grad is not None and param.grad.requires_grad:
raise RuntimeError("ShardedDataParallel only works with gradients that don't require grad") raise RuntimeError("ShardedDataParallel only works with gradients that don't require grad")
p_tmp = param.expand_as(param) p_tmp = param.expand_as(param)
# See https://pytorch.org/docs/stable/tensors.html?highlight=grad_fn # See https://pytorch.org/docs/stable/tensors.html?highlight=grad_fn
# We're interested in the tensors which will be tracked by Autograd # We're interested in the tensors which will be tracked by Autograd
# Some tensors can have gradients independent of the inputs (ie. pooling layer for instance), # Some tensors can have gradients independent of the inputs (ie. pooling layer for instance),
# these do not need to be sync'ed # these do not need to be sync'ed
if p_tmp.grad_fn is not None: if p_tmp.grad_fn is not None:
# Register the hook to the next function in line, # Register the hook to the next function in line,
# so that the hook is fired when this grad has properly been computed # so that the hook is fired when this grad has properly been computed
# (by default the hook with Pytorch is a pre-grad, not a post-grad) # (by default the hook with Pytorch is a pre-grad, not a post-grad)
grad_acc = p_tmp.grad_fn.next_functions[0][0] grad_acc = p_tmp.grad_fn.next_functions[0][0]
dst_rank = self._trainable_param_to_rank[param] dst_rank = self._trainable_param_to_rank[param]
reduce_function = self._get_reduce_fn(index, param, dst_rank) reduce_function = self._get_reduce_fn(index, param, dst_rank)
self._grad_hooks.append(grad_acc.register_hook(reduce_function)) self._grad_hooks.append(grad_acc.register_hook(reduce_function))
self._grad_accs.append(grad_acc) # keep this hook in scope self._grad_accs.append(grad_acc) # keep this hook in scope
self._manual_reduce.append(reduce_function) self._manual_reduce.append(reduce_function)
@torch.no_grad() @torch.no_grad()
def _sync_params_and_buffers(self) -> None: def _sync_params_and_buffers(self) -> None:
...@@ -552,41 +556,42 @@ class ShardedDataParallel(nn.Module): ...@@ -552,41 +556,42 @@ class ShardedDataParallel(nn.Module):
This method can be a slow for big models, but it it not typically called often (not for every forward for instance) This method can be a slow for big models, but it it not typically called often (not for every forward for instance)
""" """
if not self._use_buckets: with profiler.record_function("fairscale::sdp::setup_buckets"):
return if not self._use_buckets:
return
# Devise the bucketing strategy. Parameters are already sorted, in that: # Devise the bucketing strategy. Parameters are already sorted, in that:
# - these are only the trainable parameters, so they should produce grads # - these are only the trainable parameters, so they should produce grads
# - they are sorted by increasing size # - they are sorted by increasing size
self._buckets = {} self._buckets = {}
self._should_bucket_grad = [False for _ in self._trainable_params] self._should_bucket_grad = [False for _ in self._trainable_params]
for i, param in enumerate(self._trainable_params): for i, param in enumerate(self._trainable_params):
device = param.device device = param.device
dst_rank = self._trainable_param_to_rank[param] dst_rank = self._trainable_param_to_rank[param]
if param.device not in self._buckets.keys(): if param.device not in self._buckets.keys():
self._buckets[param.device] = {} self._buckets[param.device] = {}
if dst_rank not in self._buckets[param.device].keys(): if dst_rank not in self._buckets[param.device].keys():
self._buckets[param.device][dst_rank] = GradBucket( self._buckets[param.device][dst_rank] = GradBucket(
self._buffer_max_size, self._buffer_max_size,
dtype=param.dtype, dtype=param.dtype,
device=param.device, device=param.device,
destination=self._local_to_global_rank[dst_rank], destination=self._local_to_global_rank[dst_rank],
) )
# Criteria to decide whether this parameter is to be bucketed or not: # Criteria to decide whether this parameter is to be bucketed or not:
# - enough room in the bucket # - enough room in the bucket
if self._buckets[device][dst_rank].can_add_grad_view(param): if self._buckets[device][dst_rank].can_add_grad_view(param):
self._buckets[device][dst_rank].add_grad(param) self._buckets[device][dst_rank].add_grad(param)
self._should_bucket_grad[i] = True self._should_bucket_grad[i] = True
self._bucket_list = list(chain(*[self._buckets[device].values() for device in self._buckets.keys()])) self._bucket_list = list(chain(*[self._buckets[device].values() for device in self._buckets.keys()]))
# Resize the buckets to remove lost space in the end # Resize the buckets to remove lost space in the end
for bucket in self._bucket_list: for bucket in self._bucket_list:
bucket.shrink() bucket.shrink()
def _consume_work_handles(self) -> None: def _consume_work_handles(self) -> None:
"""Consume all the futures which are tied to this optimizer's buckets. """Consume all the futures which are tied to this optimizer's buckets.
...@@ -628,19 +633,20 @@ class ShardedDataParallel(nn.Module): ...@@ -628,19 +633,20 @@ class ShardedDataParallel(nn.Module):
self._consume_work_handles() self._consume_work_handles()
def _detect_train_change(self) -> bool: def _detect_train_change(self) -> bool:
# Optionally check whether the trainable parameters have changed with profiler.record_function("fairscale::sdp::detect_train_changes"):
trainable_mask = list(map(_trainable, self._all_params)) # Optionally check whether the trainable parameters have changed
trainable_mask = list(map(_trainable, self._all_params))
# - one or more parameters trainability changed # - one or more parameters trainability changed
trainability_changed = trainable_mask != self._reference_trainable_mask trainability_changed = trainable_mask != self._reference_trainable_mask
# - the whole model is not trainable but we still have grad hooks # - the whole model is not trainable but we still have grad hooks
trainability_changed |= not self.training and len(self._grad_hooks) > 0 trainability_changed |= not self.training and len(self._grad_hooks) > 0
if trainability_changed: if trainability_changed:
logging.warning( logging.warning(
"ShardedDDP detected that the trainable params changed, either because of eval/train mode or parameter freezing/unfreeze." "ShardedDDP detected that the trainable params changed, either because of eval/train mode or parameter freezing/unfreeze."
) )
self._reference_trainable_mask = trainable_mask self._reference_trainable_mask = trainable_mask
return trainability_changed return trainability_changed
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment