Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
b9bd1a11
Commit
b9bd1a11
authored
Jul 13, 2020
by
Neel Kant
Browse files
Additional refactoring
parent
ca0cdfaa
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
69 additions
and
162 deletions
+69
-162
indexer.py
indexer.py
+11
-13
megatron/data/dataset_utils.py
megatron/data/dataset_utils.py
+1
-1
megatron/data/realm_dataset.py
megatron/data/realm_dataset.py
+0
-115
megatron/data/realm_index.py
megatron/data/realm_index.py
+52
-21
megatron/tokenizer/tokenizer.py
megatron/tokenizer/tokenizer.py
+5
-5
megatron/training.py
megatron/training.py
+0
-7
No files found.
indexer.py
View file @
b9bd1a11
import
os
import
sys
import
time
import
torch
import
torch.distributed
as
dist
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
from
megatron
import
get_args
,
get_adlr_autoresume
,
print_rank_0
from
megatron
import
get_args
from
megatron
import
mpu
from
megatron.checkpointing
import
get_checkpoint_tracker_filename
,
get_checkpoint_name
from
megatron.data.dataset_utils
import
get_indexed_dataset_
from
megatron.data.realm_dataset
import
ICTDataset
from
megatron.data.realm_dataset_utils
import
BlockSampleData
from
megatron.data.realm_index
import
detach
,
BlockData
,
FaissMIPSIndex
from
megatron.data.ict_dataset
import
ICTDataset
from
megatron.data.realm_index
import
detach
,
BlockData
from
megatron.data.samplers
import
DistributedBatchSampler
from
megatron.initialize
import
initialize_megatron
from
megatron.training
import
get_model
from
pretrain_
bert_
ict
import
get_batch
,
general_ict_model_provider
from
pretrain_ict
import
get_batch
,
general_ict_model_provider
def
pprint
(
*
args
):
...
...
@@ -30,17 +25,21 @@ class IndexBuilder(object):
self
.
model
=
None
self
.
dataloader
=
None
self
.
block_data
=
None
# need to know whether we're using a REALM checkpoint (args.load) or ICT checkpoint
assert
not
(
args
.
load
and
args
.
ict_load
)
self
.
using_realm_chkpt
=
args
.
ict_load
is
None
self
.
load_attributes
()
self
.
is_main_builder
=
args
.
rank
==
0
self
.
iteration
=
self
.
total_processed
=
0
def
load_attributes
(
self
):
"""Load the necessary attributes: model, dataloader and empty BlockData"""
# TODO: handle from_realm_chkpt correctly
self
.
model
=
load_ict_checkpoint
(
only_block_model
=
True
,
from_realm_chkpt
=
False
)
self
.
model
=
load_ict_checkpoint
(
only_block_model
=
True
,
from_realm_chkpt
=
self
.
using_realm_chkpt
)
self
.
model
.
eval
()
self
.
dataloader
=
iter
(
get_one_epoch_dataloader
(
get_ict_dataset
()))
self
.
block_data
=
BlockData
()
self
.
block_data
=
BlockData
(
load_from_path
=
False
)
def
track_and_report_progress
(
self
,
batch_size
):
"""Utility function for tracking progress"""
...
...
@@ -141,7 +140,6 @@ def get_ict_dataset(use_titles=True, query_in_block_prob=1):
num_epochs
=
1
,
max_num_samples
=
None
,
max_seq_length
=
args
.
seq_length
,
short_seq_prob
=
0.0001
,
# doesn't matter
seed
=
1
,
query_in_block_prob
=
query_in_block_prob
,
use_titles
=
use_titles
,
...
...
megatron/data/dataset_utils.py
View file @
b9bd1a11
...
...
@@ -417,7 +417,6 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
num_epochs
=
None
,
max_num_samples
=
train_valid_test_num_samples
[
index
],
max_seq_length
=
max_seq_length
,
short_seq_prob
=
short_seq_prob
,
seed
=
seed
)
...
...
@@ -434,6 +433,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
dataset
=
BertDataset
(
indexed_dataset
=
indexed_dataset
,
masked_lm_prob
=
masked_lm_prob
,
short_seq_prob
=
short_seq_prob
,
**
kwargs
)
...
...
megatron/data/realm_dataset.py
deleted
100644 → 0
View file @
ca0cdfaa
import
collections
import
itertools
import
random
import
numpy
as
np
from
torch.utils.data
import
Dataset
from
megatron
import
get_tokenizer
from
megatron.data.realm_dataset_utils
import
BlockSampleData
,
get_block_samples_mapping
,
join_str_list
class
ICTDataset
(
Dataset
):
"""Dataset containing sentences and their blocks for an inverse cloze task."""
def
__init__
(
self
,
name
,
block_dataset
,
title_dataset
,
data_prefix
,
num_epochs
,
max_num_samples
,
max_seq_length
,
query_in_block_prob
,
short_seq_prob
,
seed
,
use_titles
=
True
,
use_one_sent_docs
=
False
):
self
.
name
=
name
self
.
seed
=
seed
self
.
max_seq_length
=
max_seq_length
self
.
query_in_block_prob
=
query_in_block_prob
self
.
block_dataset
=
block_dataset
self
.
title_dataset
=
title_dataset
self
.
short_seq_prob
=
short_seq_prob
self
.
rng
=
random
.
Random
(
self
.
seed
)
self
.
use_titles
=
use_titles
self
.
use_one_sent_docs
=
use_one_sent_docs
self
.
samples_mapping
=
get_block_samples_mapping
(
block_dataset
,
title_dataset
,
data_prefix
,
num_epochs
,
max_num_samples
,
max_seq_length
,
seed
,
name
,
use_one_sent_docs
)
self
.
tokenizer
=
get_tokenizer
()
self
.
vocab_id_list
=
list
(
self
.
tokenizer
.
inv_vocab
.
keys
())
self
.
vocab_id_to_token_list
=
self
.
tokenizer
.
inv_vocab
self
.
cls_id
=
self
.
tokenizer
.
cls
self
.
sep_id
=
self
.
tokenizer
.
sep
self
.
mask_id
=
self
.
tokenizer
.
mask
self
.
pad_id
=
self
.
tokenizer
.
pad
def
__len__
(
self
):
return
self
.
samples_mapping
.
shape
[
0
]
def
__getitem__
(
self
,
idx
):
"""Get an ICT example of a pseudo-query and the block of text from which it was extracted"""
sample_data
=
self
.
samples_mapping
[
idx
]
start_idx
,
end_idx
,
doc_idx
,
block_idx
=
sample_data
.
as_tuple
()
if
self
.
use_titles
:
title
=
self
.
title_dataset
[
int
(
doc_idx
)]
title_pad_offset
=
3
+
len
(
title
)
else
:
title
=
None
title_pad_offset
=
2
block
=
[
self
.
block_dataset
[
i
]
for
i
in
range
(
start_idx
,
end_idx
)]
assert
len
(
block
)
>
1
or
self
.
use_one_sent_docs
or
self
.
query_in_block_prob
==
1
# randint() is inclusive for Python rng
rand_sent_idx
=
self
.
rng
.
randint
(
0
,
len
(
block
)
-
1
)
# keep the query in the context query_in_block_prob fraction of the time.
if
self
.
rng
.
random
()
<
self
.
query_in_block_prob
:
query
=
block
[
rand_sent_idx
].
copy
()
else
:
query
=
block
.
pop
(
rand_sent_idx
)
# still need to truncate because blocks are concluded when
# the sentence lengths have exceeded max_seq_length.
query
=
query
[:
self
.
max_seq_length
-
2
]
block
=
list
(
itertools
.
chain
(
*
block
))[:
self
.
max_seq_length
-
title_pad_offset
]
query_tokens
,
query_pad_mask
=
self
.
concat_and_pad_tokens
(
query
)
block_tokens
,
block_pad_mask
=
self
.
concat_and_pad_tokens
(
block
,
title
)
block_data
=
sample_data
.
as_array
()
sample
=
{
'query_tokens'
:
query_tokens
,
'query_pad_mask'
:
query_pad_mask
,
'block_tokens'
:
block_tokens
,
'block_pad_mask'
:
block_pad_mask
,
'block_data'
:
block_data
,
}
return
sample
def
get_block
(
self
,
start_idx
,
end_idx
,
doc_idx
):
"""Get the IDs for an evidence block plus the title of the corresponding document"""
block
=
[
self
.
block_dataset
[
i
]
for
i
in
range
(
start_idx
,
end_idx
)]
title
=
self
.
title_dataset
[
int
(
doc_idx
)]
block
=
list
(
itertools
.
chain
(
*
block
))[:
self
.
max_seq_length
-
(
3
+
len
(
title
))]
block_tokens
,
block_pad_mask
=
self
.
concat_and_pad_tokens
(
block
,
title
)
return
block_tokens
,
block_pad_mask
def
get_null_block
(
self
):
"""Get empty block and title - used in REALM pretraining"""
block
,
title
=
[],
[]
block_tokens
,
block_pad_mask
=
self
.
concat_and_pad_tokens
(
block
,
title
)
return
block_tokens
,
block_pad_mask
def
concat_and_pad_tokens
(
self
,
tokens
,
title
=
None
):
"""Concat with special tokens and pad sequence to self.max_seq_length"""
tokens
=
list
(
tokens
)
if
title
is
None
:
tokens
=
[
self
.
cls_id
]
+
tokens
+
[
self
.
sep_id
]
else
:
title
=
list
(
title
)
tokens
=
[
self
.
cls_id
]
+
title
+
[
self
.
sep_id
]
+
tokens
+
[
self
.
sep_id
]
assert
len
(
tokens
)
<=
self
.
max_seq_length
num_pad
=
self
.
max_seq_length
-
len
(
tokens
)
pad_mask
=
[
1
]
*
len
(
tokens
)
+
[
0
]
*
num_pad
tokens
+=
[
self
.
pad_id
]
*
num_pad
return
np
.
array
(
tokens
),
np
.
array
(
pad_mask
)
megatron/data/realm_index.py
View file @
b9bd1a11
from
collections
import
defaultdict
import
itertools
import
os
import
pickle
...
...
@@ -8,7 +7,7 @@ import faiss
import
numpy
as
np
import
torch
from
megatron
import
get_args
,
mpu
from
megatron
import
get_args
def
detach
(
tensor
):
...
...
@@ -17,7 +16,7 @@ def detach(tensor):
class
BlockData
(
object
):
"""Serializable data structure for holding data for blocks -- embeddings and necessary metadata for REALM"""
def
__init__
(
self
,
block_data_path
=
None
,
rank
=
None
):
def
__init__
(
self
,
block_data_path
=
None
,
load_from_path
=
True
,
rank
=
None
):
self
.
embed_data
=
dict
()
self
.
meta_data
=
dict
()
if
block_data_path
is
None
:
...
...
@@ -27,6 +26,9 @@ class BlockData(object):
self
.
block_data_path
=
block_data_path
self
.
rank
=
rank
if
load_from_path
:
self
.
load_from_file
()
block_data_name
=
os
.
path
.
splitext
(
self
.
block_data_path
)[
0
]
self
.
temp_dir_name
=
block_data_name
+
'_tmp'
...
...
@@ -43,18 +45,23 @@ class BlockData(object):
"""
self
.
embed_data
=
dict
()
@
classmethod
def
load_from_file
(
cls
,
fname
):
def
load_from_file
(
self
):
"""Populate members from instance saved to file"""
print
(
"
\n
> Unpickling BlockData"
,
flush
=
True
)
state_dict
=
pickle
.
load
(
open
(
fname
,
'rb'
))
state_dict
=
pickle
.
load
(
open
(
self
.
block_data_path
,
'rb'
))
print
(
">> Finished unpickling BlockData
\n
"
,
flush
=
True
)
new_index
=
cls
()
new_index
.
embed_data
=
state_dict
[
'embed_data'
]
new_index
.
meta_data
=
state_dict
[
'meta_data'
]
return
new_index
self
.
embed_data
=
state_dict
[
'embed_data'
]
self
.
meta_data
=
state_dict
[
'meta_data'
]
def
add_block_data
(
self
,
block_indices
,
block_embeds
,
block_metas
,
allow_overwrite
=
False
):
"""Add data for set of blocks
:param block_indices: 1D array of unique int ids for the blocks
:param block_embeds: 2D array of embeddings of the blocks
:param block_metas: 2D array of metadata for the blocks.
In the case of REALM this will be [start_idx, end_idx, doc_idx]
"""
for
idx
,
embed
,
meta
in
zip
(
block_indices
,
block_embeds
,
block_metas
):
if
not
allow_overwrite
and
idx
in
self
.
embed_data
:
raise
ValueError
(
"Unexpectedly tried to overwrite block data"
)
...
...
@@ -63,6 +70,7 @@ class BlockData(object):
self
.
meta_data
[
idx
]
=
meta
def
save_shard
(
self
):
"""Save the block data that was created this in this process"""
if
not
os
.
path
.
isdir
(
self
.
temp_dir_name
):
os
.
makedirs
(
self
.
temp_dir_name
,
exist_ok
=
True
)
...
...
@@ -104,9 +112,9 @@ class BlockData(object):
class
FaissMIPSIndex
(
object
):
"""Wrapper object for a BlockData which similarity search via FAISS under the hood"""
def
__init__
(
self
,
index_type
,
embed_size
,
use_gpu
=
False
):
self
.
index_type
=
index_type
def
__init__
(
self
,
embed_size
,
block_data
=
None
,
use_gpu
=
False
):
self
.
embed_size
=
embed_size
self
.
block_data
=
block_data
self
.
use_gpu
=
use_gpu
self
.
id_map
=
dict
()
...
...
@@ -114,10 +122,7 @@ class FaissMIPSIndex(object):
self
.
_set_block_index
()
def
_set_block_index
(
self
):
INDEX_TYPES
=
[
'flat_ip'
]
if
self
.
index_type
not
in
INDEX_TYPES
:
raise
ValueError
(
"Invalid index type specified"
)
"""Create a Faiss Flat index with inner product as the metric to search against"""
print
(
"
\n
> Building index"
,
flush
=
True
)
self
.
block_mips_index
=
faiss
.
index_factory
(
self
.
embed_size
,
'Flat'
,
faiss
.
METRIC_INNER_PRODUCT
)
...
...
@@ -129,29 +134,52 @@ class FaissMIPSIndex(object):
config
.
useFloat16
=
True
self
.
block_mips_index
=
faiss
.
GpuIndexFlat
(
res
,
self
.
block_mips_index
,
config
)
print
(
">>
> Finished building
index on GPU {}
\n
"
.
format
(
self
.
block_mips_index
.
getDevice
()),
flush
=
True
)
print
(
">>
Initialized
index on GPU {}"
.
format
(
self
.
block_mips_index
.
getDevice
()),
flush
=
True
)
else
:
# CPU index supports IDs so wrap with IDMap
self
.
block_mips_index
=
faiss
.
IndexIDMap
(
self
.
block_mips_index
)
print
(
">> Finished building index
\n
"
,
flush
=
True
)
print
(
">> Initialized index on CPU"
,
flush
=
True
)
# if we were constructed with a BlockData, then automatically load it when the FAISS structure is built
if
self
.
block_data
is
not
None
:
self
.
add_block_embed_data
(
self
.
block_data
)
def
reset_index
(
self
):
"""Delete existing index and create anew"""
del
self
.
block_mips_index
# reset the block data so that _set_block_index will reload it as well
if
self
.
block_data
is
not
None
:
block_data_path
=
self
.
block_data
.
block_data_path
del
self
.
block_data
self
.
block_data
=
BlockData
.
load_from_file
(
block_data_path
)
self
.
_set_block_index
()
def
add_block_embed_data
(
self
,
all_block_data
):
"""Add the embedding of each block to the underlying FAISS index"""
# this assumes the embed_data is a dict : {int: np.array<float>}
block_indices
,
block_embeds
=
zip
(
*
all_block_data
.
embed_data
.
items
())
# the embeddings have to be entered in as float32 even though the math internally is done with float16.
block_embeds_arr
=
np
.
float32
(
np
.
array
(
block_embeds
))
block_indices_arr
=
np
.
array
(
block_indices
)
# faiss GpuIndex doesn't work with IDMap wrapper so store ids to map back with
if
self
.
use_gpu
:
for
i
,
idx
in
enumerate
(
block_indices
):
self
.
id_map
[
i
]
=
idx
# we no longer need the embedding data since it's in the index now
all_block_data
.
clear
()
if
self
.
use_gpu
:
self
.
block_mips_index
.
add
(
np
.
float32
(
np
.
array
(
block_embeds
))
)
self
.
block_mips_index
.
add
(
block_embeds
_arr
)
else
:
self
.
block_mips_index
.
add_with_ids
(
np
.
float32
(
np
.
array
(
block_embeds
)),
np
.
array
(
block_indices
))
self
.
block_mips_index
.
add_with_ids
(
block_embeds_arr
,
block_indices_arr
)
print
(
">>> Finished adding block data to index"
,
flush
=
True
)
def
search_mips_index
(
self
,
query_embeds
,
top_k
,
reconstruct
=
True
):
"""Get the top-k blocks by the index distance metric.
...
...
@@ -160,12 +188,15 @@ class FaissMIPSIndex(object):
if False: return [num_queries x k] array of distances, and another for indices
"""
query_embeds
=
np
.
float32
(
detach
(
query_embeds
))
with
torch
.
no_grad
():
if
reconstruct
:
# get the vectors themselves
top_k_block_embeds
=
self
.
block_mips_index
.
search_and_reconstruct
(
query_embeds
,
top_k
)
return
top_k_block_embeds
else
:
# get distances and indices of closest vectors
distances
,
block_indices
=
self
.
block_mips_index
.
search
(
query_embeds
,
top_k
)
if
self
.
use_gpu
:
fresh_indices
=
np
.
zeros
(
block_indices
.
shape
)
...
...
megatron/tokenizer/tokenizer.py
View file @
b9bd1a11
...
...
@@ -31,11 +31,11 @@ def build_tokenizer(args):
# Select and instantiate the tokenizer.
assert
args
.
vocab_file
is
not
None
if
args
.
tokenizer_type
==
'BertWordPieceLowerCase'
:
tokenizer
=
BertWordPieceTokenizer
(
vocab_file
=
args
.
vocab_file
,
lower_case
=
True
)
tokenizer
=
_
BertWordPieceTokenizer
(
vocab_file
=
args
.
vocab_file
,
lower_case
=
True
)
elif
args
.
tokenizer_type
==
'BertWordPieceCase'
:
tokenizer
=
BertWordPieceTokenizer
(
vocab_file
=
args
.
vocab_file
,
lower_case
=
False
)
tokenizer
=
_
BertWordPieceTokenizer
(
vocab_file
=
args
.
vocab_file
,
lower_case
=
False
)
elif
args
.
tokenizer_type
==
'GPT2BPETokenizer'
:
assert
args
.
merge_file
is
not
None
tokenizer
=
_GPT2BPETokenizer
(
args
.
vocab_file
,
args
.
merge_file
)
...
...
@@ -124,7 +124,7 @@ class AbstractTokenizer(ABC):
'tokenizer'
.
format
(
self
.
name
))
class
BertWordPieceTokenizer
(
AbstractTokenizer
):
class
_
BertWordPieceTokenizer
(
AbstractTokenizer
):
"""Original BERT wordpiece tokenizer."""
def
__init__
(
self
,
vocab_file
,
lower_case
=
True
):
...
...
megatron/training.py
View file @
b9bd1a11
...
...
@@ -74,11 +74,6 @@ def pretrain(train_valid_test_dataset_provider, model_provider,
args
=
get_args
()
timers
=
get_timers
()
if
args
.
rank
==
0
and
args
.
cased_data_path
is
not
None
:
import
stanza
stanza
.
download
(
'en'
,
processors
=
{
'ner'
:
'conll03'
},
dir
=
'stanza'
)
# Model, optimizer, and learning rate.
timers
(
'model and optimizer'
).
start
()
model
,
optimizer
,
lr_scheduler
=
setup_model_and_optimizer
(
model_provider
)
...
...
@@ -237,8 +232,6 @@ def setup_model_and_optimizer(model_provider_func):
def
backward_step
(
optimizer
,
model
,
loss
):
"""Backward step."""
# if args.rank == 0:
# torch.save(lick)
args
=
get_args
()
timers
=
get_timers
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment