Unverified Commit 98581954 authored by Arthur's avatar Arthur Committed by GitHub
Browse files

add fast support and option (#22724)



* add fast support and option

* update based on review

* Apply suggestions from code review
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/models/llama/convert_llama_weights_to_hf.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* nit

* add print

* fixup

---------
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 10fab90f
...@@ -17,12 +17,22 @@ import json ...@@ -17,12 +17,22 @@ import json
import math import math
import os import os
import shutil import shutil
import warnings
import torch import torch
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
try:
from transformers import LlamaTokenizerFast
except ImportError as e:
warnings.warn(e)
warnings.warn(
"The converted tokenizer will be the `slow` tokenizer. To use the fast, update your `tokenizers` library and re-run the tokenizer conversion"
)
LlamaTokenizerFast = None
""" """
Sample usage: Sample usage:
...@@ -232,9 +242,10 @@ def write_model(model_path, input_base_path, model_size): ...@@ -232,9 +242,10 @@ def write_model(model_path, input_base_path, model_size):
def write_tokenizer(tokenizer_path, input_tokenizer_path): def write_tokenizer(tokenizer_path, input_tokenizer_path):
print(f"Fetching the tokenizer from {input_tokenizer_path}.")
# Initialize the tokenizer based on the `spm` model # Initialize the tokenizer based on the `spm` model
tokenizer = LlamaTokenizer(input_tokenizer_path) tokenizer_class = LlamaTokenizer if LlamaTokenizerFast is None else LlamaTokenizerFast
print("Saving a {tokenizer_class} to {tokenizer_path}")
tokenizer = tokenizer_class(input_tokenizer_path)
tokenizer.save_pretrained(tokenizer_path) tokenizer.save_pretrained(tokenizer_path)
...@@ -259,10 +270,8 @@ def main(): ...@@ -259,10 +270,8 @@ def main():
input_base_path=os.path.join(args.input_dir, args.model_size), input_base_path=os.path.join(args.input_dir, args.model_size),
model_size=args.model_size, model_size=args.model_size,
) )
write_tokenizer( spm_path = os.path.join(args.input_dir, "tokenizer.model")
tokenizer_path=args.output_dir, write_tokenizer(args.output_dir, spm_path)
input_tokenizer_path=os.path.join(args.input_dir, "tokenizer.model"),
)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment