Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
b450a7fa
Commit
b450a7fa
authored
Feb 18, 2019
by
thomwolf
Browse files
clean up tokenization - fix python 2 tests
parent
d44db114
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
17 additions
and
5 deletions
+17
-5
pytorch_pretrained_bert/tokenization_gpt2.py
pytorch_pretrained_bert/tokenization_gpt2.py
+17
-5
No files found.
pytorch_pretrained_bert/tokenization_gpt2.py
View file @
b450a7fa
...
...
@@ -20,14 +20,19 @@ import json
import
logging
import
os
import
regex
as
re
import
sys
from
io
import
open
from
functools
import
lru_cache
from
tqdm
import
tqdm
try
:
from
functools
import
lru_cache
except
ImportError
:
# Just a dummy decorator to get the checks to run on python2
# because honestly I don't want to support a byte-level unicode BPE tokenizer on python 2 right now.
def
lru_cache
(
func
):
def
func_wrapper
(
*
inputs
,
**
args
):
return
func
(
inputs
,
args
)
return
func_wrapper
from
.file_utils
import
cached_path
from
.tokenization
import
BasicTokenizer
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -125,7 +130,8 @@ class GPT2Tokenizer(object):
tokenizer
=
cls
(
resolved_vocab_file
,
resolved_merges_file
,
*
inputs
,
**
kwargs
)
return
tokenizer
def
__init__
(
self
,
vocab_file
,
merges_file
,
errors
=
'replace'
):
def
__init__
(
self
,
vocab_file
,
merges_file
,
errors
=
'replace'
,
max_len
=
None
):
self
.
max_len
=
max_len
if
max_len
is
not
None
else
int
(
1e12
)
self
.
encoder
=
json
.
load
(
open
(
vocab_file
))
self
.
decoder
=
{
v
:
k
for
k
,
v
in
self
.
encoder
.
items
()}
self
.
errors
=
errors
# how to handle errors in decoding
...
...
@@ -188,6 +194,12 @@ class GPT2Tokenizer(object):
for
token
in
re
.
findall
(
self
.
pat
,
text
):
token
=
''
.
join
(
self
.
byte_encoder
[
b
]
for
b
in
token
.
encode
(
'utf-8'
))
bpe_tokens
.
extend
(
self
.
encoder
[
bpe_token
]
for
bpe_token
in
self
.
bpe
(
token
).
split
(
' '
))
if
len
(
bpe_tokens
)
>
self
.
max_len
:
raise
ValueError
(
"Token indices sequence length is longer than the specified maximum "
" sequence length for this OpenAI GPT-2 model ({} > {}). Running this"
" sequence through the model will result in indexing errors"
.
format
(
len
(
bpe_tokens
),
self
.
max_len
)
)
return
bpe_tokens
def
decode
(
self
,
tokens
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment