Unverified Commit 61bb32b5 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[refactor][minor] OSS - small refactor of the bucketing (#153)

* small refactor, code cleanup
* broadcast tensor .data attribute directly
parent 66b2b514
......@@ -421,17 +421,16 @@ class OSS(Optimizer):
"""Helper function to broadcast all the parameters from a given device
"""
buffer_size = buffers[0].numel()
restore_require_grad = []
bucket_requests = []
requests = []
direct_requests = []
# Bucket and issue all the async calls
for (rank, params), buffer in zip(enumerate(per_rank_params), buffers):
for (dst_rank, params), buffer in zip(enumerate(per_rank_params), buffers):
# All the params are sorted per rank and per increasing size
if len(params) == 0:
continue
global_rank = OSS.get_global_rank(self.group, rank)
global_dst_rank = OSS.get_global_rank(self.group, dst_rank)
# Copy small parameters into per-GPU buffers
i_bucketed = 0 # the number of tensors packed in the buffer
......@@ -440,28 +439,22 @@ class OSS(Optimizer):
# Since all the parameters are already sorted per increasing size, we only need to consider the first ones.
while i_bucketed < len(params) and offset + params[i_bucketed].numel() < buffer_size:
end = offset + params[i_bucketed].numel()
if global_rank == self.global_rank:
if global_dst_rank == self.global_rank:
buffer[offset:end].copy_(params[i_bucketed].data.view(-1)) # type: ignore
offset = end
i_bucketed += 1
if i_bucketed > 0:
future = dist.broadcast(tensor=buffer, src=global_rank, group=self.group, async_op=True)
if global_rank != self.global_rank:
future = dist.broadcast(tensor=buffer, src=global_dst_rank, group=self.group, async_op=True)
if global_dst_rank != self.global_rank:
# This request will need to be unrolled
bucket_requests.append((future, rank))
bucket_requests.append((future, dst_rank))
# Directly broadcast the rest
for param in params[i_bucketed:]:
# NOTE: Broadcast is in-place and not differentiable
# Gloo will assert on this operation for any tensor that requires grad.
# We save and restore the grad requirement state to work around that, in our case
# the grad is only useful on the source rank.
if param.requires_grad:
restore_require_grad.append(param)
param.requires_grad = False
requests.append(dist.broadcast(tensor=param, src=global_rank, group=self.group, async_op=True))
direct_requests.append(
dist.broadcast(tensor=param.data, src=global_dst_rank, group=self.group, async_op=True),
)
# Unroll the initial packed small parameters
for gate, rank in bucket_requests:
......@@ -478,8 +471,5 @@ class OSS(Optimizer):
offset = end
i_bucketed += 1
# Unroll all the async work items, just in case
_ = list(map(lambda x: x.wait(), requests))
for p in restore_require_grad:
p.requires_grad = True
# Unroll all the async work items, wait for completion
_ = list(map(lambda x: x.wait(), direct_requests))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment