Unverified Commit 4557b6df authored by Hongkuan Zhou's avatar Hongkuan Zhou Committed by GitHub
Browse files

feat: predict log1p(y) instead of y in load predictor if ARIMA collapses (#5545)


Signed-off-by: default avatarhongkuanz <hongkuanz@nvidia.com>
parent 228c6871
...@@ -18,19 +18,22 @@ import math ...@@ -18,19 +18,22 @@ import math
import warnings import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from datetime import datetime, timedelta from datetime import datetime, timedelta
from enum import Enum
import pandas as pd import pandas as pd
import pmdarima import pmdarima
from prophet import Prophet from prophet import Prophet
logger = logging.getLogger("cmdstanpy") from dynamo.runtime.logging import configure_dynamo_logging
logger.addHandler(logging.NullHandler())
logger.propagate = False
logger.setLevel(logging.CRITICAL)
# Suppress sklearn deprecation warnings configure_dynamo_logging()
warnings.filterwarnings("ignore", category=FutureWarning) logger = logging.getLogger(__name__)
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings(
"ignore",
category=FutureWarning,
message=".*force_all_finite.*",
)
class BasePredictor(ABC): class BasePredictor(ABC):
...@@ -89,57 +92,136 @@ class ConstantPredictor(BasePredictor): ...@@ -89,57 +92,136 @@ class ConstantPredictor(BasePredictor):
# Auto ARIMA model from pmdarima # Auto ARIMA model from pmdarima
class ARIMAPredictor(BasePredictor): class ARIMAPredictor(BasePredictor):
class Mode(str, Enum):
RAW = "raw"
LOG1P = "log1p"
def __init__(self, window_size=100, minimum_data_points=5): def __init__(self, window_size=100, minimum_data_points=5):
super().__init__(minimum_data_points=minimum_data_points) super().__init__(minimum_data_points=minimum_data_points)
self.window_size = window_size # How many past points to use self.window_size = window_size # How many past points to use
self.model = None self.model = None
# Pending points to incrementally update the fitted model with. # Keep raw values so we can fit in raw space first, then fallback to log1p space.
# This avoids re-running auto_arima() on every step. self._raw_buffer: list[float] = []
self._pending_updates: list[float] = [] # Pending raw points to incrementally update the fitted model with.
self._pending_raw_updates: list[float] = []
# Current modeling space
self._mode: ARIMAPredictor.Mode = ARIMAPredictor.Mode.RAW
def get_last_value(self):
"""Return last value in original scale."""
if self._raw_buffer:
return float(self._raw_buffer[-1])
if not self.data_buffer:
return 0
return float(self.data_buffer[-1])
def add_data_point(self, value): def add_data_point(self, value):
prev_len = len(self.data_buffer) prev_len = len(self.data_buffer)
# Use raw value for idle skipping in BasePredictor. We may transform later.
super().add_data_point(value) super().add_data_point(value)
if len(self.data_buffer) > prev_len: if len(self.data_buffer) > prev_len:
# Only queue updates if the value wasn't skipped by BasePredictor. raw = max(0.0, float(self.data_buffer[-1]))
self._pending_updates.append(float(self.data_buffer[-1])) self._raw_buffer.append(raw)
self._pending_raw_updates.append(raw)
# If we are in log1p mode, keep data_buffer in model space.
if self._mode == ARIMAPredictor.Mode.LOG1P:
self.data_buffer[-1] = math.log1p(raw)
# Keep only the last window_size points # Keep only the last window_size points
if len(self.data_buffer) > self.window_size: if len(self.data_buffer) > self.window_size:
self.data_buffer = self.data_buffer[-self.window_size :] self.data_buffer = self.data_buffer[-self.window_size :]
if len(self._raw_buffer) > self.window_size:
self._raw_buffer = self._raw_buffer[-self.window_size :]
def predict_next(self): def predict_next(self):
"""Predict the next value(s)""" """Predict the next value(s)"""
if len(self.data_buffer) < self.minimum_data_points: if len(self._raw_buffer) < self.minimum_data_points:
return self.get_last_value() return self.get_last_value()
# Check if all values are the same (constant data) # Check if all values are the same (constant data)
# pmdarima will predict 0 for constant data, we need to correct its prediction # pmdarima will predict 0 for constant data, we need to correct its prediction
if len(set(self.data_buffer)) == 1: if len(set(self._raw_buffer)) == 1:
return self.data_buffer[0] # Return the constant value return float(self._raw_buffer[0])
try: try:
# Fit auto ARIMA model once, then only do incremental updates. # Fit auto ARIMA model once, then only do incremental updates.
if self.model is None: if self.model is None:
# Always try raw space first
self._mode = ARIMAPredictor.Mode.RAW
self.model = pmdarima.auto_arima( self.model = pmdarima.auto_arima(
self.data_buffer, self.data_buffer,
suppress_warnings=True, suppress_warnings=True,
error_action="ignore", error_action="ignore",
) )
order = getattr(self.model, "order", None)
seasonal_order = getattr(self.model, "seasonal_order", None)
aic = None
try:
aic = float(self.model.aic()) # type: ignore[attr-defined]
except Exception:
aic = None
logger.info(
f"ARIMA selected order={order} seasonal_order={seasonal_order} aic={aic}"
)
# If raw collapses to (0,d,0), fallback to log1p(y)
try:
if order is not None and len(order) == 3:
p, _, q = order
if p == 0 and q == 0:
# Build log buffer/model in locals and only swap on success
log_buffer = [math.log1p(v) for v in self._raw_buffer]
log_model = pmdarima.auto_arima(
log_buffer,
suppress_warnings=True,
error_action="ignore",
)
# Swap mode + model + buffer atomically
self._mode = ARIMAPredictor.Mode.LOG1P
self.data_buffer = log_buffer
self.model = log_model
order2 = getattr(self.model, "order", None)
seasonal_order2 = getattr(
self.model, "seasonal_order", None
)
aic2 = None
try:
aic2 = float(self.model.aic()) # type: ignore[attr-defined]
except Exception:
aic2 = None
logger.info(
f"Detect ARIMA model collapses to (0,d,0), fallback to log1p(y) to better handle spiky time series."
f"ARIMA (fallback log1p) selected order={order2} seasonal_order={seasonal_order2} aic={aic2}"
)
except Exception:
# If fallback fails, keep raw.
self._mode = ARIMAPredictor.Mode.RAW
# Model is fit on all history; clear pending updates.
self._pending_raw_updates = []
else: else:
# Incrementally update model with any new observations since last predict. # Incrementally update model with any new observations since last predict.
if self._pending_updates: if self._pending_raw_updates:
self.model.update(self._pending_updates) upd = (
[math.log1p(v) for v in self._pending_raw_updates]
if self._mode == ARIMAPredictor.Mode.LOG1P
else self._pending_raw_updates
)
self.model.update(upd)
# Clear pending updates: model is now up-to-date through the latest observed point. # Clear pending updates: model is now up-to-date through the latest observed point.
self._pending_updates = [] self._pending_raw_updates = []
# Make prediction # Make prediction
forecast = self.model.predict(n_periods=1) forecast = float(self.model.predict(n_periods=1)[0])
return forecast[0] if self._mode == ARIMAPredictor.Mode.LOG1P:
return max(0.0, math.expm1(forecast))
return max(0.0, forecast)
except Exception as e: except Exception as e:
# Log the specific error for debugging # Log the specific error for debugging
logger.warning(f"ARIMA prediction failed: {e}, using last value") logger.warning(f"ARIMA prediction failed: {e}, using last value")
self._pending_updates = [] self._pending_raw_updates = []
return self.get_last_value() return self.get_last_value()
......
...@@ -38,10 +38,10 @@ sys.modules["dynamo.runtime.logging"] = mock_runtime.logging ...@@ -38,10 +38,10 @@ sys.modules["dynamo.runtime.logging"] = mock_runtime.logging
# Now import after mocking # Now import after mocking
from dynamo.planner.utils.planner_core import Metrics, Planner # noqa: E402 from dynamo.planner.utils.planner_core import Metrics, Planner # noqa: E402
pytestmark = [pytest.mark.pre_merge, pytest.mark.gpu_0]
@pytest.fixture @pytest.fixture
@pytest.mark.pre_merge
@pytest.mark.gpu_0
def planner(): def planner():
"""Set up test environment with mocked dependencies.""" """Set up test environment with mocked dependencies."""
# Create mock arguments # Create mock arguments
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment