Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
bcd605f8
Commit
bcd605f8
authored
Feb 23, 2021
by
Mostofa Patwary
Browse files
Added code for building embeddings and savings
parent
31d39ec0
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
359 additions
and
98 deletions
+359
-98
megatron/arguments.py
megatron/arguments.py
+5
-0
megatron/checkpointing.py
megatron/checkpointing.py
+13
-13
megatron/data/biencoder_dataset_utils.py
megatron/data/biencoder_dataset_utils.py
+25
-12
megatron/data/data_samplers.py
megatron/data/data_samplers.py
+14
-4
megatron/data/orqa_wiki_dataset.py
megatron/data/orqa_wiki_dataset.py
+205
-0
megatron/data/realm_index.py
megatron/data/realm_index.py
+19
-19
megatron/indexer.py
megatron/indexer.py
+76
-45
megatron/model/__init__.py
megatron/model/__init__.py
+0
-2
megatron/training.py
megatron/training.py
+0
-1
tools/create_doc_index.py
tools/create_doc_index.py
+2
-2
No files found.
megatron/arguments.py
View file @
bcd605f8
...
@@ -635,6 +635,9 @@ def _add_data_args(parser):
...
@@ -635,6 +635,9 @@ def _add_data_args(parser):
group
.
add_argument
(
'--retriever-seq-length'
,
type
=
int
,
default
=
256
,
group
.
add_argument
(
'--retriever-seq-length'
,
type
=
int
,
default
=
256
,
help
=
'Maximum sequence length for the biencoder model '
help
=
'Maximum sequence length for the biencoder model '
' for retriever'
)
' for retriever'
)
group
.
add_argument
(
'--sample-rate'
,
type
=
float
,
default
=
1.0
,
help
=
'sample rate for training data. Supposed to be 0 '
' < sample_rate < 1'
)
group
.
add_argument
(
'--mask-prob'
,
type
=
float
,
default
=
0.15
,
group
.
add_argument
(
'--mask-prob'
,
type
=
float
,
default
=
0.15
,
help
=
'Probability of replacing a token with mask.'
)
help
=
'Probability of replacing a token with mask.'
)
group
.
add_argument
(
'--short-seq-prob'
,
type
=
float
,
default
=
0.1
,
group
.
add_argument
(
'--short-seq-prob'
,
type
=
float
,
default
=
0.1
,
...
@@ -704,6 +707,8 @@ def _add_biencoder_args(parser):
...
@@ -704,6 +707,8 @@ def _add_biencoder_args(parser):
'ICT dataset'
)
'ICT dataset'
)
group
.
add_argument
(
'--use-one-sent-docs'
,
action
=
'store_true'
,
group
.
add_argument
(
'--use-one-sent-docs'
,
action
=
'store_true'
,
help
=
'Whether to use one sentence documents in ICT'
)
help
=
'Whether to use one sentence documents in ICT'
)
group
.
add_argument
(
'--evidence-data-path'
,
type
=
str
,
default
=
None
,
help
=
'Path to Wikipedia Evidence frm DPR paper'
)
# training
# training
group
.
add_argument
(
'--retriever-report-topk-accuracies'
,
nargs
=
'+'
,
type
=
int
,
group
.
add_argument
(
'--retriever-report-topk-accuracies'
,
nargs
=
'+'
,
type
=
int
,
...
...
megatron/checkpointing.py
View file @
bcd605f8
...
@@ -383,42 +383,42 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
...
@@ -383,42 +383,42 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
return
iteration
return
iteration
def
load_ict_checkpoint
(
model
,
only_query_model
=
False
,
only_context_model
=
False
,
from_realm_chkpt
=
False
):
def
load_biencoder_checkpoint
(
model
,
only_query_model
=
False
,
"""selectively load ICT models for indexing/retrieving from ICT or REALM checkpoints"""
only_context_model
=
False
,
custom_load_path
=
None
):
"""
selectively load retrieval models for indexing/retrieving
from saved checkpoints
"""
args
=
get_args
()
args
=
get_args
()
model
=
utils
.
unwrap_model
(
model
)
model
=
utils
.
unwrap_model
(
model
)
load_path
=
args
.
load
if
from_realm_chkpt
else
args
.
ict_
load
load_path
=
custom_load_path
if
custom_load_path
is
not
None
else
args
.
load
tracker_filename
=
get_checkpoint_tracker_filename
(
load_path
)
tracker_filename
=
get_checkpoint_tracker_filename
(
load_path
)
with
open
(
tracker_filename
,
'r'
)
as
f
:
with
open
(
tracker_filename
,
'r'
)
as
f
:
iteration
=
int
(
f
.
read
().
strip
())
iteration
=
int
(
f
.
read
().
strip
())
# assert iteration > 0
checkpoint_name
=
get_checkpoint_name
(
load_path
,
iteration
,
False
)
checkpoint_name
=
get_checkpoint_name
(
load_path
,
iteration
,
False
)
if
mpu
.
get_data_parallel_rank
()
==
0
:
if
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
'global rank {} is loading checkpoint {}'
.
format
(
print
(
'global rank {} is loading checkpoint {}'
.
format
(
torch
.
distributed
.
get_rank
(),
checkpoint_name
))
torch
.
distributed
.
get_rank
(),
checkpoint_name
))
state_dict
=
torch
.
load
(
checkpoint_name
,
map_location
=
'cpu'
)
state_dict
=
torch
.
load
(
checkpoint_name
,
map_location
=
'cpu'
)
ict_state_dict
=
state_dict
[
'model'
]
ret_state_dict
=
state_dict
[
'model'
]
print
(
ict_state_dict
)
sys
.
exit
()
if
from_realm_chkpt
and
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
" loading ICT state dict from REALM"
,
flush
=
True
)
ict_state_dict
=
ict_state_dict
[
'retriever'
][
'ict_model'
]
if
only_query_model
:
if
only_query_model
:
ic
t_state_dict
.
pop
(
'context_model'
)
re
t_state_dict
.
pop
(
'context_model'
)
if
only_context_model
:
if
only_context_model
:
ic
t_state_dict
.
pop
(
'query_model'
)
re
t_state_dict
.
pop
(
'query_model'
)
model
.
load_state_dict
(
ict_state_dict
)
assert
len
(
model
)
==
1
model
[
0
].
load_state_dict
(
ret_state_dict
)
torch
.
distributed
.
barrier
()
torch
.
distributed
.
barrier
()
if
mpu
.
get_data_parallel_rank
()
==
0
:
if
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
' successfully loaded {}'
.
format
(
checkpoint_name
))
print
(
' successfully loaded {}'
.
format
(
checkpoint_name
))
return
model
return
model
megatron/data/biencoder_dataset_utils.py
View file @
bcd605f8
...
@@ -4,10 +4,21 @@ import time
...
@@ -4,10 +4,21 @@ import time
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
megatron
import
mpu
,
print_rank_0
from
megatron
import
get_args
,
get_tokenizer
,
mpu
,
print_rank_0
from
megatron.data.dataset_utils
import
create_masked_lm_predictions
,
pad_and_convert_to_numpy
from
megatron.data.dataset_utils
import
create_masked_lm_predictions
,
\
from
megatron
import
get_args
,
get_tokenizer
,
print_rank_0
,
mpu
pad_and_convert_to_numpy
from
megatron.data.data_samplers
import
MegatronPretrainingSampler
def
make_attention_mask
(
source_block
,
target_block
):
"""
Returns a 2-dimensional (2-D) attention mask
:param source_block: 1-D array
:param target_block: 1-D array
"""
mask
=
(
target_block
[
None
,
:]
>=
1
)
*
(
source_block
[:,
None
]
>=
1
)
mask
=
mask
.
astype
(
np
.
int64
)
# (source_length, target_length)
return
mask
def
get_one_epoch_dataloader
(
dataset
,
micro_batch_size
=
None
):
def
get_one_epoch_dataloader
(
dataset
,
micro_batch_size
=
None
):
"""Specifically one epoch to be used in an indexing job."""
"""Specifically one epoch to be used in an indexing job."""
...
@@ -20,15 +31,17 @@ def get_one_epoch_dataloader(dataset, micro_batch_size=None):
...
@@ -20,15 +31,17 @@ def get_one_epoch_dataloader(dataset, micro_batch_size=None):
global_batch_size
=
micro_batch_size
*
world_size
global_batch_size
=
micro_batch_size
*
world_size
num_workers
=
args
.
num_workers
num_workers
=
args
.
num_workers
sampler
=
torch
.
utils
.
data
.
SequentialSampler
(
dataset
)
# Use megatron's sampler with consumed samples set to 0 as
# importantly, drop_last must be False to get all the data.
# this is only for evaluation and don't intend to resume half way.
assert
False
,
'DistributedBatchSampler deprecated, change the implementation'
# Also, set the drop last to false as don't intend to remove
from
megatron.data.samplers
import
DistributedBatchSampler
# the last batch
batch_sampler
=
DistributedBatchSampler
(
sampler
,
batch_sampler
=
MegatronPretrainingSampler
(
batch_size
=
global_batch_size
,
total_samples
=
len
(
dataset
),
drop_last
=
False
,
consumed_samples
=
0
,
rank
=
rank
,
micro_batch_size
=
args
.
micro_batch_size
,
world_size
=
world_size
)
data_parallel_rank
=
mpu
.
get_data_parallel_rank
(),
data_parallel_size
=
mpu
.
get_data_parallel_world_size
(),
drop_last
=
False
)
return
torch
.
utils
.
data
.
DataLoader
(
dataset
,
return
torch
.
utils
.
data
.
DataLoader
(
dataset
,
batch_sampler
=
batch_sampler
,
batch_sampler
=
batch_sampler
,
...
...
megatron/data/data_samplers.py
View file @
bcd605f8
...
@@ -57,7 +57,7 @@ def build_pretraining_data_loader(dataset, consumed_samples):
...
@@ -57,7 +57,7 @@ def build_pretraining_data_loader(dataset, consumed_samples):
class
MegatronPretrainingSampler
:
class
MegatronPretrainingSampler
:
def
__init__
(
self
,
total_samples
,
consumed_samples
,
micro_batch_size
,
def
__init__
(
self
,
total_samples
,
consumed_samples
,
micro_batch_size
,
data_parallel_rank
,
data_parallel_size
):
data_parallel_rank
,
data_parallel_size
,
drop_last
=
True
):
# Keep a copy of input params for later use.
# Keep a copy of input params for later use.
self
.
total_samples
=
total_samples
self
.
total_samples
=
total_samples
self
.
consumed_samples
=
consumed_samples
self
.
consumed_samples
=
consumed_samples
...
@@ -65,6 +65,7 @@ class MegatronPretrainingSampler:
...
@@ -65,6 +65,7 @@ class MegatronPretrainingSampler:
self
.
data_parallel_rank
=
data_parallel_rank
self
.
data_parallel_rank
=
data_parallel_rank
self
.
micro_batch_times_data_parallel_size
=
\
self
.
micro_batch_times_data_parallel_size
=
\
self
.
micro_batch_size
*
data_parallel_size
self
.
micro_batch_size
*
data_parallel_size
self
.
drop_last
=
drop_last
# Sanity checks.
# Sanity checks.
assert
self
.
total_samples
>
0
,
\
assert
self
.
total_samples
>
0
,
\
...
@@ -81,17 +82,26 @@ class MegatronPretrainingSampler:
...
@@ -81,17 +82,26 @@ class MegatronPretrainingSampler:
def
__len__
(
self
):
def
__len__
(
self
):
return
self
.
total_samples
return
self
.
total_samples
def
get_start_end_idx
(
self
):
start_idx
=
self
.
data_parallel_rank
*
self
.
micro_batch_size
end_idx
=
start_idx
+
self
.
micro_batch_size
return
start_idx
,
end_idx
def
__iter__
(
self
):
def
__iter__
(
self
):
batch
=
[]
batch
=
[]
# Last batch
if not complete will be dropped.
# Last batch
will be dropped if drop_last is not set False
for
idx
in
range
(
self
.
consumed_samples
,
self
.
total_samples
):
for
idx
in
range
(
self
.
consumed_samples
,
self
.
total_samples
):
batch
.
append
(
idx
)
batch
.
append
(
idx
)
if
len
(
batch
)
==
self
.
micro_batch_times_data_parallel_size
:
if
len
(
batch
)
==
self
.
micro_batch_times_data_parallel_size
:
start_idx
=
self
.
data_parallel_rank
*
self
.
micro_batch_size
start_idx
,
end_idx
=
self
.
get_start_end_idx
()
end_idx
=
start_idx
+
self
.
micro_batch_size
yield
batch
[
start_idx
:
end_idx
]
yield
batch
[
start_idx
:
end_idx
]
batch
=
[]
batch
=
[]
# Check the last partial batch and see drop_last is set
if
len
(
batch
)
>
0
and
not
self
.
drop_last
:
start_idx
,
end_idx
=
self
.
get_start_end_idx
()
yield
batch
[
start_idx
:
end_idx
]
class
MegatronPretrainingRandomSampler
:
class
MegatronPretrainingRandomSampler
:
...
...
megatron/data/orqa_wiki_dataset.py
0 → 100644
View file @
bcd605f8
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Wikipedia dataset from DPR code for ORQA."""
from
abc
import
ABC
import
csv
import
numpy
as
np
import
random
import
torch
from
torch.utils.data
import
Dataset
from
megatron
import
print_rank_0
,
get_args
,
get_tokenizer
,
mpu
from
megatron.data.biencoder_dataset_utils
import
make_attention_mask
def
get_open_retrieval_wiki_dataset
():
args
=
get_args
()
tokenizer
=
get_tokenizer
()
dataset
=
OpenRetrievalEvidenceDataset
(
'2018 Wikipedia from DPR codebase'
,
'evidence'
,
args
.
evidence_data_path
,
tokenizer
,
args
.
retriever_seq_length
)
return
dataset
def
get_open_retrieval_batch
(
data_iterator
):
# Items and their type.
keys
=
[
'row_id'
,
'context'
,
'context_mask'
,
'context_types'
,
'context_pad_mask'
]
datatype
=
torch
.
int64
# Broadcast data.
data
=
None
if
data_iterator
is
None
else
next
(
data_iterator
)
data_b
=
mpu
.
broadcast_data
(
keys
,
data
,
datatype
)
# Unpack.
row_id
=
data_b
[
'row_id'
].
long
()
context
=
data_b
[
'context'
].
long
()
# TODO: make the context mask a binary one
context_mask
=
(
data_b
[
'context_mask'
]
<
0.5
)
context_types
=
data_b
[
'context_types'
].
long
()
context_pad_mask
=
data_b
[
'context_pad_mask'
].
long
()
return
row_id
,
context
,
context_mask
,
context_types
,
context_pad_mask
def
build_tokens_types_paddings_from_text
(
row
,
tokenizer
,
max_seq_length
):
"""Build token types and paddings, trim if needed, and pad if needed."""
title_ids
=
tokenizer
.
tokenize
(
row
[
'title'
])
context_ids
=
tokenizer
.
tokenize
(
row
[
'text'
])
# Appending the title of the context at front
extended_context_ids
=
title_ids
+
[
tokenizer
.
sep_id
]
+
context_ids
context_ids
,
context_types
,
context_pad_mask
=
\
build_tokens_types_paddings_from_ids
(
extended_context_ids
,
max_seq_length
,
tokenizer
.
cls
,
tokenizer
.
sep
,
tokenizer
.
pad
)
return
context_ids
,
context_types
,
context_pad_mask
# noinspection DuplicatedCode
def
build_tokens_types_paddings_from_ids
(
text_ids
,
max_seq_length
,
cls_id
,
sep_id
,
pad_id
):
"""Build token types and paddings, trim if needed, and pad if needed."""
enc_ids
=
[]
tokentypes_enc
=
[]
# [CLS].
enc_ids
.
append
(
cls_id
)
tokentypes_enc
.
append
(
0
)
# A.
len_src
=
len
(
text_ids
)
enc_ids
.
extend
(
text_ids
)
tokentypes_enc
.
extend
([
0
]
*
len_src
)
# Cap the size.
if
len
(
enc_ids
)
>
max_seq_length
-
1
:
enc_ids
=
enc_ids
[
0
:
max_seq_length
-
1
]
tokentypes_enc
=
tokentypes_enc
[
0
:
max_seq_length
-
1
]
# [SEP].
enc_ids
.
append
(
sep_id
)
tokentypes_enc
.
append
(
0
)
num_tokens_enc
=
len
(
enc_ids
)
# Padding.
padding_length
=
max_seq_length
-
len
(
enc_ids
)
if
padding_length
>
0
:
enc_ids
.
extend
([
pad_id
]
*
padding_length
)
tokentypes_enc
.
extend
([
pad_id
]
*
padding_length
)
pad_mask
=
([
1
]
*
num_tokens_enc
)
+
([
0
]
*
padding_length
)
pad_mask
=
np
.
array
(
pad_mask
,
dtype
=
np
.
int64
)
return
enc_ids
,
tokentypes_enc
,
pad_mask
def
build_sample
(
row_id
,
context_ids
,
context_types
,
context_pad_mask
):
"""Convert to numpy and return a sample consumed by the batch producer."""
context_ids
=
np
.
array
(
context_ids
,
dtype
=
np
.
int64
)
context_types
=
np
.
array
(
context_types
,
dtype
=
np
.
int64
)
context_mask
=
make_attention_mask
(
context_ids
,
context_ids
)
sample
=
({
'row_id'
:
row_id
,
'context'
:
context_ids
,
'context_mask'
:
context_mask
,
'context_types'
:
context_types
,
'context_pad_mask'
:
context_pad_mask
})
return
sample
class
OpenRetrievalEvidenceDataset
(
ABC
,
Dataset
):
"""Open Retrieval Evidence dataset class."""
def
__init__
(
self
,
task_name
,
dataset_name
,
datapath
,
tokenizer
,
max_seq_length
):
# Store inputs.
self
.
task_name
=
task_name
self
.
dataset_name
=
dataset_name
self
.
tokenizer
=
tokenizer
self
.
max_seq_length
=
max_seq_length
print_rank_0
(
' > building {} dataset for {}:'
.
format
(
self
.
task_name
,
self
.
dataset_name
))
# Process the files.
print_rank_0
(
datapath
)
self
.
samples
,
self
.
id2text
=
self
.
process_samples_from_single_path
(
datapath
)
args
=
get_args
()
if
args
.
sample_rate
<
1
:
# subsample
k
=
int
(
len
(
self
.
samples
)
*
args
.
sample_rate
)
self
.
samples
=
random
.
sample
(
self
.
samples
,
k
)
print_rank_0
(
' >> total number of samples: {}'
.
format
(
len
(
self
.
samples
)))
def
__len__
(
self
):
return
len
(
self
.
samples
)
def
__getitem__
(
self
,
idx
):
row
=
self
.
samples
[
idx
]
context_ids
,
context_types
,
context_pad_mask
=
\
build_tokens_types_paddings_from_text
(
row
,
self
.
tokenizer
,
self
.
max_seq_length
)
sample
=
build_sample
(
row
[
'doc_id'
],
context_ids
,
context_types
,
context_pad_mask
)
return
sample
@
staticmethod
def
process_samples_from_single_path
(
filename
):
print_rank_0
(
' > Processing {} ...'
.
format
(
filename
))
total
=
0
rows
=
[]
id2text
=
{}
with
open
(
filename
)
as
tsvfile
:
reader
=
csv
.
reader
(
tsvfile
,
delimiter
=
'
\t
'
)
next
(
reader
,
None
)
# skip the headers
for
row
in
reader
:
# file format: doc_id, doc_text, title
doc_id
=
int
(
row
[
0
])
text
=
row
[
1
]
title
=
row
[
2
]
rows
.
append
({
'doc_id'
:
doc_id
,
'text'
:
text
,
'title'
:
title
})
assert
doc_id
not
in
id2text
id2text
[
doc_id
]
=
(
text
,
title
)
total
+=
1
if
total
%
100000
==
0
:
print_rank_0
(
' > processed {} rows so far ...'
.
format
(
total
))
print_rank_0
(
' >> processed {} samples.'
.
format
(
len
(
rows
)))
return
rows
,
id2text
megatron/data/realm_index.py
View file @
bcd605f8
...
@@ -15,11 +15,12 @@ def detach(tensor):
...
@@ -15,11 +15,12 @@ def detach(tensor):
class
OpenRetreivalDataStore
(
object
):
class
OpenRetreivalDataStore
(
object
):
"""Serializable data structure for holding data for blocks -- embeddings
"""
and necessary metadata for Retriever"""
Serializable data structure for holding data for blocks --
embeddings and necessary metadata for Retriever
"""
def
__init__
(
self
,
embedding_path
=
None
,
load_from_path
=
True
,
rank
=
None
):
def
__init__
(
self
,
embedding_path
=
None
,
load_from_path
=
True
,
rank
=
None
):
self
.
embed_data
=
dict
()
self
.
embed_data
=
dict
()
#self.meta_data = dict()
if
embedding_path
is
None
:
if
embedding_path
is
None
:
args
=
get_args
()
args
=
get_args
()
embedding_path
=
args
.
embedding_path
embedding_path
=
args
.
embedding_path
...
@@ -36,13 +37,13 @@ class OpenRetreivalDataStore(object):
...
@@ -36,13 +37,13 @@ class OpenRetreivalDataStore(object):
def
state
(
self
):
def
state
(
self
):
return
{
return
{
'embed_data'
:
self
.
embed_data
,
'embed_data'
:
self
.
embed_data
,
#'meta_data': self.meta_data,
}
}
def
clear
(
self
):
def
clear
(
self
):
"""Clear the embedding data structures to save memory.
"""
The metadata ends up getting used, and is also much smaller in dimensionality
Clear the embedding data structures to save memory.
so it isn't really worth clearing.
The metadata ends up getting used, and is also much smaller in
dimensionality so it isn't really worth clearing.
"""
"""
self
.
embed_data
=
dict
()
self
.
embed_data
=
dict
()
...
@@ -56,35 +57,34 @@ class OpenRetreivalDataStore(object):
...
@@ -56,35 +57,34 @@ class OpenRetreivalDataStore(object):
print
(
">> Finished unpickling BlockData
\n
"
,
flush
=
True
)
print
(
">> Finished unpickling BlockData
\n
"
,
flush
=
True
)
self
.
embed_data
=
state_dict
[
'embed_data'
]
self
.
embed_data
=
state_dict
[
'embed_data'
]
#self.meta_data = state_dict['meta_data']
#def add_block_data(self, block_indices, block_embeds, block_metas, allow_overwrite=False):
def
add_block_data
(
self
,
row_id
,
block_embeds
,
allow_overwrite
=
False
):
def
add_block_data
(
self
,
row_id
,
block_embeds
,
allow_overwrite
=
False
):
"""Add data for set of blocks
"""
Add data for set of blocks
:param row_id: 1D array of unique int ids for the blocks
:param row_id: 1D array of unique int ids for the blocks
:param block_embeds: 2D array of embeddings of the blocks
:param block_embeds: 2D array of embeddings of the blocks
#:param block_metas: 2D array of metadata for the blocks.
In the case of retriever this will be [start_idx, end_idx, doc_idx]
In the case of REALM this will be [start_idx, end_idx, doc_idx]
"""
"""
#for idx, embed, meta in zip(block_indices, block_embeds, block_metas):
for
idx
,
embed
in
zip
(
row_id
,
block_embeds
):
for
idx
,
embed
in
zip
(
row_id
,
block_embeds
):
if
not
allow_overwrite
and
idx
in
self
.
embed_data
:
if
not
allow_overwrite
and
idx
in
self
.
embed_data
:
raise
ValueError
(
"Unexpectedly tried to overwrite block data"
)
raise
ValueError
(
"Unexpectedly tried to overwrite block data"
)
self
.
embed_data
[
idx
]
=
np
.
float16
(
embed
)
self
.
embed_data
[
idx
]
=
np
.
float16
(
embed
)
#self.meta_data[idx] = meta
def
save_shard
(
self
):
def
save_shard
(
self
):
"""Save the block data that was created this in this process"""
"""
Save the block data that was created this in this process
"""
if
not
os
.
path
.
isdir
(
self
.
temp_dir_name
):
if
not
os
.
path
.
isdir
(
self
.
temp_dir_name
):
os
.
makedirs
(
self
.
temp_dir_name
,
exist_ok
=
True
)
os
.
makedirs
(
self
.
temp_dir_name
,
exist_ok
=
True
)
# save the data for each shard
# save the data for each shard
with
open
(
'{}/{}.pkl'
.
format
(
self
.
temp_dir_name
,
self
.
rank
),
'wb'
)
as
writer
:
with
open
(
'{}/{}.pkl'
.
format
(
self
.
temp_dir_name
,
self
.
rank
),
'wb'
)
\
as
writer
:
pickle
.
dump
(
self
.
state
(),
writer
)
pickle
.
dump
(
self
.
state
(),
writer
)
def
merge_shards_and_save
(
self
):
def
merge_shards_and_save
(
self
):
"""
Combine all the shards made using
self.
save_shard
()"""
#
Combine all the shards made using save_shard
shard_names
=
os
.
listdir
(
self
.
temp_dir_name
)
shard_names
=
os
.
listdir
(
self
.
temp_dir_name
)
seen_own_shard
=
False
seen_own_shard
=
False
...
@@ -99,9 +99,9 @@ class OpenRetreivalDataStore(object):
...
@@ -99,9 +99,9 @@ class OpenRetreivalDataStore(object):
old_size
=
len
(
self
.
embed_data
)
old_size
=
len
(
self
.
embed_data
)
shard_size
=
len
(
data
[
'embed_data'
])
shard_size
=
len
(
data
[
'embed_data'
])
# add the shard's data and check to make sure there is no overlap
# add the shard's data and check to make sure there
# is no overlap
self
.
embed_data
.
update
(
data
[
'embed_data'
])
self
.
embed_data
.
update
(
data
[
'embed_data'
])
#self.meta_data.update(data['meta_data'])
assert
len
(
self
.
embed_data
)
==
old_size
+
shard_size
assert
len
(
self
.
embed_data
)
==
old_size
+
shard_size
assert
seen_own_shard
assert
seen_own_shard
...
...
megatron/indexer.py
View file @
bcd605f8
...
@@ -4,27 +4,32 @@ import torch.distributed as dist
...
@@ -4,27 +4,32 @@ import torch.distributed as dist
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron
import
mpu
from
megatron
import
mpu
from
megatron.checkpointing
import
load_ict_checkpoint
from
megatron.checkpointing
import
load_biencoder_checkpoint
from
megatron.data.ict_dataset
import
get_ict_dataset
from
megatron.data.orqa_wiki_dataset
import
get_open_retrieval_wiki_dataset
from
megatron.data.orqa_wiki_dataset
import
get_open_retrieval_batch
from
megatron.data.biencoder_dataset_utils
import
get_one_epoch_dataloader
from
megatron.data.biencoder_dataset_utils
import
get_one_epoch_dataloader
from
megatron.data.realm_index
import
detach
,
OpenRetreivalDataStore
from
megatron.data.realm_index
import
detach
,
OpenRetreivalDataStore
from
megatron.data.biencoder_dataset_utils
import
get_ict_batch
from
megatron.model.biencoder_model
import
biencoder_model_provider
from
megatron.model.biencoder_model
import
biencoder_model_provider
#from megatron.model.realm_model import general_ict_model_provider
from
megatron.training
import
get_model
from
megatron.training
import
get_model
class
IndexBuilder
(
object
):
class
IndexBuilder
(
object
):
"""Object for taking one pass over a dataset and creating a BlockData of its embeddings"""
"""
Object for taking one pass over a dataset and creating a BlockData of its
embeddings
"""
def
__init__
(
self
):
def
__init__
(
self
):
args
=
get_args
()
args
=
get_args
()
self
.
model
=
None
self
.
model
=
None
self
.
dataloader
=
None
self
.
dataloader
=
None
self
.
block_data
=
None
self
.
evidence_embedder_obj
=
None
self
.
biencoder_shared_query_context_model
=
\
args
.
biencoder_shared_query_context_model
# need to know whether we're using a REALM checkpoint (args.load) or ICT checkpoint
# need to know whether we're using a REALM checkpoint (args.load)
# or ICT checkpoint
assert
not
(
args
.
load
and
args
.
ict_load
)
assert
not
(
args
.
load
and
args
.
ict_load
)
self
.
using_realm_chkpt
=
args
.
ict_load
is
None
#
self.using_realm_chkpt = args.ict_load is None
self
.
log_interval
=
args
.
indexer_log_interval
self
.
log_interval
=
args
.
indexer_log_interval
self
.
batch_size
=
args
.
indexer_batch_size
self
.
batch_size
=
args
.
indexer_batch_size
...
@@ -35,62 +40,88 @@ class IndexBuilder(object):
...
@@ -35,62 +40,88 @@ class IndexBuilder(object):
self
.
iteration
=
self
.
total_processed
=
0
self
.
iteration
=
self
.
total_processed
=
0
def
load_attributes
(
self
):
def
load_attributes
(
self
):
"""Load the necessary attributes: model, dataloader and empty BlockData"""
"""
model
=
get_model
(
lambda
:
biencoder_model_provider
(
only_context_model
=
True
))
Load the necessary attributes: model, dataloader and empty BlockData
self
.
model
=
load_ict_checkpoint
(
model
,
only_context_model
=
True
,
from_realm_chkpt
=
self
.
using_realm_chkpt
)
"""
sys
.
exit
()
only_context_model
=
True
self
.
model
.
eval
()
if
self
.
biencoder_shared_query_context_model
:
self
.
dataset
=
get_ict_dataset
()
only_context_model
=
False
self
.
dataloader
=
iter
(
get_one_epoch_dataloader
(
self
.
dataset
,
self
.
batch_size
))
self
.
block_data
=
OpenRetreivalDataStore
(
load_from_path
=
False
)
model
=
get_model
(
lambda
:
biencoder_model_provider
(
only_context_model
\
print
(
"load_attributes is done"
,
flush
=
True
)
=
only_context_model
,
biencoder_shared_query_context_model
=
\
sys
.
exit
()
self
.
biencoder_shared_query_context_model
))
self
.
model
=
load_biencoder_checkpoint
(
model
,
only_context_model
=
only_context_model
)
assert
len
(
self
.
model
)
==
1
self
.
model
[
0
].
eval
()
self
.
dataset
=
get_open_retrieval_wiki_dataset
()
self
.
dataloader
=
iter
(
get_one_epoch_dataloader
(
self
.
dataset
,
\
self
.
batch_size
))
self
.
evidence_embedder_obj
=
OpenRetreivalDataStore
(
\
load_from_path
=
False
)
def
track_and_report_progress
(
self
,
batch_size
):
def
track_and_report_progress
(
self
,
batch_size
):
"""Utility function for tracking progress"""
"""
Utility function for tracking progress
"""
self
.
iteration
+=
1
self
.
iteration
+=
1
self
.
total_processed
+=
batch_size
*
self
.
num_total_builders
self
.
total_processed
+=
batch_size
*
self
.
num_total_builders
if
self
.
is_main_builder
and
self
.
iteration
%
self
.
log_interval
==
0
:
if
self
.
is_main_builder
and
self
.
iteration
%
self
.
log_interval
==
0
:
print
(
'Batch {:10d} | Total {:10d}'
.
format
(
self
.
iteration
,
self
.
total_processed
),
flush
=
True
)
print
(
'Batch {:10d} | Total {:10d}'
.
format
(
self
.
iteration
,
self
.
total_processed
),
flush
=
True
)
def
build_and_save_index
(
self
):
def
build_and_save_index
(
self
):
"""Goes through one epoch of the dataloader and adds all data to this instance's BlockData.
"""
Goes through one epoch of the dataloader and adds all data to this
instance's BlockData.
The copy of BlockData is saved as a shard, which when run in a distributed setting will be
The copy of BlockData is saved as a shard, which when run in a
consolidated by the rank 0 process and saved as a final pickled BlockData.
distributed setting will be consolidated by the rank 0 process
and saved as a final pickled BlockData.
"""
"""
assert
len
(
self
.
model
)
==
1
unwrapped_model
=
self
.
model
[
0
]
while
not
hasattr
(
unwrapped_model
,
'embed_text'
):
unwrapped_model
=
unwrapped_model
.
module
while
True
:
while
True
:
try
:
try
:
# batch also has query_tokens and query_pad_data
# batch also has query_tokens and query_pad_data
_
,
_
,
block_tokens
,
block_pad_mask
,
block_sample_data
=
get_ict_batch
(
self
.
dataloader
)
row_id
,
context_tokens
,
context_mask
,
context_types
,
\
context_pad_mask
=
get_open_retrieval_batch
(
\
self
.
dataloader
)
except
(
StopIteration
,
IndexError
):
except
(
StopIteration
,
IndexError
):
break
break
unwrapped_model
=
self
.
model
# TODO: can we add with torch.no_grad() to reduce memory usage
while
not
hasattr
(
unwrapped_model
,
'embed_block'
):
unwrapped_model
=
unwrapped_model
.
module
# detach, separate fields and add to BlockData
# detach, separate fields and add to BlockData
block_logits
=
detach
(
unwrapped_model
.
embed_block
(
block_tokens
,
block_pad_mask
))
assert
context_mask
.
dtype
==
torch
.
bool
detached_data
=
detach
(
block_sample_data
)
context_logits
=
unwrapped_model
.
embed_text
(
unwrapped_model
.
context_model
,
context_tokens
,
context_mask
,
# block_sample_data is a 2D array [batch x 4]
context_types
)
# with columns [start_idx, end_idx, doc_idx, block_idx] same as class BlockSampleData
context_logits
=
detach
(
context_logits
)
block_indices
=
detached_data
[:,
3
]
row_id
=
detach
(
row_id
)
block_metas
=
detached_data
[:,
:
3
]
self
.
evidence_embedder_obj
.
add_block_data
(
row_id
,
context_logits
)
self
.
block_data
.
add_block_data
(
block_indices
,
block_logits
,
block_metas
)
self
.
track_and_report_progress
(
batch_size
=
len
(
row_id
)
)
self
.
track_and_report_progress
(
batch_size
=
block_tokens
.
shape
[
0
])
# This process signals to finalize its shard and then synchronize with
#
This process signals to finalize its shard and then synchronize with
the other processes
# the other processes
self
.
block_data
.
save_shard
()
self
.
evidence_embedder_obj
.
save_shard
()
torch
.
distributed
.
barrier
()
torch
.
distributed
.
barrier
()
del
self
.
model
del
self
.
model
# rank 0 process builds the final copy
# rank 0 process builds the final copy
if
self
.
is_main_builder
:
if
self
.
is_main_builder
:
self
.
block_data
.
merge_shards_and_save
()
self
.
evidence_embedder_obj
.
merge_shards_and_save
()
# make sure that every single piece of data was embedded
# make sure that every single piece of data was embedded
assert
len
(
self
.
block_data
.
embed_data
)
==
len
(
self
.
dataset
)
assert
len
(
self
.
evidence_embedder_obj
.
embed_data
)
==
\
self
.
block_data
.
clear
()
len
(
self
.
dataset
)
self
.
evidence_embedder_obj
.
clear
()
# complete building the final copy
torch
.
distributed
.
barrier
()
megatron/model/__init__.py
View file @
bcd605f8
...
@@ -34,13 +34,11 @@ from .bert_model import (BertModel,
...
@@ -34,13 +34,11 @@ from .bert_model import (BertModel,
BertModelFirstStage
,
BertModelFirstStage
,
BertModelIntermediateStage
,
BertModelIntermediateStage
,
BertModelLastStage
)
BertModelLastStage
)
from
.realm_model
import
ICTBertModel
from
.gpt_model
import
(
GPTModel
,
from
.gpt_model
import
(
GPTModel
,
GPTModelFirstStage
,
GPTModelFirstStage
,
GPTModelIntermediateStage
,
GPTModelIntermediateStage
,
GPTModelLastStage
)
GPTModelLastStage
)
from
.language_model
import
get_language_model
from
.language_model
import
get_language_model
from
.module
import
FP16Module
from
.module
import
FP16Module
from
.realm_model
import
ICTBertModel
megatron/training.py
View file @
bcd605f8
...
@@ -44,7 +44,6 @@ from megatron.initialize import initialize_megatron
...
@@ -44,7 +44,6 @@ from megatron.initialize import initialize_megatron
from
megatron.initialize
import
write_args_to_tensorboard
from
megatron.initialize
import
write_args_to_tensorboard
from
megatron.learning_rates
import
AnnealingLR
from
megatron.learning_rates
import
AnnealingLR
from
megatron.model
import
DistributedDataParallel
as
LocalDDP
from
megatron.model
import
DistributedDataParallel
as
LocalDDP
from
megatron.model.realm_model
import
ICTBertModel
from
megatron.utils
import
check_adlr_autoresume_termination
from
megatron.utils
import
check_adlr_autoresume_termination
from
megatron.utils
import
unwrap_model
from
megatron.utils
import
unwrap_model
from
megatron.data.data_samplers
import
build_pretraining_data_loader
from
megatron.data.data_samplers
import
build_pretraining_data_loader
...
...
tools/create_doc_index.py
View file @
bcd605f8
...
@@ -3,6 +3,7 @@ import sys
...
@@ -3,6 +3,7 @@ import sys
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
os
.
path
.
pardir
)))
os
.
path
.
pardir
)))
from
megatron
import
print_rank_0
from
megatron.indexer
import
IndexBuilder
from
megatron.indexer
import
IndexBuilder
from
megatron.initialize
import
initialize_megatron
from
megatron.initialize
import
initialize_megatron
...
@@ -24,9 +25,8 @@ def main():
...
@@ -24,9 +25,8 @@ def main():
initialize_megatron
(
extra_args_provider
=
None
,
initialize_megatron
(
extra_args_provider
=
None
,
args_defaults
=
{
'tokenizer_type'
:
'BertWordPieceLowerCase'
})
args_defaults
=
{
'tokenizer_type'
:
'BertWordPieceLowerCase'
})
index_builder
=
IndexBuilder
()
index_builder
=
IndexBuilder
()
sys
.
exit
()
index_builder
.
build_and_save_index
()
index_builder
.
build_and_save_index
()
print_rank_0
(
"Build and save indices: done!"
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
main
()
main
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment