Unverified Commit a31b08a5 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[refactor] OSS - broadcasts - getting rid of the while loop (#165)

* small refactor, getting rid of the while loop
parent 339cf060
...@@ -418,58 +418,60 @@ class OSS(Optimizer): ...@@ -418,58 +418,60 @@ class OSS(Optimizer):
return global_rank return global_rank
def _broadcast_params(self, buffers: List[torch.Tensor], per_rank_params: List[List[Parameter]]) -> None: def _broadcast_params(self, buffers: List[torch.Tensor], per_rank_params: List[List[Parameter]]) -> None:
"""Helper function to broadcast all the parameters from a given device """Helper function to broadcast all the parameters from a given device"""
"""
buffer_size = buffers[0].numel() buffer_size = buffers[0].numel()
bucket_requests = [] bucket_requests = []
direct_requests = [] direct_requests = []
# Bucket and issue all the async calls # Bucket and issue all the async calls
for (dst_rank, params), buffer in zip(enumerate(per_rank_params), buffers): for (src_rank, params), buffer in zip(enumerate(per_rank_params), buffers):
# All the params are sorted per rank and per increasing size global_src_rank = self.get_global_rank(self.group, src_rank)
if len(params) == 0:
continue
global_dst_rank = OSS.get_global_rank(self.group, dst_rank)
# Copy small parameters into per-GPU buffers # Copy small parameters into per-GPU buffers and then async broadcast
i_bucketed = 0 # the number of tensors packed in the buffer
offset = 0 offset = 0
bucket_sent = False
bucket_params = []
# All the params are sorted per rank and per increasing size
for p in params:
# Since all the parameters are already sorted per increasing size, we only need to consider the first ones. # Since all the parameters are already sorted per increasing size, we only need to consider the first ones.
while i_bucketed < len(params) and offset + params[i_bucketed].numel() < buffer_size: if not bucket_sent and offset + p.numel() < buffer_size:
end = offset + params[i_bucketed].numel() end = offset + p.numel()
if global_dst_rank == self.global_rank: buffer[offset:end].copy_(p.data.view(-1))
buffer[offset:end].copy_(params[i_bucketed].data.view(-1)) # type: ignore bucket_params.append((p, offset, end))
offset = end offset = end
i_bucketed += 1 else:
if offset > 0 and not bucket_sent:
bucket_requests.append(
(
dist.broadcast(tensor=buffer, src=global_src_rank, group=self.group, async_op=True),
src_rank,
bucket_params,
)
)
if i_bucketed > 0: bucket_sent = True
future = dist.broadcast(tensor=buffer, src=global_dst_rank, group=self.group, async_op=True)
if global_dst_rank != self.global_rank:
# This request will need to be unrolled
bucket_requests.append((future, dst_rank))
# Directly broadcast the rest
for param in params[i_bucketed:]:
direct_requests.append( direct_requests.append(
dist.broadcast(tensor=param.data, src=global_dst_rank, group=self.group, async_op=True), dist.broadcast(tensor=p.data, src=global_src_rank, group=self.group, async_op=True)
) )
# Unroll the initial packed small parameters # Catch a trailing bucket
for gate, rank in bucket_requests: if not bucket_sent:
gate.wait() bucket_requests.append(
(
params = per_rank_params[rank] dist.broadcast(tensor=buffer, src=global_src_rank, group=self.group, async_op=True),
buffer = buffers[rank] src_rank,
i_bucketed = 0 # the number of tensors packed in the buffer bucket_params,
offset = 0 )
)
while i_bucketed < len(params) and offset + params[i_bucketed].numel() < buffer_size: # Unroll the initial packed small parameters
end = offset + params[i_bucketed].numel() for work_handle, src_rank, bucket_params in bucket_requests:
params[i_bucketed].data.copy_(buffer[offset:end].view_as(params[i_bucketed])) # type: ignore work_handle.wait()
offset = end if src_rank != self.rank:
i_bucketed += 1 for p, offset, end in bucket_params:
p.data.copy_(buffers[src_rank][offset:end].view_as(p.data))
# Unroll all the async work items, wait for completion # Unroll all the async work items, just in case
_ = list(map(lambda x: x.wait(), direct_requests)) _ = list(map(lambda x: x.wait(), direct_requests))
...@@ -320,6 +320,7 @@ class Tensor: ...@@ -320,6 +320,7 @@ class Tensor:
def coalesce(self) -> Tensor: ... def coalesce(self) -> Tensor: ...
def conj(self) -> Tensor: ... def conj(self) -> Tensor: ...
def contiguous(self) -> Tensor: ... def contiguous(self) -> Tensor: ...
def copy_(self, other: Tensor) -> None: ...
def cos(self) -> Tensor: ... def cos(self) -> Tensor: ...
def cos_(self) -> Tensor: ... def cos_(self) -> Tensor: ...
def cosh(self) -> Tensor: ... def cosh(self) -> Tensor: ...
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment