Commit 3354081f authored by Neel Kant's avatar Neel Kant
Browse files

More ict_merge changes and interactive testing

parent 6495927e
......@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from .package_info import (
__description__,
......@@ -30,7 +31,6 @@ from .global_vars import get_tensorboard_writer
from .global_vars import get_adlr_autoresume
from .global_vars import get_timers
import torch
def print_rank_0(message):
"""If distributed is initialized print only on rank 0."""
......@@ -38,4 +38,4 @@ def print_rank_0(message):
if torch.distributed.get_rank() == 0:
print(message, flush=True)
else:
print(message, flush=True)
print(message, flush=True)
\ No newline at end of file
......@@ -23,9 +23,8 @@ import numpy as np
import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron import mpu
from megatron import mpu, print_rank_0
from megatron import get_args
from megatron import print_rank_0
def check_checkpoint_args(checkpoint_args):
......
......@@ -22,15 +22,13 @@ import numpy as np
import torch
from torch.utils.data import Dataset
from megatron import get_tokenizer, get_args
from megatron import get_tokenizer, get_args, print_rank_0
from megatron import mpu
from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset
from megatron.data.dataset_utils import get_a_and_b_segments
from megatron.data.dataset_utils import truncate_segments
from megatron.data.dataset_utils import create_tokens_and_tokentypes
from megatron.data.dataset_utils import pad_and_convert_to_numpy
from megatron.data.dataset_utils import create_masked_lm_predictions
from megatron import print_rank_0
class BertDataset(Dataset):
......@@ -85,55 +83,6 @@ class BertDataset(Dataset):
self.masked_lm_prob, np_rng)
def get_indexed_dataset_(data_prefix, data_impl, skip_warmup):
print_rank_0(' > building dataset index ...')
start_time = time.time()
indexed_dataset = make_indexed_dataset(data_prefix,
data_impl,
skip_warmup)
assert indexed_dataset.sizes.shape[0] == indexed_dataset.doc_idx[-1]
print_rank_0(' > finished creating indexed dataset in {:4f} '
'seconds'.format(time.time() - start_time))
print_rank_0(' > indexed dataset stats:')
print_rank_0(' number of documents: {}'.format(
indexed_dataset.doc_idx.shape[0] - 1))
print_rank_0(' number of sentences: {}'.format(
indexed_dataset.sizes.shape[0]))
return indexed_dataset
def get_train_valid_test_split_(splits_string, size):
""" Get dataset splits from comma or '/' separated string list."""
splits = []
if splits_string.find(',') != -1:
splits = [float(s) for s in splits_string.split(',')]
elif splits_string.find('/') != -1:
splits = [float(s) for s in splits_string.split('/')]
else:
splits = [float(splits_string)]
while len(splits) < 3:
splits.append(0.)
splits = splits[:3]
splits_sum = sum(splits)
assert splits_sum > 0.0
splits = [split / splits_sum for split in splits]
splits_index = [0]
for index, split in enumerate(splits):
splits_index.append(splits_index[index] +
int(round(split * float(size))))
diff = splits_index[-1] - size
for index in range(1, len(splits_index)):
splits_index[index] -= diff
assert len(splits_index) == 4
assert splits_index[-1] == size
return splits_index
def get_samples_mapping_(indexed_dataset,
data_prefix,
num_epochs,
......
......@@ -18,17 +18,19 @@
# https://github.com/google-research/albert/blob/master/create_pretraining_data.py
# with some modifications.
import time
import collections
import itertools
import numpy as np
from megatron import print_rank_0, get_args
from megatron import get_args, print_rank_0
from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset
DSET_TYPE_STD = 'standard_bert'
DSET_TYPE_ICT = 'ict'
DSET_TYPES = [DSET_TYPE_ICT, DSET_TYPE_STD]
def compile_helper():
"""Compile helper function ar runtime. Make sure this
is invoked on a single process."""
......@@ -447,3 +449,54 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
test_dataset = build_dataset(2, 'test')
return (train_dataset, valid_dataset, test_dataset)
def get_indexed_dataset_(data_prefix, data_impl, skip_warmup):
print_rank_0(' > building dataset index ...')
start_time = time.time()
indexed_dataset = make_indexed_dataset(data_prefix,
data_impl,
skip_warmup)
assert indexed_dataset.sizes.shape[0] == indexed_dataset.doc_idx[-1]
print_rank_0(' > finished creating indexed dataset in {:4f} '
'seconds'.format(time.time() - start_time))
print_rank_0(' > indexed dataset stats:')
print_rank_0(' number of documents: {}'.format(
indexed_dataset.doc_idx.shape[0] - 1))
print_rank_0(' number of sentences: {}'.format(
indexed_dataset.sizes.shape[0]))
return indexed_dataset
def get_train_valid_test_split_(splits_string, size):
""" Get dataset splits from comma or '/' separated string list."""
splits = []
if splits_string.find(',') != -1:
splits = [float(s) for s in splits_string.split(',')]
elif splits_string.find('/') != -1:
splits = [float(s) for s in splits_string.split('/')]
else:
splits = [float(splits_string)]
while len(splits) < 3:
splits.append(0.)
splits = splits[:3]
splits_sum = sum(splits)
assert splits_sum > 0.0
splits = [split / splits_sum for split in splits]
splits_index = [0]
for index, split in enumerate(splits):
splits_index.append(splits_index[index] +
int(round(split * float(size))))
diff = splits_index[-1] - size
for index in range(1, len(splits_index)):
splits_index[index] -= diff
assert len(splits_index) == 4
assert splits_index[-1] == size
return splits_index
......@@ -21,8 +21,7 @@ import time
import numpy as np
import torch
from megatron import print_rank_0
from megatron import mpu
from megatron import mpu, print_rank_0
from megatron.data.bert_dataset import get_train_valid_test_split_
from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset
......
......@@ -401,7 +401,8 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
const uint64_t max_num_samples,
const int32_t max_seq_length,
const int32_t seed,
const bool verbose) {
const bool verbose,
const bool use_one_sent_blocks) {
/* Build a mapping of (start-index, end-index, sequence-length) where
start and end index are the indices of the sentences in the sample
and sequence-length is the target sequence length.
......@@ -442,6 +443,12 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
int64_t num_samples = -1;
DocIdx* maps = NULL;
// Acceptable number of sentences per block.
int min_num_sent = 2;
if (use_one_sent_blocks) {
min_num_sent = 1;
}
// Perform two iterations, in the first iteration get the size
// and allocate memory and in the second iteration populate the map.
bool second = false;
......@@ -453,6 +460,9 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
// Current map index.
uint64_t map_index = 0;
uint64_t empty_docs = 0;
uint64_t one_sent_docs = 0;
uint64_t long_sent_docs = 0;
// For each epoch:
for (int32_t epoch=0; epoch<num_epochs; ++epoch) {
// assign every block a unique id
......@@ -480,19 +490,31 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
// Remaining documents.
auto num_remain_sent = sent_index_last - sent_index_first;
// Some bookkeeping
if ((epoch == 0) && (!second)) {
if (num_remain_sent == 0) {
++empty_docs;
}
if (num_remain_sent == 1) {
++one_sent_docs;
}
}
// Detect documents with long sentences.
bool contains_long_sentence = false;
if (num_remain_sent > 1) {
if (num_remain_sent >= min_num_sent) {
for (auto sent_index=sent_index_first;
sent_index < sent_index_last; ++sent_index) {
if (sizes[sent_index] > LONG_SENTENCE_LEN){
if ((epoch == 0) && (!second)) {
++long_sent_docs;
}
contains_long_sentence = true;
break;
}
}
}
// If we have more than two sentences.
if ((num_remain_sent > 1) && (!contains_long_sentence)) {
// If we have enough sentences and no long sentences.
if ((num_remain_sent >= min_num_sent) && (!contains_long_sentence)) {
// Set values.
auto seq_len = int32_t{0};
......@@ -508,12 +530,12 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
--num_remain_sent;
// If we have reached the target length.
// and if not only one sentence is left in the document.
// and if we have at least two sentneces.
// and there are an acceptable number of sentences left
// and if we have at least the minimum number of sentences.
// or if we have reached end of the document.
if (((seq_len >= target_seq_len) &&
(num_remain_sent > 1) &&
(num_sent > 1) ) || (num_remain_sent == 0)) {
(num_remain_sent >= min_num_sent) &&
(num_sent >= min_num_sent) ) || (num_remain_sent == 0)) {
// Populate the map.
if (second) {
......@@ -538,11 +560,16 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
} // for (auto sent_index=sent_index_first; ...
} // if (num_remain_sent > 1) {
} // for (int doc=0; doc < num_docs; ++doc) {
block_id = 0;
} // for (int epoch=0; epoch < num_epochs; ++epoch) {
if (!second) {
if (verbose) {
cout << " number of empty documents: " << empty_docs <<
endl << std::flush;
cout << " number of documents with one sentence: " <<
one_sent_docs << endl << std::flush;
cout << " number of documents with long sentences: " <<
long_sent_docs << endl << std::flush;
cout << " will create mapping for " << map_index <<
" samples" << endl << std::flush;
}
......@@ -554,9 +581,9 @@ py::array build_blocks_mapping_impl(const py::array_t<int64_t>& docs_,
} // for (int iteration=0; iteration < 2; ++iteration) {
// Shuffle.
// We need a 64 bit random number generator as we might have more
// than 2 billion samples.
Shuffle.
We need a 64 bit random number generator as we might have more
than 2 billion samples.
std::mt19937_64 rand64_gen(seed + 1);
for (auto i=(num_samples - 1); i > 0; --i) {
const auto j = static_cast<int64_t>(rand64_gen() % (i + 1));
......@@ -591,20 +618,21 @@ py::array build_blocks_mapping(const py::array_t<int64_t>& docs_,
const uint64_t max_num_samples,
const int max_seq_length,
const int seed,
const bool verbose) {
const bool verbose,
const bool use_one_sent_blocks) {
if (sizes_.size() > std::numeric_limits<uint32_t>::max()) {
if (verbose) {
cout << " using uint64 for data mapping..." << endl << std::flush;
}
return build_blocks_mapping_impl<uint64_t>(docs_, sizes_, titles_sizes_,
num_epochs, max_num_samples, max_seq_length, seed, verbose);
num_epochs, max_num_samples, max_seq_length, seed, verbose, use_one_sent_blocks);
} else {
if (verbose) {
cout << " using uint32 for data mapping..." << endl << std::flush;
}
return build_blocks_mapping_impl<uint32_t>(docs_, sizes_, titles_sizes_,
num_epochs, max_num_samples, max_seq_length, seed, verbose);
num_epochs, max_num_samples, max_seq_length, seed, verbose, use_one_sent_blocks);
}
}
......
......@@ -65,6 +65,7 @@ class ICTDataset(Dataset):
query_tokens, query_pad_mask = self.concat_and_pad_tokens(query)
block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title)
print(self.tokenizer.decode_token_ids(block_tokens), '\n')
block_data = np.array([start_idx, end_idx, doc_idx, block_idx]).astype(np.int64)
sample = {
......
......@@ -4,7 +4,7 @@ import time
import numpy as np
import torch
from megatron import print_rank_0, mpu
from megatron import mpu, print_rank_0
def join_str_list(str_list):
......@@ -19,7 +19,7 @@ def join_str_list(str_list):
def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epochs,
max_num_samples, max_seq_length, seed, name):
max_num_samples, max_seq_length, seed, name, use_one_sent_docs=False):
"""Get samples mapping for a dataset over fixed size blocks. This function also requires
a dataset of the titles for the source documents since their lengths must be taken into account."""
if not num_epochs:
......@@ -39,6 +39,8 @@ def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epo
indexmap_filename += '_{}mns'.format(max_num_samples)
indexmap_filename += '_{}msl'.format(max_seq_length)
indexmap_filename += '_{}s'.format(seed)
if use_one_sent_docs:
indexmap_filename += '_1sentok'
indexmap_filename += '.npy'
# Build the indexed mapping if not exist.
......@@ -67,7 +69,8 @@ def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epo
max_num_samples,
max_seq_length-3, # account for added tokens
seed,
verbose)
verbose,
use_one_sent_docs)
print_rank_0(' > done building samples index mapping')
np.save(indexmap_filename, samples_mapping, allow_pickle=True)
print_rank_0(' > saved the index mapping in {}'.format(
......
......@@ -17,7 +17,7 @@
import torch
from megatron import get_args
from megatron import get_args, print_rank_0
from megatron.model.bert_model import bert_attention_mask_func
from megatron.model.bert_model import bert_extended_attention_mask
from megatron.model.bert_model import bert_position_ids
......@@ -26,7 +26,6 @@ from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal
from megatron.model.utils import scaled_init_method_normal
from megatron.module import MegatronModule
from megatron import print_rank_0
class Classification(MegatronModule):
......
......@@ -17,7 +17,7 @@
import torch
from megatron import get_args
from megatron import get_args, print_rank_0
from megatron.model.bert_model import bert_attention_mask_func
from megatron.model.bert_model import bert_extended_attention_mask
from megatron.model.bert_model import bert_position_ids
......@@ -26,7 +26,6 @@ from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal
from megatron.model.utils import scaled_init_method_normal
from megatron.module import MegatronModule
from megatron import print_rank_0
class MultipleChoice(MegatronModule):
......
......@@ -125,7 +125,7 @@ class ICTBertModel(MegatronModule):
class IREncoderBertModel(MegatronModule):
"""Bert Language model."""
"""BERT-based encoder for queries or blocks used for learned information retrieval."""
def __init__(self, ict_head_size, num_tokentypes=2, parallel_output=True):
super(IREncoderBertModel, self).__init__()
args = get_args()
......@@ -158,9 +158,8 @@ class IREncoderBertModel(MegatronModule):
tokentype_ids=tokentype_ids)
# Output.
if self.add_ict_head:
ict_logits = self.ict_head(pooled_output)
return ict_logits, None
ict_logits = self.ict_head(pooled_output)
return ict_logits, None
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
......
......@@ -20,7 +20,6 @@ from abc import abstractmethod
from .bert_tokenization import FullTokenizer as FullBertTokenizer
from .gpt2_tokenization import GPT2Tokenizer
from megatron.data.realm_dataset_utils import join_str_list
def build_tokenizer(args):
......@@ -160,8 +159,15 @@ class _BertWordPieceTokenizer(AbstractTokenizer):
tokens = self.tokenizer.convert_ids_to_tokens(token_ids)
exclude_list = ['[PAD]', '[CLS]']
non_pads = [t for t in tokens if t not in exclude_list]
joined_strs = join_str_list(non_pads)
return joined_strs
result = ""
for s in non_pads:
if s.startswith("##"):
result += s[2:]
else:
result += " " + s
return result
@property
def cls(self):
......
......@@ -22,11 +22,10 @@ import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from apex.optimizers import FusedAdam as Adam
from megatron import get_args
from megatron import get_args, print_rank_0
from megatron import get_timers
from megatron import get_tensorboard_writer
from megatron import mpu
from megatron import print_rank_0
from megatron.checkpointing import load_checkpoint
from megatron.checkpointing import save_checkpoint
from megatron.fp16 import FP16_Module
......
......@@ -19,10 +19,9 @@ import sys
import torch
from megatron import get_args
from megatron import get_args, print_rank_0
from megatron import get_adlr_autoresume
from megatron import mpu
from megatron import print_rank_0
from megatron.checkpointing import save_checkpoint
from megatron.data.samplers import DistributedBatchSampler
from megatron.fp16 import FP16_Optimizer
......@@ -173,3 +172,5 @@ def get_ltor_masks_and_position_ids(data,
attention_mask = (attention_mask < 0.5)
return attention_mask, loss_mask, position_ids
......@@ -18,10 +18,9 @@
import torch
import torch.nn.functional as F
from megatron import get_args
from megatron import get_args, print_rank_0
from megatron import get_timers
from megatron import mpu
from megatron import print_rank_0
from megatron.data.dataset_utils import build_train_valid_test_datasets
from megatron.model import BertModel
from megatron.training import pretrain
......
......@@ -19,10 +19,9 @@ import torch
import torch.distributed as dist
import torch.nn.functional as F
from megatron import get_args
from megatron import get_args, print_rank_0
from megatron import get_timers
from megatron import mpu
from megatron import print_rank_0
from megatron.data.dataset_utils import build_train_valid_test_datasets
from megatron.model import ICTBertModel
from megatron.training import pretrain
......
......@@ -17,11 +17,10 @@
import torch
from megatron import get_args
from megatron import get_args, print_rank_0
from megatron import get_timers
from megatron import get_tokenizer
from megatron import mpu
from megatron import print_rank_0
from megatron.data.gpt2_dataset import build_train_valid_test_datasets
from megatron.model import GPT2Model
from megatron.training import pretrain
......
......@@ -20,9 +20,8 @@ import time
import torch
from megatron import get_args
from megatron import get_args, print_rank_0
from megatron import mpu
from megatron import print_rank_0
from tasks.finetune_utils import build_data_loader
from tasks.finetune_utils import process_batch
......
......@@ -17,10 +17,9 @@
import torch
from megatron import get_args
from megatron import get_args, print_rank_0
from megatron import get_timers
from megatron import mpu
from megatron import print_rank_0
from megatron.checkpointing import load_checkpoint
from megatron.checkpointing import save_checkpoint
from megatron.training import evaluate_and_print_results
......
......@@ -15,9 +15,8 @@
"""GLUE finetuning/evaluation."""
from megatron import get_args
from megatron import get_args, print_rank_0
from megatron import get_tokenizer
from megatron import print_rank_0
from megatron.model.classification import Classification
from tasks.eval_utils import accuracy_func_provider
from tasks.finetune_utils import finetune
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment