Commit cf8a257c authored by Baber's avatar Baber
Browse files

fix metrics

parent 76e517d1
...@@ -322,10 +322,11 @@ class Task(abc.ABC): ...@@ -322,10 +322,11 @@ class Task(abc.ABC):
elif self.has_validation_docs(): elif self.has_validation_docs():
return self.validation_docs() return self.validation_docs()
else: else:
eval_logger.warning( if self.config.get("num_fewshot", 0) > 0:
f"[Task: {self.config.task}] has_training_docs and has_validation_docs are False" eval_logger.warning(
", using test_docs as fewshot_docs but this is not recommended." f"[Task: {self.config.task}] has_training_docs and has_validation_docs are False"
) ", using test_docs as fewshot_docs but this is not recommended."
)
return self.test_docs() return self.test_docs()
def _process_doc(self, doc: dict) -> dict: def _process_doc(self, doc: dict) -> dict:
......
...@@ -134,7 +134,10 @@ def rouge_zh_score(predictions: list[str], references: list[str], **kwargs) -> f ...@@ -134,7 +134,10 @@ def rouge_zh_score(predictions: list[str], references: list[str], **kwargs) -> f
def f1_score(predictions: list[str], references: list[str], **kwargs): def f1_score(predictions: list[str], references: list[str], **kwargs):
prediction, ground_truth = predictions[0], references[0] try:
prediction, ground_truth = predictions[0], references[0]
except:
return 0.0
common = Counter(prediction) & Counter(ground_truth) common = Counter(prediction) & Counter(ground_truth)
num_same = sum(common.values()) num_same = sum(common.values())
if num_same == 0: if num_same == 0:
...@@ -152,7 +155,11 @@ def qa_f1_score(predictions: list[str], references: list[str], **kwargs) -> floa ...@@ -152,7 +155,11 @@ def qa_f1_score(predictions: list[str], references: list[str], **kwargs) -> floa
prediction_tokens = normalized_prediction.split() prediction_tokens = normalized_prediction.split()
ground_truth_tokens = normalized_ground_truth.split() ground_truth_tokens = normalized_ground_truth.split()
return f1_score(prediction_tokens, ground_truth_tokens) try:
res = f1_score(prediction_tokens, ground_truth_tokens)
except:
return 0.0
return res
def qa_f1_zh_score(predictions: list[str], references: list[str], **kwargs) -> float: def qa_f1_zh_score(predictions: list[str], references: list[str], **kwargs) -> float:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment