Commit 24034e03 authored by Neel Kant's avatar Neel Kant
Browse files

Revise dataset_type

parent cf0100cf
......@@ -27,21 +27,27 @@ from megatron import mpu
from megatron.data.dataset_utils import build_training_sample
from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset
from megatron.data.ict_dataset import InverseClozeDataset
from megatron.data.realm_dataset import RealmDataset
from megatron import print_rank_0
DATASET_TYPES = ['standard_bert', 'ict', 'realm']
def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
train_valid_test_num_samples,
max_seq_length, masked_lm_prob,
short_seq_prob, seed, skip_warmup,
ict_dataset=False):
dataset_type='standard_bert'):
if dataset_type not in DATASET_TYPES:
raise ValueError("Invalid dataset_type: ", dataset_type)
# Indexed dataset.
indexed_dataset = get_indexed_dataset_(data_prefix,
data_impl,
skip_warmup)
if ict_dataset:
if dataset_type == 'ict':
title_dataset = get_indexed_dataset_(data_prefix + '-titles',
data_impl,
skip_warmup)
......@@ -91,18 +97,20 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
seed=seed
)
if ict_dataset:
if dataset_type == 'ict':
dataset = InverseClozeDataset(
block_dataset=indexed_dataset,
title_dataset=title_dataset,
**kwargs
)
else:
dataset = BertDataset(
dataset_cls = BertDataset if dataset_type == 'standard_bert' else RealmDataset
dataset = dataset_cls(
indexed_dataset=indexed_dataset,
masked_lm_prob=masked_lm_prob,
**kwargs
)
# Set the original pointer so dataset remains the main dataset.
indexed_dataset.set_doc_idx(doc_idx_ptr)
# Checks.
......
......@@ -132,7 +132,8 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
masked_lm_prob=args.mask_prob,
short_seq_prob=args.short_seq_prob,
seed=args.seed,
skip_warmup=(not args.mmap_warmup))
skip_warmup=(not args.mmap_warmup),
dataset_type='realm')
print_rank_0("> finished creating BERT ICT datasets ...")
return train_ds, valid_ds, test_ds
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment