Commit 83e08b6f authored by Myle Ott's avatar Myle Ott
Browse files

Fix validation loss

parent bfeb7732
......@@ -266,17 +266,23 @@ class Trainer(object):
def valid_step(self, sample):
"""Do forward pass in evaluation mode."""
with torch.no_grad():
self.model.eval()
logging_output, sample_size = {}, 0
with torch.no_grad():
sample = self._prepare_sample(sample)
if sample is None:
sample = self._prepare_sample(self._dummy_batch)
ignore_results = True
else:
ignore_results = False
_loss, sample_size, logging_output = self.task.get_loss(
self.model, self.criterion, sample,
)
if ignore_results:
logging_output, sample_size = {}, 0
# gather logging outputs from all replicas
if self.args.distributed_world_size > 1:
logging_output, sample_size = zip(*distributed_utils.all_gather_list(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment