Commit 7a60e044 authored by wanglch's avatar wanglch
Browse files

Initial commit

parents
Pipeline #1185 canceled with stages
export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
python -m torch.distributed.launch \
--nproc_per_node=${NPROC_PER_NODE:-8} \
--nnodes=${WORLD_SIZE:-1} \
--node_rank=${RANK:-0} \
--master_addr=${MASTER_ADDR:-127.0.0.1} \
--master_port=${MASTER_PORT:-12345} \
./eval.py \
--model_name minicpm \
--model_path \
--generate_method interleave \
--eval_textVQA \
--eval_docVQA \
--answer_path ./answers \
--batchsize 1
\ No newline at end of file
python ./transform_docvqatest_for_submission.py \
--input_file_path \
--output_file_path
\ No newline at end of file
import argparse
import json
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--input_file_path", type=str, default="", help="path to the originial output json.")
parser.add_argument("--output_file_path", type=str, default="", help="path to where you want to save the processed json.")
args = parser.parse_args()
with open(args.input_file_path , 'r') as f:
data = json.load(f)
transformed_data = [{"questionId": item["question_id"], "answer": item["answer"].replace("</s>", "")} for item in data]
with open(args.output_file_path, 'w') as f:
json.dump(transformed_data, f)
import copy
import json
import logging
import math
import os
from dataclasses import dataclass, field
from typing import Dict, List, Optional
import numpy as np
import torch
from PIL import Image
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset
from transformers import AutoProcessor, AutoTokenizer
llama3_chat_template = "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}"
class SupervisedDataset(Dataset):
"""Dataset for supervised fine-tuning."""
def __init__(
self,
raw_data,
transform,
tokenizer,
slice_config,
llm_type="minicpm",
patch_size=14,
query_nums=64,
batch_vision=False,
):
super(SupervisedDataset, self).__init__()
self.raw_data = raw_data
self.tokenizer = tokenizer
self.transform = transform
self.slice_config = slice_config
self.llm_type = llm_type
self.patch_size = patch_size
self.query_nums=query_nums
self.batch_vision = batch_vision
def __len__(self):
return len(self.raw_data)
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
image = Image.open(self.raw_data[i]["image"]).convert("RGB")
ret = preprocess(
image,
self.raw_data[i]["conversations"],
self.tokenizer,
self.transform,
query_nums=self.query_nums,
slice_config=self.slice_config,
llm_type=self.llm_type,
patch_size=self.patch_size,
batch_vision=self.batch_vision,
)
ret = dict(
input_ids=ret["input_ids"],
position_ids=ret["position_ids"],
labels=ret["target"],
attention_mask=torch.ones_like(ret["input_ids"], dtype=torch.bool),
pixel_values=ret["pixel_values"],
tgt_sizes=ret["tgt_sizes"],
image_bound=ret["image_bound"],
)
return ret
def data_collator(examples, padding_value=0, max_length=2048):
def trim_and_pad(seq, batch_first, padding_value):
return pad_sequence([s[:max_length] for s in seq], batch_first=True, padding_value=padding_value)
input_ids = trim_and_pad(
[example["input_ids"] for example in examples],
batch_first=True,
padding_value=padding_value,
)
position_ids = trim_and_pad(
[example["position_ids"] for example in examples],
batch_first=True,
padding_value=padding_value,
)
targets = trim_and_pad(
[example["labels"] for example in examples],
batch_first=True,
padding_value=-100,
)
attention_mask = trim_and_pad(
[example["attention_mask"] for example in examples],
batch_first=True,
padding_value=padding_value,
)
pixel_values = [example["pixel_values"] for example in examples]
image_bound = [example["image_bound"] for example in examples]
tgt_sizes = [example["tgt_sizes"] for example in examples]
return {
"input_ids": input_ids,
"position_ids": position_ids,
"labels": targets,
"attention_mask": attention_mask,
"image_bound": image_bound,
"tgt_sizes": tgt_sizes,
"pixel_values": pixel_values,
}
def conversation_to_ids(conversation, tokenizer, llm_type=None):
"""
for single image multi-turn conversation
conversation: [{'role': 'user', 'content': 'Describe this image'},
{'role': 'assistant', 'content': 'This is a cat.'}]
"""
if llm_type == "llama3":
input_ids, context, raw_msg = conversation_to_ids_llama3(
conversation, tokenizer
)
else:
input_ids, context, raw_msg = conversation_to_ids_minicpm(
conversation, tokenizer
)
ids = torch.from_numpy(np.hstack(input_ids, dtype=np.int32))
context = torch.from_numpy(np.hstack(context, dtype=np.int8))
# build target
target = torch.full_like(ids, -100, dtype=torch.int32)
for i in range(1, len(ids)):
if context[i] == 0:
target[i - 1] = ids[i]
if context[i] == 1 and context[i - 1] == 0:
if hasattr(tokenizer, "eot_id"):
target[i - 1] = tokenizer.eot_id
else:
target[i - 1] = tokenizer.eos_id
# build image bound
image_start_tokens = torch.where(ids == tokenizer.im_start_id)[0]
image_start_tokens += 1
image_end_tokens = torch.where(ids == tokenizer.im_end_id)[0]
if len(image_start_tokens) != len(image_end_tokens):
print("image start token != image end tokens")
if len(image_start_tokens) > 0:
image_bound = torch.hstack(
[image_start_tokens.unsqueeze(-1), image_end_tokens.unsqueeze(-1)]
)
else:
image_bound = []
position_ids = torch.arange(ids.size(0)).long()
return {
"input_ids": ids,
"target": target,
"image_bound": image_bound,
"raw_msg": raw_msg,
"position_ids": position_ids
}
def conversation_to_ids_minicpm(conversation, tokenizer):
raw_msg = ""
input_ids = []
context = []
for idx, msg in enumerate(conversation):
role = msg["role"]
message = msg["content"]
assert role in ["user", "assistant"]
if role == "user":
prefix = "<用户>"
else:
prefix = "<AI>"
# append eos
if idx == len(conversation) - 1:
message = message + tokenizer.eos_token
prefix_ids = tokenizer.encode(prefix)[1:] # remove bos
message_ids = tokenizer.encode(message)[1:]
input_ids.append(prefix_ids)
input_ids.append(message_ids)
context.append(np.ones((len(prefix_ids),), dtype=np.int8))
if role == "assistant":
context.append(np.zeros((len(message_ids),), dtype=np.int8))
else:
context.append(np.ones((len(message_ids),), dtype=np.int8))
raw_msg += prefix + message
return input_ids, context, raw_msg
def conversation_to_ids_llama3(conversation, tokenizer):
raw_msg = ""
input_ids = []
context = []
raw_msg = tokenizer.apply_chat_template(
conversation, tokenize=False, add_generation_prompt=False, chat_template=llama3_chat_template,
)
input_ids = tokenizer.apply_chat_template(
conversation, tokenize=True, add_generation_prompt=False, chat_template=llama3_chat_template,
)
input_ids = np.array(input_ids)
start_header_idxs = np.where(
input_ids == tokenizer.convert_tokens_to_ids("<|start_header_id|>")
)[0]
assistant_idxs = np.where(
input_ids == tokenizer.convert_tokens_to_ids("assistant")
)[0]
end_header_idxs = np.where(
input_ids == tokenizer.convert_tokens_to_ids("<|end_header_id|>")
)[0]
eot_idxs = np.where(
input_ids == tokenizer.convert_tokens_to_ids("<|eot_id|>"))[0]
context = np.ones_like(input_ids, dtype=np.int8)
for assistant_idx in assistant_idxs:
if assistant_idx in set((start_header_idxs + end_header_idxs) / 2):
st = assistant_idx + 3 # assistant<|end_header_id|>\n\n
for eot_idx in eot_idxs:
if eot_idx > st:
context[st: eot_idx + 1] = 0
break
input_ids = np.hstack(input_ids)
context = np.hstack(context)
return input_ids, context, raw_msg
def preprocess(
image,
conversation,
tokenizer,
transform,
query_nums=64,
slice_config=None,
llm_type=None,
patch_size=14,
batch_vision=False,
):
"""
single image preprocess, the image will be placed at the top of the conversation
"""
conversation = copy.deepcopy(conversation)
assert len(conversation) > 1, "conversation length must large than 2"
assert conversation[0]["role"] == "user", "the first role must be user"
if slice_config is not None:
assert isinstance(slice_config, Dict)
assert "patch_size" in slice_config
assert "max_slice_nums" in slice_config
assert "scale_resolution" in slice_config
default_image_placeholder = (
tokenizer.im_start + tokenizer.unk_token * query_nums + tokenizer.im_end
)
if slice_config:
images = []
source_image, patches, best_grid = slice_image(
image,
slice_config["max_slice_nums"],
slice_config["scale_resolution"],
slice_config["patch_size"],
)
images.append(source_image)
image_placeholder = default_image_placeholder
if len(patches) > 0:
for i in range(len(patches)):
for j in range(len(patches[0])):
images.append(patches[i][j])
image_placeholder += get_grid_placeholder(
tokenizer, best_grid, query_nums)
images = [transform(i) for i in images]
else:
images = [transform(image)]
image_placeholder = default_image_placeholder
if "<image>" in conversation[0]["content"]:
conversation[0]["content"] = conversation[0]["content"].replace(
"<image>", image_placeholder
)
else:
conversation[0]["content"] = (
image_placeholder + "\n" + conversation[0]["content"]
)
input_dict = conversation_to_ids(conversation, tokenizer, llm_type)
if batch_vision:
tgt_sizes = []
reshape_images = []
for image in images:
H, W = image.shape[1:]
reshape_image = reshape_by_patch(image, patch_size)
reshape_images.append(reshape_image)
tgt_sizes.append([H // patch_size, W // patch_size])
if tgt_sizes:
tgt_sizes = torch.Tensor(tgt_sizes).type(torch.int32)
input_dict["pixel_values"] = reshape_images
input_dict["tgt_sizes"] = tgt_sizes
else:
input_dict["pixel_values"] = images
input_dict["tgt_sizes"] = []
return input_dict
def slice_image(
image, max_slice_nums=9, scale_resolution=448, patch_size=14, never_split=False
):
original_size = image.size
original_width, original_height = original_size
log_ratio = math.log(original_width / original_height)
ratio = original_width * original_height / \
(scale_resolution * scale_resolution)
multiple = min(math.ceil(ratio), max_slice_nums)
source_image = None
best_grid = None
patches = []
if multiple <= 1 or never_split:
# dont need to slice, upsample
best_size = find_best_resize(
original_size, scale_resolution, patch_size, allow_upscale=True
)
source_image = image.resize(best_size, Image.Resampling.BICUBIC)
else:
candidate_split_grids_nums = []
for i in [multiple - 1, multiple, multiple + 1]:
if i == 1 or i > max_slice_nums:
continue
candidate_split_grids_nums.append(i)
# source image, down-sampling and ensure divided by patch_size
best_resize = find_best_resize(
original_size, scale_resolution, patch_size)
source_image = image.copy().resize(best_resize, Image.Resampling.BICUBIC)
candidate_grids = []
# find best grid
for split_grids_nums in candidate_split_grids_nums:
m = 1
while m <= split_grids_nums:
if split_grids_nums % m == 0:
candidate_grids.append([m, split_grids_nums // m])
m += 1
best_grid = [1, 1]
min_error = float("inf")
for grid in candidate_grids:
error = abs(log_ratio - math.log(grid[0] / grid[1]))
if error < min_error:
best_grid = grid
min_error = error
refine_size = get_refine_size(
original_size, best_grid, scale_resolution, patch_size, allow_upscale=True
)
refine_image = image.resize(refine_size, Image.Resampling.BICUBIC)
patches = split_to_patches(refine_image, best_grid)
return source_image, patches, best_grid
def ensure_divide(length, patch_size):
return max(round(length / patch_size) * patch_size, patch_size)
def find_best_resize(original_size, scale_resolution, patch_size, allow_upscale=False):
width, height = original_size
if (width * height > scale_resolution * scale_resolution) or allow_upscale:
r = width / height
height = int(scale_resolution / math.sqrt(r))
width = int(height * r)
best_width = ensure_divide(width, patch_size)
best_height = ensure_divide(height, patch_size)
return (best_width, best_height)
def get_refine_size(
original_size, grid, scale_resolution, patch_size, allow_upscale=False
):
width, height = original_size
grid_x, grid_y = grid
refine_width = ensure_divide(width, grid_x)
refine_height = ensure_divide(height, grid_y)
grid_width = refine_width / grid_x
grid_height = refine_height / grid_y
best_grid_size = find_best_resize(
(grid_width, grid_height),
scale_resolution,
patch_size,
allow_upscale=allow_upscale,
)
refine_size = (best_grid_size[0] * grid_x, best_grid_size[1] * grid_y)
return refine_size
def split_to_patches(image, grid):
patches = []
width, height = image.size
grid_x = int(width / grid[0])
grid_y = int(height / grid[1])
for i in range(0, height, grid_y):
images = []
for j in range(0, width, grid_x):
box = (j, i, j + grid_x, i + grid_y)
patch = image.crop(box)
images.append(patch)
patches.append(images)
return patches
def get_grid_placeholder(tokenizer, grid, query_num):
image_placeholder = (
tokenizer.im_start + tokenizer.unk_token * query_num + tokenizer.im_end
)
cols = grid[0]
rows = grid[1]
slices = []
for i in range(rows):
lines = []
for j in range(cols):
lines.append(image_placeholder)
slices.append("".join(lines))
slice_placeholder = tokenizer.slice_start + \
"\n".join(slices) + tokenizer.slice_end
return slice_placeholder
def reshape_by_patch(image_tensor, patch_size):
"""
:param image_tensor: shape [3, H, W]
:param patch_size:
:return: [3, patch_size, HW/patch_size]
"""
patches = torch.nn.functional.unfold(
image_tensor, (patch_size, patch_size), stride=(patch_size, patch_size)
)
patches = patches.reshape(image_tensor.size(0), patch_size, patch_size, -1)
patches = patches.permute(0, 1, 3, 2).reshape(
image_tensor.size(0), patch_size, -1)
return patches
{
"fp16": {
"enabled": "auto",
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 1
},
"bf16": {
"enabled": "auto"
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"betas": "auto",
"eps": "auto",
"weight_decay": "auto"
}
},
"scheduler": {
"type": "WarmupLR",
"params": {
"warmup_min_lr": "auto",
"warmup_max_lr": "auto",
"warmup_num_steps": "auto"
}
},
"zero_optimization": {
"stage": 2,
"offload_optimizer": {
"device": "none",
"pin_memory": true
},
"allgather_partitions": true,
"allgather_bucket_size": 2e8,
"overlap_comm": true,
"reduce_scatter": true,
"reduce_bucket_size": 2e8,
"contiguous_gradients": true
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"steps_per_print": 100,
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false
}
{
"fp16": {
"enabled": "auto",
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 1
},
"bf16": {
"enabled": "auto"
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"betas": "auto",
"eps": "auto",
"weight_decay": "auto"
}
},
"scheduler": {
"type": "WarmupLR",
"params": {
"warmup_min_lr": "auto",
"warmup_max_lr": "auto",
"warmup_num_steps": "auto"
}
},
"zero_optimization": {
"stage": 3,
"offload_optimizer": {
"device": "none",
"pin_memory": true
},
"offload_param": {
"device": "none",
"pin_memory": true
},
"overlap_comm": true,
"contiguous_gradients": true,
"sub_group_size": 1e9,
"reduce_bucket_size": "auto",
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto",
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"stage3_gather_16bit_weights_on_model_save": true
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"steps_per_print": 100,
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false
}
import glob
import json
import logging
import os
from dataclasses import dataclass, field
from functools import partial
from typing import Dict, List, Optional, Union, Literal, Tuple
from types import MethodType
import torch
import transformers
from accelerate.utils import DistributedType
from deepspeed import zero
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
from transformers import AutoModel, AutoTokenizer
from transformers.integrations import deepspeed
from transformers import AutoModel, AutoTokenizer
from dataset import SupervisedDataset, data_collator
from trainer import CPMTrainer
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
@dataclass
class ModelArguments:
model_name_or_path: Optional[str] = field(default="openbmb/MiniCPM-V-2")
@dataclass
class DataArguments:
data_path: str = field(
default=None, metadata={"help": "Path to the training data."}
)
eval_data_path: str = field(
default=None, metadata={"help": "Path to the evaluation data."}
)
@dataclass
class TrainingArguments(transformers.TrainingArguments):
cache_dir: Optional[str] = field(default=None)
optim: str = field(default="adamw_torch")
model_max_length: int = field(
default=2048,
metadata={
"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
},
)
tune_vision: Optional[bool] = field(default=True)
tune_llm: Optional[bool] = field(default=True)
llm_type: str = field(default="minicpm")
use_lora: Optional[bool] = field(default=False)
max_slice_nums: Optional[int] = field(default=9)
@dataclass
class LoraArguments:
lora_r: int = 64
lora_alpha: int = 64
lora_dropout: float = 0.05
lora_target_modules: str = r"llm\..*layers\.\d+\.self_attn\.(q_proj|k_proj|v_proj)"
lora_weight_path: str = ""
lora_bias: str = "none"
q_lora: bool = False
lora_modules_to_save: str = ""
lora_layer_replication: Optional[List[Tuple[int, int]]] = None
lora_layers_to_transform: Optional[List[int]] = None
lora_layers_pattern: Optional[str] = None
def maybe_zero_3(param):
if hasattr(param, "ds_id"):
assert param.ds_status == ZeroParamStatus.NOT_AVAILABLE
with zero.GatheredParameters([param]):
param = param.data.detach().cpu().clone()
else:
param = param.detach().cpu().clone()
return param
# Borrowed from peft.utils.get_peft_model_state_dict
def get_peft_state_maybe_zero_3(named_params, bias):
if bias == "none":
to_return = {k: t for k, t in named_params if "lora_" in k}
elif bias == "all":
to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k}
elif bias == "lora_only":
to_return = {}
maybe_lora_bias = {}
lora_bias_names = set()
for k, t in named_params:
if "lora_" in k:
to_return[k] = t
bias_name = k.split("lora_")[0] + "bias"
lora_bias_names.add(bias_name)
elif "bias" in k:
maybe_lora_bias[k] = t
for k, t in maybe_lora_bias:
if bias_name in lora_bias_names:
to_return[bias_name] = t
else:
raise NotImplementedError
to_return = {k: maybe_zero_3(v) for k, v in to_return.items()}
return to_return
local_rank = None
def rank0_print(*args):
if local_rank == 0:
print(*args)
def safe_save_model_for_hf_trainer(trainer, output_dir: str, bias="none"):
"""Collects the state dict and dump to disk."""
# check if zero3 mode enabled
if deepspeed.is_deepspeed_zero3_enabled():
state_dict = trainer.model_wrapped._zero3_consolidated_16bit_state_dict()
else:
if trainer.args.use_lora:
state_dict = get_peft_state_maybe_zero_3(
trainer.model.named_parameters(), bias
)
else:
state_dict = trainer.model.state_dict()
if trainer.args.should_save and trainer.args.local_rank == 0:
trainer._save(output_dir, state_dict=state_dict)
def make_supervised_data_module(
tokenizer: transformers.PreTrainedTokenizer,
data_args,
transform,
data_collator=None,
llm_type="minicpm",
slice_config=None,
patch_size=14,
query_nums=64,
batch_vision=False,
max_length=2048,
) -> Dict:
"""Make dataset and collator for supervised fine-tuning."""
dataset_cls = SupervisedDataset
rank0_print("Loading data...")
train_json = json.load(open(data_args.data_path, "r"))
train_dataset = dataset_cls(
train_json,
transform,
tokenizer,
slice_config=slice_config,
llm_type=llm_type,
patch_size=patch_size,
query_nums=query_nums,
batch_vision=batch_vision,
)
if data_args.eval_data_path:
eval_json = json.load(open(data_args.eval_data_path, "r"))
eval_dataset = dataset_cls(
eval_json,
transform,
tokenizer,
slice_config=slice_config,
llm_type=llm_type,
patch_size=patch_size,
query_nums=query_nums,
batch_vision=batch_vision,
)
else:
eval_dataset = None
return dict(
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator= partial(data_collator, max_length=max_length),
)
def get_parameter_number(model):
trainable_params, all_param = 0, 0
for param in model.parameters():
num_params = param.numel()
# if using DS Zero 3 and the weights are initialized empty
if num_params == 0 and hasattr(param, "ds_numel"):
num_params = param.ds_numel
all_param += num_params
if param.requires_grad:
trainable_params += num_params
return {'Total': all_param, 'Trainable': trainable_params}
local_rank = 0
def train():
global local_rank
parser = transformers.HfArgumentParser(
(ModelArguments, DataArguments, TrainingArguments, LoraArguments)
)
(
model_args,
data_args,
training_args,
lora_args,
) = parser.parse_args_into_dataclasses()
if getattr(training_args, "deepspeed", None) :
training_args.distributed_state.distributed_type = DistributedType.DEEPSPEED
compute_dtype = (
torch.float16
if training_args.fp16
else (torch.bfloat16 if training_args.bf16 else torch.float32)
)
local_rank = training_args.local_rank
world_size = int(os.environ.get("WORLD_SIZE", 1))
ddp = world_size != 1
device_map = None
if lora_args.q_lora:
device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} if ddp else None
if len(training_args.fsdp) > 0 or deepspeed.is_deepspeed_zero3_enabled():
logging.warning(
"FSDP or ZeRO3 are not incompatible with QLoRA."
)
model = AutoModel.from_pretrained(
model_args.model_name_or_path,
trust_remote_code=True,
torch_dtype=compute_dtype,
device_map=device_map,
)
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path, trust_remote_code=True
)
if not training_args.tune_vision:
model.vpm.requires_grad_(False)
if not training_args.tune_llm:
model.llm.requires_grad_(False)
if training_args.use_lora:
if training_args.use_lora and training_args.tune_llm:
raise ValueError("The model cannot simultaneously adjust LLM parameters and apply LoRA.")
rank0_print("Currently using LoRA for fine-tuning the MiniCPM-V model.")
for name, param in model.llm.named_parameters():
param.requires_grad = False
lora_config = LoraConfig(
r=lora_args.lora_r,
lora_alpha=lora_args.lora_alpha,
target_modules=lora_args.lora_target_modules,
lora_dropout=lora_args.lora_dropout,
bias=lora_args.lora_bias,
layers_to_transform=lora_args.lora_layers_to_transform,
task_type="CAUSAL_LM",
)
if not hasattr(model, 'get_input_embeddings'):
def get_input_embeddings(self):
return self.llm.get_input_embeddings()
model.get_input_embeddings = MethodType(get_input_embeddings, model)
if lora_args.q_lora:
model = prepare_model_for_kbit_training(
model, use_gradient_checkpointing=training_args.gradient_checkpointing
)
model = get_peft_model(model, lora_config)
model.base_model.resampler.requires_grad_(True)
model.base_model.llm.model.embed_tokens.weight.requires_grad_(True)
if training_args.tune_vision:
model.base_model.vpm.requires_grad_(True)
if training_args.gradient_checkpointing:
model.enable_input_require_grads()
rank0_print(get_parameter_number(model))
llm_type = training_args.llm_type
rank0_print(f'llm_type={llm_type}')
# Load data
if hasattr(model.config, "slice_config"):
model.config.slice_config.max_slice_nums = training_args.max_slice_nums
slice_config = model.config.slice_config.to_dict()
else:
model.config.max_slice_nums = training_args.max_slice_nums
slice_config = model.config.to_dict()
if hasattr(model.config, "batch_vision_input"):
batch_vision = model.config.batch_vision_input
else:
batch_vision = False
data_module = make_supervised_data_module(
tokenizer=tokenizer,
data_args=data_args,
transform=model.transform,
data_collator=data_collator,
slice_config=slice_config,
llm_type=llm_type,
patch_size=model.config.patch_size,
query_nums=model.config.query_num,
batch_vision=batch_vision,
max_length=training_args.model_max_length,
)
trainer = CPMTrainer(
model=model,
tokenizer=tokenizer,
args=training_args,
**data_module,
)
trainer.train()
trainer.save_state()
safe_save_model_for_hf_trainer(
trainer=trainer,
output_dir=training_args.output_dir,
bias=lora_args.lora_bias)
if __name__ == "__main__":
train()
#!/bin/bash
GPUS_PER_NODE=8
NNODES=1
NODE_RANK=0
MASTER_ADDR=localhost
MASTER_PORT=6001
MODEL="openbmb/MiniCPM-Llama3-V-2_5" # or openbmb/MiniCPM-V-2
# ATTENTION: specify the path to your training data, which should be a json file consisting of a list of conversations.
# See the section for finetuning in README for more information.
DATA="path/to/trainging_data"
EVAL_DATA="path/to/test_data"
LLM_TYPE="llama3" # if use openbmb/MiniCPM-V-2, please set LLM_TYPE=minicpm
DISTRIBUTED_ARGS="
--nproc_per_node $GPUS_PER_NODE \
--nnodes $NNODES \
--node_rank $NODE_RANK \
--master_addr $MASTER_ADDR \
--master_port $MASTER_PORT
"
torchrun $DISTRIBUTED_ARGS finetune.py \
--model_name_or_path $MODEL \
--llm_type $LLM_TYPE \
--data_path $DATA \
--eval_data_path $EVAL_DATA \
--remove_unused_columns false \
--label_names "labels" \
--prediction_loss_only false \
--bf16 false \
--bf16_full_eval false \
--fp16 true \
--fp16_full_eval true \
--do_train \
--do_eval \
--tune_vision true \
--tune_llm true \
--model_max_length 2048 \
--max_slice_nums 9 \
--max_steps 10000 \
--eval_steps 1000 \
--output_dir output/output_minicpmv2 \
--logging_dir output/output_minicpmv2 \
--logging_strategy "steps" \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 1 \
--evaluation_strategy "steps" \
--save_strategy "steps" \
--save_steps 1000 \
--save_total_limit 10 \
--learning_rate 1e-6 \
--weight_decay 0.1 \
--adam_beta2 0.95 \
--warmup_ratio 0.01 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--gradient_checkpointing true \
--deepspeed ds_config_zero2.json \
--report_to "tensorboard"
# MiniCPM-V Finetuning
We offer the official scripts for easy finetuning of the pretrained **MiniCPM-Llama3-V 2.5** and **MiniCPM-V 2.0** on downstream tasks. Our finetune scripts use transformers Trainer and DeepSpeed by default.
### Data preparation
To prepare your finetuning data, you should formulate each sample as a dictionary consisting of an id, an image path list with an image, and a list of conversations. Then save data samples in JSON files.
For the vision-language example with image, you are required to provide **\<image\>** to define the position to insert the image embeddings. If you don't provide \<image\>, the image will be placed at the front of the conversation.
<details>
<summary>
<b>vision-language example (vl_finetune_data.json) with 1 samples.</b>
</summary>
```
[
{
"id": "0",
"image": 'path/to/image_0.jpg',
"conversations": [
{
'role': 'user',
'content': '<image>\nHow many desserts are on the white plate?'
},
{
'role': 'assistant',
'content': 'There are three desserts on the white plate.'
},
{
'role': 'user',
'content': 'What type of desserts are they?'
},
{
'role': 'assistant',
'content': 'The desserts are cakes with bananas and pecans on top. They share similarities with donuts, but the presence of bananas and pecans differentiates them.'
},
{
'role': 'user',
'content': 'What is the setting of the image?'},
{
'role': 'assistant',
'content': 'The image is set on a table top with a plate containing the three desserts.'
},
]
},
]
```
</details>
### Full-parameter finetuning
Full-parameter parameter finetuning requires updating all parameters of LLM in the whole training process. Please specify the correct MODEL path, DATA path and LLM_TYPE in the shell scripts.
```shell
MODEL="openbmb/MiniCPM-Llama3-V-2_5" # or openbmb/MiniCPM-V-2
DATA="path/to/trainging_data" # json file
EVAL_DATA="path/to/test_data" # json file
LLM_TYPE="llama3" # if use openbmb/MiniCPM-V-2, please set LLM_TYPE=minicpm
```
To launch your training, run the following script:
```
sh finetune_ds.sh
```
Specially, Llama3 has a different chat_template for training and inference, we modified the chat_template for training, so please take care to restore the chat_template when inference on the training ckpt.
### LoRA finetuning
The LoRA allows light-weight model tuning with only a small subset of parameters updated. We provide the LoRA implementation based on `peft`. To launch your training, run the following script:
```
sh finetune_lora.sh
```
After training, you could load the model with the path to the adapter. We advise you to use absolute path for your pretrained model. This is because LoRA only saves the adapter and the absolute path in the adapter configuration json file is used for finding out the pretrained model to load.
```
from peft import AutoPeftModelForCausalLM
path_to_adapter="path_to_adapter"
model = AutoPeftModelForCausalLM.from_pretrained(
# path to the output directory
path_to_adapter,
device_map="auto",
trust_remote_code=True
).eval()
vpm_resampler_embedtokens_weight = torch.load(f"{path_to_adapter}/vpm_resampler_embedtokens.pt")
msg = model.load_state_dict(vpm_resampler_embedtokens_weight, strict=False)
```
### Model Fine-tuning Memory Usage Statistics
The following table presents the memory usage of the model when fine-tuning using NVIDIA A100 (80GiB) GPUs under different numbers of GPUs. The fine-tuning was performed with the DeepSpeed Zero-3 optimization, Gradient Checkpointing techniques and offloading optimizer as well as parameters memory to cpu, with a maximum length set to 2048 and batch size set to 1. You refer to [deepspeed zero stage](https://huggingface.co/docs/transformers/v4.41.2/en/deepspeed#select-a-zero-stage) to reduce memory cost.
| Fine-tuning Method | GPUs: 2 | GPUs: 4 | GPUs: 8 |
|--------------------|---------|---------|---------|
| LoRA Fine-tuning | 14.4 GiB| 13.6 GiB| 13.1 GiB |
| Full Parameters Fine-tuning | 16.0 GiB | 15.8 GiB | 15.63GiB |
### Notes
- **Fine-tuning Method**: Displays two different fine-tuning strategies, LoRA fine-tuning and Full parameters fine-tuning.
- **Number of GPUs**: The table lists the memory usage for configurations with 2, 4, and 8 GPUs.
- **Memory Usage**: Expressed in GiB, this shows the required memory for each fine-tuning method under corresponding GPU configurations.
- **Out of memory**: Indicates that the memory was insufficient for full parameters fine-tuning under the current GPU configurations.
### Finetuning FAQs
<details>
<summary>Q:When you encounter Out of Memory (OOM) issues during training large models, you can try the following methods to resolve or mitigate the issue:</summary>
A:When you face Out of Memory (OOM) issues during training large models, the following strategies may help resolve or mitigate the problem:
#### Adjust Model Hyperparameters
- **Reduce `max_model_length`**: Decreasing the maximum sequence length the model processes can significantly reduce the memory required for each operation. For example, reducing the maximum length from 2048 to 1200 or another value suitable for your dataset.
```
--model_max_length 1200
```
- **Lower `batch_size`**: Reducing the amount of data processed in each batch helps decrease memory consumption.
```
--batch_size 1
```
- **Reduce the number of slices (`slice`)**: When handling large datasets such as large images files, reducing the number of slices processed each time can lower memory requirements.
```
--max_slice_nums 9
```
#### Reduce Training Model Parameters
- **Do not train VPM (Visual Processing Module)**: You can adjust hyperparameters in the finetune script to opt out of training the visual processing module to save memory.
```
--tune_vision false
```
- **Use LoRA finetuning**: Refer to the [LoRA finetuning](#LoRA-finetuning) section.
#### Optimize with DeepSpeed
- **Configure DeepSpeed Zero Stage 2**: Use the following configuration to offload optimizer parameters to the CPU, reducing memory pressure on the GPU:
```json
"zero_optimization": {
"stage": 2,
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
}
}
- **Configure DeepSpeed Zero Stage 3**:Further offload model parameters and optimizer parameters to the CPU, further reducing GPU memory usage:
```json
"zero_optimization": {
"stage": 3,
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
},
"offload_param": {
"device": "cpu",
"pin_memory": true
}
}
```
You can visit [huggingface deepspeed](https://huggingface.co/docs/transformers/deepspeed) to find out more about how to use DeepSpeed.
</details>
<details>
<summary>Q: Encounter an error while using the AutoPeftModelForCausalLM to load a checkpoint that has undergone lora fine-tuning</summary>
A: The error as described in [issues 168](https://github.com/OpenBMB/MiniCPM-V/issues/168) occurs because the model lacks `get_input_embeddings` and `set_input_embeddings` methods. Follow these steps to resolve this issue:
1.**Reload the Fine-Tuned Model:** Make sure you correctly load the checkpoint that has been fine-tuned using lora techniques. Use the following code example to guide you:
```python
from peft import AutoPeftModelForCausalLM
model = AutoPeftModelForCausalLM.from_pretrained(
'path_to_your_fine_tuned_checkpoint', # Path to your fine-tuned checkpoint directory
output='output/minicpmv2_lora',
device_map='auto',
trust_remote_code=True
).eval()
```
2.**Update the `model_minicpmv.py` File:**
- **Verification:** Make sure you verify and update your `model_minicpmv.py` file to ensure it is the latest version.
- **Update Hugging Face Library Code:** If the issue persists after updating the file, consider updating the related code in the Hugging Face library.
- **Direct File Copy:** For a quick resolution, directly download and copy the latest `model_minicpmv.py` file into your project. This file is available from the following sources:
- [MiniCPM-Llama3-V-2_5 on Hugging Face](https://huggingface.co/openbmb/MiniCPM-Llama3-V-2_5/tree/main)
- [MiniCPM-V-2 on Hugging Face](https://huggingface.co/openbmb/MiniCPM-V-2)
</details>
<details>
<summary>Q: How do I use the `flash_attention_2` implementation when loading a pretrained model?</summary>
A: If your environment supports `flash_attn2`, you can add an argument `_attn_implementation="flash_attention_2"` when using the `AutoModel.from_pretrained` method to load a model. For example:
```python
model = AutoModel.from_pretrained('model_name', _attn_implementation="flash_attention_2")
```
</details>
<details>
<summary>Q: What if our data is resized to 512? Can we use the original image size instead?</summary>
A: Our model supports up to 1344x1344 lossless encoding. If you are currently resizing your images to 512, you might want to try using the original image sizes instead. Our system automatically includes a high-definition image encoding scheme by default.
</details>
<details>
<summary>Q: What should we do if we encounter out-of-memory (OOM) errors?</summary>
A: If you experience OOM issues, consider reducing the batch size (`bs`). To maintain an equivalent total batch size, you can adjust the `gradient_accumulation_steps` setting. This approach allows you to manage memory usage effectively while still processing the desired amount of data per training step.
</details>
<details>
<summary>Q: How can we determine the maximum length for our training data, and what if we do not want to train the vision encoder?</summary>
A: I recommend using this function [here](https://github.com/OpenBMB/MiniCPM-V/blob/main/finetune/dataset.py#L220) to sample the length of your training data. Note that the `input_ids` length includes the image portion. Once you determine the maximum length, you can specify it in the startup command using `--model_max_length xxx`.
Additionally, if you prefer not to train the vision encoder, you can add `--tune_vision false` to your command.
</details>
<details>
<summary>Q: How can we adjust training hyperparameters when using LoRA to train our model?</summary>
A: You can refer to the [LoRA documentation](https://huggingface.co/docs/peft/en/package_reference/lora#peft.LoraConfig) for guidance on adjusting your training hyperparameters when using LoRA. This documentation provides detailed information on configuring various parameters specific to the LoRA adaptation technique.
</details>
#### Customizing Hyperparameters
To tailor the training process according to your specific requirements, you can adjust various hyperparameters. For comprehensive documentation on available hyperparameters and their functionalities, you can refer to the [official Transformers documentation](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments) and [Lora documentation](https://huggingface.co/docs/peft/en/package_reference/lora#peft.LoraConfig). Experimentation and fine-tuning of these parameters are essential for achieving optimal model performance tailored to your specific task and dataset.
import torch
import torch.nn as nn
import deepspeed
from transformers import Trainer
from transformers.trainer_pt_utils import nested_detach
from transformers.utils import is_sagemaker_mp_enabled
from transformers.trainer import *
import deepspeed
from transformers.integrations import is_deepspeed_zero3_enabled
class CPMTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False):
if "labels" in inputs:
labels = inputs.pop("labels")
else:
labels = None
self.model.resampler.pos_embed = self.model.resampler.pos_embed.to(self.model.device)
if is_deepspeed_zero3_enabled():
with deepspeed.zero.GatheredParameters(self.model.resampler.attn.parameters(), modifier_rank=0):
if not self.args.use_lora:
outputs = self.model(data = inputs, use_cache=False)
else:
with self.model._enable_peft_forward_hooks(**inputs):
outputs = self.model.base_model(data = inputs, use_cache=False)
else:
if not self.args.use_lora:
outputs = self.model(data = inputs, use_cache=False)
else:
with self.model._enable_peft_forward_hooks(**inputs):
outputs = self.model.base_model(data = inputs, use_cache=False)
if labels is not None:
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss()
logits = outputs.logits.view(-1,
self.model.config.vocab_size).contiguous()
labels = labels.view(-1).long().contiguous()
# Enable model parallelism
labels = labels.to(logits.device)
loss = loss_fct(logits, labels)
else:
if isinstance(outputs, dict) and "loss" not in outputs:
raise ValueError(
"The model did not return a loss from the inputs, only the following keys: "
f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}."
)
# We don't use .loss here since the model may return tuples instead of ModelOutput.
loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
return (loss, outputs) if return_outputs else loss
def prediction_step(
self,
model: nn.Module,
inputs: Dict[str, Union[torch.Tensor, Any]],
prediction_loss_only: bool,
ignore_keys: Optional[List[str]] = None,
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
"""
Perform an evaluation step on `model` using `inputs`.
Subclass and override to inject custom behavior.
Args:
model (`nn.Module`):
The model to evaluate.
inputs (`Dict[str, Union[torch.Tensor, Any]]`):
The inputs and targets of the model.
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
argument `labels`. Check your model's documentation for all accepted arguments.
prediction_loss_only (`bool`):
Whether or not to return the loss only.
ignore_keys (`List[str]`, *optional*):
A list of keys in the output of your model (if it is a dictionary) that should be ignored when
gathering predictions.
Return:
Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss,
logits and labels (each being optional).
"""
has_labels = (
False
if len(self.label_names) == 0
else all(inputs.get(k) is not None for k in self.label_names)
)
# For CLIP-like models capable of returning loss values.
# If `return_loss` is not specified or being `None` in `inputs`, we check if the default value of `return_loss`
# is `True` in `model.forward`.
return_loss = inputs.get("return_loss", None)
if return_loss is None:
return_loss = self.can_return_loss
loss_without_labels = (
True if len(self.label_names) == 0 and return_loss else False
)
inputs = self._prepare_inputs(inputs)
if ignore_keys is None:
if hasattr(self.model, "config"):
ignore_keys = getattr(
self.model.config, "keys_to_ignore_at_inference", []
)
else:
ignore_keys = []
# labels may be popped when computing the loss (label smoothing for instance) so we grab them first.
if has_labels or loss_without_labels:
labels = nested_detach(tuple(inputs.get(name)
for name in self.label_names))
if len(labels) == 1:
labels = labels[0]
else:
labels = None
with torch.no_grad():
if is_sagemaker_mp_enabled():
raw_outputs = smp_forward_only(model, inputs)
if has_labels or loss_without_labels:
if isinstance(raw_outputs, dict):
loss_mb = raw_outputs["loss"]
logits_mb = tuple(
v
for k, v in raw_outputs.items()
if k not in ignore_keys + ["loss"]
)
else:
loss_mb = raw_outputs[0]
logits_mb = raw_outputs[1:]
loss = loss_mb.reduce_mean().detach().cpu()
logits = smp_nested_concat(logits_mb)
else:
loss = None
if isinstance(raw_outputs, dict):
logits_mb = tuple(
v for k, v in raw_outputs.items() if k not in ignore_keys
)
else:
logits_mb = raw_outputs
logits = smp_nested_concat(logits_mb)
else:
if has_labels or loss_without_labels:
with self.compute_loss_context_manager():
loss, outputs = self.compute_loss(
model, inputs, return_outputs=True
)
loss = loss.mean().detach()
if isinstance(outputs, dict):
logits = tuple(
v
for k, v in outputs.items()
if k not in ignore_keys + ["loss"]
)
else:
logits = outputs[1:]
else:
loss = None
with self.compute_loss_context_manager():
outputs = model(**inputs)
if isinstance(outputs, dict):
logits = tuple(
v for k, v in outputs.items() if k not in ignore_keys
)
else:
logits = outputs
# TODO: this needs to be fixed and made cleaner later.
if self.args.past_index >= 0:
self._past = outputs[self.args.past_index - 1]
if prediction_loss_only:
return (loss, None, None)
logits = nested_detach(logits)
if len(logits) == 1:
logits = logits[0]
return (loss, logits, labels)
def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
"""
Perform a training step on a batch of inputs.
Subclass and override to inject custom behavior.
Args:
model (`nn.Module`):
The model to train.
inputs (`Dict[str, Union[torch.Tensor, Any]]`):
The inputs and targets of the model.
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
argument `labels`. Check your model's documentation for all accepted arguments.
Return:
`torch.Tensor`: The tensor with training loss on this batch.
"""
model.train()
inputs = self._prepare_inputs(inputs)
if is_sagemaker_mp_enabled():
loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps)
return loss_mb.reduce_mean().detach().to(self.args.device)
with self.compute_loss_context_manager():
loss = self.compute_loss(model, inputs)
del inputs
torch.cuda.empty_cache()
if self.args.n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu parallel training
if self.use_apex:
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
scaled_loss.backward()
else:
if is_deepspeed_zero3_enabled():
with deepspeed.zero.GatheredParameters(self.model.resampler.attn.parameters(), modifier_rank=0):
self.accelerator.backward(loss)
else:
self.accelerator.backward(loss)
return loss.detach() / self.args.gradient_accumulation_steps
def _save(self, output_dir: Optional[str] = None, state_dict=None):
# If we are executing this function, we are the process zero, so we don't check for that.
output_dir = output_dir if output_dir is not None else self.args.output_dir
os.makedirs(output_dir, exist_ok=True)
logger.info(f"Saving model checkpoint to {output_dir}")
supported_classes = (PreTrainedModel,) if not is_peft_available() else (PreTrainedModel, PeftModel)
# Save a trained model and configuration using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()`
if not isinstance(self.model, supported_classes):
if state_dict is None:
state_dict = self.model.state_dict()
if isinstance(unwrap_model(self.model), supported_classes):
unwrap_model(self.model).save_pretrained(
output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
)
else:
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
if self.args.save_safetensors:
safetensors.torch.save_file(
state_dict, os.path.join(output_dir, SAFE_WEIGHTS_NAME), metadata={"format": "pt"}
)
else:
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
else:
if self.args.use_lora:
from collections import OrderedDict
state_dict_vision = OrderedDict()
for key, values in state_dict.items():
if 'vpm' in key or 'resampler' in key or 'embed_tokens' in key:
state_dict_vision[key] = values
self.model.save_pretrained(
output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
)
torch.save(state_dict_vision, f"{output_dir}/vpm_resampler_embedtokens.pt", )
else:
self.model.save_pretrained(
output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
)
if self.tokenizer is not None:
self.tokenizer.save_pretrained(output_dir)
# Good practice: save your training arguments together with the trained model
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
#!/bin/bash
HIP_VISIBLE_DEVICES=0,1,2,3
GPUS_PER_NODE=4
NNODES=1
NODE_RANK=0
MASTER_ADDR=localhost
MASTER_PORT=29500
MODEL="/home/wanglch/projects/MiniCPM-V/MiniCPM-Llama3-V-2_5-base" # or openbmb/MiniCPM-V-2
# ATTENTION: specify the path to your training data, which should be a json file consisting of a list of conversations.
# See the section for finetuning in README for more information.
DATA="/home/wanglch/projects/MiniCPM-V/data/self_build/train_data/train_data.json"
EVAL_DATA="/home/wanglch/projects/MiniCPM-V/data/self_build/eval_data/eval_data.json"
LLM_TYPE="llama3" # if use openbmb/MiniCPM-V-2, please set LLM_TYPE=minicpm
DISTRIBUTED_ARGS="
--nproc_per_node $GPUS_PER_NODE \
--nnodes $NNODES \
--node_rank $NODE_RANK \
--master_addr $MASTER_ADDR \
--master_port $MASTER_PORT
"
torchrun $DISTRIBUTED_ARGS finetune.py \
--model_name_or_path $MODEL \
--llm_type $LLM_TYPE \
--data_path $DATA \
--eval_data_path $EVAL_DATA \
--remove_unused_columns false \
--label_names "labels" \
--prediction_loss_only false \
--bf16 false \
--bf16_full_eval false \
--fp16 true \
--fp16_full_eval true \
--do_train \
--do_eval \
--tune_vision true \
--tune_llm false \
--use_lora true \
--lora_target_modules "llm\..*layers\.\d+\.self_attn\.(q_proj|k_proj)" \
--model_max_length 2048 \
--max_slice_nums 9 \
--max_steps 100 \
--eval_steps 10 \
--output_dir /home/wanglch/projects/saves/MiniCPM-Llama3-V-2_5/lora_train_dtk \
--logging_dir /home/wanglch/projects/saves/MiniCPM-Llama3-V-2_5/lora_train_dtk \
--logging_strategy "steps" \
--per_device_train_batch_size 2 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 1 \
--evaluation_strategy "steps" \
--save_strategy "steps" \
--save_steps 100 \
--save_total_limit 10 \
--learning_rate 1e-6 \
--weight_decay 0.1 \
--adam_beta2 0.95 \
--warmup_ratio 0.01 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--gradient_checkpointing true \
--deepspeed ds_config_zero2.json \
--report_to "tensorboard" # wandb
## MiniCPM-V 1.0
> Archive at:2024-05-19
MiniCPM-V 1.0 is an efficient version with promising performance for deployment. The model is built based on SigLip-400M and [MiniCPM-2.4B](https://github.com/OpenBMB/MiniCPM/), connected by a perceiver resampler. Notable features of MiniCPM-V 1.0 include:
- ⚡️ **High Efficiency.**
MiniCPM-V 1.0 can be **efficiently deployed on most GPU cards and personal computers**, and **even on end devices such as mobile phones**. In terms of visual encoding, we compress the image representations into 64 tokens via a perceiver resampler, which is significantly fewer than other LMMs based on MLP architecture (typically > 512 tokens). This allows MiniCPM-V 1.0 to operate with **much less memory cost and higher speed during inference**.
- 🔥 **Promising Performance.**
MiniCPM-V 1.0 achieves **state-of-the-art performance** on multiple benchmarks (including MMMU, MME, and MMbech, etc) among models with comparable sizes, surpassing existing LMMs built on Phi-2. It even **achieves comparable or better performance than the 9.6B Qwen-VL-Chat**.
- 🙌 **Bilingual Support.**
MiniCPM-V 1.0 is **the first end-deployable LMM supporting bilingual multimodal interaction in English and Chinese**. This is achieved by generalizing multimodal capabilities across languages, a technique from the ICLR 2024 spotlight [paper](https://arxiv.org/abs/2308.12038).
### Evaluation
<div align="center">
<table style="margin: 0px auto;">
<thead>
<tr>
<th align="left">Model</th>
<th>Size</th>
<th nowrap="nowrap" >Visual Tokens</th>
<th>MME</th>
<th nowrap="nowrap" >MMB dev (en)</th>
<th nowrap="nowrap" >MMB dev (zh)</th>
<th nowrap="nowrap" >MMMU val</th>
<th nowrap="nowrap" >CMMMU val</th>
</tr>
</thead>
<tbody align="center">
<tr>
<td align="left">LLaVA-Phi</td>
<td align="right">3B</td>
<td>576</td>
<td>1335</td>
<td>59.8</td>
<td>- </td>
<td>- </td>
<td>- </td>
</tr>
<tr>
<td nowrap="nowrap" align="left">MobileVLM</td>
<td align="right">3B</td>
<td>144</td>
<td>1289</td>
<td>59.6</td>
<td>- </td>
<td>- </td>
<td>- </td>
</tr>
<tr>
<td nowrap="nowrap" align="left" >Imp-v1</td>
<td align="right">3B</td>
<td>576</td>
<td>1434</td>
<td>66.5</td>
<td>- </td>
<td>- </td>
<td>- </td>
</tr>
<tr>
<td nowrap="nowrap" align="left" >Qwen-VL-Chat</td>
<td align="right" >9.6B</td>
<td>256</td>
<td>1487</td>
<td>60.6 </td>
<td>56.7 </td>
<td>35.9 </td>
<td>30.7 </td>
</tr>
<tr>
<td nowrap="nowrap" align="left" >CogVLM</td>
<td align="right">17.4B </td>
<td>1225</td>
<td>1438 </td>
<td>63.7 </td>
<td>53.8 </td>
<td>32.1 </td>
<td>- </td>
</tr>
<tr>
<td nowrap="nowrap" align="left" ><b>MiniCPM-V 1.0</b></td>
<td align="right">3B </td>
<td>64</td>
<td>1452 </td>
<td>67.9 </td>
<td>65.3 </td>
<td>37.2 </td>
<td>32.1 </td>
</tr>
</tbody>
</table>
</div>
### Examples
We deploy MiniCPM-V 1.0 on end devices. The demo video is the raw screen recording on a OnePlus 9R without edition.
<table align="center">
<p align="center">
<img src="assets/gif_cases/蛇_cn.gif" width=36%/>
<img src="assets/gif_cases/Mushroom_en.gif" width=36%/>
</p>
</table>
## Install
1. Clone this repository and navigate to the source folder
```bash
git clone https://github.com/OpenBMB/OmniLMM.git
cd OmniLMM
```
2. Create conda environment
```Shell
conda create -n OmniLMM python=3.10 -y
conda activate OmniLMM
```
3. Install dependencies
```shell
pip install -r requirements.txt
```
## Inference
### Model Zoo
| Model | Description | Download Link |
|:----------------------|:-------------------|:---------------:|
| MiniCPM-V 1.0 | The efficient version for end device deployment. | [🤗](https://huggingface.co/openbmb/MiniCPM-V) &nbsp;&nbsp; [<img src="./assets/modelscope_logo.png" width="20px"></img>](https://modelscope.cn/models/OpenBMB/MiniCPM-V/files) |
### Multi-turn Conversation
Please refer to the following codes to run `MiniCPM-V 1.0`.
<div align="center">
<img src="assets/worldmap_ck.jpg" width="500px">
</div>
```python
from chat import OmniLMMChat, img2base64
chat_model = OmniLMMChat('openbmb/MiniCPM-V')
im_64 = img2base64('./assets/worldmap_ck.jpg')
# First round chat
msgs = [{"role": "user", "content": "What is interesting about this image?"}]
inputs = {"image": im_64, "question": json.dumps(msgs)}
answer = chat_model.chat(inputs)
print(answer)
# Second round chat
# pass history context of multi-turn conversation
msgs.append({"role": "assistant", "content": answer})
msgs.append({"role": "user", "content": "Where is China in the image"})
inputs = {"image": im_64, "question": json.dumps(msgs)}
answer = chat_model.chat(inputs)
print(answer)
```
### Inference on Mac
<details>
<summary>Click to view example, MiniCPM-V 1.0 can run on Mac with MPS (Apple silicon or AMD GPUs). </summary>
```python
# test.py
import torch
from PIL import Image
from transformers import AutoModel, AutoTokenizer
model = AutoModel.from_pretrained('openbmb/MiniCPM-V', trust_remote_code=True, torch_dtype=torch.bfloat16)
model = model.to(device='mps', dtype=torch.float16)
tokenizer = AutoTokenizer.from_pretrained('openbmb/MiniCPM-V', trust_remote_code=True)
model.eval()
image = Image.open('./assets/worldmap_ck.jpg').convert('RGB')
question = 'What is interesting about this image?'
msgs = [{'role': 'user', 'content': question}]
answer, context, _ = model.chat(
image=image,
msgs=msgs,
context=None,
tokenizer=tokenizer,
sampling=True
)
print(answer)
```
Run with command:
```shell
PYTORCH_ENABLE_MPS_FALLBACK=1 python test.py
```
</details>
### Deployment on Mobile Phone
Currently MiniCPM-V 1.0 can be deployed on mobile phones with Android and Harmony operating systems. 🚀 Try it out [here](https://github.com/OpenBMB/mlc-MiniCPM).
# 模型唯一标识
modelCode = 691
# 模型名称
modelName=minicpm-v_pytorch
# 模型描述
modelDescription=多模态OCR大模型,端侧可用
# 应用场景
appScenario=推理,OCR,金融,教育,政府,科研,交通,广媒
# 框架类型
frameType=pytorch
## OmniLMM-12B
> OmniLMM-12B 发布于本项目早期。推荐您使用我们[最新发布的模型](./README_zh.md),以获得更高效的推理和更强大的性能体验。
> 归档时间:2024-05-19
**OmniLMM-12B** 是当前系列中性能最佳的版本。该模型基于EVA02-5B和Zephyr-7B-β初始化构建,并使用perceiver resampler连接,采用了课程学习的方法在多模态数据上进行训练。该模型具有三个特点:
- 🔥 **性能领先。**
OmniLMM-12B 相比其他同规模模型在多个基准测试中取得**领先的性能**(包括 MME、MMBench、SEED-Bench 等),模型掌握了较为丰富的多模态世界知识。
- 🏆 **行为可信。**
多模态大模型的幻觉问题备受关注,模型经常生成和图像中的事实不符的文本(例如,确信地描述图片中并不存在的物体)。OmniLMM-12B是 **第一个通过多模态 RLHF 对齐的综合能力优秀的开源多模态大模型**(借助 [RLHF-V](https://rlhf-v.github.io/) [CVPR'24] 系列技术)。该模型在 [MMHal-Bench](https://huggingface.co/datasets/Shengcao1006/MMHal-Bench) 幻觉评测基准上达到**开源模型最佳水平**,并在 [Object HalBench](https://arxiv.org/abs/2312.00849)**优于GPT-4V**
- 🕹 **实时多模态交互。**
我们尝试结合OmniLMM-12B和GPT-3.5 (纯文本模型) ,实现**实时多模态交互助手**。该模型接受来自摄像头的视频流,并借助工具处理语音输入输出。虽然还很初步,我们发现该模型无需视频编辑可以**复现Gemini演示视频中的一些有趣例子**
### 评测结果 <!-- omit in toc -->
<div align="center">
<img src=assets/radar_omnilmm12b.png width=66% />
</div>
<details>
<summary> MME, MMBench, MMMU, MMBench, MMHal-Bench, Object HalBench, SeedBench, LLaVA Bench W, MathVista 上的详细评测结果。 </summary>
<table>
<thead>
<tr>
<th align="left">Model</th>
<th>Size</th>
<th>MME</th>
<th nowrap="nowrap">MMB dev (en)</th>
<th nowrap="nowrap" >MMMU val</th>
<th nowrap="nowrap" >MMHal-Bench</th>
<th nowrap="nowrap" >Object HalBench</th>
<th nowrap="nowrap" >SeedBench-I</th>
<th>MathVista</th>
<th nowrap="nowrap" >LLaVA Bench</th>
</tr>
</thead>
<tbody align="center">
<tr>
<td align="left">GPT-4V†</td>
<td>-</td>
<td>1771.5</td>
<td>75.1 </td>
<td>56.8</td>
<td>3.53 / 70.8</td>
<td>86.4 / 92.7</td>
<td>71.6 </td>
<td>47.8 </td>
<td>93.1 </td>
</tr>
<tr>
<td nowrap="nowrap" align="left">Qwen-VL-Plus†</td>
<td>-</td>
<td>2183.4</td>
<td>66.2 </td>
<td>45.2</td>
<td>- </td>
<td>- </td>
<td>65.7 </td>
<td>36.0 </td>
<td>73.7 </td>
</tr>
<tr>
<td align="left">Yi-VL 6B</td>
<td align="right">6.7B </td>
<td>1915.1 </td>
<td>68.6 </td>
<td>40.3 </td>
<td>- </td>
<td>- </td>
<td>67.5 </td>
<td>28.8 </td>
<td>51.9 </td>
</tr>
<tr>
<td nowrap="nowrap" align="left" >Qwen-VL-Chat</td>
<td align="right">9.6B</td>
<td>1860.0</td>
<td>60.6 </td>
<td>35.9</td>
<td>2.93 / 59.4</td>
<td>56.2 / 80.0</td>
<td>64.8 </td>
<td>33.8 </td>
<td>67.7 </td>
</tr>
<tr>
<td align="left" >CogVLM-Chat</td>
<td align="right">17.4B</td>
<td>1736.6</td>
<td>63.7 </td>
<td>32.1 </td>
<td>2.68 / 52.1 </td>
<td>73.6 / 87.4 </td>
<td>68.8 </td>
<td>34.7 </td>
<td>73.9 </td>
</tr>
<tr>
<td align="left" >LLaVA 1.5</td>
<td align="right">13.6B </td>
<td>1808.4 </td>
<td>68.2 </td>
<td>36.4 </td>
<td>2.71 / 51.0 </td>
<td>53.7 / 77.4 </td>
<td>68.1 </td>
<td>26.4 </td>
<td>64.6 </td>
</tr>
<tr>
<td nowrap="nowrap" align="left" ><b>OmniLMM-12B</b></td>
<td align="right">11.6B </td>
<td>1935.8 </td>
<td>71.6 </td>
<td>40.7 </td>
<td>3.45 / 68.8 </td>
<td>90.3 / 95.5 </td>
<td>71.1 </td>
<td>34.9 </td>
<td>72.0 </td>
</tr>
</tbody>
</table>
<small>†: 闭源模型</small>
<br>
</details>
### 典型示例 <!-- omit in toc -->
<table align="center" >
<p align="center" >
<img src="assets/omnilmm-12b-examples_2.png" />
</p>
</table>
我们结合 OmniLMM-12B 和 ChatGPT-3.5 (纯文本模型) 尝试构建 **实时多模态交互助手**. OmniLMM-12B 将视频帧转为对应的图像描述并输入给ChatGPT-3.5来生成对用户指令的响应。演示视频未经编辑。
<div align="center" >
<video controls src="https://github.com/OpenBMB/OmniLMM/assets/157115220/8fec13bf-bb47-4bf8-8f8c-d0b716a964ec" type="video/mp4" width=80%/>
</div>
## Online Demo
欢迎通过以下链接使用我们的网页端推理服务: [OmniLMM-12B](http://120.92.209.146:8081)[MiniCPM-V 2.0](http://120.92.209.146:80).
## 安装
1. 克隆我们的仓库并跳转到相应目录
```bash
git clone https://github.com/OpenBMB/MiniCPM-V.git
cd MiniCPM-V
```
1. 创建 conda 环境
```Shell
conda create -n MiniCPMV python=3.10 -y
conda activate MiniCPMV
```
3. 安装依赖
```shell
pip install -r requirements.txt
```
## 推理
### 模型库
| 模型 | 简介 | 下载链接 |
|:----------------------|:-------------------|:---------------:|
| OmniLMM-12B | 性能最强的版本 | [🤗](https://huggingface.co/openbmb/OmniLMM-12B) &nbsp;&nbsp; [<img src="./assets/modelscope_logo.png" width="20px"></img>](https://modelscope.cn/models/OpenBMB/OmniLMM-12B/files) |
CONTROLLER_HEART_BEAT_EXPIRATION = 30
WORKER_HEART_BEAT_INTERVAL = 15
LOGDIR = "."
import dataclasses
from enum import auto, Enum
from typing import List, Tuple
class SeparatorStyle(Enum):
"""Different separator style."""
SINGLE = auto()
TWO = auto()
@dataclasses.dataclass
class Conversation:
"""A class that keeps all conversation history."""
system: str
roles: List[str]
messages: List[List[str]]
offset: int
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
sep: str = "###"
sep2: str = None
version: str = "Unknown"
skip_next: bool = False
def get_prompt(self):
if self.sep_style == SeparatorStyle.SINGLE:
ret = self.system + self.sep
for role, message in self.messages:
if message:
if type(message) is tuple:
message, _, _ = message
ret += role + ": " + message + self.sep
else:
ret += role + ":"
return ret
elif self.sep_style == SeparatorStyle.TWO:
seps = [self.sep, self.sep2]
ret = self.system + seps[0]
for i, (role, message) in enumerate(self.messages):
if message:
if type(message) is tuple:
message, _, _ = message
ret += role + ": " + message + seps[i % 2]
else:
ret += role + ":"
return ret
else:
raise ValueError(f"Invalid style: {self.sep_style}")
def append_message(self, role, message):
self.messages.append([role, message])
def get_images(self, return_pil=False):
images = []
for i, (role, msg) in enumerate(self.messages[self.offset:]):
if i % 2 == 0:
if type(msg) is tuple:
import base64
from io import BytesIO
from PIL import Image
msg, image, image_process_mode = msg
if image_process_mode == "Pad":
def expand2square(pil_img, background_color=(122, 116, 104)):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(
pil_img.mode, (width, width), background_color)
result.paste(
pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(
pil_img.mode, (height, height), background_color)
result.paste(
pil_img, ((height - width) // 2, 0))
return result
image = expand2square(image)
elif image_process_mode == "Crop":
pass
elif image_process_mode == "Resize":
image = image.resize((224, 224))
else:
raise ValueError(
f"Invalid image_process_mode: {image_process_mode}")
max_hw, min_hw = max(image.size), min(image.size)
aspect_ratio = max_hw / min_hw
max_len, min_len = 800, 400
shortest_edge = int(
min(max_len / aspect_ratio, min_len, min_hw))
longest_edge = int(shortest_edge * aspect_ratio)
W, H = image.size
if H > W:
H, W = longest_edge, shortest_edge
else:
H, W = shortest_edge, longest_edge
image = image.resize((W, H))
if return_pil:
images.append(image)
else:
buffered = BytesIO()
image.save(buffered, format="JPEG")
img_b64_str = base64.b64encode(
buffered.getvalue()).decode()
images.append(img_b64_str)
return images
def to_gradio_chatbot(self):
ret = []
for i, (role, msg) in enumerate(self.messages[self.offset:]):
if i % 2 == 0:
if type(msg) is tuple:
import base64
from io import BytesIO
msg, image, image_process_mode = msg
max_hw, min_hw = max(image.size), min(image.size)
aspect_ratio = max_hw / min_hw
max_len, min_len = 800, 400
shortest_edge = int(
min(max_len / aspect_ratio, min_len, min_hw))
longest_edge = int(shortest_edge * aspect_ratio)
W, H = image.size
if H > W:
H, W = longest_edge, shortest_edge
else:
H, W = shortest_edge, longest_edge
image = image.resize((W, H))
# image = image.resize((224, 224))
buffered = BytesIO()
image.save(buffered, format="JPEG")
img_b64_str = base64.b64encode(
buffered.getvalue()).decode()
img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
msg = msg.replace('<image>', img_str)
ret.append([msg, None])
else:
ret[-1][-1] = msg
return ret
def copy(self):
return Conversation(
system=self.system,
roles=self.roles,
messages=[[x, y] for x, y in self.messages],
offset=self.offset,
sep_style=self.sep_style,
sep=self.sep,
sep2=self.sep2)
def dict(self):
if len(self.get_images()) > 0:
return {
"system": self.system,
"roles": self.roles,
"messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
"offset": self.offset,
"sep": self.sep,
"sep2": self.sep2,
}
return {
"system": self.system,
"roles": self.roles,
"messages": self.messages,
"offset": self.offset,
"sep": self.sep,
"sep2": self.sep2,
}
conv_v1 = Conversation(
system="A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
roles=("Human", "Assistant"),
messages=(
("Human", "Give three tips for staying healthy."),
("Assistant",
"Sure, here are three tips for staying healthy:\n"
"1. Exercise regularly: Regular physical activity can help improve your overall health and wellbeing. "
"It can also help reduce your risk of chronic conditions such as obesity, diabetes, heart disease, "
"and certain cancers. Aim for at least 150 minutes of moderate-intensity aerobic exercise or "
"75 minutes of vigorous-intensity aerobic exercise per week, along with muscle-strengthening "
"activities at least two days per week.\n"
"2. Eat a balanced diet: Eating a balanced diet that is rich in fruits, "
"vegetables, whole grains, lean proteins, and healthy fats can help support "
"your overall health. Try to limit your intake of processed and high-sugar foods, "
"and aim to drink plenty of water throughout the day.\n"
"3. Get enough sleep: Getting enough quality sleep is essential for your physical "
"and mental health. Adults should aim for seven to nine hours of sleep per night. "
"Establish a regular sleep schedule and try to create a relaxing bedtime routine to "
"help improve the quality of your sleep.")
),
offset=2,
sep_style=SeparatorStyle.SINGLE,
sep="###",
)
conv_v1_2 = Conversation(
system="A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
roles=("Human", "Assistant"),
messages=(
("Human", "What are the key differences between renewable and non-renewable energy sources?"),
("Assistant",
"Renewable energy sources are those that can be replenished naturally in a relatively "
"short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
"Non-renewable energy sources, on the other hand, are finite and will eventually be "
"depleted, such as coal, oil, and natural gas. Here are some key differences between "
"renewable and non-renewable energy sources:\n"
"1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
"energy sources are finite and will eventually run out.\n"
"2. Environmental impact: Renewable energy sources have a much lower environmental impact "
"than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
"and other negative effects.\n"
"3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
"have lower operational costs than non-renewable sources.\n"
"4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
"locations than non-renewable sources.\n"
"5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
"situations and needs, while non-renewable sources are more rigid and inflexible.\n"
"6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
"non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
),
offset=2,
sep_style=SeparatorStyle.SINGLE,
sep="###",
)
conv_vicuna_v1_1 = Conversation(
system="A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
roles=("USER", "ASSISTANT"),
version="v1",
messages=(),
offset=0,
sep_style=SeparatorStyle.TWO,
sep=" ",
sep2="</s>",
)
conv_bair_v1 = Conversation(
system="BEGINNING OF CONVERSATION:",
roles=("USER", "GPT"),
messages=(),
offset=0,
sep_style=SeparatorStyle.TWO,
sep=" ",
sep2="</s>",
)
simple_conv = Conversation(
system="You are LLaVA, a large language model trained by UW Madison WAIV Lab, based on LLaMA architecture."
"You are designed to assist human with a variety of tasks using natural language."
"Follow the instructions carefully.",
roles=("Human", "Assistant"),
messages=(
("Human", "Hi!"),
("Assistant", "Hi there! How can I help you today?\n")
),
offset=2,
sep_style=SeparatorStyle.SINGLE,
sep="###",
)
simple_conv_multimodal = Conversation(
system="A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
roles=("Human", "Assistant"),
messages=(
),
offset=0,
sep_style=SeparatorStyle.SINGLE,
sep="###",
)
simple_conv_legacy = Conversation(
system="You are LLaVA, a large language model trained by UW Madison WAIV Lab."
"You are designed to assist human with a variety of tasks using natural language."
"Follow the instructions carefully.",
roles=("Human", "Assistant"),
messages=(
("Human", "Hi!\n\n### Response:"),
("Assistant", "Hi there! How can I help you today?\n")
),
offset=2,
sep_style=SeparatorStyle.SINGLE,
sep="###",
)
conv_llava_v1 = Conversation(
system="You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab."
"You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
"Follow the instructions carefully and explain your answers in detail.",
roles=("USER", "ASSISTANT"),
version="v1",
messages=(),
offset=0,
sep_style=SeparatorStyle.TWO,
sep=" ",
sep2="</s>",
)
default_conversation = conv_v1_2
conv_templates = {
"default": conv_v1_2,
"simple": simple_conv,
"simple_legacy": simple_conv_legacy,
"multimodal": simple_conv_multimodal,
"llava_v1": conv_llava_v1,
# fastchat
"v1": conv_v1_2,
"bair_v1": conv_bair_v1,
"vicuna_v1_1": conv_vicuna_v1_1,
}
if __name__ == "__main__":
print(default_conversation.get_prompt())
from .omnilmm import OmniLMMForCausalLM
\ No newline at end of file
import gc
import math
import timm
import torch
from torch import Tensor
import torch.nn as nn
from torch.nn import CrossEntropyLoss
from typing import List, Optional, Tuple, Union
from transformers import AutoConfig, AutoModelForCausalLM
from transformers import MistralForCausalLM, MistralModel, MistralConfig
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from omnilmm.model.utils import build_transform
from omnilmm.model.resampler import Resampler
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
DEFAULT_IM_START_TOKEN = "<im_start>"
DEFAULT_IM_END_TOKEN = "<im_end>"
class OmniLMMConfig(MistralConfig):
model_type = "omnilmm"
class Identity(torch.nn.Identity):
def forward(self, input: Tensor, **kwargs) -> Tensor:
return super().forward(input)
def create_vision_module(config):
vision_tower = timm.create_model('eva02_enormous_patch14_clip_224.laion2b_plus',
pretrained=False,
num_classes=0,
dynamic_img_size=True,
dynamic_img_pad=True)
if isinstance(vision_tower, timm.models.VisionTransformer):
if vision_tower.attn_pool is not None:
vision_tower.attn_pool = Identity()
# use 2nd last layer's output
vision_tower.blocks[-1] = Identity()
embed_dim = config.hidden_size
resampler = Resampler(
grid_size=int(math.sqrt(config.num_query)),
embed_dim=embed_dim,
num_heads=embed_dim // 128,
kv_dim=vision_tower.embed_dim,
)
return vision_tower, resampler
class OmniLMMModel(MistralModel):
config_class = OmniLMMConfig
def __init__(self, config: OmniLMMConfig, mm_vision_tower=None, mm_hidden_size=None, tune_clip=True):
super(OmniLMMModel, self).__init__(config)
if hasattr(config, "mm_vision_tower"):
vision_tower, resampler = create_vision_module(config)
# print(__file__, 'skip loading vision tower weights')
# HACK: for FSDP
self.vision_tower = [vision_tower]
self.resampler = resampler
if tune_clip:
self.vision_tower = self.vision_tower[0]
self.vision_config = lambda x: None
def initialize_vision_modules(self, vision_tower, no_randaug, num_query, image_size, tune_clip=False):
self.config.mm_vision_tower = vision_tower
self.config.use_mm_proj = True
self.config.num_query = num_query
self.config.image_size = image_size
if not hasattr(self, 'vision_tower'):
vision_tower, resampler = create_vision_module(self.config)
state_dict = torch.load(
'/tt/data/public/multimodal/multimodal_model_ckpts/timm/eva02_enormous_patch14_clip_224.laion2b_plus.pt')
vision_tower.load_state_dict(state_dict, strict=False)
del state_dict
gc.collect()
else:
if isinstance(self.vision_tower, list):
vision_tower = self.vision_tower[0]
else:
vision_tower = self.vision_tower
resampler = self.resampler
self.vision_tower = vision_tower if tune_clip else [vision_tower]
self.resampler = resampler
train_img_transform = build_transform(
is_train=True, randaug=not no_randaug, input_size=self.config.image_size, std_mode='OPENAI_CLIP')
eval_img_transform = build_transform(
is_train=False, input_size=self.config.image_size, std_mode='OPENAI_CLIP')
return dict(
image_processor=(train_img_transform, eval_img_transform),
image_token_len=num_query,
vision_config=self.vision_config
)
def get_vision_embedding(self, pixel_values):
if isinstance(self.vision_tower, list):
vision_tower = self.vision_tower[0] # HACK: for FSDP
else:
vision_tower = self.vision_tower
dtype = vision_tower.pos_embed.data.dtype
vision_embedding = vision_tower.forward_features(
pixel_values.type(dtype))
if hasattr(vision_tower, 'num_prefix_tokens') and vision_tower.num_prefix_tokens > 0:
vision_embedding = vision_embedding[:,
vision_tower.num_prefix_tokens:]
res = self.resampler(vision_embedding)
return res
def get_vllm_embedding(self, data):
if 'vision_hidden_states' not in data:
pixel_values_list = data['pixel_values']
vision_hidden_states = []
for pixel_values in pixel_values_list:
if len(pixel_values) > 0:
vision_hidden_states.append(self.get_vision_embedding(pixel_values.unsqueeze(0))[0])
else:
vision_hidden_states.append([])
else:
vision_hidden_states = data['vision_hidden_states']
#vllm_embedding = self.llm.model.embed_tokens(data['input_ids']) * self.llm.config.scale_emb
inputs_embeds = self.embed_tokens(data['input_ids'])
vision_hidden_states = [i.type(inputs_embeds.dtype)
if isinstance(i, torch.Tensor) else i for i in vision_hidden_states
]
# HACK: replace back original embeddings for LLaVA pretraining
orig_embeds_params = getattr(self, 'orig_embeds_params', None)
new_input_embeds = []
cur_image_idx = 0
for cur_input_ids, cur_input_embeds in zip(data['input_ids'], inputs_embeds):
if (cur_input_ids == self.vision_config.im_patch_token).sum() == 0:
# multimodal LLM, but the current sample is not multimodal
cur_input_embeds = cur_input_embeds + (0. * dummy_image_features).sum()
new_input_embeds.append(cur_input_embeds)
continue
if self.vision_config.use_im_start_end:
cur_image_features = vision_hidden_states[cur_image_idx]
num_patches = cur_image_features.shape[0]
if (cur_input_ids == self.vision_config.im_start_token).sum() != (cur_input_ids == self.vision_config.im_end_token).sum():
raise ValueError(
"The number of image start tokens and image end tokens should be the same.")
image_start_tokens = torch.where(
cur_input_ids == self.vision_config.im_start_token)[0]
for image_start_token_pos in image_start_tokens:
cur_image_features = vision_hidden_states[cur_image_idx].to(
device=cur_input_embeds.device)
num_patches = cur_image_features.shape[0]
if cur_input_ids[image_start_token_pos + num_patches + 1] != self.vision_config.im_end_token:
raise ValueError(
"The image end token should follow the image start token.")
if orig_embeds_params is not None:
cur_new_input_embeds = torch.cat((cur_input_embeds[:image_start_token_pos].detach(), cur_input_embeds[image_start_token_pos:image_start_token_pos+1], cur_image_features,
cur_input_embeds[image_start_token_pos + num_patches + 1:image_start_token_pos + num_patches + 2], cur_input_embeds[image_start_token_pos + num_patches + 2:].detach()), dim=0)
else:
cur_new_input_embeds = torch.cat(
(cur_input_embeds[:image_start_token_pos+1], cur_image_features, cur_input_embeds[image_start_token_pos + num_patches + 1:]), dim=0)
cur_image_idx += 1
new_input_embeds.append(cur_new_input_embeds)
else:
raise NotImplementedError
inputs_embeds = torch.stack(new_input_embeds, dim=0)
return inputs_embeds, vision_hidden_states
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
images: Optional[torch.FloatTensor] = None,
return_dict: Optional[bool] = None,
**kwargs
) -> Union[Tuple, BaseModelOutputWithPast]:
# HACK: replace back original embeddings for LLaVA pretraining
orig_embeds_params = getattr(self, 'orig_embeds_params', None)
if inputs_embeds is None and past_key_values is None:
inputs_embeds = self.embed_tokens(input_ids)
vision_tower = getattr(self, 'vision_tower', None)
if vision_tower is not None and (input_ids.shape[1] != 1 or self.training) and images is not None:
if type(images) is list:
image_features = []
for image in images:
image_forward_out = self.get_vision_embedding(image.unsqueeze(0))[
0]
image_features.append(image_forward_out)
else:
image_features = self.get_vision_embedding(images)
dummy_image_features = torch.zeros(
self.config.num_query,
self.config.hidden_size,
device=inputs_embeds.device,
dtype=inputs_embeds.dtype)
new_input_embeds = []
cur_image_idx = 0
for cur_input_ids, cur_input_embeds in zip(input_ids, inputs_embeds):
if (cur_input_ids == self.vision_config.im_patch_token).sum() == 0:
# multimodal LLM, but the current sample is not multimodal
cur_input_embeds = cur_input_embeds + \
(0. * dummy_image_features).sum()
new_input_embeds.append(cur_input_embeds)
continue
if self.vision_config.use_im_start_end:
cur_image_features = image_features[cur_image_idx]
num_patches = cur_image_features.shape[0]
if (cur_input_ids == self.vision_config.im_start_token).sum() != (cur_input_ids == self.vision_config.im_end_token).sum():
raise ValueError(
"The number of image start tokens and image end tokens should be the same.")
image_start_tokens = torch.where(
cur_input_ids == self.vision_config.im_start_token)[0]
for image_start_token_pos in image_start_tokens:
cur_image_features = image_features[cur_image_idx].to(
device=cur_input_embeds.device)
num_patches = cur_image_features.shape[0]
if cur_input_ids[image_start_token_pos + num_patches + 1] != self.vision_config.im_end_token:
raise ValueError(
"The image end token should follow the image start token.")
if orig_embeds_params is not None:
cur_new_input_embeds = torch.cat((cur_input_embeds[:image_start_token_pos].detach(), cur_input_embeds[image_start_token_pos:image_start_token_pos+1], cur_image_features,
cur_input_embeds[image_start_token_pos + num_patches + 1:image_start_token_pos + num_patches + 2], cur_input_embeds[image_start_token_pos + num_patches + 2:].detach()), dim=0)
else:
cur_new_input_embeds = torch.cat(
(cur_input_embeds[:image_start_token_pos+1], cur_image_features, cur_input_embeds[image_start_token_pos + num_patches + 1:]), dim=0)
cur_image_idx += 1
new_input_embeds.append(cur_new_input_embeds)
else:
raise NotImplementedError
inputs_embeds = torch.stack(new_input_embeds, dim=0)
input_ids = None
return super(OmniLMMModel, self).forward(
input_ids=input_ids, attention_mask=attention_mask, past_key_values=past_key_values,
inputs_embeds=inputs_embeds, use_cache=use_cache,
output_attentions=output_attentions, output_hidden_states=output_hidden_states,
return_dict=return_dict,
**kwargs
)
class OmniLMMForCausalLM(MistralForCausalLM):
config_class = OmniLMMConfig
def __init__(self, config, mm_vision_tower=None, tune_clip=True):
super(MistralForCausalLM, self).__init__(config)
self.model = OmniLMMModel(
config, mm_vision_tower=mm_vision_tower, tune_clip=tune_clip)
self.lm_head = nn.Linear(
config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
images: Optional[torch.FloatTensor] = None,
return_dict: Optional[bool] = None,
**kwargs
) -> Union[Tuple, CausalLMOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# print(f'@@@ At forward, labels: {labels.shape}-{labels}', flush=True)
# print(f'@@@ At forward, input_ids: {input_ids.shape}-{input_ids}', flush=True)
# print(f'@@@ At forward, input_ids: {attention_mask.shape}-{attention_mask}', flush=True)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
images=images,
**kwargs
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model/pipeline parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
# TODO could be removed for generate_vllm()
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
):
if past_key_values:
input_ids = input_ids[:, -1:]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
"images": kwargs.get("images", None),
}
)
return model_inputs
def generate_vllm(
self,
input_ids: torch.LongTensor = None,
images: Optional[torch.FloatTensor] = None,
vision_hidden_states=None,
return_vision_hidden_states=False,
**kwargs
):
model_inputs = {'input_ids': input_ids}
if vision_hidden_states is None:
model_inputs['pixel_values'] = images
else:
model_inputs['vision_hidden_states'] = vision_hidden_states
with torch.inference_mode():
inputs_embeds, vision_hidden_states = self.model.get_vllm_embedding(model_inputs)
result = self.generate(
inputs_embeds=inputs_embeds,
**kwargs
)
if return_vision_hidden_states:
return result, vision_hidden_states
return result
def initialize_vision_tokenizer(self, mm_use_im_start_end, tokenizer, device,
tune_mm_mlp_adapter=False):
self.model.vision_config.use_im_start_end = mm_use_im_start_end
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
self.resize_token_embeddings(len(tokenizer))
if mm_use_im_start_end:
num_new_tokens = tokenizer.add_tokens(
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
self.resize_token_embeddings(len(tokenizer))
self.model.vision_config.im_start_token, self.model.vision_config.im_end_token = tokenizer.convert_tokens_to_ids(
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN])
if num_new_tokens > 0:
input_embeddings = self.get_input_embeddings().weight.data
output_embeddings = self.get_output_embeddings().weight.data
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
dim=0, keepdim=True)
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
dim=0, keepdim=True)
input_embeddings[-num_new_tokens:] = input_embeddings_avg
output_embeddings[-num_new_tokens:] = output_embeddings_avg
# for new sft data
num_new_tokens = tokenizer.add_tokens(
['<box>', '</box>', '<ref>', '</ref>', '<quad>', '</quad>'], special_tokens=True)
self.resize_token_embeddings(len(tokenizer))
if num_new_tokens > 0:
input_embeddings = self.get_input_embeddings().weight.data
output_embeddings = self.get_output_embeddings().weight.data
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
dim=0, keepdim=True)
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
dim=0, keepdim=True)
input_embeddings[-num_new_tokens:] = input_embeddings_avg
output_embeddings[-num_new_tokens:] = output_embeddings_avg
if tune_mm_mlp_adapter:
self.model.orig_embeds_params = [
self.get_input_embeddings().weight.data.clone().to(device=device)]
for p in self.get_input_embeddings().parameters():
p.requires_grad = True
for p in self.get_output_embeddings().parameters():
p.requires_grad = False
self.model.vision_config.im_patch_token = tokenizer.convert_tokens_to_ids(
[DEFAULT_IMAGE_PATCH_TOKEN])[0]
print(f'Tokenizer: {tokenizer}\n patch_token_id: {self.model.vision_config.im_patch_token}, visoin_config: {self.model.vision_config}', flush=True)
# exit()
AutoConfig.register("omnilmm", OmniLMMConfig)
AutoModelForCausalLM.register(OmniLMMConfig, OmniLMMForCausalLM)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment