Unverified Commit a1bdc7d3 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[refactor][fix][SDP] Extract the grad buckets in a dedicated class, fix the resize_ bug (#532)

* extracting the buckets in a dedicated class, fixing the resize_ bug
* adding a unit test
* copyright
parent fcbf1ea3
......@@ -20,8 +20,9 @@ from torch import nn
from torch.autograd import Variable
import torch.distributed as dist
from fairscale.nn.misc import GradBucket
from fairscale.optim import OSS
from fairscale.optim.utils import Bucket, Workhandle
from fairscale.optim.utils import Workhandle
def _trainable(param: torch.Tensor) -> bool:
......@@ -171,9 +172,9 @@ class ShardedDataParallel(nn.Module):
)
self.use_buckets = self.buffer_max_size > 0
self.buckets: Dict[torch.device, List[Bucket]] = {}
self.buckets: Dict[torch.device, Dict[int, GradBucket]] = {}
self._should_bucket_grad: List[bool] = []
self._bucket_list: Optional[List[Bucket]] = None
self._bucket_list: List[GradBucket] = []
# - setup backward hooks which will be called by Torch's autograd in due time
self._grad_accs: List[Callable] = []
......@@ -257,8 +258,8 @@ class ShardedDataParallel(nn.Module):
), "Several devices specified to begin with, incompatible with setting a single device here"
for _device in self.buckets.keys():
for bucket in self.buckets[_device]:
bucket.buffer.to(device=device, dtype=dtype, non_blocking=non_blocking)
for bucket in self.buckets[_device].values():
bucket.to(device=_device, dtype=dtype, non_blocking=non_blocking)
self.module.to(device=device, dtype=dtype, non_blocking=non_blocking)
......@@ -344,6 +345,9 @@ class ShardedDataParallel(nn.Module):
elif trainable_param.grad is not None:
trainable_param.grad.zero_()
for bucket in self._bucket_list:
bucket.zero()
def __getattr__(self, name: str) -> Any:
"""Forward missing attributes to wrapped module."""
try:
......@@ -367,10 +371,8 @@ class ShardedDataParallel(nn.Module):
self._bucket_flush_callback_set = False
if self.use_buckets:
assert self._bucket_list is not None
for bucket in self._bucket_list:
bucket.reset()
bucket.reset_checked_in()
if not self.should_accumulate_grads:
self.accumulate_grads_flipped = False
......@@ -443,7 +445,9 @@ class ShardedDataParallel(nn.Module):
bucket = self.buckets[param.device][dst_rank]
bucket.params_checked_in += 1
if bucket.full():
if bucket.all_checked_in:
assert bucket.buffer is not None
# Normalize the bucket in one go
bucket.buffer.mul_(self.world_size_scaling)
......@@ -532,47 +536,34 @@ class ShardedDataParallel(nn.Module):
# - these are only the trainable parameters, so they should produce grads
# - they are sorted by increasing size
self.buckets = {}
self._should_bucket_grad = [False for _ in self._trainable_params]
for param in self._trainable_params:
for i, param in enumerate(self._trainable_params):
device = param.device
dst_rank = self._trainable_param_to_rank[param]
if param.device not in self.buckets.keys():
self.buckets[param.device] = [
Bucket(buffer=torch.zeros(self.buffer_max_size, dtype=param.dtype, device=device))
for _ in range(dist.get_world_size(self.process_group))
]
self.buckets[param.device] = {}
bucket = self.buckets[device][dst_rank]
bucket.destination = self._local_to_global_rank[dst_rank]
if dst_rank not in self.buckets[param.device].keys():
self.buckets[param.device][dst_rank] = GradBucket(
self.buffer_max_size,
dtype=param.dtype,
device=param.device,
destination=self._local_to_global_rank[dst_rank],
)
# Criteria to decide whether this parameter is to be bucketed or not:
# - enough room in the bucket
if (bucket.fill + param.numel()) < self.buffer_max_size:
self._should_bucket_grad.append(True)
# This parameter gradients becomes a view of the bucket
fill_next = bucket.fill + param.numel()
if param.grad is None:
# will be overwritten just below, see next line
param.grad = torch.zeros_like(param)
param.grad.data = bucket.buffer[bucket.fill : fill_next].view_as(param.data)
bucket.fill = fill_next
# Update the bucket
self.buckets[device][dst_rank].max_params_checked_in += 1
else:
self._should_bucket_grad.append(False)
if self.buckets[device][dst_rank].can_add_grad_view(param):
self.buckets[device][dst_rank].add_grad(param)
self._should_bucket_grad[i] = True
self._bucket_list = list(chain(*[self.buckets[device] for device in self.buckets.keys()]))
self._bucket_list = list(chain(*[self.buckets[device].values() for device in self.buckets.keys()]))
# Resize the buckets to remove lost space in the end
for bucket in self._bucket_list:
bucket.buffer.resize_(bucket.fill)
bucket.sent = True
bucket.shrink()
def _consume_work_handles(self) -> None:
"""Consume all the futures which are tied to this optimizer's buckets.
......@@ -593,21 +584,22 @@ class ShardedDataParallel(nn.Module):
work_handle.callback()
def _flush_reduce_calls(self) -> None:
if self._bucket_list is not None:
for bucket in self._bucket_list:
if not bucket.sent:
# Normalize the bucket in one go
bucket.buffer.mul_(self.world_size_scaling)
# Reduce the bucket
self._work_handles.append(
Workhandle(
handle=dist.reduce(
tensor=bucket.buffer, dst=bucket.destination, group=self.process_group, async_op=True,
),
callback=None,
)
for bucket in self._bucket_list:
if not bucket.sent:
assert bucket.buffer is not None
# Normalize the bucket in one go
bucket.buffer.mul_(self.world_size_scaling)
# Reduce the bucket
self._work_handles.append(
Workhandle(
handle=dist.reduce(
tensor=bucket.buffer, dst=bucket.destination, group=self.process_group, async_op=True,
),
callback=None,
)
bucket.sent = True
)
bucket.sent = True
self._consume_work_handles()
......@@ -5,3 +5,4 @@
from .checkpoint_activations import checkpoint_wrapper
from .flatten_params_wrapper import FlattenParamsWrapper
from .grad_bucket import GradBucket
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, Callable, List, Optional, Union
import torch
class GradBucket:
"""
Helper class to simplify the handling of gradient buckets
"""
def __init__(self, size: int, dtype: torch.dtype, device: torch.device, destination: int) -> None:
self._max_size = size
self._params: List[torch.Tensor] = []
self._fill = 0
self._is_collapsed = False
# The actual flat tensor
self.buffer: Optional[torch.Tensor] = torch.zeros(self._max_size, dtype=dtype, device=device)
self.params_checked_in = 0
self.destination = destination
self.sent = True
self.callback: Optional[Callable[[Any], None]] = None
def reset_checked_in(self) -> None:
""" Reset the counter of the parameter grads which have been checked in
"""
self.params_checked_in = 0
self.sent = False
@property
def all_checked_in(self) -> bool:
""" Have all the expected gradient check-in happened ?"""
return len(self._params) == self.params_checked_in
def can_add_grad_view(self, param: torch.Tensor) -> bool:
""" Is there enough room in the bucket to add this parameter gradient ?
"""
return self._fill + param.numel() < self._max_size
def to( # type: ignore
self,
device: Optional[Union[int, torch.device]],
dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
) -> "GradBucket":
"""
Move the underlying buffer
"""
if self._is_collapsed:
self.rebuild()
assert self.buffer is not None, "Cannot move a collapsed bucket, please rebuild it"
self.buffer.to(device, dtype, non_blocking)
def zero(self) -> None:
"""
Set all the grads to zero
"""
if self.buffer is not None:
self.buffer.fill_(0.0)
@torch.no_grad()
def add_grad(self, param: torch.Tensor) -> None:
"""
Add a new parameter gradient to the bucket. Param.grad becomes a view of this bucket buffer
"""
if param.grad is None:
param.grad = torch.zeros_like(param)
self._add_grad_as_view(param)
self._params.append(param)
@torch.no_grad()
def collapse(self) -> None:
"""
Release the buffer from memory. The bucket will need to be rebuilt before use
"""
if not self._is_collapsed:
for p in self._params:
assert p.grad is not None
p.grad.detach_()
p.grad = None
self.buffer = None
self._fill = 0
self.params_checked_in = 0
self._is_collapsed = True
@torch.no_grad()
def rebuild(self) -> None:
"""
Given the parameter gradients which have been registered previously, rebuild the whole bucket
"""
assert len(self._params) > 0
if self._is_collapsed:
self.buffer = torch.zeros(self._max_size, dtype=self._params[0].dtype, device=self._params[0].device)
for p in self._params:
self._add_grad_as_view(p)
self._is_collapsed = False
@torch.no_grad()
def shrink(self) -> None:
"""
Shrink the buffer to the size of the parameter gradients currently checked in, release the extra memory
"""
assert self.buffer is not None, "Cannot shrink a collapsed bucket, please rebuild"
self.buffer = self.buffer.resize_(self._fill).clone()
self._fill = 0
for p in self._params:
self._add_grad_as_view(p)
self._max_size = self._fill
@torch.no_grad()
def _add_grad_as_view(self, param: torch.Tensor) -> None:
assert self.buffer is not None
assert param.dtype == self.buffer.dtype
assert param.device == self.buffer.device
fill_next = self._fill + param.numel()
assert fill_next <= self.buffer.numel()
# Copy the current grad value, if any
if param.grad is not None:
# keep param.grad in place
self.buffer[self._fill : fill_next].copy_(param.grad.data.flatten())
param.grad.data = self.buffer[self._fill : fill_next].view_as(param.data)
else:
param.grad = self.buffer[self._fill : fill_next].view_as(param.data)
self._fill = fill_next
......@@ -79,32 +79,6 @@ def broadcast_object(
return obj
class Bucket:
"""
Helper class to simplify the handling of broadcast or reduce buckets
"""
def __init__(self, buffer: torch.Tensor) -> None:
# The actual flat tensor
self.buffer = buffer
self.max_size = buffer.numel()
# Current status for this buffer
self.fill = 0
self.params_checked_in = 0
self.max_params_checked_in = 0 # atttribute present for convenience purposes
self.destination = -1
self.sent = True
def reset(self) -> None:
self.params_checked_in = 0
self.sent = False
def full(self) -> bool:
""" is the bucket full ? """
return self.max_params_checked_in == self.params_checked_in
def calc_grad_norm(parameters: List[torch.nn.Parameter], p: float) -> torch.Tensor:
r"""Calculate gradient norm of an iterable of parameters.
Returns:
......
......@@ -4,6 +4,7 @@ tests/utils/test_parallel.py
tests/utils/test_state_dict.py
tests/nn/misc/test_checkpoint_activations.py
tests/nn/misc/test_checkpoint_activations_norm.py
tests/nn/misc/test_grad_bucket.py
tests/nn/wrap/test_wrap.py
tests/nn/pipe_process/test_pipe.py
tests/nn/pipe_process/test_transparency.py
......
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import pytest
import torch
from fairscale.nn.misc import GradBucket
def test_grad_values_conserved():
with torch.no_grad(): # remove a warning
param = torch.rand((2, 3), requires_grad=True)
param.grad = torch.rand(2, 3)
bucket = GradBucket(10, param.dtype, param.device, -1)
param_ = param.clone()
bucket.add_grad(param_)
torch.allclose(param.grad, param_.grad)
def test_memory_leak():
with torch.no_grad(): # remove a warning
param = torch.rand((2, 3), requires_grad=True)
param.grad = torch.rand(2, 3)
bucket = GradBucket(300, param.dtype, param.device, -1)
bucket.add_grad(param)
bucket.shrink()
assert len(bucket.buffer.storage()) == 6
def test_max_size():
with torch.no_grad(): # remove a warning
param = torch.rand((20, 30), requires_grad=True)
param.grad = torch.rand(20, 30)
bucket = GradBucket(5, param.dtype, param.device, -1)
with pytest.raises(AssertionError):
bucket.add_grad(param)
def test_collapse():
with torch.no_grad(): # remove a warning
size = (5, 6)
param = torch.rand(size, requires_grad=True)
param.grad = torch.rand(size)
bucket = GradBucket(300, param.dtype, param.device, -1)
bucket.add_grad(param)
bucket.shrink()
bucket.collapse()
assert bucket.buffer is None
assert param.grad is None
bucket.rebuild()
assert param.grad is not None
torch.allclose(param.grad, torch.zeros(size))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment