Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
87da161c
Commit
87da161c
authored
Nov 04, 2018
by
thomwolf
Browse files
finishing model test
parent
d69b0b0e
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
2 additions
and
121 deletions
+2
-121
tests/modeling_test.py
tests/modeling_test.py
+2
-121
No files found.
tests/modeling_test.py
View file @
87da161c
...
@@ -16,16 +16,13 @@ from __future__ import absolute_import
...
@@ -16,16 +16,13 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
print_function
import
six
import
unittest
import
unittest
import
collections
import
json
import
json
import
random
import
random
import
re
import
torch
import
torch
import
modeling
as
modeling
import
modeling
class
BertModelTest
(
unittest
.
TestCase
):
class
BertModelTest
(
unittest
.
TestCase
):
...
@@ -124,9 +121,6 @@ class BertModelTest(unittest.TestCase):
...
@@ -124,9 +121,6 @@ class BertModelTest(unittest.TestCase):
output_result
=
tester
.
create_model
()
output_result
=
tester
.
create_model
()
tester
.
check_output
(
output_result
)
tester
.
check_output
(
output_result
)
# TODO Find PyTorch equivalent of assert_all_tensors_reachable() if necessary
# self.assert_all_tensors_reachable(sess, [init_op, ops])
@
classmethod
@
classmethod
def
ids_tensor
(
cls
,
shape
,
vocab_size
,
rng
=
None
,
name
=
None
):
def
ids_tensor
(
cls
,
shape
,
vocab_size
,
rng
=
None
,
name
=
None
):
"""Creates a random int32 tensor of the shape within the vocab size."""
"""Creates a random int32 tensor of the shape within the vocab size."""
...
@@ -141,120 +135,7 @@ class BertModelTest(unittest.TestCase):
...
@@ -141,120 +135,7 @@ class BertModelTest(unittest.TestCase):
for
_
in
range
(
total_dims
):
for
_
in
range
(
total_dims
):
values
.
append
(
rng
.
randint
(
0
,
vocab_size
-
1
))
values
.
append
(
rng
.
randint
(
0
,
vocab_size
-
1
))
# TODO Solve : the returned tensors provoke index out of range errors when passed to the model
return
torch
.
tensor
(
data
=
values
,
dtype
=
torch
.
long
).
view
(
shape
).
contiguous
()
return
torch
.
tensor
(
data
=
values
,
dtype
=
torch
.
int32
)
def
assert_all_tensors_reachable
(
self
,
sess
,
outputs
):
"""Checks that all the tensors in the graph are reachable from outputs."""
graph
=
sess
.
graph
ignore_strings
=
[
"^.*/dilation_rate$"
,
"^.*/Tensordot/concat$"
,
"^.*/Tensordot/concat/axis$"
,
"^testing/.*$"
,
]
ignore_regexes
=
[
re
.
compile
(
x
)
for
x
in
ignore_strings
]
unreachable
=
self
.
get_unreachable_ops
(
graph
,
outputs
)
filtered_unreachable
=
[]
for
x
in
unreachable
:
do_ignore
=
False
for
r
in
ignore_regexes
:
m
=
r
.
match
(
x
.
name
)
if
m
is
not
None
:
do_ignore
=
True
if
do_ignore
:
continue
filtered_unreachable
.
append
(
x
)
unreachable
=
filtered_unreachable
self
.
assertEqual
(
len
(
unreachable
),
0
,
"The following ops are unreachable: %s"
%
(
" "
.
join
([
x
.
name
for
x
in
unreachable
])))
@
classmethod
def
get_unreachable_ops
(
cls
,
graph
,
outputs
):
"""Finds all of the tensors in graph that are unreachable from outputs."""
outputs
=
cls
.
flatten_recursive
(
outputs
)
output_to_op
=
collections
.
defaultdict
(
list
)
op_to_all
=
collections
.
defaultdict
(
list
)
assign_out_to_in
=
collections
.
defaultdict
(
list
)
for
op
in
graph
.
get_operations
():
for
x
in
op
.
inputs
:
op_to_all
[
op
.
name
].
append
(
x
.
name
)
for
y
in
op
.
outputs
:
output_to_op
[
y
.
name
].
append
(
op
.
name
)
op_to_all
[
op
.
name
].
append
(
y
.
name
)
if
str
(
op
.
type
)
==
"Assign"
:
for
y
in
op
.
outputs
:
for
x
in
op
.
inputs
:
assign_out_to_in
[
y
.
name
].
append
(
x
.
name
)
assign_groups
=
collections
.
defaultdict
(
list
)
for
out_name
in
assign_out_to_in
.
keys
():
name_group
=
assign_out_to_in
[
out_name
]
for
n1
in
name_group
:
assign_groups
[
n1
].
append
(
out_name
)
for
n2
in
name_group
:
if
n1
!=
n2
:
assign_groups
[
n1
].
append
(
n2
)
seen_tensors
=
{}
stack
=
[
x
.
name
for
x
in
outputs
]
while
stack
:
name
=
stack
.
pop
()
if
name
in
seen_tensors
:
continue
seen_tensors
[
name
]
=
True
if
name
in
output_to_op
:
for
op_name
in
output_to_op
[
name
]:
if
op_name
in
op_to_all
:
for
input_name
in
op_to_all
[
op_name
]:
if
input_name
not
in
stack
:
stack
.
append
(
input_name
)
expanded_names
=
[]
if
name
in
assign_groups
:
for
assign_name
in
assign_groups
[
name
]:
expanded_names
.
append
(
assign_name
)
for
expanded_name
in
expanded_names
:
if
expanded_name
not
in
stack
:
stack
.
append
(
expanded_name
)
unreachable_ops
=
[]
for
op
in
graph
.
get_operations
():
is_unreachable
=
False
all_names
=
[
x
.
name
for
x
in
op
.
inputs
]
+
[
x
.
name
for
x
in
op
.
outputs
]
for
name
in
all_names
:
if
name
not
in
seen_tensors
:
is_unreachable
=
True
if
is_unreachable
:
unreachable_ops
.
append
(
op
)
return
unreachable_ops
@
classmethod
def
flatten_recursive
(
cls
,
item
):
"""Flattens (potentially nested) a tuple/dictionary/list to a list."""
output
=
[]
if
isinstance
(
item
,
list
):
output
.
extend
(
item
)
elif
isinstance
(
item
,
tuple
):
output
.
extend
(
list
(
item
))
elif
isinstance
(
item
,
dict
):
for
(
_
,
v
)
in
six
.
iteritems
(
item
):
output
.
append
(
v
)
else
:
return
[
item
]
flat_output
=
[]
for
x
in
output
:
flat_output
.
extend
(
cls
.
flatten_recursive
(
x
))
return
flat_output
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment