Unverified Commit d924ca33 authored by ben's avatar ben Committed by GitHub
Browse files

Merge pull request #2 from EleutherAI/multigpu-feature-minor-edits

Multigpu feature minor edits
parents 650d3c76 c77fa461
......@@ -3,3 +3,5 @@ env
data/
lm_cache
.idea
*.egg-info/
......@@ -12,6 +12,7 @@ repos:
- id: check-merge-conflict
- id: check-symlinks
- id: check-yaml
args: ['--unsafe']
- id: destroyed-symlinks
- id: detect-private-key
- id: end-of-file-fixer
......
......@@ -45,7 +45,7 @@ python main.py \
--device cuda:0
```
Additional arguments can be provided to the model constructor using the `--model_args` flag. Most notably, this supports the common practice of using the `revisions` feature on the Hub to store partialy trained checkpoints:
Additional arguments can be provided to the model constructor using the `--model_args` flag. Most notably, this supports the common practice of using the `revisions` feature on the Hub to store partially trained checkpoints:
```bash
python main.py \
......@@ -64,8 +64,8 @@ To use with [PEFT](https://github.com/huggingface/peft), take the call you would
python main.py \
--model hf-causal \
--model_args pretrained=EleutherAI/gpt-j-6b,peft=nomic-ai/gpt4all-j-lora \
--tasks openbookqa,arc_easy,winogrande,hellaswag,arc_challenge,piqa,boolq \
--device cuda:0
--tasks openbookqa,arc_easy,winogrande,hellaswag,arc_challenge,piqa,boolq \
--device cuda:0
```
Our library also supports the OpenAI API:
......@@ -78,7 +78,7 @@ python main.py \
--tasks lambada_openai,hellaswag
```
While this functionality is only officially mantained for the official OpenAI API, it tends to also work for other hosting services that use the same API such as [goose.ai](goose.ai) with minor modification. We also have an implementation for the [TextSynth](https://textsynth.com/index.html) API, using `--model textsynth`.
While this functionality is only officially maintained for the official OpenAI API, it tends to also work for other hosting services that use the same API such as [goose.ai](goose.ai) with minor modification. We also have an implementation for the [TextSynth](https://textsynth.com/index.html) API, using `--model textsynth`.
To verify the data integrity of the tasks you're performing in addition to running the tasks themselves, you can use the `--check_integrity` flag:
......@@ -129,7 +129,7 @@ When reporting eval harness results, please also report the version of each task
## Test Set Decontamination
To address concerns about train / test contamination, we provide utilities for comparing results on a benchmark using only the data points nto found in the model trainign set. Unfortunately, outside of models trained on the Pile ans C4, its very rare that people who train models disclose the contents of the training data. However this utility can be useful to evaluate models you have trained on private data, provided you are willing to pre-compute the necessary indices. We provide computed indices for 13-gram exact match deduplication against the Pile, and plan to add additional precomputed dataset indices in the future (including C4 and min-hash LSH deduplication).
To address concerns about train / test contamination, we provide utilities for comparing results on a benchmark using only the data points nto found in the model trainign set. Unfortunately, outside of models trained on the Pile and C4, its very rare that people who train models disclose the contents of the training data. However this utility can be useful to evaluate models you have trained on private data, provided you are willing to pre-compute the necessary indices. We provide computed indices for 13-gram exact match deduplication against the Pile, and plan to add additional precomputed dataset indices in the future (including C4 and min-hash LSH deduplication).
For details on text decontamination, see the [decontamination guide](./docs/decontamination.md).
......
from . import metrics
METRIC_REGISTRY = {
"matthews_corrcoef": metrics.matthews_corrcoef,
"f1_score": metrics.f1_score,
"perplexity": metrics.perplexity,
"bleu": metrics.bleu,
"chrf": metrics.chrf,
"ter": metrics.ter,
}
AGGREGATION_REGISTRY = {
"mean": metrics.mean,
"median": metrics.median,
"perplexity": metrics.perplexity,
}
HIGHER_IS_BETTER_REGISTRY = {
"matthews_corrcoef": True,
"f1_score": True,
"perplexity": False,
"bleu": True,
"chrf": True,
"ter": False,
"acc": True,
"acc_norm": True,
"acc_mutual_info": True,
"word_perplexity": False,
"byte_perplexity": False,
"bits_per_byte": False,
}
\ No newline at end of file
......@@ -3,9 +3,10 @@ from typing import List
from lm_eval.api.instance import Instance
class Filter:
"""
Filter classes operate on a per-task level.
Filter classes operate on a per-task level.
They take all model outputs (`instance.resps` for all `task.instances`)
across all instances of a task, and perform operations.
In a single run, one can configure any number of separate filters or lists of filters.
......@@ -25,30 +26,33 @@ class Filter:
[<filtered resps for instance 0>, <filtered resps for instance 1>]
"""
return resps
@dataclass
class FilterEnsemble:
"""
FilterEnsemble creates a pipeline applying multiple filters.
Its intended usage is to stack multiple post-processing steps in order.
`task.apply_filters` should use a list of FilterEnsemble classes that it stores, to apply each
Its intended usage is to stack multiple post-processing steps in order.
`task.apply_filters` should use a list of FilterEnsemble classes that it stores, to apply each
pipeline separately.
"""
name: str
name: str
filters: List[Filter]
def apply(self, instances: List[Instance]):
resps = [inst.resps for inst in instances] # operate just on the model responses
resps = [
inst.resps for inst in instances
] # operate just on the model responses
for f in self.filters:
# apply filters in sequence
out = f.apply(resps)
resps = out # TODO: handle the case where a filter returns multiple "buckets"
resps = (
out # TODO: handle the case where a filter returns multiple "buckets"
)
# add the end results after filtering to filtered_requests of their respective source instances.
# has key `self.name`: each FilterEnsemble applied in a given run should use a different name.
for inst, resp in zip(instances, resps):
inst.filtered_resps[self.name] = resp
from dataclasses import dataclass, field
from typing import Literal, Tuple
@dataclass
class Instance:
request_type: str = Literal["loglikelihood", "loglikelihood_rolling", "greedy_until"]
request_type: str = Literal[
"loglikelihood", "loglikelihood_rolling", "greedy_until"
]
doc: dict = None
arguments: tuple = None
idx: int = None
metadata: tuple = Tuple[str, int, int] # TODO: better typehints here
metadata: tuple = Tuple[str, int, int] # TODO: better typehints here
resps: list = field(default_factory=list)
filtered_resps: dict = field(default_factory=dict)
......@@ -19,10 +22,12 @@ class Instance:
def __post_init__(self):
# unpack metadata field
self.task_name, self.doc_id, self.repeats = self.metadata
@property
def args(self):
"""
Returns (string,) where `string` is the string to calculate loglikelihood over
"""
return self.arguments if isinstance(self.arguments, tuple) else (self.arguments,)
return (
self.arguments if isinstance(self.arguments, tuple) else (self.arguments,)
)
......@@ -10,6 +10,7 @@ import evaluate
AGGREGATION_REGISTRY = {}
METRIC_REGISTRY = {
"acc": None,
"acc_norm": None,
......@@ -18,6 +19,21 @@ METRIC_REGISTRY = {
"byte_perplexity": None,
}
HIGHER_IS_BETTER_REGISTRY = {
"matthews_corrcoef": True,
"f1_score": True,
"perplexity": False,
"bleu": True,
"chrf": True,
"ter": False,
"acc": True,
"acc_norm": True,
"acc_mutual_info": True,
"word_perplexity": False,
"byte_perplexity": False,
"bits_per_byte": False,
}
def register_metric(name):
# TODO: do we want to enforce a certain interface to registered metrics?
......@@ -28,7 +44,7 @@ def register_metric(name):
METRIC_REGISTRY[name] = fn
return fn
return decorate
......@@ -38,12 +54,14 @@ def get_metric(name):
return METRIC_REGISTRY[name]
except KeyError:
# TODO: change this print to logging?
print(f"Could not find registered metric '{name}' in lm-eval, \
searching in HF Evaluate library...")
print(
f"Could not find registered metric '{name}' in lm-eval, \
searching in HF Evaluate library..."
)
try:
metric_object = evaluate.load(name)
return metric_object.compute
except:
except Exception:
raise Warning(
"{} not found in the evaluate library!".format(name),
"Please check https://huggingface.co/evaluate-metric",
......@@ -59,7 +77,7 @@ def register_aggregation(name):
AGGREGATION_REGISTRY[name] = fn
return fn
return decorate
......
......@@ -6,14 +6,15 @@ from lm_eval import utils
MODEL_REGISTRY = {}
def register_model(*names):
# either pass a list or a single alias.
# function receives them as a tuple of strings
def decorate(cls):
for name in names:
assert (
issubclass(cls, LM)
for name in names:
assert issubclass(
cls, LM
), f"Model '{name}' ({cls.__name__}) must extend LM class"
assert (
......@@ -22,7 +23,7 @@ def register_model(*names):
MODEL_REGISTRY[name] = cls
return cls
return decorate
......
import os
task_registry = {}
group_registry = {}
task2func_index = {}
func2task_index = {}
def register_task(name):
def wrapper(func):
task_registry[name] = func
func2task_index[func.__name__] = name
task2func_index[name] = func.__name__
return func
return wrapper
def register_group(name):
def wrapper(func):
func_name = func2task_index[func.__name__]
if name in group_registry:
group_registry[name].append(func_name)
else:
group_registry[name] = [func_name]
return func
return wrapper
class Sampler:
def __init__(self, docs, task, fewshot_indices=None, rnd=None):
self.rnd = rnd
......@@ -12,15 +9,18 @@ class Sampler:
self.delimiter = self.config.delimiter
self.docs = docs # HF dataset split, provided by task._fewshot_docs()
if fewshot_indices: # subset few-shot docs from
self.docs = docs # HF dataset split, provided by task._fewshot_docs()
if fewshot_indices: # subset few-shot docs from
self.docs = self.docs.select(fewshot_indices)
def get_context(self, doc, num_fewshot):
# draw an extra fewshot sample if using same split as evaluting on
n_samples = num_fewshot + 1 if self.config.fewshot_split == self.config.test_split else num_fewshot
# draw an extra fewshot sample if using same split as evaluating on
n_samples = (
num_fewshot + 1
if self.config.fewshot_split == self.config.test_split
else num_fewshot
)
# draw `n_samples` docs from fewshot_docs
fewshotex = self.sample(n_samples)
......@@ -28,16 +28,16 @@ class Sampler:
# get rid of the doc that's the one we're evaluating, if it's in the fewshot
# TODO: should we just stop people from using fewshot from same split as evaluating?
selected_docs = [x for x in fewshotex if x != doc][:num_fewshot]
labeled_examples = (
self.delimiter.join(
[
self.task.doc_to_text(doc) + self.task.doc_to_target(doc)
for doc in selected_docs
]
)
+ self.delimiter
self.delimiter.join(
[
self.task.doc_to_text(doc) + self.task.doc_to_target(doc)
for doc in selected_docs
]
)
+ self.delimiter
)
# only returns the fewshot context! Does not append the document, do this outside the object
return labeled_examples
......@@ -51,25 +51,22 @@ class Sampler:
class BalancedSampler(Sampler):
def sample(self, n):
"""
TODO: this should return approximately class-balanced samples from our fewshot examples.
TODO: this should return approximately class-balanced samples from our fewshot examples.
TODO: what order should they be in? maybe random?
"""
pass
class ManualSampler(Sampler):
class ManualSampler(Sampler):
def sample(self, n):
"""
"""
pass
""" """
pass
# TODO: how should we do design here? might be better to have a single sampler and pass more kwargs at init.
# TODO: how should we do design here? might be better to have a single sampler and pass more kwargs at init.
# Depends what's easier for new user to add own functionality on top of
# types of sampler:
......
......@@ -3,6 +3,7 @@ from dataclasses import dataclass
import re
import ast
import yaml
import evaluate
import random
import itertools
......@@ -11,35 +12,49 @@ import functools
import datasets
import numpy as np
from typing import List, Union
from typing import Union
from collections.abc import Callable
from lm_eval.api.metrics import METRIC_REGISTRY, AGGREGATION_REGISTRY
from lm_eval.api import HIGHER_IS_BETTER_REGISTRY
from lm_eval.api.instance import Instance
from lm_eval.api.metrics import get_metric, get_aggregation, mean, weighted_perplexity, bits_per_byte
from lm_eval import utils
from lm_eval.api import samplers
from lm_eval.api.instance import Instance
from lm_eval.api.metrics import (
METRIC_REGISTRY,
AGGREGATION_REGISTRY,
HIGHER_IS_BETTER_REGISTRY,
get_metric,
get_aggregation,
mean,
weighted_perplexity,
bits_per_byte,
)
from lm_eval.logger import eval_logger
from lm_eval.prompts import get_prompt
from lm_eval.filters import build_filter_ensemble
from lm_eval.api import samplers
@dataclass
class TaskConfig(dict):
task: str = None
group: str = None
names: str = None
task_name: str = None # TODO: deprecate this, it'll be set in __post_init__ to be names[0]
reference: str = None
task_name: str = (
None # TODO: deprecate this, it'll be set in __post_init__ to be names[0]
)
base_task: str = None
dataset_path: str = None
dataset_name: str = None
training_split: str = None
validation_split: str = None
test_split: str = None
fewshot_split: str = None # TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaling (?)
template_aliases: str = ""
doc_to_text: str = ""
doc_to_target: str = ""
fewshot_split: str = None # TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaling (?)
template_aliases: str = None
doc_to_text: Union[Callable, str] = None
doc_to_target: Union[Callable, str] = None
num_fewshot: int = 0
batch_size: int = 1
......@@ -49,20 +64,26 @@ class TaskConfig(dict):
gold_alias: str = None
output_type: str = "greedy_until"
delimiter: str = "\n\n"
filters: str = None #TODO: need to make this typehint `list`?
normalization: str = None # TODO: add length-normalization of various types, mutual info
filter_list: Union[str, list] = None
normalization: str = (
None # TODO: add length-normalization of various types, mutual info
)
should_decontaminate: bool = False
doc_to_decontamination_query: str = None
use_prompt: str = None
metadata: str = None # by default, not used in the code. allows for users to pass arbitrary info to tasks
metadata: str = None # by default, not used in the code. allows for users to pass arbitrary info to tasks
def __post_init__(self):
# allow user-specified aliases so that users can
# force prompt-compatibility for some prompt regardless of
# field names in prompt
self.doc_to_text = self.template_aliases + self.doc_to_text
self.doc_to_target = self.template_aliases + self.doc_to_target
if self.template_aliases is not None:
if type(self.doc_to_text) == str:
self.doc_to_text = self.template_aliases + self.doc_to_text
if type(self.doc_to_target) == str:
self.doc_to_target = self.template_aliases + self.doc_to_target
# set "task_name" metadata field based on the "primary" name set
if self.names:
......@@ -83,6 +104,7 @@ class Task(abc.ABC):
"""
VERSION = None
# The name of the `Task` benchmark as denoted in the HuggingFace datasets Hub
# or a path to a custom `datasets` loading script.
DATASET_PATH: str = None
......@@ -91,6 +113,7 @@ class Task(abc.ABC):
DATASET_NAME: str = None
OUTPUT_TYPE: str = None
def __init__(
self,
data_dir=None,
......@@ -129,12 +152,15 @@ class Task(abc.ABC):
if not hasattr(self, "_filters"):
self._filters = []
for name, components in self._config.get("filters", [["none", ["take_first"]]]):
for name, components in self._config.get(
"filters", [["none", ["take_first"]]]
):
filter_pipeline = build_filter_ensemble(name, components)
self._filters.append(filter_pipeline)
self.sampler = samplers.Sampler(self.fewshot_docs(), self, rnd=random.Random()) # TODO: pass the correct docs in here
self.sampler = samplers.Sampler(
list(self.fewshot_docs()), self, rnd=random.Random()
) # TODO: pass the correct docs in here
def download(self, data_dir=None, cache_dir=None, download_mode=None):
"""Downloads and returns the task dataset.
......@@ -215,7 +241,10 @@ class Task(abc.ABC):
elif self.has_validation_docs():
return self.validation_docs()
else:
# TODO: should we allow this case to occur? / should raise a warning here
eval_logger.warning(
"has_training_docs and has_validation_docs are False"
"using test_docs but this is not recommended."
)
return self.test_docs()
def _process_doc(self, doc):
......@@ -268,20 +297,24 @@ class Task(abc.ABC):
), f"Task dataset (path={self.DATASET_PATH}, name={self.DATASET_NAME}) must have valid or test docs!"
instances = []
for doc_id, doc in utils.create_iterator(enumerate(docs), rank, world_size, limit):
for doc_id, doc in utils.create_iterator(
enumerate(docs), rank, world_size, limit
):
# sample fewshot context #TODO: need to offset doc_id by rank now!
fewshot_ctx = self.fewshot_context(
doc, self._config.num_fewshot, rnd=random.Random()
)
# TODO: hardcoded for now: # of runs on each input to be 2. # TODO: we should override this if doing greedy gen so users don't waste time+compute
inst = self.construct_requests(doc=doc, ctx=fewshot_ctx, metadata=(self._config["task_name"], doc_id, 1))
inst = self.construct_requests(
doc=doc,
ctx=fewshot_ctx,
metadata=(self._config["task_name"], doc_id, self._config.repeats),
)
if not isinstance(inst, list):
inst = [inst]
instances.extend(inst)
self._instances = instances
assert len(self._instances) != 0, "task.build_requests() did not find any docs!"
......@@ -302,7 +335,7 @@ class Task(abc.ABC):
whichever is the main split used.
:param repeats: int
TODO: update this docstring
The number of times each instance in a dataset is inferred on. Defaults to 1,
The number of times each instance in a dataset is inferred on. Defaults to 1,
can be increased for techniques like majority voting.
"""
pass
......@@ -370,7 +403,6 @@ class Task(abc.ABC):
if num_fewshot == 0:
labeled_examples = ""
else:
labeled_examples = self.sampler.get_context(doc, self._config.num_fewshot)
# for sets with no training docs, draw from other set *but ensure no overlap with current doc*
......@@ -417,13 +449,21 @@ class ConfigurableTask(Task):
def __init__(
self, data_dir=None, cache_dir=None, download_mode=None, config: dict = None
):
# if we are a subclass that has the CONFIG class attr set, ignore whatever is passed.
# Get pre-configured attributes
self._config = self.CONFIG
# else, if a config was passed as kwarg: use it
if (self._config is None) and config:
# Use new configurations if there was no preconfiguration
if self._config is None:
self._config = TaskConfig(**config)
# Overwrite configs
else:
if config is not None:
self._config.__dict__.update(config)
if self._config is None:
raise ValueError("Must pass a config to ConfigurableTask, either in cls.CONFIG or `config` kwarg")
raise ValueError(
"Must pass a config to ConfigurableTask, either in cls.CONFIG or `config` kwarg"
)
if self._config.output_type is not None:
self.OUTPUT_TYPE = self._config.output_type
......@@ -441,16 +481,22 @@ class ConfigurableTask(Task):
self._higher_is_better = {}
for metric_config in self._config.metric_list:
metric_name = metric_config['metric']
aggregation = metric_config['aggregation']
higher_is_better = metric_config['higher_is_better']
kwargs = {key: metric_config[key] for key in metric_config if key not in ['metric', 'aggregation', 'higher_is_better']}
metric_name = metric_config["metric"]
aggregation = metric_config["aggregation"]
higher_is_better = metric_config["higher_is_better"]
kwargs = {
key: metric_config[key]
for key in metric_config
if key not in ["metric", "aggregation", "higher_is_better"]
}
self._aggregation_list[metric_name] = AGGREGATION_REGISTRY[aggregation]
if metric_name in METRIC_REGISTRY.keys():
self._metric_list[metric_name] = METRIC_REGISTRY[metric_name]
self._higher_is_better[metric_name] = HIGHER_IS_BETTER_REGISTRY[metric_name]
self._higher_is_better[metric_name] = HIGHER_IS_BETTER_REGISTRY[
metric_name
]
else:
self._higher_is_better[metric_name] = higher_is_better
try:
......@@ -458,7 +504,7 @@ class ConfigurableTask(Task):
self._metric_list[metric_name] = metric_object
self._metric_kwargs[metric_name] = kwargs
except Exception as ex:
except Exception:
raise Warning(
"{} not found in the evaluate library!".format(metric_name),
"Please check https://huggingface.co/evaluate-metric",
......@@ -468,13 +514,38 @@ class ConfigurableTask(Task):
self._training_docs = None
self._fewshot_docs = None
self._filters = []
for name, components in self._config.get("filters", [["none", ["take_first"]]]):
filter_pipeline = build_filter_ensemble(name, components)
if self._config.filter_list is not None:
for filter_config in self._config.filter_list:
for filter_pipeline in filter_config:
filter_name = filter_config["name"]
filter_functions = filter_config["filter"]
components = []
for function in filter_functions:
kwargs = {
key: function[key] for key in function if key != "function"
}
components.append([function["function"], kwargs])
filter_pipeline = build_filter_ensemble(filter_name, components)
self._filters.append(filter_pipeline)
self.sampler = samplers.Sampler(list(self.fewshot_docs()), self, rnd=random.Random()) # TODO: pass the correct docs in here
else:
self._filters = [
build_filter_ensemble("take_first", [["take_first", None]])
]
if self._config.use_prompt is not None:
eval_logger.info(f"loading prompt {self._config.use_prompt}")
self.prompt = get_prompt(
self._config.use_prompt, self.DATASET_PATH, self.DATASET_NAME
)
else:
self.prompt = None
if self.fewshot_docs() is not None:
self.sampler = samplers.Sampler(
list(self.fewshot_docs()), self, rnd=random.Random()
) # TODO: pass the correct docs in here
def has_training_docs(self):
if self._config.training_split is not None:
......@@ -507,12 +578,16 @@ class ConfigurableTask(Task):
return self.dataset[self._config.test_split]
def fewshot_docs(self):
if self._config.fewshot_split:
return self.dataset[self._config.fewshot_split]
else:
# TODO: warn user if fewshot split isn't explicitly set
if (self._config.num_fewshot > 0) and (self._config.fewshot_split is None):
eval_logger.warning(
"num_fewshot > 0 but fewshot_split is None. "
"using preconfigured rule."
)
return super().fewshot_docs()
elif self._config.fewshot_split is not None:
return self.dataset[self._config.fewshot_split]
def should_decontaminate(self):
return self._config.should_decontaminate
......@@ -532,67 +607,90 @@ class ConfigurableTask(Task):
return doc
def doc_to_text(self, doc):
if self._config.use_prompt is not None:
doc_to_text = get_prompt(self._config.use_prompt)
if self.prompt is not None:
doc_to_text = self.prompt
else:
doc_to_text = self._config.doc_to_text
return utils.apply_template(doc_to_text, doc)
if type(doc_to_text) == str:
return utils.apply_template(doc_to_text, doc)
elif callable(doc_to_text):
return doc_to_text(doc)
if hasattr(doc_to_text, "apply"):
return doc_to_text.apply(doc)[0]
else:
print(type(doc_to_text))
raise TypeError
def doc_to_target(self, doc):
return utils.apply_template(self._config.doc_to_target, doc)
if self.prompt is not None:
doc_to_target = self.prompt
else:
doc_to_target = self._config.doc_to_target
if type(doc_to_target) == str:
return utils.apply_template(doc_to_target, doc)
elif callable(doc_to_target):
return doc_to_target(doc)
elif hasattr(doc_to_target, "apply"):
return doc_to_target.apply(doc)[1]
else:
raise TypeError
def construct_requests(self, doc, ctx, **kwargs):
if self.OUTPUT_TYPE == "loglikelihood":
arguments=(ctx, self.doc_to_target(doc))
arguments = (ctx, self.doc_to_target(doc))
elif self.OUTPUT_TYPE == "loglikelihood_rolling":
arguments=(self.doc_to_target(doc),)
arguments = (self.doc_to_target(doc),)
elif self.OUTPUT_TYPE == "multiple_choice":
# we pass the user-defined answer_choices var (in aliases) and translate the result to a Python list.
# TODO: any cleaner way to do this?
choices = ast.literal_eval(utils.apply_template(self._config.template_aliases + "{{answer_choices}}", doc))
choices = ast.literal_eval(
utils.apply_template(
self._config.template_aliases + "{{answer_choices}}", doc
)
)
request_list = [
Instance(
request_type="loglikelihood",
doc=doc,
doc=doc,
arguments=(ctx, " {}".format(choice)),
idx=i,
**kwargs,
)
for i, choice in enumerate(choices)
for i, choice in enumerate(choices)
]
# TODO: we should raise a warning telling users this will at most ~2x runtime.
if "acc_mutual_info" in self._metric_list.keys():
# if we are calculating multiple choice accuracy
# using mutual information instead of raw loglikelihood as metric, need unconditional lls.
# here mutual info refers to calculating
# here mutual info refers to calculating
# log(P(choice|ctx) / P(choice)) = log(P(choice|ctx)) - log(P(choice))
# in other words normalizing by subtracting the unconditional logprob of each choice.
request_list.extend(
[
Instance(
request_type="loglikelihood",
doc=doc,
doc=doc,
arguments=("", "{}".format(choice)),
idx=i,
**kwargs,
)
for i, choice in enumerate(choices)
for i, choice in enumerate(choices)
]
)
return request_list
elif self.OUTPUT_TYPE == "greedy_until":
arguments=(ctx, self._config.delimiter)
arguments = (ctx, self._config.delimiter)
return Instance(
request_type=self.OUTPUT_TYPE,
doc=doc,
arguments=arguments,
idx=0,
**kwargs
)
request_type=self.OUTPUT_TYPE, doc=doc, arguments=arguments, idx=0, **kwargs
)
def process_results(self, doc, results):
......@@ -611,11 +709,20 @@ class ConfigurableTask(Task):
"bits_per_byte": (loglikelihood, bytes_),
}
elif self.OUTPUT_TYPE == "multiple_choice":
lls = [res[0] for res in results] # only retain loglikelihoods, discard is_greedy
lls = [
res[0] for res in results
] # only retain loglikelihoods, discard is_greedy
gold = int(self.doc_to_target(doc))
# retrieve choices in List[str] form, to compute choice lengths, etc.
choices = ast.literal_eval(utils.apply_template(self._config.template_aliases + "{{answer_choices}}", doc))
if 2 * len(choices) == len(lls) and "acc_mutual_info" in self._metric_list.keys():
choices = ast.literal_eval(
utils.apply_template(
self._config.template_aliases + "{{answer_choices}}", doc
)
)
if (
2 * len(choices) == len(lls)
and "acc_mutual_info" in self._metric_list.keys()
):
# then we are doing mutual info.
# this stores the "dryrun" / unconditional answer loglikelihoods
lls_unconditional = lls[1::2]
......@@ -636,12 +743,16 @@ class ConfigurableTask(Task):
if "exact_match" in self._metric_list.keys():
# TODO: this gets score of 0 on arc_challenge for pythia-70m. need to test that this works properly
is_greedy = [res[1] for res in results] # take only the `is_greedy` results
is_greedy = is_greedy[gold] # take value for the gold answer
is_greedy = [
res[1] for res in results
] # take only the `is_greedy` results
is_greedy = is_greedy[gold] # take value for the gold answer
result_dict["exact_match"] = int(is_greedy)
if "acc_mutual_info" in self._metric_list.keys():
lls_mutual_info = [ll_c - ll_u for ll_c, ll_u in zip(lls, lls_unconditional)]
lls_mutual_info = [
ll_c - ll_u for ll_c, ll_u in zip(lls, lls_unconditional)
]
acc_mutual_info = 1.0 if np.argmax(lls_mutual_info) == gold else 0.0
result_dict["acc_mutual_info"] = acc_mutual_info
......@@ -654,15 +765,14 @@ class ConfigurableTask(Task):
for key, result in zip(self._metric_list.keys(), results):
_dict = self._metric_list[key].compute(
references=[gold],
predictions=[result],
**self._metric_kwargs[key]
references=[gold], predictions=[result], **self._metric_kwargs[key]
)
result_dict[key] = _dict[key]
else:
raise ValueError(f"Passed invalid output_type '{self.OUTPUT_TYPE}' ! Please use one of ",
"'loglikelihood', 'loglikelihood_rolling', 'greedy_until'"
raise ValueError(
f"Passed invalid output_type '{self.OUTPUT_TYPE}' ! Please use one of ",
"'loglikelihood', 'loglikelihood_rolling', 'greedy_until'",
)
return result_dict
......@@ -683,17 +793,21 @@ class MultipleChoiceTask(Task):
def construct_requests(self, doc, ctx, **kwargs):
# TODO: add mutual info here?
return [Instance(
return [
Instance(
request_type="loglikelihood",
doc=doc,
doc=doc,
arguments=(ctx, " {}".format(choice)),
idx=i,
**kwargs,
)
for i, choice in enumerate(doc["choices"])]
for i, choice in enumerate(doc["choices"])
]
def process_results(self, doc, results):
results = [res[0] for res in results] # only retain loglikelihoods, discard is_greedy TODO: do we need is_greedy anywhere?
results = [
res[0] for res in results
] # only retain loglikelihoods, discard is_greedy TODO: do we need is_greedy anywhere?
gold = doc["gold"]
acc = 1.0 if np.argmax(results) == gold else 0.0
......@@ -718,7 +832,7 @@ class MultipleChoiceTask(Task):
}
class PerplexityTask(Task, abc.ABC):
class PerplexityTask(Task):
OUTPUT_TYPE = "loglikelihood_rolling"
......@@ -729,9 +843,7 @@ class PerplexityTask(Task, abc.ABC):
assert k == 0
return []
def fewshot_context(
self, doc, num_fewshot, rnd=None
):
def fewshot_context(self, doc, num_fewshot, rnd=None):
assert (
num_fewshot == 0
), "The number of fewshot examples must be 0 for perplexity tasks."
......@@ -760,7 +872,13 @@ class PerplexityTask(Task, abc.ABC):
def construct_requests(self, doc, ctx, **kwargs):
assert not ctx
return Instance(request_type=self.OUTPUT_TYPE, doc=doc, arguments=(self.doc_to_target(doc),), idx=0, **kwargs)
return Instance(
request_type=self.OUTPUT_TYPE,
doc=doc,
arguments=(self.doc_to_target(doc),),
idx=0,
**kwargs,
)
def process_results(self, doc, results):
(loglikelihood,) = results
......@@ -787,118 +905,3 @@ class PerplexityTask(Task, abc.ABC):
def count_words(cls, doc):
"""Downstream tasks with custom word boundaries should override this!"""
return len(re.split(r"\s+", doc))
# TODO: confirm we want this to go in this file
TASK_REGISTRY = {}
ALL_TASKS = []
def register_task(*names):
# either pass a list or a single alias.
# function receives them as a tuple of strings
def decorate(cls):
for name in names:
assert (
issubclass(cls, Task)
), f"Task '{name}' ({cls.__name__}) must extend Task class"
assert (
name not in TASK_REGISTRY
), f"Task named '{name}' conflicts with existing task! Please register with a non-conflicting alias instead."
TASK_REGISTRY[name] = cls
ALL_TASKS = sorted(list(TASK_REGISTRY)) # TODO: this doesn't seem to import right.
return cls
return decorate
def register_yaml_task(yaml_path):
# same goal as register_task() but used to register yamls
import yaml
with open(yaml_path, "r") as f:
config = yaml.load(f, yaml.Loader)
from functools import partial
# TODO: strip whitespace from name?
# TODO: ensure num_fewshot overrides the config vals
def decorate(names, cls):
for name in names:
assert (
issubclass(cls, Task)
), f"Task '{name}' ({cls.__name__}) must extend Task class"
assert (
name not in TASK_REGISTRY
), f"Task named '{name}' conflicts with existing task! Please register with a non-conflicting alias instead."
TASK_REGISTRY[name] = cls
ALL_TASKS = sorted(list(TASK_REGISTRY)) # TODO: this doesn't seem to import properly.
return cls
# we create a subclass that has subclass attr CONFIG = our yaml config, and decorate with the config's specified aliases
names = config['names']
yaml_task = decorate(
names,
type(config['names'][0] + 'ConfigurableTask', (ConfigurableTask,), {'CONFIG': TaskConfig(**config)})
)
##### Task registry utils and setup.
# ALL_TASKS = sorted(list(TASK_REGISTRY))
def get_task(task_name):
try:
return TASK_REGISTRY[task_name]
except KeyError:
print("Available tasks:")
pprint(TASK_REGISTRY)
raise KeyError(f"Missing task {task_name}")
def get_task_name_from_object(task_object):
for name, class_ in TASK_REGISTRY.items():
if class_ is task_object:
return name
# TODO: scrap this
# this gives a mechanism for non-registered tasks to have a custom name anyways when reporting
return (
task_object.EVAL_HARNESS_NAME
if hasattr(task_object, "EVAL_HARNESS_NAME")
else type(task_object).__name__
)
def get_task_name_from_config(task_config):
return "configurable_{dataset_path}_{dataset_name}".format(**task_config)
def get_task_dict(task_name_list: List[Union[str, dict, Task]], num_fewshot=None): # TODO: pass num_fewshot and other cmdline overrides in a better way
task_name_dict = {
task_name: get_task(task_name)(config={"num_fewshot": num_fewshot if num_fewshot else 0, "task_name": task_name})
for task_name in task_name_list
if isinstance(task_name, str)
}
task_name_from_config_dict = {
get_task_name_from_config(task_config): ConfigurableTask(
config=task_config
)
for task_config in task_name_list
if isinstance(task_config, dict)
}
task_name_from_object_dict = {
get_task_name_from_object(task_object): task_object
for task_object in task_name_list
if isinstance(task_object, Task)
}
assert set(task_name_dict.keys()).isdisjoint(set(task_name_from_object_dict.keys()))
return {
**task_name_dict,
**task_name_from_config_dict,
**task_name_from_object_dict,
}
\ No newline at end of file
import collections
import random
import itertools
import collections
import torch
import numpy as np
import random
import lm_eval.api
import lm_eval.api.metrics
import lm_eval.models
import lm_eval.tasks
import lm_eval.api
from lm_eval.utils import positional_deprecated, run_task_tests, make_table, create_iterator
import torch
import lm_eval.models
from lm_eval.utils import (
positional_deprecated,
run_task_tests,
make_table,
create_iterator,
get_git_commit_hash,
)
from lm_eval.logger import eval_logger
@positional_deprecated
def simple_evaluate(
......@@ -65,7 +79,7 @@ def simple_evaluate(
assert isinstance(model, lm_eval.api.model.LM)
lm = model
task_dict = lm_eval.api.task.get_task_dict(tasks, num_fewshot=num_fewshot)
task_dict = lm_eval.tasks.get_task_dict(tasks, num_fewshot=num_fewshot)
if check_integrity:
run_task_tests(task_list=tasks)
......@@ -73,7 +87,6 @@ def simple_evaluate(
results = evaluate(
lm=lm,
task_dict=task_dict,
num_fewshot=num_fewshot,
limit=limit,
bootstrap_iters=bootstrap_iters,
decontamination_ngrams_path=decontamination_ngrams_path,
......@@ -91,13 +104,12 @@ def simple_evaluate(
"limit": limit,
"bootstrap_iters": bootstrap_iters,
}
results["git_hash"] = get_git_commit_hash()
return results
else:
return None
decontaminate_suffix = "_decontaminate"
......@@ -105,7 +117,6 @@ decontaminate_suffix = "_decontaminate"
def evaluate(
lm,
task_dict,
num_fewshot=0,
limit=None,
bootstrap_iters=100000,
decontamination_ngrams_path=None,
......@@ -126,47 +137,54 @@ def evaluate(
Dictionary of results
"""
decontaminate = decontamination_ngrams_path is not None
# decontaminate = decontamination_ngrams_path is not None
results = collections.defaultdict(dict)
versions = collections.defaultdict(dict)
requests = collections.defaultdict(list)
requests_origin = collections.defaultdict(list)
# requests_origin = collections.defaultdict(list)
docs = {}
# docs = {}
# get lists of each type of request
for task_name, task in task_dict.items():
versions[task_name] = task.VERSION
# deterministically shuffle docs and chop off the first `limit` because sometimes docs are in some kind of order
# task_docs = list(task_doc_func())
# rnd = random.Random()
# rnd.seed(42)
# rnd.shuffle(task_docs)
task.build_all_requests(limit=limit, rank = lm.rank, world_size = lm.world_size)
task.build_all_requests(limit=limit, rank=lm.rank, world_size=lm.world_size)
# aggregate Instances by LM method requested to get output.
reqtype = "loglikelihood" if task.OUTPUT_TYPE == "multiple_choice" else task.OUTPUT_TYPE #TODO: this is hacky, fix in task.py
requests[reqtype].extend(task.instances)
reqtype = (
"loglikelihood"
if task.OUTPUT_TYPE == "multiple_choice"
else task.OUTPUT_TYPE
) # TODO: this is hacky, fix in task.py
requests[reqtype].extend(task.instances)
if lm.world_size > 1:
instances_rnk = torch.tensor(len(task._instances), device = lm.device)
gathered_item = lm.accelerator.gather(instances_rnk).cpu().detach().numpy().tolist()
instances_rnk = torch.tensor(len(task._instances), device=lm.device)
gathered_item = (
lm.accelerator.gather(instances_rnk).cpu().detach().numpy().tolist()
)
# compute number of pseudobatches to pad with (FSDP/DDP require even batches among ranks)
numpad = max(gathered_item) - gathered_item[lm.rank]
### Run LM on inputs, get all outputs ###
# execute each type of request
for reqtype, reqs in requests.items():
print("Running", reqtype, "requests")
eval_logger.info("Running {} requests".format(reqtype))
# create `K` copies of each request `req` based off `K = req.repeats`
cloned_reqs = []
for req in reqs:
cloned_reqs.extend([req] * req.repeats)
if (lm.world_size > 1) and (numpad > 0):
for _ in range(numpad):
cloned_reqs.extend([req] * req.repeats)
......@@ -186,9 +204,8 @@ def evaluate(
for task_name, task in task_dict.items():
task.apply_filters()
### Collect values of metrics on all datapoints ###
# TODO: make metric configurable, add metric registry
# TODO: make metric configurable, add metric registry
vals = collections.defaultdict(list)
# unpack results and sort back in order and return control to Task
......@@ -196,55 +213,70 @@ def evaluate(
# calculate values for each filter setup (TODO: make getting list of keys cleaner)
# TODO: make it possible to use a different metric per key
for key in task.instances[0].filtered_resps.keys():
doc_iterator = itertools.islice(enumerate(task.test_docs()), lm.rank, limit, lm.world_size) if task.has_test_docs() else itertools.islice(enumerate(task.validation_docs()), lm.rank, limit, lm.world_size)
doc_iterator = (
itertools.islice(
enumerate(task.test_docs()), lm.rank, limit, lm.world_size
)
if task.has_test_docs()
else itertools.islice(
enumerate(task.validation_docs()), lm.rank, limit, lm.world_size
)
)
for doc_id, doc in doc_iterator:
# subset instances to only this document id ; sort by idx
requests = list(filter(lambda x: x.doc_id == doc_id, task.instances))
requests.sort(key=lambda x: x.idx)
metrics = task.process_results(doc, [req.filtered_resps[key] for req in requests])
metrics = task.process_results(
doc, [req.filtered_resps[key] for req in requests]
)
for metric, value in metrics.items():
vals[(task_name, key, metric)].append(value)
if lm.world_size > 1:
# if multigpu, then gather data across all ranks
# if multigpu, then gather data across all ranks
vals_torch = collections.defaultdict(list)
for (task_name, key, metric), items in vals.items():
numitem = 0
numitem = 0
if type(items[0]) == tuple:
numitem = len(items[0])
numitem = len(items[0])
# distributed gather requires all ranks to have same dimensions
# so we pad out with float32 min value
pad_value = torch.finfo(torch.float32).min
metrics_tensor = torch.tensor(items, device = lm.device)
original_dtype = metrics_tensor.dtype # store original dtype
torch_device_tensor = lm.accelerator.pad_across_processes(metrics_tensor.to(torch.float32), pad_index = pad_value)
metrics_tensor = torch.tensor(items, device=lm.device)
original_dtype = metrics_tensor.dtype # store original dtype
torch_device_tensor = lm.accelerator.pad_across_processes(
metrics_tensor.to(torch.float32), pad_index=pad_value
)
gathered_item = lm.accelerator.gather(torch_device_tensor)
if numitem > 0:
gathered_filtered = gathered_item[gathered_item[:,0] != pad_value]
gathered_filtered = gathered_item[gathered_item[:, 0] != pad_value]
else:
gathered_filtered = gathered_item[gathered_item != pad_value]
gathered_item = gathered_filtered.to(original_dtype).cpu().detach().numpy().tolist()
gathered_item = (
gathered_filtered.to(original_dtype).cpu().detach().numpy().tolist()
)
# reconvert if we were passed a tuple of values
if numitem > 0:
gathered_item = [tuple(g) for g in gathered_item]
if lm.rank == 0:
vals_torch[(task_name, key, metric)] = gathered_item
vals = vals_torch
vals = vals_torch
if lm.rank == 0:
### Aggregate results over all datapoints ###
# aggregate results ; run bootstrap CIs
for (task_name, key, metric), items in vals.items():
task = task_dict[task_name]
results[task_name][metric + " - filter=" + key] = task.aggregation()[metric](items)
results[task_name][metric + " - filter=" + key] = task.aggregation()[
metric
](items)
# hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
# so we run them less iterations. still looking for a cleaner way to do this
......@@ -257,7 +289,9 @@ def evaluate(
)
if stderr is not None:
results[task_name][metric + " - filter=" + key + "_stderr"] = stderr(items)
results[task_name][metric + " - filter=" + key + "_stderr"] = stderr(
items
)
return {"results": dict(results), "versions": dict(versions)}
......
......@@ -6,10 +6,10 @@ from . import extraction
FILTER_REGISTRY = {
"take_first": selection.TakeFirstFilter,
"regex": extraction.RegexFilter,
# TODO: implement this filter. either it should take in an arbitrary "scoring"/reward function
# that takes an input and returns a scalar and then should select the max reward,
# TODO: implement this filter. either it should take in an arbitrary "scoring"/reward function
# that takes an input and returns a scalar and then should select the max reward,
# or should implement different filters for different ways of handling a reward model's inference.
#"arg_max": selection.ArgMaxFilter,
# "arg_max": selection.ArgMaxFilter,
}
......@@ -17,16 +17,19 @@ def get_filter(filter_name):
return FILTER_REGISTRY[filter_name]
def build_filter_ensemble(name, components):
def build_filter_ensemble(filter_name, components):
"""
Create a filtering pipeline.
"""
filters = []
for step in components:
for (function, kwargs) in components:
if kwargs is None:
f = get_filter(function)()
else:
# create a filter given its name in the registry
f = get_filter(step)() # TODO: pass kwargs to filters properly
# add the filter as a pipeline step
filters.append(f)
f = get_filter(function)(**kwargs) # TODO: pass kwargs to filters properly
# add the filter as a pipeline step
filters.append(f)
return FilterEnsemble(name=name, filters=filters)
return FilterEnsemble(name=filter_name, filters=filters)
......@@ -4,7 +4,7 @@ from lm_eval.api.filter import Filter
class DecontaminationFilter(Filter):
"""
A filter which evaluates
A filter which evaluates
"""
name = "track_decontamination"
......@@ -12,7 +12,7 @@ class DecontaminationFilter(Filter):
def __init__(self, path):
"""
TODO: make sure only ever run one time on the train set (should this be cached as a class var? keyed by value for "path").
TODO: make sure only ever run one time on the train set (should this be cached as a class var? keyed by value for "path").
should further cache result on a given (task_name, doc_id)
"""
self._decontam_results = None
......@@ -21,4 +21,4 @@ class DecontaminationFilter(Filter):
"""
Return {"no_contamination", "only_contamination"} keys for the 2 different subsets
"""
pass
\ No newline at end of file
pass
......@@ -4,24 +4,20 @@ from lm_eval.api.filter import Filter
class RegexFilter(Filter):
"""
""" """
"""
def __init__(self, regex=r"#### (\-?[0-9\.\,]+)", fallback="[invalid]"):
def __init__(self, regex_pattern=r"#### (\-?[0-9\.\,]+)", fallback="[invalid]"):
"""
pass a string `regex` to run `re.compile(r"regex")` on.
`fallback` defines the output returned if no matches for the regex are located.
"""
self.regex_pattern = regex
self.regex = re.compile(regex)
self.regex_pattern = regex_pattern
self.regex = re.compile(regex_pattern)
self.fallback = fallback
def apply(self, resps):
# here, we assume we have a list, in which each element is
# a list of model responses for some particular input/target pair.
# a list of model responses for some particular input/target pair.
# so we process each of these (same input/target response sets)
# independently (and keep them a list.)
def filter_set(inst):
......@@ -30,7 +26,7 @@ class RegexFilter(Filter):
match = self.regex.search(resp)
if match:
match = match.group(1).strip()
match_str.replace(",", "")
match.replace(",", "")
# TODO: should we assume any other filtering is performed?
else:
match = self.fallback
......
from lm_eval.api.filter import Filter
class TakeFirstFilter:
class TakeFirstFilter:
def __init__(self):
"""
Can define custom behavior here, if an individual instantiation of a Filter class should have state.
......@@ -11,4 +11,4 @@ class TakeFirstFilter:
"""
Assuming each entry of `resps` is a list of model responses, we discard all but the first response.
"""
return map(lambda r: r[0], resps)
\ No newline at end of file
return map(lambda r: r[0], resps)
import logging
logging.basicConfig(
format="%(asctime)s,%(msecs)03d %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s",
datefmt="%Y-%m-%d:%H:%M:%S",
level=logging.INFO,
)
eval_logger = logging.getLogger("lm-eval")
......@@ -6,11 +6,13 @@ from tqdm import tqdm
import torch.nn.functional as F
from lm_eval import utils
from lm_eval.logger import eval_logger
from lm_eval.api.model import LM, register_model
from accelerate import Accelerator
from itertools import islice
@register_model("hf-causal", "gpt2")
class HFLM(LM):
def __init__(
......@@ -28,10 +30,10 @@ class HFLM(LM):
assert isinstance(device, str)
assert isinstance(pretrained, str)
assert isinstance(batch_size, int)
gpus = torch.cuda.device_count()
if gpus <= 1:
if device:
if device:
if device not in ["cuda", "cpu"]:
device = int(device)
self._device = torch.device(device)
......@@ -48,7 +50,7 @@ class HFLM(LM):
self._world_size = 1
else:
self._device = 'cpu'
self._device = "cpu"
# TODO: update this to be less of a hack once subfolder is fixed in HF
revision = revision + ("/" + subfolder if subfolder is not None else "")
......@@ -72,10 +74,12 @@ class HFLM(LM):
if gpus > 1:
accelerator = Accelerator()
if gpus > accelerator.num_processes:
warning = ("WARNING: The number of total system GPUs does not match the number of spawned processes. "
"If you would like to use data parallelism, please launch the script "
"with 'accelerate launch *script*'. "
f"Current run will proceed with {accelerator.num_processes} devices.")
warning = (
"WARNING: The number of total system GPUs does not match the number of spawned processes. "
"If you would like to use data parallelism, please launch the script "
"with 'accelerate launch *script*'. "
f"Current run will proceed with {accelerator.num_processes} devices."
)
print(warning)
self._rank = accelerator.local_process_index
self._world_size = accelerator.num_processes
......@@ -90,7 +94,6 @@ class HFLM(LM):
self._rank = self.accelerator.local_process_index
self._world_size = self.accelerator.num_processes
@property
def eot_token_id(self):
......@@ -100,14 +103,16 @@ class HFLM(LM):
@property
def max_length(self):
try:
if hasattr(self, 'accelerator'):
if hasattr(self, "accelerator"):
return self.accelerator.unwrap_model(self.gpt2).config.n_ctx
else:
return self.gpt2.config.n_ctx
except AttributeError:
# gptneoconfig doesn't have n_ctx apparently
if hasattr(self, 'accelerator'):
return self.accelerator.unwrap_model(self.gpt2).config.max_position_embeddings
if hasattr(self, "accelerator"):
return self.accelerator.unwrap_model(
self.gpt2
).config.max_position_embeddings
else:
return self.gpt2.config.max_position_embeddings
......@@ -122,7 +127,7 @@ class HFLM(LM):
@property
def device(self):
return self._device
@property
def rank(self):
return self._rank
......@@ -150,7 +155,11 @@ class HFLM(LM):
def _model_generate(self, context, max_length, eos_token_id):
return self.gpt2.generate(
context, max_length=max_length, pad_token_id=eos_token_id, eos_token_id=eos_token_id, do_sample=False
context,
max_length=max_length,
pad_token_id=eos_token_id,
eos_token_id=eos_token_id,
do_sample=False,
)
def loglikelihood(self, requests):
......@@ -173,7 +182,7 @@ class HFLM(LM):
# TODO: automatic batch size detection for vectorization
loglikelihoods = []
for (string,) in tqdm([req.args for req in requests],disable=(self.rank != 0)):
for (string,) in tqdm([req.args for req in requests], disable=(self.rank != 0)):
rolling_token_windows = list(
map(
utils.make_disjoint_window,
......@@ -185,22 +194,24 @@ class HFLM(LM):
),
)
)
rolling_token_windows = [(None,) + x for x in rolling_token_windows]
# TODO: extract out this call so it only gets called once and also somehow figure out partial caching for
# that
pad_amnt = 0
pad_amnt = 0
if self.world_size > 1:
#TODO: Comment on what we do here
mytensor = torch.tensor(len(rolling_token_windows), device = self.device)
gathered = self.accelerator.gather(mytensor).cpu().detach().numpy().tolist()
# TODO: Comment on what we do here
mytensor = torch.tensor(len(rolling_token_windows), device=self.device)
gathered = (
self.accelerator.gather(mytensor).cpu().detach().numpy().tolist()
)
pad_amnt = max(gathered) - gathered[self.rank]
if pad_amnt > 0:
rolling_token_windows += pad_amnt*[rolling_token_windows[0]]
rolling_token_windows += pad_amnt * [rolling_token_windows[0]]
string_nll = self._loglikelihood_tokens(
rolling_token_windows, disable_tqdm=True
)
......@@ -210,10 +221,9 @@ class HFLM(LM):
else:
# discard is_greedy
string_nll = [x[0] for x in string_nll]
string_nll = sum(string_nll)
loglikelihoods.append(string_nll)
return loglikelihoods
......@@ -235,9 +245,10 @@ class HFLM(LM):
# TODO: automatic (variable) batch size detection for vectorization
re_ord = utils.Reorderer(requests, _collate)
for chunk in utils.chunks(
tqdm(re_ord.get_reordered(), disable=(disable_tqdm or (self.rank != 0))), self.batch_size
):
tqdm(re_ord.get_reordered(), disable=(disable_tqdm or (self.rank != 0))),
self.batch_size,
):
inps = []
cont_toks_list = []
inplens = []
......
import os
import numpy as np
import time
import transformers
from lm_eval.api.model import LM, register_model
from lm_eval import utils
import numpy as np
from tqdm import tqdm
import time
from lm_eval import utils
from lm_eval.api.model import LM, register_model
def get_result(response, ctxlen):
......
from lm_eval.logger import eval_logger
from promptsource.templates import DatasetTemplates
# TODO: decide whether we want jinja2 or f-string prompts. would it be cursed to support both?
# Prompt library.
# Prompt library.
# Stores prompts in a dictionary indexed by 2 levels:
# prompt category name, and prompt name.
# This allows us to access prompts
PROMPT_REGISTRY = {
"qa-basic": {
"question-newline-answer": "Question: {{question}}\nAnswer:",
"q-newline-a": "Q: {question}\nA:"
"q-newline-a": "Q: {{question}}\nA:",
},
}
def get_prompt(prompt_id: str):
# unpack prompt name
try:
category_name, prompt_name = prompt_id.split(":")
except:
raise ValueError(
f"expected only a single `:` as separator between \
prompt category and name, but got `{prompt_id}` instead"
def get_prompt(prompt_id: str, dataset_name=None, subset_name=None):
# unpack prompt name
category_name, prompt_name = prompt_id.split(":")
if subset_name is None:
dataset_full_name = dataset_name
else:
dataset_full_name = f"{dataset_name}-{subset_name}"
eval_logger.info(f"Loading prompt from {category_name} for {dataset_full_name}")
if category_name == "promptsource":
try:
if subset_name is None:
prompts = DatasetTemplates(dataset_name=dataset_name)
else:
prompts = DatasetTemplates(
dataset_name=dataset_name, subset_name=subset_name
)
except Exception:
raise ValueError(f"{dataset_name} and {subset_name} not found")
if prompt_name in prompts.all_template_names:
return prompts[prompt_name]
else:
raise ValueError(
f"{prompt_name} not in prompt list {prompts.all_template_names}"
)
else:
try:
return PROMPT_REGISTRY[category_name][prompt_name]
except Exception:
raise ValueError(
f"expected only a single `:` as separator between \
prompt category and name, but got `{prompt_id}` instead"
)
return PROMPT_REGISTRY[category_name][prompt_name]
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment