Commit 730266ca authored by Neel Kant's avatar Neel Kant
Browse files

Refactor and add more REALM arguments

parent a2e64ad5
......@@ -5,7 +5,7 @@ from megatron import get_args
from megatron import mpu
from megatron.checkpointing import get_checkpoint_tracker_filename, get_checkpoint_name
from megatron.data.bert_dataset import get_indexed_dataset_
from megatron.data.realm_dataset import InverseClozeDataset
from megatron.data.realm_dataset import ICTDataset
from megatron.data.realm_index import detach, BlockData, RandProjectionLSHIndex
from megatron.data.samplers import DistributedBatchSampler
from megatron.initialize import initialize_megatron
......@@ -150,7 +150,7 @@ def get_ict_dataset():
short_seq_prob=0.0001, # doesn't matter
seed=1
)
dataset = InverseClozeDataset(**kwargs)
dataset = ICTDataset(**kwargs)
return dataset
......
......@@ -245,6 +245,8 @@ def _add_checkpointing_args(parser):
help='Directory containing a model checkpoint.')
group.add_argument('--ict-load', type=str, default=None,
help='Directory containing an ICTBertModel checkpoint')
group.add_argument('--bert-load', type=str, default=None,
help='Directory containing an BertModel checkpoint (needed to start REALM)')
group.add_argument('--no-load-optim', action='store_true',
help='Do not load optimizer when loading checkpoint.')
group.add_argument('--no-load-rng', action='store_true',
......@@ -326,6 +328,8 @@ def _add_data_args(parser):
help='Path to pickled BlockData data structure')
group.add_argument('--block-index-path', type=str, default=None,
help='Path to pickled data structure for efficient block indexing')
group.add_argument('--block-top-k', type=int, default=5,
help='Number of blocks to use as top-k during retrieval')
group.add_argument('--split', type=str, default='969, 30, 1',
help='Comma-separated list of proportions for training,'
' validation, and test split. For example the split '
......
......@@ -131,11 +131,15 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
def load_checkpoint(model, optimizer, lr_scheduler):
"""Load a model checkpoint and return the iteration."""
args = get_args()
load_dir = args.load
from megatron.model.bert_model import BertModel
if isinstance(model, BertModel) and args.bert_load is not None:
load_dir = args.bert_load
if isinstance(model, torchDDP):
model = model.module
# Read the tracker file and set the iteration.
tracker_filename = get_checkpoint_tracker_filename(args.load)
tracker_filename = get_checkpoint_tracker_filename(load_dir)
# If no tracker file, return iretation zero.
if not os.path.isfile(tracker_filename):
......@@ -164,7 +168,7 @@ def load_checkpoint(model, optimizer, lr_scheduler):
tracker_filename)
# Checkpoint.
checkpoint_name = get_checkpoint_name(args.load, iteration, release)
checkpoint_name = get_checkpoint_name(load_dir, iteration, release)
if mpu.get_data_parallel_rank() == 0:
print('global rank {} is loading checkpoint {}'.format(
torch.distributed.get_rank(), checkpoint_name))
......
......@@ -454,8 +454,8 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
print_split_stats('test', 2)
def build_dataset(index, name):
from megatron.data.realm_dataset import InverseClozeDataset
from megatron.data.realm_dataset import RealmDataset
from megatron.data.realm_dataset import ICTDataset
from megatron.data.realm_dataset import REALMDataset
dataset = None
if splits[index + 1] > splits[index]:
# Get the pointer to the original doc-idx so we can set it later.
......@@ -478,13 +478,13 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
)
if dataset_type == 'ict':
dataset = InverseClozeDataset(
dataset = ICTDataset(
block_dataset=indexed_dataset,
title_dataset=title_dataset,
**kwargs
)
else:
dataset_cls = BertDataset if dataset_type == 'standard_bert' else RealmDataset
dataset_cls = BertDataset if dataset_type == 'standard_bert' else REALMDataset
dataset = dataset_cls(
indexed_dataset=indexed_dataset,
masked_lm_prob=masked_lm_prob,
......
......@@ -15,7 +15,7 @@ from megatron.data.dataset_utils import create_masked_lm_predictions, pad_and_co
#qa_nlp = spacy.load('en_core_web_lg')
class RealmDataset(BertDataset):
class REALMDataset(BertDataset):
"""Dataset containing simple masked sentences for masked language modeling.
The dataset should yield sentences just like the regular BertDataset
......@@ -28,7 +28,7 @@ class RealmDataset(BertDataset):
def __init__(self, name, indexed_dataset, data_prefix,
num_epochs, max_num_samples, masked_lm_prob,
max_seq_length, short_seq_prob, seed):
super(RealmDataset, self).__init__(name, indexed_dataset, data_prefix,
super(REALMDataset, self).__init__(name, indexed_dataset, data_prefix,
num_epochs, max_num_samples, masked_lm_prob,
max_seq_length, short_seq_prob, seed)
self.build_sample_fn = build_simple_training_sample
......@@ -81,7 +81,7 @@ def spacy_ner(block_text):
candidates['answers'] = answers
class InverseClozeDataset(Dataset):
class ICTDataset(Dataset):
"""Dataset containing sentences and their blocks for an inverse cloze task."""
def __init__(self, name, block_dataset, title_dataset, data_prefix,
num_epochs, max_num_samples, max_seq_length,
......
......@@ -14,6 +14,7 @@
# limitations under the License.
from .distributed import *
from .bert_model import BertModel, ICTBertModel, REALMBertModel, REALMRetriever
from .bert_model import BertModel
from megatron.model.realm_model import ICTBertModel, REALMRetriever, REALMBertModel
from .gpt2_model import GPT2Model
from .utils import get_params_for_weight_decay_optimization
......@@ -15,14 +15,9 @@
"""BERT model."""
import pickle
import numpy as np
import torch
import torch.nn.functional as F
from megatron import get_args
from megatron.data.realm_index import detach
from megatron.model.language_model import parallel_lm_logits
from megatron.model.language_model import get_language_model
from megatron.model.transformer import LayerNorm
......@@ -224,198 +219,3 @@ class BertModel(MegatronModule):
state_dict[self._ict_head_key], strict=strict)
class REALMBertModel(MegatronModule):
def __init__(self, retriever):
super(REALMBertModel, self).__init__()
bert_args = dict(
num_tokentypes=1,
add_binary_head=False,
parallel_output=True
)
self.lm_model = BertModel(**bert_args)
self._lm_key = 'realm_lm'
self.retriever = retriever
self._retriever_key = 'retriever'
def forward(self, tokens, attention_mask):
# [batch_size x 5 x seq_length]
top5_block_tokens, top5_block_attention_mask = self.retriever.retrieve_evidence_blocks(tokens, attention_mask)
batch_size = tokens.shape[0]
seq_length = top5_block_tokens.shape[2]
top5_block_tokens = torch.cuda.LongTensor(top5_block_tokens).reshape(-1, seq_length)
top5_block_attention_mask = torch.cuda.LongTensor(top5_block_attention_mask).reshape(-1, seq_length)
# [batch_size x 5 x embed_size]
true_model = self.retriever.ict_model.module.module
fresh_block_logits = true_model.embed_block(top5_block_tokens, top5_block_attention_mask).reshape(batch_size, 5, -1)
# [batch_size x embed_size x 1]
query_logits = true_model.embed_query(tokens, attention_mask).unsqueeze(2)
# [batch_size x 5]
fresh_block_scores = torch.matmul(fresh_block_logits, query_logits).squeeze()
block_probs = F.softmax(fresh_block_scores, dim=1)
# [batch_size * 5 x seq_length]
tokens = torch.stack([tokens.unsqueeze(1)] * 5, dim=1).reshape(-1, seq_length)
attention_mask = torch.stack([attention_mask.unsqueeze(1)] * 5, dim=1).reshape(-1, seq_length)
# [batch_size * 5 x 2 * seq_length]
all_tokens = torch.cat((tokens, top5_block_tokens), axis=1)
all_attention_mask = torch.cat((attention_mask, top5_block_attention_mask), axis=1)
all_token_types = torch.zeros(all_tokens.shape).type(torch.int64).cuda()
# [batch_size x 5 x 2 * seq_length x vocab_size]
lm_logits, _ = self.lm_model.forward(all_tokens, all_attention_mask, all_token_types)
lm_logits = lm_logits.reshape(batch_size, 5, 2 * seq_length, -1)
return lm_logits, block_probs
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
"""For easy load when model is combined with other heads,
add an extra key."""
state_dict_ = {}
state_dict_[self._lm_key] = self.lm_model.state_dict_for_save_checkpoint(destination, prefix, keep_vars)
return state_dict_
class REALMRetriever(MegatronModule):
"""Retriever which uses a pretrained ICTBertModel and a HashedIndex"""
def __init__(self, ict_model, ict_dataset, block_data, hashed_index, top_k=5):
super(REALMRetriever, self).__init__()
self.ict_model = ict_model
self.ict_dataset = ict_dataset
self.block_data = block_data
self.hashed_index = hashed_index
self.top_k = top_k
def retrieve_evidence_blocks_text(self, query_text):
"""Get the top k evidence blocks for query_text in text form"""
print("-" * 100)
print("Query: ", query_text)
padless_max_len = self.ict_dataset.max_seq_length - 2
query_tokens = self.ict_dataset.encode_text(query_text)[:padless_max_len]
query_tokens, query_pad_mask = self.ict_dataset.concat_and_pad_tokens(query_tokens)
query_tokens = torch.cuda.LongTensor(np.array(query_tokens).reshape(1, -1))
query_pad_mask = torch.cuda.LongTensor(np.array(query_pad_mask).reshape(1, -1))
top5_block_tokens, _ = self.retrieve_evidence_blocks(query_tokens, query_pad_mask)
for i, block in enumerate(top5_block_tokens[0]):
block_text = self.ict_dataset.decode_tokens(block)
print('\n > Block {}: {}'.format(i, block_text))
def retrieve_evidence_blocks(self, query_tokens, query_pad_mask):
"""Embed blocks to be used in a forward pass"""
with torch.no_grad():
true_model = self.ict_model.module.module
query_embeds = detach(true_model.embed_query(query_tokens, query_pad_mask))
_, block_indices = self.hashed_index.search_mips_index(query_embeds, top_k=self.top_k, reconstruct=False)
all_top5_tokens, all_top5_pad_masks = [], []
for indices in block_indices:
# [k x meta_dim]
top5_metas = np.array([self.block_data.meta_data[idx] for idx in indices])
top5_block_data = [self.ict_dataset.get_block(*block_meta) for block_meta in top5_metas]
top5_tokens, top5_pad_masks = zip(*top5_block_data)
all_top5_tokens.append(np.array(top5_tokens))
all_top5_pad_masks.append(np.array(top5_pad_masks))
# [batch_size x k x seq_length]
return np.array(all_top5_tokens), np.array(all_top5_pad_masks)
class ICTBertModel(MegatronModule):
"""Bert-based module for Inverse Cloze task."""
def __init__(self,
ict_head_size,
num_tokentypes=1,
parallel_output=True,
only_query_model=False,
only_block_model=False):
super(ICTBertModel, self).__init__()
bert_args = dict(
num_tokentypes=num_tokentypes,
add_binary_head=False,
ict_head_size=ict_head_size,
parallel_output=parallel_output
)
assert not (only_block_model and only_query_model)
self.use_block_model = not only_query_model
self.use_query_model = not only_block_model
if self.use_query_model:
# this model embeds (pseudo-)queries - Embed_input in the paper
self.query_model = BertModel(**bert_args)
self._query_key = 'question_model'
if self.use_block_model:
# this model embeds evidence blocks - Embed_doc in the paper
self.block_model = BertModel(**bert_args)
self._block_key = 'context_model'
def forward(self, query_tokens, query_attention_mask, block_tokens, block_attention_mask, only_query=False, only_block=False):
"""Run a forward pass for each of the models and compute the similarity scores."""
if only_query:
return self.embed_query(query_tokens, query_attention_mask)
if only_block:
return self.embed_block(block_tokens, block_attention_mask)
query_logits = self.embed_query(query_tokens, query_attention_mask)
block_logits = self.embed_block(block_tokens, block_attention_mask)
# [batch x embed] * [embed x batch]
retrieval_scores = query_logits.matmul(torch.transpose(block_logits, 0, 1))
return retrieval_scores
def embed_query(self, query_tokens, query_attention_mask):
"""Embed a batch of tokens using the query model"""
if self.use_query_model:
query_types = torch.zeros(query_tokens.shape).type(torch.int64).cuda()
query_ict_logits, _ = self.query_model.forward(query_tokens, query_attention_mask, query_types)
return query_ict_logits
else:
raise ValueError("Cannot embed query without query model.")
def embed_block(self, block_tokens, block_attention_mask):
"""Embed a batch of tokens using the block model"""
if self.use_block_model:
block_types = torch.zeros(block_tokens.shape).type(torch.int64).cuda()
block_ict_logits, _ = self.block_model.forward(block_tokens, block_attention_mask, block_types)
return block_ict_logits
else:
raise ValueError("Cannot embed block without block model.")
def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False):
"""Save dict with state dicts of each of the models."""
state_dict_ = {}
if self.use_query_model:
state_dict_[self._query_key] \
= self.query_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
if self.use_block_model:
state_dict_[self._block_key] \
= self.block_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
return state_dict_
def load_state_dict(self, state_dict, strict=True):
"""Load the state dicts of each of the models"""
if self.use_query_model:
print("Loading ICT query model", flush=True)
self.query_model.load_state_dict(
state_dict[self._query_key], strict=strict)
if self.use_block_model:
print("Loading ICT block model", flush=True)
self.block_model.load_state_dict(
state_dict[self._block_key], strict=strict)
import numpy as np
import torch
import torch.nn.functional as F
from megatron.checkpointing import load_checkpoint
from megatron.data.realm_index import detach
from megatron.model import BertModel
from megatron.module import MegatronModule
class REALMBertModel(MegatronModule):
def __init__(self, retriever):
super(REALMBertModel, self).__init__()
bert_args = dict(
num_tokentypes=1,
add_binary_head=False,
parallel_output=True
)
self.lm_model = BertModel(**bert_args)
load_checkpoint(self.lm_model, optimizer=None, lr_scheduler=None)
self._lm_key = 'realm_lm'
self.retriever = retriever
self._retriever_key = 'retriever'
def forward(self, tokens, attention_mask):
# [batch_size x 5 x seq_length]
top5_block_tokens, top5_block_attention_mask = self.retriever.retrieve_evidence_blocks(tokens, attention_mask)
batch_size = tokens.shape[0]
seq_length = top5_block_tokens.shape[2]
top5_block_tokens = torch.cuda.LongTensor(top5_block_tokens).reshape(-1, seq_length)
top5_block_attention_mask = torch.cuda.LongTensor(top5_block_attention_mask).reshape(-1, seq_length)
# [batch_size x 5 x embed_size]
true_model = self.retriever.ict_model.module.module
fresh_block_logits = true_model.embed_block(top5_block_tokens, top5_block_attention_mask).reshape(batch_size, 5, -1)
# [batch_size x embed_size x 1]
query_logits = true_model.embed_query(tokens, attention_mask).unsqueeze(2)
# [batch_size x 5]
fresh_block_scores = torch.matmul(fresh_block_logits, query_logits).squeeze()
block_probs = F.softmax(fresh_block_scores, dim=1)
# [batch_size * 5 x seq_length]
tokens = torch.stack([tokens.unsqueeze(1)] * 5, dim=1).reshape(-1, seq_length)
attention_mask = torch.stack([attention_mask.unsqueeze(1)] * 5, dim=1).reshape(-1, seq_length)
# [batch_size * 5 x 2 * seq_length]
all_tokens = torch.cat((tokens, top5_block_tokens), axis=1)
all_attention_mask = torch.cat((attention_mask, top5_block_attention_mask), axis=1)
all_token_types = torch.zeros(all_tokens.shape).type(torch.int64).cuda()
# [batch_size x 5 x 2 * seq_length x vocab_size]
lm_logits, _ = self.lm_model.forward(all_tokens, all_attention_mask, all_token_types)
lm_logits = lm_logits.reshape(batch_size, 5, 2 * seq_length, -1)
return lm_logits, block_probs
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
"""For easy load when model is combined with other heads,
add an extra key."""
state_dict_ = {}
state_dict_[self._lm_key] = self.lm_model.state_dict_for_save_checkpoint(destination, prefix, keep_vars)
state_dict_[self._retriever_key] = self.retriever.state_dict_for_save_checkpoint(destination, prefix, keep_vars)
return state_dict_
def load_state_dict(self, state_dict, strict=True):
"""Load the state dicts of each of the models"""
self.lm_model.load_state_dict(state_dict[self._lm_key], strict)
self.retriever.load_state_dict(state_dict[self._retriever_key], strict)
class REALMRetriever(MegatronModule):
"""Retriever which uses a pretrained ICTBertModel and a HashedIndex"""
def __init__(self, ict_model, ict_dataset, block_data, hashed_index, top_k=5):
super(REALMRetriever, self).__init__()
self.ict_model = ict_model
self.ict_dataset = ict_dataset
self.block_data = block_data
self.hashed_index = hashed_index
self.top_k = top_k
self._ict_key = 'ict_model'
def retrieve_evidence_blocks_text(self, query_text):
"""Get the top k evidence blocks for query_text in text form"""
print("-" * 100)
print("Query: ", query_text)
padless_max_len = self.ict_dataset.max_seq_length - 2
query_tokens = self.ict_dataset.encode_text(query_text)[:padless_max_len]
query_tokens, query_pad_mask = self.ict_dataset.concat_and_pad_tokens(query_tokens)
query_tokens = torch.cuda.LongTensor(np.array(query_tokens).reshape(1, -1))
query_pad_mask = torch.cuda.LongTensor(np.array(query_pad_mask).reshape(1, -1))
top5_block_tokens, _ = self.retrieve_evidence_blocks(query_tokens, query_pad_mask)
for i, block in enumerate(top5_block_tokens[0]):
block_text = self.ict_dataset.decode_tokens(block)
print('\n > Block {}: {}'.format(i, block_text))
def retrieve_evidence_blocks(self, query_tokens, query_pad_mask):
"""Embed blocks to be used in a forward pass"""
with torch.no_grad():
true_model = self.ict_model.module.module
query_embeds = detach(true_model.embed_query(query_tokens, query_pad_mask))
_, block_indices = self.hashed_index.search_mips_index(query_embeds, top_k=self.top_k, reconstruct=False)
all_top5_tokens, all_top5_pad_masks = [], []
for indices in block_indices:
# [k x meta_dim]
top5_metas = np.array([self.block_data.meta_data[idx] for idx in indices])
top5_block_data = [self.ict_dataset.get_block(*block_meta) for block_meta in top5_metas]
top5_tokens, top5_pad_masks = zip(*top5_block_data)
all_top5_tokens.append(np.array(top5_tokens))
all_top5_pad_masks.append(np.array(top5_pad_masks))
# [batch_size x k x seq_length]
return np.array(all_top5_tokens), np.array(all_top5_pad_masks)
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
"""For easy load when model is combined with other heads,
add an extra key."""
state_dict_ = {}
state_dict_[self._ict_key] = self.ict_model.state_dict_for_save_checkpoint(destination, prefix, keep_vars)
return state_dict_
def load_state_dict(self, state_dict, strict=True):
"""Load the state dicts of each of the models"""
self.ict_model.load_state_dict(state_dict[self._ict_key], strict)
class ICTBertModel(MegatronModule):
"""Bert-based module for Inverse Cloze task."""
def __init__(self,
ict_head_size,
num_tokentypes=1,
parallel_output=True,
only_query_model=False,
only_block_model=False):
super(ICTBertModel, self).__init__()
bert_args = dict(
num_tokentypes=num_tokentypes,
add_binary_head=False,
ict_head_size=ict_head_size,
parallel_output=parallel_output
)
assert not (only_block_model and only_query_model)
self.use_block_model = not only_query_model
self.use_query_model = not only_block_model
if self.use_query_model:
# this model embeds (pseudo-)queries - Embed_input in the paper
self.query_model = BertModel(**bert_args)
self._query_key = 'question_model'
if self.use_block_model:
# this model embeds evidence blocks - Embed_doc in the paper
self.block_model = BertModel(**bert_args)
self._block_key = 'context_model'
def forward(self, query_tokens, query_attention_mask, block_tokens, block_attention_mask, only_query=False, only_block=False):
"""Run a forward pass for each of the models and compute the similarity scores."""
if only_query:
return self.embed_query(query_tokens, query_attention_mask)
if only_block:
return self.embed_block(block_tokens, block_attention_mask)
query_logits = self.embed_query(query_tokens, query_attention_mask)
block_logits = self.embed_block(block_tokens, block_attention_mask)
# [batch x embed] * [embed x batch]
retrieval_scores = query_logits.matmul(torch.transpose(block_logits, 0, 1))
return retrieval_scores
def embed_query(self, query_tokens, query_attention_mask):
"""Embed a batch of tokens using the query model"""
if self.use_query_model:
query_types = torch.zeros(query_tokens.shape).type(torch.int64).cuda()
query_ict_logits, _ = self.query_model.forward(query_tokens, query_attention_mask, query_types)
return query_ict_logits
else:
raise ValueError("Cannot embed query without query model.")
def embed_block(self, block_tokens, block_attention_mask):
"""Embed a batch of tokens using the block model"""
if self.use_block_model:
block_types = torch.zeros(block_tokens.shape).type(torch.int64).cuda()
block_ict_logits, _ = self.block_model.forward(block_tokens, block_attention_mask, block_types)
return block_ict_logits
else:
raise ValueError("Cannot embed block without block model.")
def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False):
"""Save dict with state dicts of each of the models."""
state_dict_ = {}
if self.use_query_model:
state_dict_[self._query_key] \
= self.query_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
if self.use_block_model:
state_dict_[self._block_key] \
= self.block_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
return state_dict_
def load_state_dict(self, state_dict, strict=True):
"""Load the state dicts of each of the models"""
if self.use_query_model:
print("Loading ICT query model", flush=True)
self.query_model.load_state_dict(
state_dict[self._query_key], strict=strict)
if self.use_block_model:
print("Loading ICT block model", flush=True)
self.block_model.load_state_dict(
state_dict[self._block_key], strict=strict)
......@@ -44,7 +44,7 @@ def model_provider():
hashed_index = FaissMIPSIndex(index_type='flat_l2', embed_size=128)
hashed_index.add_block_embed_data(all_block_data)
retriever = REALMRetriever(ict_model, ict_dataset, all_block_data, hashed_index)
retriever = REALMRetriever(ict_model, ict_dataset, all_block_data, hashed_index, args.block_top_k)
# TODO: REALMBertModel should accept a path to a pretrained bert-base
model = REALMBertModel(retriever)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment