Commit a84a5fa0 authored by Jared Casper's avatar Jared Casper Committed by Deepak Narayanan
Browse files

Make an eval iteration the same number of samples as a training iteration

parent 2cf1d6d0
...@@ -761,30 +761,31 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False): ...@@ -761,30 +761,31 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
print_rank_0('Evaluating iter {}/{}'.format(iteration, print_rank_0('Evaluating iter {}/{}'.format(iteration,
args.eval_iters)) args.eval_iters))
if not mpu.is_pipeline_first_stage(): for _ in range(args.num_microbatches_in_minibatch):
input_tensor, _ = communicate( if not mpu.is_pipeline_first_stage():
tensor_send_next=None, input_tensor, _ = communicate(
tensor_send_prev=None, tensor_send_next=None,
recv_forward=True, tensor_send_prev=None,
recv_backward=False) recv_forward=True,
else: recv_backward=False)
input_tensor = None else:
input_tensor = None
# Forward evaluation.
output_tensor = forward_step_func(data_iterator, model, input_tensor) # Forward evaluation.
output_tensor = forward_step_func(data_iterator, model, input_tensor)
if mpu.is_pipeline_last_stage():
_, loss_dict = output_tensor if mpu.is_pipeline_last_stage():
# Reduce across processes. _, loss_dict = output_tensor
for key in loss_dict: # Reduce across processes.
total_loss_dict[key] = total_loss_dict.get(key, torch.cuda.FloatTensor([0.0])) + \ for key in loss_dict:
loss_dict[key] total_loss_dict[key] = total_loss_dict.get(key, torch.cuda.FloatTensor([0.0])) + \
else: loss_dict[key]
communicate( else:
tensor_send_next=output_tensor, communicate(
tensor_send_prev=None, tensor_send_next=output_tensor,
recv_forward=False, tensor_send_prev=None,
recv_backward=False) recv_forward=False,
recv_backward=False)
args.consumed_valid_samples += mpu.get_data_parallel_world_size() \ args.consumed_valid_samples += mpu.get_data_parallel_world_size() \
* args.batch_size \ * args.batch_size \
...@@ -793,7 +794,7 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False): ...@@ -793,7 +794,7 @@ def evaluate(forward_step_func, data_iterator, model, verbose=False):
model.train() model.train()
for key in total_loss_dict: for key in total_loss_dict:
total_loss_dict[key] /= args.eval_iters total_loss_dict[key] /= args.eval_iters * args.num_microbatches_in_minibatch
return total_loss_dict return total_loss_dict
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment