Unverified Commit 82dbd5d8 authored by ngoyal2707's avatar ngoyal2707 Committed by GitHub
Browse files

[fix] megatron + oss (#127)


authored-by: default avatarNaman Goyal <namangoyal@learnfair0755.h2.fair>
parent 6658be22
......@@ -434,14 +434,14 @@ class OSS(Optimizer):
# Since all the parameters are already sorted per increasing size, we only need to consider the first ones.
while i_bucketed < len(params) and offset + params[i_bucketed].numel() < buffer_size:
end = offset + params[i_bucketed].numel()
if rank == self_rank:
if global_rank == self_rank:
buffer[offset:end].copy_(params[i_bucketed].data.view(-1)) # type: ignore
offset = end
i_bucketed += 1
if i_bucketed > 0:
future = dist.broadcast(tensor=buffer, src=global_rank, group=group, async_op=True)
if rank != self_rank:
if global_rank != self_rank:
# This request will need to be unrolled
bucket_requests.append((future, rank))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment