convertjsontoparquet.py 15.3 KB
Newer Older
wanglch's avatar
wanglch committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
# Script to generate parquet dataset files to upload to hugging face
# Input is a dataset location /data/jakep/pdfdata/openai_batch_data_v5_1_iabooks_train_done/*.json
# Each json line has a custom id that looks like {"custom_id": "s3://ai2-s2-pdfs/de80/a57e6c57b45796d2e020173227f7eae44232.pdf-1", ... more data}

# Fix this script so that it works, and that it will take a path to an input dataset, and sqllite database location
# And then it will build a parquet file with rows that look like: "id", "url", "page_number", "response"
# Where Id will be the output of parse_pdf_hash plus "-" plus the page number
# The url will be the result of get_uri_from_db
# Rresponse will be NormalizedEntry.text
import argparse
import concurrent.futures
import glob
import json
import multiprocessing
import os
import re
import sqlite3
from dataclasses import dataclass
from typing import Dict, List, Optional, Set, Tuple
from urllib.parse import urlparse

import boto3
import pandas as pd
from pypdf import PdfReader, PdfWriter
from tqdm import tqdm


def parse_pdf_hash(pretty_pdf_path: str) -> Optional[str]:
    """
    Extracts a hash from a pretty PDF S3 URL.
    For example, given:
      s3://ai2-s2-pdfs/de80/a57e6c57b45796d2e020173227f7eae44232.pdf-1
    it will return "de80a57e6c57b45796d2e020173227f7eae44232".
    """
    # Allow an optional "-<number>" at the end.
    if pretty_pdf_path.startswith("s3://ai2-s2-pdfs/"):
        pattern = r"s3://ai2-s2-pdfs/([a-f0-9]{4})/([a-f0-9]+)\.pdf(?:-\d+)?$"
        match = re.match(pattern, pretty_pdf_path)
        if match:
            return match.group(1) + match.group(2)
        return None
    elif pretty_pdf_path.startswith("s3://ai2-oe-data/reganh/iabooks/"):
        return urlparse(pretty_pdf_path).path.split("/")[-1]
    else:
        raise NotImplementedError()


def get_uri_from_db(db_path: str, pdf_hash: str) -> Optional[str]:
    """
    Looks up the URL for the given pdf_hash in the sqlite database.
    Assumes there is a table called 'pdf_mapping' with a column 'uri'.
    """
    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()
    cursor.execute("SELECT uri FROM pdf_mapping WHERE pdf_hash = ?", (pdf_hash,))
    result = cursor.fetchone()
    conn.close()
    return result[0].strip() if result and result[0] else None


@dataclass(frozen=True)
class NormalizedEntry:
    s3_path: str
    pagenum: int
    text: Optional[str]
    finish_reason: Optional[str]
    error: Optional[str] = None

    @staticmethod
    def from_goldkey(goldkey: str, **kwargs):
        """
        Constructs a NormalizedEntry from a goldkey string.
        The goldkey is expected to be of the format:
          <s3_path>-<page_number>
        """
        s3_path = goldkey[: goldkey.rindex("-")]
        page_num = int(goldkey[goldkey.rindex("-") + 1 :])
        return NormalizedEntry(s3_path, page_num, **kwargs)

    @property
    def goldkey(self):
        return f"{self.s3_path}-{self.pagenum}"


def normalize_json_entry(data: dict) -> NormalizedEntry:
    """
    Normalizes a JSON entry from any of the supported formats.
    It supports:
      - Birr: looks for an "outputs" field.
      - Already normalized entries: if they contain s3_path, pagenum, etc.
      - OpenAI: where the response is in data["response"]["body"]["choices"].
      - SGLang: where the response is in data["response"]["choices"].
    """
    if "outputs" in data:
        # Birr case
        if data["outputs"] is None:
            text = None
            finish_reason = None
        else:
            text = data["outputs"][0]["text"]
            finish_reason = data["outputs"][0]["finish_reason"]

        return NormalizedEntry.from_goldkey(
            goldkey=data["custom_id"],
            text=text,
            finish_reason=finish_reason,
            error=data.get("completion_error", None),
        )
    elif all(field in data for field in ["s3_path", "pagenum", "text", "error", "finish_reason"]):
        # Already normalized
        return NormalizedEntry(**data)
    elif "response" in data and "body" in data["response"] and "choices" in data["response"]["body"]:
        return NormalizedEntry.from_goldkey(
            goldkey=data["custom_id"],
            text=data["response"]["body"]["choices"][0]["message"]["content"],
            finish_reason=data["response"]["body"]["choices"][0]["finish_reason"],
        )
    else:
        raise ValueError("Unsupported JSON format")


def parse_s3_url(s3_url: str) -> Tuple[str, str]:
    """
    Parses an S3 URL of the form s3://bucket/key and returns (bucket, key).
    """
    if not s3_url.startswith("s3://"):
        raise ValueError(f"Invalid S3 URL: {s3_url}")
    s3_path = s3_url[5:]
    bucket, key = s3_path.split("/", 1)
    return bucket, key


def download_pdf_to_cache(s3_url: str, cache_dir: str) -> Optional[str]:
    """
    Downloads the PDF from the given S3 URL into the specified cache directory.
    The destination filename is based on the parsed PDF hash.
    Returns the path to the downloaded PDF.
    """
    try:
        bucket, key = parse_s3_url(s3_url)
        s3_client = boto3.client("s3")
        pdf_hash = parse_pdf_hash(s3_url)
        if not pdf_hash:
            # Fallback: use a sanitized version of the s3_url
            pdf_hash = re.sub(r"\W+", "_", s3_url)
        dest_path = os.path.join(cache_dir, f"{pdf_hash}.pdf")
        # Avoid re-downloading if already exists
        if not os.path.exists(dest_path):
            s3_client.download_file(bucket, key, dest_path)
        return dest_path
    except Exception as e:
        print(f"Error downloading {s3_url}: {e}")
        return None


def process_pdf_page(s3_url: str, page_number: int, combined_id: str, output_pdf_dir: str, pdf_cache: Dict[str, str]) -> Optional[str]:
    """
    Extracts the specified page (1-indexed) from the cached PDF corresponding to s3_url.
    Writes a new single-page PDF to the output_pdf_dir using the combined_id as the filename.
    Returns the relative path to the new PDF (e.g., "pdfs/<combined_id>.pdf").
    """
    try:
        local_cached_pdf = pdf_cache.get(s3_url)
        if not local_cached_pdf or not os.path.exists(local_cached_pdf):
            print(f"Cached PDF not found for {s3_url}")
            return None
        reader = PdfReader(local_cached_pdf)
        # pypdf uses 0-indexed page numbers
        page_index = page_number - 1
        if page_index < 0 or page_index >= len(reader.pages):
            print(f"Page number {page_number} out of range for PDF {s3_url}")
            return None
        writer = PdfWriter()
        writer.add_page(reader.pages[page_index])
        output_filename = f"{combined_id}.pdf"
        output_path = os.path.join(output_pdf_dir, output_filename)
        with open(output_path, "wb") as f_out:
            writer.write(f_out)
        # Return the relative path (assuming pdfs/ folder is relative to the parquet file location)
        return os.path.join("pdfs", output_filename)
    except Exception as e:
        print(f"Error processing PDF page for {s3_url} page {page_number}: {e}")
        return None


def process_file(file_path: str, db_path: str, output_pdf_dir: str, pdf_cache: Dict[str, str]) -> Tuple[List[dict], int]:
    """
    Process a single file and return a tuple:
      (list of valid rows, number of rows skipped due to missing URL or PDF extraction/filtering).
    For each JSON entry, the function:
      - Normalizes the JSON.
      - Skips entries whose response contains the word "resume" (any case) along with either an email address or a phone number.
      - Extracts the PDF hash and builds the combined id.
      - Looks up the corresponding URL from the sqlite database.
      - Extracts the specified page from the cached PDF and writes it to output_pdf_dir.
      - Outputs a row with "id", "url", "page_number", "response".
    """
    rows = []
    missing_count = 0
    email_regex = r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}\b"
    phone_regex = r"\b(?:\+?\d{1,3}[-.\s]?)?(?:\(?\d{3}\)?[-.\s]?)?\d{3}[-.\s]?\d{4}\b"

    try:
        with open(file_path, "r", encoding="utf-8") as f:
            for line_num, line in enumerate(f, start=1):
                line = line.strip()
                if not line:
                    continue
                try:
                    data = json.loads(line)
                except json.JSONDecodeError as e:
                    print(f"Skipping invalid JSON at {file_path}:{line_num} - {e}")
                    continue

                try:
                    normalized = normalize_json_entry(data)
                except Exception as e:
                    print(f"Error normalizing entry at {file_path}:{line_num} - {e}")
                    continue

                # Apply filter: skip if response contains "resume" (any case) and an email or phone number.
                response_text = normalized.text if normalized.text else ""
                if re.search(r"resume", response_text, re.IGNORECASE) and (re.search(email_regex, response_text) or re.search(phone_regex, response_text)):
                    print(f"Skipping entry due to resume and contact info in response at {file_path}:{line_num}")
                    continue

                # Extract the PDF hash from the s3_path.
                pdf_hash = parse_pdf_hash(normalized.s3_path)
                if pdf_hash is None:
                    print(f"Could not parse pdf hash from {normalized.s3_path} at {file_path}:{line_num}")
                    continue

                # The output id is the pdf hash plus '-' plus the page number.
                combined_id = f"{pdf_hash}-{normalized.pagenum}"

                # Look up the corresponding URL from the sqlite database.
                url = get_uri_from_db(db_path, pdf_hash)
                if not url:
                    print(f"Missing URL for pdf hash {pdf_hash} at {file_path}:{line_num}")
                    missing_count += 1
                    continue

                # Process PDF: extract the specified page from the cached PDF.
                local_pdf_path = process_pdf_page(normalized.s3_path, normalized.pagenum, combined_id, output_pdf_dir, pdf_cache)
                if local_pdf_path is None:
                    print(f"Skipping entry because PDF processing failed for {normalized.s3_path} page {normalized.pagenum} at {file_path}:{line_num}")
                    missing_count += 1
                    continue

                row = {
                    "id": combined_id,
                    "url": url,
                    "page_number": normalized.pagenum,
                    "response": normalized.text,
                }
                rows.append(row)
    except Exception as e:
        print(f"Error processing file {file_path}: {e}")
    return rows, missing_count


def scan_file_for_s3_urls(file_path: str) -> Set[str]:
    """
    Scans a single file and returns a set of unique S3 URLs found in the JSON entries.
    """
    urls = set()
    try:
        with open(file_path, "r", encoding="utf-8") as f:
            for line in f:
                line = line.strip()
                if not line:
                    continue
                try:
                    data = json.loads(line)
                    normalized = normalize_json_entry(data)
                    urls.add(normalized.s3_path)
                except Exception:
                    # Skip entries that cannot be normalized
                    continue
    except Exception as e:
        print(f"Error reading file {file_path}: {e}")
    return urls


def main():
    parser = argparse.ArgumentParser(description="Generate a Parquet dataset file for HuggingFace upload.")
    parser.add_argument(
        "input_dataset",
        help="Input dataset file pattern (e.g., '/data/jakep/pdfdata/openai_batch_data_v5_1_iabooks_train_done/*.json')",
    )
    parser.add_argument("db_path", help="Path to the SQLite database file.")
    parser.add_argument("--output", default="output.parquet", help="Output Parquet file path.")

    args = parser.parse_args()

    files = glob.glob(args.input_dataset)
    print(f"Found {len(files)} files matching pattern: {args.input_dataset}")

    # Determine output directory and create 'pdfs' subfolder.
    output_abs_path = os.path.abspath(args.output)
    output_dir = os.path.dirname(output_abs_path)
    pdfs_dir = os.path.join(output_dir, "pdfs")
    os.makedirs(pdfs_dir, exist_ok=True)

    # Create a temporary directory for caching PDFs.
    pdf_cache_dir = "/tmp/pdf_cache"
    os.makedirs(pdf_cache_dir, exist_ok=True)

    print(f"Caching PDFs to temporary directory: {pdf_cache_dir}")

    # ---------------------------------------------------------------------
    # Step 1: Scan input files to collect all unique S3 URLs using a ProcessPoolExecutor.
    unique_s3_urls: Set[str] = set()
    print("Scanning input files to collect unique PDF URLs...")
    num_cpus = multiprocessing.cpu_count()
    with concurrent.futures.ProcessPoolExecutor(max_workers=num_cpus * 4) as executor:
        results = list(tqdm(executor.map(scan_file_for_s3_urls, files), total=len(files), desc="Scanning files"))
    for url_set in results:
        unique_s3_urls |= url_set

    print(f"Found {len(unique_s3_urls)} unique PDF URLs.")

    # ---------------------------------------------------------------------
    # Step 2: Download all unique PDFs to the cache directory.
    pdf_cache: Dict[str, str] = {}
    print("Caching PDFs from S3...")
    with concurrent.futures.ProcessPoolExecutor(max_workers=num_cpus * 8) as executor:
        future_to_url = {executor.submit(download_pdf_to_cache, s3_url, pdf_cache_dir): s3_url for s3_url in unique_s3_urls}
        for future in tqdm(concurrent.futures.as_completed(future_to_url), total=len(future_to_url), desc="Downloading PDFs"):
            s3_url = future_to_url[future]
            try:
                local_path = future.result()
                if local_path:
                    pdf_cache[s3_url] = local_path
                else:
                    print(f"Failed to cache PDF for {s3_url}")
            except Exception as e:
                print(f"Error caching PDF for {s3_url}: {e}")

    # ---------------------------------------------------------------------
    # Step 3: Process input files using the precached PDFs.
    all_rows = []
    total_missing = 0
    print("Processing files...")
    with concurrent.futures.ProcessPoolExecutor() as executor:
        futures = {executor.submit(process_file, file_path, args.db_path, pdfs_dir, pdf_cache): file_path for file_path in files}
        for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures), desc="Processing files"):
            file_path = futures[future]
            try:
                rows, missing_count = future.result()
                all_rows.extend(rows)
                total_missing += missing_count
            except Exception as e:
                print(f"Error processing file {file_path}: {e}")

    if all_rows:
        df = pd.DataFrame(all_rows)
        # Set the "id" column as the index.
        df.set_index("id", inplace=True)
        df.to_parquet(args.output)

        valid_count = len(df)
        total_processed = valid_count + total_missing
        print(f"Successfully wrote {valid_count} rows to {args.output}")
        print(f"Rows skipped due to missing URL/PDF or filtering: {total_missing} out of {total_processed} processed rows")
    else:
        print("No valid rows to write. Exiting.")


if __name__ == "__main__":
    main()