Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
7399dae1
Commit
7399dae1
authored
Dec 19, 2024
by
Baber
Browse files
fix aggregation
parent
19d54607
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
177 additions
and
65 deletions
+177
-65
lm_eval/tasks/ruler/essays.py
lm_eval/tasks/ruler/essays.py
+149
-50
lm_eval/tasks/ruler/prepare.py
lm_eval/tasks/ruler/prepare.py
+12
-2
lm_eval/tasks/ruler/sniah_1.yaml
lm_eval/tasks/ruler/sniah_1.yaml
+2
-2
lm_eval/tasks/ruler/utils.py
lm_eval/tasks/ruler/utils.py
+14
-11
No files found.
lm_eval/tasks/ruler/essays.py
View file @
7399dae1
...
@@ -11,21 +11,56 @@
...
@@ -11,21 +11,56 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License
# limitations under the License
import
asyncio
import
glob
import
glob
import
os
import
os
import
shutil
import
shutil
import
urllib.request
from
typing
import
Dict
from
functools
import
cache
import
html2text
import
html2text
import
requests
import
httpx
from
bs4
import
BeautifulSoup
from
bs4
import
BeautifulSoup
from
tqdm
import
tqdm
from
tqdm
.asyncio
import
tqdm
as
async_tqdm
@
cache
async
def
fetch_url
(
client
:
httpx
.
AsyncClient
,
url
:
str
)
->
str
:
def
get_essays
():
response
=
await
client
.
get
(
url
)
response
.
raise_for_status
()
return
response
.
text
async
def
process_html_essay
(
client
:
httpx
.
AsyncClient
,
url
:
str
,
h
:
html2text
.
HTML2Text
,
temp_folder
:
str
)
->
None
:
filename
=
url
.
split
(
"/"
)[
-
1
].
replace
(
".html"
,
".txt"
)
try
:
content
=
await
fetch_url
(
client
,
url
)
soup
=
BeautifulSoup
(
content
,
"html.parser"
)
specific_tag
=
soup
.
find
(
"font"
)
if
specific_tag
:
parsed
=
h
.
handle
(
str
(
specific_tag
))
with
open
(
os
.
path
.
join
(
temp_folder
,
filename
),
"w"
,
encoding
=
"utf-8"
)
as
file
:
file
.
write
(
parsed
)
except
Exception
as
e
:
print
(
f
"Failed to download
{
filename
}
:
{
str
(
e
)
}
"
)
async
def
process_text_essay
(
client
:
httpx
.
AsyncClient
,
url
:
str
,
temp_folder
:
str
)
->
None
:
filename
=
url
.
split
(
"/"
)[
-
1
]
try
:
content
=
await
fetch_url
(
client
,
url
)
with
open
(
os
.
path
.
join
(
temp_folder
,
filename
),
"w"
,
encoding
=
"utf-8"
)
as
file
:
file
.
write
(
content
)
except
Exception
as
e
:
print
(
f
"Failed to download
{
filename
}
:
{
str
(
e
)
}
"
)
async
def
get_essays
()
->
Dict
[
str
,
str
]:
temp_folder_repo
=
"essay_repo"
temp_folder_repo
=
"essay_repo"
temp_folder_html
=
"essay_html"
temp_folder_html
=
"essay_html"
os
.
makedirs
(
temp_folder_repo
,
exist_ok
=
True
)
os
.
makedirs
(
temp_folder_repo
,
exist_ok
=
True
)
...
@@ -38,62 +73,126 @@ def get_essays():
...
@@ -38,62 +73,126 @@ def get_essays():
h
.
reference_links
=
False
h
.
reference_links
=
False
h
.
mark_code
=
False
h
.
mark_code
=
False
url
=
"https://raw.githubusercontent.com/NVIDIA/RULER/main/scripts/data/synthetic/json/PaulGrahamEssays_URLs.txt"
url_list
=
"https://raw.githubusercontent.com/NVIDIA/RULER/main/scripts/data/synthetic/json/PaulGrahamEssays_URLs.txt"
response
=
requests
.
get
(
url
)
response
.
raise_for_status
()
# The content is now in memory as a string
content
=
response
.
text
# If you want to process it line by line:
urls
=
content
.
splitlines
()
for
url
in
tqdm
(
urls
):
if
".html"
in
url
:
filename
=
url
.
split
(
"/"
)[
-
1
].
replace
(
".html"
,
".txt"
)
try
:
with
urllib
.
request
.
urlopen
(
url
)
as
website
:
content
=
website
.
read
().
decode
(
"unicode_escape"
,
"utf-8"
)
soup
=
BeautifulSoup
(
content
,
"html.parser"
)
specific_tag
=
soup
.
find
(
"font"
)
parsed
=
h
.
handle
(
str
(
specific_tag
))
with
open
(
os
.
path
.
join
(
temp_folder_html
,
filename
),
"w"
)
as
file
:
file
.
write
(
parsed
)
except
Exception
as
e
:
async
with
httpx
.
AsyncClient
(
timeout
=
30.0
,
follow_redirects
=
True
)
as
client
:
print
(
f
"Fail download
{
filename
}
, (
{
e
}
)"
)
# Fetch URL list
content
=
await
fetch_url
(
client
,
url_list
)
urls
=
content
.
splitlines
()
else
:
# Separate HTML and text URLs
filename
=
url
.
split
(
"/"
)[
-
1
]
html_urls
=
[
url
for
url
in
urls
if
".html"
in
url
]
try
:
text_urls
=
[
url
for
url
in
urls
if
".html"
not
in
url
]
with
urllib
.
request
.
urlopen
(
url
)
as
website
:
content
=
website
.
read
().
decode
(
"utf-8"
)
with
open
(
os
.
path
.
join
(
temp_folder_repo
,
filename
),
"w"
)
as
file
:
# Process HTML essays
file
.
write
(
content
)
html_tasks
=
[
process_html_essay
(
client
,
url
,
h
,
temp_folder_html
)
for
url
in
html_urls
]
await
async_tqdm
.
gather
(
*
html_tasks
,
desc
=
"Downloading HTML essays"
)
except
Exception
as
e
:
# Process text essays
print
(
f
"Fail download
{
filename
}
, (
{
e
}
)"
)
text_tasks
=
[
process_text_essay
(
client
,
url
,
temp_folder_repo
)
for
url
in
text_urls
]
await
async_tqdm
.
gather
(
*
text_tasks
,
desc
=
"Downloading text essays"
)
# Collect results
files_repo
=
sorted
(
glob
.
glob
(
os
.
path
.
join
(
temp_folder_repo
,
"*.txt"
)))
files_repo
=
sorted
(
glob
.
glob
(
os
.
path
.
join
(
temp_folder_repo
,
"*.txt"
)))
files_html
=
sorted
(
glob
.
glob
(
os
.
path
.
join
(
temp_folder_html
,
"*.txt"
)))
files_html
=
sorted
(
glob
.
glob
(
os
.
path
.
join
(
temp_folder_html
,
"*.txt"
)))
print
(
f
"Download
{
len
(
files_repo
)
}
essays from `https://github.com/gkamradt/LLMTest_NeedleInAHaystack/`"
)
print
(
f
"Download
{
len
(
files_html
)
}
essays from `http://www.paulgraham.com/`"
)
# print(
# f"Downloaded {len(files_repo)} essays from `https://github.com/gkamradt/LLMTest_NeedleInAHaystack/`"
# )
# print(f"Downloaded {len(files_html)} essays from `http://www.paulgraham.com/`")
# Combine all texts
text
=
""
text
=
""
for
file
in
files_repo
+
files_html
:
for
file
in
files_repo
+
files_html
:
with
open
(
file
,
"r"
)
as
f
:
with
open
(
file
,
"r"
,
encoding
=
"utf-8"
)
as
f
:
text
+=
f
.
read
()
text
+=
f
.
read
()
# Cleanup
shutil
.
rmtree
(
temp_folder_repo
)
shutil
.
rmtree
(
temp_folder_repo
)
shutil
.
rmtree
(
temp_folder_html
)
shutil
.
rmtree
(
temp_folder_html
)
return
{
"text"
:
text
}
return
{
"text"
:
text
}
# with open('PaulGrahamEssays.json', 'w') as f:
# json.dump({"text": text}, f)
def
get_all_essays
()
->
Dict
[
str
,
str
]:
#
"""Synchronous wrapper for get_essays()"""
# shutil.rmtree(temp_folder_repo)
return
asyncio
.
run
(
get_essays
())
# shutil.rmtree(temp_folder_html)
# @cache
# def get_essays():
# temp_folder_repo = "essay_repo"
# temp_folder_html = "essay_html"
# os.makedirs(temp_folder_repo, exist_ok=True)
# os.makedirs(temp_folder_html, exist_ok=True)
#
# h = html2text.HTML2Text()
# h.ignore_images = True
# h.ignore_tables = True
# h.escape_all = True
# h.reference_links = False
# h.mark_code = False
#
# url = "https://raw.githubusercontent.com/NVIDIA/RULER/main/scripts/data/synthetic/json/PaulGrahamEssays_URLs.txt"
# response = requests.get(url)
# response.raise_for_status()
#
# # The content is now in memory as a string
# content = response.text
#
# # If you want to process it line by line:
# urls = content.splitlines()
#
# for url in tqdm(urls):
# if ".html" in url:
# filename = url.split("/")[-1].replace(".html", ".txt")
# try:
# with urllib.request.urlopen(url) as website:
# content = website.read().decode("unicode_escape", "utf-8")
# soup = BeautifulSoup(content, "html.parser")
# specific_tag = soup.find("font")
# parsed = h.handle(str(specific_tag))
#
# with open(os.path.join(temp_folder_html, filename), "w") as file:
# file.write(parsed)
#
# except Exception as e:
# print(f"Fail download {filename}, ({e})")
#
# else:
# filename = url.split("/")[-1]
# try:
# with urllib.request.urlopen(url) as website:
# content = website.read().decode("utf-8")
#
# with open(os.path.join(temp_folder_repo, filename), "w") as file:
# file.write(content)
#
# except Exception as e:
# print(f"Fail download {filename}, ({e})")
#
# files_repo = sorted(glob.glob(os.path.join(temp_folder_repo, "*.txt")))
# files_html = sorted(glob.glob(os.path.join(temp_folder_html, "*.txt")))
# print(
# f"Download {len(files_repo)} essays from `https://github.com/gkamradt/LLMTest_NeedleInAHaystack/`"
# )
# print(f"Download {len(files_html)} essays from `http://www.paulgraham.com/`")
#
# text = ""
# for file in files_repo + files_html:
# with open(file, "r") as f:
# text += f.read()
#
# shutil.rmtree(temp_folder_repo)
# shutil.rmtree(temp_folder_html)
# return {"text": text}
#
# # with open('PaulGrahamEssays.json', 'w') as f:
# # json.dump({"text": text}, f)
# #
# # shutil.rmtree(temp_folder_repo)
# # shutil.rmtree(temp_folder_html)
lm_eval/tasks/ruler/prepare.py
View file @
7399dae1
import
os
import
os
import
random
import
random
import
uuid
import
uuid
from
linecache
import
cache
from
functools
import
lru_cache
from
typing
import
List
import
numpy
as
np
import
numpy
as
np
import
wonderwords
import
wonderwords
...
@@ -40,6 +43,11 @@ NLTK_MIN_VERSION = "3.9.1"
...
@@ -40,6 +43,11 @@ NLTK_MIN_VERSION = "3.9.1"
RANK
=
os
.
environ
.
get
(
"LOCAL_RANK"
,
"0"
)
RANK
=
os
.
environ
.
get
(
"LOCAL_RANK"
,
"0"
)
@
lru_cache
(
maxsize
=
1024
)
def
cached_sent_tokenize
(
text
:
str
)
->
List
[
str
]:
return
sent_tokenize
(
text
)
def
download_nltk_resources
():
def
download_nltk_resources
():
"""Download 'punkt' if not already installed"""
"""Download 'punkt' if not already installed"""
assert
(
assert
(
...
@@ -119,7 +127,7 @@ def generate_input_output(
...
@@ -119,7 +127,7 @@ def generate_input_output(
if
type_haystack
==
"essay"
:
if
type_haystack
==
"essay"
:
assert
isinstance
(
haystack
,
list
)
assert
isinstance
(
haystack
,
list
)
text
=
" "
.
join
(
haystack
[:
num_haystack
])
text
=
" "
.
join
(
haystack
[:
num_haystack
])
document_sents
=
sent_tokenize
(
text
.
strip
())
document_sents
=
cached_
sent_tokenize
(
text
.
strip
())
insertion_positions
=
(
insertion_positions
=
(
[
0
]
[
0
]
+
sorted
(
+
sorted
(
...
@@ -288,7 +296,9 @@ def generate_samples(
...
@@ -288,7 +296,9 @@ def generate_samples(
"max_length"
:
max_seq_length
,
"max_length"
:
max_seq_length
,
}
}
if
formatted_output
[
"outputs"
][
0
]
not
in
formatted_output
[
"input"
]:
if
formatted_output
[
"outputs"
][
0
]
not
in
formatted_output
[
"input"
]:
COUNT
+=
1
assert
(
False
),
f
"Needle not in input:
{
formatted_output
}
. Something went wrong."
write_jsons
.
append
(
formatted_output
)
write_jsons
.
append
(
formatted_output
)
print
(
COUNT
)
print
(
COUNT
)
return
write_jsons
return
write_jsons
lm_eval/tasks/ruler/sniah_1.yaml
View file @
7399dae1
...
@@ -7,7 +7,7 @@ output_type: generate_until
...
@@ -7,7 +7,7 @@ output_type: generate_until
test_split
:
test
test_split
:
test
download_dataset
:
!function
utils.niah_single_1
download_dataset
:
!function
utils.niah_single_1
doc_to_text
:
"
{{input}}"
doc_to_text
:
"
{{input}}"
doc_to_target
:
"
{{outputs[0]}}"
#" {{answer.split('### ')[-1].rstrip()}}"
doc_to_target
:
"
{{outputs[0]}}"
process_results
:
!function
utils.process_results
process_results
:
!function
utils.process_results
metric_list
:
metric_list
:
...
@@ -36,4 +36,4 @@ generation_kwargs:
...
@@ -36,4 +36,4 @@ generation_kwargs:
until
:
[]
until
:
[]
repeats
:
1
repeats
:
1
metadata
:
metadata
:
version
:
3
.0
version
:
1
.0
lm_eval/tasks/ruler/utils.py
View file @
7399dae1
...
@@ -3,13 +3,13 @@ import itertools
...
@@ -3,13 +3,13 @@ import itertools
import
json
import
json
import
os
import
os
import
re
import
re
from
functools
import
partial
from
functools
import
partial
,
cache
from
typing
import
Literal
from
typing
import
Literal
import
datasets
import
datasets
from
transformers
import
AutoTokenizer
from
transformers
import
AutoTokenizer
from
lm_eval.tasks.ruler.essays
import
get_essays
from
lm_eval.tasks.ruler.essays
import
get_essays
,
get_all_essays
from
lm_eval.tasks.ruler.prepare
import
generate_samples
from
lm_eval.tasks.ruler.prepare
import
generate_samples
...
@@ -31,10 +31,11 @@ STOP_WORDS = ""
...
@@ -31,10 +31,11 @@ STOP_WORDS = ""
RANDOM_SEED
=
42
RANDOM_SEED
=
42
@
cache
def
get_haystack
(
type_haystack
:
Literal
[
"essay"
,
"repeat"
,
"needle"
]):
def
get_haystack
(
type_haystack
:
Literal
[
"essay"
,
"repeat"
,
"needle"
]):
NEEDLE
=
"One of the special magic {type_needle_v} for {key} is: {value}."
NEEDLE
=
"One of the special magic {type_needle_v} for {key} is: {value}."
if
type_haystack
==
"essay"
:
if
type_haystack
==
"essay"
:
essay
=
get_essays
()[
"text"
]
essay
=
get_
all_
essays
()[
"text"
]
# essay = json.load(open(essay))["text"]
# essay = json.load(open(essay))["text"]
haystack
=
re
.
sub
(
r
"\s+"
,
" "
,
essay
).
split
(
" "
)
haystack
=
re
.
sub
(
r
"\s+"
,
" "
,
essay
).
split
(
" "
)
elif
type_haystack
==
"repeat"
:
elif
type_haystack
==
"repeat"
:
...
@@ -155,7 +156,7 @@ niah_multiquery = lambda: flatten(
...
@@ -155,7 +156,7 @@ niah_multiquery = lambda: flatten(
)
)
def
postprocess_pred
(
predict_str
:
str
):
def
postprocess_pred
(
predict_str
:
str
)
->
str
:
predict_str
=
predict_str
.
strip
()
predict_str
=
predict_str
.
strip
()
# Remove all non-printable characters
# Remove all non-printable characters
...
@@ -165,16 +166,18 @@ def postprocess_pred(predict_str: str):
...
@@ -165,16 +166,18 @@ def postprocess_pred(predict_str: str):
return
predict_str
return
predict_str
def
process_results
(
doc
,
results
):
def
process_results
(
doc
:
dict
,
results
:
list
[
str
])
->
dict
[
str
,
float
]:
# hacky: set all other lengths to -1
metrics
=
{
str
(
length
):
-
1.0
for
length
in
SEQ_LENGTHS
}
metrics
=
{
str
(
length
):
-
1.0
for
length
in
SEQ_LENGTHS
}
input_len
=
doc
[
"max_length"
]
input_len
=
doc
[
"max_length"
]
acc
=
1.0
if
postprocess_pred
(
results
[
0
])
in
doc
[
"input"
]
else
0.0
acc
=
1.0
if
postprocess_pred
(
results
[
0
])
in
doc
[
"input"
]
else
0.0
metrics
[
str
(
next
(
length
for
length
in
SEQ_LENGTHS
if
input_len
<=
length
)
)]
=
acc
metrics
[
str
(
input_len
)]
=
acc
return
metrics
return
metrics
def
aggregate_metrics
(
metrics
):
def
aggregate_metrics
(
metrics
:
list
[
int
])
->
float
:
return
{
res
=
[
x
for
x
in
metrics
if
x
!=
-
1
]
length
:
sum
(
metric
[
length
]
for
metric
in
metrics
)
/
len
(
metrics
)
if
not
res
:
for
length
in
SEQ_LENGTHS
# we don't have any samples with this length
}
return
0.0
return
sum
(
res
)
/
len
(
res
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment