select_config.py 2.48 KB
Newer Older
1
2
3
4
5
6
7
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

import json
import logging
import queue

8
from benchmarks.profiler.utils.defaults import DEFAULT_GPU_COST_PER_HOUR
9
from benchmarks.profiler.webui.utils import (
10
    create_gpu_cost_update_handler,
11
12
    create_gradio_interface,
    create_selection_handler,
13
    generate_config_data,
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
    wait_for_selection,
)

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
formatter = logging.Formatter(
    "%(asctime)s - %(name)s - %(levelname)s - %(message)s", "%Y-%m-%d %H:%M:%S"
)
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)


def pick_config_with_webui(prefill_data, decode_data, args):
    """
    Launch WebUI for user to pick configurations.

    Args:
        prefill_data: PrefillProfileData instance
        decode_data: DecodeProfileData instance
        args: Arguments containing SLA targets and output_dir

    Returns:
        tuple[int, int]: (selected_prefill_idx, selected_decode_idx)
    """
40
41
42
43
44
45
46
47
48
    # Generate JSON data (also writes default JSON file for convenience)
    data_dict = generate_config_data(
        prefill_data,
        decode_data,
        args,
        gpu_cost_per_hour=DEFAULT_GPU_COST_PER_HOUR,
        write_to_disk=True,
    )
    json_data_str = json.dumps(data_dict)
49
50
51
52
53
54
55
56
57
58
59

    logger.info(f"Launching WebUI on port {args.webui_port}...")

    # Queue to communicate selection from UI to main thread
    selection_queue: queue.Queue[tuple[int | None, int | None]] = queue.Queue()

    # Track individual selections
    prefill_selection = {"idx": None}
    decode_selection = {"idx": None}

    # Create selection handler and Gradio interface
60
    data_dict_ref = {"data": data_dict}
61
    handle_selection = create_selection_handler(
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
        data_dict_ref, selection_queue, prefill_selection, decode_selection
    )
    update_gpu_cost_per_hour = create_gpu_cost_update_handler(
        prefill_data=prefill_data,
        decode_data=decode_data,
        args=args,
        data_dict_ref=data_dict_ref,
        default_gpu_cost_per_hour=DEFAULT_GPU_COST_PER_HOUR,
    )

    demo = create_gradio_interface(
        json_data_str,
        handle_selection,
        update_json_data_fn=update_gpu_cost_per_hour,
        default_gpu_cost_per_hour=DEFAULT_GPU_COST_PER_HOUR,
77
78
79
    )

    return wait_for_selection(demo, selection_queue, args.webui_port)