Unverified Commit c2e3fa0b authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Fixing single candidate_label return. (#24023)

parent 6307312d
......@@ -141,6 +141,8 @@ class ZeroShotImageClassificationPipeline(Pipeline):
if self.framework == "pt":
probs = logits.softmax(dim=-1).squeeze(-1)
scores = probs.tolist()
if not isinstance(scores, list):
scores = [scores]
elif self.framework == "tf":
probs = stable_softmax(logits, axis=-1)
scores = probs.numpy().tolist()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment