runeval.py 14.1 KB
Newer Older
wanglch's avatar
wanglch committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
# This script will build a set of scores for the accuracy of a given pdf conversion tactic against a gold dataset
import argparse
import hashlib
import json
import logging
import os
import random
import sys
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional

import boto3
import zstandard
from smart_open import register_compressor, smart_open
from tqdm import tqdm

from .dolma_refine.aligners import HirschbergAligner
from .dolma_refine.metrics import DocumentEditSimilarity
from .dolma_refine.segmenters import SpacySegmenter
from .evalhtml import create_review_html

logging.getLogger("pypdf").setLevel(logging.ERROR)


CACHE_DIR = os.path.join(Path.home(), ".cache", "pdf_gold_data_cache")

s3_client = boto3.client("s3")


def _handle_zst(file_obj, mode):
    return zstandard.open(file_obj, mode)


register_compressor(".zstd", _handle_zst)
register_compressor(".zst", _handle_zst)


# Helper function to download files from S3
def download_from_s3(s3_path: str, local_path: str):
    bucket_name, key = s3_path.replace("s3://", "").split("/", 1)
    s3_client.download_file(bucket_name, key, local_path)


def is_debugging():
    return sys.gettrace() is not None


# Create a hash to store file contents and check for changes
def compute_file_hash(file_path: str) -> str:
    hash_md5 = hashlib.md5()
    with open(file_path, "rb") as f:
        for chunk in iter(lambda: f.read(4096), b""):
            hash_md5.update(chunk)
    return hash_md5.hexdigest()


# A single method which can take in any format json entry (openai regular, openai structured, birr)
# and normalize it to a common structure for use later in the
@dataclass(frozen=True)
class NormalizedEntry:
    s3_path: str
    pagenum: int
    text: Optional[str]
    finish_reason: Optional[str]
    error: Optional[str] = None

    @staticmethod
    def from_goldkey(goldkey: str, **kwargs):
        s3_path = goldkey[: goldkey.rindex("-")]
        page_num = int(goldkey[goldkey.rindex("-") + 1 :])
        return NormalizedEntry(s3_path, page_num, **kwargs)

    @property
    def goldkey(self):
        return f"{self.s3_path}-{self.pagenum}"


def normalize_json_entry(data: dict) -> NormalizedEntry:
    if "outputs" in data:
        # Birr case
        if data["outputs"] is None:
            text = None
            finish_reason = None
        else:
            text = data["outputs"][0]["text"]
            finish_reason = data["outputs"][0]["finish_reason"]

        # Try to parse the structured output if possible
        try:
            if text is not None:
                parsed_content = json.loads(text)
                text = parsed_content["natural_text"]
        except json.JSONDecodeError:
            pass

        return NormalizedEntry.from_goldkey(goldkey=data["custom_id"], text=text, finish_reason=finish_reason, error=data.get("completion_error", None))
    elif all(field in data for field in ["s3_path", "pagenum", "text", "error", "finish_reason"]):
        return NormalizedEntry(**data)
    elif "response" in data and "body" in data["response"] and "choices" in data["response"]["body"]:
        # OpenAI case
        try:
            # Attempt to parse the JSON content from OpenAI's response
            parsed_content = json.loads(data["response"]["body"]["choices"][0]["message"]["content"])
            return NormalizedEntry.from_goldkey(
                goldkey=data["custom_id"], text=parsed_content["natural_text"], finish_reason=data["response"]["body"]["choices"][0]["finish_reason"]
            )
        except json.JSONDecodeError:
            # Fallback if content is not valid JSON
            return NormalizedEntry.from_goldkey(
                goldkey=data["custom_id"],
                text=data["response"]["body"]["choices"][0]["message"]["content"],
                finish_reason=data["response"]["body"]["choices"][0]["finish_reason"],
            )
    else:
        # SGLang case
        try:
            # Attempt to parse the JSON content from OpenAI's response
            parsed_content = json.loads(data["response"]["choices"][0]["message"]["content"])
            return NormalizedEntry.from_goldkey(
                goldkey=data["custom_id"], text=parsed_content["natural_text"], finish_reason=data["response"]["choices"][0]["finish_reason"]
            )
        except json.JSONDecodeError:
            # Fallback if content is not valid JSON
            return NormalizedEntry.from_goldkey(
                goldkey=data["custom_id"],
                text=data["response"]["choices"][0]["message"]["content"],
                finish_reason=data["response"]["choices"][0]["finish_reason"],
            )


# Load every .json file from GOLD_DATA_S3_PATH (and saves it to some temp folder for quick loading next time)
# returns map from  "custom_id" ex. "s3://ai2-s2-pdfs/39ce/3db4516cd6e7d7f8e580a494c7a665a6a16a.pdf-4" (where the -4 means page 4)
# to the gold standard text
def load_gold_data(gold_data_path: str, max_workers: int = 8) -> dict:
    """
    Load gold data from JSONL files in a multithreaded manner.

    Args:
        gold_data_path (str): Path to the directory containing JSONL files.
        max_workers (int, optional): Maximum number of threads to use. Defaults to 8.

    Returns:
        dict: A dictionary containing gold data entries.
    """
    if not os.path.exists(CACHE_DIR):
        os.makedirs(CACHE_DIR)

    gold_data: Dict[str, str] = {}
    total_errors = 0
    total_overruns = 0

    gold_jsonl_files: List[str] = list_jsonl_files(gold_data_path)

    def process_file(path: str) -> tuple:
        """Process a single JSONL file and return its data and error counts."""
        file_gold_data = {}
        file_errors = 0
        file_overruns = 0

        with smart_open(path, "r") as f:
            for line in f:
                data = json.loads(line)
                data = normalize_json_entry(data)

                if data.error is not None:
                    file_errors += 1
                elif data.finish_reason != "stop":
                    file_overruns += 1
                else:
                    file_gold_data[data.goldkey] = data.text

        return file_gold_data, file_errors, file_overruns

    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        # Submit all file processing tasks
        futures = [executor.submit(process_file, path) for path in gold_jsonl_files]

        # Gather results as they complete
        for future in as_completed(futures):
            try:
                file_gold_data, file_errors, file_overruns = future.result()
                gold_data.update(file_gold_data)
                total_errors += file_errors
                total_overruns += file_overruns
            except Exception as e:
                print(f"Error processing a file: {e}")

    print(f"Loaded {len(gold_data):,} gold data entries for comparison")
    print(f"Gold processing errors: {total_errors}")
    print(f"Gold overrun errors: {total_overruns}")
    print("-----------------------------------------------------------")

    return gold_data


# Helper function to list all .jsonl files from a directory or an S3 bucket
def list_jsonl_files(path: str) -> list:
    valid_endings = [".json", ".jsonl", ".json.zstd", ".jsonl.zstd"]
    jsonl_files = []

    if path.startswith("s3://"):
        bucket_name, prefix = path.replace("s3://", "").split("/", 1)
        paginator = s3_client.get_paginator("list_objects_v2")
        pages = paginator.paginate(Bucket=bucket_name, Prefix=prefix)

        for page in pages:
            for obj in page.get("Contents", []):
                if any(obj["Key"].endswith(ending) for ending in valid_endings):
                    jsonl_files.append(f"s3://{bucket_name}/{obj['Key']}")

    else:
        # If it's a local directory, list all .jsonl files
        for root, _, files in os.walk(path):
            for file in files:
                if any(file.endswith(ending) for ending in valid_endings):
                    jsonl_files.append(os.path.join(root, file))

    return jsonl_files


# Takes in a path to a local directory or s3://[bucket]/[prefix path] where your jsonl files are stored
# This is most likely the output location of the refiner
# Expecting each jsonl line to include {s3_path: [path to original pdf], page: [pagenum], text: [proper page text]}
# Returns the average Levenshtein distance match between the data
def process_jsonl_file(jsonl_file, gold_data, comparer):
    page_data = {}
    total_alignment_score: float = 0.0
    char_weighted_alignment_score: float = 0.0
    total_pages = 0
    total_chars = 0
    total_errors = 0
    total_overruns = 0

    with smart_open(jsonl_file, "r") as f:
        for line in f:
            data = json.loads(line)

            data = normalize_json_entry(data)

            if data.goldkey not in gold_data:
                continue

            gold_text = gold_data[data.goldkey]
            eval_text = data.text

            gold_text = gold_text or ""
            eval_text = eval_text or ""

            if data.error is not None:
                total_errors += 1
                eval_text = f"[Error processing this page: {data.error}]"

            if data.error is None and data.finish_reason != "stop":
                total_overruns += 1
                eval_text += f"\n[Error processing this page: overrun {data.finish_reason}]"

            if len(gold_text.strip()) < 3 and len(eval_text.strip()) < 3:
                alignment = 1.0
            else:
                alignment = comparer.compute(gold_text, eval_text)

            page_data[data.goldkey] = {"s3_path": data.s3_path, "page": data.pagenum, "gold_text": gold_text, "eval_text": eval_text, "alignment": alignment}

            total_alignment_score += alignment
            char_weighted_alignment_score += alignment * len(gold_text)
            total_chars += len(gold_text)
            total_pages += 1

    return total_alignment_score, char_weighted_alignment_score, total_chars, total_pages, total_errors, total_overruns, page_data


def do_eval(gold_data_path: str, eval_data_path: str, review_page_name: str, review_page_size: int) -> tuple[float, list[dict]]:
    gold_data = load_gold_data(gold_data_path)

    total_alignment_score = 0
    total_char_alignment_score = 0
    total_weight = 0
    total_pages = 0
    total_errors = 0
    total_overruns = 0
    total_pages_compared = set()

    page_eval_data = []

    segmenter = SpacySegmenter("spacy")
    aligner = HirschbergAligner(match_score=1, mismatch_score=-1, indel_score=-1)
    comparer = DocumentEditSimilarity(segmenter=segmenter, aligner=aligner)

    # List all .jsonl files in the directory or S3 bucket
    jsonl_files = list_jsonl_files(eval_data_path)

    if not jsonl_files:
        raise ValueError("No .jsonl files found in the specified path.")

    print(f"Found {len(jsonl_files):,} files to evaluate")

    with ProcessPoolExecutor() if not is_debugging() else ThreadPoolExecutor() as executor:
        # Prepare the future tasks
        futures = [executor.submit(process_jsonl_file, jsonl_file, gold_data, comparer) for jsonl_file in jsonl_files]

        # Process each future as it completes
        for future in tqdm(as_completed(futures), total=len(jsonl_files)):
            alignment_score, char_weighted_score, chars, pages, errors, overruns, page_data = future.result()  # Get the result of the completed task

            # Aggregate statistics
            total_alignment_score += alignment_score
            total_char_alignment_score += char_weighted_score
            total_weight += chars
            total_pages += pages
            total_errors += errors
            total_overruns += overruns
            total_pages_compared |= page_data.keys()

            # Generate the eval data
            for pd_key, pd in page_data.items():
                # if pd["alignment"] > 0.97:
                #     continue

                # if len(pd["gold_text"]) < 200 and len(pd["eval_text"]) < 200:
                #     continue
                # if "[Error processing this page: overrun" not in pd["eval_text"]:
                #     continue

                page_eval_data.append(pd)

    print(f"Compared {len(total_pages_compared):,} pages")
    print(f"Found {total_errors} errors in the eval set, and {total_overruns} cases of length overruns")
    print(f"Mean page-weighted alignment: {total_alignment_score / total_pages:.3f}")
    print(f"Mean char-weighted alignment: {total_char_alignment_score / total_weight:.3f}")
    print("")
    print("...creating review page")

    # TODO Temporary filter to see other stuff
    # page_eval_data = [x for x in page_eval_data if "NO ENGLISH TEXT" not in x["gold_text"]]

    # Select the top 20 lowest alignments
    page_eval_data.sort(key=lambda x: x["alignment"])
    create_review_html(page_eval_data[:review_page_size], filename=review_page_name + "_worst.html")

    # Select random entries to return in the page_eval_data
    page_eval_data = random.sample(page_eval_data, review_page_size)
    create_review_html(page_eval_data, filename=review_page_name + "_sample.html")

    return total_alignment_score / total_weight, page_eval_data


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Transform JSONL files by extracting and renaming specific fields.")
    parser.add_argument("--name", default="review_page", help="What name to give to this evaluation/comparison")
    parser.add_argument(
        "--review_size",
        default=20,
        type=int,
        help="Number of entries to show on the generated review page",
    )
    parser.add_argument(
        "gold_data_path",
        type=str,
        help='Path to the gold data directory containing JSONL files. Can be a local path or S3 URL. Can be openai "done" data, or birr "done" data',
    )
    parser.add_argument(
        "eval_data_path",
        type=str,
        help='Path to the eval data directory containing JSONL files. Can be a local path or S3 URL. Can be openai "done" data, or birr "done" data',
    )

    args = parser.parse_args()

    result = do_eval(gold_data_path=args.gold_data_path, eval_data_path=args.eval_data_path, review_page_name=args.name, review_page_size=args.review_size)