Unverified Commit d924ca33 authored by ben's avatar ben Committed by GitHub
Browse files

Merge pull request #2 from EleutherAI/multigpu-feature-minor-edits

Multigpu feature minor edits
parents 650d3c76 c77fa461
...@@ -3,3 +3,5 @@ env ...@@ -3,3 +3,5 @@ env
data/ data/
lm_cache lm_cache
.idea .idea
*.egg-info/
...@@ -12,6 +12,7 @@ repos: ...@@ -12,6 +12,7 @@ repos:
- id: check-merge-conflict - id: check-merge-conflict
- id: check-symlinks - id: check-symlinks
- id: check-yaml - id: check-yaml
args: ['--unsafe']
- id: destroyed-symlinks - id: destroyed-symlinks
- id: detect-private-key - id: detect-private-key
- id: end-of-file-fixer - id: end-of-file-fixer
......
...@@ -45,7 +45,7 @@ python main.py \ ...@@ -45,7 +45,7 @@ python main.py \
--device cuda:0 --device cuda:0
``` ```
Additional arguments can be provided to the model constructor using the `--model_args` flag. Most notably, this supports the common practice of using the `revisions` feature on the Hub to store partialy trained checkpoints: Additional arguments can be provided to the model constructor using the `--model_args` flag. Most notably, this supports the common practice of using the `revisions` feature on the Hub to store partially trained checkpoints:
```bash ```bash
python main.py \ python main.py \
...@@ -78,7 +78,7 @@ python main.py \ ...@@ -78,7 +78,7 @@ python main.py \
--tasks lambada_openai,hellaswag --tasks lambada_openai,hellaswag
``` ```
While this functionality is only officially mantained for the official OpenAI API, it tends to also work for other hosting services that use the same API such as [goose.ai](goose.ai) with minor modification. We also have an implementation for the [TextSynth](https://textsynth.com/index.html) API, using `--model textsynth`. While this functionality is only officially maintained for the official OpenAI API, it tends to also work for other hosting services that use the same API such as [goose.ai](goose.ai) with minor modification. We also have an implementation for the [TextSynth](https://textsynth.com/index.html) API, using `--model textsynth`.
To verify the data integrity of the tasks you're performing in addition to running the tasks themselves, you can use the `--check_integrity` flag: To verify the data integrity of the tasks you're performing in addition to running the tasks themselves, you can use the `--check_integrity` flag:
...@@ -129,7 +129,7 @@ When reporting eval harness results, please also report the version of each task ...@@ -129,7 +129,7 @@ When reporting eval harness results, please also report the version of each task
## Test Set Decontamination ## Test Set Decontamination
To address concerns about train / test contamination, we provide utilities for comparing results on a benchmark using only the data points nto found in the model trainign set. Unfortunately, outside of models trained on the Pile ans C4, its very rare that people who train models disclose the contents of the training data. However this utility can be useful to evaluate models you have trained on private data, provided you are willing to pre-compute the necessary indices. We provide computed indices for 13-gram exact match deduplication against the Pile, and plan to add additional precomputed dataset indices in the future (including C4 and min-hash LSH deduplication). To address concerns about train / test contamination, we provide utilities for comparing results on a benchmark using only the data points nto found in the model trainign set. Unfortunately, outside of models trained on the Pile and C4, its very rare that people who train models disclose the contents of the training data. However this utility can be useful to evaluate models you have trained on private data, provided you are willing to pre-compute the necessary indices. We provide computed indices for 13-gram exact match deduplication against the Pile, and plan to add additional precomputed dataset indices in the future (including C4 and min-hash LSH deduplication).
For details on text decontamination, see the [decontamination guide](./docs/decontamination.md). For details on text decontamination, see the [decontamination guide](./docs/decontamination.md).
......
from . import metrics from . import metrics
METRIC_REGISTRY = {
"matthews_corrcoef": metrics.matthews_corrcoef,
"f1_score": metrics.f1_score,
"perplexity": metrics.perplexity,
"bleu": metrics.bleu,
"chrf": metrics.chrf,
"ter": metrics.ter,
}
AGGREGATION_REGISTRY = {
"mean": metrics.mean,
"median": metrics.median,
"perplexity": metrics.perplexity,
}
HIGHER_IS_BETTER_REGISTRY = {
"matthews_corrcoef": True,
"f1_score": True,
"perplexity": False,
"bleu": True,
"chrf": True,
"ter": False,
"acc": True,
"acc_norm": True,
"acc_mutual_info": True,
"word_perplexity": False,
"byte_perplexity": False,
"bits_per_byte": False,
}
\ No newline at end of file
...@@ -3,6 +3,7 @@ from typing import List ...@@ -3,6 +3,7 @@ from typing import List
from lm_eval.api.instance import Instance from lm_eval.api.instance import Instance
class Filter: class Filter:
""" """
Filter classes operate on a per-task level. Filter classes operate on a per-task level.
...@@ -26,6 +27,7 @@ class Filter: ...@@ -26,6 +27,7 @@ class Filter:
""" """
return resps return resps
@dataclass @dataclass
class FilterEnsemble: class FilterEnsemble:
""" """
...@@ -34,21 +36,23 @@ class FilterEnsemble: ...@@ -34,21 +36,23 @@ class FilterEnsemble:
`task.apply_filters` should use a list of FilterEnsemble classes that it stores, to apply each `task.apply_filters` should use a list of FilterEnsemble classes that it stores, to apply each
pipeline separately. pipeline separately.
""" """
name: str name: str
filters: List[Filter] filters: List[Filter]
def apply(self, instances: List[Instance]): def apply(self, instances: List[Instance]):
resps = [inst.resps for inst in instances] # operate just on the model responses resps = [
inst.resps for inst in instances
] # operate just on the model responses
for f in self.filters: for f in self.filters:
# apply filters in sequence # apply filters in sequence
out = f.apply(resps) out = f.apply(resps)
resps = out # TODO: handle the case where a filter returns multiple "buckets" resps = (
out # TODO: handle the case where a filter returns multiple "buckets"
)
# add the end results after filtering to filtered_requests of their respective source instances. # add the end results after filtering to filtered_requests of their respective source instances.
# has key `self.name`: each FilterEnsemble applied in a given run should use a different name. # has key `self.name`: each FilterEnsemble applied in a given run should use a different name.
for inst, resp in zip(instances, resps): for inst, resp in zip(instances, resps):
inst.filtered_resps[self.name] = resp inst.filtered_resps[self.name] = resp
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Literal, Tuple from typing import Literal, Tuple
@dataclass @dataclass
class Instance: class Instance:
request_type: str = Literal["loglikelihood", "loglikelihood_rolling", "greedy_until"] request_type: str = Literal[
"loglikelihood", "loglikelihood_rolling", "greedy_until"
]
doc: dict = None doc: dict = None
arguments: tuple = None arguments: tuple = None
idx: int = None idx: int = None
...@@ -25,4 +28,6 @@ class Instance: ...@@ -25,4 +28,6 @@ class Instance:
""" """
Returns (string,) where `string` is the string to calculate loglikelihood over Returns (string,) where `string` is the string to calculate loglikelihood over
""" """
return self.arguments if isinstance(self.arguments, tuple) else (self.arguments,) return (
self.arguments if isinstance(self.arguments, tuple) else (self.arguments,)
)
...@@ -10,6 +10,7 @@ import evaluate ...@@ -10,6 +10,7 @@ import evaluate
AGGREGATION_REGISTRY = {} AGGREGATION_REGISTRY = {}
METRIC_REGISTRY = { METRIC_REGISTRY = {
"acc": None, "acc": None,
"acc_norm": None, "acc_norm": None,
...@@ -18,6 +19,21 @@ METRIC_REGISTRY = { ...@@ -18,6 +19,21 @@ METRIC_REGISTRY = {
"byte_perplexity": None, "byte_perplexity": None,
} }
HIGHER_IS_BETTER_REGISTRY = {
"matthews_corrcoef": True,
"f1_score": True,
"perplexity": False,
"bleu": True,
"chrf": True,
"ter": False,
"acc": True,
"acc_norm": True,
"acc_mutual_info": True,
"word_perplexity": False,
"byte_perplexity": False,
"bits_per_byte": False,
}
def register_metric(name): def register_metric(name):
# TODO: do we want to enforce a certain interface to registered metrics? # TODO: do we want to enforce a certain interface to registered metrics?
...@@ -38,12 +54,14 @@ def get_metric(name): ...@@ -38,12 +54,14 @@ def get_metric(name):
return METRIC_REGISTRY[name] return METRIC_REGISTRY[name]
except KeyError: except KeyError:
# TODO: change this print to logging? # TODO: change this print to logging?
print(f"Could not find registered metric '{name}' in lm-eval, \ print(
searching in HF Evaluate library...") f"Could not find registered metric '{name}' in lm-eval, \
searching in HF Evaluate library..."
)
try: try:
metric_object = evaluate.load(name) metric_object = evaluate.load(name)
return metric_object.compute return metric_object.compute
except: except Exception:
raise Warning( raise Warning(
"{} not found in the evaluate library!".format(name), "{} not found in the evaluate library!".format(name),
"Please check https://huggingface.co/evaluate-metric", "Please check https://huggingface.co/evaluate-metric",
......
...@@ -6,14 +6,15 @@ from lm_eval import utils ...@@ -6,14 +6,15 @@ from lm_eval import utils
MODEL_REGISTRY = {} MODEL_REGISTRY = {}
def register_model(*names): def register_model(*names):
# either pass a list or a single alias. # either pass a list or a single alias.
# function receives them as a tuple of strings # function receives them as a tuple of strings
def decorate(cls): def decorate(cls):
for name in names: for name in names:
assert ( assert issubclass(
issubclass(cls, LM) cls, LM
), f"Model '{name}' ({cls.__name__}) must extend LM class" ), f"Model '{name}' ({cls.__name__}) must extend LM class"
assert ( assert (
......
import os
task_registry = {}
group_registry = {}
task2func_index = {}
func2task_index = {}
def register_task(name):
def wrapper(func):
task_registry[name] = func
func2task_index[func.__name__] = name
task2func_index[name] = func.__name__
return func
return wrapper
def register_group(name):
def wrapper(func):
func_name = func2task_index[func.__name__]
if name in group_registry:
group_registry[name].append(func_name)
else:
group_registry[name] = [func_name]
return func
return wrapper
class Sampler: class Sampler:
def __init__(self, docs, task, fewshot_indices=None, rnd=None): def __init__(self, docs, task, fewshot_indices=None, rnd=None):
self.rnd = rnd self.rnd = rnd
...@@ -16,11 +13,14 @@ class Sampler: ...@@ -16,11 +13,14 @@ class Sampler:
if fewshot_indices: # subset few-shot docs from if fewshot_indices: # subset few-shot docs from
self.docs = self.docs.select(fewshot_indices) self.docs = self.docs.select(fewshot_indices)
def get_context(self, doc, num_fewshot): def get_context(self, doc, num_fewshot):
# draw an extra fewshot sample if using same split as evaluting on # draw an extra fewshot sample if using same split as evaluating on
n_samples = num_fewshot + 1 if self.config.fewshot_split == self.config.test_split else num_fewshot n_samples = (
num_fewshot + 1
if self.config.fewshot_split == self.config.test_split
else num_fewshot
)
# draw `n_samples` docs from fewshot_docs # draw `n_samples` docs from fewshot_docs
fewshotex = self.sample(n_samples) fewshotex = self.sample(n_samples)
...@@ -51,7 +51,6 @@ class Sampler: ...@@ -51,7 +51,6 @@ class Sampler:
class BalancedSampler(Sampler): class BalancedSampler(Sampler):
def sample(self, n): def sample(self, n):
""" """
TODO: this should return approximately class-balanced samples from our fewshot examples. TODO: this should return approximately class-balanced samples from our fewshot examples.
...@@ -60,12 +59,10 @@ class BalancedSampler(Sampler): ...@@ -60,12 +59,10 @@ class BalancedSampler(Sampler):
pass pass
class ManualSampler(Sampler):
class ManualSampler(Sampler):
def sample(self, n): def sample(self, n):
""" """ """
"""
pass pass
......
This diff is collapsed.
import collections import random
import itertools import itertools
import collections
import torch
import numpy as np import numpy as np
import random
import lm_eval.api
import lm_eval.api.metrics import lm_eval.api.metrics
import lm_eval.models
import lm_eval.tasks import lm_eval.tasks
import lm_eval.api import lm_eval.models
from lm_eval.utils import positional_deprecated, run_task_tests, make_table, create_iterator
import torch from lm_eval.utils import (
positional_deprecated,
run_task_tests,
make_table,
create_iterator,
get_git_commit_hash,
)
from lm_eval.logger import eval_logger
@positional_deprecated @positional_deprecated
def simple_evaluate( def simple_evaluate(
...@@ -65,7 +79,7 @@ def simple_evaluate( ...@@ -65,7 +79,7 @@ def simple_evaluate(
assert isinstance(model, lm_eval.api.model.LM) assert isinstance(model, lm_eval.api.model.LM)
lm = model lm = model
task_dict = lm_eval.api.task.get_task_dict(tasks, num_fewshot=num_fewshot) task_dict = lm_eval.tasks.get_task_dict(tasks, num_fewshot=num_fewshot)
if check_integrity: if check_integrity:
run_task_tests(task_list=tasks) run_task_tests(task_list=tasks)
...@@ -73,7 +87,6 @@ def simple_evaluate( ...@@ -73,7 +87,6 @@ def simple_evaluate(
results = evaluate( results = evaluate(
lm=lm, lm=lm,
task_dict=task_dict, task_dict=task_dict,
num_fewshot=num_fewshot,
limit=limit, limit=limit,
bootstrap_iters=bootstrap_iters, bootstrap_iters=bootstrap_iters,
decontamination_ngrams_path=decontamination_ngrams_path, decontamination_ngrams_path=decontamination_ngrams_path,
...@@ -91,13 +104,12 @@ def simple_evaluate( ...@@ -91,13 +104,12 @@ def simple_evaluate(
"limit": limit, "limit": limit,
"bootstrap_iters": bootstrap_iters, "bootstrap_iters": bootstrap_iters,
} }
results["git_hash"] = get_git_commit_hash()
return results return results
else: else:
return None return None
decontaminate_suffix = "_decontaminate" decontaminate_suffix = "_decontaminate"
...@@ -105,7 +117,6 @@ decontaminate_suffix = "_decontaminate" ...@@ -105,7 +117,6 @@ decontaminate_suffix = "_decontaminate"
def evaluate( def evaluate(
lm, lm,
task_dict, task_dict,
num_fewshot=0,
limit=None, limit=None,
bootstrap_iters=100000, bootstrap_iters=100000,
decontamination_ngrams_path=None, decontamination_ngrams_path=None,
...@@ -126,15 +137,15 @@ def evaluate( ...@@ -126,15 +137,15 @@ def evaluate(
Dictionary of results Dictionary of results
""" """
decontaminate = decontamination_ngrams_path is not None # decontaminate = decontamination_ngrams_path is not None
results = collections.defaultdict(dict) results = collections.defaultdict(dict)
versions = collections.defaultdict(dict) versions = collections.defaultdict(dict)
requests = collections.defaultdict(list) requests = collections.defaultdict(list)
requests_origin = collections.defaultdict(list) # requests_origin = collections.defaultdict(list)
docs = {} # docs = {}
# get lists of each type of request # get lists of each type of request
for task_name, task in task_dict.items(): for task_name, task in task_dict.items():
...@@ -146,14 +157,21 @@ def evaluate( ...@@ -146,14 +157,21 @@ def evaluate(
# rnd.seed(42) # rnd.seed(42)
# rnd.shuffle(task_docs) # rnd.shuffle(task_docs)
task.build_all_requests(limit=limit, rank = lm.rank, world_size = lm.world_size) task.build_all_requests(limit=limit, rank=lm.rank, world_size=lm.world_size)
# aggregate Instances by LM method requested to get output. # aggregate Instances by LM method requested to get output.
reqtype = "loglikelihood" if task.OUTPUT_TYPE == "multiple_choice" else task.OUTPUT_TYPE #TODO: this is hacky, fix in task.py reqtype = (
"loglikelihood"
if task.OUTPUT_TYPE == "multiple_choice"
else task.OUTPUT_TYPE
) # TODO: this is hacky, fix in task.py
requests[reqtype].extend(task.instances) requests[reqtype].extend(task.instances)
if lm.world_size > 1: if lm.world_size > 1:
instances_rnk = torch.tensor(len(task._instances), device = lm.device) instances_rnk = torch.tensor(len(task._instances), device=lm.device)
gathered_item = lm.accelerator.gather(instances_rnk).cpu().detach().numpy().tolist() gathered_item = (
lm.accelerator.gather(instances_rnk).cpu().detach().numpy().tolist()
)
# compute number of pseudobatches to pad with (FSDP/DDP require even batches among ranks) # compute number of pseudobatches to pad with (FSDP/DDP require even batches among ranks)
numpad = max(gathered_item) - gathered_item[lm.rank] numpad = max(gathered_item) - gathered_item[lm.rank]
...@@ -161,7 +179,7 @@ def evaluate( ...@@ -161,7 +179,7 @@ def evaluate(
### Run LM on inputs, get all outputs ### ### Run LM on inputs, get all outputs ###
# execute each type of request # execute each type of request
for reqtype, reqs in requests.items(): for reqtype, reqs in requests.items():
print("Running", reqtype, "requests") eval_logger.info("Running {} requests".format(reqtype))
# create `K` copies of each request `req` based off `K = req.repeats` # create `K` copies of each request `req` based off `K = req.repeats`
cloned_reqs = [] cloned_reqs = []
for req in reqs: for req in reqs:
...@@ -186,7 +204,6 @@ def evaluate( ...@@ -186,7 +204,6 @@ def evaluate(
for task_name, task in task_dict.items(): for task_name, task in task_dict.items():
task.apply_filters() task.apply_filters()
### Collect values of metrics on all datapoints ### ### Collect values of metrics on all datapoints ###
# TODO: make metric configurable, add metric registry # TODO: make metric configurable, add metric registry
vals = collections.defaultdict(list) vals = collections.defaultdict(list)
...@@ -196,12 +213,22 @@ def evaluate( ...@@ -196,12 +213,22 @@ def evaluate(
# calculate values for each filter setup (TODO: make getting list of keys cleaner) # calculate values for each filter setup (TODO: make getting list of keys cleaner)
# TODO: make it possible to use a different metric per key # TODO: make it possible to use a different metric per key
for key in task.instances[0].filtered_resps.keys(): for key in task.instances[0].filtered_resps.keys():
doc_iterator = itertools.islice(enumerate(task.test_docs()), lm.rank, limit, lm.world_size) if task.has_test_docs() else itertools.islice(enumerate(task.validation_docs()), lm.rank, limit, lm.world_size) doc_iterator = (
itertools.islice(
enumerate(task.test_docs()), lm.rank, limit, lm.world_size
)
if task.has_test_docs()
else itertools.islice(
enumerate(task.validation_docs()), lm.rank, limit, lm.world_size
)
)
for doc_id, doc in doc_iterator: for doc_id, doc in doc_iterator:
# subset instances to only this document id ; sort by idx # subset instances to only this document id ; sort by idx
requests = list(filter(lambda x: x.doc_id == doc_id, task.instances)) requests = list(filter(lambda x: x.doc_id == doc_id, task.instances))
requests.sort(key=lambda x: x.idx) requests.sort(key=lambda x: x.idx)
metrics = task.process_results(doc, [req.filtered_resps[key] for req in requests]) metrics = task.process_results(
doc, [req.filtered_resps[key] for req in requests]
)
for metric, value in metrics.items(): for metric, value in metrics.items():
vals[(task_name, key, metric)].append(value) vals[(task_name, key, metric)].append(value)
...@@ -217,18 +244,22 @@ def evaluate( ...@@ -217,18 +244,22 @@ def evaluate(
# distributed gather requires all ranks to have same dimensions # distributed gather requires all ranks to have same dimensions
# so we pad out with float32 min value # so we pad out with float32 min value
pad_value = torch.finfo(torch.float32).min pad_value = torch.finfo(torch.float32).min
metrics_tensor = torch.tensor(items, device = lm.device) metrics_tensor = torch.tensor(items, device=lm.device)
original_dtype = metrics_tensor.dtype # store original dtype original_dtype = metrics_tensor.dtype # store original dtype
torch_device_tensor = lm.accelerator.pad_across_processes(metrics_tensor.to(torch.float32), pad_index = pad_value) torch_device_tensor = lm.accelerator.pad_across_processes(
metrics_tensor.to(torch.float32), pad_index=pad_value
)
gathered_item = lm.accelerator.gather(torch_device_tensor) gathered_item = lm.accelerator.gather(torch_device_tensor)
if numitem > 0: if numitem > 0:
gathered_filtered = gathered_item[gathered_item[:,0] != pad_value] gathered_filtered = gathered_item[gathered_item[:, 0] != pad_value]
else: else:
gathered_filtered = gathered_item[gathered_item != pad_value] gathered_filtered = gathered_item[gathered_item != pad_value]
gathered_item = gathered_filtered.to(original_dtype).cpu().detach().numpy().tolist() gathered_item = (
gathered_filtered.to(original_dtype).cpu().detach().numpy().tolist()
)
# reconvert if we were passed a tuple of values # reconvert if we were passed a tuple of values
if numitem > 0: if numitem > 0:
gathered_item = [tuple(g) for g in gathered_item] gathered_item = [tuple(g) for g in gathered_item]
...@@ -238,13 +269,14 @@ def evaluate( ...@@ -238,13 +269,14 @@ def evaluate(
vals = vals_torch vals = vals_torch
if lm.rank == 0: if lm.rank == 0:
### Aggregate results over all datapoints ### ### Aggregate results over all datapoints ###
# aggregate results ; run bootstrap CIs # aggregate results ; run bootstrap CIs
for (task_name, key, metric), items in vals.items(): for (task_name, key, metric), items in vals.items():
task = task_dict[task_name] task = task_dict[task_name]
results[task_name][metric + " - filter=" + key] = task.aggregation()[metric](items) results[task_name][metric + " - filter=" + key] = task.aggregation()[
metric
](items)
# hotfix: bleu, chrf, ter seem to be really expensive to bootstrap # hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
# so we run them less iterations. still looking for a cleaner way to do this # so we run them less iterations. still looking for a cleaner way to do this
...@@ -257,7 +289,9 @@ def evaluate( ...@@ -257,7 +289,9 @@ def evaluate(
) )
if stderr is not None: if stderr is not None:
results[task_name][metric + " - filter=" + key + "_stderr"] = stderr(items) results[task_name][metric + " - filter=" + key + "_stderr"] = stderr(
items
)
return {"results": dict(results), "versions": dict(versions)} return {"results": dict(results), "versions": dict(versions)}
......
...@@ -9,7 +9,7 @@ FILTER_REGISTRY = { ...@@ -9,7 +9,7 @@ FILTER_REGISTRY = {
# TODO: implement this filter. either it should take in an arbitrary "scoring"/reward function # TODO: implement this filter. either it should take in an arbitrary "scoring"/reward function
# that takes an input and returns a scalar and then should select the max reward, # that takes an input and returns a scalar and then should select the max reward,
# or should implement different filters for different ways of handling a reward model's inference. # or should implement different filters for different ways of handling a reward model's inference.
#"arg_max": selection.ArgMaxFilter, # "arg_max": selection.ArgMaxFilter,
} }
...@@ -17,16 +17,19 @@ def get_filter(filter_name): ...@@ -17,16 +17,19 @@ def get_filter(filter_name):
return FILTER_REGISTRY[filter_name] return FILTER_REGISTRY[filter_name]
def build_filter_ensemble(name, components): def build_filter_ensemble(filter_name, components):
""" """
Create a filtering pipeline. Create a filtering pipeline.
""" """
filters = [] filters = []
for step in components: for (function, kwargs) in components:
if kwargs is None:
f = get_filter(function)()
else:
# create a filter given its name in the registry # create a filter given its name in the registry
f = get_filter(step)() # TODO: pass kwargs to filters properly f = get_filter(function)(**kwargs) # TODO: pass kwargs to filters properly
# add the filter as a pipeline step # add the filter as a pipeline step
filters.append(f) filters.append(f)
return FilterEnsemble(name=name, filters=filters) return FilterEnsemble(name=filter_name, filters=filters)
...@@ -4,19 +4,15 @@ from lm_eval.api.filter import Filter ...@@ -4,19 +4,15 @@ from lm_eval.api.filter import Filter
class RegexFilter(Filter): class RegexFilter(Filter):
""" """ """
""" def __init__(self, regex_pattern=r"#### (\-?[0-9\.\,]+)", fallback="[invalid]"):
def __init__(self, regex=r"#### (\-?[0-9\.\,]+)", fallback="[invalid]"):
""" """
pass a string `regex` to run `re.compile(r"regex")` on. pass a string `regex` to run `re.compile(r"regex")` on.
`fallback` defines the output returned if no matches for the regex are located. `fallback` defines the output returned if no matches for the regex are located.
""" """
self.regex_pattern = regex self.regex_pattern = regex_pattern
self.regex = re.compile(regex) self.regex = re.compile(regex_pattern)
self.fallback = fallback self.fallback = fallback
def apply(self, resps): def apply(self, resps):
...@@ -30,7 +26,7 @@ class RegexFilter(Filter): ...@@ -30,7 +26,7 @@ class RegexFilter(Filter):
match = self.regex.search(resp) match = self.regex.search(resp)
if match: if match:
match = match.group(1).strip() match = match.group(1).strip()
match_str.replace(",", "") match.replace(",", "")
# TODO: should we assume any other filtering is performed? # TODO: should we assume any other filtering is performed?
else: else:
match = self.fallback match = self.fallback
......
from lm_eval.api.filter import Filter from lm_eval.api.filter import Filter
class TakeFirstFilter:
class TakeFirstFilter:
def __init__(self): def __init__(self):
""" """
Can define custom behavior here, if an individual instantiation of a Filter class should have state. Can define custom behavior here, if an individual instantiation of a Filter class should have state.
......
import logging
logging.basicConfig(
format="%(asctime)s,%(msecs)03d %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s",
datefmt="%Y-%m-%d:%H:%M:%S",
level=logging.INFO,
)
eval_logger = logging.getLogger("lm-eval")
...@@ -6,11 +6,13 @@ from tqdm import tqdm ...@@ -6,11 +6,13 @@ from tqdm import tqdm
import torch.nn.functional as F import torch.nn.functional as F
from lm_eval import utils from lm_eval import utils
from lm_eval.logger import eval_logger
from lm_eval.api.model import LM, register_model from lm_eval.api.model import LM, register_model
from accelerate import Accelerator from accelerate import Accelerator
from itertools import islice from itertools import islice
@register_model("hf-causal", "gpt2") @register_model("hf-causal", "gpt2")
class HFLM(LM): class HFLM(LM):
def __init__( def __init__(
...@@ -48,7 +50,7 @@ class HFLM(LM): ...@@ -48,7 +50,7 @@ class HFLM(LM):
self._world_size = 1 self._world_size = 1
else: else:
self._device = 'cpu' self._device = "cpu"
# TODO: update this to be less of a hack once subfolder is fixed in HF # TODO: update this to be less of a hack once subfolder is fixed in HF
revision = revision + ("/" + subfolder if subfolder is not None else "") revision = revision + ("/" + subfolder if subfolder is not None else "")
...@@ -72,10 +74,12 @@ class HFLM(LM): ...@@ -72,10 +74,12 @@ class HFLM(LM):
if gpus > 1: if gpus > 1:
accelerator = Accelerator() accelerator = Accelerator()
if gpus > accelerator.num_processes: if gpus > accelerator.num_processes:
warning = ("WARNING: The number of total system GPUs does not match the number of spawned processes. " warning = (
"WARNING: The number of total system GPUs does not match the number of spawned processes. "
"If you would like to use data parallelism, please launch the script " "If you would like to use data parallelism, please launch the script "
"with 'accelerate launch *script*'. " "with 'accelerate launch *script*'. "
f"Current run will proceed with {accelerator.num_processes} devices.") f"Current run will proceed with {accelerator.num_processes} devices."
)
print(warning) print(warning)
self._rank = accelerator.local_process_index self._rank = accelerator.local_process_index
self._world_size = accelerator.num_processes self._world_size = accelerator.num_processes
...@@ -91,7 +95,6 @@ class HFLM(LM): ...@@ -91,7 +95,6 @@ class HFLM(LM):
self._rank = self.accelerator.local_process_index self._rank = self.accelerator.local_process_index
self._world_size = self.accelerator.num_processes self._world_size = self.accelerator.num_processes
@property @property
def eot_token_id(self): def eot_token_id(self):
# we use EOT because end of *text* is more accurate for what we're doing than end of *sentence* # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
...@@ -100,14 +103,16 @@ class HFLM(LM): ...@@ -100,14 +103,16 @@ class HFLM(LM):
@property @property
def max_length(self): def max_length(self):
try: try:
if hasattr(self, 'accelerator'): if hasattr(self, "accelerator"):
return self.accelerator.unwrap_model(self.gpt2).config.n_ctx return self.accelerator.unwrap_model(self.gpt2).config.n_ctx
else: else:
return self.gpt2.config.n_ctx return self.gpt2.config.n_ctx
except AttributeError: except AttributeError:
# gptneoconfig doesn't have n_ctx apparently # gptneoconfig doesn't have n_ctx apparently
if hasattr(self, 'accelerator'): if hasattr(self, "accelerator"):
return self.accelerator.unwrap_model(self.gpt2).config.max_position_embeddings return self.accelerator.unwrap_model(
self.gpt2
).config.max_position_embeddings
else: else:
return self.gpt2.config.max_position_embeddings return self.gpt2.config.max_position_embeddings
...@@ -150,7 +155,11 @@ class HFLM(LM): ...@@ -150,7 +155,11 @@ class HFLM(LM):
def _model_generate(self, context, max_length, eos_token_id): def _model_generate(self, context, max_length, eos_token_id):
return self.gpt2.generate( return self.gpt2.generate(
context, max_length=max_length, pad_token_id=eos_token_id, eos_token_id=eos_token_id, do_sample=False context,
max_length=max_length,
pad_token_id=eos_token_id,
eos_token_id=eos_token_id,
do_sample=False,
) )
def loglikelihood(self, requests): def loglikelihood(self, requests):
...@@ -173,7 +182,7 @@ class HFLM(LM): ...@@ -173,7 +182,7 @@ class HFLM(LM):
# TODO: automatic batch size detection for vectorization # TODO: automatic batch size detection for vectorization
loglikelihoods = [] loglikelihoods = []
for (string,) in tqdm([req.args for req in requests],disable=(self.rank != 0)): for (string,) in tqdm([req.args for req in requests], disable=(self.rank != 0)):
rolling_token_windows = list( rolling_token_windows = list(
map( map(
utils.make_disjoint_window, utils.make_disjoint_window,
...@@ -193,13 +202,15 @@ class HFLM(LM): ...@@ -193,13 +202,15 @@ class HFLM(LM):
pad_amnt = 0 pad_amnt = 0
if self.world_size > 1: if self.world_size > 1:
#TODO: Comment on what we do here # TODO: Comment on what we do here
mytensor = torch.tensor(len(rolling_token_windows), device = self.device) mytensor = torch.tensor(len(rolling_token_windows), device=self.device)
gathered = self.accelerator.gather(mytensor).cpu().detach().numpy().tolist() gathered = (
self.accelerator.gather(mytensor).cpu().detach().numpy().tolist()
)
pad_amnt = max(gathered) - gathered[self.rank] pad_amnt = max(gathered) - gathered[self.rank]
if pad_amnt > 0: if pad_amnt > 0:
rolling_token_windows += pad_amnt*[rolling_token_windows[0]] rolling_token_windows += pad_amnt * [rolling_token_windows[0]]
string_nll = self._loglikelihood_tokens( string_nll = self._loglikelihood_tokens(
rolling_token_windows, disable_tqdm=True rolling_token_windows, disable_tqdm=True
...@@ -214,7 +225,6 @@ class HFLM(LM): ...@@ -214,7 +225,6 @@ class HFLM(LM):
string_nll = sum(string_nll) string_nll = sum(string_nll)
loglikelihoods.append(string_nll) loglikelihoods.append(string_nll)
return loglikelihoods return loglikelihoods
def _loglikelihood_tokens(self, requests, disable_tqdm=False): def _loglikelihood_tokens(self, requests, disable_tqdm=False):
...@@ -235,7 +245,8 @@ class HFLM(LM): ...@@ -235,7 +245,8 @@ class HFLM(LM):
# TODO: automatic (variable) batch size detection for vectorization # TODO: automatic (variable) batch size detection for vectorization
re_ord = utils.Reorderer(requests, _collate) re_ord = utils.Reorderer(requests, _collate)
for chunk in utils.chunks( for chunk in utils.chunks(
tqdm(re_ord.get_reordered(), disable=(disable_tqdm or (self.rank != 0))), self.batch_size tqdm(re_ord.get_reordered(), disable=(disable_tqdm or (self.rank != 0))),
self.batch_size,
): ):
inps = [] inps = []
......
import os import os
import numpy as np import time
import transformers import transformers
from lm_eval.api.model import LM, register_model
from lm_eval import utils import numpy as np
from tqdm import tqdm from tqdm import tqdm
import time from lm_eval import utils
from lm_eval.api.model import LM, register_model
def get_result(response, ctxlen): def get_result(response, ctxlen):
......
from lm_eval.logger import eval_logger
from promptsource.templates import DatasetTemplates
# TODO: decide whether we want jinja2 or f-string prompts. would it be cursed to support both? # TODO: decide whether we want jinja2 or f-string prompts. would it be cursed to support both?
# Prompt library. # Prompt library.
# Stores prompts in a dictionary indexed by 2 levels: # Stores prompts in a dictionary indexed by 2 levels:
...@@ -6,17 +9,40 @@ ...@@ -6,17 +9,40 @@
PROMPT_REGISTRY = { PROMPT_REGISTRY = {
"qa-basic": { "qa-basic": {
"question-newline-answer": "Question: {{question}}\nAnswer:", "question-newline-answer": "Question: {{question}}\nAnswer:",
"q-newline-a": "Q: {question}\nA:" "q-newline-a": "Q: {{question}}\nA:",
}, },
} }
def get_prompt(prompt_id: str):
def get_prompt(prompt_id: str, dataset_name=None, subset_name=None):
# unpack prompt name # unpack prompt name
try:
category_name, prompt_name = prompt_id.split(":") category_name, prompt_name = prompt_id.split(":")
except: if subset_name is None:
dataset_full_name = dataset_name
else:
dataset_full_name = f"{dataset_name}-{subset_name}"
eval_logger.info(f"Loading prompt from {category_name} for {dataset_full_name}")
if category_name == "promptsource":
try:
if subset_name is None:
prompts = DatasetTemplates(dataset_name=dataset_name)
else:
prompts = DatasetTemplates(
dataset_name=dataset_name, subset_name=subset_name
)
except Exception:
raise ValueError(f"{dataset_name} and {subset_name} not found")
if prompt_name in prompts.all_template_names:
return prompts[prompt_name]
else:
raise ValueError( raise ValueError(
f"expected only a single `:` as separator between \ f"{prompt_name} not in prompt list {prompts.all_template_names}"
prompt category and name, but got `{prompt_id}` instead"
) )
else:
try:
return PROMPT_REGISTRY[category_name][prompt_name] return PROMPT_REGISTRY[category_name][prompt_name]
except Exception:
raise ValueError(
f"expected only a single `:` as separator between \
prompt category and name, but got `{prompt_id}` instead"
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment