Unverified Commit ef332026 authored by Hailey Schoelkopf's avatar Hailey Schoelkopf Committed by GitHub
Browse files

Merge pull request #931 from EleutherAI/fix-generate-until

[Refactor] Generate_until rename
parents a8d130ab e66ba123
...@@ -155,14 +155,14 @@ A full accounting of the supported and planned libraries + APIs can be seen belo ...@@ -155,14 +155,14 @@ A full accounting of the supported and planned libraries + APIs can be seen belo
| API or Inference Server | Implemented? | `--model <xxx>` name | Models supported: | Request Types: | | API or Inference Server | Implemented? | `--model <xxx>` name | Models supported: | Request Types: |
|-----------------------------|---------------------------------|----------------------------------------------------------------------------------|--------------------------------------|----------------------------------------------------------| |-----------------------------|---------------------------------|----------------------------------------------------------------------------------|--------------------------------------|----------------------------------------------------------|
| OpenAI Completions | :heavy_check_mark: | `openai`, `openai-completions`, `gooseai` | up to `code-davinci-002` | `greedy_until`, `loglikelihood`, `loglikelihood_rolling` | | OpenAI Completions | :heavy_check_mark: | `openai`, `openai-completions`, `gooseai` | up to `code-davinci-002` | `generate_until`, `loglikelihood`, `loglikelihood_rolling` |
| OpenAI ChatCompletions | :x: Not yet - needs help! | N/A | (link here?) | `greedy_until` (no logprobs) | | OpenAI ChatCompletions | :x: Not yet - needs help! | N/A | (link here?) | `generate_until` (no logprobs) |
| Anthropic | :heavy_check_mark: | `anthropic` | [Supported Anthropic Engines](https://docs.anthropic.com/claude/reference/selecting-a-model) | `greedy_until` (no logprobs) | | Anthropic | :heavy_check_mark: | `anthropic` | [Supported Anthropic Engines](https://docs.anthropic.com/claude/reference/selecting-a-model) | `generate_until` (no logprobs) |
| GooseAI | :heavy_check_mark: (not separately maintained) | `openai`, `openai-completions`, `gooseai` (same interface as OpenAI Completions) | | `greedy_until`, `loglikelihood`, `loglikelihood_rolling` | | GooseAI | :heavy_check_mark: (not separately maintained) | `openai`, `openai-completions`, `gooseai` (same interface as OpenAI Completions) | | `generate_until`, `loglikelihood`, `loglikelihood_rolling` |
| Textsynth | Needs testing | `textsynth` | ??? | `greedy_until`, `loglikelihood`, `loglikelihood_rolling` | | Textsynth | Needs testing | `textsynth` | ??? | `generate_until`, `loglikelihood`, `loglikelihood_rolling` |
| Cohere | :hourglass: - blocked on Cohere API bug | N/A | [All `cohere.generate()` engines](https://docs.cohere.com/docs/models) | `greedy_until`, `loglikelihood`, `loglikelihood_rolling` | | Cohere | :hourglass: - blocked on Cohere API bug | N/A | [All `cohere.generate()` engines](https://docs.cohere.com/docs/models) | `generate_until`, `loglikelihood`, `loglikelihood_rolling` |
| GGML | :hourglass: [PR](https://github.com/EleutherAI/lm-evaluation-harness/pull/617) | N/A | ??? | `greedy_until`, `loglikelihood`, `loglikelihood_rolling` | | GGML | :hourglass: [PR](https://github.com/EleutherAI/lm-evaluation-harness/pull/617) | N/A | ??? | `generate_until`, `loglikelihood`, `loglikelihood_rolling` |
| vLLM | :x: Not yet - needs help! | N/A | All HF models | `greedy_until` (no logprobs) | | vLLM | :x: Not yet - needs help! | N/A | All HF models | `generate_until` (no logprobs) |
| Your inference server here! | ... | ... | ... | ... | | ... | | Your inference server here! | ... | ... | ... | ... | | ... |
It is on our roadmap to create task variants designed to enable models which do not serve logprobs/loglikelihoods to be compared with generation performance of open-source models. It is on our roadmap to create task variants designed to enable models which do not serve logprobs/loglikelihoods to be compared with generation performance of open-source models.
......
...@@ -57,7 +57,7 @@ import lm_eval ...@@ -57,7 +57,7 @@ import lm_eval
my_model = initialize_my_model() # create your model (could be running finetuning with some custom modeling code) my_model = initialize_my_model() # create your model (could be running finetuning with some custom modeling code)
... ...
lm_obj = Your_LM(model=my_model, batch_size=16) # instantiate an LM subclass that takes your initialized model and can run `Your_LM.loglikelihood()`, `Your_LM.loglikelihood_rolling()`, `Your_LM.greedy_until()` lm_obj = Your_LM(model=my_model, batch_size=16) # instantiate an LM subclass that takes your initialized model and can run `Your_LM.loglikelihood()`, `Your_LM.loglikelihood_rolling()`, `Your_LM.generate_until()`
results = lm_eval.simple_evaluate( # call simple_evaluate results = lm_eval.simple_evaluate( # call simple_evaluate
model=lm_obj, model=lm_obj,
...@@ -83,7 +83,7 @@ from my_tasks import MyTask1 # suppose you've defined a custom lm_eval.api.Task ...@@ -83,7 +83,7 @@ from my_tasks import MyTask1 # suppose you've defined a custom lm_eval.api.Task
my_model = initialize_my_model() # create your model (could be running finetuning with some custom modeling code) my_model = initialize_my_model() # create your model (could be running finetuning with some custom modeling code)
... ...
lm_obj = Your_LM(model=my_model, batch_size=16) # instantiate an LM subclass that takes your initialized model and can run `Your_LM.loglikelihood()`, `Your_LM.loglikelihood_rolling()`, `Your_LM.greedy_until()` lm_obj = Your_LM(model=my_model, batch_size=16) # instantiate an LM subclass that takes your initialized model and can run `Your_LM.loglikelihood()`, `Your_LM.loglikelihood_rolling()`, `Your_LM.generate_until()`
......
...@@ -44,26 +44,24 @@ class MyCustomLM(LM): ...@@ -44,26 +44,24 @@ class MyCustomLM(LM):
#... #...
def greedy_until(self, requests: list[Instance]) -> list[str]: def generate_until(self, requests: list[Instance]) -> list[str]:
#... #...
#... #...
``` ```
Where `Instance` is a dataclass defined in [`lm_eval.api.instance`](https://github.com/EleutherAI/lm-evaluation-harness/blob/big-refactor/lm_eval/api/instance.py) with property `args` which returns a tuple of (context, continuation). Where `Instance` is a dataclass defined in [`lm_eval.api.instance`](https://github.com/EleutherAI/lm-evaluation-harness/blob/big-refactor/lm_eval/api/instance.py) with property `args` which returns a tuple of (context, continuation).
We support We support three types of requests, consisting of different interactions / measurements with an autoregressive LM.
The three types of All three request types take as input `requests` of type `list[Instance]` that have a matching `Instance.request_type` to the method name.
- `generate_until`
- Each request contains `Instance.args : Tuple[str, dict]` containing 1. an input string to the LM and 2. a dictionary of keyword arguments used to control generation parameters.
-
- `loglikelihood`
-
smth smth tokenizer-agnostic - `loglikelihood_rolling`, and args passed to it
3 reqtypes
- greedy_until, and the arguments passed to it
- loglikelihood, and args passed to it
- loglikelihood_rolling, and args passed to it
## Registration ## Registration
......
...@@ -32,7 +32,7 @@ Prompting / in-context formatting options: ...@@ -32,7 +32,7 @@ Prompting / in-context formatting options:
- **use_prompt** (`str`, *optional*) — Name of prompt in promptsource to use. if defined, will overwrite doc_to_text, doc_to_target, and doc_to_choice. - **use_prompt** (`str`, *optional*) — Name of prompt in promptsource to use. if defined, will overwrite doc_to_text, doc_to_target, and doc_to_choice.
- **doc_to_text** (`Union[Callable, str]`, *optional*) — Jinja2, f-string, or function to process a sample into the appropriate input for the model - **doc_to_text** (`Union[Callable, str]`, *optional*) — Jinja2, f-string, or function to process a sample into the appropriate input for the model
- **doc_to_target** (`Union[Callable, str]`, *optional*) — Jinja2, f-string, or function to process a sample into the appropriate target output for the model. For multiple choice tasks, this should return an index into - **doc_to_target** (`Union[Callable, str]`, *optional*) — Jinja2, f-string, or function to process a sample into the appropriate target output for the model. For multiple choice tasks, this should return an index into
- **doc_to_choice** (`Union[Callable, str]`, *optional*) — Jinja2, f-string, or function to process a sample into a list of possible string choices for `multiple_choice` tasks. Left undefined for `greedy_until` tasks. - **doc_to_choice** (`Union[Callable, str]`, *optional*) — Jinja2, f-string, or function to process a sample into a list of possible string choices for `multiple_choice` tasks. Left undefined for `generate_until` tasks.
- **fewshot_delimiter** (`str`, *optional*, defaults to "\n\n") — String to insert between few-shot examples. - **fewshot_delimiter** (`str`, *optional*, defaults to "\n\n") — String to insert between few-shot examples.
- **target_delimiter** (`str`, *optional*, defaults to `" "`) — String to insert between input and target output for the datapoint being tested. - **target_delimiter** (`str`, *optional*, defaults to `" "`) — String to insert between input and target output for the datapoint being tested.
...@@ -42,7 +42,7 @@ Runtime configuration options: ...@@ -42,7 +42,7 @@ Runtime configuration options:
Scoring details: Scoring details:
- **metric_list** (`str`, *optional*, defaults to None) — A list of metrics to use for evaluation. See docs for expected format. - **metric_list** (`str`, *optional*, defaults to None) — A list of metrics to use for evaluation. See docs for expected format.
- **output_type** (`str`, *optional*, defaults to "greedy_until") — Selects the type of model output for the given task. Options are `greedy_until`, `loglikelihood`, `loglikelihood_rolling`, and `multiple_choice`. - **output_type** (`str`, *optional*, defaults to "generate_until") — Selects the type of model output for the given task. Options are `generate_until`, `loglikelihood`, `loglikelihood_rolling`, and `multiple_choice`.
- **generation_kwargs** (`dict`, *optional*) — Auxiliary arguments for the `generate` function from HF transformers library. Advanced keyword arguments may not be supported for non-HF LM classes. - **generation_kwargs** (`dict`, *optional*) — Auxiliary arguments for the `generate` function from HF transformers library. Advanced keyword arguments may not be supported for non-HF LM classes.
- **repeats** (`int`, *optional*, defaults to 1) — Number of repeated runs through model for each sample. can be used for cases such as self-consistency. - **repeats** (`int`, *optional*, defaults to 1) — Number of repeated runs through model for each sample. can be used for cases such as self-consistency.
- **filter_list** (`Union[str, list]`, *optional*) — List of filters to postprocess model outputs. See below for further detail on the filter API. - **filter_list** (`Union[str, list]`, *optional*) — List of filters to postprocess model outputs. See below for further detail on the filter API.
......
...@@ -4,7 +4,7 @@ from typing import Literal, Tuple ...@@ -4,7 +4,7 @@ from typing import Literal, Tuple
@dataclass @dataclass
class Instance: class Instance:
request_type: Literal["loglikelihood", "loglikelihood_rolling", "greedy_until"] request_type: Literal["loglikelihood", "loglikelihood_rolling", "generate_until"]
doc: dict doc: dict
arguments: tuple arguments: tuple
idx: int idx: int
......
...@@ -212,7 +212,7 @@ def f1_fn(items): # This is a passthrough function ...@@ -212,7 +212,7 @@ def f1_fn(items): # This is a passthrough function
@register_metric( @register_metric(
metric="bleu", metric="bleu",
higher_is_better=True, higher_is_better=True,
output_type="greedy_until", output_type="generate_until",
aggregation="bleu", aggregation="bleu",
) )
def bleu_fn(items): # This is a passthrough function def bleu_fn(items): # This is a passthrough function
...@@ -222,7 +222,7 @@ def bleu_fn(items): # This is a passthrough function ...@@ -222,7 +222,7 @@ def bleu_fn(items): # This is a passthrough function
@register_metric( @register_metric(
metric="chrf", metric="chrf",
higher_is_better=True, higher_is_better=True,
output_type="greedy_until", output_type="generate_until",
aggregation="chrf", aggregation="chrf",
) )
def chrf_fn(items): # This is a passthrough function def chrf_fn(items): # This is a passthrough function
...@@ -232,7 +232,7 @@ def chrf_fn(items): # This is a passthrough function ...@@ -232,7 +232,7 @@ def chrf_fn(items): # This is a passthrough function
@register_metric( @register_metric(
metric="ter", metric="ter",
higher_is_better=True, higher_is_better=True,
output_type="greedy_until", output_type="generate_until",
aggregation="ter", aggregation="ter",
) )
def ter_fn(items): # This is a passthrough function def ter_fn(items): # This is a passthrough function
......
...@@ -211,12 +211,12 @@ class CachingLM: ...@@ -211,12 +211,12 @@ class CachingLM:
) )
for req in tqdm(requests): for req in tqdm(requests):
hsh = hash_args(attr, req.args) hsh = hash_args(attr, req.args)
if attr == "greedy_until" and req.args[1].get("do_sample", False): if attr == "generate_until" and req.args[1].get("do_sample", False):
# when we are doing non-greedy generation, don't use the cache # when we are doing non-greedy generation, don't use the cache
# (else every "randomly sampled" generation would be identical for repeats > 1). # (else every "randomly sampled" generation would be identical for repeats > 1).
if not warned: if not warned:
eval_logger.warning( eval_logger.warning(
f"Arguments to lm.greedy_until() '{req.args[1]}' include non-deterministic sampling. Caching will not be performed for such requests." f"Arguments to lm.generate_until() '{req.args[1]}' include non-deterministic sampling. Caching will not be performed for such requests."
) )
warned = True warned = True
res.append(None) res.append(None)
......
...@@ -81,7 +81,7 @@ DEFAULT_METRIC_REGISTRY = { ...@@ -81,7 +81,7 @@ DEFAULT_METRIC_REGISTRY = {
], ],
"loglikelihood_rolling": ["word_perplexity", "byte_perplexity", "bits_per_byte"], "loglikelihood_rolling": ["word_perplexity", "byte_perplexity", "bits_per_byte"],
"multiple_choice": ["acc", "acc_norm"], "multiple_choice": ["acc", "acc_norm"],
"greedy_until": ["exact_match"], "generate_until": ["exact_match"],
} }
...@@ -171,7 +171,6 @@ def is_higher_better(metric_name): ...@@ -171,7 +171,6 @@ def is_higher_better(metric_name):
try: try:
return HIGHER_IS_BETTER_REGISTRY[metric_name] return HIGHER_IS_BETTER_REGISTRY[metric_name]
except KeyError: except KeyError:
raise Warning(f"higher_is_better not specified for metric '{metric_name}'!")
eval_logger.warning( eval_logger.warning(
f"higher_is_better not specified for metric '{metric_name}'!" f"higher_is_better not specified for metric '{metric_name}'!"
) )
...@@ -23,7 +23,7 @@ class DryrunLM(LM): ...@@ -23,7 +23,7 @@ class DryrunLM(LM):
return res return res
def greedy_until(self, requests): def generate_until(self, requests):
res = [] res = []
for ctx, _ in requests: for ctx, _ in requests:
......
...@@ -15,10 +15,10 @@ class Test_HFLM: ...@@ -15,10 +15,10 @@ class Test_HFLM:
multiple_choice_task = tasks.TASK_REGISTRY.get("arc_easy")() # type: ignore multiple_choice_task = tasks.TASK_REGISTRY.get("arc_easy")() # type: ignore
multiple_choice_task.build_all_requests(limit=10, rank=0, world_size=1) multiple_choice_task.build_all_requests(limit=10, rank=0, world_size=1)
MULTIPLE_CH: list[Instance] = multiple_choice_task.instances MULTIPLE_CH: list[Instance] = multiple_choice_task.instances
greedy_until_task = tasks.TASK_REGISTRY.get("gsm8k_yaml")() # type: ignore generate_until_task = tasks.TASK_REGISTRY.get("gsm8k_yaml")() # type: ignore
greedy_until_task.build_all_requests(limit=10, rank=0, world_size=1) generate_until_task.build_all_requests(limit=10, rank=0, world_size=1)
greedy_until_task._config.generation_kwargs["max_gen_toks"] = 10 generate_until_task._config.generation_kwargs["max_gen_toks"] = 10
GREEDY_UNTIL: list[Instance] = greedy_until_task.instances generate_until: list[Instance] = generate_until_task.instances
rolling_task = tasks.TASK_REGISTRY.get("wikitext")() # type: ignore rolling_task = tasks.TASK_REGISTRY.get("wikitext")() # type: ignore
rolling_task.build_all_requests(limit=10, rank=0, world_size=1) rolling_task.build_all_requests(limit=10, rank=0, world_size=1)
ROLLING: list[Instance] = rolling_task.instances ROLLING: list[Instance] = rolling_task.instances
...@@ -65,7 +65,7 @@ class Test_HFLM: ...@@ -65,7 +65,7 @@ class Test_HFLM:
-52.70050811767578, -52.70050811767578,
-56.25089645385742, -56.25089645385742,
] ]
GREEDY_UNTIL_RES = [ generate_until_RES = [
" The average of $2.50 each is $", " The average of $2.50 each is $",
" A robe takes 2 bolts of blue fiber and half", " A robe takes 2 bolts of blue fiber and half",
" $50,000 in repairs.", " $50,000 in repairs.",
...@@ -109,9 +109,9 @@ class Test_HFLM: ...@@ -109,9 +109,9 @@ class Test_HFLM:
), np.argmax(np.array(_res).reshape(-1, 4), axis=1) ), np.argmax(np.array(_res).reshape(-1, 4), axis=1)
assert (argmax_RES == argmax_res).all() assert (argmax_RES == argmax_res).all()
def test_greedy_until(self) -> None: def test_generate_until(self) -> None:
res = self.LM.greedy_until(self.GREEDY_UNTIL) res = self.LM.generate_until(self.generate_until)
assert res == self.GREEDY_UNTIL_RES assert res == self.generate_until_RES
def test_logliklihood_rolling(self) -> None: def test_logliklihood_rolling(self) -> None:
res = self.LM.loglikelihood_rolling(self.ROLLING) res = self.LM.loglikelihood_rolling(self.ROLLING)
......
...@@ -78,7 +78,7 @@ def test_gpt2(): ...@@ -78,7 +78,7 @@ def test_gpt2():
# test empty context # test empty context
gpt2.loglikelihood([("", "test")]) gpt2.loglikelihood([("", "test")])
(gen,) = gpt2.greedy_until( (gen,) = gpt2.generate_until(
[("The quick brown fox jumps over the lazy", [".", "\n"])] [("The quick brown fox jumps over the lazy", [".", "\n"])]
) )
...@@ -204,7 +204,7 @@ def test_gpt3(): ...@@ -204,7 +204,7 @@ def test_gpt3():
# test empty context # test empty context
gpt3.loglikelihood([("", "test")]) gpt3.loglikelihood([("", "test")])
(gen,) = gpt3.greedy_until( (gen,) = gpt3.generate_until(
[("The quick brown fox jumps over the lazy", [".", "\n"])] [("The quick brown fox jumps over the lazy", [".", "\n"])]
) )
...@@ -300,7 +300,7 @@ def test_textsynth(): ...@@ -300,7 +300,7 @@ def test_textsynth():
# test empty context # test empty context
textsynth.loglikelihood([("", "test")]) textsynth.loglikelihood([("", "test")])
(gen,) = textsynth.greedy_until( (gen,) = textsynth.generate_until(
[("The quick brown fox jumps over the lazy", [".", "\n"])] [("The quick brown fox jumps over the lazy", [".", "\n"])]
) )
......
...@@ -98,9 +98,9 @@ def test_versions_stable(taskname, task_class): ...@@ -98,9 +98,9 @@ def test_versions_stable(taskname, task_class):
return res return res
def greedy_until(reqs): def generate_until(reqs):
res = [] res = []
assert_target_hashed(f"{taskname}-v{task_class.VERSION}-greedy_until", reqs) assert_target_hashed(f"{taskname}-v{task_class.VERSION}-generate_until", reqs)
for ctx, _ in [req.args for req in reqs]: for ctx, _ in [req.args for req in reqs]:
res.append("lol") res.append("lol")
...@@ -110,7 +110,7 @@ def test_versions_stable(taskname, task_class): ...@@ -110,7 +110,7 @@ def test_versions_stable(taskname, task_class):
lm.loglikelihood = ll_fn lm.loglikelihood = ll_fn
lm.loglikelihood_rolling = ll_perp_fn lm.loglikelihood_rolling = ll_perp_fn
lm.greedy_until = greedy_until lm.generate_until = generate_until
limit = None limit = None
result = evaluator.evaluate( result = evaluator.evaluate(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment