Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
a670b6c9
Commit
a670b6c9
authored
May 19, 2020
by
Neel Kant
Browse files
Async works for total 8 GPU, indexer debug mode
parent
5684f904
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
35 additions
and
45 deletions
+35
-45
indexer.py
indexer.py
+12
-13
megatron/checkpointing.py
megatron/checkpointing.py
+4
-4
megatron/global_vars.py
megatron/global_vars.py
+2
-2
megatron/model/distributed.py
megatron/model/distributed.py
+1
-1
megatron/model/realm_model.py
megatron/model/realm_model.py
+3
-2
megatron/mpu/initialize.py
megatron/mpu/initialize.py
+0
-2
megatron/training.py
megatron/training.py
+9
-18
megatron/utils.py
megatron/utils.py
+3
-2
pretrain_realm.py
pretrain_realm.py
+1
-1
No files found.
indexer.py
View file @
a670b6c9
...
...
@@ -16,7 +16,7 @@ from megatron.data.samplers import DistributedBatchSampler
from
megatron.initialize
import
initialize_megatron
from
megatron.model
import
REALMRetriever
from
megatron.global_vars
import
set_global_variables
from
megatron.mpu.initialize
import
get_index_ready
,
get_index_group
,
get_train_group
from
megatron.mpu.initialize
import
get_index_ready
,
get_index_group
,
get_train_group
,
get_data_parallel_group
from
megatron.mpu.initialize
import
set_data_parallel_group
,
set_model_parallel_group
,
init_realm_groups
from
megatron.initialize
import
init_distributed
,
_init_autoresume
,
_set_random_seed
,
_write_args_to_tensorboard
from
megatron.training
import
get_model
...
...
@@ -67,12 +67,12 @@ def initialize_and_run_async_megatron(extra_args_provider=None, args_defaults={}
torch
.
distributed
.
barrier
()
if
args
.
rank
<
args
.
max_training_rank
:
torch
.
distributed
.
barrier
(
get_
train
_group
())
torch
.
distributed
.
barrier
(
get_
data_parallel
_group
())
pprint
(
"All trainers ready."
)
return
else
:
runner
=
AsyncIndexBuilder
(
args
.
rank
)
torch
.
distributed
.
barrier
(
get_
index
_group
())
torch
.
distributed
.
barrier
(
get_
data_parallel
_group
())
pprint
(
"All indexers ready."
)
runner
.
run_async
()
...
...
@@ -123,6 +123,7 @@ class AsyncIndexBuilder(object):
try
:
self
.
model
=
load_ict_checkpoint
(
only_block_model
=
True
,
no_grad
=
True
,
from_realm_chkpt
=
True
)
except
:
print
(
">>>>> No realm chkpt available"
,
flush
=
True
)
self
.
model
=
load_ict_checkpoint
(
only_block_model
=
True
,
no_grad
=
True
,
from_realm_chkpt
=
False
)
self
.
model
.
eval
()
self
.
dataloader
=
iter
(
get_one_epoch_dataloader
(
get_ict_dataset
()))
...
...
@@ -148,7 +149,7 @@ class AsyncIndexBuilder(object):
total
+=
block_indices
.
size
i
+=
1
if
i
%
1
0
==
0
:
if
i
%
50
0
==
0
:
print
(
'Batch {:10d} | Total {:10d}'
.
format
(
i
,
total
),
flush
=
True
)
if
self
.
debug
:
break
...
...
@@ -162,7 +163,7 @@ class AsyncIndexBuilder(object):
sys
.
exit
(
0
)
self
.
block_data
.
save_shard
(
self
.
rank
)
torch
.
distributed
.
barrier
()
torch
.
distributed
.
barrier
(
get_data_parallel_group
()
)
del
self
.
model
if
self
.
is_main_builder
:
...
...
@@ -174,12 +175,11 @@ class AsyncIndexBuilder(object):
if
self
.
is_main_builder
:
INDEX_READY
=
1
-
INDEX_READY
print
(
"Switched INDEX_READY"
,
flush
=
True
)
import
time
print
(
time
.
ctime
(
time
.
time
()),
flush
=
True
)
torch
.
cuda
.
synchronize
()
send_handle
=
dist
.
broadcast
(
INDEX_READY
,
self
.
main_builder_idx
,
async_op
=
True
)
torch
.
distributed
.
barrier
(
get_
index
_group
())
recv_handle
=
dist
.
broadcast
(
INDEX_READY
,
0
,
async_op
=
True
)
torch
.
distributed
.
barrier
(
get_
data_parallel
_group
())
recv_handle
=
dist
.
broadcast
(
INDEX_READY
,
0
)
class
BasicIndexBuilder
(
object
):
...
...
@@ -236,7 +236,7 @@ def load_ict_checkpoint(only_query_model=False, only_block_model=False, no_grad=
with
open
(
tracker_filename
,
'r'
)
as
f
:
iteration
=
int
(
f
.
read
().
strip
())
assert
iteration
>
0
#
assert iteration > 0
checkpoint_name
=
get_checkpoint_name
(
load_path
,
iteration
,
False
)
if
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
'global rank {} is loading checkpoint {}'
.
format
(
...
...
@@ -245,6 +245,7 @@ def load_ict_checkpoint(only_query_model=False, only_block_model=False, no_grad=
state_dict
=
torch
.
load
(
checkpoint_name
,
map_location
=
'cpu'
)
ict_state_dict
=
state_dict
[
'model'
]
if
from_realm_chkpt
:
print
(
">>>> Attempting to get ict state dict from realm"
,
flush
=
True
)
ict_state_dict
=
ict_state_dict
[
'retriever'
][
'ict_model'
]
if
only_query_model
:
...
...
@@ -256,7 +257,7 @@ def load_ict_checkpoint(only_query_model=False, only_block_model=False, no_grad=
model
.
load_state_dict
(
ict_state_dict
)
else
:
model
.
load_state_dict
(
ict_state_dict
)
torch
.
distributed
.
barrier
()
torch
.
distributed
.
barrier
(
get_data_parallel_group
()
)
if
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
' successfully loaded {}'
.
format
(
checkpoint_name
))
...
...
@@ -290,9 +291,7 @@ def get_one_epoch_dataloader(dataset):
args
=
get_args
()
world_size
=
mpu
.
get_data_parallel_world_size
()
print
(
world_size
,
flush
=
True
)
rank
=
mpu
.
get_data_parallel_rank
()
print
(
rank
,
flush
=
True
)
global_batch_size
=
args
.
batch_size
*
world_size
num_workers
=
args
.
num_workers
...
...
megatron/checkpointing.py
View file @
a670b6c9
...
...
@@ -24,7 +24,7 @@ import torch
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
from
megatron
import
mpu
from
megatron.mpu.initialize
import
get_train_group
from
megatron.mpu.initialize
import
get_train_group
,
get_data_parallel_group
from
megatron
import
get_args
from
megatron
import
print_rank_0
...
...
@@ -119,14 +119,14 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
print
(
' successfully saved {}'
.
format
(
checkpoint_name
))
# Wait so everyone is done (necessary)
torch
.
distributed
.
barrier
(
get_
train
_group
())
torch
.
distributed
.
barrier
(
get_
data_parallel
_group
())
# And update the latest iteration
if
torch
.
distributed
.
get_rank
()
==
0
:
tracker_filename
=
get_checkpoint_tracker_filename
(
args
.
save
)
with
open
(
tracker_filename
,
'w'
)
as
f
:
f
.
write
(
str
(
iteration
))
# Wait so everyone is done (not necessary)
torch
.
distributed
.
barrier
(
get_
train
_group
())
torch
.
distributed
.
barrier
(
get_
data_parallel
_group
())
def
load_checkpoint
(
model
,
optimizer
,
lr_scheduler
):
...
...
@@ -243,7 +243,7 @@ def load_checkpoint(model, optimizer, lr_scheduler):
'exiting ...'
.
format
(
checkpoint_name
))
sys
.
exit
()
torch
.
distributed
.
barrier
()
#
torch.distributed.barrier()
if
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
' successfully loaded {}'
.
format
(
checkpoint_name
))
...
...
megatron/global_vars.py
View file @
a670b6c9
...
...
@@ -164,14 +164,14 @@ class _Timer:
def
start
(
self
):
"""Start the timer."""
assert
not
self
.
started_
,
'timer has already been started'
torch
.
cuda
.
synchronize
()
#
torch.cuda.synchronize()
self
.
start_time
=
time
.
time
()
self
.
started_
=
True
def
stop
(
self
):
"""Stop the timer."""
assert
self
.
started_
,
'timer is not started'
torch
.
cuda
.
synchronize
()
#
torch.cuda.synchronize()
self
.
elapsed_
+=
(
time
.
time
()
-
self
.
start_time
)
self
.
started_
=
False
...
...
megatron/model/distributed.py
View file @
a670b6c9
...
...
@@ -56,7 +56,7 @@ class DistributedDataParallel(MegatronModule):
if
not
no_scale
and
not
reduce_after
:
coalesced
/=
dist
.
get_world_size
(
group
=
self
.
data_parallel_group
)
dist
.
all_reduce
(
coalesced
,
group
=
self
.
data_parallel_group
)
torch
.
cuda
.
synchronize
()
#
torch.cuda.synchronize()
if
not
no_scale
and
reduce_after
:
coalesced
/=
dist
.
get_world_size
(
group
=
self
.
data_parallel_group
)
for
buf
,
synced
in
zip
(
grads
,
_unflatten_dense_tensors
(
coalesced
,
grads
)):
...
...
megatron/model/realm_model.py
View file @
a670b6c9
...
...
@@ -199,7 +199,9 @@ class REALMRetriever(MegatronModule):
true_model
=
true_model
.
module
else
:
true_model
=
self
.
ict_model
query_embeds
=
detach
(
true_model
.
embed_query
(
query_tokens
,
query_pad_mask
))
# print("true model: ", true_model, flush=True)
query_embeds
=
detach
(
self
.
ict_model
(
query_tokens
,
query_pad_mask
,
None
,
None
,
only_query
=
True
))
_
,
block_indices
=
self
.
hashed_index
.
search_mips_index
(
query_embeds
,
top_k
=
self
.
top_k
,
reconstruct
=
False
)
all_topk_tokens
,
all_topk_pad_masks
=
[],
[]
...
...
@@ -268,7 +270,6 @@ class ICTBertModel(MegatronModule):
def
forward
(
self
,
query_tokens
,
query_attention_mask
,
block_tokens
,
block_attention_mask
,
only_query
=
False
,
only_block
=
False
):
"""Run a forward pass for each of the models and compute the similarity scores."""
if
only_query
:
return
self
.
embed_query
(
query_tokens
,
query_attention_mask
)
...
...
megatron/mpu/initialize.py
View file @
a670b6c9
...
...
@@ -109,10 +109,8 @@ def set_model_parallel_group(group):
def
get_data_parallel_group
():
"""Get the data parallel group the caller rank belongs to."""
#print(">>> yeah this function works.")
assert
_DATA_PARALLEL_GROUP
is
not
None
,
\
'data parallel group is not initialized'
#print(_DATA_PARALLEL_GROUP)
return
_DATA_PARALLEL_GROUP
...
...
megatron/training.py
View file @
a670b6c9
...
...
@@ -36,7 +36,7 @@ from megatron.initialize import initialize_megatron
from
megatron.learning_rates
import
AnnealingLR
from
megatron.model
import
DistributedDataParallel
as
LocalDDP
from
megatron.model
import
get_params_for_weight_decay_optimization
from
megatron.mpu.initialize
import
get_index_ready
,
get_train_group
from
megatron.mpu.initialize
import
get_index_ready
,
get_train_group
,
get_data_parallel_group
from
megatron.utils
import
check_adlr_autoresume_termination
from
megatron.utils
import
make_data_loader
from
megatron.utils
import
report_memory
...
...
@@ -236,7 +236,7 @@ def backward_step(optimizer, model, loss):
"""Backward step."""
args
=
get_args
()
timers
=
get_timers
()
torch
.
cuda
.
synchronize
()
#
torch.cuda.synchronize()
# Backward pass.
optimizer
.
zero_grad
(
set_grads_to_None
=
True
)
...
...
@@ -373,19 +373,10 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
timers
(
'interval time'
).
start
()
report_memory_flag
=
True
import
time
print
(
">>> going to sleep"
,
flush
=
True
)
time
.
sleep
(
10
)
print
(
">>> woke from sleep"
,
flush
=
True
)
print
(
time
.
ctime
(
time
.
time
()),
flush
=
True
)
global
INDEX_READY
recv_handle
=
torch
.
distributed
.
broadcast
(
INDEX_READY
,
args
.
max_training_rank
,
async_op
=
True
)
print
(
">>>>>>>> Created recv handle"
,
flush
=
True
)
while
iteration
<
args
.
train_iters
:
print
(
"INDEX READY: "
,
INDEX_READY
)
if
args
.
max_training_rank
is
not
None
and
INDEX_READY
==
1
:
print
(
">>>>>>> entering the good stuff"
,
flush
=
True
)
true_model
=
model
if
hasattr
(
true_model
,
'module'
):
true_model
=
true_model
.
module
...
...
@@ -393,24 +384,24 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
true_model
=
true_model
.
module
print
(
">>>>>>> starting to reload index"
,
flush
=
True
)
true_model
.
retriever
.
reload_index
()
print
(
">>>>>>> starting to save checkpoint"
,
flush
=
True
)
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr_scheduler
)
print
(
">>>>>>> saved checkpoint"
,
flush
=
True
)
if
args
.
rank
==
0
:
INDEX_READY
=
1
-
INDEX_READY
print
(
"Switched index ready"
,
flush
=
True
)
send_handle
=
torch
.
distributed
.
broadcast
(
INDEX_READY
,
0
,
async_op
=
True
)
torch
.
distributed
.
barrier
(
get_train_group
())
print
(
">>> Switched index ready"
,
flush
=
True
)
torch
.
cuda
.
synchronize
()
send_handle
=
torch
.
distributed
.
broadcast
(
INDEX_READY
,
0
)
torch
.
distributed
.
barrier
(
get_data_parallel_group
())
recv_handle
=
torch
.
distributed
.
broadcast
(
INDEX_READY
,
args
.
max_training_rank
,
async_op
=
True
)
else
:
print
(
"
>>>>>>>
moving right along"
,
flush
=
True
)
print
(
"moving right along"
,
flush
=
True
)
loss_dict
,
skipped_iter
=
train_step
(
forward_step_func
,
train_data_iterator
,
model
,
optimizer
,
lr_scheduler
)
skipped_iters
+=
skipped_iter
iteration
+=
1
...
...
@@ -443,7 +434,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
iteration
,
False
)
if
args
.
exit_interval
and
iteration
%
args
.
exit_interval
==
0
:
torch
.
distributed
.
barrier
()
torch
.
distributed
.
barrier
(
get_data_parallel_group
()
)
time_str
=
datetime
.
now
().
strftime
(
'%Y-%m-%d %H:%M:%S'
)
rank
=
torch
.
distributed
.
get_rank
()
print_rank_0
(
'rank: {} | time: {} | exiting the program at '
...
...
megatron/utils.py
View file @
a670b6c9
...
...
@@ -25,6 +25,7 @@ from megatron import mpu
from
megatron
import
print_rank_0
from
megatron.checkpointing
import
save_checkpoint
from
megatron.data.samplers
import
DistributedBatchSampler
from
megatron.mpu.initialize
import
get_data_parallel_group
from
megatron.fp16
import
FP16_Optimizer
...
...
@@ -32,7 +33,7 @@ def reduce_losses(losses):
"""Reduce a tensor of losses across all GPUs."""
reduced_losses
=
torch
.
cat
(
[
loss
.
clone
().
detach
().
view
(
1
)
for
loss
in
losses
])
torch
.
distributed
.
all_reduce
(
reduced_losses
)
torch
.
distributed
.
all_reduce
(
reduced_losses
,
group
=
get_data_parallel_group
()
)
reduced_losses
=
reduced_losses
/
torch
.
distributed
.
get_world_size
()
return
reduced_losses
...
...
@@ -78,7 +79,7 @@ def check_adlr_autoresume_termination(iteration, model,
args
=
get_args
()
autoresume
=
get_adlr_autoresume
()
# Add barrier to ensure consistnecy.
torch
.
distributed
.
barrier
()
torch
.
distributed
.
barrier
(
get_data_parallel_group
()
)
if
autoresume
.
termination_requested
():
if
args
.
save
:
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr_scheduler
)
...
...
pretrain_realm.py
View file @
a670b6c9
...
...
@@ -102,7 +102,7 @@ def forward_step(data_iterator, model):
lm_loss_
.
view
(
-
1
)
*
loss_mask
.
reshape
(
-
1
))
/
loss_mask
.
sum
()
reduced_loss
=
reduce_losses
([
lm_loss
,
retrieval_utility
])
torch
.
cuda
.
synchronize
()
#
torch.cuda.synchronize()
return
lm_loss
,
{
'lm_loss'
:
reduced_loss
[
0
],
'retrieval_utility'
:
reduced_loss
[
1
]}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment