Commit 83e08b6f authored by Myle Ott's avatar Myle Ott
Browse files

Fix validation loss

parent bfeb7732
...@@ -266,17 +266,23 @@ class Trainer(object): ...@@ -266,17 +266,23 @@ class Trainer(object):
def valid_step(self, sample): def valid_step(self, sample):
"""Do forward pass in evaluation mode.""" """Do forward pass in evaluation mode."""
self.model.eval()
logging_output, sample_size = {}, 0
with torch.no_grad(): with torch.no_grad():
self.model.eval()
sample = self._prepare_sample(sample) sample = self._prepare_sample(sample)
if sample is None: if sample is None:
sample = self._prepare_sample(self._dummy_batch) sample = self._prepare_sample(self._dummy_batch)
ignore_results = True
else:
ignore_results = False
_loss, sample_size, logging_output = self.task.get_loss( _loss, sample_size, logging_output = self.task.get_loss(
self.model, self.criterion, sample, self.model, self.criterion, sample,
) )
if ignore_results:
logging_output, sample_size = {}, 0
# gather logging outputs from all replicas # gather logging outputs from all replicas
if self.args.distributed_world_size > 1: if self.args.distributed_world_size > 1:
logging_output, sample_size = zip(*distributed_utils.all_gather_list( logging_output, sample_size = zip(*distributed_utils.all_gather_list(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment