Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
cf2f4d9d
Commit
cf2f4d9d
authored
Oct 31, 2019
by
Jared Casper
Browse files
Merge branch 'code_reuse' into 'master'
refactored for code reuse See merge request ADLR/megatron-lm!11
parents
beb3e0d3
cbd8c054
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
101 additions
and
75 deletions
+101
-75
megatron/training.py
megatron/training.py
+84
-63
megatron/utils.py
megatron/utils.py
+11
-1
pretrain_bert.py
pretrain_bert.py
+3
-7
pretrain_gpt2.py
pretrain_gpt2.py
+3
-4
No files found.
megatron/training.py
View file @
cf2f4d9d
...
@@ -71,28 +71,17 @@ def run(top_level_message, train_val_test_data_provider,
...
@@ -71,28 +71,17 @@ def run(top_level_message, train_val_test_data_provider,
function add `batch generator` to the timers class.
function add `batch generator` to the timers class.
"""
"""
# Timer.
timers
=
Timers
()
# Arguments.
# Arguments.
args
=
get_args
()
args
=
get_args
()
# Timer.
timers
=
Timers
()
# Tensorboard writer
# Tensorboard writer
writer
=
get_tensorboard_writer
(
args
)
writer
=
get_tensorboard_writer
(
args
)
# Pytorch distributed.
# Initalize.
initialize_distributed
(
args
)
initialize_megatron
(
top_level_message
,
args
,
writer
)
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
top_level_message
,
flush
=
True
)
print_args
(
args
,
writer
)
# Autoresume.
torch
.
distributed
.
barrier
()
if
args
.
adlr_autoresume
:
enable_adlr_autoresume
(
args
)
# Random seeds for reproducability.
set_random_seed
(
args
.
seed
)
# Data stuff.
# Data stuff.
train_data
,
val_data
,
test_data
=
train_val_test_data_provider
(
args
)
train_data
,
val_data
,
test_data
=
train_val_test_data_provider
(
args
)
...
@@ -135,6 +124,24 @@ def run(top_level_message, train_val_test_data_provider,
...
@@ -135,6 +124,24 @@ def run(top_level_message, train_val_test_data_provider,
args
,
None
,
0
,
timers
,
True
)
args
,
None
,
0
,
timers
,
True
)
def
initialize_megatron
(
message
,
args
,
writer
):
""""Initialize distributed, random seed, and autoresume."""
# Pytorch distributed.
initialize_distributed
(
args
)
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
message
,
flush
=
True
)
print_args
(
args
,
writer
)
# Autoresume.
torch
.
distributed
.
barrier
()
if
args
.
adlr_autoresume
:
enable_adlr_autoresume
(
args
)
# Random seeds for reproducability.
set_random_seed
(
args
.
seed
)
def
get_model
(
model_provider_func
,
args
):
def
get_model
(
model_provider_func
,
args
):
"""Build the model."""
"""Build the model."""
...
@@ -301,53 +308,31 @@ def train_step(forward_step_func, data_iterator, model, optimizer, lr_scheduler,
...
@@ -301,53 +308,31 @@ def train_step(forward_step_func, data_iterator, model, optimizer, lr_scheduler,
return
loss_reduced
,
skipped_iter
return
loss_reduced
,
skipped_iter
def
train
(
forward_step_func
,
model
,
optimizer
,
lr_scheduler
,
def
training_log
(
loss_dict
,
total_loss_dict
,
learning_rate
,
iteration
,
train_data_iterator
,
val_data_iterator
,
timers
,
args
,
writer
):
loss_scale
,
report_memory_flag
,
writer
,
args
,
timers
):
"""Train the model function."""
# Turn on training mode which enables dropout.
model
.
train
()
# Tracking loss.
total_loss_dict
=
{}
# Iterations.
iteration
=
args
.
iteration
skipped_iters
=
0
timers
(
'interval time'
).
start
()
report_memory_flag
=
True
while
iteration
<
args
.
train_iters
:
loss_dict
,
skipped_iter
=
train_step
(
forward_step_func
,
train_data_iterator
,
model
,
optimizer
,
lr_scheduler
,
args
,
timers
)
skipped_iters
+=
skipped_iter
iteration
+=
1
# Update losses.
# Update losses.
for
key
in
loss_dict
:
for
key
in
loss_dict
:
total_loss_dict
[
key
]
=
total_loss_dict
.
get
(
key
,
0.
)
+
loss_dict
[
key
]
total_loss_dict
[
key
]
=
total_loss_dict
.
get
(
key
,
0.
)
+
loss_dict
[
key
]
# Logging.
# Logging.
if
args
.
DDP_impl
==
'torch'
:
timers_to_log
=
[]
timers_to_log
=
[
'forward'
,
'backward'
,
'optimizer'
,
def
add_to_logging
(
name
):
'batch generator'
]
if
name
in
timers
.
timers
:
else
:
timers_to_log
.
append
(
name
)
timers_to_log
=
[
'forward'
,
'backward'
,
'allreduce'
,
'optimizer'
,
add_to_logging
(
'forward'
)
'batch generator'
]
add_to_logging
(
'backward'
)
add_to_logging
(
'allreduce'
)
learning_rate
=
optimizer
.
param_groups
[
0
][
'lr'
]
add_to_logging
(
'optimizer'
)
add_to_logging
(
'batch generator'
)
# Tensorboard values.
if
writer
and
torch
.
distributed
.
get_rank
()
==
0
:
if
writer
and
torch
.
distributed
.
get_rank
()
==
0
:
writer
.
add_scalar
(
'learning_rate'
,
learning_rate
,
iteration
)
writer
.
add_scalar
(
'learning_rate'
,
learning_rate
,
iteration
)
for
key
in
total_
loss_dict
:
for
key
in
loss_dict
:
writer
.
add_scalar
(
key
,
total_
loss_dict
[
key
],
iteration
)
writer
.
add_scalar
(
key
,
loss_dict
[
key
],
iteration
)
if
args
.
fp16
:
if
args
.
fp16
:
writer
.
add_scalar
(
'loss_scale'
,
optimizer
.
loss_scale
,
iteration
)
writer
.
add_scalar
(
'loss_scale'
,
loss_scale
,
iteration
)
normalizer
=
iteration
%
args
.
log_interval
normalizer
=
iteration
%
args
.
log_interval
if
normalizer
==
0
:
if
normalizer
==
0
:
normalizer
=
args
.
log_interval
normalizer
=
args
.
log_interval
...
@@ -369,14 +354,50 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
...
@@ -369,14 +354,50 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
log_string
+=
' {}: {:.6E} |'
.
format
(
key
,
avg
)
log_string
+=
' {}: {:.6E} |'
.
format
(
key
,
avg
)
total_loss_dict
[
key
]
=
0.0
total_loss_dict
[
key
]
=
0.0
if
args
.
fp16
:
if
args
.
fp16
:
log_string
+=
' loss scale: {:.1f} |'
.
format
(
log_string
+=
' loss scale: {:.1f} |'
.
format
(
loss_scale
)
optimizer
.
loss_scale
)
print_rank_0
(
log_string
)
print_rank_0
(
log_string
)
if
report_memory_flag
:
if
report_memory_flag
:
report_memory
(
'after {} iterations'
.
format
(
iteration
))
report_memory
(
'after {} iterations'
.
format
(
iteration
))
report_memory_flag
=
False
report_memory_flag
=
False
timers
.
log
(
timers_to_log
,
normalizer
=
args
.
log_interval
)
timers
.
log
(
timers_to_log
,
normalizer
=
args
.
log_interval
)
return
report_memory_flag
def
train
(
forward_step_func
,
model
,
optimizer
,
lr_scheduler
,
train_data_iterator
,
val_data_iterator
,
timers
,
args
,
writer
):
"""Train the model function."""
# Turn on training mode which enables dropout.
model
.
train
()
# Tracking loss.
total_loss_dict
=
{}
# Iterations.
iteration
=
args
.
iteration
skipped_iters
=
0
timers
(
'interval time'
).
start
()
report_memory_flag
=
True
while
iteration
<
args
.
train_iters
:
loss_dict
,
skipped_iter
=
train_step
(
forward_step_func
,
train_data_iterator
,
model
,
optimizer
,
lr_scheduler
,
args
,
timers
)
skipped_iters
+=
skipped_iter
iteration
+=
1
# Logging.
report_memory_flag
=
training_log
(
loss_dict
,
total_loss_dict
,
optimizer
.
param_groups
[
0
][
'lr'
],
iteration
,
optimizer
.
loss_scale
,
report_memory_flag
,
writer
,
args
,
timers
)
# Autoresume
# Autoresume
if
(
iteration
%
args
.
adlr_autoresume_interval
==
0
)
and
\
if
(
iteration
%
args
.
adlr_autoresume_interval
==
0
)
and
\
args
.
adlr_autoresume
:
args
.
adlr_autoresume
:
...
...
megatron/utils.py
View file @
cf2f4d9d
...
@@ -31,9 +31,19 @@ from megatron.model import DistributedDataParallel as LocalDDP
...
@@ -31,9 +31,19 @@ from megatron.model import DistributedDataParallel as LocalDDP
from
megatron.model
import
get_params_for_weight_decay_optimization
from
megatron.model
import
get_params_for_weight_decay_optimization
def
reduce_losses
(
losses
):
reduced_losses
=
torch
.
cat
(
[
loss
.
clone
().
detach
().
view
(
1
)
for
loss
in
losses
])
torch
.
distributed
.
all_reduce
(
reduced_losses
)
reduced_losses
=
reduced_losses
/
torch
.
distributed
.
get_world_size
()
return
reduced_losses
def
get_tensorboard_writer
(
args
):
def
get_tensorboard_writer
(
args
):
writer
=
None
writer
=
None
if
args
.
tensorboard_dir
and
args
.
rank
==
0
:
if
hasattr
(
args
,
'tensorboard_dir'
)
and
\
args
.
tensorboard_dir
and
args
.
rank
==
0
:
try
:
try
:
from
torch.utils.tensorboard
import
SummaryWriter
from
torch.utils.tensorboard
import
SummaryWriter
writer
=
SummaryWriter
(
log_dir
=
args
.
tensorboard_dir
)
writer
=
SummaryWriter
(
log_dir
=
args
.
tensorboard_dir
)
...
...
pretrain_bert.py
View file @
cf2f4d9d
...
@@ -22,6 +22,7 @@ from configure_data import configure_data
...
@@ -22,6 +22,7 @@ from configure_data import configure_data
from
megatron
import
mpu
from
megatron
import
mpu
from
megatron.model
import
BertModel
from
megatron.model
import
BertModel
from
megatron.utils
import
print_rank_0
from
megatron.utils
import
print_rank_0
from
megatron.utils
import
reduce_losses
from
megatron.utils
import
vocab_size_with_padding
from
megatron.utils
import
vocab_size_with_padding
from
megatron.training
import
run
from
megatron.training
import
run
...
@@ -99,14 +100,9 @@ def forward_step(data_iterator, model, args, timers):
...
@@ -99,14 +100,9 @@ def forward_step(data_iterator, model, args, timers):
loss
=
lm_loss
+
nsp_loss
loss
=
lm_loss
+
nsp_loss
reduced_losses
=
torch
.
cat
((
lm_loss
.
clone
().
detach
().
view
(
1
),
reduced_losses
=
reduce_losses
([
lm_loss
,
nsp_loss
])
nsp_loss
.
clone
().
detach
().
view
(
1
)))
torch
.
distributed
.
all_reduce
(
reduced_losses
)
reduced_losses
=
reduced_losses
/
torch
.
distributed
.
get_world_size
()
lm_loss_reduced
=
reduced_losses
[
0
]
nsp_loss_reduced
=
reduced_losses
[
1
]
return
loss
,
{
'lm loss'
:
lm_loss_reduced
,
'nsp loss'
:
nsp_loss_reduced
}
return
loss
,
{
'lm loss'
:
reduced_losses
[
0
],
'nsp loss'
:
reduced_losses
[
1
]
}
def
get_train_val_test_data
(
args
):
def
get_train_val_test_data
(
args
):
...
...
pretrain_gpt2.py
View file @
cf2f4d9d
...
@@ -22,6 +22,7 @@ from gpt2_data_loader import make_gpt2_dataloaders
...
@@ -22,6 +22,7 @@ from gpt2_data_loader import make_gpt2_dataloaders
from
megatron
import
mpu
from
megatron
import
mpu
from
megatron.model
import
GPT2Model
from
megatron.model
import
GPT2Model
from
megatron.utils
import
print_rank_0
from
megatron.utils
import
print_rank_0
from
megatron.utils
import
reduce_losses
from
megatron.utils
import
vocab_size_with_padding
from
megatron.utils
import
vocab_size_with_padding
from
megatron.training
import
run
from
megatron.training
import
run
...
@@ -155,11 +156,9 @@ def forward_step(data_iterator, model, args, timers):
...
@@ -155,11 +156,9 @@ def forward_step(data_iterator, model, args, timers):
loss
=
torch
.
sum
(
losses
.
view
(
-
1
)
*
loss_mask
)
/
loss_mask
.
sum
()
loss
=
torch
.
sum
(
losses
.
view
(
-
1
)
*
loss_mask
)
/
loss_mask
.
sum
()
# Reduce loss for logging.
# Reduce loss for logging.
reduced_loss
=
loss
.
clone
().
detach
().
view
(
1
)
reduced_loss
=
reduce_losses
([
loss
])
torch
.
distributed
.
all_reduce
(
reduced_loss
)
reduced_loss
=
reduced_loss
/
torch
.
distributed
.
get_world_size
()
return
loss
,
{
'lm loss'
:
reduced_loss
}
return
loss
,
{
'lm loss'
:
reduced_loss
[
0
]
}
def
get_train_val_test_data
(
args
):
def
get_train_val_test_data
(
args
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment