Commit 02e841ce authored by lintangsutawika's avatar lintangsutawika
Browse files

Merge branch 'main' of https://github.com/EleutherAI/lm-evaluation-harness into t5v2-alt-plus

parents 90ad5db7 e74ec966
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
exclude: ^tests/testdata/ exclude: ^tests/testdata/
repos: repos:
- repo: https://github.com/pre-commit/pre-commit-hooks - repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.1.0 rev: v4.5.0
hooks: hooks:
- id: check-added-large-files - id: check-added-large-files
- id: check-ast - id: check-ast
...@@ -29,7 +29,7 @@ repos: ...@@ -29,7 +29,7 @@ repos:
args: [--fix=lf] args: [--fix=lf]
- repo: https://github.com/astral-sh/ruff-pre-commit - repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version. # Ruff version.
rev: v0.1.8 rev: v0.2.2
hooks: hooks:
# Run the linter. # Run the linter.
- id: ruff - id: ruff
...@@ -38,7 +38,7 @@ repos: ...@@ -38,7 +38,7 @@ repos:
# Run the formatter. # Run the formatter.
- id: ruff-format - id: ruff-format
- repo: https://github.com/codespell-project/codespell - repo: https://github.com/codespell-project/codespell
rev: v2.1.0 rev: v2.2.6
hooks: hooks:
- id: codespell - id: codespell
exclude: > exclude: >
......
...@@ -321,7 +321,7 @@ lm_eval \ ...@@ -321,7 +321,7 @@ lm_eval \
--log_samples --log_samples
``` ```
In the stdout, you will find the link to the W&B run page as well as link to the generated report. You can find an example of this workflow in [examples/visualize-wandb.ipynb](examples/visualize-wandb.ipynb). In the stdout, you will find the link to the W&B run page as well as link to the generated report. You can find an example of this workflow in [examples/visualize-wandb.ipynb](examples/visualize-wandb.ipynb), and an example of how to integrate it beyond the CLI.
## How to Contribute or Learn More? ## How to Contribute or Learn More?
......
...@@ -19,7 +19,7 @@ LM Evaluation Harness uses [ruff](https://github.com/astral-sh/ruff) for linting ...@@ -19,7 +19,7 @@ LM Evaluation Harness uses [ruff](https://github.com/astral-sh/ruff) for linting
You can install linters and dev tools via You can install linters and dev tools via
```pip install lm_eval[dev]``` ```pip install lm_eval[dev]``` or ```pip install -e ".[dev]"```
Then, run Then, run
......
...@@ -2,15 +2,14 @@ ...@@ -2,15 +2,14 @@
## Usage ## Usage
Simply add a "--decontamination_ngrams_path" when running \__main\__.py. The provided directory should contain The provided directory should contain
the ngram files and info.json produced in "Pile Ngram Generation" further down. the ngram files and info.json produced in "Pile Ngram Generation" further down.
```bash ```bash
python -m lm_eval \ python -m lm_eval \
--model gpt2 \ --model gpt2 \
--device 0 \ --device 0 \
--tasks sciq \ --tasks sciq
--decontamination_ngrams_path path/containing/training/set/ngrams
``` ```
## Background ## Background
...@@ -70,5 +69,3 @@ python -m scripts/clean_training_data/compress_and_package \ ...@@ -70,5 +69,3 @@ python -m scripts/clean_training_data/compress_and_package \
-output path/to/final/directory \ -output path/to/final/directory \
-procs 8 -procs 8
``` ```
Congratulations, the final directory can now be passed to lm-evaulation-harness with the "--decontamination_ngrams_path" argument.
...@@ -36,8 +36,6 @@ This mode supports a number of command-line arguments, the details of which can ...@@ -36,8 +36,6 @@ This mode supports a number of command-line arguments, the details of which can
- `--cache_requests` : Can be "true", "refresh", or "delete". "true" means that the cache should be used. "refresh" means that you wish to regenerate the cache, which you should run if you change your dataset configuration for a given task. "delete" will delete the cache. Cached files are stored under lm_eval/cache/.cache unless you specify a different path via the environment variable: `LM_HARNESS_CACHE_PATH`. e.g. `LM_HARNESS_CACHE_PATH=~/Documents/cache_for_lm_harness`. - `--cache_requests` : Can be "true", "refresh", or "delete". "true" means that the cache should be used. "refresh" means that you wish to regenerate the cache, which you should run if you change your dataset configuration for a given task. "delete" will delete the cache. Cached files are stored under lm_eval/cache/.cache unless you specify a different path via the environment variable: `LM_HARNESS_CACHE_PATH`. e.g. `LM_HARNESS_CACHE_PATH=~/Documents/cache_for_lm_harness`.
- `--decontamination_ngrams_path` : Deprecated, see (this commit)[https://github.com/EleutherAI/lm-evaluation-harness/commit/00209e10f6e27edf5d766145afaf894079b5fe10] or older for a working decontamination-checker tool.
- `--check_integrity` : If this flag is used, the library tests for each task selected are run to confirm task integrity. - `--check_integrity` : If this flag is used, the library tests for each task selected are run to confirm task integrity.
- `--write_out` : Used for diagnostic purposes to observe the format of task documents passed to a model. If this flag is used, then prints the prompt and gold target string for the first document of each task. - `--write_out` : Used for diagnostic purposes to observe the format of task documents passed to a model. If this flag is used, then prints the prompt and gold target string for the first document of each task.
...@@ -50,7 +48,7 @@ This mode supports a number of command-line arguments, the details of which can ...@@ -50,7 +48,7 @@ This mode supports a number of command-line arguments, the details of which can
* `--seed`: Set seed for python's random, numpy and torch. Accepts a comma-separated list of 3 values for python's random, numpy, and torch seeds, respectively, or a single integer to set the same seed for all three. The values are either an integer or 'None' to not set the seed. Default is `0,1234,1234` (for backward compatibility). E.g. `--seed 0,None,8` sets `random.seed(0)` and `torch.manual_seed(8)`. Here numpy's seed is not set since the second value is `None`. E.g, `--seed 42` sets all three seeds to 42. * `--seed`: Set seed for python's random, numpy and torch. Accepts a comma-separated list of 3 values for python's random, numpy, and torch seeds, respectively, or a single integer to set the same seed for all three. The values are either an integer or 'None' to not set the seed. Default is `0,1234,1234` (for backward compatibility). E.g. `--seed 0,None,8` sets `random.seed(0)` and `torch.manual_seed(8)`. Here numpy's seed is not set since the second value is `None`. E.g, `--seed 42` sets all three seeds to 42.
* `--wandb_args`: Tracks logging to Weights and Biases for evaluation runs and includes args passed to `wandb.init`, such as `project` and `job_type`. Full list (here.)[https://docs.wandb.ai/ref/python/init] * `--wandb_args`: Tracks logging to Weights and Biases for evaluation runs and includes args passed to `wandb.init`, such as `project` and `job_type`. Full list (here.)[https://docs.wandb.ai/ref/python/init]. e.g., ```--wandb_args project=test-project,name=test-run```
## External Library Usage ## External Library Usage
......
...@@ -30,9 +30,10 @@ Dataset configuration options: ...@@ -30,9 +30,10 @@ Dataset configuration options:
Prompting / in-context formatting options: Prompting / in-context formatting options:
- **use_prompt** (`str`, *optional*) — Name of prompt in promptsource to use. if defined, will overwrite doc_to_text, doc_to_target, and doc_to_choice. - **use_prompt** (`str`, *optional*) — Name of prompt in promptsource to use. if defined, will overwrite doc_to_text, doc_to_target, and doc_to_choice.
- **doc_to_text** (`Union[Callable, str]`, *optional*) — Jinja2, f-string, or function to process a sample into the appropriate input for the model - **description** (`str`, *optional*) — An optional prepended Jinja2 template or string which will be prepended to the few-shot examples passed into the model, often describing the task or providing instructions to a model, such as `"The following are questions (with answers) about {{subject}}.\n\n"`. No delimiters or spacing are inserted between the description and the first few-shot example.
- **doc_to_target** (`Union[Callable, str]`, *optional*) — Jinja2, f-string, or function to process a sample into the appropriate target output for the model. For multiple choice tasks, this should return an index into - **doc_to_text** (`Union[Callable, str]`, *optional*) — Jinja2 template, string, or function to process a sample into the appropriate input for the model
- **doc_to_choice** (`Union[Callable, str]`, *optional*) — Jinja2, f-string, or function to process a sample into a list of possible string choices for `multiple_choice` tasks. Left undefined for `generate_until` tasks. - **doc_to_target** (`Union[Callable, str]`, *optional*) — Jinja2 template, string, or function to process a sample into the appropriate target output for the model. For multiple choice tasks, this should return an index into
- **doc_to_choice** (`Union[Callable, str]`, *optional*) — Jinja2 template, string, or function to process a sample into a list of possible string choices for `multiple_choice` tasks. Left undefined for `generate_until` tasks.
- **fewshot_delimiter** (`str`, *optional*, defaults to "\n\n") — String to insert between few-shot examples. - **fewshot_delimiter** (`str`, *optional*, defaults to "\n\n") — String to insert between few-shot examples.
- **target_delimiter** (`str`, *optional*, defaults to `" "`) — String to insert between input and target output for the datapoint being tested. - **target_delimiter** (`str`, *optional*, defaults to `" "`) — String to insert between input and target output for the datapoint being tested.
......
...@@ -67,6 +67,7 @@ ...@@ -67,6 +67,7 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"import wandb\n", "import wandb\n",
"\n",
"wandb.login()" "wandb.login()"
] ]
}, },
...@@ -104,6 +105,43 @@ ...@@ -104,6 +105,43 @@
" --wandb_args project=lm-eval-harness-integration \\\n", " --wandb_args project=lm-eval-harness-integration \\\n",
" --log_samples" " --log_samples"
] ]
},
{
"cell_type": "markdown",
"id": "e974cabdbe70b667",
"metadata": {},
"source": ""
},
{
"cell_type": "markdown",
"id": "5178ca9445b844e4",
"metadata": {},
"source": "W&B can also be initialized programmatically for use outside the CLI to parse and log the results."
},
{
"cell_type": "code",
"execution_count": null,
"id": "c6a421b2cf3ddac5",
"metadata": {},
"outputs": [],
"source": [
"import lm_eval\n",
"from lm_eval.logging_utils import WandbLogger\n",
"\n",
"results = lm_eval.simple_evaluate(\n",
" model=\"hf\",\n",
" model_args=\"pretrained=microsoft/phi-2,trust_remote_code=True\",\n",
" tasks=\"hellaswag,mmlu_abstract_algebra\",\n",
" log_samples=True,\n",
")\n",
"\n",
"wandb_logger = WandbLogger(\n",
" project=\"lm-eval-harness-integration\", job_type=\"eval\"\n",
") # or empty if wandb.init(...) already called before\n",
"wandb_logger.post_init(results)\n",
"wandb_logger.log_eval_result()\n",
"wandb_logger.log_eval_samples(results[\"samples\"]) # if log_samples"
]
} }
], ],
"metadata": { "metadata": {
......
...@@ -14,7 +14,10 @@ from lm_eval import evaluator, utils ...@@ -14,7 +14,10 @@ from lm_eval import evaluator, utils
from lm_eval.evaluator import request_caching_arg_to_dict from lm_eval.evaluator import request_caching_arg_to_dict
from lm_eval.logging_utils import WandbLogger from lm_eval.logging_utils import WandbLogger
from lm_eval.tasks import TaskManager, include_path, initialize_tasks from lm_eval.tasks import TaskManager, include_path, initialize_tasks
from lm_eval.utils import make_table from lm_eval.utils import make_table, simple_parse_args_string
DEFAULT_RESULTS_FILE = "results.json"
def _handle_non_serializable(o): def _handle_non_serializable(o):
...@@ -127,7 +130,6 @@ def parse_eval_args() -> argparse.Namespace: ...@@ -127,7 +130,6 @@ def parse_eval_args() -> argparse.Namespace:
choices=["true", "refresh", "delete"], choices=["true", "refresh", "delete"],
help="Speed up evaluation by caching the building of dataset requests. `None` if not caching.", help="Speed up evaluation by caching the building of dataset requests. `None` if not caching.",
) )
parser.add_argument("--decontamination_ngrams_path", default=None) # TODO: not used
parser.add_argument( parser.add_argument(
"--check_integrity", "--check_integrity",
action="store_true", action="store_true",
...@@ -201,6 +203,12 @@ def parse_eval_args() -> argparse.Namespace: ...@@ -201,6 +203,12 @@ def parse_eval_args() -> argparse.Namespace:
"E.g, `--seed 42` sets all three seeds to 42." "E.g, `--seed 42` sets all three seeds to 42."
), ),
) )
parser.add_argument(
"--trust_remote_code",
action="store_true",
help="Sets trust_remote_code to True to execute code to create HF Datasets from the Hub",
)
return parser.parse_args() return parser.parse_args()
...@@ -210,7 +218,7 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None: ...@@ -210,7 +218,7 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
args = parse_eval_args() args = parse_eval_args()
if args.wandb_args: if args.wandb_args:
wandb_logger = WandbLogger(args) wandb_logger = WandbLogger(**simple_parse_args_string(args.wandb_args))
eval_logger = utils.eval_logger eval_logger = utils.eval_logger
eval_logger.setLevel(getattr(logging, f"{args.verbosity}")) eval_logger.setLevel(getattr(logging, f"{args.verbosity}"))
...@@ -220,7 +228,9 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None: ...@@ -220,7 +228,9 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
if args.predict_only: if args.predict_only:
args.log_samples = True args.log_samples = True
if (args.log_samples or args.predict_only) and not args.output_path: if (args.log_samples or args.predict_only) and not args.output_path:
assert args.output_path, "Specify --output_path" raise ValueError(
"Specify --output_path if providing --log_samples or --predict_only"
)
initialize_tasks(args.verbosity) initialize_tasks(args.verbosity)
task_manager = TaskManager(args.verbosity, include_path=args.include_path) task_manager = TaskManager(args.verbosity, include_path=args.include_path)
...@@ -271,12 +281,13 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None: ...@@ -271,12 +281,13 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
if args.output_path: if args.output_path:
path = Path(args.output_path) path = Path(args.output_path)
# check if file or 'dir/results.json' exists # check if file or 'dir/results.json' exists
if path.is_file() or Path(args.output_path).joinpath("results.json").is_file(): if path.is_file():
raise FileExistsError(f"File already exists at {path}")
output_path_file = path.joinpath(DEFAULT_RESULTS_FILE)
if output_path_file.is_file():
eval_logger.warning( eval_logger.warning(
f"File already exists at {path}. Results will be overwritten." f"File {output_path_file} already exists. Results will be overwritten."
) )
output_path_file = path.joinpath("results.json")
assert not path.is_file(), "File already exists"
# if path json then get parent dir # if path json then get parent dir
elif path.suffix in (".json", ".jsonl"): elif path.suffix in (".json", ".jsonl"):
output_path_file = path output_path_file = path
...@@ -284,7 +295,14 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None: ...@@ -284,7 +295,14 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
path = path.parent path = path.parent
else: else:
path.mkdir(parents=True, exist_ok=True) path.mkdir(parents=True, exist_ok=True)
output_path_file = path.joinpath("results.json")
# Respect user's value passed in via CLI, otherwise default to True and add to comma-separated model args
if args.trust_remote_code:
os.environ["HF_DATASETS_TRUST_REMOTE_CODE"] = str(args.trust_remote_code)
args.model_args = (
args.model_args
+ f",trust_remote_code={os.environ['HF_DATASETS_TRUST_REMOTE_CODE']}"
)
eval_logger.info(f"Selected Tasks: {task_names}") eval_logger.info(f"Selected Tasks: {task_names}")
eval_logger.info("Loading selected tasks...") eval_logger.info("Loading selected tasks...")
...@@ -303,17 +321,17 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None: ...@@ -303,17 +321,17 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
device=args.device, device=args.device,
use_cache=args.use_cache, use_cache=args.use_cache,
limit=args.limit, limit=args.limit,
decontamination_ngrams_path=args.decontamination_ngrams_path,
check_integrity=args.check_integrity, check_integrity=args.check_integrity,
write_out=args.write_out, write_out=args.write_out,
log_samples=args.log_samples, log_samples=args.log_samples,
gen_kwargs=args.gen_kwargs, gen_kwargs=args.gen_kwargs,
task_manager=task_manager, task_manager=task_manager,
verbosity=args.verbosity,
predict_only=args.predict_only, predict_only=args.predict_only,
**request_caching_args,
random_seed=args.seed[0], random_seed=args.seed[0],
numpy_random_seed=args.seed[1], numpy_random_seed=args.seed[1],
torch_random_seed=args.seed[2], torch_random_seed=args.seed[2],
**request_caching_args,
) )
if results is not None: if results is not None:
......
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Literal, Tuple from typing import Literal, Optional, Tuple
OutputType = Literal[
"loglikelihood", "loglikelihood_rolling", "generate_until", "multiple_choice"
]
@dataclass @dataclass
class Instance: class Instance:
request_type: Literal[ request_type: OutputType
"loglikelihood",
"loglikelihood_rolling",
"generate_until",
"multiple_choice",
]
doc: dict doc: dict
arguments: tuple arguments: tuple
idx: int idx: int
metadata: Tuple[str, int, int] = field( metadata: Tuple[Optional[str], Optional[int], Optional[int]] = field(
default_factory=lambda: (None, None, None) default_factory=lambda: (None, None, None)
) # TODO: better typehints here )
resps: list = field(default_factory=list) resps: list = field(default_factory=list)
filtered_resps: dict = field(default_factory=dict) filtered_resps: dict = field(default_factory=dict)
# initialized after init # initialized after init
task_name: str = None task_name: Optional[str] = None
doc_id: str = None doc_id: Optional[int] = None
repeats: str = None repeats: Optional[int] = None
def __post_init__(self) -> None: def __post_init__(self) -> None:
# unpack metadata field # unpack metadata field
......
...@@ -54,7 +54,7 @@ class LM(abc.ABC): ...@@ -54,7 +54,7 @@ class LM(abc.ABC):
pass pass
@abc.abstractmethod @abc.abstractmethod
def loglikelihood_rolling(self, requests) -> List[Tuple[float, bool]]: def loglikelihood_rolling(self, requests) -> List[Tuple[float]]:
"""Compute full log-likelihood of a string, with no truncation, for perplexity computation """Compute full log-likelihood of a string, with no truncation, for perplexity computation
- We will use the full max context length of the model. - We will use the full max context length of the model.
- For inputs that exceed the max context length, we divide the tokenized string into chunks of up to - For inputs that exceed the max context length, we divide the tokenized string into chunks of up to
...@@ -83,15 +83,13 @@ class LM(abc.ABC): ...@@ -83,15 +83,13 @@ class LM(abc.ABC):
2. For the last pair, we provide the full context, but only score the last two tokens 2. For the last pair, we provide the full context, but only score the last two tokens
:param requests: list[Instance] :param requests: list[Instance]
A list of Instance objects with property `args` which returns a tuple (context, continuation). A list of Instance objects with property `args` which returns a tuple (context,).
string: str string: str
String for which we are computing per-token loglikelihood String for which we are computing overall loglikelihood
:return: list[tuple[float, bool]] :return: list[tuple[float]]
A list of pairs (logprob, isgreedy) A list of tuples (logprob,)
logprob: float logprob: float
The log probability of `continuation` The log probability of `context` conditioned on the EOT token.
isgreedy:
Whether `continuation` would be generated by greedy sampling from `context`
""" """
pass pass
...@@ -306,7 +304,9 @@ class TemplateLM(LM): ...@@ -306,7 +304,9 @@ class TemplateLM(LM):
return context_enc, continuation_enc return context_enc, continuation_enc
def loglikelihood(self, requests) -> List[Tuple[float, bool]]: def loglikelihood(
self, requests, disable_tqdm: bool = False
) -> List[Tuple[float, bool]]:
new_reqs = [] new_reqs = []
for context, continuation in [req.args for req in requests]: for context, continuation in [req.args for req in requests]:
if context == "": if context == "":
...@@ -320,12 +320,14 @@ class TemplateLM(LM): ...@@ -320,12 +320,14 @@ class TemplateLM(LM):
new_reqs.append(((context, continuation), context_enc, continuation_enc)) new_reqs.append(((context, continuation), context_enc, continuation_enc))
return self._loglikelihood_tokens(new_reqs) return self._loglikelihood_tokens(new_reqs, disable_tqdm=disable_tqdm)
@abc.abstractmethod @abc.abstractmethod
def loglikelihood_rolling(self, requests) -> List[Tuple[float, bool]]: def loglikelihood_rolling(
self, requests, disable_tqdm: bool = False
) -> List[Tuple[float, bool]]:
pass pass
@abc.abstractmethod @abc.abstractmethod
def generate_until(self, requests) -> List[str]: def generate_until(self, requests, disable_tqdm: bool = False) -> List[str]:
pass pass
...@@ -7,7 +7,18 @@ from collections.abc import Callable ...@@ -7,7 +7,18 @@ from collections.abc import Callable
from copy import deepcopy from copy import deepcopy
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass
from inspect import getsource from inspect import getsource
from typing import Any, Iterator, List, Literal, Tuple, Union from typing import (
Any,
Dict,
Iterable,
Iterator,
List,
Literal,
Mapping,
Optional,
Tuple,
Union,
)
import datasets import datasets
import numpy as np import numpy as np
...@@ -15,12 +26,8 @@ from tqdm import tqdm ...@@ -15,12 +26,8 @@ from tqdm import tqdm
from lm_eval import utils from lm_eval import utils
from lm_eval.api import samplers from lm_eval.api import samplers
from lm_eval.api.instance import Instance from lm_eval.api.instance import Instance, OutputType
from lm_eval.api.metrics import ( from lm_eval.api.metrics import bits_per_byte, mean, weighted_perplexity
bits_per_byte,
mean,
weighted_perplexity,
)
from lm_eval.api.registry import ( from lm_eval.api.registry import (
AGGREGATION_REGISTRY, AGGREGATION_REGISTRY,
DEFAULT_METRIC_REGISTRY, DEFAULT_METRIC_REGISTRY,
...@@ -63,56 +70,54 @@ class GroupConfig(dict): ...@@ -63,56 +70,54 @@ class GroupConfig(dict):
@dataclass @dataclass
class TaskConfig(dict): class TaskConfig(dict):
# task naming/registry # task naming/registry
task: str = None task: Optional[str] = None
task_alias: str = None task_alias: Optional[str] = None
group: Union[str, list] = None group: Optional[Union[str, list]] = None
group_alias: Union[str, list] = None group_alias: Optional[Union[str, list]] = None
# HF dataset options. # HF dataset options.
# which dataset to use, # which dataset to use,
# and what splits for what purpose # and what splits for what purpose
dataset_path: str = None dataset_path: Optional[str] = None
dataset_name: str = None dataset_name: Optional[str] = None
dataset_kwargs: dict = None dataset_kwargs: Optional[dict] = None
training_split: str = None training_split: Optional[str] = None
validation_split: str = None validation_split: Optional[str] = None
test_split: str = None test_split: Optional[str] = None
fewshot_split: str = None # TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaling (?) fewshot_split: Optional[
str
] = None # TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaling (?)
# formatting / prompting options. # formatting / prompting options.
# see docs/advanced_task_guide.md for more info # see docs/advanced_task_guide.md for more info
process_docs: Callable = None process_docs: Optional[Callable] = None
doc_to_text: Union[Callable, str] = None doc_to_text: Optional[Union[Callable, str]] = None
doc_to_target: Union[Callable, str] = None doc_to_target: Optional[Union[Callable, str]] = None
doc_to_choice: Union[Callable, str, dict, list] = None doc_to_choice: Optional[Union[Callable, str, dict, list]] = None
process_results: Union[Callable, str] = None process_results: Optional[Union[Callable, str]] = None
use_prompt: str = None use_prompt: Optional[str] = None
description: str = "" description: str = ""
target_delimiter: str = " " target_delimiter: str = " "
fewshot_delimiter: str = "\n\n" fewshot_delimiter: str = "\n\n"
fewshot_config: dict = None fewshot_config: Optional[dict] = None
# runtime configuration options # runtime configuration options
num_fewshot: int = None num_fewshot: Optional[int] = None
# scoring options # scoring options
metric_list: list = None metric_list: Optional[list] = None
output_type: Literal[ output_type: OutputType = "generate_until"
"loglikelihood", generation_kwargs: Optional[dict] = None
"loglikelihood_rolling",
"generate_until",
"multiple_choice",
] = "generate_until"
generation_kwargs: dict = None
repeats: int = 1 repeats: int = 1
filter_list: Union[str, list] = None filter_list: Optional[Union[str, list]] = None
should_decontaminate: bool = False should_decontaminate: bool = False
doc_to_decontamination_query: str = None doc_to_decontamination_query: Optional[str] = None
metadata: dict = None # by default, not used in the code. allows for users to pass arbitrary info to tasks metadata: Optional[
dict
] = None # by default, not used in the code. allows for users to pass arbitrary info to tasks
def __post_init__(self) -> None: def __post_init__(self) -> None:
if self.generation_kwargs is not None: if self.generation_kwargs is not None:
if self.output_type != "generate_until": if self.output_type != "generate_until":
eval_logger.warning( raise ValueError(
f"[{self.task}] passed `generation_kwargs`, but not using `output_type: generate_until`!" f"[{self.task}] passed `generation_kwargs`, but not using `output_type: generate_until`!"
) )
assert self.output_type != "generate_until"
if "temperature" in self.generation_kwargs: if "temperature" in self.generation_kwargs:
self.generation_kwargs["temperature"] = float( self.generation_kwargs["temperature"] = float(
...@@ -193,23 +198,23 @@ class Task(abc.ABC): ...@@ -193,23 +198,23 @@ class Task(abc.ABC):
{"question": ..., question, answer) {"question": ..., question, answer)
""" """
VERSION = None VERSION: Optional[Union[int, str]] = None
# The name of the `Task` benchmark as denoted in the HuggingFace datasets Hub # The name of the `Task` benchmark as denoted in the HuggingFace datasets Hub
# or a path to a custom `datasets` loading script. # or a path to a custom `datasets` loading script.
DATASET_PATH: str = None DATASET_PATH: Optional[str] = None
# The name of a subset within `DATASET_PATH`. # The name of a subset within `DATASET_PATH`.
DATASET_NAME: str = None DATASET_NAME: Optional[str] = None
OUTPUT_TYPE: str = None OUTPUT_TYPE: Optional[OutputType] = None
def __init__( def __init__(
self, self,
data_dir=None, data_dir: Optional[str] = None,
cache_dir=None, cache_dir: Optional[str] = None,
download_mode=None, download_mode: Optional[datasets.DownloadMode] = None,
config=None, config: Optional[Mapping] = None, # Union[dict, TaskConfig]
) -> None: ) -> None:
""" """
:param data_dir: str :param data_dir: str
...@@ -233,15 +238,20 @@ class Task(abc.ABC): ...@@ -233,15 +238,20 @@ class Task(abc.ABC):
Fresh download and fresh dataset. Fresh download and fresh dataset.
""" """
self.download(data_dir, cache_dir, download_mode) self.download(data_dir, cache_dir, download_mode)
self._training_docs = None self._training_docs: Optional[list] = None
self._fewshot_docs = None self._fewshot_docs: Optional[list] = None
self._instances = None self._instances: Optional[List[Instance]] = None
self._config = TaskConfig({**config}) if config else TaskConfig() self._config: TaskConfig = TaskConfig({**config}) if config else TaskConfig()
self._filters = [build_filter_ensemble("none", [["take_first", None]])] self._filters = [build_filter_ensemble("none", [["take_first", None]])]
def download(self, data_dir=None, cache_dir=None, download_mode=None) -> None: def download(
self,
data_dir: Optional[str] = None,
cache_dir: Optional[str] = None,
download_mode=None,
) -> None:
"""Downloads and returns the task dataset. """Downloads and returns the task dataset.
Override this method to download the dataset from a custom API. Override this method to download the dataset from a custom API.
...@@ -275,7 +285,7 @@ class Task(abc.ABC): ...@@ -275,7 +285,7 @@ class Task(abc.ABC):
) )
@property @property
def config(self): def config(self) -> TaskConfig:
"""Returns the TaskConfig associated with this class.""" """Returns the TaskConfig associated with this class."""
return self._config return self._config
...@@ -294,28 +304,28 @@ class Task(abc.ABC): ...@@ -294,28 +304,28 @@ class Task(abc.ABC):
"""Whether the task has a test set""" """Whether the task has a test set"""
pass pass
def training_docs(self): def training_docs(self) -> Iterable:
""" """
:return: Iterable[obj] :return: Iterable[obj]
A iterable of any object, that doc_to_text can handle A iterable of any object, that doc_to_text can handle
""" """
return [] return []
def validation_docs(self): def validation_docs(self) -> Iterable:
""" """
:return: Iterable[obj] :return: Iterable[obj]
A iterable of any object, that doc_to_text can handle A iterable of any object, that doc_to_text can handle
""" """
return [] return []
def test_docs(self): def test_docs(self) -> Iterable:
""" """
:return: Iterable[obj] :return: Iterable[obj]
A iterable of any object, that doc_to_text can handle A iterable of any object, that doc_to_text can handle
""" """
return [] return []
def fewshot_docs(self): def fewshot_docs(self) -> Iterable:
""" """
:return: Iterable[obj] :return: Iterable[obj]
A iterable of any object, that doc_to_text can handle A iterable of any object, that doc_to_text can handle
...@@ -331,7 +341,7 @@ class Task(abc.ABC): ...@@ -331,7 +341,7 @@ class Task(abc.ABC):
) )
return self.test_docs() return self.test_docs()
def _process_doc(self, doc): def _process_doc(self, doc: dict) -> dict:
""" """
Override this to process (detokenize, strip, replace, etc.) individual Override this to process (detokenize, strip, replace, etc.) individual
documents. This can be used in a map over documents of a data split. documents. This can be used in a map over documents of a data split.
...@@ -355,11 +365,10 @@ class Task(abc.ABC): ...@@ -355,11 +365,10 @@ class Task(abc.ABC):
return rnd.sample(self._training_docs, k) return rnd.sample(self._training_docs, k)
def doc_to_decontamination_query(self, doc) -> None: def doc_to_decontamination_query(self, doc):
print( raise NotImplementedError(
"Override doc_to_decontamination_query with document specific decontamination query." "Override doc_to_decontamination_query with document specific decontamination query."
) )
assert False
@abc.abstractmethod @abc.abstractmethod
def doc_to_text(self, doc): def doc_to_text(self, doc):
...@@ -451,7 +460,8 @@ class Task(abc.ABC): ...@@ -451,7 +460,8 @@ class Task(abc.ABC):
self._instances = flattened_instances self._instances = flattened_instances
assert len(self._instances) != 0, "task.build_requests() did not find any docs!" if len(self._instances) == 0:
raise ValueError("task.build_requests() did not find any docs!")
if cache_requests and (not cached_instances or rewrite_requests_cache): if cache_requests and (not cached_instances or rewrite_requests_cache):
save_to_cache(file_name=cache_key, obj=instances) save_to_cache(file_name=cache_key, obj=instances)
...@@ -544,9 +554,10 @@ class Task(abc.ABC): ...@@ -544,9 +554,10 @@ class Task(abc.ABC):
:returns: str :returns: str
The fewshot context. The fewshot context.
""" """
assert ( if rnd is None:
rnd is not None raise ValueError(
), "A `random.Random` generator argument must be provided to `rnd`" "A `random.Random` generator argument must be provided to `rnd`"
)
description = description if description else "" description = description if description else ""
...@@ -582,7 +593,7 @@ class Task(abc.ABC): ...@@ -582,7 +593,7 @@ class Task(abc.ABC):
example = self.doc_to_text(doc) example = self.doc_to_text(doc)
return description + labeled_examples + example return description + labeled_examples + example
def apply_filters(self): def apply_filters(self) -> Optional[List[Instance]]:
"""Iterates over FilterEnsembles and applies them to instances""" """Iterates over FilterEnsembles and applies them to instances"""
if hasattr(self, "_filters"): if hasattr(self, "_filters"):
for f in self._filters: for f in self._filters:
...@@ -644,7 +655,9 @@ class Task(abc.ABC): ...@@ -644,7 +655,9 @@ class Task(abc.ABC):
elif self.has_validation_docs(): elif self.has_validation_docs():
return self.validation_docs() return self.validation_docs()
else: else:
assert False, f"Task dataset (path={self.DATASET_PATH}, name={self.DATASET_NAME}) must have valid or test docs!" raise ValueError(
f"Task dataset (path={self.DATASET_PATH}, name={self.DATASET_NAME}) must have valid or test docs!"
)
def doc_iterator( def doc_iterator(
self, *, rank: int = 0, limit: Union[int, None] = None, world_size: int = 1 self, *, rank: int = 0, limit: Union[int, None] = None, world_size: int = 1
...@@ -665,7 +678,11 @@ class ConfigurableTask(Task): ...@@ -665,7 +678,11 @@ class ConfigurableTask(Task):
CONFIG = None CONFIG = None
def __init__( def __init__(
self, data_dir=None, cache_dir=None, download_mode=None, config: dict = None self,
data_dir=None,
cache_dir=None,
download_mode=None,
config: Optional[dict] = None,
) -> None: # TODO no super() call here ) -> None: # TODO no super() call here
# Get pre-configured attributes # Get pre-configured attributes
self._config = self.CONFIG self._config = self.CONFIG
...@@ -688,7 +705,10 @@ class ConfigurableTask(Task): ...@@ -688,7 +705,10 @@ class ConfigurableTask(Task):
self.VERSION = self.config.metadata["version"] self.VERSION = self.config.metadata["version"]
if self.config.output_type is not None: if self.config.output_type is not None:
assert self.config.output_type in ALL_OUTPUT_TYPES if self.config.output_type not in ALL_OUTPUT_TYPES:
raise ValueError(
f"Got invalid output_type '{self.config.output_type}', must be in '{','.join(ALL_OUTPUT_TYPES)}'"
)
self.OUTPUT_TYPE = self.config.output_type self.OUTPUT_TYPE = self.config.output_type
if self.config.dataset_path is not None: if self.config.dataset_path is not None:
...@@ -715,7 +735,10 @@ class ConfigurableTask(Task): ...@@ -715,7 +735,10 @@ class ConfigurableTask(Task):
self._higher_is_better[metric_name] = is_higher_better(metric_name) self._higher_is_better[metric_name] = is_higher_better(metric_name)
else: else:
for metric_config in self.config.metric_list: for metric_config in self.config.metric_list:
assert "metric" in metric_config if "metric" not in metric_config:
raise ValueError(
"'metric' key not provided for an entry in 'metric_list', must be specified!"
)
metric_name = metric_config["metric"] metric_name = metric_config["metric"]
kwargs = { kwargs = {
key: metric_config[key] key: metric_config[key]
...@@ -860,7 +883,7 @@ class ConfigurableTask(Task): ...@@ -860,7 +883,7 @@ class ConfigurableTask(Task):
f'Both target_delimiter "{self.config.target_delimiter}" and target choice: "{choice}" do not have whitespace, ignore if the language you are evaluating on does not require/use whitespace' f'Both target_delimiter "{self.config.target_delimiter}" and target choice: "{choice}" do not have whitespace, ignore if the language you are evaluating on does not require/use whitespace'
) )
def download(self, dataset_kwargs=None) -> None: def download(self, dataset_kwargs: Optional[Dict[str, Any]] = None) -> None:
self.dataset = datasets.load_dataset( self.dataset = datasets.load_dataset(
path=self.DATASET_PATH, path=self.DATASET_PATH,
name=self.DATASET_NAME, name=self.DATASET_NAME,
...@@ -922,7 +945,7 @@ class ConfigurableTask(Task): ...@@ -922,7 +945,7 @@ class ConfigurableTask(Task):
return super().fewshot_docs() return super().fewshot_docs()
@utils.positional_deprecated @utils.positional_deprecated
def fewshot_context(self, doc, num_fewshot): def fewshot_context(self, doc: str, num_fewshot: int) -> str:
"""Returns a fewshot context string that is made up of a prepended description """Returns a fewshot context string that is made up of a prepended description
(if provided), the `num_fewshot` number of examples, and an appended prompt example. (if provided), the `num_fewshot` number of examples, and an appended prompt example.
...@@ -933,14 +956,14 @@ class ConfigurableTask(Task): ...@@ -933,14 +956,14 @@ class ConfigurableTask(Task):
:returns: str :returns: str
The fewshot context. The fewshot context.
""" """
if description := self.config.description:
description = utils.apply_template(self.config.description, doc)
if num_fewshot == 0: if num_fewshot == 0:
# always prepend the (possibly empty) task description # always prepend the (possibly empty) task description
labeled_examples = self.config.description labeled_examples = description
else: else:
labeled_examples = self.config.description + self.sampler.get_context( labeled_examples = description + self.sampler.get_context(doc, num_fewshot)
doc, num_fewshot
)
example = self.doc_to_text(doc) example = self.doc_to_text(doc)
if self.multiple_input: if self.multiple_input:
...@@ -986,7 +1009,7 @@ class ConfigurableTask(Task): ...@@ -986,7 +1009,7 @@ class ConfigurableTask(Task):
) )
) )
def _process_doc(self, doc): def _process_doc(self, doc: dict) -> dict:
""" """
Override this to process (detokenize, strip, replace, etc.) individual Override this to process (detokenize, strip, replace, etc.) individual
documents. This can be used in a map over documents of a data split. documents. This can be used in a map over documents of a data split.
...@@ -1031,7 +1054,7 @@ class ConfigurableTask(Task): ...@@ -1031,7 +1054,7 @@ class ConfigurableTask(Task):
print(type(doc_to_text)) print(type(doc_to_text))
raise TypeError raise TypeError
def doc_to_target(self, doc: dict) -> Union[int, str, list]: def doc_to_target(self, doc: Mapping) -> Union[int, str, list]:
if self.prompt is not None: if self.prompt is not None:
doc_to_target = self.prompt doc_to_target = self.prompt
else: else:
...@@ -1213,7 +1236,8 @@ class ConfigurableTask(Task): ...@@ -1213,7 +1236,8 @@ class ConfigurableTask(Task):
# then we are doing mutual info. # then we are doing mutual info.
# this stores the "dryrun" / unconditional answer loglikelihoods # this stores the "dryrun" / unconditional answer loglikelihoods
lls_unconditional = lls[1::2] lls_unconditional = lls[1::2]
assert len(lls_unconditional) == len(choices) if len(lls_unconditional) != len(choices):
raise ValueError
# and this stores our "regular" conditional loglikelihoods # and this stores our "regular" conditional loglikelihoods
lls = lls[::2] lls = lls[::2]
...@@ -1376,7 +1400,7 @@ class ConfigurableTask(Task): ...@@ -1376,7 +1400,7 @@ class ConfigurableTask(Task):
class MultipleChoiceTask(Task): class MultipleChoiceTask(Task):
OUTPUT_TYPE: str = "loglikelihood" OUTPUT_TYPE = "loglikelihood"
def doc_to_target(self, doc: dict) -> str: def doc_to_target(self, doc: dict) -> str:
return " " + doc["choices"][doc["gold"]] return " " + doc["choices"][doc["gold"]]
...@@ -1394,7 +1418,7 @@ class MultipleChoiceTask(Task): ...@@ -1394,7 +1418,7 @@ class MultipleChoiceTask(Task):
for i, choice in enumerate(doc["choices"]) for i, choice in enumerate(doc["choices"])
] ]
def process_results(self, doc: dict, results: List[Tuple[float, bool]]) -> dict: def process_results(self, doc: dict, results: Iterable[Tuple[float, bool]]) -> dict:
results = [ results = [
res[0] for res in results res[0] for res in results
] # only retain loglikelihoods, discard is_greedy TODO: do we need is_greedy anywhere? ] # only retain loglikelihoods, discard is_greedy TODO: do we need is_greedy anywhere?
...@@ -1429,13 +1453,17 @@ class PerplexityTask(Task): ...@@ -1429,13 +1453,17 @@ class PerplexityTask(Task):
return False return False
def fewshot_examples(self, k: int, rnd) -> List: def fewshot_examples(self, k: int, rnd) -> List:
assert k == 0 if k != 0:
raise ValueError(
"The number of fewshot examples must be 0 for perplexity tasks."
)
return [] return []
def fewshot_context(self, doc: dict, num_fewshot: int) -> Literal[""]: def fewshot_context(self, doc: dict, num_fewshot: int) -> Literal[""]:
assert ( if num_fewshot != 0:
num_fewshot == 0 raise ValueError(
), "The number of fewshot examples must be 0 for perplexity tasks." "The number of fewshot examples must be 0 for perplexity tasks."
)
return "" return ""
...@@ -1455,8 +1483,9 @@ class PerplexityTask(Task): ...@@ -1455,8 +1483,9 @@ class PerplexityTask(Task):
def doc_to_target(self, doc): def doc_to_target(self, doc):
return doc return doc
def construct_requests(self, doc: dict, ctx: Union[str, None], **kwargs): def construct_requests(self, doc: dict, ctx: Optional[str], **kwargs):
assert not ctx if bool(ctx):
raise ValueError
return Instance( return Instance(
request_type=self.OUTPUT_TYPE, request_type=self.OUTPUT_TYPE,
...@@ -1466,7 +1495,7 @@ class PerplexityTask(Task): ...@@ -1466,7 +1495,7 @@ class PerplexityTask(Task):
**kwargs, **kwargs,
) )
def process_results(self, doc: dict, results: float) -> dict: def process_results(self, doc: dict, results: Tuple[float]) -> dict:
(loglikelihood,) = results (loglikelihood,) = results
words = self.count_words(self.doc_to_target(doc)) words = self.count_words(self.doc_to_target(doc))
bytes_ = self.count_bytes(self.doc_to_target(doc)) bytes_ = self.count_bytes(self.doc_to_target(doc))
......
import collections
import itertools import itertools
import logging import logging
import random import random
from typing import TYPE_CHECKING, Optional, Union from collections import defaultdict
from typing import TYPE_CHECKING, List, Optional, Union
import numpy as np import numpy as np
import torch import torch
...@@ -10,6 +10,7 @@ import torch ...@@ -10,6 +10,7 @@ import torch
import lm_eval.api.metrics import lm_eval.api.metrics
import lm_eval.api.registry import lm_eval.api.registry
import lm_eval.models import lm_eval.models
from lm_eval.caching.cache import delete_cache
from lm_eval.evaluator_utils import ( from lm_eval.evaluator_utils import (
consolidate_results, consolidate_results,
get_sample_size, get_sample_size,
...@@ -20,25 +21,19 @@ from lm_eval.evaluator_utils import ( ...@@ -20,25 +21,19 @@ from lm_eval.evaluator_utils import (
) )
from lm_eval.logging_utils import add_env_info, get_git_commit_hash from lm_eval.logging_utils import add_env_info, get_git_commit_hash
from lm_eval.tasks import TaskManager, get_task_dict from lm_eval.tasks import TaskManager, get_task_dict
from lm_eval.utils import ( from lm_eval.utils import eval_logger, positional_deprecated, simple_parse_args_string
eval_logger,
positional_deprecated,
simple_parse_args_string,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from lm_eval.api.model import LM from lm_eval.api.model import LM
from lm_eval.tasks import Task from lm_eval.tasks import Task
from lm_eval.caching.cache import delete_cache
@positional_deprecated @positional_deprecated
def simple_evaluate( def simple_evaluate(
model, model,
model_args: Optional[Union[str, dict, None]] = None, model_args: Optional[Union[str, dict]] = None,
tasks=None, tasks: Optional[List[Union[str, dict, object]]] = None,
num_fewshot: Optional[int] = None, num_fewshot: Optional[int] = None,
batch_size: Optional[int] = None, batch_size: Optional[int] = None,
max_batch_size: Optional[int] = None, max_batch_size: Optional[int] = None,
...@@ -50,11 +45,10 @@ def simple_evaluate( ...@@ -50,11 +45,10 @@ def simple_evaluate(
limit: Optional[Union[int, float]] = None, limit: Optional[Union[int, float]] = None,
bootstrap_iters: int = 100000, bootstrap_iters: int = 100000,
check_integrity: bool = False, check_integrity: bool = False,
decontamination_ngrams_path=None,
write_out: bool = False, write_out: bool = False,
log_samples: bool = True, log_samples: bool = True,
gen_kwargs: str = None, gen_kwargs: Optional[str] = None,
task_manager: TaskManager = None, task_manager: Optional[TaskManager] = None,
verbosity: str = "INFO", verbosity: str = "INFO",
predict_only: bool = False, predict_only: bool = False,
random_seed: int = 0, random_seed: int = 0,
...@@ -136,14 +130,16 @@ def simple_evaluate( ...@@ -136,14 +130,16 @@ def simple_evaluate(
if tasks is None: if tasks is None:
tasks = [] tasks = []
assert ( if len(tasks) == 0:
tasks != [] raise ValueError(
), "No tasks specified, or no tasks found. Please verify the task names." "No tasks specified, or no tasks found. Please verify the task names."
)
if gen_kwargs is not None: if gen_kwargs is not None:
gen_kwargs = simple_parse_args_string(gen_kwargs) gen_kwargs = simple_parse_args_string(gen_kwargs)
eval_logger.warning( eval_logger.warning(
"generation_kwargs specified through cli, these settings will update set parameters in yaml tasks. Ensure 'do_sample=True' for non-greedy decoding!" "generation_kwargs specified through cli, these settings will update set parameters in yaml tasks. "
"Ensure 'do_sample=True' for non-greedy decoding!"
) )
if gen_kwargs == "": if gen_kwargs == "":
gen_kwargs = None gen_kwargs = None
...@@ -152,7 +148,7 @@ def simple_evaluate( ...@@ -152,7 +148,7 @@ def simple_evaluate(
if model_args is None: if model_args is None:
model_args = "" model_args = ""
elif isinstance(model_args, dict): if isinstance(model_args, dict):
lm = lm_eval.api.registry.get_model(model).create_from_arg_obj( lm = lm_eval.api.registry.get_model(model).create_from_arg_obj(
model_args, model_args,
{ {
...@@ -172,7 +168,8 @@ def simple_evaluate( ...@@ -172,7 +168,8 @@ def simple_evaluate(
}, },
) )
else: else:
assert isinstance(model, lm_eval.api.model.LM) if not isinstance(model, lm_eval.api.model.LM):
raise TypeError
lm = model lm = model
if use_cache is not None: if use_cache is not None:
...@@ -237,7 +234,6 @@ def simple_evaluate( ...@@ -237,7 +234,6 @@ def simple_evaluate(
cache_requests=cache_requests, cache_requests=cache_requests,
rewrite_requests_cache=rewrite_requests_cache, rewrite_requests_cache=rewrite_requests_cache,
bootstrap_iters=bootstrap_iters, bootstrap_iters=bootstrap_iters,
decontamination_ngrams_path=decontamination_ngrams_path,
write_out=write_out, write_out=write_out,
log_samples=log_samples, log_samples=log_samples,
verbosity=verbosity, verbosity=verbosity,
...@@ -272,18 +268,14 @@ def simple_evaluate( ...@@ -272,18 +268,14 @@ def simple_evaluate(
return None return None
decontaminate_suffix = "_decontaminate"
@positional_deprecated @positional_deprecated
def evaluate( def evaluate(
lm: "LM", lm: "LM",
task_dict, task_dict,
limit: Optional[int] = None, limit: Optional[int] = None,
cache_requests=False, cache_requests: bool = False,
rewrite_requests_cache=False, rewrite_requests_cache: bool = False,
bootstrap_iters: Optional[int] = 100000, bootstrap_iters: Optional[int] = 100000,
decontamination_ngrams_path=None,
write_out: bool = False, write_out: bool = False,
log_samples: bool = True, log_samples: bool = True,
verbosity: str = "INFO", verbosity: str = "INFO",
...@@ -307,21 +299,21 @@ def evaluate( ...@@ -307,21 +299,21 @@ def evaluate(
""" """
eval_logger.setLevel(getattr(logging, f"{verbosity}")) eval_logger.setLevel(getattr(logging, f"{verbosity}"))
# decontaminate = decontamination_ngrams_path is not None
# tracks all Instances/requests a model must generate output on. # tracks all Instances/requests a model must generate output on.
requests = collections.defaultdict(list) requests = defaultdict(list)
# stores the amount to pad out reqs per req. type so that # stores the amount to pad out reqs per req. type so that
# number of fwd passes per distributed rank is equal # number of fwd passes per distributed rank is equal
padding_requests = collections.defaultdict(int) padding_requests = defaultdict(int)
# get lists of group hierarchy and each type of request # get lists of group hierarchy and each type of request
task_hierarchy, eval_tasks = get_task_list(task_dict) task_hierarchy, eval_tasks = get_task_list(task_dict)
if not log_samples: if not log_samples:
assert all( if not all(
"bypass" not in getattr(task_output.task, "_metric_fn_list", {}).keys() "bypass" not in getattr(task_output.task, "_metric_fn_list", {}).keys()
for task_output in eval_tasks for task_output in eval_tasks
), "log_samples must be True for 'bypass' only tasks" ):
raise ValueError("log_samples must be True for 'bypass' metric-only tasks")
for task_output in eval_tasks: for task_output in eval_tasks:
task: Task = task_output.task task: Task = task_output.task
limit = get_sample_size(task, limit) limit = get_sample_size(task, limit)
...@@ -348,10 +340,16 @@ def evaluate( ...@@ -348,10 +340,16 @@ def evaluate(
gathered_item = ( gathered_item = (
lm.accelerator.gather(instances_rnk).cpu().detach().numpy().tolist() lm.accelerator.gather(instances_rnk).cpu().detach().numpy().tolist()
) )
# "multiple_choice" task types dispatch (several) "loglikelihood" request types
reqtype = (
"loglikelihood"
if task.OUTPUT_TYPE == "multiple_choice"
else task.OUTPUT_TYPE
)
# compute number of pseudo-batches to pad with (FSDP/DDP require even batches among ranks) # compute number of pseudo-batches to pad with (FSDP/DDP require even batches among ranks)
numpad = max(gathered_item) - gathered_item[lm.rank] numpad = max(gathered_item) - gathered_item[lm.rank]
padding_requests[task.OUTPUT_TYPE] += numpad # todo: may not account for padding in cases like SquadV2 which has multiple req types
padding_requests[reqtype] += numpad
### Run LM on inputs, get all outputs ### ### Run LM on inputs, get all outputs ###
# execute each type of request # execute each type of request
...@@ -388,7 +386,7 @@ def evaluate( ...@@ -388,7 +386,7 @@ def evaluate(
# # unpack results and sort back in order and return control to Task # # unpack results and sort back in order and return control to Task
# TODO: make it possible to use a different metric per filter # TODO: make it possible to use a different metric per filter
# Pre-process task.instances to group by doc_id # Pre-process task.instances to group by doc_id
instances_by_doc_id = collections.defaultdict(list) instances_by_doc_id = defaultdict(list)
for instance in task.instances: for instance in task.instances:
instances_by_doc_id[instance.doc_id].append(instance) instances_by_doc_id[instance.doc_id].append(instance)
# Sort instances within each group # Sort instances within each group
...@@ -515,10 +513,9 @@ def evaluate( ...@@ -515,10 +513,9 @@ def evaluate(
results[group]["samples"] = sum(sizes) results[group]["samples"] = sum(sizes)
results_agg = collections.defaultdict(dict) results_agg = defaultdict(dict)
groups_agg = collections.defaultdict(dict) groups_agg = defaultdict(dict)
all_tasks_list = list(task_hierarchy.keys()) all_tasks_list = list(task_hierarchy.keys())
left_tasks_list = []
while True: while True:
add_tasks_list = list(k for k in results_agg.keys()) add_tasks_list = list(k for k in results_agg.keys())
left_tasks_list = sorted(list(set(all_tasks_list) - set(add_tasks_list))) left_tasks_list = sorted(list(set(all_tasks_list) - set(add_tasks_list)))
...@@ -542,7 +539,7 @@ def evaluate( ...@@ -542,7 +539,7 @@ def evaluate(
results_dict = { results_dict = {
"results": dict(results_agg.items()), "results": dict(results_agg.items()),
**({"groups": dict(groups_agg.items())} if bool(groups_agg) else {}), **({"groups": dict(groups_agg.items())} if bool(groups_agg) else {}),
"group_subtasks": {k: v for k, v in reversed(task_hierarchy.items())}, "group_subtasks": dict(reversed(task_hierarchy.items())),
"configs": dict(sorted(configs.items())), "configs": dict(sorted(configs.items())),
"versions": dict(sorted(versions.items())), "versions": dict(sorted(versions.items())),
"n-shot": dict(sorted(num_fewshot.items())), "n-shot": dict(sorted(num_fewshot.items())),
...@@ -558,11 +555,9 @@ def evaluate( ...@@ -558,11 +555,9 @@ def evaluate(
def request_caching_arg_to_dict(cache_requests: str) -> dict: def request_caching_arg_to_dict(cache_requests: str) -> dict:
request_caching_args = { request_caching_args = {
"cache_requests": ( "cache_requests": cache_requests in {"true", "refresh"},
True if cache_requests == "true" or cache_requests == "refresh" else False "rewrite_requests_cache": cache_requests == "refresh",
), "delete_requests_cache": cache_requests == "delete",
"rewrite_requests_cache": True if cache_requests == "refresh" else False,
"delete_requests_cache": True if cache_requests == "delete" else False,
} }
return request_caching_args return request_caching_args
...@@ -15,6 +15,7 @@ FILTER_REGISTRY = { ...@@ -15,6 +15,7 @@ FILTER_REGISTRY = {
"lowercase": transformation.LowercaseFilter, "lowercase": transformation.LowercaseFilter,
"uppercase": transformation.UppercaseFilter, "uppercase": transformation.UppercaseFilter,
"map": transformation.MapFilter, "map": transformation.MapFilter,
"multi_choice_regex": extraction.MultiChoiceRegexFilter,
# TODO: implement this filter. either it should take in an arbitrary "scoring"/reward function # TODO: implement this filter. either it should take in an arbitrary "scoring"/reward function
# that takes an input and returns a scalar and then should select the max reward, # that takes an input and returns a scalar and then should select the max reward,
# or should implement different filters for different ways of handling a reward model's inference. # or should implement different filters for different ways of handling a reward model's inference.
......
import re import re
import sys
import unicodedata
from lm_eval.api.filter import Filter from lm_eval.api.filter import Filter
...@@ -67,3 +69,115 @@ class WhitespaceFilter(Filter): ...@@ -67,3 +69,115 @@ class WhitespaceFilter(Filter):
filtered_resps = [filter_set(resp) for resp in resps] filtered_resps = [filter_set(resp) for resp in resps]
return filtered_resps return filtered_resps
class MultiChoiceRegexFilter(RegexFilter):
"""
A filter used to extract a model's answer on multiple choice questions with
letter answers. assumes each document has a "choices" field
containing the list of answer choices and that the answer label symbols
are of the form (A), (B), (C), ... or A, B, C.
"""
def __init__(
self,
regex_pattern: str = r"#### (\-?[0-9\.\,]+)",
group_select=0,
fallback: str = "[invalid]",
ignore_case=False,
ignore_punctuation=False,
regexes_to_ignore=None,
) -> None:
"""
regex_pattern: The basic regex pattern to use. If fails to match, we will use the customized match procedure
- step 1 : We parse the choices between ([A-Z])s then try to find these choices in the response.
- step 2 : We parse the choice with regex :[\s]*([A-?]), where ? varies by number of choices.
group_select: Selects the (group_select)th match from the findall result.
ignore_case: Ignores the case during step 1 matching
ignore_punctuation: Remove the punctuation during step 1 matching
regexes_to_ignore: Remove these regexes during step 1 matching
"""
super().__init__(regex_pattern, group_select, fallback)
self.ignore_case = ignore_case
self.ignore_punctuation = ignore_punctuation
self.regexes_to_ignore = regexes_to_ignore
def apply(self, resps, docs):
# here, we assume we have a list, in which each element is
# a list of model responses for some particular input/target pair.
# so we process each of these (same input/target response sets)
# independently (and keep them a list.)
def find_match(regex, resp, convert_dict={}):
match = regex.findall(resp)
if match:
match = match[self.group_select]
if isinstance(match, tuple):
match = [m for m in match if m][0]
match = match.strip()
if match and match in convert_dict:
match = convert_dict[match]
return match
punct_tbl = dict.fromkeys(
i
for i in range(sys.maxunicode)
if unicodedata.category(chr(i)).startswith("P")
)
def filter_ignores(st):
if self.regexes_to_ignore is not None:
for s in self.regexes_to_ignore:
st = re.sub(s, "", st)
if self.ignore_case:
st = st.lower()
if self.ignore_punctuation:
# https://stackoverflow.com/a/266162
st = st.translate(punct_tbl)
return st
filtered_resps = []
for r, doc in zip(resps, docs):
fallback_regexes = []
choice_to_alpha = {}
next_alpha = "A"
without_paren_fallback_regexes = []
without_paren_to_target = {}
choices = doc["choices"]
for c in choices:
m = filter_ignores(c.strip())
fallback_regexes.append(f"{re.escape(m)}")
choice_to_alpha[m] = f"({next_alpha})"
without_paren_fallback_regexes.append(next_alpha)
without_paren_to_target[next_alpha] = f"({next_alpha})"
next_alpha = chr(ord(next_alpha) + 1)
fallback_regex = re.compile("|".join(fallback_regexes))
without_paren_fallback_regex = "|".join(without_paren_fallback_regexes)
without_paren_fallback_regex = re.compile(
f":[\s]*({without_paren_fallback_regex})"
)
filtered = []
for resp in r:
match = find_match(self.regex, resp)
if not match:
match = find_match(
fallback_regex, filter_ignores(resp), choice_to_alpha
)
if not match:
match = find_match(
without_paren_fallback_regex, resp, without_paren_to_target
)
if not match:
match = self.fallback
filtered.append(match)
filtered_resps.append(filtered)
return filtered_resps
...@@ -13,24 +13,9 @@ from packaging.version import Version ...@@ -13,24 +13,9 @@ from packaging.version import Version
from torch.utils.collect_env import get_pretty_env_info from torch.utils.collect_env import get_pretty_env_info
from transformers import __version__ as trans_version from transformers import __version__ as trans_version
from lm_eval.utils import simple_parse_args_string
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
try:
import wandb
assert Version(wandb.__version__) >= Version("0.13.6")
if Version(wandb.__version__) < Version("0.13.6"):
wandb.require("report-editing:v0")
except Exception as e:
logger.warning(
"To use the wandb reporting functionality please install wandb>=0.13.6.\n"
"To install the latest version of wandb run `pip install wandb --upgrade`\n"
f"{e}"
)
def remove_none_pattern(input_string: str) -> Tuple[str, bool]: def remove_none_pattern(input_string: str) -> Tuple[str, bool]:
"""Remove the ',none' substring from the input_string if it exists at the end. """Remove the ',none' substring from the input_string if it exists at the end.
...@@ -83,14 +68,31 @@ def get_wandb_printer() -> Literal["Printer"]: ...@@ -83,14 +68,31 @@ def get_wandb_printer() -> Literal["Printer"]:
class WandbLogger: class WandbLogger:
def __init__(self, args: Any) -> None: def __init__(self, **kwargs) -> None:
"""Initialize the WandbLogger. """Attaches to wandb logger if already initialized. Otherwise, passes kwargs to wandb.init()
Args: Args:
results (Dict[str, Any]): The results dictionary. kwargs Optional[Any]: Arguments for configuration.
args (Any): Arguments for configuration.
Parse and log the results returned from evaluator.simple_evaluate() with:
wandb_logger.post_init(results)
wandb_logger.log_eval_result()
wandb_logger.log_eval_samples(results["samples"])
""" """
self.wandb_args: Dict[str, Any] = simple_parse_args_string(args.wandb_args) try:
import wandb
assert Version(wandb.__version__) >= Version("0.13.6")
if Version(wandb.__version__) < Version("0.13.6"):
wandb.require("report-editing:v0")
except Exception as e:
logger.warning(
"To use the wandb reporting functionality please install wandb>=0.13.6.\n"
"To install the latest version of wandb run `pip install wandb --upgrade`\n"
f"{e}"
)
self.wandb_args: Dict[str, Any] = kwargs
# initialize a W&B run # initialize a W&B run
if wandb.run is None: if wandb.run is None:
...@@ -164,6 +166,8 @@ class WandbLogger: ...@@ -164,6 +166,8 @@ class WandbLogger:
] ]
def make_table(columns: List[str], key: str = "results"): def make_table(columns: List[str], key: str = "results"):
import wandb
table = wandb.Table(columns=columns) table = wandb.Table(columns=columns)
results = copy.deepcopy(self.results) results = copy.deepcopy(self.results)
...@@ -202,6 +206,8 @@ class WandbLogger: ...@@ -202,6 +206,8 @@ class WandbLogger:
def _log_results_as_artifact(self) -> None: def _log_results_as_artifact(self) -> None:
"""Log results as JSON artifact to W&B.""" """Log results as JSON artifact to W&B."""
import wandb
dumped = json.dumps( dumped = json.dumps(
self.results, indent=2, default=_handle_non_serializable, ensure_ascii=False self.results, indent=2, default=_handle_non_serializable, ensure_ascii=False
) )
...@@ -320,6 +326,8 @@ class WandbLogger: ...@@ -320,6 +326,8 @@ class WandbLogger:
def _log_samples_as_artifact( def _log_samples_as_artifact(
self, data: List[Dict[str, Any]], task_name: str self, data: List[Dict[str, Any]], task_name: str
) -> None: ) -> None:
import wandb
# log the samples as an artifact # log the samples as an artifact
dumped = json.dumps( dumped = json.dumps(
data, data,
......
...@@ -147,7 +147,7 @@ please install anthropic via `pip install lm-eval[anthropic]` or `pip install -e ...@@ -147,7 +147,7 @@ please install anthropic via `pip install lm-eval[anthropic]` or `pip install -e
def _loglikelihood_tokens(self, requests, disable_tqdm: bool = False): def _loglikelihood_tokens(self, requests, disable_tqdm: bool = False):
raise NotImplementedError("No support for logits.") raise NotImplementedError("No support for logits.")
def generate_until(self, requests) -> List[str]: def generate_until(self, requests, disable_tqdm: bool = False) -> List[str]:
try: try:
import anthropic import anthropic
except ModuleNotFoundError: except ModuleNotFoundError:
...@@ -162,7 +162,7 @@ please install anthropic via `pip install lm-eval[anthropic]` or `pip install -e ...@@ -162,7 +162,7 @@ please install anthropic via `pip install lm-eval[anthropic]` or `pip install -e
_requests: List[Tuple[str, dict]] = [req.args for req in requests] _requests: List[Tuple[str, dict]] = [req.args for req in requests]
res = [] res = []
for request in tqdm(_requests): for request in tqdm(_requests, disable=disable_tqdm):
try: try:
inp = request[0] inp = request[0]
request_args = request[1] request_args = request[1]
...@@ -199,8 +199,8 @@ please install anthropic via `pip install lm-eval[anthropic]` or `pip install -e ...@@ -199,8 +199,8 @@ please install anthropic via `pip install lm-eval[anthropic]` or `pip install -e
# Isn't used because we override generate_until # Isn't used because we override generate_until
raise NotImplementedError() raise NotImplementedError()
def loglikelihood(self, requests): def loglikelihood(self, requests, disable_tqdm: bool = False):
raise NotImplementedError("No support for logits.") raise NotImplementedError("No support for logits.")
def loglikelihood_rolling(self, requests): def loglikelihood_rolling(self, requests, disable_tqdm: bool = False):
raise NotImplementedError("No support for logits.") raise NotImplementedError("No support for logits.")
import random import random
from tqdm import tqdm
from lm_eval.api.model import LM from lm_eval.api.model import LM
from lm_eval.api.registry import register_model from lm_eval.api.registry import register_model
...@@ -13,27 +15,27 @@ class DummyLM(LM): ...@@ -13,27 +15,27 @@ class DummyLM(LM):
def create_from_arg_string(cls, arg_string, additional_config=None): def create_from_arg_string(cls, arg_string, additional_config=None):
return cls() return cls()
def loglikelihood(self, requests): def loglikelihood(self, requests, disable_tqdm: bool = False):
res = [] res = []
for _ in requests: for _ in tqdm(requests, disable=disable_tqdm):
res.append((-random.random(), False)) res.append((-random.random(), False))
return res return res
def generate_until(self, requests): def generate_until(self, requests, disable_tqdm: bool = False):
res = [] res = []
for ctx, _ in requests: for ctx, _ in tqdm(requests, disable=disable_tqdm):
res.append("lol") res.append("lol")
assert ctx.strip() != "" assert ctx.strip() != ""
return res return res
def loglikelihood_rolling(self, requests): def loglikelihood_rolling(self, requests, disable_tqdm: bool = False):
res = [] res = []
for _ in requests: for _ in tqdm(requests, disable=disable_tqdm):
res.append(-random.random()) res.append(-random.random())
return res return res
...@@ -70,11 +70,13 @@ class GGUFLM(LM): ...@@ -70,11 +70,13 @@ class GGUFLM(LM):
else: else:
raise Exception(f"Failed to get a valid response after {retries} retries.") raise Exception(f"Failed to get a valid response after {retries} retries.")
def loglikelihood(self, requests): def loglikelihood(self, requests, disable_tqdm: bool = False):
if not requests: if not requests:
return [] return []
res = [] res = []
for context, continuation in tqdm([req.args for req in requests]): for context, continuation in tqdm(
[req.args for req in requests], disable=disable_tqdm
):
response = self.gguf_completion(context=context, continuation=continuation) response = self.gguf_completion(context=context, continuation=continuation)
if response and "choices" in response and response["choices"]: if response and "choices" in response and response["choices"]:
choice = response["choices"][0] choice = response["choices"][0]
...@@ -97,12 +99,12 @@ class GGUFLM(LM): ...@@ -97,12 +99,12 @@ class GGUFLM(LM):
assert False assert False
return res return res
def generate_until(self, requests): def generate_until(self, requests, disable_tqdm: bool = False):
if not requests: if not requests:
return [] return []
res = [] res = []
for request in tqdm([req.args for req in requests]): for request in tqdm([req.args for req in requests], disable=disable_tqdm):
inp = request[0] inp = request[0]
request_args = request[1] request_args = request[1]
until = request_args.get("until", ["</s>"]) until = request_args.get("until", ["</s>"])
...@@ -122,7 +124,7 @@ class GGUFLM(LM): ...@@ -122,7 +124,7 @@ class GGUFLM(LM):
res.append(None) # Add default value in case of error res.append(None) # Add default value in case of error
return res return res
def loglikelihood_rolling(self, requests): def loglikelihood_rolling(self, requests, disable_tqdm: bool = False):
raise NotImplementedError( raise NotImplementedError(
"loglikelihood_rolling not yet supported for GGUF models" "loglikelihood_rolling not yet supported for GGUF models"
) )
...@@ -790,7 +790,9 @@ class HFLM(TemplateLM): ...@@ -790,7 +790,9 @@ class HFLM(TemplateLM):
return logits return logits
def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]: def loglikelihood_rolling(
self, requests: List[Instance], disable_tqdm: bool = False
) -> List[float]:
loglikelihoods = [] loglikelihoods = []
adaptive_batch_size = None adaptive_batch_size = None
...@@ -801,7 +803,9 @@ class HFLM(TemplateLM): ...@@ -801,7 +803,9 @@ class HFLM(TemplateLM):
print(f"Determined Largest batch size: {batch_size}") print(f"Determined Largest batch size: {batch_size}")
adaptive_batch_size = batch_size adaptive_batch_size = batch_size
for (string,) in tqdm([req.args for req in requests], disable=(self.rank != 0)): for (string,) in tqdm(
[req.args for req in requests], disable=(disable_tqdm or (self.rank != 0))
):
rolling_token_windows = list( rolling_token_windows = list(
map( map(
utils.make_disjoint_window, utils.make_disjoint_window,
...@@ -887,7 +891,7 @@ class HFLM(TemplateLM): ...@@ -887,7 +891,7 @@ class HFLM(TemplateLM):
def _lookup_one_token_cont(req: Tuple[Tuple[str, str], List[int], List[int]]): def _lookup_one_token_cont(req: Tuple[Tuple[str, str], List[int], List[int]]):
"""Defines the key to group and lookup one-token continuations""" """Defines the key to group and lookup one-token continuations"""
# Use with group_by="contexts" (optional)" # Use with group_by="contexts" (optional)"
# allows for the creation of a lookup, so we can re-use logits in case of one-token continuations. # allows for the creation of a lookup, so we can reuse logits in case of one-token continuations.
# speeds up some multiple-choice tasks proportionally to the number of choices. # speeds up some multiple-choice tasks proportionally to the number of choices.
# groups requests by context+continuation[:-1] and infer on one request/group. # groups requests by context+continuation[:-1] and infer on one request/group.
return req[-2] + req[-1][:-1] return req[-2] + req[-1][:-1]
...@@ -1079,7 +1083,9 @@ class HFLM(TemplateLM): ...@@ -1079,7 +1083,9 @@ class HFLM(TemplateLM):
return re_ord.get_original(res) return re_ord.get_original(res)
def generate_until(self, requests: List[Instance]) -> List[str]: def generate_until(
self, requests: List[Instance], disable_tqdm: bool = False
) -> List[str]:
res = [] res = []
def _collate(req: Tuple[str, dict]): def _collate(req: Tuple[str, dict]):
...@@ -1095,7 +1101,7 @@ class HFLM(TemplateLM): ...@@ -1095,7 +1101,7 @@ class HFLM(TemplateLM):
pbar = tqdm( pbar = tqdm(
total=len(requests), total=len(requests),
disable=(self.rank != 0), disable=(disable_tqdm or (self.rank != 0)),
desc="Running generate_until requests", desc="Running generate_until requests",
) )
adaptive_batch_size = None adaptive_batch_size = None
...@@ -1151,8 +1157,12 @@ class HFLM(TemplateLM): ...@@ -1151,8 +1157,12 @@ class HFLM(TemplateLM):
raise ValueError( raise ValueError(
f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}" f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}"
) )
# add EOS token to stop sequences
eos = self.tok_decode(self.eot_token_id)
if not until: if not until:
until = [self.tok_decode(self.eot_token_id)] until = [eos]
else:
until.append(eos)
if "max_gen_toks" in kwargs.keys(): if "max_gen_toks" in kwargs.keys():
max_gen_toks = kwargs.pop("max_gen_toks") max_gen_toks = kwargs.pop("max_gen_toks")
else: else:
......
...@@ -447,12 +447,14 @@ class NEURON_HF(TemplateLM): ...@@ -447,12 +447,14 @@ class NEURON_HF(TemplateLM):
return logits return logits
def loglikelihood_rolling(self, requests): def loglikelihood_rolling(self, requests, disable_tqdm: bool = False):
loglikelihoods = [] loglikelihoods = []
adaptive_batch_size = None adaptive_batch_size = None
for (string,) in tqdm([req.args for req in requests], disable=(self.rank != 0)): for (string,) in tqdm(
[req.args for req in requests], disable=(disable_tqdm or (self.rank != 0))
):
rolling_token_windows = list( rolling_token_windows = list(
map( map(
utils.make_disjoint_window, utils.make_disjoint_window,
...@@ -616,7 +618,7 @@ class NEURON_HF(TemplateLM): ...@@ -616,7 +618,7 @@ class NEURON_HF(TemplateLM):
return re_ord.get_original(res) return re_ord.get_original(res)
def generate_until(self, requests): def generate_until(self, requests, disable_tqdm: bool = False):
res = defaultdict(list) res = defaultdict(list)
re_ords = {} re_ords = {}
...@@ -638,7 +640,7 @@ class NEURON_HF(TemplateLM): ...@@ -638,7 +640,7 @@ class NEURON_HF(TemplateLM):
# within each set of reqs for given kwargs, we reorder by token length, descending. # within each set of reqs for given kwargs, we reorder by token length, descending.
re_ords[key] = utils.Reorderer([req.args for req in reqs], _collate) re_ords[key] = utils.Reorderer([req.args for req in reqs], _collate)
pbar = tqdm(total=len(requests), disable=(self.rank != 0)) pbar = tqdm(total=len(requests), disable=(disable_tqdm or (self.rank != 0)))
# for each different set of kwargs, we execute all requests, by batch. # for each different set of kwargs, we execute all requests, by batch.
for key, re_ord in re_ords.items(): for key, re_ord in re_ords.items():
...@@ -666,8 +668,12 @@ class NEURON_HF(TemplateLM): ...@@ -666,8 +668,12 @@ class NEURON_HF(TemplateLM):
raise ValueError( raise ValueError(
f"Expected `kwargs` to be of type `dict` but got {kwargs}" f"Expected `kwargs` to be of type `dict` but got {kwargs}"
) )
# add EOS token to stop sequences
eos = self.tok_decode(self.eot_token_id)
if not until: if not until:
until = [self.tok_decode(self.eot_token_id)] until = [eos]
else:
until.append(eos)
if "max_gen_toks" in kwargs.keys(): if "max_gen_toks" in kwargs.keys():
max_gen_toks = kwargs.pop("max_gen_toks") max_gen_toks = kwargs.pop("max_gen_toks")
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment