Commit f91d2ea3 authored by mashun1's avatar mashun1
Browse files

hunyuandit

parents
# Copyright 2023 Haotian Liu
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
from transformers import AutoConfig, AutoModelForCausalLM, \
LlamaConfig, LlamaModel, LlamaForCausalLM
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.generation.utils import GenerateOutput
from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
class LlavaConfig(LlamaConfig):
model_type = "llava_llama"
class LlavaLlamaModel(LlavaMetaModel, LlamaModel):
config_class = LlavaConfig
def __init__(self, config: LlamaConfig):
super(LlavaLlamaModel, self).__init__(config)
class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM):
config_class = LlavaConfig
def __init__(self, config):
super(LlamaForCausalLM, self).__init__(config)
self.model = LlavaLlamaModel(config)
self.pretraining_tp = config.pretraining_tp
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_model(self):
return self.model
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
images: Optional[torch.FloatTensor] = None,
image_sizes: Optional[List[List[int]]] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
if inputs_embeds is None:
(
input_ids,
position_ids,
attention_mask,
past_key_values,
inputs_embeds,
labels
) = self.prepare_inputs_labels_for_multimodal(
input_ids,
position_ids,
attention_mask,
past_key_values,
labels,
images,
image_sizes
)
return super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict
)
@torch.no_grad()
def generate(
self,
inputs: Optional[torch.Tensor] = None,
images: Optional[torch.Tensor] = None,
image_sizes: Optional[torch.Tensor] = None,
**kwargs,
) -> Union[GenerateOutput, torch.LongTensor]:
position_ids = kwargs.pop("position_ids", None)
attention_mask = kwargs.pop("attention_mask", None)
if "inputs_embeds" in kwargs:
raise NotImplementedError("`inputs_embeds` is not supported")
if images is not None:
(
inputs,
position_ids,
attention_mask,
_,
inputs_embeds,
_
) = self.prepare_inputs_labels_for_multimodal(
inputs,
position_ids,
attention_mask,
None,
None,
images,
image_sizes=image_sizes
)
else:
inputs_embeds = self.get_model().embed_tokens(inputs)
return super().generate(
position_ids=position_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
**kwargs
)
def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
inputs_embeds=None, **kwargs):
images = kwargs.pop("images", None)
image_sizes = kwargs.pop("image_sizes", None)
inputs = super().prepare_inputs_for_generation(
input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
)
if images is not None:
inputs['images'] = images
if image_sizes is not None:
inputs['image_sizes'] = image_sizes
return inputs
AutoConfig.register("llava_llama", LlavaConfig)
AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM)
# Copyright 2023 Haotian Liu
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss
from transformers import AutoConfig, AutoModelForCausalLM, \
MistralConfig, MistralModel, MistralForCausalLM
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.generation.utils import GenerateOutput
from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
class LlavaMistralConfig(MistralConfig):
model_type = "llava_mistral"
class LlavaMistralModel(LlavaMetaModel, MistralModel):
config_class = LlavaMistralConfig
def __init__(self, config: MistralConfig):
super(LlavaMistralModel, self).__init__(config)
class LlavaMistralForCausalLM(MistralForCausalLM, LlavaMetaForCausalLM):
config_class = LlavaMistralConfig
def __init__(self, config):
super(MistralForCausalLM, self).__init__(config)
self.model = LlavaMistralModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_model(self):
return self.model
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
images: Optional[torch.FloatTensor] = None,
image_sizes: Optional[List[List[int]]] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
if inputs_embeds is None:
(
input_ids,
position_ids,
attention_mask,
past_key_values,
inputs_embeds,
labels
) = self.prepare_inputs_labels_for_multimodal(
input_ids,
position_ids,
attention_mask,
past_key_values,
labels,
images,
image_sizes
)
return super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict
)
@torch.no_grad()
def generate(
self,
inputs: Optional[torch.Tensor] = None,
images: Optional[torch.Tensor] = None,
image_sizes: Optional[torch.Tensor] = None,
**kwargs,
) -> Union[GenerateOutput, torch.LongTensor]:
position_ids = kwargs.pop("position_ids", None)
attention_mask = kwargs.pop("attention_mask", None)
if "inputs_embeds" in kwargs:
raise NotImplementedError("`inputs_embeds` is not supported")
if images is not None:
(
inputs,
position_ids,
attention_mask,
_,
inputs_embeds,
_
) = self.prepare_inputs_labels_for_multimodal(
inputs,
position_ids,
attention_mask,
None,
None,
images,
image_sizes=image_sizes
)
else:
inputs_embeds = self.get_model().embed_tokens(inputs)
return super().generate(
position_ids=position_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
**kwargs
)
def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
inputs_embeds=None, **kwargs):
images = kwargs.pop("images", None)
image_sizes = kwargs.pop("image_sizes", None)
inputs = super().prepare_inputs_for_generation(
input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
)
if images is not None:
inputs['images'] = images
if image_sizes is not None:
inputs['image_sizes'] = image_sizes
return inputs
AutoConfig.register("llava_mistral", LlavaMistralConfig)
AutoModelForCausalLM.register(LlavaMistralConfig, LlavaMistralForCausalLM)
# Copyright 2023 Haotian Liu
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple
import torch
from transformers import AutoConfig, AutoModelForCausalLM, \
MptConfig, MptForCausalLM, MptModel
from llava.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
class LlavaMptConfig(MptConfig):
model_type = "llava_mpt"
class LlavaMptModel(LlavaMetaModel, MptModel):
config_class = LlavaMptConfig
def __init__(self, config: MptConfig):
config.hidden_size = config.d_model
super(LlavaMptModel, self).__init__(config)
def embed_tokens(self, x):
return self.wte(x)
class LlavaMptForCausalLM(MptForCausalLM, LlavaMetaForCausalLM):
config_class = LlavaMptConfig
supports_gradient_checkpointing = True
def __init__(self, config):
super(MptForCausalLM, self).__init__(config)
self.transformer = LlavaMptModel(config)
self.lm_head = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_model(self):
return self.transformer
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, LlavaMptModel):
module.gradient_checkpointing = value
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
images=None):
input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images)
return super().forward(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
images = kwargs.pop("images", None)
_inputs = super().prepare_inputs_for_generation(
input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
)
_inputs['images'] = images
return _inputs
AutoConfig.register("llava_mpt", LlavaMptConfig)
AutoModelForCausalLM.register(LlavaMptConfig, LlavaMptForCausalLM)
# Copyright 2023 Haotian Liu
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
import torch
import torch.nn as nn
from .multimodal_encoder.builder import build_vision_tower
from .multimodal_projector.builder import build_vision_projector
from llava.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.mm_utils import get_anyres_image_grid_shape
class LlavaMetaModel:
def __init__(self, config):
super(LlavaMetaModel, self).__init__(config)
if hasattr(config, "mm_vision_tower"):
self.vision_tower = build_vision_tower(config, delay_load=True)
self.mm_projector = build_vision_projector(config)
if 'unpad' in getattr(config, 'mm_patch_merge_type', ''):
self.image_newline = nn.Parameter(
torch.empty(config.hidden_size, dtype=self.dtype)
)
def get_vision_tower(self):
vision_tower = getattr(self, 'vision_tower', None)
if type(vision_tower) is list:
vision_tower = vision_tower[0]
return vision_tower
def initialize_vision_modules(self, model_args, fsdp=None):
vision_tower = model_args.vision_tower
mm_vision_select_layer = model_args.mm_vision_select_layer
mm_vision_select_feature = model_args.mm_vision_select_feature
pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
mm_patch_merge_type = model_args.mm_patch_merge_type
self.config.mm_vision_tower = vision_tower
if self.get_vision_tower() is None:
vision_tower = build_vision_tower(model_args)
if fsdp is not None and len(fsdp) > 0:
self.vision_tower = [vision_tower]
else:
self.vision_tower = vision_tower
else:
if fsdp is not None and len(fsdp) > 0:
vision_tower = self.vision_tower[0]
else:
vision_tower = self.vision_tower
vision_tower.load_model()
self.config.use_mm_proj = True
self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'linear')
self.config.mm_hidden_size = vision_tower.hidden_size
self.config.mm_vision_select_layer = mm_vision_select_layer
self.config.mm_vision_select_feature = mm_vision_select_feature
self.config.mm_patch_merge_type = mm_patch_merge_type
if getattr(self, 'mm_projector', None) is None:
self.mm_projector = build_vision_projector(self.config)
if 'unpad' in mm_patch_merge_type:
embed_std = 1 / torch.sqrt(torch.tensor(self.config.hidden_size, dtype=self.dtype))
self.image_newline = nn.Parameter(
torch.randn(self.config.hidden_size, dtype=self.dtype) * embed_std
)
else:
# In case it is frozen by LoRA
for p in self.mm_projector.parameters():
p.requires_grad = True
if pretrain_mm_mlp_adapter is not None:
mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
def get_w(weights, keyword):
return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k}
self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector'))
def unpad_image(tensor, original_size):
"""
Unpads a PyTorch tensor of a padded and resized image.
Args:
tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format.
original_size (tuple): The original size of the image (height, width).
Returns:
torch.Tensor: The unpadded image tensor.
"""
original_width, original_height = original_size
current_height, current_width = tensor.shape[1:]
original_aspect_ratio = original_width / original_height
current_aspect_ratio = current_width / current_height
if original_aspect_ratio > current_aspect_ratio:
scale_factor = current_width / original_width
new_height = int(original_height * scale_factor)
padding = (current_height - new_height) // 2
unpadded_tensor = tensor[:, padding:current_height - padding, :]
else:
scale_factor = current_height / original_height
new_width = int(original_width * scale_factor)
padding = (current_width - new_width) // 2
unpadded_tensor = tensor[:, :, padding:current_width - padding]
return unpadded_tensor
class LlavaMetaForCausalLM(ABC):
@abstractmethod
def get_model(self):
pass
def get_vision_tower(self):
return self.get_model().get_vision_tower()
def encode_images(self, images):
image_features = self.get_model().get_vision_tower()(images)
image_features = self.get_model().mm_projector(image_features)
return image_features
def prepare_inputs_labels_for_multimodal(
self, input_ids, position_ids, attention_mask, past_key_values, labels,
images, image_sizes=None
):
vision_tower = self.get_vision_tower()
if vision_tower is None or images is None or input_ids.shape[1] == 1:
return input_ids, position_ids, attention_mask, past_key_values, None, labels
if type(images) is list or images.ndim == 5:
if type(images) is list:
images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images]
concat_images = torch.cat([image for image in images], dim=0)
image_features = self.encode_images(concat_images)
split_sizes = [image.shape[0] for image in images]
image_features = torch.split(image_features, split_sizes, dim=0)
mm_patch_merge_type = getattr(self.config, 'mm_patch_merge_type', 'flat')
image_aspect_ratio = getattr(self.config, 'image_aspect_ratio', 'square')
if mm_patch_merge_type == 'flat':
image_features = [x.flatten(0, 1) for x in image_features]
elif mm_patch_merge_type.startswith('spatial'):
new_image_features = []
for image_idx, image_feature in enumerate(image_features):
if image_feature.shape[0] > 1:
base_image_feature = image_feature[0]
image_feature = image_feature[1:]
height = width = self.get_vision_tower().num_patches_per_side
assert height * width == base_image_feature.shape[0]
if image_aspect_ratio == 'anyres':
num_patch_width, num_patch_height = get_anyres_image_grid_shape(image_sizes[image_idx], self.config.image_grid_pinpoints, self.get_vision_tower().config.image_size)
image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1)
else:
raise NotImplementedError
if 'unpad' in mm_patch_merge_type:
image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous()
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
image_feature = unpad_image(image_feature, image_sizes[image_idx])
image_feature = torch.cat((
image_feature,
self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)
), dim=-1)
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
else:
image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous()
image_feature = image_feature.flatten(0, 3)
image_feature = torch.cat((base_image_feature, image_feature), dim=0)
else:
image_feature = image_feature[0]
if 'unpad' in mm_patch_merge_type:
image_feature = torch.cat((
image_feature,
self.model.image_newline[None].to(image_feature.device)
), dim=0)
new_image_features.append(image_feature)
image_features = new_image_features
else:
raise ValueError(f"Unexpected mm_patch_merge_type: {self.config.mm_patch_merge_type}")
else:
image_features = self.encode_images(images)
# TODO: image start / end is not implemented here to support pretraining.
if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
raise NotImplementedError
# Let's just add dummy tensors if they do not exist,
# it is a headache to deal with None all the time.
# But it is not ideal, and if you have a better idea,
# please open an issue / submit a PR, thanks.
_labels = labels
_position_ids = position_ids
_attention_mask = attention_mask
if attention_mask is None:
attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
else:
attention_mask = attention_mask.bool()
if position_ids is None:
position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
if labels is None:
labels = torch.full_like(input_ids, IGNORE_INDEX)
# remove the padding using attention_mask -- FIXME
_input_ids = input_ids
input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)]
labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
new_input_embeds = []
new_labels = []
cur_image_idx = 0
for batch_idx, cur_input_ids in enumerate(input_ids):
num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
if num_images == 0:
cur_image_features = image_features[cur_image_idx]
cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0)
new_input_embeds.append(cur_input_embeds)
new_labels.append(labels[batch_idx])
cur_image_idx += 1
continue
image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]]
cur_input_ids_noim = []
cur_labels = labels[batch_idx]
cur_labels_noim = []
for i in range(len(image_token_indices) - 1):
cur_input_ids_noim.append(cur_input_ids[image_token_indices[i]+1:image_token_indices[i+1]])
cur_labels_noim.append(cur_labels[image_token_indices[i]+1:image_token_indices[i+1]])
split_sizes = [x.shape[0] for x in cur_labels_noim]
cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
cur_new_input_embeds = []
cur_new_labels = []
for i in range(num_images + 1):
cur_new_input_embeds.append(cur_input_embeds_no_im[i])
cur_new_labels.append(cur_labels_noim[i])
if i < num_images:
cur_image_features = image_features[cur_image_idx]
cur_image_idx += 1
cur_new_input_embeds.append(cur_image_features)
cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype))
cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds]
cur_new_input_embeds = torch.cat(cur_new_input_embeds)
cur_new_labels = torch.cat(cur_new_labels)
new_input_embeds.append(cur_new_input_embeds)
new_labels.append(cur_new_labels)
# Truncate sequences to max length as image embeddings can make the sequence longer
tokenizer_model_max_length = getattr(self.config, 'tokenizer_model_max_length', None)
if tokenizer_model_max_length is not None:
new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
# Combine them
max_len = max(x.shape[0] for x in new_input_embeds)
batch_size = len(new_input_embeds)
new_input_embeds_padded = []
new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device)
attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
cur_len = cur_new_embed.shape[0]
if getattr(self.config, 'tokenizer_padding_side', 'right') == "left":
new_input_embeds_padded.append(torch.cat((
torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device),
cur_new_embed
), dim=0))
if cur_len > 0:
new_labels_padded[i, -cur_len:] = cur_new_labels
attention_mask[i, -cur_len:] = True
position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
else:
new_input_embeds_padded.append(torch.cat((
cur_new_embed,
torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)
), dim=0))
if cur_len > 0:
new_labels_padded[i, :cur_len] = cur_new_labels
attention_mask[i, :cur_len] = True
position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
if _labels is None:
new_labels = None
else:
new_labels = new_labels_padded
if _attention_mask is None:
attention_mask = None
else:
attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
if _position_ids is None:
position_ids = None
return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
def initialize_vision_tokenizer(self, model_args, tokenizer):
if model_args.mm_use_im_patch_token:
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
self.resize_token_embeddings(len(tokenizer))
if model_args.mm_use_im_start_end:
num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
self.resize_token_embeddings(len(tokenizer))
if num_new_tokens > 0:
input_embeddings = self.get_input_embeddings().weight.data
output_embeddings = self.get_output_embeddings().weight.data
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
dim=0, keepdim=True)
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
dim=0, keepdim=True)
input_embeddings[-num_new_tokens:] = input_embeddings_avg
output_embeddings[-num_new_tokens:] = output_embeddings_avg
if model_args.tune_mm_mlp_adapter:
for p in self.get_input_embeddings().parameters():
p.requires_grad = True
for p in self.get_output_embeddings().parameters():
p.requires_grad = False
if model_args.pretrain_mm_mlp_adapter:
mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu')
embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight']
assert num_new_tokens == 2
if input_embeddings.shape == embed_tokens_weight.shape:
input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:]
elif embed_tokens_weight.shape[0] == num_new_tokens:
input_embeddings[-num_new_tokens:] = embed_tokens_weight
else:
raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.")
elif model_args.mm_use_im_patch_token:
if model_args.tune_mm_mlp_adapter:
for p in self.get_input_embeddings().parameters():
p.requires_grad = False
for p in self.get_output_embeddings().parameters():
p.requires_grad = False
"""
Usage:
python3 -m llava.model.make_delta --base ~/model_weights/llama-7b --target ~/model_weights/llava-7b --delta ~/model_weights/llava-7b-delta --hub-repo-id liuhaotian/llava-7b-delta
"""
import argparse
import torch
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM
from llava.model.utils import auto_upgrade
def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id):
print("Loading base model")
base = AutoModelForCausalLM.from_pretrained(
base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
print("Loading target model")
auto_upgrade(target_model_path)
target = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
print("Calculating delta")
for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"):
if name not in base.state_dict():
assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model'
continue
if param.data.shape == base.state_dict()[name].shape:
param.data -= base.state_dict()[name]
else:
assert name in ['model.embed_tokens.weight', 'lm_head.weight'], f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}'
bparam = base.state_dict()[name]
param.data[:bparam.shape[0], :bparam.shape[1]] -= bparam
print("Saving delta")
if hub_repo_id:
kwargs = {"push_to_hub": True, "repo_id": hub_repo_id}
else:
kwargs = {}
target.save_pretrained(delta_path, **kwargs)
target_tokenizer = AutoTokenizer.from_pretrained(target_model_path)
target_tokenizer.save_pretrained(delta_path, **kwargs)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--base-model-path", type=str, required=True)
parser.add_argument("--target-model-path", type=str, required=True)
parser.add_argument("--delta-path", type=str, required=True)
parser.add_argument("--hub-repo-id", type=str, default=None)
args = parser.parse_args()
make_delta(args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id)
import os
from .clip_encoder import CLIPVisionTower
def build_vision_tower(vision_tower_cfg, **kwargs):
vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None))
is_absolute_path_exists = os.path.exists(vision_tower)
if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion") or "ShareGPT4V" in vision_tower:
return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
raise ValueError(f'Unknown vision tower: {vision_tower}')
import torch
import torch.nn as nn
from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
from pathlib import Path
ckpt_dir = Path(__file__).resolve().parent.parent.parent.parent.parent / "ckpts" / "dialoggen"
class CLIPVisionTower(nn.Module):
def __init__(self, vision_tower, args, delay_load=False):
super().__init__()
self.is_loaded = False
self.vision_tower_name = str(ckpt_dir / vision_tower)
self.select_layer = args.mm_vision_select_layer
self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
if not delay_load:
self.load_model()
elif getattr(args, 'unfreeze_mm_vision_tower', False):
self.load_model()
else:
self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
def load_model(self, device_map=None):
if self.is_loaded:
print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name))
return
self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map)
self.vision_tower.requires_grad_(False)
self.is_loaded = True
def feature_select(self, image_forward_outs):
image_features = image_forward_outs.hidden_states[self.select_layer]
if self.select_feature == 'patch':
image_features = image_features[:, 1:]
elif self.select_feature == 'cls_patch':
image_features = image_features
else:
raise ValueError(f'Unexpected select feature: {self.select_feature}')
return image_features
@torch.no_grad()
def forward(self, images):
if type(images) is list:
image_features = []
for image in images:
image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
image_feature = self.feature_select(image_forward_out).to(image.dtype)
image_features.append(image_feature)
else:
image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
image_features = self.feature_select(image_forward_outs).to(images.dtype)
return image_features
@property
def dummy_feature(self):
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
@property
def dtype(self):
return self.vision_tower.dtype
@property
def device(self):
return self.vision_tower.device
@property
def config(self):
if self.is_loaded:
return self.vision_tower.config
else:
return self.cfg_only
@property
def hidden_size(self):
return self.config.hidden_size
@property
def num_patches_per_side(self):
return self.config.image_size // self.config.patch_size
@property
def num_patches(self):
return (self.config.image_size // self.config.patch_size) ** 2
import torch
import torch.nn as nn
import re
class IdentityMap(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, *args, **kwargs):
return x
@property
def config(self):
return {"mm_projector_type": 'identity'}
class SimpleResBlock(nn.Module):
def __init__(self, channels):
super().__init__()
self.pre_norm = nn.LayerNorm(channels)
self.proj = nn.Sequential(
nn.Linear(channels, channels),
nn.GELU(),
nn.Linear(channels, channels)
)
def forward(self, x):
x = self.pre_norm(x)
return x + self.proj(x)
def build_vision_projector(config, delay_load=False, **kwargs):
projector_type = getattr(config, 'mm_projector_type', 'linear')
if projector_type == 'linear':
return nn.Linear(config.mm_hidden_size, config.hidden_size)
mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
if mlp_gelu_match:
mlp_depth = int(mlp_gelu_match.group(1))
modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
for _ in range(1, mlp_depth):
modules.append(nn.GELU())
modules.append(nn.Linear(config.hidden_size, config.hidden_size))
return nn.Sequential(*modules)
if projector_type == 'identity':
return IdentityMap()
raise ValueError(f'Unknown projector type: {projector_type}')
from transformers import AutoConfig
def auto_upgrade(config):
cfg = AutoConfig.from_pretrained(config)
if 'llava' in config and 'llava' not in cfg.model_type:
assert cfg.model_type == 'llama'
print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.")
print("You must upgrade the checkpoint to the new code base (this can be done automatically).")
confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]")
if confirm.lower() in ["y", "yes"]:
print("Upgrading checkpoint...")
assert len(cfg.architectures) == 1
setattr(cfg.__class__, "model_type", "llava")
cfg.architectures[0] = 'LlavaLlamaForCausalLM'
cfg.save_pretrained(config)
print("Checkpoint upgraded.")
else:
print("Checkpoint upgrade aborted.")
exit(1)
import datetime
import logging
import logging.handlers
import os
import sys
import requests
from llava.constants import LOGDIR
server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN."
handler = None
def build_logger(logger_name, logger_filename):
global handler
formatter = logging.Formatter(
fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
# Set the format of root handlers
if not logging.getLogger().handlers:
logging.basicConfig(level=logging.INFO)
logging.getLogger().handlers[0].setFormatter(formatter)
# Redirect stdout and stderr to loggers
stdout_logger = logging.getLogger("stdout")
stdout_logger.setLevel(logging.INFO)
sl = StreamToLogger(stdout_logger, logging.INFO)
sys.stdout = sl
stderr_logger = logging.getLogger("stderr")
stderr_logger.setLevel(logging.ERROR)
sl = StreamToLogger(stderr_logger, logging.ERROR)
sys.stderr = sl
# Get logger
logger = logging.getLogger(logger_name)
logger.setLevel(logging.INFO)
# Add a file handler for all loggers
if handler is None:
os.makedirs(LOGDIR, exist_ok=True)
filename = os.path.join(LOGDIR, logger_filename)
handler = logging.handlers.TimedRotatingFileHandler(
filename, when='D', utc=True, encoding='UTF-8')
handler.setFormatter(formatter)
for name, item in logging.root.manager.loggerDict.items():
if isinstance(item, logging.Logger):
item.addHandler(handler)
return logger
class StreamToLogger(object):
"""
Fake file-like stream object that redirects writes to a logger instance.
"""
def __init__(self, logger, log_level=logging.INFO):
self.terminal = sys.stdout
self.logger = logger
self.log_level = log_level
self.linebuf = ''
def __getattr__(self, attr):
return getattr(self.terminal, attr)
def write(self, buf):
temp_linebuf = self.linebuf + buf
self.linebuf = ''
for line in temp_linebuf.splitlines(True):
# From the io.TextIOWrapper docs:
# On output, if newline is None, any '\n' characters written
# are translated to the system default line separator.
# By default sys.stdout.write() expects '\n' newlines and then
# translates them so this is still cross platform.
if line[-1] == '\n':
self.logger.log(self.log_level, line.rstrip())
else:
self.linebuf += line
def flush(self):
if self.linebuf != '':
self.logger.log(self.log_level, self.linebuf.rstrip())
self.linebuf = ''
def disable_torch_init():
"""
Disable the redundant torch default initialization to accelerate model creation.
"""
import torch
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
def violates_moderation(text):
"""
Check whether the text violates OpenAI moderation API.
"""
url = "https://api.openai.com/v1/moderations"
headers = {"Content-Type": "application/json",
"Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]}
text = text.replace("\n", "")
data = "{" + '"input": ' + f'"{text}"' + "}"
data = data.encode("utf-8")
try:
ret = requests.post(url, headers=headers, data=data, timeout=5)
flagged = ret.json()["results"][0]["flagged"]
except requests.exceptions.RequestException as e:
flagged = False
except KeyError as e:
flagged = False
return flagged
def pretty_print_semaphore(semaphore):
if semaphore is None:
return "None"
return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
name: HunyuanDiT
channels:
- pytorch
- nvidia
dependencies:
- python=3.8.12
- pytorch=1.13.1
- pip
一只聪明的狐狸走在阔叶树林里, 旁边是一条小溪, 细节真实, 摄影
湖水清澈,天空湛蓝,阳光灿烂。一只优雅的白天鹅在湖边游泳。它周围有几只小鸭子,看起来非常可爱,整个画面给人一种宁静祥和的感觉。
太阳微微升起,花园里的玫瑰花瓣上露珠晶莹剔透,一只瓢虫正在爬向露珠,背景是清晨的花园,微距镜头
一位女明星,中国人,头发是黑色,衣服是纯白色短袖,人物风格清新,城市背景
后印象主义风格,一条古老的石板路上面散落着金黄色的树叶。路旁的风车在静谧地转动,后面竖着两个风车。背景是一片向日葵田,蓝天上飘着几朵白云
一幅细致的油画描绘了一只年轻獾轻轻嗅着一朵明亮的黄色玫瑰时错综复杂的皮毛。背景是一棵大树干的粗糙纹理,獾的爪子轻轻地挖进树皮。在柔和的背景中,一个宁静的瀑布倾泻而下,它的水在绿色植物中闪烁着蓝色。
渔舟唱晚
请将杞人忧天的样子画出来
一只长靴猫手持亮银色的宝剑,身着铠甲,眼神坚毅,站在一堆金币上,背景是暗色调的洞穴,图像上有金币的光影点缀。
插画风格,一只狐狸和一只刺猬坐在水边的石头上,刺猬手里拿着一杯茶,狐狸旁边放着一个玻璃杯。周围是茂密的绿色植物和树木,阳光透过树叶洒在水面上,画面宁静温馨。
泥塑风格,一座五彩斑斓的花园在画面中展现,各种各样的花朵,绿色的叶子和一只正在嬉戏的小猫形成了一幅生动的图像,背景是蓝天和白云
枯藤老树昏鸦,小桥流水人家
一张细致的照片捕捉到了一尊雕像的形象,这尊雕像酷似一位古代法老,头上出人意料地戴着一副青铜蒸汽朋克护目镜。这座雕像穿着复古时髦,一件清爽的白色T恤和一件合身的黑色皮夹克,与传统的头饰形成鲜明对比。背景是简单的纯色,突出了雕像的非传统服装和蒸汽朋克眼镜的复杂细节。
一朵鲜艳的红色玫瑰花,花瓣撒有一些水珠,晶莹剔透,特写镜头,
一只可爱的猫, 细节真实, 摄影
飞流直下三千尺,疑是银河落九天
成语“鲤鱼跃龙门”
一颗新鲜的草莓特写,红色的外表,表面布满许多种子,背景是淡绿色的叶子
九寨沟
摄影风格,在画面中心是一盘热气腾腾的麻婆豆腐,豆腐呈白色,上面撒着一层红色的辣酱,有些许绿色的葱花点缀,背景是深色木质餐桌,桌子上放有辣椒和葱花作为点缀。
一位年轻女子站在春季的火车站月台上。她身着蓝灰色长风衣,白色衬衫。她的深棕色头发扎成低马尾,几缕碎发随风飘扬。她的眼神充满期待,阳光洒在她温暖的脸庞上。
一只优雅的白鹤在湖边静静地站立,它的身体纯白色,翅膀轻轻展开,背景是湖面和远处的山脉
国画风格,苏州园林中的小桥流水,周围是郁郁葱葱的树,池塘里有几朵绽放的荷花,背景是宁静的江南水乡
现实主义风格,画面主要描述一个巴洛克风格的花瓶,带有金色的装饰边框,花瓶上盛开着各种色彩鲜艳的花,白色背景
醉后不知天在水,满船清梦压星河
长城
一个亚洲中年男士在夕阳下的公园长椅上静坐。他穿着一件深蓝色的针织毛衣和灰色裤子。他的头发略显花白,手中拿着一本敞开的书。面带微笑,眼神温和,周围是落日余晖和四周的绿树。
风格是写实,画面主要描述一个亚洲戏曲艺术家正在表演,她穿着华丽的戏服,脸上戴着精致的面具,身姿优雅,背景是古色古香的舞台,镜头是近景
\ No newline at end of file
import argparse
from .constants import *
from .modules.models import HUNYUAN_DIT_CONFIG
def get_args(default_args=None):
parser = argparse.ArgumentParser()
# Basic
parser.add_argument("--prompt", type=str, default="一只小猫", help="The prompt for generating images.")
parser.add_argument("--model-root", type=str, default="ckpts", help="Model root path.")
parser.add_argument("--image-size", type=int, nargs='+', default=[1024, 1024],
help='Image size (h, w). If a single value is provided, the image will be treated to '
'(value, value).')
parser.add_argument("--infer-mode", type=str, choices=["fa", "torch", "trt"], default="torch",
help="Inference mode")
# HunYuan-DiT
parser.add_argument("--model", type=str, choices=list(HUNYUAN_DIT_CONFIG.keys()), default='DiT-g/2')
parser.add_argument("--norm", type=str, default="layer", help="Normalization layer type")
parser.add_argument("--load-key", type=str, choices=["ema", "module"], default="ema", help="Load model key for HunYuanDiT checkpoint.")
parser.add_argument('--size-cond', type=int, nargs='+', default=[1024, 1024],
help="Size condition used in sampling. 2 values are required for height and width. "
"If a single value is provided, the image will be treated to (value, value).")
parser.add_argument("--cfg-scale", type=float, default=6.0, help="Guidance scale for classifier-free.")
# Prompt enhancement
parser.add_argument("--enhance", action="store_true", help="Enhance prompt with dialoggen.")
parser.add_argument("--no-enhance", dest="enhance", action="store_false")
parser.add_argument("--load-4bit", help="load DialogGen model with 4bit quantization.", action="store_true")
parser.set_defaults(enhance=True)
# Diffusion
parser.add_argument("--learn-sigma", action="store_true", help="Learn extra channels for sigma.")
parser.add_argument("--no-learn-sigma", dest="learn_sigma", action="store_false")
parser.set_defaults(learn_sigma=True)
parser.add_argument("--predict-type", type=str, choices=list(PREDICT_TYPE), default="v_prediction",
help="Diffusion predict type")
parser.add_argument("--noise-schedule", type=str, choices=list(NOISE_SCHEDULES), default="scaled_linear",
help="Noise schedule")
parser.add_argument("--beta-start", type=float, default=0.00085, help="Beta start value")
parser.add_argument("--beta-end", type=float, default=0.03, help="Beta end value")
# Text condition
parser.add_argument("--text-states-dim", type=int, default=1024, help="Hidden size of CLIP text encoder.")
parser.add_argument("--text-len", type=int, default=77, help="Token length of CLIP text encoder output.")
parser.add_argument("--text-states-dim-t5", type=int, default=2048, help="Hidden size of CLIP text encoder.")
parser.add_argument("--text-len-t5", type=int, default=256, help="Token length of T5 text encoder output.")
parser.add_argument("--negative", type=str, default=None, help="Negative prompt.")
# Acceleration
parser.add_argument("--use_fp16", action="store_true", help="Use FP16 precision.")
parser.add_argument("--no-fp16", dest="use_fp16", action="store_false")
parser.set_defaults(use_fp16=True)
parser.add_argument("--onnx-workdir", type=str, default="onnx_model", help="Path to save ONNX model")
# Sampling
parser.add_argument("--batch-size", type=int, default=1, help="Per-GPU batch size")
parser.add_argument("--sampler", type=str, choices=SAMPLER_FACTORY, default="ddpm", help="Diffusion sampler")
parser.add_argument("--infer-steps", type=int, default=100, help="Inference steps")
parser.add_argument('--seed', type=int, default=42, help="A seed for all the prompts.")
# App
parser.add_argument("--lang", type=str, default="zh", choices=["zh", "en"], help="Language")
args = parser.parse_args(default_args)
return args
# =======================================================
NOISE_SCHEDULES = {
"linear",
"scaled_linear",
"squaredcos_cap_v2",
}
PREDICT_TYPE = {
"epsilon",
"sample",
"v_prediction",
}
# =======================================================
NEGATIVE_PROMPT = '错误的眼睛,糟糕的人脸,毁容,糟糕的艺术,变形,多余的肢体,模糊的颜色,模糊,重复,病态,残缺,'
# =======================================================
TRT_MAX_BATCH_SIZE = 1
TRT_MAX_WIDTH = 1280
TRT_MAX_HEIGHT = 1280
# =======================================================
# Constants about models
# =======================================================
SAMPLER_FACTORY = {
'ddpm': {
'scheduler': 'DDPMScheduler',
'name': 'DDPM',
'kwargs': {
'steps_offset': 1,
'clip_sample': False,
'clip_sample_range': 1.0,
'beta_schedule': 'scaled_linear',
'beta_start': 0.00085,
'beta_end': 0.03,
'prediction_type': 'v_prediction',
}
},
'ddim': {
'scheduler': 'DDIMScheduler',
'name': 'DDIM',
'kwargs': {
'steps_offset': 1,
'clip_sample': False,
'clip_sample_range': 1.0,
'beta_schedule': 'scaled_linear',
'beta_start': 0.00085,
'beta_end': 0.03,
'prediction_type': 'v_prediction',
}
},
'dpmms': {
'scheduler': 'DPMSolverMultistepScheduler',
'name': 'DPMMS',
'kwargs': {
'beta_schedule': 'scaled_linear',
'beta_start': 0.00085,
'beta_end': 0.03,
'prediction_type': 'v_prediction',
'trained_betas': None,
'solver_order': 2,
'algorithm_type': 'dpmsolver++',
}
},
}
This diff is collapsed.
import random
import time
from pathlib import Path
import numpy as np
import torch
# For reproducibility
# torch.backends.cudnn.benchmark = False
# torch.backends.cudnn.deterministic = True
from diffusers import schedulers
from diffusers.models import AutoencoderKL
from loguru import logger
from transformers import BertModel, BertTokenizer
from transformers.modeling_utils import logger as tf_logger
from .constants import SAMPLER_FACTORY, NEGATIVE_PROMPT, TRT_MAX_WIDTH, TRT_MAX_HEIGHT, TRT_MAX_BATCH_SIZE
from .diffusion.pipeline import StableDiffusionPipeline
from .modules.models import HunYuanDiT, HUNYUAN_DIT_CONFIG
from .modules.posemb_layers import get_2d_rotary_pos_embed, get_fill_resize_and_crop
from .modules.text_encoder import MT5Embedder
from .utils.tools import set_seeds
class Resolution:
def __init__(self, width, height):
self.width = width
self.height = height
def __str__(self):
return f'{self.height}x{self.width}'
class ResolutionGroup:
def __init__(self):
self.data = [
Resolution(768, 768), # 1:1
Resolution(1024, 1024), # 1:1
Resolution(1280, 1280), # 1:1
Resolution(1024, 768), # 4:3
Resolution(1152, 864), # 4:3
Resolution(1280, 960), # 4:3
Resolution(768, 1024), # 3:4
Resolution(864, 1152), # 3:4
Resolution(960, 1280), # 3:4
Resolution(1280, 768), # 16:9
Resolution(768, 1280), # 9:16
]
self.supported_sizes = set([(r.width, r.height) for r in self.data])
def is_valid(self, width, height):
return (width, height) in self.supported_sizes
STANDARD_RATIO = np.array([
1.0, # 1:1
4.0 / 3.0, # 4:3
3.0 / 4.0, # 3:4
16.0 / 9.0, # 16:9
9.0 / 16.0, # 9:16
])
STANDARD_SHAPE = [
[(768, 768), (1024, 1024), (1280, 1280)], # 1:1
[(1024, 768), (1152, 864), (1280, 960)], # 4:3
[(768, 1024), (864, 1152), (960, 1280)], # 3:4
[(1280, 768)], # 16:9
[(768, 1280)], # 9:16
]
STANDARD_AREA = [
np.array([w * h for w, h in shapes])
for shapes in STANDARD_SHAPE
]
def get_standard_shape(target_width, target_height):
"""
Map image size to standard size.
"""
target_ratio = target_width / target_height
closest_ratio_idx = np.argmin(np.abs(STANDARD_RATIO - target_ratio))
closest_area_idx = np.argmin(np.abs(STANDARD_AREA[closest_ratio_idx] - target_width * target_height))
width, height = STANDARD_SHAPE[closest_ratio_idx][closest_area_idx]
return width, height
def _to_tuple(val):
if isinstance(val, (list, tuple)):
if len(val) == 1:
val = [val[0], val[0]]
elif len(val) == 2:
val = tuple(val)
else:
raise ValueError(f"Invalid value: {val}")
elif isinstance(val, (int, float)):
val = (val, val)
else:
raise ValueError(f"Invalid value: {val}")
return val
def get_pipeline(args, vae, text_encoder, tokenizer, model, device, rank,
embedder_t5, infer_mode, sampler=None):
"""
Get scheduler and pipeline for sampling. The sampler and pipeline are both
based on diffusers and make some modifications.
Returns
-------
pipeline: StableDiffusionPipeline
sampler_name: str
"""
sampler = sampler or args.sampler
# Load sampler from factory
kwargs = SAMPLER_FACTORY[sampler]['kwargs']
scheduler = SAMPLER_FACTORY[sampler]['scheduler']
# Update sampler according to the arguments
kwargs['beta_schedule'] = args.noise_schedule
kwargs['beta_start'] = args.beta_start
kwargs['beta_end'] = args.beta_end
kwargs['prediction_type'] = args.predict_type
# Build scheduler according to the sampler.
scheduler_class = getattr(schedulers, scheduler)
scheduler = scheduler_class(**kwargs)
# Set timesteps for inference steps.
scheduler.set_timesteps(args.infer_steps, device)
# Only enable progress bar for rank 0
progress_bar_config = {} if rank == 0 else {'disable': True}
pipeline = StableDiffusionPipeline(vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=model,
scheduler=scheduler,
feature_extractor=None,
safety_checker=None,
requires_safety_checker=False,
progress_bar_config=progress_bar_config,
embedder_t5=embedder_t5,
infer_mode=infer_mode,
)
pipeline = pipeline.to(device)
return pipeline, sampler
class End2End(object):
def __init__(self, args, models_root_path):
self.args = args
# Check arguments
t2i_root_path = Path(models_root_path) / "t2i"
self.root = t2i_root_path
logger.info(f"Got text-to-image model root path: {t2i_root_path}")
# Set device and disable gradient
self.device = "cuda" if torch.cuda.is_available() else "cpu"
torch.set_grad_enabled(False)
# Disable BertModel logging checkpoint info
tf_logger.setLevel('ERROR')
# ========================================================================
logger.info(f"Loading CLIP Text Encoder...")
text_encoder_path = self.root / "clip_text_encoder"
self.clip_text_encoder = BertModel.from_pretrained(str(text_encoder_path), False, revision=None).to(self.device)
logger.info(f"Loading CLIP Text Encoder finished")
# ========================================================================
logger.info(f"Loading CLIP Tokenizer...")
tokenizer_path = self.root / "tokenizer"
self.tokenizer = BertTokenizer.from_pretrained(str(tokenizer_path))
logger.info(f"Loading CLIP Tokenizer finished")
# ========================================================================
logger.info(f"Loading T5 Text Encoder and T5 Tokenizer...")
t5_text_encoder_path = self.root / 'mt5'
embedder_t5 = MT5Embedder(t5_text_encoder_path, torch_dtype=torch.float16, max_length=256)
self.embedder_t5 = embedder_t5
logger.info(f"Loading t5_text_encoder and t5_tokenizer finished")
# ========================================================================
logger.info(f"Loading VAE...")
vae_path = self.root / "sdxl-vae-fp16-fix"
self.vae = AutoencoderKL.from_pretrained(str(vae_path)).to(self.device)
logger.info(f"Loading VAE finished")
# ========================================================================
# Create model structure and load the checkpoint
logger.info(f"Building HunYuan-DiT model...")
model_config = HUNYUAN_DIT_CONFIG[self.args.model]
self.patch_size = model_config['patch_size']
self.head_size = model_config['hidden_size'] // model_config['num_heads']
self.resolutions, self.freqs_cis_img = self.standard_shapes() # Used for TensorRT models
self.image_size = _to_tuple(self.args.image_size)
latent_size = (self.image_size[0] // 8, self.image_size[1] // 8)
self.infer_mode = self.args.infer_mode
if self.infer_mode in ['fa', 'torch']:
model_dir = self.root / "model"
model_path = model_dir / f"pytorch_model_{self.args.load_key}.pt"
if not model_path.exists():
raise ValueError(f"model_path not exists: {model_path}")
# Build model structure
self.model = HunYuanDiT(self.args,
input_size=latent_size,
**model_config,
log_fn=logger.info,
).half().to(self.device) # Force to use fp16
# Load model checkpoint
logger.info(f"Loading torch model {model_path}...")
state_dict = torch.load(model_path, map_location=lambda storage, loc: storage)
self.model.load_state_dict(state_dict)
self.model.eval()
logger.info(f"Loading torch model finished")
elif self.infer_mode == 'trt':
from .modules.trt.hcf_model import TRTModel
trt_dir = self.root / "model_trt"
engine_dir = trt_dir / "engine"
plugin_path = trt_dir / "fmha_plugins/9.2_plugin_cuda11/fMHAPlugin.so"
model_name = "model_onnx"
logger.info(f"Loading TensorRT model {engine_dir}/{model_name}...")
self.model = TRTModel(model_name=model_name,
engine_dir=str(engine_dir),
image_height=TRT_MAX_HEIGHT,
image_width=TRT_MAX_WIDTH,
text_maxlen=args.text_len,
embedding_dim=args.text_states_dim,
plugin_path=str(plugin_path),
max_batch_size=TRT_MAX_BATCH_SIZE,
)
logger.info(f"Loading TensorRT model finished")
else:
raise ValueError(f"Unknown infer_mode: {self.infer_mode}")
# ========================================================================
# Build inference pipeline. We use a customized StableDiffusionPipeline.
logger.info(f"Loading inference pipeline...")
self.pipeline, self.sampler = self.load_sampler()
logger.info(f'Loading pipeline finished')
# ========================================================================
self.default_negative_prompt = NEGATIVE_PROMPT
logger.info("==================================================")
logger.info(f" Model is ready. ")
logger.info("==================================================")
def load_sampler(self, sampler=None):
pipeline, sampler = get_pipeline(self.args,
self.vae,
self.clip_text_encoder,
self.tokenizer,
self.model,
device=self.device,
rank=0,
embedder_t5=self.embedder_t5,
infer_mode=self.infer_mode,
sampler=sampler,
)
return pipeline, sampler
def calc_rope(self, height, width):
th = height // 8 // self.patch_size
tw = width // 8 // self.patch_size
base_size = 512 // 8 // self.patch_size
start, stop = get_fill_resize_and_crop((th, tw), base_size)
sub_args = [start, stop, (th, tw)]
rope = get_2d_rotary_pos_embed(self.head_size, *sub_args)
return rope
def standard_shapes(self):
resolutions = ResolutionGroup()
freqs_cis_img = {}
for reso in resolutions.data:
freqs_cis_img[str(reso)] = self.calc_rope(reso.height, reso.width)
return resolutions, freqs_cis_img
def predict(self,
user_prompt,
height=1024,
width=1024,
seed=None,
enhanced_prompt=None,
negative_prompt=None,
infer_steps=100,
guidance_scale=6,
batch_size=1,
src_size_cond=(1024, 1024),
sampler=None,
):
# ========================================================================
# Arguments: seed
# ========================================================================
if seed is None:
seed = random.randint(0, 1_000_000)
if not isinstance(seed, int):
raise TypeError(f"`seed` must be an integer, but got {type(seed)}")
generator = set_seeds(seed)
# ========================================================================
# Arguments: target_width, target_height
# ========================================================================
if width <= 0 or height <= 0:
raise ValueError(f"`height` and `width` must be positive integers, got height={height}, width={width}")
logger.info(f"Input (height, width) = ({height}, {width})")
if self.infer_mode in ['fa', 'torch']:
# We must force height and width to align to 16 and to be an integer.
target_height = int((height // 16) * 16)
target_width = int((width // 16) * 16)
logger.info(f"Align to 16: (height, width) = ({target_height}, {target_width})")
elif self.infer_mode == 'trt':
target_width, target_height = get_standard_shape(width, height)
logger.info(f"Align to standard shape: (height, width) = ({target_height}, {target_width})")
else:
raise ValueError(f"Unknown infer_mode: {self.infer_mode}")
# ========================================================================
# Arguments: prompt, new_prompt, negative_prompt
# ========================================================================
if not isinstance(user_prompt, str):
raise TypeError(f"`user_prompt` must be a string, but got {type(user_prompt)}")
user_prompt = user_prompt.strip()
prompt = user_prompt
if enhanced_prompt is not None:
if not isinstance(enhanced_prompt, str):
raise TypeError(f"`enhanced_prompt` must be a string, but got {type(enhanced_prompt)}")
enhanced_prompt = enhanced_prompt.strip()
prompt = enhanced_prompt
# negative prompt
if negative_prompt is None or negative_prompt == '':
negative_prompt = self.default_negative_prompt
if not isinstance(negative_prompt, str):
raise TypeError(f"`negative_prompt` must be a string, but got {type(negative_prompt)}")
# ========================================================================
# Arguments: style. (A fixed argument. Don't Change it.)
# ========================================================================
style = torch.as_tensor([0, 0] * batch_size, device=self.device)
# ========================================================================
# Inner arguments: image_meta_size (Please refer to SDXL.)
# ========================================================================
if isinstance(src_size_cond, int):
src_size_cond = [src_size_cond, src_size_cond]
if not isinstance(src_size_cond, (list, tuple)):
raise TypeError(f"`src_size_cond` must be a list or tuple, but got {type(src_size_cond)}")
if len(src_size_cond) != 2:
raise ValueError(f"`src_size_cond` must be a tuple of 2 integers, but got {len(src_size_cond)}")
size_cond = list(src_size_cond) + [target_width, target_height, 0, 0]
image_meta_size = torch.as_tensor([size_cond] * 2 * batch_size, device=self.device)
# ========================================================================
start_time = time.time()
logger.debug(f"""
prompt: {user_prompt}
enhanced prompt: {enhanced_prompt}
seed: {seed}
(height, width): {(target_height, target_width)}
negative_prompt: {negative_prompt}
batch_size: {batch_size}
guidance_scale: {guidance_scale}
infer_steps: {infer_steps}
image_meta_size: {size_cond}
""")
reso = f'{target_height}x{target_width}'
if reso in self.freqs_cis_img:
freqs_cis_img = self.freqs_cis_img[reso]
else:
freqs_cis_img = self.calc_rope(target_height, target_width)
if sampler is not None and sampler != self.sampler:
self.pipeline, self.sampler = self.load_sampler(sampler)
samples = self.pipeline(
height=target_height,
width=target_width,
prompt=prompt,
negative_prompt=negative_prompt,
num_images_per_prompt=batch_size,
guidance_scale=guidance_scale,
num_inference_steps=infer_steps,
image_meta_size=image_meta_size,
style=style,
return_dict=False,
generator=generator,
freqs_cis_img=freqs_cis_img,
use_fp16=self.args.use_fp16,
learn_sigma=self.args.learn_sigma,
)[0]
gen_time = time.time() - start_time
logger.debug(f"Success, time: {gen_time}")
return {
'images': samples,
'seed': seed,
}
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment