Unverified Commit 4557b6df authored by Hongkuan Zhou's avatar Hongkuan Zhou Committed by GitHub
Browse files

feat: predict log1p(y) instead of y in load predictor if ARIMA collapses (#5545)


Signed-off-by: default avatarhongkuanz <hongkuanz@nvidia.com>
parent 228c6871
......@@ -18,19 +18,22 @@ import math
import warnings
from abc import ABC, abstractmethod
from datetime import datetime, timedelta
from enum import Enum
import pandas as pd
import pmdarima
from prophet import Prophet
logger = logging.getLogger("cmdstanpy")
logger.addHandler(logging.NullHandler())
logger.propagate = False
logger.setLevel(logging.CRITICAL)
from dynamo.runtime.logging import configure_dynamo_logging
# Suppress sklearn deprecation warnings
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning)
configure_dynamo_logging()
logger = logging.getLogger(__name__)
warnings.filterwarnings(
"ignore",
category=FutureWarning,
message=".*force_all_finite.*",
)
class BasePredictor(ABC):
......@@ -89,57 +92,136 @@ class ConstantPredictor(BasePredictor):
# Auto ARIMA model from pmdarima
class ARIMAPredictor(BasePredictor):
class Mode(str, Enum):
RAW = "raw"
LOG1P = "log1p"
def __init__(self, window_size=100, minimum_data_points=5):
super().__init__(minimum_data_points=minimum_data_points)
self.window_size = window_size # How many past points to use
self.model = None
# Pending points to incrementally update the fitted model with.
# This avoids re-running auto_arima() on every step.
self._pending_updates: list[float] = []
# Keep raw values so we can fit in raw space first, then fallback to log1p space.
self._raw_buffer: list[float] = []
# Pending raw points to incrementally update the fitted model with.
self._pending_raw_updates: list[float] = []
# Current modeling space
self._mode: ARIMAPredictor.Mode = ARIMAPredictor.Mode.RAW
def get_last_value(self):
"""Return last value in original scale."""
if self._raw_buffer:
return float(self._raw_buffer[-1])
if not self.data_buffer:
return 0
return float(self.data_buffer[-1])
def add_data_point(self, value):
prev_len = len(self.data_buffer)
# Use raw value for idle skipping in BasePredictor. We may transform later.
super().add_data_point(value)
if len(self.data_buffer) > prev_len:
# Only queue updates if the value wasn't skipped by BasePredictor.
self._pending_updates.append(float(self.data_buffer[-1]))
raw = max(0.0, float(self.data_buffer[-1]))
self._raw_buffer.append(raw)
self._pending_raw_updates.append(raw)
# If we are in log1p mode, keep data_buffer in model space.
if self._mode == ARIMAPredictor.Mode.LOG1P:
self.data_buffer[-1] = math.log1p(raw)
# Keep only the last window_size points
if len(self.data_buffer) > self.window_size:
self.data_buffer = self.data_buffer[-self.window_size :]
if len(self._raw_buffer) > self.window_size:
self._raw_buffer = self._raw_buffer[-self.window_size :]
def predict_next(self):
"""Predict the next value(s)"""
if len(self.data_buffer) < self.minimum_data_points:
if len(self._raw_buffer) < self.minimum_data_points:
return self.get_last_value()
# Check if all values are the same (constant data)
# pmdarima will predict 0 for constant data, we need to correct its prediction
if len(set(self.data_buffer)) == 1:
return self.data_buffer[0] # Return the constant value
if len(set(self._raw_buffer)) == 1:
return float(self._raw_buffer[0])
try:
# Fit auto ARIMA model once, then only do incremental updates.
if self.model is None:
# Always try raw space first
self._mode = ARIMAPredictor.Mode.RAW
self.model = pmdarima.auto_arima(
self.data_buffer,
suppress_warnings=True,
error_action="ignore",
)
order = getattr(self.model, "order", None)
seasonal_order = getattr(self.model, "seasonal_order", None)
aic = None
try:
aic = float(self.model.aic()) # type: ignore[attr-defined]
except Exception:
aic = None
logger.info(
f"ARIMA selected order={order} seasonal_order={seasonal_order} aic={aic}"
)
# If raw collapses to (0,d,0), fallback to log1p(y)
try:
if order is not None and len(order) == 3:
p, _, q = order
if p == 0 and q == 0:
# Build log buffer/model in locals and only swap on success
log_buffer = [math.log1p(v) for v in self._raw_buffer]
log_model = pmdarima.auto_arima(
log_buffer,
suppress_warnings=True,
error_action="ignore",
)
# Swap mode + model + buffer atomically
self._mode = ARIMAPredictor.Mode.LOG1P
self.data_buffer = log_buffer
self.model = log_model
order2 = getattr(self.model, "order", None)
seasonal_order2 = getattr(
self.model, "seasonal_order", None
)
aic2 = None
try:
aic2 = float(self.model.aic()) # type: ignore[attr-defined]
except Exception:
aic2 = None
logger.info(
f"Detect ARIMA model collapses to (0,d,0), fallback to log1p(y) to better handle spiky time series."
f"ARIMA (fallback log1p) selected order={order2} seasonal_order={seasonal_order2} aic={aic2}"
)
except Exception:
# If fallback fails, keep raw.
self._mode = ARIMAPredictor.Mode.RAW
# Model is fit on all history; clear pending updates.
self._pending_raw_updates = []
else:
# Incrementally update model with any new observations since last predict.
if self._pending_updates:
self.model.update(self._pending_updates)
if self._pending_raw_updates:
upd = (
[math.log1p(v) for v in self._pending_raw_updates]
if self._mode == ARIMAPredictor.Mode.LOG1P
else self._pending_raw_updates
)
self.model.update(upd)
# Clear pending updates: model is now up-to-date through the latest observed point.
self._pending_updates = []
self._pending_raw_updates = []
# Make prediction
forecast = self.model.predict(n_periods=1)
return forecast[0]
forecast = float(self.model.predict(n_periods=1)[0])
if self._mode == ARIMAPredictor.Mode.LOG1P:
return max(0.0, math.expm1(forecast))
return max(0.0, forecast)
except Exception as e:
# Log the specific error for debugging
logger.warning(f"ARIMA prediction failed: {e}, using last value")
self._pending_updates = []
self._pending_raw_updates = []
return self.get_last_value()
......
......@@ -38,10 +38,10 @@ sys.modules["dynamo.runtime.logging"] = mock_runtime.logging
# Now import after mocking
from dynamo.planner.utils.planner_core import Metrics, Planner # noqa: E402
pytestmark = [pytest.mark.pre_merge, pytest.mark.gpu_0]
@pytest.fixture
@pytest.mark.pre_merge
@pytest.mark.gpu_0
def planner():
"""Set up test environment with mocked dependencies."""
# Create mock arguments
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment