Commit 1b9205c9 authored by yangzhong's avatar yangzhong
Browse files

v1.0

parents
Pipeline #2931 failed with stages
in 0 seconds
import time
from contextlib import suppress
import torch
from tqdm import tqdm
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
import os
import shutil
import wandb
import glob
from data_utils import DataInfo
import random
import numpy as np
import torch.nn as nn
def train_one_epoch(
args,
model,
epoch,
datasets: [DataInfo],
compute_loss_fn: callable,
tokenizer,
optimizer,
lr_scheduler,
device_id,
wandb,
):
"""
Helper function for running one epoch of training.
Handles logging, calling forward, backward, gradient clipping, and optimizer step.
Args:
args (argparse.Namespace): arguments from command line
model: DDP / FSDP wrapped model
epoch (int): epoch number
datasets (list): list of DataInfos, one for each dataset, to train on
compute_loss_fn (callable): function that given the model and inputs, calls forward
and returns a loss
tokenizer: tokenizer for the language model
optimizer: optimizer to step
lr_scheduler: learning rate scheduler
device_id (int): GPU device ID for this rank
wandb: wandb object for logging
"""
# calculate the number of steps in an epoch
num_batches_per_epoch = datasets[0].dataloader.num_batches
total_training_steps = num_batches_per_epoch * args.num_epochs
# set up model, autocast, and dtypes
model.train()
autocast = get_autocast(args.precision)
# set up logging
step_time_m = AverageMeter()
data_time_m = AverageMeter()
end = time.time()
# loop through the batches in this epoch
for step_num, batches in tqdm(
enumerate(zip(*[dataset.dataloader for dataset in datasets])),
disable=args.rank != 0,
total=total_training_steps,
initial=(epoch * num_batches_per_epoch),
):
data_time_m.update(time.time() - end)
global_step = step_num + epoch * num_batches_per_epoch
# call compute_loss_fn on each dataset; call backward before continuing
losses_to_log = {}
batch_metadata_to_log = {}
for dataset_ix, (images, (input_ids, attention_mask)) in enumerate(batches):
# unpack the batch and move to device
images = images.to(device_id, non_blocking=True)
input_ids = input_ids.to(device_id, non_blocking=True)
attention_mask = attention_mask.to(device_id, non_blocking=True)
# save some metadata for logging
batch_metadata_to_log[
f"{datasets[dataset_ix].name}_num_tokens"
] = attention_mask.sum().item()
batch_metadata_to_log[f"{datasets[dataset_ix].name}_num_images"] = (
(input_ids == unwrap_model(model).media_token_id).sum().item()
)
# forward pass
dataset_loss = compute_loss_fn(
model=model,
tokenizer=tokenizer,
images=images,
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
)[0]
divided_loss_laion = loss_laion / args.gradient_accumulation_steps
(divided_loss_laion * args.loss_multiplier_laion).backward()
#### MMC4 FORWARD PASS ####
images = batch_mmc4[0].to(device_id, dtype=cast_dtype, non_blocking=True)
images = rearrange(images, "b (t f) c h w -> b t f c h w", f=1)
input_ids = torch.stack([x[0] for x in batch_mmc4[1]]).squeeze(1)
attention_mask = torch.stack([x[1] for x in batch_mmc4[1]]).squeeze(1)
# set up labels; language model is expected to handle shifting
labels = input_ids.clone()
labels[labels == tokenizer.pad_token_id] = -100
labels[labels == tokenizer.eos_token] = -100
for i in range(labels.shape[0]):
# remove loss for any token before the first <image> token
label_idx = 0
while (
label_idx < labels.shape[1] and labels[i][label_idx] != media_token_id
):
labels[i][label_idx] = -100
label_idx += 1
# get index of all endofchunk tokens in the sequence
endofchunk_idxs = torch.where(labels[i] == endofchunk_token_id)[0]
for endofchunk_idx in endofchunk_idxs:
token_idx = endofchunk_idx + 1
while (
token_idx < labels.shape[1]
and labels[i][token_idx] != media_token_id
):
labels[i][token_idx] = -100
token_idx += 1
labels[labels == media_token_id] = -100
labels = labels.to(device_id)
# gradient accumulation w/ fsdp cpu offloading requires a no_sync context manager
with autocast():
loss_mmc4 = model(
vision_x=images,
lang_x=input_ids.to(device_id),
attention_mask=attention_mask.to(device_id),
labels=labels,
)[0]
# if loss is nan, skip this batch
# this hack of skipping the batch is not FSDP-compatible
if torch.isnan(loss_mmc4):
print("loss is nan, skipping this batch")
print("input_ids: ", tokenizer.batch_decode(input_ids))
print("labels: ", labels)
print("images: ", images)
optimizer.zero_grad(set_to_none=True)
continue
divided_loss_mmc4 = loss_mmc4 / args.gradient_accumulation_steps
(divided_loss_mmc4 * args.loss_multiplier_mmc4).backward()
if (not args.freeze_lm_embeddings) and (
not args.fsdp or args.fsdp_use_orig_params
):
# Mask gradients for input embeddings s.t. we only update the added tokens <image> and <|endofchunk|>
if args.fsdp:
embed_grad = model.lang_encoder.get_input_embeddings().weight.grad
else:
embed_grad = (
model.module.lang_encoder.get_input_embeddings().weight.grad
)
zero_mask = torch.zeros_like(embed_grad)
zero_mask[media_token_id] = torch.ones_like(zero_mask[media_token_id])
zero_mask[endofchunk_token_id] = torch.ones_like(
zero_mask[endofchunk_token_id]
)
if args.fsdp:
model.lang_encoder.get_input_embeddings().weight.grad = (
embed_grad * zero_mask
)
else:
model.module.lang_encoder.get_input_embeddings().weight.grad = (
embed_grad * zero_mask
)
# clip gradient norm
if args.fsdp:
model.clip_grad_norm_(1.0, norm_type=2.0)
else:
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
# step optimizer and log
if (((step_num + 1) % args.gradient_accumulation_steps) == 0) or (
step_num == num_batches_per_epoch - 1
):
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
# step time and reset end outside of rank 0
step_time_m.update(time.time() - end)
end = time.time()
# rank 0 logging
if args.rank == 0 and args.report_to_wandb:
# calculate samples per second
throughput_metrics = compute_throughput(
args,
datasets,
batch_metadata_to_log,
step_time_m,
)
wandb.log(
{
"global_step": global_step,
"lr": optimizer.param_groups[0]["lr"],
"data_time": data_time_m.avg,
"step_time": step_time_m.avg,
**throughput_metrics,
**losses_to_log,
},
commit=True,
)
step_time_m.reset()
data_time_m.reset()
# Log loss to console
if ((step_num + 1) % args.logging_steps == 0) and args.rank == 0:
print(
f"Step {step_num+1}/{num_batches_per_epoch} of epoch {epoch+1}/{args.num_epochs} complete. Losses: "
+ "// ".join([f"{k}: {v:.3f}" for k, v in losses_to_log.items()])
)
def finetune_one_epoch(
args,
resume_from_step,
model,
epoch,
dataset: DataInfo,
compute_loss_fn: callable,
tokenizer,
optimizer,
lr_scheduler,
device_id,
wandb,
):
"""
Helper function for running one epoch of training.
Handles logging, calling forward, backward, gradient clipping, and optimizer step.
Args:
args (argparse.Namespace): arguments from command line
model: DDP / FSDP wrapped model
epoch (int): epoch number
datasets (list): list of DataInfos, one for each dataset, to train on
compute_loss_fn (callable): function that given the model and inputs, calls forward
and returns a loss
tokenizer: tokenizer for the language model
optimizer: optimizer to step
lr_scheduler: learning rate scheduler
device_id (int): GPU device ID for this rank
wandb: wandb object for logging
"""
# calculate the number of steps in an epoch
num_batches_per_epoch = len(dataset.dataloader)
total_training_steps = num_batches_per_epoch * args.num_epochs
# set up model, autocast, and dtypes
model.train()
autocast = get_autocast(args.precision)
# set up logging
step_time_m = AverageMeter()
data_time_m = AverageMeter()
end = time.time()
# loop through the batches in this epoch
for step_num, samples in tqdm(enumerate(dataset.dataloader),
disable=args.rank != 0,
total=total_training_steps,
initial=epoch * num_batches_per_epoch,
):
# for step_num, samples in enumerate(dataset.dataloader):
if step_num < resume_from_step:
# Jump to the resume step.
continue
data_time_m.update(time.time() - end)
global_step = step_num + epoch * num_batches_per_epoch
# call compute_loss_fn on each dataset; call backward before continuing
losses_to_log = {}
batch_metadata_to_log = {}
# images, (input_ids, attention_mask) = samples
# unpack the batch and move to device
images = samples["images"]
if not isinstance(images, list):
images = images.to(device_id, non_blocking=True)
input_ids = samples["input_ids"].to(device_id, non_blocking=True)
attention_mask = samples["attention_mask"].to(device_id, non_blocking=True)
labels = samples["labels"].to(device_id, non_blocking=True)
# save some metadata for logging
batch_metadata_to_log[
f"{dataset.name}_num_tokens"
] = attention_mask.sum().item()
batch_metadata_to_log[f"{dataset.name}_num_images"] = (
(input_ids == unwrap_model(model).media_token_id).sum().item()
)
# forward pass
loss = compute_loss_fn(
model=model,
tokenizer=tokenizer,
images=images,
image_size=samples['image_size'],
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
autocast=autocast,
)
losses_to_log["train_loss"] = loss.item()
divided_loss = loss / args.gradient_accumulation_steps
divided_loss.backward()
if args.dryrun:
del loss
del divided_loss
optimizer.zero_grad(set_to_none=True)
continue
# FIXME: Where are the special tokens added/defined?
# if (not args.freeze_lm_embeddings) and (
# not args.fsdp or args.fsdp_use_orig_params
# ):
# # Mask gradients for input embeddings s.t. we only update the added tokens <image> and <|endofchunk|>
# if args.fsdp:
# embed_grad = model.lang_encoder.get_input_embeddings().weight.grad
# else:
# embed_grad = (
# model.module.lang_encoder.get_input_embeddings().weight.grad
# )
# zero_mask = torch.zeros_like(embed_grad)
# zero_mask[media_token_id] = torch.ones_like(zero_mask[media_token_id])
# zero_mask[endofchunk_token_id] = torch.ones_like(
# zero_mask[endofchunk_token_id]
# )
# if args.fsdp:
# model.lang_encoder.get_input_embeddings().weight.grad = (
# embed_grad * zero_mask
# )
# else:
# model.module.lang_encoder.get_input_embeddings().weight.grad = (
# embed_grad * zero_mask
# )
# clip gradient norm
if args.fsdp:
model.clip_grad_norm_(1.0, norm_type=2.0)
else:
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
# step optimizer and log
if (((step_num + 1) % args.gradient_accumulation_steps) == 0) or (
step_num == num_batches_per_epoch - 1
):
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
# step time and reset end outside of rank 0
step_time_m.update(time.time() - end)
end = time.time()
# rank 0 logging
if args.rank == 0 and args.report_to_wandb:
# calculate samples per second
throughput_metrics = compute_throughput(
args,
[dataset],
batch_metadata_to_log,
step_time_m,
)
wandb.log(
{
"global_step": global_step,
"lr": optimizer.param_groups[0]["lr"],
**losses_to_log,
"data_time": data_time_m.avg,
"step_time": step_time_m.avg,
**throughput_metrics,
},
commit=True,
)
step_time_m.reset()
data_time_m.reset()
# dist.barrier()
# Log loss to console
if ((step_num + 1) % args.logging_steps == 0) and args.rank == 0:
print(
f"Step {step_num+1}/{num_batches_per_epoch} of epoch {epoch+1}/{args.num_epochs} complete. Losses: "
+ "// ".join([f"{k}: {v:.3f}" for k, v in losses_to_log.items()])
)
if ((step_num + 1) % args.checkpoint_steps == 0):
save_checkpoint(model, optimizer, lr_scheduler, epoch, args, step=step_num)
def get_autocast(precision, cache_enabled=True):
"""
Parses the precision argument and returns an autocast context manager.
"""
if precision == "amp":
return torch.cuda.amp.autocast(cache_enabled=cache_enabled)
elif precision == "amp_bfloat16" or precision == "amp_bf16":
return lambda: torch.cuda.amp.autocast(
dtype=torch.bfloat16, cache_enabled=cache_enabled
)
else:
return suppress
def random_seed(seed=42, rank=0):
"""Seed everything"""
torch.manual_seed(seed + rank)
torch.cuda.manual_seed(seed + rank)
np.random.seed(seed + rank)
random.seed(seed + rank)
def unwrap_model(model):
"""
Unwrap a model from a DataParallel or DistributedDataParallel wrapper.
"""
if isinstance(model, (nn.DataParallel, nn.parallel.DistributedDataParallel)):
return model.module
else:
return model
################################
# Helper functions for logging #
################################
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def compute_throughput(
args,
datasets,
batch_metadata,
step_time_m,
):
"""
Computes throughput metrics for logging, including samples per second and tokens per second.
"""
log = {}
for dataset in datasets:
log[f"{dataset.name}_samples_per_second_per_gpu"] = (
args.gradient_accumulation_steps * dataset.batch_size / step_time_m.val
)
log[f"{dataset.name}_samples_per_second"] = (
log[f"{dataset.name}_samples_per_second_per_gpu"] * args.world_size
)
log[f"{dataset.name}_tokens_per_second_per_gpu"] = (
args.gradient_accumulation_steps
* batch_metadata[f"{dataset.name}_num_tokens"]
/ step_time_m.val
)
log[f"{dataset.name}_tokens_per_second"] = (
log[f"{dataset.name}_tokens_per_second_per_gpu"] * args.world_size
) # this is an estimate based on rank 0
log[f"{dataset.name}_images_per_second_per_gpu"] = (
args.gradient_accumulation_steps
* batch_metadata[f"{dataset.name}_num_images"]
/ step_time_m.val
)
log[f"{dataset.name}_images_per_second"] = (
log[f"{dataset.name}_images_per_second_per_gpu"] * args.world_size
) # this is an estimate based on rank 0
return log
####################################################
# Helper functions for checkpoint loading / saving #
####################################################
def find_most_recent_checkpoint(args):
"""
Returns the path of the most recent checkpoint for a given run name.
"""
checkpoint_list = glob.glob(f"{args.run_name}/checkpoint_*.pt")
if len(checkpoint_list) == 0:
print(f"Found no checkpoints for run {args.run_name}.")
resume_from_checkpoint = None
else:
resume_from_checkpoint = sorted(
checkpoint_list, key=lambda x: int(x.split("_")[-1].split(".")[0])
)[-1]
print(f"Found checkpoint {resume_from_checkpoint} for run {args.run_name}.")
return resume_from_checkpoint
def load_checkpoint(args, model, pretrained=False):
"""
Loads a checkpoint into the model and returns the checkpoint + epoch to resume from.
Does not load the optimizer or learning rate checkpoints, but these are included in the returned checkpoint dict.
"""
if pretrained:
ckpt_path = args.pretrained
else:
ckpt_path = args.resume_from_checkpoint
if args.rank == 0:
print(f"Loading checkpoint from {ckpt_path}")
checkpoint = torch.load(ckpt_path, map_location="cpu")
# msd = checkpoint.pop("model_state_dict")
if "model_state_dict" in checkpoint:
msd = checkpoint.pop("model_state_dict")
else:
print("No 'model_state_dict' found. Using entire checkpoint as model state dict.")
msd = checkpoint
msd = {k.replace("module.", ""): v for k, v in msd.items()}
if 'vision_tokenizer.latents' in msd.keys():
msd_current = model.state_dict()
if msd_current['vision_tokenizer.latents'].shape != msd['vision_tokenizer.latents'].shape:
msd["vision_tokenizer.latents"] = msd_current['vision_tokenizer.latents'] # Random re-init.
# remove any module with vision_encoder in the name
# msd = {k: v for k, v in msd.items() if "vision_encoder" not in k}
if not pretrained:
resume_from_epoch = checkpoint["epoch"] + 1
else:
resume_from_epoch = None
if 'step' in checkpoint and checkpoint["step"] is not None:
resume_from_step = checkpoint["step"] + 1
resume_from_epoch = checkpoint["epoch"] # Resume from prev epoch at the given step.
else:
resume_from_step = 0
if args.fsdp:
FSDP.set_state_dict_type(
model,
**args.fsdp_checkpoint_config,
)
result = model.load_state_dict(msd, strict=False)
# Print missing and unexpected keys
print("Missing keys:", result.missing_keys)
print("Unexpected keys:", result.unexpected_keys)
return resume_from_epoch, resume_from_step, checkpoint
def filter_state_dict_to_trainable(model, state_dict):
"""
Remove non-trainable parameters from model state dict.
Exception: Embeddings will not be removed, even if frozen.
This is because we need the new <image> <|endofchunk|> tokens to
be consistent across initializations.
"""
# first, remove frozen params
for name, p in model.named_parameters():
if "fsdp" in name:
continue
if not p.requires_grad:
name = name.replace("._checkpoint_wrapped_module", "")
if name in state_dict:
del state_dict[name]
else:
print(f"WARNING: filtering but {name} not in state_dict")
# second, remove additional duplicate params
duplicate = lambda k: (
"lang_model.old_decoder_blocks" in k
or "lang_model.gated_cross_attn_layers" in k
)
filtered_dict = {
key: value for key, value in state_dict.items() if not duplicate(key)
}
return filtered_dict
def save_checkpoint(model, optimizer, lr_scheduler, epoch, args, step=None):
"""
Save training checkpoint with model, optimizer, and lr_scheduler state.
"""
torch.cuda.empty_cache() # (Sometimes this is necessary to avoid OOM errors when saving checkpoints)
if args.fsdp:
FSDP.set_state_dict_type(
model,
**args.fsdp_checkpoint_config,
)
model_state = model.state_dict()
optim_state = FSDP.optim_state_dict(model, optimizer)
else:
model_state = model.state_dict()
optim_state = optimizer.state_dict()
if args.rank == 0:
model_state = filter_state_dict_to_trainable(model, model_state)
if not os.path.exists(args.run_name):
os.makedirs(args.run_name)
checkpoint_dict = {
"epoch": epoch,
"step": step,
"model_state_dict": model_state,
"optimizer_state_dict": optim_state,
"lr_scheduler_state_dict": lr_scheduler.state_dict(),
}
if args.no_save_optim_state and step is None:
del checkpoint_dict['optimizer_state_dict']
del checkpoint_dict['lr_scheduler_state_dict']
if step is not None:
save_name = f"{args.run_name}/checkpoint_{step}.pt"
else:
save_name = f"{args.run_name}/checkpoint_{epoch}.pt"
print(f"Saving checkpoint to {save_name}")
torch.save(checkpoint_dict, save_name)
if args.report_to_wandb and args.save_checkpoints_to_wandb:
wandb.save(f"{save_name}")
if args.delete_previous_checkpoint:
if epoch > 0:
os.remove(f"{args.run_name}/checkpoint_{epoch-1}.pt")
else:
checkpoint_list = glob.glob(f"{args.run_name}/checkpoint_*.pt")
if len(checkpoint_list) > 1:
last_checkpoint = sorted(
checkpoint_list, key=lambda x: int(x.split("_")[-1].split(".")[0])
)[0]
os.remove(f"{last_checkpoint}")
torchvision
braceexpand
webdataset
tqdm
wandb
\ No newline at end of file
#!/bin/bash
exp_name="finetune-xgenmmv1-phi3_4k_instruct"
data_path="/blip-3_pytorch/data_configs/example_data_config.yaml"
if [[ ! -e $exp_name ]]; then
mkdir $exp_name
fi
pretrained_ckpt="/blip-3_pytorch/pretrain_model/xgen-mm-phi3-mini-base-r-v1.5.pt"
HIP_VISIBLE_DEVICES=4,5,6,7 python -m torch.distributed.run --nproc_per_node=4 --nnodes=1 --master_port 9650 /blip-3_pytorch/open_flamingo/train/instruction_finetune.py \
--lm_path /blip-3_pytorch/pretrain_model/Phi-3-mini-4k-instruct \
--tokenizer_path /blip-3_pytorch/pretrain_model/Phi-3-mini-4k-instruct \
--conv_template_name phi_3 \
--vision_encoder_path /blip-3_pytorch/pretrain_model/siglip-so400m-patch14-384 \
--vision_encoder_pretrained google \
--model_family 'xgenmm_v1' \
--num_vision_tokens 128 \
--pretrained ${pretrained_ckpt} \
--data_path ${data_path} \
--data_sampler_group_by_length \
--image_aspect_ratio anyres --anyres_patch_sampling \
--batch_size 8 \
--fsdp \
--no_save_optim_state \
--gradient_checkpointing \
--fsdp_sharding_strategy hybrid \
--workers 4 \
--num_epochs 1 \
--warmup_steps 2000 \
--learning_rate 2e-5 \
--weight_decay 0.0 \
--lr_scheduler cosine \
--precision amp_bf16 \
--run_name ${exp_name} 2>&1 | tee ${exp_name}/terminal_output.log;
# --report_to_wandb \
# --wandb_project "blip3-xgenmm-finetune" \
from pathlib import Path
from setuptools import find_packages, setup
if __name__ == "__main__":
with Path(Path(__file__).parent, "README.md").open(encoding="utf-8") as file:
long_description = file.read()
with open("requirements.txt") as f:
REQUIREMENTS = f.read().splitlines()
# with open("requirements-eval.txt") as f:
# EVAL = f.read().splitlines()
with open("requirements-training.txt") as f:
TRAINING = f.read().splitlines()
setup(
name="open_flamingo",
packages=find_packages(),
include_package_data=True,
version="2.0.1",
license="MIT",
description="An open-source framework for training large multimodal models",
long_description=long_description,
long_description_content_type="text/markdown",
data_files=[(".", ["README.md"])],
keywords=["machine learning"],
install_requires=REQUIREMENTS,
extras_require={
# "eval": EVAL,
"training": TRAINING,
# "all": list(set(REQUIREMENTS + EVAL + TRAINING)),
"all": list(set(REQUIREMENTS + TRAINING)),
},
classifiers=[
"Development Status :: 4 - Beta",
"Intended Audience :: Developers",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"License :: OSI Approved :: MIT License",
"Programming Language :: Python :: 3.9",
],
)
{
"architectures": [
"XGenMMModelForConditionalGeneration"
],
"auto_map": {
"AutoConfig": "modeling_xgenmm.XGenMMConfig",
"AutoModelForVision2Seq": "modeling_xgenmm.XGenMMModelForConditionalGeneration"
},
"model_type": "xgenmm",
"text_config": {
"attention_dropout": 0.0,
"embd_pdrop": 0.0,
"hidden_act": "silu",
"hidden_size": 3072,
"initial_tokenizer_len": 32012,
"initializer_range": 0.02,
"intermediate_size": 8192,
"max_position_embeddings": 4096,
"model_type": "phi3",
"num_attention_heads": 32,
"num_hidden_layers": 32,
"num_key_value_heads": 32,
"original_max_position_embeddings": 4096,
"partial_rotary_factor": 1.0,
"resid_pdrop": 0.0,
"rms_norm_eps": 1e-05,
"rope_scaling": null,
"rope_theta": 10000.0,
"sliding_window": 2047,
"torch_dtype": "bfloat16",
"use_cache": true,
"vocab_size": 32064
},
"torch_dtype": "float32",
"transformers_version": "4.51.1",
"vision_encoder_config": {
"anyres_patch_sampling": true,
"image_aspect_ratio": "anyres",
"model_name": "google/siglip-so400m-patch14-384",
"model_type": "xgenmm_vision_encoder"
},
"vision_tokenizer_config": {
"image_aspect_ratio": "none",
"lang_embedding_dim": 3072,
"model_type": "xgenmm_vision_tokenizer",
"num_vis_tokens": 128,
"vis_feature_dim": 1152
}
}
import ast
import math
from einops import rearrange, repeat
from einops_exts import rearrange_many
from einops import rearrange
from PIL import Image
import torch
from torch import einsum, nn
from typing import List, Optional, Tuple, Union
import torch.nn.functional as F
from transformers.modeling_outputs import CausalLMOutputWithPast
from dataclasses import dataclass
from transformers import CLIPVisionModel
from transformers import PreTrainedModel, AutoModelForCausalLM, AutoModel
from transformers import PretrainedConfig, logging, CONFIG_MAPPING
from transformers.models.siglip.modeling_siglip import SiglipVisionTransformer
logger = logging.get_logger(__name__)
class XGenMMVisionEncoderConfig(PretrainedConfig):
model_type = "xgenmm_vision_encoder"
def __init__(self, model_name: str = "google/siglip-so400m-patch14-384", **kwargs):
self.model_name = model_name
super().__init__(**kwargs)
class XGenMMVisionTokenizerConfig(PretrainedConfig):
model_type = "xgenmm_vision_tokenizer"
def __init__(
self,
vis_feature_dim: int = 1152,
lang_embedding_dim: int = 3072,
num_vis_tokens: int = 128,
image_aspect_ratio: str = "none",
**kwargs,
):
self.vis_feature_dim = vis_feature_dim
self.lang_embedding_dim = lang_embedding_dim
self.num_vis_tokens = num_vis_tokens
self.image_aspect_ratio = image_aspect_ratio
super().__init__(**kwargs)
class XGenMMConfig(PretrainedConfig):
model_type = "xgenmm"
def __init__(
self,
vision_encoder_config: dict = None,
vision_tokenizer_config: dict = None,
text_config: dict = None,
**kwargs,
):
if vision_encoder_config is None:
vision_encoder_config = {
"image_aspect_ratio": "anyres",
"anyres_patch_sampling": True,
}
logger.info(
"vision_encoder_config is None. initializing the XGenMMVisionEncoderConfig with default values."
)
if vision_tokenizer_config is None:
vision_tokenizer_config = {}
logger.info(
"vision_tokenizer_config is None. Initializing the XGenMMVisionTokenizerConfig with default values."
)
if text_config is None:
text_config = {
"initial_tokenizer_len": 32012,
"pad_token_id": 32011,
"bos_token_id": 1,
"eos_token_id": 32000,
"vocab_size": 32064,
"hidden_size": 3072,
"intermediate_size": 8192,
"num_hidden_layers": 32,
"num_attention_heads": 32,
"num_key_value_heads": 32,
"resid_pdrop": 0.0,
"embd_pdrop": 0.0,
"attention_dropout": 0.0,
"hidden_act": "silu",
"max_position_embeddings": 4096,
"original_max_position_embeddings": 4096,
"initializer_range": 0.02,
"rms_norm_eps": 1e-05,
"use_cache": True,
"rope_theta": 10000.0,
"rope_scaling": None,
"sliding_window": 2047,
"return_dict": True,
"output_hidden_states": False,
"output_attentions": False,
"torchscript": False,
"torch_dtype": "bfloat16",
"use_bfloat16": False,
"tf_legacy_loss": False,
"pruned_heads": {},
"tie_word_embeddings": False,
"chunk_size_feed_forward": 0,
"is_encoder_decoder": False,
"is_decoder": False,
"cross_attention_hidden_size": None,
"add_cross_attention": False,
"tie_encoder_decoder": False,
"max_length": 20,
"min_length": 0,
"do_sample": False,
"early_stopping": False,
"num_beams": 1,
"num_beam_groups": 1,
"diversity_penalty": 0.0,
"temperature": 1.0,
"top_k": 50,
"top_p": 1.0,
"typical_p": 1.0,
"repetition_penalty": 1.0,
"length_penalty": 1.0,
"no_repeat_ngram_size": 0,
"encoder_no_repeat_ngram_size": 0,
"bad_words_ids": None,
"num_return_sequences": 1,
"output_scores": False,
"return_dict_in_generate": False,
"forced_bos_token_id": None,
"forced_eos_token_id": None,
"remove_invalid_values": False,
"exponential_decay_length_penalty": None,
"suppress_tokens": None,
"begin_suppress_tokens": None,
"finetuning_task": None,
"id2label": {0: "LABEL_0", 1: "LABEL_1"},
"label2id": {"LABEL_0": 0, "LABEL_1": 1},
"tokenizer_class": None,
"prefix": None,
"bos_token_id": 1,
"pad_token_id": 32000,
"eos_token_id": 32000,
"sep_token_id": None,
"decoder_start_token_id": None,
"task_specific_params": None,
"problem_type": None,
"model_type": "phi3",
}
logger.info(
"text_config is None. Initializing the text config with default values (`Phi3Config`)."
)
self.vision_encoder_config = XGenMMVisionEncoderConfig(**vision_encoder_config)
self.vision_tokenizer_config = XGenMMVisionTokenizerConfig(
**vision_tokenizer_config
)
text_model_type = (
text_config["model_type"] if "model_type" in text_config else "phi3"
)
self.text_config = CONFIG_MAPPING[text_model_type](**text_config)
for key in ["initial_tokenizer_len", "pad_token_id"]:
if key not in self.text_config.to_dict():
raise ValueError(f"The key `{key}` is missing in the text_config.")
super().__init__(**kwargs)
@classmethod
def from_vision_encoder_vision_tokenizer_text_configs(
cls,
vision_encoder_config: XGenMMVisionEncoderConfig,
vision_tokenizer_config: XGenMMVisionTokenizerConfig,
text_config: PretrainedConfig,
**kwargs,
):
return cls(
vision_encoder_config=vision_encoder_config.to_dict(),
vision_tokenizer_config=vision_tokenizer_config.to_dict(),
text_config=text_config.to_dict(),
**kwargs,
)
def has_fn(model, fn_name):
"""Check if model has a function fn_name"""
return callable(getattr(model, fn_name, None))
def exists(val):
return val is not None
def num_params(module, filter_to_trainable=False):
"""Returns the number of parameters in the module, or optionally only the trainable parameters"""
if filter_to_trainable:
return sum(p.numel() for p in module.parameters() if p.requires_grad)
else:
return sum(p.numel() for p in module.parameters())
def hasattr_recursive(obj, att):
"""
Check if obj has nested attribute
Example: hasattr_recursive(obj, 'a.b.c') is equivalent to hasattr(obj, 'a') and hasattr(obj.a, 'b') and hasattr(obj.a.b, 'c')
"""
if att == "":
return True
i = att.find(".")
if i < 0:
return hasattr(obj, att)
else:
try:
return hasattr_recursive(getattr(obj, att[:i]), att[i + 1 :])
except:
return False
def getattr_recursive(obj, att):
"""
Return nested attribute of obj
Example: getattr_recursive(obj, 'a.b.c') is equivalent to obj.a.b.c
"""
if att == "":
return obj
i = att.find(".")
if i < 0:
return getattr(obj, att)
else:
return getattr_recursive(getattr(obj, att[:i]), att[i + 1 :])
def setattr_recursive(obj, att, val):
"""
Set nested attribute of obj
Example: setattr_recursive(obj, 'a.b.c', val) is equivalent to obj.a.b.c = val
"""
if "." in att:
obj = getattr_recursive(obj, ".".join(att.split(".")[:-1]))
setattr(obj, att.split(".")[-1], val)
def stack_with_padding(list_of_tensors, padding_value=0, padding_side="right"):
"""
Stack a list of tensors with padding on one side
Args:
list_of_tensors (list[torch.Tensor]): List of tensors to stack
padding_value (int, optional): Value to pad with. Defaults to 0.
padding_side (str, optional): Side to pad on. Defaults to "right".
Returns:
torch.Tensor: Stacked tensors
"""
max_tokens = max(tensor.size(0) for tensor in list_of_tensors)
padded_tensors = []
for tensor in list_of_tensors:
num_tokens = tensor.size(0)
if len(tensor.size()) == 1:
padding = torch.full(
(max_tokens - num_tokens,),
padding_value,
dtype=tensor.dtype,
device=tensor.device,
)
else:
padding = torch.full(
(max_tokens - num_tokens, tensor.size(1)),
padding_value,
dtype=tensor.dtype,
device=tensor.device,
)
padded_tensor = (
torch.cat((tensor, padding), dim=0)
if padding_side == "right"
else torch.cat((padding, tensor), dim=0)
)
padded_tensors.append(padded_tensor)
return torch.stack(padded_tensors)
def check_embedding_fns(lang_model):
"""Checks for and attempts to set {get/set}_{input/output}_embeddings functions to the model"""
if not has_fn(lang_model, "get_input_embeddings"):
if hasattr_recursive(lang_model, "transformer.wte"): # MPT
lang_model.get_input_embeddings = lambda: lang_model.transformer.wte
elif hasattr_recursive(lang_model, "model.decoder.embed_tokens"): # OPT
lang_model.get_input_embeddings = lambda: lang_model.decoder.embed_tokens
else:
raise ValueError(
"We require the language encoder to have a get_input_embeddings method but we couldn't determine the name of the input embeddings attribute. Please supply this manually in factory.py."
)
if not has_fn(lang_model, "set_input_embeddings"):
if hasattr_recursive(lang_model, "transformer.wte"): # MPT
lang_model.set_input_embeddings = lambda x: setattr_recursive(
lang_model, "transformer.wte", x
)
elif hasattr_recursive(lang_model, "model.decoder.embed_tokens"): # OPT
lang_model.set_input_embeddings = lambda x: setattr_recursive(
lang_model, "model.decoder.embed_tokens", x
)
else:
raise ValueError(
"We require the language encoder to have a set_input_embeddings method but we couldn't determine the name of the input embeddings attribute. Please supply this manually in factory.py."
)
if not has_fn(lang_model, "get_output_embeddings"):
if hasattr_recursive(lang_model, "lm_head"):
lang_model.get_output_embeddings = lambda: lang_model.lm_head
else:
raise ValueError(
"We require the language encoder to have a get_output_embeddings method but we couldn't determine the name of the output embeddings attribute. Please supply this manually in factory.py."
)
if not has_fn(lang_model, "set_output_embeddings"):
if hasattr_recursive(lang_model, "lm_head"):
lang_model.set_output_embeddings = lambda x: setattr_recursive(
lang_model, "lm_head", x
)
else:
raise ValueError(
"We require the language encoder to have a set_output_embeddings method but we couldn't determine the name of the output embeddings attribute. Please supply this manually in factory.py."
)
def has_fn(model, fn_name):
"""Check if model has a function fn_name"""
return callable(getattr(model, fn_name, None))
def unpad_image(tensor, original_size, keep_original_shape=False):
"""
Unpads a PyTorch tensor of a padded and resized image.
Args:
tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format.
original_size (tuple): The original size of the image (height, width).
Returns:
torch.Tensor: The unpadded image tensor.
"""
original_width, original_height = original_size
current_height, current_width = tensor.shape[1:]
original_aspect_ratio = original_width / original_height
current_aspect_ratio = current_width / current_height
if original_aspect_ratio > current_aspect_ratio:
scale_factor = current_width / original_width
new_height = int(original_height * scale_factor)
padding = (current_height - new_height) // 2
if keep_original_shape:
attention_mask = torch.ones(
(current_height, current_width), device=tensor.device
)
attention_mask[:padding, :] = 0
attention_mask[current_height - padding :, :] = 0
return tensor, attention_mask
else:
unpadded_tensor = tensor[:, padding : current_height - padding, :]
return unpadded_tensor, None
else:
scale_factor = current_height / original_height
new_width = int(original_width * scale_factor)
padding = (current_width - new_width) // 2
if keep_original_shape:
attention_mask = torch.ones(
(current_height, current_width), device=tensor.device
)
attention_mask[:, :padding] = 0
attention_mask[:, current_width - padding :] = 0
return tensor, attention_mask
else:
unpadded_tensor = tensor[:, :, padding : current_width - padding]
return unpadded_tensor, None
def expand2square(pil_img, background_color):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
class VisionTokenizer(nn.Module):
def __init__(self, dim_media, num_tokens_per_media):
super().__init__()
self.dim_media = dim_media
self.num_tokens_per_media = num_tokens_per_media
class PerceiverAttention(nn.Module):
def __init__(self, *, dim, dim_head=64, heads=8):
super().__init__()
self.scale = dim_head**-0.5
self.heads = heads
inner_dim = dim_head * heads
self.norm_media = nn.LayerNorm(dim)
self.norm_latents = nn.LayerNorm(dim)
self.to_q = nn.Linear(dim, inner_dim, bias=False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
self.to_out = nn.Linear(inner_dim, dim, bias=False)
def forward(self, x, latents, vision_attn_masks=None):
"""
Args:
x (torch.Tensor): image features
shape (b, T, n1, D)
latent (torch.Tensor): latent features
shape (b, T, n2, D)
"""
x = self.norm_media(x)
latents = self.norm_latents(latents)
h = self.heads
q = self.to_q(latents)
kv_input = torch.cat(
(x, latents), dim=-2
) # TODO: Change the shape of vision attention mask according to this.
if vision_attn_masks is not None:
vision_attn_masks = torch.cat(
(
vision_attn_masks,
torch.ones(
(latents.shape[0], latents.shape[-2]),
dtype=latents.dtype,
device=latents.device,
),
),
dim=-1,
)
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
q, k, v = rearrange_many((q, k, v), "b t n (h d) -> b h t n d", h=h)
q = q * self.scale
# attention
sim = einsum("... i d, ... j d -> ... i j", q, k)
# Apply vision attention mask here.
# Reference: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html#torch.nn.functional.scaled_dot_product_attention
if vision_attn_masks is not None:
attn_bias = torch.zeros(
(q.size(0), 1, 1, q.size(-2), k.size(-2)),
dtype=q.dtype,
device=q.device,
)
vision_attn_masks = repeat(
vision_attn_masks, "b n -> b 1 1 l n", l=q.size(-2)
)
attn_bias.masked_fill_(vision_attn_masks.logical_not(), float("-inf"))
sim += attn_bias
sim = sim - sim.amax(dim=-1, keepdim=True).detach()
attn = sim.softmax(dim=-1)
out = einsum("... i j, ... j d -> ... i d", attn, v)
out = rearrange(out, "b h t n d -> b t n (h d)", h=h)
return self.to_out(out)
def FeedForward(dim, mult=4):
inner_dim = int(dim * mult)
return nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, inner_dim, bias=False),
nn.GELU(),
nn.Linear(inner_dim, dim, bias=False),
)
class PerceiverResampler(VisionTokenizer):
def __init__(
self,
*,
dim,
dim_inner=None,
depth=6,
dim_head=96,
heads=16,
num_latents=128,
max_num_media=None,
max_num_frames=None,
ff_mult=4,
):
"""
Perceiver module which takes in image features and outputs image tokens.
Args:
dim (int): dimension of the incoming image features
dim_inner (int, optional): final dimension to project the incoming image features to;
also the final dimension of the outputted features. If None, no projection is used, and dim_inner = dim.
depth (int, optional): number of layers. Defaults to 6.
dim_head (int, optional): dimension of each head. Defaults to 64.
heads (int, optional): number of heads. Defaults to 8.
num_latents (int, optional): number of latent tokens to use in the Perceiver;
also corresponds to number of tokens per sequence to output. Defaults to 64.
max_num_media (int, optional): maximum number of media per sequence to input into the Perceiver
and keep positional embeddings for. If None, no positional embeddings are used.
max_num_frames (int, optional): maximum number of frames to input into the Perceiver
and keep positional embeddings for. If None, no positional embeddings are used.
ff_mult (int, optional): dimension multiplier for the feedforward network. Defaults to 4.
"""
if dim_inner is not None:
projection = nn.Linear(dim, dim_inner)
else:
projection = None
dim_inner = dim
super().__init__(dim_media=dim, num_tokens_per_media=num_latents)
self.projection = projection
self.latents = nn.Parameter(torch.randn(num_latents, dim))
# positional embeddings
self.frame_embs = (
nn.Parameter(torch.randn(max_num_frames, dim))
if exists(max_num_frames)
else None
)
self.media_time_embs = (
nn.Parameter(torch.randn(max_num_media, 1, dim))
if exists(max_num_media)
else None
)
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(
nn.ModuleList(
[
PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
FeedForward(dim=dim, mult=ff_mult),
]
)
)
self.norm = nn.LayerNorm(dim)
def forward(self, x, vision_attn_masks=None):
"""
Args:
x (torch.Tensor): image features
shape (b, T, F, v, D)
vision_attn_masks (torch.Tensor): attention masks for padded visiont tokens (i.e., x)
shape (b, v)
Returns:
shape (b, T, n, D) where n is self.num_latents
"""
b, T, F, v = x.shape[:4]
# frame and media time embeddings
if exists(self.frame_embs):
frame_embs = repeat(self.frame_embs[:F], "F d -> b T F v d", b=b, T=T, v=v)
x = x + frame_embs
x = rearrange(
x, "b T F v d -> b T (F v) d"
) # flatten the frame and spatial dimensions
if exists(self.media_time_embs):
x = x + self.media_time_embs[:T]
# blocks
latents = self.latents
latents = repeat(latents, "n d -> b T n d", b=b, T=T)
for attn, ff in self.layers:
latents = attn(x, latents, vision_attn_masks) + latents
latents = ff(latents) + latents
if exists(self.projection):
return self.projection(self.norm(latents))
else:
return self.norm(latents)
class DecoupledEmbedding(nn.Embedding):
# Derived from https://pytorch.org/docs/stable/_modules/torch/nn/modules/sparse.html#Embedding
"""
Implements a decoupling of parameters to allow freezing (or not) a subset of the embeddings. In practise, the
regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `num_additional_embeddings` > 0,
then it will create `num_additional_embeddings` additional parameters that are always trained. If
`num_additional_embeddings=0`, then the module defaults back to the regular behavior of `nn.Embedding`.
"""
def __init__(
self,
max_original_id: int,
num_additional_embeddings: int = 0,
_weight: torch.Tensor = None,
num_original_embeddings: int = None,
embedding_dim: int = None,
partially_freeze=True,
device=None,
dtype=None,
pad_token_id=None,
) -> None:
"""
Args:
max_original_id (`int`):
The largest token id that should be embedded using the regular embedding (regular `weight`).
This is usually len(tokenizer) - 1 before additional tokens are added.
Note that this may not equal self.weight.shape[0]
num_additional_embeddings (`int`):
Number of additional tokens to initialize an Embedding matrix for (`additional_weight`).
_weight (`torch.Tensor`, *optional*, defaults to `None`): The regular weight tensor.
If provided, this sets the `num_original_embeddings` and `embedding_dim` parameters.
num_original_embeddings (`int`):
self.weight.shape[0]
embedding_dim (`int`):
The size of each embedding vector
partially_freeze: (`bool`, *optional*, defaults to `True`):
If `True`, the regular `weight` will be frozen. `additional_weight` is never frozen.
padding_idx (`int`, *optional*):
The padding index (needs to be less than num_embeddings)
Note: there are a lot of other parameters to initialize a standard `nn.Embedding` such as `padding_idx`,
`max_norm` or `norm_type`. We are not supporting these.
"""
# validate args
if pad_token_id is not None and pad_token_id > max_original_id:
raise ValueError(
f"pad_token_id must be <= max_original_id. Got {pad_token_id} and {max_original_id}."
+ "If the original tokenizer does not have a pad_token_id, use pad_token_id=None."
)
if _weight is not None:
assert (num_original_embeddings is None) or (
_weight.shape[0] == num_original_embeddings
), f"num_original_embeddings={num_original_embeddings} but _weight.shape[0]={_weight.shape[0]}"
assert (embedding_dim is None) or (
_weight.shape[1] == embedding_dim
), f"embedding_dim={embedding_dim} but _weight.shape[1]={_weight.shape[1]}"
num_original_embeddings = _weight.shape[0]
embedding_dim = _weight.shape[1]
else:
assert (
num_original_embeddings is not None
), "num_original_embeddings must be provided if _weight is not provided"
assert (
embedding_dim is not None
), "embedding_dim must be provided if _weight is not provided"
super().__init__(
num_embeddings=num_original_embeddings,
embedding_dim=embedding_dim,
device=device,
dtype=dtype,
padding_idx=pad_token_id,
_weight=_weight,
)
self.max_original_id = max_original_id
self.padding_idx = pad_token_id
self.num_additional_embeddings = num_additional_embeddings
if self.num_additional_embeddings > 0:
self.additional_embedding = nn.Embedding(
num_embeddings=self.num_additional_embeddings,
embedding_dim=embedding_dim,
device=device,
dtype=dtype,
)
self.set_requires_grad(
require_regular_grad=not partially_freeze, require_additional_grad=True
)
def set_requires_grad(self, require_regular_grad, require_additional_grad):
"""
Helper function to separately set the requires_grad flag for the regular weight and the additional weight.
"""
self.weight.requires_grad_(require_regular_grad)
self.additional_embedding.requires_grad_(require_additional_grad)
def forward(self, input_ids):
"""
we have 2 embeddings, with different indices - one pretrained self.weight and another
self.additional_embedding.weight that is being trained.
in order to make a lookup of the input ids, we:
1. find out the indices of the entries belonging to the 2nd embedding
2. extract those values while subtracting the size of the first embedding (num_embeddings), since the 2nd
embedding starts from 0 and not num_embeddings
3. perform the 2nd embedding lookup
4. now we handle the 1st embedding, we overwrite indices belonging to the 2nd embedding with a padding index
5. perform the 1st embedding lookup
6. now we overwrite the values in the 1st embedding lookup with the values of the 2nd embedding lookup
note: for the 1st embedding lookup we could have looked up only the low indices and not do the padding, but
then we have to create a new tensor and populate it with 2 tensors that are spread out across various indices -
i.e. not a simple concat - I haven't benchmarked the complex case if it's any faster, given that seqlens are
usually relatively short it's probably not faster or if faster not by much - but might be a good idea to
measure.
"""
if self.num_additional_embeddings == 0:
return F.embedding(input_ids, self.weight)
# Clone so that we don't modify the original input_ids later on
input_ids = input_ids.clone()
additional_vocab_indices = torch.where(input_ids > self.max_original_id)
input_ids_additional_vocab = input_ids[additional_vocab_indices]
additional_embeddings = self.additional_embedding(
input_ids_additional_vocab - self.max_original_id - 1
)
# for successful lookup replace input_ids with 0, the results of these will be discarded anyway
input_ids[additional_vocab_indices] = 0
full_vector = F.embedding(input_ids, self.weight)
# overwrite the records with high indices
full_vector[additional_vocab_indices] = additional_embeddings
return full_vector
def extra_repr(self) -> str:
return "num_original_embeddings={}, num_additional_embeddings={}, embedding_dim={}, partially_freeze={}".format(
self.max_original_id + 1,
self.num_additional_embeddings,
self.embedding_dim,
(not self.weight.requires_grad),
)
class DecoupledLinear(nn.Linear):
# Derived from https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html#Linear
"""
Implements a decoupling of parameters to allow freezing (or not) a subset of the parameters. In practise, the
regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `additional_out_features` > 0,
then it will create `additional_out_features * in_features` additional parameters that are always trained. If
`additional_out_features=0`, then the module defaults back to the regular behavior of `nn.Linear`.
"""
def __init__(
self,
max_original_id: int,
additional_out_features: int = 0,
_weight: torch.Tensor = None,
_bias: torch.Tensor = None,
in_features: int = None,
original_out_features: int = None,
bias: bool = True,
partially_freeze: bool = True,
device=None,
dtype=None,
) -> None:
"""
Args:
max_original_id (`int`): The largest token id that should be extracted from the regular weight.
This is usually len(tokenizer) - 1 before additional tokens are added.
Note that this may not equal original_out_features - 1
_weight: torch.Tensor, *optional*, defaults to `None`. The regular weight tensor.
If provided, this sets the `in_features` and `original_out_features` parameters.
_bias: torch.Tensor, *optional*, defaults to `None`. The regular bias tensor.
in_features: int. Input hidden size.
original_out_features: int. Original out_features of the language model's get_output_embeddings() function.
additional_out_features: int. Number of additional trainable dimensions.
bias: bool. Whether to include a bias term.
partially_freeze: bool, *optional*, defaults to `True`): If `True`, the regular `weight` will be frozen.
"""
# argument validation
if _weight is not None:
assert (_weight.shape[0] == original_out_features) or (
original_out_features is None
), f"original_out_features={original_out_features} but _weight.shape[0]={_weight.shape[0]}"
assert (_weight.shape[1] == in_features) or (
in_features is None
), f"in_features={in_features} but _weight.shape[1]={_weight.shape[1]}"
in_features = _weight.shape[1]
original_out_features = _weight.shape[0]
else:
assert (
in_features is not None
), "in_features must be provided if _weight is not provided"
assert (
original_out_features is not None
), "original_out_features must be provided if _weight is not provided"
if _bias is not None:
assert bias is True, "bias must be True if _bias is provided"
# initialize original linear
super().__init__(in_features, original_out_features, bias, device, dtype)
# set weight and bias manually
if _weight is not None:
self.weight = nn.Parameter(_weight)
if _bias is not None:
self.bias = nn.Parameter(_bias)
self.in_features = in_features
self.original_out_features = original_out_features
self.max_original_id = max_original_id
# initialize additional linear
self.additional_out_features = additional_out_features
self.has_bias = bias
if additional_out_features > 0:
self.additional_fc = nn.Linear(
in_features=in_features,
out_features=additional_out_features,
bias=self.has_bias,
device=device,
dtype=dtype,
)
self.set_requires_grad(
require_regular_grad=not partially_freeze, require_additional_grad=True
)
def set_requires_grad(self, require_regular_grad, require_additional_grad):
"""
Helper function to separately set the requires_grad flag for the regular weight and the additional weight.
"""
self.weight.requires_grad_(require_regular_grad)
if self.has_bias:
self.bias.requires_grad_(require_regular_grad)
self.additional_fc.requires_grad_(require_additional_grad)
def forward(self, input: torch.Tensor) -> torch.Tensor:
output = F.linear(input, self.weight, self.bias)
output = output[..., : self.max_original_id + 1]
if self.additional_out_features > 0:
additional_features = F.linear(
input, self.additional_fc.weight, self.additional_fc.bias
)
output = torch.cat((output, additional_features), -1)
return output
def extra_repr(self) -> str:
"""Overwriting `nn.Linear.extra_repr` to include new parameters."""
return "in_features={}, out_features={}, additional_out_features={}, bias={}, partially_freeze={}".format(
self.in_features,
self.max_original_id + 1,
self.additional_out_features,
self.bias is not None,
(not self.weight.requires_grad or not self.bias.requires_grad),
)
class VLM(nn.Module):
"""
Generic vision-language model (VLM) class.
A VLM consists of four components:
1. A vision encoder that extracts features from pixels, e.g. CLIP
input: (B, T_img, F, C, H, W)
output: (B, T_img, F, v, d)
2. A vision tokenizer that converts these features to visual token-like embeddings, e.g. Perceiver, or a linear projection head
input: (B, T_img, F, v, d)
output: (B, T_img, n, d)
3. A fusion method that allows the language model to attend to these tokens, e.g. cross-attention, or placing the tokens directly in the language model's input sequence
4. A language model
"""
def __init__(
self,
vision_encoder: nn.Module,
vision_tokenizer: nn.Module,
lang_model: nn.Module,
initial_tokenizer_len: int,
pad_token_id: int,
gradient_checkpointing: bool = False,
):
"""
Args:
vision_encoder (nn.Module): e.g. CLIP
vision_tokenizer (nn.Module): e.g. PerceiverResampler
lang_model (nn.Module): e.g. MPT
initial_tokenizer_len (int): size of the original tokenizer vocab
pad_token_id (int): id of the pad token
gradient_checkpointing (bool, optional): Whether to use gradient checkpointing. Defaults to False.
"""
super().__init__()
# save dimension information
self.lang_embedding_dim = lang_model.get_input_embeddings().weight.shape[1]
if hasattr(lang_model.config, "d_model"):
self.lang_hidden_dim = lang_model.config.d_model # mpt uses d_model
else:
self.lang_hidden_dim = lang_model.config.hidden_size
self.vis_embedding_dim = vision_tokenizer.dim_media
self.num_tokens_per_vis = vision_tokenizer.num_tokens_per_media
# core components
self.vision_encoder = vision_encoder
self.vision_tokenizer = vision_tokenizer
self.lang_model = lang_model
# lm embeddings
self.pad_token_id = pad_token_id
self.initial_tokenizer_len = initial_tokenizer_len
input_embeds = DecoupledEmbedding(
max_original_id=initial_tokenizer_len - 1,
num_additional_embeddings=len(self.special_tokens),
_weight=self.lang_model.get_input_embeddings().weight,
pad_token_id=self.pad_token_id,
)
if hasattr(input_embeds, "additional_embedding"):
input_embeds.additional_embedding.weight.data.normal_(
mean=0.0,
std=(
self.lang_model.config.initializer_range
if hasattr(self.lang_model.config, "initializer_range")
else 0.02
),
)
self.lang_model.set_input_embeddings(input_embeds)
out_embeds = DecoupledLinear(
max_original_id=initial_tokenizer_len - 1,
additional_out_features=len(self.special_tokens),
_weight=self.lang_model.get_output_embeddings().weight,
_bias=(
self.lang_model.get_output_embeddings().bias
if hasattr(self.lang_model.get_output_embeddings(), "bias")
else None
),
)
if hasattr(out_embeds, "additional_fc"):
out_embeds.additional_fc.weight.data.normal_(
mean=0.0,
std=(
self.lang_model.config.initializer_range
if hasattr(self.lang_model.config, "initializer_range")
else 0.02
),
)
self.lang_model.set_output_embeddings(out_embeds)
# gradient checkpointing
self.vision_tokenizer._use_gradient_checkpointing = gradient_checkpointing
def forward(
self,
vision_x: Optional[torch.Tensor],
lang_x: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
past_key_values: Optional[
List[Union[torch.Tensor, Tuple[torch.Tensor]]]
] = None,
past_media_locations: Optional[torch.Tensor] = None,
past_vision_tokens: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = False,
**kwargs,
):
"""
Args:
vision_x: Vision input
shape (B, T_img, F, C, H, W) with F=1
only F = 1 is supported (single-frame videos)
if T_img > the number of media tokens in the corresponding input_ids (lang_x),
only the first number of media tokens in lang_x are used
lang_x: Language input ids, with media tokens denoting where
visual media should be inserted.
shape (B, T_txt)
attention_mask: Attention mask. Defaults to None.
labels: Labels. Defaults to None.
shape (B, T_txt)
past_key_values (Tuple[torch.Tensor]], optional): Past key value pairs for each of the T_txt previous tokens in the language model. Defaults to None.
list of length = number of decoder layers in the LM
exact implementation depends on LM, see Hugging Face docs
past_media_locations (torch.Tensor, optional): boolean mask denoting which of the previous T_txt tokens were media tokens. Defaults to None.
shape (B, T_txt)
past_vision_tokens (torch.Tensor, optional): Previous vision tokens. Defaults to None.
use_cache (Optional[bool], optional): Whether to use cache. Defaults to False.
If True, includes key_values, media_locations, and vision_tokens in the output.
"""
assert not (past_vision_tokens is None) ^ (
past_media_locations is None
), "past_vision_tokens and past_media_locations must both be None or both be not None"
# convert pixels to vision tokens
if vision_x is not None:
vision_features = self._encode_vision_x(vision_x=vision_x)
vision_tokens = self.vision_tokenizer(vision_features)
else:
vision_tokens = None
# fuse the vision and language tokens
new_inputs = self._prepare_inputs_for_forward(
vision_tokens=vision_tokens,
lang_x=lang_x,
attention_mask=attention_mask,
labels=labels,
past_key_values=past_key_values,
past_media_locations=past_media_locations,
padding_side="right",
past_vision_tokens=past_vision_tokens,
)
output = self.lang_model(
**new_inputs,
use_cache=use_cache,
past_key_values=past_key_values,
**kwargs,
)
# postprocessing may be needed, e.g. to remove extra tokens from logits that were inserted into the language stream
# or to add the past_vision_tokens and past_media_locations to the output
output = self._postprocess_outputs_from_forward(
output=output,
lang_x=lang_x,
vision_tokens=vision_tokens,
use_cache=use_cache,
past_vision_tokens=past_vision_tokens,
past_media_locations=past_media_locations,
)
# postforward hooks
self._post_forward_hook()
return output
def _encode_vision_x_anyres(self, samples, device):
assert self.anyres_grids is not None
image_raw = samples[
"image"
] # list of patch list in of shape [1, N_patch, C, H, W]
image_sizes = samples["image_size"]
# Image_raw can be a list of list of patches, when a `samples` has multiple images.
if isinstance(image_raw[0], list):
images = [x.squeeze(0) for sample_img in image_raw for x in sample_img]
image_sizes = [s for sample_sizes in image_sizes for s in sample_sizes]
else:
# assert isinstance(image_raw[0], torch.Tensor), f"Unkown image type: {image_raw[0]}"
# concate list of patches into one big patch for any res encoding.
images = [x.squeeze(0) for x in image_raw] # [N_patch, C, H, W]
image = torch.cat(images, dim=0) # [\sum{B}{N_patch_i}, C, H, W]
image = image.to(device)
with torch.no_grad():
if self.vision_encoder.__class__.__name__ == "TimmModel":
image_embeds = self.vision_encoder.trunk.forward_features(image)
elif self.vision_encoder.__class__.__name__ in [
"CLIPVisionModel",
"SiglipVisionTransformer",
]:
image_embeds = self.vision_encoder(image).last_hidden_state
else:
image_embeds = self.vision_encoder(image)[1] # OpenCLIP returns tuples
if isinstance(self.vision_encoder, CLIPVisionModel) or isinstance(
self.vision_encoder, SiglipVisionTransformer
):
base_img_size = self.vision_encoder.config.image_size
else:
base_img_size = self.vision_encoder.image_size[0]
if self.vision_encoder.__class__.__name__ == "TimmModel":
grid_size = self.vision_encoder.trunk.patch_embed.grid_size
elif self.vision_encoder.__class__.__name__ in [
"CLIPVisionModel",
"SiglipVisionTransformer",
]:
grid_size_base = (
self.vision_encoder.config.image_size
// self.vision_encoder.config.patch_size
)
grid_size = (grid_size_base, grid_size_base)
else:
grid_size = self.vision_encoder.grid_size
height, width = grid_size
if not image_embeds.shape[1] == height * width:
assert (
image_embeds.shape[1] == height * width + 1
) # For vision encoders that has [CLS] token.
image_embeds = image_embeds[:, 1:, :] # Drop the cls token for each patch.
n_vis_token_per_patch = image_embeds.shape[1]
# Split encoded patches and merge patch features
# 1. Get the raw sizes from samples, and split the image embeds [\sum_{B}(N_patch_i), N_tok(16*16), C]
split_sizes = [image.shape[0] for image in images]
image_embeds = torch.split(image_embeds, split_sizes, dim=0)
# 2. For each image (consist of a list of patches), merge the patches spatially (of shape [C, n_patch_height, n_patch_width])
new_image_embeds = []
patch_attn_masks = []
max_n_img_token = -1
for idx, patch_embeds in enumerate(image_embeds):
if patch_embeds.shape[0] > 1:
# 3. Flatten the patch features and get [C, n_patch_height * (n_patch_width+1)]
base_patch_embeds = patch_embeds[
0
] # TODO: prepend the CLS token for th base patch embeds (of the resized entire image).
patch_embeds = patch_embeds[1:]
assert height * width == base_patch_embeds.shape[0]
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
image_sizes[idx], self.anyres_grids, base_img_size
) # Hardcoded grid_pinpoints.
patch_embeds = patch_embeds.view(
num_patch_height, num_patch_width, height, width, -1
)
patch_embeds = patch_embeds.permute(4, 0, 2, 1, 3).contiguous()
patch_embeds = patch_embeds.flatten(1, 2).flatten(2, 3)
patch_embeds, patch_attn_mask = unpad_image(
patch_embeds, image_sizes[idx], self.anyres_patch_sampling
)
if hasattr(self, "image_newline"):
patch_embeds = torch.cat(
(
patch_embeds,
self.image_newline[:, None, None].expand(
*patch_embeds.shape[:-1], 1
),
),
dim=-1,
)
if self.anyres_patch_sampling:
patch_embeds = patch_embeds.view(
-1, num_patch_height, num_patch_width, height * width
)
patch_embeds = patch_embeds.flatten(1, 2).permute(1, 2, 0)
assert patch_attn_mask is not None
patch_attn_mask = patch_attn_mask.view(
num_patch_height, num_patch_width, height * width
)
patch_attn_mask = patch_attn_mask.flatten(0, 1)
patch_embeds = torch.cat(
(base_patch_embeds.unsqueeze(0), patch_embeds), dim=0
)
patch_attn_mask = torch.cat(
(
torch.ones(
n_vis_token_per_patch, device=patch_embeds.device
).unsqueeze(0),
patch_attn_mask,
),
dim=0,
)
else:
patch_embeds = patch_embeds.flatten(1, 2).transpose(0, 1)
patch_embeds = torch.cat((base_patch_embeds, patch_embeds), dim=0)
else:
patch_embeds = (
patch_embeds[0].unsqueeze(0)
if self.anyres_patch_sampling
else patch_embeds[0]
)
patch_attn_mask = (
torch.ones(
n_vis_token_per_patch, device=patch_embeds.device
).unsqueeze(0)
if self.anyres_patch_sampling
else None
)
if hasattr(self, "image_newline"):
patch_embeds = torch.cat(
(patch_embeds, self.image_newline[None]), dim=0
)
if not self.anyres_patch_sampling:
max_n_img_token = max(patch_embeds.shape[0], max_n_img_token)
new_image_embeds.append(patch_embeds)
patch_attn_masks.append(patch_attn_mask)
if self.anyres_patch_sampling:
# Return individual patches for independent token downsampling.
return new_image_embeds, patch_attn_masks
# 4. Pad and concat the list of image_embeds [N_tok_i, C] together into a batch. Also modify the query attention mask.
image_embeds = []
image_atts = []
for image_embed in new_image_embeds:
n_img_token = image_embed.shape[0]
img_attn = torch.ones(
(max_n_img_token), dtype=torch.long, device=image_embed.device
)
if n_img_token < max_n_img_token:
padded_embed = torch.zeros(
(max_n_img_token, image_embed.shape[-1]),
dtype=image_embed.dtype,
device=image_embed.device,
)
padded_embed[:n_img_token, :] = image_embed
img_attn[n_img_token:] = 0 # Mask out the padded entries.
else:
padded_embed = image_embed
image_embeds.append(padded_embed)
image_atts.append(img_attn)
image_embeds = torch.stack(
image_embeds, dim=0
) # Shape [B, N_tok_longest, C_dim]
image_atts = torch.stack(image_atts, dim=0) # Shape [B, N_tok_longest, C_dim]
# TODO: reshape image_embeds and image_atts to "b T F v d"
image_embeds = image_embeds[:, None, None, :, :]
# image_atts = image_atts[:, None, None, :, :]
return image_embeds, image_atts
def _encode_vision_x(self, vision_x: torch.Tensor):
"""
Compute media tokens from vision input by passing it through vision encoder and conditioning language model.
Args:
vision_x: Vision input
shape (B, T_img, F, C, H, W)
Images in the same chunk are collated along T_img, and frames are collated along F
Currently only F=1 is supported (single-frame videos)
rearrange code based on https://github.com/dhansmair/flamingo-mini
"""
assert vision_x.ndim == 6, "vision_x should be of shape (b, T_img, F, C, H, W)"
b, T, F = vision_x.shape[:3]
vision_x = rearrange(vision_x, "b T F c h w -> (b T F) c h w")
with torch.no_grad():
if self.vision_encoder.__class__.__name__ == "TimmModel":
vision_x = self.vision_encoder.trunk.forward_features(vision_x)
elif self.vision_encoder.__class__.__name__ in [
"CLIPVisionModel",
"SiglipVisionTransformer",
]:
vision_x = self.vision_encoder(vision_x).last_hidden_state
else:
vision_x = self.vision_encoder(vision_x)[1] # OpenCLIP returns tuples
vision_x = rearrange(vision_x, "(b T F) v d -> b T F v d", b=b, T=T, F=F)
return vision_x
def _concat_vision_cache(
self, lang_x, vision_tokens, past_vision_tokens, past_media_locations, use_cache
):
"""
Helper function to include the past vision tokens and past media locations in the output.
"""
if use_cache:
if past_media_locations is not None and past_vision_tokens is not None:
if vision_tokens is not None:
updated_vision_tokens = torch.cat(
[
past_vision_tokens,
vision_tokens,
],
dim=1,
)
else:
updated_vision_tokens = past_vision_tokens
updated_media_locations = torch.cat(
[
past_media_locations,
lang_x == self.media_token_id,
],
dim=1,
)
else:
updated_vision_tokens = vision_tokens
updated_media_locations = lang_x == self.media_token_id
else:
updated_vision_tokens = None
updated_media_locations = None
return updated_vision_tokens, updated_media_locations
def generate(
self,
vision_x: torch.Tensor,
lang_x: torch.Tensor,
attention_mask: torch.Tensor = None,
past_key_values: Optional[
List[Union[torch.Tensor, Tuple[torch.Tensor]]]
] = None,
past_media_locations: Optional[torch.Tensor] = None,
past_vision_tokens: Optional[torch.Tensor] = None,
**kwargs,
):
"""
Generate text conditioned on vision and language inputs.
Args:
vision_x (torch.Tensor): Vision input
shape (B, T_img, F, C, H, W)
see documentation for forward
lang_x (torch.Tensor): Language input
shape (B, T_txt)
attention_mask (torch.Tensor, optional): Attention mask. Defaults to None.
**kwargs: see generate documentation in Hugging Face CausalLM models.
Returns:
torch.Tensor: lang_x with generated tokens appended to it
"""
num_beams = kwargs.pop("num_beams", 1)
# convert pixels to vision tokens
if vision_x is not None:
vision_features = self._encode_vision_x(vision_x=vision_x)
vision_tokens = self.vision_tokenizer(vision_features)
else:
vision_tokens = None
# fuse the vision and language tokens
# for xattn, vision_x and media_location are repeat_interleaved s.t.
# the total batch size is B * num_beams
new_inputs = self._prepare_inputs_for_forward(
vision_tokens=vision_tokens,
lang_x=lang_x,
attention_mask=attention_mask,
past_key_values=past_key_values,
past_media_locations=past_media_locations,
past_vision_tokens=past_vision_tokens,
padding_side="left",
num_beams=num_beams,
)
output = self.lang_model.generate(
**new_inputs,
past_key_values=past_key_values,
num_beams=num_beams,
use_cache=True,
**kwargs,
)
self._post_forward_hook()
return output
@property
def num_trainable_params(self):
"""Print the number of trainable parameters"""
return num_params(self, filter_to_trainable=True)
def set_trainable(self):
"""
Freeze appropriate parameters in the model.
"""
raise NotImplementedError
def group_params_by_weight_decay(self):
"""
Return a tuple of (params to optimize w/ weight decay, params to optimize w/o weight decay)
"""
params_with_wd, params_without_wd = [], []
for n, p in self.named_parameters():
if p.requires_grad:
if self._should_apply_weight_decay(n):
params_with_wd.append(p)
else:
params_without_wd.append(p)
return params_with_wd, params_without_wd
def _should_apply_weight_decay(self, parameter_name):
"""
Return whether weight decay should be applied to a parameter.
"""
raise NotImplementedError
@property
def special_tokens(self):
"""
Returns a dict mapping from the attribute name of a special token to its string format,
e.g. "media_token": "<image>"
"""
assert (
"media_token" in self._special_tokens
), "VLMs need to request that the tokenizer add a media_token and call set_special_token_ids to set self.media_token_id"
return self._special_tokens
@property
def special_token_ids(self):
"""
Returns a list of the special token ids
"""
return [getattr(self, f"{att_name}_id") for att_name in self.special_tokens]
def set_special_token_ids(self, string_to_ids):
"""
Args:
string_to_ids (dict): mapping from token string to id
"""
assert set(self.special_tokens.values()).issubset(set(string_to_ids.keys()))
for att_name, token_str in self.special_tokens.items():
token_id = string_to_ids[token_str]
setattr(self, f"{att_name}_id", token_id)
setattr(self.lang_model, f"{att_name}_id", token_id)
def init_gradient_checkpointing(self):
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper,
CheckpointWrapper,
CheckpointImpl,
apply_activation_checkpointing,
)
from functools import partial
non_reentrant_wrapper = partial(
checkpoint_wrapper,
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
)
apply_activation_checkpointing(
self,
checkpoint_wrapper_fn=non_reentrant_wrapper,
check_fn=lambda m: getattr(m, "_use_gradient_checkpointing", False)
and not isinstance(m, CheckpointWrapper),
)
@dataclass
class VLMOutputWithPast(CausalLMOutputWithPast):
"""
VLMOutputWithPast is a wrapper around CausalLMOutputWithPast that adds the following attributes:
past_media_locations: Optional[torch.Tensor] = None,
past_vision_tokens: Optional[torch.Tensor] = None,
"""
past_media_locations: Optional[torch.Tensor] = None
past_vision_tokens: Optional[torch.Tensor] = None
def exists(val):
return val is not None
def FeedForward(dim, mult=4):
inner_dim = int(dim * mult)
return nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, inner_dim, bias=False),
nn.GELU(),
nn.Linear(inner_dim, dim, bias=False),
)
class VLMWithLanguageStream(VLM):
"""
VLM that fuses modalities by inserting vision tokens directly into the language stream.
"""
def __init__(
self,
vision_encoder: nn.Module,
vision_tokenizer: nn.Module,
lang_model: nn.Module,
initial_tokenizer_len: int,
pad_token_id: int,
decoder_layers_attr_name: str = None,
gradient_checkpointing: bool = False,
):
super().__init__(
vision_encoder=vision_encoder,
vision_tokenizer=vision_tokenizer,
lang_model=lang_model,
initial_tokenizer_len=initial_tokenizer_len,
pad_token_id=pad_token_id,
gradient_checkpointing=gradient_checkpointing,
)
self.decoder_layers_attr_name = decoder_layers_attr_name
if decoder_layers_attr_name is not None:
for block in getattr_recursive(
self.lang_model, self.decoder_layers_attr_name
):
block._use_gradient_checkpointing = gradient_checkpointing
def _prepare_inputs_for_forward(
self,
vision_tokens: torch.Tensor,
lang_x: torch.Tensor,
attention_mask: torch.Tensor,
labels: torch.Tensor = None,
past_key_values=None,
vision_attention_mask: Optional[torch.Tensor] = None,
past_media_locations: torch.Tensor = None,
past_vision_tokens: torch.Tensor = None,
padding_side: str = "left",
num_beams: int = 1,
):
"""
Insert the vision tokens directly into the language stream/
This requires us to modify the input_ids, attention_mask, and labels.
"""
if past_key_values is not None:
past_len = past_key_values[0][0].shape[2]
assert attention_mask.shape[1] == past_len + lang_x.shape[1], (
"Attention_mask must be as long as the entire past len (including image tokens) and current input IDs. "
+ "Check that you've expanded the attention mask to account for past image tokens."
)
if vision_tokens is None:
return {
"input_ids": lang_x,
"attention_mask": attention_mask,
"labels": labels,
}
# get the language embeddings
lang_embeds = self.lang_model.get_input_embeddings()(lang_x)
# build up the multimodal embeddings
B = lang_x.shape[0]
has_labels = labels is not None
multimodal_embeds = []
multimodal_attention_mask = []
multimodal_labels = [] if has_labels else None
for i in range(B):
# get index of <image> tokens in lang_x[i]
image_token_idxs = torch.where(lang_x[i] == self.media_token_id)[0]
if len(image_token_idxs) == 0:
multimodal_embeds.append(lang_embeds[i].clone())
multimodal_attention_mask.append(attention_mask[i].clone())
if has_labels:
multimodal_labels.append(labels[i].clone())
continue
# loop through the image_token_idxs and insert the vision tokens
new_embed = lang_embeds[i].clone()
new_attention_mask = (
attention_mask[i].clone() if attention_mask is not None else None
)
if has_labels:
new_label = labels[i].clone()
print(vision_tokens.shape)
for img_num, img_idx in enumerate(image_token_idxs):
new_embed = torch.cat(
(
new_embed[:img_idx],
vision_tokens[i][img_num],
new_embed[img_idx + self.num_tokens_per_vis :],
),
dim=0,
)
new_attention_mask = torch.cat(
(
new_attention_mask[:img_idx],
torch.ones(self.num_tokens_per_vis, dtype=torch.long).to(
attention_mask.device
),
new_attention_mask[img_idx + self.num_tokens_per_vis :],
),
dim=0,
)
if has_labels:
new_label = torch.cat(
(
new_label[:img_idx],
torch.ones(self.num_tokens_per_vis, dtype=torch.long).to(
labels.device
)
* -100,
new_label[img_idx + self.num_tokens_per_vis :],
),
dim=0,
)
multimodal_embeds.append(new_embed)
multimodal_attention_mask.append(new_attention_mask)
if has_labels:
multimodal_labels.append(new_label)
# stack
multimodal_embeds = stack_with_padding(
multimodal_embeds,
padding_value=self.pad_token_id,
padding_side=padding_side,
)
multimodal_attention_mask = stack_with_padding(
multimodal_attention_mask,
padding_value=0,
padding_side=padding_side,
)
if has_labels:
multimodal_labels = stack_with_padding(
multimodal_labels,
padding_value=-100,
padding_side=padding_side,
)
return {
"inputs_embeds": multimodal_embeds,
"attention_mask": multimodal_attention_mask,
"labels": multimodal_labels,
}
def _postprocess_outputs_from_forward(
self,
output: CausalLMOutputWithPast,
lang_x: torch.Tensor,
vision_tokens: torch.Tensor,
past_vision_tokens: torch.Tensor,
past_media_locations: torch.Tensor,
use_cache: bool = False,
):
# Include the past vision tokens and past media locations in the output
updated_vision_tokens, updated_media_locations = self._concat_vision_cache(
lang_x=lang_x,
vision_tokens=vision_tokens,
past_vision_tokens=past_vision_tokens,
past_media_locations=past_media_locations,
use_cache=use_cache,
)
# return logits that are the same shape as the original input_ids
logits = output.logits
batch_logits = []
B, T_txt = lang_x.shape
for i in range(B):
sequence_logits = []
logits_j = 0
for j in range(T_txt):
if lang_x[i, j] != self.media_token_id:
sequence_logits.append(logits[i, logits_j])
logits_j += 1
else:
# append the logit for the first image token, then skip over the rest
# note: the model actually learns to predict <im_patch>, not <image>
sequence_logits.append(logits[i, logits_j])
logits_j += self.num_tokens_per_vis
sequence_logits = torch.stack(sequence_logits, dim=0) # (B, vocab_size)
batch_logits.append(sequence_logits)
batch_logits = torch.stack(batch_logits, dim=0) # (B, T_txt, vocab_size)
# The final logits shape should be the same as the original input_ids shape
assert batch_logits.shape[:2] == (B, T_txt)
# assemble the output
output = VLMOutputWithPast(
loss=output.loss,
logits=batch_logits,
past_key_values=output.past_key_values,
hidden_states=output.hidden_states,
attentions=output.attentions,
past_media_locations=updated_media_locations,
past_vision_tokens=updated_vision_tokens,
)
return output
def _post_forward_hook(self):
pass
@property
def num_params_per_module(self):
"""Print the number of parameters per module in the model"""
return "\n".join(
[
f"Vision encoder: {num_params(self.vision_encoder):,} parameters",
f"Vision tokenizer: {num_params(self.vision_tokenizer):,} parameters",
f"Language model: {num_params(self.lang_model):,} parameters",
]
)
@property
def num_trainable_params_per_module(self):
"""Print the number of trainable parameters per module in the model"""
return "\n".join(
[
f"Vision encoder: {num_params(self.vision_encoder, filter_to_trainable=True):,} trainable parameters",
f"Vision tokenizer: {num_params(self.vision_tokenizer, filter_to_trainable=True):,} trainable parameters",
f"Language model: {num_params(self.lang_model, filter_to_trainable=True):,} trainable parameters",
]
)
class XGenMMPerceiver(VLMWithLanguageStream):
def __init__(
self,
vision_encoder: nn.Module,
vision_tokenizer: nn.Module,
lang_model: nn.Module,
initial_tokenizer_len: int,
pad_token_id: int,
decoder_layers_attr_name: str = None,
gradient_checkpointing: bool = False,
image_aspect_ratio: str = "none",
):
"""
Args:
vision_encoder (nn.Module): HF CLIPModel
lang_encoder (nn.Module): HF causal language model
vis_feature_dim (int): final dimension of the visual features outputted by the vision_encoder
initial_tokenizer_len (int): size of the tokenizer vocab
padding_token_id (int): id of the padding token. None if no padding token; then a padding token
will be inserted into self.special_tokens, which factory.py fills after creating new tokens
decoder_layers_attr_name (str, optional): name of the decoder layers attribute. Defaults to None.
gradient_checkpointing (bool, optional): whether to use gradient checkpointing. Defaults to False.
"""
self._special_tokens = {
"media_token": "<image>",
"image_placeholder_token": "<image placeholder>",
"end_of_trunk_token": "<|endofchunk|>",
}
lang_embedding_dim = lang_model.get_input_embeddings().weight.shape[1]
super().__init__(
vision_encoder=vision_encoder,
vision_tokenizer=vision_tokenizer,
lang_model=lang_model,
initial_tokenizer_len=initial_tokenizer_len,
gradient_checkpointing=gradient_checkpointing,
decoder_layers_attr_name=decoder_layers_attr_name,
pad_token_id=pad_token_id,
)
self.image_aspect_ratio = image_aspect_ratio
def set_trainable(self):
"""
Unfreeze everything except the vision_encoder
"""
self.requires_grad_(True)
self.vision_encoder.requires_grad_(False)
def _should_apply_weight_decay(self, parameter_name):
"""
Kosmos applies 0.01 weight deacy to everything
"""
return True
def generate(
self,
vision_x: torch.Tensor,
lang_x: torch.Tensor,
image_size: Optional[Tuple] = None,
attention_mask: torch.Tensor = None,
past_key_values: Optional[
List[Union[torch.Tensor, Tuple[torch.Tensor]]]
] = None,
past_media_locations: Optional[torch.Tensor] = None,
past_vision_tokens: Optional[torch.Tensor] = None,
**kwargs,
):
"""
Generate text conditioned on vision and language inputs.
Args:
vision_x (torch.Tensor): Vision input
shape (B, T_img, F, C, H, W)
see documentation for forward
lang_x (torch.Tensor): Language input
shape (B, T_txt)
attention_mask (torch.Tensor, optional): Attention mask. Defaults to None.
**kwargs: see generate documentation in Hugging Face CausalLM models.
Returns:
torch.Tensor: lang_x with generated tokens appended to it
"""
num_beams = kwargs.pop("num_beams", 1)
# convert pixels to vision tokens
vision_attention_mask = None
if vision_x is not None:
vision_features = self._encode_vision_x(vision_x=vision_x)
vision_tokens = self.vision_tokenizer(vision_features)
else:
vision_tokens = None
# fuse the vision and language tokens
# for xattn, vision_x and media_location are repeat_interleaved s.t.
# the total batch size is B * num_beams
new_inputs = self._prepare_inputs_for_forward(
vision_tokens=vision_tokens,
lang_x=lang_x,
attention_mask=attention_mask,
vision_attention_mask=vision_attention_mask,
past_key_values=past_key_values,
past_media_locations=past_media_locations,
past_vision_tokens=past_vision_tokens,
padding_side="left",
num_beams=num_beams,
)
if past_key_values is not None:
output = self.lang_model.generate(
**new_inputs,
past_key_values=past_key_values,
num_beams=num_beams,
use_cache=True,
**kwargs,
)
else:
output = self.lang_model.generate(
**new_inputs,
num_beams=num_beams,
use_cache=True,
**kwargs,
)
self._post_forward_hook()
return output
class XGenMMVisionEncoder(PreTrainedModel):
main_input_name = "pixel_values"
config_class = XGenMMVisionEncoderConfig
def __init__(self, config: XGenMMVisionEncoderConfig):
super().__init__(config)
if config.model_name != "google/siglip-so400m-patch14-384":
raise ValueError(
f"Unsupported model {config.model_name}. New vision models will be added soon."
)
self.model = AutoModel.from_pretrained(config.model_name)
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
# assert pixel_values.ndim == 4, f"Expected 4D tensor (bs, c, h, w), got {pixel_values.ndim}"
return self.model.encode_image(pixel_values)
# vision tokenizer
class XGenMMVisionTokenizer(PreTrainedModel):
config_class = XGenMMVisionTokenizerConfig
def __init__(self, config: XGenMMVisionTokenizerConfig):
super().__init__(config)
self.model = PerceiverResampler(
dim=config.vis_feature_dim,
dim_inner=config.lang_embedding_dim,
num_latents=config.num_vis_tokens,
)
def forward(self, vision_features: torch.Tensor, vision_attn_masks: torch.Tensor):
return self.model(vision_features, vision_attn_masks)
# XGenMM model
class XGenMMModelForConditionalGeneration(PreTrainedModel):
config_class = XGenMMConfig
def __init__(self, config: XGenMMConfig):
super().__init__(config)
# vision encoder initialization
vision_encoder = AutoModel.from_pretrained(
config.vision_encoder_config.model_name
).vision_model
# language model initialization
language_model = AutoModelForCausalLM.from_config(config.text_config)
check_embedding_fns(language_model)
# Update _tied_weights_keys using the base model used.
if language_model._tied_weights_keys is not None:
self._tied_weights_keys = [
f"language_model.{k}" for k in language_model._tied_weights_keys
]
# vision tokenizer initialization
if (
config.vision_tokenizer_config.lang_embedding_dim
!= language_model.get_input_embeddings().weight.shape[1]
):
overwrite = language_model.get_input_embeddings().weight.shape[1]
config.vision_tokenizer_config.lang_embedding_dim = overwrite
print(
f"Warning: The language embedding dimension in the vision tokenizer config is different from the language model's embedding dimension. Overwriting the language embedding dimension in the vision tokenizer config to {overwrite}."
)
vision_tokenizer = XGenMMVisionTokenizer(config.vision_tokenizer_config).model
self.vlm = XGenMMPerceiver(
vision_encoder=vision_encoder,
vision_tokenizer=vision_tokenizer,
lang_model=language_model,
initial_tokenizer_len=config.text_config.initial_tokenizer_len,
pad_token_id=config.text_config.pad_token_id,
image_aspect_ratio=config.vision_encoder_config.image_aspect_ratio,
)
# Initialize weights and apply final processing
self.post_init()
@torch.no_grad()
def generate(
self,
pixel_values: torch.FloatTensor,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
**generate_kwargs,
) -> torch.LongTensor:
self.vlm = self.vlm.eval()
return self.vlm.generate(
vision_x=pixel_values,
lang_x=input_ids,
attention_mask=attention_mask,
**generate_kwargs,
)
def update_special_tokens(self, tokenizer):
tokenizer.add_special_tokens(
{"additional_special_tokens": list(self.vlm.special_tokens.values())}
)
self.vlm.lang_model.config.vocab_size = len(tokenizer)
self.vlm.set_special_token_ids(
{
v: tokenizer.convert_tokens_to_ids(v)
for v in self.vlm.special_tokens.values()
}
)
return tokenizer
run-20241216_113420-xvp9nqy9/logs/debug-internal.log
\ No newline at end of file
run-20241216_113420-xvp9nqy9/logs/debug.log
\ No newline at end of file
run-20241216_113420-xvp9nqy9
\ No newline at end of file
_wandb:
value:
cli_version: 0.19.1
m: []
python_version: 3.10.12
t:
"1":
- 1
- 11
- 41
- 49
- 55
- 71
"2":
- 1
- 11
- 41
- 49
- 55
- 71
"3":
- 13
- 16
- 23
- 55
"4": 3.10.12
"5": 0.19.1
"6": 4.47.0
"8":
- 5
"12": 0.19.1
"13": linux-x86_64
anyres_grids:
value:
- - 1
- 2
- - 2
- 1
- - 2
- 2
- - 3
- 1
- - 1
- 3
anyres_patch_sampling:
value: true
batch_size:
value: 8
checkpoint_steps:
value: 5000
conv_template_name:
value: phi_3
cpu_offload_gradients:
value: false
cross_attn_every_n_layers:
value: 1
data_path:
value: /mnt/xgen-mm/LAVIS/data_configs/example_data_config.yaml
data_sampler_group_by_length:
value: true
delete_previous_checkpoint:
value: false
device:
value: cuda:0
dist_backend:
value: nccl
dist_url:
value: env://
distributed:
value: true
dryrun:
value: false
fsdp:
value: true
fsdp_sharding_strategy:
value: hybrid
gradient_accumulation_steps:
value: 1
gradient_checkpointing:
value: true
horovod:
value: false
image_aspect_ratio:
value: anyres
is_multimodal:
value: true
learning_rate:
value: 2e-05
lm_path:
value: microsoft/Phi-3-mini-4k-instruct
local_rank:
value: 0
logging_steps:
value: 100
loss:
value: supervised_finetune
lr_scheduler:
value: cosine
mm_use_im_start_end:
value: false
model_family:
value: xgenmm_v1
no_save_optim_state:
value: true
no_set_device_rank:
value: false
num_epochs:
value: 1
num_vision_tokens:
value: 128
offline:
value: false
precision:
value: amp_bf16
pretrained:
value: /mnt/xgen-mm/xgen-mm-phi3-mini-base-r-v1.5.pt
pretrained_vision_tokenizer:
value: null
rank:
value: 0
report_to_wandb:
value: true
resume_from_checkpoint:
value: null
run_name:
value: finetune-xgenmmv1-phi3_4k_instruct-example_data_config
save_checkpoints_to_wandb:
value: false
seed:
value: 42
tokenizer_path:
value: microsoft/Phi-3-mini-4k-instruct
unfreeze_vision_encoder:
value: false
use_flash_attention_2:
value: false
vision_encoder_path:
value: google/siglip-so400m-patch14-384
vision_encoder_precision:
value: fp32
vision_encoder_pretrained:
value: google
wandb_entity:
value: null
wandb_project:
value: blip3-xgenmm-finetune
warmup_steps:
value: 2000
weight_decay:
value: 0
workers:
value: 4
world_size:
value: 8
Found no checkpoints for run finetune-xgenmmv1-phi3_4k_instruct-example_data_config.
Loading checkpoint from /mnt/xgen-mm/xgen-mm-phi3-mini-base-r-v1.5.pt
Traceback (most recent call last):
File "/mnt/xgen-mm/LAVIS/open_flamingo/train/instruction_finetune.py", line 453, in <module>
main()
File "/mnt/xgen-mm/LAVIS/open_flamingo/train/instruction_finetune.py", line 333, in main
_, _, checkpoint = load_checkpoint(args, model, pretrained=True)
File "/mnt/xgen-mm/LAVIS/open_flamingo/train/train_utils.py", line 527, in load_checkpoint
msd = checkpoint.pop("model_state_dict")
KeyError: 'model_state_dict'
Traceback (most recent call last):
File "/mnt/xgen-mm/LAVIS/open_flamingo/train/instruction_finetune.py", line 453, in <module>
main()
File "/mnt/xgen-mm/LAVIS/open_flamingo/train/instruction_finetune.py", line 333, in main
_, _, checkpoint = load_checkpoint(args, model, pretrained=True)
File "/mnt/xgen-mm/LAVIS/open_flamingo/train/train_utils.py", line 527, in load_checkpoint
msd = checkpoint.pop("model_state_dict")
KeyError: 'model_state_dict'
setuptools==65.5.0
protobuf==3.20.3
einops-exts==0.0.4
GitPython==3.1.43
wcwidth==0.2.13
ftfy==6.3.1
braceexpand==0.1.7
omegaconf==2.3.0
wandb==0.19.1
smmap==5.0.1
webdataset==0.2.100
sentry-sdk==2.19.2
gitdb==4.0.11
docker-pycreds==0.4.0
antlr4-python3-runtime==4.9.3
setproctitle==1.3.4
transformers==4.47.0
tokenizers==0.21.0
annotated-types==0.7.0
torch==2.1.0+das.opt2.dtk24043
fused-dense-lib==0.1.0+das.opt1.dtk24043
python-dateutil==2.9.0.post0
jiter==0.7.0
diskcache==5.6.3
vllm==0.5.0+das.opt3.dtk24043
requests==2.32.3
triton==2.1.0+das.opt1.dtk24042
Pygments==2.18.0
pandas==2.2.3
pytest-asyncio==0.24.0
bitsandbytes==0.42.0+das.opt1.dtk24042
jsonschema-specifications==2024.10.1
xformers==0.0.25+das.opt1.dtk24042
humanfriendly==10.0
sentencepiece==0.2.0
yacs==0.1.8
safetensors==0.4.5
cloudpickle==3.1.0
lmdeploy==0.2.6+das.opt1.dtk24042
layer-check-pt==1.2.3+das.dtk24042
lightop==0.4.0+das.dtk24042
torch-spline-conv==1.2.1+das.opt1.dtk24042
ninja==1.11.1.1
fire==0.7.0
dill==0.3.8
outlines==0.1.1
h11==0.14.0
peft==0.9.0
certifi==2024.8.30
websockets==13.1
sympy==1.12.1
shortuuid==1.0.13
hypothesis==5.35.1
pluggy==1.5.0
rich==13.9.4
opencv-python==4.10.0.84
zipp==3.20.2
tiktoken==0.8.0
pillow==11.0.0
flash-attn==2.6.1+das.opt1.dtk24043
faiss==1.7.2+das.dtk24042
addict==2.4.0
tzdata==2024.2
starlette==0.41.2
tomli==2.0.2
aiosignal==1.3.1
mpmath==1.3.0
pyarrow==18.0.0
xxhash==3.5.0
fvcore==0.1.5.post20221221
fastpt==1.1.0+das.dtk24042
networkx==3.4.2
distro==1.9.0
referencing==0.35.1
regex==2024.9.11
pydantic_core==2.23.4
yarl==1.17.1
jsonschema==4.23.0
matplotlib==3.9.2
psutil==6.1.0
coloredlogs==15.0.1
cmake==3.30.5
multiprocess==0.70.16
uvloop==0.21.0
numpy==1.24.3
frozenlist==1.5.0
fonttools==4.54.1
einops==0.8.0
airportsdata==20241001
Jinja2==3.1.4
pynvml==11.5.3
datasets==3.1.0
mmengine-lite==0.10.5
attrs==24.2.0
tqdm==4.66.6
aitemplate==0.0.2+das.dtk24042
click==8.1.7
rotary-emb==0.1.0+das.opt1.dtk24043
watchfiles==0.24.0
mmcv==2.0.1+das.opt1.dtk24042
mdurl==0.1.2
aiohttp==3.10.10
accelerate==1.1.0
packaging==24.1
markdown-it-py==3.0.0
pytz==2024.2
contourpy==1.3.0
six==1.16.0
msgpack==1.1.0
torch-sparse==0.6.16+das.opt1.dtk24042
yapf==0.40.2
fastapi==0.115.4
diffusers==0.29.0+das.opt1.dtk24042
exceptiongroup==1.2.2
pytest==8.3.3
dropout-layer-norm==0.1.0+das.opt1.dtk24043
torch-cluster==1.6.0+das.opt1.dtk24042
openai==1.54.0
tinycudann==1.7+das.opt1.dtk24042
sortedcontainers==2.4.0
python-dotenv==1.0.1
ctranslate2==4.1.0+das.opt1.dtk24042
fsspec==2024.9.0
apex==1.1.0+das.opt1.dtk24042
interegular==0.3.3
prometheus_client==0.21.0
multidict==6.1.0
flatbuffers==24.3.25
anyio==4.6.2.post1
uvicorn==0.32.0
scipy==1.14.1
charset-normalizer==3.4.0
filelock==3.16.1
pytorch3d==0.7.6+das.dtk24042
importlib_metadata==8.5.0
xentropy-cuda-lib==0.1.0+das.opt1.dtk24043
torchvision==0.16.0+das.opt1.dtk24042
pydantic==2.9.2
tabulate==0.9.0
httpcore==1.0.6
onnxruntime==1.15.0+das.opt1.dtk24042
iniconfig==2.0.0
propcache==0.2.0
PyYAML==6.0.2
lmslim==0.1.0+das.dtk24042
torchaudio==2.1.2+das.opt1.dtk24042
cycler==0.12.1
sniffio==1.3.1
MarkupSafe==3.0.2
deepspeed==0.14.2+das.opt1.dtk24042
portalocker==2.10.1
urllib3==2.2.3
kiwisolver==1.4.7
iopath==0.1.10
py-cpuinfo==9.0.0
lm-format-enforcer==0.10.1
pycountry==24.6.1
aiohappyeyeballs==2.4.3
fastmoe==1.1.0+das.dtk24042
lark==1.2.2
platformdirs==4.3.6
idna==3.10
torch-scatter==2.1.0+das.opt1.dtk24042
hjson==3.1.0
termcolor==2.5.0
async-timeout==4.0.3
mmengine==0.10.5
typing_extensions==4.12.2
pyparsing==3.2.0
httpx==0.27.2
rpds-py==0.20.1
ray==2.38.0
outlines_core==0.1.14
huggingface-hub==0.26.2
prometheus-fastapi-instrumentator==7.0.0
httptools==0.6.4
nest-asyncio==1.6.0
pip==24.3.1
open-flamingo==2.0.1
open-flamingo==2.0.1
open_clip_torch==2.29.0
{
"os": "Linux-4.18.0-372.9.1.el8.x86_64-x86_64-with-glibc2.31",
"python": "CPython 3.10.12",
"startedAt": "2024-12-14T10:02:20.994023Z",
"args": [
"--lm_path",
"microsoft/Phi-3-mini-4k-instruct",
"--tokenizer_path",
"microsoft/Phi-3-mini-4k-instruct",
"--conv_template_name",
"phi_3",
"--vision_encoder_path",
"google/siglip-so400m-patch14-384",
"--vision_encoder_pretrained",
"google",
"--model_family",
"xgenmm_v1",
"--num_vision_tokens",
"128",
"--pretrained",
"/mnt/xgen-mm/xgen-mm-phi3-mini-base-r-v1.5.pt",
"--data_path",
"/mnt/xgen-mm/LAVIS/data_configs/example_data_config.yaml",
"--data_sampler_group_by_length",
"--image_aspect_ratio",
"anyres",
"--anyres_patch_sampling",
"--batch_size",
"8",
"--fsdp",
"--no_save_optim_state",
"--gradient_checkpointing",
"--fsdp_sharding_strategy",
"hybrid",
"--workers",
"4",
"--num_epochs",
"1",
"--warmup_steps",
"2000",
"--learning_rate",
"2e-5",
"--weight_decay",
"0.0",
"--lr_scheduler",
"cosine",
"--precision",
"amp_bf16",
"--report_to_wandb",
"--wandb_project",
"blip3-xgenmm-finetune",
"--run_name",
"finetune-xgenmmv1-phi3_4k_instruct-example_data_config"
],
"program": "/mnt/xgen-mm/LAVIS/open_flamingo/train/instruction_finetune.py",
"codePath": "open_flamingo/train/instruction_finetune.py",
"git": {
"remote": "https://ghp.ci/github.com/salesforce/LAVIS.git",
"commit": "d699f7e54fbe7072c1fbef3b61a4f5e6d3591bd3"
},
"email": "2470381734@qq.com",
"root": "/mnt/xgen-mm/LAVIS",
"host": "K100-AI02",
"executable": "/usr/local/bin/python",
"codePathLocal": "open_flamingo/train/instruction_finetune.py",
"cpu_count": 88,
"cpu_count_logical": 176,
"disk": {
"/": {
"total": "3779395256320",
"used": "3174349447168"
}
},
"memory": {
"total": "1081531023360"
},
"cpu": {
"count": 88,
"countLogical": 176
}
}
\ No newline at end of file
{"_wandb":{"runtime":12}}
\ No newline at end of file
/root/.cache/wandb/logs/core-debug-20241214_180220.log
\ No newline at end of file
{"time":"2024-12-14T18:02:20.995120164+08:00","level":"INFO","msg":"using version","core version":"0.19.1"}
{"time":"2024-12-14T18:02:20.995130861+08:00","level":"INFO","msg":"created symlink","path":"/mnt/xgen-mm/LAVIS/wandb/run-20241214_180220-9iv99jfi/logs/debug-core.log"}
{"time":"2024-12-14T18:02:21.108926827+08:00","level":"INFO","msg":"created new stream","id":"9iv99jfi"}
{"time":"2024-12-14T18:02:21.108972039+08:00","level":"INFO","msg":"stream: started","id":"9iv99jfi"}
{"time":"2024-12-14T18:02:21.108988541+08:00","level":"INFO","msg":"writer: Do: started","stream_id":"9iv99jfi"}
{"time":"2024-12-14T18:02:21.108991633+08:00","level":"INFO","msg":"sender: started","stream_id":"9iv99jfi"}
{"time":"2024-12-14T18:02:21.108998232+08:00","level":"INFO","msg":"handler: started","stream_id":"9iv99jfi"}
{"time":"2024-12-14T18:02:21.92970439+08:00","level":"INFO","msg":"Starting system monitor"}
{"time":"2024-12-14T18:02:33.461093379+08:00","level":"INFO","msg":"stream: closing","id":"9iv99jfi"}
{"time":"2024-12-14T18:02:33.461182045+08:00","level":"INFO","msg":"Stopping system monitor"}
{"time":"2024-12-14T18:02:33.461615216+08:00","level":"INFO","msg":"Stopped system monitor"}
{"time":"2024-12-14T18:02:34.411264263+08:00","level":"INFO","msg":"fileTransfer: Close: file transfer manager closed"}
{"time":"2024-12-14T18:02:34.898427114+08:00","level":"INFO","msg":"handler: closed","stream_id":"9iv99jfi"}
{"time":"2024-12-14T18:02:34.89845283+08:00","level":"INFO","msg":"writer: Close: closed","stream_id":"9iv99jfi"}
{"time":"2024-12-14T18:02:34.898480833+08:00","level":"INFO","msg":"sender: closed","stream_id":"9iv99jfi"}
{"time":"2024-12-14T18:02:34.898747604+08:00","level":"INFO","msg":"stream: closed","id":"9iv99jfi"}
2024-12-14 18:02:20,988 INFO MainThread:14981 [wandb_setup.py:_flush():68] Current SDK version is 0.19.1
2024-12-14 18:02:20,988 INFO MainThread:14981 [wandb_setup.py:_flush():68] Configure stats pid to 14981
2024-12-14 18:02:20,988 INFO MainThread:14981 [wandb_setup.py:_flush():68] Loading settings from /root/.config/wandb/settings
2024-12-14 18:02:20,988 INFO MainThread:14981 [wandb_setup.py:_flush():68] Loading settings from /mnt/xgen-mm/LAVIS/wandb/settings
2024-12-14 18:02:20,988 INFO MainThread:14981 [wandb_setup.py:_flush():68] Loading settings from environment variables
2024-12-14 18:02:20,988 INFO MainThread:14981 [wandb_init.py:_log_setup():528] Logging user logs to /mnt/xgen-mm/LAVIS/wandb/run-20241214_180220-9iv99jfi/logs/debug.log
2024-12-14 18:02:20,988 INFO MainThread:14981 [wandb_init.py:_log_setup():529] Logging internal logs to /mnt/xgen-mm/LAVIS/wandb/run-20241214_180220-9iv99jfi/logs/debug-internal.log
2024-12-14 18:02:20,988 INFO MainThread:14981 [wandb_init.py:init():644] calling init triggers
2024-12-14 18:02:20,988 INFO MainThread:14981 [wandb_init.py:init():650] wandb.init called with sweep_config: {}
config: {'model_family': 'xgenmm_v1', 'vision_encoder_path': 'google/siglip-so400m-patch14-384', 'vision_encoder_pretrained': 'google', 'lm_path': 'microsoft/Phi-3-mini-4k-instruct', 'tokenizer_path': 'microsoft/Phi-3-mini-4k-instruct', 'cross_attn_every_n_layers': 1, 'num_vision_tokens': 128, 'pretrained': '/mnt/xgen-mm/xgen-mm-phi3-mini-base-r-v1.5.pt', 'pretrained_vision_tokenizer': None, 'loss': 'supervised_finetune', 'run_name': 'finetune-xgenmmv1-phi3_4k_instruct-example_data_config', 'resume_from_checkpoint': None, 'delete_previous_checkpoint': False, 'no_save_optim_state': True, 'gradient_accumulation_steps': 1, 'seed': 42, 'learning_rate': 2e-05, 'lr_scheduler': 'cosine', 'warmup_steps': 2000, 'weight_decay': 0.0, 'precision': 'amp_bf16', 'gradient_checkpointing': True, 'num_epochs': 1, 'offline': False, 'logging_steps': 100, 'checkpoint_steps': 5000, 'data_path': '/mnt/xgen-mm/LAVIS/data_configs/example_data_config.yaml', 'batch_size': 8, 'workers': 4, 'data_sampler_group_by_length': True, 'is_multimodal': True, 'mm_use_im_start_end': False, 'conv_template_name': 'phi_3', 'image_aspect_ratio': 'anyres', 'anyres_patch_sampling': True, 'anyres_grids': [(1, 2), (2, 1), (2, 2), (3, 1), (1, 3)], 'dist_url': 'env://', 'dist_backend': 'nccl', 'horovod': False, 'no_set_device_rank': False, 'local_rank': 0, 'fsdp': True, 'fsdp_sharding_strategy': 'hybrid', 'report_to_wandb': True, 'wandb_project': 'blip3-xgenmm-finetune', 'wandb_entity': None, 'save_checkpoints_to_wandb': False, 'dryrun': False, 'use_flash_attention_2': False, 'unfreeze_vision_encoder': False, 'vision_encoder_precision': 'fp32', 'cpu_offload_gradients': False, 'rank': 0, 'world_size': 8, 'distributed': True, 'device': 'cuda:0'}
2024-12-14 18:02:20,988 INFO MainThread:14981 [wandb_init.py:init():680] starting backend
2024-12-14 18:02:20,988 INFO MainThread:14981 [wandb_init.py:init():684] sending inform_init request
2024-12-14 18:02:20,993 INFO MainThread:14981 [backend.py:_multiprocessing_setup():104] multiprocessing start_methods=fork,spawn,forkserver, using: spawn
2024-12-14 18:02:20,993 INFO MainThread:14981 [wandb_init.py:init():697] backend started and connected
2024-12-14 18:02:20,995 INFO MainThread:14981 [wandb_init.py:init():790] updated telemetry
2024-12-14 18:02:21,005 INFO MainThread:14981 [wandb_init.py:init():822] communicating run to backend with 90.0 second timeout
2024-12-14 18:02:21,923 INFO MainThread:14981 [wandb_init.py:init():874] starting run threads in backend
2024-12-14 18:02:22,048 INFO MainThread:14981 [wandb_run.py:_console_start():2374] atexit reg
2024-12-14 18:02:22,048 INFO MainThread:14981 [wandb_run.py:_redirect():2224] redirect: wrap_raw
2024-12-14 18:02:22,048 INFO MainThread:14981 [wandb_run.py:_redirect():2289] Wrapping output streams.
2024-12-14 18:02:22,048 INFO MainThread:14981 [wandb_run.py:_redirect():2314] Redirects installed.
2024-12-14 18:02:22,050 INFO MainThread:14981 [wandb_init.py:init():916] run started, returning control to user process
2024-12-14 18:02:33,461 WARNING MsgRouterThr:14981 [router.py:message_loop():75] message_loop has been closed
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment