Unverified Commit d17a5b94 authored by Haodong Duan's avatar Haodong Duan Committed by GitHub
Browse files

[Refine] Refine PR #122 (#123)

* update

* update
parent 191a3f6f
......@@ -21,8 +21,7 @@ class HuggingfaceEvaluator(BaseEvaluator):
def __init__(self, metric: str, seed: int = 0) -> None:
self.metric = metric
random.seed(seed)
np.random.seed(seed)
self.seed = seed
super().__init__()
def _preprocess(self, predictions: List, references: List) -> dict:
......@@ -61,6 +60,11 @@ class HuggingfaceEvaluator(BaseEvaluator):
Returns:
dict: calculated scores.
"""
random_state = random.getstate()
np_random_state = np.random.get_state()
random.seed(self.seed)
np.random.seed(self.seed)
if len(predictions) != len(references):
return {
'error':
......@@ -70,7 +74,10 @@ class HuggingfaceEvaluator(BaseEvaluator):
}
metric = evaluate.load(self.metric)
scores = metric.compute(**self._preprocess(predictions, references))
return self._postprocess(scores)
result = self._postprocess(scores)
random.setstate(random_state)
np.random.set_state(np_random_state)
return result
@ICL_EVALUATORS.register_module()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment