select_config.py 2.44 KB
Newer Older
1
2
3
4
5
6
7
8
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

import json
import logging
import queue

from benchmarks.profiler.webui.utils import (
9
10
    add_profiling_error,
    clear_profiling_errors,
11
12
    create_gradio_interface,
    create_selection_handler,
13
    generate_config_data,
14
15
16
    wait_for_selection,
)

17
18
19
# Re-export for use by profiler modules
__all__ = ["pick_config_with_webui", "add_profiling_error", "clear_profiling_errors"]

20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
formatter = logging.Formatter(
    "%(asctime)s - %(name)s - %(levelname)s - %(message)s", "%Y-%m-%d %H:%M:%S"
)
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)


def pick_config_with_webui(prefill_data, decode_data, args):
    """
    Launch WebUI for user to pick configurations.

    Args:
        prefill_data: PrefillProfileData instance
        decode_data: DecodeProfileData instance
        args: Arguments containing SLA targets and output_dir

    Returns:
        tuple[int, int]: (selected_prefill_idx, selected_decode_idx)
    """
43
44
45
46
47
    # Note: Don't clear profiling errors here - they should be accumulated
    # during the profiling run and displayed in the WebUI.
    # clear_profiling_errors() should be called at the start of a new profiling run.

    # Generate JSON data with GPU hours (frontend handles cost conversion)
48
49
50
51
52
53
54
    data_dict = generate_config_data(
        prefill_data,
        decode_data,
        args,
        write_to_disk=True,
    )
    json_data_str = json.dumps(data_dict)
55
56
57
58
59
60
61
62
63
64
65

    logger.info(f"Launching WebUI on port {args.webui_port}...")

    # Queue to communicate selection from UI to main thread
    selection_queue: queue.Queue[tuple[int | None, int | None]] = queue.Queue()

    # Track individual selections
    prefill_selection = {"idx": None}
    decode_selection = {"idx": None}

    # Create selection handler and Gradio interface
66
    data_dict_ref = {"data": data_dict}
67
    handle_selection = create_selection_handler(
68
69
70
        data_dict_ref, selection_queue, prefill_selection, decode_selection
    )

71
    # Note: GPU hours -> Cost conversion is handled by frontend JavaScript (gpu_cost_toggle.js)
72
73
74
    demo = create_gradio_interface(
        json_data_str,
        handle_selection,
75
76
77
    )

    return wait_for_selection(demo, selection_queue, args.webui_port)