Commit 7399dae1 authored by Baber's avatar Baber
Browse files

fix aggregation

parent 19d54607
......@@ -11,21 +11,56 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import asyncio
import glob
import os
import shutil
import urllib.request
from functools import cache
from typing import Dict
import html2text
import requests
import httpx
from bs4 import BeautifulSoup
from tqdm import tqdm
from tqdm.asyncio import tqdm as async_tqdm
@cache
def get_essays():
async def fetch_url(client: httpx.AsyncClient, url: str) -> str:
response = await client.get(url)
response.raise_for_status()
return response.text
async def process_html_essay(
client: httpx.AsyncClient, url: str, h: html2text.HTML2Text, temp_folder: str
) -> None:
filename = url.split("/")[-1].replace(".html", ".txt")
try:
content = await fetch_url(client, url)
soup = BeautifulSoup(content, "html.parser")
specific_tag = soup.find("font")
if specific_tag:
parsed = h.handle(str(specific_tag))
with open(
os.path.join(temp_folder, filename), "w", encoding="utf-8"
) as file:
file.write(parsed)
except Exception as e:
print(f"Failed to download {filename}: {str(e)}")
async def process_text_essay(
client: httpx.AsyncClient, url: str, temp_folder: str
) -> None:
filename = url.split("/")[-1]
try:
content = await fetch_url(client, url)
with open(os.path.join(temp_folder, filename), "w", encoding="utf-8") as file:
file.write(content)
except Exception as e:
print(f"Failed to download {filename}: {str(e)}")
async def get_essays() -> Dict[str, str]:
temp_folder_repo = "essay_repo"
temp_folder_html = "essay_html"
os.makedirs(temp_folder_repo, exist_ok=True)
......@@ -38,62 +73,126 @@ def get_essays():
h.reference_links = False
h.mark_code = False
url = "https://raw.githubusercontent.com/NVIDIA/RULER/main/scripts/data/synthetic/json/PaulGrahamEssays_URLs.txt"
response = requests.get(url)
response.raise_for_status()
# The content is now in memory as a string
content = response.text
# If you want to process it line by line:
urls = content.splitlines()
for url in tqdm(urls):
if ".html" in url:
filename = url.split("/")[-1].replace(".html", ".txt")
try:
with urllib.request.urlopen(url) as website:
content = website.read().decode("unicode_escape", "utf-8")
soup = BeautifulSoup(content, "html.parser")
specific_tag = soup.find("font")
parsed = h.handle(str(specific_tag))
with open(os.path.join(temp_folder_html, filename), "w") as file:
file.write(parsed)
url_list = "https://raw.githubusercontent.com/NVIDIA/RULER/main/scripts/data/synthetic/json/PaulGrahamEssays_URLs.txt"
except Exception as e:
print(f"Fail download {filename}, ({e})")
async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:
# Fetch URL list
content = await fetch_url(client, url_list)
urls = content.splitlines()
else:
filename = url.split("/")[-1]
try:
with urllib.request.urlopen(url) as website:
content = website.read().decode("utf-8")
# Separate HTML and text URLs
html_urls = [url for url in urls if ".html" in url]
text_urls = [url for url in urls if ".html" not in url]
with open(os.path.join(temp_folder_repo, filename), "w") as file:
file.write(content)
# Process HTML essays
html_tasks = [
process_html_essay(client, url, h, temp_folder_html) for url in html_urls
]
await async_tqdm.gather(*html_tasks, desc="Downloading HTML essays")
except Exception as e:
print(f"Fail download {filename}, ({e})")
# Process text essays
text_tasks = [
process_text_essay(client, url, temp_folder_repo) for url in text_urls
]
await async_tqdm.gather(*text_tasks, desc="Downloading text essays")
# Collect results
files_repo = sorted(glob.glob(os.path.join(temp_folder_repo, "*.txt")))
files_html = sorted(glob.glob(os.path.join(temp_folder_html, "*.txt")))
print(
f"Download {len(files_repo)} essays from `https://github.com/gkamradt/LLMTest_NeedleInAHaystack/`"
)
print(f"Download {len(files_html)} essays from `http://www.paulgraham.com/`")
# print(
# f"Downloaded {len(files_repo)} essays from `https://github.com/gkamradt/LLMTest_NeedleInAHaystack/`"
# )
# print(f"Downloaded {len(files_html)} essays from `http://www.paulgraham.com/`")
# Combine all texts
text = ""
for file in files_repo + files_html:
with open(file, "r") as f:
with open(file, "r", encoding="utf-8") as f:
text += f.read()
# Cleanup
shutil.rmtree(temp_folder_repo)
shutil.rmtree(temp_folder_html)
return {"text": text}
# with open('PaulGrahamEssays.json', 'w') as f:
# json.dump({"text": text}, f)
#
# shutil.rmtree(temp_folder_repo)
# shutil.rmtree(temp_folder_html)
def get_all_essays() -> Dict[str, str]:
"""Synchronous wrapper for get_essays()"""
return asyncio.run(get_essays())
# @cache
# def get_essays():
# temp_folder_repo = "essay_repo"
# temp_folder_html = "essay_html"
# os.makedirs(temp_folder_repo, exist_ok=True)
# os.makedirs(temp_folder_html, exist_ok=True)
#
# h = html2text.HTML2Text()
# h.ignore_images = True
# h.ignore_tables = True
# h.escape_all = True
# h.reference_links = False
# h.mark_code = False
#
# url = "https://raw.githubusercontent.com/NVIDIA/RULER/main/scripts/data/synthetic/json/PaulGrahamEssays_URLs.txt"
# response = requests.get(url)
# response.raise_for_status()
#
# # The content is now in memory as a string
# content = response.text
#
# # If you want to process it line by line:
# urls = content.splitlines()
#
# for url in tqdm(urls):
# if ".html" in url:
# filename = url.split("/")[-1].replace(".html", ".txt")
# try:
# with urllib.request.urlopen(url) as website:
# content = website.read().decode("unicode_escape", "utf-8")
# soup = BeautifulSoup(content, "html.parser")
# specific_tag = soup.find("font")
# parsed = h.handle(str(specific_tag))
#
# with open(os.path.join(temp_folder_html, filename), "w") as file:
# file.write(parsed)
#
# except Exception as e:
# print(f"Fail download {filename}, ({e})")
#
# else:
# filename = url.split("/")[-1]
# try:
# with urllib.request.urlopen(url) as website:
# content = website.read().decode("utf-8")
#
# with open(os.path.join(temp_folder_repo, filename), "w") as file:
# file.write(content)
#
# except Exception as e:
# print(f"Fail download {filename}, ({e})")
#
# files_repo = sorted(glob.glob(os.path.join(temp_folder_repo, "*.txt")))
# files_html = sorted(glob.glob(os.path.join(temp_folder_html, "*.txt")))
# print(
# f"Download {len(files_repo)} essays from `https://github.com/gkamradt/LLMTest_NeedleInAHaystack/`"
# )
# print(f"Download {len(files_html)} essays from `http://www.paulgraham.com/`")
#
# text = ""
# for file in files_repo + files_html:
# with open(file, "r") as f:
# text += f.read()
#
# shutil.rmtree(temp_folder_repo)
# shutil.rmtree(temp_folder_html)
# return {"text": text}
#
# # with open('PaulGrahamEssays.json', 'w') as f:
# # json.dump({"text": text}, f)
# #
# # shutil.rmtree(temp_folder_repo)
# # shutil.rmtree(temp_folder_html)
import os
import random
import uuid
from linecache import cache
from functools import lru_cache
from typing import List
import numpy as np
import wonderwords
......@@ -40,6 +43,11 @@ NLTK_MIN_VERSION = "3.9.1"
RANK = os.environ.get("LOCAL_RANK", "0")
@lru_cache(maxsize=1024)
def cached_sent_tokenize(text: str) -> List[str]:
return sent_tokenize(text)
def download_nltk_resources():
"""Download 'punkt' if not already installed"""
assert (
......@@ -119,7 +127,7 @@ def generate_input_output(
if type_haystack == "essay":
assert isinstance(haystack, list)
text = " ".join(haystack[:num_haystack])
document_sents = sent_tokenize(text.strip())
document_sents = cached_sent_tokenize(text.strip())
insertion_positions = (
[0]
+ sorted(
......@@ -288,7 +296,9 @@ def generate_samples(
"max_length": max_seq_length,
}
if formatted_output["outputs"][0] not in formatted_output["input"]:
COUNT += 1
assert (
False
), f"Needle not in input: {formatted_output}. Something went wrong."
write_jsons.append(formatted_output)
print(COUNT)
return write_jsons
......@@ -7,7 +7,7 @@ output_type: generate_until
test_split: test
download_dataset: !function utils.niah_single_1
doc_to_text: "{{input}}"
doc_to_target: "{{outputs[0]}}" #" {{answer.split('### ')[-1].rstrip()}}"
doc_to_target: "{{outputs[0]}}"
process_results: !function utils.process_results
metric_list:
......@@ -36,4 +36,4 @@ generation_kwargs:
until: []
repeats: 1
metadata:
version: 3.0
version: 1.0
......@@ -3,13 +3,13 @@ import itertools
import json
import os
import re
from functools import partial
from functools import partial, cache
from typing import Literal
import datasets
from transformers import AutoTokenizer
from lm_eval.tasks.ruler.essays import get_essays
from lm_eval.tasks.ruler.essays import get_essays, get_all_essays
from lm_eval.tasks.ruler.prepare import generate_samples
......@@ -31,10 +31,11 @@ STOP_WORDS = ""
RANDOM_SEED = 42
@cache
def get_haystack(type_haystack: Literal["essay", "repeat", "needle"]):
NEEDLE = "One of the special magic {type_needle_v} for {key} is: {value}."
if type_haystack == "essay":
essay = get_essays()["text"]
essay = get_all_essays()["text"]
# essay = json.load(open(essay))["text"]
haystack = re.sub(r"\s+", " ", essay).split(" ")
elif type_haystack == "repeat":
......@@ -155,7 +156,7 @@ niah_multiquery = lambda: flatten(
)
def postprocess_pred(predict_str: str):
def postprocess_pred(predict_str: str) -> str:
predict_str = predict_str.strip()
# Remove all non-printable characters
......@@ -165,16 +166,18 @@ def postprocess_pred(predict_str: str):
return predict_str
def process_results(doc, results):
def process_results(doc: dict, results: list[str]) -> dict[str, float]:
# hacky: set all other lengths to -1
metrics = {str(length): -1.0 for length in SEQ_LENGTHS}
input_len = doc["max_length"]
acc = 1.0 if postprocess_pred(results[0]) in doc["input"] else 0.0
metrics[str(next(length for length in SEQ_LENGTHS if input_len <= length))] = acc
metrics[str(input_len)] = acc
return metrics
def aggregate_metrics(metrics):
return {
length: sum(metric[length] for metric in metrics) / len(metrics)
for length in SEQ_LENGTHS
}
def aggregate_metrics(metrics: list[int]) -> float:
res = [x for x in metrics if x != -1]
if not res:
# we don't have any samples with this length
return 0.0
return sum(res) / len(res)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment