"src/vscode:/vscode.git/clone" did not exist on "f7a1de58baffdaca898e189b71ce20eebd1cf225"
dataloader.py 6.2 KB
Newer Older
wanglch's avatar
wanglch committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import glob
import logging
import os
import re
from typing import Optional

import boto3
from datasets import Dataset, load_dataset
from filelock import FileLock

from olmocr.data.renderpdf import get_pdf_media_box_width_height
from olmocr.prompts.anchor import get_anchor_text
from olmocr.s3_utils import parse_custom_id, parse_s3_path

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Quiet logs from pypdf and smart open
logging.getLogger("pypdf").setLevel(logging.ERROR)
logging.getLogger("smart_open").setLevel(logging.ERROR)


def list_dataset_files(s3_glob_path: str):
    """
    Lists files in the specified S3 path that match the glob pattern.
    """
    if s3_glob_path.startswith("s3://"):
        s3 = boto3.client("s3")
        match = re.match(r"s3://([^/]+)/(.+)", s3_glob_path)
        if not match:
            logger.error(f"Invalid S3 path: {s3_glob_path}")
            raise ValueError(f"Invalid S3 path: {s3_glob_path}")

        bucket, prefix_pattern = match.groups()
        prefix = prefix_pattern.split("*")[0]  # Extract prefix before the wildcard
        paginator = s3.get_paginator("list_objects_v2")
        pages = paginator.paginate(Bucket=bucket, Prefix=prefix)

        files = []
        pattern = re.compile(prefix_pattern.replace("*", ".*"))
        for page in pages:
            for obj in page.get("Contents", []):
                key = obj["Key"]
                if pattern.fullmatch(key):
                    files.append(f"s3://{bucket}/{key}")
        return files
    else:
        return glob.glob(s3_glob_path)


def load_jsonl_into_ds(s3_glob_path: str, first_n_files: Optional[int] = None) -> Dataset:
    """
    Loads JSONL files from the specified S3 path into a Hugging Face Dataset.
    """
    all_json_files = list_dataset_files(s3_glob_path)

    if first_n_files:
        all_json_files = all_json_files[:first_n_files]

    # Use datasets library to load JSON files from S3
    dataset = load_dataset(
        "json",
        data_files=all_json_files,
    )

    return dataset


def extract_openai_batch_response(example):
    custom_id = example.get("custom_id", None)

    # Parse the custom id into an s3 document path and page number (1indexed)
    s3_path, page_num = parse_custom_id(custom_id)

    response_body = example.get("response", {}).get("body", {})
    choices = response_body.get("choices", [])
    response = ""
    finish_reason = ""
    if choices:
        first_choice = choices[0]
        message = first_choice.get("message", {})
        response = message.get("content", "")
        finish_reason = first_choice.get("finish_reason", "")

    # TODO Maybe in the future we can parse the response (which is a structured JSON document itself)
    # into its own columns

    return {"s3_path": s3_path, "page_num": page_num, "response": response, "finish_reason": finish_reason}


def _cache_s3_file(s3_path: str, local_cache_dir: str):
    """
    Downloads an S3 object to a local cache directory, ensuring no two writers corrupt the same file.
    """
    bucket, key = parse_s3_path(s3_path)

    # Define the local file path
    local_file_path = os.path.join(local_cache_dir, bucket + "__" + key.replace("/", "_"))
    os.makedirs(os.path.dirname(local_file_path), exist_ok=True)
    lock_file = f"{local_file_path}.lock"

    # Use a file lock to prevent concurrent writes
    with FileLock(lock_file):
        if not os.path.exists(local_file_path):
            logger.info(f"Downloading {s3_path} to {local_file_path}")
            s3_client = boto3.client("s3", aws_access_key_id=os.getenv("DS_AWS_ACCESS_KEY_ID"), aws_secret_access_key=os.getenv("DS_AWS_SECRET_ACCESS_KEY"))
            s3_client.download_file(bucket, key, local_file_path)
        else:
            pass
            # logger.info(f"File {local_file_path} already exists, skipping download.")

    return local_file_path


def cache_s3_files(dataset: Dataset, pdf_cache_location: str, num_proc: int = 32) -> Dataset:
    """
    Caches all S3 paths in the dataset to the local cache directory.
    """

    # Define the download function to use in parallel processing
    def cache_file(example):
        s3_path = example["s3_path"]
        if s3_path:
            # Download the file and cache it locally
            local_path = _cache_s3_file(s3_path, pdf_cache_location)
            return {"local_pdf_path": local_path}
        return {"local_pdf_path": None}

    # Map the caching function to the dataset (with parallelism if needed)
    dataset = dataset.map(cache_file, num_proc=num_proc, load_from_cache_file=False)

    return dataset


def build_finetuning_dataset(response_glob_path: str, pdf_cache_location: Optional[str] = None, num_proc: int = 32) -> Dataset:
    if pdf_cache_location is None:
        pdf_cache_location = os.path.join(os.path.expanduser("~"), ".cache", "olmocr_pdfs")

    logger.info("Loading fine tuning dataset from OpenAI style batch responses")
    response_data = load_jsonl_into_ds(response_glob_path)
    response_data = response_data["train"]

    response_data = response_data.map(extract_openai_batch_response, remove_columns=response_data.column_names, num_proc=num_proc)

    # Don't include data where the model cut off due to a length issue, or moderation issue
    logger.info("Filtering on finish_reason == stop")
    final_dataset = response_data.filter(lambda x: x["finish_reason"] == "stop", num_proc=num_proc)

    # Cache all the s3_paths that were accessed to a local storage location,
    final_dataset = cache_s3_files(final_dataset, pdf_cache_location, num_proc)

    # Filter out pages where you cannot get an anchor text generated, to prevent errors during actual training
    def _can_create_anchor_text(example):
        try:
            anchor_text = get_anchor_text(example["local_pdf_path"], example["page_num"], pdf_engine="pdfreport", target_length=4000)
            _ = get_pdf_media_box_width_height(example["local_pdf_path"], example["page_num"])
            return anchor_text is not None
        except:
            logger.exception("Could not generate anchor text for file, be sure you have all dependencies installed")
            return False

    final_dataset = final_dataset.filter(_can_create_anchor_text, num_proc=num_proc)

    return final_dataset