Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
d2eabecb
Commit
d2eabecb
authored
Mar 24, 2020
by
Neel Kant
Browse files
Complete __getitem__ for InverseClozeDataset
parent
1c4e8955
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
39 additions
and
56 deletions
+39
-56
megatron/data_utils/datasets.py
megatron/data_utils/datasets.py
+39
-56
No files found.
megatron/data_utils/datasets.py
View file @
d2eabecb
...
...
@@ -791,7 +791,7 @@ class bert_sentencepair_dataset(data.Dataset):
def
mask_token
(
self
,
idx
,
tokens
,
types
,
vocab_words
,
rng
):
"""
helper function to mask `idx` token from `tokens` according to
section 3.
3
.1 of https://arxiv.org/pdf/1810.04805.pdf
section 3.
1
.1 of https://arxiv.org/pdf/1810.04805.pdf
"""
label
=
tokens
[
idx
]
if
rng
.
random
()
<
0.8
:
...
...
@@ -856,15 +856,12 @@ class InverseClozeDataset(data.Dataset):
Arguments:
ds (Dataset or array-like): data corpus to use for training
max_seq_len (int): maximum sequence length to use for a target sentence
mask_lm_prob (float): proportion of tokens to mask for masked LM
max_preds_per_seq (int): Maximum number of masked tokens per sentence pair. Default: math.ceil(max_seq_len*mask_lm_prob/10)*10
short_seq_prob (float): Proportion of sentence pairs purposefully shorter than max_seq_len
dataset_size (int): number of random sentencepairs in the dataset. Default: len(ds)*(len(ds)-1)
"""
def
__init__
(
self
,
ds
,
max_seq_len
=
512
,
mask_lm_prob
=
.
15
,
max_preds_per_seq
=
None
,
short_seq_prob
=
.
01
,
dataset_size
=
None
,
...
...
@@ -877,10 +874,6 @@ class InverseClozeDataset(data.Dataset):
self
.
vocab_words
=
list
(
self
.
tokenizer
.
text_token_vocab
.
values
())
self
.
ds
.
SetTokenizer
(
None
)
self
.
max_seq_len
=
max_seq_len
self
.
mask_lm_prob
=
mask_lm_prob
if
max_preds_per_seq
is
None
:
max_preds_per_seq
=
math
.
ceil
(
max_seq_len
*
mask_lm_prob
/
10
)
*
10
self
.
max_preds_per_seq
=
max_preds_per_seq
self
.
short_seq_prob
=
short_seq_prob
self
.
dataset_size
=
dataset_size
if
self
.
dataset_size
is
None
:
...
...
@@ -889,9 +882,6 @@ class InverseClozeDataset(data.Dataset):
if
not
self
.
presplit_sentences
:
nltk
.
download
(
'punkt'
,
download_dir
=
"./nltk"
)
self
.
weighted
=
weighted
self
.
get_weighting
()
def
get_weighting
(
self
):
if
self
.
weighted
:
if
hasattr
(
self
.
ds
,
'is_lazy'
)
and
self
.
ds
.
is_lazy
:
lens
=
np
.
array
(
self
.
ds
.
lens
)
...
...
@@ -907,7 +897,7 @@ class InverseClozeDataset(data.Dataset):
idx
=
np_rng
.
randint
(
self
.
total_len
)
return
bisect_right
(
self
.
weighting
,
idx
)
else
:
return
np_rng
.
randint
(
self
.
ds_len
)
return
np_rng
.
randint
(
self
.
ds_len
-
1
)
def
__len__
(
self
):
return
self
.
dataset_size
...
...
@@ -917,15 +907,24 @@ class InverseClozeDataset(data.Dataset):
rng
=
random
.
Random
(
idx
)
np_rng
=
np
.
random
.
RandomState
(
seed
=
[
rng
.
randint
(
0
,
2
**
32
-
1
)
for
_
in
range
(
16
)])
# get seq length
target_seq_length
=
self
.
max_seq_len
# get seq length
. Save 2 tokens for beginning and end
target_seq_length
=
self
.
max_seq_len
-
2
if
rng
.
random
()
<
self
.
short_seq_prob
:
target_seq_length
=
rng
.
randint
(
2
,
target_seq_length
)
input_data
,
context_data
,
doc_idx
=
self
.
get_input_and_context
(
target_seq_length
,
rng
,
np_rng
)
# get other documents too
# return sample
input_data
,
context_data
=
self
.
get_input_and_context
(
target_seq_length
,
rng
,
np_rng
)
input_tokens
,
input_token_types
,
input_pad_mask
=
input_data
context_tokens
,
context_token_types
,
context_pad_mask
=
context_data
sample
=
{
'input_text'
:
np
.
array
(
input_tokens
),
'input_types'
:
np
.
array
(
input_token_types
),
'input_pad_mask'
:
np
.
array
(
input_pad_mask
),
'context_text'
:
np
.
array
(
context_tokens
),
'context_types'
:
np
.
array
(
context_token_types
),
'context_pad_mask'
:
np
.
array
(
context_pad_mask
)
}
return
sample
def
get_sentence_split_doc
(
self
,
idx
):
"""fetch document at index idx and split into sentences"""
...
...
@@ -950,17 +949,15 @@ class InverseClozeDataset(data.Dataset):
def
get_input_and_context
(
self
,
target_seq_length
,
rng
,
np_rng
):
"""fetches a sentence and its surrounding context"""
doc
=
doc_idx
=
None
doc
=
None
while
doc
is
None
:
if
self
.
weighted
:
doc_idx
=
self
.
get_weighted_samples
(
np_rng
)
else
:
doc_idx
=
rng
.
randint
(
0
,
self
.
ds_len
-
1
)
doc_idx
=
self
.
get_weighted_samples
(
np_rng
)
# doc is a list of sentences
doc
=
self
.
get_sentence_split_doc
(
doc_idx
)
if
not
doc
:
doc
=
None
# set up and tokenize the entire selected document
num_sentences
=
len
(
doc
)
all_token_lists
=
[]
all_token_type_lists
=
[]
...
...
@@ -972,9 +969,10 @@ class InverseClozeDataset(data.Dataset):
sentence_token_lens
=
[
len
(
l
)
for
l
in
all_token_lists
]
inclusion_mask
=
[
True
]
*
num_sentences
# select a random sentence from the document as input
input_sentence_idx
=
rng
.
randint
(
0
,
len
(
all_token_lists
)
-
1
)
input_
sentence_
tokens
=
all_token_lists
[
input_sentence_idx
].
copy
()
input_
sentence_
token_types
=
all_token_type_lists
[
input_sentence_idx
].
copy
()
input_tokens
=
all_token_lists
[
input_sentence_idx
].
copy
()
input_token_types
=
all_token_type_lists
[
input_sentence_idx
].
copy
()
# 10% of the time, the input sentence is left in the context.
# The other 90% of the time, remove it.
...
...
@@ -994,42 +992,27 @@ class InverseClozeDataset(data.Dataset):
inclusion_mask
[
num_sentences
-
view_radius
]
=
False
remove_preceding
=
not
remove_preceding
# assemble the tokens and token types of the context
context_tokens
=
list
(
itertools
.
chain
(
*
[
l
for
i
,
l
in
enumerate
(
all_token_lists
)
if
inclusion_mask
[
i
]]))
context_token_types
=
list
(
itertools
.
chain
(
*
[
l
for
i
,
l
in
enumerate
(
all_token_type_lists
)
if
inclusion_mask
[
i
]]))
return
(
input_sentence_tokens
,
input_sentence_token_types
),
(
context_tokens
,
context_token_types
),
doc_idx
def
calc_seq_len
(
self
,
max_seq_len
):
return
max_seq_len
-
3
def
mask_token
(
self
,
idx
,
tokens
,
types
,
vocab_words
,
rng
):
"""
helper function to mask `idx` token from `tokens` according to
section 3.3.1 of https://arxiv.org/pdf/1810.04805.pdf
"""
label
=
tokens
[
idx
]
if
rng
.
random
()
<
0.8
:
new_label
=
self
.
tokenizer
.
get_command
(
'MASK'
).
Id
else
:
if
rng
.
random
()
<
0.5
:
new_label
=
label
else
:
new_label
=
rng
.
choice
(
vocab_words
)
tokens
[
idx
]
=
new_label
# concatenate 'CLS' and 'SEP' tokens and add extra token types
input_tokens
,
input_token_types
,
input_pad_mask
=
self
.
concat_and_pad_tokens
(
input_tokens
,
input_token_types
)
context_tokens
,
context_token_types
,
context_pad_mask
=
self
.
concat_and_pad_tokens
(
context_tokens
,
context_token_types
)
return
label
return
(
input_tokens
,
input_token_types
,
input_pad_mask
),
\
(
context_tokens
,
context_token_types
,
context_pad_mask
)
def
pad_seq
(
self
,
seq
):
"""helper function to pad sequence pair"""
num_pad
=
max
(
0
,
self
.
max_seq_len
-
len
(
seq
))
pad_mask
=
[
0
]
*
len
(
seq
)
+
[
1
]
*
num_pad
seq
+=
[
self
.
tokenizer
.
get_command
(
'pad'
).
Id
]
*
num_pad
return
seq
,
pad_mask
def
concat_and_pad_tokens
(
self
,
tokens
,
token_types
):
"""concat with special tokens and pad sequence to self.max_seq_len"""
tokens
=
[
self
.
tokenizer
.
get_command
(
'ENC'
).
Id
]
+
tokens
+
[
self
.
tokenizer
.
get_command
(
'sep'
).
Id
]
token_types
=
[
token_types
[
0
]]
+
token_types
+
[
token_types
[
0
]]
def
concat_tokens
(
self
,
tokens_a
,
token_types_a
,
tokens_b
,
token_types_b
):
tokens
=
[
self
.
tokenizer
.
get_command
(
'ENC'
).
Id
]
+
tokens
_a
+
[
self
.
tokenizer
.
get_command
(
'sep'
).
Id
]
+
tokens_b
+
[
self
.
tokenizer
.
get_command
(
'sep'
).
Id
]
token
_type
s
=
[
token_types_a
[
0
]]
+
token_types_a
+
[
token_types_a
[
0
]]
+
token_types_b
+
[
token_types_b
[
0
]]
return
tokens
,
token_types
num_pad
=
max
(
0
,
self
.
max_seq_len
-
len
(
tokens
))
pad_mask
=
[
0
]
*
len
(
tokens
)
+
[
1
]
*
num_pad
tokens
+
=
[
self
.
tokenizer
.
get_command
(
'pad'
).
Id
]
*
num_pad
return
tokens
,
token_types
,
pad_mask
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment