Commit 1d9ad5d4 authored by chenzk's avatar chenzk
Browse files

v1.0

parents
Pipeline #2728 failed with stages
in 0 seconds
CONTROLLER_HEART_BEAT_EXPIRATION = 30
WORKER_HEART_BEAT_INTERVAL = 15
LOGDIR = "."
# Model Constants
IGNORE_INDEX = -100
# IMAGE_TOKEN_INDEX = -200
DEFAULT_IMAGE_TOKEN = "<image>"
DEFAULT_IM_START_TOKEN = "[IMG]"
DEFAULT_IM_END_TOKEN = "[/IMG]"
# IMAGE_TOKEN_IDX = 32002
# DEFAULT_IM_START_TOKEN_IDX = 32000
# DEFAULT_IM_END_TOKEN_IDX = 32001
IMAGE_TOKEN_IDX = 151667
DEFAULT_IM_START_TOKEN_IDX = 128257
DEFAULT_IM_END_TOKEN_IDX = 128258
UND_IMAGE_TOKEN_IDX = 151655
# N_QUERY = 729
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
IMAGE_PLACEHOLDER = "<image-placeholder>"
import dataclasses
from enum import auto, Enum
from typing import List, Tuple
import base64
from io import BytesIO
from PIL import Image
class SeparatorStyle(Enum):
"""Different separator style."""
SINGLE = auto()
TWO = auto()
MPT = auto()
PLAIN = auto()
LLAMA_2 = auto()
CHATML = auto()
QWEN = auto()
@dataclasses.dataclass
class Conversation:
"""A class that keeps all conversation history."""
system: str
roles: List[str]
messages: List[List[str]]
offset: int
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
sep: str = "###"
sep2: str = None
version: str = "Unknown"
skip_next: bool = False
def get_prompt(self):
messages = self.messages
if len(messages) > 0 and type(messages[0][1]) is tuple:
messages = self.messages.copy()
init_role, init_msg = messages[0].copy()
init_msg = init_msg[0]
if "mmtag" in self.version:
init_msg = init_msg.replace("<image>", "").strip()
messages[0] = (init_role, init_msg)
messages.insert(0, (self.roles[0], "<Image><image></Image>"))
messages.insert(1, (self.roles[1], "Received."))
elif not init_msg.startswith("<image>"):
init_msg = init_msg.replace("<image>", "").strip()
messages[0] = (init_role, "<image>\n" + init_msg)
else:
messages[0] = (init_role, init_msg)
if self.sep_style == SeparatorStyle.SINGLE:
ret = self.system + self.sep
for role, message in messages:
if message:
if type(message) is tuple:
message, _, _ = message
ret += role + ": " + message + self.sep
else:
ret += role + ":"
elif self.sep_style == SeparatorStyle.TWO:
seps = [self.sep, self.sep2]
ret = self.system + seps[0]
for i, (role, message) in enumerate(messages):
if message:
if type(message) is tuple:
message, _, _ = message
ret += role + ": " + message + seps[i % 2]
else:
ret += role + ":"
elif self.sep_style == SeparatorStyle.CHATML:
ret = "" if self.system == "" else self.system + self.sep + "\n"
for role, message in messages:
if message:
if type(message) is tuple:
message, images, _ = message
message = "<image>" * len(images) + message
ret += role + "\n" + message + self.sep + "\n"
else:
ret += role + "\n"
return ret
elif self.sep_style == SeparatorStyle.LLAMA_3:
if self.tokenizer is None:
raise ValueError("Llama 3 tokenizer is not available. Make sure you have the necessary permissions.")
chat_template_messages = [{"role": "system", "content": self.system}]
for role, message in messages:
if message:
if type(message) is tuple:
message, images = message
message = "<image>" * len(images) + message
chat_template_messages.append({"role": role, "content": message})
# print(chat_template_messages)
return self.tokenizer.apply_chat_template(chat_template_messages, tokenize=False, add_generation_prompt=True)
# ret = "" if self.system == "" else self.system + self.sep + "\n"
# for role, message in messages:
# if message:
# if type(message) is tuple:
# message, images = message
# message = "<image>" * len(images) + message
# ret += role + "\n" + message + self.sep + "\n"
# else:
# ret += role + "\n"
# return ret
elif self.sep_style == SeparatorStyle.MPT:
ret = self.system + self.sep
for role, message in messages:
if message:
if type(message) is tuple:
message, _, _ = message
ret += role + message + self.sep
else:
ret += role
elif self.sep_style == SeparatorStyle.GEMMA:
ret = ""
for i, (role, message) in enumerate(messages):
assert role == self.roles[i % 2], "Conversation should alternate user/assistant/user/assistant/..."
if message:
if type(message) is tuple:
message, _, _ = message
ret += role + message + self.sep
else:
ret += role
elif self.sep_style == SeparatorStyle.LLAMA_2:
wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n" if len(msg) > 0 else msg
wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
ret = ""
for i, (role, message) in enumerate(messages):
if i == 0:
assert message, "first message should not be none"
assert role == self.roles[0], "first message should come from user"
if message:
if type(message) is tuple:
message, _, _ = message
if i == 0:
message = wrap_sys(self.system) + message
if i % 2 == 0:
message = wrap_inst(message)
ret += self.sep + message
else:
ret += " " + message + " " + self.sep2
else:
ret += ""
ret = ret.lstrip(self.sep)
elif self.sep_style == SeparatorStyle.PLAIN:
seps = [self.sep, self.sep2]
ret = self.system
for i, (role, message) in enumerate(messages):
if message:
if type(message) is tuple:
message, _, _ = message
ret += message + seps[i % 2]
else:
ret += ""
else:
raise ValueError(f"Invalid style: {self.sep_style}")
return ret
def append_message(self, role, message):
self.messages.append([role, message])
def process_image(self, image, image_process_mode, return_pil=False, image_format='PNG', max_len=1344, min_len=672):
if image_process_mode == "Pad":
def expand2square(pil_img, background_color=(122, 116, 104)):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
image = expand2square(image)
elif image_process_mode in ["Default", "Crop"]:
pass
elif image_process_mode == "Resize":
image = image.resize((336, 336))
else:
raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
if max(image.size) > max_len:
max_hw, min_hw = max(image.size), min(image.size)
aspect_ratio = max_hw / min_hw
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
longest_edge = int(shortest_edge * aspect_ratio)
W, H = image.size
if H > W:
H, W = longest_edge, shortest_edge
else:
H, W = shortest_edge, longest_edge
image = image.resize((W, H))
if return_pil:
return image
else:
buffered = BytesIO()
image.save(buffered, format=image_format)
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
return img_b64_str
def get_images(self, return_pil=False):
images = []
for i, (role, msg) in enumerate(self.messages[self.offset:]):
if i % 2 == 0:
if type(msg) is tuple:
msg, image, image_process_mode = msg
image = self.process_image(image, image_process_mode, return_pil=return_pil)
images.append(image)
return images
def to_gradio_chatbot(self):
ret = []
for i, (role, msg) in enumerate(self.messages[self.offset:]):
if i % 2 == 0:
if type(msg) is tuple:
msg, image, image_process_mode = msg
img_b64_str = self.process_image(
image, "Default", return_pil=False,
image_format='JPEG')
img_str = f'<img src="data:image/jpeg;base64,{img_b64_str}" alt="user upload image" />'
msg = img_str + msg.replace('<image>', '').strip()
ret.append([msg, None])
else:
ret.append([msg, None])
else:
ret[-1][-1] = msg
return ret
def copy(self):
return Conversation(
system=self.system,
roles=self.roles,
messages=[[x, y] for x, y in self.messages],
offset=self.offset,
sep_style=self.sep_style,
sep=self.sep,
sep2=self.sep2,
version=self.version)
def dict(self):
if len(self.get_images()) > 0:
return {
"system": self.system,
"roles": self.roles,
"messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
"offset": self.offset,
"sep": self.sep,
"sep2": self.sep2,
}
return {
"system": self.system,
"roles": self.roles,
"messages": self.messages,
"offset": self.offset,
"sep": self.sep,
"sep2": self.sep2,
}
conv_vicuna_v0 = Conversation(
system="A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
roles=("Human", "Assistant"),
messages=(
("Human", "What are the key differences between renewable and non-renewable energy sources?"),
("Assistant",
"Renewable energy sources are those that can be replenished naturally in a relatively "
"short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
"Non-renewable energy sources, on the other hand, are finite and will eventually be "
"depleted, such as coal, oil, and natural gas. Here are some key differences between "
"renewable and non-renewable energy sources:\n"
"1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
"energy sources are finite and will eventually run out.\n"
"2. Environmental impact: Renewable energy sources have a much lower environmental impact "
"than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
"and other negative effects.\n"
"3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
"have lower operational costs than non-renewable sources.\n"
"4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
"locations than non-renewable sources.\n"
"5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
"situations and needs, while non-renewable sources are more rigid and inflexible.\n"
"6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
"non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
),
offset=2,
sep_style=SeparatorStyle.SINGLE,
sep="###",
)
conv_vicuna_v1 = Conversation(
system="A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
roles=("USER", "ASSISTANT"),
version="v1",
messages=(),
offset=0,
sep_style=SeparatorStyle.TWO,
sep=" ",
sep2="</s>",
)
conv_llama_2 = Conversation(
system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""",
roles=("USER", "ASSISTANT"),
version="llama_v2",
messages=(),
offset=0,
sep_style=SeparatorStyle.LLAMA_2,
sep="<s>",
sep2="</s>",
)
conv_blip3o_llama_2 = Conversation(
system="You are a helpful language and vision assistant. "
"You are able to understand the visual content that the user provides, "
"and assist the user with a variety of tasks using natural language.",
roles=("USER", "ASSISTANT"),
version="llama_v2",
messages=(),
offset=0,
sep_style=SeparatorStyle.LLAMA_2,
sep="<s>",
sep2="</s>",
)
conv_mpt = Conversation(
system="""<|im_start|>system
A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
version="mpt",
messages=(),
offset=0,
sep_style=SeparatorStyle.MPT,
sep="<|im_end|>",
)
conv_blip3o_plain = Conversation(
system="",
roles=("", ""),
messages=(
),
offset=0,
sep_style=SeparatorStyle.PLAIN,
sep="\n",
)
conv_blip3o_v0 = Conversation(
system="A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
roles=("Human", "Assistant"),
messages=(
),
offset=0,
sep_style=SeparatorStyle.SINGLE,
sep="###",
)
conv_blip3o_v0_mmtag = Conversation(
system="A chat between a curious user and an artificial intelligence assistant. "
"The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
"The visual content will be provided with the following format: <Image>visual content</Image>.",
roles=("Human", "Assistant"),
messages=(
),
offset=0,
sep_style=SeparatorStyle.SINGLE,
sep="###",
version="v0_mmtag",
)
conv_blip3o_v1 = Conversation(
system="A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
roles=("USER", "ASSISTANT"),
version="v1",
messages=(),
offset=0,
sep_style=SeparatorStyle.TWO,
sep=" ",
sep2="</s>",
)
conv_blip3o_v1_mmtag = Conversation(
system="A chat between a curious user and an artificial intelligence assistant. "
"The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
"The visual content will be provided with the following format: <Image>visual content</Image>.",
roles=("USER", "ASSISTANT"),
messages=(),
offset=0,
sep_style=SeparatorStyle.TWO,
sep=" ",
sep2="</s>",
version="v1_mmtag",
)
conv_mistral_instruct = Conversation(
system="",
roles=("USER", "ASSISTANT"),
version="llama_v2",
messages=(),
offset=0,
sep_style=SeparatorStyle.LLAMA_2,
sep="",
sep2="</s>",
)
conv_chatml_direct = Conversation(
system="""<|im_start|>system
Answer the questions.""",
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
version="mpt",
messages=(),
offset=0,
sep_style=SeparatorStyle.MPT,
sep="<|im_end|>",
)
conv_llama3 = Conversation(
system="""<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nA chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.""",
roles=("<|start_header_id|>user<|end_header_id|>\n\n", "<|start_header_id|>assistant<|end_header_id|>\n\n"),
version="llama3",
messages=(),
offset=0,
sep_style=SeparatorStyle.MPT,
sep="<|eot_id|>",
)
conv_qwen = Conversation(
system="""<|im_start|>system
You are a helpful assistant.""",
roles=("<|im_start|>user", "<|im_start|>assistant"),
version="qwen",
messages=[],
offset=0,
sep_style=SeparatorStyle.CHATML,
sep="<|im_end|>",
)
default_conversation = conv_llama3
conv_templates = {
"default": conv_vicuna_v0,
"v0": conv_vicuna_v0,
"v1": conv_vicuna_v1,
"vicuna_v1": conv_vicuna_v1,
"llama_2": conv_llama_2,
"mistral_instruct": conv_mistral_instruct,
"chatml_direct": conv_chatml_direct,
"mistral_direct": conv_chatml_direct,
"plain": conv_blip3o_plain,
"v0_plain": conv_blip3o_plain,
"blip3o_v0": conv_blip3o_v0,
"v0_mmtag": conv_blip3o_v0_mmtag,
"blip3o_v1": conv_blip3o_v1,
"v1_mmtag": conv_blip3o_v1_mmtag,
"blip3o_llama_2": conv_blip3o_llama_2,
"llama3": conv_llama3,
"qwen": conv_qwen,
"mpt": conv_mpt,
}
if __name__ == "__main__":
print(default_conversation.get_prompt())
\ No newline at end of file
from PIL import Image
from io import BytesIO
import base64
import torch
import math
import ast
from transformers import StoppingCriteria
from blip3o.constants import IMAGE_TOKEN_IDX
def select_best_resolution(original_size, possible_resolutions):
"""
Selects the best resolution from a list of possible resolutions based on the original size.
Args:
original_size (tuple): The original size of the image in the format (width, height).
possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
Returns:
tuple: The best fit resolution in the format (width, height).
"""
original_width, original_height = original_size
best_fit = None
max_effective_resolution = 0
min_wasted_resolution = float('inf')
for width, height in possible_resolutions:
scale = min(width / original_width, height / original_height)
downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
wasted_resolution = (width * height) - effective_resolution
if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
max_effective_resolution = effective_resolution
min_wasted_resolution = wasted_resolution
best_fit = (width, height)
return best_fit
def resize_and_pad_image(image, target_resolution):
"""
Resize and pad an image to a target resolution while maintaining aspect ratio.
Args:
image (PIL.Image.Image): The input image.
target_resolution (tuple): The target resolution (width, height) of the image.
Returns:
PIL.Image.Image: The resized and padded image.
"""
original_width, original_height = image.size
target_width, target_height = target_resolution
scale_w = target_width / original_width
scale_h = target_height / original_height
if scale_w < scale_h:
new_width = target_width
new_height = min(math.ceil(original_height * scale_w), target_height)
else:
new_height = target_height
new_width = min(math.ceil(original_width * scale_h), target_width)
# Resize the image
resized_image = image.resize((new_width, new_height))
new_image = Image.new('RGB', (target_width, target_height), (0, 0, 0))
paste_x = (target_width - new_width) // 2
paste_y = (target_height - new_height) // 2
new_image.paste(resized_image, (paste_x, paste_y))
return new_image
def divide_to_patches(image, patch_size):
"""
Divides an image into patches of a specified size.
Args:
image (PIL.Image.Image): The input image.
patch_size (int): The size of each patch.
Returns:
list: A list of PIL.Image.Image objects representing the patches.
"""
patches = []
width, height = image.size
for i in range(0, height, patch_size):
for j in range(0, width, patch_size):
box = (j, i, j + patch_size, i + patch_size)
patch = image.crop(box)
patches.append(patch)
return patches
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
"""
Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
Args:
image_size (tuple): The size of the input image in the format (width, height).
grid_pinpoints (str): A string representation of a list of possible resolutions.
patch_size (int): The size of each image patch.
Returns:
tuple: The shape of the image patch grid in the format (width, height).
"""
if type(grid_pinpoints) is list:
possible_resolutions = grid_pinpoints
else:
possible_resolutions = ast.literal_eval(grid_pinpoints)
width, height = select_best_resolution(image_size, possible_resolutions)
return width // patch_size, height // patch_size
def process_anyres_image(image, processor, grid_pinpoints):
"""
Process an image with variable resolutions.
Args:
image (PIL.Image.Image): The input image to be processed.
processor: The image processor object.
grid_pinpoints (str): A string representation of a list of possible resolutions.
Returns:
torch.Tensor: A tensor containing the processed image patches.
"""
if type(grid_pinpoints) is list:
possible_resolutions = grid_pinpoints
else:
possible_resolutions = ast.literal_eval(grid_pinpoints)
best_resolution = select_best_resolution(image.size, possible_resolutions)
image_padded = resize_and_pad_image(image, best_resolution)
patches = divide_to_patches(image_padded, processor.crop_size['height'])
image_original_resize = image.resize((processor.size['shortest_edge'], processor.size['shortest_edge']))
image_patches = [image_original_resize] + patches
image_patches = [processor.preprocess(image_patch, return_tensors='pt')['pixel_values'][0]
for image_patch in image_patches]
return torch.stack(image_patches, dim=0)
def load_image_from_base64(image):
return Image.open(BytesIO(base64.b64decode(image)))
def expand2square(pil_img, background_color):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
def process_images(images, image_processor, model_cfg):
image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
new_images = []
if image_aspect_ratio == 'pad':
for image in images:
image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
new_images.append(image)
elif image_aspect_ratio == "anyres":
for image in images:
image = process_anyres_image(image, image_processor, model_cfg.image_grid_pinpoints)
new_images.append(image)
else:
return image_processor(images, return_tensors='pt')['pixel_values']
if all(x.shape == new_images[0].shape for x in new_images):
new_images = torch.stack(new_images, dim=0)
return new_images
def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_IDX, return_tensors=None):
prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]
def insert_separator(X, sep):
return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
input_ids = []
offset = 0
if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
offset = 1
input_ids.append(prompt_chunks[0][0])
for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
input_ids.extend(x[offset:])
if return_tensors is not None:
if return_tensors == 'pt':
return torch.tensor(input_ids, dtype=torch.long)
raise ValueError(f'Unsupported tensor type: {return_tensors}')
return input_ids
def get_model_name_from_path(model_path):
model_path = model_path.strip("/")
model_paths = model_path.split("/")
if model_paths[-1].startswith('checkpoint-'):
return model_paths[-2] + "_" + model_paths[-1]
else:
return model_paths[-1]
class KeywordsStoppingCriteria(StoppingCriteria):
def __init__(self, keywords, tokenizer, input_ids):
self.keywords = keywords
self.keyword_ids = []
self.max_keyword_len = 0
for keyword in keywords:
cur_keyword_ids = tokenizer(keyword).input_ids
if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
cur_keyword_ids = cur_keyword_ids[1:]
if len(cur_keyword_ids) > self.max_keyword_len:
self.max_keyword_len = len(cur_keyword_ids)
self.keyword_ids.append(torch.tensor(cur_keyword_ids))
self.tokenizer = tokenizer
self.start_len = input_ids.shape[1]
def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
for keyword_id in self.keyword_ids:
truncated_output_ids = output_ids[0, -keyword_id.shape[0]:]
if torch.equal(truncated_output_ids, keyword_id):
return True
outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
for keyword in self.keywords:
if keyword in outputs:
return True
return False
def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
outputs = []
for i in range(output_ids.shape[0]):
outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores))
return all(outputs)
from .language_model.blip3o_llama import blip3oLlamaForCausalLM, blip3oConfig
from .language_model.blip3o_qwen import blip3oQwenForCausalLM, blip3oQwenConfig
"""
Usage:
python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta
"""
import argparse
import torch
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM
from blip3o import blip3oLlamaForCausalLM
def apply_delta(base_model_path, target_model_path, delta_path):
print("Loading base model")
base = AutoModelForCausalLM.from_pretrained(
base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
print("Loading delta")
delta = blip3oLlamaForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
delta_tokenizer = AutoTokenizer.from_pretrained(delta_path)
print("Applying delta")
for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"):
if name not in base.state_dict():
assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model'
continue
if param.data.shape == base.state_dict()[name].shape:
param.data += base.state_dict()[name]
else:
assert name in ['model.embed_tokens.weight', 'lm_head.weight'], \
f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}'
bparam = base.state_dict()[name]
param.data[:bparam.shape[0], :bparam.shape[1]] += bparam
print("Saving target model")
delta.save_pretrained(target_model_path)
delta_tokenizer.save_pretrained(target_model_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--base-model-path", type=str, required=True)
parser.add_argument("--target-model-path", type=str, required=True)
parser.add_argument("--delta-path", type=str, required=True)
args = parser.parse_args()
apply_delta(args.base_model_path, args.target_model_path, args.delta_path)
from abc import ABC, abstractmethod
import torch
import torch.nn as nn
import torch.nn.functional as F
from .multimodal_encoder.builder import build_vision_tower, build_gen_vision_tower, build_dit
from .multimodal_projector.builder import build_vision_projector, build_down_projector, build_gen_vision_projector
from blip3o.constants import IGNORE_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IMAGE_TOKEN_IDX, DEFAULT_IM_START_TOKEN_IDX, DEFAULT_IM_END_TOKEN_IDX, UND_IMAGE_TOKEN_IDX
class blip3oMetaModel:
def __init__(self, config):
super(blip3oMetaModel, self).__init__(config)
if hasattr(config, "mm_vision_tower"):
# self.vision_tower = build_vision_tower(config, delay_load=True)
# self.mm_projector = build_vision_projector(config)
self.down_projector = build_down_projector(config)
if 'unpad' in getattr(config, 'mm_patch_merge_type', ''):
self.image_newline = nn.Parameter(
torch.empty(config.hidden_size, dtype=self.dtype)
)
if hasattr(config, "gen_vision_tower"):
self.gen_vision_tower = build_gen_vision_tower(config, delay_load=True)
# self.gen_projector = build_gen_vision_projector(config)
self.latent_queries = nn.Parameter(torch.randn(1, config.n_query, config.hidden_size))
print(f" latent query size {self.latent_queries.shape}")
if 'unpad' in getattr(config, 'mm_patch_merge_type', ''):
self.image_newline = nn.Parameter(
torch.empty(config.hidden_size, dtype=self.dtype)
)
self.dit, self.vae, self.noise_scheduler = build_dit(config)
# def get_vision_tower(self):
# vision_tower = getattr(self, 'vision_tower', None)
# if type(vision_tower) is list:
# vision_tower = vision_tower[0]
# return vision_tower
def get_gen_vision_tower(self):
gen_vision_tower = getattr(self, 'gen_vision_tower', None)
if type(gen_vision_tower) is list:
gen_vision_tower = gen_vision_tower[0]
return gen_vision_tower
def initialize_vision_modules(self, model_args, fsdp=None):
gen_vision_tower = model_args.gen_vision_tower
mm_vision_select_layer = model_args.mm_vision_select_layer
mm_vision_select_feature = model_args.mm_vision_select_feature
pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
pretrain_gen_mlp_adapter = model_args.pretrain_gen_mlp_adapter
mm_patch_merge_type = model_args.mm_patch_merge_type
self.config.gen_vision_tower = gen_vision_tower
self.config.vision_tower_pretrained = getattr(model_args, "vision_tower_pretrained", "")
if getattr(self, 'dit', None) is None:
print("random initiation the DiT !!!")
self.dit, self.vae, self.noise_scheduler = build_dit(model_args)
else:
print("DiT load from checkpoint!!!")
for p in self.dit.parameters():
p.requires_grad = True
if self.get_gen_vision_tower() is None:
gen_vision_tower = build_gen_vision_tower(model_args)
if fsdp is not None and len(fsdp) > 0:
self.gen_vision_tower = [gen_vision_tower]
else:
self.gen_vision_tower = gen_vision_tower
else:
if fsdp is not None and len(fsdp) > 0:
gen_vision_tower = self.gen_vision_tower[0]
else:
gen_vision_tower = self.gen_vision_tower
gen_vision_tower.load_model()
self.config.use_mm_proj = True
self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'linear')
# self.config.gen_projector_type = getattr(model_args, 'gen_projector_type', 'linear')
self.config.gen_hidden_size = gen_vision_tower.hidden_size
self.config.mm_vision_select_layer = mm_vision_select_layer
self.config.mm_vision_select_feature = mm_vision_select_feature
self.config.mm_patch_merge_type = mm_patch_merge_type
self.config.n_query = model_args.n_query
self.config.gen_pooling = model_args.gen_pooling
# if getattr(self, 'mm_projector', None) is None:
# print("random initiation the mm_project !!!")
# self.mm_projector = build_vision_projector(self.config)
# if 'unpad' in mm_patch_merge_type:
# embed_std = 1 / torch.sqrt(torch.tensor(self.config.hidden_size, dtype=self.dtype))
# self.image_newline = nn.Parameter(
# torch.randn(self.config.hidden_size, dtype=self.dtype) * embed_std
# )
# else:
# # In case it is frozen by LoRA
# for p in self.mm_projector.parameters():
# p.requires_grad = True
if getattr(self, 'down_projector', None) is None:
print("random initiation the down_projector !!!")
self.down_projector = build_down_projector(self.config)
else:
# In case it is frozen by LoRA
for p in self.down_projector.parameters():
p.requires_grad = True
if getattr(self, 'latent_queries', None) is None:
print("random initiation the latent_queries !!!")
self.latent_queries = nn.Parameter(torch.randn(1, self.config.n_query, self.config.hidden_size))
else:
print("latent_queries load from checkpoint!!!")
self.latent_queries.requires_grad = True
if pretrain_mm_mlp_adapter is not None:
mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
def get_w(weights, keyword):
return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k}
# self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector'))
def unpad_image(tensor, original_size):
"""
Unpads a PyTorch tensor of a padded and resized image.
Args:
tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format.
original_size (tuple): The original size of PIL image (width, height).
Returns:
torch.Tensor: The unpadded image tensor.
"""
original_width, original_height = original_size
current_height, current_width = tensor.shape[1:]
original_aspect_ratio = original_width / original_height
current_aspect_ratio = current_width / current_height
if original_aspect_ratio > current_aspect_ratio:
scale_factor = current_width / original_width
new_height = int(original_height * scale_factor)
padding = (current_height - new_height) // 2
unpadded_tensor = tensor[:, padding:current_height - padding, :]
else:
scale_factor = current_height / original_height
new_width = int(original_width * scale_factor)
padding = (current_width - new_width) // 2
unpadded_tensor = tensor[:, :, padding:current_width - padding]
return unpadded_tensor
class blip3oMetaForCausalLM(ABC):
@abstractmethod
def get_model(self):
pass
def get_vision_tower(self):
return self.get_model().get_vision_tower()
def get_gen_vision_tower(self):
return self.get_model().get_gen_vision_tower()
def encode_image(self, images):
# breakpoint()
gen_vision_tower = self.get_gen_vision_tower()
device = gen_vision_tower.device
images = images.to(device)
prompt_image_embeds = gen_vision_tower(images)
if 'early' in self.get_gen_pooling():
prompt_image_embeds = self.pool_img(prompt_image_embeds)
num_img, _, c = prompt_image_embeds.shape
# prompt_image_embeds = prompt_image_embeds.contiguous().view(-1, c)
# ------------- compute similarity -------
all_dist = 0
count = 0
for i in range(2, prompt_image_embeds.shape[1]-1):
diff = (prompt_image_embeds[:,i,:].unsqueeze(1) - prompt_image_embeds[:,:i,:])
dist = torch.sqrt(diff.square().sum(-1)).min().item()
all_dist+=dist
count+=1
all_dist /= count
# self.dist = all_dist
# print(self.dist)
return prompt_image_embeds
def get_mm_projector(self):
return self.get_model().mm_projector
def get_gen_projector(self):
return None
def get_n_query(self):
return self.get_model().config.n_query
def get_gen_pooling(self):
return self.get_model().config.gen_pooling
def pool_img(self, image_features):
num_img, n, c = image_features.shape
gen_pooling = self.get_gen_pooling()
# n_query = self.get_n_query()
stride = int(gen_pooling.split('_')[-1])
sqrt_n = int(n**0.5)
image_features = image_features.permute(0, 2, 1).view(num_img, c, sqrt_n, sqrt_n)
image_features = F.avg_pool2d(image_features, kernel_size=(stride, stride), stride=stride)
# image_features = image_features.view(num_img, c, -1).permute(0,2,1).contiguous()
return image_features
def get_sigmas(self, timesteps, device, n_dim=4, dtype=torch.float32):
sigmas = self.get_model().noise_scheduler.sigmas.to(device=device, dtype=dtype)
schedule_timesteps = self.get_model().noise_scheduler.timesteps.to(device=device)
timesteps = timesteps.to(device)
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < n_dim:
sigma = sigma.unsqueeze(-1)
return sigma
def mask_drop(self, latents, drop_prob=0.1):
if drop_prob <= 0:
return latents
mask = torch.bernoulli(torch.zeros(latents.shape[0], device=latents.device, dtype=latents.dtype) + drop_prob)
while len(mask.shape) < len(latents.shape):
mask = mask.unsqueeze(-1)
mask = 1 - mask # need to flip 0 <-> 1
return latents * mask
def prepare_inputs_labels_for_multimodal(
self, input_ids, position_ids, attention_mask, past_key_values, labels,
gen_images, und_images, grid_thw, i_s_pos, image_sizes=None
):
pad_ids = 128256
vision_tower = self.visual
gen_vision_tower = self.get_gen_vision_tower()
if (gen_images is None and und_images is None) or input_ids.shape[1] == 1:
return input_ids, position_ids, attention_mask, past_key_values, None, labels, None, None, None
prompt_image_embeds = gen_vision_tower(gen_images) # TODO: check dimension
if 'early' in self.get_gen_pooling():
prompt_image_embeds = self.pool_img(prompt_image_embeds)
target_image_embeds = torch.clone(prompt_image_embeds).detach()
latent_queries = self.get_model().latent_queries.repeat(input_ids.shape[0], 1, 1)
H = latent_queries.shape[-1]
latent_queries = latent_queries.contiguous().view(-1, H)
# if not gen_images is None:
# prompt_image_embeds = gen_vision_tower(gen_images) # TODO: check dimension
# if 'early' in self.get_gen_pooling():
# prompt_image_embeds = self.pool_img(prompt_image_embeds)
# # num_img, _, c = prompt_image_embeds.shape # [batch, 729, 1152]
# # prompt_image_embeds = prompt_image_embeds.contiguous().view(-1, c)
# target_image_embeds = torch.clone(prompt_image_embeds).detach()
# # prompt_image_embeds = gen_projector(prompt_image_embeds)
# latent_queries = self.get_model().latent_queries.repeat(input_ids.shape[0], 1, 1)
# H = latent_queries.shape[-1]
# latent_queries = latent_queries.contiguous().view(-1, H)
# else:
# target_image_embeds = None
# num_img = und_images.shape[0]
# dummy = torch.zeros(num_img, 3, 448, 448 , dtype=und_images.dtype, device=und_images.device) # TODO
# temp = gen_vision_tower(dummy)[:,:729,:]
# num_img, _, c = temp.shape
# temp = temp.contiguous().view(-1, c) * 1e-20
# # temp = gen_projector(temp) * 1e-9
# latent_queries = self.get_model().latent_queries.repeat(input_ids.shape[0], 1, 1)
# H = latent_queries.shape[-1]
# latent_queries = latent_queries.contiguous().view(-1, H)
if not und_images is None:
und_image_embeds = vision_tower(und_images, grid_thw=grid_thw)
# _, c = und_image_embeds.shape
# batch_size = und_images.shape[0]
# und_image_embeds = und_image_embeds.view(batch_size, -1, c)
# und_image_embeds = und_image_embeds.contiguous().view(-1, c)
# und_image_embeds = mm_projector(und_image_embeds)
# else:
# num_img = input_ids.shape[0]
# dummy = torch.zeros(num_img, 3, 384, 384 , dtype=gen_images.dtype, device=gen_images.device) # clip (3, 336, 336)
# temp = vision_tower(dummy)
# if 'early' in self.get_gen_pooling():
# temp = temp[:,:64,:]
# num_img, _, c = temp.shape
# temp = temp.contiguous().view(-1, c)
# temp = mm_projector(temp) * 1e-20
# latent_queries += temp
image_idx = (input_ids == IMAGE_TOKEN_IDX)
und_image_idx = (input_ids == UND_IMAGE_TOKEN_IDX)
# img_indicator = torch.clone(image_idx)
output_indicator = labels != -100
input_indicator = labels == -100
# img_loss_indicator = torch.logical_and(output_indicator, image_idx)
# img_loss_indicator = torch.cat(
# [img_loss_indicator[:, 1:], img_loss_indicator[:, :1]], dim=1)
# img_indicator = torch.cat(
# [img_indicator[:, 1:], img_indicator[:, :1]], dim=1)
# if not target_image_embeds is None:
# target_image_embeds = target_image_embeds[-img_loss_indicator.sum():,:]
text_embeds = self.get_model().embed_tokens(input_ids)
# N_QUERY = self.get_n_query()
gen_img_idx = torch.logical_and(output_indicator, image_idx)
# if not target_image_embeds is None:
text_embeds = text_embeds.clone()
text_embeds[gen_img_idx] = latent_queries
# text_embeds[gen_img_idx] = prompt_image_embeds.to(text_embeds.device)[:gen_img_idx.sum(),:]
# target_image_embeds = target_image_embeds.to(text_embeds.device)[:gen_img_idx.sum(),:]
und_img_idx = torch.logical_and(input_indicator, und_image_idx)
if not und_images is None:
text_embeds[und_img_idx] = und_image_embeds.to(text_embeds.device)[:und_img_idx.sum(), :]
labels[image_idx] = -100
return None, position_ids, attention_mask, past_key_values, text_embeds, labels, target_image_embeds
def initialize_vision_tokenizer(self, model_args, tokenizer):
if model_args.mm_use_im_patch_token:
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
self.resize_token_embeddings(len(tokenizer))
if model_args.mm_use_im_start_end:
num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
self.resize_token_embeddings(len(tokenizer))
if num_new_tokens > 0:
input_embeddings = self.get_input_embeddings().weight.data
output_embeddings = self.get_output_embeddings().weight.data
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
dim=0, keepdim=True)
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
dim=0, keepdim=True)
input_embeddings[-num_new_tokens:] = input_embeddings_avg
output_embeddings[-num_new_tokens:] = output_embeddings_avg
if model_args.tune_mm_mlp_adapter:
for p in self.get_input_embeddings().parameters():
p.requires_grad = True
for p in self.get_output_embeddings().parameters():
p.requires_grad = False
if model_args.pretrain_mm_mlp_adapter:
mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu')
embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight']
assert num_new_tokens == 2
if input_embeddings.shape == embed_tokens_weight.shape:
input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:]
elif embed_tokens_weight.shape[0] == num_new_tokens:
input_embeddings[-num_new_tokens:] = embed_tokens_weight
else:
raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.")
elif model_args.mm_use_im_patch_token:
if model_args.tune_mm_mlp_adapter:
for p in self.get_input_embeddings().parameters():
p.requires_grad = False
for p in self.get_output_embeddings().parameters():
p.requires_grad = False
import os
import warnings
import shutil
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
import torch
from blip3o.model import *
from blip3o.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from blip3o.train.train import smart_tokenizer_and_embedding_resize
def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda", use_flash_attn=False, **kwargs):
kwargs = {"device_map": device_map, **kwargs}
if device != "cuda":
kwargs['device_map'] = {"": device}
if load_8bit:
kwargs['load_in_8bit'] = True
elif load_4bit:
kwargs['load_in_4bit'] = True
kwargs['quantization_config'] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type='nf4'
)
else:
kwargs['torch_dtype'] = torch.float16
if use_flash_attn:
kwargs['attn_implementation'] = 'flash_attention_2'
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
model = blip3oQwenForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, torch_dtype=torch.float16).to('cuda:0')
image_processor = None
mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
if mm_use_im_patch_token:
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
if mm_use_im_start_end:
tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
model.resize_token_embeddings(len(tokenizer))
if hasattr(model.config, "max_sequence_length"):
context_len = model.config.max_sequence_length
else:
context_len = 2048
return tokenizer, model, context_len
\ No newline at end of file
import argparse
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from blip3o.model import *
from blip3o.model.utils import auto_upgrade
def consolidate_ckpt(src_path, dst_path):
print("Loading model")
auto_upgrade(src_path)
src_model = AutoModelForCausalLM.from_pretrained(src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False)
src_model.save_pretrained(dst_path)
src_tokenizer.save_pretrained(dst_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--src", type=str, required=True)
parser.add_argument("--dst", type=str, required=True)
args = parser.parse_args()
consolidate_ckpt(args.src, args.dst)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment