Commit 68233932 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'main_retriver_merge' into 'main'

Merge ICT Retriever with Main branch

See merge request ADLR/megatron-lm!227
parents 76e3fca0 816fb890
#! /bin/bash
# Runs the "217M" parameter biencoder model for ICT retriever
RANK=0
WORLD_SIZE=1
PRETRAINED_BERT_PATH=<Specify path of pretrained BERT model>
TEXT_DATA_PATH=<Specify path and file prefix of the text data>
TITLE_DATA_PATH=<Specify path and file prefix od the titles>
CHECKPOINT_PATH=<Specify path>
python pretrain_ict.py \
--num-layers 12 \
--hidden-size 768 \
--num-attention-heads 12 \
--tensor-model-parallel-size 1 \
--micro-batch-size 32 \
--seq-length 256 \
--max-position-embeddings 512 \
--train-iters 100000 \
--vocab-file bert-vocab.txt \
--tokenizer-type BertWordPieceLowerCase \
--DDP-impl torch \
--bert-load ${PRETRAINED_BERT_PATH} \
--log-interval 100 \
--eval-interval 1000 \
--eval-iters 10 \
--retriever-report-topk-accuracies 1 5 10 20 100 \
--retriever-score-scaling \
--load $CHECKPOINT_PATH \
--save $CHECKPOINT_PATH \
--data-path ${TEXT_DATA_PATH} \
--titles-data-path ${TITLE_DATA_PATH} \
--lr 0.0001 \
--lr-decay-style linear \
--weight-decay 1e-2 \
--clip-grad 1.0 \
--lr-warmup-fraction 0.01 \
--save-interval 4000 \
--exit-interval 8000 \
--query-in-block-prob 0.1 \
--fp16
...@@ -39,7 +39,7 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -39,7 +39,7 @@ def parse_args(extra_args_provider=None, defaults={},
parser = _add_validation_args(parser) parser = _add_validation_args(parser)
parser = _add_data_args(parser) parser = _add_data_args(parser)
parser = _add_autoresume_args(parser) parser = _add_autoresume_args(parser)
parser = _add_realm_args(parser) parser = _add_biencoder_args(parser)
parser = _add_vit_args(parser) parser = _add_vit_args(parser)
parser = _add_logging_args(parser) parser = _add_logging_args(parser)
...@@ -672,13 +672,19 @@ def _add_autoresume_args(parser): ...@@ -672,13 +672,19 @@ def _add_autoresume_args(parser):
return parser return parser
def _add_realm_args(parser): def _add_biencoder_args(parser):
group = parser.add_argument_group(title='realm') group = parser.add_argument_group(title='biencoder')
# network size # network size
group.add_argument('--ict-head-size', type=int, default=None, group.add_argument('--ict-head-size', type=int, default=None,
help='Size of block embeddings to be used in ICT and ' help='Size of block embeddings to be used in ICT and '
'REALM (paper default: 128)') 'REALM (paper default: 128)')
group.add_argument('--biencoder-projection-dim', type=int, default=0,
help='Size of projection head used in biencoder (paper'
' default: 128)')
group.add_argument('--biencoder-shared-query-context-model', action='store_true',
help='Whether to share the parameters of the query '
'and context models or not')
# checkpointing # checkpointing
group.add_argument('--ict-load', type=str, default=None, group.add_argument('--ict-load', type=str, default=None,
...@@ -697,8 +703,12 @@ def _add_realm_args(parser): ...@@ -697,8 +703,12 @@ def _add_realm_args(parser):
help='Whether to use one sentence documents in ICT') help='Whether to use one sentence documents in ICT')
# training # training
group.add_argument('--report-topk-accuracies', nargs='+', default=[], group.add_argument('--retriever-report-topk-accuracies', nargs='+', type=int,
help="Which top-k accuracies to report (e.g. '1 5 20')") default=[], help="Which top-k accuracies to report "
"(e.g. '1 5 20')")
group.add_argument('--retriever-score-scaling', action='store_true',
help='Whether to scale retriever scores by inverse '
'square root of hidden size')
# faiss index # faiss index
group.add_argument('--faiss-use-gpu', action='store_true', group.add_argument('--faiss-use-gpu', action='store_true',
......
...@@ -206,6 +206,33 @@ def _transpose_first_dim(t, num_splits, num_splits_first, model): ...@@ -206,6 +206,33 @@ def _transpose_first_dim(t, num_splits, num_splits_first, model):
return t return t
def fix_query_key_value_ordering(model, checkpoint_version):
"""Fix up query/key/value matrix ordering if checkpoint
version is smaller than 2.0
"""
if checkpoint_version < 2.0:
for name, param in model.named_parameters():
if name.endswith(('.query_key_value.weight', '.query_key_value.bias')):
if checkpoint_version == 0:
fixed_param = _transpose_first_dim(param.data, 3, True, model)
elif checkpoint_version == 1.0:
fixed_param = _transpose_first_dim(param.data, 3, False, model)
else:
print_rank_0(f"Invalid checkpoint version {checkpoint_version}.")
sys.exit()
param.data.copy_(fixed_param)
if name.endswith(('.key_value.weight', '.key_value.bias')):
if checkpoint_version == 0:
fixed_param = _transpose_first_dim(param.data, 2, True, model)
elif checkpoint_version == 1.0:
fixed_param = _transpose_first_dim(param.data, 2, False, model)
else:
print_rank_0(f"Invalid checkpoint version {checkpoint_version}.")
sys.exit()
param.data.copy_(fixed_param)
print_rank_0(" succesfully fixed query-key-values ordering for"
" checkpoint version {}".format(checkpoint_version))
def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True): def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True):
"""Load a model checkpoint and return the iteration. """Load a model checkpoint and return the iteration.
strict (bool): whether to strictly enforce that the keys in strict (bool): whether to strictly enforce that the keys in
...@@ -308,28 +335,10 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True ...@@ -308,28 +335,10 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
mpu.set_virtual_pipeline_model_parallel_rank(i) mpu.set_virtual_pipeline_model_parallel_rank(i)
model[i].load_state_dict(state_dict['model%d' % i], strict=strict) model[i].load_state_dict(state_dict['model%d' % i], strict=strict)
# Fix up query/key/value matrix ordering # Fix up query/key/value matrix ordering if needed
if get_checkpoint_version() < 2.0:
checkpoint_version = get_checkpoint_version() checkpoint_version = get_checkpoint_version()
for name, param in model.named_parameters(): print_rank_0(f' checkpoint version {checkpoint_version}')
if name.endswith(('.query_key_value.weight', '.query_key_value.bias')): fix_query_key_value_ordering(model, checkpoint_version)
if checkpoint_version == 0:
fixed_param = _transpose_first_dim(param.data, 3, True, model)
elif checkpoint_version == 1.0:
fixed_param = _transpose_first_dim(param.data, 3, False, model)
else:
print_rank_0(f"Invalid checkpoint version {checkpoint_version}.")
sys.exit()
param.data.copy_(fixed_param)
if name.endswith(('.key_value.weight', '.key_value.bias')):
if checkpoint_version == 0:
fixed_param = _transpose_first_dim(param.data, 2, True, model)
elif checkpoint_version == 1.0:
fixed_param = _transpose_first_dim(param.data, 2, False, model)
else:
print_rank_0(f"Invalid checkpoint version {checkpoint_version}.")
sys.exit()
param.data.copy_(fixed_param)
# Optimizer. # Optimizer.
if not release and not args.finetune and not args.no_load_optim: if not release and not args.finetune and not args.no_load_optim:
......
import os
import time
import numpy as np
import torch
from megatron import mpu, print_rank_0
from megatron.data.dataset_utils import create_masked_lm_predictions, pad_and_convert_to_numpy
from megatron import get_args, get_tokenizer, print_rank_0, mpu
def get_ict_batch(data_iterator):
# Items and their type.
keys = ['query_tokens', 'query_mask',
'context_tokens', 'context_mask', 'block_data']
datatype = torch.int64
# Broadcast data.
if data_iterator is None:
data = None
else:
data = next(data_iterator)
data_b = mpu.broadcast_data(keys, data, datatype)
# Unpack.
query_tokens = data_b['query_tokens'].long()
query_mask = data_b['query_mask'] < 0.5
context_tokens = data_b['context_tokens'].long()
context_mask = data_b['context_mask'] < 0.5
block_indices = data_b['block_data'].long()
return query_tokens, query_mask,\
context_tokens, context_mask, block_indices
def join_str_list(str_list):
"""Join a list of strings, handling spaces appropriately"""
result = ""
for s in str_list:
if s.startswith("##"):
result += s[2:]
else:
result += " " + s
return result
class BlockSampleData(object):
"""A struct for fully describing a fixed-size block of data as used in REALM
:param start_idx: for first sentence of the block
:param end_idx: for last sentence of the block (may be partially truncated in sample construction)
:param doc_idx: the index of the document from which the block comes in the original indexed dataset
:param block_idx: a unique integer identifier given to every block.
"""
def __init__(self, start_idx, end_idx, doc_idx, block_idx):
self.start_idx = start_idx
self.end_idx = end_idx
self.doc_idx = doc_idx
self.block_idx = block_idx
def as_array(self):
return np.array([self.start_idx, self.end_idx, self.doc_idx, self.block_idx]).astype(np.int64)
def as_tuple(self):
return self.start_idx, self.end_idx, self.doc_idx, self.block_idx
class BlockSamplesMapping(object):
def __init__(self, mapping_array):
# make sure that the array is compatible with BlockSampleData
assert mapping_array.shape[1] == 4
self.mapping_array = mapping_array
def __len__(self):
return self.mapping_array.shape[0]
def __getitem__(self, idx):
"""Get the data associated with an indexed sample."""
sample_data = BlockSampleData(*self.mapping_array[idx])
return sample_data
def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epochs,
max_num_samples, max_seq_length, seed, name, use_one_sent_docs=False):
"""Get samples mapping for a dataset over fixed size blocks. This function also requires
a dataset of the titles for the source documents since their lengths must be taken into account.
:return: samples_mapping (BlockSamplesMapping)
"""
if not num_epochs:
if not max_num_samples:
raise ValueError("Need to specify either max_num_samples "
"or num_epochs")
num_epochs = np.iinfo(np.int32).max - 1
if not max_num_samples:
max_num_samples = np.iinfo(np.int64).max - 1
# Filename of the index mapping
indexmap_filename = data_prefix
indexmap_filename += '_{}_indexmap'.format(name)
if num_epochs != (np.iinfo(np.int32).max - 1):
indexmap_filename += '_{}ep'.format(num_epochs)
if max_num_samples != (np.iinfo(np.int64).max - 1):
indexmap_filename += '_{}mns'.format(max_num_samples)
indexmap_filename += '_{}msl'.format(max_seq_length)
indexmap_filename += '_{}s'.format(seed)
if use_one_sent_docs:
indexmap_filename += '_1sentok'
indexmap_filename += '.npy'
# Build the indexed mapping if not exist.
if mpu.get_data_parallel_rank() == 0 and \
not os.path.isfile(indexmap_filename):
print(' > WARNING: could not find index map file {}, building '
'the indices on rank 0 ...'.format(indexmap_filename))
# Make sure the types match the helpers input types.
assert block_dataset.doc_idx.dtype == np.int64
assert block_dataset.sizes.dtype == np.int32
# Build samples mapping
verbose = torch.distributed.get_rank() == 0
start_time = time.time()
print_rank_0(' > building samples index mapping for {} ...'.format(
name))
from megatron.data import helpers
mapping_array = helpers.build_blocks_mapping(
block_dataset.doc_idx,
block_dataset.sizes,
title_dataset.sizes,
num_epochs,
max_num_samples,
max_seq_length - 3, # account for added tokens
seed,
verbose,
use_one_sent_docs)
print_rank_0(' > done building samples index mapping')
np.save(indexmap_filename, mapping_array, allow_pickle=True)
print_rank_0(' > saved the index mapping in {}'.format(
indexmap_filename))
# Make sure all the ranks have built the mapping
print_rank_0(' > elapsed time to build and save samples mapping '
'(seconds): {:4f}'.format(
time.time() - start_time))
# This should be a barrier but nccl barrier assumes
# device_index=rank which is not the case for model
# parallel case
counts = torch.cuda.LongTensor([1])
torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group())
assert counts[0].item() == torch.distributed.get_world_size(
group=mpu.get_data_parallel_group())
# Load indexed dataset.
print_rank_0(' > loading indexed mapping from {}'.format(
indexmap_filename))
start_time = time.time()
mapping_array = np.load(indexmap_filename, allow_pickle=True, mmap_mode='r')
samples_mapping = BlockSamplesMapping(mapping_array)
print_rank_0(' loaded indexed file in {:3.3f} seconds'.format(
time.time() - start_time))
print_rank_0(' total number of samples: {}'.format(
mapping_array.shape[0]))
return samples_mapping
...@@ -9,6 +9,16 @@ from megatron import get_args ...@@ -9,6 +9,16 @@ from megatron import get_args
from megatron.data.dataset_utils import get_indexed_dataset_ from megatron.data.dataset_utils import get_indexed_dataset_
from megatron.data.realm_dataset_utils import get_block_samples_mapping from megatron.data.realm_dataset_utils import get_block_samples_mapping
def make_attention_mask(source_block, target_block):
"""
Returns a 2-dimensional (2-D) attention mask
:param source_block: 1-D array
:param target_block: 1-D array
"""
mask = (target_block[None, :] >= 1) * (source_block[:, None] >= 1)
mask = mask.astype(np.int64)
# (source_length, target_length)
return mask
def get_ict_dataset(use_titles=True, query_in_block_prob=1): def get_ict_dataset(use_titles=True, query_in_block_prob=1):
"""Get a dataset which uses block samples mappings to get ICT/block indexing data (via get_block()) """Get a dataset which uses block samples mappings to get ICT/block indexing data (via get_block())
...@@ -39,7 +49,7 @@ class ICTDataset(Dataset): ...@@ -39,7 +49,7 @@ class ICTDataset(Dataset):
"""Dataset containing sentences and their blocks for an inverse cloze task.""" """Dataset containing sentences and their blocks for an inverse cloze task."""
def __init__(self, name, block_dataset, title_dataset, data_prefix, def __init__(self, name, block_dataset, title_dataset, data_prefix,
num_epochs, max_num_samples, max_seq_length, query_in_block_prob, num_epochs, max_num_samples, max_seq_length, query_in_block_prob,
seed, use_titles=True, use_one_sent_docs=False): seed, use_titles=True, use_one_sent_docs=False, binary_head=False):
self.name = name self.name = name
self.seed = seed self.seed = seed
self.max_seq_length = max_seq_length self.max_seq_length = max_seq_length
...@@ -93,14 +103,20 @@ class ICTDataset(Dataset): ...@@ -93,14 +103,20 @@ class ICTDataset(Dataset):
block = list(itertools.chain(*block))[:self.max_seq_length - title_pad_offset] block = list(itertools.chain(*block))[:self.max_seq_length - title_pad_offset]
query_tokens, query_pad_mask = self.concat_and_pad_tokens(query) query_tokens, query_pad_mask = self.concat_and_pad_tokens(query)
block_tokens, block_pad_mask = self.concat_and_pad_tokens(block, title) context_tokens, context_pad_mask = self.concat_and_pad_tokens(block, title)
query_mask = make_attention_mask(query_tokens, query_tokens)
context_mask = make_attention_mask(context_tokens, context_tokens)
block_data = sample_data.as_array() block_data = sample_data.as_array()
sample = { sample = {
'query_tokens': query_tokens, 'query_tokens': query_tokens,
'query_mask': query_mask,
'query_pad_mask': query_pad_mask, 'query_pad_mask': query_pad_mask,
'block_tokens': block_tokens, 'context_tokens': context_tokens,
'block_pad_mask': block_pad_mask, 'context_mask': context_mask,
'context_pad_mask': context_pad_mask,
'block_data': block_data, 'block_data': block_data,
} }
......
import os
import torch
import sys
from megatron import get_args, print_rank_0
from megatron.checkpointing import fix_query_key_value_ordering
from megatron.checkpointing import get_checkpoint_tracker_filename
from megatron.checkpointing import get_checkpoint_name
from megatron import mpu, get_tokenizer
from megatron.model.bert_model import bert_position_ids
from megatron.model.enums import AttnMaskType
from megatron.model.language_model import get_language_model
from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal
from megatron.model.utils import scaled_init_method_normal
from .module import MegatronModule
def biencoder_model_provider(only_query_model=False,
only_context_model=False,
biencoder_shared_query_context_model=False):
"""Build the model."""
args = get_args()
assert mpu.get_tensor_model_parallel_world_size() == 1 and \
mpu.get_pipeline_model_parallel_world_size() == 1, \
"Model parallel size > 1 not supported for ICT"
print_rank_0('building BiEncoderModel...')
# simpler to just keep using 2 tokentypes since
# the LM we initialize with has 2 tokentypes
model = BiEncoderModel(
num_tokentypes=2,
parallel_output=False,
only_query_model=only_query_model,
only_context_model=only_context_model,
biencoder_shared_query_context_model=\
biencoder_shared_query_context_model)
return model
class BiEncoderModel(MegatronModule):
"""Bert-based module for Biencoder model."""
def __init__(self,
num_tokentypes=1,
parallel_output=True,
only_query_model=False,
only_context_model=False,
biencoder_shared_query_context_model=False):
super(BiEncoderModel, self).__init__()
args = get_args()
bert_kwargs = dict(
num_tokentypes=num_tokentypes,
parallel_output=parallel_output)
self.biencoder_shared_query_context_model = \
biencoder_shared_query_context_model
assert not (only_context_model and only_query_model)
self.use_context_model = not only_query_model
self.use_query_model = not only_context_model
self.biencoder_projection_dim = args.biencoder_projection_dim
if self.biencoder_shared_query_context_model:
self.model = PretrainedBertModel(**bert_kwargs)
self._model_key = 'shared_model'
self.query_model, self.context_model = self.model, self.model
else:
if self.use_query_model:
# this model embeds (pseudo-)queries - Embed_input in the paper
self.query_model = PretrainedBertModel(**bert_kwargs)
self._query_key = 'query_model'
if self.use_context_model:
# this model embeds evidence blocks - Embed_doc in the paper
self.context_model = PretrainedBertModel(**bert_kwargs)
self._context_key = 'context_model'
def forward(self, query_tokens, query_attention_mask, query_types,
context_tokens, context_attention_mask, context_types):
"""Run a forward pass for each of the models and
return the respective embeddings."""
if self.use_query_model:
query_logits = self.embed_text(self.query_model,
query_tokens,
query_attention_mask,
query_types)
else:
raise ValueError("Cannot embed query without the query model.")
if self.use_context_model:
context_logits = self.embed_text(self.context_model,
context_tokens,
context_attention_mask,
context_types)
else:
raise ValueError("Cannot embed block without the block model.")
return query_logits, context_logits
@staticmethod
def embed_text(model, tokens, attention_mask, token_types):
"""Embed a batch of tokens using the model"""
logits = model(tokens,
attention_mask,
token_types)
return logits
def state_dict_for_save_checkpoint(self, destination=None, \
prefix='', keep_vars=False):
"""Save dict with state dicts of each of the models."""
state_dict_ = {}
if self.biencoder_shared_query_context_model:
state_dict_[self._model_key] = \
self.model.state_dict_for_save_checkpoint(destination,
prefix,
keep_vars)
else:
if self.use_query_model:
state_dict_[self._query_key] = \
self.query_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
if self.use_context_model:
state_dict_[self._context_key] = \
self.context_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
return state_dict_
def load_state_dict(self, state_dict, strict=True):
"""Load the state dicts of each of the models"""
if self.biencoder_shared_query_context_model:
print_rank_0("Loading shared query-context model")
self.model.load_state_dict(state_dict[self._model_key], \
strict=strict)
else:
if self.use_query_model:
print_rank_0("Loading query model")
self.query_model.load_state_dict( \
state_dict[self._query_key], strict=strict)
if self.use_context_model:
print_rank_0("Loading context model")
self.context_model.load_state_dict( \
state_dict[self._context_key], strict=strict)
def init_state_dict_from_bert(self):
"""Initialize the state from a pretrained BERT model
on iteration zero of ICT pretraining"""
args = get_args()
if args.bert_load is None:
print_rank_0("bert-load argument is None")
return
tracker_filename = get_checkpoint_tracker_filename(args.bert_load)
if not os.path.isfile(tracker_filename):
raise FileNotFoundError("Could not find BERT checkpoint")
with open(tracker_filename, 'r') as f:
iteration = int(f.read().strip())
assert iteration > 0
checkpoint_name = get_checkpoint_name(args.bert_load, iteration, False)
if mpu.get_data_parallel_rank() == 0:
print('global rank {} is loading BERT checkpoint {}'.format(
torch.distributed.get_rank(), checkpoint_name))
# Load the checkpoint.
try:
state_dict = torch.load(checkpoint_name, map_location='cpu')
except ModuleNotFoundError:
from megatron.fp16_deprecated import loss_scaler
# For backward compatibility.
print_rank_0(' > deserializing using the old code structure ...')
sys.modules['fp16.loss_scaler'] = sys.modules[
'megatron.fp16_deprecated.loss_scaler']
sys.modules['megatron.fp16.loss_scaler'] = sys.modules[
'megatron.fp16_deprecated.loss_scaler']
state_dict = torch.load(checkpoint_name, map_location='cpu')
sys.modules.pop('fp16.loss_scaler', None)
sys.modules.pop('megatron.fp16.loss_scaler', None)
except BaseException:
print_rank_0('could not load the BERT checkpoint')
sys.exit()
checkpoint_version = state_dict.get('checkpoint_version', 0)
# load the LM state dict into each model
model_dict = state_dict['model']['language_model']
if self.biencoder_shared_query_context_model:
self.model.language_model.load_state_dict(model_dict)
fix_query_key_value_ordering(self.model, checkpoint_version)
else:
if self.use_query_model:
self.query_model.language_model.load_state_dict(model_dict)
# give each model the same ict_head to begin with as well
if self.biencoder_projection_dim > 0:
query_proj_state_dict = \
self.state_dict_for_save_checkpoint()\
[self._query_key]['projection_enc']
fix_query_key_value_ordering(self.query_model, checkpoint_version)
if self.use_context_model:
self.context_model.language_model.load_state_dict(model_dict)
if self.query_model is not None and \
self.biencoder_projection_dim > 0:
self.context_model.projection_enc.load_state_dict\
(query_proj_state_dict)
fix_query_key_value_ordering(self.context_model, checkpoint_version)
class PretrainedBertModel(MegatronModule):
"""BERT-based encoder for queries or contexts used for
learned information retrieval."""
def __init__(self, num_tokentypes=2,
parallel_output=True):
super(PretrainedBertModel, self).__init__()
args = get_args()
tokenizer = get_tokenizer()
self.pad_id = tokenizer.pad
self.biencoder_projection_dim = args.biencoder_projection_dim
self.parallel_output = parallel_output
init_method = init_method_normal(args.init_method_std)
scaled_init_method = scaled_init_method_normal(
args.init_method_std, args.num_layers)
self.language_model, self._language_model_key = get_language_model(
num_tokentypes=num_tokentypes,
add_pooler=False,
encoder_attn_mask_type=AttnMaskType.padding,
init_method=init_method,
scaled_init_method=scaled_init_method)
if args.biencoder_projection_dim > 0:
self.projection_enc = get_linear_layer(args.hidden_size,
args.biencoder_projection_dim,
init_method)
self._projection_enc_key = 'projection_enc'
def forward(self, input_ids, attention_mask, tokentype_ids=None):
extended_attention_mask = attention_mask.unsqueeze(1)
#extended_attention_mask = bert_extended_attention_mask(attention_mask)
position_ids = bert_position_ids(input_ids)
lm_output = self.language_model(input_ids,
position_ids,
extended_attention_mask,
tokentype_ids=tokentype_ids)
# This mask will be used in average-pooling and max-pooling
pool_mask = (input_ids == self.pad_id).unsqueeze(2)
# Taking the representation of the [CLS] token of BERT
pooled_output = lm_output[:, 0, :]
# Converting to float16 dtype
pooled_output = pooled_output.to(lm_output.dtype)
# Output.
if self.biencoder_projection_dim:
pooled_output = self.projection_enc(pooled_output)
return pooled_output
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
"""For easy load when model is combined with other heads,
add an extra key."""
state_dict_ = {}
state_dict_[self._language_model_key] \
= self.language_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
if self.biencoder_projection_dim > 0:
state_dict_[self._projection_enc_key] = \
self.projection_enc.state_dict(destination, prefix, keep_vars)
return state_dict_
def load_state_dict(self, state_dict, strict=True):
"""Customized load."""
print_rank_0("loading BERT weights")
self.language_model.load_state_dict(
state_dict[self._language_model_key], strict=strict)
if self.biencoder_projection_dim > 0:
print_rank_0("loading projection head weights")
self.projection_enc.load_state_dict(
state_dict[self._projection_enc_key], strict=strict)
...@@ -311,7 +311,7 @@ def setup_model_and_optimizer(model_provider_func): ...@@ -311,7 +311,7 @@ def setup_model_and_optimizer(model_provider_func):
# We only support local DDP with multiple micro-batches. # We only support local DDP with multiple micro-batches.
if get_num_microbatches() > 1: if get_num_microbatches() > 1:
assert args.DDP_impl == 'local' assert args.DDP_impl == 'local'
if len(model) == 1: if len(model) > 1:
assert args.DDP_impl == 'local' assert args.DDP_impl == 'local'
if mpu.get_pipeline_model_parallel_world_size() > 1: if mpu.get_pipeline_model_parallel_world_size() > 1:
assert args.DDP_impl == 'local' assert args.DDP_impl == 'local'
...@@ -323,6 +323,8 @@ def setup_model_and_optimizer(model_provider_func): ...@@ -323,6 +323,8 @@ def setup_model_and_optimizer(model_provider_func):
'init_state_dict_from_bert'): 'init_state_dict_from_bert'):
print("Initializing ICT from pretrained BERT model", flush=True) print("Initializing ICT from pretrained BERT model", flush=True)
module.init_state_dict_from_bert() module.init_state_dict_from_bert()
if args.fp16:
optimizer.reload_model_params()
return model, optimizer, lr_scheduler return model, optimizer, lr_scheduler
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
"""Pretrain BERT for Inverse Cloze Task""" """Pretrain BERT for Inverse Cloze Task"""
import math
import torch import torch
import torch.distributed as dist import torch.distributed as dist
...@@ -23,17 +24,21 @@ from megatron import get_args ...@@ -23,17 +24,21 @@ from megatron import get_args
from megatron import print_rank_0 from megatron import print_rank_0
from megatron import get_timers from megatron import get_timers
from megatron import mpu from megatron import mpu
from megatron.data.biencoder_dataset_utils import get_ict_batch
from megatron.data.dataset_utils import build_train_valid_test_datasets from megatron.data.dataset_utils import build_train_valid_test_datasets
from megatron.model.biencoder_model import biencoder_model_provider
from megatron.training import pretrain from megatron.training import pretrain
from megatron.utils import average_losses_across_data_parallel_group from megatron.utils import average_losses_across_data_parallel_group
from megatron.model.realm_model import general_ict_model_provider
from megatron.data.realm_dataset_utils import get_ict_batch
def pretrain_ict_model_provider(): def pretrain_ict_model_provider():
args = get_args() args = get_args()
return general_ict_model_provider(False, False) model = biencoder_model_provider(
only_context_model=False,
only_query_model=False,
biencoder_shared_query_context_model=\
args.biencoder_shared_query_context_model)
return model
def get_group_world_size_rank(): def get_group_world_size_rank():
...@@ -72,7 +77,6 @@ class AllgatherFromDataParallelRegion(torch.autograd.Function): ...@@ -72,7 +77,6 @@ class AllgatherFromDataParallelRegion(torch.autograd.Function):
output = output_list[rank].contiguous() output = output_list[rank].contiguous()
return output return output
def forward_step(data_iterator, model, input_tensor): def forward_step(data_iterator, model, input_tensor):
"""Forward step.""" """Forward step."""
args = get_args() args = get_args()
...@@ -80,37 +84,57 @@ def forward_step(data_iterator, model, input_tensor): ...@@ -80,37 +84,57 @@ def forward_step(data_iterator, model, input_tensor):
# Get the batch. # Get the batch.
timers('batch-generator').start() timers('batch-generator').start()
query_tokens, query_pad_mask, \ query_tokens, query_mask, \
block_tokens, block_pad_mask, block_indices = get_ict_batch(data_iterator) context_tokens, context_mask, context_indices = get_ict_batch(data_iterator)
timers('batch-generator').stop() timers('batch-generator').stop()
# Query and Context Types
query_types = torch.cuda.LongTensor(*query_tokens.shape).fill_(0)
context_types = torch.cuda.LongTensor(*context_tokens.shape).fill_(0)
# Forward model. # Forward model.
query_logits, block_logits = model(query_tokens, query_pad_mask, block_tokens, block_pad_mask) query_logits, context_logits = model(query_tokens, query_mask,
query_types, context_tokens,
context_mask, context_types)
micro_batch_size = query_logits.shape[0] micro_batch_size = query_logits.shape[0]
global_batch_size = dist.get_world_size() * micro_batch_size # recall we assert that tensor_model_parallel_size == 1 # recall we assert that tensor_model_parallel_size == 1
assert mpu.get_tensor_model_parallel_world_size() == 1, \
"Model parallel size > 1 not supported for ICT"
global_batch_size = dist.get_world_size() * micro_batch_size
all_query_logits = AllgatherFromDataParallelRegion.apply(query_logits) all_query_logits = AllgatherFromDataParallelRegion.apply(query_logits)
all_block_logits = AllgatherFromDataParallelRegion.apply(block_logits) all_context_logits = AllgatherFromDataParallelRegion.apply(context_logits)
# scores are inner products between query and context embeddings
retrieval_scores = torch.matmul(all_query_logits,
torch.transpose(all_context_logits, 0, 1))
# scaling the retriever scores
if args.retriever_score_scaling:
retrieval_scores = retrieval_scores / math.sqrt(args.hidden_size)
# scores are inner products between query and block embeddings softmax_scores = F.log_softmax(retrieval_scores, dim=1)
retrieval_scores = all_query_logits.float().matmul(torch.transpose(all_block_logits, 0, 1).float()) sorted_vals, sorted_indices = torch.topk(softmax_scores,
softmaxed = F.softmax(retrieval_scores, dim=1) k=softmax_scores.shape[1], sorted=True)
sorted_vals, sorted_indices = torch.topk(softmaxed, k=softmaxed.shape[1], sorted=True)
def topk_accuracy(k): def topk_accuracy(k):
return torch.cuda.FloatTensor([sum([int(i in sorted_indices[i, :k]) for i in range(global_batch_size)]) / global_batch_size]) return torch.cuda.FloatTensor([sum([int(i in sorted_indices[i, :k]) \
for i in range(global_batch_size)]) / global_batch_size])
topk_accs = [topk_accuracy(int(k)) for k in args.report_topk_accuracies] topk_accs = [topk_accuracy(int(k)) for k in args.retriever_report_topk_accuracies]
retrieval_loss = torch.nn.CrossEntropyLoss()(retrieval_scores, torch.arange(global_batch_size).long().cuda())
retrieval_loss = retrieval_loss.float()
averaged_losses = average_losses_across_data_parallel_group([retrieval_loss, *topk_accs])
# create stats_dict with retrieval loss and all specified top-k accuracies labels = torch.arange(global_batch_size).long().cuda()
topk_acc_dict = {'top{}_acc'.format(k): v for k, v in zip(args.report_topk_accuracies, averaged_losses[1:])} loss = F.nll_loss(softmax_scores, labels, reduction='mean')
stats_dict = dict(retrieval_loss=averaged_losses[0], **topk_acc_dict) reduced_losses = average_losses_across_data_parallel_group([loss, *topk_accs])
# Scale the retrieval loss
loss = loss * mpu.get_data_parallel_world_size()
return retrieval_loss, stats_dict # create stats_dict with retrieval loss and all specified top-k accuracies
topk_acc_dict = {'top{}_acc'.format(k): v * 100 for k, v in \
zip(args.retriever_report_topk_accuracies, reduced_losses[1:])}
stats_dict = dict(loss=reduced_losses[0], **topk_acc_dict)
return loss, stats_dict
def train_valid_test_datasets_provider(train_val_test_num_samples): def train_valid_test_datasets_provider(train_val_test_num_samples):
...@@ -129,6 +153,7 @@ def train_valid_test_datasets_provider(train_val_test_num_samples): ...@@ -129,6 +153,7 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
short_seq_prob=args.short_seq_prob, short_seq_prob=args.short_seq_prob,
seed=args.seed, seed=args.seed,
skip_warmup=(not args.mmap_warmup), skip_warmup=(not args.mmap_warmup),
binary_head=False,
dataset_type='ict') dataset_type='ict')
print_rank_0("> finished creating BERT ICT datasets ...") print_rank_0("> finished creating BERT ICT datasets ...")
...@@ -136,5 +161,7 @@ def train_valid_test_datasets_provider(train_val_test_num_samples): ...@@ -136,5 +161,7 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
if __name__ == "__main__": if __name__ == "__main__":
pretrain(train_valid_test_datasets_provider, pretrain_ict_model_provider, forward_step, pretrain(train_valid_test_datasets_provider,
pretrain_ict_model_provider,
forward_step,
args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'}) args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment