Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
c6bea084
"Unix/git@developer.sourcefind.cn:dadigang/Ventoy.git" did not exist on "8dce0adda6e1967de616287e224b914eb1bcef8e"
Commit
c6bea084
authored
Feb 13, 2019
by
thomwolf
Browse files
OpenAI GPT Tokenizer can fallback on using BERT BasicTokenizer
parent
e7cfc46f
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
25 additions
and
10 deletions
+25
-10
pytorch_pretrained_bert/tokenization_openai.py
pytorch_pretrained_bert/tokenization_openai.py
+25
-10
No files found.
pytorch_pretrained_bert/tokenization_openai.py
View file @
c6bea084
...
...
@@ -26,6 +26,7 @@ from io import open
from
tqdm
import
tqdm
from
.file_utils
import
cached_path
from
.tokenization
import
BasicTokenizer
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -72,8 +73,9 @@ class OpenAIGPTTokenizer(object):
"""
BPE tokenizer. Peculiarities:
- lower case all inputs
- uses SpaCy tokenizer
- special tokens: additional symbols (ex: "__classify__") to add to a vocabulary.
- uses SpaCy tokenizer and ftfy for pre-BPE tokenization if they are installed, fallback to BERT's BasicTokenizer if not.
- argument special_tokens and function set_special_tokens:
can be used to add additional symbols (ex: "__classify__") to a vocabulary.
"""
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
cache_dir
=
None
,
*
inputs
,
**
kwargs
):
...
...
@@ -122,12 +124,15 @@ class OpenAIGPTTokenizer(object):
try
:
import
ftfy
import
spacy
self
.
nlp
=
spacy
.
load
(
'en'
,
disable
=
[
'parser'
,
'tagger'
,
'ner'
,
'textcat'
])
self
.
fix_text
=
ftfy
.
fix_text
except
ImportError
:
raise
ImportError
(
"Please install ftfy and spacy to use OpenAI GPT tokenizer."
)
logger
.
warning
(
"ftfy or spacy is not installed using BERT BasicTokenizer instead of SpaCy & ftfy."
)
self
.
nlp
=
BasicTokenizer
(
do_lower_case
=
True
,
never_split
=
special_tokens
if
special_tokens
is
not
None
else
[])
self
.
fix_text
=
None
self
.
max_len
=
max_len
if
max_len
is
not
None
else
int
(
1e12
)
self
.
nlp
=
spacy
.
load
(
'en'
,
disable
=
[
'parser'
,
'tagger'
,
'ner'
,
'textcat'
])
self
.
fix_text
=
ftfy
.
fix_text
self
.
encoder
=
json
.
load
(
open
(
vocab_file
,
encoding
=
"utf-8"
))
self
.
decoder
=
{
v
:
k
for
k
,
v
in
self
.
encoder
.
items
()}
merges
=
open
(
merges_file
,
encoding
=
'utf-8'
).
read
().
split
(
'
\n
'
)[
1
:
-
1
]
...
...
@@ -150,6 +155,9 @@ class OpenAIGPTTokenizer(object):
return
self
.
special_tokens
=
dict
((
tok
,
len
(
self
.
encoder
)
+
i
)
for
i
,
tok
in
enumerate
(
special_tokens
))
self
.
special_tokens_decoder
=
{
v
:
k
for
k
,
v
in
self
.
special_tokens
.
items
()}
if
self
.
fix_text
is
None
:
# Using BERT's BasicTokenizer: we can update the tokenizer
self
.
nlp
.
never_split
=
special_tokens
logger
.
info
(
"Special tokens {}"
.
format
(
self
.
special_tokens
))
def
bpe
(
self
,
token
):
...
...
@@ -198,9 +206,16 @@ class OpenAIGPTTokenizer(object):
def
tokenize
(
self
,
text
):
""" Tokenize a string. """
split_tokens
=
[]
text
=
self
.
nlp
(
text_standardize
(
self
.
fix_text
(
text
)))
for
token
in
text
:
split_tokens
.
extend
([
t
for
t
in
self
.
bpe
(
token
.
text
.
lower
()).
split
(
' '
)])
if
self
.
fix_text
is
None
:
# Using BERT's BasicTokenizer
text
=
self
.
nlp
.
tokenize
(
text
)
for
token
in
text
:
split_tokens
.
extend
([
t
for
t
in
self
.
bpe
(
token
).
split
(
' '
)])
else
:
# Using SpaCy & ftfy (original tokenization process of OpenAI GPT)
text
=
self
.
nlp
(
text_standardize
(
self
.
fix_text
(
text
)))
for
token
in
text
:
split_tokens
.
extend
([
t
for
t
in
self
.
bpe
(
token
.
text
.
lower
()).
split
(
' '
)])
return
split_tokens
def
convert_tokens_to_ids
(
self
,
tokens
):
...
...
@@ -219,8 +234,8 @@ class OpenAIGPTTokenizer(object):
if
len
(
ids
)
>
self
.
max_len
:
raise
ValueError
(
"Token indices sequence length is longer than the specified maximum "
" sequence length for this
BER
T model ({} > {}). Running this"
" sequence through
BERT
will result in indexing errors"
.
format
(
len
(
ids
),
self
.
max_len
)
" sequence length for this
OpenAI GP
T model ({} > {}). Running this"
" sequence through
the model
will result in indexing errors"
.
format
(
len
(
ids
),
self
.
max_len
)
)
return
ids
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment