"...composable_kernel-1.git" did not exist on "66edb2590d47a4dd4208e10998b19d0318b1cd71"
Unverified Commit d17a5b94 authored by Haodong Duan's avatar Haodong Duan Committed by GitHub
Browse files

[Refine] Refine PR #122 (#123)

* update

* update
parent 191a3f6f
...@@ -21,8 +21,7 @@ class HuggingfaceEvaluator(BaseEvaluator): ...@@ -21,8 +21,7 @@ class HuggingfaceEvaluator(BaseEvaluator):
def __init__(self, metric: str, seed: int = 0) -> None: def __init__(self, metric: str, seed: int = 0) -> None:
self.metric = metric self.metric = metric
random.seed(seed) self.seed = seed
np.random.seed(seed)
super().__init__() super().__init__()
def _preprocess(self, predictions: List, references: List) -> dict: def _preprocess(self, predictions: List, references: List) -> dict:
...@@ -61,6 +60,11 @@ class HuggingfaceEvaluator(BaseEvaluator): ...@@ -61,6 +60,11 @@ class HuggingfaceEvaluator(BaseEvaluator):
Returns: Returns:
dict: calculated scores. dict: calculated scores.
""" """
random_state = random.getstate()
np_random_state = np.random.get_state()
random.seed(self.seed)
np.random.seed(self.seed)
if len(predictions) != len(references): if len(predictions) != len(references):
return { return {
'error': 'error':
...@@ -70,7 +74,10 @@ class HuggingfaceEvaluator(BaseEvaluator): ...@@ -70,7 +74,10 @@ class HuggingfaceEvaluator(BaseEvaluator):
} }
metric = evaluate.load(self.metric) metric = evaluate.load(self.metric)
scores = metric.compute(**self._preprocess(predictions, references)) scores = metric.compute(**self._preprocess(predictions, references))
return self._postprocess(scores) result = self._postprocess(scores)
random.setstate(random_state)
np.random.set_state(np_random_state)
return result
@ICL_EVALUATORS.register_module() @ICL_EVALUATORS.register_module()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment