Unverified Commit e3865549 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[feat][refactor][OSS] Param buckets + fp16 broadcasts (#540)

* param buckets
* unifying the buckets
parent 195d62f1
......@@ -5,4 +5,4 @@
from .checkpoint_activations import checkpoint_wrapper
from .flatten_params_wrapper import FlattenParamsWrapper
from .grad_bucket import GradBucket
from .param_bucket import GradBucket, ParamBucket
......@@ -16,6 +16,7 @@ class GradBucket:
def __init__(self, size: int, dtype: torch.dtype, device: torch.device, destination: int) -> None:
self._max_size = size
self._params: List[torch.Tensor] = []
self._param_ids: List[int] = []
self._fill = 0
self._is_collapsed = False
......@@ -39,9 +40,9 @@ class GradBucket:
return len(self._params) == self.params_checked_in
def can_add_grad_view(self, param: torch.Tensor) -> bool:
""" Is there enough room in the bucket to add this parameter gradient ?
""" Is there enough room in the bucket to add this parameter gradient, and is this param not already checked in ?
"""
return self._fill + param.numel() < self._max_size
return self._fill + param.numel() < self._max_size and id(param) not in self._param_ids
def to( # type: ignore
self,
......@@ -70,11 +71,15 @@ class GradBucket:
"""
Add a new parameter gradient to the bucket. Param.grad becomes a view of this bucket buffer
"""
assert id(param) not in self._param_ids, "The same gradients cannot be checked in twice"
if param.grad is None:
param.grad = torch.zeros_like(param)
self._add_grad_as_view(param)
self._params.append(param)
self._param_ids.append(id(param))
@torch.no_grad()
def collapse(self) -> None:
......
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, Callable, List, Optional, Union
import torch
class Bucket:
"""
Helper class to simplify the handling of buckets, which unify the underlying storage of multiple tensors
"""
def __init__(self, size: int, dtype: torch.dtype, device: torch.device) -> None:
self._params: List[torch.Tensor] = []
self._param_ids: List[int] = []
self._fill = 0
# The actual flat tensor
self.buffer: torch.Tensor = torch.zeros(size, dtype=dtype, device=device)
def to( # type: ignore
self,
device: Optional[Union[int, torch.device]],
dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
keep_param_alignment: bool = True,
) -> "ParamBucket":
"""
Move the underlying buffer
"""
assert self.buffer is not None, "Cannot move a collapsed bucket, please rebuild it"
self.buffer.to(device, dtype, non_blocking)
class ParamBucket(Bucket):
"""
Helper class to simplify the handling of parameter buckets
"""
def __init__(self, size: int, dtype: torch.dtype, device: torch.device) -> None:
super().__init__(size, dtype, device)
def to( # type: ignore
self,
device: Optional[Union[int, torch.device]],
dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
keep_param_alignment: bool = True,
) -> "ParamBucket":
"""
Move the underlying buffer
"""
super().to(device, dtype, non_blocking)
if keep_param_alignment:
self._reattach_params()
@torch.no_grad()
def add_param(self, param: torch.Tensor) -> None:
"""
Add a new parameter gradient to the bucket. Param.grad becomes a view of this bucket buffer
"""
assert id(param) not in self._param_ids, "The same param cannot be checked in twice"
self._add_param_as_view(param)
self._params.append(param)
self._param_ids.append(id(param))
@torch.no_grad()
def _add_param_as_view(self, param: torch.Tensor, keep_existing_value: bool = True) -> None:
assert self.buffer is not None
assert param.dtype == self.buffer.dtype
assert param.device == self.buffer.device
fill_next = self._fill + param.numel()
assert fill_next <= self.buffer.numel()
# Copy the current param value
if keep_existing_value:
self.buffer[self._fill : fill_next].copy_(param.data.flatten())
param.data = self.buffer[self._fill : fill_next].view_as(param.data)
self._fill = fill_next
@torch.no_grad()
def _reattach_params(self) -> None:
"""
Given the parameters which have been registered previously, rebuild the whole bucket
"""
assert len(self._params) > 0
self._fill = 0
for p in self._params:
self._add_param_as_view(p, keep_existing_value=False)
class GradBucket(Bucket):
"""
Helper class to simplify the handling of gradient buckets
"""
def __init__(self, size: int, dtype: torch.dtype, device: torch.device, destination: int) -> None:
super().__init__(size, dtype, device)
self._max_size = size
self._is_collapsed = False
self.params_checked_in = 0
self.destination = destination
self.sent = True
self.callback: Optional[Callable[[Any], None]] = None
def reset_checked_in(self) -> None:
""" Reset the counter of the parameter grads which have been checked in
"""
self.params_checked_in = 0
self.sent = False
@property
def all_checked_in(self) -> bool:
""" Have all the expected gradient check-in happened ?"""
return len(self._params) == self.params_checked_in
def can_add_grad_view(self, param: torch.Tensor) -> bool:
""" Is there enough room in the bucket to add this parameter gradient, and is this param not already checked in ?
"""
return self._fill + param.numel() < self._max_size and id(param) not in self._param_ids
def to( # type: ignore
self,
device: Optional[Union[int, torch.device]],
dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
keep_param_alignment: bool = True,
) -> "GradBucket":
"""
Move the underlying buffer
"""
if self._is_collapsed:
self.rebuild()
super().to(device, dtype, non_blocking)
if keep_param_alignment:
self._reattach_grads()
def zero(self) -> None:
"""
Set all the grads to zero
"""
self.buffer.fill_(0.0)
@torch.no_grad()
def add_grad(self, param: torch.Tensor) -> None:
"""
Add a new parameter gradient to the bucket. Param.grad becomes a view of this bucket buffer
"""
assert id(param) not in self._param_ids, "The same gradients cannot be checked in twice"
if param.grad is None:
param.grad = torch.zeros_like(param)
self._add_grad_as_view(param)
self._params.append(param)
self._param_ids.append(id(param))
@torch.no_grad()
def collapse(self) -> None:
"""
Release the buffer from memory. The bucket will need to be rebuilt before use
"""
if not self._is_collapsed:
for p in self._params:
assert p.grad is not None
p.grad.detach_()
p.grad = None
self.buffer = torch.zeros(0, dtype=self.buffer.dtype, device=self.buffer.device)
self._fill = 0
self.params_checked_in = 0
self._is_collapsed = True
@torch.no_grad()
def rebuild(self) -> None:
"""
Given the parameter gradients which have been registered previously, rebuild the whole bucket
"""
assert len(self._params) > 0
if self._is_collapsed:
self.buffer = torch.zeros(self._max_size, dtype=self._params[0].dtype, device=self._params[0].device)
for p in self._params:
self._add_grad_as_view(p)
self._is_collapsed = False
@torch.no_grad()
def shrink(self) -> None:
"""
Shrink the buffer to the size of the parameter gradients currently checked in, release the extra memory
"""
assert self.buffer.numel() > 0, "Cannot shrink a collapsed bucket, please rebuild"
self.buffer = self.buffer.resize_(self._fill).clone()
self._fill = 0
for p in self._params:
self._add_grad_as_view(p)
self._max_size = self._fill
@torch.no_grad()
def _reattach_grads(self) -> None:
"""
Given the parameters gradients which have been registered previously, rebuild the whole bucket
"""
assert len(self._params) > 0
self._fill = 0
for p in self._params:
self._add_grad_as_view(p, keep_existing_value=False)
@torch.no_grad()
def _add_grad_as_view(self, param: torch.Tensor, keep_existing_value: bool = True) -> None:
assert self.buffer.numel() > 0, "Cannot add a gradient to a collapsed bucket, please rebuild"
assert param.dtype == self.buffer.dtype
assert param.device == self.buffer.device
fill_next = self._fill + param.numel()
assert fill_next <= self.buffer.numel()
# Copy the current grad value, if any
if param.grad is not None:
# keep param.grad in place
if keep_existing_value:
self.buffer[self._fill : fill_next].copy_(param.grad.data.flatten())
param.grad.data = self.buffer[self._fill : fill_next].view_as(param.data)
else:
param.grad = self.buffer[self._fill : fill_next].view_as(param.data)
self._fill = fill_next
......@@ -15,6 +15,8 @@ import torch.distributed as dist
from torch.nn import Parameter
from torch.optim import SGD, Optimizer
from fairscale.nn.misc import ParamBucket
from .utils import broadcast_object, calc_grad_norm, recursive_copy_to_device
__all__ = ["OSS"]
......@@ -52,6 +54,10 @@ class OSS(Optimizer):
torch.distributed group (default: group.WORLD)
broadcast_buffer_size (int):
(deprecated) used to cap the size of the broadcast buffers, not being used anymore.
broadcast_fp16 (bool):
Compress the model shards in fp16 before sharing them in between ranks.
This is safe to use when PyTorch AMP is activated. Without torch AMP this will lead to a slight
degradation in terms of accuracy.
.. warning: the communication patterns that OSS use depend on the "trainability" graph,
......@@ -73,6 +79,7 @@ class OSS(Optimizer):
optim: Type[Optimizer] = SGD,
group: Optional[Any] = None,
broadcast_buffer_size: int = -1,
broadcast_fp16: bool = False,
**default: Any,
):
......@@ -99,7 +106,8 @@ class OSS(Optimizer):
self.global_rank = self.get_global_rank(self.group, self.rank)
self._local_to_global_rank = [self.get_global_rank(self.group, i) for i in range(self.world_size)]
self.buckets: Dict[torch.device, List[torch.Tensor]] = {}
self.broadcast_fp16 = broadcast_fp16
self.buckets: Dict[torch.device, Dict[int, ParamBucket]] = {}
self._all_states: List[Dict[str, Any]] = [] # Optional consolidated optimizer state
self._default_device = torch.device("cpu")
......@@ -542,20 +550,31 @@ class OSS(Optimizer):
work_handles = [] # Work handles are consumed within this scope, no callback
# Populate the fp16 shards
if self.broadcast_fp16:
for device in self.buckets.keys():
for dst_rank, bucket in self.buckets[device].items():
bucket.to(dtype=torch.float16, device=device, non_blocking=True, keep_param_alignment=False)
if torch.cuda.is_available():
torch.cuda.synchronize()
# Exchange all the shards with the other ranks
for device in self.buckets.keys():
for src_rank, bucket in enumerate(self.buckets[device]):
if bucket.numel() > 0:
work_handles.append(
dist.broadcast(
tensor=bucket, src=self._local_to_global_rank[src_rank], group=self.group, async_op=True
)
for dst_rank, bucket in self.buckets[device].items():
work_handles.append(
dist.broadcast(
tensor=bucket.buffer, src=self._local_to_global_rank[dst_rank], group=self.group, async_op=True,
)
)
# Only check on the last handle, they're all inlined on the same CUDA stream
if work_handles and self.backend == dist.Backend.NCCL:
work_handles[-1].wait()
else:
_ = list(filter(lambda x: x.wait(), work_handles))
_ = list(filter(lambda x: x.wait(), work_handles))
# Populate back the fp32 shards
if self.broadcast_fp16:
for device in self.buckets.keys():
for dst_rank in self.buckets[device].keys():
bucket.to(dtype=torch.float32, device=device, non_blocking=True, keep_param_alignment=True)
def _setup_flat_buffers(self) -> None:
"""Make all params which are on the same device and tied to the same rank views of a single buffer.
......@@ -567,7 +586,7 @@ class OSS(Optimizer):
# Only wipe the existing buckets if there are none
# (could be that this is called twice, when trainability changes)
if device not in self.buckets.keys():
self.buckets[device] = []
self.buckets[device] = {}
# Make parameters a view of the bucket
for dst_rank, params in enumerate(per_rank_params):
......@@ -580,23 +599,12 @@ class OSS(Optimizer):
# Merge all the trainable params in a single bucket
trainable_params = list(filter(lambda x: x.requires_grad, params))
buffer_size = sum(map(lambda x: x.numel(), trainable_params))
bucket = torch.empty(buffer_size, dtype=params[0].dtype, device=device)
offset = 0
bucket = ParamBucket(size=buffer_size, dtype=params[0].dtype, device=device)
for param in trainable_params:
offset_next = offset + param.numel()
bucket[offset:offset_next].copy_(param.data.flatten())
param.data = bucket[offset:offset_next].view_as(param.data)
offset = offset_next
# Either replace the existing bucket, or create it
if len(self.buckets[device]) == dst_rank:
self.buckets[device].append(bucket)
else:
self.buckets[device][dst_rank] = bucket
else:
# This rank has an empty shard, that's fine
self.buckets[device].append(torch.zeros(0, device=device))
bucket.add_param(param)
self.buckets[device][dst_rank] = bucket
# Clear the buffer keys which are not in use anymore (could be that the devices changed)
devices_in_use = list(self.per_device_params.keys())
......
......@@ -5,6 +5,7 @@ tests/utils/test_state_dict.py
tests/nn/misc/test_checkpoint_activations.py
tests/nn/misc/test_checkpoint_activations_norm.py
tests/nn/misc/test_grad_bucket.py
tests/nn/misc/test_param_bucket.py
tests/nn/wrap/test_wrap.py
tests/nn/pipe_process/test_pipe.py
tests/nn/pipe_process/test_transparency.py
......
......@@ -55,7 +55,7 @@ def test_collapse():
bucket.shrink()
bucket.collapse()
assert bucket.buffer is None
assert bucket.buffer.numel() == 0
assert param.grad is None
bucket.rebuild()
......
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import pytest
import torch
from fairscale.nn.misc import ParamBucket
def test_param_values_conserved():
param = torch.rand((2, 3))
bucket = ParamBucket(10, param.dtype, param.device)
param_ = param.clone()
bucket.add_param(param_)
torch.allclose(param, param_)
def test_max_size():
param = torch.rand((20, 30))
bucket = ParamBucket(5, param.dtype, param.device)
with pytest.raises(AssertionError):
bucket.add_param(param)
def test_double_check_int():
param = torch.rand((5, 6))
bucket = ParamBucket(300, param.dtype, param.device)
bucket.add_param(param)
with pytest.raises(AssertionError):
bucket.add_param(param)
def test_type_change():
size = (5, 6)
param = torch.rand(size, requires_grad=True)
param_ = param.clone()
bucket = ParamBucket(30, param.dtype, param.device)
bucket.add_param(param)
# Move the bucket to fp16 and back
bucket.to(dtype=torch.float16, device=param.device)
bucket.to(dtype=torch.float32, device=param.device, keep_param_alignment=True)
# Same with the reference tensor
param_.to(dtype=torch.float16)
param_.to(dtype=torch.float32)
torch.allclose(param, param_)
......@@ -484,7 +484,7 @@ def test_collect_shards():
)
def run_test_reproducibility(rank, world_size, tempfile_name):
def run_test_reproducibility(rank, world_size, tempfile_name, broadcast_fp16):
dist_init(rank, world_size, tempfile_name)
device = torch.device(rank) if torch.cuda.device_count() > 1 else DEVICE
torch.cuda.set_device(rank)
......@@ -501,7 +501,7 @@ def run_test_reproducibility(rank, world_size, tempfile_name):
loss_fn = torch.nn.L1Loss()
loss_fn.to(device)
optimizer = optim.OSS(model.parameters(), optim=torch.optim.RMSprop, lr=0.1)
optimizer = optim.OSS(model.parameters(), optim=torch.optim.RMSprop, lr=0.1, broadcast_fp16=broadcast_fp16)
def closure():
optimizer.zero_grad()
......@@ -534,12 +534,13 @@ def run_test_reproducibility(rank, world_size, tempfile_name):
@skip_if_single_gpu
def test_reproducibility():
@pytest.mark.parametrize("broadcast_fp16", [False, True])
def test_reproducibility(broadcast_fp16: bool):
world_size = 2
temp_file_name = tempfile.mkstemp()[1]
mp.spawn(
run_test_reproducibility, args=(world_size, temp_file_name), nprocs=world_size, join=True,
run_test_reproducibility, args=(world_size, temp_file_name, broadcast_fp16), nprocs=world_size, join=True,
)
......@@ -810,7 +811,7 @@ def test_state_dict_distributed():
)
def run_ddp_parity(rank, world_size, backend, temp_file_name, change_train_graph):
def run_ddp_parity(rank, world_size, backend, temp_file_name, change_train_graph, broadcast_fp16):
url = "file://" + temp_file_name
dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size)
......@@ -937,9 +938,13 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name, change_train_graph
@skip_if_single_gpu
@pytest.mark.parametrize("change_train_graph", [True, False])
@pytest.mark.parametrize("backend", [dist.Backend.NCCL, dist.Backend.GLOO])
def test_ddp_parity(change_train_graph: bool, backend: dist.Backend):
@pytest.mark.parametrize("broadcast_fp16", [False, True])
def test_ddp_parity(change_train_graph: bool, backend: dist.Backend, broadcast_fp16: bool):
temp_file_name = tempfile.mkstemp()[1]
world_size = torch.cuda.device_count()
mp.spawn(
run_ddp_parity, args=(world_size, backend, temp_file_name, change_train_graph), nprocs=world_size, join=True
run_ddp_parity,
args=(world_size, backend, temp_file_name, change_train_graph, broadcast_fp16),
nprocs=world_size,
join=True,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment