Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
61b452e8
"vscode:/vscode.git/clone" did not exist on "18ead1935557a2d11cac44bb5dfd82f3d63ea682"
Commit
61b452e8
authored
Jun 05, 2018
by
Michael Carilli
Browse files
Merging latest master
parents
fb7d4e1d
6143b30f
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
71 additions
and
7 deletions
+71
-7
apex/amp/compat.py
apex/amp/compat.py
+4
-0
apex/amp/lists/tensor_overrides.py
apex/amp/lists/tensor_overrides.py
+1
-1
apex/amp/wrap.py
apex/amp/wrap.py
+1
-3
apex/parallel/LARC.py
apex/parallel/LARC.py
+53
-0
apex/parallel/multiproc.py
apex/parallel/multiproc.py
+1
-1
examples/imagenet/main.py
examples/imagenet/main.py
+11
-2
No files found.
apex/amp/compat.py
View file @
61b452e8
...
@@ -5,6 +5,10 @@ def variable_is_tensor():
...
@@ -5,6 +5,10 @@ def variable_is_tensor():
v
=
torch
.
autograd
.
Variable
()
v
=
torch
.
autograd
.
Variable
()
return
isinstance
(
v
,
torch
.
Tensor
)
return
isinstance
(
v
,
torch
.
Tensor
)
def
tensor_is_variable
():
x
=
torch
.
Tensor
()
return
type
(
x
)
==
torch
.
autograd
.
Variable
# False for post-0.4
# False for post-0.4
def
tensor_is_float_tensor
():
def
tensor_is_float_tensor
():
x
=
torch
.
Tensor
()
x
=
torch
.
Tensor
()
...
...
apex/amp/lists/tensor_overrides.py
View file @
61b452e8
...
@@ -5,7 +5,7 @@ import importlib
...
@@ -5,7 +5,7 @@ import importlib
import
torch
import
torch
if
compat
.
variable_is_tensor
():
if
compat
.
variable_is_tensor
()
and
not
compat
.
tensor_is_variable
()
:
MODULE
=
torch
.
Tensor
MODULE
=
torch
.
Tensor
else
:
else
:
MODULE
=
torch
.
autograd
.
Variable
MODULE
=
torch
.
autograd
.
Variable
...
...
apex/amp/wrap.py
View file @
61b452e8
...
@@ -8,8 +8,6 @@ import torch
...
@@ -8,8 +8,6 @@ import torch
def
cached_cast
(
mod
,
fn
,
cast_fn
,
handle
,
def
cached_cast
(
mod
,
fn
,
cast_fn
,
handle
,
try_caching
=
False
,
verbose
=
False
):
try_caching
=
False
,
verbose
=
False
):
if
not
utils
.
has_func
(
mod
,
fn
):
if
not
utils
.
has_func
(
mod
,
fn
):
# Should happen only pre-0.4
assert
not
compat
.
variable_is_tensor
()
return
return
orig_fn
=
utils
.
get_func
(
mod
,
fn
)
orig_fn
=
utils
.
get_func
(
mod
,
fn
)
...
@@ -140,7 +138,7 @@ def rnn_cast(backend, fn, verbose=False):
...
@@ -140,7 +138,7 @@ def rnn_cast(backend, fn, verbose=False):
# autograd graph correctly backprops from the wgrads computed
# autograd graph correctly backprops from the wgrads computed
# inside cuDNN (on fp16 weights) into the fp32 weights.
# inside cuDNN (on fp16 weights) into the fp32 weights.
assert
utils
.
type_string
(
flat_weight
)
==
'FloatTensor'
assert
utils
.
type_string
(
flat_weight
)
==
'FloatTensor'
if
compat
.
tensor_is_float_tensor
():
if
compat
.
tensor_is_float_tensor
()
or
compat
.
tensor_is_variable
()
:
# Pre-0.4. A little slower, since it zeros out memory.
# Pre-0.4. A little slower, since it zeros out memory.
flat_weight_fp16
=
flat_weight
.
new
().
half
().
resize_
(
flat_weight
.
shape
)
flat_weight_fp16
=
flat_weight
.
new
().
half
().
resize_
(
flat_weight
.
shape
)
else
:
else
:
...
...
apex/parallel/LARC.py
0 → 100644
View file @
61b452e8
import
torch
from
torch
import
nn
from
torch.autograd
import
Variable
from
torch.nn.parameter
import
Parameter
class
LARC
(
object
):
def
__init__
(
self
,
optimizer
,
trust_coefficient
=
0.02
,
epsilon
=
1e-8
):
self
.
param_groups
=
optimizer
.
param_groups
self
.
optim
=
optimizer
self
.
trust_coefficient
=
trust_coefficient
self
.
eps
=
epsilon
def
__getstate__
(
self
):
return
self
.
optim
.
__getstate__
()
def
__setstate__
(
self
,
state
):
self
.
optim
.
__setstate__
(
state
)
def
__repr__
(
self
):
return
self
.
optim
.
__repr__
()
def
state_dict
(
self
):
return
self
.
optim
.
state_dict
()
def
load_state_dict
(
self
,
state_dict
):
self
.
optim
.
load_state_dict
(
state_dict
)
def
zero_grad
(
self
):
self
.
optim
.
zero_grad
()
def
add_param_group
(
self
,
param_group
):
self
.
optim
.
add_param_group
(
param_group
)
def
step
(
self
):
with
torch
.
no_grad
():
weight_decays
=
[]
for
group
in
self
.
optim
.
param_groups
:
# absorb weight decay control from optimizer
weight_decay
=
group
[
'weight_decay'
]
if
'weight_decay'
in
group
else
0
weight_decays
.
append
(
weight_decay
)
group
[
'weight_decay'
]
=
0
for
p
in
group
[
'params'
]:
if
p
.
grad
is
None
:
continue
param_norm
=
torch
.
norm
(
p
.
data
)
# calculate adaptive lr + weight decay
adaptive_lr
=
(
param_norm
+
self
.
eps
)
/
(
torch
.
norm
(
p
.
grad
.
data
)
+
param_norm
*
weight_decay
+
self
.
eps
)
p
.
grad
.
data
+=
weight_decay
*
p
.
data
p
.
grad
.
data
*=
self
.
trust_coefficient
*
adaptive_lr
self
.
optim
.
step
()
# return weight decay control to optimizer
for
i
,
group
in
enumerate
(
self
.
optim
.
param_groups
):
group
[
'weight_decay'
]
=
weight_decays
[
i
]
apex/parallel/multiproc.py
View file @
61b452e8
...
@@ -13,7 +13,7 @@ argslist = list(sys.argv)[1:]
...
@@ -13,7 +13,7 @@ argslist = list(sys.argv)[1:]
world_size
=
torch
.
cuda
.
device_count
()
world_size
=
torch
.
cuda
.
device_count
()
if
'--world-size'
in
argslist
:
if
'--world-size'
in
argslist
:
argslist
[
argslist
.
index
(
'--world-size'
)
+
1
]
=
str
(
world_size
)
world_size
=
int
(
argslist
[
argslist
.
index
(
'--world-size'
)
+
1
])
else
:
else
:
argslist
.
append
(
'--world-size'
)
argslist
.
append
(
'--world-size'
)
argslist
.
append
(
str
(
world_size
))
argslist
.
append
(
str
(
world_size
))
...
...
examples/imagenet/main.py
View file @
61b452e8
...
@@ -300,6 +300,7 @@ def train(train_loader, model, criterion, optimizer, epoch):
...
@@ -300,6 +300,7 @@ def train(train_loader, model, criterion, optimizer, epoch):
loss
.
backward
()
loss
.
backward
()
optimizer
.
step
()
optimizer
.
step
()
torch
.
cuda
.
synchronize
()
# measure elapsed time
# measure elapsed time
batch_time
.
update
(
time
.
time
()
-
end
)
batch_time
.
update
(
time
.
time
()
-
end
)
...
@@ -309,11 +310,15 @@ def train(train_loader, model, criterion, optimizer, epoch):
...
@@ -309,11 +310,15 @@ def train(train_loader, model, criterion, optimizer, epoch):
if
args
.
rank
==
0
and
i
%
args
.
print_freq
==
0
and
i
>
1
:
if
args
.
rank
==
0
and
i
%
args
.
print_freq
==
0
and
i
>
1
:
print
(
'Epoch: [{0}][{1}/{2}]
\t
'
print
(
'Epoch: [{0}][{1}/{2}]
\t
'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})
\t
'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})
\t
'
'Speed {3:.3f} ({4:.3f})
\t
'
'Data {data_time.val:.3f} ({data_time.avg:.3f})
\t
'
'Data {data_time.val:.3f} ({data_time.avg:.3f})
\t
'
'Loss {loss.val:.4f} ({loss.avg:.4f})
\t
'
'Loss {loss.val:.4f} ({loss.avg:.4f})
\t
'
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})
\t
'
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})
\t
'
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'
.
format
(
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'
.
format
(
epoch
,
i
,
len
(
train_loader
),
batch_time
=
batch_time
,
epoch
,
i
,
len
(
train_loader
),
args
.
world_size
*
args
.
batch_size
/
batch_time
.
val
,
args
.
world_size
*
args
.
batch_size
/
batch_time
.
avg
,
batch_time
=
batch_time
,
data_time
=
data_time
,
loss
=
losses
,
top1
=
top1
,
top5
=
top5
))
data_time
=
data_time
,
loss
=
losses
,
top1
=
top1
,
top5
=
top5
))
...
@@ -362,10 +367,14 @@ def validate(val_loader, model, criterion):
...
@@ -362,10 +367,14 @@ def validate(val_loader, model, criterion):
if
args
.
rank
==
0
and
i
%
args
.
print_freq
==
0
:
if
args
.
rank
==
0
and
i
%
args
.
print_freq
==
0
:
print
(
'Test: [{0}/{1}]
\t
'
print
(
'Test: [{0}/{1}]
\t
'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})
\t
'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})
\t
'
'Speed {2:.3f} ({3:.3f})
\t
'
'Loss {loss.val:.4f} ({loss.avg:.4f})
\t
'
'Loss {loss.val:.4f} ({loss.avg:.4f})
\t
'
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})
\t
'
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})
\t
'
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'
.
format
(
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'
.
format
(
i
,
len
(
val_loader
),
batch_time
=
batch_time
,
loss
=
losses
,
i
,
len
(
val_loader
),
args
.
world_size
*
args
.
batch_size
/
batch_time
.
val
,
args
.
world_size
*
args
.
batch_size
/
batch_time
.
avg
,
batch_time
=
batch_time
,
loss
=
losses
,
top1
=
top1
,
top5
=
top5
))
top1
=
top1
,
top5
=
top5
))
input
,
target
=
prefetcher
.
next
()
input
,
target
=
prefetcher
.
next
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment