"tests/vscode:/vscode.git/clone" did not exist on "3adbfa1812913d3aa3e124c68715d9ab68dc674a"
Commit 16a64c41 authored by Neel Kant's avatar Neel Kant
Browse files

Move get_train_val... to dataset_utils

parent 59031aa7
......@@ -26,106 +26,11 @@ from megatron import get_tokenizer
from megatron import mpu
from megatron.data.dataset_utils import build_training_sample
from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset
from megatron.data.realm_dataset import InverseClozeDataset
from megatron import print_rank_0
DATASET_TYPES = ['standard_bert', 'ict', 'realm']
def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
train_valid_test_num_samples,
max_seq_length, masked_lm_prob,
short_seq_prob, seed, skip_warmup,
dataset_type='standard_bert'):
if dataset_type not in DATASET_TYPES:
raise ValueError("Invalid dataset_type: ", dataset_type)
# Indexed dataset.
indexed_dataset = get_indexed_dataset_(data_prefix,
data_impl,
skip_warmup)
if dataset_type == 'ict':
title_dataset = get_indexed_dataset_(data_prefix + '-titles',
data_impl,
skip_warmup)
# Get start and end indices of train/valid/train into doc-idx
# Note that doc-idx is desinged to be num-docs + 1 so we can
# easily iterate over it.
total_num_of_documents = indexed_dataset.doc_idx.shape[0] - 1
splits = get_train_valid_test_split_(splits_string, total_num_of_documents)
# Print stats about the splits.
print_rank_0(' > dataset split:')
def print_split_stats(name, index):
print_rank_0(' {}:'.format(name))
print_rank_0(' document indices in [{}, {}) total of {} '
'documents'.format(splits[index], splits[index + 1],
splits[index + 1] - splits[index]))
start_index = indexed_dataset.doc_idx[splits[index]]
end_index = indexed_dataset.doc_idx[splits[index + 1]]
print_rank_0(' sentence indices in [{}, {}) total of {} '
'sentences'.format(start_index, end_index,
end_index - start_index))
print_split_stats('train', 0)
print_split_stats('validation', 1)
print_split_stats('test', 2)
def build_dataset(index, name):
from megatron.data.realm_dataset import RealmDataset
dataset = None
if splits[index + 1] > splits[index]:
# Get the pointer to the original doc-idx so we can set it later.
doc_idx_ptr = indexed_dataset.get_doc_idx()
# Slice the doc-idx
start_index = splits[index]
# Add +1 so we can index into the dataset to get the upper bound.
end_index = splits[index + 1] + 1
# New doc_idx view.
indexed_dataset.set_doc_idx(doc_idx_ptr[start_index:end_index])
# Build the dataset accordingly.
kwargs = dict(
name=name,
data_prefix=data_prefix,
num_epochs=None,
max_num_samples=train_valid_test_num_samples[index],
max_seq_length=max_seq_length,
short_seq_prob=short_seq_prob,
seed=seed
)
if dataset_type == 'ict':
dataset = InverseClozeDataset(
block_dataset=indexed_dataset,
title_dataset=title_dataset,
**kwargs
)
else:
dataset_cls = BertDataset if dataset_type == 'standard_bert' else RealmDataset
dataset = dataset_cls(
indexed_dataset=indexed_dataset,
masked_lm_prob=masked_lm_prob,
**kwargs
)
# Set the original pointer so dataset remains the main dataset.
indexed_dataset.set_doc_idx(doc_idx_ptr)
# Checks.
assert indexed_dataset.doc_idx[0] == 0
assert indexed_dataset.doc_idx.shape[0] == \
(total_num_of_documents + 1)
return dataset
train_dataset = build_dataset(0, 'train')
valid_dataset = build_dataset(1, 'valid')
test_dataset = build_dataset(2, 'test')
return (train_dataset, valid_dataset, test_dataset)
class BertDataset(Dataset):
def __init__(self, name, indexed_dataset, data_prefix,
......
......@@ -22,6 +22,9 @@ import collections
import itertools
import numpy as np
from megatron import print_rank_0
from megatron.data.bert_dataset import DATASET_TYPES, get_indexed_dataset_, get_train_valid_test_split_, BertDataset
from megatron.data.realm_dataset import InverseClozeDataset
def compile_helper():
......@@ -406,3 +409,97 @@ def pad_and_convert_to_numpy(tokens, tokentypes, masked_positions,
loss_mask_np = np.array(loss_mask, dtype=np.int64)
return tokens_np, tokentypes_np, labels_np, padding_mask_np, loss_mask_np
def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
train_valid_test_num_samples,
max_seq_length, masked_lm_prob,
short_seq_prob, seed, skip_warmup,
dataset_type='standard_bert'):
if dataset_type not in DATASET_TYPES:
raise ValueError("Invalid dataset_type: ", dataset_type)
# Indexed dataset.
indexed_dataset = get_indexed_dataset_(data_prefix,
data_impl,
skip_warmup)
if dataset_type == 'ict':
title_dataset = get_indexed_dataset_(data_prefix + '-titles',
data_impl,
skip_warmup)
# Get start and end indices of train/valid/train into doc-idx
# Note that doc-idx is desinged to be num-docs + 1 so we can
# easily iterate over it.
total_num_of_documents = indexed_dataset.doc_idx.shape[0] - 1
splits = get_train_valid_test_split_(splits_string, total_num_of_documents)
# Print stats about the splits.
print_rank_0(' > dataset split:')
def print_split_stats(name, index):
print_rank_0(' {}:'.format(name))
print_rank_0(' document indices in [{}, {}) total of {} '
'documents'.format(splits[index], splits[index + 1],
splits[index + 1] - splits[index]))
start_index = indexed_dataset.doc_idx[splits[index]]
end_index = indexed_dataset.doc_idx[splits[index + 1]]
print_rank_0(' sentence indices in [{}, {}) total of {} '
'sentences'.format(start_index, end_index,
end_index - start_index))
print_split_stats('train', 0)
print_split_stats('validation', 1)
print_split_stats('test', 2)
def build_dataset(index, name):
from megatron.data.realm_dataset import RealmDataset
dataset = None
if splits[index + 1] > splits[index]:
# Get the pointer to the original doc-idx so we can set it later.
doc_idx_ptr = indexed_dataset.get_doc_idx()
# Slice the doc-idx
start_index = splits[index]
# Add +1 so we can index into the dataset to get the upper bound.
end_index = splits[index + 1] + 1
# New doc_idx view.
indexed_dataset.set_doc_idx(doc_idx_ptr[start_index:end_index])
# Build the dataset accordingly.
kwargs = dict(
name=name,
data_prefix=data_prefix,
num_epochs=None,
max_num_samples=train_valid_test_num_samples[index],
max_seq_length=max_seq_length,
short_seq_prob=short_seq_prob,
seed=seed
)
if dataset_type == 'ict':
dataset = InverseClozeDataset(
block_dataset=indexed_dataset,
title_dataset=title_dataset,
**kwargs
)
else:
dataset_cls = BertDataset if dataset_type == 'standard_bert' else RealmDataset
dataset = dataset_cls(
indexed_dataset=indexed_dataset,
masked_lm_prob=masked_lm_prob,
**kwargs
)
# Set the original pointer so dataset remains the main dataset.
indexed_dataset.set_doc_idx(doc_idx_ptr)
# Checks.
assert indexed_dataset.doc_idx[0] == 0
assert indexed_dataset.doc_idx.shape[0] == \
(total_num_of_documents + 1)
return dataset
train_dataset = build_dataset(0, 'train')
valid_dataset = build_dataset(1, 'valid')
test_dataset = build_dataset(2, 'test')
return (train_dataset, valid_dataset, test_dataset)
\ No newline at end of file
......@@ -22,7 +22,7 @@ from megatron import get_args
from megatron import get_timers
from megatron import mpu
from megatron import print_rank_0
from megatron.data.bert_dataset import build_train_valid_test_datasets
from megatron.data.dataset_utils import build_train_valid_test_datasets
from megatron.model import BertModel
from megatron.training import pretrain
from megatron.utils import reduce_losses
......
......@@ -22,7 +22,7 @@ from megatron import get_args
from megatron import get_timers
from megatron import mpu
from megatron import print_rank_0
from megatron.data.bert_dataset import build_train_valid_test_datasets
from megatron.data.dataset_utils import build_train_valid_test_datasets
from megatron.model import ICTBertModel
from megatron.training import pretrain
from megatron.utils import reduce_losses
......
......@@ -24,7 +24,7 @@ from megatron import get_args
from megatron import get_timers
from megatron import mpu
from megatron import print_rank_0
from megatron.data.bert_dataset import build_train_valid_test_datasets
from megatron.data.dataset_utils import build_train_valid_test_datasets
from megatron.model import REALMBertModel, REALMRetriever
from megatron.training import pretrain
from megatron.utils import reduce_losses
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment