Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
d283f97f
"vscode:/vscode.git/clone" did not exist on "3b5b1c56983004ca1ee4190d0eb65f98b0101d39"
Commit
d283f97f
authored
May 12, 2020
by
rohithkrn
Browse files
add bflaot16 tests in test_basic_casts
parent
69251362
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
163 additions
and
50 deletions
+163
-50
tests/L0/run_amp/test_basic_casts.py
tests/L0/run_amp/test_basic_casts.py
+158
-49
tests/L0/run_amp/utils.py
tests/L0/run_amp/utils.py
+5
-1
No files found.
tests/L0/run_amp/test_basic_casts.py
View file @
d283f97f
...
...
@@ -9,7 +9,7 @@ from torch import nn
import
torch.nn.functional
as
F
from
utils
import
common_init
,
HALF
,
FLOAT
,
\
ALWAYS_HALF
,
ALWAYS_FLOAT
,
MATCH_INPUT
ALWAYS_HALF
,
ALWAYS_BFLOAT16
,
ALWAYS_FLOAT
,
MATCH_INPUT
def
run_layer_test
(
test_case
,
fns
,
expected
,
input_shape
,
test_backward
=
True
):
for
fn
,
typ
in
it
.
product
(
fns
,
expected
.
keys
()):
...
...
@@ -20,124 +20,233 @@ def run_layer_test(test_case, fns, expected, input_shape, test_backward=True):
y
.
float
().
sum
().
backward
()
test_case
.
assertEqual
(
x
.
grad
.
type
(),
MATCH_INPUT
[
typ
])
class
TestBasicCasts
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
handle
=
amp
.
init
(
enabled
=
True
)
common_init
(
self
)
def
tearDown
(
self
):
self
.
handle
.
_deactivate
()
def
test_linear_is_half
(
self
):
class
_TestBasicCasts
(
unittest
.
TestCase
):
def
_test_linear
(
self
,
expected
):
m
=
nn
.
Linear
(
self
.
h
,
self
.
h
)
f
=
ft
.
partial
(
F
.
linear
,
weight
=
m
.
weight
,
bias
=
m
.
bias
)
run_layer_test
(
self
,
[
m
,
f
],
ALWAYS_HALF
,
(
self
.
b
,
self
.
h
))
run_layer_test
(
self
,
[
m
,
f
],
expected
,
(
self
.
b
,
self
.
h
))
def
test_conv2d
_is_half
(
self
):
def
_
test_conv2d
(
self
,
expected
):
m
=
nn
.
Conv2d
(
self
.
c
,
self
.
c
,
self
.
k
)
f
=
ft
.
partial
(
F
.
conv2d
,
weight
=
m
.
weight
,
bias
=
m
.
bias
)
run_layer_test
(
self
,
[
m
,
f
],
ALWAYS_HALF
,
(
self
.
b
,
self
.
c
,
self
.
h
,
self
.
h
))
run_layer_test
(
self
,
[
m
,
f
],
expected
,
(
self
.
b
,
self
.
c
,
self
.
h
,
self
.
h
))
def
test_softmax
_is_float
(
self
):
def
_
test_softmax
(
self
,
expected
):
m
=
nn
.
Softmax
(
dim
=
1
)
f
=
ft
.
partial
(
F
.
softmax
,
dim
=
1
)
run_layer_test
(
self
,
[
m
,
f
],
ALWAYS_FLOAT
,
(
self
.
b
,
self
.
h
))
run_layer_test
(
self
,
[
m
,
f
],
expected
,
(
self
.
b
,
self
.
h
))
def
test_group_norm
_is_float
(
self
):
def
_
test_group_norm
(
self
,
expected
):
m
=
nn
.
GroupNorm
(
num_groups
=
4
,
num_channels
=
self
.
c
)
run_layer_test
(
self
,
[
m
],
ALWAYS_FLOAT
,
(
self
.
b
,
self
.
c
,
self
.
h
,
self
.
h
))
run_layer_test
(
self
,
[
m
],
expected
,
(
self
.
b
,
self
.
c
,
self
.
h
,
self
.
h
))
def
test_mse_loss
_is_float
(
self
):
def
_
test_mse_loss
(
self
,
expected
):
shape
=
(
self
.
b
,
self
.
h
)
target
=
torch
.
randn
(
shape
)
mod
=
nn
.
MSELoss
()
m
=
lambda
x
:
mod
(
x
,
target
)
f
=
ft
.
partial
(
F
.
mse_loss
,
target
=
target
)
run_layer_test
(
self
,
[
m
],
ALWAYS_FLOAT
,
shape
)
run_layer_test
(
self
,
[
m
],
expected
,
shape
)
def
test_relu
_is_match
(
self
):
run_layer_test
(
self
,
[
nn
.
ReLU
(),
F
.
relu
],
MATCH_INPUT
,
(
self
.
b
,
self
.
h
))
def
_
test_relu
(
self
,
expected
):
run_layer_test
(
self
,
[
nn
.
ReLU
(),
F
.
relu
],
expected
,
(
self
.
b
,
self
.
h
))
def
test_batch_norm
_is_match
(
self
):
def
_
test_batch_norm
(
self
,
expected
):
m
=
nn
.
BatchNorm2d
(
num_features
=
self
.
c
)
f
=
ft
.
partial
(
F
.
batch_norm
,
running_mean
=
m
.
running_mean
,
running_var
=
m
.
running_var
,
weight
=
m
.
weight
,
bias
=
m
.
bias
,
training
=
True
)
run_layer_test
(
self
,
[
m
],
MATCH_INPUT
,
(
self
.
b
,
self
.
c
,
self
.
h
,
self
.
h
))
run_layer_test
(
self
,
[
m
],
expected
,
(
self
.
b
,
self
.
c
,
self
.
h
,
self
.
h
))
# Test forward-only for BN inference
m
.
eval
()
f
=
ft
.
partial
(
F
.
batch_norm
,
running_mean
=
m
.
running_mean
,
running_var
=
m
.
running_var
,
weight
=
m
.
weight
,
bias
=
m
.
bias
,
training
=
False
)
run_layer_test
(
self
,
[
m
,
f
],
MATCH_INPUT
,
(
self
.
b
,
self
.
c
,
self
.
h
,
self
.
h
),
run_layer_test
(
self
,
[
m
,
f
],
expected
,
(
self
.
b
,
self
.
c
,
self
.
h
,
self
.
h
),
test_backward
=
False
)
class
TestBasicCastsHalf
(
_TestBasicCasts
):
def
setUp
(
self
):
self
.
handle
=
amp
.
init
(
enabled
=
True
,
patch_type
=
torch
.
half
)
common_init
(
self
)
def
tearDown
(
self
):
self
.
handle
.
_deactivate
()
def
test_linear_is_half
(
self
):
self
.
_test_linear
(
ALWAYS_HALF
)
def
test_conv2d_is_half
(
self
):
self
.
_test_conv2d
(
ALWAYS_HALF
)
def
test_softmax_is_float
(
self
):
self
.
_test_softmax
(
ALWAYS_FLOAT
)
def
test_group_norm_is_float
(
self
):
self
.
_test_group_norm
(
ALWAYS_FLOAT
)
def
test_mse_loss_is_float
(
self
):
self
.
_test_mse_loss
(
ALWAYS_FLOAT
)
def
test_relu_is_match
(
self
):
self
.
_test_relu
(
MATCH_INPUT
)
def
test_batch_norm_is_match
(
self
):
self
.
_test_batch_norm
(
MATCH_INPUT
)
class
TestBasicCastsBFloat16
(
_TestBasicCasts
):
def
setUp
(
self
):
self
.
handle
=
amp
.
init
(
enabled
=
True
,
patch_type
=
torch
.
bfloat16
)
common_init
(
self
)
def
tearDown
(
self
):
self
.
handle
.
_deactivate
()
def
test_linear_is_bfloat16
(
self
):
self
.
_test_linear
(
ALWAYS_BFLOAT16
)
def
test_conv2d_is_bfloat16
(
self
):
self
.
_test_conv2d
(
ALWAYS_BFLOAT16
)
def
test_softmax_is_float
(
self
):
self
.
_test_softmax
(
ALWAYS_FLOAT
)
def
test_group_norm_is_float
(
self
):
self
.
_test_group_norm
(
ALWAYS_FLOAT
)
def
test_mse_loss_is_float
(
self
):
self
.
_test_mse_loss
(
ALWAYS_FLOAT
)
def
test_relu_is_match
(
self
):
self
.
_test_relu
(
MATCH_INPUT
)
def
test_batch_norm_is_match
(
self
):
self
.
_test_batch_norm
(
MATCH_INPUT
)
class
TestBannedMethods
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
handle
=
amp
.
init
(
enabled
=
True
)
self
.
handle
=
amp
.
init
(
enabled
=
True
,
patch_type
=
torch
.
half
)
common_init
(
self
)
def
tearDown
(
self
):
self
.
handle
.
_deactivate
()
def
bce_common
(
self
,
assertion
):
def
bce_common
(
self
,
assertion
,
dtype
=
torch
.
half
):
shape
=
(
self
.
b
,
self
.
h
)
target
=
torch
.
rand
(
shape
)
mod
=
nn
.
BCELoss
()
m
=
lambda
x
:
mod
(
x
,
target
)
f
=
ft
.
partial
(
F
.
binary_cross_entropy
,
target
=
target
)
for
fn
in
[
m
,
f
]:
x
=
torch
.
rand
(
shape
,
dtype
=
torch
.
half
)
x
=
torch
.
rand
(
shape
,
dtype
=
dtype
)
assertion
(
fn
,
x
)
def
test_bce_raises_by_default
(
self
):
assertion
=
lambda
fn
,
x
:
self
.
assertRaises
(
NotImplementedError
,
fn
,
x
)
self
.
bce_common
(
assertion
)
self
.
bce_common
(
assertion
,
dtype
=
torch
.
half
)
# handle with bfloat16 as patch_type
self
.
handle
.
_deactivate
()
self
.
handle
=
amp
.
init
(
enabled
=
True
,
patch_type
=
torch
.
bfloat16
)
self
.
bce_common
(
assertion
,
dtype
=
torch
.
bfloat16
)
def
test_bce_is_float_with_allow_banned
(
self
):
self
.
handle
.
_deactivate
()
self
.
handle
=
amp
.
init
(
enabled
=
True
,
allow_banned
=
True
)
self
.
handle
=
amp
.
init
(
enabled
=
True
,
allow_banned
=
True
,
patch_type
=
torch
.
half
)
assertion
=
lambda
fn
,
x
:
self
.
assertEqual
(
fn
(
x
).
type
(),
FLOAT
)
self
.
bce_common
(
assertion
)
self
.
bce_common
(
assertion
,
dtype
=
torch
.
half
)
class
TestTensorCasts
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
handle
=
amp
.
init
(
enabled
=
True
)
common_init
(
self
)
def
tearDown
(
self
):
# handle with bfloat16 as patch_type
self
.
handle
.
_deactivate
()
self
.
handle
=
amp
.
init
(
enabled
=
True
,
allow_banned
=
True
,
patch_type
=
torch
.
bfloat16
)
self
.
bce_common
(
assertion
,
dtype
=
torch
.
bfloat16
)
def
test_matmul_method_is_half
(
self
):
class
_TestTensorCasts
(
unittest
.
TestCase
):
def
_test_matmul_method
(
self
,
expected
):
other
=
torch
.
randn
(
self
.
h
,
self
.
h
)
lhs
=
lambda
x
:
x
.
matmul
(
other
)
rhs
=
lambda
x
:
other
.
matmul
(
x
)
run_layer_test
(
self
,
[
lhs
,
rhs
],
ALWAYS_HALF
,
(
self
.
h
,
self
.
h
))
run_layer_test
(
self
,
[
lhs
,
rhs
],
expected
,
(
self
.
h
,
self
.
h
))
def
test_matmul_op
_is_half
(
self
):
def
_
test_matmul_op
(
self
,
expected
):
other
=
torch
.
randn
(
self
.
h
,
self
.
h
)
lhs
=
lambda
x
:
x
@
other
rhs
=
lambda
x
:
other
@
x
run_layer_test
(
self
,
[
lhs
,
rhs
],
ALWAYS_HALF
,
(
self
.
h
,
self
.
h
))
run_layer_test
(
self
,
[
lhs
,
rhs
],
expected
,
(
self
.
h
,
self
.
h
))
def
test_pow_method
_is_float
(
self
):
def
_
test_pow_method
(
self
,
expected
):
fn
=
lambda
x
:
x
.
pow
(
2.
)
run_layer_test
(
self
,
[
fn
],
ALWAYS_FLOAT
,
(
self
.
b
,
self
.
h
))
run_layer_test
(
self
,
[
fn
],
expected
,
(
self
.
b
,
self
.
h
))
def
test_pow_op
_is_float
(
self
):
def
_
test_pow_op
(
self
,
expected
):
fn
=
lambda
x
:
x
**
2.
run_layer_test
(
self
,
[
fn
],
ALWAYS_FLOAT
,
(
self
.
b
,
self
.
h
))
run_layer_test
(
self
,
[
fn
],
expected
,
(
self
.
b
,
self
.
h
))
def
test_cpu
_is_float
(
self
):
def
_
test_cpu
(
self
,
expected
):
fn
=
lambda
x
:
x
.
cpu
()
run_layer_test
(
self
,
[
fn
],
expected
,
(
self
.
b
,
self
.
h
))
def
_test_sum
(
self
,
expected
):
fn
=
lambda
x
:
x
.
sum
()
run_layer_test
(
self
,
[
fn
],
expected
,
(
self
.
b
,
self
.
h
))
# TODO: maybe more tests on disabled casting?
class
TestTensorCastsHalf
(
_TestTensorCasts
):
def
setUp
(
self
):
self
.
handle
=
amp
.
init
(
enabled
=
True
,
patch_type
=
torch
.
half
)
common_init
(
self
)
def
tearDown
(
self
):
self
.
handle
.
_deactivate
()
def
test_matmul_method_is_half
(
self
):
self
.
_test_matmul_method
(
ALWAYS_HALF
)
def
test_matmul_op_is_half
(
self
):
self
.
_test_matmul_op
(
ALWAYS_HALF
)
def
test_pow_method_is_float
(
self
):
self
.
_test_pow_method
(
ALWAYS_FLOAT
)
def
test_pow_op_is_float
(
self
):
self
.
_test_pow_op
(
ALWAYS_FLOAT
)
def
test_cpu_is_float
(
self
):
always_cpu_float
=
{
torch
.
float
:
'torch.FloatTensor'
,
torch
.
half
:
'torch.FloatTensor'
}
run_layer_test
(
self
,
[
fn
],
always_cpu_float
,
(
self
.
b
,
self
.
h
)
)
self
.
_test_cpu
(
always_cpu_float
)
def
test_sum_is_float
(
self
):
fn
=
lambda
x
:
x
.
sum
()
run_layer_test
(
self
,
[
fn
],
ALWAYS_FLOAT
,
(
self
.
b
,
self
.
h
))
self
.
_test_sum
(
ALWAYS_FLOAT
)
class
TestTensorCastsBFloat16
(
_TestTensorCasts
):
def
setUp
(
self
):
self
.
handle
=
amp
.
init
(
enabled
=
True
,
patch_type
=
torch
.
bfloat16
)
common_init
(
self
)
def
tearDown
(
self
):
self
.
handle
.
_deactivate
()
def
test_matmul_method_is_bfloat16
(
self
):
self
.
_test_matmul_method
(
ALWAYS_BFLOAT16
)
def
test_matmul_op_is_bfloat16
(
self
):
self
.
_test_matmul_op
(
ALWAYS_BFLOAT16
)
def
test_pow_method_is_float
(
self
):
self
.
_test_pow_method
(
ALWAYS_FLOAT
)
def
test_pow_op_is_float
(
self
):
self
.
_test_pow_op
(
ALWAYS_FLOAT
)
def
test_cpu_is_float
(
self
):
always_cpu_float
=
{
torch
.
float
:
'torch.FloatTensor'
,
torch
.
bfloat16
:
'torch.FloatTensor'
}
self
.
_test_cpu
(
always_cpu_float
)
def
test_sum_is_float
(
self
):
self
.
_test_sum
(
ALWAYS_FLOAT
)
# TODO: maybe more tests on disabled casting?
if
__name__
==
'__main__'
:
unittest
.
main
()
tests/L0/run_amp/utils.py
View file @
d283f97f
...
...
@@ -2,15 +2,19 @@ import torch
HALF
=
'torch.cuda.HalfTensor'
FLOAT
=
'torch.cuda.FloatTensor'
BFLOAT16
=
'torch.cuda.BFloat16Tensor'
DTYPES
=
[
torch
.
half
,
torch
.
float
]
ALWAYS_HALF
=
{
torch
.
float
:
HALF
,
torch
.
half
:
HALF
}
ALWAYS_BFLOAT16
=
{
torch
.
bfloat16
:
BFLOAT16
,
torch
.
float
:
BFLOAT16
}
ALWAYS_FLOAT
=
{
torch
.
float
:
FLOAT
,
torch
.
half
:
FLOAT
}
MATCH_INPUT
=
{
torch
.
float
:
FLOAT
,
torch
.
half
:
HALF
}
torch
.
half
:
HALF
,
torch
.
bfloat16
:
BFLOAT16
}
def
common_init
(
test_case
):
test_case
.
h
=
64
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment