Unverified Commit ad7f714d authored by Chukwuma Nwaugha's avatar Chukwuma Nwaugha Committed by GitHub
Browse files

hfrunner.classify should return list[list[float]] not list[str] (#29671)


Signed-off-by: default avatarChukwuma Nwaugha <nwaughac@gmail.com>
parent f4341f45
......@@ -459,14 +459,17 @@ class HfRunner:
embeddings.append(embedding)
return embeddings
def classify(self, prompts: list[str]) -> list[str]:
def classify(self, prompts: list[str]) -> list[list[float]]:
# output is final logits
all_inputs = self.get_inputs(prompts)
outputs = []
outputs: list[list[float]] = []
problem_type = getattr(self.config, "problem_type", "")
for inputs in all_inputs:
output = self.model(**self.wrap_device(inputs))
assert isinstance(output.logits, torch.Tensor)
if problem_type == "regression":
logits = output.logits[0].tolist()
elif problem_type == "multi_label_classification":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment