inference.py 1.48 KB
Newer Older
1
2
3
import time

import torch
4
from transformers import AutoModelForCausalLM, LlamaTokenizerFast
5
6
7
8
9
10
11
12
13
14
15
16
17
from utils import get_defualt_parser, inference, print_output

if __name__ == "__main__":
    parser = get_defualt_parser()
    args = parser.parse_args()
    start = time.time()
    torch.set_default_dtype(torch.bfloat16)
    model = AutoModelForCausalLM.from_pretrained(
        args.pretrained,
        trust_remote_code=True,
        device_map="auto",
        torch_dtype=torch.bfloat16,
    )
18
19
20
21
22
23
24
    model.eval()
    init_time = time.time() - start

    # A transformers-compatible version of the grok-1 tokenizer by Xenova
    # https://huggingface.co/Xenova/grok-1-tokenizer
    tokenizer = LlamaTokenizerFast.from_pretrained("Xenova/grok-1-tokenizer")

25
26
27
    for text in args.text:
        output = inference(
            model,
28
            tokenizer,
29
30
31
32
33
34
35
            text,
            max_new_tokens=args.max_new_tokens,
            do_sample=args.do_sample,
            temperature=args.temperature,
            top_k=args.top_k,
            top_p=args.top_p,
        )
36
37
38
39
40
41
42
43
44
45
46
        print_output(text, tokenizer.decode(output))

    overall_time = time.time() - start
    gen_latency = overall_time - init_time
    avg_gen_latency = gen_latency / len(args.text)
    print(
        f"Initializing time: {init_time:.2f} seconds.\n"
        f"Overall time: {overall_time:.2f} seconds. \n"
        f"Generation latency: {gen_latency:.2f} seconds. \n"
        f"Average generation latency: {avg_gen_latency:.2f} seconds. \n"
    )