"examples/vscode:/vscode.git/clone" did not exist on "ca5b681a1c1074f2121336a61ab0d5ca4fa47bf3"
Unverified Commit 0d5c8dfc authored by Hongkuan Zhou's avatar Hongkuan Zhou Committed by GitHub
Browse files

feat: add kalman filter as load predictor (#5554)


Signed-off-by: default avatarhongkuanz <hongkuanz@nvidia.com>
parent cb352d95
...@@ -63,12 +63,21 @@ class SLAPlannerDefaults(BasePlannerDefaults): ...@@ -63,12 +63,21 @@ class SLAPlannerDefaults(BasePlannerDefaults):
"http://prometheus-kube-prometheus-prometheus.monitoring.svc.cluster.local:9090", "http://prometheus-kube-prometheus-prometheus.monitoring.svc.cluster.local:9090",
) )
profile_results_dir = "profiling_results" profile_results_dir = "profiling_results"
isl = 3000 # in number of tokens isl = 3000 # in number of tokens
osl = 150 # in number of tokens osl = 150 # in number of tokens
ttft = 500.0 # in milliseconds ttft = 500.0 # in milliseconds
itl = 50.0 # in milliseconds itl = 50.0 # in milliseconds
load_predictor = "arima" # ["constant", "arima", "prophet"]
load_prediction_window_size = 50 # predict load using how many recent load samples # for load predictor
load_predictor = "arima" # ["constant", "arima", "kalman", "prophet"]
prophet_window_size = 50
load_predictor_log1p = False
kalman_q_level = 1.0
kalman_q_trend = 0.1
kalman_r = 10.0
kalman_min_points = 5
no_correction = False # disable correction factor, might be useful under some conditions like long cold start time no_correction = False # disable correction factor, might be useful under some conditions like long cold start time
......
...@@ -17,11 +17,14 @@ import logging ...@@ -17,11 +17,14 @@ import logging
import math import math
import warnings import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from argparse import Namespace
from datetime import datetime, timedelta from datetime import datetime, timedelta
from enum import Enum from enum import Enum
import numpy as np
import pandas as pd import pandas as pd
import pmdarima import pmdarima
from filterpy.kalman import KalmanFilter
from prophet import Prophet from prophet import Prophet
from dynamo.runtime.logging import configure_dynamo_logging from dynamo.runtime.logging import configure_dynamo_logging
...@@ -35,6 +38,19 @@ warnings.filterwarnings( ...@@ -35,6 +38,19 @@ warnings.filterwarnings(
message=".*force_all_finite.*", message=".*force_all_finite.*",
) )
# Silence very chatty Prophet/cmdstanpy logs (we keep planner logs at INFO).
for _name in (
"prophet",
"prophet.forecaster",
"prophet.models",
"cmdstanpy",
"cmdstanpy.model",
):
_l = logging.getLogger(_name)
_l.addHandler(logging.NullHandler())
_l.propagate = False
_l.setLevel(logging.WARNING)
class BasePredictor(ABC): class BasePredictor(ABC):
"""Base class for all load predictors""" """Base class for all load predictors"""
...@@ -83,7 +99,7 @@ class ConstantPredictor(BasePredictor): ...@@ -83,7 +99,7 @@ class ConstantPredictor(BasePredictor):
Assume load is constant and predict the next load to be the same as most recent load Assume load is constant and predict the next load to be the same as most recent load
""" """
def __init__(self, **kwargs): def __init__(self, _args: Namespace):
super().__init__(minimum_data_points=1) super().__init__(minimum_data_points=1)
def predict_next(self): def predict_next(self):
...@@ -96,16 +112,19 @@ class ARIMAPredictor(BasePredictor): ...@@ -96,16 +112,19 @@ class ARIMAPredictor(BasePredictor):
RAW = "raw" RAW = "raw"
LOG1P = "log1p" LOG1P = "log1p"
def __init__(self, window_size=100, minimum_data_points=5): def __init__(self, args: Namespace):
super().__init__(minimum_data_points=minimum_data_points) super().__init__(minimum_data_points=5)
self.window_size = window_size # How many past points to use
self.model = None self.model = None
# Keep raw values so we can fit in raw space first, then fallback to log1p space. # Keep raw values so we can fit in raw space first, then fallback to log1p space.
self._raw_buffer: list[float] = [] self._raw_buffer: list[float] = []
# Pending raw points to incrementally update the fitted model with. # Pending raw points to incrementally update the fitted model with.
self._pending_raw_updates: list[float] = [] self._pending_raw_updates: list[float] = []
# Current modeling space # Shared log1p knob across predictors. Back-compat: `--arima-mode=log1p`.
self._mode: ARIMAPredictor.Mode = ARIMAPredictor.Mode.RAW use_log1p = bool(getattr(args, "load_predictor_log1p", False))
self._requested_mode = (
ARIMAPredictor.Mode.LOG1P if use_log1p else ARIMAPredictor.Mode.RAW
)
self._mode: ARIMAPredictor.Mode = self._requested_mode
def get_last_value(self): def get_last_value(self):
"""Return last value in original scale.""" """Return last value in original scale."""
...@@ -123,14 +142,9 @@ class ARIMAPredictor(BasePredictor): ...@@ -123,14 +142,9 @@ class ARIMAPredictor(BasePredictor):
raw = max(0.0, float(self.data_buffer[-1])) raw = max(0.0, float(self.data_buffer[-1]))
self._raw_buffer.append(raw) self._raw_buffer.append(raw)
self._pending_raw_updates.append(raw) self._pending_raw_updates.append(raw)
# If we are in log1p mode, keep data_buffer in model space. # Keep `data_buffer` in the model space.
if self._mode == ARIMAPredictor.Mode.LOG1P: if self._mode == ARIMAPredictor.Mode.LOG1P:
self.data_buffer[-1] = math.log1p(raw) self.data_buffer[-1] = math.log1p(raw)
# Keep only the last window_size points
if len(self.data_buffer) > self.window_size:
self.data_buffer = self.data_buffer[-self.window_size :]
if len(self._raw_buffer) > self.window_size:
self._raw_buffer = self._raw_buffer[-self.window_size :]
def predict_next(self): def predict_next(self):
"""Predict the next value(s)""" """Predict the next value(s)"""
...@@ -145,8 +159,11 @@ class ARIMAPredictor(BasePredictor): ...@@ -145,8 +159,11 @@ class ARIMAPredictor(BasePredictor):
try: try:
# Fit auto ARIMA model once, then only do incremental updates. # Fit auto ARIMA model once, then only do incremental updates.
if self.model is None: if self.model is None:
# Always try raw space first # First fit: honor requested mode
self._mode = ARIMAPredictor.Mode.RAW self._mode = self._requested_mode
if self._mode == ARIMAPredictor.Mode.LOG1P:
# Ensure model buffer is in log-space
self.data_buffer = [math.log1p(v) for v in self._raw_buffer]
self.model = pmdarima.auto_arima( self.model = pmdarima.auto_arima(
self.data_buffer, self.data_buffer,
suppress_warnings=True, suppress_warnings=True,
...@@ -163,11 +180,15 @@ class ARIMAPredictor(BasePredictor): ...@@ -163,11 +180,15 @@ class ARIMAPredictor(BasePredictor):
f"ARIMA selected order={order} seasonal_order={seasonal_order} aic={aic}" f"ARIMA selected order={order} seasonal_order={seasonal_order} aic={aic}"
) )
# If raw collapses to (0,d,0), fallback to log1p(y) # If user requested raw and it collapses to (0,d,0), fallback to log1p(y)
try: try:
if order is not None and len(order) == 3: if order is not None and len(order) == 3:
p, _, q = order p, _, q = order
if p == 0 and q == 0: if (
p == 0
and q == 0
and self._requested_mode == ARIMAPredictor.Mode.RAW
):
# Build log buffer/model in locals and only swap on success # Build log buffer/model in locals and only swap on success
log_buffer = [math.log1p(v) for v in self._raw_buffer] log_buffer = [math.log1p(v) for v in self._raw_buffer]
log_model = pmdarima.auto_arima( log_model = pmdarima.auto_arima(
...@@ -227,11 +248,18 @@ class ARIMAPredictor(BasePredictor): ...@@ -227,11 +248,18 @@ class ARIMAPredictor(BasePredictor):
# Time-series forecasting model from Meta # Time-series forecasting model from Meta
class ProphetPredictor(BasePredictor): class ProphetPredictor(BasePredictor):
def __init__(self, window_size=100, step_size=3600, minimum_data_points=5): def __init__(self, args: Namespace):
super().__init__(minimum_data_points=minimum_data_points) super().__init__(minimum_data_points=5)
self.window_size = window_size self._use_log1p = bool(getattr(args, "load_predictor_log1p", False))
# Window size is only used by Prophet (to bound refit cost).
self.window_size = getattr(
args,
"prophet_window_size",
getattr(args, "load_prediction_window_size", 50),
)
self.curr_step = 0 self.curr_step = 0
self.step_size = step_size # Use adjustment_interval as step size (seconds per observation)
self.step_size = getattr(args, "adjustment_interval", 3600)
self.start_date = datetime(2024, 1, 1) # Base date for generating timestamps self.start_date = datetime(2024, 1, 1) # Base date for generating timestamps
self.data_buffer = [] # Override to store dicts instead of values self.data_buffer = [] # Override to store dicts instead of values
self._seen_nonzero_since_idle_reset = False self._seen_nonzero_since_idle_reset = False
...@@ -239,7 +267,7 @@ class ProphetPredictor(BasePredictor): ...@@ -239,7 +267,7 @@ class ProphetPredictor(BasePredictor):
def add_data_point(self, value): def add_data_point(self, value):
"""Add new data point to the buffer""" """Add new data point to the buffer"""
# Use proper datetime for Prophet # Use proper datetime for Prophet
timestamp = self.start_date + timedelta(seconds=self.curr_step) timestamp = self.start_date + timedelta(seconds=self.curr_step * self.step_size)
value = 0 if math.isnan(value) else value value = 0 if math.isnan(value) else value
if value == 0 and not self._seen_nonzero_since_idle_reset: if value == 0 and not self._seen_nonzero_since_idle_reset:
...@@ -249,6 +277,8 @@ class ProphetPredictor(BasePredictor): ...@@ -249,6 +277,8 @@ class ProphetPredictor(BasePredictor):
if value != 0: if value != 0:
self._seen_nonzero_since_idle_reset = True self._seen_nonzero_since_idle_reset = True
if self._use_log1p:
value = math.log1p(max(0.0, value))
self.data_buffer.append({"ds": timestamp, "y": value}) self.data_buffer.append({"ds": timestamp, "y": value})
self.curr_step += 1 self.curr_step += 1
...@@ -260,7 +290,8 @@ class ProphetPredictor(BasePredictor): ...@@ -260,7 +290,8 @@ class ProphetPredictor(BasePredictor):
"""Get the last value from the buffer""" """Get the last value from the buffer"""
if not self.data_buffer: if not self.data_buffer:
return 0 return 0
return self.data_buffer[-1]["y"] y = float(self.data_buffer[-1]["y"])
return max(0.0, math.expm1(y)) if self._use_log1p else y
def predict_next(self): def predict_next(self):
"""Predict the next value""" """Predict the next value"""
...@@ -282,11 +313,93 @@ class ProphetPredictor(BasePredictor): ...@@ -282,11 +313,93 @@ class ProphetPredictor(BasePredictor):
# Make prediction # Make prediction
forecast = model.predict(future_df) forecast = model.predict(future_df)
return forecast["yhat"].iloc[0] yhat = float(forecast["yhat"].iloc[0])
return max(0.0, math.expm1(yhat)) if self._use_log1p else yhat
class KalmanPredictor(BasePredictor):
"""
Simple 1D Kalman predictor for online "observe 1 -> predict 1".
Uses a local linear trend model:
x_t = x_{t-1} + v_{t-1} + w
v_t = v_{t-1} + u
This tends to be a better match than ARIMA for low-latency smoothing + short-horizon
forecasting in bursty systems.
"""
def __init__(self, args: Namespace):
super().__init__(minimum_data_points=getattr(args, "kalman_min_points", 5))
# Shared log1p knob across predictors. Back-compat: `--kalman-log1p`.
self._use_log1p = bool(getattr(args, "load_predictor_log1p", False)) or bool(
getattr(args, "kalman_log1p", False)
)
q_level = getattr(args, "kalman_q_level", 1.0)
q_trend = getattr(args, "kalman_q_trend", 0.1)
r = getattr(args, "kalman_r", 10.0)
self._kf = KalmanFilter(dim_x=2, dim_z=1)
# State: [level, trend]
self._kf.x = np.array([[0.0], [0.0]], dtype=float)
self._kf.F = np.array([[1.0, 1.0], [0.0, 1.0]], dtype=float)
self._kf.H = np.array([[1.0, 0.0]], dtype=float)
self._kf.P *= 1000.0
self._kf.R = np.array([[max(1e-9, float(r))]], dtype=float)
self._kf.Q = np.array(
[
[max(1e-12, float(q_level)), 0.0],
[0.0, max(1e-12, float(q_trend))],
],
dtype=float,
)
self._initialized = False
# Gate repeated predict_next() calls: cache the one-step forecast so we
# don't advance the filter multiple times per interval.
self._has_cached_pred = False
self._cached_pred: float = 0.0
def add_data_point(self, value):
prev_len = len(self.data_buffer)
super().add_data_point(value)
if len(self.data_buffer) == prev_len:
return
z_raw = float(self.data_buffer[-1])
z = math.log1p(max(0.0, z_raw)) if self._use_log1p else z_raw
# immediately update the filter with new data point
if not self._initialized:
self._kf.x = np.array([[z], [0.0]], dtype=float)
self._initialized = True
else:
# If we already predicted this step, don't predict again.
if not self._has_cached_pred:
self._kf.predict()
self._kf.update(np.array([[z]], dtype=float))
# Consumed this step; clear cached forecast for next interval.
self._has_cached_pred = False
def predict_next(self):
if not self._initialized:
return self.get_last_value()
if self._has_cached_pred:
return (
max(0.0, math.expm1(self._cached_pred))
if self._use_log1p
else self._cached_pred
)
# one-step ahead prediction: predict then return predicted level
self._kf.predict()
self._cached_pred = float(self._kf.x[0][0])
self._has_cached_pred = True
return (
max(0.0, math.expm1(self._cached_pred))
if self._use_log1p
else self._cached_pred
)
LOAD_PREDICTORS = { LOAD_PREDICTORS = {
"constant": ConstantPredictor, "constant": ConstantPredictor,
"arima": ARIMAPredictor, "arima": ARIMAPredictor,
"kalman": KalmanPredictor,
"prophet": ProphetPredictor, "prophet": ProphetPredictor,
} }
...@@ -101,13 +101,19 @@ def create_sla_planner_parser() -> argparse.ArgumentParser: ...@@ -101,13 +101,19 @@ def create_sla_planner_parser() -> argparse.ArgumentParser:
parser.add_argument( parser.add_argument(
"--load-predictor", "--load-predictor",
default=SLAPlannerDefaults.load_predictor, default=SLAPlannerDefaults.load_predictor,
help="Load predictor type", help="Load predictor type (constant, arima, kalman, prophet)",
) )
parser.add_argument( parser.add_argument(
"--load-prediction-window-size", "--load-predictor-log1p",
action="store_true",
default=SLAPlannerDefaults.load_predictor_log1p,
help="Model log1p(y) instead of y in the selected load predictor (ARIMA/Kalman/Prophet)",
)
parser.add_argument(
"--prophet-window-size",
type=int, type=int,
default=SLAPlannerDefaults.load_prediction_window_size, default=SLAPlannerDefaults.prophet_window_size,
help="Load prediction window size", help="Prophet history window size",
) )
parser.add_argument( parser.add_argument(
"--load-predictor-warmup-trace", "--load-predictor-warmup-trace",
...@@ -115,6 +121,30 @@ def create_sla_planner_parser() -> argparse.ArgumentParser: ...@@ -115,6 +121,30 @@ def create_sla_planner_parser() -> argparse.ArgumentParser:
default=None, default=None,
help="Optional path to a mooncake-style JSONL trace file used to warm up load predictors before observing live traffic", help="Optional path to a mooncake-style JSONL trace file used to warm up load predictors before observing live traffic",
) )
parser.add_argument(
"--kalman-q-level",
type=float,
default=SLAPlannerDefaults.kalman_q_level,
help="Kalman process noise for level (higher = more responsive)",
)
parser.add_argument(
"--kalman-q-trend",
type=float,
default=SLAPlannerDefaults.kalman_q_trend,
help="Kalman process noise for trend (higher = faster trend changes)",
)
parser.add_argument(
"--kalman-r",
type=float,
default=SLAPlannerDefaults.kalman_r,
help="Kalman measurement noise (lower = remember less / react more to new measurements)",
)
parser.add_argument(
"--kalman-min-points",
type=int,
default=SLAPlannerDefaults.kalman_min_points,
help="Minimum number of points before Kalman predictor returns forecasts",
)
parser.add_argument( parser.add_argument(
"--metric-pulling-prometheus-endpoint", "--metric-pulling-prometheus-endpoint",
type=str, type=str,
......
...@@ -155,15 +155,11 @@ class Planner: ...@@ -155,15 +155,11 @@ class Planner:
args.namespace, args.namespace,
) )
self.num_req_predictor = LOAD_PREDICTORS[args.load_predictor]( predictor_cls = LOAD_PREDICTORS[args.load_predictor]
window_size=args.load_prediction_window_size, # Predictors read configuration from `args` directly.
) self.num_req_predictor = predictor_cls(args)
self.isl_predictor = LOAD_PREDICTORS[args.load_predictor]( self.isl_predictor = predictor_cls(args)
window_size=args.load_prediction_window_size, self.osl_predictor = predictor_cls(args)
)
self.osl_predictor = LOAD_PREDICTORS[args.load_predictor](
window_size=args.load_prediction_window_size,
)
# Optional warmup: preload predictors with historical observations from a # Optional warmup: preload predictors with historical observations from a
# mooncake-style JSONL trace (request_count/avg_isl/avg_osl per interval). # mooncake-style JSONL trace (request_count/avg_isl/avg_osl per interval).
......
...@@ -17,6 +17,7 @@ aiofiles ...@@ -17,6 +17,7 @@ aiofiles
aiperf @ git+https://github.com/ai-dynamo/aiperf.git@54cd6dc820bff8bfebc875da104e59d745e14f75 aiperf @ git+https://github.com/ai-dynamo/aiperf.git@54cd6dc820bff8bfebc875da104e59d745e14f75
av==15.0.0 av==15.0.0
fastapi==0.120.1 fastapi==0.120.1
filterpy==1.4.5
ftfy==6.3.1 ftfy==6.3.1
genai-perf==0.0.15 genai-perf==0.0.15
grpcio-tools<=1.76.0 # May have platform-specific builds grpcio-tools<=1.76.0 # May have platform-specific builds
......
...@@ -34,7 +34,7 @@ flowchart LR ...@@ -34,7 +34,7 @@ flowchart LR
## Features ## Features
* **SLA-driven scaling**: Automatically scales prefill/decode workers to meet TTFT and ITL targets * **SLA-driven scaling**: Automatically scales prefill/decode workers to meet TTFT and ITL targets
* **Predictive load forecasting**: Uses ARIMA, Prophet, or constant predictors to forecast future load * **Predictive load forecasting**: Uses ARIMA, Prophet, Kalman, or constant predictors to forecast future load
* **Performance interpolation**: Leverages profiling results data from pre-deployment profiling for accurate scaling decisions * **Performance interpolation**: Leverages profiling results data from pre-deployment profiling for accurate scaling decisions
* **Correction factors**: Adapts to real-world performance deviations from profiled data * **Correction factors**: Adapts to real-world performance deviations from profiled data
...@@ -55,7 +55,7 @@ See [Pre-Deployment Profiling](../benchmarks/sla_driven_profiling.md) for detail ...@@ -55,7 +55,7 @@ See [Pre-Deployment Profiling](../benchmarks/sla_driven_profiling.md) for detail
## Load Prediction ## Load Prediction
The SLA planner use load predictor to predict the number of requests, ISL, and OSL in the next adjustment interval. Currently, three load prediction model is supported: The SLA planner uses a load predictor to forecast the number of requests, ISL, and OSL in the next adjustment interval. Currently, four load prediction models are supported:
### Constant Predictor ### Constant Predictor
- **Use case**: Stable and long prediction interval - **Use case**: Stable and long prediction interval
...@@ -66,11 +66,33 @@ The SLA planner use load predictor to predict the number of requests, ISL, and O ...@@ -66,11 +66,33 @@ The SLA planner use load predictor to predict the number of requests, ISL, and O
- **Use case**: Time-series data with trends and seasonality - **Use case**: Time-series data with trends and seasonality
- **Behavior**: Uses auto-ARIMA to fit optimal model parameters - **Behavior**: Uses auto-ARIMA to fit optimal model parameters
- **Configuration**: `load-predictor: "arima"` - **Configuration**: `load-predictor: "arima"`
- **Tunable parameters**:
- `--load-predictor-log1p`: model `log1p(y)` instead of `y`. If not set, ARIMA starts in raw space, and if it collapses to `(0,d,0)`, it falls back to `log1p` automatically.
### Kalman Predictor
- **Use case**: Low-latency online forecasting (observe 1 → predict 1) with smooth adaptation
- **Behavior**: Local linear trend Kalman filter (fast online updates; good default when ARIMA collapses to mean-only)
- **Configuration**: `load-predictor: "kalman"`
- **Tunable parameters**:
- `--kalman-q-level`: process noise for level (higher = more responsive)
- `--kalman-q-trend`: process noise for trend (higher = trend changes faster)
- `--kalman-r`: measurement noise (lower = trusts new measurements more)
- `--kalman-min-points`: minimum points before forecasting
- `--load-predictor-log1p`: model `log1p(y)` instead of `y` (often helps request-rate/count series)
### Prophet Predictor ### Prophet Predictor
- **Use case**: Complex seasonal patterns and trend changes - **Use case**: Complex seasonal patterns and trend changes
- **Behavior**: Facebook's [Prophet](https://facebook.github.io/prophet/) model for time-series forecasting - **Behavior**: Facebook's [Prophet](https://facebook.github.io/prophet/) model for time-series forecasting
- **Configuration**: `load-predictor: "prophet"` - **Configuration**: `load-predictor: "prophet"`
- **Tunable parameters**:
- `--prophet-window-size`: bounds internal history to control refit cost
- `--load-predictor-log1p`: model `log1p(y)` instead of `y`
### Warm-starting Load Predictors (Optional)
You can warm-start the load predictors with a mooncake-style JSONL trace file to provide historical context before live traffic is observed:
- **CLI argument**: `--load-predictor-warmup-trace <path/to/trace.jsonl>`
- **Effect**: preloads the predictors with historical request-count / ISL / OSL samples extracted from the trace.
## Scaling Algorithm ## Scaling Algorithm
......
...@@ -90,6 +90,8 @@ STUB_MODULES = [ ...@@ -90,6 +90,8 @@ STUB_MODULES = [
"matplotlib.pyplot", "matplotlib.pyplot",
"pmdarima", "pmdarima",
"prophet", "prophet",
"filterpy",
"filterpy.kalman",
"scipy", "scipy",
"scipy.interpolate", "scipy.interpolate",
"nats", "nats",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment