"tests/test_zero/vscode:/vscode.git/clone" did not exist on "ae02d4e4f70e8ba4f8ae1058ac48bd08b06b6d24"
inference.py 2.59 KB
Newer Older
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
1
2
3
4
import argparse

import torch
from coati.models.bloom import BLOOMActor
5
from coati.models.generation import generate
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
6
from coati.models.gpt import GPTActor
7
from coati.models.llama import LlamaActor
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
8
from coati.models.opt import OPTActor
9
from transformers import AutoTokenizer, BloomTokenizerFast, GPT2Tokenizer, LlamaTokenizer
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
10
11
12
13


def eval(args):
    # configure model
14
    if args.model == "gpt2":
15
        actor = GPTActor(pretrained=args.pretrain)
16
    elif args.model == "bloom":
17
        actor = BLOOMActor(pretrained=args.pretrain)
18
    elif args.model == "opt":
19
        actor = OPTActor(pretrained=args.pretrain)
20
    elif args.model == "llama":
21
        actor = LlamaActor(pretrained=args.pretrain)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
22
23
24
    else:
        raise ValueError(f'Unsupported model "{args.model}"')

25
26
27
28
    actor.to(torch.cuda.current_device())
    if args.model_path is not None:
        state_dict = torch.load(args.model_path)
        actor.load_state_dict(state_dict)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
29
30

    # configure tokenizer
31
32
    if args.model == "gpt2":
        tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
33
        tokenizer.pad_token = tokenizer.eos_token
34
35
    elif args.model == "bloom":
        tokenizer = BloomTokenizerFast.from_pretrained("bigscience/bloom-560m")
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
36
        tokenizer.pad_token = tokenizer.eos_token
37
    elif args.model == "opt":
38
39
        tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
        tokenizer.pad_token = tokenizer.eos_token
40
    elif args.model == "llama":
41
        tokenizer = LlamaTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
42
        tokenizer.eos_token = "<\s>"
43
        tokenizer.pad_token = tokenizer.unk_token
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
44
45
46
47
    else:
        raise ValueError(f'Unsupported model "{args.model}"')

    actor.eval()
48
49
50
51
52
    input_ids = tokenizer.encode(args.input, return_tensors="pt").to(torch.cuda.current_device())
    outputs = generate(
        actor, input_ids, max_length=args.max_length, do_sample=True, top_k=50, top_p=0.95, num_return_sequences=1
    )
    output = tokenizer.batch_decode(outputs[0], skip_special_tokens=True)
53
    print(f"[Output]: {''.join(output)}")
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
54
55


56
if __name__ == "__main__":
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
57
    parser = argparse.ArgumentParser()
58
    parser.add_argument("--model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"])
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
59
    # We suggest to use the pretrained model from HuggingFace, use pretrain to configure model
60
61
62
63
    parser.add_argument("--pretrain", type=str, default=None)
    parser.add_argument("--model_path", type=str, default=None)
    parser.add_argument("--input", type=str, default="Question: How are you ? Answer:")
    parser.add_argument("--max_length", type=int, default=100)
Fazzie-Maqianli's avatar
Fazzie-Maqianli committed
64
65
    args = parser.parse_args()
    eval(args)