Unverified Commit 5898b135 authored by Nick Hill's avatar Nick Hill Committed by GitHub
Browse files

[BugFix] Fix KVConnectorOutput TPU breakage (#22598)


Signed-off-by: default avatarNick Hill <nhill@redhat.com>
parent b799f4b9
......@@ -179,6 +179,13 @@ def create_model_runner_output(
sampled_token = EOS_TOKEN_ID if use_eos else 0
sampled_token_ids = [[sampled_token] for _ in req_ids]
kv_connector_output = None if (
finished_sending is None
and finished_recving is None) else KVConnectorOutput(
finished_sending=finished_sending,
finished_recving=finished_recving,
)
# Make output data structure.
return ModelRunnerOutput(
req_ids=req_ids,
......@@ -188,10 +195,7 @@ def create_model_runner_output(
logprobs=None,
prompt_logprobs_dict={},
pooler_output=None,
kv_connector_output=KVConnectorOutput(
finished_sending=finished_sending,
finished_recving=finished_recving,
),
kv_connector_output=kv_connector_output,
)
......
......@@ -1151,8 +1151,8 @@ class Scheduler(SchedulerInterface):
scheduler the request during the next step.
"""
assert self.connector is not None
self.connector.update_connector_output(kv_connector_output)
if self.connector is not None:
self.connector.update_connector_output(kv_connector_output)
# KV Connector:: update recv and send status from last step.
for req_id in (kv_connector_output.finished_recving or ()):
......
......@@ -1138,6 +1138,13 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
i, target_slice] = valid_sampled_token_ids[i]
req_state.output_token_ids.extend(valid_sampled_token_ids[i])
kv_connector_output = None if (
finished_sending is None
and finished_recving is None) else KVConnectorOutput(
finished_sending=finished_sending,
finished_recving=finished_recving,
)
model_runner_output = ModelRunnerOutput(
req_ids=req_ids,
req_id_to_index=self.input_batch.req_id_to_index,
......@@ -1146,10 +1153,8 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
logprobs=logprobs_lists,
prompt_logprobs_dict=prompt_logprobs_dict,
pooler_output=[],
kv_connector_output=KVConnectorOutput(
finished_sending=finished_sending,
finished_recving=finished_recving,
))
kv_connector_output=kv_connector_output,
)
# Check there are no new graphs compiled - all the graphs should be
# captured and compiled during warm up.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment