Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
197d74f9
"...git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "e85f029bd4e4b1bdf3e679893fb6447e4d6b2c79"
Commit
197d74f9
authored
Feb 20, 2020
by
Joe Davison
Browse files
Add get_vocab method to PretrainedTokenizer
parent
ea8eba35
Changes
12
Show whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
62 additions
and
0 deletions
+62
-0
src/transformers/tokenization_albert.py
src/transformers/tokenization_albert.py
+5
-0
src/transformers/tokenization_bert.py
src/transformers/tokenization_bert.py
+3
-0
src/transformers/tokenization_ctrl.py
src/transformers/tokenization_ctrl.py
+3
-0
src/transformers/tokenization_gpt2.py
src/transformers/tokenization_gpt2.py
+3
-0
src/transformers/tokenization_openai.py
src/transformers/tokenization_openai.py
+3
-0
src/transformers/tokenization_t5.py
src/transformers/tokenization_t5.py
+5
-0
src/transformers/tokenization_transfo_xl.py
src/transformers/tokenization_transfo_xl.py
+3
-0
src/transformers/tokenization_utils.py
src/transformers/tokenization_utils.py
+4
-0
src/transformers/tokenization_xlm.py
src/transformers/tokenization_xlm.py
+3
-0
src/transformers/tokenization_xlm_roberta.py
src/transformers/tokenization_xlm_roberta.py
+5
-0
src/transformers/tokenization_xlnet.py
src/transformers/tokenization_xlnet.py
+5
-0
tests/test_tokenization_common.py
tests/test_tokenization_common.py
+20
-0
No files found.
src/transformers/tokenization_albert.py
View file @
197d74f9
...
@@ -114,6 +114,11 @@ class AlbertTokenizer(PreTrainedTokenizer):
...
@@ -114,6 +114,11 @@ class AlbertTokenizer(PreTrainedTokenizer):
def
vocab_size
(
self
):
def
vocab_size
(
self
):
return
len
(
self
.
sp_model
)
return
len
(
self
.
sp_model
)
def
get_vocab
(
self
):
vocab
=
{
self
.
convert_ids_to_tokens
(
i
):
i
for
i
in
range
(
self
.
vocab_size
)}
vocab
.
update
(
self
.
added_tokens_encoder
)
return
vocab
def
__getstate__
(
self
):
def
__getstate__
(
self
):
state
=
self
.
__dict__
.
copy
()
state
=
self
.
__dict__
.
copy
()
state
[
"sp_model"
]
=
None
state
[
"sp_model"
]
=
None
...
...
src/transformers/tokenization_bert.py
View file @
197d74f9
...
@@ -195,6 +195,9 @@ class BertTokenizer(PreTrainedTokenizer):
...
@@ -195,6 +195,9 @@ class BertTokenizer(PreTrainedTokenizer):
def
vocab_size
(
self
):
def
vocab_size
(
self
):
return
len
(
self
.
vocab
)
return
len
(
self
.
vocab
)
def
get_vocab
(
self
):
return
dict
(
self
.
vocab
,
**
self
.
added_tokens_encoder
)
def
_tokenize
(
self
,
text
):
def
_tokenize
(
self
,
text
):
split_tokens
=
[]
split_tokens
=
[]
if
self
.
do_basic_tokenize
:
if
self
.
do_basic_tokenize
:
...
...
src/transformers/tokenization_ctrl.py
View file @
197d74f9
...
@@ -147,6 +147,9 @@ class CTRLTokenizer(PreTrainedTokenizer):
...
@@ -147,6 +147,9 @@ class CTRLTokenizer(PreTrainedTokenizer):
def
vocab_size
(
self
):
def
vocab_size
(
self
):
return
len
(
self
.
encoder
)
return
len
(
self
.
encoder
)
def
get_vocab
(
self
):
return
dict
(
self
.
encoder
,
**
self
.
added_tokens_encoder
)
def
bpe
(
self
,
token
):
def
bpe
(
self
,
token
):
if
token
in
self
.
cache
:
if
token
in
self
.
cache
:
return
self
.
cache
[
token
]
return
self
.
cache
[
token
]
...
...
src/transformers/tokenization_gpt2.py
View file @
197d74f9
...
@@ -149,6 +149,9 @@ class GPT2Tokenizer(PreTrainedTokenizer):
...
@@ -149,6 +149,9 @@ class GPT2Tokenizer(PreTrainedTokenizer):
def
vocab_size
(
self
):
def
vocab_size
(
self
):
return
len
(
self
.
encoder
)
return
len
(
self
.
encoder
)
def
get_vocab
(
self
):
return
dict
(
self
.
encoder
,
**
self
.
added_tokens_encoder
)
def
bpe
(
self
,
token
):
def
bpe
(
self
,
token
):
if
token
in
self
.
cache
:
if
token
in
self
.
cache
:
return
self
.
cache
[
token
]
return
self
.
cache
[
token
]
...
...
src/transformers/tokenization_openai.py
View file @
197d74f9
...
@@ -125,6 +125,9 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer):
...
@@ -125,6 +125,9 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer):
def
vocab_size
(
self
):
def
vocab_size
(
self
):
return
len
(
self
.
encoder
)
return
len
(
self
.
encoder
)
def
get_vocab
(
self
):
return
dict
(
self
.
encoder
,
**
self
.
added_tokens_encoder
)
def
bpe
(
self
,
token
):
def
bpe
(
self
,
token
):
word
=
tuple
(
token
[:
-
1
])
+
(
token
[
-
1
]
+
"</w>"
,)
word
=
tuple
(
token
[:
-
1
])
+
(
token
[
-
1
]
+
"</w>"
,)
if
token
in
self
.
cache
:
if
token
in
self
.
cache
:
...
...
src/transformers/tokenization_t5.py
View file @
197d74f9
...
@@ -119,6 +119,11 @@ class T5Tokenizer(PreTrainedTokenizer):
...
@@ -119,6 +119,11 @@ class T5Tokenizer(PreTrainedTokenizer):
def
vocab_size
(
self
):
def
vocab_size
(
self
):
return
self
.
sp_model
.
get_piece_size
()
+
self
.
_extra_ids
return
self
.
sp_model
.
get_piece_size
()
+
self
.
_extra_ids
def
get_vocab
(
self
):
vocab
=
{
self
.
convert_ids_to_tokens
(
i
):
i
for
i
in
range
(
self
.
vocab_size
)}
vocab
.
update
(
self
.
added_tokens_encoder
)
return
vocab
def
__getstate__
(
self
):
def
__getstate__
(
self
):
state
=
self
.
__dict__
.
copy
()
state
=
self
.
__dict__
.
copy
()
state
[
"sp_model"
]
=
None
state
[
"sp_model"
]
=
None
...
...
src/transformers/tokenization_transfo_xl.py
View file @
197d74f9
...
@@ -273,6 +273,9 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
...
@@ -273,6 +273,9 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
def
vocab_size
(
self
):
def
vocab_size
(
self
):
return
len
(
self
.
idx2sym
)
return
len
(
self
.
idx2sym
)
def
get_vocab
(
self
):
return
dict
(
self
.
sym2idx
,
**
self
.
added_tokens_encoder
)
def
_tokenize
(
self
,
line
,
add_eos
=
False
,
add_double_eos
=
False
):
def
_tokenize
(
self
,
line
,
add_eos
=
False
,
add_double_eos
=
False
):
line
=
line
.
strip
()
line
=
line
.
strip
()
# convert to lower case
# convert to lower case
...
...
src/transformers/tokenization_utils.py
View file @
197d74f9
...
@@ -286,6 +286,10 @@ class PreTrainedTokenizer(object):
...
@@ -286,6 +286,10 @@ class PreTrainedTokenizer(object):
""" Ids of all the additional special tokens in the vocabulary (list of integers). Log an error if used while not having been set. """
""" Ids of all the additional special tokens in the vocabulary (list of integers). Log an error if used while not having been set. """
return
self
.
convert_tokens_to_ids
(
self
.
additional_special_tokens
)
return
self
.
convert_tokens_to_ids
(
self
.
additional_special_tokens
)
def
get_vocab
(
self
):
""" Returns the vocabulary as a dict of {token: index} pairs. `tokenizer.get_vocab()[token]` is equivalent to `tokenizer.convert_tokens_to_ids(token)` when `token` is in the vocab. """
raise
NotImplementedError
()
def
__init__
(
self
,
max_len
=
None
,
**
kwargs
):
def
__init__
(
self
,
max_len
=
None
,
**
kwargs
):
self
.
_bos_token
=
None
self
.
_bos_token
=
None
self
.
_eos_token
=
None
self
.
_eos_token
=
None
...
...
src/transformers/tokenization_xlm.py
View file @
197d74f9
...
@@ -662,6 +662,9 @@ class XLMTokenizer(PreTrainedTokenizer):
...
@@ -662,6 +662,9 @@ class XLMTokenizer(PreTrainedTokenizer):
def
vocab_size
(
self
):
def
vocab_size
(
self
):
return
len
(
self
.
encoder
)
return
len
(
self
.
encoder
)
def
get_vocab
(
self
):
return
dict
(
self
.
encoder
,
**
self
.
added_tokens_encoder
)
def
bpe
(
self
,
token
):
def
bpe
(
self
,
token
):
word
=
tuple
(
token
[:
-
1
])
+
(
token
[
-
1
]
+
"</w>"
,)
word
=
tuple
(
token
[:
-
1
])
+
(
token
[
-
1
]
+
"</w>"
,)
if
token
in
self
.
cache
:
if
token
in
self
.
cache
:
...
...
src/transformers/tokenization_xlm_roberta.py
View file @
197d74f9
...
@@ -190,6 +190,11 @@ class XLMRobertaTokenizer(PreTrainedTokenizer):
...
@@ -190,6 +190,11 @@ class XLMRobertaTokenizer(PreTrainedTokenizer):
def
vocab_size
(
self
):
def
vocab_size
(
self
):
return
len
(
self
.
sp_model
)
+
len
(
self
.
fairseq_tokens_to_ids
)
return
len
(
self
.
sp_model
)
+
len
(
self
.
fairseq_tokens_to_ids
)
def
get_vocab
(
self
):
vocab
=
{
self
.
convert_ids_to_tokens
(
i
):
i
for
i
in
range
(
self
.
vocab_size
)}
vocab
.
update
(
self
.
added_tokens_encoder
)
return
vocab
def
_tokenize
(
self
,
text
):
def
_tokenize
(
self
,
text
):
return
self
.
sp_model
.
EncodeAsPieces
(
text
)
return
self
.
sp_model
.
EncodeAsPieces
(
text
)
...
...
src/transformers/tokenization_xlnet.py
View file @
197d74f9
...
@@ -114,6 +114,11 @@ class XLNetTokenizer(PreTrainedTokenizer):
...
@@ -114,6 +114,11 @@ class XLNetTokenizer(PreTrainedTokenizer):
def
vocab_size
(
self
):
def
vocab_size
(
self
):
return
len
(
self
.
sp_model
)
return
len
(
self
.
sp_model
)
def
get_vocab
(
self
):
vocab
=
{
self
.
convert_ids_to_tokens
(
i
):
i
for
i
in
range
(
self
.
vocab_size
)}
vocab
.
update
(
self
.
added_tokens_encoder
)
return
vocab
def
__getstate__
(
self
):
def
__getstate__
(
self
):
state
=
self
.
__dict__
.
copy
()
state
=
self
.
__dict__
.
copy
()
state
[
"sp_model"
]
=
None
state
[
"sp_model"
]
=
None
...
...
tests/test_tokenization_common.py
View file @
197d74f9
...
@@ -542,3 +542,23 @@ class TokenizerTesterMixin:
...
@@ -542,3 +542,23 @@ class TokenizerTesterMixin:
print
(
new_tokenizer
.
init_kwargs
)
print
(
new_tokenizer
.
init_kwargs
)
assert
tokenizer
.
init_kwargs
[
"random_argument"
]
is
True
assert
tokenizer
.
init_kwargs
[
"random_argument"
]
is
True
assert
new_tokenizer
.
init_kwargs
[
"random_argument"
]
is
False
assert
new_tokenizer
.
init_kwargs
[
"random_argument"
]
is
False
def
test_get_vocab
(
self
):
tokenizer
=
self
.
get_tokenizer
()
vocab
=
tokenizer
.
get_vocab
()
self
.
assertIsInstance
(
vocab
,
dict
)
self
.
assertEqual
(
len
(
vocab
),
len
(
tokenizer
))
for
word
,
ind
in
vocab
.
items
():
self
.
assertEqual
(
tokenizer
.
convert_tokens_to_ids
(
word
),
ind
)
self
.
assertEqual
(
tokenizer
.
convert_ids_to_tokens
(
ind
),
word
)
tokenizer
.
add_tokens
([
"asdfasdfasdfasdf"
])
vocab
=
tokenizer
.
get_vocab
()
self
.
assertIsInstance
(
vocab
,
dict
)
self
.
assertEqual
(
len
(
vocab
),
len
(
tokenizer
))
for
word
,
ind
in
vocab
.
items
():
self
.
assertEqual
(
tokenizer
.
convert_tokens_to_ids
(
word
),
ind
)
self
.
assertEqual
(
tokenizer
.
convert_ids_to_tokens
(
ind
),
word
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment