utils.py 25.2 KB
Newer Older
1
2
3
4
5
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

import json
import logging
6
import os
7
8
import queue
import threading
9
import time
10
from enum import Enum
11
from pathlib import Path
12
13
14

import gradio as gr
import numpy as np
15
import yaml
16
17
18
19
20
21
22
23
24
25
26
27
from aiconfigurator.webapp.components.profiling import (
    create_performance_results_section,
    create_profiling_ui_components,
    inject_profiling_assets,
    load_profiling_javascript,
)

from benchmarks.profiler.utils.pareto import compute_pareto

logger = logging.getLogger(__name__)


28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
# Global variable to track selection completion for graceful shutdown
_selection_complete = threading.Event()

# Global error state for propagating profiling errors to WebUI
_profiling_errors: list[str] = []


def add_profiling_error(error_message: str) -> None:
    """Add an error message to be displayed in the WebUI.

    Args:
        error_message: The error message to display
    """
    _profiling_errors.append(error_message)
    logger.error(f"Profiling error: {error_message}")


def get_profiling_errors() -> list[str]:
    """Get all profiling errors.

    Returns:
        List of error messages
    """
    return _profiling_errors.copy()


def clear_profiling_errors() -> None:
    """Clear all profiling errors."""
    _profiling_errors.clear()


def generate_dgd_worker_config_yaml(
    parallel_mapping,
    engine_type: str,
    model: str | None = None,
    backend: str | None = None,
    ttft_or_itl: float | None = None,
    thpt_per_gpu: float | None = None,
) -> str:
    """
    Generate a DGD worker service config snippet for display in the WebUI.

    Uses ParallelizationMapping.label() for display and shows the service structure
    that would be used in the final DynamoGraphDeployment.

    Args:
        parallel_mapping: ParallelizationMapping instance
        engine_type: "prefill" or "decode"
        model: Model name/path
        backend: Backend name (sglang, vllm, trtllm)
        ttft_or_itl: TTFT (prefill) or ITL (decode) in ms
        thpt_per_gpu: Throughput per GPU in tokens/s/GPU

    Returns:
        YAML string representation of the DGD worker config
    """
    num_gpus = parallel_mapping.get_num_gpus()

    # Build the worker config in DGD style
    # Note: Actual args vary by backend; this shows the structure
    worker_config = {
        "componentType": "worker",
        "subComponentType": engine_type,
        "replicas": 1,
        "resources": {
            "limits": {
                "gpu": str(num_gpus),
            }
        },
    }

    # Build header comments with profiling metadata
    header_lines = [
        "# DynamoGraphDeployment Worker Config",
        f"# Engine: {engine_type}",
        f"# Num GPUs: {num_gpus}",
        f"# Parallelization: {parallel_mapping.label()}",
    ]

    if engine_type == "prefill" and ttft_or_itl is not None:
        header_lines.append(f"# Profiled TTFT: {round(ttft_or_itl, 2)} ms")
    elif engine_type == "decode" and ttft_or_itl is not None:
        header_lines.append(f"# Profiled ITL: {round(ttft_or_itl, 2)} ms")

    if thpt_per_gpu is not None:
        header_lines.append(
            f"# Profiled Throughput: {round(thpt_per_gpu, 2)} tokens/s/GPU"
        )

    if model:
        header_lines.append(f"# Model: {model}")
    if backend:
        header_lines.append(f"# Backend: {backend}")

    header_lines.append("#")
    header_lines.append("# Note: Final config generated after selection includes")
    header_lines.append("# backend-specific args and planner configuration.")

    # Add the actual config
    service_name = f"{engine_type.capitalize()}Worker"
    body = yaml.dump(
        {service_name: worker_config}, default_flow_style=False, sort_keys=False
    )

    return "\n".join(header_lines) + "\n" + body


def generate_dgd_config_yaml_for_display(
    prefill_mapping,
    decode_mapping,
    model: str | None = None,
    backend: str | None = None,
) -> str:
    """
    Generate a DGD config snippet for display in the WebUI.

    This shows the combined prefill + decode DynamoGraphDeployment structure.
    Uses ParallelizationMapping.label() for parallelization info.

    Args:
        prefill_mapping: ParallelizationMapping for prefill
        decode_mapping: ParallelizationMapping for decode
        model: Model name/path
        backend: Backend name

    Returns:
        YAML string representation of the DGD configuration
    """
    prefill_gpus = prefill_mapping.get_num_gpus()
    decode_gpus = decode_mapping.get_num_gpus()

    # Build DGD-style config showing the service structure
    config = {
        "apiVersion": "nvidia.com/v1alpha1",
        "kind": "DynamoGraphDeployment",
        "spec": {
            "services": {
                "PrefillWorker": {
                    "componentType": "worker",
                    "subComponentType": "prefill",
                    "replicas": 1,
                    "resources": {
                        "limits": {"gpu": str(prefill_gpus)},
                    },
                },
                "DecodeWorker": {
                    "componentType": "worker",
                    "subComponentType": "decode",
                    "replicas": 1,
                    "resources": {
                        "limits": {"gpu": str(decode_gpus)},
                    },
                },
            }
        },
    }

    # Build header comments with parallelization and model info
    header_lines = [
        "# DynamoGraphDeployment Configuration Preview",
        f"# Prefill: {prefill_gpus} GPU(s), {prefill_mapping.label()}",
        f"# Decode: {decode_gpus} GPU(s), {decode_mapping.label()}",
    ]
    if model:
        header_lines.append(f"# Model: {model}")
    if backend:
        header_lines.append(f"# Backend: {backend}")
    header_lines.append("#")
    header_lines.append("# Full config with planner saved to: config_with_planner.yaml")

    header = "\n".join(header_lines)
    body = yaml.dump(config, default_flow_style=False, sort_keys=False)

    return f"{header}\n{body}"


204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
class PlotType(str, Enum):
    """Enum for the three plot/config types in the WebUI."""

    PREFILL = "prefill"
    DECODE = "decode"
    COST = "cost"


# Color palette for chart datasets
# TODO: handle case with more than 8 lines
CHART_COLORS = [
    "#1f77b4",  # blue
    "#ff7f0e",  # orange
    "#2ca02c",  # green
    "#d62728",  # red
    "#9467bd",  # purple
    "#8c564b",  # brown
    "#e377c2",  # pink
    "#7f7f7f",  # gray
]

# TODO: is this too long?
WEB_UI_SELECTION_TIMEOUT = 3600


229
230
231
232
233
234
235
236
237
def generate_config_data(
    prefill_data,
    decode_data,
    args,
    write_to_disk: bool = True,
):
    """
    Generate JSON data file for WebUI from profiling results.

238
239
240
    Note: This function computes GPU hours (not cost). The frontend handles
    cost calculation when the user provides a GPU cost per hour value.

241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
    Args:
        prefill_data: PrefillProfileData instance
        decode_data: DecodeProfileData instance
        args: Arguments containing SLA targets (ttft, itl, isl, osl) and output_dir
        write_to_disk: Whether to write the generated JSON to args.output_dir/webui_data.json

    Returns:
        dict: Data dict for WebUI consumption.
    """
    # Load template
    template_path = Path(__file__).parent / "data_template.json"
    with open(template_path, "r") as f:
        data = json.load(f)

    # Construct output path
    output_path = os.path.join(args.output_dir, "webui_data.json")

    # Set SLA targets
    data[PlotType.PREFILL]["chart"]["target_line"]["value"] = args.ttft
    data[PlotType.PREFILL]["chart"]["target_line"][
        "label"
    ] = f"Target TTFT: {args.ttft} ms"

    data[PlotType.DECODE]["chart"]["target_line"]["value"] = args.itl
    data[PlotType.DECODE]["chart"]["target_line"][
        "label"
    ] = f"Target ITL: {args.itl} ms"

    data[PlotType.COST]["chart"][
        "title"
271
    ] = f"GPU Hours Per 1000 i{args.isl}o{args.osl} requests"
272
273

    # Populate data sections
274
275
276
    populate_prefill_data(data, prefill_data, args)
    populate_decode_data(data, decode_data, args)
    populate_cost_data(data, prefill_data, decode_data, args)
277
278
279
280
281
282
283
284
285
286
287

    # Save JSON file (optional)
    if write_to_disk:
        os.makedirs(os.path.dirname(output_path), exist_ok=True)
        with open(output_path, "w") as f:
            json.dump(data, f, indent=4)
        logger.info(f"Generated WebUI config data at {output_path}")

    return data


288
def populate_prefill_data(data, prefill_data, args):
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
    """Populate prefill chart and table data."""
    if not prefill_data.num_gpus:
        return

    # Get unique GPU counts for labels
    unique_gpus = sorted(set(prefill_data.num_gpus))
    data[PlotType.PREFILL]["chart"]["labels"] = [f"{gpu} GPUs" for gpu in unique_gpus]

    # Populate chart data points
    chart_data = []
    for i, (gpu, ttft, thpt, label) in enumerate(
        zip(
            prefill_data.num_gpus,
            prefill_data.ttft,
            prefill_data.thpt_per_gpu,
            prefill_data.parallel_mapping_labels,
        )
    ):
        chart_data.append(
            {
                "x": round(ttft, 2),
                "y": round(thpt, 2),
                "gpu": gpu,
                "tableIdx": i,
                "gpuLabel": f"{gpu} GPUs [{label}]",
            }
        )
    data[PlotType.PREFILL]["chart"]["datasets"][0]["data"] = chart_data

    # Populate table data
    table_data = []
320
    for i, (gpu, ttft, thpt, label, mapping) in enumerate(
321
322
323
324
325
        zip(
            prefill_data.num_gpus,
            prefill_data.ttft,
            prefill_data.thpt_per_gpu,
            prefill_data.parallel_mapping_labels,
326
            prefill_data.parallel_mappings,
327
328
        )
    ):
329
330
331
332
333
334
335
336
337
        # Generate DGD worker config YAML for display
        config_yaml = generate_dgd_worker_config_yaml(
            parallel_mapping=mapping,
            engine_type="prefill",
            model=getattr(args, "model", None),
            backend=getattr(args, "backend", None),
            ttft_or_itl=ttft,
            thpt_per_gpu=thpt,
        )
338
339
340
341
        table_data.append([gpu, round(ttft, 2), round(thpt, 2), config_yaml])
    data[PlotType.PREFILL]["table"]["data"] = table_data


342
def populate_decode_data(data, decode_data, args):
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
    """Populate decode chart and table data."""
    if not decode_data.num_gpus:
        return

    # Group by GPU count for multiple datasets
    gpu_groups: dict[int, list[dict[str, float | int]]] = {}
    for i, (gpu, itl, thpt, label) in enumerate(
        zip(
            decode_data.num_gpus,
            decode_data.itl,
            decode_data.thpt_per_gpu,
            decode_data.parallel_mapping_labels,
        )
    ):
        if gpu not in gpu_groups:
            gpu_groups[gpu] = []
        gpu_groups[gpu].append({"x": round(itl, 2), "y": round(thpt, 2), "tableIdx": i})

    # Create datasets for each GPU count with different colors
    datasets = []
    for idx, (gpu, points) in enumerate(sorted(gpu_groups.items())):
        color = CHART_COLORS[idx % len(CHART_COLORS)]
        datasets.append(
            {
                "label": f"{gpu} GPUs",
                "data": points,
                "backgroundColor": color,
                "borderColor": color,
            }
        )
    data[PlotType.DECODE]["chart"]["datasets"] = datasets

    # Populate table data
    table_data = []
377
    for i, (gpu, itl, thpt, label, mapping) in enumerate(
378
379
380
381
382
        zip(
            decode_data.num_gpus,
            decode_data.itl,
            decode_data.thpt_per_gpu,
            decode_data.parallel_mapping_labels,
383
            decode_data.parallel_mappings,
384
385
        )
    ):
386
387
388
389
390
391
392
393
394
        # Generate DGD worker config YAML for display
        config_yaml = generate_dgd_worker_config_yaml(
            parallel_mapping=mapping,
            engine_type="decode",
            model=getattr(args, "model", None),
            backend=getattr(args, "backend", None),
            ttft_or_itl=itl,
            thpt_per_gpu=thpt,
        )
395
396
397
398
        table_data.append([gpu, round(itl, 2), round(thpt, 2), config_yaml])
    data[PlotType.DECODE]["table"]["data"] = table_data


399
400
401
402
403
404
def populate_cost_data(
    data,
    prefill_data,
    decode_data,
    args,
):
405
406
407
408
409
    """Populate cost chart and table data with pareto-optimal configurations.

    Note: This function computes GPU hours (not cost). The frontend handles
    cost calculation when the user provides a GPU cost per hour value.
    """
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
    if not prefill_data.num_gpus or not decode_data.num_gpus:
        return

    # Compute pareto front for prefill (minimize TTFT, maximize throughput)
    p_ttft, p_thpt, prefill_pareto_indices = compute_pareto(
        prefill_data.ttft, prefill_data.thpt_per_gpu
    )

    # Compute pareto front for decode (minimize ITL, maximize throughput)
    d_itl, d_thpt, decode_pareto_indices = compute_pareto(
        decode_data.itl, decode_data.thpt_per_gpu
    )

    # Convert to numpy arrays
    p_ttft = np.array(p_ttft)
    p_thpt = np.array(p_thpt)
    d_itl = np.array(d_itl)
    d_thpt = np.array(d_thpt)

    # Generate cost datasets - one line per prefill config
    cost_datasets = []
    table_data = []
    cost_index_mapping = {}  # Map cost table row idx -> (prefill_idx, decode_idx)
    table_idx = 0

    for p_idx, (_p_ttft, _p_thpt) in enumerate(zip(p_ttft, p_thpt)):
436
437
438
439
        # Get prefill config details for this pareto point
        orig_prefill_idx = prefill_pareto_indices[p_idx]
        prefill_mapping = prefill_data.parallel_mappings[orig_prefill_idx]
        prefill_num_gpus = prefill_mapping.get_num_gpus()
440

441
442
443
444
445
        # Calculate prefill GPU hours per 1000 requests
        # GPU hours = (tokens_per_request * num_requests) / (tokens_per_second_per_gpu * 3600) * num_gpus
        prefill_gpu_hours = args.isl * 1000 / _p_thpt / 3600 * prefill_num_gpus

        # For each decode config, calculate total GPU hours
446
447
        line_data = []
        for d_idx, (_d_itl, _d_thpt) in enumerate(zip(d_itl, d_thpt)):
448
449
450
451
452
453
454
455
            # Get decode config details for this pareto point
            orig_decode_idx = decode_pareto_indices[d_idx]
            decode_mapping = decode_data.parallel_mappings[orig_decode_idx]
            decode_num_gpus = decode_mapping.get_num_gpus()

            # Calculate decode GPU hours per 1000 requests (scaled by num_gpus)
            decode_gpu_hours = args.osl * 1000 / _d_thpt / 3600 * decode_num_gpus
            total_gpu_hours = prefill_gpu_hours + decode_gpu_hours
456
457
458
459
460
461
462

            # X-axis: tokens per user (based on ITL)
            tokens_per_user = 1000 / _d_itl

            line_data.append(
                {
                    "x": round(tokens_per_user, 2),
463
                    "y": round(total_gpu_hours, 4),
464
465
466
467
468
469
470
                    "tableIdx": table_idx,
                }
            )

            # Store mapping from cost table row to original indices
            cost_index_mapping[table_idx] = (orig_prefill_idx, orig_decode_idx)

471
472
473
474
475
476
477
478
479
            # Generate DGD config YAML for display
            config_yaml = generate_dgd_config_yaml_for_display(
                prefill_mapping=prefill_mapping,
                decode_mapping=decode_mapping,
                model=getattr(args, "model", None),
                backend=getattr(args, "backend", None),
            )

            # Add to table data (GPU hours, not cost - frontend handles cost conversion)
480
481
482
483
484
485
486
            table_data.append(
                [
                    round(_p_ttft, 2),
                    round(_p_thpt, 2),
                    round(_d_itl, 2),
                    round(_d_thpt, 2),
                    round(tokens_per_user, 2),
487
488
                    round(total_gpu_hours, 4),
                    config_yaml,
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
                ]
            )
            table_idx += 1

        # Create dataset for this prefill config
        color = CHART_COLORS[p_idx % len(CHART_COLORS)]
        cost_datasets.append(
            {
                "label": f"TTFT: {_p_ttft:.2f}ms",
                "data": line_data,
                "backgroundColor": color,
                "borderColor": color,
            }
        )

    data[PlotType.COST]["chart"]["datasets"] = cost_datasets
    data[PlotType.COST]["table"]["data"] = table_data

    # Store the index mapping in the JSON for reference
    data[PlotType.COST]["index_mapping"] = {
        str(k): list(v) for k, v in cost_index_mapping.items()
    }


def create_selection_handler(
514
    data_dict_ref, selection_queue, prefill_selection, decode_selection
515
516
517
518
):
    """Create a selection handler closure for the WebUI.

    Args:
519
        data_dict_ref: Dict wrapper holding the latest parsed JSON data (mutated when UI inputs change)
520
521
522
523
524
        selection_queue: Queue to communicate selections to main thread
        prefill_selection: Dict tracking prefill selection state
        decode_selection: Dict tracking decode selection state

    Returns:
525
        Callable: Selection handler function for Gradio that returns a status message
526
527
528
    """

    def handle_selection(selection_json):
529
530
531
532
533
        """Handle datapoint selection from table.

        Returns:
            str: Status message to display in the UI
        """
534
        if not selection_json or selection_json.strip() == "":
535
            return ""
536
537

        try:
538
            data_dict = data_dict_ref["data"]
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
            selection = json.loads(selection_json)
            plot_type = selection.get("plotType")
            row_idx = selection.get("rowIndex")

            logger.info(f"Selection received: {plot_type}, row {row_idx}")

            # Store selection for later confirmation
            if plot_type == PlotType.COST:
                # Cost selection - use index mapping to get original indices
                cost_index_mapping = data_dict[PlotType.COST].get("index_mapping", {})
                mapping_entry = cost_index_mapping.get(str(row_idx))

                if mapping_entry:
                    prefill_idx, decode_idx = mapping_entry
                    if prefill_idx is not None and decode_idx is not None:
                        logger.info(
                            f"Cost selection determines: Prefill={prefill_idx}, Decode={decode_idx}"
                        )
557
558
                        # Signal selection complete and put in queue
                        _selection_complete.set()
559
                        selection_queue.put((prefill_idx, decode_idx))
560
                        return f"✅ Configuration selected! Prefill config #{prefill_idx}, Decode config #{decode_idx}. Processing..."
561
562
563
564
565
566
567
568
            elif plot_type == PlotType.PREFILL:
                prefill_selection["idx"] = row_idx
                logger.info(f"Prefill selected: {row_idx}")
                # Check if we have both selections
                if decode_selection["idx"] is not None:
                    logger.info(
                        f"Both selections complete: Prefill={row_idx}, Decode={decode_selection['idx']}"
                    )
569
                    _selection_complete.set()
570
                    selection_queue.put((row_idx, decode_selection["idx"]))
571
                    return f"✅  Configuration selected! Prefill config #{row_idx}, Decode config #{decode_selection['idx']}. Processing..."
572
                else:
573
                    return f"ℹ️  Prefill config #{row_idx} selected. Please select a Decode configuration."
574
575
576
577
578
579
580
581
            elif plot_type == PlotType.DECODE:
                decode_selection["idx"] = row_idx
                logger.info(f"Decode selected: {row_idx}")
                # Check if we have both selections
                if prefill_selection["idx"] is not None:
                    logger.info(
                        f"Both selections complete: Prefill={prefill_selection['idx']}, Decode={row_idx}"
                    )
582
                    _selection_complete.set()
583
                    selection_queue.put((prefill_selection["idx"], row_idx))
584
                    return f"✅  Configuration selected! Prefill config #{prefill_selection['idx']}, Decode config #{row_idx}. Processing..."
585
                else:
586
587
588
                    return f"ℹ️  Decode config #{row_idx} selected. Please select a Prefill configuration."

            return ""
589
590
591

        except Exception as e:
            logger.error(f"Error handling selection: {e}")
592
            return f"❌  Error: {str(e)}"
593
594
595
596

    return handle_selection


597
598
599
600
def create_gradio_interface(
    json_data_str,
    handle_selection,
):
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
    """Create the Gradio interface for configuration selection.

    Args:
        json_data_str: JSON string containing profiling data
        handle_selection: Selection handler function

    Returns:
        gr.Blocks: Configured Gradio demo
    """
    with gr.Blocks(title="Configuration Selection") as demo:
        # Create hidden UI components (reused from AIC profiling module)
        ui_components = create_profiling_ui_components()
        selection_input = ui_components["selection_input"]
        selection_button = ui_components["selection_button"]
        json_data = ui_components["json_data"]

        # Inject CSS and modal (reused from AIC profiling module)
        inject_profiling_assets()

        gr.Markdown("# 📊 Profiling Results - Select Configuration")
621

622
623
624
625
626
627
628
629
630
631
632
633
634
635
        # Display any profiling errors/warnings at the top
        profiling_errors = get_profiling_errors()
        if profiling_errors:
            error_text = "\n".join(f"- {err}" for err in profiling_errors)
            gr.Markdown(
                f"""
                <div style="background-color: #fff3cd; border: 1px solid #ffc107; border-radius: 4px; padding: 10px; margin-bottom: 10px;">
                <strong>⚠️ Profiling Warnings/Errors:</strong>

{error_text}
                </div>
                """
            )

636
637
638
        gr.Markdown(
            """
            **Two ways to select prefill and decode configs:**
639
640
            1. **GPU Hours Analysis** (recommended): Select any row in the GPU Hours table - automatically determines both prefill and decode
            2. **Individual**: Select one row in the Prefill table AND one row in the Decode table
641
642
643
644
            The selection will be processed automatically once complete.

            > 📝 **Note:** The dotted red line in the prefill and decode charts are default TTFT and ITL SLAs if not specified.

645
646
647
            > ⚠️ **Warning:** The TTFT values here represent the ideal case when requests arrive uniformly, minimizing queueing. Real-world TTFT may be higher than profiling results. To mitigate the issue, planner uses [correction factors](https://github.com/ai-dynamo/dynamo/blob/main/docs/planner/sla_planner.md#2-correction-factor-calculation) to adjust dynamically at runtime.

            > 💡 **Tip:** Use the GPU cost checkbox and input in the charts section to convert GPU hours to cost.
648
649
650
            """
        )

651
652
653
654
655
        # Status message display for selection feedback
        selection_status = gr.Markdown(
            value="",
            elem_id="selection_status",
        )
656

657
658
659
        # Performance Results Section (reused from AIC profiling module)
        create_performance_results_section()

660
        # Handle selection button - now returns status message
661
662
663
        selection_button.click(
            fn=handle_selection,
            inputs=[selection_input],
664
            outputs=[selection_status],
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
        )

        # Trigger visualization when JSON data changes
        json_data.change(
            fn=None,
            inputs=[json_data],
            outputs=[],
            js=(
                "(data) => { if (data && data.trim() && window.initializeVisualizations) "
                "window.initializeVisualizations(data); }"
            ),
        )

        # Load JavaScript and data automatically on page load
        def load_data():
            """Load profiling data."""
            return json_data_str

        demo.load(
            fn=load_data, inputs=[], outputs=[json_data], js=load_profiling_javascript()
        )

    return demo


def wait_for_selection(demo, selection_queue, port):
    """Launch the demo and wait for user selection.

    Args:
        demo: Gradio demo instance
        selection_queue: Queue to receive selection from UI
        port: Port number for the WebUI

    Returns:
        tuple[int, int]: (selected_prefill_idx, selected_decode_idx)
    """

    # Launch the interface in a separate thread
    def launch_thread():
        demo.launch(
            server_name="0.0.0.0",
            server_port=port,
            share=False,
            prevent_thread_lock=True,
        )

    thread = threading.Thread(target=launch_thread, daemon=True)
    thread.start()

    logger.info(f"WebUI launched. Waiting for user selection on http://0.0.0.0:{port}")
    logger.info("Please select a row from the Cost Analysis table")

717
718
719
    # Reset the selection complete event
    _selection_complete.clear()

720
721
722
723
724
725
726
727
728
    # Block and wait for selection
    try:
        selected_prefill_idx, selected_decode_idx = selection_queue.get(
            timeout=WEB_UI_SELECTION_TIMEOUT
        )
        logger.info(
            f"User selected: Prefill={selected_prefill_idx}, Decode={selected_decode_idx}"
        )

729
730
731
732
733
734
        # Wait for the selection handler to complete and give UI time to show success message
        if _selection_complete.wait(timeout=2.0):
            # Give extra time for the UI to display the success message
            time.sleep(1.0)

        # Close the demo gracefully
735
736
737
738
739
740
741
742
743
        demo.close()

        return selected_prefill_idx, selected_decode_idx

    except queue.Empty:
        logger.error("Selection timeout - no selection made within 1 hour")
        demo.close()
        # Return default
        return 0, 0