Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
78cf7b4a
Commit
78cf7b4a
authored
Dec 18, 2018
by
Patrick Lewis
Browse files
added code to raise value error for bert tokenizer for covert_tokens_to_indices
parent
786cc412
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
53 additions
and
13 deletions
+53
-13
pytorch_pretrained_bert/tokenization.py
pytorch_pretrained_bert/tokenization.py
+33
-11
tests/tokenization_test.py
tests/tokenization_test.py
+20
-2
No files found.
pytorch_pretrained_bert/tokenization.py
View file @
78cf7b4a
...
...
@@ -36,6 +36,15 @@ PRETRAINED_VOCAB_ARCHIVE_MAP = {
'bert-base-multilingual-cased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt"
,
'bert-base-chinese'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt"
,
}
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP
=
{
'bert-base-uncased'
:
512
,
'bert-large-uncased'
:
512
,
'bert-base-cased'
:
512
,
'bert-large-cased'
:
512
,
'bert-base-multilingual-uncased'
:
512
,
'bert-base-multilingual-cased'
:
512
,
'bert-base-chinese'
:
512
,
}
VOCAB_NAME
=
'vocab.txt'
...
...
@@ -65,7 +74,8 @@ def whitespace_tokenize(text):
class
BertTokenizer
(
object
):
"""Runs end-to-end tokenization: punctuation splitting + wordpiece"""
def
__init__
(
self
,
vocab_file
,
do_lower_case
=
True
):
def
__init__
(
self
,
vocab_file
,
do_lower_case
=
True
,
max_len
=
None
):
if
not
os
.
path
.
isfile
(
vocab_file
):
raise
ValueError
(
"Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
...
...
@@ -75,6 +85,7 @@ class BertTokenizer(object):
[(
ids
,
tok
)
for
tok
,
ids
in
self
.
vocab
.
items
()])
self
.
basic_tokenizer
=
BasicTokenizer
(
do_lower_case
=
do_lower_case
)
self
.
wordpiece_tokenizer
=
WordpieceTokenizer
(
vocab
=
self
.
vocab
)
self
.
max_len
=
max_len
if
max_len
is
not
None
else
int
(
1e12
)
def
tokenize
(
self
,
text
):
split_tokens
=
[]
...
...
@@ -88,6 +99,12 @@ class BertTokenizer(object):
ids
=
[]
for
token
in
tokens
:
ids
.
append
(
self
.
vocab
[
token
])
if
len
(
ids
)
>
self
.
max_len
:
raise
ValueError
(
"Token indices sequence length is longer than the specified maximum "
" sequence length for this BERT model ({} > {}). Running this"
" sequence through BERT will result in indexing errors"
.
format
(
len
(
ids
),
self
.
max_len
)
)
return
ids
def
convert_ids_to_tokens
(
self
,
ids
):
...
...
@@ -126,6 +143,11 @@ class BertTokenizer(object):
else
:
logger
.
info
(
"loading vocabulary file {} from cache at {}"
.
format
(
vocab_file
,
resolved_vocab_file
))
if
pretrained_model_name
in
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP
:
# if we're using a pretrained model, ensure the tokenizer wont index sequences longer
# than the number of positional embeddings
max_len
=
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP
[
pretrained_model_name
]
kwargs
[
'max_len'
]
=
min
(
kwargs
.
get
(
'max_len'
,
int
(
1e12
)),
max_len
)
# Instantiate tokenizer.
tokenizer
=
cls
(
resolved_vocab_file
,
*
inputs
,
**
kwargs
)
return
tokenizer
...
...
@@ -193,7 +215,7 @@ class BasicTokenizer(object):
i
+=
1
return
[
""
.
join
(
x
)
for
x
in
output
]
def
_tokenize_chinese_chars
(
self
,
text
):
"""Adds whitespace around any CJK character."""
output
=
[]
...
...
@@ -218,17 +240,17 @@ class BasicTokenizer(object):
# space-separated words, so they are not treated specially and handled
# like the all of the other languages.
if
((
cp
>=
0x4E00
and
cp
<=
0x9FFF
)
or
#
(
cp
>=
0x3400
and
cp
<=
0x4DBF
)
or
#
(
cp
>=
0x20000
and
cp
<=
0x2A6DF
)
or
#
(
cp
>=
0x2A700
and
cp
<=
0x2B73F
)
or
#
(
cp
>=
0x2B740
and
cp
<=
0x2B81F
)
or
#
(
cp
>=
0x2B820
and
cp
<=
0x2CEAF
)
or
(
cp
>=
0xF900
and
cp
<=
0xFAFF
)
or
#
(
cp
>=
0x2F800
and
cp
<=
0x2FA1F
)):
#
(
cp
>=
0x3400
and
cp
<=
0x4DBF
)
or
#
(
cp
>=
0x20000
and
cp
<=
0x2A6DF
)
or
#
(
cp
>=
0x2A700
and
cp
<=
0x2B73F
)
or
#
(
cp
>=
0x2B740
and
cp
<=
0x2B81F
)
or
#
(
cp
>=
0x2B820
and
cp
<=
0x2CEAF
)
or
(
cp
>=
0xF900
and
cp
<=
0xFAFF
)
or
#
(
cp
>=
0x2F800
and
cp
<=
0x2FA1F
)):
#
return
True
return
False
def
_clean_text
(
self
,
text
):
"""Performs invalid character removal and whitespace cleanup on text."""
output
=
[]
...
...
tests/tokenization_test.py
View file @
78cf7b4a
...
...
@@ -44,12 +44,30 @@ class TokenizationTest(unittest.TestCase):
self
.
assertListEqual
(
tokenizer
.
convert_tokens_to_ids
(
tokens
),
[
7
,
4
,
5
,
10
,
8
,
9
])
def
test_full_tokenizer_raises_error_for_long_sequences
(
self
):
vocab_tokens
=
[
"[UNK]"
,
"[CLS]"
,
"[SEP]"
,
"want"
,
"##want"
,
"##ed"
,
"wa"
,
"un"
,
"runn"
,
"##ing"
,
","
]
with
open
(
"/tmp/bert_tokenizer_test.txt"
,
"w"
)
as
vocab_writer
:
vocab_writer
.
write
(
""
.
join
([
x
+
"
\n
"
for
x
in
vocab_tokens
]))
vocab_file
=
vocab_writer
.
name
tokenizer
=
BertTokenizer
(
vocab_file
,
max_len
=
10
)
os
.
remove
(
vocab_file
)
tokens
=
tokenizer
.
tokenize
(
u
"the cat sat on the mat in the summer time"
)
indices
=
tokenizer
.
convert_tokens_to_ids
(
tokens
)
self
.
assertListEqual
(
indices
,
[
0
for
_
in
range
(
10
)])
tokens
=
tokenizer
.
tokenize
(
u
"the cat sat on the mat in the summer time ."
)
self
.
assertRaises
(
ValueError
,
tokenizer
.
convert_tokens_to_ids
,
tokens
)
def
test_chinese
(
self
):
tokenizer
=
BasicTokenizer
()
self
.
assertListEqual
(
tokenizer
.
tokenize
(
u
"ah
\u535A\u63A8
zz"
),
[
u
"ah"
,
u
"
\u535A
"
,
u
"
\u63A8
"
,
u
"zz"
])
[
u
"ah"
,
u
"
\u535A
"
,
u
"
\u63A8
"
,
u
"zz"
])
def
test_basic_tokenizer_lower
(
self
):
tokenizer
=
BasicTokenizer
(
do_lower_case
=
True
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment