Commit aae4edb5 authored by Lysandre's avatar Lysandre
Browse files

Addressing review comment

parent 43b9d938
...@@ -150,7 +150,7 @@ def main(): ...@@ -150,7 +150,7 @@ def main():
def build_compute_metrics_fn(task_name: str) -> Callable[[EvalPrediction], Dict]: def build_compute_metrics_fn(task_name: str) -> Callable[[EvalPrediction], Dict]:
def compute_metrics_fn(p: EvalPrediction): def compute_metrics_fn(p: EvalPrediction):
preds = p.predictions[0] if type(p.predictions) == tuple else p.predictions preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
if output_mode == "classification": if output_mode == "classification":
preds = np.argmax(preds, axis=1) preds = np.argmax(preds, axis=1)
else: # regression else: # regression
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment