Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
5b8a7506
Commit
5b8a7506
authored
Jun 23, 2025
by
Baber
Browse files
remove other schemas. work on metrics
parent
ba1d4483
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
118 additions
and
115 deletions
+118
-115
lm_eval/api/filter.py
lm_eval/api/filter.py
+7
-8
lm_eval/api/instance.py
lm_eval/api/instance.py
+8
-7
lm_eval/api/model.py
lm_eval/api/model.py
+7
-7
lm_eval/api/schemas.py
lm_eval/api/schemas.py
+62
-57
lm_eval/api/task.py
lm_eval/api/task.py
+20
-13
lm_eval/evaluator.py
lm_eval/evaluator.py
+1
-3
lm_eval/filters/selection.py
lm_eval/filters/selection.py
+1
-1
lm_eval/models/huggingface.py
lm_eval/models/huggingface.py
+12
-19
No files found.
lm_eval/api/filter.py
View file @
5b8a7506
...
...
@@ -3,7 +3,6 @@ from dataclasses import dataclass
from
typing
import
Callable
,
Iterable
,
List
,
Union
from
lm_eval.api.instance
import
Instance
from
lm_eval.api.schemas
import
GenerateOutput
class
Filter
(
ABC
):
...
...
@@ -47,13 +46,13 @@ class FilterEnsemble:
resps
,
docs
=
zip
(
*
((
inst
.
resps
,
inst
.
doc
)
for
inst
in
instances
))
# TODO: add backward
# unwrap responses from GenerateOutput as the filters expect strings
resps
=
tuple
(
[
item
.
text
if
isinstance
(
item
,
GenerateOutput
)
else
str
(
item
)
for
item
in
sublist
]
for
sublist
in
resps
)
#
resps = tuple(
#
[
#
item.text if isinstance(item, GenerateOutput) else item
#
for item in sublist
#
]
#
for sublist in resps
#
)
for
f
in
self
.
filters
:
# apply filters in sequence
...
...
lm_eval/api/instance.py
View file @
5b8a7506
from
dataclasses
import
dataclass
,
field
from
typing
import
Generic
,
Literal
,
Optional
,
Tuple
,
TypeVar
,
Union
from
typing
import
Literal
,
Optional
,
Tuple
from
lm_eval.api.schemas
import
GenerateInput
,
LoglikelihoodInput
# from lm_eval.api.schemas import GenerateInput, LoglikelihoodInput
OutputType
=
Literal
[
"loglikelihood"
,
"loglikelihood_rolling"
,
"generate_until"
,
"multiple_choice"
]
T
=
TypeVar
(
"T"
,
LoglikelihoodInput
,
GenerateInput
)
#
T = TypeVar("T", LoglikelihoodInput, GenerateInput)
@
dataclass
class
Instance
(
Generic
[
T
])
:
class
Instance
:
request_type
:
OutputType
doc
:
dict
arguments
:
T
arguments
:
tuple
idx
:
int
metadata
:
Tuple
[
Optional
[
str
],
Optional
[
int
],
Optional
[
int
]]
=
field
(
default_factory
=
lambda
:
(
None
,
None
,
None
)
)
resps
:
list
[
Union
[
GenerateInput
,
LoglikelihoodInput
]]
=
field
(
default_factory
=
list
)
resps
:
list
=
field
(
default_factory
=
list
)
filtered_resps
:
dict
=
field
(
default_factory
=
dict
)
# initialized after init
...
...
@@ -33,7 +34,7 @@ class Instance(Generic[T]):
self
.
task_name
,
self
.
doc_id
,
self
.
repeats
=
self
.
metadata
@
property
def
args
(
self
)
->
T
:
def
args
(
self
):
"""
Returns (string,) where `string` is the string to calculate loglikelihood over
"""
...
...
lm_eval/api/model.py
View file @
5b8a7506
...
...
@@ -9,10 +9,12 @@ from tqdm import tqdm
from
lm_eval
import
utils
from
lm_eval.api.instance
import
Instance
from
lm_eval.api.schemas
import
(
LoglikelihoodInput
,
LoglikelihoodOutput
,
)
# from lm_eval.api.schemas import (
# LoglikelihoodInput,
# LoglikelihoodOutput,
# )
if
TYPE_CHECKING
:
...
...
@@ -381,9 +383,7 @@ class TemplateLM(LM):
self
,
requests
:
list
[
"Instance"
],
disable_tqdm
:
bool
=
False
)
->
list
[
tuple
[
float
,
bool
]]:
new_reqs
=
[]
for
context
,
continuation
in
(
(
req
.
args
.
context
,
req
.
args
.
continuation
)
for
req
in
requests
):
for
context
,
continuation
in
[
req
.
args
for
req
in
requests
]:
if
context
==
""
:
# BOS or EOS as context
context_enc
,
continuation_enc
=
(
...
...
lm_eval/api/schemas.py
View file @
5b8a7506
...
...
@@ -2,61 +2,60 @@ from dataclasses import dataclass
from
typing
import
Optional
@
dataclass
class
GenerateInput
:
"""
Inputs for the generate function.
"""
prompt
:
str
gen_kwargs
:
dict
multimodal_arg
:
Optional
[
dict
]
=
None
def
__iter__
(
self
):
return
(
iter
((
self
.
prompt
,
self
.
gen_kwargs
))
if
not
self
.
multimodal_arg
else
iter
((
self
.
prompt
,
self
.
gen_kwargs
,
self
.
multimodal_arg
))
)
def
__getitem__
(
self
,
item
:
int
):
return
[
self
.
prompt
,
self
.
gen_kwargs
][
item
]
@
dataclass
class
GenerateOutput
:
"""
Outputs for the generate function.
"""
text
:
str
metadata
:
dict
=
None
@
dataclass
class
LoglikelihoodInput
:
"""
Inputs for the loglikelihood function.
"""
context
:
str
continuation
:
Optional
[
str
]
=
None
@
dataclass
class
LoglikelihoodOutput
:
"""
Outputs for the loglikelihood function.
"""
loglikelihood
:
float
is_greedy
:
Optional
[
bool
]
=
None
ctx_tokens
:
Optional
[
list
[
int
]]
=
None
cont_tokens
:
Optional
[
list
[
int
]]
=
None
metadata
:
Optional
[
dict
]
=
None
def
__iter__
(
self
):
return
iter
((
self
.
loglikelihood
,
self
.
is_greedy
))
# @dataclass
# class GenerateInput:
# """
# Inputs for the generate function.
# """
#
# prompt: str
# gen_kwargs: dict
# multimodal_arg: Optional[dict] = None
#
# def __iter__(self):
# return (
# iter((self.prompt, self.gen_kwargs))
# if not self.multimodal_arg
# else iter((self.prompt, self.gen_kwargs, self.multimodal_arg))
# )
#
# def __getitem__(self, item: int):
# return [self.prompt, self.gen_kwargs][item]
#
#
# @dataclass
# class GenerateOutput:
# """
# Outputs for the generate function.
# """
#
# text: str
# metadata: dict = None
#
#
# @dataclass
# class LoglikelihoodInput:
# """
# Inputs for the loglikelihood function.
# """
#
# context: str
# continuation: Optional[str] = None
#
#
# class LoglikelihoodOutput(NamedTuple):
# """
# Outputs for the loglikelihood function.
# """
#
# loglikelihood: float
# is_greedy: Optional[bool] = None
# ctx_tokens: Optional[list[int]] = None
# cont_tokens: Optional[list[int]] = None
# metadata: Optional[dict] = None
# def __iter__(self):
# return iter((self.loglikelihood, self.is_greedy))
@
dataclass
...
...
@@ -66,7 +65,7 @@ class MetricResult:
"""
doc_id
:
str
|
int
|
None
scores
:
list
[
dict
[
str
,
float
]]
|
None
scores
:
list
[
dict
[
str
,
float
]]
|
dict
filter_key
:
str
=
None
metric_name
:
str
=
None
metadata
:
Optional
[
dict
]
=
None
...
...
@@ -76,6 +75,8 @@ class MetricResult:
return
iter
([])
# Group values by metric key
if
not
isinstance
(
self
.
scores
,
list
):
self
.
scores
=
[
self
.
scores
]
grouped
=
{}
for
score_dict
in
self
.
scores
:
for
key
,
value
in
score_dict
.
items
():
...
...
@@ -99,4 +100,8 @@ class MetricResult:
def
metric_keys
(
self
)
->
list
[
str
]:
if
self
.
scores
is
None
:
return
[]
return
list
(
self
.
scores
[
0
].
keys
())
if
self
.
scores
else
[]
return
(
list
(
self
.
scores
[
0
].
keys
())
if
isinstance
(
self
.
scores
,
list
)
else
list
(
self
.
scores
.
keys
())
)
lm_eval/api/task.py
View file @
5b8a7506
...
...
@@ -37,7 +37,7 @@ from lm_eval.api.registry import (
get_metric_aggregation
,
is_higher_better
,
)
from
lm_eval.api.schemas
import
GenerateInput
,
LoglikelihoodInput
,
MetricResult
from
lm_eval.api.schemas
import
MetricResult
from
lm_eval.caching.cache
import
load_from_cache
,
save_to_cache
from
lm_eval.filters
import
build_filter_ensemble
from
lm_eval.prompts
import
get_prompt
...
...
@@ -1531,7 +1531,8 @@ class ConfigurableTask(Task):
Instance
(
request_type
=
"loglikelihood"
,
doc
=
doc
,
arguments
=
LoglikelihoodInput
(
context
=
arg
[
0
],
continuation
=
arg
[
1
]),
arguments
=
arg
,
# arguments=LoglikelihoodInput(context=arg[0], continuation=arg[1]),
idx
=
i
,
**
kwargs
,
)
...
...
@@ -1543,9 +1544,9 @@ class ConfigurableTask(Task):
return
Instance
(
request_type
=
self
.
OUTPUT_TYPE
,
doc
=
doc
,
arguments
=
LoglikelihoodInput
(
*
arguments
)
if
self
.
OUTPUT_TYPE
in
[
"loglikelihood"
,
"loglikelihood_rolling"
]
else
GenerateInput
(
*
arguments
),
arguments
=
arguments
,
#
if self.OUTPUT_TYPE in ["loglikelihood", "loglikelihood_rolling"]
#
else GenerateInput(*arguments),
idx
=
0
,
**
kwargs
,
)
...
...
@@ -1819,15 +1820,21 @@ class ConfigurableTask(Task):
for
doc_id
,
doc
in
doc_iterator
:
# doc_id_true = indices[doc_id] if indices else doc_id
requests
=
instances_by_doc_id
[
doc_id
]
metrics
=
[
self
.
process_results
(
doc
,
response
)
for
req
in
requests
for
response
in
(
req
.
filtered_resps
[
filter_key
]
if
isinstance
(
req
.
filtered_resps
[
filter_key
],
list
)
else
[
req
.
filtered_resps
[
filter_key
]]
if
len
(
requests
)
>
1
:
# if one doc has multiple instances then calculate metric together
metrics
=
self
.
process_results
(
doc
,
[
req
.
filtered_resps
[
filter_key
]
for
req
in
requests
]
)
]
else
:
metrics
=
[
self
.
process_results
(
doc
,
response
)
for
req
in
requests
for
response
in
(
req
.
filtered_resps
[
filter_key
]
if
isinstance
(
req
.
filtered_resps
[
filter_key
],
list
)
else
[
req
.
filtered_resps
[
filter_key
]]
)
]
all_metrics
[
filter_key
].
append
(
MetricResult
(
scores
=
metrics
,
doc_id
=
doc_id
,
filter_key
=
filter_key
)
)
...
...
lm_eval/evaluator.py
View file @
5b8a7506
...
...
@@ -647,9 +647,7 @@ def evaluate(
ensure_ascii
=
False
,
)
),
"prompt_hash"
:
hash_string
(
requests
[
0
].
arguments
.
prompt
),
"prompt_hash"
:
hash_string
(
requests
[
0
].
arguments
[
0
]),
"target_hash"
:
hash_string
(
str
(
target
)),
}
example
.
update
(
...
...
lm_eval/filters/selection.py
View file @
5b8a7506
...
...
@@ -20,7 +20,7 @@ class TakeFirstFilter(Filter):
"""
Assuming each entry of `resps` is a list of model responses, we discard all but the first response.
"""
return
map
(
lambda
r
:
r
,
resps
)
return
map
(
lambda
r
:
r
[
0
]
,
resps
)
@
register_filter
(
"take_first_k"
)
...
...
lm_eval/models/huggingface.py
View file @
5b8a7506
...
...
@@ -27,7 +27,6 @@ from lm_eval import utils
from
lm_eval.api.instance
import
Instance
from
lm_eval.api.model
import
TemplateLM
from
lm_eval.api.registry
import
register_model
from
lm_eval.api.schemas
import
GenerateInput
,
GenerateOutput
,
LoglikelihoodOutput
from
lm_eval.models.utils
import
(
Collator
,
clear_torch_cache
,
...
...
@@ -966,7 +965,7 @@ class HFLM(TemplateLM):
def
loglikelihood_rolling
(
self
,
requests
:
List
[
Instance
],
disable_tqdm
:
bool
=
False
)
->
l
ist
[
LoglikelihoodOutpu
t
]:
)
->
L
ist
[
floa
t
]:
adaptive_batch_size
=
None
if
self
.
batch_size
==
"auto"
:
# using rolling window with maximum context
...
...
@@ -1026,7 +1025,7 @@ class HFLM(TemplateLM):
override_bs
=
len
(
batch_windows
),
)
# Store results with their request indices
all_nlls
.
extend
(
zip
(
batch_indices
,
(
x
.
loglikelihood
for
x
in
batch_nlls
))
)
all_nlls
.
extend
(
zip
(
batch_indices
,
batch_nlls
))
# Remove padding if necessary
if
(
self
.
world_size
>
1
)
and
(
pad_amnt
>
0
):
...
...
@@ -1039,8 +1038,8 @@ class HFLM(TemplateLM):
# Get all nlls for this request
request_nlls
=
all_nlls
[
current_idx
:
current_idx
+
window_count
]
# Sum up the nlls for this request (discarding is_greedy)
request_total
=
sum
(
nll
for
nll
in
request_nlls
)
loglikelihoods
.
append
(
LoglikelihoodOutput
(
loglikelihood
=
request_total
)
)
request_total
=
sum
(
nll
[
0
]
for
_
,
nll
in
request_nlls
)
loglikelihoods
.
append
(
request_total
)
current_idx
+=
window_count
string
=
requests
[
len
(
loglikelihoods
)
-
1
].
args
[
0
]
...
...
@@ -1072,7 +1071,7 @@ class HFLM(TemplateLM):
requests
:
List
[
Tuple
[
Tuple
[
str
,
str
],
List
[
int
],
List
[
int
]]],
disable_tqdm
:
bool
=
False
,
override_bs
:
int
=
None
,
)
->
List
[
LoglikelihoodOutput
]:
)
->
List
[
Tuple
[
float
,
bool
]
]:
# TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
res
=
[]
...
...
@@ -1287,13 +1286,7 @@ class HFLM(TemplateLM):
# Answer: (log prob, is-exact-match)
answer
=
(
float
(
logits
.
sum
()),
bool
(
max_equal
))
res
.
append
(
LoglikelihoodOutput
(
*
answer
,
ctx_tokens
=
ctx_tokens
,
cont_tokens
=
cont_toks
.
tolist
(),
)
)
res
.
append
(
answer
)
if
request_str
is
not
None
:
# special case: loglikelihood_rolling produces a number of loglikelihood requests
...
...
@@ -1309,8 +1302,8 @@ class HFLM(TemplateLM):
return
re_ord
.
get_original
(
res
)
def
generate_until
(
self
,
requests
:
List
[
Instance
[
GenerateInput
]
],
disable_tqdm
:
bool
=
False
)
->
List
[
GenerateOutput
]:
self
,
requests
:
List
[
Instance
],
disable_tqdm
:
bool
=
False
)
->
List
[
str
]:
res
=
[]
def
_collate
(
req
:
Tuple
[
str
,
dict
]):
...
...
@@ -1321,8 +1314,8 @@ class HFLM(TemplateLM):
# padded context length. this is useful to simplify the batching logic and more importantly to make
# automatic adaptive batches much much easier to implement
# - any OOMs will happen right away rather than near the end
toks
=
self
.
tok_encode
(
req
.
prompt
)
return
-
len
(
toks
),
req
.
prompt
toks
=
self
.
tok_encode
(
req
[
0
]
)
return
-
len
(
toks
),
req
[
0
]
pbar
=
tqdm
(
total
=
len
(
requests
),
...
...
@@ -1358,7 +1351,7 @@ class HFLM(TemplateLM):
[
reg
.
args
for
reg
in
requests
],
sort_fn
=
_collate
,
group_by
=
"gen_kwargs"
,
group_fn
=
lambda
x
:
x
.
gen_kwargs
,
group_fn
=
lambda
x
:
x
[
1
]
,
)
chunks
=
re_ords
.
get_batched
(
n
=
batch_size
,
batch_fn
=
batch_fn
)
eos
=
self
.
tok_decode
(
self
.
eot_token_id
,
skip_special_tokens
=
False
)
...
...
@@ -1427,7 +1420,7 @@ class HFLM(TemplateLM):
# for seq2seq case where self.tok_decode(self.eot_token_id) = ''
s
=
s
.
split
(
term
)[
0
]
res
.
append
(
GenerateOutput
(
text
=
s
)
)
res
.
append
(
s
)
self
.
cache_hook
.
add_partial
(
"generate_until"
,
(
context
,
gen_kwargs
),
s
)
pbar
.
update
(
1
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment