Unverified Commit 31909ca3 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

fix(replay): short-circuit .npz planner imports (#8050)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 4e6c3964
...@@ -55,15 +55,15 @@ def resolve_planner_profile_data( ...@@ -55,15 +55,15 @@ def resolve_planner_profile_data(
Raises: Raises:
FileNotFoundError: If path doesn't contain valid profile data in any supported format. FileNotFoundError: If path doesn't contain valid profile data in any supported format.
""" """
if planner_profile_data is None:
return ProfileDataResult(npz_path=None, tmpdir=None)
from .utils.planner_profiler_perf_data_converter import ( from .utils.planner_profiler_perf_data_converter import (
convert_profile_results_to_npz, convert_profile_results_to_npz,
is_mocker_format_npz, is_mocker_format_npz,
is_profile_results_dir, is_profile_results_dir,
) )
if planner_profile_data is None:
return ProfileDataResult(npz_path=None, tmpdir=None)
# Case 1: Already a mocker-format NPZ file # Case 1: Already a mocker-format NPZ file
if is_mocker_format_npz(planner_profile_data): if is_mocker_format_npz(planner_profile_data):
logger.info(f"Using mocker-format NPZ file: {planner_profile_data}") logger.info(f"Using mocker-format NPZ file: {planner_profile_data}")
......
...@@ -32,12 +32,6 @@ from typing import Any ...@@ -32,12 +32,6 @@ from typing import Any
import numpy as np import numpy as np
from dynamo.planner.core.perf_model import DecodeRegressionModel, PrefillRegressionModel
from dynamo.planner.monitoring.perf_metrics import (
_convert_decode_profiling,
_convert_prefill_profiling,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -61,6 +55,15 @@ def convert_profile_results_to_npz( ...@@ -61,6 +55,15 @@ def convert_profile_results_to_npz(
Returns: Returns:
Path to the generated NPZ file. Path to the generated NPZ file.
""" """
from dynamo.planner.core.perf_model import (
DecodeRegressionModel,
PrefillRegressionModel,
)
from dynamo.planner.monitoring.perf_metrics import (
_convert_decode_profiling,
_convert_prefill_profiling,
)
profile_results_dir = str(Path(profile_results_dir).resolve()) profile_results_dir = str(Path(profile_results_dir).resolve())
output_path = Path(output_path) output_path = Path(output_path)
......
...@@ -27,15 +27,17 @@ class PlannerProfileDataResult(Protocol): ...@@ -27,15 +27,17 @@ class PlannerProfileDataResult(Protocol):
def resolve_planner_profile_data( def resolve_planner_profile_data(
planner_profile_data: Path | None, planner_profile_data: Path | None,
) -> PlannerProfileDataResult: ) -> PlannerProfileDataResult:
if planner_profile_data is None:
return SimpleNamespace(npz_path=None)
if planner_profile_data.suffix == ".npz":
return SimpleNamespace(npz_path=planner_profile_data)
try: try:
module = importlib.import_module("dynamo.mocker.args") module = importlib.import_module("dynamo.mocker.args")
except ImportError: except ImportError:
if planner_profile_data is None:
return SimpleNamespace(npz_path=None)
return SimpleNamespace( return SimpleNamespace(
npz_path=planner_profile_data npz_path=None,
if planner_profile_data.suffix == ".npz"
else None
) )
return module.resolve_planner_profile_data(planner_profile_data) return module.resolve_planner_profile_data(planner_profile_data)
...@@ -62,13 +64,16 @@ def _load_engine_args(raw_args: str | None): ...@@ -62,13 +64,16 @@ def _load_engine_args(raw_args: str | None):
"worker_type must be one of 'aggregated', 'prefill', or 'decode'" "worker_type must be one of 'aggregated', 'prefill', or 'decode'"
) )
if "planner_profile_data" in raw: if "planner_profile_data" in raw:
profile_data_result = resolve_planner_profile_data( if raw["planner_profile_data"] is None:
Path(raw["planner_profile_data"])
)
if profile_data_result.npz_path is not None:
raw["planner_profile_data"] = str(profile_data_result.npz_path)
else:
del raw["planner_profile_data"] del raw["planner_profile_data"]
else:
profile_data_result = resolve_planner_profile_data(
Path(raw["planner_profile_data"])
)
if profile_data_result.npz_path is not None:
raw["planner_profile_data"] = str(profile_data_result.npz_path)
else:
del raw["planner_profile_data"]
return MockEngineArgs.from_json(json.dumps(raw)) return MockEngineArgs.from_json(json.dumps(raw))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment