Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wxj
Megatron-LM
Commits
6e856fac
"README.md" did not exist on "033f82a9830c307e9c8b47e7664e31b3b022e6cd"
Commit
6e856fac
authored
Mar 31, 2020
by
Neel Kant
Browse files
Add while condition to InverseClozeDataset to protect against corner cases
parent
6f56b909
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
72 additions
and
63 deletions
+72
-63
megatron/data_utils/datasets.py
megatron/data_utils/datasets.py
+72
-63
No files found.
megatron/data_utils/datasets.py
View file @
6e856fac
...
@@ -910,7 +910,7 @@ class InverseClozeDataset(data.Dataset):
...
@@ -910,7 +910,7 @@ class InverseClozeDataset(data.Dataset):
# get seq length. Save 2 tokens for beginning and end
# get seq length. Save 2 tokens for beginning and end
target_seq_length
=
self
.
max_seq_len
-
2
target_seq_length
=
self
.
max_seq_len
-
2
if
rng
.
random
()
<
self
.
short_seq_prob
:
if
rng
.
random
()
<
self
.
short_seq_prob
:
target_seq_length
=
rng
.
randint
(
2
,
target_seq_length
)
target_seq_length
=
rng
.
randint
(
5
,
target_seq_length
)
input_data
,
context_data
=
self
.
get_input_and_context
(
target_seq_length
,
rng
,
np_rng
)
input_data
,
context_data
=
self
.
get_input_and_context
(
target_seq_length
,
rng
,
np_rng
)
input_tokens
,
input_token_types
,
input_pad_mask
=
input_data
input_tokens
,
input_token_types
,
input_pad_mask
=
input_data
...
@@ -950,69 +950,78 @@ class InverseClozeDataset(data.Dataset):
...
@@ -950,69 +950,78 @@ class InverseClozeDataset(data.Dataset):
def
get_input_and_context
(
self
,
target_seq_length
,
rng
,
np_rng
):
def
get_input_and_context
(
self
,
target_seq_length
,
rng
,
np_rng
):
"""fetches a sentence and its surrounding context"""
"""fetches a sentence and its surrounding context"""
doc
=
None
num_tries
=
0
while
doc
is
None
:
while
num_tries
<
20
:
doc_idx
=
self
.
get_weighted_samples
(
np_rng
)
num_tries
+=
1
# doc is a list of sentences
doc
=
None
doc
=
self
.
get_sentence_split_doc
(
doc_idx
)
while
doc
is
None
:
if
not
doc
:
doc_idx
=
self
.
get_weighted_samples
(
np_rng
)
doc
=
None
# doc is a list of sentences
doc
=
self
.
get_sentence_split_doc
(
doc_idx
)
# set up and tokenize the entire selected document
if
not
doc
:
num_sentences
=
len
(
doc
)
doc
=
None
all_token_lists
=
[]
all_token_type_lists
=
[]
# set up and tokenize the entire selected document
for
sentence
in
doc
:
num_sentences
=
len
(
doc
)
tokens
,
token_types
=
self
.
sentence_tokenize
(
sentence
,
0
)
all_token_lists
=
[]
all_token_lists
.
append
(
tokens
)
all_token_type_lists
=
[]
all_token_type_lists
.
append
(
token_types
)
for
sentence
in
doc
:
tokens
,
token_types
=
self
.
sentence_tokenize
(
sentence
,
0
)
sentence_token_lens
=
[
len
(
l
)
for
l
in
all_token_lists
]
all_token_lists
.
append
(
tokens
)
inclusion_mask
=
[
False
]
*
num_sentences
all_token_type_lists
.
append
(
token_types
)
# select a random sentence from the document as input
sentence_token_lens
=
[
len
(
l
)
for
l
in
all_token_lists
]
input_sentence_idx
=
rng
.
randint
(
0
,
len
(
all_token_lists
)
-
1
)
inclusion_mask
=
[
False
]
*
num_sentences
input_tokens
=
all_token_lists
[
input_sentence_idx
].
copy
()[:
self
.
max_seq_len
-
2
]
padless_max_len
=
self
.
max_seq_len
-
2
input_token_types
=
all_token_type_lists
[
input_sentence_idx
].
copy
()[:
self
.
max_seq_len
-
2
]
# select a random sentence from the document as input
# 10% of the time, the input sentence is left in the context.
input_sentence_idx
=
rng
.
randint
(
0
,
len
(
all_token_lists
)
-
1
)
# The other 90% of the time, remove it.
input_tokens
=
all_token_lists
[
input_sentence_idx
].
copy
()[:
target_seq_length
]
if
rng
.
random
()
<
0.1
:
input_token_types
=
all_token_type_lists
[
input_sentence_idx
].
copy
()[:
target_seq_length
]
inclusion_mask
[
input_sentence_idx
]
=
True
if
not
len
(
input_tokens
)
>
0
:
continue
# parameters for examining sentences to remove from the context
view_preceding
=
True
# 10% of the time, the input sentence is left in the context.
view_radius
=
1
# The other 90% of the time, remove it.
while
sum
(
s
for
i
,
s
in
enumerate
(
sentence_token_lens
)
if
inclusion_mask
[
i
])
<
self
.
max_seq_len
-
2
:
if
rng
.
random
()
<
0.1
:
# keep removing sentences while the context is too large.
inclusion_mask
[
input_sentence_idx
]
=
True
if
view_preceding
:
examine_idx
=
input_sentence_idx
-
view_radius
# parameters for examining sentences to remove from the context
if
examine_idx
>=
0
:
view_preceding
=
True
inclusion_mask
[
examine_idx
]
=
True
view_radius
=
1
else
:
while
sum
(
s
for
i
,
s
in
enumerate
(
sentence_token_lens
)
if
inclusion_mask
[
i
])
<
padless_max_len
:
examine_idx
=
input_sentence_idx
+
view_radius
# keep removing sentences while the context is too large.
if
examine_idx
<
num_sentences
:
if
view_preceding
:
inclusion_mask
[
examine_idx
]
=
True
examine_idx
=
input_sentence_idx
-
view_radius
view_radius
+=
1
if
examine_idx
>=
0
:
view_preceding
=
not
view_preceding
inclusion_mask
[
examine_idx
]
=
True
if
view_radius
>
num_sentences
:
else
:
break
examine_idx
=
input_sentence_idx
+
view_radius
if
examine_idx
<
num_sentences
:
# assemble the tokens and token types of the context
inclusion_mask
[
examine_idx
]
=
True
context_tokens
=
list
(
itertools
.
chain
(
view_radius
+=
1
*
[
l
for
i
,
l
in
enumerate
(
all_token_lists
)
if
inclusion_mask
[
i
]]))[:
self
.
max_seq_len
-
2
]
view_preceding
=
not
view_preceding
context_token_types
=
list
(
itertools
.
chain
(
if
view_radius
>
num_sentences
:
*
[
l
for
i
,
l
in
enumerate
(
all_token_type_lists
)
if
inclusion_mask
[
i
]]))[:
self
.
max_seq_len
-
2
]
break
# concatenate 'CLS' and 'SEP' tokens and add extra token types
input_tokens
,
input_token_types
,
input_pad_mask
=
self
.
concat_and_pad_tokens
(
input_tokens
,
input_token_types
)
context_tokens
,
context_token_types
,
context_pad_mask
=
self
.
concat_and_pad_tokens
(
context_tokens
,
context_token_types
)
return
(
input_tokens
,
input_token_types
,
input_pad_mask
),
\
# assemble the tokens and token types of the context
(
context_tokens
,
context_token_types
,
context_pad_mask
)
context_tokens
=
list
(
itertools
.
chain
(
*
[
l
for
i
,
l
in
enumerate
(
all_token_lists
)
if
inclusion_mask
[
i
]]))[:
padless_max_len
]
context_token_types
=
list
(
itertools
.
chain
(
*
[
l
for
i
,
l
in
enumerate
(
all_token_type_lists
)
if
inclusion_mask
[
i
]]))[:
padless_max_len
]
if
not
len
(
context_tokens
)
>
0
:
continue
# concatenate 'CLS' and 'SEP' tokens and add extra token types
input_tokens
,
input_token_types
,
input_pad_mask
=
self
.
concat_and_pad_tokens
(
input_tokens
,
input_token_types
)
context_tokens
,
context_token_types
,
context_pad_mask
=
self
.
concat_and_pad_tokens
(
context_tokens
,
context_token_types
)
return
(
input_tokens
,
input_token_types
,
input_pad_mask
),
\
(
context_tokens
,
context_token_types
,
context_pad_mask
)
else
:
raise
RuntimeError
(
"Could not get a valid data point from InverseClozeDataset"
)
def
concat_and_pad_tokens
(
self
,
tokens
,
token_types
):
def
concat_and_pad_tokens
(
self
,
tokens
,
token_types
):
"""concat with special tokens and pad sequence to self.max_seq_len"""
"""concat with special tokens and pad sequence to self.max_seq_len"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment