Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
4ac8ecb9
Commit
4ac8ecb9
authored
May 19, 2020
by
lcskrishna
Browse files
Merge branch 'master' of
https://github.com/ROCmSoftwarePlatform/apex
into cl/enable-test-framework
parents
bc626b13
b2da92fc
Changes
30
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
395 additions
and
102 deletions
+395
-102
csrc/multi_tensor_sgd_kernel.cu
csrc/multi_tensor_sgd_kernel.cu
+42
-0
csrc/type_shim.h
csrc/type_shim.h
+60
-0
tests/L0/run_amp/test_add_param_group.py
tests/L0/run_amp/test_add_param_group.py
+18
-7
tests/L0/run_amp/test_basic_casts.py
tests/L0/run_amp/test_basic_casts.py
+158
-49
tests/L0/run_amp/test_cache.py
tests/L0/run_amp/test_cache.py
+29
-8
tests/L0/run_amp/test_checkpointing.py
tests/L0/run_amp/test_checkpointing.py
+2
-1
tests/L0/run_amp/test_multi_tensor_axpby.py
tests/L0/run_amp/test_multi_tensor_axpby.py
+7
-4
tests/L0/run_amp/test_multi_tensor_scale.py
tests/L0/run_amp/test_multi_tensor_scale.py
+6
-3
tests/L0/run_amp/test_promotion.py
tests/L0/run_amp/test_promotion.py
+66
-29
tests/L0/run_amp/utils.py
tests/L0/run_amp/utils.py
+7
-1
No files found.
csrc/multi_tensor_sgd_kernel.cu
View file @
4ac8ecb9
...
...
@@ -166,6 +166,8 @@ void multi_tensor_sgd_cuda(
// 2. fp32, fp32, fp32, No
// 3. fp16, fp32, fp32, Yes
// 4. fp32, fp32, fp32, Yes // this is the materialize_master_grads=True case
// 5. bfp16, bfp16, bfp16, No
// 6. bfp16, fp32, fp32, Yes
// It's easier to hardcode these possibilities than to use
// switches etc. to handle the cross-product of cases where
// we don't want the majority of them.
...
...
@@ -268,6 +270,46 @@ void multi_tensor_sgd_cuda(
wd_after_momentum
,
scale
);
}
// Case 5. bfp16, bfp16, bfp16, No
if
(
grad_type
==
at
::
ScalarType
::
BFloat16
&&
weight_type
==
at
::
ScalarType
::
BFloat16
&&
num_tensors
==
3
)
{
multi_tensor_apply
<
3
>
(
BLOCK_SIZE
,
chunk_size
,
noop_flag
,
tensor_lists
,
SGDFunctor
<
3
,
at
::
BFloat16
,
at
::
BFloat16
>
(),
wd
,
momentum
,
dampening
,
lr
,
nesterov
,
first_run
,
wd_after_momentum
,
scale
);
}
// Case 6. bfp16, fp32, fp32, Yes
else
if
(
grad_type
==
at
::
ScalarType
::
BFloat16
&&
weight_type
==
at
::
ScalarType
::
Float
&&
num_tensors
==
4
)
{
multi_tensor_apply
<
4
>
(
BLOCK_SIZE
,
chunk_size
,
noop_flag
,
tensor_lists
,
SGDFunctor
<
4
,
at
::
BFloat16
,
float
>
(),
wd
,
momentum
,
dampening
,
lr
,
nesterov
,
first_run
,
wd_after_momentum
,
scale
);
}
else
{
AT_ERROR
(
"multi_tensor_sgd only supports some combinations of gradient & weight types. Given: "
,
...
...
csrc/type_shim.h
View file @
4ac8ecb9
...
...
@@ -105,6 +105,66 @@
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
// TODO: We might have come up with an optimal set of dispatch macros by
// changing the signature to have an integer suffix of number of types
// to dispatch for as defined in upstream (e.g AT_DISPATCH_FLOATING_TYPES_AND2)
// Refactor once all the extension ops are enabled.
#define DISPATCH_FLOAT_AND_HALF_AND_BFLOAT16(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t_##LEVEL = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_FLOAT_AND_HALF_AND_BFLOAT16(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Double: \
{ \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t_##LEVEL = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
template
<
typename
T
>
__device__
__forceinline__
T
reduce_block_into_lanes
...
...
tests/L0/run_amp/test_add_param_group.py
View file @
4ac8ecb9
...
...
@@ -14,11 +14,11 @@ from utils import common_init, HALF, FLOAT,\
ALWAYS_HALF
,
ALWAYS_FLOAT
,
MATCH_INPUT
class
MyModel
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
unique
):
def
__init__
(
self
,
unique
,
dtype
=
torch
.
float16
):
super
(
MyModel
,
self
).
__init__
()
self
.
weight0
=
Parameter
(
unique
+
torch
.
arange
(
2
,
device
=
'cuda'
,
dtype
=
torch
.
float32
))
self
.
weight1
=
Parameter
(
1.
+
unique
+
torch
.
arange
(
2
,
device
=
'cuda'
,
dtype
=
torch
.
float16
))
self
.
weight1
=
Parameter
(
1.
+
unique
+
torch
.
arange
(
2
,
device
=
'cuda'
,
dtype
=
dtype
))
@
staticmethod
def
ops
(
input
,
weight0
,
weight1
):
...
...
@@ -51,11 +51,15 @@ class TestAddParamGroup(unittest.TestCase):
optimizer
.
zero_grad
()
def
test_add_param_group
(
self
):
for
opt_level
in
(
"O0"
,
"O1"
,
"O2"
,
"O3"
):
for
opt_level
in
(
"O0"
,
"O1"
,
"O2"
,
"O3"
,
"O4"
,
"O5"
):
for
zero_before_add
in
(
True
,
False
):
for
try_accumulation
in
(
True
,
False
):
model0
=
MyModel
(
1
)
model1
=
MyModel
(
2
)
if
opt_level
in
{
"O4"
,
"O5"
}:
model0
=
MyModel
(
1
,
torch
.
bfloat16
)
model1
=
MyModel
(
2
,
torch
.
bfloat16
)
else
:
model0
=
MyModel
(
1
)
model1
=
MyModel
(
2
)
optimizer
=
torch
.
optim
.
SGD
([{
'params'
:
model0
.
parameters
(),
'lr'
:
0.25
}],
momentum
=
0.125
)
...
...
@@ -89,8 +93,12 @@ class TestAddParamGroup(unittest.TestCase):
[
param
.
data
.
clone
()
for
param
in
model1
.
parameters
()]
for
how_to_zero
in
"none"
,
"model"
,
"optimizer"
:
model0
=
MyModel
(
1
)
model1
=
MyModel
(
2
)
if
opt_level
in
{
"O4"
,
"O5"
}:
model0
=
MyModel
(
1
,
torch
.
bfloat16
)
model1
=
MyModel
(
2
,
torch
.
bfloat16
)
else
:
model0
=
MyModel
(
1
)
model1
=
MyModel
(
2
)
optimizer
=
torch
.
optim
.
SGD
([{
'params'
:
model0
.
parameters
(),
'lr'
:
0.25
}],
momentum
=
0.125
)
...
...
@@ -139,6 +147,9 @@ class TestAddParamGroup(unittest.TestCase):
[
param
.
data
.
clone
()
for
param
in
model1
.
parameters
()]
for
reference
,
final
in
zip
(
reference_params
,
final_params
):
# TODO: remove the conversion once allclose supports bfloat16 type.
if
final
.
dtype
==
torch
.
bfloat16
:
final
=
final
.
float
()
self
.
assertTrue
(
torch
.
allclose
(
reference
.
to
(
final
.
dtype
),
final
),
"opt_level = {}, how_to_zero = {}, zero_before_add = {}"
.
format
(
opt_level
,
how_to_zero
,
zero_before_add
))
...
...
tests/L0/run_amp/test_basic_casts.py
View file @
4ac8ecb9
...
...
@@ -9,7 +9,7 @@ from torch import nn
import
torch.nn.functional
as
F
from
utils
import
common_init
,
HALF
,
FLOAT
,
\
ALWAYS_HALF
,
ALWAYS_FLOAT
,
MATCH_INPUT
ALWAYS_HALF
,
ALWAYS_BFLOAT16
,
ALWAYS_FLOAT
,
MATCH_INPUT
def
run_layer_test
(
test_case
,
fns
,
expected
,
input_shape
,
test_backward
=
True
):
for
fn
,
typ
in
it
.
product
(
fns
,
expected
.
keys
()):
...
...
@@ -20,124 +20,233 @@ def run_layer_test(test_case, fns, expected, input_shape, test_backward=True):
y
.
float
().
sum
().
backward
()
test_case
.
assertEqual
(
x
.
grad
.
type
(),
MATCH_INPUT
[
typ
])
class
TestBasicCasts
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
handle
=
amp
.
init
(
enabled
=
True
)
common_init
(
self
)
def
tearDown
(
self
):
self
.
handle
.
_deactivate
()
def
test_linear_is_half
(
self
):
class
_TestBasicCasts
(
unittest
.
TestCase
):
def
_test_linear
(
self
,
expected
):
m
=
nn
.
Linear
(
self
.
h
,
self
.
h
)
f
=
ft
.
partial
(
F
.
linear
,
weight
=
m
.
weight
,
bias
=
m
.
bias
)
run_layer_test
(
self
,
[
m
,
f
],
ALWAYS_HALF
,
(
self
.
b
,
self
.
h
))
run_layer_test
(
self
,
[
m
,
f
],
expected
,
(
self
.
b
,
self
.
h
))
def
test_conv2d
_is_half
(
self
):
def
_
test_conv2d
(
self
,
expected
):
m
=
nn
.
Conv2d
(
self
.
c
,
self
.
c
,
self
.
k
)
f
=
ft
.
partial
(
F
.
conv2d
,
weight
=
m
.
weight
,
bias
=
m
.
bias
)
run_layer_test
(
self
,
[
m
,
f
],
ALWAYS_HALF
,
(
self
.
b
,
self
.
c
,
self
.
h
,
self
.
h
))
run_layer_test
(
self
,
[
m
,
f
],
expected
,
(
self
.
b
,
self
.
c
,
self
.
h
,
self
.
h
))
def
test_softmax
_is_float
(
self
):
def
_
test_softmax
(
self
,
expected
):
m
=
nn
.
Softmax
(
dim
=
1
)
f
=
ft
.
partial
(
F
.
softmax
,
dim
=
1
)
run_layer_test
(
self
,
[
m
,
f
],
ALWAYS_FLOAT
,
(
self
.
b
,
self
.
h
))
run_layer_test
(
self
,
[
m
,
f
],
expected
,
(
self
.
b
,
self
.
h
))
def
test_group_norm
_is_float
(
self
):
def
_
test_group_norm
(
self
,
expected
):
m
=
nn
.
GroupNorm
(
num_groups
=
4
,
num_channels
=
self
.
c
)
run_layer_test
(
self
,
[
m
],
ALWAYS_FLOAT
,
(
self
.
b
,
self
.
c
,
self
.
h
,
self
.
h
))
run_layer_test
(
self
,
[
m
],
expected
,
(
self
.
b
,
self
.
c
,
self
.
h
,
self
.
h
))
def
test_mse_loss
_is_float
(
self
):
def
_
test_mse_loss
(
self
,
expected
):
shape
=
(
self
.
b
,
self
.
h
)
target
=
torch
.
randn
(
shape
)
mod
=
nn
.
MSELoss
()
m
=
lambda
x
:
mod
(
x
,
target
)
f
=
ft
.
partial
(
F
.
mse_loss
,
target
=
target
)
run_layer_test
(
self
,
[
m
],
ALWAYS_FLOAT
,
shape
)
run_layer_test
(
self
,
[
m
],
expected
,
shape
)
def
test_relu
_is_match
(
self
):
run_layer_test
(
self
,
[
nn
.
ReLU
(),
F
.
relu
],
MATCH_INPUT
,
(
self
.
b
,
self
.
h
))
def
_
test_relu
(
self
,
expected
):
run_layer_test
(
self
,
[
nn
.
ReLU
(),
F
.
relu
],
expected
,
(
self
.
b
,
self
.
h
))
def
test_batch_norm
_is_match
(
self
):
def
_
test_batch_norm
(
self
,
expected
):
m
=
nn
.
BatchNorm2d
(
num_features
=
self
.
c
)
f
=
ft
.
partial
(
F
.
batch_norm
,
running_mean
=
m
.
running_mean
,
running_var
=
m
.
running_var
,
weight
=
m
.
weight
,
bias
=
m
.
bias
,
training
=
True
)
run_layer_test
(
self
,
[
m
],
MATCH_INPUT
,
(
self
.
b
,
self
.
c
,
self
.
h
,
self
.
h
))
run_layer_test
(
self
,
[
m
],
expected
,
(
self
.
b
,
self
.
c
,
self
.
h
,
self
.
h
))
# Test forward-only for BN inference
m
.
eval
()
f
=
ft
.
partial
(
F
.
batch_norm
,
running_mean
=
m
.
running_mean
,
running_var
=
m
.
running_var
,
weight
=
m
.
weight
,
bias
=
m
.
bias
,
training
=
False
)
run_layer_test
(
self
,
[
m
,
f
],
MATCH_INPUT
,
(
self
.
b
,
self
.
c
,
self
.
h
,
self
.
h
),
run_layer_test
(
self
,
[
m
,
f
],
expected
,
(
self
.
b
,
self
.
c
,
self
.
h
,
self
.
h
),
test_backward
=
False
)
class
TestBasicCastsHalf
(
_TestBasicCasts
):
def
setUp
(
self
):
self
.
handle
=
amp
.
init
(
enabled
=
True
,
patch_type
=
torch
.
half
)
common_init
(
self
)
def
tearDown
(
self
):
self
.
handle
.
_deactivate
()
def
test_linear_is_half
(
self
):
self
.
_test_linear
(
ALWAYS_HALF
)
def
test_conv2d_is_half
(
self
):
self
.
_test_conv2d
(
ALWAYS_HALF
)
def
test_softmax_is_float
(
self
):
self
.
_test_softmax
(
ALWAYS_FLOAT
)
def
test_group_norm_is_float
(
self
):
self
.
_test_group_norm
(
ALWAYS_FLOAT
)
def
test_mse_loss_is_float
(
self
):
self
.
_test_mse_loss
(
ALWAYS_FLOAT
)
def
test_relu_is_match
(
self
):
self
.
_test_relu
(
MATCH_INPUT
)
def
test_batch_norm_is_match
(
self
):
self
.
_test_batch_norm
(
MATCH_INPUT
)
class
TestBasicCastsBFloat16
(
_TestBasicCasts
):
def
setUp
(
self
):
self
.
handle
=
amp
.
init
(
enabled
=
True
,
patch_type
=
torch
.
bfloat16
)
common_init
(
self
)
def
tearDown
(
self
):
self
.
handle
.
_deactivate
()
def
test_linear_is_bfloat16
(
self
):
self
.
_test_linear
(
ALWAYS_BFLOAT16
)
def
test_conv2d_is_bfloat16
(
self
):
self
.
_test_conv2d
(
ALWAYS_BFLOAT16
)
def
test_softmax_is_float
(
self
):
self
.
_test_softmax
(
ALWAYS_FLOAT
)
def
test_group_norm_is_float
(
self
):
self
.
_test_group_norm
(
ALWAYS_FLOAT
)
def
test_mse_loss_is_float
(
self
):
self
.
_test_mse_loss
(
ALWAYS_FLOAT
)
def
test_relu_is_match
(
self
):
self
.
_test_relu
(
MATCH_INPUT
)
def
test_batch_norm_is_match
(
self
):
self
.
_test_batch_norm
(
MATCH_INPUT
)
class
TestBannedMethods
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
handle
=
amp
.
init
(
enabled
=
True
)
self
.
handle
=
amp
.
init
(
enabled
=
True
,
patch_type
=
torch
.
half
)
common_init
(
self
)
def
tearDown
(
self
):
self
.
handle
.
_deactivate
()
def
bce_common
(
self
,
assertion
):
def
bce_common
(
self
,
assertion
,
dtype
=
torch
.
half
):
shape
=
(
self
.
b
,
self
.
h
)
target
=
torch
.
rand
(
shape
)
mod
=
nn
.
BCELoss
()
m
=
lambda
x
:
mod
(
x
,
target
)
f
=
ft
.
partial
(
F
.
binary_cross_entropy
,
target
=
target
)
for
fn
in
[
m
,
f
]:
x
=
torch
.
rand
(
shape
,
dtype
=
torch
.
half
)
x
=
torch
.
rand
(
shape
,
dtype
=
dtype
)
assertion
(
fn
,
x
)
def
test_bce_raises_by_default
(
self
):
assertion
=
lambda
fn
,
x
:
self
.
assertRaises
(
NotImplementedError
,
fn
,
x
)
self
.
bce_common
(
assertion
)
self
.
bce_common
(
assertion
,
dtype
=
torch
.
half
)
# handle with bfloat16 as patch_type
self
.
handle
.
_deactivate
()
self
.
handle
=
amp
.
init
(
enabled
=
True
,
patch_type
=
torch
.
bfloat16
)
self
.
bce_common
(
assertion
,
dtype
=
torch
.
bfloat16
)
def
test_bce_is_float_with_allow_banned
(
self
):
self
.
handle
.
_deactivate
()
self
.
handle
=
amp
.
init
(
enabled
=
True
,
allow_banned
=
True
)
self
.
handle
=
amp
.
init
(
enabled
=
True
,
allow_banned
=
True
,
patch_type
=
torch
.
half
)
assertion
=
lambda
fn
,
x
:
self
.
assertEqual
(
fn
(
x
).
type
(),
FLOAT
)
self
.
bce_common
(
assertion
)
self
.
bce_common
(
assertion
,
dtype
=
torch
.
half
)
class
TestTensorCasts
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
handle
=
amp
.
init
(
enabled
=
True
)
common_init
(
self
)
def
tearDown
(
self
):
# handle with bfloat16 as patch_type
self
.
handle
.
_deactivate
()
self
.
handle
=
amp
.
init
(
enabled
=
True
,
allow_banned
=
True
,
patch_type
=
torch
.
bfloat16
)
self
.
bce_common
(
assertion
,
dtype
=
torch
.
bfloat16
)
def
test_matmul_method_is_half
(
self
):
class
_TestTensorCasts
(
unittest
.
TestCase
):
def
_test_matmul_method
(
self
,
expected
):
other
=
torch
.
randn
(
self
.
h
,
self
.
h
)
lhs
=
lambda
x
:
x
.
matmul
(
other
)
rhs
=
lambda
x
:
other
.
matmul
(
x
)
run_layer_test
(
self
,
[
lhs
,
rhs
],
ALWAYS_HALF
,
(
self
.
h
,
self
.
h
))
run_layer_test
(
self
,
[
lhs
,
rhs
],
expected
,
(
self
.
h
,
self
.
h
))
def
test_matmul_op
_is_half
(
self
):
def
_
test_matmul_op
(
self
,
expected
):
other
=
torch
.
randn
(
self
.
h
,
self
.
h
)
lhs
=
lambda
x
:
x
@
other
rhs
=
lambda
x
:
other
@
x
run_layer_test
(
self
,
[
lhs
,
rhs
],
ALWAYS_HALF
,
(
self
.
h
,
self
.
h
))
run_layer_test
(
self
,
[
lhs
,
rhs
],
expected
,
(
self
.
h
,
self
.
h
))
def
test_pow_method
_is_float
(
self
):
def
_
test_pow_method
(
self
,
expected
):
fn
=
lambda
x
:
x
.
pow
(
2.
)
run_layer_test
(
self
,
[
fn
],
ALWAYS_FLOAT
,
(
self
.
b
,
self
.
h
))
run_layer_test
(
self
,
[
fn
],
expected
,
(
self
.
b
,
self
.
h
))
def
test_pow_op
_is_float
(
self
):
def
_
test_pow_op
(
self
,
expected
):
fn
=
lambda
x
:
x
**
2.
run_layer_test
(
self
,
[
fn
],
ALWAYS_FLOAT
,
(
self
.
b
,
self
.
h
))
run_layer_test
(
self
,
[
fn
],
expected
,
(
self
.
b
,
self
.
h
))
def
test_cpu
_is_float
(
self
):
def
_
test_cpu
(
self
,
expected
):
fn
=
lambda
x
:
x
.
cpu
()
run_layer_test
(
self
,
[
fn
],
expected
,
(
self
.
b
,
self
.
h
))
def
_test_sum
(
self
,
expected
):
fn
=
lambda
x
:
x
.
sum
()
run_layer_test
(
self
,
[
fn
],
expected
,
(
self
.
b
,
self
.
h
))
# TODO: maybe more tests on disabled casting?
class
TestTensorCastsHalf
(
_TestTensorCasts
):
def
setUp
(
self
):
self
.
handle
=
amp
.
init
(
enabled
=
True
,
patch_type
=
torch
.
half
)
common_init
(
self
)
def
tearDown
(
self
):
self
.
handle
.
_deactivate
()
def
test_matmul_method_is_half
(
self
):
self
.
_test_matmul_method
(
ALWAYS_HALF
)
def
test_matmul_op_is_half
(
self
):
self
.
_test_matmul_op
(
ALWAYS_HALF
)
def
test_pow_method_is_float
(
self
):
self
.
_test_pow_method
(
ALWAYS_FLOAT
)
def
test_pow_op_is_float
(
self
):
self
.
_test_pow_op
(
ALWAYS_FLOAT
)
def
test_cpu_is_float
(
self
):
always_cpu_float
=
{
torch
.
float
:
'torch.FloatTensor'
,
torch
.
half
:
'torch.FloatTensor'
}
run_layer_test
(
self
,
[
fn
],
always_cpu_float
,
(
self
.
b
,
self
.
h
)
)
self
.
_test_cpu
(
always_cpu_float
)
def
test_sum_is_float
(
self
):
fn
=
lambda
x
:
x
.
sum
()
run_layer_test
(
self
,
[
fn
],
ALWAYS_FLOAT
,
(
self
.
b
,
self
.
h
))
self
.
_test_sum
(
ALWAYS_FLOAT
)
class
TestTensorCastsBFloat16
(
_TestTensorCasts
):
def
setUp
(
self
):
self
.
handle
=
amp
.
init
(
enabled
=
True
,
patch_type
=
torch
.
bfloat16
)
common_init
(
self
)
def
tearDown
(
self
):
self
.
handle
.
_deactivate
()
def
test_matmul_method_is_bfloat16
(
self
):
self
.
_test_matmul_method
(
ALWAYS_BFLOAT16
)
def
test_matmul_op_is_bfloat16
(
self
):
self
.
_test_matmul_op
(
ALWAYS_BFLOAT16
)
def
test_pow_method_is_float
(
self
):
self
.
_test_pow_method
(
ALWAYS_FLOAT
)
def
test_pow_op_is_float
(
self
):
self
.
_test_pow_op
(
ALWAYS_FLOAT
)
def
test_cpu_is_float
(
self
):
always_cpu_float
=
{
torch
.
float
:
'torch.FloatTensor'
,
torch
.
bfloat16
:
'torch.FloatTensor'
}
self
.
_test_cpu
(
always_cpu_float
)
def
test_sum_is_float
(
self
):
self
.
_test_sum
(
ALWAYS_FLOAT
)
# TODO: maybe more tests on disabled casting?
if
__name__
==
'__main__'
:
unittest
.
main
()
tests/L0/run_amp/test_cache.py
View file @
4ac8ecb9
...
...
@@ -67,12 +67,12 @@ class TestCache(unittest.TestCase):
def
tearDown
(
self
):
pass
def
train_eval_train_test
(
self
,
module
,
t
):
def
train_eval_train_test
(
self
,
module
,
t
,
opt_level
):
model
=
module
(
t
).
cuda
()
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
1.0
)
_amp_state
.
allow_incoming_model_not_fp32
=
True
model
,
optimizer
=
amp
.
initialize
(
model
,
optimizer
,
opt_level
=
"O1"
,
verbosity
=
0
)
model
,
optimizer
=
amp
.
initialize
(
model
,
optimizer
,
opt_level
=
opt_level
,
verbosity
=
0
)
_amp_state
.
allow_incoming_model_not_fp32
=
False
def
training_step
():
...
...
@@ -93,6 +93,8 @@ class TestCache(unittest.TestCase):
# but I'm keeping this in case we want different tolerances for fp16 and fp32 checks.
if
model
.
weight
.
grad
.
type
()
==
"torch.cuda.HalfTensor"
:
self
.
assertTrue
(
torch
.
allclose
(
model
.
weight
.
grad
.
float
(),
reference_grad
))
elif
model
.
weight
.
grad
.
type
()
==
"torch.cuda.BFloat16Tensor"
:
self
.
assertTrue
(
torch
.
allclose
(
model
.
weight
.
grad
.
float
(),
reference_grad
))
elif
model
.
weight
.
grad
.
type
()
==
"torch.cuda.FloatTensor"
:
self
.
assertTrue
(
torch
.
allclose
(
model
.
weight
.
grad
.
float
(),
reference_grad
))
else
:
...
...
@@ -115,22 +117,41 @@ class TestCache(unittest.TestCase):
# I could easily have these as a set of for loops in a single test,
# instead of going for granularity.
def
test_whitelist_module_fp16_weight
(
self
):
self
.
train_eval_train_test
(
WhitelistModule
,
torch
.
float16
)
self
.
train_eval_train_test
(
WhitelistModule
,
torch
.
float16
,
"O1"
)
def
test_whitelist_module_fp32_weight
(
self
):
self
.
train_eval_train_test
(
WhitelistModule
,
torch
.
float32
)
self
.
train_eval_train_test
(
WhitelistModule
,
torch
.
float32
,
"O1"
)
def
test_blacklist_module_fp16_weight
(
self
):
self
.
train_eval_train_test
(
BlacklistModule
,
torch
.
float16
)
self
.
train_eval_train_test
(
BlacklistModule
,
torch
.
float16
,
"O1"
)
def
test_blacklist_module_fp32_weight
(
self
):
self
.
train_eval_train_test
(
BlacklistModule
,
torch
.
float32
)
self
.
train_eval_train_test
(
BlacklistModule
,
torch
.
float32
,
"O1"
)
def
test_promote_module_fp16_weight
(
self
):
self
.
train_eval_train_test
(
PromoteModule
,
torch
.
float16
)
self
.
train_eval_train_test
(
PromoteModule
,
torch
.
float16
,
"O1"
)
def
test_promote_module_fp32_weight
(
self
):
self
.
train_eval_train_test
(
PromoteModule
,
torch
.
float32
,
"O1"
)
# opt_level = O4
def
test_whitelist_module_bfp16_weight
(
self
):
self
.
train_eval_train_test
(
WhitelistModule
,
torch
.
bfloat16
,
"O4"
)
def
test_whitelist_module_fp32_weight
(
self
):
self
.
train_eval_train_test
(
WhitelistModule
,
torch
.
float32
,
"O4"
)
def
test_blacklist_module_bfp16_weight
(
self
):
self
.
train_eval_train_test
(
BlacklistModule
,
torch
.
bfloat16
,
"O4"
)
def
test_blacklist_module_fp32_weight
(
self
):
self
.
train_eval_train_test
(
BlacklistModule
,
torch
.
float32
,
"O4"
)
def
test_promote_module_bfp16_weight
(
self
):
self
.
train_eval_train_test
(
PromoteModule
,
torch
.
bfloat16
,
"O4"
)
def
test_promote_module_fp32_weight
(
self
):
self
.
train_eval_train_test
(
PromoteModule
,
torch
.
float32
)
self
.
train_eval_train_test
(
PromoteModule
,
torch
.
float32
,
"O4"
)
if
__name__
==
'__main__'
:
...
...
tests/L0/run_amp/test_checkpointing.py
View file @
4ac8ecb9
...
...
@@ -28,7 +28,7 @@ class MyModel(torch.nn.Module):
class
TestCheckpointing
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
initial_lr
=
1e-3
self
.
test_opt_levels
=
(
"O0"
,
"O1"
,
"O2"
,
"O3"
)
self
.
test_opt_levels
=
(
"O0"
,
"O1"
,
"O2"
,
"O3"
,
"O4"
,
"O5"
)
def
seed
(
self
):
torch
.
manual_seed
(
2809
)
...
...
@@ -237,6 +237,7 @@ class TestCheckpointing(unittest.TestCase):
state_dict
=
model
.
state_dict
()
for
key
in
state_dict
:
self
.
assertFalse
(
'Half'
in
state_dict
[
key
].
type
())
self
.
assertFalse
(
'BFloat16'
in
state_dict
[
key
].
type
())
# Check, if model is still trainable
# Create dummy data
...
...
tests/L0/run_amp/test_multi_tensor_axpby.py
View file @
4ac8ecb9
...
...
@@ -71,7 +71,10 @@ class TestMultiTensorAxpby(unittest.TestCase):
applier
(
multi_tensor_axpby
,
self
.
overflow_buf
,
[
x_list
,
y_list
,
out_list
],
self
.
a
,
self
.
b
,
-
1
)
self
.
assertTrue
(
all
([
torch
.
allclose
(
out
,
self
.
ref
.
to
(
out_type
))
for
out
in
out_list
]),
# TODO: Remove this workaround for bfloat16 after torch.allcose() support bfloat16
if
out_type
==
torch
.
bfloat16
:
out_list
=
[
out
.
float
()
for
out
in
out_list
]
self
.
assertTrue
(
all
([
torch
.
allclose
(
out
,
self
.
ref
.
to
(
out
.
dtype
))
for
out
in
out_list
]),
msg
=
"{} {} {} {} {} {} {}"
.
format
(
sizea
,
sizeb
,
repeat_tensors
,
x_type
,
y_type
,
out_type
,
inplace
))
self
.
assertTrue
(
self
.
overflow_buf
.
item
()
==
0
,
...
...
@@ -121,9 +124,9 @@ class TestMultiTensorAxpby(unittest.TestCase):
for
sizea
,
sizeb
in
input_size_pairs
:
for
applier
in
appliers
:
for
repeat
in
repeat_tensors
:
for
x_type
in
(
torch
.
float32
,
torch
.
float16
):
for
y_type
in
(
torch
.
float32
,
torch
.
float16
):
for
out_type
in
(
torch
.
float32
,
torch
.
float16
):
for
x_type
in
(
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
):
for
y_type
in
(
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
):
for
out_type
in
(
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
):
for
inplace
in
(
True
,
False
):
if
inplace
is
True
and
(
y_type
is
not
out_type
):
continue
...
...
tests/L0/run_amp/test_multi_tensor_scale.py
View file @
4ac8ecb9
...
...
@@ -49,7 +49,10 @@ class TestMultiTensorScale(unittest.TestCase):
applier
(
multi_tensor_scale
,
self
.
overflow_buf
,
[
in_list
,
out_list
],
1.
/
self
.
scale
)
self
.
assertTrue
(
all
([
torch
.
allclose
(
out
,
self
.
ref
.
to
(
out_type
))
for
out
in
out_list
]))
# TODO: Remove this workaround for bfloat16 after torch.allcose() support bfloat16
if
out_type
==
torch
.
bfloat16
:
out_list
=
[
out
.
float
()
for
out
in
out_list
]
self
.
assertTrue
(
all
([
torch
.
allclose
(
out
,
self
.
ref
.
to
(
out
.
dtype
))
for
out
in
out_list
]))
self
.
assertTrue
(
self
.
overflow_buf
.
item
()
==
0
)
def
find_inf
(
self
,
sizea
,
sizeb
,
applier
,
repeat_tensors
,
in_type
,
out_type
,
t
,
ind
,
val
,
inplace
=
False
):
...
...
@@ -106,8 +109,8 @@ class TestMultiTensorScale(unittest.TestCase):
for
sizea
,
sizeb
in
input_size_pairs
:
for
applier
in
appliers
:
for
repeat
in
repeat_tensors
:
for
in_type
in
(
torch
.
float32
,
torch
.
float16
):
for
out_type
in
(
torch
.
float32
,
torch
.
float16
):
for
in_type
in
(
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
):
for
out_type
in
(
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
):
for
inplace
in
(
True
,
False
):
if
inplace
is
True
and
(
out_type
is
not
in_type
):
continue
...
...
tests/L0/run_amp/test_promotion.py
View file @
4ac8ecb9
...
...
@@ -7,18 +7,18 @@ import torch
from
torch
import
nn
import
torch.nn.functional
as
F
from
utils
import
common_init
,
HALF
,
FLOAT
,
DTYPES
from
utils
import
common_init
,
HALF
,
FLOAT
,
DTYPES
,
DTYPES2
,
MATCH_INPUT
class
TestPromotion
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
handl
e
=
amp
.
init
(
enabled
=
True
)
common_init
(
self
)
def
tearDown
(
self
):
s
el
f
.
handle
.
_deactivate
()
def
run_binary_promote_test
(
self
,
fns
,
input_shape
,
x_inplace
=
False
):
type_pairs
=
it
.
product
(
DTYPES
,
DTYPES
)
class
_
TestPromotion
(
unittest
.
TestCase
):
def
run_binary_promote_test
(
self
,
fns
,
input_shape
,
lp_type
,
x_inplace
=
False
):
if
lp_typ
e
=
=
torch
.
half
:
dtypes
=
DTYPES
elif
lp_type
==
torch
.
bfloat16
:
dtypes
=
DTYPES2
el
se
:
raise
RuntimeError
(
"Creating test class with invalid low_precision type.
\
Supported types are torch.half and torch.bfloat16"
)
type_pairs
=
it
.
product
(
dtypes
,
dtypes
)
for
fn
,
(
xtype
,
ytype
)
in
it
.
product
(
fns
,
type_pairs
):
x
=
torch
.
randn
(
input_shape
,
dtype
=
xtype
).
requires_grad_
()
x_leaf
=
x
...
...
@@ -35,41 +35,78 @@ class TestPromotion(unittest.TestCase):
if
xtype
==
torch
.
float
or
ytype
==
torch
.
float
:
self
.
assertEqual
(
out
.
type
(),
FLOAT
)
else
:
self
.
assertEqual
(
out
.
type
(),
HALF
)
self
.
assertEqual
(
out
.
type
(),
MATCH_INPUT
[
lp_type
]
)
out
.
float
().
sum
().
backward
()
self
.
assertEqual
(
x_leaf
.
grad
.
dtype
,
xtype
)
def
_test_cat_matches_widest
(
self
,
lp_type
):
shape
=
self
.
b
ys
=
[
torch
.
randn
(
shape
,
dtype
=
lp_type
)
for
_
in
range
(
5
)]
x_float
=
torch
.
randn
(
shape
)
out
=
torch
.
cat
(
ys
+
[
x_float
])
self
.
assertEqual
(
out
.
type
(),
FLOAT
)
x_lp
=
torch
.
randn
(
shape
,
dtype
=
lp_type
)
out
=
torch
.
cat
(
ys
+
[
x_lp
])
self
.
assertEqual
(
out
.
type
(),
MATCH_INPUT
[
lp_type
])
def
_test_inplace_exp_is_error_for_lp
(
self
,
lp_type
):
xs
=
torch
.
randn
(
self
.
b
)
xs
.
exp_
()
self
.
assertEqual
(
xs
.
type
(),
FLOAT
)
xs
=
torch
.
randn
(
self
.
b
,
dtype
=
lp_type
)
with
self
.
assertRaises
(
NotImplementedError
):
xs
.
exp_
()
class
TestPromotionHalf
(
_TestPromotion
):
def
setUp
(
self
):
self
.
handle
=
amp
.
init
(
enabled
=
True
,
patch_type
=
torch
.
half
)
common_init
(
self
)
def
tearDown
(
self
):
self
.
handle
.
_deactivate
()
def
test_atan2_matches_widest
(
self
):
fns
=
[
lambda
x
,
y
:
torch
.
atan2
(
x
,
y
),
lambda
x
,
y
:
x
.
atan2
(
y
)]
self
.
run_binary_promote_test
(
fns
,
(
self
.
b
,))
self
.
run_binary_promote_test
(
fns
,
(
self
.
b
,)
,
torch
.
half
)
def
test_mul_matches_widest
(
self
):
fns
=
[
lambda
x
,
y
:
torch
.
mul
(
x
,
y
),
lambda
x
,
y
:
x
.
mul
(
y
)]
self
.
run_binary_promote_test
(
fns
,
(
self
.
b
,))
self
.
run_binary_promote_test
(
fns
,
(
self
.
b
,)
,
torch
.
half
)
def
test_cat_matches_widest
(
self
):
shape
=
self
.
b
ys
=
[
torch
.
randn
(
shape
,
dtype
=
torch
.
half
)
for
_
in
range
(
5
)]
x_float
=
torch
.
randn
(
shape
)
out
=
torch
.
cat
(
ys
+
[
x_float
])
self
.
assertEqual
(
out
.
type
(),
FLOAT
)
x_half
=
torch
.
randn
(
shape
,
dtype
=
torch
.
half
)
out
=
torch
.
cat
(
ys
+
[
x_half
])
self
.
assertEqual
(
out
.
type
(),
HALF
)
self
.
_test_cat_matches_widest
(
torch
.
half
)
def
test_inplace_exp_is_error_for_half
(
self
):
xs
=
torch
.
randn
(
self
.
b
)
xs
.
exp_
()
self
.
assertEqual
(
xs
.
type
(),
FLOAT
)
xs
=
torch
.
randn
(
self
.
b
,
dtype
=
torch
.
half
)
with
self
.
assertRaises
(
NotImplementedError
):
xs
.
exp_
()
self
.
_test_inplace_exp_is_error_for_lp
(
torch
.
half
)
def
test_inplace_add_matches_self
(
self
):
fn
=
lambda
x
,
y
:
x
.
add_
(
y
)
self
.
run_binary_promote_test
([
fn
],
(
self
.
b
,),
torch
.
half
,
x_inplace
=
True
)
class
TestPromotionBFloat16
(
_TestPromotion
):
def
setUp
(
self
):
self
.
handle
=
amp
.
init
(
enabled
=
True
,
patch_type
=
torch
.
bfloat16
)
common_init
(
self
)
def
tearDown
(
self
):
self
.
handle
.
_deactivate
()
def
test_mul_matches_widest
(
self
):
fns
=
[
lambda
x
,
y
:
torch
.
mul
(
x
,
y
),
lambda
x
,
y
:
x
.
mul
(
y
)]
self
.
run_binary_promote_test
(
fns
,
(
self
.
b
,),
torch
.
bfloat16
)
def
test_cat_matches_widest
(
self
):
self
.
_test_cat_matches_widest
(
torch
.
bfloat16
)
def
test_inplace_exp_is_error_for_bfloat16
(
self
):
self
.
_test_inplace_exp_is_error_for_lp
(
torch
.
bfloat16
)
def
test_inplace_add_matches_self
(
self
):
fn
=
lambda
x
,
y
:
x
.
add_
(
y
)
self
.
run_binary_promote_test
([
fn
],
(
self
.
b
,),
x_inplace
=
True
)
self
.
run_binary_promote_test
([
fn
],
(
self
.
b
,),
torch
.
bfloat16
,
x_inplace
=
True
)
if
__name__
==
'__main__'
:
unittest
.
main
()
tests/L0/run_amp/utils.py
View file @
4ac8ecb9
...
...
@@ -2,15 +2,21 @@ import torch
HALF
=
'torch.cuda.HalfTensor'
FLOAT
=
'torch.cuda.FloatTensor'
BFLOAT16
=
'torch.cuda.BFloat16Tensor'
DTYPES
=
[
torch
.
half
,
torch
.
float
]
DTYPES2
=
[
torch
.
bfloat16
,
torch
.
float
]
ALWAYS_HALF
=
{
torch
.
float
:
HALF
,
torch
.
half
:
HALF
}
ALWAYS_BFLOAT16
=
{
torch
.
bfloat16
:
BFLOAT16
,
torch
.
float
:
BFLOAT16
}
ALWAYS_FLOAT
=
{
torch
.
float
:
FLOAT
,
torch
.
half
:
FLOAT
}
MATCH_INPUT
=
{
torch
.
float
:
FLOAT
,
torch
.
half
:
HALF
}
torch
.
half
:
HALF
,
torch
.
bfloat16
:
BFLOAT16
}
def
common_init
(
test_case
):
test_case
.
h
=
64
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment