"...resnet50_tensorflow.git" did not exist on "9b57f41ce21cd7264c52140c9ab31cdfc5169fcd"
Unverified Commit 81fe8afa authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Adding support for `hidden_states` and `attentions` in unbatching (#14420)

support.
parent f25a9332
......@@ -747,9 +747,14 @@ if is_torch_available():
else:
loader_batched = {}
for k, element in self._loader_batch_data.items():
if k == "past_key_values":
continue
if isinstance(element[self._loader_batch_index], torch.Tensor):
if k in {"hidden_states", "past_key_values", "attentions"} and isinstance(element, tuple):
if isinstance(element[0], torch.Tensor):
loader_batched[k] = tuple(el[self._loader_batch_index].unsqueeze(0) for el in element)
elif isinstance(element[0], np.ndarray):
loader_batched[k] = tuple(
np.expand_dims(el[self._loader_batch_index], 0) for el in element
)
elif isinstance(element[self._loader_batch_index], torch.Tensor):
loader_batched[k] = element[self._loader_batch_index].unsqueeze(0)
elif isinstance(element[self._loader_batch_index], np.ndarray):
loader_batched[k] = np.expand_dims(element[self._loader_batch_index], 0)
......
......@@ -27,6 +27,7 @@ from transformers import (
TOKENIZER_MAPPING,
AutoFeatureExtractor,
AutoTokenizer,
DistilBertForSequenceClassification,
IBertConfig,
RobertaConfig,
TextClassificationPipeline,
......@@ -322,6 +323,19 @@ class CommonPipelineTest(unittest.TestCase):
results.append(out)
self.assertEqual(len(results), 10)
@require_torch
def test_unbatch_attentions_hidden_states(self):
model = DistilBertForSequenceClassification.from_pretrained(
"Narsil/tiny-distilbert-sequence-classification", output_hidden_states=True, output_attentions=True
)
tokenizer = AutoTokenizer.from_pretrained("Narsil/tiny-distilbert-sequence-classification")
text_classifier = TextClassificationPipeline(model=model, tokenizer=tokenizer)
# Used to throw an error because `hidden_states` are a tuple of tensors
# instead of the expected tensor.
outputs = text_classifier(["This is great !"] * 20, batch_size=32)
self.assertEqual(len(outputs), 20)
@is_pipeline_test
class PipelinePadTest(unittest.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment