Unverified Commit d9aef67e authored by Hongkuan Zhou's avatar Hongkuan Zhou Committed by GitHub
Browse files

feat: add a knob to turn off correction factor in sla planner (#2511)

parent cae5822a
...@@ -80,6 +80,7 @@ class SLAPlannerDefaults(BasePlannerDefaults): ...@@ -80,6 +80,7 @@ class SLAPlannerDefaults(BasePlannerDefaults):
itl = 0.05 # in seconds itl = 0.05 # in seconds
load_predictor = "arima" # ["constant", "arima", "prophet"] load_predictor = "arima" # ["constant", "arima", "prophet"]
load_prediction_window_size = 50 # predict load using how many recent load samples load_prediction_window_size = 50 # predict load using how many recent load samples
no_correction = False # disable correction factor, might be useful under some conditions like long cold start time
class VllmComponentName: class VllmComponentName:
......
...@@ -141,6 +141,12 @@ if __name__ == "__main__": ...@@ -141,6 +141,12 @@ if __name__ == "__main__":
default=SLAPlannerDefaults.prometheus_port, default=SLAPlannerDefaults.prometheus_port,
help="Prometheus port", help="Prometheus port",
) )
parser.add_argument(
"--no-correction",
action="store_true",
default=SLAPlannerDefaults.no_correction,
help="Disable correction factor",
)
args = parser.parse_args() args = parser.parse_args()
asyncio.run(init_planner(args)) asyncio.run(init_planner(args))
...@@ -11,7 +11,7 @@ from typing import Optional ...@@ -11,7 +11,7 @@ from typing import Optional
from prometheus_client import Gauge, start_http_server from prometheus_client import Gauge, start_http_server
from dynamo.planner import KubernetesConnector, __version__ from dynamo.planner import KubernetesConnector
from dynamo.planner.defaults import WORKER_COMPONENT_NAMES, SLAPlannerDefaults from dynamo.planner.defaults import WORKER_COMPONENT_NAMES, SLAPlannerDefaults
from dynamo.planner.utils.load_predictor import LOAD_PREDICTORS from dynamo.planner.utils.load_predictor import LOAD_PREDICTORS
from dynamo.planner.utils.perf_interpolation import ( from dynamo.planner.utils.perf_interpolation import (
...@@ -19,7 +19,7 @@ from dynamo.planner.utils.perf_interpolation import ( ...@@ -19,7 +19,7 @@ from dynamo.planner.utils.perf_interpolation import (
PrefillInterpolator, PrefillInterpolator,
) )
from dynamo.planner.utils.prometheus import PrometheusAPIClient from dynamo.planner.utils.prometheus import PrometheusAPIClient
from dynamo.runtime import DistributedRuntime, dynamo_worker from dynamo.runtime import DistributedRuntime
from dynamo.runtime.logging import configure_dynamo_logging from dynamo.runtime.logging import configure_dynamo_logging
configure_dynamo_logging() configure_dynamo_logging()
...@@ -90,6 +90,7 @@ class Planner: ...@@ -90,6 +90,7 @@ class Planner:
self.p_correction_factor = 1.0 self.p_correction_factor = 1.0
self.d_correction_factor = 1.0 self.d_correction_factor = 1.0
self.no_correction = args.no_correction
self.prometheus_port = args.prometheus_port self.prometheus_port = args.prometheus_port
...@@ -204,40 +205,41 @@ class Planner: ...@@ -204,40 +205,41 @@ class Planner:
self.osl_predictor.add_data_point(self.last_metrics.osl) self.osl_predictor.add_data_point(self.last_metrics.osl)
async def make_adjustments(self): async def make_adjustments(self):
try: if not self.no_correction:
# Skip adjustment if no traffic try:
if not self.last_metrics.is_valid(): # Skip adjustment if no traffic
if not self.last_metrics.is_valid():
logger.info(
"Metrics contain None or NaN values (no active requests), skipping adjustment"
)
return
self.p_endpoints, self.d_endpoints = await self.get_workers_info()
logger.info( logger.info(
"Metrics contain None or NaN values (no active requests), skipping adjustment" f"Number of prefill workers: {len(self.p_endpoints)}, number of decode workers: {len(self.d_endpoints)}"
) )
return
self.p_endpoints, self.d_endpoints = await self.get_workers_info() # first correct the prediction correction factor
logger.info( # for TTFT, we expect the correction factor to be << 1 due to queuing delay
f"Number of prefill workers: {len(self.p_endpoints)}, number of decode workers: {len(self.d_endpoints)}" expect_ttft = self.prefill_interpolator.interpolate_ttft(
) self.last_metrics.isl
)
# first correct the prediction correction factor self.p_correction_factor = self.last_metrics.ttft / expect_ttft
# for TTFT, we expect the correction factor to be << 1 due to queuing delay # for ITL, we expect the correction factor to be close to 1
expect_ttft = self.prefill_interpolator.interpolate_ttft( expect_itl = self.decode_interpolator.interpolate_itl(
self.last_metrics.isl concurrency=self.last_metrics.num_req # type: ignore
) / len(self.d_endpoints)
self.p_correction_factor = self.last_metrics.ttft / expect_ttft * self.last_metrics.request_duration # type: ignore
# for ITL, we expect the correction factor to be close to 1 / self.args.adjustment_interval,
expect_itl = self.decode_interpolator.interpolate_itl( context_length=self.last_metrics.isl + self.last_metrics.osl / 2, # type: ignore
concurrency=self.last_metrics.num_req # type: ignore )
/ len(self.d_endpoints) self.d_correction_factor = self.last_metrics.itl / expect_itl
* self.last_metrics.request_duration # type: ignore logger.info(
/ self.args.adjustment_interval, f"Correction factors: TTFT: {self.p_correction_factor:.3f}, ITL: {self.d_correction_factor:.3f}"
context_length=self.last_metrics.isl + self.last_metrics.osl / 2, # type: ignore )
) except Exception as e:
self.d_correction_factor = self.last_metrics.itl / expect_itl logger.error(f"Failed to correct prediction factors: {e}")
logger.info( return
f"Correction factors: TTFT: {self.p_correction_factor:.3f}, ITL: {self.d_correction_factor:.3f}"
)
except Exception as e:
logger.error(f"Failed to correct prediction factors: {e}")
return
try: try:
# predict the next load # predict the next load
...@@ -360,116 +362,3 @@ class Planner: ...@@ -360,116 +362,3 @@ class Planner:
async def start_sla_planner(runtime: DistributedRuntime, args: argparse.Namespace): async def start_sla_planner(runtime: DistributedRuntime, args: argparse.Namespace):
planner = Planner(runtime, args) planner = Planner(runtime, args)
await planner.run() await planner.run()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Common planner arguments
parser.add_argument(
"--version", action="version", version=f"Dynamo Planner {__version__}"
)
parser.add_argument(
"--environment",
type=str,
default=SLAPlannerDefaults.environment,
help="Environment to run the planner in (local, kubernetes)",
)
parser.add_argument(
"--no-operation",
action="store_true",
default=SLAPlannerDefaults.no_operation,
help="Do not make any adjustments, just observe the metrics",
)
parser.add_argument(
"--log-dir",
type=str,
default=SLAPlannerDefaults.log_dir,
help="Tensorboard logging directory",
)
parser.add_argument(
"--adjustment-interval",
type=int,
default=SLAPlannerDefaults.adjustment_interval,
help="Interval in seconds between scaling adjustments",
)
parser.add_argument(
"--max-gpu-budget",
type=int,
default=SLAPlannerDefaults.max_gpu_budget,
help="Maximum number of GPUs to use",
)
parser.add_argument(
"--min-endpoint",
type=int,
default=SLAPlannerDefaults.min_endpoint,
help="Minimum number of endpoints to keep for prefill/decode workers",
)
parser.add_argument(
"--decode-engine-num-gpu",
type=int,
default=SLAPlannerDefaults.decode_engine_num_gpu,
help="Number of GPUs per decode engine",
)
parser.add_argument(
"--prefill-engine-num-gpu",
type=int,
default=SLAPlannerDefaults.prefill_engine_num_gpu,
help="Number of GPUs per prefill engine",
)
# SLA-planner specific arguments
parser.add_argument(
"--prometheus-endpoint",
type=str,
default=SLAPlannerDefaults.prometheus_endpoint,
help="Prometheus endpoint url",
)
parser.add_argument(
"--profile-results-dir",
type=str,
default=SLAPlannerDefaults.profile_results_dir,
help="Directory to pre-deployment profiling results",
)
parser.add_argument(
"--isl",
type=int,
default=SLAPlannerDefaults.isl,
help="Input sequence length",
)
parser.add_argument(
"--osl",
type=int,
default=SLAPlannerDefaults.osl,
help="Output sequence length",
)
parser.add_argument(
"--ttft",
type=float,
default=SLAPlannerDefaults.ttft,
help="Time to first token (in seconds)",
)
parser.add_argument(
"--itl",
type=float,
default=SLAPlannerDefaults.itl,
help="Inter-token latency (in seconds)",
)
parser.add_argument(
"--load-predictor",
type=str,
default=SLAPlannerDefaults.load_predictor,
help="Load predictor to use",
)
parser.add_argument(
"--load-prediction-window-size",
type=int,
default=SLAPlannerDefaults.load_prediction_window_size,
help="Window size for load prediction",
)
parser.add_argument(
"--prometheus-port",
type=int,
default=SLAPlannerDefaults.prometheus_port,
help="Prometheus port for metrics server (0 to disable)",
)
args = parser.parse_args()
asyncio.run(dynamo_worker()(start_sla_planner)(args))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment