Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
aacd2123
Unverified
Commit
aacd2123
authored
Sep 09, 2021
by
Nicolas Patry
Committed by
GitHub
Sep 09, 2021
Browse files
Fixing #13381 (#13400)
* Fixing #13381 * Enabling automatic LED models.
parent
db514a75
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
40 additions
and
8 deletions
+40
-8
src/transformers/pipelines/zero_shot_classification.py
src/transformers/pipelines/zero_shot_classification.py
+26
-8
tests/test_pipelines_zero_shot.py
tests/test_pipelines_zero_shot.py
+14
-0
No files found.
src/transformers/pipelines/zero_shot_classification.py
View file @
aacd2123
...
@@ -88,7 +88,7 @@ class ZeroShotClassificationPipeline(Pipeline):
...
@@ -88,7 +88,7 @@ class ZeroShotClassificationPipeline(Pipeline):
hypothesis_template
,
hypothesis_template
,
padding
=
True
,
padding
=
True
,
add_special_tokens
=
True
,
add_special_tokens
=
True
,
truncation
=
TruncationStrategy
.
DO_NOT_TRUNCATE
,
truncation
=
TruncationStrategy
.
ONLY_FIRST
,
**
kwargs
**
kwargs
):
):
"""
"""
...
@@ -113,6 +113,7 @@ class ZeroShotClassificationPipeline(Pipeline):
...
@@ -113,6 +113,7 @@ class ZeroShotClassificationPipeline(Pipeline):
)
)
inputs
.
append
(
model_input
)
inputs
.
append
(
model_input
)
else
:
else
:
try
:
inputs
=
self
.
tokenizer
(
inputs
=
self
.
tokenizer
(
sequence_pairs
,
sequence_pairs
,
add_special_tokens
=
add_special_tokens
,
add_special_tokens
=
add_special_tokens
,
...
@@ -120,6 +121,23 @@ class ZeroShotClassificationPipeline(Pipeline):
...
@@ -120,6 +121,23 @@ class ZeroShotClassificationPipeline(Pipeline):
padding
=
padding
,
padding
=
padding
,
truncation
=
truncation
,
truncation
=
truncation
,
)
)
except
Exception
as
e
:
if
"too short"
in
str
(
e
):
# tokenizers might yell that we want to truncate
# to a value that is not even reached by the input.
# In that case we don't want to truncate.
# It seems there's not a really better way to catch that
# exception.
inputs
=
self
.
tokenizer
(
sequence_pairs
,
add_special_tokens
=
add_special_tokens
,
return_tensors
=
return_tensors
,
padding
=
padding
,
truncation
=
TruncationStrategy
.
DO_NOT_TRUNCATE
,
)
else
:
raise
e
return
inputs
return
inputs
...
...
tests/test_pipelines_zero_shot.py
View file @
aacd2123
...
@@ -105,6 +105,20 @@ class ZeroShotClassificationPipelineTests(unittest.TestCase, metaclass=PipelineT
...
@@ -105,6 +105,20 @@ class ZeroShotClassificationPipelineTests(unittest.TestCase, metaclass=PipelineT
zero_shot_classifier
.
model
.
config
.
label2id
=
original_label2id
zero_shot_classifier
.
model
.
config
.
label2id
=
original_label2id
self
.
assertEqual
(
original_entailment
,
zero_shot_classifier
.
entailment_id
)
self
.
assertEqual
(
original_entailment
,
zero_shot_classifier
.
entailment_id
)
@
require_torch
def
test_truncation
(
self
):
zero_shot_classifier
=
pipeline
(
"zero-shot-classification"
,
model
=
"sshleifer/tiny-distilbert-base-cased-distilled-squad"
,
framework
=
"pt"
,
)
# There was a regression in 4.10 for this
# Adding a test so we don't make the mistake again.
# https://github.com/huggingface/transformers/issues/13381#issuecomment-912343499
zero_shot_classifier
(
"Who are you voting for in 2020?"
*
100
,
candidate_labels
=
[
"politics"
,
"public health"
,
"science"
]
)
@
require_torch
@
require_torch
def
test_small_model_pt
(
self
):
def
test_small_model_pt
(
self
):
zero_shot_classifier
=
pipeline
(
zero_shot_classifier
=
pipeline
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment