Commit 7577931b authored by Mostofa Patwary's avatar Mostofa Patwary
Browse files

Fixed issues with ICT pretraining

parent 8e44d619
#!/bin/bash
# Compute embeddings for each entry of a given dataset (e.g. Wikipedia)
RANK=0
WORLD_SIZE=1
# Wikipedia data can be downloaded from the following link:
# https://github.com/facebookresearch/DPR/blob/master/data/download_data.py
EVIDENCE_DATA_DIR=<Specify path of Wikipedia dataset>
EMBEDDING_PATH=<Specify path to store embeddings>
CHECKPOINT_PATH=<Specify path of pretrained ICT model>
python tools/create_doc_index.py \
--num-layers 12 \
--hidden-size 768 \
--num-attention-heads 12 \
--tensor-model-parallel-size 1 \
--micro-batch-size 128 \
--checkpoint-activations \
--seq-length 512 \
--retriever-seq-length 256 \
--max-position-embeddings 512 \
--load ${CHECKPOINT_PATH} \
--evidence-data-path ${EVIDENCE_DATA_DIR} \
--embedding-path ${EMBEDDING_PATH} \
--indexer-log-interval 1000 \
--indexer-batch-size 128 \
--vocab-file bert-vocab.txt \
--num-workers 2 \
--fp16
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
# limitations under the License. # limitations under the License.
"""Pretrain BERT for Inverse Cloze Task""" """Pretrain BERT for Inverse Cloze Task"""
from functools import partial
import math import math
import torch import torch
...@@ -31,14 +33,15 @@ from megatron.training import pretrain ...@@ -31,14 +33,15 @@ from megatron.training import pretrain
from megatron.utils import average_losses_across_data_parallel_group from megatron.utils import average_losses_across_data_parallel_group
def pretrain_ict_model_provider(): def pretrain_ict_model_provider(pre_process=True, post_process=True):
args = get_args() args = get_args()
model = biencoder_model_provider( model = biencoder_model_provider(
only_context_model=False, only_context_model=False,
only_query_model=False, only_query_model=False,
biencoder_shared_query_context_model=\ biencoder_shared_query_context_model=\
args.biencoder_shared_query_context_model) args.biencoder_shared_query_context_model,
pre_process=pre_process, post_process=post_process)
return model return model
...@@ -79,25 +82,9 @@ class AllgatherFromDataParallelRegion(torch.autograd.Function): ...@@ -79,25 +82,9 @@ class AllgatherFromDataParallelRegion(torch.autograd.Function):
output = output_list[rank].contiguous() output = output_list[rank].contiguous()
return output return output
def forward_step(data_iterator, model, input_tensor): def loss_func(output_tensor):
"""Forward step."""
args = get_args() args = get_args()
timers = get_timers() query_logits, context_logits = output_tensor
# Get the batch.
timers('batch-generator').start()
query_tokens, query_mask, \
context_tokens, context_mask, context_indices = get_ict_batch(data_iterator)
timers('batch-generator').stop()
# Query and Context Types
query_types = torch.cuda.LongTensor(*query_tokens.shape).fill_(0)
context_types = torch.cuda.LongTensor(*context_tokens.shape).fill_(0)
# Forward model.
query_logits, context_logits = model(query_tokens, query_mask,
query_types, context_tokens,
context_mask, context_types)
micro_batch_size = query_logits.shape[0] micro_batch_size = query_logits.shape[0]
# recall we assert that tensor_model_parallel_size == 1 # recall we assert that tensor_model_parallel_size == 1
...@@ -139,6 +126,28 @@ def forward_step(data_iterator, model, input_tensor): ...@@ -139,6 +126,28 @@ def forward_step(data_iterator, model, input_tensor):
return loss, stats_dict return loss, stats_dict
def forward_step(data_iterator, model):
"""Forward step."""
args = get_args()
timers = get_timers()
# Get the batch.
timers('batch-generator').start()
query_tokens, query_mask, \
context_tokens, context_mask, context_indices = get_ict_batch(data_iterator)
timers('batch-generator').stop()
# Query and Context Types
query_types = torch.cuda.LongTensor(*query_tokens.shape).fill_(0)
context_types = torch.cuda.LongTensor(*context_tokens.shape).fill_(0)
# Forward model.
output_tensor = model(query_tokens, query_mask, query_types, context_tokens,
context_mask, context_types)
return output_tensor, partial(loss_func)
def train_valid_test_datasets_provider(train_val_test_num_samples): def train_valid_test_datasets_provider(train_val_test_num_samples):
"""Build train, valid and test datasets.""" """Build train, valid and test datasets."""
args = get_args() args = get_args()
......
import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
os.path.pardir)))
from megatron import print_rank_0
from megatron.indexer import IndexBuilder
from megatron.initialize import initialize_megatron
def main():
"""Create a BlockData data structure by running an IndexBuilder over an ICT Dataset
- Include all args needed for initial model specification
Other key args:
--block-data-path: path to write to
--ict-load or --realm-load: path to checkpoint with which to embed
--data-path and --titles-data-path: paths for dataset
--indexer-log-interval: reporting interval
--indexer-batch-size: size specific for indexer jobs
Check README.md for example script
"""
initialize_megatron(extra_args_provider=None,
args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
index_builder = IndexBuilder()
index_builder.build_and_save_index()
print_rank_0("Build and save indices: done!")
if __name__ == "__main__":
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment