Unverified Commit 417fd28f authored by Nick Hill's avatar Nick Hill Committed by GitHub
Browse files

[Model Runner V2] Fix pooling (#36019)


Signed-off-by: default avatarNick Hill <nickhill123@gmail.com>
parent 7faba503
......@@ -95,8 +95,8 @@ class AsyncPoolingOutput(AsyncModelRunnerOutput):
self.copy_event.record(copy_stream)
def get_output(self) -> ModelRunnerOutput:
pooler_output = list(self.pooler_output_cpu.unbind(dim=0))
self.copy_event.synchronize()
pooler_output = self.pooler_output_cpu.unbind(dim=0)
if self.is_valid_cpu is not None:
is_valid_cpu = self.is_valid_cpu.tolist()
for i, is_valid in enumerate(is_valid_cpu):
......
......@@ -1117,7 +1117,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# The prior execute_model call must have failed.
return None
input_batch, _, _, _, hidden_states, _, kv_connector_output = (
input_batch, _, _, _, hidden_states, _, kv_connector_output, _ = (
self.execute_model_state
)
self.execute_model_state = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment