Commit 06076c7a authored by Mostofa Patwary's avatar Mostofa Patwary
Browse files

implementation dpr

parent cdde4338
......@@ -17,7 +17,9 @@ from .module import MegatronModule
def biencoder_model_provider(only_query_model=False,
only_context_model=False,
biencoder_shared_query_context_model=False):
biencoder_shared_query_context_model=False,
pre_process=True,
post_process=True):
"""Build the model."""
args = get_args()
......@@ -35,7 +37,9 @@ def biencoder_model_provider(only_query_model=False,
only_query_model=only_query_model,
only_context_model=only_context_model,
biencoder_shared_query_context_model=\
biencoder_shared_query_context_model)
biencoder_shared_query_context_model,
pre_process=pre_process,
post_process=post_process)
return model
......@@ -48,13 +52,17 @@ class BiEncoderModel(MegatronModule):
parallel_output=True,
only_query_model=False,
only_context_model=False,
biencoder_shared_query_context_model=False):
biencoder_shared_query_context_model=False,
pre_process=True,
post_process=True):
super(BiEncoderModel, self).__init__()
args = get_args()
bert_kwargs = dict(
num_tokentypes=num_tokentypes,
parallel_output=parallel_output)
parallel_output=parallel_output,
pre_process=pre_process,
post_process=post_process)
self.biencoder_shared_query_context_model = \
biencoder_shared_query_context_model
......@@ -78,6 +86,19 @@ class BiEncoderModel(MegatronModule):
self.context_model = PretrainedBertModel(**bert_kwargs)
self._context_key = 'context_model'
def set_input_tensor(self, input_tensor):
"""See megatron.model.transformer.set_input_tensor()"""
#self.language_model.set_input_tensor(input_tensor)
return
# #if self._model_key is not None:
# # print("_model_key {}".format(self._model_key), flush=True)
# print(input_tensor)
# if self._query_key is not None:
# print("_query_key {}".format(self._query_key), flush=True)
# if self._context_key is not None:
# print("_context_key {}".format(self._context_key), flush=True)
# exit()
def forward(self, query_tokens, query_attention_mask, query_types,
context_tokens, context_attention_mask, context_types):
"""Run a forward pass for each of the models and
......@@ -217,7 +238,7 @@ class PretrainedBertModel(MegatronModule):
learned information retrieval."""
def __init__(self, num_tokentypes=2,
parallel_output=True):
parallel_output=True, pre_process=True, post_process=True):
super(PretrainedBertModel, self).__init__()
args = get_args()
......@@ -225,6 +246,8 @@ class PretrainedBertModel(MegatronModule):
self.pad_id = tokenizer.pad
self.biencoder_projection_dim = args.biencoder_projection_dim
self.parallel_output = parallel_output
self.pre_process = pre_process
self.post_process = post_process
init_method = init_method_normal(args.init_method_std)
scaled_init_method = scaled_init_method_normal(
args.init_method_std, args.num_layers)
......@@ -234,7 +257,9 @@ class PretrainedBertModel(MegatronModule):
add_pooler=False,
encoder_attn_mask_type=AttnMaskType.padding,
init_method=init_method,
scaled_init_method=scaled_init_method)
scaled_init_method=scaled_init_method,
pre_process=self.pre_process,
post_process=self.post_process)
if args.biencoder_projection_dim > 0:
self.projection_enc = get_linear_layer(args.hidden_size,
......
......@@ -181,6 +181,35 @@ class FullTokenizer(object):
def convert_ids_to_tokens(self, ids):
return convert_by_vocab(self.inv_vocab, ids)
@staticmethod
def convert_tokens_to_string(tokens, clean_up_tokenization_spaces=True):
""" Converts a sequence of tokens (string) in a single string. """
def clean_up_tokenization(out_string):
""" Clean up a list of simple English tokenization artifacts
like spaces before punctuations and abreviated forms.
"""
out_string = (
out_string.replace(" .", ".")
.replace(" ?", "?")
.replace(" !", "!")
.replace(" ,", ",")
.replace(" ' ", "'")
.replace(" n't", "n't")
.replace(" 'm", "'m")
.replace(" 's", "'s")
.replace(" 've", "'ve")
.replace(" 're", "'re")
)
return out_string
text = ' '.join(tokens).replace(' ##', '').strip()
if clean_up_tokenization_spaces:
clean_text = clean_up_tokenization(text)
return clean_text
else:
return text
def vocab_size(self):
return len(self.vocab)
......
......@@ -155,6 +155,10 @@ class _BertWordPieceTokenizer(AbstractTokenizer):
text_tokens = self.tokenizer.tokenize(text)
return self.tokenizer.convert_tokens_to_ids(text_tokens)
def decode(self, ids):
tokens = self.tokenizer.convert_ids_to_tokens(ids)
return self.tokenizer.convert_tokens_to_string(tokens)
def decode_token_ids(self, token_ids):
tokens = self.tokenizer.convert_ids_to_tokens(token_ids)
exclude_list = ['[PAD]', '[CLS]']
......
......@@ -80,7 +80,8 @@ def _cross_entropy_forward_step(batch, model):
return output_tensor, partial(cross_entropy_loss_func, labels)
def build_data_loader(dataset, micro_batch_size, num_workers, drop_last):
def build_data_loader(dataset, micro_batch_size, num_workers, drop_last,
task_collate_fn=None):
"""Data loader. Note that batch-size is the local (per GPU) batch-size."""
# Sampler.
......@@ -89,6 +90,8 @@ def build_data_loader(dataset, micro_batch_size, num_workers, drop_last):
sampler = torch.utils.data.distributed.DistributedSampler(
dataset, num_replicas=world_size, rank=rank)
print_rank_0(len(sampler))
# Data loader. Note that batch size is the per GPU batch size.
data_loader = torch.utils.data.DataLoader(dataset,
batch_size=micro_batch_size,
......@@ -96,7 +99,8 @@ def build_data_loader(dataset, micro_batch_size, num_workers, drop_last):
shuffle=False,
num_workers=num_workers,
drop_last=drop_last,
pin_memory=True)
pin_memory=True,
collate_fn=task_collate_fn)
return data_loader
......@@ -112,21 +116,23 @@ def _build_infinite_size_dataloader(dataloader):
iterator = dataloader.__iter__()
def _build_train_valid_dataloaders(train_dataset, valid_dataset):
def _build_train_valid_dataloaders(train_dataset, valid_dataset, task_collate_fn=None):
"""Traing and validation dataloaders."""
args = get_args()
print_rank_0('building train and validation dataloaders ...')
# Training dataset.
train_dataloader = build_data_loader(train_dataset, args.micro_batch_size,
args.num_workers, not args.keep_last)
args.num_workers, not args.keep_last,
task_collate_fn)
# Set the training iterations.
args.train_iters_per_epoch = len(train_dataloader)
args.train_iters = args.epochs * args.train_iters_per_epoch
# Validation dataset. For this dataset, we do not need to set up
# shuffling so we can just use a simple infinite loop.
valid_dataloader_ = build_data_loader(valid_dataset, args.micro_batch_size,
args.num_workers, not args.keep_last)
args.num_workers, not args.keep_last,
task_collate_fn)
valid_dataloader = _build_infinite_size_dataloader(valid_dataloader_)
# Now that we've built the data loaders, set batch_size arguments
......@@ -185,9 +191,10 @@ def _train(model, optimizer, lr_scheduler, forward_step,
continue
# Set to zero so the next epoch does not skip any batches.
start_iteration = 0
# Train for one step.
out = train_step(forward_step, batch, model, optimizer, lr_scheduler)
losses_dict, skipped_iter, grad_norm, num_zeros_in_grad = out
iteration += 1
......@@ -220,6 +227,10 @@ def _train(model, optimizer, lr_scheduler, forward_step,
valid_dataloader, model,
iteration, False)
#if iteration == 1000:
# exit()
#break
# Checkpointing at the end of each epoch.
if args.save:
save_checkpoint(iteration, model, optimizer, lr_scheduler)
......@@ -231,7 +242,8 @@ def _train(model, optimizer, lr_scheduler, forward_step,
def finetune(train_valid_datasets_provider, model_provider,
forward_step=_cross_entropy_forward_step,
end_of_epoch_callback_provider=None):
end_of_epoch_callback_provider=None,
task_collate_fn=None):
"""Main finetune function used across all tasks."""
args = get_args()
timers = get_timers()
......@@ -244,7 +256,7 @@ def finetune(train_valid_datasets_provider, model_provider,
if args.epochs > 0:
train_dataset, valid_dataset = train_valid_datasets_provider()
train_dataloader, valid_dataloader = _build_train_valid_dataloaders(
train_dataset, valid_dataset)
train_dataset, valid_dataset, task_collate_fn)
else:
args.train_iters = 0
timers('train/valid/test dataset/dataloder').stop()
......@@ -256,8 +268,6 @@ def finetune(train_valid_datasets_provider, model_provider,
end_of_epoch_callback = end_of_epoch_callback_provider()
timers('callback function').stop()
exit()
# Build model, optimizer and learning rate scheduler.
timers('model and optimizer').start()
model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment