Commit ca0cdfaa authored by Neel Kant's avatar Neel Kant
Browse files

Merge ict-merge into indexer-merge

parents d9d4ce70 a5bfc013
...@@ -4,13 +4,3 @@ __pycache__ ...@@ -4,13 +4,3 @@ __pycache__
build/ build/
dist/ dist/
*.egg-info/ *.egg-info/
# added by neel
*.npy
*.bin
*.idx
*.pkl
raw_*
run_*
realm_*
...@@ -272,15 +272,15 @@ Loosely, they are pretraining the retriever modules, then jointly training the l ...@@ -272,15 +272,15 @@ Loosely, they are pretraining the retriever modules, then jointly training the l
### Inverse Cloze Task (ICT) Pretraining ### Inverse Cloze Task (ICT) Pretraining
1. Have a corpus in loose JSON format with the intention of creating a collection of fixed-size blocks of text as the fundamental units of data. For a corpus like Wikipedia, this will mean multiple sentences per block but also multiple blocks per document. 1. Have a corpus in loose JSON format with the intention of creating a collection of fixed-size blocks of text as the fundamental units of data. For a corpus like Wikipedia, this will mean multiple sentences per block but also multiple blocks per document.
Run `tools/preprocess_data.py` to construct one or more indexed datasets with the `--split-sentences` argument to make sentences the basic unit. For the original REALM system, we construct two datasets, one with the title of every document, and another with the body. Run `tools/preprocess_data.py` to construct one or more indexed datasets with the `--split-sentences` argument to make sentences the basic unit. For the original REALM system, we construct two datasets, one with the title of every document, and another with the body.
Refer to the following script meant to be run in an interactive session on draco: Refer to the following script
<pre> <pre>
python preprocess_data.py \ python preprocess_data.py \
--input /home/universal-lm-data.cosmos549/datasets/wikipedia/wikidump_lines.json \ --input /path/to/corpus.json \
--json-keys text title \ --json-keys text title \
--split-sentences \ --split-sentences \
--tokenizer-type BertWordPieceLowerCase \ --tokenizer-type BertWordPieceLowerCase \
--vocab-file /home/universal-lm-data.cosmos549/scratch/mshoeybi/data/albert/vocab.txt \ --vocab-file /path/to/vocab.txt \
--output-prefix wiki_indexed \ --output-prefix corpus_indexed \
--workers 5 # works well for 10 CPU cores. Scale up accordingly. --workers 5 # works well for 10 CPU cores. Scale up accordingly.
</pre> </pre>
...@@ -288,13 +288,10 @@ python preprocess_data.py \ ...@@ -288,13 +288,10 @@ python preprocess_data.py \
The samples mapping is responsible for holding all of the required metadata needed to construct the sample from one or more indexed datasets. In REALM, the samples mapping contains the start and end sentence indices, as well as the document index (to find the correct title for a body) and a unique ID for every block. The samples mapping is responsible for holding all of the required metadata needed to construct the sample from one or more indexed datasets. In REALM, the samples mapping contains the start and end sentence indices, as well as the document index (to find the correct title for a body) and a unique ID for every block.
3. Pretrain a BERT language model using `pretrain_bert.py`, with the sequence length equal to the block size in token ids. This model should be trained on the same indexed dataset that is used to supply the blocks for the information retrieval task. 3. Pretrain a BERT language model using `pretrain_bert.py`, with the sequence length equal to the block size in token ids. This model should be trained on the same indexed dataset that is used to supply the blocks for the information retrieval task.
In REALM, this is an uncased bert base model trained with the standard hyperparameters. In REALM, this is an uncased bert base model trained with the standard hyperparameters.
4. Use `pretrain_bert_ict.py` to train an `ICTBertModel` which uses two BERT-based encoders to encode queries and blocks to perform retrieval with. 4. Use `pretrain_ict.py` to train an `ICTBertModel` which uses two BERT-based encoders to encode queries and blocks to perform retrieval with.
The script below trains the ICT model from REALM on draco. It refrences a pretrained BERT model (step 3) in the `--bert-load` argument. The script below trains the ICT model from REALM. It refrences a pretrained BERT model (step 3) in the `--bert-load` argument. The batch size used in the paper is 4096, so this would need to be run with data parallel world size 32.
<pre> <pre>
EXPNAME="ict_wikipedia" python pretrain_ict.py \
CHKPT="chkpts/${EXPNAME}"
LOGDIR="logs/${EXPNAME}"
COMMAND="/home/scratch.gcf/adlr-utils/release/cluster-interface/latest/mp_launch python pretrain_bert_ict.py \
--num-layers 12 \ --num-layers 12 \
--num-attention-heads 12 \ --num-attention-heads 12 \
--hidden-size 768 \ --hidden-size 768 \
...@@ -304,13 +301,12 @@ COMMAND="/home/scratch.gcf/adlr-utils/release/cluster-interface/latest/mp_launch ...@@ -304,13 +301,12 @@ COMMAND="/home/scratch.gcf/adlr-utils/release/cluster-interface/latest/mp_launch
--ict-head-size 128 \ --ict-head-size 128 \
--train-iters 100000 \ --train-iters 100000 \
--checkpoint-activations \ --checkpoint-activations \
--bert-load /home/dcg-adlr-nkant-output.cosmos1203/chkpts/base_bert_seq256 \ --bert-load /path/to/pretrained_bert \
--load $CHKPT \ --load checkpoints \
--save $CHKPT \ --save checkpoints \
--data-path /home/dcg-adlr-nkant-data.cosmos1202/wiki/wikipedia_lines \ --data-path /path/to/indexed_dataset \
--titles-data-path /home/dcg-adlr-nkant-data.cosmos1202/wiki/wikipedia_lines-titles \ --titles-data-path /path/to/titles_indexed_dataset \
--vocab-file /home/universal-lm-data.cosmos549/scratch/mshoeybi/data/albert/vocab.txt \ --vocab-file /path/to/vocab.txt \
--distributed-backend nccl \
--lr 0.0001 \ --lr 0.0001 \
--num-workers 2 \ --num-workers 2 \
--lr-decay-style linear \ --lr-decay-style linear \
...@@ -319,11 +315,8 @@ COMMAND="/home/scratch.gcf/adlr-utils/release/cluster-interface/latest/mp_launch ...@@ -319,11 +315,8 @@ COMMAND="/home/scratch.gcf/adlr-utils/release/cluster-interface/latest/mp_launch
--warmup .01 \ --warmup .01 \
--save-interval 3000 \ --save-interval 3000 \
--query-in-block-prob 0.1 \ --query-in-block-prob 0.1 \
--fp16 \ --fp16
--adlr-autoresume \
--adlr-autoresume-interval 100"
submit_job --image 'http://gitlab-master.nvidia.com/adlr/megatron-lm/megatron:20.03_faiss' --mounts /home/universal-lm-data.cosmos549,/home/dcg-adlr-nkant-data.cosmos1202,/home/dcg-adlr-nkant-output.cosmos1203,/home/nkant --name "${EXPNAME}" --partition batch_32GB --gpu 8 --nodes 4 --autoresume_timer 420 -c "${COMMAND}" --logdir "${LOGDIR}"
</pre> </pre>
### Building an Index of Block Embeddings ### Building an Index of Block Embeddings
...@@ -331,9 +324,7 @@ After having trained an ICT model, you can now embed an entire dataset of blocks ...@@ -331,9 +324,7 @@ After having trained an ICT model, you can now embed an entire dataset of blocks
and wrap it with a `FaissMIPSIndex` to do fast similarity search which is key in the learned information retrieval pipeline. The initial index can be built with the following script, meant to be run in an interactive session. It can leverage multiple GPUs on multiple nodes to index large datasets much more quickly. and wrap it with a `FaissMIPSIndex` to do fast similarity search which is key in the learned information retrieval pipeline. The initial index can be built with the following script, meant to be run in an interactive session. It can leverage multiple GPUs on multiple nodes to index large datasets much more quickly.
<pre> <pre>
ICT_LOAD="chkpts/ict_wikipedia" python indexer.py \
BLOCK_DATA="block_data/wikipedia"
/home/scratch.gcf/adlr-utils/release/cluster-interface/latest/mp_launch python indexer.py \
--num-layers 12 \ --num-layers 12 \
--hidden-size 768 \ --hidden-size 768 \
--ict-head-size 128 \ --ict-head-size 128 \
...@@ -342,11 +333,11 @@ BLOCK_DATA="block_data/wikipedia" ...@@ -342,11 +333,11 @@ BLOCK_DATA="block_data/wikipedia"
--checkpoint-activations \ --checkpoint-activations \
--seq-length 256 \ --seq-length 256 \
--max-position-embeddings 256 \ --max-position-embeddings 256 \
--ict-load $ICT_LOAD \ --ict-load /path/to/pretrained_ict \
--data-path /home/dcg-adlr-nkant-data.cosmos1202/wiki/wikipedia_lines \ --data-path /path/to/indexed_dataset \
--titles-data-path /home/dcg-adlr-nkant-data.cosmos1202/wiki/wikipedia_lines \ --titles-data-path /path/to/titles_indexed_dataset \
--block-data-path $BLOCK_DATA \ --block-data-path embedded_blocks.pkl \
--vocab-file /home/universal-lm-data.cosmos549/scratch/mshoeybi/data/albert/vocab.txt \ --vocab-file /path/to/vocab.txt \
--num-workers 2 \ --num-workers 2 \
--fp16 --fp16
</pre> </pre>
......
...@@ -37,6 +37,7 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -37,6 +37,7 @@ def parse_args(extra_args_provider=None, defaults={},
parser = _add_validation_args(parser) parser = _add_validation_args(parser)
parser = _add_data_args(parser) parser = _add_data_args(parser)
parser = _add_autoresume_args(parser) parser = _add_autoresume_args(parser)
parser = _add_realm_args(parser)
# Custom arguments. # Custom arguments.
if extra_args_provider is not None: if extra_args_provider is not None:
...@@ -139,8 +140,6 @@ def _add_network_size_args(parser): ...@@ -139,8 +140,6 @@ def _add_network_size_args(parser):
' grouped: [1, 2, 1, 2] and spaced: [1, 1, 2, 2].') ' grouped: [1, 2, 1, 2] and spaced: [1, 1, 2, 2].')
group.add_argument('--hidden-size', type=int, default=None, group.add_argument('--hidden-size', type=int, default=None,
help='Tansformer hidden size.') help='Tansformer hidden size.')
group.add_argument('--ict-head-size', type=int, default=None,
help='Size of block embeddings to be used in ICT and REALM (paper default: 128)')
group.add_argument('--num-attention-heads', type=int, default=None, group.add_argument('--num-attention-heads', type=int, default=None,
help='Number of transformer attention heads.') help='Number of transformer attention heads.')
group.add_argument('--max-position-embeddings', type=int, default=None, group.add_argument('--max-position-embeddings', type=int, default=None,
...@@ -264,10 +263,6 @@ def _add_checkpointing_args(parser): ...@@ -264,10 +263,6 @@ def _add_checkpointing_args(parser):
help='Do not save current rng state.') help='Do not save current rng state.')
group.add_argument('--load', type=str, default=None, group.add_argument('--load', type=str, default=None,
help='Directory containing a model checkpoint.') help='Directory containing a model checkpoint.')
group.add_argument('--ict-load', type=str, default=None,
help='Directory containing an ICTBertModel checkpoint')
group.add_argument('--bert-load', type=str, default=None,
help='Directory containing an BertModel checkpoint (needed to start ICT and REALM)')
group.add_argument('--no-load-optim', action='store_true', group.add_argument('--no-load-optim', action='store_true',
help='Do not load optimizer when loading checkpoint.') help='Do not load optimizer when loading checkpoint.')
group.add_argument('--no-load-rng', action='store_true', group.add_argument('--no-load-rng', action='store_true',
...@@ -347,10 +342,6 @@ def _add_data_args(parser): ...@@ -347,10 +342,6 @@ def _add_data_args(parser):
group.add_argument('--data-path', type=str, default=None, group.add_argument('--data-path', type=str, default=None,
help='Path to combined dataset to split.') help='Path to combined dataset to split.')
group.add_argument('--titles-data-path', type=str, default=None,
help='Path to titles dataset used for ICT')
group.add_argument('--block-data-path', type=str, default=None,
help='Path for loading and saving block data')
group.add_argument('--split', type=str, default='969, 30, 1', group.add_argument('--split', type=str, default='969, 30, 1',
help='Comma-separated list of proportions for training,' help='Comma-separated list of proportions for training,'
' validation, and test split. For example the split ' ' validation, and test split. For example the split '
...@@ -386,10 +377,6 @@ def _add_data_args(parser): ...@@ -386,10 +377,6 @@ def _add_data_args(parser):
'end-of-document token.') 'end-of-document token.')
group.add_argument('--eod-mask-loss', action='store_true', group.add_argument('--eod-mask-loss', action='store_true',
help='Mask loss for the end of document tokens.') help='Mask loss for the end of document tokens.')
group.add_argument('--query-in-block-prob', type=float, default=0.1,
help='Probability of keeping query in block for ICT dataset')
group.add_argument('--faiss-use-gpu', action='store_true',
help='Whether create the FaissMIPSIndex on GPU')
return parser return parser
...@@ -404,3 +391,37 @@ def _add_autoresume_args(parser): ...@@ -404,3 +391,37 @@ def _add_autoresume_args(parser):
'termination signal') 'termination signal')
return parser return parser
def _add_realm_args(parser):
group = parser.add_argument_group(title='realm')
# network size
group.add_argument('--ict-head-size', type=int, default=None,
help='Size of block embeddings to be used in ICT and REALM (paper default: 128)')
# checkpointing
group.add_argument('--ict-load', type=str, default=None,
help='Directory containing an ICTBertModel checkpoint')
group.add_argument('--bert-load', type=str, default=None,
help='Directory containing an BertModel checkpoint (needed to start ICT and REALM)')
# data
group.add_argument('--titles-data-path', type=str, default=None,
help='Path to titles dataset used for ICT')
group.add_argument('--query-in-block-prob', type=float, default=0.1,
help='Probability of keeping query in block for ICT dataset')
group.add_argument('--ict-one-sent', action='store_true',
help='Whether to use one sentence documents in ICT')
# training
group.add_argument('--report-topk-accuracies', nargs='+', default=[],
help="Which top-k accuracies to report (e.g. '1 5 20')")
# faiss index
group.add_argument('--faiss-use-gpu', action='store_true',
help='Whether create the FaissMIPSIndex on GPU')
group.add_argument('--block-data-path', type=str,
help='Where to save/load BlockData to/from')
return parser
...@@ -23,8 +23,9 @@ import numpy as np ...@@ -23,8 +23,9 @@ import numpy as np
import torch import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron import mpu, print_rank_0 from megatron import mpu
from megatron import get_args from megatron import get_args
from megatron import print_rank_0
def check_checkpoint_args(checkpoint_args): def check_checkpoint_args(checkpoint_args):
......
...@@ -22,7 +22,8 @@ import numpy as np ...@@ -22,7 +22,8 @@ import numpy as np
import torch import torch
from torch.utils.data import Dataset from torch.utils.data import Dataset
from megatron import get_tokenizer, get_args, print_rank_0 from megatron import get_tokenizer, get_args
from megatron import print_rank_0
from megatron import mpu from megatron import mpu
from megatron.data.dataset_utils import get_a_and_b_segments from megatron.data.dataset_utils import get_a_and_b_segments
from megatron.data.dataset_utils import truncate_segments from megatron.data.dataset_utils import truncate_segments
......
...@@ -399,7 +399,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, ...@@ -399,7 +399,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
def build_dataset(index, name): def build_dataset(index, name):
from megatron.data.bert_dataset import BertDataset from megatron.data.bert_dataset import BertDataset
from megatron.data.realm_dataset import ICTDataset from megatron.data.ict_dataset import ICTDataset
dataset = None dataset = None
if splits[index + 1] > splits[index]: if splits[index + 1] > splits[index]:
# Get the pointer to the original doc-idx so we can set it later. # Get the pointer to the original doc-idx so we can set it later.
...@@ -427,6 +427,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, ...@@ -427,6 +427,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
block_dataset=indexed_dataset, block_dataset=indexed_dataset,
title_dataset=title_dataset, title_dataset=title_dataset,
query_in_block_prob=args.query_in_block_prob, query_in_block_prob=args.query_in_block_prob,
use_one_sent_docs=args.ict_one_sent,
**kwargs **kwargs
) )
else: else:
......
import itertools
import random
import numpy as np
from torch.utils.data import Dataset
from megatron import get_tokenizer
from megatron.data.realm_dataset_utils import get_block_samples_mapping
class ICTDataset(Dataset):
"""Dataset containing sentences and their blocks for an inverse cloze task."""
def __init__(self, name, block_dataset, title_dataset, data_prefix,
num_epochs, max_num_samples, max_seq_length, query_in_block_prob,
seed, use_titles=True, use_one_sent_docs=False):
self.name = name
self.seed = seed
self.max_seq_length = max_seq_length
self.query_in_block_prob = query_in_block_prob
self.block_dataset = block_dataset
self.title_dataset = title_dataset
self.rng = random.Random(self.seed)
self.use_titles = use_titles
self.use_one_sent_docs = use_one_sent_docs
self.samples_mapping = get_block_samples_mapping(
block_dataset, title_dataset, data_prefix, num_epochs,
max_num_samples, max_seq_length, seed, name, use_one_sent_docs)
self.tokenizer = get_tokenizer()
self.vocab_id_list = list(self.tokenizer.inv_vocab.keys())
self.vocab_id_to_token_list = self.tokenizer.inv_vocab
self.cls_id = self.tokenizer.cls
self.sep_id = self.tokenizer.sep
self.mask_id = self.tokenizer.mask
self.pad_id = self.tokenizer.pad
def __len__(self):
return self.samples_mapping.shape[0]
def __getitem__(self, idx):
"""Get an ICT example of a pseudo-query and the block of text from which it was extracted"""
sample_data = self.samples_mapping[idx]
start_idx, end_idx, doc_idx, block_idx = sample_data.as_tuple()
if self.use_titles:
title = self.title_dataset[int(doc_idx)]
title_pad_offset = 3 + len(title)
else:
title = None
title_pad_offset = 2
block = [self.block_dataset[i] for i in range(start_idx, end_idx)]
assert len(block) > 1 or self.use_one_sent_docs or self.query_in_block_prob == 1
# randint() is inclusive for Python rng
rand_sent_idx = self.rng.randint(0, len(block) - 1)
# keep the query in the context query_in_block_prob fraction of the time.
if self.rng.random() < self.query_in_block_prob:
query = block[rand_sent_idx].copy()
else:
query = block.pop(rand_sent_idx)
# still need to truncate because blocks are concluded when
# the sentence lengths have exceeded max_seq_length.
query = query[:self.max_seq_length - 2]
block = list(itertools.chain(*block))[:self.max_seq_length - title_pad_offset]
query_tokens, query_pad_mask = self.concat_and_pad_tokens(query)
block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title)
block_data = sample_data.as_array()
sample = {
'query_tokens': query_tokens,
'query_pad_mask': query_pad_mask,
'block_tokens': block_tokens,
'block_pad_mask': block_pad_mask,
'block_data': block_data,
}
return sample
def get_block(self, start_idx, end_idx, doc_idx):
"""Get the IDs for an evidence block plus the title of the corresponding document"""
block = [self.block_dataset[i] for i in range(start_idx, end_idx)]
title = self.title_dataset[int(doc_idx)]
block = list(itertools.chain(*block))[:self.max_seq_length - (3 + len(title))]
block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title)
return block_tokens, block_pad_mask
def get_null_block(self):
"""Get empty block and title - used in REALM pretraining"""
block, title = [], []
block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title)
return block_tokens, block_pad_mask
def concat_and_pad_tokens(self, tokens, title=None):
"""Concat with special tokens and pad sequence to self.max_seq_length"""
tokens = list(tokens)
if title is None:
tokens = [self.cls_id] + tokens + [self.sep_id]
else:
title = list(title)
tokens = [self.cls_id] + title + [self.sep_id] + tokens + [self.sep_id]
assert len(tokens) <= self.max_seq_length
num_pad = self.max_seq_length - len(tokens)
pad_mask = [1] * len(tokens) + [0] * num_pad
tokens += [self.pad_id] * num_pad
return np.array(tokens), np.array(pad_mask)
...@@ -23,6 +23,7 @@ class ICTDataset(Dataset): ...@@ -23,6 +23,7 @@ class ICTDataset(Dataset):
self.short_seq_prob = short_seq_prob self.short_seq_prob = short_seq_prob
self.rng = random.Random(self.seed) self.rng = random.Random(self.seed)
self.use_titles = use_titles self.use_titles = use_titles
self.use_one_sent_docs = use_one_sent_docs
self.samples_mapping = get_block_samples_mapping( self.samples_mapping = get_block_samples_mapping(
block_dataset, title_dataset, data_prefix, num_epochs, block_dataset, title_dataset, data_prefix, num_epochs,
...@@ -50,7 +51,7 @@ class ICTDataset(Dataset): ...@@ -50,7 +51,7 @@ class ICTDataset(Dataset):
title = None title = None
title_pad_offset = 2 title_pad_offset = 2
block = [self.block_dataset[i] for i in range(start_idx, end_idx)] block = [self.block_dataset[i] for i in range(start_idx, end_idx)]
assert len(block) > 1 or self.query_in_block_prob == 1 assert len(block) > 1 or self.use_one_sent_docs or self.query_in_block_prob == 1
# randint() is inclusive for Python rng # randint() is inclusive for Python rng
rand_sent_idx = self.rng.randint(0, len(block) - 1) rand_sent_idx = self.rng.randint(0, len(block) - 1)
......
...@@ -138,7 +138,7 @@ def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epo ...@@ -138,7 +138,7 @@ def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epo
indexmap_filename)) indexmap_filename))
start_time = time.time() start_time = time.time()
mapping_array = np.load(indexmap_filename, allow_pickle=True) mapping_array = np.load(indexmap_filename, allow_pickle=True, mmap_mode='r')
samples_mapping = BlockSamplesMapping(mapping_array) samples_mapping = BlockSamplesMapping(mapping_array)
print_rank_0(' loaded indexed file in {:3.3f} seconds'.format( print_rank_0(' loaded indexed file in {:3.3f} seconds'.format(
......
...@@ -26,12 +26,46 @@ from megatron.model.utils import openai_gelu ...@@ -26,12 +26,46 @@ from megatron.model.utils import openai_gelu
from megatron.model.utils import get_linear_layer from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal from megatron.model.utils import init_method_normal
from megatron.model.utils import scaled_init_method_normal from megatron.model.utils import scaled_init_method_normal
from megatron.model.utils import bert_attention_mask_func
from megatron.model.utils import bert_extended_attention_mask
from megatron.model.utils import bert_position_ids
from megatron.module import MegatronModule from megatron.module import MegatronModule
def bert_attention_mask_func(attention_scores, attention_mask):
attention_scores = attention_scores + attention_mask
return attention_scores
def bert_extended_attention_mask(attention_mask, dtype):
# We create a 3D attention mask from a 2D tensor mask.
# [b, 1, s]
attention_mask_b1s = attention_mask.unsqueeze(1)
# [b, s, 1]
attention_mask_bs1 = attention_mask.unsqueeze(2)
# [b, s, s]
attention_mask_bss = attention_mask_b1s * attention_mask_bs1
# [b, 1, s, s]
extended_attention_mask = attention_mask_bss.unsqueeze(1)
# Since attention_mask is 1.0 for positions we want to attend and 0.0
# for masked positions, this operation will create a tensor which is
# 0.0 for positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
# fp16 compatibility
extended_attention_mask = extended_attention_mask.to(dtype=dtype)
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
return extended_attention_mask
def bert_position_ids(token_ids):
# Create position ids
seq_length = token_ids.size(1)
position_ids = torch.arange(seq_length, dtype=torch.long,
device=token_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(token_ids)
return position_ids
class BertLMHead(MegatronModule): class BertLMHead(MegatronModule):
"""Masked LM head for Bert """Masked LM head for Bert
...@@ -171,5 +205,3 @@ class BertModel(MegatronModule): ...@@ -171,5 +205,3 @@ class BertModel(MegatronModule):
if self.add_binary_head: if self.add_binary_head:
self.binary_head.load_state_dict( self.binary_head.load_state_dict(
state_dict[self._binary_head_key], strict=strict) state_dict[self._binary_head_key], strict=strict)
...@@ -18,9 +18,7 @@ ...@@ -18,9 +18,7 @@
import torch import torch
from megatron import get_args, print_rank_0 from megatron import get_args, print_rank_0
from megatron.model.bert_model import bert_attention_mask_func from megatron.model.bert_model import bert_attention_mask_func, bert_extended_attention_mask, bert_position_ids
from megatron.model.bert_model import bert_extended_attention_mask
from megatron.model.bert_model import bert_position_ids
from megatron.model.language_model import get_language_model from megatron.model.language_model import get_language_model
from megatron.model.utils import get_linear_layer from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal from megatron.model.utils import init_method_normal
......
...@@ -18,9 +18,7 @@ ...@@ -18,9 +18,7 @@
import torch import torch
from megatron import get_args, print_rank_0 from megatron import get_args, print_rank_0
from megatron.model.bert_model import bert_attention_mask_func from megatron.model.bert_model import bert_attention_mask_func, bert_extended_attention_mask, bert_position_ids
from megatron.model.bert_model import bert_extended_attention_mask
from megatron.model.bert_model import bert_position_ids
from megatron.model.language_model import get_language_model from megatron.model.language_model import get_language_model
from megatron.model.utils import get_linear_layer from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal from megatron.model.utils import init_method_normal
......
...@@ -10,9 +10,7 @@ from megatron.model.utils import get_linear_layer ...@@ -10,9 +10,7 @@ from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal from megatron.model.utils import init_method_normal
from megatron.model.language_model import get_language_model from megatron.model.language_model import get_language_model
from megatron.model.utils import scaled_init_method_normal from megatron.model.utils import scaled_init_method_normal
from megatron.model.utils import bert_attention_mask_func from megatron.model.bert_model import bert_attention_mask_func, bert_extended_attention_mask, bert_position_ids
from megatron.model.utils import bert_extended_attention_mask
from megatron.model.utils import bert_position_ids
class ICTBertModel(MegatronModule): class ICTBertModel(MegatronModule):
...@@ -52,7 +50,7 @@ class ICTBertModel(MegatronModule): ...@@ -52,7 +50,7 @@ class ICTBertModel(MegatronModule):
def embed_query(self, query_tokens, query_attention_mask): def embed_query(self, query_tokens, query_attention_mask):
"""Embed a batch of tokens using the query model""" """Embed a batch of tokens using the query model"""
if self.use_query_model: if self.use_query_model:
query_types = torch.zeros(query_tokens.shape).type(torch.int64).cuda() query_types = torch.cuda.LongTensor(*query_tokens.shape).fill_(0)
query_ict_logits, _ = self.query_model.forward(query_tokens, query_attention_mask, query_types) query_ict_logits, _ = self.query_model.forward(query_tokens, query_attention_mask, query_types)
return query_ict_logits return query_ict_logits
else: else:
...@@ -61,7 +59,7 @@ class ICTBertModel(MegatronModule): ...@@ -61,7 +59,7 @@ class ICTBertModel(MegatronModule):
def embed_block(self, block_tokens, block_attention_mask): def embed_block(self, block_tokens, block_attention_mask):
"""Embed a batch of tokens using the block model""" """Embed a batch of tokens using the block model"""
if self.use_block_model: if self.use_block_model:
block_types = torch.zeros(block_tokens.shape).type(torch.int64).cuda() block_types = torch.cuda.LongTensor(*block_tokens.shape).fill_(0)
block_ict_logits, _ = self.block_model.forward(block_tokens, block_attention_mask, block_types) block_ict_logits, _ = self.block_model.forward(block_tokens, block_attention_mask, block_types)
return block_ict_logits return block_ict_logits
else: else:
......
...@@ -78,42 +78,3 @@ def get_params_for_weight_decay_optimization(module): ...@@ -78,42 +78,3 @@ def get_params_for_weight_decay_optimization(module):
if p is not None and n == 'bias']) if p is not None and n == 'bias'])
return weight_decay_params, no_weight_decay_params return weight_decay_params, no_weight_decay_params
def bert_attention_mask_func(attention_scores, attention_mask):
attention_scores = attention_scores + attention_mask
return attention_scores
def bert_extended_attention_mask(attention_mask, dtype):
# We create a 3D attention mask from a 2D tensor mask.
# [b, 1, s]
attention_mask_b1s = attention_mask.unsqueeze(1)
# [b, s, 1]
attention_mask_bs1 = attention_mask.unsqueeze(2)
# [b, s, s]
attention_mask_bss = attention_mask_b1s * attention_mask_bs1
# [b, 1, s, s]
extended_attention_mask = attention_mask_bss.unsqueeze(1)
# Since attention_mask is 1.0 for positions we want to attend and 0.0
# for masked positions, this operation will create a tensor which is
# 0.0 for positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
# fp16 compatibility
extended_attention_mask = extended_attention_mask.to(dtype=dtype)
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
return extended_attention_mask
def bert_position_ids(token_ids):
# Create position ids
seq_length = token_ids.size(1)
position_ids = torch.arange(seq_length, dtype=torch.long,
device=token_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(token_ids)
return position_ids
...@@ -234,7 +234,9 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -234,7 +234,9 @@ class CheckpointFunction(torch.autograd.Function):
if isinstance(outputs, torch.Tensor): if isinstance(outputs, torch.Tensor):
outputs = (outputs,) outputs = (outputs,)
torch.autograd.backward(outputs, args) torch.autograd.backward(outputs, args)
return (None,) + tuple(inp.grad for inp in detached_inputs) grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp
for inp in detached_inputs)
return (None,) + grads
def checkpoint(function, *args): def checkpoint(function, *args):
......
...@@ -22,10 +22,11 @@ import torch ...@@ -22,10 +22,11 @@ import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from apex.optimizers import FusedAdam as Adam from apex.optimizers import FusedAdam as Adam
from megatron import get_args, print_rank_0 from megatron import get_args
from megatron import get_timers from megatron import get_timers
from megatron import get_tensorboard_writer from megatron import get_tensorboard_writer
from megatron import mpu from megatron import mpu
from megatron import print_rank_0
from megatron.checkpointing import load_checkpoint from megatron.checkpointing import load_checkpoint
from megatron.checkpointing import save_checkpoint from megatron.checkpointing import save_checkpoint
from megatron.fp16 import FP16_Module from megatron.fp16 import FP16_Module
...@@ -222,7 +223,11 @@ def setup_model_and_optimizer(model_provider_func): ...@@ -222,7 +223,11 @@ def setup_model_and_optimizer(model_provider_func):
else: else:
args.iteration = 0 args.iteration = 0
unwrapped_model = model.module.module # get model without FP16 and/or TorchDDP wrappers
unwrapped_model = model
while hasattr(unwrapped_model, 'module'):
unwrapped_model = unwrapped_model.module
if args.iteration == 0 and hasattr(unwrapped_model, 'init_state_dict_from_bert'): if args.iteration == 0 and hasattr(unwrapped_model, 'init_state_dict_from_bert'):
print("Initializing ICT from pretrained BERT model", flush=True) print("Initializing ICT from pretrained BERT model", flush=True)
unwrapped_model.init_state_dict_from_bert() unwrapped_model.init_state_dict_from_bert()
......
...@@ -19,7 +19,8 @@ import sys ...@@ -19,7 +19,8 @@ import sys
import torch import torch
from megatron import get_args, print_rank_0 from megatron import get_args
from megatron import print_rank_0
from megatron import get_adlr_autoresume from megatron import get_adlr_autoresume
from megatron import mpu from megatron import mpu
from megatron.checkpointing import save_checkpoint from megatron.checkpointing import save_checkpoint
......
...@@ -18,7 +18,8 @@ ...@@ -18,7 +18,8 @@
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from megatron import get_args, print_rank_0 from megatron import get_args
from megatron import print_rank_0
from megatron import get_timers from megatron import get_timers
from megatron import mpu from megatron import mpu
from megatron.data.dataset_utils import build_train_valid_test_datasets from megatron.data.dataset_utils import build_train_valid_test_datasets
......
...@@ -17,7 +17,8 @@ ...@@ -17,7 +17,8 @@
import torch import torch
from megatron import get_args, print_rank_0 from megatron import get_args
from megatron import print_rank_0
from megatron import get_timers from megatron import get_timers
from megatron import get_tokenizer from megatron import get_tokenizer
from megatron import mpu from megatron import mpu
......
...@@ -19,7 +19,8 @@ import torch ...@@ -19,7 +19,8 @@ import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn.functional as F import torch.nn.functional as F
from megatron import get_args, print_rank_0 from megatron import get_args
from megatron import print_rank_0
from megatron import get_timers from megatron import get_timers
from megatron import mpu from megatron import mpu
from megatron.data.dataset_utils import build_train_valid_test_datasets from megatron.data.dataset_utils import build_train_valid_test_datasets
...@@ -56,6 +57,44 @@ def model_provider(): ...@@ -56,6 +57,44 @@ def model_provider():
return general_ict_model_provider(False, False) return general_ict_model_provider(False, False)
def get_group_world_size_rank():
group = mpu.get_data_parallel_group()
rank = torch.distributed.get_rank(group=group)
world_size = torch.distributed.get_world_size(group=group)
return group, rank, world_size
class AllgatherFromDataParallelRegion(torch.autograd.Function):
@staticmethod
def forward(ctx, input_):
assert input_.dim() == 2
group, rank, world_size = get_group_world_size_rank()
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
tensor_list[rank] = input_
torch.distributed.all_gather(tensor_list, input_, group=group)
output = torch.cat(tensor_list, dim=0).contiguous()
return output
@staticmethod
def backward(ctx, grad_output):
group, rank, world_size = get_group_world_size_rank()
assert grad_output.shape[0] % world_size == 0
dim_size = grad_output.shape[0] // world_size
output_list = torch.split(grad_output, dim_size, dim=0)
# get chunk from this rank
output = output_list[rank].contiguous()
return output
def get_batch(data_iterator): def get_batch(data_iterator):
# Items and their type. # Items and their type.
keys = ['query_tokens', 'query_pad_mask', keys = ['query_tokens', 'query_pad_mask',
...@@ -91,43 +130,30 @@ def forward_step(data_iterator, model): ...@@ -91,43 +130,30 @@ def forward_step(data_iterator, model):
block_tokens, block_pad_mask, block_indices = get_batch(data_iterator) block_tokens, block_pad_mask, block_indices = get_batch(data_iterator)
timers('batch generator').stop() timers('batch generator').stop()
# Forward model. # Forward model.
query_logits, block_logits = model(query_tokens, query_pad_mask, block_tokens, block_pad_mask) query_logits, block_logits = model(query_tokens, query_pad_mask, block_tokens, block_pad_mask)
local_batch_size = query_logits.shape[0]
global_batch_size = dist.get_world_size() * local_batch_size # recall we assert that model_parallel_size == 1
data_parallel_size = dist.get_world_size() / args.model_parallel_size all_query_logits = AllgatherFromDataParallelRegion.apply(query_logits)
batch_size = query_logits.shape[0] all_block_logits = AllgatherFromDataParallelRegion.apply(block_logits)
global_batch_size = int(batch_size * data_parallel_size)
all_logits_shape = (int(global_batch_size), int(query_logits.shape[1]))
all_query_logits = torch.zeros(all_logits_shape).type(query_logits.dtype).cuda()
all_block_logits = all_query_logits.clone().cuda()
# record this processes' data
all_query_logits[args.rank * batch_size:(args.rank + 1) * batch_size] = query_logits
all_block_logits[args.rank * batch_size:(args.rank + 1) * batch_size] = block_logits
# merge data from all processes
dist.all_reduce(all_query_logits)
dist.all_reduce(all_block_logits)
# scores are inner products between query and block embeddings # scores are inner products between query and block embeddings
retrieval_scores = all_query_logits.float().matmul(torch.transpose(all_block_logits, 0, 1).float()) retrieval_scores = all_query_logits.float().matmul(torch.transpose(all_block_logits, 0, 1).float())
softmaxed = F.softmax(retrieval_scores, dim=1) softmaxed = F.softmax(retrieval_scores, dim=1)
sorted_vals, sorted_indices = torch.topk(softmaxed, k=softmaxed.shape[1], sorted=True) sorted_vals, sorted_indices = torch.topk(softmaxed, k=softmaxed.shape[1], sorted=True)
def topk_acc(k): def topk_accuracy(k):
return torch.cuda.FloatTensor([sum([int(i in sorted_indices[i, :k]) for i in range(global_batch_size)]) / global_batch_size]) return torch.cuda.FloatTensor([sum([int(i in sorted_indices[i, :k]) for i in range(global_batch_size)]) / global_batch_size])
top_accs = [topk_acc(k) for k in [1, 8, 20, 100]]
topk_accs = [topk_accuracy(int(k)) for k in args.report_topk_accuracies]
retrieval_loss = torch.nn.CrossEntropyLoss()(retrieval_scores, torch.arange(global_batch_size).long().cuda()) retrieval_loss = torch.nn.CrossEntropyLoss()(retrieval_scores, torch.arange(global_batch_size).long().cuda())
reduced_losses = reduce_losses([retrieval_loss, *top_accs]) reduced_losses = reduce_losses([retrieval_loss, *topk_accs])
stats_dict = {
'retrieval loss': reduced_losses[0], # create stats_dict with retrieval loss and all specified top-k accuracies
'top1_acc': reduced_losses[1], topk_acc_dict = {'top{}_acc'.format(k): v for k, v in zip(args.report_topk_accuracies, reduced_losses[1:])}
'top8_acc': reduced_losses[2], stats_dict = dict(retrieval_loss=reduced_losses[0], **topk_acc_dict)
'top20_acc': reduced_losses[3],
'top100_acc': reduced_losses[4],
}
return retrieval_loss, stats_dict return retrieval_loss, stats_dict
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment