Unverified Commit a1bdc7d3 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[refactor][fix][SDP] Extract the grad buckets in a dedicated class, fix the resize_ bug (#532)

* extracting the buckets in a dedicated class, fixing the resize_ bug
* adding a unit test
* copyright
parent fcbf1ea3
...@@ -20,8 +20,9 @@ from torch import nn ...@@ -20,8 +20,9 @@ from torch import nn
from torch.autograd import Variable from torch.autograd import Variable
import torch.distributed as dist import torch.distributed as dist
from fairscale.nn.misc import GradBucket
from fairscale.optim import OSS from fairscale.optim import OSS
from fairscale.optim.utils import Bucket, Workhandle from fairscale.optim.utils import Workhandle
def _trainable(param: torch.Tensor) -> bool: def _trainable(param: torch.Tensor) -> bool:
...@@ -171,9 +172,9 @@ class ShardedDataParallel(nn.Module): ...@@ -171,9 +172,9 @@ class ShardedDataParallel(nn.Module):
) )
self.use_buckets = self.buffer_max_size > 0 self.use_buckets = self.buffer_max_size > 0
self.buckets: Dict[torch.device, List[Bucket]] = {} self.buckets: Dict[torch.device, Dict[int, GradBucket]] = {}
self._should_bucket_grad: List[bool] = [] self._should_bucket_grad: List[bool] = []
self._bucket_list: Optional[List[Bucket]] = None self._bucket_list: List[GradBucket] = []
# - setup backward hooks which will be called by Torch's autograd in due time # - setup backward hooks which will be called by Torch's autograd in due time
self._grad_accs: List[Callable] = [] self._grad_accs: List[Callable] = []
...@@ -257,8 +258,8 @@ class ShardedDataParallel(nn.Module): ...@@ -257,8 +258,8 @@ class ShardedDataParallel(nn.Module):
), "Several devices specified to begin with, incompatible with setting a single device here" ), "Several devices specified to begin with, incompatible with setting a single device here"
for _device in self.buckets.keys(): for _device in self.buckets.keys():
for bucket in self.buckets[_device]: for bucket in self.buckets[_device].values():
bucket.buffer.to(device=device, dtype=dtype, non_blocking=non_blocking) bucket.to(device=_device, dtype=dtype, non_blocking=non_blocking)
self.module.to(device=device, dtype=dtype, non_blocking=non_blocking) self.module.to(device=device, dtype=dtype, non_blocking=non_blocking)
...@@ -344,6 +345,9 @@ class ShardedDataParallel(nn.Module): ...@@ -344,6 +345,9 @@ class ShardedDataParallel(nn.Module):
elif trainable_param.grad is not None: elif trainable_param.grad is not None:
trainable_param.grad.zero_() trainable_param.grad.zero_()
for bucket in self._bucket_list:
bucket.zero()
def __getattr__(self, name: str) -> Any: def __getattr__(self, name: str) -> Any:
"""Forward missing attributes to wrapped module.""" """Forward missing attributes to wrapped module."""
try: try:
...@@ -367,10 +371,8 @@ class ShardedDataParallel(nn.Module): ...@@ -367,10 +371,8 @@ class ShardedDataParallel(nn.Module):
self._bucket_flush_callback_set = False self._bucket_flush_callback_set = False
if self.use_buckets: if self.use_buckets:
assert self._bucket_list is not None
for bucket in self._bucket_list: for bucket in self._bucket_list:
bucket.reset() bucket.reset_checked_in()
if not self.should_accumulate_grads: if not self.should_accumulate_grads:
self.accumulate_grads_flipped = False self.accumulate_grads_flipped = False
...@@ -443,7 +445,9 @@ class ShardedDataParallel(nn.Module): ...@@ -443,7 +445,9 @@ class ShardedDataParallel(nn.Module):
bucket = self.buckets[param.device][dst_rank] bucket = self.buckets[param.device][dst_rank]
bucket.params_checked_in += 1 bucket.params_checked_in += 1
if bucket.full(): if bucket.all_checked_in:
assert bucket.buffer is not None
# Normalize the bucket in one go # Normalize the bucket in one go
bucket.buffer.mul_(self.world_size_scaling) bucket.buffer.mul_(self.world_size_scaling)
...@@ -532,47 +536,34 @@ class ShardedDataParallel(nn.Module): ...@@ -532,47 +536,34 @@ class ShardedDataParallel(nn.Module):
# - these are only the trainable parameters, so they should produce grads # - these are only the trainable parameters, so they should produce grads
# - they are sorted by increasing size # - they are sorted by increasing size
self.buckets = {} self.buckets = {}
self._should_bucket_grad = [False for _ in self._trainable_params]
for param in self._trainable_params: for i, param in enumerate(self._trainable_params):
device = param.device device = param.device
dst_rank = self._trainable_param_to_rank[param] dst_rank = self._trainable_param_to_rank[param]
if param.device not in self.buckets.keys(): if param.device not in self.buckets.keys():
self.buckets[param.device] = [ self.buckets[param.device] = {}
Bucket(buffer=torch.zeros(self.buffer_max_size, dtype=param.dtype, device=device))
for _ in range(dist.get_world_size(self.process_group))
]
bucket = self.buckets[device][dst_rank] if dst_rank not in self.buckets[param.device].keys():
bucket.destination = self._local_to_global_rank[dst_rank] self.buckets[param.device][dst_rank] = GradBucket(
self.buffer_max_size,
dtype=param.dtype,
device=param.device,
destination=self._local_to_global_rank[dst_rank],
)
# Criteria to decide whether this parameter is to be bucketed or not: # Criteria to decide whether this parameter is to be bucketed or not:
# - enough room in the bucket # - enough room in the bucket
if (bucket.fill + param.numel()) < self.buffer_max_size: if self.buckets[device][dst_rank].can_add_grad_view(param):
self._should_bucket_grad.append(True) self.buckets[device][dst_rank].add_grad(param)
self._should_bucket_grad[i] = True
# This parameter gradients becomes a view of the bucket
fill_next = bucket.fill + param.numel()
if param.grad is None:
# will be overwritten just below, see next line
param.grad = torch.zeros_like(param)
param.grad.data = bucket.buffer[bucket.fill : fill_next].view_as(param.data)
bucket.fill = fill_next
# Update the bucket
self.buckets[device][dst_rank].max_params_checked_in += 1
else:
self._should_bucket_grad.append(False)
self._bucket_list = list(chain(*[self.buckets[device] for device in self.buckets.keys()])) self._bucket_list = list(chain(*[self.buckets[device].values() for device in self.buckets.keys()]))
# Resize the buckets to remove lost space in the end # Resize the buckets to remove lost space in the end
for bucket in self._bucket_list: for bucket in self._bucket_list:
bucket.buffer.resize_(bucket.fill) bucket.shrink()
bucket.sent = True
def _consume_work_handles(self) -> None: def _consume_work_handles(self) -> None:
"""Consume all the futures which are tied to this optimizer's buckets. """Consume all the futures which are tied to this optimizer's buckets.
...@@ -593,21 +584,22 @@ class ShardedDataParallel(nn.Module): ...@@ -593,21 +584,22 @@ class ShardedDataParallel(nn.Module):
work_handle.callback() work_handle.callback()
def _flush_reduce_calls(self) -> None: def _flush_reduce_calls(self) -> None:
if self._bucket_list is not None: for bucket in self._bucket_list:
for bucket in self._bucket_list: if not bucket.sent:
if not bucket.sent: assert bucket.buffer is not None
# Normalize the bucket in one go
bucket.buffer.mul_(self.world_size_scaling) # Normalize the bucket in one go
bucket.buffer.mul_(self.world_size_scaling)
# Reduce the bucket
self._work_handles.append( # Reduce the bucket
Workhandle( self._work_handles.append(
handle=dist.reduce( Workhandle(
tensor=bucket.buffer, dst=bucket.destination, group=self.process_group, async_op=True, handle=dist.reduce(
), tensor=bucket.buffer, dst=bucket.destination, group=self.process_group, async_op=True,
callback=None, ),
) callback=None,
) )
bucket.sent = True )
bucket.sent = True
self._consume_work_handles() self._consume_work_handles()
...@@ -5,3 +5,4 @@ ...@@ -5,3 +5,4 @@
from .checkpoint_activations import checkpoint_wrapper from .checkpoint_activations import checkpoint_wrapper
from .flatten_params_wrapper import FlattenParamsWrapper from .flatten_params_wrapper import FlattenParamsWrapper
from .grad_bucket import GradBucket
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, Callable, List, Optional, Union
import torch
class GradBucket:
"""
Helper class to simplify the handling of gradient buckets
"""
def __init__(self, size: int, dtype: torch.dtype, device: torch.device, destination: int) -> None:
self._max_size = size
self._params: List[torch.Tensor] = []
self._fill = 0
self._is_collapsed = False
# The actual flat tensor
self.buffer: Optional[torch.Tensor] = torch.zeros(self._max_size, dtype=dtype, device=device)
self.params_checked_in = 0
self.destination = destination
self.sent = True
self.callback: Optional[Callable[[Any], None]] = None
def reset_checked_in(self) -> None:
""" Reset the counter of the parameter grads which have been checked in
"""
self.params_checked_in = 0
self.sent = False
@property
def all_checked_in(self) -> bool:
""" Have all the expected gradient check-in happened ?"""
return len(self._params) == self.params_checked_in
def can_add_grad_view(self, param: torch.Tensor) -> bool:
""" Is there enough room in the bucket to add this parameter gradient ?
"""
return self._fill + param.numel() < self._max_size
def to( # type: ignore
self,
device: Optional[Union[int, torch.device]],
dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
) -> "GradBucket":
"""
Move the underlying buffer
"""
if self._is_collapsed:
self.rebuild()
assert self.buffer is not None, "Cannot move a collapsed bucket, please rebuild it"
self.buffer.to(device, dtype, non_blocking)
def zero(self) -> None:
"""
Set all the grads to zero
"""
if self.buffer is not None:
self.buffer.fill_(0.0)
@torch.no_grad()
def add_grad(self, param: torch.Tensor) -> None:
"""
Add a new parameter gradient to the bucket. Param.grad becomes a view of this bucket buffer
"""
if param.grad is None:
param.grad = torch.zeros_like(param)
self._add_grad_as_view(param)
self._params.append(param)
@torch.no_grad()
def collapse(self) -> None:
"""
Release the buffer from memory. The bucket will need to be rebuilt before use
"""
if not self._is_collapsed:
for p in self._params:
assert p.grad is not None
p.grad.detach_()
p.grad = None
self.buffer = None
self._fill = 0
self.params_checked_in = 0
self._is_collapsed = True
@torch.no_grad()
def rebuild(self) -> None:
"""
Given the parameter gradients which have been registered previously, rebuild the whole bucket
"""
assert len(self._params) > 0
if self._is_collapsed:
self.buffer = torch.zeros(self._max_size, dtype=self._params[0].dtype, device=self._params[0].device)
for p in self._params:
self._add_grad_as_view(p)
self._is_collapsed = False
@torch.no_grad()
def shrink(self) -> None:
"""
Shrink the buffer to the size of the parameter gradients currently checked in, release the extra memory
"""
assert self.buffer is not None, "Cannot shrink a collapsed bucket, please rebuild"
self.buffer = self.buffer.resize_(self._fill).clone()
self._fill = 0
for p in self._params:
self._add_grad_as_view(p)
self._max_size = self._fill
@torch.no_grad()
def _add_grad_as_view(self, param: torch.Tensor) -> None:
assert self.buffer is not None
assert param.dtype == self.buffer.dtype
assert param.device == self.buffer.device
fill_next = self._fill + param.numel()
assert fill_next <= self.buffer.numel()
# Copy the current grad value, if any
if param.grad is not None:
# keep param.grad in place
self.buffer[self._fill : fill_next].copy_(param.grad.data.flatten())
param.grad.data = self.buffer[self._fill : fill_next].view_as(param.data)
else:
param.grad = self.buffer[self._fill : fill_next].view_as(param.data)
self._fill = fill_next
...@@ -79,32 +79,6 @@ def broadcast_object( ...@@ -79,32 +79,6 @@ def broadcast_object(
return obj return obj
class Bucket:
"""
Helper class to simplify the handling of broadcast or reduce buckets
"""
def __init__(self, buffer: torch.Tensor) -> None:
# The actual flat tensor
self.buffer = buffer
self.max_size = buffer.numel()
# Current status for this buffer
self.fill = 0
self.params_checked_in = 0
self.max_params_checked_in = 0 # atttribute present for convenience purposes
self.destination = -1
self.sent = True
def reset(self) -> None:
self.params_checked_in = 0
self.sent = False
def full(self) -> bool:
""" is the bucket full ? """
return self.max_params_checked_in == self.params_checked_in
def calc_grad_norm(parameters: List[torch.nn.Parameter], p: float) -> torch.Tensor: def calc_grad_norm(parameters: List[torch.nn.Parameter], p: float) -> torch.Tensor:
r"""Calculate gradient norm of an iterable of parameters. r"""Calculate gradient norm of an iterable of parameters.
Returns: Returns:
......
...@@ -4,6 +4,7 @@ tests/utils/test_parallel.py ...@@ -4,6 +4,7 @@ tests/utils/test_parallel.py
tests/utils/test_state_dict.py tests/utils/test_state_dict.py
tests/nn/misc/test_checkpoint_activations.py tests/nn/misc/test_checkpoint_activations.py
tests/nn/misc/test_checkpoint_activations_norm.py tests/nn/misc/test_checkpoint_activations_norm.py
tests/nn/misc/test_grad_bucket.py
tests/nn/wrap/test_wrap.py tests/nn/wrap/test_wrap.py
tests/nn/pipe_process/test_pipe.py tests/nn/pipe_process/test_pipe.py
tests/nn/pipe_process/test_transparency.py tests/nn/pipe_process/test_transparency.py
......
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import pytest
import torch
from fairscale.nn.misc import GradBucket
def test_grad_values_conserved():
with torch.no_grad(): # remove a warning
param = torch.rand((2, 3), requires_grad=True)
param.grad = torch.rand(2, 3)
bucket = GradBucket(10, param.dtype, param.device, -1)
param_ = param.clone()
bucket.add_grad(param_)
torch.allclose(param.grad, param_.grad)
def test_memory_leak():
with torch.no_grad(): # remove a warning
param = torch.rand((2, 3), requires_grad=True)
param.grad = torch.rand(2, 3)
bucket = GradBucket(300, param.dtype, param.device, -1)
bucket.add_grad(param)
bucket.shrink()
assert len(bucket.buffer.storage()) == 6
def test_max_size():
with torch.no_grad(): # remove a warning
param = torch.rand((20, 30), requires_grad=True)
param.grad = torch.rand(20, 30)
bucket = GradBucket(5, param.dtype, param.device, -1)
with pytest.raises(AssertionError):
bucket.add_grad(param)
def test_collapse():
with torch.no_grad(): # remove a warning
size = (5, 6)
param = torch.rand(size, requires_grad=True)
param.grad = torch.rand(size)
bucket = GradBucket(300, param.dtype, param.device, -1)
bucket.add_grad(param)
bucket.shrink()
bucket.collapse()
assert bucket.buffer is None
assert param.grad is None
bucket.rebuild()
assert param.grad is not None
torch.allclose(param.grad, torch.zeros(size))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment