Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-harmonics
Commits
c485a1fb
Unverified
Commit
c485a1fb
authored
Jul 01, 2025
by
Thorsten Kurth
Committed by
GitHub
Jul 01, 2025
Browse files
Merge pull request #84 from NVIDIA/depth_small_fix
Small fix in metric computation
parents
5da00de0
3c3b0f8e
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
5 additions
and
3 deletions
+5
-3
examples/depth/train.py
examples/depth/train.py
+5
-3
No files found.
examples/depth/train.py
View file @
c485a1fb
...
@@ -138,7 +138,7 @@ def validate_model(
...
@@ -138,7 +138,7 @@ def validate_model(
for
metric
in
metrics_fns
:
for
metric
in
metrics_fns
:
metric_buff
=
metrics
[
metric
]
metric_buff
=
metrics
[
metric
]
metric_fn
=
metrics_fns
[
metric
]
metric_fn
=
metrics_fns
[
metric
]
metric_buff
[
idx
]
=
metric_fn
(
prd
,
tar
.
unsqueeze
(
-
3
)
,
mask
)
metric_buff
[
idx
]
=
metric_fn
(
prd
,
tar
,
mask
)
tar
=
(
tar
*
mask
).
squeeze
()
tar
=
(
tar
*
mask
).
squeeze
()
prd
=
(
prd
*
mask
).
squeeze
()
prd
=
(
prd
*
mask
).
squeeze
()
...
@@ -257,7 +257,7 @@ def train_model(
...
@@ -257,7 +257,7 @@ def train_model(
# prepare metrics buffer for accumulation of validation metrics
# prepare metrics buffer for accumulation of validation metrics
valid_metrics
=
{}
valid_metrics
=
{}
for
metric
in
metrics_fns
:
for
metric
in
metrics_fns
:
valid_metrics
[
metric
]
=
torch
.
zeros
(
1
,
dtype
=
torch
.
float32
,
device
=
device
)
valid_metrics
[
metric
]
=
torch
.
zeros
(
2
,
dtype
=
torch
.
float32
,
device
=
device
)
model
.
eval
()
model
.
eval
()
...
@@ -287,6 +287,7 @@ def train_model(
...
@@ -287,6 +287,7 @@ def train_model(
metric_buff
=
valid_metrics
[
metric
]
metric_buff
=
valid_metrics
[
metric
]
metric_fn
=
metrics_fns
[
metric
]
metric_fn
=
metrics_fns
[
metric
]
metric_buff
[
0
]
+=
metric_fn
(
prd
,
tar
,
mask
)
*
inp
.
size
(
0
)
metric_buff
[
0
]
+=
metric_fn
(
prd
,
tar
,
mask
)
*
inp
.
size
(
0
)
metric_buff
[
1
]
+=
inp
.
size
(
0
)
if
dist
.
is_initialized
():
if
dist
.
is_initialized
():
dist
.
all_reduce
(
valid_loss
)
dist
.
all_reduce
(
valid_loss
)
...
@@ -294,8 +295,9 @@ def train_model(
...
@@ -294,8 +295,9 @@ def train_model(
dist
.
all_reduce
(
valid_metrics
[
metric
])
dist
.
all_reduce
(
valid_metrics
[
metric
])
valid_loss
=
(
valid_loss
[
0
]
/
valid_loss
[
1
]).
item
()
valid_loss
=
(
valid_loss
[
0
]
/
valid_loss
[
1
]).
item
()
for
metric
in
valid_metrics
:
for
metric
in
valid_metrics
:
valid_metrics
[
metric
]
=
(
valid_metrics
[
metric
][
0
]
/
valid_
loss
[
1
]).
item
()
valid_metrics
[
metric
]
=
(
valid_metrics
[
metric
][
0
]
/
valid_
metrics
[
metric
]
[
1
]).
item
()
if
scheduler
is
not
None
:
if
scheduler
is
not
None
:
scheduler
.
step
(
valid_loss
)
scheduler
.
step
(
valid_loss
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment