Commit f107ae29 authored by lintangsutawika's avatar lintangsutawika
Browse files

ported changes here

parent 2a9da9fb
......@@ -389,25 +389,30 @@ class ConfigurableTask(Task):
self._metric_list = {}
self._aggregation_list = {}
self._higher_is_better = {}
for (metric_name, aggregation, higher_is_better) in self._config.metric_list:
self._metric_kwargs = {}
for metric_config in self._config.metric_list:
self._aggregation_list[metric_name] = get_aggregation(aggregation)
self._higher_is_better[metric_name] = higher_is_better
metric_name = metric_config['name']
aggregation = metric_config['aggregation']
higher_is_better = metric_config['higher_is_better']
kwargs = {key: metric_config[key] for key in metric_config if key not in ['name', 'aggregation', 'higher_is_better']}
self._metric_list[metric_name] = get_metric(metric_name)
self._aggregation_list[metric_name] = AGGREGATION_REGISTRY[aggregation]
self._higher_is_better[metric_name] = higher_is_better
if metric_name in METRIC_REGISTRY.keys():
self._metric_list[metric_name] = METRIC_REGISTRY[metric_name]
else:
try:
metric_object = evaluate.load(metric_name)
self._metric_list[metric_name] = metric_object
self._metric_kwargs[metric_name] = kwargs
# if metric_name in METRIC_REGISTRY.keys():
# self._metric_list[metric_name] = METRIC_REGISTRY[metric_name]
# else:
# try:
# metric_object = evaluate.load(metric_name)
# self._metric_list[metric_name] = metric_object
# except Exception as ex:
# raise Warning(
# "{} not found in the evaluate library!".format(metric_name),
# "Please check https://huggingface.co/evaluate-metric",
# )
except Exception as ex:
raise Warning(
"{} not found in the evaluate library!".format(metric_name),
"Please check https://huggingface.co/evaluate-metric",
)
self.download(data_dir, cache_dir, download_mode)
self._training_docs = None
......@@ -468,8 +473,19 @@ class ConfigurableTask(Task):
def construct_requests(self, doc, ctx, **kwargs):
if self.OUTPUT_TYPE == "greedy_until":
return Instance(request_type=self.OUTPUT_TYPE, doc=doc, arguments=(ctx, "\n\n"), id_=0, **kwargs)
if self.output_type == "loglikelihood":
arguments=(ctx, self.doc_to_target(doc))
elif self.output_type == "loglikelihood_rolling":
arguments=(self.doc_to_target(doc),)
elif self.output_type == "greedy_until":
arguments=(ctx, "\n\n")
return Instance(
request_type=self.output_type,
doc=doc,
arguments=arguments,
**kwargs
)
def process_results(self, doc, results):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment