Unverified Commit 29f12dd9 authored by Lintang Sutawika's avatar Lintang Sutawika Committed by GitHub
Browse files

Merge branch 'big-refactor' into benchmark-scripts

parents e37698df 4168c05f
...@@ -55,7 +55,7 @@ jobs: ...@@ -55,7 +55,7 @@ jobs:
- name: Install dependencies - name: Install dependencies
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
pip install -e '.[testing]' --extra-index-url https://download.pytorch.org/whl/cpu pip install -e '.[testing,anthropic,sentencepiece]' --extra-index-url https://download.pytorch.org/whl/cpu
# Install optional git dependencies # Install optional git dependencies
# pip install bleurt@https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt # pip install bleurt@https://github.com/google-research/bleurt/archive/b610120347ef22b494b6d69b4316e303f5932516.zip#egg=bleurt
# if [ -f requirements.txt ]; then pip install -r requirements.txt; fi # if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
......
...@@ -26,13 +26,13 @@ Dataset configuration options: ...@@ -26,13 +26,13 @@ Dataset configuration options:
- **validation_split** (`str`, *optional*) — Split in the dataset to use as the validation split. - **validation_split** (`str`, *optional*) — Split in the dataset to use as the validation split.
- **test_split** (`str`, *optional*) — Split in the dataset to use as the test split. - **test_split** (`str`, *optional*) — Split in the dataset to use as the test split.
- **fewshot_split** (`str`, *optional*) — Split in the dataset to draw few-shot exemplars from. assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaling (?) - **fewshot_split** (`str`, *optional*) — Split in the dataset to draw few-shot exemplars from. assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaling (?)
- **process_docs** (`Callable`, *optional*) — Optionally define a function to apply to each HF dataset split, to preprocess all documents before being fed into prompt template rendering or other evaluation steps. Can be used to rename dataset columns, or to process documents into a format closer to the expected format expected by a prompt template.
Prompting / in-context formatting options: Prompting / in-context formatting options:
- **template_aliases** (`str`, *optional*) — A field for inputting additional Jinja2 content. Intended not to render as text after applying a Jinja template, but to instead define variables within Jinja that will be used within the written prompts. (for example, mapping the dataset column `label` to the new name `gold`). - **use_prompt** (`str`, *optional*) — Name of prompt in promptsource to use. if defined, will overwrite doc_to_text, doc_to_target, and doc_to_choice.
- **use_prompt** (`str`, *optional*) — Name of prompt in promptsource to use. if defined, will overwrite doc_to_text and doc_to_target and make template_aliases unused.
- **doc_to_text** (`Union[Callable, str]`, *optional*) — Jinja2, f-string, or function to process a sample into the appropriate input for the model - **doc_to_text** (`Union[Callable, str]`, *optional*) — Jinja2, f-string, or function to process a sample into the appropriate input for the model
- **doc_to_target** (`Union[Callable, str]`, *optional*) — Jinja2, f-string, or function to process a sample into the appropriate target output for the model. - **doc_to_target** (`Union[Callable, str]`, *optional*) — Jinja2, f-string, or function to process a sample into the appropriate target output for the model. For multiple choice tasks, this should return an index into
- **doc_to_choice** (`Union[Callable, str]`, *optional*) — Jinja2, f-string, or function to process a sample into possible choices for `multiple_choice` - **doc_to_choice** (`Union[Callable, str]`, *optional*) — Jinja2, f-string, or function to process a sample into a list of possible string choices for `multiple_choice` tasks. Left undefined for `greedy_until` tasks.
- **gold_alias** (`str`, *optional*, defaults to None) — if provided, used to generate the reference answer that is scored against. Used in cases where `doc_to_target` should be the "target string" format appended to each example's input for a fewshot exemplar, so doc_to_target is used for fewshot examples, but the input to the metric function as `gold` is from `gold_alias`. - **gold_alias** (`str`, *optional*, defaults to None) — if provided, used to generate the reference answer that is scored against. Used in cases where `doc_to_target` should be the "target string" format appended to each example's input for a fewshot exemplar, so doc_to_target is used for fewshot examples, but the input to the metric function as `gold` is from `gold_alias`.
- **fewshot_delimiter** (`str`, *optional*, defaults to "\n\n") — String to insert between few-shot examples. - **fewshot_delimiter** (`str`, *optional*, defaults to "\n\n") — String to insert between few-shot examples.
- **target_delimiter** (`str`, *optional*, defaults to `" "`) — String to insert between input and target output for the datapoint being tested. - **target_delimiter** (`str`, *optional*, defaults to `" "`) — String to insert between input and target output for the datapoint being tested.
...@@ -160,7 +160,7 @@ Thus, given the 64 responses from our LM on each document, we can report metrics ...@@ -160,7 +160,7 @@ Thus, given the 64 responses from our LM on each document, we can report metrics
Use can use python functions for certain arguments by using the `!function` operator after the argument name followed by `<filename>.<pythonfunctionname>`. This feature can be used for the following arguments: Use can use python functions for certain arguments by using the `!function` operator after the argument name followed by `<filename>.<pythonfunctionname>`. This feature can be used for the following arguments:
1. `doc_to_text` 1. `doc_to_text`
2. `doc_to_target` 2. `doc_to_target`
3. `gold_alias` 3. `doc_to_choice`
4. `aggregation` for a `metric` in `metric_list` 4. `aggregation` for a `metric` in `metric_list`
## (No Longer Recommended) Direct `Task` Subclassing ## (No Longer Recommended) Direct `Task` Subclassing
......
...@@ -60,16 +60,44 @@ fewshot_split: <split name to draw fewshot examples from, or `null`> ...@@ -60,16 +60,44 @@ fewshot_split: <split name to draw fewshot examples from, or `null`>
``` ```
though if this is not set, we will default to train/validation/test sets, in that order. though if this is not set, we will default to train/validation/test sets, in that order.
Finally, our dataset may not be already in the exact format we want. Maybe we have to strip whitespace and special characters via a regex from our dataset's "question" field! Or maybe we just want to rename its columns to match a convention we'll be using for our prompts.
Let's create a python file in the directory where we're writing our YAML file:
```bash
touch lm_eval/tasks/<dataset_name>/utils.py
```
Now, in `utils.py` we'll write a function to process each split of our dataset:
```python
def process_docs(dataset: datasets.Dataset):
def _helper(doc):
# modifies the contents of a single
# document in our dataset.
doc["choices"] = [doc["choice1"], doc["choice2"], doc["wrong_answer"]]
doc["gold"] = doc["label"]
return doc
return dataset.map(_helper) # returns back a datasets.Dataset object
```
Now, in our YAML config file we'll use the `!function` constructor, and tell the config where our imported Python function will come from. At runtime, before doing anything else we will preprocess our dataset according to this function!
```yaml
process_docs: !function utils.process_docs
```
### Writing a prompt with Jinja 2 ### Writing a prompt with Jinja 2
The next thing we need to do is decide what format to use when presenting the data to the LM. This is our **prompt**, where we'll define both an input and output format. The next thing we need to do is decide what format to use when presenting the data to the LM. This is our **prompt**, where we'll define both an input and output format.
We support the [Jinja 2](https://jinja.palletsprojects.com/en/3.1.x/) templating language for writing prompts. In practice, this means you can take your dataset's columns and do many basic string manipulations to place each document into prompted format. We support the [Jinja 2](https://jinja.palletsprojects.com/en/3.1.x/) templating language for writing prompts. In practice, this means you can take your dataset's columns and do many basic string manipulations to place each document into prompted format.
To write a prompt, users are required to write two YAML fields in Jinja as strings: To write a prompt, users are required to write two or three YAML fields in Jinja as strings:
```yaml ```yaml
doc_to_text: doc_to_text:
doc_to_target: doc_to_target:
doc_to_choice:
``` ```
Suppose our dataset has a `"question"` field, and an `"answer"` field, which are both strings. We want the model to see, if given a `document` object that is a row of our dataset: Suppose our dataset has a `"question"` field, and an `"answer"` field, which are both strings. We want the model to see, if given a `document` object that is a row of our dataset:
``` ```
...@@ -101,10 +129,9 @@ For tasks which are multiple choice (a fixed, finite set of label words per each ...@@ -101,10 +129,9 @@ For tasks which are multiple choice (a fixed, finite set of label words per each
An annotated example in the case of SciQ is as follows: An annotated example in the case of SciQ is as follows:
```yaml ```yaml
template_aliases: "{% set answer_choices = [distractor1, distractor2, distractor3, correct_answer] %}{% set gold = 3 %}" # `template_aliases` must set the list of possible answer choices to the jinja variable `answer_choices` (List[str]), and set what the index within `answer_choices` of this doc's gold label (correct answer choice).
doc_to_text: "{{support.lstrip()}}\nQuestion: {{question}}\nAnswer:" # This is the input portion of the prompt for this doc. It will have " {{choice}}" appended to it as target for each choice in answer_choices. doc_to_text: "{{support.lstrip()}}\nQuestion: {{question}}\nAnswer:" # This is the input portion of the prompt for this doc. It will have " {{choice}}" appended to it as target for each choice in answer_choices.
doc_to_target: "{{answer_choices[gold]}}" # this contains the gold-standard answer choice, selected via indexing to index `gold` in the answer choice list. doc_to_target: 3 # this contains the index into the answer choice list of the correct answer.
gold_alias: "{{gold}}" # this must be castable to an integer. It must output only the index within `answer_choices` that is the correct label. doc_to_choice: "{{[distractor1, distractor2, distractor3, correct_answer]}}"
``` ```
Task implementers are thus able to decide what the answer choices should be for a document, and what prompt format to use. Task implementers are thus able to decide what the answer choices should be for a document, and what prompt format to use.
......
...@@ -81,7 +81,6 @@ DEFAULT_METRIC_REGISTRY = { ...@@ -81,7 +81,6 @@ DEFAULT_METRIC_REGISTRY = {
], ],
"loglikelihood_rolling": ["word_perplexity", "byte_perplexity", "bits_per_byte"], "loglikelihood_rolling": ["word_perplexity", "byte_perplexity", "bits_per_byte"],
"multiple_choice": ["acc", "acc_norm"], "multiple_choice": ["acc", "acc_norm"],
"winograd_schema": ["acc"],
"greedy_until": ["exact_match"], "greedy_until": ["exact_match"],
} }
......
...@@ -65,7 +65,7 @@ class TaskConfig(dict): ...@@ -65,7 +65,7 @@ class TaskConfig(dict):
fewshot_split: str = None # TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaling (?) fewshot_split: str = None # TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaling (?)
# formatting / prompting options. # formatting / prompting options.
# see docs/advanced_task_guide.md for more info # see docs/advanced_task_guide.md for more info
template_aliases: Union[str, list] = None process_docs: Callable = None
doc_to_text: Union[Callable, str] = None doc_to_text: Union[Callable, str] = None
doc_to_target: Union[Callable, str] = None doc_to_target: Union[Callable, str] = None
doc_to_choice: Union[Callable, str, dict, list] = None doc_to_choice: Union[Callable, str, dict, list] = None
...@@ -89,24 +89,13 @@ class TaskConfig(dict): ...@@ -89,24 +89,13 @@ class TaskConfig(dict):
metadata: str = None # by default, not used in the code. allows for users to pass arbitrary info to tasks metadata: str = None # by default, not used in the code. allows for users to pass arbitrary info to tasks
def __post_init__(self): def __post_init__(self):
# allow user-specified aliases so that users can
# force prompt-compatibility for some prompt regardless of
# field names in prompt
if self.template_aliases:
if type(self.doc_to_text) == str:
self.doc_to_text = self.template_aliases + self.doc_to_text
if type(self.doc_to_target) == str:
self.doc_to_target = self.template_aliases + self.doc_to_target
if type(self.gold_alias) == str:
self.gold_alias = self.template_aliases + self.gold_alias
if self.generation_kwargs is not None: if self.generation_kwargs is not None:
if self.output_type != "greedy_until": if self.output_type != "greedy_until":
eval_logger.warning( eval_logger.warning(
"passed `generation_kwargs`, but not using a generation request type!" "passed `generation_kwargs`, but not using `output_type: greedy_until`!"
) )
assert self.output_type != "greedy_until"
if "temperature" in self.generation_kwargs: if "temperature" in self.generation_kwargs:
self.generation_kwargs["temperature"] = float( self.generation_kwargs["temperature"] = float(
...@@ -131,6 +120,9 @@ class TaskConfig(dict): ...@@ -131,6 +120,9 @@ class TaskConfig(dict):
def __getitem__(self, item): def __getitem__(self, item):
return getattr(self, item) return getattr(self, item)
def __setitem__(self, item, value):
return setattr(self, item, value)
def to_dict(self): def to_dict(self):
"""dumps the current config as a dictionary object, as a printable format. """dumps the current config as a dictionary object, as a printable format.
null fields will not be printed. null fields will not be printed.
...@@ -632,10 +624,6 @@ class ConfigurableTask(Task): ...@@ -632,10 +624,6 @@ class ConfigurableTask(Task):
list(self.fewshot_docs()), self, rnd=random.Random(1234) list(self.fewshot_docs()), self, rnd=random.Random(1234)
) )
if self._config.template_aliases is not None:
for key, alias in self._config.template_aliases:
self.dataset.rename_column(key, alias)
if self.has_test_docs(): if self.has_test_docs():
docs = self.test_docs() docs = self.test_docs()
elif self.has_validation_docs(): elif self.has_validation_docs():
...@@ -693,15 +681,25 @@ class ConfigurableTask(Task): ...@@ -693,15 +681,25 @@ class ConfigurableTask(Task):
return False return False
def training_docs(self): def training_docs(self):
if self._config.training_split is not None: if self.has_training_docs():
if self._config.process_docs is not None:
return self._config.process_docs(
self.dataset[self._config.training_split]
)
return self.dataset[self._config.training_split] return self.dataset[self._config.training_split]
def validation_docs(self): def validation_docs(self):
if self._config.validation_split is not None: if self.has_validation_docs():
if self._config.process_docs is not None:
return self._config.process_docs(
self.dataset[self._config.validation_split]
)
return self.dataset[self._config.validation_split] return self.dataset[self._config.validation_split]
def test_docs(self): def test_docs(self):
if self._config.test_split is not None: if self.has_test_docs():
if self._config.process_docs is not None:
return self._config.process_docs(self.dataset[self._config.test_split])
return self.dataset[self._config.test_split] return self.dataset[self._config.test_split]
def fewshot_docs(self): def fewshot_docs(self):
......
...@@ -35,7 +35,7 @@ def simple_evaluate( ...@@ -35,7 +35,7 @@ def simple_evaluate(
model, model,
model_args=None, model_args=None,
tasks=[], tasks=[],
num_fewshot=0, num_fewshot=None,
batch_size=None, batch_size=None,
max_batch_size=None, max_batch_size=None,
device=None, device=None,
...@@ -112,7 +112,17 @@ def simple_evaluate( ...@@ -112,7 +112,17 @@ def simple_evaluate(
+ "_rank" + str(lm.rank) + ".db", + "_rank" + str(lm.rank) + ".db",
) )
task_dict = lm_eval.tasks.get_task_dict(tasks, num_fewshot=num_fewshot) task_dict = lm_eval.tasks.get_task_dict(tasks)
for task_name in task_dict.keys():
config = task_dict[task_name]._config
if num_fewshot is not None:
if config["num_fewshot"] > 0:
default_num_fewshot = config["num_fewshot"]
eval_logger.warning(
f"Overwriting default num_fewshot of {task_name} from {default_num_fewshot} to {num_fewshot}"
)
task_dict[task_name]._config["num_fewshot"] = num_fewshot
if check_integrity: if check_integrity:
run_task_tests(task_list=tasks) run_task_tests(task_list=tasks)
...@@ -134,7 +144,6 @@ def simple_evaluate( ...@@ -134,7 +144,6 @@ def simple_evaluate(
if isinstance(model, str) if isinstance(model, str)
else model.model.config._name_or_path, else model.model.config._name_or_path,
"model_args": model_args, "model_args": model_args,
"num_fewshot": num_fewshot,
"batch_size": batch_size, "batch_size": batch_size,
"batch_sizes": list(lm.batch_sizes.values()) "batch_sizes": list(lm.batch_sizes.values())
if hasattr(lm, "batch_sizes") if hasattr(lm, "batch_sizes")
...@@ -169,8 +178,6 @@ def evaluate( ...@@ -169,8 +178,6 @@ def evaluate(
Language Model Language Model
:param task_dict: dict[str, Task] :param task_dict: dict[str, Task]
Dictionary of tasks. Tasks will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise. Dictionary of tasks. Tasks will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise.
:param num_fewshot: int
Number of examples in few-shot context
:param limit: int, optional :param limit: int, optional
Limit the number of examples per task (only use this for testing) Limit the number of examples per task (only use this for testing)
:param bootstrap_iters: :param bootstrap_iters:
......
...@@ -3,21 +3,28 @@ from lm_eval.api.model import LM ...@@ -3,21 +3,28 @@ from lm_eval.api.model import LM
from lm_eval.api.registry import register_model from lm_eval.api.registry import register_model
from tqdm import tqdm from tqdm import tqdm
import time import time
import anthropic
from lm_eval.logger import eval_logger
from typing import List, Literal, Any
def anthropic_completion( def anthropic_completion(
client, model, prompt, max_tokens_to_sample, temperature, stop client: anthropic.Anthropic,
model: str,
prompt: str,
max_tokens_to_sample: int,
temperature: float,
stop: List[str],
**kwargs: Any,
): ):
"""Query Anthropic API for completion. """Query Anthropic API for completion.
Retry with back-off until they respond Retry with back-off until they respond
""" """
import anthropic
backoff_time = 3 backoff_time = 3
while True: while True:
try: try:
response = client.completion( response = client.completions.create(
prompt=f"{anthropic.HUMAN_PROMPT} {prompt}{anthropic.AI_PROMPT}", prompt=f"{anthropic.HUMAN_PROMPT} {prompt}{anthropic.AI_PROMPT}",
model=model, model=model,
# NOTE: Claude really likes to do CoT, and overly aggressive stop sequences # NOTE: Claude really likes to do CoT, and overly aggressive stop sequences
...@@ -25,36 +32,53 @@ def anthropic_completion( ...@@ -25,36 +32,53 @@ def anthropic_completion(
stop_sequences=[anthropic.HUMAN_PROMPT] + stop, stop_sequences=[anthropic.HUMAN_PROMPT] + stop,
max_tokens_to_sample=max_tokens_to_sample, max_tokens_to_sample=max_tokens_to_sample,
temperature=temperature, temperature=temperature,
**kwargs,
)
return response.completion
except anthropic.RateLimitError as e:
eval_logger.warning(
f"RateLimitError occurred: {e.__cause__}\n Retrying in {backoff_time} seconds"
) )
return response["completion"]
except RuntimeError:
# TODO: I don't actually know what error Anthropic raises when it times out
# So err update this error when we find out.
import traceback
traceback.print_exc()
time.sleep(backoff_time) time.sleep(backoff_time)
backoff_time *= 1.5 backoff_time *= 1.5
@register_model("anthropic") @register_model("anthropic")
class AnthropicLM(LM): class AnthropicLM(LM):
REQ_CHUNK_SIZE = 20 REQ_CHUNK_SIZE = 20 # TODO: not used
def __init__(self, model): def __init__(
""" self,
batch_size: int = 1,
model: str = "claude-2.0",
max_tokens_to_sample: int = 256,
temperature: float = 0, # defaults to 1
**kwargs, # top_p, top_k, etc.
):
"""Anthropic API wrapper.
:param model: str :param model: str
Anthropic model e.g. claude-instant-v1 Anthropic model e.g. 'claude-instant-v1', 'claude-2'
:param max_tokens_to_sample: int
Maximum number of tokens to sample from the model
:param temperature: float
Sampling temperature
:param kwargs: Any
Additional model_args to pass to the API client
""" """
super().__init__() super().__init__()
import anthropic
self.model = model self.model = model
self.client = anthropic.Client(os.environ["ANTHROPIC_API_KEY"]) # defaults to os.environ.get("ANTHROPIC_API_KEY")
self.client = anthropic.Anthropic()
self.temperature = temperature
self.max_tokens_to_sample = max_tokens_to_sample
self.tokenizer = self.client.get_tokenizer()
self.kwargs = kwargs
@property @property
def eot_token_id(self): def eot_token_id(self):
# Not sure but anthropic.AI_PROMPT -> [203, 203, 50803, 30]
raise NotImplementedError("No idea about anthropic tokenization.") raise NotImplementedError("No idea about anthropic tokenization.")
@property @property
...@@ -63,23 +87,23 @@ class AnthropicLM(LM): ...@@ -63,23 +87,23 @@ class AnthropicLM(LM):
@property @property
def max_gen_toks(self): def max_gen_toks(self):
return 256 return self.max_tokens_to_sample
@property @property
def batch_size(self): def batch_size(self):
# Isn't used because we override _loglikelihood_tokens # Isn't used because we override _loglikelihood_tokens
raise NotImplementedError() raise NotImplementedError("No support for logits.")
@property @property
def device(self): def device(self):
# Isn't used because we override _loglikelihood_tokens # Isn't used because we override _loglikelihood_tokens
raise NotImplementedError() raise NotImplementedError("No support for logits.")
def tok_encode(self, string: str): def tok_encode(self, string: str) -> List[int]:
raise NotImplementedError("No idea about anthropic tokenization.") return self.tokenizer.encode(string).ids
def tok_decode(self, tokens): def tok_decode(self, tokens: List[int]) -> str:
raise NotImplementedError("No idea about anthropic tokenization.") return self.tokenizer.decode(tokens)
def _loglikelihood_tokens(self, requests, disable_tqdm=False): def _loglikelihood_tokens(self, requests, disable_tqdm=False):
raise NotImplementedError("No support for logits.") raise NotImplementedError("No support for logits.")
...@@ -92,20 +116,31 @@ class AnthropicLM(LM): ...@@ -92,20 +116,31 @@ class AnthropicLM(LM):
res = [] res = []
for request in tqdm(requests): for request in tqdm(requests):
inp = request[0] try:
request_args = request[1] inp = request[0]
until = request_args["until"] request_args = request[1]
response = anthropic_completion( # generation_kwargs
client=self.client, until = request_args.get("until")
model=self.model, max_gen_toks = request_args.get("max_gen_toks", self.max_length)
prompt=inp, temperature = request_args.get("temperature", self.temperature)
max_tokens_to_sample=self.max_gen_toks, response = anthropic_completion(
temperature=0.0, # TODO: implement non-greedy sampling for Anthropic client=self.client,
stop=until, model=self.model,
) prompt=inp,
res.append(response) max_tokens_to_sample=max_gen_toks,
temperature=temperature, # TODO: implement non-greedy sampling for Anthropic
self.cache_hook.add_partial("greedy_until", request, response) stop=until,
**self.kwargs,
)
res.append(response)
self.cache_hook.add_partial("greedy_until", request, response)
except anthropic.APIConnectionError as e:
eval_logger.critical(f"Server unreachable: {e.__cause__}")
break
except anthropic.APIStatusError as e:
eval_logger.critical(f"API error {e.status_code}: {e.message}")
break
return res return res
...@@ -116,3 +151,9 @@ class AnthropicLM(LM): ...@@ -116,3 +151,9 @@ class AnthropicLM(LM):
def _model_generate(self, context, max_length, eos_token_id): def _model_generate(self, context, max_length, eos_token_id):
# Isn't used because we override greedy_until # Isn't used because we override greedy_until
raise NotImplementedError() raise NotImplementedError()
def loglikelihood(self, requests):
raise NotImplementedError("No support for logits.")
def loglikelihood_rolling(self, requests):
raise NotImplementedError("No support for logits.")
import torch import torch
import transformers import transformers
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES from transformers.models.auto.modeling_auto import (
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
)
from peft import __version__ as PEFT_VERSION, PeftModel from peft import __version__ as PEFT_VERSION, PeftModel
import copy import copy
...@@ -147,6 +150,18 @@ class HFLM(LM): ...@@ -147,6 +150,18 @@ class HFLM(LM):
if getattr(self._config, "model_type") in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: if getattr(self._config, "model_type") in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
self.AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM self.AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM
elif (
not getattr(self._config, "model_type")
in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
):
if not trust_remote_code:
eval_logger.warning(
"HF model type is neither marked as CausalLM or Seq2SeqLM. \
This is expected if your model requires `trust_remote_code=True` but may be an error otherwise."
)
# if model type is neither in HF transformers causal or seq2seq model registries
# then we default to AutoModelForCausalLM
self.AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM
else: else:
self.AUTO_MODEL_CLASS = transformers.AutoModelForSeq2SeqLM self.AUTO_MODEL_CLASS = transformers.AutoModelForSeq2SeqLM
...@@ -634,8 +649,10 @@ class HFLM(LM): ...@@ -634,8 +649,10 @@ class HFLM(LM):
contlen = len(cont_toks) contlen = len(cont_toks)
# take only logits in the continuation # take only logits in the continuation
# (discard context toks if decoder-only ; discard right-padding) # (discard context toks if decoder-only ; discard right-padding)
# also discards + checks for "virtual tokens" in the causal LM's input window
# from prompt/prefix tuning tokens, if applicable
ctx_len = ( ctx_len = (
inplen inplen + (logits.shape[0] - padding_len_inp)
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM
else None else None
) )
......
...@@ -3,7 +3,7 @@ This list keeps track of which tasks' implementations have been ported to YAML / ...@@ -3,7 +3,7 @@ This list keeps track of which tasks' implementations have been ported to YAML /
Boxes should be checked iff tasks are implemented in the refactor and tested for regression. Tasks should be struck through if checked *against original introducing paper* implementation or popularizing implementation. (WIP) Denotes that there exists a PR or person working on this task already. Boxes should be checked iff tasks are implemented in the refactor and tested for regression. Tasks should be struck through if checked *against original introducing paper* implementation or popularizing implementation. (WIP) Denotes that there exists a PR or person working on this task already.
- [ ] Glue (WIP) - [ ] Glue (Lintang)
- [x] SuperGlue - [x] SuperGlue
- [ ] CoQA - [ ] CoQA
- [ ] DROP - [ ] DROP
...@@ -20,14 +20,14 @@ Boxes should be checked iff tasks are implemented in the refactor and tested for ...@@ -20,14 +20,14 @@ Boxes should be checked iff tasks are implemented in the refactor and tested for
- [x] QA4MRE - [x] QA4MRE
- [ ] TriviaQA - [ ] TriviaQA
- [x] AI2 ARC - [x] AI2 ARC
- [ ] LogiQA (WIP) - [ ] LogiQA [(WIP)](https://github.com/EleutherAI/lm-evaluation-harness/pull/711)
- [x] HellaSwag - [x] HellaSwag
- [x] SWAG - [x] SWAG
- [x] OpenBookQA - [x] OpenBookQA
- [ ] SQuADv2 (WIP) - [ ] SQuADv2
- [x] RACE - [x] RACE
- [x] HeadQA - [x] HeadQA
- [ ] MathQA (WIP) - [x] MathQA
- [ ] WebQs - [ ] WebQs
- [ ] WSC273 - [ ] WSC273
- [x] Winogrande - [x] Winogrande
...@@ -37,28 +37,27 @@ Boxes should be checked iff tasks are implemented in the refactor and tested for ...@@ -37,28 +37,27 @@ Boxes should be checked iff tasks are implemented in the refactor and tested for
- [ ] TruthfulQA (mc2) - [ ] TruthfulQA (mc2)
- [ ] TruthfulQA (gen) - [ ] TruthfulQA (gen)
- [ ] MuTual - [ ] MuTual
- [ ] Hendrycks Math (WIP) - [ ] Hendrycks Math
- [ ] Asdiv (WIP) - [ ] Asdiv
- [ ] GSM8k - [ ] GSM8k
- [x] Arithmetic - [x] Arithmetic
- [ ] MMMLU - [ ] MMMLU (Hailey)
- [ ] Translation (WMT) suite - [ ] Translation (WMT) suite (Hailey)
- [x] Unscramble - [x] Unscramble
- [x] ~~Pile (perplexity)~~ - [x] ~~Pile (perplexity)~~
- [ ] BLiMP - [ ] BLiMP (Lintang)
- [x] ToxiGen - [x] ToxiGen
- [ ] StoryCloze - [ ] StoryCloze
- [ ] NaturalQs (WIP) - [ ] NaturalQs
- [ ] CrowS-Pairs - [ ] CrowS-Pairs
- [ ] XCopa - [ ] XCopa
- [ ] BIG-Bench - [ ] BIG-Bench
- [ ] XStoryCloze - [ ] XStoryCloze
- [ ] XWinograd - [x] XWinograd
- [ ] PAWS-X - [ ] PAWS-X
- [ ] XNLI - [ ] XNLI
- [ ] MGSM - [ ] MGSM
- [ ] SCROLLS - [ ] SCROLLS
- [ ] JSON Task (reference: https://github.com/EleutherAI/lm-evaluation-harness/pull/481)
- [ ] Babi - [ ] Babi
# Novel Tasks # Novel Tasks
......
...@@ -6,7 +6,6 @@ dataset_name: arithmetic_1dc ...@@ -6,7 +6,6 @@ dataset_name: arithmetic_1dc
output_type: loglikelihood output_type: loglikelihood
validation_split: validation validation_split: validation
test_split: null test_split: null
template_aliases: ""
doc_to_text: "{{context}}" doc_to_text: "{{context}}"
doc_to_target: "{{completion}}" doc_to_target: "{{completion}}"
metric_list: metric_list:
......
group: include: arithmetic_1dc.yaml
- arithmetic
task: arithmetic_2da task: arithmetic_2da
dataset_path: EleutherAI/arithmetic
dataset_name: arithmetic_2da dataset_name: arithmetic_2da
output_type: loglikelihood
validation_split: validation
test_split: null
template_aliases: ""
doc_to_text: "{{context}}"
doc_to_target: "{{completion}}"
metric_list:
- metric: acc
aggregation: mean
higher_is_better: true
group: include: arithmetic_1dc.yaml
- arithmetic
task: arithmetic_2dm task: arithmetic_2dm
dataset_path: EleutherAI/arithmetic
dataset_name: arithmetic_2dm dataset_name: arithmetic_2dm
output_type: loglikelihood
validation_split: validation
test_split: null
template_aliases: ""
doc_to_text: "{{context}}"
doc_to_target: "{{completion}}"
metric_list:
- metric: acc
aggregation: mean
higher_is_better: true
group: include: arithmetic_1dc.yaml
- arithmetic
task: arithmetic_2ds task: arithmetic_2ds
dataset_path: EleutherAI/arithmetic
dataset_name: arithmetic_2ds dataset_name: arithmetic_2ds
output_type: loglikelihood
validation_split: validation
test_split: null
template_aliases: ""
doc_to_text: "{{context}}"
doc_to_target: "{{completion}}"
metric_list:
- metric: acc
aggregation: mean
higher_is_better: true
group: include: arithmetic_1dc.yaml
- arithmetic
task: arithmetic_3da task: arithmetic_3da
dataset_path: EleutherAI/arithmetic
dataset_name: arithmetic_3da dataset_name: arithmetic_3da
output_type: loglikelihood
validation_split: validation
test_split: null
template_aliases: ""
doc_to_text: "{{context}}"
doc_to_target: "{{completion}}"
metric_list:
- metric: acc
aggregation: mean
higher_is_better: true
group: include: arithmetic_1dc.yaml
- arithmetic
task: arithmetic_3ds task: arithmetic_3ds
dataset_path: EleutherAI/arithmetic
dataset_name: arithmetic_3ds dataset_name: arithmetic_3ds
output_type: loglikelihood
validation_split: validation
test_split: null
template_aliases: ""
doc_to_text: "{{context}}"
doc_to_target: "{{completion}}"
metric_list:
- metric: acc
aggregation: mean
higher_is_better: true
group: include: arithmetic_1dc.yaml
- arithmetic
task: arithmetic_4da task: arithmetic_4da
dataset_path: EleutherAI/arithmetic
dataset_name: arithmetic_4da dataset_name: arithmetic_4da
output_type: loglikelihood
validation_split: validation
test_split: null
template_aliases: ""
doc_to_text: "{{context}}"
doc_to_target: "{{completion}}"
metric_list:
- metric: acc
aggregation: mean
higher_is_better: true
group: include: arithmetic_1dc.yaml
- arithmetic
task: arithmetic_4ds task: arithmetic_4ds
dataset_path: EleutherAI/arithmetic
dataset_name: arithmetic_4ds dataset_name: arithmetic_4ds
output_type: loglikelihood
validation_split: validation
test_split: null
template_aliases: ""
doc_to_text: "{{context}}"
doc_to_target: "{{completion}}"
metric_list:
- metric: acc
aggregation: mean
higher_is_better: true
group: include: arithmetic_1dc.yaml
- arithmetic
task: arithmetic_5da task: arithmetic_5da
dataset_path: EleutherAI/arithmetic
dataset_name: arithmetic_5da dataset_name: arithmetic_5da
output_type: loglikelihood
validation_split: validation
test_split: null
template_aliases: ""
doc_to_text: "{{context}}"
doc_to_target: "{{completion}}"
metric_list:
- metric: acc
aggregation: mean
higher_is_better: true
group: include: arithmetic_1dc.yaml
- arithmetic
task: arithmetic_5ds task: arithmetic_5ds
dataset_path: EleutherAI/arithmetic
dataset_name: arithmetic_5ds dataset_name: arithmetic_5ds
output_type: loglikelihood
validation_split: validation
test_split: null
template_aliases: ""
doc_to_text: "{{context}}"
doc_to_target: "{{completion}}"
metric_list:
- metric: acc
aggregation: mean
higher_is_better: true
group: group:
- hendrycks_ethics - hendrycks_ethics
task: ethics_cm task: ethics_cm
dataset_path: hails/hendrycks_ethics dataset_path: EleutherAI/hendrycks_ethics
dataset_name: commonsense dataset_name: commonsense
output_type: multiple_choice output_type: multiple_choice
training_split: train training_split: train
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment