Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
ebfbfcec
Commit
ebfbfcec
authored
Jun 09, 2021
by
Mostofa Patwary
Browse files
fixed the tensor size miss-mass issue
parent
04c79f30
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
56 additions
and
33 deletions
+56
-33
tasks/orqa/supervised/finetune.py
tasks/orqa/supervised/finetune.py
+56
-33
No files found.
tasks/orqa/supervised/finetune.py
View file @
ebfbfcec
...
@@ -33,6 +33,44 @@ from tasks.orqa.supervised.eval_utils import accuracy_func_provider
...
@@ -33,6 +33,44 @@ from tasks.orqa.supervised.eval_utils import accuracy_func_provider
from
tasks.orqa.supervised.eval_utils
import
process_batch
,
task_collate_fn
from
tasks.orqa.supervised.eval_utils
import
process_batch
,
task_collate_fn
from
tasks.orqa.evaluate_utils
import
ORQAEvaluator
from
tasks.orqa.evaluate_utils
import
ORQAEvaluator
# input_ is a 2D tensor
def
check_and_append_tensor_for_gather
(
group
,
rank
,
world_size
,
input_
):
# gather the size of the first dimension of the tensor from all ranks
current_length
=
input_
.
size
()[
0
]
first_dim
=
torch
.
tensor
([[
current_length
]],
device
=
torch
.
cuda
.
current_device
())
input_list
=
[
torch
.
empty_like
(
first_dim
)
for
_
in
range
(
world_size
)]
input_list
[
rank
].
copy_
(
first_dim
)
torch
.
distributed
.
all_gather
(
input_list
,
first_dim
,
group
=
group
)
all_input_list
=
torch
.
cat
(
input_list
,
dim
=
0
).
contiguous
()
max_length
=
torch
.
max
(
all_input_list
)
min_length
=
torch
.
min
(
all_input_list
)
#if rank == 0:
# print("rank {} all pad neg_context_tokens 0 {}".format(rank, input_[0]), flush=True)
# print("rank {} all pad neg_context_tokens max_length {}".format(rank, input_[max_length-1]), flush=True)
if
max_length
>
current_length
:
#print("rank {} before pad neg_context_tokens current_length-1 {}".format(rank, input_[current_length-1]), flush=True)
#torch.set_printoptions(profile="full")
#input_ = torch.nn.functional.pad(input=input_,
# pad=(0, 0, 0, max_length - current_length))
padding
=
tuple
([
0
]
*
(
input_
.
dim
()
*
2
-
1
))
+
\
tuple
([
max_length
-
current_length
])
input_
=
F
.
pad
(
input
=
input_
,
pad
=
padding
)
#print("rank {} after pad neg_context_tokens current_length-1 {}".format(rank, input_[current_length-1]), flush=True)
#print("rank {} after pad neg_context_tokens current_length {}".format(rank, input_[current_length]), flush=True)
#print("rank {} after pad neg_context_tokens max_length {}".format(rank, input_[max_length-1]), flush=True)
#if rank == 0:
# print("rank {} all pad neg_context_tokens 0 {}".format(rank, input_[0]), flush=True)
# print("rank {} all pad neg_context_tokens max_length {}".format(rank, input_[max_length-1]), flush=True)
return
input_
def
orqa
(
Dataset
):
def
orqa
(
Dataset
):
def
cross_entropy_forward_step
(
batch
,
model
):
def
cross_entropy_forward_step
(
batch
,
model
):
...
@@ -56,7 +94,6 @@ def orqa(Dataset):
...
@@ -56,7 +94,6 @@ def orqa(Dataset):
timers
(
'batch generator'
).
stop
()
timers
(
'batch generator'
).
stop
()
local_batch_size
=
query_tokens
.
shape
[
0
]
local_batch_size
=
query_tokens
.
shape
[
0
]
#print("rank {} query_tokens {} context_tokens {} batch {} neg_context_tokens {}".format(rank, query_tokens.size(), context_tokens.size(), local_batch_size, neg_context_tokens.size()), flush=True)
# Text representation of query and context
# Text representation of query and context
query_list
,
context_list
=
[],
[]
query_list
,
context_list
=
[],
[]
...
@@ -64,44 +101,30 @@ def orqa(Dataset):
...
@@ -64,44 +101,30 @@ def orqa(Dataset):
query_list
.
append
(
tokenizer
.
decode
(
query_tokens
[
i
].
tolist
()))
query_list
.
append
(
tokenizer
.
decode
(
query_tokens
[
i
].
tolist
()))
context_list
.
append
(
tokenizer
.
decode
(
context_tokens
[
i
].
tolist
()))
context_list
.
append
(
tokenizer
.
decode
(
context_tokens
[
i
].
tolist
()))
if
neg_context_tokens
.
size
()[
0
]
>
200
:
#if rank == 5:
current_length
=
neg_context_tokens
.
size
()[
0
]
# print("rank {} before query_tokens {} query_mask {} query_types {} context_tokens {} context_mask {} context_types {} neg_context_tokens {} neg_context_mask {} neg_context_types {}".format(rank, query_tokens.size(), query_mask.size(),
first_dim
=
torch
.
tensor
([[
neg_context_tokens
.
size
()[
0
]]],
device
=
torch
.
cuda
.
current_device
())
# query_types.size(), context_tokens.size(), context_mask.size(), context_types.size(), neg_context_tokens.size(), neg_context_mask.size(), neg_context_types.size()), flush=True)
neg_context_list
=
[
torch
.
empty_like
(
first_dim
)
for
_
in
range
(
world_size
)]
neg_context_list
[
rank
].
copy_
(
first_dim
)
if
neg_context_tokens
is
not
None
:
# and neg_context_tokens.size()[0] > local_batch_size:
torch
.
distributed
.
all_gather
(
neg_context_list
,
first_dim
,
group
=
group
)
neg_context_tokens
=
check_and_append_tensor_for_gather
(
group
,
rank
,
world_size
,
neg_context_tokens
)
all_neg_context_list
=
torch
.
cat
(
neg_context_list
,
dim
=
0
).
contiguous
()
neg_context_mask
=
check_and_append_tensor_for_gather
(
group
,
rank
,
world_size
,
neg_context_mask
)
max_length
=
torch
.
max
(
all_neg_context_list
)
neg_context_types
=
check_and_append_tensor_for_gather
(
group
,
rank
,
world_size
,
neg_context_types
)
torch
.
set_printoptions
(
profile
=
"full"
)
#exit()
if
max_length
>
current_length
:
#if rank == 5:
print
(
"rank {} before pad neg_context_tokens {}"
.
format
(
rank
,
neg_context_tokens
[
current_length
-
1
]),
flush
=
True
)
# print("rank {} middle query_tokens {} query_mask {} query_types {} context_tokens {} context_mask {} context_types {} neg_context_tokens {} neg_context_mask {} neg_context_types {}".format(rank, query_tokens.size(), query_mask.size(),
neg_context_tokens
=
torch
.
nn
.
functional
.
pad
(
input
=
neg_context_tokens
,
pad
=
(
0
,
0
,
0
,
max_length
-
neg_context_tokens
.
size
()[
0
]))
# query_types.size(), context_tokens.size(), context_mask.size(), context_types.size(), neg_context_tokens.size(), neg_context_mask.size(), neg_context_types.size()), flush=True)
input_
=
torch
.
empty_like
(
neg_context_tokens
).
copy_
(
\
neg_context_tokens
).
detach_
()
tensor_list
=
[
torch
.
empty_like
(
input_
)
for
_
in
range
(
world_size
)]
tensor_list
[
rank
].
copy_
(
input_
)
torch
.
distributed
.
all_gather
(
tensor_list
,
input_
,
group
=
group
)
if
max_length
>
current_length
:
print
(
"rank {} after pad neg_context_tokens current_length-1 {}"
.
format
(
rank
,
neg_context_tokens
[
current_length
-
1
]),
flush
=
True
)
print
(
"rank {} after pad neg_context_tokens current_length {}"
.
format
(
rank
,
neg_context_tokens
[
current_length
]),
flush
=
True
)
print
(
"rank {} after pad neg_context_tokens max_length-1 {}"
.
format
(
rank
,
neg_context_tokens
[
max_length
-
1
]),
flush
=
True
)
if
rank
==
0
:
print
(
"rank {} other pad neg_context_tokens current_length-1 {}"
.
format
(
rank
,
tensor_list
[
5
][
451
]),
flush
=
True
)
print
(
"rank {} other pad neg_context_tokens current_length {}"
.
format
(
rank
,
tensor_list
[
5
][
452
]),
flush
=
True
)
print
(
"rank {} other pad neg_context_tokens max_length-1 {}"
.
format
(
rank
,
tensor_list
[
5
][
max_length
-
1
]),
flush
=
True
)
torch
.
set_printoptions
(
profile
=
"default"
)
exit
()
if
neg_context_tokens
is
not
None
:
if
neg_context_tokens
is
not
None
:
context_tokens
=
torch
.
cat
([
context_tokens
,
neg_context_tokens
])
context_tokens
=
torch
.
cat
([
context_tokens
,
neg_context_tokens
])
context_mask
=
torch
.
cat
([
context_mask
,
neg_context_mask
])
context_mask
=
torch
.
cat
([
context_mask
,
neg_context_mask
])
context_types
=
torch
.
cat
([
context_types
,
neg_context_types
])
context_types
=
torch
.
cat
([
context_types
,
neg_context_types
])
#if rank == 5:
# print("rank {} after query_tokens {} query_mask {} query_types {} context_tokens {} context_mask {} context_types {}".format(rank, query_tokens.size(), query_mask.size(),
# query_types.size(), context_tokens.size(), context_mask.size(), context_types.size()), flush=True)
#print("==rank {} query_tokens {} context_tokens {}".format(rank, query_tokens.size(), context_tokens.size()), flush=True)
#print("==rank {} query_tokens {} context_tokens {}".format(rank, query_tokens.size(), context_tokens.size()), flush=True)
# Forward model.
# Forward model.
output_tensor
=
model
(
query_tokens
,
query_mask
,
output_tensor
=
model
(
query_tokens
,
query_mask
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment