Unverified Commit 85dea5b2 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[chore] SDP - adding the profiler labels (#630)

* adding the labels
* longer labels, following aten::
parent 38ce54b7
......@@ -18,6 +18,7 @@ from typing import Any, Callable, Deque, Dict, Generator, List, Optional, Union
import torch
from torch import nn
from torch.autograd import Variable
import torch.autograd.profiler as profiler
import torch.distributed as dist
from fairscale.nn.misc import GradBucket
......@@ -199,25 +200,26 @@ class ShardedDataParallel(nn.Module):
backward pass for gradient reduction to the proper ranks.
"""
# Deferred initialization, or change detection
needs_setup = len(self._grad_hooks) == 0 and self.training
with profiler.record_function("fairscale::sdp::forward"):
# Deferred initialization, or change detection
needs_setup = len(self._grad_hooks) == 0 and self.training
if self._auto_refresh_trainable:
needs_setup |= self._detect_train_change()
if self._auto_refresh_trainable:
needs_setup |= self._detect_train_change()
if needs_setup:
self.refresh_trainable()
if needs_setup:
self.refresh_trainable()
if self._enable_broadcast_buffers:
# NCCL communications are on a different stream, needs to be blocking
# for the subsequent FW to be correct
self.sync_buffers(blocking=True)
if self._enable_broadcast_buffers:
# NCCL communications are on a different stream, needs to be blocking
# for the subsequent FW to be correct
self.sync_buffers(blocking=True)
# Reset all the grad reduce and bucket state flags
self._clear_counters()
# Reset all the grad reduce and bucket state flags
self._clear_counters()
# Normal FW on the base model
return self._module(*inputs, **kwargs)
# Normal FW on the base model
return self._module(*inputs, **kwargs)
def to( # type: ignore
self,
......@@ -274,24 +276,25 @@ class ShardedDataParallel(nn.Module):
"Grads waiting to be reduced. If this is on purpose (grad accumulation), please use a no_sync() context"
)
self._trainable_params = list(filter(lambda x: x.requires_grad, self._all_params))
self._trainable_params.sort(key=lambda x: x.numel())
with profiler.record_function("fairscale::sdp::refresh_trainable"):
self._trainable_params = list(filter(lambda x: x.requires_grad, self._all_params))
self._trainable_params.sort(key=lambda x: x.numel())
self._trainable_param_to_rank = {}
for optim in self._sharded_optimizers:
# OSS may need to change the communication pattern
optim.refresh_trainable()
self._trainable_param_to_rank = {}
for optim in self._sharded_optimizers:
# OSS may need to change the communication pattern
optim.refresh_trainable()
# Update ShardedDDP given the new partitions
for (
device_per_rank_params
) in optim._per_device_params.values(): # all the params on this device (inc all ranks)
for device_params in device_per_rank_params:
for param in filter(lambda x: x.requires_grad, device_params):
self._trainable_param_to_rank[param] = optim._param_to_rank[param]
# Update ShardedDDP given the new partitions
for (
device_per_rank_params
) in optim._per_device_params.values(): # all the params on this device (inc all ranks)
for device_params in device_per_rank_params:
for param in filter(lambda x: x.requires_grad, device_params):
self._trainable_param_to_rank[param] = optim._param_to_rank[param]
self._setup_bucket_strategy()
self._setup_backward_hooks()
self._setup_bucket_strategy()
self._setup_backward_hooks()
def reduce(self) -> None:
"""
......@@ -320,18 +323,19 @@ class ShardedDataParallel(nn.Module):
blocking (bool): wait for the operation to conclude.
"""
work_handles = []
with profiler.record_function("fairscale::sdp::sync_buffers"):
work_handles = []
for buffer in self._module.buffers(recurse=True):
work_handles.append(
dist.broadcast(buffer.data, self._reference_global_rank, self._process_group, async_op=True)
)
for buffer in self._module.buffers(recurse=True):
work_handles.append(
dist.broadcast(buffer.data, self._reference_global_rank, self._process_group, async_op=True)
)
if blocking and work_handles:
if self._backend != dist.Backend.NCCL:
_ = list(filter(lambda x: x.wait(), work_handles))
else:
work_handles[-1].wait()
if blocking and work_handles:
if self._backend != dist.Backend.NCCL:
_ = list(filter(lambda x: x.wait(), work_handles))
else:
work_handles[-1].wait()
def zero_grad(self, set_to_none: bool = False) -> None:
r"""Sets gradients of all model parameters to zero. See similar function
......@@ -480,39 +484,39 @@ class ShardedDataParallel(nn.Module):
Attach a reduce function to each grad-requiring parameter.
This makes the gradient reduction automatic whenever there's a backward pass
"""
# Detach possible pre-existing hooks
while len(self._grad_hooks) > 0:
self._grad_hooks.pop().remove()
# Go through the parameters, attach the hook
self._grad_accs = []
self._manual_reduce = []
if not self.training:
return
for index, param in enumerate(self._trainable_params):
if param.grad is not None and param.grad.requires_grad:
raise RuntimeError("ShardedDataParallel only works with gradients that don't require grad")
p_tmp = param.expand_as(param)
# See https://pytorch.org/docs/stable/tensors.html?highlight=grad_fn
# We're interested in the tensors which will be tracked by Autograd
# Some tensors can have gradients independent of the inputs (ie. pooling layer for instance),
# these do not need to be sync'ed
if p_tmp.grad_fn is not None:
# Register the hook to the next function in line,
# so that the hook is fired when this grad has properly been computed
# (by default the hook with Pytorch is a pre-grad, not a post-grad)
grad_acc = p_tmp.grad_fn.next_functions[0][0]
dst_rank = self._trainable_param_to_rank[param]
reduce_function = self._get_reduce_fn(index, param, dst_rank)
self._grad_hooks.append(grad_acc.register_hook(reduce_function))
self._grad_accs.append(grad_acc) # keep this hook in scope
self._manual_reduce.append(reduce_function)
with profiler.record_function("fairscale::sdp::setup_backward_hooks"):
# Detach possible pre-existing hooks
while len(self._grad_hooks) > 0:
self._grad_hooks.pop().remove()
# Go through the parameters, attach the hook
self._grad_accs = []
self._manual_reduce = []
if not self.training:
return
for index, param in enumerate(self._trainable_params):
if param.grad is not None and param.grad.requires_grad:
raise RuntimeError("ShardedDataParallel only works with gradients that don't require grad")
p_tmp = param.expand_as(param)
# See https://pytorch.org/docs/stable/tensors.html?highlight=grad_fn
# We're interested in the tensors which will be tracked by Autograd
# Some tensors can have gradients independent of the inputs (ie. pooling layer for instance),
# these do not need to be sync'ed
if p_tmp.grad_fn is not None:
# Register the hook to the next function in line,
# so that the hook is fired when this grad has properly been computed
# (by default the hook with Pytorch is a pre-grad, not a post-grad)
grad_acc = p_tmp.grad_fn.next_functions[0][0]
dst_rank = self._trainable_param_to_rank[param]
reduce_function = self._get_reduce_fn(index, param, dst_rank)
self._grad_hooks.append(grad_acc.register_hook(reduce_function))
self._grad_accs.append(grad_acc) # keep this hook in scope
self._manual_reduce.append(reduce_function)
@torch.no_grad()
def _sync_params_and_buffers(self) -> None:
......@@ -552,41 +556,42 @@ class ShardedDataParallel(nn.Module):
This method can be a slow for big models, but it it not typically called often (not for every forward for instance)
"""
if not self._use_buckets:
return
with profiler.record_function("fairscale::sdp::setup_buckets"):
if not self._use_buckets:
return
# Devise the bucketing strategy. Parameters are already sorted, in that:
# - these are only the trainable parameters, so they should produce grads
# - they are sorted by increasing size
self._buckets = {}
self._should_bucket_grad = [False for _ in self._trainable_params]
# Devise the bucketing strategy. Parameters are already sorted, in that:
# - these are only the trainable parameters, so they should produce grads
# - they are sorted by increasing size
self._buckets = {}
self._should_bucket_grad = [False for _ in self._trainable_params]
for i, param in enumerate(self._trainable_params):
device = param.device
dst_rank = self._trainable_param_to_rank[param]
for i, param in enumerate(self._trainable_params):
device = param.device
dst_rank = self._trainable_param_to_rank[param]
if param.device not in self._buckets.keys():
self._buckets[param.device] = {}
if param.device not in self._buckets.keys():
self._buckets[param.device] = {}
if dst_rank not in self._buckets[param.device].keys():
self._buckets[param.device][dst_rank] = GradBucket(
self._buffer_max_size,
dtype=param.dtype,
device=param.device,
destination=self._local_to_global_rank[dst_rank],
)
if dst_rank not in self._buckets[param.device].keys():
self._buckets[param.device][dst_rank] = GradBucket(
self._buffer_max_size,
dtype=param.dtype,
device=param.device,
destination=self._local_to_global_rank[dst_rank],
)
# Criteria to decide whether this parameter is to be bucketed or not:
# - enough room in the bucket
if self._buckets[device][dst_rank].can_add_grad_view(param):
self._buckets[device][dst_rank].add_grad(param)
self._should_bucket_grad[i] = True
# Criteria to decide whether this parameter is to be bucketed or not:
# - enough room in the bucket
if self._buckets[device][dst_rank].can_add_grad_view(param):
self._buckets[device][dst_rank].add_grad(param)
self._should_bucket_grad[i] = True
self._bucket_list = list(chain(*[self._buckets[device].values() for device in self._buckets.keys()]))
self._bucket_list = list(chain(*[self._buckets[device].values() for device in self._buckets.keys()]))
# Resize the buckets to remove lost space in the end
for bucket in self._bucket_list:
bucket.shrink()
# Resize the buckets to remove lost space in the end
for bucket in self._bucket_list:
bucket.shrink()
def _consume_work_handles(self) -> None:
"""Consume all the futures which are tied to this optimizer's buckets.
......@@ -628,19 +633,20 @@ class ShardedDataParallel(nn.Module):
self._consume_work_handles()
def _detect_train_change(self) -> bool:
# Optionally check whether the trainable parameters have changed
trainable_mask = list(map(_trainable, self._all_params))
with profiler.record_function("fairscale::sdp::detect_train_changes"):
# Optionally check whether the trainable parameters have changed
trainable_mask = list(map(_trainable, self._all_params))
# - one or more parameters trainability changed
trainability_changed = trainable_mask != self._reference_trainable_mask
# - one or more parameters trainability changed
trainability_changed = trainable_mask != self._reference_trainable_mask
# - the whole model is not trainable but we still have grad hooks
trainability_changed |= not self.training and len(self._grad_hooks) > 0
# - the whole model is not trainable but we still have grad hooks
trainability_changed |= not self.training and len(self._grad_hooks) > 0
if trainability_changed:
logging.warning(
"ShardedDDP detected that the trainable params changed, either because of eval/train mode or parameter freezing/unfreeze."
)
self._reference_trainable_mask = trainable_mask
if trainability_changed:
logging.warning(
"ShardedDDP detected that the trainable params changed, either because of eval/train mode or parameter freezing/unfreeze."
)
self._reference_trainable_mask = trainable_mask
return trainability_changed
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment