Unverified Commit 4dd79783 authored by Roger Wang's avatar Roger Wang Committed by GitHub
Browse files

[Bugfix] Fix regression on pooling models from PR#29621 (#29921)


Signed-off-by: default avatarRoger Wang <hey@rogerw.io>
Co-authored-by: default avatargemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
parent 5cdd6645
...@@ -134,11 +134,17 @@ class EmbeddingItems( ...@@ -134,11 +134,17 @@ class EmbeddingItems(
or a list of embedding tensors (one per item). or a list of embedding tensors (one per item).
""" """
def _unwrap(
self, item: torch.Tensor | MediaWithBytes[torch.Tensor]
) -> torch.Tensor:
"""Extract media from wrapper if present."""
return item.media if isinstance(item, MediaWithBytes) else item
def get_count(self) -> int: def get_count(self) -> int:
return len(self.data) return len(self.data)
def get(self, index: int) -> torch.Tensor: def get(self, index: int) -> torch.Tensor:
return self.data[index] return self._unwrap(self.data[index])
def get_processor_data(self) -> Mapping[str, object]: def get_processor_data(self) -> Mapping[str, object]:
return {} return {}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment