Commit a68c3fa4 authored by haileyschoelkopf's avatar haileyschoelkopf
Browse files

fix mutual info normalization

parent 6449ab1a
...@@ -52,7 +52,6 @@ class TaskConfig(dict): ...@@ -52,7 +52,6 @@ class TaskConfig(dict):
task: str = None task: str = None
group: Union[str, list] = None group: Union[str, list] = None
reference: str = None
dataset_path: str = None dataset_path: str = None
dataset_name: str = None dataset_name: str = None
...@@ -67,6 +66,8 @@ class TaskConfig(dict): ...@@ -67,6 +66,8 @@ class TaskConfig(dict):
doc_to_target: Union[Callable, str] = None doc_to_target: Union[Callable, str] = None
use_prompt: str = None use_prompt: str = None
description: str = "" description: str = ""
target_delimiter: str = " "
fewshot_delimiter: str = "\n\n"
num_fewshot: int = 0 num_fewshot: int = 0
batch_size: int = 1 batch_size: int = 1
...@@ -76,8 +77,6 @@ class TaskConfig(dict): ...@@ -76,8 +77,6 @@ class TaskConfig(dict):
gold_alias: Union[Callable, str] = None gold_alias: Union[Callable, str] = None
output_type: str = "greedy_until" output_type: str = "greedy_until"
generation_kwargs: dict = None generation_kwargs: dict = None
target_delimiter: str = " "
fewshot_delimiter: str = "\n\n"
filter_list: Union[str, list] = None filter_list: Union[str, list] = None
should_decontaminate: bool = False should_decontaminate: bool = False
doc_to_decontamination_query: str = None doc_to_decontamination_query: str = None
...@@ -343,7 +342,7 @@ class Task(abc.ABC): ...@@ -343,7 +342,7 @@ class Task(abc.ABC):
fewshot_ctx = self.fewshot_context( fewshot_ctx = self.fewshot_context(
doc, self._config.num_fewshot, rnd=random.Random() doc, self._config.num_fewshot, rnd=random.Random()
) )
# TODO: hardcoded for now: # of runs on each input to be 2. # TODO: we should override this if doing greedy gen so users don't waste time+compute # TODO: we should override this if doing greedy gen so users don't waste time+compute
inst = self.construct_requests( inst = self.construct_requests(
doc=doc, doc=doc,
ctx=fewshot_ctx, ctx=fewshot_ctx,
...@@ -773,7 +772,7 @@ class ConfigurableTask(Task): ...@@ -773,7 +772,7 @@ class ConfigurableTask(Task):
Instance( Instance(
request_type="loglikelihood", request_type="loglikelihood",
doc=doc, doc=doc,
arguments=("", "{}".format(choice)), arguments=("", " {}".format(choice)),
idx=i, idx=i,
**kwargs, **kwargs,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment