gpt_dynamic_inference.py 10.7 KB
Newer Older
jerrrrry's avatar
jerrrrry committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.

import torch
from argparse import ArgumentParser
from collections import defaultdict
from tqdm import tqdm
from typing import List

from megatron.core.inference.contexts import (
    ContextOverflowError,
    DynamicInferenceContext,
)
from megatron.core.inference.engines import DynamicInferenceEngine
from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper import GPTInferenceWrapper
from megatron.core.inference.sampling_params import SamplingParams
from megatron.core.inference.text_generation_controllers.text_generation_controller import TextGenerationController
from megatron.core.transformer.module import MegatronModule
from megatron.training import (
    get_args,
    get_model as _get_model,
    get_tokenizer,
    initialize_megatron,
)
from megatron.training.checkpointing import load_checkpoint
from pretrain_gpt import model_provider

from .utils import add_common_inference_args, build_requests, get_curr_time, Request


def add_dynamic_inference_args(parser: ArgumentParser) -> ArgumentParser:
    """Dynamic inference arguments."""

    add_common_inference_args(parser)

    return parser


def get_model() -> MegatronModule:
    """Initialize model and load checkpoint."""

    args = get_args()

    # Build model.
    model = _get_model(model_provider, wrap_with_ddp=False)

    # Load checkpoint.
    assert args.load is not None
    args.exit_on_missing_checkpoint = True
    load_checkpoint(model, None, None)

    # No virtual PP.
    assert len(model) == 1, "Above condition should have caught this"
    model = model[0]

    # Eval mode.
    model.eval()

    return model


def get_inference_context(
    requests: List[Request],
    sampling_params: SamplingParams,
):
    """The inference context manages the KV cache and other inference state."""

    args = get_args()

    # Max sequence length.
    max_gen_length = sampling_params.num_tokens_to_generate
    max_context_length = max(len(r.prompt_tokens) for r in requests)
    max_sequence_length = max_context_length + max_gen_length

    # Inference context.
    context = DynamicInferenceContext(
        params_dtype=args.params_dtype,
        num_layers=args.num_layers,
        kv_channels=args.kv_channels,
        num_attention_heads=args.num_query_groups if args.group_query_attention else args.num_attention_heads,
        max_sequence_length=max_sequence_length,
        buffer_size_gb=args.inference_dynamic_batching_buffer_size_gb,
        buffer_guaranteed_fraction=args.inference_dynamic_batching_buffer_guaranteed_fraction,
        buffer_overflow_factor=args.inference_dynamic_batching_buffer_overflow_factor,
        max_requests_override=args.inference_dynamic_batching_max_requests_override,
        max_tokens_override=args.inference_dynamic_batching_max_tokens_override,
    )

    return context


def get_inference_controller(
    model: MegatronModule,
    context: DynamicInferenceContext,
) -> TextGenerationController:
    """Buid text generation controller, which manages the model inference context.

    Args:
        model (MegatronModule): Megatron GPT model.
        context (DynamicInferenceContext): Context for managing KV cache.

    Return:
        (TextGenerationController) Inference text generation controller.
    """

    args = get_args()
    tokenizer = get_tokenizer()

    # Wrap model in inference wrapper.
    model = GPTInferenceWrapper(model, args, context)

    # Note: the following is taken from AbstractModelInferenceWrapper.prep_model_for_inference().
    from megatron.core import parallel_state
    model.model_is_pipeline_parallel = not (
        parallel_state.is_pipeline_first_stage() and
        parallel_state.is_pipeline_last_stage()
    )

    # Text generation controller.
    controller = TextGenerationController(model, tokenizer)

    return controller


def run_inference(
    requests: List[Request],
    sampling_params: SamplingParams,
    engine: DynamicInferenceEngine,
) -> None:
    """Add requests to engine and generate tokens.

    Args:
        requests (List[Request]): Requests that are to be added and processed.
        sampling_params (SamplingParams): Sampling params for the logits.
        engine (DynamicInferenceEngine): Inference engine that manages generating tokens.

    Return:
        None.
    """

    # Initialize request arrival times.
    base_arrival_time = get_curr_time()
    for request in requests:
        request.time_arrival = request.time_offset + base_arrival_time

    # Add and process requests.
    num_requests_total = len(requests)
    num_requests_added = 0
    num_requests_finished = 0
    step_id = 0
    step_times = {"prefill": [], "decode": []}
    add_times = []
    output_times = []
    tbar = tqdm(total=num_requests_total)
    while True:
        curr_time = get_curr_time()

        # Add requests with 'earlier' arrival time.
        add_start = get_curr_time()
        while num_requests_added < num_requests_total:
            request = requests[num_requests_added]
            if request.time_arrival > curr_time:
                break
            try:

                # Using `prompt_text` instead of `prompt_tokens` for fair comparison.
                engine.add_request(num_requests_added, request.prompt_text)
                request.time_start = get_curr_time()
                request.state = "started"
                num_requests_added += 1
                tbar.update(1)
            except ContextOverflowError:
                break
        add_times.append(get_curr_time() - add_start)

        # Step inference engine (i.e., generate a token for each active request).
        is_decode_only = engine.context.is_decode_only()
        result, step_time = engine.step(sampling_params, verbose=True)
        step_id += 1

        # Append output tokens.
        if result is not None:

            output_start = get_curr_time()

            if is_decode_only:
                step_times["decode"].append(step_time)
            else:
                step_times["prefill"].append(step_time)

            request_ids, finished_request_ids, sample = result
            request_ids = request_ids.tolist()
            sample = sample.tolist()
            for request_id, token in zip(request_ids, sample):
                request = requests[request_id]
                request.output_tokens.append(token)
                if request_id in finished_request_ids:
                    request.time_end = get_curr_time()
                    request.state = "finished"
                    num_requests_finished += 1

            output_times.append(get_curr_time() - output_start)

        # Check if all requests are finished.
        if not (engine.has_unfinished_requests() or
                num_requests_added < num_requests_total):
            break

    return step_times, add_times, output_times


if __name__ == "__main__":

    # Initialize Megatron.
    initialize_megatron(
        extra_args_provider=add_dynamic_inference_args,
        args_defaults={'no_load_rng': True,
                       'no_load_optim': True},
    )

    args = get_args()
    tokenizer = get_tokenizer()

    # Sampling params.
    sampling_params = SamplingParams(
        temperature=args.temperature,
        top_k=args.top_k,
        top_p=args.top_p,
        return_log_probs=args.return_log_probs,
        num_tokens_to_generate=args.num_tokens_to_generate,
    )

    # Requests, context, conroller.
    model = get_model()
    requests = build_requests(args, tokenizer)
    context = get_inference_context(requests, sampling_params)
    controller = get_inference_controller(model, context)

    # Inference engine.
    engine = DynamicInferenceEngine(controller,
                                    context,
                                    termination_id=tokenizer.eod,
                                    enable_cuda_graph=args.enable_cuda_graph,
                                    random_seed=args.seed)

    # Print setup.
    setup_prefix = "dynamic | cg %d | %s | bf %.0f, flw %.1f [r %d, t %d], gtd %.2f [r %d] ... reqs %d" % (
        args.enable_cuda_graph,
        (
            f"<user prompts, n {len(args.prompts)}>"
            if args.prompts else
            "<auto prompts> %s, %d, %.1e, %.1e" % (
                "(%s)" % " ".join(map(str, args.num_tokens_to_prompt)),
                args.num_tokens_to_generate,
                args.incoming_requests_duration,
                args.incoming_requests_per_sec,
            )
        ),
        args.inference_dynamic_batching_buffer_size_gb,
        args.inference_dynamic_batching_buffer_overflow_factor,
        context.max_requests,
        context.max_tokens,
        args.inference_dynamic_batching_buffer_guaranteed_fraction,
        context.gtd_request_count,
        len(requests),
    )
    print("~~~")
    print(setup_prefix)
    print("~~~")

    # Run and time test.
    t = get_curr_time()
    step_times, add_times, output_times = run_inference(requests, sampling_params, engine)
    total_time = get_curr_time() - t

    # Validate all requests finished.
    for request in requests:
        assert request.state == "finished"

    # Detokenize outputs.
    for request in requests:
        request.output_text = tokenizer.detokenize(request.output_tokens)

    # Print unique prompts + outputs.
    if torch.distributed.get_rank() == 0:

        print("~~~~ Unique prompts + outputs. ~~~~")

        # Map requests by their prompt.
        unique_prompt_map = defaultdict(list)
        for request_idx, request in enumerate(requests):
            unique_prompt_map[request.prompt_text].append(request_idx)

        # Print unique prompts + outputs.
        for unique_idx, (prompt_text, request_idxs) in enumerate(unique_prompt_map.items()):
            request_idx = request_idxs[0]
            request = requests[request_idx]
            print(f"{unique_idx}/{len(unique_prompt_map)} [{len(request_idxs)}]. {prompt_text} ... %s" % request.output_text.replace("\n", "\\n"))

    # Timing results.
    stats = torch.cuda.memory_stats()
    print("~~~")
    print("%s ... mem %.1f/%.1f ... total time: %.3f ... step time: total %.3f [ p %.3f, d %.3f ], mean [ p %.3f, d %.3f ], count [ p %d, d %d ] ... add time: %.3f, output time: %.3f." % (
        setup_prefix,
        stats["allocated_bytes.all.peak"] / (1024**3),
        stats["reserved_bytes.all.peak"] / (1024**3),
        sum(step_times["prefill"]) + sum(step_times["decode"]) + sum(add_times),
        sum(step_times["prefill"]) + sum(step_times["decode"]),
        sum(step_times["prefill"]),
        sum(step_times["decode"]),
        sum(step_times["prefill"]) / len(step_times["prefill"]),
        sum(step_times["decode"]) / len(step_times["decode"]),
        len(step_times["prefill"]),
        len(step_times["decode"]),
        sum(add_times),
        sum(output_times),
    ))
    print("~~~")