profile_endpoint.py 2.68 KB
Newer Older
1
2
3
4
5
6
7
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

import argparse
import logging
import os

8
from utils.profile_decode import profile_decode
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
from utils.profile_prefill import profile_prefill

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
formatter = logging.Formatter(
    "%(asctime)s - %(name)s - %(levelname)s - %(message)s", "%Y-%m-%d %H:%M:%S"
)
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="profile a given endpoint's performance for prefill or decode"
    )
    parser.add_argument(
        "--mode",
        type=str,
        required=True,
        choices=["prefill", "decode"],
        help="mode to profile",
    )
    parser.add_argument(
        "--model_name",
        type=str,
        required=True,
        help="model name",
    )
    parser.add_argument(
        "--url",
        type=str,
        required=True,
        help="base url of the endpoint",
    )
    parser.add_argument(
        "--num_gpus",
        type=int,
        required=True,
        help="number of gpus",
    )
    parser.add_argument(
        "--max_kv_tokens",
        type=int,
        required=False,
        default=0,
        help="max kv tokens of the endpoint (only used for decode)",
    )
    parser.add_argument(
        "--work_dir",
        type=str,
        default="endpoint_profiling_results/",
        help="work directory to save the results",
    )
    parser.add_argument(
        "--max_context_length",
        type=int,
        default=16384,
        help="max context length of the endpoint",
    )
    parser.add_argument(
        "--interpolation_granularity",
        type=int,
        default=8,
        help="interpolation granularity for the results",
    )
    args = parser.parse_args()

    os.makedirs(args.work_dir, exist_ok=True)
    if args.mode == "prefill":
        profile_prefill(
            args.work_dir,
            args.model_name,
            args.url,
            args.num_gpus,
            args.max_context_length,
            args.interpolation_granularity,
        )
    elif args.mode == "decode":
        assert args.max_kv_tokens > 0, "max_kv_tokens must be provided for decode"
        profile_decode(
            args.work_dir,
            args.model_name,
            args.url,
            args.num_gpus,
            args.max_kv_tokens,
            args.max_context_length,
            args.interpolation_granularity,
        )
    else:
        raise ValueError(f"Invalid mode: {args.mode}")