Commit 89e60e48 authored by wanglch's avatar wanglch
Browse files

Initial commit

parents
Pipeline #2484 canceled with stages
model:
name_or_path: Qwen/Qwen2-VL-7B-Instruct
arch: causal
use_flash_attn: true
wandb:
project: pdelfin
entity: ai2-llm
generate:
max_length: 8192
train_data:
seed: 1337
cache_location: /data/jakep/pdfdata/pdelfin_cache
sources:
- name: openai_batch_data_v5_1_train
response_glob_path: /data/jakep/pdfdata/openai_batch_data_v5_1_train_done/*.json
target_longest_image_dim: 1024
target_anchor_text_len: 6000
- name: openai_batch_data_v5_1_iabooks_train
response_glob_path: /data/jakep/pdfdata/openai_batch_data_v5_1_iabooks_train_done/*.json
target_longest_image_dim: 1024
target_anchor_text_len: 6000
valid_data:
cache_location: /data/jakep/pdfdata/pdelfin_cache
metric_for_best_model: openai_batch_data_v5_1_eval_loss
sources:
# These tend to be small, so you can load from s3 it's no big deal
- name: openai_batch_data_v5_1_eval
response_glob_path: s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.json
target_longest_image_dim: 1024
target_anchor_text_len: 6000
- name: openai_batch_data_v5_1_iabooks_eval
response_glob_path: s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_iabooks_eval/*.json
target_longest_image_dim: 1024
target_anchor_text_len: 6000
# Mostly pulled from https://github.com/QwenLM/Qwen2/blob/main/examples/sft/finetune.sh
hparams:
batch_size: 1
eval_batch_size: 1
gradient_accumulation_steps: 4
gradient_checkpointing: true
clip_grad_norm: 1.0
learning_rate: 1e-4
max_steps: 10000
pad_multiple_of: 16
log_every_steps: 10
eval_every_steps: 100
optim: adamw_torch
lr_scheduler: cosine
weight_decay: 0.01
warmup_ratio: 0.03
# From https://github.com/QwenLM/Qwen2/blob/main/examples/sft/finetune.py
lora:
rank: 32
alpha: 32
dropout: 0.05
task_type: causal_lm
target_modules:
- q_proj
- k_proj
- v_proj
- o_proj
- gate_proj
- up_proj
- down_proj
- visual.blocks.[0-9]+.attn.qkv
- visual.blocks.[0-9]+.attn.proj
- visual.blocks.[0-9]+.mlp.fc1
- visual.blocks.[0-9]+.mlp.fc2
- visual.merger.mlp.0
- visual.merger.mlp.2
save:
path: s3://ai2-oe-data/jakep/experiments/qwen2vl-pdf/v1/models/
save_every_steps: 1000
max_workers: 10
\ No newline at end of file
model:
name_or_path: Qwen/Qwen2-VL-7B-Instruct
arch: causal
use_flash_attn: true
wandb:
project: pdelfin
entity: ai2-llm
generate:
max_length: 8192
train_data:
seed: 1337
cache_location: /data/jakep/pdfdata/pdelfin_cache
sources:
- name: openai_batch_data_v5_1_train
response_glob_path: /data/jakep/pdfdata/openai_batch_data_v5_1_train_done/*.json
target_longest_image_dim: [1024]
target_anchor_text_len: [6000]
- name: openai_batch_data_v5_1_iabooks_train
response_glob_path: /data/jakep/pdfdata/openai_batch_data_v5_1_iabooks_train_done/*.json
target_longest_image_dim: [1024]
target_anchor_text_len: [6000]
valid_data:
cache_location: /data/jakep/pdfdata/pdelfin_cache
metric_for_best_model: openai_batch_data_v5_1_eval_loss
sources:
# These tend to be small, so you can load from s3 it's no big deal
- name: openai_batch_data_v5_1_eval
response_glob_path: s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.json
target_longest_image_dim: [1024]
target_anchor_text_len: [6000]
- name: openai_batch_data_v5_1_iabooks_eval
response_glob_path: s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_iabooks_eval/*.json
target_longest_image_dim: [1024]
target_anchor_text_len: [6000]
# Mostly pulled from https://github.com/QwenLM/Qwen2/blob/main/examples/sft/finetune.sh
hparams:
batch_size: 1
eval_batch_size: 1
gradient_accumulation_steps: 4
gradient_checkpointing: true
clip_grad_norm: 1.0
learning_rate: 1e-6
max_steps: 10000
pad_multiple_of: 16
log_every_steps: 10
eval_every_steps: 100
optim: adamw_torch
lr_scheduler: cosine
weight_decay: 0.01
warmup_ratio: 0.03
save:
path: s3://ai2-oe-data/jakep/experiments/qwen2vl-pdf/v1/models/
save_every_steps: 9500
max_workers: 10
\ No newline at end of file
import json
from logging import Logger
from typing import Optional, Type
import smart_open
import torch
from peft.peft_model import PeftModel
from transformers import (
AutoModelForCausalLM,
AutoModelForSeq2SeqLM,
AutoModelWithLMHead,
AutoTokenizer,
)
from .config import ModelConfig
from .loggers import get_logger
from .paths import cached_path, exists, get_cache_dir, join_path, resource_to_filename
__all__ = ["load_model", "cache_merged_model"]
def get_model_cls(config: ModelConfig) -> Type[AutoModelWithLMHead]:
if config.arch == "seq2seq":
return AutoModelForSeq2SeqLM # pyright: ignore
elif config.arch == "causal" or config.arch == "vllm":
return AutoModelForCausalLM # pyright: ignore
else:
raise ValueError(f"Unsupported model architecture: {config.arch}")
def get_adapter_config(config: ModelConfig) -> dict:
local_path = cached_path(config.name_or_path)
if exists(adapter_config_path := join_path("", local_path, "adapter_config.json")):
with smart_open.open(adapter_config_path, "rt", encoding="utf-8") as f:
return json.load(f)
return {}
def load_model(config: ModelConfig, logger: Optional[Logger] = None) -> AutoModelWithLMHead:
logger = logger or get_logger(__file__, level="INFO")
logger.info(f"Loading model from {config.name_or_path}")
local_path = cached_path(config.name_or_path)
if local_path != config.name_or_path:
logger.info(f"Model cached at {local_path}")
if exists(adapter_config_path := join_path("", local_path, "adapter_config.json")):
logger.info(f"Loading LoRA adapter from {adapter_config_path}")
with smart_open.open(adapter_config_path) as f:
adapter_config = json.load(f)
base_model_name_or_path = adapter_config["base_model_name_or_path"]
enable_lora = True
else:
base_model_name_or_path = local_path
enable_lora = False
model = get_model_cls(config).from_pretrained(
base_model_name_or_path,
device_map="auto",
trust_remote_code=config.trust_remote_code,
# low_cpu_mem_usage=model_config.low_cpu_mem_usage,
use_flash_attention_2=True if config.use_flash_attn else False,
revision=config.model_revision,
torch_dtype=torch.bfloat16 if config.use_flash_attn else getattr(torch, config.dtype),
)
logger.info(f"Successfully loaded base model from {base_model_name_or_path}")
if enable_lora:
peft_model = PeftModel.from_pretrained(model, local_path)
model = peft_model.merge_and_unload()
logger.info(f"Successfully loaded LoRA adapter from base model: {base_model_name_or_path}")
return model
def cache_merged_model(config: ModelConfig, logger: Optional[Logger] = None) -> str:
logger = logger or get_logger(__file__, level="INFO")
base_local_path = cached_path(config.name_or_path)
adapter_config = get_adapter_config(config)
if not adapter_config:
logger.info("No adapter config found; using base model")
return base_local_path
local_fn = resource_to_filename(json.dumps({"adapter": adapter_config, "model": config.name_or_path}))
merged_local_path = f"{get_cache_dir()}/{local_fn}"
if not exists(merged_local_path):
model = load_model(config=config, logger=logger)
tokenizer = AutoTokenizer.from_pretrained(base_local_path)
logger.info(f"Saving merged model to {merged_local_path}")
model.save_pretrained(merged_local_path)
tokenizer.save_pretrained(merged_local_path)
return merged_local_path
"""
Utilities to work with a OmegaConf structured config object
From Dolma Toolkit: https://github.com/allenai/dolma/blob/64886d9db15bd99acea9e28740ae20a510875dfb/python/dolma/cli/__init__.py
Author: Luca Soldaini (@soldni)
""" # noqa: E501
from argparse import ArgumentParser, Namespace
from collections.abc import Iterable
from copy import deepcopy
from dataclasses import Field
from dataclasses import field as dataclass_field
from dataclasses import is_dataclass
from logging import warning
from typing import (
Any,
Dict,
Literal,
Optional,
Protocol,
Type,
TypeVar,
Union,
get_args,
get_origin,
)
import smart_open
from necessary import necessary
from omegaconf import MISSING, DictConfig, ListConfig
from omegaconf import OmegaConf as om
from omegaconf.errors import OmegaConfBaseException
from rich.console import Console
from rich.syntax import Syntax
from yaml import safe_load # type: ignore
from .errors import DolmaRefineError
__all__ = ["field", "namespace_to_nested_omegaconf", "print_config", "make_cli", "read_config", "to_native_types"]
T = TypeVar("T", bound=Any)
D = TypeVar("D", bound="DataClass")
A = TypeVar("A", bound="ArgumentParser")
def _field_nargs(default: Any) -> Union[Literal["?"], Literal["*"]]:
# return '+' if _default is iterable but not string/bytes, else 1
if isinstance(default, (str, bytes)):
return "?"
if isinstance(default, Iterable):
return "*"
return "?"
def field(default: T = MISSING, help: Optional[str] = None, **extra: Any) -> T:
metadata = {"help": help, "type": type(default), "default": default, "nargs": _field_nargs(default), **extra}
return dataclass_field(default_factory=lambda: deepcopy(default), metadata=metadata)
class DataClass(Protocol):
__dataclass_fields__: Dict[str, Field]
def read_config(path: Union[None, str]) -> Dict[str, Any]:
"""Read a configuration file if it exists"""
if path is None:
return {}
try:
with smart_open.open(path, mode="rt") as f:
return dict(safe_load(f))
except FileNotFoundError as ex:
raise DolmaRefineError(f"Config file not found: {path}") from ex
except Exception as ex:
raise DolmaRefineError(f"Error while reading config file: {path}") from ex
def save_config(config: Union[dict, DictConfig, list, ListConfig, DataClass], path: str) -> None:
"""Save a configuration to a file"""
if isinstance(config, (list, dict)):
config = om.create(config)
elif is_dataclass(config):
config = om.structured(config)
with smart_open.open(path, mode="wt") as f:
f.write(om.to_yaml(config))
def _make_parser(parser: A, config: Type[DataClass], prefix: Optional[str] = None) -> A:
for field_name, dt_field in config.__dataclass_fields__.items():
# get type from annotations or metadata
typ_ = config.__annotations__.get(field_name, dt_field.metadata.get("type", MISSING))
if typ_ is MISSING:
warning(f"No type annotation for field {field_name} in {config.__name__}")
continue
# join prefix and field name
field_name = f"{prefix}.{field_name}" if prefix else field_name
# This section here is to handle Optional[T] types; we only care for cases where T is a dataclass
# So we first check if type is Union since Optional[T] is just a shorthand for Union[T, None]
# and that the union contains only one non-None type
if get_origin(typ_) == Union:
# get all non-None types
args = [a for a in get_args(typ_) if a is not type(None)] # noqa: E721
if len(args) == 1:
# simple Optional[T] type
typ_ = args[0]
# here's where we check if T is a dataclass
if is_dataclass(typ_):
# recursively add subparsers
_make_parser(parser, typ_, prefix=field_name) # type: ignore
continue
if typ_ is bool:
# for boolean values, we add two arguments: --field_name and --no-field_name
parser.add_argument(
f"--{field_name}",
help=dt_field.metadata.get("help"),
dest=field_name,
action="store_true",
default=MISSING,
)
parser.add_argument(
f"--no-{field_name}",
help=f"Disable {field_name}",
dest=field_name,
action="store_false",
default=MISSING,
)
else:
# else it's just a normal argument
parser.add_argument(
f"--{field_name}",
help=dt_field.metadata.get("help"),
nargs=dt_field.metadata.get("nargs", "?"),
default=MISSING,
)
return parser
def make_nested_dict(key: str, value: Any, d: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
d = d or {}
if "." in key:
key, rest = key.split(".", 1)
value = make_nested_dict(rest, value, d.get(key))
# the value was provided (is not MISSING constant) and is not an empty dict or list
if value != MISSING and (not isinstance(value, (dict, list)) or len(value) > 0):
d[key] = value
return d
def to_native_types(obj: Any, resolve: bool = True, throw_on_missing: bool = True, enum_to_str: bool = True) -> Any:
"""Converts an OmegaConf object to native types (dicts, lists, etc.)"""
# convert dataclass to structured config
if hasattr(obj, "to_dict"):
# huggingface objects have a to_dict method, we prefer that
obj = obj.to_dict()
elif is_dataclass(obj):
# we go through structured config instead and hope for the best
obj = om.to_container(obj)
if isinstance(obj, DictConfig) or isinstance(obj, ListConfig):
obj = om.to_container(obj, resolve=resolve, throw_on_missing=throw_on_missing, enum_to_str=enum_to_str)
if isinstance(obj, dict):
return {k: to_native_types(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [to_native_types(v) for v in obj]
else:
return obj
def namespace_to_nested_omegaconf(args: Namespace, structured: Type[T], config: Optional[dict] = None) -> T:
nested_config_dict: Dict[str, Any] = {}
for key, value in vars(args).items():
nested_config_dict = make_nested_dict(key, value, nested_config_dict)
untyped_config: DictConfig = om.merge(
om.create(config or {}), om.create(nested_config_dict)
) # pyright: ignore (pylance is confused because om.create might return a DictConfig or a ListConfig)
# resolve any interpolations in the config
om.resolve(untyped_config)
# create structured config from cli dataclass
base_structured_config: DictConfig = om.structured(structured)
# merge with options parsed from config file and
merged_config = om.merge(base_structured_config, untyped_config)
# check for type
if not isinstance(merged_config, DictConfig):
raise DolmaRefineError(f"Expected a DictConfig, got {type(merged_config).__name__}")
# try resolving all cross references in the config, raise a DolmaConfigError if it fails
try:
om.resolve(merged_config)
except OmegaConfBaseException as ex:
raise DolmaRefineError(f"Invalid error while parsing key `{ex.full_key}`: {type(ex).__name__}") from ex
return merged_config # pyright: ignore
def print_config(config: Any, console: Optional[Console] = None) -> None:
if not isinstance(config, (DictConfig, ListConfig)):
config = om.create(config)
# print the config as yaml using a rich syntax highlighter
console = console or Console()
yaml_config = om.to_yaml(config, sort_keys=True).strip()
highlighted = Syntax(code=yaml_config, lexer="yaml", theme="ansi_dark")
console.print(highlighted)
def _patch_old_omegaconf():
"""Monkey patch omegaconf below version 2.3.0 to support custom resolver returning
lists or dicts. Applies patch https://github.com/omry/omegaconf/pull/1093"""
if necessary(("omegaconf", "2.4.0"), soft=True):
# no need to patch
return
if getattr(_patch_old_omegaconf, "__patched__", False):
# already patched
return
from omegaconf import _impl # pylint: disable=import-outside-toplevel
from omegaconf import ( # pylint: disable=import-outside-toplevel
Container,
Node,
ValueNode,
)
from omegaconf._utils import ( # noqa: F401 # pylint: disable=import-outside-toplevel
_ensure_container,
_get_value,
is_primitive_container,
is_structured_config,
)
from omegaconf.errors import ( # pylint: disable=import-outside-toplevel
InterpolationToMissingValueError,
)
from omegaconf.nodes import ( # pylint: disable=import-outside-toplevel
InterpolationResultNode,
)
def _resolve_container_value(cfg: Container, key: Any) -> None:
node = cfg._get_child(key) # pylint: disable=protected-access
assert isinstance(node, Node)
if node._is_interpolation(): # pylint: disable=protected-access
try:
resolved = node._dereference_node() # pylint: disable=protected-access
except InterpolationToMissingValueError:
node._set_value(MISSING) # pylint: disable=protected-access
else:
if isinstance(resolved, Container):
_impl._resolve(resolved) # pylint: disable=protected-access
if isinstance(resolved, InterpolationResultNode):
resolved_value = _get_value(resolved)
if is_primitive_container(resolved_value) or is_structured_config(resolved_value):
resolved = _ensure_container(resolved_value)
if isinstance(resolved, Container) and isinstance(node, ValueNode):
cfg[key] = resolved
else:
node._set_value(_get_value(resolved)) # pylint: disable=protected-access
else:
_impl._resolve(node) # pylint: disable=protected-access
# set new function and mark as patched
setattr(_impl, "_resolve_container_value", _resolve_container_value)
setattr(_patch_old_omegaconf, "__patched__", True)
# actually executes the patch
_patch_old_omegaconf()
def make_cli(config_cls: Type[D], _config_flag: str = "config", _dryrun_flag: str = "dryrun") -> D:
"""Create a CLI parser for a dataclass and parse the arguments into a structured config object."""
if hasattr(config_cls, _config_flag):
raise DolmaRefineError(f"`{_config_flag}` is a reserved attribute; remove it from `{config_cls.__name__}`")
if hasattr(config_cls, _dryrun_flag):
raise DolmaRefineError(f"`{_dryrun_flag}` is a reserved attribute; remove it from `{config_cls.__name__}`")
parser = ArgumentParser()
parser.add_argument(f"-{_config_flag[0]}", f"--{_config_flag}", help="Path to config file", default=None, type=str)
parser.add_argument(
f"-{_dryrun_flag[0]}",
f"--{_dryrun_flag}",
help="Dry run mode: print config and exit",
action="store_true",
default=False,
)
parser = _make_parser(parser, config_cls)
args = parser.parse_args()
parsed_config: Dict[str, Any] = {}
if (config_path := getattr(args, _config_flag)) is not None:
parsed_config = read_config(config_path)
delattr(args, _config_flag)
only_dryrun = getattr(args, _dryrun_flag, False)
delattr(args, _dryrun_flag)
full_config = namespace_to_nested_omegaconf(args, config_cls, parsed_config)
print_config(full_config)
if only_dryrun:
exit(0)
return full_config
from smart_open import register_compressor
__all__ = ["mk_compression"]
def mk_compression():
def _handle_zst(file_obj, mode):
try:
import zstandard as zstd
except ImportError:
raise ImportError("zstandard is required for zstd support")
return zstd.open(file_obj, mode)
register_compressor(".zstd", _handle_zst)
register_compressor(".zst", _handle_zst)
from dataclasses import dataclass
from typing import List, Optional
from peft import TaskType # pyright: ignore
from .cli import field
@dataclass
class ModelConfig:
"""Configuration for loading a model; includes model name and type."""
name_or_path: str = field(help="The model name or path to load; must be compatible with huggingface transformers.")
arch: str = field(help="The model type to load; can be 'vllm', 'causal', or 'vllm'")
dtype: str = field(help="The precision to use for the model", default="bfloat16")
use_flash_attn: bool = field(help="Whether to use the flash attention for the model.", default=False)
trust_remote_code: bool = field(help="Whether to trust remote code for the model.", default=False)
low_cpu_mem_usage: bool = field(help="Whether to use low cpu memory usage for the model.", default=False)
fast_tokenizer: bool = field(help="Whether to use the fast tokenizer for the model.", default=True)
model_revision: Optional[str] = field(help="The model revision to use for the model.", default=None)
@dataclass
class GenerateConfig:
max_length: int = field(help="The maximum length of the generated text", default=4096)
temperature: float = field(default=0.2, help="The temperature to use for generation")
top_k: int = field(default=50, help="The top k to use for generation")
top_p: float = field(default=1.0, help="The top p to use for generation")
num_beams: int = field(default=1, help="The number of beams to use for generation")
truncate_prompt_tokens: bool = field(default=True, help="Whether to truncate the prompt tokens for generation")
max_num_seqs: int = field(default=16, help="The maximum number of sequences to generate")
@dataclass
class WandbConfig:
entity: str = field(help="The wandb entity to use for logging", default="ai2-llm")
project: str = field(help="The wandb project to use for logging", default="pdf-qwen2vl")
wandb_api_key: Optional[str] = field(help="The wandb api key to use for logging", default=None)
mode: str = field(help="The wandb mode to use for logging. Set it to `offline`", default="online")
watch: str = field(help="The wandb watch to use for logging", default="false")
@dataclass
class AwsConfig:
profile: Optional[str] = field(help="The aws profile to use for s3 access", default=None)
access_key_id: Optional[str] = field(help="The aws access key id to use for s3 access", default=None)
secret_access_key: Optional[str] = field(help="The aws secret access key to use for s3 access", default=None)
default_region: Optional[str] = field(help="The default region to use for s3 access", default=None)
@dataclass
class SourceConfig:
name: str = field(help="The name of the source")
response_glob_path: str = field(help="The s3 bucket pointing to the batch api response json's sent back from open ai")
target_longest_image_dim: list[int] = field(help="Dimensions to render the pdf page image to")
target_anchor_text_len: list[int] = field(help="Maximum amount of anchor text (aka prompt hint)")
@dataclass
class DataConfig:
seed: int = field(default=42, help="The seed to use for data loading")
cache_location: Optional[str] = field(help="Location to store s3 pdfs that need to be used to compute page images", default=None)
metric_for_best_model: Optional[str] = field(help="metric to pass to trainer args to use for picking best model checkpoint at end", default=None)
sources: List[SourceConfig] = field(help="The source configurations")
@dataclass
class HyperparamConfig:
batch_size: int = field(default=8, help="The batch size to use for training")
eval_batch_size: Optional[int] = field(default=None, help="The batch size to use for evaluation; default is the same as the training batch size")
learning_rate: float = field(default=2e-5, help="The learning rate to use for training")
max_steps: int = field(default=-1, help="The maximum number of steps to train the model")
pad_multiple_of: int = field(default=16, help="The padding multiple to use for the model")
log_every_steps: int = field(default=5, help="The number of steps to log training metrics")
eval_every_steps: int = field(default=100, help="The number of steps to evaluate the model")
weight_decay: float = field(default=0.0, help="The weight decay to use for training")
warmup_steps: int = field(default=0, help="The number of warmup steps to use for training")
warmup_ratio: float = field(default=0.0, help="The ratio of warmup steps to use for training")
lr_scheduler: str = field(default="linear", help="The learning rate scheduler to use for training")
gradient_accumulation_steps: int = field(default=1, help="The number of gradient accumulation steps to use for training")
gradient_checkpointing: bool = field(default=False, help="Whether to use gradient checkpointing for training")
seed: int = field(default=42, help="The seed to use for training")
reduce_loss: str = field(default="mean", help="The loss reduction to use for training")
clip_grad_norm: float = field(default=0.0, help="The gradient norm to clip to for training")
optim: str = field(default="adamw_torch", help="The optimizer to use for training")
find_unused_parameters: bool = field(default=False, help="Whether to find unused parameters for training")
@dataclass
class SaveConfig:
path: str = field(default="./results", help="The output directory to save the model")
limit: Optional[int] = field(default=None, help="The number of checkpoints to save")
save_every_steps: int = field(default="${hparams.eval_every_steps}", help="The number of steps to save the model") # type: ignore
@dataclass
class LoraConfig:
rank: int = field(default=16, help="The rank of the LoRA attention")
alpha: int = field(default=16, help="The alpha parameter for LoRA scaling")
dropout: float = field(default=0.05, help="The dropout probability for LoRA layers")
bias: str = field(default="none", help="The bias to use for LoRA layers (none, causal, or full)")
task_type: str = field(default=TaskType.CAUSAL_LM, help="The task type for the model")
target_modules: List[str] = field(
default=["k_proj", "q_proj", "v_proj", "o_proj", "gate_proj", "down_proj", "up_proj"],
help="The target modules in the model that will be replaced with LoRA layers",
)
@dataclass
class TrainConfig:
model: ModelConfig = field(default=ModelConfig(), help="The model configuration")
lora: Optional[LoraConfig] = field(default=None, help="The LoRA configuration")
aws: AwsConfig = field(default=AwsConfig(), help="Configuration for AWS S3")
wandb: WandbConfig = field(default=WandbConfig(), help="Configuration for Weights and Biases")
train_data: DataConfig = field(default=DataConfig(), help="Configuration for the training data")
valid_data: DataConfig = field(default=DataConfig(), help="Configuration for the validation data")
generate: GenerateConfig = field(default=GenerateConfig(), help="Configuration for text generation")
num_proc: int = field(default=1, help="The maximum number of workers to use for data processing")
max_workers: int = field(default=1, help="The maximum number of workers to use for data loaders")
hparams: HyperparamConfig = field(default=HyperparamConfig(), help="Hyperparameters for training")
save: SaveConfig = field(default=SaveConfig(), help="Configuration for saving the model")
@dataclass
class DemoConfig:
title: str = field(default="# Dolma Rewriter Demo")
description: str = field(default="Internal use only, **DO NOT SHARE OUTSIDE AI2**.")
share: bool = field(default=False, help="Share the demo publicly.")
model: ModelConfig = field(default=ModelConfig())
generate: GenerateConfig = field(default=GenerateConfig())
class DolmaRefineError(RuntimeError): ...
import logging
import multiprocessing
from typing import Union
LOGGER_PREFIX = "dolma-refine"
def get_logger(name: str, level: Union[int, str] = logging.WARN) -> logging.Logger:
if (proc_name := multiprocessing.current_process().name) == "MainProcess":
proc_name = "main"
proc_name = proc_name.replace(" ", "_")
# set the log level
level = level if isinstance(level, int) else getattr(logging, level.strip().upper(), logging.WARN)
# set name
name = f"{LOGGER_PREFIX}.{proc_name}.{name}"
logger = logging.getLogger(name)
logger.setLevel(level)
# add handler
if not logger.handlers:
handler = logging.StreamHandler()
formatter = logging.Formatter("[%(asctime)s %(name)s %(levelname)s] %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
handler.setFormatter(formatter)
logger.addHandler(handler)
return logger
def reset_level(level: Union[int, str]) -> None:
"""
Reset the log level for all Dolma loggers.
Args:
level (Union[int, str]): The log level to set. It can be either an integer
representing the log level (e.g., logging.DEBUG) or a string
representing the log level name (e.g., 'debug').
Returns:
None
"""
if isinstance(level, str):
if (level_tmp := getattr(logging, level.strip().upper(), None)) is not None:
level = level_tmp
else:
raise ValueError(f"Invalid log level: {level}")
for logger in logging.Logger.manager.loggerDict.values():
if isinstance(logger, logging.Logger):
if logger.name.startswith(LOGGER_PREFIX):
logger.setLevel(level)
import glob
import os
import re
from concurrent.futures import ThreadPoolExecutor
from functools import partial, reduce
from hashlib import sha256
from itertools import chain
from pathlib import Path
from shutil import copyfileobj
from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple, Union
from urllib.parse import urlparse
import platformdirs
import smart_open
from fsspec import AbstractFileSystem, get_filesystem_class
from smart_open.compression import get_supported_extensions
from .loggers import LOGGER_PREFIX, get_logger
__all__ = [
"glob_path",
"sub_prefix",
"add_suffix",
"sub_suffix",
"make_relative",
"mkdir_p",
"split_path",
"join_path",
"is_glob",
"split_glob",
"partition_path",
]
FS_KWARGS: Dict[str, Dict[str, Any]] = {
"": {"auto_mkdir": True},
}
RE_ANY_ESCAPE = re.compile(r"(?<!\\)(\*\?\[\])")
RE_GLOB_STAR_ESCAPE = re.compile(r"(?<!\\)\*")
RE_GLOB_ONE_ESCAPE = re.compile(r"(?<!\\)\?")
RE_GLOB_OPEN_ESCAPE = re.compile(r"(?<!\\)\[")
RE_GLOB_CLOSE_ESCAPE = re.compile(r"(?<!\\)\]")
ESCAPE_SYMBOLS_MAP = {"*": "\u2581", "?": "\u2582", "[": "\u2583", "]": "\u2584"}
REVERSE_ESCAPE_SYMBOLS_MAP = {v: k for k, v in ESCAPE_SYMBOLS_MAP.items()}
PATCHED_GLOB = False
LOGGER = get_logger(__name__)
def get_fs(path: Union[Path, str]) -> AbstractFileSystem:
"""
Get the filesystem class for a given path.
"""
path = str(path)
protocol = urlparse(path).scheme
fs = get_filesystem_class(protocol)(**FS_KWARGS.get(protocol, {}))
global PATCHED_GLOB # pylint: disable=global-statement
# patch glob method to support recursive globbing
if protocol == "" and not PATCHED_GLOB:
fs.glob = partial(glob.glob, recursive=True)
# only patch once
PATCHED_GLOB = True
return fs
def _escape_glob(s: Union[str, Path]) -> str:
"""
Escape glob characters in a string.
"""
s = str(s)
s = RE_GLOB_STAR_ESCAPE.sub(ESCAPE_SYMBOLS_MAP["*"], s)
s = RE_GLOB_ONE_ESCAPE.sub(ESCAPE_SYMBOLS_MAP["?"], s)
s = RE_GLOB_OPEN_ESCAPE.sub(ESCAPE_SYMBOLS_MAP["["], s)
s = RE_GLOB_CLOSE_ESCAPE.sub(ESCAPE_SYMBOLS_MAP["]"], s)
return s
def _unescape_glob(s: Union[str, Path]) -> str:
"""
Unescape glob characters in a string.
"""
s = str(s)
for k, v in REVERSE_ESCAPE_SYMBOLS_MAP.items():
s = s.replace(k, v)
return s
def _pathify(path: Union[Path, str]) -> Tuple[str, Path]:
"""
Return the protocol and path of a given path.
"""
path = _escape_glob(str(path))
parsed = urlparse(path)
path = Path(f"{parsed.netloc}/{parsed.path}") if parsed.netloc else Path(parsed.path)
return parsed.scheme, path
def _unpathify(protocol: str, path: Path) -> str:
"""
Return a path from its protocol and path components.
"""
path_str = _unescape_glob(str(path))
if protocol:
path_str = f"{protocol}://{path_str.lstrip('/')}"
return path_str
def remove_params(path: str) -> str:
"""
Remove parameters from a path.
"""
parsed = urlparse(path)
return (f"{parsed.scheme}://" if parsed.scheme else "") + f"{parsed.netloc}{parsed.path}"
def is_local(path: str) -> bool:
"""
Check if a path is local.
"""
prot, _ = _pathify(path)
return prot == "" or prot == "file"
def copy_file(src: str, dest: str) -> None:
"""Copy a file using shutil.copyfileobj for efficient chunked copying."""
with smart_open.open(src, "rb") as src_file, smart_open.open(dest, "wb") as dest_file:
copyfileobj(src_file, dest_file)
def copy_dir(src: str, dst: str, src_fs: Optional[AbstractFileSystem] = None, dst_fs: Optional[AbstractFileSystem] = None):
"""Copy a directory using a ThreadPoolExecutor for parallel file copying."""
src_fs = src_fs or get_fs(src)
dst_fs = dst_fs or get_fs(dst)
logger = get_logger(__name__)
with ThreadPoolExecutor(max_workers=8) as executor:
futures = []
for src_path in glob_path(src, yield_dirs=True, fs=src_fs):
rel_path = sub_prefix(src_path, src)
dest_path = join_path("", dst, rel_path)
if is_dir(src_path, fs=src_fs):
# Recursively copy directories
copy_dir(src=src_path, dst=dest_path, src_fs=src_fs, dst_fs=dst_fs)
else:
# File; copy over using the executor for parallelism
logger.info(f"Copying {src_path} to {dest_path}")
futures.append(executor.submit(copy_file, src_path, dest_path))
# Wait for all futures to complete
for future in futures:
future.result() # This will raise an exception if any of the threads failed
def delete_file(path: str, ignore_missing: bool = False, fs: Optional[AbstractFileSystem] = None) -> bool:
"""Delete a file."""
fs = fs or get_fs(path)
try:
fs.rm(path)
deleted = True
except FileNotFoundError as ex:
if not ignore_missing:
raise ex
deleted = False
return deleted
def get_size(path: str, fs: Optional[AbstractFileSystem] = None) -> int:
"""Get the size of a file"""
fs = fs or get_fs(path)
if not exists(path, fs=fs):
raise ValueError(f"Path {path} does not exist")
if is_dir(path, fs=fs):
raise ValueError(f"Path {path} is a directory")
return fs.info(path)["size"]
def delete_dir(path: str, ignore_missing: bool = False, fs: Optional[AbstractFileSystem] = None) -> bool:
"""Delete a directory."""
fs = fs or get_fs(path)
try:
fs.rm(path, recursive=True)
deleted = True
except FileNotFoundError as ex:
if not ignore_missing:
raise ex
deleted = False
return deleted
def partition_path(path: str) -> Tuple[str, Tuple[str, ...], Tuple[str, ...]]:
"""Partition a path into its protocol, symbols before a glob, and symbols after a glob."""
# split the path into its protocol and path components
prot, path_obj = _pathify(path)
# we need to first figure out if this path has a glob by checking if any of the escaped symbols for
# globs are in the path.
glob_locs = [i for i, p in enumerate(path_obj.parts) if any(c in p for c in REVERSE_ESCAPE_SYMBOLS_MAP)]
# make the path components before the glob
pre_glob_path = path_obj.parts[: glob_locs[0]] if glob_locs else path_obj.parts
pre_glob_path = tuple(_unescape_glob(p) for p in pre_glob_path)
# make the path components after the glob
post_glob_path = path_obj.parts[glob_locs[0] + 1 :] if glob_locs else ()
post_glob_path = tuple(_unescape_glob(p) for p in post_glob_path)
return prot, pre_glob_path, post_glob_path
def split_path(path: str) -> Tuple[str, Tuple[str, ...]]:
"""
Split a path into its protocol and path components.
"""
protocol, _path = _pathify(path)
return protocol, tuple(_unescape_glob(p) for p in _path.parts)
def join_path(protocol: Union[str, None], *parts: Union[str, Iterable[str]]) -> str:
"""
Join a path from its protocol and path components.
"""
all_prots, all_parts = zip(*(_pathify(p) for p in chain.from_iterable([p] if isinstance(p, str) else p for p in parts)))
path = str(Path(*all_parts)).rstrip("/")
protocol = protocol or str(all_prots[0])
if protocol:
path = f"{protocol}://{path.lstrip('/')}"
return _unescape_glob(path)
def glob_path(
path: Union[Path, str],
hidden_files: bool = False,
autoglob_dirs: bool = True,
recursive_dirs: bool = False,
yield_dirs: bool = True,
fs: Optional[AbstractFileSystem] = None,
) -> Iterator[str]:
"""
Expand a glob path into a list of paths.
"""
protocol, parsed_path = _pathify(path)
fs = fs or get_fs(path)
if autoglob_dirs and fs.isdir(path):
path = join_path(protocol, _unescape_glob(parsed_path), "*")
if "*" not in str(path):
# nothing to glob
yield str(path)
return
for gl in fs.glob(path):
gl = str(gl)
if not hidden_files and Path(gl).name.startswith("."):
continue
if fs.isdir(gl):
if recursive_dirs:
yield from glob_path(
gl,
hidden_files=hidden_files,
autoglob_dirs=autoglob_dirs,
recursive_dirs=recursive_dirs,
yield_dirs=yield_dirs,
fs=fs,
)
if yield_dirs:
yield join_path(protocol, gl)
else:
yield join_path(protocol, gl)
def sub_prefix(a: str, b: str) -> str:
"""
Return the relative path of b from a.
"""
prot_a, path_a = _pathify(a)
prot_b, path_b = _pathify(b)
if prot_a != prot_b:
raise ValueError(f"Protocols of {a} and {b} do not match")
try:
diff = str(path_a.relative_to(path_b))
except ValueError:
diff = join_path(prot_a, path_a.parts)
return _unescape_glob(diff)
def sub_suffix(a: str, b: str) -> str:
"""
Remove b from the end of a.
"""
prot_a, path_a = _pathify(a)
prot_b, path_b = _pathify(b)
if prot_b:
raise ValueError(f"{b} is not a relative path")
sub_path = re.sub(f"{path_b}$", "", str(path_a))
sub_prot = f"{prot_a}://" if prot_a else ""
# need to trim '/' from the end if (a) '/' is not the only symbol in the path or
# (b) there is a protocol so absolute paths don't make sense
if sub_path != "/" or sub_prot:
sub_path = sub_path.rstrip("/")
return _unescape_glob(sub_prot + sub_path)
def add_suffix(a: str, b: str) -> str:
"""
Return the the path of a joined with b.
"""
prot_a, path_a = _pathify(a)
prot_b, path_b = _pathify(b)
if prot_b:
raise ValueError(f"{b} is not a relative path")
return join_path(prot_a, str(path_a / path_b))
def exists(path: str, fs: Optional[AbstractFileSystem] = None) -> bool:
"""Check if a path exists."""
fs = fs or get_fs(path)
return fs.exists(path)
def is_dir(path: str, fs: Optional[AbstractFileSystem] = None) -> bool:
"""Check if a path is a directory."""
fs = fs or get_fs(path)
if exists(path, fs=fs):
return fs.isdir(path)
return False
def is_file(path: str, fs: Optional[AbstractFileSystem] = None) -> bool:
"""Check if a path is a file."""
fs = fs or get_fs(path)
if exists(path, fs=fs):
return fs.isfile(path)
return False
def parent(path: str) -> str:
"""Get the parent directory of a path; if the parent is the root, return the root."""
prot, parts = split_path(path)
if len(parts) == 1:
return path
return join_path(prot, *parts[:-1])
def mkdir_p(path: str, fs: Optional[AbstractFileSystem] = None) -> None:
"""
Create a directory if it does not exist.
"""
if is_glob(path):
raise ValueError(f"Cannot create directory with glob pattern: {path}")
fs = fs or get_fs(path)
fs.makedirs(path, exist_ok=True)
def make_relative(paths: List[str]) -> Tuple[str, List[str]]:
"""Find minimum longest root shared among all paths"""
if len(paths) == 0:
raise ValueError("Cannot make relative path of empty list")
common_prot, common_parts, _ = partition_path(paths[0])
for path in paths:
current_prot, current_parts, _ = partition_path(path)
if current_prot != common_prot:
raise ValueError(f"Protocols of {path} and {paths[0]} do not match")
for i in range(min(len(common_parts), len(current_parts))):
if common_parts[i] != current_parts[i]:
common_parts = common_parts[:i]
break
if len(common_parts) > 0:
common_path = (f"{common_prot}://" if common_prot else "") + str(Path(*common_parts))
relative_paths = [sub_prefix(path, common_path) for path in paths]
else:
common_path = f"{common_prot}://" if common_prot else ""
relative_paths = [_unpathify("", _pathify(path)[1]) for path in paths]
return common_path, relative_paths
def is_glob(path: str) -> bool:
"""
Check if a path contains a glob wildcard.
"""
return bool(re.search(r"(?<!\\)[*?[\]]", path))
def split_glob(path: str) -> Tuple[str, str]:
"""
Partition a path on the first wildcard.
"""
if not is_glob(path):
# it's not a glob, so it's all path
return path, ""
if path[0] == "*":
# starts with a glob, so it's all glob
return "", path
protocol, parts = split_path(path)
i = min(i for i, c in enumerate(parts) if is_glob(c))
if i == 0:
# no path, so it's all glob
return protocol, join_path("", *parts)
path = join_path(protocol, *parts[:i])
rest = join_path("", *parts[i:])
return path, rest
def get_cache_dir() -> str:
"""
Returns the path to the cache directory for the Dolma toolkit.
If the directory does not exist, it will be created.
Returns:
str: The path to the cache directory.
"""
loc = platformdirs.user_cache_dir(LOGGER_PREFIX)
mkdir_p(loc)
return loc
def resource_to_filename(resource: Union[str, bytes]) -> str:
"""
Convert a ``resource`` into a hashed filename in a repeatable way. Preserves the file extensions.
"""
_, (*_, orig_filename) = split_path(remove_params(str(resource)))
_, extensions = split_basename_and_extension(orig_filename)
resource_bytes = str(resource).encode("utf-8")
resource_hash = sha256(resource_bytes)
hash_filename = resource_hash.hexdigest() + extensions
return hash_filename
def cached_path(path: str, fs: Optional[AbstractFileSystem] = None) -> str:
"""
Returns the cached path for a given resource.
If the resource is already available locally, the function returns the path as is.
Otherwise, it downloads the resource from the specified path and saves it in the cache directory.
Args:
path (str): The path to the resource.
Returns:
str: The cached path of the resource.
"""
if is_local(path):
# Implementation goes here
pass
return path
destination = f"{get_cache_dir()}/{resource_to_filename(path)}"
remote_fs = fs or get_fs(path)
local_fs = get_fs(destination)
if exists(destination, fs=local_fs):
LOGGER.info(f"Using cached file {destination} for {path}")
return destination
if is_dir(path, fs=remote_fs):
for sub_path in glob_path(path, fs=remote_fs):
rel_path = sub_prefix(sub_path, path)
dest_path = join_path("", destination, rel_path)
mkdir_p(parent(dest_path), fs=local_fs)
LOGGER.info(f"Downloading {sub_path} to {dest_path}")
with smart_open.open(sub_path, "rb") as src, smart_open.open(dest_path, "wb") as dest:
dest.write(src.read())
else:
LOGGER.info(f"Downloading {path} to {destination}")
with smart_open.open(path, "rb") as src, smart_open.open(destination, "wb") as dest:
dest.write(src.read())
return destination
def split_basename_and_extension(path: str) -> Tuple[str, str]:
"""
Get the path and extension from a given file path. If a file has multiple
extensions, they will be joined with a period, e.g. "foo/bar/baz.tar.gz"
will return ("foo/bar/baz", ".tar.gz"). If the file has no extension, the
second element of the tuple will be an empty string. Works with both local
and remote (e.g. s3://) paths.
Args:
path (str): The file path.
Returns:
Tuple[str, str]: A tuple containing the path and extension.
"""
prot, (*parts, filename) = split_path(path)
base, *ext_parts = filename.split(".")
ext = ("." + ".".join(ext_parts)) if ext_parts else ""
return join_path(prot, *parts, base), ext
def decompress_path(path: str, dest: Optional[str] = None) -> str:
"""
Decompresses a file at the given path and returns the path to the decompressed file.
Args:
path (str): The path to the file to be decompressed.
dest (str, optional): The destination path for the decompressed file.
If not provided, a destination path will be computed based on the original
file name and the cache directory.
Returns:
str: The path to the decompressed file. If the file cannot be decompressed,
the original path will be returned.
"""
for supported_ext in get_supported_extensions():
# not the supported extension
if not path.endswith(supported_ext):
continue
if dest is None:
# compute the name for the decompressed file; to do this, we first hash for
# resource and then remove the extension.
base_fn, ext = split_basename_and_extension(resource_to_filename(path))
# to get the decompressed file name, we remove the bit of the extension that
# indicates the compression type.
decompressed_fn = base_fn + ext.replace(supported_ext, "")
# finally, we get cache directory and join the decompressed file name to it
dest = join_path("", get_cache_dir(), decompressed_fn)
# here we do the actual decompression
with smart_open.open(path, "rb") as fr, smart_open.open(dest, "wb") as fw:
fw.write(fr.read())
# return the path to the decompressed file
return dest
# already decompressed or can't be decompressed
return path
def split_ext(path: str) -> Tuple[str, Tuple[str, ...], str]:
"""
Split a path into its protocol and extensions.
"""
prot, parts = split_path(path)
if not parts:
return prot, (), ""
filename = parts[-1]
extensions = []
while True:
filename, ext = os.path.splitext(filename)
if not ext:
break
extensions.append(ext)
return prot, (*parts[:-1], filename), "".join(reversed(extensions))
def get_unified_path(paths: List[str]) -> str:
"""Get a unified path for a list of paths."""
if len(paths) == 1:
# if there is only one path, we don't need to unify anything
return paths[0]
# get shared root for all paths; we will put the unified path here
root, relative = make_relative(paths)
# get the extension from the first path; assume all paths have the same extension
_, _, ext = split_ext(relative[0])
# hash all the sorted relative paths in order to get a unique name
# the type: ignore is needed because mypy fails to infer the type of the lambda
# (the "or" ensures that the lambda returns the same type as the first argument, which is a hash)
h = reduce(lambda h, p: h.update(p.encode()) or h, sorted(relative), sha256()) # type: ignore
# return the unified path
return join_path(root, h.hexdigest() + ext)
import os
from dataclasses import dataclass
from typing import Optional
@dataclass
class BeakerState:
job_id: Optional[str] = None
job_kind: Optional[str] = None
task_id: Optional[str] = None
experiment_id: Optional[str] = None
replica_rank: Optional[str] = None
leader_replica_hostname: Optional[str] = None
leader_replica_node_id: Optional[str] = None
user_id: Optional[str] = None
def __post_init__(self):
for key, value in os.environ.items():
if not key.startswith("BEAKER_"):
continue
setattr(self, key.lstrip("BEAKER_").lower(), value)
@property
def url(self) -> Optional[str]:
if self.job_id:
return f"https://beaker.org/jobs/{self.job_id}"
return None
import glob
import logging
import os
import re
from typing import Optional
import boto3
from datasets import Dataset, load_dataset
from filelock import FileLock
from olmocr.data.renderpdf import get_pdf_media_box_width_height
from olmocr.prompts.anchor import get_anchor_text
from olmocr.s3_utils import parse_custom_id, parse_s3_path
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Quiet logs from pypdf and smart open
logging.getLogger("pypdf").setLevel(logging.ERROR)
logging.getLogger("smart_open").setLevel(logging.ERROR)
def list_dataset_files(s3_glob_path: str):
"""
Lists files in the specified S3 path that match the glob pattern.
"""
if s3_glob_path.startswith("s3://"):
s3 = boto3.client("s3")
match = re.match(r"s3://([^/]+)/(.+)", s3_glob_path)
if not match:
logger.error(f"Invalid S3 path: {s3_glob_path}")
raise ValueError(f"Invalid S3 path: {s3_glob_path}")
bucket, prefix_pattern = match.groups()
prefix = prefix_pattern.split("*")[0] # Extract prefix before the wildcard
paginator = s3.get_paginator("list_objects_v2")
pages = paginator.paginate(Bucket=bucket, Prefix=prefix)
files = []
pattern = re.compile(prefix_pattern.replace("*", ".*"))
for page in pages:
for obj in page.get("Contents", []):
key = obj["Key"]
if pattern.fullmatch(key):
files.append(f"s3://{bucket}/{key}")
return files
else:
return glob.glob(s3_glob_path)
def load_jsonl_into_ds(s3_glob_path: str, first_n_files: Optional[int] = None) -> Dataset:
"""
Loads JSONL files from the specified S3 path into a Hugging Face Dataset.
"""
all_json_files = list_dataset_files(s3_glob_path)
if first_n_files:
all_json_files = all_json_files[:first_n_files]
# Use datasets library to load JSON files from S3
dataset = load_dataset(
"json",
data_files=all_json_files,
)
return dataset
def extract_openai_batch_response(example):
custom_id = example.get("custom_id", None)
# Parse the custom id into an s3 document path and page number (1indexed)
s3_path, page_num = parse_custom_id(custom_id)
response_body = example.get("response", {}).get("body", {})
choices = response_body.get("choices", [])
response = ""
finish_reason = ""
if choices:
first_choice = choices[0]
message = first_choice.get("message", {})
response = message.get("content", "")
finish_reason = first_choice.get("finish_reason", "")
# TODO Maybe in the future we can parse the response (which is a structured JSON document itself)
# into its own columns
return {"s3_path": s3_path, "page_num": page_num, "response": response, "finish_reason": finish_reason}
def _cache_s3_file(s3_path: str, local_cache_dir: str):
"""
Downloads an S3 object to a local cache directory, ensuring no two writers corrupt the same file.
"""
bucket, key = parse_s3_path(s3_path)
# Define the local file path
local_file_path = os.path.join(local_cache_dir, bucket + "__" + key.replace("/", "_"))
os.makedirs(os.path.dirname(local_file_path), exist_ok=True)
lock_file = f"{local_file_path}.lock"
# Use a file lock to prevent concurrent writes
with FileLock(lock_file):
if not os.path.exists(local_file_path):
logger.info(f"Downloading {s3_path} to {local_file_path}")
s3_client = boto3.client("s3", aws_access_key_id=os.getenv("DS_AWS_ACCESS_KEY_ID"), aws_secret_access_key=os.getenv("DS_AWS_SECRET_ACCESS_KEY"))
s3_client.download_file(bucket, key, local_file_path)
else:
pass
# logger.info(f"File {local_file_path} already exists, skipping download.")
return local_file_path
def cache_s3_files(dataset: Dataset, pdf_cache_location: str, num_proc: int = 32) -> Dataset:
"""
Caches all S3 paths in the dataset to the local cache directory.
"""
# Define the download function to use in parallel processing
def cache_file(example):
s3_path = example["s3_path"]
if s3_path:
# Download the file and cache it locally
local_path = _cache_s3_file(s3_path, pdf_cache_location)
return {"local_pdf_path": local_path}
return {"local_pdf_path": None}
# Map the caching function to the dataset (with parallelism if needed)
dataset = dataset.map(cache_file, num_proc=num_proc, load_from_cache_file=False)
return dataset
def build_finetuning_dataset(response_glob_path: str, pdf_cache_location: Optional[str] = None, num_proc: int = 32) -> Dataset:
if pdf_cache_location is None:
pdf_cache_location = os.path.join(os.path.expanduser("~"), ".cache", "olmocr_pdfs")
logger.info("Loading fine tuning dataset from OpenAI style batch responses")
response_data = load_jsonl_into_ds(response_glob_path)
response_data = response_data["train"]
response_data = response_data.map(extract_openai_batch_response, remove_columns=response_data.column_names, num_proc=num_proc)
# Don't include data where the model cut off due to a length issue, or moderation issue
logger.info("Filtering on finish_reason == stop")
final_dataset = response_data.filter(lambda x: x["finish_reason"] == "stop", num_proc=num_proc)
# Cache all the s3_paths that were accessed to a local storage location,
final_dataset = cache_s3_files(final_dataset, pdf_cache_location, num_proc)
# Filter out pages where you cannot get an anchor text generated, to prevent errors during actual training
def _can_create_anchor_text(example):
try:
anchor_text = get_anchor_text(example["local_pdf_path"], example["page_num"], pdf_engine="pdfreport", target_length=4000)
_ = get_pdf_media_box_width_height(example["local_pdf_path"], example["page_num"])
return anchor_text is not None
except:
logger.exception("Could not generate anchor text for file, be sure you have all dependencies installed")
return False
final_dataset = final_dataset.filter(_can_create_anchor_text, num_proc=num_proc)
return final_dataset
import base64
import random
from io import BytesIO
from typing import Union
import numpy as np
import torch # Make sure to import torch as it's used in the DataCollator
from PIL import Image
from olmocr.data.renderpdf import render_pdf_to_base64png
from olmocr.prompts import build_finetuning_prompt
from olmocr.prompts.anchor import get_anchor_text
def prepare_data_for_qwen2_training(example, processor, target_longest_image_dim: Union[int, list[int]], target_anchor_text_len: Union[int, list[int]]):
if isinstance(target_longest_image_dim, list):
target_longest_image_dim = random.choice(target_longest_image_dim)
if isinstance(target_anchor_text_len, list):
target_anchor_text_len = random.choice(target_anchor_text_len)
anchor_text = get_anchor_text(example["local_pdf_path"], example["page_num"], pdf_engine="pdfreport", target_length=target_anchor_text_len)
base64_page_image = render_pdf_to_base64png(example["local_pdf_path"], example["page_num"], target_longest_image_dim=target_longest_image_dim)
# Prepare messages
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": base64_page_image},
{"type": "text", "text": build_finetuning_prompt(anchor_text)},
],
}
]
# Apply chat template to get the text
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
# Decode image from base64
main_image = Image.open(BytesIO(base64.b64decode(base64_page_image)))
# Process inputs using processor
inputs = processor(
text=[text],
images=[main_image],
padding=True,
return_tensors="np",
)
# Get labels by tokenizing the output text
labels = processor(text=[example["response"]], padding=True, return_tensors="np")
# Append an <|im_end|>\n" to the labels, because this is what it would look like
# if we passed the whole message stream in there
im_end_tokens = processor.tokenizer("<|im_end|>\n", add_special_tokens=False)["input_ids"]
im_end_tokens = np.array(im_end_tokens, dtype=inputs.input_ids.dtype) # Ensure correct dtype
# Handle the case where labels['input_ids'] is empty
if labels["input_ids"].shape[1] == 0:
labels_input_ids_0 = np.array([], dtype=inputs.input_ids.dtype)
else:
labels_input_ids_0 = labels["input_ids"][0].astype(inputs.input_ids.dtype)
labels["input_ids"] = np.concatenate([labels_input_ids_0, im_end_tokens])
labels["input_ids"] = np.expand_dims(labels["input_ids"], axis=0)
# Concatenate input_ids and labels
input_ids = np.concatenate([inputs.input_ids[0], labels.input_ids[0]], axis=0)
# All columns will participate in attention fully
attention_mask = np.ones_like(input_ids)
# Create labels, masking the input portion with -100
labels_full = np.full_like(input_ids, fill_value=-100)
labels_full[len(inputs.input_ids[0]) :] = labels.input_ids[0]
# TODO Maybe cap the max length
# Return as dict, including pixel_values
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels_full,
"pixel_values": inputs.pixel_values,
"image_grid_thw": inputs["image_grid_thw"][0],
}
def batch_prepare_data_for_qwen2_training(batch, processor, target_longest_image_dim: list[int], target_anchor_text_len: list[int]):
# Process each example in the batch using the helper function
processed_examples = []
for i in range(len(batch["response"])):
example = {"local_pdf_path": batch["local_pdf_path"][i], "page_num": batch["page_num"][i], "response": batch["response"][i]}
processed_example = prepare_data_for_qwen2_training(
example, processor, target_longest_image_dim=target_longest_image_dim, target_anchor_text_len=target_anchor_text_len
)
processed_examples.append(processed_example)
return {
"input_ids": [x["input_ids"] for x in processed_examples],
"attention_mask": [x["attention_mask"] for x in processed_examples],
"labels": [x["labels"] for x in processed_examples],
"pixel_values": [x["pixel_values"] for x in processed_examples],
"image_grid_thw": [x["image_grid_thw"] for x in processed_examples],
}
def prepare_data_for_molmo_training(example, processor, target_longest_image_dim: Union[int, list[int]], target_anchor_text_len: Union[int, list[int]]):
if isinstance(target_longest_image_dim, list):
target_longest_image_dim = random.choice(target_longest_image_dim)
if isinstance(target_anchor_text_len, list):
target_anchor_text_len = random.choice(target_anchor_text_len)
anchor_text = get_anchor_text(example["local_pdf_path"], example["page_num"], pdf_engine="pdfreport", target_length=target_anchor_text_len)
base64_page_image = render_pdf_to_base64png(example["local_pdf_path"], example["page_num"], target_longest_image_dim=target_longest_image_dim)
# Decode image from base64
main_image = Image.open(BytesIO(base64.b64decode(base64_page_image)))
# Process the input text and image
inputs = processor.process(
images=[main_image],
text=build_finetuning_prompt(anchor_text),
)
# Get labels by tokenizing the output text
labels = processor.tokenizer(example["response"], return_tensors="np")["input_ids"][0]
# Concatenate input_ids and labels
full_input_ids = torch.cat([inputs["input_ids"], torch.from_numpy(labels)], dim=0)
labels_full = torch.cat([torch.ones_like(inputs["input_ids"]) * -100, torch.from_numpy(labels)], dim=0)
# Create a full attention mask
attention_mask = torch.ones_like(full_input_ids)
# image_input_idx does not need adjustment as images are inserted before labels
image_input_idx = inputs["image_input_idx"]
return {
"input_ids": full_input_ids,
"labels": labels_full,
"images": inputs["images"],
"image_input_idx": image_input_idx,
"image_masks": inputs["image_masks"],
"attention_mask": attention_mask,
}
def batch_prepare_data_for_molmo_training(batch, processor, target_longest_image_dim: list[int], target_anchor_text_len: list[int]):
# Assume batch size 1 and process the single example
example = {"local_pdf_path": batch["local_pdf_path"][0], "page_num": batch["page_num"][0], "response": batch["response"][0]}
processed_example = prepare_data_for_molmo_training(
example, processor, target_longest_image_dim=target_longest_image_dim, target_anchor_text_len=target_anchor_text_len
)
# Return in the same format as the qwen2 function
return {
"input_ids": [processed_example["input_ids"]],
"attention_mask": [processed_example["attention_mask"]],
"labels": [processed_example["labels"]],
"images": [processed_example["images"]],
"image_input_idx": [processed_example["image_input_idx"]],
"image_masks": [processed_example["image_masks"]],
}
import argparse
import concurrent.futures
import json
import os
import boto3
import torch
from smart_open import smart_open
from transformers import Qwen2VLForConditionalGeneration
from olmocr.s3_utils import parse_s3_path
s3_client = boto3.client("s3")
def download_file_from_s3(bucket_name, key, local_file_path):
"""Download a single file from S3."""
s3_client.download_file(bucket_name, key, local_file_path)
print(f"Downloaded {key} to {local_file_path}")
def download_model_from_s3(bucket_name, model_s3_key, local_model_dir):
if not os.path.exists(local_model_dir):
os.makedirs(local_model_dir)
# List objects in the S3 model path
response = s3_client.list_objects_v2(Bucket=bucket_name, Prefix=model_s3_key)
objects = response.get("Contents", [])
# Prepare list of download tasks
download_tasks = []
for obj in objects:
key = obj["Key"]
if key.endswith("/"):
continue # Skip directories
local_file_path = os.path.join(local_model_dir, os.path.basename(key))
download_tasks.append((bucket_name, key, local_file_path))
# Use a ThreadPoolExecutor to download files in parallel
with concurrent.futures.ThreadPoolExecutor() as executor:
futures = [executor.submit(download_file_from_s3, bucket_name, key, local_file_path) for bucket_name, key, local_file_path in download_tasks]
# Wait for all downloads to complete and handle any exceptions
for future in concurrent.futures.as_completed(futures):
try:
future.result() # This will raise any exceptions encountered during download
except Exception as e:
print(f"Error downloading file: {e}")
def upload_file_to_s3(local_file_path, bucket_name, s3_key):
"""Upload a single file to S3."""
try:
s3_client.upload_file(local_file_path, bucket_name, s3_key)
print(f"Uploaded {local_file_path} to s3://{bucket_name}/{s3_key}")
except Exception as e:
print(f"Error uploading {local_file_path} to s3://{bucket_name}/{s3_key}: {e}")
def save_model_to_s3(local_model_dir, bucket_name, s3_model_key):
"""Upload the model directory to S3 in parallel."""
# Collect all file paths to be uploaded
upload_tasks = []
for root, dirs, files in os.walk(local_model_dir):
for file in files:
local_file_path = os.path.join(root, file)
s3_key = os.path.join(s3_model_key, file)
upload_tasks.append((local_file_path, bucket_name, s3_key))
# Use a ThreadPoolExecutor to upload files in parallel
with concurrent.futures.ThreadPoolExecutor() as executor:
futures = [executor.submit(upload_file_to_s3, local_file_path, bucket_name, s3_key) for local_file_path, bucket_name, s3_key in upload_tasks]
# Wait for all uploads to complete and handle any exceptions
for future in concurrent.futures.as_completed(futures):
try:
future.result() # This will raise any exceptions encountered during upload
except Exception as e:
print(f"Error during upload: {e}")
def main():
parser = argparse.ArgumentParser(description="Fix up a Qwen2VL checkpoint saved on s3 or otherwise, so that it will load properly in vllm/birr")
parser.add_argument("s3_path", type=str, help="S3 path to the Hugging Face checkpoint.")
args = parser.parse_args()
qwen_replacement_files = [
# Config is special to fix rope config
"s3://ai2-oe-data/artifacts/Qwen2-VL-7B-Instruct/config.json",
# Tokenizer and preprocessor are just not saved in the usual flow
"https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct/resolve/main/tokenizer.json",
"https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct/resolve/main/tokenizer_config.json",
"https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct/resolve/main/vocab.json",
"https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct/resolve/main/merges.txt",
"https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct/resolve/main/generation_config.json",
"https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct/resolve/main/chat_template.json",
"https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct/resolve/main/preprocessor_config.json",
]
# Now, download the config.json from the original path and verify the architectures
config_path = os.path.join(args.s3_path, "config.json")
with smart_open(config_path, "r") as f:
config_data = json.load(f)
assert config_data["architectures"] == ["Qwen2VLForConditionalGeneration"]
if config_data["torch_dtype"] == "float32":
print("Detected model is float32, this is probably an FSDP checkpoint")
print("Saving to _bf16 location with adjusted parameters")
bucket, prefix = parse_s3_path(args.s3_path)
td = "/tmp/qwen2_checkpoint_saving"
download_model_from_s3(bucket, prefix, td)
print("Downloaded entire model from s3, resaving as bfloat16")
model = Qwen2VLForConditionalGeneration.from_pretrained(td)
model = model.to(torch.bfloat16)
os.makedirs(os.path.join(td, "bf16_checkpoint"), exist_ok=True)
print("Saving...")
model.save_pretrained(os.path.join(td, "bf16_checkpoint"))
print("Uploading")
save_model_to_s3(os.path.join(td, "bf16_checkpoint"), bucket, prefix.rstrip("/") + "/bf16")
args.s3_path = args.s3_path.rstrip("/") + "/bf16"
# Iterate over each file in the replacement list
for replacement_file in qwen_replacement_files:
filename = os.path.basename(replacement_file)
dest_path = os.path.join(args.s3_path, filename)
with smart_open(replacement_file, "rb") as src_file:
data = src_file.read()
with smart_open(dest_path, "wb") as dest_file:
dest_file.write(data)
print("Model updated successfully.")
if __name__ == "__main__":
main()
# Script to generate parquet dataset files to upload to hugging face
# Input is a dataset location /data/jakep/pdfdata/openai_batch_data_v5_1_iabooks_train_done/*.json
# Each json line has a custom id that looks like {"custom_id": "s3://ai2-s2-pdfs/de80/a57e6c57b45796d2e020173227f7eae44232.pdf-1", ... more data}
# Fix this script so that it works, and that it will take a path to an input dataset, and sqllite database location
# And then it will build a parquet file with rows that look like: "id", "url", "page_number", "response"
# Where Id will be the output of parse_pdf_hash plus "-" plus the page number
# The url will be the result of get_uri_from_db
# Rresponse will be NormalizedEntry.text
import argparse
import concurrent.futures
import glob
import json
import multiprocessing
import os
import re
import sqlite3
from dataclasses import dataclass
from typing import Dict, List, Optional, Set, Tuple
from urllib.parse import urlparse
import boto3
import pandas as pd
from pypdf import PdfReader, PdfWriter
from tqdm import tqdm
def parse_pdf_hash(pretty_pdf_path: str) -> Optional[str]:
"""
Extracts a hash from a pretty PDF S3 URL.
For example, given:
s3://ai2-s2-pdfs/de80/a57e6c57b45796d2e020173227f7eae44232.pdf-1
it will return "de80a57e6c57b45796d2e020173227f7eae44232".
"""
# Allow an optional "-<number>" at the end.
if pretty_pdf_path.startswith("s3://ai2-s2-pdfs/"):
pattern = r"s3://ai2-s2-pdfs/([a-f0-9]{4})/([a-f0-9]+)\.pdf(?:-\d+)?$"
match = re.match(pattern, pretty_pdf_path)
if match:
return match.group(1) + match.group(2)
return None
elif pretty_pdf_path.startswith("s3://ai2-oe-data/reganh/iabooks/"):
return urlparse(pretty_pdf_path).path.split("/")[-1]
else:
raise NotImplementedError()
def get_uri_from_db(db_path: str, pdf_hash: str) -> Optional[str]:
"""
Looks up the URL for the given pdf_hash in the sqlite database.
Assumes there is a table called 'pdf_mapping' with a column 'uri'.
"""
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
cursor.execute("SELECT uri FROM pdf_mapping WHERE pdf_hash = ?", (pdf_hash,))
result = cursor.fetchone()
conn.close()
return result[0].strip() if result and result[0] else None
@dataclass(frozen=True)
class NormalizedEntry:
s3_path: str
pagenum: int
text: Optional[str]
finish_reason: Optional[str]
error: Optional[str] = None
@staticmethod
def from_goldkey(goldkey: str, **kwargs):
"""
Constructs a NormalizedEntry from a goldkey string.
The goldkey is expected to be of the format:
<s3_path>-<page_number>
"""
s3_path = goldkey[: goldkey.rindex("-")]
page_num = int(goldkey[goldkey.rindex("-") + 1 :])
return NormalizedEntry(s3_path, page_num, **kwargs)
@property
def goldkey(self):
return f"{self.s3_path}-{self.pagenum}"
def normalize_json_entry(data: dict) -> NormalizedEntry:
"""
Normalizes a JSON entry from any of the supported formats.
It supports:
- Birr: looks for an "outputs" field.
- Already normalized entries: if they contain s3_path, pagenum, etc.
- OpenAI: where the response is in data["response"]["body"]["choices"].
- SGLang: where the response is in data["response"]["choices"].
"""
if "outputs" in data:
# Birr case
if data["outputs"] is None:
text = None
finish_reason = None
else:
text = data["outputs"][0]["text"]
finish_reason = data["outputs"][0]["finish_reason"]
return NormalizedEntry.from_goldkey(
goldkey=data["custom_id"],
text=text,
finish_reason=finish_reason,
error=data.get("completion_error", None),
)
elif all(field in data for field in ["s3_path", "pagenum", "text", "error", "finish_reason"]):
# Already normalized
return NormalizedEntry(**data)
elif "response" in data and "body" in data["response"] and "choices" in data["response"]["body"]:
return NormalizedEntry.from_goldkey(
goldkey=data["custom_id"],
text=data["response"]["body"]["choices"][0]["message"]["content"],
finish_reason=data["response"]["body"]["choices"][0]["finish_reason"],
)
else:
raise ValueError("Unsupported JSON format")
def parse_s3_url(s3_url: str) -> Tuple[str, str]:
"""
Parses an S3 URL of the form s3://bucket/key and returns (bucket, key).
"""
if not s3_url.startswith("s3://"):
raise ValueError(f"Invalid S3 URL: {s3_url}")
s3_path = s3_url[5:]
bucket, key = s3_path.split("/", 1)
return bucket, key
def download_pdf_to_cache(s3_url: str, cache_dir: str) -> Optional[str]:
"""
Downloads the PDF from the given S3 URL into the specified cache directory.
The destination filename is based on the parsed PDF hash.
Returns the path to the downloaded PDF.
"""
try:
bucket, key = parse_s3_url(s3_url)
s3_client = boto3.client("s3")
pdf_hash = parse_pdf_hash(s3_url)
if not pdf_hash:
# Fallback: use a sanitized version of the s3_url
pdf_hash = re.sub(r"\W+", "_", s3_url)
dest_path = os.path.join(cache_dir, f"{pdf_hash}.pdf")
# Avoid re-downloading if already exists
if not os.path.exists(dest_path):
s3_client.download_file(bucket, key, dest_path)
return dest_path
except Exception as e:
print(f"Error downloading {s3_url}: {e}")
return None
def process_pdf_page(s3_url: str, page_number: int, combined_id: str, output_pdf_dir: str, pdf_cache: Dict[str, str]) -> Optional[str]:
"""
Extracts the specified page (1-indexed) from the cached PDF corresponding to s3_url.
Writes a new single-page PDF to the output_pdf_dir using the combined_id as the filename.
Returns the relative path to the new PDF (e.g., "pdfs/<combined_id>.pdf").
"""
try:
local_cached_pdf = pdf_cache.get(s3_url)
if not local_cached_pdf or not os.path.exists(local_cached_pdf):
print(f"Cached PDF not found for {s3_url}")
return None
reader = PdfReader(local_cached_pdf)
# pypdf uses 0-indexed page numbers
page_index = page_number - 1
if page_index < 0 or page_index >= len(reader.pages):
print(f"Page number {page_number} out of range for PDF {s3_url}")
return None
writer = PdfWriter()
writer.add_page(reader.pages[page_index])
output_filename = f"{combined_id}.pdf"
output_path = os.path.join(output_pdf_dir, output_filename)
with open(output_path, "wb") as f_out:
writer.write(f_out)
# Return the relative path (assuming pdfs/ folder is relative to the parquet file location)
return os.path.join("pdfs", output_filename)
except Exception as e:
print(f"Error processing PDF page for {s3_url} page {page_number}: {e}")
return None
def process_file(file_path: str, db_path: str, output_pdf_dir: str, pdf_cache: Dict[str, str]) -> Tuple[List[dict], int]:
"""
Process a single file and return a tuple:
(list of valid rows, number of rows skipped due to missing URL or PDF extraction/filtering).
For each JSON entry, the function:
- Normalizes the JSON.
- Skips entries whose response contains the word "resume" (any case) along with either an email address or a phone number.
- Extracts the PDF hash and builds the combined id.
- Looks up the corresponding URL from the sqlite database.
- Extracts the specified page from the cached PDF and writes it to output_pdf_dir.
- Outputs a row with "id", "url", "page_number", "response".
"""
rows = []
missing_count = 0
email_regex = r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}\b"
phone_regex = r"\b(?:\+?\d{1,3}[-.\s]?)?(?:\(?\d{3}\)?[-.\s]?)?\d{3}[-.\s]?\d{4}\b"
try:
with open(file_path, "r", encoding="utf-8") as f:
for line_num, line in enumerate(f, start=1):
line = line.strip()
if not line:
continue
try:
data = json.loads(line)
except json.JSONDecodeError as e:
print(f"Skipping invalid JSON at {file_path}:{line_num} - {e}")
continue
try:
normalized = normalize_json_entry(data)
except Exception as e:
print(f"Error normalizing entry at {file_path}:{line_num} - {e}")
continue
# Apply filter: skip if response contains "resume" (any case) and an email or phone number.
response_text = normalized.text if normalized.text else ""
if re.search(r"resume", response_text, re.IGNORECASE) and (re.search(email_regex, response_text) or re.search(phone_regex, response_text)):
print(f"Skipping entry due to resume and contact info in response at {file_path}:{line_num}")
continue
# Extract the PDF hash from the s3_path.
pdf_hash = parse_pdf_hash(normalized.s3_path)
if pdf_hash is None:
print(f"Could not parse pdf hash from {normalized.s3_path} at {file_path}:{line_num}")
continue
# The output id is the pdf hash plus '-' plus the page number.
combined_id = f"{pdf_hash}-{normalized.pagenum}"
# Look up the corresponding URL from the sqlite database.
url = get_uri_from_db(db_path, pdf_hash)
if not url:
print(f"Missing URL for pdf hash {pdf_hash} at {file_path}:{line_num}")
missing_count += 1
continue
# Process PDF: extract the specified page from the cached PDF.
local_pdf_path = process_pdf_page(normalized.s3_path, normalized.pagenum, combined_id, output_pdf_dir, pdf_cache)
if local_pdf_path is None:
print(f"Skipping entry because PDF processing failed for {normalized.s3_path} page {normalized.pagenum} at {file_path}:{line_num}")
missing_count += 1
continue
row = {
"id": combined_id,
"url": url,
"page_number": normalized.pagenum,
"response": normalized.text,
}
rows.append(row)
except Exception as e:
print(f"Error processing file {file_path}: {e}")
return rows, missing_count
def scan_file_for_s3_urls(file_path: str) -> Set[str]:
"""
Scans a single file and returns a set of unique S3 URLs found in the JSON entries.
"""
urls = set()
try:
with open(file_path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
try:
data = json.loads(line)
normalized = normalize_json_entry(data)
urls.add(normalized.s3_path)
except Exception:
# Skip entries that cannot be normalized
continue
except Exception as e:
print(f"Error reading file {file_path}: {e}")
return urls
def main():
parser = argparse.ArgumentParser(description="Generate a Parquet dataset file for HuggingFace upload.")
parser.add_argument(
"input_dataset",
help="Input dataset file pattern (e.g., '/data/jakep/pdfdata/openai_batch_data_v5_1_iabooks_train_done/*.json')",
)
parser.add_argument("db_path", help="Path to the SQLite database file.")
parser.add_argument("--output", default="output.parquet", help="Output Parquet file path.")
args = parser.parse_args()
files = glob.glob(args.input_dataset)
print(f"Found {len(files)} files matching pattern: {args.input_dataset}")
# Determine output directory and create 'pdfs' subfolder.
output_abs_path = os.path.abspath(args.output)
output_dir = os.path.dirname(output_abs_path)
pdfs_dir = os.path.join(output_dir, "pdfs")
os.makedirs(pdfs_dir, exist_ok=True)
# Create a temporary directory for caching PDFs.
pdf_cache_dir = "/tmp/pdf_cache"
os.makedirs(pdf_cache_dir, exist_ok=True)
print(f"Caching PDFs to temporary directory: {pdf_cache_dir}")
# ---------------------------------------------------------------------
# Step 1: Scan input files to collect all unique S3 URLs using a ProcessPoolExecutor.
unique_s3_urls: Set[str] = set()
print("Scanning input files to collect unique PDF URLs...")
num_cpus = multiprocessing.cpu_count()
with concurrent.futures.ProcessPoolExecutor(max_workers=num_cpus * 4) as executor:
results = list(tqdm(executor.map(scan_file_for_s3_urls, files), total=len(files), desc="Scanning files"))
for url_set in results:
unique_s3_urls |= url_set
print(f"Found {len(unique_s3_urls)} unique PDF URLs.")
# ---------------------------------------------------------------------
# Step 2: Download all unique PDFs to the cache directory.
pdf_cache: Dict[str, str] = {}
print("Caching PDFs from S3...")
with concurrent.futures.ProcessPoolExecutor(max_workers=num_cpus * 8) as executor:
future_to_url = {executor.submit(download_pdf_to_cache, s3_url, pdf_cache_dir): s3_url for s3_url in unique_s3_urls}
for future in tqdm(concurrent.futures.as_completed(future_to_url), total=len(future_to_url), desc="Downloading PDFs"):
s3_url = future_to_url[future]
try:
local_path = future.result()
if local_path:
pdf_cache[s3_url] = local_path
else:
print(f"Failed to cache PDF for {s3_url}")
except Exception as e:
print(f"Error caching PDF for {s3_url}: {e}")
# ---------------------------------------------------------------------
# Step 3: Process input files using the precached PDFs.
all_rows = []
total_missing = 0
print("Processing files...")
with concurrent.futures.ProcessPoolExecutor() as executor:
futures = {executor.submit(process_file, file_path, args.db_path, pdfs_dir, pdf_cache): file_path for file_path in files}
for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures), desc="Processing files"):
file_path = futures[future]
try:
rows, missing_count = future.result()
all_rows.extend(rows)
total_missing += missing_count
except Exception as e:
print(f"Error processing file {file_path}: {e}")
if all_rows:
df = pd.DataFrame(all_rows)
# Set the "id" column as the index.
df.set_index("id", inplace=True)
df.to_parquet(args.output)
valid_count = len(df)
total_processed = valid_count + total_missing
print(f"Successfully wrote {valid_count} rows to {args.output}")
print(f"Rows skipped due to missing URL/PDF or filtering: {total_missing} out of {total_processed} processed rows")
else:
print("No valid rows to write. Exiting.")
if __name__ == "__main__":
main()
import logging
import os
import tarfile
from math import ceil
from huggingface_hub import HfApi
# Configuration
pdf_dir = "pdfs" # Directory with PDF files (flat structure)
tarball_dir = "tarballs" # Directory where tar.gz files will be saved
os.makedirs(tarball_dir, exist_ok=True)
repo_id = "allenai/olmOCR-mix-0225" # Hugging Face dataset repo ID
# Set up logging to file
logging.basicConfig(filename="upload.log", level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
def process_chunk(args):
"""
Worker function to create a tar.gz file for a given chunk.
Returns a tuple: (chunk_index, success (bool), message).
"""
chunk_index, chunk_files = args
tarball_name = f"pdf_chunk_{chunk_index:04d}.tar.gz"
tarball_path = os.path.join(tarball_dir, tarball_name)
try:
with tarfile.open(tarball_path, "w:gz") as tar:
for pdf_filename in chunk_files:
pdf_path = os.path.join(pdf_dir, pdf_filename)
# Add the file with its basename to maintain a flat structure
tar.add(pdf_path, arcname=pdf_filename)
logging.info(f"Chunk {chunk_index:04d}: Created '{tarball_name}' with {len(chunk_files)} PDFs.")
return chunk_index, True, "Success"
except Exception as e:
error_msg = f"Chunk {chunk_index:04d}: Error creating '{tarball_name}': {e}"
logging.error(error_msg)
return chunk_index, False, error_msg
def main():
# List all PDF files (assuming a flat directory)
try:
pdf_files = sorted([f for f in os.listdir(pdf_dir) if f.lower().endswith(".pdf")])
except Exception as e:
logging.error(f"Error listing PDFs in '{pdf_dir}': {e}")
return
total_files = len(pdf_files)
chunk_size = 5000
total_chunks = ceil(total_files / chunk_size)
logging.info(f"Found {total_files} PDFs; dividing into {total_chunks} chunks of up to {chunk_size} files each.")
# # Enumerate chunks (starting at 0000)
# chunks = []
# for idx in range(total_chunks):
# start = idx * chunk_size
# end = start + chunk_size
# chunk_files = pdf_files[start:end]
# chunks.append((idx, chunk_files))
# # Create tarballs in parallel
# results = []
# with ProcessPoolExecutor() as executor:
# futures = {executor.submit(process_chunk, chunk): chunk for chunk in chunks}
# for future in tqdm(as_completed(futures), total=len(futures), desc="Creating tarballs"):
# try:
# result = future.result()
# results.append(result)
# chunk_index, success, message = result
# if not success:
# logging.error(f"Chunk {chunk_index:04d} failed: {message}")
# except Exception as e:
# logging.error(f"Unexpected error processing a chunk: {e}")
# # Abort upload if any tarball creation failed
# failed_chunks = [r for r in results if not r[1]]
# if failed_chunks:
# logging.error(f"{len(failed_chunks)} chunk(s) failed to create. Aborting upload.")
# return
# All tarballs created successfully; now upload the entire tarball directory
api = HfApi()
logging.info("Starting upload of tarballs folder to Hugging Face Hub...")
# This will upload all files in tarball_dir to the repo under "pdf_tarballs"
api.upload_large_folder(
folder_path=tarball_dir,
repo_id=repo_id,
# path_in_repo="pdf_tarballs",
repo_type="dataset",
)
logging.info("Successfully uploaded tarballs folder to Hugging Face Hub.")
if __name__ == "__main__":
main()
#!/usr/bin/env python3
import argparse
import sqlite3
from concurrent.futures import ThreadPoolExecutor
from functools import partial
import boto3
from tqdm import tqdm
from warcio.archiveiterator import ArchiveIterator
def parse_s3_path(s3_path):
"""
Parses an S3 path of the form s3://bucket/prefix and returns the bucket and prefix.
"""
if not s3_path.startswith("s3://"):
raise ValueError("S3 path must start with s3://")
without_prefix = s3_path[5:]
parts = without_prefix.split("/", 1)
bucket = parts[0]
prefix = parts[1] if len(parts) > 1 else ""
return bucket, prefix
def list_s3_warc_objects(s3_path, suffix=".warc.gz"):
"""
Lists all objects under the given S3 path that end with the provided suffix.
Uses a paginator to handle large result sets.
"""
bucket, prefix = parse_s3_path(s3_path)
s3_client = boto3.client("s3")
paginator = s3_client.get_paginator("list_objects_v2")
warc_keys = []
for page in paginator.paginate(Bucket=bucket, Prefix=prefix):
if "Contents" in page:
for obj in page["Contents"]:
key = obj["Key"]
if key.endswith(suffix):
warc_keys.append(key)
return bucket, warc_keys, s3_client
def extract_target_uri_s3(bucket, key, s3_client, head_bytes=1048576):
"""
Retrieves the first head_bytes bytes (1 MB by default) from the S3 object using a range request,
and extracts the first response record's target URI from the HTTP headers.
"""
target_uri = None
try:
response = s3_client.get_object(Bucket=bucket, Key=key, Range=f"bytes=0-{head_bytes-1}")
stream = response["Body"]
for record in ArchiveIterator(stream):
for name, value in record.rec_headers.headers:
if name == "WARC-Target-URI":
target_uri = value
break
if target_uri:
break # Only use the first valid response record
except Exception as e:
tqdm.write(f"Error processing s3://{bucket}/{key}: {e}")
return target_uri
def create_db(db_path):
"""
Creates (or opens) the SQLite database and ensures that the pdf_mapping table exists,
including an index on pdf_hash.
"""
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS pdf_mapping (
pdf_hash TEXT PRIMARY KEY,
uri TEXT
)
"""
)
cursor.execute(
"""
CREATE INDEX IF NOT EXISTS idx_pdf_hash ON pdf_mapping (pdf_hash)
"""
)
conn.commit()
return conn
def process_warc_file(key, bucket, s3_client):
"""
Processes a single WARC file from S3 and returns a tuple (pdf_hash, uri)
if successful, otherwise returns None.
"""
uri = extract_target_uri_s3(bucket, key, s3_client, head_bytes=1048576)
if uri:
# Derive pdf_hash as the file's basename with .warc.gz replaced by .pdf.
pdf_hash = key.split("/")[-1].replace(".warc.gz", ".pdf")
return (pdf_hash, uri)
else:
tqdm.write(f"Warning: No valid response record found in s3://{bucket}/{key}")
return None
def process_s3_folder(s3_path, db_path):
"""
Lists all .warc.gz files under the provided S3 path, then processes each file in parallel
to extract the target URI from the HTTP headers. The resulting mapping (derived from the file's
basename with .warc.gz replaced by .pdf) is stored in the SQLite database.
"""
bucket, warc_keys, s3_client = list_s3_warc_objects(s3_path, suffix=".warc.gz")
conn = create_db(db_path)
cursor = conn.cursor()
# Process WARC files concurrently using ThreadPoolExecutor.
results = []
func = partial(process_warc_file, bucket=bucket, s3_client=s3_client)
with ThreadPoolExecutor() as executor:
for result in tqdm(executor.map(func, warc_keys), total=len(warc_keys), desc="Processing S3 WARC files"):
if result is not None:
results.append(result)
# Bulk insert into the database.
conn.execute("BEGIN")
for pdf_hash, uri in results:
cursor.execute("INSERT OR REPLACE INTO pdf_mapping (pdf_hash, uri) VALUES (?, ?)", (pdf_hash, uri))
conn.commit()
conn.close()
def main():
parser = argparse.ArgumentParser(description="Create an SQLite database mapping PDF file names to target URIs from S3 WARC files.")
parser.add_argument("s3_path", help="S3 path (e.g., s3://bucket/prefix) containing .warc.gz files")
parser.add_argument("db_file", help="Path for the output SQLite database file")
args = parser.parse_args()
process_s3_folder(args.s3_path, args.db_file)
if __name__ == "__main__":
main()
import base64
from io import BytesIO
import torch
import torch.distributed
from PIL import Image
from transformers import AutoConfig, AutoProcessor, Qwen2_5_VLForConditionalGeneration
from olmocr.data.renderpdf import render_pdf_to_base64png
from olmocr.prompts.anchor import get_anchor_text
from olmocr.prompts.prompts import build_openai_silver_data_prompt
@torch.no_grad()
def run_inference(model_name: str):
config = AutoConfig.from_pretrained(model_name)
processor = AutoProcessor.from_pretrained(model_name)
# If it doesn't load, change the type:mrope key to "default"
# model = Qwen2VLForConditionalGeneration.from_pretrained(model_name, device_map="auto", config=config)
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_name, device_map="auto", config=config)
model.eval()
# local_pdf_path = os.path.join(os.path.dirname(__file__), "..", "..", "tests", "gnarly_pdfs", "horribleocr.pdf")
local_pdf_path = "/root/brochure.pdf"
page = 1
image_base64 = render_pdf_to_base64png(local_pdf_path, page, 1024)
anchor_text = get_anchor_text(local_pdf_path, page, pdf_engine="pdfreport")
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": build_openai_silver_data_prompt(anchor_text)},
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_base64}"}},
],
}
]
# Preparation for inference
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
main_image = Image.open(BytesIO(base64.b64decode(image_base64)))
inputs = processor(
text=[text],
images=[main_image],
padding=True,
return_tensors="pt",
)
inputs = inputs.to("cuda")
output_ids = model.generate(**inputs, temperature=0.8, do_sample=True, max_new_tokens=1500)
generated_ids = [output_ids[len(input_ids) :] for input_ids, output_ids in zip(inputs["input_ids"], output_ids)]
output_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
print(output_text[0])
def main():
run_inference(model_name="Qwen/Qwen2.5-VL-7B-Instruct")
if __name__ == "__main__":
main()
from transformers import AutoProcessor
from olmocr.train.core.cli import make_cli
from olmocr.train.core.config import TrainConfig
from .utils import make_dataset
def main():
train_config = make_cli(TrainConfig) # pyright: ignore
processor = AutoProcessor.from_pretrained(train_config.model.name_or_path, trust_remote_code=True)
train_dataset, valid_dataset = make_dataset(train_config, processor)
print("Training dataset........")
print(train_dataset)
train_example = train_dataset[0]
print(train_example)
print({(x, y.shape) for x, y in train_example.items()})
print("\nTokens")
print(processor.tokenizer.batch_decode(train_example["input_ids"]))
print("\n\n")
print("Validation dataset........")
print(valid_dataset)
print(valid_dataset[list(valid_dataset.keys())[0]][0])
print("\n\n")
print("Datasets loaded into hugging face cache directory")
# data_collator = TruncatingCollator(
# max_length=4096
# )
# train_dataloader = DataLoader(train_dataset, batch_size=1, num_workers=4, shuffle=False, collate_fn=data_collator)
# max_seen_len = 0
# for index, entry in tqdm(enumerate(train_dataloader)):
# if index == 0:
# print(entry)
# num_input_tokens = entry["input_ids"].shape[1]
# max_seen_len = max(max_seen_len, num_input_tokens)
# print(max_seen_len)
if __name__ == "__main__":
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment