Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
3471ff0d
Unverified
Commit
3471ff0d
authored
Dec 24, 2019
by
Anthony MOI
Browse files
FastPreTrainedTokenizer
parent
81db12c3
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
127 additions
and
0 deletions
+127
-0
src/transformers/tokenization_utils.py
src/transformers/tokenization_utils.py
+127
-0
No files found.
src/transformers/tokenization_utils.py
View file @
3471ff0d
...
...
@@ -1410,3 +1410,130 @@ class PreTrainedTokenizer(object):
.
replace
(
" 're"
,
"'re"
)
)
return
out_string
class
FastPreTrainedTokenizer
(
PreTrainedTokenizer
):
def
__init__
(
self
,
**
kwargs
):
super
(
FastPreTrainedTokenizer
,
self
).
__init__
(
**
kwargs
)
@
property
def
tokenizer
(
self
):
if
self
.
_tokenizer
is
None
:
raise
NotImplementedError
return
self
.
_tokenizer
@
property
def
decoder
(
self
):
if
self
.
_decoder
is
None
:
raise
NotImplementedError
return
self
.
_decoder
@
property
def
vocab_size
(
self
):
return
self
.
tokenizer
.
get_vocab_size
(
False
)
def
__len__
(
self
):
return
self
.
tokenizer
.
get_vocab_size
(
True
)
def
_update_special_tokens
(
self
):
self
.
tokenizer
.
add_special_tokens
(
self
.
all_special_tokens
)
@
staticmethod
def
_convert_encoding
(
encoding
,
return_tensors
=
None
,
return_token_type_ids
=
True
,
return_attention_mask
=
True
,
return_overflowing_tokens
=
False
,
return_special_tokens_mask
=
False
):
encoding_dict
=
{
"input_ids"
:
encoding
.
ids
,
}
if
return_token_type_ids
:
encoding_dict
[
"token_type_ids"
]
=
encoding
.
type_ids
if
return_attention_mask
:
encoding_dict
[
"attention_mask"
]
=
encoding
.
attention_mask
if
return_overflowing_tokens
:
overflowing
=
encoding
.
overflowing
encoding_dict
[
"overflowing_tokens"
]
=
overflowing
.
ids
if
overflowing
is
not
None
else
[]
if
return_special_tokens_mask
:
encoding_dict
[
"special_tokens_mask"
]
=
encoding
.
special_tokens_mask
# Prepare inputs as tensors if asked
if
return_tensors
==
'tf'
and
is_tf_available
():
encoding_dict
[
"input_ids"
]
=
tf
.
constant
([
encoding_dict
[
"input_ids"
]])
encoding_dict
[
"token_type_ids"
]
=
tf
.
constant
([
encoding_dict
[
"token_type_ids"
]])
if
"attention_mask"
in
encoding_dict
:
encoding_dict
[
"attention_mask"
]
=
tf
.
constant
([
encoding_dict
[
"attention_mask"
]])
elif
return_tensors
==
'pt'
and
is_torch_available
():
encoding_dict
[
"input_ids"
]
=
torch
.
tensor
([
encoding_dict
[
"input_ids"
]])
encoding_dict
[
"token_type_ids"
]
=
torch
.
tensor
([
encoding_dict
[
"token_type_ids"
]])
if
"attention_mask"
in
encoding_dict
:
encoding_dict
[
"attention_mask"
]
=
torch
.
tensor
([
encoding_dict
[
"attention_mask"
]])
elif
return_tensors
is
not
None
:
logger
.
warning
(
"Unable to convert output to tensors format {}, PyTorch or TensorFlow is not available."
.
format
(
return_tensors
))
return
encoding_dict
def
encode_plus
(
self
,
text
,
text_pair
=
None
,
return_tensors
=
None
,
return_token_type_ids
=
True
,
return_attention_mask
=
True
,
return_overflowing_tokens
=
False
,
return_special_tokens_mask
=
False
,
**
kwargs
):
encoding
=
self
.
tokenizer
.
encode
(
text
,
text_pair
)
return
self
.
_convert_encoding
(
encoding
,
return_tensors
=
return_tensors
,
return_token_type_ids
=
return_token_type_ids
,
return_attention_mask
=
return_attention_mask
,
return_overflowing_tokens
=
return_overflowing_tokens
,
return_special_tokens_mask
=
return_special_tokens_mask
)
def
tokenize
(
self
,
text
):
return
self
.
tokenizer
.
encode
(
text
).
tokens
def
_convert_token_to_id_with_added_voc
(
self
,
token
):
return
self
.
tokenizer
.
token_to_id
(
token
)
def
_convert_id_to_token
(
self
,
index
):
return
self
.
tokenizer
.
id_to_token
(
int
(
index
))
def
convert_tokens_to_string
(
self
,
tokens
):
return
self
.
decoder
.
decode
(
tokens
)
def
add_tokens
(
self
,
new_tokens
):
self
.
tokenizer
.
add_tokens
(
new_tokens
)
def
encode_batch
(
self
,
texts
,
return_tensors
=
None
,
return_token_type_ids
=
True
,
return_attention_mask
=
True
,
return_overflowing_tokens
=
False
,
return_special_tokens_mask
=
False
):
return
[
self
.
_convert_encoding
(
encoding
,
return_tensors
=
return_tensors
,
return_token_type_ids
=
return_token_type_ids
,
return_attention_mask
=
return_attention_mask
,
return_overflowing_tokens
=
return_overflowing_tokens
,
return_special_tokens_mask
=
return_special_tokens_mask
)
for
encoding
in
self
.
tokenizer
.
encode_batch
(
texts
)]
def
decode
(
self
,
token_ids
,
skip_special_tokens
=
False
,
clean_up_tokenization_spaces
=
True
):
text
=
self
.
tokenizer
.
decode
(
token_ids
,
skip_special_tokens
)
if
clean_up_tokenization_spaces
:
clean_text
=
self
.
clean_up_tokenization
(
text
)
return
clean_text
else
:
return
text
def
decode_batch
(
self
,
ids_batch
,
skip_special_tokens
=
False
,
clear_up_tokenization_spaces
=
True
):
return
[
self
.
clean_up_tokenization
(
text
)
if
clear_up_tokenization_spaces
else
text
for
text
in
self
.
tokenizer
.
decode_batch
(
ids_batch
,
skip_special_tokens
)]
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment