Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
c671de3e
Commit
c671de3e
authored
Nov 12, 2020
by
Deepak Narayanan
Browse files
Move division of loss tensor by number of microbatches to training.py
parent
69a546be
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
11 additions
and
10 deletions
+11
-10
megatron/training.py
megatron/training.py
+9
-6
pretrain_bert.py
pretrain_bert.py
+1
-2
pretrain_gpt2.py
pretrain_gpt2.py
+1
-2
No files found.
megatron/training.py
View file @
c671de3e
...
...
@@ -294,6 +294,8 @@ def backward_step(optimizer, model, input_tensor, output_tensor, output_tensor_g
def
forward_step_with_communication
(
forward_step_func
,
data_iterator
,
model
,
input_tensors
,
output_tensors
,
losses_reduced
,
timers
):
args
=
get_args
()
if
not
mpu
.
is_pipeline_first_stage
():
timers
(
'forward-recv'
).
start
()
input_tensor
,
_
=
communicate
(
...
...
@@ -312,7 +314,7 @@ def forward_step_with_communication(forward_step_func, data_iterator, model,
if
mpu
.
is_pipeline_last_stage
():
loss
,
loss_reduced
=
output_tensor
output_tensor
=
loss
output_tensor
=
loss
/
args
.
num_microbatches_in_minibatch
losses_reduced
.
append
(
loss_reduced
)
else
:
timers
(
'forward-send'
).
start
()
...
...
@@ -328,7 +330,6 @@ def forward_step_with_communication(forward_step_func, data_iterator, model,
def
backward_step_with_communication
(
optimizer
,
model
,
input_tensors
,
output_tensors
,
timers
):
"""Backward step."""
input_tensor
=
input_tensors
.
pop
(
0
)
output_tensor
=
output_tensors
.
pop
(
0
)
...
...
@@ -364,6 +365,8 @@ def forward_and_backward_steps_with_communication(forward_step_func, data_iterat
input_tensor
,
last_microbatch
,
input_tensors
,
output_tensors
,
losses_reduced
,
timers
):
args
=
get_args
()
# Forward model for one step.
timers
(
'forward-compute'
).
start
()
output_tensor
=
forward_step_func
(
data_iterator
,
model
,
input_tensor
)
...
...
@@ -371,7 +374,7 @@ def forward_and_backward_steps_with_communication(forward_step_func, data_iterat
if
mpu
.
is_pipeline_last_stage
():
loss
,
loss_reduced
=
output_tensor
output_tensor
=
loss
output_tensor
=
loss
/
args
.
num_microbatches_in_minibatch
output_tensor_grad
=
None
losses_reduced
.
append
(
loss_reduced
)
else
:
...
...
@@ -418,7 +421,7 @@ def forward_backward_no_pipelining(forward_step_func, data_iterator, model,
for
i
in
range
(
args
.
num_microbatches_in_minibatch
):
timers
(
'forward-compute'
).
start
()
loss
,
loss_reduced
=
forward_step_func
(
data_iterator
,
model
,
input_tensor
=
None
)
output_tensor
=
loss
output_tensor
=
loss
/
args
.
num_microbatches_in_minibatch
losses_reduced
.
append
(
loss_reduced
)
timers
(
'forward-compute'
).
stop
()
...
...
@@ -571,7 +574,7 @@ def train_step(forward_step_func, data_iterator,
loss_reduced
=
{}
for
key
in
losses_reduced
[
0
]:
losses_reduced_for_key
=
[
x
[
key
]
for
x
in
losses_reduced
]
loss_reduced
[
key
]
=
sum
(
losses_reduced_for_key
)
loss_reduced
[
key
]
=
sum
(
losses_reduced_for_key
)
/
len
(
losses_reduced_for_key
)
return
loss_reduced
,
skipped_iter
return
{},
skipped_iter
...
...
@@ -770,7 +773,7 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
_
,
loss_dict
=
output_tensor
# Reduce across processes.
for
key
in
loss_dict
:
total_loss_dict
[
key
]
=
total_loss_dict
.
get
(
key
,
0.
)
+
\
total_loss_dict
[
key
]
=
total_loss_dict
.
get
(
key
,
torch
.
cuda
.
FloatTensor
([
0.0
])
)
+
\
loss_dict
[
key
]
else
:
communicate
(
...
...
pretrain_bert.py
View file @
c671de3e
...
...
@@ -118,8 +118,7 @@ def forward_step(data_iterator, model, input_tensor):
lm_loss_
=
lm_loss_
.
float
()
loss_mask
=
loss_mask
.
float
()
lm_loss
=
torch
.
sum
(
lm_loss_
.
view
(
-
1
)
*
loss_mask
.
reshape
(
-
1
))
/
(
loss_mask
.
sum
()
*
args
.
num_microbatches_in_minibatch
)
lm_loss_
.
view
(
-
1
)
*
loss_mask
.
reshape
(
-
1
))
/
loss_mask
.
sum
()
loss
=
lm_loss
+
sop_loss
...
...
pretrain_gpt2.py
View file @
c671de3e
...
...
@@ -110,8 +110,7 @@ def forward_step(data_iterator, model, input_tensor):
if
mpu
.
is_pipeline_last_stage
():
losses
=
output_tensor
.
float
()
loss_mask
=
loss_mask
.
view
(
-
1
).
float
()
loss
=
torch
.
sum
(
losses
.
view
(
-
1
)
*
loss_mask
)
/
(
loss_mask
.
sum
()
*
args
.
num_microbatches_in_minibatch
)
loss
=
torch
.
sum
(
losses
.
view
(
-
1
)
*
loss_mask
)
/
loss_mask
.
sum
()
# Reduce loss for logging.
averaged_loss
=
average_losses_across_data_parallel_group
([
loss
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment