Commit 63121a9e authored by Mostofa Patwary's avatar Mostofa Patwary
Browse files

updating the scripts

parent d078e54a
......@@ -32,5 +32,8 @@ python tasks/main.py \
--num-workers 2 \
--faiss-use-gpu \
--retriever-report-topk-accuracies 1 5 20 100 \
--fp16
--fp16 \
--indexer-log-interval 1000 \
--indexer-batch-size 128
#!/bin/bash
# Finetune a BERT or pretrained ICT model using Google natural question data
# Datasets can be downloaded from the following link:
# https://github.com/facebookresearch/DPR/blob/master/data/download_data.py
WORLD_SIZE=8
DISTRIBUTED_ARGS="--nproc_per_node $WORLD_SIZE \
--nnodes 1 \
--node_rank 0 \
--master_addr localhost \
--master_port 6000"
CHECKPOINT_PATH=<Specify path for the finetuned retriever model>
# Load either of the below
BERT_LOAD_PATH=<Path of BERT pretrained model>
PRETRAINED_CHECKPOINT=<Path of Pretrained ICT model>
python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/main.py \
--task RET-FINETUNE-NQ \
--train-with-neg \
--train-hard-neg 1 \
--pretrained-checkpoint ${PRETRAINED_CHECKPOINT} \
--num-layers 12 \
--hidden-size 768 \
--num-attention-heads 12 \
--tensor-model-parallel-size 1 \
--tokenizer-type BertWordPieceLowerCase \
--train-data nq-train.json \
--valid-data nq-dev.json \
--save ${CHECKPOINT_PATH} \
--load ${CHECKPOINT_PATH} \
--vocab-file bert-vocab.txt \
--bert-load ${BERT_LOAD_PATH} \
--save-interval 5000 \
--log-interval 10 \
--eval-interval 25000 \
--eval-iters 100 \
--indexer-log-interval 1000 \
--faiss-use-gpu \
--DDP-impl torch \
--fp16 \
--retriever-report-topk-accuracies 1 5 10 20 100 \
--seq-length 512 \
--retriever-seq-length 256 \
--max-position-embeddings 512 \
--retriever-score-scaling \
--epochs 80 \
--micro-batch-size 8 \
--eval-micro-batch-size 16 \
--indexer-batch-size 128 \
--lr 2e-5 \
--lr-warmup-fraction 0.01 \
--weight-decay 1e-1
......@@ -320,7 +320,7 @@ class PretrainedBertModel(MegatronModule):
def load_state_dict(self, state_dict, strict=True):
"""Customized load."""
print_rank_0("loading BERT weights")
print_rank_0("loading pretrained weights")
self.language_model.load_state_dict(
state_dict[self._language_model_key], strict=strict)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment