Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
764f6fb2
Commit
764f6fb2
authored
Jan 17, 2025
by
Baber
Browse files
better api
parent
a61b3ee6
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
39 additions
and
36 deletions
+39
-36
lm_eval/api/task.py
lm_eval/api/task.py
+5
-8
lm_eval/tasks/ruler/prepare.py
lm_eval/tasks/ruler/prepare.py
+5
-7
lm_eval/tasks/ruler/utils.py
lm_eval/tasks/ruler/utils.py
+28
-19
lm_eval/tasks/ruler/vt_utils.py
lm_eval/tasks/ruler/vt_utils.py
+1
-2
No files found.
lm_eval/api/task.py
View file @
764f6fb2
...
...
@@ -60,7 +60,7 @@ class TaskConfig(dict):
# HF dataset options.
# which dataset to use,
# and what splits for what purpose
download_dataset
:
Optional
[
bool
]
=
None
download_dataset
:
Optional
[
Callable
]
=
None
dataset_path
:
Optional
[
str
]
=
None
dataset_name
:
Optional
[
str
]
=
None
dataset_kwargs
:
Optional
[
dict
]
=
None
...
...
@@ -822,13 +822,10 @@ class ConfigurableTask(Task):
self
.
download
(
self
.
config
.
dataset_kwargs
)
else
:
self
.
dataset
=
self
.
config
.
download_dataset
(
metadata
=
self
.
config
.
metadata
.
get
(
"tokenizer"
,
self
.
config
.
metadata
.
get
(
"pretrained"
),
**
self
.
config
.
dataset_kwargs
if
self
.
config
.
dataset_kwargs
is
not
None
else
{},
)
metadata
=
self
.
config
.
metadata
,
**
self
.
config
.
dataset_kwargs
if
self
.
config
.
dataset_kwargs
is
not
None
else
{},
)
self
.
_training_docs
=
None
self
.
_fewshot_docs
=
None
...
...
lm_eval/tasks/ruler/prepare.py
View file @
764f6fb2
...
...
@@ -193,7 +193,7 @@ def generate_input_output(
query
=
query
,
)
return
input_text
,
answers
return
input_text
,
answers
,
query
def
generate_samples
(
...
...
@@ -213,7 +213,7 @@ def generate_samples(
remove_newline_tab
:
bool
=
False
,
random_seed
:
int
=
42
,
TOKENIZER
=
None
,
):
)
->
list
[
dict
]
:
assert
TOKENIZER
is
not
None
,
"TOKENIZER is not defined."
num_needle_k
=
max
(
num_needle_k
,
num_needle_q
)
write_jsons
=
[]
...
...
@@ -233,7 +233,7 @@ def generate_samples(
total_tokens
=
0
# Track the total tokens generated for the first example
while
total_tokens
+
tokens_to_generate
<
max_seq_length
:
input_text
,
answer
=
generate_input_output
(
input_text
,
answer
,
query
=
generate_input_output
(
num_haystack
,
haystack
,
type_haystack
=
type_haystack
,
...
...
@@ -247,9 +247,6 @@ def generate_samples(
)
# Calculate the number of tokens in the example
total_tokens
=
len
(
TOKENIZER
(
input_text
+
" "
.
join
(
answer
)).
input_ids
)
# print(
# f"Max length {max_seq_length} | Current length {total_tokens + tokens_to_generate} | Haystack: {num_haystack}"
# )
if
total_tokens
+
tokens_to_generate
>
max_seq_length
:
num_haystack
-=
incremental
break
...
...
@@ -270,7 +267,7 @@ def generate_samples(
used_haystack
=
num_haystack
while
True
:
try
:
input_text
,
answer
=
generate_input_output
(
input_text
,
answer
,
query
=
generate_input_output
(
used_haystack
,
haystack
,
type_haystack
=
type_haystack
,
...
...
@@ -301,6 +298,7 @@ def generate_samples(
"outputs"
:
answer
,
"length"
:
length
,
"max_length"
:
max_seq_length
,
"gen_prefix"
:
f
"The special magic
{
type_needle_v
}
for
{
query
}
mentioned in the provided text are"
,
}
if
formatted_output
[
"outputs"
][
0
]
not
in
formatted_output
[
"input"
]:
assert
(
...
...
lm_eval/tasks/ruler/utils.py
View file @
764f6fb2
...
...
@@ -2,7 +2,7 @@
import
itertools
import
re
from
functools
import
cache
from
typing
import
Literal
from
typing
import
Literal
,
Generator
,
Union
,
TYPE_CHECKING
import
datasets
from
transformers
import
AutoTokenizer
...
...
@@ -10,9 +10,16 @@ from transformers import AutoTokenizer
from
lm_eval.tasks.ruler.essays
import
get_all_essays
from
lm_eval.tasks.ruler.prepare
import
generate_samples
if
TYPE_CHECKING
:
import
transformers
@
cache
def
get_tokenizer
(
pretrained
):
def
get_tokenizer
(
**
kwargs
,
)
->
Union
[
"transformers.PreTrainedTokenizer"
,
"transformers.PreTrainedTokenizerFast"
]:
kwargs
=
kwargs
.
get
(
"metadata"
,
{})
pretrained
=
kwargs
.
get
(
"tokenizer"
,
kwargs
.
get
(
"pretrained"
,
{}))
assert
pretrained
,
"No tokenizer or pretrained provided."
print
(
"using tokenizer "
,
pretrained
)
return
AutoTokenizer
.
from_pretrained
(
pretrained
,
trust_remote_code
=
True
)
...
...
@@ -36,7 +43,9 @@ RANDOM_SEED = 42
@
cache
def
get_haystack
(
type_haystack
:
Literal
[
"essay"
,
"repeat"
,
"needle"
]):
def
get_haystack
(
type_haystack
:
Literal
[
"essay"
,
"repeat"
,
"needle"
],
)
->
Union
[
list
[
str
],
str
]:
NEEDLE
=
"One of the special magic {type_needle_v} for {key} is: {value}."
if
type_haystack
==
"essay"
:
essay
=
get_all_essays
()[
"text"
]
...
...
@@ -51,7 +60,7 @@ def get_haystack(type_haystack: Literal["essay", "repeat", "needle"]):
return
haystack
def
flatten
(
df
)
:
def
flatten
(
df
:
Generator
)
->
dict
[
str
,
datasets
.
Dataset
]
:
return
{
"test"
:
datasets
.
Dataset
.
from_list
(
list
(
itertools
.
chain
.
from_iterable
(
df
)),
split
=
datasets
.
Split
.
TEST
...
...
@@ -60,7 +69,7 @@ def flatten(df):
# ruff: noqa
niah_single_1
=
lambda
x
:
flatten
(
niah_single_1
=
lambda
**
kwargs
:
flatten
(
generate_samples
(
get_haystack
(
type_haystack
=
"repeat"
),
max_seq_length
=
seq
,
...
...
@@ -68,7 +77,7 @@ niah_single_1 = lambda x: flatten(
type_haystack
=
"repeat"
,
type_needle_k
=
"words"
,
type_needle_v
=
"numbers"
,
TOKENIZER
=
get_tokenizer
(
x
),
TOKENIZER
=
get_tokenizer
(
**
kwargs
),
)
for
seq
in
SEQ_LENGTHS
)
...
...
@@ -86,7 +95,7 @@ niah_single_2 = lambda x: flatten(
for
seq
in
SEQ_LENGTHS
)
# noqa
niah_single_3
=
lambda
x
:
flatten
(
niah_single_3
=
lambda
**
kwargs
:
flatten
(
generate_samples
(
get_haystack
(
type_haystack
=
"essay"
),
max_seq_length
=
seq
,
...
...
@@ -94,12 +103,12 @@ niah_single_3 = lambda x: flatten(
type_haystack
=
"essay"
,
type_needle_k
=
"words"
,
type_needle_v
=
"uuids"
,
TOKENIZER
=
get_tokenizer
(
x
),
TOKENIZER
=
get_tokenizer
(
**
kwargs
),
)
for
seq
in
SEQ_LENGTHS
)
# noqa
niah_multikey_1
=
lambda
x
:
flatten
(
niah_multikey_1
=
lambda
**
kwargs
:
flatten
(
generate_samples
(
get_haystack
(
type_haystack
=
"essay"
),
max_seq_length
=
seq
,
...
...
@@ -108,12 +117,12 @@ niah_multikey_1 = lambda x: flatten(
type_needle_k
=
"words"
,
type_needle_v
=
"numbers"
,
num_needle_k
=
4
,
TOKENIZER
=
get_tokenizer
(
x
),
TOKENIZER
=
get_tokenizer
(
**
kwargs
),
)
for
seq
in
SEQ_LENGTHS
)
# noqa
niah_multikey_2
=
lambda
x
:
flatten
(
niah_multikey_2
=
lambda
**
kwargs
:
flatten
(
generate_samples
(
get_haystack
(
type_haystack
=
"needle"
),
max_seq_length
=
seq
,
...
...
@@ -121,12 +130,12 @@ niah_multikey_2 = lambda x: flatten(
type_haystack
=
"needle"
,
type_needle_k
=
"words"
,
type_needle_v
=
"numbers"
,
TOKENIZER
=
get_tokenizer
(
x
),
TOKENIZER
=
get_tokenizer
(
**
kwargs
),
)
for
seq
in
SEQ_LENGTHS
)
# noqa
niah_multikey_3
=
lambda
x
:
flatten
(
niah_multikey_3
=
lambda
**
kwargs
:
flatten
(
generate_samples
(
get_haystack
(
type_haystack
=
"needle"
),
max_seq_length
=
seq
,
...
...
@@ -134,12 +143,12 @@ niah_multikey_3 = lambda x: flatten(
type_haystack
=
"needle"
,
type_needle_k
=
"uuids"
,
type_needle_v
=
"uuids"
,
TOKENIZER
=
get_tokenizer
(
x
),
TOKENIZER
=
get_tokenizer
(
**
kwargs
),
)
for
seq
in
SEQ_LENGTHS
)
# noqa
niah_multivalue
=
lambda
x
:
flatten
(
niah_multivalue
=
lambda
**
kwargs
:
flatten
(
generate_samples
(
get_haystack
(
type_haystack
=
"essay"
),
max_seq_length
=
seq
,
...
...
@@ -148,12 +157,12 @@ niah_multivalue = lambda x: flatten(
type_needle_k
=
"words"
,
type_needle_v
=
"numbers"
,
num_needle_v
=
4
,
TOKENIZER
=
get_tokenizer
(
x
),
TOKENIZER
=
get_tokenizer
(
**
kwargs
),
)
for
seq
in
SEQ_LENGTHS
)
# noqa
niah_multiquery
=
lambda
x
:
flatten
(
niah_multiquery
=
lambda
**
kwargs
:
flatten
(
generate_samples
(
get_haystack
(
type_haystack
=
"essay"
),
max_seq_length
=
seq
,
...
...
@@ -162,7 +171,7 @@ niah_multiquery = lambda x: flatten(
type_needle_k
=
"words"
,
type_needle_v
=
"numbers"
,
num_needle_q
=
4
,
TOKENIZER
=
get_tokenizer
(
x
),
TOKENIZER
=
get_tokenizer
(
**
kwargs
),
)
for
seq
in
SEQ_LENGTHS
)
...
...
lm_eval/tasks/ruler/vt_utils.py
View file @
764f6fb2
...
...
@@ -24,8 +24,7 @@ import numpy as np
from
tqdm
import
tqdm
from
transformers
import
AutoTokenizer
from
lm_eval.tasks.ruler.prepare
import
SEQ_LENGTHS
from
lm_eval.tasks.ruler.utils
import
SEQ_LENGTHS
TASKS
=
{
"variable_tracking"
:
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment