Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
e338e311
Commit
e338e311
authored
May 18, 2020
by
Neel Kant
Browse files
Indexer_async works in theory
parent
edaf2aab
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
36 additions
and
65 deletions
+36
-65
indexer.py
indexer.py
+5
-21
indexer_async.py
indexer_async.py
+8
-3
indexer_utils.py
indexer_utils.py
+0
-39
megatron/arguments.py
megatron/arguments.py
+1
-0
megatron/initialize.py
megatron/initialize.py
+8
-2
megatron/mpu/initialize.py
megatron/mpu/initialize.py
+14
-0
No files found.
indexer.py
View file @
e338e311
...
...
@@ -20,23 +20,6 @@ from pretrain_bert_ict import get_batch, model_provider
from
indexer_utils
import
set_index_com_file_ready
,
set_model_com_file_not_ready
,
check_model_com_file_ready
# TODO re: main()
# consider broadcasting/all-reducing all in memory rather than using the filesystem
# create a different process group in the same nccl world - don't have to use chkpts on disc or transfer things on disc
# torch distributed new group, constains a list of rank, gives back a group which I can hand to the collective operations
# create a training process group, indexing process group
# pass the training group to the distributed DDP, instead of the large world process group
# use indexing process group for the shard-combining
# communication group between process "8" and process "0" which tells training group that there's a new index
# also, process 0 sends process 8 the new model
# if i want to launch a separate process for indexing, may have to work with environment variables to
# allocate the resources well. Have to subsequently assign the correct gpus to the indexing job
# consider initializing everything in a single group and break off processes based on the ranks
# for debugging purposes, make it so that the training process group checks every some number of intervals
# and if it isn't ready, then wait so that it's consistent. Start with using the filesystem
def
test_retriever
():
# TODO: Update this because it's outdated and definitely won't run.
initialize_megatron
(
extra_args_provider
=
None
,
...
...
@@ -66,9 +49,11 @@ def main():
initialize_megatron
(
extra_args_provider
=
None
,
args_defaults
=
{
'tokenizer_type'
:
'BertWordPieceLowerCase'
})
args
=
get_args
()
ran_once
=
False
while
True
:
model
=
load_ict_checkpoint
(
only_block_model
=
True
,
no_grad
=
True
,
from_realm_chkpt
=
ran_once
)
try
:
model
=
load_ict_checkpoint
(
only_block_model
=
True
,
no_grad
=
True
,
from_realm_chkpt
=
True
)
except
:
model
=
load_ict_checkpoint
(
only_block_model
=
True
,
no_grad
=
True
,
from_realm_chkpt
=
False
)
model
.
eval
()
dataset
=
get_ict_dataset
()
data_iter
=
iter
(
get_one_epoch_dataloader
(
dataset
))
...
...
@@ -93,7 +78,7 @@ def main():
total
+=
block_indices
.
size
i
+=
1
if
i
%
20
==
0
:
if
i
%
20
00
==
0
:
print
(
'Batch {:10d} | Total {:10d}'
.
format
(
i
,
total
),
flush
=
True
)
if
args
.
debug
:
break
...
...
@@ -107,7 +92,6 @@ def main():
else
:
all_block_data
.
clear
()
ran_once
=
True
set_index_com_file_ready
()
torch
.
distributed
.
barrier
()
if
args
.
async_indexer
:
...
...
indexer_async.py
View file @
e338e311
...
...
@@ -111,7 +111,7 @@ class AsyncIndexBuilder(object):
pprint
(
"-"
*
100
)
for
i
in
range
(
5
):
# simulating building the index which takes 20 seconds
time
.
sleep
(
2
0
)
time
.
sleep
(
1
0
)
pprint
(
'built the index. Time: {}'
.
format
(
time
.
ctime
(
time
.
time
())))
args
=
get_args
()
...
...
@@ -121,8 +121,11 @@ class AsyncIndexBuilder(object):
INDEX_READY
=
1
-
INDEX_READY
send_handle
=
dist
.
broadcast
(
INDEX_READY
,
args
.
max_training_rank
,
async_op
=
True
)
pprint
(
"Broadcasted index ready = "
,
INDEX_READY
)
else
:
send_recv_handle
=
dist
.
broadcast
(
INDEX_READY
,
args
.
max_training_rank
,
async_op
=
True
)
torch
.
distributed
.
barrier
(
INDEX_GROUP
)
pprint
(
"Synced after broadcasting"
)
recv_handle
=
dist
.
broadcast
(
INDEX_READY
,
0
,
async_op
=
True
)
while
INDEX_READY
==
1
:
...
...
@@ -154,12 +157,14 @@ class AsyncREALMTrainer(object):
# INDEX_READY is 1
if
self
.
rank
==
0
:
INDEX_READY
=
1
-
INDEX_READY
send_handle
=
dist
.
broadcast
(
INDEX_READY
,
self
.
rank
,
async_op
=
True
)
send_handle
=
dist
.
broadcast
(
INDEX_READY
,
0
,
async_op
=
True
)
pprint
(
"Broadcasted index ready = "
,
INDEX_READY
)
else
:
send_recv_handle
=
dist
.
broadcast
(
INDEX_READY
,
0
,
async_op
=
True
)
torch
.
distributed
.
barrier
(
TRAIN_GROUP
)
pprint
(
"Synced after broadcasting"
)
if
__name__
==
"__main__"
:
initialize_and_run_async_megatron
(
args_defaults
=
{
'tokenizer_type'
:
'BertWordPieceLowerCase'
})
indexer_utils.py
deleted
100644 → 0
View file @
edaf2aab
INDEX_COM_FILE
=
'ready.index'
MODEL_COM_FILE
=
'ready.model'
def
set_index_com_file_not_ready
():
with
open
(
INDEX_COM_FILE
,
'w'
)
as
com_file
:
com_file
.
write
(
'0'
)
def
set_index_com_file_ready
():
with
open
(
INDEX_COM_FILE
,
'w'
)
as
com_file
:
com_file
.
write
(
'1'
)
def
check_index_com_file_ready
():
if
not
os
.
path
.
exists
(
INDEX_COM_FILE
):
set_index_com_file_not_ready
()
with
open
(
INDEX_COM_FILE
,
'r'
)
as
com_file
:
return
bool
(
com_file
.
readline
())
def
set_model_com_file_not_ready
():
with
open
(
MODEL_COM_FILE
,
'w'
)
as
com_file
:
com_file
.
write
(
'0'
)
def
set_model_com_file_ready
():
with
open
(
MODEL_COM_FILE
,
'w'
)
as
com_file
:
com_file
.
write
(
'1'
)
def
check_model_com_file_ready
():
if
not
os
.
path
.
exists
(
MODEL_COM_FILE
):
set_index_com_file_not_ready
()
with
open
(
MODEL_COM_FILE
,
'r'
)
as
com_file
:
return
bool
(
com_file
.
readline
())
megatron/arguments.py
View file @
e338e311
...
...
@@ -195,6 +195,7 @@ def _add_training_args(parser):
'by this value.'
)
group
.
add_argument
(
'--tensorboard-dir'
,
type
=
str
,
default
=
None
,
help
=
'Write TensorBoard logs to this directory.'
)
group
.
add_argument
(
'--max-training-rank'
,
type
=
int
,
default
=
None
)
return
parser
...
...
megatron/initialize.py
View file @
e338e311
...
...
@@ -61,8 +61,7 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
_write_args_to_tensorboard
()
def
_initialize_distributed
():
"""Initialize torch.distributed and mpu."""
def
init_distributed
():
args
=
get_args
()
device_count
=
torch
.
cuda
.
device_count
()
...
...
@@ -102,6 +101,13 @@ def _initialize_distributed():
world_size
=
args
.
world_size
,
rank
=
args
.
rank
,
init_method
=
init_method
)
def
_initialize_distributed
():
"""Initialize torch.distributed and mpu."""
init_distributed
()
args
=
get_args
()
device_count
=
torch
.
cuda
.
device_count
()
# Set the model-parallel / data-parallel communicators.
if
device_count
>
0
:
mpu
.
initialize_model_parallel
(
args
.
model_parallel_size
)
...
...
megatron/mpu/initialize.py
View file @
e338e311
...
...
@@ -96,6 +96,13 @@ def get_model_parallel_group():
return
_MODEL_PARALLEL_GROUP
def
set_model_parallel_group
(
group
):
global
_MODEL_PARALLEL_GROUP
assert
_MODEL_PARALLEL_GROUP
is
None
,
\
'model parallel group has already been initialized'
_MODEL_PARALLEL_GROUP
=
group
def
get_data_parallel_group
():
"""Get the data parallel group the caller rank belongs to."""
assert
_DATA_PARALLEL_GROUP
is
not
None
,
\
...
...
@@ -103,6 +110,13 @@ def get_data_parallel_group():
return
_DATA_PARALLEL_GROUP
def
set_data_parallel_group
(
group
):
global
_DATA_PARALLEL_GROUP
assert
_DATA_PARALLEL_GROUP
is
None
,
\
'data parallel group has already been initialized'
_DATA_PARALLEL_GROUP
=
group
def
set_model_parallel_world_size
(
world_size
):
"""Set the model parallel size"""
global
_MPU_WORLD_SIZE
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment