Commit 06076c7a authored by Mostofa Patwary's avatar Mostofa Patwary
Browse files

implementation dpr

parent cdde4338
...@@ -17,7 +17,9 @@ from .module import MegatronModule ...@@ -17,7 +17,9 @@ from .module import MegatronModule
def biencoder_model_provider(only_query_model=False, def biencoder_model_provider(only_query_model=False,
only_context_model=False, only_context_model=False,
biencoder_shared_query_context_model=False): biencoder_shared_query_context_model=False,
pre_process=True,
post_process=True):
"""Build the model.""" """Build the model."""
args = get_args() args = get_args()
...@@ -35,7 +37,9 @@ def biencoder_model_provider(only_query_model=False, ...@@ -35,7 +37,9 @@ def biencoder_model_provider(only_query_model=False,
only_query_model=only_query_model, only_query_model=only_query_model,
only_context_model=only_context_model, only_context_model=only_context_model,
biencoder_shared_query_context_model=\ biencoder_shared_query_context_model=\
biencoder_shared_query_context_model) biencoder_shared_query_context_model,
pre_process=pre_process,
post_process=post_process)
return model return model
...@@ -48,13 +52,17 @@ class BiEncoderModel(MegatronModule): ...@@ -48,13 +52,17 @@ class BiEncoderModel(MegatronModule):
parallel_output=True, parallel_output=True,
only_query_model=False, only_query_model=False,
only_context_model=False, only_context_model=False,
biencoder_shared_query_context_model=False): biencoder_shared_query_context_model=False,
pre_process=True,
post_process=True):
super(BiEncoderModel, self).__init__() super(BiEncoderModel, self).__init__()
args = get_args() args = get_args()
bert_kwargs = dict( bert_kwargs = dict(
num_tokentypes=num_tokentypes, num_tokentypes=num_tokentypes,
parallel_output=parallel_output) parallel_output=parallel_output,
pre_process=pre_process,
post_process=post_process)
self.biencoder_shared_query_context_model = \ self.biencoder_shared_query_context_model = \
biencoder_shared_query_context_model biencoder_shared_query_context_model
...@@ -78,6 +86,19 @@ class BiEncoderModel(MegatronModule): ...@@ -78,6 +86,19 @@ class BiEncoderModel(MegatronModule):
self.context_model = PretrainedBertModel(**bert_kwargs) self.context_model = PretrainedBertModel(**bert_kwargs)
self._context_key = 'context_model' self._context_key = 'context_model'
def set_input_tensor(self, input_tensor):
"""See megatron.model.transformer.set_input_tensor()"""
#self.language_model.set_input_tensor(input_tensor)
return
# #if self._model_key is not None:
# # print("_model_key {}".format(self._model_key), flush=True)
# print(input_tensor)
# if self._query_key is not None:
# print("_query_key {}".format(self._query_key), flush=True)
# if self._context_key is not None:
# print("_context_key {}".format(self._context_key), flush=True)
# exit()
def forward(self, query_tokens, query_attention_mask, query_types, def forward(self, query_tokens, query_attention_mask, query_types,
context_tokens, context_attention_mask, context_types): context_tokens, context_attention_mask, context_types):
"""Run a forward pass for each of the models and """Run a forward pass for each of the models and
...@@ -217,7 +238,7 @@ class PretrainedBertModel(MegatronModule): ...@@ -217,7 +238,7 @@ class PretrainedBertModel(MegatronModule):
learned information retrieval.""" learned information retrieval."""
def __init__(self, num_tokentypes=2, def __init__(self, num_tokentypes=2,
parallel_output=True): parallel_output=True, pre_process=True, post_process=True):
super(PretrainedBertModel, self).__init__() super(PretrainedBertModel, self).__init__()
args = get_args() args = get_args()
...@@ -225,6 +246,8 @@ class PretrainedBertModel(MegatronModule): ...@@ -225,6 +246,8 @@ class PretrainedBertModel(MegatronModule):
self.pad_id = tokenizer.pad self.pad_id = tokenizer.pad
self.biencoder_projection_dim = args.biencoder_projection_dim self.biencoder_projection_dim = args.biencoder_projection_dim
self.parallel_output = parallel_output self.parallel_output = parallel_output
self.pre_process = pre_process
self.post_process = post_process
init_method = init_method_normal(args.init_method_std) init_method = init_method_normal(args.init_method_std)
scaled_init_method = scaled_init_method_normal( scaled_init_method = scaled_init_method_normal(
args.init_method_std, args.num_layers) args.init_method_std, args.num_layers)
...@@ -234,7 +257,9 @@ class PretrainedBertModel(MegatronModule): ...@@ -234,7 +257,9 @@ class PretrainedBertModel(MegatronModule):
add_pooler=False, add_pooler=False,
encoder_attn_mask_type=AttnMaskType.padding, encoder_attn_mask_type=AttnMaskType.padding,
init_method=init_method, init_method=init_method,
scaled_init_method=scaled_init_method) scaled_init_method=scaled_init_method,
pre_process=self.pre_process,
post_process=self.post_process)
if args.biencoder_projection_dim > 0: if args.biencoder_projection_dim > 0:
self.projection_enc = get_linear_layer(args.hidden_size, self.projection_enc = get_linear_layer(args.hidden_size,
......
...@@ -181,6 +181,35 @@ class FullTokenizer(object): ...@@ -181,6 +181,35 @@ class FullTokenizer(object):
def convert_ids_to_tokens(self, ids): def convert_ids_to_tokens(self, ids):
return convert_by_vocab(self.inv_vocab, ids) return convert_by_vocab(self.inv_vocab, ids)
@staticmethod
def convert_tokens_to_string(tokens, clean_up_tokenization_spaces=True):
""" Converts a sequence of tokens (string) in a single string. """
def clean_up_tokenization(out_string):
""" Clean up a list of simple English tokenization artifacts
like spaces before punctuations and abreviated forms.
"""
out_string = (
out_string.replace(" .", ".")
.replace(" ?", "?")
.replace(" !", "!")
.replace(" ,", ",")
.replace(" ' ", "'")
.replace(" n't", "n't")
.replace(" 'm", "'m")
.replace(" 's", "'s")
.replace(" 've", "'ve")
.replace(" 're", "'re")
)
return out_string
text = ' '.join(tokens).replace(' ##', '').strip()
if clean_up_tokenization_spaces:
clean_text = clean_up_tokenization(text)
return clean_text
else:
return text
def vocab_size(self): def vocab_size(self):
return len(self.vocab) return len(self.vocab)
......
...@@ -155,6 +155,10 @@ class _BertWordPieceTokenizer(AbstractTokenizer): ...@@ -155,6 +155,10 @@ class _BertWordPieceTokenizer(AbstractTokenizer):
text_tokens = self.tokenizer.tokenize(text) text_tokens = self.tokenizer.tokenize(text)
return self.tokenizer.convert_tokens_to_ids(text_tokens) return self.tokenizer.convert_tokens_to_ids(text_tokens)
def decode(self, ids):
tokens = self.tokenizer.convert_ids_to_tokens(ids)
return self.tokenizer.convert_tokens_to_string(tokens)
def decode_token_ids(self, token_ids): def decode_token_ids(self, token_ids):
tokens = self.tokenizer.convert_ids_to_tokens(token_ids) tokens = self.tokenizer.convert_ids_to_tokens(token_ids)
exclude_list = ['[PAD]', '[CLS]'] exclude_list = ['[PAD]', '[CLS]']
......
...@@ -80,7 +80,8 @@ def _cross_entropy_forward_step(batch, model): ...@@ -80,7 +80,8 @@ def _cross_entropy_forward_step(batch, model):
return output_tensor, partial(cross_entropy_loss_func, labels) return output_tensor, partial(cross_entropy_loss_func, labels)
def build_data_loader(dataset, micro_batch_size, num_workers, drop_last): def build_data_loader(dataset, micro_batch_size, num_workers, drop_last,
task_collate_fn=None):
"""Data loader. Note that batch-size is the local (per GPU) batch-size.""" """Data loader. Note that batch-size is the local (per GPU) batch-size."""
# Sampler. # Sampler.
...@@ -89,6 +90,8 @@ def build_data_loader(dataset, micro_batch_size, num_workers, drop_last): ...@@ -89,6 +90,8 @@ def build_data_loader(dataset, micro_batch_size, num_workers, drop_last):
sampler = torch.utils.data.distributed.DistributedSampler( sampler = torch.utils.data.distributed.DistributedSampler(
dataset, num_replicas=world_size, rank=rank) dataset, num_replicas=world_size, rank=rank)
print_rank_0(len(sampler))
# Data loader. Note that batch size is the per GPU batch size. # Data loader. Note that batch size is the per GPU batch size.
data_loader = torch.utils.data.DataLoader(dataset, data_loader = torch.utils.data.DataLoader(dataset,
batch_size=micro_batch_size, batch_size=micro_batch_size,
...@@ -96,7 +99,8 @@ def build_data_loader(dataset, micro_batch_size, num_workers, drop_last): ...@@ -96,7 +99,8 @@ def build_data_loader(dataset, micro_batch_size, num_workers, drop_last):
shuffle=False, shuffle=False,
num_workers=num_workers, num_workers=num_workers,
drop_last=drop_last, drop_last=drop_last,
pin_memory=True) pin_memory=True,
collate_fn=task_collate_fn)
return data_loader return data_loader
...@@ -112,21 +116,23 @@ def _build_infinite_size_dataloader(dataloader): ...@@ -112,21 +116,23 @@ def _build_infinite_size_dataloader(dataloader):
iterator = dataloader.__iter__() iterator = dataloader.__iter__()
def _build_train_valid_dataloaders(train_dataset, valid_dataset): def _build_train_valid_dataloaders(train_dataset, valid_dataset, task_collate_fn=None):
"""Traing and validation dataloaders.""" """Traing and validation dataloaders."""
args = get_args() args = get_args()
print_rank_0('building train and validation dataloaders ...') print_rank_0('building train and validation dataloaders ...')
# Training dataset. # Training dataset.
train_dataloader = build_data_loader(train_dataset, args.micro_batch_size, train_dataloader = build_data_loader(train_dataset, args.micro_batch_size,
args.num_workers, not args.keep_last) args.num_workers, not args.keep_last,
task_collate_fn)
# Set the training iterations. # Set the training iterations.
args.train_iters_per_epoch = len(train_dataloader) args.train_iters_per_epoch = len(train_dataloader)
args.train_iters = args.epochs * args.train_iters_per_epoch args.train_iters = args.epochs * args.train_iters_per_epoch
# Validation dataset. For this dataset, we do not need to set up # Validation dataset. For this dataset, we do not need to set up
# shuffling so we can just use a simple infinite loop. # shuffling so we can just use a simple infinite loop.
valid_dataloader_ = build_data_loader(valid_dataset, args.micro_batch_size, valid_dataloader_ = build_data_loader(valid_dataset, args.micro_batch_size,
args.num_workers, not args.keep_last) args.num_workers, not args.keep_last,
task_collate_fn)
valid_dataloader = _build_infinite_size_dataloader(valid_dataloader_) valid_dataloader = _build_infinite_size_dataloader(valid_dataloader_)
# Now that we've built the data loaders, set batch_size arguments # Now that we've built the data loaders, set batch_size arguments
...@@ -185,9 +191,10 @@ def _train(model, optimizer, lr_scheduler, forward_step, ...@@ -185,9 +191,10 @@ def _train(model, optimizer, lr_scheduler, forward_step,
continue continue
# Set to zero so the next epoch does not skip any batches. # Set to zero so the next epoch does not skip any batches.
start_iteration = 0 start_iteration = 0
# Train for one step. # Train for one step.
out = train_step(forward_step, batch, model, optimizer, lr_scheduler) out = train_step(forward_step, batch, model, optimizer, lr_scheduler)
losses_dict, skipped_iter, grad_norm, num_zeros_in_grad = out losses_dict, skipped_iter, grad_norm, num_zeros_in_grad = out
iteration += 1 iteration += 1
...@@ -220,6 +227,10 @@ def _train(model, optimizer, lr_scheduler, forward_step, ...@@ -220,6 +227,10 @@ def _train(model, optimizer, lr_scheduler, forward_step,
valid_dataloader, model, valid_dataloader, model,
iteration, False) iteration, False)
#if iteration == 1000:
# exit()
#break
# Checkpointing at the end of each epoch. # Checkpointing at the end of each epoch.
if args.save: if args.save:
save_checkpoint(iteration, model, optimizer, lr_scheduler) save_checkpoint(iteration, model, optimizer, lr_scheduler)
...@@ -231,7 +242,8 @@ def _train(model, optimizer, lr_scheduler, forward_step, ...@@ -231,7 +242,8 @@ def _train(model, optimizer, lr_scheduler, forward_step,
def finetune(train_valid_datasets_provider, model_provider, def finetune(train_valid_datasets_provider, model_provider,
forward_step=_cross_entropy_forward_step, forward_step=_cross_entropy_forward_step,
end_of_epoch_callback_provider=None): end_of_epoch_callback_provider=None,
task_collate_fn=None):
"""Main finetune function used across all tasks.""" """Main finetune function used across all tasks."""
args = get_args() args = get_args()
timers = get_timers() timers = get_timers()
...@@ -244,7 +256,7 @@ def finetune(train_valid_datasets_provider, model_provider, ...@@ -244,7 +256,7 @@ def finetune(train_valid_datasets_provider, model_provider,
if args.epochs > 0: if args.epochs > 0:
train_dataset, valid_dataset = train_valid_datasets_provider() train_dataset, valid_dataset = train_valid_datasets_provider()
train_dataloader, valid_dataloader = _build_train_valid_dataloaders( train_dataloader, valid_dataloader = _build_train_valid_dataloaders(
train_dataset, valid_dataset) train_dataset, valid_dataset, task_collate_fn)
else: else:
args.train_iters = 0 args.train_iters = 0
timers('train/valid/test dataset/dataloder').stop() timers('train/valid/test dataset/dataloder').stop()
...@@ -256,8 +268,6 @@ def finetune(train_valid_datasets_provider, model_provider, ...@@ -256,8 +268,6 @@ def finetune(train_valid_datasets_provider, model_provider,
end_of_epoch_callback = end_of_epoch_callback_provider() end_of_epoch_callback = end_of_epoch_callback_provider()
timers('callback function').stop() timers('callback function').stop()
exit()
# Build model, optimizer and learning rate scheduler. # Build model, optimizer and learning rate scheduler.
timers('model and optimizer').start() timers('model and optimizer').start()
model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider) model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment