Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-harmonics
Commits
e6b5a952
Commit
e6b5a952
authored
Jul 01, 2025
by
Andrea Paris
Browse files
first patch
parent
5da00de0
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
14 additions
and
12 deletions
+14
-12
examples/depth/train.py
examples/depth/train.py
+14
-12
No files found.
examples/depth/train.py
View file @
e6b5a952
...
...
@@ -138,7 +138,7 @@ def validate_model(
for
metric
in
metrics_fns
:
metric_buff
=
metrics
[
metric
]
metric_fn
=
metrics_fns
[
metric
]
metric_buff
[
idx
]
=
metric_fn
(
prd
,
tar
.
unsqueeze
(
-
3
)
,
mask
)
metric_buff
[
idx
]
=
metric_fn
(
prd
,
tar
,
mask
)
tar
=
(
tar
*
mask
).
squeeze
()
prd
=
(
prd
*
mask
).
squeeze
()
...
...
@@ -257,7 +257,7 @@ def train_model(
# prepare metrics buffer for accumulation of validation metrics
valid_metrics
=
{}
for
metric
in
metrics_fns
:
valid_metrics
[
metric
]
=
torch
.
zeros
(
1
,
dtype
=
torch
.
float32
,
device
=
device
)
valid_metrics
[
metric
]
=
torch
.
zeros
(
2
,
dtype
=
torch
.
float32
,
device
=
device
)
model
.
eval
()
...
...
@@ -287,6 +287,7 @@ def train_model(
metric_buff
=
valid_metrics
[
metric
]
metric_fn
=
metrics_fns
[
metric
]
metric_buff
[
0
]
+=
metric_fn
(
prd
,
tar
,
mask
)
*
inp
.
size
(
0
)
metric_buff
[
1
]
+=
inp
.
size
(
0
)
if
dist
.
is_initialized
():
dist
.
all_reduce
(
valid_loss
)
...
...
@@ -294,8 +295,9 @@ def train_model(
dist
.
all_reduce
(
valid_metrics
[
metric
])
valid_loss
=
(
valid_loss
[
0
]
/
valid_loss
[
1
]).
item
()
for
metric
in
valid_metrics
:
valid_metrics
[
metric
]
=
(
valid_metrics
[
metric
][
0
]
/
valid_
loss
[
1
]).
item
()
valid_metrics
[
metric
]
=
(
valid_metrics
[
metric
][
0
]
/
valid_
metrics
[
metric
]
[
1
]).
item
()
if
scheduler
is
not
None
:
scheduler
.
step
(
valid_loss
)
...
...
@@ -435,16 +437,16 @@ def main(
# specify which models to train here
models
=
[
"transformer_sc2_layers4_e128"
,
"s2transformer_sc2_layers4_e128"
,
"ntransformer_sc2_layers4_e128"
,
#
"transformer_sc2_layers4_e128",
#
"s2transformer_sc2_layers4_e128",
#
"ntransformer_sc2_layers4_e128",
"s2ntransformer_sc2_layers4_e128"
,
"segformer_sc2_layers4_e128"
,
"s2segformer_sc2_layers4_e128"
,
"nsegformer_sc2_layers4_e128"
,
"s2nsegformer_sc2_layers4_e128"
,
"sfno_sc2_layers4_e32"
,
"lsno_sc2_layers4_e32"
,
#
"segformer_sc2_layers4_e128",
#
"s2segformer_sc2_layers4_e128",
#
"nsegformer_sc2_layers4_e128",
#
"s2nsegformer_sc2_layers4_e128",
#
"sfno_sc2_layers4_e32",
#
"lsno_sc2_layers4_e32",
]
models
=
{
k
:
baseline_models
[
k
]
for
k
in
models
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment