Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
2ca894da
"git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "3a2631ba0ffb07a8b1ea53636224cf0cd8d26949"
Commit
2ca894da
authored
Jan 27, 2020
by
Vitaly Fedyunin
Committed by
mcarilli
Jan 27, 2020
Browse files
Channels last support (#668)
parent
b66ffc1d
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
20 additions
and
14 deletions
+20
-14
examples/imagenet/main_amp.py
examples/imagenet/main_amp.py
+20
-14
No files found.
examples/imagenet/main_amp.py
View file @
2ca894da
...
@@ -25,21 +25,19 @@ try:
...
@@ -25,21 +25,19 @@ try:
except
ImportError
:
except
ImportError
:
raise
ImportError
(
"Please install apex from https://www.github.com/nvidia/apex to run this example."
)
raise
ImportError
(
"Please install apex from https://www.github.com/nvidia/apex to run this example."
)
def
fast_collate
(
batch
,
memory_format
):
def
fast_collate
(
batch
):
imgs
=
[
img
[
0
]
for
img
in
batch
]
imgs
=
[
img
[
0
]
for
img
in
batch
]
targets
=
torch
.
tensor
([
target
[
1
]
for
target
in
batch
],
dtype
=
torch
.
int64
)
targets
=
torch
.
tensor
([
target
[
1
]
for
target
in
batch
],
dtype
=
torch
.
int64
)
w
=
imgs
[
0
].
size
[
0
]
w
=
imgs
[
0
].
size
[
0
]
h
=
imgs
[
0
].
size
[
1
]
h
=
imgs
[
0
].
size
[
1
]
tensor
=
torch
.
zeros
(
(
len
(
imgs
),
3
,
h
,
w
),
dtype
=
torch
.
uint8
)
tensor
=
torch
.
zeros
(
(
len
(
imgs
),
3
,
h
,
w
),
dtype
=
torch
.
uint8
).
contiguous
(
memory_format
=
memory_format
)
for
i
,
img
in
enumerate
(
imgs
):
for
i
,
img
in
enumerate
(
imgs
):
nump_array
=
np
.
asarray
(
img
,
dtype
=
np
.
uint8
)
nump_array
=
np
.
asarray
(
img
,
dtype
=
np
.
uint8
)
if
(
nump_array
.
ndim
<
3
):
if
(
nump_array
.
ndim
<
3
):
nump_array
=
np
.
expand_dims
(
nump_array
,
axis
=-
1
)
nump_array
=
np
.
expand_dims
(
nump_array
,
axis
=-
1
)
nump_array
=
np
.
rollaxis
(
nump_array
,
2
)
nump_array
=
np
.
rollaxis
(
nump_array
,
2
)
tensor
[
i
]
+=
torch
.
from_numpy
(
nump_array
)
tensor
[
i
]
+=
torch
.
from_numpy
(
nump_array
)
return
tensor
,
targets
return
tensor
,
targets
...
@@ -90,6 +88,7 @@ def parse():
...
@@ -90,6 +88,7 @@ def parse():
parser
.
add_argument
(
'--opt-level'
,
type
=
str
)
parser
.
add_argument
(
'--opt-level'
,
type
=
str
)
parser
.
add_argument
(
'--keep-batchnorm-fp32'
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
'--keep-batchnorm-fp32'
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
'--loss-scale'
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
'--loss-scale'
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
'--channels-last'
,
type
=
bool
,
default
=
False
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
return
args
return
args
...
@@ -127,6 +126,11 @@ def main():
...
@@ -127,6 +126,11 @@ def main():
assert
torch
.
backends
.
cudnn
.
enabled
,
"Amp requires cudnn backend to be enabled."
assert
torch
.
backends
.
cudnn
.
enabled
,
"Amp requires cudnn backend to be enabled."
if
args
.
channels_last
:
memory_format
=
torch
.
channels_last
else
:
memory_format
=
torch
.
contiguous_format
# create model
# create model
if
args
.
pretrained
:
if
args
.
pretrained
:
print
(
"=> using pre-trained model '{}'"
.
format
(
args
.
arch
))
print
(
"=> using pre-trained model '{}'"
.
format
(
args
.
arch
))
...
@@ -140,10 +144,10 @@ def main():
...
@@ -140,10 +144,10 @@ def main():
print
(
"using apex synced BN"
)
print
(
"using apex synced BN"
)
model
=
apex
.
parallel
.
convert_syncbn_model
(
model
)
model
=
apex
.
parallel
.
convert_syncbn_model
(
model
)
model
=
model
.
cuda
()
model
=
model
.
cuda
()
.
to
(
memory_format
=
memory_format
)
# Scale learning rate based on global batch size
# Scale learning rate based on global batch size
args
.
lr
=
args
.
lr
*
float
(
args
.
batch_size
*
args
.
world_size
)
/
256.
args
.
lr
=
args
.
lr
*
float
(
args
.
batch_size
*
args
.
world_size
)
/
256.
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
args
.
lr
,
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
args
.
lr
,
momentum
=
args
.
momentum
,
momentum
=
args
.
momentum
,
weight_decay
=
args
.
weight_decay
)
weight_decay
=
args
.
weight_decay
)
...
@@ -161,7 +165,7 @@ def main():
...
@@ -161,7 +165,7 @@ def main():
# before model, ... = amp.initialize(model, ...), the call to amp.initialize may alter
# before model, ... = amp.initialize(model, ...), the call to amp.initialize may alter
# the types of model's parameters in a way that disrupts or destroys DDP's allreduce hooks.
# the types of model's parameters in a way that disrupts or destroys DDP's allreduce hooks.
if
args
.
distributed
:
if
args
.
distributed
:
# By default, apex.parallel.DistributedDataParallel overlaps communication with
# By default, apex.parallel.DistributedDataParallel overlaps communication with
# computation in the backward pass.
# computation in the backward pass.
# model = DDP(model)
# model = DDP(model)
# delay_allreduce delays all communication to the end of the backward pass.
# delay_allreduce delays all communication to the end of the backward pass.
...
@@ -218,16 +222,18 @@ def main():
...
@@ -218,16 +222,18 @@ def main():
train_sampler
=
torch
.
utils
.
data
.
distributed
.
DistributedSampler
(
train_dataset
)
train_sampler
=
torch
.
utils
.
data
.
distributed
.
DistributedSampler
(
train_dataset
)
val_sampler
=
torch
.
utils
.
data
.
distributed
.
DistributedSampler
(
val_dataset
)
val_sampler
=
torch
.
utils
.
data
.
distributed
.
DistributedSampler
(
val_dataset
)
collate_fn
=
lambda
b
:
fast_collate
(
b
,
memory_format
)
train_loader
=
torch
.
utils
.
data
.
DataLoader
(
train_loader
=
torch
.
utils
.
data
.
DataLoader
(
train_dataset
,
batch_size
=
args
.
batch_size
,
shuffle
=
(
train_sampler
is
None
),
train_dataset
,
batch_size
=
args
.
batch_size
,
shuffle
=
(
train_sampler
is
None
),
num_workers
=
args
.
workers
,
pin_memory
=
True
,
sampler
=
train_sampler
,
collate_fn
=
fast_
collate
)
num_workers
=
args
.
workers
,
pin_memory
=
True
,
sampler
=
train_sampler
,
collate_fn
=
collate
_fn
)
val_loader
=
torch
.
utils
.
data
.
DataLoader
(
val_loader
=
torch
.
utils
.
data
.
DataLoader
(
val_dataset
,
val_dataset
,
batch_size
=
args
.
batch_size
,
shuffle
=
False
,
batch_size
=
args
.
batch_size
,
shuffle
=
False
,
num_workers
=
args
.
workers
,
pin_memory
=
True
,
num_workers
=
args
.
workers
,
pin_memory
=
True
,
sampler
=
val_sampler
,
sampler
=
val_sampler
,
collate_fn
=
fast_
collate
)
collate_fn
=
collate
_fn
)
if
args
.
evaluate
:
if
args
.
evaluate
:
validate
(
val_loader
,
model
,
criterion
)
validate
(
val_loader
,
model
,
criterion
)
...
@@ -297,7 +303,7 @@ class data_prefetcher():
...
@@ -297,7 +303,7 @@ class data_prefetcher():
# else:
# else:
self
.
next_input
=
self
.
next_input
.
float
()
self
.
next_input
=
self
.
next_input
.
float
()
self
.
next_input
=
self
.
next_input
.
sub_
(
self
.
mean
).
div_
(
self
.
std
)
self
.
next_input
=
self
.
next_input
.
sub_
(
self
.
mean
).
div_
(
self
.
std
)
def
next
(
self
):
def
next
(
self
):
torch
.
cuda
.
current_stream
().
wait_stream
(
self
.
stream
)
torch
.
cuda
.
current_stream
().
wait_stream
(
self
.
stream
)
input
=
self
.
next_input
input
=
self
.
next_input
...
@@ -361,20 +367,20 @@ def train(train_loader, model, criterion, optimizer, epoch):
...
@@ -361,20 +367,20 @@ def train(train_loader, model, criterion, optimizer, epoch):
# Measure accuracy
# Measure accuracy
prec1
,
prec5
=
accuracy
(
output
.
data
,
target
,
topk
=
(
1
,
5
))
prec1
,
prec5
=
accuracy
(
output
.
data
,
target
,
topk
=
(
1
,
5
))
# Average loss and accuracy across processes for logging
# Average loss and accuracy across processes for logging
if
args
.
distributed
:
if
args
.
distributed
:
reduced_loss
=
reduce_tensor
(
loss
.
data
)
reduced_loss
=
reduce_tensor
(
loss
.
data
)
prec1
=
reduce_tensor
(
prec1
)
prec1
=
reduce_tensor
(
prec1
)
prec5
=
reduce_tensor
(
prec5
)
prec5
=
reduce_tensor
(
prec5
)
else
:
else
:
reduced_loss
=
loss
.
data
reduced_loss
=
loss
.
data
# to_python_float incurs a host<->device sync
# to_python_float incurs a host<->device sync
losses
.
update
(
to_python_float
(
reduced_loss
),
input
.
size
(
0
))
losses
.
update
(
to_python_float
(
reduced_loss
),
input
.
size
(
0
))
top1
.
update
(
to_python_float
(
prec1
),
input
.
size
(
0
))
top1
.
update
(
to_python_float
(
prec1
),
input
.
size
(
0
))
top5
.
update
(
to_python_float
(
prec5
),
input
.
size
(
0
))
top5
.
update
(
to_python_float
(
prec5
),
input
.
size
(
0
))
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
batch_time
.
update
((
time
.
time
()
-
end
)
/
args
.
print_freq
)
batch_time
.
update
((
time
.
time
()
-
end
)
/
args
.
print_freq
)
end
=
time
.
time
()
end
=
time
.
time
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment