benchmark_throughput.py 24.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
"""Benchmark offline inference throughput."""
3
import argparse
4
import dataclasses
5
import json
6
import os
7
8
import random
import time
9
import warnings
10
from typing import Any, Optional, Union
11

12
import torch
13
import uvloop
14
from benchmark_dataset import (BurstGPTDataset, ConversationDataset,
15
16
17
                               InstructCoderDataset, RandomDataset,
                               SampleRequest, ShareGPTDataset, SonnetDataset,
                               VisionArenaDataset)
18
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
19
from tqdm import tqdm
20
21
from transformers import (AutoModelForCausalLM, AutoTokenizer,
                          PreTrainedTokenizerBase)
22

23
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
24
25
from vllm.entrypoints.openai.api_server import (
    build_async_engine_client_from_engine_args)
26
from vllm.inputs import TextPrompt, TokensPrompt
27
from vllm.lora.request import LoRARequest
28
from vllm.outputs import RequestOutput
29
from vllm.sampling_params import BeamSearchParams
30
from vllm.utils import FlexibleArgumentParser, merge_async_iterators
31

32

Woosuk Kwon's avatar
Woosuk Kwon committed
33
def run_vllm(
34
    requests: list[SampleRequest],
35
    n: int,
36
    engine_args: EngineArgs,
37
    disable_detokenize: bool = False,
38
) -> tuple[float, Optional[list[RequestOutput]]]:
39
    from vllm import LLM, SamplingParams
40
    llm = LLM(**dataclasses.asdict(engine_args))
41
42
43
44
45
46
    assert all(
        llm.llm_engine.model_config.max_model_len >= (
            request.prompt_len + request.expected_output_len)
        for request in requests), (
            "Please ensure that max_model_len is greater than the sum of"
            " prompt_len and expected_output_len for all requests.")
Zhuohan Li's avatar
Zhuohan Li committed
47
    # Add the requests to the engine.
48
    prompts: list[Union[TextPrompt, TokensPrompt]] = []
49
    sampling_params: list[SamplingParams] = []
50
    for request in requests:
51
        prompts.append(
52
53
54
            TokensPrompt(prompt_token_ids=request.prompt["prompt_token_ids"],
                       multi_modal_data=request.multi_modal_data)
            if "prompt_token_ids" in request.prompt else \
55
56
            TextPrompt(prompt=request.prompt,
                       multi_modal_data=request.multi_modal_data))
57
58
59
        sampling_params.append(
            SamplingParams(
                n=n,
60
                temperature=1.0,
61
62
                top_p=1.0,
                ignore_eos=True,
63
                max_tokens=request.expected_output_len,
64
                detokenize=not disable_detokenize,
65
            ))
66
    lora_requests: Optional[list[LoRARequest]] = None
67
68
    if engine_args.enable_lora:
        lora_requests = [request.lora_request for request in requests]
69

70
71
    use_beam_search = False

72
    outputs = None
73
    if not use_beam_search:
74
        start = time.perf_counter()
75
76
77
78
        outputs = llm.generate(prompts,
                               sampling_params,
                               lora_request=lora_requests,
                               use_tqdm=True)
79
80
        end = time.perf_counter()
    else:
81
        assert lora_requests is None, "BeamSearch API does not support LoRA"
82
        prompts = [request.prompt for request in requests]
83
84
        # output_len should be the same for all requests.
        output_len = requests[0][2]
85
86
        for request in requests:
            assert request.expected_output_len == output_len
87
        start = time.perf_counter()
88
89
90
91
92
93
94
        llm.beam_search(
            prompts,
            BeamSearchParams(
                beam_width=n,
                max_tokens=output_len,
                ignore_eos=True,
            ))
95
        end = time.perf_counter()
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
    return end - start, outputs


def run_vllm_chat(
        requests: list[SampleRequest],
        n: int,
        engine_args: EngineArgs,
        disable_detokenize: bool = False) -> tuple[float, list[RequestOutput]]:
    """
    Run vLLM chat benchmark. This function is recommended ONLY for benchmarking
    multimodal models as it properly handles multimodal inputs and chat
    formatting. For non-multimodal models, use run_vllm() instead.
    """
    from vllm import LLM, SamplingParams
    llm = LLM(**dataclasses.asdict(engine_args))

    assert all(
        llm.llm_engine.model_config.max_model_len >= (
            request.prompt_len + request.expected_output_len)
        for request in requests), (
            "Please ensure that max_model_len is greater than the sum of "
            "prompt_len and expected_output_len for all requests.")

    prompts = []
    sampling_params: list[SamplingParams] = []
    for request in requests:
        prompts.append(request.prompt)
        sampling_params.append(
            SamplingParams(
                n=n,
                temperature=1.0,
                top_p=1.0,
                ignore_eos=True,
                max_tokens=request.expected_output_len,
                detokenize=not disable_detokenize,
            ))
    start = time.perf_counter()
    outputs = llm.chat(prompts, sampling_params, use_tqdm=True)
    end = time.perf_counter()
    return end - start, outputs
136
137


138
async def run_vllm_async(
139
    requests: list[SampleRequest],
140
    n: int,
141
    engine_args: AsyncEngineArgs,
142
    disable_frontend_multiprocessing: bool = False,
143
    disable_detokenize: bool = False,
144
145
146
147
148
) -> float:
    from vllm import SamplingParams

    async with build_async_engine_client_from_engine_args(
            engine_args, disable_frontend_multiprocessing) as llm:
149
150
151
152
153
154
        assert all(
            llm.model_config.max_model_len >= (request.prompt_len +
                                               request.expected_output_len)
            for request in requests), (
                "Please ensure that max_model_len is greater than the sum of"
                " prompt_len and expected_output_len for all requests.")
155
156

        # Add the requests to the engine.
157
        prompts: list[Union[TextPrompt, TokensPrompt]] = []
158
159
        sampling_params: list[SamplingParams] = []
        lora_requests: list[Optional[LoRARequest]] = []
160
        for request in requests:
161
            prompts.append(
162
163
164
                TokensPrompt(prompt_token_ids=request.prompt["prompt_token_ids"],
                        multi_modal_data=request.multi_modal_data)
                if "prompt_token_ids" in request.prompt else \
165
166
                TextPrompt(prompt=request.prompt,
                           multi_modal_data=request.multi_modal_data))
167
168
169
            sampling_params.append(
                SamplingParams(
                    n=n,
170
                    temperature=1.0,
171
172
                    top_p=1.0,
                    ignore_eos=True,
173
                    max_tokens=request.expected_output_len,
174
                    detokenize=not disable_detokenize,
175
                ))
176
            lora_requests.append(request.lora_request)
177
178
179

        generators = []
        start = time.perf_counter()
180
181
182
183
184
185
        for i, (prompt, sp,
                lr) in enumerate(zip(prompts, sampling_params, lora_requests)):
            generator = llm.generate(prompt,
                                     sp,
                                     lora_request=lr,
                                     request_id=f"test{i}")
186
187
188
189
190
191
192
193
            generators.append(generator)
        all_gens = merge_async_iterators(*generators)
        async for i, res in all_gens:
            pass
        end = time.perf_counter()
        return end - start


194
def run_hf(
195
    requests: list[SampleRequest],
196
197
198
199
    model: str,
    tokenizer: PreTrainedTokenizerBase,
    n: int,
    max_batch_size: int,
200
    trust_remote_code: bool,
201
    disable_detokenize: bool = False,
202
) -> float:
203
204
    llm = AutoModelForCausalLM.from_pretrained(
        model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code)
205
206
207
    if llm.config.model_type == "llama":
        # To enable padding in the HF backend.
        tokenizer.pad_token = tokenizer.eos_token
208
209
210
    llm = llm.cuda()

    pbar = tqdm(total=len(requests))
211
    start = time.perf_counter()
212
    batch: list[str] = []
213
214
215
216
217
218
219
220
221
222
223
    max_prompt_len = 0
    max_output_len = 0
    for i in range(len(requests)):
        prompt, prompt_len, output_len = requests[i]
        # Add the prompt to the batch.
        batch.append(prompt)
        max_prompt_len = max(max_prompt_len, prompt_len)
        max_output_len = max(max_output_len, output_len)
        if len(batch) < max_batch_size and i != len(requests) - 1:
            # Check if we can add more requests to the batch.
            _, next_prompt_len, next_output_len = requests[i + 1]
224
225
            if (max(max_prompt_len, next_prompt_len) +
                    max(max_output_len, next_output_len)) <= 2048:
226
227
228
229
                # We can add more requests to the batch.
                continue

        # Generate the sequences.
230
231
        input_ids = tokenizer(batch, return_tensors="pt",
                              padding=True).input_ids
232
233
        llm_outputs = llm.generate(
            input_ids=input_ids.cuda(),
234
            do_sample=True,
235
236
237
238
239
240
            num_return_sequences=n,
            temperature=1.0,
            top_p=1.0,
            use_cache=True,
            max_new_tokens=max_output_len,
        )
241
242
243
        if not disable_detokenize:
            # Include the decoding time.
            tokenizer.batch_decode(llm_outputs, skip_special_tokens=True)
244
245
246
247
248
249
        pbar.update(len(batch))

        # Clear the batch.
        batch = []
        max_prompt_len = 0
        max_output_len = 0
250
    end = time.perf_counter()
251
252
253
    return end - start


254
def run_mii(
255
    requests: list[SampleRequest],
256
257
258
259
    model: str,
    tensor_parallel_size: int,
    output_len: int,
) -> float:
260
261
    from mii import client, serve
    llm = serve(model, tensor_parallel=tensor_parallel_size)
262
    prompts = [request.prompt for request in requests]
263
264

    start = time.perf_counter()
265
    llm.generate(prompts, max_new_tokens=output_len)
266
    end = time.perf_counter()
267
268
    client = client(model)
    client.terminate_server()
269
270
271
    return end - start


272
def save_to_pytorch_benchmark_format(args: argparse.Namespace,
273
                                     results: dict[str, Any]) -> None:
274
275
276
277
278
279
280
281
282
283
284
285
286
    pt_records = convert_to_pytorch_benchmark_format(
        args=args,
        metrics={
            "requests_per_second": [results["requests_per_second"]],
            "tokens_per_second": [results["tokens_per_second"]],
        },
        extra_info={
            k: results[k]
            for k in ["elapsed_time", "num_requests", "total_num_tokens"]
        })
    if pt_records:
        # Don't use json suffix here as we don't want CI to pick it up
        pt_file = f"{os.path.splitext(args.output_json)[0]}.pytorch.json"
287
        write_to_json(pt_file, pt_records)
288
289


290
291
292
293
294
295
296
297
298
299
300
301
302
303
def get_requests(args, tokenizer):
    # Common parameters for all dataset types.
    common_kwargs = {
        "dataset_path": args.dataset_path,
        "random_seed": args.seed,
    }
    sample_kwargs = {
        "tokenizer": tokenizer,
        "lora_path": args.lora_path,
        "max_loras": args.max_loras,
        "num_requests": args.num_prompts,
        "input_len": args.input_len,
        "output_len": args.output_len,
    }
304

305
306
307
308
309
310
    if args.dataset_path is None or args.dataset_name == "random":
        sample_kwargs["range_ratio"] = args.random_range_ratio
        sample_kwargs["prefix_len"] = args.prefix_len
        dataset_cls = RandomDataset
    elif args.dataset_name == "sharegpt":
        dataset_cls = ShareGPTDataset
311
312
        if args.backend == "vllm-chat":
            sample_kwargs["enable_multimodal_chat"] = True
313
314
315
316
317
318
319
320
    elif args.dataset_name == "sonnet":
        assert tokenizer.chat_template or tokenizer.default_chat_template, (
            "Tokenizer/model must have chat template for sonnet dataset.")
        dataset_cls = SonnetDataset
        sample_kwargs["prefix_len"] = args.prefix_len
        sample_kwargs["return_prompt_formatted"] = True
    elif args.dataset_name == "burstgpt":
        dataset_cls = BurstGPTDataset
321
    elif args.dataset_name == "hf":
322
323
324
325
        if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS:
            dataset_cls = VisionArenaDataset
            common_kwargs['dataset_subset'] = None
            common_kwargs['dataset_split'] = "train"
326
            sample_kwargs["enable_multimodal_chat"] = True
327
        elif args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS:
328
329
            dataset_cls = InstructCoderDataset
            common_kwargs['dataset_split'] = "train"
330
331
332
333
334
        elif args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS:
            dataset_cls = ConversationDataset
            common_kwargs['dataset_subset'] = args.hf_subset
            common_kwargs['dataset_split'] = args.hf_split
            sample_kwargs["enable_multimodal_chat"] = True
335

336
337
338
339
340
341
342
    else:
        raise ValueError(f"Unknown dataset name: {args.dataset_name}")
    # Remove None values
    sample_kwargs = {k: v for k, v in sample_kwargs.items() if v is not None}
    return dataset_cls(**common_kwargs).sample(**sample_kwargs)


343
def main(args: argparse.Namespace):
344
345
    if args.seed is None:
        args.seed = 0
346
347
348
    print(args)
    random.seed(args.seed)
    # Sample the requests.
349
350
    tokenizer = AutoTokenizer.from_pretrained(
        args.tokenizer, trust_remote_code=args.trust_remote_code)
351
    requests = get_requests(args, tokenizer)
352
353
    is_multi_modal = any(request.multi_modal_data is not None
                         for request in requests)
354
    request_outputs: Optional[list[RequestOutput]] = None
Woosuk Kwon's avatar
Woosuk Kwon committed
355
    if args.backend == "vllm":
356
        if args.async_engine:
357
358
359
360
361
362
            elapsed_time = uvloop.run(
                run_vllm_async(
                    requests,
                    args.n,
                    AsyncEngineArgs.from_cli_args(args),
                    args.disable_frontend_multiprocessing,
363
                    args.disable_detokenize,
364
                ))
365
        else:
366
367
368
            elapsed_time, request_outputs = run_vllm(
                requests, args.n, EngineArgs.from_cli_args(args),
                args.disable_detokenize)
369
370
    elif args.backend == "hf":
        assert args.tensor_parallel_size == 1
371
        elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
372
373
                              args.hf_max_batch_size, args.trust_remote_code,
                              args.disable_detokenize)
374
375
376
    elif args.backend == "mii":
        elapsed_time = run_mii(requests, args.model, args.tensor_parallel_size,
                               args.output_len)
377
378
379
380
    elif args.backend == "vllm-chat":
        elapsed_time, request_outputs = run_vllm_chat(
            requests, args.n, EngineArgs.from_cli_args(args),
            args.disable_detokenize)
381
382
    else:
        raise ValueError(f"Unknown backend: {args.backend}")
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405

    if request_outputs:
        # Note: with the vllm and vllm-chat backends,
        # we have request_outputs, which we use to count tokens.
        total_prompt_tokens = 0
        total_output_tokens = 0
        for ro in request_outputs:
            if not isinstance(ro, RequestOutput):
                continue
            total_prompt_tokens += len(
                ro.prompt_token_ids) if ro.prompt_token_ids else 0
            total_output_tokens += sum(
                len(o.token_ids) for o in ro.outputs if o)
        total_num_tokens = total_prompt_tokens + total_output_tokens
    else:
        total_num_tokens = sum(r.prompt_len + r.expected_output_len
                               for r in requests)
        total_output_tokens = sum(r.expected_output_len for r in requests)
        total_prompt_tokens = total_num_tokens - total_output_tokens

    if is_multi_modal and args.backend != "vllm-chat":
        print("\033[91mWARNING\033[0m: Multi-modal request with "
              f"{args.backend} backend detected. The "
406
407
              "following metrics are not accurate because image tokens are not"
              " counted. See vllm-project/vllm/issues/9778 for details.")
408
        # TODO(vllm-project/vllm/issues/9778): Count multi-modal token length.
409
410
        # vllm-chat backend counts the image tokens now

Woosuk Kwon's avatar
Woosuk Kwon committed
411
    print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
412
413
          f"{total_num_tokens / elapsed_time:.2f} total tokens/s, "
          f"{total_output_tokens / elapsed_time:.2f} output tokens/s")
414
415
    print(f"Total num prompt tokens:  {total_prompt_tokens}")
    print(f"Total num output tokens:  {total_output_tokens}")
416

417
418
419
420
421
422
423
424
425
426
427
    # Output JSON results if specified
    if args.output_json:
        results = {
            "elapsed_time": elapsed_time,
            "num_requests": len(requests),
            "total_num_tokens": total_num_tokens,
            "requests_per_second": len(requests) / elapsed_time,
            "tokens_per_second": total_num_tokens / elapsed_time,
        }
        with open(args.output_json, "w") as f:
            json.dump(results, f, indent=4)
428
        save_to_pytorch_benchmark_format(args, results)
429

430

431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
def validate_args(args):
    """
    Validate command-line arguments.
    """

    # === Deprecation and Defaulting ===
    if args.dataset is not None:
        warnings.warn(
            "The '--dataset' argument will be deprecated in the next release. "
            "Please use '--dataset-name' and '--dataset-path' instead.",
            stacklevel=2)
        args.dataset_path = args.dataset

    if not getattr(args, "tokenizer", None):
        args.tokenizer = args.model

    # === Backend Validation ===
    valid_backends = {"vllm", "hf", "mii", "vllm-chat"}
    if args.backend not in valid_backends:
        raise ValueError(f"Unsupported backend: {args.backend}")

    # === Dataset Configuration ===
    if not args.dataset and not args.dataset_path:
        print(
            "When dataset path is not set, it will default to random dataset")
        args.dataset_name = 'random'
        if args.input_len is None:
            raise ValueError("input_len must be provided for a random dataset")

    # === Dataset Name Specific Checks ===
    # --hf-subset and --hf-split: only used
    # when dataset_name is 'hf'
    if args.dataset_name != "hf" and (
            getattr(args, "hf_subset", None) is not None
            or getattr(args, "hf_split", None) is not None):
        warnings.warn("--hf-subset and --hf-split will be ignored \
                since --dataset-name is not 'hf'.",
                      stacklevel=2)
469
    elif args.dataset_name == "hf":
470
        if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS:
471
            assert args.backend == "vllm-chat", "VisionArenaDataset needs to use vllm-chat as the backend."  #noqa: E501
472
        elif args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS:
473
            assert args.backend == "vllm", "InstructCoder dataset needs to use vllm as the backend."  #noqa: E501
474
475
        elif args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS:
            assert args.backend == "vllm-chat", "ConversationDataset needs to use vllm-chat as the backend."  #noqa: E501
476
477
478
        else:
            raise ValueError(
                f"{args.dataset_path} is not supported by hf dataset.")
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519

    # --random-range-ratio: only used when dataset_name is 'random'
    if args.dataset_name != 'random' and args.random_range_ratio is not None:
        warnings.warn("--random-range-ratio will be ignored since \
                --dataset-name is not 'random'.",
                      stacklevel=2)

    # --prefix-len: only used when dataset_name is 'random', 'sonnet', or not
    # set.
    if args.dataset_name not in {"random", "sonnet", None
                                 } and args.prefix_len is not None:
        warnings.warn("--prefix-len will be ignored since --dataset-name\
                 is not 'random', 'sonnet', or not set.",
                      stacklevel=2)

    # === LoRA Settings ===
    if getattr(args, "enable_lora", False) and args.backend != "vllm":
        raise ValueError(
            "LoRA benchmarking is only supported for vLLM backend")
    if getattr(args, "enable_lora", False) and args.lora_path is None:
        raise ValueError("LoRA path must be provided when enable_lora is True")

    # === Backend-specific Validations ===
    if args.backend == "hf" and args.hf_max_batch_size is None:
        raise ValueError("HF max batch size is required for HF backend")
    if args.backend != "hf" and args.hf_max_batch_size is not None:
        raise ValueError("HF max batch size is only for HF backend.")

    if args.backend in {"hf", "mii"} and getattr(args, "quantization",
                                                 None) is not None:
        raise ValueError("Quantization is only for vLLM backend.")

    if args.backend == "mii" and args.dtype != "auto":
        raise ValueError("dtype must be auto for MII backend.")
    if args.backend == "mii" and args.n != 1:
        raise ValueError("n must be 1 for MII backend.")
    if args.backend == "mii" and args.tokenizer != args.model:
        raise ValueError(
            "Tokenizer must be the same as the model for MII backend.")


520
if __name__ == "__main__":
521
    parser = FlexibleArgumentParser(description="Benchmark the throughput.")
522
523
    parser.add_argument("--backend",
                        type=str,
524
                        choices=["vllm", "hf", "mii", "vllm-chat"],
Woosuk Kwon's avatar
Woosuk Kwon committed
525
                        default="vllm")
526
527
528
529
530
531
    parser.add_argument(
        "--dataset-name",
        type=str,
        choices=["sharegpt", "random", "sonnet", "burstgpt", "hf"],
        help="Name of the dataset to benchmark on.",
        default="sharegpt")
532
533
534
535
536
537
538
539
540
    parser.add_argument(
        "--dataset",
        type=str,
        default=None,
        help="Path to the ShareGPT dataset, will be deprecated in\
            the next release. The dataset is expected to "
        "be a json in form of list[dict[..., conversations: "
        "list[dict[..., value: <prompt_or_response>]]]]")
    parser.add_argument("--dataset-path",
541
                        type=str,
542
                        default=None,
543
                        help="Path to the dataset")
544
545
546
547
548
549
550
551
552
    parser.add_argument("--input-len",
                        type=int,
                        default=None,
                        help="Input prompt length for each request")
    parser.add_argument("--output-len",
                        type=int,
                        default=None,
                        help="Output length for each request. Overrides the "
                        "output length from the dataset.")
553
554
555
    parser.add_argument("--n",
                        type=int,
                        default=1,
556
                        help="Number of generated sequences per prompt.")
557
558
559
    parser.add_argument("--num-prompts",
                        type=int,
                        default=1000,
560
                        help="Number of prompts to process.")
561
562
563
    parser.add_argument("--hf-max-batch-size",
                        type=int,
                        default=None,
564
                        help="Maximum batch size for HF backend.")
565
566
567
568
569
    parser.add_argument(
        '--output-json',
        type=str,
        default=None,
        help='Path to save the throughput results in JSON format.')
570
571
572
573
574
575
576
577
    parser.add_argument("--async-engine",
                        action='store_true',
                        default=False,
                        help="Use vLLM async engine rather than LLM class.")
    parser.add_argument("--disable-frontend-multiprocessing",
                        action='store_true',
                        default=False,
                        help="Disable decoupled async engine frontend.")
578
579
580
581
582
    parser.add_argument(
        "--disable-detokenize",
        action="store_true",
        help=("Do not detokenize the response (i.e. do not include "
              "detokenization time in the measurement)"))
583
584
585
586
587
588
589
    # LoRA
    parser.add_argument(
        "--lora-path",
        type=str,
        default=None,
        help="Path to the lora adapters to use. This can be an absolute path, "
        "a relative path, or a Hugging Face model identifier.")
590
591
592
593
594
595
596
597
598
    parser.add_argument("--prefix-len",
                        type=int,
                        default=None,
                        help="Number of prefix tokens per request."
                        "This is for the RandomDataset and SonnetDataset")
    # random dataset
    parser.add_argument(
        "--random-range-ratio",
        type=float,
599
        default=None,
600
601
602
        help="Range of sampled ratio of input/output length, "
        "used only for RandomDataSet.",
    )
603

604
605
606
607
608
609
610
611
612
613
    # hf dtaset
    parser.add_argument("--hf-subset",
                        type=str,
                        default=None,
                        help="Subset of the HF dataset.")
    parser.add_argument("--hf-split",
                        type=str,
                        default=None,
                        help="Split of the HF dataset.")

614
    parser = AsyncEngineArgs.add_cli_args(parser)
615
    args = parser.parse_args()
616
617
    if args.tokenizer is None:
        args.tokenizer = args.model
618
    validate_args(args)
619
    main(args)