Unverified Commit b0959cfd authored by Hongkuan Zhou's avatar Hongkuan Zhou Committed by GitHub
Browse files

feat: warmup dataset for planner load predictor (#5529)


Signed-off-by: default avatarhongkuanz <hongkuanz@nvidia.com>
parent 0f137de4
......@@ -56,13 +56,15 @@ If `--output-file` is not specified, the output will use the input filename with
**Filtering:**
- `--model`: Filter by model (`ChatGPT` or `GPT-4`), None for no filtering
- `--log-type`: Filter by log type (`Conversation log` or `API log`), None for no filtering
- `--num-prompt`: Limit number of rows in the final output, None for no filtering
- `--skip-num-prompt`: Skip the first N rows after filtering (default: 0). Applied **before** `--num-prompt`.
- `--num-prompt`: Limit number of rows in the final output, None for no filtering (applied **after** `--skip-num-prompt`)
**Timestamp Adjustment:**
- `--speed-ratio`: Adjust request timing (default: 1.0)
- Values > 1: Speed up (e.g., 2.0 = 2x faster)
- Values < 1: Slow down (e.g., 0.5 = 2x slower)
- Formula: `new_timestamp = old_timestamp / speed_ratio`
- After filtering/skip/cap and speed-ratio adjustment, timestamps are shifted so the first kept request starts at `t=0`.
**Hash Generation:**
- `--block-size`: Block size in mooncake traces (default: 128)
......
......@@ -41,6 +41,12 @@ def parse_args():
default=None,
help="Limit the number of rows to output after filtering. If not specified, all rows are output.",
)
parser.add_argument(
"--skip-num-prompt",
type=int,
default=0,
help="Skip the first N rows after filtering (before applying --num-prompt). Default: 0",
)
parser.add_argument(
"--speed-ratio",
type=float,
......@@ -80,7 +86,7 @@ def load_csv(filepath):
return None
def apply_filters(df, model=None, log_type=None, num_prompt=None):
def apply_filters(df, model=None, log_type=None, skip_num_prompt=0, num_prompt=None):
"""
Apply filters to the DataFrame.
......@@ -88,6 +94,7 @@ def apply_filters(df, model=None, log_type=None, num_prompt=None):
df: Input DataFrame
model: Model to filter by (ChatGPT or GPT-4)
log_type: Log type to filter by (Conversation log or API log)
skip_num_prompt: Number of rows to skip after filtering (before capping)
num_prompt: Number of rows to keep after filtering
Returns:
......@@ -105,12 +112,18 @@ def apply_filters(df, model=None, log_type=None, num_prompt=None):
filtered_df = filtered_df[filtered_df["Log Type"] == log_type]
print(f"After log type filter ({log_type}): {len(filtered_df)} rows")
# Skip rows (before capping)
if skip_num_prompt and skip_num_prompt > 0:
filtered_df = filtered_df.iloc[skip_num_prompt:]
print(f"After skip_num_prompt ({skip_num_prompt}): {len(filtered_df)} rows")
# Apply num_prompt limit
if num_prompt is not None:
filtered_df = filtered_df.head(num_prompt)
print(f"After num_prompt limit ({num_prompt}): {len(filtered_df)} rows")
return filtered_df
# Reset index so downstream iterrows() uses a clean, deterministic range
return filtered_df.reset_index(drop=True)
def apply_speed_ratio(df, speed_ratio):
......@@ -142,6 +155,31 @@ def apply_speed_ratio(df, speed_ratio):
return adjusted_df
def offset_timestamps_to_zero(df):
"""
Offset timestamps so the first request starts at t=0.
Args:
df: DataFrame with a "Timestamp" column in seconds
Returns:
DataFrame with timestamps shifted such that min Timestamp is 0
"""
if "Timestamp" not in df.columns or len(df) == 0:
return df
min_ts = df["Timestamp"].min()
if pd.isna(min_ts) or min_ts == 0:
return df
adjusted_df = df.copy()
adjusted_df["Timestamp"] = adjusted_df["Timestamp"] - float(min_ts)
print(
f"Offset timestamps so first request starts at t=0 (subtracted {min_ts:.6f}s)"
)
return adjusted_df
def convert_to_mooncake(df, block_size, num_hash_blocks):
"""
Convert DataFrame to mooncake format.
......@@ -243,6 +281,46 @@ def print_statistics(df):
print(f" Total requests: {len(df)}")
print(f" Duration: {duration_s:.2f} seconds")
print(f" Average RPS: {avg_rps:.2f}")
# Request rate vs time (ASCII plot, 60-col width)
plot_width = 60
# Target ~20 bins; clamp to at least 1s bins for stability
target_bins = 20
bin_size_s = max(1.0, duration_s / target_bins)
num_bins = max(1, math.ceil(duration_s / bin_size_s))
counts = [0] * num_bins
# Compute per-bin counts using timestamps relative to start
for ts_ms in df["timestamp"].tolist():
rel_s = (ts_ms / 1000.0) - min_timestamp_s
idx = int(rel_s / bin_size_s)
if idx < 0:
idx = 0
elif idx >= num_bins:
idx = num_bins - 1
counts[idx] += 1
rates = [c / bin_size_s for c in counts]
peak_rps = max(rates) if rates else 0.0
print("\nRequest rate vs time:")
print(f" Bin: {bin_size_s:.2f}s, Peak RPS: {peak_rps:.2f}")
if peak_rps > 0:
# Use dynamic, fixed-width labels so the bars align
max_time_s = max(0.0, duration_s)
digits = max(1, len(str(int(math.ceil(max_time_s)))))
label_width = (2 * digits) + 2 # "{start}-{end}s"
bar_width = max(1, plot_width - label_width - 3) # " | "
for i, rps in enumerate(rates):
start_s = i * bin_size_s
end_s = min((i + 1) * bin_size_s, duration_s)
bar_len = int(round((rps / peak_rps) * bar_width))
bar = "#" * max(0, min(bar_width, bar_len))
label = f"{start_s:>{digits}.0f}-{end_s:>{digits}.0f}s"
line = f"{label} | {bar}"
print(line[:plot_width])
else:
print("\nRequest Rate:")
print(f" Total requests: {len(df)}")
......@@ -268,11 +346,16 @@ def main():
print("\nApplying filters...")
print(f"Initial rows: {len(df)}")
filtered_df = apply_filters(
df, model=args.model, log_type=args.log_type, num_prompt=args.num_prompt
df,
model=args.model,
log_type=args.log_type,
skip_num_prompt=args.skip_num_prompt,
num_prompt=args.num_prompt,
)
# Apply Speedup
adjusted_df = apply_speed_ratio(filtered_df, args.speed_ratio)
adjusted_df = offset_timestamps_to_zero(adjusted_df)
# Convert to mooncake format
print("\nConverting to mooncake format...")
......
......@@ -31,6 +31,10 @@ def create_dryrun_plot(
d_thpt: list,
safe_d_thpt: list,
output_path: str,
warmup_time: list | None = None,
warmup_rr: list | None = None,
warmup_isl: list | None = None,
warmup_osl: list | None = None,
) -> None:
"""
Create a comprehensive dryrun plot with 4 subplots showing various metrics over time.
......@@ -50,12 +54,27 @@ def create_dryrun_plot(
d_thpt: List of actual decode throughputs
safe_d_thpt: List of safe decode throughput limits
output_path: Path where the plot should be saved
warmup_time: Optional list of warmup time points (negative seconds)
warmup_rr: Optional list of warmup request rates (same units as rr)
warmup_isl: Optional list of warmup input sequence lengths
warmup_osl: Optional list of warmup output sequence lengths
"""
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))
# Plot 1: Request Rate
if warmup_time is not None and warmup_rr is not None and len(warmup_time) > 0:
ax1.plot(
warmup_time,
warmup_rr,
"b-",
alpha=0.35,
linewidth=2,
label="Warmup Request Rate",
)
ax1.plot(time, rr, "b-", label="Actual Request Rate", linewidth=2)
ax1.plot(time, est_rr, "r--", label="Predicted Request Rate", linewidth=2)
if warmup_time is not None and warmup_rr is not None:
ax1.axvline(0, color="k", linestyle=":", linewidth=2, label="Warmup Boundary")
ax1.set_xlabel("Time (s)")
ax1.set_ylabel("Request Rate")
ax1.set_ylim(bottom=0)
......@@ -64,10 +83,34 @@ def create_dryrun_plot(
ax1.grid(True, alpha=0.3)
# Plot 2: Sequence Lengths
if (
warmup_time is not None
and warmup_isl is not None
and warmup_osl is not None
and len(warmup_time) > 0
):
ax2.plot(
warmup_time,
warmup_isl,
"g-",
alpha=0.35,
linewidth=2,
label="Warmup ISL",
)
ax2.plot(
warmup_time,
warmup_osl,
"m-",
alpha=0.35,
linewidth=2,
label="Warmup OSL",
)
ax2.plot(time, isl, "g-", label="Actual ISL", linewidth=2)
ax2.plot(time, est_isl, "g--", label="Predicted ISL", linewidth=2)
ax2.plot(time, osl, "m-", label="Actual OSL", linewidth=2)
ax2.plot(time, est_osl, "m--", label="Predicted OSL", linewidth=2)
if warmup_time is not None and warmup_isl is not None and warmup_osl is not None:
ax2.axvline(0, color="k", linestyle=":", linewidth=2, label="Warmup Boundary")
ax2.set_xlabel("Time (s)")
ax2.set_ylabel("Num Tokens")
ax2.set_ylim(bottom=0)
......
......@@ -39,16 +39,28 @@ class BasePredictor(ABC):
def __init__(self, minimum_data_points=5):
self.minimum_data_points = minimum_data_points
self.data_buffer = []
# Even if we preload historical data, we still want to ignore the initial
# post-deployment idle period (a run of zeros) until we see the first
# non-zero datapoint from live traffic.
self._seen_nonzero_since_idle_reset = False
def reset_idle_skip(self):
"""Reset idle-period skipping state (e.g., after warmup, before live)."""
self._seen_nonzero_since_idle_reset = False
def add_data_point(self, value):
"""Add new data point to the buffer"""
if math.isnan(value):
value = 0
if len(self.data_buffer) == 0 and value == 0:
# skip the beginning idle period
if value == 0 and not self._seen_nonzero_since_idle_reset:
# Skip the beginning idle period (leading zeros) even if data_buffer
# is pre-warmed with historical data.
return
else:
if value != 0:
self._seen_nonzero_since_idle_reset = True
self.data_buffer.append(value)
def get_last_value(self):
......@@ -81,9 +93,16 @@ class ARIMAPredictor(BasePredictor):
super().__init__(minimum_data_points=minimum_data_points)
self.window_size = window_size # How many past points to use
self.model = None
# Pending points to incrementally update the fitted model with.
# This avoids re-running auto_arima() on every step.
self._pending_updates: list[float] = []
def add_data_point(self, value):
prev_len = len(self.data_buffer)
super().add_data_point(value)
if len(self.data_buffer) > prev_len:
# Only queue updates if the value wasn't skipped by BasePredictor.
self._pending_updates.append(float(self.data_buffer[-1]))
# Keep only the last window_size points
if len(self.data_buffer) > self.window_size:
self.data_buffer = self.data_buffer[-self.window_size :]
......@@ -99,12 +118,20 @@ class ARIMAPredictor(BasePredictor):
return self.data_buffer[0] # Return the constant value
try:
# Fit auto ARIMA model
# Fit auto ARIMA model once, then only do incremental updates.
if self.model is None:
self.model = pmdarima.auto_arima(
self.data_buffer,
suppress_warnings=True,
error_action="ignore",
)
else:
# Incrementally update model with any new observations since last predict.
if self._pending_updates:
self.model.update(self._pending_updates)
# Clear pending updates: model is now up-to-date through the latest observed point.
self._pending_updates = []
# Make prediction
forecast = self.model.predict(n_periods=1)
......@@ -112,6 +139,7 @@ class ARIMAPredictor(BasePredictor):
except Exception as e:
# Log the specific error for debugging
logger.warning(f"ARIMA prediction failed: {e}, using last value")
self._pending_updates = []
return self.get_last_value()
......@@ -124,6 +152,7 @@ class ProphetPredictor(BasePredictor):
self.step_size = step_size
self.start_date = datetime(2024, 1, 1) # Base date for generating timestamps
self.data_buffer = [] # Override to store dicts instead of values
self._seen_nonzero_since_idle_reset = False
def add_data_point(self, value):
"""Add new data point to the buffer"""
......@@ -131,9 +160,13 @@ class ProphetPredictor(BasePredictor):
timestamp = self.start_date + timedelta(seconds=self.curr_step)
value = 0 if math.isnan(value) else value
if len(self.data_buffer) == 0 and value == 0:
# skip the beginning idle period
if value == 0 and not self._seen_nonzero_since_idle_reset:
# skip the beginning idle period (leading zeros), even if pre-warmed
return
if value != 0:
self._seen_nonzero_since_idle_reset = True
self.data_buffer.append({"ds": timestamp, "y": value})
self.curr_step += 1
......
......@@ -109,6 +109,12 @@ def create_sla_planner_parser() -> argparse.ArgumentParser:
default=SLAPlannerDefaults.load_prediction_window_size,
help="Load prediction window size",
)
parser.add_argument(
"--load-predictor-warmup-trace",
type=str,
default=None,
help="Optional path to a mooncake-style JSONL trace file used to warm up load predictors before observing live traffic",
)
parser.add_argument(
"--metric-pulling-prometheus-endpoint",
type=str,
......
......@@ -165,6 +165,36 @@ class Planner:
window_size=args.load_prediction_window_size,
)
# Optional warmup: preload predictors with historical observations from a
# mooncake-style JSONL trace (request_count/avg_isl/avg_osl per interval).
if getattr(args, "load_predictor_warmup_trace", None):
warmup_trace = args.load_predictor_warmup_trace
try:
metrics = extract_metrics_from_mooncake(
warmup_trace, args.adjustment_interval
)
for m in metrics:
self.num_req_predictor.add_data_point(float(m["request_count"]))
self.isl_predictor.add_data_point(float(m["avg_isl"]))
self.osl_predictor.add_data_point(float(m["avg_osl"]))
logger.info(
f"Warmed load predictors with {len(metrics)} intervals from {warmup_trace}"
)
except Exception as e:
logger.warning(
f"Failed to warm load predictors from {warmup_trace}: {e}"
)
finally:
# Even with warmup data, ignore the initial post-deploy idle
# period (leading zeros) when live metrics start coming in.
for p in (
self.num_req_predictor,
self.isl_predictor,
self.osl_predictor,
):
if hasattr(p, "reset_idle_skip"):
p.reset_idle_skip()
if "use-pre-swept-results" in args.profile_results_dir:
config_list = args.profile_results_dir.split(":")
configs = {
......@@ -625,6 +655,13 @@ class Planner:
def dryrun_run(self):
"""Run planner in dry-run mode with dataset"""
warmup_metrics = None
if getattr(self.args, "load_predictor_warmup_trace", None):
warmup_metrics = extract_metrics_from_mooncake(
self.args.load_predictor_warmup_trace,
self.args.adjustment_interval,
)
metrics = extract_metrics_from_mooncake(
self.args.dataset, self.args.adjustment_interval
)
......@@ -723,6 +760,18 @@ class Planner:
# plot the results
from dynamo.planner.utils.dryrun_plot_utils import create_dryrun_plot
warmup_time = None
warmup_rr = None
warmup_isl = None
warmup_osl = None
if warmup_metrics:
interval = self.args.adjustment_interval
n = len(warmup_metrics)
warmup_time = [-(n - i) * interval for i in range(n)]
warmup_rr = [m["request_count"] for m in warmup_metrics]
warmup_isl = [m["avg_isl"] for m in warmup_metrics]
warmup_osl = [m["avg_osl"] for m in warmup_metrics]
create_dryrun_plot(
time=time,
rr=rr,
......@@ -738,6 +787,10 @@ class Planner:
d_thpt=d_thpt,
safe_d_thpt=safe_d_thpt,
output_path=self.args.output_plot,
warmup_time=warmup_time,
warmup_rr=warmup_rr,
warmup_isl=warmup_isl,
warmup_osl=warmup_osl,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment