"tools/vscode:/vscode.git/clone" did not exist on "07419768607d76ee16b4f8d641ee7f1990ec55d8"
Unverified Commit 36c4bb28 authored by Yuanheng Zhao's avatar Yuanheng Zhao Committed by GitHub
Browse files

[Fix] Grok-1 use tokenizer from the same pretrained path (#5532)

* [fix] use tokenizer from the same pretrained path

* trust remote code
parent 00525f77
import time
import torch
from transformers import AutoModelForCausalLM, LlamaTokenizerFast
from transformers import AutoModelForCausalLM, AutoTokenizer
from utils import get_defualt_parser, inference, print_output
if __name__ == "__main__":
......@@ -9,6 +9,9 @@ if __name__ == "__main__":
args = parser.parse_args()
start = time.time()
torch.set_default_dtype(torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(args.pretrained, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
args.pretrained,
trust_remote_code=True,
......@@ -18,10 +21,6 @@ if __name__ == "__main__":
model.eval()
init_time = time.time() - start
# A transformers-compatible version of the grok-1 tokenizer by Xenova
# https://huggingface.co/Xenova/grok-1-tokenizer
tokenizer = LlamaTokenizerFast.from_pretrained("Xenova/grok-1-tokenizer")
for text in args.text:
output = inference(
model,
......
......@@ -2,7 +2,7 @@ import time
import torch
from grok1_policy import Grok1ForCausalLMPolicy
from transformers import AutoModelForCausalLM, LlamaTokenizerFast
from transformers import AutoModelForCausalLM, AutoTokenizer
from utils import get_defualt_parser, inference, print_output
import colossalai
......@@ -27,6 +27,9 @@ if __name__ == "__main__":
)
booster = Booster(plugin=plugin)
torch.set_default_dtype(torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(args.pretrained, trust_remote_code=True)
with LazyInitContext(default_device=get_current_device()):
model = AutoModelForCausalLM.from_pretrained(
args.pretrained, trust_remote_code=True, torch_dtype=torch.bfloat16
......@@ -35,10 +38,6 @@ if __name__ == "__main__":
model.eval()
init_time = time.time() - start
# A transformers-compatible version of the grok-1 tokenizer by Xenova
# https://huggingface.co/Xenova/grok-1-tokenizer
tokenizer = LlamaTokenizerFast.from_pretrained("Xenova/grok-1-tokenizer")
for text in args.text:
output = inference(
model.unwrap(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment