Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
e46f3260
Commit
e46f3260
authored
Jun 09, 2021
by
Mostofa Patwary
Browse files
fixed the evaluation hangs
parent
ebfbfcec
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
9 additions
and
39 deletions
+9
-39
tasks/orqa/supervised/finetune.py
tasks/orqa/supervised/finetune.py
+9
-39
No files found.
tasks/orqa/supervised/finetune.py
View file @
e46f3260
...
@@ -47,28 +47,13 @@ def check_and_append_tensor_for_gather(group, rank, world_size, input_):
...
@@ -47,28 +47,13 @@ def check_and_append_tensor_for_gather(group, rank, world_size, input_):
max_length
=
torch
.
max
(
all_input_list
)
max_length
=
torch
.
max
(
all_input_list
)
min_length
=
torch
.
min
(
all_input_list
)
min_length
=
torch
.
min
(
all_input_list
)
#if rank == 0:
# if the size are different than the max, extend the tensor
# print("rank {} all pad neg_context_tokens 0 {}".format(rank, input_[0]), flush=True)
# accordingly
# print("rank {} all pad neg_context_tokens max_length {}".format(rank, input_[max_length-1]), flush=True)
if
max_length
>
current_length
:
if
max_length
>
current_length
:
#print("rank {} before pad neg_context_tokens current_length-1 {}".format(rank, input_[current_length-1]), flush=True)
#torch.set_printoptions(profile="full")
#input_ = torch.nn.functional.pad(input=input_,
# pad=(0, 0, 0, max_length - current_length))
padding
=
tuple
([
0
]
*
(
input_
.
dim
()
*
2
-
1
))
+
\
padding
=
tuple
([
0
]
*
(
input_
.
dim
()
*
2
-
1
))
+
\
tuple
([
max_length
-
current_length
])
tuple
([
max_length
-
current_length
])
input_
=
F
.
pad
(
input
=
input_
,
pad
=
padding
)
input_
=
F
.
pad
(
input
=
input_
,
pad
=
padding
)
#print("rank {} after pad neg_context_tokens current_length-1 {}".format(rank, input_[current_length-1]), flush=True)
#print("rank {} after pad neg_context_tokens current_length {}".format(rank, input_[current_length]), flush=True)
#print("rank {} after pad neg_context_tokens max_length {}".format(rank, input_[max_length-1]), flush=True)
#if rank == 0:
# print("rank {} all pad neg_context_tokens 0 {}".format(rank, input_[0]), flush=True)
# print("rank {} all pad neg_context_tokens max_length {}".format(rank, input_[max_length-1]), flush=True)
return
input_
return
input_
def
orqa
(
Dataset
):
def
orqa
(
Dataset
):
...
@@ -101,31 +86,19 @@ def orqa(Dataset):
...
@@ -101,31 +86,19 @@ def orqa(Dataset):
query_list
.
append
(
tokenizer
.
decode
(
query_tokens
[
i
].
tolist
()))
query_list
.
append
(
tokenizer
.
decode
(
query_tokens
[
i
].
tolist
()))
context_list
.
append
(
tokenizer
.
decode
(
context_tokens
[
i
].
tolist
()))
context_list
.
append
(
tokenizer
.
decode
(
context_tokens
[
i
].
tolist
()))
#if rank == 5:
if
neg_context_tokens
is
not
None
:
# print("rank {} before query_tokens {} query_mask {} query_types {} context_tokens {} context_mask {} context_types {} neg_context_tokens {} neg_context_mask {} neg_context_types {}".format(rank, query_tokens.size(), query_mask.size(),
neg_context_tokens
=
check_and_append_tensor_for_gather
(
group
,
# query_types.size(), context_tokens.size(), context_mask.size(), context_types.size(), neg_context_tokens.size(), neg_context_mask.size(), neg_context_types.size()), flush=True)
rank
,
world_size
,
neg_context_tokens
)
neg_context_mask
=
check_and_append_tensor_for_gather
(
group
,
if
neg_context_tokens
is
not
None
:
# and neg_context_tokens.size()[0] > local_batch_size:
rank
,
world_size
,
neg_context_mask
)
neg_context_tokens
=
check_and_append_tensor_for_gather
(
group
,
rank
,
world_size
,
neg_context_tokens
)
neg_context_types
=
check_and_append_tensor_for_gather
(
group
,
neg_context_mask
=
check_and_append_tensor_for_gather
(
group
,
rank
,
world_size
,
neg_context_mask
)
rank
,
world_size
,
neg_context_types
)
neg_context_types
=
check_and_append_tensor_for_gather
(
group
,
rank
,
world_size
,
neg_context_types
)
#exit()
#if rank == 5:
# print("rank {} middle query_tokens {} query_mask {} query_types {} context_tokens {} context_mask {} context_types {} neg_context_tokens {} neg_context_mask {} neg_context_types {}".format(rank, query_tokens.size(), query_mask.size(),
# query_types.size(), context_tokens.size(), context_mask.size(), context_types.size(), neg_context_tokens.size(), neg_context_mask.size(), neg_context_types.size()), flush=True)
if
neg_context_tokens
is
not
None
:
if
neg_context_tokens
is
not
None
:
context_tokens
=
torch
.
cat
([
context_tokens
,
neg_context_tokens
])
context_tokens
=
torch
.
cat
([
context_tokens
,
neg_context_tokens
])
context_mask
=
torch
.
cat
([
context_mask
,
neg_context_mask
])
context_mask
=
torch
.
cat
([
context_mask
,
neg_context_mask
])
context_types
=
torch
.
cat
([
context_types
,
neg_context_types
])
context_types
=
torch
.
cat
([
context_types
,
neg_context_types
])
#if rank == 5:
# print("rank {} after query_tokens {} query_mask {} query_types {} context_tokens {} context_mask {} context_types {}".format(rank, query_tokens.size(), query_mask.size(),
# query_types.size(), context_tokens.size(), context_mask.size(), context_types.size()), flush=True)
#print("==rank {} query_tokens {} context_tokens {}".format(rank, query_tokens.size(), context_tokens.size()), flush=True)
# Forward model.
# Forward model.
output_tensor
=
model
(
query_tokens
,
query_mask
,
output_tensor
=
model
(
query_tokens
,
query_mask
,
query_types
,
context_tokens
,
query_types
,
context_tokens
,
...
@@ -144,13 +117,10 @@ def orqa(Dataset):
...
@@ -144,13 +117,10 @@ def orqa(Dataset):
query_logits
,
context_logits
=
output_tensor
query_logits
,
context_logits
=
output_tensor
if
world_size
>
1
:
if
world_size
>
1
:
#print("rank {} query_logits {} context_logits {}".format(rank, query_logits.size(), context_logits.size()))
input_
=
torch
.
empty_like
(
context_logits
).
copy_
(
\
input_
=
torch
.
empty_like
(
context_logits
).
copy_
(
\
context_logits
).
detach_
()
context_logits
).
detach_
()
tensor_list
=
[
torch
.
empty_like
(
input_
)
for
_
in
range
(
world_size
)]
tensor_list
=
[
torch
.
empty_like
(
input_
)
for
_
in
range
(
world_size
)]
tensor_list
[
rank
].
copy_
(
input_
)
tensor_list
[
rank
].
copy_
(
input_
)
#print_rank_0("At cross_entropy_loss_func")
#print("rank {} input_ {}".format(rank, input_.size()))
torch
.
distributed
.
all_gather
(
tensor_list
,
input_
,
group
=
group
)
torch
.
distributed
.
all_gather
(
tensor_list
,
input_
,
group
=
group
)
# Check if all-gather happens in order
# Check if all-gather happens in order
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment