Unverified Commit 9ba41558 authored by Robert Shaw's avatar Robert Shaw Committed by GitHub
Browse files

[BugFix] Fix Embedding Models with TP>1 (#5075)

parent d4f39859
......@@ -79,6 +79,10 @@ class EmbeddingModelRunner(ModelRunner):
execute_model_kwargs.update({"image_input": multi_modal_input})
hidden_states = model_executable(**execute_model_kwargs)
# Only perform pooling in the driver worker.
if not self.is_driver_worker:
return None
return self.model.pooler(hidden_states=hidden_states,
pooling_metadata=pooling_metadata)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment