"examples/vscode:/vscode.git/clone" did not exist on "44250d44f6ab20d0674796b6a8d0ba57a4bc230e"
Unverified Commit 4abafffe authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

fix(mocker): cut AIC replay overhead (#7692)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 95c45509
......@@ -22,6 +22,7 @@ DEFAULT_BACKEND_VERSIONS = {
"vllm": "0.12.0",
"sglang": "0.5.6.post2",
}
DEFAULT_STATIC_STRIDE = 32
class AicSession:
......@@ -58,6 +59,10 @@ class AicSession:
self._session = InferenceSession(
model=model, database=database, backend=backend
)
self._database = database
self._model = model
# AIC models consistently expose model_path, but some do not surface model_name.
self._model_name = getattr(model, "model_name", None) or model_path
self._config = config
logger.info(
"AIC session initialized: backend=%s, system=%s, model=%s, tp=%d",
......@@ -67,23 +72,70 @@ class AicSession:
tp_size,
)
def _predict_context_latency(self, batch_size: int, isl: int, prefix: int) -> float:
effective_isl = isl - prefix
if effective_isl <= 0:
raise ValueError(
f"isl must be greater than prefix, got isl={isl}, prefix={prefix}"
)
total_latency = 0.0
for op in self._model.context_ops:
# AIC operations identify kernels via Operation._name; there is no public name accessor.
op_name = getattr(op, "_name", "")
x = batch_size if "logits_gemm" in op_name else batch_size * effective_isl
result = op.query(
self._database,
x=x,
batch_size=batch_size,
beam_width=1,
s=effective_isl,
prefix=prefix,
model_name=self._model_name,
seq_imbalance_correction_scale=1.0,
)
total_latency += float(result)
return total_latency
def _predict_generation_latency(self, batch_size: int, isl: int, osl: int) -> float:
if osl <= 1:
return 0.0
# BaseModel stores speculative decode width on _nextn, which generation_ops scale by.
effective_batch_size = batch_size * (self._model._nextn + 1)
total_latency = 0.0
for step in range(0, osl - 1, DEFAULT_STATIC_STRIDE):
step_latency = 0.0
for op in self._model.generation_ops:
result = op.query(
self._database,
x=effective_batch_size,
batch_size=effective_batch_size,
beam_width=1,
s=isl + step + 1,
model_name=self._model_name,
gen_seq_imbalance_correction_scale=1.0,
)
step_latency += float(result)
repeat_count = min(DEFAULT_STATIC_STRIDE, osl - 1 - step)
total_latency += step_latency * repeat_count
return total_latency
def predict_prefill(
self, batch_size: int, isl: int, prefix: int, osl: int
) -> float:
"""Predict prefill latency in ms. Parameters match AIC RuntimeConfig."""
# AIC requires at least 1 new token (isl > prefix)
actual_prefix = min(prefix, isl - 1) if isl > 0 else 0
rt = self._config.RuntimeConfig(
batch_size=batch_size, isl=isl, osl=osl, prefix=actual_prefix
)
summary = self._session.run_static(mode="static_ctx", runtime_config=rt)
return sum(summary.get_context_latency_dict().values())
return self._predict_context_latency(batch_size, isl, actual_prefix)
def predict_decode(self, batch_size: int, isl: int, osl: int) -> float:
"""Predict decode (generation) latency in ms."""
rt = self._config.RuntimeConfig(batch_size=batch_size, isl=isl, osl=osl)
summary = self._session.run_static(mode="static_gen", runtime_config=rt)
return sum(summary.get_generation_latency_dict().values())
return self._predict_generation_latency(batch_size, isl, osl)
def create_session(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment