Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
a84a5fa0
Commit
a84a5fa0
authored
Dec 03, 2020
by
Jared Casper
Committed by
Deepak Narayanan
Dec 19, 2020
Browse files
Make an eval iteration the same number of samples as a training iteration
parent
2cf1d6d0
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
26 additions
and
25 deletions
+26
-25
megatron/training.py
megatron/training.py
+26
-25
No files found.
megatron/training.py
View file @
a84a5fa0
...
@@ -761,30 +761,31 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
...
@@ -761,30 +761,31 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
print_rank_0
(
'Evaluating iter {}/{}'
.
format
(
iteration
,
print_rank_0
(
'Evaluating iter {}/{}'
.
format
(
iteration
,
args
.
eval_iters
))
args
.
eval_iters
))
if
not
mpu
.
is_pipeline_first_stage
():
for
_
in
range
(
args
.
num_microbatches_in_minibatch
):
input_tensor
,
_
=
communicate
(
if
not
mpu
.
is_pipeline_first_stage
():
tensor_send_next
=
None
,
input_tensor
,
_
=
communicate
(
tensor_send_prev
=
None
,
tensor_send_next
=
None
,
recv_forward
=
True
,
tensor_send_prev
=
None
,
recv_backward
=
False
)
recv_forward
=
True
,
else
:
recv_backward
=
False
)
input_tensor
=
None
else
:
input_tensor
=
None
# Forward evaluation.
output_tensor
=
forward_step_func
(
data_iterator
,
model
,
input_tensor
)
# Forward evaluation.
output_tensor
=
forward_step_func
(
data_iterator
,
model
,
input_tensor
)
if
mpu
.
is_pipeline_last_stage
():
_
,
loss_dict
=
output_tensor
if
mpu
.
is_pipeline_last_stage
():
# Reduce across processes.
_
,
loss_dict
=
output_tensor
for
key
in
loss_dict
:
# Reduce across processes.
total_loss_dict
[
key
]
=
total_loss_dict
.
get
(
key
,
torch
.
cuda
.
FloatTensor
([
0.0
]))
+
\
for
key
in
loss_dict
:
loss_dict
[
key
]
total_loss_dict
[
key
]
=
total_loss_dict
.
get
(
key
,
torch
.
cuda
.
FloatTensor
([
0.0
]))
+
\
else
:
loss_dict
[
key
]
communicate
(
else
:
tensor_send_next
=
output_tensor
,
communicate
(
tensor_send_prev
=
None
,
tensor_send_next
=
output_tensor
,
recv_forward
=
False
,
tensor_send_prev
=
None
,
recv_backward
=
False
)
recv_forward
=
False
,
recv_backward
=
False
)
args
.
consumed_valid_samples
+=
mpu
.
get_data_parallel_world_size
()
\
args
.
consumed_valid_samples
+=
mpu
.
get_data_parallel_world_size
()
\
*
args
.
batch_size
\
*
args
.
batch_size
\
...
@@ -793,7 +794,7 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
...
@@ -793,7 +794,7 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
model
.
train
()
model
.
train
()
for
key
in
total_loss_dict
:
for
key
in
total_loss_dict
:
total_loss_dict
[
key
]
/=
args
.
eval_iters
total_loss_dict
[
key
]
/=
args
.
eval_iters
*
args
.
num_microbatches_in_minibatch
return
total_loss_dict
return
total_loss_dict
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment