Unverified Commit 81fe8afa authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Adding support for `hidden_states` and `attentions` in unbatching (#14420)

support.
parent f25a9332
...@@ -747,9 +747,14 @@ if is_torch_available(): ...@@ -747,9 +747,14 @@ if is_torch_available():
else: else:
loader_batched = {} loader_batched = {}
for k, element in self._loader_batch_data.items(): for k, element in self._loader_batch_data.items():
if k == "past_key_values": if k in {"hidden_states", "past_key_values", "attentions"} and isinstance(element, tuple):
continue if isinstance(element[0], torch.Tensor):
if isinstance(element[self._loader_batch_index], torch.Tensor): loader_batched[k] = tuple(el[self._loader_batch_index].unsqueeze(0) for el in element)
elif isinstance(element[0], np.ndarray):
loader_batched[k] = tuple(
np.expand_dims(el[self._loader_batch_index], 0) for el in element
)
elif isinstance(element[self._loader_batch_index], torch.Tensor):
loader_batched[k] = element[self._loader_batch_index].unsqueeze(0) loader_batched[k] = element[self._loader_batch_index].unsqueeze(0)
elif isinstance(element[self._loader_batch_index], np.ndarray): elif isinstance(element[self._loader_batch_index], np.ndarray):
loader_batched[k] = np.expand_dims(element[self._loader_batch_index], 0) loader_batched[k] = np.expand_dims(element[self._loader_batch_index], 0)
......
...@@ -27,6 +27,7 @@ from transformers import ( ...@@ -27,6 +27,7 @@ from transformers import (
TOKENIZER_MAPPING, TOKENIZER_MAPPING,
AutoFeatureExtractor, AutoFeatureExtractor,
AutoTokenizer, AutoTokenizer,
DistilBertForSequenceClassification,
IBertConfig, IBertConfig,
RobertaConfig, RobertaConfig,
TextClassificationPipeline, TextClassificationPipeline,
...@@ -322,6 +323,19 @@ class CommonPipelineTest(unittest.TestCase): ...@@ -322,6 +323,19 @@ class CommonPipelineTest(unittest.TestCase):
results.append(out) results.append(out)
self.assertEqual(len(results), 10) self.assertEqual(len(results), 10)
@require_torch
def test_unbatch_attentions_hidden_states(self):
model = DistilBertForSequenceClassification.from_pretrained(
"Narsil/tiny-distilbert-sequence-classification", output_hidden_states=True, output_attentions=True
)
tokenizer = AutoTokenizer.from_pretrained("Narsil/tiny-distilbert-sequence-classification")
text_classifier = TextClassificationPipeline(model=model, tokenizer=tokenizer)
# Used to throw an error because `hidden_states` are a tuple of tensors
# instead of the expected tensor.
outputs = text_classifier(["This is great !"] * 20, batch_size=32)
self.assertEqual(len(outputs), 20)
@is_pipeline_test @is_pipeline_test
class PipelinePadTest(unittest.TestCase): class PipelinePadTest(unittest.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment