Unverified Commit da503ea0 authored by Vijay S Kalmath's avatar Vijay S Kalmath Committed by GitHub
Browse files

Migrate metrics used in flax examples to Evaluate (#18348)

Currently, tensorflow examples use the `load_metric` function from
Datasets library, commit migrates function call to `load` function
from Evaluate library.
parent a2586795
...@@ -4,4 +4,5 @@ conllu ...@@ -4,4 +4,5 @@ conllu
nltk nltk
rouge-score rouge-score
seqeval seqeval
tensorboard tensorboard
\ No newline at end of file evaluate >= 0.2.0
\ No newline at end of file
...@@ -31,10 +31,11 @@ from typing import Callable, Optional ...@@ -31,10 +31,11 @@ from typing import Callable, Optional
import datasets import datasets
import nltk # Here to have a nice missing dependency error message early on import nltk # Here to have a nice missing dependency error message early on
import numpy as np import numpy as np
from datasets import Dataset, load_dataset, load_metric from datasets import Dataset, load_dataset
from PIL import Image from PIL import Image
from tqdm import tqdm from tqdm import tqdm
import evaluate
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
import optax import optax
...@@ -811,7 +812,7 @@ def main(): ...@@ -811,7 +812,7 @@ def main():
yield batch yield batch
# Metric # Metric
metric = load_metric("rouge") metric = evaluate.load("rouge")
def postprocess_text(preds, labels): def postprocess_text(preds, labels):
preds = [pred.strip() for pred in preds] preds = [pred.strip() for pred in preds]
......
...@@ -32,9 +32,10 @@ from typing import Any, Callable, Dict, Optional, Tuple ...@@ -32,9 +32,10 @@ from typing import Any, Callable, Dict, Optional, Tuple
import datasets import datasets
import numpy as np import numpy as np
from datasets import load_dataset, load_metric from datasets import load_dataset
from tqdm import tqdm from tqdm import tqdm
import evaluate
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
import optax import optax
...@@ -776,7 +777,7 @@ def main(): ...@@ -776,7 +777,7 @@ def main():
references = [{"id": ex["id"], "answers": ex[answer_column_name]} for ex in examples] references = [{"id": ex["id"], "answers": ex[answer_column_name]} for ex in examples]
return EvalPrediction(predictions=formatted_predictions, label_ids=references) return EvalPrediction(predictions=formatted_predictions, label_ids=references)
metric = load_metric("squad_v2" if data_args.version_2_with_negative else "squad") metric = evaluate.load("squad_v2" if data_args.version_2_with_negative else "squad")
def compute_metrics(p: EvalPrediction): def compute_metrics(p: EvalPrediction):
return metric.compute(predictions=p.predictions, references=p.label_ids) return metric.compute(predictions=p.predictions, references=p.label_ids)
......
...@@ -33,9 +33,10 @@ from typing import Callable, Optional ...@@ -33,9 +33,10 @@ from typing import Callable, Optional
import datasets import datasets
import nltk # Here to have a nice missing dependency error message early on import nltk # Here to have a nice missing dependency error message early on
import numpy as np import numpy as np
from datasets import Dataset, load_dataset, load_metric from datasets import Dataset, load_dataset
from tqdm import tqdm from tqdm import tqdm
import evaluate
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
import optax import optax
...@@ -656,7 +657,7 @@ def main(): ...@@ -656,7 +657,7 @@ def main():
) )
# Metric # Metric
metric = load_metric("rouge") metric = evaluate.load("rouge")
def postprocess_text(preds, labels): def postprocess_text(preds, labels):
preds = [pred.strip() for pred in preds] preds = [pred.strip() for pred in preds]
......
...@@ -27,9 +27,10 @@ from typing import Any, Callable, Dict, Optional, Tuple ...@@ -27,9 +27,10 @@ from typing import Any, Callable, Dict, Optional, Tuple
import datasets import datasets
import numpy as np import numpy as np
from datasets import load_dataset, load_metric from datasets import load_dataset
from tqdm import tqdm from tqdm import tqdm
import evaluate
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
import optax import optax
...@@ -570,9 +571,9 @@ def main(): ...@@ -570,9 +571,9 @@ def main():
p_eval_step = jax.pmap(eval_step, axis_name="batch") p_eval_step = jax.pmap(eval_step, axis_name="batch")
if data_args.task_name is not None: if data_args.task_name is not None:
metric = load_metric("glue", data_args.task_name) metric = evaluate.load("glue", data_args.task_name)
else: else:
metric = load_metric("accuracy") metric = evaluate.load("accuracy")
logger.info(f"===== Starting training ({num_epochs} epochs) =====") logger.info(f"===== Starting training ({num_epochs} epochs) =====")
train_time = 0 train_time = 0
......
...@@ -29,9 +29,10 @@ from typing import Any, Callable, Dict, Optional, Tuple ...@@ -29,9 +29,10 @@ from typing import Any, Callable, Dict, Optional, Tuple
import datasets import datasets
import numpy as np import numpy as np
from datasets import ClassLabel, load_dataset, load_metric from datasets import ClassLabel, load_dataset
from tqdm import tqdm from tqdm import tqdm
import evaluate
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
import optax import optax
...@@ -646,7 +647,7 @@ def main(): ...@@ -646,7 +647,7 @@ def main():
p_eval_step = jax.pmap(eval_step, axis_name="batch") p_eval_step = jax.pmap(eval_step, axis_name="batch")
metric = load_metric("seqeval") metric = evaluate.load("seqeval")
def get_labels(y_pred, y_true): def get_labels(y_pred, y_true):
# Transform predictions and references tensos to numpy arrays # Transform predictions and references tensos to numpy arrays
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment