Commit 02e841ce authored by lintangsutawika's avatar lintangsutawika
Browse files

Merge branch 'main' of https://github.com/EleutherAI/lm-evaluation-harness into t5v2-alt-plus

parents 90ad5db7 e74ec966
......@@ -2,7 +2,7 @@
exclude: ^tests/testdata/
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.1.0
rev: v4.5.0
hooks:
- id: check-added-large-files
- id: check-ast
......@@ -29,7 +29,7 @@ repos:
args: [--fix=lf]
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.1.8
rev: v0.2.2
hooks:
# Run the linter.
- id: ruff
......@@ -38,7 +38,7 @@ repos:
# Run the formatter.
- id: ruff-format
- repo: https://github.com/codespell-project/codespell
rev: v2.1.0
rev: v2.2.6
hooks:
- id: codespell
exclude: >
......
......@@ -321,7 +321,7 @@ lm_eval \
--log_samples
```
In the stdout, you will find the link to the W&B run page as well as link to the generated report. You can find an example of this workflow in [examples/visualize-wandb.ipynb](examples/visualize-wandb.ipynb).
In the stdout, you will find the link to the W&B run page as well as link to the generated report. You can find an example of this workflow in [examples/visualize-wandb.ipynb](examples/visualize-wandb.ipynb), and an example of how to integrate it beyond the CLI.
## How to Contribute or Learn More?
......
......@@ -19,7 +19,7 @@ LM Evaluation Harness uses [ruff](https://github.com/astral-sh/ruff) for linting
You can install linters and dev tools via
```pip install lm_eval[dev]```
```pip install lm_eval[dev]``` or ```pip install -e ".[dev]"```
Then, run
......
......@@ -2,15 +2,14 @@
## Usage
Simply add a "--decontamination_ngrams_path" when running \__main\__.py. The provided directory should contain
The provided directory should contain
the ngram files and info.json produced in "Pile Ngram Generation" further down.
```bash
python -m lm_eval \
--model gpt2 \
--device 0 \
--tasks sciq \
--decontamination_ngrams_path path/containing/training/set/ngrams
--tasks sciq
```
## Background
......@@ -70,5 +69,3 @@ python -m scripts/clean_training_data/compress_and_package \
-output path/to/final/directory \
-procs 8
```
Congratulations, the final directory can now be passed to lm-evaulation-harness with the "--decontamination_ngrams_path" argument.
......@@ -36,8 +36,6 @@ This mode supports a number of command-line arguments, the details of which can
- `--cache_requests` : Can be "true", "refresh", or "delete". "true" means that the cache should be used. "refresh" means that you wish to regenerate the cache, which you should run if you change your dataset configuration for a given task. "delete" will delete the cache. Cached files are stored under lm_eval/cache/.cache unless you specify a different path via the environment variable: `LM_HARNESS_CACHE_PATH`. e.g. `LM_HARNESS_CACHE_PATH=~/Documents/cache_for_lm_harness`.
- `--decontamination_ngrams_path` : Deprecated, see (this commit)[https://github.com/EleutherAI/lm-evaluation-harness/commit/00209e10f6e27edf5d766145afaf894079b5fe10] or older for a working decontamination-checker tool.
- `--check_integrity` : If this flag is used, the library tests for each task selected are run to confirm task integrity.
- `--write_out` : Used for diagnostic purposes to observe the format of task documents passed to a model. If this flag is used, then prints the prompt and gold target string for the first document of each task.
......@@ -50,7 +48,7 @@ This mode supports a number of command-line arguments, the details of which can
* `--seed`: Set seed for python's random, numpy and torch. Accepts a comma-separated list of 3 values for python's random, numpy, and torch seeds, respectively, or a single integer to set the same seed for all three. The values are either an integer or 'None' to not set the seed. Default is `0,1234,1234` (for backward compatibility). E.g. `--seed 0,None,8` sets `random.seed(0)` and `torch.manual_seed(8)`. Here numpy's seed is not set since the second value is `None`. E.g, `--seed 42` sets all three seeds to 42.
* `--wandb_args`: Tracks logging to Weights and Biases for evaluation runs and includes args passed to `wandb.init`, such as `project` and `job_type`. Full list (here.)[https://docs.wandb.ai/ref/python/init]
* `--wandb_args`: Tracks logging to Weights and Biases for evaluation runs and includes args passed to `wandb.init`, such as `project` and `job_type`. Full list (here.)[https://docs.wandb.ai/ref/python/init]. e.g., ```--wandb_args project=test-project,name=test-run```
## External Library Usage
......
......@@ -30,9 +30,10 @@ Dataset configuration options:
Prompting / in-context formatting options:
- **use_prompt** (`str`, *optional*) — Name of prompt in promptsource to use. if defined, will overwrite doc_to_text, doc_to_target, and doc_to_choice.
- **doc_to_text** (`Union[Callable, str]`, *optional*) — Jinja2, f-string, or function to process a sample into the appropriate input for the model
- **doc_to_target** (`Union[Callable, str]`, *optional*) — Jinja2, f-string, or function to process a sample into the appropriate target output for the model. For multiple choice tasks, this should return an index into
- **doc_to_choice** (`Union[Callable, str]`, *optional*) — Jinja2, f-string, or function to process a sample into a list of possible string choices for `multiple_choice` tasks. Left undefined for `generate_until` tasks.
- **description** (`str`, *optional*) — An optional prepended Jinja2 template or string which will be prepended to the few-shot examples passed into the model, often describing the task or providing instructions to a model, such as `"The following are questions (with answers) about {{subject}}.\n\n"`. No delimiters or spacing are inserted between the description and the first few-shot example.
- **doc_to_text** (`Union[Callable, str]`, *optional*) — Jinja2 template, string, or function to process a sample into the appropriate input for the model
- **doc_to_target** (`Union[Callable, str]`, *optional*) — Jinja2 template, string, or function to process a sample into the appropriate target output for the model. For multiple choice tasks, this should return an index into
- **doc_to_choice** (`Union[Callable, str]`, *optional*) — Jinja2 template, string, or function to process a sample into a list of possible string choices for `multiple_choice` tasks. Left undefined for `generate_until` tasks.
- **fewshot_delimiter** (`str`, *optional*, defaults to "\n\n") — String to insert between few-shot examples.
- **target_delimiter** (`str`, *optional*, defaults to `" "`) — String to insert between input and target output for the datapoint being tested.
......
......@@ -67,6 +67,7 @@
"outputs": [],
"source": [
"import wandb\n",
"\n",
"wandb.login()"
]
},
......@@ -104,6 +105,43 @@
" --wandb_args project=lm-eval-harness-integration \\\n",
" --log_samples"
]
},
{
"cell_type": "markdown",
"id": "e974cabdbe70b667",
"metadata": {},
"source": ""
},
{
"cell_type": "markdown",
"id": "5178ca9445b844e4",
"metadata": {},
"source": "W&B can also be initialized programmatically for use outside the CLI to parse and log the results."
},
{
"cell_type": "code",
"execution_count": null,
"id": "c6a421b2cf3ddac5",
"metadata": {},
"outputs": [],
"source": [
"import lm_eval\n",
"from lm_eval.logging_utils import WandbLogger\n",
"\n",
"results = lm_eval.simple_evaluate(\n",
" model=\"hf\",\n",
" model_args=\"pretrained=microsoft/phi-2,trust_remote_code=True\",\n",
" tasks=\"hellaswag,mmlu_abstract_algebra\",\n",
" log_samples=True,\n",
")\n",
"\n",
"wandb_logger = WandbLogger(\n",
" project=\"lm-eval-harness-integration\", job_type=\"eval\"\n",
") # or empty if wandb.init(...) already called before\n",
"wandb_logger.post_init(results)\n",
"wandb_logger.log_eval_result()\n",
"wandb_logger.log_eval_samples(results[\"samples\"]) # if log_samples"
]
}
],
"metadata": {
......
......@@ -14,7 +14,10 @@ from lm_eval import evaluator, utils
from lm_eval.evaluator import request_caching_arg_to_dict
from lm_eval.logging_utils import WandbLogger
from lm_eval.tasks import TaskManager, include_path, initialize_tasks
from lm_eval.utils import make_table
from lm_eval.utils import make_table, simple_parse_args_string
DEFAULT_RESULTS_FILE = "results.json"
def _handle_non_serializable(o):
......@@ -127,7 +130,6 @@ def parse_eval_args() -> argparse.Namespace:
choices=["true", "refresh", "delete"],
help="Speed up evaluation by caching the building of dataset requests. `None` if not caching.",
)
parser.add_argument("--decontamination_ngrams_path", default=None) # TODO: not used
parser.add_argument(
"--check_integrity",
action="store_true",
......@@ -201,6 +203,12 @@ def parse_eval_args() -> argparse.Namespace:
"E.g, `--seed 42` sets all three seeds to 42."
),
)
parser.add_argument(
"--trust_remote_code",
action="store_true",
help="Sets trust_remote_code to True to execute code to create HF Datasets from the Hub",
)
return parser.parse_args()
......@@ -210,7 +218,7 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
args = parse_eval_args()
if args.wandb_args:
wandb_logger = WandbLogger(args)
wandb_logger = WandbLogger(**simple_parse_args_string(args.wandb_args))
eval_logger = utils.eval_logger
eval_logger.setLevel(getattr(logging, f"{args.verbosity}"))
......@@ -220,7 +228,9 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
if args.predict_only:
args.log_samples = True
if (args.log_samples or args.predict_only) and not args.output_path:
assert args.output_path, "Specify --output_path"
raise ValueError(
"Specify --output_path if providing --log_samples or --predict_only"
)
initialize_tasks(args.verbosity)
task_manager = TaskManager(args.verbosity, include_path=args.include_path)
......@@ -271,12 +281,13 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
if args.output_path:
path = Path(args.output_path)
# check if file or 'dir/results.json' exists
if path.is_file() or Path(args.output_path).joinpath("results.json").is_file():
if path.is_file():
raise FileExistsError(f"File already exists at {path}")
output_path_file = path.joinpath(DEFAULT_RESULTS_FILE)
if output_path_file.is_file():
eval_logger.warning(
f"File already exists at {path}. Results will be overwritten."
f"File {output_path_file} already exists. Results will be overwritten."
)
output_path_file = path.joinpath("results.json")
assert not path.is_file(), "File already exists"
# if path json then get parent dir
elif path.suffix in (".json", ".jsonl"):
output_path_file = path
......@@ -284,7 +295,14 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
path = path.parent
else:
path.mkdir(parents=True, exist_ok=True)
output_path_file = path.joinpath("results.json")
# Respect user's value passed in via CLI, otherwise default to True and add to comma-separated model args
if args.trust_remote_code:
os.environ["HF_DATASETS_TRUST_REMOTE_CODE"] = str(args.trust_remote_code)
args.model_args = (
args.model_args
+ f",trust_remote_code={os.environ['HF_DATASETS_TRUST_REMOTE_CODE']}"
)
eval_logger.info(f"Selected Tasks: {task_names}")
eval_logger.info("Loading selected tasks...")
......@@ -303,17 +321,17 @@ def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
device=args.device,
use_cache=args.use_cache,
limit=args.limit,
decontamination_ngrams_path=args.decontamination_ngrams_path,
check_integrity=args.check_integrity,
write_out=args.write_out,
log_samples=args.log_samples,
gen_kwargs=args.gen_kwargs,
task_manager=task_manager,
verbosity=args.verbosity,
predict_only=args.predict_only,
**request_caching_args,
random_seed=args.seed[0],
numpy_random_seed=args.seed[1],
torch_random_seed=args.seed[2],
**request_caching_args,
)
if results is not None:
......
from dataclasses import dataclass, field
from typing import Literal, Tuple
from typing import Literal, Optional, Tuple
OutputType = Literal[
"loglikelihood", "loglikelihood_rolling", "generate_until", "multiple_choice"
]
@dataclass
class Instance:
request_type: Literal[
"loglikelihood",
"loglikelihood_rolling",
"generate_until",
"multiple_choice",
]
request_type: OutputType
doc: dict
arguments: tuple
idx: int
metadata: Tuple[str, int, int] = field(
metadata: Tuple[Optional[str], Optional[int], Optional[int]] = field(
default_factory=lambda: (None, None, None)
) # TODO: better typehints here
)
resps: list = field(default_factory=list)
filtered_resps: dict = field(default_factory=dict)
# initialized after init
task_name: str = None
doc_id: str = None
repeats: str = None
task_name: Optional[str] = None
doc_id: Optional[int] = None
repeats: Optional[int] = None
def __post_init__(self) -> None:
# unpack metadata field
......
......@@ -54,7 +54,7 @@ class LM(abc.ABC):
pass
@abc.abstractmethod
def loglikelihood_rolling(self, requests) -> List[Tuple[float, bool]]:
def loglikelihood_rolling(self, requests) -> List[Tuple[float]]:
"""Compute full log-likelihood of a string, with no truncation, for perplexity computation
- We will use the full max context length of the model.
- For inputs that exceed the max context length, we divide the tokenized string into chunks of up to
......@@ -83,15 +83,13 @@ class LM(abc.ABC):
2. For the last pair, we provide the full context, but only score the last two tokens
:param requests: list[Instance]
A list of Instance objects with property `args` which returns a tuple (context, continuation).
A list of Instance objects with property `args` which returns a tuple (context,).
string: str
String for which we are computing per-token loglikelihood
:return: list[tuple[float, bool]]
A list of pairs (logprob, isgreedy)
String for which we are computing overall loglikelihood
:return: list[tuple[float]]
A list of tuples (logprob,)
logprob: float
The log probability of `continuation`
isgreedy:
Whether `continuation` would be generated by greedy sampling from `context`
The log probability of `context` conditioned on the EOT token.
"""
pass
......@@ -306,7 +304,9 @@ class TemplateLM(LM):
return context_enc, continuation_enc
def loglikelihood(self, requests) -> List[Tuple[float, bool]]:
def loglikelihood(
self, requests, disable_tqdm: bool = False
) -> List[Tuple[float, bool]]:
new_reqs = []
for context, continuation in [req.args for req in requests]:
if context == "":
......@@ -320,12 +320,14 @@ class TemplateLM(LM):
new_reqs.append(((context, continuation), context_enc, continuation_enc))
return self._loglikelihood_tokens(new_reqs)
return self._loglikelihood_tokens(new_reqs, disable_tqdm=disable_tqdm)
@abc.abstractmethod
def loglikelihood_rolling(self, requests) -> List[Tuple[float, bool]]:
def loglikelihood_rolling(
self, requests, disable_tqdm: bool = False
) -> List[Tuple[float, bool]]:
pass
@abc.abstractmethod
def generate_until(self, requests) -> List[str]:
def generate_until(self, requests, disable_tqdm: bool = False) -> List[str]:
pass
......@@ -7,7 +7,18 @@ from collections.abc import Callable
from copy import deepcopy
from dataclasses import asdict, dataclass
from inspect import getsource
from typing import Any, Iterator, List, Literal, Tuple, Union
from typing import (
Any,
Dict,
Iterable,
Iterator,
List,
Literal,
Mapping,
Optional,
Tuple,
Union,
)
import datasets
import numpy as np
......@@ -15,12 +26,8 @@ from tqdm import tqdm
from lm_eval import utils
from lm_eval.api import samplers
from lm_eval.api.instance import Instance
from lm_eval.api.metrics import (
bits_per_byte,
mean,
weighted_perplexity,
)
from lm_eval.api.instance import Instance, OutputType
from lm_eval.api.metrics import bits_per_byte, mean, weighted_perplexity
from lm_eval.api.registry import (
AGGREGATION_REGISTRY,
DEFAULT_METRIC_REGISTRY,
......@@ -63,56 +70,54 @@ class GroupConfig(dict):
@dataclass
class TaskConfig(dict):
# task naming/registry
task: str = None
task_alias: str = None
group: Union[str, list] = None
group_alias: Union[str, list] = None
task: Optional[str] = None
task_alias: Optional[str] = None
group: Optional[Union[str, list]] = None
group_alias: Optional[Union[str, list]] = None
# HF dataset options.
# which dataset to use,
# and what splits for what purpose
dataset_path: str = None
dataset_name: str = None
dataset_kwargs: dict = None
training_split: str = None
validation_split: str = None
test_split: str = None
fewshot_split: str = None # TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaling (?)
dataset_path: Optional[str] = None
dataset_name: Optional[str] = None
dataset_kwargs: Optional[dict] = None
training_split: Optional[str] = None
validation_split: Optional[str] = None
test_split: Optional[str] = None
fewshot_split: Optional[
str
] = None # TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaling (?)
# formatting / prompting options.
# see docs/advanced_task_guide.md for more info
process_docs: Callable = None
doc_to_text: Union[Callable, str] = None
doc_to_target: Union[Callable, str] = None
doc_to_choice: Union[Callable, str, dict, list] = None
process_results: Union[Callable, str] = None
use_prompt: str = None
process_docs: Optional[Callable] = None
doc_to_text: Optional[Union[Callable, str]] = None
doc_to_target: Optional[Union[Callable, str]] = None
doc_to_choice: Optional[Union[Callable, str, dict, list]] = None
process_results: Optional[Union[Callable, str]] = None
use_prompt: Optional[str] = None
description: str = ""
target_delimiter: str = " "
fewshot_delimiter: str = "\n\n"
fewshot_config: dict = None
fewshot_config: Optional[dict] = None
# runtime configuration options
num_fewshot: int = None
num_fewshot: Optional[int] = None
# scoring options
metric_list: list = None
output_type: Literal[
"loglikelihood",
"loglikelihood_rolling",
"generate_until",
"multiple_choice",
] = "generate_until"
generation_kwargs: dict = None
metric_list: Optional[list] = None
output_type: OutputType = "generate_until"
generation_kwargs: Optional[dict] = None
repeats: int = 1
filter_list: Union[str, list] = None
filter_list: Optional[Union[str, list]] = None
should_decontaminate: bool = False
doc_to_decontamination_query: str = None
metadata: dict = None # by default, not used in the code. allows for users to pass arbitrary info to tasks
doc_to_decontamination_query: Optional[str] = None
metadata: Optional[
dict
] = None # by default, not used in the code. allows for users to pass arbitrary info to tasks
def __post_init__(self) -> None:
if self.generation_kwargs is not None:
if self.output_type != "generate_until":
eval_logger.warning(
raise ValueError(
f"[{self.task}] passed `generation_kwargs`, but not using `output_type: generate_until`!"
)
assert self.output_type != "generate_until"
if "temperature" in self.generation_kwargs:
self.generation_kwargs["temperature"] = float(
......@@ -193,23 +198,23 @@ class Task(abc.ABC):
{"question": ..., question, answer)
"""
VERSION = None
VERSION: Optional[Union[int, str]] = None
# The name of the `Task` benchmark as denoted in the HuggingFace datasets Hub
# or a path to a custom `datasets` loading script.
DATASET_PATH: str = None
DATASET_PATH: Optional[str] = None
# The name of a subset within `DATASET_PATH`.
DATASET_NAME: str = None
DATASET_NAME: Optional[str] = None
OUTPUT_TYPE: str = None
OUTPUT_TYPE: Optional[OutputType] = None
def __init__(
self,
data_dir=None,
cache_dir=None,
download_mode=None,
config=None,
data_dir: Optional[str] = None,
cache_dir: Optional[str] = None,
download_mode: Optional[datasets.DownloadMode] = None,
config: Optional[Mapping] = None, # Union[dict, TaskConfig]
) -> None:
"""
:param data_dir: str
......@@ -233,15 +238,20 @@ class Task(abc.ABC):
Fresh download and fresh dataset.
"""
self.download(data_dir, cache_dir, download_mode)
self._training_docs = None
self._fewshot_docs = None
self._instances = None
self._training_docs: Optional[list] = None
self._fewshot_docs: Optional[list] = None
self._instances: Optional[List[Instance]] = None
self._config = TaskConfig({**config}) if config else TaskConfig()
self._config: TaskConfig = TaskConfig({**config}) if config else TaskConfig()
self._filters = [build_filter_ensemble("none", [["take_first", None]])]
def download(self, data_dir=None, cache_dir=None, download_mode=None) -> None:
def download(
self,
data_dir: Optional[str] = None,
cache_dir: Optional[str] = None,
download_mode=None,
) -> None:
"""Downloads and returns the task dataset.
Override this method to download the dataset from a custom API.
......@@ -275,7 +285,7 @@ class Task(abc.ABC):
)
@property
def config(self):
def config(self) -> TaskConfig:
"""Returns the TaskConfig associated with this class."""
return self._config
......@@ -294,28 +304,28 @@ class Task(abc.ABC):
"""Whether the task has a test set"""
pass
def training_docs(self):
def training_docs(self) -> Iterable:
"""
:return: Iterable[obj]
A iterable of any object, that doc_to_text can handle
"""
return []
def validation_docs(self):
def validation_docs(self) -> Iterable:
"""
:return: Iterable[obj]
A iterable of any object, that doc_to_text can handle
"""
return []
def test_docs(self):
def test_docs(self) -> Iterable:
"""
:return: Iterable[obj]
A iterable of any object, that doc_to_text can handle
"""
return []
def fewshot_docs(self):
def fewshot_docs(self) -> Iterable:
"""
:return: Iterable[obj]
A iterable of any object, that doc_to_text can handle
......@@ -331,7 +341,7 @@ class Task(abc.ABC):
)
return self.test_docs()
def _process_doc(self, doc):
def _process_doc(self, doc: dict) -> dict:
"""
Override this to process (detokenize, strip, replace, etc.) individual
documents. This can be used in a map over documents of a data split.
......@@ -355,11 +365,10 @@ class Task(abc.ABC):
return rnd.sample(self._training_docs, k)
def doc_to_decontamination_query(self, doc) -> None:
print(
def doc_to_decontamination_query(self, doc):
raise NotImplementedError(
"Override doc_to_decontamination_query with document specific decontamination query."
)
assert False
@abc.abstractmethod
def doc_to_text(self, doc):
......@@ -451,7 +460,8 @@ class Task(abc.ABC):
self._instances = flattened_instances
assert len(self._instances) != 0, "task.build_requests() did not find any docs!"
if len(self._instances) == 0:
raise ValueError("task.build_requests() did not find any docs!")
if cache_requests and (not cached_instances or rewrite_requests_cache):
save_to_cache(file_name=cache_key, obj=instances)
......@@ -544,9 +554,10 @@ class Task(abc.ABC):
:returns: str
The fewshot context.
"""
assert (
rnd is not None
), "A `random.Random` generator argument must be provided to `rnd`"
if rnd is None:
raise ValueError(
"A `random.Random` generator argument must be provided to `rnd`"
)
description = description if description else ""
......@@ -582,7 +593,7 @@ class Task(abc.ABC):
example = self.doc_to_text(doc)
return description + labeled_examples + example
def apply_filters(self):
def apply_filters(self) -> Optional[List[Instance]]:
"""Iterates over FilterEnsembles and applies them to instances"""
if hasattr(self, "_filters"):
for f in self._filters:
......@@ -644,7 +655,9 @@ class Task(abc.ABC):
elif self.has_validation_docs():
return self.validation_docs()
else:
assert False, f"Task dataset (path={self.DATASET_PATH}, name={self.DATASET_NAME}) must have valid or test docs!"
raise ValueError(
f"Task dataset (path={self.DATASET_PATH}, name={self.DATASET_NAME}) must have valid or test docs!"
)
def doc_iterator(
self, *, rank: int = 0, limit: Union[int, None] = None, world_size: int = 1
......@@ -665,7 +678,11 @@ class ConfigurableTask(Task):
CONFIG = None
def __init__(
self, data_dir=None, cache_dir=None, download_mode=None, config: dict = None
self,
data_dir=None,
cache_dir=None,
download_mode=None,
config: Optional[dict] = None,
) -> None: # TODO no super() call here
# Get pre-configured attributes
self._config = self.CONFIG
......@@ -688,7 +705,10 @@ class ConfigurableTask(Task):
self.VERSION = self.config.metadata["version"]
if self.config.output_type is not None:
assert self.config.output_type in ALL_OUTPUT_TYPES
if self.config.output_type not in ALL_OUTPUT_TYPES:
raise ValueError(
f"Got invalid output_type '{self.config.output_type}', must be in '{','.join(ALL_OUTPUT_TYPES)}'"
)
self.OUTPUT_TYPE = self.config.output_type
if self.config.dataset_path is not None:
......@@ -715,7 +735,10 @@ class ConfigurableTask(Task):
self._higher_is_better[metric_name] = is_higher_better(metric_name)
else:
for metric_config in self.config.metric_list:
assert "metric" in metric_config
if "metric" not in metric_config:
raise ValueError(
"'metric' key not provided for an entry in 'metric_list', must be specified!"
)
metric_name = metric_config["metric"]
kwargs = {
key: metric_config[key]
......@@ -860,7 +883,7 @@ class ConfigurableTask(Task):
f'Both target_delimiter "{self.config.target_delimiter}" and target choice: "{choice}" do not have whitespace, ignore if the language you are evaluating on does not require/use whitespace'
)
def download(self, dataset_kwargs=None) -> None:
def download(self, dataset_kwargs: Optional[Dict[str, Any]] = None) -> None:
self.dataset = datasets.load_dataset(
path=self.DATASET_PATH,
name=self.DATASET_NAME,
......@@ -922,7 +945,7 @@ class ConfigurableTask(Task):
return super().fewshot_docs()
@utils.positional_deprecated
def fewshot_context(self, doc, num_fewshot):
def fewshot_context(self, doc: str, num_fewshot: int) -> str:
"""Returns a fewshot context string that is made up of a prepended description
(if provided), the `num_fewshot` number of examples, and an appended prompt example.
......@@ -933,14 +956,14 @@ class ConfigurableTask(Task):
:returns: str
The fewshot context.
"""
if description := self.config.description:
description = utils.apply_template(self.config.description, doc)
if num_fewshot == 0:
# always prepend the (possibly empty) task description
labeled_examples = self.config.description
labeled_examples = description
else:
labeled_examples = self.config.description + self.sampler.get_context(
doc, num_fewshot
)
labeled_examples = description + self.sampler.get_context(doc, num_fewshot)
example = self.doc_to_text(doc)
if self.multiple_input:
......@@ -986,7 +1009,7 @@ class ConfigurableTask(Task):
)
)
def _process_doc(self, doc):
def _process_doc(self, doc: dict) -> dict:
"""
Override this to process (detokenize, strip, replace, etc.) individual
documents. This can be used in a map over documents of a data split.
......@@ -1031,7 +1054,7 @@ class ConfigurableTask(Task):
print(type(doc_to_text))
raise TypeError
def doc_to_target(self, doc: dict) -> Union[int, str, list]:
def doc_to_target(self, doc: Mapping) -> Union[int, str, list]:
if self.prompt is not None:
doc_to_target = self.prompt
else:
......@@ -1213,7 +1236,8 @@ class ConfigurableTask(Task):
# then we are doing mutual info.
# this stores the "dryrun" / unconditional answer loglikelihoods
lls_unconditional = lls[1::2]
assert len(lls_unconditional) == len(choices)
if len(lls_unconditional) != len(choices):
raise ValueError
# and this stores our "regular" conditional loglikelihoods
lls = lls[::2]
......@@ -1376,7 +1400,7 @@ class ConfigurableTask(Task):
class MultipleChoiceTask(Task):
OUTPUT_TYPE: str = "loglikelihood"
OUTPUT_TYPE = "loglikelihood"
def doc_to_target(self, doc: dict) -> str:
return " " + doc["choices"][doc["gold"]]
......@@ -1394,7 +1418,7 @@ class MultipleChoiceTask(Task):
for i, choice in enumerate(doc["choices"])
]
def process_results(self, doc: dict, results: List[Tuple[float, bool]]) -> dict:
def process_results(self, doc: dict, results: Iterable[Tuple[float, bool]]) -> dict:
results = [
res[0] for res in results
] # only retain loglikelihoods, discard is_greedy TODO: do we need is_greedy anywhere?
......@@ -1429,13 +1453,17 @@ class PerplexityTask(Task):
return False
def fewshot_examples(self, k: int, rnd) -> List:
assert k == 0
if k != 0:
raise ValueError(
"The number of fewshot examples must be 0 for perplexity tasks."
)
return []
def fewshot_context(self, doc: dict, num_fewshot: int) -> Literal[""]:
assert (
num_fewshot == 0
), "The number of fewshot examples must be 0 for perplexity tasks."
if num_fewshot != 0:
raise ValueError(
"The number of fewshot examples must be 0 for perplexity tasks."
)
return ""
......@@ -1455,8 +1483,9 @@ class PerplexityTask(Task):
def doc_to_target(self, doc):
return doc
def construct_requests(self, doc: dict, ctx: Union[str, None], **kwargs):
assert not ctx
def construct_requests(self, doc: dict, ctx: Optional[str], **kwargs):
if bool(ctx):
raise ValueError
return Instance(
request_type=self.OUTPUT_TYPE,
......@@ -1466,7 +1495,7 @@ class PerplexityTask(Task):
**kwargs,
)
def process_results(self, doc: dict, results: float) -> dict:
def process_results(self, doc: dict, results: Tuple[float]) -> dict:
(loglikelihood,) = results
words = self.count_words(self.doc_to_target(doc))
bytes_ = self.count_bytes(self.doc_to_target(doc))
......
import collections
import itertools
import logging
import random
from typing import TYPE_CHECKING, Optional, Union
from collections import defaultdict
from typing import TYPE_CHECKING, List, Optional, Union
import numpy as np
import torch
......@@ -10,6 +10,7 @@ import torch
import lm_eval.api.metrics
import lm_eval.api.registry
import lm_eval.models
from lm_eval.caching.cache import delete_cache
from lm_eval.evaluator_utils import (
consolidate_results,
get_sample_size,
......@@ -20,25 +21,19 @@ from lm_eval.evaluator_utils import (
)
from lm_eval.logging_utils import add_env_info, get_git_commit_hash
from lm_eval.tasks import TaskManager, get_task_dict
from lm_eval.utils import (
eval_logger,
positional_deprecated,
simple_parse_args_string,
)
from lm_eval.utils import eval_logger, positional_deprecated, simple_parse_args_string
if TYPE_CHECKING:
from lm_eval.api.model import LM
from lm_eval.tasks import Task
from lm_eval.caching.cache import delete_cache
@positional_deprecated
def simple_evaluate(
model,
model_args: Optional[Union[str, dict, None]] = None,
tasks=None,
model_args: Optional[Union[str, dict]] = None,
tasks: Optional[List[Union[str, dict, object]]] = None,
num_fewshot: Optional[int] = None,
batch_size: Optional[int] = None,
max_batch_size: Optional[int] = None,
......@@ -50,11 +45,10 @@ def simple_evaluate(
limit: Optional[Union[int, float]] = None,
bootstrap_iters: int = 100000,
check_integrity: bool = False,
decontamination_ngrams_path=None,
write_out: bool = False,
log_samples: bool = True,
gen_kwargs: str = None,
task_manager: TaskManager = None,
gen_kwargs: Optional[str] = None,
task_manager: Optional[TaskManager] = None,
verbosity: str = "INFO",
predict_only: bool = False,
random_seed: int = 0,
......@@ -136,14 +130,16 @@ def simple_evaluate(
if tasks is None:
tasks = []
assert (
tasks != []
), "No tasks specified, or no tasks found. Please verify the task names."
if len(tasks) == 0:
raise ValueError(
"No tasks specified, or no tasks found. Please verify the task names."
)
if gen_kwargs is not None:
gen_kwargs = simple_parse_args_string(gen_kwargs)
eval_logger.warning(
"generation_kwargs specified through cli, these settings will update set parameters in yaml tasks. Ensure 'do_sample=True' for non-greedy decoding!"
"generation_kwargs specified through cli, these settings will update set parameters in yaml tasks. "
"Ensure 'do_sample=True' for non-greedy decoding!"
)
if gen_kwargs == "":
gen_kwargs = None
......@@ -152,7 +148,7 @@ def simple_evaluate(
if model_args is None:
model_args = ""
elif isinstance(model_args, dict):
if isinstance(model_args, dict):
lm = lm_eval.api.registry.get_model(model).create_from_arg_obj(
model_args,
{
......@@ -172,7 +168,8 @@ def simple_evaluate(
},
)
else:
assert isinstance(model, lm_eval.api.model.LM)
if not isinstance(model, lm_eval.api.model.LM):
raise TypeError
lm = model
if use_cache is not None:
......@@ -237,7 +234,6 @@ def simple_evaluate(
cache_requests=cache_requests,
rewrite_requests_cache=rewrite_requests_cache,
bootstrap_iters=bootstrap_iters,
decontamination_ngrams_path=decontamination_ngrams_path,
write_out=write_out,
log_samples=log_samples,
verbosity=verbosity,
......@@ -272,18 +268,14 @@ def simple_evaluate(
return None
decontaminate_suffix = "_decontaminate"
@positional_deprecated
def evaluate(
lm: "LM",
task_dict,
limit: Optional[int] = None,
cache_requests=False,
rewrite_requests_cache=False,
cache_requests: bool = False,
rewrite_requests_cache: bool = False,
bootstrap_iters: Optional[int] = 100000,
decontamination_ngrams_path=None,
write_out: bool = False,
log_samples: bool = True,
verbosity: str = "INFO",
......@@ -307,21 +299,21 @@ def evaluate(
"""
eval_logger.setLevel(getattr(logging, f"{verbosity}"))
# decontaminate = decontamination_ngrams_path is not None
# tracks all Instances/requests a model must generate output on.
requests = collections.defaultdict(list)
requests = defaultdict(list)
# stores the amount to pad out reqs per req. type so that
# number of fwd passes per distributed rank is equal
padding_requests = collections.defaultdict(int)
padding_requests = defaultdict(int)
# get lists of group hierarchy and each type of request
task_hierarchy, eval_tasks = get_task_list(task_dict)
if not log_samples:
assert all(
if not all(
"bypass" not in getattr(task_output.task, "_metric_fn_list", {}).keys()
for task_output in eval_tasks
), "log_samples must be True for 'bypass' only tasks"
):
raise ValueError("log_samples must be True for 'bypass' metric-only tasks")
for task_output in eval_tasks:
task: Task = task_output.task
limit = get_sample_size(task, limit)
......@@ -348,10 +340,16 @@ def evaluate(
gathered_item = (
lm.accelerator.gather(instances_rnk).cpu().detach().numpy().tolist()
)
# "multiple_choice" task types dispatch (several) "loglikelihood" request types
reqtype = (
"loglikelihood"
if task.OUTPUT_TYPE == "multiple_choice"
else task.OUTPUT_TYPE
)
# compute number of pseudo-batches to pad with (FSDP/DDP require even batches among ranks)
numpad = max(gathered_item) - gathered_item[lm.rank]
padding_requests[task.OUTPUT_TYPE] += numpad
# todo: may not account for padding in cases like SquadV2 which has multiple req types
padding_requests[reqtype] += numpad
### Run LM on inputs, get all outputs ###
# execute each type of request
......@@ -388,7 +386,7 @@ def evaluate(
# # unpack results and sort back in order and return control to Task
# TODO: make it possible to use a different metric per filter
# Pre-process task.instances to group by doc_id
instances_by_doc_id = collections.defaultdict(list)
instances_by_doc_id = defaultdict(list)
for instance in task.instances:
instances_by_doc_id[instance.doc_id].append(instance)
# Sort instances within each group
......@@ -515,10 +513,9 @@ def evaluate(
results[group]["samples"] = sum(sizes)
results_agg = collections.defaultdict(dict)
groups_agg = collections.defaultdict(dict)
results_agg = defaultdict(dict)
groups_agg = defaultdict(dict)
all_tasks_list = list(task_hierarchy.keys())
left_tasks_list = []
while True:
add_tasks_list = list(k for k in results_agg.keys())
left_tasks_list = sorted(list(set(all_tasks_list) - set(add_tasks_list)))
......@@ -542,7 +539,7 @@ def evaluate(
results_dict = {
"results": dict(results_agg.items()),
**({"groups": dict(groups_agg.items())} if bool(groups_agg) else {}),
"group_subtasks": {k: v for k, v in reversed(task_hierarchy.items())},
"group_subtasks": dict(reversed(task_hierarchy.items())),
"configs": dict(sorted(configs.items())),
"versions": dict(sorted(versions.items())),
"n-shot": dict(sorted(num_fewshot.items())),
......@@ -558,11 +555,9 @@ def evaluate(
def request_caching_arg_to_dict(cache_requests: str) -> dict:
request_caching_args = {
"cache_requests": (
True if cache_requests == "true" or cache_requests == "refresh" else False
),
"rewrite_requests_cache": True if cache_requests == "refresh" else False,
"delete_requests_cache": True if cache_requests == "delete" else False,
"cache_requests": cache_requests in {"true", "refresh"},
"rewrite_requests_cache": cache_requests == "refresh",
"delete_requests_cache": cache_requests == "delete",
}
return request_caching_args
......@@ -15,6 +15,7 @@ FILTER_REGISTRY = {
"lowercase": transformation.LowercaseFilter,
"uppercase": transformation.UppercaseFilter,
"map": transformation.MapFilter,
"multi_choice_regex": extraction.MultiChoiceRegexFilter,
# TODO: implement this filter. either it should take in an arbitrary "scoring"/reward function
# that takes an input and returns a scalar and then should select the max reward,
# or should implement different filters for different ways of handling a reward model's inference.
......
import re
import sys
import unicodedata
from lm_eval.api.filter import Filter
......@@ -67,3 +69,115 @@ class WhitespaceFilter(Filter):
filtered_resps = [filter_set(resp) for resp in resps]
return filtered_resps
class MultiChoiceRegexFilter(RegexFilter):
"""
A filter used to extract a model's answer on multiple choice questions with
letter answers. assumes each document has a "choices" field
containing the list of answer choices and that the answer label symbols
are of the form (A), (B), (C), ... or A, B, C.
"""
def __init__(
self,
regex_pattern: str = r"#### (\-?[0-9\.\,]+)",
group_select=0,
fallback: str = "[invalid]",
ignore_case=False,
ignore_punctuation=False,
regexes_to_ignore=None,
) -> None:
"""
regex_pattern: The basic regex pattern to use. If fails to match, we will use the customized match procedure
- step 1 : We parse the choices between ([A-Z])s then try to find these choices in the response.
- step 2 : We parse the choice with regex :[\s]*([A-?]), where ? varies by number of choices.
group_select: Selects the (group_select)th match from the findall result.
ignore_case: Ignores the case during step 1 matching
ignore_punctuation: Remove the punctuation during step 1 matching
regexes_to_ignore: Remove these regexes during step 1 matching
"""
super().__init__(regex_pattern, group_select, fallback)
self.ignore_case = ignore_case
self.ignore_punctuation = ignore_punctuation
self.regexes_to_ignore = regexes_to_ignore
def apply(self, resps, docs):
# here, we assume we have a list, in which each element is
# a list of model responses for some particular input/target pair.
# so we process each of these (same input/target response sets)
# independently (and keep them a list.)
def find_match(regex, resp, convert_dict={}):
match = regex.findall(resp)
if match:
match = match[self.group_select]
if isinstance(match, tuple):
match = [m for m in match if m][0]
match = match.strip()
if match and match in convert_dict:
match = convert_dict[match]
return match
punct_tbl = dict.fromkeys(
i
for i in range(sys.maxunicode)
if unicodedata.category(chr(i)).startswith("P")
)
def filter_ignores(st):
if self.regexes_to_ignore is not None:
for s in self.regexes_to_ignore:
st = re.sub(s, "", st)
if self.ignore_case:
st = st.lower()
if self.ignore_punctuation:
# https://stackoverflow.com/a/266162
st = st.translate(punct_tbl)
return st
filtered_resps = []
for r, doc in zip(resps, docs):
fallback_regexes = []
choice_to_alpha = {}
next_alpha = "A"
without_paren_fallback_regexes = []
without_paren_to_target = {}
choices = doc["choices"]
for c in choices:
m = filter_ignores(c.strip())
fallback_regexes.append(f"{re.escape(m)}")
choice_to_alpha[m] = f"({next_alpha})"
without_paren_fallback_regexes.append(next_alpha)
without_paren_to_target[next_alpha] = f"({next_alpha})"
next_alpha = chr(ord(next_alpha) + 1)
fallback_regex = re.compile("|".join(fallback_regexes))
without_paren_fallback_regex = "|".join(without_paren_fallback_regexes)
without_paren_fallback_regex = re.compile(
f":[\s]*({without_paren_fallback_regex})"
)
filtered = []
for resp in r:
match = find_match(self.regex, resp)
if not match:
match = find_match(
fallback_regex, filter_ignores(resp), choice_to_alpha
)
if not match:
match = find_match(
without_paren_fallback_regex, resp, without_paren_to_target
)
if not match:
match = self.fallback
filtered.append(match)
filtered_resps.append(filtered)
return filtered_resps
......@@ -13,24 +13,9 @@ from packaging.version import Version
from torch.utils.collect_env import get_pretty_env_info
from transformers import __version__ as trans_version
from lm_eval.utils import simple_parse_args_string
logger = logging.getLogger(__name__)
try:
import wandb
assert Version(wandb.__version__) >= Version("0.13.6")
if Version(wandb.__version__) < Version("0.13.6"):
wandb.require("report-editing:v0")
except Exception as e:
logger.warning(
"To use the wandb reporting functionality please install wandb>=0.13.6.\n"
"To install the latest version of wandb run `pip install wandb --upgrade`\n"
f"{e}"
)
def remove_none_pattern(input_string: str) -> Tuple[str, bool]:
"""Remove the ',none' substring from the input_string if it exists at the end.
......@@ -83,14 +68,31 @@ def get_wandb_printer() -> Literal["Printer"]:
class WandbLogger:
def __init__(self, args: Any) -> None:
"""Initialize the WandbLogger.
def __init__(self, **kwargs) -> None:
"""Attaches to wandb logger if already initialized. Otherwise, passes kwargs to wandb.init()
Args:
results (Dict[str, Any]): The results dictionary.
args (Any): Arguments for configuration.
kwargs Optional[Any]: Arguments for configuration.
Parse and log the results returned from evaluator.simple_evaluate() with:
wandb_logger.post_init(results)
wandb_logger.log_eval_result()
wandb_logger.log_eval_samples(results["samples"])
"""
self.wandb_args: Dict[str, Any] = simple_parse_args_string(args.wandb_args)
try:
import wandb
assert Version(wandb.__version__) >= Version("0.13.6")
if Version(wandb.__version__) < Version("0.13.6"):
wandb.require("report-editing:v0")
except Exception as e:
logger.warning(
"To use the wandb reporting functionality please install wandb>=0.13.6.\n"
"To install the latest version of wandb run `pip install wandb --upgrade`\n"
f"{e}"
)
self.wandb_args: Dict[str, Any] = kwargs
# initialize a W&B run
if wandb.run is None:
......@@ -164,6 +166,8 @@ class WandbLogger:
]
def make_table(columns: List[str], key: str = "results"):
import wandb
table = wandb.Table(columns=columns)
results = copy.deepcopy(self.results)
......@@ -202,6 +206,8 @@ class WandbLogger:
def _log_results_as_artifact(self) -> None:
"""Log results as JSON artifact to W&B."""
import wandb
dumped = json.dumps(
self.results, indent=2, default=_handle_non_serializable, ensure_ascii=False
)
......@@ -320,6 +326,8 @@ class WandbLogger:
def _log_samples_as_artifact(
self, data: List[Dict[str, Any]], task_name: str
) -> None:
import wandb
# log the samples as an artifact
dumped = json.dumps(
data,
......
......@@ -147,7 +147,7 @@ please install anthropic via `pip install lm-eval[anthropic]` or `pip install -e
def _loglikelihood_tokens(self, requests, disable_tqdm: bool = False):
raise NotImplementedError("No support for logits.")
def generate_until(self, requests) -> List[str]:
def generate_until(self, requests, disable_tqdm: bool = False) -> List[str]:
try:
import anthropic
except ModuleNotFoundError:
......@@ -162,7 +162,7 @@ please install anthropic via `pip install lm-eval[anthropic]` or `pip install -e
_requests: List[Tuple[str, dict]] = [req.args for req in requests]
res = []
for request in tqdm(_requests):
for request in tqdm(_requests, disable=disable_tqdm):
try:
inp = request[0]
request_args = request[1]
......@@ -199,8 +199,8 @@ please install anthropic via `pip install lm-eval[anthropic]` or `pip install -e
# Isn't used because we override generate_until
raise NotImplementedError()
def loglikelihood(self, requests):
def loglikelihood(self, requests, disable_tqdm: bool = False):
raise NotImplementedError("No support for logits.")
def loglikelihood_rolling(self, requests):
def loglikelihood_rolling(self, requests, disable_tqdm: bool = False):
raise NotImplementedError("No support for logits.")
import random
from tqdm import tqdm
from lm_eval.api.model import LM
from lm_eval.api.registry import register_model
......@@ -13,27 +15,27 @@ class DummyLM(LM):
def create_from_arg_string(cls, arg_string, additional_config=None):
return cls()
def loglikelihood(self, requests):
def loglikelihood(self, requests, disable_tqdm: bool = False):
res = []
for _ in requests:
for _ in tqdm(requests, disable=disable_tqdm):
res.append((-random.random(), False))
return res
def generate_until(self, requests):
def generate_until(self, requests, disable_tqdm: bool = False):
res = []
for ctx, _ in requests:
for ctx, _ in tqdm(requests, disable=disable_tqdm):
res.append("lol")
assert ctx.strip() != ""
return res
def loglikelihood_rolling(self, requests):
def loglikelihood_rolling(self, requests, disable_tqdm: bool = False):
res = []
for _ in requests:
for _ in tqdm(requests, disable=disable_tqdm):
res.append(-random.random())
return res
......@@ -70,11 +70,13 @@ class GGUFLM(LM):
else:
raise Exception(f"Failed to get a valid response after {retries} retries.")
def loglikelihood(self, requests):
def loglikelihood(self, requests, disable_tqdm: bool = False):
if not requests:
return []
res = []
for context, continuation in tqdm([req.args for req in requests]):
for context, continuation in tqdm(
[req.args for req in requests], disable=disable_tqdm
):
response = self.gguf_completion(context=context, continuation=continuation)
if response and "choices" in response and response["choices"]:
choice = response["choices"][0]
......@@ -97,12 +99,12 @@ class GGUFLM(LM):
assert False
return res
def generate_until(self, requests):
def generate_until(self, requests, disable_tqdm: bool = False):
if not requests:
return []
res = []
for request in tqdm([req.args for req in requests]):
for request in tqdm([req.args for req in requests], disable=disable_tqdm):
inp = request[0]
request_args = request[1]
until = request_args.get("until", ["</s>"])
......@@ -122,7 +124,7 @@ class GGUFLM(LM):
res.append(None) # Add default value in case of error
return res
def loglikelihood_rolling(self, requests):
def loglikelihood_rolling(self, requests, disable_tqdm: bool = False):
raise NotImplementedError(
"loglikelihood_rolling not yet supported for GGUF models"
)
......@@ -790,7 +790,9 @@ class HFLM(TemplateLM):
return logits
def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]:
def loglikelihood_rolling(
self, requests: List[Instance], disable_tqdm: bool = False
) -> List[float]:
loglikelihoods = []
adaptive_batch_size = None
......@@ -801,7 +803,9 @@ class HFLM(TemplateLM):
print(f"Determined Largest batch size: {batch_size}")
adaptive_batch_size = batch_size
for (string,) in tqdm([req.args for req in requests], disable=(self.rank != 0)):
for (string,) in tqdm(
[req.args for req in requests], disable=(disable_tqdm or (self.rank != 0))
):
rolling_token_windows = list(
map(
utils.make_disjoint_window,
......@@ -887,7 +891,7 @@ class HFLM(TemplateLM):
def _lookup_one_token_cont(req: Tuple[Tuple[str, str], List[int], List[int]]):
"""Defines the key to group and lookup one-token continuations"""
# Use with group_by="contexts" (optional)"
# allows for the creation of a lookup, so we can re-use logits in case of one-token continuations.
# allows for the creation of a lookup, so we can reuse logits in case of one-token continuations.
# speeds up some multiple-choice tasks proportionally to the number of choices.
# groups requests by context+continuation[:-1] and infer on one request/group.
return req[-2] + req[-1][:-1]
......@@ -1079,7 +1083,9 @@ class HFLM(TemplateLM):
return re_ord.get_original(res)
def generate_until(self, requests: List[Instance]) -> List[str]:
def generate_until(
self, requests: List[Instance], disable_tqdm: bool = False
) -> List[str]:
res = []
def _collate(req: Tuple[str, dict]):
......@@ -1095,7 +1101,7 @@ class HFLM(TemplateLM):
pbar = tqdm(
total=len(requests),
disable=(self.rank != 0),
disable=(disable_tqdm or (self.rank != 0)),
desc="Running generate_until requests",
)
adaptive_batch_size = None
......@@ -1151,8 +1157,12 @@ class HFLM(TemplateLM):
raise ValueError(
f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}"
)
# add EOS token to stop sequences
eos = self.tok_decode(self.eot_token_id)
if not until:
until = [self.tok_decode(self.eot_token_id)]
until = [eos]
else:
until.append(eos)
if "max_gen_toks" in kwargs.keys():
max_gen_toks = kwargs.pop("max_gen_toks")
else:
......
......@@ -447,12 +447,14 @@ class NEURON_HF(TemplateLM):
return logits
def loglikelihood_rolling(self, requests):
def loglikelihood_rolling(self, requests, disable_tqdm: bool = False):
loglikelihoods = []
adaptive_batch_size = None
for (string,) in tqdm([req.args for req in requests], disable=(self.rank != 0)):
for (string,) in tqdm(
[req.args for req in requests], disable=(disable_tqdm or (self.rank != 0))
):
rolling_token_windows = list(
map(
utils.make_disjoint_window,
......@@ -616,7 +618,7 @@ class NEURON_HF(TemplateLM):
return re_ord.get_original(res)
def generate_until(self, requests):
def generate_until(self, requests, disable_tqdm: bool = False):
res = defaultdict(list)
re_ords = {}
......@@ -638,7 +640,7 @@ class NEURON_HF(TemplateLM):
# within each set of reqs for given kwargs, we reorder by token length, descending.
re_ords[key] = utils.Reorderer([req.args for req in reqs], _collate)
pbar = tqdm(total=len(requests), disable=(self.rank != 0))
pbar = tqdm(total=len(requests), disable=(disable_tqdm or (self.rank != 0)))
# for each different set of kwargs, we execute all requests, by batch.
for key, re_ord in re_ords.items():
......@@ -666,8 +668,12 @@ class NEURON_HF(TemplateLM):
raise ValueError(
f"Expected `kwargs` to be of type `dict` but got {kwargs}"
)
# add EOS token to stop sequences
eos = self.tok_decode(self.eot_token_id)
if not until:
until = [self.tok_decode(self.eot_token_id)]
until = [eos]
else:
until.append(eos)
if "max_gen_toks" in kwargs.keys():
max_gen_toks = kwargs.pop("max_gen_toks")
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment