"vscode:/vscode.git/clone" did not exist on "2cae2907b6c07f83aa6a17ca5b475df574896e7b"
Commit bcd605f8 authored by Mostofa Patwary's avatar Mostofa Patwary
Browse files

Added code for building embeddings and savings

parent 31d39ec0
...@@ -635,6 +635,9 @@ def _add_data_args(parser): ...@@ -635,6 +635,9 @@ def _add_data_args(parser):
group.add_argument('--retriever-seq-length', type=int, default=256, group.add_argument('--retriever-seq-length', type=int, default=256,
help='Maximum sequence length for the biencoder model ' help='Maximum sequence length for the biencoder model '
' for retriever') ' for retriever')
group.add_argument('--sample-rate', type=float, default=1.0,
help='sample rate for training data. Supposed to be 0 '
' < sample_rate < 1')
group.add_argument('--mask-prob', type=float, default=0.15, group.add_argument('--mask-prob', type=float, default=0.15,
help='Probability of replacing a token with mask.') help='Probability of replacing a token with mask.')
group.add_argument('--short-seq-prob', type=float, default=0.1, group.add_argument('--short-seq-prob', type=float, default=0.1,
...@@ -704,6 +707,8 @@ def _add_biencoder_args(parser): ...@@ -704,6 +707,8 @@ def _add_biencoder_args(parser):
'ICT dataset') 'ICT dataset')
group.add_argument('--use-one-sent-docs', action='store_true', group.add_argument('--use-one-sent-docs', action='store_true',
help='Whether to use one sentence documents in ICT') help='Whether to use one sentence documents in ICT')
group.add_argument('--evidence-data-path', type=str, default=None,
help='Path to Wikipedia Evidence frm DPR paper')
# training # training
group.add_argument('--retriever-report-topk-accuracies', nargs='+', type=int, group.add_argument('--retriever-report-topk-accuracies', nargs='+', type=int,
......
...@@ -383,42 +383,42 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True ...@@ -383,42 +383,42 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
return iteration return iteration
def load_ict_checkpoint(model, only_query_model=False, only_context_model=False, from_realm_chkpt=False): def load_biencoder_checkpoint(model, only_query_model=False,
"""selectively load ICT models for indexing/retrieving from ICT or REALM checkpoints""" only_context_model=False, custom_load_path=None):
"""
selectively load retrieval models for indexing/retrieving
from saved checkpoints
"""
args = get_args() args = get_args()
model = utils.unwrap_model(model) model = utils.unwrap_model(model)
load_path = args.load if from_realm_chkpt else args.ict_load load_path = custom_load_path if custom_load_path is not None else args.load
tracker_filename = get_checkpoint_tracker_filename(load_path) tracker_filename = get_checkpoint_tracker_filename(load_path)
with open(tracker_filename, 'r') as f: with open(tracker_filename, 'r') as f:
iteration = int(f.read().strip()) iteration = int(f.read().strip())
# assert iteration > 0
checkpoint_name = get_checkpoint_name(load_path, iteration, False) checkpoint_name = get_checkpoint_name(load_path, iteration, False)
if mpu.get_data_parallel_rank() == 0: if mpu.get_data_parallel_rank() == 0:
print('global rank {} is loading checkpoint {}'.format( print('global rank {} is loading checkpoint {}'.format(
torch.distributed.get_rank(), checkpoint_name)) torch.distributed.get_rank(), checkpoint_name))
state_dict = torch.load(checkpoint_name, map_location='cpu') state_dict = torch.load(checkpoint_name, map_location='cpu')
ict_state_dict = state_dict['model'] ret_state_dict = state_dict['model']
print(ict_state_dict)
sys.exit()
if from_realm_chkpt and mpu.get_data_parallel_rank() == 0:
print(" loading ICT state dict from REALM", flush=True)
ict_state_dict = ict_state_dict['retriever']['ict_model']
if only_query_model: if only_query_model:
ict_state_dict.pop('context_model') ret_state_dict.pop('context_model')
if only_context_model: if only_context_model:
ict_state_dict.pop('query_model') ret_state_dict.pop('query_model')
model.load_state_dict(ict_state_dict) assert len(model) == 1
model[0].load_state_dict(ret_state_dict)
torch.distributed.barrier() torch.distributed.barrier()
if mpu.get_data_parallel_rank() == 0: if mpu.get_data_parallel_rank() == 0:
print(' successfully loaded {}'.format(checkpoint_name)) print(' successfully loaded {}'.format(checkpoint_name))
return model return model
...@@ -4,10 +4,21 @@ import time ...@@ -4,10 +4,21 @@ import time
import numpy as np import numpy as np
import torch import torch
from megatron import mpu, print_rank_0 from megatron import get_args, get_tokenizer, mpu, print_rank_0
from megatron.data.dataset_utils import create_masked_lm_predictions, pad_and_convert_to_numpy from megatron.data.dataset_utils import create_masked_lm_predictions, \
from megatron import get_args, get_tokenizer, print_rank_0, mpu pad_and_convert_to_numpy
from megatron.data.data_samplers import MegatronPretrainingSampler
def make_attention_mask(source_block, target_block):
"""
Returns a 2-dimensional (2-D) attention mask
:param source_block: 1-D array
:param target_block: 1-D array
"""
mask = (target_block[None, :] >= 1) * (source_block[:, None] >= 1)
mask = mask.astype(np.int64)
# (source_length, target_length)
return mask
def get_one_epoch_dataloader(dataset, micro_batch_size=None): def get_one_epoch_dataloader(dataset, micro_batch_size=None):
"""Specifically one epoch to be used in an indexing job.""" """Specifically one epoch to be used in an indexing job."""
...@@ -20,15 +31,17 @@ def get_one_epoch_dataloader(dataset, micro_batch_size=None): ...@@ -20,15 +31,17 @@ def get_one_epoch_dataloader(dataset, micro_batch_size=None):
global_batch_size = micro_batch_size * world_size global_batch_size = micro_batch_size * world_size
num_workers = args.num_workers num_workers = args.num_workers
sampler = torch.utils.data.SequentialSampler(dataset) # Use megatron's sampler with consumed samples set to 0 as
# importantly, drop_last must be False to get all the data. # this is only for evaluation and don't intend to resume half way.
assert False, 'DistributedBatchSampler deprecated, change the implementation' # Also, set the drop last to false as don't intend to remove
from megatron.data.samplers import DistributedBatchSampler # the last batch
batch_sampler = DistributedBatchSampler(sampler, batch_sampler = MegatronPretrainingSampler(
batch_size=global_batch_size, total_samples=len(dataset),
drop_last=False, consumed_samples=0,
rank=rank, micro_batch_size=args.micro_batch_size,
world_size=world_size) data_parallel_rank=mpu.get_data_parallel_rank(),
data_parallel_size=mpu.get_data_parallel_world_size(),
drop_last=False)
return torch.utils.data.DataLoader(dataset, return torch.utils.data.DataLoader(dataset,
batch_sampler=batch_sampler, batch_sampler=batch_sampler,
......
...@@ -57,7 +57,7 @@ def build_pretraining_data_loader(dataset, consumed_samples): ...@@ -57,7 +57,7 @@ def build_pretraining_data_loader(dataset, consumed_samples):
class MegatronPretrainingSampler: class MegatronPretrainingSampler:
def __init__(self, total_samples, consumed_samples, micro_batch_size, def __init__(self, total_samples, consumed_samples, micro_batch_size,
data_parallel_rank, data_parallel_size): data_parallel_rank, data_parallel_size, drop_last=True):
# Keep a copy of input params for later use. # Keep a copy of input params for later use.
self.total_samples = total_samples self.total_samples = total_samples
self.consumed_samples = consumed_samples self.consumed_samples = consumed_samples
...@@ -65,6 +65,7 @@ class MegatronPretrainingSampler: ...@@ -65,6 +65,7 @@ class MegatronPretrainingSampler:
self.data_parallel_rank = data_parallel_rank self.data_parallel_rank = data_parallel_rank
self.micro_batch_times_data_parallel_size = \ self.micro_batch_times_data_parallel_size = \
self.micro_batch_size * data_parallel_size self.micro_batch_size * data_parallel_size
self.drop_last = drop_last
# Sanity checks. # Sanity checks.
assert self.total_samples > 0, \ assert self.total_samples > 0, \
...@@ -81,17 +82,26 @@ class MegatronPretrainingSampler: ...@@ -81,17 +82,26 @@ class MegatronPretrainingSampler:
def __len__(self): def __len__(self):
return self.total_samples return self.total_samples
def get_start_end_idx(self):
start_idx = self.data_parallel_rank * self.micro_batch_size
end_idx = start_idx + self.micro_batch_size
return start_idx, end_idx
def __iter__(self): def __iter__(self):
batch = [] batch = []
# Last batch if not complete will be dropped. # Last batch will be dropped if drop_last is not set False
for idx in range(self.consumed_samples, self.total_samples): for idx in range(self.consumed_samples, self.total_samples):
batch.append(idx) batch.append(idx)
if len(batch) == self.micro_batch_times_data_parallel_size: if len(batch) == self.micro_batch_times_data_parallel_size:
start_idx = self.data_parallel_rank * self.micro_batch_size start_idx, end_idx = self.get_start_end_idx()
end_idx = start_idx + self.micro_batch_size
yield batch[start_idx:end_idx] yield batch[start_idx:end_idx]
batch = [] batch = []
# Check the last partial batch and see drop_last is set
if len(batch) > 0 and not self.drop_last:
start_idx, end_idx = self.get_start_end_idx()
yield batch[start_idx:end_idx]
class MegatronPretrainingRandomSampler: class MegatronPretrainingRandomSampler:
......
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Wikipedia dataset from DPR code for ORQA."""
from abc import ABC
import csv
import numpy as np
import random
import torch
from torch.utils.data import Dataset
from megatron import print_rank_0, get_args, get_tokenizer, mpu
from megatron.data.biencoder_dataset_utils import make_attention_mask
def get_open_retrieval_wiki_dataset():
args = get_args()
tokenizer = get_tokenizer()
dataset = OpenRetrievalEvidenceDataset('2018 Wikipedia from DPR codebase',
'evidence',
args.evidence_data_path,
tokenizer,
args.retriever_seq_length)
return dataset
def get_open_retrieval_batch(data_iterator):
# Items and their type.
keys = ['row_id', 'context', 'context_mask', 'context_types',
'context_pad_mask']
datatype = torch.int64
# Broadcast data.
data = None if data_iterator is None else next(data_iterator)
data_b = mpu.broadcast_data(keys, data, datatype)
# Unpack.
row_id = data_b['row_id'].long()
context = data_b['context'].long()
# TODO: make the context mask a binary one
context_mask = (data_b['context_mask'] < 0.5)
context_types = data_b['context_types'].long()
context_pad_mask = data_b['context_pad_mask'].long()
return row_id, context, context_mask, context_types, context_pad_mask
def build_tokens_types_paddings_from_text(row, tokenizer, max_seq_length):
"""Build token types and paddings, trim if needed, and pad if needed."""
title_ids = tokenizer.tokenize(row['title'])
context_ids = tokenizer.tokenize(row['text'])
# Appending the title of the context at front
extended_context_ids = title_ids + [tokenizer.sep_id] + context_ids
context_ids, context_types, context_pad_mask = \
build_tokens_types_paddings_from_ids(extended_context_ids,
max_seq_length, tokenizer.cls, tokenizer.sep, tokenizer.pad)
return context_ids, context_types, context_pad_mask
# noinspection DuplicatedCode
def build_tokens_types_paddings_from_ids(text_ids, max_seq_length,
cls_id, sep_id, pad_id):
"""Build token types and paddings, trim if needed, and pad if needed."""
enc_ids = []
tokentypes_enc = []
# [CLS].
enc_ids.append(cls_id)
tokentypes_enc.append(0)
# A.
len_src = len(text_ids)
enc_ids.extend(text_ids)
tokentypes_enc.extend([0] * len_src)
# Cap the size.
if len(enc_ids) > max_seq_length - 1:
enc_ids = enc_ids[0: max_seq_length - 1]
tokentypes_enc = tokentypes_enc[0: max_seq_length - 1]
# [SEP].
enc_ids.append(sep_id)
tokentypes_enc.append(0)
num_tokens_enc = len(enc_ids)
# Padding.
padding_length = max_seq_length - len(enc_ids)
if padding_length > 0:
enc_ids.extend([pad_id] * padding_length)
tokentypes_enc.extend([pad_id] * padding_length)
pad_mask = ([1] * num_tokens_enc) + ([0] * padding_length)
pad_mask = np.array(pad_mask, dtype=np.int64)
return enc_ids, tokentypes_enc, pad_mask
def build_sample(row_id, context_ids, context_types, context_pad_mask):
"""Convert to numpy and return a sample consumed by the batch producer."""
context_ids = np.array(context_ids, dtype=np.int64)
context_types = np.array(context_types, dtype=np.int64)
context_mask = make_attention_mask(context_ids, context_ids)
sample = ({
'row_id': row_id,
'context': context_ids,
'context_mask': context_mask,
'context_types': context_types,
'context_pad_mask': context_pad_mask
})
return sample
class OpenRetrievalEvidenceDataset(ABC, Dataset):
"""Open Retrieval Evidence dataset class."""
def __init__(self, task_name, dataset_name, datapath, tokenizer,
max_seq_length):
# Store inputs.
self.task_name = task_name
self.dataset_name = dataset_name
self.tokenizer = tokenizer
self.max_seq_length = max_seq_length
print_rank_0(' > building {} dataset for {}:'.format(self.task_name,
self.dataset_name))
# Process the files.
print_rank_0(datapath)
self.samples, self.id2text = self.process_samples_from_single_path(
datapath)
args = get_args()
if args.sample_rate < 1: # subsample
k = int(len(self.samples) * args.sample_rate)
self.samples = random.sample(self.samples, k)
print_rank_0(' >> total number of samples: {}'.format(
len(self.samples)))
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
row = self.samples[idx]
context_ids, context_types, context_pad_mask = \
build_tokens_types_paddings_from_text(row, self.tokenizer,
self.max_seq_length)
sample = build_sample(row['doc_id'],
context_ids,
context_types,
context_pad_mask)
return sample
@staticmethod
def process_samples_from_single_path(filename):
print_rank_0(' > Processing {} ...'.format(filename))
total = 0
rows = []
id2text = {}
with open(filename) as tsvfile:
reader = csv.reader(tsvfile, delimiter='\t')
next(reader, None) # skip the headers
for row in reader:
# file format: doc_id, doc_text, title
doc_id = int(row[0])
text = row[1]
title = row[2]
rows.append({'doc_id': doc_id,
'text': text,
'title': title})
assert doc_id not in id2text
id2text[doc_id] = (text, title)
total += 1
if total % 100000 == 0:
print_rank_0(' > processed {} rows so far ...'.format(
total))
print_rank_0(' >> processed {} samples.'.format(len(rows)))
return rows, id2text
...@@ -15,11 +15,12 @@ def detach(tensor): ...@@ -15,11 +15,12 @@ def detach(tensor):
class OpenRetreivalDataStore(object): class OpenRetreivalDataStore(object):
"""Serializable data structure for holding data for blocks -- embeddings """
and necessary metadata for Retriever""" Serializable data structure for holding data for blocks --
embeddings and necessary metadata for Retriever
"""
def __init__(self, embedding_path=None, load_from_path=True, rank=None): def __init__(self, embedding_path=None, load_from_path=True, rank=None):
self.embed_data = dict() self.embed_data = dict()
#self.meta_data = dict()
if embedding_path is None: if embedding_path is None:
args = get_args() args = get_args()
embedding_path = args.embedding_path embedding_path = args.embedding_path
...@@ -36,13 +37,13 @@ class OpenRetreivalDataStore(object): ...@@ -36,13 +37,13 @@ class OpenRetreivalDataStore(object):
def state(self): def state(self):
return { return {
'embed_data': self.embed_data, 'embed_data': self.embed_data,
#'meta_data': self.meta_data,
} }
def clear(self): def clear(self):
"""Clear the embedding data structures to save memory. """
The metadata ends up getting used, and is also much smaller in dimensionality Clear the embedding data structures to save memory.
so it isn't really worth clearing. The metadata ends up getting used, and is also much smaller in
dimensionality so it isn't really worth clearing.
""" """
self.embed_data = dict() self.embed_data = dict()
...@@ -56,35 +57,34 @@ class OpenRetreivalDataStore(object): ...@@ -56,35 +57,34 @@ class OpenRetreivalDataStore(object):
print(">> Finished unpickling BlockData\n", flush=True) print(">> Finished unpickling BlockData\n", flush=True)
self.embed_data = state_dict['embed_data'] self.embed_data = state_dict['embed_data']
#self.meta_data = state_dict['meta_data']
#def add_block_data(self, block_indices, block_embeds, block_metas, allow_overwrite=False):
def add_block_data(self, row_id, block_embeds, allow_overwrite=False): def add_block_data(self, row_id, block_embeds, allow_overwrite=False):
"""Add data for set of blocks """
Add data for set of blocks
:param row_id: 1D array of unique int ids for the blocks :param row_id: 1D array of unique int ids for the blocks
:param block_embeds: 2D array of embeddings of the blocks :param block_embeds: 2D array of embeddings of the blocks
#:param block_metas: 2D array of metadata for the blocks. In the case of retriever this will be [start_idx, end_idx, doc_idx]
In the case of REALM this will be [start_idx, end_idx, doc_idx]
""" """
#for idx, embed, meta in zip(block_indices, block_embeds, block_metas):
for idx, embed in zip(row_id, block_embeds): for idx, embed in zip(row_id, block_embeds):
if not allow_overwrite and idx in self.embed_data: if not allow_overwrite and idx in self.embed_data:
raise ValueError("Unexpectedly tried to overwrite block data") raise ValueError("Unexpectedly tried to overwrite block data")
self.embed_data[idx] = np.float16(embed) self.embed_data[idx] = np.float16(embed)
#self.meta_data[idx] = meta
def save_shard(self): def save_shard(self):
"""Save the block data that was created this in this process""" """
Save the block data that was created this in this process
"""
if not os.path.isdir(self.temp_dir_name): if not os.path.isdir(self.temp_dir_name):
os.makedirs(self.temp_dir_name, exist_ok=True) os.makedirs(self.temp_dir_name, exist_ok=True)
# save the data for each shard # save the data for each shard
with open('{}/{}.pkl'.format(self.temp_dir_name, self.rank), 'wb') as writer: with open('{}/{}.pkl'.format(self.temp_dir_name, self.rank), 'wb') \
as writer:
pickle.dump(self.state(), writer) pickle.dump(self.state(), writer)
def merge_shards_and_save(self): def merge_shards_and_save(self):
"""Combine all the shards made using self.save_shard()""" #Combine all the shards made using save_shard
shard_names = os.listdir(self.temp_dir_name) shard_names = os.listdir(self.temp_dir_name)
seen_own_shard = False seen_own_shard = False
...@@ -99,9 +99,9 @@ class OpenRetreivalDataStore(object): ...@@ -99,9 +99,9 @@ class OpenRetreivalDataStore(object):
old_size = len(self.embed_data) old_size = len(self.embed_data)
shard_size = len(data['embed_data']) shard_size = len(data['embed_data'])
# add the shard's data and check to make sure there is no overlap # add the shard's data and check to make sure there
# is no overlap
self.embed_data.update(data['embed_data']) self.embed_data.update(data['embed_data'])
#self.meta_data.update(data['meta_data'])
assert len(self.embed_data) == old_size + shard_size assert len(self.embed_data) == old_size + shard_size
assert seen_own_shard assert seen_own_shard
......
...@@ -4,27 +4,32 @@ import torch.distributed as dist ...@@ -4,27 +4,32 @@ import torch.distributed as dist
from megatron import get_args from megatron import get_args
from megatron import mpu from megatron import mpu
from megatron.checkpointing import load_ict_checkpoint from megatron.checkpointing import load_biencoder_checkpoint
from megatron.data.ict_dataset import get_ict_dataset from megatron.data.orqa_wiki_dataset import get_open_retrieval_wiki_dataset
from megatron.data.orqa_wiki_dataset import get_open_retrieval_batch
from megatron.data.biencoder_dataset_utils import get_one_epoch_dataloader from megatron.data.biencoder_dataset_utils import get_one_epoch_dataloader
from megatron.data.realm_index import detach, OpenRetreivalDataStore from megatron.data.realm_index import detach, OpenRetreivalDataStore
from megatron.data.biencoder_dataset_utils import get_ict_batch
from megatron.model.biencoder_model import biencoder_model_provider from megatron.model.biencoder_model import biencoder_model_provider
#from megatron.model.realm_model import general_ict_model_provider
from megatron.training import get_model from megatron.training import get_model
class IndexBuilder(object): class IndexBuilder(object):
"""Object for taking one pass over a dataset and creating a BlockData of its embeddings""" """
Object for taking one pass over a dataset and creating a BlockData of its
embeddings
"""
def __init__(self): def __init__(self):
args = get_args() args = get_args()
self.model = None self.model = None
self.dataloader = None self.dataloader = None
self.block_data = None self.evidence_embedder_obj = None
self.biencoder_shared_query_context_model = \
args.biencoder_shared_query_context_model
# need to know whether we're using a REALM checkpoint (args.load) or ICT checkpoint # need to know whether we're using a REALM checkpoint (args.load)
# or ICT checkpoint
assert not (args.load and args.ict_load) assert not (args.load and args.ict_load)
self.using_realm_chkpt = args.ict_load is None #self.using_realm_chkpt = args.ict_load is None
self.log_interval = args.indexer_log_interval self.log_interval = args.indexer_log_interval
self.batch_size = args.indexer_batch_size self.batch_size = args.indexer_batch_size
...@@ -35,62 +40,88 @@ class IndexBuilder(object): ...@@ -35,62 +40,88 @@ class IndexBuilder(object):
self.iteration = self.total_processed = 0 self.iteration = self.total_processed = 0
def load_attributes(self): def load_attributes(self):
"""Load the necessary attributes: model, dataloader and empty BlockData""" """
model = get_model(lambda: biencoder_model_provider(only_context_model=True)) Load the necessary attributes: model, dataloader and empty BlockData
self.model = load_ict_checkpoint(model, only_context_model=True, from_realm_chkpt=self.using_realm_chkpt) """
sys.exit() only_context_model = True
self.model.eval() if self.biencoder_shared_query_context_model:
self.dataset = get_ict_dataset() only_context_model = False
self.dataloader = iter(get_one_epoch_dataloader(self.dataset, self.batch_size))
self.block_data = OpenRetreivalDataStore(load_from_path=False) model = get_model(lambda: biencoder_model_provider(only_context_model \
print("load_attributes is done", flush=True) = only_context_model, biencoder_shared_query_context_model = \
sys.exit() self.biencoder_shared_query_context_model))
self.model = load_biencoder_checkpoint(model,
only_context_model=only_context_model)
assert len(self.model) == 1
self.model[0].eval()
self.dataset = get_open_retrieval_wiki_dataset()
self.dataloader = iter(get_one_epoch_dataloader(self.dataset, \
self.batch_size))
self.evidence_embedder_obj = OpenRetreivalDataStore( \
load_from_path=False)
def track_and_report_progress(self, batch_size): def track_and_report_progress(self, batch_size):
"""Utility function for tracking progress""" """
Utility function for tracking progress
"""
self.iteration += 1 self.iteration += 1
self.total_processed += batch_size * self.num_total_builders self.total_processed += batch_size * self.num_total_builders
if self.is_main_builder and self.iteration % self.log_interval == 0: if self.is_main_builder and self.iteration % self.log_interval == 0:
print('Batch {:10d} | Total {:10d}'.format(self.iteration, self.total_processed), flush=True) print('Batch {:10d} | Total {:10d}'.format(self.iteration,
self.total_processed), flush=True)
def build_and_save_index(self): def build_and_save_index(self):
"""Goes through one epoch of the dataloader and adds all data to this instance's BlockData. """
Goes through one epoch of the dataloader and adds all data to this
instance's BlockData.
The copy of BlockData is saved as a shard, which when run in a distributed setting will be The copy of BlockData is saved as a shard, which when run in a
consolidated by the rank 0 process and saved as a final pickled BlockData. distributed setting will be consolidated by the rank 0 process
and saved as a final pickled BlockData.
""" """
assert len(self.model) == 1
unwrapped_model = self.model[0]
while not hasattr(unwrapped_model, 'embed_text'):
unwrapped_model = unwrapped_model.module
while True: while True:
try: try:
# batch also has query_tokens and query_pad_data # batch also has query_tokens and query_pad_data
_, _, block_tokens, block_pad_mask, block_sample_data = get_ict_batch(self.dataloader) row_id, context_tokens, context_mask, context_types, \
context_pad_mask = get_open_retrieval_batch( \
self.dataloader)
except (StopIteration, IndexError): except (StopIteration, IndexError):
break break
unwrapped_model = self.model # TODO: can we add with torch.no_grad() to reduce memory usage
while not hasattr(unwrapped_model, 'embed_block'):
unwrapped_model = unwrapped_model.module
# detach, separate fields and add to BlockData # detach, separate fields and add to BlockData
block_logits = detach(unwrapped_model.embed_block(block_tokens, block_pad_mask)) assert context_mask.dtype == torch.bool
detached_data = detach(block_sample_data) context_logits = unwrapped_model.embed_text(
unwrapped_model.context_model, context_tokens, context_mask,
# block_sample_data is a 2D array [batch x 4] context_types)
# with columns [start_idx, end_idx, doc_idx, block_idx] same as class BlockSampleData context_logits = detach(context_logits)
block_indices = detached_data[:, 3] row_id = detach(row_id)
block_metas = detached_data[:, :3]
self.evidence_embedder_obj.add_block_data(row_id, context_logits)
self.block_data.add_block_data(block_indices, block_logits, block_metas) self.track_and_report_progress(batch_size=len(row_id))
self.track_and_report_progress(batch_size=block_tokens.shape[0])
# This process signals to finalize its shard and then synchronize with
# This process signals to finalize its shard and then synchronize with the other processes # the other processes
self.block_data.save_shard() self.evidence_embedder_obj.save_shard()
torch.distributed.barrier() torch.distributed.barrier()
del self.model del self.model
# rank 0 process builds the final copy # rank 0 process builds the final copy
if self.is_main_builder: if self.is_main_builder:
self.block_data.merge_shards_and_save() self.evidence_embedder_obj.merge_shards_and_save()
# make sure that every single piece of data was embedded # make sure that every single piece of data was embedded
assert len(self.block_data.embed_data) == len(self.dataset) assert len(self.evidence_embedder_obj.embed_data) == \
self.block_data.clear() len(self.dataset)
self.evidence_embedder_obj.clear()
# complete building the final copy
torch.distributed.barrier()
...@@ -34,13 +34,11 @@ from .bert_model import (BertModel, ...@@ -34,13 +34,11 @@ from .bert_model import (BertModel,
BertModelFirstStage, BertModelFirstStage,
BertModelIntermediateStage, BertModelIntermediateStage,
BertModelLastStage) BertModelLastStage)
from .realm_model import ICTBertModel
from .gpt_model import (GPTModel, from .gpt_model import (GPTModel,
GPTModelFirstStage, GPTModelFirstStage,
GPTModelIntermediateStage, GPTModelIntermediateStage,
GPTModelLastStage) GPTModelLastStage)
from .language_model import get_language_model from .language_model import get_language_model
from .module import FP16Module from .module import FP16Module
from .realm_model import ICTBertModel
...@@ -44,7 +44,6 @@ from megatron.initialize import initialize_megatron ...@@ -44,7 +44,6 @@ from megatron.initialize import initialize_megatron
from megatron.initialize import write_args_to_tensorboard from megatron.initialize import write_args_to_tensorboard
from megatron.learning_rates import AnnealingLR from megatron.learning_rates import AnnealingLR
from megatron.model import DistributedDataParallel as LocalDDP from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model.realm_model import ICTBertModel
from megatron.utils import check_adlr_autoresume_termination from megatron.utils import check_adlr_autoresume_termination
from megatron.utils import unwrap_model from megatron.utils import unwrap_model
from megatron.data.data_samplers import build_pretraining_data_loader from megatron.data.data_samplers import build_pretraining_data_loader
......
...@@ -3,6 +3,7 @@ import sys ...@@ -3,6 +3,7 @@ import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
os.path.pardir))) os.path.pardir)))
from megatron import print_rank_0
from megatron.indexer import IndexBuilder from megatron.indexer import IndexBuilder
from megatron.initialize import initialize_megatron from megatron.initialize import initialize_megatron
...@@ -24,9 +25,8 @@ def main(): ...@@ -24,9 +25,8 @@ def main():
initialize_megatron(extra_args_provider=None, initialize_megatron(extra_args_provider=None,
args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'}) args_defaults={'tokenizer_type': 'BertWordPieceLowerCase'})
index_builder = IndexBuilder() index_builder = IndexBuilder()
sys.exit()
index_builder.build_and_save_index() index_builder.build_and_save_index()
print_rank_0("Build and save indices: done!")
if __name__ == "__main__": if __name__ == "__main__":
main() main()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment