Commit 6c0a5bd8 authored by Neel Kant's avatar Neel Kant
Browse files

Update and test misc functionality

parent 1eccfc94
...@@ -29,7 +29,8 @@ class HashedIndex(object): ...@@ -29,7 +29,8 @@ class HashedIndex(object):
np.random.seed(seed) np.random.seed(seed)
self.block_data = defaultdict(list) self.block_data = defaultdict(list)
self.hash_data = defaultdict(list) self.hash_data = defaultdict(list)
self.hash_matrix = np.random.rand(embed_size, int(num_buckets / 2)) hash_matrix = np.random.rand(embed_size, int(num_buckets / 2))
self.hash_matrix = hash_matrix / np.linalg.norm(hash_matrix, axis=0).reshape(1, -1)
def state(self): def state(self):
state = { state = {
...@@ -47,7 +48,7 @@ class HashedIndex(object): ...@@ -47,7 +48,7 @@ class HashedIndex(object):
def hash_embeds(self, embeds, block_data=None): def hash_embeds(self, embeds, block_data=None):
"""Hash a tensor of embeddings using a random projection matrix""" """Hash a tensor of embeddings using a random projection matrix"""
embed_scores_pos = torch.matmul(embeds, torch.cuda.HalfTensor(self.hash_matrix)) embed_scores_pos = torch.matmul(embeds, torch.cuda.FloatTensor(self.hash_matrix))
embed_scores = torch.cat((embed_scores_pos, -embed_scores_pos), axis=1) embed_scores = torch.cat((embed_scores_pos, -embed_scores_pos), axis=1)
embed_hashes = detach(torch.argmax(embed_scores, axis=1)) embed_hashes = detach(torch.argmax(embed_scores, axis=1))
...@@ -62,7 +63,7 @@ class HashedIndex(object): ...@@ -62,7 +63,7 @@ class HashedIndex(object):
for idx, embed in zip(block_indices, block_embeds): for idx, embed in zip(block_indices, block_embeds):
if not allow_overwrite and int(idx) in self.block_data: if not allow_overwrite and int(idx) in self.block_data:
raise ValueError("Attempted to overwrite a read-only HashedIndex") raise ValueError("Attempted to overwrite a read-only HashedIndex")
self.block_data[int(idx)] = embed self.block_data[int(idx)] = np.float16(embed)
def save_shard(self, rank): def save_shard(self, rank):
dir_name = 'block_hash_data' dir_name = 'block_hash_data'
...@@ -92,7 +93,8 @@ class HashedIndex(object): ...@@ -92,7 +93,8 @@ class HashedIndex(object):
for bucket, items in data['hash_data'].items(): for bucket, items in data['hash_data'].items():
self.hash_data[bucket].extend(items) self.hash_data[bucket].extend(items)
with open('block_hash_data.pkl', 'wb') as final_file: args = get_args()
with open(args.hash_data_path, 'wb') as final_file:
pickle.dump(self.state(), final_file) pickle.dump(self.state(), final_file)
shutil.rmtree(dir_name, ignore_errors=True) shutil.rmtree(dir_name, ignore_errors=True)
...@@ -119,7 +121,7 @@ def test_retriever(): ...@@ -119,7 +121,7 @@ def test_retriever():
initialize_megatron(extra_args_provider=None, initialize_megatron(extra_args_provider=None,
args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'}) args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
args = get_args() args = get_args()
model = load_ict_checkpoint() model = load_ict_checkpoint(only_block_model=True)
model.eval() model.eval()
dataset = get_ict_dataset() dataset = get_ict_dataset()
hashed_index = HashedIndex.load_from_file(args.hash_data_path) hashed_index = HashedIndex.load_from_file(args.hash_data_path)
...@@ -158,11 +160,11 @@ def main(): ...@@ -158,11 +160,11 @@ def main():
initialize_megatron(extra_args_provider=None, initialize_megatron(extra_args_provider=None,
args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'}) args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
args = get_args() args = get_args()
model = load_ict_checkpoint() model = load_ict_checkpoint(only_block_model=True, no_grad=True)
model.eval() model.eval()
dataset = get_ict_dataset() dataset = get_ict_dataset()
data_iter = iter(get_dataloader(dataset)) data_iter = iter(get_dataloader(dataset))
hashed_index = HashedIndex(embed_size=128, num_buckets=2048) hashed_index = HashedIndex(embed_size=128, num_buckets=4096)
i = 0 i = 0
while True: while True:
...@@ -172,10 +174,8 @@ def main(): ...@@ -172,10 +174,8 @@ def main():
except: except:
break break
actual_model = model.module.module
block_indices = detach(block_indices) block_indices = detach(block_indices)
block_logits = model(None, None, block_tokens, block_pad_mask, only_block=True)
block_logits = actual_model.embed_block(block_tokens, block_pad_mask)
hashed_index.hash_embeds(block_logits, block_indices) hashed_index.hash_embeds(block_logits, block_indices)
hashed_index.assign_block_embeds(block_indices[:,3], detach(block_logits)) hashed_index.assign_block_embeds(block_indices[:,3], detach(block_logits))
...@@ -193,9 +193,9 @@ def main(): ...@@ -193,9 +193,9 @@ def main():
hashed_index.clear() hashed_index.clear()
def load_ict_checkpoint(): def load_ict_checkpoint(only_query_model=False, only_block_model=False, no_grad=False):
args = get_args() args = get_args()
model = get_model(model_provider) model = get_model(lambda: model_provider(only_query_model, only_block_model))
if isinstance(model, torchDDP): if isinstance(model, torchDDP):
model = model.module model = model.module
...@@ -210,6 +210,14 @@ def load_ict_checkpoint(): ...@@ -210,6 +210,14 @@ def load_ict_checkpoint():
torch.distributed.get_rank(), checkpoint_name)) torch.distributed.get_rank(), checkpoint_name))
state_dict = torch.load(checkpoint_name, map_location='cpu') state_dict = torch.load(checkpoint_name, map_location='cpu')
if only_query_model:
state_dict['model'].pop('context_model')
if only_block_model:
state_dict['model'].pop('question_model')
if no_grad:
with torch.no_grad():
model.load_state_dict(state_dict['model'])
else:
model.load_state_dict(state_dict['model']) model.load_state_dict(state_dict['model'])
torch.distributed.barrier() torch.distributed.barrier()
...@@ -261,4 +269,4 @@ def get_dataloader(dataset): ...@@ -261,4 +269,4 @@ def get_dataloader(dataset):
if __name__ == "__main__": if __name__ == "__main__":
test_retriever() main()
...@@ -131,8 +131,8 @@ class InverseClozeDataset(Dataset): ...@@ -131,8 +131,8 @@ class InverseClozeDataset(Dataset):
'the indices on rank 0 ...'.format(indexmap_filename)) 'the indices on rank 0 ...'.format(indexmap_filename))
# Make sure the types match the helpers input types. # Make sure the types match the helpers input types.
assert self.context_dataset.doc_idx.dtype == np.int64 assert self.block_dataset.doc_idx.dtype == np.int64
assert self.context_dataset.sizes.dtype == np.int32 assert self.block_dataset.sizes.dtype == np.int32
# Build samples mapping # Build samples mapping
verbose = torch.distributed.get_rank() == 0 verbose = torch.distributed.get_rank() == 0
...@@ -140,9 +140,9 @@ class InverseClozeDataset(Dataset): ...@@ -140,9 +140,9 @@ class InverseClozeDataset(Dataset):
print_rank_0(' > building samples index mapping for {} ...'.format( print_rank_0(' > building samples index mapping for {} ...'.format(
self.name)) self.name))
samples_mapping = helpers.build_blocks_mapping( samples_mapping = helpers.build_blocks_mapping(
self.context_dataset.doc_idx, self.block_dataset.doc_idx,
self.context_dataset.sizes, self.block_dataset.sizes,
self.titles_dataset.sizes, self.title_dataset.sizes,
num_epochs, num_epochs,
max_num_samples, max_num_samples,
self.max_seq_length-3, # account for added tokens self.max_seq_length-3, # account for added tokens
......
...@@ -47,7 +47,7 @@ def build_simple_training_sample(sample, target_seq_length, max_seq_length, ...@@ -47,7 +47,7 @@ def build_simple_training_sample(sample, target_seq_length, max_seq_length,
masked_labels, pad_id, max_seq_length) masked_labels, pad_id, max_seq_length)
# REALM true sequence length is twice as long but none of that is to be predicted with LM # REALM true sequence length is twice as long but none of that is to be predicted with LM
loss_mask_np = np.concatenate((loss_mask_np, np.ones(loss_mask_np.shape)), -1).astype(np.int64) # loss_mask_np = np.concatenate((loss_mask_np, np.ones(loss_mask_np.shape)), -1).astype(np.int64)
train_sample = { train_sample = {
'tokens': tokens_np, 'tokens': tokens_np,
......
...@@ -126,12 +126,18 @@ class BertModel(MegatronModule): ...@@ -126,12 +126,18 @@ class BertModel(MegatronModule):
add_pooler = self.add_binary_head or self.add_ict_head add_pooler = self.add_binary_head or self.add_ict_head
scaled_init_method = scaled_init_method_normal(args.init_method_std, scaled_init_method = scaled_init_method_normal(args.init_method_std,
args.num_layers) args.num_layers)
max_pos_embeds = None
if not add_binary_head and ict_head_size is None:
max_pos_embeds = 2 * args.seq_length
self.language_model, self._language_model_key = get_language_model( self.language_model, self._language_model_key = get_language_model(
attention_mask_func=bert_attention_mask_func, attention_mask_func=bert_attention_mask_func,
num_tokentypes=num_tokentypes, num_tokentypes=num_tokentypes,
add_pooler=add_pooler, add_pooler=add_pooler,
init_method=init_method, init_method=init_method,
scaled_init_method=scaled_init_method) scaled_init_method=scaled_init_method,
max_pos_embeds=max_pos_embeds)
if not self.add_ict_head: if not self.add_ict_head:
self.lm_head = BertLMHead( self.lm_head = BertLMHead(
...@@ -218,6 +224,8 @@ class BertModel(MegatronModule): ...@@ -218,6 +224,8 @@ class BertModel(MegatronModule):
class REALMBertModel(MegatronModule): class REALMBertModel(MegatronModule):
# TODO: load BertModel checkpoint
def __init__(self, retriever): def __init__(self, retriever):
super(REALMBertModel, self).__init__() super(REALMBertModel, self).__init__()
bert_args = dict( bert_args = dict(
...@@ -241,10 +249,11 @@ class REALMBertModel(MegatronModule): ...@@ -241,10 +249,11 @@ class REALMBertModel(MegatronModule):
top5_block_attention_mask = torch.cuda.LongTensor(top5_block_attention_mask).reshape(-1, seq_length) top5_block_attention_mask = torch.cuda.LongTensor(top5_block_attention_mask).reshape(-1, seq_length)
# [batch_size x 5 x embed_size] # [batch_size x 5 x embed_size]
fresh_block_logits = self.retriever.ict_model.module.module.embed_block(top5_block_tokens, top5_block_attention_mask).reshape(batch_size, 5, -1) fresh_block_logits = self.retriever.ict_model(None, None, top5_block_tokens, top5_block_attention_mask, only_block=True).reshape(batch_size, 5, -1)
# fresh_block_logits.register_hook(lambda x: print("fresh block: ", x.shape, flush=True))
# [batch_size x embed_size x 1] # [batch_size x embed_size x 1]
query_logits = self.retriever.ict_model.module.module.embed_query(tokens, attention_mask).unsqueeze(2) query_logits = self.retriever.ict_model(tokens, attention_mask, None, None, only_query=True).unsqueeze(2)
# [batch_size x 5] # [batch_size x 5]
...@@ -282,6 +291,7 @@ class REALMRetriever(MegatronModule): ...@@ -282,6 +291,7 @@ class REALMRetriever(MegatronModule):
self.ict_model = ict_model self.ict_model = ict_model
self.ict_dataset = ict_dataset self.ict_dataset = ict_dataset
self.hashed_index = hashed_index self.hashed_index = hashed_index
self.top_k = top_k
def retrieve_evidence_blocks_text(self, query_text): def retrieve_evidence_blocks_text(self, query_text):
"""Get the top k evidence blocks for query_text in text form""" """Get the top k evidence blocks for query_text in text form"""
...@@ -300,16 +310,25 @@ class REALMRetriever(MegatronModule): ...@@ -300,16 +310,25 @@ class REALMRetriever(MegatronModule):
print('\n > Block {}: {}'.format(i, block_text)) print('\n > Block {}: {}'.format(i, block_text))
def retrieve_evidence_blocks(self, query_tokens, query_pad_mask): def retrieve_evidence_blocks(self, query_tokens, query_pad_mask):
query_embeds = self.ict_model.module.module.embed_query(query_tokens, query_pad_mask) """Embed blocks to be used in a forward pass"""
query_embeds = self.ict_model(query_tokens, query_pad_mask, None, None, only_query=True)
query_hashes = self.hashed_index.hash_embeds(query_embeds) query_hashes = self.hashed_index.hash_embeds(query_embeds)
block_buckets = [self.hashed_index.get_block_bucket(hash) for hash in query_hashes] block_buckets = [self.hashed_index.get_block_bucket(hash) for hash in query_hashes]
block_embeds = [torch.cuda.HalfTensor(np.array([self.hashed_index.get_block_embed(arr[3]) for j, bucket in enumerate(block_buckets):
if len(bucket) < 5:
for i in range(len(block_buckets)):
if len(block_buckets[i]) > 5:
block_buckets[j] = block_buckets[i].copy()
# [batch_size x max_bucket_population x embed_size]
block_embeds = [torch.cuda.FloatTensor(np.array([self.hashed_index.get_block_embed(arr[3])
for arr in bucket])) for bucket in block_buckets] for arr in bucket])) for bucket in block_buckets]
all_top5_tokens, all_top5_pad_masks = [], [] all_top5_tokens, all_top5_pad_masks = [], []
for query_embed, embed_tensor, bucket in zip(query_embeds, block_embeds, block_buckets): for query_embed, embed_tensor, bucket in zip(query_embeds, block_embeds, block_buckets):
retrieval_scores = query_embed.matmul(torch.transpose(embed_tensor, 0, 1)) retrieval_scores = query_embed.matmul(torch.transpose(embed_tensor.reshape(-1, query_embed.size()[0]), 0, 1))
print(retrieval_scores.shape, flush=True)
top5_vals, top5_indices = torch.topk(retrieval_scores, k=5, sorted=True) top5_vals, top5_indices = torch.topk(retrieval_scores, k=5, sorted=True)
top5_start_end_doc = [bucket[idx][:3] for idx in top5_indices.squeeze()] top5_start_end_doc = [bucket[idx][:3] for idx in top5_indices.squeeze()]
...@@ -354,8 +373,16 @@ class ICTBertModel(MegatronModule): ...@@ -354,8 +373,16 @@ class ICTBertModel(MegatronModule):
self.block_model = BertModel(**bert_args) self.block_model = BertModel(**bert_args)
self._block_key = 'context_model' self._block_key = 'context_model'
def forward(self, query_tokens, query_attention_mask, block_tokens, block_attention_mask): def forward(self, query_tokens, query_attention_mask, block_tokens, block_attention_mask, only_query=False, only_block=False):
"""Run a forward pass for each of the models and compute the similarity scores.""" """Run a forward pass for each of the models and compute the similarity scores."""
if only_query:
return self.embed_query(query_tokens, query_attention_mask)
if only_block:
return self.embed_block(block_tokens, block_attention_mask)
query_logits = self.embed_query(query_tokens, query_attention_mask) query_logits = self.embed_query(query_tokens, query_attention_mask)
block_logits = self.embed_block(block_tokens, block_attention_mask) block_logits = self.embed_block(block_tokens, block_attention_mask)
...@@ -399,9 +426,11 @@ class ICTBertModel(MegatronModule): ...@@ -399,9 +426,11 @@ class ICTBertModel(MegatronModule):
def load_state_dict(self, state_dict, strict=True): def load_state_dict(self, state_dict, strict=True):
"""Load the state dicts of each of the models""" """Load the state dicts of each of the models"""
if self.use_query_model: if self.use_query_model:
print("Loading ICT query model", flush=True)
self.query_model.load_state_dict( self.query_model.load_state_dict(
state_dict[self._query_key], strict=strict) state_dict[self._query_key], strict=strict)
if self.use_block_model: if self.use_block_model:
print("Loading ICT block model", flush=True)
self.block_model.load_state_dict( self.block_model.load_state_dict(
state_dict[self._block_key], strict=strict) state_dict[self._block_key], strict=strict)
...@@ -45,7 +45,7 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, ...@@ -45,7 +45,7 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
def get_language_model(attention_mask_func, num_tokentypes, add_pooler, def get_language_model(attention_mask_func, num_tokentypes, add_pooler,
init_method, scaled_init_method): init_method, scaled_init_method, max_pos_embeds=None):
"""Build language model and return along with the key to save.""" """Build language model and return along with the key to save."""
# Language model. # Language model.
...@@ -55,7 +55,8 @@ def get_language_model(attention_mask_func, num_tokentypes, add_pooler, ...@@ -55,7 +55,8 @@ def get_language_model(attention_mask_func, num_tokentypes, add_pooler,
init_method=init_method, init_method=init_method,
output_layer_init_method=scaled_init_method, output_layer_init_method=scaled_init_method,
num_tokentypes=num_tokentypes, num_tokentypes=num_tokentypes,
add_pooler=add_pooler) add_pooler=add_pooler,
max_pos_embeds=max_pos_embeds)
# key used for checkpoints. # key used for checkpoints.
language_model_key = 'language_model' language_model_key = 'language_model'
...@@ -266,7 +267,8 @@ class TransformerLanguageModel(MegatronModule): ...@@ -266,7 +267,8 @@ class TransformerLanguageModel(MegatronModule):
init_method, init_method,
output_layer_init_method, output_layer_init_method,
num_tokentypes=0, num_tokentypes=0,
add_pooler=False): add_pooler=False,
max_pos_embeds=None):
super(TransformerLanguageModel, self).__init__() super(TransformerLanguageModel, self).__init__()
args = get_args() args = get_args()
...@@ -275,10 +277,11 @@ class TransformerLanguageModel(MegatronModule): ...@@ -275,10 +277,11 @@ class TransformerLanguageModel(MegatronModule):
self.init_method = init_method self.init_method = init_method
self.add_pooler = add_pooler self.add_pooler = add_pooler
max_pos_embeds = args.max_position_embeddings if max_pos_embeds is None else max_pos_embeds
# Embeddings # Embeddings
self.embedding = Embedding(self.hidden_size, self.embedding = Embedding(self.hidden_size,
args.padded_vocab_size, args.padded_vocab_size,
args.max_position_embeddings, max_pos_embeds,
args.hidden_dropout, args.hidden_dropout,
self.init_method, self.init_method,
self.num_tokentypes) self.num_tokentypes)
......
...@@ -225,7 +225,7 @@ def backward_step(optimizer, model, loss): ...@@ -225,7 +225,7 @@ def backward_step(optimizer, model, loss):
"""Backward step.""" """Backward step."""
args = get_args() args = get_args()
timers = get_timers() timers = get_timers()
print("start backward", flush=True) torch.cuda.synchronize()
# Backward pass. # Backward pass.
optimizer.zero_grad() optimizer.zero_grad()
...@@ -250,6 +250,7 @@ def backward_step(optimizer, model, loss): ...@@ -250,6 +250,7 @@ def backward_step(optimizer, model, loss):
else: else:
optimizer.clip_master_grads(args.clip_grad) optimizer.clip_master_grads(args.clip_grad)
ran_backward_once = False
def train_step(forward_step_func, data_iterator, def train_step(forward_step_func, data_iterator,
model, optimizer, lr_scheduler): model, optimizer, lr_scheduler):
...@@ -262,11 +263,12 @@ def train_step(forward_step_func, data_iterator, ...@@ -262,11 +263,12 @@ def train_step(forward_step_func, data_iterator,
loss, loss_reduced = forward_step_func(data_iterator, model) loss, loss_reduced = forward_step_func(data_iterator, model)
timers('forward').stop() timers('forward').stop()
# Calculate gradients, reduce across processes, and clip.
timers('backward').start() timers('backward').start()
backward_step(optimizer, model, loss) backward_step(optimizer, model, loss)
timers('backward').stop() timers('backward').stop()
# Calculate gradients, reduce across processes, and clip.
# Update parameters. # Update parameters.
timers('optimizer').start() timers('optimizer').start()
optimizer.step() optimizer.step()
......
...@@ -29,7 +29,7 @@ from megatron.utils import reduce_losses ...@@ -29,7 +29,7 @@ from megatron.utils import reduce_losses
num_batches = 0 num_batches = 0
def model_provider(): def model_provider(only_query_model=False, only_block_model=False):
"""Build the model.""" """Build the model."""
args = get_args() args = get_args()
print_rank_0('building BERT models ...') print_rank_0('building BERT models ...')
...@@ -37,7 +37,9 @@ def model_provider(): ...@@ -37,7 +37,9 @@ def model_provider():
model = ICTBertModel( model = ICTBertModel(
ict_head_size=128, ict_head_size=128,
num_tokentypes=2, num_tokentypes=2,
parallel_output=True) parallel_output=True,
only_query_model=only_query_model,
only_block_model=only_block_model)
return model return model
......
...@@ -38,9 +38,10 @@ def model_provider(): ...@@ -38,9 +38,10 @@ def model_provider():
ict_model = load_ict_checkpoint() ict_model = load_ict_checkpoint()
ict_dataset = get_ict_dataset() ict_dataset = get_ict_dataset()
hashed_index = HashedIndex.load_from_file('block_hash_data.pkl') hashed_index = HashedIndex.load_from_file(args.hash_data_path)
retriever = REALMRetriever(ict_model, ict_dataset, hashed_index) retriever = REALMRetriever(ict_model, ict_dataset, hashed_index)
# TODO: REALMBertModel should accept a path to a pretrained bert-base
model = REALMBertModel(retriever) model = REALMBertModel(retriever)
return model return model
...@@ -74,7 +75,6 @@ def forward_step(data_iterator, model): ...@@ -74,7 +75,6 @@ def forward_step(data_iterator, model):
# Get the batch. # Get the batch.
timers('batch generator').start() timers('batch generator').start()
tokens, labels, loss_mask, pad_mask = get_batch(data_iterator) tokens, labels, loss_mask, pad_mask = get_batch(data_iterator)
labels = torch.cat((labels, labels), axis=-1)
timers('batch generator').stop() timers('batch generator').stop()
# Forward model. # Forward model.
...@@ -83,13 +83,17 @@ def forward_step(data_iterator, model): ...@@ -83,13 +83,17 @@ def forward_step(data_iterator, model):
# P(y|x) = sum_z(P(y|z, x) * P(z|x)) # P(y|x) = sum_z(P(y|z, x) * P(z|x))
block_probs = block_probs.unsqueeze(2).unsqueeze(3).expand_as(lm_logits) block_probs = block_probs.unsqueeze(2).unsqueeze(3).expand_as(lm_logits)
lm_logits = torch.sum(lm_logits * block_probs, dim=1) #block_probs.register_hook(lambda x: print("block_probs: ", x.shape, flush=True))
lm_logits = torch.sum(lm_logits * block_probs, dim=1)[:, :labels.shape[1]]
lm_loss_ = mpu.vocab_parallel_cross_entropy(lm_logits.contiguous().float(), lm_loss_ = mpu.vocab_parallel_cross_entropy(lm_logits.contiguous().float(),
labels.contiguous()) labels.contiguous())
lm_loss = torch.sum( lm_loss = torch.sum(
lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum() lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
reduced_loss = reduce_losses([lm_loss]) reduced_loss = reduce_losses([lm_loss])
torch.cuda.synchronize()
print(reduced_loss, flush=True) print(reduced_loss, flush=True)
return lm_loss, {'lm_loss': reduced_loss[0]} return lm_loss, {'lm_loss': reduced_loss[0]}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment