Commit 727428ec authored by jerrrrry's avatar jerrrrry
Browse files

Initial commit CI/CD

parents
# Copyright 2023 Haotian Liu
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple
import torch
from transformers import (
AutoConfig,
AutoModelForCausalLM,
MptConfig,
MptForCausalLM,
MptModel,
)
from llava.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
class LlavaMptConfig(MptConfig):
model_type = "llava_mpt"
class LlavaMptModel(LlavaMetaModel, MptModel):
config_class = LlavaMptConfig
def __init__(self, config: MptConfig):
config.hidden_size = config.d_model
super(LlavaMptModel, self).__init__(config)
def embed_tokens(self, x):
return self.wte(x)
class LlavaMptForCausalLM(MptForCausalLM, LlavaMetaForCausalLM):
config_class = LlavaMptConfig
supports_gradient_checkpointing = True
def __init__(self, config):
super(MptForCausalLM, self).__init__(config)
self.transformer = LlavaMptModel(config)
self.lm_head = torch.nn.Linear(
config.hidden_size, config.vocab_size, bias=False
)
# Initialize weights and apply final processing
self.post_init()
def get_model(self):
return self.transformer
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, LlavaMptModel):
module.gradient_checkpointing = value
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
images=None,
):
input_ids, attention_mask, past_key_values, inputs_embeds, labels = (
self.prepare_inputs_labels_for_multimodal(
input_ids, attention_mask, past_key_values, labels, images
)
)
return super().forward(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs
):
images = kwargs.pop("images", None)
_inputs = super().prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
**kwargs
)
_inputs["images"] = images
return _inputs
AutoConfig.register("llava_mpt", LlavaMptConfig)
AutoModelForCausalLM.register(LlavaMptConfig, LlavaMptForCausalLM)
# Copyright 2023 Haotian Liu
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
import torch
import torch.nn as nn
from .multimodal_encoder.builder import build_vision_tower
from .multimodal_projector.builder import build_vision_projector
from llava.constants import (
IGNORE_INDEX,
IMAGE_TOKEN_INDEX,
DEFAULT_IMAGE_PATCH_TOKEN,
DEFAULT_IM_START_TOKEN,
DEFAULT_IM_END_TOKEN,
)
from llava.mm_utils import get_anyres_image_grid_shape
class LlavaMetaModel:
def __init__(self, config):
super(LlavaMetaModel, self).__init__(config)
if hasattr(config, "mm_vision_tower"):
self.vision_tower = build_vision_tower(config, delay_load=True)
self.mm_projector = build_vision_projector(config)
if "unpad" in getattr(config, "mm_patch_merge_type", ""):
self.image_newline = nn.Parameter(
torch.empty(config.hidden_size, dtype=self.dtype)
)
def get_vision_tower(self):
vision_tower = getattr(self, "vision_tower", None)
if type(vision_tower) is list:
vision_tower = vision_tower[0]
return vision_tower
def initialize_vision_modules(self, model_args, fsdp=None):
vision_tower = model_args.vision_tower
mm_vision_select_layer = model_args.mm_vision_select_layer
mm_vision_select_feature = model_args.mm_vision_select_feature
pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
mm_patch_merge_type = model_args.mm_patch_merge_type
self.config.mm_vision_tower = vision_tower
if self.get_vision_tower() is None:
vision_tower = build_vision_tower(model_args)
if fsdp is not None and len(fsdp) > 0:
self.vision_tower = [vision_tower]
else:
self.vision_tower = vision_tower
else:
if fsdp is not None and len(fsdp) > 0:
vision_tower = self.vision_tower[0]
else:
vision_tower = self.vision_tower
vision_tower.load_model()
self.config.use_mm_proj = True
self.config.mm_projector_type = getattr(
model_args, "mm_projector_type", "linear"
)
self.config.mm_hidden_size = vision_tower.hidden_size
self.config.mm_vision_select_layer = mm_vision_select_layer
self.config.mm_vision_select_feature = mm_vision_select_feature
self.config.mm_patch_merge_type = mm_patch_merge_type
if getattr(self, "mm_projector", None) is None:
self.mm_projector = build_vision_projector(self.config)
if "unpad" in mm_patch_merge_type:
embed_std = 1 / torch.sqrt(
torch.tensor(self.config.hidden_size, dtype=self.dtype)
)
self.image_newline = nn.Parameter(
torch.randn(self.config.hidden_size, dtype=self.dtype) * embed_std
)
else:
# In case it is frozen by LoRA
for p in self.mm_projector.parameters():
p.requires_grad = True
if pretrain_mm_mlp_adapter is not None:
mm_projector_weights = torch.load(
pretrain_mm_mlp_adapter, map_location="cpu"
)
def get_w(weights, keyword):
return {
k.split(keyword + ".")[1]: v
for k, v in weights.items()
if keyword in k
}
self.mm_projector.load_state_dict(
get_w(mm_projector_weights, "mm_projector")
)
def unpad_image(tensor, original_size):
"""
Unpads a PyTorch tensor of a padded and resized image.
Args:
tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format.
original_size (tuple): The original size of the image (height, width).
Returns:
torch.Tensor: The unpadded image tensor.
"""
original_width, original_height = original_size
current_height, current_width = tensor.shape[1:]
original_aspect_ratio = original_width / original_height
current_aspect_ratio = current_width / current_height
if original_aspect_ratio > current_aspect_ratio:
scale_factor = current_width / original_width
new_height = int(original_height * scale_factor)
padding = (current_height - new_height) // 2
unpadded_tensor = tensor[:, padding : current_height - padding, :]
else:
scale_factor = current_height / original_height
new_width = int(original_width * scale_factor)
padding = (current_width - new_width) // 2
unpadded_tensor = tensor[:, :, padding : current_width - padding]
return unpadded_tensor
class LlavaMetaForCausalLM(ABC):
@abstractmethod
def get_model(self):
pass
def get_vision_tower(self):
return self.get_model().get_vision_tower()
def encode_images(self, images):
image_features = self.get_model().get_vision_tower()(images)
image_features = self.get_model().mm_projector(image_features)
return image_features
def prepare_inputs_labels_for_multimodal(
self,
input_ids,
position_ids,
attention_mask,
past_key_values,
labels,
images,
image_sizes=None,
):
vision_tower = self.get_vision_tower()
if vision_tower is None or images is None or input_ids.shape[1] == 1:
return (
input_ids,
position_ids,
attention_mask,
past_key_values,
None,
labels,
)
if type(images) is list or images.ndim == 5:
if type(images) is list:
images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images]
concat_images = torch.cat([image for image in images], dim=0)
image_features = self.encode_images(concat_images)
split_sizes = [image.shape[0] for image in images]
image_features = torch.split(image_features, split_sizes, dim=0)
mm_patch_merge_type = getattr(self.config, "mm_patch_merge_type", "flat")
image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square")
if mm_patch_merge_type == "flat":
image_features = [x.flatten(0, 1) for x in image_features]
elif mm_patch_merge_type.startswith("spatial"):
new_image_features = []
for image_idx, image_feature in enumerate(image_features):
if image_feature.shape[0] > 1:
base_image_feature = image_feature[0]
image_feature = image_feature[1:]
height = width = self.get_vision_tower().num_patches_per_side
assert height * width == base_image_feature.shape[0]
if image_aspect_ratio == "anyres":
num_patch_width, num_patch_height = (
get_anyres_image_grid_shape(
image_sizes[image_idx],
self.config.image_grid_pinpoints,
self.get_vision_tower().config.image_size,
)
)
image_feature = image_feature.view(
num_patch_height, num_patch_width, height, width, -1
)
else:
raise NotImplementedError
if "unpad" in mm_patch_merge_type:
image_feature = image_feature.permute(
4, 0, 2, 1, 3
).contiguous()
image_feature = image_feature.flatten(1, 2).flatten(2, 3)
image_feature = unpad_image(
image_feature, image_sizes[image_idx]
)
image_feature = torch.cat(
(
image_feature,
self.model.image_newline[:, None, None]
.expand(*image_feature.shape[:-1], 1)
.to(image_feature.device),
),
dim=-1,
)
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
else:
image_feature = image_feature.permute(
0, 2, 1, 3, 4
).contiguous()
image_feature = image_feature.flatten(0, 3)
image_feature = torch.cat(
(base_image_feature, image_feature), dim=0
)
else:
image_feature = image_feature[0]
if "unpad" in mm_patch_merge_type:
image_feature = torch.cat(
(
image_feature,
self.model.image_newline[None].to(
image_feature.device
),
),
dim=0,
)
new_image_features.append(image_feature)
image_features = new_image_features
else:
raise ValueError(
f"Unexpected mm_patch_merge_type: {self.config.mm_patch_merge_type}"
)
else:
image_features = self.encode_images(images)
# TODO: image start / end is not implemented here to support pretraining.
if getattr(self.config, "tune_mm_mlp_adapter", False) and getattr(
self.config, "mm_use_im_start_end", False
):
raise NotImplementedError
# Let's just add dummy tensors if they do not exist,
# it is a headache to deal with None all the time.
# But it is not ideal, and if you have a better idea,
# please open an issue / submit a PR, thanks.
_labels = labels
_position_ids = position_ids
_attention_mask = attention_mask
if attention_mask is None:
attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
else:
attention_mask = attention_mask.bool()
if position_ids is None:
position_ids = torch.arange(
0, input_ids.shape[1], dtype=torch.long, device=input_ids.device
)
if labels is None:
labels = torch.full_like(input_ids, IGNORE_INDEX)
# remove the padding using attention_mask -- FIXME
_input_ids = input_ids
input_ids = [
cur_input_ids[cur_attention_mask]
for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)
]
labels = [
cur_labels[cur_attention_mask]
for cur_labels, cur_attention_mask in zip(labels, attention_mask)
]
new_input_embeds = []
new_labels = []
cur_image_idx = 0
for batch_idx, cur_input_ids in enumerate(input_ids):
num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
if num_images == 0:
cur_image_features = image_features[cur_image_idx]
cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
cur_input_embeds = torch.cat(
[cur_input_embeds_1, cur_image_features[0:0]], dim=0
)
new_input_embeds.append(cur_input_embeds)
new_labels.append(labels[batch_idx])
cur_image_idx += 1
continue
image_token_indices = (
[-1]
+ torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist()
+ [cur_input_ids.shape[0]]
)
cur_input_ids_noim = []
cur_labels = labels[batch_idx]
cur_labels_noim = []
for i in range(len(image_token_indices) - 1):
cur_input_ids_noim.append(
cur_input_ids[
image_token_indices[i] + 1 : image_token_indices[i + 1]
]
)
cur_labels_noim.append(
cur_labels[image_token_indices[i] + 1 : image_token_indices[i + 1]]
)
split_sizes = [x.shape[0] for x in cur_labels_noim]
cur_input_embeds = self.get_model().embed_tokens(
torch.cat(cur_input_ids_noim)
)
cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
cur_new_input_embeds = []
cur_new_labels = []
for i in range(num_images + 1):
cur_new_input_embeds.append(cur_input_embeds_no_im[i])
cur_new_labels.append(cur_labels_noim[i])
if i < num_images:
cur_image_features = image_features[cur_image_idx]
cur_image_idx += 1
cur_new_input_embeds.append(cur_image_features)
cur_new_labels.append(
torch.full(
(cur_image_features.shape[0],),
IGNORE_INDEX,
device=cur_labels.device,
dtype=cur_labels.dtype,
)
)
cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds]
cur_new_input_embeds = torch.cat(cur_new_input_embeds)
cur_new_labels = torch.cat(cur_new_labels)
new_input_embeds.append(cur_new_input_embeds)
new_labels.append(cur_new_labels)
# Truncate sequences to max length as image embeddings can make the sequence longer
tokenizer_model_max_length = getattr(
self.config, "tokenizer_model_max_length", None
)
if tokenizer_model_max_length is not None:
new_input_embeds = [
x[:tokenizer_model_max_length] for x in new_input_embeds
]
new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
# Combine them
max_len = max(x.shape[0] for x in new_input_embeds)
batch_size = len(new_input_embeds)
new_input_embeds_padded = []
new_labels_padded = torch.full(
(batch_size, max_len),
IGNORE_INDEX,
dtype=new_labels[0].dtype,
device=new_labels[0].device,
)
attention_mask = torch.zeros(
(batch_size, max_len),
dtype=attention_mask.dtype,
device=attention_mask.device,
)
position_ids = torch.zeros(
(batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device
)
for i, (cur_new_embed, cur_new_labels) in enumerate(
zip(new_input_embeds, new_labels)
):
cur_len = cur_new_embed.shape[0]
if getattr(self.config, "tokenizer_padding_side", "right") == "left":
new_input_embeds_padded.append(
torch.cat(
(
torch.zeros(
(max_len - cur_len, cur_new_embed.shape[1]),
dtype=cur_new_embed.dtype,
device=cur_new_embed.device,
),
cur_new_embed,
),
dim=0,
)
)
if cur_len > 0:
new_labels_padded[i, -cur_len:] = cur_new_labels
attention_mask[i, -cur_len:] = True
position_ids[i, -cur_len:] = torch.arange(
0, cur_len, dtype=position_ids.dtype, device=position_ids.device
)
else:
new_input_embeds_padded.append(
torch.cat(
(
cur_new_embed,
torch.zeros(
(max_len - cur_len, cur_new_embed.shape[1]),
dtype=cur_new_embed.dtype,
device=cur_new_embed.device,
),
),
dim=0,
)
)
if cur_len > 0:
new_labels_padded[i, :cur_len] = cur_new_labels
attention_mask[i, :cur_len] = True
position_ids[i, :cur_len] = torch.arange(
0, cur_len, dtype=position_ids.dtype, device=position_ids.device
)
new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
if _labels is None:
new_labels = None
else:
new_labels = new_labels_padded
if _attention_mask is None:
attention_mask = None
else:
attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
if _position_ids is None:
position_ids = None
return (
None,
position_ids,
attention_mask,
past_key_values,
new_input_embeds,
new_labels,
)
def initialize_vision_tokenizer(self, model_args, tokenizer):
if model_args.mm_use_im_patch_token:
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
self.resize_token_embeddings(len(tokenizer))
if model_args.mm_use_im_start_end:
num_new_tokens = tokenizer.add_tokens(
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
)
self.resize_token_embeddings(len(tokenizer))
if num_new_tokens > 0:
input_embeddings = self.get_input_embeddings().weight.data
output_embeddings = self.get_output_embeddings().weight.data
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
dim=0, keepdim=True
)
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
dim=0, keepdim=True
)
input_embeddings[-num_new_tokens:] = input_embeddings_avg
output_embeddings[-num_new_tokens:] = output_embeddings_avg
if model_args.tune_mm_mlp_adapter:
for p in self.get_input_embeddings().parameters():
p.requires_grad = True
for p in self.get_output_embeddings().parameters():
p.requires_grad = False
if model_args.pretrain_mm_mlp_adapter:
mm_projector_weights = torch.load(
model_args.pretrain_mm_mlp_adapter, map_location="cpu"
)
embed_tokens_weight = mm_projector_weights["model.embed_tokens.weight"]
assert num_new_tokens == 2
if input_embeddings.shape == embed_tokens_weight.shape:
input_embeddings[-num_new_tokens:] = embed_tokens_weight[
-num_new_tokens:
]
elif embed_tokens_weight.shape[0] == num_new_tokens:
input_embeddings[-num_new_tokens:] = embed_tokens_weight
else:
raise ValueError(
f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}."
)
elif model_args.mm_use_im_patch_token:
if model_args.tune_mm_mlp_adapter:
for p in self.get_input_embeddings().parameters():
p.requires_grad = False
for p in self.get_output_embeddings().parameters():
p.requires_grad = False
"""
Usage:
python3 -m llava.model.make_delta --base ~/model_weights/llama-7b --target ~/model_weights/llava-7b --delta ~/model_weights/llava-7b-delta --hub-repo-id liuhaotian/llava-7b-delta
"""
import argparse
import torch
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM
from llava.model.utils import auto_upgrade
def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id):
print("Loading base model")
base = AutoModelForCausalLM.from_pretrained(
base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True
)
print("Loading target model")
auto_upgrade(target_model_path)
target = AutoModelForCausalLM.from_pretrained(
target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True
)
print("Calculating delta")
for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"):
if name not in base.state_dict():
assert name in [
"model.mm_projector.weight",
"model.mm_projector.bias",
], f"{name} not in base model"
continue
if param.data.shape == base.state_dict()[name].shape:
param.data -= base.state_dict()[name]
else:
assert name in [
"model.embed_tokens.weight",
"lm_head.weight",
], f"{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}"
bparam = base.state_dict()[name]
param.data[: bparam.shape[0], : bparam.shape[1]] -= bparam
print("Saving delta")
if hub_repo_id:
kwargs = {"push_to_hub": True, "repo_id": hub_repo_id}
else:
kwargs = {}
target.save_pretrained(delta_path, **kwargs)
target_tokenizer = AutoTokenizer.from_pretrained(target_model_path)
target_tokenizer.save_pretrained(delta_path, **kwargs)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--base-model-path", type=str, required=True)
parser.add_argument("--target-model-path", type=str, required=True)
parser.add_argument("--delta-path", type=str, required=True)
parser.add_argument("--hub-repo-id", type=str, default=None)
args = parser.parse_args()
make_delta(
args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id
)
import os
from .clip_encoder import CLIPVisionTower
def build_vision_tower(vision_tower_cfg, **kwargs):
vision_tower = getattr(
vision_tower_cfg,
"mm_vision_tower",
getattr(vision_tower_cfg, "vision_tower", None),
)
is_absolute_path_exists = os.path.exists(vision_tower)
if (
is_absolute_path_exists
or vision_tower.startswith("openai")
or vision_tower.startswith("laion")
or "ShareGPT4V" in vision_tower
):
return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
raise ValueError(f"Unknown vision tower: {vision_tower}")
import torch
import torch.nn as nn
from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
class CLIPVisionTower(nn.Module):
def __init__(self, vision_tower, args, delay_load=False):
super().__init__()
self.is_loaded = False
self.vision_tower_name = vision_tower
self.select_layer = args.mm_vision_select_layer
self.select_feature = getattr(args, "mm_vision_select_feature", "patch")
if not delay_load:
self.load_model()
elif getattr(args, "unfreeze_mm_vision_tower", False):
self.load_model()
else:
self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
def load_model(self, device_map=None):
if self.is_loaded:
print(
"{} is already loaded, `load_model` called again, skipping.".format(
self.vision_tower_name
)
)
return
self.image_processor = CLIPImageProcessor.from_pretrained(
self.vision_tower_name
)
self.vision_tower = CLIPVisionModel.from_pretrained(
self.vision_tower_name, device_map=device_map
)
self.vision_tower.requires_grad_(False)
self.is_loaded = True
def feature_select(self, image_forward_outs):
image_features = image_forward_outs.hidden_states[self.select_layer]
if self.select_feature == "patch":
image_features = image_features[:, 1:]
elif self.select_feature == "cls_patch":
image_features = image_features
else:
raise ValueError(f"Unexpected select feature: {self.select_feature}")
return image_features
@torch.no_grad()
def forward(self, images):
if type(images) is list:
image_features = []
for image in images:
image_forward_out = self.vision_tower(
image.to(device=self.device, dtype=self.dtype).unsqueeze(0),
output_hidden_states=True,
)
image_feature = self.feature_select(image_forward_out).to(image.dtype)
image_features.append(image_feature)
else:
image_forward_outs = self.vision_tower(
images.to(device=self.device, dtype=self.dtype),
output_hidden_states=True,
)
image_features = self.feature_select(image_forward_outs).to(images.dtype)
return image_features
@property
def dummy_feature(self):
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
@property
def dtype(self):
return self.vision_tower.dtype
@property
def device(self):
return self.vision_tower.device
@property
def config(self):
if self.is_loaded:
return self.vision_tower.config
else:
return self.cfg_only
@property
def hidden_size(self):
return self.config.hidden_size
@property
def num_patches_per_side(self):
return self.config.image_size // self.config.patch_size
@property
def num_patches(self):
return (self.config.image_size // self.config.patch_size) ** 2
import torch
import torch.nn as nn
import re
class IdentityMap(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, *args, **kwargs):
return x
@property
def config(self):
return {"mm_projector_type": "identity"}
class SimpleResBlock(nn.Module):
def __init__(self, channels):
super().__init__()
self.pre_norm = nn.LayerNorm(channels)
self.proj = nn.Sequential(
nn.Linear(channels, channels), nn.GELU(), nn.Linear(channels, channels)
)
def forward(self, x):
x = self.pre_norm(x)
return x + self.proj(x)
def build_vision_projector(config, delay_load=False, **kwargs):
projector_type = getattr(config, "mm_projector_type", "linear")
if projector_type == "linear":
return nn.Linear(config.mm_hidden_size, config.hidden_size)
mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", projector_type)
if mlp_gelu_match:
mlp_depth = int(mlp_gelu_match.group(1))
modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
for _ in range(1, mlp_depth):
modules.append(nn.GELU())
modules.append(nn.Linear(config.hidden_size, config.hidden_size))
return nn.Sequential(*modules)
if projector_type == "identity":
return IdentityMap()
raise ValueError(f"Unknown projector type: {projector_type}")
from transformers import AutoConfig
def auto_upgrade(config):
cfg = AutoConfig.from_pretrained(config)
if "llava" in config and "llava" not in cfg.model_type:
assert cfg.model_type == "llama"
print(
"You are using newer LLaVA code base, while the checkpoint of v0 is from older code base."
)
print(
"You must upgrade the checkpoint to the new code base (this can be done automatically)."
)
confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]")
if confirm.lower() in ["y", "yes"]:
print("Upgrading checkpoint...")
assert len(cfg.architectures) == 1
setattr(cfg.__class__, "model_type", "llava")
cfg.architectures[0] = "LlavaLlamaForCausalLM"
cfg.save_pretrained(config)
print("Checkpoint upgraded.")
else:
print("Checkpoint upgrade aborted.")
exit(1)
import argparse
import torch
from llava.constants import (
IMAGE_TOKEN_INDEX,
DEFAULT_IMAGE_TOKEN,
DEFAULT_IM_START_TOKEN,
DEFAULT_IM_END_TOKEN,
)
from llava.conversation import conv_templates, SeparatorStyle
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.mm_utils import (
process_images,
tokenizer_image_token,
get_model_name_from_path,
)
from PIL import Image
import requests
from PIL import Image
from io import BytesIO
from transformers import TextStreamer
def load_image(image_file):
if image_file.startswith("http://") or image_file.startswith("https://"):
response = requests.get(image_file)
image = Image.open(BytesIO(response.content)).convert("RGB")
else:
image = Image.open(image_file).convert("RGB")
return image
def main(args):
# Model
disable_torch_init()
model_name = get_model_name_from_path(args.model_path)
tokenizer, model, image_processor, context_len = load_pretrained_model(
args.model_path,
args.model_base,
model_name,
args.load_8bit,
args.load_4bit,
device=args.device,
)
if "llama-2" in model_name.lower():
conv_mode = "llava_llama_2"
elif "mistral" in model_name.lower():
conv_mode = "mistral_instruct"
elif "v1.6-34b" in model_name.lower():
conv_mode = "chatml_direct"
elif "v1" in model_name.lower():
conv_mode = "llava_v1"
elif "mpt" in model_name.lower():
conv_mode = "mpt"
else:
conv_mode = "llava_v0"
if args.conv_mode is not None and conv_mode != args.conv_mode:
print(
"[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format(
conv_mode, args.conv_mode, args.conv_mode
)
)
else:
args.conv_mode = conv_mode
conv = conv_templates[args.conv_mode].copy()
if "mpt" in model_name.lower():
roles = ("user", "assistant")
else:
roles = conv.roles
if args.image_file is not None:
image = load_image(args.image_file)
image_size = image.size
# Similar operation in model_worker.py
image_tensor = process_images([image], image_processor, model.config)
if type(image_tensor) is list:
image_tensor = [
image.to(model.device, dtype=torch.float16) for image in image_tensor
]
else:
image_tensor = image_tensor.to(model.device, dtype=torch.float16)
else:
image = True
image_size = (1024, 1024)
image_tensor = torch.zeros(1, 5, 3, 336, 336)
image_tensor = image_tensor.to(model.device, dtype=torch.float16)
while True:
try:
inp = input(f"{roles[0]}: ")
except EOFError:
inp = ""
if not inp:
print("exit...")
break
print(f"{roles[1]}: ", end="")
if image is not None:
# first message
if model.config.mm_use_im_start_end:
inp = (
DEFAULT_IM_START_TOKEN
+ DEFAULT_IMAGE_TOKEN
+ DEFAULT_IM_END_TOKEN
+ "\n"
+ inp
)
else:
inp = inp.replace(DEFAULT_IMAGE_TOKEN, "").strip()
inp = DEFAULT_IMAGE_TOKEN + "\n" + inp
inp = inp.strip()
image = None
conv.append_message(conv.roles[0], inp)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
input_ids = (
tokenizer_image_token(
prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
)
.unsqueeze(0)
.to(model.device)
)
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
keywords = [stop_str]
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
with torch.inference_mode():
output_ids = model.generate(
input_ids,
images=image_tensor,
image_sizes=[image_size],
do_sample=True if args.temperature > 0 else False,
temperature=args.temperature,
max_new_tokens=args.max_new_tokens,
streamer=streamer,
use_cache=True,
)
outputs = tokenizer.decode(output_ids[0]).strip()
conv.messages[-1][-1] = outputs
if args.debug:
print("\n", {"prompt": prompt, "outputs": outputs}, "\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
parser.add_argument("--model-base", type=str, default=None)
parser.add_argument("--image-file", type=str, default=None)
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--conv-mode", type=str, default=None)
parser.add_argument("--temperature", type=float, default=0.2)
parser.add_argument("--max-new-tokens", type=int, default=512)
parser.add_argument("--load-8bit", action="store_true")
parser.add_argument("--load-4bit", action="store_true")
parser.add_argument("--debug", action="store_true")
args = parser.parse_args()
main(args)
"""
A controller manages distributed workers.
It sends worker addresses to clients.
"""
import argparse
import asyncio
import dataclasses
from enum import Enum, auto
import json
import logging
import time
from typing import List, Union
import threading
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
import numpy as np
import requests
import uvicorn
from llava.constants import CONTROLLER_HEART_BEAT_EXPIRATION
from llava.utils import build_logger, server_error_msg
logger = build_logger("controller", "controller.log")
class DispatchMethod(Enum):
LOTTERY = auto()
SHORTEST_QUEUE = auto()
@classmethod
def from_str(cls, name):
if name == "lottery":
return cls.LOTTERY
elif name == "shortest_queue":
return cls.SHORTEST_QUEUE
else:
raise ValueError(f"Invalid dispatch method")
@dataclasses.dataclass
class WorkerInfo:
model_names: List[str]
speed: int
queue_length: int
check_heart_beat: bool
last_heart_beat: str
def heart_beat_controller(controller):
while True:
time.sleep(CONTROLLER_HEART_BEAT_EXPIRATION)
controller.remove_stable_workers_by_expiration()
class Controller:
def __init__(self, dispatch_method: str):
# Dict[str -> WorkerInfo]
self.worker_info = {}
self.dispatch_method = DispatchMethod.from_str(dispatch_method)
self.heart_beat_thread = threading.Thread(
target=heart_beat_controller, args=(self,), daemon=True
)
self.heart_beat_thread.start()
logger.info("Init controller")
def register_worker(
self, worker_name: str, check_heart_beat: bool, worker_status: dict
):
if worker_name not in self.worker_info:
logger.info(f"Register a new worker: {worker_name}")
else:
logger.info(f"Register an existing worker: {worker_name}")
if not worker_status:
worker_status = self.get_worker_status(worker_name)
if not worker_status:
return False
self.worker_info[worker_name] = WorkerInfo(
worker_status["model_names"],
worker_status["speed"],
worker_status["queue_length"],
check_heart_beat,
time.time(),
)
logger.info(f"Register done: {worker_name}, {worker_status}")
return True
def get_worker_status(self, worker_name: str):
try:
r = requests.post(worker_name + "/worker_get_status", timeout=5)
except requests.exceptions.RequestException as e:
logger.error(f"Get status fails: {worker_name}, {e}")
return None
if r.status_code != 200:
logger.error(f"Get status fails: {worker_name}, {r}")
return None
return r.json()
def remove_worker(self, worker_name: str):
del self.worker_info[worker_name]
def refresh_all_workers(self):
old_info = dict(self.worker_info)
self.worker_info = {}
for w_name, w_info in old_info.items():
if not self.register_worker(w_name, w_info.check_heart_beat, None):
logger.info(f"Remove stale worker: {w_name}")
def list_models(self):
model_names = set()
for w_name, w_info in self.worker_info.items():
model_names.update(w_info.model_names)
return list(model_names)
def get_worker_address(self, model_name: str):
if self.dispatch_method == DispatchMethod.LOTTERY:
worker_names = []
worker_speeds = []
for w_name, w_info in self.worker_info.items():
if model_name in w_info.model_names:
worker_names.append(w_name)
worker_speeds.append(w_info.speed)
worker_speeds = np.array(worker_speeds, dtype=np.float32)
norm = np.sum(worker_speeds)
if norm < 1e-4:
return ""
worker_speeds = worker_speeds / norm
if True: # Directly return address
pt = np.random.choice(np.arange(len(worker_names)), p=worker_speeds)
worker_name = worker_names[pt]
return worker_name
# Check status before returning
while True:
pt = np.random.choice(np.arange(len(worker_names)), p=worker_speeds)
worker_name = worker_names[pt]
if self.get_worker_status(worker_name):
break
else:
self.remove_worker(worker_name)
worker_speeds[pt] = 0
norm = np.sum(worker_speeds)
if norm < 1e-4:
return ""
worker_speeds = worker_speeds / norm
continue
return worker_name
elif self.dispatch_method == DispatchMethod.SHORTEST_QUEUE:
worker_names = []
worker_qlen = []
for w_name, w_info in self.worker_info.items():
if model_name in w_info.model_names:
worker_names.append(w_name)
worker_qlen.append(w_info.queue_length / w_info.speed)
if len(worker_names) == 0:
return ""
min_index = np.argmin(worker_qlen)
w_name = worker_names[min_index]
self.worker_info[w_name].queue_length += 1
logger.info(
f"names: {worker_names}, queue_lens: {worker_qlen}, ret: {w_name}"
)
return w_name
else:
raise ValueError(f"Invalid dispatch method: {self.dispatch_method}")
def receive_heart_beat(self, worker_name: str, queue_length: int):
if worker_name not in self.worker_info:
logger.info(f"Receive unknown heart beat. {worker_name}")
return False
self.worker_info[worker_name].queue_length = queue_length
self.worker_info[worker_name].last_heart_beat = time.time()
logger.info(f"Receive heart beat. {worker_name}")
return True
def remove_stable_workers_by_expiration(self):
expire = time.time() - CONTROLLER_HEART_BEAT_EXPIRATION
to_delete = []
for worker_name, w_info in self.worker_info.items():
if w_info.check_heart_beat and w_info.last_heart_beat < expire:
to_delete.append(worker_name)
for worker_name in to_delete:
self.remove_worker(worker_name)
def worker_api_generate_stream(self, params):
worker_addr = self.get_worker_address(params["model"])
if not worker_addr:
logger.info(f"no worker: {params['model']}")
ret = {
"text": server_error_msg,
"error_code": 2,
}
yield json.dumps(ret).encode() + b"\0"
try:
response = requests.post(
worker_addr + "/worker_generate_stream",
json=params,
stream=True,
timeout=5,
)
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk:
yield chunk + b"\0"
except requests.exceptions.RequestException as e:
logger.info(f"worker timeout: {worker_addr}")
ret = {
"text": server_error_msg,
"error_code": 3,
}
yield json.dumps(ret).encode() + b"\0"
# Let the controller act as a worker to achieve hierarchical
# management. This can be used to connect isolated sub networks.
def worker_api_get_status(self):
model_names = set()
speed = 0
queue_length = 0
for w_name in self.worker_info:
worker_status = self.get_worker_status(w_name)
if worker_status is not None:
model_names.update(worker_status["model_names"])
speed += worker_status["speed"]
queue_length += worker_status["queue_length"]
return {
"model_names": list(model_names),
"speed": speed,
"queue_length": queue_length,
}
app = FastAPI()
@app.post("/register_worker")
async def register_worker(request: Request):
data = await request.json()
controller.register_worker(
data["worker_name"], data["check_heart_beat"], data.get("worker_status", None)
)
@app.post("/refresh_all_workers")
async def refresh_all_workers():
models = controller.refresh_all_workers()
@app.post("/list_models")
async def list_models():
models = controller.list_models()
return {"models": models}
@app.post("/get_worker_address")
async def get_worker_address(request: Request):
data = await request.json()
addr = controller.get_worker_address(data["model"])
return {"address": addr}
@app.post("/receive_heart_beat")
async def receive_heart_beat(request: Request):
data = await request.json()
exist = controller.receive_heart_beat(data["worker_name"], data["queue_length"])
return {"exist": exist}
@app.post("/worker_generate_stream")
async def worker_api_generate_stream(request: Request):
params = await request.json()
generator = controller.worker_api_generate_stream(params)
return StreamingResponse(generator)
@app.post("/worker_get_status")
async def worker_api_get_status(request: Request):
return controller.worker_api_get_status()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=21001)
parser.add_argument(
"--dispatch-method",
type=str,
choices=["lottery", "shortest_queue"],
default="shortest_queue",
)
args = parser.parse_args()
logger.info(f"args: {args}")
controller = Controller(args.dispatch_method)
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
import argparse
import datetime
import json
import os
import time
import gradio as gr
import requests
from llava.conversation import default_conversation, conv_templates, SeparatorStyle
from llava.constants import LOGDIR
from llava.utils import (
build_logger,
server_error_msg,
violates_moderation,
moderation_msg,
)
import hashlib
logger = build_logger("gradio_web_server", "gradio_web_server.log")
headers = {"User-Agent": "LLaVA Client"}
no_change_btn = gr.Button()
enable_btn = gr.Button(interactive=True)
disable_btn = gr.Button(interactive=False)
priority = {
"vicuna-13b": "aaaaaaa",
"koala-13b": "aaaaaab",
}
def get_conv_log_filename():
t = datetime.datetime.now()
name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
return name
def get_model_list():
ret = requests.post(args.controller_url + "/refresh_all_workers")
assert ret.status_code == 200
ret = requests.post(args.controller_url + "/list_models")
models = ret.json()["models"]
models.sort(key=lambda x: priority.get(x, x))
logger.info(f"Models: {models}")
return models
get_window_url_params = """
function() {
const params = new URLSearchParams(window.location.search);
url_params = Object.fromEntries(params);
console.log(url_params);
return url_params;
}
"""
def load_demo(url_params, request: gr.Request):
logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
dropdown_update = gr.Dropdown(visible=True)
if "model" in url_params:
model = url_params["model"]
if model in models:
dropdown_update = gr.Dropdown(value=model, visible=True)
state = default_conversation.copy()
return state, dropdown_update
def load_demo_refresh_model_list(request: gr.Request):
logger.info(f"load_demo. ip: {request.client.host}")
models = get_model_list()
state = default_conversation.copy()
dropdown_update = gr.Dropdown(
choices=models,
value=models[0] if len(models) > 0 else "",
)
return state, dropdown_update
def vote_last_response(state, vote_type, model_selector, request: gr.Request):
with open(get_conv_log_filename(), "a") as fout:
data = {
"tstamp": round(time.time(), 4),
"type": vote_type,
"model": model_selector,
"state": state.dict(),
"ip": request.client.host,
}
fout.write(json.dumps(data) + "\n")
def upvote_last_response(state, model_selector, request: gr.Request):
logger.info(f"upvote. ip: {request.client.host}")
vote_last_response(state, "upvote", model_selector, request)
return ("",) + (disable_btn,) * 3
def downvote_last_response(state, model_selector, request: gr.Request):
logger.info(f"downvote. ip: {request.client.host}")
vote_last_response(state, "downvote", model_selector, request)
return ("",) + (disable_btn,) * 3
def flag_last_response(state, model_selector, request: gr.Request):
logger.info(f"flag. ip: {request.client.host}")
vote_last_response(state, "flag", model_selector, request)
return ("",) + (disable_btn,) * 3
def regenerate(state, image_process_mode, request: gr.Request):
logger.info(f"regenerate. ip: {request.client.host}")
state.messages[-1][-1] = None
prev_human_msg = state.messages[-2]
if type(prev_human_msg[1]) in (tuple, list):
prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode)
state.skip_next = False
return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
def clear_history(request: gr.Request):
logger.info(f"clear_history. ip: {request.client.host}")
state = default_conversation.copy()
return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
def add_text(state, text, image, image_process_mode, request: gr.Request):
logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}")
if len(text) <= 0 and image is None:
state.skip_next = True
return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5
if args.moderate:
flagged = violates_moderation(text)
if flagged:
state.skip_next = True
return (state, state.to_gradio_chatbot(), moderation_msg, None) + (
no_change_btn,
) * 5
text = text[:1536] # Hard cut-off
if image is not None:
text = text[:1200] # Hard cut-off for images
if "<image>" not in text:
# text = '<Image><image></Image>' + text
text = text + "\n<image>"
text = (text, image, image_process_mode)
state = default_conversation.copy()
state.append_message(state.roles[0], text)
state.append_message(state.roles[1], None)
state.skip_next = False
return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
def http_bot(
state, model_selector, temperature, top_p, max_new_tokens, request: gr.Request
):
logger.info(f"http_bot. ip: {request.client.host}")
start_tstamp = time.time()
model_name = model_selector
if state.skip_next:
# This generate call is skipped due to invalid inputs
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
return
if len(state.messages) == state.offset + 2:
# First round of conversation
if "llava" in model_name.lower():
if "llama-2" in model_name.lower():
template_name = "llava_llama_2"
elif "mistral" in model_name.lower() or "mixtral" in model_name.lower():
if "orca" in model_name.lower():
template_name = "mistral_orca"
elif "hermes" in model_name.lower():
template_name = "chatml_direct"
else:
template_name = "mistral_instruct"
elif "llava-v1.6-34b" in model_name.lower():
template_name = "chatml_direct"
elif "v1" in model_name.lower():
if "mmtag" in model_name.lower():
template_name = "v1_mmtag"
elif (
"plain" in model_name.lower()
and "finetune" not in model_name.lower()
):
template_name = "v1_mmtag"
else:
template_name = "llava_v1"
elif "mpt" in model_name.lower():
template_name = "mpt"
else:
if "mmtag" in model_name.lower():
template_name = "v0_mmtag"
elif (
"plain" in model_name.lower()
and "finetune" not in model_name.lower()
):
template_name = "v0_mmtag"
else:
template_name = "llava_v0"
elif "mpt" in model_name:
template_name = "mpt_text"
elif "llama-2" in model_name:
template_name = "llama_2"
else:
template_name = "vicuna_v1"
new_state = conv_templates[template_name].copy()
new_state.append_message(new_state.roles[0], state.messages[-2][1])
new_state.append_message(new_state.roles[1], None)
state = new_state
# Query worker address
controller_url = args.controller_url
ret = requests.post(
controller_url + "/get_worker_address", json={"model": model_name}
)
worker_addr = ret.json()["address"]
logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}")
# No available worker
if worker_addr == "":
state.messages[-1][-1] = server_error_msg
yield (
state,
state.to_gradio_chatbot(),
disable_btn,
disable_btn,
disable_btn,
enable_btn,
enable_btn,
)
return
# Construct prompt
prompt = state.get_prompt()
all_images = state.get_images(return_pil=True)
all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images]
for image, hash in zip(all_images, all_image_hash):
t = datetime.datetime.now()
filename = os.path.join(
LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{hash}.jpg"
)
if not os.path.isfile(filename):
os.makedirs(os.path.dirname(filename), exist_ok=True)
image.save(filename)
# Make requests
pload = {
"model": model_name,
"prompt": prompt,
"temperature": float(temperature),
"top_p": float(top_p),
"max_new_tokens": min(int(max_new_tokens), 1536),
"stop": (
state.sep
if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT]
else state.sep2
),
"images": f"List of {len(state.get_images())} images: {all_image_hash}",
}
logger.info(f"==== request ====\n{pload}")
pload["images"] = state.get_images()
state.messages[-1][-1] = "▌"
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
try:
# Stream output
response = requests.post(
worker_addr + "/worker_generate_stream",
headers=headers,
json=pload,
stream=True,
timeout=10,
)
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk:
data = json.loads(chunk.decode())
if data["error_code"] == 0:
output = data["text"][len(prompt) :].strip()
state.messages[-1][-1] = output + "▌"
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
else:
output = data["text"] + f" (error_code: {data['error_code']})"
state.messages[-1][-1] = output
yield (state, state.to_gradio_chatbot()) + (
disable_btn,
disable_btn,
disable_btn,
enable_btn,
enable_btn,
)
return
time.sleep(0.03)
except requests.exceptions.RequestException as e:
state.messages[-1][-1] = server_error_msg
yield (state, state.to_gradio_chatbot()) + (
disable_btn,
disable_btn,
disable_btn,
enable_btn,
enable_btn,
)
return
state.messages[-1][-1] = state.messages[-1][-1][:-1]
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
finish_tstamp = time.time()
logger.info(f"{output}")
with open(get_conv_log_filename(), "a") as fout:
data = {
"tstamp": round(finish_tstamp, 4),
"type": "chat",
"model": model_name,
"start": round(start_tstamp, 4),
"finish": round(finish_tstamp, 4),
"state": state.dict(),
"images": all_image_hash,
"ip": request.client.host,
}
fout.write(json.dumps(data) + "\n")
block_css = """
#buttons button {
min-width: min(120px,100%);
}
"""
def build_demo(embed_mode, cur_dir=None, concurrency_count=10):
textbox = gr.Textbox(
show_label=False, placeholder="Enter text and press ENTER", container=False
)
with gr.Blocks(title="Hunyuan", theme=gr.themes.Default(), css=block_css) as demo:
state = gr.State()
with gr.Row():
with gr.Column(scale=3):
with gr.Row(elem_id="model_selector_row"):
model_selector = gr.Dropdown(
choices=models,
value=models[0] if len(models) > 0 else "",
interactive=True,
show_label=False,
container=False,
)
imagebox = gr.Image(type="pil")
image_process_mode = gr.Radio(
["Crop", "Resize", "Pad", "Default"],
value="Default",
label="Preprocess for non-square image",
visible=False,
)
if cur_dir is None:
cur_dir = os.path.dirname(os.path.abspath(__file__))
with gr.Accordion("Parameters", open=False) as parameter_row:
temperature = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.2,
step=0.1,
interactive=True,
label="Temperature",
)
top_p = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.7,
step=0.1,
interactive=True,
label="Top P",
)
max_output_tokens = gr.Slider(
minimum=0,
maximum=1024,
value=512,
step=64,
interactive=True,
label="Max output tokens",
)
with gr.Column(scale=8):
chatbot = gr.Chatbot(
elem_id="chatbot",
label="Hunyuan",
height=650,
layout="panel",
)
with gr.Row():
with gr.Column(scale=8):
textbox.render()
with gr.Column(scale=1, min_width=50):
submit_btn = gr.Button(value="Send", variant="primary")
with gr.Row(elem_id="buttons") as button_row:
regenerate_btn = gr.Button(
value="🔄 Regenerate", interactive=False
)
clear_btn = gr.Button(value="🗑️ Clear", interactive=False)
url_params = gr.JSON(visible=False)
btn_list = [regenerate_btn, clear_btn]
regenerate_btn.click(
regenerate,
[state, image_process_mode],
[state, chatbot, textbox, imagebox] + btn_list,
).then(
http_bot,
[state, model_selector, temperature, top_p, max_output_tokens],
[state, chatbot] + btn_list,
# concurrency_limit=concurrency_count
)
clear_btn.click(
clear_history,
None,
[state, chatbot, textbox, imagebox] + btn_list,
queue=False,
)
textbox.submit(
add_text,
[state, textbox, imagebox, image_process_mode],
[state, chatbot, textbox, imagebox] + btn_list,
queue=False,
).then(
http_bot,
[state, model_selector, temperature, top_p, max_output_tokens],
[state, chatbot] + btn_list,
)
submit_btn.click(
add_text,
[state, textbox, imagebox, image_process_mode],
[state, chatbot, textbox, imagebox] + btn_list,
).then(
http_bot,
[state, model_selector, temperature, top_p, max_output_tokens],
[state, chatbot] + btn_list,
)
if args.model_list_mode == "once":
demo.load(
load_demo,
[url_params],
[state, model_selector],
js=get_window_url_params,
)
elif args.model_list_mode == "reload":
demo.load(
load_demo_refresh_model_list, None, [state, model_selector], queue=False
)
else:
raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
return demo
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="0.0.0.0")
parser.add_argument("--port", type=int)
parser.add_argument("--controller-url", type=str, default="http://localhost:21001")
parser.add_argument("--concurrency-count", type=int, default=16)
parser.add_argument(
"--model-list-mode", type=str, default="once", choices=["once", "reload"]
)
parser.add_argument("--share", action="store_true")
parser.add_argument("--moderate", action="store_true")
parser.add_argument("--embed", action="store_true")
args = parser.parse_args()
logger.info(f"args: {args}")
models = get_model_list()
logger.info(args)
demo = build_demo(args.embed, concurrency_count=args.concurrency_count)
demo.queue(api_open=False).launch(
server_name=args.host, server_port=args.port, share=args.share
)
"""
A model worker executes the model.
"""
import argparse
import asyncio
import json
import time
import threading
import uuid
from fastapi import FastAPI, Request, BackgroundTasks
from fastapi.responses import StreamingResponse
import requests
import torch
import uvicorn
from functools import partial
from llava.constants import WORKER_HEART_BEAT_INTERVAL
from llava.utils import build_logger, server_error_msg, pretty_print_semaphore
from llava.model.builder import load_pretrained_model
from llava.mm_utils import process_images, load_image_from_base64, tokenizer_image_token
from llava.constants import (
IMAGE_TOKEN_INDEX,
DEFAULT_IMAGE_TOKEN,
DEFAULT_IM_START_TOKEN,
DEFAULT_IM_END_TOKEN,
)
from transformers import TextIteratorStreamer
from threading import Thread
GB = 1 << 30
worker_id = str(uuid.uuid4())[:6]
logger = build_logger("model_worker", f"model_worker_{worker_id}.log")
global_counter = 0
model_semaphore = None
def heart_beat_worker(controller):
while True:
time.sleep(WORKER_HEART_BEAT_INTERVAL)
controller.send_heart_beat()
class ModelWorker:
def __init__(
self,
controller_addr,
worker_addr,
worker_id,
no_register,
model_path,
model_base,
model_name,
load_8bit,
load_4bit,
device,
use_flash_attn=False,
):
self.controller_addr = controller_addr
self.worker_addr = worker_addr
self.worker_id = worker_id
if model_path.endswith("/"):
model_path = model_path[:-1]
if model_name is None:
model_paths = model_path.split("/")
if model_paths[-1].startswith("checkpoint-"):
self.model_name = model_paths[-2] + "_" + model_paths[-1]
else:
self.model_name = model_paths[-1]
else:
self.model_name = model_name
self.device = device
logger.info(f"Loading the model {self.model_name} on worker {worker_id} ...")
self.tokenizer, self.model, self.image_processor, self.context_len = (
load_pretrained_model(
model_path,
model_base,
self.model_name,
load_8bit,
load_4bit,
device=self.device,
use_flash_attn=use_flash_attn,
)
)
self.is_multimodal = "llava" in self.model_name.lower()
if not no_register:
self.register_to_controller()
self.heart_beat_thread = threading.Thread(
target=heart_beat_worker, args=(self,), daemon=True
)
self.heart_beat_thread.start()
def register_to_controller(self):
logger.info("Register to controller")
url = self.controller_addr + "/register_worker"
data = {
"worker_name": self.worker_addr,
"check_heart_beat": True,
"worker_status": self.get_status(),
}
r = requests.post(url, json=data)
assert r.status_code == 200
def send_heart_beat(self):
logger.info(
f"Send heart beat. Models: {[self.model_name]}. "
f"Semaphore: {pretty_print_semaphore(model_semaphore)}. "
f"global_counter: {global_counter}"
)
url = self.controller_addr + "/receive_heart_beat"
while True:
try:
ret = requests.post(
url,
json={
"worker_name": self.worker_addr,
"queue_length": self.get_queue_length(),
},
timeout=5,
)
exist = ret.json()["exist"]
break
except requests.exceptions.RequestException as e:
logger.error(f"heart beat error: {e}")
time.sleep(5)
if not exist:
self.register_to_controller()
def get_queue_length(self):
if model_semaphore is None:
return 0
else:
return (
args.limit_model_concurrency
- model_semaphore._value
+ (
len(model_semaphore._waiters)
if model_semaphore._waiters is not None
else 0
)
)
def get_status(self):
return {
"model_names": [self.model_name],
"speed": 1,
"queue_length": self.get_queue_length(),
}
@torch.inference_mode()
def generate_stream(self, params):
tokenizer, model, image_processor = (
self.tokenizer,
self.model,
self.image_processor,
)
prompt = params["prompt"]
ori_prompt = prompt
images = params.get("images", None)
num_image_tokens = 0
if images is not None and len(images) > 0 and self.is_multimodal:
if len(images) > 0:
if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):
raise ValueError(
"Number of images does not match number of <image> tokens in prompt"
)
images = [load_image_from_base64(image) for image in images]
image_sizes = [image.size for image in images]
images = process_images(images, image_processor, model.config)
if type(images) is list:
images = [
image.to(self.model.device, dtype=torch.float16)
for image in images
]
else:
images = images.to(self.model.device, dtype=torch.float16)
replace_token = DEFAULT_IMAGE_TOKEN
if getattr(self.model.config, "mm_use_im_start_end", False):
replace_token = (
DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
)
prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
num_image_tokens = (
prompt.count(replace_token) * model.get_vision_tower().num_patches
)
else:
images = None
image_sizes = None
image_args = {"images": images, "image_sizes": image_sizes}
else:
images = None
image_args = {}
temperature = float(params.get("temperature", 1.0))
top_p = float(params.get("top_p", 1.0))
max_context_length = getattr(model.config, "max_position_embeddings", 2048)
max_new_tokens = min(int(params.get("max_new_tokens", 256)), 1024)
stop_str = params.get("stop", None)
do_sample = True if temperature > 0.001 else False
input_ids = (
tokenizer_image_token(
prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
)
.unsqueeze(0)
.to(self.device)
)
keywords = [stop_str]
# stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
streamer = TextIteratorStreamer(
tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15
)
max_new_tokens = min(
max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens
)
if max_new_tokens < 1:
yield json.dumps(
{
"text": ori_prompt
+ "Exceeds max token length. Please start a new conversation, thanks.",
"error_code": 0,
}
).encode() + b"\0"
return
thread = Thread(
target=model.generate,
kwargs=dict(
inputs=input_ids,
do_sample=do_sample,
temperature=temperature,
top_p=top_p,
max_new_tokens=max_new_tokens,
streamer=streamer,
use_cache=True,
**image_args,
),
)
thread.start()
generated_text = ori_prompt
for new_text in streamer:
generated_text += new_text
if generated_text.endswith(stop_str):
generated_text = generated_text[: -len(stop_str)]
yield json.dumps({"text": generated_text, "error_code": 0}).encode() + b"\0"
def generate_stream_gate(self, params):
try:
for x in self.generate_stream(params):
yield x
except ValueError as e:
print("Caught ValueError:", e)
ret = {
"text": server_error_msg,
"error_code": 1,
}
yield json.dumps(ret).encode() + b"\0"
except torch.cuda.CudaError as e:
print("Caught torch.cuda.CudaError:", e)
ret = {
"text": server_error_msg,
"error_code": 1,
}
yield json.dumps(ret).encode() + b"\0"
except Exception as e:
print("Caught Unknown Error", e)
ret = {
"text": server_error_msg,
"error_code": 1,
}
yield json.dumps(ret).encode() + b"\0"
app = FastAPI()
def release_model_semaphore(fn=None):
model_semaphore.release()
if fn is not None:
fn()
@app.post("/worker_generate_stream")
async def generate_stream(request: Request):
global model_semaphore, global_counter
global_counter += 1
params = await request.json()
if model_semaphore is None:
model_semaphore = asyncio.Semaphore(args.limit_model_concurrency)
await model_semaphore.acquire()
worker.send_heart_beat()
generator = worker.generate_stream_gate(params)
background_tasks = BackgroundTasks()
background_tasks.add_task(
partial(release_model_semaphore, fn=worker.send_heart_beat)
)
return StreamingResponse(generator, background=background_tasks)
@app.post("/worker_get_status")
async def get_status(request: Request):
return worker.get_status()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=21002)
parser.add_argument("--worker-address", type=str, default="http://localhost:21002")
parser.add_argument(
"--controller-address", type=str, default="http://localhost:21001"
)
parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
parser.add_argument("--model-base", type=str, default=None)
parser.add_argument("--model-name", type=str)
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument(
"--multi-modal",
action="store_true",
help="Multimodal mode is automatically detected with model name, please make sure `llava` is included in the model path.",
)
parser.add_argument("--limit-model-concurrency", type=int, default=5)
parser.add_argument("--stream-interval", type=int, default=1)
parser.add_argument("--no-register", action="store_true")
parser.add_argument("--load-8bit", action="store_true")
parser.add_argument("--load-4bit", action="store_true")
parser.add_argument("--use-flash-attn", action="store_true")
args = parser.parse_args()
logger.info(f"args: {args}")
if args.multi_modal:
logger.warning(
"Multimodal mode is automatically detected with model name, please make sure `llava` is included in the model path."
)
worker = ModelWorker(
args.controller_address,
args.worker_address,
worker_id,
args.no_register,
args.model_path,
args.model_base,
args.model_name,
args.load_8bit,
args.load_4bit,
args.device,
use_flash_attn=args.use_flash_attn,
)
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
"""
Manually register workers.
Usage:
python3 -m fastchat.serve.register_worker --controller http://localhost:21001 --worker-name http://localhost:21002
"""
import argparse
import requests
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--controller-address", type=str)
parser.add_argument("--worker-name", type=str)
parser.add_argument("--check-heart-beat", action="store_true")
args = parser.parse_args()
url = args.controller_address + "/register_worker"
data = {
"worker_name": args.worker_name,
"check_heart_beat": args.check_heart_beat,
"worker_status": None,
}
r = requests.post(url, json=data)
assert r.status_code == 200
"""
A model worker executes the model.
"""
import argparse
import asyncio
from concurrent.futures import ThreadPoolExecutor
import json
import time
import threading
import uuid
from fastapi import FastAPI, Request, BackgroundTasks
from fastapi.responses import StreamingResponse
import requests
import re
import uvicorn
from functools import partial
from llava.constants import WORKER_HEART_BEAT_INTERVAL
from llava.utils import build_logger, server_error_msg, pretty_print_semaphore
from llava.mm_utils import (
process_images,
load_image_from_base64,
tokenizer_image_token,
expand2square,
)
from llava.constants import DEFAULT_IMAGE_TOKEN
import sglang as sgl
from sglang.backend.runtime_endpoint import RuntimeEndpoint
GB = 1 << 30
worker_id = str(uuid.uuid4())[:6]
logger = build_logger("model_worker", f"model_worker_{worker_id}.log")
global_counter = 0
model_semaphore = None
def heart_beat_worker(controller):
while True:
time.sleep(WORKER_HEART_BEAT_INTERVAL)
controller.send_heart_beat()
@sgl.function
def pipeline(s, prompt, max_tokens):
for p in prompt:
if type(p) is str:
s += p
else:
s += sgl.image(p)
s += sgl.gen("response", max_tokens=max_tokens)
class ModelWorker:
def __init__(
self,
controller_addr,
worker_addr,
sgl_endpoint,
worker_id,
no_register,
model_name,
):
self.controller_addr = controller_addr
self.worker_addr = worker_addr
self.worker_id = worker_id
# Select backend
backend = RuntimeEndpoint(sgl_endpoint)
sgl.set_default_backend(backend)
model_path = backend.model_info["model_path"]
if model_path.endswith("/"):
model_path = model_path[:-1]
if model_name is None:
model_paths = model_path.split("/")
if model_paths[-1].startswith("checkpoint-"):
self.model_name = model_paths[-2] + "_" + model_paths[-1]
else:
self.model_name = model_paths[-1]
else:
self.model_name = model_name
logger.info(
f"Loading the SGLANG model {self.model_name} on worker {worker_id} ..."
)
if not no_register:
self.register_to_controller()
self.heart_beat_thread = threading.Thread(
target=heart_beat_worker, args=(self,), daemon=True
)
self.heart_beat_thread.start()
def register_to_controller(self):
logger.info("Register to controller")
url = self.controller_addr + "/register_worker"
data = {
"worker_name": self.worker_addr,
"check_heart_beat": True,
"worker_status": self.get_status(),
}
r = requests.post(url, json=data)
assert r.status_code == 200
def send_heart_beat(self):
logger.info(
f"Send heart beat. Models: {[self.model_name]}. "
f"Semaphore: {pretty_print_semaphore(model_semaphore)}. "
f"global_counter: {global_counter}"
)
url = self.controller_addr + "/receive_heart_beat"
while True:
try:
ret = requests.post(
url,
json={
"worker_name": self.worker_addr,
"queue_length": self.get_queue_length(),
},
timeout=5,
)
exist = ret.json()["exist"]
break
except requests.exceptions.RequestException as e:
logger.error(f"heart beat error: {e}")
time.sleep(5)
if not exist:
self.register_to_controller()
def get_queue_length(self):
if model_semaphore is None:
return 0
else:
return (
args.limit_model_concurrency
- model_semaphore._value
+ (
len(model_semaphore._waiters)
if model_semaphore._waiters is not None
else 0
)
)
def get_status(self):
return {
"model_names": [self.model_name],
"speed": 1,
"queue_length": self.get_queue_length(),
}
async def generate_stream(self, params):
ori_prompt = prompt = params["prompt"]
images = params.get("images", None)
if images is not None and len(images) > 0:
if len(images) > 0:
if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):
raise ValueError(
"Number of images does not match number of <image> tokens in prompt"
)
images = [load_image_from_base64(image) for image in images]
# FIXME: for image-start/end token
# replace_token = DEFAULT_IMAGE_TOKEN
# if getattr(self.model.config, 'mm_use_im_start_end', False):
# replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
# prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
prompt = prompt.replace(
" " + DEFAULT_IMAGE_TOKEN + "\n", DEFAULT_IMAGE_TOKEN
)
prompt_split = prompt.split(DEFAULT_IMAGE_TOKEN)
prompt = []
for i in range(len(prompt_split)):
prompt.append(prompt_split[i])
if i < len(images):
prompt.append(images[i])
else:
prompt = [prompt]
temperature = float(params.get("temperature", 1.0))
top_p = float(params.get("top_p", 1.0))
# max_context_length = getattr(model.config, 'max_position_embeddings', 2048)
max_new_tokens = min(int(params.get("max_new_tokens", 256)), 1024)
stop_str = params.get("stop", None)
stop_str = [stop_str] if stop_str is not None else None
print(
{
"prompt": prompt,
"max_new_tokens": max_new_tokens,
"temperature": temperature,
"top_p": top_p,
}
)
state = pipeline.run(
prompt, max_new_tokens, temperature=temperature, top_p=top_p, stream=True
)
generated_text = ori_prompt
async for text_outputs in state.text_async_iter(var_name="response"):
generated_text += text_outputs
yield json.dumps({"text": generated_text, "error_code": 0}).encode() + b"\0"
async def generate_stream_gate(self, params):
try:
async for x in self.generate_stream(params):
yield x
except ValueError as e:
print("Caught ValueError:", e)
ret = {
"text": server_error_msg,
"error_code": 1,
}
yield json.dumps(ret).encode() + b"\0"
except Exception as e:
print("Caught Unknown Error", e)
ret = {
"text": server_error_msg,
"error_code": 1,
}
yield json.dumps(ret).encode() + b"\0"
app = FastAPI()
def release_model_semaphore(fn=None):
model_semaphore.release()
if fn is not None:
fn()
@app.post("/worker_generate_stream")
async def generate_stream(request: Request):
global model_semaphore, global_counter
global_counter += 1
params = await request.json()
if model_semaphore is None:
model_semaphore = asyncio.Semaphore(args.limit_model_concurrency)
await model_semaphore.acquire()
worker.send_heart_beat()
generator = worker.generate_stream_gate(params)
background_tasks = BackgroundTasks()
background_tasks.add_task(
partial(release_model_semaphore, fn=worker.send_heart_beat)
)
return StreamingResponse(generator, background=background_tasks)
@app.post("/worker_get_status")
async def get_status(request: Request):
return worker.get_status()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=21002)
parser.add_argument("--worker-address", type=str, default="http://localhost:21002")
parser.add_argument(
"--controller-address", type=str, default="http://localhost:21001"
)
parser.add_argument("--model-name", type=str)
parser.add_argument("--sgl-endpoint", type=str)
parser.add_argument("--limit-model-concurrency", type=int, default=5)
parser.add_argument("--stream-interval", type=int, default=1)
parser.add_argument("--no-register", action="store_true")
args = parser.parse_args()
logger.info(f"args: {args}")
worker = ModelWorker(
args.controller_address,
args.worker_address,
args.sgl_endpoint,
worker_id,
args.no_register,
args.model_name,
)
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
import argparse
import json
import requests
from llava.conversation import default_conversation
def main():
if args.worker_address:
worker_addr = args.worker_address
else:
controller_addr = args.controller_address
ret = requests.post(controller_addr + "/refresh_all_workers")
ret = requests.post(controller_addr + "/list_models")
models = ret.json()["models"]
models.sort()
print(f"Models: {models}")
ret = requests.post(
controller_addr + "/get_worker_address", json={"model": args.model_name}
)
worker_addr = ret.json()["address"]
print(f"worker_addr: {worker_addr}")
if worker_addr == "":
return
conv = default_conversation.copy()
conv.append_message(conv.roles[0], args.message)
prompt = conv.get_prompt()
headers = {"User-Agent": "LLaVA Client"}
pload = {
"model": args.model_name,
"prompt": prompt,
"max_new_tokens": args.max_new_tokens,
"temperature": 0.7,
"stop": conv.sep,
}
response = requests.post(
worker_addr + "/worker_generate_stream",
headers=headers,
json=pload,
stream=True,
)
print(prompt.replace(conv.sep, "\n"), end="")
for chunk in response.iter_lines(
chunk_size=8192, decode_unicode=False, delimiter=b"\0"
):
if chunk:
data = json.loads(chunk.decode("utf-8"))
output = data["text"].split(conv.sep)[-1]
print(output, end="\r")
print("")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--controller-address", type=str, default="http://localhost:21001"
)
parser.add_argument("--worker-address", type=str)
parser.add_argument("--model-name", type=str, default="facebook/opt-350m")
parser.add_argument("--max-new-tokens", type=int, default=32)
parser.add_argument(
"--message", type=str, default="Tell me a story with more than 1000 words."
)
args = parser.parse_args()
main()
import datetime
import logging
import logging.handlers
import os
import sys
import requests
from llava.constants import LOGDIR
server_error_msg = (
"**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
)
moderation_msg = (
"YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN."
)
handler = None
def build_logger(logger_name, logger_filename):
global handler
formatter = logging.Formatter(
fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
# Set the format of root handlers
if not logging.getLogger().handlers:
logging.basicConfig(level=logging.INFO)
logging.getLogger().handlers[0].setFormatter(formatter)
# Redirect stdout and stderr to loggers
stdout_logger = logging.getLogger("stdout")
stdout_logger.setLevel(logging.INFO)
sl = StreamToLogger(stdout_logger, logging.INFO)
sys.stdout = sl
stderr_logger = logging.getLogger("stderr")
stderr_logger.setLevel(logging.ERROR)
sl = StreamToLogger(stderr_logger, logging.ERROR)
sys.stderr = sl
# Get logger
logger = logging.getLogger(logger_name)
logger.setLevel(logging.INFO)
# Add a file handler for all loggers
if handler is None:
os.makedirs(LOGDIR, exist_ok=True)
filename = os.path.join(LOGDIR, logger_filename)
handler = logging.handlers.TimedRotatingFileHandler(
filename, when="D", utc=True, encoding="UTF-8"
)
handler.setFormatter(formatter)
for name, item in logging.root.manager.loggerDict.items():
if isinstance(item, logging.Logger):
item.addHandler(handler)
return logger
class StreamToLogger(object):
"""
Fake file-like stream object that redirects writes to a logger instance.
"""
def __init__(self, logger, log_level=logging.INFO):
self.terminal = sys.stdout
self.logger = logger
self.log_level = log_level
self.linebuf = ""
def __getattr__(self, attr):
return getattr(self.terminal, attr)
def write(self, buf):
temp_linebuf = self.linebuf + buf
self.linebuf = ""
for line in temp_linebuf.splitlines(True):
# From the io.TextIOWrapper docs:
# On output, if newline is None, any '\n' characters written
# are translated to the system default line separator.
# By default sys.stdout.write() expects '\n' newlines and then
# translates them so this is still cross platform.
if line[-1] == "\n":
self.logger.log(self.log_level, line.rstrip())
else:
self.linebuf += line
def flush(self):
if self.linebuf != "":
self.logger.log(self.log_level, self.linebuf.rstrip())
self.linebuf = ""
def disable_torch_init():
"""
Disable the redundant torch default initialization to accelerate model creation.
"""
import torch
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
def violates_moderation(text):
"""
Check whether the text violates OpenAI moderation API.
"""
url = "https://api.openai.com/v1/moderations"
headers = {
"Content-Type": "application/json",
"Authorization": "Bearer " + os.environ["OPENAI_API_KEY"],
}
text = text.replace("\n", "")
data = "{" + '"input": ' + f'"{text}"' + "}"
data = data.encode("utf-8")
try:
ret = requests.post(url, headers=headers, data=data, timeout=5)
flagged = ret.json()["results"][0]["flagged"]
except requests.exceptions.RequestException as e:
flagged = False
except KeyError as e:
flagged = False
return flagged
def pretty_print_semaphore(semaphore):
if semaphore is None:
return "None"
return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
import os
import argparse
import pandas as pd
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--img_dir", type=str, default="mllm/images")
parser.add_argument("--input_file", type=str, default="mllm/images/demo.csv")
args = parser.parse_args()
df = pd.DataFrame(columns=["img_path"])
df["img_path"] = [
os.path.join(args.img_dir, fn)
for fn in os.listdir(args.img_dir)
if fn.endswith(".png")
]
df.to_csv(args.input_file, index=False)
print("csv file saved to: ", args.input_file)
# Hunyuan-MLLM-TRTLLM
We provide TensorRT-LLM (precision: int8 weight-only) version of Hunyuan-Captioner for inference acceleration(for Linux).
## Hunyuan-Captioner-TRTLLM
### Instructions
a. Retrieve and launch the docker container
For a list of the supported hardware see the [**Frameworks Support Matrix**](https://docs.nvidia.com/deeplearning/frameworks/support-matrix/index.html).
It should be noted that Nvidia’s official documentation does not list the support list for consumer-grade graphics cards. Our tests show that 4090 and 3080 graphics cards are supported, but for other consumer-grade graphics cards, we cannot guarantee whether you will encounter performance problems or some bugs, please experiment by yourself.
```bash
docker pull nvcr.io/nvidia/tritonserver:24.06-trtllm-python-py3
docker run --rm --ipc=host --runtime=nvidia --gpus all --entrypoint /bin/bash -it nvcr.io/nvidia/tritonserver:24.06-trtllm-python-py3
```
b. Download Torch model
```shell
huggingface-cli download Tencent-Hunyuan/HunyuanCaptioner --local-dir ../ckpts/captioner
```
b. Build TRTLLM engine
```shell
sh build_trtllm.sh
```
### Inference
Our model supports three different modes including: **directly generating Chinese caption**, **generating Chinese caption based on specific knowledge**, and **directly generating English caption**. The injected information can be either accurate cues or noisy labels (e.g., raw descriptions crawled from the internet). The model is capable of generating reliable and accurate descriptions based on both the inserted information and the image content.
|Mode | Prompt Template |Description |
| --- | --- | --- |
|caption_zh | 描述这张图片 |Caption in Chinese |
|insert_content | 根据提示词“{}”,描述这张图片 |Caption with inserted knowledge|
|caption_en | Please describe the content of this image |Caption in English |
| | | |
a. Single picture inference in Chinese
```bash
python3 run_llava.py --max_new_tokens 512 --hf_model_dir ./llava-v1.6-mistral-7b-hf-merged/ --visual_engine_dir visual_engines/ --llm_engine_dir trt_engines/llava/int8/1-gpu --mode caption_zh --image_file ../images/demo1.png
```
b. Insert specific knowledge into caption
```bash
python3 run_llava.py --max_new_tokens 512 --hf_model_dir ./llava-v1.6-mistral-7b-hf-merged/ --visual_engine_dir visual_engines/ --llm_engine_dir trt_engines/llava/int8/1-gpu --mode insert_content --image_file ../images/demo2.png --content 宫保鸡丁
```
c. Single picture inference in English
```bash
python3 run_llava.py --max_new_tokens 512 --hf_model_dir ./llava-v1.6-mistral-7b-hf-merged/ --visual_engine_dir visual_engines/ --llm_engine_dir trt_engines/llava/int8/1-gpu --mode caption_en --image_file ../images/demo1.png
```
d. Multiple pictures inference in Chinese
```bash
### Convert multiple pictures to csv file.
python3 ../make_csv.py --img_dir ../images --input_file ../images/demo.csv
### Multiple pictures inference
python3 run_llava.py --max_new_tokens 512 --hf_model_dir ./llava-v1.6-mistral-7b-hf-merged/ --visual_engine_dir visual_engines/ --llm_engine_dir trt_engines/llava/int8/1-gpu --mode caption_zh --input_file ../images/demo.csv --output_file ../images/demo_res.csv
```
### Benchmark
|Hardware | GPU Memory Usage (GB) |TRTLLM Inference Duration(s) |
| --- | --- | --- |
|A100 | 8.9 | 0.73 |
|4090 | 8.7 |0.73|
|3080 | 8.7 |1.16
sh env.sh
git clone https://huggingface.co/llava-hf/llava-v1.6-mistral-7b-hf
python3 rename_key.py --huggingface_repo_dir ./llava-v1.6-mistral-7b-hf/ --thirdparty_repo_dir ../ckpts/captioner/ --merged_repo_dir ./llava-v1.6-mistral-7b-hf-merged/
python3 convert_checkpoint.py --model_dir ./llava-v1.6-mistral-7b-hf-merged/ --output_dir tmp/trt_models/llava/int8/1-gpu --dtype float16 --use_weight_only --weight_only_precision int8
trtllm-build --checkpoint_dir tmp/trt_models/llava/int8/1-gpu --output_dir trt_engines/llava/int8/1-gpu --gemm_plugin float16 --max_batch_size 1 --max_input_len 2048 --max_output_len 512 --max_multimodal_len 576
python3 build_visual_engine.py --model_path ./llava-v1.6-mistral-7b-hf-merged/ --model_type llava_next
import argparse
import os
import shutil
import sys
import tarfile
from time import time
import yaml
# isort: off
import torch
import tensorrt as trt
from tensorrt_llm.builder import Builder
# isort: on
import torch.nn.functional as F
from PIL import Image
from transformers import (
AutoConfig,
AutoModel,
AutoModelForCausalLM,
AutoModelForVision2Seq,
AutoProcessor,
Blip2ForConditionalGeneration,
Blip2Processor,
FuyuForCausalLM,
FuyuProcessor,
LlavaForConditionalGeneration,
LlavaNextForConditionalGeneration,
NougatProcessor,
Pix2StructForConditionalGeneration,
VisionEncoderDecoderModel,
)
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_type",
type=str,
default=None,
choices=[
"opt-2.7b",
"opt-6.7b",
"flan-t5-xl",
"flan-t5-xxl",
"llava",
"llava_next",
"vila",
"nougat",
"cogvlm",
"fuyu",
"pix2struct",
"neva",
"kosmos-2",
],
help="Model type",
)
parser.add_argument(
"--model_path",
type=str,
default=None,
help="Huggingface repo, local directory with weights or path to checkpoint file",
)
parser.add_argument(
"--vila_path", type=str, default=None, help="Path to VILA source code directory"
)
parser.add_argument(
"--output_dir",
type=str,
default=None,
help="Directory where visual TRT engines are saved",
)
parser.add_argument(
"--max_batch_size",
type=int,
default=4,
help="Maximum batch size for input images",
)
return parser.parse_args()
class VisionEngineBuilder:
def __init__(self, args):
args.device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
if args.output_dir is None:
args.output_dir = "visual_engines/%s" % (
args.model_path.split("/")[-1].split(".")[0]
)
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
self.args = args
def build(self):
args = self.args
if "opt" in args.model_type or "t5" in args.model_type:
build_blip2_engine(args)
elif args.model_type == "pix2struct":
build_pix2struct_engine(args)
elif args.model_type == "llava":
build_llava_engine(args)
elif args.model_type == "llava_next":
build_llava_next_engine(args)
elif args.model_type == "vila":
assert (
args.vila_path is not None
), "Please clone and provide VILA source code path"
build_vila_engine(args)
elif args.model_type == "nougat":
build_nougat_engine(args)
elif args.model_type == "cogvlm":
build_cogvlm_engine(args)
elif args.model_type == "fuyu":
build_fuyu_engine(args)
elif args.model_type == "neva":
build_neva_engine(args)
elif args.model_type == "kosmos-2":
build_kosmos_engine(args)
else:
raise RuntimeError(f"Invalid model type {args.model_type}")
def export_visual_wrapper_onnx(
visual_wrapper,
input,
output_dir,
input_names=["input"],
dynamic_axes={"input": {0: "batch"}},
):
logger.log(trt.Logger.INFO, "Exporting onnx")
os.makedirs(f"{output_dir}/onnx", exist_ok=True)
torch.onnx.export(
visual_wrapper,
input,
f"{output_dir}/onnx/visual_encoder.onnx",
opset_version=17,
input_names=input_names,
output_names=["output"],
dynamic_axes=dynamic_axes,
)
def build_trt_engine(
model_type, input_sizes, output_dir, max_batch_size, dtype=torch.float16
):
part_name = "visual_encoder"
onnx_file = "%s/onnx/%s.onnx" % (output_dir, part_name)
engine_file = "%s/%s.engine" % (output_dir, part_name)
config_file = "%s/%s" % (output_dir, "config.json")
logger.log(trt.Logger.INFO, "Building TRT engine for %s" % part_name)
builder = trt.Builder(logger)
network = builder.create_network(
1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
)
profile = builder.create_optimization_profile()
config_wrapper = Builder().create_builder_config(
precision=str(dtype).split(".")[-1], model_type=model_type
)
config = config_wrapper.trt_builder_config
parser = trt.OnnxParser(network, logger)
with open(onnx_file, "rb") as model:
if not parser.parse(model.read(), os.path.abspath(onnx_file)):
logger.log(trt.Logger.ERROR, "Failed parsing %s" % onnx_file)
for error in range(parser.num_errors):
logger.log(trt.Logger.ERROR, parser.get_error(error))
logger.log(trt.Logger.INFO, "Succeeded parsing %s" % onnx_file)
# Delete onnx files since we don't need them now
# shutil.rmtree(f'{output_dir}/onnx')
nBS = -1
nMinBS = 1
nOptBS = max(nMinBS, int(max_batch_size / 2))
nMaxBS = max_batch_size
inputT = network.get_input(0)
# input sizes can be a list of ints (e.g., [3, H, W]) when inputs are images,
# or a list of three int lists (e.g., [[1, 1, 2700], [1, 500, 2700], [1, 4096, 2700]]).
assert isinstance(input_sizes, list), "input_sizes must be a list"
if isinstance(input_sizes[0], int):
logger.log(trt.Logger.INFO, f"Processed input sizes {input_sizes}")
inputT.shape = [nBS, *input_sizes]
min_size = opt_size = max_size = input_sizes
elif len(input_sizes) == 3 and isinstance(input_sizes[0], list):
min_size, opt_size, max_size = input_sizes
logger.log(
trt.Logger.INFO,
f"Processed min/opt/max input sizes {min_size}/{opt_size}/{max_size}",
)
else:
raise ValueError(f"invalid input sizes: {input_sizes}")
profile.set_shape(
inputT.name, [nMinBS, *min_size], [nOptBS, *opt_size], [nMaxBS, *max_size]
)
if model_type == "pix2struct":
inputT = network.get_input(1)
P = input_sizes[0] # Number of patches
inputT.shape = [nBS, P]
profile.set_shape(inputT.name, [nMinBS, P], [nOptBS, P], [nMaxBS, P])
config.add_optimization_profile(profile)
t0 = time()
engine_string = builder.build_serialized_network(network, config)
t1 = time()
if engine_string is None:
raise RuntimeError("Failed building %s" % (engine_file))
else:
logger.log(
trt.Logger.INFO, "Succeeded building %s in %d s" % (engine_file, t1 - t0)
)
with open(engine_file, "wb") as f:
f.write(engine_string)
Builder.save_config(config_wrapper, config_file)
def build_blip2_engine(args):
model_type = "Salesforce/blip2-" + args.model_type
processor = Blip2Processor.from_pretrained(model_type)
raw_image = Image.new("RGB", [10, 10]) # dummy image
prompt = "Question: what is this? Answer:"
inputs = processor(raw_image, prompt, return_tensors="pt").to(
args.device, torch.float16
)
image = inputs["pixel_values"]
class Blip2VisionWrapper(torch.nn.Module):
def __init__(self, vision_model, qformer, projector, query_tokens):
super().__init__()
self.vision_model = vision_model
self.qformer = qformer
self.projector = projector
self.query_tokens = query_tokens
def forward(self, image):
features = self.vision_model(image)[0]
qformer_output = self.qformer(
query_embeds=self.query_tokens,
encoder_hidden_states=features,
return_dict=True,
)
return self.projector(qformer_output.last_hidden_state)
model = Blip2ForConditionalGeneration.from_pretrained(
model_type, torch_dtype=torch.float16
)
wrapper = Blip2VisionWrapper(
model.vision_model, model.qformer, model.language_projection, model.query_tokens
)
wrapper.to(args.device)
export_visual_wrapper_onnx(wrapper, image, args.output_dir)
build_trt_engine(
model_type,
[image.shape[1], image.shape[2], image.shape[3]], # [3, H, W]
args.output_dir,
args.max_batch_size,
)
def build_pix2struct_engine(args):
processor = AutoProcessor.from_pretrained(args.model_path)
raw_image = Image.new("RGB", [10, 10]) # dummy image
dtype = torch.float16
inputs = processor(text="dummy", images=raw_image, return_tensors="pt")
image = inputs["flattened_patches"].to(args.device, dtype)
attention_mask = inputs["attention_mask"].to(args.device, torch.int)
class pix2structVisionWrapper(torch.nn.Module):
def __init__(self, encoder):
super().__init__()
self.encoder = encoder
def forward(self, image, attention_mask):
vision_x = self.encoder.embeddings(image)
img_features = self.encoder.encoder(vision_x, attention_mask=attention_mask)
img_features = self.encoder.layernorm(img_features[0])
return img_features
model = Pix2StructForConditionalGeneration.from_pretrained(
args.model_path, torch_dtype=dtype
)
wrapper = pix2structVisionWrapper(model.encoder.to(args.device))
# input shape: batch size, number of patches, hidden dimension
# attention mask shape: batch size, number of patches
# The number of image patches can vary depending on the image size, but it typically
# falls within a relatively narrow range. To improve performance, we can avoid using
# dynamic axis for the input patches and instead use a fixed number of patches along
# with an attention mask.
export_visual_wrapper_onnx(
wrapper,
(image, attention_mask),
args.output_dir,
input_names=["input", "attention_mask"],
dynamic_axes={"input": {0: "batch"}, "attention_mask": {0: "batch"}},
)
build_trt_engine(
args.model_type,
[image.shape[1], image.shape[2]], # Number of Patches, Hidden Dimension
args.output_dir,
args.max_batch_size,
torch.bfloat16,
)
def build_llava_engine(args):
processor = AutoProcessor.from_pretrained(args.model_path)
raw_image = Image.new("RGB", [10, 10]) # dummy image
image = processor(text="dummy", images=raw_image, return_tensors="pt")[
"pixel_values"
].to(args.device, torch.float16)
class LlavaVisionWrapper(torch.nn.Module):
def __init__(self, tower, projector, feature_layer):
super().__init__()
self.tower = tower
self.projector = projector
self.feature_layer = feature_layer
def forward(self, image):
all_hidden_states = self.tower(
image, output_hidden_states=True
).hidden_states
features = all_hidden_states[self.feature_layer][:, 1:]
return self.projector(features)
model = LlavaForConditionalGeneration.from_pretrained(
args.model_path, torch_dtype=torch.float16
)
wrapper = LlavaVisionWrapper(
model.vision_tower.to(args.device),
model.multi_modal_projector.to(args.device),
model.config.vision_feature_layer,
)
export_visual_wrapper_onnx(wrapper, image, args.output_dir)
build_trt_engine(
args.model_type,
[image.shape[1], image.shape[2], image.shape[3]], # [3, H, W]
args.output_dir,
args.max_batch_size,
)
def build_llava_next_engine(args):
processor = AutoProcessor.from_pretrained(args.model_path)
# raw_image = Image.new('RGB', [10, 10]) # dummy image
import requests
url = "https://www.ilankelman.org/stopsigns/australia.jpg"
raw_image = Image.open(requests.get(url, stream=True).raw)
image = processor(text="dummy", images=raw_image, return_tensors="pt")[
"pixel_values"
].to(args.device, torch.float16)
class LlavaVisionWrapper(torch.nn.Module):
def __init__(self, tower, projector, feature_layer):
super().__init__()
self.tower = tower
self.projector = projector
self.feature_layer = feature_layer
def forward(self, image):
all_hidden_states = self.tower(
image, output_hidden_states=True
).hidden_states
features = all_hidden_states[self.feature_layer][:, 1:]
return self.projector(features)
model = LlavaNextForConditionalGeneration.from_pretrained(
args.model_path, torch_dtype=torch.float16
)
wrapper = LlavaVisionWrapper(
model.vision_tower.to(args.device),
model.multi_modal_projector.to(args.device),
model.config.vision_feature_layer,
)
# 2. Merge text and images
# ! infer image_num_patches from image_sizes
pixel_values = image
image_num_patches = [pixel_values.shape[1]]
# figure out if pixel_values is concatenated or stacked
if image.dim() == 5:
# stacking when input is (batch_size, num_patches, num_channels, height, width)
_pixel_values_list = [
pix_val[:num_patch]
for pix_val, num_patch in zip(pixel_values, image_num_patches)
]
pixel_values = torch.cat(_pixel_values_list, dim=0)
elif pixel_values.dim() != 4:
# otherwise has to be stacked from list of (num_patches, num_channels, height, width)
raise ValueError(
f"pixel_values of shape {pixel_values.shape}, expect to be of 4 or 5 dimensions"
)
print("------Debug image: ", pixel_values, pixel_values.shape)
image = pixel_values
export_visual_wrapper_onnx(wrapper, image, args.output_dir)
build_trt_engine(
args.model_type,
[image.shape[1], image.shape[2], image.shape[3]], # [3, H, W]
args.output_dir,
args.max_batch_size,
)
def build_vila_engine(args):
# Note: VILA model is not in public HF model zoo yet. We need to explicitly import from the git repo
sys.path.append(args.vila_path)
from llava.model import LlavaLlamaForCausalLM
model = LlavaLlamaForCausalLM.from_pretrained(
args.model_path, torch_dtype=torch.float16
)
vision_tower = model.get_vision_tower()
image_processor = vision_tower.image_processor
raw_image = Image.new("RGB", [10, 10]) # dummy image
image = image_processor(images=raw_image, return_tensors="pt")["pixel_values"].to(
args.device, torch.float16
)
class VilaVisionWrapper(torch.nn.Module):
def __init__(self, tower, projector):
super().__init__()
self.tower = tower
self.projector = projector
def forward(self, image):
features = self.tower(image)
return self.projector(features)
model = LlavaLlamaForCausalLM.from_pretrained(
args.model_path, torch_dtype=torch.float16
)
wrapper = VilaVisionWrapper(
model.get_model().get_vision_tower().to(args.device),
model.get_model().mm_projector.to(args.device),
)
export_visual_wrapper_onnx(wrapper, image, args.output_dir)
build_trt_engine(
args.model_type,
[image.shape[1], image.shape[2], image.shape[3]], # [3, H, W]
args.output_dir,
args.max_batch_size,
)
def build_nougat_engine(args):
processor = NougatProcessor.from_pretrained(args.model_path)
raw_image = Image.new("RGB", [10, 10]) # dummy image
image = processor(raw_image, return_tensors="pt")["pixel_values"].to(
args.device, torch.float16
)
class SwinEncoderWrapper(torch.nn.Module):
def __init__(self, encoder):
super().__init__()
self.encoder = encoder
def forward(self, image):
return self.encoder(image).last_hidden_state
model = VisionEncoderDecoderModel.from_pretrained(
args.model_path, torch_dtype=torch.float16
)
swin_encoder = model.get_encoder().to(args.device)
wrapper = SwinEncoderWrapper(swin_encoder)
export_visual_wrapper_onnx(wrapper, image, args.output_dir)
build_trt_engine(
args.model_type,
[image.shape[1], image.shape[2], image.shape[3]], # [3, H, W]
args.output_dir,
args.max_batch_size,
)
def build_cogvlm_engine(args):
hf_config = AutoConfig.from_pretrained(args.model_path, trust_remote_code=True)
image_size = hf_config.vision_config["image_size"]
dtype = hf_config.torch_dtype
image = torch.empty(
1, 3, image_size, image_size, dtype=dtype, device=args.device
) # dummy image
class CogVlmVisionWrapper(torch.nn.Module):
def __init__(self, encoder):
super().__init__()
self.encoder = encoder
def forward(self, image):
return self.encoder(image)
cogvlm = AutoModelForCausalLM.from_pretrained(
args.model_path, torch_dtype=dtype, trust_remote_code=True
)
vit_encoder = cogvlm.model.vision.to(args.device).eval()
wrapper = CogVlmVisionWrapper(vit_encoder)
export_visual_wrapper_onnx(wrapper, image, args.output_dir)
build_trt_engine(
args.model_type,
[image.shape[1], image.shape[2], image.shape[3]], # [3, H, W]
args.output_dir,
args.max_batch_size,
dtype,
)
def build_fuyu_engine(args):
processor = FuyuProcessor.from_pretrained(args.model_path)
raw_image = Image.new("RGB", [10, 10])
image = (
processor(text="dummy", images=raw_image, return_tensors="pt")["image_patches"][
0
]
.to(args.device, torch.float16)
.unsqueeze(0)
)
class FuyuEncoderWrapper(torch.nn.Module):
def __init__(self, linear):
super().__init__()
self.linear = linear.to(torch.float16)
def forward(self, patches):
return self.linear(patches).flatten(0, 1)
model = FuyuForCausalLM.from_pretrained(args.model_path, torch_dtype=torch.float16)
vision_encoder = model.vision_embed_tokens
wrapper = FuyuEncoderWrapper(vision_encoder).to(args.device)
export_visual_wrapper_onnx(
wrapper,
image,
args.output_dir,
dynamic_axes={"input": {0: "batch", 2: "patch"}},
)
build_trt_engine(
args.model_type,
# [nImgs, nImgPatches, nDims]
# nImgs is always one since each query has exactly one image
# nImgPatches depends on image size (patch size: 30x30)
# nDims is 30x30x3=2700 (patch size x color channels)
[[1, 1, 2700], [1, 500, 2700], [1, 4096, 2700]],
args.output_dir,
args.max_batch_size,
)
def build_neva_engine(args):
# extract NeMo checkpoint
with tarfile.open(args.model_path) as tar:
nemo_config = yaml.safe_load(tar.extractfile("./model_config.yaml"))
try:
# trained without TP
mp0_weights = torch.load(
tar.extractfile("./model_weights.ckpt"), map_location=args.device
)
except KeyError:
# trained with TP
mp0_weights = torch.load(
tar.extractfile("./mp_rank_00/model_weights.ckpt"),
map_location=args.device,
)
vision_config = nemo_config["mm_cfg"]["vision_encoder"]
class VisionEncoderWrapper(torch.nn.Module):
def __init__(self, encoder, connector):
super().__init__()
self.encoder = encoder
self.connector = connector
def forward(self, images):
vision_x = self.encoder(pixel_values=images, output_hidden_states=True)
vision_x = vision_x.hidden_states[-2]
vision_x = vision_x[:, 1:]
vision_x = self.connector(vision_x)
return vision_x
encoder = AutoModel.from_pretrained(
vision_config["from_pretrained"],
torch_dtype=torch.bfloat16,
trust_remote_code=True,
)
vision_encoder = encoder.vision_model
hf_config = encoder.config
dtype = hf_config.torch_dtype
# connector
assert nemo_config["mm_cfg"]["mm_mlp_adapter_type"] == "mlp2x_gelu"
vision_connector = torch.nn.Sequential(
torch.nn.Linear(
vision_config["hidden_size"], nemo_config["hidden_size"], bias=True
),
torch.nn.GELU(),
torch.nn.Linear(
nemo_config["hidden_size"], nemo_config["hidden_size"], bias=True
),
).to(dtype=dtype)
key_prefix = "model.embedding.word_embeddings.adapter_layer.mm_projector_adapter.mm_projector"
for layer in range(0, 3, 2):
vision_connector[layer].load_state_dict(
{
"weight": mp0_weights[f"{key_prefix}.{layer}.weight"].to(dtype),
"bias": mp0_weights[f"{key_prefix}.{layer}.bias"].to(dtype),
}
)
# export the whole wrapper
wrapper = VisionEncoderWrapper(vision_encoder, vision_connector).to(
args.device, dtype
)
image_size = hf_config.vision_config.image_size
dummy_image = torch.empty(
1, 3, image_size, image_size, dtype=dtype, device=args.device
) # dummy image
export_visual_wrapper_onnx(wrapper, dummy_image, args.output_dir)
build_trt_engine(
args.model_type,
[3, image_size, image_size], # [3, H, W]
args.output_dir,
args.max_batch_size,
dtype,
)
def build_kosmos_engine(args):
processor = AutoProcessor.from_pretrained(args.model_path)
raw_image = Image.new("RGB", [10, 10]) # dummy image
image = processor(text="dummy", images=raw_image, return_tensors="pt")[
"pixel_values"
].to(args.device, torch.float16)
class VisionEncoderWrapper(torch.nn.Module):
def __init__(self, encoder, connector):
super().__init__()
self.encoder = encoder
self.connector = connector
def forward(self, images):
vision_x = self.encoder(images, output_hidden_states=True)
img_features = self.encoder.model.post_layernorm(vision_x.last_hidden_state)
img_features = F.normalize(img_features, dim=-1)
img_features, _ = self.connector(img_features)
return img_features
model = AutoModelForVision2Seq.from_pretrained(
args.model_path, torch_dtype=torch.float16
)
wrapper = VisionEncoderWrapper(
model.vision_model.to(args.device),
model.image_to_text_projection.to(args.device),
)
export_visual_wrapper_onnx(wrapper, image, args.output_dir)
build_trt_engine(
args.model_type,
[image.shape[1], image.shape[2], image.shape[3]], # [3, H, W]
args.output_dir,
args.max_batch_size,
)
if __name__ == "__main__":
logger = trt.Logger(trt.Logger.INFO)
args = parse_arguments()
builder = VisionEngineBuilder(args)
builder.build()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment