Commit 7399dae1 authored by Baber's avatar Baber
Browse files

fix aggregation

parent 19d54607
...@@ -11,21 +11,56 @@ ...@@ -11,21 +11,56 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License # limitations under the License
import asyncio
import glob import glob
import os import os
import shutil import shutil
import urllib.request from typing import Dict
from functools import cache
import html2text import html2text
import requests import httpx
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
from tqdm import tqdm from tqdm.asyncio import tqdm as async_tqdm
@cache async def fetch_url(client: httpx.AsyncClient, url: str) -> str:
def get_essays(): response = await client.get(url)
response.raise_for_status()
return response.text
async def process_html_essay(
client: httpx.AsyncClient, url: str, h: html2text.HTML2Text, temp_folder: str
) -> None:
filename = url.split("/")[-1].replace(".html", ".txt")
try:
content = await fetch_url(client, url)
soup = BeautifulSoup(content, "html.parser")
specific_tag = soup.find("font")
if specific_tag:
parsed = h.handle(str(specific_tag))
with open(
os.path.join(temp_folder, filename), "w", encoding="utf-8"
) as file:
file.write(parsed)
except Exception as e:
print(f"Failed to download {filename}: {str(e)}")
async def process_text_essay(
client: httpx.AsyncClient, url: str, temp_folder: str
) -> None:
filename = url.split("/")[-1]
try:
content = await fetch_url(client, url)
with open(os.path.join(temp_folder, filename), "w", encoding="utf-8") as file:
file.write(content)
except Exception as e:
print(f"Failed to download {filename}: {str(e)}")
async def get_essays() -> Dict[str, str]:
temp_folder_repo = "essay_repo" temp_folder_repo = "essay_repo"
temp_folder_html = "essay_html" temp_folder_html = "essay_html"
os.makedirs(temp_folder_repo, exist_ok=True) os.makedirs(temp_folder_repo, exist_ok=True)
...@@ -38,62 +73,126 @@ def get_essays(): ...@@ -38,62 +73,126 @@ def get_essays():
h.reference_links = False h.reference_links = False
h.mark_code = False h.mark_code = False
url = "https://raw.githubusercontent.com/NVIDIA/RULER/main/scripts/data/synthetic/json/PaulGrahamEssays_URLs.txt" url_list = "https://raw.githubusercontent.com/NVIDIA/RULER/main/scripts/data/synthetic/json/PaulGrahamEssays_URLs.txt"
response = requests.get(url)
response.raise_for_status()
# The content is now in memory as a string
content = response.text
# If you want to process it line by line:
urls = content.splitlines()
for url in tqdm(urls):
if ".html" in url:
filename = url.split("/")[-1].replace(".html", ".txt")
try:
with urllib.request.urlopen(url) as website:
content = website.read().decode("unicode_escape", "utf-8")
soup = BeautifulSoup(content, "html.parser")
specific_tag = soup.find("font")
parsed = h.handle(str(specific_tag))
with open(os.path.join(temp_folder_html, filename), "w") as file:
file.write(parsed)
except Exception as e: async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:
print(f"Fail download {filename}, ({e})") # Fetch URL list
content = await fetch_url(client, url_list)
urls = content.splitlines()
else: # Separate HTML and text URLs
filename = url.split("/")[-1] html_urls = [url for url in urls if ".html" in url]
try: text_urls = [url for url in urls if ".html" not in url]
with urllib.request.urlopen(url) as website:
content = website.read().decode("utf-8")
with open(os.path.join(temp_folder_repo, filename), "w") as file: # Process HTML essays
file.write(content) html_tasks = [
process_html_essay(client, url, h, temp_folder_html) for url in html_urls
]
await async_tqdm.gather(*html_tasks, desc="Downloading HTML essays")
except Exception as e: # Process text essays
print(f"Fail download {filename}, ({e})") text_tasks = [
process_text_essay(client, url, temp_folder_repo) for url in text_urls
]
await async_tqdm.gather(*text_tasks, desc="Downloading text essays")
# Collect results
files_repo = sorted(glob.glob(os.path.join(temp_folder_repo, "*.txt"))) files_repo = sorted(glob.glob(os.path.join(temp_folder_repo, "*.txt")))
files_html = sorted(glob.glob(os.path.join(temp_folder_html, "*.txt"))) files_html = sorted(glob.glob(os.path.join(temp_folder_html, "*.txt")))
print(
f"Download {len(files_repo)} essays from `https://github.com/gkamradt/LLMTest_NeedleInAHaystack/`"
)
print(f"Download {len(files_html)} essays from `http://www.paulgraham.com/`")
# print(
# f"Downloaded {len(files_repo)} essays from `https://github.com/gkamradt/LLMTest_NeedleInAHaystack/`"
# )
# print(f"Downloaded {len(files_html)} essays from `http://www.paulgraham.com/`")
# Combine all texts
text = "" text = ""
for file in files_repo + files_html: for file in files_repo + files_html:
with open(file, "r") as f: with open(file, "r", encoding="utf-8") as f:
text += f.read() text += f.read()
# Cleanup
shutil.rmtree(temp_folder_repo) shutil.rmtree(temp_folder_repo)
shutil.rmtree(temp_folder_html) shutil.rmtree(temp_folder_html)
return {"text": text} return {"text": text}
# with open('PaulGrahamEssays.json', 'w') as f:
# json.dump({"text": text}, f) def get_all_essays() -> Dict[str, str]:
# """Synchronous wrapper for get_essays()"""
# shutil.rmtree(temp_folder_repo) return asyncio.run(get_essays())
# shutil.rmtree(temp_folder_html)
# @cache
# def get_essays():
# temp_folder_repo = "essay_repo"
# temp_folder_html = "essay_html"
# os.makedirs(temp_folder_repo, exist_ok=True)
# os.makedirs(temp_folder_html, exist_ok=True)
#
# h = html2text.HTML2Text()
# h.ignore_images = True
# h.ignore_tables = True
# h.escape_all = True
# h.reference_links = False
# h.mark_code = False
#
# url = "https://raw.githubusercontent.com/NVIDIA/RULER/main/scripts/data/synthetic/json/PaulGrahamEssays_URLs.txt"
# response = requests.get(url)
# response.raise_for_status()
#
# # The content is now in memory as a string
# content = response.text
#
# # If you want to process it line by line:
# urls = content.splitlines()
#
# for url in tqdm(urls):
# if ".html" in url:
# filename = url.split("/")[-1].replace(".html", ".txt")
# try:
# with urllib.request.urlopen(url) as website:
# content = website.read().decode("unicode_escape", "utf-8")
# soup = BeautifulSoup(content, "html.parser")
# specific_tag = soup.find("font")
# parsed = h.handle(str(specific_tag))
#
# with open(os.path.join(temp_folder_html, filename), "w") as file:
# file.write(parsed)
#
# except Exception as e:
# print(f"Fail download {filename}, ({e})")
#
# else:
# filename = url.split("/")[-1]
# try:
# with urllib.request.urlopen(url) as website:
# content = website.read().decode("utf-8")
#
# with open(os.path.join(temp_folder_repo, filename), "w") as file:
# file.write(content)
#
# except Exception as e:
# print(f"Fail download {filename}, ({e})")
#
# files_repo = sorted(glob.glob(os.path.join(temp_folder_repo, "*.txt")))
# files_html = sorted(glob.glob(os.path.join(temp_folder_html, "*.txt")))
# print(
# f"Download {len(files_repo)} essays from `https://github.com/gkamradt/LLMTest_NeedleInAHaystack/`"
# )
# print(f"Download {len(files_html)} essays from `http://www.paulgraham.com/`")
#
# text = ""
# for file in files_repo + files_html:
# with open(file, "r") as f:
# text += f.read()
#
# shutil.rmtree(temp_folder_repo)
# shutil.rmtree(temp_folder_html)
# return {"text": text}
#
# # with open('PaulGrahamEssays.json', 'w') as f:
# # json.dump({"text": text}, f)
# #
# # shutil.rmtree(temp_folder_repo)
# # shutil.rmtree(temp_folder_html)
import os import os
import random import random
import uuid import uuid
from linecache import cache
from functools import lru_cache
from typing import List
import numpy as np import numpy as np
import wonderwords import wonderwords
...@@ -40,6 +43,11 @@ NLTK_MIN_VERSION = "3.9.1" ...@@ -40,6 +43,11 @@ NLTK_MIN_VERSION = "3.9.1"
RANK = os.environ.get("LOCAL_RANK", "0") RANK = os.environ.get("LOCAL_RANK", "0")
@lru_cache(maxsize=1024)
def cached_sent_tokenize(text: str) -> List[str]:
return sent_tokenize(text)
def download_nltk_resources(): def download_nltk_resources():
"""Download 'punkt' if not already installed""" """Download 'punkt' if not already installed"""
assert ( assert (
...@@ -119,7 +127,7 @@ def generate_input_output( ...@@ -119,7 +127,7 @@ def generate_input_output(
if type_haystack == "essay": if type_haystack == "essay":
assert isinstance(haystack, list) assert isinstance(haystack, list)
text = " ".join(haystack[:num_haystack]) text = " ".join(haystack[:num_haystack])
document_sents = sent_tokenize(text.strip()) document_sents = cached_sent_tokenize(text.strip())
insertion_positions = ( insertion_positions = (
[0] [0]
+ sorted( + sorted(
...@@ -288,7 +296,9 @@ def generate_samples( ...@@ -288,7 +296,9 @@ def generate_samples(
"max_length": max_seq_length, "max_length": max_seq_length,
} }
if formatted_output["outputs"][0] not in formatted_output["input"]: if formatted_output["outputs"][0] not in formatted_output["input"]:
COUNT += 1 assert (
False
), f"Needle not in input: {formatted_output}. Something went wrong."
write_jsons.append(formatted_output) write_jsons.append(formatted_output)
print(COUNT) print(COUNT)
return write_jsons return write_jsons
...@@ -7,7 +7,7 @@ output_type: generate_until ...@@ -7,7 +7,7 @@ output_type: generate_until
test_split: test test_split: test
download_dataset: !function utils.niah_single_1 download_dataset: !function utils.niah_single_1
doc_to_text: "{{input}}" doc_to_text: "{{input}}"
doc_to_target: "{{outputs[0]}}" #" {{answer.split('### ')[-1].rstrip()}}" doc_to_target: "{{outputs[0]}}"
process_results: !function utils.process_results process_results: !function utils.process_results
metric_list: metric_list:
...@@ -36,4 +36,4 @@ generation_kwargs: ...@@ -36,4 +36,4 @@ generation_kwargs:
until: [] until: []
repeats: 1 repeats: 1
metadata: metadata:
version: 3.0 version: 1.0
...@@ -3,13 +3,13 @@ import itertools ...@@ -3,13 +3,13 @@ import itertools
import json import json
import os import os
import re import re
from functools import partial from functools import partial, cache
from typing import Literal from typing import Literal
import datasets import datasets
from transformers import AutoTokenizer from transformers import AutoTokenizer
from lm_eval.tasks.ruler.essays import get_essays from lm_eval.tasks.ruler.essays import get_essays, get_all_essays
from lm_eval.tasks.ruler.prepare import generate_samples from lm_eval.tasks.ruler.prepare import generate_samples
...@@ -31,10 +31,11 @@ STOP_WORDS = "" ...@@ -31,10 +31,11 @@ STOP_WORDS = ""
RANDOM_SEED = 42 RANDOM_SEED = 42
@cache
def get_haystack(type_haystack: Literal["essay", "repeat", "needle"]): def get_haystack(type_haystack: Literal["essay", "repeat", "needle"]):
NEEDLE = "One of the special magic {type_needle_v} for {key} is: {value}." NEEDLE = "One of the special magic {type_needle_v} for {key} is: {value}."
if type_haystack == "essay": if type_haystack == "essay":
essay = get_essays()["text"] essay = get_all_essays()["text"]
# essay = json.load(open(essay))["text"] # essay = json.load(open(essay))["text"]
haystack = re.sub(r"\s+", " ", essay).split(" ") haystack = re.sub(r"\s+", " ", essay).split(" ")
elif type_haystack == "repeat": elif type_haystack == "repeat":
...@@ -155,7 +156,7 @@ niah_multiquery = lambda: flatten( ...@@ -155,7 +156,7 @@ niah_multiquery = lambda: flatten(
) )
def postprocess_pred(predict_str: str): def postprocess_pred(predict_str: str) -> str:
predict_str = predict_str.strip() predict_str = predict_str.strip()
# Remove all non-printable characters # Remove all non-printable characters
...@@ -165,16 +166,18 @@ def postprocess_pred(predict_str: str): ...@@ -165,16 +166,18 @@ def postprocess_pred(predict_str: str):
return predict_str return predict_str
def process_results(doc, results): def process_results(doc: dict, results: list[str]) -> dict[str, float]:
# hacky: set all other lengths to -1
metrics = {str(length): -1.0 for length in SEQ_LENGTHS} metrics = {str(length): -1.0 for length in SEQ_LENGTHS}
input_len = doc["max_length"] input_len = doc["max_length"]
acc = 1.0 if postprocess_pred(results[0]) in doc["input"] else 0.0 acc = 1.0 if postprocess_pred(results[0]) in doc["input"] else 0.0
metrics[str(next(length for length in SEQ_LENGTHS if input_len <= length))] = acc metrics[str(input_len)] = acc
return metrics return metrics
def aggregate_metrics(metrics): def aggregate_metrics(metrics: list[int]) -> float:
return { res = [x for x in metrics if x != -1]
length: sum(metric[length] for metric in metrics) / len(metrics) if not res:
for length in SEQ_LENGTHS # we don't have any samples with this length
} return 0.0
return sum(res) / len(res)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment