Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
32157739
Commit
32157739
authored
May 15, 2020
by
rohithkrn
Browse files
add tests for O4 and O5 opt levels
parent
ba2407e2
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
130 additions
and
52 deletions
+130
-52
tests/L0/run_amp/test_add_param_group.py
tests/L0/run_amp/test_add_param_group.py
+18
-7
tests/L0/run_amp/test_cache.py
tests/L0/run_amp/test_cache.py
+29
-8
tests/L0/run_amp/test_checkpointing.py
tests/L0/run_amp/test_checkpointing.py
+2
-1
tests/L0/run_amp/test_multi_tensor_axpby.py
tests/L0/run_amp/test_multi_tensor_axpby.py
+7
-4
tests/L0/run_amp/test_multi_tensor_scale.py
tests/L0/run_amp/test_multi_tensor_scale.py
+6
-3
tests/L0/run_amp/test_promotion.py
tests/L0/run_amp/test_promotion.py
+66
-29
tests/L0/run_amp/utils.py
tests/L0/run_amp/utils.py
+2
-0
No files found.
tests/L0/run_amp/test_add_param_group.py
View file @
32157739
...
...
@@ -14,11 +14,11 @@ from utils import common_init, HALF, FLOAT,\
ALWAYS_HALF
,
ALWAYS_FLOAT
,
MATCH_INPUT
class
MyModel
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
unique
):
def
__init__
(
self
,
unique
,
dtype
=
torch
.
float16
):
super
(
MyModel
,
self
).
__init__
()
self
.
weight0
=
Parameter
(
unique
+
torch
.
arange
(
2
,
device
=
'cuda'
,
dtype
=
torch
.
float32
))
self
.
weight1
=
Parameter
(
1.
+
unique
+
torch
.
arange
(
2
,
device
=
'cuda'
,
dtype
=
torch
.
float16
))
self
.
weight1
=
Parameter
(
1.
+
unique
+
torch
.
arange
(
2
,
device
=
'cuda'
,
dtype
=
dtype
))
@
staticmethod
def
ops
(
input
,
weight0
,
weight1
):
...
...
@@ -51,11 +51,15 @@ class TestAddParamGroup(unittest.TestCase):
optimizer
.
zero_grad
()
def
test_add_param_group
(
self
):
for
opt_level
in
(
"O0"
,
"O1"
,
"O2"
,
"O3"
):
for
opt_level
in
(
"O0"
,
"O1"
,
"O2"
,
"O3"
,
"O4"
,
"O5"
):
for
zero_before_add
in
(
True
,
False
):
for
try_accumulation
in
(
True
,
False
):
model0
=
MyModel
(
1
)
model1
=
MyModel
(
2
)
if
opt_level
in
{
"O4"
,
"O5"
}:
model0
=
MyModel
(
1
,
torch
.
bfloat16
)
model1
=
MyModel
(
2
,
torch
.
bfloat16
)
else
:
model0
=
MyModel
(
1
)
model1
=
MyModel
(
2
)
optimizer
=
torch
.
optim
.
SGD
([{
'params'
:
model0
.
parameters
(),
'lr'
:
0.25
}],
momentum
=
0.125
)
...
...
@@ -89,8 +93,12 @@ class TestAddParamGroup(unittest.TestCase):
[
param
.
data
.
clone
()
for
param
in
model1
.
parameters
()]
for
how_to_zero
in
"none"
,
"model"
,
"optimizer"
:
model0
=
MyModel
(
1
)
model1
=
MyModel
(
2
)
if
opt_level
in
{
"O4"
,
"O5"
}:
model0
=
MyModel
(
1
,
torch
.
bfloat16
)
model1
=
MyModel
(
2
,
torch
.
bfloat16
)
else
:
model0
=
MyModel
(
1
)
model1
=
MyModel
(
2
)
optimizer
=
torch
.
optim
.
SGD
([{
'params'
:
model0
.
parameters
(),
'lr'
:
0.25
}],
momentum
=
0.125
)
...
...
@@ -139,6 +147,9 @@ class TestAddParamGroup(unittest.TestCase):
[
param
.
data
.
clone
()
for
param
in
model1
.
parameters
()]
for
reference
,
final
in
zip
(
reference_params
,
final_params
):
# TODO: remove the conversion once allclose supports bfloat16 type.
if
final
.
dtype
==
torch
.
bfloat16
:
final
=
final
.
float
()
self
.
assertTrue
(
torch
.
allclose
(
reference
.
to
(
final
.
dtype
),
final
),
"opt_level = {}, how_to_zero = {}, zero_before_add = {}"
.
format
(
opt_level
,
how_to_zero
,
zero_before_add
))
...
...
tests/L0/run_amp/test_cache.py
View file @
32157739
...
...
@@ -67,12 +67,12 @@ class TestCache(unittest.TestCase):
def
tearDown
(
self
):
pass
def
train_eval_train_test
(
self
,
module
,
t
):
def
train_eval_train_test
(
self
,
module
,
t
,
opt_level
):
model
=
module
(
t
).
cuda
()
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
1.0
)
_amp_state
.
allow_incoming_model_not_fp32
=
True
model
,
optimizer
=
amp
.
initialize
(
model
,
optimizer
,
opt_level
=
"O1"
,
verbosity
=
0
)
model
,
optimizer
=
amp
.
initialize
(
model
,
optimizer
,
opt_level
=
opt_level
,
verbosity
=
0
)
_amp_state
.
allow_incoming_model_not_fp32
=
False
def
training_step
():
...
...
@@ -93,6 +93,8 @@ class TestCache(unittest.TestCase):
# but I'm keeping this in case we want different tolerances for fp16 and fp32 checks.
if
model
.
weight
.
grad
.
type
()
==
"torch.cuda.HalfTensor"
:
self
.
assertTrue
(
torch
.
allclose
(
model
.
weight
.
grad
.
float
(),
reference_grad
))
elif
model
.
weight
.
grad
.
type
()
==
"torch.cuda.BFloat16Tensor"
:
self
.
assertTrue
(
torch
.
allclose
(
model
.
weight
.
grad
.
float
(),
reference_grad
))
elif
model
.
weight
.
grad
.
type
()
==
"torch.cuda.FloatTensor"
:
self
.
assertTrue
(
torch
.
allclose
(
model
.
weight
.
grad
.
float
(),
reference_grad
))
else
:
...
...
@@ -115,22 +117,41 @@ class TestCache(unittest.TestCase):
# I could easily have these as a set of for loops in a single test,
# instead of going for granularity.
def
test_whitelist_module_fp16_weight
(
self
):
self
.
train_eval_train_test
(
WhitelistModule
,
torch
.
float16
)
self
.
train_eval_train_test
(
WhitelistModule
,
torch
.
float16
,
"O1"
)
def
test_whitelist_module_fp32_weight
(
self
):
self
.
train_eval_train_test
(
WhitelistModule
,
torch
.
float32
)
self
.
train_eval_train_test
(
WhitelistModule
,
torch
.
float32
,
"O1"
)
def
test_blacklist_module_fp16_weight
(
self
):
self
.
train_eval_train_test
(
BlacklistModule
,
torch
.
float16
)
self
.
train_eval_train_test
(
BlacklistModule
,
torch
.
float16
,
"O1"
)
def
test_blacklist_module_fp32_weight
(
self
):
self
.
train_eval_train_test
(
BlacklistModule
,
torch
.
float32
)
self
.
train_eval_train_test
(
BlacklistModule
,
torch
.
float32
,
"O1"
)
def
test_promote_module_fp16_weight
(
self
):
self
.
train_eval_train_test
(
PromoteModule
,
torch
.
float16
)
self
.
train_eval_train_test
(
PromoteModule
,
torch
.
float16
,
"O1"
)
def
test_promote_module_fp32_weight
(
self
):
self
.
train_eval_train_test
(
PromoteModule
,
torch
.
float32
,
"O1"
)
# opt_level = O4
def
test_whitelist_module_bfp16_weight
(
self
):
self
.
train_eval_train_test
(
WhitelistModule
,
torch
.
bfloat16
,
"O4"
)
def
test_whitelist_module_fp32_weight
(
self
):
self
.
train_eval_train_test
(
WhitelistModule
,
torch
.
float32
,
"O4"
)
def
test_blacklist_module_bfp16_weight
(
self
):
self
.
train_eval_train_test
(
BlacklistModule
,
torch
.
bfloat16
,
"O4"
)
def
test_blacklist_module_fp32_weight
(
self
):
self
.
train_eval_train_test
(
BlacklistModule
,
torch
.
float32
,
"O4"
)
def
test_promote_module_bfp16_weight
(
self
):
self
.
train_eval_train_test
(
PromoteModule
,
torch
.
bfloat16
,
"O4"
)
def
test_promote_module_fp32_weight
(
self
):
self
.
train_eval_train_test
(
PromoteModule
,
torch
.
float32
)
self
.
train_eval_train_test
(
PromoteModule
,
torch
.
float32
,
"O4"
)
if
__name__
==
'__main__'
:
...
...
tests/L0/run_amp/test_checkpointing.py
View file @
32157739
...
...
@@ -28,7 +28,7 @@ class MyModel(torch.nn.Module):
class
TestCheckpointing
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
initial_lr
=
1e-3
self
.
test_opt_levels
=
(
"O0"
,
"O1"
,
"O2"
,
"O3"
)
self
.
test_opt_levels
=
(
"O0"
,
"O1"
,
"O2"
,
"O3"
,
"O4"
,
"O5"
)
def
seed
(
self
):
torch
.
manual_seed
(
2809
)
...
...
@@ -236,6 +236,7 @@ class TestCheckpointing(unittest.TestCase):
state_dict
=
model
.
state_dict
()
for
key
in
state_dict
:
self
.
assertFalse
(
'Half'
in
state_dict
[
key
].
type
())
self
.
assertFalse
(
'BFloat16'
in
state_dict
[
key
].
type
())
# Check, if model is still trainable
# Create dummy data
...
...
tests/L0/run_amp/test_multi_tensor_axpby.py
View file @
32157739
...
...
@@ -69,7 +69,10 @@ class TestMultiTensorAxpby(unittest.TestCase):
applier
(
multi_tensor_axpby
,
self
.
overflow_buf
,
[
x_list
,
y_list
,
out_list
],
self
.
a
,
self
.
b
,
-
1
)
self
.
assertTrue
(
all
([
torch
.
allclose
(
out
,
self
.
ref
.
to
(
out_type
))
for
out
in
out_list
]),
# TODO: Remove this workaround for bfloat16 after torch.allcose() support bfloat16
if
out_type
==
torch
.
bfloat16
:
out_list
=
[
out
.
float
()
for
out
in
out_list
]
self
.
assertTrue
(
all
([
torch
.
allclose
(
out
,
self
.
ref
.
to
(
out
.
dtype
))
for
out
in
out_list
]),
msg
=
"{} {} {} {} {} {} {}"
.
format
(
sizea
,
sizeb
,
repeat_tensors
,
x_type
,
y_type
,
out_type
,
inplace
))
self
.
assertTrue
(
self
.
overflow_buf
.
item
()
==
0
,
...
...
@@ -119,9 +122,9 @@ class TestMultiTensorAxpby(unittest.TestCase):
for
sizea
,
sizeb
in
input_size_pairs
:
for
applier
in
appliers
:
for
repeat
in
repeat_tensors
:
for
x_type
in
(
torch
.
float32
,
torch
.
float16
):
for
y_type
in
(
torch
.
float32
,
torch
.
float16
):
for
out_type
in
(
torch
.
float32
,
torch
.
float16
):
for
x_type
in
(
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
):
for
y_type
in
(
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
):
for
out_type
in
(
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
):
for
inplace
in
(
True
,
False
):
if
inplace
is
True
and
(
y_type
is
not
out_type
):
continue
...
...
tests/L0/run_amp/test_multi_tensor_scale.py
View file @
32157739
...
...
@@ -49,7 +49,10 @@ class TestMultiTensorScale(unittest.TestCase):
applier
(
multi_tensor_scale
,
self
.
overflow_buf
,
[
in_list
,
out_list
],
1.
/
self
.
scale
)
self
.
assertTrue
(
all
([
torch
.
allclose
(
out
,
self
.
ref
.
to
(
out_type
))
for
out
in
out_list
]))
# TODO: Remove this workaround for bfloat16 after torch.allcose() support bfloat16
if
out_type
==
torch
.
bfloat16
:
out_list
=
[
out
.
float
()
for
out
in
out_list
]
self
.
assertTrue
(
all
([
torch
.
allclose
(
out
,
self
.
ref
.
to
(
out
.
dtype
))
for
out
in
out_list
]))
self
.
assertTrue
(
self
.
overflow_buf
.
item
()
==
0
)
def
find_inf
(
self
,
sizea
,
sizeb
,
applier
,
repeat_tensors
,
in_type
,
out_type
,
t
,
ind
,
val
,
inplace
=
False
):
...
...
@@ -106,8 +109,8 @@ class TestMultiTensorScale(unittest.TestCase):
for
sizea
,
sizeb
in
input_size_pairs
:
for
applier
in
appliers
:
for
repeat
in
repeat_tensors
:
for
in_type
in
(
torch
.
float32
,
torch
.
float16
):
for
out_type
in
(
torch
.
float32
,
torch
.
float16
):
for
in_type
in
(
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
):
for
out_type
in
(
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
):
for
inplace
in
(
True
,
False
):
if
inplace
is
True
and
(
out_type
is
not
in_type
):
continue
...
...
tests/L0/run_amp/test_promotion.py
View file @
32157739
...
...
@@ -7,18 +7,18 @@ import torch
from
torch
import
nn
import
torch.nn.functional
as
F
from
utils
import
common_init
,
HALF
,
FLOAT
,
DTYPES
from
utils
import
common_init
,
HALF
,
FLOAT
,
DTYPES
,
DTYPES2
,
MATCH_INPUT
class
TestPromotion
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
handl
e
=
amp
.
init
(
enabled
=
True
)
common_init
(
self
)
def
tearDown
(
self
):
s
el
f
.
handle
.
_deactivate
()
def
run_binary_promote_test
(
self
,
fns
,
input_shape
,
x_inplace
=
False
):
type_pairs
=
it
.
product
(
DTYPES
,
DTYPES
)
class
_
TestPromotion
(
unittest
.
TestCase
):
def
run_binary_promote_test
(
self
,
fns
,
input_shape
,
lp_type
,
x_inplace
=
False
):
if
lp_typ
e
=
=
torch
.
half
:
dtypes
=
DTYPES
elif
lp_type
==
torch
.
bfloat16
:
dtypes
=
DTYPES2
el
se
:
raise
RuntimeError
(
"Creating test class with invalid low_precision type.
\
Supported types are torch.half and torch.bfloat16"
)
type_pairs
=
it
.
product
(
dtypes
,
dtypes
)
for
fn
,
(
xtype
,
ytype
)
in
it
.
product
(
fns
,
type_pairs
):
x
=
torch
.
randn
(
input_shape
,
dtype
=
xtype
).
requires_grad_
()
x_leaf
=
x
...
...
@@ -35,41 +35,78 @@ class TestPromotion(unittest.TestCase):
if
xtype
==
torch
.
float
or
ytype
==
torch
.
float
:
self
.
assertEqual
(
out
.
type
(),
FLOAT
)
else
:
self
.
assertEqual
(
out
.
type
(),
HALF
)
self
.
assertEqual
(
out
.
type
(),
MATCH_INPUT
[
lp_type
]
)
out
.
float
().
sum
().
backward
()
self
.
assertEqual
(
x_leaf
.
grad
.
dtype
,
xtype
)
def
_test_cat_matches_widest
(
self
,
lp_type
):
shape
=
self
.
b
ys
=
[
torch
.
randn
(
shape
,
dtype
=
lp_type
)
for
_
in
range
(
5
)]
x_float
=
torch
.
randn
(
shape
)
out
=
torch
.
cat
(
ys
+
[
x_float
])
self
.
assertEqual
(
out
.
type
(),
FLOAT
)
x_lp
=
torch
.
randn
(
shape
,
dtype
=
lp_type
)
out
=
torch
.
cat
(
ys
+
[
x_lp
])
self
.
assertEqual
(
out
.
type
(),
MATCH_INPUT
[
lp_type
])
def
_test_inplace_exp_is_error_for_lp
(
self
,
lp_type
):
xs
=
torch
.
randn
(
self
.
b
)
xs
.
exp_
()
self
.
assertEqual
(
xs
.
type
(),
FLOAT
)
xs
=
torch
.
randn
(
self
.
b
,
dtype
=
lp_type
)
with
self
.
assertRaises
(
NotImplementedError
):
xs
.
exp_
()
class
TestPromotionHalf
(
_TestPromotion
):
def
setUp
(
self
):
self
.
handle
=
amp
.
init
(
enabled
=
True
,
patch_type
=
torch
.
half
)
common_init
(
self
)
def
tearDown
(
self
):
self
.
handle
.
_deactivate
()
def
test_atan2_matches_widest
(
self
):
fns
=
[
lambda
x
,
y
:
torch
.
atan2
(
x
,
y
),
lambda
x
,
y
:
x
.
atan2
(
y
)]
self
.
run_binary_promote_test
(
fns
,
(
self
.
b
,))
self
.
run_binary_promote_test
(
fns
,
(
self
.
b
,)
,
torch
.
half
)
def
test_mul_matches_widest
(
self
):
fns
=
[
lambda
x
,
y
:
torch
.
mul
(
x
,
y
),
lambda
x
,
y
:
x
.
mul
(
y
)]
self
.
run_binary_promote_test
(
fns
,
(
self
.
b
,))
self
.
run_binary_promote_test
(
fns
,
(
self
.
b
,)
,
torch
.
half
)
def
test_cat_matches_widest
(
self
):
shape
=
self
.
b
ys
=
[
torch
.
randn
(
shape
,
dtype
=
torch
.
half
)
for
_
in
range
(
5
)]
x_float
=
torch
.
randn
(
shape
)
out
=
torch
.
cat
(
ys
+
[
x_float
])
self
.
assertEqual
(
out
.
type
(),
FLOAT
)
x_half
=
torch
.
randn
(
shape
,
dtype
=
torch
.
half
)
out
=
torch
.
cat
(
ys
+
[
x_half
])
self
.
assertEqual
(
out
.
type
(),
HALF
)
self
.
_test_cat_matches_widest
(
torch
.
half
)
def
test_inplace_exp_is_error_for_half
(
self
):
xs
=
torch
.
randn
(
self
.
b
)
xs
.
exp_
()
self
.
assertEqual
(
xs
.
type
(),
FLOAT
)
xs
=
torch
.
randn
(
self
.
b
,
dtype
=
torch
.
half
)
with
self
.
assertRaises
(
NotImplementedError
):
xs
.
exp_
()
self
.
_test_inplace_exp_is_error_for_lp
(
torch
.
half
)
def
test_inplace_add_matches_self
(
self
):
fn
=
lambda
x
,
y
:
x
.
add_
(
y
)
self
.
run_binary_promote_test
([
fn
],
(
self
.
b
,),
torch
.
half
,
x_inplace
=
True
)
class
TestPromotionBFloat16
(
_TestPromotion
):
def
setUp
(
self
):
self
.
handle
=
amp
.
init
(
enabled
=
True
,
patch_type
=
torch
.
bfloat16
)
common_init
(
self
)
def
tearDown
(
self
):
self
.
handle
.
_deactivate
()
def
test_mul_matches_widest
(
self
):
fns
=
[
lambda
x
,
y
:
torch
.
mul
(
x
,
y
),
lambda
x
,
y
:
x
.
mul
(
y
)]
self
.
run_binary_promote_test
(
fns
,
(
self
.
b
,),
torch
.
bfloat16
)
def
test_cat_matches_widest
(
self
):
self
.
_test_cat_matches_widest
(
torch
.
bfloat16
)
def
test_inplace_exp_is_error_for_bfloat16
(
self
):
self
.
_test_inplace_exp_is_error_for_lp
(
torch
.
bfloat16
)
def
test_inplace_add_matches_self
(
self
):
fn
=
lambda
x
,
y
:
x
.
add_
(
y
)
self
.
run_binary_promote_test
([
fn
],
(
self
.
b
,),
x_inplace
=
True
)
self
.
run_binary_promote_test
([
fn
],
(
self
.
b
,),
torch
.
bfloat16
,
x_inplace
=
True
)
if
__name__
==
'__main__'
:
unittest
.
main
()
tests/L0/run_amp/utils.py
View file @
32157739
...
...
@@ -6,6 +6,8 @@ BFLOAT16 = 'torch.cuda.BFloat16Tensor'
DTYPES
=
[
torch
.
half
,
torch
.
float
]
DTYPES2
=
[
torch
.
bfloat16
,
torch
.
float
]
ALWAYS_HALF
=
{
torch
.
float
:
HALF
,
torch
.
half
:
HALF
}
ALWAYS_BFLOAT16
=
{
torch
.
bfloat16
:
BFLOAT16
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment