"git@developer.sourcefind.cn:change/sglang.git" did not exist on "e81d7f11dede2b9b3f82de00a433eccc3d47c25e"
Unverified Commit 6219b57b authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[perf][OSS] tensor views for bucketing (#300)

* min bucket size with model size
* resize the bucket after all the params have been squeezed in, save a tiny bit of memory
* minor, ensure that the cache is freed and improve the comments
parent b202804a
...@@ -16,7 +16,7 @@ import torch.distributed as dist ...@@ -16,7 +16,7 @@ import torch.distributed as dist
from torch.nn import Parameter from torch.nn import Parameter
from torch.optim import SGD, Optimizer from torch.optim import SGD, Optimizer
from .utils import Bucket, Workhandle, recursive_copy_to_device from .utils import Workhandle, recursive_copy_to_device
__all__ = ["OSS"] __all__ = ["OSS"]
...@@ -52,7 +52,9 @@ class OSS(Optimizer): ...@@ -52,7 +52,9 @@ class OSS(Optimizer):
group (group): group (group):
torch.distributed group (default: group.WORLD) torch.distributed group (default: group.WORLD)
broadcast_buffer_size (int): broadcast_buffer_size (int):
the size of the buffer used to batch the small parameter tensors (default 128k). the max size of the buffer used to batch the small parameter tensors, in number of elements (default 16M).
this will not impact the long term memory consumption, but the peak memory can be impacted by the moment
when the buffers are allocated and the bucketed params have not yet been relocated to them.
""" """
#: The optimizer used for a given shard #: The optimizer used for a given shard
...@@ -65,7 +67,7 @@ class OSS(Optimizer): ...@@ -65,7 +67,7 @@ class OSS(Optimizer):
params: _params_t, params: _params_t,
optim: Type[Optimizer] = SGD, optim: Type[Optimizer] = SGD,
group: Optional[Any] = None, group: Optional[Any] = None,
broadcast_buffer_size: int = 2 ** 17, broadcast_buffer_size: int = 2 ** 24,
**default: Any, **default: Any,
): ):
...@@ -97,17 +99,21 @@ class OSS(Optimizer): ...@@ -97,17 +99,21 @@ class OSS(Optimizer):
# Current default device is set by the parameters allocated to this rank # Current default device is set by the parameters allocated to this rank
self._device = list(self.per_device_params.keys())[0] self._device = list(self.per_device_params.keys())[0]
self.buckets: Dict[torch.device, List[Bucket]] = {} self.buckets: Dict[torch.device, List[torch.Tensor]] = {}
if torch.cuda.is_available() and self.world_size <= torch.cuda.device_count(): # Get the correct size for the buckets, cannot be bigger than the model
broadcast_buffer_size = 0 model_size = sum([p.numel() for p in self.param_to_rank.keys()])
logging.warning("Assuming single node job, bucketing is disabled") self.bucket_size = min(broadcast_buffer_size, model_size)
logging.info(
"Bucket size: {:.2f}M parameters, model size {:.2f}M parameters".format(
self.bucket_size / 2 ** 20, model_size / 2 ** 20
)
)
self.bucket_size = broadcast_buffer_size # Allocate one buffer per rank and per device to group the small parameters
for device, per_device in self.per_device_params.items(): for device, per_device in self.per_device_params.items():
# Allocate one buffer per rank and per device to group the small parameters
self.buckets[device] = [ self.buckets[device] = [
Bucket(buffer=torch.zeros(broadcast_buffer_size, dtype=per_device[0][0].dtype, device=device)) torch.zeros(self.bucket_size, dtype=per_device[0][0].dtype, device=device)
for _ in range(len(per_device)) for _ in range(len(per_device))
] ]
self.should_bucket_param: List[bool] = [] self.should_bucket_param: List[bool] = []
...@@ -517,56 +523,39 @@ class OSS(Optimizer): ...@@ -517,56 +523,39 @@ class OSS(Optimizer):
def _broadcast_params(self) -> None: def _broadcast_params(self) -> None:
"""Helper function to broadcast all the parameters from a given device""" """Helper function to broadcast all the parameters from a given device"""
with torch.no_grad(): i_param = 0
i_param = 0
for (
device,
device_params,
) in self.per_device_params.items(): # all the params on this device (inc all ranks)
buckets = self.buckets[device] for (device, device_params,) in self.per_device_params.items(): # all the params on this device (inc all ranks)
buckets = self.buckets[device]
# Bucket and issue all the async calls # Bucket and issue all the async calls
for (src_rank, params), bucket in zip(enumerate(device_params), buckets): for (src_rank, params), bucket in zip(enumerate(device_params), buckets):
global_src_rank = self.get_global_rank(self.group, src_rank) global_src_rank = self.get_global_rank(self.group, src_rank)
for param in params:
# Bucket broadcast
if self.bucket_size > 0 and self.should_bucket_param[i_param]:
assert bucket.append(param), "Bucket overflow: max %s - current %s - adding %s" % (
bucket.max_size,
bucket.current_offset,
param.numel(),
)
if bucket.full(): # Direct broadcasts only
self.work_handles.append( for param in params:
Workhandle( if not self.should_bucket_param[i_param]:
handle=dist.broadcast( self.work_handles.append(
tensor=bucket.buffer, src=global_src_rank, group=self.group, async_op=True Workhandle(
), handle=dist.broadcast(
callback=bucket.unroll, tensor=param.data, src=global_src_rank, group=self.group, async_op=True
) ),
) callback=None,
# Direct
else:
self.work_handles.append(
Workhandle(
handle=dist.broadcast(
tensor=param.data, src=global_src_rank, group=self.group, async_op=True
),
callback=None,
)
) )
)
i_param += 1 i_param += 1
# Bucket broadcasts
self.work_handles.append(
Workhandle(
handle=dist.broadcast(tensor=bucket, src=global_src_rank, group=self.group, async_op=True),
callback=None,
)
)
self._consume_work_handles() self._consume_work_handles()
def _consume_work_handles(self) -> None: def _consume_work_handles(self) -> None:
""" Consume all the futures which are tied to this optimizer's buckets. """Consume all the futures which are tied to this optimizer's buckets.
We start from the first/older ones, since they are the most likely to be ready and non-blocking We start from the first/older ones, since they are the most likely to be ready and non-blocking
""" """
...@@ -577,15 +566,14 @@ class OSS(Optimizer): ...@@ -577,15 +566,14 @@ class OSS(Optimizer):
work_handle.callback() work_handle.callback()
def _try_consume_work_handle(self) -> None: def _try_consume_work_handle(self) -> None:
""" Try to consume the oldest future. This is non blocking, if not ready we'll pass """Try to consume the oldest future. This is non blocking, if not ready we'll pass"""
"""
while len(self.work_handles) > 0 and self.work_handles[0].handle.is_completed(): while len(self.work_handles) > 0 and self.work_handles[0].handle.is_completed():
work_handle = self.work_handles.popleft() work_handle = self.work_handles.popleft()
if work_handle.callback is not None: if work_handle.callback is not None:
work_handle.callback() work_handle.callback()
def _setup_bucket_strategy(self) -> None: def _setup_bucket_strategy(self) -> None:
""" Tag parameters to either bucket them or broadcast/reduce them directly. The parameters are ordered """Tag parameters to either bucket them or broadcast/reduce them directly. The parameters are ordered
(smallest first), the bucket will hold the smallest elements, the remaining ones will be directly sent (smallest first), the bucket will hold the smallest elements, the remaining ones will be directly sent
over the wire. over the wire.
...@@ -604,21 +592,29 @@ class OSS(Optimizer): ...@@ -604,21 +592,29 @@ class OSS(Optimizer):
for param in params: for param in params:
# Criteria to decide whether this parameter is to be bucketed or not: # Criteria to decide whether this parameter is to be bucketed or not:
# - enough room in the bucket # - enough room in the bucket
if param.requires_grad and (offset + param.numel()) < self.buckets[device][dst_rank].max_size: if param.requires_grad and (offset + param.numel()) < self.bucket_size:
self.should_bucket_param.append(True) self.should_bucket_param.append(True)
if offset == 0: if offset == 0:
# count this bucket, only once # count this bucket, only once
self._max_work_handles += 1 self._max_work_handles += 1
offset += param.numel() # This parameter becomes a view of the bucket
offset_next = offset + param.numel()
self.buckets[device][dst_rank][offset:offset_next].copy_(param.data.flatten())
param.data = self.buckets[device][dst_rank][offset:offset_next].view_as(param.data)
offset = offset_next
else: else:
self.should_bucket_param.append(False) self.should_bucket_param.append(False)
# Register the max offset for this buffer, and the reference rank # Resize the bucket to remove lost space in the end
self.buckets[device][dst_rank].max_offset = offset self.buckets[device][dst_rank].resize_(offset)
self.buckets[device][dst_rank].global_ref_rank = self.get_global_rank(self.group, dst_rank)
self.buckets[device][dst_rank].global_rank = self.global_rank # Make sure that the memory previously taken by the bucketed parameters is released
if self._device.type == "cuda":
torch.cuda.empty_cache()
# Determine the max work handles in flight: # Determine the max work handles in flight:
# - all the direct reduce/broadcast # - all the direct reduce/broadcast
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
# This source code is licensed under the BSD license found in the # This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from typing import Any, Callable, Dict, List, Optional from typing import Any, Callable, Dict, Optional
import torch import torch
from torch._six import container_abcs from torch._six import container_abcs
...@@ -15,88 +15,6 @@ class Workhandle: ...@@ -15,88 +15,6 @@ class Workhandle:
self.callback = callback self.callback = callback
class FlatParam:
def __init__(self, tensor: torch.Tensor, start: int, stop: int) -> None:
self.param = tensor
self.start = start
self.stop = stop
class Bucket:
"""
Helper class to simplify the handling of broadcast or reduce buckets
"""
def __init__(self, buffer: torch.Tensor) -> None:
# The actual flat tensor
self.buffer = buffer
self.max_size = buffer.numel()
# Handles to the params and their position in this tensor, can be useful for a callback
self.params: List[FlatParam] = []
# Optional callback, possibly to unwrap the bucket
self.callback: Optional[Callable] = None
# Current status for this buffer
self.current_offset = 0
self.max_offset = 0
self.global_ref_rank = -1 # Either the destination or the src rank, if reducing or broadcasting for instance
self.global_rank = -1
self.gradients_based = False
def unroll(self) -> None:
"""
Dsitribute the contents of the flat buffer back to the attached parameters
"""
for flat in self.params:
if self.global_ref_rank != self.global_rank and self.gradients_based:
# this rank is not the owner, release the grad
flat.param.grad = None
else:
if self.gradients_based:
# this rank is the owner, unroll the results
assert flat.param.grad is not None
flat.param.grad.data.copy_(
self.buffer[flat.start : flat.stop].view_as(flat.param.data), non_blocking=True
)
else:
flat.param.data.copy_(
self.buffer[flat.start : flat.stop].view_as(flat.param.data), non_blocking=True
)
self.reset()
def reset(self) -> None:
""" empty the bucket """
self.current_offset = 0
self.params.clear()
def append(self, tensor: torch.Tensor, use_gradient: bool = False) -> bool:
""" add a tensor to the bucket """
end = self.current_offset + tensor.numel()
self.gradients_based = use_gradient
if end > self.max_size:
return False
if use_gradient:
assert tensor.grad is not None
data_source = tensor.grad.data if use_gradient else tensor.data # type: ignore # mypy is drunk
self.buffer[self.current_offset : end].copy_(data_source.view(-1))
self.params.append(FlatParam(tensor=tensor, start=self.current_offset, stop=end))
self.current_offset = end
return True
def full(self) -> bool:
""" is the bucket full ? """
return self.current_offset == self.max_offset
# Credits: classy_vision/generic/distributed_util.py # Credits: classy_vision/generic/distributed_util.py
def recursive_copy_to_device(value: Any, non_blocking: bool, device: torch.device) -> Any: def recursive_copy_to_device(value: Any, non_blocking: bool, device: torch.device) -> Any:
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment