Commit 4225df50 authored by baberabb's avatar baberabb
Browse files

test pythia-70m and pytest only test_evaluator.py

parent d9b547b7
...@@ -44,7 +44,7 @@ jobs: ...@@ -44,7 +44,7 @@ jobs:
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- name: Test with pytest - name: Test with pytest
run: | run: |
pytest -vv --cov=lm_eval/ tests/ pytest -vv tests/evaluator.py
- name: Upload to codecov # - name: Upload to codecov
run: | # run: |
bash <(curl -s https://codecov.io/bash) -t $CODECOV_TOKEN # bash <(curl -s https://codecov.io/bash) -t $CODECOV_TOKEN
...@@ -14,58 +14,70 @@ import pytest ...@@ -14,58 +14,70 @@ import pytest
# TODO: more fine grained unit tests rather than this big honking integration # TODO: more fine grained unit tests rather than this big honking integration
# test once we break evaluator into smaller, more manageable pieces # test once we break evaluator into smaller, more manageable pieces
# @pytest.mark.parametrize("taskname,task_class", tasks.TASK_REGISTRY.items())
def test_evaluator(): @pytest.mark.parametrize(
TASK = ["arc_easy"] ("task_name,limit,model,model_args"),
LIMIT = 10 [
(
["arc_easy"],
10,
"hf",
"pretrained=EleutherAI/pythia-160m,dtype=float32,device=cpu",
)
],
)
def test_evaluator(task_name: list[str], limit: int, model: str, model_args: str):
task_name = task_name
limit = 10
model, model_args = model, model_args
# task_dict = tasks.get_task_dict(task) # task_dict = tasks.get_task_dict(task)
# TODO: re-add cachingLM # TODO: re-add cachingLM
# os.system("rm test_cache.db") # os.system("rm test_cache.db")
# lm = base.CachingLM(models.get_model("dummy")(), "test_cache.db") # lm = base.CachingLM(models.get_model("dummy")(), "test_cache.db")
lm = registry.get_model("dummy")() # lm = registry.get_model("dummy")()
def ll_fn(reqs): # def ll_fn(reqs):
for ctx, cont in [req.args for req in reqs]: # for ctx, cont in [req.args for req in reqs]:
if len(ctx) == 0: # if len(ctx) == 0:
continue # continue
# space convention # # space convention
assert ctx[-1] != " " # assert ctx[-1] != " "
assert cont[0] == " " or ctx[-1] == "\n" # assert cont[0] == " " or ctx[-1] == "\n"
#
res = [] # res = []
#
random.seed(42) # random.seed(42)
for _ in reqs: # for _ in reqs:
res.append((-random.random(), False)) # res.append((-random.random(), False))
#
return res # return res
#
def ll_perp_fn(reqs): # def ll_perp_fn(reqs):
for (string,) in reqs: # for (string,) in reqs:
assert isinstance(string, str) # assert isinstance(string, str)
#
res = [] # res = []
random.seed(42) # random.seed(42)
for _ in reqs: # for _ in reqs:
res.append(-random.random()) # res.append(-random.random())
#
return res # return res
#
lm.loglikelihood = ll_fn # lm.loglikelihood = ll_fn
lm.loglikelihood_rolling = ll_perp_fn # lm.loglikelihood_rolling = ll_perp_fn
e1 = evaluator.simple_evaluate( e1 = evaluator.simple_evaluate(
model="dummy", model=model,
tasks=TASK, tasks=task_name,
limit=LIMIT, limit=limit,
bootstrap_iters=10, model_args=model_args,
) )
e2 = evaluator.simple_evaluate( e2 = evaluator.simple_evaluate(
model="dummy", model=model,
tasks=TASK, tasks=task_name,
limit=LIMIT, limit=limit,
bootstrap_iters=10, model_args=model_args,
) )
# check that caching is working # check that caching is working
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment