Unverified Commit 1f843991 authored by atturaioe's avatar atturaioe Committed by GitHub
Browse files

Migrate metric to Evaluate in Pytorch examples (#18369)

* Migrate metric to Evaluate in pytorch examples

* Remove unused imports
parent 25ec12ea
...@@ -26,8 +26,9 @@ from typing import Optional ...@@ -26,8 +26,9 @@ from typing import Optional
import datasets import datasets
import numpy as np import numpy as np
from datasets import load_dataset, load_metric from datasets import load_dataset
import evaluate
import transformers import transformers
from transformers import ( from transformers import (
AutoConfig, AutoConfig,
...@@ -349,7 +350,7 @@ def main(): ...@@ -349,7 +350,7 @@ def main():
) )
# Get the metric function # Get the metric function
metric = load_metric("xnli") metric = evaluate.load("xnli")
# You can define your custom compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a # You can define your custom compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a
# predictions and label_ids field) and has to return a dictionary string to float. # predictions and label_ids field) and has to return a dictionary string to float.
......
...@@ -27,8 +27,9 @@ from typing import Optional ...@@ -27,8 +27,9 @@ from typing import Optional
import datasets import datasets
import numpy as np import numpy as np
from datasets import ClassLabel, load_dataset, load_metric from datasets import ClassLabel, load_dataset
import evaluate
import transformers import transformers
from transformers import ( from transformers import (
AutoConfig, AutoConfig,
...@@ -504,7 +505,7 @@ def main(): ...@@ -504,7 +505,7 @@ def main():
data_collator = DataCollatorForTokenClassification(tokenizer, pad_to_multiple_of=8 if training_args.fp16 else None) data_collator = DataCollatorForTokenClassification(tokenizer, pad_to_multiple_of=8 if training_args.fp16 else None)
# Metrics # Metrics
metric = load_metric("seqeval") metric = evaluate.load("seqeval")
def compute_metrics(p): def compute_metrics(p):
predictions, labels = p predictions, labels = p
......
...@@ -28,10 +28,11 @@ from pathlib import Path ...@@ -28,10 +28,11 @@ from pathlib import Path
import datasets import datasets
import torch import torch
from datasets import ClassLabel, load_dataset, load_metric from datasets import ClassLabel, load_dataset
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from tqdm.auto import tqdm from tqdm.auto import tqdm
import evaluate
import transformers import transformers
from accelerate import Accelerator from accelerate import Accelerator
from accelerate.logging import get_logger from accelerate.logging import get_logger
...@@ -580,7 +581,7 @@ def main(): ...@@ -580,7 +581,7 @@ def main():
accelerator.init_trackers("ner_no_trainer", experiment_config) accelerator.init_trackers("ner_no_trainer", experiment_config)
# Metrics # Metrics
metric = load_metric("seqeval") metric = evaluate.load("seqeval")
def get_labels(predictions, references): def get_labels(predictions, references):
# Transform predictions and references tensos to numpy arrays # Transform predictions and references tensos to numpy arrays
......
...@@ -26,8 +26,9 @@ from typing import Optional ...@@ -26,8 +26,9 @@ from typing import Optional
import datasets import datasets
import numpy as np import numpy as np
from datasets import load_dataset, load_metric from datasets import load_dataset
import evaluate
import transformers import transformers
from transformers import ( from transformers import (
AutoConfig, AutoConfig,
...@@ -522,7 +523,7 @@ def main(): ...@@ -522,7 +523,7 @@ def main():
) )
# Metric # Metric
metric = load_metric("sacrebleu") metric = evaluate.load("sacrebleu")
def postprocess_text(preds, labels): def postprocess_text(preds, labels):
preds = [pred.strip() for pred in preds] preds = [pred.strip() for pred in preds]
......
...@@ -29,10 +29,11 @@ from pathlib import Path ...@@ -29,10 +29,11 @@ from pathlib import Path
import datasets import datasets
import numpy as np import numpy as np
import torch import torch
from datasets import load_dataset, load_metric from datasets import load_dataset
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from tqdm.auto import tqdm from tqdm.auto import tqdm
import evaluate
import transformers import transformers
from accelerate import Accelerator from accelerate import Accelerator
from accelerate.logging import get_logger from accelerate.logging import get_logger
...@@ -562,7 +563,7 @@ def main(): ...@@ -562,7 +563,7 @@ def main():
experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
accelerator.init_trackers("translation_no_trainer", experiment_config) accelerator.init_trackers("translation_no_trainer", experiment_config)
metric = load_metric("sacrebleu") metric = evaluate.load("sacrebleu")
def postprocess_text(preds, labels): def postprocess_text(preds, labels):
preds = [pred.strip() for pred in preds] preds = [pred.strip() for pred in preds]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment