Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
3aca1415
Commit
3aca1415
authored
Apr 29, 2024
by
liangjing
Browse files
Merge branch 'megatron-lm_dtk24.04' into 'main'
Megatron lm dtk24.04 See merge request
!1
parents
0024a5c6
1005e9d3
Pipeline
#1806
passed with stage
Changes
204
Pipelines
3
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
932 additions
and
735 deletions
+932
-735
tools/retro/cli/__main__.py
tools/retro/cli/__main__.py
+2
-0
tools/retro/cli/cli.py
tools/retro/cli/cli.py
+299
-0
tools/retro/db/build.py
tools/retro/db/build.py
+100
-137
tools/retro/db/dataset.py
tools/retro/db/dataset.py
+15
-1
tools/retro/db/utils.py
tools/retro/db/utils.py
+24
-28
tools/retro/examples/get_dataset_configs.sh
tools/retro/examples/get_dataset_configs.sh
+0
-40
tools/retro/examples/get_preprocess_cmd.sh
tools/retro/examples/get_preprocess_cmd.sh
+0
-138
tools/retro/examples/preprocess_data.sh
tools/retro/examples/preprocess_data.sh
+117
-29
tools/retro/examples/pretrain_model.sh
tools/retro/examples/pretrain_model.sh
+62
-66
tools/retro/index/__init__.py
tools/retro/index/__init__.py
+2
-1
tools/retro/index/build.py
tools/retro/index/build.py
+54
-6
tools/retro/index/index.py
tools/retro/index/index.py
+14
-2
tools/retro/index/indexes/faiss_base.py
tools/retro/index/indexes/faiss_base.py
+17
-6
tools/retro/index/indexes/faiss_par_add.py
tools/retro/index/indexes/faiss_par_add.py
+24
-11
tools/retro/index/utils.py
tools/retro/index/utils.py
+14
-114
tools/retro/main.py
tools/retro/main.py
+95
-52
tools/retro/pretraining/utils.py
tools/retro/pretraining/utils.py
+0
-10
tools/retro/query/__init__.py
tools/retro/query/__init__.py
+3
-0
tools/retro/query/chunk_dataset.py
tools/retro/query/chunk_dataset.py
+23
-27
tools/retro/query/query.py
tools/retro/query/query.py
+67
-67
No files found.
tools/retro/cli/__main__.py
View file @
3aca1415
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import
os
from
.
import
retro
...
...
tools/retro/cli/cli.py
0 → 100644
View file @
3aca1415
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import
json
import
numpy
as
np
import
os
import
torch
import
types
from
megatron.global_vars
import
set_global_variables
,
set_retro_args
from
megatron.initialize
import
(
initialize_megatron
,
_initialize_distributed
,
_set_random_seed
,
_compile_dependencies
,
)
from
tools.retro.db.utils
import
(
get_indexed_dataset_infos
as
get_db_indexed_dataset_infos
,
get_merged_train_dataset
as
get_db_dataset
,
)
from
tools.retro.main
import
add_retro_args
from
tools.retro.query.retro_dataset
import
get_retro_datasets
from
tools.retro.utils
import
get_args_path
,
get_bert_tokenizer
,
get_gpt_tokenizer
def
shorten_str
(
s
,
n
):
s
=
"
\\
n"
.
join
(
s
.
splitlines
())
return
s
if
len
(
s
)
<=
n
else
"%s ... %s"
%
(
s
[:
n
//
2
],
s
[
-
n
//
2
:])
class
retro
:
args
=
None
##############################################
# initialize.
##############################################
@
classmethod
def
parse_dtype_str
(
cls
,
dtype_str
):
return
{
"torch.float16"
:
torch
.
float16
,
"torch.float32"
:
torch
.
float32
,
"torch.bfloat16"
:
torch
.
bfloat16
,
}[
dtype_str
]
@
classmethod
def
init_megatron
(
cls
,
workdir
):
'''Custom initialization of Megatron.'''
# Load args.
args_path
=
get_args_path
(
workdir
)
assert
os
.
path
.
exists
(
args_path
),
"args.json not found in workdir."
with
open
(
args_path
)
as
f
:
cls
.
args
=
types
.
SimpleNamespace
(
**
json
.
load
(
f
))
cls
.
args
.
retro_workdir
=
workdir
# just in case workdir moved
cls
.
args
.
rank
=
0
# override env
cls
.
args
.
world_size
=
1
# override env
cls
.
args
.
params_dtype
=
cls
.
parse_dtype_str
(
cls
.
args
.
params_dtype
)
set_global_variables
(
cls
.
args
)
set_retro_args
(
cls
.
args
)
_initialize_distributed
()
_set_random_seed
(
cls
.
args
.
seed
,
cls
.
args
.
data_parallel_random_init
)
_compile_dependencies
()
@
classmethod
def
init
(
cls
,
workdir
):
'''Initialize Megatron, tokenizers, and datasets.'''
# Load args.
cls
.
init_megatron
(
workdir
)
cls
.
tokenizers
=
types
.
SimpleNamespace
(
gpt
=
get_gpt_tokenizer
(),
bert
=
get_bert_tokenizer
(),
)
# Load data.
cls
.
db_indexed_dataset_infos
=
get_db_indexed_dataset_infos
()
cls
.
db_dataset
=
get_db_dataset
()
pt_train_ds
,
pt_valid_ds
,
_
=
get_retro_datasets
(
verify_sizes
=
False
)
cls
.
pt_datasets
=
types
.
SimpleNamespace
(
train
=
pt_train_ds
,
valid
=
pt_valid_ds
,
)
# Retrieve max saved neighbors.
for
key
in
vars
(
cls
.
pt_datasets
):
getattr
(
cls
.
pt_datasets
,
key
).
num_neighbors
=
\
cls
.
args
.
retro_query_num_neighbors_save
# Print usage.
cls
.
print_usage
()
##############################################
# utils.
##############################################
@
classmethod
def
gpt_to_text
(
cls
,
token_ids
):
'''GPT tokens to text.'''
return
cls
.
tokenizers
.
gpt
.
detokenize
(
token_ids
.
tolist
()
if
isinstance
(
token_ids
,
np
.
ndarray
)
else
token_ids
)
@
classmethod
def
text_to_bert
(
cls
,
text
):
'''Text to Bert tokens.'''
return
cls
.
tokenizers
.
bert
.
tokenize
(
text
)
##############################################
# chunk db.
##############################################
@
classmethod
def
get_db_num_indexed_datasets
(
cls
):
'''Number of indexed datasets within blendable dataset.'''
return
len
(
cls
.
db_indexed_dataset_infos
)
@
classmethod
def
get_db_indexed_dataset_infos
(
cls
):
'''Dataset infos, including number of training & sampled sets.'''
return
[(
info
[
"ratio"
],
info
[
"name"
])
for
info
in
cls
.
db_indexed_dataset_infos
]
@
classmethod
def
get_db_dataset
(
cls
):
return
cls
.
db_dataset
@
classmethod
def
get_db_num_chunks
(
cls
):
'''Number of DB chunks.'''
return
len
(
cls
.
get_db_dataset
())
@
classmethod
def
get_db_chunk_gpt
(
cls
,
idx
):
'''Get DB chunk as GPT token ids.'''
return
cls
.
get_db_dataset
()[
idx
][
"text"
].
tolist
()
@
classmethod
def
get_db_chunk_bert
(
cls
,
idx
):
'''Get DB chunk as Bert token ids.'''
return
cls
.
text_to_bert
(
cls
.
get_db_chunk_text
(
idx
))
@
classmethod
def
get_db_chunk_text
(
cls
,
idx
):
'''Get DB chunk as text.'''
return
cls
.
gpt_to_text
(
cls
.
get_db_chunk_gpt
(
idx
))
@
classmethod
def
get_db_chunk_and_continuation_text
(
cls
,
idx
):
'''Get DB chunk along with continuation, as text.'''
# Modulus used here to match original implementation (i.e., last
# chunks continuation wraps around to first chunk).
return
[
cls
.
get_db_chunk_text
(
idx
),
cls
.
get_db_chunk_text
((
idx
+
1
)
%
len
(
cls
.
get_db_dataset
())),
]
##############################################
# pretraining corpus.
##############################################
@
classmethod
def
get_pt_num_samples_and_chunks
(
cls
,
data_key
):
'''Number of samples & chunks (e.g., 32*n_samples) in corpus.'''
assert
hasattr
(
cls
.
pt_datasets
,
data_key
),
\
"pretraining set '%s' not found (choices: %s)."
%
(
data_key
,
", "
.
join
(
vars
(
cls
.
pt_datasets
).
keys
()))
chunk_dataset
=
getattr
(
cls
.
pt_datasets
,
data_key
).
chunk_dataset
return
(
len
(
chunk_dataset
.
sample_dataset
),
len
(
chunk_dataset
),
)
@
classmethod
def
get_pt_num_samples
(
cls
,
data_key
):
'''Number of pretraining samples.'''
return
cls
.
get_pt_num_samples_and_chunks
(
data_key
)[
0
]
@
classmethod
def
get_pt_num_chunks
(
cls
,
data_key
):
'''Number of pretraining chunks (e.g., 32*n_samples).'''
return
cls
.
get_pt_num_samples_and_chunks
(
data_key
)[
1
]
@
classmethod
def
get_pt_dataset
(
cls
,
data_key
):
return
getattr
(
cls
.
pt_datasets
,
data_key
)
@
classmethod
def
get_pt_sample
(
cls
,
data_key
,
idx
):
return
getattr
(
cls
.
pt_datasets
,
data_key
)[
idx
]
@
classmethod
def
get_neighbor_tokens
(
cls
,
sample_id
,
chunk_id
,
data_key
=
"train"
):
try
:
sample
=
cls
.
get_pt_sample
(
data_key
,
sample_id
)
sample_token_ids
=
sample
[
"text"
]
chunk_length
=
cls
.
args
.
retro_gpt_chunk_length
chunk_start_idx
=
chunk_id
*
chunk_length
chunk_end_idx
=
min
(
sample_token_ids
.
shape
[
0
],
chunk_start_idx
+
chunk_length
)
chunk_token_ids
=
sample_token_ids
[
chunk_start_idx
:
chunk_end_idx
]
neighbor_token_ids
=
sample
[
"neighbor_tokens"
][
chunk_id
]
return
{
"chunk_tokens"
:
chunk_token_ids
,
"neighbor_tokens"
:
neighbor_token_ids
,
}
except
:
return
None
@
classmethod
def
print_neighbor_texts
(
cls
,
sample_id
,
chunk_id
,
data_key
=
"train"
):
tokens
=
cls
.
get_neighbor_tokens
(
sample_id
,
chunk_id
,
data_key
)
print
(
"~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~"
)
try
:
print
(
"PRETRAINING CHUNK:"
)
print
(
" - %s"
%
shorten_str
(
cls
.
gpt_to_text
(
tokens
[
"chunk_tokens"
]),
150
))
print
(
"NEIGHBOR_CHUNKS:"
)
for
token_ids
in
tokens
[
"neighbor_tokens"
]:
print
(
" - %s"
%
shorten_str
(
cls
.
gpt_to_text
(
token_ids
),
150
))
except
:
print
(
"<no neighbors for sample %d>"
%
sample_id
)
##############################################
# usage.
##############################################
@
classmethod
def
print_usage
(
cls
):
'''Print usage.'''
print
()
print
(
"+++++++++++++++++++++++++++++++++++++++++++++++++++"
)
print
(
"examples ... [ *note*: 'db' = chunk db; 'pt' = pretraining corpus. ]"
)
print
(
"+++++++++++++++++++++++++++++++++++++++++++++++++++"
)
print
()
print
(
"~~~~ indexed datasets ~~~~"
)
print
(
"retro.get_db_num_indexed_datasets() : %s"
%
cls
.
get_db_num_indexed_datasets
())
print
(
"retro.get_db_indexed_dataset_infos() :"
)
for
i
,
(
ratio
,
prefix
)
in
enumerate
(
cls
.
get_db_indexed_dataset_infos
()):
print
(
" %s(%f, %s)%s"
%
(
"["
if
i
==
0
else
" "
,
ratio
,
prefix
,
"]"
if
i
==
len
(
cls
.
db_indexed_dataset_infos
)
-
1
else
","
,
))
print
()
print
(
"~~~~ counts ~~~~"
)
print
(
"retro.get_db_num_chunks : %d."
%
cls
.
get_db_num_chunks
())
print
()
for
sq_key
in
(
"sample"
,
"chunk"
):
for
data_key
in
(
"train"
,
"valid"
):
# test?
print
(
"retro.get_pt_num_%ss('%s') : %d."
%
(
sq_key
,
data_key
,
getattr
(
cls
,
f
"get_pt_num_
{
sq_key
}
s"
)(
data_key
)))
print
()
print
(
"~~~~ tokens, text ~~~~"
)
print
(
"retro.get_db_chunk_gpt(chunk_id) : %s"
%
shorten_str
(
str
(
retro
.
get_db_chunk_gpt
(
0
)),
50
))
print
(
"retro.get_db_chunk_bert(chunk_id) : %s"
%
shorten_str
(
str
(
retro
.
get_db_chunk_bert
(
0
)),
50
))
print
(
"retro.get_db_chunk_text(chunk_id) : %s"
%
shorten_str
(
retro
.
get_db_chunk_text
(
0
).
strip
(),
50
))
print
(
"retro.get_db_chunk_and_continuation_text(chunk_id) :"
)
for
i
,
t
in
enumerate
(
retro
.
get_db_chunk_and_continuation_text
(
0
)):
print
(
" %s'%s'%s"
%
(
"["
if
i
==
0
else
" "
,
shorten_str
(
t
.
strip
().
replace
(
"
\n
"
,
" "
),
50
),
"]"
if
i
==
1
else
","
,
))
sample
=
cls
.
get_pt_sample
(
"train"
,
0
)
sample_chunk_id
=
sample
[
"neighbor_tokens"
].
shape
[
0
]
//
2
sample_neighbor_id
=
0
print
()
print
(
"retro.get_pt_sample('train', sample_id) :"
)
print
(
" {"
)
for
k
,
v
in
sample
.
items
():
print
(
" '%s' : %s"
%
(
k
,
shorten_str
(
str
(
v
),
50
)))
print
(
" }"
)
print
()
print
(
"(e.g., sample = retro.get_pt_sample(...))"
)
print
()
print
(
" sample['text'].shape : %s"
%
str
(
sample
[
"text"
].
shape
))
print
(
" sample['neighbor_tokens'].shape : %s"
%
str
(
sample
[
"neighbor_tokens"
].
shape
))
print
(
" sample['text'] : %s"
%
shorten_str
(
str
(
sample
[
"text"
]),
50
))
print
(
" sample['neighbor_tokens'][17][1] : %s"
%
shorten_str
(
str
(
sample
[
"neighbor_tokens"
][
sample_chunk_id
][
sample_neighbor_id
]),
50
))
print
(
" retro.gpt_to_text(sample['text']) : %s"
%
shorten_str
(
cls
.
gpt_to_text
(
sample
[
"text"
]),
50
))
print
(
" retro.gpt_to_text(sample['neighbor_tokens']) : %s"
%
shorten_str
(
cls
.
gpt_to_text
(
sample
[
"neighbor_tokens"
][
sample_chunk_id
][
sample_neighbor_id
]),
50
))
print
(
"+++++++++++++++++++++++++++++++++++++++++++++++++++"
)
tools/retro/db/build.py
View file @
3aca1415
...
...
@@ -24,11 +24,13 @@ from tools.retro.external_libs import h5py
from
tools.retro.utils
import
get_gpt_tokenizer
,
get_bert_tokenizer
from
.utils
import
(
get_individual_db
,
get_indexed_dataset_infos
,
get_indexed_dataset_infos_path
,
get_individual_db_dir
,
get_individual_chunk_db
,
get_individual_doc_offsets
,
get_merged_dataset
,
get_merged_db_path_map
,
get_train_doc_chunk_map_dir
,
save_indexed_dataset_infos
,
)
...
...
@@ -52,7 +54,7 @@ def init_indexed_dataset_infos():
prefix
=
args
.
data_path
[
i
+
1
]
path
=
prefix
+
".bin"
name
=
os
.
path
.
basename
(
prefix
)
assert
os
.
path
.
exists
(
path
)
assert
os
.
path
.
exists
(
path
)
,
"couldn't find '%s'."
%
path
infos
.
append
({
"ratio"
:
ratio
,
"prefix"
:
prefix
,
...
...
@@ -114,6 +116,7 @@ def build_partial_db(
# Iterate documents & parse chunks.
chunk_db_valid
=
[]
chunk_db_invalid
=
[]
doc_size_map
=
{}
for
doc_id
in
pbar
:
# Progress description.
...
...
@@ -130,7 +133,7 @@ def build_partial_db(
# Remove EOD token.
doc
=
indexed_dataset
.
get
(
doc_id
)
if
doc
[
-
1
].
item
()
==
tokenizers
.
gpt
.
eod
_id
:
if
doc
[
-
1
].
item
()
==
tokenizers
.
gpt
.
eod
:
doc
=
doc
[:
-
1
]
doc_len
=
len
(
doc
)
...
...
@@ -140,6 +143,7 @@ def build_partial_db(
for
s
in
chunk_start_idxs
]
# Re-tokenize each chunk to Bert/Wordpiece (empty bert -> 'invalid').
doc_size_map
[
doc_id
]
=
0
for
i
,
chunk_start_idx
in
enumerate
(
chunk_start_idxs
):
# Re-tokenize.
...
...
@@ -149,13 +153,15 @@ def build_partial_db(
offset
=
chunk_start_idx
,
length
=
chunk_end_idx
-
chunk_start_idx
,
)
text
=
tokenizers
.
gpt
.
detokenize
(
gpt_token_ids
)
text
=
tokenizers
.
gpt
.
detokenize
(
gpt_token_ids
.
tolist
()
)
bert_token_ids
=
tokenizers
.
bert
.
tokenize
(
text
)
# 'Valid' for non-empty Bert chunks; 'invalid' otherwise.
_chunk_db
=
chunk_db_invalid
\
if
len
(
bert_token_ids
)
==
0
else
\
chunk_db_valid
if
len
(
bert_token_ids
)
==
0
:
_chunk_db
=
chunk_db_invalid
else
:
_chunk_db
=
chunk_db_valid
doc_size_map
[
doc_id
]
+=
1
_chunk_db
.
append
((
doc_id
,
chunk_start_idx
,
...
...
@@ -163,7 +169,7 @@ def build_partial_db(
len
(
bert_token_ids
),
))
return
proc_id
,
chunk_db_valid
,
chunk_db_invalid
return
proc_id
,
chunk_db_valid
,
chunk_db_invalid
,
doc_size_map
def
build_individual_db
(
dataset_idx
,
n_datasets
,
dataset_info
,
tokenizers
):
...
...
@@ -181,9 +187,10 @@ def build_individual_db(dataset_idx, n_datasets, dataset_info, tokenizers):
# Missing db blocks.
n_missing_world
,
missing_db_blocks
=
get_missing_blocks_by_rank
(
db_dir
,
len
(
indexed_dataset
.
doc_idx
)
-
1
,
len
(
indexed_dataset
)
,
args
.
retro_doc_block_size
,
validate
=
lambda
f
:
f
[
"chunks_valid"
].
shape
[
1
]
==
4
)
validate
=
lambda
f
:
f
[
"chunks_valid"
].
shape
==
(
0
,)
\
or
f
[
"chunks_valid"
].
shape
[
1
]
==
4
)
# Prevent missing-path-write race condition.
torch
.
distributed
.
barrier
()
...
...
@@ -209,6 +216,8 @@ def build_individual_db(dataset_idx, n_datasets, dataset_info, tokenizers):
if
block
is
not
None
:
db_path
=
block
[
"path"
]
# Build partial dbs.
print_rank_0
(
' > build partial dbs.'
)
futures
=
[]
...
...
@@ -240,15 +249,27 @@ def build_individual_db(dataset_idx, n_datasets, dataset_info, tokenizers):
# Convert to numpy.
print_rank_0
(
' > converting chunk db to numpy.'
)
chunk_db_valid
=
np
.
array
(
chunk_db_valid
)
chunk_db_invalid
=
np
.
array
(
chunk_db_invalid
)
chunk_db_valid
=
np
.
array
(
chunk_db_valid
,
dtype
=
"uint32"
)
chunk_db_invalid
=
np
.
array
(
chunk_db_invalid
,
dtype
=
"uint32"
)
# Document offsets.
doc_sizes
=
[(
d
,
s
)
for
partial_chunk_db
in
partial_chunk_dbs
for
d
,
s
in
partial_chunk_db
[
3
].
items
()]
doc_sizes
.
sort
(
key
=
lambda
item
:
item
[
0
])
doc_offsets
=
np
.
cumsum
([
item
[
1
]
for
item
in
doc_sizes
])
\
.
astype
(
"uint64"
)
doc_offsets
=
np
.
stack
((
np
.
array
([
item
[
0
]
for
item
in
doc_sizes
],
dtype
=
"uint64"
),
doc_offsets
),
axis
=
1
)
# Save DB.
print_rank_0
(
" > saving individual db."
)
f
=
h5py
.
File
(
block
[
"path"
],
"w"
)
dset
=
f
.
create_dataset
(
"chunks_valid"
,
data
=
chunk_db_valid
)
dset
=
f
.
create_dataset
(
"chunks_invalid"
,
data
=
chunk_db_invalid
)
f
.
close
()
with
h5py
.
File
(
db_path
,
"w"
)
as
f
:
dset
=
f
.
create_dataset
(
"chunks_valid"
,
data
=
chunk_db_valid
)
dset
=
f
.
create_dataset
(
"chunks_invalid"
,
data
=
chunk_db_invalid
)
dset
=
f
.
create_dataset
(
"doc_offsets"
,
data
=
doc_offsets
)
# Wait for all ranks to finish block.
print_rank_0
(
" > waiting for all ranks to finish block."
)
...
...
@@ -292,14 +313,16 @@ def update_chunk_counts(indexed_dataset_infos):
if
torch
.
distributed
.
get_rank
()
!=
0
:
return
# Data ratio sum (for setting index training chunks).
data_ratio_sum
=
sum
([
d
[
"ratio"
]
for
d
in
indexed_dataset_infos
])
# Training split size (split at document level).
train_fraction
=
float
(
args
.
split
.
split
(
","
)[
0
])
/
100
assert
train_fraction
>
0
and
train_fraction
<=
1
# Set n_chunks (including n_chunks_sampled for unambiguity).
print_rank_0
(
" > compute n_chunks."
)
for
ds_index
,
ds_info
in
\
enumerate
(
tqdm
(
indexed_dataset_infos
,
"count_chunks"
)):
for
ds_index
,
ds_info
in
enumerate
(
indexed_dataset_infos
):
db_dir
=
ds_info
[
"db_dir"
]
db_paths
=
sorted
(
glob
.
glob
(
db_dir
+
"/*.hdf5"
))
...
...
@@ -310,16 +333,17 @@ def update_chunk_counts(indexed_dataset_infos):
ds_info
[
"n_chunks"
]
=
0
# previously, 'n_chunks_valid'
ds_info
[
"n_chunks_train"
]
=
0
ds_info
[
"n_chunks_invalid"
]
=
0
for
db_path
in
db_paths
:
with
h5py
.
File
(
db_path
,
"r"
)
as
f
:
for
db_path
in
tqdm
(
db_paths
,
"%d/%d, %s"
%
(
ds_index
,
len
(
indexed_dataset_infos
),
ds_info
[
"name"
])):
with
h5py
.
File
(
db_path
,
"r"
)
as
f
:
ds_info
[
"n_chunks"
]
+=
len
(
f
[
"chunks_valid"
])
ds_info
[
"n_chunks_invalid"
]
+=
len
(
f
[
"chunks_invalid"
])
ds_info
[
"n_chunks_train"
]
+=
\
(
np
.
copy
(
f
[
"chunks_valid"
][:,
0
])
<
ds_info
[
"n_docs_train"
])
\
.
sum
().
item
()
ds_info
[
"n_chunks_sampled"
]
=
\
int
(
round
(
args
.
retro_nchunks_sampled
*
ds_info
[
"ratio"
])
)
ds_info
[
"n_chunks_sampled"
]
=
int
(
args
.
retro_index_ntrain
*
ds_info
[
"ratio"
]
/
data_ratio_sum
)
# Verify counts.
assert
ds_info
[
"n_chunks_train"
]
<=
ds_info
[
"n_chunks"
],
\
...
...
@@ -339,15 +363,14 @@ def merge_dbs(indexed_dataset_infos, db_type):
print
(
" > build %s chunk db."
%
db_type
)
# Count chunks.
if
db_type
==
"full"
:
raise
Exception
(
"deprecated; use 'train' or 'sampled'."
)
n_chunks_key
=
"n_chunks"
elif
db_type
==
"sampled"
:
if
db_type
==
"sampled"
:
n_chunks_key
=
"n_chunks_sampled"
n_docs_key
=
None
elif
db_type
==
"train"
:
n_chunks_key
=
"n_chunks_train"
n_docs_key
=
"n_docs_train"
elif
db_type
==
"valid"
:
pass
n_docs_key
=
None
else
:
raise
Exception
(
"handle db_type '%s'."
%
db_type
)
...
...
@@ -356,6 +379,8 @@ def merge_dbs(indexed_dataset_infos, db_type):
for
m
in
indexed_dataset_infos
)
else
:
n_chunks
=
sum
(
m
[
n_chunks_key
]
for
m
in
indexed_dataset_infos
)
n_docs
=
None
if
n_docs_key
is
None
else
\
sum
(
m
[
n_docs_key
]
for
m
in
indexed_dataset_infos
)
# DB path.
db_path
=
get_merged_db_path_map
()[
db_type
]
...
...
@@ -375,10 +400,10 @@ def merge_dbs(indexed_dataset_infos, db_type):
except
Exception
as
e
:
if
isinstance
(
e
,
OSError
):
os
.
remove
(
full_
db_path
)
os
.
remove
(
db_path
)
elif
isinstance
(
e
,
KeyError
):
f
.
close
()
os
.
remove
(
full_
db_path
)
os
.
remove
(
db_path
)
else
:
raise
e
...
...
@@ -389,121 +414,60 @@ def merge_dbs(indexed_dataset_infos, db_type):
f
=
h5py
.
File
(
db_path
,
"w"
)
# Initialize output arrays.
merged_db
=
f
.
create_dataset
(
"chunks"
,
(
n_chunks
,
5
),
dtype
=
"i8"
)
merged_chunk_db
=
\
f
.
create_dataset
(
"chunks"
,
(
n_chunks
,
5
),
dtype
=
"uint32"
)
merged_doc_offsets
=
None
if
n_docs_key
is
None
else
\
f
.
create_dataset
(
"doc_offsets"
,
(
n_docs
,
3
),
dtype
=
"uint64"
)
n_written
=
f
.
create_dataset
(
"n_written"
,
(
1
,),
dtype
=
"uint64"
)
n_written
[
0
]
=
0
# Iterate indexed datasets & collect chunks.
start_index
=
0
chunk_start_index
=
0
doc_start_index
=
0
doc_start_offset
=
0
for
ds_idx
,
ds_info
in
enumerate
(
indexed_dataset_infos
):
print
(
" > merging dbs; '%s', dataset %d / %d ... '%s'."
%
(
db_type
,
ds_idx
,
len
(
indexed_dataset_infos
),
ds_info
[
"name"
]))
individual_db
=
get_individual_db
(
ds_idx
,
ds_info
)
individual_chunk_db
=
get_individual_chunk_db
(
ds_idx
,
ds_info
)
individual_doc_offsets
=
None
if
n_docs_key
is
None
else
\
get_individual_doc_offsets
(
ds_idx
,
ds_info
)
if
db_type
==
"valid"
:
individual_db
=
individual_db
[
ds_info
[
"n_chunks_train"
]:]
individual_chunk_db
=
\
individual_chunk_db
[
ds_info
[
"n_chunks_train"
]:]
if
n_docs_key
is
None
:
individual_doc_offsets
=
None
else
:
train_doc_offset
=
\
individual_doc_offsets
[
ds_info
[
"n_docs_train"
]
-
1
,
2
]
individual_doc_offsets
=
\
np
.
copy
(
individual_doc_offsets
[
ds_info
[
"n_docs_train"
]:])
individual_doc_offsets
[:,
2
]
-=
train_doc_offset
print
(
"~~~"
)
print
(
individual_doc_offsets
)
print
(
train_doc_offset
)
raise
Exception
(
"test me."
)
else
:
individual_db
=
individual_db
[:
ds_info
[
n_chunks_key
]]
merged_db
[
start_index
:
start_index
+
len
(
individual_db
)]
=
individual_db
start_index
+=
len
(
individual_db
)
n_written
[
0
]
=
start_index
individual_chunk_db
=
\
individual_chunk_db
[:
ds_info
[
n_chunks_key
]]
individual_doc_offsets
=
None
if
n_docs_key
is
None
else
\
np
.
copy
(
individual_doc_offsets
[:
ds_info
[
n_docs_key
]])
merged_chunk_db
[
chunk_start_index
:
chunk_start_index
+
len
(
individual_chunk_db
)]
=
individual_chunk_db
chunk_start_index
+=
len
(
individual_chunk_db
)
n_written
[
0
]
=
chunk_start_index
if
n_docs_key
is
not
None
:
individual_doc_offsets
[:,
2
]
+=
doc_start_offset
doc_end_index
=
doc_start_index
+
individual_doc_offsets
.
shape
[
0
]
merged_doc_offsets
[
doc_start_index
:
doc_end_index
]
=
\
individual_doc_offsets
doc_start_index
=
doc_end_index
doc_start_offset
=
individual_doc_offsets
[
-
1
,
2
].
item
()
f
.
close
()
def
get_partial_banned_chunk_map
(
proc_id
,
db_path
,
chunk_range_info
):
'''Build partial mapping of {(dataset_id,doc_id):[chunk_ids]}.
In this method, only chunks within the range (start_chunk_id, end_chunk_id]
are processed.'''
start_chunk_id
=
chunk_range_info
[
"start"
]
end_chunk_id
=
chunk_range_info
[
"end"
]
output_path
=
chunk_range_info
[
"path"
]
# Skip, if output file exists.
if
os
.
path
.
exists
(
output_path
):
return
# Chunk subset.
with
h5py
.
File
(
db_path
)
as
f
:
sub_chunk_db
=
np
.
copy
(
f
[
"chunks"
][
start_chunk_id
:
end_chunk_id
,
:
2
])
# Map docs to chunks.
banned_chunk_map
=
defaultdict
(
list
)
for
rel_chunk_id
,
(
dataset_id
,
doc_id
)
in
enumerate
(
tqdm
(
sub_chunk_db
,
"map banned docs, proc %d"
%
proc_id
,
total
=
sub_chunk_db
.
shape
[
0
],
)):
chunk_id
=
start_chunk_id
+
rel_chunk_id
banned_chunk_map
[
"%d,%d"
%
(
dataset_id
.
item
(),
doc_id
.
item
())]
\
.
append
(
chunk_id
)
# Save output.
with
open
(
output_path
,
"w"
)
as
f
:
json
.
dump
(
banned_chunk_map
,
f
)
def
build_doc_chunk_map
(
indexed_dataset_infos
,
db_type
):
'''Build mapping of {(dataset_id,doc_id):[chunk_ids]}.'''
if
torch
.
distributed
.
get_rank
()
!=
0
:
return
print
(
" > build %s doc-chunk map."
%
db_type
)
n_procs
=
128
# Get dataset.
db_dataset
=
get_merged_dataset
(
db_type
,
indexed_dataset_infos
)
# Sub-ranges for parallel processing.
n_chunks
=
db_dataset
.
chunks
.
shape
[
0
]
n_chunks_per_proc
=
max
(
1
,
int
(
np
.
ceil
(
n_chunks
/
n_procs
)))
chunk_id_starts
=
list
(
range
(
0
,
n_chunks
,
n_chunks_per_proc
))
chunk_id_ranges
=
[(
s
,
min
(
n_chunks
,
s
+
n_chunks_per_proc
))
for
s
in
chunk_id_starts
]
# Wrap range info with output path.
n_digits
=
int
(
np
.
ceil
(
np
.
log
(
n_chunks
)
/
np
.
log
(
10
))
+
1
)
output_dirname
=
get_train_doc_chunk_map_dir
()
chunk_range_infos
=
[{
"start"
:
start_id
,
"end"
:
end_id
,
"path"
:
os
.
path
.
join
(
output_dirname
,
"%s-%s.json"
%
(
str
(
start_id
).
zfill
(
n_digits
),
str
(
end_id
).
zfill
(
n_digits
),
)),
}
for
start_id
,
end_id
in
chunk_id_ranges
]
# Build doc-chunk map.
print_rank_0
(
"build doc-chunk-map."
)
with
ProcessPoolExecutor
(
max_workers
=
n_procs
)
as
executor
:
# Build partial chunk maps.
futures
=
[]
for
proc_id
,
chunk_range_info
in
enumerate
(
chunk_range_infos
):
if
os
.
path
.
exists
(
chunk_range_info
[
"path"
]):
continue
# Submit job.
futures
.
append
(
executor
.
submit
(
get_partial_banned_chunk_map
,
proc_id
,
db_dataset
.
db_path
,
chunk_range_info
,
))
# Wait for processes to finish.
banned_chunk_paths
=
[]
for
finished_idx
,
future
in
enumerate
(
as_completed
(
futures
)):
print
(
"finished %d / %d."
%
(
finished_idx
,
n_procs
))
future
.
result
()
def
build_db
():
'''Extract token chunks from each indexed dataset.
...
...
@@ -521,14 +485,13 @@ def build_db():
if
torch
.
distributed
.
get_rank
()
!=
0
:
return
# Update n_chunks.
update_chunk_counts
(
indexed_dataset_infos
)
# Update n_chunks & save indexed dataset infos.
if
not
os
.
path
.
exists
(
get_indexed_dataset_infos_path
()):
update_chunk_counts
(
indexed_dataset_infos
)
save_indexed_dataset_infos
(
indexed_dataset_infos
)
indexed_dataset_infos
=
get_indexed_dataset_infos
()
# Merge dbs.
merge_dbs
(
indexed_dataset_infos
,
"sampled"
)
merge_dbs
(
indexed_dataset_infos
,
"train"
)
merge_dbs
(
indexed_dataset_infos
,
"valid"
)
build_doc_chunk_map
(
indexed_dataset_infos
,
"train"
)
# Save (fully annotated) indexed dataset infos.
save_indexed_dataset_infos
(
indexed_dataset_infos
)
tools/retro/db/dataset.py
View file @
3aca1415
...
...
@@ -3,6 +3,7 @@
import
json
import
numpy
as
np
import
torch
from
tqdm
import
tqdm
from
megatron
import
get_args
,
print_rank_0
from
tools.retro.external_libs
import
h5py
...
...
@@ -27,9 +28,10 @@ class DBDataset(torch.utils.data.Dataset):
self
.
db_path
=
db_path
self
.
indexed_datasets
=
indexed_datasets
self
.
chunks
=
chunks
self
.
doc_chunk_map
=
None
self
.
max_chunk_length
=
max_chunk_length
self
.
eod_token_id
=
get_gpt_tokenizer
().
eod
_id
self
.
eod_token_id
=
get_gpt_tokenizer
().
eod
def
__len__
(
self
):
return
self
.
chunks
.
shape
[
0
]
...
...
@@ -58,3 +60,15 @@ class DBDataset(torch.utils.data.Dataset):
"doc_id"
:
doc_id
,
"text"
:
np
.
array
(
token_ids
,
dtype
=
np
.
int64
),
}
def
load_doc_tuples
(
self
):
'''Load the dataset & document ids.
Load the dataset id & document id of each chunk in the database, to
be used for causality filtering during querying.
'''
self
.
doc_tuples
=
np
.
zeros
(
shape
=
(
len
(
self
),
2
),
dtype
=
"uint32"
)
block_size
=
int
(
1e6
)
for
start_idx
in
tqdm
(
range
(
0
,
len
(
self
),
block_size
)):
end_idx
=
min
(
len
(
self
),
start_idx
+
block_size
)
self
.
doc_tuples
[
start_idx
:
end_idx
]
=
self
.
chunks
[
start_idx
:
end_idx
,:
2
]
tools/retro/db/utils.py
View file @
3aca1415
...
...
@@ -57,14 +57,14 @@ def get_indexed_dataset_infos():
def
get_individual_db_dir
(
name
):
'''Individual DB's directory.'''
return
os
.
path
.
join
(
get_base_db_workdir
(),
"individual"
,
name
,
"db"
)
return
os
.
path
.
join
(
get_base_db_workdir
(),
"individual"
,
name
)
def
get_individual_db
(
ds_id
,
ds_info
):
def
get_individual_
chunk_
db
(
ds_id
,
ds_info
):
'''Load individual dataset's chunk DB.'''
db_paths
=
sorted
(
glob
.
glob
(
ds_info
[
"db_dir"
]
+
"/*hdf5"
))
# *Note*: convert to dataset, rather than copying to memory.
db
=
np
.
zeros
((
ds_info
[
"n_chunks"
],
5
),
dtype
=
"
i8
"
)
db
=
np
.
zeros
((
ds_info
[
"n_chunks"
],
5
),
dtype
=
"
uint32
"
)
db
[:,
0
]
=
ds_id
start_idx
=
0
for
db_path
in
db_paths
:
...
...
@@ -79,6 +79,27 @@ def get_individual_db(ds_id, ds_info):
return
db
def
get_individual_doc_offsets
(
ds_id
,
ds_info
):
'''Load individual dataset's chunk DB.'''
paths
=
sorted
(
glob
.
glob
(
ds_info
[
"db_dir"
]
+
"/*hdf5"
))
# *Note*: convert to dataset, rather than copying to memory.
doc_offsets
=
np
.
zeros
((
ds_info
[
"n_docs"
],
3
),
dtype
=
"uint64"
)
doc_offsets
[:,
0
]
=
ds_id
start_idx
=
0
start_offset
=
0
for
path
in
paths
:
with
h5py
.
File
(
path
)
as
f
:
current_doc_offsets
=
np
.
copy
(
f
[
"doc_offsets"
])
current_doc_offsets
[:,
1
]
+=
start_offset
current_ndocs
=
current_doc_offsets
.
shape
[
0
]
doc_offsets
[
start_idx
:(
start_idx
+
current_ndocs
),
1
:]
=
\
current_doc_offsets
start_idx
+=
current_ndocs
start_offset
=
current_doc_offsets
[
-
1
,
1
].
item
()
return
doc_offsets
def
get_merged_db_path_map
():
'''Paths to merged datasets.'''
base_dir
=
get_base_db_workdir
()
...
...
@@ -120,28 +141,3 @@ def get_merged_train_dataset(indexed_dataset_infos=None):
def
get_merged_valid_dataset
(
indexed_dataset_infos
=
None
):
return
get_merged_dataset
(
"valid"
,
indexed_dataset_infos
)
def
get_train_doc_chunk_map_dir
():
dirname
=
os
.
path
.
join
(
get_base_db_workdir
(),
"merged"
,
"train_doc_chunk_map"
)
os
.
makedirs
(
dirname
,
exist_ok
=
True
)
return
dirname
def
get_train_doc_chunk_map
():
paths
=
sorted
(
glob
.
glob
(
get_train_doc_chunk_map_dir
()
+
"/*.json"
))
doc_map
=
defaultdict
(
set
)
for
path
in
tqdm
(
paths
,
"load train doc maps"
):
# Read file.
with
open
(
path
)
as
f
:
crnt_doc_map
=
json
.
load
(
f
)
# Add to doc map.
for
key
,
chunk_ids
in
crnt_doc_map
.
items
():
key
=
tuple
(
int
(
i
)
for
i
in
key
.
split
(
","
))
doc_map
[
key
].
update
(
chunk_ids
)
return
doc_map
tools/retro/examples/get_dataset_configs.sh
deleted
100644 → 0
View file @
0024a5c6
#!/bin/bash
# Small English Wikipedia dataset (~2M chunks).
get_wiki_tiny_config
()
{
RETRO_INDEX_STR
=
"IVF4096_HNSW4,Flat"
RETRO_GPT_TRAIN_SAMPLES
=
31250
LR_DECAY_SAMPLES
=
2
LR_WARMUP_SAMPLES
=
1
RETRO_GPT_EVAL_INTERVAL
=
2000
RETRO_GPT_EVAL_ITERS
=
100
RETRO_EF_SEARCH
=
4
RETRO_NPROBE
=
64
DATALOADER_TYPE
=
cyclic
}
# English Wikipedia dataset (~67M chunks).
get_wiki_config
()
{
RETRO_INDEX_STR
=
"IVF262144_HNSW32,Flat"
RETRO_GPT_TRAIN_SAMPLES
=
2037248
LR_DECAY_SAMPLES
=
2
LR_WARMUP_SAMPLES
=
1
RETRO_GPT_EVAL_INTERVAL
=
2000
RETRO_GPT_EVAL_ITERS
=
100
RETRO_EF_SEARCH
=
16
RETRO_NPROBE
=
4096
DATALOADER_TYPE
=
cyclic
}
# Full corpus (~5B chunks).
get_corpus_config
()
{
RETRO_INDEX_STR
=
"OPQ32_256,IVF4194304_HNSW32,PQ32"
RETRO_GPT_TRAIN_SAMPLES
=
192000000
LR_DECAY_SAMPLES
=
166400000
LR_WARMUP_SAMPLES
=
162761
RETRO_GPT_EVAL_INTERVAL
=
2000
RETRO_GPT_EVAL_ITERS
=
50
RETRO_EF_SEARCH
=
32
RETRO_NPROBE
=
4096
DATALOADER_TYPE
=
single
}
tools/retro/examples/get_preprocess_cmd.sh
deleted
100644 → 0
View file @
0024a5c6
#!/bin/bash
# Build preprocessing command for Retro.
set
-u
DIR
=
$(
cd
--
"
$(
dirname
--
"
${
BASH_SOURCE
[0]
}
"
)
"
&> /dev/null
&&
pwd
)
################ Required environment variables. ################
# Required environment variables:
# - REPO_DIR : Root directory of Megatron codebase.
# - RETRO_WORKDIR : Root directory of this Retro project's processed data. (For
# example, this project directory might be for a blended dataset, while
# another project directory might be for just a Wikipedia dataset, and
# another for just Book Corpus data, etc.) This project directory will
# contain a complete set of processed data, including the retrieval
# database, search index, and pretraining neighbors.
# - RETRO_TASKS : One of 'build', 'db-build', 'index-build', or
# 'pretraining-query-neighbors'. See 'Retro tasks' below for task
# descriptions.
# - DATA_BLEND_SCRIPT : Path to blended dataset definition file.
# - GPT_VOCAB_FILE : GPT vocab file.
# - GPT_MERGE_FILE : GPT merge file.
# - GPT_TOKENIZER : GPT tokenizer type (e.g., GPT2BPETokenizer)
# - BERT_LOAD_PATH : Bert checkpoint directory.
# - BERT_VOCAB_FILE : Bert vocab file.
# - BERT_TOKENIZER : Bert tokenizer type (e.g., BertWordPieceLowerCase,
# BertWordPieceCase).
# - BERT_EMBEDDER_TYPE : One of 'megatron' or 'huggingface'.
# - EXTRA_ARGS : Extra arguments (else, leave empty).
################ Data blend. ################
.
${
DATA_BLEND_SCRIPT
}
DATA_PATH
=
${
DATA_BLEND
}
################ Retro setup. ################
RETRO_GPT_SEQ_LENGTH
=
2048
RETRO_GPT_CHUNK_LENGTH
=
64
RETRO_GPT_MICRO_BATCH_SIZE
=
1
# *8
RETRO_GPT_GLOBAL_BATCH_SIZE
=
256
RETRO_NCHUNKS_SAMPLED
=
300000000
################ Retro tasks. ################
# The '--retro-tasks' argument is a comma-separated list of tasks to run, in
# sequential order. For a quick start, simply set this to 'build' to run the
# entire preprocessing pipeline. For finer control, you may specify the list of
# tasks to run. This is desirable for tuning computational resources. For
# example, training the search index is relatively fast and utilizes GPUs,
# while querying the search index is relatively slow, CPU-only, and memory
# intensive (i.e., multiple populated search indexes are loaded simultaneously).
# *Note* : Once the task(s) below have been completed -- by running either
# 1) 'build', or 2) the sequential combination of 'db-build', 'index-build',
# and 'pretraining-query-neighbors' -- we are ready to pretrain Retro by
# calling pretrain_retro.py.
# ---- Option #1 : Run entire pipeline. ----
# RETRO_TASKS="build" # (*note*: default tasks)
# ---- Option #2 : Run specific stages. ----
# *Note*: Run the following stages in the given order. Optionally, tune your
# cluster setup for each stage, as described above.
# RETRO_TASKS="db-build" # ....................... run 1st
# RETRO_TASKS="index-build" # .................... run 2nd
# RETRO_TASKS="pretraining-query-neighbors" # .... run 3rd
################ Megatron args. ################
MEGATRON_ARGS
=
"
\
--seed 1234
\
--distributed-timeout-minutes 600
\
--tokenizer-type
${
BERT_TOKENIZER
}
\
--tensor-model-parallel-size 1
\
--pipeline-model-parallel-size 1
\
--num-layers 24
\
--hidden-size 1024
\
--num-attention-heads 16
\
--micro-batch-size
${
RETRO_GPT_MICRO_BATCH_SIZE
}
\
--global-batch-size
${
RETRO_GPT_GLOBAL_BATCH_SIZE
}
\
--seq-length 512
\
--max-position-embeddings 512
\
--train-samples
${
RETRO_GPT_TRAIN_SAMPLES
}
\
--load
${
BERT_LOAD_PATH
}
\
--exit-on-missing-checkpoint
\
--no-load-optim
\
--data-path
${
DATA_PATH
}
\
--vocab-file
${
BERT_VOCAB_FILE
}
\
--data-impl mmap
\
--split 98,2,0
\
--distributed-backend nccl
\
--lr 0.0001
\
--lr-decay-style linear
\
--min-lr 1.0e-5
\
--lr-decay-samples
${
LR_DECAY_SAMPLES
}
\
--lr-warmup-samples
${
LR_WARMUP_SAMPLES
}
\
--weight-decay 1e-2
\
--clip-grad 1.0
\
--eval-interval
${
RETRO_GPT_EVAL_INTERVAL
}
\
--eval-iters
${
RETRO_GPT_EVAL_ITERS
}
\
--fp16
\
--DDP-impl local
\
--dataloader-type
${
DATALOADER_TYPE
}
\
--no-data-sharding
\
--no-gradient-accumulation-fusion
\
--no-async-tensor-model-parallel-allreduce
\
"
################ Retro args. ################
RETRO_ARGS
=
"
\
--bert-embedder-type
${
BERT_EMBEDDER_TYPE
}
\
--output-bert-embeddings
\
\
--retro-gpt-vocab-file
${
GPT_VOCAB_FILE
}
\
--retro-gpt-merge-file
${
GPT_MERGE_FILE
}
\
--retro-gpt-tokenizer-type
${
GPT_TOKENIZER
}
\
--retro-gpt-seq-length
${
RETRO_GPT_SEQ_LENGTH
}
\
--retro-gpt-chunk-length
${
RETRO_GPT_CHUNK_LENGTH
}
\
--retro-bert-vocab-file
${
BERT_VOCAB_FILE
}
\
--retro-bert-tokenizer-type
${
BERT_TOKENIZER
}
\
\
--retro-tasks
${
RETRO_TASKS
}
\
--retro-index-str
${
RETRO_INDEX_STR
}
\
--retro-ef-search
${
RETRO_EF_SEARCH
}
\
--retro-nprobe
${
RETRO_NPROBE
}
\
\
--retro-workdir
${
RETRO_WORKDIR
}
\
--retro-nchunks-sampled
${
RETRO_NCHUNKS_SAMPLED
}
\
\
--retro-return-doc-ids
\
"
################ Command. ################
RETRO_PREPROCESS_CMD
=
"
\
./tools/retro/main.py
\
${
MEGATRON_ARGS
}
\
${
RETRO_ARGS
}
\
${
EXTRA_ARGS
}
\
"
tools/retro/examples/preprocess_data.sh
View file @
3aca1415
#!/bin/bash
set
-u
unset
NCCL_DEBUG
NPROCS
=
8
# NPROCS must be <= number of GPUs.
######## Megatron, Retro dirs. ########
set_current_dir
()
{
DIR
=
$(
cd
--
"
$(
dirname
--
"
${
BASH_SOURCE
[0]
}
"
)
"
&> /dev/null
&&
pwd
)
}
REPO_DIR
=
"<path/to/megatron/repo>"
RETRO_WORKDIR
=
"<path/to/retro/data/directory>"
################ Dataset configs. ################
# This script contains methods to customize arguments to specific dataset
# types. Customize this script as needed for your datasets.
set_current_dir
.
$DIR
/get_dataset_configs.sh
######## Task (e.g., db, index, query). ########
################ Environment variables. ################
# *Note*: See 'Required environment variables' in 'get_preprocess_cmd.sh' for
# a description of the required environment variables. These variables can be
# set however a user would like. In our setup, we use another bash script
# (location defined by $RETRO_ENV_VARS) that sets all the environment variables
# at once.
.
$RETRO_ENV_VARS
RETRO_TASKS
=
"db-build"
# RETRO_TASKS="index-train"
# RETRO_TASKS="index-add"
# RETRO_TASKS="query-pretraining-neighbors"
######## Environment vars. ########
set_current_dir
.
${
DIR
}
/get_preprocess_cmd.sh
######## Data. ########
echo
"~~~~~~~~~~~~~~~~~~~~~~~~~~"
echo
"DIR = '
$DIR
'."
echo
"RETRO_PREPROCESS_CMD = '
$RETRO_PREPROCESS_CMD
'."
echo
"~~~~~~~~~~~~~~~~~~~~~~~~~~"
DATA_BLEND
=
"<see --data-path in arguments.py>"
######## Index. ########
RETRO_INDEX_STR
=
"OPQ32_64,IVF65536_HNSW8,PQ32"
RETRO_INDEX_NTRAIN
=
1000000
RETRO_INDEX_TRAIN_LOAD_FRACTION
=
0.97
RETRO_INDEX_ADD_LOAD_FRACTION
=
0.95
######## GPT. ########
RETRO_GPT_SEED
=
1234
RETRO_GPT_SPLIT
=
"98,2,0"
RETRO_GPT_DATA_PATH
=
${
DATA_BLEND
}
RETRO_GPT_DATA_IMPL
=
mmap
RETRO_GPT_DATALOADER_TYPE
=
single
RETRO_GPT_EVAL_INTERVAL
=
2000
RETRO_GPT_EVAL_ITERS
=
50
RETRO_GPT_TRAIN_SAMPLES
=
200000
RETRO_GPT_LR_DECAY_SAMPLES
=
175000
RETRO_GPT_LR_WARMUP_SAMPLES
=
10000
RETRO_GPT_SEQ_LENGTH
=
512
RETRO_GPT_GLOBAL_BATCH_SIZE
=
256
RETRO_GPT_CHUNK_LENGTH
=
64
######## Query. ########
RETRO_QUERY_NUM_NEIGHBORS_QUERY
=
200
RETRO_QUERY_NUM_NEIGHBORS_SAVE
=
20
RETRO_QUERY_EF_SEARCH
=
32
RETRO_QUERY_NPROBE
=
4096
######## Args. ########
ARGS
=
"
\
--distributed-timeout-minutes 600
\
--tensor-model-parallel-size 1
\
--pipeline-model-parallel-size 1
\
--num-layers 24
\
--hidden-size 1024
\
--num-attention-heads 16
\
--micro-batch-size 1
\
--global-batch-size
${
RETRO_GPT_GLOBAL_BATCH_SIZE
}
\
--seq-length 512
\
--max-position-embeddings 512
\
--load <path/to/bert/checkpoint>
\
--exit-on-missing-checkpoint
\
--no-load-optim
\
--data-path
${
RETRO_GPT_DATA_PATH
}
\
--tokenizer-type BertWordPieceLowerCase
\
--vocab-file <path/to/bert/vocab>
\
--data-impl
${
RETRO_GPT_DATA_IMPL
}
\
--split
${
RETRO_GPT_SPLIT
}
\
--distributed-backend nccl
\
--lr 0.0001
\
--lr-decay-style linear
\
--min-lr 1.0e-5
\
--train-samples
${
RETRO_GPT_TRAIN_SAMPLES
}
\
--lr-decay-samples
${
RETRO_GPT_LR_DECAY_SAMPLES
}
\
--lr-warmup-samples
${
RETRO_GPT_LR_WARMUP_SAMPLES
}
\
--weight-decay 1e-2
\
--clip-grad 1.0
\
--eval-interval
${
RETRO_GPT_EVAL_INTERVAL
}
\
--eval-iters
${
RETRO_GPT_EVAL_ITERS
}
\
--fp16
\
--DDP-impl local
\
--dataloader-type
${
RETRO_GPT_DATALOADER_TYPE
}
\
--no-data-sharding
\
--no-gradient-accumulation-fusion
\
--no-async-tensor-model-parallel-allreduce
\
--bert-embedder-type megatron
\
--output-bert-embeddings
\
\
--retro-workdir
${
RETRO_WORKDIR
}
\
--retro-tasks
${
RETRO_TASKS
}
\
--retro-return-doc-ids
\
--retro-bert-vocab-file <path/to/bert/vocab>
\
--retro-bert-tokenizer-type BertWordPieceLowerCase
\
--retro-gpt-seed
${
RETRO_GPT_SEED
}
\
--retro-gpt-tokenizer-type GPTSentencePieceTokenizer
\
--retro-gpt-tokenizer-model <path/to/gpt/tokenizer/model>
\
--retro-gpt-seq-length
${
RETRO_GPT_SEQ_LENGTH
}
\
--retro-gpt-chunk-length
${
RETRO_GPT_CHUNK_LENGTH
}
\
--retro-gpt-global-batch-size
${
RETRO_GPT_GLOBAL_BATCH_SIZE
}
\
--retro-gpt-eval-interval
${
RETRO_GPT_EVAL_INTERVAL
}
\
--retro-gpt-eval-iters
${
RETRO_GPT_EVAL_ITERS
}
\
--retro-gpt-split
${
RETRO_GPT_SPLIT
}
\
--retro-gpt-data-impl
${
RETRO_GPT_DATA_IMPL
}
\
--retro-gpt-data-path
${
RETRO_GPT_DATA_PATH
}
\
--retro-index-str
${
RETRO_INDEX_STR
}
\
--retro-index-ntrain
${
RETRO_INDEX_NTRAIN
}
\
--retro-index-train-load-fraction
${
RETRO_INDEX_TRAIN_LOAD_FRACTION
}
\
--retro-index-add-load-fraction
${
RETRO_INDEX_ADD_LOAD_FRACTION
}
\
--retro-index-no-delete-training-embeddings
\
--retro-index-no-delete-added-codes
\
--retro-query-num-neighbors-query
${
RETRO_QUERY_NUM_NEIGHBORS_QUERY
}
\
--retro-query-num-neighbors-save
${
RETRO_QUERY_NUM_NEIGHBORS_SAVE
}
\
--retro-query-ef-search
${
RETRO_QUERY_EF_SEARCH
}
\
--retro-query-nprobe
${
RETRO_QUERY_NPROBE
}
\
"
######## Command. ########
FULL_CMD
=
"
\
pwd && cd
${
REPO_DIR
}
&& pwd &&
\
NPROCS
=
8
# Number of GPUs.
CMD
=
"
\
cd
${
REPO_DIR
}
&& pwd &&
\
export PYTHONPATH=
$PYTHONPATH
:
${
REPO_DIR
}
&&
\
python -m torch.distributed.
launch
\
python -m torch.distributed.
run
\
--nproc_per_node
${
NPROCS
}
\
--nnodes 1
\
--node_rank
${
NODE_RANK
}
\
--master_addr
${
MASTER_ADDR
}
\
--master_port 6000
\
$RETRO_PREPROCESS_CMD
\
tools/retro/main.py
${
ARGS
}
\
"
echo
"~~~~~~~~~~~~~~~~~~~~~~~~~~"
echo
"
FULL_
CMD = '
$
FULL_
CMD
'."
echo
"CMD = '
$CMD
'."
echo
"~~~~~~~~~~~~~~~~~~~~~~~~~~"
eval
$
FULL_
CMD
eval
$CMD
tools/retro/examples/pretrain_model.sh
View file @
3aca1415
#!/bin/bash
##################################################
# Example script for pretraining Retro.
##################################################
set
-u
unset
NCCL_DEBUG
export
CUDA_DEVICE_MAX_CONNECTIONS
=
1
NPROCS
=
8
# NPROCS must be <= number of GPUs.
######## GPT or Retro?. ########
# 0 : GPT.
# 1 : Retro
ADD_RETRIEVER
=
1
################ Dataset configs. ################
# This script contains methods to customize arguments to specific dataset
# types. Customize this script as needed for your datasets.
DIR
=
$(
cd
--
"
$(
dirname
--
"
${
BASH_SOURCE
[0]
}
"
)
"
&> /dev/null
&&
pwd
)
.
$DIR
/get_dataset_configs.sh
######## Megatron, Retro dirs. ########
################ Environment variables. ################
# *Note*: See 'Required environment variables' in 'get_preprocess_cmd.sh' for
# a description of the required environment variables. These variables can be
# set however a user would like. In our setup, we use another bash script
# (location defined by $RETRO_ENV_VARS) that sets all the environment variables
# at once.
.
$RETRO_ENV_VARS
REPO_DIR
=
"<path/to/megatron/repo>"
RETRO_WORKDIR
=
"<path/to/retro/data/directory>"
################ Data blend. ################
.
${
DATA_BLEND_SCRIPT
}
DATA_PATH
=
${
DATA_BLEND
}
######## Data. ########
######## Retro setup. ########
RETRO_ADD_RETRIEVER
=
1
RETRO_CYCLIC_TRAIN_ITERS
=
750000
RETRO_NUM_NEIGHBORS
=
2
DATA_BLEND
=
"<see --data-path in arguments.py>"
######## Args. ########
######## Arguments. ########
CHECKPOINT_DIR
=
${
RETRO_WORKDIR
}
/checkpoints/
${
RETRO_ADD_RETRIEVER
}
TENSORBOARD_DIR
=
"
${
CHECKPOINT_DIR
}
/tensorboard"
mkdir
-p
${
TENSORBOARD_DIR
}
ARGS
=
"
\
--save-interval 1000
\
--save
${
CHECKPOINT_DIR
}
\
--load
${
CHECKPOINT_DIR
}
\
--tensorboard-dir
${
TENSORBOARD_DIR
}
\
--log-interval 5
\
--log-interval 1
\
--use-flash-attn
\
--apply-layernorm-1p
\
--untie-embeddings-and-output-weights
\
--disable-bias-linear
\
--no-position-embedding
\
--use-rotary-position-embeddings
\
--rotary-percent 0.5
\
--swiglu
\
--attention-dropout 0.0
\
--hidden-dropout 0.0
\
--exit-duration-in-mins 220
\
--tensor-model-parallel-size 1
\
--pipeline-model-parallel-size 1
\
--num-layers
1
2
\
--hidden-size
768
\
--num-attention-heads 1
2
\
--seq-length
2048
\
--max-position-embeddings
2048
\
--micro-batch-size
4
\
--num-layers 2
4
\
--hidden-size
1024
\
--num-attention-heads 1
6
\
--seq-length
512
\
--max-position-embeddings
512
\
--micro-batch-size
16
\
--global-batch-size 256
\
--train-samples
${
RETRO_GPT_TRAIN_SAMPLES
}
\
--lr-decay-samples
${
LR_DECAY_SAMPLES
}
\
--lr-warmup-samples
${
LR_WARMUP_SAMPLES
}
\
--lr
6.0
e-
4
\
--min-lr
6.0
e-
5
\
--train-samples
200000
\
--lr-decay-samples
175000
\
--lr-warmup-samples
10000
\
--lr
2.5
e-
5
\
--min-lr
2.5
e-
6
\
--lr-decay-style cosine
\
--eval-i
n
ter
val
${
RETRO_GPT_EVAL_INTERVAL
}
\
--eval-iter
s
${
RETRO_GPT_EVAL_ITERS
}
\
--
data-path
${
DATA_PATH
}
\
--
vocab-file
${
GPT_VOCAB_FILE
}
\
--
merge-file
${
GPT_MERGE_FILE
}
\
--eval-iter
s 50
\
--eval-i
n
ter
val 2000
\
--
tokenizer-type GPTSentencePieceTokenizer
\
--
tokenizer-model <path/to/gpt/tokenizer/model>
\
--
data-path
${
DATA_BLEND
}
\
--split 98,2,0
\
--clip-grad 1.0
\
--weight-decay 0.1
\
--adam-beta1 0.9
\
--adam-beta2 0.95
\
--init-method-std 0.0
23
\
--init-method-std 0.0
07
\
--log-params-norm
\
--log-num-zeros-in-grad
\
--f
p
16
\
--
b
f16
\
--DDP-impl local
\
--dataloader-type
${
DATALOADER_TYPE
}
\
--no-data-sharding
\
--no-gradient-accumulation-fusion
\
"
if
[
"
$RETRO_ADD_RETRIEVER
"
=
"0"
]
;
then
######## Retro. ########
if
[
"
$ADD_RETRIEVER
"
=
"0"
]
;
then
SCRIPT
=
pretrain_gpt.py
else
ARGS
=
"
${
ARGS
}
\
--retro-add-retriever
\
--retro-workdir
${
RETRO_WORKDIR
}
\
--retro-cyclic-train-iters
${
RETRO_CYCLIC_TRAIN_ITERS
}
\
--retro-num-neighbors
${
RETRO_NUM_NEIGHBORS
}
\
--retro-add-retriever
\
"
SCRIPT
=
pretrain_retro.py
fi
echo
"~~~~~~~~~~~~~~~~~~~~~~~~~~"
echo
"ARGS = '
$ARGS
'."
echo
"~~~~~~~~~~~~~~~~~~~~~~~~~~"
######## Command. ########
python
-m
torch.distributed.launch
\
NPROCS
=
8
CMD
=
"
\
pwd && cd
${
REPO_DIR
}
&& pwd &&
\
export PYTHONPATH=
$PYTHONPATH
:
${
REPO_DIR
}
&&
\
python -m torch.distributed.run
\
--nproc_per_node
${
NPROCS
}
\
--nnodes 1
\
--node_rank
0
\
--master_addr
localhost
\
--node_rank
${
NODE_RANK
}
\
--master_addr
${
MASTER_ADDR
}
\
--master_port 6000
\
${
SCRIPT
}
\
${
ARGS
}
\
${
SCRIPT
}
${
ARGS
}
\
"
echo
"~~~~~~~~~~~~~~~~~~~~~~~~~~"
echo
"CMD = '
$CMD
'."
echo
"~~~~~~~~~~~~~~~~~~~~~~~~~~"
eval
$CMD
tools/retro/index/__init__.py
View file @
3aca1415
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from
.index
import
Index
from
.build
import
add_to_index
,
build_index
,
train_index
# from .index import Index
tools/retro/index/build.py
View file @
3aca1415
...
...
@@ -18,8 +18,10 @@ from tools.retro.index.factory import IndexFactory
from
tools.retro.utils
import
GPTToTextDataset
from
.utils
import
(
get_training_data_dir
,
get_training_data_merged
,
get_training_data_block_dir
,
get_training_data_block_paths
,
get_training_data_merged_path
,
get_training_data_root_dir
,
)
...
...
@@ -36,6 +38,43 @@ def get_empty_index_path():
return
empty_index_path
def
get_block_nload
(
block_path
,
load_fraction
):
with
h5py
.
File
(
block_path
)
as
fi
:
return
int
(
load_fraction
*
fi
[
"data"
].
shape
[
0
])
def
merge_embedding_blocks
():
if
torch
.
distributed
.
get_rank
()
!=
0
:
return
args
=
get_retro_args
()
# Get block, merged paths.
load_fraction
=
args
.
retro_index_train_load_fraction
block_paths
=
get_training_data_block_paths
()
bin_path
=
get_training_data_merged_path
()
# Skip, if already built.
if
os
.
path
.
exists
(
bin_path
):
return
# Merge blocks.
with
open
(
bin_path
,
"wb"
)
as
fo
:
byte_offset
=
0
for
block_idx
,
block_path
in
\
enumerate
(
tqdm
(
block_paths
,
"merge train embeddings"
)):
with
h5py
.
File
(
block_path
)
as
fi
:
nload
=
get_block_nload
(
block_path
,
load_fraction
)
block
=
np
.
array
(
fi
[
"data"
][:
nload
],
copy
=
False
)
fo
.
write
(
block
.
tobytes
())
byte_offset
+=
block
.
size
*
block
.
itemsize
fo
.
seek
(
byte_offset
)
def
embed_db
():
'''Embed DB chunks.
...
...
@@ -45,6 +84,10 @@ def embed_db():
args
=
get_retro_args
()
merged_train_data_path
=
get_training_data_merged_path
()
if
os
.
path
.
exists
(
merged_train_data_path
):
return
# Get db dataset.
gpt_dataset
=
get_merged_sampled_dataset
()
text_dataset
=
GPTToTextDataset
(
gpt_dataset
)
...
...
@@ -54,14 +97,19 @@ def embed_db():
args
.
retro_bert_max_chunk_length
,
args
.
retro_block_size
,
args
.
bert_embedder_type
)
embedder
.
embed_text_dataset
(
"index"
,
get_training_data_dir
(),
text_dataset
)
embedder
.
embed_text_dataset
(
"index"
,
get_training_data_block_dir
(),
text_dataset
)
# Merge embeddings.
merge_embedding_blocks
()
def
train_on_embeddings
():
'''Train index on embedded DB chunks.'''
args
=
get_retro_args
()
index
=
IndexFactory
.
get_index
(
args
.
retro_index_type
)
index
.
train
(
get_training_data_merged
)
index
.
train
()
def
remove_embeddings
():
...
...
@@ -71,7 +119,7 @@ def remove_embeddings():
return
empty_index_path
=
get_empty_index_path
()
assert
os
.
path
.
isfile
(
empty_index_path
)
shutil
.
rmtree
(
get_training_data_dir
(),
ignore_errors
=
True
)
shutil
.
rmtree
(
get_training_data_
root_
dir
(),
ignore_errors
=
True
)
def
train_index
():
...
...
@@ -92,7 +140,7 @@ def train_index():
torch
.
distributed
.
barrier
()
# Remove embeddings.
if
args
.
retro_delete
_index
_training_embeddings
:
if
args
.
retro_
index_
delete_training_embeddings
:
remove_embeddings
()
...
...
tools/retro/index/index.py
View file @
3aca1415
...
...
@@ -5,6 +5,7 @@ import numpy as np
import
os
import
torch
from
megatron
import
get_retro_args
from
tools.retro.external_libs
import
faiss
from
.utils
import
get_index_dir
...
...
@@ -30,13 +31,24 @@ class Index(abc.ABC):
faiss
.
ParameterSpace
().
set_index_parameter
(
index
,
"verbose"
,
v
)
def
get_empty_index_path
(
self
):
return
os
.
path
.
join
(
get_index_dir
(),
"empty.faissindex"
)
args
=
get_retro_args
()
return
os
.
path
.
join
(
get_index_dir
(),
"empty_%.3f.faissindex"
%
args
.
retro_index_train_load_fraction
,
)
def
get_empty_index
(
self
):
return
faiss
.
read_index
(
self
.
get_empty_index_path
())
def
get_added_index_path
(
self
):
return
os
.
path
.
join
(
get_index_dir
(),
"added.faissindex"
)
args
=
get_retro_args
()
return
os
.
path
.
join
(
get_index_dir
(),
"added_%.3f_%.3f.faissindex"
%
(
args
.
retro_index_train_load_fraction
,
args
.
retro_index_add_load_fraction
,
),
)
def
get_added_index
(
self
):
return
faiss
.
read_index
(
self
.
get_added_index_path
())
...
...
tools/retro/index/indexes/faiss_base.py
View file @
3aca1415
...
...
@@ -8,6 +8,7 @@ inherit from this class (see FaissParAddIndex, for an example).
"""
from
datetime
import
timedelta
import
numpy
as
np
import
os
import
torch
from
tqdm
import
tqdm
...
...
@@ -15,13 +16,16 @@ from tqdm import tqdm
from
megatron
import
get_retro_args
,
print_rank_0
from
tools.bert_embedding
import
BertEmbedder
from
tools.retro.external_libs
import
faiss
from
tools.retro.index
import
Index
from
tools.retro.index.utils
import
num_samples_to_block_ranges
from
tools.retro.index.index
import
Index
from
tools.retro.index.utils
import
(
get_training_data_merged_path
,
num_samples_to_block_ranges
,
)
class
FaissBaseIndex
(
Index
):
def
_train
(
self
,
input_data_loader
):
def
_train
(
self
):
'''Train index (rank 0's method).'''
args
=
get_retro_args
()
...
...
@@ -40,17 +44,24 @@ class FaissBaseIndex(Index):
return
# Load data.
inp
=
input_data_loader
()
merged_path
=
get_training_data_merged_path
()
inp
=
np
.
memmap
(
merged_path
,
dtype
=
"f4"
,
mode
=
"r"
,
).
reshape
((
-
1
,
args
.
hidden_size
))
# Init index.
index
=
faiss
.
index_factory
(
args
.
retro_index_nfeats
,
args
.
retro_index_str
)
# Move to GPU.
print
(
"> move faiss index to gpu."
)
index_ivf
=
faiss
.
extract_index_ivf
(
index
)
clustering_index
=
\
faiss
.
index_cpu_to_all_gpus
(
faiss
.
IndexFlatL2
(
index_ivf
.
d
))
index_ivf
.
clustering_index
=
clustering_index
print
(
"> finished moving to gpu."
)
self
.
c_verbose
(
index
,
True
)
self
.
c_verbose
(
index_ivf
,
True
)
self
.
c_verbose
(
index_ivf
.
quantizer
,
True
)
...
...
@@ -62,12 +73,12 @@ class FaissBaseIndex(Index):
# Save index.
faiss
.
write_index
(
index
,
empty_index_path
)
def
train
(
self
,
input_data_loader
):
def
train
(
self
):
'''Train index.'''
# Single process only.
if
torch
.
distributed
.
get_rank
()
==
0
:
self
.
_train
(
input_data_loader
)
self
.
_train
()
torch
.
distributed
.
barrier
()
...
...
tools/retro/index/indexes/faiss_par_add.py
View file @
3aca1415
...
...
@@ -10,6 +10,7 @@ the vast majority of the computational effort is embarrassingly parallel.
import
numpy
as
np
import
os
import
psutil
import
shutil
import
torch
from
tqdm
import
tqdm
...
...
@@ -104,6 +105,8 @@ class FaissParallelAddIndex(FaissBaseIndex):
if
os
.
path
.
exists
(
added_index_path
):
return
args
=
get_retro_args
()
# Index.
print_rank_0
(
"read empty index."
)
index
=
self
.
get_empty_index
()
...
...
@@ -112,10 +115,19 @@ class FaissParallelAddIndex(FaissBaseIndex):
# Add codes.
print_rank_0
(
"add codes."
)
code_paths
=
get_added_code_paths
()
for
code_path
in
tqdm
(
code_paths
,
"add codes"
):
pbar
=
tqdm
(
code_paths
)
for
code_path
in
pbar
:
pbar
.
set_description
(
"add codes, mem %.3f gb, %.1f%%"
%
(
psutil
.
virtual_memory
()[
3
]
/
1024
**
3
,
psutil
.
virtual_memory
()[
2
],
))
with
h5py
.
File
(
code_path
)
as
f
:
codes
=
np
.
copy
(
f
[
"data"
])
index_ivf
.
add_sa_codes
(
codes
)
nload
=
int
(
args
.
retro_index_add_load_fraction
*
f
[
"data"
].
shape
[
0
])
offset
=
int
(
os
.
path
.
basename
(
code_path
).
split
(
"-"
)[
0
])
xids
=
np
.
arange
(
offset
,
offset
+
nload
)
codes
=
np
.
copy
(
f
[
"data"
][:
nload
])
index_ivf
.
add_sa_codes
(
codes
,
xids
)
# Update index's ntotal.
index
.
ntotal
=
index_ivf
.
ntotal
...
...
@@ -129,18 +141,19 @@ class FaissParallelAddIndex(FaissBaseIndex):
if
torch
.
distributed
.
get_rank
()
!=
0
:
return
assert
os
.
path
.
isfile
(
self
.
get_added_index_path
())
shutil
.
rmtree
(
get_added_codes_dir
(),
ignore_errors
=
True
)
def
add
(
self
,
text_dataset
):
args
=
get_retro_args
()
if
args
.
retro_index_delete_added_codes
:
raise
Exception
(
"remove?"
)
shutil
.
rmtree
(
get_added_codes_dir
(),
ignore_errors
=
True
)
# Check if index already exists.
if
not
os
.
path
.
isfile
(
self
.
get_added_index_path
()):
def
add
(
self
,
text_dataset
):
# Encode chunks.
self
.
encode
(
text_dataset
)
# Encode chunks.
self
.
encode
(
text_dataset
)
# Add codes to index.
self
.
add_codes
()
# Add codes to index.
self
.
add_codes
()
# Wait for (single-process) adding to complete.
torch
.
distributed
.
barrier
()
...
...
tools/retro/index/utils.py
View file @
3aca1415
...
...
@@ -45,128 +45,28 @@ def num_samples_to_block_ranges(num_samples):
return
ranges
def
get_training_data_dir
():
return
os
.
path
.
join
(
get_index_dir
(),
"train_tmp"
)
def
get_training_data_paths
():
return
sorted
(
glob
.
glob
(
get_training_data_dir
()
+
"/*.hdf5"
))
def
get_added_codes_dir
():
return
os
.
path
.
join
(
get_index_dir
(),
"add_tmp"
)
def
get_added_code_paths
():
return
sorted
(
glob
.
glob
(
get_added_codes_dir
()
+
"/*.hdf5"
))
def
get_training_data_group_infos
():
def
get_training_data_root_dir
():
args
=
get_retro_args
()
return
os
.
path
.
join
(
args
.
retro_workdir
,
"index"
,
"train_emb"
)
block_paths
=
get_training_data_paths
()
max_group_size
=
args
.
retro_index_train_block_size
groups
=
[]
group
=
[]
group_size
=
0
for
block_path
in
block_paths
:
with
h5py
.
File
(
block_path
)
as
f
:
block_size
=
f
[
"data"
].
shape
[
0
]
group
.
append
(
block_path
)
group_size
+=
block_size
def
get_training_data_block_dir
():
return
os
.
path
.
join
(
get_training_data_root_dir
(),
"blocks"
)
if
group_size
>=
max_group_size
:
groups
.
append
({
"paths"
:
group
,
"size"
:
group_size
,
})
group
=
[]
group_size
=
0
if
group
:
groups
.
append
({
"paths"
:
group
,
"size"
:
group_size
,
})
return
groups
def
get_training_data_block_paths
():
return
sorted
(
glob
.
glob
(
get_training_data_block_dir
()
+
"/*.hdf5"
))
def
load_training_block
(
path
,
load_fraction
):
with
h5py
.
File
(
path
)
as
f
:
n_load
=
int
(
load_fraction
*
f
[
"data"
].
shape
[
0
])
return
np
.
copy
(
f
[
"data"
][:
n_load
])
def
load_training_group
(
executor
,
group_info
,
load_fraction
):
# Launch threads to load block data.
futures
=
[]
for
path
in
group_info
[
"paths"
]:
futures
.
append
(
executor
.
submit
(
load_training_block
,
path
,
load_fraction
))
# Collect block data.
block_datas
=
[]
for
future
in
futures
:
block_datas
.
append
(
future
.
result
())
# Concatenate blocks.
group_data
=
np
.
concatenate
(
block_datas
,
axis
=
0
)
# Garbage collect.
for
d
in
block_datas
:
del
d
gc
.
collect
()
return
group_data
def
get_training_data_merged_path
():
args
=
get_retro_args
()
return
os
.
path
.
join
(
get_training_data_root_dir
(),
"train_%.3f.bin"
%
args
.
retro_index_train_load_fraction
)
def
get_
training_data_merged
():
'''Merge embeddings into single dataset.'''
def
get_
added_codes_dir
():
return
os
.
path
.
join
(
get_index_dir
(),
"add_codes"
)
args
=
get_retro_args
()
# Setup.
ds_infos
=
get_indexed_dataset_infos
()
n_chunks_sampled
=
sum
(
d
[
"n_chunks_sampled"
]
for
d
in
ds_infos
)
load_fraction
=
args
.
retro_index_train_load_fraction
# Initialize merged data.
print
(
"allocate training data array."
)
t
=
time
.
time
()
data
=
np
.
empty
((
n_chunks_sampled
,
args
.
retro_index_nfeats
),
dtype
=
"f4"
)
print
(
" time : %.3f sec."
%
(
time
.
time
()
-
t
))
# Data groups (minimizing fragmentation).
group_infos
=
get_training_data_group_infos
()
# Load data blocks.
n_threads
=
max
(
len
(
group
[
"paths"
])
for
group
in
group_infos
)
with
concurrent
.
futures
.
ThreadPoolExecutor
(
max_workers
=
n_threads
)
as
executor
:
# Load data blocks.
print
(
"load training data blocks."
)
start_idx
=
0
pbar
=
tqdm
(
group_infos
)
for
group_info
in
pbar
:
pbar
.
set_description
(
"mem %.0f gb, %.1f%%"
%
(
psutil
.
virtual_memory
()[
3
]
/
1024
**
3
,
psutil
.
virtual_memory
()[
2
],
))
# Load group data.
group_data
=
load_training_group
(
executor
,
group_info
,
load_fraction
)
data
[
start_idx
:(
start_idx
+
len
(
group_data
))]
=
group_data
start_idx
+=
len
(
group_data
)
# Garbage collect.
del
group_data
gc
.
collect
()
# Handle load ratio <1.
data
=
data
[:
start_idx
]
print
(
"> training block data.shape = %s."
%
str
(
data
.
shape
))
return
data
def
get_added_code_paths
():
return
sorted
(
glob
.
glob
(
get_added_codes_dir
()
+
"/*.hdf5"
))
tools/retro/main.py
View file @
3aca1415
...
...
@@ -15,8 +15,8 @@ import torch
from
megatron
import
get_args
,
initialize_megatron
,
print_rank_0
from
megatron.global_vars
import
set_retro_args
from
tools.retro.db
import
build_db
from
tools.retro.index
.build
import
add_to_index
,
build_index
,
train_index
from
tools.retro.
pretraining.
query
import
query_pretraining_neighbors
from
tools.retro.index
import
add_to_index
,
build_index
,
train_index
from
tools.retro.query
import
query_pretraining_neighbors
from
tools.retro.utils
import
get_args_path
...
...
@@ -31,16 +31,69 @@ def add_retro_args(parser):
group
=
parser
.
add_argument_group
(
title
=
"Retro preprocessing."
)
group
.
add_argument
(
"--retro-gpt-vocab-file"
,
required
=
True
,
help
=
"GPT vocab file."
)
group
.
add_argument
(
"--retro-gpt-merge-file"
,
required
=
True
,
help
=
"GPT merge file."
)
# Basic args.
group
.
add_argument
(
"--retro-tasks"
,
default
=
"build"
,
help
=
"Comma-separated list of tasks to run. Run entire "
"preprocesing pipeline by using '--retro-tasks build'. "
"Alternatively, run individual stages with tasks (in "
"this order) 'db-build', 'index-build', or "
"'query-pretraining-neighbors'. For example, "
"'--retro-tasks db-build,index-build,"
"query-pretraining-neighbors' is equivalent to "
"'--retro-tasks build'; or the argument can contain "
"a subset of these tasks. Stages must always be run "
"in the correct order (listed above)."
)
group
.
add_argument
(
"--retro-block-size"
,
type
=
int
,
default
=
100000
,
help
=
"Number of chunks to process at a time when "
"generating Bert embeddings and querying the search "
"index. Partial results for each block are generally "
"saved to disk in separate files."
)
group
.
add_argument
(
"--retro-doc-block-size"
,
type
=
int
,
default
=
100000
,
help
=
"Number of documents to processe at time when "
"processing token datasets into chunk databases. The "
"partial chunk database for each block is saved into "
"a separate file."
)
# GPT args.
group
.
add_argument
(
'--retro-gpt-seed'
,
type
=
int
,
default
=
1234
,
help
=
'Random seed used for python, numpy, '
'pytorch, and cuda.'
)
group
.
add_argument
(
'--retro-gpt-data-impl'
,
type
=
str
,
default
=
'infer'
,
choices
=
[
'lazy'
,
'cached'
,
'mmap'
,
'infer'
],
help
=
'Implementation of indexed datasets.'
)
group
.
add_argument
(
'--retro-gpt-data-path'
,
nargs
=
'*'
,
required
=
True
,
help
=
'Path to the training dataset. Accepted format:'
'1) a single data path, 2) multiple datasets in the'
'form: dataset1-weight dataset1-path dataset2-weight '
'dataset2-path ... It is used with --split when a '
'single dataset used for all three: train, valid '
'and test. It is exclusive to the other '
'--*-data-path args'
)
group
.
add_argument
(
'--retro-gpt-split'
,
type
=
str
,
default
=
'969,30,1'
,
help
=
'Comma-separated list of proportions for training,'
' validation, and test split. For example the split '
'`90,5,5` will use 90%% of data for training, 5%% for '
'validation and 5%% for test.'
)
group
.
add_argument
(
'--retro-gpt-mmap-warmup'
,
action
=
'store_true'
,
help
=
'Warm up mmap files.'
)
group
.
add_argument
(
"--retro-gpt-eval-interval"
,
type
=
int
,
required
=
True
,
help
=
"GPT evaluation interval."
)
group
.
add_argument
(
"--retro-gpt-eval-iters"
,
type
=
int
,
required
=
True
,
help
=
"GPT evaluation iterations."
)
group
.
add_argument
(
"--retro-gpt-tokenizer-type"
,
required
=
True
,
help
=
"GPT tokenizer type."
)
group
.
add_argument
(
"--retro-gpt-seq-length"
,
type
=
int
,
default
=
2048
,
group
.
add_argument
(
"--retro-gpt-vocab-file"
,
help
=
"GPT vocab file."
)
group
.
add_argument
(
"--retro-gpt-merge-file"
,
help
=
"GPT merge file."
)
group
.
add_argument
(
"--retro-gpt-tokenizer-model"
,
help
=
"GPT tokenizer model file."
)
group
.
add_argument
(
"--retro-gpt-seq-length"
,
type
=
int
,
required
=
True
,
help
=
"GPT sequence length."
)
group
.
add_argument
(
"--retro-gpt-global-batch-size"
,
type
=
int
,
required
=
True
,
help
=
"GPT global batch size."
)
group
.
add_argument
(
"--retro-gpt-chunk-length"
,
type
=
int
,
default
=
64
,
help
=
"GPT chunk length."
)
# Bert args.
group
.
add_argument
(
"--retro-bert-vocab-file"
,
required
=
True
,
help
=
"Bert vocab file."
)
group
.
add_argument
(
"--retro-bert-tokenizer-type"
,
required
=
True
,
...
...
@@ -52,17 +105,8 @@ def add_retro_args(parser):
help
=
"Maximum sequence length for Bert embeddings. "
"(Named 'chunk' here in reference to these Bert "
"sequences being converted from GPT chunks.)"
)
group
.
add_argument
(
"--retro-tasks"
,
default
=
"build"
,
help
=
"Comma-separated list of tasks to run. Run entire "
"preprocesing pipeline by using '--retro-tasks build'. "
"Alternatively, run individual stages with tasks (in "
"this order) 'db-build', 'index-build', or "
"'pretraining-query-neighbors'. For example, "
"'--retro-tasks db-build,index-build,"
"pretraining-query-neighbors' is equivalent to "
"'--retro-tasks build'; or the argument can contain "
"a subset of these tasks. Stages must always be run "
"in the correct order (listed above)."
)
# Index args.
group
.
add_argument
(
"--retro-index-nfeats"
,
"-f"
,
type
=
int
,
default
=
1024
,
help
=
"Dimension of Bert embeddings. Bert-large is "
"commonly used, so this value defaults to 1024."
)
...
...
@@ -78,34 +122,10 @@ def add_retro_args(parser):
"faiss.index_factory(). For example, "
"'IVF262144_HNSW32,Flat' or "
"'OPQ32_256,IVF4194304_HNSW32,PQ32'."
)
group
.
add_argument
(
"--retro-ef-search"
,
type
=
int
,
default
=
256
,
help
=
"Index ef-search parameter for HNSW during "
"querying."
)
group
.
add_argument
(
"--retro-nprobe"
,
type
=
int
,
default
=
65536
,
help
=
"Index nprobe parameter for IVF during "
"querying."
)
group
.
add_argument
(
"--retro-nchunks-sampled"
,
type
=
int
,
required
=
True
,
group
.
add_argument
(
"--retro-index-ntrain"
,
type
=
int
,
required
=
True
,
help
=
"Number of database chunks to use for training "
"the index. This value must be less or equal to the "
"total number of chunks in the database."
)
group
.
add_argument
(
"--retro-doc-block-size"
,
type
=
int
,
default
=
100000
,
help
=
"Number of documents to processe at time when "
"processing token datasets into chunk databases. The "
"partial chunk database for each block is saved into "
"a separate file."
)
group
.
add_argument
(
"--retro-block-size"
,
type
=
int
,
default
=
100000
,
help
=
"Number of chunks to process at a time when "
"generating Bert embeddings and querying the search "
"index. Partial results for each block are generally "
"saved to disk in separate files."
)
group
.
add_argument
(
"--retro-index-train-block-size"
,
type
=
int
,
default
=
3750000
,
help
=
"As a memory fragmentation optimization, when "
"loading training data for training the search index, "
"enough data blocks loaded at a time until they reach "
"retro_index_train_block_size, and then this "
"data block is copied into the full training data "
"array."
)
group
.
add_argument
(
"--retro-index-train-load-fraction"
,
type
=
float
,
default
=
1.
,
help
=
"Fraction of sampled chunks to use for training "
...
...
@@ -113,19 +133,36 @@ def add_retro_args(parser):
"use too much memory; lowering the load fraction is "
"less costly than re-embedding a new sampled dataset "
"from scratch."
)
group
.
add_argument
(
"--retro-num-neighbors-query"
,
type
=
int
,
default
=
2000
,
group
.
add_argument
(
"--retro-index-add-load-fraction"
,
type
=
float
,
default
=
1.
,
help
=
"Fraction of database chunks to use for adding to "
"the index. Useful when our total index size would "
"use too much memory; lowering the load fraction is "
"less costly than re-designing our token datasets."
)
group
.
add_argument
(
"--retro-index-no-delete-training-embeddings"
,
action
=
'store_false'
,
dest
=
"retro_index_delete_training_embeddings"
,
help
=
"Skip deleting training embeddings for the search "
"index. Useful for debugging."
)
group
.
add_argument
(
"--retro-index-no-delete-added-codes"
,
action
=
'store_false'
,
dest
=
"retro_index_delete_added_codes"
,
help
=
"Skip deleting added codes for the search "
"index. Useful for debugging."
)
# Query args.
group
.
add_argument
(
"--retro-query-ef-search"
,
type
=
int
,
default
=
256
,
help
=
"Index ef-search parameter for HNSW during querying."
)
group
.
add_argument
(
"--retro-query-nprobe"
,
type
=
int
,
default
=
65536
,
help
=
"Index nprobe parameter for IVF during querying."
)
group
.
add_argument
(
"--retro-query-num-neighbors-query"
,
type
=
int
,
default
=
200
,
help
=
"Number of neighbors to retrieve when calling "
"index.search()."
)
group
.
add_argument
(
"--retro-num-neighbors-
target
"
,
type
=
int
,
default
=
20
0
,
group
.
add_argument
(
"--retro-
query-
num-neighbors-
save
"
,
type
=
int
,
default
=
20
,
help
=
"Number of neighbors to save to disk after "
"the index's returned neighbors. If longer than target "
"value, neighbors truncated; and if shorter than target "
"value, neighbors are padded with -1's."
)
group
.
add_argument
(
"--retro-no-delete-index-training-embeddings"
,
action
=
'store_false'
,
dest
=
"retro_delete_index_training_embeddings"
,
help
=
"Skip deleting training embeddings for the search "
"index. Useful for debugging."
)
# Enforce argument naming convention.
for
action
in
group
.
_group_actions
:
...
...
@@ -140,10 +177,16 @@ def add_retro_args(parser):
def
save_args
(
args
):
'''Save copy of args within retro workdir.'''
def
default_dump
(
obj
):
if
isinstance
(
obj
,
torch
.
dtype
):
return
str
(
obj
)
else
:
raise
Exception
(
"specialize for <%s>."
%
type
(
obj
).
__name__
)
if
torch
.
distributed
.
get_rank
()
==
0
:
args_path
=
get_args_path
(
args
.
retro_workdir
)
with
open
(
args_path
,
"w"
)
as
f
:
json
.
dump
(
vars
(
args
),
f
,
indent
=
4
,
default
=
lambda
o
:
"<skipped>"
)
json
.
dump
(
vars
(
args
),
f
,
indent
=
4
,
default
=
default_dump
)
torch
.
distributed
.
barrier
()
...
...
@@ -188,7 +231,7 @@ if __name__ == "__main__":
add_to_index
()
# add only
# Pretraining.
elif
task
==
"pretraining-
query-
neighbors"
:
elif
task
==
"
query-
pretraining-neighbors"
:
query_pretraining_neighbors
()
else
:
...
...
tools/retro/pretraining/utils.py
deleted
100644 → 0
View file @
0024a5c6
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import
os
from
megatron
import
get_retro_args
def
get_pretraining_workdir
():
args
=
get_retro_args
()
return
os
.
path
.
join
(
args
.
retro_workdir
,
"pretraining"
)
tools/retro/query/__init__.py
0 → 100644
View file @
3aca1415
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from
.query
import
query_pretraining_neighbors
tools/retro/
pretraining
/chunk_dataset.py
→
tools/retro/
query
/chunk_dataset.py
View file @
3aca1415
...
...
@@ -4,15 +4,16 @@ import os
import
torch
from
megatron
import
get_retro_args
,
print_rank_0
from
megatron.data.gpt_dataset
import
build_train_valid_test_datasets
from
megatron.data.gpt_dataset
import
build_train_valid_test_datasets
\
as
build_gpt_train_valid_test_datasets
from
megatron.training
import
(
build_train_valid_test_data
_loader
s
,
build_train_valid_test_data
sets
as
build_pretraining_train_valid_test_dataset
s
,
update_train_iters
,
)
from
tools.retro.db.utils
import
get_indexed_dataset_infos
from
tools.retro.utils
import
get_num_chunks_per_sample
from
.utils
import
get_
pretraining
_workdir
from
.utils
import
get_
neighbor_dirname
,
get_query
_workdir
class
ChunkDataset
(
torch
.
utils
.
data
.
Dataset
):
...
...
@@ -86,14 +87,14 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
print_rank_0
(
'> building train, validation, and test datasets '
'for GPT ...'
)
train_ds
,
valid_ds
,
test_ds
=
build_train_valid_test_datasets
(
data_prefix
=
args
.
data_path
,
data_impl
=
args
.
data_impl
,
splits_string
=
args
.
split
,
train_ds
,
valid_ds
,
test_ds
=
build_
gpt_
train_valid_test_datasets
(
data_prefix
=
args
.
retro_gpt_
data_path
,
data_impl
=
args
.
retro_gpt_
data_impl
,
splits_string
=
args
.
retro_gpt_
split
,
train_valid_test_num_samples
=
train_val_test_num_samples
,
seq_length
=
args
.
retro_gpt_seq_length
,
seed
=
args
.
seed
,
skip_warmup
=
(
not
args
.
mmap_warmup
),
seed
=
args
.
retro_gpt_
seed
,
skip_warmup
=
(
not
args
.
retro_gpt_
mmap_warmup
),
return_doc_ids
=
args
.
retro_return_doc_ids
)
print_rank_0
(
"> finished creating pretrained GPT datasets ..."
)
...
...
@@ -115,28 +116,23 @@ def get_chunk_dataset_map():
verify_indexed_dataset_order
()
# Datasets.
print_rank_0
(
" > data loader."
)
train_data_loader
,
valid_data_loader
,
test_data_loader
\
=
build_train_valid_test_data_loaders
(
train_valid_test_datasets_provider
)
data_loader_map
=
{
"train"
:
train_data_loader
,
"valid"
:
valid_data_loader
,
"test"
:
test_data_loader
,
print_rank_0
(
" > datasets."
)
train_ds
,
valid_ds
,
test_ds
=
build_pretraining_train_valid_test_datasets
(
train_valid_test_datasets_provider
)
sample_dataset_map
=
{
"train"
:
train_ds
,
"valid"
:
valid_ds
,
"test"
:
test_ds
,
}
# Info dict.
workdir
=
get_pretraining_workdir
()
dataset_map
=
{
chunk_dataset_map
=
{
key
:
{
"neighbor_dir"
:
os
.
path
.
join
(
workdir
,
os
.
path
.
basename
(
loader
.
dataset
.
datasets
[
0
].
index_prefix
),
),
"data"
:
ChunkDataset
(
loader
.
dataset
,
args
.
retro_gpt_chunk_length
),
"neighbor_dir"
:
get_neighbor_dirname
(
key
,
sample_ds
),
"data"
:
ChunkDataset
(
sample_ds
,
args
.
retro_gpt_chunk_length
),
}
for
key
,
loader
in
data_loader
_map
.
items
()
if
loader
for
key
,
sample_ds
in
sample_dataset
_map
.
items
()
if
sample_ds
}
return
dataset_map
return
chunk_
dataset_map
tools/retro/
pretraining
/query.py
→
tools/retro/
query
/query.py
View file @
3aca1415
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
from
collections
import
defaultdict
import
numpy
as
np
import
os
import
psutil
import
time
import
torch
from
tqdm
import
tqdm
from
megatron
import
get_retro_args
,
mpu
,
print_rank_0
from
megatron
import
get_retro_args
,
print_rank_0
from
tools.bert_embedding
import
BertEmbedder
from
tools.bert_embedding.utils
import
get_missing_blocks_by_rank
from
tools.retro.db.utils
import
(
get_merged_train_dataset
as
get_db_merged_train_dataset
,
get_train_doc_chunk_map
,
)
from
tools.retro.db.utils
import
\
get_merged_train_dataset
as
get_db_merged_train_dataset
from
tools.retro.external_libs
import
faiss
,
h5py
from
tools.retro.index.factory
import
IndexFactory
from
tools.retro.index.utils
import
get_index_dir
,
num_samples_to_block_ranges
from
tools.retro.index.utils
import
get_index_dir
from
tools.retro.utils
import
GPTToTextDataset
from
.chunk_dataset
import
get_chunk_dataset_map
from
.chunk_dataset
import
get_chunk_dataset_map
as
get_query_dataset_map
def
get_index
(
chunk_db_dataset
,
ondisk
=
False
):
def
get_index
(
ondisk
=
False
):
'''Read index from disk.'''
args
=
get_retro_args
()
# Chunk db block ranges.
n_db_chunks
=
len
(
chunk_db_dataset
)
dataset_block_ranges
=
num_samples_to_block_ranges
(
n_db_chunks
)
# Load index.
index_wrapper
=
IndexFactory
.
get_index
(
args
.
retro_index_type
)
index_dir
=
get_index_dir
()
...
...
@@ -42,9 +36,9 @@ def get_index(chunk_db_dataset, ondisk=False):
# Search parameters.
faiss
.
ParameterSpace
().
set_index_parameter
(
index
,
"efSearch"
,
args
.
retro_ef_search
)
args
.
retro_
query_
ef_search
)
faiss
.
ParameterSpace
().
set_index_parameter
(
index
,
"nprobe"
,
args
.
retro_nprobe
)
args
.
retro_
query_
nprobe
)
return
index
...
...
@@ -58,8 +52,9 @@ def embed_block(gpt_dataset, block, embedder):
return
embedder
.
embed_text_dataset
(
text_block_dataset
)
def
query_embeddings
(
index
,
banned_chunk_map
,
chunk_id_range
,
embeddings
,
sample_map
,
n_chunks_per_sample
,
def
query_embeddings
(
db_dataset
,
index
,
embeddings
,
chunk_id_range
,
sample_map
,
n_chunks_per_sample
,
verbose
=
True
):
'''Query neighbors of a block of embeddings.'''
...
...
@@ -70,24 +65,13 @@ def query_embeddings(index, banned_chunk_map, chunk_id_range,
t
=
time
.
time
()
assert
index
.
ntotal
>
0
,
"check we don't accidentally have an empty index."
_
,
query_neighbor_ids
=
\
index
.
search
(
embeddings
,
args
.
retro_num_neighbors_query
)
index
.
search
(
embeddings
,
args
.
retro_
query_
num_neighbors_query
)
if
verbose
:
print_rank_0
(
" time : %.3f sec."
%
(
time
.
time
()
-
t
))
# Banned neighbor ids.
if
verbose
:
print_rank_0
(
"get banned neighbor ids."
)
sample_banned_chunk_id_map
=
{}
for
sample_id
,
sample
in
sample_map
.
items
():
dataset_idx
=
sample
[
"dataset_idx"
].
item
()
doc_ids
=
sample
[
"doc_ids"
].
tolist
()
banned_chunk_ids
=
set
()
for
doc_id
in
doc_ids
:
banned_chunk_ids
.
update
(
banned_chunk_map
[(
dataset_idx
,
doc_id
)])
sample_banned_chunk_id_map
[
sample_id
]
=
banned_chunk_ids
# Filter banned neighbor ids.
if
verbose
:
print_rank_0
(
"filter banned neighbor ids."
)
filtered_neighbor_ids
=
np
.
full
(
shape
=
(
len
(
query_neighbor_ids
),
args
.
retro_num_neighbors_
target
),
shape
=
(
len
(
query_neighbor_ids
),
args
.
retro_
query_
num_neighbors_
save
),
fill_value
=-
1
,
dtype
=
"int64"
,
)
...
...
@@ -95,24 +79,30 @@ def query_embeddings(index, banned_chunk_map, chunk_id_range,
for
chunk_id
in
range
(
min_chunk_id
,
max_chunk_id
):
sample_id
=
chunk_id
//
n_chunks_per_sample
sample
=
sample_map
[
sample_id
]
sample_dataset_idx
=
sample
[
"dataset_idx"
].
item
()
sample_doc_ids
=
sample
[
"doc_ids"
].
tolist
()
sample_doc_tuples
=
[(
sample_dataset_idx
,
d
)
for
d
in
sample_doc_ids
]
# Get valid neighbors (!= -1).
query_row
=
[
i
for
i
in
query_neighbor_ids
[
chunk_id
-
min_chunk_id
]
if
i
>=
0
]
# Filter row.
filtered_row
=
[
i
for
i
in
query_row
if
i
not
in
sample_banned_chunk_id_map
[
sample_id
]]
filtered_row
=
filtered_row
[:
args
.
retro_num_neighbors_target
]
filtered_row
=
[
i
for
i
in
query_row
if
tuple
(
db_dataset
.
doc_tuples
[
i
].
tolist
())
not
in
sample_doc_tuples
]
filtered_row
=
filtered_row
[:
args
.
retro_query_num_neighbors_save
]
filtered_row
+=
\
[
-
1
]
*
(
args
.
retro_num_neighbors_
target
-
len
(
filtered_row
))
[
-
1
]
*
(
args
.
retro_
query_
num_neighbors_
save
-
len
(
filtered_row
))
filtered_neighbor_ids
[
chunk_id
-
min_chunk_id
]
=
filtered_row
return
query_neighbor_ids
,
filtered_neighbor_ids
def
query_embedding_block
(
index
,
banned_chunk_map
,
chunk_id_range
,
embeddings
,
sample_map
,
n_chunks_per_sample
):
def
query_embedding_block
(
db_dataset
,
index
,
embeddings
,
chunk_id_range
,
sample_map
,
n_chunks_per_sample
):
query_neighbor_ids
=
[]
filtered_neighbor_ids
=
[]
...
...
@@ -131,8 +121,9 @@ def query_embedding_block(index, banned_chunk_map, chunk_id_range,
chunk_id_range
[
0
]
+
partial_end_idx
,
)
partial_query_neighbor_ids
,
partial_filtered_neighbor_ids
=
\
query_embeddings
(
index
,
banned_chunk_map
,
partial_chunk_id_range
,
partial_embeddings
,
sample_map
,
n_chunks_per_sample
,
query_embeddings
(
db_dataset
,
index
,
partial_embeddings
,
partial_chunk_id_range
,
sample_map
,
n_chunks_per_sample
,
verbose
=
False
)
query_neighbor_ids
.
append
(
partial_query_neighbor_ids
)
filtered_neighbor_ids
.
append
(
partial_filtered_neighbor_ids
)
...
...
@@ -144,26 +135,33 @@ def query_embedding_block(index, banned_chunk_map, chunk_id_range,
return
query_neighbor_ids
,
filtered_neighbor_ids
def
query_block_neighbors
(
index
,
banned_chunk_map
,
chunk_dataset
,
block
,
embedder
):
def
query_block_neighbors
(
db_dataset
,
query_dataset
,
index
,
embedder
,
block
):
'''Query neighbors of a dataset block (i.e., range).'''
args
=
get_retro_args
()
n_chunks_per_sample
=
chunk
_dataset
.
n_chunks_per_sample
n_chunks_per_sample
=
query
_dataset
.
n_chunks_per_sample
# Sample map.
sample_ids
=
sorted
(
list
(
set
(
chunk_id
//
n_chunks_per_sample
for
chunk_id
in
range
(
*
block
[
"range"
]))))
sample_map
=
{
i
:
chunk_dataset
.
sample_dataset
[
i
]
for
i
in
sample_ids
}
sample_map
=
{}
for
i
in
sample_ids
:
sample
=
query_dataset
.
sample_dataset
[
i
]
sample_map
[
i
]
=
{
"dataset_idx"
:
sample
[
"dataset_idx"
],
"doc_ids"
:
sample
[
"doc_ids"
],
}
# Embed block.
embeddings
=
embed_block
(
chunk
_dataset
,
block
,
embedder
)
embeddings
=
embed_block
(
query
_dataset
,
block
,
embedder
)
# Query embeddings.
_
,
filtered_neighbor_ids
=
query_embedding_block
(
index
,
banned_chunk_map
,
block
[
"range"
]
,
embeddings
,
sample_map
,
n_chunks_per_sample
)
db_dataset
,
index
,
embeddings
,
block
[
"range"
]
,
sample_map
,
n_chunks_per_sample
)
# Save neighbors.
print_rank_0
(
"save neighbors."
)
...
...
@@ -173,22 +171,22 @@ def query_block_neighbors(index, banned_chunk_map, chunk_dataset,
f
.
close
()
def
query_dataset_neighbors
(
index
,
banned_chunk_map
,
prefix
,
chunk_dataset
,
neighbor_dir
,
embedder
):
def
query_dataset_neighbors
(
db_dataset
,
query_dataset
,
prefix
,
neighbor_dir
,
index
,
embedder
):
'''Query neighbors of each chunk within a dataset.'''
args
=
get_retro_args
()
def
validate
(
f
):
assert
f
[
"neighbors"
].
shape
[
1
]
==
args
.
retro_num_neighbors_
target
,
\
assert
f
[
"neighbors"
].
shape
[
1
]
==
args
.
retro_
query_
num_neighbors_
save
,
\
"neighbors.shape == %s; num_neighbors_target == %d."
%
(
str
(
f
[
"neighbors"
].
shape
),
args
.
retro_num_neighbors_target
,
)
n_missing_blocks
,
missing_neighbor_blocks
=
get_missing_blocks_by_rank
(
neighbor_dir
,
len
(
chunk
_dataset
),
len
(
query
_dataset
),
args
.
retro_block_size
,
validate
=
validate
,
)
...
...
@@ -199,16 +197,19 @@ def query_dataset_neighbors(index, banned_chunk_map,
if
block
is
not
None
:
# Progress.
print_rank_0
(
"query '%s' block %d / %d ... %s."
%
(
print_rank_0
(
"query '%s' block %d / %d ... %s
... mem %.3f gb, %.1f%%
."
%
(
prefix
,
block_index
,
len
(
missing_neighbor_blocks
),
block
[
"path"
],
os
.
path
.
basename
(
block
[
"path"
]),
psutil
.
virtual_memory
()[
3
]
/
1024
**
3
,
psutil
.
virtual_memory
()[
2
],
))
# Query block neighbors.
query_block_neighbors
(
index
,
banned_chunk_map
,
chunk_dataset
,
block
,
embedder
)
query_block_neighbors
(
db_dataset
,
query_dataset
,
index
,
embedder
,
block
)
# Synchronize progress across all ranks. (for easier observation)
print_rank_0
(
" > waiting for other ranks to finish block."
)
...
...
@@ -225,17 +226,16 @@ def query_pretraining_neighbors():
# Load chunk db dataset.
print_rank_0
(
"load chunk db dataset."
)
chunk_db_dataset
=
get_db_merged_train_dataset
()
db_dataset
=
get_db_merged_train_dataset
()
db_dataset
.
load_doc_tuples
()
# Load index
, banned chunk ids, datasets
.
# Load index.
print_rank_0
(
" > get index."
)
index
=
get_index
(
chunk_db_dataset
)
print_rank_0
(
" > get banned doc-chunk id map."
)
banned_chunk_map
=
get_train_doc_chunk_map
()
index
=
get_index
()
# Load datasets.
print_rank_0
(
" > get dataset map."
)
chunk
_dataset_map
=
get_
chunk
_dataset_map
()
query
_dataset_map
=
get_
query
_dataset_map
()
# Bert embedder.
embedder
=
BertEmbedder
(
args
.
retro_bert_batch_size
,
...
...
@@ -244,9 +244,9 @@ def query_pretraining_neighbors():
# Query each (i.e., train, valid, test) dataset.
print_rank_0
(
" > query."
)
for
prefix
,
info
in
chunk
_dataset_map
.
items
():
for
prefix
,
info
in
query
_dataset_map
.
items
():
print_rank_0
(
" > query '%s' dataset ... %d samples."
%
(
prefix
,
len
(
info
[
"data"
])))
query_dataset_neighbors
(
index
,
banned_chunk_map
,
prefix
,
info
[
"data"
],
info
[
"neighbor_dir"
],
embedder
)
query_dataset_neighbors
(
db_dataset
,
info
[
"data"
]
,
prefix
,
info
[
"neighbor_dir"
],
index
,
embedder
)
Prev
1
…
6
7
8
9
10
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment