Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
dfd40f9a
Commit
dfd40f9a
authored
Jan 24, 2019
by
Michael Carilli
Browse files
Adding tests, also, don't drop cache during eval.
parent
3b8e5c4a
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
132 additions
and
2 deletions
+132
-2
apex/amp/utils.py
apex/amp/utils.py
+1
-2
tests/run_amp/test_cache.py
tests/run_amp/test_cache.py
+131
-0
No files found.
apex/amp/utils.py
View file @
dfd40f9a
...
@@ -82,7 +82,6 @@ def casted_args(cast_fn, args, kwargs):
...
@@ -82,7 +82,6 @@ def casted_args(cast_fn, args, kwargs):
return
new_args
return
new_args
def
cached_cast
(
cast_fn
,
x
,
cache
):
def
cached_cast
(
cast_fn
,
x
,
cache
):
# print("Calling cached_cast")
if
is_nested
(
x
):
if
is_nested
(
x
):
return
type
(
x
)([
cached_cast
(
y
)
for
y
in
x
])
return
type
(
x
)([
cached_cast
(
y
)
for
y
in
x
])
if
x
in
cache
:
if
x
in
cache
:
...
@@ -97,7 +96,7 @@ def cached_cast(cast_fn, x, cache):
...
@@ -97,7 +96,7 @@ def cached_cast(cast_fn, x, cache):
# and reused from the cache, it will not actually have x as its parent.
# and reused from the cache, it will not actually have x as its parent.
# Therefore, we choose to invalidate the cache (and force refreshing the cast)
# Therefore, we choose to invalidate the cache (and force refreshing the cast)
# if x.requires_grad and cached_x.requires_grad do not match.
# if x.requires_grad and cached_x.requires_grad do not match.
if
x
.
requires_grad
!=
cached_x
.
requires_grad
:
if
torch
.
is_grad_enabled
()
and
x
.
requires_grad
!=
cached_x
.
requires_grad
:
del
cache
[
x
]
del
cache
[
x
]
else
:
else
:
return
cached_x
return
cached_x
...
...
tests/run_amp/test_cache.py
0 → 100644
View file @
dfd40f9a
import
unittest
import
functools
as
ft
import
itertools
as
it
from
apex
import
amp
import
torch
from
torch
import
nn
import
torch.nn.functional
as
F
from
utils
import
common_init
,
HALF
,
FLOAT
,
\
ALWAYS_HALF
,
ALWAYS_FLOAT
,
MATCH_INPUT
def
get_reference_grad
(
i
,
w
,
ops
):
# Creating new tensors ensures, among other things, that the new tensors are not in the cache.
# In fact, they are guaranteed not to use the cache because they are not torch.nn.Parameters.
fp32_i
=
i
.
detach
().
clone
().
float
()
fp32_w
=
w
.
detach
().
clone
().
float
().
requires_grad_
()
loss
=
ops
(
fp32_i
,
fp32_w
)
loss
.
backward
()
return
fp32_w
.
grad
class
WhitelistModule
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dtype
):
super
(
WhitelistModule
,
self
).
__init__
()
self
.
weight
=
torch
.
nn
.
Parameter
(
torch
.
arange
(
8
*
8
,
device
=
'cuda'
,
dtype
=
dtype
).
view
(
8
,
8
))
@
staticmethod
def
ops
(
input
,
weight
):
return
(
input
.
mm
(
weight
)).
mm
(
weight
).
sum
()
def
forward
(
self
,
input
):
return
self
.
ops
(
input
,
self
.
weight
)
class
BlacklistModule
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dtype
):
super
(
BlacklistModule
,
self
).
__init__
()
self
.
weight
=
torch
.
nn
.
Parameter
(
torch
.
arange
(
2
*
8
,
device
=
'cuda'
,
dtype
=
dtype
).
view
(
2
,
8
))
@
staticmethod
def
ops
(
input
,
weight
):
return
(
input
+
torch
.
pow
(
weight
,
2
)
+
torch
.
pow
(
weight
,
2
)).
sum
()
def
forward
(
self
,
input
):
return
self
.
ops
(
input
,
self
.
weight
)
class
PromoteModule
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dtype
):
super
(
PromoteModule
,
self
).
__init__
()
self
.
weight
=
torch
.
nn
.
Parameter
(
torch
.
arange
(
2
*
8
,
device
=
'cuda'
,
dtype
=
dtype
).
view
(
2
,
8
))
@
staticmethod
def
ops
(
input
,
weight
):
return
((
input
*
weight
)
*
weight
).
sum
()
def
forward
(
self
,
input
):
return
self
.
ops
(
input
,
self
.
weight
)
class
TestCache
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
handle
=
amp
.
init
(
enabled
=
True
)
self
.
x
=
torch
.
ones
((
2
,
8
),
device
=
'cuda'
,
dtype
=
torch
.
float32
)
common_init
(
self
)
def
tearDown
(
self
):
self
.
handle
.
_deactivate
()
def
train_eval_train_test
(
self
,
module
,
t
):
model
=
module
(
t
).
cuda
()
dummy_optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
1.0
)
def
training_step
():
for
param
in
model
.
parameters
():
param
.
grad
=
None
loss
=
model
(
self
.
x
).
sum
()
self
.
handle
.
_default_scaler
.
_loss_scale
=
1.0
with
self
.
handle
.
scale_loss
(
loss
,
dummy_optimizer
)
as
scaled_loss
:
scaled_loss
.
backward
()
self
.
assertEqual
(
len
([
p
.
grad
for
p
in
model
.
parameters
()
if
p
.
grad
is
not
None
]),
1
)
self
.
assertEqual
(
model
.
weight
.
grad
.
type
(),
model
.
weight
.
type
())
reference_grad
=
get_reference_grad
(
self
.
x
,
model
.
weight
,
model
.
ops
)
# Currently there's no difference in the allclose calls, so no need for branching,
# but I'm keeping this in case we want different tolerances for fp16 and fp32 checks.
if
model
.
weight
.
grad
.
type
()
==
"torch.cuda.HalfTensor"
:
self
.
assertTrue
(
torch
.
allclose
(
model
.
weight
.
grad
.
float
(),
reference_grad
))
elif
model
.
weight
.
grad
.
type
()
==
"torch.cuda.FloatTensor"
:
self
.
assertTrue
(
torch
.
allclose
(
model
.
weight
.
grad
.
float
(),
reference_grad
))
else
:
raise
RuntimeError
(
"model.weight.grad.type = {}"
.
format
(
model
.
weight
.
grad
.
type
()))
model
.
weight
.
data
-=
1.
# Simulates first epoch
training_step
()
# Simulates eval
with
torch
.
no_grad
():
loss
=
model
(
self
.
x
).
sum
()
# Simulates resuming training after eval
training_step
()
# I could easily have these as a set of for loops in a single test,
# instead of going for granularity.
def
test_whitelist_module_fp16_weight
(
self
):
self
.
train_eval_train_test
(
WhitelistModule
,
torch
.
float16
)
def
test_whitelist_module_fp32_weight
(
self
):
self
.
train_eval_train_test
(
WhitelistModule
,
torch
.
float32
)
def
test_blacklist_module_fp16_weight
(
self
):
self
.
train_eval_train_test
(
BlacklistModule
,
torch
.
float16
)
def
test_blacklist_module_fp32_weight
(
self
):
self
.
train_eval_train_test
(
BlacklistModule
,
torch
.
float32
)
def
test_promote_module_fp16_weight
(
self
):
self
.
train_eval_train_test
(
PromoteModule
,
torch
.
float16
)
def
test_promote_module_fp32_weight
(
self
):
self
.
train_eval_train_test
(
PromoteModule
,
torch
.
float32
)
if
__name__
==
'__main__'
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment