Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
6e856fac
Commit
6e856fac
authored
Mar 31, 2020
by
Neel Kant
Browse files
Add while condition to InverseClozeDataset to protect against corner cases
parent
6f56b909
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
72 additions
and
63 deletions
+72
-63
megatron/data_utils/datasets.py
megatron/data_utils/datasets.py
+72
-63
No files found.
megatron/data_utils/datasets.py
View file @
6e856fac
...
...
@@ -910,7 +910,7 @@ class InverseClozeDataset(data.Dataset):
# get seq length. Save 2 tokens for beginning and end
target_seq_length
=
self
.
max_seq_len
-
2
if
rng
.
random
()
<
self
.
short_seq_prob
:
target_seq_length
=
rng
.
randint
(
2
,
target_seq_length
)
target_seq_length
=
rng
.
randint
(
5
,
target_seq_length
)
input_data
,
context_data
=
self
.
get_input_and_context
(
target_seq_length
,
rng
,
np_rng
)
input_tokens
,
input_token_types
,
input_pad_mask
=
input_data
...
...
@@ -950,69 +950,78 @@ class InverseClozeDataset(data.Dataset):
def
get_input_and_context
(
self
,
target_seq_length
,
rng
,
np_rng
):
"""fetches a sentence and its surrounding context"""
doc
=
None
while
doc
is
None
:
doc_idx
=
self
.
get_weighted_samples
(
np_rng
)
# doc is a list of sentences
doc
=
self
.
get_sentence_split_doc
(
doc_idx
)
if
not
doc
:
doc
=
None
# set up and tokenize the entire selected document
num_sentences
=
len
(
doc
)
all_token_lists
=
[]
all_token_type_lists
=
[]
for
sentence
in
doc
:
tokens
,
token_types
=
self
.
sentence_tokenize
(
sentence
,
0
)
all_token_lists
.
append
(
tokens
)
all_token_type_lists
.
append
(
token_types
)
sentence_token_lens
=
[
len
(
l
)
for
l
in
all_token_lists
]
inclusion_mask
=
[
False
]
*
num_sentences
# select a random sentence from the document as input
input_sentence_idx
=
rng
.
randint
(
0
,
len
(
all_token_lists
)
-
1
)
input_tokens
=
all_token_lists
[
input_sentence_idx
].
copy
()[:
self
.
max_seq_len
-
2
]
input_token_types
=
all_token_type_lists
[
input_sentence_idx
].
copy
()[:
self
.
max_seq_len
-
2
]
# 10% of the time, the input sentence is left in the context.
# The other 90% of the time, remove it.
if
rng
.
random
()
<
0.1
:
inclusion_mask
[
input_sentence_idx
]
=
True
# parameters for examining sentences to remove from the context
view_preceding
=
True
view_radius
=
1
while
sum
(
s
for
i
,
s
in
enumerate
(
sentence_token_lens
)
if
inclusion_mask
[
i
])
<
self
.
max_seq_len
-
2
:
# keep removing sentences while the context is too large.
if
view_preceding
:
examine_idx
=
input_sentence_idx
-
view_radius
if
examine_idx
>=
0
:
inclusion_mask
[
examine_idx
]
=
True
else
:
examine_idx
=
input_sentence_idx
+
view_radius
if
examine_idx
<
num_sentences
:
inclusion_mask
[
examine_idx
]
=
True
view_radius
+=
1
view_preceding
=
not
view_preceding
if
view_radius
>
num_sentences
:
break
# assemble the tokens and token types of the context
context_tokens
=
list
(
itertools
.
chain
(
*
[
l
for
i
,
l
in
enumerate
(
all_token_lists
)
if
inclusion_mask
[
i
]]))[:
self
.
max_seq_len
-
2
]
context_token_types
=
list
(
itertools
.
chain
(
*
[
l
for
i
,
l
in
enumerate
(
all_token_type_lists
)
if
inclusion_mask
[
i
]]))[:
self
.
max_seq_len
-
2
]
# concatenate 'CLS' and 'SEP' tokens and add extra token types
input_tokens
,
input_token_types
,
input_pad_mask
=
self
.
concat_and_pad_tokens
(
input_tokens
,
input_token_types
)
context_tokens
,
context_token_types
,
context_pad_mask
=
self
.
concat_and_pad_tokens
(
context_tokens
,
context_token_types
)
num_tries
=
0
while
num_tries
<
20
:
num_tries
+=
1
doc
=
None
while
doc
is
None
:
doc_idx
=
self
.
get_weighted_samples
(
np_rng
)
# doc is a list of sentences
doc
=
self
.
get_sentence_split_doc
(
doc_idx
)
if
not
doc
:
doc
=
None
# set up and tokenize the entire selected document
num_sentences
=
len
(
doc
)
all_token_lists
=
[]
all_token_type_lists
=
[]
for
sentence
in
doc
:
tokens
,
token_types
=
self
.
sentence_tokenize
(
sentence
,
0
)
all_token_lists
.
append
(
tokens
)
all_token_type_lists
.
append
(
token_types
)
sentence_token_lens
=
[
len
(
l
)
for
l
in
all_token_lists
]
inclusion_mask
=
[
False
]
*
num_sentences
padless_max_len
=
self
.
max_seq_len
-
2
# select a random sentence from the document as input
input_sentence_idx
=
rng
.
randint
(
0
,
len
(
all_token_lists
)
-
1
)
input_tokens
=
all_token_lists
[
input_sentence_idx
].
copy
()[:
target_seq_length
]
input_token_types
=
all_token_type_lists
[
input_sentence_idx
].
copy
()[:
target_seq_length
]
if
not
len
(
input_tokens
)
>
0
:
continue
# 10% of the time, the input sentence is left in the context.
# The other 90% of the time, remove it.
if
rng
.
random
()
<
0.1
:
inclusion_mask
[
input_sentence_idx
]
=
True
# parameters for examining sentences to remove from the context
view_preceding
=
True
view_radius
=
1
while
sum
(
s
for
i
,
s
in
enumerate
(
sentence_token_lens
)
if
inclusion_mask
[
i
])
<
padless_max_len
:
# keep removing sentences while the context is too large.
if
view_preceding
:
examine_idx
=
input_sentence_idx
-
view_radius
if
examine_idx
>=
0
:
inclusion_mask
[
examine_idx
]
=
True
else
:
examine_idx
=
input_sentence_idx
+
view_radius
if
examine_idx
<
num_sentences
:
inclusion_mask
[
examine_idx
]
=
True
view_radius
+=
1
view_preceding
=
not
view_preceding
if
view_radius
>
num_sentences
:
break
return
(
input_tokens
,
input_token_types
,
input_pad_mask
),
\
(
context_tokens
,
context_token_types
,
context_pad_mask
)
# assemble the tokens and token types of the context
context_tokens
=
list
(
itertools
.
chain
(
*
[
l
for
i
,
l
in
enumerate
(
all_token_lists
)
if
inclusion_mask
[
i
]]))[:
padless_max_len
]
context_token_types
=
list
(
itertools
.
chain
(
*
[
l
for
i
,
l
in
enumerate
(
all_token_type_lists
)
if
inclusion_mask
[
i
]]))[:
padless_max_len
]
if
not
len
(
context_tokens
)
>
0
:
continue
# concatenate 'CLS' and 'SEP' tokens and add extra token types
input_tokens
,
input_token_types
,
input_pad_mask
=
self
.
concat_and_pad_tokens
(
input_tokens
,
input_token_types
)
context_tokens
,
context_token_types
,
context_pad_mask
=
self
.
concat_and_pad_tokens
(
context_tokens
,
context_token_types
)
return
(
input_tokens
,
input_token_types
,
input_pad_mask
),
\
(
context_tokens
,
context_token_types
,
context_pad_mask
)
else
:
raise
RuntimeError
(
"Could not get a valid data point from InverseClozeDataset"
)
def
concat_and_pad_tokens
(
self
,
tokens
,
token_types
):
"""concat with special tokens and pad sequence to self.max_seq_len"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment