Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
a6756bf8
Commit
a6756bf8
authored
Nov 03, 2020
by
Deepak Narayanan
Browse files
Better 'forward' and 'backward' timing in megatron/training.py
parent
3e6898e6
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
6 additions
and
4 deletions
+6
-4
megatron/training.py
megatron/training.py
+6
-4
No files found.
megatron/training.py
View file @
a6756bf8
...
@@ -363,6 +363,7 @@ def forward_and_backward_steps_with_communication(forward_step_func, data_iterat
...
@@ -363,6 +363,7 @@ def forward_and_backward_steps_with_communication(forward_step_func, data_iterat
input_tensors
,
output_tensors
,
input_tensors
,
output_tensors
,
losses_reduced
,
timers
):
losses_reduced
,
timers
):
# Forward model for one step.
# Forward model for one step.
timers
(
'forward'
).
start
()
timers
(
'forward-compute'
).
start
()
timers
(
'forward-compute'
).
start
()
output_tensor
=
forward_step_func
(
data_iterator
,
model
,
input_tensor
)
output_tensor
=
forward_step_func
(
data_iterator
,
model
,
input_tensor
)
timers
(
'forward-compute'
).
stop
()
timers
(
'forward-compute'
).
stop
()
...
@@ -374,14 +375,13 @@ def forward_and_backward_steps_with_communication(forward_step_func, data_iterat
...
@@ -374,14 +375,13 @@ def forward_and_backward_steps_with_communication(forward_step_func, data_iterat
losses_reduced
.
append
(
loss_reduced
)
losses_reduced
.
append
(
loss_reduced
)
else
:
else
:
timers
(
'forward-send'
).
start
()
timers
(
'forward-send'
).
start
()
timers
(
'backward-recv'
).
start
()
_
,
output_tensor_grad
=
communicate
(
_
,
output_tensor_grad
=
communicate
(
tensor_send_next
=
output_tensor
,
tensor_send_next
=
output_tensor
,
tensor_send_prev
=
None
,
tensor_send_prev
=
None
,
recv_forward
=
False
,
recv_forward
=
False
,
recv_backward
=
True
)
recv_backward
=
True
)
timers
(
'forward-send'
).
stop
()
timers
(
'forward-send'
).
stop
()
timers
(
'
backward-recv
'
).
stop
()
timers
(
'
forward
'
).
stop
()
input_tensors
.
append
(
input_tensor
)
input_tensors
.
append
(
input_tensor
)
output_tensors
.
append
(
output_tensor
)
output_tensors
.
append
(
output_tensor
)
...
@@ -390,6 +390,7 @@ def forward_and_backward_steps_with_communication(forward_step_func, data_iterat
...
@@ -390,6 +390,7 @@ def forward_and_backward_steps_with_communication(forward_step_func, data_iterat
output_tensor
=
output_tensors
.
pop
(
0
)
output_tensor
=
output_tensors
.
pop
(
0
)
# Backward pass for one step.
# Backward pass for one step.
timers
(
'backward'
).
start
()
timers
(
'backward-compute'
).
start
()
timers
(
'backward-compute'
).
start
()
input_grad_tensor
=
\
input_grad_tensor
=
\
backward_step
(
optimizer
,
model
,
input_tensor
,
output_tensor
,
output_tensor_grad
)
backward_step
(
optimizer
,
model
,
input_tensor
,
output_tensor
,
output_tensor_grad
)
...
@@ -397,16 +398,15 @@ def forward_and_backward_steps_with_communication(forward_step_func, data_iterat
...
@@ -397,16 +398,15 @@ def forward_and_backward_steps_with_communication(forward_step_func, data_iterat
if
not
mpu
.
is_pipeline_first_stage
():
if
not
mpu
.
is_pipeline_first_stage
():
timers
(
'backward-send'
).
start
()
timers
(
'backward-send'
).
start
()
timers
(
'forward-recv'
).
start
()
input_tensor
,
_
=
communicate
(
input_tensor
,
_
=
communicate
(
tensor_send_next
=
None
,
tensor_send_next
=
None
,
tensor_send_prev
=
input_grad_tensor
,
tensor_send_prev
=
input_grad_tensor
,
recv_forward
=
(
not
last_microbatch
),
recv_forward
=
(
not
last_microbatch
),
recv_backward
=
False
)
recv_backward
=
False
)
timers
(
'backward-send'
).
stop
()
timers
(
'backward-send'
).
stop
()
timers
(
'forward-recv'
).
stop
()
else
:
else
:
input_tensor
=
None
input_tensor
=
None
timers
(
'backward'
).
stop
()
return
input_tensor
return
input_tensor
...
@@ -460,10 +460,12 @@ def train_step(forward_step_func, data_iterator,
...
@@ -460,10 +460,12 @@ def train_step(forward_step_func, data_iterator,
if
mpu
.
is_pipeline_first_stage
():
if
mpu
.
is_pipeline_first_stage
():
input_tensor
=
None
input_tensor
=
None
else
:
else
:
timers
(
'forward-recv'
).
start
()
input_tensor
,
_
=
communicate
(
tensor_send_next
=
None
,
input_tensor
,
_
=
communicate
(
tensor_send_next
=
None
,
tensor_send_prev
=
None
,
tensor_send_prev
=
None
,
recv_forward
=
True
,
recv_forward
=
True
,
recv_backward
=
False
)
recv_backward
=
False
)
timers
(
'forward-recv'
).
stop
()
# Run 1F1B.
# Run 1F1B.
for
i
in
range
(
num_microbatches_in_minibatch
-
num_warmup_microbatches
):
for
i
in
range
(
num_microbatches_in_minibatch
-
num_warmup_microbatches
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment