Commit 2a3b445d authored by Neel Kant's avatar Neel Kant
Browse files

Cosmetic changes

parent ac967fa0
......@@ -65,7 +65,6 @@ class ICTDataset(Dataset):
query_tokens, query_pad_mask = self.concat_and_pad_tokens(query)
block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title)
print(self.tokenizer.decode_token_ids(block_tokens), '\n')
block_data = np.array([start_idx, end_idx, doc_idx, block_idx]).astype(np.int64)
sample = {
......
......@@ -33,8 +33,11 @@ num_batches = 0
def general_model_provider(only_query_model=False, only_block_model=False):
"""Build the model."""
args = get_args()
if args.ict_head_size is None:
raise ValueError("Need to specify --ict-head-size to provide an ICTBertModel")
assert args.ict_head_size is not None, \
"Need to specify --ict-head-size to provide an ICTBertModel"
assert args.model_parallel_size == 1, \
"Model parallel size > 1 not supported for ICT"
print_rank_0('building ICTBertModel...')
......@@ -89,7 +92,6 @@ def forward_step(data_iterator, model):
timers('batch generator').stop()
# Forward model.
# retrieval_scores = model(query_tokens, query_pad_mask, block_tokens, block_pad_mask).float()
query_logits, block_logits = model(query_tokens, query_pad_mask, block_tokens, block_pad_mask)
data_parallel_size = dist.get_world_size() / args.model_parallel_size
......@@ -100,11 +102,11 @@ def forward_step(data_iterator, model):
all_query_logits = torch.zeros(all_logits_shape).type(query_logits.dtype).cuda()
all_block_logits = all_query_logits.clone().cuda()
# record this processes' data and then merge with other processes below
# record this processes' data
all_query_logits[args.rank * batch_size:(args.rank + 1) * batch_size] = query_logits
all_block_logits[args.rank * batch_size:(args.rank + 1) * batch_size] = block_logits
# currently this assumes model parallel size == 1.
# merge data from all processes
dist.all_reduce(all_query_logits)
dist.all_reduce(all_block_logits)
......@@ -153,6 +155,5 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
if __name__ == "__main__":
pretrain(train_valid_test_datasets_provider, model_provider, forward_step,
args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment