Commit 81028572 authored by luopl's avatar luopl
Browse files

init

parents
Pipeline #1722 canceled with stages
from .visual_tokenizer.clip_visual_tokenizer import ClipVisualTokenizerConfig, ClipVisualTokenizer
from .visual_tokenizer.siglip_visual_tokenizer import SiglipVisualTokenizerConfig, SiglipVisualTokenizer
from typing import Union, Optional
from transformers import PretrainedConfig, AutoConfig
class OvisConfig(PretrainedConfig):
model_type = "ovis"
def __init__(
self,
llm_config: Optional[Union[PretrainedConfig, dict]] = None,
visual_tokenizer_config: Optional[Union[PretrainedConfig, dict]] = None,
multimodal_max_length=8192,
hidden_size=None,
conversation_formatter_class=None,
llm_attn_implementation=None,
disable_tie_weight=False,
**kwargs
):
super().__init__(**kwargs)
if llm_config is not None:
assert isinstance(llm_config, (PretrainedConfig, dict)), \
f"expect `llm_config` to be instance of PretrainedConfig or dict, but got {type(llm_config)} type"
if not isinstance(llm_config, PretrainedConfig):
model_type = llm_config['model_type']
llm_config.pop('model_type')
llm_config = AutoConfig.for_model(model_type, **llm_config)
self.llm_config = llm_config
if visual_tokenizer_config is not None:
assert isinstance(visual_tokenizer_config, (PretrainedConfig, dict)), \
f"expect `visual_tokenizer_config` to be instance of PretrainedConfig or dict, but got {type(visual_tokenizer_config)} type"
if not isinstance(visual_tokenizer_config, PretrainedConfig):
model_type = visual_tokenizer_config['model_type']
visual_tokenizer_config.pop('model_type')
visual_tokenizer_config = AutoConfig.for_model(model_type, **visual_tokenizer_config)
self.visual_tokenizer_config = visual_tokenizer_config
self.multimodal_max_length = multimodal_max_length
self.hidden_size = hidden_size
self.conversation_formatter_class = conversation_formatter_class
self.llm_attn_implementation = llm_attn_implementation
self.disable_tie_weight = disable_tie_weight
\ No newline at end of file
from abc import ABC, abstractmethod
from typing import List, Dict
from ovis.util.constants import IMAGE_TOKEN_ID, IGNORE_ID, IMAGE_TOKEN
class ConversationFormatter(ABC):
support_tokenizer_types = None
def __init__(self, tokenizer):
tokenizer_type = type(tokenizer).__name__
assert tokenizer_type in self.support_tokenizer_types, \
f'Invalid tokenizer type, expected one from `{self.support_tokenizer_types}`, but got `{tokenizer_type}`'
self.tokenizer = tokenizer
self.image_token = IMAGE_TOKEN
self.image_token_id = IMAGE_TOKEN_ID
self.ignore_id = IGNORE_ID
def _tokenize_with_image_symbol(self, text):
text_chunks = [self.tokenizer(chunk, add_special_tokens=False).input_ids for chunk in
text.split(self.image_token)]
token_ids = []
num_chuck = len(text_chunks)
for i, chunk in enumerate(text_chunks):
token_ids.extend(chunk)
if i < num_chuck - 1:
token_ids.append(self.image_token_id)
return token_ids
@abstractmethod
def format(self, conversations: List[Dict], generation_preface=None):
pass
@abstractmethod
def format_query(self, query, generation_preface=""):
pass
class QwenConversationFormatter(ConversationFormatter):
support_tokenizer_types = ['QWenTokenizer', 'Qwen2TokenizerFast']
def __init__(self, tokenizer):
super().__init__(tokenizer)
self.from2role = {
"system": "<|im_start|>system\n",
"human": "<|im_start|>user\n",
"gpt": "<|im_start|>assistant\n",
}
self.gpt_token_num = None
self.im_end = "<|im_end|>\n"
self.default_system_prompt = "You are a helpful assistant."
def format(self, conversations: List[Dict], generation_preface=None):
if self.gpt_token_num is None:
self.gpt_token_num = len(self.tokenizer(self.from2role["gpt"], add_special_tokens=False).input_ids)
if conversations[0]["from"] != "system":
conversations.insert(0, {
"from": "system",
"value": self.default_system_prompt
})
if generation_preface is not None:
conversations.append({
"from": "gpt",
"value": generation_preface
})
prompt = ""
input_ids = []
labels = []
num_conversation = len(conversations)
for i, conversation in enumerate(conversations):
frm = conversation["from"]
role = self.from2role[frm]
message = conversation["value"]
text = role + message
if i < num_conversation - 1 or generation_preface is None:
text += self.im_end
prompt += text
token_ids = self._tokenize_with_image_symbol(text)
input_ids.extend(token_ids)
label_ids = [self.ignore_id] * len(token_ids)
if frm == "gpt" and generation_preface is None:
# learning `\n` following `im_end` is meaningless, so the last `\n` token is ignored in label
label_ids[self.gpt_token_num:-1] = token_ids[self.gpt_token_num:-1]
labels.extend(label_ids)
assert self._tokenize_with_image_symbol(prompt) == input_ids
assert len(input_ids) == len(labels)
return prompt, input_ids, labels
def format_query(self, query, generation_preface=""):
prompt, input_ids, _ = self.format([{
"from": "human",
"value": query
}], generation_preface=generation_preface)
return prompt, input_ids
class Llama3ConversationFormatter(ConversationFormatter):
support_tokenizer_types = ['PreTrainedTokenizerFast']
def __init__(self, tokenizer):
super().__init__(tokenizer)
self.from2role = {
"system": "<|start_header_id|>system<|end_header_id|>\n\n",
"human": "<|start_header_id|>user<|end_header_id|>\n\n",
"gpt": "<|start_header_id|>assistant<|end_header_id|>\n\n",
}
self.gpt_token_num = None
self.im_end = "<|eot_id|>"
self.default_system_prompt = "You are a helpful and honest multimodal assistant."
self.bos_token = "<|begin_of_text|>"
self.bos_token_ids = None
def format(self, conversations: List[Dict], generation_preface=None):
if self.gpt_token_num is None:
self.gpt_token_num = len(self.tokenizer(self.from2role["gpt"], add_special_tokens=False).input_ids)
if self.bos_token_ids is None:
self.bos_token_ids = self.tokenizer(self.bos_token, add_special_tokens=False).input_ids
if conversations[0]["from"] != "system":
conversations.insert(0, {
"from": "system",
"value": self.default_system_prompt
})
if generation_preface is not None:
conversations.append({
"from": "gpt",
"value": generation_preface
})
prompt = "" + self.bos_token
input_ids = [] + self.bos_token_ids
labels = [] + [IGNORE_ID] * len(input_ids)
num_conversation = len(conversations)
for i, conversation in enumerate(conversations):
frm = conversation["from"]
role = self.from2role[frm]
message = conversation["value"].strip()
text = role + message
if i < num_conversation - 1 or generation_preface is None:
text += self.im_end
prompt += text
token_ids = self._tokenize_with_image_symbol(text)
input_ids.extend(token_ids)
label_ids = [self.ignore_id] * len(token_ids)
if frm == "gpt":
label_ids[self.gpt_token_num:] = token_ids[self.gpt_token_num:]
labels.extend(label_ids)
assert self._tokenize_with_image_symbol(prompt) == input_ids
assert len(input_ids) == len(labels)
return prompt, input_ids, labels
def format_query(self, query, generation_preface=""):
prompt, input_ids, _ = self.format([{
"from": "human",
"value": query
}], generation_preface=generation_preface)
return prompt, input_ids
class GemmaConversationFormatter(ConversationFormatter):
support_tokenizer_types = ['GemmaTokenizer', 'GemmaTokenizerFast']
def __init__(self, tokenizer):
super().__init__(tokenizer)
# Gemma does not support system prompt
self.from2role = {
"human": "<start_of_turn>user\n",
"gpt": "<start_of_turn>model\n",
}
self.gpt_token_num = None
self.im_end = "<end_of_turn>\n"
self.bos_token = "<bos>"
self.bos_token_ids = None
def format(self, conversations: List[Dict], generation_preface=None):
if self.gpt_token_num is None:
self.gpt_token_num = len(self.tokenizer(self.from2role["gpt"], add_special_tokens=False).input_ids)
if self.bos_token_ids is None:
self.bos_token_ids = self.tokenizer(self.bos_token, add_special_tokens=False).input_ids
if conversations[0]["from"] == "system":
raise ValueError("Gemma does not support system prompt")
if generation_preface is not None:
conversations.append({
"from": "gpt",
"value": generation_preface
})
prompt = "" + self.bos_token
input_ids = [] + self.bos_token_ids
labels = [] + [IGNORE_ID] * len(input_ids)
num_conversation = len(conversations)
for i, conversation in enumerate(conversations):
frm = conversation["from"]
role = self.from2role[frm]
message = conversation["value"].strip()
text = role + message
if i < num_conversation - 1 or generation_preface is None:
text += self.im_end
prompt += text
token_ids = self._tokenize_with_image_symbol(text)
input_ids.extend(token_ids)
label_ids = [self.ignore_id] * len(token_ids)
if frm == "gpt":
# learning `\n` following `im_end` is meaningless, so the last `\n` token is ignored in label
label_ids[self.gpt_token_num:-1] = token_ids[self.gpt_token_num:-1]
labels.extend(label_ids)
assert self._tokenize_with_image_symbol(prompt) == input_ids
assert len(input_ids) == len(labels)
return prompt, input_ids, labels
def format_query(self, query, generation_preface=""):
prompt, input_ids, _ = self.format([{
"from": "human",
"value": query
}], generation_preface=generation_preface)
return prompt, input_ids
import logging
import os
from datetime import datetime
from importlib import import_module
from typing import List, Union, Callable, Optional, Dict
import PIL.Image
import deepspeed
import torch
from torch import Tensor
from torch.nn import init
from transformers import PreTrainedModel, AutoConfig, AutoModel, AutoTokenizer, AutoModelForCausalLM
from transformers.cache_utils import HybridCache
from transformers.generation.utils import GenerateOutput
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled, deepspeed_config
from ovis.model.configuration_ovis import OvisConfig
from ovis.model.conversation_formatter import ConversationFormatter
from ovis.util.constants import IGNORE_ID, BEGIN_LINE, END_LINE, IMAGE_ATOM_ID, IMAGE_INDICATOR_IDS, \
IMAGE_TOKEN_ID
from ovis.util.utils import rank0_print
class VisualEmbedding(torch.nn.Embedding):
def forward(self, visual_tokens: Tensor) -> Tensor:
if visual_tokens.dtype in [torch.int8, torch.int16, torch.int32, torch.int64, torch.long]:
return super().forward(visual_tokens)
return torch.matmul(visual_tokens, self.weight)
def reset_parameters(self, mean=0., std=1.) -> None:
init.normal_(self.weight, mean=mean, std=std)
self._fill_padding_idx_with_zero()
class OvisPreTrainedModel(PreTrainedModel):
config_class = OvisConfig
base_model_prefix = "ovis"
class Ovis(OvisPreTrainedModel):
def __init__(self, config: OvisConfig, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
if kwargs.get('train_from_scratch'):
self.llm = kwargs['llm']
self.generation_config = self.llm.generation_config
self.config.llm_config = self.llm.config
self.config.hidden_size = self.llm.config.hidden_size # for deepspeed auto configuration
self.text_tokenizer = kwargs['text_tokenizer']
self.visual_tokenizer = kwargs['visual_tokenizer']
self.config.visual_tokenizer_config = self.visual_tokenizer.config
else:
attn_kwargs = dict()
if self.config.llm_attn_implementation:
attn_kwargs['attn_implementation'] = self.config.llm_attn_implementation
self.llm = AutoModelForCausalLM.from_config(self.config.llm_config, **attn_kwargs)
assert self.config.hidden_size == self.llm.config.hidden_size, "hidden size mismatch"
self.text_tokenizer = AutoTokenizer.from_pretrained(self.config.name_or_path)
self.visual_tokenizer = AutoModel.from_config(self.config.visual_tokenizer_config,
image_processor_name_or_path=self.config.name_or_path)
# initialize vte
if is_deepspeed_zero3_enabled():
with deepspeed.zero.Init(config_dict_or_path=deepspeed_config()):
self.vte = VisualEmbedding(self.config.visual_tokenizer_config.vocab_size, self.config.hidden_size)
else:
self.vte = VisualEmbedding(self.config.visual_tokenizer_config.vocab_size, self.config.hidden_size,
device=self.visual_tokenizer.device, dtype=self.visual_tokenizer.dtype)
def _merge_modules(modules_list: tuple):
merged_modules = []
for modules in modules_list:
merged_modules.extend(modules if modules else [])
return merged_modules
self._no_split_modules = _merge_modules((self.llm._no_split_modules, self.visual_tokenizer._no_split_modules))
self._skip_keys_device_placement = self.llm._skip_keys_device_placement
self._keep_in_fp32_modules = _merge_modules(
(self.llm._keep_in_fp32_modules, self.visual_tokenizer._keep_in_fp32_modules))
self.is_parallelizable = all((self.llm.is_parallelizable, self.visual_tokenizer.is_parallelizable))
self.supports_gradient_checkpointing = all(
(self.llm.supports_gradient_checkpointing, self.visual_tokenizer.supports_gradient_checkpointing))
self._supports_flash_attn_2 = all(
(self.llm._supports_flash_attn_2, self.visual_tokenizer._supports_flash_attn_2))
self._supports_sdpa = all((self.llm._supports_sdpa, self.visual_tokenizer._supports_sdpa))
def get_text_tokenizer(self):
return self.text_tokenizer
def get_visual_tokenizer(self):
return self.visual_tokenizer
def tie_weights(self):
if not self.config.disable_tie_weight:
self.get_llm().tie_weights()
def re_init_vte(self, mean, std):
vte = self.get_vte()
rank0_print(BEGIN_LINE)
rank0_print(f'[{datetime.now()}] Before re-initialization of vte: ')
with deepspeed.zero.GatheredParameters([vte.weight]):
rank0_print(f'vte.weight: {vte.weight}')
with deepspeed.zero.GatheredParameters([vte.weight], modifier_rank=0):
if not is_deepspeed_zero3_enabled() or deepspeed.comm.get_rank() == 0:
vte.reset_parameters(mean, std)
rank0_print(f'[{datetime.now()}] After re-initialization of vte:')
with deepspeed.zero.GatheredParameters([vte.weight]):
rank0_print(f'vte.weight: {vte.weight}')
rank0_print(END_LINE)
def get_monitor_tensors(self):
monitor_tensors = dict(
wte=self.get_wte().weight,
lm_head=self.get_lm_head().weight,
vte=self.get_vte().weight
)
monitor_tensors.update(
{f'visual_tokenizer_{k}': v for k, v in self.get_visual_tokenizer().get_monitor_tensors().items()})
return monitor_tensors
def get_lm_head(self):
return self.get_llm().get_output_embeddings()
def get_llm(self):
return self.llm
def get_vte(self):
return self.vte
def get_wte(self):
return self.llm.get_input_embeddings()
def get_conversation_formatter(self) -> ConversationFormatter:
if getattr(self, 'conversation_formatter', None) is None:
self.conversation_formatter = getattr(import_module(".conversation_formatter", __package__),
self.config.conversation_formatter_class)(self.text_tokenizer)
return self.conversation_formatter
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
labels: Optional[torch.Tensor],
pixel_values: List[Optional[torch.Tensor]],
**kwargs
):
assert self.training, "`forward` can only be used in training. For inference, use `generate`."
_, inputs_embeds, labels, attention_mask = self.merge_multimodal(
text_input_ids=input_ids,
text_attention_masks=attention_mask,
text_labels=labels,
pixel_values=pixel_values
)
return self.llm(inputs_embeds=inputs_embeds, labels=labels, attention_mask=attention_mask, **kwargs)
def merge_multimodal(
self,
text_input_ids: torch.Tensor,
text_attention_masks: torch.Tensor,
text_labels: Optional[torch.Tensor],
pixel_values: List[Optional[torch.Tensor]]
):
input_device = text_input_ids.device
visual_vocab_szie = self.get_visual_tokenizer().config.vocab_size
visual_indicator_embeds = self.get_vte()(
torch.tensor(
list(range(visual_vocab_szie - 5, visual_vocab_szie)),
dtype=torch.long,
device=self.get_visual_tokenizer().device
)
).to(device=input_device)
if self.training:
# When training, to be compatible with deepspeed zero, each sample has to include pixel_value tensor.
# For text-only sample, one can simply use a full zero tensor as pixel_value, which will be ignored
# (see below in this function); so, the gradient will not be affected.
num_images = [x.shape[0] for x in pixel_values]
visual_tokens = self.visual_tokenizer(torch.cat([x for x in pixel_values], dim=0))
visual_embeds = torch.split(self.get_vte()(visual_tokens).to(dtype=self.dtype, device=input_device),
split_size_or_sections=num_images, dim=0)
visual_input_ids = torch.split(torch.argmax(visual_tokens, dim=-1).to(device=input_device),
split_size_or_sections=num_images, dim=0)
visual_labels = [torch.full(x.shape, IGNORE_ID, dtype=torch.long, device=input_device) for x in
visual_input_ids]
else:
# When inference, sample can include only text with `None` pixel_value
num_images = [x.shape[0] if x is not None else 0 for x in pixel_values]
if sum(num_images) > 0:
visual_tokens = self.visual_tokenizer(torch.cat([x for x in pixel_values if x is not None], dim=0))
visual_embeds = torch.split(self.get_vte()(visual_tokens).to(dtype=self.dtype, device=input_device),
split_size_or_sections=num_images, dim=0)
visual_input_ids = torch.split(torch.argmax(visual_tokens, dim=-1).to(device=input_device),
split_size_or_sections=num_images, dim=0)
visual_labels = [torch.full(x.shape, IGNORE_ID, dtype=torch.long, device=input_device) for x in
visual_input_ids]
else:
# just placeholders
visual_embeds = [None] * len(num_images)
visual_input_ids = [None] * len(num_images)
visual_labels = [None] * len(num_images)
# just placeholders
text_labels = torch.full(text_input_ids.shape, IGNORE_ID, dtype=torch.long, device=input_device)
input_embeds = []
attention_masks = []
labels = []
for text_input_id, text_label, text_attention_mask, visual_embed, visual_input_id, visual_label in zip(
text_input_ids, text_labels, text_attention_masks, visual_embeds, visual_input_ids, visual_labels
):
placeholder_token_mask = torch.lt(text_input_id, 0)
text_embed = self.get_wte()(torch.masked_fill(text_input_id, placeholder_token_mask, 0))
for i, indicator_id in enumerate(IMAGE_INDICATOR_IDS):
text_embed[text_input_id == indicator_id] = visual_indicator_embeds[i]
image_atom_positions = torch.where(torch.eq(text_input_id, IMAGE_ATOM_ID))[0].tolist()
if len(image_atom_positions) > 0:
input_embed_parts = []
attention_mask_parts = []
label_parts = []
prev_image_atom_position = -1
for index, image_atom_position in enumerate(image_atom_positions):
input_embed_parts.append(
text_embed[prev_image_atom_position + 1:image_atom_position, :])
label_parts.append(
text_label[prev_image_atom_position + 1:image_atom_position])
attention_mask_parts.append(
text_attention_mask[prev_image_atom_position + 1:image_atom_position])
input_embed_parts.append(visual_embed[index])
attention_mask_parts.append(
torch.ones_like(visual_label[index], dtype=torch.bool))
label_parts.append(visual_label[index])
prev_image_atom_position = image_atom_position
if prev_image_atom_position + 1 < text_input_id.shape[0]:
input_embed_parts.append(
text_embed[prev_image_atom_position + 1:, :])
attention_mask_parts.append(
text_attention_mask[prev_image_atom_position + 1:])
label_parts.append(
text_label[prev_image_atom_position + 1:])
input_embed = torch.cat(input_embed_parts, dim=0)
attention_mask = torch.cat(attention_mask_parts, dim=0)
label = torch.cat(label_parts, dim=0)
else:
input_embed = text_embed
attention_mask = text_attention_mask
label = text_label
if self.training:
# Make visual_embed & visual_indicator_embeds involved in the backward graph,
# to be compatible with deepspeed zero and ddp.
input_embed += torch.sum(visual_embed * 0.0) + torch.sum(visual_indicator_embeds * 0.0)
input_embeds.append(input_embed)
attention_masks.append(attention_mask)
labels.append(label)
if self.training: # padding to self.config.multimodal_max_length for increased training speed
padding_size = max(0, self.config.multimodal_max_length - len(input_embeds[0]))
input_embeds[0] = torch.nn.ConstantPad2d((0, 0, 0, padding_size), 0.0)(input_embeds[0])
attention_masks[0] = torch.nn.ConstantPad1d((0, padding_size), False)(attention_masks[0])
labels[0] = torch.nn.ConstantPad1d((0, padding_size), IGNORE_ID)(labels[0])
batch_input_embeds = torch.nn.utils.rnn.pad_sequence(input_embeds, batch_first=True, padding_value=0.0)[:,
:self.config.multimodal_max_length, :]
batch_attention_mask = torch.nn.utils.rnn.pad_sequence(attention_masks, batch_first=True, padding_value=False)[
:,
:self.config.multimodal_max_length]
batch_labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_ID)[:,
:self.config.multimodal_max_length]
return visual_input_ids, batch_input_embeds, batch_labels, batch_attention_mask
def preprocess_inputs(
self,
text_or_conversations: Union[List[Dict], str],
images: Optional[List[PIL.Image.Image]],
max_partition=9,
generation_preface='',
return_labels=False,
propagate_exception=True
):
# convert text to conversations
if isinstance(text_or_conversations, str):
conversations = [{
"from": "human",
"value": text_or_conversations
}]
elif isinstance(text_or_conversations, list):
conversations = text_or_conversations
else:
raise ValueError(f'Invalid type of `text_or_conversations`, expected `List[Dict]` or `str`,'
f' but got {type(text_or_conversations)}')
# format conversations
prompt, raw_input_ids, raw_labels = self.get_conversation_formatter().format(
conversations, generation_preface=generation_preface)
# place image placeholders
input_ids = []
labels = []
pixel_values = []
invalidate_label = False
image_token_indices = [i for i, v in enumerate(raw_input_ids) if v == IMAGE_TOKEN_ID]
last_image_token_index = -1
for i in range(len(image_token_indices)):
head = 0 if i == 0 else image_token_indices[i - 1] + 1
tail = image_token_indices[i]
last_image_token_index = tail
input_ids.extend(raw_input_ids[head:tail])
labels.extend(raw_labels[head:tail])
try:
image = images[i]
raw_pixel_values, image_placeholders = self.visual_tokenizer.preprocess_image(
image, max_partition=max_partition)
except Exception as e:
if propagate_exception:
raise e
logging.exception(e)
invalidate_label = True
raw_pixel_values, image_placeholders = self.visual_tokenizer.mock_input()
input_ids.extend(image_placeholders)
labels.extend([IGNORE_ID] * len(image_placeholders))
pixel_values.append(raw_pixel_values)
input_ids.extend(raw_input_ids[last_image_token_index + 1:])
labels.extend(raw_labels[last_image_token_index + 1:])
# return tensors
input_ids = torch.tensor(input_ids, dtype=torch.long)
labels = torch.tensor([IGNORE_ID] * len(labels) if invalidate_label else labels, dtype=torch.long)
pixel_values = torch.cat(pixel_values, dim=0) if len(pixel_values) > 0 else None
if return_labels:
return prompt, input_ids, pixel_values, labels
else:
return prompt, input_ids, pixel_values
def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
is_main_process: bool = True,
state_dict: Optional[dict] = None,
save_function: Callable = torch.save,
push_to_hub: bool = False,
max_shard_size: Union[int, str] = "5GB",
safe_serialization: bool = True,
variant: Optional[str] = None,
token: Optional[Union[str, bool]] = None,
save_peft_format: bool = True,
**kwargs
):
super().save_pretrained(save_directory,
is_main_process=is_main_process,
state_dict=state_dict,
save_function=save_function,
safe_serialization=safe_serialization)
self.get_text_tokenizer().save_pretrained(save_directory)
self.get_visual_tokenizer().get_image_processor().save_pretrained(save_directory)
# uncomment the following will additionally save a separate visual tokenizer
# visual_tokenizer_directory = os.path.join(save_directory, 'visual_tokenizer')
# self.get_visual_tokenizer().save_pretrained(visual_tokenizer_directory,
# is_main_process=is_main_process,
# state_dict=None,
# save_function=save_function,
# safe_serialization=safe_serialization)
# self.get_visual_tokenizer().get_image_processor().save_pretrained(visual_tokenizer_directory)
def _get_hybrid_cache_for_llm(self, max_batch_size: int, max_cache_len: int):
cache_cls = HybridCache
llm = self.get_llm()
need_new_cache = (
not hasattr(llm, "_cache")
or (not isinstance(llm._cache, cache_cls))
or llm._cache.max_batch_size != max_batch_size
or llm._cache.max_cache_len < max_cache_len
)
if need_new_cache:
if hasattr(llm.config, "_pre_quantization_dtype"):
cache_dtype = llm.config._pre_quantization_dtype
else:
cache_dtype = llm.dtype
llm._cache = cache_cls(
config=llm.config,
max_batch_size=max_batch_size,
max_cache_len=max_cache_len,
device=llm.device,
dtype=cache_dtype,
)
else:
llm._cache.reset()
return llm._cache
# TODO: support batch generation
def generate(
self,
inputs: Optional[torch.Tensor] = None,
**kwargs
) -> Union[GenerateOutput, torch.LongTensor]:
assert inputs.shape[0] == 1, 'Currently, only support `batch_size=1`'
_, inputs_embeds, labels, attention_mask = self.merge_multimodal(
text_input_ids=inputs,
text_attention_masks=kwargs.pop('attention_mask'),
text_labels=None,
pixel_values=kwargs.pop('pixel_values')
)
if getattr(self.generation_config, 'cache_implementation') == 'hybrid': # mainly for Gemma2
kwargs['past_key_values'] = self._get_hybrid_cache_for_llm(
getattr(kwargs, "num_beams", 1), kwargs['max_new_tokens'] + inputs_embeds.shape[-2])
self.get_llm()._supports_cache_class = True
kwargs['cache_implementation'] = None
return self.llm.generate(inputs=None, inputs_embeds=inputs_embeds, attention_mask=attention_mask, **kwargs)
AutoConfig.register("ovis", OvisConfig)
AutoModelForCausalLM.register(OvisConfig, Ovis)
from typing import Union, Optional
import PIL.Image
import torch
from torch.nn.functional import softmax, gumbel_softmax, pad
from transformers import PretrainedConfig, PreTrainedModel, AutoImageProcessor, AutoModel, AutoConfig
from ovis.util.constants import IMAGE_INDICATOR_IDS, IMAGE_ATOM_ID
class BaseVisualTokenizerConfig(PretrainedConfig):
def __init__(
self,
vocab_size=16384,
tokenize_function="softmax",
tau=1.0,
depths=None,
drop_cls_token=False,
backbone_config: Optional[Union[PretrainedConfig, dict]] = None,
hidden_stride: int = 1,
**kwargs
):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.tokenize_function = tokenize_function
self.tau = tau
if isinstance(depths, str):
depths = [int(x) for x in depths.split('|')]
self.depths = depths
self.backbone_kwargs = {}
self.drop_cls_token = drop_cls_token
if backbone_config is not None:
assert isinstance(backbone_config, (PretrainedConfig, dict)), \
f"expect `backbone_config` to be instance of PretrainedConfig or dict, but got {type(backbone_config)} type"
if not isinstance(backbone_config, PretrainedConfig):
model_type = backbone_config['model_type']
backbone_config.pop('model_type')
backbone_config = AutoConfig.for_model(model_type, **backbone_config)
self.backbone_config = backbone_config
self.hidden_stride = hidden_stride
class BaseVisualTokenizer(PreTrainedModel):
base_model_prefix = "backbone"
main_input_name = None
_image_processor_class = None
_image_processor_kwargs = {}
_backbone_class = None
_backbone_name_or_path = None
def __init__(self, config: BaseVisualTokenizerConfig, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
if kwargs.get('train_from_scratch'):
self.image_processor = self._image_processor_class.from_pretrained(self._backbone_name_or_path,
**self._image_processor_kwargs)
self.backbone = self._backbone_class.from_pretrained(self._backbone_name_or_path,
**self.config.backbone_kwargs)
self.config.backbone_config = self.backbone.config
else:
self.image_processor = AutoImageProcessor.from_pretrained(kwargs['image_processor_name_or_path'])
self.backbone = AutoModel.from_config(self.config.backbone_config)
head_dim = self.config.vocab_size - len(IMAGE_INDICATOR_IDS) # reserved tokens for IMAGE_INDICATORS
self.head = torch.nn.Sequential(
torch.nn.Linear(
self.backbone.config.hidden_size * self.config.hidden_stride * self.config.hidden_stride, head_dim,
bias=False
),
torch.nn.LayerNorm(head_dim)
)
assert all((self.image_processor.do_resize,
not getattr(self.image_processor, 'do_center_crop', False),
self.image_processor.do_rescale,
self.image_processor.do_normalize
)), f"image_processor `{self.image_processor}` is not supported currently"
def get_backbone(self):
return self.backbone
def get_monitor_tensors(self):
raise NotImplementedError
def get_image_processor(self):
return self.image_processor
def mock_input(self):
height, width = self.get_image_size()
return torch.zeros(1, 3, height, width), self.construct_image_placeholders((1, 1))
def get_head(self):
return self.head
def get_image_size(self):
raise NotImplementedError
@staticmethod
def construct_image_placeholders(grid):
image_placeholders = [IMAGE_INDICATOR_IDS[0], IMAGE_ATOM_ID, IMAGE_INDICATOR_IDS[1]]
if grid[0] * grid[1] > 1:
for r in range(grid[0]):
for c in range(grid[1]):
image_placeholders.append(IMAGE_ATOM_ID)
if c < grid[1] - 1:
image_placeholders.append(IMAGE_INDICATOR_IDS[2])
if r < grid[0] - 1:
image_placeholders.append(IMAGE_INDICATOR_IDS[3])
image_placeholders.append(IMAGE_INDICATOR_IDS[4])
return image_placeholders
def preprocess_image(self, image: PIL.Image.Image, max_partition=9, covering_threshold=0.9, convert_to_rgb=True):
def _preprocess(img: PIL.Image.Image, side):
# first resize and preprocess
w, h = img.size
if w == h:
new_width = new_height = side
elif w > h:
new_width = side
new_height = int(h / w * new_width)
else:
new_height = side
new_width = int(w / h * new_height)
new_size = dict(height=new_height, width=new_width)
pixel_values = self.image_processor.preprocess(img, size=new_size, return_tensors='pt')['pixel_values']
# then pad to square
square_values = torch.zeros([1, 3, side, side], dtype=pixel_values.dtype, device=pixel_values.device)
new_height, new_width = pixel_values.shape[2:]
if new_height == new_width:
square_values[:, :, :, :] = pixel_values
elif new_height > new_width:
from_index = (side - new_width) // 2
square_values[:, :, :, from_index:from_index + new_width] = pixel_values
else:
from_index = (side - new_height) // 2
square_values[:, :, from_index:from_index + new_height, :] = pixel_values
return square_values
def _partition(img, grid):
w, h = img.size
row_height = h // grid[0]
col_width = w // grid[1]
partition = []
for row in range(grid[0]):
for col in range(grid[1]):
left = col * col_width
upper = row * row_height
right = w if col == grid[1] - 1 else (col + 1) * col_width
lower = h if row == grid[0] - 1 else (row + 1) * row_height
partition.append((left, upper, right, lower))
return partition
def _covering_area(left, upper, right, lower, side):
w = right - left
h = lower - upper
w, h = max(w, h), min(w, h)
if w > side:
h = h / w * side
w = side
return w * h
def _get_best_grid(img, side):
img_area = img.size[0] * img.size[1]
candidate_grids = []
for i in range(1, max_partition + 1):
for j in range(1, max_partition + 1):
if i * j <= max_partition:
candidate_grids.append((i, j))
all_grids = []
good_grids = []
for grid in candidate_grids:
partition = _partition(img, grid)
covering_ratio = sum([_covering_area(*p, side) for p in partition]) / img_area
assert covering_ratio <= 1.0
all_grids.append((grid, covering_ratio))
if covering_ratio > covering_threshold:
good_grids.append((grid, covering_ratio))
if len(good_grids) > 0:
# pick the good partition with minimum #sub_images and break the tie using covering_ratio
return sorted(good_grids, key=lambda x: (x[0][0] * x[0][1], -x[1]))[0][0]
else:
# pick the partition with maximum covering_ratio and break the tie using #sub_images
return sorted(all_grids, key=lambda x: (-x[1], x[0][0] * x[0][1]))[0][0]
if convert_to_rgb and image.mode != 'RGB':
image = image.convert('RGB')
sides = self.get_image_size()
if sides[0] != sides[1]:
raise ValueError('get_image_size() returns non-square size')
side = sides[0]
grid = _get_best_grid(image, side)
partition = _partition(image, grid)
crops = [image.crop(p) for p in partition]
if len(crops) > 1:
crops.insert(0, image)
pixel_values = torch.cat([_preprocess(crop, side) for crop in crops], dim=0)
image_placeholders = self.construct_image_placeholders(grid)
return pixel_values, image_placeholders
def get_backbone_layer(self, index):
return self.backbone.vision_model.encoder.layers[index]
def tokenize(self, logits):
def st_argmax(y_soft, dim): # straight-through softmax
index = y_soft.max(dim, keepdim=True)[1]
y_hard = torch.zeros_like(y_soft, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0)
ret = y_hard - y_soft.detach() + y_soft
return ret
if self.config.tokenize_function == 'softmax':
tokens = softmax(logits, dim=-1)
elif self.config.tokenize_function == 'gumbel_argmax':
tokens = gumbel_softmax(logits, tau=self.config.tau, hard=True)
elif self.config.tokenize_function == 'st_argmax':
tokens = st_argmax(logits, dim=-1)
else:
raise ValueError(
f'Invalid `max_type`, expected softmax or gumbel_argmax or st_argmax, but got {self.config.tokenize_function}')
return tokens
def encode(self, pixel_values):
output = self.backbone(pixel_values, output_hidden_states=True, return_dict=True)
features = output.hidden_states[-1]
if self.config.drop_cls_token:
features = features[:, 1:, :]
# merge number of `hidden_stride * hidden_stride` hidden states together to reduce token sequence length
# e.g., for hidden_stride=3, this leads to a token length reduction: 729 -> 81 for siglip
if self.config.hidden_stride > 1:
n, l, d = features.shape # this `d` maybe different from the above `d
sqrt_l = int(l ** 0.5)
assert sqrt_l ** 2 == l, "The token sequence length should be a perfect square."
features = features.reshape(n, sqrt_l, sqrt_l, d)
pl = (self.config.hidden_stride - (sqrt_l % self.config.hidden_stride)) % self.config.hidden_stride
features = pad(features, (0, 0, 0, pl, 0, pl), "constant", 0)
sqrt_l += pl
features = features.reshape(n, sqrt_l // self.config.hidden_stride, self.config.hidden_stride,
sqrt_l // self.config.hidden_stride, self.config.hidden_stride, d)
features = features.permute(0, 1, 3, 2, 4, 5) # [n, sqrt_l/hs, sqrt_l/hs, hs, hs, d]
features = features.flatten(3) # [n, sqrt_l/hs, sqrt_l/hs, hs*hs*d]
features = features.reshape(
n, -1, self.config.hidden_stride * self.config.hidden_stride * d)
return features
def forward(self, pixel_values) -> torch.Tensor: # [BatchSize, ImageShape] -> [BatchSize, #Token, VocabSize]
features = self.encode(pixel_values)
logits = self.head(features)
tokens = self.tokenize(logits)
# tokens' shape is [BatchSize, #Token, VocabSize-5], so padding with [BatchSize, #Token, 5], after
# which, tokens' shape should become [BatchSize, #Token, VocabSize]
batch_size, token_len, _ = tokens.shape
padding_tensor = torch.zeros(size=(batch_size, token_len, len(IMAGE_INDICATOR_IDS)),
dtype=tokens.dtype,
device=tokens.device,
layout=tokens.layout,
requires_grad=False)
tokens = torch.cat((tokens, padding_tensor), dim=2)
return tokens
from transformers import AutoConfig, AutoModel
from transformers import CLIPVisionModel, CLIPImageProcessor
from .base_visual_tokenizer import BaseVisualTokenizerConfig, BaseVisualTokenizer
MODEL_TYPE = "clip_visual_tokenizer"
class ClipVisualTokenizerConfig(BaseVisualTokenizerConfig):
model_type = MODEL_TYPE
def __init__(self, **kwargs):
super().__init__(**kwargs)
if self.depths:
assert len(self.depths) == 1
self.backbone_kwargs['num_hidden_layers'] = self.depths[0]
class ClipVisualTokenizer(BaseVisualTokenizer):
config_class = ClipVisualTokenizerConfig
supports_gradient_checkpointing = True
_no_split_modules = ["CLIPEncoderLayer"]
_image_processor_class = CLIPImageProcessor
_image_processor_kwargs = dict(do_center_crop=False)
_backbone_class = CLIPVisionModel
_backbone_name_or_path = "openai/clip-vit-large-patch14-336"
def get_monitor_tensors(self):
return dict(
backbone_bottom=self.backbone.vision_model.encoder.layers[0].self_attn.k_proj.weight,
backbone_top=self.backbone.vision_model.encoder.layers[-1].self_attn.out_proj.weight,
head=self.head[0].weight
)
def get_image_size(self):
height = self.image_processor.crop_size["height"]
width = self.image_processor.crop_size["width"]
return height, width
AutoConfig.register(MODEL_TYPE, ClipVisualTokenizerConfig)
AutoModel.register(ClipVisualTokenizerConfig, ClipVisualTokenizer)
from transformers import AutoConfig, AutoModel
from transformers import SiglipVisionModel, SiglipImageProcessor
from .base_visual_tokenizer import BaseVisualTokenizerConfig, BaseVisualTokenizer
MODEL_TYPE = "siglip_visual_tokenizer"
class SiglipVisualTokenizerConfig(BaseVisualTokenizerConfig):
model_type = MODEL_TYPE
def __init__(self, **kwargs):
super().__init__(**kwargs)
if self.drop_cls_token:
self.drop_cls_token = False
if self.depths:
assert len(self.depths) == 1
self.backbone_kwargs['num_hidden_layers'] = self.depths[0]
class SiglipVisualTokenizer(BaseVisualTokenizer):
config_class = SiglipVisualTokenizerConfig
supports_gradient_checkpointing = True
_no_split_modules = ["SiglipVisionTransformer"]
_image_processor_class = SiglipImageProcessor
_image_processor_kwargs = {}
_backbone_class = SiglipVisionModel
_backbone_name_or_path = "google/siglip-so400m-patch14-384"
def get_monitor_tensors(self):
return dict(
backbone_bottom=self.backbone.vision_model.encoder.layers[0].self_attn.k_proj.weight,
backbone_top=self.backbone.vision_model.encoder.layers[-1].self_attn.out_proj.weight,
head=self.head[0].weight
)
def get_image_size(self):
height = self.image_processor.size["height"]
width = self.image_processor.size["width"]
return height, width
AutoConfig.register(MODEL_TYPE, SiglipVisualTokenizerConfig)
AutoModel.register(SiglipVisualTokenizerConfig, SiglipVisualTokenizer)
from dataclasses import field, dataclass
from typing import Optional, Union, List
import torch
from PIL import Image
from ovis.model.modeling_ovis import Ovis
from ovis.util.constants import IMAGE_TOKEN
import time
@dataclass
class RunnerArguments:
model_path: str
max_new_tokens: int = field(default=512)
do_sample: bool = field(default=False)
top_p: Optional[float] = field(default=None)
top_k: Optional[int] = field(default=None)
temperature: Optional[float] = field(default=None)
max_partition: int = field(default=9)
class OvisRunner:
def __init__(self, args: RunnerArguments):
self.model_path = args.model_path
self.dtype = torch.bfloat16
self.device = torch.cuda.current_device()
self.dtype = torch.bfloat16
self.model = Ovis.from_pretrained(self.model_path, torch_dtype=self.dtype, multimodal_max_length=8192)
self.model = self.model.eval().to(device=self.device)
self.eos_token_id = self.model.generation_config.eos_token_id
self.text_tokenizer = self.model.get_text_tokenizer()
self.pad_token_id = self.text_tokenizer.pad_token_id
self.visual_tokenizer = self.model.get_visual_tokenizer()
self.conversation_formatter = self.model.get_conversation_formatter()
self.image_placeholder = IMAGE_TOKEN
self.max_partition = args.max_partition
self.gen_kwargs = dict(
max_new_tokens=args.max_new_tokens,
do_sample=args.do_sample,
top_p=args.top_p,
top_k=args.top_k,
temperature=args.temperature,
repetition_penalty=None,
eos_token_id=self.eos_token_id,
pad_token_id=self.pad_token_id,
use_cache=True
)
def preprocess(self, inputs: List[Union[Image.Image, str]]):
# for single image and single text inputs, ensure image ahead
if len(inputs) == 2 and isinstance(inputs[0], str) and isinstance(inputs[1], Image.Image):
inputs = reversed(inputs)
# build query
query = ''
images = []
for data in inputs:
if isinstance(data, Image.Image):
query += self.image_placeholder + '\n'
images.append(data)
elif isinstance(data, str):
query += data.replace(self.image_placeholder, '')
elif data is not None:
raise RuntimeError(f'Invalid input type, expected `PIL.Image.Image` or `str`, but got {type(data)}')
# format conversation
prompt, input_ids, pixel_values = self.model.preprocess_inputs(
query, images, max_partition=self.max_partition)
attention_mask = torch.ne(input_ids, self.text_tokenizer.pad_token_id)
input_ids = input_ids.unsqueeze(0).to(device=self.device)
attention_mask = attention_mask.unsqueeze(0).to(device=self.device)
if pixel_values is not None:
pixel_values = [pixel_values.to(device=self.device, dtype=self.dtype)]
else:
pixel_values = [None]
return prompt, input_ids, attention_mask, pixel_values
def run(self, inputs: List[Union[Image.Image, str]]):
prompt, input_ids, attention_mask, pixel_values = self.preprocess(inputs)
start_time = time.time()
output_ids = self.model.generate(
input_ids,
pixel_values=pixel_values,
attention_mask=attention_mask,
**self.gen_kwargs
)
output = self.text_tokenizer.decode(output_ids[0], skip_special_tokens=True)
end_time = time.time()
print(f"times****:{end_time-start_time}")
input_token_len = input_ids.shape[1]
output_token_len = output_ids.shape[1]
response = dict(
prompt=prompt,
output=output,
prompt_tokens=input_token_len,
total_tokens=input_token_len + output_token_len
)
return response
if __name__ == '__main__':
runner_args = RunnerArguments(model_path='AIDC-AI/Ovis1.6-Gemma2-9B')
runner = OvisRunner(runner_args)
image = Image.open('OST_120.png')
text = 'Please describe this image'
response = runner.run([image, text])
print(response['output'])
import argparse
import os.path
import gradio as gr
from gradio.components import Textbox, Image
from ovis.serve.runner import RunnerArguments, OvisRunner
class Server:
def __init__(self, runner: OvisRunner):
self.runner = runner
def __call__(self, image, text):
response = self.runner.run([image, text])
output = response["output"]
return output
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Ovis Server')
parser.add_argument('--model_path', type=str, required=True)
parser.add_argument('--flagging_dir', type=str, default=os.path.expanduser('~/ovis-flagged'))
parser.add_argument('--max_partition', type=int, default=9)
parser.add_argument('--port', type=int, required=True)
args = parser.parse_args()
os.makedirs(args.flagging_dir, exist_ok=True)
runner_args = RunnerArguments(
model_path=args.model_path,
max_partition=args.max_partition
)
demo = gr.Interface(
fn=Server(OvisRunner(runner_args)),
inputs=[Image(type='pil', label='image'),
Textbox(placeholder='Enter your text here...', label='prompt')],
outputs=gr.Markdown(),
title=args.model_path.split('/')[-1],
flagging_dir=args.flagging_dir
)
demo.launch(server_port=args.port)
from dataclasses import dataclass, field
from typing import Optional
import transformers
@dataclass
class ModelArguments:
llm_name_or_path: Optional[str] = field(default=None)
visual_tokenizer_type: str = field(default=None)
visual_vocab_size: int = field(default=8192)
visual_drop_cls_token: bool = field(default=False)
visual_tokenize_function: str = field(default='softmax')
visual_tau: float = field(default=1.0)
visual_depths: Optional[str] = field(default=None)
visual_hidden_stride: int = field(default=1)
multimodal_max_length: int = field(default=2048)
conversation_formatter_class: str = field(default=None)
pad_token_id: Optional[int] = field(default=None)
llm_attn_implementation: Optional[str] = field(default=None)
disable_tie_weight: bool = field(default=False)
@dataclass
class TrainingArguments(transformers.TrainingArguments):
dataset_names: Optional[str] = field(default=None) # a|b|c
dataset_info: Optional[str] = field(default='dataset_info_v1_6')
ovis_pretrained_path: Optional[str] = field(default=None)
visual_tokenizer_pretrained_path: Optional[str] = field(default=None)
caption_template: Optional[str] = field(default=None)
stage: Optional[int] = field(default=None)
train_modules: Optional[str] = field(default=None)
cache_dir: Optional[str] = field(default=None)
optim: str = field(default="adamw_torch")
visual_max_tau: float = field(default=5.0)
visual_min_tau: float = field(default=0.05)
save_safetensors: bool = field(default=True)
monitor_step: int = field(default=100)
vte_re_init: bool = field(default=False)
text_max_length: int = field(default=1024)
max_partitions: str = field(default="9|1|1")
def __post_init__(self):
if self.gradient_checkpointing:
self.gradient_checkpointing_kwargs = {"use_reentrant": False}
if self.stage < 3:
self.save_safetensors = False
super().__post_init__()
import deepspeed
import torch
from transformers import TrainerCallback, TrainingArguments, TrainerState, TrainerControl
from ovis.util.constants import END_LINE, BEGIN_LINE
from ovis.util.utils import rank0_print
class TuneTauCallback(TrainerCallback):
def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
visual_tokenizer = kwargs['model'].get_visual_tokenizer()
current_step = state.global_step
max_step = state.max_steps
ratio = current_step / max_step
visual_tokenizer.config.tau = args.visual_max_tau - (args.visual_max_tau - args.visual_min_tau) * ratio
class MonitorCallback(TrainerCallback):
def _monitoring(self, model, step):
with torch.no_grad():
with deepspeed.zero.GatheredParameters(model.get_monitor_tensors().values()):
for k, v in model.get_monitor_tensors().items():
rank0_print(BEGIN_LINE)
rank0_print(f'{k} @ step {step} with sum: {v.sum().item()} and content: ')
rank0_print(v)
rank0_print(END_LINE)
def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
model = kwargs['model']
step = state.global_step
if step % args.monitor_step == 0 or step == 10: # monitor at step 10 for fast check
self._monitoring(model, step)
def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
model = kwargs['model']
step = state.global_step
self._monitoring(model, step)
import logging
from datetime import datetime
from typing import Dict
import pandas
import torch
from ovis.train.dataset.multimodal_dataset import MultimodalDataset
from ovis.util.constants import IMAGE_TOKEN, IGNORE_ID
from ovis.util.utils import rank0_print
class CaptionDataset(MultimodalDataset):
def load(self):
rank0_print(f"[{datetime.now()}] Loading dataset {self.name} from {self.meta_file} begin")
samples = pandas.read_parquet(self.meta_file, engine='pyarrow')
rank0_print(f"[{datetime.now()}] Loading dataset {self.name} end")
return samples
def __getitem__(self, i: int) -> Dict[str, torch.Tensor]:
sample = self.samples.iloc[i]
text = sample['caption']
image_path = sample['image_path']
# read and preprocess image
pixel_values, image_placeholders = self.visual_tokenizer.mock_input()
valid_image = False
image, e = self.read_image(image_path)
if image is None:
logging.warning(
f'reading image failed with index: {i}, image path: {image_path}, and exception: {e}')
else:
try:
pixel_values, image_placeholders = self.visual_tokenizer.preprocess_image(
image, max_partition=self.max_partitions[0])
valid_image = True
except Exception as e:
logging.warning(
f'preprocessing image failed with index: {i}, image path: {image_path}, and exception: {e}')
# preprocess text
if text is None:
logging.warning(f'text is `None`, index: {i}')
text = ""
if not valid_image:
logging.warning(f'image is not valid, so set text as empty, index: {i}, image path: {image_path}')
text = ""
text = text.replace(IMAGE_TOKEN, '').strip()
head, tail = self.caption_template.split(IMAGE_TOKEN)
head_ids = self.text_tokenizer(head, add_special_tokens=False).input_ids
tail_ids = self.text_tokenizer(tail, add_special_tokens=False).input_ids
text_ids = self.text_tokenizer(text, add_special_tokens=False).input_ids
input_ids = head_ids + image_placeholders + tail_ids + text_ids
labels = [IGNORE_ID] * (len(input_ids) - len(text_ids)) + text_ids
input_ids = input_ids[:self.text_max_length]
labels = labels[:self.text_max_length]
input_ids = torch.tensor(input_ids, dtype=torch.long)
labels = torch.tensor(labels, dtype=torch.long)
return dict(
pixel_values=pixel_values,
input_ids=input_ids,
labels=labels
)
import copy
import json
import logging
from datetime import datetime
from typing import Dict
import torch
from ovis.train.dataset.multimodal_dataset import MultimodalDataset
from ovis.util.utils import rank0_print
class ConversationDataset(MultimodalDataset):
def load(self):
rank0_print(f"[{datetime.now()}] Loading dataset {self.name} from {self.meta_file} begin")
with open(self.meta_file, 'r', encoding='utf-8') as f:
samples = json.load(f)
rank0_print(f'#samples: {len(samples)}')
rank0_print(f'sample: {samples[0]}')
rank0_print(f"[{datetime.now()}] Loading dataset {self.name} end")
return samples
def __getitem__(self, i: int) -> Dict[str, torch.Tensor]:
sample = self.samples[i]
conversations = copy.deepcopy(sample["conversations"])
images = None
max_partition = None
if 'image' in sample:
image_paths = sample['image']
if isinstance(image_paths, str):
image_paths = [image_paths]
images = []
for image_path in image_paths:
image, e = self.read_image(image_path)
if image is None:
logging.warning(
f'reading image failed with index: {i}, image path: {image_path}, and exception: {e}')
images = None
break
images.append(image)
elif 'video' in sample:
raise RuntimeError('video is to be supported')
if images:
max_partition = self.max_partitions[0] if len(images) == 1 else self.max_partitions[1]
prompt, input_ids, pixel_values, labels = self.model.preprocess_inputs(
conversations,
images,
max_partition=max_partition,
generation_preface=None,
return_labels=True,
propagate_exception=False
)
if pixel_values is None:
pixel_values, _ = self.visual_tokenizer.mock_input()
input_ids = input_ids[:self.text_max_length]
labels = labels[:self.text_max_length]
return dict(
pixel_values=pixel_values,
input_ids=input_ids,
labels=labels
)
import logging
import os
from typing import Dict, Sequence, Union, List
import torch
from PIL import Image
from torch.utils.data import Dataset
from transformers import PreTrainedTokenizer
from ovis.model.modeling_ovis import Ovis
from ovis.train.arguments import TrainingArguments
from ovis.util.constants import IGNORE_ID
class MultimodalDataset(Dataset):
def __init__(self, name: str, info: Dict, model: Ovis, training_args: TrainingArguments):
self.name = name
self.meta_file = info['meta_file']
self.image_dir = info['image_dir']
self.caption_template = info.get('caption_template', None)
self.text_tokenizer = model.get_text_tokenizer()
self.visual_tokenizer = model.get_visual_tokenizer()
self.image_height, self.image_width = self.visual_tokenizer.get_image_size()
self.model = model
self.text_max_length = training_args.text_max_length
self.max_partitions = [int(m.strip()) for m in training_args.max_partitions.split('|')]
self.samples = self.load()
def load(self):
raise NotImplementedError
def __getitem__(self, i: int) -> Dict[str, torch.Tensor]:
raise NotImplementedError
def __len__(self):
return len(self.samples)
def read_image(self, path):
try:
full_path = os.path.join(self.image_dir, path)
image = Image.open(full_path).convert('RGB')
return image, None
except Exception as e:
return None, e
class DataCollatorForMultimodalDataset:
def __init__(self, text_tokenizer: PreTrainedTokenizer):
self.text_tokenizer = text_tokenizer
def __call__(self, instances: Sequence[Dict]) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]:
pixel_values, input_ids, labels = tuple([instance[key] for instance in instances]
for key in ("pixel_values", "input_ids", "labels"))
input_ids = torch.nn.utils.rnn.pad_sequence(
input_ids,
batch_first=True,
padding_value=self.text_tokenizer.pad_token_id)
attention_mask = torch.ne(input_ids, self.text_tokenizer.pad_token_id)
labels = torch.nn.utils.rnn.pad_sequence(
labels,
batch_first=True,
padding_value=IGNORE_ID)
num_valid_label = torch.not_equal(labels, IGNORE_ID).sum().item()
if num_valid_label == 0:
logging.warning(
f'[DataCollatorForMultimodalDataset] All labels in a batch are ignored, which may lead to training instability\n{input_ids=}\n{attention_mask=}\n{labels=}')
return dict(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
pixel_values=pixel_values
)
import json
import os
import pathlib
import deepspeed
import torch
import transformers
from deepspeed import get_accelerator
from torch.utils.data import ConcatDataset
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel, AutoConfig
from transformers import Trainer
from transformers.integrations.deepspeed import unset_hf_deepspeed_config, set_hf_deepspeed_config
from callback import TuneTauCallback, MonitorCallback
from ovis.model.configuration_ovis import OvisConfig
from ovis.model.modeling_ovis import Ovis
from ovis.train.arguments import ModelArguments, TrainingArguments
from ovis.train.dataset.caption_dataset import CaptionDataset
from ovis.train.dataset.conversation_dataset import ConversationDataset
from ovis.train.dataset.multimodal_dataset import DataCollatorForMultimodalDataset
from ovis.util.constants import BEGIN_LINE, END_LINE
from ovis.util.utils import smart_unit, rank0_print
def train():
# parse args
parser = transformers.HfArgumentParser(
(ModelArguments, TrainingArguments))
model_args, training_args = parser.parse_args_into_dataclasses()
# save args to checkpoint dir
with training_args.main_process_first(local=False):
if training_args.process_index == 0:
def args2dict(args):
return {k: str(v) for k, v in args.__dict__.items()}
args_log = json.dumps(dict(
model_args=args2dict(model_args),
training_args=args2dict(training_args)
), ensure_ascii=False, indent=2)
print(args_log)
os.makedirs(training_args.output_dir, exist_ok=True)
with open(os.path.join(training_args.output_dir, 'model_training_args.json'), 'w',
encoding='utf-8') as f:
f.write(args_log + '\n')
# construct or load ovis model
if not training_args.ovis_pretrained_path: # construct model (S1)
# 1. construct ovis config
ovis_config = OvisConfig(
multimodal_max_length=model_args.multimodal_max_length,
conversation_formatter_class=model_args.conversation_formatter_class,
llm_attn_implementation=model_args.llm_attn_implementation
)
# 2. load pretrained llm and text tokenizer
attn_kwargs = dict()
if model_args.llm_attn_implementation:
attn_kwargs['attn_implementation'] = model_args.llm_attn_implementation
llm = AutoModelForCausalLM.from_pretrained(model_args.llm_name_or_path, **attn_kwargs)
text_tokenizer = AutoTokenizer.from_pretrained(model_args.llm_name_or_path)
if text_tokenizer.pad_token_id is None and model_args.pad_token_id is not None:
text_tokenizer.pad_token_id = model_args.pad_token_id
# 3. construct visual tokenizer
# deepspeed zero.Init with bfloat16 fail for visual_tokenizer, so temporarily disable zero.Init here
unset_hf_deepspeed_config()
if training_args.visual_tokenizer_pretrained_path is not None:
visual_tokenizer = AutoModel.from_pretrained(
training_args.visual_tokenizer_pretrained_path,
image_processor_name_or_path=training_args.visual_tokenizer_pretrained_path
)
else:
visual_tokenizer_config = AutoConfig.for_model(
model_type=model_args.visual_tokenizer_type + "_visual_tokenizer",
vocab_size=model_args.visual_vocab_size,
tokenize_function=model_args.visual_tokenize_function,
tau=model_args.visual_tau,
depths=model_args.visual_depths,
drop_cls_token=model_args.visual_drop_cls_token,
hidden_stride=model_args.visual_hidden_stride,
)
visual_tokenizer = AutoModel.from_config(visual_tokenizer_config, train_from_scratch=True)
visual_tokenizer = visual_tokenizer.to(
device=torch.device(get_accelerator().device_name(os.getenv("LOCAL_RANK"))))
if getattr(training_args, 'hf_deepspeed_config', None) is not None:
set_hf_deepspeed_config(training_args.hf_deepspeed_config)
# 4. construct ovis model
model = Ovis(ovis_config, llm=llm, text_tokenizer=text_tokenizer, visual_tokenizer=visual_tokenizer,
train_from_scratch=True)
else: # load pretrained ovis model
model, loading_info = Ovis.from_pretrained(training_args.ovis_pretrained_path,
multimodal_max_length=model_args.multimodal_max_length,
output_loading_info=True)
rank0_print(BEGIN_LINE)
rank0_print(f'Loading info of Ovis:\n{loading_info}')
rank0_print(END_LINE)
training_args.vte_re_init = False
model.get_llm().config.use_cache = False
model.config.use_cache = False
text_tokenizer = model.get_text_tokenizer()
rank0_print(BEGIN_LINE)
rank0_print(f'model.config:\n{model.config}')
rank0_print(END_LINE)
# maybe re-init vte
if training_args.vte_re_init:
with deepspeed.zero.GatheredParameters([model.get_wte().weight]):
mean = model.get_wte().weight.mean().item()
std = model.get_wte().weight.std().item()
rank0_print(f'Statistics of embedding table of LLM: {mean=}, {std=}')
model.re_init_vte(mean, std)
# select train modules
model.requires_grad_(False)
for module in training_args.train_modules.split('|'):
if module == 'all':
model.requires_grad_(True)
elif module == 'llm':
model.get_llm().requires_grad_(True)
elif module == 'visual_tokenizer':
model.get_visual_tokenizer().requires_grad_(True)
elif module == 'visual_tokenizer.backbone':
model.get_visual_tokenizer().get_backbone().requires_grad_(True)
elif module.startswith('visual_tokenizer.backbone.layer.'):
layer_index = int(module[len('visual_tokenizer.backbone.layer.'):])
layer = model.get_visual_tokenizer().get_backbone_layer(layer_index)
layer.requires_grad_(True)
elif module == 'visual_tokenizer.head':
model.get_visual_tokenizer().get_head().requires_grad_(True)
elif module == 'vte':
model.get_vte().requires_grad_(True)
else:
raise ValueError(f'Invalid train module name: {module}')
rank0_print(BEGIN_LINE)
rank0_print('Parameters to train:')
for name, param in model.named_parameters():
if param.requires_grad:
rank0_print(name)
rank0_print(f'LLM\'s attn implementation: {model.get_llm().config._attn_implementation}')
rank0_print(END_LINE)
# construct data module
datasets = []
dataset_info_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),
f'dataset/{training_args.dataset_info}.json')
with open(dataset_info_path, 'r', encoding='utf-8') as f:
dataset_info = json.load(f)
for name in training_args.dataset_names.split('|'):
info = dataset_info[name]
data_format = info['data_format']
if data_format == 'caption':
dataset = CaptionDataset(name, info, model, training_args)
elif data_format == 'conversation':
dataset = ConversationDataset(name, info, model, training_args)
else:
raise ValueError(f'Invalid data format `{data_format}` for dataset `{name}`')
datasets.append(dataset)
data_module = dict(
train_dataset=ConcatDataset(datasets),
data_collator=DataCollatorForMultimodalDataset(text_tokenizer)
)
# train
train_callbacks = [MonitorCallback]
if model_args.visual_tokenize_function == 'gumbel_argmax':
train_callbacks.append(TuneTauCallback)
trainer = Trainer(
model=model,
args=training_args,
callbacks=train_callbacks,
**data_module
)
rank0_print(BEGIN_LINE)
rank0_print('Dataset sample tensor:')
rank0_print(data_module['train_dataset'][0])
rank0_print(END_LINE)
rank0_print(BEGIN_LINE)
rank0_print('Dataset sample input_ids decoding:')
rank0_print(text_tokenizer.decode([x for x in data_module['train_dataset'][0]['input_ids'] if x >= 0]))
rank0_print(END_LINE)
rank0_print(BEGIN_LINE)
rank0_print('Dataset sample labels decoding:')
rank0_print(text_tokenizer.decode([x for x in data_module['train_dataset'][0]['labels'] if x >= 0]))
rank0_print(END_LINE)
rank0_print(BEGIN_LINE)
rank0_print(f'#param of model: {smart_unit(model.num_parameters())}')
rank0_print(f'#param of llm: {smart_unit(model.get_llm().num_parameters())}')
rank0_print(f'#param of visual_tokenizer: {smart_unit(model.get_visual_tokenizer().num_parameters())}')
rank0_print(f'#param of vte: {smart_unit(model.get_vte().weight.numel())}')
rank0_print(END_LINE)
if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
trainer.train(resume_from_checkpoint=True)
else:
trainer.train()
trainer.save_state()
# save model
model.get_llm().config.use_cache = True
model.config.use_cache = True
trainer.save_model()
if __name__ == '__main__':
train()
# Model Constants
IGNORE_ID = -100
IMAGE_TOKEN_ID = -200
IMAGE_TOKEN = "<image>"
IMAGE_ATOM_ID = -300
IMAGE_INDICATOR_IDS = [-301, -302, -303, -304, -305]
# Log & Print
BEGIN_LINE = '========================************========================'
END_LINE = '------------------------------------------------------------'
import os
from importlib import import_module
def rank0_print(*args):
if int(os.getenv("LOCAL_PROCESS_RANK", os.getenv("LOCAL_RANK", 0))) == 0:
print(*args)
def smart_unit(num):
if num / 1.0e9 >= 1:
return f'{num / 1.0e9:.2f}B'
else:
return f'{num / 1.0e6:.2f}M'
def import_class_from_string(full_class_string):
# Split the path to get separate module and class names
module_path, _, class_name = full_class_string.rpartition('.')
# Import the module using the module path
module = import_module(module_path)
# Get the class from the imported module
cls = getattr(module, class_name)
return cls
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment