Commit d7022c72 authored by Neel Kant's avatar Neel Kant
Browse files

Mostly debugged realm-mlm

parent 6f54f50f
...@@ -198,7 +198,7 @@ def load_ict_checkpoint(): ...@@ -198,7 +198,7 @@ def load_ict_checkpoint():
if isinstance(model, torchDDP): if isinstance(model, torchDDP):
model = model.module model = model.module
tracker_filename = get_checkpoint_tracker_filename(args.load) tracker_filename = get_checkpoint_tracker_filename(args.ict_load)
with open(tracker_filename, 'r') as f: with open(tracker_filename, 'r') as f:
iteration = int(f.read().strip()) iteration = int(f.read().strip())
......
...@@ -27,7 +27,6 @@ from megatron import mpu ...@@ -27,7 +27,6 @@ from megatron import mpu
from megatron.data.dataset_utils import build_training_sample from megatron.data.dataset_utils import build_training_sample
from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset
from megatron.data.ict_dataset import InverseClozeDataset from megatron.data.ict_dataset import InverseClozeDataset
from megatron.data.realm_dataset import RealmDataset
from megatron import print_rank_0 from megatron import print_rank_0
DATASET_TYPES = ['standard_bert', 'ict', 'realm'] DATASET_TYPES = ['standard_bert', 'ict', 'realm']
...@@ -76,6 +75,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, ...@@ -76,6 +75,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
print_split_stats('test', 2) print_split_stats('test', 2)
def build_dataset(index, name): def build_dataset(index, name):
from megatron.data.realm_dataset import RealmDataset
dataset = None dataset = None
if splits[index + 1] > splits[index]: if splits[index + 1] > splits[index]:
# Get the pointer to the original doc-idx so we can set it later. # Get the pointer to the original doc-idx so we can set it later.
......
...@@ -90,7 +90,7 @@ class InverseClozeDataset(Dataset): ...@@ -90,7 +90,7 @@ class InverseClozeDataset(Dataset):
block = list(itertools.chain(*block))[:self.max_seq_length - (3 + len(title))] block = list(itertools.chain(*block))[:self.max_seq_length - (3 + len(title))]
block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title) block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title)
return block_tokens, block_pad_mask return (block_tokens, block_pad_mask)
def concat_and_pad_tokens(self, tokens, title=None): def concat_and_pad_tokens(self, tokens, title=None):
"""concat with special tokens and pad sequence to self.max_seq_length""" """concat with special tokens and pad sequence to self.max_seq_length"""
......
...@@ -7,8 +7,8 @@ from megatron import get_tokenizer ...@@ -7,8 +7,8 @@ from megatron import get_tokenizer
from megatron.data.bert_dataset import BertDataset, get_samples_mapping_ from megatron.data.bert_dataset import BertDataset, get_samples_mapping_
from megatron.data.dataset_utils import create_masked_lm_predictions, pad_and_convert_to_numpy from megatron.data.dataset_utils import create_masked_lm_predictions, pad_and_convert_to_numpy
qa_nlp = spacy.load('en_core_web_lg') #qa_nlp = spacy.load('en_core_web_lg')
qa_nlp = None
class RealmDataset(BertDataset): class RealmDataset(BertDataset):
"""Dataset containing simple masked sentences for masked language modeling. """Dataset containing simple masked sentences for masked language modeling.
...@@ -47,7 +47,7 @@ def build_simple_training_sample(sample, target_seq_length, max_seq_length, ...@@ -47,7 +47,7 @@ def build_simple_training_sample(sample, target_seq_length, max_seq_length,
masked_labels, pad_id, max_seq_length) masked_labels, pad_id, max_seq_length)
# REALM true sequence length is twice as long but none of that is to be predicted with LM # REALM true sequence length is twice as long but none of that is to be predicted with LM
loss_mask_np = np.concatenate((loss_mask_np, np.ones(loss_mask_np.shape)), -1) loss_mask_np = np.concatenate((loss_mask_np, np.ones(loss_mask_np.shape)), -1).astype(np.int64)
train_sample = { train_sample = {
'tokens': tokens_np, 'tokens': tokens_np,
......
...@@ -234,22 +234,35 @@ class REALMBertModel(MegatronModule): ...@@ -234,22 +234,35 @@ class REALMBertModel(MegatronModule):
def forward(self, tokens, attention_mask): def forward(self, tokens, attention_mask):
# [batch_size x 5 x seq_length] # [batch_size x 5 x seq_length]
top5_block_tokens, top5_block_attention_mask = self.retriever.retrieve_evidence_blocks(tokens, attention_mask) top5_block_tokens, top5_block_attention_mask = self.retriever.retrieve_evidence_blocks(tokens, attention_mask)
batch_size = tokens.shape[0]
seq_length = top5_block_tokens.shape[2]
top5_block_tokens = torch.cuda.LongTensor(top5_block_tokens).reshape(-1, seq_length)
top5_block_attention_mask = torch.cuda.LongTensor(top5_block_attention_mask).reshape(-1, seq_length)
# [batch_size x 5 x embed_size]
fresh_block_logits = self.retriever.ict_model.module.module.embed_block(top5_block_tokens, top5_block_attention_mask).reshape(batch_size, 5, -1)
# [batch_size x embed_size x 1]
query_logits = self.retriever.ict_model.module.module.embed_query(tokens, attention_mask).unsqueeze(2)
# [batch_size x 5] # [batch_size x 5]
fresh_block_logits = self.retriever.ict_model.embed_block(top5_block_tokens, top5_block_attention_mask) fresh_block_scores = torch.matmul(fresh_block_logits, query_logits).squeeze()
block_probs = F.softmax(fresh_block_logits, axis=1) block_probs = F.softmax(fresh_block_scores, dim=1)
# [batch_size x 5 x seq_length] # [batch_size * 5 x seq_length]
tokens = torch.stack([tokens.unsqueeze(1)] * 5, dim=1) tokens = torch.stack([tokens.unsqueeze(1)] * 5, dim=1).reshape(-1, seq_length)
attention_mask = torch.stack([attention_mask.unsqueeze(1)] * 5, dim=1) attention_mask = torch.stack([attention_mask.unsqueeze(1)] * 5, dim=1).reshape(-1, seq_length)
# [batch_size x 5 x 2 * seq_length] # [batch_size * 5 x 2 * seq_length]
all_tokens = torch.cat((tokens, top5_block_tokens), axis=2) all_tokens = torch.cat((tokens, top5_block_tokens), axis=1)
all_attention_mask = torch.cat((attention_mask, top5_block_attention_mask), axis=2) all_attention_mask = torch.cat((attention_mask, top5_block_attention_mask), axis=1)
all_token_types = torch.zeros(all_tokens.shape).type(torch.int64).cuda() all_token_types = torch.zeros(all_tokens.shape).type(torch.int64).cuda()
# [batch_size x 5 x 2 * seq_length x vocab_size] # [batch_size x 5 x 2 * seq_length x vocab_size]
lm_logits, _ = self.lm_model.forward(all_tokens, all_attention_mask, all_token_types) lm_logits, _ = self.lm_model.forward(all_tokens, all_attention_mask, all_token_types)
lm_logits = lm_logits.reshape(batch_size, 5, 2 * seq_length, -1)
return lm_logits, block_probs return lm_logits, block_probs
def state_dict_for_save_checkpoint(self, destination=None, prefix='', def state_dict_for_save_checkpoint(self, destination=None, prefix='',
...@@ -263,7 +276,7 @@ class REALMBertModel(MegatronModule): ...@@ -263,7 +276,7 @@ class REALMBertModel(MegatronModule):
class REALMRetriever(MegatronModule): class REALMRetriever(MegatronModule):
"""Retriever which uses a pretrained ICTBertModel and a hashed_index""" """Retriever which uses a pretrained ICTBertModel and a HashedIndex"""
def __init__(self, ict_model, ict_dataset, hashed_index, top_k=5): def __init__(self, ict_model, ict_dataset, hashed_index, top_k=5):
super(REALMRetriever, self).__init__() super(REALMRetriever, self).__init__()
self.ict_model = ict_model self.ict_model = ict_model
...@@ -301,13 +314,14 @@ class REALMRetriever(MegatronModule): ...@@ -301,13 +314,14 @@ class REALMRetriever(MegatronModule):
top5_start_end_doc = [bucket[idx][:3] for idx in top5_indices.squeeze()] top5_start_end_doc = [bucket[idx][:3] for idx in top5_indices.squeeze()]
# top_k tuples of (block_tokens, block_pad_mask) # top_k tuples of (block_tokens, block_pad_mask)
top5_block_data = [(self.ict_dataset.get_block(*indices)) for indices in top5_start_end_doc] top5_block_data = [self.ict_dataset.get_block(*indices) for indices in top5_start_end_doc]
top5_tokens, top5_pad_masks = zip(top5_block_data)
top5_tokens, top5_pad_masks = zip(*top5_block_data)
all_top5_tokens.append(np.array(top5_tokens)) all_top5_tokens.append(np.array(top5_tokens))
all_top5_pad_masks.append(np.array(top5_pad_masks)) all_top5_pad_masks.append(np.array(top5_pad_masks))
return all_top5_tokens, all_top5_pad_masks return np.array(all_top5_tokens), np.array(all_top5_pad_masks)
class ICTBertModel(MegatronModule): class ICTBertModel(MegatronModule):
......
...@@ -225,6 +225,7 @@ def backward_step(optimizer, model, loss): ...@@ -225,6 +225,7 @@ def backward_step(optimizer, model, loss):
"""Backward step.""" """Backward step."""
args = get_args() args = get_args()
timers = get_timers() timers = get_timers()
print("start backward", flush=True)
# Backward pass. # Backward pass.
optimizer.zero_grad() optimizer.zero_grad()
...@@ -239,11 +240,9 @@ def backward_step(optimizer, model, loss): ...@@ -239,11 +240,9 @@ def backward_step(optimizer, model, loss):
model.allreduce_params(reduce_after=False, model.allreduce_params(reduce_after=False,
fp32_allreduce=args.fp32_allreduce) fp32_allreduce=args.fp32_allreduce)
timers('allreduce').stop() timers('allreduce').stop()
# Update master gradients. # Update master gradients.
if args.fp16: if args.fp16:
optimizer.update_master_grads() optimizer.update_master_grads()
# Clipping gradients helps prevent the exploding gradient. # Clipping gradients helps prevent the exploding gradient.
if args.clip_grad > 0: if args.clip_grad > 0:
if not args.fp16: if not args.fp16:
......
...@@ -74,6 +74,7 @@ def forward_step(data_iterator, model): ...@@ -74,6 +74,7 @@ def forward_step(data_iterator, model):
# Get the batch. # Get the batch.
timers('batch generator').start() timers('batch generator').start()
tokens, labels, loss_mask, pad_mask = get_batch(data_iterator) tokens, labels, loss_mask, pad_mask = get_batch(data_iterator)
labels = torch.cat((labels, labels), axis=-1)
timers('batch generator').stop() timers('batch generator').stop()
# Forward model. # Forward model.
...@@ -81,6 +82,7 @@ def forward_step(data_iterator, model): ...@@ -81,6 +82,7 @@ def forward_step(data_iterator, model):
lm_logits, block_probs = model(tokens, pad_mask) lm_logits, block_probs = model(tokens, pad_mask)
# P(y|x) = sum_z(P(y|z, x) * P(z|x)) # P(y|x) = sum_z(P(y|z, x) * P(z|x))
block_probs = block_probs.unsqueeze(2).unsqueeze(3).expand_as(lm_logits)
lm_logits = torch.sum(lm_logits * block_probs, dim=1) lm_logits = torch.sum(lm_logits * block_probs, dim=1)
lm_loss_ = mpu.vocab_parallel_cross_entropy(lm_logits.contiguous().float(), lm_loss_ = mpu.vocab_parallel_cross_entropy(lm_logits.contiguous().float(),
labels.contiguous()) labels.contiguous())
...@@ -88,6 +90,7 @@ def forward_step(data_iterator, model): ...@@ -88,6 +90,7 @@ def forward_step(data_iterator, model):
lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum() lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
reduced_loss = reduce_losses([lm_loss]) reduced_loss = reduce_losses([lm_loss])
print(reduced_loss, flush=True)
return lm_loss, {'lm_loss': reduced_loss[0]} return lm_loss, {'lm_loss': reduced_loss[0]}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment