Unverified Commit 63be8e6f authored by Moses Hohman's avatar Moses Hohman Committed by GitHub
Browse files

Fix typo in classification function selection logic to improve code consistency (#32031)

Make problem_type condition consistent with num_labels condition

The latter condition generally overrides the former, so this is more of a code reading issue. I'm not sure the bug would ever actually get triggered under normal use.
parent 72fb02c4
...@@ -171,9 +171,9 @@ class ImageClassificationPipeline(Pipeline): ...@@ -171,9 +171,9 @@ class ImageClassificationPipeline(Pipeline):
def postprocess(self, model_outputs, function_to_apply=None, top_k=5): def postprocess(self, model_outputs, function_to_apply=None, top_k=5):
if function_to_apply is None: if function_to_apply is None:
if self.model.config.problem_type == "multi_label_classification" or self.model.config.num_labels == 1: if self.model.config.problem_type == "single_label_classification" or self.model.config.num_labels == 1:
function_to_apply = ClassificationFunction.SIGMOID function_to_apply = ClassificationFunction.SIGMOID
elif self.model.config.problem_type == "single_label_classification" or self.model.config.num_labels > 1: elif self.model.config.problem_type == "multi_label_classification" or self.model.config.num_labels > 1:
function_to_apply = ClassificationFunction.SOFTMAX function_to_apply = ClassificationFunction.SOFTMAX
elif hasattr(self.model.config, "function_to_apply") and function_to_apply is None: elif hasattr(self.model.config, "function_to_apply") and function_to_apply is None:
function_to_apply = self.model.config.function_to_apply function_to_apply = self.model.config.function_to_apply
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment