Commit 5235ed87 authored by Neel Kant's avatar Neel Kant
Browse files

Simplify batch and forward for ICT dataset and model

parent aae93362
...@@ -28,13 +28,13 @@ class InverseClozeDataset(Dataset): ...@@ -28,13 +28,13 @@ class InverseClozeDataset(Dataset):
self.samples_mapping = self.get_samples_mapping( self.samples_mapping = self.get_samples_mapping(
data_prefix, num_epochs, max_num_samples) data_prefix, num_epochs, max_num_samples)
tokenizer = get_tokenizer() self.tokenizer = get_tokenizer()
self.vocab_id_list = list(tokenizer.inv_vocab.keys()) self.vocab_id_list = list(self.tokenizer.inv_vocab.keys())
self.vocab_id_to_token_list = tokenizer.inv_vocab self.vocab_id_to_token_list = self.tokenizer.inv_vocab
self.cls_id = tokenizer.cls self.cls_id = self.tokenizer.cls
self.sep_id = tokenizer.sep self.sep_id = self.tokenizer.sep
self.mask_id = tokenizer.mask self.mask_id = self.tokenizer.mask
self.pad_id = tokenizer.pad self.pad_id = self.tokenizer.pad
def __len__(self): def __len__(self):
return self.samples_mapping.shape[0] return self.samples_mapping.shape[0]
...@@ -62,21 +62,36 @@ class InverseClozeDataset(Dataset): ...@@ -62,21 +62,36 @@ class InverseClozeDataset(Dataset):
query = query[:self.max_seq_length - 2] query = query[:self.max_seq_length - 2]
block = list(itertools.chain(*block))[:self.max_seq_length - (3 + len(title))] block = list(itertools.chain(*block))[:self.max_seq_length - (3 + len(title))]
query_tokens, query_token_types, query_pad_mask = self.concat_and_pad_tokens(query) query_tokens, query_pad_mask = self.concat_and_pad_tokens(query)
block_tokens, block_token_types, block_pad_mask = self.concat_and_pad_tokens(block, title) block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title)
sample = { sample = {
'query_tokens': np.array(query_tokens), 'query_tokens': np.array(query_tokens),
'query_types': np.array(query_token_types),
'query_pad_mask': np.array(query_pad_mask), 'query_pad_mask': np.array(query_pad_mask),
'block_tokens': np.array(block_tokens), 'block_tokens': np.array(block_tokens),
'block_types': np.array(block_token_types),
'block_pad_mask': np.array(block_pad_mask), 'block_pad_mask': np.array(block_pad_mask),
'block_indices': np.array([start_idx, end_idx, doc_idx, block_idx]).astype(np.int64) 'block_data': np.array([start_idx, end_idx, doc_idx, block_idx]).astype(np.int64)
} }
return sample return sample
def encode_text(self, text):
return self.tokenizer.tokenize(text)
def decode_tokens(self, token_ids):
tokens = self.tokenizer.tokenizer.convert_ids_to_tokens(token_ids)
return ' '.join(tokens)
def get_block(self, start_idx, end_idx, doc_idx):
"""Get the IDs for an evidence block plus the title of the corresponding document"""
block = [self.context_dataset[i] for i in range(start_idx, end_idx)]
title = list(self.titles_dataset[int(doc_idx)])
block = list(itertools.chain(*block))[self.max_seq_length - (3 + len(title))]
block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title)
return block_tokens, block_pad_mask
def concat_and_pad_tokens(self, tokens, title=None): def concat_and_pad_tokens(self, tokens, title=None):
"""concat with special tokens and pad sequence to self.max_seq_length""" """concat with special tokens and pad sequence to self.max_seq_length"""
tokens = [self.cls_id] + tokens + [self.sep_id] tokens = [self.cls_id] + tokens + [self.sep_id]
...@@ -85,16 +100,9 @@ class InverseClozeDataset(Dataset): ...@@ -85,16 +100,9 @@ class InverseClozeDataset(Dataset):
assert len(tokens) <= self.max_seq_length, len(tokens) assert len(tokens) <= self.max_seq_length, len(tokens)
num_pad = self.max_seq_length - len(tokens) num_pad = self.max_seq_length - len(tokens)
pad_mask = [0] * len(tokens) + [1] * num_pad pad_mask = [1] * len(tokens) + [0] * num_pad
tokens += [self.pad_id] * num_pad tokens += [self.pad_id] * num_pad
token_types = [0] * self.max_seq_length return tokens, pad_mask
return tokens, token_types, pad_mask
def get_block(self, start_idx, end_idx, doc_idx, block_idx):
block = [self.context_dataset[i] for i in range(start_idx, end_idx)]
title = list(self.titles_dataset[int(doc_idx)])
block = list(itertools.chain(*block))[self.max_seq_length - (3 + len(title))]
def get_samples_mapping(self, data_prefix, num_epochs, max_num_samples): def get_samples_mapping(self, data_prefix, num_epochs, max_num_samples):
if not num_epochs: if not num_epochs:
......
...@@ -273,8 +273,10 @@ class ICTBertModel(MegatronModule): ...@@ -273,8 +273,10 @@ class ICTBertModel(MegatronModule):
"""Bert-based module for Inverse Cloze task.""" """Bert-based module for Inverse Cloze task."""
def __init__(self, def __init__(self,
ict_head_size, ict_head_size,
num_tokentypes=2, num_tokentypes=1,
parallel_output=True): parallel_output=True,
only_query_model=False,
only_block_model=False):
super(ICTBertModel, self).__init__() super(ICTBertModel, self).__init__()
bert_args = dict( bert_args = dict(
num_tokentypes=num_tokentypes, num_tokentypes=num_tokentypes,
...@@ -282,44 +284,68 @@ class ICTBertModel(MegatronModule): ...@@ -282,44 +284,68 @@ class ICTBertModel(MegatronModule):
ict_head_size=ict_head_size, ict_head_size=ict_head_size,
parallel_output=parallel_output parallel_output=parallel_output
) )
assert not only_block_model and only_query_model
self.use_block_model = not only_query_model
self.use_query_model = not only_block_model
# this model embeds (pseudo-)queries - Embed_input in the paper if self.use_query_model:
self.query_model = BertModel(**bert_args) # this model embeds (pseudo-)queries - Embed_input in the paper
self._query_key = 'question_model' self.query_model = BertModel(**bert_args)
self._query_key = 'question_model'
# this model embeds evidence blocks - Embed_doc in the paper if self.use_block_model:
self.block_model = BertModel(**bert_args) # this model embeds evidence blocks - Embed_doc in the paper
self._block_key = 'context_model' self.block_model = BertModel(**bert_args)
self._block_key = 'context_model'
def forward(self, query_tokens, query_attention_mask, query_types, def forward(self, query_tokens, query_attention_mask, block_tokens, block_attention_mask):
block_tokens, block_attention_mask, block_types):
"""Run a forward pass for each of the models and compute the similarity scores.""" """Run a forward pass for each of the models and compute the similarity scores."""
query_logits = self.embed_query(query_tokens, query_attention_mask)
block_logits = self.embed_block(block_tokens, block_attention_mask)
# [batch x embed] * [embed x batch]
retrieval_scores = query_logits.matmul(torch.transpose(block_logits, 0, 1))
return retrieval_scores
def embed_query(self, query_tokens, query_attention_mask):
"""Embed a batch of tokens using the query model"""
if self.use_query_model:
query_types = torch.zeros(query_tokens.shape).type(torch.float16).cuda()
query_ict_logits, _ = self.query_model.forward(query_tokens, query_attention_mask, query_types)
return query_ict_logits
else:
raise ValueError("Cannot embed query without query model.")
def embed_block(self, block_tokens, block_attention_mask):
"""Embed a batch of tokens using the block model"""
if self.use_block_model:
block_types = torch.zeros(block_tokens.shape).type(torch.float16).cuda()
block_ict_logits, _ = self.block_model.forward(block_tokens, block_attention_mask, block_types)
return block_ict_logits
else:
raise ValueError("Cannot embed block without block model.")
query_logits, _ = self.query_model.forward(query_tokens, 1 - query_attention_mask, query_types) def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False):
block_logits, _ = self.block_model.forward(block_tokens, 1 - block_attention_mask, block_types)
return query_logits, block_logits
def embed_query(self, query_tokens, query_attention_mask, query_types):
query_ict_logits, _ = self.question_model.forward(query_tokens, 1 - query_attention_mask, query_types)
return query_ict_logits
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
"""Save dict with state dicts of each of the models.""" """Save dict with state dicts of each of the models."""
state_dict_ = {} state_dict_ = {}
state_dict_[self._query_key] \ if self.use_query_model:
= self.query_model.state_dict_for_save_checkpoint( state_dict_[self._query_key] \
destination, prefix, keep_vars) = self.query_model.state_dict_for_save_checkpoint(
state_dict_[self._block_key] \ destination, prefix, keep_vars)
= self.block_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars) if self.use_block_model:
state_dict_[self._block_key] \
= self.block_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
return state_dict_ return state_dict_
def load_state_dict(self, state_dict, strict=True): def load_state_dict(self, state_dict, strict=True):
"""Load the state dicts of each of the models""" """Load the state dicts of each of the models"""
self.query_model.load_state_dict( if self.use_query_model:
state_dict[self._query_key], strict=strict) self.query_model.load_state_dict(
self.block_model.load_state_dict( state_dict[self._query_key], strict=strict)
state_dict[self._block_key], strict=strict)
if self.use_block_model:
self.block_model.load_state_dict(
state_dict[self._block_key], strict=strict)
...@@ -43,10 +43,9 @@ def model_provider(): ...@@ -43,10 +43,9 @@ def model_provider():
def get_batch(data_iterator): def get_batch(data_iterator):
# Items and their type. # Items and their type.
keys = ['query_tokens', 'query_types', 'query_pad_mask', keys = ['query_tokens', 'query_pad_mask',
'block_tokens', 'block_types', 'block_pad_mask', 'block_indices'] 'block_tokens', 'block_pad_mask', 'block_data']
datatype = torch.int64 datatype = torch.int64
# Broadcast data. # Broadcast data.
...@@ -58,15 +57,13 @@ def get_batch(data_iterator): ...@@ -58,15 +57,13 @@ def get_batch(data_iterator):
# Unpack. # Unpack.
query_tokens = data_b['query_tokens'].long() query_tokens = data_b['query_tokens'].long()
query_types = data_b['query_types'].long()
query_pad_mask = data_b['query_pad_mask'].long() query_pad_mask = data_b['query_pad_mask'].long()
block_tokens = data_b['block_tokens'].long() block_tokens = data_b['block_tokens'].long()
block_types = data_b['block_types'].long()
block_pad_mask = data_b['block_pad_mask'].long() block_pad_mask = data_b['block_pad_mask'].long()
block_indices = data_b['block_indices'].long() block_indices = data_b['block_data'].long()
return query_tokens, query_types, query_pad_mask,\ return query_tokens, query_pad_mask,\
block_tokens, block_types, block_pad_mask, block_indices block_tokens, block_pad_mask, block_indices
def forward_step(data_iterator, model): def forward_step(data_iterator, model):
...@@ -75,16 +72,12 @@ def forward_step(data_iterator, model): ...@@ -75,16 +72,12 @@ def forward_step(data_iterator, model):
# Get the batch. # Get the batch.
timers('batch generator').start() timers('batch generator').start()
query_tokens, query_types, query_pad_mask,\ query_tokens, query_pad_mask, \
block_tokens, block_types, block_pad_mask, block_indices = get_batch(data_iterator) block_tokens, block_pad_mask, block_indices = get_batch(data_iterator)
timers('batch generator').stop() timers('batch generator').stop()
# Forward model. # Forward model.
query_logits, block_logits = model(query_tokens, query_pad_mask, query_types, retrieval_scores = model(query_tokens, query_pad_mask, block_tokens, block_pad_mask).float()
block_tokens, block_pad_mask, block_types).float()
# [batch x h] * [h x batch]
retrieval_scores = query_logits.matmul(torch.transpose(block_logits, 0, 1))
softmaxed = F.softmax(retrieval_scores, dim=1) softmaxed = F.softmax(retrieval_scores, dim=1)
top5_vals, top5_indices = torch.topk(softmaxed, k=5, sorted=True) top5_vals, top5_indices = torch.topk(softmaxed, k=5, sorted=True)
...@@ -95,10 +88,13 @@ def forward_step(data_iterator, model): ...@@ -95,10 +88,13 @@ def forward_step(data_iterator, model):
retrieval_loss = F.cross_entropy(softmaxed, torch.arange(batch_size).cuda()) retrieval_loss = F.cross_entropy(softmaxed, torch.arange(batch_size).cuda())
reduced_losses = reduce_losses([retrieval_loss, top1_acc, top5_acc]) reduced_losses = reduce_losses([retrieval_loss, top1_acc, top5_acc])
stats_dict = {
'retrieval loss': reduced_losses[0],
'top1_acc': reduced_losses[1],
'top5_acc': reduced_losses[2]
}
return retrieval_loss, {'retrieval loss': reduced_losses[0], return retrieval_loss, stats_dict
'top1_acc': reduced_losses[1],
'top5_acc': reduced_losses[2]}
def train_valid_test_datasets_provider(train_val_test_num_samples): def train_valid_test_datasets_provider(train_val_test_num_samples):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment