Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
183ad176
Commit
183ad176
authored
May 03, 2020
by
Neel Kant
Browse files
Refactor to build BlockData, FaissMIPSIndex, RandProjectLSHIndex
parent
0104f910
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
194 additions
and
273 deletions
+194
-273
hashed_index.py
hashed_index.py
+16
-268
megatron/data/realm_dataset.py
megatron/data/realm_dataset.py
+176
-4
pretrain_realm.py
pretrain_realm.py
+2
-1
No files found.
hashed_index.py
View file @
183ad176
from
collections
import
defaultdict
import
os
import
pickle
import
shutil
import
numpy
as
np
import
torch
import
torch
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
...
@@ -11,7 +5,8 @@ from megatron import get_args
...
@@ -11,7 +5,8 @@ from megatron import get_args
from
megatron
import
mpu
from
megatron
import
mpu
from
megatron.checkpointing
import
get_checkpoint_tracker_filename
,
get_checkpoint_name
from
megatron.checkpointing
import
get_checkpoint_tracker_filename
,
get_checkpoint_name
from
megatron.data.bert_dataset
import
get_indexed_dataset_
from
megatron.data.bert_dataset
import
get_indexed_dataset_
from
megatron.data.ict_dataset
import
InverseClozeDataset
from
megatron.data.realm_dataset
import
InverseClozeDataset
from
megatron.data.realm_index
import
BlockData
,
RandProjectionLSHIndex
from
megatron.data.samplers
import
DistributedBatchSampler
from
megatron.data.samplers
import
DistributedBatchSampler
from
megatron.initialize
import
initialize_megatron
from
megatron.initialize
import
initialize_megatron
from
megatron.model
import
REALMRetriever
from
megatron.model
import
REALMRetriever
...
@@ -23,253 +18,6 @@ def detach(tensor):
...
@@ -23,253 +18,6 @@ def detach(tensor):
return
tensor
.
detach
().
cpu
().
numpy
()
return
tensor
.
detach
().
cpu
().
numpy
()
class
HashedIndex
(
object
):
"""Class for holding hashed data"""
def
__init__
(
self
,
embed_size
,
num_buckets
,
whiten
=
False
,
seed
=
0
):
np
.
random
.
seed
(
seed
)
self
.
block_data
=
defaultdict
(
list
)
self
.
hash_data
=
defaultdict
(
list
)
hash_matrix
=
2
*
np
.
random
.
rand
(
embed_size
,
int
(
num_buckets
/
2
))
-
1
self
.
hash_matrix
=
hash_matrix
/
np
.
linalg
.
norm
(
hash_matrix
,
axis
=
0
).
reshape
(
1
,
-
1
)
self
.
embed_mean
=
None
self
.
embed_whitener
=
None
self
.
whiten
=
whiten
# alsh
self
.
m
=
5
self
.
u
=
0.99
self
.
max_norm
=
None
self
.
block_index
=
None
def
state
(
self
):
state
=
{
'block_data'
:
self
.
block_data
,
'hash_data'
:
self
.
hash_data
,
'hash_matrix'
:
self
.
hash_matrix
,
'embed_mean'
:
self
.
embed_mean
,
'embed_whitener'
:
self
.
embed_whitener
,
}
return
state
def
get_block_bucket
(
self
,
hash
):
return
self
.
hash_data
[
hash
]
def
get_block_embed
(
self
,
block_idx
):
return
self
.
block_data
[
block_idx
]
def
hash_embeds
(
self
,
embeds
,
block_data
=
None
):
"""Hash a tensor of embeddings using a random projection matrix"""
embed_scores_pos
=
torch
.
matmul
(
embeds
,
torch
.
cuda
.
FloatTensor
(
self
.
hash_matrix
))
embed_scores
=
torch
.
cat
((
embed_scores_pos
,
-
embed_scores_pos
),
axis
=
1
)
embed_hashes
=
detach
(
torch
.
argmax
(
embed_scores
,
axis
=
1
))
if
block_data
is
not
None
:
for
hash
,
indices
in
zip
(
embed_hashes
,
block_data
):
self
.
hash_data
[
hash
].
append
(
indices
)
return
embed_hashes
def
assign_block_embeds
(
self
,
block_indices
,
block_embeds
,
allow_overwrite
=
False
):
"""Assign the embeddings for each block index into a hash map"""
for
idx
,
embed
in
zip
(
block_indices
,
block_embeds
):
if
not
allow_overwrite
and
int
(
idx
)
in
self
.
block_data
:
raise
ValueError
(
"Attempted to overwrite a read-only HashedIndex"
)
self
.
block_data
[
int
(
idx
)]
=
np
.
float16
(
embed
)
def
save_shard
(
self
,
rank
):
dir_name
=
'block_hash_data'
if
not
os
.
path
.
isdir
(
dir_name
):
os
.
mkdir
(
dir_name
)
# save the data for each shard
with
open
(
'{}/{}.pkl'
.
format
(
dir_name
,
rank
),
'wb'
)
as
data_file
:
pickle
.
dump
(
self
.
state
(),
data_file
)
def
consolidate_shards_and_save
(
self
,
ignore_shard
=
0
):
"""Combine all the shards made using self.save_shard()"""
dir_name
=
'block_hash_data'
fnames
=
os
.
listdir
(
dir_name
)
for
fname
in
fnames
:
with
open
(
'{}/{}'
.
format
(
dir_name
,
fname
),
'rb'
)
as
f
:
data
=
pickle
.
load
(
f
)
assert
np
.
array_equal
(
data
[
'hash_matrix'
],
self
.
hash_matrix
)
old_size
=
len
(
self
.
block_data
)
shard_size
=
len
(
data
[
'block_data'
])
self
.
block_data
.
update
(
data
[
'block_data'
])
assert
(
len
(
self
.
block_data
)
==
old_size
+
shard_size
)
or
(
str
(
ignore_shard
)
in
fname
)
if
not
self
.
whiten
:
for
bucket
,
items
in
data
[
'hash_data'
].
items
():
self
.
hash_data
[
bucket
].
extend
(
items
)
if
self
.
whiten
:
self
.
whiten_block_embeds
()
args
=
get_args
()
with
open
(
args
.
hash_data_path
,
'wb'
)
as
final_file
:
pickle
.
dump
(
self
.
state
(),
final_file
)
shutil
.
rmtree
(
dir_name
,
ignore_errors
=
True
)
def
clear
(
self
):
"""Clear the data structures to save memory"""
self
.
block_data
=
dict
()
self
.
hash_data
=
defaultdict
(
list
)
def
whiten_block_embeds
(
self
):
"""Transform all block embeds to have zero mean and unit covariance
when treated as samples from a distribution"""
block_idx
,
all_embeds
=
zip
(
*
self
.
block_data
.
items
())
arr_embeds
=
np
.
transpose
(
np
.
array
(
all_embeds
))
mean
=
np
.
mean
(
arr_embeds
,
axis
=
1
).
reshape
(
-
1
,
1
)
centered
=
arr_embeds
-
mean
inv_cov
=
np
.
linalg
.
inv
(
np
.
cov
(
arr_embeds
))
whitener
=
np
.
transpose
(
np
.
linalg
.
cholesky
(
inv_cov
))
whitened
=
np
.
float16
(
np
.
transpose
(
whitener
.
dot
(
centered
)))
self
.
embed_mean
=
mean
.
reshape
(
-
1
)
self
.
embed_whitener
=
whitener
self
.
block_data
=
dict
(
zip
(
block_idx
,
list
(
whitened
)))
self
.
hash_data
=
defaultdict
(
list
)
batch_size
=
16384
i
=
0
args
=
get_args
()
with
torch
.
no_grad
():
hashing_tensor
=
torch
.
cuda
.
HalfTensor
(
self
.
hash_matrix
)
while
True
:
if
args
.
debug
:
print
(
i
,
flush
=
True
)
batch_slice
=
slice
(
i
*
batch_size
,
(
i
+
1
)
*
batch_size
)
batch_embed
=
torch
.
cuda
.
HalfTensor
(
whitened
[
batch_slice
])
batch_block_idx
=
block_idx
[
batch_slice
]
if
len
(
batch_block_idx
)
==
0
:
break
hash_scores_pos
=
torch
.
matmul
(
batch_embed
,
hashing_tensor
)
embed_scores
=
torch
.
cat
((
hash_scores_pos
,
-
hash_scores_pos
),
axis
=
1
)
embed_hashes
=
detach
(
torch
.
argmax
(
embed_scores
,
axis
=
1
))
for
idx
,
hash
in
zip
(
batch_block_idx
,
list
(
embed_hashes
)):
# [int] instead of [array<int>] since this is just for analysis rn
self
.
hash_data
[
hash
].
append
(
idx
)
i
+=
1
def
create_block_data_index
(
self
):
import
faiss
self
.
block_idx
,
block_embeds
=
zip
(
*
self
.
block_data
.
items
())
block_embeds
=
np
.
array
(
block_embeds
)
alsh_preprocessed_blocks
=
self
.
alsh_block_preprocess_fn
()
index
=
faiss
.
IndexFlatL2
(
alsh_preprocessed_blocks
.
shape
[
1
])
index
.
add
(
alsh_preprocessed_blocks
)
print
(
'Total blocks in index: '
,
index
.
ntotal
)
self
.
block_index
=
index
def
get_norm_powers_and_halves_array
(
self
,
embeds
):
norm
=
np
.
linalg
.
norm
(
embeds
,
axis
=
1
)
norm_powers
=
[
np
.
multiply
(
norm
,
norm
)]
# squared L2 norms of all
for
i
in
range
(
self
.
m
-
1
):
norm_powers
.
append
(
np
.
multiply
(
norm_powers
[
-
1
],
norm_powers
[
-
1
]))
# [num_blocks x self.m]
norm_powers
=
np
.
transpose
(
np
.
array
(
norm_powers
))
halves_array
=
0.5
*
np
.
ones
(
norm_powers
.
shape
)
return
norm_powers
,
halves_array
def
alsh_block_preprocess_fn
(
self
):
block_idx
,
block_embeds
=
zip
(
*
self
.
block_data
.
items
())
block_embeds
=
np
.
array
(
block_embeds
)
if
self
.
max_norm
is
None
:
self
.
max_norm
=
max
(
np
.
linalg
.
norm
(
block_embeds
,
axis
=
1
))
if
self
.
max_norm
>
1
:
block_embeds
=
self
.
u
/
self
.
max_norm
*
block_embeds
norm_powers
,
halves_array
=
self
.
get_norm_powers_and_halves_array
(
block_embeds
)
# P'(S(x)) for all x in block_embeds
return
np
.
float32
(
np
.
concatenate
((
block_embeds
,
norm_powers
,
halves_array
),
axis
=
1
))
def
alsh_query_preprocess_fn
(
self
,
query_embeds
):
max_norm
=
max
(
np
.
linalg
.
norm
(
query_embeds
,
axis
=
1
))
if
max_norm
>
1
:
query_embeds
=
self
.
u
/
max_norm
*
query_embeds
norm_powers
,
halves_array
=
self
.
get_norm_powers_and_halves_array
(
query_embeds
)
# Q'(S(x)) for all x in query_embeds
return
np
.
float32
(
np
.
concatenate
((
query_embeds
,
halves_array
,
norm_powers
),
axis
=
1
))
def
exact_mips_equals
(
self
,
query_embeds
,
norm_blocks
):
"""For each query, determine whether the mips block is in the correct hash bucket"""
shuffled_block_idx
,
block_embeds
=
zip
(
*
self
.
block_data
.
items
())
if
norm_blocks
:
block_embeds
=
block_embeds
/
np
.
linalg
.
norm
(
block_embeds
,
axis
=
1
).
reshape
(
-
1
,
1
)
with
torch
.
no_grad
():
# get hashes for the queries
hash_scores_pos
=
torch
.
matmul
(
torch
.
cuda
.
HalfTensor
(
query_embeds
),
torch
.
cuda
.
HalfTensor
(
self
.
hash_matrix
))
hash_scores
=
torch
.
cat
((
hash_scores_pos
,
-
hash_scores_pos
),
axis
=
1
)
query_hashes
=
detach
(
torch
.
argmax
(
hash_scores
,
axis
=
1
))
# [num_query x num_blocks]
inner_products
=
torch
.
matmul
(
torch
.
cuda
.
HalfTensor
(
query_embeds
),
torch
.
cuda
.
HalfTensor
(
np
.
transpose
(
np
.
array
(
block_embeds
))))
max_inner_product_idxes
=
detach
(
torch
.
argmax
(
inner_products
,
axis
=
1
))
best_blocks
=
[
self
.
block_data
[
shuffled_block_idx
[
idx
]]
for
idx
in
max_inner_product_idxes
]
best_blocks_tensor
=
torch
.
cuda
.
HalfTensor
(
np
.
array
(
best_blocks
))
# bb = best_blocks
bb_hash_scores_pos
=
torch
.
matmul
(
best_blocks_tensor
,
torch
.
cuda
.
HalfTensor
(
self
.
hash_matrix
))
bb_hash_scores
=
torch
.
cat
((
bb_hash_scores_pos
,
-
bb_hash_scores_pos
),
axis
=
1
)
best_block_hashes
=
detach
(
torch
.
argmax
(
bb_hash_scores
,
axis
=
1
))
print
(
'Query hashes: '
,
query_hashes
)
print
(
'Block hashes: '
,
best_block_hashes
)
equal_arr
=
np
.
equal
(
query_hashes
,
best_block_hashes
).
astype
(
int
)
# array of zeros and ones which can be used for counting success
return
equal_arr
def
exact_mips_test
(
self
,
num_queries
,
whitened
,
norm_blocks
,
alsh
):
if
whitened
:
if
self
.
embed_mean
is
None
:
self
.
whiten_block_embeds
()
query_embeds
=
np
.
random
.
multivariate_normal
(
np
.
zeros
(
128
),
np
.
eye
(
128
),
num_queries
)
query_embeds
=
query_embeds
/
np
.
linalg
.
norm
(
query_embeds
,
axis
=
1
).
reshape
(
-
1
,
1
)
if
alsh
:
if
self
.
block_index
is
None
:
self
.
create_block_data_index
()
alsh_queries
=
self
.
alsh_query_preprocess_fn
(
query_embeds
)
neighbor_ids
,
distances
=
self
.
block_index
.
search
(
alsh_queries
,
5
)
print
(
'DONE'
)
return
else
:
block_idx
,
all_embeds
=
zip
(
*
self
.
block_data
.
items
())
arr_embeds
=
np
.
transpose
(
np
.
array
(
all_embeds
))
mean
=
np
.
mean
(
arr_embeds
,
axis
=
1
).
reshape
(
-
1
,
1
)
cov
=
np
.
cov
(
arr_embeds
)
query_embeds
=
np
.
random
.
multivariate_normal
(
mean
,
cov
,
num_queries
)
equal_arr
=
self
.
exact_mips_equals
(
query_embeds
,
norm_blocks
)
print
(
"Num correct: "
,
sum
(
equal_arr
),
" Fraction correct: "
,
sum
(
equal_arr
)
/
equal_arr
.
size
)
print
(
equal_arr
)
@
classmethod
def
load_from_file
(
cls
,
fname
):
print
(
" > Unpickling block hash data"
)
state_dict
=
pickle
.
load
(
open
(
fname
,
'rb'
))
print
(
" > Finished unpickling"
)
hash_matrix
=
state_dict
[
'hash_matrix'
]
new_index
=
HashedIndex
(
hash_matrix
.
shape
[
0
],
hash_matrix
.
shape
[
1
]
*
2
)
new_index
.
block_data
=
state_dict
[
'block_data'
]
new_index
.
hash_data
=
state_dict
[
'hash_data'
]
new_index
.
embed_mean
=
state_dict
.
get
(
'embed_mean'
)
new_index
.
embed_whitener
=
state_dict
.
get
(
'embed_whitener'
)
new_index
.
hash_matrix
=
hash_matrix
return
new_index
def
test_retriever
():
def
test_retriever
():
initialize_megatron
(
extra_args_provider
=
None
,
initialize_megatron
(
extra_args_provider
=
None
,
args_defaults
=
{
'tokenizer_type'
:
'BertWordPieceLowerCase'
})
args_defaults
=
{
'tokenizer_type'
:
'BertWordPieceLowerCase'
})
...
@@ -317,42 +65,42 @@ def main():
...
@@ -317,42 +65,42 @@ def main():
model
.
eval
()
model
.
eval
()
dataset
=
get_ict_dataset
()
dataset
=
get_ict_dataset
()
data_iter
=
iter
(
get_one_epoch_dataloader
(
dataset
))
data_iter
=
iter
(
get_one_epoch_dataloader
(
dataset
))
hashed_index
=
HashedIndex
(
embed_size
=
128
,
num_buckets
=
32
,
whiten
=
True
)
all_block_data
=
BlockData
()
hashed_index
=
RandProjectionLSHIndex
(
embed_size
=
128
,
num_buckets
=
32
,
whiten
=
True
)
i
=
1
i
=
1
total
=
0
total
=
0
whiten
=
False
while
True
:
while
True
:
try
:
try
:
query_tokens
,
query_pad_mask
,
\
query_tokens
,
query_pad_mask
,
\
block_tokens
,
block_pad_mask
,
block_ind
ices
=
get_batch
(
data_iter
)
block_tokens
,
block_pad_mask
,
block_ind
ex_data
=
get_batch
(
data_iter
)
except
:
except
:
break
break
block_indices
=
detach
(
block_indices
)
block_index_data
=
detach
(
block_index_data
)
block_logits
=
model
(
None
,
None
,
block_tokens
,
block_pad_mask
,
only_block
=
True
)
block_indices
=
block_index_data
[:,
3
]
block_meta
=
block_index_data
[:,
:
3
]
# If whitened, then hashing needs to be done after whitening the block embeds
block_logits
=
model
(
None
,
None
,
block_tokens
,
block_pad_mask
,
only_block
=
True
)
# which is done in consolidate_shards_and_save()
all_block_data
.
add_block_data
(
block_indices
,
block_logits
,
block_meta
)
if
not
whiten
:
hashed_index
.
hash_embeds
(
block_logits
,
block_indices
)
hashed_index
.
assign_block_embeds
(
block_indices
[:,
3
],
detach
(
block_logits
))
total
+=
block_indices
.
s
hape
[
0
]
total
+=
block_indices
.
s
ize
i
+=
1
i
+=
1
if
i
%
20
==
0
:
if
i
%
20
==
0
:
print
(
'Batch {:10d} | Total {:10d}'
.
format
(
i
,
total
),
flush
=
True
)
print
(
'Batch {:10d} | Total {:10d}'
.
format
(
i
,
total
),
flush
=
True
)
if
args
.
debug
:
if
args
.
debug
:
break
break
hashed_index
.
save_shard
(
args
.
rank
)
all_block_data
.
save_shard
(
args
.
rank
)
torch
.
distributed
.
barrier
()
torch
.
distributed
.
barrier
()
del
model
del
model
if
args
.
rank
==
0
:
if
args
.
rank
==
0
:
hashed_index
.
consolidate_shards_and_save
()
all_block_data
.
consolidate_shards_and_save
()
hashed_index
.
hash_whitened_block_embeds
(
all_block_data
)
hashed_index
.
save_to_file
()
else
:
else
:
hashed_index
.
clear
()
all_block_data
.
clear
()
def
load_ict_checkpoint
(
only_query_model
=
False
,
only_block_model
=
False
,
no_grad
=
False
):
def
load_ict_checkpoint
(
only_query_model
=
False
,
only_block_model
=
False
,
no_grad
=
False
):
...
...
megatron/data/realm_dataset.py
View file @
183ad176
import
itertools
import
itertools
import
os
import
random
import
time
import
numpy
as
np
import
numpy
as
np
import
spacy
import
spacy
import
torch
from
torch.utils.data
import
Dataset
from
megatron
import
get_tokenizer
from
megatron
import
get_tokenizer
,
print_rank_0
,
mpu
from
megatron.data.bert_dataset
import
BertDataset
,
get_samples_mapping_
from
megatron.data.bert_dataset
import
BertDataset
from
megatron.data.dataset_utils
import
create_masked_lm_predictions
,
pad_and_convert_to_numpy
from
megatron.data.dataset_utils
import
create_masked_lm_predictions
,
pad_and_convert_to_numpy
#
qa_nlp = spacy.load('en_core_web_lg')
qa_nlp
=
spacy
.
load
(
'en_core_web_lg'
)
qa_nlp
=
None
class
RealmDataset
(
BertDataset
):
class
RealmDataset
(
BertDataset
):
"""Dataset containing simple masked sentences for masked language modeling.
"""Dataset containing simple masked sentences for masked language modeling.
...
@@ -74,3 +79,170 @@ def spacy_ner(block_text):
...
@@ -74,3 +79,170 @@ def spacy_ner(block_text):
answers
.
append
(
str
(
ent
.
text
))
answers
.
append
(
str
(
ent
.
text
))
candidates
[
'starts'
]
=
starts
candidates
[
'starts'
]
=
starts
candidates
[
'answers'
]
=
answers
candidates
[
'answers'
]
=
answers
class
InverseClozeDataset
(
Dataset
):
"""Dataset containing sentences and their blocks for an inverse cloze task."""
def
__init__
(
self
,
name
,
block_dataset
,
title_dataset
,
data_prefix
,
num_epochs
,
max_num_samples
,
max_seq_length
,
short_seq_prob
,
seed
):
self
.
name
=
name
self
.
seed
=
seed
self
.
max_seq_length
=
max_seq_length
self
.
block_dataset
=
block_dataset
self
.
title_dataset
=
title_dataset
self
.
short_seq_prob
=
short_seq_prob
self
.
rng
=
random
.
Random
(
self
.
seed
)
self
.
samples_mapping
=
self
.
get_samples_mapping
(
data_prefix
,
num_epochs
,
max_num_samples
)
self
.
tokenizer
=
get_tokenizer
()
self
.
vocab_id_list
=
list
(
self
.
tokenizer
.
inv_vocab
.
keys
())
self
.
vocab_id_to_token_list
=
self
.
tokenizer
.
inv_vocab
self
.
cls_id
=
self
.
tokenizer
.
cls
self
.
sep_id
=
self
.
tokenizer
.
sep
self
.
mask_id
=
self
.
tokenizer
.
mask
self
.
pad_id
=
self
.
tokenizer
.
pad
def
__len__
(
self
):
return
self
.
samples_mapping
.
shape
[
0
]
def
__getitem__
(
self
,
idx
):
start_idx
,
end_idx
,
doc_idx
,
block_idx
=
self
.
samples_mapping
[
idx
]
title
=
list
(
self
.
title_dataset
[
int
(
doc_idx
)])
block
=
[
list
(
self
.
block_dataset
[
i
])
for
i
in
range
(
start_idx
,
end_idx
)]
assert
len
(
block
)
>
1
# avoid selecting the first or last sentence to be the query.
if
len
(
block
)
==
2
:
rand_sent_idx
=
int
(
self
.
rng
.
random
()
>
0.5
)
else
:
rand_sent_idx
=
self
.
rng
.
randint
(
1
,
len
(
block
)
-
2
)
# keep the query in the context 10% of the time.
if
self
.
rng
.
random
()
<
1
:
query
=
block
[
rand_sent_idx
].
copy
()
else
:
query
=
block
.
pop
(
rand_sent_idx
)
# still need to truncate because blocks are concluded when
# the sentence lengths have exceeded max_seq_length.
query
=
query
[:
self
.
max_seq_length
-
2
]
block
=
list
(
itertools
.
chain
(
*
block
))[:
self
.
max_seq_length
-
(
3
+
len
(
title
))]
query_tokens
,
query_pad_mask
=
self
.
concat_and_pad_tokens
(
query
)
block_tokens
,
block_pad_mask
=
self
.
concat_and_pad_tokens
(
block
,
title
)
sample
=
{
'query_tokens'
:
np
.
array
(
query_tokens
),
'query_pad_mask'
:
np
.
array
(
query_pad_mask
),
'block_tokens'
:
np
.
array
(
block_tokens
),
'block_pad_mask'
:
np
.
array
(
block_pad_mask
),
'block_data'
:
np
.
array
([
start_idx
,
end_idx
,
doc_idx
,
block_idx
]).
astype
(
np
.
int64
)
}
return
sample
def
encode_text
(
self
,
text
):
return
self
.
tokenizer
.
tokenize
(
text
)
def
decode_tokens
(
self
,
token_ids
):
tokens
=
self
.
tokenizer
.
tokenizer
.
convert_ids_to_tokens
(
token_ids
)
return
' '
.
join
(
token
for
token
in
tokens
if
token
!=
'[PAD]'
)
def
get_block
(
self
,
start_idx
,
end_idx
,
doc_idx
):
"""Get the IDs for an evidence block plus the title of the corresponding document"""
block
=
[
list
(
self
.
block_dataset
[
i
])
for
i
in
range
(
start_idx
,
end_idx
)]
title
=
list
(
self
.
title_dataset
[
int
(
doc_idx
)])
block
=
list
(
itertools
.
chain
(
*
block
))[:
self
.
max_seq_length
-
(
3
+
len
(
title
))]
block_tokens
,
block_pad_mask
=
self
.
concat_and_pad_tokens
(
block
,
title
)
return
(
block_tokens
,
block_pad_mask
)
def
concat_and_pad_tokens
(
self
,
tokens
,
title
=
None
):
"""concat with special tokens and pad sequence to self.max_seq_length"""
tokens
=
[
self
.
cls_id
]
+
tokens
+
[
self
.
sep_id
]
if
title
is
not
None
:
tokens
+=
title
+
[
self
.
sep_id
]
assert
len
(
tokens
)
<=
self
.
max_seq_length
,
len
(
tokens
)
num_pad
=
self
.
max_seq_length
-
len
(
tokens
)
pad_mask
=
[
1
]
*
len
(
tokens
)
+
[
0
]
*
num_pad
tokens
+=
[
self
.
pad_id
]
*
num_pad
return
tokens
,
pad_mask
def
get_samples_mapping
(
self
,
data_prefix
,
num_epochs
,
max_num_samples
):
if
not
num_epochs
:
if
not
max_num_samples
:
raise
ValueError
(
"Need to specify either max_num_samples "
"or num_epochs"
)
num_epochs
=
np
.
iinfo
(
np
.
int32
).
max
-
1
if
not
max_num_samples
:
max_num_samples
=
np
.
iinfo
(
np
.
int64
).
max
-
1
# Filename of the index mapping
indexmap_filename
=
data_prefix
indexmap_filename
+=
'_{}_indexmap'
.
format
(
self
.
name
)
if
num_epochs
!=
(
np
.
iinfo
(
np
.
int32
).
max
-
1
):
indexmap_filename
+=
'_{}ep'
.
format
(
num_epochs
)
if
max_num_samples
!=
(
np
.
iinfo
(
np
.
int64
).
max
-
1
):
indexmap_filename
+=
'_{}mns'
.
format
(
max_num_samples
)
indexmap_filename
+=
'_{}msl'
.
format
(
self
.
max_seq_length
)
indexmap_filename
+=
'_{}s'
.
format
(
self
.
seed
)
indexmap_filename
+=
'.npy'
# Build the indexed mapping if not exist.
if
torch
.
distributed
.
get_rank
()
==
0
and
\
not
os
.
path
.
isfile
(
indexmap_filename
):
print
(
' > WARNING: could not find index map file {}, building '
'the indices on rank 0 ...'
.
format
(
indexmap_filename
))
# Make sure the types match the helpers input types.
assert
self
.
block_dataset
.
doc_idx
.
dtype
==
np
.
int64
assert
self
.
block_dataset
.
sizes
.
dtype
==
np
.
int32
# Build samples mapping
verbose
=
torch
.
distributed
.
get_rank
()
==
0
start_time
=
time
.
time
()
print_rank_0
(
' > building samples index mapping for {} ...'
.
format
(
self
.
name
))
from
megatron.data.dataset_utils
import
compile_helper
compile_helper
()
from
megatron.data
import
helpers
samples_mapping
=
helpers
.
build_blocks_mapping
(
self
.
block_dataset
.
doc_idx
,
self
.
block_dataset
.
sizes
,
self
.
title_dataset
.
sizes
,
num_epochs
,
max_num_samples
,
self
.
max_seq_length
-
3
,
# account for added tokens
self
.
seed
,
verbose
)
print_rank_0
(
' > done building samples index mapping'
)
np
.
save
(
indexmap_filename
,
samples_mapping
,
allow_pickle
=
True
)
print_rank_0
(
' > saved the index mapping in {}'
.
format
(
indexmap_filename
))
# Make sure all the ranks have built the mapping
print_rank_0
(
' > elapsed time to build and save samples mapping '
'(seconds): {:4f}'
.
format
(
time
.
time
()
-
start_time
))
# This should be a barrier but nccl barrier assumes
# device_index=rank which is not the case for model
# parallel case
counts
=
torch
.
cuda
.
LongTensor
([
1
])
torch
.
distributed
.
all_reduce
(
counts
,
group
=
mpu
.
get_data_parallel_group
())
assert
counts
[
0
].
item
()
==
torch
.
distributed
.
get_world_size
(
group
=
mpu
.
get_data_parallel_group
())
# Load indexed dataset.
print_rank_0
(
' > loading indexed mapping from {}'
.
format
(
indexmap_filename
))
start_time
=
time
.
time
()
samples_mapping
=
np
.
load
(
indexmap_filename
,
allow_pickle
=
True
)
print_rank_0
(
' loaded indexed file in {:3.3f} seconds'
.
format
(
time
.
time
()
-
start_time
))
print_rank_0
(
' total number of samples: {}'
.
format
(
samples_mapping
.
shape
[
0
]))
return
samples_mapping
pretrain_realm.py
View file @
183ad176
...
@@ -18,7 +18,8 @@
...
@@ -18,7 +18,8 @@
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
hashed_index
import
HashedIndex
,
load_ict_checkpoint
,
get_ict_dataset
from
hashed_index
import
load_ict_checkpoint
,
get_ict_dataset
from
megatron.data.realm_dataset
import
HashedIndex
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron
import
get_timers
from
megatron
import
get_timers
from
megatron
import
mpu
from
megatron
import
mpu
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment