Commit 17d897e0 authored by Mostofa Patwary's avatar Mostofa Patwary
Browse files

WIP: main_retriver_merge

parent b69bc7ef
...@@ -39,7 +39,7 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -39,7 +39,7 @@ def parse_args(extra_args_provider=None, defaults={},
parser = _add_validation_args(parser) parser = _add_validation_args(parser)
parser = _add_data_args(parser) parser = _add_data_args(parser)
parser = _add_autoresume_args(parser) parser = _add_autoresume_args(parser)
parser = _add_realm_args(parser) parser = _add_biencoder_args(parser)
# Custom arguments. # Custom arguments.
if extra_args_provider is not None: if extra_args_provider is not None:
...@@ -310,6 +310,8 @@ def _add_training_args(parser): ...@@ -310,6 +310,8 @@ def _add_training_args(parser):
group.add_argument('--checkpoint-activations', action='store_true', group.add_argument('--checkpoint-activations', action='store_true',
help='Checkpoint activation to allow for training ' help='Checkpoint activation to allow for training '
'with larger models, sequences, and batch sizes.') 'with larger models, sequences, and batch sizes.')
group.add_argument('--override-checkpoint-version', type=float, default=None,
help='Override checkpoint version')
group.add_argument('--distribute-checkpointed-activations', group.add_argument('--distribute-checkpointed-activations',
action='store_true', action='store_true',
help='If set, distribute checkpointed activations ' help='If set, distribute checkpointed activations '
...@@ -567,12 +569,19 @@ def _add_autoresume_args(parser): ...@@ -567,12 +569,19 @@ def _add_autoresume_args(parser):
return parser return parser
def _add_realm_args(parser): def _add_biencoder_args(parser):
group = parser.add_argument_group(title='realm') group = parser.add_argument_group(title='biencoder')
# network size # network size
group.add_argument('--ict-head-size', type=int, default=None, group.add_argument('--ict-head-size', type=int, default=None,
help='Size of block embeddings to be used in ICT and REALM (paper default: 128)') help='Size of block embeddings to be used in ICT and REALM (paper default: 128)')
group.add_argument('--projection-dim', type=int, default=0,
help='Size of projection head used in biencoder (paper default: 128)')
group.add_argument('--shared-query-context-model', action='store_true',
help='Whether to share the parameters of the query and context models or not')
group.add_argument('--pool-type', type=str, default='cls-token',
choices=['avg', 'cls-token', 'max'],
help='different options are: avg | cls-token | max, default=cls-token')
# checkpointing # checkpointing
group.add_argument('--ict-load', type=str, default=None, group.add_argument('--ict-load', type=str, default=None,
...@@ -589,14 +598,16 @@ def _add_realm_args(parser): ...@@ -589,14 +598,16 @@ def _add_realm_args(parser):
help='Whether to use one sentence documents in ICT') help='Whether to use one sentence documents in ICT')
# training # training
group.add_argument('--report-topk-accuracies', nargs='+', default=[], group.add_argument('--report-topk-accuracies', nargs='+', type=int, default=[],
help="Which top-k accuracies to report (e.g. '1 5 20')") help="Which top-k accuracies to report (e.g. '1 5 20')")
group.add_argument('--retriever-score-scaling', action='store_true',
help="Whether to scale retriever scores by inverse square root of hidden size")
# faiss index # faiss index
group.add_argument('--faiss-use-gpu', action='store_true', group.add_argument('--faiss-use-gpu', action='store_true',
help='Whether create the FaissMIPSIndex on GPU') help='Whether create the FaissMIPSIndex on GPU')
group.add_argument('--block-data-path', type=str, default=None, #group.add_argument('--block-data-path', type=str, default=None,
help='Where to save/load BlockData to/from') # help='Where to save/load BlockData to/from')
# indexer # indexer
group.add_argument('--indexer-batch-size', type=int, default=128, group.add_argument('--indexer-batch-size', type=int, default=128,
......
...@@ -9,6 +9,16 @@ from megatron import get_args ...@@ -9,6 +9,16 @@ from megatron import get_args
from megatron.data.dataset_utils import get_indexed_dataset_ from megatron.data.dataset_utils import get_indexed_dataset_
from megatron.data.realm_dataset_utils import get_block_samples_mapping from megatron.data.realm_dataset_utils import get_block_samples_mapping
def make_attention_mask(source_block, target_block):
"""
Returns a 2-dimensional (2-D) attention mask
:param source_block: 1-D array
:param target_block: 1-D array
"""
mask = (target_block[None, :] >= 1) * (source_block[:, None] >= 1)
mask = mask.astype(np.int64)
# (source_length, target_length)
return mask
def get_ict_dataset(use_titles=True, query_in_block_prob=1): def get_ict_dataset(use_titles=True, query_in_block_prob=1):
"""Get a dataset which uses block samples mappings to get ICT/block indexing data (via get_block()) """Get a dataset which uses block samples mappings to get ICT/block indexing data (via get_block())
...@@ -93,14 +103,20 @@ class ICTDataset(Dataset): ...@@ -93,14 +103,20 @@ class ICTDataset(Dataset):
block = list(itertools.chain(*block))[:self.max_seq_length - title_pad_offset] block = list(itertools.chain(*block))[:self.max_seq_length - title_pad_offset]
query_tokens, query_pad_mask = self.concat_and_pad_tokens(query) query_tokens, query_pad_mask = self.concat_and_pad_tokens(query)
block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title) context_tokens, context_pad_mask = self.concat_and_pad_tokens(block, title)
query_mask = make_attention_mask(query_tokens, query_tokens)
context_mask = make_attention_mask(context_tokens, context_tokens)
block_data = sample_data.as_array() block_data = sample_data.as_array()
sample = { sample = {
'query_tokens': query_tokens, 'query_tokens': query_tokens,
'query_mask': query_mask,
'query_pad_mask': query_pad_mask, 'query_pad_mask': query_pad_mask,
'block_tokens': block_tokens, 'context_tokens': context_tokens,
'block_pad_mask': block_pad_mask, 'context_mask': context_mask,
'context_pad_mask': context_pad_mask,
'block_data': block_data, 'block_data': block_data,
} }
......
...@@ -59,6 +59,12 @@ class AnnealingLR(object): ...@@ -59,6 +59,12 @@ class AnnealingLR(object):
"""Learning rate decay functions from: """Learning rate decay functions from:
https://openreview.net/pdf?id=BJYwwY9ll pg. 4""" https://openreview.net/pdf?id=BJYwwY9ll pg. 4"""
#print_rank_0(self.warmup_steps)
#print_rank_0(self.num_steps)
#print_rank_0(self.warmup_steps)
#print_rank_0(self.max_lr)
#print_rank_0(self.max_lr * float(self.num_steps) / float(self.warmup_steps))
# Use linear warmup for the initial part. # Use linear warmup for the initial part.
if self.warmup_steps > 0 and self.num_steps <= self.warmup_steps: if self.warmup_steps > 0 and self.num_steps <= self.warmup_steps:
return self.max_lr * float(self.num_steps) / \ return self.max_lr * float(self.num_steps) / \
...@@ -97,7 +103,7 @@ class AnnealingLR(object): ...@@ -97,7 +103,7 @@ class AnnealingLR(object):
new_lr = self.get_lr() new_lr = self.get_lr()
for group in self.optimizer.param_groups: for group in self.optimizer.param_groups:
group['lr'] = new_lr group['lr'] = new_lr
#print_rank_0(new_lr)
def state_dict(self): def state_dict(self):
state_dict = { state_dict = {
......
...@@ -374,6 +374,19 @@ class TransformerLanguageModelBase(MegatronModule): ...@@ -374,6 +374,19 @@ class TransformerLanguageModelBase(MegatronModule):
# Transformer. # Transformer.
if self._transformer_key in state_dict: if self._transformer_key in state_dict:
state_dict_ = state_dict[self._transformer_key] state_dict_ = state_dict[self._transformer_key]
# for compatiability with t5 architecture
# this is temporary unless t5_main is merged
elif 'encoder' in state_dict:
state_dict_ = state_dict['encoder']
# for forward compatibility for t5 architecture
state_dict_attention = {}
for key in state_dict_.keys():
if '.self_attention.' in key:
state_dict_attention[key.replace(".self_attention.",
".attention.")] = state_dict_[key]
else:
state_dict_attention[key] = state_dict_[key]
state_dict_ = state_dict_attention
else: else:
# for backward compatibility. # for backward compatibility.
state_dict_ = {} state_dict_ = {}
......
...@@ -214,6 +214,9 @@ class ParallelSelfAttention(MegatronModule): ...@@ -214,6 +214,9 @@ class ParallelSelfAttention(MegatronModule):
mixed_x_layer, _ = self.query_key_value(hidden_states) mixed_x_layer, _ = self.query_key_value(hidden_states)
checkpoint_version = get_checkpoint_version() checkpoint_version = get_checkpoint_version()
if get_args().override_checkpoint_version is not None:
checkpoint_version = get_args().override_checkpoint_version
if checkpoint_version is not None: if checkpoint_version is not None:
if checkpoint_version == 0: if checkpoint_version == 0:
# [s, b, (3 * np * hn)] --> [s, b, (np * 3 * hn)] # [s, b, (3 * np * hn)] --> [s, b, (np * 3 * hn)]
...@@ -472,7 +475,7 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -472,7 +475,7 @@ class ParallelTransformerLayer(MegatronModule):
# MLP. # MLP.
mlp_output, mlp_bias = self.mlp(layernorm_output) mlp_output, mlp_bias = self.mlp(layernorm_output)
# Second residual connection. # Second residual connection.
if self.apply_residual_connection_post_layernorm: if self.apply_residual_connection_post_layernorm:
residual = layernorm_output residual = layernorm_output
......
...@@ -48,7 +48,7 @@ from megatron.model import get_params_for_weight_decay_optimization ...@@ -48,7 +48,7 @@ from megatron.model import get_params_for_weight_decay_optimization
from megatron.model.realm_model import ICTBertModel from megatron.model.realm_model import ICTBertModel
from megatron.utils import check_adlr_autoresume_termination from megatron.utils import check_adlr_autoresume_termination
from megatron.data.data_loaders import build_pretraining_data_loader from megatron.data.data_loaders import build_pretraining_data_loader
from megatron.utils import report_memory from megatron.utils import report_memory, params_grad_norm, params_global_norm, print_model, print_grads
def print_datetime(string): def print_datetime(string):
...@@ -663,11 +663,25 @@ def train_step(forward_step_func, data_iterator, ...@@ -663,11 +663,25 @@ def train_step(forward_step_func, data_iterator,
optimizer.clip_master_grads(args.clip_grad) optimizer.clip_master_grads(args.clip_grad)
timers('backward-clip-grad').stop() timers('backward-clip-grad').stop()
#print_rank_0("after backward")
#print_grads(model)
print_model(model)
print_rank_0(params_global_norm(model))
print_rank_0(params_grad_norm(model))
# Update parameters. # Update parameters.
timers('optimizer').start() timers('optimizer').start()
optimizer.step() optimizer.step()
timers('optimizer').stop() timers('optimizer').stop()
#print_rank_0("after optimizer")
#print_model(model)
print_rank_0(params_global_norm(model))
#print_rank_0(params_grad_norm(model))
#sys.exit()
# Update learning rate. # Update learning rate.
skipped_iter = 0 skipped_iter = 0
if not (args.fp16 and optimizer.overflow): if not (args.fp16 and optimizer.overflow):
...@@ -905,9 +919,9 @@ def train(forward_step_func, model, optimizer, lr_scheduler, ...@@ -905,9 +919,9 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
# Exiting based on iterations # Exiting based on iterations
if args.exit_interval and iteration % args.exit_interval == 0: if args.exit_interval and iteration % args.exit_interval == 0:
if not saved_checkpoint: #if not saved_checkpoint:
save_checkpoint_and_time(iteration, model, optimizer, # save_checkpoint_and_time(iteration, model, optimizer,
lr_scheduler) # lr_scheduler)
torch.distributed.barrier() torch.distributed.barrier()
print_datetime('exiting program at iteration {}'.format(iteration)) print_datetime('exiting program at iteration {}'.format(iteration))
sys.exit() sys.exit()
......
...@@ -150,4 +150,40 @@ def get_ltor_masks_and_position_ids(data, ...@@ -150,4 +150,40 @@ def get_ltor_masks_and_position_ids(data,
return attention_mask, loss_mask, position_ids return attention_mask, loss_mask, position_ids
def params_grad_norm(model):
print_rank_0("params_grad_norm")
norm2 = torch.cuda.FloatTensor([0.0])
for param in model.parameters():
if param.grad is None:
continue
norm = torch.norm(param.grad.data.float(), 2)
norm2 += norm * norm
torch.distributed.all_reduce(norm2)
norm = norm2 ** 0.5
return norm.item()
def params_global_norm(model):
print_rank_0("params_global_norm")
norm2 = torch.cuda.FloatTensor([0.0])
for param in model.parameters():
norm = torch.norm(param.data.float(), 2)
norm2 += norm * norm
torch.distributed.all_reduce(norm2)
norm = norm2 ** 0.5
return norm.item()
def print_model(model):
print_rank_0("print-model")
for name, param in model.named_parameters():
if param.requires_grad:
#print("{} {}".format(name, param.data), flush=True)
print_rank_0("{} {}".format(name, param.data))
return
def print_grads(model):
print_rank_0("print-grads")
for name, param in model.named_parameters():
if param.grad is None:
continue
print_rank_0("{} {}".format(name, param.grad))
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
# limitations under the License. # limitations under the License.
"""Pretrain BERT for Inverse Cloze Task""" """Pretrain BERT for Inverse Cloze Task"""
import sys
import math
import torch import torch
import torch.distributed as dist import torch.distributed as dist
...@@ -26,14 +28,16 @@ from megatron import mpu ...@@ -26,14 +28,16 @@ from megatron import mpu
from megatron.data.dataset_utils import build_train_valid_test_datasets from megatron.data.dataset_utils import build_train_valid_test_datasets
from megatron.training import pretrain from megatron.training import pretrain
from megatron.utils import average_losses_across_data_parallel_group from megatron.utils import average_losses_across_data_parallel_group
from megatron.model.realm_model import general_ict_model_provider from megatron.model.biencoder_model import biencoder_model_provider
from megatron.data.realm_dataset_utils import get_ict_batch from megatron.data.biencoder_dataset_utils import get_ict_batch
def pretrain_ict_model_provider(): def pretrain_ict_model_provider():
args = get_args() args = get_args()
return general_ict_model_provider(False, False) model = biencoder_model_provider(only_context_model=False,
only_query_model=False,
shared_query_context_model=args.shared_query_context_model)
return model
def get_group_world_size_rank(): def get_group_world_size_rank():
...@@ -72,7 +76,6 @@ class AllgatherFromDataParallelRegion(torch.autograd.Function): ...@@ -72,7 +76,6 @@ class AllgatherFromDataParallelRegion(torch.autograd.Function):
output = output_list[rank].contiguous() output = output_list[rank].contiguous()
return output return output
def forward_step(data_iterator, model, input_tensor): def forward_step(data_iterator, model, input_tensor):
"""Forward step.""" """Forward step."""
args = get_args() args = get_args()
...@@ -80,37 +83,76 @@ def forward_step(data_iterator, model, input_tensor): ...@@ -80,37 +83,76 @@ def forward_step(data_iterator, model, input_tensor):
# Get the batch. # Get the batch.
timers('batch-generator').start() timers('batch-generator').start()
query_tokens, query_pad_mask, \ query_tokens, query_mask, \
block_tokens, block_pad_mask, block_indices = get_ict_batch(data_iterator) context_tokens, context_mask, context_indices = get_ict_batch(data_iterator)
timers('batch-generator').stop() timers('batch-generator').stop()
# Query and Context Types
query_types = torch.cuda.LongTensor(*query_tokens.shape).fill_(0)
context_types = torch.cuda.LongTensor(*context_tokens.shape).fill_(0)
# Forward model. #print_rank_0(query_tokens)
query_logits, block_logits = model(query_tokens, query_pad_mask, block_tokens, block_pad_mask) #print_rank_0(context_tokens)
micro_batch_size = query_logits.shape[0] #print_rank_0(torch.sum(query_types))
global_batch_size = dist.get_world_size() * micro_batch_size # recall we assert that tensor_model_parallel_size == 1 #print_rank_0(torch.sum(query_mask))
#print_rank_0(torch.sum(context_types))
#print_rank_0(torch.sum(context_mask))
all_query_logits = AllgatherFromDataParallelRegion.apply(query_logits) #print_rank_0(params_global_norm(model))
all_block_logits = AllgatherFromDataParallelRegion.apply(block_logits) #print_rank_0(params_grad_norm(model))
# Forward model.
query_logits, context_logits = model(query_tokens, query_mask,
query_types, context_tokens,
context_mask, context_types)
#print_rank_0(query_logits)
#print_rank_0(context_logits)
# scores are inner products between query and block embeddings micro_batch_size = query_logits.shape[0]
retrieval_scores = all_query_logits.float().matmul(torch.transpose(all_block_logits, 0, 1).float()) # recall we assert that tensor_model_parallel_size == 1
softmaxed = F.softmax(retrieval_scores, dim=1) #global_batch_size = dist.get_world_size() * micro_batch_size
sorted_vals, sorted_indices = torch.topk(softmaxed, k=softmaxed.shape[1], sorted=True) #all_query_logits = AllgatherFromDataParallelRegion.apply(query_logits)
#all_context_logits = AllgatherFromDataParallelRegion.apply(context_logits)
global_batch_size = micro_batch_size
all_query_logits = query_logits
all_context_logits = context_logits
# scores are inner products between query and context embeddings
retrieval_scores = torch.matmul(all_query_logits,
torch.transpose(all_context_logits, 0, 1))
# scaling the retriever scores
if args.retriever_score_scaling:
retrieval_scores = retrieval_scores / math.sqrt(args.hidden_size)
softmax_scores = F.log_softmax(retrieval_scores, dim=1)
sorted_vals, sorted_indices = torch.topk(softmax_scores,
k=softmax_scores.shape[1], sorted=True)
def topk_accuracy(k): def topk_accuracy(k):
return torch.cuda.FloatTensor([sum([int(i in sorted_indices[i, :k]) for i in range(global_batch_size)]) / global_batch_size]) return torch.cuda.FloatTensor([sum([int(i in sorted_indices[i, :k]) \
for i in range(global_batch_size)]) / global_batch_size])
topk_accs = [topk_accuracy(int(k)) for k in args.report_topk_accuracies] topk_accs = [topk_accuracy(int(k)) for k in args.report_topk_accuracies]
retrieval_loss = torch.nn.CrossEntropyLoss()(retrieval_scores, torch.arange(global_batch_size).long().cuda())
retrieval_loss = retrieval_loss.float()
averaged_losses = average_losses_across_data_parallel_group([retrieval_loss, *topk_accs])
# create stats_dict with retrieval loss and all specified top-k accuracies labels = torch.arange(global_batch_size).long().cuda()
topk_acc_dict = {'top{}_acc'.format(k): v for k, v in zip(args.report_topk_accuracies, averaged_losses[1:])} loss = F.nll_loss(softmax_scores, labels, reduction='mean')
stats_dict = dict(retrieval_loss=averaged_losses[0], **topk_acc_dict) reduced_losses = average_losses_across_data_parallel_group([loss, *topk_accs])
# Scale the retrieval loss
loss = loss * mpu.get_data_parallel_world_size()
return retrieval_loss, stats_dict #retrieval_loss = torch.nn.CrossEntropyLoss()(retrieval_scores, torch.arange(global_batch_size).long().cuda())
#retrieval_loss = retrieval_loss.float()
#averaged_losses = average_losses_across_data_parallel_group([retrieval_loss, *topk_accs])
# create stats_dict with retrieval loss and all specified top-k accuracies
topk_acc_dict = {'top{}_acc'.format(k): v * 100 for k, v in \
zip(args.report_topk_accuracies, reduced_losses[1:])}
stats_dict = dict(loss=reduced_losses[0], **topk_acc_dict)
#print_rank_0(loss)
#print_rank_0(stats_dict)
#sys.exit()
return loss, stats_dict
def train_valid_test_datasets_provider(train_val_test_num_samples): def train_valid_test_datasets_provider(train_val_test_num_samples):
...@@ -136,5 +178,7 @@ def train_valid_test_datasets_provider(train_val_test_num_samples): ...@@ -136,5 +178,7 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
if __name__ == "__main__": if __name__ == "__main__":
pretrain(train_valid_test_datasets_provider, pretrain_ict_model_provider, forward_step, pretrain(train_valid_test_datasets_provider,
pretrain_ict_model_provider,
forward_step,
args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'}) args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment