"...git@developer.sourcefind.cn:2222/OpenDAS/vllm_cscc.git" did not exist on "11b556878b958043e9c026919d4cba0236c85143"
Unverified Commit 417fd28f authored by Nick Hill's avatar Nick Hill Committed by GitHub
Browse files

[Model Runner V2] Fix pooling (#36019)


Signed-off-by: default avatarNick Hill <nickhill123@gmail.com>
parent 7faba503
...@@ -95,8 +95,8 @@ class AsyncPoolingOutput(AsyncModelRunnerOutput): ...@@ -95,8 +95,8 @@ class AsyncPoolingOutput(AsyncModelRunnerOutput):
self.copy_event.record(copy_stream) self.copy_event.record(copy_stream)
def get_output(self) -> ModelRunnerOutput: def get_output(self) -> ModelRunnerOutput:
pooler_output = list(self.pooler_output_cpu.unbind(dim=0))
self.copy_event.synchronize() self.copy_event.synchronize()
pooler_output = self.pooler_output_cpu.unbind(dim=0)
if self.is_valid_cpu is not None: if self.is_valid_cpu is not None:
is_valid_cpu = self.is_valid_cpu.tolist() is_valid_cpu = self.is_valid_cpu.tolist()
for i, is_valid in enumerate(is_valid_cpu): for i, is_valid in enumerate(is_valid_cpu):
......
...@@ -1117,7 +1117,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1117,7 +1117,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# The prior execute_model call must have failed. # The prior execute_model call must have failed.
return None return None
input_batch, _, _, _, hidden_states, _, kv_connector_output = ( input_batch, _, _, _, hidden_states, _, kv_connector_output, _ = (
self.execute_model_state self.execute_model_state
) )
self.execute_model_state = None self.execute_model_state = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment