Commit c64bf9a9 authored by lintangsutawika's avatar lintangsutawika
Browse files

change all mentions of `greedy_until` to `generate_until`

parent 04ca5671
......@@ -44,7 +44,7 @@ ALL_OUTPUT_TYPES = [
"loglikelihood",
"multiple_choice",
"loglikelihood_rolling",
"greedy_until",
"generate_until",
]
......@@ -80,7 +80,7 @@ class TaskConfig(dict):
num_fewshot: int = 0
# scoring options
metric_list: list = None
output_type: str = "greedy_until"
output_type: str = "generate_until"
generation_kwargs: dict = None
repeats: int = 1
filter_list: Union[str, list] = None
......@@ -97,11 +97,11 @@ class TaskConfig(dict):
self.dataset_path = inspect.getfile(import_module(self.dataset_path))
if self.generation_kwargs is not None:
if self.output_type != "greedy_until":
if self.output_type != "generate_until":
eval_logger.warning(
f"[{self.task}] passed `generation_kwargs`, but not using `output_type: greedy_until`!"
f"[{self.task}] passed `generation_kwargs`, but not using `output_type: generate_until`!"
)
assert self.output_type != "greedy_until"
assert self.output_type != "generate_until"
if "temperature" in self.generation_kwargs:
self.generation_kwargs["temperature"] = float(
......@@ -111,7 +111,7 @@ class TaskConfig(dict):
if "until" not in self.generation_kwargs:
self.generation_kwargs["until"] = [self.fewshot_delimiter]
else:
if self.output_type == "greedy_until":
if self.output_type == "generate_until":
# ensure that we greedily generate in absence of explicit arguments otherwise
self.generation_kwargs = {
"until": None
......@@ -958,7 +958,7 @@ class ConfigurableTask(Task):
)
return request_list
elif self.OUTPUT_TYPE == "greedy_until":
elif self.OUTPUT_TYPE == "generate_until":
arguments = (ctx, self.config.generation_kwargs)
return Instance(
......@@ -1070,7 +1070,7 @@ class ConfigurableTask(Task):
acc_mutual_info = 1.0 if np.argmax(lls_mutual_info) == gold else 0.0
result_dict["acc_mutual_info"] = acc_mutual_info
elif self.OUTPUT_TYPE == "greedy_until":
elif self.OUTPUT_TYPE == "generate_until":
gold = self.doc_to_target(doc)
result = results[0]
if self.config.doc_to_choice is not None:
......@@ -1134,7 +1134,7 @@ class ConfigurableTask(Task):
else:
raise ValueError(
f"Passed invalid output_type '{self.OUTPUT_TYPE}' ! Please use one of ",
"'loglikelihood', 'loglikelihood_rolling', 'greedy_until' or 'multiple_choice'",
"'loglikelihood', 'loglikelihood_rolling', 'generate_until' or 'multiple_choice'",
)
return result_dict
......
import os
import yaml
from lm_eval import utils
from lm_eval.tasks import register_configurable_task, check_prompt_config
from lm_eval.logger import eval_logger
from lm_eval.api.registry import (
TASK_REGISTRY,
GROUP_REGISTRY,
ALL_TASKS,
)
def include_benchmarks(task_dir: str) -> None:
for root, subdirs, file_list in os.walk(task_dir):
if (subdirs == [] or "__pycache__" in subdirs) and (len(file_list) > 0):
for f in file_list:
if f.endswith(".yaml"):
try:
benchmark_path = os.path.join(root, f)
with open(benchmark_path, "rb") as file:
yaml_config = yaml.full_load(file)
if "prompts" in yaml_config:
continue # Skip it
assert "group" in yaml_config
group = yaml_config["group"]
all_task_list = yaml_config["task"]
config_list = [
task for task in all_task_list if type(task) != str
]
task_list = [
task for task in all_task_list if type(task) == str
]
for task_config in config_list:
yaml_dir = os.path.dirname(benchmark_path)
task_config = utils.load_yaml_config(
yaml_config=task_config, yaml_dir=yaml_dir
)
if "use_prompt" in task_config:
if "yaml" in task_config["use_prompt"]:
task_config["use_prompt"] = os.path.join(
root, task_config["use_prompt"]
)
var_configs = check_prompt_config(
{
**task_config,
**{"group": group},
}
)
for config in var_configs:
register_configurable_task(config)
task_names = utils.pattern_match(task_list, ALL_TASKS)
for task in task_names:
if task in TASK_REGISTRY:
if group in GROUP_REGISTRY:
GROUP_REGISTRY[group].append(task)
else:
GROUP_REGISTRY[group] = [task]
ALL_TASKS.add(group)
except Exception as error:
eval_logger.warning(
"Failed to load benchmark in\n"
f" {benchmark_path}\n"
" Benchmark will not be added to registry\n"
f" Error: {error}"
)
task_dir = os.path.dirname(os.path.abspath(__file__)) + "/"
include_benchmarks(task_dir)
......@@ -138,7 +138,7 @@ please install anthropic via `pip install lm-eval[anthropic]` or `pip install -e
def _loglikelihood_tokens(self, requests, disable_tqdm: bool = False):
raise NotImplementedError("No support for logits.")
def greedy_until(self, requests) -> List[str]:
def generate_until(self, requests) -> List[str]:
if not requests:
return []
......@@ -164,7 +164,7 @@ please install anthropic via `pip install lm-eval[anthropic]` or `pip install -e
)
res.append(response)
self.cache_hook.add_partial("greedy_until", request, response)
self.cache_hook.add_partial("generate_until", request, response)
except anthropic.APIConnectionError as e: # type: ignore # noqa: F821
eval_logger.critical(f"Server unreachable: {e.__cause__}")
break
......@@ -179,7 +179,7 @@ please install anthropic via `pip install lm-eval[anthropic]` or `pip install -e
raise NotImplementedError()
def _model_generate(self, context, max_length, eos_token_id):
# Isn't used because we override greedy_until
# Isn't used because we override generate_until
raise NotImplementedError()
def loglikelihood(self, requests):
......
......@@ -20,7 +20,7 @@ class DummyLM(LM):
return res
def greedy_until(self, requests):
def generate_until(self, requests):
res = []
for ctx, _ in requests:
......
......@@ -813,7 +813,7 @@ class HFLM(LM):
return re_ord.get_original(res)
def greedy_until(self, requests):
def generate_until(self, requests):
res = defaultdict(list)
re_ords = {}
......@@ -930,7 +930,7 @@ class HFLM(LM):
res[key].append(s)
self.cache_hook.add_partial(
"greedy_until", (context, gen_kwargs), s
"generate_until", (context, gen_kwargs), s
)
pbar.update(1)
# reorder this group of results back to original unsorted form
......
......@@ -203,7 +203,7 @@ class OpenaiCompletionsLM(LM):
self.cache_hook.add_partial("loglikelihood", cache_key, answer)
return re_ord.get_original(res)
def greedy_until(self, requests) -> List[str]:
def generate_until(self, requests) -> List[str]:
if not requests:
return []
res = []
......@@ -260,7 +260,7 @@ class OpenaiCompletionsLM(LM):
# partial caching
self.cache_hook.add_partial(
"greedy_until", (context, {"until": until_}), s
"generate_until", (context, {"until": until_}), s
)
res.append(s)
......@@ -271,7 +271,7 @@ class OpenaiCompletionsLM(LM):
raise NotImplementedError()
def _model_generate(self, context, max_length, eos_token_id):
# Isn't used because we override greedy_until
# Isn't used because we override generate_until
raise NotImplementedError()
def loglikelihood_rolling(self, requests) -> List[float]:
......
......@@ -58,7 +58,7 @@ class TextSynthLM(LM):
@property
def eot_token_id(self):
# Isn't used because we override loglikelihood, loglikelihood_rolling and greedy_until
# Isn't used because we override loglikelihood, loglikelihood_rolling and generate_until
raise NotImplementedError()
@property
......@@ -72,20 +72,20 @@ class TextSynthLM(LM):
@property
def batch_size(self):
# Isn't used because we override loglikelihood, loglikelihood_rolling and greedy_until
# Isn't used because we override loglikelihood, loglikelihood_rolling and generate_until
raise NotImplementedError()
@property
def device(self):
# Isn't used because we override loglikelihood, loglikelihood_rolling and greedy_until
# Isn't used because we override loglikelihood, loglikelihood_rolling and generate_until
raise NotImplementedError()
def tok_encode(self, string: str):
# Isn't used because we override loglikelihood, loglikelihood_rolling and greedy_until
# Isn't used because we override loglikelihood, loglikelihood_rolling and generate_until
raise NotImplementedError()
def tok_decode(self, tokens):
# Isn't used because we override loglikelihood, loglikelihood_rolling and greedy_until
# Isn't used because we override loglikelihood, loglikelihood_rolling and generate_until
raise NotImplementedError()
def loglikelihood(self, requests):
......@@ -122,7 +122,7 @@ class TextSynthLM(LM):
"input tokenization support from TextSynth."
)
def greedy_until(self, requests):
def generate_until(self, requests):
if not requests:
return []
......@@ -146,7 +146,7 @@ class TextSynthLM(LM):
s = resp["text"]
res.append(s)
self.cache_hook.add_partial("greedy_until", (inp, request_args), s)
self.cache_hook.add_partial("generate_until", (inp, request_args), s)
else:
logger.error(
f"The following response does not contain generated `text`. "
......@@ -160,5 +160,5 @@ class TextSynthLM(LM):
raise NotImplementedError()
def _model_generate(self, context, max_length, eos_token_id):
# Isn't used because we override greedy_until
# Isn't used because we override generate_until
raise NotImplementedError()
......@@ -98,7 +98,7 @@ def check_prompt_config(
]
)
},
**{"output_type": "greedy_until"},
**{"output_type": "generate_until"},
}
)
else:
......
task: babi
dataset_path: Muennighoff/babi
dataset_name: null
output_type: greedy_until
output_type: generate_until
training_split: train
validation_split: valid
test_split: test
......
group: bbh_flan_cot_fewshot
dataset_path: lukaemon/bbh
output_type: greedy_until
output_type: generate_until
test_split: test
doc_to_target: "{{target}}"
metric_list:
......
group: bbh_flan_cot_zeroshot
dataset_path: lukaemon/bbh
output_type: greedy_until
output_type: generate_until
test_split: test
doc_to_target: "{{target}}"
metric_list:
......
group: bbh_flan_fewshot
dataset_path: lukaemon/bbh
output_type: greedy_until
output_type: generate_until
test_split: test
doc_to_target: "{{target}}"
metric_list:
......
group: bbh_flan_zeroshot
dataset_path: lukaemon/bbh
output_type: greedy_until
output_type: generate_until
test_split: test
doc_to_target: "{{target}}"
metric_list:
......
group: flan-cot
output_type: greedy_until
output_type: generate_until
validation_split: validation
doc_to_target: "{{answer}}"
metric_list:
......
output_type: greedy_until
output_type: generate_until
validation_split: validation
metric_list:
- metric: exact_match
......
......@@ -6,7 +6,7 @@ task:
use_prompt: promptsource:*
training_split: train
validation_split: validation
output_type: greedy_until
output_type: generate_until
metric_list:
- metric: exact_match
aggregation: mean
......@@ -19,7 +19,7 @@ task:
use_prompt: promptsource:*
training_split: train
validation_split: validation
output_type: greedy_until
output_type: generate_until
metric_list:
- metric: exact_match
aggregation: mean
......@@ -32,7 +32,7 @@ task:
use_prompt: promptsource:*
training_split: train
validation_split: validation
output_type: greedy_until
output_type: generate_until
metric_list:
- metric: exact_match
aggregation: mean
......@@ -44,7 +44,7 @@ task:
use_prompt: promptsource:*
training_split: train
validation_split: validation
output_type: greedy_until
output_type: generate_until
metric_list:
- metric: exact_match
aggregation: mean
......@@ -56,7 +56,7 @@ task:
use_prompt: promptsource:*
training_split: train_r1
validation_split: dev_r1
output_type: greedy_until
output_type: generate_until
metric_list:
- metric: exact_match
aggregation: mean
......@@ -68,7 +68,7 @@ task:
use_prompt: promptsource:*
training_split: train_r2
validation_split: dev_r2
output_type: greedy_until
output_type: generate_until
metric_list:
- metric: exact_match
aggregation: mean
......@@ -80,7 +80,7 @@ task:
use_prompt: promptsource:*
training_split: train_r3
validation_split: dev_r3
output_type: greedy_until
output_type: generate_until
metric_list:
- metric: exact_match
aggregation: mean
......@@ -93,7 +93,7 @@ task:
use_prompt: promptsource:*
training_split: train
validation_split: validation
output_type: greedy_until
output_type: generate_until
metric_list:
- metric: exact_match
aggregation: mean
......@@ -105,7 +105,7 @@ task:
use_prompt: promptsource:*
training_split: train
validation_split: validation
output_type: greedy_until
output_type: generate_until
metric_list:
- metric: exact_match
aggregation: mean
......@@ -118,7 +118,7 @@ task:
use_prompt: promptsource:*
training_split: train
validation_split: validation
output_type: greedy_until
output_type: generate_until
metric_list:
- metric: exact_match
aggregation: mean
......
......@@ -175,8 +175,8 @@ all_subtasks = [
def main() -> None:
for path, task_type in zip(
["multiple_choice", "greedy_until"],
["multiple_choice_template_yaml", "greedy_until_template_yaml"],
["multiple_choice", "generate_until"],
["multiple_choice_template_yaml", "generate_until_template_yaml"],
):
os.makedirs(path, exist_ok=True)
for task in all_subtasks:
......
# Generated by utils.py
dataset_name: abstract_narrative_understanding_zero_shot
include: ../greedy_until_template_yaml
task: bigbench_abstract_narrative_understanding_greedy_until
include: ../generate_until_template_yaml
task: bigbench_abstract_narrative_understanding_generate_until
# Generated by utils.py
dataset_name: anachronisms_zero_shot
include: ../greedy_until_template_yaml
task: bigbench_anachronisms_greedy_until
include: ../generate_until_template_yaml
task: bigbench_anachronisms_generate_until
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment