Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
e86f986d
Commit
e86f986d
authored
Jul 26, 2019
by
Syed Tousif Ahmed
Browse files
Put parser in a function to make script importable
parent
8d0deb09
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
63 additions
and
60 deletions
+63
-60
examples/imagenet/main_amp.py
examples/imagenet/main_amp.py
+63
-60
No files found.
examples/imagenet/main_amp.py
View file @
e86f986d
...
...
@@ -25,89 +25,92 @@ try:
except
ImportError
:
raise
ImportError
(
"Please install apex from https://www.github.com/nvidia/apex to run this example."
)
model_names
=
sorted
(
name
for
name
in
models
.
__dict__
def
fast_collate
(
batch
):
imgs
=
[
img
[
0
]
for
img
in
batch
]
targets
=
torch
.
tensor
([
target
[
1
]
for
target
in
batch
],
dtype
=
torch
.
int64
)
w
=
imgs
[
0
].
size
[
0
]
h
=
imgs
[
0
].
size
[
1
]
tensor
=
torch
.
zeros
(
(
len
(
imgs
),
3
,
h
,
w
),
dtype
=
torch
.
uint8
)
for
i
,
img
in
enumerate
(
imgs
):
nump_array
=
np
.
asarray
(
img
,
dtype
=
np
.
uint8
)
if
(
nump_array
.
ndim
<
3
):
nump_array
=
np
.
expand_dims
(
nump_array
,
axis
=-
1
)
nump_array
=
np
.
rollaxis
(
nump_array
,
2
)
tensor
[
i
]
+=
torch
.
from_numpy
(
nump_array
)
return
tensor
,
targets
def
parse
():
model_names
=
sorted
(
name
for
name
in
models
.
__dict__
if
name
.
islower
()
and
not
name
.
startswith
(
"__"
)
and
callable
(
models
.
__dict__
[
name
]))
parser
=
argparse
.
ArgumentParser
(
description
=
'PyTorch ImageNet Training'
)
parser
.
add_argument
(
'data'
,
metavar
=
'DIR'
,
parser
=
argparse
.
ArgumentParser
(
description
=
'PyTorch ImageNet Training'
)
parser
.
add_argument
(
'data'
,
metavar
=
'DIR'
,
help
=
'path to dataset'
)
parser
.
add_argument
(
'--arch'
,
'-a'
,
metavar
=
'ARCH'
,
default
=
'resnet18'
,
parser
.
add_argument
(
'--arch'
,
'-a'
,
metavar
=
'ARCH'
,
default
=
'resnet18'
,
choices
=
model_names
,
help
=
'model architecture: '
+
' | '
.
join
(
model_names
)
+
' (default: resnet18)'
)
parser
.
add_argument
(
'-j'
,
'--workers'
,
default
=
4
,
type
=
int
,
metavar
=
'N'
,
parser
.
add_argument
(
'-j'
,
'--workers'
,
default
=
4
,
type
=
int
,
metavar
=
'N'
,
help
=
'number of data loading workers (default: 4)'
)
parser
.
add_argument
(
'--epochs'
,
default
=
90
,
type
=
int
,
metavar
=
'N'
,
parser
.
add_argument
(
'--epochs'
,
default
=
90
,
type
=
int
,
metavar
=
'N'
,
help
=
'number of total epochs to run'
)
parser
.
add_argument
(
'--start-epoch'
,
default
=
0
,
type
=
int
,
metavar
=
'N'
,
parser
.
add_argument
(
'--start-epoch'
,
default
=
0
,
type
=
int
,
metavar
=
'N'
,
help
=
'manual epoch number (useful on restarts)'
)
parser
.
add_argument
(
'-b'
,
'--batch-size'
,
default
=
256
,
type
=
int
,
parser
.
add_argument
(
'-b'
,
'--batch-size'
,
default
=
256
,
type
=
int
,
metavar
=
'N'
,
help
=
'mini-batch size per process (default: 256)'
)
parser
.
add_argument
(
'--lr'
,
'--learning-rate'
,
default
=
0.1
,
type
=
float
,
parser
.
add_argument
(
'--lr'
,
'--learning-rate'
,
default
=
0.1
,
type
=
float
,
metavar
=
'LR'
,
help
=
'Initial learning rate. Will be scaled by <global batch size>/256: args.lr = args.lr*float(args.batch_size*args.world_size)/256. A warmup schedule will also be applied over the first 5 epochs.'
)
parser
.
add_argument
(
'--momentum'
,
default
=
0.9
,
type
=
float
,
metavar
=
'M'
,
parser
.
add_argument
(
'--momentum'
,
default
=
0.9
,
type
=
float
,
metavar
=
'M'
,
help
=
'momentum'
)
parser
.
add_argument
(
'--weight-decay'
,
'--wd'
,
default
=
1e-4
,
type
=
float
,
parser
.
add_argument
(
'--weight-decay'
,
'--wd'
,
default
=
1e-4
,
type
=
float
,
metavar
=
'W'
,
help
=
'weight decay (default: 1e-4)'
)
parser
.
add_argument
(
'--print-freq'
,
'-p'
,
default
=
10
,
type
=
int
,
parser
.
add_argument
(
'--print-freq'
,
'-p'
,
default
=
10
,
type
=
int
,
metavar
=
'N'
,
help
=
'print frequency (default: 10)'
)
parser
.
add_argument
(
'--resume'
,
default
=
''
,
type
=
str
,
metavar
=
'PATH'
,
parser
.
add_argument
(
'--resume'
,
default
=
''
,
type
=
str
,
metavar
=
'PATH'
,
help
=
'path to latest checkpoint (default: none)'
)
parser
.
add_argument
(
'-e'
,
'--evaluate'
,
dest
=
'evaluate'
,
action
=
'store_true'
,
parser
.
add_argument
(
'-e'
,
'--evaluate'
,
dest
=
'evaluate'
,
action
=
'store_true'
,
help
=
'evaluate model on validation set'
)
parser
.
add_argument
(
'--pretrained'
,
dest
=
'pretrained'
,
action
=
'store_true'
,
parser
.
add_argument
(
'--pretrained'
,
dest
=
'pretrained'
,
action
=
'store_true'
,
help
=
'use pre-trained model'
)
parser
.
add_argument
(
'--prof'
,
default
=-
1
,
type
=
int
,
parser
.
add_argument
(
'--prof'
,
default
=-
1
,
type
=
int
,
help
=
'Only run 10 iterations for profiling.'
)
parser
.
add_argument
(
'--deterministic'
,
action
=
'store_true'
)
parser
.
add_argument
(
'--deterministic'
,
action
=
'store_true'
)
parser
.
add_argument
(
"--local_rank"
,
default
=
0
,
type
=
int
)
parser
.
add_argument
(
'--sync_bn'
,
action
=
'store_true'
,
parser
.
add_argument
(
"--local_rank"
,
default
=
0
,
type
=
int
)
parser
.
add_argument
(
'--sync_bn'
,
action
=
'store_true'
,
help
=
'enabling apex sync BN.'
)
parser
.
add_argument
(
'--opt-level'
,
type
=
str
)
parser
.
add_argument
(
'--keep-batchnorm-fp32'
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
'--loss-scale'
,
type
=
str
,
default
=
None
)
cudnn
.
benchmark
=
True
def
fast_collate
(
batch
):
imgs
=
[
img
[
0
]
for
img
in
batch
]
targets
=
torch
.
tensor
([
target
[
1
]
for
target
in
batch
],
dtype
=
torch
.
int64
)
w
=
imgs
[
0
].
size
[
0
]
h
=
imgs
[
0
].
size
[
1
]
tensor
=
torch
.
zeros
(
(
len
(
imgs
),
3
,
h
,
w
),
dtype
=
torch
.
uint8
)
for
i
,
img
in
enumerate
(
imgs
):
nump_array
=
np
.
asarray
(
img
,
dtype
=
np
.
uint8
)
if
(
nump_array
.
ndim
<
3
):
nump_array
=
np
.
expand_dims
(
nump_array
,
axis
=-
1
)
nump_array
=
np
.
rollaxis
(
nump_array
,
2
)
tensor
[
i
]
+=
torch
.
from_numpy
(
nump_array
)
return
tensor
,
targets
parser
.
add_argument
(
'--opt-level'
,
type
=
str
)
parser
.
add_argument
(
'--keep-batchnorm-fp32'
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
'--loss-scale'
,
type
=
str
,
default
=
None
)
args
=
parser
.
parse_args
()
return
args
best_prec1
=
0
args
=
parser
.
parse_
args
()
def
main
():
global
best_prec1
,
args
print
(
"opt_level = {}"
.
format
(
args
.
opt_level
))
print
(
"keep_batchnorm_fp32 = {}"
.
format
(
args
.
keep_batchnorm_fp32
),
type
(
args
.
keep_batchnorm_fp32
))
print
(
"loss_scale = {}"
.
format
(
args
.
loss_scale
),
type
(
args
.
loss_scale
))
args
=
parse
()
print
(
"opt_level = {}"
.
format
(
args
.
opt_level
))
print
(
"keep_batchnorm_fp32 = {}"
.
format
(
args
.
keep_batchnorm_fp32
),
type
(
args
.
keep_batchnorm_fp32
))
print
(
"loss_scale = {}"
.
format
(
args
.
loss_scale
),
type
(
args
.
loss_scale
))
print
(
"
\n
CUDNN VERSION: {}
\n
"
.
format
(
torch
.
backends
.
cudnn
.
version
()))
print
(
"
\n
CUDNN VERSION: {}
\n
"
.
format
(
torch
.
backends
.
cudnn
.
version
()))
if
args
.
deterministic
:
cudnn
.
benchmark
=
True
best_prec1
=
0
if
args
.
deterministic
:
cudnn
.
benchmark
=
False
cudnn
.
deterministic
=
True
torch
.
manual_seed
(
args
.
local_rank
)
torch
.
set_printoptions
(
precision
=
10
)
def
main
():
global
best_prec1
,
args
args
.
distributed
=
False
if
'WORLD_SIZE'
in
os
.
environ
:
args
.
distributed
=
int
(
os
.
environ
[
'WORLD_SIZE'
])
>
1
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment