Commit 20895f2c authored by Neel Kant's avatar Neel Kant
Browse files

runs with new log loss but plateaus early

parent 51204a4d
......@@ -176,7 +176,7 @@ class AsyncIndexBuilder(IndexBuilder):
print(">>>>> No realm chkpt available", flush=True)
self.model = load_ict_checkpoint(only_block_model=True, no_grad=True, from_realm_chkpt=False)
self.model.eval()
self.dataloader = iter(get_one_epoch_dataloader(get_ict_dataset()))
self.dataloader = iter(get_one_epoch_dataloader(get_ict_dataset(), batch_size=128))
self.block_data = BlockData()
def send_index_ready_signal(self):
......
......@@ -150,6 +150,12 @@ class FaissMIPSIndex(object):
for j in range(block_indices.shape[1]):
fresh_indices[i, j] = self.id_map[block_indices[i, j]]
block_indices = fresh_indices
args = get_args()
if args.rank == 0:
torch.save({'query_embeds': query_embeds,
'id_map': self.id_map,
'block_indices': block_indices,
'distances': distances}, 'search.data')
return distances, block_indices
# functions below are for ALSH, which currently isn't being used
......
......@@ -92,9 +92,9 @@ class REALMBertModel(MegatronModule):
self.retriever = retriever
self.top_k = self.retriever.top_k
self._retriever_key = 'retriever'
# self.eval()
def forward(self, tokens, attention_mask, query_block_indices, return_topk_block_tokens=False):
# print("\nNEW FORWARD", '-' * 100, flush=True)
dset = self.retriever.ict_dataset
det_tokens = detach(tokens)[0].tolist()
......@@ -112,7 +112,6 @@ class REALMBertModel(MegatronModule):
# text = dset.decode_tokens(det_tokens)
# print(text, flush=True)
# print("Token shape: ", tokens.shape, flush=True)
# [batch_size x k x seq_length]
topk_block_tokens, topk_block_attention_mask = self.retriever.retrieve_evidence_blocks(
......@@ -132,23 +131,21 @@ class REALMBertModel(MegatronModule):
# [batch_size x k x embed_size]
true_model = self.retriever.ict_model.module.module
fresh_block_logits = mpu.checkpoint(true_model.embed_block, topk_block_tokens, topk_block_attention_mask)
fresh_block_logits = fresh_block_logits.reshape(batch_size, self.top_k, -1)
fresh_block_logits = fresh_block_logits.reshape(batch_size, self.top_k, -1).float()
# print('Fresh block logits shape: ', fresh_block_logits.shape, flush=True)
# [batch_size x 1 x embed_size]
query_logits = mpu.checkpoint(true_model.embed_query, tokens, attention_mask).unsqueeze(1)
# print('Query logits shape: ', query_logits.shape, flush=True)
query_logits = mpu.checkpoint(true_model.embed_query, tokens, attention_mask).unsqueeze(1).float()
# [batch_size x k]
fresh_block_scores = torch.matmul(query_logits, torch.transpose(fresh_block_logits, 1, 2)).squeeze()
# print('Block score shape: ', fresh_block_scores.shape, flush=True)
block_probs = F.softmax(fresh_block_scores, dim=1)
# [batch_size * k x seq_length]
tokens = torch.stack([tokens.unsqueeze(1)] * self.top_k, dim=1).reshape(-1, seq_length)
#assert all(tokens[i] == tokens[0] for i in range(self.top_k))
#assert all(tokens[i] == tokens[self.top_k] for i in range(self.top_k, 2 * self.top_k))
#assert not any(tokens[i] == tokens[0] for i in range(self.top_k, batch_size * self.top_k))
# assert all(torch.equal(tokens[i], tokens[0]) for i in range(self.top_k))
# assert all(torch.equal(tokens[i], tokens[self.top_k]) for i in range(self.top_k, 2 * self.top_k))
# assert not any(torch.equal(tokens[i], tokens[0]) for i in range(self.top_k, batch_size * self.top_k))
attention_mask = torch.stack([attention_mask.unsqueeze(1)] * self.top_k, dim=1).reshape(-1, seq_length)
# [batch_size * k x 2 * seq_length]
......@@ -156,9 +153,6 @@ class REALMBertModel(MegatronModule):
all_tokens = torch.zeros(lm_input_batch_shape).long().cuda()
all_attention_mask = all_tokens.clone()
all_token_types = all_tokens.clone()
#all_tokens = torch.cat((tokens, topk_block_tokens), axis=1)
#all_attention_mask = torch.cat((attention_mask, topk_block_attention_mask), axis=1)
#all_token_types = torch.zeros(all_tokens.shape).type(torch.int64).cuda()
query_lengths = torch.sum(attention_mask, axis=1)
# all blocks (including null ones) will have two SEP tokens
......@@ -169,23 +163,40 @@ class REALMBertModel(MegatronModule):
# block body ends after the second SEP
block_ends = block_sep_indices[:, 1, 1] + 1
# block_lengths = torch.sum(topk_block_attention_mask, axis=1)
print('-' * 100)
for row_num in range(all_tokens.shape[0]):
q_len = query_lengths[row_num]
b_start = block_starts[row_num]
b_end = block_ends[row_num]
# new tokens = CLS + query + SEP + block + SEP
# new_tokens_length = q_len + b_end - b_start
new_tokens_length = q_len
new_tokens_length = q_len + b_end - b_start
# splice query and block tokens accordingly
all_tokens[row_num, :q_len] = tokens[row_num, :q_len]
# all_tokens[row_num, q_len:new_tokens_length] = topk_block_tokens[row_num, b_start:b_end]
all_tokens[row_num, q_len:new_tokens_length] = topk_block_tokens[row_num, b_start:b_end]
all_tokens[row_num, new_tokens_length:] = self.retriever.ict_dataset.pad_id
# print(dset.decode_tokens(detach(all_tokens[row_num]).tolist()), '\n', flush=True)
print(dset.decode_tokens(detach(all_tokens[row_num]).tolist()), '\n', flush=True)
all_attention_mask[row_num, :new_tokens_length] = 1
all_attention_mask[row_num, new_tokens_length:] = 0
print('-' * 100)
args = get_args()
if args.rank == 0:
torch.save({'lm_tokens': all_tokens,
'lm_attn_mask': all_attention_mask,
'query_tokens': tokens,
'query_attn_mask': attention_mask,
'query_logits': query_logits,
'block_tokens': topk_block_tokens,
'block_attn_mask': topk_block_attention_mask,
'block_logits': fresh_block_logits,
'block_probs': block_probs,
}, 'final_lm_inputs.data')
# assert all(torch.equal(all_tokens[i], all_tokens[0]) for i in range(self.top_k))
# assert all(torch.equal(all_attention_mask[i], all_attention_mask[0]) for i in range(self.top_k))
# [batch_size x k x 2 * seq_length x vocab_size]
lm_logits, _ = self.lm_model.forward(all_tokens, all_attention_mask, all_token_types)
......@@ -261,7 +272,7 @@ class REALMRetriever(MegatronModule):
true_model = self.ict_model
# print("true model: ", true_model, flush=True)
query_embeds = self.ict_model(query_tokens, query_pad_mask, None, None, only_query=True)
query_embeds = true_model.embed_query(query_tokens, query_pad_mask)
_, block_indices = self.hashed_index.search_mips_index(query_embeds, top_k=self.top_k, reconstruct=False)
all_topk_tokens, all_topk_pad_masks = [], []
......
......@@ -242,6 +242,8 @@ def setup_model_and_optimizer(model_provider_func):
def backward_step(optimizer, model, loss):
"""Backward step."""
# if args.rank == 0:
# torch.save(lick)
args = get_args()
timers = get_timers()
torch.cuda.synchronize()
......@@ -392,7 +394,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
recv_handle = torch.distributed.broadcast(INDEX_READY, args.max_training_rank, group=get_gloo_comm_group(), async_op=True)
last_reload_iteration = iteration
while iteration < args.train_iters:
if args.max_training_rank is not None and iteration >= last_reload_iteration + 500:
if args.max_training_rank is not None and iteration >= last_reload_iteration + 100:
if recv_handle.is_completed():
# should add check that INDEX_READY == 1 but what else could be happening
true_model = model
......
......@@ -14,6 +14,9 @@
# limitations under the License.
"""Pretrain BERT for Inverse Cloze Task"""
import sys
import numpy as np
import torch
import torch.nn.functional as F
......@@ -29,6 +32,7 @@ from megatron.training import pretrain
from megatron.utils import reduce_losses, report_memory
from megatron import mpu
from indexer import initialize_and_run_async_megatron
from megatron.mpu.initialize import get_data_parallel_group
num_batches = 0
......@@ -44,7 +48,6 @@ def model_provider():
ict_model = load_ict_checkpoint(from_realm_chkpt=False)
ict_dataset = get_ict_dataset(use_titles=False)
all_block_data = BlockData.load_from_file(args.block_data_path)
# hashed_index = RandProjectionLSHIndex.load_from_file(args.block_index_path)
hashed_index = FaissMIPSIndex(index_type='flat_ip', embed_size=128, use_gpu=args.faiss_use_gpu)
hashed_index.add_block_embed_data(all_block_data)
......@@ -66,8 +69,6 @@ def get_batch(data_iterator):
else:
data = next(data_iterator)
data_b = mpu.broadcast_data(keys, data, datatype)
# Unpack.
......@@ -96,21 +97,57 @@ def forward_step(data_iterator, model):
# Forward model.
lm_logits, block_probs = model(tokens, pad_mask, query_block_indices)
# print('logits shape: ', lm_logits.shape, flush=True)
# print('labels shape: ', labels.shape, flush=True)
with torch.no_grad():
max_retrieval_utility, top_retrieval_utility, avg_retrieval_utility = mpu.checkpoint(
get_retrieval_utility, lm_logits, block_probs, labels, loss_mask)
# P(y|x) = sum_z(P(y|z, x) * P(z|x))
null_block_probs = torch.mean(block_probs[:, block_probs.shape[1] - 1])
block_probs = block_probs.unsqueeze(2).unsqueeze(3).expand_as(lm_logits)
lm_logits = torch.sum(lm_logits * block_probs, dim=1)[:, :labels.shape[1]]
lm_loss_ = mpu.vocab_parallel_cross_entropy(lm_logits.contiguous().float(),
labels.contiguous())
lm_loss = torch.sum(
lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
# logits: [batch x top_k x 2 * seq_length x vocab_size]
# labels: [batch x seq_length]
relevant_logits = lm_logits[:, :, :labels.shape[1]].float()
# if get_args().rank == 0:
# torch.save({'logits': relevant_logits.cpu(),
# 'block_probs': block_probs.cpu(),
# 'labels': labels.cpu(),
# 'loss_mask': loss_mask.cpu(),
# 'tokens': tokens.cpu(),
# 'pad_mask': pad_mask.cpu(),
# }, 'tensors.data')
# torch.load('gagaga')
block_probs = block_probs.unsqueeze(2).unsqueeze(3).expand_as(relevant_logits)
# print(torch.sum(block_probs, dim=1), flush=True)
def get_log_probs(logits, b_probs):
max_logits = torch.max(logits, dim=-1, keepdim=True)[0].expand_as(logits)
logits = logits - max_logits
softmaxed_logits = F.softmax(logits, dim=-1)
marginalized_probs = torch.sum(softmaxed_logits * b_probs, dim=1)
l_probs = torch.log(marginalized_probs)
return l_probs
log_probs = mpu.checkpoint(get_log_probs, relevant_logits, block_probs)
def get_loss(l_probs, labs):
vocab_size = l_probs.shape[2]
loss = torch.nn.NLLLoss(ignore_index=-1)(l_probs.reshape(-1, vocab_size), labs.reshape(-1))
# loss = torch.sum(lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
return loss.float()
lm_loss = mpu.checkpoint(get_loss, log_probs, labels)
# marginalized_logits = torch.sum(relevant_logits * block_probs, dim=1)
# vocab_size = marginalized_logits.shape[2]
# lm_loss_ = torch.nn.CrossEntropyLoss()(marginalized_logits.reshape(-1, vocab_size), labels.reshape(-1))
# lm_loss = torch.sum(lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
reduced_loss = reduce_losses([lm_loss, max_retrieval_utility, top_retrieval_utility, avg_retrieval_utility, null_block_probs])
# reduced_loss = reduce_losses([lm_loss])
# torch.cuda.synchronize()
return lm_loss, {'lm_loss': reduced_loss[0],
'max_ru': reduced_loss[1],
......@@ -119,10 +156,10 @@ def forward_step(data_iterator, model):
'null_prob': reduced_loss[4]}
def get_retrieval_utility(lm_logits, block_probs, labels, loss_mask):
def get_retrieval_utility(lm_logits_, block_probs, labels, loss_mask):
"""log P(y | z, x) - log P(y | null, x)"""
# [batch x seq_len x vocab_size]
lm_logits = lm_logits[:, :, :labels.shape[1], :]
lm_logits = lm_logits_[:, :, :labels.shape[1], :]
#non_null_block_probs = block_probs[:, :-1]
#non_null_block_probs /= torch.sum(non_null_block_probs, axis=1, keepdim=True)
# non_null_block_probs = non_null_block_probsexpand_as(lm_logits[:, :-1, :, :])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment