Unverified Commit 6a02e980 authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

LlamaTokenizerFast Fix (.., from_slow=True). (#22630)

parent 09a9888f
...@@ -12,12 +12,25 @@ ...@@ -12,12 +12,25 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
from shutil import copyfile
from typing import Optional, Tuple
from ...tokenization_utils_fast import PreTrainedTokenizerFast from ...tokenization_utils_fast import PreTrainedTokenizerFast
from ...utils import is_sentencepiece_available, logging
from ...utils.versions import require_version from ...utils.versions import require_version
require_version("tokenizers>=0.13.3") require_version("tokenizers>=0.13.3")
if is_sentencepiece_available():
from .tokenization_llama import LlamaTokenizer
else:
LlamaTokenizer = None
logger = logging.get_logger(__name__)
VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model", "tokenizer_file": "tokenizer.json"}
class LlamaTokenizerFast(PreTrainedTokenizerFast): class LlamaTokenizerFast(PreTrainedTokenizerFast):
""" """
...@@ -59,6 +72,8 @@ class LlamaTokenizerFast(PreTrainedTokenizerFast): ...@@ -59,6 +72,8 @@ class LlamaTokenizerFast(PreTrainedTokenizerFast):
token instead. token instead.
""" """
vocab_files_names = VOCAB_FILES_NAMES
slow_tokenizer_class = LlamaTokenizer
padding_side = "left" padding_side = "left"
def __init__( def __init__(
...@@ -80,3 +95,25 @@ class LlamaTokenizerFast(PreTrainedTokenizerFast): ...@@ -80,3 +95,25 @@ class LlamaTokenizerFast(PreTrainedTokenizerFast):
eos_token=eos_token, eos_token=eos_token,
**kwargs, **kwargs,
) )
self.vocab_file = vocab_file
self.can_save_slow_tokenizer = False if not self.vocab_file else True
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
if not self.can_save_slow_tokenizer:
raise ValueError(
"Your fast tokenizer does not have the necessary information to save the vocabulary for a slow "
"tokenizer."
)
if not os.path.isdir(save_directory):
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
return
out_vocab_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
)
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
copyfile(self.vocab_file, out_vocab_file)
return (out_vocab_file,)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment