Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
31c56f2e
Unverified
Commit
31c56f2e
authored
Dec 24, 2019
by
Anthony MOI
Browse files
Fix style
parent
951ae99b
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
126 additions
and
73 deletions
+126
-73
src/transformers/tokenization_bert.py
src/transformers/tokenization_bert.py
+48
-26
src/transformers/tokenization_gpt2.py
src/transformers/tokenization_gpt2.py
+21
-7
src/transformers/tokenization_utils.py
src/transformers/tokenization_utils.py
+57
-40
No files found.
src/transformers/tokenization_bert.py
View file @
31c56f2e
...
...
@@ -20,7 +20,7 @@ import logging
import
os
import
unicodedata
from
.tokenization_utils
import
PreTrainedTokenizer
,
Fast
PreTrainedTokenizer
from
.tokenization_utils
import
Fast
PreTrainedTokenizer
,
PreTrainedTokenizer
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -526,42 +526,64 @@ def _is_punctuation(char):
return
True
return
False
class
BertTokenizerFast
(
FastPreTrainedTokenizer
):
vocab_files_names
=
VOCAB_FILES_NAMES
pretrained_vocab_files_map
=
PRETRAINED_VOCAB_FILES_MAP
pretrained_init_configuration
=
PRETRAINED_INIT_CONFIGURATION
max_model_input_sizes
=
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def
__init__
(
self
,
vocab_file
,
do_lower_case
=
True
,
do_basic_tokenize
=
True
,
never_split
=
None
,
unk_token
=
"[UNK]"
,
sep_token
=
"[SEP]"
,
pad_token
=
"[PAD]"
,
cls_token
=
"[CLS]"
,
mask_token
=
"[MASK]"
,
tokenize_chinese_chars
=
True
,
max_length
=
None
,
pad_to_max_length
=
False
,
stride
=
0
,
truncation_strategy
=
'longest_first'
,
add_special_tokens
=
True
,
**
kwargs
):
def
__init__
(
self
,
vocab_file
,
do_lower_case
=
True
,
do_basic_tokenize
=
True
,
never_split
=
None
,
unk_token
=
"[UNK]"
,
sep_token
=
"[SEP]"
,
pad_token
=
"[PAD]"
,
cls_token
=
"[CLS]"
,
mask_token
=
"[MASK]"
,
tokenize_chinese_chars
=
True
,
max_length
=
None
,
pad_to_max_length
=
False
,
stride
=
0
,
truncation_strategy
=
"longest_first"
,
add_special_tokens
=
True
,
**
kwargs
):
try
:
from
tokenizers
import
Tokenizer
,
models
,
pre_tokenizers
,
decoders
,
processors
super
(
BertTokenizerFast
,
self
).
__init__
(
unk_token
=
unk_token
,
sep_token
=
sep_token
,
pad_token
=
pad_token
,
cls_token
=
cls_token
,
mask_token
=
mask_token
,
**
kwargs
)
self
.
_tokenizer
=
Tokenizer
(
models
.
WordPiece
.
from_files
(
vocab_file
,
unk_token
=
unk_token
))
super
(
BertTokenizerFast
,
self
).
__init__
(
unk_token
=
unk_token
,
sep_token
=
sep_token
,
pad_token
=
pad_token
,
cls_token
=
cls_token
,
mask_token
=
mask_token
,
**
kwargs
)
self
.
_tokenizer
=
Tokenizer
(
models
.
WordPiece
.
from_files
(
vocab_file
,
unk_token
=
unk_token
))
self
.
_update_special_tokens
()
self
.
_tokenizer
.
with_pre_tokenizer
(
pre_tokenizers
.
BertPreTokenizer
.
new
(
self
.
_tokenizer
.
with_pre_tokenizer
(
pre_tokenizers
.
BertPreTokenizer
.
new
(
do_basic_tokenize
=
do_basic_tokenize
,
do_lower_case
=
do_lower_case
,
tokenize_chinese_chars
=
tokenize_chinese_chars
,
never_split
=
never_split
if
never_split
is
not
None
else
[],
))
)
)
self
.
_tokenizer
.
with_decoder
(
decoders
.
WordPiece
.
new
())
if
add_special_tokens
:
self
.
_tokenizer
.
with_post_processor
(
processors
.
BertProcessing
.
new
(
self
.
_tokenizer
.
with_post_processor
(
processors
.
BertProcessing
.
new
(
(
sep_token
,
self
.
_tokenizer
.
token_to_id
(
sep_token
)),
(
cls_token
,
self
.
_tokenizer
.
token_to_id
(
cls_token
)),
))
)
)
if
max_length
is
not
None
:
self
.
_tokenizer
.
with_truncation
(
max_length
,
stride
,
truncation_strategy
)
self
.
_tokenizer
.
with_padding
(
...
...
@@ -569,7 +591,7 @@ class BertTokenizerFast(FastPreTrainedTokenizer):
self
.
padding_side
,
self
.
pad_token_id
,
self
.
pad_token_type_id
,
self
.
pad_token
self
.
pad_token
,
)
self
.
_decoder
=
decoders
.
WordPiece
.
new
()
...
...
src/transformers/tokenization_gpt2.py
View file @
31c56f2e
...
...
@@ -22,7 +22,7 @@ from functools import lru_cache
import
regex
as
re
from
.tokenization_utils
import
PreTrainedTokenizer
,
Fast
PreTrainedTokenizer
from
.tokenization_utils
import
Fast
PreTrainedTokenizer
,
PreTrainedTokenizer
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -247,19 +247,33 @@ class GPT2Tokenizer(PreTrainedTokenizer):
return
vocab_file
,
merge_file
class
GPT2TokenizerFast
(
FastPreTrainedTokenizer
):
vocab_files_names
=
VOCAB_FILES_NAMES
pretrained_vocab_files_map
=
PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes
=
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def
__init__
(
self
,
vocab_file
,
merges_file
,
unk_token
=
"<|endoftext|>"
,
bos_token
=
"<|endoftext|>"
,
eos_token
=
"<|endoftext|>"
,
pad_to_max_length
=
False
,
add_prefix_space
=
False
,
max_length
=
None
,
stride
=
0
,
truncation_strategy
=
'longest_first'
,
**
kwargs
):
def
__init__
(
self
,
vocab_file
,
merges_file
,
unk_token
=
"<|endoftext|>"
,
bos_token
=
"<|endoftext|>"
,
eos_token
=
"<|endoftext|>"
,
pad_to_max_length
=
False
,
add_prefix_space
=
False
,
max_length
=
None
,
stride
=
0
,
truncation_strategy
=
"longest_first"
,
**
kwargs
):
try
:
from
tokenizers
import
Tokenizer
,
models
,
pre_tokenizers
,
decoders
super
(
GPT2TokenizerFast
,
self
).
__init__
(
bos_token
=
bos_token
,
eos_token
=
eos_token
,
unk_token
=
unk_token
,
**
kwargs
)
super
(
GPT2TokenizerFast
,
self
).
__init__
(
bos_token
=
bos_token
,
eos_token
=
eos_token
,
unk_token
=
unk_token
,
**
kwargs
)
self
.
_tokenizer
=
Tokenizer
(
models
.
BPE
.
from_files
(
vocab_file
,
merges_file
))
self
.
_update_special_tokens
()
...
...
@@ -272,7 +286,7 @@ class GPT2TokenizerFast(FastPreTrainedTokenizer):
self
.
padding_side
,
self
.
pad_token_id
if
self
.
pad_token_id
is
not
None
else
0
,
self
.
pad_token_type_id
,
self
.
pad_token
if
self
.
pad_token
is
not
None
else
""
self
.
pad_token
if
self
.
pad_token
is
not
None
else
""
,
)
self
.
_decoder
=
decoders
.
ByteLevel
.
new
()
...
...
src/transformers/tokenization_utils.py
View file @
31c56f2e
...
...
@@ -1411,6 +1411,7 @@ class PreTrainedTokenizer(object):
)
return
out_string
class
FastPreTrainedTokenizer
(
PreTrainedTokenizer
):
def
__init__
(
self
,
**
kwargs
):
super
(
FastPreTrainedTokenizer
,
self
).
__init__
(
**
kwargs
)
...
...
@@ -1438,12 +1439,14 @@ class FastPreTrainedTokenizer(PreTrainedTokenizer):
self
.
tokenizer
.
add_special_tokens
(
self
.
all_special_tokens
)
@
staticmethod
def
_convert_encoding
(
encoding
,
def
_convert_encoding
(
encoding
,
return_tensors
=
None
,
return_token_type_ids
=
True
,
return_attention_mask
=
True
,
return_overflowing_tokens
=
False
,
return_special_tokens_mask
=
False
):
return_special_tokens_mask
=
False
,
):
encoding_dict
=
{
"input_ids"
:
encoding
.
ids
,
}
...
...
@@ -1458,14 +1461,14 @@ class FastPreTrainedTokenizer(PreTrainedTokenizer):
encoding_dict
[
"special_tokens_mask"
]
=
encoding
.
special_tokens_mask
# Prepare inputs as tensors if asked
if
return_tensors
==
'
tf
'
and
is_tf_available
():
if
return_tensors
==
"
tf
"
and
is_tf_available
():
encoding_dict
[
"input_ids"
]
=
tf
.
constant
([
encoding_dict
[
"input_ids"
]])
encoding_dict
[
"token_type_ids"
]
=
tf
.
constant
([
encoding_dict
[
"token_type_ids"
]])
if
"attention_mask"
in
encoding_dict
:
encoding_dict
[
"attention_mask"
]
=
tf
.
constant
([
encoding_dict
[
"attention_mask"
]])
elif
return_tensors
==
'
pt
'
and
is_torch_available
():
elif
return_tensors
==
"
pt
"
and
is_torch_available
():
encoding_dict
[
"input_ids"
]
=
torch
.
tensor
([
encoding_dict
[
"input_ids"
]])
encoding_dict
[
"token_type_ids"
]
=
torch
.
tensor
([
encoding_dict
[
"token_type_ids"
]])
...
...
@@ -1474,11 +1477,14 @@ class FastPreTrainedTokenizer(PreTrainedTokenizer):
elif
return_tensors
is
not
None
:
logger
.
warning
(
"Unable to convert output to tensors format {}, PyTorch or TensorFlow is not available."
.
format
(
return_tensors
))
return_tensors
)
)
return
encoding_dict
def
encode_plus
(
self
,
def
encode_plus
(
self
,
text
,
text_pair
=
None
,
return_tensors
=
None
,
...
...
@@ -1486,14 +1492,17 @@ class FastPreTrainedTokenizer(PreTrainedTokenizer):
return_attention_mask
=
True
,
return_overflowing_tokens
=
False
,
return_special_tokens_mask
=
False
,
**
kwargs
):
**
kwargs
):
encoding
=
self
.
tokenizer
.
encode
(
text
,
text_pair
)
return
self
.
_convert_encoding
(
encoding
,
return
self
.
_convert_encoding
(
encoding
,
return_tensors
=
return_tensors
,
return_token_type_ids
=
return_token_type_ids
,
return_attention_mask
=
return_attention_mask
,
return_overflowing_tokens
=
return_overflowing_tokens
,
return_special_tokens_mask
=
return_special_tokens_mask
)
return_special_tokens_mask
=
return_special_tokens_mask
,
)
def
tokenize
(
self
,
text
):
return
self
.
tokenizer
.
encode
(
text
).
tokens
...
...
@@ -1510,19 +1519,26 @@ class FastPreTrainedTokenizer(PreTrainedTokenizer):
def
add_tokens
(
self
,
new_tokens
):
self
.
tokenizer
.
add_tokens
(
new_tokens
)
def
encode_batch
(
self
,
texts
,
def
encode_batch
(
self
,
texts
,
return_tensors
=
None
,
return_token_type_ids
=
True
,
return_attention_mask
=
True
,
return_overflowing_tokens
=
False
,
return_special_tokens_mask
=
False
):
return
[
self
.
_convert_encoding
(
encoding
,
return_special_tokens_mask
=
False
,
):
return
[
self
.
_convert_encoding
(
encoding
,
return_tensors
=
return_tensors
,
return_token_type_ids
=
return_token_type_ids
,
return_attention_mask
=
return_attention_mask
,
return_overflowing_tokens
=
return_overflowing_tokens
,
return_special_tokens_mask
=
return_special_tokens_mask
)
for
encoding
in
self
.
tokenizer
.
encode_batch
(
texts
)]
return_special_tokens_mask
=
return_special_tokens_mask
,
)
for
encoding
in
self
.
tokenizer
.
encode_batch
(
texts
)
]
def
decode
(
self
,
token_ids
,
skip_special_tokens
=
False
,
clean_up_tokenization_spaces
=
True
):
text
=
self
.
tokenizer
.
decode
(
token_ids
,
skip_special_tokens
)
...
...
@@ -1534,6 +1550,7 @@ class FastPreTrainedTokenizer(PreTrainedTokenizer):
return
text
def
decode_batch
(
self
,
ids_batch
,
skip_special_tokens
=
False
,
clear_up_tokenization_spaces
=
True
):
return
[
self
.
clean_up_tokenization
(
text
)
if
clear_up_tokenization_spaces
else
text
for
text
in
self
.
tokenizer
.
decode_batch
(
ids_batch
,
skip_special_tokens
)]
\ No newline at end of file
return
[
self
.
clean_up_tokenization
(
text
)
if
clear_up_tokenization_spaces
else
text
for
text
in
self
.
tokenizer
.
decode_batch
(
ids_batch
,
skip_special_tokens
)
]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment