Commit 0e8f4331 authored by Neel Kant's avatar Neel Kant
Browse files

Correct CrossEntropyLoss

parent 8e22824e
...@@ -95,6 +95,7 @@ def setup_realm_groups_and_vars(): ...@@ -95,6 +95,7 @@ def setup_realm_groups_and_vars():
class IndexBuilder(object): class IndexBuilder(object):
def __init__(self): def __init__(self):
args = get_args() args = get_args()
self.debug = args.debug
self.rank = args.rank self.rank = args.rank
self.model = None self.model = None
self.dataloader = None self.dataloader = None
...@@ -287,6 +288,6 @@ def get_one_epoch_dataloader(dataset, batch_size=None): ...@@ -287,6 +288,6 @@ def get_one_epoch_dataloader(dataset, batch_size=None):
if __name__ == "__main__": if __name__ == "__main__":
initialize_megatron(extra_args_provider=None, initialize_megatron(extra_args_provider=None,
args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'}) args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
index_builder = BasicIndexBuilder() index_builder = IndexBuilder()
index_builder.build_and_save_index() index_builder.build_and_save_index()
...@@ -265,7 +265,7 @@ def _add_checkpointing_args(parser): ...@@ -265,7 +265,7 @@ def _add_checkpointing_args(parser):
group.add_argument('--ict-load', type=str, default=None, group.add_argument('--ict-load', type=str, default=None,
help='Directory containing an ICTBertModel checkpoint') help='Directory containing an ICTBertModel checkpoint')
group.add_argument('--bert-load', type=str, default=None, group.add_argument('--bert-load', type=str, default=None,
help='Directory containing an BertModel checkpoint (needed to start REALM)') help='Directory containing an BertModel checkpoint (needed to start ICT and REALM)')
group.add_argument('--no-load-optim', action='store_true', group.add_argument('--no-load-optim', action='store_true',
help='Do not load optimizer when loading checkpoint.') help='Do not load optimizer when loading checkpoint.')
group.add_argument('--no-load-rng', action='store_true', group.add_argument('--no-load-rng', action='store_true',
......
...@@ -97,7 +97,8 @@ class InverseClozeDataset(Dataset): ...@@ -97,7 +97,8 @@ class InverseClozeDataset(Dataset):
"""concat with special tokens and pad sequence to self.max_seq_length""" """concat with special tokens and pad sequence to self.max_seq_length"""
tokens = [self.cls_id] + tokens + [self.sep_id] tokens = [self.cls_id] + tokens + [self.sep_id]
if title is not None: if title is not None:
tokens += title + [self.sep_id] # tokens += title + [self.sep_id]
tokens = t
assert len(tokens) <= self.max_seq_length, len(tokens) assert len(tokens) <= self.max_seq_length, len(tokens)
num_pad = self.max_seq_length - len(tokens) num_pad = self.max_seq_length - len(tokens)
......
...@@ -294,10 +294,11 @@ class ICTBertModel(MegatronModule): ...@@ -294,10 +294,11 @@ class ICTBertModel(MegatronModule):
query_logits = self.embed_query(query_tokens, query_attention_mask) query_logits = self.embed_query(query_tokens, query_attention_mask)
block_logits = self.embed_block(block_tokens, block_attention_mask) block_logits = self.embed_block(block_tokens, block_attention_mask)
return query_logits, block_logits
# [batch x embed] * [embed x batch] # [batch x embed] * [embed x batch]
retrieval_scores = query_logits.matmul(torch.transpose(block_logits, 0, 1)) # retrieval_scores = query_logits.matmul(torch.transpose(block_logits, 0, 1))
return retrieval_scores # return retrieval_scores
def embed_query(self, query_tokens, query_attention_mask): def embed_query(self, query_tokens, query_attention_mask):
"""Embed a batch of tokens using the query model""" """Embed a batch of tokens using the query model"""
...@@ -343,3 +344,31 @@ class ICTBertModel(MegatronModule): ...@@ -343,3 +344,31 @@ class ICTBertModel(MegatronModule):
print("Loading ICT block model", flush=True) print("Loading ICT block model", flush=True)
self.block_model.load_state_dict( self.block_model.load_state_dict(
state_dict[self._block_key], strict=strict) state_dict[self._block_key], strict=strict)
def init_state_dict_from_bert(self):
args = get_args()
import os
from megatron import mpu
from megatron.checkpointing import get_checkpoint_tracker_filename, get_checkpoint_name
tracker_filename = get_checkpoint_tracker_filename(args.bert_load)
if not os.path.isfile(tracker_filename):
raise FileNotFoundError("Could not find BERT load for ICT")
with open(tracker_filename, 'r') as f:
iteration = int(f.read().strip())
assert iteration > 0
checkpoint_name = get_checkpoint_name(args.bert_load, iteration, False)
if mpu.get_data_parallel_rank() == 0:
print('global rank {} is loading checkpoint {}'.format(
torch.distributed.get_rank(), checkpoint_name))
try:
state_dict = torch.load(checkpoint_name, map_location='cpu')
except BaseException:
raise ValueError("Could not load checkpoint")
model_dict = state_dict['model']['language_model']
self.query_model.language_model.load_state_dict(model_dict)
self.block_model.language_model.load_state_dict(model_dict)
query_ict_head_state_dict = self.state_dict_for_save_checkpoint()[self._query_key]['ict_head']
self.block_model.ict_head.load_state_dict(query_ict_head_state_dict)
...@@ -37,6 +37,7 @@ from megatron.learning_rates import AnnealingLR ...@@ -37,6 +37,7 @@ from megatron.learning_rates import AnnealingLR
from megatron.model import DistributedDataParallel as LocalDDP from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import get_params_for_weight_decay_optimization from megatron.model import get_params_for_weight_decay_optimization
from megatron.mpu.initialize import get_index_ready, get_train_group, get_data_parallel_group, get_gloo_comm_group from megatron.mpu.initialize import get_index_ready, get_train_group, get_data_parallel_group, get_gloo_comm_group
from megatron.model.realm_model import ICTBertModel
from megatron.utils import check_adlr_autoresume_termination from megatron.utils import check_adlr_autoresume_termination
from megatron.utils import make_data_loader from megatron.utils import make_data_loader
from megatron.utils import report_memory from megatron.utils import report_memory
...@@ -229,6 +230,12 @@ def setup_model_and_optimizer(model_provider_func): ...@@ -229,6 +230,12 @@ def setup_model_and_optimizer(model_provider_func):
else: else:
args.iteration = 0 args.iteration = 0
if args.iteration == 0 and isinstance(model.module.module, ICTBertModel):
print("Yes, located ICT model", flush=True)
model.module.module.init_state_dict_from_bert()
elif args.iteration == 0:
print("Ooops", flush=True)
return model, optimizer, lr_scheduler return model, optimizer, lr_scheduler
...@@ -239,10 +246,12 @@ def backward_step(optimizer, model, loss): ...@@ -239,10 +246,12 @@ def backward_step(optimizer, model, loss):
# torch.cuda.synchronize() # torch.cuda.synchronize()
# Backward pass. # Backward pass.
optimizer.zero_grad(set_grads_to_None=True) # optimizer.zero_grad(set_grads_to_None=True)
if args.fp16: if args.fp16:
optimizer.zero_grad(set_grads_to_None=True)
optimizer.backward(loss, update_master_grads=False) optimizer.backward(loss, update_master_grads=False)
else: else:
optimizer.zero_grad()
loss.backward() loss.backward()
# All-reduce if needed. # All-reduce if needed.
...@@ -377,6 +386,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler, ...@@ -377,6 +386,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
print('>>> Starting train()', flush=True) print('>>> Starting train()', flush=True)
# start off by posting a receive call which will be answered. # start off by posting a receive call which will be answered.
# synchronize for start # synchronize for start
if args.max_training_rank is not None:
torch.distributed.broadcast(INDEX_READY, 0, group=get_gloo_comm_group()) torch.distributed.broadcast(INDEX_READY, 0, group=get_gloo_comm_group())
recv_handle = torch.distributed.broadcast(INDEX_READY, args.max_training_rank, group=get_gloo_comm_group(), async_op=True) recv_handle = torch.distributed.broadcast(INDEX_READY, args.max_training_rank, group=get_gloo_comm_group(), async_op=True)
last_reload_iteration = iteration last_reload_iteration = iteration
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
"""Pretrain BERT for Inverse Cloze Task""" """Pretrain BERT for Inverse Cloze Task"""
import torch import torch
import torch.distributed as dist
import torch.nn.functional as F import torch.nn.functional as F
from megatron import get_args from megatron import get_args
...@@ -71,6 +72,7 @@ def get_batch(data_iterator): ...@@ -71,6 +72,7 @@ def get_batch(data_iterator):
def forward_step(data_iterator, model): def forward_step(data_iterator, model):
"""Forward step.""" """Forward step."""
args = get_args()
timers = get_timers() timers = get_timers()
# Get the batch. # Get the batch.
...@@ -80,21 +82,49 @@ def forward_step(data_iterator, model): ...@@ -80,21 +82,49 @@ def forward_step(data_iterator, model):
timers('batch generator').stop() timers('batch generator').stop()
# Forward model. # Forward model.
retrieval_scores = model(query_tokens, query_pad_mask, block_tokens, block_pad_mask).float() # retrieval_scores = model(query_tokens, query_pad_mask, block_tokens, block_pad_mask).float()
query_logits, block_logits = model(query_tokens, query_pad_mask, block_tokens, block_pad_mask)
data_parallel_size = dist.get_world_size() / args.model_parallel_size
batch_size = query_logits.shape[0]
global_batch_size = int(batch_size * data_parallel_size)
all_logits_shape = (int(global_batch_size), int(query_logits.shape[1]))
all_query_logits = torch.zeros(all_logits_shape).type(query_logits.dtype).cuda()
all_block_logits = all_query_logits.clone().cuda()
all_query_logits[args.rank * batch_size:(args.rank + 1) * batch_size] = query_logits
all_block_logits[args.rank * batch_size:(args.rank + 1) * batch_size] = block_logits
# print(all_query_logits[:, :5], flush=True)
# print(all_block_logits[:, :5], flush=True)
dist.all_reduce(all_query_logits)
dist.all_reduce(all_block_logits)
# print(all_query_logits[:, :5], flush=True)
# print(all_block_logits[:, :5], flush=True)
retrieval_scores = all_query_logits.float().matmul(torch.transpose(all_block_logits, 0, 1).float())
softmaxed = F.softmax(retrieval_scores, dim=1) softmaxed = F.softmax(retrieval_scores, dim=1)
sorted_vals, sorted_indices = torch.topk(softmaxed, k=softmaxed.shape[1], sorted=True) sorted_vals, sorted_indices = torch.topk(softmaxed, k=softmaxed.shape[1], sorted=True)
batch_size = softmaxed.shape[0]
top1_acc = torch.cuda.FloatTensor([sum([int(sorted_indices[i, 0] == i) for i in range(batch_size)]) / batch_size]) def topk_acc(k):
top5_acc = torch.cuda.FloatTensor([sum([int(i in sorted_indices[i, :5]) for i in range(batch_size)]) / batch_size]) return torch.cuda.FloatTensor([sum([int(i in sorted_indices[i, :k]) for i in range(global_batch_size)]) / global_batch_size])
top_accs = [topk_acc(k) for k in [1, 8, 20, 100]]
retrieval_loss = torch.nn.CrossEntropyLoss()(retrieval_scores, torch.arange(global_batch_size).long().cuda())
# correct_probs = torch.gather(softmaxed, 1, torch.arange(global_batch_size).long().cuda().reshape(-1, 1))
# assert correct_probs[3] == softmaxed[3, 3]
# retrieval_loss = -torch.sum(torch.log(correct_probs)) / global_batch_size
retrieval_loss = F.cross_entropy(softmaxed, torch.arange(batch_size).cuda()) reduced_losses = reduce_losses([retrieval_loss, *top_accs])
reduced_losses = reduce_losses([retrieval_loss, top1_acc, top5_acc])
stats_dict = { stats_dict = {
'retrieval loss': reduced_losses[0], 'retrieval loss': reduced_losses[0],
'top1_acc': reduced_losses[1], 'top1_acc': reduced_losses[1],
'top5_acc': reduced_losses[2] 'top8_acc': reduced_losses[2],
'top20_acc': reduced_losses[3],
'top100_acc': reduced_losses[4],
} }
return retrieval_loss, stats_dict return retrieval_loss, stats_dict
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment