Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
c1f1e226
Unverified
Commit
c1f1e226
authored
Sep 14, 2021
by
Prabhat Roy
Committed by
GitHub
Sep 14, 2021
Browse files
Added update_parameters to EMA to fix calculation (#4406)
parent
9fa689b2
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
11 additions
and
0 deletions
+11
-0
references/classification/utils.py
references/classification/utils.py
+11
-0
No files found.
references/classification/utils.py
View file @
c1f1e226
...
@@ -172,6 +172,17 @@ class ExponentialMovingAverage(torch.optim.swa_utils.AveragedModel):
...
@@ -172,6 +172,17 @@ class ExponentialMovingAverage(torch.optim.swa_utils.AveragedModel):
decay
*
avg_model_param
+
(
1
-
decay
)
*
model_param
)
decay
*
avg_model_param
+
(
1
-
decay
)
*
model_param
)
super
().
__init__
(
model
,
device
,
ema_avg
)
super
().
__init__
(
model
,
device
,
ema_avg
)
def
update_parameters
(
self
,
model
):
for
p_swa
,
p_model
in
zip
(
self
.
module
.
state_dict
().
values
(),
model
.
state_dict
().
values
()):
device
=
p_swa
.
device
p_model_
=
p_model
.
detach
().
to
(
device
)
if
self
.
n_averaged
==
0
:
p_swa
.
detach
().
copy_
(
p_model_
)
else
:
p_swa
.
detach
().
copy_
(
self
.
avg_fn
(
p_swa
.
detach
(),
p_model_
,
self
.
n_averaged
.
to
(
device
)))
self
.
n_averaged
+=
1
def
accuracy
(
output
,
target
,
topk
=
(
1
,)):
def
accuracy
(
output
,
target
,
topk
=
(
1
,)):
"""Computes the accuracy over the k top predictions for the specified values of k"""
"""Computes the accuracy over the k top predictions for the specified values of k"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment