utils.py 25.4 KB
Newer Older
1
2
3
4
5
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

import json
import logging
6
import os
7
8
import queue
import threading
9
import time
10
from enum import Enum
11
from pathlib import Path
12
13
14

import gradio as gr
import numpy as np
15
import yaml
16
17
18
19
20
21
22
from aiconfigurator.webapp.components.profiling import (
    create_performance_results_section,
    create_profiling_ui_components,
    inject_profiling_assets,
    load_profiling_javascript,
)

23
24
25
26
27
from benchmarks.profiler.utils.dgd_generation import (
    generate_decode_service_config_preview,
    generate_prefill_decode_services_config_preview,
    generate_prefill_service_config_preview,
)
28
29
30
31
32
from benchmarks.profiler.utils.pareto import compute_pareto

logger = logging.getLogger(__name__)


33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
# Global variable to track selection completion for graceful shutdown
_selection_complete = threading.Event()

# Global error state for propagating profiling errors to WebUI
_profiling_errors: list[str] = []


def add_profiling_error(error_message: str) -> None:
    """Add an error message to be displayed in the WebUI.

    Args:
        error_message: The error message to display
    """
    _profiling_errors.append(error_message)
    logger.error(f"Profiling error: {error_message}")


def get_profiling_errors() -> list[str]:
    """Get all profiling errors.

    Returns:
        List of error messages
    """
    return _profiling_errors.copy()


def clear_profiling_errors() -> None:
    """Clear all profiling errors."""
    _profiling_errors.clear()


64
65
66
67
68
def dump_yaml_with_header(header_lines: list[str], obj: dict) -> str:
    """Dump YAML with a leading comment header (used for WebUI config previews)."""
    header = "\n".join(header_lines + ["#"])
    body = yaml.safe_dump(obj, sort_keys=False)
    return f"{header}\n{body}"
69
70


71
72
73
74
75
76
77
def _maybe_add_model_backend_header_lines(header_lines: list[str], args) -> None:
    model = getattr(args, "model", None)
    backend = getattr(args, "backend", None)
    if model:
        header_lines.append(f"# Model: {model}")
    if backend:
        header_lines.append(f"# Backend: {backend}")
78
79


80
81
82
83
84
85
86
87
88
def build_single_service_preview_header_lines(
    *,
    service_name: str,
    engine_type: str,
    mapping,
    ttft_or_itl_ms: float | None,
    thpt_per_gpu: float | None,
    args,
) -> list[str]:
89
    header_lines = [
90
91
        "# DynamoGraphDeployment Service Config Preview",
        f"# Service: {service_name}",
92
        f"# Engine: {engine_type}",
93
94
        f"# Num GPUs: {mapping.get_num_gpus()}",
        f"# Parallelization: {mapping.label()}",
95
    ]
96
97
98
99
    if engine_type == "prefill" and ttft_or_itl_ms is not None:
        header_lines.append(f"# Profiled TTFT: {round(ttft_or_itl_ms, 2)} ms")
    if engine_type == "decode" and ttft_or_itl_ms is not None:
        header_lines.append(f"# Profiled ITL: {round(ttft_or_itl_ms, 2)} ms")
100
101
102
103
    if thpt_per_gpu is not None:
        header_lines.append(
            f"# Profiled Throughput: {round(thpt_per_gpu, 2)} tokens/s/GPU"
        )
104
105
106
    _maybe_add_model_backend_header_lines(header_lines, args)
    header_lines.append(
        "# Note: This is a service-only preview. Final config includes planner."
107
    )
108
    return header_lines
109
110


111
112
113
114
def build_two_service_preview_header_lines(
    *,
    prefill_service_name: str,
    decode_service_name: str,
115
116
    prefill_mapping,
    decode_mapping,
117
118
119
120
121
122
    prefill_ttft_ms: float | None,
    prefill_thpt_per_gpu: float | None,
    decode_itl_ms: float | None,
    decode_thpt_per_gpu: float | None,
    args,
) -> list[str]:
123
    header_lines = [
124
125
126
        "# DynamoGraphDeployment Services Config Preview",
        f"# Prefill service: {prefill_service_name} ({prefill_mapping.get_num_gpus()} GPU(s), {prefill_mapping.label()})",
        f"# Decode service: {decode_service_name} ({decode_mapping.get_num_gpus()} GPU(s), {decode_mapping.label()})",
127
    ]
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
    if prefill_ttft_ms is not None:
        header_lines.append(f"# Profiled TTFT: {round(prefill_ttft_ms, 2)} ms")
    if decode_itl_ms is not None:
        header_lines.append(f"# Profiled ITL: {round(decode_itl_ms, 2)} ms")
    if prefill_thpt_per_gpu is not None:
        header_lines.append(
            f"# Profiled Prefill Throughput: {round(prefill_thpt_per_gpu, 2)} tokens/s/GPU"
        )
    if decode_thpt_per_gpu is not None:
        header_lines.append(
            f"# Profiled Decode Throughput: {round(decode_thpt_per_gpu, 2)} tokens/s/GPU"
        )
    _maybe_add_model_backend_header_lines(header_lines, args)
    header_lines.append(
        "# Note: This is a services-only preview. Final config includes planner."
    )
    return header_lines
145
146


147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
class PlotType(str, Enum):
    """Enum for the three plot/config types in the WebUI."""

    PREFILL = "prefill"
    DECODE = "decode"
    COST = "cost"


# Color palette for chart datasets
# TODO: handle case with more than 8 lines
CHART_COLORS = [
    "#1f77b4",  # blue
    "#ff7f0e",  # orange
    "#2ca02c",  # green
    "#d62728",  # red
    "#9467bd",  # purple
    "#8c564b",  # brown
    "#e377c2",  # pink
    "#7f7f7f",  # gray
]

# TODO: is this too long?
WEB_UI_SELECTION_TIMEOUT = 3600


172
173
174
175
176
177
178
179
180
def generate_config_data(
    prefill_data,
    decode_data,
    args,
    write_to_disk: bool = True,
):
    """
    Generate JSON data file for WebUI from profiling results.

181
182
183
    Note: This function computes GPU hours (not cost). The frontend handles
    cost calculation when the user provides a GPU cost per hour value.

184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
    Args:
        prefill_data: PrefillProfileData instance
        decode_data: DecodeProfileData instance
        args: Arguments containing SLA targets (ttft, itl, isl, osl) and output_dir
        write_to_disk: Whether to write the generated JSON to args.output_dir/webui_data.json

    Returns:
        dict: Data dict for WebUI consumption.
    """
    # Load template
    template_path = Path(__file__).parent / "data_template.json"
    with open(template_path, "r") as f:
        data = json.load(f)

    # Construct output path
    output_path = os.path.join(args.output_dir, "webui_data.json")

    # Set SLA targets
    data[PlotType.PREFILL]["chart"]["target_line"]["value"] = args.ttft
    data[PlotType.PREFILL]["chart"]["target_line"][
        "label"
    ] = f"Target TTFT: {args.ttft} ms"

    data[PlotType.DECODE]["chart"]["target_line"]["value"] = args.itl
    data[PlotType.DECODE]["chart"]["target_line"][
        "label"
    ] = f"Target ITL: {args.itl} ms"

    data[PlotType.COST]["chart"][
        "title"
214
    ] = f"GPU Hours Per 1000 i{args.isl}o{args.osl} requests"
215
216

    # Populate data sections
217
218
219
    populate_prefill_data(data, prefill_data, args)
    populate_decode_data(data, decode_data, args)
    populate_cost_data(data, prefill_data, decode_data, args)
220
221
222
223
224
225
226
227
228
229
230

    # Save JSON file (optional)
    if write_to_disk:
        os.makedirs(os.path.dirname(output_path), exist_ok=True)
        with open(output_path, "w") as f:
            json.dump(data, f, indent=4)
        logger.info(f"Generated WebUI config data at {output_path}")

    return data


231
def populate_prefill_data(data, prefill_data, args):
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
    """Populate prefill chart and table data."""
    if not prefill_data.num_gpus:
        return

    # Get unique GPU counts for labels
    unique_gpus = sorted(set(prefill_data.num_gpus))
    data[PlotType.PREFILL]["chart"]["labels"] = [f"{gpu} GPUs" for gpu in unique_gpus]

    # Populate chart data points
    chart_data = []
    for i, (gpu, ttft, thpt, label) in enumerate(
        zip(
            prefill_data.num_gpus,
            prefill_data.ttft,
            prefill_data.thpt_per_gpu,
            prefill_data.parallel_mapping_labels,
        )
    ):
        chart_data.append(
            {
                "x": round(ttft, 2),
                "y": round(thpt, 2),
                "gpu": gpu,
                "tableIdx": i,
                "gpuLabel": f"{gpu} GPUs [{label}]",
            }
        )
    data[PlotType.PREFILL]["chart"]["datasets"][0]["data"] = chart_data

    # Populate table data
    table_data = []
263
    for i, (gpu, ttft, thpt, label, mapping) in enumerate(
264
265
266
267
268
        zip(
            prefill_data.num_gpus,
            prefill_data.ttft,
            prefill_data.thpt_per_gpu,
            prefill_data.parallel_mapping_labels,
269
            prefill_data.parallel_mappings,
270
271
        )
    ):
272
273
274
275
276
277
278
279
280
        config_obj = generate_prefill_service_config_preview(
            config_path=args.config,
            args=args,
            best_prefill_mapping=mapping,
            num_gpus_per_node=getattr(args, "num_gpus_per_node", 8),
        )
        service_name = next(iter(config_obj.keys()))
        header_lines = build_single_service_preview_header_lines(
            service_name=service_name,
281
            engine_type="prefill",
282
283
            mapping=mapping,
            ttft_or_itl_ms=ttft,
284
            thpt_per_gpu=thpt,
285
            args=args,
286
        )
287
        config_yaml = dump_yaml_with_header(header_lines, config_obj)
288
289
290
291
        table_data.append([gpu, round(ttft, 2), round(thpt, 2), config_yaml])
    data[PlotType.PREFILL]["table"]["data"] = table_data


292
def populate_decode_data(data, decode_data, args):
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
    """Populate decode chart and table data."""
    if not decode_data.num_gpus:
        return

    # Group by GPU count for multiple datasets
    gpu_groups: dict[int, list[dict[str, float | int]]] = {}
    for i, (gpu, itl, thpt, label) in enumerate(
        zip(
            decode_data.num_gpus,
            decode_data.itl,
            decode_data.thpt_per_gpu,
            decode_data.parallel_mapping_labels,
        )
    ):
        if gpu not in gpu_groups:
            gpu_groups[gpu] = []
        gpu_groups[gpu].append({"x": round(itl, 2), "y": round(thpt, 2), "tableIdx": i})

    # Create datasets for each GPU count with different colors
    datasets = []
    for idx, (gpu, points) in enumerate(sorted(gpu_groups.items())):
        color = CHART_COLORS[idx % len(CHART_COLORS)]
        datasets.append(
            {
                "label": f"{gpu} GPUs",
                "data": points,
                "backgroundColor": color,
                "borderColor": color,
            }
        )
    data[PlotType.DECODE]["chart"]["datasets"] = datasets

    # Populate table data
    table_data = []
327
    for i, (gpu, itl, thpt, label, mapping) in enumerate(
328
329
330
331
332
        zip(
            decode_data.num_gpus,
            decode_data.itl,
            decode_data.thpt_per_gpu,
            decode_data.parallel_mapping_labels,
333
            decode_data.parallel_mappings,
334
335
        )
    ):
336
337
338
339
340
341
342
343
344
        config_obj = generate_decode_service_config_preview(
            config_path=args.config,
            args=args,
            best_decode_mapping=mapping,
            num_gpus_per_node=getattr(args, "num_gpus_per_node", 8),
        )
        service_name = next(iter(config_obj.keys()))
        header_lines = build_single_service_preview_header_lines(
            service_name=service_name,
345
            engine_type="decode",
346
347
            mapping=mapping,
            ttft_or_itl_ms=itl,
348
            thpt_per_gpu=thpt,
349
            args=args,
350
        )
351
        config_yaml = dump_yaml_with_header(header_lines, config_obj)
352
353
354
355
        table_data.append([gpu, round(itl, 2), round(thpt, 2), config_yaml])
    data[PlotType.DECODE]["table"]["data"] = table_data


356
357
358
359
360
361
def populate_cost_data(
    data,
    prefill_data,
    decode_data,
    args,
):
362
363
364
365
366
    """Populate cost chart and table data with pareto-optimal configurations.

    Note: This function computes GPU hours (not cost). The frontend handles
    cost calculation when the user provides a GPU cost per hour value.
    """
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
    if not prefill_data.num_gpus or not decode_data.num_gpus:
        return

    # Compute pareto front for prefill (minimize TTFT, maximize throughput)
    p_ttft, p_thpt, prefill_pareto_indices = compute_pareto(
        prefill_data.ttft, prefill_data.thpt_per_gpu
    )

    # Compute pareto front for decode (minimize ITL, maximize throughput)
    d_itl, d_thpt, decode_pareto_indices = compute_pareto(
        decode_data.itl, decode_data.thpt_per_gpu
    )

    # Convert to numpy arrays
    p_ttft = np.array(p_ttft)
    p_thpt = np.array(p_thpt)
    d_itl = np.array(d_itl)
    d_thpt = np.array(d_thpt)

    # Generate cost datasets - one line per prefill config
    cost_datasets = []
    table_data = []
    cost_index_mapping = {}  # Map cost table row idx -> (prefill_idx, decode_idx)
    table_idx = 0

    for p_idx, (_p_ttft, _p_thpt) in enumerate(zip(p_ttft, p_thpt)):
393
394
395
396
        # Get prefill config details for this pareto point
        orig_prefill_idx = prefill_pareto_indices[p_idx]
        prefill_mapping = prefill_data.parallel_mappings[orig_prefill_idx]
        prefill_num_gpus = prefill_mapping.get_num_gpus()
397

398
399
400
401
402
        # Calculate prefill GPU hours per 1000 requests
        # GPU hours = (tokens_per_request * num_requests) / (tokens_per_second_per_gpu * 3600) * num_gpus
        prefill_gpu_hours = args.isl * 1000 / _p_thpt / 3600 * prefill_num_gpus

        # For each decode config, calculate total GPU hours
403
404
        line_data = []
        for d_idx, (_d_itl, _d_thpt) in enumerate(zip(d_itl, d_thpt)):
405
406
407
408
409
410
411
412
            # Get decode config details for this pareto point
            orig_decode_idx = decode_pareto_indices[d_idx]
            decode_mapping = decode_data.parallel_mappings[orig_decode_idx]
            decode_num_gpus = decode_mapping.get_num_gpus()

            # Calculate decode GPU hours per 1000 requests (scaled by num_gpus)
            decode_gpu_hours = args.osl * 1000 / _d_thpt / 3600 * decode_num_gpus
            total_gpu_hours = prefill_gpu_hours + decode_gpu_hours
413
414
415
416
417
418
419

            # X-axis: tokens per user (based on ITL)
            tokens_per_user = 1000 / _d_itl

            line_data.append(
                {
                    "x": round(tokens_per_user, 2),
420
                    "y": round(total_gpu_hours, 4),
421
422
423
424
425
426
427
                    "tableIdx": table_idx,
                }
            )

            # Store mapping from cost table row to original indices
            cost_index_mapping[table_idx] = (orig_prefill_idx, orig_decode_idx)

428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
            services_obj = generate_prefill_decode_services_config_preview(
                config_path=args.config,
                args=args,
                best_prefill_mapping=prefill_mapping,
                best_decode_mapping=decode_mapping,
                num_gpus_per_node=getattr(args, "num_gpus_per_node", 8),
            )
            # Determine service names (backend-dependent)
            service_names = list(services_obj.keys())
            # Prefer stable names by picking based on subComponentType if present; fallback to insertion order.
            prefill_service_name = service_names[0]
            decode_service_name = (
                service_names[1] if len(service_names) > 1 else service_names[0]
            )
            header_lines = build_two_service_preview_header_lines(
                prefill_service_name=prefill_service_name,
                decode_service_name=decode_service_name,
445
446
                prefill_mapping=prefill_mapping,
                decode_mapping=decode_mapping,
447
448
449
450
451
                prefill_ttft_ms=float(_p_ttft),
                prefill_thpt_per_gpu=float(_p_thpt),
                decode_itl_ms=float(_d_itl),
                decode_thpt_per_gpu=float(_d_thpt),
                args=args,
452
            )
453
            config_yaml = dump_yaml_with_header(header_lines, services_obj)
454
455

            # Add to table data (GPU hours, not cost - frontend handles cost conversion)
456
457
458
459
460
461
462
            table_data.append(
                [
                    round(_p_ttft, 2),
                    round(_p_thpt, 2),
                    round(_d_itl, 2),
                    round(_d_thpt, 2),
                    round(tokens_per_user, 2),
463
464
                    round(total_gpu_hours, 4),
                    config_yaml,
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
                ]
            )
            table_idx += 1

        # Create dataset for this prefill config
        color = CHART_COLORS[p_idx % len(CHART_COLORS)]
        cost_datasets.append(
            {
                "label": f"TTFT: {_p_ttft:.2f}ms",
                "data": line_data,
                "backgroundColor": color,
                "borderColor": color,
            }
        )

    data[PlotType.COST]["chart"]["datasets"] = cost_datasets
    data[PlotType.COST]["table"]["data"] = table_data

    # Store the index mapping in the JSON for reference
    data[PlotType.COST]["index_mapping"] = {
        str(k): list(v) for k, v in cost_index_mapping.items()
    }


def create_selection_handler(
490
    data_dict_ref, selection_queue, prefill_selection, decode_selection
491
492
493
494
):
    """Create a selection handler closure for the WebUI.

    Args:
495
        data_dict_ref: Dict wrapper holding the latest parsed JSON data (mutated when UI inputs change)
496
497
498
499
500
        selection_queue: Queue to communicate selections to main thread
        prefill_selection: Dict tracking prefill selection state
        decode_selection: Dict tracking decode selection state

    Returns:
501
        Callable: Selection handler function for Gradio that returns a status message
502
503
504
    """

    def handle_selection(selection_json):
505
506
507
508
509
        """Handle datapoint selection from table.

        Returns:
            str: Status message to display in the UI
        """
510
        if not selection_json or selection_json.strip() == "":
511
            return ""
512
513

        try:
514
            data_dict = data_dict_ref["data"]
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
            selection = json.loads(selection_json)
            plot_type = selection.get("plotType")
            row_idx = selection.get("rowIndex")

            logger.info(f"Selection received: {plot_type}, row {row_idx}")

            # Store selection for later confirmation
            if plot_type == PlotType.COST:
                # Cost selection - use index mapping to get original indices
                cost_index_mapping = data_dict[PlotType.COST].get("index_mapping", {})
                mapping_entry = cost_index_mapping.get(str(row_idx))

                if mapping_entry:
                    prefill_idx, decode_idx = mapping_entry
                    if prefill_idx is not None and decode_idx is not None:
                        logger.info(
                            f"Cost selection determines: Prefill={prefill_idx}, Decode={decode_idx}"
                        )
533
534
                        # Signal selection complete and put in queue
                        _selection_complete.set()
535
                        selection_queue.put((prefill_idx, decode_idx))
536
                        return f"✅ Configuration selected! Prefill config #{prefill_idx}, Decode config #{decode_idx}. Processing..."
537
538
539
540
541
542
543
544
            elif plot_type == PlotType.PREFILL:
                prefill_selection["idx"] = row_idx
                logger.info(f"Prefill selected: {row_idx}")
                # Check if we have both selections
                if decode_selection["idx"] is not None:
                    logger.info(
                        f"Both selections complete: Prefill={row_idx}, Decode={decode_selection['idx']}"
                    )
545
                    _selection_complete.set()
546
                    selection_queue.put((row_idx, decode_selection["idx"]))
547
                    return f"✅  Configuration selected! Prefill config #{row_idx}, Decode config #{decode_selection['idx']}. Processing..."
548
                else:
549
                    return f"ℹ️  Prefill config #{row_idx} selected. Please select a Decode configuration."
550
551
552
553
554
555
556
557
            elif plot_type == PlotType.DECODE:
                decode_selection["idx"] = row_idx
                logger.info(f"Decode selected: {row_idx}")
                # Check if we have both selections
                if prefill_selection["idx"] is not None:
                    logger.info(
                        f"Both selections complete: Prefill={prefill_selection['idx']}, Decode={row_idx}"
                    )
558
                    _selection_complete.set()
559
                    selection_queue.put((prefill_selection["idx"], row_idx))
560
                    return f"✅  Configuration selected! Prefill config #{prefill_selection['idx']}, Decode config #{row_idx}. Processing..."
561
                else:
562
563
564
                    return f"ℹ️  Decode config #{row_idx} selected. Please select a Prefill configuration."

            return ""
565
566
567

        except Exception as e:
            logger.error(f"Error handling selection: {e}")
568
            return f"❌  Error: {str(e)}"
569
570
571
572

    return handle_selection


573
574
575
576
def create_gradio_interface(
    json_data_str,
    handle_selection,
):
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
    """Create the Gradio interface for configuration selection.

    Args:
        json_data_str: JSON string containing profiling data
        handle_selection: Selection handler function

    Returns:
        gr.Blocks: Configured Gradio demo
    """
    with gr.Blocks(title="Configuration Selection") as demo:
        # Create hidden UI components (reused from AIC profiling module)
        ui_components = create_profiling_ui_components()
        selection_input = ui_components["selection_input"]
        selection_button = ui_components["selection_button"]
        json_data = ui_components["json_data"]

        # Inject CSS and modal (reused from AIC profiling module)
        inject_profiling_assets()

        gr.Markdown("# 📊 Profiling Results - Select Configuration")
597

598
599
600
601
602
603
604
605
606
607
608
609
610
611
        # Display any profiling errors/warnings at the top
        profiling_errors = get_profiling_errors()
        if profiling_errors:
            error_text = "\n".join(f"- {err}" for err in profiling_errors)
            gr.Markdown(
                f"""
                <div style="background-color: #fff3cd; border: 1px solid #ffc107; border-radius: 4px; padding: 10px; margin-bottom: 10px;">
                <strong>⚠️ Profiling Warnings/Errors:</strong>

{error_text}
                </div>
                """
            )

612
613
614
        gr.Markdown(
            """
            **Two ways to select prefill and decode configs:**
615
616
            1. **GPU Hours Analysis** (recommended): Select any row in the GPU Hours table - automatically determines both prefill and decode
            2. **Individual**: Select one row in the Prefill table AND one row in the Decode table
617
618
            The selection will be processed automatically once complete.

619
620
            **Chart Reference Points:** 🔴 Max Throughput Under SLA · 🟡 Max Throughput Overall · 🟢 Latency-Optimized (lowest latency under SLA)

621
622
            > 📝 **Note:** The dotted red line in the prefill and decode charts are default TTFT and ITL SLAs if not specified.

623
624
625
            > ⚠️ **Warning:** The TTFT values here represent the ideal case when requests arrive uniformly, minimizing queueing. Real-world TTFT may be higher than profiling results. To mitigate the issue, planner uses [correction factors](https://github.com/ai-dynamo/dynamo/blob/main/docs/planner/sla_planner.md#2-correction-factor-calculation) to adjust dynamically at runtime.

            > 💡 **Tip:** Use the GPU cost checkbox and input in the charts section to convert GPU hours to cost.
626
627
628
            """
        )

629
630
631
632
633
        # Status message display for selection feedback
        selection_status = gr.Markdown(
            value="",
            elem_id="selection_status",
        )
634

635
636
637
        # Performance Results Section (reused from AIC profiling module)
        create_performance_results_section()

638
        # Handle selection button - now returns status message
639
640
641
        selection_button.click(
            fn=handle_selection,
            inputs=[selection_input],
642
            outputs=[selection_status],
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
        )

        # Trigger visualization when JSON data changes
        json_data.change(
            fn=None,
            inputs=[json_data],
            outputs=[],
            js=(
                "(data) => { if (data && data.trim() && window.initializeVisualizations) "
                "window.initializeVisualizations(data); }"
            ),
        )

        # Load JavaScript and data automatically on page load
        def load_data():
            """Load profiling data."""
            return json_data_str

        demo.load(
            fn=load_data, inputs=[], outputs=[json_data], js=load_profiling_javascript()
        )

    return demo


def wait_for_selection(demo, selection_queue, port):
    """Launch the demo and wait for user selection.

    Args:
        demo: Gradio demo instance
        selection_queue: Queue to receive selection from UI
        port: Port number for the WebUI

    Returns:
        tuple[int, int]: (selected_prefill_idx, selected_decode_idx)
    """

    # Launch the interface in a separate thread
    def launch_thread():
        demo.launch(
            server_name="0.0.0.0",
            server_port=port,
            share=False,
            prevent_thread_lock=True,
        )

    thread = threading.Thread(target=launch_thread, daemon=True)
    thread.start()

    logger.info(f"WebUI launched. Waiting for user selection on http://0.0.0.0:{port}")
    logger.info("Please select a row from the Cost Analysis table")

695
696
697
    # Reset the selection complete event
    _selection_complete.clear()

698
699
700
701
702
703
704
705
706
    # Block and wait for selection
    try:
        selected_prefill_idx, selected_decode_idx = selection_queue.get(
            timeout=WEB_UI_SELECTION_TIMEOUT
        )
        logger.info(
            f"User selected: Prefill={selected_prefill_idx}, Decode={selected_decode_idx}"
        )

707
708
709
710
711
712
        # Wait for the selection handler to complete and give UI time to show success message
        if _selection_complete.wait(timeout=2.0):
            # Give extra time for the UI to display the success message
            time.sleep(1.0)

        # Close the demo gracefully
713
714
715
716
717
718
719
720
721
        demo.close()

        return selected_prefill_idx, selected_decode_idx

    except queue.Empty:
        logger.error("Selection timeout - no selection made within 1 hour")
        demo.close()
        # Return default
        return 0, 0