Commit c17d880c authored by Neel Kant's avatar Neel Kant
Browse files

Debug null document and exclude trivial candidate

parent efcee158
......@@ -44,7 +44,7 @@ def check_checkpoint_args(checkpoint_args):
_compare('num_layers')
_compare('hidden_size')
_compare('num_attention_heads')
_compare('max_position_embeddings')
# _compare('max_position_embeddings')
_compare('make_vocab_size_divisible_by')
_compare('padded_vocab_size')
_compare('tokenizer_type')
......
......@@ -425,7 +425,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
data_impl,
skip_warmup)
if dataset_type == 'ict':
if dataset_type in ['ict', 'realm']:
title_dataset = get_indexed_dataset_(data_prefix + '-titles',
data_impl,
skip_warmup)
......
......@@ -92,7 +92,7 @@ class REALMDataset(Dataset):
self.pad_id,
self.masked_lm_prob,
np_rng)
sample.update({'query_block_indices': np.array([block_idx])})
sample.update({'query_block_indices': np.array([block_idx]).astype(np.int64)})
return sample
def get_samples_mapping(self, data_prefix, num_epochs, max_num_samples):
......
......@@ -12,7 +12,7 @@ class REALMBertModel(MegatronModule):
def __init__(self, retriever):
super(REALMBertModel, self).__init__()
bert_args = dict(
num_tokentypes=1,
num_tokentypes=2,
add_binary_head=False,
parallel_output=True
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment