Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
c81c03ee
"vscode:/vscode.git/clone" did not exist on "7f9bf1f2068d343246d81715844f6dea003ac449"
Commit
c81c03ee
authored
Jul 08, 2025
by
Baber
Browse files
cleanup
parent
674611e9
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
70 additions
and
21 deletions
+70
-21
lm_eval/api/model.py
lm_eval/api/model.py
+59
-11
lm_eval/config/task.py
lm_eval/config/task.py
+11
-10
No files found.
lm_eval/api/model.py
View file @
c81c03ee
...
@@ -24,7 +24,7 @@ T = TypeVar("T", bound="LM")
...
@@ -24,7 +24,7 @@ T = TypeVar("T", bound="LM")
class
LM
(
abc
.
ABC
):
class
LM
(
abc
.
ABC
):
def
__init__
(
self
)
->
None
:
def
__init__
(
self
)
->
None
:
"""Defines the interface that should be implemented by all LM subclasses.
"""Defines the interface that should be implemented by all LM subclasses.
LMs are assumed to take text (strings) as input and yield strings as output
LMs are assumed to take text (strings) as input and yield strings
or logprobabilities
as output
(inputs/outputs should be tokenization-agnostic.)
(inputs/outputs should be tokenization-agnostic.)
"""
"""
...
@@ -34,7 +34,7 @@ class LM(abc.ABC):
...
@@ -34,7 +34,7 @@ class LM(abc.ABC):
self
.
cache_hook
:
"CacheHook"
=
CacheHook
(
None
)
self
.
cache_hook
:
"CacheHook"
=
CacheHook
(
None
)
@
abc
.
abstractmethod
@
abc
.
abstractmethod
def
loglikelihood
(
self
,
requests
)
->
list
[
tuple
[
float
,
bool
]]:
def
loglikelihood
(
self
,
requests
:
list
[
Instance
]
)
->
list
[
tuple
[
float
,
bool
]]:
"""Compute log-likelihood of generating a continuation from a context.
"""Compute log-likelihood of generating a continuation from a context.
Downstream tasks should attempt to use loglikelihood instead of other
Downstream tasks should attempt to use loglikelihood instead of other
LM calls whenever possible.
LM calls whenever possible.
...
@@ -59,7 +59,7 @@ class LM(abc.ABC):
...
@@ -59,7 +59,7 @@ class LM(abc.ABC):
pass
pass
@
abc
.
abstractmethod
@
abc
.
abstractmethod
def
loglikelihood_rolling
(
self
,
requests
)
->
list
[
float
]:
def
loglikelihood_rolling
(
self
,
requests
:
list
[
Instance
]
)
->
list
[
float
]:
"""Compute full log-likelihood of a string, with no truncation, for perplexity computation
"""Compute full log-likelihood of a string, with no truncation, for perplexity computation
- We will use the full max context length of the model.
- We will use the full max context length of the model.
- For inputs that exceed the max context length, we divide the tokenized string into chunks of up to
- For inputs that exceed the max context length, we divide the tokenized string into chunks of up to
...
@@ -67,7 +67,7 @@ class LM(abc.ABC):
...
@@ -67,7 +67,7 @@ class LM(abc.ABC):
- IMPORTANT: Each document's loglikelihood/perplexity is computed *separately*, unlike other implementations
- IMPORTANT: Each document's loglikelihood/perplexity is computed *separately*, unlike other implementations
which may simply concatenate multiple documents together.
which may simply concatenate multiple documents together.
- IMPORTANT: We maximize the amount of context for each prediction. Specifically, for inputs that we break into
- IMPORTANT: We maximize the amount of context for each prediction. Specifically, for inputs that we break into
multiple chunks, the last input will still
a
full-sized context.
multiple chunks, the last input will still
have
full-sized context.
Example:
Example:
Input tokens: [ 0 1 2 3 4 5 6 7 8 9 ]
Input tokens: [ 0 1 2 3 4 5 6 7 8 9 ]
Prefix: BOS/EOS
Prefix: BOS/EOS
...
@@ -101,7 +101,7 @@ class LM(abc.ABC):
...
@@ -101,7 +101,7 @@ class LM(abc.ABC):
# TODO: Add an optional max length
# TODO: Add an optional max length
@
abc
.
abstractmethod
@
abc
.
abstractmethod
def
generate_until
(
self
,
requests
)
->
list
[
str
]:
def
generate_until
(
self
,
requests
:
list
[
Instance
]
)
->
list
[
str
]:
"""Generate greedily until a stopping sequence
"""Generate greedily until a stopping sequence
:param requests: list[Instance]
:param requests: list[Instance]
...
@@ -118,7 +118,7 @@ class LM(abc.ABC):
...
@@ -118,7 +118,7 @@ class LM(abc.ABC):
pass
pass
def
apply_chat_template
(
def
apply_chat_template
(
self
,
chat_history
:
list
[
dict
[
str
,
str
]
],
add_generation_prompt
=
True
self
,
chat_history
:
list
[
dict
],
add_generation_prompt
=
True
)
->
str
:
)
->
str
:
"""
"""
Defines how to transform few-shot examples provided as chat history into a format that can be used as input to the LM.
Defines how to transform few-shot examples provided as chat history into a format that can be used as input to the LM.
...
@@ -177,6 +177,7 @@ class LM(abc.ABC):
...
@@ -177,6 +177,7 @@ class LM(abc.ABC):
@
property
@
property
def
rank
(
self
)
->
int
:
def
rank
(
self
)
->
int
:
"""Returns the rank of the current process in a distributed setting."""
# used in the case of parallelism. Hardcoded to
# used in the case of parallelism. Hardcoded to
# ensure no errors arise using API models which do
# ensure no errors arise using API models which do
# not support multi-device parallelism nor expect it.
# not support multi-device parallelism nor expect it.
...
@@ -184,6 +185,7 @@ class LM(abc.ABC):
...
@@ -184,6 +185,7 @@ class LM(abc.ABC):
@
property
@
property
def
world_size
(
self
)
->
int
:
def
world_size
(
self
)
->
int
:
"""Returns the total number of processes in a distributed setting."""
# used in the case of parallelism. Hardcoded to
# used in the case of parallelism. Hardcoded to
# ensure no errors arise using API models which do
# ensure no errors arise using API models which do
# not support multi-device parallelism nor expect it.
# not support multi-device parallelism nor expect it.
...
@@ -208,6 +210,7 @@ class LM(abc.ABC):
...
@@ -208,6 +210,7 @@ class LM(abc.ABC):
return
""
return
""
def
set_cache_hook
(
self
,
cache_hook
:
"CacheHook"
)
->
None
:
def
set_cache_hook
(
self
,
cache_hook
:
"CacheHook"
)
->
None
:
"""Sets the cache hook for the LM, which is used to cache responses from the LM."""
self
.
cache_hook
=
cache_hook
self
.
cache_hook
=
cache_hook
...
@@ -219,6 +222,7 @@ def hash_args(attr: str, args: Iterable[Any]) -> str:
...
@@ -219,6 +222,7 @@ def hash_args(attr: str, args: Iterable[Any]) -> str:
class
CacheHook
:
class
CacheHook
:
def
__init__
(
self
,
cachinglm
:
Optional
[
"CachingLM"
])
->
None
:
def
__init__
(
self
,
cachinglm
:
Optional
[
"CachingLM"
])
->
None
:
"""CacheHook is used to cache responses from the LM."""
if
cachinglm
is
None
:
if
cachinglm
is
None
:
self
.
dbdict
:
Optional
[
"SqliteDict"
]
=
None
self
.
dbdict
:
Optional
[
"SqliteDict"
]
=
None
return
return
...
@@ -226,6 +230,7 @@ class CacheHook:
...
@@ -226,6 +230,7 @@ class CacheHook:
self
.
dbdict
=
cachinglm
.
dbdict
self
.
dbdict
=
cachinglm
.
dbdict
def
add_partial
(
self
,
attr
:
str
,
req
:
Iterable
[
Any
],
res
:
Any
)
->
None
:
def
add_partial
(
self
,
attr
:
str
,
req
:
Iterable
[
Any
],
res
:
Any
)
->
None
:
"""Adds a partial result to the cache."""
if
self
.
dbdict
is
None
:
if
self
.
dbdict
is
None
:
return
return
hsh
=
hash_args
(
attr
,
req
)
hsh
=
hash_args
(
attr
,
req
)
...
@@ -328,11 +333,12 @@ class TemplateLM(LM):
...
@@ -328,11 +333,12 @@ class TemplateLM(LM):
@
property
@
property
@
abc
.
abstractmethod
@
abc
.
abstractmethod
def
eot_token_id
(
self
)
->
int
:
def
eot_token_id
(
self
)
->
int
:
"""Returns the token ID for the end-of-text token (e.g., EOS)."""
pass
pass
@
property
@
property
def
prefix_token_id
(
self
)
->
int
:
def
prefix_token_id
(
self
)
->
int
:
# it is used as prefix for loglikelihood
"""Returns the token ID for the prefix token (e.g., BOS or EOS)."""
return
self
.
eot_token_id
return
self
.
eot_token_id
@
abc
.
abstractmethod
@
abc
.
abstractmethod
...
@@ -344,8 +350,24 @@ class TemplateLM(LM):
...
@@ -344,8 +350,24 @@ class TemplateLM(LM):
@
abc
.
abstractmethod
@
abc
.
abstractmethod
def
_loglikelihood_tokens
(
def
_loglikelihood_tokens
(
self
,
requests
:
list
[
"Instance"
],
**
kwargs
self
,
requests
:
list
[
tuple
[
tuple
[
str
,
str
],
list
[
int
],
list
[
int
]]
],
**
kwargs
)
->
list
[
tuple
[
float
,
bool
]]:
)
->
list
[
tuple
[
float
,
bool
]]:
"""Called by loglikelihood to compute log likelihoods for a list of requests.
Args:
requests: list[tuple[tuple[str, str], list[int], list[int]]]
A list of tuples where each tuple contains:
- (context, continuation) as a tuple of strings
- context_enc: list of token IDs for the context
- continuation_enc: list of token IDs for the continuation
Returns:
list[tuple[float, bool]]
A list of tuples where each tuple contains:
- logprob: float, the (summed) log probability of the continuation given the context
- isgreedy: bool, whether the continuation would be generated by greedy sampling from the context
See LM.loglikelihood for more details.
"""
pass
pass
def
_encode_pair
(
def
_encode_pair
(
...
@@ -353,8 +375,7 @@ class TemplateLM(LM):
...
@@ -353,8 +375,7 @@ class TemplateLM(LM):
)
->
tuple
[
list
[
int
],
list
[
int
]]:
)
->
tuple
[
list
[
int
],
list
[
int
]]:
"""Encodes a pair of context and continuation strings into token IDs.
"""Encodes a pair of context and continuation strings into token IDs.
Ensures that encode(context + continuation) == encode(context) + encode(continuation)
We encode using encode(context+continuation) and then split into context and continuation.
"""
"""
import
transformers
import
transformers
...
@@ -380,6 +401,10 @@ class TemplateLM(LM):
...
@@ -380,6 +401,10 @@ class TemplateLM(LM):
def
loglikelihood
(
def
loglikelihood
(
self
,
requests
:
list
[
"Instance"
],
disable_tqdm
:
bool
=
False
self
,
requests
:
list
[
"Instance"
],
disable_tqdm
:
bool
=
False
)
->
list
[
tuple
[
float
,
bool
]]:
)
->
list
[
tuple
[
float
,
bool
]]:
"""Compute log-likelihood of generating a continuation from a context.
This calls `_loglikelihood_tokens` to compute the log likelihoods for a list of requests, after encoding.
"""
new_reqs
=
[]
new_reqs
=
[]
for
context
,
continuation
in
[
req
.
args
for
req
in
requests
]:
for
context
,
continuation
in
[
req
.
args
for
req
in
requests
]:
if
context
==
""
:
if
context
==
""
:
...
@@ -399,10 +424,33 @@ class TemplateLM(LM):
...
@@ -399,10 +424,33 @@ class TemplateLM(LM):
def
loglikelihood_rolling
(
def
loglikelihood_rolling
(
self
,
requests
,
disable_tqdm
:
bool
=
False
self
,
requests
,
disable_tqdm
:
bool
=
False
)
->
list
[
float
]:
)
->
list
[
float
]:
"""Compute rolling log-likelihood of a sequence using non-overlapping windows.
See LM.loglikelihood_rolling for more details.
"""
pass
pass
@
abc
.
abstractmethod
@
abc
.
abstractmethod
def
generate_until
(
self
,
requests
,
disable_tqdm
:
bool
=
False
)
->
list
[
str
]:
def
generate_until
(
self
,
requests
,
disable_tqdm
:
bool
=
False
)
->
list
[
str
]:
"""Generate until a stopping sequence.
Args:
requests: list[Instance]
A list of Instance objects with property `args` which returns a tuple (context, gen_kwargs).
context: str
Context string
gen_kwargs: dict
A dictionary of keyword arguments to pass to the generation function e.g. top_k, until, etc.
Returns:
list[continuation, ...]
A list of model generated continuations.
continuation: str
The generated continuation.
See LM.generate_until for more details.
"""
pass
pass
def
chat_template
(
self
,
chat_template
:
Union
[
bool
,
str
]
=
False
)
->
Optional
[
str
]:
def
chat_template
(
self
,
chat_template
:
Union
[
bool
,
str
]
=
False
)
->
Optional
[
str
]:
...
...
lm_eval/config/task.py
View file @
c81c03ee
...
@@ -21,7 +21,7 @@ class RepeatConfig:
...
@@ -21,7 +21,7 @@ class RepeatConfig:
repeats
:
int
=
1
repeats
:
int
=
1
metric_fn
:
Union
[
str
,
Callable
]
=
"pass@N"
metric_fn
:
Union
[
str
,
Callable
]
=
"pass@N"
kwargs
:
Optional
[
dict
]
=
None
kwargs
:
Optional
[
dict
]
=
field
(
default_factory
=
dict
)
@
dataclass
@
dataclass
...
@@ -30,7 +30,7 @@ class FilterConfig:
...
@@ -30,7 +30,7 @@ class FilterConfig:
name
:
str
name
:
str
fn
:
Optional
[
Callable
]
=
None
fn
:
Optional
[
Callable
]
=
None
kwargs
:
Optional
[
dict
]
=
None
kwargs
:
Optional
[
dict
]
=
field
(
default_factory
=
dict
)
@
dataclass
@
dataclass
...
@@ -123,13 +123,13 @@ class DatasetConfig:
...
@@ -123,13 +123,13 @@ class DatasetConfig:
name
:
Optional
[
str
]
=
None
name
:
Optional
[
str
]
=
None
kwargs
:
Optional
[
dict
]
=
field
(
default_factory
=
dict
)
kwargs
:
Optional
[
dict
]
=
field
(
default_factory
=
dict
)
custom
:
Optional
[
Callable
]
=
None
custom
:
Optional
[
Callable
]
=
None
metadata
:
Optional
[
dict
]
=
None
metadata
:
Optional
[
dict
]
=
field
(
default_factory
=
dict
)
@
dataclass
@
dataclass
class
TaskConfig
(
dict
):
class
TaskConfig
(
dict
):
# task naming/registry
# task naming/registry
task
:
Optional
[
str
]
=
None
task
:
str
task_alias
:
Optional
[
str
]
=
None
task_alias
:
Optional
[
str
]
=
None
tag
:
Optional
[
Union
[
str
,
list
]]
=
None
tag
:
Optional
[
Union
[
str
,
list
]]
=
None
# HF dataset options.
# HF dataset options.
...
@@ -171,13 +171,14 @@ class TaskConfig(dict):
...
@@ -171,13 +171,14 @@ class TaskConfig(dict):
should_decontaminate
:
bool
=
False
should_decontaminate
:
bool
=
False
doc_to_decontamination_query
:
Optional
[
str
]
=
None
doc_to_decontamination_query
:
Optional
[
str
]
=
None
gen_prefix
:
Optional
[
str
]
=
None
gen_prefix
:
Optional
[
str
]
=
None
metadata
:
Optional
[
dict
]
=
(
metadata
:
Optional
[
dict
]
=
field
(
None
# by default, not used in the code. allows for users to pass arbitrary info to tasks
default_factory
=
dict
)
)
# by default, not used in the code. allows for users to pass arbitrary info to tasks
_metric_list
:
list
[
MetricConfig
]
=
None
_metric_list
:
list
[
MetricConfig
]
=
None
_filter_list
:
list
[
FilterConfig
]
=
None
_filter_list
:
list
[
FilterConfig
]
=
None
ds_cfg
:
DatasetConfig
=
None
ds_cfg
:
DatasetConfig
=
field
(
init
=
False
)
fewshot_cfg
:
FewshotConfig
=
None
fewshot_cfg
:
FewshotConfig
=
field
(
init
=
False
)
def
__post_init__
(
self
)
->
None
:
def
__post_init__
(
self
)
->
None
:
### ---setup generation kwargs--- ###
### ---setup generation kwargs--- ###
...
@@ -218,7 +219,7 @@ class TaskConfig(dict):
...
@@ -218,7 +219,7 @@ class TaskConfig(dict):
name
=
self
.
dataset_name
,
name
=
self
.
dataset_name
,
kwargs
=
self
.
dataset_kwargs
,
kwargs
=
self
.
dataset_kwargs
,
custom
=
self
.
custom_dataset
,
custom
=
self
.
custom_dataset
,
metadata
=
self
.
metadata
,
metadata
=
self
.
metadata
or
{}
,
)
)
# ---setup fewshot config--- #
# ---setup fewshot config--- #
_fewshot_cfg
=
self
.
fewshot_config
if
self
.
fewshot_config
is
not
None
else
{}
_fewshot_cfg
=
self
.
fewshot_config
if
self
.
fewshot_config
is
not
None
else
{}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment