Unverified Commit 61bb32b5 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[refactor][minor] OSS - small refactor of the bucketing (#153)

* small refactor, code cleanup
* broadcast tensor .data attribute directly
parent 66b2b514
...@@ -421,17 +421,16 @@ class OSS(Optimizer): ...@@ -421,17 +421,16 @@ class OSS(Optimizer):
"""Helper function to broadcast all the parameters from a given device """Helper function to broadcast all the parameters from a given device
""" """
buffer_size = buffers[0].numel() buffer_size = buffers[0].numel()
restore_require_grad = []
bucket_requests = [] bucket_requests = []
requests = [] direct_requests = []
# Bucket and issue all the async calls # Bucket and issue all the async calls
for (rank, params), buffer in zip(enumerate(per_rank_params), buffers): for (dst_rank, params), buffer in zip(enumerate(per_rank_params), buffers):
# All the params are sorted per rank and per increasing size # All the params are sorted per rank and per increasing size
if len(params) == 0: if len(params) == 0:
continue continue
global_rank = OSS.get_global_rank(self.group, rank) global_dst_rank = OSS.get_global_rank(self.group, dst_rank)
# Copy small parameters into per-GPU buffers # Copy small parameters into per-GPU buffers
i_bucketed = 0 # the number of tensors packed in the buffer i_bucketed = 0 # the number of tensors packed in the buffer
...@@ -440,28 +439,22 @@ class OSS(Optimizer): ...@@ -440,28 +439,22 @@ class OSS(Optimizer):
# Since all the parameters are already sorted per increasing size, we only need to consider the first ones. # Since all the parameters are already sorted per increasing size, we only need to consider the first ones.
while i_bucketed < len(params) and offset + params[i_bucketed].numel() < buffer_size: while i_bucketed < len(params) and offset + params[i_bucketed].numel() < buffer_size:
end = offset + params[i_bucketed].numel() end = offset + params[i_bucketed].numel()
if global_rank == self.global_rank: if global_dst_rank == self.global_rank:
buffer[offset:end].copy_(params[i_bucketed].data.view(-1)) # type: ignore buffer[offset:end].copy_(params[i_bucketed].data.view(-1)) # type: ignore
offset = end offset = end
i_bucketed += 1 i_bucketed += 1
if i_bucketed > 0: if i_bucketed > 0:
future = dist.broadcast(tensor=buffer, src=global_rank, group=self.group, async_op=True) future = dist.broadcast(tensor=buffer, src=global_dst_rank, group=self.group, async_op=True)
if global_rank != self.global_rank: if global_dst_rank != self.global_rank:
# This request will need to be unrolled # This request will need to be unrolled
bucket_requests.append((future, rank)) bucket_requests.append((future, dst_rank))
# Directly broadcast the rest # Directly broadcast the rest
for param in params[i_bucketed:]: for param in params[i_bucketed:]:
# NOTE: Broadcast is in-place and not differentiable direct_requests.append(
# Gloo will assert on this operation for any tensor that requires grad. dist.broadcast(tensor=param.data, src=global_dst_rank, group=self.group, async_op=True),
# We save and restore the grad requirement state to work around that, in our case )
# the grad is only useful on the source rank.
if param.requires_grad:
restore_require_grad.append(param)
param.requires_grad = False
requests.append(dist.broadcast(tensor=param, src=global_rank, group=self.group, async_op=True))
# Unroll the initial packed small parameters # Unroll the initial packed small parameters
for gate, rank in bucket_requests: for gate, rank in bucket_requests:
...@@ -478,8 +471,5 @@ class OSS(Optimizer): ...@@ -478,8 +471,5 @@ class OSS(Optimizer):
offset = end offset = end
i_bucketed += 1 i_bucketed += 1
# Unroll all the async work items, just in case # Unroll all the async work items, wait for completion
_ = list(map(lambda x: x.wait(), requests)) _ = list(map(lambda x: x.wait(), direct_requests))
for p in restore_require_grad:
p.requires_grad = True
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment