Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
858d7899
Commit
858d7899
authored
Feb 05, 2020
by
Kexin Yu
Browse files
Merge branch 'master' of
https://github.com/NVIDIA/apex
parents
8d2647f8
2ca894da
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
25 additions
and
20 deletions
+25
-20
apex/optimizers/fused_novograd.py
apex/optimizers/fused_novograd.py
+1
-1
examples/imagenet/main_amp.py
examples/imagenet/main_amp.py
+20
-14
setup.py
setup.py
+4
-5
No files found.
apex/optimizers/fused_novograd.py
View file @
858d7899
...
...
@@ -47,7 +47,7 @@ class FusedNovoGrad(torch.optim.Optimizer):
reg_inside_moment (bool, optional): whether do regularization (norm and L2)
in momentum calculation. True for include, False for not include and
only do it on update term. (default: False)
grad_averaging (bool, optional): whether apply (1-beta
2
) to grad when
grad_averaging (bool, optional): whether apply (1-beta
1
) to grad when
calculating running averages of gradient. (default: True)
norm_type (int, optional): which norm to calculate for each layer.
2 for L2 norm, and 0 for infinite norm. These 2 are only supported
...
...
examples/imagenet/main_amp.py
View file @
858d7899
...
...
@@ -25,21 +25,19 @@ try:
except
ImportError
:
raise
ImportError
(
"Please install apex from https://www.github.com/nvidia/apex to run this example."
)
def
fast_collate
(
batch
,
memory_format
):
def
fast_collate
(
batch
):
imgs
=
[
img
[
0
]
for
img
in
batch
]
targets
=
torch
.
tensor
([
target
[
1
]
for
target
in
batch
],
dtype
=
torch
.
int64
)
w
=
imgs
[
0
].
size
[
0
]
h
=
imgs
[
0
].
size
[
1
]
tensor
=
torch
.
zeros
(
(
len
(
imgs
),
3
,
h
,
w
),
dtype
=
torch
.
uint8
)
tensor
=
torch
.
zeros
(
(
len
(
imgs
),
3
,
h
,
w
),
dtype
=
torch
.
uint8
).
contiguous
(
memory_format
=
memory_format
)
for
i
,
img
in
enumerate
(
imgs
):
nump_array
=
np
.
asarray
(
img
,
dtype
=
np
.
uint8
)
if
(
nump_array
.
ndim
<
3
):
nump_array
=
np
.
expand_dims
(
nump_array
,
axis
=-
1
)
nump_array
=
np
.
rollaxis
(
nump_array
,
2
)
tensor
[
i
]
+=
torch
.
from_numpy
(
nump_array
)
return
tensor
,
targets
...
...
@@ -90,6 +88,7 @@ def parse():
parser
.
add_argument
(
'--opt-level'
,
type
=
str
)
parser
.
add_argument
(
'--keep-batchnorm-fp32'
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
'--loss-scale'
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
'--channels-last'
,
type
=
bool
,
default
=
False
)
args
=
parser
.
parse_args
()
return
args
...
...
@@ -127,6 +126,11 @@ def main():
assert
torch
.
backends
.
cudnn
.
enabled
,
"Amp requires cudnn backend to be enabled."
if
args
.
channels_last
:
memory_format
=
torch
.
channels_last
else
:
memory_format
=
torch
.
contiguous_format
# create model
if
args
.
pretrained
:
print
(
"=> using pre-trained model '{}'"
.
format
(
args
.
arch
))
...
...
@@ -140,10 +144,10 @@ def main():
print
(
"using apex synced BN"
)
model
=
apex
.
parallel
.
convert_syncbn_model
(
model
)
model
=
model
.
cuda
()
model
=
model
.
cuda
()
.
to
(
memory_format
=
memory_format
)
# Scale learning rate based on global batch size
args
.
lr
=
args
.
lr
*
float
(
args
.
batch_size
*
args
.
world_size
)
/
256.
args
.
lr
=
args
.
lr
*
float
(
args
.
batch_size
*
args
.
world_size
)
/
256.
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
args
.
lr
,
momentum
=
args
.
momentum
,
weight_decay
=
args
.
weight_decay
)
...
...
@@ -161,7 +165,7 @@ def main():
# before model, ... = amp.initialize(model, ...), the call to amp.initialize may alter
# the types of model's parameters in a way that disrupts or destroys DDP's allreduce hooks.
if
args
.
distributed
:
# By default, apex.parallel.DistributedDataParallel overlaps communication with
# By default, apex.parallel.DistributedDataParallel overlaps communication with
# computation in the backward pass.
# model = DDP(model)
# delay_allreduce delays all communication to the end of the backward pass.
...
...
@@ -218,16 +222,18 @@ def main():
train_sampler
=
torch
.
utils
.
data
.
distributed
.
DistributedSampler
(
train_dataset
)
val_sampler
=
torch
.
utils
.
data
.
distributed
.
DistributedSampler
(
val_dataset
)
collate_fn
=
lambda
b
:
fast_collate
(
b
,
memory_format
)
train_loader
=
torch
.
utils
.
data
.
DataLoader
(
train_dataset
,
batch_size
=
args
.
batch_size
,
shuffle
=
(
train_sampler
is
None
),
num_workers
=
args
.
workers
,
pin_memory
=
True
,
sampler
=
train_sampler
,
collate_fn
=
fast_
collate
)
num_workers
=
args
.
workers
,
pin_memory
=
True
,
sampler
=
train_sampler
,
collate_fn
=
collate
_fn
)
val_loader
=
torch
.
utils
.
data
.
DataLoader
(
val_dataset
,
batch_size
=
args
.
batch_size
,
shuffle
=
False
,
num_workers
=
args
.
workers
,
pin_memory
=
True
,
sampler
=
val_sampler
,
collate_fn
=
fast_
collate
)
collate_fn
=
collate
_fn
)
if
args
.
evaluate
:
validate
(
val_loader
,
model
,
criterion
)
...
...
@@ -297,7 +303,7 @@ class data_prefetcher():
# else:
self
.
next_input
=
self
.
next_input
.
float
()
self
.
next_input
=
self
.
next_input
.
sub_
(
self
.
mean
).
div_
(
self
.
std
)
def
next
(
self
):
torch
.
cuda
.
current_stream
().
wait_stream
(
self
.
stream
)
input
=
self
.
next_input
...
...
@@ -361,20 +367,20 @@ def train(train_loader, model, criterion, optimizer, epoch):
# Measure accuracy
prec1
,
prec5
=
accuracy
(
output
.
data
,
target
,
topk
=
(
1
,
5
))
# Average loss and accuracy across processes for logging
# Average loss and accuracy across processes for logging
if
args
.
distributed
:
reduced_loss
=
reduce_tensor
(
loss
.
data
)
prec1
=
reduce_tensor
(
prec1
)
prec5
=
reduce_tensor
(
prec5
)
else
:
reduced_loss
=
loss
.
data
# to_python_float incurs a host<->device sync
losses
.
update
(
to_python_float
(
reduced_loss
),
input
.
size
(
0
))
top1
.
update
(
to_python_float
(
prec1
),
input
.
size
(
0
))
top5
.
update
(
to_python_float
(
prec5
),
input
.
size
(
0
))
torch
.
cuda
.
synchronize
()
batch_time
.
update
((
time
.
time
()
-
end
)
/
args
.
print_freq
)
end
=
time
.
time
()
...
...
setup.py
View file @
858d7899
...
...
@@ -2,7 +2,6 @@ import torch
from
setuptools
import
setup
,
find_packages
import
subprocess
from
pip._internal
import
main
as
pipmain
import
sys
import
warnings
import
os
...
...
@@ -31,10 +30,11 @@ if TORCH_MAJOR == 0 and TORCH_MINOR < 4:
cmdclass
=
{}
ext_modules
=
[]
extras
=
{}
if
"--pyprof"
in
sys
.
argv
:
with
open
(
'requirements.txt'
)
as
f
:
required_packages
=
f
.
read
().
splitlines
()
pipmain
([
"install"
]
+
required_packages
)
extras
[
'pyprof'
]
=
required_packages
try
:
sys
.
argv
.
remove
(
"--pyprof"
)
except
:
...
...
@@ -153,9 +153,7 @@ if "--bnp" in sys.argv:
'nvcc'
:[
'-DCUDA_HAS_FP16=1'
,
'-D__CUDA_NO_HALF_OPERATORS__'
,
'-D__CUDA_NO_HALF_CONVERSIONS__'
,
'-D__CUDA_NO_HALF2_OPERATORS__'
,
'-gencode'
,
'arch=compute_70,code=sm_70'
]
+
version_dependent_macros
}))
'-D__CUDA_NO_HALF2_OPERATORS__'
]
+
version_dependent_macros
}))
if
"--xentropy"
in
sys
.
argv
:
from
torch.utils.cpp_extension
import
CUDAExtension
...
...
@@ -209,4 +207,5 @@ setup(
description
=
'PyTorch Extensions written by NVIDIA'
,
ext_modules
=
ext_modules
,
cmdclass
=
cmdclass
,
extras_require
=
extras
,
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment