Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
dcc9bb32
Commit
dcc9bb32
authored
Sep 19, 2019
by
LysandreJik
Browse files
Modified encode to return only lists. Added a more complete encode_plus method
parent
af23b626
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
105 additions
and
8 deletions
+105
-8
pytorch_transformers/tokenization_utils.py
pytorch_transformers/tokenization_utils.py
+105
-8
No files found.
pytorch_transformers/tokenization_utils.py
View file @
dcc9bb32
...
@@ -535,7 +535,7 @@ class PreTrainedTokenizer(object):
...
@@ -535,7 +535,7 @@ class PreTrainedTokenizer(object):
"""
"""
if
pair
:
if
pair
:
initial_tokens_len
=
sum
([
len
(
encoded
)
for
encoded
in
self
.
encode
(
"This is a sequence"
,
"This is another"
)
]
)
initial_tokens_len
=
len
(
self
.
encode
(
"This is a sequence"
)
+
self
.
encode
(
"This is another"
))
final_tokens
=
self
.
encode
(
"This is a sequence"
,
"This is another"
,
add_special_tokens
=
True
)
final_tokens
=
self
.
encode
(
"This is a sequence"
,
"This is another"
,
add_special_tokens
=
True
)
# In some models (e.g. GPT-2), there is no sequence pair encoding.
# In some models (e.g. GPT-2), there is no sequence pair encoding.
...
@@ -693,10 +693,39 @@ class PreTrainedTokenizer(object):
...
@@ -693,10 +693,39 @@ class PreTrainedTokenizer(object):
def
_convert_token_to_id
(
self
,
token
):
def
_convert_token_to_id
(
self
,
token
):
raise
NotImplementedError
raise
NotImplementedError
def
encode
(
self
,
text
,
text_pair
=
None
,
add_special_tokens
=
False
,
output_mask
=
False
,
max_length
=
None
,
**
kwargs
):
def
encode
(
self
,
text
,
text_pair
=
None
,
add_special_tokens
=
False
,
**
kwargs
):
"""
"""
Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary.
Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary.
Same as doing ``self.convert_tokens_to_ids(self.tokenize(text))``.
Args:
text: The first sequence to be encoded.
text_pair: Optional second sequence to be encoded.
add_special_tokens: if set to ``True``, the sequences will be encoded with the special tokens relative
to their model.
"""
if
text_pair
is
None
:
if
add_special_tokens
:
sequence_tokens
=
self
.
convert_tokens_to_ids
(
self
.
tokenize
(
text
,
**
kwargs
))
return
self
.
add_special_tokens_single_sentence
(
sequence_tokens
)
else
:
ids
=
self
.
convert_tokens_to_ids
(
self
.
tokenize
(
text
,
**
kwargs
))
return
ids
first_sentence_tokens
=
[
self
.
_convert_token_to_id
(
token
)
for
token
in
self
.
tokenize
(
text
,
**
kwargs
)]
second_sentence_tokens
=
[
self
.
_convert_token_to_id
(
token
)
for
token
in
self
.
tokenize
(
text_pair
,
**
kwargs
)]
if
add_special_tokens
:
return
self
.
add_special_tokens_sentences_pair
(
first_sentence_tokens
,
second_sentence_tokens
)
else
:
logger
.
warning
(
"No special tokens were added. The two sequences have been concatenated."
)
return
first_sentence_tokens
+
second_sentence_tokens
def
encode_plus
(
self
,
text
,
text_pair
=
None
,
add_special_tokens
=
False
,
output_mask
=
False
,
max_length
=
None
,
**
kwargs
):
"""
Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary.
Same as doing ``self.convert_tokens_to_ids(self.tokenize(text))``.
Same as doing ``self.convert_tokens_to_ids(self.tokenize(text))``.
Args:
Args:
...
@@ -709,6 +738,69 @@ class PreTrainedTokenizer(object):
...
@@ -709,6 +738,69 @@ class PreTrainedTokenizer(object):
max_length: if set to a number, will limit the total sequence returned so that it has a maximum length.
max_length: if set to a number, will limit the total sequence returned so that it has a maximum length.
**kwargs: passed to the `self.tokenize()` method
**kwargs: passed to the `self.tokenize()` method
"""
"""
information
=
{}
if
text_pair
is
None
:
n_added_tokens
=
self
.
num_added_tokens
()
if
add_special_tokens
:
sequence_tokens
=
self
.
convert_tokens_to_ids
(
self
.
tokenize
(
text
,
**
kwargs
))
if
max_length
:
information
[
"overflowing_tokens"
]
=
sequence_tokens
[
max_length
-
n_added_tokens
:]
sequence_tokens
=
sequence_tokens
[:
max_length
-
n_added_tokens
]
sequence
=
self
.
add_special_tokens_single_sentence
(
sequence_tokens
)
else
:
sequence_tokens
=
self
.
convert_tokens_to_ids
(
self
.
tokenize
(
text
,
**
kwargs
))
if
max_length
:
information
[
"overflowing_tokens"
]
=
sequence_tokens
[
max_length
:]
sequence_tokens
=
sequence_tokens
[:
max_length
]
sequence
=
sequence_tokens
if
output_mask
:
information
[
"mask"
]
=
[
0
]
*
len
(
sequence
)
information
[
"sequence"
]
=
sequence
else
:
first_sentence_tokens
=
[
self
.
_convert_token_to_id
(
token
)
for
token
in
self
.
tokenize
(
text
,
**
kwargs
)]
second_sentence_tokens
=
[
self
.
_convert_token_to_id
(
token
)
for
token
in
self
.
tokenize
(
text_pair
,
**
kwargs
)]
f_len
,
s_len
=
len
(
first_sentence_tokens
),
len
(
second_sentence_tokens
)
n_added_tokens
=
self
.
num_added_tokens
(
pair
=
True
)
if
add_special_tokens
:
if
max_length
:
if
len
(
first_sentence_tokens
)
+
n_added_tokens
>=
max_length
:
logger
.
warning
(
"The first sequence is longer than the maximum specified length. This sequence will not be truncated."
)
else
:
if
f_len
+
s_len
+
self
.
num_added_tokens
(
pair
=
True
)
>
max_length
:
information
[
"overflowing_tokens"
]
=
second_sentence_tokens
[
max_length
-
f_len
-
n_added_tokens
:]
second_sentence_tokens
=
second_sentence_tokens
[:
max_length
-
f_len
-
n_added_tokens
]
encoded_sequence
=
self
.
add_special_tokens_sentences_pair
(
first_sentence_tokens
,
second_sentence_tokens
,
output_mask
)
if
output_mask
:
sequence
,
information
[
"mask"
]
=
encoded_sequence
else
:
sequence
=
encoded_sequence
information
[
"sequence"
]
=
sequence
else
:
logger
.
warning
(
"No special tokens were added. The two sequences have been concatenated."
)
sequence
=
first_sentence_tokens
+
second_sentence_tokens
if
max_length
:
information
[
"overflowing_tokens"
]
=
sequence
[
max_length
:]
sequence
=
sequence
[:
max_length
]
if
output_mask
:
information
[
"mask"
]
=
[
0
]
*
len
(
sequence
)
information
[
"sequence"
]
=
sequence
return
information
if
text_pair
is
None
:
if
text_pair
is
None
:
if
add_special_tokens
:
if
add_special_tokens
:
sequence_tokens
=
self
.
convert_tokens_to_ids
(
self
.
tokenize
(
text
,
**
kwargs
))
sequence_tokens
=
self
.
convert_tokens_to_ids
(
self
.
tokenize
(
text
,
**
kwargs
))
...
@@ -725,12 +817,17 @@ class PreTrainedTokenizer(object):
...
@@ -725,12 +817,17 @@ class PreTrainedTokenizer(object):
if
add_special_tokens
:
if
add_special_tokens
:
if
max_length
:
if
max_length
:
if
len
(
first_sentence_tokens
)
+
self
.
num_added_tokens
(
pair
=
True
)
>=
max_length
:
if
len
(
first_sentence_tokens
)
+
self
.
num_added_tokens
(
pair
=
True
)
>=
max_length
:
logger
.
warning
(
"The first sequence is longer than the maximum specified length. This sequence will not be truncated."
)
logger
.
warning
(
"The first sequence is longer than the maximum specified length. This sequence will not be truncated."
)
else
:
else
:
if
len
(
second_sentence_tokens
)
+
len
(
first_sentence_tokens
)
+
self
.
num_added_tokens
(
pair
=
True
)
>
max_length
:
if
len
(
second_sentence_tokens
)
+
len
(
first_sentence_tokens
)
+
self
.
num_added_tokens
(
second_sentence_tokens
=
second_sentence_tokens
[:
max_length
-
len
(
first_sentence_tokens
)
-
self
.
num_added_tokens
(
pair
=
True
)]
pair
=
True
)
>
max_length
:
second_sentence_tokens
=
second_sentence_tokens
[
return
self
.
add_special_tokens_sentences_pair
(
first_sentence_tokens
,
second_sentence_tokens
,
output_mask
)
:
max_length
-
len
(
first_sentence_tokens
)
-
self
.
num_added_tokens
(
pair
=
True
)]
return
self
.
add_special_tokens_sentences_pair
(
first_sentence_tokens
,
second_sentence_tokens
,
output_mask
)
else
:
else
:
if
max_length
:
if
max_length
:
first_sentence_tokens
=
first_sentence_tokens
[:
max_length
]
first_sentence_tokens
=
first_sentence_tokens
[:
max_length
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment