Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
0bad3ace
Commit
0bad3ace
authored
Jul 23, 2025
by
Baber
Browse files
nit
parent
43388406
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
20 additions
and
20 deletions
+20
-20
lm_eval/api/model.py
lm_eval/api/model.py
+20
-20
No files found.
lm_eval/api/model.py
View file @
0bad3ace
from
__future__
import
annotations
import
abc
import
abc
import
hashlib
import
hashlib
import
json
import
json
import
logging
import
logging
import
os
import
os
from
typing
import
TYPE_CHECKING
,
Any
,
Iterable
,
Optional
,
Type
,
TypeVar
,
Union
from
collections.abc
import
Iterable
from
typing
import
TYPE_CHECKING
,
Any
,
TypeVar
from
tqdm
import
tqdm
from
tqdm
import
tqdm
...
@@ -31,7 +34,7 @@ class LM(abc.ABC):
...
@@ -31,7 +34,7 @@ class LM(abc.ABC):
# set rank and world size to a single process, by default.
# set rank and world size to a single process, by default.
self
.
_rank
=
0
self
.
_rank
=
0
self
.
_world_size
=
1
self
.
_world_size
=
1
self
.
cache_hook
:
"
CacheHook
"
=
CacheHook
(
None
)
self
.
cache_hook
:
CacheHook
=
CacheHook
(
None
)
@
abc
.
abstractmethod
@
abc
.
abstractmethod
def
loglikelihood
(
self
,
requests
)
->
list
[
tuple
[
float
,
bool
]]:
def
loglikelihood
(
self
,
requests
)
->
list
[
tuple
[
float
,
bool
]]:
...
@@ -101,7 +104,7 @@ class LM(abc.ABC):
...
@@ -101,7 +104,7 @@ class LM(abc.ABC):
# TODO: Add an optional max length
# TODO: Add an optional max length
@
abc
.
abstractmethod
@
abc
.
abstractmethod
def
generate_until
(
self
,
requests
:
list
[
"
Instance
"
])
->
list
[
str
]:
def
generate_until
(
self
,
requests
:
list
[
Instance
])
->
list
[
str
]:
"""Generate greedily until a stopping sequence
"""Generate greedily until a stopping sequence
:param requests: list[Instance]
:param requests: list[Instance]
...
@@ -137,7 +140,7 @@ class LM(abc.ABC):
...
@@ -137,7 +140,7 @@ class LM(abc.ABC):
@
classmethod
@
classmethod
def
create_from_arg_string
(
def
create_from_arg_string
(
cls
:
T
ype
[
T
],
arg_string
:
str
,
additional_config
:
Optional
[
dict
]
=
None
cls
:
t
ype
[
T
],
arg_string
:
str
,
additional_config
:
dict
|
None
=
None
)
->
T
:
)
->
T
:
"""
"""
Creates an instance of the LM class using the given argument string and additional config.
Creates an instance of the LM class using the given argument string and additional config.
...
@@ -156,7 +159,7 @@ class LM(abc.ABC):
...
@@ -156,7 +159,7 @@ class LM(abc.ABC):
@
classmethod
@
classmethod
def
create_from_arg_obj
(
def
create_from_arg_obj
(
cls
:
T
ype
[
T
],
arg_dict
:
dict
,
additional_config
:
Optional
[
dict
]
=
None
cls
:
t
ype
[
T
],
arg_dict
:
dict
,
additional_config
:
dict
|
None
=
None
)
->
T
:
)
->
T
:
"""
"""
Creates an instance of the LM class using the given arg_obj
Creates an instance of the LM class using the given arg_obj
...
@@ -201,7 +204,7 @@ class LM(abc.ABC):
...
@@ -201,7 +204,7 @@ class LM(abc.ABC):
"To use this model with chat templates, please implement the 'tokenizer_name' property."
"To use this model with chat templates, please implement the 'tokenizer_name' property."
)
)
def
chat_template
(
self
,
chat_template
:
Union
[
bool
,
str
]
=
False
)
->
Optional
[
str
]
:
def
chat_template
(
self
,
chat_template
:
bool
|
str
=
False
)
->
str
|
None
:
"""Returns the chat template structure for user/assistant messages if a template is provided.
"""Returns the chat template structure for user/assistant messages if a template is provided.
This method is intended to be overridden in a subclass to define a specific chat template format.
This method is intended to be overridden in a subclass to define a specific chat template format.
For models that do not support chat templates, this method returns None by default.
For models that do not support chat templates, this method returns None by default.
...
@@ -209,7 +212,7 @@ class LM(abc.ABC):
...
@@ -209,7 +212,7 @@ class LM(abc.ABC):
return
""
return
""
def
set_cache_hook
(
self
,
cache_hook
:
"
CacheHook
"
)
->
None
:
def
set_cache_hook
(
self
,
cache_hook
:
CacheHook
)
->
None
:
"""Sets the cache hook for the LM, which is used to cache responses from the LM."""
"""Sets the cache hook for the LM, which is used to cache responses from the LM."""
self
.
cache_hook
=
cache_hook
self
.
cache_hook
=
cache_hook
...
@@ -221,10 +224,10 @@ def hash_args(attr: str, args: Iterable[Any]) -> str:
...
@@ -221,10 +224,10 @@ def hash_args(attr: str, args: Iterable[Any]) -> str:
class
CacheHook
:
class
CacheHook
:
def
__init__
(
self
,
cachinglm
:
Optional
[
"
CachingLM
"
]
)
->
None
:
def
__init__
(
self
,
cachinglm
:
CachingLM
|
None
)
->
None
:
"""CacheHook is used to cache responses from the LM."""
"""CacheHook is used to cache responses from the LM."""
if
cachinglm
is
None
:
if
cachinglm
is
None
:
self
.
dbdict
:
Optional
[
"
SqliteDict
"
]
=
None
self
.
dbdict
:
SqliteDict
|
None
=
None
return
return
self
.
dbdict
=
cachinglm
.
dbdict
self
.
dbdict
=
cachinglm
.
dbdict
...
@@ -238,7 +241,7 @@ class CacheHook:
...
@@ -238,7 +241,7 @@ class CacheHook:
class
CachingLM
:
class
CachingLM
:
def
__init__
(
self
,
lm
:
"
LM
"
,
cache_db
:
str
)
->
None
:
def
__init__
(
self
,
lm
:
LM
,
cache_db
:
str
)
->
None
:
"""LM wrapper that returns cached results if they exist, and uses the underlying LM if not.
"""LM wrapper that returns cached results if they exist, and uses the underlying LM if not.
:param lm: LM
:param lm: LM
...
@@ -263,7 +266,7 @@ class CachingLM:
...
@@ -263,7 +266,7 @@ class CachingLM:
eval_logger
.
debug
(
f
"Passing through attribute '
{
attr
}
' to underlying LM"
)
eval_logger
.
debug
(
f
"Passing through attribute '
{
attr
}
' to underlying LM"
)
return
lm_attr
return
lm_attr
def
_fn
(
requests
:
list
[
"
Instance
"
])
->
list
[
"
Instance
"
]:
def
_fn
(
requests
:
list
[
Instance
])
->
list
[
Instance
]:
res
=
[]
res
=
[]
remaining_reqs
=
[]
remaining_reqs
=
[]
warned
=
False
warned
=
False
...
@@ -295,11 +298,8 @@ class CachingLM:
...
@@ -295,11 +298,8 @@ class CachingLM:
eval_logger
.
info
(
eval_logger
.
info
(
f
"Cached requests:
{
len
(
requests
)
-
len
(
remaining_reqs
)
}
, Requests remaining:
{
len
(
remaining_reqs
)
}
"
f
"Cached requests:
{
len
(
requests
)
-
len
(
remaining_reqs
)
}
, Requests remaining:
{
len
(
remaining_reqs
)
}
"
)
)
if
remaining_reqs
:
# actually run the LM on the requests that do not have cached results
rem_res
=
getattr
(
self
.
lm
,
attr
)(
remaining_reqs
)
if
remaining_reqs
else
[]
rem_res
=
getattr
(
self
.
lm
,
attr
)(
remaining_reqs
)
else
:
rem_res
=
[]
# stick the new ones back into the list and also cache any of the new ones
# stick the new ones back into the list and also cache any of the new ones
resptr
=
0
resptr
=
0
...
@@ -318,7 +318,7 @@ class CachingLM:
...
@@ -318,7 +318,7 @@ class CachingLM:
return
_fn
return
_fn
def
get_cache_hook
(
self
)
->
"
CacheHook
"
:
def
get_cache_hook
(
self
)
->
CacheHook
:
return
CacheHook
(
self
)
return
CacheHook
(
self
)
...
@@ -395,7 +395,7 @@ class TemplateLM(LM):
...
@@ -395,7 +395,7 @@ class TemplateLM(LM):
return
context_enc
,
continuation_enc
return
context_enc
,
continuation_enc
def
loglikelihood
(
def
loglikelihood
(
self
,
requests
:
list
[
"
Instance
"
],
disable_tqdm
:
bool
=
False
self
,
requests
:
list
[
Instance
],
disable_tqdm
:
bool
=
False
)
->
list
[
tuple
[
float
,
bool
]]:
)
->
list
[
tuple
[
float
,
bool
]]:
"""Compute log-likelihood of generating a continuation from a context.
"""Compute log-likelihood of generating a continuation from a context.
...
@@ -428,7 +428,7 @@ class TemplateLM(LM):
...
@@ -428,7 +428,7 @@ class TemplateLM(LM):
@
abc
.
abstractmethod
@
abc
.
abstractmethod
def
generate_until
(
def
generate_until
(
self
,
requests
:
list
[
"
Instance
"
],
disable_tqdm
:
bool
=
False
self
,
requests
:
list
[
Instance
],
disable_tqdm
:
bool
=
False
)
->
list
[
str
]:
)
->
list
[
str
]:
"""Generate until a stopping sequence.
"""Generate until a stopping sequence.
...
@@ -449,7 +449,7 @@ class TemplateLM(LM):
...
@@ -449,7 +449,7 @@ class TemplateLM(LM):
"""
"""
pass
pass
def
chat_template
(
self
,
chat_template
:
Union
[
bool
,
str
]
=
False
)
->
Optional
[
str
]
:
def
chat_template
(
self
,
chat_template
:
bool
|
str
=
False
)
->
str
|
None
:
"""
"""
Assumes tokenizer has a chat_template attribute (self.tokenizer.chat_template: dict | str)
Assumes tokenizer has a chat_template attribute (self.tokenizer.chat_template: dict | str)
Set and get the appropriate chat template for the model.
Set and get the appropriate chat template for the model.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment