Commit f7f730e1 authored by Neel Kant's avatar Neel Kant
Browse files

Write pretrain_realm.py and misc dataset_type left from earlier

parent f42b4d24
...@@ -113,7 +113,7 @@ def train_valid_test_datasets_provider(train_val_test_num_samples): ...@@ -113,7 +113,7 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
short_seq_prob=args.short_seq_prob, short_seq_prob=args.short_seq_prob,
seed=args.seed, seed=args.seed,
skip_warmup=(not args.mmap_warmup), skip_warmup=(not args.mmap_warmup),
ict_dataset=True) dataset_type='ict')
print_rank_0("> finished creating BERT ICT datasets ...") print_rank_0("> finished creating BERT ICT datasets ...")
return train_ds, valid_ds, test_ds return train_ds, valid_ds, test_ds
......
...@@ -17,18 +17,16 @@ ...@@ -17,18 +17,16 @@
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from hashed_index import HashedIndex, load_ict_checkpoint, get_ict_dataset
from megatron import get_args from megatron import get_args
from megatron import get_timers from megatron import get_timers
from megatron import mpu from megatron import mpu
from megatron import print_rank_0 from megatron import print_rank_0
from megatron.checkpointing import get_checkpoint_tracker_filename, get_checkpoint_name
from megatron.data.bert_dataset import build_train_valid_test_datasets from megatron.data.bert_dataset import build_train_valid_test_datasets
from megatron.model import ICTBertModel, REALMBertModel from megatron.model import REALMBertModel, REALMRetriever
from megatron.training import get_model, pretrain from megatron.training import pretrain
from megatron.utils import reduce_losses from megatron.utils import reduce_losses
from pretrain_bert_ict import model_provider as ict_model_provider
num_batches = 0 num_batches = 0
...@@ -36,39 +34,21 @@ num_batches = 0 ...@@ -36,39 +34,21 @@ num_batches = 0
def model_provider(): def model_provider():
"""Build the model.""" """Build the model."""
args = get_args() args = get_args()
print_rank_0('building BERT models ...') print_rank_0('building REALM models ...')
ict_model = get_model(ict_model_provider) ict_model = load_ict_checkpoint()
ict_dataset = get_ict_dataset()
hashed_index = HashedIndex.load_from_file('block_hash_data.pkl')
if isinstance(ict_model, torchDDP): retriever = REALMRetriever(ict_model, ict_dataset, hashed_index)
model = ict_model.module model = REALMBertModel(retriever)
tracker_filename = get_checkpoint_tracker_filename(args.load)
with open(tracker_filename, 'r') as f:
iteration = int(f.read().strip())
assert iteration > 0 return model
checkpoint_name = get_checkpoint_name(args.load, iteration, False)
if mpu.get_data_parallel_rank() == 0:
print('global rank {} is loading checkpoint {}'.format(
torch.distributed.get_rank(), checkpoint_name))
state_dict = torch.load(checkpoint_name, map_location='cpu')
ict_model.load_state_dict(state_dict['model'])
torch.distributed.barrier()
if mpu.get_data_parallel_rank() == 0:
print(' successfully loaded {}'.format(checkpoint_name))
realm_model = REALMBertModel(ict_model,
args.block_hash_data_path)
return ict_model
def get_batch(data_iterator): def get_batch(data_iterator):
# Items and their type. # Items and their type.
keys = ['query_tokens', 'query_types', 'query_pad_mask'] keys = ['tokens', 'labels', 'loss_mask', 'pad_mask']
datatype = torch.int64 datatype = torch.int64
# Broadcast data. # Broadcast data.
...@@ -79,11 +59,12 @@ def get_batch(data_iterator): ...@@ -79,11 +59,12 @@ def get_batch(data_iterator):
data_b = mpu.broadcast_data(keys, data, datatype) data_b = mpu.broadcast_data(keys, data, datatype)
# Unpack. # Unpack.
query_tokens = data_b['query_tokens'].long() tokens = data_b['tokens'].long()
query_types = data_b['query_types'].long() labels = data_b['labels'].long()
query_pad_mask = data_b['query_pad_mask'].long() loss_mask = data_b['loss_mask'].long()
pad_mask = data_b['pad_mask'].long()
return query_tokens, query_types, query_pad_mask return tokens, labels, loss_mask, pad_mask
def forward_step(data_iterator, model): def forward_step(data_iterator, model):
...@@ -92,29 +73,22 @@ def forward_step(data_iterator, model): ...@@ -92,29 +73,22 @@ def forward_step(data_iterator, model):
# Get the batch. # Get the batch.
timers('batch generator').start() timers('batch generator').start()
query_tokens, query_types, query_pad_mask = get_batch(data_iterator) tokens, labels, loss_mask, pad_mask = get_batch(data_iterator)
timers('batch generator').stop() timers('batch generator').stop()
# Forward model. # Forward model.
query_logits, block_logits = model(query_tokens, query_pad_mask, query_types, # TODO: MAKE SURE PAD IS NOT 1 - PAD
block_tokens, block_pad_mask, block_types).float() lm_logits, block_probs = model(tokens, pad_mask)
# [batch x h] * [h x batch] # P(y|x) = sum_z(P(y|z, x) * P(z|x))
retrieval_scores = query_logits.matmul(torch.transpose(block_logits, 0, 1)) lm_logits = torch.sum(lm_logits * block_probs, dim=1)
softmaxed = F.softmax(retrieval_scores, dim=1) lm_loss_ = mpu.vocab_parallel_cross_entropy(lm_logits.contiguous().float(),
labels.contiguous())
top5_vals, top5_indices = torch.topk(softmaxed, k=5, sorted=True) lm_loss = torch.sum(
batch_size = softmaxed.shape[0] lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
top1_acc = torch.cuda.FloatTensor([sum([int(top5_indices[i, 0] == i) for i in range(batch_size)]) / batch_size]) reduced_loss = reduce_losses([lm_loss])
top5_acc = torch.cuda.FloatTensor([sum([int(i in top5_indices[i]) for i in range(batch_size)]) / batch_size]) return lm_loss, {'lm_loss': reduced_loss[0]}
retrieval_loss = F.cross_entropy(softmaxed, torch.arange(batch_size).cuda())
reduced_losses = reduce_losses([retrieval_loss, top1_acc, top5_acc])
return retrieval_loss, {'retrieval loss': reduced_losses[0],
'top1_acc': reduced_losses[1],
'top5_acc': reduced_losses[2]}
def train_valid_test_datasets_provider(train_val_test_num_samples): def train_valid_test_datasets_provider(train_val_test_num_samples):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment