Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
6fc2ac49
Commit
6fc2ac49
authored
Jul 12, 2025
by
Baber
Browse files
fix circular
parent
a9c16905
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
39 additions
and
39 deletions
+39
-39
lm_eval/api/task.py
lm_eval/api/task.py
+12
-12
lm_eval/prompts/__init__.py
lm_eval/prompts/__init__.py
+3
-2
lm_eval/tasks/__init__.py
lm_eval/tasks/__init__.py
+0
-24
lm_eval/utils.py
lm_eval/utils.py
+23
-0
tests/test_utils.py
tests/test_utils.py
+1
-1
No files found.
lm_eval/api/task.py
View file @
6fc2ac49
...
...
@@ -24,11 +24,9 @@ import datasets
import
numpy
as
np
from
tqdm
import
tqdm
import
lm_eval.tasks
from
lm_eval
import
utils
from
lm_eval.api
import
samplers
from
lm_eval.api.instance
import
Instance
,
OutputType
from
lm_eval.api.metrics
import
bits_per_byte
,
mean
,
weighted_perplexity
from
lm_eval.api.registry
import
(
AGGREGATION_REGISTRY
,
DEFAULT_METRIC_REGISTRY
,
...
...
@@ -1125,7 +1123,7 @@ class ConfigurableTask(Task):
# get task description
if
description
:
=
self
.
config
.
description
:
description
=
lm_eval
.
task
s
.
apply_template
(
self
.
config
.
description
,
doc
)
description
=
util
s
.
apply_template
(
self
.
config
.
description
,
doc
)
# create system prompt based on the provided system instruction and description
if
system_instruction
is
not
None
and
description
:
...
...
@@ -1260,7 +1258,7 @@ class ConfigurableTask(Task):
return
doc_to_decontamination_query
(
doc
)
else
:
return
ast
.
literal_eval
(
lm_eval
.
task
s
.
apply_template
(
util
s
.
apply_template
(
self
.
config
.
doc_to_decontamination_query
,
doc
)
)
...
...
@@ -1293,7 +1291,7 @@ class ConfigurableTask(Task):
# else:
return
doc
[
doc_to_text
]
else
:
text_string
=
lm_eval
.
task
s
.
apply_template
(
doc_to_text
,
doc
)
text_string
=
util
s
.
apply_template
(
doc_to_text
,
doc
)
if
text_string
.
isdigit
()
and
self
.
_config
.
doc_to_choice
is
not
None
:
return
ast
.
literal_eval
(
text_string
)
else
:
...
...
@@ -1329,7 +1327,7 @@ class ConfigurableTask(Task):
# else:
return
doc
[
doc_to_target
]
else
:
target_string
=
lm_eval
.
task
s
.
apply_template
(
doc_to_target
,
doc
)
target_string
=
util
s
.
apply_template
(
doc_to_target
,
doc
)
if
target_string
.
isdigit
()
and
self
.
_config
.
doc_to_choice
is
not
None
:
return
ast
.
literal_eval
(
target_string
)
elif
(
...
...
@@ -1372,9 +1370,7 @@ class ConfigurableTask(Task):
if
doc_to_choice
in
self
.
features
:
return
doc
[
doc_to_choice
]
else
:
return
ast
.
literal_eval
(
lm_eval
.
tasks
.
apply_template
(
doc_to_choice
,
doc
)
)
return
ast
.
literal_eval
(
utils
.
apply_template
(
doc_to_choice
,
doc
))
elif
isinstance
(
doc_to_choice
,
list
):
return
doc_to_choice
elif
isinstance
(
doc_to_choice
,
dict
):
...
...
@@ -1403,7 +1399,7 @@ class ConfigurableTask(Task):
if
doc_to_image
in
self
.
features
:
return
doc
[
doc_to_image
]
else
:
return
ast
.
literal_eval
(
lm_eval
.
task
s
.
apply_template
(
doc_to_image
,
doc
))
return
ast
.
literal_eval
(
util
s
.
apply_template
(
doc_to_image
,
doc
))
elif
callable
(
doc_to_image
):
return
doc_to_image
(
doc
)
else
:
...
...
@@ -1426,7 +1422,7 @@ class ConfigurableTask(Task):
if
doc_to_audio
in
self
.
features
:
return
doc
[
doc_to_audio
]
else
:
return
ast
.
literal_eval
(
lm_eval
.
task
s
.
apply_template
(
doc_to_audio
,
doc
))
return
ast
.
literal_eval
(
util
s
.
apply_template
(
doc_to_audio
,
doc
))
elif
callable
(
doc_to_audio
):
return
doc_to_audio
(
doc
)
else
:
...
...
@@ -1437,7 +1433,7 @@ class ConfigurableTask(Task):
if
gen_prefix
in
self
.
features
:
return
doc
[
gen_prefix
]
else
:
return
lm_eval
.
task
s
.
apply_template
(
gen_prefix
,
doc
)
return
util
s
.
apply_template
(
gen_prefix
,
doc
)
return
None
def
construct_requests
(
...
...
@@ -1802,6 +1798,8 @@ class MultipleChoiceTask(Task):
}
def
aggregation
(
self
)
->
dict
:
from
lm_eval.api.metrics
import
mean
return
{
"acc"
:
mean
,
"acc_norm"
:
mean
,
...
...
@@ -1868,6 +1866,8 @@ class PerplexityTask(Task):
}
def
aggregation
(
self
)
->
dict
:
from
lm_eval.api.metrics
import
bits_per_byte
,
weighted_perplexity
return
{
"word_perplexity"
:
weighted_perplexity
,
"byte_perplexity"
:
weighted_perplexity
,
...
...
lm_eval/prompts/__init__.py
View file @
6fc2ac49
...
...
@@ -4,6 +4,7 @@ import os
from
typing
import
Dict
import
lm_eval.tasks
import
lm_eval.utils
from
lm_eval
import
utils
...
...
@@ -123,7 +124,7 @@ class PromptString:
if
"doc_to_choice"
in
self
.
prompt_string
:
raise
NotImplementedError
(
"Not yet implemented to accept doc_to_choice"
)
text_string
=
lm_eval
.
task
s
.
apply_template
(
doc_to_text
,
doc
)
target_string
=
lm_eval
.
task
s
.
apply_template
(
doc_to_target
,
doc
)
text_string
=
lm_eval
.
util
s
.
apply_template
(
doc_to_text
,
doc
)
target_string
=
lm_eval
.
util
s
.
apply_template
(
doc_to_target
,
doc
)
return
[
text_string
,
target_string
]
lm_eval/tasks/__init__.py
View file @
6fc2ac49
...
...
@@ -3,7 +3,6 @@ import functools
import
importlib.util
import
inspect
import
logging
import
re
import
sys
from
functools
import
partial
from
glob
import
iglob
...
...
@@ -11,7 +10,6 @@ from pathlib import Path
from
typing
import
Any
,
Callable
,
Dict
,
Generator
,
List
,
Mapping
,
Optional
,
Union
import
yaml
from
jinja2
import
BaseLoader
,
Environment
,
StrictUndefined
from
yaml
import
YAMLError
from
lm_eval
import
utils
...
...
@@ -177,28 +175,6 @@ def load_yaml_config(
return
final_cfg
def
regex_replace
(
string
,
pattern
,
repl
,
count
:
int
=
0
):
"""Implements the `re.sub` function as a custom Jinja filter."""
return
re
.
sub
(
pattern
,
repl
,
string
,
count
=
count
)
@
functools
.
lru_cache
(
maxsize
=
256
)
def
_compile_tpl
(
src
:
str
):
return
apply_template
.
_env
.
from_string
(
src
)
def
apply_template
(
template
:
str
,
doc
:
dict
)
->
str
:
if
not
hasattr
(
apply_template
,
"_env"
):
apply_template
.
_env
=
Environment
(
loader
=
BaseLoader
(),
undefined
=
StrictUndefined
,
keep_trailing_newline
=
True
,
)
apply_template
.
_env
.
filters
[
"regex_replace"
]
=
regex_replace
return
_compile_tpl
(
template
).
render
(
**
doc
)
def
iter_yaml_files
(
root
:
Path
)
->
Generator
[
Path
,
Any
,
None
]:
# '**/*.yaml' is handled internally by os.scandir.
for
path
in
iglob
(
"**/*.yaml"
,
root_dir
=
root
,
recursive
=
True
):
...
...
lm_eval/utils.py
View file @
6fc2ac49
...
...
@@ -13,6 +13,7 @@ from itertools import islice
from
typing
import
Any
,
Callable
,
Generator
,
List
,
Optional
,
Tuple
import
numpy
as
np
from
jinja2
import
BaseLoader
,
Environment
,
StrictUndefined
SPACING
=
" "
*
47
...
...
@@ -511,3 +512,25 @@ def hash_dict_images(data_dict):
if
importlib
.
util
.
find_spec
(
"PIL"
)
else
data_dict
)
def
regex_replace
(
string
,
pattern
,
repl
,
count
:
int
=
0
):
"""Implements the `re.sub` function as a custom Jinja filter."""
return
re
.
sub
(
pattern
,
repl
,
string
,
count
=
count
)
@
functools
.
lru_cache
(
maxsize
=
256
)
def
_compile_tpl
(
src
:
str
):
return
apply_template
.
_env
.
from_string
(
src
)
def
apply_template
(
template
:
str
,
doc
:
dict
)
->
str
:
if
not
hasattr
(
apply_template
,
"_env"
):
apply_template
.
_env
=
Environment
(
loader
=
BaseLoader
(),
undefined
=
StrictUndefined
,
keep_trailing_newline
=
True
,
)
apply_template
.
_env
.
filters
[
"regex_replace"
]
=
regex_replace
return
_compile_tpl
(
template
).
render
(
**
doc
)
tests/test_utils.py
View file @
6fc2ac49
...
...
@@ -11,8 +11,8 @@ from lm_eval.api.metrics import (
stderr_for_metric
,
)
from
lm_eval.models.utils
import
Collator
from
lm_eval.tasks
import
apply_template
from
lm_eval.utils
import
(
apply_template
,
get_rolling_token_windows
,
make_disjoint_window
,
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment