Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
00048838
Commit
00048838
authored
Jul 22, 2025
by
Baber
Browse files
nit
parent
55be51ea
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
40 additions
and
27 deletions
+40
-27
lm_eval/api/samplers.py
lm_eval/api/samplers.py
+30
-18
lm_eval/api/task.py
lm_eval/api/task.py
+4
-5
lm_eval/config/task.py
lm_eval/config/task.py
+6
-4
No files found.
lm_eval/api/samplers.py
View file @
00048838
from
__future__
import
annotations
import
logging
import
logging
import
warnings
import
warnings
from
collections.abc
import
Iterable
,
Sequence
from
functools
import
partial
from
functools
import
partial
from
typing
import
TYPE_CHECKING
,
Iterable
,
Optional
,
Sequence
,
Union
from
typing
import
TYPE_CHECKING
,
Any
import
datasets
import
datasets
...
@@ -18,9 +21,9 @@ class ContextSampler:
...
@@ -18,9 +21,9 @@ class ContextSampler:
def
__init__
(
def
__init__
(
self
,
self
,
docs
:
list
[
dict
],
docs
:
list
[
dict
],
task
:
Union
[
"Task"
,
"
ConfigurableTask
"
]
,
task
:
Task
|
ConfigurableTask
,
fewshot_indices
:
Optional
[
Iterable
]
=
None
,
fewshot_indices
:
Iterable
|
None
=
None
,
rnd
:
Optional
[
"Random"
]
=
None
,
rnd
:
Random
|
None
=
None
,
)
->
None
:
)
->
None
:
self
.
rnd
=
rnd
self
.
rnd
=
rnd
if
not
self
.
rnd
:
if
not
self
.
rnd
:
...
@@ -75,7 +78,7 @@ class ContextSampler:
...
@@ -75,7 +78,7 @@ class ContextSampler:
)
)
self
.
docs
=
self
.
docs
.
select
(
fewshot_indices
)
self
.
docs
=
self
.
docs
.
select
(
fewshot_indices
)
def
get_context
(
self
,
doc
:
dict
,
num_fewshot
:
int
,
gen_prefix
:
str
=
None
):
def
get_context
(
self
,
doc
:
dict
,
num_fewshot
:
int
,
gen_prefix
:
str
|
None
=
None
):
# draw an extra fewshot sample if using same split as evaluating on
# draw an extra fewshot sample if using same split as evaluating on
prefix
=
gen_prefix
+
" "
if
gen_prefix
else
""
prefix
=
gen_prefix
+
" "
if
gen_prefix
else
""
n_samples
=
(
n_samples
=
(
...
@@ -95,10 +98,13 @@ class ContextSampler:
...
@@ -95,10 +98,13 @@ class ContextSampler:
for
doc
in
selected_docs
:
for
doc
in
selected_docs
:
doc_content
=
self
.
doc_to_text
(
doc
)
doc_content
=
self
.
doc_to_text
(
doc
)
doc_target
=
self
.
doc_to_target
(
doc
)
doc_target
=
self
.
doc_to_target
(
doc
)
if
self
.
config
.
doc_to_choice
is
None
or
isinstance
(
doc_content
,
str
):
if
(
self
.
config
.
doc_to_choice
is
None
and
isinstance
(
doc_content
,
str
)
)
or
isinstance
(
doc_content
,
str
):
labeled_examples
+=
doc_content
labeled_examples
+=
doc_content
else
:
else
:
labeled_examples
+=
self
.
doc_to_choice
(
doc
)[
doc_content
]
if
isinstance
(
doc_content
,
int
):
labeled_examples
+=
self
.
doc_to_choice
(
doc
)[
doc_content
]
if
doc_target
!=
""
:
if
doc_target
!=
""
:
if
self
.
target_delimiter
.
isspace
()
and
str
(
doc_target
)[
0
].
isspace
():
if
self
.
target_delimiter
.
isspace
()
and
str
(
doc_target
)[
0
].
isspace
():
...
@@ -126,7 +132,7 @@ class ContextSampler:
...
@@ -126,7 +132,7 @@ class ContextSampler:
doc
:
dict
,
doc
:
dict
,
num_fewshot
:
int
,
num_fewshot
:
int
,
fewshot_as_multiturn
:
bool
=
False
,
fewshot_as_multiturn
:
bool
=
False
,
gen_prefix
:
Optional
[
str
]
=
None
,
gen_prefix
:
str
|
None
=
None
,
):
):
# TODO: Do we need any other delimiter
# TODO: Do we need any other delimiter
prefix
=
gen_prefix
+
" "
if
gen_prefix
else
""
prefix
=
gen_prefix
+
" "
if
gen_prefix
else
""
...
@@ -181,16 +187,22 @@ class ContextSampler:
...
@@ -181,16 +187,22 @@ class ContextSampler:
return
chat_history
return
chat_history
# @classmethod
# def from_fewshot_dfg(cls, cfg: FewshotConfig):
# if not
def
sample
(
self
,
n
:
int
)
->
Sequence
[
dict
]:
def
sample
(
self
,
n
:
int
)
->
Sequence
[
dict
]:
"""
"""
Draw `n` samples from our fewshot docs. This method should be overridden by subclasses.
Draw `n` samples from our fewshot docs. This method should be overridden by subclasses.
"""
"""
assert
self
.
rnd
is
not
None
,
(
"Error: `rnd` must be set to a random.Random instance before sampling."
)
return
self
.
rnd
.
sample
(
self
.
docs
,
n
)
return
self
.
rnd
.
sample
(
self
.
docs
,
n
)
class
FirstNSampler
(
ContextSampler
):
class
FirstNSampler
(
ContextSampler
):
def
sample
(
self
,
n
:
int
)
->
Sequence
[
dict
]:
def
sample
(
self
,
n
:
int
)
->
Sequence
[
dict
[
str
,
Any
]
]:
"""
"""
Draw the first `n` samples in order from the specified split.
Draw the first `n` samples in order from the specified split.
Used for tasks with "canonical" ordered fewshot examples, such as MMLU and CMMLU.
Used for tasks with "canonical" ordered fewshot examples, such as MMLU and CMMLU.
...
@@ -202,22 +214,22 @@ class FirstNSampler(ContextSampler):
...
@@ -202,22 +214,22 @@ class FirstNSampler(ContextSampler):
class
BalancedSampler
(
ContextSampler
):
class
BalancedSampler
(
ContextSampler
):
def
sample
(
self
,
n
:
int
)
->
None
:
def
sample
(
self
,
n
:
int
):
"""
"""
TODO: this should return approximately class-balanced samples from our fewshot examples.
TODO: this should return approximately class-balanced samples from our fewshot examples.
TODO: what order should they be in? maybe random?
TODO: what order should they be in? maybe random?
"""
"""
pass
raise
NotImplementedError
class
ManualSampler
(
ContextSampler
):
class
ManualSampler
(
ContextSampler
):
def
sample
(
self
,
n
:
int
)
->
None
:
def
sample
(
self
,
n
:
int
):
""" """
""" """
pass
raise
NotImplementedError
SAMPLER_REGISTRY
=
{
SAMPLER_REGISTRY
:
dict
[
str
,
type
[
ContextSampler
]]
=
{
"default"
:
ContextSampler
,
"default"
:
ContextSampler
,
"first_n"
:
FirstNSampler
,
"first_n"
:
FirstNSampler
,
}
}
...
@@ -226,7 +238,7 @@ SAMPLER_REGISTRY = {
...
@@ -226,7 +238,7 @@ SAMPLER_REGISTRY = {
def
get_sampler
(
name
:
str
):
def
get_sampler
(
name
:
str
):
try
:
try
:
return
SAMPLER_REGISTRY
[
name
]
return
SAMPLER_REGISTRY
[
name
]
except
KeyError
:
except
KeyError
as
e
:
raise
Value
Error
(
raise
Key
Error
(
f
"Attempted to use contextsampler '
{
name
}
', but no sampling strategy for this name found! Supported model names:
{
', '
.
join
(
SAMPLER_REGISTRY
.
keys
())
}
"
f
"Attempted to use contextsampler '
{
name
}
', but no sampling strategy for this name found! Supported model names:
{
', '
.
join
(
SAMPLER_REGISTRY
.
keys
())
}
"
)
)
from
e
lm_eval/api/task.py
View file @
00048838
...
@@ -990,13 +990,12 @@ class ConfigurableTask(Task):
...
@@ -990,13 +990,12 @@ class ConfigurableTask(Task):
"""
"""
return
doc
return
doc
def
doc_to_text
(
self
,
doc
:
dict
,
doc_to_text
:
int
|
str
|
Callable
|
None
=
None
):
def
doc_to_text
(
self
,
doc
:
dict
,
doc_to_text
:
int
|
str
|
Callable
[...,
str
]
|
None
=
None
)
->
str
:
# if self.prompt is not None:
# if self.prompt is not None:
# doc_to_text = self.prompt
# doc_to_text = self.prompt
if
doc_to_text
is
not
None
:
doc_to_text
=
doc_to_text
or
self
.
config
.
doc_to_text
doc_to_text
=
doc_to_text
else
:
doc_to_text
=
self
.
config
.
doc_to_text
if
isinstance
(
doc_to_text
,
int
):
if
isinstance
(
doc_to_text
,
int
):
return
doc_to_text
return
doc_to_text
...
...
lm_eval/config/task.py
View file @
00048838
...
@@ -3,7 +3,7 @@ from __future__ import annotations
...
@@ -3,7 +3,7 @@ from __future__ import annotations
import
logging
import
logging
from
collections.abc
import
Iterable
from
collections.abc
import
Iterable
from
dataclasses
import
asdict
,
dataclass
,
field
from
dataclasses
import
asdict
,
dataclass
,
field
from
typing
import
TYPE_CHECKING
,
Callable
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
from
lm_eval.api.filter
import
FilterEnsemble
from
lm_eval.api.filter
import
FilterEnsemble
from
lm_eval.api.instance
import
OutputType
from
lm_eval.api.instance
import
OutputType
...
@@ -45,7 +45,9 @@ class FewshotConfig:
...
@@ -45,7 +45,9 @@ class FewshotConfig:
split
:
str
|
None
=
None
split
:
str
|
None
=
None
sampler
:
str
|
Callable
=
"default"
sampler
:
str
|
Callable
=
"default"
samples
:
Callable
[[],
list
[
dict
]]
|
list
[
dict
]
|
None
=
None
samples
:
Callable
[[],
list
[
dict
]]
|
list
[
dict
]
|
None
=
None
process_docs
:
Callable
[[
list
[
dict
]],
Iterable
[
dict
]]
|
None
=
None
process_docs
:
Callable
[[
list
[
dict
[
str
,
Any
]]],
Iterable
[
dict
[
str
,
Any
]]]
|
None
=
(
None
)
fewshot_indices
:
list
[
int
]
|
None
=
None
fewshot_indices
:
list
[
int
]
|
None
=
None
rnd
:
int
=
field
(
init
=
False
,
default
=
False
)
rnd
:
int
=
field
(
init
=
False
,
default
=
False
)
...
@@ -82,7 +84,7 @@ class FewshotConfig:
...
@@ -82,7 +84,7 @@ class FewshotConfig:
"samples must be either a list of dicts or a callable returning a list"
"samples must be either a list of dicts or a callable returning a list"
)
)
def
get_docs
(
self
,
dataset
)
->
Iterable
[
dict
]
|
None
:
def
get_docs
(
self
,
dataset
)
->
Iterable
[
dict
[
str
,
Any
]
]
|
None
:
"""Get processed documents from configured source."""
"""Get processed documents from configured source."""
raw_docs
=
self
.
_get_raw_docs
(
dataset
)
raw_docs
=
self
.
_get_raw_docs
(
dataset
)
if
raw_docs
is
None
:
if
raw_docs
is
None
:
...
@@ -93,7 +95,7 @@ class FewshotConfig:
...
@@ -93,7 +95,7 @@ class FewshotConfig:
return
raw_docs
return
raw_docs
@
property
@
property
def
get_sampler
(
self
):
def
get_sampler
(
self
)
->
Callable
[...,
Any
]
|
None
:
from
lm_eval.api
import
samplers
from
lm_eval.api
import
samplers
if
isinstance
(
self
.
sampler
,
str
):
if
isinstance
(
self
.
sampler
,
str
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment