Commit 04c79f30 authored by Mostofa Patwary's avatar Mostofa Patwary
Browse files

resolved hang issue

parent 84eb016c
......@@ -47,6 +47,8 @@ def orqa(Dataset):
except BaseException:
batch_ = batch
group, rank, world_size = get_group_world_size_rank()
query_tokens, query_mask, query_types, query_pad_mask, \
context_tokens, context_mask, context_types, context_pad_mask, \
neg_context_tokens, neg_context_mask, neg_context_types, \
......@@ -54,6 +56,7 @@ def orqa(Dataset):
timers('batch generator').stop()
local_batch_size = query_tokens.shape[0]
#print("rank {} query_tokens {} context_tokens {} batch {} neg_context_tokens {}".format(rank, query_tokens.size(), context_tokens.size(), local_batch_size, neg_context_tokens.size()), flush=True)
# Text representation of query and context
query_list, context_list = [], []
......@@ -61,16 +64,49 @@ def orqa(Dataset):
query_list.append(tokenizer.decode(query_tokens[i].tolist()))
context_list.append(tokenizer.decode(context_tokens[i].tolist()))
if neg_context_tokens.size()[0] > 200:
current_length = neg_context_tokens.size()[0]
first_dim = torch.tensor([[neg_context_tokens.size()[0]]], device=torch.cuda.current_device())
neg_context_list = [torch.empty_like(first_dim) for _ in range(world_size)]
neg_context_list[rank].copy_(first_dim)
torch.distributed.all_gather(neg_context_list, first_dim, group=group)
all_neg_context_list = torch.cat(neg_context_list, dim=0).contiguous()
max_length = torch.max(all_neg_context_list)
torch.set_printoptions(profile="full")
if max_length > current_length:
print("rank {} before pad neg_context_tokens {}".format(rank, neg_context_tokens[current_length-1]), flush=True)
neg_context_tokens = torch.nn.functional.pad(input=neg_context_tokens, pad=(0, 0, 0, max_length - neg_context_tokens.size()[0]))
input_ = torch.empty_like(neg_context_tokens).copy_(\
neg_context_tokens).detach_()
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
tensor_list[rank].copy_(input_)
torch.distributed.all_gather(tensor_list, input_, group=group)
if max_length > current_length:
print("rank {} after pad neg_context_tokens current_length-1 {}".format(rank, neg_context_tokens[current_length-1]), flush=True)
print("rank {} after pad neg_context_tokens current_length {}".format(rank, neg_context_tokens[current_length]), flush=True)
print("rank {} after pad neg_context_tokens max_length-1 {}".format(rank, neg_context_tokens[max_length-1]), flush=True)
if rank == 0:
print("rank {} other pad neg_context_tokens current_length-1 {}".format(rank, tensor_list[5][451]), flush=True)
print("rank {} other pad neg_context_tokens current_length {}".format(rank, tensor_list[5][452]), flush=True)
print("rank {} other pad neg_context_tokens max_length-1 {}".format(rank, tensor_list[5][max_length-1]), flush=True)
torch.set_printoptions(profile="default")
exit()
if neg_context_tokens is not None:
context_tokens = torch.cat([context_tokens, neg_context_tokens])
context_mask = torch.cat([context_mask, neg_context_mask])
context_types = torch.cat([context_types, neg_context_types])
#print("==rank {} query_tokens {} context_tokens {}".format(rank, query_tokens.size(), context_tokens.size()), flush=True)
# Forward model.
output_tensor = model(query_tokens, query_mask,
query_types, context_tokens,
context_mask, context_types)
return output_tensor, partial(cross_entropy_loss_func, query_tokens, context_tokens)
......@@ -85,10 +121,13 @@ def orqa(Dataset):
query_logits, context_logits = output_tensor
if world_size > 1:
#print("rank {} query_logits {} context_logits {}".format(rank, query_logits.size(), context_logits.size()))
input_ = torch.empty_like(context_logits).copy_(\
context_logits).detach_()
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
tensor_list[rank].copy_(input_)
#print_rank_0("At cross_entropy_loss_func")
#print("rank {} input_ {}".format(rank, input_.size()))
torch.distributed.all_gather(tensor_list, input_, group=group)
# Check if all-gather happens in order
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment