Commit f107ae29 authored by lintangsutawika's avatar lintangsutawika
Browse files

ported changes here

parent 2a9da9fb
...@@ -389,25 +389,30 @@ class ConfigurableTask(Task): ...@@ -389,25 +389,30 @@ class ConfigurableTask(Task):
self._metric_list = {} self._metric_list = {}
self._aggregation_list = {} self._aggregation_list = {}
self._higher_is_better = {} self._higher_is_better = {}
for (metric_name, aggregation, higher_is_better) in self._config.metric_list: self._metric_kwargs = {}
for metric_config in self._config.metric_list:
self._aggregation_list[metric_name] = get_aggregation(aggregation) metric_name = metric_config['name']
self._higher_is_better[metric_name] = higher_is_better aggregation = metric_config['aggregation']
higher_is_better = metric_config['higher_is_better']
kwargs = {key: metric_config[key] for key in metric_config if key not in ['name', 'aggregation', 'higher_is_better']}
self._metric_list[metric_name] = get_metric(metric_name) self._aggregation_list[metric_name] = AGGREGATION_REGISTRY[aggregation]
self._higher_is_better[metric_name] = higher_is_better
if metric_name in METRIC_REGISTRY.keys():
self._metric_list[metric_name] = METRIC_REGISTRY[metric_name]
else:
try:
metric_object = evaluate.load(metric_name)
self._metric_list[metric_name] = metric_object
self._metric_kwargs[metric_name] = kwargs
# if metric_name in METRIC_REGISTRY.keys(): except Exception as ex:
# self._metric_list[metric_name] = METRIC_REGISTRY[metric_name] raise Warning(
# else: "{} not found in the evaluate library!".format(metric_name),
# try: "Please check https://huggingface.co/evaluate-metric",
# metric_object = evaluate.load(metric_name) )
# self._metric_list[metric_name] = metric_object
# except Exception as ex:
# raise Warning(
# "{} not found in the evaluate library!".format(metric_name),
# "Please check https://huggingface.co/evaluate-metric",
# )
self.download(data_dir, cache_dir, download_mode) self.download(data_dir, cache_dir, download_mode)
self._training_docs = None self._training_docs = None
...@@ -468,8 +473,19 @@ class ConfigurableTask(Task): ...@@ -468,8 +473,19 @@ class ConfigurableTask(Task):
def construct_requests(self, doc, ctx, **kwargs): def construct_requests(self, doc, ctx, **kwargs):
if self.OUTPUT_TYPE == "greedy_until": if self.output_type == "loglikelihood":
return Instance(request_type=self.OUTPUT_TYPE, doc=doc, arguments=(ctx, "\n\n"), id_=0, **kwargs) arguments=(ctx, self.doc_to_target(doc))
elif self.output_type == "loglikelihood_rolling":
arguments=(self.doc_to_target(doc),)
elif self.output_type == "greedy_until":
arguments=(ctx, "\n\n")
return Instance(
request_type=self.output_type,
doc=doc,
arguments=arguments,
**kwargs
)
def process_results(self, doc, results): def process_results(self, doc, results):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment