Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
a3a09c8c
Commit
a3a09c8c
authored
Mar 08, 2019
by
Michael Carilli
Browse files
Fix for #188
parent
371633d5
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
6 additions
and
3 deletions
+6
-3
apex/amp/frontend.py
apex/amp/frontend.py
+1
-1
examples/imagenet/main_amp.py
examples/imagenet/main_amp.py
+2
-2
tests/L1/common/main_amp.py
tests/L1/common/main_amp.py
+3
-0
No files found.
apex/amp/frontend.py
View file @
a3a09c8c
...
@@ -269,7 +269,7 @@ def initialize(
...
@@ -269,7 +269,7 @@ def initialize(
https://github.com/NVIDIA/apex/tree/master/examples/imagenet
https://github.com/NVIDIA/apex/tree/master/examples/imagenet
"""
"""
_amp_state
.
opt_properties
=
Properties
()
_amp_state
.
opt_properties
=
Properties
()
_amp_state
.
opt_properties
.
verbosity
=
verbosity
_amp_state
.
verbosity
=
verbosity
if
not
enabled
:
if
not
enabled
:
return
models
,
optimizers
return
models
,
optimizers
...
...
examples/imagenet/main_amp.py
View file @
a3a09c8c
...
@@ -334,8 +334,6 @@ def train(train_loader, model, criterion, optimizer, epoch):
...
@@ -334,8 +334,6 @@ def train(train_loader, model, criterion, optimizer, epoch):
optimizer
.
step
()
optimizer
.
step
()
if
args
.
prof
:
torch
.
cuda
.
nvtx
.
range_pop
()
if
args
.
prof
:
torch
.
cuda
.
nvtx
.
range_pop
()
input
,
target
=
prefetcher
.
next
()
if
i
%
args
.
print_freq
==
0
:
if
i
%
args
.
print_freq
==
0
:
# Every print_freq iterations, check the loss accuracy and speed.
# Every print_freq iterations, check the loss accuracy and speed.
# For best performance, it doesn't make sense to print these metrics every
# For best performance, it doesn't make sense to print these metrics every
...
@@ -374,6 +372,8 @@ def train(train_loader, model, criterion, optimizer, epoch):
...
@@ -374,6 +372,8 @@ def train(train_loader, model, criterion, optimizer, epoch):
batch_time
=
batch_time
,
batch_time
=
batch_time
,
loss
=
losses
,
top1
=
top1
,
top5
=
top5
))
loss
=
losses
,
top1
=
top1
,
top5
=
top5
))
input
,
target
=
prefetcher
.
next
()
def
validate
(
val_loader
,
model
,
criterion
):
def
validate
(
val_loader
,
model
,
criterion
):
batch_time
=
AverageMeter
()
batch_time
=
AverageMeter
()
...
...
tests/L1/common/main_amp.py
View file @
a3a09c8c
...
@@ -365,6 +365,9 @@ def train(train_loader, model, criterion, optimizer, epoch):
...
@@ -365,6 +365,9 @@ def train(train_loader, model, criterion, optimizer, epoch):
batch_time
.
update
(
time
.
time
()
-
end
)
batch_time
.
update
(
time
.
time
()
-
end
)
end
=
time
.
time
()
end
=
time
.
time
()
# If you decide to refactor this test, like examples/imagenet, to sample the loss every
# print_freq iterations, make sure to move this prefetching below the accuracy calculation.
input
,
target
=
prefetcher
.
next
()
input
,
target
=
prefetcher
.
next
()
if
i
%
args
.
print_freq
==
0
and
i
>
1
:
if
i
%
args
.
print_freq
==
0
and
i
>
1
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment