pareto.py 1.34 KB
Newer Older
1
2
3
4
5
6
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0


def compute_pareto(x, y):
    """
7
8
9
10
11
12
13
    Compute the pareto front (top-left is better) for the given x and y values.

    Returns:
        tuple: (xs, ys, indices) where:
            - xs: list of x values on the pareto front
            - ys: list of y values on the pareto front
            - indices: list of original indices corresponding to the pareto points
14
15
16
    """
    # Validate inputs
    if x is None or y is None:
17
        return [], [], []
18
19
20
21
22

    if len(x) != len(y):
        raise ValueError("x and y must have the same length")

    if len(x) == 0:
23
        return [], [], []
24

25
26
    # Build point list with original indices and sort by x asc, then y desc
    points = [(x[i], y[i], i) for i in range(len(x))]
27
28
    points.sort(key=lambda p: (p[0], -p[1]))

29
    # Single pass to keep only non-dominated points (minimize x, maximize y)
30
31
    pareto = []
    max_y = float("-inf")
32
    for px, py, idx in points:
33
        if py > max_y:
34
            pareto.append((px, py, idx))
35
36
37
38
            max_y = py

    # Return sorted by x ascending for convenience
    pareto.sort(key=lambda p: (p[0], p[1]))
39
40
41
42
    xs = [px for px, _, _ in pareto]
    ys = [py for _, py, _ in pareto]
    indices = [idx for _, _, idx in pareto]
    return xs, ys, indices