Unverified Commit c485a1fb authored by Thorsten Kurth's avatar Thorsten Kurth Committed by GitHub
Browse files

Merge pull request #84 from NVIDIA/depth_small_fix

Small fix in metric computation
parents 5da00de0 3c3b0f8e
...@@ -138,7 +138,7 @@ def validate_model( ...@@ -138,7 +138,7 @@ def validate_model(
for metric in metrics_fns: for metric in metrics_fns:
metric_buff = metrics[metric] metric_buff = metrics[metric]
metric_fn = metrics_fns[metric] metric_fn = metrics_fns[metric]
metric_buff[idx] = metric_fn(prd, tar.unsqueeze(-3), mask) metric_buff[idx] = metric_fn(prd, tar, mask)
tar = (tar * mask).squeeze() tar = (tar * mask).squeeze()
prd = (prd * mask).squeeze() prd = (prd * mask).squeeze()
...@@ -257,7 +257,7 @@ def train_model( ...@@ -257,7 +257,7 @@ def train_model(
# prepare metrics buffer for accumulation of validation metrics # prepare metrics buffer for accumulation of validation metrics
valid_metrics = {} valid_metrics = {}
for metric in metrics_fns: for metric in metrics_fns:
valid_metrics[metric] = torch.zeros(1, dtype=torch.float32, device=device) valid_metrics[metric] = torch.zeros(2, dtype=torch.float32, device=device)
model.eval() model.eval()
...@@ -287,6 +287,7 @@ def train_model( ...@@ -287,6 +287,7 @@ def train_model(
metric_buff = valid_metrics[metric] metric_buff = valid_metrics[metric]
metric_fn = metrics_fns[metric] metric_fn = metrics_fns[metric]
metric_buff[0] += metric_fn(prd, tar, mask) * inp.size(0) metric_buff[0] += metric_fn(prd, tar, mask) * inp.size(0)
metric_buff[1] += inp.size(0)
if dist.is_initialized(): if dist.is_initialized():
dist.all_reduce(valid_loss) dist.all_reduce(valid_loss)
...@@ -294,8 +295,9 @@ def train_model( ...@@ -294,8 +295,9 @@ def train_model(
dist.all_reduce(valid_metrics[metric]) dist.all_reduce(valid_metrics[metric])
valid_loss = (valid_loss[0] / valid_loss[1]).item() valid_loss = (valid_loss[0] / valid_loss[1]).item()
for metric in valid_metrics: for metric in valid_metrics:
valid_metrics[metric] = (valid_metrics[metric][0] / valid_loss[1]).item() valid_metrics[metric] = (valid_metrics[metric][0] / valid_metrics[metric][1]).item()
if scheduler is not None: if scheduler is not None:
scheduler.step(valid_loss) scheduler.step(valid_loss)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment