Unverified Commit 77d94861 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[refactor] OSS only use flat buffers (#371)

* flat params all along, way simpler
* updating the docstring
parent 8778fa66
...@@ -51,9 +51,7 @@ class OSS(Optimizer): ...@@ -51,9 +51,7 @@ class OSS(Optimizer):
group (group): group (group):
torch.distributed group (default: group.WORLD) torch.distributed group (default: group.WORLD)
broadcast_buffer_size (int): broadcast_buffer_size (int):
the max size of the buffer used to batch the small parameter tensors, in number of elements (default 16M). (deprecated) used to cap the size of the broadcast buffers, not being used anymore.
this will not impact the long term memory consumption, but the peak memory can be impacted by the moment
when the buffers are allocated and the bucketed params have not yet been relocated to them.
""" """
#: The optimizer used for a given shard #: The optimizer used for a given shard
...@@ -66,7 +64,7 @@ class OSS(Optimizer): ...@@ -66,7 +64,7 @@ class OSS(Optimizer):
params: _params_t, params: _params_t,
optim: Type[Optimizer] = SGD, optim: Type[Optimizer] = SGD,
group: Optional[Any] = None, group: Optional[Any] = None,
broadcast_buffer_size: int = 2 ** 24, broadcast_buffer_size: int = -1,
**default: Any, **default: Any,
): ):
...@@ -101,12 +99,9 @@ class OSS(Optimizer): ...@@ -101,12 +99,9 @@ class OSS(Optimizer):
# Current default device is set by the parameters allocated to this rank # Current default device is set by the parameters allocated to this rank
self._device = list(self.per_device_params.keys())[0] self._device = list(self.per_device_params.keys())[0]
self.buckets: Dict[torch.device, List[torch.Tensor]] = {}
self.buffer_max_size = broadcast_buffer_size
self.should_bucket_param: List[bool] = []
self.work_handles: Deque[Workhandle] = deque() self.work_handles: Deque[Workhandle] = deque()
self._setup_bucket_strategy() self.buckets: Dict[torch.device, List[torch.Tensor]] = {}
self._setup_flat_buffers()
# Partition helpers # Partition helpers
def partition_parameters(self) -> List[List[dict]]: def partition_parameters(self) -> List[List[dict]]:
...@@ -509,7 +504,7 @@ class OSS(Optimizer): ...@@ -509,7 +504,7 @@ class OSS(Optimizer):
self.optim.add_param_group(param_groups[-1]) self.optim.add_param_group(param_groups[-1])
# Update the bucketing strategy accordingly # Update the bucketing strategy accordingly
self._setup_bucket_strategy() self._setup_flat_buffers()
def _clear_cache(self) -> None: def _clear_cache(self) -> None:
self._partition_parameters.clear() self._partition_parameters.clear()
...@@ -540,25 +535,11 @@ class OSS(Optimizer): ...@@ -540,25 +535,11 @@ class OSS(Optimizer):
def _broadcast_params(self) -> None: def _broadcast_params(self) -> None:
"""Helper function to broadcast all the parameters from a given device""" """Helper function to broadcast all the parameters from a given device"""
i_param = 0
last_work_handle = None # Work handles are consumed within this scope, no callback last_work_handle = None # Work handles are consumed within this scope, no callback
for (device, device_params,) in self.per_device_params.items(): # all the params on this device (inc all ranks) for device in self.buckets.keys():
buckets = self.buckets[device] for src_rank, bucket in enumerate(self.buckets[device]):
# Bucket and issue all the async calls
for (src_rank, params), bucket in zip(enumerate(device_params), buckets):
global_src_rank = self.get_global_rank(self.group, src_rank) global_src_rank = self.get_global_rank(self.group, src_rank)
# Direct broadcasts only
for param in params:
if not self.should_bucket_param[i_param]:
last_work_handle = dist.broadcast(
tensor=param.data, src=global_src_rank, group=self.group, async_op=True
)
i_param += 1
# Bucket broadcasts
last_work_handle = dist.broadcast(tensor=bucket, src=global_src_rank, group=self.group, async_op=True) last_work_handle = dist.broadcast(tensor=bucket, src=global_src_rank, group=self.group, async_op=True)
# Only check on the last handle, they're all inlined on the same CUDA stream # Only check on the last handle, they're all inlined on the same CUDA stream
...@@ -569,7 +550,6 @@ class OSS(Optimizer): ...@@ -569,7 +550,6 @@ class OSS(Optimizer):
"""Consume all the futures which are tied to this optimizer's buckets. """Consume all the futures which are tied to this optimizer's buckets.
We start from the first/older ones, since they are the most likely to be ready and non-blocking We start from the first/older ones, since they are the most likely to be ready and non-blocking
""" """
while len(self.work_handles) > 0: while len(self.work_handles) > 0:
work_handle = self.work_handles.popleft() work_handle = self.work_handles.popleft()
work_handle.handle.wait() work_handle.handle.wait()
...@@ -583,51 +563,26 @@ class OSS(Optimizer): ...@@ -583,51 +563,26 @@ class OSS(Optimizer):
if work_handle.callback is not None: if work_handle.callback is not None:
work_handle.callback() work_handle.callback()
def _setup_bucket_strategy(self) -> None: def _setup_flat_buffers(self) -> None:
"""Tag parameters to either bucket them or broadcast/reduce them directly. The parameters are ordered """Make all params which are on the same device and tied to the same rank views of a single buffer.
(smallest first), the bucket will hold the smallest elements, the remaining ones will be directly sent This is used at construction time, and anytime parameter trainability is changed (frozen or unfrozen) and
over the wire. `refresh_trainability` is called.
Generating the partition once and for all allows us to save some time at runtime, and to know when all the
network requests have been issued.
""" """
# (re) allocate the buckets
# - Get the correct size for the buckets, cannot be bigger than the model
model_size = sum([p.numel() for p in self.param_to_rank.keys()])
self.bucket_size = min(self.buffer_max_size, model_size)
logging.info(
"Bucket size: {:.2f}M parameters, model size {:.2f}M parameters".format(
self.bucket_size / 2 ** 20, model_size / 2 ** 20
)
)
# - Allocate one buffer per rank and per device to group the small parameters
for device, per_device in self.per_device_params.items():
self.buckets[device] = [
torch.zeros(self.bucket_size, dtype=per_device[0][0].dtype, device=device)
for _ in range(len(per_device))
]
# Devise the bucketing strategy
for device, per_rank_params in self.per_device_params.items(): for device, per_rank_params in self.per_device_params.items():
self.buckets[device] = []
for dst_rank, params in enumerate(per_rank_params): for dst_rank, params in enumerate(per_rank_params):
if len(params) > 0:
trainable_params = list(filter(lambda x: x.requires_grad, params))
buffer_size = sum(map(lambda x: x.numel(), trainable_params))
self.buckets[device].append(torch.empty(buffer_size, dtype=params[0].dtype, device=device))
offset = 0 offset = 0
for param in params: for param in trainable_params:
# Criteria to decide whether this parameter is to be bucketed or not:
# - enough room in the bucket
if param.requires_grad and (offset + param.numel()) < self.bucket_size:
self.should_bucket_param.append(True)
# This parameter becomes a view of the bucket # This parameter becomes a view of the bucket
offset_next = offset + param.numel() offset_next = offset + param.numel()
self.buckets[device][dst_rank][offset:offset_next].copy_(param.data.flatten()) self.buckets[device][dst_rank][offset:offset_next].copy_(param.data.flatten())
param.data = self.buckets[device][dst_rank][offset:offset_next].view_as(param.data) param.data = self.buckets[device][dst_rank][offset:offset_next].view_as(param.data)
offset = offset_next offset = offset_next
else:
self.should_bucket_param.append(False)
# Resize the bucket to remove lost space in the end
self.buckets[device][dst_rank].resize_(offset)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment