Commit 1b9205c9 authored by yangzhong's avatar yangzhong
Browse files

v1.0

parents
Pipeline #2931 failed with stages
in 0 seconds
"""
Based on: https://github.com/lucidrains/flamingo-pytorch
"""
import re
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
from einops_exts import rearrange_many
from torch import einsum, nn
from transformers.modeling_outputs import CausalLMOutputWithPast
from typing import Optional
from dataclasses import dataclass
@dataclass
class VLMOutputWithPast(CausalLMOutputWithPast):
"""
VLMOutputWithPast is a wrapper around CausalLMOutputWithPast that adds the following attributes:
past_media_locations: Optional[torch.Tensor] = None,
past_vision_tokens: Optional[torch.Tensor] = None,
"""
past_media_locations: Optional[torch.Tensor] = None
past_vision_tokens: Optional[torch.Tensor] = None
def exists(val):
return val is not None
def FeedForward(dim, mult=4):
inner_dim = int(dim * mult)
return nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, inner_dim, bias=False),
nn.GELU(),
nn.Linear(inner_dim, dim, bias=False),
)
class VisionTokenizer(nn.Module):
def __init__(self, dim_media, num_tokens_per_media):
super().__init__()
self.dim_media = dim_media
self.num_tokens_per_media = num_tokens_per_media
class PerceiverAttention(nn.Module):
def __init__(self, *, dim, dim_head=64, heads=8):
super().__init__()
self.scale = dim_head**-0.5
self.heads = heads
inner_dim = dim_head * heads
self.norm_media = nn.LayerNorm(dim)
self.norm_latents = nn.LayerNorm(dim)
self.to_q = nn.Linear(dim, inner_dim, bias=False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
self.to_out = nn.Linear(inner_dim, dim, bias=False)
def forward(self, x, latents, vision_attn_masks=None):
"""
Args:
x (torch.Tensor): image features
shape (b, T, n1, D)
latent (torch.Tensor): latent features
shape (b, T, n2, D)
"""
x = self.norm_media(x)
latents = self.norm_latents(latents)
h = self.heads
q = self.to_q(latents)
kv_input = torch.cat((x, latents), dim=-2) # TODO: Change the shape of vision attention mask according to this.
if vision_attn_masks is not None:
vision_attn_masks = torch.cat((vision_attn_masks,
torch.ones((latents.shape[0], latents.shape[-2]), dtype=latents.dtype, device=latents.device)),
dim=-1)
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
q, k, v = rearrange_many((q, k, v), "b t n (h d) -> b h t n d", h=h)
q = q * self.scale
# attention
sim = einsum("... i d, ... j d -> ... i j", q, k)
# Apply vision attention mask here.
# Reference: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html#torch.nn.functional.scaled_dot_product_attention
if vision_attn_masks is not None:
attn_bias = torch.zeros((q.size(0), 1, 1, q.size(-2), k.size(-2)), dtype=q.dtype, device=q.device)
vision_attn_masks = repeat(vision_attn_masks, 'b n -> b 1 1 l n', l=q.size(-2))
attn_bias.masked_fill_(vision_attn_masks.logical_not(), float("-inf"))
sim += attn_bias
sim = sim - sim.amax(dim=-1, keepdim=True).detach()
attn = sim.softmax(dim=-1)
out = einsum("... i j, ... j d -> ... i d", attn, v)
out = rearrange(out, "b h t n d -> b t n (h d)", h=h)
return self.to_out(out)
class PerceiverResampler(VisionTokenizer):
def __init__(
self,
*,
dim,
dim_inner=None,
depth=6,
dim_head=96,
heads=16,
num_latents=128,
max_num_media=None,
max_num_frames=None,
ff_mult=4,
):
"""
Perceiver module which takes in image features and outputs image tokens.
Args:
dim (int): dimension of the incoming image features
dim_inner (int, optional): final dimension to project the incoming image features to;
also the final dimension of the outputted features. If None, no projection is used, and dim_inner = dim.
depth (int, optional): number of layers. Defaults to 6.
dim_head (int, optional): dimension of each head. Defaults to 64.
heads (int, optional): number of heads. Defaults to 8.
num_latents (int, optional): number of latent tokens to use in the Perceiver;
also corresponds to number of tokens per sequence to output. Defaults to 64.
max_num_media (int, optional): maximum number of media per sequence to input into the Perceiver
and keep positional embeddings for. If None, no positional embeddings are used.
max_num_frames (int, optional): maximum number of frames to input into the Perceiver
and keep positional embeddings for. If None, no positional embeddings are used.
ff_mult (int, optional): dimension multiplier for the feedforward network. Defaults to 4.
"""
if dim_inner is not None:
projection = nn.Linear(dim, dim_inner)
else:
projection = None
dim_inner = dim
super().__init__(dim_media=dim, num_tokens_per_media=num_latents)
self.projection = projection
self.latents = nn.Parameter(torch.randn(num_latents, dim))
# positional embeddings
self.frame_embs = (
nn.Parameter(torch.randn(max_num_frames, dim))
if exists(max_num_frames)
else None
)
self.media_time_embs = (
nn.Parameter(torch.randn(max_num_media, 1, dim))
if exists(max_num_media)
else None
)
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(
nn.ModuleList(
[
PerceiverAttention(
dim=dim, dim_head=dim_head, heads=heads
),
FeedForward(dim=dim, mult=ff_mult),
]
)
)
self.norm = nn.LayerNorm(dim)
def forward(self, x, vision_attn_masks):
"""
Args:
x (torch.Tensor): image features
shape (b, T, F, v, D)
vision_attn_masks (torch.Tensor): attention masks for padded visiont tokens (i.e., x)
shape (b, v)
Returns:
shape (b, T, n, D) where n is self.num_latents
"""
b, T, F, v = x.shape[:4]
# frame and media time embeddings
if exists(self.frame_embs):
frame_embs = repeat(self.frame_embs[:F], "F d -> b T F v d", b=b, T=T, v=v)
x = x + frame_embs
x = rearrange(
x, "b T F v d -> b T (F v) d"
) # flatten the frame and spatial dimensions
if exists(self.media_time_embs):
x = x + self.media_time_embs[:T]
# blocks
latents = self.latents
latents = repeat(latents, "n d -> b T n d", b=b, T=T)
for attn, ff in self.layers:
latents = attn(x, latents, vision_attn_masks) + latents
latents = ff(latents) + latents
if exists(self.projection):
return self.projection(self.norm(latents))
else:
return self.norm(latents)
class LinearPatchProjection(VisionTokenizer):
"""Linear projection from patch features to image tokens."""
def __init__(self, mm_projector_type, *, dim_visual, dim_out, num_patches):
super().__init__(dim_media=dim_visual, num_tokens_per_media=num_patches)
if mm_projector_type == 'linear':
self.proj = nn.Linear(dim_visual, dim_out)
else:
mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', mm_projector_type)
if mlp_gelu_match:
mlp_depth = int(mlp_gelu_match.group(1))
modules = [nn.Linear(dim_visual, dim_out)]
for _ in range(1, mlp_depth):
modules.append(nn.GELU())
modules.append(nn.Linear(dim_out, dim_out))
self.proj = nn.Sequential(*modules)
else:
raise ValueError(f'Unknown projector type: {mm_projector_type}')
def forward(self, x):
B = x.shape[0]
x = rearrange(x, "b T F v d -> (b T) (F v) d")
x = self.proj(x)
return rearrange(x, "(b T) n d -> b T n d", b=B)
# gated cross attention
class MaskedCrossAttention(nn.Module):
def __init__(
self,
*,
dim,
dim_visual,
dim_head=64,
heads=8,
only_attend_immediate_media=True,
):
super().__init__()
self.scale = dim_head**-0.5
self.heads = heads
inner_dim = dim_head * heads
self.norm = nn.LayerNorm(dim)
self.to_q = nn.Linear(dim, inner_dim, bias=False)
self.to_kv = nn.Linear(dim_visual, inner_dim * 2, bias=False)
self.to_out = nn.Linear(inner_dim, dim, bias=False)
# whether for text to only attend to immediate preceding image, or all previous images
self.only_attend_immediate_media = only_attend_immediate_media
def forward(self, x, media, media_locations=None):
"""
Args:
x (torch.Tensor): text features
shape (B, T_txt, D_txt)
media (torch.Tensor): image features
shape (B, T_img, n, D_img) where n is the dim of the latents
media_locations: boolean mask identifying the media tokens in x
shape (B, T_txt_all)
T_txt_all >= T_txt
If T_txt_all > T_txt, then the last T_txt text_times are used
"""
T_txt = x.shape[1]
assert (
T_txt <= media_locations.shape[1]
), "current text cannot be longer than conditioned media locations"
_, T_img, n = media.shape[:3]
h = self.heads
x = self.norm(x)
q = self.to_q(x)
media = rearrange(media, "b t n d -> b (t n) d")
k, v = self.to_kv(media).chunk(2, dim=-1)
q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=h)
q = q * self.scale
sim = einsum("... i d, ... j d -> ... i j", q, k)
if exists(media_locations):
media_time = torch.arange(T_img, device=x.device) + 1
# at each boolean of True, increment the time counter (relative to media time)
text_time = media_locations.cumsum(dim=-1)[:, -T_txt:]
# text time must equal media time if only attending to most immediate image
# otherwise, as long as text time is greater than media time (if attending to all previous images / media)
mask_op = torch.eq if self.only_attend_immediate_media else torch.ge
text_to_media_mask = mask_op(
rearrange(text_time, "b i -> b 1 i 1"),
repeat(media_time, "j -> 1 1 1 (j n)", n=n),
)
sim = sim.masked_fill(~text_to_media_mask, -torch.finfo(sim.dtype).max)
sim = sim - sim.amax(dim=-1, keepdim=True).detach()
attn = sim.softmax(dim=-1)
if exists(media_locations) and self.only_attend_immediate_media:
# any text without a preceding media needs to have attention zeroed out
text_without_media_mask = text_time == 0
text_without_media_mask = rearrange(
text_without_media_mask, "b i -> b 1 i 1"
)
attn = attn.masked_fill(text_without_media_mask, 0.0)
out = einsum("... i j, ... j d -> ... i d", attn, v)
out = rearrange(out, "b h n d -> b n (h d)")
return self.to_out(out)
class GatedCrossAttentionBlock(nn.Module):
def __init__(
self,
*,
dim,
dim_visual,
dim_head=64,
heads=8,
ff_mult=4,
only_attend_immediate_media=True,
):
super().__init__()
self.attn = MaskedCrossAttention(
dim=dim,
dim_visual=dim_visual,
dim_head=dim_head,
heads=heads,
only_attend_immediate_media=only_attend_immediate_media,
)
self.attn_gate = nn.Parameter(torch.tensor([0.0]))
self.ff = FeedForward(dim, mult=ff_mult)
self.ff_gate = nn.Parameter(torch.tensor([0.0]))
def forward(
self,
x,
media,
media_locations=None,
):
x = (
self.attn(
x,
media,
media_locations=media_locations,
)
* self.attn_gate.tanh()
+ x
)
x = self.ff(x) * self.ff_gate.tanh() + x
return x
class QFormerWithProjection(VisionTokenizer):
"""
Based on BLIP-2 (https://arxiv.org/pdf/2301.12597.pdf)
In the BLIP-2 paper, Q-former is initialized with BERT-base weights,
so dim_inner = 768, num_hidden_layers = 12, and intermediate_size = 3072
"""
def __init__(
self,
dim_input,
dim_out,
dim_inner=768,
num_hidden_layers=12,
num_query_tokens=32,
):
super().__init__(dim_media=dim_out, num_tokens_per_media=num_query_tokens)
# initialize the qformer
from transformers import Blip2QFormerModel, Blip2QFormerConfig
self.qformer = Blip2QFormerModel(
Blip2QFormerConfig(
encoder_hidden_size=dim_input,
hidden_size=dim_inner,
num_hidden_layers=num_hidden_layers,
)
)
self.query_tokens = nn.Parameter(
torch.zeros(1, num_query_tokens, dim_inner)
)
self.proj = nn.Linear(dim_inner, dim_out)
def forward(self, x):
"""
Args:
x (torch.Tensor): image features
shape (B, T, F, v, D)
Returns:
shape (B, T, n, D) where n is num_query_tokens
"""
# HF class expects three dimensional input
B, T = x.shape[:2]
x = rearrange(x, "b T F v d -> (b T) (F v) d")
# get the outputs
image_attention_mask = torch.ones(
x.size()[:-1], dtype=torch.long, device=x.device
)
query_tokens = self.query_tokens.expand(x.shape[0], -1, -1)
query_outputs = self.qformer(
query_embeds=query_tokens,
encoder_hidden_states=x,
encoder_attention_mask=image_attention_mask,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
)
query_output = query_outputs[0]
query_output = self.proj(query_output)
# reshape
query_output = rearrange(query_output, "(b T) n d -> b T n d", b=B)
return query_output
# Both DecoupledEmbedding and DecoupledLinear are taken from https://github.com/huggingface/transformers/blob/v4.32.1/src/transformers/models/idefics/modeling_idefics.py and renamed for clarity
class DecoupledEmbedding(nn.Embedding):
# Derived from https://pytorch.org/docs/stable/_modules/torch/nn/modules/sparse.html#Embedding
"""
Implements a decoupling of parameters to allow freezing (or not) a subset of the embeddings. In practise, the
regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `num_additional_embeddings` > 0,
then it will create `num_additional_embeddings` additional parameters that are always trained. If
`num_additional_embeddings=0`, then the module defaults back to the regular behavior of `nn.Embedding`.
"""
def __init__(
self,
max_original_id: int,
num_additional_embeddings: int = 0,
_weight: torch.Tensor = None,
num_original_embeddings: int = None,
embedding_dim: int = None,
partially_freeze=True,
device=None,
dtype=None,
pad_token_id=None,
) -> None:
"""
Args:
max_original_id (`int`):
The largest token id that should be embedded using the regular embedding (regular `weight`).
This is usually len(tokenizer) - 1 before additional tokens are added.
Note that this may not equal self.weight.shape[0]
num_additional_embeddings (`int`):
Number of additional tokens to initialize an Embedding matrix for (`additional_weight`).
_weight (`torch.Tensor`, *optional*, defaults to `None`): The regular weight tensor.
If provided, this sets the `num_original_embeddings` and `embedding_dim` parameters.
num_original_embeddings (`int`):
self.weight.shape[0]
embedding_dim (`int`):
The size of each embedding vector
partially_freeze: (`bool`, *optional*, defaults to `True`):
If `True`, the regular `weight` will be frozen. `additional_weight` is never frozen.
padding_idx (`int`, *optional*):
The padding index (needs to be less than num_embeddings)
Note: there are a lot of other parameters to initialize a standard `nn.Embedding` such as `padding_idx`,
`max_norm` or `norm_type`. We are not supporting these.
"""
# validate args
if pad_token_id is not None and pad_token_id > max_original_id:
raise ValueError(
f"pad_token_id must be <= max_original_id. Got {pad_token_id} and {max_original_id}."
+ "If the original tokenizer does not have a pad_token_id, use pad_token_id=None."
)
if _weight is not None:
assert (num_original_embeddings is None) or (
_weight.shape[0] == num_original_embeddings
), f"num_original_embeddings={num_original_embeddings} but _weight.shape[0]={_weight.shape[0]}"
assert (embedding_dim is None) or (
_weight.shape[1] == embedding_dim
), f"embedding_dim={embedding_dim} but _weight.shape[1]={_weight.shape[1]}"
num_original_embeddings = _weight.shape[0]
embedding_dim = _weight.shape[1]
else:
assert (
num_original_embeddings is not None
), "num_original_embeddings must be provided if _weight is not provided"
assert (
embedding_dim is not None
), "embedding_dim must be provided if _weight is not provided"
super().__init__(
num_embeddings=num_original_embeddings,
embedding_dim=embedding_dim,
device=device,
dtype=dtype,
padding_idx=pad_token_id,
_weight=_weight,
)
self.max_original_id = max_original_id
self.padding_idx = pad_token_id
self.num_additional_embeddings = num_additional_embeddings
if self.num_additional_embeddings > 0:
self.additional_embedding = nn.Embedding(
num_embeddings=self.num_additional_embeddings,
embedding_dim=embedding_dim,
device=device,
dtype=dtype,
)
self.set_requires_grad(
require_regular_grad=not partially_freeze, require_additional_grad=True
)
def set_requires_grad(self, require_regular_grad, require_additional_grad):
"""
Helper function to separately set the requires_grad flag for the regular weight and the additional weight.
"""
self.weight.requires_grad_(require_regular_grad)
self.additional_embedding.requires_grad_(require_additional_grad)
def forward(self, input_ids):
"""
we have 2 embeddings, with different indices - one pretrained self.weight and another
self.additional_embedding.weight that is being trained.
in order to make a lookup of the input ids, we:
1. find out the indices of the entries belonging to the 2nd embedding
2. extract those values while subtracting the size of the first embedding (num_embeddings), since the 2nd
embedding starts from 0 and not num_embeddings
3. perform the 2nd embedding lookup
4. now we handle the 1st embedding, we overwrite indices belonging to the 2nd embedding with a padding index
5. perform the 1st embedding lookup
6. now we overwrite the values in the 1st embedding lookup with the values of the 2nd embedding lookup
note: for the 1st embedding lookup we could have looked up only the low indices and not do the padding, but
then we have to create a new tensor and populate it with 2 tensors that are spread out across various indices -
i.e. not a simple concat - I haven't benchmarked the complex case if it's any faster, given that seqlens are
usually relatively short it's probably not faster or if faster not by much - but might be a good idea to
measure.
"""
if self.num_additional_embeddings == 0:
return F.embedding(input_ids, self.weight)
# Clone so that we don't modify the original input_ids later on
input_ids = input_ids.clone()
additional_vocab_indices = torch.where(input_ids > self.max_original_id)
input_ids_additional_vocab = input_ids[additional_vocab_indices]
additional_embeddings = self.additional_embedding(
input_ids_additional_vocab - self.max_original_id - 1
)
# for successful lookup replace input_ids with 0, the results of these will be discarded anyway
input_ids[additional_vocab_indices] = 0
full_vector = F.embedding(input_ids, self.weight)
# overwrite the records with high indices
full_vector[additional_vocab_indices] = additional_embeddings
return full_vector
def extra_repr(self) -> str:
return "num_original_embeddings={}, num_additional_embeddings={}, embedding_dim={}, partially_freeze={}".format(
self.max_original_id + 1,
self.num_additional_embeddings,
self.embedding_dim,
(not self.weight.requires_grad),
)
class DecoupledLinear(nn.Linear):
# Derived from https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html#Linear
"""
Implements a decoupling of parameters to allow freezing (or not) a subset of the parameters. In practise, the
regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `additional_out_features` > 0,
then it will create `additional_out_features * in_features` additional parameters that are always trained. If
`additional_out_features=0`, then the module defaults back to the regular behavior of `nn.Linear`.
"""
def __init__(
self,
max_original_id: int,
additional_out_features: int = 0,
_weight: torch.Tensor = None,
_bias: torch.Tensor = None,
in_features: int = None,
original_out_features: int = None,
bias: bool = True,
partially_freeze: bool = True,
device=None,
dtype=None,
) -> None:
"""
Args:
max_original_id (`int`): The largest token id that should be extracted from the regular weight.
This is usually len(tokenizer) - 1 before additional tokens are added.
Note that this may not equal original_out_features - 1
_weight: torch.Tensor, *optional*, defaults to `None`. The regular weight tensor.
If provided, this sets the `in_features` and `original_out_features` parameters.
_bias: torch.Tensor, *optional*, defaults to `None`. The regular bias tensor.
in_features: int. Input hidden size.
original_out_features: int. Original out_features of the language model's get_output_embeddings() function.
additional_out_features: int. Number of additional trainable dimensions.
bias: bool. Whether to include a bias term.
partially_freeze: bool, *optional*, defaults to `True`): If `True`, the regular `weight` will be frozen.
"""
# argument validation
if _weight is not None:
assert (_weight.shape[0] == original_out_features) or (
original_out_features is None
), f"original_out_features={original_out_features} but _weight.shape[0]={_weight.shape[0]}"
assert (_weight.shape[1] == in_features) or (
in_features is None
), f"in_features={in_features} but _weight.shape[1]={_weight.shape[1]}"
in_features = _weight.shape[1]
original_out_features = _weight.shape[0]
else:
assert (
in_features is not None
), "in_features must be provided if _weight is not provided"
assert (
original_out_features is not None
), "original_out_features must be provided if _weight is not provided"
if _bias is not None:
assert bias is True, "bias must be True if _bias is provided"
# initialize original linear
super().__init__(
in_features,
original_out_features,
bias,
device,
dtype)
# set weight and bias manually
if _weight is not None:
self.weight = nn.Parameter(_weight)
if _bias is not None:
self.bias = nn.Parameter(_bias)
self.in_features = in_features
self.original_out_features = original_out_features
self.max_original_id = max_original_id
# initialize additional linear
self.additional_out_features = additional_out_features
self.has_bias = bias
if additional_out_features > 0:
self.additional_fc = nn.Linear(
in_features=in_features,
out_features=additional_out_features,
bias=self.has_bias,
device=device,
dtype=dtype,
)
self.set_requires_grad(
require_regular_grad=not partially_freeze, require_additional_grad=True
)
def set_requires_grad(self, require_regular_grad, require_additional_grad):
"""
Helper function to separately set the requires_grad flag for the regular weight and the additional weight.
"""
self.weight.requires_grad_(require_regular_grad)
if self.has_bias:
self.bias.requires_grad_(require_regular_grad)
self.additional_fc.requires_grad_(require_additional_grad)
def forward(self, input: torch.Tensor) -> torch.Tensor:
output = F.linear(input, self.weight, self.bias)
output = output[..., : self.max_original_id + 1]
if self.additional_out_features > 0:
additional_features = F.linear(
input, self.additional_fc.weight, self.additional_fc.bias
)
output = torch.cat((output, additional_features), -1)
return output
def extra_repr(self) -> str:
"""Overwriting `nn.Linear.extra_repr` to include new parameters."""
return "in_features={}, out_features={}, additional_out_features={}, bias={}, partially_freeze={}".format(
self.in_features,
self.max_original_id + 1,
self.additional_out_features,
self.bias is not None,
(not self.weight.requires_grad or not self.bias.requires_grad),
)
import torch
def extend_instance(obj, mixin):
"""Apply mixins to a class instance after creation"""
base_cls = obj.__class__
base_cls_name = obj.__class__.__name__
obj.__class__ = type(
base_cls_name, (mixin, base_cls), {}
) # mixin needs to go first for our forward() logic to work
def hasattr_recursive(obj, att):
"""
Check if obj has nested attribute
Example: hasattr_recursive(obj, 'a.b.c') is equivalent to hasattr(obj, 'a') and hasattr(obj.a, 'b') and hasattr(obj.a.b, 'c')
"""
if att == "":
return True
i = att.find(".")
if i < 0:
return hasattr(obj, att)
else:
try:
return hasattr_recursive(getattr(obj, att[:i]), att[i + 1 :])
except:
return False
def getattr_recursive(obj, att):
"""
Return nested attribute of obj
Example: getattr_recursive(obj, 'a.b.c') is equivalent to obj.a.b.c
"""
if att == "":
return obj
i = att.find(".")
if i < 0:
return getattr(obj, att)
else:
return getattr_recursive(getattr(obj, att[:i]), att[i + 1 :])
def setattr_recursive(obj, att, val):
"""
Set nested attribute of obj
Example: setattr_recursive(obj, 'a.b.c', val) is equivalent to obj.a.b.c = val
"""
if "." in att:
obj = getattr_recursive(obj, ".".join(att.split(".")[:-1]))
setattr(obj, att.split(".")[-1], val)
def apply_with_stopping_condition(
module, apply_fn, apply_condition=None, stopping_condition=None, **other_args
):
if stopping_condition(module):
return
if apply_condition(module):
apply_fn(module, **other_args)
for child in module.children():
apply_with_stopping_condition(
child,
apply_fn,
apply_condition=apply_condition,
stopping_condition=stopping_condition,
**other_args
)
def stack_with_padding(list_of_tensors, padding_value=0, padding_side="right"):
"""
Stack a list of tensors with padding on one side
Args:
list_of_tensors (list[torch.Tensor]): List of tensors to stack
padding_value (int, optional): Value to pad with. Defaults to 0.
padding_side (str, optional): Side to pad on. Defaults to "right".
Returns:
torch.Tensor: Stacked tensors
"""
max_tokens = max(tensor.size(0) for tensor in list_of_tensors)
padded_tensors = []
for tensor in list_of_tensors:
num_tokens = tensor.size(0)
if len(tensor.size()) == 1:
padding = torch.full(
(max_tokens - num_tokens,),
padding_value,
dtype=tensor.dtype,
device=tensor.device,
)
else:
padding = torch.full(
(max_tokens - num_tokens, tensor.size(1)),
padding_value,
dtype=tensor.dtype,
device=tensor.device,
)
padded_tensor = (
torch.cat((tensor, padding), dim=0)
if padding_side == "right"
else torch.cat((padding, tensor), dim=0)
)
padded_tensors.append(padded_tensor)
return torch.stack(padded_tensors)
def num_params(module, filter_to_trainable=False):
"""Returns the number of parameters in the module, or optionally only the trainable parameters"""
if filter_to_trainable:
return sum(p.numel() for p in module.parameters() if p.requires_grad)
else:
return sum(p.numel() for p in module.parameters())
import torch
from einops import rearrange
from torch import nn
from typing import List, Optional, Tuple, Union
from .utils import extend_instance, stack_with_padding, num_params, getattr_recursive
from .cross_attn_lm import CrossAttentionMixin
from .helpers import DecoupledEmbedding, DecoupledLinear, VLMOutputWithPast
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers import CLIPVisionModel
from transformers.models.siglip.modeling_siglip import SiglipVisionTransformer
from open_flamingo.train.any_res_data_utils import get_anyres_image_grid_shape, unpad_image
class VLM(nn.Module):
"""
Generic vision-language model (VLM) class.
A VLM consists of four components:
1. A vision encoder that extracts features from pixels, e.g. CLIP
input: (B, T_img, F, C, H, W)
output: (B, T_img, F, v, d)
2. A vision tokenizer that converts these features to visual token-like embeddings, e.g. Perceiver, or a linear projection head
input: (B, T_img, F, v, d)
output: (B, T_img, n, d)
3. A fusion method that allows the language model to attend to these tokens, e.g. cross-attention, or placing the tokens directly in the language model's input sequence
4. A language model
"""
def __init__(
self,
vision_encoder: nn.Module,
vision_tokenizer: nn.Module,
lang_model: nn.Module,
initial_tokenizer_len: int,
pad_token_id: int,
gradient_checkpointing: bool = False,
base_img_size: Optional[int] = None,
):
"""
Args:
vision_encoder (nn.Module): e.g. CLIP
vision_tokenizer (nn.Module): e.g. PerceiverResampler
lang_model (nn.Module): e.g. MPT
initial_tokenizer_len (int): size of the original tokenizer vocab
pad_token_id (int): id of the pad token
gradient_checkpointing (bool, optional): Whether to use gradient checkpointing. Defaults to False.
"""
super().__init__()
# save dimension information
self.lang_embedding_dim = lang_model.get_input_embeddings().weight.shape[1]
if hasattr(lang_model.config, "d_model"):
self.lang_hidden_dim = lang_model.config.d_model # mpt uses d_model
else:
self.lang_hidden_dim = lang_model.config.hidden_size
self.vis_embedding_dim = vision_tokenizer.dim_media
self.num_tokens_per_vis = vision_tokenizer.num_tokens_per_media
# core components
self.vision_encoder = vision_encoder
self.vision_tokenizer = vision_tokenizer
self.lang_model = lang_model
if base_img_size is None:
if isinstance(self.vision_encoder, CLIPVisionModel) or isinstance(self.vision_encoder, SiglipVisionTransformer):
base_img_size = self.vision_encoder.config.image_size
else:
base_img_size = self.vision_encoder.image_size[0]
self.base_img_size = base_img_size
# lm embeddings
self.pad_token_id = pad_token_id
self.initial_tokenizer_len = initial_tokenizer_len
input_embeds = DecoupledEmbedding(
max_original_id=initial_tokenizer_len - 1,
num_additional_embeddings=len(self.special_tokens),
_weight=self.lang_model.get_input_embeddings().weight,
pad_token_id=self.pad_token_id,
)
if hasattr(input_embeds, "additional_embedding"):
input_embeds.additional_embedding.weight.data.normal_(
mean=0.0,
std=self.lang_model.config.initializer_range
if hasattr(self.lang_model.config, "initializer_range")
else 0.02,
)
self.lang_model.set_input_embeddings(input_embeds)
out_embeds = DecoupledLinear(
max_original_id=initial_tokenizer_len - 1,
additional_out_features=len(self.special_tokens),
_weight=self.lang_model.get_output_embeddings().weight,
_bias=self.lang_model.get_output_embeddings().bias if hasattr(self.lang_model.get_output_embeddings(), "bias") else None,
)
if hasattr(out_embeds, "additional_fc"):
out_embeds.additional_fc.weight.data.normal_(
mean=0.0,
std=self.lang_model.config.initializer_range
if hasattr(self.lang_model.config, "initializer_range")
else 0.02,
)
self.lang_model.set_output_embeddings(out_embeds)
# gradient checkpointing
self.vision_tokenizer._use_gradient_checkpointing = gradient_checkpointing
def forward(
self,
vision_x: Optional[torch.Tensor],
lang_x: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
past_key_values: Optional[
List[Union[torch.Tensor, Tuple[torch.Tensor]]]
] = None,
past_media_locations: Optional[torch.Tensor] = None,
past_vision_tokens: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = False,
**kwargs,
):
"""
Args:
vision_x: Vision input
shape (B, T_img, F, C, H, W) with F=1
only F = 1 is supported (single-frame videos)
if T_img > the number of media tokens in the corresponding input_ids (lang_x),
only the first number of media tokens in lang_x are used
lang_x: Language input ids, with media tokens denoting where
visual media should be inserted.
shape (B, T_txt)
attention_mask: Attention mask. Defaults to None.
labels: Labels. Defaults to None.
shape (B, T_txt)
past_key_values (Tuple[torch.Tensor]], optional): Past key value pairs for each of the T_txt previous tokens in the language model. Defaults to None.
list of length = number of decoder layers in the LM
exact implementation depends on LM, see Hugging Face docs
past_media_locations (torch.Tensor, optional): boolean mask denoting which of the previous T_txt tokens were media tokens. Defaults to None.
shape (B, T_txt)
past_vision_tokens (torch.Tensor, optional): Previous vision tokens. Defaults to None.
use_cache (Optional[bool], optional): Whether to use cache. Defaults to False.
If True, includes key_values, media_locations, and vision_tokens in the output.
"""
assert not (past_vision_tokens is None) ^ (
past_media_locations is None
), "past_vision_tokens and past_media_locations must both be None or both be not None"
# convert pixels to vision tokens
if vision_x is not None:
vision_features = self._encode_vision_x(vision_x=vision_x)
vision_tokens = self.vision_tokenizer(vision_features)
else:
vision_tokens = None
# fuse the vision and language tokens
new_inputs = self._prepare_inputs_for_forward(
vision_tokens=vision_tokens,
lang_x=lang_x,
attention_mask=attention_mask,
labels=labels,
past_key_values=past_key_values,
past_media_locations=past_media_locations,
padding_side="right",
past_vision_tokens=past_vision_tokens,
)
output = self.lang_model(
**new_inputs,
use_cache=use_cache,
past_key_values=past_key_values,
**kwargs,
)
# postprocessing may be needed, e.g. to remove extra tokens from logits that were inserted into the language stream
# or to add the past_vision_tokens and past_media_locations to the output
output = self._postprocess_outputs_from_forward(
output=output,
lang_x=lang_x,
vision_tokens=vision_tokens,
use_cache=use_cache,
past_vision_tokens=past_vision_tokens,
past_media_locations=past_media_locations,
)
# postforward hooks
self._post_forward_hook()
return output
def _encode_vision_x_anyres(self, samples, device):
assert self.anyres_grids is not None
image_raw = samples["image"] # list of patch list in of shape [1, N_patch, C, H, W]
image_sizes = samples["image_size"]
# Image_raw can be a list of list of patches, when a `samples` has multiple images.
if isinstance(image_raw[0], list):
images = [x.squeeze(0) for sample_img in image_raw for x in sample_img]
image_sizes = [s for sample_sizes in image_sizes for s in sample_sizes]
else:
# assert isinstance(image_raw[0], torch.Tensor), f"Unkown image type: {image_raw[0]}"
# concate list of patches into one big patch for any res encoding.
images = [x.squeeze(0) for x in image_raw] # [N_patch, C, H, W]
image = torch.cat(images, dim=0) # [\sum{B}{N_patch_i}, C, H, W]
image = image.to(device)
with torch.no_grad():
if self.vision_encoder.__class__.__name__ == "TimmModel":
image_embeds = self.vision_encoder.trunk.forward_features(image)
elif self.vision_encoder.__class__.__name__ in ['CLIPVisionModel', 'SiglipVisionTransformer']:
image_embeds = self.vision_encoder(image, interpolate_pos_encoding=True).last_hidden_state
else:
image_embeds = self.vision_encoder(image)[1] # OpenCLIP returns tuples
if self.vision_encoder.__class__.__name__ == "TimmModel":
grid_size = self.vision_encoder.trunk.patch_embed.grid_size
elif self.vision_encoder.__class__.__name__ in ['CLIPVisionModel', 'SiglipVisionTransformer']:
grid_size_base = self.base_img_size // self.vision_encoder.config.patch_size
grid_size = (grid_size_base, grid_size_base)
else:
grid_size = self.vision_encoder.grid_size
height, width = grid_size
if not image_embeds.shape[1] == height * width:
assert image_embeds.shape[1] == height * width + 1 # For vision encoders that has [CLS] token.
image_embeds = image_embeds[:, 1:, :] # Drop the cls token for each patch.
n_vis_token_per_patch = image_embeds.shape[1]
# Split encoded patches and merge patch features
# 1. Get the raw sizes from samples, and split the image embeds [\sum_{B}(N_patch_i), N_tok(16*16), C]
split_sizes = [image.shape[0] for image in images]
image_embeds = torch.split(image_embeds, split_sizes, dim=0)
# 2. For each image (consist of a list of patches), merge the patches spatially (of shape [C, n_patch_height, n_patch_width])
new_image_embeds = []
patch_attn_masks = []
max_n_img_token = -1
for idx, patch_embeds in enumerate(image_embeds):
if patch_embeds.shape[0] > 1:
# 3. Flatten the patch features and get [C, n_patch_height * (n_patch_width+1)]
base_patch_embeds = patch_embeds[0] # TODO: prepend the CLS token for th base patch embeds (of the resized entire image).
patch_embeds = patch_embeds[1:]
assert height * width == base_patch_embeds.shape[0]
num_patch_width, num_patch_height = get_anyres_image_grid_shape(image_sizes[idx],
self.anyres_grids,
self.base_img_size) # Hardcoded grid_pinpoints.
patch_embeds = patch_embeds.view(num_patch_height, num_patch_width, height, width, -1)
patch_embeds = patch_embeds.permute(4, 0, 2, 1, 3).contiguous()
patch_embeds = patch_embeds.flatten(1, 2).flatten(2, 3)
patch_embeds, patch_attn_mask = unpad_image(patch_embeds, image_sizes[idx], self.anyres_patch_sampling)
if hasattr(self, 'image_newline'):
patch_embeds = torch.cat((
patch_embeds,
self.image_newline[:, None, None].expand(*patch_embeds.shape[:-1], 1)
), dim=-1)
if self.anyres_patch_sampling:
patch_embeds = patch_embeds.view(-1, num_patch_height, num_patch_width, height*width)
patch_embeds = patch_embeds.flatten(1, 2).permute(1, 2, 0)
assert patch_attn_mask is not None
patch_attn_mask = patch_attn_mask.view(num_patch_height, num_patch_width, height*width)
patch_attn_mask = patch_attn_mask.flatten(0, 1)
patch_embeds = torch.cat((base_patch_embeds.unsqueeze(0), patch_embeds), dim=0)
patch_attn_mask = torch.cat((torch.ones(n_vis_token_per_patch, device=patch_embeds.device).unsqueeze(0), patch_attn_mask), dim=0)
else:
patch_embeds = patch_embeds.flatten(1, 2).transpose(0, 1)
patch_embeds = torch.cat((base_patch_embeds, patch_embeds), dim=0)
else:
patch_embeds = patch_embeds[0].unsqueeze(0) if self.anyres_patch_sampling else patch_embeds[0]
patch_attn_mask = torch.ones(n_vis_token_per_patch, device=patch_embeds.device).unsqueeze(0) if self.anyres_patch_sampling else None
if hasattr(self, 'image_newline'):
patch_embeds = torch.cat((
patch_embeds,
self.image_newline[None]
), dim=0)
if not self.anyres_patch_sampling:
max_n_img_token = max(patch_embeds.shape[0], max_n_img_token)
new_image_embeds.append(patch_embeds)
patch_attn_masks.append(patch_attn_mask)
if self.anyres_patch_sampling:
# Return individual patches for independent token downsampling.
return new_image_embeds, patch_attn_masks
# 4. Pad and concat the list of image_embeds [N_tok_i, C] together into a batch. Also modify the query attention mask.
image_embeds = []
image_atts = []
for image_embed in new_image_embeds:
n_img_token = image_embed.shape[0]
img_attn = torch.ones((max_n_img_token), dtype=torch.long, device=image_embed.device)
if n_img_token < max_n_img_token:
padded_embed = torch.zeros((max_n_img_token, image_embed.shape[-1]), dtype=image_embed.dtype, device=image_embed.device)
padded_embed[:n_img_token, :] = image_embed
img_attn[n_img_token:] = 0 # Mask out the padded entries.
else:
padded_embed = image_embed
image_embeds.append(padded_embed)
image_atts.append(img_attn)
image_embeds = torch.stack(image_embeds, dim=0) # Shape [B, N_tok_longest, C_dim]
image_atts = torch.stack(image_atts, dim=0) # Shape [B, N_tok_longest, C_dim]
# TODO: reshape image_embeds and image_atts to "b T F v d"
image_embeds = image_embeds[:, None, None, :, :]
# image_atts = image_atts[:, None, None, :, :]
return image_embeds, image_atts
def _encode_vision_x(self, vision_x: torch.Tensor):
"""
Compute media tokens from vision input by passing it through vision encoder and conditioning language model.
Args:
vision_x: Vision input
shape (B, T_img, F, C, H, W)
Images in the same chunk are collated along T_img, and frames are collated along F
Currently only F=1 is supported (single-frame videos)
rearrange code based on https://github.com/dhansmair/flamingo-mini
"""
assert vision_x.ndim == 6, "vision_x should be of shape (b, T_img, F, C, H, W)"
b, T, F = vision_x.shape[:3]
vision_x = rearrange(vision_x, "b T F c h w -> (b T F) c h w")
with torch.no_grad():
if self.vision_encoder.__class__.__name__ == "TimmModel":
vision_x = self.vision_encoder.trunk.forward_features(vision_x)
elif self.vision_encoder.__class__.__name__ in ['CLIPVisionModel', 'SiglipVisionTransformer']:
vision_x = self.vision_encoder(vision_x).last_hidden_state
else:
vision_x = self.vision_encoder(vision_x)[1] # OpenCLIP returns tuples
vision_x = rearrange(vision_x, "(b T F) v d -> b T F v d", b=b, T=T, F=F)
return vision_x
def _concat_vision_cache(
self, lang_x, vision_tokens, past_vision_tokens, past_media_locations, use_cache
):
"""
Helper function to include the past vision tokens and past media locations in the output.
"""
if use_cache:
if past_media_locations is not None and past_vision_tokens is not None:
if vision_tokens is not None:
updated_vision_tokens = torch.cat(
[
past_vision_tokens,
vision_tokens,
],
dim=1,
)
else:
updated_vision_tokens = past_vision_tokens
updated_media_locations = torch.cat(
[
past_media_locations,
lang_x == self.media_token_id,
],
dim=1,
)
else:
updated_vision_tokens = vision_tokens
updated_media_locations = lang_x == self.media_token_id
else:
updated_vision_tokens = None
updated_media_locations = None
return updated_vision_tokens, updated_media_locations
def generate(
self,
vision_x: torch.Tensor,
lang_x: torch.Tensor,
attention_mask: torch.Tensor = None,
past_key_values: Optional[
List[Union[torch.Tensor, Tuple[torch.Tensor]]]
] = None,
past_media_locations: Optional[torch.Tensor] = None,
past_vision_tokens: Optional[torch.Tensor] = None,
**kwargs,
):
"""
Generate text conditioned on vision and language inputs.
Args:
vision_x (torch.Tensor): Vision input
shape (B, T_img, F, C, H, W)
see documentation for forward
lang_x (torch.Tensor): Language input
shape (B, T_txt)
attention_mask (torch.Tensor, optional): Attention mask. Defaults to None.
**kwargs: see generate documentation in Hugging Face CausalLM models.
Returns:
torch.Tensor: lang_x with generated tokens appended to it
"""
num_beams = kwargs.pop("num_beams", 1)
# convert pixels to vision tokens
if vision_x is not None:
vision_features = self._encode_vision_x(vision_x=vision_x)
vision_tokens = self.vision_tokenizer(vision_features)
else:
vision_tokens = None
# fuse the vision and language tokens
# for xattn, vision_x and media_location are repeat_interleaved s.t.
# the total batch size is B * num_beams
new_inputs = self._prepare_inputs_for_forward(
vision_tokens=vision_tokens,
lang_x=lang_x,
attention_mask=attention_mask,
past_key_values=past_key_values,
past_media_locations=past_media_locations,
past_vision_tokens=past_vision_tokens,
padding_side="left",
num_beams=num_beams,
)
output = self.lang_model.generate(
**new_inputs,
past_key_values=past_key_values,
num_beams=num_beams,
use_cache=True,
**kwargs,
)
self._post_forward_hook()
return output
@property
def num_trainable_params(self):
"""Print the number of trainable parameters"""
return num_params(self, filter_to_trainable=True)
def set_trainable(self):
"""
Freeze appropriate parameters in the model.
"""
raise NotImplementedError
def group_params_by_weight_decay(self):
"""
Return a tuple of (params to optimize w/ weight decay, params to optimize w/o weight decay)
"""
params_with_wd, params_without_wd = [], []
for n, p in self.named_parameters():
if p.requires_grad:
if self._should_apply_weight_decay(n):
params_with_wd.append(p)
else:
params_without_wd.append(p)
return params_with_wd, params_without_wd
def _should_apply_weight_decay(self, parameter_name):
"""
Return whether weight decay should be applied to a parameter.
"""
raise NotImplementedError
@property
def special_tokens(self):
"""
Returns a dict mapping from the attribute name of a special token to its string format,
e.g. "media_token": "<image>"
"""
assert (
"media_token" in self._special_tokens
), "VLMs need to request that the tokenizer add a media_token and call set_special_token_ids to set self.media_token_id"
return self._special_tokens
@property
def special_token_ids(self):
"""
Returns a list of the special token ids
"""
return [getattr(self, f"{att_name}_id") for att_name in self.special_tokens]
def set_special_token_ids(self, string_to_ids):
"""
Args:
string_to_ids (dict): mapping from token string to id
"""
assert set(self.special_tokens.values()).issubset(set(string_to_ids.keys()))
for att_name, token_str in self.special_tokens.items():
token_id = string_to_ids[token_str]
setattr(self, f"{att_name}_id", token_id)
setattr(self.lang_model, f"{att_name}_id", token_id)
def init_gradient_checkpointing(self):
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper,
CheckpointWrapper,
CheckpointImpl,
apply_activation_checkpointing,
)
from functools import partial
non_reentrant_wrapper = partial(
checkpoint_wrapper,
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
)
apply_activation_checkpointing(
self,
checkpoint_wrapper_fn=non_reentrant_wrapper,
check_fn=lambda m: getattr(m, "_use_gradient_checkpointing", False)
and not isinstance(m, CheckpointWrapper),
)
class VLMWithCrossAttention(VLM):
"""
VLM using cross-attention to fuse vision and language tokens.
"""
def __init__(
self,
vision_encoder: nn.Module,
vision_tokenizer: nn.Module,
lang_model: nn.Module,
initial_tokenizer_len: int,
pad_token_id: int,
gradient_checkpointing: bool = False,
decoder_layers_attr_name: str = None,
cross_attn_every_n_layers: int = None,
):
extend_instance(lang_model, CrossAttentionMixin)
super().__init__(
vision_encoder=vision_encoder,
vision_tokenizer=vision_tokenizer,
lang_model=lang_model,
initial_tokenizer_len=initial_tokenizer_len,
pad_token_id=pad_token_id,
gradient_checkpointing=gradient_checkpointing,
)
self.lang_model.set_decoder_layers_attr_name(decoder_layers_attr_name)
self.decoder_layers_attr_name = decoder_layers_attr_name
self.lang_model.init_cross_attention_layers(
lang_hidden_size=self.lang_hidden_dim,
vis_hidden_size=self.vis_embedding_dim,
cross_attn_every_n_layers=cross_attn_every_n_layers,
gradient_checkpointing=gradient_checkpointing,
)
def _prepare_inputs_for_forward(
self,
vision_tokens: torch.Tensor,
lang_x: torch.Tensor,
attention_mask: torch.Tensor,
labels: torch.Tensor = None,
past_key_values=None,
past_media_locations: torch.Tensor = None,
past_vision_tokens: torch.Tensor = None,
padding_side: str = "right", # noop for cross-attention models
num_beams: int = 1,
):
"""Each xattn layer needs to save the vision tokens and the locations of the media tokens in the language sequence"""
self.lang_model._condition_media_before_forward(
input_ids=lang_x,
vision_tokens=vision_tokens,
past_media_locations=past_media_locations,
past_vision_tokens=past_vision_tokens,
num_beams=num_beams,
)
if past_key_values is not None:
past_key_values = [
(
k.repeat_interleave(num_beams, dim=0),
v.repeat_interleave(num_beams, dim=0)
)
for k, v in past_key_values
]
return {
"input_ids": lang_x,
"attention_mask": attention_mask,
"labels": labels,
}
def _postprocess_outputs_from_forward(
self,
output: CausalLMOutputWithPast,
lang_x: torch.Tensor,
vision_tokens: torch.Tensor,
past_vision_tokens: torch.Tensor,
past_media_locations: torch.Tensor,
use_cache: bool = False,
):
"""Include the past vision tokens and past media locations in the output"""
updated_vision_tokens, updated_media_locations = self._concat_vision_cache(
lang_x=lang_x,
vision_tokens=vision_tokens,
past_vision_tokens=past_vision_tokens,
past_media_locations=past_media_locations,
use_cache=use_cache,
)
output = VLMOutputWithPast(
loss=output.loss,
logits=output.logits,
past_key_values=output.past_key_values,
hidden_states=output.hidden_states,
attentions=output.attentions,
past_media_locations=updated_media_locations,
past_vision_tokens=updated_vision_tokens,
)
return output
def _post_forward_hook(self):
# clear the conditioned layers
self.lang_model.clear_conditioned_layers()
def get_fsdp_lambda_fn(self):
"""
Returns the lambda function used to decide how to perform FSDP wrapping.
"""
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
CheckpointWrapper,
)
from .helpers import GatedCrossAttentionBlock
decoder_block_class = getattr_recursive(
self.lang_model, self.decoder_layers_attr_name
)[0].__class__
def lambda_fn(module: nn.Module):
# we want FSDP(ckpt(module)), not ckpt(FSDP(module))
if getattr(module, "_use_gradient_checkpointing", False) and not isinstance(
module, CheckpointWrapper
):
return False
if module is self.vision_tokenizer:
return True
if isinstance(module, GatedCrossAttentionBlock):
return True
if isinstance(module, decoder_block_class):
return True
return lambda_fn
@property
def num_params_per_module(self):
"""Print the number of parameters per module in the model"""
num_xattn_params = num_params(self.lang_model.gated_cross_attn_layers)
return "\n".join(
[
f"Vision encoder: {num_params(self.vision_encoder):,} parameters",
f"Vision tokenizer: {num_params(self.vision_tokenizer):,} parameters",
f"Cross attention: {num_xattn_params:,} parameters",
f"Language model: {num_params(self.lang_model) - num_xattn_params:,} parameters",
]
)
@property
def num_trainable_params_per_module(self):
"""Print the number of trainable parameters per module in the model"""
num_xattn_params = num_params(
self.lang_model.gated_cross_attn_layers, filter_to_trainable=True
)
return "\n".join(
[
f"Vision encoder: {num_params(self.vision_encoder, filter_to_trainable=True):,} trainable parameters",
f"Vision tokenizer: {num_params(self.vision_tokenizer, filter_to_trainable=True):,} trainable parameters",
f"Cross attention: {num_xattn_params:,} trainable parameters",
f"Language model: {num_params(self.lang_model, filter_to_trainable=True) - num_xattn_params:,} trainable parameters",
]
)
class VLMWithLanguageStream(VLM):
"""
VLM that fuses modalities by inserting vision tokens directly into the language stream.
"""
def __init__(
self,
vision_encoder: nn.Module,
vision_tokenizer: nn.Module,
lang_model: nn.Module,
initial_tokenizer_len: int,
pad_token_id: int,
decoder_layers_attr_name: str = None,
gradient_checkpointing: bool = False,
base_img_size: Optional[int] = None,
):
super().__init__(
vision_encoder=vision_encoder,
vision_tokenizer=vision_tokenizer,
lang_model=lang_model,
initial_tokenizer_len=initial_tokenizer_len,
pad_token_id=pad_token_id,
base_img_size=base_img_size,
gradient_checkpointing=gradient_checkpointing,
)
self.decoder_layers_attr_name = decoder_layers_attr_name
for block in getattr_recursive(self.lang_model, self.decoder_layers_attr_name):
block._use_gradient_checkpointing = gradient_checkpointing
def _prepare_inputs_for_forward(
self,
vision_tokens: torch.Tensor,
lang_x: torch.Tensor,
attention_mask: torch.Tensor,
labels: torch.Tensor = None,
past_key_values=None,
vision_attention_mask: Optional[torch.Tensor] = None,
past_media_locations: torch.Tensor = None,
past_vision_tokens: torch.Tensor = None,
padding_side: str = "left",
num_beams: int = 1,
):
"""
Insert the vision tokens directly into the language stream/
This requires us to modify the input_ids, attention_mask, and labels.
"""
if past_key_values is not None:
past_len = past_key_values[0][0].shape[2]
assert attention_mask.shape[1] == past_len + lang_x.shape[1], (
"Attention_mask must be as long as the entire past len (including image tokens) and current input IDs. "
+ "Check that you've expanded the attention mask to account for past image tokens."
)
if vision_tokens is None:
return {
"input_ids": lang_x,
"attention_mask": attention_mask,
"labels": labels,
}
# get the language embeddings
lang_embeds = self.lang_model.get_input_embeddings()(lang_x)
# build up the multimodal embeddings
B = lang_x.shape[0]
has_labels = labels is not None
multimodal_embeds = []
multimodal_attention_mask = []
multimodal_labels = [] if has_labels else None
for i in range(B):
# get index of <image> tokens in lang_x[i]
image_token_idxs = torch.where(lang_x[i] == self.media_token_id)[0]
if len(image_token_idxs) == 0:
multimodal_embeds.append(lang_embeds[i].clone())
multimodal_attention_mask.append(attention_mask[i].clone())
if has_labels:
multimodal_labels.append(labels[i].clone())
continue
# since an image is represented by self.num_tokens_per_vis tokens, we need to offset the image_token_idxs
# loop through the image_token_idxs and insert the vision tokens
new_embed = lang_embeds[i].clone()
new_attention_mask = (
attention_mask[i].clone() if attention_mask is not None else None
)
if has_labels:
new_label = labels[i].clone()
for img_num in range(len(image_token_idxs)):
img_idx = image_token_idxs[img_num]
# Get vision token attention mask for padded llava-style any resolution image tokens.
if self.image_aspect_ratio =='anyres':
num_vis_tokens = vision_tokens[i][img_num].shape[0]
if vision_attention_mask is not None:
vis_attention_mask = vision_attention_mask[i][img_num]
else:
vis_attention_mask = torch.ones(
num_vis_tokens, dtype=torch.long
).to(attention_mask.device)
else:
assert (
vision_tokens[i][img_num].shape[0] == self.num_tokens_per_vis
), f"vision token number mismatch: image embedding ({vision_tokens[i][img_num].shape[0]}) \
vs. model.num_tokens_per_vis ({self.num_tokens_per_vis})"
# By default, vision tokens are not padded.
num_vis_tokens = self.num_tokens_per_vis
vis_attention_mask = torch.ones(
num_vis_tokens, dtype=torch.long
).to(attention_mask.device)
# Offset the rest of image tokens with current num_vis_tokens
for j in range(img_num+1, len(image_token_idxs)):
image_token_idxs[j] += (num_vis_tokens - 1)
new_embed = torch.cat(
(
new_embed[:img_idx],
vision_tokens[i][img_num],
new_embed[img_idx + 1 :],
),
dim=0,
)
new_attention_mask = torch.cat(
(
new_attention_mask[:img_idx],
vis_attention_mask,
new_attention_mask[img_idx + 1 :],
),
dim=0,
)
if has_labels:
new_label = torch.cat(
(
new_label[:img_idx],
torch.ones(num_vis_tokens, dtype=torch.long).to(
labels.device
)
* -100,
new_label[img_idx + 1 :],
),
dim=0,
)
multimodal_embeds.append(new_embed)
multimodal_attention_mask.append(new_attention_mask)
if has_labels:
multimodal_labels.append(new_label)
# stack
multimodal_embeds = stack_with_padding(
multimodal_embeds,
padding_value=self.pad_token_id,
padding_side=padding_side,
)
multimodal_attention_mask = stack_with_padding(
multimodal_attention_mask,
padding_value=0,
padding_side=padding_side,
)
if has_labels:
multimodal_labels = stack_with_padding(
multimodal_labels,
padding_value=-100,
padding_side=padding_side,
)
return {
"inputs_embeds": multimodal_embeds,
"attention_mask": multimodal_attention_mask,
"labels": multimodal_labels,
}
def _postprocess_outputs_from_forward(
self,
output: CausalLMOutputWithPast,
lang_x: torch.Tensor,
vision_tokens: torch.Tensor,
past_vision_tokens: torch.Tensor,
past_media_locations: torch.Tensor,
use_cache: bool = False,
):
# Include the past vision tokens and past media locations in the output
updated_vision_tokens, updated_media_locations = self._concat_vision_cache(
lang_x=lang_x,
vision_tokens=vision_tokens,
past_vision_tokens=past_vision_tokens,
past_media_locations=past_media_locations,
use_cache=use_cache,
)
# return logits that are the same shape as the original input_ids
logits = output.logits
batch_logits = []
B, T_txt = lang_x.shape
for i in range(B):
sequence_logits = []
logits_j = 0
img_id = 0
for j in range(T_txt):
if lang_x[i, j] != self.media_token_id:
sequence_logits.append(logits[i, logits_j])
logits_j += 1
else:
# append the logit for the first image token, then skip over the rest
# note: the model actually learns to predict <im_patch>, not <image>
sequence_logits.append(logits[i, logits_j])
# logits_j += self.num_tokens_per_vis
# Offset in account of dynamic num_vis_tokens.
logits_j += vision_tokens[i][img_id].shape[0]
img_id += 1
sequence_logits = torch.stack(sequence_logits, dim=0) # (B, vocab_size)
batch_logits.append(sequence_logits)
batch_logits = torch.stack(batch_logits, dim=0) # (B, T_txt, vocab_size)
# The final logits shape should be the same as the original input_ids shape
assert batch_logits.shape[:2] == (B, T_txt)
# assemble the output
output = VLMOutputWithPast(
loss=output.loss,
logits=batch_logits,
past_key_values=output.past_key_values,
hidden_states=output.hidden_states,
attentions=output.attentions,
past_media_locations=updated_media_locations,
past_vision_tokens=updated_vision_tokens,
)
return output
def _post_forward_hook(self):
pass
def get_fsdp_lambda_fn(self):
"""
Returns the lambda function used to decide how to perform FSDP wrapping.
"""
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
CheckpointWrapper,
)
decoder_block_class = getattr_recursive(
self.lang_model, self.decoder_layers_attr_name
)[0].__class__
def lambda_fn(module: nn.Module):
if getattr(module, "_use_gradient_checkpointing", False) and not isinstance(
module, CheckpointWrapper
):
return False
if module is self.vision_tokenizer:
return True
if isinstance(module, decoder_block_class):
return True
return lambda_fn
def get_fsdp_wrapping_policy(self):
"""
Returns the policy used to decide how to perform FSDP wrapping.
"""
from torch.distributed.fsdp.wrap import _or_policy, _module_wrap_policy, transformer_auto_wrap_policy
from open_clip.transformer import VisionTransformer, ResidualAttentionBlock
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
from transformers.models.phi.modeling_phi import PhiDecoderLayer
# for Phi-3 hot fiix
try:
import importlib
commit_hash = str(type(self.lang_model)).split('instruct.')[1].split('.modeling')[0]
module_name = f"transformers_modules.microsoft.Phi-3-mini-128k-instruct.{commit_hash}.modeling_phi3"
module = importlib.import_module(module_name)
Phi3DecoderLayer = module.Phi3DecoderLayer
import_phi3 = True
except IndexError:
import_phi3 = False
# hard code the wrap module name
# vision
if isinstance(self.vision_encoder, SiglipVisionModel):
from transformers import SiglipVisionModel
vit_wrap_policy = functools.partial(_module_wrap_policy, module_classes={SiglipVisionModel})
from transformers.models.siglip.modeling_siglip import SiglipEncoderLayer, SiglipVisionTransformer, SiglipVisionEmbeddings, SiglipMultiheadAttentionPoolingHead
# import torch.nn.LayerNorm as LayerNorm
transformer_layer_cls_vit = {SiglipEncoderLayer, SiglipVisionTransformer, SiglipVisionEmbeddings, SiglipMultiheadAttentionPoolingHead}
vision_transformer_block_policy = functools.partial(transformer_auto_wrap_policy, transformer_layer_cls=transformer_layer_cls_vit)
vision_wrap_policy = functools.partial(_or_policy, policies=[vit_wrap_policy, vision_transformer_block_policy])
else:
vit_wrap_policy = functools.partial(_module_wrap_policy, module_classes={VisionTransformer, TimmModel})
# vit_wrap_policy = functools.partial(_module_wrap_policy, module_classes={VisionTransformer})
# transformer_layer_cls_vit = {ResidualAttentionBlock}
transformer_layer_cls_vit = {ResidualAttentionBlock, Block}
# transformer_layer_cls_vit = {Block}
vision_transformer_block_policy = functools.partial(transformer_auto_wrap_policy, transformer_layer_cls=transformer_layer_cls_vit)
vision_wrap_policy = functools.partial(_or_policy, policies=[vit_wrap_policy, vision_transformer_block_policy])
# llm
transformer_layer_cls={LlamaDecoderLayer, PhiDecoderLayer}
if import_phi3:
transformer_layer_cls.add(Phi3DecoderLayer)
llm_transformer_block_policy = functools.partial(transformer_auto_wrap_policy, transformer_layer_cls=transformer_layer_cls)
# vision_tokenizer
vis_tokenizer_policy = functools.partial(_module_wrap_policy, module_classes={LinearPatchProjection, PerceiverResampler})
return functools.partial(
_or_policy,
policies = [
vision_wrap_policy,
llm_transformer_block_policy,
vis_tokenizer_policy
])
@property
def num_params_per_module(self):
"""Print the number of parameters per module in the model"""
return "\n".join(
[
f"Vision encoder: {num_params(self.vision_encoder):,} parameters",
f"Vision tokenizer: {num_params(self.vision_tokenizer):,} parameters",
f"Language model: {num_params(self.lang_model):,} parameters",
]
)
@property
def num_trainable_params_per_module(self):
"""Print the number of trainable parameters per module in the model"""
return "\n".join(
[
f"Vision encoder: {num_params(self.vision_encoder, filter_to_trainable=True):,} trainable parameters",
f"Vision tokenizer: {num_params(self.vision_tokenizer, filter_to_trainable=True):,} trainable parameters",
f"Language model: {num_params(self.lang_model, filter_to_trainable=True):,} trainable parameters",
]
)
import torch
from einops import rearrange
from torch import nn
from typing import List, Optional, Tuple, Union
import os
from .helpers import PerceiverResampler
from .vlm import VLMWithLanguageStream
class XGenMMPerceiver(VLMWithLanguageStream):
def __init__(
self,
vision_encoder: nn.Module,
lang_model: nn.Module,
vis_feature_dim: int,
initial_tokenizer_len: int,
pad_token_id: int,
decoder_layers_attr_name: str = None,
gradient_checkpointing: bool = False,
base_img_size: Optional[int] = None,
image_aspect_ratio: str = 'anyres',
anyres_patch_sampling: bool = True,
num_vision_tokens: int = 128,
):
"""
Args:
vision_encoder (nn.Module): HF CLIPModel
lang_encoder (nn.Module): HF causal language model
vis_feature_dim (int): final dimension of the visual features outputted by the vision_encoder
initial_tokenizer_len (int): size of the tokenizer vocab
padding_token_id (int): id of the padding token. None if no padding token; then a padding token
will be inserted into self.special_tokens, which factory.py fills after creating new tokens
decoder_layers_attr_name (str, optional): name of the decoder layers attribute. Defaults to None.
gradient_checkpointing (bool, optional): whether to use gradient checkpointing. Defaults to False.
"""
self._special_tokens = {
"media_token": "<image>",
"image_placeholder_token": "<image placeholder>",
"end_of_trunk_token": "<|endofchunk|>",
}
lang_embedding_dim = lang_model.get_input_embeddings().weight.shape[1]
super().__init__(
vision_encoder=vision_encoder,
vision_tokenizer=PerceiverResampler(
dim=vis_feature_dim, dim_inner=lang_embedding_dim,
num_latents=num_vision_tokens,
),
lang_model=lang_model,
initial_tokenizer_len=initial_tokenizer_len,
gradient_checkpointing=gradient_checkpointing,
base_img_size=base_img_size,
decoder_layers_attr_name=decoder_layers_attr_name,
pad_token_id=pad_token_id,
)
self.image_aspect_ratio = image_aspect_ratio
self.anyres_patch_sampling = anyres_patch_sampling
self.anyres_grids = None
def set_trainable(self):
"""
Unfreeze everything except the vision_encoder
"""
self.requires_grad_(True)
self.vision_encoder.requires_grad_(False)
def _should_apply_weight_decay(self, parameter_name):
"""
Kosmos applies 0.01 weight deacy to everything
"""
return True
def forward(
self,
vision_x: Optional[torch.Tensor],
lang_x: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
image_size: Optional[Tuple] = None,
past_key_values: Optional[
List[Union[torch.Tensor, Tuple[torch.Tensor]]]
] = None,
past_media_locations: Optional[torch.Tensor] = None,
past_vision_tokens: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = False,
**kwargs,
):
"""
Args:
vision_x: Vision input
shape (B, T_img, F, C, H, W) with F=1
only F = 1 is supported (single-frame videos)
if T_img > the number of media tokens in the corresponding input_ids (lang_x),
only the first number of media tokens in lang_x are used
lang_x: Language input ids, with media tokens denoting where
visual media should be inserted.
shape (B, T_txt)
attention_mask: Attention mask. Defaults to None.
labels: Labels. Defaults to None.
shape (B, T_txt)
past_key_values (Tuple[torch.Tensor]], optional): Past key value pairs for each of the T_txt previous tokens in the language model. Defaults to None.
list of length = number of decoder layers in the LM
exact implementation depends on LM, see Hugging Face docs
past_media_locations (torch.Tensor, optional): boolean mask denoting which of the previous T_txt tokens were media tokens. Defaults to None.
shape (B, T_txt)
past_vision_tokens (torch.Tensor, optional): Previous vision tokens. Defaults to None.
use_cache (Optional[bool], optional): Whether to use cache. Defaults to False.
If True, includes key_values, media_locations, and vision_tokens in the output.
"""
assert not (past_vision_tokens is None) ^ (
past_media_locations is None
), "past_vision_tokens and past_media_locations must both be None or both be not None"
# convert pixels to vision tokens
vision_attention_mask = None
if vision_x is not None:
if self.image_aspect_ratio == 'anyres':
input_dict = dict(image=vision_x, image_size=image_size)
vision_features, vision_attn_masks = self._encode_vision_x_anyres(input_dict, lang_x.device)
else:
vision_features = self._encode_vision_x(vision_x=vision_x)
vision_attn_masks = None
if self.anyres_patch_sampling:
split_sizes = [feature.shape[0] for feature in vision_features]
# Nested splits for multi-image samples.
if isinstance(vision_x[0], list):
nt_images = [len(images) for images in vision_x]
split_split_sizes = []
img_id = 0
for nt in nt_images:
split_split_sizes.append(split_sizes[img_id:img_id+nt])
img_id += nt
else:
nt_images = [1] * len(vision_x)
split_split_sizes = split_sizes
vision_features = torch.cat(vision_features, dim=0)
vision_features = vision_features[:, None, None, :, :] # Expand dimensions.
vision_attn_masks = torch.cat(vision_attn_masks, dim=0)
vision_tokens = self.vision_tokenizer(vision_features, vision_attn_masks)
# Post-processing: Split the batches into groups of patches and concatenate them together.
if self.anyres_patch_sampling:
# assert isinstance(vision_x, list)
if isinstance(vision_x[0], list):
vision_token_groups = torch.split(vision_tokens, list(sum(nt_img) for nt_img in split_split_sizes), dim=0)
vision_tokens = []
for sample_id, patch_vis_tokens in enumerate(vision_token_groups):
patch_vis_token_groups = torch.split(patch_vis_tokens, split_split_sizes[sample_id], dim=0) # [Np*nt, 1, v, d] -> [[Np_t, 1, v, d], ...]
flatten_vision_tokens = []
# padded_attn_masks = []
for image_vis_token in patch_vis_token_groups:
image_vis_token = image_vis_token.flatten(0, 2) # [Np, 1, v, d] -> [Np*v, d]
flatten_vision_tokens.append(image_vis_token)
vision_tokens_i = flatten_vision_tokens
vision_tokens.append(vision_tokens_i)
else:
vision_token_groups = torch.split(vision_tokens, split_sizes, dim=0)
vision_tokens = []
for patch_vis_tokens in vision_token_groups:
patch_vis_tokens = patch_vis_tokens.flatten(0, 2) # [Np, 1, v, d] -> [Np*v, d]
vision_tokens.append(patch_vis_tokens.unsqueeze(0)) # Add the nt dimension.
else:
vision_tokens = None
# fuse the vision and language tokens
new_inputs = self._prepare_inputs_for_forward(
vision_tokens=vision_tokens,
lang_x=lang_x,
attention_mask=attention_mask,
vision_attention_mask=vision_attention_mask,
labels=labels,
past_key_values=past_key_values,
past_media_locations=past_media_locations,
padding_side="right",
past_vision_tokens=past_vision_tokens,
)
output = self.lang_model(
**new_inputs,
use_cache=use_cache,
past_key_values=past_key_values,
**kwargs,
)
# postforward hooks
self._post_forward_hook()
return output
def generate(
self,
vision_x: torch.Tensor,
lang_x: torch.Tensor,
image_size: Optional[Tuple] = None,
attention_mask: torch.Tensor = None,
past_key_values: Optional[
List[Union[torch.Tensor, Tuple[torch.Tensor]]]
] = None,
past_media_locations: Optional[torch.Tensor] = None,
past_vision_tokens: Optional[torch.Tensor] = None,
**kwargs,
):
"""
Generate text conditioned on vision and language inputs.
Args:
vision_x (torch.Tensor): Vision input
shape (B, T_img, F, C, H, W)
see documentation for forward
lang_x (torch.Tensor): Language input
shape (B, T_txt)
attention_mask (torch.Tensor, optional): Attention mask. Defaults to None.
**kwargs: see generate documentation in Hugging Face CausalLM models.
Returns:
torch.Tensor: lang_x with generated tokens appended to it
"""
num_beams = kwargs.pop("num_beams", 1)
# convert pixels to vision tokens
vision_attention_mask = None
if vision_x is not None:
if self.image_aspect_ratio == 'anyres':
input_dict = dict(image=vision_x, image_size=image_size)
vision_features, vision_attn_masks = self._encode_vision_x_anyres(input_dict, lang_x.device)
else:
vision_features = self._encode_vision_x(vision_x=vision_x)
vision_attn_masks = None
if self.anyres_patch_sampling:
split_sizes = [feature.shape[0] for feature in vision_features]
# Nested splits for multi-image samples.
if isinstance(vision_x[0], list):
nt_images = [len(images) for images in vision_x]
split_split_sizes = []
img_id = 0
for nt in nt_images:
split_split_sizes.append(split_sizes[img_id:img_id+nt])
img_id += nt
else:
nt_images = [1] * len(vision_x)
split_split_sizes = split_sizes
vision_features = torch.cat(vision_features, dim=0)
vision_features = vision_features[:, None, None, :, :] # Expand dimensions.
vision_attn_masks = torch.cat(vision_attn_masks, dim=0)
vision_tokens = self.vision_tokenizer(vision_features, vision_attn_masks)
# Post-processing: Split the batches into groups of patches and concatenate them together.
if self.anyres_patch_sampling:
assert isinstance(vision_x, list)
if isinstance(vision_x[0], list):
vision_token_groups = torch.split(vision_tokens, list(sum(nt_img) for nt_img in split_split_sizes), dim=0)
vision_tokens = []
for sample_id, patch_vis_tokens in enumerate(vision_token_groups):
# Pad the image tokens within a sample.
patch_vis_token_groups = torch.split(patch_vis_tokens, split_split_sizes[sample_id], dim=0) # [Np*nt, 1, v, d] -> [[Np_t, 1, v, d], ...]
flatten_vision_tokens = []
for image_vis_token in patch_vis_token_groups:
image_vis_token = image_vis_token.flatten(0, 2) # [Np, 1, v, d] -> [Np*v, d]
flatten_vision_tokens.append(image_vis_token)
vision_tokens_i = flatten_vision_tokens
vision_tokens.append(vision_tokens_i)
else:
# Padding. FIXME: padding here might not be necessary?
vision_token_groups = torch.split(vision_tokens, split_sizes, dim=0)
# Padding.
vision_tokens = []
for patch_vis_tokens in vision_token_groups:
patch_vis_tokens = patch_vis_tokens.flatten(0, 2) # [Np, 1, v, d] -> [Np*v, d]
vision_tokens.append(patch_vis_tokens.unsqueeze(0)) # Add the nt dimension.
else:
vision_tokens = None
# fuse the vision and language tokens
new_inputs = self._prepare_inputs_for_forward(
vision_tokens=vision_tokens,
lang_x=lang_x,
attention_mask=attention_mask,
vision_attention_mask=vision_attention_mask,
past_key_values=past_key_values,
past_media_locations=past_media_locations,
past_vision_tokens=past_vision_tokens,
padding_side="left",
num_beams=num_beams,
)
if past_key_values is not None:
output = self.lang_model.generate(
**new_inputs,
past_key_values=past_key_values,
num_beams=num_beams,
use_cache=True,
**kwargs,
)
else:
output = self.lang_model.generate(
**new_inputs,
num_beams=num_beams,
use_cache=True,
**kwargs,
)
self._post_forward_hook()
return output
# OpenFlamingo Training
We provide efficient data loading and distributed training code.
To train with OpenFlamingo, please ensure your environment matches that of `environment.yml`.
Table of contents:
* [Data](#data)
* [Example commands](#example-training-command)
* [Distributed training](#distributed-training)
## Data
Our codebase uses [WebDataset](https://github.com/webdataset/webdataset) to efficiently load `.tar` files containing image and text sequences. We recommend resampling shards with replacement during training using the `--dataset_resampled` flag.
Supported pretraining datasets
* LAION-2B
* Multimodal C4 (MMC4)
* ChatGPT-generated sequences from OpenFlamingo [technical report](https://arxiv.org/abs/2308.01390)
We plan to add additional datasets in the future, and we welcome contributions! If you'd like to add support for a pretraining dataset, please open a PR.
### LAION-2B Dataset
[LAION-2B](https://arxiv.org/abs/2210.08402) contains 2B web-scraped (image, text) pairs.
We use [img2dataset](https://github.com/rom1504/img2dataset) to download this dataset into tar files.
### Multimodal C4 Dataset
We train on the full version of [Multimodal C4 (MMC4)](https://github.com/allenai/mmc4), which includes 103M documents of web-scraped, interleaved image-text sequences. During training, we truncate sequences to 256 text tokens and six images per sequence.
Our codebase expects `.tar` files containing `.json` files, which include raw images encoded in base64.
We provide scripts to convert MMC4 to this format:
1. Download the MMC4 shards into `.zip` files using [the MMC4-provided scripts](https://github.com/allenai/mmc4/tree/main/scripts) (e.g., `fewer_facesv2.sh`).
2. Download the MMC4 raw images into an image directory using [the MMC4-provided scripts](https://github.com/allenai/mmc4/tree/main/scripts) (e.g., `download_images.py`).
2. Run `scripts/convert_mmc4_to_wds.py` to convert the downloaded items into the expected tar files.
### ChatGPT-generated sequences
A subset of our models (listed below) were also trained on experimental ChatGPT-generated (image, text) sequences, where images are pulled from LAION. The shards containing these sequences can be found at [this CodaLab worksheet](https://worksheets.codalab.org/worksheets/0xdcd888ff7c754ae680c5e038f6ed1d9b). We are unable to distribute raw images in the released shards; images must be pre-downloaded from the urls in the json files and converted to base64 before using this data for training in our codebase.
Models trained with ChatGPT-generated sequences:
* OpenFlamingo-4B-vitl-rpj3b
* OpenFlamingo-4B-vitl-rpj3b-langinstruct
## Example training command
We provide sample Slurm training scripts in `scripts/`. You can also modify the following command:
```
torchrun --nnodes=1 --nproc_per_node=4 train.py \
--lm_path anas-awadalla/mpt-1b-redpajama-200b \
--tokenizer_path anas-awadalla/mpt-1b-redpajama-200b \
--cross_attn_every_n_layers 1 \
--dataset_resampled \
--batch_size_mmc4 32 \
--batch_size_laion 64 \
--train_num_samples_mmc4 125000\
--train_num_samples_laion 250000 \
--loss_multiplier_laion 0.2 \
--workers=4 \
--run_name OpenFlamingo-3B-vitl-mpt1b \
--num_epochs 480 \
--warmup_steps 1875 \
--mmc4_textsim_threshold 0.24 \
--laion_shards "/path/to/shards/shard-{0000..0999}.tar" \
--mmc4_shards "/path/to/shards/shard-{0000..0999}.tar" \
--report_to_wandb
```
*Note: The MPT-1B [base](https://huggingface.co/mosaicml/mpt-1b-redpajama-200b) and [instruct](https://huggingface.co/mosaicml/mpt-1b-redpajama-200b-dolly) modeling code does not accept the `labels` kwarg or compute cross-entropy loss directly within `forward()`, as expected by our codebase. We suggest using a modified version of the MPT-1B models found [here](https://huggingface.co/anas-awadalla/mpt-1b-redpajama-200b) and [here](https://huggingface.co/anas-awadalla/mpt-1b-redpajama-200b-dolly).*
## Distributed training
Our codebase supports distributed training using three frameworks:
* Pytorch's [DistributedDataParallel](https://pytorch.org/docs/stable/torch.nn.parallel.DistributedDataParallel.html). This is the default method used by `train.py`.
* Pytorch's [FullyShardedDataParallel](https://pytorch.org/docs/stable/fsdp.html) (FSDP). Use the `--fsdp` flag.
Note that you should use exactly one of these training methods.
`train/distributed.py` contains utilities to help with setting up distributed training using Slurm / `torchrun`. See example scripts in the `scripts` directory.
### FSDP notes
To use FSDP, make sure to use Pytorch (> 2.0.1).
We support two sharding strategies for FSDP: full sharding (model sharing across all nodes and GPUs) or hybrid sharding (model sharding across GPUs within nodes, data parallel between nodes). The former saves GPU memory; the latter saves on communication costs.
# Adopted from https://github.com/haotian-liu/LLaVA. Below is the original copyright:
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import ast
import math
from PIL import Image
import torch
def unpad_image(tensor, original_size, keep_original_shape=False):
"""
Unpads a PyTorch tensor of a padded and resized image.
Args:
tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format.
original_size (tuple): The original size of the image (height, width).
Returns:
torch.Tensor: The unpadded image tensor.
"""
original_width, original_height = original_size
current_height, current_width = tensor.shape[1:]
original_aspect_ratio = original_width / original_height
current_aspect_ratio = current_width / current_height
if original_aspect_ratio > current_aspect_ratio:
scale_factor = current_width / original_width
new_height = int(original_height * scale_factor)
padding = (current_height - new_height) // 2
if keep_original_shape:
attention_mask = torch.ones((current_height, current_width), device=tensor.device)
attention_mask[:padding, :] = 0
attention_mask[current_height - padding:, :] = 0
return tensor, attention_mask
else:
unpadded_tensor = tensor[:, padding:current_height - padding, :]
return unpadded_tensor, None
else:
scale_factor = current_height / original_height
new_width = int(original_width * scale_factor)
padding = (current_width - new_width) // 2
if keep_original_shape:
attention_mask = torch.ones((current_height, current_width), device=tensor.device)
attention_mask[:, :padding] = 0
attention_mask[:, current_width - padding:] = 0
return tensor, attention_mask
else:
unpadded_tensor = tensor[:, :, padding:current_width - padding]
return unpadded_tensor, None
def select_best_resolution(original_size, possible_resolutions):
"""
Selects the best resolution from a list of possible resolutions based on the original size.
Args:
original_size (tuple): The original size of the image in the format (width, height).
possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
Returns:
tuple: The best fit resolution in the format (width, height).
"""
original_width, original_height = original_size
best_fit = None
max_effective_resolution = 0
min_wasted_resolution = float('inf')
for width, height in possible_resolutions:
scale = min(width / original_width, height / original_height)
downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
wasted_resolution = (width * height) - effective_resolution
if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
max_effective_resolution = effective_resolution
min_wasted_resolution = wasted_resolution
best_fit = (width, height)
return best_fit
def resize_and_pad_image(image, target_resolution):
"""
Resize and pad an image to a target resolution while maintaining aspect ratio.
Args:
image (PIL.Image.Image): The input image.
target_resolution (tuple): The target resolution (width, height) of the image.
Returns:
PIL.Image.Image: The resized and padded image.
"""
original_width, original_height = image.size
target_width, target_height = target_resolution
scale_w = target_width / original_width
scale_h = target_height / original_height
if scale_w < scale_h:
new_width = target_width
new_height = min(math.ceil(original_height * scale_w), target_height)
else:
new_height = target_height
new_width = min(math.ceil(original_width * scale_h), target_width)
# Resize the image
resized_image = image.resize((new_width, new_height))
new_image = Image.new('RGB', (target_width, target_height), (0, 0, 0))
paste_x = (target_width - new_width) // 2
paste_y = (target_height - new_height) // 2
new_image.paste(resized_image, (paste_x, paste_y))
return new_image
def divide_to_patches(image, patch_size):
"""
Divides an image into patches of a specified size.
Args:
image (PIL.Image.Image): The input image.
patch_size (int): The size of each patch.
Returns:
list: A list of PIL.Image.Image objects representing the patches.
"""
patches = []
width, height = image.size
for i in range(0, height, patch_size):
for j in range(0, width, patch_size):
box = (j, i, j + patch_size, i + patch_size)
patch = image.crop(box)
patches.append(patch)
return patches
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
"""
Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
Args:
image_size (tuple): The size of the input image in the format (width, height).
grid_pinpoints (str): A string representation of a list of possible resolutions.
patch_size (int): The size of each image patch.
Returns:
tuple: The shape of the image patch grid in the format (width, height).
"""
if type(grid_pinpoints) is list:
possible_resolutions = grid_pinpoints
else:
possible_resolutions = ast.literal_eval(grid_pinpoints)
width, height = select_best_resolution(image_size, possible_resolutions)
return width // patch_size, height // patch_size
def process_anyres_image(image, processor, grid_pinpoints):
"""
Process an image with variable resolutions.
Args:
image (PIL.Image.Image): The input image to be processed.
processor: The image processor object.
grid_pinpoints (str): A string representation of a list of possible resolutions.
Returns:
torch.Tensor: A tensor containing the processed image patches.
"""
# FIXME: determine grid_pinpoints from image sizes.
if type(grid_pinpoints) is list:
possible_resolutions = grid_pinpoints
else:
possible_resolutions = ast.literal_eval(grid_pinpoints)
best_resolution = select_best_resolution(image.size, possible_resolutions)
image_padded = resize_and_pad_image(image, best_resolution)
processor_size = processor.transforms[0].size
patches = divide_to_patches(image_padded, processor_size[0])
image_original_resize = image.resize((processor_size[0], processor_size[0]))
image_patches = [image_original_resize] + patches
image_patches = [processor(image_patch)
for image_patch in image_patches]
return torch.stack(image_patches, dim=0)
def expand2square(pil_img, background_color):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
def process_images(images, image_processor, model_cfg):
image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
new_images = []
if image_aspect_ratio == 'pad':
for image in images:
image = expand2square(image, tuple(int(x*255) for x in image_processor.transforms[-1].mean))
image = image_processor(image)
new_images.append(image)
elif image_aspect_ratio in ["anyres", "anyres-legacy"]:
base_img_size = image_processor.transforms[0].size[0]
for image in images:
image = process_anyres_image(image, image_processor, [[base_img_size,base_img_size*2],
[base_img_size*2,base_img_size],
[base_img_size*2,base_img_size*2],
[base_img_size*3,base_img_size],
[base_img_size,base_img_size*3]])
new_images.append(image)
else:
return image_processor(images)
if all(x.shape == new_images[0].shape for x in new_images):
new_images = torch.stack(new_images, dim=0)
return new_images
import dataclasses
from enum import auto, Enum
from typing import List, Tuple
class SeparatorStyle(Enum):
"""Different separator style."""
SINGLE = auto()
TWO = auto()
MPT = auto()
PLAIN = auto()
LLAMA_2 = auto()
PHI_3 = auto()
@dataclasses.dataclass
class Conversation:
"""A class that keeps all conversation history."""
system: str
roles: List[str]
messages: List[List[str]]
offset: int
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
sep: str = "###"
sep2: str = None
version: str = "Unknown"
skip_next: bool = False
def get_prompt(self):
messages = self.messages
if len(messages) > 0 and type(messages[0][1]) is tuple:
messages = self.messages.copy()
init_role, init_msg = messages[0].copy()
init_msg = init_msg[0].replace("<image>", "").strip()
if 'mmtag' in self.version:
messages[0] = (init_role, init_msg)
messages.insert(0, (self.roles[0], "<Image><image></Image>"))
messages.insert(1, (self.roles[1], "Received."))
else:
messages[0] = (init_role, "<image>\n" + init_msg)
if self.sep_style == SeparatorStyle.SINGLE:
ret = self.system + self.sep
for role, message in messages:
if message:
if type(message) is tuple:
message, _, _ = message
ret += role + ": " + message + self.sep
else:
ret += role + ":"
elif self.sep_style == SeparatorStyle.TWO:
seps = [self.sep, self.sep2]
ret = self.system + seps[0]
for i, (role, message) in enumerate(messages):
if message:
if type(message) is tuple:
message, _, _ = message
ret += role + ": " + message + seps[i % 2]
else:
ret += role + ":"
elif self.sep_style == SeparatorStyle.MPT:
ret = self.system + self.sep
for role, message in messages:
if message:
if type(message) is tuple:
message, _, _ = message
ret += role + message + self.sep
else:
ret += role
elif self.sep_style == SeparatorStyle.LLAMA_2:
wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n" if len(msg) > 0 else msg
wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
ret = ""
for i, (role, message) in enumerate(messages):
if i == 0:
assert message, "first message should not be none"
assert role == self.roles[0], "first message should come from user"
if message:
if type(message) is tuple:
message, _, _ = message
if i == 0: message = wrap_sys(self.system) + message
if i % 2 == 0:
message = wrap_inst(message)
ret += self.sep + message
else:
ret += " " + message + " " + self.sep2
else:
ret += ""
ret = ret.lstrip(self.sep)
elif self.sep_style == SeparatorStyle.PLAIN:
seps = [self.sep, self.sep2]
ret = self.system
for i, (role, message) in enumerate(messages):
if message:
if type(message) is tuple:
message, _, _ = message
ret += message + seps[i % 2]
else:
ret += ""
elif self.sep_style == SeparatorStyle.PHI_3:
seps = [self.sep, self.sep2] # []
if self.system != "":
ret = '<|system|>' + '\n' + self.system + '<|end|>' + '\n'
else:
# Phi-3 w/o system prompt.
ret = ""
for i, (role, message) in enumerate(messages):
if message:
if type(message) is tuple:
message, _, _ = message
ret += role + '\n' + message + '<|end|>' + '\n'
else:
ret += '<|assistant|>' + '\n'
else:
raise ValueError(f"Invalid style: {self.sep_style}")
return ret
def append_message(self, role, message):
self.messages.append([role, message])
def get_images(self, return_pil=False):
images = []
for i, (role, msg) in enumerate(self.messages[self.offset:]):
if i % 2 == 0:
if type(msg) is tuple:
import base64
from io import BytesIO
from PIL import Image
msg, image, image_process_mode = msg
if image_process_mode == "Pad":
def expand2square(pil_img, background_color=(122, 116, 104)):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
image = expand2square(image)
elif image_process_mode in ["Default", "Crop"]:
pass
elif image_process_mode == "Resize":
image = image.resize((336, 336))
else:
raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
max_hw, min_hw = max(image.size), min(image.size)
aspect_ratio = max_hw / min_hw
max_len, min_len = 800, 400
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
longest_edge = int(shortest_edge * aspect_ratio)
W, H = image.size
if longest_edge != max(image.size):
if H > W:
H, W = longest_edge, shortest_edge
else:
H, W = shortest_edge, longest_edge
image = image.resize((W, H))
if return_pil:
images.append(image)
else:
buffered = BytesIO()
image.save(buffered, format="PNG")
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
images.append(img_b64_str)
return images
def to_gradio_chatbot(self):
ret = []
for i, (role, msg) in enumerate(self.messages[self.offset:]):
if i % 2 == 0:
if type(msg) is tuple:
import base64
from io import BytesIO
msg, image, image_process_mode = msg
max_hw, min_hw = max(image.size), min(image.size)
aspect_ratio = max_hw / min_hw
max_len, min_len = 800, 400
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
longest_edge = int(shortest_edge * aspect_ratio)
W, H = image.size
if H > W:
H, W = longest_edge, shortest_edge
else:
H, W = shortest_edge, longest_edge
image = image.resize((W, H))
buffered = BytesIO()
image.save(buffered, format="JPEG")
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
msg = img_str + msg.replace('<image>', '').strip()
ret.append([msg, None])
else:
ret.append([msg, None])
else:
ret[-1][-1] = msg
return ret
def copy(self):
return Conversation(
system=self.system,
roles=self.roles,
messages=[[x, y] for x, y in self.messages],
offset=self.offset,
sep_style=self.sep_style,
sep=self.sep,
sep2=self.sep2,
version=self.version)
def dict(self):
if len(self.get_images()) > 0:
return {
"system": self.system,
"roles": self.roles,
"messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
"offset": self.offset,
"sep": self.sep,
"sep2": self.sep2,
}
return {
"system": self.system,
"roles": self.roles,
"messages": self.messages,
"offset": self.offset,
"sep": self.sep,
"sep2": self.sep2,
}
conv_vicuna_v0 = Conversation(
system="A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
roles=("Human", "Assistant"),
messages=(
("Human", "What are the key differences between renewable and non-renewable energy sources?"),
("Assistant",
"Renewable energy sources are those that can be replenished naturally in a relatively "
"short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
"Non-renewable energy sources, on the other hand, are finite and will eventually be "
"depleted, such as coal, oil, and natural gas. Here are some key differences between "
"renewable and non-renewable energy sources:\n"
"1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
"energy sources are finite and will eventually run out.\n"
"2. Environmental impact: Renewable energy sources have a much lower environmental impact "
"than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
"and other negative effects.\n"
"3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
"have lower operational costs than non-renewable sources.\n"
"4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
"locations than non-renewable sources.\n"
"5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
"situations and needs, while non-renewable sources are more rigid and inflexible.\n"
"6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
"non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
),
offset=2,
sep_style=SeparatorStyle.SINGLE,
sep="###",
)
conv_vicuna_v1 = Conversation(
system="A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
roles=("USER", "ASSISTANT"),
version="v1",
messages=(),
offset=0,
sep_style=SeparatorStyle.TWO,
sep=" ",
sep2="</s>",
)
conv_mistral_instruct = Conversation(
system="",
roles=("USER", "ASSISTANT"),
version="llama_v2",
messages=(),
offset=0,
sep_style=SeparatorStyle.LLAMA_2,
sep="",
sep2="</s>",
)
conv_phi_3 = Conversation(
system="A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
roles=("<|user|>", "<|assistant|>"),
version="phi_3",
messages=(),
offset=0,
sep_style=SeparatorStyle.PHI_3,
sep="<s>",
sep2="<|end|>",
)
default_conversation = conv_vicuna_v1
conv_templates = {
"default": conv_vicuna_v0,
"v0": conv_vicuna_v0,
"v1": conv_vicuna_v1,
"vicuna_v1": conv_vicuna_v1,
"mistral_instruct": conv_mistral_instruct,
"phi_3": conv_phi_3,
}
if __name__ == "__main__":
print(default_conversation.get_prompt())
"""
Util functions for initializing webdataset objects
"""
import ast
import json
import logging
import os
import random
import sys
from dataclasses import dataclass
from multiprocessing import Value
import braceexpand
import webdataset as wds
from torch.utils.data import DataLoader, IterableDataset, get_worker_info
from torch.utils.data.distributed import DistributedSampler
from webdataset.filters import _shuffle
from webdataset.tariterators import (
base_plus_ext,
tar_file_expander,
url_opener,
valid_sample,
)
try:
import horovod.torch as hvd
except ImportError:
hvd = None
class SharedEpoch:
def __init__(self, epoch: int = 0):
self.shared_epoch = Value("i", epoch)
def set_value(self, epoch):
self.shared_epoch.value = epoch
def get_value(self):
return self.shared_epoch.value
@dataclass
class DataInfo:
"""
DataInfo is a dataclass that holds information about a dataset.
"""
name: str
dataloader: DataLoader
batch_size: int
loss_multiplier: int
sampler: DistributedSampler = None
shared_epoch: SharedEpoch = None
def set_epoch(self, epoch):
if self.shared_epoch is not None:
self.shared_epoch.set_value(epoch)
if self.sampler is not None and isinstance(self.sampler, DistributedSampler):
self.sampler.set_epoch(epoch)
def get_dataset_size(shards):
"""
Get the number of samples in a dataset and the number of shards in a dataset
based on the shards list.
Returns None for the number of samples if is undefined.
One can define the number of samples using a sizes.json file in the same directory
or a __len__ file in the same directory.
"""
shards_list = list(braceexpand.braceexpand(shards))
dir_path = os.path.dirname(shards[0])
sizes_filename = os.path.join(dir_path, "sizes.json")
len_filename = os.path.join(dir_path, "__len__")
if os.path.exists(sizes_filename):
sizes = json.load(open(sizes_filename, "r"))
total_size = sum(
[
int(sizes[os.path.basename(shard)])
if os.path.basename(shard) in sizes
else 0
for shard in shards_list
]
)
elif os.path.exists(len_filename):
# FIXME this used to be eval(open(...)) but that seemed rather unsafe
total_size = ast.literal_eval(open(len_filename, "r").read())
else:
total_size = None # num samples undefined
num_shards = len(shards_list)
return total_size, num_shards
def log_and_continue(exn):
"""Call in an exception handler to ignore any exception, issue a warning, and continue."""
if "images in sample" not in repr(exn):
logging.warning(f"Handling webdataset error ({repr(exn)}). Ignoring.")
return True
def group_by_keys_nothrow(
data, keys=base_plus_ext, lcase=True, suffixes=None, handler=None
):
"""Return function over iterator that groups key, value pairs into samples.
:param keys: function that splits the key into key and extension (base_plus_ext)
:param lcase: convert suffixes to lower case (Default value = True)
"""
current_sample = None
for filesample in data:
assert isinstance(filesample, dict)
fname, value = filesample["fname"], filesample["data"]
prefix, suffix = keys(fname)
if prefix is None:
continue
if lcase:
suffix = suffix.lower()
# FIXME webdataset version throws if suffix in current_sample, but we have a potential for
# this happening in the current LAION400m dataset if a tar ends with same prefix as the next
# begins, rare, but can happen since prefix aren't unique across tar files in that dataset
if (
current_sample is None
or prefix != current_sample["__key__"]
or suffix in current_sample
):
if valid_sample(current_sample):
yield current_sample
current_sample = dict(__key__=prefix, __url__=filesample["__url__"])
if suffixes is None or suffix in suffixes:
current_sample[suffix] = value
if valid_sample(current_sample):
yield current_sample
def tarfile_to_samples_nothrow(src, handler=log_and_continue):
# NOTE this is a re-impl of the webdataset impl with group_by_keys that doesn't throw
streams = url_opener(src, handler=handler)
files = tar_file_expander(streams, handler=handler)
samples = group_by_keys_nothrow(files, handler=handler)
return samples
def pytorch_worker_seed(increment=0):
"""get dataloader worker seed from pytorch"""
worker_info = get_worker_info()
if worker_info is not None:
# favour using the seed already created for pytorch dataloader workers if it exists
seed = worker_info.seed
if increment:
# space out seed increments so they can't overlap across workers in different iterations
seed += increment * max(1, worker_info.num_workers)
return seed
# fallback to wds rank based seed
return wds.utils.pytorch_worker_seed()
class detshuffle2(wds.PipelineStage):
def __init__(
self,
bufsize=1000,
initial=100,
seed=0,
epoch=-1,
):
self.bufsize = bufsize
self.initial = initial
self.seed = seed
self.epoch = epoch
def run(self, src):
if isinstance(self.epoch, SharedEpoch):
epoch = self.epoch.get_value()
else:
# NOTE: this is epoch tracking is problematic in a multiprocess (dataloader workers or train)
# situation as different workers may wrap at different times (or not at all).
self.epoch += 1
epoch = self.epoch
rng = random.Random()
if self.seed < 0:
# If seed is negative, we use the worker's seed, this will be different across all nodes/workers
seed = pytorch_worker_seed(epoch)
else:
# This seed to be deterministic AND the same across all nodes/workers in each epoch
seed = self.seed + epoch
rng.seed(seed)
return _shuffle(src, self.bufsize, self.initial, rng)
class ResampledShards2(IterableDataset):
"""An iterable dataset yielding a list of urls."""
def __init__(
self,
urls,
nshards=sys.maxsize,
worker_seed=None,
deterministic=False,
epoch=-1,
):
"""Sample shards from the shard list with replacement.
:param urls: a list of URLs as a Python list or brace notation string
"""
super().__init__()
urls = wds.shardlists.expand_urls(urls)
self.urls = urls
assert isinstance(self.urls[0], str)
self.nshards = nshards
self.rng = random.Random()
self.worker_seed = worker_seed
self.deterministic = deterministic
self.epoch = epoch
def __iter__(self):
"""Return an iterator over the shards."""
if isinstance(self.epoch, SharedEpoch):
epoch = self.epoch.get_value()
else:
# NOTE: this is epoch tracking is problematic in a multiprocess (dataloader workers or train)
# situation as different workers may wrap at different times (or not at all).
self.epoch += 1
epoch = self.epoch
if self.deterministic:
# reset seed w/ epoch if deterministic
if self.worker_seed is None:
# pytorch worker seed should be deterministic due to being init by arg.seed + rank + worker id
seed = pytorch_worker_seed(epoch)
else:
seed = self.worker_seed() + epoch
self.rng.seed(seed)
for _ in range(self.nshards):
yield dict(url=self.rng.choice(self.urls))
"""
Util functions for distributed training and FSDP.
"""
import os
import torch
##################################
# SLURM setup; Credit: open_clip #
##################################
try:
import horovod.torch as hvd
except ImportError:
hvd = None
def is_global_master(args):
return args.rank == 0
def is_local_master(args):
return args.local_rank == 0
def is_master(args, local=False):
return is_local_master(args) if local else is_global_master(args)
def is_using_horovod():
# NOTE w/ horovod run, OMPI vars should be set, but w/ SLURM PMI vars will be set
# Differentiating between horovod and DDP use via SLURM may not be possible, so horovod arg still required...
ompi_vars = ["OMPI_COMM_WORLD_RANK", "OMPI_COMM_WORLD_SIZE"]
pmi_vars = ["PMI_RANK", "PMI_SIZE"]
if all([var in os.environ for var in ompi_vars]) or all(
[var in os.environ for var in pmi_vars]
):
return True
else:
return False
def is_using_distributed():
if "WORLD_SIZE" in os.environ:
return int(os.environ["WORLD_SIZE"]) > 1
if "SLURM_NTASKS" in os.environ:
return int(os.environ["SLURM_NTASKS"]) > 1
return False
def world_info_from_env():
local_rank = 0
for v in (
"LOCAL_RANK",
"MPI_LOCALRANKID",
"SLURM_LOCALID",
"OMPI_COMM_WORLD_LOCAL_RANK",
):
if v in os.environ:
local_rank = int(os.environ[v])
break
global_rank = 0
for v in ("RANK", "PMI_RANK", "SLURM_PROCID", "OMPI_COMM_WORLD_RANK"):
if v in os.environ:
global_rank = int(os.environ[v])
break
world_size = 1
for v in ("WORLD_SIZE", "PMI_SIZE", "SLURM_NTASKS", "OMPI_COMM_WORLD_SIZE"):
if v in os.environ:
world_size = int(os.environ[v])
break
return local_rank, global_rank, world_size
def init_distributed_device(args):
# Distributed training = training on more than one GPU.
# Works in both single and multi-node scenarios.
args.distributed = False
args.world_size = 1
args.rank = 0 # global rank
args.local_rank = 0
if args.horovod:
assert hvd is not None, "Horovod is not installed"
hvd.init()
args.local_rank = int(hvd.local_rank())
args.rank = hvd.rank()
args.world_size = hvd.size()
args.distributed = True
os.environ["LOCAL_RANK"] = str(args.local_rank)
os.environ["RANK"] = str(args.rank)
os.environ["WORLD_SIZE"] = str(args.world_size)
elif is_using_distributed():
if "SLURM_PROCID" in os.environ:
# DDP via SLURM
args.local_rank, args.rank, args.world_size = world_info_from_env()
# SLURM var -> torch.distributed vars in case needed
os.environ["LOCAL_RANK"] = str(args.local_rank)
os.environ["RANK"] = str(args.rank)
os.environ["WORLD_SIZE"] = str(args.world_size)
torch.distributed.init_process_group(
backend=args.dist_backend,
init_method=args.dist_url,
world_size=args.world_size,
rank=args.rank,
)
else:
# DDP via torchrun, torch.distributed.launch
args.local_rank, _, _ = world_info_from_env()
torch.distributed.init_process_group(
backend=args.dist_backend, init_method=args.dist_url
)
args.world_size = torch.distributed.get_world_size()
args.rank = torch.distributed.get_rank()
args.distributed = True
else:
# needed to run on single gpu
torch.distributed.init_process_group(
backend=args.dist_backend,
init_method=args.dist_url,
world_size=1,
rank=0,
)
if torch.cuda.is_available():
if args.distributed and not args.no_set_device_rank:
device = "cuda:%d" % args.local_rank
else:
device = "cuda:0"
torch.cuda.set_device(device)
else:
device = "cpu"
args.device = device
device = torch.device(device)
return device
#####################################
# FSDP util functions #
#####################################
def get_fsdp_mixed_precision_policy(
precision: str,
reduce_param_precision=False,
reduce_communication_precision=True,
reduce_buffer_precision=True,
):
"""
Returns the FSDP mixed precision policy for a given precision.
"""
if "bfloat16" in precision or "bf16" in precision:
cast_dtype = torch.bfloat16
elif precision == "fp16":
cast_dtype = torch.float16
else:
cast_dtype = torch.float32
if cast_dtype == torch.float32:
return None
from torch.distributed.fsdp import MixedPrecision
return MixedPrecision(
param_dtype=cast_dtype if reduce_param_precision else torch.float32,
reduce_dtype=cast_dtype if reduce_communication_precision else torch.float32,
buffer_dtype=cast_dtype if reduce_buffer_precision else torch.float32,
)
def get_fsdp_config(
args,
device_id,
):
"""
Return kwargs for FSDP wrapper.
This includes some hard-coded settings.
"""
# init MixedPrecision
mp_policy = get_fsdp_mixed_precision_policy(
args.precision,
reduce_param_precision=False,
reduce_communication_precision=True,
reduce_buffer_precision=True,
)
# init FSDP
from torch.distributed.fsdp import (
ShardingStrategy,
BackwardPrefetch,
CPUOffload
)
if args.fsdp_sharding_strategy == "full":
sharding_strategy = ShardingStrategy.FULL_SHARD
elif args.fsdp_sharding_strategy == "hybrid":
sharding_strategy = ShardingStrategy.HYBRID_SHARD
elif args.fsdp_sharding_strategy == "shard_grad_op":
sharding_strategy = ShardingStrategy.SHARD_GRAD_OP
elif args.fsdp_sharding_strategy == "hybrid_shard_grad_op":
sharding_strategy = ShardingStrategy._HYBRID_SHARD_ZERO2
elif args.fsdp_sharding_strategy == "no_shard":
sharding_strategy = ShardingStrategy.NO_SHARD
else:
raise ValueError(
f"Invalid sharding strategy: {args.fsdp_sharding_strategy}. Supported: full, hybrid, shard_grad_op, hybrid_shard_grad_op, no_shard"
)
if args.cpu_offload_gradients:
cpu_offload = CPUOffload(offload_params=True)
else:
cpu_offload = None
return dict(
cpu_offload=cpu_offload,
device_id=device_id,
sync_module_states=True, # broadcast loaded ckpt from rank 0 -> all ranks
sharding_strategy=sharding_strategy,
use_orig_params=True,
mixed_precision=mp_policy,
forward_prefetch=True,
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
limit_all_gathers=True,
)
def get_fsdp_checkpoint_config(args):
"""
Return kwargs for FSDP checkpointing.
"""
from torch.distributed.fsdp import (
FullStateDictConfig,
StateDictType,
)
from torch.distributed.fsdp.api import FullOptimStateDictConfig
# to avoid GPU OOM when loading/saving ckpts, load/save on CPU
# this is slow
return dict(
state_dict_type=StateDictType.FULL_STATE_DICT,
state_dict_config=FullStateDictConfig(rank0_only=True, offload_to_cpu=True),
optim_state_dict_config=FullOptimStateDictConfig(
rank0_only=True, offload_to_cpu=True
),
)
""" Main training script """
import argparse
from datetime import datetime
import os
from omegaconf import OmegaConf
import torch
import wandb
import functools
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy
from open_flamingo import create_model_and_transforms, SUPPORTED_MODEL_FAMILIES
from open_flamingo.train.distributed import (
init_distributed_device,
world_info_from_env,
get_fsdp_config,
get_fsdp_checkpoint_config,
)
from open_flamingo.train.sft_data_utils import make_supervised_data_module
from open_flamingo.train.train_utils import (
finetune_one_epoch,
random_seed,
find_most_recent_checkpoint,
load_checkpoint,
save_checkpoint,
)
from open_flamingo.train.losses import (
SUPPORTED_LOSSES,
get_loss_fn,
)
from transformers import (
get_constant_schedule_with_warmup,
get_cosine_schedule_with_warmup,
get_linear_schedule_with_warmup,
)
def parse_tuple_list(input_string):
try:
tuples = input_string.strip().strip('()').split('),(')
# Convert each item in the list to a tuple
tuple_list = [tuple(map(int, item.split(','))) for item in tuples]
return tuple_list
except Exception as e:
raise argparse.ArgumentTypeError(f"Invalid tuple list format: {input_string}. Error: {e}")
def main():
parser = argparse.ArgumentParser()
# model configuration args
parser.add_argument(
"--model_family", default="kosmos-instruct", type=str, choices=SUPPORTED_MODEL_FAMILIES
)
parser.add_argument("--vision_encoder_path", default="ViT-SO400M-14-SigLIP-384", type=str)
parser.add_argument("--vision_encoder_pretrained", default="webli", type=str)
parser.add_argument("--lm_path", default="facebook/opt-1.3b", type=str)
parser.add_argument(
"--tokenizer_path",
default="facebook/opt-30b",
type=str,
help="path to tokenizer",
)
parser.add_argument(
"--cross_attn_every_n_layers",
type=int,
default=1,
help="how often to add a cross-attention layer after each transformer layer",
)
parser.add_argument(
"--num_vision_tokens",
type=int, default=64, help="number of query tokens used for resampling vision features.",
)
parser.add_argument("--pretrained", type=str, default=None, help="pretrained weights for fine-tuning.")
parser.add_argument("--pretrained_vision_tokenizer", type=str, default=None, help="pretrained vl connector for fine-tuning.")
# training args
parser.add_argument(
"--loss", type=str, choices=SUPPORTED_LOSSES, default="supervised_finetune"
)
parser.add_argument(
"--run_name",
type=str,
default="openflamingo3B",
help="used to name saving directory and wandb run",
)
parser.add_argument(
"--resume_from_checkpoint",
type=str,
help="path to checkpoint to resume from, this should contain model, optimizer, and lr_scheduler states. if there exists a checkpoint in the dir named run_name, we will resume from that checkpoint by default.",
default=None,
)
parser.add_argument(
"--delete_previous_checkpoint",
action="store_true",
help="delete previous checkpoint when saving new checkpoint",
)
parser.add_argument(
"--no_save_optim_state",
action="store_true",
help="do not save optimizer states when saving checkpoints",
)
parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--learning_rate", default=1e-4, type=float)
parser.add_argument(
"--lr_scheduler",
default="constant",
type=str,
help="constant, linear, or cosine",
)
parser.add_argument("--warmup_steps", default=5000, type=int)
parser.add_argument("--weight_decay", default=0.1, type=float)
parser.add_argument(
"--precision",
choices=["amp_bf16", "amp_bfloat16", "bf16", "fp16", "fp32"],
default="fp32",
help="Floating point precision.",
)
parser.add_argument(
"--gradient_checkpointing",
action="store_true",
help="whether to train with gradient/activation checkpointing",
)
parser.add_argument(
"--num_epochs",
type=int,
default=1,
help="we define an 'epoch' as a fixed number of examples specified by train_num_samples, not a pass through the entire dataset",
)
parser.add_argument("--offline", action="store_true")
parser.add_argument(
"--logging_steps", type=int, default=100, help="log loss every n steps"
)
parser.add_argument(
"--checkpoint_steps", type=int, default=5000, help="log loss every n steps"
)
# data args
# TODO: load a data args yaml file
parser.add_argument(
"--data_path",
default="/export/home/LLaVA/playground/data/llava_v1_5_mix665k_ocr_tagged_vqa_placeholder.json",
type=str
)
parser.add_argument("--batch_size", type=int, default=8)
parser.add_argument("--workers", type=int, default=1)
parser.add_argument("--data_sampler_group_by_length", default=False, action="store_true")
# Legacy Llava data args
parser.add_argument("--is_multimodal", type=bool, default=True)
parser.add_argument("--mm_use_im_start_end", default=False, action="store_true")
parser.add_argument("--conv_template_name", type=str, default=None)
# Any resolution
parser.add_argument("--image_aspect_ratio", type=str, default='pad')
parser.add_argument(
"--anyres_patch_sampling",
default=False,
action="store_true",
)
parser.add_argument('--anyres_grids',
type=parse_tuple_list,
default="(1,2),(2,1),(2,2),(3,1),(1,3)",
help="List of tuples in the format (1,2),(3,4),...")
# distributed training args
parser.add_argument(
"--dist-url",
default="env://",
type=str,
help="url used to set up distributed training",
)
parser.add_argument(
"--dist-backend", default="nccl", type=str, help="distributed backend"
)
parser.add_argument(
"--horovod",
default=False,
action="store_true",
help="Use horovod for distributed training.",
)
parser.add_argument(
"--no-set-device-rank",
default=False,
action="store_true",
help="Don't set device index from local rank (when CUDA_VISIBLE_DEVICES restricted to one per proc).",
)
parser.add_argument(
'--local-rank',
default=0,
type=int,
help='Local rank for distributed training'
)
# fsdp args
parser.add_argument(
"--fsdp",
default=False,
action="store_true",
help="Use FullyShardedDataParallel for distributed training. Not supported for some models, e.g. OPT.",
)
parser.add_argument(
"--fsdp_sharding_strategy", default="full", type=str, choices=["full", "hybrid", "shard_grad_op", "hybrid_shard_grad_op", "no_shard"]
)
# wandb args
parser.add_argument("--report_to_wandb", default=False, action="store_true")
parser.add_argument(
"--wandb_project",
type=str,
)
parser.add_argument(
"--wandb_entity",
type=str,
)
parser.add_argument(
"--save_checkpoints_to_wandb",
default=False,
action="store_true",
help="save checkpoints to wandb",
)
parser.add_argument(
"--dryrun",
default=False,
action="store_true",
)
parser.add_argument(
'--use_flash_attention_2',
default=False, action='store_true',
help='Use Flash Attention 2.0 for language model.'
)
parser.add_argument(
'--unfreeze_vision_encoder',
default=False, action='store_true',
help='Unfreeze vision encoder during training.'
)
parser.add_argument(
'--vision_encoder_precision',
default='fp32',
choices=["bf16", "fp32"],
help='Precision of the vision encoder during training.'
)
parser.add_argument(
'--cpu_offload_gradients',
default=False, action='store_true',
help='This specifies whether to offload parameters to CPU when not involved in computation. If True, then this offloads gradients to CPU as well, meaning that the optimizer step runs on CPU.'
)
args = parser.parse_args()
if args.save_checkpoints_to_wandb and not args.report_to_wandb:
raise ValueError("save_checkpoints_to_wandb requires report_to_wandb")
if args.fsdp:
assert (
torch.__version__ > "2.0.1"
), "FSDP requires torch > 2.0.1"
# Set up distributed training
args.local_rank, args.rank, args.world_size = world_info_from_env()
if args.rank == 0:
print(f"Initializing distributed training with {args.world_size} GPUs.")
if args.offline:
os.environ["WANDB_MODE"] = "offline"
os.environ["TRANSFORMERS_OFFLINE"] = "1"
device_id = init_distributed_device(args)
random_seed(args.seed)
# Initialize model
if args.model_family == "flamingo":
additional_kwargs={"cross_attn_every_n_layers": args.cross_attn_every_n_layers}
elif args.model_family in ['xgenmm_v1']:
additional_kwargs = {
"image_aspect_ratio": args.image_aspect_ratio,
"num_vision_tokens": args.num_vision_tokens,
"anyres_patch_sampling": args.anyres_patch_sampling,
}
else:
additional_kwargs = {}
model, image_processor, tokenizer = create_model_and_transforms(
args.vision_encoder_path,
args.vision_encoder_pretrained,
args.lm_path,
args.tokenizer_path if args.tokenizer_path else args.lm_path,
model_family=args.model_family,
pretrained_vision_tokenizer=args.pretrained_vision_tokenizer,
use_local_files=args.offline,
gradient_checkpointing=args.gradient_checkpointing,
verbose=(args.rank == 0),
**additional_kwargs,
)
random_seed(args.seed, args.rank)
# Initialize wandb logging
now = datetime.now().strftime("%Y%m%d%H%M")[:-1]
if args.rank == 0 and args.report_to_wandb:
config = vars(args)
# print("-----------wandb:Config:", config)
wandb.init(
project=args.wandb_project,
name=f"{args.run_name}-{now}",
config=vars(args),
settings=wandb.Settings(init_timeout=120)
)
# Load model checkpoint (on CPU)
if args.fsdp:
args.fsdp_checkpoint_config = get_fsdp_checkpoint_config(args)
# if args do not specify a checkpoint to resume from, resume from most recent checkpoint
resume_from_step = 0
if os.path.exists(f"{args.run_name}") and args.resume_from_checkpoint is None:
args.resume_from_checkpoint = find_most_recent_checkpoint(args)
if (
args.resume_from_checkpoint is not None
):
resume_from_epoch, resume_from_step, checkpoint = load_checkpoint(args, model)
print(f"Resume training from epoch {resume_from_epoch}, step {resume_from_step}...")
else:
resume_from_epoch = 0
resume_from_step = 0
# Load pretrained weights.
if args.resume_from_checkpoint is None and not args.dryrun:
if args.pretrained_vision_tokenizer is None:
assert os.path.exists(args.pretrained), "Must fine-tune from a pretrained weight."
if args.pretrained is not None:
_, _, checkpoint = load_checkpoint(args, model, pretrained=True)
print("Finished loading checkpoint...")
# Initialize gradient checkpointing
if args.gradient_checkpointing:
model.init_gradient_checkpointing()
# Initialize FSDP / DDP, and ensure the model is on GPU
if args.fsdp:
auto_wrap_policy = functools.partial(
lambda_auto_wrap_policy, lambda_fn=model.get_fsdp_lambda_fn()
)
wrapper_kwargs = get_fsdp_config(args, device_id)
distributed_model = FSDP(
model, auto_wrap_policy=auto_wrap_policy, **wrapper_kwargs
)
print("Finished FSDP wrapping...")
else:
model = model.to(device_id)
distributed_model = DDP(model, device_ids=[device_id])
# Initialize optimizer
params_with_wd, params_without_wd = model.group_params_by_weight_decay()
optimizer = torch.optim.AdamW(
[
{"params": params_with_wd, "weight_decay": args.weight_decay},
{"params": params_without_wd, "weight_decay": 0.0},
],
lr=args.learning_rate,
)
# load optimizer checkpoint
if args.resume_from_checkpoint is not None:
optim_state_dict = checkpoint["optimizer_state_dict"]
if args.fsdp:
# FSDP.set_state_dict_type(
# distributed_model,
# **args.fsdp_checkpoint_config,
# )
optim_state_dict = FSDP.optim_state_dict_to_load(
model=distributed_model, optim=optimizer, optim_state_dict=optim_state_dict
)
optimizer.load_state_dict(optim_state_dict)
# Initialize datasets
if args.data_path.split('.')[-1] == 'yaml':
# Loading a mixture of datasets with sampling ratios.
data_config = OmegaConf.load(args.data_path)
if args.rank == 0:
print("================== Data mixture config ===================")
print(data_config)
print("==========================================================")
args.data_path = dict(data_config.data_path)
train_dataset, total_num_samples = make_supervised_data_module(tokenizer=tokenizer,
image_processor=image_processor,
data_args=args)
print("-----------------------")
data_loader = train_dataset.dataloader
for batch in data_loader:
print(batch)
break
print(total_num_samples)
# Update anyres grid.
args.anyres_grids = train_dataset.dataloader.dataset.anyres_grids
model.anyres_grids = args.anyres_grids
# TODO: Summarize training data stats (dataset, portion, etc.)
total_training_steps = (
total_num_samples
// (args.batch_size * args.gradient_accumulation_steps * args.world_size)
) * args.num_epochs
if args.rank == 0:
print(f"Total training steps: {total_training_steps}")
# Initialize lr scheduler
if args.lr_scheduler == "linear":
lr_scheduler = get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=args.warmup_steps,
num_training_steps=total_training_steps,
)
elif args.lr_scheduler == "cosine":
lr_scheduler = get_cosine_schedule_with_warmup(
optimizer,
num_warmup_steps=args.warmup_steps,
num_training_steps=total_training_steps,
)
else:
lr_scheduler = get_constant_schedule_with_warmup(
optimizer, num_warmup_steps=args.warmup_steps
)
# load lr scheduler checkpoint
if args.resume_from_checkpoint is not None:
lr_scheduler.load_state_dict(checkpoint["lr_scheduler_state_dict"])
# Initialize the loss fn
loss_fn = get_loss_fn(args.loss)
# check wrapping
if args.rank == 0:
print(distributed_model)
# Start training!
print(f"Start running training on rank {args.rank}.")
for epoch in range(resume_from_epoch, args.num_epochs):
train_dataset.set_epoch(epoch)
finetune_one_epoch(
args=args,
resume_from_step=resume_from_step,
model=distributed_model,
epoch=epoch,
dataset=train_dataset,
compute_loss_fn=loss_fn,
tokenizer=tokenizer,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
device_id=device_id,
wandb=wandb,
)
save_checkpoint(distributed_model, optimizer, lr_scheduler, epoch, args)
if __name__ == "__main__":
main()
\ No newline at end of file
from open_flamingo.src.vlm import VLM
import torch
from typing import List, Optional
SUPPORTED_LOSSES = ["next_token_prediction",
"supervised_finetune"]
def get_loss_fn(loss_name):
if loss_name == "next_token_prediction":
return NextTokenPrediction()
elif loss_name == "supervised_finetune":
return SupervisedPrediction()
else:
raise ValueError(
f"Loss {loss_name} not supported. Supported losses: {SUPPORTED_LOSSES}"
)
class Loss:
@property
def name(self):
raise NotImplementedError
def __call__(
self,
model: VLM,
tokenizer,
images: torch.Tensor,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
autocast: callable,
):
"""
Args:
model: VLM model
images: images tensor, already moved to device and cast to appropriate dtype
shape (B, T_img, F, C, H, W)
input_ids: input ids tensor, already moved to device and cast to appropriate dtype
shape (B, T_text)
attention_mask: attention mask tensor, already moved to device and cast to appropriate dtype
shape (B, T_text)
autocast: autocast context manager
Return:
loss: scalar loss
"""
raise NotImplementedError
class NextTokenPrediction(Loss):
@property
def name(self):
return "next_token_prediction"
def __call__(
self,
model: VLM,
tokenizer,
images: torch.Tensor,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
autocast: callable,
):
# set up labels; language model is expected to handle shifting
labels = input_ids.clone()
labels[labels == tokenizer.pad_token_id] = -100
special_token_ids = torch.Tensor(unwrap_model(model).special_token_ids).to(
labels.device
)
labels[torch.isin(labels, special_token_ids)] = -100 # TODO: dont want to remove loss on <|endofchunk|> tokens
labels = labels.to(input_ids.device)
# call forward
with autocast():
loss = model(
vision_x=images,
lang_x=input_ids,
attention_mask=attention_mask,
labels=labels,
)[0]
return loss
class SupervisedPrediction(Loss):
@property
def name(self):
return "supervised_finetune"
def __call__(
self,
model: VLM,
tokenizer,
images: torch.Tensor,
input_ids: torch.Tensor,
labels: torch.Tensor,
attention_mask: torch.Tensor,
autocast: callable,
image_size: Optional[torch.Tensor] = None,
):
# set up labels; language model is expected to handle shifting
labels[labels == tokenizer.pad_token_id] = -100
special_token_ids = torch.Tensor(unwrap_model(model).special_token_ids).to(
labels.device
)
labels[torch.isin(labels, special_token_ids)] = -100 # TODO: dont want to remove loss on <|endofchunk|> tokens
labels = labels.to(input_ids.device)
# call forward
with autocast():
loss = model(
vision_x=images,
image_size=image_size,
lang_x=input_ids,
attention_mask=attention_mask,
labels=labels,
)[0]
return loss
def unwrap_model(model):
"""
Unwrap a model from a DataParallel or DistributedDataParallel wrapper.
"""
if isinstance(
model, (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel)
):
return model.module
else:
return model
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment