Commit 7a60e044 authored by wanglch's avatar wanglch
Browse files

Initial commit

parents
Pipeline #1185 canceled with stages
# Copyright (c) Alibaba Cloud.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from collections import OrderedDict
import math
import requests
from io import BytesIO
from functools import partial
from PIL import Image
from typing import Callable, Optional, Sequence, Tuple, List, Union
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from torch.nn.init import trunc_normal_
from torchvision import transforms
from torchvision.transforms import InterpolationMode
def get_abs_pos(abs_pos, tgt_size):
# abs_pos: L, C
# tgt_size: M
# return: M, C
src_size = int(math.sqrt(abs_pos.size(0)))
tgt_size = int(math.sqrt(tgt_size))
dtype = abs_pos.dtype
if src_size != tgt_size:
return F.interpolate(
abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2),
size=(tgt_size, tgt_size),
mode="bicubic",
align_corners=False,
).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype)
else:
return abs_pos
# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
"""
grid_size: int of the grid height and width
return:
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
grid_h = np.arange(grid_size, dtype=np.float32)
grid_w = np.arange(grid_size, dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, grid_size, grid_size])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
if cls_token:
pos_embed = np.concatenate(
[np.zeros([1, embed_dim]), pos_embed], axis=0)
return pos_embed
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
assert embed_dim % 2 == 0
# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(
embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(
embed_dim // 2, grid[1]) # (H*W, D/2)
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
return emb
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,)
out: (M, D)
"""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float32)
omega /= embed_dim / 2.
omega = 1. / 10000 ** omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
return emb
class Resampler(nn.Module):
"""
A 2D perceiver-resampler network with one cross attention layers by
(grid_size**2) learnable queries and 2d sincos pos_emb
Outputs:
A tensor with the shape of (grid_size**2, embed_dim)
"""
def __init__(
self,
grid_size,
embed_dim,
num_heads,
kv_dim=None,
norm_layer=partial(nn.LayerNorm, eps=1e-6)
):
super().__init__()
self.num_queries = grid_size ** 2
self.embed_dim = embed_dim
self.num_heads = num_heads
self.pos_embed = nn.Parameter(
torch.from_numpy(get_2d_sincos_pos_embed(
embed_dim, grid_size)).float()
).requires_grad_(False)
self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim))
trunc_normal_(self.query, std=.02)
if kv_dim is not None and kv_dim != embed_dim:
self.kv_proj = nn.Linear(kv_dim, embed_dim, bias=False)
else:
self.kv_proj = nn.Identity()
self.attn = nn.MultiheadAttention(embed_dim, num_heads)
self.ln_q = norm_layer(embed_dim)
self.ln_kv = norm_layer(embed_dim)
self.ln_post = norm_layer(embed_dim)
self.proj = nn.Parameter(
(embed_dim ** -0.5) * torch.randn(embed_dim, embed_dim))
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self, x, attn_mask=None):
pos_embed = get_abs_pos(self.pos_embed, x.size(1))
x = self.kv_proj(x)
x = self.ln_kv(x).permute(1, 0, 2)
N = x.shape[1]
q = self.ln_q(self.query)
# print((self._repeat(q, N) + self.pos_embed.unsqueeze(1)).dtype, (x + pos_embed.unsqueeze(1)).dtype, x.dtype)
out = self.attn(
self._repeat(q, N) + self.pos_embed.unsqueeze(1),
x + pos_embed.unsqueeze(1),
x,
attn_mask=attn_mask)[0]
x = out.permute(1, 0, 2)
x = self.ln_post(x)
x = x @ self.proj
return x
def _repeat(self, query, N: int):
return query.unsqueeze(1).repeat(1, N, 1)
from torchvision import transforms
from timm.data.transforms import RandomResizedCropAndInterpolation
from timm.data.constants import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from transformers import AutoConfig
from PIL import Image
from io import BytesIO
import torch.distributed as dist
import numpy as np
import pickle
import base64
import cv2
import os
import torch
from transformers import AutoConfig, StoppingCriteria
try:
from timm.data.constants import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
except ImportError:
OPENAI_CLIP_MEAN = (0.48145466, 0.4578275, 0.40821073)
OPENAI_CLIP_STD = (0.26862954, 0.26130258, 0.27577711)
def auto_upgrade(config):
cfg = AutoConfig.from_pretrained(config)
if 'llava' in config and cfg.model_type != 'llava':
print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.")
print("You must upgrade the checkpoint to the new code base (this can be done automatically).")
confirm = input(
"Please confirm that you want to upgrade the checkpoint. [Y/N]")
if confirm.lower() in ["y", "yes"]:
print("Upgrading checkpoint...")
assert len(cfg.architectures) == 1
setattr(cfg.__class__, "model_type", "llava")
cfg.architectures[0] = 'LlavaLlamaForCausalLM'
cfg.save_pretrained(config)
print("Checkpoint upgraded.")
else:
print("Checkpoint upgrade aborted.")
exit(1)
class KeywordsStoppingCriteria(StoppingCriteria):
def __init__(self, keywords, tokenizer, input_ids):
self.keywords = keywords
self.tokenizer = tokenizer
self.start_len = None
self.input_ids = input_ids
def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
if self.start_len is None:
self.start_len = self.input_ids.shape[1]
else:
outputs = self.tokenizer.batch_decode(
output_ids[:, self.start_len:], skip_special_tokens=True)[0]
for keyword in self.keywords:
if keyword in outputs:
return True
return False
def auto_upgrade(config):
cfg = AutoConfig.from_pretrained(config)
if 'llava' in config and cfg.model_type != 'llava':
print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.")
print("You must upgrade the checkpoint to the new code base (this can be done automatically).")
confirm = input(
"Please confirm that you want to upgrade the checkpoint. [Y/N]")
if confirm.lower() in ["y", "yes"]:
print("Upgrading checkpoint...")
assert len(cfg.architectures) == 1
setattr(cfg.__class__, "model_type", "llava")
cfg.architectures[0] = 'LlavaLlamaForCausalLM'
cfg.save_pretrained(config)
print("Checkpoint upgraded.")
else:
print("Checkpoint upgrade aborted.")
exit(1)
# aug functions
def identity_func(img):
return img
def autocontrast_func(img, cutoff=0):
'''
same output as PIL.ImageOps.autocontrast
'''
n_bins = 256
def tune_channel(ch):
n = ch.size
cut = cutoff * n // 100
if cut == 0:
high, low = ch.max(), ch.min()
else:
hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
low = np.argwhere(np.cumsum(hist) > cut)
low = 0 if low.shape[0] == 0 else low[0]
high = np.argwhere(np.cumsum(hist[::-1]) > cut)
high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0]
if high <= low:
table = np.arange(n_bins)
else:
scale = (n_bins - 1) / (high - low)
table = np.arange(n_bins) * scale - low * scale
table[table < 0] = 0
table[table > n_bins - 1] = n_bins - 1
table = table.clip(0, 255).astype(np.uint8)
return table[ch]
channels = [tune_channel(ch) for ch in cv2.split(img)]
out = cv2.merge(channels)
return out
def equalize_func(img):
'''
same output as PIL.ImageOps.equalize
PIL's implementation is different from cv2.equalize
'''
n_bins = 256
def tune_channel(ch):
hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
non_zero_hist = hist[hist != 0].reshape(-1)
step = np.sum(non_zero_hist[:-1]) // (n_bins - 1)
if step == 0:
return ch
n = np.empty_like(hist)
n[0] = step // 2
n[1:] = hist[:-1]
table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8)
return table[ch]
channels = [tune_channel(ch) for ch in cv2.split(img)]
out = cv2.merge(channels)
return out
def rotate_func(img, degree, fill=(0, 0, 0)):
'''
like PIL, rotate by degree, not radians
'''
H, W = img.shape[0], img.shape[1]
center = W / 2, H / 2
M = cv2.getRotationMatrix2D(center, degree, 1)
out = cv2.warpAffine(img, M, (W, H), borderValue=fill)
return out
def solarize_func(img, thresh=128):
'''
same output as PIL.ImageOps.posterize
'''
table = np.array([el if el < thresh else 255 - el for el in range(256)])
table = table.clip(0, 255).astype(np.uint8)
out = table[img]
return out
def color_func(img, factor):
'''
same output as PIL.ImageEnhance.Color
'''
# implementation according to PIL definition, quite slow
# degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis]
# out = blend(degenerate, img, factor)
# M = (
# np.eye(3) * factor
# + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor)
# )[np.newaxis, np.newaxis, :]
M = (
np.float32([
[0.886, -0.114, -0.114],
[-0.587, 0.413, -0.587],
[-0.299, -0.299, 0.701]]) * factor
+ np.float32([[0.114], [0.587], [0.299]])
)
out = np.matmul(img, M).clip(0, 255).astype(np.uint8)
return out
def contrast_func(img, factor):
"""
same output as PIL.ImageEnhance.Contrast
"""
mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299]))
table = np.array([(
el - mean) * factor + mean
for el in range(256)
]).clip(0, 255).astype(np.uint8)
out = table[img]
return out
def brightness_func(img, factor):
'''
same output as PIL.ImageEnhance.Contrast
'''
table = (np.arange(256, dtype=np.float32) *
factor).clip(0, 255).astype(np.uint8)
out = table[img]
return out
def sharpness_func(img, factor):
'''
The differences the this result and PIL are all on the 4 boundaries, the center
areas are same
'''
kernel = np.ones((3, 3), dtype=np.float32)
kernel[1][1] = 5
kernel /= 13
degenerate = cv2.filter2D(img, -1, kernel)
if factor == 0.0:
out = degenerate
elif factor == 1.0:
out = img
else:
out = img.astype(np.float32)
degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :]
out[1:-1, 1:-1, :] = degenerate + factor * \
(out[1:-1, 1:-1, :] - degenerate)
out = out.astype(np.uint8)
return out
def shear_x_func(img, factor, fill=(0, 0, 0)):
H, W = img.shape[0], img.shape[1]
M = np.float32([[1, factor, 0], [0, 1, 0]])
out = cv2.warpAffine(img, M, (W, H), borderValue=fill,
flags=cv2.INTER_LINEAR).astype(np.uint8)
return out
def translate_x_func(img, offset, fill=(0, 0, 0)):
'''
same output as PIL.Image.transform
'''
H, W = img.shape[0], img.shape[1]
M = np.float32([[1, 0, -offset], [0, 1, 0]])
out = cv2.warpAffine(img, M, (W, H), borderValue=fill,
flags=cv2.INTER_LINEAR).astype(np.uint8)
return out
def translate_y_func(img, offset, fill=(0, 0, 0)):
'''
same output as PIL.Image.transform
'''
H, W = img.shape[0], img.shape[1]
M = np.float32([[1, 0, 0], [0, 1, -offset]])
out = cv2.warpAffine(img, M, (W, H), borderValue=fill,
flags=cv2.INTER_LINEAR).astype(np.uint8)
return out
def posterize_func(img, bits):
'''
same output as PIL.ImageOps.posterize
'''
out = np.bitwise_and(img, np.uint8(255 << (8 - bits)))
return out
def shear_y_func(img, factor, fill=(0, 0, 0)):
H, W = img.shape[0], img.shape[1]
M = np.float32([[1, 0, 0], [factor, 1, 0]])
out = cv2.warpAffine(img, M, (W, H), borderValue=fill,
flags=cv2.INTER_LINEAR).astype(np.uint8)
return out
def cutout_func(img, pad_size, replace=(0, 0, 0)):
replace = np.array(replace, dtype=np.uint8)
H, W = img.shape[0], img.shape[1]
rh, rw = np.random.random(2)
pad_size = pad_size // 2
ch, cw = int(rh * H), int(rw * W)
x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H)
y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W)
out = img.copy()
out[x1:x2, y1:y2, :] = replace
return out
# level to args
def enhance_level_to_args(MAX_LEVEL):
def level_to_args(level):
return ((level / MAX_LEVEL) * 1.8 + 0.1,)
return level_to_args
def shear_level_to_args(MAX_LEVEL, replace_value):
def level_to_args(level):
level = (level / MAX_LEVEL) * 0.3
if np.random.random() > 0.5:
level = -level
return (level, replace_value)
return level_to_args
def translate_level_to_args(translate_const, MAX_LEVEL, replace_value):
def level_to_args(level):
level = (level / MAX_LEVEL) * float(translate_const)
if np.random.random() > 0.5:
level = -level
return (level, replace_value)
return level_to_args
def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value):
def level_to_args(level):
level = int((level / MAX_LEVEL) * cutout_const)
return (level, replace_value)
return level_to_args
def solarize_level_to_args(MAX_LEVEL):
def level_to_args(level):
level = int((level / MAX_LEVEL) * 256)
return (level, )
return level_to_args
def none_level_to_args(level):
return ()
def posterize_level_to_args(MAX_LEVEL):
def level_to_args(level):
level = int((level / MAX_LEVEL) * 4)
return (level, )
return level_to_args
def rotate_level_to_args(MAX_LEVEL, replace_value):
def level_to_args(level):
level = (level / MAX_LEVEL) * 30
if np.random.random() < 0.5:
level = -level
return (level, replace_value)
return level_to_args
func_dict = {
'Identity': identity_func,
'AutoContrast': autocontrast_func,
'Equalize': equalize_func,
'Rotate': rotate_func,
'Solarize': solarize_func,
'Color': color_func,
'Contrast': contrast_func,
'Brightness': brightness_func,
'Sharpness': sharpness_func,
'ShearX': shear_x_func,
'TranslateX': translate_x_func,
'TranslateY': translate_y_func,
'Posterize': posterize_func,
'ShearY': shear_y_func,
}
translate_const = 10
MAX_LEVEL = 10
replace_value = (128, 128, 128)
arg_dict = {
'Identity': none_level_to_args,
'AutoContrast': none_level_to_args,
'Equalize': none_level_to_args,
'Rotate': rotate_level_to_args(MAX_LEVEL, replace_value),
'Solarize': solarize_level_to_args(MAX_LEVEL),
'Color': enhance_level_to_args(MAX_LEVEL),
'Contrast': enhance_level_to_args(MAX_LEVEL),
'Brightness': enhance_level_to_args(MAX_LEVEL),
'Sharpness': enhance_level_to_args(MAX_LEVEL),
'ShearX': shear_level_to_args(MAX_LEVEL, replace_value),
'TranslateX': translate_level_to_args(
translate_const, MAX_LEVEL, replace_value
),
'TranslateY': translate_level_to_args(
translate_const, MAX_LEVEL, replace_value
),
'Posterize': posterize_level_to_args(MAX_LEVEL),
'ShearY': shear_level_to_args(MAX_LEVEL, replace_value),
}
class RandomAugment(object):
def __init__(self, N=2, M=10, isPIL=False, augs=[]):
self.N = N
self.M = M
self.isPIL = isPIL
if augs:
self.augs = augs
else:
self.augs = list(arg_dict.keys())
def get_random_ops(self):
sampled_ops = np.random.choice(self.augs, self.N)
return [(op, 0.5, self.M) for op in sampled_ops]
def __call__(self, img):
if self.isPIL:
img = np.array(img)
ops = self.get_random_ops()
for name, prob, level in ops:
if np.random.random() > prob:
continue
args = arg_dict[name](level)
img = func_dict[name](img, *args)
return img
def build_transform(is_train, randaug=True, input_size=224, interpolation='bicubic', std_mode='IMAGENET_INCEPTION'):
if std_mode == 'IMAGENET_INCEPTION':
mean = IMAGENET_INCEPTION_MEAN
std = IMAGENET_INCEPTION_STD
elif std_mode == 'OPENAI_CLIP':
mean = OPENAI_CLIP_MEAN
std = OPENAI_CLIP_STD
else:
raise NotImplementedError
if is_train:
crop_scale = float(os.environ.get('TRAIN_CROP_SCALE', 0.9999))
t = [
RandomResizedCropAndInterpolation(
input_size, scale=(crop_scale, 1.0), interpolation='bicubic'),
# transforms.RandomHorizontalFlip(),
]
if randaug and os.environ.get('TRAIN_DO_AUG', 'False') == 'True':
print(f'@@@@@ Do random aug during training', flush=True)
t.append(
RandomAugment(
2, 7, isPIL=True,
augs=[
'Identity', 'AutoContrast', 'Equalize', 'Brightness', 'Sharpness',
'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate',
]))
else:
print(f'@@@@@ Skip random aug during training', flush=True)
t += [
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std),
]
t = transforms.Compose(t)
else:
t = transforms.Compose([
transforms.Resize((input_size, input_size),
interpolation=transforms.InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std)
])
return t
def img2b64(img_path):
img = Image.open(img_path) # path to file
img_buffer = BytesIO()
img.save(img_buffer, format=img.format)
byte_data = img_buffer.getvalue()
base64_str = base64.b64encode(byte_data) # bytes
base64_str = base64_str.decode("utf-8") # str
return base64_str
def str2b64(str):
return base64.b64encode(str.encode('utf-8')).decode('utf-8')
def b642str(b64):
return base64.b64decode(b64).decode('utf-8')
def is_dist_avail_and_initialized():
if not dist.is_available():
return False
if not dist.is_initialized():
return False
return True
def get_world_size():
if not is_dist_avail_and_initialized():
return 1
return dist.get_world_size()
def get_rank():
if not is_dist_avail_and_initialized():
return 0
return dist.get_rank()
def all_gather(data):
"""
Run all_gather on arbitrary picklable data (not necessarily tensors)
Args:
data: any picklable object
Returns:
list[data]: list of data gathered from each rank
"""
world_size = get_world_size()
if world_size == 1:
return [data]
# serialized to a Tensor
buffer = pickle.dumps(data)
storage = torch.ByteStorage.from_buffer(buffer)
tensor = torch.ByteTensor(storage).to("cuda")
# obtain Tensor size of each rank
local_size = torch.LongTensor([tensor.numel()]).to("cuda")
size_list = [torch.LongTensor([0]).to("cuda") for _ in range(world_size)]
dist.all_gather(size_list, local_size)
size_list = [int(size.item()) for size in size_list]
max_size = max(size_list)
# receiving Tensor from all ranks
# we pad the tensor because torch all_gather does not support
# gathering tensors of different shapes
tensor_list = []
for _ in size_list:
tensor_list.append(torch.ByteTensor(size=(max_size,)).to("cuda"))
if local_size != max_size:
padding = torch.ByteTensor(size=(max_size - local_size,)).to("cuda")
tensor = torch.cat((tensor, padding), dim=0)
dist.all_gather(tensor_list, tensor)
data_list = []
for size, tensor in zip(size_list, tensor_list):
buffer = tensor.cpu().numpy().tobytes()[:size]
data_list.append(pickle.loads(buffer))
return data_list
def mean(lst):
return sum(lst) / len(lst)
def stop_gradient_by_name(name: str):
def apply_fn(module):
if hasattr(module, name):
getattr(module, name).requires_grad_(False)
return apply_fn
import os
import gc
import copy
import time
import torch
import warnings
import transformers
import numpy as np
from typing import Dict, Optional, Sequence
from omnilmm import conversation as conversation_lib
IGNORE_INDEX = -100
DEFAULT_IMAGE_TOKEN = "<image>"
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
DEFAULT_IM_START_TOKEN = "<im_start>"
DEFAULT_IM_END_TOKEN = "<im_end>"
def _tokenize_fn(strings: Sequence[str],
tokenizer: transformers.PreTrainedTokenizer) -> Dict:
"""Tokenize a list of strings."""
tokenized_list = [
tokenizer(
text,
return_tensors="pt",
padding="longest",
max_length=tokenizer.model_max_length,
truncation=True,
) for text in strings
]
input_ids = labels = [
tokenized.input_ids[0] for tokenized in tokenized_list
]
input_ids_lens = labels_lens = [
tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item()
for tokenized in tokenized_list
]
return dict(
input_ids=input_ids,
labels=labels,
input_ids_lens=input_ids_lens,
labels_lens=labels_lens,
)
def omni_preprocess(sources,
tokenizer: transformers.PreTrainedTokenizer,
generation=False):
system_content = 'You are an artificial intelligence assistant, which gives helpful, detailed, and polite answers to the human\'s questions.'
ignore_index = -100
response_template = '\n<|assistant|>\n'
instruction_template = '\n<|user|>\n'
response_token_ids = tokenizer.encode(
response_template, add_special_tokens=False)
instruction_token_ids = tokenizer.encode(
instruction_template, add_special_tokens=False)
batch_input_ids = []
batch_labels = []
for i in range(len(sources)):
new_source = []
prev_role = 'unexpect'
for conv_turn in sources[i]:
role = conv_turn['from'] if 'from' in conv_turn else conv_turn['role']
content = conv_turn['value'] if 'value' in conv_turn else conv_turn['content']
role = 'user' if role == 'human' else role
role = 'assistant' if role == 'gpt' else role
assert role in ['user', 'assistant']
assert role != prev_role, f'role={role}, prev_role={prev_role}'
prev_role = role
new_turn = {
'role': role,
'content': content
}
new_source.append(new_turn)
if new_source[0]['role'] != 'system':
new_source.insert(0, {'role': 'system', 'content': system_content})
# TODO: this automatically add '\n' to the end
res_text = tokenizer.apply_chat_template(
new_source, tokenize=False, add_generation_prompt=generation)
if not generation:
res_text = res_text.strip()
conversations_tokenized = _tokenize_fn([res_text], tokenizer)
res_input_ids = conversations_tokenized["input_ids"][0]
# since labels and input_ids are reference towards the same object
res_labels = copy.deepcopy(conversations_tokenized["labels"][0])
response_token_ids_idxs = []
human_token_ids_idxs = []
for assistant_idx in np.where(res_labels == response_token_ids[0])[0]:
# find the indexes of the start of a response.
if (response_token_ids == res_labels[assistant_idx: assistant_idx + len(
response_token_ids)].tolist()
):
response_token_ids_idxs.append(
assistant_idx + len(response_token_ids))
if len(response_token_ids_idxs) == 0:
warnings.warn(
f"Could not find response key `{response_template}` in the "
f'following instance: @===>{tokenizer.decode(res_input_ids)}<===@ '
f'Raw text is @===>{res_text}<===@'
f'Raw source is @===>{new_source}<===@'
f"This instance will be ignored in loss calculation. "
f"Note, if this happens often, consider increasing the `max_seq_length`."
)
res_labels[:] = ignore_index
human_token_ids = instruction_token_ids
for human_idx in np.where(res_labels == human_token_ids[0])[0]:
# find the indexes of the start of a human answer.
if human_token_ids == res_labels[human_idx: human_idx + len(human_token_ids)].tolist():
human_token_ids_idxs.append(human_idx)
if len(human_token_ids_idxs) == 0:
warnings.warn(
f"Could not find instruction key `{instruction_template}` in the "
f'following instance: @===>{tokenizer.decode(res_input_ids)}<===@ '
f'Raw text is @===>{res_text}<===@'
f'Raw source is @===>{new_source}<===@'
f"This instance will be ignored in loss calculation. "
f"Note, if this happens often, consider increasing the `max_seq_length`."
)
res_labels[:] = ignore_index
for idx, (start, end) in enumerate(zip(human_token_ids_idxs, response_token_ids_idxs)):
# Make pytorch loss function ignore all non response tokens
if idx != 0:
res_labels[start:end] = ignore_index
else:
res_labels[:end] = ignore_index
if len(response_token_ids_idxs) < len(human_token_ids_idxs):
res_labels[human_token_ids_idxs[-1]:] = ignore_index
batch_input_ids.append(res_input_ids)
batch_labels.append(res_labels)
return dict(input_ids=batch_input_ids, labels=batch_labels)
import datetime
import logging
import logging.handlers
import os
import sys
import requests
from omnilmm.constants import LOGDIR
server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN."
handler = None
def build_logger(logger_name, logger_filename):
global handler
formatter = logging.Formatter(
fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
# Set the format of root handlers
if not logging.getLogger().handlers:
logging.basicConfig(level=logging.INFO)
logging.getLogger().handlers[0].setFormatter(formatter)
# Redirect stdout and stderr to loggers
stdout_logger = logging.getLogger("stdout")
stdout_logger.setLevel(logging.INFO)
sl = StreamToLogger(stdout_logger, logging.INFO)
sys.stdout = sl
stderr_logger = logging.getLogger("stderr")
stderr_logger.setLevel(logging.ERROR)
sl = StreamToLogger(stderr_logger, logging.ERROR)
sys.stderr = sl
# Get logger
logger = logging.getLogger(logger_name)
logger.setLevel(logging.INFO)
# Add a file handler for all loggers
if handler is None:
os.makedirs(LOGDIR, exist_ok=True)
filename = os.path.join(LOGDIR, logger_filename)
handler = logging.handlers.TimedRotatingFileHandler(
filename, when='D', utc=True)
handler.setFormatter(formatter)
for name, item in logging.root.manager.loggerDict.items():
if isinstance(item, logging.Logger):
item.addHandler(handler)
return logger
class StreamToLogger(object):
"""
Fake file-like stream object that redirects writes to a logger instance.
"""
def __init__(self, logger, log_level=logging.INFO):
self.terminal = sys.stdout
self.logger = logger
self.log_level = log_level
self.linebuf = ''
def __getattr__(self, attr):
return getattr(self.terminal, attr)
def write(self, buf):
temp_linebuf = self.linebuf + buf
self.linebuf = ''
for line in temp_linebuf.splitlines(True):
# From the io.TextIOWrapper docs:
# On output, if newline is None, any '\n' characters written
# are translated to the system default line separator.
# By default sys.stdout.write() expects '\n' newlines and then
# translates them so this is still cross platform.
if line[-1] == '\n':
self.logger.log(self.log_level, line.rstrip())
else:
self.linebuf += line
def flush(self):
if self.linebuf != '':
self.logger.log(self.log_level, self.linebuf.rstrip())
self.linebuf = ''
def disable_torch_init():
"""
Disable the redundant torch default initialization to accelerate model creation.
"""
import torch
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
def violates_moderation(text):
"""
Check whether the text violates OpenAI moderation API.
"""
url = "https://api.openai.com/v1/moderations"
headers = {"Content-Type": "application/json",
"Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]}
text = text.replace("\n", "")
data = "{" + '"input": ' + f'"{text}"' + "}"
data = data.encode("utf-8")
try:
ret = requests.post(url, headers=headers, data=data, timeout=5)
flagged = ret.json()["results"][0]["flagged"]
except requests.exceptions.RequestException as e:
flagged = False
except KeyError as e:
flagged = False
return flagged
def pretty_print_semaphore(semaphore):
if semaphore is None:
return "None"
return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
## OmniLMM-12B
> OmniLMM-12B is released at early time of this project. We recommond you to use our [recently released models](./README_en.md), for better performance and efficiency.
> Archieve at: 2024-05-19
**OmniLMM-12B** is the most capable version. The model is built based on EVA02-5B and Zephyr-7B-β, connected with a perceiver resampler layer, and trained on multimodal data in a curriculum fashion. The model has three notable features:
- 🔥 **Strong Performance.**
OmniLMM-12B achieves **leading performance** among models with comparable sizes, surpassing established LMMs on multiple benchmarks (including MME, MMBench, SEED-Bench, etc). The model also endows rich multi-modal world knowledge.
- 🏆 **Trustworthy Behavior.**
LMMs are known for suffering from hallucination, often generating text that is not factually grounded in images (e.g., faithfully describing non-existing objects in images). OmniLMM-12B is **the first state-of-the-art open-source LMM aligned via multimodal RLHF for trustworthy behavior** (using the recent [RLHF-V](https://rlhf-v.github.io/) technique). It **ranks #1** among open-source models on [MMHal-Bench](https://huggingface.co/datasets/Shengcao1006/MMHal-Bench), and **outperforms GPT-4V** on [Object HalBench](https://arxiv.org/abs/2312.00849).
- 🕹 **Real-time Multimodal Interaction.**
We combine the OmniLMM-12B and GPT-3.5 (text-only) into a **real-time multimodal interactive assistant**. The assistant accepts video streams from the camera and speech streams from the microphone and emits speech output. While still primary, we find the model can **replicate some of the fun cases shown in the Gemini Demo video, without any video edition**.
### Evaluation <!-- omit in toc -->
<div align="center">
<img src=assets/radar_omnilmm12b.png width=66% />
</div>
<details>
<summary>Click to view results on MME, MMBench, MMMU, MMBench, MMHal-Bench, Object HalBench, SeedBench, LLaVA Bench, MathVista. </summary>
<table>
<thead>
<tr>
<th align="left">Model</th>
<th>Size</th>
<th>MME</th>
<th nowrap="nowrap">MMB dev (en)</th>
<th nowrap="nowrap" >MMMU val</th>
<th nowrap="nowrap" >MMHal-Bench</th>
<th nowrap="nowrap" >Object HalBench</th>
<th nowrap="nowrap" >SeedBench-I</th>
<th>MathVista</th>
<th nowrap="nowrap" >LLaVA Bench</th>
</tr>
</thead>
<tbody align="center">
<tr>
<td align="left">GPT-4V†</td>
<td>-</td>
<td>1771.5</td>
<td>75.1 </td>
<td>56.8</td>
<td>3.53 / 70.8</td>
<td>86.4 / 92.7</td>
<td>71.6 </td>
<td>47.8 </td>
<td>93.1 </td>
</tr>
<tr>
<td nowrap="nowrap" align="left">Qwen-VL-Plus†</td>
<td>-</td>
<td>2183.4</td>
<td>66.2 </td>
<td>45.2</td>
<td>- </td>
<td>- </td>
<td>65.7 </td>
<td>36.0 </td>
<td>73.7 </td>
</tr>
<tr>
<td align="left">Yi-VL 6B</td>
<td align="right">6.7B </td>
<td>1915.1 </td>
<td>68.6 </td>
<td>40.3 </td>
<td>- </td>
<td>- </td>
<td>67.5 </td>
<td>28.8 </td>
<td>51.9 </td>
</tr>
<tr>
<td nowrap="nowrap" align="left" >Qwen-VL-Chat</td>
<td align="right">9.6B</td>
<td>1860.0</td>
<td>60.6 </td>
<td>35.9</td>
<td>2.93 / 59.4</td>
<td>56.2 / 80.0</td>
<td>64.8 </td>
<td>33.8 </td>
<td>67.7 </td>
</tr>
<tr>
<td align="left" >CogVLM-Chat</td>
<td align="right">17.4B</td>
<td>1736.6</td>
<td>63.7 </td>
<td>32.1 </td>
<td>2.68 / 52.1 </td>
<td>73.6 / 87.4 </td>
<td>68.8 </td>
<td>34.7 </td>
<td>73.9 </td>
</tr>
<tr>
<td align="left" >LLaVA 1.5</td>
<td align="right">13.6B </td>
<td>1808.4 </td>
<td>68.2 </td>
<td>36.4 </td>
<td>2.71 / 51.0 </td>
<td>53.7 / 77.4 </td>
<td>68.1 </td>
<td>26.4 </td>
<td>64.6 </td>
</tr>
<tr>
<td nowrap="nowrap" align="left" ><b>OmniLMM-12B</b></td>
<td align="right">11.6B </td>
<td>1935.8 </td>
<td>71.6 </td>
<td>40.7 </td>
<td>3.45 / 68.8 </td>
<td>90.3 / 95.5 </td>
<td>71.1 </td>
<td>34.9 </td>
<td>72.0 </td>
</tr>
</tbody>
</table>
<small>†: Proprietary models</small>
<br>
</details>
### Examples <!-- omit in toc -->
<table align="center" >
<p align="center" >
<img src="assets/omnilmm-12b-examples_2.png" />
</p>
</table>
We combine the OmniLMM-12B and GPT-3.5 (text-only) into a **real-time multimodal interactive assistant**. Video frames are described in text using OmniLMM-12B, and ChatGPT 3.5 (text-only) is employed to generate response according to the descriptions and user prompts. The demo video is a raw recording without edition.
<div align="center" >
<video controls src="https://github.com/OpenBMB/OmniLMM/assets/157115220/485a8f52-fb4d-4eca-8fee-506347efcfc6" type="video/mp4" width=80%/>
</div>
### Model Zoo
| Model | Description | Download Link |
|:----------------------|:-------------------|:---------------:|
| OmniLMM-12B | The most capable version with leading performance. | [🤗](https://huggingface.co/openbmb/OmniLMM-12B) &nbsp;&nbsp; [<img src="./assets/modelscope_logo.png" width="20px"></img>](https://modelscope.cn/models/OpenBMB/OmniLMM-12B/files) |
#!/bin/bash
HIP_VISIBLE_DEVICES=4,5,6,7
GPUS_PER_NODE=4
torchrun $DISTRIBUTED_ARGS finetune.py \
--model_name_or_path $MODEL \
--llm_type $LLM_TYPE \
--data_path $DATA \
--eval_data_path $EVAL_DATA \
--remove_unused_columns false \
--label_names "labels" \
--prediction_loss_only false \
--bf16 false \
--bf16_full_eval false \
--fp16 true \
--fp16_full_eval true \
--do_train \
--do_eval \
--tune_vision true \
--tune_llm false \
--use_lora true \
--lora_target_modules "llm\..*layers\.\d+\.self_attn\.(q_proj|k_proj)" \
--model_max_length 2048 \
--max_slice_nums 9 \
--max_steps 100 \
--eval_steps 10 \
--output_dir /home/wanglch/projects/saves/MiniCPM-Llama3-V-2_5/lora_train_dtk \
--logging_dir /home/wanglch/projects/saves/MiniCPM-Llama3-V-2_5/lora_train_dtk \
--logging_strategy "steps" \
--per_device_train_batch_size 2 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 1 \
--evaluation_strategy "steps" \
--save_strategy "steps" \
--save_steps 100 \
--save_total_limit 10 \
--learning_rate 1e-6 \
--weight_decay 0.1 \
--adam_beta2 0.95 \
--warmup_ratio 0.01 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--gradient_checkpointing true \
--deepspeed ds_config_zero2.json \
--report_to "tensorboard" # wandb
{'train_runtime': 305.4486, 'train_samples_per_second': 2.619, 'train_steps_per_second': 0.327, 'train_loss': 0.06511555195844267, 'epoch': 100.0}
{'eval_loss': 0.389315664768219, 'eval_runtime': 0.2161, 'eval_samples_per_second': 4.628, 'eval_steps_per_second': 4.628, 'epoch': 100.0}
\ No newline at end of file
#!/bin/bash
HIP_VISIBLE_DEVICES=0,2,3,4
GPUS_PER_NODE=4
NNODES=1
NODE_RANK=0
MASTER_ADDR=localhost
MASTER_PORT=29500
MODEL="/home/wanglch/projects/MiniCPM-V/MiniCPM-Llama3-V-2_5-base" # or openbmb/MiniCPM-V-2
# ATTENTION: specify the path to your training data, which should be a json file consisting of a list of conversations.
# See the section for finetuning in README for more information.
DATA="/home/wanglch/projects/MiniCPM-V/data/self_build/train_data/train_data.json"
EVAL_DATA="/home/wanglch/projects/MiniCPM-V/data/self_build/eval_data/eval_data.json"
LLM_TYPE="llama3" # if use openbmb/MiniCPM-V-2, please set LLM_TYPE=minicpm
DISTRIBUTED_ARGS="
--nproc_per_node $GPUS_PER_NODE \
--nnodes $NNODES \
--node_rank $NODE_RANK \
--master_addr $MASTER_ADDR \
--master_port $MASTER_PORT
"
torchrun $DISTRIBUTED_ARGS finetune.py \
--model_name_or_path $MODEL \
--llm_type $LLM_TYPE \
--data_path $DATA \
--eval_data_path $EVAL_DATA \
--remove_unused_columns false \
--label_names "labels" \
--prediction_loss_only false \
--bf16 false \
--bf16_full_eval false \
--fp16 true \
--fp16_full_eval true \
--do_train \
--do_eval \
--tune_vision true \
--tune_llm false \
--use_lora true \
--lora_target_modules "llm\..*layers\.\d+\.self_attn\.(q_proj|k_proj)" \
--model_max_length 2048 \
--max_slice_nums 9 \
--max_steps 100 \
--eval_steps 10 \
--output_dir /home/wanglch/projects/saves/MiniCPM-Llama3-V-2_5/lora_train_dtk \
--logging_dir /home/wanglch/projects/saves/MiniCPM-Llama3-V-2_5/lora_train_dtk \
--logging_strategy "steps" \
--per_device_train_batch_size 2 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 1 \
--evaluation_strategy "steps" \
--save_strategy "steps" \
--save_steps 100 \
--save_total_limit 10 \
--learning_rate 1e-6 \
--weight_decay 0.1 \
--adam_beta2 0.95 \
--warmup_ratio 0.01 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--gradient_checkpointing true \
--deepspeed ds_config_zero2.json \
--report_to "tensorboard" # wandb
{'train_runtime': 961.2071, 'train_samples_per_second': 0.832, 'train_steps_per_second': 0.104, 'train_loss': 0.06350907939166063, 'epoch': 100.0}
{'eval_loss': 0.3914967179298401, 'eval_runtime': 0.8044, 'eval_samples_per_second': 1.243, 'eval_steps_per_second': 1.243, 'epoch': 100.0}
\ No newline at end of file
packaging==23.2
deepspeed
peft
addict==2.4.0
editdistance==0.6.2
einops==0.7.0
fairscale==0.4.0
jsonlines==4.0.0
markdown2==2.4.10
matplotlib==3.7.4
more_itertools==10.1.0
nltk==3.8.1
numpy==1.24.4
opencv_python_headless==4.5.5.64
openpyxl==3.1.2
Pillow
sacrebleu==2.3.2
seaborn==0.13.0
shortuuid==1.0.11
spacy==3.7.2
timm==0.9.10
torch
torchvision
tqdm==4.66.1
protobuf==4.25.0
transformers==4.40.0
typing_extensions==4.8.0
uvicorn==0.24.0.post1
#xformers==0.0.22.post7
#flash_attn==2.3.4
sentencepiece==0.1.99
accelerate==0.30.1
socksio==1.0.0
gradio
gradio_client
streamlit
tensorboard
\ No newline at end of file
#!/usr/bin/env python
# encoding: utf-8
import gradio as gr
from PIL import Image
import traceback
import re
import torch
import argparse
from transformers import AutoModel, AutoTokenizer
# README, How to run demo on different devices
# For Nvidia GPUs support BF16 (like A100, H100, RTX3090)
# python web_demo.py --device cuda --dtype bf16
# For Nvidia GPUs do NOT support BF16 (like V100, T4, RTX2080)
# python web_demo.py --device cuda --dtype fp16
# For Mac with MPS (Apple silicon or AMD GPUs).
# PYTORCH_ENABLE_MPS_FALLBACK=1 python web_demo.py --device mps --dtype fp16
# Argparser
parser = argparse.ArgumentParser(description='demo')
parser.add_argument('--device', type=str, default='cuda', help='cuda or mps')
parser.add_argument('--dtype', type=str, default='bf16', help='bf16 or fp16')
args = parser.parse_args()
device = args.device
assert device in ['cuda', 'mps']
if args.dtype == 'bf16':
if device == 'mps':
print('Warning: MPS does not support bf16, will use fp16 instead')
dtype = torch.float16
else:
dtype = torch.bfloat16
else:
dtype = torch.float16
# Load model
model_path = 'openbmb/MiniCPM-V-2'
model = AutoModel.from_pretrained(model_path, trust_remote_code=True).to(dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = model.to(device=device, dtype=dtype)
model.eval()
ERROR_MSG = "Error, please retry"
model_name = 'MiniCPM-V 2.0'
form_radio = {
'choices': ['Beam Search', 'Sampling'],
#'value': 'Beam Search',
'value': 'Sampling',
'interactive': True,
'label': 'Decode Type'
}
# Beam Form
num_beams_slider = {
'minimum': 0,
'maximum': 5,
'value': 3,
'step': 1,
'interactive': True,
'label': 'Num Beams'
}
repetition_penalty_slider = {
'minimum': 0,
'maximum': 3,
'value': 1.2,
'step': 0.01,
'interactive': True,
'label': 'Repetition Penalty'
}
repetition_penalty_slider2 = {
'minimum': 0,
'maximum': 3,
'value': 1.05,
'step': 0.01,
'interactive': True,
'label': 'Repetition Penalty'
}
max_new_tokens_slider = {
'minimum': 1,
'maximum': 4096,
'value': 1024,
'step': 1,
'interactive': True,
'label': 'Max New Tokens'
}
top_p_slider = {
'minimum': 0,
'maximum': 1,
'value': 0.8,
'step': 0.05,
'interactive': True,
'label': 'Top P'
}
top_k_slider = {
'minimum': 0,
'maximum': 200,
'value': 100,
'step': 1,
'interactive': True,
'label': 'Top K'
}
temperature_slider = {
'minimum': 0,
'maximum': 2,
'value': 0.7,
'step': 0.05,
'interactive': True,
'label': 'Temperature'
}
def create_component(params, comp='Slider'):
if comp == 'Slider':
return gr.Slider(
minimum=params['minimum'],
maximum=params['maximum'],
value=params['value'],
step=params['step'],
interactive=params['interactive'],
label=params['label']
)
elif comp == 'Radio':
return gr.Radio(
choices=params['choices'],
value=params['value'],
interactive=params['interactive'],
label=params['label']
)
elif comp == 'Button':
return gr.Button(
value=params['value'],
interactive=True
)
def chat(img, msgs, ctx, params=None, vision_hidden_states=None):
default_params = {"num_beams":3, "repetition_penalty": 1.2, "max_new_tokens": 1024}
if params is None:
params = default_params
if img is None:
return -1, "Error, invalid image, please upload a new image", None, None
try:
image = img.convert('RGB')
answer, context, _ = model.chat(
image=image,
msgs=msgs,
context=None,
tokenizer=tokenizer,
**params
)
res = re.sub(r'(<box>.*</box>)', '', answer)
res = res.replace('<ref>', '')
res = res.replace('</ref>', '')
res = res.replace('<box>', '')
answer = res.replace('</box>', '')
return 0, answer, None, None
except Exception as err:
print(err)
traceback.print_exc()
return -1, ERROR_MSG, None, None
def upload_img(image, _chatbot, _app_session):
image = Image.fromarray(image)
_app_session['sts']=None
_app_session['ctx']=[]
_app_session['img']=image
_chatbot.append(('', 'Image uploaded successfully, you can talk to me now'))
return _chatbot, _app_session
def respond(_question, _chat_bot, _app_cfg, params_form, num_beams, repetition_penalty, repetition_penalty_2, top_p, top_k, temperature):
if _app_cfg.get('ctx', None) is None:
_chat_bot.append((_question, 'Please upload an image to start'))
return '', _chat_bot, _app_cfg
_context = _app_cfg['ctx'].copy()
if _context:
_context.append({"role": "user", "content": _question})
else:
_context = [{"role": "user", "content": _question}]
print('<User>:', _question)
if params_form == 'Beam Search':
params = {
'sampling': False,
'num_beams': num_beams,
'repetition_penalty': repetition_penalty,
"max_new_tokens": 896
}
else:
params = {
'sampling': True,
'top_p': top_p,
'top_k': top_k,
'temperature': temperature,
'repetition_penalty': repetition_penalty_2,
"max_new_tokens": 896
}
code, _answer, _, sts = chat(_app_cfg['img'], _context, None, params)
print('<Assistant>:', _answer)
_context.append({"role": "assistant", "content": _answer})
_chat_bot.append((_question, _answer))
if code == 0:
_app_cfg['ctx']=_context
_app_cfg['sts']=sts
return '', _chat_bot, _app_cfg
def regenerate_button_clicked(_question, _chat_bot, _app_cfg, params_form, num_beams, repetition_penalty, repetition_penalty_2, top_p, top_k, temperature):
if len(_chat_bot) <= 1:
_chat_bot.append(('Regenerate', 'No question for regeneration.'))
return '', _chat_bot, _app_cfg
elif _chat_bot[-1][0] == 'Regenerate':
return '', _chat_bot, _app_cfg
else:
_question = _chat_bot[-1][0]
_chat_bot = _chat_bot[:-1]
_app_cfg['ctx'] = _app_cfg['ctx'][:-2]
return respond(_question, _chat_bot, _app_cfg, params_form, num_beams, repetition_penalty, repetition_penalty_2, top_p, top_k, temperature)
with gr.Blocks() as demo:
with gr.Row():
with gr.Column(scale=1, min_width=300):
params_form = create_component(form_radio, comp='Radio')
with gr.Accordion("Beam Search") as beams_according:
num_beams = create_component(num_beams_slider)
repetition_penalty = create_component(repetition_penalty_slider)
with gr.Accordion("Sampling") as sampling_according:
top_p = create_component(top_p_slider)
top_k = create_component(top_k_slider)
temperature = create_component(temperature_slider)
repetition_penalty_2 = create_component(repetition_penalty_slider2)
regenerate = create_component({'value': 'Regenerate'}, comp='Button')
with gr.Column(scale=3, min_width=500):
app_session = gr.State({'sts':None,'ctx':None,'img':None})
bt_pic = gr.Image(label="Upload an image to start")
chat_bot = gr.Chatbot(label=f"Chat with {model_name}")
txt_message = gr.Textbox(label="Input text")
regenerate.click(
regenerate_button_clicked,
[txt_message, chat_bot, app_session, params_form, num_beams, repetition_penalty, repetition_penalty_2, top_p, top_k, temperature],
[txt_message, chat_bot, app_session]
)
txt_message.submit(
respond,
[txt_message, chat_bot, app_session, params_form, num_beams, repetition_penalty, repetition_penalty_2, top_p, top_k, temperature],
[txt_message, chat_bot, app_session]
)
bt_pic.upload(lambda: None, None, chat_bot, queue=False).then(upload_img, inputs=[bt_pic,chat_bot,app_session], outputs=[chat_bot,app_session])
# launch
demo.launch(share=False, debug=True, show_api=False, server_port=8080, server_name="0.0.0.0")
#!/usr/bin/env python
# encoding: utf-8
import gradio as gr
from PIL import Image
import traceback
import re
import torch
import argparse
from transformers import AutoModel, AutoTokenizer
# README, How to run demo on different devices
# For Nvidia GPUs.
# python web_demo_2.5.py --device cuda
# For Mac with MPS (Apple silicon or AMD GPUs).
# PYTORCH_ENABLE_MPS_FALLBACK=1 python web_demo_2.5.py --device mps
# Argparser
parser = argparse.ArgumentParser(description='demo')
parser.add_argument('--device', type=str, default='cuda', help='cuda or mps')
args = parser.parse_args()
device = args.device
assert device in ['cuda', 'mps']
# Load model
model_path = 'openbmb/MiniCPM-Llama3-V-2_5'
if 'int4' in model_path:
if device == 'mps':
print('Error: running int4 model with bitsandbytes on Mac is not supported right now.')
exit()
model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
else:
model = AutoModel.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.float16, device_map=device)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model.eval()
ERROR_MSG = "Error, please retry"
model_name = 'MiniCPM-V 2.5'
form_radio = {
'choices': ['Beam Search', 'Sampling'],
#'value': 'Beam Search',
'value': 'Sampling',
'interactive': True,
'label': 'Decode Type'
}
# Beam Form
num_beams_slider = {
'minimum': 0,
'maximum': 5,
'value': 3,
'step': 1,
'interactive': True,
'label': 'Num Beams'
}
repetition_penalty_slider = {
'minimum': 0,
'maximum': 3,
'value': 1.2,
'step': 0.01,
'interactive': True,
'label': 'Repetition Penalty'
}
repetition_penalty_slider2 = {
'minimum': 0,
'maximum': 3,
'value': 1.05,
'step': 0.01,
'interactive': True,
'label': 'Repetition Penalty'
}
max_new_tokens_slider = {
'minimum': 1,
'maximum': 4096,
'value': 1024,
'step': 1,
'interactive': True,
'label': 'Max New Tokens'
}
top_p_slider = {
'minimum': 0,
'maximum': 1,
'value': 0.8,
'step': 0.05,
'interactive': True,
'label': 'Top P'
}
top_k_slider = {
'minimum': 0,
'maximum': 200,
'value': 100,
'step': 1,
'interactive': True,
'label': 'Top K'
}
temperature_slider = {
'minimum': 0,
'maximum': 2,
'value': 0.7,
'step': 0.05,
'interactive': True,
'label': 'Temperature'
}
def create_component(params, comp='Slider'):
if comp == 'Slider':
return gr.Slider(
minimum=params['minimum'],
maximum=params['maximum'],
value=params['value'],
step=params['step'],
interactive=params['interactive'],
label=params['label']
)
elif comp == 'Radio':
return gr.Radio(
choices=params['choices'],
value=params['value'],
interactive=params['interactive'],
label=params['label']
)
elif comp == 'Button':
return gr.Button(
value=params['value'],
interactive=True
)
def chat(img, msgs, ctx, params=None, vision_hidden_states=None):
default_params = {"num_beams":3, "repetition_penalty": 1.2, "max_new_tokens": 1024}
if params is None:
params = default_params
if img is None:
return -1, "Error, invalid image, please upload a new image", None, None
try:
image = img.convert('RGB')
answer = model.chat(
image=image,
msgs=msgs,
tokenizer=tokenizer,
**params
)
res = re.sub(r'(<box>.*</box>)', '', answer)
res = res.replace('<ref>', '')
res = res.replace('</ref>', '')
res = res.replace('<box>', '')
answer = res.replace('</box>', '')
return 0, answer, None, None
except Exception as err:
print(err)
traceback.print_exc()
return -1, ERROR_MSG, None, None
def upload_img(image, _chatbot, _app_session):
image = Image.fromarray(image)
_app_session['sts']=None
_app_session['ctx']=[]
_app_session['img']=image
_chatbot.append(('', 'Image uploaded successfully, you can talk to me now'))
return _chatbot, _app_session
def respond(_question, _chat_bot, _app_cfg, params_form, num_beams, repetition_penalty, repetition_penalty_2, top_p, top_k, temperature):
if _app_cfg.get('ctx', None) is None:
_chat_bot.append((_question, 'Please upload an image to start'))
return '', _chat_bot, _app_cfg
_context = _app_cfg['ctx'].copy()
if _context:
_context.append({"role": "user", "content": _question})
else:
_context = [{"role": "user", "content": _question}]
print('<User>:', _question)
if params_form == 'Beam Search':
params = {
'sampling': False,
'num_beams': num_beams,
'repetition_penalty': repetition_penalty,
"max_new_tokens": 896
}
else:
params = {
'sampling': True,
'top_p': top_p,
'top_k': top_k,
'temperature': temperature,
'repetition_penalty': repetition_penalty_2,
"max_new_tokens": 896
}
code, _answer, _, sts = chat(_app_cfg['img'], _context, None, params)
print('<Assistant>:', _answer)
_context.append({"role": "assistant", "content": _answer})
_chat_bot.append((_question, _answer))
if code == 0:
_app_cfg['ctx']=_context
_app_cfg['sts']=sts
return '', _chat_bot, _app_cfg
def regenerate_button_clicked(_question, _chat_bot, _app_cfg, params_form, num_beams, repetition_penalty, repetition_penalty_2, top_p, top_k, temperature):
if len(_chat_bot) <= 1:
_chat_bot.append(('Regenerate', 'No question for regeneration.'))
return '', _chat_bot, _app_cfg
elif _chat_bot[-1][0] == 'Regenerate':
return '', _chat_bot, _app_cfg
else:
_question = _chat_bot[-1][0]
_chat_bot = _chat_bot[:-1]
_app_cfg['ctx'] = _app_cfg['ctx'][:-2]
return respond(_question, _chat_bot, _app_cfg, params_form, num_beams, repetition_penalty, repetition_penalty_2, top_p, top_k, temperature)
with gr.Blocks() as demo:
with gr.Row():
with gr.Column(scale=1, min_width=300):
params_form = create_component(form_radio, comp='Radio')
with gr.Accordion("Beam Search") as beams_according:
num_beams = create_component(num_beams_slider)
repetition_penalty = create_component(repetition_penalty_slider)
with gr.Accordion("Sampling") as sampling_according:
top_p = create_component(top_p_slider)
top_k = create_component(top_k_slider)
temperature = create_component(temperature_slider)
repetition_penalty_2 = create_component(repetition_penalty_slider2)
regenerate = create_component({'value': 'Regenerate'}, comp='Button')
with gr.Column(scale=3, min_width=500):
app_session = gr.State({'sts':None,'ctx':None,'img':None})
bt_pic = gr.Image(label="Upload an image to start")
chat_bot = gr.Chatbot(label=f"Chat with {model_name}")
txt_message = gr.Textbox(label="Input text")
regenerate.click(
regenerate_button_clicked,
[txt_message, chat_bot, app_session, params_form, num_beams, repetition_penalty, repetition_penalty_2, top_p, top_k, temperature],
[txt_message, chat_bot, app_session]
)
txt_message.submit(
respond,
[txt_message, chat_bot, app_session, params_form, num_beams, repetition_penalty, repetition_penalty_2, top_p, top_k, temperature],
[txt_message, chat_bot, app_session]
)
bt_pic.upload(lambda: None, None, chat_bot, queue=False).then(upload_img, inputs=[bt_pic,chat_bot,app_session], outputs=[chat_bot,app_session])
# launch
demo.launch(share=False, debug=True, show_api=False, server_port=8080, server_name="0.0.0.0")
CUDA_VISIBLE_DEVICES=0 python web_demo_2.5.py --device cuda
\ No newline at end of file
CUDA_VISIBLE_DEVICES=0,1,2,3 python web_demo_2.5.py --device cuda
\ No newline at end of file
import streamlit as st
from PIL import Image
import torch
from transformers import AutoModel, AutoTokenizer
# Model path
model_path = "openbmb/MiniCPM-Llama3-V-2_5"
# User and assistant names
U_NAME = "User"
A_NAME = "Assistant"
# Set page configuration
st.set_page_config(
page_title="MiniCPM-Llama3-V-2_5 Streamlit",
page_icon=":robot:",
layout="wide"
)
# Load model and tokenizer
@st.cache_resource
def load_model_and_tokenizer():
print(f"load_model_and_tokenizer from {model_path}")
model = AutoModel.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.float16).to(device="cuda")
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
return model, tokenizer
# Initialize session state
if 'model' not in st.session_state:
st.session_state.model, st.session_state.tokenizer = load_model_and_tokenizer()
st.session_state.model.eval()
print("model and tokenizer had loaded completed!")
# Initialize session state
if 'chat_history' not in st.session_state:
st.session_state.chat_history = []
# Sidebar settings
sidebar_name = st.sidebar.title("MiniCPM-Llama3-V-2_5 Streamlit")
max_length = st.sidebar.slider("max_length", 0, 4096, 2048, step=2)
repetition_penalty = st.sidebar.slider("repetition_penalty", 0.0, 2.0, 1.05, step=0.01)
top_p = st.sidebar.slider("top_p", 0.0, 1.0, 0.8, step=0.01)
top_k = st.sidebar.slider("top_k", 0, 100, 100, step=1)
temperature = st.sidebar.slider("temperature", 0.0, 1.0, 0.7, step=0.01)
# Clear chat history button
buttonClean = st.sidebar.button("Clear chat history", key="clean")
if buttonClean:
st.session_state.chat_history = []
st.session_state.response = ""
if torch.cuda.is_available():
torch.cuda.empty_cache()
st.rerun()
# Display chat history
for i, message in enumerate(st.session_state.chat_history):
if message["role"] == "user":
with st.chat_message(name="user", avatar="user"):
if message["image"] is not None:
st.image(message["image"], caption='User uploaded image', width=448, use_column_width=False)
continue
elif message["content"] is not None:
st.markdown(message["content"])
else:
with st.chat_message(name="model", avatar="assistant"):
st.markdown(message["content"])
# Select mode
selected_mode = st.sidebar.selectbox("Select mode", ["Text", "Image"])
if selected_mode == "Image":
# Image mode
uploaded_image = st.sidebar.file_uploader("Upload image", key=1, type=["jpg", "jpeg", "png"],
accept_multiple_files=False)
if uploaded_image is not None:
st.image(uploaded_image, caption='User uploaded image', width=468, use_column_width=False)
# Add uploaded image to chat history
st.session_state.chat_history.append({"role": "user", "content": None, "image": uploaded_image})
# User input box
user_text = st.chat_input("Enter your question")
if user_text:
with st.chat_message(U_NAME, avatar="user"):
st.session_state.chat_history.append({"role": "user", "content": user_text, "image": None})
st.markdown(f"{U_NAME}: {user_text}")
# Generate reply using the model
model = st.session_state.model
tokenizer = st.session_state.tokenizer
with st.chat_message(A_NAME, avatar="assistant"):
# If the previous message contains an image, pass the image to the model
if len(st.session_state.chat_history) > 1 and st.session_state.chat_history[-2]["image"] is not None:
uploaded_image = st.session_state.chat_history[-2]["image"]
imagefile = Image.open(uploaded_image).convert('RGB')
msgs = [{"role": "user", "content": user_text}]
res = model.chat(image=imagefile, msgs=msgs, context=None, tokenizer=tokenizer,
sampling=True, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty,
temperature=temperature, stream=True)
# Collect the generated_text str
generated_text = st.write_stream(res)
st.session_state.chat_history.append({"role": "model", "content": generated_text, "image": None})
st.divider()
import streamlit as st
from PIL import Image
import torch
from transformers import AutoModel, AutoTokenizer
# Model path
model_path = "openbmb/MiniCPM-V-2"
# User and assistant names
U_NAME = "User"
A_NAME = "Assistant"
# Set page configuration
st.set_page_config(
page_title="Minicpm-V-2 Streamlit",
page_icon=":robot:",
layout="wide"
)
# Load model and tokenizer
@st.cache_resource
def load_model_and_tokenizer():
print(f"load_model_and_tokenizer from {model_path}")
model = AutoModel.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.bfloat16).to(
device="cuda:0", dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
return model, tokenizer
# Initialize session state
if 'model' not in st.session_state:
st.session_state.model, st.session_state.tokenizer = load_model_and_tokenizer()
print("model and tokenizer had loaded completed!")
# Initialize session state
if 'chat_history' not in st.session_state:
st.session_state.chat_history = []
# Sidebar settings
sidebar_name = st.sidebar.title("Minicpm-V-2 Streamlit")
max_length = st.sidebar.slider("max_length", 0, 4096, 2048, step=2)
top_p = st.sidebar.slider("top_p", 0.0, 1.0, 0.8, step=0.01)
temperature = st.sidebar.slider("temperature", 0.0, 1.0, 0.7, step=0.01)
# Clear chat history button
buttonClean = st.sidebar.button("Clear chat history", key="clean")
if buttonClean:
st.session_state.chat_history = []
st.session_state.response = ""
if torch.cuda.is_available():
torch.cuda.empty_cache()
st.rerun()
# Display chat history
for i, message in enumerate(st.session_state.chat_history):
if message["role"] == "user":
with st.chat_message(name="user", avatar="user"):
if message["image"] is not None:
st.image(message["image"], caption='User uploaded image', width=468, use_column_width=False)
continue
elif message["content"] is not None:
st.markdown(message["content"])
else:
with st.chat_message(name="model", avatar="assistant"):
st.markdown(message["content"])
# Select mode
selected_mode = st.sidebar.selectbox("Select mode", ["Text", "Image"])
if selected_mode == "Image":
# Image mode
uploaded_image = st.sidebar.file_uploader("Upload image", key=1, type=["jpg", "jpeg", "png"], accept_multiple_files=False)
if uploaded_image is not None:
st.image(uploaded_image, caption='User uploaded image', width=468, use_column_width=False)
# Add uploaded image to chat history
st.session_state.chat_history.append({"role": "user", "content": None, "image": uploaded_image})
# User input box
user_text = st.chat_input("Enter your question")
if user_text:
with st.chat_message(U_NAME, avatar="user"):
st.session_state.chat_history.append({"role": "user", "content": user_text, "image": None})
st.markdown(f"{U_NAME}: {user_text}")
# Generate reply using the model
model = st.session_state.model
tokenizer = st.session_state.tokenizer
with st.chat_message(A_NAME, avatar="assistant"):
# If the previous message contains an image, pass the image to the model
if len(st.session_state.chat_history) > 1 and st.session_state.chat_history[-2]["image"] is not None:
uploaded_image = st.session_state.chat_history[-2]["image"]
imagefile = Image.open(uploaded_image).convert('RGB')
msgs = [{"role": "user", "content": user_text}]
res, context, _ = model.chat(image=imagefile, msgs=msgs, context=None, tokenizer=tokenizer,
sampling=True,top_p=top_p,temperature=temperature)
st.markdown(f"{A_NAME}: {res}")
st.session_state.chat_history.append({"role": "model", "content": res, "image": None})
st.divider()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment