Unverified Commit bd5b29eb authored by Lintang Sutawika's avatar Lintang Sutawika Committed by GitHub
Browse files

Merge pull request #609 from EleutherAI/misc-cleanup-refactor

[Refactor] Misc. cleanup of dead code
parents 46d3bead fbd712f7
...@@ -281,7 +281,7 @@ class Task(abc.ABC): ...@@ -281,7 +281,7 @@ class Task(abc.ABC):
else: else:
eval_logger.warning( eval_logger.warning(
"has_training_docs and has_validation_docs are False" "has_training_docs and has_validation_docs are False"
", using test_docs but this is not recommended." ", using test_docs as fewshot_docs but this is not recommended."
) )
return self.test_docs() return self.test_docs()
...@@ -342,7 +342,8 @@ class Task(abc.ABC): ...@@ -342,7 +342,8 @@ class Task(abc.ABC):
fewshot_ctx = self.fewshot_context( fewshot_ctx = self.fewshot_context(
doc, self._config.num_fewshot, rnd=random.Random() doc, self._config.num_fewshot, rnd=random.Random()
) )
# TODO: we should override this if doing greedy gen so users don't waste time+compute
# TODO: we should override self._config.repeats if doing greedy gen so users don't waste time+compute
inst = self.construct_requests( inst = self.construct_requests(
doc=doc, doc=doc,
ctx=fewshot_ctx, ctx=fewshot_ctx,
......
...@@ -195,11 +195,6 @@ def evaluate( ...@@ -195,11 +195,6 @@ def evaluate(
versions[task_name] = task.VERSION versions[task_name] = task.VERSION
configs[task_name] = dict(task.dump_config()) configs[task_name] = dict(task.dump_config())
# deterministically shuffle docs and chop off the first `limit` because sometimes docs are in some kind of order
# task_docs = list(task_doc_func())
# rnd = random.Random()
# rnd.seed(42)
# rnd.shuffle(task_docs)
if limit is not None: if limit is not None:
if task.has_test_docs(): if task.has_test_docs():
task_docs = task.test_docs() task_docs = task.test_docs()
...@@ -257,13 +252,12 @@ def evaluate( ...@@ -257,13 +252,12 @@ def evaluate(
task.apply_filters() task.apply_filters()
### Collect values of metrics on all datapoints ### ### Collect values of metrics on all datapoints ###
# TODO: make metric configurable, add metric registry
vals = collections.defaultdict(list) vals = collections.defaultdict(list)
# unpack results and sort back in order and return control to Task # unpack results and sort back in order and return control to Task
for task_name, task in task_dict.items(): for task_name, task in task_dict.items():
# calculate values for each filter setup (TODO: make getting list of keys cleaner) # TODO: make it possible to use a different metric per filter
# TODO: make it possible to use a different metric per key # iterate over different filters used
for key in task.instances[0].filtered_resps.keys(): for key in task.instances[0].filtered_resps.keys():
doc_iterator = ( doc_iterator = (
itertools.islice( itertools.islice(
......
...@@ -124,28 +124,6 @@ def get_task_dict(task_name_list: List[Union[str, dict, Task]], **kwargs): ...@@ -124,28 +124,6 @@ def get_task_dict(task_name_list: List[Union[str, dict, Task]], **kwargs):
get_task_name_from_object(task_element): task_element, get_task_name_from_object(task_element): task_element,
} }
# task_name_from_registry_dict = {
# task_name: get_task(
# task_name=task_name,
# task_config=config
# )
# for group_name in task_name_list for task_name in GROUP_REGISTRY[group_name]
# if (isinstance(group_name, str)) and (group_name in GROUP_REGISTRY)
# }
# task_name_from_config_dict = {
# get_task_name_from_config(task_config): ConfigurableTask(
# config=task_config
# )
# for task_config in task_name_list
# if isinstance(task_config, dict)
# }
# # TODO: Do we still need this?
# task_name_from_object_dict = {
# get_task_name_from_object(task_object): task_object
# for task_object in task_name_list
# if isinstance(task_object, Task)
# }
assert set(task_name_from_registry_dict.keys()).isdisjoint( assert set(task_name_from_registry_dict.keys()).isdisjoint(
set(task_name_from_object_dict.keys()) set(task_name_from_object_dict.keys())
) )
......
...@@ -25,7 +25,7 @@ metric_list: ...@@ -25,7 +25,7 @@ metric_list:
regexes_to_ignore: regexes_to_ignore:
- "," - ","
- "\\$" - "\\$"
delimiter: "\n\n" fewshot_delimiter: "\n\n"
generation_kwargs: generation_kwargs:
until: until:
- "Q:" - "Q:"
......
...@@ -23,15 +23,6 @@ from itertools import islice ...@@ -23,15 +23,6 @@ from itertools import islice
from lm_eval.logger import eval_logger from lm_eval.logger import eval_logger
class ExitCodeError(Exception):
pass
def sh(x):
if os.system(x):
raise ExitCodeError()
def escaped_split(text, sep_char, maxsplit=-1): def escaped_split(text, sep_char, maxsplit=-1):
"""Split text into a list on occurrences of the given separation """Split text into a list on occurrences of the given separation
character `sep_char`. The separation character may be escaped by a character `sep_char`. The separation character may be escaped by a
...@@ -181,26 +172,6 @@ def make_disjoint_window(pair): ...@@ -181,26 +172,6 @@ def make_disjoint_window(pair):
return a[: len(a) - (len(b) - 1)], b return a[: len(a) - (len(b) - 1)], b
def select_continuation_from_batch_left_padding(
generations: Union[List[List[int]], torch.Tensor], max_context_size: int
):
"""Select the continuation from the batch, removing prompts of different lengths.
Args:
generations (Union[List[List[int]], torch.Tensor]):
A tensor or list-of-lists of shape [batch_size, sequence length].
max_context_size (int):
The size of the biggest context; generations will proceed from that
index.
Example:
PAD PAD Continue : The dog chased the cat [every day of the week]
Riddle me this : The dog chased the cat [yesterday] PAD PAD PAD PAD
Output:
[every day of the week]
[yesterday] PAD PAD PAD PAD
"""
return generations[:, max_context_size:]
class Reorderer: class Reorderer:
def __init__(self, arr, fn): def __init__(self, arr, fn):
self.size = len(arr) self.size = len(arr)
...@@ -396,9 +367,10 @@ def get_git_commit_hash(): ...@@ -396,9 +367,10 @@ def get_git_commit_hash():
Source: https://github.com/EleutherAI/gpt-neox/blob/b608043be541602170bfcfb8ec9bf85e8a0799e0/megatron/neox_arguments/neox_args.py#L42 Source: https://github.com/EleutherAI/gpt-neox/blob/b608043be541602170bfcfb8ec9bf85e8a0799e0/megatron/neox_arguments/neox_args.py#L42
""" """
try: try:
git_hash = subprocess.check_output(["git", "describe", "--always"]).strip() git_hash = subprocess.check_output(["gt", "describe", "--always"]).strip()
git_hash = git_hash.decode() git_hash = git_hash.decode()
except subprocess.CalledProcessError: except subprocess.CalledProcessError or FileNotFoundError:
# FileNotFoundError occurs when git not installed on system
git_hash = None git_hash = None
return git_hash return git_hash
......
...@@ -6,14 +6,18 @@ import lm_eval.models ...@@ -6,14 +6,18 @@ import lm_eval.models
def test_description(): def test_description():
seed = 42 seed = 42
num_examples = 1 num_examples = 1
task_names = ["hellaswag", "winogrande"] task_names = ["arc_challenge", "lambada"]
description_dict = { description_dict = {
"hellaswag": "Label for the relevant action:\nSentences describing context, with an incomplete sentence trailing answer that plausibly completes the situation.", "arc_challenge": "Label for the relevant action:\nSentences describing context, with an incomplete sentence trailing answer that plausibly completes the situation.",
"winogrande": "Winograd schema sentence including a either a ___ blank with a missing word, making the pronoun ambiguous, or the same with the word filled in.", "lambada": "Winograd schema sentence including a either a ___ blank with a missing word, making the pronoun ambiguous, or the same with the word filled in.",
} }
task_dict = lm_eval.tasks.get_task_dict(task_names) task_dict = lm_eval.tasks.get_task_dict(task_names)
for task_name, task in task_dict.items(): for task_name, task in task_dict.items():
# patch description field in task (# TODO: make this much more cleaned up)
task._config.description = description_dict[task_name]
rnd = random.Random() rnd = random.Random()
rnd.seed(seed) rnd.seed(seed)
......
import os import os
import lm_eval.base as base
# import lm_eval.base as base
import lm_eval.api.registry as registry
import lm_eval.tasks as tasks import lm_eval.tasks as tasks
import lm_eval.models as models
# import lm_eval.models as models
import lm_eval.evaluator as evaluator import lm_eval.evaluator as evaluator
import random import random
import pytest import pytest
...@@ -15,8 +19,10 @@ import pytest ...@@ -15,8 +19,10 @@ import pytest
def test_evaluator(taskname, task_class): def test_evaluator(taskname, task_class):
task_dict = tasks.get_task_dict([taskname]) task_dict = tasks.get_task_dict([taskname])
os.system("rm test_cache.db") # TODO: re-add cachingLM
lm = base.CachingLM(models.get_model("dummy")(), "test_cache.db") # os.system("rm test_cache.db")
# lm = base.CachingLM(models.get_model("dummy")(), "test_cache.db")
lm = registry.get_model("dummy")()
def ll_fn(reqs): def ll_fn(reqs):
for ctx, cont in reqs: for ctx, cont in reqs:
......
import pytest import pytest
import lm_eval.metrics as metrics import lm_eval.api.metrics as metrics
import random import random
......
import lm_eval.tasks as tasks import lm_eval.tasks as tasks
import lm_eval.base as base
import pytest import pytest
from itertools import islice from itertools import islice
...@@ -100,5 +100,5 @@ def test_documents_and_requests(taskname, task_class): ...@@ -100,5 +100,5 @@ def test_documents_and_requests(taskname, task_class):
reqs = [reqs] reqs = [reqs]
# todo: mock lm after refactoring evaluator.py to not be a mess # todo: mock lm after refactoring evaluator.py to not be a mess
for req in reqs: # for req in reqs:
assert isinstance(req, base.Request) # assert isinstance(req, base.Request)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment