Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
b0a3c636
"...text-generation-inference.git" did not exist on "a6a0c97ed92b46592572f15b1cd954c789205447"
Commit
b0a3c636
authored
Jun 29, 2020
by
Neel Kant
Browse files
Add indexer.py
parent
68283d93
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
185 additions
and
0 deletions
+185
-0
indexer.py
indexer.py
+185
-0
No files found.
indexer.py
0 → 100644
View file @
b0a3c636
import
os
import
sys
import
time
import
torch
import
torch.distributed
as
dist
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
from
megatron
import
get_args
,
get_adlr_autoresume
,
print_rank_0
from
megatron
import
mpu
from
megatron.checkpointing
import
get_checkpoint_tracker_filename
,
get_checkpoint_name
from
megatron.data.dataset_utils
import
get_indexed_dataset_
from
megatron.data.realm_dataset
import
ICTDataset
from
megatron.data.realm_dataset_utils
import
BlockSampleData
from
megatron.data.realm_index
import
detach
,
BlockData
,
FaissMIPSIndex
from
megatron.data.samplers
import
DistributedBatchSampler
from
megatron.initialize
import
initialize_megatron
from
megatron.training
import
get_model
from
pretrain_bert_ict
import
get_batch
,
general_ict_model_provider
def
pprint
(
*
args
):
print
(
*
args
,
flush
=
True
)
class
IndexBuilder
(
object
):
"""Object for taking one pass over a dataset and creating a BlockData of its embeddings"""
def
__init__
(
self
):
args
=
get_args
()
self
.
model
=
None
self
.
dataloader
=
None
self
.
block_data
=
None
self
.
load_attributes
()
self
.
is_main_builder
=
args
.
rank
==
0
self
.
iteration
=
self
.
total_processed
=
0
def
load_attributes
(
self
):
"""Load the necessary attributes: model, dataloader and empty BlockData"""
# TODO: handle from_realm_chkpt correctly
self
.
model
=
load_ict_checkpoint
(
only_block_model
=
True
,
from_realm_chkpt
=
False
)
self
.
model
.
eval
()
self
.
dataloader
=
iter
(
get_one_epoch_dataloader
(
get_ict_dataset
()))
self
.
block_data
=
BlockData
()
def
track_and_report_progress
(
self
,
batch_size
):
"""Utility function for tracking progress"""
self
.
iteration
+=
1
self
.
total_processed
+=
batch_size
if
self
.
iteration
%
10
==
0
:
print
(
'Batch {:10d} | Total {:10d}'
.
format
(
self
.
iteration
,
self
.
total_processed
),
flush
=
True
)
def
build_and_save_index
(
self
):
"""Goes through one epoch of the dataloader and adds all data to this instance's BlockData.
The copy of BlockData is saved as a shard, which when run in a distributed setting will be
consolidated by the rank 0 process and saved as a final pickled BlockData.
"""
while
True
:
try
:
# batch also has query_tokens and query_pad_data
_
,
_
,
block_tokens
,
block_pad_mask
,
block_sample_data
=
get_batch
(
self
.
dataloader
)
except
:
break
# detach, setup and add to BlockData
unwrapped_model
=
self
.
model
while
not
hasattr
(
unwrapped_model
,
'embed_block'
):
unwrapped_model
=
unwrapped_model
.
module
block_logits
=
detach
(
unwrapped_model
.
embed_block
(
block_tokens
,
block_pad_mask
))
detached_data
=
detach
(
block_sample_data
)
block_indices
=
detached_data
[:,
3
]
block_metas
=
detached_data
[:,
:
3
]
self
.
block_data
.
add_block_data
(
block_indices
,
block_logits
,
block_metas
)
self
.
track_and_report_progress
(
batch_size
=
block_tokens
.
shape
[
0
])
# This process signals to finalize its shard and then synchronize with the other processes
self
.
block_data
.
save_shard
()
torch
.
distributed
.
barrier
()
del
self
.
model
# rank 0 process builds the final copy
if
self
.
is_main_builder
:
self
.
block_data
.
merge_shards_and_save
()
self
.
block_data
.
clear
()
def
load_ict_checkpoint
(
only_query_model
=
False
,
only_block_model
=
False
,
from_realm_chkpt
=
False
):
"""load ICT checkpoints for indexing/retrieving. Arguments specify which parts of the state dict to actually use."""
args
=
get_args
()
model
=
get_model
(
lambda
:
general_ict_model_provider
(
only_query_model
,
only_block_model
))
if
isinstance
(
model
,
torchDDP
):
model
=
model
.
module
load_path
=
args
.
load
if
from_realm_chkpt
else
args
.
ict_load
tracker_filename
=
get_checkpoint_tracker_filename
(
load_path
)
with
open
(
tracker_filename
,
'r'
)
as
f
:
iteration
=
int
(
f
.
read
().
strip
())
# assert iteration > 0
checkpoint_name
=
get_checkpoint_name
(
load_path
,
iteration
,
False
)
if
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
'global rank {} is loading checkpoint {}'
.
format
(
torch
.
distributed
.
get_rank
(),
checkpoint_name
))
state_dict
=
torch
.
load
(
checkpoint_name
,
map_location
=
'cpu'
)
ict_state_dict
=
state_dict
[
'model'
]
if
from_realm_chkpt
:
print
(
">>>> Attempting to get ict state dict from realm"
,
flush
=
True
)
ict_state_dict
=
ict_state_dict
[
'retriever'
][
'ict_model'
]
if
only_query_model
:
ict_state_dict
.
pop
(
'context_model'
)
if
only_block_model
:
ict_state_dict
.
pop
(
'question_model'
)
model
.
load_state_dict
(
ict_state_dict
)
torch
.
distributed
.
barrier
()
if
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
' successfully loaded {}'
.
format
(
checkpoint_name
))
return
model
def
get_ict_dataset
(
use_titles
=
True
,
query_in_block_prob
=
1
):
"""Get a dataset which uses block samples mappings to get ICT/block indexing data"""
args
=
get_args
()
block_dataset
=
get_indexed_dataset_
(
args
.
data_path
,
'mmap'
,
True
)
titles_dataset
=
get_indexed_dataset_
(
args
.
titles_data_path
,
'mmap'
,
True
)
kwargs
=
dict
(
name
=
'full'
,
block_dataset
=
block_dataset
,
title_dataset
=
titles_dataset
,
data_prefix
=
args
.
data_path
,
num_epochs
=
1
,
max_num_samples
=
None
,
max_seq_length
=
args
.
seq_length
,
short_seq_prob
=
0.0001
,
# doesn't matter
seed
=
1
,
query_in_block_prob
=
query_in_block_prob
,
use_titles
=
use_titles
,
use_one_sent_docs
=
True
)
dataset
=
ICTDataset
(
**
kwargs
)
return
dataset
def
get_one_epoch_dataloader
(
dataset
,
batch_size
=
None
):
"""Specifically one epoch to be used in an indexing job."""
args
=
get_args
()
world_size
=
mpu
.
get_data_parallel_world_size
()
rank
=
mpu
.
get_data_parallel_rank
()
if
batch_size
is
None
:
batch_size
=
args
.
batch_size
global_batch_size
=
batch_size
*
world_size
num_workers
=
args
.
num_workers
sampler
=
torch
.
utils
.
data
.
SequentialSampler
(
dataset
)
# importantly, drop_last must be False to get all the data.
batch_sampler
=
DistributedBatchSampler
(
sampler
,
batch_size
=
global_batch_size
,
drop_last
=
False
,
rank
=
rank
,
world_size
=
world_size
)
return
torch
.
utils
.
data
.
DataLoader
(
dataset
,
batch_sampler
=
batch_sampler
,
num_workers
=
num_workers
,
pin_memory
=
True
)
if
__name__
==
"__main__"
:
# This usage is for basic (as opposed to realm async) indexing jobs.
initialize_megatron
(
extra_args_provider
=
None
,
args_defaults
=
{
'tokenizer_type'
:
'BertWordPieceLowerCase'
})
index_builder
=
IndexBuilder
()
index_builder
.
build_and_save_index
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment