Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
15d0d55b
Commit
15d0d55b
authored
May 07, 2020
by
Neel Kant
Browse files
Add primitive filesystem-based IPC for indexer and trainer jobs
parent
0f5e2809
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
139 additions
and
54 deletions
+139
-54
faiss_test.py
faiss_test.py
+0
-1
indexer.py
indexer.py
+112
-44
megatron/data/realm_index.py
megatron/data/realm_index.py
+8
-4
megatron/model/realm_model.py
megatron/model/realm_model.py
+7
-2
megatron/training.py
megatron/training.py
+10
-0
pretrain_realm.py
pretrain_realm.py
+2
-3
No files found.
faiss_test.py
View file @
15d0d55b
...
@@ -67,7 +67,6 @@ def print_accuracy_stats(name, gold_indices, estimated_indices):
...
@@ -67,7 +67,6 @@ def print_accuracy_stats(name, gold_indices, estimated_indices):
print
(
'{:20s} First missing: {:4d} | All equal: {:4d} | Mixed: {:4d}'
.
format
(
name
,
*
[
results
[
s
]
for
s
in
result_strs
]))
print
(
'{:20s} First missing: {:4d} | All equal: {:4d} | Mixed: {:4d}'
.
format
(
name
,
*
[
results
[
s
]
for
s
in
result_strs
]))
def
create_and_test_gold
(
d
,
k
,
embeds
,
queries
):
def
create_and_test_gold
(
d
,
k
,
embeds
,
queries
):
times
=
[
time
.
time
()]
times
=
[
time
.
time
()]
gold_idx
=
index_factory
(
d
,
'Flat'
)
gold_idx
=
index_factory
(
d
,
'Flat'
)
...
...
hashed_
index.py
→
index
er
.py
View file @
15d0d55b
import
os
import
time
import
torch
import
torch
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
...
@@ -15,6 +18,7 @@ from pretrain_bert_ict import get_batch, model_provider
...
@@ -15,6 +18,7 @@ from pretrain_bert_ict import get_batch, model_provider
def
test_retriever
():
def
test_retriever
():
# TODO: Update this because it's outdated and definitely won't run.
initialize_megatron
(
extra_args_provider
=
None
,
initialize_megatron
(
extra_args_provider
=
None
,
args_defaults
=
{
'tokenizer_type'
:
'BertWordPieceLowerCase'
})
args_defaults
=
{
'tokenizer_type'
:
'BertWordPieceLowerCase'
})
args
=
get_args
()
args
=
get_args
()
...
@@ -57,75 +61,139 @@ def main():
...
@@ -57,75 +61,139 @@ def main():
initialize_megatron
(
extra_args_provider
=
None
,
initialize_megatron
(
extra_args_provider
=
None
,
args_defaults
=
{
'tokenizer_type'
:
'BertWordPieceLowerCase'
})
args_defaults
=
{
'tokenizer_type'
:
'BertWordPieceLowerCase'
})
args
=
get_args
()
args
=
get_args
()
model
=
load_ict_checkpoint
(
only_block_model
=
True
,
no_grad
=
True
)
ran_once
=
False
model
.
eval
()
dataset
=
get_ict_dataset
()
data_iter
=
iter
(
get_one_epoch_dataloader
(
dataset
))
all_block_data
=
BlockData
()
hashed_index
=
RandProjectionLSHIndex
(
embed_size
=
128
,
num_buckets
=
32
,
whiten
=
True
)
i
=
1
total
=
0
while
True
:
while
True
:
with
torch
.
no_grad
():
model
=
load_ict_checkpoint
(
only_block_model
=
True
,
no_grad
=
True
,
from_realm_chkpt
=
ran_once
)
try
:
model
.
eval
()
query_tokens
,
query_pad_mask
,
\
dataset
=
get_ict_dataset
()
block_tokens
,
block_pad_mask
,
block_index_data
=
get_batch
(
data_iter
)
data_iter
=
iter
(
get_one_epoch_dataloader
(
dataset
))
except
:
all_block_data
=
BlockData
()
break
hashed_index
=
RandProjectionLSHIndex
(
embed_size
=
128
,
num_buckets
=
32
,
whiten
=
True
)
block_index_data
=
detach
(
block_index_data
)
i
=
1
block_indices
=
block_index_data
[:,
3
]
total
=
0
block_meta
=
block_index_data
[:,
:
3
]
while
True
:
with
torch
.
no_grad
():
block_logits
=
detach
(
model
(
None
,
None
,
block_tokens
,
block_pad_mask
,
only_block
=
True
))
try
:
all_block_data
.
add_block_data
(
block_indices
,
block_logits
,
block_meta
)
query_tokens
,
query_pad_mask
,
\
block_tokens
,
block_pad_mask
,
block_index_data
=
get_batch
(
data_iter
)
total
+=
block_indices
.
size
except
:
i
+=
1
if
i
%
20
==
0
:
print
(
'Batch {:10d} | Total {:10d}'
.
format
(
i
,
total
),
flush
=
True
)
if
args
.
debug
:
break
break
all_block_data
.
save_shard
(
args
.
rank
)
block_index_data
=
detach
(
block_index_data
)
torch
.
distributed
.
barrier
()
block_indices
=
block_index_data
[:,
3
]
del
model
block_meta
=
block_index_data
[:,
:
3
]
block_logits
=
detach
(
model
(
None
,
None
,
block_tokens
,
block_pad_mask
,
only_block
=
True
))
all_block_data
.
add_block_data
(
block_indices
,
block_logits
,
block_meta
)
total
+=
block_indices
.
size
i
+=
1
if
i
%
20
==
0
:
print
(
'Batch {:10d} | Total {:10d}'
.
format
(
i
,
total
),
flush
=
True
)
if
args
.
debug
:
break
all_block_data
.
save_shard
(
args
.
rank
)
torch
.
distributed
.
barrier
()
del
model
if
args
.
rank
==
0
:
all_block_data
.
consolidate_shards_and_save
()
hashed_index
.
hash_whitened_block_embeds
(
all_block_data
)
hashed_index
.
save_to_file
()
else
:
all_block_data
.
clear
()
ran_once
=
True
set_index_com_file_ready
()
torch
.
distributed
.
barrier
()
while
not
check_model_com_file_ready
():
time
.
sleep
(
5
)
set_model_com_file_not_ready
()
INDEX_COM_FILE
=
'ready.index'
MODEL_COM_FILE
=
'ready.model'
def
setup_index_com_file
():
set_index_com_file_not_ready
()
if
args
.
rank
==
0
:
all_block_data
.
consolidate_shards_and_save
()
hashed_index
.
hash_whitened_block_embeds
(
all_block_data
)
hashed_index
.
save_to_file
()
else
:
all_block_data
.
clear
()
def
set_index_com_file_not_ready
():
with
open
(
INDEX_COM_FILE
,
'w'
)
as
com_file
:
com_file
.
write
(
'0'
)
def
load_ict_checkpoint
(
only_query_model
=
False
,
only_block_model
=
False
,
no_grad
=
False
):
def
set_index_com_file_ready
():
with
open
(
INDEX_COM_FILE
,
'w'
)
as
com_file
:
com_file
.
write
(
'1'
)
def
check_index_com_file_ready
():
if
os
.
path
.
exists
(
INDEX_COM_FILE
):
with
open
(
INDEX_COM_FILE
,
'r'
)
as
com_file
:
return
bool
(
com_file
.
readline
())
return
False
def
setup_model_com_file
():
set_model_com_file_not_ready
()
def
set_model_com_file_not_ready
():
with
open
(
MODEL_COM_FILE
,
'w'
)
as
com_file
:
com_file
.
write
(
'0'
)
def
set_model_com_file_ready
():
with
open
(
MODEL_COM_FILE
,
'w'
)
as
com_file
:
com_file
.
write
(
'1'
)
def
check_model_com_file_ready
():
if
os
.
path
.
exists
(
MODEL_COM_FILE
):
with
open
(
MODEL_COM_FILE
,
'r'
)
as
com_file
:
return
bool
(
com_file
.
readline
())
return
False
def
load_ict_checkpoint
(
only_query_model
=
False
,
only_block_model
=
False
,
no_grad
=
False
,
from_realm_chkpt
=
False
):
args
=
get_args
()
args
=
get_args
()
model
=
get_model
(
lambda
:
model_provider
(
only_query_model
,
only_block_model
))
model
=
get_model
(
lambda
:
model_provider
(
only_query_model
,
only_block_model
))
load_path
=
args
.
load
if
from_realm_chkpt
else
args
.
ict_load
if
isinstance
(
model
,
torchDDP
):
if
isinstance
(
model
,
torchDDP
):
model
=
model
.
module
model
=
model
.
module
tracker_filename
=
get_checkpoint_tracker_filename
(
args
.
ict_load
)
tracker_filename
=
get_checkpoint_tracker_filename
(
load_path
)
with
open
(
tracker_filename
,
'r'
)
as
f
:
with
open
(
tracker_filename
,
'r'
)
as
f
:
iteration
=
int
(
f
.
read
().
strip
())
iteration
=
int
(
f
.
read
().
strip
())
assert
iteration
>
0
assert
iteration
>
0
checkpoint_name
=
get_checkpoint_name
(
args
.
ict_load
,
iteration
,
False
)
checkpoint_name
=
get_checkpoint_name
(
load_path
,
iteration
,
False
)
if
mpu
.
get_data_parallel_rank
()
==
0
:
if
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
'global rank {} is loading checkpoint {}'
.
format
(
print
(
'global rank {} is loading checkpoint {}'
.
format
(
torch
.
distributed
.
get_rank
(),
checkpoint_name
))
torch
.
distributed
.
get_rank
(),
checkpoint_name
))
state_dict
=
torch
.
load
(
checkpoint_name
,
map_location
=
'cpu'
)
state_dict
=
torch
.
load
(
checkpoint_name
,
map_location
=
'cpu'
)
ict_state_dict
=
state_dict
[
'model'
]
if
from_realm_chkpt
:
ict_state_dict
=
ict_state_dict
[
'retriever'
][
'ict_model'
]
if
only_query_model
:
if
only_query_model
:
state_dict
[
'model'
]
.
pop
(
'context_model'
)
ict_
state_dict
.
pop
(
'context_model'
)
if
only_block_model
:
if
only_block_model
:
state_dict
[
'model'
]
.
pop
(
'question_model'
)
ict_
state_dict
.
pop
(
'question_model'
)
if
no_grad
:
if
no_grad
:
with
torch
.
no_grad
():
with
torch
.
no_grad
():
model
.
load_state_dict
(
state_dict
[
'model'
]
)
model
.
load_state_dict
(
ict_
state_dict
)
else
:
else
:
model
.
load_state_dict
(
state_dict
[
'model'
]
)
model
.
load_state_dict
(
ict_
state_dict
)
torch
.
distributed
.
barrier
()
torch
.
distributed
.
barrier
()
if
mpu
.
get_data_parallel_rank
()
==
0
:
if
mpu
.
get_data_parallel_rank
()
==
0
:
...
...
megatron/data/realm_index.py
View file @
15d0d55b
...
@@ -86,7 +86,8 @@ class FaissMIPSIndex(object):
...
@@ -86,7 +86,8 @@ class FaissMIPSIndex(object):
self
.
m
=
5
self
.
m
=
5
self
.
u
=
0.99
self
.
u
=
0.99
self
.
max_norm
=
None
self
.
max_norm
=
None
self
.
block_mips_index
=
self
.
get_block_index
()
self
.
block_mips_index
=
None
self
.
_set_block_index
()
@
classmethod
@
classmethod
def
load_from_file
(
cls
,
fname
):
def
load_from_file
(
cls
,
fname
):
...
@@ -101,7 +102,7 @@ class FaissMIPSIndex(object):
...
@@ -101,7 +102,7 @@ class FaissMIPSIndex(object):
return
new_index
return
new_index
def
g
et_block_index
(
self
):
def
_s
et_block_index
(
self
):
import
faiss
import
faiss
INDEX_TYPES
=
[
'flat_l2'
,
'flat_ip'
]
INDEX_TYPES
=
[
'flat_l2'
,
'flat_ip'
]
if
self
.
index_type
not
in
INDEX_TYPES
:
if
self
.
index_type
not
in
INDEX_TYPES
:
...
@@ -109,10 +110,13 @@ class FaissMIPSIndex(object):
...
@@ -109,10 +110,13 @@ class FaissMIPSIndex(object):
if
self
.
index_type
==
'flat_l2'
:
if
self
.
index_type
==
'flat_l2'
:
index
=
faiss
.
IndexFlatL2
(
self
.
embed_size
+
2
*
self
.
m
)
index
=
faiss
.
IndexFlatL2
(
self
.
embed_size
+
2
*
self
.
m
)
return
faiss
.
IndexIDMap
(
index
)
self
.
block_mips_index
=
faiss
.
IndexIDMap
(
index
)
elif
self
.
index_type
==
'flat_ip'
:
elif
self
.
index_type
==
'flat_ip'
:
index
=
faiss
.
IndexFlatIP
(
self
.
embed_size
)
index
=
faiss
.
IndexFlatIP
(
self
.
embed_size
)
return
faiss
.
IndexIDMap
(
index
)
self
.
block_mips_index
=
faiss
.
IndexIDMap
(
index
)
def
reset_index
(
self
):
self
.
_set_block_index
()
def
add_block_embed_data
(
self
,
all_block_data
,
clear_block_data
=
False
):
def
add_block_embed_data
(
self
,
all_block_data
,
clear_block_data
=
False
):
"""Add the embedding of each block to the underlying FAISS index"""
"""Add the embedding of each block to the underlying FAISS index"""
...
...
megatron/model/realm_model.py
View file @
15d0d55b
...
@@ -4,7 +4,7 @@ import torch.nn.functional as F
...
@@ -4,7 +4,7 @@ import torch.nn.functional as F
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron.checkpointing
import
load_checkpoint
from
megatron.checkpointing
import
load_checkpoint
from
megatron.data.realm_index
import
detach
from
megatron.data.realm_index
import
detach
,
BlockData
,
FaissMIPSIndex
from
megatron.model
import
BertModel
from
megatron.model
import
BertModel
from
megatron.model.utils
import
get_linear_layer
,
init_method_normal
from
megatron.model.utils
import
get_linear_layer
,
init_method_normal
from
megatron.module
import
MegatronModule
from
megatron.module
import
MegatronModule
...
@@ -161,6 +161,12 @@ class REALMRetriever(MegatronModule):
...
@@ -161,6 +161,12 @@ class REALMRetriever(MegatronModule):
self
.
top_k
=
top_k
self
.
top_k
=
top_k
self
.
_ict_key
=
'ict_model'
self
.
_ict_key
=
'ict_model'
def
reload_index
(
self
):
args
=
get_args
()
self
.
block_data
=
BlockData
.
load_from_file
(
args
.
block_data_path
)
self
.
hashed_index
.
reset_index
()
self
.
hashed_index
.
add_block_embed_data
(
self
.
block_data
)
def
retrieve_evidence_blocks_text
(
self
,
query_text
):
def
retrieve_evidence_blocks_text
(
self
,
query_text
):
"""Get the top k evidence blocks for query_text in text form"""
"""Get the top k evidence blocks for query_text in text form"""
print
(
"-"
*
100
)
print
(
"-"
*
100
)
...
@@ -256,7 +262,6 @@ class ICTBertModel(MegatronModule):
...
@@ -256,7 +262,6 @@ class ICTBertModel(MegatronModule):
if
only_block
:
if
only_block
:
return
self
.
embed_block
(
block_tokens
,
block_attention_mask
)
return
self
.
embed_block
(
block_tokens
,
block_attention_mask
)
query_logits
=
self
.
embed_query
(
query_tokens
,
query_attention_mask
)
query_logits
=
self
.
embed_query
(
query_tokens
,
query_attention_mask
)
block_logits
=
self
.
embed_block
(
block_tokens
,
block_attention_mask
)
block_logits
=
self
.
embed_block
(
block_tokens
,
block_attention_mask
)
...
...
megatron/training.py
View file @
15d0d55b
...
@@ -39,6 +39,7 @@ from megatron.model import get_params_for_weight_decay_optimization
...
@@ -39,6 +39,7 @@ from megatron.model import get_params_for_weight_decay_optimization
from
megatron.utils
import
check_adlr_autoresume_termination
from
megatron.utils
import
check_adlr_autoresume_termination
from
megatron.utils
import
make_data_loader
from
megatron.utils
import
make_data_loader
from
megatron.utils
import
report_memory
from
megatron.utils
import
report_memory
from
indexer
import
check_index_com_file_ready
,
set_index_com_file_not_ready
,
set_model_com_file_ready
def
pretrain
(
train_valid_test_dataset_provider
,
model_provider
,
def
pretrain
(
train_valid_test_dataset_provider
,
model_provider
,
...
@@ -363,6 +364,15 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
...
@@ -363,6 +364,15 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
timers
(
'interval time'
).
start
()
timers
(
'interval time'
).
start
()
report_memory_flag
=
True
report_memory_flag
=
True
while
iteration
<
args
.
train_iters
:
while
iteration
<
args
.
train_iters
:
if
hasattr
(
model
,
'retriever'
):
new_index_ready
=
check_index_com_file_ready
()
if
new_index_ready
:
torch
.
distributed
.
barrier
()
model
.
retriever
.
reload_index
()
set_index_com_file_not_ready
()
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr_scheduler
)
set_model_com_file_ready
()
loss_dict
,
skipped_iter
=
train_step
(
forward_step_func
,
loss_dict
,
skipped_iter
=
train_step
(
forward_step_func
,
train_data_iterator
,
train_data_iterator
,
model
,
model
,
...
...
pretrain_realm.py
View file @
15d0d55b
...
@@ -18,7 +18,7 @@
...
@@ -18,7 +18,7 @@
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
hashed_
index
import
load_ict_checkpoint
,
get_ict_dataset
from
index
er
import
load_ict_checkpoint
,
get_ict_dataset
from
megatron.data.realm_index
import
BlockData
,
RandProjectionLSHIndex
,
FaissMIPSIndex
from
megatron.data.realm_index
import
BlockData
,
RandProjectionLSHIndex
,
FaissMIPSIndex
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron
import
get_timers
from
megatron
import
get_timers
...
@@ -41,7 +41,7 @@ def model_provider():
...
@@ -41,7 +41,7 @@ def model_provider():
ict_dataset
=
get_ict_dataset
(
use_titles
=
False
)
ict_dataset
=
get_ict_dataset
(
use_titles
=
False
)
all_block_data
=
BlockData
.
load_from_file
(
args
.
block_data_path
)
all_block_data
=
BlockData
.
load_from_file
(
args
.
block_data_path
)
# hashed_index = RandProjectionLSHIndex.load_from_file(args.block_index_path)
# hashed_index = RandProjectionLSHIndex.load_from_file(args.block_index_path)
hashed_index
=
FaissMIPSIndex
(
index_type
=
'flat_
l2
'
,
embed_size
=
128
)
hashed_index
=
FaissMIPSIndex
(
index_type
=
'flat_
ip
'
,
embed_size
=
128
)
hashed_index
.
add_block_embed_data
(
all_block_data
)
hashed_index
.
add_block_embed_data
(
all_block_data
)
# top_k + 1 because we may need to exclude trivial candidate
# top_k + 1 because we may need to exclude trivial candidate
...
@@ -102,7 +102,6 @@ def forward_step(data_iterator, model):
...
@@ -102,7 +102,6 @@ def forward_step(data_iterator, model):
reduced_loss
=
reduce_losses
([
lm_loss
])
reduced_loss
=
reduce_losses
([
lm_loss
])
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
print
(
reduced_loss
,
flush
=
True
)
return
lm_loss
,
{
'lm_loss'
:
reduced_loss
[
0
]}
return
lm_loss
,
{
'lm_loss'
:
reduced_loss
[
0
]}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment