Commit 89e60e48 authored by wanglch's avatar wanglch
Browse files

Initial commit

parents
Pipeline #2484 canceled with stages
from transformers import PretrainedConfig
class MolmoConfig(PretrainedConfig):
model_type = "molmo"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=50304,
embedding_size=50304,
hidden_size=4096,
intermediate_size=11008,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=None,
max_position_embeddings=2048,
initializer_range=0.02,
use_cache=True,
layer_norm_eps: float = 1e-5,
rope_theta=10000.0,
clip_qkv=None,
qkv_bias: bool = False,
weight_tying: bool = False,
use_position_ids: bool = True,
tie_word_embeddings: bool = True,
attention_layer_norm: bool = False,
norm_after: bool = False,
layer_norm_type: str = "rms",
**kwargs,
):
self.vocab_size = vocab_size
self.embedding_size = embedding_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.layer_norm_eps = layer_norm_eps
self.weight_tying = weight_tying
self.use_position_ids = use_position_ids
self.attention_layer_norm = attention_layer_norm
self.num_key_value_heads = num_key_value_heads
self.initializer_range = initializer_range
self.use_cache = use_cache
self.rope_theta = rope_theta
self.clip_qkv = clip_qkv
self.qkv_bias = qkv_bias
self.norm_after = norm_after
self.tie_word_embeddings = tie_word_embeddings
self.layer_norm_type = layer_norm_type
super().__init__(
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
MolmoConfig.register_for_auto_class()
"""Image processor class for Molmo"""
from typing import List, Optional, Union
import einops
import numpy as np
import torch
import torchvision.transforms
from torchvision.transforms import InterpolationMode
from torchvision.transforms.functional import convert_image_dtype
from transformers.image_processing_utils import BaseImageProcessor
from transformers.image_utils import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD, ImageInput
from transformers.processing_utils import ImagesKwargs
from transformers.utils import logging
logger = logging.get_logger(__name__)
def pad_to_bounding_box(image, offset_height, offset_width, target_height, target_width, value=0):
height, width = image.shape[:2]
after_padding_width = target_width - offset_width - width
after_padding_height = target_height - offset_height - height
return np.pad(image, [[offset_height, after_padding_height], [offset_width, after_padding_width], [0, 0]], constant_values=value)
def normalize_image(image, offset, scale):
image -= np.array(offset, dtype=np.float32)[None, None, :]
image /= np.array(scale, dtype=np.float32)[None, None, :]
return image
def resize_and_pad(
image,
desired_output_size,
resize_method="torch-bilinear",
pad_value=0,
normalize=True,
image_mean=OPENAI_CLIP_MEAN,
image_std=OPENAI_CLIP_STD,
):
desired_height, desired_width = desired_output_size
height, width = image.shape[:2]
# Cast into float32 since the training code did this in float32 and it (very rarely) effects
# the results after rounding.
image_scale_y = np.array(desired_height, np.float32) / np.array(height, np.float32)
image_scale_x = np.array(desired_width, np.float32) / np.array(width, np.float32)
image_scale = min(image_scale_x, image_scale_y)
scaled_height = int(np.array(height, np.float32) * image_scale)
scaled_width = int(np.array(width, np.float32) * image_scale)
if resize_method == "tensorflow":
# This how the original training code did resizing, it can produce slightly different
# results then using torch resize so we keep it just in case
import tensorflow as tf
image = tf.image.convert_image_dtype(tf.constant(image), dtype=tf.float32)
image = tf.image.resize(
image,
[scaled_height, scaled_width],
method=tf.image.ResizeMethod.BILINEAR,
antialias=True,
)
image = tf.clip_by_value(image, 0.0, 1.0)
image = image.numpy()
elif resize_method == "torch-bilinear":
image = torch.permute(torch.from_numpy(image), [2, 0, 1])
image = convert_image_dtype(image) # resize in float32 to match the training code
image = torchvision.transforms.Resize([scaled_height, scaled_width], InterpolationMode.BILINEAR, antialias=True)(image)
image = torch.clip(image, 0.0, 1.0)
image = torch.permute(image, [1, 2, 0]).numpy()
else:
raise NotImplementedError(resize_method)
top_pad = (desired_height - scaled_height) // 2
left_pad = (desired_width - scaled_width) // 2
padding = [[top_pad, desired_height - scaled_height - top_pad], [left_pad, desired_width - scaled_width - left_pad], [0, 0]]
image_mask = np.pad(np.ones_like(image[:, :, 0], dtype=bool), padding[:2])
image = np.pad(image, padding, constant_values=pad_value)
if normalize:
image = normalize_image(image, offset=image_mean, scale=image_std)
return image, image_mask
def select_tiling(h, w, patch_size, max_num_patches):
"""Decide how best to divide in image of size [w, h] in up to max_num_patches of size patch_size"""
original_size = np.stack([h, w]) # [1, 2]
original_res = h * w
tilings = []
for i in range(1, max_num_patches + 1):
for j in range(1, max_num_patches + 1):
if i * j <= max_num_patches:
tilings.append((i, j))
# sort so argmin and argmax favour smaller tilings in the event of a tie
tilings.sort(key=lambda x: (x[0] * x[1], x[0]))
candidate_tilings = np.array(tilings, dtype=np.int32) # [n_resolutions, 2]
candidate_resolutions = candidate_tilings * patch_size # [n_resolutions, 2]
# How much we would need to scale the image to fit exactly in each tiling
original_size = np.stack([h, w], dtype=np.float32) # [1, 2]
required_scale_d = candidate_resolutions.astype(np.float32) / original_size
required_scale = np.min(required_scale_d, axis=-1, keepdims=True) # [n_resolutions, 1]
if np.all(required_scale < 1):
# We are forced to downscale, so try to minimize the amount of downscaling
ix = np.argmax(required_scale)
else:
# Pick the resolution that required the least upscaling so that it most closely fits the image
required_scale = np.where(required_scale < 1.0, 10e9, required_scale)
ix = np.argmin(required_scale)
return candidate_tilings[ix]
class MolmoImagesKwargs(ImagesKwargs, total=False):
max_crops: Optional[int]
overlap_margins: Optional[List[int]]
base_image_input_size: Optional[List[int]]
image_token_length_w: Optional[int]
image_token_length_h: Optional[int]
image_patch_size: Optional[int]
image_padding_mask: Optional[bool]
class MolmoImageProcessor(BaseImageProcessor):
"""Preprocess images and multi-model inputs"""
def __init__(
self,
max_crops: int = 12,
overlap_margins: List[int] = (4, 4),
base_image_input_size: List[int] = (336, 336),
image_token_length_w: int = 12,
image_token_length_h: int = 12,
image_patch_size: int = 14,
image_padding_mask: bool = True,
do_normalize: bool = True,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
**kwargs,
):
super().__init__(**kwargs)
self.max_crops = max_crops
self.overlap_margins = overlap_margins
self.base_image_input_size = base_image_input_size
self.image_token_length_w = image_token_length_w
self.image_token_length_h = image_token_length_h
self.image_patch_size = image_patch_size
self.image_padding_mask = image_padding_mask
self.do_normalize = do_normalize
self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
def image_to_patches_and_tokens(
self,
image: ImageInput,
image_patch_token_id: int,
image_col_token_id: int,
image_start_token_id: int,
image_end_token_id: int,
max_crops: Optional[int] = None,
overlap_margins: Optional[List[int]] = None,
base_image_input_size: Optional[Union[int, List[int]]] = None,
image_token_length_w: Optional[int] = None,
image_token_length_h: Optional[int] = None,
image_patch_size: Optional[int] = None,
):
if isinstance(base_image_input_size, int):
base_image_input_size = (base_image_input_size, base_image_input_size)
base_image_input_d = image_patch_size
tokens_per_image = image_token_length_w * image_token_length_h
image_base_patch_w = base_image_input_size[1] // base_image_input_d
image_base_patch_h = base_image_input_size[0] // base_image_input_d
original_image_h, original_image_w = image.shape[:2]
crop_size = base_image_input_size[0]
# Discard this many patches from the (left/top, right/bottom) of crops
left_margin, right_margin = overlap_margins
# left_margin, right_margin = 2, 2
assert left_margin % 2 == 0 # Required for compatibility with 2x2 pooling
total_margin_pixels = base_image_input_d * (right_margin + left_margin) # pixels removed per dim
crop_patches = base_image_input_size[0] // base_image_input_d # patches per crop dim
crop_window_patches = crop_patches - (right_margin + left_margin) # usable patches
crop_window_size = crop_window_patches * base_image_input_d
tiling = select_tiling(original_image_h - total_margin_pixels, original_image_w - total_margin_pixels, crop_window_size, max_crops)
src, img_mask = resize_and_pad(image, [tiling[0] * crop_window_size + total_margin_pixels, tiling[1] * crop_window_size + total_margin_pixels])
# Now we have to split the image into crops, while keeping track of how each patch in the
# each crop should be ordered in the global image, this require a lot of tricky booking
n_crops = tiling[0] * tiling[1]
patches_arr = []
mask_arr = []
patch_ordering_arr = []
# We assume 2x2 pooling, but can allow padding the right/bottom with extra
# patches if the number of patches per side is not even
assert (crop_patches + 1) // 2 == image_token_length_h
assert (crop_patches + 1) // 2 == image_token_length_w
on = 0
on_patch = 0
for i in range(tiling[0]):
y0 = i * crop_window_size
if i == 0:
crop_y0 = 0
else:
crop_y0 = left_margin // 2
crop_h = image_base_patch_h - (right_margin + left_margin)
if i == 0:
crop_h += left_margin
if i == (tiling[0] - 1):
crop_h += right_margin
for j in range(tiling[1]):
x0 = j * crop_window_size
if j == 0:
crop_x0 = 0
else:
crop_x0 = left_margin // 2
crop_w = image_base_patch_w - (right_margin + left_margin)
if j == 0:
crop_w += left_margin
if j == (tiling[1] - 1):
crop_w += right_margin
pooled_w = (crop_w + 1) // 2
pooled_h = (crop_h + 1) // 2
patch_ordering_arr.append(
pad_to_bounding_box(
np.reshape(np.arange(on, on + pooled_h * pooled_w, dtype=np.int32), (pooled_h, pooled_w, 1)),
crop_y0,
crop_x0,
image_token_length_h,
image_token_length_w,
value=-1,
)[:, :, 0]
)
patches_arr.append(src[y0 : y0 + crop_size, x0 : x0 + crop_size])
mask_arr.append(img_mask[y0 : y0 + crop_size, x0 : x0 + crop_size])
on += pooled_h * pooled_w
on_patch += 1
patches = np.stack(patches_arr)
patch_ordering = np.stack(patch_ordering_arr)
img_mask = np.stack(mask_arr)
# Switch to [n_crops, n_patches, pixels_per_patch] format
image_layout_impatch_w, image_layout_impatch_h = tiling[0], tiling[1]
patches = einops.rearrange(
patches, "p (h dh) (w dw) c -> p (h w) (dh dw c)", dh=base_image_input_d, dw=base_image_input_d, h=image_base_patch_h, w=image_base_patch_w
)
img_mask = einops.rearrange(
img_mask, "p (h dh) (w dw) -> p (h w) (dh dw)", dh=base_image_input_d, dw=base_image_input_d, h=image_base_patch_h, w=image_base_patch_w
)
img_mask = img_mask.astype(np.float32).mean(axis=-1)
patch_ordering = np.reshape(patch_ordering, [-1])
valid = patch_ordering >= 0
# Transpose order, to get left-to-right order instead of crop-by-crop order
patch_ordering_rh = np.reshape(patch_ordering, [tiling[0], tiling[1], image_token_length_h, image_token_length_w])
patch_ordering_rh = np.transpose(patch_ordering_rh, [0, 2, 1, 3])
patch_ordering_rh = np.reshape(patch_ordering_rh, [-1])
# The transpose will screw up which patches are masked, project the
# new order into sparse structure of `patch_ordering` to fix this
patch_ordering[valid] = patch_ordering_rh[patch_ordering_rh >= 0]
# Now build the output tokens
h = tiling[0] * crop_window_patches + (right_margin + left_margin)
w = tiling[1] * crop_window_patches + (right_margin + left_margin)
per_row = np.full(
((w + 1) // 2,),
image_patch_token_id,
)
per_row = np.concatenate([per_row, [image_col_token_id]], 0)
joint = np.tile(per_row, [(h + 1) // 2])
joint = [[image_start_token_id], joint, [image_end_token_id]]
# Finally do the same for the global image
resized, _ = resize_and_pad(image, base_image_input_size)
resized = einops.rearrange(
resized, "(h dh) (w dw) c -> (h w) (dh dw c)", dh=base_image_input_d, dw=base_image_input_d, h=image_base_patch_h, w=image_base_patch_w
)
patches = np.concatenate([np.expand_dims(resized, 0), patches], 0)
# Global image goes first, so the order of patches in previous crops gets increased
patch_ordering = np.where(patch_ordering >= 0, patch_ordering + tokens_per_image, -1)
patch_ordering = np.concatenate([np.arange(0, tokens_per_image), patch_ordering], 0)
per_row = np.full(
(image_token_length_w,),
image_patch_token_id,
)
per_row = np.concatenate([per_row, [image_col_token_id]], 0)
extra_tokens = np.tile(per_row, [image_token_length_h])
joint = [
[image_start_token_id],
extra_tokens,
[image_end_token_id],
] + joint
joint = np.concatenate(joint, 0)
img_mask = np.pad(img_mask, [[0, 1], [0, 0]], constant_values=-1)
return patches, joint, patch_ordering, img_mask
def build_image_input_idx(
self,
image_tokens: np.ndarray,
patch_order: np.ndarray,
image_patch_token_id: int,
no_image: Optional[bool] = None,
image_token_length_w: Optional[int] = None,
image_token_length_h: Optional[int] = None,
):
"""Converts `patch_order` into a mapping of token_id -> patch_id"""
tokens_per_image = image_token_length_w * image_token_length_h
if no_image is not None and no_image:
return np.zeros((0, tokens_per_image), np.int32)
# Indices to insert the patches
image_input_idx = image_tokens == image_patch_token_id
image_input_idx = np.nonzero(image_input_idx)[0].astype(np.int32)
if patch_order is not None:
n_tokens = image_input_idx.shape[0]
patch_order = np.reshape(patch_order, [-1])
n_patches = patch_order.shape[0]
valid = patch_order >= 0
n_valid_patches = valid.sum()
assert len(image_input_idx) == n_valid_patches
sorted_patch_ixs = np.zeros([n_tokens], np.int32)
sorted_patch_ixs[patch_order[valid]] = np.arange(n_valid_patches, dtype=np.int32)
# Project the inverted mapping into same sparse structure
sorted_patch_ixs_ex = np.full(np.shape(patch_order), -1)
sorted_patch_ixs_ex[valid] = sorted_patch_ixs
# Do the gather and then re-masked outputs that were masked in `sorted_patch_ixs`
valid = (sorted_patch_ixs_ex >= 0).astype(np.int32)
image_input_idx = image_input_idx[sorted_patch_ixs_ex * valid]
image_input_idx = image_input_idx * valid - 100 * (1 - valid)
image_input_idx = np.reshape(image_input_idx, [-1, tokens_per_image])
return image_input_idx
def preprocess(
self,
image: np.ndarray,
image_patch_token_id: int,
image_col_token_id: int,
image_start_token_id: int,
image_end_token_id: int,
max_crops: Optional[int] = None,
overlap_margins: Optional[List[int]] = None,
base_image_input_size: Optional[Union[int, List[int]]] = None,
image_token_length_w: Optional[int] = None,
image_token_length_h: Optional[int] = None,
image_patch_size: Optional[int] = None,
**kwargs,
):
"""Preprocesses an image
Returns:
crops: (n_crops, n_patches, patch_dim) individual crops, `n_crops` might
change between images but the other dimension are fixed
tokens: (n_tokens,) int32 tokens, pad tokens indicate where to insert the
patch features, might include other special tokens as well
image_idx: (n_crops, n_patches) index in `tokens` to put the patch features from the
crops after pooling, negative values indicates patches features to exclude
padding_mask: (n_crops, n_patches) what percent of each crop is padding, can be None
if the image mask is not being used.
"""
max_crops = max_crops or self.max_crops
overlap_margins = overlap_margins or self.overlap_margins
base_image_input_size = base_image_input_size or self.base_image_input_size
image_token_length_w = image_token_length_w or self.image_token_length_w
image_token_length_h = image_token_length_h or self.image_token_length_h
image_patch_size = image_patch_size or self.image_patch_size
crops, image_tokens, patch_ordering, img_mask = self.image_to_patches_and_tokens(
image,
image_patch_token_id,
image_col_token_id,
image_start_token_id,
image_end_token_id,
max_crops,
overlap_margins,
base_image_input_size,
image_token_length_w,
image_token_length_h,
image_patch_size,
)
patch_idx = self.build_image_input_idx(
image_tokens,
patch_ordering,
image_patch_token_id,
image_token_length_w=image_token_length_w,
image_token_length_h=image_token_length_h,
)
return crops, image_tokens, patch_idx, img_mask
def multimodal_preprocess(
self,
images: np.ndarray,
tokens: List[int],
image_idx: np.ndarray,
sequence_length: int,
image_patch_token_id: int,
image_col_token_id: int,
image_start_token_id: int,
image_end_token_id: int,
**kwargs,
):
"""Merge images and text tokens into multi-modal features for the model
:param images: images to use as input
:param tokens: input text tokens
:param image_idx: where to insert the images into `tokens`
:params image_patch_token_id: id to use of tokens that will contain image features
:params image_col_token_id: token id for image column special tokens
:params image_start_token_id: token id for image start special tokens
:params image_end_token_id: token id for image end special tokens
:params kwargs: override preprocessor default args
"""
max_total_crops = kwargs.get("max_crops") or self.max_crops
image_token_length_w = kwargs.get("image_token_length_w") or self.image_token_length_w
image_token_length_h = kwargs.get("image_token_length_h") or self.image_token_length_h
image_patch_size = kwargs.get("image_patch_size") or self.image_patch_size
base_image_input_size = kwargs.get("base_image_input_size") or self.base_image_input_size
image_num_patch = (
base_image_input_size[0] // image_patch_size,
base_image_input_size[1] // image_patch_size,
)
image_padding_mask = kwargs.get("image_padding_mask") or self.image_padding_mask
tokens_per_image = image_token_length_w * image_token_length_h
n_pixels = image_patch_size * image_patch_size * 3
n_patches = image_num_patch[0] * image_num_patch[1]
if images is None:
return {
"input_ids": tokens,
}
else:
n = len(images)
all_crops = []
all_image_idx = []
out_tokens = []
all_crop_masks = []
for ix in range(n):
token_ix = image_idx[ix]
crops, image_tokens, patch_idx, img_mask = self.preprocess(
images[ix],
image_patch_token_id,
image_col_token_id,
image_start_token_id,
image_end_token_id,
**kwargs,
)
if token_ix == -1: # -1 is an image inserted at the very start
start = 0
token_ix = 0
end = 0
else:
start = 0 if ix == 0 else image_idx[ix - 1] + 1
end = token_ix + 1
all_image_idx.append(patch_idx + token_ix)
all_crops.append(crops)
out_tokens.append(tokens[start:token_ix])
out_tokens.append(image_tokens)
if ix == (n - 1):
out_tokens.append(tokens[end:])
if image_padding_mask:
all_crop_masks.append(img_mask)
input_ids = np.concatenate(out_tokens, 0)
images = np.concatenate(all_crops, 0)
image_input_idx = np.concatenate(all_image_idx, 0)
if image_padding_mask:
image_masks = np.concatenate(all_crop_masks, 0)
else:
image_masks = None
out = {"input_ids": input_ids, "images": images, "image_input_idx": image_input_idx}
if image_masks is not None:
out["image_masks"] = image_masks
return out
MolmoImageProcessor.register_for_auto_class()
# type: ignore
import logging
import math
from copy import deepcopy
from dataclasses import dataclass, replace
from enum import Enum
from typing import (
Any,
Callable,
Dict,
List,
MutableMapping,
Optional,
Sequence,
Tuple,
Union,
cast,
)
import torch
from einops import einops
from torch import nn
from torch.nn import functional as F
from transformers import GenerationConfig, PreTrainedModel
from transformers.cache_utils import Cache
from transformers.modeling_outputs import CausalLMOutputWithPast, ModelOutput
from transformers.models.auto import AutoModelForCausalLM
from .config_molmo import MolmoConfig
log = logging.getLogger(__name__)
class BufferCache(dict, MutableMapping[str, torch.Tensor]):
"""
Cache for attention biases and other things that would normally be stored as buffers.
We avoid using buffers because we've run into various issues doing so with FSDP.
In general it appears the way FSDP handles buffers is not well-defined.
It doesn't shard them but apparently it does synchronize them across processes, which we want to avoid
since (A) it isn't necessary, and (B) we sometimes have `-inf` in these biases which might get turned into
NaNs when they're synchronized due to casting or some other issue.
"""
class StrEnum(str, Enum):
def __str__(self) -> str:
return self.value
def __repr__(self) -> str:
return f"'{str(self)}'"
class ImageProjectType(StrEnum):
mlp = "mlp"
mlpx2 = "2mlp"
linear = "linear"
class ImagePooling2DType(StrEnum):
attention = "attention"
attention_meanq = "attention-meanq"
attention_2wide = "attention_2wide"
attention_v2 = "attention-v2"
none = "none"
stack = "stack"
class ActivationType(StrEnum):
quick_gelu = "quick_gelu"
gelu = "gelu"
gelu_tanh = "gelu_tanh"
relu = "relu"
silu = "silu"
llama_geglu = "llama_geglu"
llama_geglu_tanh = "llama_geglu_tanh"
llama_swiglu = "llama_swiglu"
swiglu = "swiglu"
def ensure_finite_(x: torch.Tensor, check_neg_inf: bool = True, check_pos_inf: bool = False):
"""
Modify ``x`` in place to replace ``float("-inf")`` with the minimum value of the dtype when ``check_neg_inf``
is ``True`` and to replace ``float("inf")`` with the maximum value of the dtype when ``check_pos_inf`` is ``True``.
"""
if check_neg_inf:
x.masked_fill_(x == float("-inf"), torch.finfo(x.dtype).min)
if check_pos_inf:
x.masked_fill_(x == float("inf"), torch.finfo(x.dtype).max)
class MolmoConfigurationError(Exception):
pass
def _non_meta_init_device(config) -> torch.device:
if config.init_device is not None and config.init_device != "meta":
return torch.device(config.init_device)
else:
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
class RotaryEmbedding(nn.Module):
"""
[Rotary positional embeddings (RoPE)](https://arxiv.org/abs/2104.09864).
"""
def __init__(self, config: MolmoConfig, cache: BufferCache):
super().__init__()
self.config = config
self.__cache = cache
# Warm up cache.
self.get_rotary_embedding(config.max_position_embeddings or config.max_sequence_length, _non_meta_init_device(config))
def get_rotary_embedding(self, seq_len: int, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
if (
(pos_sin := self.__cache.get("rope_pos_sin")) is not None
and (pos_cos := self.__cache.get("rope_pos_cos")) is not None
and pos_sin.shape[-2] >= seq_len
and pos_cos.shape[-2] >= seq_len
):
if pos_sin.device != device:
pos_sin = pos_sin.to(device)
self.__cache["rope_pos_sin"] = pos_sin
if pos_cos.device != device:
pos_cos = pos_cos.to(device)
self.__cache["rope_pos_cos"] = pos_cos
return pos_sin[:, :, :seq_len, :], pos_cos[:, :, :seq_len, :]
with torch.autocast(device.type, enabled=False):
dim = self.config.d_model // self.config.n_heads
inv_freq = 1.0 / (self.config.rope_theta ** (torch.arange(0, dim, 2, device=device, dtype=torch.float) / dim))
seq = torch.arange(seq_len, device=device, dtype=torch.float)
freqs = torch.einsum("i , j -> i j", seq, inv_freq)
if self.config.rope_impl == "interleave":
positions = freqs.repeat_interleave(2, dim=-1)
else:
positions = torch.cat((freqs, freqs), dim=-1)
pos_sin, pos_cos = positions.sin()[None, None, :, :], positions.cos()[None, None, :, :]
self.__cache["rope_pos_sin"] = pos_sin
self.__cache["rope_pos_cos"] = pos_cos
return pos_sin, pos_cos
def rotate_half(self, x: torch.Tensor) -> torch.Tensor:
B, nh, T, hs = x.size()
x = x.view(B, nh, T, 2, hs // 2)
x1, x2 = x.unbind(dim=-2)
return torch.cat((-x2, x1), dim=-1)
def rotate_every_two(self, x: torch.Tensor) -> torch.Tensor:
B, nh, T, hs = x.size()
x = x.view(B, nh, T, hs // 2, 2)
x1, x2 = x.unbind(dim=-1)
x = torch.stack((-x2, x1), dim=-1)
return x.view(B, nh, T, hs)
def apply_rotary_pos_emb(self, pos_sin: torch.Tensor, pos_cos: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
if self.config.rope_impl == "interleave":
return ((t * pos_cos) + (self.rotate_every_two(t) * pos_sin)).to(t.dtype)
else:
return ((t * pos_cos) + (self.rotate_half(t) * pos_sin)).to(t.dtype)
def forward(self, q: torch.Tensor, k: torch.Tensor, position_ids: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
if self.config.rope_full_precision:
q_, k_ = q.float(), k.float()
else:
q_, k_ = q, k
with torch.autocast(q.device.type, enabled=False):
batch_size = q_.shape[0]
query_len, key_len = q_.shape[-2], k_.shape[-2] # could be different if layer_past not None
if position_ids is not None:
freqs_cis_len = self.config.max_position_embeddings or self.config.max_sequence_length
else:
freqs_cis_len = key_len
pos_sin, pos_cos = self.get_rotary_embedding(freqs_cis_len, q_.device)
pos_sin = pos_sin.type_as(q_)
pos_cos = pos_cos.type_as(q_)
if position_ids is not None:
assert query_len == key_len, "Query and key lengths must be equal when using position IDs."
pos_sin = pos_sin[0, 0][position_ids].view((batch_size, 1, key_len, pos_sin.shape[-1]))
pos_cos = pos_cos[0, 0][position_ids].view((batch_size, 1, key_len, pos_cos.shape[-1]))
q_ = self.apply_rotary_pos_emb(
pos_sin[:, :, key_len - query_len : key_len, :],
pos_cos[:, :, key_len - query_len : key_len, :],
q_,
)
k_ = self.apply_rotary_pos_emb(pos_sin, pos_cos, k_)
return q_.type_as(q), k_.type_as(k)
class MolmoBlock(nn.Module):
"""
A base class for transformer block implementations.
"""
def __init__(self, layer_id: int, config: MolmoConfig, cache: BufferCache):
super().__init__()
self.layer_id = layer_id
self.config = config
self.hidden_size = config.mlp_hidden_size if config.mlp_hidden_size is not None else config.mlp_ratio * config.d_model
self.__cache = cache
self._activation_checkpoint_fn = None
# Dropout.
self.dropout = Dropout(config.residual_dropout)
# Layer norms.
self.k_norm: Optional[LayerNormBase] = None
self.q_norm: Optional[LayerNormBase] = None
if config.attention_layer_norm:
assert config.effective_n_kv_heads is not None
self.k_norm = LayerNormBase.build(
config,
size=(config.d_model // config.n_heads) * config.effective_n_kv_heads,
elementwise_affine=config.attention_layer_norm_with_affine,
)
self.q_norm = LayerNormBase.build(config, elementwise_affine=config.attention_layer_norm_with_affine)
# Make sure QKV clip coefficient is positive, otherwise it's not well-defined.
if config.clip_qkv is not None:
assert config.clip_qkv > 0
# Activation function.
self.act = Activation.build(config)
assert (self.act.output_multiplier * self.hidden_size) % 1 == 0
# Attention output projection.
input_dim = config.d_model
self.attn_out = nn.Linear(input_dim, config.d_model, bias=config.include_bias, device=config.init_device)
# Feed-forward output projection.
self.ff_out = nn.Linear(
int(self.act.output_multiplier * self.hidden_size),
config.d_model,
bias=config.include_bias,
device=config.init_device,
)
self.ff_out._is_residual = True # type: ignore
# Rotary embeddings.
if self.config.rope:
self.rotary_emb = RotaryEmbedding(config, self.__cache)
self.flash_attn_func = None
if config.attention_type == "flash":
try:
from flash_attn import flash_attn_func # type: ignore
self.flash_attn_func = flash_attn_func
except ModuleNotFoundError:
pass
def reset_parameters(self):
if self.k_norm is not None:
self.k_norm.reset_parameters()
if self.q_norm is not None:
self.q_norm.reset_parameters()
init_weights(
self.config,
self.attn_out,
d=self.config.d_model,
layer_id=self.layer_id,
type_of_module=ModuleType.out_module,
)
init_weights(
self.config,
self.ff_out,
d=self.ff_out.in_features,
layer_id=self.layer_id,
type_of_module=ModuleType.out_module,
)
@classmethod
def _cast_attn_bias(cls, bias: torch.Tensor, input_dtype: torch.dtype) -> torch.Tensor:
target_dtype = input_dtype
# NOTE: `is_autocast_enabled()` only checks for CUDA autocast, so we use the separate function
# `is_autocast_cpu_enabled()` for CPU autocast.
# See https://github.com/pytorch/pytorch/issues/110966.
if bias.device.type == "cuda" and torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
elif bias.device.type == "cpu" and torch.is_autocast_cpu_enabled():
target_dtype = torch.get_autocast_cpu_dtype()
if bias.dtype != target_dtype:
bias = bias.to(target_dtype)
ensure_finite_(bias, check_neg_inf=True, check_pos_inf=False)
return bias
def _scaled_dot_product_attention(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
response_dropout_p: float = 0.0,
is_causal: bool = False,
) -> torch.Tensor:
"""
Computes scaled dot product attention on query, key and value tensors, using an optional
attention mask if passed, and applying dropout if a probability greater than 0.0 is specified.
"""
if attn_mask is not None:
attn_mask = attn_mask.to(q.device)
if self.flash_attn_func is not None and attn_mask is None:
r = self.flash_attn_func(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), dropout_p=dropout_p, causal=is_causal)
return r.transpose(1, 2)
else:
# torch's sdpa doesn't support GQA, so we're doing this
assert k.size(1) == v.size(1)
num_kv_heads = k.size(1)
num_q_heads = q.size(1)
if num_q_heads != num_kv_heads:
assert num_q_heads % num_kv_heads == 0
k = k.repeat_interleave(num_q_heads // num_kv_heads, dim=1, output_size=num_q_heads)
v = v.repeat_interleave(num_q_heads // num_kv_heads, dim=1, output_size=num_q_heads)
return F.scaled_dot_product_attention(
q,
k,
v,
attn_mask=attn_mask,
dropout_p=dropout_p,
is_causal=is_causal,
)
def attention(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
attention_bias: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
B, T, C = q.size() # batch size, sequence length, d_model
dtype = k.dtype
# Optionally apply layer norm to keys and queries.
if self.q_norm is not None and self.k_norm is not None:
q = self.q_norm(q).to(dtype=dtype)
k = self.k_norm(k).to(dtype=dtype)
# Move head forward to be next to the batch dim.
# shape: (B, nh, T, hs)
q = q.view(B, T, self.config.n_heads, C // self.config.n_heads).transpose(1, 2)
# shape: (B, n_kv_h, T, hs)
k = k.view(B, T, self.config.effective_n_kv_heads, C // self.config.n_heads).transpose(1, 2)
# shape: (B, n_kv_h, T, hs)
v = v.view(B, T, self.config.effective_n_kv_heads, C // self.config.n_heads).transpose(1, 2)
if self.config.use_position_ids and self.config.rope:
# Apply rotary embeddings
q, k = self.rotary_emb(q, k, position_ids=position_ids)
if layer_past is not None:
past_key, past_value = layer_past
k = torch.cat((past_key.to(k.device), k), dim=-2)
v = torch.cat((past_value.to(v.device), v), dim=-2)
present = (k, v) if use_cache else None
query_len, key_len = q.shape[-2], k.shape[-2] # could be different if layer_past not None
if not self.config.use_position_ids and self.config.rope:
# Apply rotary embeddings
q, k = self.rotary_emb(q, k)
if attention_bias is not None:
# Resize and cast attention bias.
# The current dtype of the attention bias might not match the dtype that the SDP attn function will
# run in if AMP is enabled, and this can be a problem if some tokens are masked out due to padding
# as down-casting the attention bias to the autocast precision will result in -infs, which will
# cause the SDP attn function to produce NaNs.
attention_bias = self._cast_attn_bias(attention_bias[:, :, key_len - query_len : key_len, :key_len], dtype)
# Get the attention scores.
# shape: (B, nh, T, hs)
att = self._scaled_dot_product_attention(
q,
k,
v,
attn_mask=attention_bias,
dropout_p=0.0 if not self.training else self.config.attention_dropout,
response_dropout_p=0.0 if not self.training else self.config.response_attention_dropout,
is_causal=attention_bias is None,
)
# Re-assemble all head outputs side-by-side.
att = att.transpose(1, 2).contiguous().view(B, T, C)
# Apply output projection.
return self.attn_out(att), present
def forward(
self,
x: torch.Tensor,
attention_bias: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.Tensor] = None,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
raise NotImplementedError
@classmethod
def build(cls, layer_id: int, config: MolmoConfig, cache: BufferCache):
return MolmoSequentialBlock(layer_id, config, cache)
class MolmoSequentialBlock(MolmoBlock):
"""
This is a typical transformer block where the output is computed as ``MLP(LN(x + Attention(LN(x))))``
(plus another skip connection).
"""
def __init__(self, layer_id: int, config: MolmoConfig, cache: BufferCache):
super().__init__(layer_id, config, cache)
# Layer norms.
self.attn_norm = LayerNorm.build(config)
self.ff_norm = LayerNorm.build(config)
# Attention input projection. Projects x -> (q, k, v)
head_dim = config.d_model // config.n_heads
self.fused_dims = (
config.d_model,
config.effective_n_kv_heads * head_dim,
config.effective_n_kv_heads * head_dim,
)
self.att_proj = nn.Linear(config.d_model, sum(self.fused_dims), bias=config.include_bias or config.qkv_bias, device=config.init_device)
# Feed-forward input projection.
self.ff_proj = nn.Linear(config.d_model, self.hidden_size, bias=config.include_bias, device=config.init_device)
def reset_parameters(self):
super().reset_parameters()
self.attn_norm.reset_parameters()
self.ff_norm.reset_parameters()
# NOTE: the standard deviation for these weights does not depend on the layer.
init_weights(self.config, self.att_proj, d=self.config.d_model, layer_id=None, type_of_module=ModuleType.in_module)
init_weights(self.config, self.ff_proj, d=self.config.d_model, layer_id=None, type_of_module=ModuleType.in_module)
def forward(
self,
x: torch.Tensor,
attention_bias: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
# Get query, key, value projections.
# shape:
# - for regular attn q, k, v: (batch_size, seq_len, d_model)
# - for multi-query attn q: (batch_size, seq_len, d_model)
# k, v: (batch_size, seq_len, d_model // n_heads)
# - for group query attn q: (batch_size, seq_len, d_model)
# k, v: (batch_size, seq_len, d_model // n_kv_heads)
if not self.config.norm_after:
if self._activation_checkpoint_fn is not None:
atten_in = self._activation_checkpoint_fn(self.attn_norm, x)
else:
atten_in = self.attn_norm(x)
else:
atten_in = x
qkv = self.att_proj(atten_in)
if self.config.clip_qkv is not None:
qkv.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
q, k, v = qkv.split(self.fused_dims, dim=-1)
# Get attention scores.
if self._activation_checkpoint_fn is not None:
att, cache = self._activation_checkpoint_fn( # type: ignore
self.attention, q, k, v, attention_bias, position_ids=position_ids, layer_past=layer_past, use_cache=use_cache
)
else:
att, cache = self.attention(q, k, v, attention_bias, position_ids=position_ids, layer_past=layer_past, use_cache=use_cache)
if self.config.norm_after:
if self._activation_checkpoint_fn is not None:
att = self._activation_checkpoint_fn(self.attn_norm, att)
else:
att = self.attn_norm(att)
# Add attention scores.
# shape: (B, T, C)
x = x + self.dropout(att)
# Add feed-forward projection.
# shape: (batch_size, seq_len, d_model)
og_x = x
if not self.config.norm_after:
if self._activation_checkpoint_fn is not None:
x = self._activation_checkpoint_fn(self.ff_norm, x) # type: ignore
else:
x = self.ff_norm(x)
x = self.ff_proj(x)
if self._activation_checkpoint_fn is not None:
x = self._activation_checkpoint_fn(self.act, x) # type: ignore
else:
x = self.act(x)
x = self.ff_out(x)
if self.config.norm_after:
if self._activation_checkpoint_fn is not None:
x = self._activation_checkpoint_fn(self.ff_norm, x) # type: ignore
else:
x = self.ff_norm(x)
x = self.dropout(x)
x = og_x + x
return x, cache
class Embedding(nn.Module):
def __init__(
self,
num_embeddings: int,
num_new_embeddings: int,
features: int,
device: Union[str, torch.device],
initializer_range: float = 0.02,
new_embed_initializer_range: float = 0.02,
):
super().__init__()
self.initializer_range = initializer_range
self.new_embed_initializer_range = new_embed_initializer_range
self.embedding = nn.Parameter(
torch.zeros(num_embeddings, features, device=device),
)
self.new_embedding = nn.Parameter(
torch.zeros(num_new_embeddings, features, device=device),
)
def reset_parameters(self):
nn.init.normal_(self.embedding, std=self.initializer_range)
nn.init.normal_(self.new_embedding, std=self.new_embed_initializer_range)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return F.embedding(x, torch.cat([self.embedding, self.new_embedding], dim=0))
class Dropout(nn.Dropout):
def __init__(
self,
p: float = 0.5,
inplace: bool = False,
mask_p: float = 0,
broadcast_dims: Sequence[int] = (),
):
super().__init__(p, inplace)
self.mask_p = mask_p
self.broadcast_dims = broadcast_dims
def forward(self, input: torch.Tensor) -> torch.Tensor:
"""
:param input: A tensor of shape `(batch_size, seq_len, embed_dim)`
"""
if self.p == 0.0 and (self.mask_p is None or self.mask_p == 0.0):
return input
else:
if self.p > 0.0 and len(self.broadcast_dims) > 0 and self.training:
keep_prob = 1.0 - self.p
dropout_shape = list(input.shape)
for dim in self.broadcast_dims:
dropout_shape[dim] = 1
keep = input.new_empty(dropout_shape).bernoulli_(keep_prob)
multiplier = keep.broadcast_to(input.shape)
multiplier.div_(keep_prob)
input = input * multiplier
else:
return F.dropout(input, self.p, self.training, self.inplace)
@dataclass
class VisionBackboneConfig:
image_default_input_size: Tuple[int, int] = (336, 336)
image_patch_size: int = 14
image_pos_patch_size: int = 14
image_emb_dim: int = 1024
image_num_heads: int = 16
image_num_key_value_heads: int = 16
image_num_layers: int = 24
image_head_dim: int = 64
image_mlp_dim: int = 4096
image_mlp_activations: str = "gelu"
image_dropout_rate: float = 0.0
image_num_pos: int = 577
image_norm_eps: float = 1e-5
attention_dropout: float = 0.0
residual_dropout: float = 0.0
initializer_range: float = 0.02
fsdp_wrap: bool = False
resize_mode: str = "default"
def __post_init__(self):
self.image_default_input_size = tuple(self.image_default_input_size) # type: ignore[assignment]
@property
def image_num_patch(self):
h, w = self.image_default_input_size
return h // self.image_patch_size, w // self.image_patch_size
@dataclass
class FullMolmoConfig:
d_model: int = 768
n_heads: int = 12
n_kv_heads: Optional[int] = None
qkv_bias: bool = False
clip_qkv: Optional[float] = None
n_layers: int = 12
mlp_ratio: int = 4
mlp_hidden_size: Optional[int] = None
activation_type: str = "swiglu"
block_group_size: int = 1
rope: bool = True
rope_full_precision: bool = True
rope_theta: float = 10000.0
rope_impl: str = "interleave"
vision_backbone: Optional[VisionBackboneConfig] = None
attention_type: str = "sdpa"
float32_attention: bool = True
attention_dropout: float = 0.1
response_attention_dropout: float = 0.0
multi_query_attention: Optional[bool] = None
attention_layer_norm: bool = False
residual_dropout: float = 0.1
embedding_dropout: float = 0.1
layer_norm_type: str = "default"
layer_norm_with_affine: bool = True
layer_norm_eps: Optional[float] = None
attention_layer_norm_with_affine: bool = True
max_sequence_length: int = 1024
max_position_embeddings: Optional[int] = None
include_bias: bool = True
bias_for_layer_norm: Optional[bool] = None
scale_logits: bool = False
vocab_size: int = 50257
embedding_size: Optional[int] = 50304
additional_vocab_size: Optional[int] = None
new_embedding_init_range: float = 0.02
weight_tying: bool = True
pad_token_id: int = -1
init_device: Optional[str] = None
init_std: float = 0.02
init_cutoff_factor: Optional[float] = None
norm_after: bool = False
precision: Optional[str] = None
image_padding_embed: Optional[str] = None
vit_layers: Tuple = (-1,)
image_pooling_h: int = 2
image_pooling_w: int = 2
image_pooling_2d: str = "attention"
image_projector: str = "mlp"
image_feature_dropout: float = 0.0
initializer_range: float = 0.02
normalize_input_embeds: bool = False
use_position_ids: bool = True
@property
def effective_n_kv_heads(self) -> int:
if self.n_kv_heads is None:
if self.multi_query_attention is True:
return 1
else:
return self.n_heads
else:
if self.multi_query_attention is None:
return self.n_kv_heads
if self.multi_query_attention:
n_kv_heads_should_be = 1
else:
n_kv_heads_should_be = self.n_heads
if self.n_kv_heads == n_kv_heads_should_be:
return n_kv_heads_should_be
else:
raise MolmoConfigurationError("You can't set `multi_query_attention` and `n_kv_heads` at the same time.")
@property
def image_num_patch(self):
assert self.vision_backbone is not None
return self.vision_backbone.image_num_patch
@property
def image_patch_size(self):
assert self.vision_backbone is not None
return self.visoin_backbone.image_patch_size
def llm_patches_per_crop(self):
h, w = self.image_num_patch
# Round up in case we need to pad the image features for pooling
h = (h + self.image_pooling_h - 1) // self.image_pooling_h
w = (w + self.image_pooling_w - 1) // self.image_pooling_w
return h, w
def _expand_token(token, batch_size: int):
return token.view(1, 1, -1).expand(batch_size, -1, -1)
class ViTMLP(nn.Module):
def __init__(self, config: FullMolmoConfig):
super().__init__()
self.config = config
v_cfg = config.vision_backbone
self.w1 = nn.Linear(
v_cfg.image_emb_dim,
v_cfg.image_mlp_dim,
bias=True,
device=config.init_device,
)
# Activation function.
cfg = deepcopy(config)
cfg.activation_type = v_cfg.image_mlp_activations
self.act = Activation.build(cfg)
self.w2 = nn.Linear(
v_cfg.image_mlp_dim,
v_cfg.image_emb_dim,
bias=True,
device=config.init_device,
)
def reset_parameters(self):
v_cfg = self.config.vision_backbone
nn.init.trunc_normal_(self.w1.weight, std=math.sqrt(1 / v_cfg.image_emb_dim), a=-2.0, b=2.0)
nn.init.trunc_normal_(self.w2.weight, std=math.sqrt(1 / v_cfg.image_mlp_dim), a=-2.0, b=2.0)
nn.init.zeros_(self.w1.bias)
nn.init.zeros_(self.w2.bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.w1(x)
x = self.act(x)
x = self.w2(x)
return x
class ResidualAttentionBlock(nn.Module):
def __init__(self, config: FullMolmoConfig):
super().__init__()
self.config = config
v_cfg = config.vision_backbone
self.attention = MultiHeadDotProductAttention(config)
self.feed_forward = ViTMLP(config)
self.attention_norm = nn.LayerNorm(
v_cfg.image_emb_dim,
eps=v_cfg.image_norm_eps,
device=config.init_device,
)
self.ffn_norm = nn.LayerNorm(
v_cfg.image_emb_dim,
eps=v_cfg.image_norm_eps,
device=config.init_device,
)
def reset_parameters(self):
self.attention.reset_parameters()
self.feed_forward.reset_parameters()
self.attention_norm.reset_parameters()
self.ffn_norm.reset_parameters()
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + self.attention(self.attention_norm(x))
x = x + self.feed_forward(self.ffn_norm(x))
return x
class BlockCollection(nn.Module):
def __init__(self, config: FullMolmoConfig):
super().__init__()
self.config = config
self.grad_checkpointing: bool = False
v_cfg = config.vision_backbone
self.resblocks = nn.ModuleList([ResidualAttentionBlock(config) for _ in range(v_cfg.image_num_layers)])
def reset_parameters(self):
for r in self.resblocks:
r.reset_parameters()
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
hidden_states = []
for r in self.resblocks:
x = r(x)
hidden_states.append(x)
return hidden_states
class LayerNormFp32(nn.LayerNorm):
def forward(self, x: torch.Tensor) -> torch.Tensor:
orig_type = x.dtype
x = F.layer_norm(x.to(torch.float32), self.normalized_shape, self.weight.to(torch.float32), self.bias.to(torch.float32), self.eps)
return x.to(orig_type)
class VisionTransformer(nn.Module):
def __init__(self, config: FullMolmoConfig):
super().__init__()
self.config = config
v_cfg = config.vision_backbone
# class embeddings and positional embeddings
self.scale = v_cfg.image_emb_dim**-0.5
self.class_embedding = nn.Parameter(
torch.zeros(v_cfg.image_emb_dim, device=config.init_device),
)
self.num_prefix_tokens: int = 1
self.positional_embedding = nn.Parameter(
torch.zeros(v_cfg.image_num_pos, v_cfg.image_emb_dim, device=config.init_device),
)
image_patch_size = v_cfg.image_patch_size
self.patch_embedding = nn.Linear(
image_patch_size * image_patch_size * 3,
v_cfg.image_emb_dim,
bias=False,
device=config.init_device,
)
self.pre_ln = LayerNormFp32(
v_cfg.image_emb_dim,
eps=v_cfg.image_norm_eps,
)
self.transformer = BlockCollection(config)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.transformer.grad_checkpointing = enable
def reset_parameters(self):
nn.init.normal_(self.class_embedding, std=self.scale)
nn.init.normal_(self.positional_embedding, std=self.scale)
nn.init.normal_(self.patch_embedding.weight, std=0.02)
self.pre_ln.reset_parameters()
self.transformer.reset_parameters()
def add_pos_emb(self, x: torch.Tensor, patch_num: int) -> torch.Tensor:
cls_emb = self.positional_embedding[0:1]
pos_emb = self.positional_embedding[1:]
pos_emb = pos_emb.reshape((int(math.sqrt(pos_emb.shape[0])), int(math.sqrt(pos_emb.shape[0])), pos_emb.shape[1]))
(patch_num_0, patch_num_1) = patch_num
if pos_emb.shape[0] != patch_num_0 or pos_emb.shape[1] != patch_num_1:
# Dervied from https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
# antialias: default True in jax.image.resize
pos_emb = pos_emb.unsqueeze(0).permute(0, 3, 1, 2)
pos_emb = F.interpolate(
pos_emb,
size=(patch_num_0, patch_num_1),
mode="bicubic",
align_corners=False,
antialias=True,
)
pos_emb = pos_emb.permute(0, 2, 3, 1).squeeze(0)
pos_emb = pos_emb.reshape(-1, pos_emb.shape[-1])
x = x + torch.cat([cls_emb[None, :, :], pos_emb[None, :, :]], dim=1).to(x.dtype)
return x
def forward(self, x: torch.Tensor, patch_num: int = None) -> List[torch.Tensor]:
"""
: param x: (batch_size, num_patch, n_pixels)
"""
if patch_num is None:
patch_num = self.config.vision_backbone.image_num_patch
B, N, D = x.shape
x = self.patch_embedding(x)
# class embeddings and positional embeddings
x = torch.cat([_expand_token(self.class_embedding, x.shape[0]).to(x.dtype), x], dim=1)
x = self.add_pos_emb(x, patch_num)
x = self.pre_ln(x)
hidden_states = self.transformer(x)
return hidden_states
class MultiHeadDotProductAttention(nn.Module):
def __init__(self, config: FullMolmoConfig, use_bias: bool = True, is_vit_layer: Optional[bool] = True):
super().__init__()
self.config = config
self.use_bias = use_bias
v_cfg = config.vision_backbone
self.embed_dim = v_cfg.image_emb_dim
self.num_heads = v_cfg.image_num_heads
self.head_dim = v_cfg.image_head_dim
self.num_key_value_heads = v_cfg.image_num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.initializer_range = v_cfg.initializer_range
self.is_vit_layer = is_vit_layer
nlayers = 1 if (is_vit_layer or config.vit_layers is None) else len(config.vit_layers)
self.wq = nn.Linear(
nlayers * self.embed_dim,
self.num_heads * self.head_dim,
bias=use_bias,
device=config.init_device,
)
self.wk = nn.Linear(
nlayers * self.embed_dim,
self.num_key_value_heads * self.head_dim,
bias=use_bias,
device=config.init_device,
)
self.wv = nn.Linear(
nlayers * self.embed_dim,
self.num_key_value_heads * self.head_dim,
bias=use_bias,
device=config.init_device,
)
self.wo = nn.Linear(
self.num_heads * self.head_dim,
self.embed_dim,
bias=use_bias,
device=config.init_device,
)
self.attention_dropout: Optional[Dropout] = None
if v_cfg.attention_dropout > 0:
self.attention_dropout = Dropout(v_cfg.attention_dropout, broadcast_dims=(0, 1))
self.residual_dropout = Dropout(v_cfg.residual_dropout)
def reset_parameters(self):
nn.init.normal_(self.wq.weight, std=self.initializer_range)
nn.init.normal_(self.wk.weight, std=self.initializer_range)
nn.init.normal_(self.wv.weight, std=self.initializer_range)
nn.init.normal_(self.wo.weight, std=self.initializer_range)
if self.use_bias:
nn.init.constant_(self.wq.bias, 0)
nn.init.constant_(self.wk.bias, 0)
nn.init.constant_(self.wv.bias, 0)
nn.init.constant_(self.wo.bias, 0)
def _split_heads(self, hidden_states, num_heads) -> torch.Tensor:
return hidden_states.reshape(hidden_states.shape[:2] + (num_heads, self.head_dim))
def _merge_heads(self, hidden_states) -> torch.Tensor:
return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))
def forward(self, inputs_q: torch.Tensor, inputs_kv: Optional[torch.Tensor] = None) -> torch.Tensor:
if inputs_kv is not None:
inputs_k = inputs_kv
inputs_v = inputs_kv
else:
inputs_k = inputs_q
inputs_v = inputs_q
xq, xk, xv = self.wq(inputs_q), self.wk(inputs_k), self.wv(inputs_v)
xq = self._split_heads(xq, self.num_heads)
xk = self._split_heads(xk, self.num_key_value_heads)
xv = self._split_heads(xv, self.num_key_value_heads)
if self.num_heads != self.num_key_value_heads:
xk = xk.repeat_interleave(self.num_key_value_groups, dim=2, output_size=self.num_heads)
xv = xv.repeat_interleave(self.num_key_value_groups, dim=2, output_size=self.num_heads)
og_dtype = xq.dtype
if self.config.float32_attention:
xq = xq.to(torch.float)
xk = xk.to(torch.float)
if self.config.attention_type == "direct":
attn_weights = torch.einsum("...qhd,...khd->...hqk", xq / math.sqrt(xq.size(-1)), xk)
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(xq.dtype)
if self.attention_dropout is not None:
attn_weights = self.attention_dropout(attn_weights)
attn_output = torch.einsum("...hqk,...khd->...qhd", attn_weights.to(xv.dtype), xv)
elif self.config.attention_type == "sdpa":
if self.config.float32_attention and not torch.is_autocast_enabled():
xv = xv.to(torch.float32)
attn_output = F.scaled_dot_product_attention(
xq.transpose(1, 2).contiguous(),
xk.transpose(1, 2).contiguous(),
xv.transpose(1, 2).contiguous(),
is_causal=False,
dropout_p=self.config.vision_backbone.attention_dropout,
).transpose(1, 2)
else:
raise NotImplementedError(self.config.attention_type)
attn_output = attn_output.to(og_dtype)
attn_output = self._merge_heads(attn_output)
attn_output = self.wo(attn_output)
attn_output = self.residual_dropout(attn_output)
return attn_output
class MultiHeadAttentionPool(nn.Module):
def __init__(
self,
config: FullMolmoConfig,
factor: int = 1,
use_bias: bool = True,
dropout: bool = True,
output_layer: bool = True,
mean_residual: bool = False,
query: str = "mean",
is_vit_layer: Optional[bool] = True,
):
super().__init__()
self.config = config
self.factor = factor
self.use_bias = use_bias
self.dropout = dropout
self.output_layer = output_layer
self.mean_residual = mean_residual
self.query = query
v_cfg = config.vision_backbone
input_dim = v_cfg.image_emb_dim
self.embed_dim = v_cfg.image_emb_dim * factor
self.num_heads = v_cfg.image_num_heads
self.head_dim = v_cfg.image_head_dim * factor
self.num_key_value_heads = v_cfg.image_num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.initializer_range = v_cfg.initializer_range
nlayers = 1 if (is_vit_layer or config.vit_layers is None) else len(config.vit_layers)
if query != "vector":
self.wq = nn.Linear(
nlayers * input_dim,
self.num_heads * self.head_dim,
bias=use_bias,
device=config.init_device,
)
self.wk = nn.Linear(
nlayers * input_dim,
self.num_key_value_heads * self.head_dim,
bias=use_bias,
device=config.init_device,
)
self.wv = nn.Linear(
nlayers * input_dim,
self.num_key_value_heads * self.head_dim,
bias=use_bias,
device=config.init_device,
)
if query == "vector":
self.attention_query = nn.Parameter(
torch.zeros(
1,
self.num_key_value_heads * self.head_dim,
device=config.init_device,
),
)
if output_layer:
self.wo = nn.Linear(
self.num_heads * self.head_dim,
self.embed_dim,
bias=use_bias,
device=config.init_device,
)
self.attention_dropout = Dropout(v_cfg.attention_dropout, broadcast_dims=(0, 1))
if dropout:
self.residual_dropout = Dropout(v_cfg.residual_dropout)
def reset_parameters(self):
if self.query != "vector":
nn.init.normal_(self.wq.weight, std=self.initializer_range)
nn.init.normal_(self.wk.weight, std=self.initializer_range)
nn.init.normal_(self.wv.weight, std=self.initializer_range)
if self.output_layer:
nn.init.normal_(self.wo.weight, std=self.initializer_range)
if self.use_bias:
if self.query != "vector":
nn.init.constant_(self.wq.bias, 0)
nn.init.constant_(self.wk.bias, 0)
nn.init.constant_(self.wv.bias, 0)
if self.output_layer:
nn.init.constant_(self.wo.bias, 0)
if self.query == "vector":
nn.init.normal_(self.attention_query, std=self.initializer_range)
def _split_heads(self, hidden_states, num_heads):
return hidden_states.reshape(hidden_states.shape[:2] + (num_heads, self.head_dim))
def _merge_heads(self, hidden_states):
return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))
def forward(self, inputs_kv: torch.Tensor) -> torch.Tensor:
xk, xv = self.wk(inputs_kv), self.wv(inputs_kv)
if self.query == "mean":
inputs_q = inputs_kv.mean(dim=1, keepdim=True)
xq = self.wq(inputs_q)
elif self.query == "first":
inputs_q = inputs_kv[:, :1]
xq = self.wq(inputs_q)
elif self.query == "vector":
xq = self.attention_query.expand(inputs_kv.size(0), -1, -1)
elif self.query == "constant":
inputs_q = torch.ones_like(inputs_kv[:, :1]) / math.sqrt(inputs_kv.shape[-1])
xq = self.wq(inputs_q)
else:
raise ValueError(f"Unknown query type: {self.query}")
xq = self._split_heads(xq, self.num_heads)
xk = self._split_heads(xk, self.num_key_value_heads)
xv = self._split_heads(xv, self.num_key_value_heads)
if self.num_heads != self.num_key_value_heads:
xk = xk.repeat_interleave(self.num_key_value_groups, dim=2, output_size=self.num_heads)
xv = xv.repeat_interleave(self.num_key_value_groups, dim=2, output_size=self.num_heads)
xq = xq.to(torch.float)
xk = xk.to(torch.float)
xq = xq / math.sqrt(xq.size(-1))
attn_weights = torch.einsum("...qhd,...khd->...hqk", xq, xk)
attn_weights = F.softmax(attn_weights, dim=-1).to(xq.dtype)
attn_weights = self.attention_dropout(attn_weights).to(xv.dtype)
attn_output = torch.einsum("...hqk,...khd->...qhd", attn_weights, xv)
attn_output = self._merge_heads(attn_output)
if self.output_layer:
attn_output = self.wo(attn_output)
if self.dropout:
attn_output = self.residual_dropout(attn_output)
if self.mean_residual:
attn_output += inputs_kv.mean(dim=1, keepdim=True)
return attn_output
class MLP(nn.Module):
def __init__(self, config: FullMolmoConfig, input_dim: int, dropout: float = 0.0):
super().__init__()
self.config = config
self.hidden_size = config.mlp_hidden_size if config.mlp_hidden_size is not None else config.mlp_ratio * config.d_model
self.initializer_range = config.initializer_range
self.w1 = nn.Linear(
input_dim,
self.hidden_size // 2,
bias=False,
device=config.init_device,
)
self.w2 = nn.Linear(
self.hidden_size // 2,
config.d_model,
bias=False,
device=config.init_device,
)
self.w3 = nn.Linear(
input_dim,
self.hidden_size // 2,
bias=False,
device=config.init_device,
)
# Activation function.
self.act = Activation.build(config)
self.dropout = Dropout(dropout)
def reset_parameters(self):
nn.init.normal_(self.w1.weight, std=self.initializer_range)
nn.init.normal_(self.w2.weight, std=self.initializer_range)
nn.init.normal_(self.w3.weight, std=self.initializer_range)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.w2(self.act(self.w1(x), self.w3(x)))
x = self.dropout(x)
return x
class Residual(nn.Module):
def __init__(self, submodule: nn.Module):
super().__init__()
self.submodule = submodule
def reset_parameters(self):
self.submodule.reset_parameters()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x + self.submodule(x)
class OLMoVisionBackbone(nn.Module):
def __init__(self, config: FullMolmoConfig):
super().__init__()
self.config = config
self.image_vit = VisionTransformer(config)
input_dim: int = None
self.image_pooling_2d: nn.Module = None
if config.image_pooling_2d in {ImagePooling2DType.attention, ImagePooling2DType.attention_meanq}:
self.image_pooling_2d = MultiHeadDotProductAttention(config, is_vit_layer=False)
input_dim = config.vision_backbone.image_emb_dim
elif config.image_pooling_2d == ImagePooling2DType.attention_2wide:
cfg = deepcopy(config)
cfg.vision_backbone.image_emb_dim *= 2
cfg.vision_backbone.image_head_dim *= 2
self.image_pooling_2d = MultiHeadDotProductAttention(cfg, is_vit_layer=False)
input_dim = cfg.vision_backbone.image_emb_dim
elif config.image_pooling_2d == ImagePooling2DType.attention_v2:
assert config.vit_layers is not None
use_bias = True
dropout = True
output_layer = True
query = "mean"
mean_residual = False
factor = len(config.vit_layers)
self.image_pooling_2d = MultiHeadAttentionPool(
config,
factor=factor,
use_bias=use_bias,
dropout=dropout,
output_layer=output_layer,
mean_residual=mean_residual,
query=query,
is_vit_layer=False,
)
input_dim = config.vision_backbone.image_emb_dim * factor
elif config.image_pooling_2d in [ImagePooling2DType.none, ImagePooling2DType.stack]:
self.image_pooling_2d = None
nlayers = 1 if config.vit_layers is None else len(config.vit_layers)
input_dim = nlayers * config.vision_backbone.image_emb_dim
else:
raise NotImplementedError(f"Unknown image pooling 2D method: {config.image_pooling_2d}")
self.input_dim = input_dim
# `MLP` assume the activation takes two inputs, so it must be a 'llama' version
if config.activation_type == ActivationType.swiglu:
mlp_config = replace(config, activation_type=ActivationType.llama_swiglu)
elif config.activation_type == ActivationType.gelu:
mlp_config = replace(config, activation_type=ActivationType.llama_geglu)
else:
mlp_config = config
if config.image_projector == ImageProjectType.mlpx2:
self.image_projector = nn.ModuleList([MLP(mlp_config, input_dim), Residual(MLP(config, input_dim))])
elif config.image_projector == ImageProjectType.mlp:
self.image_projector = MLP(mlp_config, input_dim)
elif config.image_projector == ImageProjectType.linear:
self.image_projector = nn.Linear(
input_dim,
config.d_model,
bias=False,
device=config.init_device,
)
else:
raise NotImplementedError(f"Unknown image projector: {config.image_projector}")
self.image_feature_dropout = Dropout(config.image_feature_dropout)
def reset_parameters(self):
if self.image_pooling_2d is not None:
self.image_pooling_2d.reset_parameters()
if self.config.image_projector == "2mlp":
for module in self.image_projector:
module.reset_parameters()
elif self.config.image_projector == "linear":
nn.init.xavier_uniform_(self.image_projector.weight)
else:
self.image_projector.reset_parameters()
def forward(self, images: torch.Tensor, image_masks: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
raise NotImplementedError
class OLMoPretrainedVisionBackbone(OLMoVisionBackbone):
def __init__(self, config: FullMolmoConfig):
super().__init__(config)
v_cfg = self.config.vision_backbone
self.grad_checkpointing = True
self.num_prefix_tokens = self.image_vit.num_prefix_tokens
assert self.num_prefix_tokens in {0, 1}, "Only 0 or 1 prefix tokens are supported"
self.pad_embed = None
if config.image_padding_embed:
image_dim = v_cfg.image_emb_dim * len(self.config.vit_layers)
if config.image_padding_embed in ["pad_embed", "regress"]:
self.pad_embed = nn.Parameter(torch.zeros((image_dim,), device=config.init_device))
elif config.image_padding_embed == "pad_and_partial_pad":
self.pad_embed = nn.Parameter(torch.zeros((2, image_dim), device=config.init_device))
else:
raise ValueError(config.image_padding_embed)
def reset_parameters(self):
super().reset_parameters()
self.image_vit.reset_parameters()
def encode_image(self, images: torch.Tensor) -> torch.Tensor:
"""
: param images: (batch_size, num_crops, num_patch, n_pixels)
"""
cfg = self.config
v_cfg = self.config.vision_backbone
B, T, N, D = images.shape
mask = ~torch.all(images.view(B * T, N, D) == -1, dim=(1, 2), keepdim=True)
# Output all hidden states
# n_layers x (batch_num_crops, (1+)n_tokens, image_emb_dim)
images = images.view(B * T, N, D)
image_features = self.image_vit(images)
if cfg.vit_layers is not None:
features = []
for layer in cfg.vit_layers:
features.append(image_features[layer])
image_features = torch.cat(features, dim=-1)
else:
image_features = image_features[-1]
cls_embed: torch.Tensor = None
if self.num_prefix_tokens > 0:
cls_embed = image_features[:, 0]
image_features = image_features[:, 1:]
image_features = image_features * mask
image_features = image_features.view(B, T, N, -1)
cls_embed = cls_embed.view(B, T, -1) if cls_embed is not None else None
return image_features, cls_embed
def forward(self, images: torch.Tensor, image_masks: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
cfg = self.config
# image_features: (batch_size, num_crops(=num_image), num_patch, nximage_emb_dim)
batch_size, num_image = images.shape[:2]
image_features, cls_embed = self.encode_image(images)
if cfg.image_padding_embed:
assert image_masks is not None
if cfg.image_padding_embed == "pad_embed":
all_pad = (image_masks == 0).to(dtype=torch.float32)
pad_embed = self.pad_embed[None, None, None, :]
image_features = image_features + pad_embed * torch.unsqueeze(all_pad, -1)
elif cfg.image_padding_embed == "regress":
pad_embed = self.pad_embed[None, None, None, :]
image_features = image_features + pad_embed * torch.unsqueeze(torch.maximum(image_masks, torch.zeros_like(image_masks)), -1)
elif cfg.image_padding_embed == "pad_and_partial_pad":
pad_embed = self.pad_embed[:, None, None, None, :]
all_pad = image_masks == 0
partial_pad = torch.logical_and(image_masks < 1, torch.logical_not(all_pad)).to(dtype=image_features.dtype)
all_pad = all_pad.to(dtype=image_features.dtype)
image_features = image_features + pad_embed[0] * torch.unsqueeze(all_pad, -1)
image_features = image_features + pad_embed[1] * torch.unsqueeze(partial_pad, -1)
else:
raise ValueError(cfg.image_padding_embed)
image_features = self.image_feature_dropout(image_features)
if cls_embed is not None:
cls_embed = self.image_feature_dropout(cls_embed)
image_features = image_features.reshape(
(batch_size, num_image) + cfg.image_num_patch + (-1,),
)
if cfg.image_num_patch[0] % cfg.image_pooling_h == 1:
# Pad so we can still pool 2x2 patches
image_features = F.pad(
image_features,
(0, 0, 0, 1, 0, 1, 0, 0, 0, 0),
)
# image pooling
image_features = einops.rearrange(
image_features,
"b n (h dh) (w dw) c -> (b n h w) (dh dw) c",
dh=cfg.image_pooling_h,
dw=cfg.image_pooling_w,
)
if cfg.image_pooling_2d == ImagePooling2DType.attention_meanq:
query = image_features.mean(-2, keepdim=True)
image_features = self.image_pooling_2d(query, image_features)
elif cfg.image_pooling_2d not in {ImagePooling2DType.none, ImagePooling2DType.stack}:
if self.grad_checkpointing:
from torch.utils.checkpoint import checkpoint
image_features = checkpoint(self.image_pooling_2d, image_features[:, :1, :], image_features, use_reentrant=False)
else:
image_features = self.image_pooling_2d(image_features[:, :1, :], image_features)
h, w = cfg.llm_patches_per_crop()
image_features = image_features.reshape(batch_size, num_image, h * w, -1)
# MLP layer to map the feature.
if self.grad_checkpointing:
from torch.utils.checkpoint import checkpoint
image_features = checkpoint(self.image_projector, image_features, use_reentrant=False)
else:
image_features = self.image_projector(image_features)
# image_features: (batch_size, num_image, num_patch, d_model)
# cls_embed: (batch_size, num_image, d_model)
return image_features, cls_embed
class ModuleType(str, Enum):
in_module = "in"
out_module = "out"
emb = "emb"
final_out = "final_out"
def init_weights(
config: FullMolmoConfig,
module: Union[nn.Linear, nn.Embedding],
d: Optional[int] = None,
layer_id: Optional[int] = None,
std_factor: float = 1.0,
type_of_module: Optional[ModuleType] = None,
) -> None:
d = d if d is not None else config.d_model
std = config.init_std * std_factor
if config.init_cutoff_factor is not None:
cutoff_value = config.init_cutoff_factor * std
nn.init.trunc_normal_(module.weight, mean=0.0, std=std, a=-cutoff_value, b=cutoff_value)
else:
nn.init.normal_(module.weight, mean=0.0, std=std)
class LlamaSwiGLU(nn.Module):
def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
return F.silu(x1) * x2
@property
def output_multiplier(self) -> float:
return 0.5
class SwiGLU(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x, gate = x.chunk(2, dim=-1)
return F.silu(gate) * x
@property
def output_multiplier(self) -> float:
return 0.5
class Activation(nn.Module):
def __init__(self, config: FullMolmoConfig):
super().__init__()
self.config = config
def forward(self, x: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
@property
def output_multiplier(self) -> float:
raise NotImplementedError
@classmethod
def build(cls, config: FullMolmoConfig) -> "Activation":
if config.activation_type == "quick_gelu":
return QuickGELU(config)
elif config.activation_type == "gelu":
return cast(Activation, GELU(approximate="none"))
elif config.activation_type == "gelu_tanh":
return cast(Activation, GELU(approximate="tanh"))
elif config.activation_type == "relu":
return cast(Activation, ReLU(inplace=False))
elif config.activation_type == "silu":
return cast(Activation, SiLU(inplace=False))
# elif config.activation_type == "llama_geglu":
# return LlamaGEGLU(config)
# elif config.activation_type == "llama_geglu_tanh":
# return LlamaGEGLUTanh(config)
elif config.activation_type == "llama_swiglu":
return LlamaSwiGLU()
elif config.activation_type == "swiglu":
return SwiGLU()
else:
raise NotImplementedError(f"Unknown activation: '{config.activation_type}'")
class QuickGELU(Activation):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x * torch.sigmoid(1.702 * x)
@property
def output_multiplier(self) -> float:
return 1.0
class GELU(nn.GELU):
@property
def output_multiplier(self) -> float:
return 1.0
class ReLU(nn.ReLU):
@property
def output_multiplier(self) -> float:
return 1.0
class SiLU(nn.SiLU):
@property
def output_multiplier(self) -> float:
return 1.0
def causal_attention_bias(seq_len: int, device: torch.device) -> torch.FloatTensor:
att_bias = torch.triu(
torch.ones(seq_len, seq_len, device=device, dtype=torch.float),
diagonal=1,
)
att_bias.masked_fill_(att_bias == 1, torch.finfo(att_bias.dtype).min)
return att_bias.view(1, 1, seq_len, seq_len) # type: ignore
def get_causal_attention_bias(cache: BufferCache, seq_len: int, device: torch.device) -> torch.Tensor:
if (causal_bias := cache.get("causal_attention_bias")) is not None and causal_bias.shape[-1] >= seq_len:
if causal_bias.device != device:
causal_bias = causal_bias.to(device)
cache["causal_attention_bias"] = causal_bias
return causal_bias
with torch.autocast(device.type, enabled=False):
causal_bias = causal_attention_bias(seq_len, device)
cache["causal_attention_bias"] = causal_bias
return causal_bias
class LayerNormBase(nn.Module):
def __init__(
self,
config: MolmoConfig,
*,
size: Optional[int] = None,
elementwise_affine: Optional[bool] = True,
eps: float = 1e-05,
weight_initializer: Optional[Callable] = torch.ones,
bias_initializer: Optional[Callable] = torch.zeros,
):
super().__init__()
self.config = config
self.eps = self.config.layer_norm_eps or eps
self.normalized_shape = (size or config.d_model,)
if elementwise_affine or (elementwise_affine is None and self.config.layer_norm_with_affine):
self.weight = nn.Parameter(weight_initializer(self.normalized_shape, device=config.init_device))
use_bias = self.config.bias_for_layer_norm
if use_bias is None:
use_bias = self.config.include_bias
if use_bias:
self.bias = nn.Parameter(bias_initializer(self.normalized_shape, device=config.init_device))
else:
self.register_parameter("bias", None)
else:
self.register_parameter("bias", None)
self.register_parameter("weight", None)
@classmethod
def build(cls, config: FullMolmoConfig, size: Optional[int] = None, **kwargs):
if config.layer_norm_type == "default":
return LayerNorm(config, size=size, low_precision=False, **kwargs)
elif config.layer_norm_type == "low_precision":
return LayerNorm(config, size=size, low_precision=True, **kwargs)
elif config.layer_norm_type == "rms":
return RMSLayerNorm(config, size=size, **kwargs)
else:
raise NotImplementedError(f"Unknown LayerNorm type: '{config.layer_norm_type}'")
class RMSLayerNorm(LayerNormBase):
"""
RMS layer norm, a simplified :class:`LayerNorm` implementation
"""
def __init__(
self,
config: FullMolmoConfig,
size: Optional[int] = None,
elementwise_affine: Optional[bool] = None,
eps: float = 1e-5,
):
super().__init__(config, size=size, elementwise_affine=elementwise_affine, eps=eps)
def forward(self, x: torch.Tensor) -> torch.Tensor:
with torch.autocast(enabled=False, device_type=x.device.type):
og_dtype = x.dtype
x = x.to(torch.float32)
variance = x.pow(2).mean(-1, keepdim=True)
x = x * torch.rsqrt(variance + self.eps)
x = x.to(og_dtype)
if self.weight is not None:
if self.bias is not None:
return self.weight * x + self.bias
else:
return self.weight * x
else:
return x
class LayerNorm(LayerNormBase):
"""
The default :class:`LayerNorm` implementation which can optionally run in low precision.
"""
def __init__(
self,
config: FullMolmoConfig,
size: Optional[int] = None,
low_precision: bool = False,
elementwise_affine: Optional[bool] = None,
eps: float = 1e-05,
):
super().__init__(config, size=size, elementwise_affine=elementwise_affine, eps=eps)
self.low_precision = low_precision
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.low_precision:
module_device = x.device
downcast_x = self._cast_if_autocast_enabled(x)
downcast_weight = self._cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight
downcast_bias = self._cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias
with torch.autocast(enabled=False, device_type=module_device.type):
return F.layer_norm(downcast_x, self.normalized_shape, weight=downcast_weight, bias=downcast_bias, eps=self.eps)
else:
return F.layer_norm(x, self.normalized_shape, weight=self.weight, bias=self.bias, eps=self.eps)
class Molmo(nn.Module):
def __init__(self, config: FullMolmoConfig, init_params: bool = True):
super().__init__()
self.config = config
self.__cache = BufferCache()
# Validate config.
if self.config.embedding_size is not None and self.config.embedding_size != self.config.vocab_size:
if self.config.embedding_size < self.config.vocab_size:
raise MolmoConfigurationError("embedding size should be at least as big as vocab size")
elif self.config.embedding_size % 128 != 0:
import warnings
warnings.warn("Embedding size is not a multiple of 128! This could hurt throughput performance.", UserWarning)
torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_mem_efficient_sdp(
True
) # jakep: I found that setting this to true in torch 2.5.1 greatly increased performance (6sec/it from 22sec/it)
wte = None
if self.config.additional_vocab_size is not None:
wte = Embedding(
config.embedding_size or config.vocab_size,
config.additional_vocab_size,
config.d_model,
device=config.init_device,
initializer_range=config.initializer_range,
new_embed_initializer_range=config.new_embedding_init_range,
)
else:
wte = nn.Embedding(config.embedding_size or config.vocab_size, config.d_model, device=config.init_device)
self.transformer = nn.ModuleDict(
dict(
wte=wte,
emb_drop=Dropout(config.embedding_dropout),
ln_f=LayerNorm.build(config),
)
)
blocks = [MolmoBlock.build(i, config, self.__cache) for i in range(config.n_layers)]
if self.config.block_group_size > 1:
raise NotImplementedError()
else:
self.transformer.update({"blocks": nn.ModuleList(blocks)})
if not self.config.rope:
self.transformer.update({"wpe": nn.Embedding(config.max_sequence_length, config.d_model, device=config.init_device)})
if not config.weight_tying:
self.transformer.update(
{
"ff_out": nn.Linear(
config.d_model,
config.embedding_size or config.vocab_size,
bias=config.include_bias,
device=config.init_device,
)
}
)
self.vision_backbone: Optional[OLMoVisionBackbone] = None
if config.vision_backbone is not None:
self.vision_backbone = OLMoPretrainedVisionBackbone(config)
self.__num_fwd_flops: Optional[int] = None
self.gradient_checkpointing = False
def reset_parameters(self):
if self.vision_backbone is not None:
self.vision_backbone.reset_parameters()
self.reset_non_vision_parameters()
def reset_non_vision_parameters(self):
self.transformer.wte.reset_parameters()
if hasattr(self.transformer.wte, "new_embedding"):
nn.init.normal_(self.transformer.wte.new_embedding, std=self.config.new_embedding_init_range)
if hasattr(self.transformer, "wpe"):
nn.init.normal_(self.transformer.wpe, mean=0.0, std=1.0)
self.transformer.ln_f.reset_parameters() # type: ignore
if hasattr(self.transformer, "ff_out"):
nn.init.normal_(self.transformer.ff_out, mean=0.0, std=0.02)
if self.config.block_group_size == 1:
for block in self.transformer.blocks:
block.reset_parameters()
else:
for block_group in self.transformer.block_groups:
block_group.reset_parameters()
def forward(
self,
input_ids: torch.LongTensor,
input_embeddings: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
attention_bias: Optional[torch.Tensor] = None,
response_mask: Optional[torch.Tensor] = None,
images: Optional[torch.Tensor] = None,
image_masks: Optional[torch.Tensor] = None,
image_input_idx: Optional[torch.Tensor] = None,
subsegment_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
past_key_values: Optional[Sequence[Tuple[torch.Tensor, torch.Tensor]]] = None,
use_cache: bool = False,
last_logits_only: bool = False,
output_hidden_states: Optional[bool] = None,
append_last_valid_logits: Optional[torch.Tensor] = None,
) -> ModelOutput:
"""
:param input_ids: A tensor of shape `(batch_size, seq_len)`.
:param input_embeddings: A tensor of shape `(batch_size, seq_len, d_model)` with input
embeddings. When provided, it is treated as the output of the input embedding layer.
:param attention_mask: A tensor of shape `(batch_size, seq_len)` that indicates
which input IDs are masked. A `1` value in the mask means that
the corresponding input ID should *not* be ignored. A `0` means
that the corresponding input ID is masked.
This has the same meaning as the `attention_mask` in HuggingFace's `transformers`
library.
:param attention_bias: A tensor of shape `(batch_size, 1, seq_len, seq_len)`,
`(1, 1, seq_len, seq_len)`, or `(seq_len, seq_len)`. This is used
to introduce causal or other biases.
If the tensor is a bool or byte tensor, a `True` or `1` at `attention_bias[:, :, i, j]`
indicates that the i-th element in the sequence is allowed to attend to the j-th
element in the sequence.
If the tensor is a float tensor, it will just be added to the attention
scores before the softmax.
The default is causal, which corresponds to a lower-diagonal byte matrix of ones.
:param response_mask: A tensor of shape `(batch_size, seq_len)` that indicates
the response mask. A `1` value in the mask means that the corresponding token
is a response token. A `0` means that the corresponding token is not
a response token.
:param past_key_values: Pre-computed keys and values for each attention block.
Can be used to speed up sequential decoding. The `input_ids` which have
their past given to this model should not be passed as `input_ids` as they have already been computed.
:param use_cache: If `True`, return key and value tensors for each block.
:param last_logits_only: If `True`, only compute the logits for the last token of each sequence.
This can speed up decoding when you only care about the next token.
"""
output_hidden_states = output_hidden_states if output_hidden_states is not None else False
if past_key_values:
assert len(past_key_values) == self.config.n_layers
has_image = images is not None
assert not (has_image and input_embeddings is not None), "Cannot provide both images and input embeddings."
assert not (has_image and past_key_values is not None), "Cached key and values should not be used with images."
batch_size, seq_len = input_ids.size() if input_embeddings is None else input_embeddings.size()[:2]
if past_key_values is None:
past_length = 0
else:
past_length = past_key_values[0][0].size(-2)
if self.config.use_position_ids and attention_mask is None:
attention_mask = input_ids != -1
if subsegment_ids is not None:
assert not use_cache, "Subsegment_ids cannot be used with cache."
subsegment_mask = subsegment_ids.unsqueeze(2) <= subsegment_ids.unsqueeze(1)
attention_mask = subsegment_mask.to(attention_mask.dtype) * attention_mask.unsqueeze(2) * attention_mask.unsqueeze(1)
if position_ids is None:
raise ValueError("Positioned ids must be given if using subsegment_ids")
else:
if self.config.use_position_ids and position_ids is None:
position_ids = torch.clamp(
torch.cumsum(attention_mask.to(torch.int32), dim=-1) - 1,
min=0,
).broadcast_to((batch_size, attention_mask.shape[-1]))
# Get embeddings of input.
# shape: (batch_size, seq_len, d_model)
if input_ids is not None:
input_ids = input_ids * (input_ids != -1).to(input_ids.dtype)
x = self.transformer.wte(input_ids) if input_embeddings is None else input_embeddings # type: ignore
num_image: Optional[int] = None
if images is not None:
# shape: (batch_size, num_image, num_patch, d_model)
# cls_embed: (batch_size, num_image, d_model)
image_features, cls_embed = self.vision_backbone(images, image_masks)
num_image, num_patch = image_features.shape[1:3]
assert image_input_idx.shape == (batch_size, num_image, num_patch)
# inster the image feature into the embedding.
image_features = image_features.view(batch_size, num_image * num_patch, -1)
image_input_idx = image_input_idx.view(batch_size, num_image * num_patch)
valid = image_input_idx >= 0
batch_idx = torch.arange(batch_size, device=x.device)
batch_idx = torch.tile(batch_idx[:, None], [1, image_features.shape[1]])
# For hf demo/endpoint
image_features = image_features.to(x.device)
x[batch_idx[valid], image_input_idx[valid]] += image_features[valid]
if not self.config.rope:
# Get positional embeddings.
# shape: (1, seq_len)
pos = torch.arange(past_length, past_length + seq_len, dtype=torch.long, device=x.device).unsqueeze(0)
# shape: (1, seq_len, d_model)
pos_emb = self.transformer.wpe(pos) # type: ignore
x = pos_emb + x
# Add input + positional embeddings and apply dropout.
# shape: (batch_size, seq_len, d_model)
x = self.transformer.emb_drop(x) # type: ignore
# normalized
if self.config.normalize_input_embeds:
x = x * (self.config.d_model**0.5)
# Transform the attention mask into what the blocks expect.
if attention_mask is not None:
# shape: (batch_size, 1, 1, seq_len)
if len(attention_mask.shape) == 2:
attention_mask = attention_mask[:, : past_length + seq_len]
attention_mask = attention_mask.to(dtype=torch.float).view(batch_size, -1)[:, None, None, :]
else:
attention_mask = attention_mask.unsqueeze(1).to(dtype=torch.float)
attention_mask = (1.0 - attention_mask) * torch.finfo(attention_mask.dtype).min
# Merge attention mask with attention bias.
if (
attention_bias is not None
or attention_mask is not None
# NOTE (epwalsh): we need to initialize the attn bias in order for attn to work properly
# with key+value cache. Otherwise `F.scaled_dot_product_attention()` doesn't seem to compute
# scores correctly.
or past_key_values is not None
):
if attention_bias is None:
attention_bias = get_causal_attention_bias(self.__cache, past_length + seq_len, x.device)
elif attention_bias.dtype in (torch.int8, torch.bool):
attention_bias = attention_bias.to(dtype=torch.float)
attention_bias.masked_fill_(attention_bias == 0.0, torch.finfo(attention_bias.dtype).min)
# Transform to the right shape and data type.
mask_len = seq_len
if attention_mask is not None:
mask_len = attention_mask.shape[-1]
elif past_key_values is not None:
mask_len = past_key_values[0][0].shape[-2] + seq_len
attention_bias = attention_bias[:, :, :mask_len, :mask_len].to(dtype=torch.float)
# Add in the masking bias.
if attention_mask is not None:
attention_bias = attention_bias + attention_mask
# Might get -infs after adding attention mask, since dtype.min + dtype.min = -inf.
# `F.scaled_dot_product_attention()` doesn't handle -inf like you'd expect, instead
# it can produce NaNs.
ensure_finite_(attention_bias, check_neg_inf=True, check_pos_inf=False)
attn_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = [] if use_cache else None
# decoder layers
all_hidden_states = []
# Apply blocks one-by-one.
if self.config.block_group_size == 1:
for block_idx, block in enumerate(self.transformer.blocks):
if output_hidden_states:
# add hidden states
all_hidden_states.append(x)
layer_past = None if past_key_values is None else past_key_values[block_idx]
if self.gradient_checkpointing and self.training:
x, cache = self._gradient_checkpointing_func(
block, x, attention_bias=attention_bias, position_ids=position_ids, layer_past=layer_past, use_cache=use_cache
)
else:
x, cache = block(x, attention_bias=attention_bias, position_ids=position_ids, layer_past=layer_past, use_cache=use_cache)
if attn_key_values is not None:
assert cache is not None
attn_key_values.append(cache)
else:
for group_idx, block_group in enumerate(self.transformer.block_groups):
if output_hidden_states:
# add hidden states
all_hidden_states.append(x)
layers_past = (
None
if past_key_values is None
else past_key_values[group_idx * self.config.block_group_size : (group_idx + 1) * self.config.block_group_size]
)
x, cache = block_group(x, attention_bias=attention_bias, position_ids=position_ids, layers_past=layers_past, use_cache=use_cache)
if attn_key_values is not None:
assert cache is not None
attn_key_values.extend(cache)
if last_logits_only:
# shape: (batch_size, 1, d_model)
if append_last_valid_logits is not None:
last_valid_output = x[torch.arange(x.shape[0], device=x.device), append_last_valid_logits.to(x.device)]
x = last_valid_output.unsqueeze(1)
else:
x = x[:, -1, :].unsqueeze(1)
# Apply final layer norm.
# shape: (batch_size, seq_len or 1, d_model)
x = self.transformer.ln_f(x) # type: ignore
if output_hidden_states:
# add final hidden state post-final-layernorm, following HuggingFace's convention
all_hidden_states.append(x)
# Get logits.
# shape: (batch_size, seq_len or 1, vocab_size)
if self.config.weight_tying:
logits = F.linear(x, self.transformer.wte.weight, None) # type: ignore
else:
logits = self.transformer.ff_out(x) # type: ignore
if self.config.scale_logits:
logits.mul_(1 / math.sqrt(self.config.d_model))
if not last_logits_only and append_last_valid_logits is not None:
last_valid_logit = logits[torch.arange(logits.shape[0], device=logits.device), append_last_valid_logits]
logits = torch.cat([logits[:, :-1], last_valid_logit[:, None]], dim=1)
return ModelOutput(logits=logits, attn_key_values=attn_key_values, hidden_states=tuple(all_hidden_states) if output_hidden_states else None) # type: ignore[arg-type]
class MolmoForCausalLM(PreTrainedModel):
config_class = MolmoConfig
supports_gradient_checkpointing = True
base_model_prefix = "model"
_no_split_modules = ["MolmoBlock"]
def __init__(self, config: MolmoConfig, model: Optional[Molmo] = None, init_params: bool = False):
super().__init__(config)
if not model:
full_config = FullMolmoConfig(
image_padding_embed="pad_and_partial_pad",
image_pooling_2d="attention-meanq",
attention_layer_norm=config.attention_layer_norm,
rope_impl="llama",
vocab_size=config.vocab_size,
max_sequence_length=config.max_position_embeddings,
qkv_bias=config.qkv_bias,
norm_after=config.norm_after,
embedding_size=config.embedding_size,
attention_type="sdpa",
embedding_dropout=0,
attention_dropout=0,
residual_dropout=0,
rope=True,
weight_tying=False,
include_bias=False,
d_model=config.hidden_size,
mlp_hidden_size=config.intermediate_size,
n_layers=config.num_hidden_layers,
additional_vocab_size=128,
n_heads=config.num_attention_heads,
n_kv_heads=config.num_key_value_heads,
rope_theta=config.rope_theta,
layer_norm_eps=config.layer_norm_eps,
layer_norm_type=config.layer_norm_type,
vit_layers=[-2, -9],
vision_backbone=VisionBackboneConfig(
image_default_input_size=(336, 336),
image_patch_size=14,
image_pos_patch_size=14,
image_emb_dim=1024,
image_num_heads=16,
image_num_key_value_heads=16,
image_num_layers=23,
image_head_dim=64,
image_mlp_dim=4096,
image_mlp_activations="quick_gelu",
image_dropout_rate=0.0,
image_num_pos=577,
image_norm_eps=1e-5,
attention_dropout=0.0,
residual_dropout=0.0,
initializer_range=0.02,
),
)
self.model = Molmo(full_config, init_params=init_params)
else:
self.model = model
def forward(
self,
input_ids: torch.LongTensor = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
attention_bias: Optional[torch.Tensor] = None,
response_mask: Optional[torch.Tensor] = None,
images: Optional[torch.Tensor] = None,
image_masks: Optional[torch.Tensor] = None,
image_input_idx: Optional[torch.Tensor] = None,
subsegment_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
labels: Optional[torch.LongTensor] = None,
loss_masks: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
last_logits_only: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
append_last_valid_logits: Optional[torch.Tensor] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[
Cache
] = None, # This is a hack mitigation of an issue in transformers `4.39.x` https://github.com/huggingface/transformers/issues/29426
) -> Union[Tuple, CausalLMOutputWithPast]:
if use_cache is None:
use_cache = self.config.use_cache
if output_attentions:
raise ValueError("output_attentions is not yet supported in Molmo")
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model.forward(
input_ids=input_ids,
input_embeddings=inputs_embeds,
attention_mask=attention_mask,
attention_bias=attention_bias,
response_mask=response_mask,
images=images,
image_masks=image_masks,
image_input_idx=image_input_idx,
subsegment_ids=subsegment_ids,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
last_logits_only=last_logits_only,
output_hidden_states=output_hidden_states,
append_last_valid_logits=append_last_valid_logits,
)
logits = outputs.logits
hidden_states = outputs.hidden_states
loss = None
if labels is not None:
if loss_masks is not None:
loss_masks = loss_masks * (loss_masks > 0)
batch_size_in_tokens = max(loss_masks.sum().item(), 1)
labels = labels.long()
labels.masked_fill_(~(loss_masks > 0), -100)
labels = labels.view(-1)
logits_for_loss = logits.to(torch.float32).view(-1, logits.size(-1))
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100, reduction="none")
loss = loss_fct(logits_for_loss, labels)
loss = loss.view(input_ids.shape[0], -1)
loss = loss * loss_masks
loss = loss.sum() / batch_size_in_tokens
use_zloss = getattr(self.config, "softmax_auxiliary_loss", False)
if use_zloss:
z_squared = logits_for_loss.logsumexp(-1).pow(2)
z_loss = self.config.softmax_auxiliary_loss_scale * z_squared
z_loss = z_loss.view(input_ids.shape[0], -1)
z_loss = z_loss * loss_masks
z_loss = z_loss.sum() / batch_size_in_tokens
loss += z_loss
else:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = torch.nn.CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.embedding_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.attn_key_values,
hidden_states=hidden_states,
)
def can_generate(self) -> bool:
return True
@torch.no_grad()
def generate_from_batch(
self,
batch: Dict[str, Any],
generation_config: Optional[GenerationConfig] = None,
**kwargs,
):
if generation_config is not None:
assert generation_config.use_cache
images = batch.get("images")
image_masks = batch.get("image_masks")
image_input_idx = batch.get("image_input_idx")
# Validate inputs.
input_ids = batch["input_ids"]
batch_size, seq_len = input_ids.shape
attention_mask = batch.get("attention_mask", None)
max_new_tokens = generation_config.max_new_tokens
assert max_new_tokens is not None
mask_len = seq_len + max_new_tokens if self.config.use_position_ids else seq_len
position_ids: Optional[torch.Tensor] = None
append_last_valid_logits: Optional[torch.Tensor] = None
if self.config.use_position_ids and attention_mask is None:
attention_mask = input_ids != -1
position_ids = torch.clamp(torch.cumsum(attention_mask.to(torch.int32), dim=-1) - 1, min=0)
append_last_valid_logits = attention_mask.long().sum(dim=-1) - 1
attention_mask = torch.cat(
[attention_mask, attention_mask.new_ones((batch_size, max_new_tokens))],
dim=1,
)
if attention_mask is not None:
assert attention_mask.shape == (batch_size, mask_len)
out = super().generate(
batch["input_ids"],
generation_config,
attention_mask=attention_mask,
images=images,
image_masks=image_masks,
image_input_idx=image_input_idx,
position_ids=position_ids,
append_last_valid_logits=append_last_valid_logits,
**kwargs,
)
return out
def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple]] = None, **kwargs):
if past_key_values:
# This is because we want the model to only process the last generated token.
input_ids = input_ids[:, -1:]
if self.config.use_position_ids:
attention_mask = kwargs.get("attention_mask")
images = kwargs.get("images")
image_masks = kwargs.get("image_masks")
image_input_idx = kwargs.get("image_input_idx")
position_ids = kwargs.get("position_ids")
append_last_valid_logits = kwargs.get("append_last_valid_logits")
model_inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": True,
"last_logits_only": True,
}
if past_key_values is None:
model_inputs["images"] = images
model_inputs["image_masks"] = image_masks
model_inputs["image_input_idx"] = image_input_idx
model_inputs["append_last_valid_logits"] = append_last_valid_logits
else:
model_inputs = {"input_ids": input_ids, "past_key_values": past_key_values}
model_inputs.update(kwargs)
model_inputs["use_cache"] = kwargs.pop("use_cache", self.config.use_cache)
return model_inputs
def _update_model_kwargs_for_generation(
self,
outputs: ModelOutput,
model_kwargs: Dict[str, Any],
is_encoder_decoder: bool = False,
num_new_tokens: int = 1,
) -> Dict[str, Any]:
if self.config.use_position_ids:
model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1
if "append_last_valid_logits" in model_kwargs:
del model_kwargs["append_last_valid_logits"]
if "images" in model_kwargs:
del model_kwargs["images"]
del model_kwargs["image_masks"]
del model_kwargs["image_input_idx"]
cache_name, cache = super()._extract_past_from_model_output(outputs)
model_kwargs[cache_name] = cache
model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens
return model_kwargs
def get_input_embeddings(self) -> torch.nn.Module:
return self.model.transformer.wte
def set_input_embeddings(self, value: torch.nn.Module):
self.model.transformer.wte = value
def get_output_embeddings(self):
if self.config.weight_tying:
return self.model.transformer.wte
else:
return self.model.transformer.ff_out
def set_output_embeddings(self, value: torch.nn.Module):
if self.config.weight_tying:
self.model.transformer.wte = value
else:
self.model.transformer.ff_out = value
def tie_weights(self):
"""
This function is intentionally left as a no-op.
Weight tying is handled as follows:
- When the model is initialized, the `ff_out` layer is conditionally defined based on the `weight_tying` configuration.
See: `if not config.weight_tying: self.transformer.update(...)` in `olmo/model.py`.
- When computing logits, the `wte` weights are used directly if `weight_tying` is enabled.
See: `if self.config.weight_tying: logits = F.linear(x, self.transformer.wte.weight, None)` in the `forward` method.
Therefore, there is no need to explicitly tie the weights in this function.
"""
pass
def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of: Optional[int] = None) -> torch.nn.Embedding:
"""
Resizes input token embeddings matrix of the model if `new_num_tokens != config.embedding_size`.
Takes care of tying weights embeddings afterwards if the model class has a `tie_weights()` method.
Arguments:
new_num_tokens (`int`, *optional*):
The new number of tokens in the embedding matrix. Increasing the size will add newly initialized
vectors at the end. Reducing the size will remove vectors from the end. If not provided or `None`, just
returns a pointer to the input tokens `torch.nn.Embedding` module of the model without doing anything.
pad_to_multiple_of (`int`, *optional*):
If set will pad the embedding matrix to a multiple of the provided value. If `new_num_tokens` is set to
`None` will just pad the embedding to a multiple of `pad_to_multiple_of`.
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
`>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. For more
details about this, or help on choosing the correct value for resizing, refer to this guide:
https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc
Return:
`torch.nn.Embedding`: Pointer to the input tokens Embeddings Module of the model.
Note:
This method differs from the base class implementation by resizing the `embedding_size` attribute of the
model configuration instead of the `vocab_size`. It also includes a warning if the resized `embedding_size`
is less than the `vocab_size`. In OLMo, `embedding_size` refers to the dimensionality of the model's token
embeddings, while `vocab_size` refers to the number of unique tokens in the vocabulary.
"""
model_embeds = self._resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
if new_num_tokens is None and pad_to_multiple_of is None:
return model_embeds
# Update base model and current model config
self.config.embedding_size = model_embeds.weight.shape[0]
self.model.config.embedding_size = model_embeds.weight.shape[0]
# Check if the embedding size is less than the vocab size
if self.config.embedding_size < self.config.vocab_size:
warning_message = (
f"Resizing token embeddings to size {self.config.embedding_size}, which is less than the vocab size "
f"{self.config.vocab_size} defined in the model configuration. Make sure your tokenizer's vocabulary "
"size is less than or equal to the new token embedding size."
)
log.warning(warning_message)
# Tie weights again if needed
self.tie_weights()
return model_embeds
# Always register for multi-modal features
AutoModelForCausalLM.register(MolmoConfig, MolmoForCausalLM)
"""
Processor class for Molmo.
"""
from typing import Optional
from PIL import ImageOps
from PIL.Image import Image
try:
from typing import Unpack
except ImportError:
from typing_extensions import Unpack
import numpy as np
import torch
from transformers import AutoTokenizer
from transformers.image_utils import ImageInput
from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, TextKwargs
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
from transformers.utils import logging
from .image_preprocessing_molmo import MolmoImageProcessor, MolmoImagesKwargs
logger = logging.get_logger(__name__)
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
DEFAULT_IM_START_TOKEN = "<im_start>"
DEFAULT_IM_END_TOKEN = "<im_end>"
DEFAULT_IM_COL_TOKEN = "<im_col>"
IMAGE_PROMPT = "<|image|>"
EXTRA_TOKENS = (DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_COL_TOKEN, IMAGE_PROMPT)
def get_special_token_ids(tokenizer):
ids = tokenizer.encode("".join(EXTRA_TOKENS), add_special_tokens=False)
assert len(ids) == len(EXTRA_TOKENS)
return {k: i for k, i in zip(EXTRA_TOKENS, ids)}
class MolmoTextKwargs(TextKwargs, total=False):
style: Optional[str]
system_prompt: Optional[str]
message_format: Optional[str]
always_start_with_space: Optional[bool]
sequence_length: Optional[int]
class MolmoProcessorKwargs(ProcessingKwargs, total=False):
text_kwargs: MolmoTextKwargs
images_kwargs: MolmoImagesKwargs
_defaults = {
"images_kwargs": {
"max_crops": 12,
"overlap_margins": [4, 4],
"base_image_input_size": [336, 336],
"image_token_length_w": 12,
"image_token_length_h": 12,
"image_patch_size": 14,
"image_padding_mask": True,
},
"text_kwargs": {
"style": "long_caption",
"system_prompt": "none",
"message_format": "role",
"always_start_with_space": True,
"sequence_length": 1536,
"padding": False,
},
}
class MolmoProcessor(ProcessorMixin):
attributes = ["image_processor", "tokenizer"]
image_processor_class = "AutoImageProcessor"
tokenizer_class = ("GPT2Tokenizer", "GPT2TokenizerFast")
def __init__(self, image_processor: MolmoImageProcessor = None, tokenizer: AutoTokenizer = None, **kwargs):
# self.image_processor = image_processor
# self.tokenizer = tokenizer
super().__init__(image_processor, tokenizer)
self._special_tokens = None
@property
def special_token_ids(self):
if self._special_tokens is None:
self._special_tokens = get_special_token_ids(self.tokenizer)
return self._special_tokens
def get_tokens_input(self, prompt, message_format, always_start_with_space):
if message_format == "none" or message_format is None:
pass
elif message_format == "role":
prompt = "User: " + prompt + " Assistant:"
else:
raise NotImplementedError(f"Message format {message_format} not implemented")
if always_start_with_space:
prompt = " " + prompt
tokens = self.tokenizer.encode(prompt, add_special_tokens=False)
return tokens
def process(
self,
text: TextInput = None,
images: ImageInput = None,
*,
tokens: Optional[PreTokenizedInput] = None,
**kwargs: Unpack[MolmoProcessorKwargs],
):
output_kwargs = self._merge_kwargs(
MolmoProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
if tokens is None:
tokens = self.get_tokens_input(
text,
output_kwargs["text_kwargs"]["message_format"],
output_kwargs["text_kwargs"]["always_start_with_space"],
)
image_token_id = self.special_token_ids[IMAGE_PROMPT]
if images is not None:
if not isinstance(images, (list, tuple)):
images = [images]
image_arrays = []
for image in images:
if isinstance(image, Image):
image = image.convert("RGB")
# Handle images with EXIF orientation tags, which PIL will ignore by default
# https://github.com/python-pillow/Pillow/issues/4703
img = ImageOps.exif_transpose(image)
image_arrays.append(np.array(image))
else:
assert len(image.shape) == 3 and image.shape[-1] == 3
image_arrays.append(image.astype(np.uint8))
images = image_arrays
# For now only support inserting images at the start
image_idx = [-1] * len(images)
else:
image_idx = None
sequence_length = output_kwargs["text_kwargs"]["sequence_length"]
image_patch_token_id = self.special_token_ids[DEFAULT_IMAGE_PATCH_TOKEN]
image_col_token_id = self.special_token_ids[DEFAULT_IM_COL_TOKEN]
image_start_token_id = self.special_token_ids[DEFAULT_IM_START_TOKEN]
image_end_token_id = self.special_token_ids[DEFAULT_IM_END_TOKEN]
out = self.image_processor.multimodal_preprocess(
images=images,
image_idx=image_idx,
tokens=np.asarray(tokens).astype(np.int32),
sequence_length=sequence_length,
image_patch_token_id=image_patch_token_id,
image_col_token_id=image_col_token_id,
image_start_token_id=image_start_token_id,
image_end_token_id=image_end_token_id,
**output_kwargs["images_kwargs"],
)
# Prepend BOS
# qwen2 and olmo do not have a BOS, and instead use EOS as a generic seperator token.
bos = self.tokenizer.bos_token_id or self.tokenizer.eos_token_id
decoder_input_tokens = np.pad(out["input_ids"], [[1, 0]], constant_values=bos)
out["input_ids"] = decoder_input_tokens
if "image_input_idx" in out:
# Shift patch mapping up by one since we added BOS
image_input_idx = out["image_input_idx"]
out["image_input_idx"] = np.where(image_input_idx < 0, image_input_idx, image_input_idx + 1)
for k, v in out.items():
out[k] = torch.from_numpy(v)
return out
MolmoProcessor.register_for_auto_class()
import logging
import os
from logging import Logger
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Optional
import torch
import torch.distributed
import wandb
from datasets.utils import disable_progress_bars
from datasets.utils.logging import set_verbosity
from peft import LoraConfig, get_peft_model # pyright: ignore
from transformers import (
AutoProcessor,
Qwen2VLForConditionalGeneration,
Trainer,
TrainerCallback,
TrainingArguments,
)
from transformers.integrations import WandbCallback
from transformers.trainer_callback import TrainerControl, TrainerState
from transformers.trainer_utils import get_last_checkpoint
from olmocr.train.core.cli import make_cli, save_config, to_native_types
from olmocr.train.core.config import TrainConfig
from olmocr.train.core.loggers import get_logger
from olmocr.train.core.paths import copy_dir, join_path
from olmocr.train.core.state import BeakerState
from .utils import (
RunName,
TruncatingCollator,
get_local_dir,
log_trainable_parameters,
make_dataset,
setup_environment,
)
class CheckpointUploadCallback(TrainerCallback):
def __init__(self, save_path: str, logger: Optional[Logger] = None):
self.save_path = save_path
self.logger = logger or get_logger(self.__class__.__name__)
def on_save(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
if state.is_local_process_zero:
latest_checkpoint = get_last_checkpoint(args.output_dir)
if not latest_checkpoint:
return
dir_name = Path(latest_checkpoint).name
copy_dir(str(latest_checkpoint), f"{self.save_path}/{dir_name}")
self.logger.info("Saved checkpoint to %s", f"{self.save_path}/{dir_name}")
def update_wandb_config(config: TrainConfig, trainer: Trainer, model: torch.nn.Module):
# finding wandb callback
callbacks = [c for c in trainer.callback_handler.callbacks if isinstance(c, WandbCallback)] # pyright: ignore
if not callbacks:
raise ValueError("WandbCallback not found in trainer callbacks")
wandb_callback = callbacks[0]
peft_config = to_native_types(getattr(model, "peft_config", {}))
script_config = to_native_types(config)
beaker_envs = {k: v for k, v in os.environ.items() if k.lower().startswith("beaker")}
on_setup_fn = wandb_callback.setup
def setup_and_update(args, state, model, **kwargs):
on_setup_fn(args=args, state=state, model=model, **kwargs)
wandb.config.update({"peft": peft_config}, allow_val_change=True)
wandb.config.update({"script": script_config}, allow_val_change=True)
wandb.config.update({"beaker": beaker_envs}, allow_val_change=True)
if (run := wandb.run) and (beaker_url := BeakerState().url):
run.notes = beaker_url
wandb_callback.setup = setup_and_update
def get_rank() -> int:
if torch.distributed.is_available() and torch.distributed.is_initialized():
return torch.distributed.get_rank()
return 0
def run_train(config: TrainConfig):
if get_rank() == 0:
logger_level = logging.INFO
else:
logger_level = logging.WARN
disable_progress_bars()
logger = get_logger(__name__, level=logger_level)
set_verbosity(logger_level)
run_name = RunName.get(config)
setup_environment(aws_config=config.aws, wandb_config=config.wandb, WANDB_RUN_GROUP=run_name.group)
processor = AutoProcessor.from_pretrained(config.model.name_or_path, trust_remote_code=True)
train_dataset, valid_dataset = make_dataset(config, processor)
logger.info(train_dataset)
logger.info(valid_dataset)
if "qwen" in config.model.name_or_path.lower():
model = Qwen2VLForConditionalGeneration.from_pretrained(
config.model.name_or_path, torch_dtype=torch.bfloat16, _attn_implementation="flash_attention_2" if config.model.use_flash_attn else None
)
else:
from .molmo.config_molmo import MolmoConfig
from .molmo.modeling_molmo import MolmoForCausalLM
model_config = MolmoConfig.from_pretrained(config.model.name_or_path, trust_remote_code=True)
if model_config.max_position_embeddings < config.generate.max_length:
logger.warning(
f"ALERT, force adjusting model config max_position_embeddings upwards from {model_config.max_position_embeddings} to {config.generate.max_length}"
)
model_config.max_position_embeddings = config.generate.max_length
if config.model.use_flash_attn:
model_config.attention_type = "flash"
model = MolmoForCausalLM.from_pretrained(config.model.name_or_path, torch_dtype=torch.bfloat16, config=model_config, trust_remote_code=True)
logger.info(model)
if config.lora is not None:
peft_config = LoraConfig(
r=config.lora.rank,
lora_alpha=config.lora.alpha,
lora_dropout=config.lora.dropout,
bias=config.lora.bias, # pyright: ignore
task_type=config.lora.task_type,
target_modules=list(config.lora.target_modules),
)
model = get_peft_model(model=model, peft_config=peft_config)
log_trainable_parameters(model=model, logger=logger)
save_path = join_path("", config.save.path, run_name.run)
# Make sure directory exists if local
if not save_path.startswith("s3://"):
os.makedirs(os.path.dirname(save_path), exist_ok=True)
save_config(config, join_path("", save_path, "config.yaml")) # pyright: ignore
with TemporaryDirectory() as output_dir:
training_args = TrainingArguments(
run_name=run_name.run,
logging_steps=config.hparams.log_every_steps,
output_dir=output_dir,
eval_strategy="steps",
report_to="wandb",
# report_to=[], # disable logging to wandb, we will use a custom callback
optim=config.hparams.optim,
eval_steps=config.hparams.eval_every_steps,
learning_rate=config.hparams.learning_rate,
per_device_train_batch_size=config.hparams.batch_size,
per_device_eval_batch_size=config.hparams.eval_batch_size or config.hparams.batch_size,
gradient_checkpointing=config.hparams.gradient_checkpointing,
gradient_checkpointing_kwargs=(
dict(use_reentrant=False) # from this issue: https://github.com/huggingface/peft/issues/1142
if config.hparams.gradient_checkpointing and config.lora is not None
else {}
),
gradient_accumulation_steps=config.hparams.gradient_accumulation_steps,
max_steps=config.hparams.max_steps,
weight_decay=config.hparams.weight_decay,
dataloader_num_workers=config.max_workers,
load_best_model_at_end=True,
save_strategy="steps",
ddp_find_unused_parameters=config.hparams.find_unused_parameters,
save_steps=config.save.save_every_steps,
warmup_steps=config.hparams.warmup_steps,
warmup_ratio=config.hparams.warmup_ratio,
bf16=True,
label_names=["labels"], # fix from https://github.com/huggingface/transformers/issues/22885
max_grad_norm=config.hparams.clip_grad_norm,
remove_unused_columns=False,
eval_on_start=True,
metric_for_best_model=config.valid_data.metric_for_best_model,
)
data_collator = TruncatingCollator(max_length=config.generate.max_length)
checkpoint_callback = CheckpointUploadCallback(save_path=save_path, logger=logger)
# Initialize Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=valid_dataset,
tokenizer=processor.tokenizer,
data_collator=data_collator,
callbacks=[checkpoint_callback],
)
# Train the model
trainer.train() # pyright: ignore
if get_rank() == 0:
with get_local_dir(join_path("", save_path, "best")) as best_dir:
if config.lora is not None:
logger.info("Merging LoRA adapters into the base model...")
model = model.merge_and_unload()
logger.info("LoRA adapters merged successfully.")
model.save_pretrained(best_dir)
logger.info("Saved best model to %s", best_dir)
# Uncomment to test speed of data loader
# train_dataloader = DataLoader(formatted_dataset["train"], batch_size=1, num_workers=4, shuffle=False)
# for entry in tqdm(train_dataloader):
# print("Step!")
# model.forward(**{k: v.to("cuda:0") for (k,v) in entry.items()})
def main():
train_config = make_cli(TrainConfig) # pyright: ignore
run_train(train_config)
if __name__ == "__main__":
main()
import json
import multiprocessing
import os
import random
from contextlib import contextmanager
from dataclasses import dataclass
from datetime import datetime
from functools import partial
from hashlib import sha1
from logging import Logger
from tempfile import TemporaryDirectory
from typing import Dict, Generator, List, Optional, TypeVar
import torch
from accelerate import Accelerator
from accelerate.utils import PrecisionType
from datasets import Dataset, DatasetDict, concatenate_datasets
from transformers import AutoProcessor
from olmocr.train.dataloader import build_finetuning_dataset
from olmocr.train.dataprep import (
batch_prepare_data_for_molmo_training,
batch_prepare_data_for_qwen2_training,
)
from .core.cli import to_native_types
from .core.config import AwsConfig, DataConfig, SourceConfig, TrainConfig, WandbConfig
from .core.loggers import get_logger
from .core.paths import copy_dir, is_local
from .core.state import BeakerState
T = TypeVar("T")
def accelerator_to_dtype(accelerator: Accelerator) -> torch.dtype:
pt = PrecisionType(accelerator.mixed_precision)
if pt == PrecisionType.FP16:
return torch.float16
elif pt == PrecisionType.BF16:
return torch.bfloat16
elif pt == PrecisionType.FP8:
return torch.float8_e4m3fn
return torch.float32
def get_rawdataset_from_source(data_config: DataConfig, source: SourceConfig) -> Dataset:
return build_finetuning_dataset(source.response_glob_path, pdf_cache_location=data_config.cache_location)
def make_dataset(config: TrainConfig, processor: AutoProcessor) -> tuple[Dataset, Dataset]:
random.seed(config.train_data.seed)
if "qwen" in config.model.name_or_path.lower():
batch_fn = batch_prepare_data_for_qwen2_training
elif "molmo" in config.model.name_or_path.lower():
batch_fn = batch_prepare_data_for_molmo_training
else:
raise NotImplementedError("Model format not supported")
# Retrieve the two target lengths from the first source for comparison
first_source = config.train_data.sources[0]
target_longest_image_dim = first_source.target_longest_image_dim
target_anchor_text_len = first_source.target_anchor_text_len
# Verify that all sources have the same target lengths
for source in config.train_data.sources:
if source.target_longest_image_dim != target_longest_image_dim:
raise ValueError(f"Inconsistent target_longest_image_dim found in source {source}")
if source.target_anchor_text_len != target_anchor_text_len:
raise ValueError(f"Inconsistent target_anchor_text_len found in source {source}")
# Concatenate datasets first, unfortunately you can't apply the transform before concatenation due to the library
train_dataset = concatenate_datasets([get_rawdataset_from_source(config.train_data, source) for source in config.train_data.sources])
# Apply the transform to the concatenated dataset
train_dataset = train_dataset.with_transform(
partial(
batch_fn,
processor=processor,
target_longest_image_dim=list(target_longest_image_dim),
target_anchor_text_len=list(target_anchor_text_len),
)
)
# Validation sets get put into a datasetdict so each can report a loss separately
valid_dataset = DatasetDict(
**{
source.name: get_rawdataset_from_source(config.valid_data, source).with_transform(
partial(
batch_fn,
processor=processor,
target_longest_image_dim=list(source.target_longest_image_dim),
target_anchor_text_len=list(source.target_anchor_text_len),
)
)
for source in config.valid_data.sources
}
)
return train_dataset, valid_dataset
def setup_environment(aws_config: Optional[AwsConfig] = None, wandb_config: Optional[WandbConfig] = None, **kwargs: str):
multiprocessing.set_start_method("spawn", force=True)
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "false"
if wandb_config:
os.environ["WANDB_WATCH"] = "false"
for key, value in to_native_types(wandb_config or {}).items():
if value is not None:
os.environ[f"WANDB_{key.upper()}"] = str(value)
for key, value in to_native_types(aws_config or {}).items():
if value is not None:
os.environ[f"AWS_{key.upper()}"] = str(value)
os.environ.update(kwargs)
@dataclass
class RunName:
run: str
group: str
@classmethod
def get(cls, config: TrainConfig, accelerator: Optional[Accelerator] = None) -> "RunName":
job_rank = f"-{accelerator.process_index}" if accelerator else ""
if beaker_job_id := BeakerState().job_id:
job_id = f"-{beaker_job_id}"
else:
job_id = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
(config_hash := sha1()).update(json.dumps(to_native_types(config)).encode())
model_name = config.model.name_or_path.replace("/", "_")
group_name = f"{model_name}-{config_hash.hexdigest()[:6]}"
run_name = f"{group_name}{job_id}{job_rank}"
return cls(group=group_name, run=run_name)
@contextmanager
def override_torch_threads(n: int):
torch_num_threads = torch.get_num_threads()
torch.set_num_threads(n)
yield
torch.set_num_threads(torch_num_threads)
@contextmanager
def temp_args(obj: T, **kwargs) -> Generator[T, None, None]:
orig = {k: getattr(obj, k) for k in kwargs.keys()}
for k, v in kwargs.items():
setattr(obj, k, v)
yield obj
for k, v in orig.items():
setattr(obj, k, v)
def log_trainable_parameters(model: torch.nn.Module, logger: Optional[Logger] = None):
"""
Prints the number of trainable parameters in the model.
"""
trainable_params = 0
all_param = 0
for name, param in model.named_parameters():
all_param += param.numel()
if param.requires_grad:
(logger or get_logger(__name__)).info(f"training with {name}")
trainable_params += param.numel()
(logger or get_logger(__name__)).info(
"trainable params: %s || all params: %s || trainable%%: %s",
f"{trainable_params:,}",
f"{all_param:,}",
f"{trainable_params / all_param:.2%}",
)
class TruncatingCollator:
def __init__(self, max_length: int):
self.max_length = max_length
def __call__(self, batch: List[Dict]) -> Dict:
# Assert that we are only handling batch size 1 for now
assert len(batch) == 1, "Only batch size 1 is supported for now"
if "pixel_values" in batch[0]:
# Qwen2 case
truncated_input_ids = torch.tensor(batch[0]["input_ids"][: self.max_length]).unsqueeze(0)
truncated_attention_mask = torch.tensor(batch[0]["attention_mask"][: self.max_length]).unsqueeze(0)
truncated_labels = torch.tensor(batch[0]["labels"][: self.max_length]).unsqueeze(0)
return {
"input_ids": truncated_input_ids,
"attention_mask": truncated_attention_mask,
"labels": truncated_labels,
"pixel_values": torch.tensor(batch[0]["pixel_values"]).unsqueeze(0),
"image_grid_thw": torch.tensor(batch[0]["image_grid_thw"]).unsqueeze(0),
}
elif "image_input_idx" in batch[0]:
# molmo case
truncated_input_ids = batch[0]["input_ids"][: self.max_length].unsqueeze(0)
truncated_attention_mask = batch[0]["attention_mask"][: self.max_length].unsqueeze(0)
truncated_labels = batch[0]["labels"][: self.max_length].unsqueeze(0)
return {
"input_ids": truncated_input_ids,
"attention_mask": truncated_attention_mask,
"labels": truncated_labels,
"images": batch[0]["images"].unsqueeze(0),
"image_input_idx": batch[0]["image_input_idx"].unsqueeze(0),
"image_masks": batch[0]["image_masks"].unsqueeze(0),
}
else:
raise NotImplementedError()
@contextmanager
def get_local_dir(output_dir: str):
with TemporaryDirectory() as tmp_dir:
if is_local(output_dir):
yield output_dir
else:
yield tmp_dir
copy_dir(tmp_dir, output_dir)
_MAJOR = "0"
_MINOR = "1"
# On main and in a nightly release the patch should be one ahead of the last
# released build.
_PATCH = "59"
# This is mainly for nightly builds which have the suffix ".dev$DATE". See
# https://semver.org/#is-v123-a-semantic-version for the semantics.
_SUFFIX = ""
VERSION_SHORT = "{0}.{1}".format(_MAJOR, _MINOR)
VERSION = "{0}.{1}.{2}{3}".format(_MAJOR, _MINOR, _PATCH, _SUFFIX)
import argparse
import glob
import html
import json
import os
import tempfile
from concurrent.futures import ThreadPoolExecutor, as_completed
import boto3
import markdown2
import smart_open
from botocore.exceptions import NoCredentialsError, PartialCredentialsError
from jinja2 import Template
from tqdm import tqdm
from olmocr.data.renderpdf import render_pdf_to_base64webp
from olmocr.s3_utils import get_s3_bytes, parse_s3_path
def read_jsonl(paths):
"""
Generator that yields lines from multiple JSONL files.
Supports both local and S3 paths.
"""
for path in paths:
try:
with smart_open.smart_open(path, "r", encoding="utf-8") as f:
for line in f:
yield line.strip()
except Exception as e:
print(f"Error reading {path}: {e}")
def generate_presigned_url(s3_client, bucket_name, key_name):
try:
response = s3_client.generate_presigned_url(
"get_object", Params={"Bucket": bucket_name, "Key": key_name}, ExpiresIn=3600 * 24 * 7 - 100 # Link expires in 1 week
)
return response
except (NoCredentialsError, PartialCredentialsError) as e:
print(f"Error generating presigned URL: {e}")
return None
def process_document(data, s3_client, template, output_dir):
id_ = data.get("id")
text = data.get("text", "")
attributes = data.get("attributes", {})
pdf_page_numbers = attributes.get("pdf_page_numbers", [])
metadata = data.get("metadata", {})
source_file = metadata.get("Source-File")
# Generate base64 image of the corresponding PDF page
local_pdf = tempfile.NamedTemporaryFile("wb+", suffix=".pdf", delete=False)
try:
pdf_bytes = get_s3_bytes(s3_client, source_file)
if pdf_bytes is None:
print(f"Failed to retrieve PDF from {source_file}")
return
local_pdf.write(pdf_bytes)
local_pdf.flush()
pages = []
for span in pdf_page_numbers:
start_index, end_index, page_num = span
page_text = text[start_index:end_index]
# Detect and convert Markdown to HTML
page_text = html.escape(page_text, quote=True).replace("&lt;br&gt;", "<br>")
page_text = markdown2.markdown(page_text, extras=["tables"])
base64_image = render_pdf_to_base64webp(local_pdf.name, page_num)
pages.append({"page_num": page_num, "text": page_text, "image": base64_image})
except Exception as e:
print(f"Error processing document ID {id_}: {e}")
return
finally:
local_pdf.close()
os.unlink(local_pdf.name)
# Generate pre-signed URL if source_file is an S3 path
s3_link = None
if source_file and source_file.startswith("s3://"):
bucket_name, key_name = parse_s3_path(source_file)
s3_link = generate_presigned_url(s3_client, bucket_name, key_name)
# Render the HTML using the Jinja template
try:
html_content = template.render(id=id_, pages=pages, s3_link=s3_link)
except Exception as e:
print(f"Error rendering HTML for document ID {id_}: {e}")
return
# Write the HTML content to a file
try:
safe_source = source_file.replace("s3://", "").replace("/", "_").replace(".", "_") if source_file else f"id_{id_}"
filename = f"{safe_source}.html"
filepath = os.path.join(output_dir, filename)
with open(filepath, "w", encoding="utf-8") as f:
f.write(html_content)
except Exception as e:
print(f"Error writing HTML file for document ID {id_}: {e}")
def main(jsonl_paths, output_dir, template_path, s3_profile_name):
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# Expand glob patterns for local paths
expanded_paths = []
for path in jsonl_paths:
if path.startswith("s3://"):
expanded_paths.append(path)
else:
matched = glob.glob(path)
if not matched:
print(f"No files matched the pattern: {path}")
expanded_paths.extend(matched)
if not expanded_paths:
print("No JSONL files to process.")
return
# Load the Jinja template
try:
with open(os.path.join(os.path.dirname(__file__), template_path), "r", encoding="utf-8") as template_file:
template_content = template_file.read()
template = Template(template_content)
except Exception as e:
print(f"Error loading template: {e}")
return
# Initialize S3 client for generating presigned URLs
try:
workspace_session = boto3.Session(profile_name=s3_profile_name)
s3_client = workspace_session.client("s3")
except Exception as e:
print(f"Error initializing S3 client: {e}")
return
# Create ThreadPoolExecutor
with ThreadPoolExecutor() as executor:
futures = []
for line in read_jsonl(expanded_paths):
if not line:
continue
try:
data = json.loads(line)
except json.JSONDecodeError as e:
print(f"Invalid JSON line: {e}")
continue
future = executor.submit(process_document, data, s3_client, template, output_dir)
futures.append(future)
for _ in tqdm(as_completed(futures), total=len(futures), desc="Processing documents"):
pass # Progress bar updates automatically
print(f"Output HTML-viewable pages to directory: {args.output_dir}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Generate HTML pages from one or more JSONL files with pre-signed S3 links.")
parser.add_argument("jsonl_paths", nargs="+", help="Path(s) to the JSONL file(s) (local or s3://). Supports glob patterns for local paths.")
parser.add_argument("--output_dir", default="dolma_previews", help="Directory to save HTML files")
parser.add_argument("--template_path", default="dolmaviewer_template.html", help="Path to the Jinja2 template file")
parser.add_argument("--s3_profile", default=None, help="S3 profile to use for accessing the source documents to render them in the viewer.")
args = parser.parse_args()
main(args.jsonl_paths, args.output_dir, args.template_path, args.s3_profile)
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<title>{{ id }}</title>
<style>
/* CSS styles */
body {
font-family: Arial, sans-serif;
background-color: #f0f0f0;
margin: 0;
padding: 0;
display: flex;
justify-content: center;
}
.document {
background-color: #fff;
padding: 40px;
margin: 20px;
width: 60%;
box-shadow: 0px 0px 10px rgba(0, 0, 0, 0.1);
line-height: 1.8;
}
.page-section {
display: flex;
flex-direction: row;
margin-bottom: 20px;
transition: background-color 0.3s ease;
}
.page-section:hover {
background-color: #f5f5f5;
}
.page-section .text {
flex: 2;
padding: 10px;
text-align: justify;
}
.page-section .image {
flex: 1;
padding: 10px;
}
.page-section img {
max-width: 100%;
height: auto;
border: 1px solid #ccc;
}
table {
width: 100%;
border-collapse: collapse; /* Ensures that borders are collapsed to give a clean look */
margin-bottom: 1.5em; /* Adds some space below the table */
}
th, td {
border: 1px solid #ddd; /* 1px solid border for table cells */
padding: 12px 15px; /* Adds some padding for better spacing inside the cells */
text-align: left; /* Aligns text to the left */
vertical-align: top; /* Aligns content to the top of the cell */
font-size: 14px; /* Adjusts font size for readability */
}
th {
background-color: #f4f4f4; /* Light background for table headers */
font-weight: bold; /* Bolds header text */
text-transform: uppercase; /* Makes header text uppercase */
letter-spacing: 0.05em; /* Adds slight spacing between letters for readability */
border-bottom: 2px solid #ccc; /* Slightly thicker bottom border for headers */
}
tr:nth-child(even) {
background-color: #f9f9f9; /* Alternates row background color */
}
tr:hover {
background-color: #f1f1f1; /* Highlights row on hover for better interaction */
}
td img {
max-width: 100%; /* Ensures any images in table cells scale properly */
height: auto;
display: block;
}
table caption {
caption-side: bottom; /* Position caption at the bottom of the table */
text-align: right;
font-size: 12px;
color: #777;
padding: 5px 0;
}
</style>
<script type="text/javascript">
window.MathJax = {
tex: {
inlineMath: [['$', '$'], ['\\(', '\\)']],
displayMath: [['$$', '$$'], ['\\[', '\\]']]
},
options: {
skipHtmlTags: ['script', 'noscript', 'style', 'textarea', 'pre'],
processHtmlClass: 'mathjax-process' // Class name for areas where LaTeX should be processed
}
};
</script>
<script type="text/javascript" id="MathJax-script" async
src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js">
</script>
</head>
<body>
<div class="document">
{% for page in pages %}
<div class="page-section" id="page-{{ page.page_num }}">
<div class="text">{{ page.text|safe }}</div>
{% if page.image %}
<div class="image">
<a href="{{ s3_link }}#page={{ page.page_num }}" target="_blank">
<img src="data:image/webp;base64,{{ page.image }}" alt="Page {{ page.page_num }} Image">
</a>
</div>
{% endif %}
</div>
{% endfor %}
</div>
</body>
</html>
import abc
import asyncio
import datetime
import hashlib
import logging
import os
import random
from asyncio import Queue
from dataclasses import dataclass
from typing import Any, List, Optional
from olmocr.s3_utils import (
download_zstd_csv,
expand_s3_glob,
parse_s3_path,
upload_zstd_csv,
)
logger = logging.getLogger(__name__)
@dataclass
class WorkItem:
"""Represents a single work item in the queue"""
hash: str
work_paths: List[str]
class WorkQueue(abc.ABC):
"""
Base class defining the interface for a work queue.
"""
@abc.abstractmethod
async def populate_queue(self, work_paths: List[str], items_per_group: int) -> None:
"""
Add new items to the work queue. The specifics will vary depending on
whether this is a local or S3-backed queue.
Args:
work_paths: Each individual path that we will process over
items_per_group: Number of items to group together in a single work item
"""
pass
@abc.abstractmethod
async def initialize_queue(self) -> None:
"""
Load the work queue from the relevant store (local or remote)
and initialize it for processing.
For example, this might remove already completed work items and randomize
the order before adding them to an internal queue.
"""
pass
@abc.abstractmethod
async def is_completed(self, work_hash: str) -> bool:
"""
Check if a work item has been completed.
Args:
work_hash: Hash of the work item to check
Returns:
True if the work is completed, False otherwise
"""
pass
@abc.abstractmethod
async def get_work(self, worker_lock_timeout_secs: int = 1800) -> Optional[WorkItem]:
"""
Get the next available work item that isn't completed or locked.
Args:
worker_lock_timeout_secs: Number of seconds before considering
a worker lock stale (default 30 mins)
Returns:
WorkItem if work is available, None if queue is empty
"""
pass
@abc.abstractmethod
async def mark_done(self, work_item: WorkItem) -> None:
"""
Mark a work item as done by removing its lock file
or performing any other cleanup.
Args:
work_item: The WorkItem to mark as done
"""
pass
@property
@abc.abstractmethod
def size(self) -> int:
"""Get current size of work queue"""
pass
@staticmethod
def _compute_workgroup_hash(work_paths: List[str]) -> str:
"""
Compute a deterministic hash for a group of paths.
Args:
work_paths: List of paths (local or S3)
Returns:
SHA1 hash of the sorted paths
"""
sha1 = hashlib.sha1()
for path in sorted(work_paths):
sha1.update(path.encode("utf-8"))
return sha1.hexdigest()
# --------------------------------------------------------------------------------------
# Local Helpers for reading/writing the index CSV (compressed with zstd) to disk
# --------------------------------------------------------------------------------------
try:
import zstandard
except ImportError:
zstandard = None
def download_zstd_csv_local(local_path: str) -> List[str]:
"""
Download a zstd-compressed CSV from a local path.
If the file doesn't exist, returns an empty list.
"""
if not os.path.exists(local_path):
return []
if not zstandard:
raise RuntimeError("zstandard package is required for local zstd CSV operations.")
with open(local_path, "rb") as f:
dctx = zstandard.ZstdDecompressor()
data = dctx.decompress(f.read())
lines = data.decode("utf-8").splitlines()
return lines
def upload_zstd_csv_local(local_path: str, lines: List[str]) -> None:
"""
Upload a zstd-compressed CSV to a local path.
"""
if not zstandard:
raise RuntimeError("zstandard package is required for local zstd CSV operations.")
data = "\n".join(lines).encode("utf-8")
cctx = zstandard.ZstdCompressor()
compressed_data = cctx.compress(data)
# Ensure parent directories exist
os.makedirs(os.path.dirname(local_path), exist_ok=True)
with open(local_path, "wb") as f:
f.write(compressed_data)
# --------------------------------------------------------------------------------------
# LocalWorkQueue Implementation
# --------------------------------------------------------------------------------------
class LocalWorkQueue(WorkQueue):
"""
A local in-memory and on-disk WorkQueue implementation, which uses
a local workspace directory to store the queue index, lock files,
and completed results for persistent resumption across process restarts.
"""
def __init__(self, workspace_path: str):
"""
Initialize the local work queue.
Args:
workspace_path: Local directory path where the queue index,
results, and locks are stored.
"""
self.workspace_path = os.path.abspath(workspace_path)
os.makedirs(self.workspace_path, exist_ok=True)
# Local index file (compressed)
self._index_path = os.path.join(self.workspace_path, "work_index_list.csv.zstd")
# Output directory for completed tasks
self._results_dir = os.path.join(self.workspace_path, "results")
os.makedirs(self._results_dir, exist_ok=True)
# Directory for lock files
self._locks_dir = os.path.join(self.workspace_path, "worker_locks")
os.makedirs(self._locks_dir, exist_ok=True)
# Internal queue
self._queue: Queue[Any] = Queue()
async def populate_queue(self, work_paths: List[str], items_per_group: int) -> None:
"""
Add new items to the work queue (local version).
Args:
work_paths: Each individual path (local in this context)
that we will process over
items_per_group: Number of items to group together in a single work item
"""
# Treat them as local paths, but keep variable name for consistency
all_paths = set(work_paths)
logger.info(f"Found {len(all_paths):,} total paths")
# Load existing work groups from local index
existing_lines = await asyncio.to_thread(download_zstd_csv_local, self._index_path)
existing_groups = {}
for line in existing_lines:
if line.strip():
parts = line.strip().split(",")
group_hash = parts[0]
group_paths = parts[1:]
existing_groups[group_hash] = group_paths
existing_path_set = {p for paths in existing_groups.values() for p in paths}
new_paths = all_paths - existing_path_set
logger.info(f"{len(new_paths):,} new paths to add to the workspace")
if not new_paths:
return
# Create new work groups
new_groups = []
current_group = []
for path in sorted(new_paths):
current_group.append(path)
if len(current_group) == items_per_group:
group_hash = self._compute_workgroup_hash(current_group)
new_groups.append((group_hash, current_group))
current_group = []
if current_group:
group_hash = self._compute_workgroup_hash(current_group)
new_groups.append((group_hash, current_group))
logger.info(f"Created {len(new_groups):,} new work groups")
# Combine and save updated work groups
combined_groups = existing_groups.copy()
for group_hash, group_paths in new_groups:
combined_groups[group_hash] = group_paths
combined_lines = [",".join([group_hash] + group_paths) for group_hash, group_paths in combined_groups.items()]
if new_groups:
# Write the combined data back to disk in zstd CSV format
await asyncio.to_thread(upload_zstd_csv_local, self._index_path, combined_lines)
async def initialize_queue(self) -> None:
"""
Load the work queue from the local index file and initialize it for processing.
Removes already completed work items and randomizes the order.
"""
# 1) Read the index
work_queue_lines = await asyncio.to_thread(download_zstd_csv_local, self._index_path)
work_queue = {parts[0]: parts[1:] for line in work_queue_lines if (parts := line.strip().split(",")) and line.strip()}
# 2) Determine which items are completed by scanning local results/*.jsonl
if not os.path.isdir(self._results_dir):
os.makedirs(self._results_dir, exist_ok=True)
done_work_items = [f for f in os.listdir(self._results_dir) if f.startswith("output_") and f.endswith(".jsonl")]
done_work_hashes = {fn[len("output_") : -len(".jsonl")] for fn in done_work_items}
# 3) Filter out completed items
remaining_work_hashes = set(work_queue) - done_work_hashes
remaining_items = [WorkItem(hash=hash_, work_paths=work_queue[hash_]) for hash_ in remaining_work_hashes]
random.shuffle(remaining_items)
# 4) Initialize our in-memory queue
self._queue = asyncio.Queue()
for item in remaining_items:
await self._queue.put(item)
logger.info(f"Initialized local queue with {self._queue.qsize()} work items")
async def is_completed(self, work_hash: str) -> bool:
"""
Check if a work item has been completed locally by seeing if
output_{work_hash}.jsonl is present in the results directory.
Args:
work_hash: Hash of the work item to check
"""
output_file = os.path.join(self._results_dir, f"output_{work_hash}.jsonl")
return os.path.exists(output_file)
async def get_work(self, worker_lock_timeout_secs: int = 1800) -> Optional[WorkItem]:
"""
Get the next available work item that isn't completed or locked.
Args:
worker_lock_timeout_secs: Number of seconds before considering
a worker lock stale (default 30 mins)
Returns:
WorkItem if work is available, None if queue is empty
"""
while True:
try:
work_item = self._queue.get_nowait()
except asyncio.QueueEmpty:
return None
# Check if work is already completed
if await self.is_completed(work_item.hash):
logger.debug(f"Work item {work_item.hash} already completed, skipping")
self._queue.task_done()
continue
# Check for worker lock
lock_file = os.path.join(self._locks_dir, f"output_{work_item.hash}.jsonl")
if os.path.exists(lock_file):
# Check modification time
mtime = datetime.datetime.fromtimestamp(os.path.getmtime(lock_file), datetime.timezone.utc)
if (datetime.datetime.now(datetime.timezone.utc) - mtime).total_seconds() > worker_lock_timeout_secs:
# Lock is stale, we can take this work
logger.debug(f"Found stale lock for {work_item.hash}, taking work item")
else:
# Lock is active, skip this work
logger.debug(f"Work item {work_item.hash} is locked by another worker, skipping")
self._queue.task_done()
continue
# Create our lock file (touch an empty file)
try:
with open(lock_file, "wb") as f:
f.write(b"")
except Exception as e:
logger.warning(f"Failed to create lock file for {work_item.hash}: {e}")
self._queue.task_done()
continue
return work_item
async def mark_done(self, work_item: WorkItem) -> None:
"""
Mark a work item as done by removing its lock file.
Args:
work_item: The WorkItem to mark as done
"""
lock_file = os.path.join(self._locks_dir, f"output_{work_item.hash}.jsonl")
if os.path.exists(lock_file):
try:
os.remove(lock_file)
except Exception as e:
logger.warning(f"Failed to delete lock file for {work_item.hash}: {e}")
self._queue.task_done()
@property
def size(self) -> int:
"""Get current size of local work queue"""
return self._queue.qsize()
# --------------------------------------------------------------------------------------
# S3WorkQueue Implementation
# --------------------------------------------------------------------------------------
class S3WorkQueue(WorkQueue):
"""
Manages a work queue stored in S3 that coordinates work across multiple workers.
The queue maintains a list of work items, where each work item is a group of s3 paths
that should be processed together.
Each work item gets a hash, and completed work items will have their results
stored in s3://workspace_path/results/output_[hash].jsonl
This is the ground source of truth about which work items are done.
When a worker takes an item off the queue, it will write an empty s3 file to
s3://workspace_path/worker_locks/output_[hash].jsonl
The queue gets randomized on each worker, so workers pull random work items to operate on.
As you pull an item, we will check to see if it has been completed. If yes,
then it will immediately fetch the next item. If a lock file was created within a configurable
timeout (30 mins by default), then that work item is also skipped.
The lock will will be deleted once the worker is done with that item.
"""
def __init__(self, s3_client, workspace_path: str):
"""
Initialize the work queue.
Args:
s3_client: Boto3 S3 client to use for operations
workspace_path: S3 path where work queue and results are stored
"""
self.s3_client = s3_client
self.workspace_path = workspace_path.rstrip("/")
self._index_path = os.path.join(self.workspace_path, "work_index_list.csv.zstd")
self._output_glob = os.path.join(self.workspace_path, "results", "*.jsonl")
self._queue: Queue[Any] = Queue()
async def populate_queue(self, work_paths: List[str], items_per_group: int) -> None:
"""
Add new items to the work queue.
Args:
work_paths: Each individual s3 path that we will process over
items_per_group: Number of items to group together in a single work item
"""
all_paths = set(work_paths)
logger.info(f"Found {len(all_paths):,} total paths")
# Load existing work groups
existing_lines = await asyncio.to_thread(download_zstd_csv, self.s3_client, self._index_path)
existing_groups = {}
for line in existing_lines:
if line.strip():
parts = line.strip().split(",")
group_hash = parts[0]
group_paths = parts[1:]
existing_groups[group_hash] = group_paths
existing_path_set = {path for paths in existing_groups.values() for path in paths}
# Find new paths to process
new_paths = all_paths - existing_path_set
logger.info(f"{len(new_paths):,} new paths to add to the workspace")
if not new_paths:
return
# Create new work groups
new_groups = []
current_group = []
for path in sorted(new_paths):
current_group.append(path)
if len(current_group) == items_per_group:
group_hash = self._compute_workgroup_hash(current_group)
new_groups.append((group_hash, current_group))
current_group = []
if current_group:
group_hash = self._compute_workgroup_hash(current_group)
new_groups.append((group_hash, current_group))
logger.info(f"Created {len(new_groups):,} new work groups")
# Combine and save updated work groups
combined_groups = existing_groups.copy()
for group_hash, group_paths in new_groups:
combined_groups[group_hash] = group_paths
combined_lines = [",".join([group_hash] + group_paths) for group_hash, group_paths in combined_groups.items()]
if new_groups:
await asyncio.to_thread(upload_zstd_csv, self.s3_client, self._index_path, combined_lines)
async def initialize_queue(self) -> None:
"""
Load the work queue from S3 and initialize it for processing.
Removes already completed work items and randomizes the order.
"""
# Load work items and completed items in parallel
download_task = asyncio.to_thread(download_zstd_csv, self.s3_client, self._index_path)
expand_task = asyncio.to_thread(expand_s3_glob, self.s3_client, self._output_glob)
work_queue_lines, done_work_items = await asyncio.gather(download_task, expand_task)
# Process work queue lines
work_queue = {parts[0]: parts[1:] for line in work_queue_lines if (parts := line.strip().split(",")) and line.strip()}
# Get set of completed work hashes
done_work_hashes = {
os.path.basename(item)[len("output_") : -len(".jsonl")]
for item in done_work_items
if os.path.basename(item).startswith("output_") and os.path.basename(item).endswith(".jsonl")
}
# Find remaining work and shuffle
remaining_work_hashes = set(work_queue) - done_work_hashes
remaining_items = [WorkItem(hash=hash_, work_paths=work_queue[hash_]) for hash_ in remaining_work_hashes]
random.shuffle(remaining_items)
# Initialize queue
self._queue = asyncio.Queue()
for item in remaining_items:
await self._queue.put(item)
logger.info(f"Initialized queue with {self._queue.qsize()} work items")
async def is_completed(self, work_hash: str) -> bool:
"""
Check if a work item has been completed.
Args:
work_hash: Hash of the work item to check
Returns:
True if the work is completed, False otherwise
"""
output_s3_path = os.path.join(self.workspace_path, "results", f"output_{work_hash}.jsonl")
bucket, key = parse_s3_path(output_s3_path)
try:
await asyncio.to_thread(self.s3_client.head_object, Bucket=bucket, Key=key)
return True
except self.s3_client.exceptions.ClientError:
return False
async def get_work(self, worker_lock_timeout_secs: int = 1800) -> Optional[WorkItem]:
"""
Get the next available work item that isn't completed or locked.
Args:
worker_lock_timeout_secs: Number of seconds before considering a worker lock stale (default 30 mins)
Returns:
WorkItem if work is available, None if queue is empty
"""
while True:
try:
work_item = self._queue.get_nowait()
except asyncio.QueueEmpty:
return None
# Check if work is already completed
if await self.is_completed(work_item.hash):
logger.debug(f"Work item {work_item.hash} already completed, skipping")
self._queue.task_done()
continue
# Check for worker lock
lock_path = os.path.join(self.workspace_path, "worker_locks", f"output_{work_item.hash}.jsonl")
bucket, key = parse_s3_path(lock_path)
try:
response = await asyncio.to_thread(self.s3_client.head_object, Bucket=bucket, Key=key)
# Check if lock is stale
last_modified = response["LastModified"]
if (datetime.datetime.now(datetime.timezone.utc) - last_modified).total_seconds() > worker_lock_timeout_secs:
# Lock is stale, we can take this work
logger.debug(f"Found stale lock for {work_item.hash}, taking work item")
else:
# Lock is active, skip this work
logger.debug(f"Work item {work_item.hash} is locked by another worker, skipping")
self._queue.task_done()
continue
except self.s3_client.exceptions.ClientError:
# No lock exists, we can take this work
pass
# Create our lock file
try:
await asyncio.to_thread(self.s3_client.put_object, Bucket=bucket, Key=key, Body=b"")
except Exception as e:
logger.warning(f"Failed to create lock file for {work_item.hash}: {e}")
self._queue.task_done()
continue
return work_item
async def mark_done(self, work_item: WorkItem) -> None:
"""
Mark a work item as done by removing its lock file.
Args:
work_item: The WorkItem to mark as done
"""
lock_path = os.path.join(self.workspace_path, "worker_locks", f"output_{work_item.hash}.jsonl")
bucket, key = parse_s3_path(lock_path)
try:
await asyncio.to_thread(self.s3_client.delete_object, Bucket=bucket, Key=key)
except Exception as e:
logger.warning(f"Failed to delete lock file for {work_item.hash}: {e}")
self._queue.task_done()
@property
def size(self) -> int:
"""Get current size of work queue"""
return self._queue.qsize()
import torch
import base64
import urllib.request
from io import BytesIO
from PIL import Image
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
from olmocr.data.renderpdf import render_pdf_to_base64png
from olmocr.prompts import build_finetuning_prompt
from olmocr.prompts.anchor import get_anchor_text
# Initialize the model
model = Qwen2VLForConditionalGeneration.from_pretrained("allenai/olmOCR-7B-0225-preview", torch_dtype=torch.bfloat16).eval()
processor = AutoProcessor.from_pretrained("allenai/olmOCR-7B-0225-preview")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Grab a sample PDF
urllib.request.urlretrieve("https://molmo.allenai.org/paper.pdf", "./paper.pdf")
# Render page 1 to an image
image_base64 = render_pdf_to_base64png("./paper.pdf", 1, target_longest_image_dim=1024)
# Build the prompt, using document metadata
anchor_text = get_anchor_text("./paper.pdf", 1, pdf_engine="pdfreport", target_length=4000)
prompt = build_finetuning_prompt(anchor_text)
print('prompt:', prompt)
# Build the full prompt
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_base64}"}},
],
}
]
# Apply the chat template and processor
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
main_image = Image.open(BytesIO(base64.b64decode(image_base64)))
inputs = processor(
text=[text],
images=[main_image],
padding=True,
return_tensors="pt",
)
inputs = {key: value.to(device) for (key, value) in inputs.items()}
# Generate the output
output = model.generate(
**inputs,
temperature=0.8,
max_new_tokens=50,
num_return_sequences=1,
do_sample=True,
)
# Decode the output
prompt_length = inputs["input_ids"].shape[1]
new_tokens = output[:, prompt_length:]
text_output = processor.tokenizer.batch_decode(
new_tokens, skip_special_tokens=True
)
print(text_output)
# ['{"primary_language":"en","is_rotation_valid":true,"rotation_correction":0,"is_table":false,"is_diagram":false,"natural_text":"Molmo and PixMo:\\nOpen Weights and Open Data\\nfor State-of-the']
[build-system]
requires = ["setuptools", "wheel"]
build-backend = "setuptools.build_meta"
[project]
# See https://setuptools.pypa.io/en/latest/userguide/quickstart.html for more project configuration options.
name = "olmocr"
dynamic = ["version"]
readme = "README.md"
classifiers = [
"Intended Audience :: Science/Research",
"Development Status :: 3 - Alpha",
"License :: OSI Approved :: Apache Software License",
"Programming Language :: Python :: 3",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
]
authors = [
{name = "Allen Institute for Artificial Intelligence", email = "jakep@allenai.org"}
]
requires-python = ">=3.11"
dependencies = [
"cached-path",
"smart_open",
"pypdf>=5.2.0",
"pypdfium2",
"cryptography",
"lingua-language-detector",
"Pillow",
"ftfy",
"bleach",
"markdown2",
"filelock",
"orjson",
"requests",
"zstandard",
"boto3",
"httpx",
"torch>=2.5.1",
"transformers==4.46.2",
"beaker-py",
]
license = {file = "LICENSE"}
[project.urls]
Homepage = "https://github.com/allenai/olmocr"
Repository = "https://github.com/allenai/olmocr"
Changelog = "https://github.com/allenai/olmocr/blob/main/CHANGELOG.md"
# Documentation = "https://olmocr.readthedocs.io/"
[project.optional-dependencies]
dev = [
"ruff",
"mypy",
"black",
"isort",
"pytest",
"pytest-sphinx",
"pytest-cov",
"twine>=1.11.0",
"build",
"setuptools",
"wheel",
"Sphinx>=4.3.0,<7.1.0",
"furo==2023.7.26",
"myst-parser>=1.0,<2.1",
"sphinx-copybutton==0.5.2",
"sphinx-autobuild==2021.3.14",
"sphinx-autodoc-typehints==1.23.3",
"packaging",
"necessary",
"peft",
"datasets",
"omegaconf",
"spacy",
]
bench = [
"tinyhost",
"fuzzysearch",
"rapidfuzz",
"sequence_align",
"syntok",
"google-genai",
"google-generativeai",
"playwright",
"mistralai",
]
train = [
"torch",
"torchvision",
"accelerate",
"datasets",
"peft",
"wandb",
"omegaconf",
"s3fs",
"necessary",
"einops",
"transformers>=4.45.1"
]
elo = [
"numpy",
"scipy",
"pandas",
"matplotlib"
]
[tool.setuptools.packages.find]
exclude = [
"*.tests",
"*.tests.*",
"tests.*",
"tests",
"docs*",
"scripts*"
]
[tool.setuptools]
include-package-data = true
[tool.setuptools.package-data]
olmocr = [
"py.typed",
"viewer/*.html",
"eval/*.html",
]
[tool.setuptools.dynamic]
version = {attr = "olmocr.version.VERSION"}
[tool.black]
line-length = 160
include = '\.pyi?$'
exclude = '''
(
__pycache__
| \.git
| \.mypy_cache
| \.pytest_cache
| \.vscode
| \.venv
| \bdist\b
| \bdoc\b
)
'''
[tool.isort]
profile = "black"
multi_line_output = 3
# You can override these pyright settings by adding a personal pyrightconfig.json file.
[tool.pyright]
reportPrivateImportUsage = false
[tool.ruff]
line-length = 160
target-version = "py311"
exclude = ["olmocr/train/molmo", "tests/*"]
ignore = ["E722"] #igore bare except
[tool.ruff.per-file-ignores]
"__init__.py" = ["F401"]
[tool.mypy]
ignore_missing_imports = true
no_site_packages = true
check_untyped_defs = true
exclude = ["olmocr/train/molmo/", "tests/*"]
[[tool.mypy.overrides]]
module = "tests.*"
strict_optional = false
[tool.pytest.ini_options]
testpaths = "tests/"
python_classes = [
"Test*",
"*Test"
]
log_format = "%(asctime)s - %(levelname)s - %(name)s - %(message)s"
log_level = "DEBUG"
markers = [
"nonci: mark test as not intended for CI runs"
]
\ No newline at end of file
FROM --platform=linux/amd64 nvidia/cuda:11.8.0-cudnn8-runtime-ubuntu20.04
RUN apt-get update -y && apt-get install -y software-properties-common \
&& add-apt-repository ppa:deadsnakes/ppa \
&& apt-get -y update
# Install requirements specific to pdfs
RUN apt-get update && apt-get -y install python3-apt
RUN echo "ttf-mscorefonts-installer msttcorefonts/accepted-mscorefonts-eula select true" | debconf-set-selections
RUN apt-get update -y && apt-get install -y poppler-utils ttf-mscorefonts-installer msttcorefonts fonts-crosextra-caladea fonts-crosextra-carlito gsfonts lcdf-typetools
RUN apt-get update -y && apt-get install -y --no-install-recommends \
git \
python3.11 \
python3.11-dev \
python3.11-distutils \
ca-certificates \
build-essential \
curl \
unzip
RUN rm -rf /var/lib/apt/lists/* \
&& unlink /usr/bin/python3 \
&& ln -s /usr/bin/python3.11 /usr/bin/python3 \
&& ln -s /usr/bin/python3 /usr/bin/python \
&& curl -sS https://bootstrap.pypa.io/get-pip.py | python \
&& pip3 install -U pip
RUN apt-get update && apt-get -y install python3.11-venv
ADD --chmod=755 https://astral.sh/uv/install.sh /install.sh
RUN /install.sh && rm /install.sh
ENV PYTHONUNBUFFERED=1
WORKDIR /root
COPY pyproject.toml pyproject.toml
COPY olmocr/version.py olmocr/version.py
RUN /root/.local/bin/uv pip install --system --no-cache -e .
RUN /root/.local/bin/uv pip install --system --no-cache sgl-kernel==0.0.3.post1 --force-reinstall --no-deps
RUN /root/.local/bin/uv pip install --system --no-cache "sglang[all]==0.4.2" --find-links https://flashinfer.ai/whl/cu124/torch2.4/flashinfer/
COPY olmocr olmocr
WORKDIR /root
COPY olmocr olmocr
RUN python3 -m sglang.launch_server --help
RUN python3 -m olmocr.pipeline --help
\ No newline at end of file
FROM gcr.io/ai2-beaker-core/public/cqgl31u2ba5vrtuc91og:latest
# Update the package list and install libaio-dev and gnupg2
RUN apt update && apt-get install -y libaio-dev gnupg2
# Add NVIDIA package repository keys
RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/3bf863cc.pub \
&& apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64/7fa2af80.pub \
&& apt-get -y update
# Set up the NVIDIA CUDA repository
RUN apt-get install -y software-properties-common \
&& add-apt-repository "deb https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/ /" \
&& apt-get update
# Install CUDA toolkit and nvcc 12.1
RUN apt-get install -y cuda-nvcc-12-1
# Get flash attention setup
RUN pip install flash-attn --no-build-isolation
# Install PDF utilities
RUN apt-get install -y poppler-utils
RUN echo "ttf-mscorefonts-installer msttcorefonts/accepted-mscorefonts-eula select true" | debconf-set-selections
RUN apt-get install -y ttf-mscorefonts-installer msttcorefonts fonts-crosextra-caladea fonts-crosextra-carlito gsfonts lcdf-typetools
set -ex
export NCCL_DEBUG=INFO NCCL_SOCKET_IFNAME=ib NCCL_IB_HCA="^=mlx5_bond_0"
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment