Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
32157739
"git@developer.sourcefind.cn:modelzoo/alphafold2_jax.git" did not exist on "1817d71e0b63babfd94d239be0a28a2c66d04f1b"
Commit
32157739
authored
May 15, 2020
by
rohithkrn
Browse files
add tests for O4 and O5 opt levels
parent
ba2407e2
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
130 additions
and
52 deletions
+130
-52
tests/L0/run_amp/test_add_param_group.py
tests/L0/run_amp/test_add_param_group.py
+18
-7
tests/L0/run_amp/test_cache.py
tests/L0/run_amp/test_cache.py
+29
-8
tests/L0/run_amp/test_checkpointing.py
tests/L0/run_amp/test_checkpointing.py
+2
-1
tests/L0/run_amp/test_multi_tensor_axpby.py
tests/L0/run_amp/test_multi_tensor_axpby.py
+7
-4
tests/L0/run_amp/test_multi_tensor_scale.py
tests/L0/run_amp/test_multi_tensor_scale.py
+6
-3
tests/L0/run_amp/test_promotion.py
tests/L0/run_amp/test_promotion.py
+66
-29
tests/L0/run_amp/utils.py
tests/L0/run_amp/utils.py
+2
-0
No files found.
tests/L0/run_amp/test_add_param_group.py
View file @
32157739
...
@@ -14,11 +14,11 @@ from utils import common_init, HALF, FLOAT,\
...
@@ -14,11 +14,11 @@ from utils import common_init, HALF, FLOAT,\
ALWAYS_HALF
,
ALWAYS_FLOAT
,
MATCH_INPUT
ALWAYS_HALF
,
ALWAYS_FLOAT
,
MATCH_INPUT
class
MyModel
(
torch
.
nn
.
Module
):
class
MyModel
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
unique
):
def
__init__
(
self
,
unique
,
dtype
=
torch
.
float16
):
super
(
MyModel
,
self
).
__init__
()
super
(
MyModel
,
self
).
__init__
()
self
.
weight0
=
Parameter
(
unique
+
self
.
weight0
=
Parameter
(
unique
+
torch
.
arange
(
2
,
device
=
'cuda'
,
dtype
=
torch
.
float32
))
torch
.
arange
(
2
,
device
=
'cuda'
,
dtype
=
torch
.
float32
))
self
.
weight1
=
Parameter
(
1.
+
unique
+
torch
.
arange
(
2
,
device
=
'cuda'
,
dtype
=
torch
.
float16
))
self
.
weight1
=
Parameter
(
1.
+
unique
+
torch
.
arange
(
2
,
device
=
'cuda'
,
dtype
=
dtype
))
@
staticmethod
@
staticmethod
def
ops
(
input
,
weight0
,
weight1
):
def
ops
(
input
,
weight0
,
weight1
):
...
@@ -51,9 +51,13 @@ class TestAddParamGroup(unittest.TestCase):
...
@@ -51,9 +51,13 @@ class TestAddParamGroup(unittest.TestCase):
optimizer
.
zero_grad
()
optimizer
.
zero_grad
()
def
test_add_param_group
(
self
):
def
test_add_param_group
(
self
):
for
opt_level
in
(
"O0"
,
"O1"
,
"O2"
,
"O3"
):
for
opt_level
in
(
"O0"
,
"O1"
,
"O2"
,
"O3"
,
"O4"
,
"O5"
):
for
zero_before_add
in
(
True
,
False
):
for
zero_before_add
in
(
True
,
False
):
for
try_accumulation
in
(
True
,
False
):
for
try_accumulation
in
(
True
,
False
):
if
opt_level
in
{
"O4"
,
"O5"
}:
model0
=
MyModel
(
1
,
torch
.
bfloat16
)
model1
=
MyModel
(
2
,
torch
.
bfloat16
)
else
:
model0
=
MyModel
(
1
)
model0
=
MyModel
(
1
)
model1
=
MyModel
(
2
)
model1
=
MyModel
(
2
)
...
@@ -89,6 +93,10 @@ class TestAddParamGroup(unittest.TestCase):
...
@@ -89,6 +93,10 @@ class TestAddParamGroup(unittest.TestCase):
[
param
.
data
.
clone
()
for
param
in
model1
.
parameters
()]
[
param
.
data
.
clone
()
for
param
in
model1
.
parameters
()]
for
how_to_zero
in
"none"
,
"model"
,
"optimizer"
:
for
how_to_zero
in
"none"
,
"model"
,
"optimizer"
:
if
opt_level
in
{
"O4"
,
"O5"
}:
model0
=
MyModel
(
1
,
torch
.
bfloat16
)
model1
=
MyModel
(
2
,
torch
.
bfloat16
)
else
:
model0
=
MyModel
(
1
)
model0
=
MyModel
(
1
)
model1
=
MyModel
(
2
)
model1
=
MyModel
(
2
)
...
@@ -139,6 +147,9 @@ class TestAddParamGroup(unittest.TestCase):
...
@@ -139,6 +147,9 @@ class TestAddParamGroup(unittest.TestCase):
[
param
.
data
.
clone
()
for
param
in
model1
.
parameters
()]
[
param
.
data
.
clone
()
for
param
in
model1
.
parameters
()]
for
reference
,
final
in
zip
(
reference_params
,
final_params
):
for
reference
,
final
in
zip
(
reference_params
,
final_params
):
# TODO: remove the conversion once allclose supports bfloat16 type.
if
final
.
dtype
==
torch
.
bfloat16
:
final
=
final
.
float
()
self
.
assertTrue
(
torch
.
allclose
(
reference
.
to
(
final
.
dtype
),
final
),
self
.
assertTrue
(
torch
.
allclose
(
reference
.
to
(
final
.
dtype
),
final
),
"opt_level = {}, how_to_zero = {}, zero_before_add = {}"
.
format
(
"opt_level = {}, how_to_zero = {}, zero_before_add = {}"
.
format
(
opt_level
,
how_to_zero
,
zero_before_add
))
opt_level
,
how_to_zero
,
zero_before_add
))
...
...
tests/L0/run_amp/test_cache.py
View file @
32157739
...
@@ -67,12 +67,12 @@ class TestCache(unittest.TestCase):
...
@@ -67,12 +67,12 @@ class TestCache(unittest.TestCase):
def
tearDown
(
self
):
def
tearDown
(
self
):
pass
pass
def
train_eval_train_test
(
self
,
module
,
t
):
def
train_eval_train_test
(
self
,
module
,
t
,
opt_level
):
model
=
module
(
t
).
cuda
()
model
=
module
(
t
).
cuda
()
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
1.0
)
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
1.0
)
_amp_state
.
allow_incoming_model_not_fp32
=
True
_amp_state
.
allow_incoming_model_not_fp32
=
True
model
,
optimizer
=
amp
.
initialize
(
model
,
optimizer
,
opt_level
=
"O1"
,
verbosity
=
0
)
model
,
optimizer
=
amp
.
initialize
(
model
,
optimizer
,
opt_level
=
opt_level
,
verbosity
=
0
)
_amp_state
.
allow_incoming_model_not_fp32
=
False
_amp_state
.
allow_incoming_model_not_fp32
=
False
def
training_step
():
def
training_step
():
...
@@ -93,6 +93,8 @@ class TestCache(unittest.TestCase):
...
@@ -93,6 +93,8 @@ class TestCache(unittest.TestCase):
# but I'm keeping this in case we want different tolerances for fp16 and fp32 checks.
# but I'm keeping this in case we want different tolerances for fp16 and fp32 checks.
if
model
.
weight
.
grad
.
type
()
==
"torch.cuda.HalfTensor"
:
if
model
.
weight
.
grad
.
type
()
==
"torch.cuda.HalfTensor"
:
self
.
assertTrue
(
torch
.
allclose
(
model
.
weight
.
grad
.
float
(),
reference_grad
))
self
.
assertTrue
(
torch
.
allclose
(
model
.
weight
.
grad
.
float
(),
reference_grad
))
elif
model
.
weight
.
grad
.
type
()
==
"torch.cuda.BFloat16Tensor"
:
self
.
assertTrue
(
torch
.
allclose
(
model
.
weight
.
grad
.
float
(),
reference_grad
))
elif
model
.
weight
.
grad
.
type
()
==
"torch.cuda.FloatTensor"
:
elif
model
.
weight
.
grad
.
type
()
==
"torch.cuda.FloatTensor"
:
self
.
assertTrue
(
torch
.
allclose
(
model
.
weight
.
grad
.
float
(),
reference_grad
))
self
.
assertTrue
(
torch
.
allclose
(
model
.
weight
.
grad
.
float
(),
reference_grad
))
else
:
else
:
...
@@ -115,22 +117,41 @@ class TestCache(unittest.TestCase):
...
@@ -115,22 +117,41 @@ class TestCache(unittest.TestCase):
# I could easily have these as a set of for loops in a single test,
# I could easily have these as a set of for loops in a single test,
# instead of going for granularity.
# instead of going for granularity.
def
test_whitelist_module_fp16_weight
(
self
):
def
test_whitelist_module_fp16_weight
(
self
):
self
.
train_eval_train_test
(
WhitelistModule
,
torch
.
float16
)
self
.
train_eval_train_test
(
WhitelistModule
,
torch
.
float16
,
"O1"
)
def
test_whitelist_module_fp32_weight
(
self
):
def
test_whitelist_module_fp32_weight
(
self
):
self
.
train_eval_train_test
(
WhitelistModule
,
torch
.
float32
)
self
.
train_eval_train_test
(
WhitelistModule
,
torch
.
float32
,
"O1"
)
def
test_blacklist_module_fp16_weight
(
self
):
def
test_blacklist_module_fp16_weight
(
self
):
self
.
train_eval_train_test
(
BlacklistModule
,
torch
.
float16
)
self
.
train_eval_train_test
(
BlacklistModule
,
torch
.
float16
,
"O1"
)
def
test_blacklist_module_fp32_weight
(
self
):
def
test_blacklist_module_fp32_weight
(
self
):
self
.
train_eval_train_test
(
BlacklistModule
,
torch
.
float32
)
self
.
train_eval_train_test
(
BlacklistModule
,
torch
.
float32
,
"O1"
)
def
test_promote_module_fp16_weight
(
self
):
def
test_promote_module_fp16_weight
(
self
):
self
.
train_eval_train_test
(
PromoteModule
,
torch
.
float16
)
self
.
train_eval_train_test
(
PromoteModule
,
torch
.
float16
,
"O1"
)
def
test_promote_module_fp32_weight
(
self
):
self
.
train_eval_train_test
(
PromoteModule
,
torch
.
float32
,
"O1"
)
# opt_level = O4
def
test_whitelist_module_bfp16_weight
(
self
):
self
.
train_eval_train_test
(
WhitelistModule
,
torch
.
bfloat16
,
"O4"
)
def
test_whitelist_module_fp32_weight
(
self
):
self
.
train_eval_train_test
(
WhitelistModule
,
torch
.
float32
,
"O4"
)
def
test_blacklist_module_bfp16_weight
(
self
):
self
.
train_eval_train_test
(
BlacklistModule
,
torch
.
bfloat16
,
"O4"
)
def
test_blacklist_module_fp32_weight
(
self
):
self
.
train_eval_train_test
(
BlacklistModule
,
torch
.
float32
,
"O4"
)
def
test_promote_module_bfp16_weight
(
self
):
self
.
train_eval_train_test
(
PromoteModule
,
torch
.
bfloat16
,
"O4"
)
def
test_promote_module_fp32_weight
(
self
):
def
test_promote_module_fp32_weight
(
self
):
self
.
train_eval_train_test
(
PromoteModule
,
torch
.
float32
)
self
.
train_eval_train_test
(
PromoteModule
,
torch
.
float32
,
"O4"
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
tests/L0/run_amp/test_checkpointing.py
View file @
32157739
...
@@ -28,7 +28,7 @@ class MyModel(torch.nn.Module):
...
@@ -28,7 +28,7 @@ class MyModel(torch.nn.Module):
class
TestCheckpointing
(
unittest
.
TestCase
):
class
TestCheckpointing
(
unittest
.
TestCase
):
def
setUp
(
self
):
def
setUp
(
self
):
self
.
initial_lr
=
1e-3
self
.
initial_lr
=
1e-3
self
.
test_opt_levels
=
(
"O0"
,
"O1"
,
"O2"
,
"O3"
)
self
.
test_opt_levels
=
(
"O0"
,
"O1"
,
"O2"
,
"O3"
,
"O4"
,
"O5"
)
def
seed
(
self
):
def
seed
(
self
):
torch
.
manual_seed
(
2809
)
torch
.
manual_seed
(
2809
)
...
@@ -236,6 +236,7 @@ class TestCheckpointing(unittest.TestCase):
...
@@ -236,6 +236,7 @@ class TestCheckpointing(unittest.TestCase):
state_dict
=
model
.
state_dict
()
state_dict
=
model
.
state_dict
()
for
key
in
state_dict
:
for
key
in
state_dict
:
self
.
assertFalse
(
'Half'
in
state_dict
[
key
].
type
())
self
.
assertFalse
(
'Half'
in
state_dict
[
key
].
type
())
self
.
assertFalse
(
'BFloat16'
in
state_dict
[
key
].
type
())
# Check, if model is still trainable
# Check, if model is still trainable
# Create dummy data
# Create dummy data
...
...
tests/L0/run_amp/test_multi_tensor_axpby.py
View file @
32157739
...
@@ -69,7 +69,10 @@ class TestMultiTensorAxpby(unittest.TestCase):
...
@@ -69,7 +69,10 @@ class TestMultiTensorAxpby(unittest.TestCase):
applier
(
multi_tensor_axpby
,
self
.
overflow_buf
,
[
x_list
,
y_list
,
out_list
],
self
.
a
,
self
.
b
,
-
1
)
applier
(
multi_tensor_axpby
,
self
.
overflow_buf
,
[
x_list
,
y_list
,
out_list
],
self
.
a
,
self
.
b
,
-
1
)
self
.
assertTrue
(
all
([
torch
.
allclose
(
out
,
self
.
ref
.
to
(
out_type
))
for
out
in
out_list
]),
# TODO: Remove this workaround for bfloat16 after torch.allcose() support bfloat16
if
out_type
==
torch
.
bfloat16
:
out_list
=
[
out
.
float
()
for
out
in
out_list
]
self
.
assertTrue
(
all
([
torch
.
allclose
(
out
,
self
.
ref
.
to
(
out
.
dtype
))
for
out
in
out_list
]),
msg
=
"{} {} {} {} {} {} {}"
.
format
(
sizea
,
sizeb
,
repeat_tensors
,
msg
=
"{} {} {} {} {} {} {}"
.
format
(
sizea
,
sizeb
,
repeat_tensors
,
x_type
,
y_type
,
out_type
,
inplace
))
x_type
,
y_type
,
out_type
,
inplace
))
self
.
assertTrue
(
self
.
overflow_buf
.
item
()
==
0
,
self
.
assertTrue
(
self
.
overflow_buf
.
item
()
==
0
,
...
@@ -119,9 +122,9 @@ class TestMultiTensorAxpby(unittest.TestCase):
...
@@ -119,9 +122,9 @@ class TestMultiTensorAxpby(unittest.TestCase):
for
sizea
,
sizeb
in
input_size_pairs
:
for
sizea
,
sizeb
in
input_size_pairs
:
for
applier
in
appliers
:
for
applier
in
appliers
:
for
repeat
in
repeat_tensors
:
for
repeat
in
repeat_tensors
:
for
x_type
in
(
torch
.
float32
,
torch
.
float16
):
for
x_type
in
(
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
):
for
y_type
in
(
torch
.
float32
,
torch
.
float16
):
for
y_type
in
(
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
):
for
out_type
in
(
torch
.
float32
,
torch
.
float16
):
for
out_type
in
(
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
):
for
inplace
in
(
True
,
False
):
for
inplace
in
(
True
,
False
):
if
inplace
is
True
and
(
y_type
is
not
out_type
):
if
inplace
is
True
and
(
y_type
is
not
out_type
):
continue
continue
...
...
tests/L0/run_amp/test_multi_tensor_scale.py
View file @
32157739
...
@@ -49,7 +49,10 @@ class TestMultiTensorScale(unittest.TestCase):
...
@@ -49,7 +49,10 @@ class TestMultiTensorScale(unittest.TestCase):
applier
(
multi_tensor_scale
,
self
.
overflow_buf
,
[
in_list
,
out_list
],
1.
/
self
.
scale
)
applier
(
multi_tensor_scale
,
self
.
overflow_buf
,
[
in_list
,
out_list
],
1.
/
self
.
scale
)
self
.
assertTrue
(
all
([
torch
.
allclose
(
out
,
self
.
ref
.
to
(
out_type
))
for
out
in
out_list
]))
# TODO: Remove this workaround for bfloat16 after torch.allcose() support bfloat16
if
out_type
==
torch
.
bfloat16
:
out_list
=
[
out
.
float
()
for
out
in
out_list
]
self
.
assertTrue
(
all
([
torch
.
allclose
(
out
,
self
.
ref
.
to
(
out
.
dtype
))
for
out
in
out_list
]))
self
.
assertTrue
(
self
.
overflow_buf
.
item
()
==
0
)
self
.
assertTrue
(
self
.
overflow_buf
.
item
()
==
0
)
def
find_inf
(
self
,
sizea
,
sizeb
,
applier
,
repeat_tensors
,
in_type
,
out_type
,
t
,
ind
,
val
,
inplace
=
False
):
def
find_inf
(
self
,
sizea
,
sizeb
,
applier
,
repeat_tensors
,
in_type
,
out_type
,
t
,
ind
,
val
,
inplace
=
False
):
...
@@ -106,8 +109,8 @@ class TestMultiTensorScale(unittest.TestCase):
...
@@ -106,8 +109,8 @@ class TestMultiTensorScale(unittest.TestCase):
for
sizea
,
sizeb
in
input_size_pairs
:
for
sizea
,
sizeb
in
input_size_pairs
:
for
applier
in
appliers
:
for
applier
in
appliers
:
for
repeat
in
repeat_tensors
:
for
repeat
in
repeat_tensors
:
for
in_type
in
(
torch
.
float32
,
torch
.
float16
):
for
in_type
in
(
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
):
for
out_type
in
(
torch
.
float32
,
torch
.
float16
):
for
out_type
in
(
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
):
for
inplace
in
(
True
,
False
):
for
inplace
in
(
True
,
False
):
if
inplace
is
True
and
(
out_type
is
not
in_type
):
if
inplace
is
True
and
(
out_type
is
not
in_type
):
continue
continue
...
...
tests/L0/run_amp/test_promotion.py
View file @
32157739
...
@@ -7,18 +7,18 @@ import torch
...
@@ -7,18 +7,18 @@ import torch
from
torch
import
nn
from
torch
import
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
utils
import
common_init
,
HALF
,
FLOAT
,
DTYPES
from
utils
import
common_init
,
HALF
,
FLOAT
,
DTYPES
,
DTYPES2
,
MATCH_INPUT
class
TestPromotion
(
unittest
.
TestCase
):
class
_
TestPromotion
(
unittest
.
TestCase
):
def
setUp
(
self
):
def
run_binary_promote_test
(
self
,
fns
,
input_shape
,
lp_type
,
x_inplace
=
False
):
self
.
handl
e
=
amp
.
init
(
enabled
=
True
)
if
lp_typ
e
=
=
torch
.
half
:
common_init
(
self
)
dtypes
=
DTYPES
elif
lp_type
==
torch
.
bfloat16
:
def
tearDown
(
self
):
dtypes
=
DTYPES2
s
el
f
.
handle
.
_deactivate
()
el
se
:
raise
RuntimeError
(
"Creating test class with invalid low_precision type.
\
def
run_binary_promote_test
(
self
,
fns
,
input_shape
,
x_inplace
=
False
):
Supported types are torch.half and torch.bfloat16"
)
type_pairs
=
it
.
product
(
DTYPES
,
DTYPES
)
type_pairs
=
it
.
product
(
dtypes
,
dtypes
)
for
fn
,
(
xtype
,
ytype
)
in
it
.
product
(
fns
,
type_pairs
):
for
fn
,
(
xtype
,
ytype
)
in
it
.
product
(
fns
,
type_pairs
):
x
=
torch
.
randn
(
input_shape
,
dtype
=
xtype
).
requires_grad_
()
x
=
torch
.
randn
(
input_shape
,
dtype
=
xtype
).
requires_grad_
()
x_leaf
=
x
x_leaf
=
x
...
@@ -35,41 +35,78 @@ class TestPromotion(unittest.TestCase):
...
@@ -35,41 +35,78 @@ class TestPromotion(unittest.TestCase):
if
xtype
==
torch
.
float
or
ytype
==
torch
.
float
:
if
xtype
==
torch
.
float
or
ytype
==
torch
.
float
:
self
.
assertEqual
(
out
.
type
(),
FLOAT
)
self
.
assertEqual
(
out
.
type
(),
FLOAT
)
else
:
else
:
self
.
assertEqual
(
out
.
type
(),
HALF
)
self
.
assertEqual
(
out
.
type
(),
MATCH_INPUT
[
lp_type
]
)
out
.
float
().
sum
().
backward
()
out
.
float
().
sum
().
backward
()
self
.
assertEqual
(
x_leaf
.
grad
.
dtype
,
xtype
)
self
.
assertEqual
(
x_leaf
.
grad
.
dtype
,
xtype
)
def
_test_cat_matches_widest
(
self
,
lp_type
):
shape
=
self
.
b
ys
=
[
torch
.
randn
(
shape
,
dtype
=
lp_type
)
for
_
in
range
(
5
)]
x_float
=
torch
.
randn
(
shape
)
out
=
torch
.
cat
(
ys
+
[
x_float
])
self
.
assertEqual
(
out
.
type
(),
FLOAT
)
x_lp
=
torch
.
randn
(
shape
,
dtype
=
lp_type
)
out
=
torch
.
cat
(
ys
+
[
x_lp
])
self
.
assertEqual
(
out
.
type
(),
MATCH_INPUT
[
lp_type
])
def
_test_inplace_exp_is_error_for_lp
(
self
,
lp_type
):
xs
=
torch
.
randn
(
self
.
b
)
xs
.
exp_
()
self
.
assertEqual
(
xs
.
type
(),
FLOAT
)
xs
=
torch
.
randn
(
self
.
b
,
dtype
=
lp_type
)
with
self
.
assertRaises
(
NotImplementedError
):
xs
.
exp_
()
class
TestPromotionHalf
(
_TestPromotion
):
def
setUp
(
self
):
self
.
handle
=
amp
.
init
(
enabled
=
True
,
patch_type
=
torch
.
half
)
common_init
(
self
)
def
tearDown
(
self
):
self
.
handle
.
_deactivate
()
def
test_atan2_matches_widest
(
self
):
def
test_atan2_matches_widest
(
self
):
fns
=
[
lambda
x
,
y
:
torch
.
atan2
(
x
,
y
),
fns
=
[
lambda
x
,
y
:
torch
.
atan2
(
x
,
y
),
lambda
x
,
y
:
x
.
atan2
(
y
)]
lambda
x
,
y
:
x
.
atan2
(
y
)]
self
.
run_binary_promote_test
(
fns
,
(
self
.
b
,))
self
.
run_binary_promote_test
(
fns
,
(
self
.
b
,)
,
torch
.
half
)
def
test_mul_matches_widest
(
self
):
def
test_mul_matches_widest
(
self
):
fns
=
[
lambda
x
,
y
:
torch
.
mul
(
x
,
y
),
fns
=
[
lambda
x
,
y
:
torch
.
mul
(
x
,
y
),
lambda
x
,
y
:
x
.
mul
(
y
)]
lambda
x
,
y
:
x
.
mul
(
y
)]
self
.
run_binary_promote_test
(
fns
,
(
self
.
b
,))
self
.
run_binary_promote_test
(
fns
,
(
self
.
b
,)
,
torch
.
half
)
def
test_cat_matches_widest
(
self
):
def
test_cat_matches_widest
(
self
):
shape
=
self
.
b
self
.
_test_cat_matches_widest
(
torch
.
half
)
ys
=
[
torch
.
randn
(
shape
,
dtype
=
torch
.
half
)
for
_
in
range
(
5
)]
x_float
=
torch
.
randn
(
shape
)
out
=
torch
.
cat
(
ys
+
[
x_float
])
self
.
assertEqual
(
out
.
type
(),
FLOAT
)
x_half
=
torch
.
randn
(
shape
,
dtype
=
torch
.
half
)
out
=
torch
.
cat
(
ys
+
[
x_half
])
self
.
assertEqual
(
out
.
type
(),
HALF
)
def
test_inplace_exp_is_error_for_half
(
self
):
def
test_inplace_exp_is_error_for_half
(
self
):
xs
=
torch
.
randn
(
self
.
b
)
self
.
_test_inplace_exp_is_error_for_lp
(
torch
.
half
)
xs
.
exp_
()
self
.
assertEqual
(
xs
.
type
(),
FLOAT
)
def
test_inplace_add_matches_self
(
self
):
xs
=
torch
.
randn
(
self
.
b
,
dtype
=
torch
.
half
)
fn
=
lambda
x
,
y
:
x
.
add_
(
y
)
with
self
.
assertRaises
(
NotImplementedError
):
self
.
run_binary_promote_test
([
fn
],
(
self
.
b
,),
torch
.
half
,
x_inplace
=
True
)
xs
.
exp_
()
class
TestPromotionBFloat16
(
_TestPromotion
):
def
setUp
(
self
):
self
.
handle
=
amp
.
init
(
enabled
=
True
,
patch_type
=
torch
.
bfloat16
)
common_init
(
self
)
def
tearDown
(
self
):
self
.
handle
.
_deactivate
()
def
test_mul_matches_widest
(
self
):
fns
=
[
lambda
x
,
y
:
torch
.
mul
(
x
,
y
),
lambda
x
,
y
:
x
.
mul
(
y
)]
self
.
run_binary_promote_test
(
fns
,
(
self
.
b
,),
torch
.
bfloat16
)
def
test_cat_matches_widest
(
self
):
self
.
_test_cat_matches_widest
(
torch
.
bfloat16
)
def
test_inplace_exp_is_error_for_bfloat16
(
self
):
self
.
_test_inplace_exp_is_error_for_lp
(
torch
.
bfloat16
)
def
test_inplace_add_matches_self
(
self
):
def
test_inplace_add_matches_self
(
self
):
fn
=
lambda
x
,
y
:
x
.
add_
(
y
)
fn
=
lambda
x
,
y
:
x
.
add_
(
y
)
self
.
run_binary_promote_test
([
fn
],
(
self
.
b
,),
x_inplace
=
True
)
self
.
run_binary_promote_test
([
fn
],
(
self
.
b
,),
torch
.
bfloat16
,
x_inplace
=
True
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
unittest
.
main
()
unittest
.
main
()
tests/L0/run_amp/utils.py
View file @
32157739
...
@@ -6,6 +6,8 @@ BFLOAT16 = 'torch.cuda.BFloat16Tensor'
...
@@ -6,6 +6,8 @@ BFLOAT16 = 'torch.cuda.BFloat16Tensor'
DTYPES
=
[
torch
.
half
,
torch
.
float
]
DTYPES
=
[
torch
.
half
,
torch
.
float
]
DTYPES2
=
[
torch
.
bfloat16
,
torch
.
float
]
ALWAYS_HALF
=
{
torch
.
float
:
HALF
,
ALWAYS_HALF
=
{
torch
.
float
:
HALF
,
torch
.
half
:
HALF
}
torch
.
half
:
HALF
}
ALWAYS_BFLOAT16
=
{
torch
.
bfloat16
:
BFLOAT16
,
ALWAYS_BFLOAT16
=
{
torch
.
bfloat16
:
BFLOAT16
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment