Commit 24034e03 authored by Neel Kant's avatar Neel Kant
Browse files

Revise dataset_type

parent cf0100cf
...@@ -27,21 +27,27 @@ from megatron import mpu ...@@ -27,21 +27,27 @@ from megatron import mpu
from megatron.data.dataset_utils import build_training_sample from megatron.data.dataset_utils import build_training_sample
from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset
from megatron.data.ict_dataset import InverseClozeDataset from megatron.data.ict_dataset import InverseClozeDataset
from megatron.data.realm_dataset import RealmDataset
from megatron import print_rank_0 from megatron import print_rank_0
DATASET_TYPES = ['standard_bert', 'ict', 'realm']
def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
train_valid_test_num_samples, train_valid_test_num_samples,
max_seq_length, masked_lm_prob, max_seq_length, masked_lm_prob,
short_seq_prob, seed, skip_warmup, short_seq_prob, seed, skip_warmup,
ict_dataset=False): dataset_type='standard_bert'):
if dataset_type not in DATASET_TYPES:
raise ValueError("Invalid dataset_type: ", dataset_type)
# Indexed dataset. # Indexed dataset.
indexed_dataset = get_indexed_dataset_(data_prefix, indexed_dataset = get_indexed_dataset_(data_prefix,
data_impl, data_impl,
skip_warmup) skip_warmup)
if ict_dataset: if dataset_type == 'ict':
title_dataset = get_indexed_dataset_(data_prefix + '-titles', title_dataset = get_indexed_dataset_(data_prefix + '-titles',
data_impl, data_impl,
skip_warmup) skip_warmup)
...@@ -91,18 +97,20 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, ...@@ -91,18 +97,20 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
seed=seed seed=seed
) )
if ict_dataset: if dataset_type == 'ict':
dataset = InverseClozeDataset( dataset = InverseClozeDataset(
block_dataset=indexed_dataset, block_dataset=indexed_dataset,
title_dataset=title_dataset, title_dataset=title_dataset,
**kwargs **kwargs
) )
else: else:
dataset = BertDataset( dataset_cls = BertDataset if dataset_type == 'standard_bert' else RealmDataset
dataset = dataset_cls(
indexed_dataset=indexed_dataset, indexed_dataset=indexed_dataset,
masked_lm_prob=masked_lm_prob, masked_lm_prob=masked_lm_prob,
**kwargs **kwargs
) )
# Set the original pointer so dataset remains the main dataset. # Set the original pointer so dataset remains the main dataset.
indexed_dataset.set_doc_idx(doc_idx_ptr) indexed_dataset.set_doc_idx(doc_idx_ptr)
# Checks. # Checks.
......
...@@ -132,7 +132,8 @@ def train_valid_test_datasets_provider(train_val_test_num_samples): ...@@ -132,7 +132,8 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
masked_lm_prob=args.mask_prob, masked_lm_prob=args.mask_prob,
short_seq_prob=args.short_seq_prob, short_seq_prob=args.short_seq_prob,
seed=args.seed, seed=args.seed,
skip_warmup=(not args.mmap_warmup)) skip_warmup=(not args.mmap_warmup),
dataset_type='realm')
print_rank_0("> finished creating BERT ICT datasets ...") print_rank_0("> finished creating BERT ICT datasets ...")
return train_ds, valid_ds, test_ds return train_ds, valid_ds, test_ds
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment