Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
aacd2123
Unverified
Commit
aacd2123
authored
Sep 09, 2021
by
Nicolas Patry
Committed by
GitHub
Sep 09, 2021
Browse files
Fixing #13381 (#13400)
* Fixing #13381 * Enabling automatic LED models.
parent
db514a75
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
40 additions
and
8 deletions
+40
-8
src/transformers/pipelines/zero_shot_classification.py
src/transformers/pipelines/zero_shot_classification.py
+26
-8
tests/test_pipelines_zero_shot.py
tests/test_pipelines_zero_shot.py
+14
-0
No files found.
src/transformers/pipelines/zero_shot_classification.py
View file @
aacd2123
...
...
@@ -88,7 +88,7 @@ class ZeroShotClassificationPipeline(Pipeline):
hypothesis_template
,
padding
=
True
,
add_special_tokens
=
True
,
truncation
=
TruncationStrategy
.
DO_NOT_TRUNCATE
,
truncation
=
TruncationStrategy
.
ONLY_FIRST
,
**
kwargs
):
"""
...
...
@@ -113,6 +113,7 @@ class ZeroShotClassificationPipeline(Pipeline):
)
inputs
.
append
(
model_input
)
else
:
try
:
inputs
=
self
.
tokenizer
(
sequence_pairs
,
add_special_tokens
=
add_special_tokens
,
...
...
@@ -120,6 +121,23 @@ class ZeroShotClassificationPipeline(Pipeline):
padding
=
padding
,
truncation
=
truncation
,
)
except
Exception
as
e
:
if
"too short"
in
str
(
e
):
# tokenizers might yell that we want to truncate
# to a value that is not even reached by the input.
# In that case we don't want to truncate.
# It seems there's not a really better way to catch that
# exception.
inputs
=
self
.
tokenizer
(
sequence_pairs
,
add_special_tokens
=
add_special_tokens
,
return_tensors
=
return_tensors
,
padding
=
padding
,
truncation
=
TruncationStrategy
.
DO_NOT_TRUNCATE
,
)
else
:
raise
e
return
inputs
...
...
tests/test_pipelines_zero_shot.py
View file @
aacd2123
...
...
@@ -105,6 +105,20 @@ class ZeroShotClassificationPipelineTests(unittest.TestCase, metaclass=PipelineT
zero_shot_classifier
.
model
.
config
.
label2id
=
original_label2id
self
.
assertEqual
(
original_entailment
,
zero_shot_classifier
.
entailment_id
)
@
require_torch
def
test_truncation
(
self
):
zero_shot_classifier
=
pipeline
(
"zero-shot-classification"
,
model
=
"sshleifer/tiny-distilbert-base-cased-distilled-squad"
,
framework
=
"pt"
,
)
# There was a regression in 4.10 for this
# Adding a test so we don't make the mistake again.
# https://github.com/huggingface/transformers/issues/13381#issuecomment-912343499
zero_shot_classifier
(
"Who are you voting for in 2020?"
*
100
,
candidate_labels
=
[
"politics"
,
"public health"
,
"science"
]
)
@
require_torch
def
test_small_model_pt
(
self
):
zero_shot_classifier
=
pipeline
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment