Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
cae6005c
Commit
cae6005c
authored
May 25, 2018
by
Christian Sarofeen
Browse files
Update imagenet example to fast version.
parent
343590a1
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
40 additions
and
19 deletions
+40
-19
examples/imagenet/main.py
examples/imagenet/main.py
+40
-19
No files found.
examples/imagenet/main.py
View file @
cae6005c
...
@@ -16,14 +16,14 @@ import torchvision.transforms as transforms
...
@@ -16,14 +16,14 @@ import torchvision.transforms as transforms
import
torchvision.datasets
as
datasets
import
torchvision.datasets
as
datasets
import
torchvision.models
as
models
import
torchvision.models
as
models
import
numpy
as
np
try
:
try
:
from
apex.parallel
import
DistributedDataParallel
as
DDP
from
apex.parallel
import
DistributedDataParallel
as
DDP
from
apex.fp16_utils
import
*
from
apex.fp16_utils
import
*
except
ImportError
:
except
ImportError
:
raise
ImportError
(
"Please install apex from https://www.github.com/nvidia/apex to run this example."
)
raise
ImportError
(
"Please install apex from https://www.github.com/nvidia/apex to run this example."
)
import
numpy
as
np
model_names
=
sorted
(
name
for
name
in
models
.
__dict__
model_names
=
sorted
(
name
for
name
in
models
.
__dict__
if
name
.
islower
()
and
not
name
.
startswith
(
"__"
)
if
name
.
islower
()
and
not
name
.
startswith
(
"__"
)
and
callable
(
models
.
__dict__
[
name
]))
and
callable
(
models
.
__dict__
[
name
]))
...
@@ -61,8 +61,8 @@ parser.add_argument('--pretrained', dest='pretrained', action='store_true',
...
@@ -61,8 +61,8 @@ parser.add_argument('--pretrained', dest='pretrained', action='store_true',
parser
.
add_argument
(
'--fp16'
,
action
=
'store_true'
,
parser
.
add_argument
(
'--fp16'
,
action
=
'store_true'
,
help
=
'Run model fp16 mode.'
)
help
=
'Run model fp16 mode.'
)
parser
.
add_argument
(
'--
static-
loss-scale'
,
type
=
float
,
default
=
1
,
parser
.
add_argument
(
'--loss-scale'
,
type
=
float
,
default
=
1
,
help
=
'
Static l
oss scal
e
, positive power of 2 values can improve fp16 convergence.'
)
help
=
'
L
oss scal
ing
, positive power of 2 values can improve fp16 convergence.'
)
parser
.
add_argument
(
'--prof'
,
dest
=
'prof'
,
action
=
'store_true'
,
parser
.
add_argument
(
'--prof'
,
dest
=
'prof'
,
action
=
'store_true'
,
help
=
'Only run 10 iterations for profiling.'
)
help
=
'Only run 10 iterations for profiling.'
)
...
@@ -80,6 +80,26 @@ parser.add_argument('--rank', default=0, type=int,
...
@@ -80,6 +80,26 @@ parser.add_argument('--rank', default=0, type=int,
cudnn
.
benchmark
=
True
cudnn
.
benchmark
=
True
import
numpy
as
np
def
fast_collate
(
batch
):
imgs
=
[
img
[
0
]
for
img
in
batch
]
targets
=
torch
.
tensor
([
target
[
1
]
for
target
in
batch
],
dtype
=
torch
.
int64
)
w
=
imgs
[
0
].
size
[
0
]
h
=
imgs
[
0
].
size
[
1
]
tensor
=
torch
.
zeros
(
(
len
(
imgs
),
3
,
h
,
w
),
dtype
=
torch
.
uint8
)
for
i
,
img
in
enumerate
(
imgs
):
nump_array
=
np
.
asarray
(
img
,
dtype
=
np
.
uint8
)
tens
=
torch
.
from_numpy
(
nump_array
)
if
(
nump_array
.
ndim
<
3
):
nump_array
=
np
.
expand_dims
(
nump_array
,
axis
=-
1
)
nump_array
=
np
.
rollaxis
(
nump_array
,
2
)
tensor
[
i
]
+=
torch
.
from_numpy
(
nump_array
)
return
tensor
,
targets
best_prec1
=
0
best_prec1
=
0
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
def
main
():
def
main
():
...
@@ -93,18 +113,12 @@ def main():
...
@@ -93,18 +113,12 @@ def main():
if
args
.
distributed
:
if
args
.
distributed
:
torch
.
cuda
.
set_device
(
args
.
gpu
)
torch
.
cuda
.
set_device
(
args
.
gpu
)
dist
.
init_process_group
(
backend
=
args
.
dist_backend
,
dist
.
init_process_group
(
backend
=
args
.
dist_backend
,
init_method
=
args
.
dist_url
,
init_method
=
args
.
dist_url
,
world_size
=
args
.
world_size
,
rank
=
args
.
rank
)
world_size
=
args
.
world_size
,
rank
=
args
.
rank
)
if
args
.
fp16
:
if
args
.
fp16
:
assert
torch
.
backends
.
cudnn
.
enabled
,
"fp16 mode requires cudnn backend to be enabled."
assert
torch
.
backends
.
cudnn
.
enabled
,
"fp16 mode requires cudnn backend to be enabled."
if
args
.
static_loss_scale
!=
1.0
:
if
not
args
.
fp16
:
print
(
"Warning: if --fp16 is not used, static_loss_scale will be ignored."
)
# create model
# create model
if
args
.
pretrained
:
if
args
.
pretrained
:
print
(
"=> using pre-trained model '{}'"
.
format
(
args
.
arch
))
print
(
"=> using pre-trained model '{}'"
.
format
(
args
.
arch
))
...
@@ -154,7 +168,7 @@ def main():
...
@@ -154,7 +168,7 @@ def main():
if
(
args
.
arch
==
"inception_v3"
):
if
(
args
.
arch
==
"inception_v3"
):
crop_size
=
299
crop_size
=
299
val_size
=
320
#
A
rbitrarily
chosen,
adjust
able
.
val_size
=
320
#
I chose this value a
rbitrarily
, we can
adjust.
else
:
else
:
crop_size
=
224
crop_size
=
224
val_size
=
256
val_size
=
256
...
@@ -164,8 +178,8 @@ def main():
...
@@ -164,8 +178,8 @@ def main():
transforms
.
Compose
([
transforms
.
Compose
([
transforms
.
RandomResizedCrop
(
crop_size
),
transforms
.
RandomResizedCrop
(
crop_size
),
transforms
.
RandomHorizontalFlip
(),
transforms
.
RandomHorizontalFlip
(),
transforms
.
ToTensor
(),
#
transforms.ToTensor(),
Too slow
normalize
,
#
normalize,
]))
]))
if
args
.
distributed
:
if
args
.
distributed
:
...
@@ -175,7 +189,7 @@ def main():
...
@@ -175,7 +189,7 @@ def main():
train_loader
=
torch
.
utils
.
data
.
DataLoader
(
train_loader
=
torch
.
utils
.
data
.
DataLoader
(
train_dataset
,
batch_size
=
args
.
batch_size
,
shuffle
=
(
train_sampler
is
None
),
train_dataset
,
batch_size
=
args
.
batch_size
,
shuffle
=
(
train_sampler
is
None
),
num_workers
=
args
.
workers
,
pin_memory
=
True
,
sampler
=
train_sampler
)
num_workers
=
args
.
workers
,
pin_memory
=
True
,
sampler
=
train_sampler
,
collate_fn
=
fast_collate
)
val_loader
=
torch
.
utils
.
data
.
DataLoader
(
val_loader
=
torch
.
utils
.
data
.
DataLoader
(
datasets
.
ImageFolder
(
valdir
,
transforms
.
Compose
([
datasets
.
ImageFolder
(
valdir
,
transforms
.
Compose
([
...
@@ -215,6 +229,13 @@ def main():
...
@@ -215,6 +229,13 @@ def main():
'optimizer'
:
optimizer
.
state_dict
(),
'optimizer'
:
optimizer
.
state_dict
(),
},
is_best
)
},
is_best
)
# item() is a recent addition, so this helps with backward compatibility.
def
to_python_float
(
t
):
if
hasattr
(
t
,
'item'
):
return
t
.
item
()
else
:
return
t
[
0
]
class
data_prefetcher
():
class
data_prefetcher
():
def
__init__
(
self
,
loader
):
def
__init__
(
self
,
loader
):
self
.
loader
=
iter
(
loader
)
self
.
loader
=
iter
(
loader
)
...
@@ -284,15 +305,15 @@ def train(train_loader, model, criterion, optimizer, epoch):
...
@@ -284,15 +305,15 @@ def train(train_loader, model, criterion, optimizer, epoch):
top1
.
update
(
to_python_float
(
prec1
),
input
.
size
(
0
))
top1
.
update
(
to_python_float
(
prec1
),
input
.
size
(
0
))
top5
.
update
(
to_python_float
(
prec5
),
input
.
size
(
0
))
top5
.
update
(
to_python_float
(
prec5
),
input
.
size
(
0
))
loss
=
loss
*
args
.
loss_scale
# compute gradient and do SGD step
# compute gradient and do SGD step
if
args
.
fp16
:
if
args
.
fp16
:
loss
=
loss
*
args
.
static_loss_scale
model
.
zero_grad
()
model
.
zero_grad
()
loss
.
backward
()
loss
.
backward
()
model_grads_to_master_grads
(
model_params
,
master_params
)
model_grads_to_master_grads
(
model_params
,
master_params
)
if
args
.
static_
loss_scale
!=
1
:
if
args
.
loss_scale
!=
1
:
for
param
in
master_params
:
for
param
in
master_params
:
param
.
grad
.
data
=
param
.
grad
.
data
/
args
.
static_
loss_scale
param
.
grad
.
data
=
param
.
grad
.
data
/
args
.
loss_scale
optimizer
.
step
()
optimizer
.
step
()
master_params_to_model_params
(
model_params
,
master_params
)
master_params_to_model_params
(
model_params
,
master_params
)
else
:
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment