Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
d4b00be0
Commit
d4b00be0
authored
May 19, 2020
by
Neel Kant
Browse files
Reorganize indexer. Things run up to saving model checkpoint and repeating
parent
e338e311
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
196 additions
and
237 deletions
+196
-237
indexer.py
indexer.py
+129
-51
indexer_async.py
indexer_async.py
+0
-170
megatron/data/realm_dataset_utils.py
megatron/data/realm_dataset_utils.py
+2
-2
megatron/mpu/initialize.py
megatron/mpu/initialize.py
+34
-0
megatron/training.py
megatron/training.py
+28
-13
pretrain_realm.py
pretrain_realm.py
+3
-1
No files found.
indexer.py
View file @
d4b00be0
...
@@ -3,6 +3,7 @@ import sys
...
@@ -3,6 +3,7 @@ import sys
import
time
import
time
import
torch
import
torch
import
torch.distributed
as
dist
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
from
megatron
import
get_args
,
get_adlr_autoresume
,
print_rank_0
from
megatron
import
get_args
,
get_adlr_autoresume
,
print_rank_0
...
@@ -14,58 +15,128 @@ from megatron.data.realm_index import detach, BlockData, FaissMIPSIndex
...
@@ -14,58 +15,128 @@ from megatron.data.realm_index import detach, BlockData, FaissMIPSIndex
from
megatron.data.samplers
import
DistributedBatchSampler
from
megatron.data.samplers
import
DistributedBatchSampler
from
megatron.initialize
import
initialize_megatron
from
megatron.initialize
import
initialize_megatron
from
megatron.model
import
REALMRetriever
from
megatron.model
import
REALMRetriever
from
megatron.global_vars
import
set_global_variables
from
megatron.mpu.initialize
import
get_index_ready
,
get_index_group
,
get_train_group
from
megatron.mpu.initialize
import
set_data_parallel_group
,
set_model_parallel_group
,
init_realm_groups
from
megatron.initialize
import
init_distributed
,
_init_autoresume
,
_set_random_seed
,
_write_args_to_tensorboard
from
megatron.training
import
get_model
from
megatron.training
import
get_model
from
megatron.utils
import
check_adlr_autoresume_termination
from
megatron.utils
import
check_adlr_autoresume_termination
from
pretrain_bert_ict
import
get_batch
,
model_provider
from
pretrain_bert_ict
import
get_batch
,
model_provider
from
indexer_utils
import
set_index_com_file_ready
,
set_model_com_file_not_ready
,
check_model_com_file_ready
def
test_retriever
():
INDEX_READY
=
None
# TODO: Update this because it's outdated and definitely won't run.
initialize_megatron
(
extra_args_provider
=
None
,
args_defaults
=
{
'tokenizer_type'
:
'BertWordPieceLowerCase'
})
def
pprint
(
*
args
):
print
(
*
args
,
flush
=
True
)
def
initialize_and_run_async_megatron
(
extra_args_provider
=
None
,
args_defaults
=
{},
ignore_unknown_args
=
False
,
allow_no_cuda
=
False
):
if
not
allow_no_cuda
:
# Make sure cuda is available.
assert
torch
.
cuda
.
is_available
(),
'Megatron requires CUDA.'
# Parse args, build tokenizer, and set adlr-autoresume,
# tensorboard-writer, and timers.
set_global_variables
(
extra_args_provider
=
extra_args_provider
,
args_defaults
=
args_defaults
,
ignore_unknown_args
=
ignore_unknown_args
)
# instead of _initialize_distributed()
init_distributed
()
setup_realm_groups_and_vars
()
global
INDEX_READY
INDEX_READY
=
get_index_ready
()
pprint
(
'finished setting up groups'
)
# Autoresume
_init_autoresume
()
pprint
(
'finished setting up autoresume'
)
# Random seeds for reproducibility.
args
=
get_args
()
args
=
get_args
()
model
=
load_ict_checkpoint
()
if
args
.
rank
==
0
:
model
.
eval
(
)
pprint
(
'> setting random seeds to {} ...'
.
format
(
args
.
seed
)
)
dataset
=
get_ict_dataset
(
)
_set_random_seed
(
args
.
seed
)
block_data
=
BlockData
.
load_from_file
(
args
.
block_data_path
)
# Write arguments to tensorboard.
mips_index
=
FaissMIPSIndex
(
'flat_ip'
,
128
)
_write_args_to_tensorboard
()
mips_index
.
add_block_embed_data
(
block_data
)
pprint
(
'finished writing args to tensorboard'
)
retriever
=
REALMRetriever
(
model
,
dataset
,
block_data
,
mips_index
,
top_k
=
5
)
strs
=
[
torch
.
distributed
.
barrier
()
"The last monarch from the house of windsor"
,
"married to Elvis Presley"
,
"tallest building in the world today"
,
"who makes graphics cards"
]
for
s
in
strs
:
if
args
.
rank
<
args
.
max_training_rank
:
retriever
.
retrieve_evidence_blocks_text
(
s
)
torch
.
distributed
.
barrier
(
get_train_group
())
pprint
(
"All trainers ready."
)
return
else
:
runner
=
AsyncIndexBuilder
(
args
.
rank
)
torch
.
distributed
.
barrier
(
get_index_group
())
pprint
(
"All indexers ready."
)
runner
.
run_async
()
def
main
():
def
setup_realm_groups_and_vars
():
initialize_megatron
(
extra_args_provider
=
None
,
args_defaults
=
{
'tokenizer_type'
:
'BertWordPieceLowerCase'
})
args
=
get_args
()
args
=
get_args
()
while
True
:
world_size
=
dist
.
get_world_size
()
max_training_rank
=
args
.
max_training_rank
# assuming no model parallelism right now
set_model_parallel_group
(
dist
.
new_group
([
args
.
rank
]))
init_realm_groups
(
max_training_rank
,
world_size
)
if
args
.
rank
<
max_training_rank
:
set_data_parallel_group
(
get_train_group
())
else
:
set_data_parallel_group
(
get_index_group
())
class
AsyncIndexBuilder
(
object
):
def
__init__
(
self
,
rank
):
self
.
rank
=
rank
args
=
get_args
()
self
.
is_main_builder
=
self
.
rank
==
args
.
max_training_rank
self
.
main_builder_idx
=
args
.
max_training_rank
self
.
debug
=
args
.
debug
self
.
model
=
None
self
.
dataloader
=
None
self
.
block_data
=
None
self
.
load_attributes
()
global
INDEX_READY
INDEX_READY
=
get_index_ready
()
def
run_async
(
self
):
while
True
:
print
(
"Starting (again!)"
)
self
.
build_index
()
self
.
save_index
()
self
.
send_index_ready_signal
()
while
INDEX_READY
==
1
:
print
(
"Waiting for new model checkpoint."
)
time
.
sleep
(
1
)
self
.
load_model
()
def
load_attributes
(
self
):
try
:
try
:
model
=
load_ict_checkpoint
(
only_block_model
=
True
,
no_grad
=
True
,
from_realm_chkpt
=
True
)
self
.
model
=
load_ict_checkpoint
(
only_block_model
=
True
,
no_grad
=
True
,
from_realm_chkpt
=
True
)
except
:
except
:
model
=
load_ict_checkpoint
(
only_block_model
=
True
,
no_grad
=
True
,
from_realm_chkpt
=
False
)
self
.
model
=
load_ict_checkpoint
(
only_block_model
=
True
,
no_grad
=
True
,
from_realm_chkpt
=
False
)
model
.
eval
()
self
.
model
.
eval
()
dataset
=
get_ict_dataset
()
self
.
dataloader
=
iter
(
get_one_epoch_dataloader
(
get_ict_dataset
()))
data_iter
=
iter
(
get_one_epoch_dataloader
(
dataset
))
self
.
block_data
=
BlockData
()
all_block_data
=
BlockData
()
def
build_index
(
self
):
i
=
1
i
=
1
total
=
0
total
=
0
while
True
:
while
True
:
with
torch
.
no_grad
():
with
torch
.
no_grad
():
try
:
try
:
query_tokens
,
query_pad_mask
,
\
query_tokens
,
query_pad_mask
,
\
block_tokens
,
block_pad_mask
,
block_index_data
=
get_batch
(
data_it
er
)
block_tokens
,
block_pad_mask
,
block_index_data
=
get_batch
(
self
.
dataload
er
)
except
:
except
:
break
break
...
@@ -73,30 +144,16 @@ def main():
...
@@ -73,30 +144,16 @@ def main():
block_indices
=
block_index_data
[:,
3
]
block_indices
=
block_index_data
[:,
3
]
block_meta
=
block_index_data
[:,
:
3
]
block_meta
=
block_index_data
[:,
:
3
]
block_logits
=
detach
(
model
(
None
,
None
,
block_tokens
,
block_pad_mask
,
only_block
=
True
))
block_logits
=
detach
(
self
.
model
(
None
,
None
,
block_tokens
,
block_pad_mask
,
only_block
=
True
))
all_
block_data
.
add_block_data
(
block_indices
,
block_logits
,
block_meta
)
self
.
block_data
.
add_block_data
(
block_indices
,
block_logits
,
block_meta
)
total
+=
block_indices
.
size
total
+=
block_indices
.
size
i
+=
1
i
+=
1
if
i
%
2000
==
0
:
if
i
%
2000
==
0
:
print
(
'Batch {:10d} | Total {:10d}'
.
format
(
i
,
total
),
flush
=
True
)
print
(
'Batch {:10d} | Total {:10d}'
.
format
(
i
,
total
),
flush
=
True
)
if
args
.
debug
:
if
self
.
debug
:
break
break
all_block_data
.
save_shard
(
args
.
rank
)
torch
.
distributed
.
barrier
()
del
model
if
args
.
rank
==
0
:
all_block_data
.
consolidate_shards_and_save
()
else
:
all_block_data
.
clear
()
set_index_com_file_ready
()
torch
.
distributed
.
barrier
()
if
args
.
async_indexer
:
while
not
check_model_com_file_ready
():
time
.
sleep
(
5
)
autoresume
=
get_adlr_autoresume
()
autoresume
=
get_adlr_autoresume
()
if
autoresume
.
termination_requested
():
if
autoresume
.
termination_requested
():
print_rank_0
(
">>> autoresume termination request found!"
)
print_rank_0
(
">>> autoresume termination request found!"
)
...
@@ -105,17 +162,36 @@ def main():
...
@@ -105,17 +162,36 @@ def main():
print_rank_0
(
">>> training terminated. Returning"
)
print_rank_0
(
">>> training terminated. Returning"
)
sys
.
exit
(
0
)
sys
.
exit
(
0
)
set_model_com_file_not_ready
()
def
save_index
(
self
):
self
.
block_data
.
save_shard
(
self
.
rank
)
torch
.
distributed
.
barrier
()
del
self
.
model
if
self
.
is_main_builder
:
self
.
block_data
.
consolidate_shards_and_save
(
ignore_shard
=
self
.
rank
)
else
:
self
.
block_data
.
clear
()
def
send_index_ready_signal
(
self
):
global
INDEX_READY
if
self
.
is_main_builder
:
INDEX_READY
=
1
-
INDEX_READY
print
(
"Switched INDEX_READY"
,
flush
=
True
)
send_handle
=
dist
.
broadcast
(
INDEX_READY
,
self
.
main_builder_idx
,
async_op
=
True
)
torch
.
distributed
.
barrier
(
get_index_group
())
recv_handle
=
dist
.
broadcast
(
INDEX_READY
,
0
,
async_op
=
True
)
def
load_ict_checkpoint
(
only_query_model
=
False
,
only_block_model
=
False
,
no_grad
=
False
,
from_realm_chkpt
=
False
):
def
load_ict_checkpoint
(
only_query_model
=
False
,
only_block_model
=
False
,
no_grad
=
False
,
from_realm_chkpt
=
False
):
args
=
get_args
()
args
=
get_args
()
model
=
get_model
(
lambda
:
model_provider
(
only_query_model
,
only_block_model
))
model
=
get_model
(
lambda
:
model_provider
(
only_query_model
,
only_block_model
))
load_path
=
args
.
load
if
from_realm_chkpt
else
args
.
ict_load
if
isinstance
(
model
,
torchDDP
):
if
isinstance
(
model
,
torchDDP
):
model
=
model
.
module
model
=
model
.
module
load_path
=
args
.
load
if
from_realm_chkpt
else
args
.
ict_load
tracker_filename
=
get_checkpoint_tracker_filename
(
load_path
)
tracker_filename
=
get_checkpoint_tracker_filename
(
load_path
)
with
open
(
tracker_filename
,
'r'
)
as
f
:
with
open
(
tracker_filename
,
'r'
)
as
f
:
iteration
=
int
(
f
.
read
().
strip
())
iteration
=
int
(
f
.
read
().
strip
())
...
@@ -174,7 +250,9 @@ def get_one_epoch_dataloader(dataset):
...
@@ -174,7 +250,9 @@ def get_one_epoch_dataloader(dataset):
args
=
get_args
()
args
=
get_args
()
world_size
=
mpu
.
get_data_parallel_world_size
()
world_size
=
mpu
.
get_data_parallel_world_size
()
print
(
world_size
,
flush
=
True
)
rank
=
mpu
.
get_data_parallel_rank
()
rank
=
mpu
.
get_data_parallel_rank
()
print
(
rank
,
flush
=
True
)
global_batch_size
=
args
.
batch_size
*
world_size
global_batch_size
=
args
.
batch_size
*
world_size
num_workers
=
args
.
num_workers
num_workers
=
args
.
num_workers
...
...
indexer_async.py
deleted
100644 → 0
View file @
e338e311
import
os
import
time
import
torch
import
torch.distributed
as
dist
from
megatron
import
get_args
from
megatron.global_vars
import
set_global_variables
from
megatron.initialize
import
init_distributed
,
_init_autoresume
,
_set_random_seed
,
_write_args_to_tensorboard
from
megatron.mpu.initialize
import
set_data_parallel_group
,
set_model_parallel_group
# Example: 4x8 for training, 1x8 for indexing.
# Assign args.rank < 32 to TRAIN_PROCESS_GROUP, args.rank >= to INDEX_PROCESS_GROUP
# can manually assign _MODEL_PARALLEL_GROUP to args.rank, _DATA_PARALLEL_GROUP to train or index process group
# for both, create a torchDDP accordingly because you need to set up the model to be data-parallel on each.
INDEX_READY
=
None
TRAIN_GROUP
=
None
INDEX_GROUP
=
None
# flow:
# index builder finishes first and sets INDEX_READY = 1.
# communicates by dist.broadcast(INDEX_READY, src=min_index_rank)
# index builder is now waiting for INDEX_READY = 0.
#
# at every iteration, trainer checks INDEX_READY = 1.
# when INDEX_READY = 1, reload the index, save model checkpoint and set INDEX_READY = 0.
# once done, trainer does dist.broadcast(INDEX_READY, src=min_train_rank)
# when INDEX_READY = 0, indexer loads up model checkpoint and begins again.
def
pprint
(
*
args
):
print
(
*
args
,
flush
=
True
)
def
initialize_and_run_async_megatron
(
extra_args_provider
=
None
,
args_defaults
=
{},
ignore_unknown_args
=
False
,
allow_no_cuda
=
False
):
if
not
allow_no_cuda
:
# Make sure cuda is available.
assert
torch
.
cuda
.
is_available
(),
'Megatron requires CUDA.'
# Parse args, build tokenizer, and set adlr-autoresume,
# tensorboard-writer, and timers.
set_global_variables
(
extra_args_provider
=
extra_args_provider
,
args_defaults
=
args_defaults
,
ignore_unknown_args
=
ignore_unknown_args
)
# instead of _initialize_distributed()
init_distributed
()
setup_groups
()
pprint
(
'finished setting up groups'
)
# Autoresume
_init_autoresume
()
pprint
(
'finished setting up autoresume'
)
# Random seeds for reproducibility.
args
=
get_args
()
if
args
.
rank
==
0
:
pprint
(
'> setting random seeds to {} ...'
.
format
(
args
.
seed
))
# _set_random_seed(args.seed)
# Write arguments to tensorboard.
_write_args_to_tensorboard
()
pprint
(
'finished writing args to tensorboard'
)
torch
.
distributed
.
barrier
()
global
INDEX_READY
INDEX_READY
=
torch
.
zeros
(
1
).
cuda
()
if
args
.
rank
<
args
.
max_training_rank
:
runner
=
AsyncREALMTrainer
(
args
.
rank
)
torch
.
distributed
.
barrier
(
TRAIN_GROUP
)
pprint
(
"All trainers ready."
)
runner
.
dummy_train_model
()
else
:
runner
=
AsyncIndexBuilder
(
args
.
rank
)
torch
.
distributed
.
barrier
(
INDEX_GROUP
)
pprint
(
"All indexers ready."
)
runner
.
dummy_build_index
()
def
setup_groups
():
args
=
get_args
()
world_size
=
dist
.
get_world_size
()
max_training_rank
=
args
.
max_training_rank
# assuming no model parallelism right now
set_model_parallel_group
(
args
.
rank
)
global
TRAIN_GROUP
global
INDEX_GROUP
# important for batching and whatnot
TRAIN_GROUP
=
dist
.
new_group
(
list
(
range
(
max_training_rank
)))
INDEX_GROUP
=
dist
.
new_group
(
list
(
range
(
max_training_rank
,
world_size
)))
if
args
.
rank
>
max_training_rank
:
set_data_parallel_group
(
INDEX_GROUP
)
else
:
set_data_parallel_group
(
TRAIN_GROUP
)
class
AsyncIndexBuilder
(
object
):
def
__init__
(
self
,
rank
):
self
.
rank
=
rank
pprint
(
"My rank: "
,
self
.
rank
)
def
dummy_build_index
(
self
):
start_time
=
time
.
time
()
pprint
(
"START: {}"
.
format
(
time
.
ctime
(
start_time
)))
pprint
(
"-"
*
100
)
for
i
in
range
(
5
):
# simulating building the index which takes 20 seconds
time
.
sleep
(
10
)
pprint
(
'built the index. Time: {}'
.
format
(
time
.
ctime
(
time
.
time
())))
args
=
get_args
()
global
INDEX_READY
if
self
.
rank
==
args
.
max_training_rank
:
# broadcasting that the index is ready
INDEX_READY
=
1
-
INDEX_READY
send_handle
=
dist
.
broadcast
(
INDEX_READY
,
args
.
max_training_rank
,
async_op
=
True
)
pprint
(
"Broadcasted index ready = "
,
INDEX_READY
)
else
:
send_recv_handle
=
dist
.
broadcast
(
INDEX_READY
,
args
.
max_training_rank
,
async_op
=
True
)
torch
.
distributed
.
barrier
(
INDEX_GROUP
)
pprint
(
"Synced after broadcasting"
)
recv_handle
=
dist
.
broadcast
(
INDEX_READY
,
0
,
async_op
=
True
)
while
INDEX_READY
==
1
:
pprint
(
'waiting for new model. Time: {}'
.
format
(
time
.
ctime
(
time
.
time
())))
time
.
sleep
(
1
)
class
AsyncREALMTrainer
(
object
):
def
__init__
(
self
,
rank
):
self
.
rank
=
rank
pprint
(
"My rank: "
,
self
.
rank
)
def
dummy_train_model
(
self
):
start_time
=
time
.
time
()
pprint
(
"START: {}"
.
format
(
time
.
ctime
(
start_time
)))
pprint
(
"-"
*
100
)
args
=
get_args
()
for
i
in
range
(
5
):
global
INDEX_READY
recv_handle
=
dist
.
broadcast
(
INDEX_READY
,
args
.
max_training_rank
,
async_op
=
True
)
while
True
:
if
INDEX_READY
==
1
:
break
assert
self
.
rank
!=
args
.
max_training_rank
pprint
(
'waiting for new index. Time: {}'
.
format
(
time
.
ctime
(
time
.
time
())))
time
.
sleep
(
2
)
# INDEX_READY is 1
if
self
.
rank
==
0
:
INDEX_READY
=
1
-
INDEX_READY
send_handle
=
dist
.
broadcast
(
INDEX_READY
,
0
,
async_op
=
True
)
pprint
(
"Broadcasted index ready = "
,
INDEX_READY
)
else
:
send_recv_handle
=
dist
.
broadcast
(
INDEX_READY
,
0
,
async_op
=
True
)
torch
.
distributed
.
barrier
(
TRAIN_GROUP
)
pprint
(
"Synced after broadcasting"
)
if
__name__
==
"__main__"
:
initialize_and_run_async_megatron
(
args_defaults
=
{
'tokenizer_type'
:
'BertWordPieceLowerCase'
})
megatron/data/realm_dataset_utils.py
View file @
d4b00be0
...
@@ -187,8 +187,8 @@ def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epo
...
@@ -187,8 +187,8 @@ def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epo
# parallel case
# parallel case
counts
=
torch
.
cuda
.
LongTensor
([
1
])
counts
=
torch
.
cuda
.
LongTensor
([
1
])
torch
.
distributed
.
all_reduce
(
counts
,
group
=
mpu
.
get_data_parallel_group
())
torch
.
distributed
.
all_reduce
(
counts
,
group
=
mpu
.
get_data_parallel_group
())
assert
counts
[
0
].
item
()
==
torch
.
distributed
.
get_world_size
(
#
assert counts[0].item() == torch.distributed.get_world_size(
group
=
mpu
.
get_data_parallel_group
())
#
group=mpu.get_data_parallel_group())
# Load indexed dataset.
# Load indexed dataset.
print_rank_0
(
' > loading indexed mapping from {}'
.
format
(
print_rank_0
(
' > loading indexed mapping from {}'
.
format
(
...
...
megatron/mpu/initialize.py
View file @
d4b00be0
...
@@ -26,6 +26,10 @@ _MODEL_PARALLEL_GROUP = None
...
@@ -26,6 +26,10 @@ _MODEL_PARALLEL_GROUP = None
# Data parallel group that the current rank belongs to.
# Data parallel group that the current rank belongs to.
_DATA_PARALLEL_GROUP
=
None
_DATA_PARALLEL_GROUP
=
None
_TRAIN_GROUP
=
None
_INDEX_GROUP
=
None
_INDEX_READY
=
None
# These values enable us to change the mpu sizes on the fly.
# These values enable us to change the mpu sizes on the fly.
_MPU_WORLD_SIZE
=
None
_MPU_WORLD_SIZE
=
None
_MPU_RANK
=
None
_MPU_RANK
=
None
...
@@ -105,8 +109,10 @@ def set_model_parallel_group(group):
...
@@ -105,8 +109,10 @@ def set_model_parallel_group(group):
def
get_data_parallel_group
():
def
get_data_parallel_group
():
"""Get the data parallel group the caller rank belongs to."""
"""Get the data parallel group the caller rank belongs to."""
#print(">>> yeah this function works.")
assert
_DATA_PARALLEL_GROUP
is
not
None
,
\
assert
_DATA_PARALLEL_GROUP
is
not
None
,
\
'data parallel group is not initialized'
'data parallel group is not initialized'
#print(_DATA_PARALLEL_GROUP)
return
_DATA_PARALLEL_GROUP
return
_DATA_PARALLEL_GROUP
...
@@ -114,6 +120,7 @@ def set_data_parallel_group(group):
...
@@ -114,6 +120,7 @@ def set_data_parallel_group(group):
global
_DATA_PARALLEL_GROUP
global
_DATA_PARALLEL_GROUP
assert
_DATA_PARALLEL_GROUP
is
None
,
\
assert
_DATA_PARALLEL_GROUP
is
None
,
\
'data parallel group has already been initialized'
'data parallel group has already been initialized'
print
(
">>> setting data parallel group: "
,
group
,
flush
=
True
)
_DATA_PARALLEL_GROUP
=
group
_DATA_PARALLEL_GROUP
=
group
...
@@ -169,3 +176,30 @@ def destroy_model_parallel():
...
@@ -169,3 +176,30 @@ def destroy_model_parallel():
_MODEL_PARALLEL_GROUP
=
None
_MODEL_PARALLEL_GROUP
=
None
global
_DATA_PARALLEL_GROUP
global
_DATA_PARALLEL_GROUP
_DATA_PARALLEL_GROUP
=
None
_DATA_PARALLEL_GROUP
=
None
def
init_realm_groups
(
max_training_rank
,
world_size
):
global
_TRAIN_GROUP
_TRAIN_GROUP
=
torch
.
distributed
.
new_group
(
list
(
range
(
max_training_rank
)))
global
_INDEX_GROUP
_INDEX_GROUP
=
torch
.
distributed
.
new_group
(
list
(
range
(
max_training_rank
,
world_size
)))
global
_INDEX_READY
_INDEX_READY
=
torch
.
zeros
(
1
).
cuda
()
def
get_train_group
():
global
_TRAIN_GROUP
assert
_TRAIN_GROUP
is
not
None
return
_TRAIN_GROUP
def
get_index_group
():
global
_INDEX_GROUP
assert
_INDEX_GROUP
is
not
None
return
_INDEX_GROUP
def
get_index_ready
():
global
_INDEX_READY
assert
_INDEX_READY
is
not
None
return
_INDEX_READY
megatron/training.py
View file @
d4b00be0
...
@@ -36,14 +36,18 @@ from megatron.initialize import initialize_megatron
...
@@ -36,14 +36,18 @@ from megatron.initialize import initialize_megatron
from
megatron.learning_rates
import
AnnealingLR
from
megatron.learning_rates
import
AnnealingLR
from
megatron.model
import
DistributedDataParallel
as
LocalDDP
from
megatron.model
import
DistributedDataParallel
as
LocalDDP
from
megatron.model
import
get_params_for_weight_decay_optimization
from
megatron.model
import
get_params_for_weight_decay_optimization
from
megatron.mpu.initialize
import
get_index_ready
,
get_train_group
from
megatron.utils
import
check_adlr_autoresume_termination
from
megatron.utils
import
check_adlr_autoresume_termination
from
megatron.utils
import
make_data_loader
from
megatron.utils
import
make_data_loader
from
megatron.utils
import
report_memory
from
megatron.utils
import
report_memory
from
indexer_utils
import
check_index_com_file_ready
,
set_index_com_file_not_ready
,
set_model_com_file_ready
INDEX_READY
=
None
def
pretrain
(
train_valid_test_dataset_provider
,
model_provider
,
def
pretrain
(
train_valid_test_dataset_provider
,
model_provider
,
forward_step_func
,
extra_args_provider
=
None
,
args_defaults
=
{}):
forward_step_func
,
extra_args_provider
=
None
,
args_defaults
=
{},
initializer_func
=
None
):
"""Main training program.
"""Main training program.
This function will run the followings in the order provided:
This function will run the followings in the order provided:
...
@@ -69,8 +73,15 @@ def pretrain(train_valid_test_dataset_provider, model_provider,
...
@@ -69,8 +73,15 @@ def pretrain(train_valid_test_dataset_provider, model_provider,
"""
"""
# Initalize and get arguments, timers, and Tensorboard writer.
# Initalize and get arguments, timers, and Tensorboard writer.
initialize_megatron
(
extra_args_provider
=
extra_args_provider
,
if
initializer_func
is
None
:
args_defaults
=
args_defaults
)
initialize_megatron
(
extra_args_provider
=
extra_args_provider
,
args_defaults
=
args_defaults
)
else
:
initializer_func
(
extra_args_provider
=
extra_args_provider
,
args_defaults
=
args_defaults
)
global
INDEX_READY
INDEX_READY
=
get_index_ready
()
args
=
get_args
()
args
=
get_args
()
timers
=
get_timers
()
timers
=
get_timers
()
...
@@ -250,7 +261,6 @@ def backward_step(optimizer, model, loss):
...
@@ -250,7 +261,6 @@ def backward_step(optimizer, model, loss):
else
:
else
:
optimizer
.
clip_master_grads
(
args
.
clip_grad
)
optimizer
.
clip_master_grads
(
args
.
clip_grad
)
ran_backward_once
=
False
def
train_step
(
forward_step_func
,
data_iterator
,
def
train_step
(
forward_step_func
,
data_iterator
,
model
,
optimizer
,
lr_scheduler
):
model
,
optimizer
,
lr_scheduler
):
...
@@ -363,15 +373,20 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
...
@@ -363,15 +373,20 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
timers
(
'interval time'
).
start
()
timers
(
'interval time'
).
start
()
report_memory_flag
=
True
report_memory_flag
=
True
global
INDEX_READY
recv_handle
=
torch
.
distributed
.
broadcast
(
INDEX_READY
,
args
.
max_training_rank
,
async_op
=
True
)
while
iteration
<
args
.
train_iters
:
while
iteration
<
args
.
train_iters
:
if
hasattr
(
model
,
'retriever'
):
if
hasattr
(
model
,
'retriever'
)
and
INDEX_READY
==
1
:
new_index_ready
=
check_index_com_file_ready
()
model
.
retriever
.
reload_index
()
if
new_index_ready
:
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr_scheduler
)
torch
.
distributed
.
barrier
()
model
.
retriever
.
reload_index
()
if
args
.
rank
==
0
:
set_index_com_file_not_ready
()
INDEX_READY
=
1
-
INDEX_READY
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr_scheduler
)
print
(
"Switched index ready"
,
flush
=
True
)
set_model_com_file_ready
()
send_handle
=
torch
.
distributed
.
broadcast
(
INDEX_READY
,
0
,
async_op
=
True
)
torch
.
distributed
.
barrier
(
get_train_group
())
recv_handle
=
torch
.
distributed
.
broadcast
(
INDEX_READY
,
args
.
max_training_rank
,
async_op
=
True
)
loss_dict
,
skipped_iter
=
train_step
(
forward_step_func
,
loss_dict
,
skipped_iter
=
train_step
(
forward_step_func
,
train_data_iterator
,
train_data_iterator
,
...
...
pretrain_realm.py
View file @
d4b00be0
...
@@ -27,6 +27,7 @@ from megatron.data.dataset_utils import build_train_valid_test_datasets
...
@@ -27,6 +27,7 @@ from megatron.data.dataset_utils import build_train_valid_test_datasets
from
megatron.model
import
REALMBertModel
,
REALMRetriever
from
megatron.model
import
REALMBertModel
,
REALMRetriever
from
megatron.training
import
pretrain
from
megatron.training
import
pretrain
from
megatron.utils
import
reduce_losses
from
megatron.utils
import
reduce_losses
from
indexer
import
initialize_and_run_async_megatron
num_batches
=
0
num_batches
=
0
...
@@ -177,4 +178,5 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
...
@@ -177,4 +178,5 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
pretrain
(
train_valid_test_datasets_provider
,
model_provider
,
forward_step
,
pretrain
(
train_valid_test_datasets_provider
,
model_provider
,
forward_step
,
args_defaults
=
{
'tokenizer_type'
:
'BertWordPieceLowerCase'
})
args_defaults
=
{
'tokenizer_type'
:
'BertWordPieceLowerCase'
},
initializer_func
=
initialize_and_run_async_megatron
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment