Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
2ae3be54
Unverified
Commit
2ae3be54
authored
Jan 18, 2022
by
Suraj Patil
Committed by
GitHub
Jan 18, 2022
Browse files
[MBartTokenizer] remove dep on xlm-roberta tokenizer (#15201)
parent
84c60a7b
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
219 additions
and
70 deletions
+219
-70
src/transformers/models/mbart/tokenization_mbart.py
src/transformers/models/mbart/tokenization_mbart.py
+143
-37
src/transformers/models/mbart/tokenization_mbart_fast.py
src/transformers/models/mbart/tokenization_mbart_fast.py
+76
-33
No files found.
src/transformers/models/mbart/tokenization_mbart.py
View file @
2ae3be54
...
@@ -13,16 +13,20 @@
...
@@ -13,16 +13,20 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
os
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
typing
import
List
,
Optional
from
shutil
import
copyfile
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
from
...tokenization_utils
import
BatchEncoding
import
sentencepiece
as
spm
from
...tokenization_utils
import
AddedToken
,
BatchEncoding
,
PreTrainedTokenizer
from
...utils
import
logging
from
...utils
import
logging
from
..xlm_roberta.tokenization_xlm_roberta
import
XLMRobertaTokenizer
logger
=
logging
.
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
SPIECE_UNDERLINE
=
"▁"
VOCAB_FILES_NAMES
=
{
"vocab_file"
:
"sentencepiece.bpe.model"
}
VOCAB_FILES_NAMES
=
{
"vocab_file"
:
"sentencepiece.bpe.model"
}
...
@@ -38,41 +42,17 @@ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
...
@@ -38,41 +42,17 @@ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
"facebook/mbart-large-cc25"
:
1024
,
"facebook/mbart-large-cc25"
:
1024
,
}
}
FAIRSEQ_LANGUAGE_CODES
=
[
# fmt: off
"ar_AR"
,
FAIRSEQ_LANGUAGE_CODES
=
[
"ar_AR"
,
"cs_CZ"
,
"de_DE"
,
"en_XX"
,
"es_XX"
,
"et_EE"
,
"fi_FI"
,
"fr_XX"
,
"gu_IN"
,
"hi_IN"
,
"it_IT"
,
"ja_XX"
,
"kk_KZ"
,
"ko_KR"
,
"lt_LT"
,
"lv_LV"
,
"my_MM"
,
"ne_NP"
,
"nl_XX"
,
"ro_RO"
,
"ru_RU"
,
"si_LK"
,
"tr_TR"
,
"vi_VN"
,
"zh_CN"
]
"cs_CZ"
,
# fmt: on
"de_DE"
,
"en_XX"
,
"es_XX"
,
class
MBartTokenizer
(
PreTrainedTokenizer
):
"et_EE"
,
"fi_FI"
,
"fr_XX"
,
"gu_IN"
,
"hi_IN"
,
"it_IT"
,
"ja_XX"
,
"kk_KZ"
,
"ko_KR"
,
"lt_LT"
,
"lv_LV"
,
"my_MM"
,
"ne_NP"
,
"nl_XX"
,
"ro_RO"
,
"ru_RU"
,
"si_LK"
,
"tr_TR"
,
"vi_VN"
,
"zh_CN"
,
]
class
MBartTokenizer
(
XLMRobertaTokenizer
):
"""
"""
Construct an MBART tokenizer.
Construct an MBART tokenizer.
[`MBartTokenizer`] is a subclass of [`XLMRobertaTokenizer`]. Refer to superclass [`XLMRoberta
Tokenizer`]
for usage
Adapted from [`RobertaTokenizer`] and [`XLNet
Tokenizer`]
. Based on
examples and documentation concerning the initialization parameters and other methods
.
[SentencePiece](https://github.com/google/sentencepiece)
.
The tokenization method is `<tokens> <eos> <language code>` for source language documents, and ``<language code>
The tokenization method is `<tokens> <eos> <language code>` for source language documents, and ``<language code>
<tokens> <eos>``` for target language documents.
<tokens> <eos>``` for target language documents.
...
@@ -94,22 +74,66 @@ class MBartTokenizer(XLMRobertaTokenizer):
...
@@ -94,22 +74,66 @@ class MBartTokenizer(XLMRobertaTokenizer):
vocab_files_names
=
VOCAB_FILES_NAMES
vocab_files_names
=
VOCAB_FILES_NAMES
max_model_input_sizes
=
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
max_model_input_sizes
=
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
pretrained_vocab_files_map
=
PRETRAINED_VOCAB_FILES_MAP
pretrained_vocab_files_map
=
PRETRAINED_VOCAB_FILES_MAP
model_input_names
=
[
"input_ids"
,
"attention_mask"
]
prefix_tokens
:
List
[
int
]
=
[]
prefix_tokens
:
List
[
int
]
=
[]
suffix_tokens
:
List
[
int
]
=
[]
suffix_tokens
:
List
[
int
]
=
[]
def
__init__
(
def
__init__
(
self
,
*
args
,
tokenizer_file
=
None
,
src_lang
=
None
,
tgt_lang
=
None
,
additional_special_tokens
=
None
,
**
kwargs
self
,
vocab_file
,
bos_token
=
"<s>"
,
eos_token
=
"</s>"
,
sep_token
=
"</s>"
,
cls_token
=
"<s>"
,
unk_token
=
"<unk>"
,
pad_token
=
"<pad>"
,
mask_token
=
"<mask>"
,
tokenizer_file
=
None
,
src_lang
=
None
,
tgt_lang
=
None
,
sp_model_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
additional_special_tokens
=
None
,
**
kwargs
):
):
# Mask token behave like a normal word, i.e. include the space before it
mask_token
=
AddedToken
(
mask_token
,
lstrip
=
True
,
rstrip
=
False
)
if
isinstance
(
mask_token
,
str
)
else
mask_token
self
.
sp_model_kwargs
=
{}
if
sp_model_kwargs
is
None
else
sp_model_kwargs
super
().
__init__
(
super
().
__init__
(
*
args
,
bos_token
=
bos_token
,
eos_token
=
eos_token
,
unk_token
=
unk_token
,
sep_token
=
sep_token
,
cls_token
=
cls_token
,
pad_token
=
pad_token
,
mask_token
=
mask_token
,
tokenizer_file
=
tokenizer_file
,
tokenizer_file
=
tokenizer_file
,
src_lang
=
src_lang
,
src_lang
=
src_lang
,
tgt_lang
=
tgt_lang
,
tgt_lang
=
tgt_lang
,
additional_special_tokens
=
additional_special_tokens
,
additional_special_tokens
=
additional_special_tokens
,
sp_model_kwargs
=
self
.
sp_model_kwargs
,
**
kwargs
,
**
kwargs
,
)
)
self
.
sp_model
=
spm
.
SentencePieceProcessor
(
**
self
.
sp_model_kwargs
)
self
.
sp_model
.
Load
(
str
(
vocab_file
))
self
.
vocab_file
=
vocab_file
# Original fairseq vocab and spm vocab must be "aligned":
# Vocab | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9
# -------- | ------- | ------- | ------ | ------- | --- | --- | --- | ----- | ----- | ----
# fairseq | '<s>' | '<pad>' | '</s>' | '<unk>' | ',' | '.' | '▁' | 's' | '▁de' | '-'
# spm | '<unk>' | '<s>' | '</s>' | ',' | '.' | '▁' | 's' | '▁de' | '-' | '▁a'
# Mimic fairseq token-to-id alignment for the first 4 token
self
.
fairseq_tokens_to_ids
=
{
"<s>"
:
0
,
"<pad>"
:
1
,
"</s>"
:
2
,
"<unk>"
:
3
}
# The first "real" token "," has position 4 in the original fairseq vocab and position 3 in the spm vocab
self
.
fairseq_offset
=
1
self
.
sp_model_size
=
len
(
self
.
sp_model
)
self
.
sp_model_size
=
len
(
self
.
sp_model
)
self
.
lang_code_to_id
=
{
self
.
lang_code_to_id
=
{
code
:
self
.
sp_model_size
+
i
+
self
.
fairseq_offset
for
i
,
code
in
enumerate
(
FAIRSEQ_LANGUAGE_CODES
)
code
:
self
.
sp_model_size
+
i
+
self
.
fairseq_offset
for
i
,
code
in
enumerate
(
FAIRSEQ_LANGUAGE_CODES
)
...
@@ -132,6 +156,22 @@ class MBartTokenizer(XLMRobertaTokenizer):
...
@@ -132,6 +156,22 @@ class MBartTokenizer(XLMRobertaTokenizer):
self
.
tgt_lang
=
tgt_lang
self
.
tgt_lang
=
tgt_lang
self
.
set_src_lang_special_tokens
(
self
.
_src_lang
)
self
.
set_src_lang_special_tokens
(
self
.
_src_lang
)
def
__getstate__
(
self
):
state
=
self
.
__dict__
.
copy
()
state
[
"sp_model"
]
=
None
state
[
"sp_model_proto"
]
=
self
.
sp_model
.
serialized_model_proto
()
return
state
def
__setstate__
(
self
,
d
):
self
.
__dict__
=
d
# for backward compatibility
if
not
hasattr
(
self
,
"sp_model_kwargs"
):
self
.
sp_model_kwargs
=
{}
self
.
sp_model
=
spm
.
SentencePieceProcessor
(
**
self
.
sp_model_kwargs
)
self
.
sp_model
.
LoadFromSerializedProto
(
self
.
sp_model_proto
)
@
property
@
property
def
vocab_size
(
self
):
def
vocab_size
(
self
):
return
len
(
self
.
sp_model
)
+
len
(
self
.
lang_code_to_id
)
+
self
.
fairseq_offset
+
1
# Plus 1 for the mask token
return
len
(
self
.
sp_model
)
+
len
(
self
.
lang_code_to_id
)
+
self
.
fairseq_offset
+
1
# Plus 1 for the mask token
...
@@ -202,6 +242,31 @@ class MBartTokenizer(XLMRobertaTokenizer):
...
@@ -202,6 +242,31 @@ class MBartTokenizer(XLMRobertaTokenizer):
# We don't expect to process pairs, but leave the pair logic for API consistency
# We don't expect to process pairs, but leave the pair logic for API consistency
return
self
.
prefix_tokens
+
token_ids_0
+
token_ids_1
+
self
.
suffix_tokens
return
self
.
prefix_tokens
+
token_ids_0
+
token_ids_1
+
self
.
suffix_tokens
def
create_token_type_ids_from_sequences
(
self
,
token_ids_0
:
List
[
int
],
token_ids_1
:
Optional
[
List
[
int
]]
=
None
)
->
List
[
int
]:
"""
Create a mask from the two sequences passed to be used in a sequence-pair classification task. mBART does not
make use of token type ids, therefore a list of zeros is returned.
Args:
token_ids_0 (`List[int]`):
List of IDs.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
Returns:
`List[int]`: List of zeros.
"""
sep
=
[
self
.
sep_token_id
]
cls
=
[
self
.
cls_token_id
]
if
token_ids_1
is
None
:
return
len
(
cls
+
token_ids_0
+
sep
)
*
[
0
]
return
len
(
cls
+
token_ids_0
+
sep
+
sep
+
token_ids_1
+
sep
)
*
[
0
]
def
_build_translation_inputs
(
def
_build_translation_inputs
(
self
,
raw_inputs
,
return_tensors
:
str
,
src_lang
:
Optional
[
str
],
tgt_lang
:
Optional
[
str
],
**
extra_kwargs
self
,
raw_inputs
,
return_tensors
:
str
,
src_lang
:
Optional
[
str
],
tgt_lang
:
Optional
[
str
],
**
extra_kwargs
):
):
...
@@ -214,6 +279,47 @@ class MBartTokenizer(XLMRobertaTokenizer):
...
@@ -214,6 +279,47 @@ class MBartTokenizer(XLMRobertaTokenizer):
inputs
[
"forced_bos_token_id"
]
=
tgt_lang_id
inputs
[
"forced_bos_token_id"
]
=
tgt_lang_id
return
inputs
return
inputs
def
get_vocab
(
self
):
vocab
=
{
self
.
convert_ids_to_tokens
(
i
):
i
for
i
in
range
(
self
.
vocab_size
)}
vocab
.
update
(
self
.
added_tokens_encoder
)
return
vocab
def
_tokenize
(
self
,
text
:
str
)
->
List
[
str
]:
return
self
.
sp_model
.
encode
(
text
,
out_type
=
str
)
def
_convert_token_to_id
(
self
,
token
):
"""Converts a token (str) in an id using the vocab."""
if
token
in
self
.
fairseq_tokens_to_ids
:
return
self
.
fairseq_tokens_to_ids
[
token
]
spm_id
=
self
.
sp_model
.
PieceToId
(
token
)
# Need to return unknown token if the SP model returned 0
return
spm_id
+
self
.
fairseq_offset
if
spm_id
else
self
.
unk_token_id
def
_convert_id_to_token
(
self
,
index
):
"""Converts an index (integer) in a token (str) using the vocab."""
if
index
in
self
.
fairseq_ids_to_tokens
:
return
self
.
fairseq_ids_to_tokens
[
index
]
return
self
.
sp_model
.
IdToPiece
(
index
-
self
.
fairseq_offset
)
def
convert_tokens_to_string
(
self
,
tokens
):
"""Converts a sequence of tokens (strings for sub-words) in a single string."""
out_string
=
""
.
join
(
tokens
).
replace
(
SPIECE_UNDERLINE
,
" "
).
strip
()
return
out_string
def
save_vocabulary
(
self
,
save_directory
:
str
,
filename_prefix
:
Optional
[
str
]
=
None
)
->
Tuple
[
str
]:
if
not
os
.
path
.
isdir
(
save_directory
):
logger
.
error
(
f
"Vocabulary path (
{
save_directory
}
) should be a directory"
)
return
out_vocab_file
=
os
.
path
.
join
(
save_directory
,
(
filename_prefix
+
"-"
if
filename_prefix
else
""
)
+
VOCAB_FILES_NAMES
[
"vocab_file"
]
)
if
os
.
path
.
abspath
(
self
.
vocab_file
)
!=
os
.
path
.
abspath
(
out_vocab_file
):
copyfile
(
self
.
vocab_file
,
out_vocab_file
)
return
(
out_vocab_file
,)
def
prepare_seq2seq_batch
(
def
prepare_seq2seq_batch
(
self
,
self
,
src_texts
:
List
[
str
],
src_texts
:
List
[
str
],
...
...
src/transformers/models/mbart/tokenization_mbart_fast.py
View file @
2ae3be54
...
@@ -13,15 +13,17 @@
...
@@ -13,15 +13,17 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
os
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
typing
import
List
,
Optional
from
shutil
import
copyfile
from
typing
import
List
,
Optional
,
Tuple
from
tokenizers
import
processors
from
tokenizers
import
processors
from
...file_utils
import
is_sentencepiece_available
from
...file_utils
import
is_sentencepiece_available
from
...tokenization_utils
import
BatchEncoding
from
...tokenization_utils
import
AddedToken
,
BatchEncoding
from
...tokenization_utils_fast
import
PreTrainedTokenizerFast
from
...utils
import
logging
from
...utils
import
logging
from
..xlm_roberta.tokenization_xlm_roberta_fast
import
XLMRobertaTokenizerFast
if
is_sentencepiece_available
():
if
is_sentencepiece_available
():
...
@@ -51,36 +53,12 @@ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
...
@@ -51,36 +53,12 @@ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
"facebook/mbart-large-cc25"
:
1024
,
"facebook/mbart-large-cc25"
:
1024
,
}
}
FAIRSEQ_LANGUAGE_CODES
=
[
# fmt: off
"ar_AR"
,
FAIRSEQ_LANGUAGE_CODES
=
[
"ar_AR"
,
"cs_CZ"
,
"de_DE"
,
"en_XX"
,
"es_XX"
,
"et_EE"
,
"fi_FI"
,
"fr_XX"
,
"gu_IN"
,
"hi_IN"
,
"it_IT"
,
"ja_XX"
,
"kk_KZ"
,
"ko_KR"
,
"lt_LT"
,
"lv_LV"
,
"my_MM"
,
"ne_NP"
,
"nl_XX"
,
"ro_RO"
,
"ru_RU"
,
"si_LK"
,
"tr_TR"
,
"vi_VN"
,
"zh_CN"
]
"cs_CZ"
,
# fmt: on
"de_DE"
,
"en_XX"
,
"es_XX"
,
class
MBartTokenizerFast
(
PreTrainedTokenizerFast
):
"et_EE"
,
"fi_FI"
,
"fr_XX"
,
"gu_IN"
,
"hi_IN"
,
"it_IT"
,
"ja_XX"
,
"kk_KZ"
,
"ko_KR"
,
"lt_LT"
,
"lv_LV"
,
"my_MM"
,
"ne_NP"
,
"nl_XX"
,
"ro_RO"
,
"ru_RU"
,
"si_LK"
,
"tr_TR"
,
"vi_VN"
,
"zh_CN"
,
]
class
MBartTokenizerFast
(
XLMRobertaTokenizerFast
):
"""
"""
Construct a "fast" MBART tokenizer (backed by HuggingFace's *tokenizers* library). Based on
Construct a "fast" MBART tokenizer (backed by HuggingFace's *tokenizers* library). Based on
[BPE](https://huggingface.co/docs/tokenizers/python/latest/components.html?highlight=BPE#models).
[BPE](https://huggingface.co/docs/tokenizers/python/latest/components.html?highlight=BPE#models).
...
@@ -111,6 +89,7 @@ class MBartTokenizerFast(XLMRobertaTokenizerFast):
...
@@ -111,6 +89,7 @@ class MBartTokenizerFast(XLMRobertaTokenizerFast):
vocab_files_names
=
VOCAB_FILES_NAMES
vocab_files_names
=
VOCAB_FILES_NAMES
max_model_input_sizes
=
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
max_model_input_sizes
=
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
pretrained_vocab_files_map
=
PRETRAINED_VOCAB_FILES_MAP
pretrained_vocab_files_map
=
PRETRAINED_VOCAB_FILES_MAP
model_input_names
=
[
"input_ids"
,
"attention_mask"
]
slow_tokenizer_class
=
MBartTokenizer
slow_tokenizer_class
=
MBartTokenizer
prefix_tokens
:
List
[
int
]
=
[]
prefix_tokens
:
List
[
int
]
=
[]
...
@@ -120,20 +99,40 @@ class MBartTokenizerFast(XLMRobertaTokenizerFast):
...
@@ -120,20 +99,40 @@ class MBartTokenizerFast(XLMRobertaTokenizerFast):
self
,
self
,
vocab_file
=
None
,
vocab_file
=
None
,
tokenizer_file
=
None
,
tokenizer_file
=
None
,
bos_token
=
"<s>"
,
eos_token
=
"</s>"
,
sep_token
=
"</s>"
,
cls_token
=
"<s>"
,
unk_token
=
"<unk>"
,
pad_token
=
"<pad>"
,
mask_token
=
"<mask>"
,
src_lang
=
None
,
src_lang
=
None
,
tgt_lang
=
None
,
tgt_lang
=
None
,
additional_special_tokens
=
None
,
additional_special_tokens
=
None
,
**
kwargs
**
kwargs
):
):
# Mask token behave like a normal word, i.e. include the space before it
mask_token
=
AddedToken
(
mask_token
,
lstrip
=
True
,
rstrip
=
False
)
if
isinstance
(
mask_token
,
str
)
else
mask_token
super
().
__init__
(
super
().
__init__
(
vocab_file
=
vocab_file
,
vocab_file
=
vocab_file
,
tokenizer_file
=
tokenizer_file
,
tokenizer_file
=
tokenizer_file
,
bos_token
=
bos_token
,
eos_token
=
eos_token
,
sep_token
=
sep_token
,
cls_token
=
cls_token
,
unk_token
=
unk_token
,
pad_token
=
pad_token
,
mask_token
=
mask_token
,
src_lang
=
src_lang
,
src_lang
=
src_lang
,
tgt_lang
=
tgt_lang
,
tgt_lang
=
tgt_lang
,
additional_special_tokens
=
additional_special_tokens
,
additional_special_tokens
=
additional_special_tokens
,
**
kwargs
,
**
kwargs
,
)
)
self
.
vocab_file
=
vocab_file
self
.
can_save_slow_tokenizer
=
False
if
not
self
.
vocab_file
else
True
_additional_special_tokens
=
FAIRSEQ_LANGUAGE_CODES
.
copy
()
_additional_special_tokens
=
FAIRSEQ_LANGUAGE_CODES
.
copy
()
if
additional_special_tokens
is
not
None
:
if
additional_special_tokens
is
not
None
:
...
@@ -190,6 +189,31 @@ class MBartTokenizerFast(XLMRobertaTokenizerFast):
...
@@ -190,6 +189,31 @@ class MBartTokenizerFast(XLMRobertaTokenizerFast):
# We don't expect to process pairs, but leave the pair logic for API consistency
# We don't expect to process pairs, but leave the pair logic for API consistency
return
self
.
prefix_tokens
+
token_ids_0
+
token_ids_1
+
self
.
suffix_tokens
return
self
.
prefix_tokens
+
token_ids_0
+
token_ids_1
+
self
.
suffix_tokens
def
create_token_type_ids_from_sequences
(
self
,
token_ids_0
:
List
[
int
],
token_ids_1
:
Optional
[
List
[
int
]]
=
None
)
->
List
[
int
]:
"""
Create a mask from the two sequences passed to be used in a sequence-pair classification task. mBART does not
make use of token type ids, therefore a list of zeros is returned.
Args:
token_ids_0 (`List[int]`):
List of IDs.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
Returns:
`List[int]`: List of zeros.
"""
sep
=
[
self
.
sep_token_id
]
cls
=
[
self
.
cls_token_id
]
if
token_ids_1
is
None
:
return
len
(
cls
+
token_ids_0
+
sep
)
*
[
0
]
return
len
(
cls
+
token_ids_0
+
sep
+
sep
+
token_ids_1
+
sep
)
*
[
0
]
def
_build_translation_inputs
(
def
_build_translation_inputs
(
self
,
raw_inputs
,
return_tensors
:
str
,
src_lang
:
Optional
[
str
],
tgt_lang
:
Optional
[
str
],
**
extra_kwargs
self
,
raw_inputs
,
return_tensors
:
str
,
src_lang
:
Optional
[
str
],
tgt_lang
:
Optional
[
str
],
**
extra_kwargs
):
):
...
@@ -253,3 +277,22 @@ class MBartTokenizerFast(XLMRobertaTokenizerFast):
...
@@ -253,3 +277,22 @@ class MBartTokenizerFast(XLMRobertaTokenizerFast):
pair
=
prefix_tokens_str
+
[
"$A"
,
"$B"
]
+
suffix_tokens_str
,
pair
=
prefix_tokens_str
+
[
"$A"
,
"$B"
]
+
suffix_tokens_str
,
special_tokens
=
list
(
zip
(
prefix_tokens_str
+
suffix_tokens_str
,
self
.
prefix_tokens
+
self
.
suffix_tokens
)),
special_tokens
=
list
(
zip
(
prefix_tokens_str
+
suffix_tokens_str
,
self
.
prefix_tokens
+
self
.
suffix_tokens
)),
)
)
def
save_vocabulary
(
self
,
save_directory
:
str
,
filename_prefix
:
Optional
[
str
]
=
None
)
->
Tuple
[
str
]:
if
not
self
.
can_save_slow_tokenizer
:
raise
ValueError
(
"Your fast tokenizer does not have the necessary information to save the vocabulary for a slow "
"tokenizer."
)
if
not
os
.
path
.
isdir
(
save_directory
):
logger
.
error
(
f
"Vocabulary path (
{
save_directory
}
) should be a directory."
)
return
out_vocab_file
=
os
.
path
.
join
(
save_directory
,
(
filename_prefix
+
"-"
if
filename_prefix
else
""
)
+
VOCAB_FILES_NAMES
[
"vocab_file"
]
)
if
os
.
path
.
abspath
(
self
.
vocab_file
)
!=
os
.
path
.
abspath
(
out_vocab_file
):
copyfile
(
self
.
vocab_file
,
out_vocab_file
)
return
(
out_vocab_file
,)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment