Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
1dd51c0e
Commit
1dd51c0e
authored
Mar 27, 2020
by
Neel Kant
Browse files
pretrain_bert_icy.py compiles and runs
parent
b1efc33d
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
14 additions
and
6 deletions
+14
-6
megatron/data_utils/__init__.py
megatron/data_utils/__init__.py
+10
-4
megatron/data_utils/datasets.py
megatron/data_utils/datasets.py
+2
-0
pretrain_bert_ict.py
pretrain_bert_ict.py
+2
-2
No files found.
megatron/data_utils/__init__.py
View file @
1dd51c0e
...
...
@@ -19,7 +19,7 @@ import math
import
torch
from
.samplers
import
DistributedBatchSampler
from
.datasets
import
json_dataset
,
csv_dataset
,
split_ds
,
ConcatDataset
,
SplitDataset
,
bert_sentencepair_dataset
,
GPT2Dataset
from
.datasets
import
json_dataset
,
csv_dataset
,
split_ds
,
ConcatDataset
,
SplitDataset
,
bert_sentencepair_dataset
,
GPT2Dataset
,
InverseClozeDataset
from
.lazy_loader
import
exists_lazy
,
make_lazy
,
lazy_array_loader
from
.tokenization
import
Tokenization
,
CommandToken
,
Tokenizer
,
CharacterLevelTokenizer
,
BertWordPieceTokenizer
,
GPT2BPETokenizer
,
make_tokenizer
from
.
import
corpora
...
...
@@ -120,14 +120,20 @@ def make_dataset(path, seq_length, text_key, label_key, lazy=False, process_fn=N
ds
=
split_ds
(
ds
,
split
)
if
'bert'
in
ds_type
.
lower
():
presplit_sentences
=
kwargs
[
'presplit_sentences'
]
if
'presplit_sentences'
in
kwargs
else
False
dstype
=
bert_sentencepair_dataset
ds
=
[
dstype
(
d
,
max_seq_len
=
seq_length
,
presplit_sentences
=
presplit_sentences
)
if
d
is
not
None
else
None
for
d
in
ds
]
if
'ict'
in
ds_type
.
lower
():
dstype
=
InverseClozeDataset
else
:
dstype
=
bert_sentencepair_dataset
ds
=
[
dstype
(
d
,
max_seq_len
=
seq_length
,
presplit_sentences
=
presplit_sentences
)
if
d
is
not
None
else
None
for
d
in
ds
]
elif
ds_type
.
lower
()
==
'gpt2'
:
ds
=
[
GPT2Dataset
(
d
,
max_seq_len
=
seq_length
)
if
d
is
not
None
else
None
for
d
in
ds
]
else
:
if
'bert'
in
ds_type
.
lower
():
presplit_sentences
=
kwargs
[
'presplit_sentences'
]
if
'presplit_sentences'
in
kwargs
else
False
dstype
=
bert_sentencepair_dataset
if
'ict'
in
ds_type
.
lower
():
dstype
=
InverseClozeDataset
else
:
dstype
=
bert_sentencepair_dataset
ds
=
dstype
(
ds
,
max_seq_len
=
seq_length
,
presplit_sentences
=
presplit_sentences
)
elif
ds_type
.
lower
()
==
'gpt2'
:
ds
=
GPT2Dataset
(
ds
,
max_seq_len
=
seq_length
)
...
...
megatron/data_utils/datasets.py
View file @
1dd51c0e
...
...
@@ -924,6 +924,7 @@ class InverseClozeDataset(data.Dataset):
'context_types'
:
np
.
array
(
context_token_types
),
'context_pad_mask'
:
np
.
array
(
context_pad_mask
)
}
return
sample
def
get_sentence_split_doc
(
self
,
idx
):
...
...
@@ -1015,4 +1016,5 @@ class InverseClozeDataset(data.Dataset):
num_pad
=
max
(
0
,
self
.
max_seq_len
-
len
(
tokens
))
pad_mask
=
[
0
]
*
len
(
tokens
)
+
[
1
]
*
num_pad
tokens
+=
[
self
.
tokenizer
.
get_command
(
'pad'
).
Id
]
*
num_pad
token_types
+=
[
token_types
[
0
]]
*
num_pad
return
tokens
,
token_types
,
pad_mask
pretrain_bert_ict.py
View file @
1dd51c0e
...
...
@@ -96,7 +96,7 @@ def forward_step(data_iterator, model, args, timers):
context_tokens
,
1
-
context_pad_mask
,
context_types
)
softmaxed
=
F
.
softmax
(
retrieval_scores
,
dim
=
0
).
float
()
retrieval_loss
=
F
.
cross_entropy
(
softmaxed
,
torch
.
arange
(
softmaxed
.
size
()[
0
]))
retrieval_loss
=
F
.
cross_entropy
(
softmaxed
,
torch
.
arange
(
softmaxed
.
size
()[
0
])
.
cuda
()
)
reduced_losses
=
reduce_losses
([
retrieval_loss
])
...
...
@@ -114,7 +114,7 @@ def get_train_val_test_data(args):
or
args
.
data_loader
==
'lazy'
or
args
.
data_loader
==
'tfrecords'
):
data_config
=
configure_data
()
ds_type
=
'BERT'
ds_type
=
'BERT
_ict
'
data_config
.
set_defaults
(
data_set_type
=
ds_type
,
transpose
=
False
)
(
train_data
,
val_data
,
test_data
),
tokenizer
=
data_config
.
apply
(
args
)
num_tokens
=
vocab_size_with_padding
(
tokenizer
.
num_tokens
,
args
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment