"official/modeling/multitask/task_sampler.py" did not exist on "983837ff860367690d0d50ab222bc5023f18550e"
Unverified Commit 1060b68d authored by LSinev's avatar LSinev Committed by GitHub
Browse files

Try to make existing tests run little bit faster (#1905)

parent 4902aaaf
...@@ -15,11 +15,11 @@ base_url = "https://matthoffner-ggml-llm-api.hf.space" ...@@ -15,11 +15,11 @@ base_url = "https://matthoffner-ggml-llm-api.hf.space"
def gguf_completion_mock(base_url=None, **kwargs): def gguf_completion_mock(base_url=None, **kwargs):
# Generate a hash from the parameters # Generate a hash from the parameters
hash_kwargs = {"base_url": base_url, **kwargs} hash_kwargs = {"base_url": base_url, **kwargs}
hash = hashlib.sha256( parameters_hash = hashlib.sha256(
json.dumps(hash_kwargs, sort_keys=True).encode("utf-8") json.dumps(hash_kwargs, sort_keys=True).encode("utf-8")
).hexdigest() ).hexdigest()
fname = f"./tests/testdata/gguf_test_{hash}.pkl" fname = f"./tests/testdata/gguf_test_{parameters_hash}.pkl"
if os.path.exists(fname): if os.path.exists(fname):
with open(fname, "rb") as fh: with open(fname, "rb") as fh:
......
from __future__ import annotations from __future__ import annotations
import os
import sys import sys
from pathlib import Path from pathlib import Path
import numpy as np import numpy as np
import torch import torch
import lm_eval.tasks as tasks from lm_eval import tasks
from lm_eval.api.instance import Instance from lm_eval.api.instance import Instance
from lm_eval.models.huggingface import HFLM from lm_eval.models.huggingface import HFLM
os.environ["TOKENIZERS_PARALLELISM"] = "false"
task_manager = tasks.TaskManager() task_manager = tasks.TaskManager()
TEST_STRING = "foo bar"
class Test_HFLM: class Test_HFLM:
torch.use_deterministic_algorithms(True) torch.use_deterministic_algorithms(True)
...@@ -107,7 +111,7 @@ class Test_HFLM: ...@@ -107,7 +111,7 @@ class Test_HFLM:
file_path = dir_path / f"outputs_log_{self.version_minor}.txt" file_path = dir_path / f"outputs_log_{self.version_minor}.txt"
file_path = file_path.resolve() file_path = file_path.resolve()
with open(file_path, "w") as f: with open(file_path, "w", encoding="utf-8") as f:
f.write("\n".join(str(x) for x in _res)) f.write("\n".join(str(x) for x in _res))
assert np.allclose(_res, _RES, atol=1e-2) assert np.allclose(_res, _RES, atol=1e-2)
# check indices for Multiple Choice # check indices for Multiple Choice
...@@ -126,19 +130,19 @@ class Test_HFLM: ...@@ -126,19 +130,19 @@ class Test_HFLM:
assert np.allclose(res, self.ROLLING_RES, atol=1e-1) assert np.allclose(res, self.ROLLING_RES, atol=1e-1)
def test_toc_encode(self) -> None: def test_toc_encode(self) -> None:
res = self.LM.tok_encode("foo bar") res = self.LM.tok_encode(TEST_STRING)
assert res == [12110, 2534] assert res == [12110, 2534]
def test_toc_decode(self) -> None: def test_toc_decode(self) -> None:
res = self.LM.tok_decode([12110, 2534]) res = self.LM.tok_decode([12110, 2534])
assert res == "foo bar" assert res == TEST_STRING
def test_batch_encode(self) -> None: def test_batch_encode(self) -> None:
res = self.LM.tok_batch_encode(["foo bar", "bar foo"])[0].tolist() res = self.LM.tok_batch_encode([TEST_STRING, "bar foo"])[0].tolist()
assert res == [[12110, 2534], [2009, 17374]] assert res == [[12110, 2534], [2009, 17374]]
def test_model_generate(self) -> None: def test_model_generate(self) -> None:
context = self.LM.tok_batch_encode(["foo bar"])[0] context = self.LM.tok_batch_encode([TEST_STRING])[0]
res = self.LM._model_generate(context, max_length=10, stop=["\n\n"]) res = self.LM._model_generate(context, max_length=10, stop=["\n\n"])
res = self.LM.tok_decode(res[0]) res = self.LM.tok_decode(res[0])
assert res == "foo bar\n<bazhang>!info bar" assert res == "foo bar\n<bazhang>!info bar"
import pytest import pytest
import lm_eval.evaluator as evaluator from lm_eval import evaluator
from lm_eval.api.registry import get_model from lm_eval.api.registry import get_model
......
...@@ -6,7 +6,7 @@ import pytest ...@@ -6,7 +6,7 @@ import pytest
from optimum.intel import OVModelForCausalLM from optimum.intel import OVModelForCausalLM
from transformers import AutoTokenizer from transformers import AutoTokenizer
import lm_eval.evaluator as evaluator from lm_eval import evaluator
from lm_eval.api.registry import get_model from lm_eval.api.registry import get_model
...@@ -46,7 +46,7 @@ def test_evaluator(model_id, task): ...@@ -46,7 +46,7 @@ def test_evaluator(model_id, task):
random.seed(42) random.seed(42)
for _ in reqs: for _ in reqs:
res.append((-random.random(), False)) res.extend([(-random.random(), False)])
return res return res
...@@ -57,7 +57,7 @@ def test_evaluator(model_id, task): ...@@ -57,7 +57,7 @@ def test_evaluator(model_id, task):
res = [] res = []
random.seed(42) random.seed(42)
for _ in reqs: for _ in reqs:
res.append(-random.random()) res.extend([-random.random()])
return res return res
...@@ -79,7 +79,7 @@ def test_ov_config(): ...@@ -79,7 +79,7 @@ def test_ov_config():
model_id = "hf-internal-testing/tiny-random-gpt2" model_id = "hf-internal-testing/tiny-random-gpt2"
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
config_file = str(Path(tmpdirname) / "ov_config.json") config_file = str(Path(tmpdirname) / "ov_config.json")
with open(Path(config_file), "w") as f: with open(Path(config_file), "w", encoding="utf-8") as f:
f.write('{"DYNAMIC_QUANTIZATION_GROUP_SIZE" : "32"}') f.write('{"DYNAMIC_QUANTIZATION_GROUP_SIZE" : "32"}')
lm = get_model("openvino").create_from_arg_string( lm = get_model("openvino").create_from_arg_string(
f"pretrained={model_id},ov_config={config_file}" f"pretrained={model_id},ov_config={config_file}"
......
...@@ -3,7 +3,7 @@ from typing import List ...@@ -3,7 +3,7 @@ from typing import List
import pytest import pytest
import torch import torch
import lm_eval.tasks as tasks from lm_eval import tasks
from lm_eval.api.instance import Instance from lm_eval.api.instance import Instance
......
# import lm_eval.base as base import os
from typing import List from typing import List
import pytest import pytest
# import lm_eval.models as models
import lm_eval.api as api import lm_eval.api as api
import lm_eval.evaluator as evaluator import lm_eval.evaluator as evaluator
from lm_eval import tasks from lm_eval import tasks
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# TODO: more fine grained unit tests rather than this big honking integration # TODO: more fine grained unit tests rather than this big honking integration
# test once we break evaluator into smaller, more manageable pieces # test once we break evaluator into smaller, more manageable pieces
......
import os
from collections import defaultdict from collections import defaultdict
from lm_eval.decontamination.janitor import ( from lm_eval.decontamination.janitor import (
...@@ -9,23 +10,41 @@ from lm_eval.decontamination.janitor import ( ...@@ -9,23 +10,41 @@ from lm_eval.decontamination.janitor import (
) )
os.environ["TOKENIZERS_PARALLELISM"] = "false"
TEST_SEQUENCE = (
"Hello my name is Bob, I like eating pizza, chicken, chips and ice cream. Maybe I should eat some"
" more salad but it's so booooring. I just... like eating pizza, chicken, chips and ice cream so much."
)
JANITOR_EXPECTED = (
"This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing "
" characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. "
)
JANITOR_FILTH1 = "filth lots of dirty filthy filth"
JANITOR_FILTH2 = "filth lots of filthy dirty filth"
def simple_ngram(sequence, n): def simple_ngram(sequence, n):
ngrams = list() ngrams = list()
ngram = [] ngram = []
for x in sequence: for x in sequence:
ngram.append(x) ngram.extend([x])
if len(ngram) == n: if len(ngram) == n:
ngrams.append(tuple(ngram)) ngrams.extend([tuple(ngram)])
ngram = ngram[1:] ngram = ngram[1:]
return ngrams return ngrams
def test_form_ngrams(): def test_form_ngrams():
sequence = ( sequence = TEST_SEQUENCE
"Hello my name is Bob, I like eating pizza, chicken, chips and ice cream. Maybe I should eat some"
" more salad but it's so booooring. I just... like eating pizza, chicken, chips and ice cream so much."
)
n_values = [1, 2, 3, 5, 13] n_values = [1, 2, 3, 5, 13]
for n in n_values: for n in n_values:
...@@ -36,10 +55,7 @@ def test_form_ngrams(): ...@@ -36,10 +55,7 @@ def test_form_ngrams():
def test_word_ngrams(): def test_word_ngrams():
sequence = ( sequence = TEST_SEQUENCE
"Hello my name is Bob, I like eating pizza, chicken, chips and ice cream. Maybe I should eat some"
" more salad but it's so booooring. I just... like eating pizza, chicken, chips and ice cream so much."
)
words = sequence.split() words = sequence.split()
...@@ -53,10 +69,7 @@ def test_word_ngrams(): ...@@ -53,10 +69,7 @@ def test_word_ngrams():
def test_split_indices(): def test_split_indices():
sequence = ( sequence = TEST_SEQUENCE
"Hello my name is Bob, I like eating pizza, chicken, chips and ice cream. Maybe I should eat some"
" more salad but it's so booooring. I just... like eating pizza, chicken, chips and ice cream so much."
)
comparison = [] comparison = []
current_word = "" current_word = ""
...@@ -65,12 +78,18 @@ def test_split_indices(): ...@@ -65,12 +78,18 @@ def test_split_indices():
current_word += c current_word += c
else: else:
if current_word: if current_word:
comparison.append((current_word, (i - len(current_word), i - 1))) comparison.extend([(current_word, (i - len(current_word), i - 1))])
current_word = "" current_word = ""
if current_word: if current_word:
comparison.append( len_sequence = len(sequence)
(current_word, (len(sequence) - len(current_word), len(sequence) - 1)) comparison.extend(
[
(
current_word,
(len_sequence - len(current_word), len_sequence - 1),
)
]
) )
current_word = "" current_word = ""
...@@ -80,10 +99,7 @@ def test_split_indices(): ...@@ -80,10 +99,7 @@ def test_split_indices():
def test_word_ngrams_indices(): def test_word_ngrams_indices():
sequence = ( sequence = TEST_SEQUENCE
"Hello my name is Bob, I like eating pizza, chicken, chips and ice cream. Maybe I should eat some"
" more salad but it's so booooring. I just... like eating pizza, chicken, chips and ice cream so much."
)
n_values = [1, 2, 3, 5, 13] n_values = [1, 2, 3, 5, 13]
...@@ -100,14 +116,13 @@ def test_word_ngrams_indices(): ...@@ -100,14 +116,13 @@ def test_word_ngrams_indices():
tracker[ngram] = end + 1 tracker[ngram] = end + 1
# ignore partial word matches # ignore partial word matches
if (start != 0 and sequence[start - 1] != " ") or ( if not (
end != len(sequence) - 1 and sequence[end + 1] != " " (start != 0 and sequence[start - 1] != " ")
or (end != len(sequence) - 1 and sequence[end + 1] != " ")
): ):
pass
else:
break break
comparison.append((ngram, (start, end))) comparison.extend([(ngram, (start, end))])
result_to_test = list(word_ngrams_indices(sequence, n)) result_to_test = list(word_ngrams_indices(sequence, n))
assert len(result_to_test) == len(comparison) assert len(result_to_test) == len(comparison)
...@@ -184,17 +199,6 @@ def test_janitor2(): ...@@ -184,17 +199,6 @@ def test_janitor2():
filth = "filth" filth = "filth"
expected_result = (
"This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing "
" characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. "
)
janitor = Janitor( janitor = Janitor(
ngram_n=1, window_to_remove=200, too_dirty_cutoff=10, minimum_slice_length=200 ngram_n=1, window_to_remove=200, too_dirty_cutoff=10, minimum_slice_length=200
) )
...@@ -207,7 +211,7 @@ def test_janitor2(): ...@@ -207,7 +211,7 @@ def test_janitor2():
result = janitor.clean_python(sequence) result = janitor.clean_python(sequence)
result = "".join(result) result = "".join(result)
assert result == expected_result assert result == JANITOR_EXPECTED
def test_janitor3(): def test_janitor3():
...@@ -229,19 +233,6 @@ def test_janitor3(): ...@@ -229,19 +233,6 @@ def test_janitor3():
"This is a @line #containing a certain number of characters, 76 to be exact. " "This is a @line #containing a certain number of characters, 76 to be exact. "
) )
filth = "filth lots of dirty filthy filth"
expected_result = (
"This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing "
" characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. "
)
janitor = Janitor( janitor = Janitor(
ngram_n=6, window_to_remove=200, too_dirty_cutoff=10, minimum_slice_length=200 ngram_n=6, window_to_remove=200, too_dirty_cutoff=10, minimum_slice_length=200
) )
...@@ -249,12 +240,12 @@ def test_janitor3(): ...@@ -249,12 +240,12 @@ def test_janitor3():
result = "".join(result) result = "".join(result)
assert result == sequence assert result == sequence
janitor.register_contaminant(filth) janitor.register_contaminant(JANITOR_FILTH1)
assert janitor.dirt_ngrams == {filth} assert janitor.dirt_ngrams == {JANITOR_FILTH1}
result = janitor.clean_python(sequence) result = janitor.clean_python(sequence)
result = "".join(result) result = "".join(result)
assert result == expected_result assert result == JANITOR_EXPECTED
def test_janitor4(): def test_janitor4():
...@@ -284,19 +275,6 @@ def test_janitor4(): ...@@ -284,19 +275,6 @@ def test_janitor4():
"This is a @line #containing a certain number of characters, 76 to be exact. " "This is a @line #containing a certain number of characters, 76 to be exact. "
) )
filth = "filth lots of dirty filthy filth"
expected_result = (
"This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing "
" characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. "
)
janitor = Janitor( janitor = Janitor(
ngram_n=6, window_to_remove=200, too_dirty_cutoff=10, minimum_slice_length=200 ngram_n=6, window_to_remove=200, too_dirty_cutoff=10, minimum_slice_length=200
) )
...@@ -304,12 +282,12 @@ def test_janitor4(): ...@@ -304,12 +282,12 @@ def test_janitor4():
result = "".join(result) result = "".join(result)
assert result == sequence assert result == sequence
janitor.register_contaminant(filth) janitor.register_contaminant(JANITOR_FILTH1)
assert janitor.dirt_ngrams == {filth} assert janitor.dirt_ngrams == {JANITOR_FILTH1}
result = janitor.clean_python(sequence) result = janitor.clean_python(sequence)
result = "".join(result) result = "".join(result)
assert result == expected_result assert result == JANITOR_EXPECTED
def test_janitor5(): def test_janitor5():
...@@ -338,18 +316,7 @@ def test_janitor5(): ...@@ -338,18 +316,7 @@ def test_janitor5():
"This is a @line #containing a certain number of characters, 76 to be exact. " "This is a @line #containing a certain number of characters, 76 to be exact. "
) )
filths = ["filth lots of dirty filthy filth", "filth lots of filthy dirty filth"] filths = [JANITOR_FILTH1, JANITOR_FILTH2]
expected_result = (
"This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing "
" characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. "
)
janitor = Janitor( janitor = Janitor(
ngram_n=6, window_to_remove=200, too_dirty_cutoff=10, minimum_slice_length=200 ngram_n=6, window_to_remove=200, too_dirty_cutoff=10, minimum_slice_length=200
...@@ -364,7 +331,7 @@ def test_janitor5(): ...@@ -364,7 +331,7 @@ def test_janitor5():
result = janitor.clean_python(sequence) result = janitor.clean_python(sequence)
result = "".join(result) result = "".join(result)
assert result == expected_result assert result == JANITOR_EXPECTED
def test_janitor6(): def test_janitor6():
...@@ -401,18 +368,7 @@ def test_janitor6(): ...@@ -401,18 +368,7 @@ def test_janitor6():
"This is a @line #containing a certain number of characters, 76 to be exact. " "This is a @line #containing a certain number of characters, 76 to be exact. "
) )
filths = ["filth lots of dirty filthy filth", "filth lots of filthy dirty filth"] filths = [JANITOR_FILTH1, JANITOR_FILTH2]
expected_result = (
"This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing "
" characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. "
)
janitor = Janitor( janitor = Janitor(
ngram_n=6, window_to_remove=200, too_dirty_cutoff=10, minimum_slice_length=200 ngram_n=6, window_to_remove=200, too_dirty_cutoff=10, minimum_slice_length=200
...@@ -427,7 +383,7 @@ def test_janitor6(): ...@@ -427,7 +383,7 @@ def test_janitor6():
result = janitor.clean_python(sequence) result = janitor.clean_python(sequence)
result = "".join(result) result = "".join(result)
assert result == expected_result assert result == JANITOR_EXPECTED
def test_janitor7(): def test_janitor7():
...@@ -465,7 +421,7 @@ def test_janitor7(): ...@@ -465,7 +421,7 @@ def test_janitor7():
"This is a @line #containing a certain number of characters, 76 to be exact. " "This is a @line #containing a certain number of characters, 76 to be exact. "
) )
filths = ["filth lots of dirty filthy filth", "filth lots of filthy dirty filth"] filths = [JANITOR_FILTH1, JANITOR_FILTH2]
expected_result = "" expected_result = ""
...@@ -488,20 +444,3 @@ def test_janitor7(): ...@@ -488,20 +444,3 @@ def test_janitor7():
def test_janitor8(): def test_janitor8():
# This will test the save and load contams # This will test the save and load contams
pass pass
# source = """ ,, I'm a very !dirty,, ,, dirty boy. Clean me daddy. \n\nhe he he hehe heh. lastword """ * 2
# contaminant = "dirty boy. Clean he he"
# jan = Janitor(ngram_n=3)
# jan.register_contaminant(contaminant)
# cleaned = " ".join(jan.clean(source))
# for contam in jan.dirt_ngrams:
# assert contam not in cleaned, contam
# filename = "data/saved_contam"
# jan.save_contamination_ngrams(filename)
# jan = Janitor(ngram_n=3)
# jan.load_contamination_ngrams(filename)
# cleaned = " ".join(jan.clean(source))
# for contam in jan.dirt_ngrams:
# assert contam not in cleaned, contam
# import lm_eval.base as base
import importlib import importlib
import os import os
import sys import sys
from datetime import datetime from datetime import datetime
from typing import List, Tuple from typing import List, Optional, Tuple
import pytest import pytest
import torch import torch
# import lm_eval.models as models
from lm_eval.caching.cache import PATH from lm_eval.caching.cache import PATH
...@@ -43,7 +41,7 @@ def clear_cache(): ...@@ -43,7 +41,7 @@ def clear_cache():
# leaving tasks here to allow for the option to select specific task files # leaving tasks here to allow for the option to select specific task files
def get_cache_files(tasks: List[str] = None) -> Tuple[List[str], List[str]]: def get_cache_files(tasks: Optional[List[str]] = None) -> Tuple[List[str], List[str]]:
cache_files = os.listdir(PATH) cache_files = os.listdir(PATH)
file_task_names = [] file_task_names = []
...@@ -51,7 +49,7 @@ def get_cache_files(tasks: List[str] = None) -> Tuple[List[str], List[str]]: ...@@ -51,7 +49,7 @@ def get_cache_files(tasks: List[str] = None) -> Tuple[List[str], List[str]]:
for file in cache_files: for file in cache_files:
file_without_prefix = file.split("-")[1] file_without_prefix = file.split("-")[1]
file_without_prefix_and_suffix = file_without_prefix.split(".")[0] file_without_prefix_and_suffix = file_without_prefix.split(".")[0]
file_task_names.append(file_without_prefix_and_suffix) file_task_names.extend([file_without_prefix_and_suffix])
return cache_files, file_task_names return cache_files, file_task_names
...@@ -113,10 +111,11 @@ if __name__ == "__main__": ...@@ -113,10 +111,11 @@ if __name__ == "__main__":
# test_requests_caching_refresh, # test_requests_caching_refresh,
# test_requests_caching_delete, # test_requests_caching_delete,
] ]
# Lookups of global names within a loop is inefficient, so copy to a local variable outside of the loop first
default_tasks = DEFAULT_TASKS
for test_func in tests: for test_func in tests:
clear_cache() clear_cache()
test_func(tasks=DEFAULT_TASKS) test_func(tasks=default_tasks)
print("Tests pass") print("Tests pass")
......
import os
from itertools import islice from itertools import islice
import pytest import pytest
...@@ -8,6 +9,7 @@ from lm_eval.api.task import ConfigurableTask ...@@ -8,6 +9,7 @@ from lm_eval.api.task import ConfigurableTask
from .utils import new_tasks from .utils import new_tasks
os.environ["TOKENIZERS_PARALLELISM"] = "false"
task_manager = tasks.TaskManager() task_manager = tasks.TaskManager()
# Default Task # Default Task
TASKS = ["arc_easy"] TASKS = ["arc_easy"]
...@@ -87,7 +89,6 @@ class TestNewTasks: ...@@ -87,7 +89,6 @@ class TestNewTasks:
) )
if "multiple_choice" in task._config.output_type: if "multiple_choice" in task._config.output_type:
_array = [task.doc_to_choice(doc) for doc in arr] _array = [task.doc_to_choice(doc) for doc in arr]
# assert all(len(x) == 4 for x in _array)
assert all(isinstance(x, list) for x in _array) assert all(isinstance(x, list) for x in _array)
assert all(isinstance(x[0], str) for x in _array) assert all(isinstance(x[0], str) for x in _array)
...@@ -101,9 +102,6 @@ class TestNewTasks: ...@@ -101,9 +102,6 @@ class TestNewTasks:
_array_target = [task.doc_to_target(doc) for doc in arr] _array_target = [task.doc_to_target(doc) for doc in arr]
if task._config.output_type == "multiple_choice": if task._config.output_type == "multiple_choice":
assert all(isinstance(label, int) for label in _array_target) assert all(isinstance(label, int) for label in _array_target)
# _array_text = [task.doc_to_text(doc) for doc in arr]
# Not working
# assert all(tgt[0] == " " or txt[-1] == "\n" if len(txt) != 0 else True for txt, tgt in zip(_array_text, _array_target))
def test_build_all_requests(self, task_class, limit): def test_build_all_requests(self, task_class, limit):
task_class.build_all_requests(rank=1, limit=limit, world_size=1) task_class.build_all_requests(rank=1, limit=limit, world_size=1)
...@@ -118,5 +116,4 @@ class TestNewTasks: ...@@ -118,5 +116,4 @@ class TestNewTasks:
else list(islice(task.validation_docs(), limit)) else list(islice(task.validation_docs(), limit))
) )
requests = [task.construct_requests(doc, task.doc_to_text(doc)) for doc in arr] requests = [task.construct_requests(doc, task.doc_to_text(doc)) for doc in arr]
# assert all(isinstance(doc, list) for doc in requests)
assert len(requests) == limit if limit else True assert len(requests) == limit if limit else True
...@@ -41,7 +41,7 @@ def test_get_rolling_token_windows_v1(): ...@@ -41,7 +41,7 @@ def test_get_rolling_token_windows_v1():
pred_length = 0 pred_length = 0
output = [] output = []
for input_tokens, pred_tokens in generator: for input_tokens, pred_tokens in generator:
output.append((input_tokens, pred_tokens)) output.extend([(input_tokens, pred_tokens)])
pred_length += len(pred_tokens) pred_length += len(pred_tokens)
assert pred_length == len(x) assert pred_length == len(x)
assert gold == output assert gold == output
...@@ -70,7 +70,7 @@ def test_get_rolling_token_windows_v2(): ...@@ -70,7 +70,7 @@ def test_get_rolling_token_windows_v2():
pred_length = 0 pred_length = 0
output = [] output = []
for input_tokens, pred_tokens in generator: for input_tokens, pred_tokens in generator:
output.append((input_tokens, pred_tokens)) output.extend([(input_tokens, pred_tokens)])
pred_length += len(pred_tokens) pred_length += len(pred_tokens)
assert pred_length == len(x) assert pred_length == len(x)
assert gold == output assert gold == output
...@@ -115,7 +115,7 @@ def test_get_rolling_token_windows_v3(): ...@@ -115,7 +115,7 @@ def test_get_rolling_token_windows_v3():
pred_length = 0 pred_length = 0
output = [] output = []
for input_tokens, pred_tokens in generator: for input_tokens, pred_tokens in generator:
output.append((input_tokens, pred_tokens)) output.extend([(input_tokens, pred_tokens)])
pred_length += len(pred_tokens) pred_length += len(pred_tokens)
assert pred_length == len(x) assert pred_length == len(x)
assert gold == output assert gold == output
...@@ -156,7 +156,7 @@ def test_get_rolling_token_windows_v4(): ...@@ -156,7 +156,7 @@ def test_get_rolling_token_windows_v4():
pred_length = 0 pred_length = 0
output = [] output = []
for input_tokens, pred_tokens in generator: for input_tokens, pred_tokens in generator:
output.append((input_tokens, pred_tokens)) output.extend([(input_tokens, pred_tokens)])
pred_length += len(pred_tokens) pred_length += len(pred_tokens)
assert pred_length == len(x) assert pred_length == len(x)
assert gold == output assert gold == output
...@@ -185,7 +185,7 @@ def test_get_rolling_token_windows_v5(): ...@@ -185,7 +185,7 @@ def test_get_rolling_token_windows_v5():
pred_length = 0 pred_length = 0
output = [] output = []
for input_tokens, pred_tokens in generator: for input_tokens, pred_tokens in generator:
output.append((input_tokens, pred_tokens)) output.extend([(input_tokens, pred_tokens)])
pred_length += len(pred_tokens) pred_length += len(pred_tokens)
assert pred_length == len(x) assert pred_length == len(x)
assert gold == output assert gold == output
...@@ -210,7 +210,7 @@ def test_get_rolling_token_windows_v6(): ...@@ -210,7 +210,7 @@ def test_get_rolling_token_windows_v6():
pred_length = 0 pred_length = 0
output = [] output = []
for input_tokens, pred_tokens in generator: for input_tokens, pred_tokens in generator:
output.append((input_tokens, pred_tokens)) output.extend([(input_tokens, pred_tokens)])
pred_length += len(pred_tokens) pred_length += len(pred_tokens)
assert pred_length == len(x) assert pred_length == len(x)
assert gold == output assert gold == output
...@@ -273,26 +273,26 @@ class TestCollator: ...@@ -273,26 +273,26 @@ class TestCollator:
generation_samples = self.make_generate_sample(int(end)) generation_samples = self.make_generate_sample(int(end))
gens = Collator(generation_samples, _collate_gen, group_by="gen_kwargs") gens = Collator(generation_samples, _collate_gen, group_by="gen_kwargs")
chunks = gens.get_batched(n=int(batch_size), batch_fn=None) chunks_gen = gens.get_batched(n=int(batch_size), batch_fn=None)
output = [] output = []
for chunks in chunks:
# check batching
group_one = end // 2 group_one = end // 2
group_two = end - end // 2 group_two = end - end // 2
is_batch = batch_size != 0
for chunks in chunks_gen:
# check batching
assert ( assert (
len(chunks) <= batch_size len(chunks) <= batch_size
if batch_size != 0 if is_batch
else len(chunks) in [group_one, group_two] else len(chunks) in [group_one, group_two]
) )
# check if reorder-er is working correctly # check if reorder-er is working correctly
assert all( chunk_lengths = [len(chunk[0]) for chunk in chunks]
len(chunks[i][0]) <= len(chunks[i - 1][0]) assert chunk_lengths == sorted(chunk_lengths, reverse=True)
for i in range(1, len(chunks))
)
# check if grouping correctly # check if grouping correctly
assert all(x[1] == chunks[0][1] for x in chunks) chunk_to_compare = chunks[0][1]
assert all(x[1] == chunk_to_compare for x in chunks)
for x in chunks: for x in chunks:
output.append(x) output.extend([x])
reordered_output = gens.get_original(output) reordered_output = gens.get_original(output)
# check get original # check get original
assert reordered_output == generation_samples assert reordered_output == generation_samples
...@@ -305,18 +305,17 @@ class TestCollator: ...@@ -305,18 +305,17 @@ class TestCollator:
loglikelihood_samples, loglikelihood_samples,
_collate_log, _collate_log,
) )
chunks = loglikelihoods.get_batched(n=int(batch_size), batch_fn=None) chunks_gen = loglikelihoods.get_batched(n=int(batch_size), batch_fn=None)
output = [] output = []
for chunks in chunks: is_batch = batch_size != 0
for chunks in chunks_gen:
# check batching # check batching
assert len(chunks) <= batch_size if batch_size != 0 else len(chunks) == end assert len(chunks) <= batch_size if is_batch else len(chunks) == end
# check reorder # check reorder
assert all( chunk_lengths = [len(chunk[1]) for chunk in chunks]
len(chunks[i][1]) <= len(chunks[i - 1][1]) assert chunk_lengths == sorted(chunk_lengths, reverse=True)
for i in range(1, len(chunks))
)
for x in chunks: for x in chunks:
output.append(x[1]) output.extend([x[1]])
# check indices # check indices
reordered_output = loglikelihoods.get_original(output) reordered_output = loglikelihoods.get_original(output)
assert reordered_output == [x[1] for x in loglikelihood_samples] assert reordered_output == [x[1] for x in loglikelihood_samples]
...@@ -335,18 +334,17 @@ class TestCollator: ...@@ -335,18 +334,17 @@ class TestCollator:
group_fn=lambda a: a[-2] + a[-1][:-1], group_fn=lambda a: a[-2] + a[-1][:-1],
group_by="contexts", group_by="contexts",
) )
chunks = loglikelihoods.get_batched(n=int(batch_size), batch_fn=None) chunks_gen = loglikelihoods.get_batched(n=int(batch_size), batch_fn=None)
output = [] output = []
outputs_ = [] outputs_ = []
for chunks in chunks: is_batch = batch_size != 0
for chunks in chunks_gen:
# check batching # check batching
if batch_size != 0: if is_batch:
assert len(chunks) <= batch_size assert len(chunks) <= batch_size
# check reorder # check reorder
assert all( chunk_lengths = [len(chunk[1]) for chunk in chunks]
len(chunks[i][1]) <= len(chunks[i - 1][1]) assert chunk_lengths == sorted(chunk_lengths, reverse=True)
for i in range(1, len(chunks))
)
for x in chunks: for x in chunks:
for request_str, cont_toks, logits in loglikelihoods.get_cache( for request_str, cont_toks, logits in loglikelihoods.get_cache(
req_str="".join(x[0]), req_str="".join(x[0]),
...@@ -356,8 +354,8 @@ class TestCollator: ...@@ -356,8 +354,8 @@ class TestCollator:
.unsqueeze(0) .unsqueeze(0)
.unsqueeze(0), .unsqueeze(0),
): ):
output.append(x[1]) output.extend([x[1]])
outputs_.append(cont_toks) outputs_.extend([cont_toks])
assert len(output) == len(outputs_) assert len(output) == len(outputs_)
# check indices # check indices
reordered_output = loglikelihoods.get_original(output) reordered_output = loglikelihoods.get_original(output)
......
...@@ -12,9 +12,9 @@ from lm_eval.utils import load_yaml_config ...@@ -12,9 +12,9 @@ from lm_eval.utils import load_yaml_config
# reads a text file and returns a list of words # reads a text file and returns a list of words
# used to read the output of the changed txt from tj-actions/changed-files # used to read the output of the changed txt from tj-actions/changed-files
def load_changed_files(file_path: str) -> List[str]: def load_changed_files(file_path: str) -> List[str]:
with open(file_path, "r") as f: with open(file_path, "r", encoding="utf-8") as f:
content = f.read() content = f.read()
words_list = [x for x in content.split()] words_list = list(content.split())
return words_list return words_list
...@@ -25,7 +25,7 @@ def load_changed_files(file_path: str) -> List[str]: ...@@ -25,7 +25,7 @@ def load_changed_files(file_path: str) -> List[str]:
def parser(full_path: List[str]) -> List[str]: def parser(full_path: List[str]) -> List[str]:
_output = set() _output = set()
for x in full_path: for x in full_path:
if os.path.exists(x) and x.endswith(".yaml"): if x.endswith(".yaml") and os.path.exists(x):
config = load_yaml_config(x, mode="simple") config = load_yaml_config(x, mode="simple")
if isinstance(config["task"], str): if isinstance(config["task"], str):
_output.add(config["task"]) _output.add(config["task"])
...@@ -40,10 +40,9 @@ def new_tasks() -> Union[List[str], None]: ...@@ -40,10 +40,9 @@ def new_tasks() -> Union[List[str], None]:
# If tasks folder has changed then we get the list of files from FILENAME # If tasks folder has changed then we get the list of files from FILENAME
# and parse the yaml files to get the task names. # and parse the yaml files to get the task names.
return parser(load_changed_files(FILENAME)) return parser(load_changed_files(FILENAME))
elif os.getenv("API") is not None: if os.getenv("API") is not None:
# Or if API has changed then we set the ENV variable API to True # Or if API has changed then we set the ENV variable API to True
# and run given tasks. # and run given tasks.
return ["arc_easy", "hellaswag", "piqa", "wikitext"] return ["arc_easy", "hellaswag", "piqa", "wikitext"]
# if both not true just do arc_easy # if both not true just do arc_easy
else: return None
return
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment