Commit 0e8f4331 authored by Neel Kant's avatar Neel Kant
Browse files

Correct CrossEntropyLoss

parent 8e22824e
......@@ -95,6 +95,7 @@ def setup_realm_groups_and_vars():
class IndexBuilder(object):
def __init__(self):
args = get_args()
self.debug = args.debug
self.rank = args.rank
self.model = None
self.dataloader = None
......@@ -287,6 +288,6 @@ def get_one_epoch_dataloader(dataset, batch_size=None):
if __name__ == "__main__":
initialize_megatron(extra_args_provider=None,
args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
index_builder = BasicIndexBuilder()
index_builder = IndexBuilder()
index_builder.build_and_save_index()
......@@ -265,7 +265,7 @@ def _add_checkpointing_args(parser):
group.add_argument('--ict-load', type=str, default=None,
help='Directory containing an ICTBertModel checkpoint')
group.add_argument('--bert-load', type=str, default=None,
help='Directory containing an BertModel checkpoint (needed to start REALM)')
help='Directory containing an BertModel checkpoint (needed to start ICT and REALM)')
group.add_argument('--no-load-optim', action='store_true',
help='Do not load optimizer when loading checkpoint.')
group.add_argument('--no-load-rng', action='store_true',
......
......@@ -97,7 +97,8 @@ class InverseClozeDataset(Dataset):
"""concat with special tokens and pad sequence to self.max_seq_length"""
tokens = [self.cls_id] + tokens + [self.sep_id]
if title is not None:
tokens += title + [self.sep_id]
# tokens += title + [self.sep_id]
tokens = t
assert len(tokens) <= self.max_seq_length, len(tokens)
num_pad = self.max_seq_length - len(tokens)
......
......@@ -294,10 +294,11 @@ class ICTBertModel(MegatronModule):
query_logits = self.embed_query(query_tokens, query_attention_mask)
block_logits = self.embed_block(block_tokens, block_attention_mask)
return query_logits, block_logits
# [batch x embed] * [embed x batch]
retrieval_scores = query_logits.matmul(torch.transpose(block_logits, 0, 1))
return retrieval_scores
# retrieval_scores = query_logits.matmul(torch.transpose(block_logits, 0, 1))
# return retrieval_scores
def embed_query(self, query_tokens, query_attention_mask):
"""Embed a batch of tokens using the query model"""
......@@ -343,3 +344,31 @@ class ICTBertModel(MegatronModule):
print("Loading ICT block model", flush=True)
self.block_model.load_state_dict(
state_dict[self._block_key], strict=strict)
def init_state_dict_from_bert(self):
args = get_args()
import os
from megatron import mpu
from megatron.checkpointing import get_checkpoint_tracker_filename, get_checkpoint_name
tracker_filename = get_checkpoint_tracker_filename(args.bert_load)
if not os.path.isfile(tracker_filename):
raise FileNotFoundError("Could not find BERT load for ICT")
with open(tracker_filename, 'r') as f:
iteration = int(f.read().strip())
assert iteration > 0
checkpoint_name = get_checkpoint_name(args.bert_load, iteration, False)
if mpu.get_data_parallel_rank() == 0:
print('global rank {} is loading checkpoint {}'.format(
torch.distributed.get_rank(), checkpoint_name))
try:
state_dict = torch.load(checkpoint_name, map_location='cpu')
except BaseException:
raise ValueError("Could not load checkpoint")
model_dict = state_dict['model']['language_model']
self.query_model.language_model.load_state_dict(model_dict)
self.block_model.language_model.load_state_dict(model_dict)
query_ict_head_state_dict = self.state_dict_for_save_checkpoint()[self._query_key]['ict_head']
self.block_model.ict_head.load_state_dict(query_ict_head_state_dict)
......@@ -37,6 +37,7 @@ from megatron.learning_rates import AnnealingLR
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import get_params_for_weight_decay_optimization
from megatron.mpu.initialize import get_index_ready, get_train_group, get_data_parallel_group, get_gloo_comm_group
from megatron.model.realm_model import ICTBertModel
from megatron.utils import check_adlr_autoresume_termination
from megatron.utils import make_data_loader
from megatron.utils import report_memory
......@@ -229,6 +230,12 @@ def setup_model_and_optimizer(model_provider_func):
else:
args.iteration = 0
if args.iteration == 0 and isinstance(model.module.module, ICTBertModel):
print("Yes, located ICT model", flush=True)
model.module.module.init_state_dict_from_bert()
elif args.iteration == 0:
print("Ooops", flush=True)
return model, optimizer, lr_scheduler
......@@ -239,10 +246,12 @@ def backward_step(optimizer, model, loss):
# torch.cuda.synchronize()
# Backward pass.
optimizer.zero_grad(set_grads_to_None=True)
# optimizer.zero_grad(set_grads_to_None=True)
if args.fp16:
optimizer.zero_grad(set_grads_to_None=True)
optimizer.backward(loss, update_master_grads=False)
else:
optimizer.zero_grad()
loss.backward()
# All-reduce if needed.
......@@ -377,9 +386,10 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
print('>>> Starting train()', flush=True)
# start off by posting a receive call which will be answered.
# synchronize for start
torch.distributed.broadcast(INDEX_READY, 0, group=get_gloo_comm_group())
recv_handle = torch.distributed.broadcast(INDEX_READY, args.max_training_rank, group=get_gloo_comm_group(), async_op=True)
last_reload_iteration = iteration
if args.max_training_rank is not None:
torch.distributed.broadcast(INDEX_READY, 0, group=get_gloo_comm_group())
recv_handle = torch.distributed.broadcast(INDEX_READY, args.max_training_rank, group=get_gloo_comm_group(), async_op=True)
last_reload_iteration = iteration
while iteration < args.train_iters:
# this only applies for realm right here
if args.max_training_rank is not None and recv_handle.is_completed() and iteration >= last_reload_iteration + 500:
......
......@@ -16,6 +16,7 @@
"""Pretrain BERT for Inverse Cloze Task"""
import torch
import torch.distributed as dist
import torch.nn.functional as F
from megatron import get_args
......@@ -71,6 +72,7 @@ def get_batch(data_iterator):
def forward_step(data_iterator, model):
"""Forward step."""
args = get_args()
timers = get_timers()
# Get the batch.
......@@ -80,21 +82,49 @@ def forward_step(data_iterator, model):
timers('batch generator').stop()
# Forward model.
retrieval_scores = model(query_tokens, query_pad_mask, block_tokens, block_pad_mask).float()
# retrieval_scores = model(query_tokens, query_pad_mask, block_tokens, block_pad_mask).float()
query_logits, block_logits = model(query_tokens, query_pad_mask, block_tokens, block_pad_mask)
data_parallel_size = dist.get_world_size() / args.model_parallel_size
batch_size = query_logits.shape[0]
global_batch_size = int(batch_size * data_parallel_size)
all_logits_shape = (int(global_batch_size), int(query_logits.shape[1]))
all_query_logits = torch.zeros(all_logits_shape).type(query_logits.dtype).cuda()
all_block_logits = all_query_logits.clone().cuda()
all_query_logits[args.rank * batch_size:(args.rank + 1) * batch_size] = query_logits
all_block_logits[args.rank * batch_size:(args.rank + 1) * batch_size] = block_logits
# print(all_query_logits[:, :5], flush=True)
# print(all_block_logits[:, :5], flush=True)
dist.all_reduce(all_query_logits)
dist.all_reduce(all_block_logits)
# print(all_query_logits[:, :5], flush=True)
# print(all_block_logits[:, :5], flush=True)
retrieval_scores = all_query_logits.float().matmul(torch.transpose(all_block_logits, 0, 1).float())
softmaxed = F.softmax(retrieval_scores, dim=1)
sorted_vals, sorted_indices = torch.topk(softmaxed, k=softmaxed.shape[1], sorted=True)
batch_size = softmaxed.shape[0]
top1_acc = torch.cuda.FloatTensor([sum([int(sorted_indices[i, 0] == i) for i in range(batch_size)]) / batch_size])
top5_acc = torch.cuda.FloatTensor([sum([int(i in sorted_indices[i, :5]) for i in range(batch_size)]) / batch_size])
def topk_acc(k):
return torch.cuda.FloatTensor([sum([int(i in sorted_indices[i, :k]) for i in range(global_batch_size)]) / global_batch_size])
top_accs = [topk_acc(k) for k in [1, 8, 20, 100]]
retrieval_loss = torch.nn.CrossEntropyLoss()(retrieval_scores, torch.arange(global_batch_size).long().cuda())
# correct_probs = torch.gather(softmaxed, 1, torch.arange(global_batch_size).long().cuda().reshape(-1, 1))
# assert correct_probs[3] == softmaxed[3, 3]
# retrieval_loss = -torch.sum(torch.log(correct_probs)) / global_batch_size
retrieval_loss = F.cross_entropy(softmaxed, torch.arange(batch_size).cuda())
reduced_losses = reduce_losses([retrieval_loss, top1_acc, top5_acc])
reduced_losses = reduce_losses([retrieval_loss, *top_accs])
stats_dict = {
'retrieval loss': reduced_losses[0],
'top1_acc': reduced_losses[1],
'top5_acc': reduced_losses[2]
'top8_acc': reduced_losses[2],
'top20_acc': reduced_losses[3],
'top100_acc': reduced_losses[4],
}
return retrieval_loss, stats_dict
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment