"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "fb29132b98abdd218bacb6dbaab372f5bb177a2e"
Unverified Commit a31b08a5 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[refactor] OSS - broadcasts - getting rid of the while loop (#165)

* small refactor, getting rid of the while loop
parent 339cf060
...@@ -418,58 +418,60 @@ class OSS(Optimizer): ...@@ -418,58 +418,60 @@ class OSS(Optimizer):
return global_rank return global_rank
def _broadcast_params(self, buffers: List[torch.Tensor], per_rank_params: List[List[Parameter]]) -> None: def _broadcast_params(self, buffers: List[torch.Tensor], per_rank_params: List[List[Parameter]]) -> None:
"""Helper function to broadcast all the parameters from a given device """Helper function to broadcast all the parameters from a given device"""
"""
buffer_size = buffers[0].numel() buffer_size = buffers[0].numel()
bucket_requests = [] bucket_requests = []
direct_requests = [] direct_requests = []
# Bucket and issue all the async calls # Bucket and issue all the async calls
for (dst_rank, params), buffer in zip(enumerate(per_rank_params), buffers): for (src_rank, params), buffer in zip(enumerate(per_rank_params), buffers):
# All the params are sorted per rank and per increasing size global_src_rank = self.get_global_rank(self.group, src_rank)
if len(params) == 0:
continue
global_dst_rank = OSS.get_global_rank(self.group, dst_rank)
# Copy small parameters into per-GPU buffers # Copy small parameters into per-GPU buffers and then async broadcast
i_bucketed = 0 # the number of tensors packed in the buffer
offset = 0 offset = 0
bucket_sent = False
bucket_params = []
# Since all the parameters are already sorted per increasing size, we only need to consider the first ones. # All the params are sorted per rank and per increasing size
while i_bucketed < len(params) and offset + params[i_bucketed].numel() < buffer_size: for p in params:
end = offset + params[i_bucketed].numel() # Since all the parameters are already sorted per increasing size, we only need to consider the first ones.
if global_dst_rank == self.global_rank: if not bucket_sent and offset + p.numel() < buffer_size:
buffer[offset:end].copy_(params[i_bucketed].data.view(-1)) # type: ignore end = offset + p.numel()
offset = end buffer[offset:end].copy_(p.data.view(-1))
i_bucketed += 1 bucket_params.append((p, offset, end))
offset = end
if i_bucketed > 0: else:
future = dist.broadcast(tensor=buffer, src=global_dst_rank, group=self.group, async_op=True) if offset > 0 and not bucket_sent:
if global_dst_rank != self.global_rank: bucket_requests.append(
# This request will need to be unrolled (
bucket_requests.append((future, dst_rank)) dist.broadcast(tensor=buffer, src=global_src_rank, group=self.group, async_op=True),
src_rank,
# Directly broadcast the rest bucket_params,
for param in params[i_bucketed:]: )
direct_requests.append( )
dist.broadcast(tensor=param.data, src=global_dst_rank, group=self.group, async_op=True),
bucket_sent = True
direct_requests.append(
dist.broadcast(tensor=p.data, src=global_src_rank, group=self.group, async_op=True)
)
# Catch a trailing bucket
if not bucket_sent:
bucket_requests.append(
(
dist.broadcast(tensor=buffer, src=global_src_rank, group=self.group, async_op=True),
src_rank,
bucket_params,
)
) )
# Unroll the initial packed small parameters # Unroll the initial packed small parameters
for gate, rank in bucket_requests: for work_handle, src_rank, bucket_params in bucket_requests:
gate.wait() work_handle.wait()
if src_rank != self.rank:
params = per_rank_params[rank] for p, offset, end in bucket_params:
buffer = buffers[rank] p.data.copy_(buffers[src_rank][offset:end].view_as(p.data))
i_bucketed = 0 # the number of tensors packed in the buffer
offset = 0
while i_bucketed < len(params) and offset + params[i_bucketed].numel() < buffer_size:
end = offset + params[i_bucketed].numel()
params[i_bucketed].data.copy_(buffer[offset:end].view_as(params[i_bucketed])) # type: ignore
offset = end
i_bucketed += 1
# Unroll all the async work items, wait for completion # Unroll all the async work items, just in case
_ = list(map(lambda x: x.wait(), direct_requests)) _ = list(map(lambda x: x.wait(), direct_requests))
...@@ -320,6 +320,7 @@ class Tensor: ...@@ -320,6 +320,7 @@ class Tensor:
def coalesce(self) -> Tensor: ... def coalesce(self) -> Tensor: ...
def conj(self) -> Tensor: ... def conj(self) -> Tensor: ...
def contiguous(self) -> Tensor: ... def contiguous(self) -> Tensor: ...
def copy_(self, other: Tensor) -> None: ...
def cos(self) -> Tensor: ... def cos(self) -> Tensor: ...
def cos_(self) -> Tensor: ... def cos_(self) -> Tensor: ...
def cosh(self) -> Tensor: ... def cosh(self) -> Tensor: ...
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment