"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "fa8ee8e85546469678c00e91e0866c8429fccef8"
Unverified Commit 3deffc1d authored by Joe Davison's avatar Joe Davison Committed by GitHub
Browse files

Zero shot classification pipeline (#5760)

* add initial zero-shot pipeline

* change default args

* update default template

* add label string splitting

* add str labels support, remove nli from name

* style

* add input validation and working tf defaults

* tests

* quality check

* add docstring to __call__

* add slow tests

* Change truncation to only_first

also lower precision on tests for readibility

* style
parent 1246b20f
...@@ -816,6 +816,159 @@ class TextClassificationPipeline(Pipeline): ...@@ -816,6 +816,159 @@ class TextClassificationPipeline(Pipeline):
] ]
class ZeroShotClassificationArgumentHandler(ArgumentHandler):
"""
Handles arguments for zero-shot for text classification by turning each possible label into an NLI
premise/hypothesis pair.
"""
def _parse_labels(self, labels):
if isinstance(labels, str):
labels = [label.strip() for label in labels.split(",")]
return labels
def __call__(self, sequences, labels, hypothesis_template):
if len(labels) == 0 or len(sequences) == 0:
raise ValueError("You must include at least one label and at least one sequence.")
if hypothesis_template.format(labels[0]) == hypothesis_template:
raise ValueError(
(
'The provided hypothesis_template "{}" was not able to be formatted with the target labels. '
"Make sure the passed template includes formatting syntax such as {{}} where the label should go."
).format(hypothesis_template)
)
if isinstance(sequences, str):
sequences = [sequences]
labels = self._parse_labels(labels)
sequence_pairs = []
for sequence in sequences:
sequence_pairs.extend([[sequence, hypothesis_template.format(label)] for label in labels])
return sequence_pairs
class ZeroShotClassificationPipeline(Pipeline):
"""
NLI-based zero-shot classification pipeline using a ModelForSequenceClassification head with models trained on
NLI tasks.
Any combination of sequences and labels can be passed and each combination will be posed as a premise/hypothesis
pair and passed to the pre-trained model. Then logit for `entailment` is then taken as the logit for the
candidate label being valid. Any NLI model can be used as long as the first output logit corresponds to
`contradiction` and the last to `entailment`.
This pipeline can currently be loaded from the :func:`~transformers.pipeline` method using the following task
identifier(s):
- "zero-shot-classification"
The models that this pipeline can use are models that have been fine-tuned on a Natural Language Inference task.
See the up-to-date list of available models on
`huggingface.co/models <https://huggingface.co/models?search=nli>`__.
Arguments:
model (:obj:`~transformers.PreTrainedModel` or :obj:`~transformers.TFPreTrainedModel`):
The model that will be used by the pipeline to make predictions. This needs to be a model inheriting from
:class:`~transformers.PreTrainedModel` for PyTorch and :class:`~transformers.TFPreTrainedModel` for
TensorFlow.
tokenizer (:obj:`~transformers.PreTrainedTokenizer`):
The tokenizer that will be used by the pipeline to encode data for the model. This object inherits from
:class:`~transformers.PreTrainedTokenizer`.
modelcard (:obj:`str` or :class:`~transformers.ModelCard`, `optional`, defaults to :obj:`None`):
Model card attributed to the model for this pipeline.
framework (:obj:`str`, `optional`, defaults to :obj:`None`):
The framework to use, either "pt" for PyTorch or "tf" for TensorFlow. The specified framework must be
installed.
If no framework is specified, will default to the one currently installed. If no framework is specified
and both frameworks are installed, will default to PyTorch.
args_parser (:class:`~transformers.pipelines.ArgumentHandler`, `optional`, defaults to :obj:`None`):
Reference to the object in charge of parsing supplied pipeline parameters.
device (:obj:`int`, `optional`, defaults to :obj:`-1`):
Device ordinal for CPU/GPU supports. Setting this to -1 will leverage CPU, >=0 will run the model
on the associated CUDA device id.
"""
def __init__(self, args_parser=ZeroShotClassificationArgumentHandler(), *args, **kwargs):
super().__init__(*args, args_parser=args_parser, **kwargs)
def _parse_and_tokenize(self, *args, padding=True, add_special_tokens=True, **kwargs):
"""
Parse arguments and tokenize only_first so that hypothesis (label) is not truncated
"""
inputs = self._args_parser(*args, **kwargs)
inputs = self.tokenizer(
inputs,
add_special_tokens=add_special_tokens,
return_tensors=self.framework,
padding=padding,
truncation="only_first",
)
return inputs
def __call__(self, sequences, candidate_labels, hypothesis_template="This example is {}.", multi_class=False):
"""
NLI-based zero-shot classification. Any combination of sequences and labels can be passed and each
combination will be posed as a premise/hypothesis pair and passed to the pre-trained model. Then logit for
`entailment` is then taken as the logit for the candidate label being valid. Any NLI model can be used as
long as the first output logit corresponds to `contradiction` and the last to `entailment`.
Args:
sequences (:obj:`str` or obj:`List`):
The sequence or sequences to classify. Truncated if model input is too large.
candidate_labels (:obj:`str` or obj:`List`):
The set of possible class labels to classify each sequence into. Can be a single label, a string of
comma-separated labels, or a list of labels.
hypothesis_template (obj:`str`, defaults to "This example is {}."):
The template used to turn each label into an NLI-style hypothesis. This template must include a {}
or similar syntax for the candidate label to be inserted into the template. For example, the default
template is "This example is {}." With the candidate label "sports", this would be fed into the model
like `<cls> sequence to classify <sep> This example is sports . <sep>`. The default template works
well in many cases, but it may be worthwhile to experiment with different templates depending on the
task setting.
multi_class (obj:`bool`, defaults to False):
When False, it is assumed that only one candidate label can be true, and the scores are normalized
such that the sum of the label likelihoods for each sequence is 1. When True, the labels are
considered independent and probabilities are normalized for each candidate by doing a of softmax of
the entailment score vs. the contradiction score.
"""
outputs = super().__call__(sequences, candidate_labels, hypothesis_template)
num_sequences = 1 if isinstance(sequences, str) else len(sequences)
candidate_labels = self._args_parser._parse_labels(candidate_labels)
reshaped_outputs = outputs.reshape((num_sequences, len(candidate_labels), -1))
if len(candidate_labels) == 1:
multi_class = True
if not multi_class:
# softmax the "entailment" logits over all candidate labels
entail_logits = reshaped_outputs[..., -1]
scores = np.exp(entail_logits) / np.exp(entail_logits).sum(-1, keepdims=True)
else:
# softmax over the entailment vs. contradiction dim for each label independently
entail_contr_logits = reshaped_outputs[..., [0, -1]]
scores = np.exp(entail_contr_logits) / np.exp(entail_contr_logits).sum(-1, keepdims=True)
scores = scores[..., 1]
result = []
for iseq in range(num_sequences):
top_inds = list(reversed(scores[iseq].argsort()))
result.append(
{
"sequence": sequences if num_sequences == 1 else sequences[iseq],
"labels": [candidate_labels[i] for i in top_inds],
"scores": scores[iseq][top_inds].tolist(),
}
)
if len(result) == 1:
return result[0]
return result
class FillMaskPipeline(Pipeline): class FillMaskPipeline(Pipeline):
""" """
Masked language modeling prediction pipeline using ModelWithLMHead head. See the Masked language modeling prediction pipeline using ModelWithLMHead head. See the
...@@ -1813,6 +1966,16 @@ SUPPORTED_TASKS = { ...@@ -1813,6 +1966,16 @@ SUPPORTED_TASKS = {
"pt": AutoModelWithLMHead if is_torch_available() else None, "pt": AutoModelWithLMHead if is_torch_available() else None,
"default": {"model": {"pt": "gpt2", "tf": "gpt2"}}, "default": {"model": {"pt": "gpt2", "tf": "gpt2"}},
}, },
"zero-shot-classification": {
"impl": ZeroShotClassificationPipeline,
"tf": TFAutoModelForSequenceClassification if is_tf_available() else None,
"pt": AutoModelForSequenceClassification if is_torch_available() else None,
"default": {
"model": {"pt": "facebook/bart-large-mnli", "tf": "roberta-large-mnli"},
"config": {"pt": "facebook/bart-large-mnli", "tf": "roberta-large-mnli"},
"tokenizer": {"pt": "facebook/bart-large-mnli", "tf": "roberta-large-mnli"},
},
},
} }
......
...@@ -318,6 +318,138 @@ class MonoColumnInputTestCase(unittest.TestCase): ...@@ -318,6 +318,138 @@ class MonoColumnInputTestCase(unittest.TestCase):
QA_FINETUNED_MODELS = ["sshleifer/tiny-distilbert-base-cased-distilled-squad"] QA_FINETUNED_MODELS = ["sshleifer/tiny-distilbert-base-cased-distilled-squad"]
class ZeroShotClassificationPipelineTests(unittest.TestCase):
def _test_scores_sum_to_one(self, result):
sum = 0.0
for score in result["scores"]:
sum += score
self.assertAlmostEqual(sum, 1.0)
def _test_zero_shot_pipeline(self, nlp):
output_keys = {"sequence", "labels", "scores"}
valid_mono_inputs = [
{"sequences": "Who are you voting for in 2020?", "candidate_labels": "politics"},
{"sequences": "Who are you voting for in 2020?", "candidate_labels": ["politics"]},
{"sequences": "Who are you voting for in 2020?", "candidate_labels": "politics, public health"},
{"sequences": "Who are you voting for in 2020?", "candidate_labels": ["politics", "public health"]},
{"sequences": ["Who are you voting for in 2020?"], "candidate_labels": "politics"},
{
"sequences": "Who are you voting for in 2020?",
"candidate_labels": "politics",
"hypothesis_template": "This text is about {}",
},
]
valid_multi_input = {
"sequences": ["Who are you voting for in 2020?", "What is the capital of Spain?"],
"candidate_labels": "politics",
}
invalid_inputs = [
{"sequences": None, "candidate_labels": "politics"},
{"sequences": "", "candidate_labels": "politics"},
{"sequences": "Who are you voting for in 2020?", "candidate_labels": None},
{"sequences": "Who are you voting for in 2020?", "candidate_labels": ""},
{
"sequences": "Who are you voting for in 2020?",
"candidate_labels": "politics",
"hypothesis_template": None,
},
{
"sequences": "Who are you voting for in 2020?",
"candidate_labels": "politics",
"hypothesis_template": "",
},
{
"sequences": "Who are you voting for in 2020?",
"candidate_labels": "politics",
"hypothesis_template": "Template without formatting syntax.",
},
]
self.assertIsNotNone(nlp)
for mono_input in valid_mono_inputs:
mono_result = nlp(**mono_input)
self.assertIsInstance(mono_result, dict)
if len(mono_result["labels"]) > 1:
self._test_scores_sum_to_one(mono_result)
for key in output_keys:
self.assertIn(key, mono_result)
multi_result = nlp(**valid_multi_input)
self.assertIsInstance(multi_result, list)
self.assertIsInstance(multi_result[0], dict)
self.assertEqual(len(multi_result), len(valid_multi_input["sequences"]))
for result in multi_result:
for key in output_keys:
self.assertIn(key, result)
if len(result["labels"]) > 1:
self._test_scores_sum_to_one(result)
for bad_input in invalid_inputs:
self.assertRaises(Exception, nlp, **bad_input)
def _test_zero_shot_pipeline_outputs(self, nlp):
inputs = [
{
"sequences": "Who are you voting for in 2020?",
"candidate_labels": ["politics", "public health", "science"],
},
{
"sequences": "The dominant sequence transduction models are based on complex recurrent or convolutional neural networks in an encoder-decoder configuration. The best performing models also connect the encoder and decoder through an attention mechanism. We propose a new simple network architecture, the Transformer, based solely on attention mechanisms, dispensing with recurrence and convolutions entirely. Experiments on two machine translation tasks show these models to be superior in quality while being more parallelizable and requiring significantly less time to train. Our model achieves 28.4 BLEU on the WMT 2014 English-to-German translation task, improving over the existing best results, including ensembles by over 2 BLEU. On the WMT 2014 English-to-French translation task, our model establishes a new single-model state-of-the-art BLEU score of 41.8 after training for 3.5 days on eight GPUs, a small fraction of the training costs of the best models from the literature. We show that the Transformer generalizes well to other tasks by applying it successfully to English constituency parsing both with large and limited training data.",
"candidate_labels": ["machine learning", "statistics", "translation", "vision"],
"multi_class": True,
},
]
expected_outputs = [
{
"sequence": "Who are you voting for in 2020?",
"labels": ["politics", "public health", "science"],
"scores": [0.975, 0.015, 0.008],
},
{
"sequence": "The dominant sequence transduction models are based on complex recurrent or convolutional neural networks in an encoder-decoder configuration. The best performing models also connect the encoder and decoder through an attention mechanism. We propose a new simple network architecture, the Transformer, based solely on attention mechanisms, dispensing with recurrence and convolutions entirely. Experiments on two machine translation tasks show these models to be superior in quality while being more parallelizable and requiring significantly less time to train. Our model achieves 28.4 BLEU on the WMT 2014 English-to-German translation task, improving over the existing best results, including ensembles by over 2 BLEU. On the WMT 2014 English-to-French translation task, our model establishes a new single-model state-of-the-art BLEU score of 41.8 after training for 3.5 days on eight GPUs, a small fraction of the training costs of the best models from the literature. We show that the Transformer generalizes well to other tasks by applying it successfully to English constituency parsing both with large and limited training data.",
"labels": ["translation", "machine learning", "vision", "statistics"],
"scores": [0.817, 0.712, 0.018, 0.017],
},
]
for input, expected_output in zip(inputs, expected_outputs):
output = nlp(**input)
for key in output:
if key == "scores":
for output_score, expected_score in zip(output[key], expected_output[key]):
self.assertAlmostEqual(output_score, expected_score, places=2)
else:
self.assertEqual(output[key], expected_output[key])
@require_torch
def test_torch_zero_shot_classification(self):
for model_name in TEXT_CLASSIF_FINETUNED_MODELS:
nlp = pipeline(task="zero-shot-classification", model=model_name, tokenizer=model_name)
self._test_zero_shot_pipeline(nlp)
@require_tf
def test_tf_zero_shot_classification(self):
for model_name in TEXT_CLASSIF_FINETUNED_MODELS:
nlp = pipeline(task="zero-shot-classification", model=model_name, tokenizer=model_name, framework="tf")
self._test_zero_shot_pipeline(nlp)
@slow
@require_torch
def test_torch_zero_shot_outputs(self):
nlp = pipeline(task="zero-shot-classification", model="roberta-large-mnli")
self._test_zero_shot_pipeline_outputs(nlp)
@slow
@require_tf
def test_tf_zero_shot_outputs(self):
nlp = pipeline(task="zero-shot-classification", model="roberta-large-mnli", framework="tf")
self._test_zero_shot_pipeline_outputs(nlp)
class QAPipelineTests(unittest.TestCase): class QAPipelineTests(unittest.TestCase):
def _test_qa_pipeline(self, nlp): def _test_qa_pipeline(self, nlp):
output_keys = {"score", "answer", "start", "end"} output_keys = {"score", "answer", "start", "end"}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment