Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
0ea82b24
Commit
0ea82b24
authored
Sep 24, 2019
by
LysandreJik
Browse files
Updated tests
parent
9d44236f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
9 additions
and
4 deletions
+9
-4
pytorch_transformers/tests/tokenization_tests_commons.py
pytorch_transformers/tests/tokenization_tests_commons.py
+8
-3
pytorch_transformers/tokenization_utils.py
pytorch_transformers/tokenization_utils.py
+1
-1
No files found.
pytorch_transformers/tests/tokenization_tests_commons.py
View file @
0ea82b24
...
@@ -264,9 +264,14 @@ class CommonTestCases:
...
@@ -264,9 +264,14 @@ class CommonTestCases:
assert
len
(
truncated_sequence
)
==
len
(
sequence
)
-
2
assert
len
(
truncated_sequence
)
==
len
(
sequence
)
-
2
assert
truncated_sequence
==
truncated_second_sequence
assert
truncated_sequence
==
truncated_second_sequence
def
test_
tokens_sent_to_encod
e
(
self
):
def
test_
encode_input_typ
e
(
self
):
tokenizer
=
self
.
get_tokenizer
()
tokenizer
=
self
.
get_tokenizer
()
sequence
=
"Let's encode this sequence"
sequence
=
"Let's encode this sequence"
tokens
=
tokenizer
.
encode
(
sequence
)
tokenizer
.
encode
(
tokens
,
add_special_tokens
=
True
)
tokens
=
tokenizer
.
tokenize
(
sequence
)
input_ids
=
tokenizer
.
convert_tokens_to_ids
(
tokens
)
formatted_input
=
tokenizer
.
encode
(
sequence
,
add_special_tokens
=
True
)
assert
tokenizer
.
encode
(
tokens
,
add_special_tokens
=
True
)
==
formatted_input
assert
tokenizer
.
encode
(
input_ids
,
add_special_tokens
=
True
)
==
formatted_input
pytorch_transformers/tokenization_utils.py
View file @
0ea82b24
...
@@ -744,7 +744,7 @@ class PreTrainedTokenizer(object):
...
@@ -744,7 +744,7 @@ class PreTrainedTokenizer(object):
def
get_input_ids
(
text
):
def
get_input_ids
(
text
):
if
isinstance
(
text
,
six
.
string_types
):
if
isinstance
(
text
,
six
.
string_types
):
input_ids
=
self
.
convert_tokens_to_ids
(
self
.
tokenize
(
text
,
**
kwargs
))
input_ids
=
self
.
convert_tokens_to_ids
(
self
.
tokenize
(
text
,
**
kwargs
))
elif
isinstance
(
text
,
(
list
,
tuple
))
and
len
(
text
)
>
0
and
isinstance
(
text
[
0
],
s
tr
):
elif
isinstance
(
text
,
(
list
,
tuple
))
and
len
(
text
)
>
0
and
isinstance
(
text
[
0
],
s
ix
.
string_types
):
input_ids
=
self
.
convert_tokens_to_ids
(
text
)
input_ids
=
self
.
convert_tokens_to_ids
(
text
)
elif
isinstance
(
text
,
(
list
,
tuple
))
and
len
(
text
)
>
0
and
isinstance
(
text
[
0
],
int
):
elif
isinstance
(
text
,
(
list
,
tuple
))
and
len
(
text
)
>
0
and
isinstance
(
text
[
0
],
int
):
input_ids
=
text
input_ids
=
text
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment