utils.py 12 KB
Newer Older
jerrrrry's avatar
jerrrrry committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.

import json
import itertools
import random
import time
import torch
from argparse import ArgumentParser, Namespace
from tqdm import tqdm
from typing import Any, List, Optional

from megatron.core.inference.inference_request import DynamicInferenceRequest
from megatron.core.inference.contexts import DynamicInferenceContext
from megatron.core.transformer.module import MegatronModule



def add_common_inference_args(parser: ArgumentParser) -> ArgumentParser:
    """Common inference arguments."""

    group = parser.add_argument_group(title='Common inference')

    group.add_argument("--temperature", type=float, default=1.0, help='Sampling temperature.')
    group.add_argument("--top_k", type=int, default=1, help='Top k sampling.')
    group.add_argument("--top_p", type=float, default=0.0, help='Top p sampling.')
    group.add_argument(
        "--return-log-probs",
        action='store_true',
        default=False,
        help='Return the log probabilities of the final output tokens',
    )
    group.add_argument(
        "--prompts",
        metavar='N',
        type=str,
        nargs='+',
        help='Input prompts with each prompt within quotes and seperated by space',
    )
    group.add_argument(
        "--num-tokens-to-prompt",
        type=int,
        nargs="+",
        default=[64, 1024],
        help='Number of tokens to use for simulated prompts. This should be a '
        'space-separated pair of integers, and the generated prompt lengths will '
        'be uniformly sampled within this range.',
    )
    group.add_argument(
        "--num-tokens-to-generate",
        type=int,
        default=30,
        help='Number of tokens to generate for each prompt',
    )
    group.add_argument(
        "--top-n-logprobs",
        type=int,
        default=0,
        help='Return the top n logprobs for the generated tokens and their corresponding token as a dictionary',
    )
    group.add_argument(
        "--incoming-requests-per-step",
        type=int, default=None,
        help="Add a deterministic number of requests per step. This arg is "
        "prioritized over `--incoming-requests-per-sec` below (which is non-"
        "deterministic). Note that the number of requests added per step is "
        "additionally limited by the inference context's `max_requests`, "
        "`max_tokens`, and KV buffer size.",
    )
    group.add_argument(
        "--incoming-requests-per-sec",
        type=float,
        default=100.0,
        help="Simulated number of requests per second. Set to -1 to add all requests together.",
    )
    group.add_argument(
        "--incoming-requests-duration",
        type=float,
        default=10.0,
        help="Total amount of time to simulate that requests are "
        "arriving. Multiply this value with "
        "`--incoming-requests-per-sec` to get the approximate "
        "total number of requests. Set to -1 to add all requests together.",
    )
    group.add_argument(
        "--model-provider",
        choices=["mamba", "gpt"],
        default="gpt",
        help="Model provider",
    )
    group.add_argument(
        "--output-path",
        type=str,
        default=None,
        help="Path to save generations as JSON",
    )
    group.add_argument(
        "--output-every-n-results",
        type=int,
        default=1,
        help="To minimize the output file size of larger runs, only write the "
        "results of every `n` requests.",
    )
    group.add_argument(
        "--prompt-file",
        help='Jsonl file containing input prompts, where each item (i.e., line) '
        'contains the field \'text\' where the value is the prompt. All other '
        'fields within each item are ignored, and may be customized for each '
        'application.',
    )
    group.add_argument(
        "--prompt-file-num-truncate",
        type=int,
        help='Number of samples to use from the loaded prompt file (see '
        '`--prompt-file` above). The first `--prompt-file-num-truncate` samples '
        'will be used, in order.',
    )
    group.add_argument(
        "--inference-coordinator-port",
        type=int,
        help="This port will be used to setup the inference co-ordinator on node-0",
        default=12346
    )
    group.add_argument(
        "--use-flashinfer-fused-rope",
        action='store_true',
        default=False,
        help='Use flashinfer fused rope implementation.',
    )

    return parser


def get_curr_time() -> float:
    """Get synchronized time across ranks."""
    curr_time = torch.cuda.LongTensor([time.time_ns()])
    if torch.distributed.is_initialized():
        torch.distributed.broadcast(curr_time, src=0)
    return curr_time.item() / 10**9


class Request:
    """Class to hold attributes for a single request.

    A request is initialized with its prompt text. As it is added, processed,
    and completed through the inference engine, the request is populated with its
    start time, end time, and output tokens.

    Args:
        prompt_text (str): Prompt text.
        time_offset (float): Artificial time offset for simulating incoming
            requests. This value is later added to the `base_arrival_time` to
            simulate the requests arrival time.
        tokenizer (Any): Tokenizer for tokenizing the prompt.
    """

    def __init__(self, prompt_text: str, time_offset: float, tokenizer: Any):
        self.prompt_text = prompt_text
        self.prompt_tokens = tokenizer.tokenize(prompt_text)
        self.output_text = None
        self.output_tokens = []
        self.time_offset = time_offset
        self.time_arrival = None
        self.time_start = None
        self.time_end = None
        self.state = "not-started"

    def __str__(self) -> str:
        return "state '%s'; toffset %.1e; prompt len %d; output len %d; '%s'" % (
            self.state,
            self.time_offset,
            len(self.prompt_tokens),
            len(self.output_tokens),
            self.prompt_text,
        )


def get_time_offsets(
    seed: int | None,
    incoming_requests_per_step: int,
    incoming_requests_per_sec: float,
    num_requests: int,
) -> list[float]:
    """Get example time offsets."""

    # Time offsets to add all requests at once.
    if incoming_requests_per_step is not None or incoming_requests_per_sec <= 0:
        return [-1] * num_requests

    # if num_requests is not None:
    incoming_requests_duration = num_requests / incoming_requests_per_sec
    incoming_requests_duration *= 2 # extra margin, to accomodate time sampling

    random.seed(seed)
    
    import simpy  # Guard against this import in test case

    # Generate random time offsets.
    def arrival(r):
        while True:
            yield env.timeout(random.expovariate(r))
            time_offsets.append(env.now)

    time_offsets = []
    env = simpy.Environment()
    env.process(arrival(incoming_requests_per_sec))
    env.run(incoming_requests_duration)
    
    # Ensure at least a single request.
    if len(time_offsets) == 0:
        time_offsets = [0.0]

    # Truncate to num_requests.
    assert len(time_offsets) >= num_requests
    time_offsets = time_offsets[:num_requests]

    return time_offsets


def get_cli_requests(args: Namespace, tokenizer: Any) -> list[Request]:

    # Get time offsets.
    time_offsets = get_time_offsets(
        args.seed,
        args.incoming_requests_per_step,
        args.incoming_requests_per_sec,
        len(args.prompts),
    )

    # Init requests.
    requests = [Request(p, t, tokenizer) for p,t in zip(args.prompts, time_offsets)]
    return requests


def get_synthetic_requests(args: Namespace, tokenizer: Any) -> list[Request]:
    """Get example requests."""

    # Get time offsets.
    time_offsets = get_time_offsets(
        args.seed,
        args.incoming_requests_per_step,
        args.incoming_requests_per_sec,
        int(args.incoming_requests_per_sec * args.incoming_requests_duration),
    )

    # Init requests.
    requests = [
        Request("hi " * random.randint(*args.num_tokens_to_prompt), t, tokenizer)
        for t in time_offsets
    ]

    return requests


def get_requests_from_file(args: Namespace, tokenizer: Any) -> list[Request]:
    """Get requests from a file."""
    if not args.prompt_file:
        raise ValueError("Prompt file is required to read requests from a file.")

    # Load prompts.
    n_prompts = sum(1 for _ in open(args.prompt_file))
    prompts = []
    with open(args.prompt_file) as f:
        for line in tqdm(f.readlines(), "read prompt file", total=n_prompts):
            prompts.append(json.loads(line)["text"])
            if len(prompts) == args.prompt_file_num_truncate:
                break

    # Get time offsets.
    time_offsets: list[float] = get_time_offsets(
        args.seed,
        args.incoming_requests_per_step,
        args.incoming_requests_per_sec,
        len(prompts),
    )

    # Init requests.
    requests = [
        Request(p, t, tokenizer)
        for p, t in tqdm(zip(prompts, time_offsets), "init requests", total=len(prompts))
    ]

    return requests


def build_requests(args: Namespace, tokenizer: Any) -> list[Request]:
    # Check if we have any prompts (from command line or JSONL)
    if args.prompts:
        if args.prompt_file:
            raise ValueError("Cannot use both --prompts and --prompt-file")
        return get_cli_requests(args, tokenizer)
    elif args.prompt_file:
        return get_requests_from_file(args, tokenizer)
    else:
        return get_synthetic_requests(args, tokenizer)


def get_model_size_str(model):
    n = sum(p.numel() for p in model.parameters())
    for exp, suffix in ((12, "t"), (9, "b"), (6, "m"), (3, "k"), (0, "")):
        nquery = int(10**exp)
        if n > nquery:
            return "%d%s" % (n // nquery, suffix)
    raise Exception("something went wrong.")


def build_dynamic_engine_setup_prefix(
    args: Namespace,
    model: MegatronModule,
    context: DynamicInferenceContext,
    requests: list[DynamicInferenceRequest],
):
    """
    Returns a compact, pipe-separated summary of the dynamic-batching setup.

    Example output:

    `dynamic | cg True | prompts: synth(16 256), n 1024, g 512, t 1.0e+02 5.0e-01 | bf 4, 1.2 [r 1024, t 8192] | gtd 0.50 [r 512] | reqs 100` # pylint: disable=line-too-long

    Args:
        args (Namespace): Command-line arguments for this run.
        context (DynamicInferenceContext): Stores limits such as `max_requests`,
            `max_tokens`, and `gtd_request_count`.
        requests (List[DynamicInferenceRequest]): List of inference requests.

    Returns:
        A configuration string for logging.
    """
    # CUDA graph config
    if args.cuda_graph_impl == "local":
        cg_str = (
            f"graphs {context.cuda_graph_token_counts[0]}:"
            f"{context.cuda_graph_token_counts[-1]}"
        )
    else:
        cg_str = "--"

    # Prompt description
    prompt_src_str = (
        "cli" if args.prompts else
        "file" if args.prompt_file else
        f"synth({', '.join(map(str, args.num_tokens_to_prompt))})"
    )
    request_str = (
        f"requests: {prompt_src_str}, "
        f"n {len(requests):d}, g {args.num_tokens_to_generate:d}, "
    )
    request_str += (
        f"dur {args.incoming_requests_duration:.1e} "
        f"r/sec {args.incoming_requests_per_sec:.1e}"
        if args.incoming_requests_per_step is None else
        f"r/step {args.incoming_requests_per_step}"
    )

    # Buffer limits config
    flw = args.inference_dynamic_batching_buffer_overflow_factor
    flw_str = "no overflow" if flw is None else f"{flw:.1f}"
    buffer_limits_str = (
        f"bf {args.inference_dynamic_batching_buffer_size_gb:.0f}, {flw_str} "
        f"[r {context.max_requests}, t {context.max_tokens}]"
    )

    # Guaranteed request config
    guaranteed_fraction_str = (
        f"gtd {args.inference_dynamic_batching_buffer_guaranteed_fraction:.2f} "
        f"[r {context.gtd_request_count}]"
    )

    parts = [
        get_model_size_str(model),
        "dynamic",
        cg_str,
        request_str,
        buffer_limits_str,
        guaranteed_fraction_str,
    ]

    return " | ".join(parts)