Unverified Commit 36c4bb28 authored by Yuanheng Zhao's avatar Yuanheng Zhao Committed by GitHub
Browse files

[Fix] Grok-1 use tokenizer from the same pretrained path (#5532)

* [fix] use tokenizer from the same pretrained path

* trust remote code
parent 00525f77
import time import time
import torch import torch
from transformers import AutoModelForCausalLM, LlamaTokenizerFast from transformers import AutoModelForCausalLM, AutoTokenizer
from utils import get_defualt_parser, inference, print_output from utils import get_defualt_parser, inference, print_output
if __name__ == "__main__": if __name__ == "__main__":
...@@ -9,6 +9,9 @@ if __name__ == "__main__": ...@@ -9,6 +9,9 @@ if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
start = time.time() start = time.time()
torch.set_default_dtype(torch.bfloat16) torch.set_default_dtype(torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(args.pretrained, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
args.pretrained, args.pretrained,
trust_remote_code=True, trust_remote_code=True,
...@@ -18,10 +21,6 @@ if __name__ == "__main__": ...@@ -18,10 +21,6 @@ if __name__ == "__main__":
model.eval() model.eval()
init_time = time.time() - start init_time = time.time() - start
# A transformers-compatible version of the grok-1 tokenizer by Xenova
# https://huggingface.co/Xenova/grok-1-tokenizer
tokenizer = LlamaTokenizerFast.from_pretrained("Xenova/grok-1-tokenizer")
for text in args.text: for text in args.text:
output = inference( output = inference(
model, model,
......
...@@ -2,7 +2,7 @@ import time ...@@ -2,7 +2,7 @@ import time
import torch import torch
from grok1_policy import Grok1ForCausalLMPolicy from grok1_policy import Grok1ForCausalLMPolicy
from transformers import AutoModelForCausalLM, LlamaTokenizerFast from transformers import AutoModelForCausalLM, AutoTokenizer
from utils import get_defualt_parser, inference, print_output from utils import get_defualt_parser, inference, print_output
import colossalai import colossalai
...@@ -27,6 +27,9 @@ if __name__ == "__main__": ...@@ -27,6 +27,9 @@ if __name__ == "__main__":
) )
booster = Booster(plugin=plugin) booster = Booster(plugin=plugin)
torch.set_default_dtype(torch.bfloat16) torch.set_default_dtype(torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(args.pretrained, trust_remote_code=True)
with LazyInitContext(default_device=get_current_device()): with LazyInitContext(default_device=get_current_device()):
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
args.pretrained, trust_remote_code=True, torch_dtype=torch.bfloat16 args.pretrained, trust_remote_code=True, torch_dtype=torch.bfloat16
...@@ -35,10 +38,6 @@ if __name__ == "__main__": ...@@ -35,10 +38,6 @@ if __name__ == "__main__":
model.eval() model.eval()
init_time = time.time() - start init_time = time.time() - start
# A transformers-compatible version of the grok-1 tokenizer by Xenova
# https://huggingface.co/Xenova/grok-1-tokenizer
tokenizer = LlamaTokenizerFast.from_pretrained("Xenova/grok-1-tokenizer")
for text in args.text: for text in args.text:
output = inference( output = inference(
model.unwrap(), model.unwrap(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment