"torchvision/vscode:/vscode.git/clone" did not exist on "fc69c22576cbccb59c581ddbaca4dedbdb279688"
Unverified Commit e3865549 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[feat][refactor][OSS] Param buckets + fp16 broadcasts (#540)

* param buckets
* unifying the buckets
parent 195d62f1
...@@ -5,4 +5,4 @@ ...@@ -5,4 +5,4 @@
from .checkpoint_activations import checkpoint_wrapper from .checkpoint_activations import checkpoint_wrapper
from .flatten_params_wrapper import FlattenParamsWrapper from .flatten_params_wrapper import FlattenParamsWrapper
from .grad_bucket import GradBucket from .param_bucket import GradBucket, ParamBucket
...@@ -16,6 +16,7 @@ class GradBucket: ...@@ -16,6 +16,7 @@ class GradBucket:
def __init__(self, size: int, dtype: torch.dtype, device: torch.device, destination: int) -> None: def __init__(self, size: int, dtype: torch.dtype, device: torch.device, destination: int) -> None:
self._max_size = size self._max_size = size
self._params: List[torch.Tensor] = [] self._params: List[torch.Tensor] = []
self._param_ids: List[int] = []
self._fill = 0 self._fill = 0
self._is_collapsed = False self._is_collapsed = False
...@@ -39,9 +40,9 @@ class GradBucket: ...@@ -39,9 +40,9 @@ class GradBucket:
return len(self._params) == self.params_checked_in return len(self._params) == self.params_checked_in
def can_add_grad_view(self, param: torch.Tensor) -> bool: def can_add_grad_view(self, param: torch.Tensor) -> bool:
""" Is there enough room in the bucket to add this parameter gradient ? """ Is there enough room in the bucket to add this parameter gradient, and is this param not already checked in ?
""" """
return self._fill + param.numel() < self._max_size return self._fill + param.numel() < self._max_size and id(param) not in self._param_ids
def to( # type: ignore def to( # type: ignore
self, self,
...@@ -70,11 +71,15 @@ class GradBucket: ...@@ -70,11 +71,15 @@ class GradBucket:
""" """
Add a new parameter gradient to the bucket. Param.grad becomes a view of this bucket buffer Add a new parameter gradient to the bucket. Param.grad becomes a view of this bucket buffer
""" """
assert id(param) not in self._param_ids, "The same gradients cannot be checked in twice"
if param.grad is None: if param.grad is None:
param.grad = torch.zeros_like(param) param.grad = torch.zeros_like(param)
self._add_grad_as_view(param) self._add_grad_as_view(param)
self._params.append(param) self._params.append(param)
self._param_ids.append(id(param))
@torch.no_grad() @torch.no_grad()
def collapse(self) -> None: def collapse(self) -> None:
......
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, Callable, List, Optional, Union
import torch
class Bucket:
"""
Helper class to simplify the handling of buckets, which unify the underlying storage of multiple tensors
"""
def __init__(self, size: int, dtype: torch.dtype, device: torch.device) -> None:
self._params: List[torch.Tensor] = []
self._param_ids: List[int] = []
self._fill = 0
# The actual flat tensor
self.buffer: torch.Tensor = torch.zeros(size, dtype=dtype, device=device)
def to( # type: ignore
self,
device: Optional[Union[int, torch.device]],
dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
keep_param_alignment: bool = True,
) -> "ParamBucket":
"""
Move the underlying buffer
"""
assert self.buffer is not None, "Cannot move a collapsed bucket, please rebuild it"
self.buffer.to(device, dtype, non_blocking)
class ParamBucket(Bucket):
"""
Helper class to simplify the handling of parameter buckets
"""
def __init__(self, size: int, dtype: torch.dtype, device: torch.device) -> None:
super().__init__(size, dtype, device)
def to( # type: ignore
self,
device: Optional[Union[int, torch.device]],
dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
keep_param_alignment: bool = True,
) -> "ParamBucket":
"""
Move the underlying buffer
"""
super().to(device, dtype, non_blocking)
if keep_param_alignment:
self._reattach_params()
@torch.no_grad()
def add_param(self, param: torch.Tensor) -> None:
"""
Add a new parameter gradient to the bucket. Param.grad becomes a view of this bucket buffer
"""
assert id(param) not in self._param_ids, "The same param cannot be checked in twice"
self._add_param_as_view(param)
self._params.append(param)
self._param_ids.append(id(param))
@torch.no_grad()
def _add_param_as_view(self, param: torch.Tensor, keep_existing_value: bool = True) -> None:
assert self.buffer is not None
assert param.dtype == self.buffer.dtype
assert param.device == self.buffer.device
fill_next = self._fill + param.numel()
assert fill_next <= self.buffer.numel()
# Copy the current param value
if keep_existing_value:
self.buffer[self._fill : fill_next].copy_(param.data.flatten())
param.data = self.buffer[self._fill : fill_next].view_as(param.data)
self._fill = fill_next
@torch.no_grad()
def _reattach_params(self) -> None:
"""
Given the parameters which have been registered previously, rebuild the whole bucket
"""
assert len(self._params) > 0
self._fill = 0
for p in self._params:
self._add_param_as_view(p, keep_existing_value=False)
class GradBucket(Bucket):
"""
Helper class to simplify the handling of gradient buckets
"""
def __init__(self, size: int, dtype: torch.dtype, device: torch.device, destination: int) -> None:
super().__init__(size, dtype, device)
self._max_size = size
self._is_collapsed = False
self.params_checked_in = 0
self.destination = destination
self.sent = True
self.callback: Optional[Callable[[Any], None]] = None
def reset_checked_in(self) -> None:
""" Reset the counter of the parameter grads which have been checked in
"""
self.params_checked_in = 0
self.sent = False
@property
def all_checked_in(self) -> bool:
""" Have all the expected gradient check-in happened ?"""
return len(self._params) == self.params_checked_in
def can_add_grad_view(self, param: torch.Tensor) -> bool:
""" Is there enough room in the bucket to add this parameter gradient, and is this param not already checked in ?
"""
return self._fill + param.numel() < self._max_size and id(param) not in self._param_ids
def to( # type: ignore
self,
device: Optional[Union[int, torch.device]],
dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
keep_param_alignment: bool = True,
) -> "GradBucket":
"""
Move the underlying buffer
"""
if self._is_collapsed:
self.rebuild()
super().to(device, dtype, non_blocking)
if keep_param_alignment:
self._reattach_grads()
def zero(self) -> None:
"""
Set all the grads to zero
"""
self.buffer.fill_(0.0)
@torch.no_grad()
def add_grad(self, param: torch.Tensor) -> None:
"""
Add a new parameter gradient to the bucket. Param.grad becomes a view of this bucket buffer
"""
assert id(param) not in self._param_ids, "The same gradients cannot be checked in twice"
if param.grad is None:
param.grad = torch.zeros_like(param)
self._add_grad_as_view(param)
self._params.append(param)
self._param_ids.append(id(param))
@torch.no_grad()
def collapse(self) -> None:
"""
Release the buffer from memory. The bucket will need to be rebuilt before use
"""
if not self._is_collapsed:
for p in self._params:
assert p.grad is not None
p.grad.detach_()
p.grad = None
self.buffer = torch.zeros(0, dtype=self.buffer.dtype, device=self.buffer.device)
self._fill = 0
self.params_checked_in = 0
self._is_collapsed = True
@torch.no_grad()
def rebuild(self) -> None:
"""
Given the parameter gradients which have been registered previously, rebuild the whole bucket
"""
assert len(self._params) > 0
if self._is_collapsed:
self.buffer = torch.zeros(self._max_size, dtype=self._params[0].dtype, device=self._params[0].device)
for p in self._params:
self._add_grad_as_view(p)
self._is_collapsed = False
@torch.no_grad()
def shrink(self) -> None:
"""
Shrink the buffer to the size of the parameter gradients currently checked in, release the extra memory
"""
assert self.buffer.numel() > 0, "Cannot shrink a collapsed bucket, please rebuild"
self.buffer = self.buffer.resize_(self._fill).clone()
self._fill = 0
for p in self._params:
self._add_grad_as_view(p)
self._max_size = self._fill
@torch.no_grad()
def _reattach_grads(self) -> None:
"""
Given the parameters gradients which have been registered previously, rebuild the whole bucket
"""
assert len(self._params) > 0
self._fill = 0
for p in self._params:
self._add_grad_as_view(p, keep_existing_value=False)
@torch.no_grad()
def _add_grad_as_view(self, param: torch.Tensor, keep_existing_value: bool = True) -> None:
assert self.buffer.numel() > 0, "Cannot add a gradient to a collapsed bucket, please rebuild"
assert param.dtype == self.buffer.dtype
assert param.device == self.buffer.device
fill_next = self._fill + param.numel()
assert fill_next <= self.buffer.numel()
# Copy the current grad value, if any
if param.grad is not None:
# keep param.grad in place
if keep_existing_value:
self.buffer[self._fill : fill_next].copy_(param.grad.data.flatten())
param.grad.data = self.buffer[self._fill : fill_next].view_as(param.data)
else:
param.grad = self.buffer[self._fill : fill_next].view_as(param.data)
self._fill = fill_next
...@@ -15,6 +15,8 @@ import torch.distributed as dist ...@@ -15,6 +15,8 @@ import torch.distributed as dist
from torch.nn import Parameter from torch.nn import Parameter
from torch.optim import SGD, Optimizer from torch.optim import SGD, Optimizer
from fairscale.nn.misc import ParamBucket
from .utils import broadcast_object, calc_grad_norm, recursive_copy_to_device from .utils import broadcast_object, calc_grad_norm, recursive_copy_to_device
__all__ = ["OSS"] __all__ = ["OSS"]
...@@ -52,6 +54,10 @@ class OSS(Optimizer): ...@@ -52,6 +54,10 @@ class OSS(Optimizer):
torch.distributed group (default: group.WORLD) torch.distributed group (default: group.WORLD)
broadcast_buffer_size (int): broadcast_buffer_size (int):
(deprecated) used to cap the size of the broadcast buffers, not being used anymore. (deprecated) used to cap the size of the broadcast buffers, not being used anymore.
broadcast_fp16 (bool):
Compress the model shards in fp16 before sharing them in between ranks.
This is safe to use when PyTorch AMP is activated. Without torch AMP this will lead to a slight
degradation in terms of accuracy.
.. warning: the communication patterns that OSS use depend on the "trainability" graph, .. warning: the communication patterns that OSS use depend on the "trainability" graph,
...@@ -73,6 +79,7 @@ class OSS(Optimizer): ...@@ -73,6 +79,7 @@ class OSS(Optimizer):
optim: Type[Optimizer] = SGD, optim: Type[Optimizer] = SGD,
group: Optional[Any] = None, group: Optional[Any] = None,
broadcast_buffer_size: int = -1, broadcast_buffer_size: int = -1,
broadcast_fp16: bool = False,
**default: Any, **default: Any,
): ):
...@@ -99,7 +106,8 @@ class OSS(Optimizer): ...@@ -99,7 +106,8 @@ class OSS(Optimizer):
self.global_rank = self.get_global_rank(self.group, self.rank) self.global_rank = self.get_global_rank(self.group, self.rank)
self._local_to_global_rank = [self.get_global_rank(self.group, i) for i in range(self.world_size)] self._local_to_global_rank = [self.get_global_rank(self.group, i) for i in range(self.world_size)]
self.buckets: Dict[torch.device, List[torch.Tensor]] = {} self.broadcast_fp16 = broadcast_fp16
self.buckets: Dict[torch.device, Dict[int, ParamBucket]] = {}
self._all_states: List[Dict[str, Any]] = [] # Optional consolidated optimizer state self._all_states: List[Dict[str, Any]] = [] # Optional consolidated optimizer state
self._default_device = torch.device("cpu") self._default_device = torch.device("cpu")
...@@ -542,21 +550,32 @@ class OSS(Optimizer): ...@@ -542,21 +550,32 @@ class OSS(Optimizer):
work_handles = [] # Work handles are consumed within this scope, no callback work_handles = [] # Work handles are consumed within this scope, no callback
# Populate the fp16 shards
if self.broadcast_fp16:
for device in self.buckets.keys(): for device in self.buckets.keys():
for src_rank, bucket in enumerate(self.buckets[device]): for dst_rank, bucket in self.buckets[device].items():
if bucket.numel() > 0: bucket.to(dtype=torch.float16, device=device, non_blocking=True, keep_param_alignment=False)
if torch.cuda.is_available():
torch.cuda.synchronize()
# Exchange all the shards with the other ranks
for device in self.buckets.keys():
for dst_rank, bucket in self.buckets[device].items():
work_handles.append( work_handles.append(
dist.broadcast( dist.broadcast(
tensor=bucket, src=self._local_to_global_rank[src_rank], group=self.group, async_op=True tensor=bucket.buffer, src=self._local_to_global_rank[dst_rank], group=self.group, async_op=True,
) )
) )
# Only check on the last handle, they're all inlined on the same CUDA stream
if work_handles and self.backend == dist.Backend.NCCL:
work_handles[-1].wait()
else:
_ = list(filter(lambda x: x.wait(), work_handles)) _ = list(filter(lambda x: x.wait(), work_handles))
# Populate back the fp32 shards
if self.broadcast_fp16:
for device in self.buckets.keys():
for dst_rank in self.buckets[device].keys():
bucket.to(dtype=torch.float32, device=device, non_blocking=True, keep_param_alignment=True)
def _setup_flat_buffers(self) -> None: def _setup_flat_buffers(self) -> None:
"""Make all params which are on the same device and tied to the same rank views of a single buffer. """Make all params which are on the same device and tied to the same rank views of a single buffer.
This is used at construction time, and anytime parameter trainability is changed (frozen or unfrozen) and This is used at construction time, and anytime parameter trainability is changed (frozen or unfrozen) and
...@@ -567,7 +586,7 @@ class OSS(Optimizer): ...@@ -567,7 +586,7 @@ class OSS(Optimizer):
# Only wipe the existing buckets if there are none # Only wipe the existing buckets if there are none
# (could be that this is called twice, when trainability changes) # (could be that this is called twice, when trainability changes)
if device not in self.buckets.keys(): if device not in self.buckets.keys():
self.buckets[device] = [] self.buckets[device] = {}
# Make parameters a view of the bucket # Make parameters a view of the bucket
for dst_rank, params in enumerate(per_rank_params): for dst_rank, params in enumerate(per_rank_params):
...@@ -580,23 +599,12 @@ class OSS(Optimizer): ...@@ -580,23 +599,12 @@ class OSS(Optimizer):
# Merge all the trainable params in a single bucket # Merge all the trainable params in a single bucket
trainable_params = list(filter(lambda x: x.requires_grad, params)) trainable_params = list(filter(lambda x: x.requires_grad, params))
buffer_size = sum(map(lambda x: x.numel(), trainable_params)) buffer_size = sum(map(lambda x: x.numel(), trainable_params))
bucket = torch.empty(buffer_size, dtype=params[0].dtype, device=device) bucket = ParamBucket(size=buffer_size, dtype=params[0].dtype, device=device)
offset = 0
for param in trainable_params: for param in trainable_params:
offset_next = offset + param.numel() bucket.add_param(param)
bucket[offset:offset_next].copy_(param.data.flatten())
param.data = bucket[offset:offset_next].view_as(param.data)
offset = offset_next
# Either replace the existing bucket, or create it
if len(self.buckets[device]) == dst_rank:
self.buckets[device].append(bucket)
else:
self.buckets[device][dst_rank] = bucket self.buckets[device][dst_rank] = bucket
else:
# This rank has an empty shard, that's fine
self.buckets[device].append(torch.zeros(0, device=device))
# Clear the buffer keys which are not in use anymore (could be that the devices changed) # Clear the buffer keys which are not in use anymore (could be that the devices changed)
devices_in_use = list(self.per_device_params.keys()) devices_in_use = list(self.per_device_params.keys())
......
...@@ -5,6 +5,7 @@ tests/utils/test_state_dict.py ...@@ -5,6 +5,7 @@ tests/utils/test_state_dict.py
tests/nn/misc/test_checkpoint_activations.py tests/nn/misc/test_checkpoint_activations.py
tests/nn/misc/test_checkpoint_activations_norm.py tests/nn/misc/test_checkpoint_activations_norm.py
tests/nn/misc/test_grad_bucket.py tests/nn/misc/test_grad_bucket.py
tests/nn/misc/test_param_bucket.py
tests/nn/wrap/test_wrap.py tests/nn/wrap/test_wrap.py
tests/nn/pipe_process/test_pipe.py tests/nn/pipe_process/test_pipe.py
tests/nn/pipe_process/test_transparency.py tests/nn/pipe_process/test_transparency.py
......
...@@ -55,7 +55,7 @@ def test_collapse(): ...@@ -55,7 +55,7 @@ def test_collapse():
bucket.shrink() bucket.shrink()
bucket.collapse() bucket.collapse()
assert bucket.buffer is None assert bucket.buffer.numel() == 0
assert param.grad is None assert param.grad is None
bucket.rebuild() bucket.rebuild()
......
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import pytest
import torch
from fairscale.nn.misc import ParamBucket
def test_param_values_conserved():
param = torch.rand((2, 3))
bucket = ParamBucket(10, param.dtype, param.device)
param_ = param.clone()
bucket.add_param(param_)
torch.allclose(param, param_)
def test_max_size():
param = torch.rand((20, 30))
bucket = ParamBucket(5, param.dtype, param.device)
with pytest.raises(AssertionError):
bucket.add_param(param)
def test_double_check_int():
param = torch.rand((5, 6))
bucket = ParamBucket(300, param.dtype, param.device)
bucket.add_param(param)
with pytest.raises(AssertionError):
bucket.add_param(param)
def test_type_change():
size = (5, 6)
param = torch.rand(size, requires_grad=True)
param_ = param.clone()
bucket = ParamBucket(30, param.dtype, param.device)
bucket.add_param(param)
# Move the bucket to fp16 and back
bucket.to(dtype=torch.float16, device=param.device)
bucket.to(dtype=torch.float32, device=param.device, keep_param_alignment=True)
# Same with the reference tensor
param_.to(dtype=torch.float16)
param_.to(dtype=torch.float32)
torch.allclose(param, param_)
...@@ -484,7 +484,7 @@ def test_collect_shards(): ...@@ -484,7 +484,7 @@ def test_collect_shards():
) )
def run_test_reproducibility(rank, world_size, tempfile_name): def run_test_reproducibility(rank, world_size, tempfile_name, broadcast_fp16):
dist_init(rank, world_size, tempfile_name) dist_init(rank, world_size, tempfile_name)
device = torch.device(rank) if torch.cuda.device_count() > 1 else DEVICE device = torch.device(rank) if torch.cuda.device_count() > 1 else DEVICE
torch.cuda.set_device(rank) torch.cuda.set_device(rank)
...@@ -501,7 +501,7 @@ def run_test_reproducibility(rank, world_size, tempfile_name): ...@@ -501,7 +501,7 @@ def run_test_reproducibility(rank, world_size, tempfile_name):
loss_fn = torch.nn.L1Loss() loss_fn = torch.nn.L1Loss()
loss_fn.to(device) loss_fn.to(device)
optimizer = optim.OSS(model.parameters(), optim=torch.optim.RMSprop, lr=0.1) optimizer = optim.OSS(model.parameters(), optim=torch.optim.RMSprop, lr=0.1, broadcast_fp16=broadcast_fp16)
def closure(): def closure():
optimizer.zero_grad() optimizer.zero_grad()
...@@ -534,12 +534,13 @@ def run_test_reproducibility(rank, world_size, tempfile_name): ...@@ -534,12 +534,13 @@ def run_test_reproducibility(rank, world_size, tempfile_name):
@skip_if_single_gpu @skip_if_single_gpu
def test_reproducibility(): @pytest.mark.parametrize("broadcast_fp16", [False, True])
def test_reproducibility(broadcast_fp16: bool):
world_size = 2 world_size = 2
temp_file_name = tempfile.mkstemp()[1] temp_file_name = tempfile.mkstemp()[1]
mp.spawn( mp.spawn(
run_test_reproducibility, args=(world_size, temp_file_name), nprocs=world_size, join=True, run_test_reproducibility, args=(world_size, temp_file_name, broadcast_fp16), nprocs=world_size, join=True,
) )
...@@ -810,7 +811,7 @@ def test_state_dict_distributed(): ...@@ -810,7 +811,7 @@ def test_state_dict_distributed():
) )
def run_ddp_parity(rank, world_size, backend, temp_file_name, change_train_graph): def run_ddp_parity(rank, world_size, backend, temp_file_name, change_train_graph, broadcast_fp16):
url = "file://" + temp_file_name url = "file://" + temp_file_name
dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size) dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size)
...@@ -937,9 +938,13 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name, change_train_graph ...@@ -937,9 +938,13 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name, change_train_graph
@skip_if_single_gpu @skip_if_single_gpu
@pytest.mark.parametrize("change_train_graph", [True, False]) @pytest.mark.parametrize("change_train_graph", [True, False])
@pytest.mark.parametrize("backend", [dist.Backend.NCCL, dist.Backend.GLOO]) @pytest.mark.parametrize("backend", [dist.Backend.NCCL, dist.Backend.GLOO])
def test_ddp_parity(change_train_graph: bool, backend: dist.Backend): @pytest.mark.parametrize("broadcast_fp16", [False, True])
def test_ddp_parity(change_train_graph: bool, backend: dist.Backend, broadcast_fp16: bool):
temp_file_name = tempfile.mkstemp()[1] temp_file_name = tempfile.mkstemp()[1]
world_size = torch.cuda.device_count() world_size = torch.cuda.device_count()
mp.spawn( mp.spawn(
run_ddp_parity, args=(world_size, backend, temp_file_name, change_train_graph), nprocs=world_size, join=True run_ddp_parity,
args=(world_size, backend, temp_file_name, change_train_graph, broadcast_fp16),
nprocs=world_size,
join=True,
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment