Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
24034e03
Commit
24034e03
authored
Apr 23, 2020
by
Neel Kant
Browse files
Revise dataset_type
parent
cf0100cf
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
14 additions
and
5 deletions
+14
-5
megatron/data/bert_dataset.py
megatron/data/bert_dataset.py
+12
-4
pretrain_realm.py
pretrain_realm.py
+2
-1
No files found.
megatron/data/bert_dataset.py
View file @
24034e03
...
...
@@ -27,21 +27,27 @@ from megatron import mpu
from
megatron.data.dataset_utils
import
build_training_sample
from
megatron.data.indexed_dataset
import
make_dataset
as
make_indexed_dataset
from
megatron.data.ict_dataset
import
InverseClozeDataset
from
megatron.data.realm_dataset
import
RealmDataset
from
megatron
import
print_rank_0
DATASET_TYPES
=
[
'standard_bert'
,
'ict'
,
'realm'
]
def
build_train_valid_test_datasets
(
data_prefix
,
data_impl
,
splits_string
,
train_valid_test_num_samples
,
max_seq_length
,
masked_lm_prob
,
short_seq_prob
,
seed
,
skip_warmup
,
ict_dataset
=
False
):
dataset_type
=
'standard_bert'
):
if
dataset_type
not
in
DATASET_TYPES
:
raise
ValueError
(
"Invalid dataset_type: "
,
dataset_type
)
# Indexed dataset.
indexed_dataset
=
get_indexed_dataset_
(
data_prefix
,
data_impl
,
skip_warmup
)
if
ict_
dataset
:
if
dataset
_type
==
'ict'
:
title_dataset
=
get_indexed_dataset_
(
data_prefix
+
'-titles'
,
data_impl
,
skip_warmup
)
...
...
@@ -91,18 +97,20 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
seed
=
seed
)
if
ict_
dataset
:
if
dataset
_type
==
'ict'
:
dataset
=
InverseClozeDataset
(
block_dataset
=
indexed_dataset
,
title_dataset
=
title_dataset
,
**
kwargs
)
else
:
dataset
=
BertDataset
(
dataset_cls
=
BertDataset
if
dataset_type
==
'standard_bert'
else
RealmDataset
dataset
=
dataset_cls
(
indexed_dataset
=
indexed_dataset
,
masked_lm_prob
=
masked_lm_prob
,
**
kwargs
)
# Set the original pointer so dataset remains the main dataset.
indexed_dataset
.
set_doc_idx
(
doc_idx_ptr
)
# Checks.
...
...
pretrain_realm.py
View file @
24034e03
...
...
@@ -132,7 +132,8 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
masked_lm_prob
=
args
.
mask_prob
,
short_seq_prob
=
args
.
short_seq_prob
,
seed
=
args
.
seed
,
skip_warmup
=
(
not
args
.
mmap_warmup
))
skip_warmup
=
(
not
args
.
mmap_warmup
),
dataset_type
=
'realm'
)
print_rank_0
(
"> finished creating BERT ICT datasets ..."
)
return
train_ds
,
valid_ds
,
test_ds
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment