gpt_batch_inference.py 5.19 KB
Newer Older
xingjinliang's avatar
xingjinliang committed
1
2
3
4
5
6
7
8
import os
from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import InferenceWrapperConfig
from pretrain_gpt import model_provider
import torch
import sys
from argparse import Namespace
from megatron.core.inference.engines.abstract_engine import AbstractEngine
from megatron.core.inference.engines.mcore_engine import MCoreEngine
silencealiang's avatar
add  
silencealiang committed
9
from megatron.core.inference.sampling_params import SamplingParams
xingjinliang's avatar
xingjinliang committed
10
11
from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper import GPTInferenceWrapper
from megatron.core.inference.inference_request import InferenceRequest
silencealiang's avatar
add  
silencealiang committed
12
from megatron.core.inference.text_generation_controllers.text_generation_controller import TextGenerationController
xingjinliang's avatar
xingjinliang committed
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from megatron.core.transformer.module import MegatronModule
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
                                             os.path.pardir, os.path.pardir)))

from megatron.training import get_args
from megatron.training import get_tokenizer
from megatron.training.checkpointing import load_checkpoint
from megatron.core import mpu
from megatron.training.initialize import initialize_megatron
from megatron.training import get_model
from typing import List

def add_text_generate_args(parser):
    """Text generation arguments."""
    group = parser.add_argument_group(title='text generation')

    group.add_argument("--temperature", type=float, default=1.0,
                       help='Sampling temperature.')
    group.add_argument("--top_k", type=int, default=1,
                       help='Top k sampling.')
    group.add_argument("--top_p", type=float, default=0.0,
                       help='Top p sampling.')
    group.add_argument("--return-log-probs", action='store_true', default=False,
                       help='Return the log probabilities of the final output tokens')
    group.add_argument("--num-tokens-to-generate", type=int, default=30,
                       help='Number of tokens to generate for each prompt')
    group.add_argument("--prompts", metavar='N', type=str, nargs='+',
                       help='Input prompts with each prompt within quotes and seperated by space')
    group.add_argument("--max-batch-size", type=int, default=1,
                       help='Max number of prompts to process at once')
    return parser


def get_inference_engine(args: Namespace, model: MegatronModule) -> AbstractEngine:
    """Utility to get the relevant backend for running inference

    This function will automatically chose the TRTLLMBackend when possible, and if not revert to Mcore backend if the user does not specify any backends. TRT LLM Backend is not implmented yet. 

    Args:
        args (Namespace): The user arguments parsed from command line
        model (MegatronModule): The megatron model . 

    Returns:
        AbstractBackend: The chosen backend
    """
    tokenizer = get_tokenizer()

    inference_wrapper_config = InferenceWrapperConfig(
        hidden_size=args.hidden_size,
        inference_batch_times_seqlen_threshold=args.inference_batch_times_seqlen_threshold,
        fp32_residual_connection=args.fp32_residual_connection,
        params_dtype=args.params_dtype,
        padded_vocab_size=args.padded_vocab_size
    )

    inference_wrapped_model = GPTInferenceWrapper(model, inference_wrapper_config)
silencealiang's avatar
add  
silencealiang committed
69
    text_generation_controller = TextGenerationController(inference_wrapped_model=inference_wrapped_model, tokenizer=tokenizer)
xingjinliang's avatar
xingjinliang committed
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
    return MCoreEngine(text_generation_controller=text_generation_controller, max_batch_size=args.max_batch_size)
            
def main():
    """Main program."""

    # Note: The default args passed here can be overwritten by using appropriate params (check arguments.py file)
    # Micro batch size is not needed to be set by user. (It is calculated based on inference-batch-times-seqlen-threshold argument)
    initialize_megatron(extra_args_provider=add_text_generate_args,
                        args_defaults={'no_load_rng': True,
                                       'no_load_optim': True,
                                       'micro_batch_size': 1, 
                                       'exit_on_missing_checkpoint': True})

    # Set up model and load checkpoint
    model = get_model(model_provider, wrap_with_ddp=False)
    load_checkpoint(model, None, None)
    model = model[0]

    args = get_args()

    inference_engine = get_inference_engine(args, model)

silencealiang's avatar
add  
silencealiang committed
92
    sampling_params = SamplingParams(
xingjinliang's avatar
xingjinliang committed
93
94
95
96
97
98
99
        temperature=args.temperature, 
        top_k=args.top_k, 
        top_p=args.top_p, 
        return_log_probs=args.return_log_probs, 
        num_tokens_to_generate=args.num_tokens_to_generate)

    results: List[InferenceRequest] = inference_engine.generate(
silencealiang's avatar
add  
silencealiang committed
100
        prompts=args.prompts, sampling_params=sampling_params
xingjinliang's avatar
xingjinliang committed
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
    )
    
    if torch.distributed.get_rank() == 0:
        for idx, result in enumerate(results):
            print(f' \n------------- RESULT FOR PROMPT {idx} --------------- ')
            result = {
                'id': result.request_id,
                'input_prompt': result.prompt, 
                'generated_text': result.generated_text,
                'generated_tokens' : result.generated_tokens
                }
            print(result)

if __name__ == "__main__":
    main()