"...git@developer.sourcefind.cn:chenpangpang/open-webui.git" did not exist on "75e51ecf6d373e152c9c182ca7d339fc50052253"
Unverified Commit 43b9d938 authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

[example/glue] fix compute_metrics_fn for bart like models (#7248)

* fix compute_metrics_fn

* p.predictions -> preds

* apply suggestions
parent 39062d05
...@@ -150,10 +150,11 @@ def main(): ...@@ -150,10 +150,11 @@ def main():
def build_compute_metrics_fn(task_name: str) -> Callable[[EvalPrediction], Dict]: def build_compute_metrics_fn(task_name: str) -> Callable[[EvalPrediction], Dict]:
def compute_metrics_fn(p: EvalPrediction): def compute_metrics_fn(p: EvalPrediction):
preds = p.predictions[0] if type(p.predictions) == tuple else p.predictions
if output_mode == "classification": if output_mode == "classification":
preds = np.argmax(p.predictions, axis=1) preds = np.argmax(preds, axis=1)
elif output_mode == "regression": else: # regression
preds = np.squeeze(p.predictions) preds = np.squeeze(preds)
return glue_compute_metrics(task_name, preds, p.label_ids) return glue_compute_metrics(task_name, preds, p.label_ids)
return compute_metrics_fn return compute_metrics_fn
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment