inference_example.py 2.94 KB
Newer Older
1
2
3
import argparse

import torch
4
5
6
from colossal_llama2.dataset.conversation import default_conversation
from transformers import AutoModelForCausalLM, AutoTokenizer

7
8
9
10
11
12
from colossalai.logging import get_dist_logger

logger = get_dist_logger()


def load_model(model_path, device="cuda", **kwargs):
13
    logger.info("Please check whether the tokenizer and model weights are properly stored in the same folder.")
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
    model = AutoModelForCausalLM.from_pretrained(model_path, **kwargs)
    model.to(device)

    try:
        tokenizer = AutoTokenizer.from_pretrained(model_path)
    except OSError:
        raise ImportError("Tokenizer not found. Please check if the tokenizer exists or the model path is correct.")

    return model, tokenizer


@torch.inference_mode()
def generate(args):
    model, tokenizer = load_model(model_path=args.model_path, device=args.device)

29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
    if args.prompt_style == "sft":
        conversation = default_conversation.copy()
        conversation.append_message("Human", args.input_txt)
        input_txt = conversation.get_prompt()
    else:
        BASE_INFERENCE_SUFFIX = "\n\n->\n\n"
        input_txt = f"{args.input_txt}{BASE_INFERENCE_SUFFIX}"

    inputs = tokenizer(input_txt, return_tensors="pt").to(args.device)
    num_input_tokens = inputs["input_ids"].shape[-1]
    output = model.generate(
        **inputs,
        max_new_tokens=args.max_new_tokens,
        do_sample=args.do_sample,
        temperature=args.temperature,
        top_k=args.top_k,
        top_p=args.top_p,
        num_return_sequences=1,
    )
    response = tokenizer.decode(output.cpu()[0, num_input_tokens:], skip_special_tokens=True)
49
50
51
52
53
54
    logger.info(f"Question: {input_txt} \n\n Answer: \n{response}")
    return response


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Colossal-LLaMA-2 inference Process.")
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
    parser.add_argument(
        "--model_path",
        type=str,
        default="hpcai-tech/Colossal-LLaMA-2-7b-base",
        help="HF repo name or local path of the model",
    )
    parser.add_argument("--device", type=str, default="cuda:0", help="Set the device")
    parser.add_argument(
        "--max_new_tokens",
        type=int,
        default=512,
        help=" Set maximum numbers of tokens to generate, ignoring the number of tokens in the prompt",
    )
    parser.add_argument("--do_sample", type=bool, default=True, help="Set whether or not to use sampling")
    parser.add_argument("--temperature", type=float, default=0.3, help="Set temperature value")
    parser.add_argument("--top_k", type=int, default=50, help="Set top_k value for top-k-filtering")
71
    parser.add_argument("--top_p", type=float, default=0.95, help="Set top_p value for generation")
72
73
    parser.add_argument("--input_txt", type=str, default="明月松间照,", help="The prompt input to the model")
    parser.add_argument("--prompt_style", choices=["sft", "pretrained"], default="sft", help="The style of the prompt")
74
    args = parser.parse_args()
75
    generate(args)