import argparse import os import shutil import sys import tarfile from time import time import yaml # isort: off import torch import tensorrt as trt from tensorrt_llm.builder import Builder # isort: on import torch.nn.functional as F from PIL import Image from transformers import ( AutoConfig, AutoModel, AutoModelForCausalLM, AutoModelForVision2Seq, AutoProcessor, Blip2ForConditionalGeneration, Blip2Processor, FuyuForCausalLM, FuyuProcessor, LlavaForConditionalGeneration, LlavaNextForConditionalGeneration, NougatProcessor, Pix2StructForConditionalGeneration, VisionEncoderDecoderModel, ) def parse_arguments(): parser = argparse.ArgumentParser() parser.add_argument( "--model_type", type=str, default=None, choices=[ "opt-2.7b", "opt-6.7b", "flan-t5-xl", "flan-t5-xxl", "llava", "llava_next", "vila", "nougat", "cogvlm", "fuyu", "pix2struct", "neva", "kosmos-2", ], help="Model type", ) parser.add_argument( "--model_path", type=str, default=None, help="Huggingface repo, local directory with weights or path to checkpoint file", ) parser.add_argument( "--vila_path", type=str, default=None, help="Path to VILA source code directory" ) parser.add_argument( "--output_dir", type=str, default=None, help="Directory where visual TRT engines are saved", ) parser.add_argument( "--max_batch_size", type=int, default=4, help="Maximum batch size for input images", ) return parser.parse_args() class VisionEngineBuilder: def __init__(self, args): args.device = torch.device("cuda") if torch.cuda.is_available() else "cpu" if args.output_dir is None: args.output_dir = "visual_engines/%s" % ( args.model_path.split("/")[-1].split(".")[0] ) if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) self.args = args def build(self): args = self.args if "opt" in args.model_type or "t5" in args.model_type: build_blip2_engine(args) elif args.model_type == "pix2struct": build_pix2struct_engine(args) elif args.model_type == "llava": build_llava_engine(args) elif args.model_type == "llava_next": build_llava_next_engine(args) elif args.model_type == "vila": assert ( args.vila_path is not None ), "Please clone and provide VILA source code path" build_vila_engine(args) elif args.model_type == "nougat": build_nougat_engine(args) elif args.model_type == "cogvlm": build_cogvlm_engine(args) elif args.model_type == "fuyu": build_fuyu_engine(args) elif args.model_type == "neva": build_neva_engine(args) elif args.model_type == "kosmos-2": build_kosmos_engine(args) else: raise RuntimeError(f"Invalid model type {args.model_type}") def export_visual_wrapper_onnx( visual_wrapper, input, output_dir, input_names=["input"], dynamic_axes={"input": {0: "batch"}}, ): logger.log(trt.Logger.INFO, "Exporting onnx") os.makedirs(f"{output_dir}/onnx", exist_ok=True) torch.onnx.export( visual_wrapper, input, f"{output_dir}/onnx/visual_encoder.onnx", opset_version=17, input_names=input_names, output_names=["output"], dynamic_axes=dynamic_axes, ) def build_trt_engine( model_type, input_sizes, output_dir, max_batch_size, dtype=torch.float16 ): part_name = "visual_encoder" onnx_file = "%s/onnx/%s.onnx" % (output_dir, part_name) engine_file = "%s/%s.engine" % (output_dir, part_name) config_file = "%s/%s" % (output_dir, "config.json") logger.log(trt.Logger.INFO, "Building TRT engine for %s" % part_name) builder = trt.Builder(logger) network = builder.create_network( 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) ) profile = builder.create_optimization_profile() config_wrapper = Builder().create_builder_config( precision=str(dtype).split(".")[-1], model_type=model_type ) config = config_wrapper.trt_builder_config parser = trt.OnnxParser(network, logger) with open(onnx_file, "rb") as model: if not parser.parse(model.read(), os.path.abspath(onnx_file)): logger.log(trt.Logger.ERROR, "Failed parsing %s" % onnx_file) for error in range(parser.num_errors): logger.log(trt.Logger.ERROR, parser.get_error(error)) logger.log(trt.Logger.INFO, "Succeeded parsing %s" % onnx_file) # Delete onnx files since we don't need them now # shutil.rmtree(f'{output_dir}/onnx') nBS = -1 nMinBS = 1 nOptBS = max(nMinBS, int(max_batch_size / 2)) nMaxBS = max_batch_size inputT = network.get_input(0) # input sizes can be a list of ints (e.g., [3, H, W]) when inputs are images, # or a list of three int lists (e.g., [[1, 1, 2700], [1, 500, 2700], [1, 4096, 2700]]). assert isinstance(input_sizes, list), "input_sizes must be a list" if isinstance(input_sizes[0], int): logger.log(trt.Logger.INFO, f"Processed input sizes {input_sizes}") inputT.shape = [nBS, *input_sizes] min_size = opt_size = max_size = input_sizes elif len(input_sizes) == 3 and isinstance(input_sizes[0], list): min_size, opt_size, max_size = input_sizes logger.log( trt.Logger.INFO, f"Processed min/opt/max input sizes {min_size}/{opt_size}/{max_size}", ) else: raise ValueError(f"invalid input sizes: {input_sizes}") profile.set_shape( inputT.name, [nMinBS, *min_size], [nOptBS, *opt_size], [nMaxBS, *max_size] ) if model_type == "pix2struct": inputT = network.get_input(1) P = input_sizes[0] # Number of patches inputT.shape = [nBS, P] profile.set_shape(inputT.name, [nMinBS, P], [nOptBS, P], [nMaxBS, P]) config.add_optimization_profile(profile) t0 = time() engine_string = builder.build_serialized_network(network, config) t1 = time() if engine_string is None: raise RuntimeError("Failed building %s" % (engine_file)) else: logger.log( trt.Logger.INFO, "Succeeded building %s in %d s" % (engine_file, t1 - t0) ) with open(engine_file, "wb") as f: f.write(engine_string) Builder.save_config(config_wrapper, config_file) def build_blip2_engine(args): model_type = "Salesforce/blip2-" + args.model_type processor = Blip2Processor.from_pretrained(model_type) raw_image = Image.new("RGB", [10, 10]) # dummy image prompt = "Question: what is this? Answer:" inputs = processor(raw_image, prompt, return_tensors="pt").to( args.device, torch.float16 ) image = inputs["pixel_values"] class Blip2VisionWrapper(torch.nn.Module): def __init__(self, vision_model, qformer, projector, query_tokens): super().__init__() self.vision_model = vision_model self.qformer = qformer self.projector = projector self.query_tokens = query_tokens def forward(self, image): features = self.vision_model(image)[0] qformer_output = self.qformer( query_embeds=self.query_tokens, encoder_hidden_states=features, return_dict=True, ) return self.projector(qformer_output.last_hidden_state) model = Blip2ForConditionalGeneration.from_pretrained( model_type, torch_dtype=torch.float16 ) wrapper = Blip2VisionWrapper( model.vision_model, model.qformer, model.language_projection, model.query_tokens ) wrapper.to(args.device) export_visual_wrapper_onnx(wrapper, image, args.output_dir) build_trt_engine( model_type, [image.shape[1], image.shape[2], image.shape[3]], # [3, H, W] args.output_dir, args.max_batch_size, ) def build_pix2struct_engine(args): processor = AutoProcessor.from_pretrained(args.model_path) raw_image = Image.new("RGB", [10, 10]) # dummy image dtype = torch.float16 inputs = processor(text="dummy", images=raw_image, return_tensors="pt") image = inputs["flattened_patches"].to(args.device, dtype) attention_mask = inputs["attention_mask"].to(args.device, torch.int) class pix2structVisionWrapper(torch.nn.Module): def __init__(self, encoder): super().__init__() self.encoder = encoder def forward(self, image, attention_mask): vision_x = self.encoder.embeddings(image) img_features = self.encoder.encoder(vision_x, attention_mask=attention_mask) img_features = self.encoder.layernorm(img_features[0]) return img_features model = Pix2StructForConditionalGeneration.from_pretrained( args.model_path, torch_dtype=dtype ) wrapper = pix2structVisionWrapper(model.encoder.to(args.device)) # input shape: batch size, number of patches, hidden dimension # attention mask shape: batch size, number of patches # The number of image patches can vary depending on the image size, but it typically # falls within a relatively narrow range. To improve performance, we can avoid using # dynamic axis for the input patches and instead use a fixed number of patches along # with an attention mask. export_visual_wrapper_onnx( wrapper, (image, attention_mask), args.output_dir, input_names=["input", "attention_mask"], dynamic_axes={"input": {0: "batch"}, "attention_mask": {0: "batch"}}, ) build_trt_engine( args.model_type, [image.shape[1], image.shape[2]], # Number of Patches, Hidden Dimension args.output_dir, args.max_batch_size, torch.bfloat16, ) def build_llava_engine(args): processor = AutoProcessor.from_pretrained(args.model_path) raw_image = Image.new("RGB", [10, 10]) # dummy image image = processor(text="dummy", images=raw_image, return_tensors="pt")[ "pixel_values" ].to(args.device, torch.float16) class LlavaVisionWrapper(torch.nn.Module): def __init__(self, tower, projector, feature_layer): super().__init__() self.tower = tower self.projector = projector self.feature_layer = feature_layer def forward(self, image): all_hidden_states = self.tower( image, output_hidden_states=True ).hidden_states features = all_hidden_states[self.feature_layer][:, 1:] return self.projector(features) model = LlavaForConditionalGeneration.from_pretrained( args.model_path, torch_dtype=torch.float16 ) wrapper = LlavaVisionWrapper( model.vision_tower.to(args.device), model.multi_modal_projector.to(args.device), model.config.vision_feature_layer, ) export_visual_wrapper_onnx(wrapper, image, args.output_dir) build_trt_engine( args.model_type, [image.shape[1], image.shape[2], image.shape[3]], # [3, H, W] args.output_dir, args.max_batch_size, ) def build_llava_next_engine(args): processor = AutoProcessor.from_pretrained(args.model_path) # raw_image = Image.new('RGB', [10, 10]) # dummy image import requests url = "https://www.ilankelman.org/stopsigns/australia.jpg" raw_image = Image.open(requests.get(url, stream=True).raw) image = processor(text="dummy", images=raw_image, return_tensors="pt")[ "pixel_values" ].to(args.device, torch.float16) class LlavaVisionWrapper(torch.nn.Module): def __init__(self, tower, projector, feature_layer): super().__init__() self.tower = tower self.projector = projector self.feature_layer = feature_layer def forward(self, image): all_hidden_states = self.tower( image, output_hidden_states=True ).hidden_states features = all_hidden_states[self.feature_layer][:, 1:] return self.projector(features) model = LlavaNextForConditionalGeneration.from_pretrained( args.model_path, torch_dtype=torch.float16 ) wrapper = LlavaVisionWrapper( model.vision_tower.to(args.device), model.multi_modal_projector.to(args.device), model.config.vision_feature_layer, ) # 2. Merge text and images # ! infer image_num_patches from image_sizes pixel_values = image image_num_patches = [pixel_values.shape[1]] # figure out if pixel_values is concatenated or stacked if image.dim() == 5: # stacking when input is (batch_size, num_patches, num_channels, height, width) _pixel_values_list = [ pix_val[:num_patch] for pix_val, num_patch in zip(pixel_values, image_num_patches) ] pixel_values = torch.cat(_pixel_values_list, dim=0) elif pixel_values.dim() != 4: # otherwise has to be stacked from list of (num_patches, num_channels, height, width) raise ValueError( f"pixel_values of shape {pixel_values.shape}, expect to be of 4 or 5 dimensions" ) print("------Debug image: ", pixel_values, pixel_values.shape) image = pixel_values export_visual_wrapper_onnx(wrapper, image, args.output_dir) build_trt_engine( args.model_type, [image.shape[1], image.shape[2], image.shape[3]], # [3, H, W] args.output_dir, args.max_batch_size, ) def build_vila_engine(args): # Note: VILA model is not in public HF model zoo yet. We need to explicitly import from the git repo sys.path.append(args.vila_path) from llava.model import LlavaLlamaForCausalLM model = LlavaLlamaForCausalLM.from_pretrained( args.model_path, torch_dtype=torch.float16 ) vision_tower = model.get_vision_tower() image_processor = vision_tower.image_processor raw_image = Image.new("RGB", [10, 10]) # dummy image image = image_processor(images=raw_image, return_tensors="pt")["pixel_values"].to( args.device, torch.float16 ) class VilaVisionWrapper(torch.nn.Module): def __init__(self, tower, projector): super().__init__() self.tower = tower self.projector = projector def forward(self, image): features = self.tower(image) return self.projector(features) model = LlavaLlamaForCausalLM.from_pretrained( args.model_path, torch_dtype=torch.float16 ) wrapper = VilaVisionWrapper( model.get_model().get_vision_tower().to(args.device), model.get_model().mm_projector.to(args.device), ) export_visual_wrapper_onnx(wrapper, image, args.output_dir) build_trt_engine( args.model_type, [image.shape[1], image.shape[2], image.shape[3]], # [3, H, W] args.output_dir, args.max_batch_size, ) def build_nougat_engine(args): processor = NougatProcessor.from_pretrained(args.model_path) raw_image = Image.new("RGB", [10, 10]) # dummy image image = processor(raw_image, return_tensors="pt")["pixel_values"].to( args.device, torch.float16 ) class SwinEncoderWrapper(torch.nn.Module): def __init__(self, encoder): super().__init__() self.encoder = encoder def forward(self, image): return self.encoder(image).last_hidden_state model = VisionEncoderDecoderModel.from_pretrained( args.model_path, torch_dtype=torch.float16 ) swin_encoder = model.get_encoder().to(args.device) wrapper = SwinEncoderWrapper(swin_encoder) export_visual_wrapper_onnx(wrapper, image, args.output_dir) build_trt_engine( args.model_type, [image.shape[1], image.shape[2], image.shape[3]], # [3, H, W] args.output_dir, args.max_batch_size, ) def build_cogvlm_engine(args): hf_config = AutoConfig.from_pretrained(args.model_path, trust_remote_code=True) image_size = hf_config.vision_config["image_size"] dtype = hf_config.torch_dtype image = torch.empty( 1, 3, image_size, image_size, dtype=dtype, device=args.device ) # dummy image class CogVlmVisionWrapper(torch.nn.Module): def __init__(self, encoder): super().__init__() self.encoder = encoder def forward(self, image): return self.encoder(image) cogvlm = AutoModelForCausalLM.from_pretrained( args.model_path, torch_dtype=dtype, trust_remote_code=True ) vit_encoder = cogvlm.model.vision.to(args.device).eval() wrapper = CogVlmVisionWrapper(vit_encoder) export_visual_wrapper_onnx(wrapper, image, args.output_dir) build_trt_engine( args.model_type, [image.shape[1], image.shape[2], image.shape[3]], # [3, H, W] args.output_dir, args.max_batch_size, dtype, ) def build_fuyu_engine(args): processor = FuyuProcessor.from_pretrained(args.model_path) raw_image = Image.new("RGB", [10, 10]) image = ( processor(text="dummy", images=raw_image, return_tensors="pt")["image_patches"][ 0 ] .to(args.device, torch.float16) .unsqueeze(0) ) class FuyuEncoderWrapper(torch.nn.Module): def __init__(self, linear): super().__init__() self.linear = linear.to(torch.float16) def forward(self, patches): return self.linear(patches).flatten(0, 1) model = FuyuForCausalLM.from_pretrained(args.model_path, torch_dtype=torch.float16) vision_encoder = model.vision_embed_tokens wrapper = FuyuEncoderWrapper(vision_encoder).to(args.device) export_visual_wrapper_onnx( wrapper, image, args.output_dir, dynamic_axes={"input": {0: "batch", 2: "patch"}}, ) build_trt_engine( args.model_type, # [nImgs, nImgPatches, nDims] # nImgs is always one since each query has exactly one image # nImgPatches depends on image size (patch size: 30x30) # nDims is 30x30x3=2700 (patch size x color channels) [[1, 1, 2700], [1, 500, 2700], [1, 4096, 2700]], args.output_dir, args.max_batch_size, ) def build_neva_engine(args): # extract NeMo checkpoint with tarfile.open(args.model_path) as tar: nemo_config = yaml.safe_load(tar.extractfile("./model_config.yaml")) try: # trained without TP mp0_weights = torch.load( tar.extractfile("./model_weights.ckpt"), map_location=args.device ) except KeyError: # trained with TP mp0_weights = torch.load( tar.extractfile("./mp_rank_00/model_weights.ckpt"), map_location=args.device, ) vision_config = nemo_config["mm_cfg"]["vision_encoder"] class VisionEncoderWrapper(torch.nn.Module): def __init__(self, encoder, connector): super().__init__() self.encoder = encoder self.connector = connector def forward(self, images): vision_x = self.encoder(pixel_values=images, output_hidden_states=True) vision_x = vision_x.hidden_states[-2] vision_x = vision_x[:, 1:] vision_x = self.connector(vision_x) return vision_x encoder = AutoModel.from_pretrained( vision_config["from_pretrained"], torch_dtype=torch.bfloat16, trust_remote_code=True, ) vision_encoder = encoder.vision_model hf_config = encoder.config dtype = hf_config.torch_dtype # connector assert nemo_config["mm_cfg"]["mm_mlp_adapter_type"] == "mlp2x_gelu" vision_connector = torch.nn.Sequential( torch.nn.Linear( vision_config["hidden_size"], nemo_config["hidden_size"], bias=True ), torch.nn.GELU(), torch.nn.Linear( nemo_config["hidden_size"], nemo_config["hidden_size"], bias=True ), ).to(dtype=dtype) key_prefix = "model.embedding.word_embeddings.adapter_layer.mm_projector_adapter.mm_projector" for layer in range(0, 3, 2): vision_connector[layer].load_state_dict( { "weight": mp0_weights[f"{key_prefix}.{layer}.weight"].to(dtype), "bias": mp0_weights[f"{key_prefix}.{layer}.bias"].to(dtype), } ) # export the whole wrapper wrapper = VisionEncoderWrapper(vision_encoder, vision_connector).to( args.device, dtype ) image_size = hf_config.vision_config.image_size dummy_image = torch.empty( 1, 3, image_size, image_size, dtype=dtype, device=args.device ) # dummy image export_visual_wrapper_onnx(wrapper, dummy_image, args.output_dir) build_trt_engine( args.model_type, [3, image_size, image_size], # [3, H, W] args.output_dir, args.max_batch_size, dtype, ) def build_kosmos_engine(args): processor = AutoProcessor.from_pretrained(args.model_path) raw_image = Image.new("RGB", [10, 10]) # dummy image image = processor(text="dummy", images=raw_image, return_tensors="pt")[ "pixel_values" ].to(args.device, torch.float16) class VisionEncoderWrapper(torch.nn.Module): def __init__(self, encoder, connector): super().__init__() self.encoder = encoder self.connector = connector def forward(self, images): vision_x = self.encoder(images, output_hidden_states=True) img_features = self.encoder.model.post_layernorm(vision_x.last_hidden_state) img_features = F.normalize(img_features, dim=-1) img_features, _ = self.connector(img_features) return img_features model = AutoModelForVision2Seq.from_pretrained( args.model_path, torch_dtype=torch.float16 ) wrapper = VisionEncoderWrapper( model.vision_model.to(args.device), model.image_to_text_projection.to(args.device), ) export_visual_wrapper_onnx(wrapper, image, args.output_dir) build_trt_engine( args.model_type, [image.shape[1], image.shape[2], image.shape[3]], # [3, H, W] args.output_dir, args.max_batch_size, ) if __name__ == "__main__": logger = trt.Logger(trt.Logger.INFO) args = parse_arguments() builder = VisionEngineBuilder(args) builder.build()