Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
56ea6d78
Commit
56ea6d78
authored
Jan 23, 2019
by
Michael Carilli
Browse files
saving for carl to review
parent
3c7a0e44
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
17 additions
and
7 deletions
+17
-7
apex/amp/utils.py
apex/amp/utils.py
+9
-6
apex/parallel/distributed.py
apex/parallel/distributed.py
+8
-1
No files found.
apex/amp/utils.py
View file @
56ea6d78
...
@@ -82,17 +82,20 @@ def casted_args(cast_fn, args, kwargs):
...
@@ -82,17 +82,20 @@ def casted_args(cast_fn, args, kwargs):
return
new_args
return
new_args
def
cached_cast
(
cast_fn
,
x
,
cache
):
def
cached_cast
(
cast_fn
,
x
,
cache
):
print
(
"Calling cached_cast"
)
if
is_nested
(
x
):
if
is_nested
(
x
):
return
type
(
x
)([
cached_cast
(
y
)
for
y
in
x
])
return
type
(
x
)([
cached_cast
(
y
)
for
y
in
x
])
if
x
in
cache
:
if
x
in
cache
:
cached_x
=
cache
[
x
]
cached_x
=
cache
[
x
]
# During eval, it's possible to end up caching casted weights
if
x
.
requires_grad
and
cached_x
.
requires_grad
:
# with requires_grad == False. This is then a problem when they
# Check to make sure x is actually cached_x's autograd parent.
# get reused on the next train iter. So we ensure that cached
if
cached_x
.
grad_fn
.
next_functions
[
1
][
0
].
variable
is
not
x
:
# weights have same requires_grad flag of most recent request.
raise
RuntimeError
(
"x and cache[x] both require grad, but x is not "
"cache[x]'s parent. This is likely an error."
)
if
x
.
requires_grad
!=
cached_x
.
requires_grad
:
if
x
.
requires_grad
!=
cached_x
.
requires_grad
:
cached_x
.
requires_grad_
(
x
.
requires_grad
)
del
cache
[
x
]
return
cache
[
x
]
else
:
return
cached_x
casted_x
=
cast_fn
(
x
)
casted_x
=
cast_fn
(
x
)
cache
[
x
]
=
casted_x
cache
[
x
]
=
casted_x
...
...
apex/parallel/distributed.py
View file @
56ea6d78
...
@@ -292,7 +292,8 @@ class DistributedDataParallel(Module):
...
@@ -292,7 +292,8 @@ class DistributedDataParallel(Module):
# Sanity checks that all the buckets were kicked off
# Sanity checks that all the buckets were kicked off
if
self
.
next_bucket
!=
self
.
num_buckets
:
if
self
.
next_bucket
!=
self
.
num_buckets
:
raise
RuntimeError
(
"In epilogue, next_bucket != num_buckets. "
raise
RuntimeError
(
"In epilogue, next_bucket ({}) != num_buckets ({}). "
.
format
(
self
.
next_bucket
,
self
.
num_buckets
),
"This probably indicates some buckets were not allreduced."
)
"This probably indicates some buckets were not allreduced."
)
for
actual
,
expected
in
zip
(
self
.
buckets_ready_size
,
self
.
bucket_sizes
):
for
actual
,
expected
in
zip
(
self
.
buckets_ready_size
,
self
.
bucket_sizes
):
...
@@ -389,6 +390,8 @@ class DistributedDataParallel(Module):
...
@@ -389,6 +390,8 @@ class DistributedDataParallel(Module):
def
allreduce_fallback
(
self
):
def
allreduce_fallback
(
self
):
grads
=
[
param
.
grad
.
data
for
param
in
self
.
module
.
parameters
()
if
param
.
grad
is
not
None
]
grads
=
[
param
.
grad
.
data
for
param
in
self
.
module
.
parameters
()
if
param
.
grad
is
not
None
]
print
(
"In allreduce_fallback: {}"
.
format
(
len
(
grads
)))
split_buckets
=
split_half_float_double
(
grads
)
split_buckets
=
split_half_float_double
(
grads
)
# If retain_allreduce_buffers is True and delay_allreduce is False,
# If retain_allreduce_buffers is True and delay_allreduce is False,
...
@@ -413,6 +416,7 @@ class DistributedDataParallel(Module):
...
@@ -413,6 +416,7 @@ class DistributedDataParallel(Module):
self
.
buckets
[
bucket_idx
][
bucket_loc
]
=
param
.
grad
.
data
self
.
buckets
[
bucket_idx
][
bucket_loc
]
=
param
.
grad
.
data
self
.
buckets_ready_size
[
bucket_idx
]
+=
1
self
.
buckets_ready_size
[
bucket_idx
]
+=
1
print
(
self
.
buckets_ready_size
)
if
self
.
buckets_ready_size
[
bucket_idx
]
==
self
.
bucket_sizes
[
bucket_idx
]:
if
self
.
buckets_ready_size
[
bucket_idx
]
==
self
.
bucket_sizes
[
bucket_idx
]:
if
bucket_idx
==
self
.
next_bucket
:
if
bucket_idx
==
self
.
next_bucket
:
...
@@ -472,6 +476,9 @@ class DistributedDataParallel(Module):
...
@@ -472,6 +476,9 @@ class DistributedDataParallel(Module):
self
.
allreduce_buffers
=
[
None
for
_
in
range
(
self
.
num_buckets
)]
self
.
allreduce_buffers
=
[
None
for
_
in
range
(
self
.
num_buckets
)]
self
.
next_bucket
=
0
self
.
next_bucket
=
0
self
.
ready_buckets_not_reduced
=
set
()
self
.
ready_buckets_not_reduced
=
set
()
print
(
len
(
param_list
),
len
(
self
.
active_params
),
[
len
(
b
)
for
b
in
self
.
buckets
],
self
.
needs_refresh
)
self
.
active_params
=
param_list
self
.
active_params
=
param_list
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment