Commit 68233932 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'main_retriver_merge' into 'main'

Merge ICT Retriever with Main branch

See merge request ADLR/megatron-lm!227
parents 76e3fca0 816fb890
#! /bin/bash
# Runs the "217M" parameter biencoder model for ICT retriever
RANK=0
WORLD_SIZE=1
PRETRAINED_BERT_PATH=<Specify path of pretrained BERT model>
TEXT_DATA_PATH=<Specify path and file prefix of the text data>
TITLE_DATA_PATH=<Specify path and file prefix od the titles>
CHECKPOINT_PATH=<Specify path>
python pretrain_ict.py \
--num-layers 12 \
--hidden-size 768 \
--num-attention-heads 12 \
--tensor-model-parallel-size 1 \
--micro-batch-size 32 \
--seq-length 256 \
--max-position-embeddings 512 \
--train-iters 100000 \
--vocab-file bert-vocab.txt \
--tokenizer-type BertWordPieceLowerCase \
--DDP-impl torch \
--bert-load ${PRETRAINED_BERT_PATH} \
--log-interval 100 \
--eval-interval 1000 \
--eval-iters 10 \
--retriever-report-topk-accuracies 1 5 10 20 100 \
--retriever-score-scaling \
--load $CHECKPOINT_PATH \
--save $CHECKPOINT_PATH \
--data-path ${TEXT_DATA_PATH} \
--titles-data-path ${TITLE_DATA_PATH} \
--lr 0.0001 \
--lr-decay-style linear \
--weight-decay 1e-2 \
--clip-grad 1.0 \
--lr-warmup-fraction 0.01 \
--save-interval 4000 \
--exit-interval 8000 \
--query-in-block-prob 0.1 \
--fp16
......@@ -39,7 +39,7 @@ def parse_args(extra_args_provider=None, defaults={},
parser = _add_validation_args(parser)
parser = _add_data_args(parser)
parser = _add_autoresume_args(parser)
parser = _add_realm_args(parser)
parser = _add_biencoder_args(parser)
parser = _add_vit_args(parser)
parser = _add_logging_args(parser)
......@@ -672,13 +672,19 @@ def _add_autoresume_args(parser):
return parser
def _add_realm_args(parser):
group = parser.add_argument_group(title='realm')
def _add_biencoder_args(parser):
group = parser.add_argument_group(title='biencoder')
# network size
group.add_argument('--ict-head-size', type=int, default=None,
help='Size of block embeddings to be used in ICT and '
'REALM (paper default: 128)')
'REALM (paper default: 128)')
group.add_argument('--biencoder-projection-dim', type=int, default=0,
help='Size of projection head used in biencoder (paper'
' default: 128)')
group.add_argument('--biencoder-shared-query-context-model', action='store_true',
help='Whether to share the parameters of the query '
'and context models or not')
# checkpointing
group.add_argument('--ict-load', type=str, default=None,
......@@ -697,8 +703,12 @@ def _add_realm_args(parser):
help='Whether to use one sentence documents in ICT')
# training
group.add_argument('--report-topk-accuracies', nargs='+', default=[],
help="Which top-k accuracies to report (e.g. '1 5 20')")
group.add_argument('--retriever-report-topk-accuracies', nargs='+', type=int,
default=[], help="Which top-k accuracies to report "
"(e.g. '1 5 20')")
group.add_argument('--retriever-score-scaling', action='store_true',
help='Whether to scale retriever scores by inverse '
'square root of hidden size')
# faiss index
group.add_argument('--faiss-use-gpu', action='store_true',
......
......@@ -206,6 +206,33 @@ def _transpose_first_dim(t, num_splits, num_splits_first, model):
return t
def fix_query_key_value_ordering(model, checkpoint_version):
"""Fix up query/key/value matrix ordering if checkpoint
version is smaller than 2.0
"""
if checkpoint_version < 2.0:
for name, param in model.named_parameters():
if name.endswith(('.query_key_value.weight', '.query_key_value.bias')):
if checkpoint_version == 0:
fixed_param = _transpose_first_dim(param.data, 3, True, model)
elif checkpoint_version == 1.0:
fixed_param = _transpose_first_dim(param.data, 3, False, model)
else:
print_rank_0(f"Invalid checkpoint version {checkpoint_version}.")
sys.exit()
param.data.copy_(fixed_param)
if name.endswith(('.key_value.weight', '.key_value.bias')):
if checkpoint_version == 0:
fixed_param = _transpose_first_dim(param.data, 2, True, model)
elif checkpoint_version == 1.0:
fixed_param = _transpose_first_dim(param.data, 2, False, model)
else:
print_rank_0(f"Invalid checkpoint version {checkpoint_version}.")
sys.exit()
param.data.copy_(fixed_param)
print_rank_0(" succesfully fixed query-key-values ordering for"
" checkpoint version {}".format(checkpoint_version))
def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True):
"""Load a model checkpoint and return the iteration.
strict (bool): whether to strictly enforce that the keys in
......@@ -308,28 +335,10 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
mpu.set_virtual_pipeline_model_parallel_rank(i)
model[i].load_state_dict(state_dict['model%d' % i], strict=strict)
# Fix up query/key/value matrix ordering
if get_checkpoint_version() < 2.0:
checkpoint_version = get_checkpoint_version()
for name, param in model.named_parameters():
if name.endswith(('.query_key_value.weight', '.query_key_value.bias')):
if checkpoint_version == 0:
fixed_param = _transpose_first_dim(param.data, 3, True, model)
elif checkpoint_version == 1.0:
fixed_param = _transpose_first_dim(param.data, 3, False, model)
else:
print_rank_0(f"Invalid checkpoint version {checkpoint_version}.")
sys.exit()
param.data.copy_(fixed_param)
if name.endswith(('.key_value.weight', '.key_value.bias')):
if checkpoint_version == 0:
fixed_param = _transpose_first_dim(param.data, 2, True, model)
elif checkpoint_version == 1.0:
fixed_param = _transpose_first_dim(param.data, 2, False, model)
else:
print_rank_0(f"Invalid checkpoint version {checkpoint_version}.")
sys.exit()
param.data.copy_(fixed_param)
# Fix up query/key/value matrix ordering if needed
checkpoint_version = get_checkpoint_version()
print_rank_0(f' checkpoint version {checkpoint_version}')
fix_query_key_value_ordering(model, checkpoint_version)
# Optimizer.
if not release and not args.finetune and not args.no_load_optim:
......
import os
import time
import numpy as np
import torch
from megatron import mpu, print_rank_0
from megatron.data.dataset_utils import create_masked_lm_predictions, pad_and_convert_to_numpy
from megatron import get_args, get_tokenizer, print_rank_0, mpu
def get_ict_batch(data_iterator):
# Items and their type.
keys = ['query_tokens', 'query_mask',
'context_tokens', 'context_mask', 'block_data']
datatype = torch.int64
# Broadcast data.
if data_iterator is None:
data = None
else:
data = next(data_iterator)
data_b = mpu.broadcast_data(keys, data, datatype)
# Unpack.
query_tokens = data_b['query_tokens'].long()
query_mask = data_b['query_mask'] < 0.5
context_tokens = data_b['context_tokens'].long()
context_mask = data_b['context_mask'] < 0.5
block_indices = data_b['block_data'].long()
return query_tokens, query_mask,\
context_tokens, context_mask, block_indices
def join_str_list(str_list):
"""Join a list of strings, handling spaces appropriately"""
result = ""
for s in str_list:
if s.startswith("##"):
result += s[2:]
else:
result += " " + s
return result
class BlockSampleData(object):
"""A struct for fully describing a fixed-size block of data as used in REALM
:param start_idx: for first sentence of the block
:param end_idx: for last sentence of the block (may be partially truncated in sample construction)
:param doc_idx: the index of the document from which the block comes in the original indexed dataset
:param block_idx: a unique integer identifier given to every block.
"""
def __init__(self, start_idx, end_idx, doc_idx, block_idx):
self.start_idx = start_idx
self.end_idx = end_idx
self.doc_idx = doc_idx
self.block_idx = block_idx
def as_array(self):
return np.array([self.start_idx, self.end_idx, self.doc_idx, self.block_idx]).astype(np.int64)
def as_tuple(self):
return self.start_idx, self.end_idx, self.doc_idx, self.block_idx
class BlockSamplesMapping(object):
def __init__(self, mapping_array):
# make sure that the array is compatible with BlockSampleData
assert mapping_array.shape[1] == 4
self.mapping_array = mapping_array
def __len__(self):
return self.mapping_array.shape[0]
def __getitem__(self, idx):
"""Get the data associated with an indexed sample."""
sample_data = BlockSampleData(*self.mapping_array[idx])
return sample_data
def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epochs,
max_num_samples, max_seq_length, seed, name, use_one_sent_docs=False):
"""Get samples mapping for a dataset over fixed size blocks. This function also requires
a dataset of the titles for the source documents since their lengths must be taken into account.
:return: samples_mapping (BlockSamplesMapping)
"""
if not num_epochs:
if not max_num_samples:
raise ValueError("Need to specify either max_num_samples "
"or num_epochs")
num_epochs = np.iinfo(np.int32).max - 1
if not max_num_samples:
max_num_samples = np.iinfo(np.int64).max - 1
# Filename of the index mapping
indexmap_filename = data_prefix
indexmap_filename += '_{}_indexmap'.format(name)
if num_epochs != (np.iinfo(np.int32).max - 1):
indexmap_filename += '_{}ep'.format(num_epochs)
if max_num_samples != (np.iinfo(np.int64).max - 1):
indexmap_filename += '_{}mns'.format(max_num_samples)
indexmap_filename += '_{}msl'.format(max_seq_length)
indexmap_filename += '_{}s'.format(seed)
if use_one_sent_docs:
indexmap_filename += '_1sentok'
indexmap_filename += '.npy'
# Build the indexed mapping if not exist.
if mpu.get_data_parallel_rank() == 0 and \
not os.path.isfile(indexmap_filename):
print(' > WARNING: could not find index map file {}, building '
'the indices on rank 0 ...'.format(indexmap_filename))
# Make sure the types match the helpers input types.
assert block_dataset.doc_idx.dtype == np.int64
assert block_dataset.sizes.dtype == np.int32
# Build samples mapping
verbose = torch.distributed.get_rank() == 0
start_time = time.time()
print_rank_0(' > building samples index mapping for {} ...'.format(
name))
from megatron.data import helpers
mapping_array = helpers.build_blocks_mapping(
block_dataset.doc_idx,
block_dataset.sizes,
title_dataset.sizes,
num_epochs,
max_num_samples,
max_seq_length - 3, # account for added tokens
seed,
verbose,
use_one_sent_docs)
print_rank_0(' > done building samples index mapping')
np.save(indexmap_filename, mapping_array, allow_pickle=True)
print_rank_0(' > saved the index mapping in {}'.format(
indexmap_filename))
# Make sure all the ranks have built the mapping
print_rank_0(' > elapsed time to build and save samples mapping '
'(seconds): {:4f}'.format(
time.time() - start_time))
# This should be a barrier but nccl barrier assumes
# device_index=rank which is not the case for model
# parallel case
counts = torch.cuda.LongTensor([1])
torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group())
assert counts[0].item() == torch.distributed.get_world_size(
group=mpu.get_data_parallel_group())
# Load indexed dataset.
print_rank_0(' > loading indexed mapping from {}'.format(
indexmap_filename))
start_time = time.time()
mapping_array = np.load(indexmap_filename, allow_pickle=True, mmap_mode='r')
samples_mapping = BlockSamplesMapping(mapping_array)
print_rank_0(' loaded indexed file in {:3.3f} seconds'.format(
time.time() - start_time))
print_rank_0(' total number of samples: {}'.format(
mapping_array.shape[0]))
return samples_mapping
......@@ -9,6 +9,16 @@ from megatron import get_args
from megatron.data.dataset_utils import get_indexed_dataset_
from megatron.data.realm_dataset_utils import get_block_samples_mapping
def make_attention_mask(source_block, target_block):
"""
Returns a 2-dimensional (2-D) attention mask
:param source_block: 1-D array
:param target_block: 1-D array
"""
mask = (target_block[None, :] >= 1) * (source_block[:, None] >= 1)
mask = mask.astype(np.int64)
# (source_length, target_length)
return mask
def get_ict_dataset(use_titles=True, query_in_block_prob=1):
"""Get a dataset which uses block samples mappings to get ICT/block indexing data (via get_block())
......@@ -39,7 +49,7 @@ class ICTDataset(Dataset):
"""Dataset containing sentences and their blocks for an inverse cloze task."""
def __init__(self, name, block_dataset, title_dataset, data_prefix,
num_epochs, max_num_samples, max_seq_length, query_in_block_prob,
seed, use_titles=True, use_one_sent_docs=False):
seed, use_titles=True, use_one_sent_docs=False, binary_head=False):
self.name = name
self.seed = seed
self.max_seq_length = max_seq_length
......@@ -93,14 +103,20 @@ class ICTDataset(Dataset):
block = list(itertools.chain(*block))[:self.max_seq_length - title_pad_offset]
query_tokens, query_pad_mask = self.concat_and_pad_tokens(query)
block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title)
context_tokens, context_pad_mask = self.concat_and_pad_tokens(block, title)
query_mask = make_attention_mask(query_tokens, query_tokens)
context_mask = make_attention_mask(context_tokens, context_tokens)
block_data = sample_data.as_array()
sample = {
'query_tokens': query_tokens,
'query_mask': query_mask,
'query_pad_mask': query_pad_mask,
'block_tokens': block_tokens,
'block_pad_mask': block_pad_mask,
'context_tokens': context_tokens,
'context_mask': context_mask,
'context_pad_mask': context_pad_mask,
'block_data': block_data,
}
......
import os
import torch
import sys
from megatron import get_args, print_rank_0
from megatron.checkpointing import fix_query_key_value_ordering
from megatron.checkpointing import get_checkpoint_tracker_filename
from megatron.checkpointing import get_checkpoint_name
from megatron import mpu, get_tokenizer
from megatron.model.bert_model import bert_position_ids
from megatron.model.enums import AttnMaskType
from megatron.model.language_model import get_language_model
from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal
from megatron.model.utils import scaled_init_method_normal
from .module import MegatronModule
def biencoder_model_provider(only_query_model=False,
only_context_model=False,
biencoder_shared_query_context_model=False):
"""Build the model."""
args = get_args()
assert mpu.get_tensor_model_parallel_world_size() == 1 and \
mpu.get_pipeline_model_parallel_world_size() == 1, \
"Model parallel size > 1 not supported for ICT"
print_rank_0('building BiEncoderModel...')
# simpler to just keep using 2 tokentypes since
# the LM we initialize with has 2 tokentypes
model = BiEncoderModel(
num_tokentypes=2,
parallel_output=False,
only_query_model=only_query_model,
only_context_model=only_context_model,
biencoder_shared_query_context_model=\
biencoder_shared_query_context_model)
return model
class BiEncoderModel(MegatronModule):
"""Bert-based module for Biencoder model."""
def __init__(self,
num_tokentypes=1,
parallel_output=True,
only_query_model=False,
only_context_model=False,
biencoder_shared_query_context_model=False):
super(BiEncoderModel, self).__init__()
args = get_args()
bert_kwargs = dict(
num_tokentypes=num_tokentypes,
parallel_output=parallel_output)
self.biencoder_shared_query_context_model = \
biencoder_shared_query_context_model
assert not (only_context_model and only_query_model)
self.use_context_model = not only_query_model
self.use_query_model = not only_context_model
self.biencoder_projection_dim = args.biencoder_projection_dim
if self.biencoder_shared_query_context_model:
self.model = PretrainedBertModel(**bert_kwargs)
self._model_key = 'shared_model'
self.query_model, self.context_model = self.model, self.model
else:
if self.use_query_model:
# this model embeds (pseudo-)queries - Embed_input in the paper
self.query_model = PretrainedBertModel(**bert_kwargs)
self._query_key = 'query_model'
if self.use_context_model:
# this model embeds evidence blocks - Embed_doc in the paper
self.context_model = PretrainedBertModel(**bert_kwargs)
self._context_key = 'context_model'
def forward(self, query_tokens, query_attention_mask, query_types,
context_tokens, context_attention_mask, context_types):
"""Run a forward pass for each of the models and
return the respective embeddings."""
if self.use_query_model:
query_logits = self.embed_text(self.query_model,
query_tokens,
query_attention_mask,
query_types)
else:
raise ValueError("Cannot embed query without the query model.")
if self.use_context_model:
context_logits = self.embed_text(self.context_model,
context_tokens,
context_attention_mask,
context_types)
else:
raise ValueError("Cannot embed block without the block model.")
return query_logits, context_logits
@staticmethod
def embed_text(model, tokens, attention_mask, token_types):
"""Embed a batch of tokens using the model"""
logits = model(tokens,
attention_mask,
token_types)
return logits
def state_dict_for_save_checkpoint(self, destination=None, \
prefix='', keep_vars=False):
"""Save dict with state dicts of each of the models."""
state_dict_ = {}
if self.biencoder_shared_query_context_model:
state_dict_[self._model_key] = \
self.model.state_dict_for_save_checkpoint(destination,
prefix,
keep_vars)
else:
if self.use_query_model:
state_dict_[self._query_key] = \
self.query_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
if self.use_context_model:
state_dict_[self._context_key] = \
self.context_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
return state_dict_
def load_state_dict(self, state_dict, strict=True):
"""Load the state dicts of each of the models"""
if self.biencoder_shared_query_context_model:
print_rank_0("Loading shared query-context model")
self.model.load_state_dict(state_dict[self._model_key], \
strict=strict)
else:
if self.use_query_model:
print_rank_0("Loading query model")
self.query_model.load_state_dict( \
state_dict[self._query_key], strict=strict)
if self.use_context_model:
print_rank_0("Loading context model")
self.context_model.load_state_dict( \
state_dict[self._context_key], strict=strict)
def init_state_dict_from_bert(self):
"""Initialize the state from a pretrained BERT model
on iteration zero of ICT pretraining"""
args = get_args()
if args.bert_load is None:
print_rank_0("bert-load argument is None")
return
tracker_filename = get_checkpoint_tracker_filename(args.bert_load)
if not os.path.isfile(tracker_filename):
raise FileNotFoundError("Could not find BERT checkpoint")
with open(tracker_filename, 'r') as f:
iteration = int(f.read().strip())
assert iteration > 0
checkpoint_name = get_checkpoint_name(args.bert_load, iteration, False)
if mpu.get_data_parallel_rank() == 0:
print('global rank {} is loading BERT checkpoint {}'.format(
torch.distributed.get_rank(), checkpoint_name))
# Load the checkpoint.
try:
state_dict = torch.load(checkpoint_name, map_location='cpu')
except ModuleNotFoundError:
from megatron.fp16_deprecated import loss_scaler
# For backward compatibility.
print_rank_0(' > deserializing using the old code structure ...')
sys.modules['fp16.loss_scaler'] = sys.modules[
'megatron.fp16_deprecated.loss_scaler']
sys.modules['megatron.fp16.loss_scaler'] = sys.modules[
'megatron.fp16_deprecated.loss_scaler']
state_dict = torch.load(checkpoint_name, map_location='cpu')
sys.modules.pop('fp16.loss_scaler', None)
sys.modules.pop('megatron.fp16.loss_scaler', None)
except BaseException:
print_rank_0('could not load the BERT checkpoint')
sys.exit()
checkpoint_version = state_dict.get('checkpoint_version', 0)
# load the LM state dict into each model
model_dict = state_dict['model']['language_model']
if self.biencoder_shared_query_context_model:
self.model.language_model.load_state_dict(model_dict)
fix_query_key_value_ordering(self.model, checkpoint_version)
else:
if self.use_query_model:
self.query_model.language_model.load_state_dict(model_dict)
# give each model the same ict_head to begin with as well
if self.biencoder_projection_dim > 0:
query_proj_state_dict = \
self.state_dict_for_save_checkpoint()\
[self._query_key]['projection_enc']
fix_query_key_value_ordering(self.query_model, checkpoint_version)
if self.use_context_model:
self.context_model.language_model.load_state_dict(model_dict)
if self.query_model is not None and \
self.biencoder_projection_dim > 0:
self.context_model.projection_enc.load_state_dict\
(query_proj_state_dict)
fix_query_key_value_ordering(self.context_model, checkpoint_version)
class PretrainedBertModel(MegatronModule):
"""BERT-based encoder for queries or contexts used for
learned information retrieval."""
def __init__(self, num_tokentypes=2,
parallel_output=True):
super(PretrainedBertModel, self).__init__()
args = get_args()
tokenizer = get_tokenizer()
self.pad_id = tokenizer.pad
self.biencoder_projection_dim = args.biencoder_projection_dim
self.parallel_output = parallel_output
init_method = init_method_normal(args.init_method_std)
scaled_init_method = scaled_init_method_normal(
args.init_method_std, args.num_layers)
self.language_model, self._language_model_key = get_language_model(
num_tokentypes=num_tokentypes,
add_pooler=False,
encoder_attn_mask_type=AttnMaskType.padding,
init_method=init_method,
scaled_init_method=scaled_init_method)
if args.biencoder_projection_dim > 0:
self.projection_enc = get_linear_layer(args.hidden_size,
args.biencoder_projection_dim,
init_method)
self._projection_enc_key = 'projection_enc'
def forward(self, input_ids, attention_mask, tokentype_ids=None):
extended_attention_mask = attention_mask.unsqueeze(1)
#extended_attention_mask = bert_extended_attention_mask(attention_mask)
position_ids = bert_position_ids(input_ids)
lm_output = self.language_model(input_ids,
position_ids,
extended_attention_mask,
tokentype_ids=tokentype_ids)
# This mask will be used in average-pooling and max-pooling
pool_mask = (input_ids == self.pad_id).unsqueeze(2)
# Taking the representation of the [CLS] token of BERT
pooled_output = lm_output[:, 0, :]
# Converting to float16 dtype
pooled_output = pooled_output.to(lm_output.dtype)
# Output.
if self.biencoder_projection_dim:
pooled_output = self.projection_enc(pooled_output)
return pooled_output
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
"""For easy load when model is combined with other heads,
add an extra key."""
state_dict_ = {}
state_dict_[self._language_model_key] \
= self.language_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
if self.biencoder_projection_dim > 0:
state_dict_[self._projection_enc_key] = \
self.projection_enc.state_dict(destination, prefix, keep_vars)
return state_dict_
def load_state_dict(self, state_dict, strict=True):
"""Customized load."""
print_rank_0("loading BERT weights")
self.language_model.load_state_dict(
state_dict[self._language_model_key], strict=strict)
if self.biencoder_projection_dim > 0:
print_rank_0("loading projection head weights")
self.projection_enc.load_state_dict(
state_dict[self._projection_enc_key], strict=strict)
......@@ -311,7 +311,7 @@ def setup_model_and_optimizer(model_provider_func):
# We only support local DDP with multiple micro-batches.
if get_num_microbatches() > 1:
assert args.DDP_impl == 'local'
if len(model) == 1:
if len(model) > 1:
assert args.DDP_impl == 'local'
if mpu.get_pipeline_model_parallel_world_size() > 1:
assert args.DDP_impl == 'local'
......@@ -323,6 +323,8 @@ def setup_model_and_optimizer(model_provider_func):
'init_state_dict_from_bert'):
print("Initializing ICT from pretrained BERT model", flush=True)
module.init_state_dict_from_bert()
if args.fp16:
optimizer.reload_model_params()
return model, optimizer, lr_scheduler
......
......@@ -14,6 +14,7 @@
# limitations under the License.
"""Pretrain BERT for Inverse Cloze Task"""
import math
import torch
import torch.distributed as dist
......@@ -23,17 +24,21 @@ from megatron import get_args
from megatron import print_rank_0
from megatron import get_timers
from megatron import mpu
from megatron.data.biencoder_dataset_utils import get_ict_batch
from megatron.data.dataset_utils import build_train_valid_test_datasets
from megatron.model.biencoder_model import biencoder_model_provider
from megatron.training import pretrain
from megatron.utils import average_losses_across_data_parallel_group
from megatron.model.realm_model import general_ict_model_provider
from megatron.data.realm_dataset_utils import get_ict_batch
def pretrain_ict_model_provider():
args = get_args()
return general_ict_model_provider(False, False)
model = biencoder_model_provider(
only_context_model=False,
only_query_model=False,
biencoder_shared_query_context_model=\
args.biencoder_shared_query_context_model)
return model
def get_group_world_size_rank():
......@@ -72,7 +77,6 @@ class AllgatherFromDataParallelRegion(torch.autograd.Function):
output = output_list[rank].contiguous()
return output
def forward_step(data_iterator, model, input_tensor):
"""Forward step."""
args = get_args()
......@@ -80,37 +84,57 @@ def forward_step(data_iterator, model, input_tensor):
# Get the batch.
timers('batch-generator').start()
query_tokens, query_pad_mask, \
block_tokens, block_pad_mask, block_indices = get_ict_batch(data_iterator)
query_tokens, query_mask, \
context_tokens, context_mask, context_indices = get_ict_batch(data_iterator)
timers('batch-generator').stop()
# Query and Context Types
query_types = torch.cuda.LongTensor(*query_tokens.shape).fill_(0)
context_types = torch.cuda.LongTensor(*context_tokens.shape).fill_(0)
# Forward model.
query_logits, block_logits = model(query_tokens, query_pad_mask, block_tokens, block_pad_mask)
query_logits, context_logits = model(query_tokens, query_mask,
query_types, context_tokens,
context_mask, context_types)
micro_batch_size = query_logits.shape[0]
global_batch_size = dist.get_world_size() * micro_batch_size # recall we assert that tensor_model_parallel_size == 1
# recall we assert that tensor_model_parallel_size == 1
assert mpu.get_tensor_model_parallel_world_size() == 1, \
"Model parallel size > 1 not supported for ICT"
global_batch_size = dist.get_world_size() * micro_batch_size
all_query_logits = AllgatherFromDataParallelRegion.apply(query_logits)
all_block_logits = AllgatherFromDataParallelRegion.apply(block_logits)
all_context_logits = AllgatherFromDataParallelRegion.apply(context_logits)
# scores are inner products between query and context embeddings
retrieval_scores = torch.matmul(all_query_logits,
torch.transpose(all_context_logits, 0, 1))
# scaling the retriever scores
if args.retriever_score_scaling:
retrieval_scores = retrieval_scores / math.sqrt(args.hidden_size)
# scores are inner products between query and block embeddings
retrieval_scores = all_query_logits.float().matmul(torch.transpose(all_block_logits, 0, 1).float())
softmaxed = F.softmax(retrieval_scores, dim=1)
sorted_vals, sorted_indices = torch.topk(softmaxed, k=softmaxed.shape[1], sorted=True)
softmax_scores = F.log_softmax(retrieval_scores, dim=1)
sorted_vals, sorted_indices = torch.topk(softmax_scores,
k=softmax_scores.shape[1], sorted=True)
def topk_accuracy(k):
return torch.cuda.FloatTensor([sum([int(i in sorted_indices[i, :k]) for i in range(global_batch_size)]) / global_batch_size])
return torch.cuda.FloatTensor([sum([int(i in sorted_indices[i, :k]) \
for i in range(global_batch_size)]) / global_batch_size])
topk_accs = [topk_accuracy(int(k)) for k in args.report_topk_accuracies]
retrieval_loss = torch.nn.CrossEntropyLoss()(retrieval_scores, torch.arange(global_batch_size).long().cuda())
retrieval_loss = retrieval_loss.float()
averaged_losses = average_losses_across_data_parallel_group([retrieval_loss, *topk_accs])
topk_accs = [topk_accuracy(int(k)) for k in args.retriever_report_topk_accuracies]
# create stats_dict with retrieval loss and all specified top-k accuracies
topk_acc_dict = {'top{}_acc'.format(k): v for k, v in zip(args.report_topk_accuracies, averaged_losses[1:])}
stats_dict = dict(retrieval_loss=averaged_losses[0], **topk_acc_dict)
labels = torch.arange(global_batch_size).long().cuda()
loss = F.nll_loss(softmax_scores, labels, reduction='mean')
reduced_losses = average_losses_across_data_parallel_group([loss, *topk_accs])
# Scale the retrieval loss
loss = loss * mpu.get_data_parallel_world_size()
return retrieval_loss, stats_dict
# create stats_dict with retrieval loss and all specified top-k accuracies
topk_acc_dict = {'top{}_acc'.format(k): v * 100 for k, v in \
zip(args.retriever_report_topk_accuracies, reduced_losses[1:])}
stats_dict = dict(loss=reduced_losses[0], **topk_acc_dict)
return loss, stats_dict
def train_valid_test_datasets_provider(train_val_test_num_samples):
......@@ -129,6 +153,7 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
short_seq_prob=args.short_seq_prob,
seed=args.seed,
skip_warmup=(not args.mmap_warmup),
binary_head=False,
dataset_type='ict')
print_rank_0("> finished creating BERT ICT datasets ...")
......@@ -136,5 +161,7 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
if __name__ == "__main__":
pretrain(train_valid_test_datasets_provider, pretrain_ict_model_provider, forward_step,
pretrain(train_valid_test_datasets_provider,
pretrain_ict_model_provider,
forward_step,
args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment