Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
858d7899
Commit
858d7899
authored
Feb 05, 2020
by
Kexin Yu
Browse files
Merge branch 'master' of
https://github.com/NVIDIA/apex
parents
8d2647f8
2ca894da
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
25 additions
and
20 deletions
+25
-20
apex/optimizers/fused_novograd.py
apex/optimizers/fused_novograd.py
+1
-1
examples/imagenet/main_amp.py
examples/imagenet/main_amp.py
+20
-14
setup.py
setup.py
+4
-5
No files found.
apex/optimizers/fused_novograd.py
View file @
858d7899
...
...
@@ -47,7 +47,7 @@ class FusedNovoGrad(torch.optim.Optimizer):
reg_inside_moment (bool, optional): whether do regularization (norm and L2)
in momentum calculation. True for include, False for not include and
only do it on update term. (default: False)
grad_averaging (bool, optional): whether apply (1-beta
2
) to grad when
grad_averaging (bool, optional): whether apply (1-beta
1
) to grad when
calculating running averages of gradient. (default: True)
norm_type (int, optional): which norm to calculate for each layer.
2 for L2 norm, and 0 for infinite norm. These 2 are only supported
...
...
examples/imagenet/main_amp.py
View file @
858d7899
...
...
@@ -25,21 +25,19 @@ try:
except
ImportError
:
raise
ImportError
(
"Please install apex from https://www.github.com/nvidia/apex to run this example."
)
def
fast_collate
(
batch
,
memory_format
):
def
fast_collate
(
batch
):
imgs
=
[
img
[
0
]
for
img
in
batch
]
targets
=
torch
.
tensor
([
target
[
1
]
for
target
in
batch
],
dtype
=
torch
.
int64
)
w
=
imgs
[
0
].
size
[
0
]
h
=
imgs
[
0
].
size
[
1
]
tensor
=
torch
.
zeros
(
(
len
(
imgs
),
3
,
h
,
w
),
dtype
=
torch
.
uint8
)
tensor
=
torch
.
zeros
(
(
len
(
imgs
),
3
,
h
,
w
),
dtype
=
torch
.
uint8
).
contiguous
(
memory_format
=
memory_format
)
for
i
,
img
in
enumerate
(
imgs
):
nump_array
=
np
.
asarray
(
img
,
dtype
=
np
.
uint8
)
if
(
nump_array
.
ndim
<
3
):
nump_array
=
np
.
expand_dims
(
nump_array
,
axis
=-
1
)
nump_array
=
np
.
rollaxis
(
nump_array
,
2
)
tensor
[
i
]
+=
torch
.
from_numpy
(
nump_array
)
return
tensor
,
targets
...
...
@@ -90,6 +88,7 @@ def parse():
parser
.
add_argument
(
'--opt-level'
,
type
=
str
)
parser
.
add_argument
(
'--keep-batchnorm-fp32'
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
'--loss-scale'
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
'--channels-last'
,
type
=
bool
,
default
=
False
)
args
=
parser
.
parse_args
()
return
args
...
...
@@ -127,6 +126,11 @@ def main():
assert
torch
.
backends
.
cudnn
.
enabled
,
"Amp requires cudnn backend to be enabled."
if
args
.
channels_last
:
memory_format
=
torch
.
channels_last
else
:
memory_format
=
torch
.
contiguous_format
# create model
if
args
.
pretrained
:
print
(
"=> using pre-trained model '{}'"
.
format
(
args
.
arch
))
...
...
@@ -140,7 +144,7 @@ def main():
print
(
"using apex synced BN"
)
model
=
apex
.
parallel
.
convert_syncbn_model
(
model
)
model
=
model
.
cuda
()
model
=
model
.
cuda
()
.
to
(
memory_format
=
memory_format
)
# Scale learning rate based on global batch size
args
.
lr
=
args
.
lr
*
float
(
args
.
batch_size
*
args
.
world_size
)
/
256.
...
...
@@ -218,16 +222,18 @@ def main():
train_sampler
=
torch
.
utils
.
data
.
distributed
.
DistributedSampler
(
train_dataset
)
val_sampler
=
torch
.
utils
.
data
.
distributed
.
DistributedSampler
(
val_dataset
)
collate_fn
=
lambda
b
:
fast_collate
(
b
,
memory_format
)
train_loader
=
torch
.
utils
.
data
.
DataLoader
(
train_dataset
,
batch_size
=
args
.
batch_size
,
shuffle
=
(
train_sampler
is
None
),
num_workers
=
args
.
workers
,
pin_memory
=
True
,
sampler
=
train_sampler
,
collate_fn
=
fast_
collate
)
num_workers
=
args
.
workers
,
pin_memory
=
True
,
sampler
=
train_sampler
,
collate_fn
=
collate
_fn
)
val_loader
=
torch
.
utils
.
data
.
DataLoader
(
val_dataset
,
batch_size
=
args
.
batch_size
,
shuffle
=
False
,
num_workers
=
args
.
workers
,
pin_memory
=
True
,
sampler
=
val_sampler
,
collate_fn
=
fast_
collate
)
collate_fn
=
collate
_fn
)
if
args
.
evaluate
:
validate
(
val_loader
,
model
,
criterion
)
...
...
setup.py
View file @
858d7899
...
...
@@ -2,7 +2,6 @@ import torch
from
setuptools
import
setup
,
find_packages
import
subprocess
from
pip._internal
import
main
as
pipmain
import
sys
import
warnings
import
os
...
...
@@ -31,10 +30,11 @@ if TORCH_MAJOR == 0 and TORCH_MINOR < 4:
cmdclass
=
{}
ext_modules
=
[]
extras
=
{}
if
"--pyprof"
in
sys
.
argv
:
with
open
(
'requirements.txt'
)
as
f
:
required_packages
=
f
.
read
().
splitlines
()
pipmain
([
"install"
]
+
required_packages
)
extras
[
'pyprof'
]
=
required_packages
try
:
sys
.
argv
.
remove
(
"--pyprof"
)
except
:
...
...
@@ -153,9 +153,7 @@ if "--bnp" in sys.argv:
'nvcc'
:[
'-DCUDA_HAS_FP16=1'
,
'-D__CUDA_NO_HALF_OPERATORS__'
,
'-D__CUDA_NO_HALF_CONVERSIONS__'
,
'-D__CUDA_NO_HALF2_OPERATORS__'
,
'-gencode'
,
'arch=compute_70,code=sm_70'
]
+
version_dependent_macros
}))
'-D__CUDA_NO_HALF2_OPERATORS__'
]
+
version_dependent_macros
}))
if
"--xentropy"
in
sys
.
argv
:
from
torch.utils.cpp_extension
import
CUDAExtension
...
...
@@ -209,4 +207,5 @@ setup(
description
=
'PyTorch Extensions written by NVIDIA'
,
ext_modules
=
ext_modules
,
cmdclass
=
cmdclass
,
extras_require
=
extras
,
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment