utils.py 14.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

import json
import logging
import queue
import threading
from enum import Enum

import gradio as gr
import numpy as np
from aiconfigurator.webapp.components.profiling import (
    create_performance_results_section,
    create_profiling_ui_components,
    inject_profiling_assets,
    load_profiling_javascript,
)

from benchmarks.profiler.utils.defaults import GPU_COST_PER_HOUR
from benchmarks.profiler.utils.pareto import compute_pareto

logger = logging.getLogger(__name__)


class PlotType(str, Enum):
    """Enum for the three plot/config types in the WebUI."""

    PREFILL = "prefill"
    DECODE = "decode"
    COST = "cost"


# Color palette for chart datasets
# TODO: handle case with more than 8 lines
CHART_COLORS = [
    "#1f77b4",  # blue
    "#ff7f0e",  # orange
    "#2ca02c",  # green
    "#d62728",  # red
    "#9467bd",  # purple
    "#8c564b",  # brown
    "#e377c2",  # pink
    "#7f7f7f",  # gray
]

# TODO: is this too long?
WEB_UI_SELECTION_TIMEOUT = 3600


def populate_prefill_data(data, prefill_data):
    """Populate prefill chart and table data."""
    if not prefill_data.num_gpus:
        return

    # Get unique GPU counts for labels
    unique_gpus = sorted(set(prefill_data.num_gpus))
    data[PlotType.PREFILL]["chart"]["labels"] = [f"{gpu} GPUs" for gpu in unique_gpus]

    # Populate chart data points
    chart_data = []
    for i, (gpu, ttft, thpt, label) in enumerate(
        zip(
            prefill_data.num_gpus,
            prefill_data.ttft,
            prefill_data.thpt_per_gpu,
            prefill_data.parallel_mapping_labels,
        )
    ):
        chart_data.append(
            {
                "x": round(ttft, 2),
                "y": round(thpt, 2),
                "gpu": gpu,
                "tableIdx": i,
                "gpuLabel": f"{gpu} GPUs [{label}]",
            }
        )
    data[PlotType.PREFILL]["chart"]["datasets"][0]["data"] = chart_data

    # Populate table data
    table_data = []
    for i, (gpu, ttft, thpt, label) in enumerate(
        zip(
            prefill_data.num_gpus,
            prefill_data.ttft,
            prefill_data.thpt_per_gpu,
            prefill_data.parallel_mapping_labels,
        )
    ):
        # TODO: Add actual config YAML data
        config_yaml = f"prefill_config_{i}.yaml"
        table_data.append([gpu, round(ttft, 2), round(thpt, 2), config_yaml])
    data[PlotType.PREFILL]["table"]["data"] = table_data


def populate_decode_data(data, decode_data):
    """Populate decode chart and table data."""
    if not decode_data.num_gpus:
        return

    # Group by GPU count for multiple datasets
    gpu_groups: dict[int, list[dict[str, float | int]]] = {}
    for i, (gpu, itl, thpt, label) in enumerate(
        zip(
            decode_data.num_gpus,
            decode_data.itl,
            decode_data.thpt_per_gpu,
            decode_data.parallel_mapping_labels,
        )
    ):
        if gpu not in gpu_groups:
            gpu_groups[gpu] = []
        gpu_groups[gpu].append({"x": round(itl, 2), "y": round(thpt, 2), "tableIdx": i})

    # Create datasets for each GPU count with different colors
    datasets = []
    for idx, (gpu, points) in enumerate(sorted(gpu_groups.items())):
        color = CHART_COLORS[idx % len(CHART_COLORS)]
        datasets.append(
            {
                "label": f"{gpu} GPUs",
                "data": points,
                "backgroundColor": color,
                "borderColor": color,
            }
        )
    data[PlotType.DECODE]["chart"]["datasets"] = datasets

    # Populate table data
    table_data = []
    for i, (gpu, itl, thpt, label) in enumerate(
        zip(
            decode_data.num_gpus,
            decode_data.itl,
            decode_data.thpt_per_gpu,
            decode_data.parallel_mapping_labels,
        )
    ):
        config_yaml = f"decode_config_{i}.yaml"
        table_data.append([gpu, round(itl, 2), round(thpt, 2), config_yaml])
    data[PlotType.DECODE]["table"]["data"] = table_data


def populate_cost_data(data, prefill_data, decode_data, args):
    """Populate cost chart and table data with pareto-optimal configurations."""
    if not prefill_data.num_gpus or not decode_data.num_gpus:
        return

    # Compute pareto front for prefill (minimize TTFT, maximize throughput)
    p_ttft, p_thpt, prefill_pareto_indices = compute_pareto(
        prefill_data.ttft, prefill_data.thpt_per_gpu
    )

    # Compute pareto front for decode (minimize ITL, maximize throughput)
    d_itl, d_thpt, decode_pareto_indices = compute_pareto(
        decode_data.itl, decode_data.thpt_per_gpu
    )

    # Convert to numpy arrays
    p_ttft = np.array(p_ttft)
    p_thpt = np.array(p_thpt)
    d_itl = np.array(d_itl)
    d_thpt = np.array(d_thpt)

    # Generate cost datasets - one line per prefill config
    cost_datasets = []
    table_data = []
    cost_index_mapping = {}  # Map cost table row idx -> (prefill_idx, decode_idx)
    table_idx = 0

    for p_idx, (_p_ttft, _p_thpt) in enumerate(zip(p_ttft, p_thpt)):
        # Calculate prefill cost (fixed for this line)
        prefill_cost = args.isl * 1000 / _p_thpt * GPU_COST_PER_HOUR / 3600

        # For each decode config, calculate total cost
        line_data = []
        for d_idx, (_d_itl, _d_thpt) in enumerate(zip(d_itl, d_thpt)):
            # Calculate decode cost
            decode_cost = args.osl * 1000 / _d_thpt * GPU_COST_PER_HOUR / 3600
            total_cost = prefill_cost + decode_cost

            # X-axis: tokens per user (based on ITL)
            tokens_per_user = 1000 / _d_itl

            line_data.append(
                {
                    "x": round(tokens_per_user, 2),
                    "y": round(total_cost, 2),
                    "tableIdx": table_idx,
                }
            )

            # Store mapping from cost table row to original indices
            orig_prefill_idx = prefill_pareto_indices[p_idx]
            orig_decode_idx = decode_pareto_indices[d_idx]
            cost_index_mapping[table_idx] = (orig_prefill_idx, orig_decode_idx)

            # Add to table data
            table_data.append(
                [
                    round(_p_ttft, 2),
                    round(_p_thpt, 2),
                    round(_d_itl, 2),
                    round(_d_thpt, 2),
                    round(tokens_per_user, 2),
                    round(total_cost, 2),
                    f"cost_config_{table_idx}.yaml",  # TODO: Add actual config
                ]
            )
            table_idx += 1

        # Create dataset for this prefill config
        color = CHART_COLORS[p_idx % len(CHART_COLORS)]
        cost_datasets.append(
            {
                "label": f"TTFT: {_p_ttft:.2f}ms",
                "data": line_data,
                "backgroundColor": color,
                "borderColor": color,
            }
        )

    data[PlotType.COST]["chart"]["datasets"] = cost_datasets
    data[PlotType.COST]["table"]["data"] = table_data

    # Store the index mapping in the JSON for reference
    data[PlotType.COST]["index_mapping"] = {
        str(k): list(v) for k, v in cost_index_mapping.items()
    }


def create_selection_handler(
    data_dict, selection_queue, prefill_selection, decode_selection
):
    """Create a selection handler closure for the WebUI.

    Args:
        data_dict: Parsed JSON data containing cost index mapping
        selection_queue: Queue to communicate selections to main thread
        prefill_selection: Dict tracking prefill selection state
        decode_selection: Dict tracking decode selection state

    Returns:
        Callable: Selection handler function for Gradio
    """

    def handle_selection(selection_json):
        """Handle datapoint selection from table."""
        if not selection_json or selection_json.strip() == "":
            return

        try:
            selection = json.loads(selection_json)
            plot_type = selection.get("plotType")
            row_idx = selection.get("rowIndex")

            logger.info(f"Selection received: {plot_type}, row {row_idx}")

            # Store selection for later confirmation
            if plot_type == PlotType.COST:
                # Cost selection - use index mapping to get original indices
                cost_index_mapping = data_dict[PlotType.COST].get("index_mapping", {})
                mapping_entry = cost_index_mapping.get(str(row_idx))

                if mapping_entry:
                    prefill_idx, decode_idx = mapping_entry
                    if prefill_idx is not None and decode_idx is not None:
                        logger.info(
                            f"Cost selection determines: Prefill={prefill_idx}, Decode={decode_idx}"
                        )
                        # Auto-submit for cost selection
                        selection_queue.put((prefill_idx, decode_idx))
            elif plot_type == PlotType.PREFILL:
                prefill_selection["idx"] = row_idx
                logger.info(f"Prefill selected: {row_idx}")
                # Check if we have both selections
                if decode_selection["idx"] is not None:
                    logger.info(
                        f"Both selections complete: Prefill={row_idx}, Decode={decode_selection['idx']}"
                    )
                    selection_queue.put((row_idx, decode_selection["idx"]))
                else:
                    logger.info("Waiting for decode selection...")
            elif plot_type == PlotType.DECODE:
                decode_selection["idx"] = row_idx
                logger.info(f"Decode selected: {row_idx}")
                # Check if we have both selections
                if prefill_selection["idx"] is not None:
                    logger.info(
                        f"Both selections complete: Prefill={prefill_selection['idx']}, Decode={row_idx}"
                    )
                    selection_queue.put((prefill_selection["idx"], row_idx))
                else:
                    logger.info("Waiting for prefill selection...")

        except Exception as e:
            logger.error(f"Error handling selection: {e}")

    return handle_selection


def create_gradio_interface(json_data_str, handle_selection):
    """Create the Gradio interface for configuration selection.

    Args:
        json_data_str: JSON string containing profiling data
        handle_selection: Selection handler function

    Returns:
        gr.Blocks: Configured Gradio demo
    """
    with gr.Blocks(title="Configuration Selection") as demo:
        # Create hidden UI components (reused from AIC profiling module)
        ui_components = create_profiling_ui_components()
        selection_input = ui_components["selection_input"]
        selection_button = ui_components["selection_button"]
        json_data = ui_components["json_data"]

        # Inject CSS and modal (reused from AIC profiling module)
        inject_profiling_assets()

        gr.Markdown("# 📊 Profiling Results - Select Configuration")
        gr.Markdown(
            """
            **Two ways to select prefill and decode configs:**
            1. **Cost Analysis** (recommended): Click any row in the Cost Analysis table - automatically determines both prefill and decode
            2. **Individual**: Click one row in the Prefill table AND one row in the Decode table
            The selection will be processed automatically once complete.

            > 📝 **Note:** The dotted red line in the prefill and decode charts are default TTFT and ITL SLAs if not specified.

            > ⚠️ **Warning:** The TTFT values here represent the ideal case when requests arrive uniformly, minimizing queueing. Real-world TTFT may be higher than profiling results. To mitigate the issue, planner uses ][correction factors](https://github.com/ai-dynamo/dynamo/blob/main/docs/planner/sla_planner.md#2-correction-factor-calculation) to adjust dynamically at runtime.
            """
        )

        # Performance Results Section (reused from AIC profiling module)
        create_performance_results_section()

        # Handle selection button
        selection_button.click(
            fn=handle_selection,
            inputs=[selection_input],
            outputs=[],
        )

        # Trigger visualization when JSON data changes
        json_data.change(
            fn=None,
            inputs=[json_data],
            outputs=[],
            js=(
                "(data) => { if (data && data.trim() && window.initializeVisualizations) "
                "window.initializeVisualizations(data); }"
            ),
        )

        # Load JavaScript and data automatically on page load
        def load_data():
            """Load profiling data."""
            return json_data_str

        demo.load(
            fn=load_data, inputs=[], outputs=[json_data], js=load_profiling_javascript()
        )

    return demo


def wait_for_selection(demo, selection_queue, port):
    """Launch the demo and wait for user selection.

    Args:
        demo: Gradio demo instance
        selection_queue: Queue to receive selection from UI
        port: Port number for the WebUI

    Returns:
        tuple[int, int]: (selected_prefill_idx, selected_decode_idx)
    """

    # Launch the interface in a separate thread
    def launch_thread():
        demo.launch(
            server_name="0.0.0.0",
            server_port=port,
            share=False,
            prevent_thread_lock=True,
        )

    thread = threading.Thread(target=launch_thread, daemon=True)
    thread.start()

    logger.info(f"WebUI launched. Waiting for user selection on http://0.0.0.0:{port}")
    logger.info("Please select a row from the Cost Analysis table")

    # Block and wait for selection
    try:
        selected_prefill_idx, selected_decode_idx = selection_queue.get(
            timeout=WEB_UI_SELECTION_TIMEOUT
        )
        logger.info(
            f"User selected: Prefill={selected_prefill_idx}, Decode={selected_decode_idx}"
        )

        # Close the demo
        demo.close()

        return selected_prefill_idx, selected_decode_idx

    except queue.Empty:
        logger.error("Selection timeout - no selection made within 1 hour")
        demo.close()
        # Return default
        return 0, 0