Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
c81c03ee
Commit
c81c03ee
authored
Jul 08, 2025
by
Baber
Browse files
cleanup
parent
674611e9
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
70 additions
and
21 deletions
+70
-21
lm_eval/api/model.py
lm_eval/api/model.py
+59
-11
lm_eval/config/task.py
lm_eval/config/task.py
+11
-10
No files found.
lm_eval/api/model.py
View file @
c81c03ee
...
...
@@ -24,7 +24,7 @@ T = TypeVar("T", bound="LM")
class
LM
(
abc
.
ABC
):
def
__init__
(
self
)
->
None
:
"""Defines the interface that should be implemented by all LM subclasses.
LMs are assumed to take text (strings) as input and yield strings as output
LMs are assumed to take text (strings) as input and yield strings
or logprobabilities
as output
(inputs/outputs should be tokenization-agnostic.)
"""
...
...
@@ -34,7 +34,7 @@ class LM(abc.ABC):
self
.
cache_hook
:
"CacheHook"
=
CacheHook
(
None
)
@
abc
.
abstractmethod
def
loglikelihood
(
self
,
requests
)
->
list
[
tuple
[
float
,
bool
]]:
def
loglikelihood
(
self
,
requests
:
list
[
Instance
]
)
->
list
[
tuple
[
float
,
bool
]]:
"""Compute log-likelihood of generating a continuation from a context.
Downstream tasks should attempt to use loglikelihood instead of other
LM calls whenever possible.
...
...
@@ -59,7 +59,7 @@ class LM(abc.ABC):
pass
@
abc
.
abstractmethod
def
loglikelihood_rolling
(
self
,
requests
)
->
list
[
float
]:
def
loglikelihood_rolling
(
self
,
requests
:
list
[
Instance
]
)
->
list
[
float
]:
"""Compute full log-likelihood of a string, with no truncation, for perplexity computation
- We will use the full max context length of the model.
- For inputs that exceed the max context length, we divide the tokenized string into chunks of up to
...
...
@@ -67,7 +67,7 @@ class LM(abc.ABC):
- IMPORTANT: Each document's loglikelihood/perplexity is computed *separately*, unlike other implementations
which may simply concatenate multiple documents together.
- IMPORTANT: We maximize the amount of context for each prediction. Specifically, for inputs that we break into
multiple chunks, the last input will still
a
full-sized context.
multiple chunks, the last input will still
have
full-sized context.
Example:
Input tokens: [ 0 1 2 3 4 5 6 7 8 9 ]
Prefix: BOS/EOS
...
...
@@ -101,7 +101,7 @@ class LM(abc.ABC):
# TODO: Add an optional max length
@
abc
.
abstractmethod
def
generate_until
(
self
,
requests
)
->
list
[
str
]:
def
generate_until
(
self
,
requests
:
list
[
Instance
]
)
->
list
[
str
]:
"""Generate greedily until a stopping sequence
:param requests: list[Instance]
...
...
@@ -118,7 +118,7 @@ class LM(abc.ABC):
pass
def
apply_chat_template
(
self
,
chat_history
:
list
[
dict
[
str
,
str
]
],
add_generation_prompt
=
True
self
,
chat_history
:
list
[
dict
],
add_generation_prompt
=
True
)
->
str
:
"""
Defines how to transform few-shot examples provided as chat history into a format that can be used as input to the LM.
...
...
@@ -177,6 +177,7 @@ class LM(abc.ABC):
@
property
def
rank
(
self
)
->
int
:
"""Returns the rank of the current process in a distributed setting."""
# used in the case of parallelism. Hardcoded to
# ensure no errors arise using API models which do
# not support multi-device parallelism nor expect it.
...
...
@@ -184,6 +185,7 @@ class LM(abc.ABC):
@
property
def
world_size
(
self
)
->
int
:
"""Returns the total number of processes in a distributed setting."""
# used in the case of parallelism. Hardcoded to
# ensure no errors arise using API models which do
# not support multi-device parallelism nor expect it.
...
...
@@ -208,6 +210,7 @@ class LM(abc.ABC):
return
""
def
set_cache_hook
(
self
,
cache_hook
:
"CacheHook"
)
->
None
:
"""Sets the cache hook for the LM, which is used to cache responses from the LM."""
self
.
cache_hook
=
cache_hook
...
...
@@ -219,6 +222,7 @@ def hash_args(attr: str, args: Iterable[Any]) -> str:
class
CacheHook
:
def
__init__
(
self
,
cachinglm
:
Optional
[
"CachingLM"
])
->
None
:
"""CacheHook is used to cache responses from the LM."""
if
cachinglm
is
None
:
self
.
dbdict
:
Optional
[
"SqliteDict"
]
=
None
return
...
...
@@ -226,6 +230,7 @@ class CacheHook:
self
.
dbdict
=
cachinglm
.
dbdict
def
add_partial
(
self
,
attr
:
str
,
req
:
Iterable
[
Any
],
res
:
Any
)
->
None
:
"""Adds a partial result to the cache."""
if
self
.
dbdict
is
None
:
return
hsh
=
hash_args
(
attr
,
req
)
...
...
@@ -328,11 +333,12 @@ class TemplateLM(LM):
@
property
@
abc
.
abstractmethod
def
eot_token_id
(
self
)
->
int
:
"""Returns the token ID for the end-of-text token (e.g., EOS)."""
pass
@
property
def
prefix_token_id
(
self
)
->
int
:
# it is used as prefix for loglikelihood
"""Returns the token ID for the prefix token (e.g., BOS or EOS)."""
return
self
.
eot_token_id
@
abc
.
abstractmethod
...
...
@@ -344,8 +350,24 @@ class TemplateLM(LM):
@
abc
.
abstractmethod
def
_loglikelihood_tokens
(
self
,
requests
:
list
[
"Instance"
],
**
kwargs
self
,
requests
:
list
[
tuple
[
tuple
[
str
,
str
],
list
[
int
],
list
[
int
]]
],
**
kwargs
)
->
list
[
tuple
[
float
,
bool
]]:
"""Called by loglikelihood to compute log likelihoods for a list of requests.
Args:
requests: list[tuple[tuple[str, str], list[int], list[int]]]
A list of tuples where each tuple contains:
- (context, continuation) as a tuple of strings
- context_enc: list of token IDs for the context
- continuation_enc: list of token IDs for the continuation
Returns:
list[tuple[float, bool]]
A list of tuples where each tuple contains:
- logprob: float, the (summed) log probability of the continuation given the context
- isgreedy: bool, whether the continuation would be generated by greedy sampling from the context
See LM.loglikelihood for more details.
"""
pass
def
_encode_pair
(
...
...
@@ -353,8 +375,7 @@ class TemplateLM(LM):
)
->
tuple
[
list
[
int
],
list
[
int
]]:
"""Encodes a pair of context and continuation strings into token IDs.
Ensures that encode(context + continuation) == encode(context) + encode(continuation)
We encode using encode(context+continuation) and then split into context and continuation.
"""
import
transformers
...
...
@@ -380,6 +401,10 @@ class TemplateLM(LM):
def
loglikelihood
(
self
,
requests
:
list
[
"Instance"
],
disable_tqdm
:
bool
=
False
)
->
list
[
tuple
[
float
,
bool
]]:
"""Compute log-likelihood of generating a continuation from a context.
This calls `_loglikelihood_tokens` to compute the log likelihoods for a list of requests, after encoding.
"""
new_reqs
=
[]
for
context
,
continuation
in
[
req
.
args
for
req
in
requests
]:
if
context
==
""
:
...
...
@@ -399,10 +424,33 @@ class TemplateLM(LM):
def
loglikelihood_rolling
(
self
,
requests
,
disable_tqdm
:
bool
=
False
)
->
list
[
float
]:
"""Compute rolling log-likelihood of a sequence using non-overlapping windows.
See LM.loglikelihood_rolling for more details.
"""
pass
@
abc
.
abstractmethod
def
generate_until
(
self
,
requests
,
disable_tqdm
:
bool
=
False
)
->
list
[
str
]:
def
generate_until
(
self
,
requests
,
disable_tqdm
:
bool
=
False
)
->
list
[
str
]:
"""Generate until a stopping sequence.
Args:
requests: list[Instance]
A list of Instance objects with property `args` which returns a tuple (context, gen_kwargs).
context: str
Context string
gen_kwargs: dict
A dictionary of keyword arguments to pass to the generation function e.g. top_k, until, etc.
Returns:
list[continuation, ...]
A list of model generated continuations.
continuation: str
The generated continuation.
See LM.generate_until for more details.
"""
pass
def
chat_template
(
self
,
chat_template
:
Union
[
bool
,
str
]
=
False
)
->
Optional
[
str
]:
...
...
lm_eval/config/task.py
View file @
c81c03ee
...
...
@@ -21,7 +21,7 @@ class RepeatConfig:
repeats
:
int
=
1
metric_fn
:
Union
[
str
,
Callable
]
=
"pass@N"
kwargs
:
Optional
[
dict
]
=
None
kwargs
:
Optional
[
dict
]
=
field
(
default_factory
=
dict
)
@
dataclass
...
...
@@ -30,7 +30,7 @@ class FilterConfig:
name
:
str
fn
:
Optional
[
Callable
]
=
None
kwargs
:
Optional
[
dict
]
=
None
kwargs
:
Optional
[
dict
]
=
field
(
default_factory
=
dict
)
@
dataclass
...
...
@@ -123,13 +123,13 @@ class DatasetConfig:
name
:
Optional
[
str
]
=
None
kwargs
:
Optional
[
dict
]
=
field
(
default_factory
=
dict
)
custom
:
Optional
[
Callable
]
=
None
metadata
:
Optional
[
dict
]
=
None
metadata
:
Optional
[
dict
]
=
field
(
default_factory
=
dict
)
@
dataclass
class
TaskConfig
(
dict
):
# task naming/registry
task
:
Optional
[
str
]
=
None
task
:
str
task_alias
:
Optional
[
str
]
=
None
tag
:
Optional
[
Union
[
str
,
list
]]
=
None
# HF dataset options.
...
...
@@ -171,13 +171,14 @@ class TaskConfig(dict):
should_decontaminate
:
bool
=
False
doc_to_decontamination_query
:
Optional
[
str
]
=
None
gen_prefix
:
Optional
[
str
]
=
None
metadata
:
Optional
[
dict
]
=
(
None
# by default, not used in the code. allows for users to pass arbitrary info to tasks
)
metadata
:
Optional
[
dict
]
=
field
(
default_factory
=
dict
)
# by default, not used in the code. allows for users to pass arbitrary info to tasks
_metric_list
:
list
[
MetricConfig
]
=
None
_filter_list
:
list
[
FilterConfig
]
=
None
ds_cfg
:
DatasetConfig
=
None
fewshot_cfg
:
FewshotConfig
=
None
ds_cfg
:
DatasetConfig
=
field
(
init
=
False
)
fewshot_cfg
:
FewshotConfig
=
field
(
init
=
False
)
def
__post_init__
(
self
)
->
None
:
### ---setup generation kwargs--- ###
...
...
@@ -218,7 +219,7 @@ class TaskConfig(dict):
name
=
self
.
dataset_name
,
kwargs
=
self
.
dataset_kwargs
,
custom
=
self
.
custom_dataset
,
metadata
=
self
.
metadata
,
metadata
=
self
.
metadata
or
{}
,
)
# ---setup fewshot config--- #
_fewshot_cfg
=
self
.
fewshot_config
if
self
.
fewshot_config
is
not
None
else
{}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment