sla_sweep.py 3.82 KB
Newer Older
1
2
3
4
5
6
7
8
9
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
import os
from abc import ABC, abstractmethod
from dataclasses import dataclass

from typing_extensions import override

10
11
12
SLA_EPS = 1e-8
"""Offset used to differentiate margins for equality checks."""

13
14
15
16
17
18

@dataclass
class SLACriterionBase(ABC):
    target: float

    @abstractmethod
19
20
21
22
23
    def compute_margin(self, actual: float) -> float:
        """
        Return a negative value or `0` if this criterion is met;
        otherwise a positive value indicating the distance to the target.
        """
24
25
26
27
28
29
        raise NotImplementedError

    @abstractmethod
    def format_cond(self, lhs: str) -> str:
        raise NotImplementedError

30
    def print_and_compute_margin(
31
32
33
        self,
        metrics: dict[str, float],
        metrics_key: str,
34
    ) -> float:
35
        metric = metrics[metrics_key]
36
        margin = self.compute_margin(metric)
37
38

        cond = self.format_cond(f"{metrics_key} = {metric:.2f}")
39
        print(f"Validating SLA: {cond} | " + ("PASSED" if margin <= 0 else "FAILED"))
40

41
        return margin
42
43
44
45
46


@dataclass
class SLALessThan(SLACriterionBase):
    @override
47
48
    def compute_margin(self, actual: float) -> float:
        return actual + SLA_EPS - self.target
49
50
51
52
53
54
55
56
57

    @override
    def format_cond(self, lhs: str) -> str:
        return f"{lhs}<{self.target:.2f}"


@dataclass
class SLALessThanOrEqualTo(SLACriterionBase):
    @override
58
59
    def compute_margin(self, actual: float) -> float:
        return actual - self.target
60
61
62
63
64
65
66
67
68

    @override
    def format_cond(self, lhs: str) -> str:
        return f"{lhs}<={self.target:.2f}"


@dataclass
class SLAGreaterThan(SLACriterionBase):
    @override
69
70
    def compute_margin(self, actual: float) -> float:
        return self.target + SLA_EPS - actual
71
72
73
74
75
76
77
78
79

    @override
    def format_cond(self, lhs: str) -> str:
        return f"{lhs}>{self.target:.2f}"


@dataclass
class SLAGreaterThanOrEqualTo(SLACriterionBase):
    @override
80
81
    def compute_margin(self, actual: float) -> float:
        return self.target - actual
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138

    @override
    def format_cond(self, lhs: str) -> str:
        return f"{lhs}>={self.target:.2f}"


# NOTE: The ordering is important! Match longer op_keys first
SLA_CRITERIA: dict[str, type[SLACriterionBase]] = {
    "<=": SLALessThanOrEqualTo,
    ">=": SLAGreaterThanOrEqualTo,
    "<": SLALessThan,
    ">": SLAGreaterThan,
}


class SLASweep(list["SLASweepItem"]):
    @classmethod
    def read_json(cls, filepath: os.PathLike):
        with open(filepath, "rb") as f:
            records = json.load(f)

        return cls.from_records(records)

    @classmethod
    def from_records(cls, records: list[dict[str, str]]):
        if not isinstance(records, list):
            raise TypeError(
                f"The SLA sweep should be a list of dictionaries, "
                f"but found type: {type(records)}"
            )

        return cls(SLASweepItem.from_record(record) for record in records)


class SLASweepItem(dict[str, SLACriterionBase]):
    @classmethod
    def from_record(cls, record: dict[str, str]):
        sla_criteria: dict[str, SLACriterionBase] = {}

        for metric_key, metric_value in record.items():
            for op_key in SLA_CRITERIA:
                if metric_value.startswith(op_key):
                    sla_criteria[metric_key] = SLA_CRITERIA[op_key](
                        float(metric_value.removeprefix(op_key))
                    )
                    break
            else:
                raise ValueError(
                    f"Invalid operator for "
                    f"SLA constraint '{metric_key}={metric_value}'. "
                    f"Valid operators are: {sorted(SLA_CRITERIA)}",
                )

        return cls(sla_criteria)

    def as_text(self, sep: str = ", ") -> str:
        return sep.join(v.format_cond(k) for k, v in self.items())