Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
f8276008
Commit
f8276008
authored
Nov 03, 2018
by
thomwolf
Browse files
update readme, file names, removing TF code, moving tests
parent
3c24e4be
Changes
25
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
42 additions
and
181 deletions
+42
-181
tests/modeling_test.py
tests/modeling_test.py
+25
-39
tests/optimization_test.py
tests/optimization_test.py
+1
-1
tests/tokenization_test.py
tests/tokenization_test.py
+16
-17
tokenization.py
tokenization.py
+0
-0
tokenization_test_pytorch.py
tokenization_test_pytorch.py
+0
-124
No files found.
te
nsorflow_code
/modeling_test.py
→
te
sts
/modeling_test.py
View file @
f8276008
...
...
@@ -16,17 +16,19 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
six
import
unittest
import
collections
import
json
import
random
import
re
from
tensorflow_code
import
modeling
import
six
import
tensorflow
as
tf
import
torch
import
modeling
as
modeling
class
BertModelTest
(
tf
.
test
.
TestCase
):
class
BertModelTest
(
unit
test
.
TestCase
):
class
BertModelTester
(
object
):
def
__init__
(
self
,
...
...
@@ -68,18 +70,15 @@ class BertModelTest(tf.test.TestCase):
self
.
scope
=
scope
def
create_model
(
self
):
input_ids
=
BertModelTest
.
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
input_ids
=
BertModelTest
.
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
input_mask
=
None
if
self
.
use_input_mask
:
input_mask
=
BertModelTest
.
ids_tensor
(
[
self
.
batch_size
,
self
.
seq_length
],
vocab_size
=
2
)
input_mask
=
BertModelTest
.
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
vocab_size
=
2
)
token_type_ids
=
None
if
self
.
use_token_type_ids
:
token_type_ids
=
BertModelTest
.
ids_tensor
(
[
self
.
batch_size
,
self
.
seq_length
],
self
.
type_vocab_size
)
token_type_ids
=
BertModelTest
.
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
type_vocab_size
)
config
=
modeling
.
BertConfig
(
vocab_size
=
self
.
vocab_size
,
...
...
@@ -94,33 +93,23 @@ class BertModelTest(tf.test.TestCase):
type_vocab_size
=
self
.
type_vocab_size
,
initializer_range
=
self
.
initializer_range
)
model
=
modeling
.
BertModel
(
config
=
config
,
is_training
=
self
.
is_training
,
input_ids
=
input_ids
,
input_mask
=
input_mask
,
token_type_ids
=
token_type_ids
,
scope
=
self
.
scope
)
model
=
modeling
.
BertModel
(
config
=
config
)
all_encoder_layers
,
pooled_output
=
model
(
input_ids
,
token_type_ids
,
input_mask
)
outputs
=
{
"embedding_output"
:
model
.
get_embedding_output
(),
"sequence_output"
:
model
.
get_sequence_output
(),
"pooled_output"
:
model
.
get_pooled_output
(),
"all_encoder_layers"
:
model
.
get_all_encoder_layers
(),
"sequence_output"
:
all_encoder_layers
[
-
1
],
"pooled_output"
:
pooled_output
,
"all_encoder_layers"
:
all_encoder_layers
,
}
return
outputs
def
check_output
(
self
,
result
):
self
.
parent
.
assertAllEqual
(
result
[
"embedding_output"
].
shape
,
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
])
self
.
parent
.
assertAllEqual
(
result
[
"sequence_output"
].
shape
,
self
.
parent
.
assertListEqual
(
list
(
result
[
"sequence_output"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
])
self
.
parent
.
assertAllEqual
(
result
[
"pooled_output"
].
shape
,
[
self
.
batch_size
,
self
.
hidden_size
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"pooled_output"
].
size
()),
[
self
.
batch_size
,
self
.
hidden_size
])
def
test_default
(
self
):
self
.
run_tester
(
BertModelTest
.
BertModelTester
(
self
))
...
...
@@ -132,15 +121,11 @@ class BertModelTest(tf.test.TestCase):
self
.
assertEqual
(
obj
[
"hidden_size"
],
37
)
def
run_tester
(
self
,
tester
):
with
self
.
test_session
()
as
sess
:
ops
=
tester
.
create_model
()
init_op
=
tf
.
group
(
tf
.
global_variables_initializer
(),
tf
.
local_variables_initializer
())
sess
.
run
(
init_op
)
output_result
=
sess
.
run
(
ops
)
output_result
=
tester
.
create_model
()
tester
.
check_output
(
output_result
)
self
.
assert_all_tensors_reachable
(
sess
,
[
init_op
,
ops
])
# TODO Find PyTorch equivalent of assert_all_tensors_reachable() if necessary
# self.assert_all_tensors_reachable(sess, [init_op, ops])
@
classmethod
def
ids_tensor
(
cls
,
shape
,
vocab_size
,
rng
=
None
,
name
=
None
):
...
...
@@ -156,7 +141,8 @@ class BertModelTest(tf.test.TestCase):
for
_
in
range
(
total_dims
):
values
.
append
(
rng
.
randint
(
0
,
vocab_size
-
1
))
return
tf
.
constant
(
value
=
values
,
dtype
=
tf
.
int32
,
shape
=
shape
,
name
=
name
)
# TODO Solve : the returned tensors provoke index out of range errors when passed to the model
return
torch
.
tensor
(
data
=
values
,
dtype
=
torch
.
int32
)
def
assert_all_tensors_reachable
(
self
,
sess
,
outputs
):
"""Checks that all the tensors in the graph are reachable from outputs."""
...
...
@@ -272,4 +258,4 @@ class BertModelTest(tf.test.TestCase):
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
unit
test
.
main
()
optimization_test
_pytorch
.py
→
tests/
optimization_test.py
View file @
f8276008
...
...
@@ -20,7 +20,7 @@ import unittest
import
torch
import
optimization
_pytorch
as
optimization
import
optimization
as
optimization
class
OptimizationTest
(
unittest
.
TestCase
):
...
...
te
nsorflow_code
/tokenization_test.py
→
te
sts
/tokenization_test.py
View file @
f8276008
...
...
@@ -17,45 +17,44 @@ from __future__ import division
from
__future__
import
print_function
import
os
import
tempfile
import
unittest
from
tensorflow_code
import
tokenization
import
tensorflow
as
tf
import
tokenization
as
tokenization
class
TokenizationTest
(
tf
.
test
.
TestCase
):
class
TokenizationTest
(
unit
test
.
TestCase
):
def
test_full_tokenizer
(
self
):
vocab_tokens
=
[
"[UNK]"
,
"[CLS]"
,
"[SEP]"
,
"want"
,
"##want"
,
"##ed"
,
"wa"
,
"un"
,
"runn"
,
"##ing"
,
","
]
with
tempfile
.
NamedTemporaryFile
(
delete
=
False
)
as
vocab_writer
:
with
open
(
"/tmp/bert_tokenizer_test.txt"
,
"w"
)
as
vocab_writer
:
vocab_writer
.
write
(
""
.
join
([
x
+
"
\n
"
for
x
in
vocab_tokens
]))
vocab_file
=
vocab_writer
.
name
tokenizer
=
tokenization
.
FullTokenizer
(
vocab_file
)
os
.
unlink
(
vocab_file
)
os
.
remove
(
vocab_file
)
tokens
=
tokenizer
.
tokenize
(
u
"UNwant
\u00E9
d,running"
)
self
.
assert
All
Equal
(
tokens
,
[
"un"
,
"##want"
,
"##ed"
,
","
,
"runn"
,
"##ing"
])
self
.
assert
List
Equal
(
tokens
,
[
"un"
,
"##want"
,
"##ed"
,
","
,
"runn"
,
"##ing"
])
self
.
assert
All
Equal
(
self
.
assert
List
Equal
(
tokenizer
.
convert_tokens_to_ids
(
tokens
),
[
7
,
4
,
5
,
10
,
8
,
9
])
def
test_basic_tokenizer_lower
(
self
):
tokenizer
=
tokenization
.
BasicTokenizer
(
do_lower_case
=
True
)
self
.
assert
All
Equal
(
self
.
assert
List
Equal
(
tokenizer
.
tokenize
(
u
"
\t
HeLLo!how
\n
Are yoU? "
),
[
"hello"
,
"!"
,
"how"
,
"are"
,
"you"
,
"?"
])
self
.
assert
All
Equal
(
tokenizer
.
tokenize
(
u
"H
\u00E9
llo"
),
[
"hello"
])
self
.
assert
List
Equal
(
tokenizer
.
tokenize
(
u
"H
\u00E9
llo"
),
[
"hello"
])
def
test_basic_tokenizer_no_lower
(
self
):
tokenizer
=
tokenization
.
BasicTokenizer
(
do_lower_case
=
False
)
self
.
assert
All
Equal
(
self
.
assert
List
Equal
(
tokenizer
.
tokenize
(
u
"
\t
HeLLo!how
\n
Are yoU? "
),
[
"HeLLo"
,
"!"
,
"how"
,
"Are"
,
"yoU"
,
"?"
])
...
...
@@ -70,13 +69,13 @@ class TokenizationTest(tf.test.TestCase):
vocab
[
token
]
=
i
tokenizer
=
tokenization
.
WordpieceTokenizer
(
vocab
=
vocab
)
self
.
assert
All
Equal
(
tokenizer
.
tokenize
(
""
),
[])
self
.
assert
List
Equal
(
tokenizer
.
tokenize
(
""
),
[])
self
.
assert
All
Equal
(
self
.
assert
List
Equal
(
tokenizer
.
tokenize
(
"unwanted running"
),
[
"un"
,
"##want"
,
"##ed"
,
"runn"
,
"##ing"
])
self
.
assert
All
Equal
(
self
.
assert
List
Equal
(
tokenizer
.
tokenize
(
"unwantedX running"
),
[
"[UNK]"
,
"runn"
,
"##ing"
])
def
test_convert_tokens_to_ids
(
self
):
...
...
@@ -89,7 +88,7 @@ class TokenizationTest(tf.test.TestCase):
for
(
i
,
token
)
in
enumerate
(
vocab_tokens
):
vocab
[
token
]
=
i
self
.
assert
All
Equal
(
self
.
assert
List
Equal
(
tokenization
.
convert_tokens_to_ids
(
vocab
,
[
"un"
,
"##want"
,
"##ed"
,
"runn"
,
"##ing"
]),
[
7
,
4
,
5
,
8
,
9
])
...
...
@@ -121,5 +120,5 @@ class TokenizationTest(tf.test.TestCase):
self
.
assertFalse
(
tokenization
.
_is_punctuation
(
u
" "
))
if
__name__
==
"
__main__
"
:
tf
.
test
.
main
()
if
__name__
==
'
__main__
'
:
unit
test
.
main
()
tokenization
_pytorch
.py
→
tokenization.py
View file @
f8276008
File moved
tokenization_test_pytorch.py
deleted
100644 → 0
View file @
3c24e4be
This diff is collapsed.
Click to expand it.
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment