Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
81c788f0
Commit
81c788f0
authored
Aug 27, 2018
by
Carl Case
Browse files
WIP: promotion tests
parent
2e69d933
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
148 additions
and
35 deletions
+148
-35
apex/amp/test/test_basic_casts.py
apex/amp/test/test_basic_casts.py
+70
-35
apex/amp/test/test_promotion.py
apex/amp/test/test_promotion.py
+58
-0
apex/amp/test/utils.py
apex/amp/test/utils.py
+20
-0
No files found.
apex/amp/test/test_basic_casts.py
View file @
81c788f0
...
...
@@ -8,58 +8,44 @@ import torch
from
torch
import
nn
import
torch.nn.functional
as
F
HALF
=
'torch.cuda.HalfTensor'
FLOAT
=
'torch.cuda.FloatTensor'
ALWAYS_HALF
=
{
torch
.
float
:
HALF
,
torch
.
half
:
HALF
}
ALWAYS_FLOAT
=
{
torch
.
float
:
FLOAT
,
torch
.
half
:
FLOAT
}
MATCH_INPUT
=
{
torch
.
float
:
FLOAT
,
torch
.
half
:
HALF
}
def
_common_init
(
test_case
):
test_case
.
h
=
64
test_case
.
b
=
16
test_case
.
c
=
16
test_case
.
k
=
3
torch
.
set_default_tensor_type
(
torch
.
cuda
.
FloatTensor
)
from
.utils
import
common_init
,
HALF
,
FLOAT
,
\
ALWAYS_HALF
,
ALWAYS_FLOAT
,
MATCH_INPUT
def
run_layer_test
(
test_case
,
fns
,
expected
,
input_shape
,
test_backward
=
True
):
for
fn
,
typ
in
it
.
product
(
fns
,
expected
.
keys
()):
x
=
torch
.
randn
(
input_shape
,
dtype
=
typ
).
requires_grad_
()
y
=
fn
(
x
)
test_case
.
assertEqual
(
y
.
type
(),
expected
[
typ
])
if
test_backward
:
y
.
float
().
sum
().
backward
()
test_case
.
assertEqual
(
x
.
grad
.
type
(),
MATCH_INPUT
[
typ
])
class
TestBasicCasts
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
handle
=
amp
.
init
(
enabled
=
True
)
_
common_init
(
self
)
common_init
(
self
)
def
tearDown
(
self
):
self
.
handle
.
_deactivate
()
def
run_layer_test
(
self
,
fns
,
expected
,
input_shape
,
test_backward
=
True
):
for
fn
,
typ
in
it
.
product
(
fns
,
expected
.
keys
()):
x
=
torch
.
randn
(
input_shape
,
dtype
=
typ
).
requires_grad_
()
y
=
fn
(
x
)
self
.
assertEqual
(
y
.
type
(),
expected
[
typ
])
if
test_backward
:
y
.
float
().
sum
().
backward
()
self
.
assertEqual
(
x
.
grad
.
type
(),
MATCH_INPUT
[
typ
])
def
test_linear_is_half
(
self
):
m
=
nn
.
Linear
(
self
.
h
,
self
.
h
)
f
=
ft
.
partial
(
F
.
linear
,
weight
=
m
.
weight
,
bias
=
m
.
bias
)
self
.
run_layer_test
([
m
,
f
],
ALWAYS_HALF
,
(
self
.
b
,
self
.
h
))
run_layer_test
(
self
,
[
m
,
f
],
ALWAYS_HALF
,
(
self
.
b
,
self
.
h
))
def
test_conv2d_is_half
(
self
):
m
=
nn
.
Conv2d
(
self
.
c
,
self
.
c
,
self
.
k
)
f
=
ft
.
partial
(
F
.
conv2d
,
weight
=
m
.
weight
,
bias
=
m
.
bias
)
self
.
run_layer_test
([
m
,
f
],
ALWAYS_HALF
,
(
self
.
b
,
self
.
c
,
self
.
h
,
self
.
h
))
run_layer_test
(
self
,
[
m
,
f
],
ALWAYS_HALF
,
(
self
.
b
,
self
.
c
,
self
.
h
,
self
.
h
))
def
test_softmax_is_float
(
self
):
m
=
nn
.
Softmax
(
dim
=
1
)
f
=
ft
.
partial
(
F
.
softmax
,
dim
=
1
)
self
.
run_layer_test
([
m
,
f
],
ALWAYS_FLOAT
,
(
self
.
b
,
self
.
h
))
run_layer_test
(
self
,
[
m
,
f
],
ALWAYS_FLOAT
,
(
self
.
b
,
self
.
h
))
def
test_group_norm_is_float
(
self
):
m
=
nn
.
GroupNorm
(
num_groups
=
4
,
num_channels
=
self
.
c
)
self
.
run_layer_test
([
m
],
ALWAYS_FLOAT
,
(
self
.
b
,
self
.
c
,
self
.
h
,
self
.
h
))
run_layer_test
(
self
,
[
m
],
ALWAYS_FLOAT
,
(
self
.
b
,
self
.
c
,
self
.
h
,
self
.
h
))
def
test_mse_loss_is_float
(
self
):
shape
=
(
self
.
b
,
self
.
h
)
...
...
@@ -67,27 +53,76 @@ class TestBasicCasts(unittest.TestCase):
mod
=
nn
.
MSELoss
()
m
=
lambda
x
:
mod
(
x
,
target
)
f
=
ft
.
partial
(
F
.
mse_loss
,
target
=
target
)
self
.
run_layer_test
([
m
],
ALWAYS_FLOAT
,
shape
)
run_layer_test
(
self
,
[
m
],
ALWAYS_FLOAT
,
shape
)
def
test_relu_is_match
(
self
):
self
.
run_layer_test
([
nn
.
ReLU
(),
F
.
relu
],
MATCH_INPUT
,
(
self
.
b
,
self
.
h
))
run_layer_test
(
self
,
[
nn
.
ReLU
(),
F
.
relu
],
MATCH_INPUT
,
(
self
.
b
,
self
.
h
))
def
test_batch_norm_is_match
(
self
):
m
=
nn
.
BatchNorm2d
(
num_features
=
self
.
c
)
f
=
ft
.
partial
(
F
.
batch_norm
,
running_mean
=
m
.
running_mean
,
running_var
=
m
.
running_var
,
weight
=
m
.
weight
,
bias
=
m
.
bias
,
training
=
True
)
self
.
run_layer_test
([
m
],
MATCH_INPUT
,
(
self
.
b
,
self
.
c
,
self
.
h
,
self
.
h
))
run_layer_test
(
self
,
[
m
],
MATCH_INPUT
,
(
self
.
b
,
self
.
c
,
self
.
h
,
self
.
h
))
# Test forward-only for BN inference
m
.
eval
()
f
=
ft
.
partial
(
F
.
batch_norm
,
running_mean
=
m
.
running_mean
,
running_var
=
m
.
running_var
,
weight
=
m
.
weight
,
bias
=
m
.
bias
,
training
=
False
)
self
.
run_layer_test
([
m
,
f
],
MATCH_INPUT
,
(
self
.
b
,
self
.
c
,
self
.
h
,
self
.
h
),
test_backward
=
False
)
run_layer_test
(
self
,
[
m
,
f
],
MATCH_INPUT
,
(
self
.
b
,
self
.
c
,
self
.
h
,
self
.
h
),
test_backward
=
False
)
def
test_bce_raises
(
self
):
shape
=
(
self
.
b
,
self
.
h
)
target
=
torch
.
randn
(
shape
)
mod
=
nn
.
BCELoss
()
m
=
lambda
x
:
mod
(
x
,
target
)
f
=
ft
.
partial
(
F
.
binary_cross_entropy
,
target
=
target
)
for
fn
in
[
m
,
f
]:
x
=
torch
.
randn
(
shape
,
dtype
=
torch
.
half
)
self
.
assertRaises
(
NotImplementedError
,
fn
,
x
)
class
TestTensorCasts
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
handle
=
amp
.
init
(
enabled
=
True
)
common_init
(
self
)
def
tearDown
(
self
):
self
.
handle
.
_deactivate
()
def
test_matmul_method_is_half
(
self
):
other
=
torch
.
randn
(
self
.
h
,
self
.
h
)
lhs
=
lambda
x
:
x
.
matmul
(
other
)
rhs
=
lambda
x
:
other
.
matmul
(
x
)
run_layer_test
(
self
,
[
lhs
,
rhs
],
ALWAYS_HALF
,
(
self
.
h
,
self
.
h
))
def
test_matmul_op_is_half
(
self
):
other
=
torch
.
randn
(
self
.
h
,
self
.
h
)
lhs
=
lambda
x
:
x
@
other
rhs
=
lambda
x
:
other
@
x
run_layer_test
(
self
,
[
lhs
,
rhs
],
ALWAYS_HALF
,
(
self
.
h
,
self
.
h
))
def
test_pow_method_is_float
(
self
):
fn
=
lambda
x
:
x
.
pow
(
2.
)
run_layer_test
(
self
,
[
fn
],
ALWAYS_FLOAT
,
(
self
.
b
,
self
.
h
))
def
test_pow_op_is_float
(
self
):
fn
=
lambda
x
:
x
**
2.
run_layer_test
(
self
,
[
fn
],
ALWAYS_FLOAT
,
(
self
.
b
,
self
.
h
))
def
test_cpu_is_float
(
self
):
fn
=
lambda
x
:
x
.
cpu
()
always_cpu_float
=
{
torch
.
float
:
'torch.FloatTensor'
,
torch
.
half
:
'torch.FloatTensor'
}
run_layer_test
(
self
,
[
fn
],
always_cpu_float
,
(
self
.
b
,
self
.
h
))
def
test_sum_is_float
(
self
):
fn
=
lambda
x
:
x
.
sum
()
run_layer_test
(
self
,
[
fn
],
ALWAYS_FLOAT
,
(
self
.
b
,
self
.
h
))
class
TestDisabledCasts
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
handle
=
amp
.
init
(
enabled
=
False
)
_
common_init
(
self
)
common_init
(
self
)
def
test_disabled_linear
(
self
):
m
=
nn
.
Linear
(
self
.
h
,
self
.
h
)
...
...
apex/amp/test/test_promotion.py
0 → 100644
View file @
81c788f0
import
unittest
import
itertools
as
it
from
apex
import
amp
import
torch
from
torch
import
nn
import
torch.nn.functional
as
F
from
.utils
import
common_init
,
HALF
,
FLOAT
,
DTYPES
class
TestPromotion
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
handle
=
amp
.
init
(
enabled
=
True
)
common_init
(
self
)
def
tearDown
(
self
):
self
.
handle
.
_deactivate
()
def
run_binary_promote_test
(
self
,
fns
,
input_shape
):
type_pairs
=
it
.
product
(
DTYPES
,
DTYPES
)
for
fn
,
(
xtype
,
ytype
)
in
it
.
product
(
fns
,
type_pairs
):
x
=
torch
.
randn
(
input_shape
,
dtype
=
xtype
).
requires_grad_
()
y
=
torch
.
randn
(
input_shape
,
dtype
=
ytype
)
out
=
fn
(
x
,
y
)
if
xtype
==
torch
.
float
or
ytype
==
torch
.
float
:
self
.
assertEqual
(
out
.
type
(),
FLOAT
)
else
:
self
.
assertEqual
(
out
.
type
(),
HALF
)
out
.
float
().
sum
().
backward
()
self
.
assertEqual
(
x
.
grad
.
dtype
,
xtype
)
def
test_atan2_matches_widest
(
self
):
fns
=
[
lambda
x
,
y
:
torch
.
atan2
(
x
,
y
),
lambda
x
,
y
:
x
.
atan2
(
y
)]
self
.
run_binary_promote_test
(
fns
,
(
self
.
b
,))
def
test_mul_matches_widest
(
self
):
fns
=
[
lambda
x
,
y
:
torch
.
mul
(
x
,
y
),
lambda
x
,
y
:
x
.
mul
(
y
)]
self
.
run_binary_promote_test
(
fns
,
(
self
.
b
,))
def
test_cat_matches_widest
(
self
):
shape
=
self
.
b
ys
=
[
torch
.
randn
(
shape
,
dtype
=
torch
.
half
)
for
_
in
range
(
5
)]
x_float
=
torch
.
randn
(
shape
)
out
=
torch
.
cat
(
ys
+
[
x_float
])
self
.
assertEqual
(
out
.
type
(),
FLOAT
)
x_half
=
torch
.
randn
(
shape
,
dtype
=
torch
.
half
)
out
=
torch
.
cat
(
ys
+
[
x_half
])
self
.
assertEqual
(
out
.
type
(),
HALF
)
# TODOs:
# In-place methods on fp16 are errors for fp32
# In-place methods match type of self tensor
if
__name__
==
'__main__'
:
unittest
.
main
()
apex/amp/test/utils.py
0 → 100644
View file @
81c788f0
import
torch
HALF
=
'torch.cuda.HalfTensor'
FLOAT
=
'torch.cuda.FloatTensor'
DTYPES
=
[
torch
.
half
,
torch
.
float
]
ALWAYS_HALF
=
{
torch
.
float
:
HALF
,
torch
.
half
:
HALF
}
ALWAYS_FLOAT
=
{
torch
.
float
:
FLOAT
,
torch
.
half
:
FLOAT
}
MATCH_INPUT
=
{
torch
.
float
:
FLOAT
,
torch
.
half
:
HALF
}
def
common_init
(
test_case
):
test_case
.
h
=
64
test_case
.
b
=
16
test_case
.
c
=
16
test_case
.
k
=
3
torch
.
set_default_tensor_type
(
torch
.
cuda
.
FloatTensor
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment