Commit 730266ca authored by Neel Kant's avatar Neel Kant
Browse files

Refactor and add more REALM arguments

parent a2e64ad5
...@@ -5,7 +5,7 @@ from megatron import get_args ...@@ -5,7 +5,7 @@ from megatron import get_args
from megatron import mpu from megatron import mpu
from megatron.checkpointing import get_checkpoint_tracker_filename, get_checkpoint_name from megatron.checkpointing import get_checkpoint_tracker_filename, get_checkpoint_name
from megatron.data.bert_dataset import get_indexed_dataset_ from megatron.data.bert_dataset import get_indexed_dataset_
from megatron.data.realm_dataset import InverseClozeDataset from megatron.data.realm_dataset import ICTDataset
from megatron.data.realm_index import detach, BlockData, RandProjectionLSHIndex from megatron.data.realm_index import detach, BlockData, RandProjectionLSHIndex
from megatron.data.samplers import DistributedBatchSampler from megatron.data.samplers import DistributedBatchSampler
from megatron.initialize import initialize_megatron from megatron.initialize import initialize_megatron
...@@ -150,7 +150,7 @@ def get_ict_dataset(): ...@@ -150,7 +150,7 @@ def get_ict_dataset():
short_seq_prob=0.0001, # doesn't matter short_seq_prob=0.0001, # doesn't matter
seed=1 seed=1
) )
dataset = InverseClozeDataset(**kwargs) dataset = ICTDataset(**kwargs)
return dataset return dataset
......
...@@ -245,6 +245,8 @@ def _add_checkpointing_args(parser): ...@@ -245,6 +245,8 @@ def _add_checkpointing_args(parser):
help='Directory containing a model checkpoint.') help='Directory containing a model checkpoint.')
group.add_argument('--ict-load', type=str, default=None, group.add_argument('--ict-load', type=str, default=None,
help='Directory containing an ICTBertModel checkpoint') help='Directory containing an ICTBertModel checkpoint')
group.add_argument('--bert-load', type=str, default=None,
help='Directory containing an BertModel checkpoint (needed to start REALM)')
group.add_argument('--no-load-optim', action='store_true', group.add_argument('--no-load-optim', action='store_true',
help='Do not load optimizer when loading checkpoint.') help='Do not load optimizer when loading checkpoint.')
group.add_argument('--no-load-rng', action='store_true', group.add_argument('--no-load-rng', action='store_true',
...@@ -326,6 +328,8 @@ def _add_data_args(parser): ...@@ -326,6 +328,8 @@ def _add_data_args(parser):
help='Path to pickled BlockData data structure') help='Path to pickled BlockData data structure')
group.add_argument('--block-index-path', type=str, default=None, group.add_argument('--block-index-path', type=str, default=None,
help='Path to pickled data structure for efficient block indexing') help='Path to pickled data structure for efficient block indexing')
group.add_argument('--block-top-k', type=int, default=5,
help='Number of blocks to use as top-k during retrieval')
group.add_argument('--split', type=str, default='969, 30, 1', group.add_argument('--split', type=str, default='969, 30, 1',
help='Comma-separated list of proportions for training,' help='Comma-separated list of proportions for training,'
' validation, and test split. For example the split ' ' validation, and test split. For example the split '
......
...@@ -131,11 +131,15 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler): ...@@ -131,11 +131,15 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
def load_checkpoint(model, optimizer, lr_scheduler): def load_checkpoint(model, optimizer, lr_scheduler):
"""Load a model checkpoint and return the iteration.""" """Load a model checkpoint and return the iteration."""
args = get_args() args = get_args()
load_dir = args.load
from megatron.model.bert_model import BertModel
if isinstance(model, BertModel) and args.bert_load is not None:
load_dir = args.bert_load
if isinstance(model, torchDDP): if isinstance(model, torchDDP):
model = model.module model = model.module
# Read the tracker file and set the iteration. # Read the tracker file and set the iteration.
tracker_filename = get_checkpoint_tracker_filename(args.load) tracker_filename = get_checkpoint_tracker_filename(load_dir)
# If no tracker file, return iretation zero. # If no tracker file, return iretation zero.
if not os.path.isfile(tracker_filename): if not os.path.isfile(tracker_filename):
...@@ -164,7 +168,7 @@ def load_checkpoint(model, optimizer, lr_scheduler): ...@@ -164,7 +168,7 @@ def load_checkpoint(model, optimizer, lr_scheduler):
tracker_filename) tracker_filename)
# Checkpoint. # Checkpoint.
checkpoint_name = get_checkpoint_name(args.load, iteration, release) checkpoint_name = get_checkpoint_name(load_dir, iteration, release)
if mpu.get_data_parallel_rank() == 0: if mpu.get_data_parallel_rank() == 0:
print('global rank {} is loading checkpoint {}'.format( print('global rank {} is loading checkpoint {}'.format(
torch.distributed.get_rank(), checkpoint_name)) torch.distributed.get_rank(), checkpoint_name))
......
...@@ -454,8 +454,8 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, ...@@ -454,8 +454,8 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
print_split_stats('test', 2) print_split_stats('test', 2)
def build_dataset(index, name): def build_dataset(index, name):
from megatron.data.realm_dataset import InverseClozeDataset from megatron.data.realm_dataset import ICTDataset
from megatron.data.realm_dataset import RealmDataset from megatron.data.realm_dataset import REALMDataset
dataset = None dataset = None
if splits[index + 1] > splits[index]: if splits[index + 1] > splits[index]:
# Get the pointer to the original doc-idx so we can set it later. # Get the pointer to the original doc-idx so we can set it later.
...@@ -478,13 +478,13 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, ...@@ -478,13 +478,13 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
) )
if dataset_type == 'ict': if dataset_type == 'ict':
dataset = InverseClozeDataset( dataset = ICTDataset(
block_dataset=indexed_dataset, block_dataset=indexed_dataset,
title_dataset=title_dataset, title_dataset=title_dataset,
**kwargs **kwargs
) )
else: else:
dataset_cls = BertDataset if dataset_type == 'standard_bert' else RealmDataset dataset_cls = BertDataset if dataset_type == 'standard_bert' else REALMDataset
dataset = dataset_cls( dataset = dataset_cls(
indexed_dataset=indexed_dataset, indexed_dataset=indexed_dataset,
masked_lm_prob=masked_lm_prob, masked_lm_prob=masked_lm_prob,
......
...@@ -15,7 +15,7 @@ from megatron.data.dataset_utils import create_masked_lm_predictions, pad_and_co ...@@ -15,7 +15,7 @@ from megatron.data.dataset_utils import create_masked_lm_predictions, pad_and_co
#qa_nlp = spacy.load('en_core_web_lg') #qa_nlp = spacy.load('en_core_web_lg')
class RealmDataset(BertDataset): class REALMDataset(BertDataset):
"""Dataset containing simple masked sentences for masked language modeling. """Dataset containing simple masked sentences for masked language modeling.
The dataset should yield sentences just like the regular BertDataset The dataset should yield sentences just like the regular BertDataset
...@@ -28,7 +28,7 @@ class RealmDataset(BertDataset): ...@@ -28,7 +28,7 @@ class RealmDataset(BertDataset):
def __init__(self, name, indexed_dataset, data_prefix, def __init__(self, name, indexed_dataset, data_prefix,
num_epochs, max_num_samples, masked_lm_prob, num_epochs, max_num_samples, masked_lm_prob,
max_seq_length, short_seq_prob, seed): max_seq_length, short_seq_prob, seed):
super(RealmDataset, self).__init__(name, indexed_dataset, data_prefix, super(REALMDataset, self).__init__(name, indexed_dataset, data_prefix,
num_epochs, max_num_samples, masked_lm_prob, num_epochs, max_num_samples, masked_lm_prob,
max_seq_length, short_seq_prob, seed) max_seq_length, short_seq_prob, seed)
self.build_sample_fn = build_simple_training_sample self.build_sample_fn = build_simple_training_sample
...@@ -81,7 +81,7 @@ def spacy_ner(block_text): ...@@ -81,7 +81,7 @@ def spacy_ner(block_text):
candidates['answers'] = answers candidates['answers'] = answers
class InverseClozeDataset(Dataset): class ICTDataset(Dataset):
"""Dataset containing sentences and their blocks for an inverse cloze task.""" """Dataset containing sentences and their blocks for an inverse cloze task."""
def __init__(self, name, block_dataset, title_dataset, data_prefix, def __init__(self, name, block_dataset, title_dataset, data_prefix,
num_epochs, max_num_samples, max_seq_length, num_epochs, max_num_samples, max_seq_length,
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
from .distributed import * from .distributed import *
from .bert_model import BertModel, ICTBertModel, REALMBertModel, REALMRetriever from .bert_model import BertModel
from megatron.model.realm_model import ICTBertModel, REALMRetriever, REALMBertModel
from .gpt2_model import GPT2Model from .gpt2_model import GPT2Model
from .utils import get_params_for_weight_decay_optimization from .utils import get_params_for_weight_decay_optimization
...@@ -15,14 +15,9 @@ ...@@ -15,14 +15,9 @@
"""BERT model.""" """BERT model."""
import pickle
import numpy as np
import torch import torch
import torch.nn.functional as F
from megatron import get_args from megatron import get_args
from megatron.data.realm_index import detach
from megatron.model.language_model import parallel_lm_logits from megatron.model.language_model import parallel_lm_logits
from megatron.model.language_model import get_language_model from megatron.model.language_model import get_language_model
from megatron.model.transformer import LayerNorm from megatron.model.transformer import LayerNorm
...@@ -224,198 +219,3 @@ class BertModel(MegatronModule): ...@@ -224,198 +219,3 @@ class BertModel(MegatronModule):
state_dict[self._ict_head_key], strict=strict) state_dict[self._ict_head_key], strict=strict)
class REALMBertModel(MegatronModule):
def __init__(self, retriever):
super(REALMBertModel, self).__init__()
bert_args = dict(
num_tokentypes=1,
add_binary_head=False,
parallel_output=True
)
self.lm_model = BertModel(**bert_args)
self._lm_key = 'realm_lm'
self.retriever = retriever
self._retriever_key = 'retriever'
def forward(self, tokens, attention_mask):
# [batch_size x 5 x seq_length]
top5_block_tokens, top5_block_attention_mask = self.retriever.retrieve_evidence_blocks(tokens, attention_mask)
batch_size = tokens.shape[0]
seq_length = top5_block_tokens.shape[2]
top5_block_tokens = torch.cuda.LongTensor(top5_block_tokens).reshape(-1, seq_length)
top5_block_attention_mask = torch.cuda.LongTensor(top5_block_attention_mask).reshape(-1, seq_length)
# [batch_size x 5 x embed_size]
true_model = self.retriever.ict_model.module.module
fresh_block_logits = true_model.embed_block(top5_block_tokens, top5_block_attention_mask).reshape(batch_size, 5, -1)
# [batch_size x embed_size x 1]
query_logits = true_model.embed_query(tokens, attention_mask).unsqueeze(2)
# [batch_size x 5]
fresh_block_scores = torch.matmul(fresh_block_logits, query_logits).squeeze()
block_probs = F.softmax(fresh_block_scores, dim=1)
# [batch_size * 5 x seq_length]
tokens = torch.stack([tokens.unsqueeze(1)] * 5, dim=1).reshape(-1, seq_length)
attention_mask = torch.stack([attention_mask.unsqueeze(1)] * 5, dim=1).reshape(-1, seq_length)
# [batch_size * 5 x 2 * seq_length]
all_tokens = torch.cat((tokens, top5_block_tokens), axis=1)
all_attention_mask = torch.cat((attention_mask, top5_block_attention_mask), axis=1)
all_token_types = torch.zeros(all_tokens.shape).type(torch.int64).cuda()
# [batch_size x 5 x 2 * seq_length x vocab_size]
lm_logits, _ = self.lm_model.forward(all_tokens, all_attention_mask, all_token_types)
lm_logits = lm_logits.reshape(batch_size, 5, 2 * seq_length, -1)
return lm_logits, block_probs
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
"""For easy load when model is combined with other heads,
add an extra key."""
state_dict_ = {}
state_dict_[self._lm_key] = self.lm_model.state_dict_for_save_checkpoint(destination, prefix, keep_vars)
return state_dict_
class REALMRetriever(MegatronModule):
"""Retriever which uses a pretrained ICTBertModel and a HashedIndex"""
def __init__(self, ict_model, ict_dataset, block_data, hashed_index, top_k=5):
super(REALMRetriever, self).__init__()
self.ict_model = ict_model
self.ict_dataset = ict_dataset
self.block_data = block_data
self.hashed_index = hashed_index
self.top_k = top_k
def retrieve_evidence_blocks_text(self, query_text):
"""Get the top k evidence blocks for query_text in text form"""
print("-" * 100)
print("Query: ", query_text)
padless_max_len = self.ict_dataset.max_seq_length - 2
query_tokens = self.ict_dataset.encode_text(query_text)[:padless_max_len]
query_tokens, query_pad_mask = self.ict_dataset.concat_and_pad_tokens(query_tokens)
query_tokens = torch.cuda.LongTensor(np.array(query_tokens).reshape(1, -1))
query_pad_mask = torch.cuda.LongTensor(np.array(query_pad_mask).reshape(1, -1))
top5_block_tokens, _ = self.retrieve_evidence_blocks(query_tokens, query_pad_mask)
for i, block in enumerate(top5_block_tokens[0]):
block_text = self.ict_dataset.decode_tokens(block)
print('\n > Block {}: {}'.format(i, block_text))
def retrieve_evidence_blocks(self, query_tokens, query_pad_mask):
"""Embed blocks to be used in a forward pass"""
with torch.no_grad():
true_model = self.ict_model.module.module
query_embeds = detach(true_model.embed_query(query_tokens, query_pad_mask))
_, block_indices = self.hashed_index.search_mips_index(query_embeds, top_k=self.top_k, reconstruct=False)
all_top5_tokens, all_top5_pad_masks = [], []
for indices in block_indices:
# [k x meta_dim]
top5_metas = np.array([self.block_data.meta_data[idx] for idx in indices])
top5_block_data = [self.ict_dataset.get_block(*block_meta) for block_meta in top5_metas]
top5_tokens, top5_pad_masks = zip(*top5_block_data)
all_top5_tokens.append(np.array(top5_tokens))
all_top5_pad_masks.append(np.array(top5_pad_masks))
# [batch_size x k x seq_length]
return np.array(all_top5_tokens), np.array(all_top5_pad_masks)
class ICTBertModel(MegatronModule):
"""Bert-based module for Inverse Cloze task."""
def __init__(self,
ict_head_size,
num_tokentypes=1,
parallel_output=True,
only_query_model=False,
only_block_model=False):
super(ICTBertModel, self).__init__()
bert_args = dict(
num_tokentypes=num_tokentypes,
add_binary_head=False,
ict_head_size=ict_head_size,
parallel_output=parallel_output
)
assert not (only_block_model and only_query_model)
self.use_block_model = not only_query_model
self.use_query_model = not only_block_model
if self.use_query_model:
# this model embeds (pseudo-)queries - Embed_input in the paper
self.query_model = BertModel(**bert_args)
self._query_key = 'question_model'
if self.use_block_model:
# this model embeds evidence blocks - Embed_doc in the paper
self.block_model = BertModel(**bert_args)
self._block_key = 'context_model'
def forward(self, query_tokens, query_attention_mask, block_tokens, block_attention_mask, only_query=False, only_block=False):
"""Run a forward pass for each of the models and compute the similarity scores."""
if only_query:
return self.embed_query(query_tokens, query_attention_mask)
if only_block:
return self.embed_block(block_tokens, block_attention_mask)
query_logits = self.embed_query(query_tokens, query_attention_mask)
block_logits = self.embed_block(block_tokens, block_attention_mask)
# [batch x embed] * [embed x batch]
retrieval_scores = query_logits.matmul(torch.transpose(block_logits, 0, 1))
return retrieval_scores
def embed_query(self, query_tokens, query_attention_mask):
"""Embed a batch of tokens using the query model"""
if self.use_query_model:
query_types = torch.zeros(query_tokens.shape).type(torch.int64).cuda()
query_ict_logits, _ = self.query_model.forward(query_tokens, query_attention_mask, query_types)
return query_ict_logits
else:
raise ValueError("Cannot embed query without query model.")
def embed_block(self, block_tokens, block_attention_mask):
"""Embed a batch of tokens using the block model"""
if self.use_block_model:
block_types = torch.zeros(block_tokens.shape).type(torch.int64).cuda()
block_ict_logits, _ = self.block_model.forward(block_tokens, block_attention_mask, block_types)
return block_ict_logits
else:
raise ValueError("Cannot embed block without block model.")
def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False):
"""Save dict with state dicts of each of the models."""
state_dict_ = {}
if self.use_query_model:
state_dict_[self._query_key] \
= self.query_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
if self.use_block_model:
state_dict_[self._block_key] \
= self.block_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
return state_dict_
def load_state_dict(self, state_dict, strict=True):
"""Load the state dicts of each of the models"""
if self.use_query_model:
print("Loading ICT query model", flush=True)
self.query_model.load_state_dict(
state_dict[self._query_key], strict=strict)
if self.use_block_model:
print("Loading ICT block model", flush=True)
self.block_model.load_state_dict(
state_dict[self._block_key], strict=strict)
import numpy as np
import torch
import torch.nn.functional as F
from megatron.checkpointing import load_checkpoint
from megatron.data.realm_index import detach
from megatron.model import BertModel
from megatron.module import MegatronModule
class REALMBertModel(MegatronModule):
def __init__(self, retriever):
super(REALMBertModel, self).__init__()
bert_args = dict(
num_tokentypes=1,
add_binary_head=False,
parallel_output=True
)
self.lm_model = BertModel(**bert_args)
load_checkpoint(self.lm_model, optimizer=None, lr_scheduler=None)
self._lm_key = 'realm_lm'
self.retriever = retriever
self._retriever_key = 'retriever'
def forward(self, tokens, attention_mask):
# [batch_size x 5 x seq_length]
top5_block_tokens, top5_block_attention_mask = self.retriever.retrieve_evidence_blocks(tokens, attention_mask)
batch_size = tokens.shape[0]
seq_length = top5_block_tokens.shape[2]
top5_block_tokens = torch.cuda.LongTensor(top5_block_tokens).reshape(-1, seq_length)
top5_block_attention_mask = torch.cuda.LongTensor(top5_block_attention_mask).reshape(-1, seq_length)
# [batch_size x 5 x embed_size]
true_model = self.retriever.ict_model.module.module
fresh_block_logits = true_model.embed_block(top5_block_tokens, top5_block_attention_mask).reshape(batch_size, 5, -1)
# [batch_size x embed_size x 1]
query_logits = true_model.embed_query(tokens, attention_mask).unsqueeze(2)
# [batch_size x 5]
fresh_block_scores = torch.matmul(fresh_block_logits, query_logits).squeeze()
block_probs = F.softmax(fresh_block_scores, dim=1)
# [batch_size * 5 x seq_length]
tokens = torch.stack([tokens.unsqueeze(1)] * 5, dim=1).reshape(-1, seq_length)
attention_mask = torch.stack([attention_mask.unsqueeze(1)] * 5, dim=1).reshape(-1, seq_length)
# [batch_size * 5 x 2 * seq_length]
all_tokens = torch.cat((tokens, top5_block_tokens), axis=1)
all_attention_mask = torch.cat((attention_mask, top5_block_attention_mask), axis=1)
all_token_types = torch.zeros(all_tokens.shape).type(torch.int64).cuda()
# [batch_size x 5 x 2 * seq_length x vocab_size]
lm_logits, _ = self.lm_model.forward(all_tokens, all_attention_mask, all_token_types)
lm_logits = lm_logits.reshape(batch_size, 5, 2 * seq_length, -1)
return lm_logits, block_probs
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
"""For easy load when model is combined with other heads,
add an extra key."""
state_dict_ = {}
state_dict_[self._lm_key] = self.lm_model.state_dict_for_save_checkpoint(destination, prefix, keep_vars)
state_dict_[self._retriever_key] = self.retriever.state_dict_for_save_checkpoint(destination, prefix, keep_vars)
return state_dict_
def load_state_dict(self, state_dict, strict=True):
"""Load the state dicts of each of the models"""
self.lm_model.load_state_dict(state_dict[self._lm_key], strict)
self.retriever.load_state_dict(state_dict[self._retriever_key], strict)
class REALMRetriever(MegatronModule):
"""Retriever which uses a pretrained ICTBertModel and a HashedIndex"""
def __init__(self, ict_model, ict_dataset, block_data, hashed_index, top_k=5):
super(REALMRetriever, self).__init__()
self.ict_model = ict_model
self.ict_dataset = ict_dataset
self.block_data = block_data
self.hashed_index = hashed_index
self.top_k = top_k
self._ict_key = 'ict_model'
def retrieve_evidence_blocks_text(self, query_text):
"""Get the top k evidence blocks for query_text in text form"""
print("-" * 100)
print("Query: ", query_text)
padless_max_len = self.ict_dataset.max_seq_length - 2
query_tokens = self.ict_dataset.encode_text(query_text)[:padless_max_len]
query_tokens, query_pad_mask = self.ict_dataset.concat_and_pad_tokens(query_tokens)
query_tokens = torch.cuda.LongTensor(np.array(query_tokens).reshape(1, -1))
query_pad_mask = torch.cuda.LongTensor(np.array(query_pad_mask).reshape(1, -1))
top5_block_tokens, _ = self.retrieve_evidence_blocks(query_tokens, query_pad_mask)
for i, block in enumerate(top5_block_tokens[0]):
block_text = self.ict_dataset.decode_tokens(block)
print('\n > Block {}: {}'.format(i, block_text))
def retrieve_evidence_blocks(self, query_tokens, query_pad_mask):
"""Embed blocks to be used in a forward pass"""
with torch.no_grad():
true_model = self.ict_model.module.module
query_embeds = detach(true_model.embed_query(query_tokens, query_pad_mask))
_, block_indices = self.hashed_index.search_mips_index(query_embeds, top_k=self.top_k, reconstruct=False)
all_top5_tokens, all_top5_pad_masks = [], []
for indices in block_indices:
# [k x meta_dim]
top5_metas = np.array([self.block_data.meta_data[idx] for idx in indices])
top5_block_data = [self.ict_dataset.get_block(*block_meta) for block_meta in top5_metas]
top5_tokens, top5_pad_masks = zip(*top5_block_data)
all_top5_tokens.append(np.array(top5_tokens))
all_top5_pad_masks.append(np.array(top5_pad_masks))
# [batch_size x k x seq_length]
return np.array(all_top5_tokens), np.array(all_top5_pad_masks)
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
"""For easy load when model is combined with other heads,
add an extra key."""
state_dict_ = {}
state_dict_[self._ict_key] = self.ict_model.state_dict_for_save_checkpoint(destination, prefix, keep_vars)
return state_dict_
def load_state_dict(self, state_dict, strict=True):
"""Load the state dicts of each of the models"""
self.ict_model.load_state_dict(state_dict[self._ict_key], strict)
class ICTBertModel(MegatronModule):
"""Bert-based module for Inverse Cloze task."""
def __init__(self,
ict_head_size,
num_tokentypes=1,
parallel_output=True,
only_query_model=False,
only_block_model=False):
super(ICTBertModel, self).__init__()
bert_args = dict(
num_tokentypes=num_tokentypes,
add_binary_head=False,
ict_head_size=ict_head_size,
parallel_output=parallel_output
)
assert not (only_block_model and only_query_model)
self.use_block_model = not only_query_model
self.use_query_model = not only_block_model
if self.use_query_model:
# this model embeds (pseudo-)queries - Embed_input in the paper
self.query_model = BertModel(**bert_args)
self._query_key = 'question_model'
if self.use_block_model:
# this model embeds evidence blocks - Embed_doc in the paper
self.block_model = BertModel(**bert_args)
self._block_key = 'context_model'
def forward(self, query_tokens, query_attention_mask, block_tokens, block_attention_mask, only_query=False, only_block=False):
"""Run a forward pass for each of the models and compute the similarity scores."""
if only_query:
return self.embed_query(query_tokens, query_attention_mask)
if only_block:
return self.embed_block(block_tokens, block_attention_mask)
query_logits = self.embed_query(query_tokens, query_attention_mask)
block_logits = self.embed_block(block_tokens, block_attention_mask)
# [batch x embed] * [embed x batch]
retrieval_scores = query_logits.matmul(torch.transpose(block_logits, 0, 1))
return retrieval_scores
def embed_query(self, query_tokens, query_attention_mask):
"""Embed a batch of tokens using the query model"""
if self.use_query_model:
query_types = torch.zeros(query_tokens.shape).type(torch.int64).cuda()
query_ict_logits, _ = self.query_model.forward(query_tokens, query_attention_mask, query_types)
return query_ict_logits
else:
raise ValueError("Cannot embed query without query model.")
def embed_block(self, block_tokens, block_attention_mask):
"""Embed a batch of tokens using the block model"""
if self.use_block_model:
block_types = torch.zeros(block_tokens.shape).type(torch.int64).cuda()
block_ict_logits, _ = self.block_model.forward(block_tokens, block_attention_mask, block_types)
return block_ict_logits
else:
raise ValueError("Cannot embed block without block model.")
def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False):
"""Save dict with state dicts of each of the models."""
state_dict_ = {}
if self.use_query_model:
state_dict_[self._query_key] \
= self.query_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
if self.use_block_model:
state_dict_[self._block_key] \
= self.block_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
return state_dict_
def load_state_dict(self, state_dict, strict=True):
"""Load the state dicts of each of the models"""
if self.use_query_model:
print("Loading ICT query model", flush=True)
self.query_model.load_state_dict(
state_dict[self._query_key], strict=strict)
if self.use_block_model:
print("Loading ICT block model", flush=True)
self.block_model.load_state_dict(
state_dict[self._block_key], strict=strict)
...@@ -44,7 +44,7 @@ def model_provider(): ...@@ -44,7 +44,7 @@ def model_provider():
hashed_index = FaissMIPSIndex(index_type='flat_l2', embed_size=128) hashed_index = FaissMIPSIndex(index_type='flat_l2', embed_size=128)
hashed_index.add_block_embed_data(all_block_data) hashed_index.add_block_embed_data(all_block_data)
retriever = REALMRetriever(ict_model, ict_dataset, all_block_data, hashed_index) retriever = REALMRetriever(ict_model, ict_dataset, all_block_data, hashed_index, args.block_top_k)
# TODO: REALMBertModel should accept a path to a pretrained bert-base # TODO: REALMBertModel should accept a path to a pretrained bert-base
model = REALMBertModel(retriever) model = REALMBertModel(retriever)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment