run_llama_inference.py 3.9 KB
Newer Older
1
2
3
4
5
6
7
import argparse

import torch
import torch.distributed as dist
from transformers import LlamaForCausalLM, LlamaTokenizer

import colossalai
8
from colossalai.accelerator import get_accelerator
9
10
from colossalai.inference import InferenceEngine
from colossalai.testing import spawn
11
12
13
14
15

INPUT_TEXTS = [
    "What is the longest river in the world?",
    "Explain the difference between process and thread in compouter science.",
]
16
17
18
19


def run_inference(args):
    llama_model_path = args.model_path
20
    llama_tokenize_path = args.tokenizer_path or args.model_path
21
22
23
24
25
26
27
28
29
30

    max_input_len = args.max_input_len
    max_output_len = args.max_output_len
    max_batch_size = args.batch_size
    micro_batch_size = args.micro_batch_size
    tp_size = args.tp_size
    pp_size = args.pp_size
    rank = dist.get_rank()

    tokenizer = LlamaTokenizer.from_pretrained(llama_tokenize_path, padding_side="left")
31
    tokenizer.pad_token_id = tokenizer.eos_token_id
32
33

    if args.quant is None:
34
        model = LlamaForCausalLM.from_pretrained(llama_model_path, pad_token_id=tokenizer.pad_token_id)
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
    elif args.quant == "gptq":
        from auto_gptq import AutoGPTQForCausalLM

        model = AutoGPTQForCausalLM.from_quantized(
            llama_model_path, inject_fused_attention=False, device=torch.cuda.current_device()
        )
    elif args.quant == "smoothquant":
        from colossalai.inference.quant.smoothquant.models.llama import SmoothLlamaForCausalLM

        model = SmoothLlamaForCausalLM.from_quantized(llama_model_path, model_basename=args.smoothquant_base_name)
        model = model.cuda()

    engine = InferenceEngine(
        tp_size=tp_size,
        pp_size=pp_size,
        model=model,
        max_input_len=max_input_len,
        max_output_len=max_output_len,
53
        max_batch_size=max_batch_size,
54
55
        micro_batch_size=micro_batch_size,
        quant=args.quant,
56
        dtype=args.dtype,
57
58
    )

59
    inputs = tokenizer(INPUT_TEXTS, return_tensors="pt", padding="longest", max_length=max_input_len, truncation=True)
60
    inputs = {k: v.to(get_accelerator().get_current_device()) for k, v in inputs.items()}
61
    outputs = engine.generate(inputs)
62
63

    if rank == 0:
64
65
66
67
        output_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)
        for input_text, output_text in zip(INPUT_TEXTS, output_texts):
            print(f"Input: {input_text}")
            print(f"Output: {output_text}")
68
69
70
71
72
73
74
75
76
77


def run_tp_pipeline_inference(rank, world_size, port, args):
    colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
    run_inference(args)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("-p", "--model_path", type=str, help="Model path", required=True)
78
79
    parser.add_argument("-i", "--input", default="What is the longest river in the world?")
    parser.add_argument("-t", "--tokenizer_path", type=str, help="Tokenizer path", default=None)
80
81
82
83
84
85
86
87
88
    parser.add_argument(
        "-q",
        "--quant",
        type=str,
        choices=["gptq", "smoothquant"],
        default=None,
        help="quantization type: 'gptq' or 'smoothquant'",
    )
    parser.add_argument("--smoothquant_base_name", type=str, default=None, help="soothquant base name")
89
90
    parser.add_argument("--tp_size", type=int, default=1, help="Tensor parallel size")
    parser.add_argument("--pp_size", type=int, default=1, help="Pipeline parallel size")
91
    parser.add_argument("-b", "--batch_size", type=int, default=4, help="Maximum batch size")
92
93
    parser.add_argument("--max_input_len", type=int, default=2048, help="Maximum input length")
    parser.add_argument("--max_output_len", type=int, default=64, help="Maximum output length")
94
    parser.add_argument("--micro_batch_size", type=int, default=1, help="Micro batch size")
95
    parser.add_argument("--dtype", default="fp16", type=str)
96
97
98

    args = parser.parse_args()
    spawn(run_tp_pipeline_inference, nprocs=args.tp_size * args.pp_size, args=args)