Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
a31833ce
Commit
a31833ce
authored
Dec 11, 2020
by
mshoeybi
Committed by
Deepak Narayanan
Dec 19, 2020
Browse files
Fix some bugs, add exit-duration capability
parent
51315905
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
158 additions
and
42 deletions
+158
-42
megatron/__init__.py
megatron/__init__.py
+1
-0
megatron/arguments.py
megatron/arguments.py
+2
-0
megatron/data/dataset_utils.py
megatron/data/dataset_utils.py
+17
-5
megatron/data/gpt2_dataset.py
megatron/data/gpt2_dataset.py
+15
-6
megatron/global_vars.py
megatron/global_vars.py
+7
-2
megatron/microbatches.py
megatron/microbatches.py
+22
-16
megatron/training.py
megatron/training.py
+94
-13
No files found.
megatron/__init__.py
View file @
a31833ce
...
@@ -26,6 +26,7 @@ from .package_info import (
...
@@ -26,6 +26,7 @@ from .package_info import (
)
)
from
.global_vars
import
get_args
from
.global_vars
import
get_args
from
.global_vars
import
get_current_global_batch_size
from
.global_vars
import
get_num_microbatches
from
.global_vars
import
get_num_microbatches
from
.global_vars
import
update_num_microbatches
from
.global_vars
import
update_num_microbatches
from
.global_vars
import
get_tokenizer
from
.global_vars
import
get_tokenizer
...
...
megatron/arguments.py
View file @
a31833ce
...
@@ -326,6 +326,8 @@ def _add_training_args(parser):
...
@@ -326,6 +326,8 @@ def _add_training_args(parser):
group
.
add_argument
(
'--exit-interval'
,
type
=
int
,
default
=
None
,
group
.
add_argument
(
'--exit-interval'
,
type
=
int
,
default
=
None
,
help
=
'Exit the program after the iteration is divisible '
help
=
'Exit the program after the iteration is divisible '
'by this value.'
)
'by this value.'
)
group
.
add_argument
(
'--exit-duration-in-mins'
,
type
=
int
,
default
=
None
,
help
=
'Exit the program after this many minutes.'
)
group
.
add_argument
(
'--tensorboard-dir'
,
type
=
str
,
default
=
None
,
group
.
add_argument
(
'--tensorboard-dir'
,
type
=
str
,
default
=
None
,
help
=
'Write TensorBoard logs to this directory.'
)
help
=
'Write TensorBoard logs to this directory.'
)
group
.
add_argument
(
'--scaled-upper-triang-masked-softmax-fusion'
,
group
.
add_argument
(
'--scaled-upper-triang-masked-softmax-fusion'
,
...
...
megatron/data/dataset_utils.py
View file @
a31833ce
...
@@ -418,10 +418,22 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
...
@@ -418,10 +418,22 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
datasets_train_valid_test_num_samples
[
i
],
datasets_train_valid_test_num_samples
[
i
],
max_seq_length
,
masked_lm_prob
,
short_seq_prob
,
max_seq_length
,
masked_lm_prob
,
short_seq_prob
,
seed
,
skip_warmup
,
dataset_type
=
dataset_type
)
seed
,
skip_warmup
,
dataset_type
=
dataset_type
)
if
train_ds
:
train_datasets
.
append
(
train_ds
)
if
valid_ds
:
valid_datasets
.
append
(
valid_ds
)
if
test_ds
:
test_datasets
.
append
(
test_ds
)
# Blend.
# Blend.
blending_train_dataset
=
None
if
train_datasets
:
blending_train_dataset
=
BlendableDataset
(
train_datasets
,
weights
)
blending_train_dataset
=
BlendableDataset
(
train_datasets
,
weights
)
blending_valid_dataset
=
None
if
valid_datasets
:
blending_valid_dataset
=
BlendableDataset
(
valid_datasets
,
weights
)
blending_valid_dataset
=
BlendableDataset
(
valid_datasets
,
weights
)
blending_test_dataset
=
None
if
test_datasets
:
blending_test_dataset
=
BlendableDataset
(
test_datasets
,
weights
)
blending_test_dataset
=
BlendableDataset
(
test_datasets
,
weights
)
return
(
blending_train_dataset
,
blending_valid_dataset
,
return
(
blending_train_dataset
,
blending_valid_dataset
,
...
...
megatron/data/gpt2_dataset.py
View file @
a31833ce
...
@@ -55,13 +55,22 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
...
@@ -55,13 +55,22 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
prefixes
[
i
],
data_impl
,
splits_string
,
prefixes
[
i
],
data_impl
,
splits_string
,
datasets_train_valid_test_num_samples
[
i
],
datasets_train_valid_test_num_samples
[
i
],
seq_length
,
seed
,
skip_warmup
)
seq_length
,
seed
,
skip_warmup
)
if
train_ds
:
train_datasets
.
append
(
train_ds
)
train_datasets
.
append
(
train_ds
)
if
valid_ds
:
valid_datasets
.
append
(
valid_ds
)
valid_datasets
.
append
(
valid_ds
)
if
test_ds
:
test_datasets
.
append
(
test_ds
)
test_datasets
.
append
(
test_ds
)
# Blend.
# Blend.
blending_train_dataset
=
None
if
train_datasets
:
blending_train_dataset
=
BlendableDataset
(
train_datasets
,
weights
)
blending_train_dataset
=
BlendableDataset
(
train_datasets
,
weights
)
blending_valid_dataset
=
None
if
valid_datasets
:
blending_valid_dataset
=
BlendableDataset
(
valid_datasets
,
weights
)
blending_valid_dataset
=
BlendableDataset
(
valid_datasets
,
weights
)
blending_test_dataset
=
None
if
test_datasets
:
blending_test_dataset
=
BlendableDataset
(
test_datasets
,
weights
)
blending_test_dataset
=
BlendableDataset
(
test_datasets
,
weights
)
return
(
blending_train_dataset
,
blending_valid_dataset
,
return
(
blending_train_dataset
,
blending_valid_dataset
,
...
...
megatron/global_vars.py
View file @
a31833ce
...
@@ -43,8 +43,13 @@ def get_num_microbatches():
...
@@ -43,8 +43,13 @@ def get_num_microbatches():
return
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
.
get
()
return
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
.
get
()
def
update_num_microbatches
(
consumed_samples
):
def
get_current_global_batch_size
():
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
.
update
(
consumed_samples
)
return
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
.
get_current_global_batch_size
()
def
update_num_microbatches
(
consumed_samples
,
consistency_check
=
True
):
_GLOBAL_NUM_MICROBATCHES_CALCULATOR
.
update
(
consumed_samples
,
consistency_check
)
def
get_tokenizer
():
def
get_tokenizer
():
...
...
megatron/microbatches.py
View file @
a31833ce
...
@@ -56,12 +56,16 @@ class NumMicroBatchesCalculator(ABC):
...
@@ -56,12 +56,16 @@ class NumMicroBatchesCalculator(ABC):
def
__init__
(
self
):
def
__init__
(
self
):
self
.
num_micro_batches
=
None
self
.
num_micro_batches
=
None
self
.
current_global_batch_size
=
None
def
get
(
self
):
def
get
(
self
):
return
self
.
num_micro_batches
return
self
.
num_micro_batches
def
get_current_global_batch_size
(
self
):
return
self
.
current_global_batch_size
@
abstractmethod
@
abstractmethod
def
update
(
self
,
consumed_samples
):
def
update
(
self
,
consumed_samples
,
consistency_check
):
pass
pass
...
@@ -78,8 +82,9 @@ class ConstantNumMicroBatches(NumMicroBatchesCalculator):
...
@@ -78,8 +82,9 @@ class ConstantNumMicroBatches(NumMicroBatchesCalculator):
self
.
num_micro_batches
=
global_batch_size
//
\
self
.
num_micro_batches
=
global_batch_size
//
\
micro_batch_times_data_parallel
micro_batch_times_data_parallel
assert
self
.
num_micro_batches
>=
1
assert
self
.
num_micro_batches
>=
1
self
.
current_global_batch_size
=
global_batch_size
def
update
(
self
,
consumed_samples
):
def
update
(
self
,
consumed_samples
,
consistency_check
):
pass
pass
...
@@ -128,24 +133,25 @@ class RampupBatchsizeNumMicroBatches(NumMicroBatchesCalculator):
...
@@ -128,24 +133,25 @@ class RampupBatchsizeNumMicroBatches(NumMicroBatchesCalculator):
self
.
rampup_samples_per_increment
=
self
.
ramup_samples
/
num_increments
self
.
rampup_samples_per_increment
=
self
.
ramup_samples
/
num_increments
# Initialize number of microbatches.
# Initialize number of microbatches.
self
.
update
(
0
)
self
.
update
(
0
,
False
)
def
update
(
self
,
consumed_samples
):
def
update
(
self
,
consumed_samples
,
consistency_check
):
if
consumed_samples
>
self
.
ramup_samples
:
if
consumed_samples
>
self
.
ramup_samples
:
current_global_batch_size
=
self
.
global_batch_size
self
.
current_global_batch_size
=
self
.
global_batch_size
else
:
else
:
steps
=
int
(
consumed_samples
/
self
.
rampup_samples_per_increment
)
steps
=
int
(
consumed_samples
/
self
.
rampup_samples_per_increment
)
current_global_batch_size
=
self
.
start_batch_size
+
\
self
.
current_global_batch_size
=
self
.
start_batch_size
+
\
steps
*
self
.
batch_size_increment
steps
*
self
.
batch_size_increment
assert
current_global_batch_size
<=
self
.
global_batch_size
assert
self
.
current_global_batch_size
<=
self
.
global_batch_size
assert
current_global_batch_size
%
\
if
consistency_check
:
assert
self
.
current_global_batch_size
%
\
self
.
micro_batch_times_data_parallel_size
==
0
,
'current global '
\
self
.
micro_batch_times_data_parallel_size
==
0
,
'current global '
\
'batch size ({}) is not divisible by micro-batch-size ({}) times'
\
'batch size ({}) is not divisible by micro-batch-size ({}) times'
\
'data parallel size ({})'
.
format
(
current_global_batch_size
,
'data parallel size ({})'
.
format
(
self
.
current_global_batch_size
,
self
.
micro_batch_size
,
self
.
micro_batch_size
,
self
.
data_parallel_size
)
self
.
data_parallel_size
)
self
.
num_micro_batches
=
current_global_batch_size
//
\
self
.
num_micro_batches
=
self
.
current_global_batch_size
//
\
self
.
micro_batch_times_data_parallel_size
self
.
micro_batch_times_data_parallel_size
megatron/training.py
View file @
a31833ce
...
@@ -18,6 +18,10 @@
...
@@ -18,6 +18,10 @@
from
datetime
import
datetime
from
datetime
import
datetime
import
math
import
math
import
sys
import
sys
import
time
# The earliest we can measure the start time.
_TRAIN_START_TIME
=
time
.
time
()
import
torch
import
torch
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
from
apex.optimizers
import
FusedAdam
as
Adam
from
apex.optimizers
import
FusedAdam
as
Adam
...
@@ -25,6 +29,7 @@ from apex.optimizers import FusedAdam as Adam
...
@@ -25,6 +29,7 @@ from apex.optimizers import FusedAdam as Adam
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron
import
get_timers
from
megatron
import
get_timers
from
megatron
import
get_tensorboard_writer
from
megatron
import
get_tensorboard_writer
from
megatron
import
get_current_global_batch_size
from
megatron
import
get_num_microbatches
from
megatron
import
get_num_microbatches
from
megatron
import
update_num_microbatches
from
megatron
import
update_num_microbatches
from
megatron
import
mpu
from
megatron
import
mpu
...
@@ -44,6 +49,13 @@ from megatron.data.data_loaders import build_pretraining_data_loader
...
@@ -44,6 +49,13 @@ from megatron.data.data_loaders import build_pretraining_data_loader
from
megatron.utils
import
report_memory
from
megatron.utils
import
report_memory
def
print_datetime
(
string
):
"""Note that this call will sync across all ranks."""
torch
.
distributed
.
barrier
()
time_str
=
datetime
.
now
().
strftime
(
'%Y-%m-%d %H:%M:%S'
)
print_rank_0
(
'['
+
string
+
'] datetime: {} '
.
format
(
time_str
))
def
pretrain
(
train_valid_test_dataset_provider
,
model_provider
,
def
pretrain
(
train_valid_test_dataset_provider
,
model_provider
,
forward_step_func
,
extra_args_provider
=
None
,
args_defaults
=
{}):
forward_step_func
,
extra_args_provider
=
None
,
args_defaults
=
{}):
"""Main training program.
"""Main training program.
...
@@ -74,6 +86,18 @@ def pretrain(train_valid_test_dataset_provider, model_provider,
...
@@ -74,6 +86,18 @@ def pretrain(train_valid_test_dataset_provider, model_provider,
initialize_megatron
(
extra_args_provider
=
extra_args_provider
,
initialize_megatron
(
extra_args_provider
=
extra_args_provider
,
args_defaults
=
args_defaults
)
args_defaults
=
args_defaults
)
# Adjust the startup time so it reflects the largest value.
# This will be closer to what scheduler will see (outside of
# image ... launches.
global
_TRAIN_START_TIME
start_time_tensor
=
torch
.
cuda
.
FloatTensor
([
_TRAIN_START_TIME
])
torch
.
distributed
.
all_reduce
(
start_time_tensor
,
op
=
torch
.
distributed
.
ReduceOp
.
MIN
)
_TRAIN_START_TIME
=
start_time_tensor
.
item
()
print_rank_0
(
'time took to initialize megatron (seconds): {:.3f}'
.
format
(
time
.
time
()
-
_TRAIN_START_TIME
))
print_datetime
(
'after megatron is initialized'
)
args
=
get_args
()
args
=
get_args
()
timers
=
get_timers
()
timers
=
get_timers
()
...
@@ -81,6 +105,8 @@ def pretrain(train_valid_test_dataset_provider, model_provider,
...
@@ -81,6 +105,8 @@ def pretrain(train_valid_test_dataset_provider, model_provider,
timers
(
'model and optimizer'
).
start
()
timers
(
'model and optimizer'
).
start
()
model
,
optimizer
,
lr_scheduler
=
setup_model_and_optimizer
(
model_provider
)
model
,
optimizer
,
lr_scheduler
=
setup_model_and_optimizer
(
model_provider
)
timers
(
'model and optimizer'
).
stop
()
timers
(
'model and optimizer'
).
stop
()
print_datetime
(
'after model, optimizer, and learning rate '
'scheduler are built'
)
# Data stuff.
# Data stuff.
timers
(
'train/valid/test data iterators'
).
start
()
timers
(
'train/valid/test data iterators'
).
start
()
...
@@ -88,6 +114,7 @@ def pretrain(train_valid_test_dataset_provider, model_provider,
...
@@ -88,6 +114,7 @@ def pretrain(train_valid_test_dataset_provider, model_provider,
=
build_train_valid_test_data_iterators
(
=
build_train_valid_test_data_iterators
(
train_valid_test_dataset_provider
)
train_valid_test_dataset_provider
)
timers
(
'train/valid/test data iterators'
).
stop
()
timers
(
'train/valid/test data iterators'
).
stop
()
print_datetime
(
'after dataloaders are build'
)
# Print setup timing.
# Print setup timing.
print_rank_0
(
'done with setups ...'
)
print_rank_0
(
'done with setups ...'
)
...
@@ -99,6 +126,7 @@ def pretrain(train_valid_test_dataset_provider, model_provider,
...
@@ -99,6 +126,7 @@ def pretrain(train_valid_test_dataset_provider, model_provider,
iteration
=
train
(
forward_step_func
,
iteration
=
train
(
forward_step_func
,
model
,
optimizer
,
lr_scheduler
,
model
,
optimizer
,
lr_scheduler
,
train_data_iterator
,
valid_data_iterator
)
train_data_iterator
,
valid_data_iterator
)
print_datetime
(
'after training is done'
)
if
args
.
do_valid
:
if
args
.
do_valid
:
prefix
=
'the end of training for val data'
prefix
=
'the end of training for val data'
...
@@ -132,13 +160,11 @@ def update_train_iters(args):
...
@@ -132,13 +160,11 @@ def update_train_iters(args):
consumed_samples
=
0
consumed_samples
=
0
# Rampup phase.
# Rampup phase.
while
consumed_samples
<=
int
(
args
.
rampup_batch_size
[
2
]):
while
consumed_samples
<=
int
(
args
.
rampup_batch_size
[
2
]):
update_num_microbatches
(
consumed_samples
)
update_num_microbatches
(
consumed_samples
,
consistency_check
=
False
)
consumed_samples
+=
get_num_microbatches
()
*
\
consumed_samples
+=
get_current_global_batch_size
()
args
.
micro_batch_size
*
\
args
.
data_parallel_size
iterations
+=
1
iterations
+=
1
# Reset
# Reset
update_num_microbatches
(
0
)
update_num_microbatches
(
0
,
consistency_check
=
False
)
# Constant phase
# Constant phase
# Note that we throw away any partial last batch.
# Note that we throw away any partial last batch.
iterations
+=
(
args
.
train_samples
-
consumed_samples
)
//
\
iterations
+=
(
args
.
train_samples
-
consumed_samples
)
//
\
...
@@ -267,7 +293,15 @@ def setup_model_and_optimizer(model_provider_func):
...
@@ -267,7 +293,15 @@ def setup_model_and_optimizer(model_provider_func):
lr_scheduler
=
get_learning_rate_scheduler
(
optimizer
)
lr_scheduler
=
get_learning_rate_scheduler
(
optimizer
)
if
args
.
load
is
not
None
:
if
args
.
load
is
not
None
:
timers
=
get_timers
()
# Extra barrier is added to make sure all ranks report the
# max time.
torch
.
distributed
.
barrier
()
timers
(
'load checkpoint'
).
start
()
args
.
iteration
=
load_checkpoint
(
model
,
optimizer
,
lr_scheduler
)
args
.
iteration
=
load_checkpoint
(
model
,
optimizer
,
lr_scheduler
)
torch
.
distributed
.
barrier
()
timers
(
'load checkpoint'
).
stop
()
timers
.
log
([
'load checkpoint'
])
else
:
else
:
args
.
iteration
=
0
args
.
iteration
=
0
...
@@ -685,11 +719,22 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
...
@@ -685,11 +719,22 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
# Tensorboard values.
# Tensorboard values.
if
writer
and
torch
.
distributed
.
get_rank
()
==
0
:
if
writer
and
torch
.
distributed
.
get_rank
()
==
0
:
writer
.
add_scalar
(
'learning_rate'
,
learning_rate
,
iteration
)
writer
.
add_scalar
(
'learning_rate-iterations'
,
learning_rate
,
iteration
)
writer
.
add_scalar
(
'learning_rate-samples'
,
learning_rate
,
args
.
consumed_train_samples
)
batch_size
=
args
.
micro_batch_size
*
args
.
data_parallel_size
*
\
get_num_microbatches
()
writer
.
add_scalar
(
'batch_size-iterations'
,
batch_size
,
iteration
)
writer
.
add_scalar
(
'batch_size-samples'
,
batch_size
,
args
.
consumed_train_samples
)
for
key
in
loss_dict
:
for
key
in
loss_dict
:
writer
.
add_scalar
(
key
,
loss_dict
[
key
],
iteration
)
writer
.
add_scalar
(
key
,
loss_dict
[
key
]
+
'-iterations'
,
iteration
)
writer
.
add_scalar
(
key
,
loss_dict
[
key
]
+
'-samples'
,
args
.
consumed_train_samples
)
if
args
.
fp16
:
if
args
.
fp16
:
writer
.
add_scalar
(
'loss_scale'
,
loss_scale
,
iteration
)
writer
.
add_scalar
(
'loss_scale-iterations'
,
loss_scale
,
iteration
)
writer
.
add_scalar
(
'loss_scale-samples'
,
loss_scale
,
args
.
consumed_train_samples
)
normalizer
=
iteration
%
args
.
log_interval
normalizer
=
iteration
%
args
.
log_interval
if
normalizer
==
0
:
if
normalizer
==
0
:
normalizer
=
args
.
log_interval
normalizer
=
args
.
log_interval
...
@@ -703,6 +748,8 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
...
@@ -703,6 +748,8 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
elapsed_time
/
args
.
log_interval
,
iteration
)
elapsed_time
/
args
.
log_interval
,
iteration
)
log_string
=
' iteration {:8d}/{:8d} |'
.
format
(
log_string
=
' iteration {:8d}/{:8d} |'
.
format
(
iteration
,
args
.
train_iters
)
iteration
,
args
.
train_iters
)
log_string
+=
' consumed samples {:12d} |'
.
format
(
args
.
consumed_train_samples
)
log_string
+=
' elapsed time per iteration (ms): {:.1f} |'
.
format
(
log_string
+=
' elapsed time per iteration (ms): {:.1f} |'
.
format
(
elapsed_time
*
1000.0
/
args
.
log_interval
)
elapsed_time
*
1000.0
/
args
.
log_interval
)
log_string
+=
' learning rate: {:.3E} |'
.
format
(
learning_rate
)
log_string
+=
' learning rate: {:.3E} |'
.
format
(
learning_rate
)
...
@@ -732,6 +779,18 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
...
@@ -732,6 +779,18 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
return
report_memory_flag
return
report_memory_flag
def
save_checkpoint_and_time
(
iteration
,
model
,
optimizer
,
lr_scheduler
):
timers
=
get_timers
()
# Extra barrier is added to make sure
# all ranks report the max time.
torch
.
distributed
.
barrier
()
timers
(
'save checkpoint'
).
start
()
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr_scheduler
)
torch
.
distributed
.
barrier
()
timers
(
'save checkpoint'
).
stop
()
timers
.
log
([
'save checkpoint'
])
def
train
(
forward_step_func
,
model
,
optimizer
,
lr_scheduler
,
def
train
(
forward_step_func
,
model
,
optimizer
,
lr_scheduler
,
train_data_iterator
,
valid_data_iterator
):
train_data_iterator
,
valid_data_iterator
):
"""Train the model function."""
"""Train the model function."""
...
@@ -748,6 +807,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
...
@@ -748,6 +807,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
iteration
=
args
.
iteration
iteration
=
args
.
iteration
timers
(
'interval time'
).
start
()
timers
(
'interval time'
).
start
()
print_datetime
(
'before the start of training step'
)
report_memory_flag
=
True
report_memory_flag
=
True
while
iteration
<
args
.
train_iters
:
while
iteration
<
args
.
train_iters
:
update_num_microbatches
(
args
.
consumed_train_samples
)
update_num_microbatches
(
args
.
consumed_train_samples
)
...
@@ -777,9 +837,13 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
...
@@ -777,9 +837,13 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
lr_scheduler
)
lr_scheduler
)
# Checkpointing
# Checkpointing
saved_checkpoint
=
False
if
args
.
save
and
args
.
save_interval
and
\
if
args
.
save
and
args
.
save_interval
and
\
iteration
%
args
.
save_interval
==
0
:
iteration
%
args
.
save_interval
==
0
:
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr_scheduler
)
save_checkpoint_and_time
(
iteration
,
model
,
optimizer
,
lr_scheduler
)
saved_checkpoint
=
True
# Evaluation
# Evaluation
if
args
.
eval_interval
and
iteration
%
args
.
eval_interval
==
0
and
\
if
args
.
eval_interval
and
iteration
%
args
.
eval_interval
==
0
and
\
...
@@ -789,14 +853,31 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
...
@@ -789,14 +853,31 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
valid_data_iterator
,
model
,
valid_data_iterator
,
model
,
iteration
,
False
)
iteration
,
False
)
# Exiting based on duration
if
args
.
exit_duration_in_mins
:
train_time
=
(
time
.
time
()
-
_TRAIN_START_TIME
)
/
60.0
done_cuda
=
torch
.
cuda
.
IntTensor
(
[
train_time
>
args
.
exit_duration_in_mins
])
torch
.
distributed
.
all_reduce
(
done_cuda
,
op
=
torch
.
distributed
.
ReduceOp
.
MAX
)
done
=
done_cuda
.
item
()
if
done
:
if
not
saved_checkpoint
:
save_checkpoint_and_time
(
iteration
,
model
,
optimizer
,
lr_scheduler
)
print_datetime
(
'exiting program after {} minutes'
.
format
(
train_time
))
sys
.
exit
()
# Exiting based on iterations
if
args
.
exit_interval
and
iteration
%
args
.
exit_interval
==
0
:
if
args
.
exit_interval
and
iteration
%
args
.
exit_interval
==
0
:
if
not
saved_checkpoint
:
save_checkpoint_and_time
(
iteration
,
model
,
optimizer
,
lr_scheduler
)
torch
.
distributed
.
barrier
()
torch
.
distributed
.
barrier
()
time_str
=
datetime
.
now
().
strftime
(
'%Y-%m-%d %H:%M:%S'
)
print_datetime
(
'exiting program at iteration {}'
.
format
(
iteration
))
rank
=
torch
.
distributed
.
get_rank
()
print_rank_0
(
'rank: {} | time: {} | exiting the program at '
'iteration {}'
.
format
(
rank
,
time_str
,
iteration
))
sys
.
exit
()
sys
.
exit
()
return
iteration
return
iteration
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment