Commit e46f3260 authored by Mostofa Patwary's avatar Mostofa Patwary
Browse files

fixed the evaluation hangs

parent ebfbfcec
...@@ -47,28 +47,13 @@ def check_and_append_tensor_for_gather(group, rank, world_size, input_): ...@@ -47,28 +47,13 @@ def check_and_append_tensor_for_gather(group, rank, world_size, input_):
max_length = torch.max(all_input_list) max_length = torch.max(all_input_list)
min_length = torch.min(all_input_list) min_length = torch.min(all_input_list)
#if rank == 0: # if the size are different than the max, extend the tensor
# print("rank {} all pad neg_context_tokens 0 {}".format(rank, input_[0]), flush=True) # accordingly
# print("rank {} all pad neg_context_tokens max_length {}".format(rank, input_[max_length-1]), flush=True)
if max_length > current_length: if max_length > current_length:
#print("rank {} before pad neg_context_tokens current_length-1 {}".format(rank, input_[current_length-1]), flush=True)
#torch.set_printoptions(profile="full")
#input_ = torch.nn.functional.pad(input=input_,
# pad=(0, 0, 0, max_length - current_length))
padding=tuple([0] * (input_.dim() * 2 - 1)) + \ padding=tuple([0] * (input_.dim() * 2 - 1)) + \
tuple([max_length - current_length]) tuple([max_length - current_length])
input_ = F.pad(input=input_, pad=padding) input_ = F.pad(input=input_, pad=padding)
#print("rank {} after pad neg_context_tokens current_length-1 {}".format(rank, input_[current_length-1]), flush=True)
#print("rank {} after pad neg_context_tokens current_length {}".format(rank, input_[current_length]), flush=True)
#print("rank {} after pad neg_context_tokens max_length {}".format(rank, input_[max_length-1]), flush=True)
#if rank == 0:
# print("rank {} all pad neg_context_tokens 0 {}".format(rank, input_[0]), flush=True)
# print("rank {} all pad neg_context_tokens max_length {}".format(rank, input_[max_length-1]), flush=True)
return input_ return input_
def orqa(Dataset): def orqa(Dataset):
...@@ -101,31 +86,19 @@ def orqa(Dataset): ...@@ -101,31 +86,19 @@ def orqa(Dataset):
query_list.append(tokenizer.decode(query_tokens[i].tolist())) query_list.append(tokenizer.decode(query_tokens[i].tolist()))
context_list.append(tokenizer.decode(context_tokens[i].tolist())) context_list.append(tokenizer.decode(context_tokens[i].tolist()))
#if rank == 5: if neg_context_tokens is not None:
# print("rank {} before query_tokens {} query_mask {} query_types {} context_tokens {} context_mask {} context_types {} neg_context_tokens {} neg_context_mask {} neg_context_types {}".format(rank, query_tokens.size(), query_mask.size(), neg_context_tokens = check_and_append_tensor_for_gather(group,
# query_types.size(), context_tokens.size(), context_mask.size(), context_types.size(), neg_context_tokens.size(), neg_context_mask.size(), neg_context_types.size()), flush=True) rank, world_size, neg_context_tokens)
neg_context_mask = check_and_append_tensor_for_gather(group,
if neg_context_tokens is not None: # and neg_context_tokens.size()[0] > local_batch_size: rank, world_size, neg_context_mask)
neg_context_tokens = check_and_append_tensor_for_gather(group, rank, world_size, neg_context_tokens) neg_context_types = check_and_append_tensor_for_gather(group,
neg_context_mask = check_and_append_tensor_for_gather(group, rank, world_size, neg_context_mask) rank, world_size, neg_context_types)
neg_context_types = check_and_append_tensor_for_gather(group, rank, world_size, neg_context_types)
#exit()
#if rank == 5:
# print("rank {} middle query_tokens {} query_mask {} query_types {} context_tokens {} context_mask {} context_types {} neg_context_tokens {} neg_context_mask {} neg_context_types {}".format(rank, query_tokens.size(), query_mask.size(),
# query_types.size(), context_tokens.size(), context_mask.size(), context_types.size(), neg_context_tokens.size(), neg_context_mask.size(), neg_context_types.size()), flush=True)
if neg_context_tokens is not None: if neg_context_tokens is not None:
context_tokens = torch.cat([context_tokens, neg_context_tokens]) context_tokens = torch.cat([context_tokens, neg_context_tokens])
context_mask = torch.cat([context_mask, neg_context_mask]) context_mask = torch.cat([context_mask, neg_context_mask])
context_types = torch.cat([context_types, neg_context_types]) context_types = torch.cat([context_types, neg_context_types])
#if rank == 5:
# print("rank {} after query_tokens {} query_mask {} query_types {} context_tokens {} context_mask {} context_types {}".format(rank, query_tokens.size(), query_mask.size(),
# query_types.size(), context_tokens.size(), context_mask.size(), context_types.size()), flush=True)
#print("==rank {} query_tokens {} context_tokens {}".format(rank, query_tokens.size(), context_tokens.size()), flush=True)
# Forward model. # Forward model.
output_tensor = model(query_tokens, query_mask, output_tensor = model(query_tokens, query_mask,
query_types, context_tokens, query_types, context_tokens,
...@@ -144,13 +117,10 @@ def orqa(Dataset): ...@@ -144,13 +117,10 @@ def orqa(Dataset):
query_logits, context_logits = output_tensor query_logits, context_logits = output_tensor
if world_size > 1: if world_size > 1:
#print("rank {} query_logits {} context_logits {}".format(rank, query_logits.size(), context_logits.size()))
input_ = torch.empty_like(context_logits).copy_(\ input_ = torch.empty_like(context_logits).copy_(\
context_logits).detach_() context_logits).detach_()
tensor_list = [torch.empty_like(input_) for _ in range(world_size)] tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
tensor_list[rank].copy_(input_) tensor_list[rank].copy_(input_)
#print_rank_0("At cross_entropy_loss_func")
#print("rank {} input_ {}".format(rank, input_.size()))
torch.distributed.all_gather(tensor_list, input_, group=group) torch.distributed.all_gather(tensor_list, input_, group=group)
# Check if all-gather happens in order # Check if all-gather happens in order
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment