Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
2ae3be54
Unverified
Commit
2ae3be54
authored
Jan 18, 2022
by
Suraj Patil
Committed by
GitHub
Jan 18, 2022
Browse files
[MBartTokenizer] remove dep on xlm-roberta tokenizer (#15201)
parent
84c60a7b
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
219 additions
and
70 deletions
+219
-70
src/transformers/models/mbart/tokenization_mbart.py
src/transformers/models/mbart/tokenization_mbart.py
+143
-37
src/transformers/models/mbart/tokenization_mbart_fast.py
src/transformers/models/mbart/tokenization_mbart_fast.py
+76
-33
No files found.
src/transformers/models/mbart/tokenization_mbart.py
View file @
2ae3be54
...
@@ -13,16 +13,20 @@
...
@@ -13,16 +13,20 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
os
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
typing
import
List
,
Optional
from
shutil
import
copyfile
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
from
...tokenization_utils
import
BatchEncoding
import
sentencepiece
as
spm
from
...tokenization_utils
import
AddedToken
,
BatchEncoding
,
PreTrainedTokenizer
from
...utils
import
logging
from
...utils
import
logging
from
..xlm_roberta.tokenization_xlm_roberta
import
XLMRobertaTokenizer
logger
=
logging
.
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
SPIECE_UNDERLINE
=
"▁"
VOCAB_FILES_NAMES
=
{
"vocab_file"
:
"sentencepiece.bpe.model"
}
VOCAB_FILES_NAMES
=
{
"vocab_file"
:
"sentencepiece.bpe.model"
}
...
@@ -38,41 +42,17 @@ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
...
@@ -38,41 +42,17 @@ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
"facebook/mbart-large-cc25"
:
1024
,
"facebook/mbart-large-cc25"
:
1024
,
}
}
FAIRSEQ_LANGUAGE_CODES
=
[
# fmt: off
"ar_AR"
,
FAIRSEQ_LANGUAGE_CODES
=
[
"ar_AR"
,
"cs_CZ"
,
"de_DE"
,
"en_XX"
,
"es_XX"
,
"et_EE"
,
"fi_FI"
,
"fr_XX"
,
"gu_IN"
,
"hi_IN"
,
"it_IT"
,
"ja_XX"
,
"kk_KZ"
,
"ko_KR"
,
"lt_LT"
,
"lv_LV"
,
"my_MM"
,
"ne_NP"
,
"nl_XX"
,
"ro_RO"
,
"ru_RU"
,
"si_LK"
,
"tr_TR"
,
"vi_VN"
,
"zh_CN"
]
"cs_CZ"
,
# fmt: on
"de_DE"
,
"en_XX"
,
"es_XX"
,
class
MBartTokenizer
(
PreTrainedTokenizer
):
"et_EE"
,
"fi_FI"
,
"fr_XX"
,
"gu_IN"
,
"hi_IN"
,
"it_IT"
,
"ja_XX"
,
"kk_KZ"
,
"ko_KR"
,
"lt_LT"
,
"lv_LV"
,
"my_MM"
,
"ne_NP"
,
"nl_XX"
,
"ro_RO"
,
"ru_RU"
,
"si_LK"
,
"tr_TR"
,
"vi_VN"
,
"zh_CN"
,
]
class
MBartTokenizer
(
XLMRobertaTokenizer
):
"""
"""
Construct an MBART tokenizer.
Construct an MBART tokenizer.
[`MBartTokenizer`] is a subclass of [`XLMRobertaTokenizer`]. Refer to superclass [`XLMRoberta
Tokenizer`]
for usage
Adapted from [`RobertaTokenizer`] and [`XLNet
Tokenizer`]
. Based on
examples and documentation concerning the initialization parameters and other methods
.
[SentencePiece](https://github.com/google/sentencepiece)
.
The tokenization method is `<tokens> <eos> <language code>` for source language documents, and ``<language code>
The tokenization method is `<tokens> <eos> <language code>` for source language documents, and ``<language code>
<tokens> <eos>``` for target language documents.
<tokens> <eos>``` for target language documents.
...
@@ -94,22 +74,66 @@ class MBartTokenizer(XLMRobertaTokenizer):
...
@@ -94,22 +74,66 @@ class MBartTokenizer(XLMRobertaTokenizer):
vocab_files_names
=
VOCAB_FILES_NAMES
vocab_files_names
=
VOCAB_FILES_NAMES
max_model_input_sizes
=
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
max_model_input_sizes
=
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
pretrained_vocab_files_map
=
PRETRAINED_VOCAB_FILES_MAP
pretrained_vocab_files_map
=
PRETRAINED_VOCAB_FILES_MAP
model_input_names
=
[
"input_ids"
,
"attention_mask"
]
prefix_tokens
:
List
[
int
]
=
[]
prefix_tokens
:
List
[
int
]
=
[]
suffix_tokens
:
List
[
int
]
=
[]
suffix_tokens
:
List
[
int
]
=
[]
def
__init__
(
def
__init__
(
self
,
*
args
,
tokenizer_file
=
None
,
src_lang
=
None
,
tgt_lang
=
None
,
additional_special_tokens
=
None
,
**
kwargs
self
,
vocab_file
,
bos_token
=
"<s>"
,
eos_token
=
"</s>"
,
sep_token
=
"</s>"
,
cls_token
=
"<s>"
,
unk_token
=
"<unk>"
,
pad_token
=
"<pad>"
,
mask_token
=
"<mask>"
,
tokenizer_file
=
None
,
src_lang
=
None
,
tgt_lang
=
None
,
sp_model_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
additional_special_tokens
=
None
,
**
kwargs
):
):
# Mask token behave like a normal word, i.e. include the space before it
mask_token
=
AddedToken
(
mask_token
,
lstrip
=
True
,
rstrip
=
False
)
if
isinstance
(
mask_token
,
str
)
else
mask_token
self
.
sp_model_kwargs
=
{}
if
sp_model_kwargs
is
None
else
sp_model_kwargs
super
().
__init__
(
super
().
__init__
(
*
args
,
bos_token
=
bos_token
,
eos_token
=
eos_token
,
unk_token
=
unk_token
,
sep_token
=
sep_token
,
cls_token
=
cls_token
,
pad_token
=
pad_token
,
mask_token
=
mask_token
,
tokenizer_file
=
tokenizer_file
,
tokenizer_file
=
tokenizer_file
,
src_lang
=
src_lang
,
src_lang
=
src_lang
,
tgt_lang
=
tgt_lang
,
tgt_lang
=
tgt_lang
,
additional_special_tokens
=
additional_special_tokens
,
additional_special_tokens
=
additional_special_tokens
,
sp_model_kwargs
=
self
.
sp_model_kwargs
,
**
kwargs
,
**
kwargs
,
)
)
self
.
sp_model
=
spm
.
SentencePieceProcessor
(
**
self
.
sp_model_kwargs
)
self
.
sp_model
.
Load
(
str
(
vocab_file
))
self
.
vocab_file
=
vocab_file
# Original fairseq vocab and spm vocab must be "aligned":
# Vocab | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9
# -------- | ------- | ------- | ------ | ------- | --- | --- | --- | ----- | ----- | ----
# fairseq | '<s>' | '<pad>' | '</s>' | '<unk>' | ',' | '.' | '▁' | 's' | '▁de' | '-'
# spm | '<unk>' | '<s>' | '</s>' | ',' | '.' | '▁' | 's' | '▁de' | '-' | '▁a'
# Mimic fairseq token-to-id alignment for the first 4 token
self
.
fairseq_tokens_to_ids
=
{
"<s>"
:
0
,
"<pad>"
:
1
,
"</s>"
:
2
,
"<unk>"
:
3
}
# The first "real" token "," has position 4 in the original fairseq vocab and position 3 in the spm vocab
self
.
fairseq_offset
=
1
self
.
sp_model_size
=
len
(
self
.
sp_model
)
self
.
sp_model_size
=
len
(
self
.
sp_model
)
self
.
lang_code_to_id
=
{
self
.
lang_code_to_id
=
{
code
:
self
.
sp_model_size
+
i
+
self
.
fairseq_offset
for
i
,
code
in
enumerate
(
FAIRSEQ_LANGUAGE_CODES
)
code
:
self
.
sp_model_size
+
i
+
self
.
fairseq_offset
for
i
,
code
in
enumerate
(
FAIRSEQ_LANGUAGE_CODES
)
...
@@ -132,6 +156,22 @@ class MBartTokenizer(XLMRobertaTokenizer):
...
@@ -132,6 +156,22 @@ class MBartTokenizer(XLMRobertaTokenizer):
self
.
tgt_lang
=
tgt_lang
self
.
tgt_lang
=
tgt_lang
self
.
set_src_lang_special_tokens
(
self
.
_src_lang
)
self
.
set_src_lang_special_tokens
(
self
.
_src_lang
)
def
__getstate__
(
self
):
state
=
self
.
__dict__
.
copy
()
state
[
"sp_model"
]
=
None
state
[
"sp_model_proto"
]
=
self
.
sp_model
.
serialized_model_proto
()
return
state
def
__setstate__
(
self
,
d
):
self
.
__dict__
=
d
# for backward compatibility
if
not
hasattr
(
self
,
"sp_model_kwargs"
):
self
.
sp_model_kwargs
=
{}
self
.
sp_model
=
spm
.
SentencePieceProcessor
(
**
self
.
sp_model_kwargs
)
self
.
sp_model
.
LoadFromSerializedProto
(
self
.
sp_model_proto
)
@
property
@
property
def
vocab_size
(
self
):
def
vocab_size
(
self
):
return
len
(
self
.
sp_model
)
+
len
(
self
.
lang_code_to_id
)
+
self
.
fairseq_offset
+
1
# Plus 1 for the mask token
return
len
(
self
.
sp_model
)
+
len
(
self
.
lang_code_to_id
)
+
self
.
fairseq_offset
+
1
# Plus 1 for the mask token
...
@@ -202,6 +242,31 @@ class MBartTokenizer(XLMRobertaTokenizer):
...
@@ -202,6 +242,31 @@ class MBartTokenizer(XLMRobertaTokenizer):
# We don't expect to process pairs, but leave the pair logic for API consistency
# We don't expect to process pairs, but leave the pair logic for API consistency
return
self
.
prefix_tokens
+
token_ids_0
+
token_ids_1
+
self
.
suffix_tokens
return
self
.
prefix_tokens
+
token_ids_0
+
token_ids_1
+
self
.
suffix_tokens
def
create_token_type_ids_from_sequences
(
self
,
token_ids_0
:
List
[
int
],
token_ids_1
:
Optional
[
List
[
int
]]
=
None
)
->
List
[
int
]:
"""
Create a mask from the two sequences passed to be used in a sequence-pair classification task. mBART does not
make use of token type ids, therefore a list of zeros is returned.
Args:
token_ids_0 (`List[int]`):
List of IDs.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
Returns:
`List[int]`: List of zeros.
"""
sep
=
[
self
.
sep_token_id
]
cls
=
[
self
.
cls_token_id
]
if
token_ids_1
is
None
:
return
len
(
cls
+
token_ids_0
+
sep
)
*
[
0
]
return
len
(
cls
+
token_ids_0
+
sep
+
sep
+
token_ids_1
+
sep
)
*
[
0
]
def
_build_translation_inputs
(
def
_build_translation_inputs
(
self
,
raw_inputs
,
return_tensors
:
str
,
src_lang
:
Optional
[
str
],
tgt_lang
:
Optional
[
str
],
**
extra_kwargs
self
,
raw_inputs
,
return_tensors
:
str
,
src_lang
:
Optional
[
str
],
tgt_lang
:
Optional
[
str
],
**
extra_kwargs
):
):
...
@@ -214,6 +279,47 @@ class MBartTokenizer(XLMRobertaTokenizer):
...
@@ -214,6 +279,47 @@ class MBartTokenizer(XLMRobertaTokenizer):
inputs
[
"forced_bos_token_id"
]
=
tgt_lang_id
inputs
[
"forced_bos_token_id"
]
=
tgt_lang_id
return
inputs
return
inputs
def
get_vocab
(
self
):
vocab
=
{
self
.
convert_ids_to_tokens
(
i
):
i
for
i
in
range
(
self
.
vocab_size
)}
vocab
.
update
(
self
.
added_tokens_encoder
)
return
vocab
def
_tokenize
(
self
,
text
:
str
)
->
List
[
str
]:
return
self
.
sp_model
.
encode
(
text
,
out_type
=
str
)
def
_convert_token_to_id
(
self
,
token
):
"""Converts a token (str) in an id using the vocab."""
if
token
in
self
.
fairseq_tokens_to_ids
:
return
self
.
fairseq_tokens_to_ids
[
token
]
spm_id
=
self
.
sp_model
.
PieceToId
(
token
)
# Need to return unknown token if the SP model returned 0
return
spm_id
+
self
.
fairseq_offset
if
spm_id
else
self
.
unk_token_id
def
_convert_id_to_token
(
self
,
index
):
"""Converts an index (integer) in a token (str) using the vocab."""
if
index
in
self
.
fairseq_ids_to_tokens
:
return
self
.
fairseq_ids_to_tokens
[
index
]
return
self
.
sp_model
.
IdToPiece
(
index
-
self
.
fairseq_offset
)
def
convert_tokens_to_string
(
self
,
tokens
):
"""Converts a sequence of tokens (strings for sub-words) in a single string."""
out_string
=
""
.
join
(
tokens
).
replace
(
SPIECE_UNDERLINE
,
" "
).
strip
()
return
out_string
def
save_vocabulary
(
self
,
save_directory
:
str
,
filename_prefix
:
Optional
[
str
]
=
None
)
->
Tuple
[
str
]:
if
not
os
.
path
.
isdir
(
save_directory
):
logger
.
error
(
f
"Vocabulary path (
{
save_directory
}
) should be a directory"
)
return
out_vocab_file
=
os
.
path
.
join
(
save_directory
,
(
filename_prefix
+
"-"
if
filename_prefix
else
""
)
+
VOCAB_FILES_NAMES
[
"vocab_file"
]
)
if
os
.
path
.
abspath
(
self
.
vocab_file
)
!=
os
.
path
.
abspath
(
out_vocab_file
):
copyfile
(
self
.
vocab_file
,
out_vocab_file
)
return
(
out_vocab_file
,)
def
prepare_seq2seq_batch
(
def
prepare_seq2seq_batch
(
self
,
self
,
src_texts
:
List
[
str
],
src_texts
:
List
[
str
],
...
...
src/transformers/models/mbart/tokenization_mbart_fast.py
View file @
2ae3be54
...
@@ -13,15 +13,17 @@
...
@@ -13,15 +13,17 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
os
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
typing
import
List
,
Optional
from
shutil
import
copyfile
from
typing
import
List
,
Optional
,
Tuple
from
tokenizers
import
processors
from
tokenizers
import
processors
from
...file_utils
import
is_sentencepiece_available
from
...file_utils
import
is_sentencepiece_available
from
...tokenization_utils
import
BatchEncoding
from
...tokenization_utils
import
AddedToken
,
BatchEncoding
from
...tokenization_utils_fast
import
PreTrainedTokenizerFast
from
...utils
import
logging
from
...utils
import
logging
from
..xlm_roberta.tokenization_xlm_roberta_fast
import
XLMRobertaTokenizerFast
if
is_sentencepiece_available
():
if
is_sentencepiece_available
():
...
@@ -51,36 +53,12 @@ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
...
@@ -51,36 +53,12 @@ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
"facebook/mbart-large-cc25"
:
1024
,
"facebook/mbart-large-cc25"
:
1024
,
}
}
FAIRSEQ_LANGUAGE_CODES
=
[
# fmt: off
"ar_AR"
,
FAIRSEQ_LANGUAGE_CODES
=
[
"ar_AR"
,
"cs_CZ"
,
"de_DE"
,
"en_XX"
,
"es_XX"
,
"et_EE"
,
"fi_FI"
,
"fr_XX"
,
"gu_IN"
,
"hi_IN"
,
"it_IT"
,
"ja_XX"
,
"kk_KZ"
,
"ko_KR"
,
"lt_LT"
,
"lv_LV"
,
"my_MM"
,
"ne_NP"
,
"nl_XX"
,
"ro_RO"
,
"ru_RU"
,
"si_LK"
,
"tr_TR"
,
"vi_VN"
,
"zh_CN"
]
"cs_CZ"
,
# fmt: on
"de_DE"
,
"en_XX"
,
"es_XX"
,
class
MBartTokenizerFast
(
PreTrainedTokenizerFast
):
"et_EE"
,
"fi_FI"
,
"fr_XX"
,
"gu_IN"
,
"hi_IN"
,
"it_IT"
,
"ja_XX"
,
"kk_KZ"
,
"ko_KR"
,
"lt_LT"
,
"lv_LV"
,
"my_MM"
,
"ne_NP"
,
"nl_XX"
,
"ro_RO"
,
"ru_RU"
,
"si_LK"
,
"tr_TR"
,
"vi_VN"
,
"zh_CN"
,
]
class
MBartTokenizerFast
(
XLMRobertaTokenizerFast
):
"""
"""
Construct a "fast" MBART tokenizer (backed by HuggingFace's *tokenizers* library). Based on
Construct a "fast" MBART tokenizer (backed by HuggingFace's *tokenizers* library). Based on
[BPE](https://huggingface.co/docs/tokenizers/python/latest/components.html?highlight=BPE#models).
[BPE](https://huggingface.co/docs/tokenizers/python/latest/components.html?highlight=BPE#models).
...
@@ -111,6 +89,7 @@ class MBartTokenizerFast(XLMRobertaTokenizerFast):
...
@@ -111,6 +89,7 @@ class MBartTokenizerFast(XLMRobertaTokenizerFast):
vocab_files_names
=
VOCAB_FILES_NAMES
vocab_files_names
=
VOCAB_FILES_NAMES
max_model_input_sizes
=
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
max_model_input_sizes
=
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
pretrained_vocab_files_map
=
PRETRAINED_VOCAB_FILES_MAP
pretrained_vocab_files_map
=
PRETRAINED_VOCAB_FILES_MAP
model_input_names
=
[
"input_ids"
,
"attention_mask"
]
slow_tokenizer_class
=
MBartTokenizer
slow_tokenizer_class
=
MBartTokenizer
prefix_tokens
:
List
[
int
]
=
[]
prefix_tokens
:
List
[
int
]
=
[]
...
@@ -120,20 +99,40 @@ class MBartTokenizerFast(XLMRobertaTokenizerFast):
...
@@ -120,20 +99,40 @@ class MBartTokenizerFast(XLMRobertaTokenizerFast):
self
,
self
,
vocab_file
=
None
,
vocab_file
=
None
,
tokenizer_file
=
None
,
tokenizer_file
=
None
,
bos_token
=
"<s>"
,
eos_token
=
"</s>"
,
sep_token
=
"</s>"
,
cls_token
=
"<s>"
,
unk_token
=
"<unk>"
,
pad_token
=
"<pad>"
,
mask_token
=
"<mask>"
,
src_lang
=
None
,
src_lang
=
None
,
tgt_lang
=
None
,
tgt_lang
=
None
,
additional_special_tokens
=
None
,
additional_special_tokens
=
None
,
**
kwargs
**
kwargs
):
):
# Mask token behave like a normal word, i.e. include the space before it
mask_token
=
AddedToken
(
mask_token
,
lstrip
=
True
,
rstrip
=
False
)
if
isinstance
(
mask_token
,
str
)
else
mask_token
super
().
__init__
(
super
().
__init__
(
vocab_file
=
vocab_file
,
vocab_file
=
vocab_file
,
tokenizer_file
=
tokenizer_file
,
tokenizer_file
=
tokenizer_file
,
bos_token
=
bos_token
,
eos_token
=
eos_token
,
sep_token
=
sep_token
,
cls_token
=
cls_token
,
unk_token
=
unk_token
,
pad_token
=
pad_token
,
mask_token
=
mask_token
,
src_lang
=
src_lang
,
src_lang
=
src_lang
,
tgt_lang
=
tgt_lang
,
tgt_lang
=
tgt_lang
,
additional_special_tokens
=
additional_special_tokens
,
additional_special_tokens
=
additional_special_tokens
,
**
kwargs
,
**
kwargs
,
)
)
self
.
vocab_file
=
vocab_file
self
.
can_save_slow_tokenizer
=
False
if
not
self
.
vocab_file
else
True
_additional_special_tokens
=
FAIRSEQ_LANGUAGE_CODES
.
copy
()
_additional_special_tokens
=
FAIRSEQ_LANGUAGE_CODES
.
copy
()
if
additional_special_tokens
is
not
None
:
if
additional_special_tokens
is
not
None
:
...
@@ -190,6 +189,31 @@ class MBartTokenizerFast(XLMRobertaTokenizerFast):
...
@@ -190,6 +189,31 @@ class MBartTokenizerFast(XLMRobertaTokenizerFast):
# We don't expect to process pairs, but leave the pair logic for API consistency
# We don't expect to process pairs, but leave the pair logic for API consistency
return
self
.
prefix_tokens
+
token_ids_0
+
token_ids_1
+
self
.
suffix_tokens
return
self
.
prefix_tokens
+
token_ids_0
+
token_ids_1
+
self
.
suffix_tokens
def
create_token_type_ids_from_sequences
(
self
,
token_ids_0
:
List
[
int
],
token_ids_1
:
Optional
[
List
[
int
]]
=
None
)
->
List
[
int
]:
"""
Create a mask from the two sequences passed to be used in a sequence-pair classification task. mBART does not
make use of token type ids, therefore a list of zeros is returned.
Args:
token_ids_0 (`List[int]`):
List of IDs.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
Returns:
`List[int]`: List of zeros.
"""
sep
=
[
self
.
sep_token_id
]
cls
=
[
self
.
cls_token_id
]
if
token_ids_1
is
None
:
return
len
(
cls
+
token_ids_0
+
sep
)
*
[
0
]
return
len
(
cls
+
token_ids_0
+
sep
+
sep
+
token_ids_1
+
sep
)
*
[
0
]
def
_build_translation_inputs
(
def
_build_translation_inputs
(
self
,
raw_inputs
,
return_tensors
:
str
,
src_lang
:
Optional
[
str
],
tgt_lang
:
Optional
[
str
],
**
extra_kwargs
self
,
raw_inputs
,
return_tensors
:
str
,
src_lang
:
Optional
[
str
],
tgt_lang
:
Optional
[
str
],
**
extra_kwargs
):
):
...
@@ -253,3 +277,22 @@ class MBartTokenizerFast(XLMRobertaTokenizerFast):
...
@@ -253,3 +277,22 @@ class MBartTokenizerFast(XLMRobertaTokenizerFast):
pair
=
prefix_tokens_str
+
[
"$A"
,
"$B"
]
+
suffix_tokens_str
,
pair
=
prefix_tokens_str
+
[
"$A"
,
"$B"
]
+
suffix_tokens_str
,
special_tokens
=
list
(
zip
(
prefix_tokens_str
+
suffix_tokens_str
,
self
.
prefix_tokens
+
self
.
suffix_tokens
)),
special_tokens
=
list
(
zip
(
prefix_tokens_str
+
suffix_tokens_str
,
self
.
prefix_tokens
+
self
.
suffix_tokens
)),
)
)
def
save_vocabulary
(
self
,
save_directory
:
str
,
filename_prefix
:
Optional
[
str
]
=
None
)
->
Tuple
[
str
]:
if
not
self
.
can_save_slow_tokenizer
:
raise
ValueError
(
"Your fast tokenizer does not have the necessary information to save the vocabulary for a slow "
"tokenizer."
)
if
not
os
.
path
.
isdir
(
save_directory
):
logger
.
error
(
f
"Vocabulary path (
{
save_directory
}
) should be a directory."
)
return
out_vocab_file
=
os
.
path
.
join
(
save_directory
,
(
filename_prefix
+
"-"
if
filename_prefix
else
""
)
+
VOCAB_FILES_NAMES
[
"vocab_file"
]
)
if
os
.
path
.
abspath
(
self
.
vocab_file
)
!=
os
.
path
.
abspath
(
out_vocab_file
):
copyfile
(
self
.
vocab_file
,
out_vocab_file
)
return
(
out_vocab_file
,)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment