Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
451e73f1
Commit
451e73f1
authored
May 19, 2025
by
Baber
Browse files
add classes to inputs/outputs
parent
e30978c7
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
106 additions
and
25 deletions
+106
-25
lm_eval/api/instance.py
lm_eval/api/instance.py
+9
-7
lm_eval/api/model.py
lm_eval/api/model.py
+15
-6
lm_eval/api/task.py
lm_eval/api/task.py
+13
-3
lm_eval/api/types.py
lm_eval/api/types.py
+49
-0
lm_eval/evaluator.py
lm_eval/evaluator.py
+4
-0
lm_eval/models/huggingface.py
lm_eval/models/huggingface.py
+16
-9
No files found.
lm_eval/api/instance.py
View file @
451e73f1
from
dataclasses
import
dataclass
,
field
from
typing
import
Literal
,
Optional
,
Tuple
from
typing
import
Generic
,
Literal
,
Optional
,
Tuple
,
TypeVar
from
lm_eval.api.types
import
GenerateInput
,
LoglikelihoodInput
OutputType
=
Literal
[
"loglikelihood"
,
"loglikelihood_rolling"
,
"generate_until"
,
"multiple_choice"
]
T
=
TypeVar
(
"T"
,
LoglikelihoodInput
,
GenerateInput
)
@
dataclass
class
Instance
:
class
Instance
(
Generic
[
T
])
:
request_type
:
OutputType
doc
:
dict
arguments
:
tuple
arguments
:
T
idx
:
int
metadata
:
Tuple
[
Optional
[
str
],
Optional
[
int
],
Optional
[
int
]]
=
field
(
default_factory
=
lambda
:
(
None
,
None
,
None
)
...
...
@@ -29,10 +33,8 @@ class Instance:
self
.
task_name
,
self
.
doc_id
,
self
.
repeats
=
self
.
metadata
@
property
def
args
(
self
):
def
args
(
self
)
->
T
:
"""
Returns (string,) where `string` is the string to calculate loglikelihood over
"""
return
(
self
.
arguments
if
isinstance
(
self
.
arguments
,
tuple
)
else
(
self
.
arguments
,)
)
return
self
.
arguments
lm_eval/api/model.py
View file @
451e73f1
...
...
@@ -8,6 +8,11 @@ from typing import TYPE_CHECKING, Any, Iterable, Optional, Type, TypeVar, Union
from
tqdm
import
tqdm
from
lm_eval
import
utils
from
lm_eval.api.instance
import
Instance
from
lm_eval.api.types
import
(
LoglikelihoodInput
,
LoglikelihoodOutput
,
)
if
TYPE_CHECKING
:
...
...
@@ -34,7 +39,7 @@ class LM(abc.ABC):
self
.
cache_hook
:
"CacheHook"
=
CacheHook
(
None
)
@
abc
.
abstractmethod
def
loglikelihood
(
self
,
requests
)
->
list
[
tuple
[
float
,
bool
]]:
def
loglikelihood
(
self
,
requests
:
list
[
Instance
]
)
->
list
[
tuple
[
float
,
bool
]]:
"""Compute log-likelihood of generating a continuation from a context.
Downstream tasks should attempt to use loglikelihood instead of other
LM calls whenever possible.
...
...
@@ -59,7 +64,7 @@ class LM(abc.ABC):
pass
@
abc
.
abstractmethod
def
loglikelihood_rolling
(
self
,
requests
)
->
list
[
float
]:
def
loglikelihood_rolling
(
self
,
requests
:
list
[
Instance
]
)
->
list
[
float
]:
"""Compute full log-likelihood of a string, with no truncation, for perplexity computation
- We will use the full max context length of the model.
- For inputs that exceed the max context length, we divide the tokenized string into chunks of up to
...
...
@@ -101,7 +106,7 @@ class LM(abc.ABC):
# TODO: Add an optional max length
@
abc
.
abstractmethod
def
generate_until
(
self
,
requests
)
->
list
[
str
]:
def
generate_until
(
self
,
requests
:
list
[
Instance
]
)
->
list
[
str
]:
"""Generate greedily until a stopping sequence
:param requests: list[Instance]
...
...
@@ -376,7 +381,9 @@ class TemplateLM(LM):
self
,
requests
:
list
[
"Instance"
],
disable_tqdm
:
bool
=
False
)
->
list
[
tuple
[
float
,
bool
]]:
new_reqs
=
[]
for
context
,
continuation
in
[
req
.
args
for
req
in
requests
]:
for
context
,
continuation
in
(
(
req
.
args
.
context
,
req
.
args
.
continuation
)
for
req
in
requests
):
if
context
==
""
:
# BOS or EOS as context
context_enc
,
continuation_enc
=
(
...
...
@@ -392,12 +399,14 @@ class TemplateLM(LM):
@
abc
.
abstractmethod
def
loglikelihood_rolling
(
self
,
requests
,
disable_tqdm
:
bool
=
False
self
,
requests
:
list
[
Instance
]
,
disable_tqdm
:
bool
=
False
)
->
list
[
float
]:
pass
@
abc
.
abstractmethod
def
generate_until
(
self
,
requests
,
disable_tqdm
:
bool
=
False
)
->
list
[
str
]:
def
generate_until
(
self
,
requests
:
list
[
Instance
],
disable_tqdm
:
bool
=
False
)
->
list
[
str
]:
pass
def
chat_template
(
self
,
chat_template
:
Union
[
bool
,
str
]
=
False
)
->
Optional
[
str
]:
...
...
lm_eval/api/task.py
View file @
451e73f1
...
...
@@ -36,6 +36,7 @@ from lm_eval.api.registry import (
get_metric_aggregation
,
is_higher_better
,
)
from
lm_eval.api.types
import
GenerateInput
,
LoglikelihoodInput
from
lm_eval.caching.cache
import
load_from_cache
,
save_to_cache
from
lm_eval.filters
import
build_filter_ensemble
from
lm_eval.prompts
import
get_prompt
...
...
@@ -1493,6 +1494,13 @@ class ConfigurableTask(Task):
elif
self
.
OUTPUT_TYPE
==
"generate_until"
:
arguments
=
(
ctx
,
deepcopy
(
self
.
config
.
generation_kwargs
))
else
:
raise
ValueError
(
f
"Unsupported OUTPUT_TYPE: '
{
self
.
OUTPUT_TYPE
}
'. "
f
"Expected one of: 'loglikelihood', 'loglikelihood_rolling', "
f
"'multiple_choice', 'generate_until'"
)
multimodal_arg
=
{}
if
(
self
.
config
.
doc_to_image
...
...
@@ -1521,7 +1529,7 @@ class ConfigurableTask(Task):
Instance
(
request_type
=
"loglikelihood"
,
doc
=
doc
,
arguments
=
arg
,
arguments
=
LoglikelihoodInput
(
context
=
arg
[
0
],
continuation
=
arg
[
1
])
,
idx
=
i
,
**
kwargs
,
)
...
...
@@ -1533,7 +1541,9 @@ class ConfigurableTask(Task):
return
Instance
(
request_type
=
self
.
OUTPUT_TYPE
,
doc
=
doc
,
arguments
=
arguments
,
arguments
=
LoglikelihoodInput
(
*
arguments
)
if
self
.
OUTPUT_TYPE
in
[
"loglikelihood"
,
"loglikelihood_rolling"
]
else
GenerateInput
(
*
arguments
),
idx
=
0
,
**
kwargs
,
)
...
...
@@ -1846,7 +1856,7 @@ class MultipleChoiceTask(Task):
class
PerplexityTask
(
Task
):
OUTPUT_TYPE
=
"loglikelihood_rolling"
OUTPUT_TYPE
:
OutputType
=
"loglikelihood_rolling"
def
has_training_docs
(
self
)
->
bool
:
return
False
...
...
lm_eval/api/types.py
0 → 100644
View file @
451e73f1
from
dataclasses
import
dataclass
from
typing
import
Optional
@
dataclass
class
GenerateInput
:
"""
Inputs for the generate function.
"""
prompt
:
str
gen_kwargs
:
dict
multimodal_arg
:
Optional
[
dict
]
=
None
@
dataclass
class
GenerateOutput
:
"""
Outputs for the generate function.
"""
text
:
str
metadata
:
dict
=
None
@
dataclass
class
LoglikelihoodInput
:
"""
Inputs for the loglikelihood function.
"""
context
:
str
continuation
:
Optional
[
str
]
=
None
@
dataclass
class
LoglikelihoodOutput
:
"""
Outputs for the loglikelihood function.
"""
loglikelihood
:
float
is_greedy
:
Optional
[
bool
]
=
None
ctx_tokens
:
Optional
[
list
[
int
]]
=
None
cont_tokens
:
Optional
[
list
[
int
]]
=
None
metadata
:
Optional
[
dict
]
=
None
def
__iter__
(
self
):
return
iter
((
self
.
loglikelihood
,
self
.
is_greedy
))
lm_eval/evaluator.py
View file @
451e73f1
...
...
@@ -560,6 +560,8 @@ def evaluate(
# create `K` copies of each request `req` based off `K = req.repeats`
cloned_reqs
=
[]
for
req
in
reqs
:
# Note: [req] * req.repeats creates multiple references to the same request object,
# not separate copies. This means all repeated entries point to the same req.resps list
cloned_reqs
.
extend
([
req
]
*
req
.
repeats
)
if
(
lm
.
world_size
>
1
)
and
(
padding_requests
[
reqtype
]
>
0
):
...
...
@@ -567,6 +569,8 @@ def evaluate(
cloned_reqs
.
extend
([
req
]
*
req
.
repeats
)
# run requests through model
# Since cloned_reqs contains references to original objects, each response
# automatically gets appended to the correct req.resps list
resps
=
getattr
(
lm
,
reqtype
)(
cloned_reqs
)
# put responses from model into a list of length K for each request.
...
...
lm_eval/models/huggingface.py
View file @
451e73f1
...
...
@@ -27,6 +27,7 @@ from lm_eval import utils
from
lm_eval.api.instance
import
Instance
from
lm_eval.api.model
import
TemplateLM
from
lm_eval.api.registry
import
register_model
from
lm_eval.api.types
import
GenerateInput
,
GenerateOutput
,
LoglikelihoodOutput
from
lm_eval.models.utils
import
(
Collator
,
clear_torch_cache
,
...
...
@@ -965,7 +966,7 @@ class HFLM(TemplateLM):
def
loglikelihood_rolling
(
self
,
requests
:
List
[
Instance
],
disable_tqdm
:
bool
=
False
)
->
L
ist
[
floa
t
]:
)
->
l
ist
[
LoglikelihoodOutpu
t
]:
adaptive_batch_size
=
None
if
self
.
batch_size
==
"auto"
:
# using rolling window with maximum context
...
...
@@ -1025,7 +1026,7 @@ class HFLM(TemplateLM):
override_bs
=
len
(
batch_windows
),
)
# Store results with their request indices
all_nlls
.
extend
(
zip
(
batch_indices
,
batch_nlls
))
all_nlls
.
extend
(
zip
(
batch_indices
,
(
x
.
loglikelihood
for
x
in
batch_nlls
))
)
# Remove padding if necessary
if
(
self
.
world_size
>
1
)
and
(
pad_amnt
>
0
):
...
...
@@ -1038,8 +1039,8 @@ class HFLM(TemplateLM):
# Get all nlls for this request
request_nlls
=
all_nlls
[
current_idx
:
current_idx
+
window_count
]
# Sum up the nlls for this request (discarding is_greedy)
request_total
=
sum
(
nll
[
0
]
for
_
,
nll
in
request_nlls
)
loglikelihoods
.
append
(
request_total
)
request_total
=
sum
(
nll
for
nll
in
request_nlls
)
loglikelihoods
.
append
(
LoglikelihoodOutput
(
loglikelihood
=
request_total
)
)
current_idx
+=
window_count
string
=
requests
[
len
(
loglikelihoods
)
-
1
].
args
[
0
]
...
...
@@ -1071,7 +1072,7 @@ class HFLM(TemplateLM):
requests
:
List
[
Tuple
[
Tuple
[
str
,
str
],
List
[
int
],
List
[
int
]]],
disable_tqdm
:
bool
=
False
,
override_bs
:
int
=
None
,
)
->
List
[
Tuple
[
float
,
bool
]
]:
)
->
List
[
LoglikelihoodOutput
]:
# TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
res
=
[]
...
...
@@ -1286,7 +1287,13 @@ class HFLM(TemplateLM):
# Answer: (log prob, is-exact-match)
answer
=
(
float
(
logits
.
sum
()),
bool
(
max_equal
))
res
.
append
(
answer
)
res
.
append
(
LoglikelihoodOutput
(
*
answer
,
ctx_tokens
=
ctx_tokens
,
cont_tokens
=
cont_toks
.
tolist
(),
)
)
if
request_str
is
not
None
:
# special case: loglikelihood_rolling produces a number of loglikelihood requests
...
...
@@ -1302,8 +1309,8 @@ class HFLM(TemplateLM):
return
re_ord
.
get_original
(
res
)
def
generate_until
(
self
,
requests
:
List
[
Instance
],
disable_tqdm
:
bool
=
False
)
->
List
[
str
]:
self
,
requests
:
List
[
Instance
[
GenerateInput
]
],
disable_tqdm
:
bool
=
False
)
->
List
[
GenerateOutput
]:
res
=
[]
def
_collate
(
req
:
Tuple
[
str
,
dict
]):
...
...
@@ -1420,7 +1427,7 @@ class HFLM(TemplateLM):
# for seq2seq case where self.tok_decode(self.eot_token_id) = ''
s
=
s
.
split
(
term
)[
0
]
res
.
append
(
s
)
res
.
append
(
GenerateOutput
(
text
=
s
)
)
self
.
cache_hook
.
add_partial
(
"generate_until"
,
(
context
,
gen_kwargs
),
s
)
pbar
.
update
(
1
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment