Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
858d7899
Commit
858d7899
authored
Feb 05, 2020
by
Kexin Yu
Browse files
Merge branch 'master' of
https://github.com/NVIDIA/apex
parents
8d2647f8
2ca894da
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
25 additions
and
20 deletions
+25
-20
apex/optimizers/fused_novograd.py
apex/optimizers/fused_novograd.py
+1
-1
examples/imagenet/main_amp.py
examples/imagenet/main_amp.py
+20
-14
setup.py
setup.py
+4
-5
No files found.
apex/optimizers/fused_novograd.py
View file @
858d7899
...
@@ -47,7 +47,7 @@ class FusedNovoGrad(torch.optim.Optimizer):
...
@@ -47,7 +47,7 @@ class FusedNovoGrad(torch.optim.Optimizer):
reg_inside_moment (bool, optional): whether do regularization (norm and L2)
reg_inside_moment (bool, optional): whether do regularization (norm and L2)
in momentum calculation. True for include, False for not include and
in momentum calculation. True for include, False for not include and
only do it on update term. (default: False)
only do it on update term. (default: False)
grad_averaging (bool, optional): whether apply (1-beta
2
) to grad when
grad_averaging (bool, optional): whether apply (1-beta
1
) to grad when
calculating running averages of gradient. (default: True)
calculating running averages of gradient. (default: True)
norm_type (int, optional): which norm to calculate for each layer.
norm_type (int, optional): which norm to calculate for each layer.
2 for L2 norm, and 0 for infinite norm. These 2 are only supported
2 for L2 norm, and 0 for infinite norm. These 2 are only supported
...
...
examples/imagenet/main_amp.py
View file @
858d7899
...
@@ -25,21 +25,19 @@ try:
...
@@ -25,21 +25,19 @@ try:
except
ImportError
:
except
ImportError
:
raise
ImportError
(
"Please install apex from https://www.github.com/nvidia/apex to run this example."
)
raise
ImportError
(
"Please install apex from https://www.github.com/nvidia/apex to run this example."
)
def
fast_collate
(
batch
,
memory_format
):
def
fast_collate
(
batch
):
imgs
=
[
img
[
0
]
for
img
in
batch
]
imgs
=
[
img
[
0
]
for
img
in
batch
]
targets
=
torch
.
tensor
([
target
[
1
]
for
target
in
batch
],
dtype
=
torch
.
int64
)
targets
=
torch
.
tensor
([
target
[
1
]
for
target
in
batch
],
dtype
=
torch
.
int64
)
w
=
imgs
[
0
].
size
[
0
]
w
=
imgs
[
0
].
size
[
0
]
h
=
imgs
[
0
].
size
[
1
]
h
=
imgs
[
0
].
size
[
1
]
tensor
=
torch
.
zeros
(
(
len
(
imgs
),
3
,
h
,
w
),
dtype
=
torch
.
uint8
)
tensor
=
torch
.
zeros
(
(
len
(
imgs
),
3
,
h
,
w
),
dtype
=
torch
.
uint8
).
contiguous
(
memory_format
=
memory_format
)
for
i
,
img
in
enumerate
(
imgs
):
for
i
,
img
in
enumerate
(
imgs
):
nump_array
=
np
.
asarray
(
img
,
dtype
=
np
.
uint8
)
nump_array
=
np
.
asarray
(
img
,
dtype
=
np
.
uint8
)
if
(
nump_array
.
ndim
<
3
):
if
(
nump_array
.
ndim
<
3
):
nump_array
=
np
.
expand_dims
(
nump_array
,
axis
=-
1
)
nump_array
=
np
.
expand_dims
(
nump_array
,
axis
=-
1
)
nump_array
=
np
.
rollaxis
(
nump_array
,
2
)
nump_array
=
np
.
rollaxis
(
nump_array
,
2
)
tensor
[
i
]
+=
torch
.
from_numpy
(
nump_array
)
tensor
[
i
]
+=
torch
.
from_numpy
(
nump_array
)
return
tensor
,
targets
return
tensor
,
targets
...
@@ -90,6 +88,7 @@ def parse():
...
@@ -90,6 +88,7 @@ def parse():
parser
.
add_argument
(
'--opt-level'
,
type
=
str
)
parser
.
add_argument
(
'--opt-level'
,
type
=
str
)
parser
.
add_argument
(
'--keep-batchnorm-fp32'
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
'--keep-batchnorm-fp32'
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
'--loss-scale'
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
'--loss-scale'
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
'--channels-last'
,
type
=
bool
,
default
=
False
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
return
args
return
args
...
@@ -127,6 +126,11 @@ def main():
...
@@ -127,6 +126,11 @@ def main():
assert
torch
.
backends
.
cudnn
.
enabled
,
"Amp requires cudnn backend to be enabled."
assert
torch
.
backends
.
cudnn
.
enabled
,
"Amp requires cudnn backend to be enabled."
if
args
.
channels_last
:
memory_format
=
torch
.
channels_last
else
:
memory_format
=
torch
.
contiguous_format
# create model
# create model
if
args
.
pretrained
:
if
args
.
pretrained
:
print
(
"=> using pre-trained model '{}'"
.
format
(
args
.
arch
))
print
(
"=> using pre-trained model '{}'"
.
format
(
args
.
arch
))
...
@@ -140,10 +144,10 @@ def main():
...
@@ -140,10 +144,10 @@ def main():
print
(
"using apex synced BN"
)
print
(
"using apex synced BN"
)
model
=
apex
.
parallel
.
convert_syncbn_model
(
model
)
model
=
apex
.
parallel
.
convert_syncbn_model
(
model
)
model
=
model
.
cuda
()
model
=
model
.
cuda
()
.
to
(
memory_format
=
memory_format
)
# Scale learning rate based on global batch size
# Scale learning rate based on global batch size
args
.
lr
=
args
.
lr
*
float
(
args
.
batch_size
*
args
.
world_size
)
/
256.
args
.
lr
=
args
.
lr
*
float
(
args
.
batch_size
*
args
.
world_size
)
/
256.
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
args
.
lr
,
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
args
.
lr
,
momentum
=
args
.
momentum
,
momentum
=
args
.
momentum
,
weight_decay
=
args
.
weight_decay
)
weight_decay
=
args
.
weight_decay
)
...
@@ -161,7 +165,7 @@ def main():
...
@@ -161,7 +165,7 @@ def main():
# before model, ... = amp.initialize(model, ...), the call to amp.initialize may alter
# before model, ... = amp.initialize(model, ...), the call to amp.initialize may alter
# the types of model's parameters in a way that disrupts or destroys DDP's allreduce hooks.
# the types of model's parameters in a way that disrupts or destroys DDP's allreduce hooks.
if
args
.
distributed
:
if
args
.
distributed
:
# By default, apex.parallel.DistributedDataParallel overlaps communication with
# By default, apex.parallel.DistributedDataParallel overlaps communication with
# computation in the backward pass.
# computation in the backward pass.
# model = DDP(model)
# model = DDP(model)
# delay_allreduce delays all communication to the end of the backward pass.
# delay_allreduce delays all communication to the end of the backward pass.
...
@@ -218,16 +222,18 @@ def main():
...
@@ -218,16 +222,18 @@ def main():
train_sampler
=
torch
.
utils
.
data
.
distributed
.
DistributedSampler
(
train_dataset
)
train_sampler
=
torch
.
utils
.
data
.
distributed
.
DistributedSampler
(
train_dataset
)
val_sampler
=
torch
.
utils
.
data
.
distributed
.
DistributedSampler
(
val_dataset
)
val_sampler
=
torch
.
utils
.
data
.
distributed
.
DistributedSampler
(
val_dataset
)
collate_fn
=
lambda
b
:
fast_collate
(
b
,
memory_format
)
train_loader
=
torch
.
utils
.
data
.
DataLoader
(
train_loader
=
torch
.
utils
.
data
.
DataLoader
(
train_dataset
,
batch_size
=
args
.
batch_size
,
shuffle
=
(
train_sampler
is
None
),
train_dataset
,
batch_size
=
args
.
batch_size
,
shuffle
=
(
train_sampler
is
None
),
num_workers
=
args
.
workers
,
pin_memory
=
True
,
sampler
=
train_sampler
,
collate_fn
=
fast_
collate
)
num_workers
=
args
.
workers
,
pin_memory
=
True
,
sampler
=
train_sampler
,
collate_fn
=
collate
_fn
)
val_loader
=
torch
.
utils
.
data
.
DataLoader
(
val_loader
=
torch
.
utils
.
data
.
DataLoader
(
val_dataset
,
val_dataset
,
batch_size
=
args
.
batch_size
,
shuffle
=
False
,
batch_size
=
args
.
batch_size
,
shuffle
=
False
,
num_workers
=
args
.
workers
,
pin_memory
=
True
,
num_workers
=
args
.
workers
,
pin_memory
=
True
,
sampler
=
val_sampler
,
sampler
=
val_sampler
,
collate_fn
=
fast_
collate
)
collate_fn
=
collate
_fn
)
if
args
.
evaluate
:
if
args
.
evaluate
:
validate
(
val_loader
,
model
,
criterion
)
validate
(
val_loader
,
model
,
criterion
)
...
@@ -297,7 +303,7 @@ class data_prefetcher():
...
@@ -297,7 +303,7 @@ class data_prefetcher():
# else:
# else:
self
.
next_input
=
self
.
next_input
.
float
()
self
.
next_input
=
self
.
next_input
.
float
()
self
.
next_input
=
self
.
next_input
.
sub_
(
self
.
mean
).
div_
(
self
.
std
)
self
.
next_input
=
self
.
next_input
.
sub_
(
self
.
mean
).
div_
(
self
.
std
)
def
next
(
self
):
def
next
(
self
):
torch
.
cuda
.
current_stream
().
wait_stream
(
self
.
stream
)
torch
.
cuda
.
current_stream
().
wait_stream
(
self
.
stream
)
input
=
self
.
next_input
input
=
self
.
next_input
...
@@ -361,20 +367,20 @@ def train(train_loader, model, criterion, optimizer, epoch):
...
@@ -361,20 +367,20 @@ def train(train_loader, model, criterion, optimizer, epoch):
# Measure accuracy
# Measure accuracy
prec1
,
prec5
=
accuracy
(
output
.
data
,
target
,
topk
=
(
1
,
5
))
prec1
,
prec5
=
accuracy
(
output
.
data
,
target
,
topk
=
(
1
,
5
))
# Average loss and accuracy across processes for logging
# Average loss and accuracy across processes for logging
if
args
.
distributed
:
if
args
.
distributed
:
reduced_loss
=
reduce_tensor
(
loss
.
data
)
reduced_loss
=
reduce_tensor
(
loss
.
data
)
prec1
=
reduce_tensor
(
prec1
)
prec1
=
reduce_tensor
(
prec1
)
prec5
=
reduce_tensor
(
prec5
)
prec5
=
reduce_tensor
(
prec5
)
else
:
else
:
reduced_loss
=
loss
.
data
reduced_loss
=
loss
.
data
# to_python_float incurs a host<->device sync
# to_python_float incurs a host<->device sync
losses
.
update
(
to_python_float
(
reduced_loss
),
input
.
size
(
0
))
losses
.
update
(
to_python_float
(
reduced_loss
),
input
.
size
(
0
))
top1
.
update
(
to_python_float
(
prec1
),
input
.
size
(
0
))
top1
.
update
(
to_python_float
(
prec1
),
input
.
size
(
0
))
top5
.
update
(
to_python_float
(
prec5
),
input
.
size
(
0
))
top5
.
update
(
to_python_float
(
prec5
),
input
.
size
(
0
))
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
batch_time
.
update
((
time
.
time
()
-
end
)
/
args
.
print_freq
)
batch_time
.
update
((
time
.
time
()
-
end
)
/
args
.
print_freq
)
end
=
time
.
time
()
end
=
time
.
time
()
...
...
setup.py
View file @
858d7899
...
@@ -2,7 +2,6 @@ import torch
...
@@ -2,7 +2,6 @@ import torch
from
setuptools
import
setup
,
find_packages
from
setuptools
import
setup
,
find_packages
import
subprocess
import
subprocess
from
pip._internal
import
main
as
pipmain
import
sys
import
sys
import
warnings
import
warnings
import
os
import
os
...
@@ -31,10 +30,11 @@ if TORCH_MAJOR == 0 and TORCH_MINOR < 4:
...
@@ -31,10 +30,11 @@ if TORCH_MAJOR == 0 and TORCH_MINOR < 4:
cmdclass
=
{}
cmdclass
=
{}
ext_modules
=
[]
ext_modules
=
[]
extras
=
{}
if
"--pyprof"
in
sys
.
argv
:
if
"--pyprof"
in
sys
.
argv
:
with
open
(
'requirements.txt'
)
as
f
:
with
open
(
'requirements.txt'
)
as
f
:
required_packages
=
f
.
read
().
splitlines
()
required_packages
=
f
.
read
().
splitlines
()
pipmain
([
"install"
]
+
required_packages
)
extras
[
'pyprof'
]
=
required_packages
try
:
try
:
sys
.
argv
.
remove
(
"--pyprof"
)
sys
.
argv
.
remove
(
"--pyprof"
)
except
:
except
:
...
@@ -153,9 +153,7 @@ if "--bnp" in sys.argv:
...
@@ -153,9 +153,7 @@ if "--bnp" in sys.argv:
'nvcc'
:[
'-DCUDA_HAS_FP16=1'
,
'nvcc'
:[
'-DCUDA_HAS_FP16=1'
,
'-D__CUDA_NO_HALF_OPERATORS__'
,
'-D__CUDA_NO_HALF_OPERATORS__'
,
'-D__CUDA_NO_HALF_CONVERSIONS__'
,
'-D__CUDA_NO_HALF_CONVERSIONS__'
,
'-D__CUDA_NO_HALF2_OPERATORS__'
,
'-D__CUDA_NO_HALF2_OPERATORS__'
]
+
version_dependent_macros
}))
'-gencode'
,
'arch=compute_70,code=sm_70'
]
+
version_dependent_macros
}))
if
"--xentropy"
in
sys
.
argv
:
if
"--xentropy"
in
sys
.
argv
:
from
torch.utils.cpp_extension
import
CUDAExtension
from
torch.utils.cpp_extension
import
CUDAExtension
...
@@ -209,4 +207,5 @@ setup(
...
@@ -209,4 +207,5 @@ setup(
description
=
'PyTorch Extensions written by NVIDIA'
,
description
=
'PyTorch Extensions written by NVIDIA'
,
ext_modules
=
ext_modules
,
ext_modules
=
ext_modules
,
cmdclass
=
cmdclass
,
cmdclass
=
cmdclass
,
extras_require
=
extras
,
)
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment