Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
1060b68d
Unverified
Commit
1060b68d
authored
May 31, 2024
by
LSinev
Committed by
GitHub
May 31, 2024
Browse files
Try to make existing tests run little bit faster (#1905)
parent
4902aaaf
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
118 additions
and
182 deletions
+118
-182
tests/models/test_gguf.py
tests/models/test_gguf.py
+2
-2
tests/models/test_huggingface.py
tests/models/test_huggingface.py
+10
-6
tests/models/test_neuralmagic.py
tests/models/test_neuralmagic.py
+1
-1
tests/models/test_openvino.py
tests/models/test_openvino.py
+4
-4
tests/models/test_vllm.py
tests/models/test_vllm.py
+1
-1
tests/test_evaluator.py
tests/test_evaluator.py
+2
-2
tests/test_janitor.py
tests/test_janitor.py
+53
-114
tests/test_requests_caching.py
tests/test_requests_caching.py
+6
-7
tests/test_tasks.py
tests/test_tasks.py
+2
-5
tests/test_utils.py
tests/test_utils.py
+32
-34
tests/utils.py
tests/utils.py
+5
-6
No files found.
tests/models/test_gguf.py
View file @
1060b68d
...
...
@@ -15,11 +15,11 @@ base_url = "https://matthoffner-ggml-llm-api.hf.space"
def
gguf_completion_mock
(
base_url
=
None
,
**
kwargs
):
# Generate a hash from the parameters
hash_kwargs
=
{
"base_url"
:
base_url
,
**
kwargs
}
hash
=
hashlib
.
sha256
(
parameters_
hash
=
hashlib
.
sha256
(
json
.
dumps
(
hash_kwargs
,
sort_keys
=
True
).
encode
(
"utf-8"
)
).
hexdigest
()
fname
=
f
"./tests/testdata/gguf_test_
{
hash
}
.pkl"
fname
=
f
"./tests/testdata/gguf_test_
{
parameters_
hash
}
.pkl"
if
os
.
path
.
exists
(
fname
):
with
open
(
fname
,
"rb"
)
as
fh
:
...
...
tests/models/test_huggingface.py
View file @
1060b68d
from
__future__
import
annotations
import
os
import
sys
from
pathlib
import
Path
import
numpy
as
np
import
torch
import
lm_eval
.tasks
as
tasks
from
lm_eval
import
tasks
from
lm_eval.api.instance
import
Instance
from
lm_eval.models.huggingface
import
HFLM
os
.
environ
[
"TOKENIZERS_PARALLELISM"
]
=
"false"
task_manager
=
tasks
.
TaskManager
()
TEST_STRING
=
"foo bar"
class
Test_HFLM
:
torch
.
use_deterministic_algorithms
(
True
)
...
...
@@ -107,7 +111,7 @@ class Test_HFLM:
file_path
=
dir_path
/
f
"outputs_log_
{
self
.
version_minor
}
.txt"
file_path
=
file_path
.
resolve
()
with
open
(
file_path
,
"w"
)
as
f
:
with
open
(
file_path
,
"w"
,
encoding
=
"utf-8"
)
as
f
:
f
.
write
(
"
\n
"
.
join
(
str
(
x
)
for
x
in
_res
))
assert
np
.
allclose
(
_res
,
_RES
,
atol
=
1e-2
)
# check indices for Multiple Choice
...
...
@@ -126,19 +130,19 @@ class Test_HFLM:
assert
np
.
allclose
(
res
,
self
.
ROLLING_RES
,
atol
=
1e-1
)
def
test_toc_encode
(
self
)
->
None
:
res
=
self
.
LM
.
tok_encode
(
"foo bar"
)
res
=
self
.
LM
.
tok_encode
(
TEST_STRING
)
assert
res
==
[
12110
,
2534
]
def
test_toc_decode
(
self
)
->
None
:
res
=
self
.
LM
.
tok_decode
([
12110
,
2534
])
assert
res
==
"foo bar"
assert
res
==
TEST_STRING
def
test_batch_encode
(
self
)
->
None
:
res
=
self
.
LM
.
tok_batch_encode
([
"foo bar"
,
"bar foo"
])[
0
].
tolist
()
res
=
self
.
LM
.
tok_batch_encode
([
TEST_STRING
,
"bar foo"
])[
0
].
tolist
()
assert
res
==
[[
12110
,
2534
],
[
2009
,
17374
]]
def
test_model_generate
(
self
)
->
None
:
context
=
self
.
LM
.
tok_batch_encode
([
"foo bar"
])[
0
]
context
=
self
.
LM
.
tok_batch_encode
([
TEST_STRING
])[
0
]
res
=
self
.
LM
.
_model_generate
(
context
,
max_length
=
10
,
stop
=
[
"
\n\n
"
])
res
=
self
.
LM
.
tok_decode
(
res
[
0
])
assert
res
==
"foo bar
\n
<bazhang>!info bar"
tests/models/test_neuralmagic.py
View file @
1060b68d
import
pytest
import
lm_eval
.evaluator
as
evaluator
from
lm_eval
import
evaluator
from
lm_eval.api.registry
import
get_model
...
...
tests/models/test_openvino.py
View file @
1060b68d
...
...
@@ -6,7 +6,7 @@ import pytest
from
optimum.intel
import
OVModelForCausalLM
from
transformers
import
AutoTokenizer
import
lm_eval
.evaluator
as
evaluator
from
lm_eval
import
evaluator
from
lm_eval.api.registry
import
get_model
...
...
@@ -46,7 +46,7 @@ def test_evaluator(model_id, task):
random
.
seed
(
42
)
for
_
in
reqs
:
res
.
app
end
((
-
random
.
random
(),
False
))
res
.
ext
end
(
[
(
-
random
.
random
(),
False
)
]
)
return
res
...
...
@@ -57,7 +57,7 @@ def test_evaluator(model_id, task):
res
=
[]
random
.
seed
(
42
)
for
_
in
reqs
:
res
.
app
end
(
-
random
.
random
())
res
.
ext
end
(
[
-
random
.
random
()
]
)
return
res
...
...
@@ -79,7 +79,7 @@ def test_ov_config():
model_id
=
"hf-internal-testing/tiny-random-gpt2"
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
config_file
=
str
(
Path
(
tmpdirname
)
/
"ov_config.json"
)
with
open
(
Path
(
config_file
),
"w"
)
as
f
:
with
open
(
Path
(
config_file
),
"w"
,
encoding
=
"utf-8"
)
as
f
:
f
.
write
(
'{"DYNAMIC_QUANTIZATION_GROUP_SIZE" : "32"}'
)
lm
=
get_model
(
"openvino"
).
create_from_arg_string
(
f
"pretrained=
{
model_id
}
,ov_config=
{
config_file
}
"
...
...
tests/models/test_vllm.py
View file @
1060b68d
...
...
@@ -3,7 +3,7 @@ from typing import List
import
pytest
import
torch
import
lm_eval
.tasks
as
tasks
from
lm_eval
import
tasks
from
lm_eval.api.instance
import
Instance
...
...
tests/test_evaluator.py
View file @
1060b68d
#
import
lm_eval.base as base
import
os
from
typing
import
List
import
pytest
# import lm_eval.models as models
import
lm_eval.api
as
api
import
lm_eval.evaluator
as
evaluator
from
lm_eval
import
tasks
os
.
environ
[
"TOKENIZERS_PARALLELISM"
]
=
"false"
# TODO: more fine grained unit tests rather than this big honking integration
# test once we break evaluator into smaller, more manageable pieces
...
...
tests/test_janitor.py
View file @
1060b68d
import
os
from
collections
import
defaultdict
from
lm_eval.decontamination.janitor
import
(
...
...
@@ -9,23 +10,41 @@ from lm_eval.decontamination.janitor import (
)
os
.
environ
[
"TOKENIZERS_PARALLELISM"
]
=
"false"
TEST_SEQUENCE
=
(
"Hello my name is Bob, I like eating pizza, chicken, chips and ice cream. Maybe I should eat some"
" more salad but it's so booooring. I just... like eating pizza, chicken, chips and ice cream so much."
)
JANITOR_EXPECTED
=
(
"This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing "
" characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. "
)
JANITOR_FILTH1
=
"filth lots of dirty filthy filth"
JANITOR_FILTH2
=
"filth lots of filthy dirty filth"
def
simple_ngram
(
sequence
,
n
):
ngrams
=
list
()
ngram
=
[]
for
x
in
sequence
:
ngram
.
app
end
(
x
)
ngram
.
ext
end
(
[
x
]
)
if
len
(
ngram
)
==
n
:
ngrams
.
app
end
(
tuple
(
ngram
))
ngrams
.
ext
end
(
[
tuple
(
ngram
)
]
)
ngram
=
ngram
[
1
:]
return
ngrams
def
test_form_ngrams
():
sequence
=
(
"Hello my name is Bob, I like eating pizza, chicken, chips and ice cream. Maybe I should eat some"
" more salad but it's so booooring. I just... like eating pizza, chicken, chips and ice cream so much."
)
sequence
=
TEST_SEQUENCE
n_values
=
[
1
,
2
,
3
,
5
,
13
]
for
n
in
n_values
:
...
...
@@ -36,10 +55,7 @@ def test_form_ngrams():
def
test_word_ngrams
():
sequence
=
(
"Hello my name is Bob, I like eating pizza, chicken, chips and ice cream. Maybe I should eat some"
" more salad but it's so booooring. I just... like eating pizza, chicken, chips and ice cream so much."
)
sequence
=
TEST_SEQUENCE
words
=
sequence
.
split
()
...
...
@@ -53,10 +69,7 @@ def test_word_ngrams():
def
test_split_indices
():
sequence
=
(
"Hello my name is Bob, I like eating pizza, chicken, chips and ice cream. Maybe I should eat some"
" more salad but it's so booooring. I just... like eating pizza, chicken, chips and ice cream so much."
)
sequence
=
TEST_SEQUENCE
comparison
=
[]
current_word
=
""
...
...
@@ -65,12 +78,18 @@ def test_split_indices():
current_word
+=
c
else
:
if
current_word
:
comparison
.
app
end
((
current_word
,
(
i
-
len
(
current_word
),
i
-
1
)))
comparison
.
ext
end
(
[
(
current_word
,
(
i
-
len
(
current_word
),
i
-
1
))
]
)
current_word
=
""
if
current_word
:
comparison
.
append
(
(
current_word
,
(
len
(
sequence
)
-
len
(
current_word
),
len
(
sequence
)
-
1
))
len_sequence
=
len
(
sequence
)
comparison
.
extend
(
[
(
current_word
,
(
len_sequence
-
len
(
current_word
),
len_sequence
-
1
),
)
]
)
current_word
=
""
...
...
@@ -80,10 +99,7 @@ def test_split_indices():
def
test_word_ngrams_indices
():
sequence
=
(
"Hello my name is Bob, I like eating pizza, chicken, chips and ice cream. Maybe I should eat some"
" more salad but it's so booooring. I just... like eating pizza, chicken, chips and ice cream so much."
)
sequence
=
TEST_SEQUENCE
n_values
=
[
1
,
2
,
3
,
5
,
13
]
...
...
@@ -100,14 +116,13 @@ def test_word_ngrams_indices():
tracker
[
ngram
]
=
end
+
1
# ignore partial word matches
if
(
start
!=
0
and
sequence
[
start
-
1
]
!=
" "
)
or
(
end
!=
len
(
sequence
)
-
1
and
sequence
[
end
+
1
]
!=
" "
if
not
(
(
start
!=
0
and
sequence
[
start
-
1
]
!=
" "
)
or
(
end
!=
len
(
sequence
)
-
1
and
sequence
[
end
+
1
]
!=
" "
)
):
pass
else
:
break
comparison
.
app
end
((
ngram
,
(
start
,
end
)))
comparison
.
ext
end
(
[
(
ngram
,
(
start
,
end
))
]
)
result_to_test
=
list
(
word_ngrams_indices
(
sequence
,
n
))
assert
len
(
result_to_test
)
==
len
(
comparison
)
...
...
@@ -184,17 +199,6 @@ def test_janitor2():
filth
=
"filth"
expected_result
=
(
"This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing "
" characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. "
)
janitor
=
Janitor
(
ngram_n
=
1
,
window_to_remove
=
200
,
too_dirty_cutoff
=
10
,
minimum_slice_length
=
200
)
...
...
@@ -207,7 +211,7 @@ def test_janitor2():
result
=
janitor
.
clean_python
(
sequence
)
result
=
""
.
join
(
result
)
assert
result
==
expected_result
assert
result
==
JANITOR_EXPECTED
def
test_janitor3
():
...
...
@@ -229,19 +233,6 @@ def test_janitor3():
"This is a @line #containing a certain number of characters, 76 to be exact. "
)
filth
=
"filth lots of dirty filthy filth"
expected_result
=
(
"This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing "
" characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. "
)
janitor
=
Janitor
(
ngram_n
=
6
,
window_to_remove
=
200
,
too_dirty_cutoff
=
10
,
minimum_slice_length
=
200
)
...
...
@@ -249,12 +240,12 @@ def test_janitor3():
result
=
""
.
join
(
result
)
assert
result
==
sequence
janitor
.
register_contaminant
(
filth
)
assert
janitor
.
dirt_ngrams
==
{
filth
}
janitor
.
register_contaminant
(
JANITOR_FILTH1
)
assert
janitor
.
dirt_ngrams
==
{
JANITOR_FILTH1
}
result
=
janitor
.
clean_python
(
sequence
)
result
=
""
.
join
(
result
)
assert
result
==
expected_result
assert
result
==
JANITOR_EXPECTED
def
test_janitor4
():
...
...
@@ -284,19 +275,6 @@ def test_janitor4():
"This is a @line #containing a certain number of characters, 76 to be exact. "
)
filth
=
"filth lots of dirty filthy filth"
expected_result
=
(
"This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing "
" characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. "
)
janitor
=
Janitor
(
ngram_n
=
6
,
window_to_remove
=
200
,
too_dirty_cutoff
=
10
,
minimum_slice_length
=
200
)
...
...
@@ -304,12 +282,12 @@ def test_janitor4():
result
=
""
.
join
(
result
)
assert
result
==
sequence
janitor
.
register_contaminant
(
filth
)
assert
janitor
.
dirt_ngrams
==
{
filth
}
janitor
.
register_contaminant
(
JANITOR_FILTH1
)
assert
janitor
.
dirt_ngrams
==
{
JANITOR_FILTH1
}
result
=
janitor
.
clean_python
(
sequence
)
result
=
""
.
join
(
result
)
assert
result
==
expected_result
assert
result
==
JANITOR_EXPECTED
def
test_janitor5
():
...
...
@@ -338,18 +316,7 @@ def test_janitor5():
"This is a @line #containing a certain number of characters, 76 to be exact. "
)
filths
=
[
"filth lots of dirty filthy filth"
,
"filth lots of filthy dirty filth"
]
expected_result
=
(
"This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing "
" characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. "
)
filths
=
[
JANITOR_FILTH1
,
JANITOR_FILTH2
]
janitor
=
Janitor
(
ngram_n
=
6
,
window_to_remove
=
200
,
too_dirty_cutoff
=
10
,
minimum_slice_length
=
200
...
...
@@ -364,7 +331,7 @@ def test_janitor5():
result
=
janitor
.
clean_python
(
sequence
)
result
=
""
.
join
(
result
)
assert
result
==
expected_result
assert
result
==
JANITOR_EXPECTED
def
test_janitor6
():
...
...
@@ -401,18 +368,7 @@ def test_janitor6():
"This is a @line #containing a certain number of characters, 76 to be exact. "
)
filths
=
[
"filth lots of dirty filthy filth"
,
"filth lots of filthy dirty filth"
]
expected_result
=
(
"This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing "
" characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. "
"This is a @line #containing a certain number of characters, 76 to be exact. "
)
filths
=
[
JANITOR_FILTH1
,
JANITOR_FILTH2
]
janitor
=
Janitor
(
ngram_n
=
6
,
window_to_remove
=
200
,
too_dirty_cutoff
=
10
,
minimum_slice_length
=
200
...
...
@@ -427,7 +383,7 @@ def test_janitor6():
result
=
janitor
.
clean_python
(
sequence
)
result
=
""
.
join
(
result
)
assert
result
==
expected_result
assert
result
==
JANITOR_EXPECTED
def
test_janitor7
():
...
...
@@ -465,7 +421,7 @@ def test_janitor7():
"This is a @line #containing a certain number of characters, 76 to be exact. "
)
filths
=
[
"filth lots of dirty filthy filth"
,
"filth lots of filthy dirty filth"
]
filths
=
[
JANITOR_FILTH1
,
JANITOR_FILTH2
]
expected_result
=
""
...
...
@@ -488,20 +444,3 @@ def test_janitor7():
def
test_janitor8
():
# This will test the save and load contams
pass
# source = """ ,, I'm a very !dirty,, ,, dirty boy. Clean me daddy. \n\nhe he he hehe heh. lastword """ * 2
# contaminant = "dirty boy. Clean he he"
# jan = Janitor(ngram_n=3)
# jan.register_contaminant(contaminant)
# cleaned = " ".join(jan.clean(source))
# for contam in jan.dirt_ngrams:
# assert contam not in cleaned, contam
# filename = "data/saved_contam"
# jan.save_contamination_ngrams(filename)
# jan = Janitor(ngram_n=3)
# jan.load_contamination_ngrams(filename)
# cleaned = " ".join(jan.clean(source))
# for contam in jan.dirt_ngrams:
# assert contam not in cleaned, contam
tests/test_requests_caching.py
View file @
1060b68d
# import lm_eval.base as base
import
importlib
import
os
import
sys
from
datetime
import
datetime
from
typing
import
List
,
Tuple
from
typing
import
List
,
Optional
,
Tuple
import
pytest
import
torch
# import lm_eval.models as models
from
lm_eval.caching.cache
import
PATH
...
...
@@ -43,7 +41,7 @@ def clear_cache():
# leaving tasks here to allow for the option to select specific task files
def
get_cache_files
(
tasks
:
List
[
str
]
=
None
)
->
Tuple
[
List
[
str
],
List
[
str
]]:
def
get_cache_files
(
tasks
:
Optional
[
List
[
str
]
]
=
None
)
->
Tuple
[
List
[
str
],
List
[
str
]]:
cache_files
=
os
.
listdir
(
PATH
)
file_task_names
=
[]
...
...
@@ -51,7 +49,7 @@ def get_cache_files(tasks: List[str] = None) -> Tuple[List[str], List[str]]:
for
file
in
cache_files
:
file_without_prefix
=
file
.
split
(
"-"
)[
1
]
file_without_prefix_and_suffix
=
file_without_prefix
.
split
(
"."
)[
0
]
file_task_names
.
app
end
(
file_without_prefix_and_suffix
)
file_task_names
.
ext
end
(
[
file_without_prefix_and_suffix
]
)
return
cache_files
,
file_task_names
...
...
@@ -113,10 +111,11 @@ if __name__ == "__main__":
# test_requests_caching_refresh,
# test_requests_caching_delete,
]
# Lookups of global names within a loop is inefficient, so copy to a local variable outside of the loop first
default_tasks
=
DEFAULT_TASKS
for
test_func
in
tests
:
clear_cache
()
test_func
(
tasks
=
DEFAULT_TASKS
)
test_func
(
tasks
=
default_tasks
)
print
(
"Tests pass"
)
...
...
tests/test_tasks.py
View file @
1060b68d
import
os
from
itertools
import
islice
import
pytest
...
...
@@ -8,6 +9,7 @@ from lm_eval.api.task import ConfigurableTask
from
.utils
import
new_tasks
os
.
environ
[
"TOKENIZERS_PARALLELISM"
]
=
"false"
task_manager
=
tasks
.
TaskManager
()
# Default Task
TASKS
=
[
"arc_easy"
]
...
...
@@ -87,7 +89,6 @@ class TestNewTasks:
)
if
"multiple_choice"
in
task
.
_config
.
output_type
:
_array
=
[
task
.
doc_to_choice
(
doc
)
for
doc
in
arr
]
# assert all(len(x) == 4 for x in _array)
assert
all
(
isinstance
(
x
,
list
)
for
x
in
_array
)
assert
all
(
isinstance
(
x
[
0
],
str
)
for
x
in
_array
)
...
...
@@ -101,9 +102,6 @@ class TestNewTasks:
_array_target
=
[
task
.
doc_to_target
(
doc
)
for
doc
in
arr
]
if
task
.
_config
.
output_type
==
"multiple_choice"
:
assert
all
(
isinstance
(
label
,
int
)
for
label
in
_array_target
)
# _array_text = [task.doc_to_text(doc) for doc in arr]
# Not working
# assert all(tgt[0] == " " or txt[-1] == "\n" if len(txt) != 0 else True for txt, tgt in zip(_array_text, _array_target))
def
test_build_all_requests
(
self
,
task_class
,
limit
):
task_class
.
build_all_requests
(
rank
=
1
,
limit
=
limit
,
world_size
=
1
)
...
...
@@ -118,5 +116,4 @@ class TestNewTasks:
else
list
(
islice
(
task
.
validation_docs
(),
limit
))
)
requests
=
[
task
.
construct_requests
(
doc
,
task
.
doc_to_text
(
doc
))
for
doc
in
arr
]
# assert all(isinstance(doc, list) for doc in requests)
assert
len
(
requests
)
==
limit
if
limit
else
True
tests/test_utils.py
View file @
1060b68d
...
...
@@ -41,7 +41,7 @@ def test_get_rolling_token_windows_v1():
pred_length
=
0
output
=
[]
for
input_tokens
,
pred_tokens
in
generator
:
output
.
app
end
((
input_tokens
,
pred_tokens
))
output
.
ext
end
(
[
(
input_tokens
,
pred_tokens
)
]
)
pred_length
+=
len
(
pred_tokens
)
assert
pred_length
==
len
(
x
)
assert
gold
==
output
...
...
@@ -70,7 +70,7 @@ def test_get_rolling_token_windows_v2():
pred_length
=
0
output
=
[]
for
input_tokens
,
pred_tokens
in
generator
:
output
.
app
end
((
input_tokens
,
pred_tokens
))
output
.
ext
end
(
[
(
input_tokens
,
pred_tokens
)
]
)
pred_length
+=
len
(
pred_tokens
)
assert
pred_length
==
len
(
x
)
assert
gold
==
output
...
...
@@ -115,7 +115,7 @@ def test_get_rolling_token_windows_v3():
pred_length
=
0
output
=
[]
for
input_tokens
,
pred_tokens
in
generator
:
output
.
app
end
((
input_tokens
,
pred_tokens
))
output
.
ext
end
(
[
(
input_tokens
,
pred_tokens
)
]
)
pred_length
+=
len
(
pred_tokens
)
assert
pred_length
==
len
(
x
)
assert
gold
==
output
...
...
@@ -156,7 +156,7 @@ def test_get_rolling_token_windows_v4():
pred_length
=
0
output
=
[]
for
input_tokens
,
pred_tokens
in
generator
:
output
.
app
end
((
input_tokens
,
pred_tokens
))
output
.
ext
end
(
[
(
input_tokens
,
pred_tokens
)
]
)
pred_length
+=
len
(
pred_tokens
)
assert
pred_length
==
len
(
x
)
assert
gold
==
output
...
...
@@ -185,7 +185,7 @@ def test_get_rolling_token_windows_v5():
pred_length
=
0
output
=
[]
for
input_tokens
,
pred_tokens
in
generator
:
output
.
app
end
((
input_tokens
,
pred_tokens
))
output
.
ext
end
(
[
(
input_tokens
,
pred_tokens
)
]
)
pred_length
+=
len
(
pred_tokens
)
assert
pred_length
==
len
(
x
)
assert
gold
==
output
...
...
@@ -210,7 +210,7 @@ def test_get_rolling_token_windows_v6():
pred_length
=
0
output
=
[]
for
input_tokens
,
pred_tokens
in
generator
:
output
.
app
end
((
input_tokens
,
pred_tokens
))
output
.
ext
end
(
[
(
input_tokens
,
pred_tokens
)
]
)
pred_length
+=
len
(
pred_tokens
)
assert
pred_length
==
len
(
x
)
assert
gold
==
output
...
...
@@ -273,26 +273,26 @@ class TestCollator:
generation_samples
=
self
.
make_generate_sample
(
int
(
end
))
gens
=
Collator
(
generation_samples
,
_collate_gen
,
group_by
=
"gen_kwargs"
)
chunks
=
gens
.
get_batched
(
n
=
int
(
batch_size
),
batch_fn
=
None
)
chunks
_gen
=
gens
.
get_batched
(
n
=
int
(
batch_size
),
batch_fn
=
None
)
output
=
[]
for
chunks
in
chunks
:
group_one
=
end
//
2
group_two
=
end
-
end
//
2
is_batch
=
batch_size
!=
0
for
chunks
in
chunks_gen
:
# check batching
group_one
=
end
//
2
group_two
=
end
-
end
//
2
assert
(
len
(
chunks
)
<=
batch_size
if
batch
_size
!=
0
if
is_
batch
else
len
(
chunks
)
in
[
group_one
,
group_two
]
)
# check if reorder-er is working correctly
assert
all
(
len
(
chunks
[
i
][
0
])
<=
len
(
chunks
[
i
-
1
][
0
])
for
i
in
range
(
1
,
len
(
chunks
))
)
chunk_lengths
=
[
len
(
chunk
[
0
])
for
chunk
in
chunks
]
assert
chunk_lengths
==
sorted
(
chunk_lengths
,
reverse
=
True
)
# check if grouping correctly
assert
all
(
x
[
1
]
==
chunks
[
0
][
1
]
for
x
in
chunks
)
chunk_to_compare
=
chunks
[
0
][
1
]
assert
all
(
x
[
1
]
==
chunk_to_compare
for
x
in
chunks
)
for
x
in
chunks
:
output
.
app
end
(
x
)
output
.
ext
end
(
[
x
]
)
reordered_output
=
gens
.
get_original
(
output
)
# check get original
assert
reordered_output
==
generation_samples
...
...
@@ -305,18 +305,17 @@ class TestCollator:
loglikelihood_samples
,
_collate_log
,
)
chunks
=
loglikelihoods
.
get_batched
(
n
=
int
(
batch_size
),
batch_fn
=
None
)
chunks
_gen
=
loglikelihoods
.
get_batched
(
n
=
int
(
batch_size
),
batch_fn
=
None
)
output
=
[]
for
chunks
in
chunks
:
is_batch
=
batch_size
!=
0
for
chunks
in
chunks_gen
:
# check batching
assert
len
(
chunks
)
<=
batch_size
if
batch
_size
!=
0
else
len
(
chunks
)
==
end
assert
len
(
chunks
)
<=
batch_size
if
is_
batch
else
len
(
chunks
)
==
end
# check reorder
assert
all
(
len
(
chunks
[
i
][
1
])
<=
len
(
chunks
[
i
-
1
][
1
])
for
i
in
range
(
1
,
len
(
chunks
))
)
chunk_lengths
=
[
len
(
chunk
[
1
])
for
chunk
in
chunks
]
assert
chunk_lengths
==
sorted
(
chunk_lengths
,
reverse
=
True
)
for
x
in
chunks
:
output
.
app
end
(
x
[
1
])
output
.
ext
end
(
[
x
[
1
]
]
)
# check indices
reordered_output
=
loglikelihoods
.
get_original
(
output
)
assert
reordered_output
==
[
x
[
1
]
for
x
in
loglikelihood_samples
]
...
...
@@ -335,18 +334,17 @@ class TestCollator:
group_fn
=
lambda
a
:
a
[
-
2
]
+
a
[
-
1
][:
-
1
],
group_by
=
"contexts"
,
)
chunks
=
loglikelihoods
.
get_batched
(
n
=
int
(
batch_size
),
batch_fn
=
None
)
chunks
_gen
=
loglikelihoods
.
get_batched
(
n
=
int
(
batch_size
),
batch_fn
=
None
)
output
=
[]
outputs_
=
[]
for
chunks
in
chunks
:
is_batch
=
batch_size
!=
0
for
chunks
in
chunks_gen
:
# check batching
if
batch
_size
!=
0
:
if
is_
batch
:
assert
len
(
chunks
)
<=
batch_size
# check reorder
assert
all
(
len
(
chunks
[
i
][
1
])
<=
len
(
chunks
[
i
-
1
][
1
])
for
i
in
range
(
1
,
len
(
chunks
))
)
chunk_lengths
=
[
len
(
chunk
[
1
])
for
chunk
in
chunks
]
assert
chunk_lengths
==
sorted
(
chunk_lengths
,
reverse
=
True
)
for
x
in
chunks
:
for
request_str
,
cont_toks
,
logits
in
loglikelihoods
.
get_cache
(
req_str
=
""
.
join
(
x
[
0
]),
...
...
@@ -356,8 +354,8 @@ class TestCollator:
.
unsqueeze
(
0
)
.
unsqueeze
(
0
),
):
output
.
app
end
(
x
[
1
])
outputs_
.
app
end
(
cont_toks
)
output
.
ext
end
(
[
x
[
1
]
]
)
outputs_
.
ext
end
(
[
cont_toks
]
)
assert
len
(
output
)
==
len
(
outputs_
)
# check indices
reordered_output
=
loglikelihoods
.
get_original
(
output
)
...
...
tests/utils.py
View file @
1060b68d
...
...
@@ -12,9 +12,9 @@ from lm_eval.utils import load_yaml_config
# reads a text file and returns a list of words
# used to read the output of the changed txt from tj-actions/changed-files
def
load_changed_files
(
file_path
:
str
)
->
List
[
str
]:
with
open
(
file_path
,
"r"
)
as
f
:
with
open
(
file_path
,
"r"
,
encoding
=
"utf-8"
)
as
f
:
content
=
f
.
read
()
words_list
=
[
x
for
x
in
content
.
split
()
]
words_list
=
list
(
content
.
split
()
)
return
words_list
...
...
@@ -25,7 +25,7 @@ def load_changed_files(file_path: str) -> List[str]:
def
parser
(
full_path
:
List
[
str
])
->
List
[
str
]:
_output
=
set
()
for
x
in
full_path
:
if
os
.
path
.
exists
(
x
)
and
x
.
endswith
(
".yaml"
):
if
x
.
endswith
(
".yaml"
)
and
os
.
path
.
exists
(
x
):
config
=
load_yaml_config
(
x
,
mode
=
"simple"
)
if
isinstance
(
config
[
"task"
],
str
):
_output
.
add
(
config
[
"task"
])
...
...
@@ -40,10 +40,9 @@ def new_tasks() -> Union[List[str], None]:
# If tasks folder has changed then we get the list of files from FILENAME
# and parse the yaml files to get the task names.
return
parser
(
load_changed_files
(
FILENAME
))
el
if
os
.
getenv
(
"API"
)
is
not
None
:
if
os
.
getenv
(
"API"
)
is
not
None
:
# Or if API has changed then we set the ENV variable API to True
# and run given tasks.
return
[
"arc_easy"
,
"hellaswag"
,
"piqa"
,
"wikitext"
]
# if both not true just do arc_easy
else
:
return
return
None
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment