Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
612f438a
Commit
612f438a
authored
Feb 17, 2021
by
Mostofa Patwary
Browse files
evaluation code ongoing
parent
ebc95c35
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
68 additions
and
31 deletions
+68
-31
megatron/checkpointing.py
megatron/checkpointing.py
+5
-3
megatron/data/biencoder_dataset_utils.py
megatron/data/biencoder_dataset_utils.py
+27
-0
megatron/data/realm_index.py
megatron/data/realm_index.py
+23
-20
megatron/indexer.py
megatron/indexer.py
+13
-8
No files found.
megatron/checkpointing.py
View file @
612f438a
...
...
@@ -371,7 +371,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
return
iteration
def
load_ict_checkpoint
(
model
,
only_query_model
=
False
,
only_
block
_model
=
False
,
from_realm_chkpt
=
False
):
def
load_ict_checkpoint
(
model
,
only_query_model
=
False
,
only_
context
_model
=
False
,
from_realm_chkpt
=
False
):
"""selectively load ICT models for indexing/retrieving from ICT or REALM checkpoints"""
args
=
get_args
()
...
...
@@ -393,14 +393,16 @@ def load_ict_checkpoint(model, only_query_model=False, only_block_model=False, f
state_dict
=
torch
.
load
(
checkpoint_name
,
map_location
=
'cpu'
)
ict_state_dict
=
state_dict
[
'model'
]
print
(
ict_state_dict
)
sys
.
exit
()
if
from_realm_chkpt
and
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
" loading ICT state dict from REALM"
,
flush
=
True
)
ict_state_dict
=
ict_state_dict
[
'retriever'
][
'ict_model'
]
if
only_query_model
:
ict_state_dict
.
pop
(
'context_model'
)
if
only_
block
_model
:
ict_state_dict
.
pop
(
'que
stion
_model'
)
if
only_
context
_model
:
ict_state_dict
.
pop
(
'que
ry
_model'
)
model
.
load_state_dict
(
ict_state_dict
)
torch
.
distributed
.
barrier
()
...
...
megatron/data/biencoder_dataset_utils.py
View file @
612f438a
...
...
@@ -9,6 +9,33 @@ from megatron.data.dataset_utils import create_masked_lm_predictions, pad_and_co
from
megatron
import
get_args
,
get_tokenizer
,
print_rank_0
,
mpu
def
get_one_epoch_dataloader
(
dataset
,
micro_batch_size
=
None
):
"""Specifically one epoch to be used in an indexing job."""
args
=
get_args
()
world_size
=
mpu
.
get_data_parallel_world_size
()
rank
=
mpu
.
get_data_parallel_rank
()
if
micro_batch_size
is
None
:
micro_batch_size
=
args
.
micro_batch_size
global_batch_size
=
micro_batch_size
*
world_size
num_workers
=
args
.
num_workers
sampler
=
torch
.
utils
.
data
.
SequentialSampler
(
dataset
)
# importantly, drop_last must be False to get all the data.
assert
False
,
'DistributedBatchSampler deprecated, change the implementation'
from
megatron.data.samplers
import
DistributedBatchSampler
batch_sampler
=
DistributedBatchSampler
(
sampler
,
batch_size
=
global_batch_size
,
drop_last
=
False
,
rank
=
rank
,
world_size
=
world_size
)
return
torch
.
utils
.
data
.
DataLoader
(
dataset
,
batch_sampler
=
batch_sampler
,
num_workers
=
num_workers
,
pin_memory
=
True
)
def
get_ict_batch
(
data_iterator
):
# Items and their type.
keys
=
[
'query_tokens'
,
'query_mask'
,
...
...
megatron/data/realm_index.py
View file @
612f438a
...
...
@@ -14,28 +14,29 @@ def detach(tensor):
return
tensor
.
detach
().
cpu
().
numpy
()
class
BlockData
(
object
):
"""Serializable data structure for holding data for blocks -- embeddings and necessary metadata for REALM"""
def
__init__
(
self
,
block_data_path
=
None
,
load_from_path
=
True
,
rank
=
None
):
class
OpenRetreivalDataStore
(
object
):
"""Serializable data structure for holding data for blocks -- embeddings
and necessary metadata for Retriever"""
def
__init__
(
self
,
embedding_path
=
None
,
load_from_path
=
True
,
rank
=
None
):
self
.
embed_data
=
dict
()
self
.
meta_data
=
dict
()
if
block_data
_path
is
None
:
#
self.meta_data = dict()
if
embedding
_path
is
None
:
args
=
get_args
()
block_data
_path
=
args
.
block_data
_path
embedding
_path
=
args
.
embedding
_path
rank
=
args
.
rank
self
.
block_data_path
=
block_data
_path
self
.
embedding_path
=
embedding
_path
self
.
rank
=
rank
if
load_from_path
:
self
.
load_from_file
()
block_data_name
=
os
.
path
.
splitext
(
self
.
block_data
_path
)[
0
]
block_data_name
=
os
.
path
.
splitext
(
self
.
embedding
_path
)[
0
]
self
.
temp_dir_name
=
block_data_name
+
'_tmp'
def
state
(
self
):
return
{
'embed_data'
:
self
.
embed_data
,
'meta_data'
:
self
.
meta_data
,
#
'meta_data': self.meta_data,
}
def
clear
(
self
):
...
...
@@ -50,26 +51,28 @@ class BlockData(object):
if
mpu
.
is_unitialized
()
or
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
"
\n
> Unpickling BlockData"
,
flush
=
True
)
state_dict
=
pickle
.
load
(
open
(
self
.
block_data
_path
,
'rb'
))
state_dict
=
pickle
.
load
(
open
(
self
.
embedding
_path
,
'rb'
))
if
mpu
.
is_unitialized
()
or
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
">> Finished unpickling BlockData
\n
"
,
flush
=
True
)
self
.
embed_data
=
state_dict
[
'embed_data'
]
self
.
meta_data
=
state_dict
[
'meta_data'
]
#
self.meta_data = state_dict['meta_data']
def
add_block_data
(
self
,
block_indices
,
block_embeds
,
block_metas
,
allow_overwrite
=
False
):
#def add_block_data(self, block_indices, block_embeds, block_metas, allow_overwrite=False):
def
add_block_data
(
self
,
row_id
,
block_embeds
,
allow_overwrite
=
False
):
"""Add data for set of blocks
:param
block_indices
: 1D array of unique int ids for the blocks
:param
row_id
: 1D array of unique int ids for the blocks
:param block_embeds: 2D array of embeddings of the blocks
:param block_metas: 2D array of metadata for the blocks.
#
:param block_metas: 2D array of metadata for the blocks.
In the case of REALM this will be [start_idx, end_idx, doc_idx]
"""
for
idx
,
embed
,
meta
in
zip
(
block_indices
,
block_embeds
,
block_metas
):
#for idx, embed, meta in zip(block_indices, block_embeds, block_metas):
for
idx
,
embed
in
zip
(
row_id
,
block_embeds
):
if
not
allow_overwrite
and
idx
in
self
.
embed_data
:
raise
ValueError
(
"Unexpectedly tried to overwrite block data"
)
self
.
embed_data
[
idx
]
=
np
.
float16
(
embed
)
self
.
meta_data
[
idx
]
=
meta
#
self.meta_data[idx] = meta
def
save_shard
(
self
):
"""Save the block data that was created this in this process"""
...
...
@@ -77,8 +80,8 @@ class BlockData(object):
os
.
makedirs
(
self
.
temp_dir_name
,
exist_ok
=
True
)
# save the data for each shard
with
open
(
'{}/{}.pkl'
.
format
(
self
.
temp_dir_name
,
self
.
rank
),
'wb'
)
as
data_file
:
pickle
.
dump
(
self
.
state
(),
data_file
)
with
open
(
'{}/{}.pkl'
.
format
(
self
.
temp_dir_name
,
self
.
rank
),
'wb'
)
as
writer
:
pickle
.
dump
(
self
.
state
(),
writer
)
def
merge_shards_and_save
(
self
):
"""Combine all the shards made using self.save_shard()"""
...
...
@@ -98,13 +101,13 @@ class BlockData(object):
# add the shard's data and check to make sure there is no overlap
self
.
embed_data
.
update
(
data
[
'embed_data'
])
self
.
meta_data
.
update
(
data
[
'meta_data'
])
#
self.meta_data.update(data['meta_data'])
assert
len
(
self
.
embed_data
)
==
old_size
+
shard_size
assert
seen_own_shard
# save the consolidated shards and remove temporary directory
with
open
(
self
.
block_data
_path
,
'wb'
)
as
final_file
:
with
open
(
self
.
embedding
_path
,
'wb'
)
as
final_file
:
pickle
.
dump
(
self
.
state
(),
final_file
)
shutil
.
rmtree
(
self
.
temp_dir_name
,
ignore_errors
=
True
)
...
...
megatron/indexer.py
View file @
612f438a
import
sys
import
torch
import
torch.distributed
as
dist
...
...
@@ -5,10 +6,11 @@ from megatron import get_args
from
megatron
import
mpu
from
megatron.checkpointing
import
load_ict_checkpoint
from
megatron.data.ict_dataset
import
get_ict_dataset
from
megatron.data.realm_dataset_utils
import
get_one_epoch_dataloader
from
megatron.data.realm_index
import
detach
,
BlockData
from
megatron.data.realm_dataset_utils
import
get_ict_batch
from
megatron.model.realm_model
import
general_ict_model_provider
from
megatron.data.biencoder_dataset_utils
import
get_one_epoch_dataloader
from
megatron.data.realm_index
import
detach
,
OpenRetreivalDataStore
from
megatron.data.biencoder_dataset_utils
import
get_ict_batch
from
megatron.model.biencoder_model
import
biencoder_model_provider
#from megatron.model.realm_model import general_ict_model_provider
from
megatron.training
import
get_model
...
...
@@ -34,13 +36,16 @@ class IndexBuilder(object):
def
load_attributes
(
self
):
"""Load the necessary attributes: model, dataloader and empty BlockData"""
model
=
get_model
(
lambda
:
general_ict_model_provider
(
only_block_model
=
True
))
self
.
model
=
load_ict_checkpoint
(
model
,
only_block_model
=
True
,
from_realm_chkpt
=
self
.
using_realm_chkpt
)
model
=
get_model
(
lambda
:
biencoder_model_provider
(
only_context_model
=
True
))
self
.
model
=
load_ict_checkpoint
(
model
,
only_context_model
=
True
,
from_realm_chkpt
=
self
.
using_realm_chkpt
)
sys
.
exit
()
self
.
model
.
eval
()
self
.
dataset
=
get_ict_dataset
()
self
.
dataloader
=
iter
(
get_one_epoch_dataloader
(
self
.
dataset
,
self
.
batch_size
))
self
.
block_data
=
BlockData
(
load_from_path
=
False
)
self
.
block_data
=
OpenRetreivalDataStore
(
load_from_path
=
False
)
print
(
"load_attributes is done"
,
flush
=
True
)
sys
.
exit
()
def
track_and_report_progress
(
self
,
batch_size
):
"""Utility function for tracking progress"""
self
.
iteration
+=
1
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment