Commit 0024a5c6 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'main' of https://github.com/NVIDIA/Megatron-LM

parents b004456b 3db2063b
Pipeline #229 failed with stages
in 0 seconds
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from .faiss_base import FaissBaseIndex
from .faiss_par_add import FaissParallelAddIndex
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
"""
This class implements a simple, un-optimized wrapper around a Faiss index, that
implements the Index interface (see ..index.py). While this class is
instantiable, it is meant to be extended with optimizations in classes that
inherit from this class (see FaissParAddIndex, for an example).
"""
from datetime import timedelta
import os
import torch
from tqdm import tqdm
from megatron import get_retro_args, print_rank_0
from tools.bert_embedding import BertEmbedder
from tools.retro.external_libs import faiss
from tools.retro.index import Index
from tools.retro.index.utils import num_samples_to_block_ranges
class FaissBaseIndex(Index):
def _train(self, input_data_loader):
'''Train index (rank 0's method).'''
args = get_retro_args()
assert torch.distributed.get_rank() == 0
# Set num threads (torch.distributed reset it to 1).
# faiss.omp_set_num_threads(32)
faiss.omp_set_num_threads(64)
# faiss.omp_set_num_threads(128)
empty_index_path = self.get_empty_index_path()
# Index already exists? -> return.
if os.path.isfile(empty_index_path):
return
# Load data.
inp = input_data_loader()
# Init index.
index = faiss.index_factory(args.retro_index_nfeats,
args.retro_index_str)
# Move to GPU.
index_ivf = faiss.extract_index_ivf(index)
clustering_index = \
faiss.index_cpu_to_all_gpus(faiss.IndexFlatL2(index_ivf.d))
index_ivf.clustering_index = clustering_index
self.c_verbose(index, True)
self.c_verbose(index_ivf, True)
self.c_verbose(index_ivf.quantizer, True)
self.c_verbose(index_ivf.clustering_index, True)
# Train index.
index.train(inp)
# Save index.
faiss.write_index(index, empty_index_path)
def train(self, input_data_loader):
'''Train index.'''
# Single process only.
if torch.distributed.get_rank() == 0:
self._train(input_data_loader)
torch.distributed.barrier()
def _add(self, text_dataset):
'''Add to index (rank 0's method).'''
assert torch.distributed.get_rank() == 0
args = get_retro_args()
dataset_sample_ranges = num_samples_to_block_ranges(len(text_dataset))
# Set num threads (torch.distributed reset it to 1).
faiss.omp_set_num_threads(64)
# Bert embedder.
embedder = BertEmbedder(args.retro_bert_batch_size,
args.retro_bert_max_chunk_length,
args.bert_embedder_type)
# Empty/added index paths.
empty_index_path = self.get_empty_index_path()
added_index_path = self.get_added_index_path()
# Skip adding, if index exists.
if os.path.isfile(added_index_path):
return
# Read trained index.
index = faiss.read_index(empty_index_path)
# Iterate data blocks & add.
for sample_range in tqdm(dataset_sample_ranges, "faiss_base.add"):
# Embed text.
embeds = self.embed_text_dataset_block(
embedder, text_dataset, sample_range)
# Add to index.
index.add(embeds)
# Write index.
faiss.write_index(index, added_index_path)
def add(self, text_dataset):
'''Add to index.'''
# Single process only.
if torch.distributed.get_rank() == 0:
self._add(text_dataset)
# Wait for rank 0.
torch.distributed.barrier()
# Get output index path, for return.
return self.get_added_index_path()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment