Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
183ad176
Commit
183ad176
authored
May 03, 2020
by
Neel Kant
Browse files
Refactor to build BlockData, FaissMIPSIndex, RandProjectLSHIndex
parent
0104f910
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
194 additions
and
273 deletions
+194
-273
hashed_index.py
hashed_index.py
+16
-268
megatron/data/realm_dataset.py
megatron/data/realm_dataset.py
+176
-4
pretrain_realm.py
pretrain_realm.py
+2
-1
No files found.
hashed_index.py
View file @
183ad176
from
collections
import
defaultdict
import
os
import
pickle
import
shutil
import
numpy
as
np
import
torch
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
...
...
@@ -11,7 +5,8 @@ from megatron import get_args
from
megatron
import
mpu
from
megatron.checkpointing
import
get_checkpoint_tracker_filename
,
get_checkpoint_name
from
megatron.data.bert_dataset
import
get_indexed_dataset_
from
megatron.data.ict_dataset
import
InverseClozeDataset
from
megatron.data.realm_dataset
import
InverseClozeDataset
from
megatron.data.realm_index
import
BlockData
,
RandProjectionLSHIndex
from
megatron.data.samplers
import
DistributedBatchSampler
from
megatron.initialize
import
initialize_megatron
from
megatron.model
import
REALMRetriever
...
...
@@ -23,253 +18,6 @@ def detach(tensor):
return
tensor
.
detach
().
cpu
().
numpy
()
class
HashedIndex
(
object
):
"""Class for holding hashed data"""
def
__init__
(
self
,
embed_size
,
num_buckets
,
whiten
=
False
,
seed
=
0
):
np
.
random
.
seed
(
seed
)
self
.
block_data
=
defaultdict
(
list
)
self
.
hash_data
=
defaultdict
(
list
)
hash_matrix
=
2
*
np
.
random
.
rand
(
embed_size
,
int
(
num_buckets
/
2
))
-
1
self
.
hash_matrix
=
hash_matrix
/
np
.
linalg
.
norm
(
hash_matrix
,
axis
=
0
).
reshape
(
1
,
-
1
)
self
.
embed_mean
=
None
self
.
embed_whitener
=
None
self
.
whiten
=
whiten
# alsh
self
.
m
=
5
self
.
u
=
0.99
self
.
max_norm
=
None
self
.
block_index
=
None
def
state
(
self
):
state
=
{
'block_data'
:
self
.
block_data
,
'hash_data'
:
self
.
hash_data
,
'hash_matrix'
:
self
.
hash_matrix
,
'embed_mean'
:
self
.
embed_mean
,
'embed_whitener'
:
self
.
embed_whitener
,
}
return
state
def
get_block_bucket
(
self
,
hash
):
return
self
.
hash_data
[
hash
]
def
get_block_embed
(
self
,
block_idx
):
return
self
.
block_data
[
block_idx
]
def
hash_embeds
(
self
,
embeds
,
block_data
=
None
):
"""Hash a tensor of embeddings using a random projection matrix"""
embed_scores_pos
=
torch
.
matmul
(
embeds
,
torch
.
cuda
.
FloatTensor
(
self
.
hash_matrix
))
embed_scores
=
torch
.
cat
((
embed_scores_pos
,
-
embed_scores_pos
),
axis
=
1
)
embed_hashes
=
detach
(
torch
.
argmax
(
embed_scores
,
axis
=
1
))
if
block_data
is
not
None
:
for
hash
,
indices
in
zip
(
embed_hashes
,
block_data
):
self
.
hash_data
[
hash
].
append
(
indices
)
return
embed_hashes
def
assign_block_embeds
(
self
,
block_indices
,
block_embeds
,
allow_overwrite
=
False
):
"""Assign the embeddings for each block index into a hash map"""
for
idx
,
embed
in
zip
(
block_indices
,
block_embeds
):
if
not
allow_overwrite
and
int
(
idx
)
in
self
.
block_data
:
raise
ValueError
(
"Attempted to overwrite a read-only HashedIndex"
)
self
.
block_data
[
int
(
idx
)]
=
np
.
float16
(
embed
)
def
save_shard
(
self
,
rank
):
dir_name
=
'block_hash_data'
if
not
os
.
path
.
isdir
(
dir_name
):
os
.
mkdir
(
dir_name
)
# save the data for each shard
with
open
(
'{}/{}.pkl'
.
format
(
dir_name
,
rank
),
'wb'
)
as
data_file
:
pickle
.
dump
(
self
.
state
(),
data_file
)
def
consolidate_shards_and_save
(
self
,
ignore_shard
=
0
):
"""Combine all the shards made using self.save_shard()"""
dir_name
=
'block_hash_data'
fnames
=
os
.
listdir
(
dir_name
)
for
fname
in
fnames
:
with
open
(
'{}/{}'
.
format
(
dir_name
,
fname
),
'rb'
)
as
f
:
data
=
pickle
.
load
(
f
)
assert
np
.
array_equal
(
data
[
'hash_matrix'
],
self
.
hash_matrix
)
old_size
=
len
(
self
.
block_data
)
shard_size
=
len
(
data
[
'block_data'
])
self
.
block_data
.
update
(
data
[
'block_data'
])
assert
(
len
(
self
.
block_data
)
==
old_size
+
shard_size
)
or
(
str
(
ignore_shard
)
in
fname
)
if
not
self
.
whiten
:
for
bucket
,
items
in
data
[
'hash_data'
].
items
():
self
.
hash_data
[
bucket
].
extend
(
items
)
if
self
.
whiten
:
self
.
whiten_block_embeds
()
args
=
get_args
()
with
open
(
args
.
hash_data_path
,
'wb'
)
as
final_file
:
pickle
.
dump
(
self
.
state
(),
final_file
)
shutil
.
rmtree
(
dir_name
,
ignore_errors
=
True
)
def
clear
(
self
):
"""Clear the data structures to save memory"""
self
.
block_data
=
dict
()
self
.
hash_data
=
defaultdict
(
list
)
def
whiten_block_embeds
(
self
):
"""Transform all block embeds to have zero mean and unit covariance
when treated as samples from a distribution"""
block_idx
,
all_embeds
=
zip
(
*
self
.
block_data
.
items
())
arr_embeds
=
np
.
transpose
(
np
.
array
(
all_embeds
))
mean
=
np
.
mean
(
arr_embeds
,
axis
=
1
).
reshape
(
-
1
,
1
)
centered
=
arr_embeds
-
mean
inv_cov
=
np
.
linalg
.
inv
(
np
.
cov
(
arr_embeds
))
whitener
=
np
.
transpose
(
np
.
linalg
.
cholesky
(
inv_cov
))
whitened
=
np
.
float16
(
np
.
transpose
(
whitener
.
dot
(
centered
)))
self
.
embed_mean
=
mean
.
reshape
(
-
1
)
self
.
embed_whitener
=
whitener
self
.
block_data
=
dict
(
zip
(
block_idx
,
list
(
whitened
)))
self
.
hash_data
=
defaultdict
(
list
)
batch_size
=
16384
i
=
0
args
=
get_args
()
with
torch
.
no_grad
():
hashing_tensor
=
torch
.
cuda
.
HalfTensor
(
self
.
hash_matrix
)
while
True
:
if
args
.
debug
:
print
(
i
,
flush
=
True
)
batch_slice
=
slice
(
i
*
batch_size
,
(
i
+
1
)
*
batch_size
)
batch_embed
=
torch
.
cuda
.
HalfTensor
(
whitened
[
batch_slice
])
batch_block_idx
=
block_idx
[
batch_slice
]
if
len
(
batch_block_idx
)
==
0
:
break
hash_scores_pos
=
torch
.
matmul
(
batch_embed
,
hashing_tensor
)
embed_scores
=
torch
.
cat
((
hash_scores_pos
,
-
hash_scores_pos
),
axis
=
1
)
embed_hashes
=
detach
(
torch
.
argmax
(
embed_scores
,
axis
=
1
))
for
idx
,
hash
in
zip
(
batch_block_idx
,
list
(
embed_hashes
)):
# [int] instead of [array<int>] since this is just for analysis rn
self
.
hash_data
[
hash
].
append
(
idx
)
i
+=
1
def
create_block_data_index
(
self
):
import
faiss
self
.
block_idx
,
block_embeds
=
zip
(
*
self
.
block_data
.
items
())
block_embeds
=
np
.
array
(
block_embeds
)
alsh_preprocessed_blocks
=
self
.
alsh_block_preprocess_fn
()
index
=
faiss
.
IndexFlatL2
(
alsh_preprocessed_blocks
.
shape
[
1
])
index
.
add
(
alsh_preprocessed_blocks
)
print
(
'Total blocks in index: '
,
index
.
ntotal
)
self
.
block_index
=
index
def
get_norm_powers_and_halves_array
(
self
,
embeds
):
norm
=
np
.
linalg
.
norm
(
embeds
,
axis
=
1
)
norm_powers
=
[
np
.
multiply
(
norm
,
norm
)]
# squared L2 norms of all
for
i
in
range
(
self
.
m
-
1
):
norm_powers
.
append
(
np
.
multiply
(
norm_powers
[
-
1
],
norm_powers
[
-
1
]))
# [num_blocks x self.m]
norm_powers
=
np
.
transpose
(
np
.
array
(
norm_powers
))
halves_array
=
0.5
*
np
.
ones
(
norm_powers
.
shape
)
return
norm_powers
,
halves_array
def
alsh_block_preprocess_fn
(
self
):
block_idx
,
block_embeds
=
zip
(
*
self
.
block_data
.
items
())
block_embeds
=
np
.
array
(
block_embeds
)
if
self
.
max_norm
is
None
:
self
.
max_norm
=
max
(
np
.
linalg
.
norm
(
block_embeds
,
axis
=
1
))
if
self
.
max_norm
>
1
:
block_embeds
=
self
.
u
/
self
.
max_norm
*
block_embeds
norm_powers
,
halves_array
=
self
.
get_norm_powers_and_halves_array
(
block_embeds
)
# P'(S(x)) for all x in block_embeds
return
np
.
float32
(
np
.
concatenate
((
block_embeds
,
norm_powers
,
halves_array
),
axis
=
1
))
def
alsh_query_preprocess_fn
(
self
,
query_embeds
):
max_norm
=
max
(
np
.
linalg
.
norm
(
query_embeds
,
axis
=
1
))
if
max_norm
>
1
:
query_embeds
=
self
.
u
/
max_norm
*
query_embeds
norm_powers
,
halves_array
=
self
.
get_norm_powers_and_halves_array
(
query_embeds
)
# Q'(S(x)) for all x in query_embeds
return
np
.
float32
(
np
.
concatenate
((
query_embeds
,
halves_array
,
norm_powers
),
axis
=
1
))
def
exact_mips_equals
(
self
,
query_embeds
,
norm_blocks
):
"""For each query, determine whether the mips block is in the correct hash bucket"""
shuffled_block_idx
,
block_embeds
=
zip
(
*
self
.
block_data
.
items
())
if
norm_blocks
:
block_embeds
=
block_embeds
/
np
.
linalg
.
norm
(
block_embeds
,
axis
=
1
).
reshape
(
-
1
,
1
)
with
torch
.
no_grad
():
# get hashes for the queries
hash_scores_pos
=
torch
.
matmul
(
torch
.
cuda
.
HalfTensor
(
query_embeds
),
torch
.
cuda
.
HalfTensor
(
self
.
hash_matrix
))
hash_scores
=
torch
.
cat
((
hash_scores_pos
,
-
hash_scores_pos
),
axis
=
1
)
query_hashes
=
detach
(
torch
.
argmax
(
hash_scores
,
axis
=
1
))
# [num_query x num_blocks]
inner_products
=
torch
.
matmul
(
torch
.
cuda
.
HalfTensor
(
query_embeds
),
torch
.
cuda
.
HalfTensor
(
np
.
transpose
(
np
.
array
(
block_embeds
))))
max_inner_product_idxes
=
detach
(
torch
.
argmax
(
inner_products
,
axis
=
1
))
best_blocks
=
[
self
.
block_data
[
shuffled_block_idx
[
idx
]]
for
idx
in
max_inner_product_idxes
]
best_blocks_tensor
=
torch
.
cuda
.
HalfTensor
(
np
.
array
(
best_blocks
))
# bb = best_blocks
bb_hash_scores_pos
=
torch
.
matmul
(
best_blocks_tensor
,
torch
.
cuda
.
HalfTensor
(
self
.
hash_matrix
))
bb_hash_scores
=
torch
.
cat
((
bb_hash_scores_pos
,
-
bb_hash_scores_pos
),
axis
=
1
)
best_block_hashes
=
detach
(
torch
.
argmax
(
bb_hash_scores
,
axis
=
1
))
print
(
'Query hashes: '
,
query_hashes
)
print
(
'Block hashes: '
,
best_block_hashes
)
equal_arr
=
np
.
equal
(
query_hashes
,
best_block_hashes
).
astype
(
int
)
# array of zeros and ones which can be used for counting success
return
equal_arr
def
exact_mips_test
(
self
,
num_queries
,
whitened
,
norm_blocks
,
alsh
):
if
whitened
:
if
self
.
embed_mean
is
None
:
self
.
whiten_block_embeds
()
query_embeds
=
np
.
random
.
multivariate_normal
(
np
.
zeros
(
128
),
np
.
eye
(
128
),
num_queries
)
query_embeds
=
query_embeds
/
np
.
linalg
.
norm
(
query_embeds
,
axis
=
1
).
reshape
(
-
1
,
1
)
if
alsh
:
if
self
.
block_index
is
None
:
self
.
create_block_data_index
()
alsh_queries
=
self
.
alsh_query_preprocess_fn
(
query_embeds
)
neighbor_ids
,
distances
=
self
.
block_index
.
search
(
alsh_queries
,
5
)
print
(
'DONE'
)
return
else
:
block_idx
,
all_embeds
=
zip
(
*
self
.
block_data
.
items
())
arr_embeds
=
np
.
transpose
(
np
.
array
(
all_embeds
))
mean
=
np
.
mean
(
arr_embeds
,
axis
=
1
).
reshape
(
-
1
,
1
)
cov
=
np
.
cov
(
arr_embeds
)
query_embeds
=
np
.
random
.
multivariate_normal
(
mean
,
cov
,
num_queries
)
equal_arr
=
self
.
exact_mips_equals
(
query_embeds
,
norm_blocks
)
print
(
"Num correct: "
,
sum
(
equal_arr
),
" Fraction correct: "
,
sum
(
equal_arr
)
/
equal_arr
.
size
)
print
(
equal_arr
)
@
classmethod
def
load_from_file
(
cls
,
fname
):
print
(
" > Unpickling block hash data"
)
state_dict
=
pickle
.
load
(
open
(
fname
,
'rb'
))
print
(
" > Finished unpickling"
)
hash_matrix
=
state_dict
[
'hash_matrix'
]
new_index
=
HashedIndex
(
hash_matrix
.
shape
[
0
],
hash_matrix
.
shape
[
1
]
*
2
)
new_index
.
block_data
=
state_dict
[
'block_data'
]
new_index
.
hash_data
=
state_dict
[
'hash_data'
]
new_index
.
embed_mean
=
state_dict
.
get
(
'embed_mean'
)
new_index
.
embed_whitener
=
state_dict
.
get
(
'embed_whitener'
)
new_index
.
hash_matrix
=
hash_matrix
return
new_index
def
test_retriever
():
initialize_megatron
(
extra_args_provider
=
None
,
args_defaults
=
{
'tokenizer_type'
:
'BertWordPieceLowerCase'
})
...
...
@@ -317,42 +65,42 @@ def main():
model
.
eval
()
dataset
=
get_ict_dataset
()
data_iter
=
iter
(
get_one_epoch_dataloader
(
dataset
))
hashed_index
=
HashedIndex
(
embed_size
=
128
,
num_buckets
=
32
,
whiten
=
True
)
all_block_data
=
BlockData
()
hashed_index
=
RandProjectionLSHIndex
(
embed_size
=
128
,
num_buckets
=
32
,
whiten
=
True
)
i
=
1
total
=
0
whiten
=
False
while
True
:
try
:
query_tokens
,
query_pad_mask
,
\
block_tokens
,
block_pad_mask
,
block_ind
ices
=
get_batch
(
data_iter
)
block_tokens
,
block_pad_mask
,
block_ind
ex_data
=
get_batch
(
data_iter
)
except
:
break
block_indices
=
detach
(
block_indices
)
block_logits
=
model
(
None
,
None
,
block_tokens
,
block_pad_mask
,
only_block
=
True
)
block_index_data
=
detach
(
block_index_data
)
block_indices
=
block_index_data
[:,
3
]
block_meta
=
block_index_data
[:,
:
3
]
# If whitened, then hashing needs to be done after whitening the block embeds
# which is done in consolidate_shards_and_save()
if
not
whiten
:
hashed_index
.
hash_embeds
(
block_logits
,
block_indices
)
hashed_index
.
assign_block_embeds
(
block_indices
[:,
3
],
detach
(
block_logits
))
block_logits
=
model
(
None
,
None
,
block_tokens
,
block_pad_mask
,
only_block
=
True
)
all_block_data
.
add_block_data
(
block_indices
,
block_logits
,
block_meta
)
total
+=
block_indices
.
s
hape
[
0
]
total
+=
block_indices
.
s
ize
i
+=
1
if
i
%
20
==
0
:
print
(
'Batch {:10d} | Total {:10d}'
.
format
(
i
,
total
),
flush
=
True
)
if
args
.
debug
:
break
hashed_index
.
save_shard
(
args
.
rank
)
all_block_data
.
save_shard
(
args
.
rank
)
torch
.
distributed
.
barrier
()
del
model
if
args
.
rank
==
0
:
hashed_index
.
consolidate_shards_and_save
()
all_block_data
.
consolidate_shards_and_save
()
hashed_index
.
hash_whitened_block_embeds
(
all_block_data
)
hashed_index
.
save_to_file
()
else
:
hashed_index
.
clear
()
all_block_data
.
clear
()
def
load_ict_checkpoint
(
only_query_model
=
False
,
only_block_model
=
False
,
no_grad
=
False
):
...
...
megatron/data/realm_dataset.py
View file @
183ad176
import
itertools
import
os
import
random
import
time
import
numpy
as
np
import
spacy
import
torch
from
torch.utils.data
import
Dataset
from
megatron
import
get_tokenizer
from
megatron.data.bert_dataset
import
BertDataset
,
get_samples_mapping_
from
megatron
import
get_tokenizer
,
print_rank_0
,
mpu
from
megatron.data.bert_dataset
import
BertDataset
from
megatron.data.dataset_utils
import
create_masked_lm_predictions
,
pad_and_convert_to_numpy
#
qa_nlp = spacy.load('en_core_web_lg')
qa_nlp
=
None
qa_nlp
=
spacy
.
load
(
'en_core_web_lg'
)
class
RealmDataset
(
BertDataset
):
"""Dataset containing simple masked sentences for masked language modeling.
...
...
@@ -74,3 +79,170 @@ def spacy_ner(block_text):
answers
.
append
(
str
(
ent
.
text
))
candidates
[
'starts'
]
=
starts
candidates
[
'answers'
]
=
answers
class
InverseClozeDataset
(
Dataset
):
"""Dataset containing sentences and their blocks for an inverse cloze task."""
def
__init__
(
self
,
name
,
block_dataset
,
title_dataset
,
data_prefix
,
num_epochs
,
max_num_samples
,
max_seq_length
,
short_seq_prob
,
seed
):
self
.
name
=
name
self
.
seed
=
seed
self
.
max_seq_length
=
max_seq_length
self
.
block_dataset
=
block_dataset
self
.
title_dataset
=
title_dataset
self
.
short_seq_prob
=
short_seq_prob
self
.
rng
=
random
.
Random
(
self
.
seed
)
self
.
samples_mapping
=
self
.
get_samples_mapping
(
data_prefix
,
num_epochs
,
max_num_samples
)
self
.
tokenizer
=
get_tokenizer
()
self
.
vocab_id_list
=
list
(
self
.
tokenizer
.
inv_vocab
.
keys
())
self
.
vocab_id_to_token_list
=
self
.
tokenizer
.
inv_vocab
self
.
cls_id
=
self
.
tokenizer
.
cls
self
.
sep_id
=
self
.
tokenizer
.
sep
self
.
mask_id
=
self
.
tokenizer
.
mask
self
.
pad_id
=
self
.
tokenizer
.
pad
def
__len__
(
self
):
return
self
.
samples_mapping
.
shape
[
0
]
def
__getitem__
(
self
,
idx
):
start_idx
,
end_idx
,
doc_idx
,
block_idx
=
self
.
samples_mapping
[
idx
]
title
=
list
(
self
.
title_dataset
[
int
(
doc_idx
)])
block
=
[
list
(
self
.
block_dataset
[
i
])
for
i
in
range
(
start_idx
,
end_idx
)]
assert
len
(
block
)
>
1
# avoid selecting the first or last sentence to be the query.
if
len
(
block
)
==
2
:
rand_sent_idx
=
int
(
self
.
rng
.
random
()
>
0.5
)
else
:
rand_sent_idx
=
self
.
rng
.
randint
(
1
,
len
(
block
)
-
2
)
# keep the query in the context 10% of the time.
if
self
.
rng
.
random
()
<
1
:
query
=
block
[
rand_sent_idx
].
copy
()
else
:
query
=
block
.
pop
(
rand_sent_idx
)
# still need to truncate because blocks are concluded when
# the sentence lengths have exceeded max_seq_length.
query
=
query
[:
self
.
max_seq_length
-
2
]
block
=
list
(
itertools
.
chain
(
*
block
))[:
self
.
max_seq_length
-
(
3
+
len
(
title
))]
query_tokens
,
query_pad_mask
=
self
.
concat_and_pad_tokens
(
query
)
block_tokens
,
block_pad_mask
=
self
.
concat_and_pad_tokens
(
block
,
title
)
sample
=
{
'query_tokens'
:
np
.
array
(
query_tokens
),
'query_pad_mask'
:
np
.
array
(
query_pad_mask
),
'block_tokens'
:
np
.
array
(
block_tokens
),
'block_pad_mask'
:
np
.
array
(
block_pad_mask
),
'block_data'
:
np
.
array
([
start_idx
,
end_idx
,
doc_idx
,
block_idx
]).
astype
(
np
.
int64
)
}
return
sample
def
encode_text
(
self
,
text
):
return
self
.
tokenizer
.
tokenize
(
text
)
def
decode_tokens
(
self
,
token_ids
):
tokens
=
self
.
tokenizer
.
tokenizer
.
convert_ids_to_tokens
(
token_ids
)
return
' '
.
join
(
token
for
token
in
tokens
if
token
!=
'[PAD]'
)
def
get_block
(
self
,
start_idx
,
end_idx
,
doc_idx
):
"""Get the IDs for an evidence block plus the title of the corresponding document"""
block
=
[
list
(
self
.
block_dataset
[
i
])
for
i
in
range
(
start_idx
,
end_idx
)]
title
=
list
(
self
.
title_dataset
[
int
(
doc_idx
)])
block
=
list
(
itertools
.
chain
(
*
block
))[:
self
.
max_seq_length
-
(
3
+
len
(
title
))]
block_tokens
,
block_pad_mask
=
self
.
concat_and_pad_tokens
(
block
,
title
)
return
(
block_tokens
,
block_pad_mask
)
def
concat_and_pad_tokens
(
self
,
tokens
,
title
=
None
):
"""concat with special tokens and pad sequence to self.max_seq_length"""
tokens
=
[
self
.
cls_id
]
+
tokens
+
[
self
.
sep_id
]
if
title
is
not
None
:
tokens
+=
title
+
[
self
.
sep_id
]
assert
len
(
tokens
)
<=
self
.
max_seq_length
,
len
(
tokens
)
num_pad
=
self
.
max_seq_length
-
len
(
tokens
)
pad_mask
=
[
1
]
*
len
(
tokens
)
+
[
0
]
*
num_pad
tokens
+=
[
self
.
pad_id
]
*
num_pad
return
tokens
,
pad_mask
def
get_samples_mapping
(
self
,
data_prefix
,
num_epochs
,
max_num_samples
):
if
not
num_epochs
:
if
not
max_num_samples
:
raise
ValueError
(
"Need to specify either max_num_samples "
"or num_epochs"
)
num_epochs
=
np
.
iinfo
(
np
.
int32
).
max
-
1
if
not
max_num_samples
:
max_num_samples
=
np
.
iinfo
(
np
.
int64
).
max
-
1
# Filename of the index mapping
indexmap_filename
=
data_prefix
indexmap_filename
+=
'_{}_indexmap'
.
format
(
self
.
name
)
if
num_epochs
!=
(
np
.
iinfo
(
np
.
int32
).
max
-
1
):
indexmap_filename
+=
'_{}ep'
.
format
(
num_epochs
)
if
max_num_samples
!=
(
np
.
iinfo
(
np
.
int64
).
max
-
1
):
indexmap_filename
+=
'_{}mns'
.
format
(
max_num_samples
)
indexmap_filename
+=
'_{}msl'
.
format
(
self
.
max_seq_length
)
indexmap_filename
+=
'_{}s'
.
format
(
self
.
seed
)
indexmap_filename
+=
'.npy'
# Build the indexed mapping if not exist.
if
torch
.
distributed
.
get_rank
()
==
0
and
\
not
os
.
path
.
isfile
(
indexmap_filename
):
print
(
' > WARNING: could not find index map file {}, building '
'the indices on rank 0 ...'
.
format
(
indexmap_filename
))
# Make sure the types match the helpers input types.
assert
self
.
block_dataset
.
doc_idx
.
dtype
==
np
.
int64
assert
self
.
block_dataset
.
sizes
.
dtype
==
np
.
int32
# Build samples mapping
verbose
=
torch
.
distributed
.
get_rank
()
==
0
start_time
=
time
.
time
()
print_rank_0
(
' > building samples index mapping for {} ...'
.
format
(
self
.
name
))
from
megatron.data.dataset_utils
import
compile_helper
compile_helper
()
from
megatron.data
import
helpers
samples_mapping
=
helpers
.
build_blocks_mapping
(
self
.
block_dataset
.
doc_idx
,
self
.
block_dataset
.
sizes
,
self
.
title_dataset
.
sizes
,
num_epochs
,
max_num_samples
,
self
.
max_seq_length
-
3
,
# account for added tokens
self
.
seed
,
verbose
)
print_rank_0
(
' > done building samples index mapping'
)
np
.
save
(
indexmap_filename
,
samples_mapping
,
allow_pickle
=
True
)
print_rank_0
(
' > saved the index mapping in {}'
.
format
(
indexmap_filename
))
# Make sure all the ranks have built the mapping
print_rank_0
(
' > elapsed time to build and save samples mapping '
'(seconds): {:4f}'
.
format
(
time
.
time
()
-
start_time
))
# This should be a barrier but nccl barrier assumes
# device_index=rank which is not the case for model
# parallel case
counts
=
torch
.
cuda
.
LongTensor
([
1
])
torch
.
distributed
.
all_reduce
(
counts
,
group
=
mpu
.
get_data_parallel_group
())
assert
counts
[
0
].
item
()
==
torch
.
distributed
.
get_world_size
(
group
=
mpu
.
get_data_parallel_group
())
# Load indexed dataset.
print_rank_0
(
' > loading indexed mapping from {}'
.
format
(
indexmap_filename
))
start_time
=
time
.
time
()
samples_mapping
=
np
.
load
(
indexmap_filename
,
allow_pickle
=
True
)
print_rank_0
(
' loaded indexed file in {:3.3f} seconds'
.
format
(
time
.
time
()
-
start_time
))
print_rank_0
(
' total number of samples: {}'
.
format
(
samples_mapping
.
shape
[
0
]))
return
samples_mapping
pretrain_realm.py
View file @
183ad176
...
...
@@ -18,7 +18,8 @@
import
torch
import
torch.nn.functional
as
F
from
hashed_index
import
HashedIndex
,
load_ict_checkpoint
,
get_ict_dataset
from
hashed_index
import
load_ict_checkpoint
,
get_ict_dataset
from
megatron.data.realm_dataset
import
HashedIndex
from
megatron
import
get_args
from
megatron
import
get_timers
from
megatron
import
mpu
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment