Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
5f5dfa42
"...text-generation-inference.git" did not exist on "70056d1e9c53dc85d00690cd20fab22f26fbbc46"
Commit
5f5dfa42
authored
Sep 10, 2018
by
Carl Case
Browse files
add rnn tests
parent
39928327
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
98 additions
and
0 deletions
+98
-0
apex/amp/test/test_rnn.py
apex/amp/test/test_rnn.py
+97
-0
apex/amp/test/utils.py
apex/amp/test/utils.py
+1
-0
No files found.
apex/amp/test/test_rnn.py
0 → 100644
View file @
5f5dfa42
import
unittest
from
apex
import
amp
import
torch
from
torch
import
nn
from
.utils
import
common_init
,
HALF
class
TestRnnCells
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
handle
=
amp
.
init
(
enabled
=
True
)
common_init
(
self
)
def
tearDown
(
self
):
self
.
handle
.
_deactivate
()
def
run_cell_test
(
self
,
cell
,
state_tuple
=
False
):
shape
=
(
self
.
b
,
self
.
h
)
for
typ
in
[
torch
.
float
,
torch
.
half
]:
xs
=
[
torch
.
randn
(
shape
,
dtype
=
typ
).
requires_grad_
()
for
_
in
range
(
self
.
t
)]
hidden_fn
=
lambda
:
torch
.
zeros
(
shape
,
dtype
=
typ
)
if
state_tuple
:
hidden
=
(
hidden_fn
(),
hidden_fn
())
else
:
hidden
=
hidden_fn
()
outputs
=
[]
for
i
in
range
(
self
.
t
):
hidden
=
cell
(
xs
[
i
],
hidden
)
if
state_tuple
:
output
=
hidden
[
0
]
else
:
output
=
hidden
outputs
.
append
(
output
)
for
y
in
outputs
:
self
.
assertEqual
(
y
.
type
(),
HALF
)
outputs
[
-
1
].
float
().
sum
().
backward
()
for
i
,
x
in
enumerate
(
xs
):
self
.
assertEqual
(
x
.
grad
.
dtype
,
x
.
dtype
)
def
test_rnn_cell_is_half
(
self
):
cell
=
nn
.
RNNCell
(
self
.
h
,
self
.
h
)
self
.
run_cell_test
(
cell
)
def
test_gru_cell_is_half
(
self
):
cell
=
nn
.
GRUCell
(
self
.
h
,
self
.
h
)
self
.
run_cell_test
(
cell
)
def
test_lstm_cell_is_half
(
self
):
cell
=
nn
.
LSTMCell
(
self
.
h
,
self
.
h
)
self
.
run_cell_test
(
cell
,
state_tuple
=
True
)
class
TestRnns
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
handle
=
amp
.
init
(
enabled
=
True
)
common_init
(
self
)
def
tearDown
(
self
):
self
.
handle
.
_deactivate
()
def
run_rnn_test
(
self
,
rnn
,
layers
,
bidir
,
state_tuple
=
False
):
for
typ
in
[
torch
.
float
,
torch
.
half
]:
x
=
torch
.
randn
((
self
.
t
,
self
.
b
,
self
.
h
),
dtype
=
typ
).
requires_grad_
()
hidden_fn
=
lambda
:
torch
.
zeros
((
layers
+
(
layers
*
bidir
),
self
.
b
,
self
.
h
),
dtype
=
typ
)
if
state_tuple
:
hidden
=
(
hidden_fn
(),
hidden_fn
())
else
:
hidden
=
hidden_fn
()
output
,
_
=
rnn
(
x
,
hidden
)
self
.
assertEqual
(
output
.
type
(),
HALF
)
output
[
-
1
,
:,
:].
float
().
sum
().
backward
()
self
.
assertEqual
(
x
.
grad
.
dtype
,
x
.
dtype
)
def
test_rnn_is_half
(
self
):
configs
=
[(
1
,
False
),
(
2
,
False
),
(
2
,
True
)]
for
layers
,
bidir
in
configs
:
rnn
=
nn
.
RNN
(
input_size
=
self
.
h
,
hidden_size
=
self
.
h
,
num_layers
=
layers
,
nonlinearity
=
'relu'
,
bidirectional
=
bidir
)
self
.
run_rnn_test
(
rnn
,
layers
,
bidir
)
def
test_gru_is_half
(
self
):
configs
=
[(
1
,
False
),
(
2
,
False
),
(
2
,
True
)]
for
layers
,
bidir
in
configs
:
rnn
=
nn
.
GRU
(
input_size
=
self
.
h
,
hidden_size
=
self
.
h
,
num_layers
=
layers
,
bidirectional
=
bidir
)
self
.
run_rnn_test
(
rnn
,
layers
,
bidir
)
def
test_lstm_is_half
(
self
):
configs
=
[(
1
,
False
),
(
2
,
False
),
(
2
,
True
)]
for
layers
,
bidir
in
configs
:
rnn
=
nn
.
LSTM
(
input_size
=
self
.
h
,
hidden_size
=
self
.
h
,
num_layers
=
layers
,
bidirectional
=
bidir
)
self
.
run_rnn_test
(
rnn
,
layers
,
bidir
,
state_tuple
=
True
)
if
__name__
==
'__main__'
:
unittest
.
main
()
apex/amp/test/utils.py
View file @
5f5dfa42
...
@@ -17,4 +17,5 @@ def common_init(test_case):
...
@@ -17,4 +17,5 @@ def common_init(test_case):
test_case
.
b
=
16
test_case
.
b
=
16
test_case
.
c
=
16
test_case
.
c
=
16
test_case
.
k
=
3
test_case
.
k
=
3
test_case
.
t
=
10
torch
.
set_default_tensor_type
(
torch
.
cuda
.
FloatTensor
)
torch
.
set_default_tensor_type
(
torch
.
cuda
.
FloatTensor
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment