Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
77bff386
Commit
77bff386
authored
Jul 21, 2021
by
Deepak Narayanan
Browse files
Use helper method in megatron/schedules.py as intended
parent
f3be8ead
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
6 additions
and
20 deletions
+6
-20
megatron/schedules.py
megatron/schedules.py
+3
-0
megatron/training.py
megatron/training.py
+3
-20
No files found.
megatron/schedules.py
View file @
77bff386
...
@@ -31,6 +31,9 @@ def get_forward_backward_func():
...
@@ -31,6 +31,9 @@ def get_forward_backward_func():
if
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
if
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
if
args
.
virtual_pipeline_model_parallel_size
is
not
None
:
if
args
.
virtual_pipeline_model_parallel_size
is
not
None
:
forward_backward_func
=
forward_backward_pipelining_with_interleaving
forward_backward_func
=
forward_backward_pipelining_with_interleaving
assert
get_num_microbatches
()
%
args
.
pipeline_model_parallel_size
==
0
,
\
'number of microbatches is not divisible by pipeline-parallel '
\
'size when using interleaved schedule'
else
:
else
:
forward_backward_func
=
forward_backward_pipelining_without_interleaving
forward_backward_func
=
forward_backward_pipelining_without_interleaving
else
:
else
:
...
...
megatron/training.py
View file @
77bff386
...
@@ -47,9 +47,7 @@ from megatron.utils import check_adlr_autoresume_termination
...
@@ -47,9 +47,7 @@ from megatron.utils import check_adlr_autoresume_termination
from
megatron.utils
import
unwrap_model
from
megatron.utils
import
unwrap_model
from
megatron.data.data_samplers
import
build_pretraining_data_loader
from
megatron.data.data_samplers
import
build_pretraining_data_loader
from
megatron.utils
import
calc_params_l2_norm
from
megatron.utils
import
calc_params_l2_norm
from
megatron.schedules
import
forward_backward_no_pipelining
from
megatron.schedules
import
get_forward_backward_func
from
megatron.schedules
import
forward_backward_pipelining_without_interleaving
from
megatron.schedules
import
forward_backward_pipelining_with_interleaving
from
megatron.utils
import
report_memory
from
megatron.utils
import
report_memory
...
@@ -359,16 +357,7 @@ def train_step(forward_step_func, data_iterator,
...
@@ -359,16 +357,7 @@ def train_step(forward_step_func, data_iterator,
else
:
else
:
optimizer
.
zero_grad
()
optimizer
.
zero_grad
()
if
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
forward_backward_func
=
get_forward_backward_func
()
if
args
.
virtual_pipeline_model_parallel_size
is
not
None
:
forward_backward_func
=
forward_backward_pipelining_with_interleaving
assert
get_num_microbatches
()
%
args
.
pipeline_model_parallel_size
==
0
,
\
'number of microbatches is not divisible by pipeline-parallel '
\
'size when using interleaved schedule'
else
:
forward_backward_func
=
forward_backward_pipelining_without_interleaving
else
:
forward_backward_func
=
forward_backward_no_pipelining
losses_reduced
=
forward_backward_func
(
losses_reduced
=
forward_backward_func
(
forward_step_func
,
data_iterator
,
model
,
forward_step_func
,
data_iterator
,
model
,
optimizer
,
timers
,
forward_only
=
False
)
optimizer
,
timers
,
forward_only
=
False
)
...
@@ -722,13 +711,7 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
...
@@ -722,13 +711,7 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
print_rank_0
(
'Evaluating iter {}/{}'
.
format
(
iteration
,
print_rank_0
(
'Evaluating iter {}/{}'
.
format
(
iteration
,
args
.
eval_iters
))
args
.
eval_iters
))
if
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
forward_backward_func
=
get_forward_backward_func
()
if
args
.
virtual_pipeline_model_parallel_size
is
not
None
:
forward_backward_func
=
forward_backward_pipelining_with_interleaving
else
:
forward_backward_func
=
forward_backward_pipelining_without_interleaving
else
:
forward_backward_func
=
forward_backward_no_pipelining
loss_dicts
=
forward_backward_func
(
loss_dicts
=
forward_backward_func
(
forward_step_func
,
data_iterator
,
model
,
optimizer
=
None
,
forward_step_func
,
data_iterator
,
model
,
optimizer
=
None
,
timers
=
None
,
forward_only
=
True
)
timers
=
None
,
forward_only
=
True
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment