Commit 1b9205c9 authored by yangzhong's avatar yangzhong
Browse files

v1.0

parents
Pipeline #2931 failed with stages
in 0 seconds
This diff is collapsed.
import torch
def extend_instance(obj, mixin):
"""Apply mixins to a class instance after creation"""
base_cls = obj.__class__
base_cls_name = obj.__class__.__name__
obj.__class__ = type(
base_cls_name, (mixin, base_cls), {}
) # mixin needs to go first for our forward() logic to work
def hasattr_recursive(obj, att):
"""
Check if obj has nested attribute
Example: hasattr_recursive(obj, 'a.b.c') is equivalent to hasattr(obj, 'a') and hasattr(obj.a, 'b') and hasattr(obj.a.b, 'c')
"""
if att == "":
return True
i = att.find(".")
if i < 0:
return hasattr(obj, att)
else:
try:
return hasattr_recursive(getattr(obj, att[:i]), att[i + 1 :])
except:
return False
def getattr_recursive(obj, att):
"""
Return nested attribute of obj
Example: getattr_recursive(obj, 'a.b.c') is equivalent to obj.a.b.c
"""
if att == "":
return obj
i = att.find(".")
if i < 0:
return getattr(obj, att)
else:
return getattr_recursive(getattr(obj, att[:i]), att[i + 1 :])
def setattr_recursive(obj, att, val):
"""
Set nested attribute of obj
Example: setattr_recursive(obj, 'a.b.c', val) is equivalent to obj.a.b.c = val
"""
if "." in att:
obj = getattr_recursive(obj, ".".join(att.split(".")[:-1]))
setattr(obj, att.split(".")[-1], val)
def apply_with_stopping_condition(
module, apply_fn, apply_condition=None, stopping_condition=None, **other_args
):
if stopping_condition(module):
return
if apply_condition(module):
apply_fn(module, **other_args)
for child in module.children():
apply_with_stopping_condition(
child,
apply_fn,
apply_condition=apply_condition,
stopping_condition=stopping_condition,
**other_args
)
def stack_with_padding(list_of_tensors, padding_value=0, padding_side="right"):
"""
Stack a list of tensors with padding on one side
Args:
list_of_tensors (list[torch.Tensor]): List of tensors to stack
padding_value (int, optional): Value to pad with. Defaults to 0.
padding_side (str, optional): Side to pad on. Defaults to "right".
Returns:
torch.Tensor: Stacked tensors
"""
max_tokens = max(tensor.size(0) for tensor in list_of_tensors)
padded_tensors = []
for tensor in list_of_tensors:
num_tokens = tensor.size(0)
if len(tensor.size()) == 1:
padding = torch.full(
(max_tokens - num_tokens,),
padding_value,
dtype=tensor.dtype,
device=tensor.device,
)
else:
padding = torch.full(
(max_tokens - num_tokens, tensor.size(1)),
padding_value,
dtype=tensor.dtype,
device=tensor.device,
)
padded_tensor = (
torch.cat((tensor, padding), dim=0)
if padding_side == "right"
else torch.cat((padding, tensor), dim=0)
)
padded_tensors.append(padded_tensor)
return torch.stack(padded_tensors)
def num_params(module, filter_to_trainable=False):
"""Returns the number of parameters in the module, or optionally only the trainable parameters"""
if filter_to_trainable:
return sum(p.numel() for p in module.parameters() if p.requires_grad)
else:
return sum(p.numel() for p in module.parameters())
This diff is collapsed.
import torch
from einops import rearrange
from torch import nn
from typing import List, Optional, Tuple, Union
import os
from .helpers import PerceiverResampler
from .vlm import VLMWithLanguageStream
class XGenMMPerceiver(VLMWithLanguageStream):
def __init__(
self,
vision_encoder: nn.Module,
lang_model: nn.Module,
vis_feature_dim: int,
initial_tokenizer_len: int,
pad_token_id: int,
decoder_layers_attr_name: str = None,
gradient_checkpointing: bool = False,
base_img_size: Optional[int] = None,
image_aspect_ratio: str = 'anyres',
anyres_patch_sampling: bool = True,
num_vision_tokens: int = 128,
):
"""
Args:
vision_encoder (nn.Module): HF CLIPModel
lang_encoder (nn.Module): HF causal language model
vis_feature_dim (int): final dimension of the visual features outputted by the vision_encoder
initial_tokenizer_len (int): size of the tokenizer vocab
padding_token_id (int): id of the padding token. None if no padding token; then a padding token
will be inserted into self.special_tokens, which factory.py fills after creating new tokens
decoder_layers_attr_name (str, optional): name of the decoder layers attribute. Defaults to None.
gradient_checkpointing (bool, optional): whether to use gradient checkpointing. Defaults to False.
"""
self._special_tokens = {
"media_token": "<image>",
"image_placeholder_token": "<image placeholder>",
"end_of_trunk_token": "<|endofchunk|>",
}
lang_embedding_dim = lang_model.get_input_embeddings().weight.shape[1]
super().__init__(
vision_encoder=vision_encoder,
vision_tokenizer=PerceiverResampler(
dim=vis_feature_dim, dim_inner=lang_embedding_dim,
num_latents=num_vision_tokens,
),
lang_model=lang_model,
initial_tokenizer_len=initial_tokenizer_len,
gradient_checkpointing=gradient_checkpointing,
base_img_size=base_img_size,
decoder_layers_attr_name=decoder_layers_attr_name,
pad_token_id=pad_token_id,
)
self.image_aspect_ratio = image_aspect_ratio
self.anyres_patch_sampling = anyres_patch_sampling
self.anyres_grids = None
def set_trainable(self):
"""
Unfreeze everything except the vision_encoder
"""
self.requires_grad_(True)
self.vision_encoder.requires_grad_(False)
def _should_apply_weight_decay(self, parameter_name):
"""
Kosmos applies 0.01 weight deacy to everything
"""
return True
def forward(
self,
vision_x: Optional[torch.Tensor],
lang_x: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
image_size: Optional[Tuple] = None,
past_key_values: Optional[
List[Union[torch.Tensor, Tuple[torch.Tensor]]]
] = None,
past_media_locations: Optional[torch.Tensor] = None,
past_vision_tokens: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = False,
**kwargs,
):
"""
Args:
vision_x: Vision input
shape (B, T_img, F, C, H, W) with F=1
only F = 1 is supported (single-frame videos)
if T_img > the number of media tokens in the corresponding input_ids (lang_x),
only the first number of media tokens in lang_x are used
lang_x: Language input ids, with media tokens denoting where
visual media should be inserted.
shape (B, T_txt)
attention_mask: Attention mask. Defaults to None.
labels: Labels. Defaults to None.
shape (B, T_txt)
past_key_values (Tuple[torch.Tensor]], optional): Past key value pairs for each of the T_txt previous tokens in the language model. Defaults to None.
list of length = number of decoder layers in the LM
exact implementation depends on LM, see Hugging Face docs
past_media_locations (torch.Tensor, optional): boolean mask denoting which of the previous T_txt tokens were media tokens. Defaults to None.
shape (B, T_txt)
past_vision_tokens (torch.Tensor, optional): Previous vision tokens. Defaults to None.
use_cache (Optional[bool], optional): Whether to use cache. Defaults to False.
If True, includes key_values, media_locations, and vision_tokens in the output.
"""
assert not (past_vision_tokens is None) ^ (
past_media_locations is None
), "past_vision_tokens and past_media_locations must both be None or both be not None"
# convert pixels to vision tokens
vision_attention_mask = None
if vision_x is not None:
if self.image_aspect_ratio == 'anyres':
input_dict = dict(image=vision_x, image_size=image_size)
vision_features, vision_attn_masks = self._encode_vision_x_anyres(input_dict, lang_x.device)
else:
vision_features = self._encode_vision_x(vision_x=vision_x)
vision_attn_masks = None
if self.anyres_patch_sampling:
split_sizes = [feature.shape[0] for feature in vision_features]
# Nested splits for multi-image samples.
if isinstance(vision_x[0], list):
nt_images = [len(images) for images in vision_x]
split_split_sizes = []
img_id = 0
for nt in nt_images:
split_split_sizes.append(split_sizes[img_id:img_id+nt])
img_id += nt
else:
nt_images = [1] * len(vision_x)
split_split_sizes = split_sizes
vision_features = torch.cat(vision_features, dim=0)
vision_features = vision_features[:, None, None, :, :] # Expand dimensions.
vision_attn_masks = torch.cat(vision_attn_masks, dim=0)
vision_tokens = self.vision_tokenizer(vision_features, vision_attn_masks)
# Post-processing: Split the batches into groups of patches and concatenate them together.
if self.anyres_patch_sampling:
# assert isinstance(vision_x, list)
if isinstance(vision_x[0], list):
vision_token_groups = torch.split(vision_tokens, list(sum(nt_img) for nt_img in split_split_sizes), dim=0)
vision_tokens = []
for sample_id, patch_vis_tokens in enumerate(vision_token_groups):
patch_vis_token_groups = torch.split(patch_vis_tokens, split_split_sizes[sample_id], dim=0) # [Np*nt, 1, v, d] -> [[Np_t, 1, v, d], ...]
flatten_vision_tokens = []
# padded_attn_masks = []
for image_vis_token in patch_vis_token_groups:
image_vis_token = image_vis_token.flatten(0, 2) # [Np, 1, v, d] -> [Np*v, d]
flatten_vision_tokens.append(image_vis_token)
vision_tokens_i = flatten_vision_tokens
vision_tokens.append(vision_tokens_i)
else:
vision_token_groups = torch.split(vision_tokens, split_sizes, dim=0)
vision_tokens = []
for patch_vis_tokens in vision_token_groups:
patch_vis_tokens = patch_vis_tokens.flatten(0, 2) # [Np, 1, v, d] -> [Np*v, d]
vision_tokens.append(patch_vis_tokens.unsqueeze(0)) # Add the nt dimension.
else:
vision_tokens = None
# fuse the vision and language tokens
new_inputs = self._prepare_inputs_for_forward(
vision_tokens=vision_tokens,
lang_x=lang_x,
attention_mask=attention_mask,
vision_attention_mask=vision_attention_mask,
labels=labels,
past_key_values=past_key_values,
past_media_locations=past_media_locations,
padding_side="right",
past_vision_tokens=past_vision_tokens,
)
output = self.lang_model(
**new_inputs,
use_cache=use_cache,
past_key_values=past_key_values,
**kwargs,
)
# postforward hooks
self._post_forward_hook()
return output
def generate(
self,
vision_x: torch.Tensor,
lang_x: torch.Tensor,
image_size: Optional[Tuple] = None,
attention_mask: torch.Tensor = None,
past_key_values: Optional[
List[Union[torch.Tensor, Tuple[torch.Tensor]]]
] = None,
past_media_locations: Optional[torch.Tensor] = None,
past_vision_tokens: Optional[torch.Tensor] = None,
**kwargs,
):
"""
Generate text conditioned on vision and language inputs.
Args:
vision_x (torch.Tensor): Vision input
shape (B, T_img, F, C, H, W)
see documentation for forward
lang_x (torch.Tensor): Language input
shape (B, T_txt)
attention_mask (torch.Tensor, optional): Attention mask. Defaults to None.
**kwargs: see generate documentation in Hugging Face CausalLM models.
Returns:
torch.Tensor: lang_x with generated tokens appended to it
"""
num_beams = kwargs.pop("num_beams", 1)
# convert pixels to vision tokens
vision_attention_mask = None
if vision_x is not None:
if self.image_aspect_ratio == 'anyres':
input_dict = dict(image=vision_x, image_size=image_size)
vision_features, vision_attn_masks = self._encode_vision_x_anyres(input_dict, lang_x.device)
else:
vision_features = self._encode_vision_x(vision_x=vision_x)
vision_attn_masks = None
if self.anyres_patch_sampling:
split_sizes = [feature.shape[0] for feature in vision_features]
# Nested splits for multi-image samples.
if isinstance(vision_x[0], list):
nt_images = [len(images) for images in vision_x]
split_split_sizes = []
img_id = 0
for nt in nt_images:
split_split_sizes.append(split_sizes[img_id:img_id+nt])
img_id += nt
else:
nt_images = [1] * len(vision_x)
split_split_sizes = split_sizes
vision_features = torch.cat(vision_features, dim=0)
vision_features = vision_features[:, None, None, :, :] # Expand dimensions.
vision_attn_masks = torch.cat(vision_attn_masks, dim=0)
vision_tokens = self.vision_tokenizer(vision_features, vision_attn_masks)
# Post-processing: Split the batches into groups of patches and concatenate them together.
if self.anyres_patch_sampling:
assert isinstance(vision_x, list)
if isinstance(vision_x[0], list):
vision_token_groups = torch.split(vision_tokens, list(sum(nt_img) for nt_img in split_split_sizes), dim=0)
vision_tokens = []
for sample_id, patch_vis_tokens in enumerate(vision_token_groups):
# Pad the image tokens within a sample.
patch_vis_token_groups = torch.split(patch_vis_tokens, split_split_sizes[sample_id], dim=0) # [Np*nt, 1, v, d] -> [[Np_t, 1, v, d], ...]
flatten_vision_tokens = []
for image_vis_token in patch_vis_token_groups:
image_vis_token = image_vis_token.flatten(0, 2) # [Np, 1, v, d] -> [Np*v, d]
flatten_vision_tokens.append(image_vis_token)
vision_tokens_i = flatten_vision_tokens
vision_tokens.append(vision_tokens_i)
else:
# Padding. FIXME: padding here might not be necessary?
vision_token_groups = torch.split(vision_tokens, split_sizes, dim=0)
# Padding.
vision_tokens = []
for patch_vis_tokens in vision_token_groups:
patch_vis_tokens = patch_vis_tokens.flatten(0, 2) # [Np, 1, v, d] -> [Np*v, d]
vision_tokens.append(patch_vis_tokens.unsqueeze(0)) # Add the nt dimension.
else:
vision_tokens = None
# fuse the vision and language tokens
new_inputs = self._prepare_inputs_for_forward(
vision_tokens=vision_tokens,
lang_x=lang_x,
attention_mask=attention_mask,
vision_attention_mask=vision_attention_mask,
past_key_values=past_key_values,
past_media_locations=past_media_locations,
past_vision_tokens=past_vision_tokens,
padding_side="left",
num_beams=num_beams,
)
if past_key_values is not None:
output = self.lang_model.generate(
**new_inputs,
past_key_values=past_key_values,
num_beams=num_beams,
use_cache=True,
**kwargs,
)
else:
output = self.lang_model.generate(
**new_inputs,
num_beams=num_beams,
use_cache=True,
**kwargs,
)
self._post_forward_hook()
return output
# OpenFlamingo Training
We provide efficient data loading and distributed training code.
To train with OpenFlamingo, please ensure your environment matches that of `environment.yml`.
Table of contents:
* [Data](#data)
* [Example commands](#example-training-command)
* [Distributed training](#distributed-training)
## Data
Our codebase uses [WebDataset](https://github.com/webdataset/webdataset) to efficiently load `.tar` files containing image and text sequences. We recommend resampling shards with replacement during training using the `--dataset_resampled` flag.
Supported pretraining datasets
* LAION-2B
* Multimodal C4 (MMC4)
* ChatGPT-generated sequences from OpenFlamingo [technical report](https://arxiv.org/abs/2308.01390)
We plan to add additional datasets in the future, and we welcome contributions! If you'd like to add support for a pretraining dataset, please open a PR.
### LAION-2B Dataset
[LAION-2B](https://arxiv.org/abs/2210.08402) contains 2B web-scraped (image, text) pairs.
We use [img2dataset](https://github.com/rom1504/img2dataset) to download this dataset into tar files.
### Multimodal C4 Dataset
We train on the full version of [Multimodal C4 (MMC4)](https://github.com/allenai/mmc4), which includes 103M documents of web-scraped, interleaved image-text sequences. During training, we truncate sequences to 256 text tokens and six images per sequence.
Our codebase expects `.tar` files containing `.json` files, which include raw images encoded in base64.
We provide scripts to convert MMC4 to this format:
1. Download the MMC4 shards into `.zip` files using [the MMC4-provided scripts](https://github.com/allenai/mmc4/tree/main/scripts) (e.g., `fewer_facesv2.sh`).
2. Download the MMC4 raw images into an image directory using [the MMC4-provided scripts](https://github.com/allenai/mmc4/tree/main/scripts) (e.g., `download_images.py`).
2. Run `scripts/convert_mmc4_to_wds.py` to convert the downloaded items into the expected tar files.
### ChatGPT-generated sequences
A subset of our models (listed below) were also trained on experimental ChatGPT-generated (image, text) sequences, where images are pulled from LAION. The shards containing these sequences can be found at [this CodaLab worksheet](https://worksheets.codalab.org/worksheets/0xdcd888ff7c754ae680c5e038f6ed1d9b). We are unable to distribute raw images in the released shards; images must be pre-downloaded from the urls in the json files and converted to base64 before using this data for training in our codebase.
Models trained with ChatGPT-generated sequences:
* OpenFlamingo-4B-vitl-rpj3b
* OpenFlamingo-4B-vitl-rpj3b-langinstruct
## Example training command
We provide sample Slurm training scripts in `scripts/`. You can also modify the following command:
```
torchrun --nnodes=1 --nproc_per_node=4 train.py \
--lm_path anas-awadalla/mpt-1b-redpajama-200b \
--tokenizer_path anas-awadalla/mpt-1b-redpajama-200b \
--cross_attn_every_n_layers 1 \
--dataset_resampled \
--batch_size_mmc4 32 \
--batch_size_laion 64 \
--train_num_samples_mmc4 125000\
--train_num_samples_laion 250000 \
--loss_multiplier_laion 0.2 \
--workers=4 \
--run_name OpenFlamingo-3B-vitl-mpt1b \
--num_epochs 480 \
--warmup_steps 1875 \
--mmc4_textsim_threshold 0.24 \
--laion_shards "/path/to/shards/shard-{0000..0999}.tar" \
--mmc4_shards "/path/to/shards/shard-{0000..0999}.tar" \
--report_to_wandb
```
*Note: The MPT-1B [base](https://huggingface.co/mosaicml/mpt-1b-redpajama-200b) and [instruct](https://huggingface.co/mosaicml/mpt-1b-redpajama-200b-dolly) modeling code does not accept the `labels` kwarg or compute cross-entropy loss directly within `forward()`, as expected by our codebase. We suggest using a modified version of the MPT-1B models found [here](https://huggingface.co/anas-awadalla/mpt-1b-redpajama-200b) and [here](https://huggingface.co/anas-awadalla/mpt-1b-redpajama-200b-dolly).*
## Distributed training
Our codebase supports distributed training using three frameworks:
* Pytorch's [DistributedDataParallel](https://pytorch.org/docs/stable/torch.nn.parallel.DistributedDataParallel.html). This is the default method used by `train.py`.
* Pytorch's [FullyShardedDataParallel](https://pytorch.org/docs/stable/fsdp.html) (FSDP). Use the `--fsdp` flag.
Note that you should use exactly one of these training methods.
`train/distributed.py` contains utilities to help with setting up distributed training using Slurm / `torchrun`. See example scripts in the `scripts` directory.
### FSDP notes
To use FSDP, make sure to use Pytorch (> 2.0.1).
We support two sharding strategies for FSDP: full sharding (model sharing across all nodes and GPUs) or hybrid sharding (model sharding across GPUs within nodes, data parallel between nodes). The former saves GPU memory; the latter saves on communication costs.
# Adopted from https://github.com/haotian-liu/LLaVA. Below is the original copyright:
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import ast
import math
from PIL import Image
import torch
def unpad_image(tensor, original_size, keep_original_shape=False):
"""
Unpads a PyTorch tensor of a padded and resized image.
Args:
tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format.
original_size (tuple): The original size of the image (height, width).
Returns:
torch.Tensor: The unpadded image tensor.
"""
original_width, original_height = original_size
current_height, current_width = tensor.shape[1:]
original_aspect_ratio = original_width / original_height
current_aspect_ratio = current_width / current_height
if original_aspect_ratio > current_aspect_ratio:
scale_factor = current_width / original_width
new_height = int(original_height * scale_factor)
padding = (current_height - new_height) // 2
if keep_original_shape:
attention_mask = torch.ones((current_height, current_width), device=tensor.device)
attention_mask[:padding, :] = 0
attention_mask[current_height - padding:, :] = 0
return tensor, attention_mask
else:
unpadded_tensor = tensor[:, padding:current_height - padding, :]
return unpadded_tensor, None
else:
scale_factor = current_height / original_height
new_width = int(original_width * scale_factor)
padding = (current_width - new_width) // 2
if keep_original_shape:
attention_mask = torch.ones((current_height, current_width), device=tensor.device)
attention_mask[:, :padding] = 0
attention_mask[:, current_width - padding:] = 0
return tensor, attention_mask
else:
unpadded_tensor = tensor[:, :, padding:current_width - padding]
return unpadded_tensor, None
def select_best_resolution(original_size, possible_resolutions):
"""
Selects the best resolution from a list of possible resolutions based on the original size.
Args:
original_size (tuple): The original size of the image in the format (width, height).
possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
Returns:
tuple: The best fit resolution in the format (width, height).
"""
original_width, original_height = original_size
best_fit = None
max_effective_resolution = 0
min_wasted_resolution = float('inf')
for width, height in possible_resolutions:
scale = min(width / original_width, height / original_height)
downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
wasted_resolution = (width * height) - effective_resolution
if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
max_effective_resolution = effective_resolution
min_wasted_resolution = wasted_resolution
best_fit = (width, height)
return best_fit
def resize_and_pad_image(image, target_resolution):
"""
Resize and pad an image to a target resolution while maintaining aspect ratio.
Args:
image (PIL.Image.Image): The input image.
target_resolution (tuple): The target resolution (width, height) of the image.
Returns:
PIL.Image.Image: The resized and padded image.
"""
original_width, original_height = image.size
target_width, target_height = target_resolution
scale_w = target_width / original_width
scale_h = target_height / original_height
if scale_w < scale_h:
new_width = target_width
new_height = min(math.ceil(original_height * scale_w), target_height)
else:
new_height = target_height
new_width = min(math.ceil(original_width * scale_h), target_width)
# Resize the image
resized_image = image.resize((new_width, new_height))
new_image = Image.new('RGB', (target_width, target_height), (0, 0, 0))
paste_x = (target_width - new_width) // 2
paste_y = (target_height - new_height) // 2
new_image.paste(resized_image, (paste_x, paste_y))
return new_image
def divide_to_patches(image, patch_size):
"""
Divides an image into patches of a specified size.
Args:
image (PIL.Image.Image): The input image.
patch_size (int): The size of each patch.
Returns:
list: A list of PIL.Image.Image objects representing the patches.
"""
patches = []
width, height = image.size
for i in range(0, height, patch_size):
for j in range(0, width, patch_size):
box = (j, i, j + patch_size, i + patch_size)
patch = image.crop(box)
patches.append(patch)
return patches
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
"""
Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
Args:
image_size (tuple): The size of the input image in the format (width, height).
grid_pinpoints (str): A string representation of a list of possible resolutions.
patch_size (int): The size of each image patch.
Returns:
tuple: The shape of the image patch grid in the format (width, height).
"""
if type(grid_pinpoints) is list:
possible_resolutions = grid_pinpoints
else:
possible_resolutions = ast.literal_eval(grid_pinpoints)
width, height = select_best_resolution(image_size, possible_resolutions)
return width // patch_size, height // patch_size
def process_anyres_image(image, processor, grid_pinpoints):
"""
Process an image with variable resolutions.
Args:
image (PIL.Image.Image): The input image to be processed.
processor: The image processor object.
grid_pinpoints (str): A string representation of a list of possible resolutions.
Returns:
torch.Tensor: A tensor containing the processed image patches.
"""
# FIXME: determine grid_pinpoints from image sizes.
if type(grid_pinpoints) is list:
possible_resolutions = grid_pinpoints
else:
possible_resolutions = ast.literal_eval(grid_pinpoints)
best_resolution = select_best_resolution(image.size, possible_resolutions)
image_padded = resize_and_pad_image(image, best_resolution)
processor_size = processor.transforms[0].size
patches = divide_to_patches(image_padded, processor_size[0])
image_original_resize = image.resize((processor_size[0], processor_size[0]))
image_patches = [image_original_resize] + patches
image_patches = [processor(image_patch)
for image_patch in image_patches]
return torch.stack(image_patches, dim=0)
def expand2square(pil_img, background_color):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
def process_images(images, image_processor, model_cfg):
image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
new_images = []
if image_aspect_ratio == 'pad':
for image in images:
image = expand2square(image, tuple(int(x*255) for x in image_processor.transforms[-1].mean))
image = image_processor(image)
new_images.append(image)
elif image_aspect_ratio in ["anyres", "anyres-legacy"]:
base_img_size = image_processor.transforms[0].size[0]
for image in images:
image = process_anyres_image(image, image_processor, [[base_img_size,base_img_size*2],
[base_img_size*2,base_img_size],
[base_img_size*2,base_img_size*2],
[base_img_size*3,base_img_size],
[base_img_size,base_img_size*3]])
new_images.append(image)
else:
return image_processor(images)
if all(x.shape == new_images[0].shape for x in new_images):
new_images = torch.stack(new_images, dim=0)
return new_images
import dataclasses
from enum import auto, Enum
from typing import List, Tuple
class SeparatorStyle(Enum):
"""Different separator style."""
SINGLE = auto()
TWO = auto()
MPT = auto()
PLAIN = auto()
LLAMA_2 = auto()
PHI_3 = auto()
@dataclasses.dataclass
class Conversation:
"""A class that keeps all conversation history."""
system: str
roles: List[str]
messages: List[List[str]]
offset: int
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
sep: str = "###"
sep2: str = None
version: str = "Unknown"
skip_next: bool = False
def get_prompt(self):
messages = self.messages
if len(messages) > 0 and type(messages[0][1]) is tuple:
messages = self.messages.copy()
init_role, init_msg = messages[0].copy()
init_msg = init_msg[0].replace("<image>", "").strip()
if 'mmtag' in self.version:
messages[0] = (init_role, init_msg)
messages.insert(0, (self.roles[0], "<Image><image></Image>"))
messages.insert(1, (self.roles[1], "Received."))
else:
messages[0] = (init_role, "<image>\n" + init_msg)
if self.sep_style == SeparatorStyle.SINGLE:
ret = self.system + self.sep
for role, message in messages:
if message:
if type(message) is tuple:
message, _, _ = message
ret += role + ": " + message + self.sep
else:
ret += role + ":"
elif self.sep_style == SeparatorStyle.TWO:
seps = [self.sep, self.sep2]
ret = self.system + seps[0]
for i, (role, message) in enumerate(messages):
if message:
if type(message) is tuple:
message, _, _ = message
ret += role + ": " + message + seps[i % 2]
else:
ret += role + ":"
elif self.sep_style == SeparatorStyle.MPT:
ret = self.system + self.sep
for role, message in messages:
if message:
if type(message) is tuple:
message, _, _ = message
ret += role + message + self.sep
else:
ret += role
elif self.sep_style == SeparatorStyle.LLAMA_2:
wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n" if len(msg) > 0 else msg
wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
ret = ""
for i, (role, message) in enumerate(messages):
if i == 0:
assert message, "first message should not be none"
assert role == self.roles[0], "first message should come from user"
if message:
if type(message) is tuple:
message, _, _ = message
if i == 0: message = wrap_sys(self.system) + message
if i % 2 == 0:
message = wrap_inst(message)
ret += self.sep + message
else:
ret += " " + message + " " + self.sep2
else:
ret += ""
ret = ret.lstrip(self.sep)
elif self.sep_style == SeparatorStyle.PLAIN:
seps = [self.sep, self.sep2]
ret = self.system
for i, (role, message) in enumerate(messages):
if message:
if type(message) is tuple:
message, _, _ = message
ret += message + seps[i % 2]
else:
ret += ""
elif self.sep_style == SeparatorStyle.PHI_3:
seps = [self.sep, self.sep2] # []
if self.system != "":
ret = '<|system|>' + '\n' + self.system + '<|end|>' + '\n'
else:
# Phi-3 w/o system prompt.
ret = ""
for i, (role, message) in enumerate(messages):
if message:
if type(message) is tuple:
message, _, _ = message
ret += role + '\n' + message + '<|end|>' + '\n'
else:
ret += '<|assistant|>' + '\n'
else:
raise ValueError(f"Invalid style: {self.sep_style}")
return ret
def append_message(self, role, message):
self.messages.append([role, message])
def get_images(self, return_pil=False):
images = []
for i, (role, msg) in enumerate(self.messages[self.offset:]):
if i % 2 == 0:
if type(msg) is tuple:
import base64
from io import BytesIO
from PIL import Image
msg, image, image_process_mode = msg
if image_process_mode == "Pad":
def expand2square(pil_img, background_color=(122, 116, 104)):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
image = expand2square(image)
elif image_process_mode in ["Default", "Crop"]:
pass
elif image_process_mode == "Resize":
image = image.resize((336, 336))
else:
raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
max_hw, min_hw = max(image.size), min(image.size)
aspect_ratio = max_hw / min_hw
max_len, min_len = 800, 400
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
longest_edge = int(shortest_edge * aspect_ratio)
W, H = image.size
if longest_edge != max(image.size):
if H > W:
H, W = longest_edge, shortest_edge
else:
H, W = shortest_edge, longest_edge
image = image.resize((W, H))
if return_pil:
images.append(image)
else:
buffered = BytesIO()
image.save(buffered, format="PNG")
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
images.append(img_b64_str)
return images
def to_gradio_chatbot(self):
ret = []
for i, (role, msg) in enumerate(self.messages[self.offset:]):
if i % 2 == 0:
if type(msg) is tuple:
import base64
from io import BytesIO
msg, image, image_process_mode = msg
max_hw, min_hw = max(image.size), min(image.size)
aspect_ratio = max_hw / min_hw
max_len, min_len = 800, 400
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
longest_edge = int(shortest_edge * aspect_ratio)
W, H = image.size
if H > W:
H, W = longest_edge, shortest_edge
else:
H, W = shortest_edge, longest_edge
image = image.resize((W, H))
buffered = BytesIO()
image.save(buffered, format="JPEG")
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
msg = img_str + msg.replace('<image>', '').strip()
ret.append([msg, None])
else:
ret.append([msg, None])
else:
ret[-1][-1] = msg
return ret
def copy(self):
return Conversation(
system=self.system,
roles=self.roles,
messages=[[x, y] for x, y in self.messages],
offset=self.offset,
sep_style=self.sep_style,
sep=self.sep,
sep2=self.sep2,
version=self.version)
def dict(self):
if len(self.get_images()) > 0:
return {
"system": self.system,
"roles": self.roles,
"messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
"offset": self.offset,
"sep": self.sep,
"sep2": self.sep2,
}
return {
"system": self.system,
"roles": self.roles,
"messages": self.messages,
"offset": self.offset,
"sep": self.sep,
"sep2": self.sep2,
}
conv_vicuna_v0 = Conversation(
system="A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
roles=("Human", "Assistant"),
messages=(
("Human", "What are the key differences between renewable and non-renewable energy sources?"),
("Assistant",
"Renewable energy sources are those that can be replenished naturally in a relatively "
"short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
"Non-renewable energy sources, on the other hand, are finite and will eventually be "
"depleted, such as coal, oil, and natural gas. Here are some key differences between "
"renewable and non-renewable energy sources:\n"
"1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
"energy sources are finite and will eventually run out.\n"
"2. Environmental impact: Renewable energy sources have a much lower environmental impact "
"than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
"and other negative effects.\n"
"3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
"have lower operational costs than non-renewable sources.\n"
"4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
"locations than non-renewable sources.\n"
"5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
"situations and needs, while non-renewable sources are more rigid and inflexible.\n"
"6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
"non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
),
offset=2,
sep_style=SeparatorStyle.SINGLE,
sep="###",
)
conv_vicuna_v1 = Conversation(
system="A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
roles=("USER", "ASSISTANT"),
version="v1",
messages=(),
offset=0,
sep_style=SeparatorStyle.TWO,
sep=" ",
sep2="</s>",
)
conv_mistral_instruct = Conversation(
system="",
roles=("USER", "ASSISTANT"),
version="llama_v2",
messages=(),
offset=0,
sep_style=SeparatorStyle.LLAMA_2,
sep="",
sep2="</s>",
)
conv_phi_3 = Conversation(
system="A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
roles=("<|user|>", "<|assistant|>"),
version="phi_3",
messages=(),
offset=0,
sep_style=SeparatorStyle.PHI_3,
sep="<s>",
sep2="<|end|>",
)
default_conversation = conv_vicuna_v1
conv_templates = {
"default": conv_vicuna_v0,
"v0": conv_vicuna_v0,
"v1": conv_vicuna_v1,
"vicuna_v1": conv_vicuna_v1,
"mistral_instruct": conv_mistral_instruct,
"phi_3": conv_phi_3,
}
if __name__ == "__main__":
print(default_conversation.get_prompt())
"""
Util functions for initializing webdataset objects
"""
import ast
import json
import logging
import os
import random
import sys
from dataclasses import dataclass
from multiprocessing import Value
import braceexpand
import webdataset as wds
from torch.utils.data import DataLoader, IterableDataset, get_worker_info
from torch.utils.data.distributed import DistributedSampler
from webdataset.filters import _shuffle
from webdataset.tariterators import (
base_plus_ext,
tar_file_expander,
url_opener,
valid_sample,
)
try:
import horovod.torch as hvd
except ImportError:
hvd = None
class SharedEpoch:
def __init__(self, epoch: int = 0):
self.shared_epoch = Value("i", epoch)
def set_value(self, epoch):
self.shared_epoch.value = epoch
def get_value(self):
return self.shared_epoch.value
@dataclass
class DataInfo:
"""
DataInfo is a dataclass that holds information about a dataset.
"""
name: str
dataloader: DataLoader
batch_size: int
loss_multiplier: int
sampler: DistributedSampler = None
shared_epoch: SharedEpoch = None
def set_epoch(self, epoch):
if self.shared_epoch is not None:
self.shared_epoch.set_value(epoch)
if self.sampler is not None and isinstance(self.sampler, DistributedSampler):
self.sampler.set_epoch(epoch)
def get_dataset_size(shards):
"""
Get the number of samples in a dataset and the number of shards in a dataset
based on the shards list.
Returns None for the number of samples if is undefined.
One can define the number of samples using a sizes.json file in the same directory
or a __len__ file in the same directory.
"""
shards_list = list(braceexpand.braceexpand(shards))
dir_path = os.path.dirname(shards[0])
sizes_filename = os.path.join(dir_path, "sizes.json")
len_filename = os.path.join(dir_path, "__len__")
if os.path.exists(sizes_filename):
sizes = json.load(open(sizes_filename, "r"))
total_size = sum(
[
int(sizes[os.path.basename(shard)])
if os.path.basename(shard) in sizes
else 0
for shard in shards_list
]
)
elif os.path.exists(len_filename):
# FIXME this used to be eval(open(...)) but that seemed rather unsafe
total_size = ast.literal_eval(open(len_filename, "r").read())
else:
total_size = None # num samples undefined
num_shards = len(shards_list)
return total_size, num_shards
def log_and_continue(exn):
"""Call in an exception handler to ignore any exception, issue a warning, and continue."""
if "images in sample" not in repr(exn):
logging.warning(f"Handling webdataset error ({repr(exn)}). Ignoring.")
return True
def group_by_keys_nothrow(
data, keys=base_plus_ext, lcase=True, suffixes=None, handler=None
):
"""Return function over iterator that groups key, value pairs into samples.
:param keys: function that splits the key into key and extension (base_plus_ext)
:param lcase: convert suffixes to lower case (Default value = True)
"""
current_sample = None
for filesample in data:
assert isinstance(filesample, dict)
fname, value = filesample["fname"], filesample["data"]
prefix, suffix = keys(fname)
if prefix is None:
continue
if lcase:
suffix = suffix.lower()
# FIXME webdataset version throws if suffix in current_sample, but we have a potential for
# this happening in the current LAION400m dataset if a tar ends with same prefix as the next
# begins, rare, but can happen since prefix aren't unique across tar files in that dataset
if (
current_sample is None
or prefix != current_sample["__key__"]
or suffix in current_sample
):
if valid_sample(current_sample):
yield current_sample
current_sample = dict(__key__=prefix, __url__=filesample["__url__"])
if suffixes is None or suffix in suffixes:
current_sample[suffix] = value
if valid_sample(current_sample):
yield current_sample
def tarfile_to_samples_nothrow(src, handler=log_and_continue):
# NOTE this is a re-impl of the webdataset impl with group_by_keys that doesn't throw
streams = url_opener(src, handler=handler)
files = tar_file_expander(streams, handler=handler)
samples = group_by_keys_nothrow(files, handler=handler)
return samples
def pytorch_worker_seed(increment=0):
"""get dataloader worker seed from pytorch"""
worker_info = get_worker_info()
if worker_info is not None:
# favour using the seed already created for pytorch dataloader workers if it exists
seed = worker_info.seed
if increment:
# space out seed increments so they can't overlap across workers in different iterations
seed += increment * max(1, worker_info.num_workers)
return seed
# fallback to wds rank based seed
return wds.utils.pytorch_worker_seed()
class detshuffle2(wds.PipelineStage):
def __init__(
self,
bufsize=1000,
initial=100,
seed=0,
epoch=-1,
):
self.bufsize = bufsize
self.initial = initial
self.seed = seed
self.epoch = epoch
def run(self, src):
if isinstance(self.epoch, SharedEpoch):
epoch = self.epoch.get_value()
else:
# NOTE: this is epoch tracking is problematic in a multiprocess (dataloader workers or train)
# situation as different workers may wrap at different times (or not at all).
self.epoch += 1
epoch = self.epoch
rng = random.Random()
if self.seed < 0:
# If seed is negative, we use the worker's seed, this will be different across all nodes/workers
seed = pytorch_worker_seed(epoch)
else:
# This seed to be deterministic AND the same across all nodes/workers in each epoch
seed = self.seed + epoch
rng.seed(seed)
return _shuffle(src, self.bufsize, self.initial, rng)
class ResampledShards2(IterableDataset):
"""An iterable dataset yielding a list of urls."""
def __init__(
self,
urls,
nshards=sys.maxsize,
worker_seed=None,
deterministic=False,
epoch=-1,
):
"""Sample shards from the shard list with replacement.
:param urls: a list of URLs as a Python list or brace notation string
"""
super().__init__()
urls = wds.shardlists.expand_urls(urls)
self.urls = urls
assert isinstance(self.urls[0], str)
self.nshards = nshards
self.rng = random.Random()
self.worker_seed = worker_seed
self.deterministic = deterministic
self.epoch = epoch
def __iter__(self):
"""Return an iterator over the shards."""
if isinstance(self.epoch, SharedEpoch):
epoch = self.epoch.get_value()
else:
# NOTE: this is epoch tracking is problematic in a multiprocess (dataloader workers or train)
# situation as different workers may wrap at different times (or not at all).
self.epoch += 1
epoch = self.epoch
if self.deterministic:
# reset seed w/ epoch if deterministic
if self.worker_seed is None:
# pytorch worker seed should be deterministic due to being init by arg.seed + rank + worker id
seed = pytorch_worker_seed(epoch)
else:
seed = self.worker_seed() + epoch
self.rng.seed(seed)
for _ in range(self.nshards):
yield dict(url=self.rng.choice(self.urls))
"""
Util functions for distributed training and FSDP.
"""
import os
import torch
##################################
# SLURM setup; Credit: open_clip #
##################################
try:
import horovod.torch as hvd
except ImportError:
hvd = None
def is_global_master(args):
return args.rank == 0
def is_local_master(args):
return args.local_rank == 0
def is_master(args, local=False):
return is_local_master(args) if local else is_global_master(args)
def is_using_horovod():
# NOTE w/ horovod run, OMPI vars should be set, but w/ SLURM PMI vars will be set
# Differentiating between horovod and DDP use via SLURM may not be possible, so horovod arg still required...
ompi_vars = ["OMPI_COMM_WORLD_RANK", "OMPI_COMM_WORLD_SIZE"]
pmi_vars = ["PMI_RANK", "PMI_SIZE"]
if all([var in os.environ for var in ompi_vars]) or all(
[var in os.environ for var in pmi_vars]
):
return True
else:
return False
def is_using_distributed():
if "WORLD_SIZE" in os.environ:
return int(os.environ["WORLD_SIZE"]) > 1
if "SLURM_NTASKS" in os.environ:
return int(os.environ["SLURM_NTASKS"]) > 1
return False
def world_info_from_env():
local_rank = 0
for v in (
"LOCAL_RANK",
"MPI_LOCALRANKID",
"SLURM_LOCALID",
"OMPI_COMM_WORLD_LOCAL_RANK",
):
if v in os.environ:
local_rank = int(os.environ[v])
break
global_rank = 0
for v in ("RANK", "PMI_RANK", "SLURM_PROCID", "OMPI_COMM_WORLD_RANK"):
if v in os.environ:
global_rank = int(os.environ[v])
break
world_size = 1
for v in ("WORLD_SIZE", "PMI_SIZE", "SLURM_NTASKS", "OMPI_COMM_WORLD_SIZE"):
if v in os.environ:
world_size = int(os.environ[v])
break
return local_rank, global_rank, world_size
def init_distributed_device(args):
# Distributed training = training on more than one GPU.
# Works in both single and multi-node scenarios.
args.distributed = False
args.world_size = 1
args.rank = 0 # global rank
args.local_rank = 0
if args.horovod:
assert hvd is not None, "Horovod is not installed"
hvd.init()
args.local_rank = int(hvd.local_rank())
args.rank = hvd.rank()
args.world_size = hvd.size()
args.distributed = True
os.environ["LOCAL_RANK"] = str(args.local_rank)
os.environ["RANK"] = str(args.rank)
os.environ["WORLD_SIZE"] = str(args.world_size)
elif is_using_distributed():
if "SLURM_PROCID" in os.environ:
# DDP via SLURM
args.local_rank, args.rank, args.world_size = world_info_from_env()
# SLURM var -> torch.distributed vars in case needed
os.environ["LOCAL_RANK"] = str(args.local_rank)
os.environ["RANK"] = str(args.rank)
os.environ["WORLD_SIZE"] = str(args.world_size)
torch.distributed.init_process_group(
backend=args.dist_backend,
init_method=args.dist_url,
world_size=args.world_size,
rank=args.rank,
)
else:
# DDP via torchrun, torch.distributed.launch
args.local_rank, _, _ = world_info_from_env()
torch.distributed.init_process_group(
backend=args.dist_backend, init_method=args.dist_url
)
args.world_size = torch.distributed.get_world_size()
args.rank = torch.distributed.get_rank()
args.distributed = True
else:
# needed to run on single gpu
torch.distributed.init_process_group(
backend=args.dist_backend,
init_method=args.dist_url,
world_size=1,
rank=0,
)
if torch.cuda.is_available():
if args.distributed and not args.no_set_device_rank:
device = "cuda:%d" % args.local_rank
else:
device = "cuda:0"
torch.cuda.set_device(device)
else:
device = "cpu"
args.device = device
device = torch.device(device)
return device
#####################################
# FSDP util functions #
#####################################
def get_fsdp_mixed_precision_policy(
precision: str,
reduce_param_precision=False,
reduce_communication_precision=True,
reduce_buffer_precision=True,
):
"""
Returns the FSDP mixed precision policy for a given precision.
"""
if "bfloat16" in precision or "bf16" in precision:
cast_dtype = torch.bfloat16
elif precision == "fp16":
cast_dtype = torch.float16
else:
cast_dtype = torch.float32
if cast_dtype == torch.float32:
return None
from torch.distributed.fsdp import MixedPrecision
return MixedPrecision(
param_dtype=cast_dtype if reduce_param_precision else torch.float32,
reduce_dtype=cast_dtype if reduce_communication_precision else torch.float32,
buffer_dtype=cast_dtype if reduce_buffer_precision else torch.float32,
)
def get_fsdp_config(
args,
device_id,
):
"""
Return kwargs for FSDP wrapper.
This includes some hard-coded settings.
"""
# init MixedPrecision
mp_policy = get_fsdp_mixed_precision_policy(
args.precision,
reduce_param_precision=False,
reduce_communication_precision=True,
reduce_buffer_precision=True,
)
# init FSDP
from torch.distributed.fsdp import (
ShardingStrategy,
BackwardPrefetch,
CPUOffload
)
if args.fsdp_sharding_strategy == "full":
sharding_strategy = ShardingStrategy.FULL_SHARD
elif args.fsdp_sharding_strategy == "hybrid":
sharding_strategy = ShardingStrategy.HYBRID_SHARD
elif args.fsdp_sharding_strategy == "shard_grad_op":
sharding_strategy = ShardingStrategy.SHARD_GRAD_OP
elif args.fsdp_sharding_strategy == "hybrid_shard_grad_op":
sharding_strategy = ShardingStrategy._HYBRID_SHARD_ZERO2
elif args.fsdp_sharding_strategy == "no_shard":
sharding_strategy = ShardingStrategy.NO_SHARD
else:
raise ValueError(
f"Invalid sharding strategy: {args.fsdp_sharding_strategy}. Supported: full, hybrid, shard_grad_op, hybrid_shard_grad_op, no_shard"
)
if args.cpu_offload_gradients:
cpu_offload = CPUOffload(offload_params=True)
else:
cpu_offload = None
return dict(
cpu_offload=cpu_offload,
device_id=device_id,
sync_module_states=True, # broadcast loaded ckpt from rank 0 -> all ranks
sharding_strategy=sharding_strategy,
use_orig_params=True,
mixed_precision=mp_policy,
forward_prefetch=True,
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
limit_all_gathers=True,
)
def get_fsdp_checkpoint_config(args):
"""
Return kwargs for FSDP checkpointing.
"""
from torch.distributed.fsdp import (
FullStateDictConfig,
StateDictType,
)
from torch.distributed.fsdp.api import FullOptimStateDictConfig
# to avoid GPU OOM when loading/saving ckpts, load/save on CPU
# this is slow
return dict(
state_dict_type=StateDictType.FULL_STATE_DICT,
state_dict_config=FullStateDictConfig(rank0_only=True, offload_to_cpu=True),
optim_state_dict_config=FullOptimStateDictConfig(
rank0_only=True, offload_to_cpu=True
),
)
""" Main training script """
import argparse
from datetime import datetime
import os
from omegaconf import OmegaConf
import torch
import wandb
import functools
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy
from open_flamingo import create_model_and_transforms, SUPPORTED_MODEL_FAMILIES
from open_flamingo.train.distributed import (
init_distributed_device,
world_info_from_env,
get_fsdp_config,
get_fsdp_checkpoint_config,
)
from open_flamingo.train.sft_data_utils import make_supervised_data_module
from open_flamingo.train.train_utils import (
finetune_one_epoch,
random_seed,
find_most_recent_checkpoint,
load_checkpoint,
save_checkpoint,
)
from open_flamingo.train.losses import (
SUPPORTED_LOSSES,
get_loss_fn,
)
from transformers import (
get_constant_schedule_with_warmup,
get_cosine_schedule_with_warmup,
get_linear_schedule_with_warmup,
)
def parse_tuple_list(input_string):
try:
tuples = input_string.strip().strip('()').split('),(')
# Convert each item in the list to a tuple
tuple_list = [tuple(map(int, item.split(','))) for item in tuples]
return tuple_list
except Exception as e:
raise argparse.ArgumentTypeError(f"Invalid tuple list format: {input_string}. Error: {e}")
def main():
parser = argparse.ArgumentParser()
# model configuration args
parser.add_argument(
"--model_family", default="kosmos-instruct", type=str, choices=SUPPORTED_MODEL_FAMILIES
)
parser.add_argument("--vision_encoder_path", default="ViT-SO400M-14-SigLIP-384", type=str)
parser.add_argument("--vision_encoder_pretrained", default="webli", type=str)
parser.add_argument("--lm_path", default="facebook/opt-1.3b", type=str)
parser.add_argument(
"--tokenizer_path",
default="facebook/opt-30b",
type=str,
help="path to tokenizer",
)
parser.add_argument(
"--cross_attn_every_n_layers",
type=int,
default=1,
help="how often to add a cross-attention layer after each transformer layer",
)
parser.add_argument(
"--num_vision_tokens",
type=int, default=64, help="number of query tokens used for resampling vision features.",
)
parser.add_argument("--pretrained", type=str, default=None, help="pretrained weights for fine-tuning.")
parser.add_argument("--pretrained_vision_tokenizer", type=str, default=None, help="pretrained vl connector for fine-tuning.")
# training args
parser.add_argument(
"--loss", type=str, choices=SUPPORTED_LOSSES, default="supervised_finetune"
)
parser.add_argument(
"--run_name",
type=str,
default="openflamingo3B",
help="used to name saving directory and wandb run",
)
parser.add_argument(
"--resume_from_checkpoint",
type=str,
help="path to checkpoint to resume from, this should contain model, optimizer, and lr_scheduler states. if there exists a checkpoint in the dir named run_name, we will resume from that checkpoint by default.",
default=None,
)
parser.add_argument(
"--delete_previous_checkpoint",
action="store_true",
help="delete previous checkpoint when saving new checkpoint",
)
parser.add_argument(
"--no_save_optim_state",
action="store_true",
help="do not save optimizer states when saving checkpoints",
)
parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--learning_rate", default=1e-4, type=float)
parser.add_argument(
"--lr_scheduler",
default="constant",
type=str,
help="constant, linear, or cosine",
)
parser.add_argument("--warmup_steps", default=5000, type=int)
parser.add_argument("--weight_decay", default=0.1, type=float)
parser.add_argument(
"--precision",
choices=["amp_bf16", "amp_bfloat16", "bf16", "fp16", "fp32"],
default="fp32",
help="Floating point precision.",
)
parser.add_argument(
"--gradient_checkpointing",
action="store_true",
help="whether to train with gradient/activation checkpointing",
)
parser.add_argument(
"--num_epochs",
type=int,
default=1,
help="we define an 'epoch' as a fixed number of examples specified by train_num_samples, not a pass through the entire dataset",
)
parser.add_argument("--offline", action="store_true")
parser.add_argument(
"--logging_steps", type=int, default=100, help="log loss every n steps"
)
parser.add_argument(
"--checkpoint_steps", type=int, default=5000, help="log loss every n steps"
)
# data args
# TODO: load a data args yaml file
parser.add_argument(
"--data_path",
default="/export/home/LLaVA/playground/data/llava_v1_5_mix665k_ocr_tagged_vqa_placeholder.json",
type=str
)
parser.add_argument("--batch_size", type=int, default=8)
parser.add_argument("--workers", type=int, default=1)
parser.add_argument("--data_sampler_group_by_length", default=False, action="store_true")
# Legacy Llava data args
parser.add_argument("--is_multimodal", type=bool, default=True)
parser.add_argument("--mm_use_im_start_end", default=False, action="store_true")
parser.add_argument("--conv_template_name", type=str, default=None)
# Any resolution
parser.add_argument("--image_aspect_ratio", type=str, default='pad')
parser.add_argument(
"--anyres_patch_sampling",
default=False,
action="store_true",
)
parser.add_argument('--anyres_grids',
type=parse_tuple_list,
default="(1,2),(2,1),(2,2),(3,1),(1,3)",
help="List of tuples in the format (1,2),(3,4),...")
# distributed training args
parser.add_argument(
"--dist-url",
default="env://",
type=str,
help="url used to set up distributed training",
)
parser.add_argument(
"--dist-backend", default="nccl", type=str, help="distributed backend"
)
parser.add_argument(
"--horovod",
default=False,
action="store_true",
help="Use horovod for distributed training.",
)
parser.add_argument(
"--no-set-device-rank",
default=False,
action="store_true",
help="Don't set device index from local rank (when CUDA_VISIBLE_DEVICES restricted to one per proc).",
)
parser.add_argument(
'--local-rank',
default=0,
type=int,
help='Local rank for distributed training'
)
# fsdp args
parser.add_argument(
"--fsdp",
default=False,
action="store_true",
help="Use FullyShardedDataParallel for distributed training. Not supported for some models, e.g. OPT.",
)
parser.add_argument(
"--fsdp_sharding_strategy", default="full", type=str, choices=["full", "hybrid", "shard_grad_op", "hybrid_shard_grad_op", "no_shard"]
)
# wandb args
parser.add_argument("--report_to_wandb", default=False, action="store_true")
parser.add_argument(
"--wandb_project",
type=str,
)
parser.add_argument(
"--wandb_entity",
type=str,
)
parser.add_argument(
"--save_checkpoints_to_wandb",
default=False,
action="store_true",
help="save checkpoints to wandb",
)
parser.add_argument(
"--dryrun",
default=False,
action="store_true",
)
parser.add_argument(
'--use_flash_attention_2',
default=False, action='store_true',
help='Use Flash Attention 2.0 for language model.'
)
parser.add_argument(
'--unfreeze_vision_encoder',
default=False, action='store_true',
help='Unfreeze vision encoder during training.'
)
parser.add_argument(
'--vision_encoder_precision',
default='fp32',
choices=["bf16", "fp32"],
help='Precision of the vision encoder during training.'
)
parser.add_argument(
'--cpu_offload_gradients',
default=False, action='store_true',
help='This specifies whether to offload parameters to CPU when not involved in computation. If True, then this offloads gradients to CPU as well, meaning that the optimizer step runs on CPU.'
)
args = parser.parse_args()
if args.save_checkpoints_to_wandb and not args.report_to_wandb:
raise ValueError("save_checkpoints_to_wandb requires report_to_wandb")
if args.fsdp:
assert (
torch.__version__ > "2.0.1"
), "FSDP requires torch > 2.0.1"
# Set up distributed training
args.local_rank, args.rank, args.world_size = world_info_from_env()
if args.rank == 0:
print(f"Initializing distributed training with {args.world_size} GPUs.")
if args.offline:
os.environ["WANDB_MODE"] = "offline"
os.environ["TRANSFORMERS_OFFLINE"] = "1"
device_id = init_distributed_device(args)
random_seed(args.seed)
# Initialize model
if args.model_family == "flamingo":
additional_kwargs={"cross_attn_every_n_layers": args.cross_attn_every_n_layers}
elif args.model_family in ['xgenmm_v1']:
additional_kwargs = {
"image_aspect_ratio": args.image_aspect_ratio,
"num_vision_tokens": args.num_vision_tokens,
"anyres_patch_sampling": args.anyres_patch_sampling,
}
else:
additional_kwargs = {}
model, image_processor, tokenizer = create_model_and_transforms(
args.vision_encoder_path,
args.vision_encoder_pretrained,
args.lm_path,
args.tokenizer_path if args.tokenizer_path else args.lm_path,
model_family=args.model_family,
pretrained_vision_tokenizer=args.pretrained_vision_tokenizer,
use_local_files=args.offline,
gradient_checkpointing=args.gradient_checkpointing,
verbose=(args.rank == 0),
**additional_kwargs,
)
random_seed(args.seed, args.rank)
# Initialize wandb logging
now = datetime.now().strftime("%Y%m%d%H%M")[:-1]
if args.rank == 0 and args.report_to_wandb:
config = vars(args)
# print("-----------wandb:Config:", config)
wandb.init(
project=args.wandb_project,
name=f"{args.run_name}-{now}",
config=vars(args),
settings=wandb.Settings(init_timeout=120)
)
# Load model checkpoint (on CPU)
if args.fsdp:
args.fsdp_checkpoint_config = get_fsdp_checkpoint_config(args)
# if args do not specify a checkpoint to resume from, resume from most recent checkpoint
resume_from_step = 0
if os.path.exists(f"{args.run_name}") and args.resume_from_checkpoint is None:
args.resume_from_checkpoint = find_most_recent_checkpoint(args)
if (
args.resume_from_checkpoint is not None
):
resume_from_epoch, resume_from_step, checkpoint = load_checkpoint(args, model)
print(f"Resume training from epoch {resume_from_epoch}, step {resume_from_step}...")
else:
resume_from_epoch = 0
resume_from_step = 0
# Load pretrained weights.
if args.resume_from_checkpoint is None and not args.dryrun:
if args.pretrained_vision_tokenizer is None:
assert os.path.exists(args.pretrained), "Must fine-tune from a pretrained weight."
if args.pretrained is not None:
_, _, checkpoint = load_checkpoint(args, model, pretrained=True)
print("Finished loading checkpoint...")
# Initialize gradient checkpointing
if args.gradient_checkpointing:
model.init_gradient_checkpointing()
# Initialize FSDP / DDP, and ensure the model is on GPU
if args.fsdp:
auto_wrap_policy = functools.partial(
lambda_auto_wrap_policy, lambda_fn=model.get_fsdp_lambda_fn()
)
wrapper_kwargs = get_fsdp_config(args, device_id)
distributed_model = FSDP(
model, auto_wrap_policy=auto_wrap_policy, **wrapper_kwargs
)
print("Finished FSDP wrapping...")
else:
model = model.to(device_id)
distributed_model = DDP(model, device_ids=[device_id])
# Initialize optimizer
params_with_wd, params_without_wd = model.group_params_by_weight_decay()
optimizer = torch.optim.AdamW(
[
{"params": params_with_wd, "weight_decay": args.weight_decay},
{"params": params_without_wd, "weight_decay": 0.0},
],
lr=args.learning_rate,
)
# load optimizer checkpoint
if args.resume_from_checkpoint is not None:
optim_state_dict = checkpoint["optimizer_state_dict"]
if args.fsdp:
# FSDP.set_state_dict_type(
# distributed_model,
# **args.fsdp_checkpoint_config,
# )
optim_state_dict = FSDP.optim_state_dict_to_load(
model=distributed_model, optim=optimizer, optim_state_dict=optim_state_dict
)
optimizer.load_state_dict(optim_state_dict)
# Initialize datasets
if args.data_path.split('.')[-1] == 'yaml':
# Loading a mixture of datasets with sampling ratios.
data_config = OmegaConf.load(args.data_path)
if args.rank == 0:
print("================== Data mixture config ===================")
print(data_config)
print("==========================================================")
args.data_path = dict(data_config.data_path)
train_dataset, total_num_samples = make_supervised_data_module(tokenizer=tokenizer,
image_processor=image_processor,
data_args=args)
print("-----------------------")
data_loader = train_dataset.dataloader
for batch in data_loader:
print(batch)
break
print(total_num_samples)
# Update anyres grid.
args.anyres_grids = train_dataset.dataloader.dataset.anyres_grids
model.anyres_grids = args.anyres_grids
# TODO: Summarize training data stats (dataset, portion, etc.)
total_training_steps = (
total_num_samples
// (args.batch_size * args.gradient_accumulation_steps * args.world_size)
) * args.num_epochs
if args.rank == 0:
print(f"Total training steps: {total_training_steps}")
# Initialize lr scheduler
if args.lr_scheduler == "linear":
lr_scheduler = get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=args.warmup_steps,
num_training_steps=total_training_steps,
)
elif args.lr_scheduler == "cosine":
lr_scheduler = get_cosine_schedule_with_warmup(
optimizer,
num_warmup_steps=args.warmup_steps,
num_training_steps=total_training_steps,
)
else:
lr_scheduler = get_constant_schedule_with_warmup(
optimizer, num_warmup_steps=args.warmup_steps
)
# load lr scheduler checkpoint
if args.resume_from_checkpoint is not None:
lr_scheduler.load_state_dict(checkpoint["lr_scheduler_state_dict"])
# Initialize the loss fn
loss_fn = get_loss_fn(args.loss)
# check wrapping
if args.rank == 0:
print(distributed_model)
# Start training!
print(f"Start running training on rank {args.rank}.")
for epoch in range(resume_from_epoch, args.num_epochs):
train_dataset.set_epoch(epoch)
finetune_one_epoch(
args=args,
resume_from_step=resume_from_step,
model=distributed_model,
epoch=epoch,
dataset=train_dataset,
compute_loss_fn=loss_fn,
tokenizer=tokenizer,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
device_id=device_id,
wandb=wandb,
)
save_checkpoint(distributed_model, optimizer, lr_scheduler, epoch, args)
if __name__ == "__main__":
main()
\ No newline at end of file
from open_flamingo.src.vlm import VLM
import torch
from typing import List, Optional
SUPPORTED_LOSSES = ["next_token_prediction",
"supervised_finetune"]
def get_loss_fn(loss_name):
if loss_name == "next_token_prediction":
return NextTokenPrediction()
elif loss_name == "supervised_finetune":
return SupervisedPrediction()
else:
raise ValueError(
f"Loss {loss_name} not supported. Supported losses: {SUPPORTED_LOSSES}"
)
class Loss:
@property
def name(self):
raise NotImplementedError
def __call__(
self,
model: VLM,
tokenizer,
images: torch.Tensor,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
autocast: callable,
):
"""
Args:
model: VLM model
images: images tensor, already moved to device and cast to appropriate dtype
shape (B, T_img, F, C, H, W)
input_ids: input ids tensor, already moved to device and cast to appropriate dtype
shape (B, T_text)
attention_mask: attention mask tensor, already moved to device and cast to appropriate dtype
shape (B, T_text)
autocast: autocast context manager
Return:
loss: scalar loss
"""
raise NotImplementedError
class NextTokenPrediction(Loss):
@property
def name(self):
return "next_token_prediction"
def __call__(
self,
model: VLM,
tokenizer,
images: torch.Tensor,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
autocast: callable,
):
# set up labels; language model is expected to handle shifting
labels = input_ids.clone()
labels[labels == tokenizer.pad_token_id] = -100
special_token_ids = torch.Tensor(unwrap_model(model).special_token_ids).to(
labels.device
)
labels[torch.isin(labels, special_token_ids)] = -100 # TODO: dont want to remove loss on <|endofchunk|> tokens
labels = labels.to(input_ids.device)
# call forward
with autocast():
loss = model(
vision_x=images,
lang_x=input_ids,
attention_mask=attention_mask,
labels=labels,
)[0]
return loss
class SupervisedPrediction(Loss):
@property
def name(self):
return "supervised_finetune"
def __call__(
self,
model: VLM,
tokenizer,
images: torch.Tensor,
input_ids: torch.Tensor,
labels: torch.Tensor,
attention_mask: torch.Tensor,
autocast: callable,
image_size: Optional[torch.Tensor] = None,
):
# set up labels; language model is expected to handle shifting
labels[labels == tokenizer.pad_token_id] = -100
special_token_ids = torch.Tensor(unwrap_model(model).special_token_ids).to(
labels.device
)
labels[torch.isin(labels, special_token_ids)] = -100 # TODO: dont want to remove loss on <|endofchunk|> tokens
labels = labels.to(input_ids.device)
# call forward
with autocast():
loss = model(
vision_x=images,
image_size=image_size,
lang_x=input_ids,
attention_mask=attention_mask,
labels=labels,
)[0]
return loss
def unwrap_model(model):
"""
Unwrap a model from a DataParallel or DistributedDataParallel wrapper.
"""
if isinstance(
model, (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel)
):
return model.module
else:
return model
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment