"tests/models/vscode:/vscode.git/clone" did not exist on "836921fdeb498820b71dcc7b70e990e828f4c6bc"
Unverified Commit da503ea0 authored by Vijay S Kalmath's avatar Vijay S Kalmath Committed by GitHub
Browse files

Migrate metrics used in flax examples to Evaluate (#18348)

Currently, tensorflow examples use the `load_metric` function from
Datasets library, commit migrates function call to `load` function
from Evaluate library.
parent a2586795
......@@ -5,3 +5,4 @@ nltk
rouge-score
seqeval
tensorboard
evaluate >= 0.2.0
\ No newline at end of file
......@@ -31,10 +31,11 @@ from typing import Callable, Optional
import datasets
import nltk # Here to have a nice missing dependency error message early on
import numpy as np
from datasets import Dataset, load_dataset, load_metric
from datasets import Dataset, load_dataset
from PIL import Image
from tqdm import tqdm
import evaluate
import jax
import jax.numpy as jnp
import optax
......@@ -811,7 +812,7 @@ def main():
yield batch
# Metric
metric = load_metric("rouge")
metric = evaluate.load("rouge")
def postprocess_text(preds, labels):
preds = [pred.strip() for pred in preds]
......
......@@ -32,9 +32,10 @@ from typing import Any, Callable, Dict, Optional, Tuple
import datasets
import numpy as np
from datasets import load_dataset, load_metric
from datasets import load_dataset
from tqdm import tqdm
import evaluate
import jax
import jax.numpy as jnp
import optax
......@@ -776,7 +777,7 @@ def main():
references = [{"id": ex["id"], "answers": ex[answer_column_name]} for ex in examples]
return EvalPrediction(predictions=formatted_predictions, label_ids=references)
metric = load_metric("squad_v2" if data_args.version_2_with_negative else "squad")
metric = evaluate.load("squad_v2" if data_args.version_2_with_negative else "squad")
def compute_metrics(p: EvalPrediction):
return metric.compute(predictions=p.predictions, references=p.label_ids)
......
......@@ -33,9 +33,10 @@ from typing import Callable, Optional
import datasets
import nltk # Here to have a nice missing dependency error message early on
import numpy as np
from datasets import Dataset, load_dataset, load_metric
from datasets import Dataset, load_dataset
from tqdm import tqdm
import evaluate
import jax
import jax.numpy as jnp
import optax
......@@ -656,7 +657,7 @@ def main():
)
# Metric
metric = load_metric("rouge")
metric = evaluate.load("rouge")
def postprocess_text(preds, labels):
preds = [pred.strip() for pred in preds]
......
......@@ -27,9 +27,10 @@ from typing import Any, Callable, Dict, Optional, Tuple
import datasets
import numpy as np
from datasets import load_dataset, load_metric
from datasets import load_dataset
from tqdm import tqdm
import evaluate
import jax
import jax.numpy as jnp
import optax
......@@ -570,9 +571,9 @@ def main():
p_eval_step = jax.pmap(eval_step, axis_name="batch")
if data_args.task_name is not None:
metric = load_metric("glue", data_args.task_name)
metric = evaluate.load("glue", data_args.task_name)
else:
metric = load_metric("accuracy")
metric = evaluate.load("accuracy")
logger.info(f"===== Starting training ({num_epochs} epochs) =====")
train_time = 0
......
......@@ -29,9 +29,10 @@ from typing import Any, Callable, Dict, Optional, Tuple
import datasets
import numpy as np
from datasets import ClassLabel, load_dataset, load_metric
from datasets import ClassLabel, load_dataset
from tqdm import tqdm
import evaluate
import jax
import jax.numpy as jnp
import optax
......@@ -646,7 +647,7 @@ def main():
p_eval_step = jax.pmap(eval_step, axis_name="batch")
metric = load_metric("seqeval")
metric = evaluate.load("seqeval")
def get_labels(y_pred, y_true):
# Transform predictions and references tensos to numpy arrays
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment