Unverified Commit a31b08a5 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[refactor] OSS - broadcasts - getting rid of the while loop (#165)

* small refactor, getting rid of the while loop
parent 339cf060
......@@ -418,58 +418,60 @@ class OSS(Optimizer):
return global_rank
def _broadcast_params(self, buffers: List[torch.Tensor], per_rank_params: List[List[Parameter]]) -> None:
"""Helper function to broadcast all the parameters from a given device
"""
"""Helper function to broadcast all the parameters from a given device"""
buffer_size = buffers[0].numel()
bucket_requests = []
direct_requests = []
# Bucket and issue all the async calls
for (dst_rank, params), buffer in zip(enumerate(per_rank_params), buffers):
# All the params are sorted per rank and per increasing size
if len(params) == 0:
continue
global_dst_rank = OSS.get_global_rank(self.group, dst_rank)
for (src_rank, params), buffer in zip(enumerate(per_rank_params), buffers):
global_src_rank = self.get_global_rank(self.group, src_rank)
# Copy small parameters into per-GPU buffers
i_bucketed = 0 # the number of tensors packed in the buffer
# Copy small parameters into per-GPU buffers and then async broadcast
offset = 0
bucket_sent = False
bucket_params = []
# All the params are sorted per rank and per increasing size
for p in params:
# Since all the parameters are already sorted per increasing size, we only need to consider the first ones.
while i_bucketed < len(params) and offset + params[i_bucketed].numel() < buffer_size:
end = offset + params[i_bucketed].numel()
if global_dst_rank == self.global_rank:
buffer[offset:end].copy_(params[i_bucketed].data.view(-1)) # type: ignore
if not bucket_sent and offset + p.numel() < buffer_size:
end = offset + p.numel()
buffer[offset:end].copy_(p.data.view(-1))
bucket_params.append((p, offset, end))
offset = end
i_bucketed += 1
else:
if offset > 0 and not bucket_sent:
bucket_requests.append(
(
dist.broadcast(tensor=buffer, src=global_src_rank, group=self.group, async_op=True),
src_rank,
bucket_params,
)
)
if i_bucketed > 0:
future = dist.broadcast(tensor=buffer, src=global_dst_rank, group=self.group, async_op=True)
if global_dst_rank != self.global_rank:
# This request will need to be unrolled
bucket_requests.append((future, dst_rank))
bucket_sent = True
# Directly broadcast the rest
for param in params[i_bucketed:]:
direct_requests.append(
dist.broadcast(tensor=param.data, src=global_dst_rank, group=self.group, async_op=True),
dist.broadcast(tensor=p.data, src=global_src_rank, group=self.group, async_op=True)
)
# Unroll the initial packed small parameters
for gate, rank in bucket_requests:
gate.wait()
params = per_rank_params[rank]
buffer = buffers[rank]
i_bucketed = 0 # the number of tensors packed in the buffer
offset = 0
# Catch a trailing bucket
if not bucket_sent:
bucket_requests.append(
(
dist.broadcast(tensor=buffer, src=global_src_rank, group=self.group, async_op=True),
src_rank,
bucket_params,
)
)
while i_bucketed < len(params) and offset + params[i_bucketed].numel() < buffer_size:
end = offset + params[i_bucketed].numel()
params[i_bucketed].data.copy_(buffer[offset:end].view_as(params[i_bucketed])) # type: ignore
offset = end
i_bucketed += 1
# Unroll the initial packed small parameters
for work_handle, src_rank, bucket_params in bucket_requests:
work_handle.wait()
if src_rank != self.rank:
for p, offset, end in bucket_params:
p.data.copy_(buffers[src_rank][offset:end].view_as(p.data))
# Unroll all the async work items, wait for completion
# Unroll all the async work items, just in case
_ = list(map(lambda x: x.wait(), direct_requests))
......@@ -320,6 +320,7 @@ class Tensor:
def coalesce(self) -> Tensor: ...
def conj(self) -> Tensor: ...
def contiguous(self) -> Tensor: ...
def copy_(self, other: Tensor) -> None: ...
def cos(self) -> Tensor: ...
def cos_(self) -> Tensor: ...
def cosh(self) -> Tensor: ...
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment