Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
5684f904
"torchvision/vscode:/vscode.git/clone" did not exist on "e2e511be5918fc61008f384d88bce06bd61070da"
Commit
5684f904
authored
May 19, 2020
by
Neel Kant
Browse files
Full cycle of communication complete. Also added BasicIndexBuilder
parent
d4b00be0
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
82 additions
and
16 deletions
+82
-16
indexer.py
indexer.py
+56
-12
megatron/checkpointing.py
megatron/checkpointing.py
+3
-2
megatron/model/realm_model.py
megatron/model/realm_model.py
+3
-0
megatron/training.py
megatron/training.py
+20
-2
No files found.
indexer.py
View file @
5684f904
...
...
@@ -110,15 +110,14 @@ class AsyncIndexBuilder(object):
def
run_async
(
self
):
while
True
:
print
(
"Starting (again!)"
)
self
.
build_index
()
self
.
save_index
()
print
(
"Starting (again!)"
,
flush
=
True
)
self
.
build_and_save_index
()
self
.
send_index_ready_signal
()
while
INDEX_READY
==
1
:
print
(
"Waiting for new model checkpoint."
)
time
.
sleep
(
1
)
print
(
"Waiting for new model checkpoint."
,
flush
=
True
)
time
.
sleep
(
5
)
self
.
load_
model
()
self
.
load_
attributes
()
def
load_attributes
(
self
):
try
:
...
...
@@ -129,7 +128,7 @@ class AsyncIndexBuilder(object):
self
.
dataloader
=
iter
(
get_one_epoch_dataloader
(
get_ict_dataset
()))
self
.
block_data
=
BlockData
()
def
build_index
(
self
):
def
build_
and_save_
index
(
self
):
i
=
1
total
=
0
while
True
:
...
...
@@ -149,7 +148,7 @@ class AsyncIndexBuilder(object):
total
+=
block_indices
.
size
i
+=
1
if
i
%
200
0
==
0
:
if
i
%
1
0
==
0
:
print
(
'Batch {:10d} | Total {:10d}'
.
format
(
i
,
total
),
flush
=
True
)
if
self
.
debug
:
break
...
...
@@ -162,27 +161,68 @@ class AsyncIndexBuilder(object):
print_rank_0
(
">>> training terminated. Returning"
)
sys
.
exit
(
0
)
def
save_index
(
self
):
self
.
block_data
.
save_shard
(
self
.
rank
)
torch
.
distributed
.
barrier
()
del
self
.
model
if
self
.
is_main_builder
:
self
.
block_data
.
consolidate_shards_and_save
(
ignore_shard
=
self
.
rank
)
else
:
self
.
block_data
.
clear
()
self
.
block_data
.
clear
()
def
send_index_ready_signal
(
self
):
global
INDEX_READY
if
self
.
is_main_builder
:
INDEX_READY
=
1
-
INDEX_READY
print
(
"Switched INDEX_READY"
,
flush
=
True
)
import
time
print
(
time
.
ctime
(
time
.
time
()),
flush
=
True
)
send_handle
=
dist
.
broadcast
(
INDEX_READY
,
self
.
main_builder_idx
,
async_op
=
True
)
torch
.
distributed
.
barrier
(
get_index_group
())
recv_handle
=
dist
.
broadcast
(
INDEX_READY
,
0
,
async_op
=
True
)
class
BasicIndexBuilder
(
object
):
def
__init__
(
self
):
args
=
get_args
()
self
.
rank
=
args
.
rank
self
.
model
=
load_ict_checkpoint
(
only_block_model
=
True
,
no_grad
=
True
,
from_realm_chkpt
=
False
)
self
.
model
.
eval
()
self
.
dataloader
=
iter
(
get_one_epoch_dataloader
(
get_ict_dataset
()))
self
.
block_data
=
BlockData
()
def
build_and_save_index
(
self
):
i
=
1
total
=
0
while
True
:
with
torch
.
no_grad
():
try
:
query_tokens
,
query_pad_mask
,
\
block_tokens
,
block_pad_mask
,
block_index_data
=
get_batch
(
self
.
dataloader
)
except
:
break
block_index_data
=
detach
(
block_index_data
)
block_indices
=
block_index_data
[:,
3
]
block_meta
=
block_index_data
[:,
:
3
]
block_logits
=
detach
(
self
.
model
(
None
,
None
,
block_tokens
,
block_pad_mask
,
only_block
=
True
))
self
.
block_data
.
add_block_data
(
block_indices
,
block_logits
,
block_meta
)
total
+=
block_indices
.
size
i
+=
1
if
i
%
2000
==
0
:
print
(
'Batch {:10d} | Total {:10d}'
.
format
(
i
,
total
),
flush
=
True
)
self
.
block_data
.
save_shard
(
self
.
rank
)
torch
.
distributed
.
barrier
()
del
self
.
model
if
self
.
rank
==
0
:
self
.
block_data
.
consolidate_shards_and_save
(
ignore_shard
=
self
.
rank
)
self
.
block_data
.
clear
()
def
load_ict_checkpoint
(
only_query_model
=
False
,
only_block_model
=
False
,
no_grad
=
False
,
from_realm_chkpt
=
False
):
args
=
get_args
()
model
=
get_model
(
lambda
:
model_provider
(
only_query_model
,
only_block_model
))
...
...
@@ -270,4 +310,8 @@ def get_one_epoch_dataloader(dataset):
if
__name__
==
"__main__"
:
main
()
initialize_megatron
(
extra_args_provider
=
None
,
args_defaults
=
{
'tokenizer_type'
:
'BertWordPieceLowerCase'
})
index_builder
=
BasicIndexBuilder
()
index_builder
.
build_and_save_index
()
megatron/checkpointing.py
View file @
5684f904
...
...
@@ -24,6 +24,7 @@ import torch
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
from
megatron
import
mpu
from
megatron.mpu.initialize
import
get_train_group
from
megatron
import
get_args
from
megatron
import
print_rank_0
...
...
@@ -118,14 +119,14 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
print
(
' successfully saved {}'
.
format
(
checkpoint_name
))
# Wait so everyone is done (necessary)
torch
.
distributed
.
barrier
()
torch
.
distributed
.
barrier
(
get_train_group
()
)
# And update the latest iteration
if
torch
.
distributed
.
get_rank
()
==
0
:
tracker_filename
=
get_checkpoint_tracker_filename
(
args
.
save
)
with
open
(
tracker_filename
,
'w'
)
as
f
:
f
.
write
(
str
(
iteration
))
# Wait so everyone is done (not necessary)
torch
.
distributed
.
barrier
()
torch
.
distributed
.
barrier
(
get_train_group
()
)
def
load_checkpoint
(
model
,
optimizer
,
lr_scheduler
):
...
...
megatron/model/realm_model.py
View file @
5684f904
...
...
@@ -163,8 +163,11 @@ class REALMRetriever(MegatronModule):
def
reload_index
(
self
):
args
=
get_args
()
print
(
"loading from file"
,
flush
=
True
)
self
.
block_data
=
BlockData
.
load_from_file
(
args
.
block_data_path
)
print
(
"resetting index"
,
flush
=
True
)
self
.
hashed_index
.
reset_index
()
print
(
"adding block data"
,
flush
=
True
)
self
.
hashed_index
.
add_block_embed_data
(
self
.
block_data
)
def
prep_query_text_for_retrieval
(
self
,
query_text
):
...
...
megatron/training.py
View file @
5684f904
...
...
@@ -373,13 +373,29 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
timers
(
'interval time'
).
start
()
report_memory_flag
=
True
import
time
print
(
">>> going to sleep"
,
flush
=
True
)
time
.
sleep
(
10
)
print
(
">>> woke from sleep"
,
flush
=
True
)
print
(
time
.
ctime
(
time
.
time
()),
flush
=
True
)
global
INDEX_READY
recv_handle
=
torch
.
distributed
.
broadcast
(
INDEX_READY
,
args
.
max_training_rank
,
async_op
=
True
)
print
(
">>>>>>>> Created recv handle"
,
flush
=
True
)
while
iteration
<
args
.
train_iters
:
if
hasattr
(
model
,
'retriever'
)
and
INDEX_READY
==
1
:
model
.
retriever
.
reload_index
()
print
(
"INDEX READY: "
,
INDEX_READY
)
if
args
.
max_training_rank
is
not
None
and
INDEX_READY
==
1
:
print
(
">>>>>>> entering the good stuff"
,
flush
=
True
)
true_model
=
model
if
hasattr
(
true_model
,
'module'
):
true_model
=
true_model
.
module
if
hasattr
(
true_model
,
'module'
):
true_model
=
true_model
.
module
print
(
">>>>>>> starting to reload index"
,
flush
=
True
)
true_model
.
retriever
.
reload_index
()
print
(
">>>>>>> starting to save checkpoint"
,
flush
=
True
)
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr_scheduler
)
print
(
">>>>>>> saved checkpoint"
,
flush
=
True
)
if
args
.
rank
==
0
:
INDEX_READY
=
1
-
INDEX_READY
...
...
@@ -387,6 +403,8 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
send_handle
=
torch
.
distributed
.
broadcast
(
INDEX_READY
,
0
,
async_op
=
True
)
torch
.
distributed
.
barrier
(
get_train_group
())
recv_handle
=
torch
.
distributed
.
broadcast
(
INDEX_READY
,
args
.
max_training_rank
,
async_op
=
True
)
else
:
print
(
">>>>>>> moving right along"
,
flush
=
True
)
loss_dict
,
skipped_iter
=
train_step
(
forward_step_func
,
train_data_iterator
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment