Commit c64bf9a9 authored by lintangsutawika's avatar lintangsutawika
Browse files

change all mentions of `greedy_until` to `generate_until`

parent 04ca5671
...@@ -44,7 +44,7 @@ ALL_OUTPUT_TYPES = [ ...@@ -44,7 +44,7 @@ ALL_OUTPUT_TYPES = [
"loglikelihood", "loglikelihood",
"multiple_choice", "multiple_choice",
"loglikelihood_rolling", "loglikelihood_rolling",
"greedy_until", "generate_until",
] ]
...@@ -80,7 +80,7 @@ class TaskConfig(dict): ...@@ -80,7 +80,7 @@ class TaskConfig(dict):
num_fewshot: int = 0 num_fewshot: int = 0
# scoring options # scoring options
metric_list: list = None metric_list: list = None
output_type: str = "greedy_until" output_type: str = "generate_until"
generation_kwargs: dict = None generation_kwargs: dict = None
repeats: int = 1 repeats: int = 1
filter_list: Union[str, list] = None filter_list: Union[str, list] = None
...@@ -97,11 +97,11 @@ class TaskConfig(dict): ...@@ -97,11 +97,11 @@ class TaskConfig(dict):
self.dataset_path = inspect.getfile(import_module(self.dataset_path)) self.dataset_path = inspect.getfile(import_module(self.dataset_path))
if self.generation_kwargs is not None: if self.generation_kwargs is not None:
if self.output_type != "greedy_until": if self.output_type != "generate_until":
eval_logger.warning( eval_logger.warning(
f"[{self.task}] passed `generation_kwargs`, but not using `output_type: greedy_until`!" f"[{self.task}] passed `generation_kwargs`, but not using `output_type: generate_until`!"
) )
assert self.output_type != "greedy_until" assert self.output_type != "generate_until"
if "temperature" in self.generation_kwargs: if "temperature" in self.generation_kwargs:
self.generation_kwargs["temperature"] = float( self.generation_kwargs["temperature"] = float(
...@@ -111,7 +111,7 @@ class TaskConfig(dict): ...@@ -111,7 +111,7 @@ class TaskConfig(dict):
if "until" not in self.generation_kwargs: if "until" not in self.generation_kwargs:
self.generation_kwargs["until"] = [self.fewshot_delimiter] self.generation_kwargs["until"] = [self.fewshot_delimiter]
else: else:
if self.output_type == "greedy_until": if self.output_type == "generate_until":
# ensure that we greedily generate in absence of explicit arguments otherwise # ensure that we greedily generate in absence of explicit arguments otherwise
self.generation_kwargs = { self.generation_kwargs = {
"until": None "until": None
...@@ -958,7 +958,7 @@ class ConfigurableTask(Task): ...@@ -958,7 +958,7 @@ class ConfigurableTask(Task):
) )
return request_list return request_list
elif self.OUTPUT_TYPE == "greedy_until": elif self.OUTPUT_TYPE == "generate_until":
arguments = (ctx, self.config.generation_kwargs) arguments = (ctx, self.config.generation_kwargs)
return Instance( return Instance(
...@@ -1070,7 +1070,7 @@ class ConfigurableTask(Task): ...@@ -1070,7 +1070,7 @@ class ConfigurableTask(Task):
acc_mutual_info = 1.0 if np.argmax(lls_mutual_info) == gold else 0.0 acc_mutual_info = 1.0 if np.argmax(lls_mutual_info) == gold else 0.0
result_dict["acc_mutual_info"] = acc_mutual_info result_dict["acc_mutual_info"] = acc_mutual_info
elif self.OUTPUT_TYPE == "greedy_until": elif self.OUTPUT_TYPE == "generate_until":
gold = self.doc_to_target(doc) gold = self.doc_to_target(doc)
result = results[0] result = results[0]
if self.config.doc_to_choice is not None: if self.config.doc_to_choice is not None:
...@@ -1134,7 +1134,7 @@ class ConfigurableTask(Task): ...@@ -1134,7 +1134,7 @@ class ConfigurableTask(Task):
else: else:
raise ValueError( raise ValueError(
f"Passed invalid output_type '{self.OUTPUT_TYPE}' ! Please use one of ", f"Passed invalid output_type '{self.OUTPUT_TYPE}' ! Please use one of ",
"'loglikelihood', 'loglikelihood_rolling', 'greedy_until' or 'multiple_choice'", "'loglikelihood', 'loglikelihood_rolling', 'generate_until' or 'multiple_choice'",
) )
return result_dict return result_dict
......
import os
import yaml
from lm_eval import utils
from lm_eval.tasks import register_configurable_task, check_prompt_config
from lm_eval.logger import eval_logger
from lm_eval.api.registry import (
TASK_REGISTRY,
GROUP_REGISTRY,
ALL_TASKS,
)
def include_benchmarks(task_dir: str) -> None:
for root, subdirs, file_list in os.walk(task_dir):
if (subdirs == [] or "__pycache__" in subdirs) and (len(file_list) > 0):
for f in file_list:
if f.endswith(".yaml"):
try:
benchmark_path = os.path.join(root, f)
with open(benchmark_path, "rb") as file:
yaml_config = yaml.full_load(file)
if "prompts" in yaml_config:
continue # Skip it
assert "group" in yaml_config
group = yaml_config["group"]
all_task_list = yaml_config["task"]
config_list = [
task for task in all_task_list if type(task) != str
]
task_list = [
task for task in all_task_list if type(task) == str
]
for task_config in config_list:
yaml_dir = os.path.dirname(benchmark_path)
task_config = utils.load_yaml_config(
yaml_config=task_config, yaml_dir=yaml_dir
)
if "use_prompt" in task_config:
if "yaml" in task_config["use_prompt"]:
task_config["use_prompt"] = os.path.join(
root, task_config["use_prompt"]
)
var_configs = check_prompt_config(
{
**task_config,
**{"group": group},
}
)
for config in var_configs:
register_configurable_task(config)
task_names = utils.pattern_match(task_list, ALL_TASKS)
for task in task_names:
if task in TASK_REGISTRY:
if group in GROUP_REGISTRY:
GROUP_REGISTRY[group].append(task)
else:
GROUP_REGISTRY[group] = [task]
ALL_TASKS.add(group)
except Exception as error:
eval_logger.warning(
"Failed to load benchmark in\n"
f" {benchmark_path}\n"
" Benchmark will not be added to registry\n"
f" Error: {error}"
)
task_dir = os.path.dirname(os.path.abspath(__file__)) + "/"
include_benchmarks(task_dir)
...@@ -138,7 +138,7 @@ please install anthropic via `pip install lm-eval[anthropic]` or `pip install -e ...@@ -138,7 +138,7 @@ please install anthropic via `pip install lm-eval[anthropic]` or `pip install -e
def _loglikelihood_tokens(self, requests, disable_tqdm: bool = False): def _loglikelihood_tokens(self, requests, disable_tqdm: bool = False):
raise NotImplementedError("No support for logits.") raise NotImplementedError("No support for logits.")
def greedy_until(self, requests) -> List[str]: def generate_until(self, requests) -> List[str]:
if not requests: if not requests:
return [] return []
...@@ -164,7 +164,7 @@ please install anthropic via `pip install lm-eval[anthropic]` or `pip install -e ...@@ -164,7 +164,7 @@ please install anthropic via `pip install lm-eval[anthropic]` or `pip install -e
) )
res.append(response) res.append(response)
self.cache_hook.add_partial("greedy_until", request, response) self.cache_hook.add_partial("generate_until", request, response)
except anthropic.APIConnectionError as e: # type: ignore # noqa: F821 except anthropic.APIConnectionError as e: # type: ignore # noqa: F821
eval_logger.critical(f"Server unreachable: {e.__cause__}") eval_logger.critical(f"Server unreachable: {e.__cause__}")
break break
...@@ -179,7 +179,7 @@ please install anthropic via `pip install lm-eval[anthropic]` or `pip install -e ...@@ -179,7 +179,7 @@ please install anthropic via `pip install lm-eval[anthropic]` or `pip install -e
raise NotImplementedError() raise NotImplementedError()
def _model_generate(self, context, max_length, eos_token_id): def _model_generate(self, context, max_length, eos_token_id):
# Isn't used because we override greedy_until # Isn't used because we override generate_until
raise NotImplementedError() raise NotImplementedError()
def loglikelihood(self, requests): def loglikelihood(self, requests):
......
...@@ -20,7 +20,7 @@ class DummyLM(LM): ...@@ -20,7 +20,7 @@ class DummyLM(LM):
return res return res
def greedy_until(self, requests): def generate_until(self, requests):
res = [] res = []
for ctx, _ in requests: for ctx, _ in requests:
......
...@@ -813,7 +813,7 @@ class HFLM(LM): ...@@ -813,7 +813,7 @@ class HFLM(LM):
return re_ord.get_original(res) return re_ord.get_original(res)
def greedy_until(self, requests): def generate_until(self, requests):
res = defaultdict(list) res = defaultdict(list)
re_ords = {} re_ords = {}
...@@ -930,7 +930,7 @@ class HFLM(LM): ...@@ -930,7 +930,7 @@ class HFLM(LM):
res[key].append(s) res[key].append(s)
self.cache_hook.add_partial( self.cache_hook.add_partial(
"greedy_until", (context, gen_kwargs), s "generate_until", (context, gen_kwargs), s
) )
pbar.update(1) pbar.update(1)
# reorder this group of results back to original unsorted form # reorder this group of results back to original unsorted form
......
...@@ -203,7 +203,7 @@ class OpenaiCompletionsLM(LM): ...@@ -203,7 +203,7 @@ class OpenaiCompletionsLM(LM):
self.cache_hook.add_partial("loglikelihood", cache_key, answer) self.cache_hook.add_partial("loglikelihood", cache_key, answer)
return re_ord.get_original(res) return re_ord.get_original(res)
def greedy_until(self, requests) -> List[str]: def generate_until(self, requests) -> List[str]:
if not requests: if not requests:
return [] return []
res = [] res = []
...@@ -260,7 +260,7 @@ class OpenaiCompletionsLM(LM): ...@@ -260,7 +260,7 @@ class OpenaiCompletionsLM(LM):
# partial caching # partial caching
self.cache_hook.add_partial( self.cache_hook.add_partial(
"greedy_until", (context, {"until": until_}), s "generate_until", (context, {"until": until_}), s
) )
res.append(s) res.append(s)
...@@ -271,7 +271,7 @@ class OpenaiCompletionsLM(LM): ...@@ -271,7 +271,7 @@ class OpenaiCompletionsLM(LM):
raise NotImplementedError() raise NotImplementedError()
def _model_generate(self, context, max_length, eos_token_id): def _model_generate(self, context, max_length, eos_token_id):
# Isn't used because we override greedy_until # Isn't used because we override generate_until
raise NotImplementedError() raise NotImplementedError()
def loglikelihood_rolling(self, requests) -> List[float]: def loglikelihood_rolling(self, requests) -> List[float]:
......
...@@ -58,7 +58,7 @@ class TextSynthLM(LM): ...@@ -58,7 +58,7 @@ class TextSynthLM(LM):
@property @property
def eot_token_id(self): def eot_token_id(self):
# Isn't used because we override loglikelihood, loglikelihood_rolling and greedy_until # Isn't used because we override loglikelihood, loglikelihood_rolling and generate_until
raise NotImplementedError() raise NotImplementedError()
@property @property
...@@ -72,20 +72,20 @@ class TextSynthLM(LM): ...@@ -72,20 +72,20 @@ class TextSynthLM(LM):
@property @property
def batch_size(self): def batch_size(self):
# Isn't used because we override loglikelihood, loglikelihood_rolling and greedy_until # Isn't used because we override loglikelihood, loglikelihood_rolling and generate_until
raise NotImplementedError() raise NotImplementedError()
@property @property
def device(self): def device(self):
# Isn't used because we override loglikelihood, loglikelihood_rolling and greedy_until # Isn't used because we override loglikelihood, loglikelihood_rolling and generate_until
raise NotImplementedError() raise NotImplementedError()
def tok_encode(self, string: str): def tok_encode(self, string: str):
# Isn't used because we override loglikelihood, loglikelihood_rolling and greedy_until # Isn't used because we override loglikelihood, loglikelihood_rolling and generate_until
raise NotImplementedError() raise NotImplementedError()
def tok_decode(self, tokens): def tok_decode(self, tokens):
# Isn't used because we override loglikelihood, loglikelihood_rolling and greedy_until # Isn't used because we override loglikelihood, loglikelihood_rolling and generate_until
raise NotImplementedError() raise NotImplementedError()
def loglikelihood(self, requests): def loglikelihood(self, requests):
...@@ -122,7 +122,7 @@ class TextSynthLM(LM): ...@@ -122,7 +122,7 @@ class TextSynthLM(LM):
"input tokenization support from TextSynth." "input tokenization support from TextSynth."
) )
def greedy_until(self, requests): def generate_until(self, requests):
if not requests: if not requests:
return [] return []
...@@ -146,7 +146,7 @@ class TextSynthLM(LM): ...@@ -146,7 +146,7 @@ class TextSynthLM(LM):
s = resp["text"] s = resp["text"]
res.append(s) res.append(s)
self.cache_hook.add_partial("greedy_until", (inp, request_args), s) self.cache_hook.add_partial("generate_until", (inp, request_args), s)
else: else:
logger.error( logger.error(
f"The following response does not contain generated `text`. " f"The following response does not contain generated `text`. "
...@@ -160,5 +160,5 @@ class TextSynthLM(LM): ...@@ -160,5 +160,5 @@ class TextSynthLM(LM):
raise NotImplementedError() raise NotImplementedError()
def _model_generate(self, context, max_length, eos_token_id): def _model_generate(self, context, max_length, eos_token_id):
# Isn't used because we override greedy_until # Isn't used because we override generate_until
raise NotImplementedError() raise NotImplementedError()
...@@ -98,7 +98,7 @@ def check_prompt_config( ...@@ -98,7 +98,7 @@ def check_prompt_config(
] ]
) )
}, },
**{"output_type": "greedy_until"}, **{"output_type": "generate_until"},
} }
) )
else: else:
......
task: babi task: babi
dataset_path: Muennighoff/babi dataset_path: Muennighoff/babi
dataset_name: null dataset_name: null
output_type: greedy_until output_type: generate_until
training_split: train training_split: train
validation_split: valid validation_split: valid
test_split: test test_split: test
......
group: bbh_flan_cot_fewshot group: bbh_flan_cot_fewshot
dataset_path: lukaemon/bbh dataset_path: lukaemon/bbh
output_type: greedy_until output_type: generate_until
test_split: test test_split: test
doc_to_target: "{{target}}" doc_to_target: "{{target}}"
metric_list: metric_list:
......
group: bbh_flan_cot_zeroshot group: bbh_flan_cot_zeroshot
dataset_path: lukaemon/bbh dataset_path: lukaemon/bbh
output_type: greedy_until output_type: generate_until
test_split: test test_split: test
doc_to_target: "{{target}}" doc_to_target: "{{target}}"
metric_list: metric_list:
......
group: bbh_flan_fewshot group: bbh_flan_fewshot
dataset_path: lukaemon/bbh dataset_path: lukaemon/bbh
output_type: greedy_until output_type: generate_until
test_split: test test_split: test
doc_to_target: "{{target}}" doc_to_target: "{{target}}"
metric_list: metric_list:
......
group: bbh_flan_zeroshot group: bbh_flan_zeroshot
dataset_path: lukaemon/bbh dataset_path: lukaemon/bbh
output_type: greedy_until output_type: generate_until
test_split: test test_split: test
doc_to_target: "{{target}}" doc_to_target: "{{target}}"
metric_list: metric_list:
......
group: flan-cot group: flan-cot
output_type: greedy_until output_type: generate_until
validation_split: validation validation_split: validation
doc_to_target: "{{answer}}" doc_to_target: "{{answer}}"
metric_list: metric_list:
......
output_type: greedy_until output_type: generate_until
validation_split: validation validation_split: validation
metric_list: metric_list:
- metric: exact_match - metric: exact_match
......
...@@ -6,7 +6,7 @@ task: ...@@ -6,7 +6,7 @@ task:
use_prompt: promptsource:* use_prompt: promptsource:*
training_split: train training_split: train
validation_split: validation validation_split: validation
output_type: greedy_until output_type: generate_until
metric_list: metric_list:
- metric: exact_match - metric: exact_match
aggregation: mean aggregation: mean
...@@ -19,7 +19,7 @@ task: ...@@ -19,7 +19,7 @@ task:
use_prompt: promptsource:* use_prompt: promptsource:*
training_split: train training_split: train
validation_split: validation validation_split: validation
output_type: greedy_until output_type: generate_until
metric_list: metric_list:
- metric: exact_match - metric: exact_match
aggregation: mean aggregation: mean
...@@ -32,7 +32,7 @@ task: ...@@ -32,7 +32,7 @@ task:
use_prompt: promptsource:* use_prompt: promptsource:*
training_split: train training_split: train
validation_split: validation validation_split: validation
output_type: greedy_until output_type: generate_until
metric_list: metric_list:
- metric: exact_match - metric: exact_match
aggregation: mean aggregation: mean
...@@ -44,7 +44,7 @@ task: ...@@ -44,7 +44,7 @@ task:
use_prompt: promptsource:* use_prompt: promptsource:*
training_split: train training_split: train
validation_split: validation validation_split: validation
output_type: greedy_until output_type: generate_until
metric_list: metric_list:
- metric: exact_match - metric: exact_match
aggregation: mean aggregation: mean
...@@ -56,7 +56,7 @@ task: ...@@ -56,7 +56,7 @@ task:
use_prompt: promptsource:* use_prompt: promptsource:*
training_split: train_r1 training_split: train_r1
validation_split: dev_r1 validation_split: dev_r1
output_type: greedy_until output_type: generate_until
metric_list: metric_list:
- metric: exact_match - metric: exact_match
aggregation: mean aggregation: mean
...@@ -68,7 +68,7 @@ task: ...@@ -68,7 +68,7 @@ task:
use_prompt: promptsource:* use_prompt: promptsource:*
training_split: train_r2 training_split: train_r2
validation_split: dev_r2 validation_split: dev_r2
output_type: greedy_until output_type: generate_until
metric_list: metric_list:
- metric: exact_match - metric: exact_match
aggregation: mean aggregation: mean
...@@ -80,7 +80,7 @@ task: ...@@ -80,7 +80,7 @@ task:
use_prompt: promptsource:* use_prompt: promptsource:*
training_split: train_r3 training_split: train_r3
validation_split: dev_r3 validation_split: dev_r3
output_type: greedy_until output_type: generate_until
metric_list: metric_list:
- metric: exact_match - metric: exact_match
aggregation: mean aggregation: mean
...@@ -93,7 +93,7 @@ task: ...@@ -93,7 +93,7 @@ task:
use_prompt: promptsource:* use_prompt: promptsource:*
training_split: train training_split: train
validation_split: validation validation_split: validation
output_type: greedy_until output_type: generate_until
metric_list: metric_list:
- metric: exact_match - metric: exact_match
aggregation: mean aggregation: mean
...@@ -105,7 +105,7 @@ task: ...@@ -105,7 +105,7 @@ task:
use_prompt: promptsource:* use_prompt: promptsource:*
training_split: train training_split: train
validation_split: validation validation_split: validation
output_type: greedy_until output_type: generate_until
metric_list: metric_list:
- metric: exact_match - metric: exact_match
aggregation: mean aggregation: mean
...@@ -118,7 +118,7 @@ task: ...@@ -118,7 +118,7 @@ task:
use_prompt: promptsource:* use_prompt: promptsource:*
training_split: train training_split: train
validation_split: validation validation_split: validation
output_type: greedy_until output_type: generate_until
metric_list: metric_list:
- metric: exact_match - metric: exact_match
aggregation: mean aggregation: mean
......
...@@ -175,8 +175,8 @@ all_subtasks = [ ...@@ -175,8 +175,8 @@ all_subtasks = [
def main() -> None: def main() -> None:
for path, task_type in zip( for path, task_type in zip(
["multiple_choice", "greedy_until"], ["multiple_choice", "generate_until"],
["multiple_choice_template_yaml", "greedy_until_template_yaml"], ["multiple_choice_template_yaml", "generate_until_template_yaml"],
): ):
os.makedirs(path, exist_ok=True) os.makedirs(path, exist_ok=True)
for task in all_subtasks: for task in all_subtasks:
......
# Generated by utils.py # Generated by utils.py
dataset_name: abstract_narrative_understanding_zero_shot dataset_name: abstract_narrative_understanding_zero_shot
include: ../greedy_until_template_yaml include: ../generate_until_template_yaml
task: bigbench_abstract_narrative_understanding_greedy_until task: bigbench_abstract_narrative_understanding_generate_until
# Generated by utils.py # Generated by utils.py
dataset_name: anachronisms_zero_shot dataset_name: anachronisms_zero_shot
include: ../greedy_until_template_yaml include: ../generate_until_template_yaml
task: bigbench_anachronisms_greedy_until task: bigbench_anachronisms_generate_until
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment