"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "2b81f72be9fa6d69734ae27cfcbfd72b04988fe4"
Unverified Commit 63be8e6f authored by Moses Hohman's avatar Moses Hohman Committed by GitHub
Browse files

Fix typo in classification function selection logic to improve code consistency (#32031)

Make problem_type condition consistent with num_labels condition

The latter condition generally overrides the former, so this is more of a code reading issue. I'm not sure the bug would ever actually get triggered under normal use.
parent 72fb02c4
...@@ -171,9 +171,9 @@ class ImageClassificationPipeline(Pipeline): ...@@ -171,9 +171,9 @@ class ImageClassificationPipeline(Pipeline):
def postprocess(self, model_outputs, function_to_apply=None, top_k=5): def postprocess(self, model_outputs, function_to_apply=None, top_k=5):
if function_to_apply is None: if function_to_apply is None:
if self.model.config.problem_type == "multi_label_classification" or self.model.config.num_labels == 1: if self.model.config.problem_type == "single_label_classification" or self.model.config.num_labels == 1:
function_to_apply = ClassificationFunction.SIGMOID function_to_apply = ClassificationFunction.SIGMOID
elif self.model.config.problem_type == "single_label_classification" or self.model.config.num_labels > 1: elif self.model.config.problem_type == "multi_label_classification" or self.model.config.num_labels > 1:
function_to_apply = ClassificationFunction.SOFTMAX function_to_apply = ClassificationFunction.SOFTMAX
elif hasattr(self.model.config, "function_to_apply") and function_to_apply is None: elif hasattr(self.model.config, "function_to_apply") and function_to_apply is None:
function_to_apply = self.model.config.function_to_apply function_to_apply = self.model.config.function_to_apply
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment