Commit c17d880c authored by Neel Kant's avatar Neel Kant
Browse files

Debug null document and exclude trivial candidate

parent efcee158
...@@ -44,7 +44,7 @@ def check_checkpoint_args(checkpoint_args): ...@@ -44,7 +44,7 @@ def check_checkpoint_args(checkpoint_args):
_compare('num_layers') _compare('num_layers')
_compare('hidden_size') _compare('hidden_size')
_compare('num_attention_heads') _compare('num_attention_heads')
_compare('max_position_embeddings') # _compare('max_position_embeddings')
_compare('make_vocab_size_divisible_by') _compare('make_vocab_size_divisible_by')
_compare('padded_vocab_size') _compare('padded_vocab_size')
_compare('tokenizer_type') _compare('tokenizer_type')
......
...@@ -425,7 +425,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, ...@@ -425,7 +425,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
data_impl, data_impl,
skip_warmup) skip_warmup)
if dataset_type == 'ict': if dataset_type in ['ict', 'realm']:
title_dataset = get_indexed_dataset_(data_prefix + '-titles', title_dataset = get_indexed_dataset_(data_prefix + '-titles',
data_impl, data_impl,
skip_warmup) skip_warmup)
......
...@@ -92,7 +92,7 @@ class REALMDataset(Dataset): ...@@ -92,7 +92,7 @@ class REALMDataset(Dataset):
self.pad_id, self.pad_id,
self.masked_lm_prob, self.masked_lm_prob,
np_rng) np_rng)
sample.update({'query_block_indices': np.array([block_idx])}) sample.update({'query_block_indices': np.array([block_idx]).astype(np.int64)})
return sample return sample
def get_samples_mapping(self, data_prefix, num_epochs, max_num_samples): def get_samples_mapping(self, data_prefix, num_epochs, max_num_samples):
......
...@@ -12,7 +12,7 @@ class REALMBertModel(MegatronModule): ...@@ -12,7 +12,7 @@ class REALMBertModel(MegatronModule):
def __init__(self, retriever): def __init__(self, retriever):
super(REALMBertModel, self).__init__() super(REALMBertModel, self).__init__()
bert_args = dict( bert_args = dict(
num_tokentypes=1, num_tokentypes=2,
add_binary_head=False, add_binary_head=False,
parallel_output=True parallel_output=True
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment