Commit 17d897e0 authored by Mostofa Patwary's avatar Mostofa Patwary
Browse files

WIP: main_retriver_merge

parent b69bc7ef
......@@ -39,7 +39,7 @@ def parse_args(extra_args_provider=None, defaults={},
parser = _add_validation_args(parser)
parser = _add_data_args(parser)
parser = _add_autoresume_args(parser)
parser = _add_realm_args(parser)
parser = _add_biencoder_args(parser)
# Custom arguments.
if extra_args_provider is not None:
......@@ -310,6 +310,8 @@ def _add_training_args(parser):
group.add_argument('--checkpoint-activations', action='store_true',
help='Checkpoint activation to allow for training '
'with larger models, sequences, and batch sizes.')
group.add_argument('--override-checkpoint-version', type=float, default=None,
help='Override checkpoint version')
group.add_argument('--distribute-checkpointed-activations',
action='store_true',
help='If set, distribute checkpointed activations '
......@@ -567,12 +569,19 @@ def _add_autoresume_args(parser):
return parser
def _add_realm_args(parser):
group = parser.add_argument_group(title='realm')
def _add_biencoder_args(parser):
group = parser.add_argument_group(title='biencoder')
# network size
group.add_argument('--ict-head-size', type=int, default=None,
help='Size of block embeddings to be used in ICT and REALM (paper default: 128)')
group.add_argument('--projection-dim', type=int, default=0,
help='Size of projection head used in biencoder (paper default: 128)')
group.add_argument('--shared-query-context-model', action='store_true',
help='Whether to share the parameters of the query and context models or not')
group.add_argument('--pool-type', type=str, default='cls-token',
choices=['avg', 'cls-token', 'max'],
help='different options are: avg | cls-token | max, default=cls-token')
# checkpointing
group.add_argument('--ict-load', type=str, default=None,
......@@ -589,14 +598,16 @@ def _add_realm_args(parser):
help='Whether to use one sentence documents in ICT')
# training
group.add_argument('--report-topk-accuracies', nargs='+', default=[],
group.add_argument('--report-topk-accuracies', nargs='+', type=int, default=[],
help="Which top-k accuracies to report (e.g. '1 5 20')")
group.add_argument('--retriever-score-scaling', action='store_true',
help="Whether to scale retriever scores by inverse square root of hidden size")
# faiss index
group.add_argument('--faiss-use-gpu', action='store_true',
help='Whether create the FaissMIPSIndex on GPU')
group.add_argument('--block-data-path', type=str, default=None,
help='Where to save/load BlockData to/from')
#group.add_argument('--block-data-path', type=str, default=None,
# help='Where to save/load BlockData to/from')
# indexer
group.add_argument('--indexer-batch-size', type=int, default=128,
......
......@@ -9,6 +9,16 @@ from megatron import get_args
from megatron.data.dataset_utils import get_indexed_dataset_
from megatron.data.realm_dataset_utils import get_block_samples_mapping
def make_attention_mask(source_block, target_block):
"""
Returns a 2-dimensional (2-D) attention mask
:param source_block: 1-D array
:param target_block: 1-D array
"""
mask = (target_block[None, :] >= 1) * (source_block[:, None] >= 1)
mask = mask.astype(np.int64)
# (source_length, target_length)
return mask
def get_ict_dataset(use_titles=True, query_in_block_prob=1):
"""Get a dataset which uses block samples mappings to get ICT/block indexing data (via get_block())
......@@ -93,14 +103,20 @@ class ICTDataset(Dataset):
block = list(itertools.chain(*block))[:self.max_seq_length - title_pad_offset]
query_tokens, query_pad_mask = self.concat_and_pad_tokens(query)
block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title)
context_tokens, context_pad_mask = self.concat_and_pad_tokens(block, title)
query_mask = make_attention_mask(query_tokens, query_tokens)
context_mask = make_attention_mask(context_tokens, context_tokens)
block_data = sample_data.as_array()
sample = {
'query_tokens': query_tokens,
'query_mask': query_mask,
'query_pad_mask': query_pad_mask,
'block_tokens': block_tokens,
'block_pad_mask': block_pad_mask,
'context_tokens': context_tokens,
'context_mask': context_mask,
'context_pad_mask': context_pad_mask,
'block_data': block_data,
}
......
......@@ -59,6 +59,12 @@ class AnnealingLR(object):
"""Learning rate decay functions from:
https://openreview.net/pdf?id=BJYwwY9ll pg. 4"""
#print_rank_0(self.warmup_steps)
#print_rank_0(self.num_steps)
#print_rank_0(self.warmup_steps)
#print_rank_0(self.max_lr)
#print_rank_0(self.max_lr * float(self.num_steps) / float(self.warmup_steps))
# Use linear warmup for the initial part.
if self.warmup_steps > 0 and self.num_steps <= self.warmup_steps:
return self.max_lr * float(self.num_steps) / \
......@@ -97,7 +103,7 @@ class AnnealingLR(object):
new_lr = self.get_lr()
for group in self.optimizer.param_groups:
group['lr'] = new_lr
#print_rank_0(new_lr)
def state_dict(self):
state_dict = {
......
......@@ -374,6 +374,19 @@ class TransformerLanguageModelBase(MegatronModule):
# Transformer.
if self._transformer_key in state_dict:
state_dict_ = state_dict[self._transformer_key]
# for compatiability with t5 architecture
# this is temporary unless t5_main is merged
elif 'encoder' in state_dict:
state_dict_ = state_dict['encoder']
# for forward compatibility for t5 architecture
state_dict_attention = {}
for key in state_dict_.keys():
if '.self_attention.' in key:
state_dict_attention[key.replace(".self_attention.",
".attention.")] = state_dict_[key]
else:
state_dict_attention[key] = state_dict_[key]
state_dict_ = state_dict_attention
else:
# for backward compatibility.
state_dict_ = {}
......
......@@ -214,6 +214,9 @@ class ParallelSelfAttention(MegatronModule):
mixed_x_layer, _ = self.query_key_value(hidden_states)
checkpoint_version = get_checkpoint_version()
if get_args().override_checkpoint_version is not None:
checkpoint_version = get_args().override_checkpoint_version
if checkpoint_version is not None:
if checkpoint_version == 0:
# [s, b, (3 * np * hn)] --> [s, b, (np * 3 * hn)]
......
......@@ -48,7 +48,7 @@ from megatron.model import get_params_for_weight_decay_optimization
from megatron.model.realm_model import ICTBertModel
from megatron.utils import check_adlr_autoresume_termination
from megatron.data.data_loaders import build_pretraining_data_loader
from megatron.utils import report_memory
from megatron.utils import report_memory, params_grad_norm, params_global_norm, print_model, print_grads
def print_datetime(string):
......@@ -663,11 +663,25 @@ def train_step(forward_step_func, data_iterator,
optimizer.clip_master_grads(args.clip_grad)
timers('backward-clip-grad').stop()
#print_rank_0("after backward")
#print_grads(model)
print_model(model)
print_rank_0(params_global_norm(model))
print_rank_0(params_grad_norm(model))
# Update parameters.
timers('optimizer').start()
optimizer.step()
timers('optimizer').stop()
#print_rank_0("after optimizer")
#print_model(model)
print_rank_0(params_global_norm(model))
#print_rank_0(params_grad_norm(model))
#sys.exit()
# Update learning rate.
skipped_iter = 0
if not (args.fp16 and optimizer.overflow):
......@@ -905,9 +919,9 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
# Exiting based on iterations
if args.exit_interval and iteration % args.exit_interval == 0:
if not saved_checkpoint:
save_checkpoint_and_time(iteration, model, optimizer,
lr_scheduler)
#if not saved_checkpoint:
# save_checkpoint_and_time(iteration, model, optimizer,
# lr_scheduler)
torch.distributed.barrier()
print_datetime('exiting program at iteration {}'.format(iteration))
sys.exit()
......
......@@ -150,4 +150,40 @@ def get_ltor_masks_and_position_ids(data,
return attention_mask, loss_mask, position_ids
def params_grad_norm(model):
print_rank_0("params_grad_norm")
norm2 = torch.cuda.FloatTensor([0.0])
for param in model.parameters():
if param.grad is None:
continue
norm = torch.norm(param.grad.data.float(), 2)
norm2 += norm * norm
torch.distributed.all_reduce(norm2)
norm = norm2 ** 0.5
return norm.item()
def params_global_norm(model):
print_rank_0("params_global_norm")
norm2 = torch.cuda.FloatTensor([0.0])
for param in model.parameters():
norm = torch.norm(param.data.float(), 2)
norm2 += norm * norm
torch.distributed.all_reduce(norm2)
norm = norm2 ** 0.5
return norm.item()
def print_model(model):
print_rank_0("print-model")
for name, param in model.named_parameters():
if param.requires_grad:
#print("{} {}".format(name, param.data), flush=True)
print_rank_0("{} {}".format(name, param.data))
return
def print_grads(model):
print_rank_0("print-grads")
for name, param in model.named_parameters():
if param.grad is None:
continue
print_rank_0("{} {}".format(name, param.grad))
......@@ -14,6 +14,8 @@
# limitations under the License.
"""Pretrain BERT for Inverse Cloze Task"""
import sys
import math
import torch
import torch.distributed as dist
......@@ -26,14 +28,16 @@ from megatron import mpu
from megatron.data.dataset_utils import build_train_valid_test_datasets
from megatron.training import pretrain
from megatron.utils import average_losses_across_data_parallel_group
from megatron.model.realm_model import general_ict_model_provider
from megatron.data.realm_dataset_utils import get_ict_batch
from megatron.model.biencoder_model import biencoder_model_provider
from megatron.data.biencoder_dataset_utils import get_ict_batch
def pretrain_ict_model_provider():
args = get_args()
return general_ict_model_provider(False, False)
model = biencoder_model_provider(only_context_model=False,
only_query_model=False,
shared_query_context_model=args.shared_query_context_model)
return model
def get_group_world_size_rank():
......@@ -72,7 +76,6 @@ class AllgatherFromDataParallelRegion(torch.autograd.Function):
output = output_list[rank].contiguous()
return output
def forward_step(data_iterator, model, input_tensor):
"""Forward step."""
args = get_args()
......@@ -80,37 +83,76 @@ def forward_step(data_iterator, model, input_tensor):
# Get the batch.
timers('batch-generator').start()
query_tokens, query_pad_mask, \
block_tokens, block_pad_mask, block_indices = get_ict_batch(data_iterator)
query_tokens, query_mask, \
context_tokens, context_mask, context_indices = get_ict_batch(data_iterator)
timers('batch-generator').stop()
# Query and Context Types
query_types = torch.cuda.LongTensor(*query_tokens.shape).fill_(0)
context_types = torch.cuda.LongTensor(*context_tokens.shape).fill_(0)
# Forward model.
query_logits, block_logits = model(query_tokens, query_pad_mask, block_tokens, block_pad_mask)
micro_batch_size = query_logits.shape[0]
global_batch_size = dist.get_world_size() * micro_batch_size # recall we assert that tensor_model_parallel_size == 1
#print_rank_0(query_tokens)
#print_rank_0(context_tokens)
#print_rank_0(torch.sum(query_types))
#print_rank_0(torch.sum(query_mask))
#print_rank_0(torch.sum(context_types))
#print_rank_0(torch.sum(context_mask))
all_query_logits = AllgatherFromDataParallelRegion.apply(query_logits)
all_block_logits = AllgatherFromDataParallelRegion.apply(block_logits)
#print_rank_0(params_global_norm(model))
#print_rank_0(params_grad_norm(model))
# Forward model.
query_logits, context_logits = model(query_tokens, query_mask,
query_types, context_tokens,
context_mask, context_types)
#print_rank_0(query_logits)
#print_rank_0(context_logits)
# scores are inner products between query and block embeddings
retrieval_scores = all_query_logits.float().matmul(torch.transpose(all_block_logits, 0, 1).float())
softmaxed = F.softmax(retrieval_scores, dim=1)
sorted_vals, sorted_indices = torch.topk(softmaxed, k=softmaxed.shape[1], sorted=True)
micro_batch_size = query_logits.shape[0]
# recall we assert that tensor_model_parallel_size == 1
#global_batch_size = dist.get_world_size() * micro_batch_size
#all_query_logits = AllgatherFromDataParallelRegion.apply(query_logits)
#all_context_logits = AllgatherFromDataParallelRegion.apply(context_logits)
global_batch_size = micro_batch_size
all_query_logits = query_logits
all_context_logits = context_logits
# scores are inner products between query and context embeddings
retrieval_scores = torch.matmul(all_query_logits,
torch.transpose(all_context_logits, 0, 1))
# scaling the retriever scores
if args.retriever_score_scaling:
retrieval_scores = retrieval_scores / math.sqrt(args.hidden_size)
softmax_scores = F.log_softmax(retrieval_scores, dim=1)
sorted_vals, sorted_indices = torch.topk(softmax_scores,
k=softmax_scores.shape[1], sorted=True)
def topk_accuracy(k):
return torch.cuda.FloatTensor([sum([int(i in sorted_indices[i, :k]) for i in range(global_batch_size)]) / global_batch_size])
return torch.cuda.FloatTensor([sum([int(i in sorted_indices[i, :k]) \
for i in range(global_batch_size)]) / global_batch_size])
topk_accs = [topk_accuracy(int(k)) for k in args.report_topk_accuracies]
retrieval_loss = torch.nn.CrossEntropyLoss()(retrieval_scores, torch.arange(global_batch_size).long().cuda())
retrieval_loss = retrieval_loss.float()
averaged_losses = average_losses_across_data_parallel_group([retrieval_loss, *topk_accs])
# create stats_dict with retrieval loss and all specified top-k accuracies
topk_acc_dict = {'top{}_acc'.format(k): v for k, v in zip(args.report_topk_accuracies, averaged_losses[1:])}
stats_dict = dict(retrieval_loss=averaged_losses[0], **topk_acc_dict)
labels = torch.arange(global_batch_size).long().cuda()
loss = F.nll_loss(softmax_scores, labels, reduction='mean')
reduced_losses = average_losses_across_data_parallel_group([loss, *topk_accs])
# Scale the retrieval loss
loss = loss * mpu.get_data_parallel_world_size()
return retrieval_loss, stats_dict
#retrieval_loss = torch.nn.CrossEntropyLoss()(retrieval_scores, torch.arange(global_batch_size).long().cuda())
#retrieval_loss = retrieval_loss.float()
#averaged_losses = average_losses_across_data_parallel_group([retrieval_loss, *topk_accs])
# create stats_dict with retrieval loss and all specified top-k accuracies
topk_acc_dict = {'top{}_acc'.format(k): v * 100 for k, v in \
zip(args.report_topk_accuracies, reduced_losses[1:])}
stats_dict = dict(loss=reduced_losses[0], **topk_acc_dict)
#print_rank_0(loss)
#print_rank_0(stats_dict)
#sys.exit()
return loss, stats_dict
def train_valid_test_datasets_provider(train_val_test_num_samples):
......@@ -136,5 +178,7 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
if __name__ == "__main__":
pretrain(train_valid_test_datasets_provider, pretrain_ict_model_provider, forward_step,
pretrain(train_valid_test_datasets_provider,
pretrain_ict_model_provider,
forward_step,
args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment