inference_example.py 3.02 KB
Newer Older
1
2
3
import argparse

import torch
4
from colossal_llama.dataset.conversation import default_conversation
5
6
from transformers import AutoModelForCausalLM, AutoTokenizer

7
8
9
10
11
12
from colossalai.logging import get_dist_logger

logger = get_dist_logger()


def load_model(model_path, device="cuda", **kwargs):
13
    logger.info("Please check whether the tokenizer and model weights are properly stored in the same folder.")
14
15
16
17
    model = AutoModelForCausalLM.from_pretrained(model_path, **kwargs)
    model.to(device)

    try:
18
        tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side="left")
19
20
21
22
23
24
25
26
27
28
    except OSError:
        raise ImportError("Tokenizer not found. Please check if the tokenizer exists or the model path is correct.")

    return model, tokenizer


@torch.inference_mode()
def generate(args):
    model, tokenizer = load_model(model_path=args.model_path, device=args.device)

29
30
31
    if args.prompt_style == "sft":
        conversation = default_conversation.copy()
        conversation.append_message("Human", args.input_txt)
32
        conversation.append_message("Assistant", None)
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
        input_txt = conversation.get_prompt()
    else:
        BASE_INFERENCE_SUFFIX = "\n\n->\n\n"
        input_txt = f"{args.input_txt}{BASE_INFERENCE_SUFFIX}"

    inputs = tokenizer(input_txt, return_tensors="pt").to(args.device)
    num_input_tokens = inputs["input_ids"].shape[-1]
    output = model.generate(
        **inputs,
        max_new_tokens=args.max_new_tokens,
        do_sample=args.do_sample,
        temperature=args.temperature,
        top_k=args.top_k,
        top_p=args.top_p,
        num_return_sequences=1,
    )
    response = tokenizer.decode(output.cpu()[0, num_input_tokens:], skip_special_tokens=True)
50
    logger.info(f"\nHuman: {args.input_txt} \n\nAssistant: \n{response}")
51
52
53
54
55
    return response


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Colossal-LLaMA-2 inference Process.")
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
    parser.add_argument(
        "--model_path",
        type=str,
        default="hpcai-tech/Colossal-LLaMA-2-7b-base",
        help="HF repo name or local path of the model",
    )
    parser.add_argument("--device", type=str, default="cuda:0", help="Set the device")
    parser.add_argument(
        "--max_new_tokens",
        type=int,
        default=512,
        help=" Set maximum numbers of tokens to generate, ignoring the number of tokens in the prompt",
    )
    parser.add_argument("--do_sample", type=bool, default=True, help="Set whether or not to use sampling")
    parser.add_argument("--temperature", type=float, default=0.3, help="Set temperature value")
    parser.add_argument("--top_k", type=int, default=50, help="Set top_k value for top-k-filtering")
72
    parser.add_argument("--top_p", type=float, default=0.95, help="Set top_p value for generation")
73
74
    parser.add_argument("--input_txt", type=str, default="明月松间照,", help="The prompt input to the model")
    parser.add_argument("--prompt_style", choices=["sft", "pretrained"], default="sft", help="The style of the prompt")
75
    args = parser.parse_args()
76
    generate(args)