Commit 2581b885 authored by jerrrrry's avatar jerrrrry
Browse files

Initial commit

parents
Pipeline #3320 canceled with stages
import argparse
import os
import shutil
import sys
import tarfile
from time import time
import yaml
# isort: off
import torch
import tensorrt as trt
from tensorrt_llm.builder import Builder
# isort: on
import torch.nn.functional as F
from PIL import Image
from transformers import (
AutoConfig,
AutoModel,
AutoModelForCausalLM,
AutoModelForVision2Seq,
AutoProcessor,
Blip2ForConditionalGeneration,
Blip2Processor,
FuyuForCausalLM,
FuyuProcessor,
LlavaForConditionalGeneration,
LlavaNextForConditionalGeneration,
NougatProcessor,
Pix2StructForConditionalGeneration,
VisionEncoderDecoderModel,
)
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_type",
type=str,
default=None,
choices=[
"opt-2.7b",
"opt-6.7b",
"flan-t5-xl",
"flan-t5-xxl",
"llava",
"llava_next",
"vila",
"nougat",
"cogvlm",
"fuyu",
"pix2struct",
"neva",
"kosmos-2",
],
help="Model type",
)
parser.add_argument(
"--model_path",
type=str,
default=None,
help="Huggingface repo, local directory with weights or path to checkpoint file",
)
parser.add_argument(
"--vila_path", type=str, default=None, help="Path to VILA source code directory"
)
parser.add_argument(
"--output_dir",
type=str,
default=None,
help="Directory where visual TRT engines are saved",
)
parser.add_argument(
"--max_batch_size",
type=int,
default=4,
help="Maximum batch size for input images",
)
return parser.parse_args()
class VisionEngineBuilder:
def __init__(self, args):
args.device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
if args.output_dir is None:
args.output_dir = "visual_engines/%s" % (
args.model_path.split("/")[-1].split(".")[0]
)
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
self.args = args
def build(self):
args = self.args
if "opt" in args.model_type or "t5" in args.model_type:
build_blip2_engine(args)
elif args.model_type == "pix2struct":
build_pix2struct_engine(args)
elif args.model_type == "llava":
build_llava_engine(args)
elif args.model_type == "llava_next":
build_llava_next_engine(args)
elif args.model_type == "vila":
assert (
args.vila_path is not None
), "Please clone and provide VILA source code path"
build_vila_engine(args)
elif args.model_type == "nougat":
build_nougat_engine(args)
elif args.model_type == "cogvlm":
build_cogvlm_engine(args)
elif args.model_type == "fuyu":
build_fuyu_engine(args)
elif args.model_type == "neva":
build_neva_engine(args)
elif args.model_type == "kosmos-2":
build_kosmos_engine(args)
else:
raise RuntimeError(f"Invalid model type {args.model_type}")
def export_visual_wrapper_onnx(
visual_wrapper,
input,
output_dir,
input_names=["input"],
dynamic_axes={"input": {0: "batch"}},
):
logger.log(trt.Logger.INFO, "Exporting onnx")
os.makedirs(f"{output_dir}/onnx", exist_ok=True)
torch.onnx.export(
visual_wrapper,
input,
f"{output_dir}/onnx/visual_encoder.onnx",
opset_version=17,
input_names=input_names,
output_names=["output"],
dynamic_axes=dynamic_axes,
)
def build_trt_engine(
model_type, input_sizes, output_dir, max_batch_size, dtype=torch.float16
):
part_name = "visual_encoder"
onnx_file = "%s/onnx/%s.onnx" % (output_dir, part_name)
engine_file = "%s/%s.engine" % (output_dir, part_name)
config_file = "%s/%s" % (output_dir, "config.json")
logger.log(trt.Logger.INFO, "Building TRT engine for %s" % part_name)
builder = trt.Builder(logger)
network = builder.create_network(
1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
)
profile = builder.create_optimization_profile()
config_wrapper = Builder().create_builder_config(
precision=str(dtype).split(".")[-1], model_type=model_type
)
config = config_wrapper.trt_builder_config
parser = trt.OnnxParser(network, logger)
with open(onnx_file, "rb") as model:
if not parser.parse(model.read(), os.path.abspath(onnx_file)):
logger.log(trt.Logger.ERROR, "Failed parsing %s" % onnx_file)
for error in range(parser.num_errors):
logger.log(trt.Logger.ERROR, parser.get_error(error))
logger.log(trt.Logger.INFO, "Succeeded parsing %s" % onnx_file)
# Delete onnx files since we don't need them now
# shutil.rmtree(f'{output_dir}/onnx')
nBS = -1
nMinBS = 1
nOptBS = max(nMinBS, int(max_batch_size / 2))
nMaxBS = max_batch_size
inputT = network.get_input(0)
# input sizes can be a list of ints (e.g., [3, H, W]) when inputs are images,
# or a list of three int lists (e.g., [[1, 1, 2700], [1, 500, 2700], [1, 4096, 2700]]).
assert isinstance(input_sizes, list), "input_sizes must be a list"
if isinstance(input_sizes[0], int):
logger.log(trt.Logger.INFO, f"Processed input sizes {input_sizes}")
inputT.shape = [nBS, *input_sizes]
min_size = opt_size = max_size = input_sizes
elif len(input_sizes) == 3 and isinstance(input_sizes[0], list):
min_size, opt_size, max_size = input_sizes
logger.log(
trt.Logger.INFO,
f"Processed min/opt/max input sizes {min_size}/{opt_size}/{max_size}",
)
else:
raise ValueError(f"invalid input sizes: {input_sizes}")
profile.set_shape(
inputT.name, [nMinBS, *min_size], [nOptBS, *opt_size], [nMaxBS, *max_size]
)
if model_type == "pix2struct":
inputT = network.get_input(1)
P = input_sizes[0] # Number of patches
inputT.shape = [nBS, P]
profile.set_shape(inputT.name, [nMinBS, P], [nOptBS, P], [nMaxBS, P])
config.add_optimization_profile(profile)
t0 = time()
engine_string = builder.build_serialized_network(network, config)
t1 = time()
if engine_string is None:
raise RuntimeError("Failed building %s" % (engine_file))
else:
logger.log(
trt.Logger.INFO, "Succeeded building %s in %d s" % (engine_file, t1 - t0)
)
with open(engine_file, "wb") as f:
f.write(engine_string)
Builder.save_config(config_wrapper, config_file)
def build_blip2_engine(args):
model_type = "Salesforce/blip2-" + args.model_type
processor = Blip2Processor.from_pretrained(model_type)
raw_image = Image.new("RGB", [10, 10]) # dummy image
prompt = "Question: what is this? Answer:"
inputs = processor(raw_image, prompt, return_tensors="pt").to(
args.device, torch.float16
)
image = inputs["pixel_values"]
class Blip2VisionWrapper(torch.nn.Module):
def __init__(self, vision_model, qformer, projector, query_tokens):
super().__init__()
self.vision_model = vision_model
self.qformer = qformer
self.projector = projector
self.query_tokens = query_tokens
def forward(self, image):
features = self.vision_model(image)[0]
qformer_output = self.qformer(
query_embeds=self.query_tokens,
encoder_hidden_states=features,
return_dict=True,
)
return self.projector(qformer_output.last_hidden_state)
model = Blip2ForConditionalGeneration.from_pretrained(
model_type, torch_dtype=torch.float16
)
wrapper = Blip2VisionWrapper(
model.vision_model, model.qformer, model.language_projection, model.query_tokens
)
wrapper.to(args.device)
export_visual_wrapper_onnx(wrapper, image, args.output_dir)
build_trt_engine(
model_type,
[image.shape[1], image.shape[2], image.shape[3]], # [3, H, W]
args.output_dir,
args.max_batch_size,
)
def build_pix2struct_engine(args):
processor = AutoProcessor.from_pretrained(args.model_path)
raw_image = Image.new("RGB", [10, 10]) # dummy image
dtype = torch.float16
inputs = processor(text="dummy", images=raw_image, return_tensors="pt")
image = inputs["flattened_patches"].to(args.device, dtype)
attention_mask = inputs["attention_mask"].to(args.device, torch.int)
class pix2structVisionWrapper(torch.nn.Module):
def __init__(self, encoder):
super().__init__()
self.encoder = encoder
def forward(self, image, attention_mask):
vision_x = self.encoder.embeddings(image)
img_features = self.encoder.encoder(vision_x, attention_mask=attention_mask)
img_features = self.encoder.layernorm(img_features[0])
return img_features
model = Pix2StructForConditionalGeneration.from_pretrained(
args.model_path, torch_dtype=dtype
)
wrapper = pix2structVisionWrapper(model.encoder.to(args.device))
# input shape: batch size, number of patches, hidden dimension
# attention mask shape: batch size, number of patches
# The number of image patches can vary depending on the image size, but it typically
# falls within a relatively narrow range. To improve performance, we can avoid using
# dynamic axis for the input patches and instead use a fixed number of patches along
# with an attention mask.
export_visual_wrapper_onnx(
wrapper,
(image, attention_mask),
args.output_dir,
input_names=["input", "attention_mask"],
dynamic_axes={"input": {0: "batch"}, "attention_mask": {0: "batch"}},
)
build_trt_engine(
args.model_type,
[image.shape[1], image.shape[2]], # Number of Patches, Hidden Dimension
args.output_dir,
args.max_batch_size,
torch.bfloat16,
)
def build_llava_engine(args):
processor = AutoProcessor.from_pretrained(args.model_path)
raw_image = Image.new("RGB", [10, 10]) # dummy image
image = processor(text="dummy", images=raw_image, return_tensors="pt")[
"pixel_values"
].to(args.device, torch.float16)
class LlavaVisionWrapper(torch.nn.Module):
def __init__(self, tower, projector, feature_layer):
super().__init__()
self.tower = tower
self.projector = projector
self.feature_layer = feature_layer
def forward(self, image):
all_hidden_states = self.tower(
image, output_hidden_states=True
).hidden_states
features = all_hidden_states[self.feature_layer][:, 1:]
return self.projector(features)
model = LlavaForConditionalGeneration.from_pretrained(
args.model_path, torch_dtype=torch.float16
)
wrapper = LlavaVisionWrapper(
model.vision_tower.to(args.device),
model.multi_modal_projector.to(args.device),
model.config.vision_feature_layer,
)
export_visual_wrapper_onnx(wrapper, image, args.output_dir)
build_trt_engine(
args.model_type,
[image.shape[1], image.shape[2], image.shape[3]], # [3, H, W]
args.output_dir,
args.max_batch_size,
)
def build_llava_next_engine(args):
processor = AutoProcessor.from_pretrained(args.model_path)
# raw_image = Image.new('RGB', [10, 10]) # dummy image
import requests
url = "https://www.ilankelman.org/stopsigns/australia.jpg"
raw_image = Image.open(requests.get(url, stream=True).raw)
image = processor(text="dummy", images=raw_image, return_tensors="pt")[
"pixel_values"
].to(args.device, torch.float16)
class LlavaVisionWrapper(torch.nn.Module):
def __init__(self, tower, projector, feature_layer):
super().__init__()
self.tower = tower
self.projector = projector
self.feature_layer = feature_layer
def forward(self, image):
all_hidden_states = self.tower(
image, output_hidden_states=True
).hidden_states
features = all_hidden_states[self.feature_layer][:, 1:]
return self.projector(features)
model = LlavaNextForConditionalGeneration.from_pretrained(
args.model_path, torch_dtype=torch.float16
)
wrapper = LlavaVisionWrapper(
model.vision_tower.to(args.device),
model.multi_modal_projector.to(args.device),
model.config.vision_feature_layer,
)
# 2. Merge text and images
# ! infer image_num_patches from image_sizes
pixel_values = image
image_num_patches = [pixel_values.shape[1]]
# figure out if pixel_values is concatenated or stacked
if image.dim() == 5:
# stacking when input is (batch_size, num_patches, num_channels, height, width)
_pixel_values_list = [
pix_val[:num_patch]
for pix_val, num_patch in zip(pixel_values, image_num_patches)
]
pixel_values = torch.cat(_pixel_values_list, dim=0)
elif pixel_values.dim() != 4:
# otherwise has to be stacked from list of (num_patches, num_channels, height, width)
raise ValueError(
f"pixel_values of shape {pixel_values.shape}, expect to be of 4 or 5 dimensions"
)
print("------Debug image: ", pixel_values, pixel_values.shape)
image = pixel_values
export_visual_wrapper_onnx(wrapper, image, args.output_dir)
build_trt_engine(
args.model_type,
[image.shape[1], image.shape[2], image.shape[3]], # [3, H, W]
args.output_dir,
args.max_batch_size,
)
def build_vila_engine(args):
# Note: VILA model is not in public HF model zoo yet. We need to explicitly import from the git repo
sys.path.append(args.vila_path)
from llava.model import LlavaLlamaForCausalLM
model = LlavaLlamaForCausalLM.from_pretrained(
args.model_path, torch_dtype=torch.float16
)
vision_tower = model.get_vision_tower()
image_processor = vision_tower.image_processor
raw_image = Image.new("RGB", [10, 10]) # dummy image
image = image_processor(images=raw_image, return_tensors="pt")["pixel_values"].to(
args.device, torch.float16
)
class VilaVisionWrapper(torch.nn.Module):
def __init__(self, tower, projector):
super().__init__()
self.tower = tower
self.projector = projector
def forward(self, image):
features = self.tower(image)
return self.projector(features)
model = LlavaLlamaForCausalLM.from_pretrained(
args.model_path, torch_dtype=torch.float16
)
wrapper = VilaVisionWrapper(
model.get_model().get_vision_tower().to(args.device),
model.get_model().mm_projector.to(args.device),
)
export_visual_wrapper_onnx(wrapper, image, args.output_dir)
build_trt_engine(
args.model_type,
[image.shape[1], image.shape[2], image.shape[3]], # [3, H, W]
args.output_dir,
args.max_batch_size,
)
def build_nougat_engine(args):
processor = NougatProcessor.from_pretrained(args.model_path)
raw_image = Image.new("RGB", [10, 10]) # dummy image
image = processor(raw_image, return_tensors="pt")["pixel_values"].to(
args.device, torch.float16
)
class SwinEncoderWrapper(torch.nn.Module):
def __init__(self, encoder):
super().__init__()
self.encoder = encoder
def forward(self, image):
return self.encoder(image).last_hidden_state
model = VisionEncoderDecoderModel.from_pretrained(
args.model_path, torch_dtype=torch.float16
)
swin_encoder = model.get_encoder().to(args.device)
wrapper = SwinEncoderWrapper(swin_encoder)
export_visual_wrapper_onnx(wrapper, image, args.output_dir)
build_trt_engine(
args.model_type,
[image.shape[1], image.shape[2], image.shape[3]], # [3, H, W]
args.output_dir,
args.max_batch_size,
)
def build_cogvlm_engine(args):
hf_config = AutoConfig.from_pretrained(args.model_path, trust_remote_code=True)
image_size = hf_config.vision_config["image_size"]
dtype = hf_config.torch_dtype
image = torch.empty(
1, 3, image_size, image_size, dtype=dtype, device=args.device
) # dummy image
class CogVlmVisionWrapper(torch.nn.Module):
def __init__(self, encoder):
super().__init__()
self.encoder = encoder
def forward(self, image):
return self.encoder(image)
cogvlm = AutoModelForCausalLM.from_pretrained(
args.model_path, torch_dtype=dtype, trust_remote_code=True
)
vit_encoder = cogvlm.model.vision.to(args.device).eval()
wrapper = CogVlmVisionWrapper(vit_encoder)
export_visual_wrapper_onnx(wrapper, image, args.output_dir)
build_trt_engine(
args.model_type,
[image.shape[1], image.shape[2], image.shape[3]], # [3, H, W]
args.output_dir,
args.max_batch_size,
dtype,
)
def build_fuyu_engine(args):
processor = FuyuProcessor.from_pretrained(args.model_path)
raw_image = Image.new("RGB", [10, 10])
image = (
processor(text="dummy", images=raw_image, return_tensors="pt")["image_patches"][
0
]
.to(args.device, torch.float16)
.unsqueeze(0)
)
class FuyuEncoderWrapper(torch.nn.Module):
def __init__(self, linear):
super().__init__()
self.linear = linear.to(torch.float16)
def forward(self, patches):
return self.linear(patches).flatten(0, 1)
model = FuyuForCausalLM.from_pretrained(args.model_path, torch_dtype=torch.float16)
vision_encoder = model.vision_embed_tokens
wrapper = FuyuEncoderWrapper(vision_encoder).to(args.device)
export_visual_wrapper_onnx(
wrapper,
image,
args.output_dir,
dynamic_axes={"input": {0: "batch", 2: "patch"}},
)
build_trt_engine(
args.model_type,
# [nImgs, nImgPatches, nDims]
# nImgs is always one since each query has exactly one image
# nImgPatches depends on image size (patch size: 30x30)
# nDims is 30x30x3=2700 (patch size x color channels)
[[1, 1, 2700], [1, 500, 2700], [1, 4096, 2700]],
args.output_dir,
args.max_batch_size,
)
def build_neva_engine(args):
# extract NeMo checkpoint
with tarfile.open(args.model_path) as tar:
nemo_config = yaml.safe_load(tar.extractfile("./model_config.yaml"))
try:
# trained without TP
mp0_weights = torch.load(
tar.extractfile("./model_weights.ckpt"), map_location=args.device
)
except KeyError:
# trained with TP
mp0_weights = torch.load(
tar.extractfile("./mp_rank_00/model_weights.ckpt"),
map_location=args.device,
)
vision_config = nemo_config["mm_cfg"]["vision_encoder"]
class VisionEncoderWrapper(torch.nn.Module):
def __init__(self, encoder, connector):
super().__init__()
self.encoder = encoder
self.connector = connector
def forward(self, images):
vision_x = self.encoder(pixel_values=images, output_hidden_states=True)
vision_x = vision_x.hidden_states[-2]
vision_x = vision_x[:, 1:]
vision_x = self.connector(vision_x)
return vision_x
encoder = AutoModel.from_pretrained(
vision_config["from_pretrained"],
torch_dtype=torch.bfloat16,
trust_remote_code=True,
)
vision_encoder = encoder.vision_model
hf_config = encoder.config
dtype = hf_config.torch_dtype
# connector
assert nemo_config["mm_cfg"]["mm_mlp_adapter_type"] == "mlp2x_gelu"
vision_connector = torch.nn.Sequential(
torch.nn.Linear(
vision_config["hidden_size"], nemo_config["hidden_size"], bias=True
),
torch.nn.GELU(),
torch.nn.Linear(
nemo_config["hidden_size"], nemo_config["hidden_size"], bias=True
),
).to(dtype=dtype)
key_prefix = "model.embedding.word_embeddings.adapter_layer.mm_projector_adapter.mm_projector"
for layer in range(0, 3, 2):
vision_connector[layer].load_state_dict(
{
"weight": mp0_weights[f"{key_prefix}.{layer}.weight"].to(dtype),
"bias": mp0_weights[f"{key_prefix}.{layer}.bias"].to(dtype),
}
)
# export the whole wrapper
wrapper = VisionEncoderWrapper(vision_encoder, vision_connector).to(
args.device, dtype
)
image_size = hf_config.vision_config.image_size
dummy_image = torch.empty(
1, 3, image_size, image_size, dtype=dtype, device=args.device
) # dummy image
export_visual_wrapper_onnx(wrapper, dummy_image, args.output_dir)
build_trt_engine(
args.model_type,
[3, image_size, image_size], # [3, H, W]
args.output_dir,
args.max_batch_size,
dtype,
)
def build_kosmos_engine(args):
processor = AutoProcessor.from_pretrained(args.model_path)
raw_image = Image.new("RGB", [10, 10]) # dummy image
image = processor(text="dummy", images=raw_image, return_tensors="pt")[
"pixel_values"
].to(args.device, torch.float16)
class VisionEncoderWrapper(torch.nn.Module):
def __init__(self, encoder, connector):
super().__init__()
self.encoder = encoder
self.connector = connector
def forward(self, images):
vision_x = self.encoder(images, output_hidden_states=True)
img_features = self.encoder.model.post_layernorm(vision_x.last_hidden_state)
img_features = F.normalize(img_features, dim=-1)
img_features, _ = self.connector(img_features)
return img_features
model = AutoModelForVision2Seq.from_pretrained(
args.model_path, torch_dtype=torch.float16
)
wrapper = VisionEncoderWrapper(
model.vision_model.to(args.device),
model.image_to_text_projection.to(args.device),
)
export_visual_wrapper_onnx(wrapper, image, args.output_dir)
build_trt_engine(
args.model_type,
[image.shape[1], image.shape[2], image.shape[3]], # [3, H, W]
args.output_dir,
args.max_batch_size,
)
if __name__ == "__main__":
logger = trt.Logger(trt.Logger.INFO)
args = parse_arguments()
builder = VisionEngineBuilder(args)
builder.build()
import copy
import functools
import json
import os
import sys
import time
from collections import defaultdict
from pathlib import Path
from typing import Optional
import safetensors
import torch
import torch.nn as nn
from tqdm import tqdm
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
from transformers.pytorch_utils import Conv1D
from ..._utils import pad_vocab_size, release_gc
from ...layers import MoeConfig
from ...logger import logger
from ...mapping import Mapping
from ...quantization import QuantAlgo
from ..convert_utils import load_calib_dataset
from ..modeling_utils import PretrainedConfig, QuantConfig, optimize_model
from .weight import load_from_hf_checkpoint, load_from_hf_safetensors
try:
from transformers import (
LlavaConfig,
LlavaForConditionalGeneration,
LlavaNextConfig,
LlavaNextForConditionalGeneration,
)
except ImportError:
pass
try:
pass
except ImportError:
pass
def generate_int8(weights, act_range, is_qkv=False, multi_query_mode=False):
"""
This function has two purposes:
- compute quantized weights, scaled either per-tensor or per-column
- compute scaling factors
Depending on the GEMM API (CUTLASS/CUBLAS) the required scaling factors differ.
CUTLASS uses two sets of scaling factors. One for the activation X, one for the weight W.
CUBLAS only has one (we can't do per-row scaling). So we must provide pre-multiplied scaling factor.
Here is the list of what we need (T means per-tensor, C per-column):
- scale_x_orig_quant puts fp activation into the quantized range (i.e. [-128, 127], for int8). Used before the GEMM. (T)
- scale_y_quant_orig puts quantized activation into the fp range. Used if the GEMM outputs int8. (T)
- scale_w_quant_orig puts weights from quant range to fp range (used with CUTLASS) (T, C)
- scale_y_accum_quant puts the GEMM result (XW) from accumulation range (int32)
to quant range (int8) (used for CUBLAS) (T, C)
Note that we don't do anything special about row-parallel GEMM. Theoretically, we could have per-GPU scaling factors too,
but then the model would change depending on the number of GPUs used.
For QKV projection, the behavior is special. Even if we have a single matrix to perform QKV projection, we consider it
as three different matrices: Q, K, and V. So per-tensor actually means one scaling factor for each Q, K and V.
For our GEMM implementation to respect this behavior, we use per-column mode and replicate values along columns.
"""
# compute weight scaling factors for fp->int8 and int8->fp
if is_qkv and not multi_query_mode:
scale_w_orig_quant_t = (
127.0 / act_range["w"].reshape(3, -1).max(dim=-1, keepdims=True)[0]
)
scale_w_orig_quant_c = 127.0 / act_range["w"].reshape(3, -1)
elif is_qkv and multi_query_mode:
hidden_dim = weights.shape[0]
local_dim = act_range["w"].shape[0]
kv_dim = (local_dim - hidden_dim) // 2
scale_w_q = act_range["w"][0:hidden_dim]
scale_w_k = act_range["w"][hidden_dim : hidden_dim + kv_dim]
scale_w_v = act_range["w"][-kv_dim:]
scale_w_qkv_t = torch.concat(
[
scale_w_q.max(dim=0, keepdim=True)[0],
scale_w_k.max(dim=0, keepdim=True)[0],
scale_w_v.max(dim=0, keepdim=True)[0],
]
)
scale_w_orig_quant_t = 127.0 / scale_w_qkv_t
scale_w_orig_quant_c = 127.0 / act_range["w"]
else:
scale_w_orig_quant_t = 127.0 / act_range["w"].max()
scale_w_orig_quant_c = 127.0 / act_range["w"]
scale_w_quant_orig_t = 1.0 / scale_w_orig_quant_t
scale_w_quant_orig_c = 1.0 / scale_w_orig_quant_c
scale_w_orig_quant_c = scale_w_orig_quant_c.to(torch.float32)
scale_w_orig_quant_t = scale_w_orig_quant_t.to(torch.float32)
# compute the rest of needed scaling factors
scale_x_orig_quant_t = 127.0 / act_range["x"].max()
scale_y_orig_quant_t = 127.0 / act_range["y"].max()
scale_y_quant_orig_t = act_range["y"].max() / 127.0
scale_y_accum_quant_t = scale_y_orig_quant_t / (
scale_x_orig_quant_t * scale_w_orig_quant_t
)
scale_y_accum_quant_c = scale_y_orig_quant_t / (
scale_x_orig_quant_t * scale_w_orig_quant_c
)
if is_qkv and not multi_query_mode:
scale_y_accum_quant_t = torch.broadcast_to(
scale_y_accum_quant_t, scale_w_orig_quant_c.shape
)
scale_w_quant_orig_t = torch.broadcast_to(
scale_w_quant_orig_t, scale_w_orig_quant_c.shape
)
if is_qkv and multi_query_mode:
scale_q_y_accum_t = torch.broadcast_to(
scale_y_accum_quant_t[0], scale_w_q.shape
)
scale_k_y_accum_t = torch.broadcast_to(
scale_y_accum_quant_t[1], scale_w_k.shape
)
scale_v_y_accum_t = torch.broadcast_to(
scale_y_accum_quant_t[2], scale_w_v.shape
)
scale_y_accum_quant_t = torch.concat(
[scale_q_y_accum_t, scale_k_y_accum_t, scale_v_y_accum_t]
)
scale_w_quant_orig_t = torch.concat(
[
torch.broadcast_to(scale_w_quant_orig_t[0], scale_w_q.shape),
torch.broadcast_to(scale_w_quant_orig_t[1], scale_w_k.shape),
torch.broadcast_to(scale_w_quant_orig_t[2], scale_w_v.shape),
]
)
to_i8 = lambda x: x.round().clip(-127, 127).to(torch.int8)
if is_qkv and multi_query_mode:
weight_int8 = to_i8(weights / scale_w_quant_orig_t)
else:
weight_int8 = to_i8(weights * scale_w_orig_quant_t)
return {
"weight.int8": weight_int8,
"weight.int8.col": to_i8(weights * scale_w_orig_quant_c),
"scale_x_orig_quant": scale_x_orig_quant_t.to(torch.float32),
"scale_w_quant_orig": scale_w_quant_orig_t.to(torch.float32),
"scale_w_quant_orig.col": scale_w_quant_orig_c.to(torch.float32),
"scale_y_accum_quant": scale_y_accum_quant_t.to(torch.float32),
"scale_y_accum_quant.col": scale_y_accum_quant_c.to(torch.float32),
"scale_y_quant_orig": scale_y_quant_orig_t.to(torch.float32),
}
@torch.no_grad()
def apply_smoothing(
scales,
gemm_weights,
layernorm_weights=None,
layernorm_bias=None,
dtype=torch.float32,
layernorm_1p=False,
):
if not isinstance(gemm_weights, list):
gemm_weights = [gemm_weights]
if layernorm_weights is not None:
assert layernorm_weights.numel() == scales.numel()
layernorm_weights.div_(scales).to(dtype)
if layernorm_bias is not None:
assert layernorm_bias.numel() == scales.numel()
layernorm_bias.div_(scales).to(dtype)
if layernorm_1p:
layernorm_weights += (1 / scales) - 1
for gemm in gemm_weights:
gemm.mul_(scales.view(1, -1)).to(dtype)
@torch.no_grad()
def smooth_gemm(
gemm_weights,
act_scales,
layernorm_weights=None,
layernorm_bias=None,
alpha=0.5,
weight_scales=None,
):
if not isinstance(gemm_weights, list):
gemm_weights = [gemm_weights]
orig_dtype = gemm_weights[0].dtype
for gemm in gemm_weights:
# gemm_weights are expected to be transposed
assert gemm.shape[1] == act_scales.numel()
if weight_scales is None:
weight_scales = torch.cat(
[gemm.abs().max(dim=0, keepdim=True)[0] for gemm in gemm_weights], dim=0
)
weight_scales = weight_scales.max(dim=0)[0]
weight_scales.to(float).clamp(min=1e-5)
scales = (
act_scales.to(gemm_weights[0].device).to(float).pow(alpha)
/ weight_scales.pow(1 - alpha)
).clamp(min=1e-5)
apply_smoothing(scales, gemm_weights, layernorm_weights, layernorm_bias, orig_dtype)
return scales
@torch.no_grad()
def smooth_gemm_fc1_gate(
fc1_weights,
gate_weights,
act_scales,
layernorm_weights=None,
layernorm_bias=None,
alpha=0.5,
weight_scales=None,
):
gemm_weights = []
if not isinstance(fc1_weights, list):
fc1_weights = [fc1_weights]
if not isinstance(gate_weights, list):
gate_weights = [gate_weights]
for i in range(len(fc1_weights)):
gemm_weight = torch.cat([fc1_weights[i], gate_weights[i]], dim=0)
gemm_weights.append(gemm_weight)
orig_dtype = gemm_weights[0].dtype
for gemm in gemm_weights:
# gemm_weights are expected to be transposed
assert gemm.shape[1] == act_scales.numel()
if weight_scales is None:
weight_scales = torch.cat(
[gemm.abs().max(dim=0, keepdim=True)[0] for gemm in gemm_weights], dim=0
)
weight_scales = weight_scales.max(dim=0)[0]
weight_scales.to(float).clamp(min=1e-5)
scales = (
act_scales.to(gemm_weights[0].device).to(float).pow(alpha)
/ weight_scales.pow(1 - alpha)
).clamp(min=1e-5)
apply_smoothing(
scales,
fc1_weights + gate_weights,
layernorm_weights,
layernorm_bias,
orig_dtype,
)
return scales
@torch.no_grad()
def smooth_llama_model(model, scales, alpha, llama_qkv_para, llama_smoother):
# Smooth the activation and weights with smoother = $\diag{s}$
for name, module in model.named_modules():
if not isinstance(
module, LlamaDecoderLayer
) and not module.__class__.__name__ in [
"InternLMDecoderLayer",
"MistralDecoderLayer",
]:
continue
# qkv_proj
layer_name_q = name + ".self_attn.q_proj"
layer_name_k = name + ".self_attn.k_proj"
layer_name_v = name + ".self_attn.v_proj"
layer_name_qkv = name + ".self_attn.qkv_proj"
weight = torch.cat(
[
module.self_attn.q_proj.weight,
module.self_attn.k_proj.weight,
module.self_attn.v_proj.weight,
],
dim=0,
)
smoother = smooth_gemm(
weight,
scales[layer_name_q]["x"],
module.input_layernorm.weight,
None,
alpha,
)
scales[layer_name_qkv]["x"] = scales[layer_name_q]["x"] / smoother
scales[layer_name_qkv]["w"] = weight.abs().max(dim=1)[0]
scales[layer_name_qkv]["y"] = torch.cat(
[
scales[layer_name_q]["y"],
scales[layer_name_k]["y"],
scales[layer_name_v]["y"],
],
dim=0,
)
# see transpose_weights function
llama_qkv_para[layer_name_qkv] = weight.transpose(0, 1)
# =================================================================
layer_name = name + ".self_attn.o_proj"
smoother = smooth_gemm(
module.self_attn.o_proj.weight, scales[layer_name]["x"], None, None, alpha
)
llama_smoother[layer_name] = smoother.float()
scales[layer_name]["x"] = scales[layer_name]["x"] / smoother
scales[layer_name]["w"] = module.self_attn.o_proj.weight.abs().max(dim=1)[0]
# ==================================================================
fc1_layer_name = name + ".mlp.gate_proj"
gate_layer_name = name + ".mlp.up_proj"
smoother = smooth_gemm_fc1_gate(
module.mlp.gate_proj.weight,
module.mlp.up_proj.weight,
scales[fc1_layer_name]["x"],
module.post_attention_layernorm.weight,
None,
alpha,
)
scales[fc1_layer_name]["x"] = scales[fc1_layer_name]["x"] / smoother
scales[fc1_layer_name]["w"] = module.mlp.gate_proj.weight.abs().max(dim=1)[0]
scales[gate_layer_name]["x"] = scales[gate_layer_name]["x"] / smoother
scales[gate_layer_name]["w"] = module.mlp.up_proj.weight.abs().max(dim=1)[0]
# ==================================================================
layer_name = name + ".mlp.down_proj"
smoother = smooth_gemm(
module.mlp.down_proj.weight, scales[layer_name]["x"], None, None, alpha
)
llama_smoother[layer_name] = smoother.float()
scales[layer_name]["x"] = scales[layer_name]["x"] / smoother
scales[layer_name]["w"] = module.mlp.down_proj.weight.abs().max(dim=1)[0]
# ==================================================================
if hasattr(module, "residual_mlp"):
fc1_layer_name = name + ".residual_mlp.w1"
gate_layer_name = name + ".residual_mlp.w3"
smoother = smooth_gemm_fc1_gate(
module.residual_mlp.w1.weight,
module.residual_mlp.w3.weight,
scales[fc1_layer_name]["x"],
module.residual_layernorm.weight,
None,
alpha,
)
scales[fc1_layer_name]["x"] = scales[fc1_layer_name]["x"] / smoother
scales[fc1_layer_name]["w"] = module.residual_mlp.w1.weight.abs().max(
dim=1
)[0]
scales[gate_layer_name]["x"] = scales[gate_layer_name]["x"] / smoother
scales[gate_layer_name]["w"] = module.residual_mlp.w3.weight.abs().max(
dim=1
)[0]
# ==================================================================
layer_name = name + ".residual_mlp.w2"
smoother = smooth_gemm(
module.residual_mlp.w2.weight,
scales[layer_name]["x"],
None,
None,
alpha,
)
llama_smoother[layer_name] = smoother.float()
scales[layer_name]["x"] = scales[layer_name]["x"] / smoother
scales[layer_name]["w"] = module.residual_mlp.w2.weight.abs().max(dim=1)[0]
@torch.no_grad()
def capture_activation_range(model, tokenizer, dataset, num_samples=512, seq_len=512):
model.eval()
device = next(model.parameters()).device
act_scales = defaultdict(lambda: {"x": None, "y": None, "w": None})
tokenizer.pad_token = tokenizer.eos_token
def stat_tensor(name, tensor, act_scales, key):
hidden_dim = tensor.shape[-1]
tensor = tensor.view(-1, hidden_dim).abs().detach()
comming_max = torch.max(tensor, dim=0)[0].float()
if act_scales[name][key] is None:
act_scales[name][key] = comming_max
else:
act_scales[name][key] = torch.max(act_scales[name][key], comming_max)
def stat_input_hook(m, x, y, name):
if isinstance(x, tuple):
x = x[0]
stat_tensor(name, x, act_scales, "x")
stat_tensor(name, y, act_scales, "y")
if act_scales[name]["w"] is None:
act_scales[name]["w"] = m.weight.abs().clip(1e-8, None).max(dim=1)[0]
hooks = []
for name, m in model.named_modules():
if isinstance(m, nn.Linear) or isinstance(m, Conv1D):
hooks.append(
m.register_forward_hook(functools.partial(stat_input_hook, name=name))
)
for i in tqdm(range(num_samples), desc="calibrating model"):
datapoint = dataset[i : i + 1]
line = copy.copy(datapoint)
line[0] = line[0] + " TL;DR: "
line[0] = line[0].strip()
line[0] = line[0].replace(" n't", "n't")
input_ids = tokenizer(
line, return_tensors="pt", max_length=seq_len, padding=True, truncation=True
).input_ids.to(device)
model(input_ids)
for h in hooks:
h.remove()
return act_scales
def split(v, tp_size, idx, dim=0):
if tp_size == 1:
return v
if len(v.shape) == 1:
return torch.chunk(v, tp_size)[idx].contiguous()
else:
return torch.chunk(v, tp_size, dim=dim)[idx]
def split_qkv_tp(v, n_head, n_hidden, tensor_parallel, rank):
"""
Splits the QKV matrix according to tensor parallelism
"""
v = v.reshape(3, n_hidden, n_hidden)
split_v = split(v, tensor_parallel, rank, dim=1)
split_v = split_v.reshape(3 * (n_hidden // tensor_parallel), n_hidden)
return split_v
def split_qkv_bias_tp(v, n_head, n_hidden, tensor_parallel, rank):
"""
Splits the QKV bias according to tensor parallelism
"""
v = v.reshape(3, n_hidden)
split_v = split(v, tensor_parallel, rank, dim=1)
split_v = split_v.reshape(3 * (n_hidden // tensor_parallel))
return split_v
def split_matrix_tp(v, tensor_parallel, rank, dim):
return split(v, tensor_parallel, rank, dim=dim)
def get_weight(config, prefix, dtype):
if config[prefix + ".weight"].dtype != dtype:
config[prefix + ".weight"].data = config[prefix + ".weight"].to(dtype)
return config[prefix + ".weight"].detach()
def get_bias(config, prefix, dtype):
if config[prefix + ".bias"].dtype != dtype:
config[prefix + ".bias"].data = config[prefix + ".bias"].to(dtype)
return config[prefix + ".bias"].detach()
def get_weight_and_bias(config, prefix, dtype):
return get_weight(config, prefix, dtype), get_bias(config, prefix, dtype)
def get_tllm_linear_weight(
weight,
prefix,
bias=None,
use_weight_only=False,
plugin_weight_only_quant_type=torch.int8,
dtype="float32",
use_gemm_woq_plugin=True,
postfix="weight",
quant_scale_name=None,
):
results = {}
if use_weight_only:
if weight.dim() > 2:
v = weight.transpose(1, 2).contiguous()
else:
v = weight.t().contiguous()
processed_torch_weights, torch_weight_scales = (
torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix(
v.cpu(), plugin_weight_only_quant_type
)
)
if not use_gemm_woq_plugin:
results[prefix + postfix] = v.to(dtype)
else:
results[prefix + postfix] = processed_torch_weights
if quant_scale_name is not None:
results[quant_scale_name] = torch_weight_scales
else:
results[prefix + "per_channel_scale"] = torch_weight_scales
else:
results[prefix + postfix] = weight
if bias is not None:
results[prefix + "bias"] = bias
return results
def dup_kv_weight(v, num_head, tp_size):
assert tp_size % num_head == 0
reps = tp_size // num_head
head_size = v.shape[0] // num_head
v = v.reshape(num_head, head_size, -1)[:, None, :, :].expand(
num_head, reps, head_size, v.shape[1]
)
return v.reshape(num_head * reps * head_size, -1).clone().detach()
def get_tllm_linear_sq_weight(
vals,
prefix,
shape,
tensor_parallel,
is_qkv=False,
per_token=False,
per_channel=False,
last_prefix=None,
bias=None,
smoother_value=None,
smoother_shape=None,
rank=0,
cat_dim=0,
multi_query_mode=False,
):
results = {}
def multi_query_split(data, local_dim, head_size, tp_size, cur_rank):
q, k, v = torch.split(data, [local_dim, head_size, head_size], dim=-1)
q_split = torch.chunk(q, tp_size, dim=-1)
k_split = torch.chunk(k, tp_size, dim=-1)
v_split = torch.chunk(v, tp_size, dim=-1)
return [
torch.concat((q_split[ii], k_split[ii], v_split[ii]), axis=-1)
for ii in range(tp_size)
][cur_rank]
col_shape = shape if (is_qkv or per_channel) else [1, 1]
if per_token:
if per_channel:
original_weights = torch.Tensor(vals["weight.int8.col"]).cuda()
else:
original_weights = torch.Tensor(vals["weight.int8"]).cuda()
local_dim = original_weights.shape[0]
head_size = (original_weights.shape[1] - local_dim) // 2
if multi_query_mode:
cur_weights = multi_query_split(
original_weights, local_dim, head_size, tensor_parallel, rank
)
else:
cur_weights = torch.chunk(original_weights, tensor_parallel, dim=cat_dim)[
rank
]
if is_qkv:
hidden_dim = cur_weights.shape[0]
cur_weights = cur_weights.reshape(hidden_dim, -1)
results[prefix + "weight"] = cur_weights.t().contiguous()
if smoother_value is None:
results[last_prefix] = torch.Tensor([1.0]).to(torch.float32).cuda()
if per_channel:
cur_per_channel_value = vals["scale_w_quant_orig.col"]
if smoother_value is None:
if multi_query_mode:
cur_per_channel_value = multi_query_split(
vals["scale_w_quant_orig.col"],
local_dim,
head_size,
tensor_parallel,
rank,
)
else:
cur_per_channel_value = torch.chunk(
vals["scale_w_quant_orig.col"], tensor_parallel, dim=cat_dim
)[rank]
else:
cur_per_channel_value = vals["scale_w_quant_orig"]
if is_qkv:
if multi_query_mode:
cur_per_channel_value = multi_query_split(
vals["scale_w_quant_orig"],
local_dim,
head_size,
tensor_parallel,
rank,
)
else:
cur_per_channel_value = torch.chunk(
vals["scale_w_quant_orig"], tensor_parallel, dim=cat_dim
)[rank]
results[prefix + "per_channel_scale"] = cur_per_channel_value.reshape(
col_shape
).contiguous()
else:
if per_channel:
original_weights = torch.Tensor(vals["weight.int8.col"]).cuda()
else:
original_weights = torch.Tensor(vals["weight.int8"]).cuda()
local_dim = original_weights.shape[0]
head_size = (original_weights.shape[1] - local_dim) // 2
if multi_query_mode:
cur_weights = multi_query_split(
original_weights, local_dim, head_size, tensor_parallel, rank
)
else:
cur_weights = torch.chunk(original_weights, tensor_parallel, dim=cat_dim)[
rank
]
if is_qkv:
hidden_dim = cur_weights.shape[0]
cur_weights = cur_weights.reshape(hidden_dim, -1)
results[prefix + "weight"] = cur_weights.t().contiguous()
if per_channel:
cur_per_channel_value = vals["scale_y_accum_quant.col"]
if smoother_value is None:
if multi_query_mode:
cur_per_channel_value = multi_query_split(
vals["scale_y_accum_quant.col"],
local_dim,
head_size,
tensor_parallel,
rank,
)
else:
cur_per_channel_value = torch.chunk(
vals["scale_y_accum_quant.col"], tensor_parallel, dim=cat_dim
)[rank]
else:
cur_per_channel_value = vals["scale_y_accum_quant"]
# QKV is always per_channel
if is_qkv:
if multi_query_mode:
cur_per_channel_value = multi_query_split(
vals["scale_y_accum_quant"],
local_dim,
head_size,
tensor_parallel,
rank,
)
else:
cur_per_channel_value = torch.chunk(
vals["scale_y_accum_quant"], tensor_parallel, dim=cat_dim
)[rank]
results[prefix + "per_channel_scale"] = (
torch.Tensor(cur_per_channel_value)
.to(torch.float32)
.reshape(col_shape)
.contiguous()
.cuda()
)
results[prefix + "act_scale"] = (
torch.Tensor([[vals["scale_y_quant_orig"]]])
.to(torch.float32)
.contiguous()
.cuda()
)
results[last_prefix] = (
torch.Tensor([vals["scale_x_orig_quant"]])
.to(torch.float32)
.contiguous()
.cuda()
)
if smoother_value is not None:
cur_smoother_value = torch.chunk(smoother_value, tensor_parallel, dim=cat_dim)[
rank
]
results[prefix + "smoother"] = (
cur_smoother_value.reshape(smoother_shape).contiguous().to(torch.float32)
)
if bias is not None:
results[prefix + "bias"] = bias
return results
def convert_hf_llama(
hf_model,
mapping,
vocab_size=32000,
dtype="float32",
use_parallel_embedding=False,
sharding_dim=0,
use_weight_only=False,
share_embedding_table=False,
residual_mlp=False,
use_gemm_woq_plugin=False,
plugin_weight_only_quant_type=torch.int8,
use_smooth_quant=False,
per_channel=False,
per_token=False,
int8_kv_cache=False,
act_range=[],
qkv_para=[],
smoother=[],
moe_config=None,
):
weights = {}
tik = time.time()
tensor_parallel = mapping.tp_size
model_params = dict(hf_model.named_parameters())
dtype = getattr(torch, dtype)
num_attention_heads = hf_model.config.num_attention_heads
hidden_size = hf_model.config.hidden_size
head_size = hidden_size // num_attention_heads
intermediate_size = hf_model.config.intermediate_size
num_key_value_heads = getattr(
hf_model.config, "num_key_value_heads", num_attention_heads
)
mha_mode = num_key_value_heads == num_attention_heads
layers_range = mapping.pp_layers(hf_model.config.num_hidden_layers)
def convert_layer(l):
prefix = f"model.layers.{l}."
tllm_prex = f"transformer.layers.{l - layers_range[0]}."
q_weight = get_weight(model_params, prefix + "self_attn.q_proj", dtype)
k_weight = get_weight(model_params, prefix + "self_attn.k_proj", dtype)
v_weight = get_weight(model_params, prefix + "self_attn.v_proj", dtype)
if not mha_mode:
if num_key_value_heads < tensor_parallel:
# duplicate the KV heads up to tensor_parallel
k_weight = dup_kv_weight(k_weight, num_key_value_heads, tensor_parallel)
v_weight = dup_kv_weight(v_weight, num_key_value_heads, tensor_parallel)
assert (k_weight.shape[0] % (mapping.tp_size * head_size)) == 0
assert (v_weight.shape[0] % (mapping.tp_size * head_size)) == 0
wq = split(q_weight, mapping.tp_size, mapping.tp_rank)
wk = split(k_weight, mapping.tp_size, mapping.tp_rank)
wv = split(v_weight, mapping.tp_size, mapping.tp_rank)
split_v = torch.concat((wq, wk, wv))
else:
qkv_weight = torch.cat([q_weight, k_weight, v_weight], dim=0)
split_v = split_qkv_tp(
qkv_weight,
num_attention_heads,
hidden_size,
tensor_parallel,
mapping.tp_rank,
)
if prefix + "self_attn.q_proj.bias" in model_params:
# only used in Internlm 7B models
q_bias = get_bias(model_params, prefix + "self_attn.q_proj", dtype)
k_bias = get_bias(model_params, prefix + "self_attn.k_proj", dtype)
v_bias = get_bias(model_params, prefix + "self_attn.v_proj", dtype)
qkv_bias = torch.cat((q_bias, k_bias, v_bias))
split_bias_v = split_qkv_bias_tp(
qkv_bias,
num_attention_heads,
hidden_size,
tensor_parallel,
mapping.tp_rank,
)
else:
split_bias_v = None
if use_smooth_quant:
qkv_weight = qkv_para[prefix + "self_attn.qkv_proj"]
qkv_out_dim = qkv_weight.shape[1]
if not mha_mode:
local_dim = qkv_weight.shape[0]
kv_hidden_size = (qkv_weight.shape[-1] - local_dim) // 2
qkv_weight = qkv_weight.reshape(
local_dim, local_dim + 2 * kv_hidden_size
)
else:
qkv_weight = qkv_weight.reshape(hidden_size, 3, hidden_size)
int8_weights = generate_int8(
qkv_weight,
act_range.get(prefix + "self_attn.qkv_proj"),
is_qkv=True,
multi_query_mode=bool(not mha_mode),
)
weights.update(
get_tllm_linear_sq_weight(
int8_weights,
tllm_prex + "attention.qkv.",
[1, qkv_out_dim // tensor_parallel],
tensor_parallel,
is_qkv=True,
bias=split_bias_v,
per_token=per_token,
per_channel=per_channel,
last_prefix=tllm_prex + "input_layernorm.scale_to_int",
smoother_value=None,
smoother_shape=None,
rank=mapping.tp_rank,
cat_dim=-1,
multi_query_mode=bool(not mha_mode),
)
)
else:
weights.update(
get_tllm_linear_weight(
split_v,
tllm_prex + "attention.qkv.",
split_bias_v,
use_weight_only,
plugin_weight_only_quant_type,
dtype,
use_gemm_woq_plugin,
)
)
if int8_kv_cache:
qkv_y = torch.cat(
[
act_range.get(prefix + "self_attn.q_proj")["y"],
act_range.get(prefix + "self_attn.k_proj")["y"],
act_range.get(prefix + "self_attn.v_proj")["y"],
],
dim=0,
)
int8_kv_scales = qkv_y.max() / 127.0
kv_cache_weights = {}
kv_cache_weights[tllm_prex + "attention.kv_cache_scaling_factor"] = (
int8_kv_scales.reshape([1])
)
weights.update(kv_cache_weights)
attn_dense_weight = get_weight(model_params, prefix + "self_attn.o_proj", dtype)
split_v = split_matrix_tp(
attn_dense_weight, tensor_parallel, mapping.tp_rank, dim=1
)
if prefix + "self_attn.o_proj.bias" in model_params:
attn_dense_bias = get_bias(model_params, prefix + "self_attn.o_proj", dtype)
else:
attn_dense_bias = None
if use_smooth_quant:
attn_dense_weight = attn_dense_weight.t()
int8_weights = generate_int8(
attn_dense_weight, act_range.get(prefix + "self_attn.o_proj")
)
weights.update(
get_tllm_linear_sq_weight(
int8_weights,
tllm_prex + "attention.dense.",
[1, hidden_size],
tensor_parallel,
is_qkv=False,
bias=attn_dense_bias,
per_token=per_token,
per_channel=per_channel,
last_prefix=tllm_prex + "attention.quantization_scaling_factor",
smoother_value=smoother[(prefix + "self_attn.o_proj")],
smoother_shape=[1, hidden_size // tensor_parallel],
rank=mapping.tp_rank,
cat_dim=0,
)
)
else:
weights.update(
get_tllm_linear_weight(
split_v,
tllm_prex + "attention.dense.",
attn_dense_bias,
use_weight_only,
plugin_weight_only_quant_type,
dtype,
use_gemm_woq_plugin,
)
)
if moe_config and moe_config.has_moe():
rank_experts = list(range(moe_config.num_experts))
if moe_config.tp_mode == moe_config.ParallelismMode.EXPERT_PARALLEL:
rank_experts = mapping.ep_experts(moe_config.num_experts)
for suffix in ["w1", "w2", "w3"]:
model_params[
f"model.layers.{l}.block_sparse_moe.experts.{suffix}.weight"
] = torch.stack(
[
model_params[
f"model.layers.{l}.block_sparse_moe.experts.{expert}.{suffix}.weight"
].detach()
for expert in rank_experts
]
)
w3 = model_params[f"model.layers.{l}.block_sparse_moe.experts.w3.weight"]
w2 = model_params[f"model.layers.{l}.block_sparse_moe.experts.w2.weight"]
w1 = model_params[f"model.layers.{l}.block_sparse_moe.experts.w1.weight"]
if moe_config.tp_mode == moe_config.ParallelismMode.TENSOR_PARALLEL:
w3 = split(w3, mapping.tp_size, mapping.tp_rank, dim=1)
w2 = split(w2, mapping.tp_size, mapping.tp_rank, dim=2)
w1 = split(w1, mapping.tp_size, mapping.tp_rank, dim=1)
model_params[f"model.layers.{l}.block_sparse_moe.experts.w3w1.weight"] = (
torch.concat([w3, w1], dim=-2)
)
model_params[f"model.layers.{l}.block_sparse_moe.experts.w2.weight"] = w2
## block_sparse_moe.experts.w2.weight
moe_experts_w2_weights = get_weight(
model_params, prefix + "block_sparse_moe.experts.w2", dtype
)
weights.update(
get_tllm_linear_weight(
moe_experts_w2_weights,
tllm_prex + "mlp.proj.",
None,
use_weight_only,
plugin_weight_only_quant_type,
dtype,
use_gemm_woq_plugin,
)
)
##block_sparse_moe.experts.w3w1.weight
moe_experts_w3w1_weights = get_weight(
model_params, prefix + "block_sparse_moe.experts.w3w1", dtype
)
weights.update(
get_tllm_linear_weight(
moe_experts_w3w1_weights,
tllm_prex + "mlp.fc.",
None,
use_weight_only,
plugin_weight_only_quant_type,
dtype,
use_gemm_woq_plugin,
)
)
if residual_mlp:
residual_mlp_gate_weights = get_weight(
model_params, prefix + "residual_mlp.w3", dtype
)
if use_smooth_quant:
residual_mlp_gate_weights = residual_mlp_gate_weights.t()
int8_weights = generate_int8(
residual_mlp_gate_weights,
act_range.get(prefix + "residual_mlp.w3"),
)
weights.update(
get_tllm_linear_sq_weight(
int8_weights,
tllm_prex + "residual_mlp.gate.",
[1, hidden_size // tensor_parallel],
tensor_parallel,
is_qkv=False,
per_token=per_token,
per_channel=per_channel,
last_prefix=tllm_prex + "post_layernorm.scale_to_int",
smoother_value=None,
smoother_shape=None,
rank=mapping.tp_rank,
cat_dim=-1,
)
)
else:
split_v = split_matrix_tp(
residual_mlp_gate_weights,
tensor_parallel,
mapping.tp_rank,
dim=0,
)
weights.update(
get_tllm_linear_weight(
split_v,
tllm_prex + "residual_mlp.gate.",
None,
use_weight_only,
plugin_weight_only_quant_type,
dtype,
use_gemm_woq_plugin,
)
)
residual_mlp_fc_weight = get_weight(
model_params, prefix + "residual_mlp.w1", dtype
)
if use_smooth_quant:
residual_mlp_fc_weight = residual_mlp_fc_weight.t() # verified
int8_weights = generate_int8(
residual_mlp_fc_weight,
act_range.get(prefix + "residual_mlp.w1"),
)
weights.update(
get_tllm_linear_sq_weight(
int8_weights,
tllm_prex + "residual_mlp.fc.",
[1, hidden_size // tensor_parallel],
tensor_parallel,
is_qkv=False,
per_token=per_token,
per_channel=per_channel,
last_prefix=tllm_prex + "post_layernorm.scale_to_int",
smoother_value=None,
smoother_shape=None,
rank=mapping.tp_rank,
cat_dim=-1,
)
)
else:
split_v = split_matrix_tp(
residual_mlp_fc_weight, tensor_parallel, mapping.tp_rank, dim=0
)
weights.update(
get_tllm_linear_weight(
split_v,
tllm_prex + "residual_mlp.fc.",
None,
use_weight_only,
plugin_weight_only_quant_type,
dtype,
use_gemm_woq_plugin,
)
)
residual_mlp_proj_weight = get_weight(
model_params, prefix + "residual_mlp.w2", dtype
)
if use_smooth_quant:
residual_mlp_proj_weight = residual_mlp_proj_weight.t()
int8_weights = generate_int8(
residual_mlp_proj_weight,
act_range.get(prefix + "residual_mlp.w2"),
)
weights.update(
get_tllm_linear_sq_weight(
int8_weights,
tllm_prex + "residual_mlp.proj.",
[1, hidden_size],
tensor_parallel,
is_qkv=False,
per_token=per_token,
per_channel=per_channel,
last_prefix=tllm_prex
+ "residual_mlp.quantization_scaling_factor",
smoother_value=smoother[prefix + "residual_mlp.w2"],
smoother_shape=[1, hidden_size // tensor_parallel],
rank=mapping.tp_rank,
cat_dim=0,
)
)
else:
split_v = split_matrix_tp(
residual_mlp_proj_weight,
tensor_parallel,
mapping.tp_rank,
dim=1,
)
weights.update(
get_tllm_linear_weight(
split_v,
tllm_prex + "residual_mlp.proj.",
None,
use_weight_only,
plugin_weight_only_quant_type,
dtype,
use_gemm_woq_plugin,
)
)
moe_experts_gate_weights = get_weight(
model_params, prefix + "block_sparse_moe.gate", torch.float32
)
weights.update(
get_tllm_linear_weight(
moe_experts_gate_weights,
tllm_prex + "mlp.router.",
None,
False, # Router should never be quantized
plugin_weight_only_quant_type,
dtype,
use_gemm_woq_plugin,
)
)
else:
mlp_gate_weight = get_weight(model_params, prefix + "mlp.up_proj", dtype)
split_v = split_matrix_tp(
mlp_gate_weight, tensor_parallel, mapping.tp_rank, dim=0
)
if use_smooth_quant:
mlp_gate_weight = mlp_gate_weight.t()
int8_weights = generate_int8(
mlp_gate_weight, act_range.get(prefix + "mlp.up_proj")
)
weights.update(
get_tllm_linear_sq_weight(
int8_weights,
tllm_prex + "mlp.gate.",
[1, intermediate_size // tensor_parallel],
tensor_parallel,
is_qkv=False,
per_token=per_token,
per_channel=per_channel,
last_prefix=tllm_prex + "post_layernorm.scale_to_int",
smoother_value=None,
smoother_shape=None,
rank=mapping.tp_rank,
cat_dim=-1,
)
)
else:
weights.update(
get_tllm_linear_weight(
split_v,
tllm_prex + "mlp.gate.",
None,
use_weight_only,
plugin_weight_only_quant_type,
dtype,
use_gemm_woq_plugin,
)
)
mlp_fc_weight = get_weight(model_params, prefix + "mlp.gate_proj", dtype)
split_v = split_matrix_tp(
mlp_fc_weight, tensor_parallel, mapping.tp_rank, dim=0
)
if use_smooth_quant:
mlp_fc_weight = mlp_fc_weight.t() # verified
int8_weights = generate_int8(
mlp_fc_weight, act_range.get(prefix + "mlp.gate_proj")
)
weights.update(
get_tllm_linear_sq_weight(
int8_weights,
tllm_prex + "mlp.fc.",
[1, intermediate_size // tensor_parallel],
tensor_parallel,
is_qkv=False,
per_token=per_token,
per_channel=per_channel,
last_prefix=tllm_prex + "post_layernorm.scale_to_int",
smoother_value=None,
smoother_shape=None,
rank=mapping.tp_rank,
cat_dim=-1,
)
)
else:
weights.update(
get_tllm_linear_weight(
split_v,
tllm_prex + "mlp.fc.",
None,
use_weight_only,
plugin_weight_only_quant_type,
dtype,
use_gemm_woq_plugin,
)
)
mlp_proj_weight = get_weight(model_params, prefix + "mlp.down_proj", dtype)
split_v = split_matrix_tp(
mlp_proj_weight, tensor_parallel, mapping.tp_rank, dim=1
)
if use_smooth_quant:
mlp_proj_weight = mlp_proj_weight.t()
int8_weights = generate_int8(
mlp_proj_weight, act_range.get(prefix + "mlp.down_proj")
)
weights.update(
get_tllm_linear_sq_weight(
int8_weights,
tllm_prex + "mlp.proj.",
[1, hidden_size],
tensor_parallel,
is_qkv=False,
per_token=per_token,
per_channel=per_channel,
last_prefix=tllm_prex + "mlp.quantization_scaling_factor",
smoother_value=smoother[prefix + "mlp.down_proj"],
smoother_shape=[1, intermediate_size // tensor_parallel],
rank=mapping.tp_rank,
cat_dim=0,
)
)
else:
weights.update(
get_tllm_linear_weight(
split_v,
tllm_prex + "mlp.proj.",
None,
use_weight_only,
plugin_weight_only_quant_type,
dtype,
use_gemm_woq_plugin,
)
)
# Layer norms do not use tensor parallelism
input_ln_weight = get_weight(model_params, prefix + "input_layernorm", dtype)
weights[tllm_prex + "input_layernorm.weight"] = input_ln_weight
post_ln_weight = get_weight(
model_params, prefix + "post_attention_layernorm", dtype
)
weights[tllm_prex + "post_layernorm.weight"] = post_ln_weight
if residual_mlp:
residual_ln_weight = get_weight(
model_params, prefix + "residual_layernorm", dtype
)
weights[tllm_prex + "residual_layernorm.weight"] = residual_ln_weight
cur_block_weights = [
weight_name
for weight_name in model_params
if weight_name.find(prefix) != -1
]
for weight_name in cur_block_weights:
model_params[weight_name] = None
for l in layers_range:
convert_layer(l)
release_gc()
v = get_weight(model_params, "model.embed_tokens", dtype)
if hf_model.config.tie_word_embeddings:
# lm_head.weight has the same weights as embedding
if mapping.is_last_pp_rank():
if vocab_size % mapping.tp_size != 0:
# padding
vocab_size_padded = pad_vocab_size(vocab_size, mapping.tp_size)
pad_width = vocab_size_padded - vocab_size
v = torch.nn.functional.pad(v, (0, 0, 0, pad_width), "constant", 0)
weights["lm_head.weight"] = split(v, mapping.tp_size, mapping.tp_rank)
if use_parallel_embedding:
v = split_matrix_tp(v, mapping.tp_size, mapping.tp_rank, dim=sharding_dim)
if mapping.is_first_pp_rank():
weights["transformer.vocab_embedding.weight"] = v
# if not use_parallel_embedding:
# weights['transformer.vocab_embedding.weight'] = embed_w
# else:
# assert hf_model.config.vocab_size % tensor_parallel == 0
# weights['transformer.vocab_embedding.weight'] = split_matrix_tp(
# embed_w, tensor_parallel, rank
lm_head_weights = get_weight(model_params, "lm_head", dtype)
if mapping.is_last_pp_rank():
if vocab_size % mapping.tp_size != 0:
# padding
vocab_size_padded = pad_vocab_size(vocab_size, mapping.tp_size)
pad_width = vocab_size_padded - vocab_size
lm_head_weights = torch.nn.functional.pad(
lm_head_weights, (0, 0, 0, pad_width), "constant", value=0
)
weights["lm_head.weight"] = split_matrix_tp(
lm_head_weights, tensor_parallel, mapping.tp_rank, dim=0
)
ln_f_w = get_weight(model_params, "model.norm", dtype)
weights["transformer.ln_f.weight"] = ln_f_w
tok = time.time()
t = time.strftime("%H:%M:%S", time.gmtime(tok - tik))
print(f"Weights loaded. Total time: {t}")
return weights
def smooth_quant(
model,
model_dir,
calib_dataset,
dataset_cache_dir,
smoothquant: Optional[float] = None,
):
assert model is not None
act_range = {}
llama_qkv_para = {}
# smoother for inputs of self_attn.o_proj and mlp.down_proj
llama_smoother = {}
os.environ["TOKENIZERS_PARALLELISM"] = os.environ.get(
"TOKENIZERS_PARALLELISM", "false"
)
tokenizer = AutoTokenizer.from_pretrained(
model_dir, trust_remote_code=True, use_fast=False, padding_side="left"
)
dataset = load_calib_dataset(calib_dataset, cache_dir=dataset_cache_dir)
act_range = capture_activation_range(model, tokenizer, dataset)
if smoothquant is not None:
smooth_llama_model(
model, act_range, smoothquant, llama_qkv_para, llama_smoother
)
return act_range, llama_qkv_para, llama_smoother
def create_config_from_hugging_face(
hf_model,
dtype,
mapping,
quantization: QuantConfig = None,
override_fields: dict = {},
):
config = {}
hf_config = AutoConfig.from_pretrained(hf_model, trust_remote_code=True)
if hf_config.model_type == "llava":
# LLaVA = Vision model + Llama LLM
# We load a llava config and use its' text config as llama config
hf_config = LlavaConfig.from_pretrained(hf_model).text_config
if hf_config.model_type == "llava_next":
# LLaVA = Vision model + Llama LLM
# We load a llava config and use its' text config as llama config
hf_config = LlavaNextConfig.from_pretrained(hf_model).text_config
# TODO: directly assign the hf_config fields to the config dict w/o creating these local vars
# same for from_meta and from_cli_args
n_head = hf_config.num_attention_heads
inter_size = hf_config.intermediate_size
n_layer = hf_config.num_hidden_layers
n_embd = hf_config.hidden_size
n_kv_head = getattr(hf_config, "num_key_value_heads", n_head)
rms_norm_eps = hf_config.rms_norm_eps
vocab_size = hf_config.vocab_size
n_positions = hf_config.max_position_embeddings
hidden_act = hf_config.hidden_act
config["rotary_scaling"] = getattr(hf_config, "rope_scaling", None)
rotary_base = getattr(hf_config, "rope_theta", 10000.0)
config["residual_mlp"] = getattr(hf_config, "parallel_attn_mlp_res", False)
if hf_config.model_type == "mixtral" or hf_config.model_type == "arctic":
# HF LLaMA-type models are implicitly using gated activation.
# With our MoE implementation, we must make it explicit
hidden_act = "swiglu"
config["moe_normalization_mode"] = (
MoeConfig.ExpertScaleNormalizationMode.RENORMALIZE
)
else:
config["moe_normalization_mode"] = None
moe_num_experts = getattr(hf_config, "num_local_experts", 0)
moe_top_k = getattr(hf_config, "num_experts_per_tok", 0)
moe_tp_mode = MoeConfig.ParallelismMode.TENSOR_PARALLEL
architecture = hf_config.architectures[0]
# VILA model, force to use llama config
if hf_config.model_type == "llava_llama":
architecture = "LlamaForCausalLM"
attn_bias = getattr(hf_config, "bias", False) or getattr(
hf_config, "attention_bias", False
)
config.update(
{
"architecture": architecture,
"dtype": dtype,
"logits_dtype": "float32",
"num_hidden_layers": n_layer,
"num_attention_heads": n_head,
"hidden_size": n_embd,
"intermediate_size": inter_size,
"num_key_value_heads": n_kv_head,
"vocab_size": vocab_size,
"position_embedding_type": "rope_gpt_neox",
"max_position_embeddings": n_positions,
"hidden_act": hidden_act,
"rotary_base": rotary_base,
"norm_epsilon": rms_norm_eps,
"moe_num_experts": moe_num_experts,
"moe_top_k": moe_top_k,
"moe_tp_mode": moe_tp_mode,
# TODO: should have directly map from the Mapping object to the TRT-LLM checkpoint fields
"mapping": {
"world_size": mapping.tp_size * mapping.pp_size,
"tp_size": mapping.tp_size,
"pp_size": mapping.pp_size,
},
"attn_bias": attn_bias,
}
)
config["quantization"] = quantization.asdict()
config.update(override_fields)
moe_config = MoeConfig(
config["moe_num_experts"],
config["moe_top_k"],
config["moe_tp_mode"],
config["moe_normalization_mode"],
).validate()
use_weight_only = config["quantization"]["quant_algo"] in [
QuantAlgo.W8A16,
QuantAlgo.W4A16,
QuantAlgo.FP8,
]
if use_weight_only and moe_config.has_moe():
config["quantization"]["exclude_modules"].append("router")
print("-----Debug config: ", config)
return config
def from_hugging_face(
cls,
model_dir,
dtype,
*,
mapping,
quantization: QuantConfig = None,
load_by_shard=False,
load_model_on_cpu=False,
override_fields={},
skip_loading_weights=False,
preloaded_model=None,
):
"""Create a LLaMAForCausalLM object from give parameters"""
assert model_dir is not None
if isinstance(model_dir, Path): # some code relies on this as string
model_dir = str(model_dir)
# register VILA model
if "vila" in model_dir:
sys.path.append(model_dir + "/../VILA")
from llava.model import LlavaConfig, LlavaLlamaForCausalLM
AutoConfig.register("llava_llama", LlavaConfig)
AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM)
if override_fields.get("share_embedding_table", False):
logger.warning(
"Llama model does not support share_embedding_table; setting share_embedding_table=False"
)
override_fields["share_embedding_table"] = False
config = create_config_from_hugging_face(
model_dir, dtype, mapping, quantization, override_fields=override_fields
)
pretrained_config = PretrainedConfig.from_dict(config)
pretrained_config.set_rank(mapping.rank) # TODO:remove this hack
llama = cls.from_config(pretrained_config)
llama = optimize_model(
llama,
use_parallel_embedding=pretrained_config.use_parallel_embedding,
share_embedding_table=pretrained_config.share_embedding_table,
)
if skip_loading_weights:
return llama
model = preloaded_model
if (
model is None and not load_by_shard
): # when load by shard, no need to create complete hf model
have_safetensors = any(
[f.endswith(".safetensors") for f in os.listdir(model_dir)]
)
hf_config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True)
if hf_config.model_type == "llava":
hf_llava = LlavaForConditionalGeneration.from_pretrained(
model_dir, torch_dtype="auto"
)
model = hf_llava.language_model
elif hf_config.model_type == "llava_next":
hf_llava_next = LlavaNextForConditionalGeneration.from_pretrained(
model_dir, torch_dtype="auto"
)
model = hf_llava_next.language_model
else:
# TODO: Remove WAR after `load_from_hf_safetensors` supports weight-only quantization
if not have_safetensors or config["quantization"]["quant_algo"] is not None:
model = AutoModelForCausalLM.from_pretrained(
model_dir,
device_map="auto" if not load_model_on_cpu else "cpu",
torch_dtype="auto",
trust_remote_code=True,
)
if load_by_shard:
weights = load_from_hf_checkpoint(model_dir, mapping, pretrained_config)
elif model is not None:
weights = load_weights_from_hf(config=config, mapping=mapping, model=model)
else:
weights = load_from_hf_safetensors(
model_dir=model_dir, config=pretrained_config, mapping=mapping
)
llama.load(weights)
return llama
def quantize(
dtype,
model_dir,
output_dir,
mapping,
quantization: QuantConfig,
*,
calib_dataset="cnn_dailymail",
override_fields={},
dataset_cache_dir: Optional[str] = None,
):
"""
Quantize the save the model as TRT-LLM checkpoint to output_dir
"""
# TODO: currently only smooth quant and kv cache quantization are supported, needs to support mode quant algorithm calling modelopt
config = create_config_from_hugging_face(
model_dir, dtype, mapping, quantization, override_fields=override_fields
)
with open(os.path.join(output_dir, "config.json"), "w") as f:
json.dump(config, f, indent=4)
assert (
mapping.rank == -1
), "You shall call quantize only once in one rank, assert rank==-1 for precaution"
act_range = {}
llama_qkv_para = {}
# smoother for inputs of self_attn.o_proj and mlp.down_proj
llama_smoother = {}
model = None
assert config["quantization"]["quant_algo"] == quantization.quant_algo
int8_kv_cache = quantization.kv_cache_quant_algo == QuantAlgo.INT8
use_smooth_quant = (
quantization.quant_algo is not None
and quantization.quant_algo.startswith("W8A8_SQ")
)
assert (
use_smooth_quant or int8_kv_cache
), "Call from_hugging_face when there is no quantization"
if use_smooth_quant:
assert (
quantization.smoothquant_val is not None
), "A smooth value must be specified when using smooth quant"
assert model_dir is not None
## only load and call smooth quant routine once for all ranks
hf_config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True)
assert (
"llava" not in hf_config.model_type
), "Smooth quant llava/vila is not supported yet"
model = AutoModelForCausalLM.from_pretrained(
model_dir,
device_map="auto",
torch_dtype="auto" if not use_smooth_quant else torch.float16,
trust_remote_code=True,
)
act_range, llama_qkv_para, llama_smoother = smooth_quant(
model, model_dir, calib_dataset, dataset_cache_dir, quantization.smoothquant_val
)
for rank in range(mapping.world_size):
# To avoid changing the mapping arg in-place, also the given mapping from caller is rank agnostic, since quantize is called from only one rank
ranked_mapping = Mapping(
world_size=mapping.world_size,
rank=rank,
tp_size=mapping.tp_size,
pp_size=mapping.pp_size,
)
weights = load_weights_from_hf(
config=config,
mapping=ranked_mapping,
model=model,
# for smooth quant only
act_range=act_range,
llama_qkv_para=llama_qkv_para,
llama_smoother=llama_smoother,
)
safetensors.torch.save_file(
weights, os.path.join(output_dir, f"rank{rank}.safetensors")
)
del weights
def load_weights_from_hf(
*, config, mapping, model, act_range={}, llama_qkv_para={}, llama_smoother={}
):
# TODO: simplify the parameters here
assert model is not None
plugin_weight_only_quant_type = (
None # the value does not matter when use_weight_only is False
)
quant_algo = config["quantization"]["quant_algo"]
if quant_algo == QuantAlgo.W8A16:
plugin_weight_only_quant_type = torch.int8
elif quant_algo == QuantAlgo.W4A16:
plugin_weight_only_quant_type = torch.quint4x2
moe_config = MoeConfig(
config["moe_num_experts"],
config["moe_top_k"],
config["moe_tp_mode"],
config["moe_normalization_mode"],
).validate()
use_weight_only = quant_algo in [QuantAlgo.W8A16, QuantAlgo.W4A16]
use_smooth_quant = quant_algo is not None and quant_algo.startswith("W8A8_SQ")
per_channel_sq = use_smooth_quant and "PER_CHANNEL" in quant_algo
per_token_sq = use_smooth_quant and "PER_TOKEN" in quant_algo
use_int8_kv_cache = config["quantization"]["kv_cache_quant_algo"] == QuantAlgo.INT8
weights = convert_hf_llama(
model,
mapping,
vocab_size=config["vocab_size"],
dtype=config["dtype"],
use_weight_only=use_weight_only,
use_gemm_woq_plugin=not config.get("disable_weight_only_quant_plugin", False),
plugin_weight_only_quant_type=plugin_weight_only_quant_type,
use_parallel_embedding=config.get("use_parallel_embedding", False),
sharding_dim=config.get("embedding_sharding_dim", 0),
share_embedding_table=config.get("share_embedding_table", False),
residual_mlp=config["residual_mlp"],
use_smooth_quant=use_smooth_quant,
per_channel=per_channel_sq,
per_token=per_token_sq,
int8_kv_cache=use_int8_kv_cache,
act_range=act_range,
qkv_para=llama_qkv_para,
smoother=llama_smoother,
moe_config=moe_config,
)
return weights
# from llava.constants import (
# IMAGE_TOKEN_INDEX,
# DEFAULT_IMAGE_TOKEN,
# DEFAULT_IM_START_TOKEN,
# DEFAULT_IM_END_TOKEN,
# IMAGE_PLACEHOLDER,
# )
# from llava.conversation import conv_templates, SeparatorStyle
# from llava.model.builder import load_pretrained_model
# from llava.utils import disable_torch_init
# from llava.mm_utils import (
# process_images,
# tokenizer_image_token,
# get_model_name_from_path,
# )
import argparse
import json
import os
import sys
import time
import traceback
from concurrent.futures import ThreadPoolExecutor, as_completed
import tensorrt_llm
from tensorrt_llm._utils import release_gc
from tensorrt_llm.layers import MoeConfig
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.models import LLaMAForCausalLM
from tensorrt_llm.models.llama.weight import load_from_gptq_llama
from tensorrt_llm.models.modeling_utils import QuantConfig
from tensorrt_llm.quantization import QuantAlgo
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument("--model_dir", type=str, default=None)
parser.add_argument("--meta_ckpt_dir", type=str, default=None)
parser.add_argument(
"--tp_size", type=int, default=1, help="N-way tensor parallelism size"
)
parser.add_argument(
"--pp_size", type=int, default=1, help="N-way pipeline parallelism size"
)
parser.add_argument(
"--dtype",
type=str,
default="float16",
choices=["float32", "bfloat16", "float16"],
)
parser.add_argument("--vocab_size", type=int, default=32000)
parser.add_argument("--n_positions", type=int, default=2048)
parser.add_argument("--n_layer", type=int, default=32)
parser.add_argument("--n_head", type=int, default=32)
parser.add_argument("--n_kv_head", type=int, default=None)
parser.add_argument("--n_embd", type=int, default=4096)
parser.add_argument("--inter_size", type=int, default=11008)
parser.add_argument("--rms_norm_eps", type=float, default=1e-06)
parser.add_argument(
"--use_weight_only",
default=False,
action="store_true",
help="Quantize weights for the various GEMMs to INT4/INT8."
"See --weight_only_precision to set the precision",
)
parser.add_argument(
"--disable_weight_only_quant_plugin",
default=False,
action="store_true",
help="By default, using plugin implementation for weight quantization. Enabling disable_weight_only_quant_plugin flag will use ootb implementation instead of plugin."
"You must also use --use_weight_only for that argument to have an impact.",
)
parser.add_argument(
"--weight_only_precision",
const="int8",
type=str,
nargs="?",
default="int8",
choices=["int8", "int4", "int4_gptq"],
help="Define the precision for the weights when using weight-only quantization."
"You must also use --use_weight_only for that argument to have an impact.",
)
parser.add_argument(
"--calib_dataset",
type=str,
default="ccdv/cnn_dailymail",
help="The huggingface dataset name or the local directory of the dataset for calibration.",
)
parser.add_argument(
"--smoothquant",
"-sq",
type=float,
default=None,
help="Set the α parameter (see https://arxiv.org/pdf/2211.10438.pdf)"
" to Smoothquant the model, and output int8 weights."
" A good first try is 0.5. Must be in [0, 1]",
)
parser.add_argument(
"--per_channel",
action="store_true",
default=False,
help="By default, we use a single static scaling factor for the GEMM's result. "
"per_channel instead uses a different static scaling factor for each channel. "
"The latter is usually more accurate, but a little slower.",
)
parser.add_argument(
"--per_token",
action="store_true",
default=False,
help="By default, we use a single static scaling factor to scale activations in the int8 range. "
"per_token chooses at run time, and for each token, a custom scaling factor. "
"The latter is usually more accurate, but a little slower.",
)
parser.add_argument(
"--int8_kv_cache",
default=False,
action="store_true",
help="By default, we use dtype for KV cache. int8_kv_cache chooses int8 quantization for KV",
)
parser.add_argument(
"--modelopt_quant_ckpt_path",
type=str,
default=None,
help="Path of a quantized model checkpoint in .npz format",
)
parser.add_argument(
"--per_group",
default=False,
action="store_true",
help="By default, we use a single static scaling factor to scale weights in the int4 range. "
"per_group chooses at run time, and for each group, a custom scaling factor. "
"The flag is built for GPTQ/AWQ quantization.",
)
parser.add_argument(
"--load_by_shard",
action="store_true",
help="Load a pretrained model shard-by-shard.",
)
parser.add_argument("--hidden_act", type=str, default="silu")
parser.add_argument("--rotary_base", type=float, default=10000.0)
parser.add_argument(
"--group_size",
type=int,
default=128,
help="Group size used in GPTQ quantization.",
) # AWQ is only supported by quantize.py script
parser.add_argument(
"--dataset-cache-dir",
type=str,
default=None,
help="cache dir to load the hugging face dataset",
)
parser.add_argument("--load_model_on_cpu", action="store_true")
parser.add_argument(
"--use_parallel_embedding",
action="store_true",
default=False,
help="By default embedding parallelism is disabled. By setting this flag, embedding parallelism is enabled",
)
parser.add_argument(
"--embedding_sharding_dim",
type=int,
default=0,
choices=[0, 1],
help="By default the embedding lookup table is sharded along vocab dimension (embedding_sharding_dim=0). "
"To shard it along hidden dimension, set embedding_sharding_dim=1"
"Note: embedding sharing is only enabled when embedding_sharding_dim = 0",
)
parser.add_argument(
"--use_embedding_sharing",
action="store_true",
default=False,
help="Try to reduce the engine size by sharing the embedding lookup table between two layers."
"Note: the flag might not take effect when the criteria are not met.",
)
parser.add_argument(
"--output_dir",
type=str,
default="tllm_checkpoint",
help="The path to save the TensorRT-LLM checkpoint",
)
parser.add_argument(
"--workers",
type=int,
default=1,
help="The number of workers for converting checkpoint in parallel",
)
parser.add_argument(
"--moe_num_experts",
default=0,
type=int,
help="Specify the number of experts to use for MOE layers",
)
parser.add_argument(
"--moe_top_k",
default=0,
type=int,
help="Specify the top_k value to use for MOE layers. Default to 1 if --moe_num_experts is set",
)
parser.add_argument(
"--moe_tp_mode",
default=MoeConfig.ParallelismMode.TENSOR_PARALLEL,
type=int,
help="Controls how to distribute experts in TP. Check layers/moe.py for accepted values",
)
parser.add_argument(
"--moe_renorm_mode",
default=MoeConfig.ExpertScaleNormalizationMode.RENORMALIZE,
type=int,
help="Controls renormalization after gate logits. Check layers/moe.py for accepted values",
)
parser.add_argument(
"--save_config_only",
action="store_true",
default=False,
help="Only save the model config w/o read and converting weights, be careful, this is for debug only",
)
args = parser.parse_args()
# changing the default to be consistent as the cli help said.
if args.moe_num_experts and args.moe_top_k == 0:
args.moe_top_k = 1
return args
def args_to_quantization(args: argparse.Namespace) -> QuantConfig:
"""return config dict with quantization info based on the command line args"""
quant_config = QuantConfig()
quant_config.exclude_modules = ["lm_head"]
if args.use_weight_only:
if args.weight_only_precision == "int8":
quant_config.quant_algo = QuantAlgo.W8A16
elif args.weight_only_precision == "int4":
quant_config.quant_algo = QuantAlgo.W4A16
elif args.smoothquant:
quant_config.smoothquant_val = args.smoothquant
if args.per_channel:
if args.per_token:
quant_config.quant_algo = QuantAlgo.W8A8_SQ_PER_CHANNEL_PER_TOKEN_PLUGIN
else:
quant_config.quant_algo = (
QuantAlgo.W8A8_SQ_PER_CHANNEL_PER_TENSOR_PLUGIN
)
else:
if args.per_token:
quant_config.quant_algo = QuantAlgo.W8A8_SQ_PER_TENSOR_PER_TOKEN_PLUGIN
else:
quant_config.quant_algo = QuantAlgo.W8A8_SQ_PER_TENSOR_PLUGIN
if args.int8_kv_cache:
quant_config.kv_cache_quant_algo = QuantAlgo.INT8
if args.weight_only_precision == "int4_gptq":
quant_config.group_size = args.group_size
quant_config.has_zero_point = True
quant_config.pre_quant_scale = False
quant_config.quant_algo = QuantAlgo.W4A16_GPTQ
return quant_config
def convert_and_save_meta(args, rank):
mapping = Mapping(
world_size=args.tp_size * args.pp_size,
tp_size=args.tp_size,
pp_size=args.pp_size,
rank=rank,
)
assert not args_to_quantization(
args
).quant_mode.has_any_quant(), (
"quantization from meta checkpoint or empty model were never supported"
)
llama = LLaMAForCausalLM.from_meta_ckpt(
args.meta_ckpt_dir,
args.dtype,
mapping,
use_parallel_embedding=args.use_parallel_embedding,
embedding_sharding_dim=args.embedding_sharding_dim,
)
llama.save_checkpoint(args.output_dir, save_config=(rank == 0))
def args_to_build_options(args):
return {
"use_parallel_embedding": args.use_parallel_embedding,
"embedding_sharding_dim": args.embedding_sharding_dim,
"share_embedding_table": args.use_embedding_sharing,
"disable_weight_only_quant_plugin": args.disable_weight_only_quant_plugin,
}
def from_cli_args(args):
n_kv_head = args.n_kv_head if args.n_kv_head is not None else args.n_head
config = {
"architecture": "LlamaForCausalLM",
"dtype": args.dtype,
"logits_dtype": "float32",
"num_hidden_layers": args.n_layer,
"num_attention_heads": args.n_head,
"hidden_size": args.n_embd,
"intermediate_size": args.inter_size,
"num_key_value_heads": n_kv_head,
"vocab_size": args.vocab_size,
"position_embedding_type": "rope_gpt_neox",
"max_position_embeddings": args.n_positions,
"hidden_act": args.hidden_act,
"rotary_base": args.rotary_base,
"norm_epsilon": args.rms_norm_eps,
"moe_num_experts": args.moe_num_experts,
"moe_top_k": args.moe_top_k,
"moe_tp_mode": args.moe_tp_mode,
"moe_normalization_mode": args.moe_renorm_mode,
"mapping": {
"world_size": args.tp_size * args.pp_size,
"tp_size": args.tp_size,
"pp_size": args.pp_size,
},
"quantization": args_to_quantization(args).asdict(),
}
config.update(args_to_build_options(args))
return config
def preload_model(model_dir, load_model_on_cpu):
use_safetensors = True
from transformers import AutoConfig, AutoModelForCausalLM
if "vila" in model_dir:
use_safetensors = False
sys.path.append(model_dir + "/../VILA")
from llava.model import LlavaConfig, LlavaLlamaForCausalLM
AutoConfig.register("llava_llama", LlavaConfig)
AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM)
hf_config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True)
model_cls = AutoModelForCausalLM
if hf_config.model_type == "llava":
use_safetensors = False
from transformers import LlavaForConditionalGeneration
model_cls = LlavaForConditionalGeneration
use_safetensors = (
any([f.endswith(".safetensors") for f in os.listdir(model_dir)])
and use_safetensors
)
if use_safetensors:
return None
model = model_cls.from_pretrained(
model_dir,
device_map="auto" if not load_model_on_cpu else "cpu",
torch_dtype="auto",
trust_remote_code=True,
)
if hf_config.model_type == "llava":
model = model.language_model
return model
def convert_and_save_hf(args):
model_dir = args.model_dir
load_model_on_cpu = args.load_model_on_cpu
load_by_shard = args.load_by_shard
world_size = args.tp_size * args.pp_size
# Need to convert the cli args to the kay-value pairs and override them in the generate config dict.
# Ideally these fields will be moved out of the config and pass them into build API, keep them here for compatibility purpose for now,
# before the refactor is done.
override_fields = {"moe_tp_mode": args.moe_tp_mode}
quantization = args_to_quantization(args)
override_fields.update(args_to_build_options(args))
if args.smoothquant is not None or args.int8_kv_cache:
assert (
not args.load_by_shard
), "When using quantization, TRT-LLM needs to load the whole HF model, thus load by shard not supported"
assert (
not args.load_model_on_cpu
), "When using quantization, TRT-LLM needs to load the model to GPU"
mapping = Mapping(
world_size=world_size,
rank=-1, # intentinoally make -1 to avoid mistake
tp_size=args.tp_size,
pp_size=args.pp_size,
)
LLaMAForCausalLM.quantize(
args.model_dir,
args.output_dir,
quantization,
dtype=args.dtype,
mapping=mapping,
calib_dataset=args.calib_dataset,
override_fields=override_fields,
dataset_cache_dir=args.dataset_cache_dir,
)
else:
# When not loading by shard, preload one complete model and then slice per rank weights from this
# this saves the disk reloading time
hf_model = (
preload_model(model_dir, load_model_on_cpu)
if not args.load_by_shard
else None
)
def convert_and_save_rank(args, rank):
mapping = Mapping(
world_size=world_size,
rank=rank,
tp_size=args.tp_size,
pp_size=args.pp_size,
)
llama = LLaMAForCausalLM.from_hugging_face(
model_dir,
args.dtype,
mapping=mapping,
quantization=quantization,
load_by_shard=load_by_shard,
load_model_on_cpu=load_model_on_cpu,
override_fields=override_fields,
preloaded_model=hf_model,
)
llama.save_checkpoint(args.output_dir, save_config=(rank == 0))
del llama
execute(args.workers, [convert_and_save_rank] * world_size, args)
release_gc()
def convert_and_save_gptq(args, rank):
mapping = Mapping(
world_size=args.tp_size * args.pp_size,
tp_size=args.tp_size,
rank=rank,
pp_size=args.pp_size,
)
llama = LLaMAForCausalLM.from_hugging_face(
args.model_dir,
args.dtype,
mapping=mapping,
quantization=args_to_quantization(args),
skip_loading_weights=True,
)
weights = load_from_gptq_llama(llama.config, args.modelopt_quant_ckpt_path)
llama.load(weights)
llama.save_checkpoint(args.output_dir, rank == 0)
def execute(workers, func, args):
if workers == 1:
for rank, f in enumerate(func):
f(args, rank)
else:
with ThreadPoolExecutor(max_workers=workers) as p:
futures = [p.submit(f, args, rank) for rank, f in enumerate(func)]
exceptions = []
for future in as_completed(futures):
try:
future.result()
except Exception as e:
traceback.print_exc()
exceptions.append(e)
assert (
len(exceptions) == 0
), "Checkpoint conversion failed, please check error log."
def main():
print(tensorrt_llm.__version__)
args = parse_arguments()
world_size = args.tp_size * args.pp_size
tik = time.time()
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
if (
args.model_dir is None and args.meta_ckpt_dir is None
): # generate fake config.json
config = from_cli_args(args)
with open(os.path.join(args.output_dir, "config.json"), "w") as f:
json.dump(config, f, indent=4)
elif args.meta_ckpt_dir is not None:
assert (
args.model_dir is None
), "Shall not specify both meta checkpoint dir and hugging face dir"
execute(args.workers, [convert_and_save_meta] * world_size, args)
elif args.weight_only_precision == "int4_gptq":
assert args.model_dir is not None
assert args.modelopt_quant_ckpt_path is not None
execute(args.workers, [convert_and_save_gptq] * world_size, args)
else: # all other non-gptq paths from hf model
assert args.model_dir is not None
assert (
args.modelopt_quant_ckpt_path is None
), "only gptq weights only needs this option"
convert_and_save_hf(args)
tok = time.time()
t = time.strftime("%H:%M:%S", time.gmtime(tok - tik))
print(f"Total time of converting checkpoints: {t}")
if __name__ == "__main__":
main()
cp convert.py /usr/local/lib/python3.10/dist-packages/tensorrt_llm/models/llama/convert.py
from safetensors.torch import load_file, safe_open
from safetensors.torch import save_file
import argparse
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument(
"--huggingface_repo_dir",
type=str,
)
parser.add_argument(
"--thirdparty_repo_dir",
type=str,
)
parser.add_argument(
"--merged_repo_dir",
type=str,
)
return parser.parse_args()
args = parse_arguments()
import shutil
shutil.copytree(args.huggingface_repo_dir, args.merged_repo_dir)
import torch
hf_weights_dict = dict()
hf_wgt_names = [
"model-00001-of-00004.safetensors",
"model-00002-of-00004.safetensors",
"model-00003-of-00004.safetensors",
"model-00004-of-00004.safetensors",
]
for wgt in hf_wgt_names:
ori_weights = load_file(args.huggingface_repo_dir + wgt)
for key, value in ori_weights.items():
if key == "language_model.lm_head.weight":
hf_weights_dict[key] = value
elif key == "language_model.model.embed_tokens.weight":
hf_weights_dict[key] = value
weights = [
"model-00001-of-00004.safetensors",
"model-00002-of-00004.safetensors",
"model-00003-of-00004.safetensors",
"model-00004-of-00004.safetensors",
]
for wgt in weights:
ori_weights = load_file(args.thirdparty_repo_dir + wgt)
# import pdb;pdb.set_trace()
new_weights = dict()
for key, value in ori_weights.items():
if key == "lm_head.weight":
new_key = "language_model.lm_head.weight"
elif key == "model.embed_tokens.weight":
new_key = "language_model.model.embed_tokens.weight"
elif key == "model.image_newline":
new_key = "image_newline"
elif "model.layers." in key:
new_key = key.replace("model", "language_model.model")
elif key == "model.norm.weight":
new_key = "language_model.model.norm.weight"
elif key == "model.mm_projector.0.bias":
new_key = "multi_modal_projector.linear_1.bias"
elif key == "model.mm_projector.0.weight":
new_key = "multi_modal_projector.linear_1.weight"
elif key == "model.mm_projector.2.bias":
new_key = "multi_modal_projector.linear_2.bias"
elif key == "model.mm_projector.2.weight":
new_key = "multi_modal_projector.linear_2.weight"
elif "model.vision_tower.vision_tower" in key:
new_key = key.replace("model.vision_tower.vision_tower", "vision_tower")
if new_key == "language_model.lm_head.weight":
value = torch.cat(
(value, hf_weights_dict["language_model.lm_head.weight"][32000:]), dim=0
)
elif new_key == "language_model.model.embed_tokens.weight":
value = torch.cat(
(
value,
hf_weights_dict["language_model.model.embed_tokens.weight"][32000:],
),
dim=0,
)
new_weights[new_key] = value
save_file(new_weights, args.merged_repo_dir + wgt, metadata={"format": "pt"})
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import json
import time
from pathlib import Path
# isort: off
import torch
import tensorrt as trt
# isort: on
import numpy as np
from transformers import (
AutoConfig,
AutoModelForSeq2SeqLM,
AutoTokenizer,
BartForConditionalGeneration,
MBartForConditionalGeneration,
T5ForConditionalGeneration,
)
import tensorrt_llm
from tensorrt_llm import logger
from tensorrt_llm._utils import torch_to_numpy, trt_dtype_to_torch
from tensorrt_llm.lora_manager import LoraManager
from tensorrt_llm.runtime import ModelConfig, SamplingConfig
def get_engine_name(rank):
return "rank{}.engine".format(rank)
def print_tensor(tensor_name, tensor, num_elements=10):
if tensor.dtype in (torch.int32, torch.int64):
tensor = tensor.to(dtype=float)
print(
f"{tensor_name}: mean={tensor.abs().mean().item():.3f}, sum={tensor.abs().sum().item():.3f}, max={tensor.abs().max().item():.3f}"
)
# Pass num_elements=-1 will print the whole tensor
if num_elements < 0:
num_elements = torch.numel(tensor)
print(f"{tensor.flatten()[:num_elements]}")
print("Tensor Shape: ", tensor.size())
print("")
def read_config(config_path: Path):
with open(config_path, "r") as f:
config = json.load(f)
builder_config = config["build_config"]
plugin_config = builder_config["plugin_config"]
pretrained_config = config["pretrained_config"]
lora_config = builder_config["lora_config"]
auto_parallel_config = builder_config["auto_parallel_config"]
use_gpt_attention_plugin = plugin_config["gpt_attention_plugin"]
remove_input_padding = plugin_config["remove_input_padding"]
use_lora_plugin = plugin_config["lora_plugin"]
tp_size = pretrained_config["mapping"]["tp_size"]
pp_size = pretrained_config["mapping"]["pp_size"]
gpus_per_node = auto_parallel_config["gpus_per_node"]
world_size = tp_size * pp_size
assert (
world_size == tensorrt_llm.mpi_world_size()
), f"Engine world size ({world_size}) != Runtime world size ({tensorrt_llm.mpi_world_size()})"
num_heads = pretrained_config["num_attention_heads"]
hidden_size = pretrained_config["hidden_size"]
head_size = pretrained_config["head_size"]
vocab_size = pretrained_config["vocab_size"]
max_batch_size = builder_config["max_batch_size"]
max_beam_width = builder_config["max_beam_width"]
num_layers = pretrained_config["num_hidden_layers"]
num_kv_heads = pretrained_config.get("num_kv_heads", num_heads)
assert (num_heads % tp_size) == 0
num_heads = num_heads // tp_size
hidden_size = hidden_size // tp_size
num_kv_heads = (num_kv_heads + tp_size - 1) // tp_size
cross_attention = pretrained_config["architecture"] == "DecoderModel"
skip_cross_qkv = pretrained_config.get("skip_cross_qkv", False)
has_position_embedding = pretrained_config["has_position_embedding"]
has_token_type_embedding = hasattr(pretrained_config, "type_vocab_size")
use_custom_all_reduce = plugin_config.get("use_custom_all_reduce", False)
dtype = pretrained_config["dtype"]
paged_kv_cache = plugin_config["paged_kv_cache"]
tokens_per_block = plugin_config["tokens_per_block"]
gather_context_logits = builder_config.get("gather_context_logits", False)
gather_generation_logits = builder_config.get("gather_generation_logits", False)
max_prompt_embedding_table_size = builder_config.get(
"max_prompt_embedding_table_size", 0
)
model_config = ModelConfig(
num_heads=num_heads,
num_kv_heads=num_kv_heads,
hidden_size=hidden_size,
head_size=head_size,
max_batch_size=max_batch_size,
max_beam_width=max_beam_width,
vocab_size=vocab_size,
num_layers=num_layers,
gpt_attention_plugin=use_gpt_attention_plugin,
remove_input_padding=remove_input_padding,
paged_kv_cache=paged_kv_cache,
tokens_per_block=tokens_per_block,
cross_attention=cross_attention,
has_position_embedding=has_position_embedding,
has_token_type_embedding=has_token_type_embedding,
use_custom_all_reduce=use_custom_all_reduce,
dtype=dtype,
gather_context_logits=gather_context_logits,
gather_generation_logits=gather_generation_logits,
max_prompt_embedding_table_size=max_prompt_embedding_table_size,
lora_plugin=use_lora_plugin,
lora_target_modules=lora_config.get("lora_target_modules"),
trtllm_modules_to_hf_modules=lora_config.get("trtllm_modules_to_hf_modules"),
skip_cross_qkv=skip_cross_qkv,
)
return model_config, tp_size, pp_size, gpus_per_node, dtype
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument("--max_new_tokens", type=int, default=64)
parser.add_argument("--log_level", type=str, default="error")
parser.add_argument("--engine_dir", "-i", type=str, default="trt_engines")
parser.add_argument("--engine_name", type=str, default="enc_dec")
parser.add_argument(
"--model_name",
type=str,
help="HuggingFace model name or FairSeq model path",
default="t5-small",
)
parser.add_argument(
"--num_beams", type=int, help="Use beam search if num_beams >1", default=1
)
parser.add_argument(
"--debug_mode",
help="Whether or not to turn on the debug mode",
action="store_true",
)
parser.add_argument(
"--compare_hf_fp32",
help="Compare results with HuggingFace FP32",
action="store_true",
)
parser.add_argument("--lora_dir", type=str, default=None, nargs="+")
parser.add_argument("--lora_task_uids", type=str, default=None, nargs="+")
parser.add_argument(
"--output_encoder_npy",
help="Store tensors like encoder outputs used for testing enc-dec C++ runtime.",
action="store_true",
)
return parser.parse_args()
class TRTLLMEncDecModel:
def __init__(
self,
engine_name,
engine_dir,
lora_dir=None,
lora_task_uids=None,
debug_mode=False,
skip_encoder=False,
stream: torch.cuda.Stream = None,
):
# in multi-node setup, it's important to set_device at the very beginning so .to('cuda') refers to current device
# accordingly, all input & output tensors should be moved to current device
# otherwise, it's default to 'cuda:0'
self.runtime_rank = tensorrt_llm.mpi_rank()
device_id = self.runtime_rank % torch.cuda.device_count()
torch.cuda.set_device(device_id)
self.device = torch.cuda.current_device()
self.skip_encoder = skip_encoder
self.lora_task_uids = lora_task_uids
# when enc-dec runs by itself, stream can be None and we create new stream here
# when enc-dec has to run as a component in a bigger workflow (e.g., multimodal), earlier components in the workflow may have results in its stream, which we should pass that stream in to avoid unnecessary stream sync
self.stream = stream
if self.stream is None:
self.stream = torch.cuda.Stream(self.device)
torch.cuda.set_stream(self.stream)
engine_dir = Path(engine_dir)
def engine_setup(component):
# model config
config_path = engine_dir / component / "config.json"
logger.info(f"Using config path {config_path}")
model_config, tp_size, pp_size, gpus_per_node, dtype = read_config(
config_path
)
# MGMN config
world_size = tp_size * pp_size
runtime_rank = tensorrt_llm.mpi_rank()
assert (
runtime_rank < world_size
), "Runtime GPU rank exceeds MPI world size. Did you launch more MPI processes than required?"
runtime_mapping = tensorrt_llm.Mapping(
world_size,
runtime_rank,
tp_size=tp_size,
pp_size=pp_size,
gpus_per_node=gpus_per_node,
)
# load engine
engine_fname = get_engine_name(runtime_rank)
with open(engine_dir / component / engine_fname, "rb") as f:
engine_buffer = f.read()
return model_config, runtime_mapping, engine_buffer
# Note: encoder and decoder doesn't necessarily have the same TP & PP config
if not skip_encoder:
(
self.encoder_model_config,
self.encoder_runtime_mapping,
encoder_engine_buffer,
) = engine_setup(component="encoder")
# for Pipeline Parallelism in encoder
self.nccl_comm = torch.classes.trtllm.NcclCommunicatorOp(
self.encoder_runtime_mapping.tp_size,
self.encoder_runtime_mapping.pp_size,
self.encoder_runtime_mapping.rank,
)
# session setup
self.encoder_session = tensorrt_llm.runtime.Session.from_serialized_engine(
encoder_engine_buffer
)
# encoder lora manager setup
if self.encoder_model_config.lora_plugin:
self.encoder_lora_manager = LoraManager()
# TODO: this is only for bart
self.encoder_lora_manager.load_from_hf(
model_dirs=lora_dir,
model_config=self.encoder_model_config,
runtime_mapping=self.encoder_runtime_mapping,
component="encoder",
)
else:
self.encoder_lora_manager = None
else:
(
self.encoder_model_config,
self.encoder_runtime_mapping,
encoder_engine_buffer,
) = (None, None, None)
self.nccl_comm, self.encoder_session = None, None
(
self.decoder_model_config,
self.decoder_runtime_mapping,
decoder_engine_buffer,
) = engine_setup(component="decoder")
self.decoder_session = tensorrt_llm.runtime.GenerationSession(
self.decoder_model_config,
decoder_engine_buffer,
self.decoder_runtime_mapping,
debug_mode=debug_mode,
)
# decoder lora manager setup
if self.decoder_model_config.lora_plugin:
self.decoder_lora_manager = LoraManager()
# TODO: this is only for bart
self.decoder_lora_manager.load_from_hf(
model_dirs=lora_dir,
model_config=self.decoder_model_config,
runtime_mapping=self.decoder_runtime_mapping,
component="decoder",
)
else:
self.decoder_lora_manager = None
@classmethod
def from_engine(
cls,
engine_name,
engine_dir,
lora_dir=None,
lora_task_uids=None,
debug_mode=False,
skip_encoder=False,
stream=None,
):
return cls(
engine_name,
engine_dir,
lora_dir,
lora_task_uids,
debug_mode=debug_mode,
skip_encoder=skip_encoder,
stream=stream,
)
def process_input(
self, input_ids, remove_input_padding=False, pad_token_id=0, prompt_tasks=None
):
if remove_input_padding:
# in remove padding mode --> flatten input, calculate actual length and max length
# Note: 1st token should never be removed, even if it is pad_token_id
first_ids = input_ids[:, 0]
input_ids = input_ids[:, 1:]
input_lengths = 1 + (input_ids != pad_token_id).sum(dim=1).type(
torch.IntTensor
).to(
self.device
) # [batch_size]
new_ids = []
for i in range(len(input_ids)):
row = input_ids[i, :]
row = row[row != pad_token_id]
new_ids.append(
torch.cat((torch.IntTensor([first_ids[i]]).to(self.device), row))
)
input_ids = torch.cat(new_ids) # [num_tokens]
if prompt_tasks is not None:
prompt_tasks = prompt_tasks[: input_ids.shape[0]]
else:
# in padding mode --> keep input, just calculate actual length and max length
# Note: 1st token should always count, even if it is pad_token_id. e.g., decoder start id in enc-dec models could be a single pad_token_id, we should count
input_lengths = torch.tensor(
1
+ (input_ids[:, 1:] != pad_token_id)
.sum(dim=1)
.type(torch.IntTensor)
.to(self.device),
dtype=torch.int32,
device=self.device,
)
max_input_length = torch.max(input_lengths).item()
return input_ids, input_lengths, max_input_length, prompt_tasks
def encoder_run(
self,
input_ids,
input_lengths,
max_input_length,
position_ids=None,
token_type_ids=None,
debug_mode=False,
prompt_embedding_table=None,
prompt_tasks=None,
prompt_vocab_size=None,
attention_mask=None,
):
# each engine has hidden_dim/TP, don't forget to multiply TP
hidden_size = (
self.encoder_model_config.hidden_size * self.encoder_runtime_mapping.tp_size
)
if input_ids.dim() == 1:
hidden_states_shape = (input_ids.shape[0], hidden_size) # [num_tokens,D]
else:
hidden_states_shape = (
input_ids.shape[0],
input_ids.shape[1],
hidden_size,
) # [BS,seqlen,D]
hidden_states_dtype = lambda name: trt_dtype_to_torch(
self.encoder_session.engine.get_tensor_dtype(name)
)
# input tensors. only first PP rank has id input, others are hidden_states input
inputs = {}
if self.encoder_runtime_mapping.is_first_pp_rank():
inputs["input_ids"] = input_ids.contiguous()
if self.encoder_model_config.has_position_embedding:
if position_ids is None:
if self.encoder_model_config.remove_input_padding:
position_ids = [
torch.arange(
sample_length,
dtype=torch.int32,
device=input_ids.device,
)
for sample_length in torch_to_numpy(input_lengths)
]
position_ids = torch.cat(position_ids)
else:
bsz, seq_len = input_ids.shape[:2]
position_ids = torch.arange(
seq_len, dtype=torch.int32, device=input_ids.device
).expand(bsz, -1)
inputs["position_ids"] = position_ids.contiguous()
if self.encoder_model_config.has_token_type_embedding:
inputs["token_type_ids"] = token_type_ids.contiguous()
if self.encoder_model_config.max_prompt_embedding_table_size > 0:
inputs["prompt_embedding_table"] = prompt_embedding_table.contiguous()
inputs["tasks"] = prompt_tasks.contiguous()
inputs["prompt_vocab_size"] = prompt_vocab_size.contiguous()
else:
# just need a placeholder, engine will call NCCL to recv and fill data from previous rank
inputs["hidden_states_input"] = torch.empty(
hidden_states_shape,
dtype=hidden_states_dtype("hidden_states_input"),
device=self.device,
).contiguous()
if (
attention_mask is not None
and not self.encoder_model_config.gpt_attention_plugin
):
inputs["attention_mask"] = attention_mask.contiguous()
inputs["input_lengths"] = input_lengths
# use shape info to pass max length info in remove padding mode
inputs["max_input_length"] = torch.empty(
(max_input_length,),
dtype=hidden_states_dtype("max_input_length"),
device=self.device,
).contiguous()
batch_size = input_lengths.size(0)
inputs["host_request_types"] = torch.IntTensor([0] * batch_size).to("cpu")
if self.encoder_model_config.remove_input_padding:
inputs["host_context_lengths"] = input_lengths.to("cpu")
if (
self.encoder_model_config.lora_plugin
and self.encoder_lora_manager is not None
):
inputs.update(
self.encoder_lora_manager.input_buffers(
self.lora_task_uids,
self.encoder_runtime_mapping,
self.encoder_model_config.num_layers,
)
)
# Note: runtime.Session's run() method will set input/output tensor address, here we only need to provide tensor shape
self.encoder_session.set_shapes(inputs)
# output tensors. only last PP rank final encoder output, others are intermediate hidden_states output. Need broadcast later
outputs = {}
if self.encoder_runtime_mapping.is_last_pp_rank():
outputs["encoder_output"] = torch.empty(
hidden_states_shape,
dtype=hidden_states_dtype("encoder_output"),
device=self.device,
).contiguous()
else:
outputs["hidden_states_output"] = torch.empty(
hidden_states_shape,
dtype=hidden_states_dtype("hidden_states_output"),
device=self.device,
).contiguous()
# -------------------------------------------
if debug_mode:
engine = self.encoder_session.engine
context = self.encoder_session.context
# setup debugging buffer for the encoder
for i in range(self.encoder_session.engine.num_io_tensors):
name = engine.get_tensor_name(i)
if (
engine.get_tensor_mode(name) == trt.TensorIOMode.OUTPUT
and name not in outputs.keys()
):
dtype = engine.get_tensor_dtype(name)
shape = context.get_tensor_shape(name)
outputs[name] = torch.zeros(
tuple(shape),
dtype=trt_dtype_to_torch(dtype),
device=self.device,
)
context.set_tensor_address(name, outputs[name].data_ptr())
# -------------------------------------------
# TRT session run
# Note: need cuda stream ID, not a torch Stream
ok = self.encoder_session.run(inputs, outputs, self.stream.cuda_stream)
assert ok, "Runtime execution failed"
self.stream.synchronize()
# Tensor Parallelism is handled by model/engine definition
# But we need to broadcast among PP group at the end of encoder's Pipeline Parallelism
# After this, all ranks should recv the encoder output, and world might be re-configured using decoder's TP-PP config
def pp_communicate_encoder_output(encoder_output):
if self.encoder_runtime_mapping.is_last_pp_rank():
for pp_rank in self.encoder_runtime_mapping.pp_group:
if pp_rank != self.encoder_runtime_mapping.rank:
self.nccl_comm.send(encoder_output, pp_rank)
return encoder_output
else:
self.nccl_comm.recv(
encoder_output, self.encoder_runtime_mapping.pp_group[-1]
)
return encoder_output
if self.encoder_runtime_mapping.has_pp():
# use hidden_states output buffer to receive output as the shapes are same
encoder_output_buf = (
outputs["encoder_output"]
if self.encoder_runtime_mapping.is_last_pp_rank()
else outputs["hidden_states_output"]
)
encoder_output = pp_communicate_encoder_output(encoder_output_buf)
else:
encoder_output = outputs["encoder_output"]
# -------------------------------------------
if (
debug_mode and self.encoder_runtime_mapping.tp_rank == 0
): # only tp_rank 0 print encoder output
torch.cuda.synchronize()
# use print_tensor() to print the tensors registered in the encoder network
print("--------------------------------------")
print("Debug output for Encoder")
print("--------------------------------------")
print("Registered output tensors are: ", outputs.keys())
for k, v in outputs.items():
print_tensor(k, v, num_elements=30)
print_tensor("encoder_output", encoder_output)
print("--------------------------------------")
# -------------------------------------------
return encoder_output
def generate(
self,
encoder_input_ids,
decoder_input_ids,
max_new_tokens,
num_beams=1,
pad_token_id=None,
eos_token_id=None,
bos_token_id=None,
debug_mode=False,
return_dict=False,
prompt_embedding_table=None,
prompt_tasks=None,
prompt_vocab_size=None,
attention_mask=None,
time_encoder=False,
return_encoder_output=False,
):
## ensure all externally provided tensors are on the correct device.
encoder_input_ids = encoder_input_ids.to(self.device)
decoder_input_ids = decoder_input_ids.to(self.device)
if attention_mask is not None:
attention_mask = torch.tensor(
attention_mask, dtype=torch.int32, device=self.device
)
## encoder run
encoder_remove_input_padding = (
self.encoder_model_config.remove_input_padding
if self.encoder_model_config
else self.decoder_model_config.remove_input_padding
)
(
encoder_input_ids,
encoder_input_lengths,
encoder_max_input_length,
prompt_tasks,
) = self.process_input(
encoder_input_ids, encoder_remove_input_padding, pad_token_id, prompt_tasks
)
if not self.skip_encoder:
logger.info(f"Rank {self.runtime_rank} Running encoder engine ...")
if time_encoder:
tik = time.time()
encoder_output = self.encoder_run(
encoder_input_ids,
encoder_input_lengths,
encoder_max_input_length,
debug_mode=debug_mode,
prompt_embedding_table=prompt_embedding_table,
prompt_tasks=prompt_tasks,
prompt_vocab_size=prompt_vocab_size,
attention_mask=attention_mask,
)
if time_encoder:
tok = time.time()
print(f"TRT-LLM Encoder time {(tok-tik)*1000}ms")
else:
encoder_output = prompt_embedding_table
if encoder_input_ids.dim() > 1:
encoder_output = encoder_output.unsqueeze(0)
## decoder run
logger.info(f"Rank {self.runtime_rank} Running decoder engine ...")
decoder_input_ids, decoder_input_lengths, decoder_max_input_length, _ = (
self.process_input(
decoder_input_ids,
self.decoder_model_config.remove_input_padding,
pad_token_id,
)
)
# `cross_attention_mask` in context phase [batch_size, query_len, encoder_input_len]
# where query_len happens to be 1 in current cases, but not necessarily always, and
# `cross_attention_mask` in generation phase [batch_size, 1, encoder_input_len] where
# the query_len is always 1 since we have kv cache.
cross_attention_mask = None
if attention_mask is not None:
cross_attention_mask = torch.tensor(
attention_mask, dtype=torch.int32, device=self.device
).reshape(attention_mask.shape[0], 1, attention_mask.shape[1])
# generation config
sampling_config = SamplingConfig(
end_id=eos_token_id,
pad_id=pad_token_id,
num_beams=num_beams,
min_length=1,
return_dict=return_dict,
)
sampling_config.update(
output_cum_log_probs=return_dict, output_log_probs=return_dict
)
# decoder autoregressive generation
self.decoder_session.setup(
decoder_input_lengths.size(0),
decoder_max_input_length,
max_new_tokens,
num_beams,
max_attention_window_size=None,
encoder_max_input_length=encoder_max_input_length,
lora_manager=self.decoder_lora_manager,
lora_uids=self.lora_task_uids,
)
output = self.decoder_session.decode(
decoder_input_ids,
decoder_input_lengths,
sampling_config,
encoder_output=encoder_output,
encoder_input_lengths=encoder_input_lengths,
return_dict=return_dict,
cross_attention_mask=cross_attention_mask,
)
if return_encoder_output:
return output, encoder_output
return output
def test_fairseq_models(args):
## Note: NMT is the only FairSeq model. Adding FairSeq dependency is too heavy for the CI workflow, hence we used fixed input/output ids for correctness check and leave FairSeq code in comments. Users can follow Encoder-Decoder's README to install FairSeq and test locally.
"""
from fairseq.models.transformer import TransformerModel
fairseq_model = TransformerModel.from_pretrained(model_name_or_path=args.model_name, data_name_or_path=args.model_name, bpe='subword_nmt', tokenizer='moses').cuda()
input_text = "Good Morning! How are you doing today?"
input_ids = fairseq_model.encode(input_text)
tik = time.time()
# Note: FairSeq sampling=True results are not deterministic, disable during accuracy check
fairseq_output_ids = fairseq_model.generate(input_ids, beam=1, sampling=False) #
tik = time.time()
fairseq_output_ids = fairseq_output_ids[0]['tokens']
fairseq_output_text = fairseq_model.decode(fairseq_output_ids)
print("--------------------------------------")
print("input text: ", input_text)
print("input ids: ", input_ids) # [9938, 5384, 9328, 812, 3619, 53, 181, 3829, 1735, 171, 2]
print("fairseq_output ids: ", fairseq_output_ids) # [9804, 391, 4, 4625, 167, 25, 1003, 5123, 17, 167, 1466, 1234, 171, 2]
print("fairseq_output text: ", fairseq_output_text) # "Bonjour, Comment vous en tirez-vous aujourd'hui ?"
print(f"FairSeq E2E time {(tok-tik)*1000}ms")
print("--------------------------------------")
"""
max_new_tokens = args.max_new_tokens
bos_token_id = 2
pad_token_id = 0
eos_token_id = 2
decoder_start_token_id = bos_token_id
input_ids = torch.tensor([9938, 5384, 9328, 812, 3619, 53, 181, 3829, 1735, 171, 2])
fairseq_output_ids = torch.tensor(
[9804, 391, 4, 4625, 167, 25, 1003, 5123, 17, 167, 1466, 1234, 171, 2]
)
input_ids = torch.tensor([input_ids.tolist()]).type(torch.IntTensor).cuda()
decoder_input_ids = torch.IntTensor([[decoder_start_token_id]]).cuda()
decoder_input_ids = decoder_input_ids.repeat((input_ids.shape[0], 1))
tllm_model = TRTLLMEncDecModel.from_engine(
args.engine_name, args.engine_dir, debug_mode=args.debug_mode
)
inference_dtype = tllm_model.encoder_model_config.dtype
return_dict = False # when set return_dict=True, get outputs by key
tik = time.time()
tllm_output = tllm_model.generate(
encoder_input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
max_new_tokens=max_new_tokens,
num_beams=args.num_beams,
bos_token_id=bos_token_id,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
debug_mode=args.debug_mode,
)
tok = time.time()
torch.cuda.synchronize()
if return_dict:
tllm_output_ids = tllm_output["output_ids"]
else:
tllm_output_ids = tllm_output
if tensorrt_llm.mpi_rank() == 0:
output_ids = tllm_output_ids[:, 0, :]
output_ids = output_ids[output_ids != eos_token_id]
fairseq_output_ids = fairseq_output_ids[fairseq_output_ids != eos_token_id]
print("--------------------------------------")
print("TRT-LLM output_ids: ", output_ids)
print(f"TRT-LLM E2E time {(tok-tik)*1000}ms")
print("Precision:", inference_dtype)
print("--------------------------------------")
assert (
output_ids.tolist() == fairseq_output_ids.tolist()
), f"TRT-LLM output ids {output_ids} does not match Fairseq ids {fairseq_output_ids}"
if __name__ == "__main__":
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
args = parse_arguments()
logger.set_level(args.log_level)
# FairSeq NMT test logic is different from HuggingFace models
if "wmt" in args.model_name:
test_fairseq_models(args)
exit()
test_remove_padding = True
if not test_remove_padding:
if "t5" in args.model_name:
input_text = "translate English to German: The house is wonderful, radiating timeless charm and offering a warm, inviting interior with beautiful details and a serene backyard."
elif "bart" in args.model_name:
input_text = "The tower is 324 metres (1,063 ft) tall, about the same height as an 81-storey building, and the tallest structure in Paris. Its base is square, measuring 125 metres (410 ft) on each side. During its construction, the Eiffel Tower surpassed the Washington Monument to become the tallest man-made structure in the world, a title it held for 41 years until the Chrysler Building in New York City was finished in 1930. It was the first structure to reach a height of 300 metres. Due to the addition of a broadcasting aerial at the top of the tower in 1957, it is now taller than the Chrysler Building by 5.2 metres (17 ft). Excluding transmitters, the Eiffel Tower is the second tallest free-standing structure in France after the Millau Viaduct."
else:
raise RuntimeError("Unsupported model type!")
else:
input_text = [
"translate English to German: The house is wonderful.",
"summarize: I am a high-performance inference optimizer and runtime.",
"During its construction, the Eiffel Tower surpassed the Washington Monument to become the tallest man-made structure in the world",
]
# TRT-LLM runtime
tllm_model = TRTLLMEncDecModel.from_engine(
args.engine_name,
args.engine_dir,
args.lora_dir,
args.lora_task_uids,
debug_mode=args.debug_mode,
)
inference_dtype = tllm_model.encoder_model_config.dtype
if inference_dtype == "float32":
if "byt5" in args.model_name:
print(
"ByT5 models tokenize input by bytes instead of words, causing the input text in this example to be longer than the default value during build stage. Please adjust --max_input_len during trtllm-build to select the right length limit for ByT5 models."
)
else:
input_text.append(
'Summarize this article in one sentence.\n\nKristine Watts (Molie Weeks) is broken apart, missing her lover; she is not able to overcome her love for him that is lost in the past. She hires a stranger (Douglas Davis) and gives a list of her mistakes to him with things to fix. But time is irreversible and sometimes the cure for the pain is a tragic end.\n\nThe first point that impresses in "The Cure" is the stylish cinematography that alternates black and white with color. The concise and sharp screenplay is capable to develop a tragic and bleak tale of love with an unexpected plot point in the very end in less than eight minutes. The soundtrack is beautiful but the volume is a little loud and associated to the fact that English is not my native language, in some moments I needed to repeat some words whispered by the narrator. The unknown lead actress has magnificent performance and is extremely gorgeous. I hope to have a chance to see her again on the screen. Last but not the least, the debut of the director and writer Ryan Jafri could not be better. My vote is nine.\n\nTitle (Brazil): Not Available',
)
tokenizer = AutoTokenizer.from_pretrained(
args.model_name
) # TODO: use model path instead
tokenized_inputs = tokenizer(input_text, return_tensors="pt", padding=True)
max_new_tokens = args.max_new_tokens
input_ids = tokenized_inputs.input_ids.type(torch.IntTensor).to(
"cuda"
) # [batch_size, padded_length]
# by default int64, must cast to int32! otherwise C++ kernel will interpret as [a, 0, b, 0, c, 0, ...]
CPP_RESULTS_SAVED_DIR = "cpp/tests/resources/data/enc_dec"
if tensorrt_llm.mpi_rank() == 0:
if args.output_encoder_npy:
if not os.path.isdir(CPP_RESULTS_SAVED_DIR):
os.mkdir(os.path.join(CPP_RESULTS_SAVED_DIR))
np_input_ids = tokenized_inputs.input_ids.type(torch.IntTensor)
np_input_ids = np_input_ids.numpy()
np.save(
os.path.join(CPP_RESULTS_SAVED_DIR, "enc_input_ids.npy"), np_input_ids
)
input_lengths = (
tokenized_inputs.attention_mask.sum(dim=1).type(torch.IntTensor).numpy()
)
np.save(
os.path.join(CPP_RESULTS_SAVED_DIR, "enc_input_lengths.npy"),
input_lengths,
)
print("--------------------------------------")
print(
f"BOS={tokenizer.bos_token_id}, PAD={tokenizer.pad_token_id}, EOS={tokenizer.eos_token_id}"
)
print("input text: ", input_text)
print("input ids: ", input_ids)
print("input lengths: ", tokenized_inputs.attention_mask.sum(dim=1))
print("--------------------------------------")
model_config = AutoConfig.from_pretrained(args.model_name)
# start_id for decoder (could add more input_ids as forced_decoder_ids)
decoder_input_ids = torch.IntTensor([[model_config.decoder_start_token_id]]).to(
"cuda"
)
decoder_input_ids = decoder_input_ids.repeat((input_ids.shape[0], 1))
# simple comparison with HF on FP32
if args.compare_hf_fp32:
if tensorrt_llm.mpi_rank() == 0:
hf_model = (
AutoModelForSeq2SeqLM.from_pretrained(
args.model_name, # TODO: use model path instead
# torch_dtype=torch.float16 if '16' in dtype else torch.float32, # TODO: use matched torch dtype
)
.to("cuda")
.eval()
) # TODO: create config model path instead
assert type(hf_model) in (
T5ForConditionalGeneration,
BartForConditionalGeneration,
MBartForConditionalGeneration,
), "Unsupported model!"
if args.lora_dir is not None:
assert (
len(args.lora_dir) >= 1
), "At least one lora model dir is required"
# we can only test single lora with HF
from peft import PeftModel
hf_model = (
PeftModel.from_pretrained(hf_model, args.lora_dir[0])
.to("cuda")
.eval()
)
tik = time.time()
hf_gen_output = hf_model.generate(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
max_new_tokens=max_new_tokens,
num_beams=args.num_beams,
bos_token_id=tokenizer.bos_token_id,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
use_cache=True,
# control logits processors
no_repeat_ngram_size=0, # disable no repeat post-processor
forced_bos_token_id=None, # disable forced first/last token
forced_eos_token_id=None,
min_length=0,
# for debug
output_scores=True,
output_hidden_states=True,
return_dict_in_generate=True,
)
# get hf output scores
hf_output_ids = hf_gen_output.sequences
# convert to logits
torch.cuda.synchronize()
tok = time.time()
output_ids = hf_output_ids.squeeze(dim=1)
hf_output_text = tokenizer.batch_decode(
output_ids, skip_special_tokens=True
)
decoder_input_lengths = (decoder_input_ids != tokenizer.pad_token_id).sum(
dim=1
)
output_gen_lengths = (output_ids != tokenizer.eos_token_id).sum(
dim=1
) - decoder_input_lengths
print("--------------------------------------")
print("HF output_ids: ", output_ids)
print("HF output text: ", hf_output_text)
print("HF output generated lengths: ", output_gen_lengths)
print(f"HF E2E time {(tok-tik)*1000}ms")
print("--------------------------------------")
return_dict = False # when set return_dict=True, get outputs by key
tik = time.time()
tllm_output = tllm_model.generate(
encoder_input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
max_new_tokens=max_new_tokens,
num_beams=args.num_beams,
bos_token_id=tokenizer.bos_token_id,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
debug_mode=args.debug_mode,
return_dict=return_dict,
attention_mask=tokenized_inputs.attention_mask,
time_encoder=True,
return_encoder_output=args.output_encoder_npy and tensorrt_llm.mpi_rank() == 0,
)
tok = time.time()
if args.output_encoder_npy and tensorrt_llm.mpi_rank() == 0:
tllm_output, encoder_output = tllm_output
encoder_output = encoder_output.cpu().numpy()
np.save(
os.path.join(CPP_RESULTS_SAVED_DIR, "encoder_output.npy"), encoder_output
)
if return_dict:
tllm_output_ids = tllm_output["output_ids"]
else:
tllm_output_ids = tllm_output
if tensorrt_llm.mpi_rank() == 0:
output_ids = tllm_output_ids[:, 0, :]
output_text = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
decoder_input_lengths = (decoder_input_ids != tokenizer.pad_token_id).sum(dim=1)
output_gen_lengths = (output_ids != tokenizer.eos_token_id).sum(
dim=1
) - decoder_input_lengths
print("--------------------------------------")
print("TRT-LLM output_ids: ", output_ids)
print("TRT-LLM output text: ", output_text)
print("TRT-LLM output generated lengths: ", output_gen_lengths)
print(f"TRT-LLM E2E time {(tok-tik)*1000}ms")
print("Precision:", inference_dtype)
print("--------------------------------------")
# simple accuracy check
if args.compare_hf_fp32:
from difflib import SequenceMatcher
match_rate = SequenceMatcher(
None, "\n".join(output_text), "\n".join(hf_output_text)
).ratio()
print(output_text)
print(hf_output_text)
if inference_dtype != "float32":
print("")
print(
f"[CAVEAT] Comparing TRT-LLM {inference_dtype} results with HF float32 results. Close match are not expected!"
)
assert match_rate > 0.8, f"Incorrect results! Match rate {match_rate}"
else:
assert match_rate > 0.95, f"Incorrect results! Match rate {match_rate}"
print(
f"TRT-LLM results match HF FP32 results with literal match rate {match_rate}"
)
import argparse
import json
import os
import sys
from pathlib import Path
from typing import Tuple, List, Union
from torchvision.transforms import InterpolationMode
from torchvision import transforms
import requests
# isort: off
import torch
import tensorrt as trt
# isort: on
from huggingface_hub import hf_hub_download
from PIL import Image
from torchvision import transforms
from transformers import (
AutoConfig,
AutoProcessor,
AutoTokenizer,
Blip2Processor,
NougatProcessor,
NougatTokenizerFast,
)
import tensorrt_llm
import tensorrt_llm.profiler as profiler
from tensorrt_llm import logger
from tensorrt_llm._utils import str_dtype_to_trt
from tensorrt_llm.runtime import ModelRunner, Session, TensorInfo
import pandas as pd
from run import TRTLLMEncDecModel
import tqdm
class Preprocss:
def __init__(
self,
image_size: int,
):
mean = (0.48145466, 0.4578275, 0.40821073)
std = (0.26862954, 0.26130258, 0.27577711)
self.image_transform = transforms.Compose(
[
transforms.Resize(
(image_size, image_size), interpolation=InterpolationMode.BICUBIC
),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std),
]
)
def encode(self, image_list):
images = []
for image in image_list:
image = image.convert("RGB")
images.append(self.image_transform(image))
images = torch.stack(images, dim=0)
return images
image_pre_obj = Preprocss(336)
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument("--max_new_tokens", type=int, default=30)
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--log_level", type=str, default="info")
parser.add_argument(
"--visual_engine_dir",
type=str,
default=None,
help="Directory containing visual TRT engines",
)
parser.add_argument(
"--llm_engine_dir",
type=str,
default=None,
help="Directory containing TRT-LLM engines",
)
parser.add_argument(
"--hf_model_dir", type=str, default=None, help="Directory containing tokenizer"
)
parser.add_argument("--content", type=str, default=None)
parser.add_argument(
"--image_file", type=str, default="images/demo1.jpeg"
) # 'images/demo1.jpeg'i
parser.add_argument("--input_file", type=str, default=None) # 'images/demo.csv'
parser.add_argument(
"--output_file", type=str, default=None
) # 'images/demo_res.csv'
parser.add_argument(
"--mode",
choices=["caption_zh", "caption_en", "insert_content"],
default="caption_zh",
)
parser.add_argument(
"--num_beams", type=int, help="Use beam search if num_beams >1", default=1
)
parser.add_argument("--top_k", type=int, default=1)
parser.add_argument("--top_p", type=float, default=0.0)
parser.add_argument("--temperature", type=float, default=1.0)
parser.add_argument("--repetition_penalty", type=float, default=1.0)
parser.add_argument(
"--run_profiling",
action="store_true",
help="Profile runtime over several iterations",
)
parser.add_argument(
"--check_accuracy", action="store_true", help="Check correctness of text output"
)
return parser.parse_args()
def trt_dtype_to_torch(dtype):
if dtype == trt.float16:
return torch.float16
elif dtype == trt.float32:
return torch.float32
elif dtype == trt.int32:
return torch.int32
elif dtype == trt.bfloat16:
return torch.bfloat16
else:
raise TypeError("%s is not supported" % dtype)
class MultimodalModelRunner:
def __init__(self, args):
self.args = args
self.runtime_rank = tensorrt_llm.mpi_rank()
device_id = self.runtime_rank % torch.cuda.device_count()
torch.cuda.set_device(device_id)
self.device = "cuda:%d" % (device_id)
self.stream = torch.cuda.Stream(torch.cuda.current_device())
torch.cuda.set_stream(self.stream)
# parse model type from visual engine config
with open(os.path.join(self.args.visual_engine_dir, "config.json"), "r") as f:
config = json.load(f)
self.model_type = config["builder_config"]["model_type"]
self.vision_precision = config["builder_config"]["precision"]
if self.model_type == "pix2struct":
self.vision_precision = "float16"
self.decoder_llm = not (
"t5" in self.model_type or self.model_type in ["nougat", "pix2struct"]
) # BLIP2-T5, pix2struct and Nougat are using encoder-decoder models as LLMs
self.profiling_iterations = 20
self.init_image_encoder()
self.init_tokenizer()
self.init_llm()
def init_tokenizer(self):
if self.model_type == "nougat":
self.tokenizer = NougatTokenizerFast.from_pretrained(self.args.hf_model_dir)
elif self.model_type == "neva":
from sentencepiece import SentencePieceProcessor
sp = SentencePieceProcessor(
os.path.join(self.args.hf_model_dir, "tokenizer.model")
)
class return_obj:
def __init__(self, input_ids):
self.input_ids = input_ids
def __getitem__(self, name):
if name in "input_ids":
return self.input_ids
else:
raise AttributeError(f"'return_obj' has no item '{name}'")
# sentencepiece does not follow the same interface as HF
class HFTokenizerInterface:
def encode(self, x, return_tensors=None, **kwargs):
out = sp.encode(x)
if return_tensors == "pt":
out = torch.tensor(out)
return return_obj(out)
def __call__(self, x, return_tensors=None, **kwargs):
return self.encode(x, return_tensors, **kwargs)
def decode(self, x, **kwargs):
return sp.decode(x.tolist())
def batch_decode(self, x, **kwargs):
return self.decode(x, **kwargs)
self.tokenizer = HFTokenizerInterface()
self.tokenizer.eos_token_id = sp.eos_id()
self.tokenizer.bos_token_id = sp.bos_id()
self.tokenizer.pad_token_id = sp.pad_id()
else:
self.tokenizer = AutoTokenizer.from_pretrained(
self.args.hf_model_dir, use_fast=False, use_legacy=False
)
self.tokenizer.padding_side = "right"
def init_image_encoder(self):
vision_encoder_path = os.path.join(
self.args.visual_engine_dir, "visual_encoder.engine"
)
logger.info(f"Loading engine from {vision_encoder_path}")
with open(vision_encoder_path, "rb") as f:
engine_buffer = f.read()
logger.info(f"Creating session from engine {vision_encoder_path}")
self.visual_encoder_session = Session.from_serialized_engine(engine_buffer)
def init_llm(self):
if self.decoder_llm:
self.model = ModelRunner.from_dir(
self.args.llm_engine_dir,
rank=tensorrt_llm.mpi_rank(),
debug_mode=False,
stream=self.stream,
)
self.model_config = self.model.session._model_config
self.runtime_mapping = self.model.session.mapping
else:
self.model = TRTLLMEncDecModel.from_engine(
os.path.basename(self.args.hf_model_dir),
self.args.llm_engine_dir,
skip_encoder=self.model_type in ["nougat", "pix2struct"],
debug_mode=False,
stream=self.stream,
)
if self.model_type in ["nougat", "pix2struct"]:
self.model_config = self.model.decoder_model_config
self.runtime_mapping = self.model.decoder_runtime_mapping
else:
self.model_config = self.model.encoder_model_config
self.runtime_mapping = self.model.encoder_runtime_mapping
def preprocess(self, warmup, pre_prompt, post_prompt, image, attention_mask):
if self.model_type == "kosmos-2":
input_ids = image["input_ids"].clone()
image_mask = image["image_embeds_position_mask"]
image = image["pixel_values"]
input_ids += image_mask * (self.model_config.vocab_size - 4)
input_ids = input_ids.expand(self.args.batch_size, *input_ids.shape[1:])
length = input_ids.shape[1]
if not warmup:
profiler.start("Vision")
visual_features, visual_atts = self.get_visual_features(
(
torch.stack(image["image_patches"], dim=0)
if self.model_type == "fuyu"
else image
),
attention_mask,
)
if not warmup:
profiler.stop("Vision")
if self.model_type == "fuyu":
visual_features = visual_features.squeeze()
input_ids = image["input_ids"].to(torch.int32)
image_patches_indices = image["image_patches_indices"].to(torch.int32)
input_ids = input_ids.expand(self.args.batch_size, *input_ids.shape[1:])
image_patches_indices = image_patches_indices.expand(
self.args.batch_size, *image_patches_indices.shape[1:]
)
input_ids = self.ptuning_setup_fuyu(input_ids, image_patches_indices)
input_ids = torch.stack(input_ids, dim=0).to("cpu")
length = input_ids.shape[1]
elif self.model_type == "kosmos-2":
visual_features = visual_features.squeeze()
else:
pre_input_ids = self.tokenizer(
pre_prompt, return_tensors="pt", padding=True
).input_ids
if post_prompt[0] is not None:
post_input_ids = self.tokenizer(
post_prompt, return_tensors="pt", padding=True
).input_ids
length = (
pre_input_ids.shape[1]
+ post_input_ids.shape[1]
+ visual_atts.shape[1]
)
else:
post_input_ids = None
length = pre_input_ids.shape[1] + visual_atts.shape[1]
input_lengths = torch.IntTensor([length] * args.batch_size).to(torch.int32)
if self.model_type in ["fuyu", "kosmos-2"]:
return input_ids, input_lengths, [visual_features], visual_features
input_ids, ptuning_args = self.setup_fake_prompts(
visual_features, pre_input_ids, post_input_ids, input_lengths
)
return input_ids, input_lengths, ptuning_args, visual_features
def generate(
self,
pre_prompt,
post_prompt,
image,
decoder_input_ids,
max_new_tokens,
attention_mask,
warmup,
):
if not warmup:
profiler.start("Generate")
input_ids, input_lengths, ptuning_args, visual_features = self.preprocess(
warmup, pre_prompt, post_prompt, image, attention_mask
)
if warmup:
return None
profiler.start("LLM")
if self.decoder_llm:
end_id = self.tokenizer.eos_token_id
if "opt" in self.model_type and "blip2" in self.model_type:
# For BLIP2-OPT, model outputs a "\n" at the end.
# we avoid it by using newline as the end token
end_id = self.tokenizer.encode("\n", add_special_tokens=False)[0]
ptuning_args[0] = torch.stack([ptuning_args[0]])
output_ids = self.model.generate(
input_ids,
sampling_config=None,
prompt_table=ptuning_args[0],
max_new_tokens=max_new_tokens,
end_id=end_id,
pad_id=(
self.tokenizer.pad_token_id
if self.tokenizer.pad_token_id is not None
else self.tokenizer.all_special_ids[0]
),
top_k=self.args.top_k,
top_p=self.args.top_p,
temperature=self.args.temperature,
repetition_penalty=self.args.repetition_penalty,
num_beams=self.args.num_beams,
output_sequence_lengths=False,
return_dict=False,
)
else:
if self.model_type in ["nougat", "pix2struct"]:
# Trim encoder input_ids to match visual features shape
ids_shape = (self.args.batch_size, visual_features.shape[1])
if self.model_type == "nougat":
input_ids = torch.zeros(ids_shape, dtype=torch.int32)
elif self.model_type == "pix2struct":
input_ids = torch.ones(ids_shape, dtype=torch.int32)
output_ids = self.model.generate(
input_ids,
decoder_input_ids,
max_new_tokens,
num_beams=self.args.num_beams,
bos_token_id=self.tokenizer.bos_token_id,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id,
debug_mode=False,
prompt_embedding_table=ptuning_args[0],
prompt_tasks=ptuning_args[1],
prompt_vocab_size=ptuning_args[2],
attention_mask=attention_mask,
)
# Reset input_lengths to match decoder_input_ids
input_lengths = torch.ones(input_lengths.shape, dtype=input_lengths.dtype)
profiler.stop("LLM")
if tensorrt_llm.mpi_rank() == 0:
# Extract a list of tensors of shape beam_width x output_ids.
output_beams_list = [
self.tokenizer.batch_decode(
output_ids[batch_idx, :, input_lengths[batch_idx] :],
skip_special_tokens=True,
)
for batch_idx in range(self.args.batch_size)
]
stripped_text = [
[
output_beams_list[batch_idx][beam_idx].strip()
for beam_idx in range(self.args.num_beams)
]
for batch_idx in range(self.args.batch_size)
]
profiler.stop("Generate")
return stripped_text
else:
profiler.stop("Generate")
return None
def get_visual_features(self, image, attention_mask):
visual_features = {
"input": image.to(
tensorrt_llm._utils.str_dtype_to_torch(self.vision_precision)
)
}
if attention_mask is not None:
visual_features["attention_mask"] = attention_mask
tensor_info = [
TensorInfo("input", str_dtype_to_trt(self.vision_precision), image.shape)
]
if attention_mask is not None:
tensor_info.append(
TensorInfo("attention_mask", trt.DataType.INT32, attention_mask.shape)
)
visual_output_info = self.visual_encoder_session.infer_shapes(tensor_info)
visual_outputs = {
t.name: torch.empty(
tuple(t.shape), dtype=trt_dtype_to_torch(t.dtype), device=image.device
)
for t in visual_output_info
}
ok = self.visual_encoder_session.run(
visual_features, visual_outputs, self.stream.cuda_stream
)
assert ok, "Runtime execution failed for vision encoder session"
self.stream.synchronize()
image_embeds = visual_outputs["output"]
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(
image.device
)
return image_embeds, image_atts
def setup_fake_prompts(
self, visual_features, pre_input_ids, post_input_ids, input_lengths
):
# Assemble fake prompts which points to image embedding actually
fake_prompt_id = torch.arange(
self.model_config.vocab_size,
self.model_config.vocab_size
+ visual_features.shape[0] * visual_features.shape[1],
)
fake_prompt_id = fake_prompt_id.reshape(
visual_features.shape[0], visual_features.shape[1]
)
if "cogvlm" in self.model_type:
input_ids = (
torch.cat(
[pre_input_ids[:, 0:1], fake_prompt_id, pre_input_ids[:, 1:]], dim=1
)
.contiguous()
.to(torch.int32)
)
else:
if post_input_ids is not None:
input_ids = [pre_input_ids, fake_prompt_id, post_input_ids]
else:
input_ids = [fake_prompt_id, pre_input_ids]
input_ids = torch.cat(input_ids, dim=1).contiguous().to(torch.int32)
if self.decoder_llm or self.runtime_mapping.is_first_pp_rank():
ptuning_args = self.ptuning_setup(visual_features, input_ids, input_lengths)
else:
ptuning_args = [None, None, None]
return input_ids, ptuning_args
def ptuning_setup_fuyu(self, input_ids, image_patches_indices):
res_input_ids = []
for cur_input_ids, cur_image_patches_indices in zip(
input_ids, image_patches_indices
):
# Truncate input_ids to the length of image_patches_indices
cur_image_patches_indices = cur_image_patches_indices[: len(cur_input_ids)]
# Get ids of the image_patches
non_zero_mask = cur_image_patches_indices != -1
# Replace input_ids with image_patches_indices values (where the patches are placed)
cur_input_ids = cur_input_ids.masked_scatter(
non_zero_mask,
cur_image_patches_indices[non_zero_mask] + self.model_config.vocab_size,
)
res_input_ids.append(cur_input_ids)
return res_input_ids
def ptuning_setup(self, prompt_table, input_ids, input_lengths):
hidden_size = self.model_config.hidden_size * self.runtime_mapping.tp_size
if prompt_table is not None:
task_vocab_size = torch.tensor(
[prompt_table.shape[1]],
dtype=torch.int32,
).cuda()
prompt_table = prompt_table.view(
(prompt_table.shape[0] * prompt_table.shape[1], prompt_table.shape[2])
)
assert (
prompt_table.shape[1] == hidden_size
), "Prompt table dimensions do not match hidden size"
prompt_table = prompt_table.cuda().to(
dtype=tensorrt_llm._utils.str_dtype_to_torch(self.model_config.dtype)
)
else:
prompt_table = torch.empty([1, hidden_size]).cuda()
task_vocab_size = torch.zeros([1]).cuda()
if self.model_config.remove_input_padding:
tasks = torch.zeros([torch.sum(input_lengths)], dtype=torch.int32).cuda()
if self.decoder_llm:
tasks = tasks.unsqueeze(0)
else:
tasks = torch.zeros(input_ids.shape, dtype=torch.int32).cuda()
return [prompt_table, tasks, task_vocab_size]
def load_test_image(self):
if "vila" in self.model_type:
img_url = "https://github.com/Efficient-Large-Model/VILA/raw/main/demo_images/av.png"
image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
elif "nougat" in self.model_type:
filepath = hf_hub_download(
repo_id="hf-internal-testing/fixtures_docvqa",
filename="nougat_paper.png",
repo_type="dataset",
)
image = Image.open(filepath)
elif "fuyu" in self.model_type:
filepath = hf_hub_download(
repo_id="adept/fuyu-8b", filename="skateboard.png", repo_type="model"
)
image = Image.open(filepath)
elif "kosmos" in self.model_type:
img_url = "https://huggingface.co/microsoft/kosmos-2-patch14-224/resolve/main/snowman.png"
image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
elif "pix2struct" in self.model_type:
img_url = "https://raw.githubusercontent.com/vis-nlp/ChartQA/main/ChartQA%20Dataset/val/png/multi_col_40963.png"
image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
else:
img_url = "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/assets/merlion.png"
image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
return image
def setup_inputs(self, input_text, raw_image):
attention_mask = None
if "blip2" in self.model_type:
processor = Blip2Processor.from_pretrained(self.model_type)
image = processor(raw_image, input_text, return_tensors="pt")[
"pixel_values"
]
if input_text is None:
input_text = "Question: which city is this? Answer:"
pre_prompt = input_text
post_prompt = None
elif "nougat" in self.model_type:
processor = NougatProcessor.from_pretrained(self.args.hf_model_dir)
image = processor(raw_image, return_tensors="pt")["pixel_values"]
# Nougat doesn't need text prompt (mBART use single token to start generation), just leave a dummy one here
if input_text is None:
input_text = "Question: which city is this? Answer:"
pre_prompt = input_text
post_prompt = None
elif "cogvlm" in self.model_type:
image_size = 490
dtype = torch.bfloat16
transform = transforms.Compose(
[
transforms.Resize(
(image_size, image_size),
interpolation=transforms.InterpolationMode.BICUBIC,
),
transforms.ToTensor(),
transforms.Normalize(
(0.48145466, 0.4578275, 0.40821073),
(0.26862954, 0.26130258, 0.27577711),
),
]
)
image = transform(raw_image).to(dtype).unsqueeze(0)
if input_text is None:
input_text = " [INST] which city is this? [/INST] "
pre_prompt = input_text
post_prompt = None
elif self.model_type == "pix2struct":
image_processor = AutoProcessor.from_pretrained(args.hf_model_dir)
if input_text is None:
input_text = ""
inputs = image_processor(
images=raw_image,
text=input_text,
return_tensors="pt",
)
image = inputs["flattened_patches"]
image = image.expand(self.args.batch_size, -1, -1).contiguous()
attention_mask = inputs["attention_mask"].to(self.device).to(torch.int)
attention_mask = attention_mask.expand(args.batch_size, -1).contiguous()
pre_prompt = ""
post_prompt = None
elif "neva" in self.model_type:
image_size = 384
dtype = torch.float32
transform = transforms.Compose(
[
transforms.Resize(
(image_size, image_size),
interpolation=transforms.InterpolationMode.BICUBIC,
),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
)
image = transform(raw_image).to(dtype).unsqueeze(0)
if input_text is None:
input_text = "Hi! What is in this image?"
pre_prompt = "<extra_id_0>System\n\n<extra_id_1>User\n"
post_prompt = f"\n{input_text}\n<extra_id_1>Assistant\n"
elif self.model_type in ["llava", "vila", "fuyu", "kosmos-2", "llava_next"]:
# LLaVA and VILA
if self.model_type == "llava":
pre_prompt = "USER:\n"
if input_text is None:
input_text = "Question: which city is this? Answer:"
elif self.model_type == "llava_next":
pre_prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: "
elif self.model_type == "vila":
pre_prompt = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: "
if input_text is None:
input_text = "Please describe the traffic condition."
elif self.model_type == "fuyu":
pre_prompt = "Describe this image:"
if input_text is None:
input_text = "Answer the following VQAv2 question based on the image: How many people are in the image?\n"
elif self.model_type == "kosmos-2":
pre_prompt = ""
if input_text is None:
input_text = "<grounding>An image of"
if self.model_type not in ["fuyu", "kosmos-2"]:
post_prompt = input_text + " ASSISTANT:"
else:
post_prompt = None
if self.model_type == "vila":
sys.path.append(self.args.hf_model_dir + "/../VILA")
from llava.model import LlavaLlamaForCausalLM
model = LlavaLlamaForCausalLM.from_pretrained(
self.args.hf_model_dir, torch_dtype=torch.float16
)
vision_tower = model.get_vision_tower()
image_processor = vision_tower.image_processor
image = image_processor(images=raw_image, return_tensors="pt")[
"pixel_values"
]
else:
# processor = AutoProcessor.from_pretrained(
# self.args.hf_model_dir)
# if self.model_type in ['fuyu', 'kosmos-2']:
# image = processor(text=input_text,
# images=raw_image,
# return_tensors='pt')
# else:
# image = processor(text=input_text,
# images=raw_image,
# return_tensors="pt")['pixel_values']
image = image_pre_obj.encode(raw_image).cuda()
# Repeat inputs to match batch size
pre_prompt = [pre_prompt] * self.args.batch_size
post_prompt = [post_prompt] * self.args.batch_size
if self.model_type not in ["fuyu", "pix2struct", "kosmos-2"]:
image = image.expand(args.batch_size, -1, -1, -1).contiguous()
image = image.to(self.device)
# Generate decoder_input_ids for enc-dec models
# Custom prompts can be added as:
# decoder_input_ids = model.tokenizer(decoder_prompt).input_ids
if self.decoder_llm:
decoder_input_ids = None
else:
config = AutoConfig.from_pretrained(args.hf_model_dir)
decoder_start_id = config.decoder_start_token_id # T5
if decoder_start_id is None:
decoder_start_id = config.decoder.bos_token_id # Nougat
decoder_input_ids = torch.IntTensor([[decoder_start_id]])
decoder_input_ids = decoder_input_ids.repeat((args.batch_size, 1))
return (
input_text,
pre_prompt,
post_prompt,
image,
decoder_input_ids,
attention_mask,
)
def run(self, input_text, input_image, max_new_tokens):
(
input_text,
pre_prompt,
post_prompt,
processed_image,
decoder_input_ids,
attention_mask,
) = model.setup_inputs(input_text, input_image)
model.generate(
pre_prompt,
post_prompt,
processed_image,
decoder_input_ids,
max_new_tokens,
attention_mask=attention_mask,
warmup=True,
)
num_iters = self.profiling_iterations if self.args.run_profiling else 1
num_iters = 5
output_text = model.generate(
pre_prompt,
post_prompt,
processed_image,
decoder_input_ids,
max_new_tokens,
attention_mask=attention_mask,
warmup=False,
)
# for _ in range(2):
# output_text = model.generate(pre_prompt,
# post_prompt,
# processed_image,
# decoder_input_ids,
# max_new_tokens,
# attention_mask=attention_mask,
# warmup=False)
# from datetime import datetime
# torch.cuda.synchronize()
# a = datetime.now()
# for _ in range(num_iters):
# output_text = model.generate(pre_prompt,
# post_prompt,
# processed_image,
# decoder_input_ids,
# max_new_tokens,
# attention_mask=attention_mask,
# warmup=False)
# torch.cuda.synchronize()
# b = datetime.now()
# print("cost time : ", (b - a).total_seconds() / num_iters)
if self.runtime_rank == 0:
self.print_result(input_text, output_text)
return output_text
def print_result(self, input_text, output_text):
logger.info("---------------------------------------------------------")
if self.model_type != "nougat":
logger.info(f"\n[Q] {input_text}")
logger.info(f"\n[A] {output_text[0]}")
if args.num_beams == 1:
output_ids = self.tokenizer(output_text[0][0], add_special_tokens=False)[
"input_ids"
]
logger.info(f"Generated {len(output_ids)} tokens")
if self.args.check_accuracy:
for i in range(self.args.batch_size - 1):
if not (output_text[i] == output_text[i + 1]):
logger.info(f"Output {i} and {i + 1} do not match")
assert False
if self.model_type != "nougat":
if self.model_type == "vila":
assert (
output_text[0][0].lower()
== "the traffic condition in the image is quite busy, with multiple cars and bicycles sharing the road. there are also pedestrians walking on"
)
elif self.model_type == "fuyu":
assert output_text[0][0].lower() == "4"
elif self.model_type == "pix2struct":
assert (
"characteristic | cat food, day | cat food, wet | cat treats"
in output_text[0][0].lower()
)
elif self.model_type == "neva":
assert "singapore" in output_text[0][0].lower()
elif self.model_type == "kosmos-2":
assert "snowman" in output_text[0][0].lower()
else:
assert output_text[0][0].lower() == "singapore"
if self.args.run_profiling:
msec_per_batch = (
lambda name: 1000
* profiler.elapsed_time_in_sec(name)
/ self.profiling_iterations
)
logger.info("Latencies per batch (msec)")
logger.info("TRT vision encoder: %.1f" % (msec_per_batch("Vision")))
logger.info("TRTLLM LLM generate: %.1f" % (msec_per_batch("LLM")))
logger.info("Multimodal generate: %.1f" % (msec_per_batch("Generate")))
logger.info("---------------------------------------------------------")
if __name__ == "__main__":
os.environ["TOKENIZERS_PARALLELISM"] = "false"
args = parse_arguments()
if args.mode == "caption_zh":
query = "描述这张图片"
elif args.mode == "caption_en":
query = "Please describe the content of this image"
elif args.mode == "insert_content":
assert args.content is not None
query = f"根据提示词“{args.content}”,描述这张图片"
tensorrt_llm.logger.set_level(args.log_level)
model = MultimodalModelRunner(args)
if args.input_file != None:
df = pd.read_csv(args.input_file)
text_zh = []
for i in tqdm.tqdm(range(len(df))):
img_path = df.loc[i]["img_path"]
raw_image = Image.open(img_path)
res = model.run(query, [raw_image], args.max_new_tokens)
text_zh.append(res)
df["text_zh"] = text_zh
df.to_csv(args.output_file, index=False, encoding="utf-8-sig")
else:
raw_image = Image.open(args.image_file)
res = model.run(query, [raw_image], args.max_new_tokens)
print(res)
timm==0.9.5
diffusers==0.21.2
peft==0.10.0
protobuf==3.19.0
transformers==4.39.1
accelerate==0.29.3
loguru==0.7.2
einops==0.7.0
sentencepiece==0.1.99
cuda-python==11.7.1
nvidia-pyindex==1.0.9
pandas==2.0.3
gradio==3.50.2
huggingface_hub==0.25.2
\ No newline at end of file
from pathlib import Path
from loguru import logger
from mllm.dialoggen_demo import DialogGen
from hydit.config import get_args
from hydit.inference_controlnet import End2End
from torchvision import transforms as T
import numpy as np
norm_transform = T.Compose(
[
T.ToTensor(),
T.Normalize([0.5], [0.5]),
]
)
from PIL import Image
def inferencer():
args = get_args()
models_root_path = Path(args.model_root)
if not models_root_path.exists():
raise ValueError(f"`models_root` not exists: {models_root_path}")
# Load models
gen = End2End(args, models_root_path)
# Try to enhance prompt
if args.enhance:
logger.info("Loading DialogGen model (for prompt enhancement)...")
enhancer = DialogGen(str(models_root_path / "dialoggen"), args.load_4bit)
logger.info("DialogGen model loaded.")
else:
enhancer = None
return args, gen, enhancer
if __name__ == "__main__":
args, gen, enhancer = inferencer()
if enhancer:
logger.info("Prompt Enhancement...")
success, enhanced_prompt = enhancer(args.prompt)
if not success:
logger.info("Sorry, the prompt is not compliant, refuse to draw.")
exit()
logger.info(f"Enhanced prompt: {enhanced_prompt}")
else:
enhanced_prompt = None
# Run inference
logger.info("Generating images...")
height, width = args.image_size
condition = (
Image.open(args.condition_image_path).convert("RGB").resize((width, height))
)
image = norm_transform(condition)
image = image.unsqueeze(0).cuda()
results = gen.predict(
args.prompt,
height=height,
width=width,
image=image,
seed=args.seed,
enhanced_prompt=enhanced_prompt,
negative_prompt=args.negative,
infer_steps=args.infer_steps,
guidance_scale=args.cfg_scale,
batch_size=args.batch_size,
src_size_cond=args.size_cond,
use_style_cond=args.use_style_cond,
)
images = results["images"]
# Save images
save_dir = Path("results")
save_dir.mkdir(exist_ok=True)
# Find the first available index
all_files = list(save_dir.glob("*.png"))
if all_files:
start = max([int(f.stem) for f in all_files]) + 1
else:
start = 0
for idx, pil_img in enumerate(images):
save_path = save_dir / f"{idx + start}.png"
pil_img.save(save_path)
logger.info(f"Save to {save_path}")
from pathlib import Path
from loguru import logger
from mllm.dialoggen_demo import DialogGen
from hydit.config import get_args
from hydit.inference_ipadapter import End2End
from torchvision import transforms as T
import numpy as np
norm_transform = T.Compose(
[
T.ToTensor(),
T.Normalize([0.5], [0.5]),
]
)
from PIL import Image
def inferencer():
args = get_args()
models_root_path = Path(args.model_root)
if not models_root_path.exists():
raise ValueError(f"`models_root` not exists: {models_root_path}")
# Load models
gen = End2End(args, models_root_path)
# Try to enhance prompt
if args.enhance:
logger.info("Loading DialogGen model (for prompt enhancement)...")
enhancer = DialogGen(str(models_root_path / "dialoggen"), args.load_4bit)
logger.info("DialogGen model loaded.")
else:
enhancer = None
return args, gen, enhancer
if __name__ == "__main__":
args, gen, enhancer = inferencer()
if enhancer:
logger.info("Prompt Enhancement...")
success, enhanced_prompt = enhancer(args.prompt)
if not success:
logger.info("Sorry, the prompt is not compliant, refuse to draw.")
exit()
logger.info(f"Enhanced prompt: {enhanced_prompt}")
else:
enhanced_prompt = None
# # Run inference
logger.info("Generating images...")
height, width = args.image_size
ref_image = Image.open(args.ref_image_path).convert("RGB")
i_scale = args.i_scale
results = gen.predict(
args.prompt,
height=height,
width=width,
image=ref_image,
i_scale=i_scale,
t_scale=1,
seed=3333,
enhanced_prompt=enhanced_prompt,
negative_prompt=args.negative,
infer_steps=args.infer_steps,
guidance_scale=3,
batch_size=args.batch_size,
src_size_cond=args.size_cond,
)
images = results["images"]
save_dir = Path("results")
save_dir.mkdir(exist_ok=True)
# Find the first available index
all_files = list(save_dir.glob("*.png"))
if all_files:
start = max([int(f.stem) for f in all_files]) + 1
else:
start = 0
for idx, pil_img in enumerate(images):
save_path = save_dir / f"{idx + start}.png"
pil_img.save(save_path)
logger.info(f"Save to {save_path}")
from pathlib import Path
from loguru import logger
from mllm.dialoggen_demo import DialogGen
from hydit.config import get_args
from hydit.inference import End2End
def inferencer():
args = get_args()
models_root_path = Path(args.model_root)
if not models_root_path.exists():
raise ValueError(f"`models_root` not exists: {models_root_path}")
# Load models
gen = End2End(args, models_root_path)
# Try to enhance prompt
if args.enhance:
logger.info("Loading DialogGen model (for prompt enhancement)...")
enhancer = DialogGen(str(models_root_path / "dialoggen"), args.load_4bit)
logger.info("DialogGen model loaded.")
else:
enhancer = None
return args, gen, enhancer
if __name__ == "__main__":
args, gen, enhancer = inferencer()
if enhancer:
logger.info("Prompt Enhancement...")
success, enhanced_prompt = enhancer(args.prompt)
if not success:
logger.info("Sorry, the prompt is not compliant, refuse to draw.")
exit()
logger.info(f"Enhanced prompt: {enhanced_prompt}")
else:
enhanced_prompt = None
# Run inference
logger.info("Generating images...")
height, width = args.image_size
results = gen.predict(
args.prompt,
height=height,
width=width,
seed=args.seed,
enhanced_prompt=enhanced_prompt,
negative_prompt=args.negative,
infer_steps=args.infer_steps,
guidance_scale=args.cfg_scale,
batch_size=args.batch_size,
src_size_cond=args.size_cond,
use_style_cond=args.use_style_cond,
)
images = results["images"]
# Save images
save_dir = Path("results")
save_dir.mkdir(exist_ok=True)
# Find the first available index
all_files = list(save_dir.glob("*.png"))
if all_files:
start = max([int(f.stem) for f in all_files]) + 1
else:
start = 0
for idx, pil_img in enumerate(images):
save_path = save_dir / f"{idx + start}.png"
pil_img.save(save_path)
logger.info(f"Save to {save_path}")
from pathlib import Path
from loguru import logger
from mllm.dialoggen_demo import DialogGen
from hydit.config import get_args
from hydit.inference import End2End
def inferencer():
args = get_args()
models_root_path = Path(args.model_root)
if not models_root_path.exists():
raise ValueError(f"`models_root` not exists: {models_root_path}")
# Load models
gen = End2End(args, models_root_path)
# Try to enhance prompt
if args.enhance:
logger.info("Loading DialogGen model (for prompt enhancement)...")
enhancer = DialogGen(str(models_root_path / "dialoggen"), args.load_4bit)
logger.info("DialogGen model loaded.")
else:
enhancer = None
return args, gen, enhancer
if __name__ == "__main__":
args, gen, enhancer = inferencer()
if enhancer:
logger.info("Prompt Enhancement...")
success, enhanced_prompt = enhancer(args.prompt)
if not success:
logger.info("Sorry, the prompt is not compliant, refuse to draw.")
exit()
logger.info(f"Enhanced prompt: {enhanced_prompt}")
else:
enhanced_prompt = None
#warmup
height, width = args.image_size
results = gen.predict(
args.prompt,
height=height,
width=width,
seed=args.seed,
enhanced_prompt=enhanced_prompt,
negative_prompt=args.negative,
infer_steps=5,
guidance_scale=args.cfg_scale,
batch_size=args.batch_size,
src_size_cond=args.size_cond,
use_style_cond=args.use_style_cond,
)
# Run inference
logger.info("Generating images...")
height, width = args.image_size
results = gen.predict(
args.prompt,
height=height,
width=width,
seed=args.seed,
enhanced_prompt=enhanced_prompt,
negative_prompt=args.negative,
infer_steps=args.infer_steps,
guidance_scale=args.cfg_scale,
batch_size=args.batch_size,
src_size_cond=args.size_cond,
use_style_cond=args.use_style_cond,
)
images = results["images"]
# Save images
save_dir = Path("results")
save_dir.mkdir(exist_ok=True)
# Find the first available index
all_files = list(save_dir.glob("*.png"))
if all_files:
start = max([int(f.stem) for f in all_files]) + 1
else:
start = 0
for idx, pil_img in enumerate(images):
save_path = save_dir / f"{idx + start}.png"
pil_img.save(save_path)
logger.info(f"Save to {save_path}")
#!/bin/bash
test_base=./tests # 指定测试目录
export CUDA_VISIBLE_DEVICES=3 # 指定GPU
for file in $(find "$test_base" -maxdepth 1 -name 'test_*.sh'); do
# 去掉路径前的 './' 以获得文件名
filename=$(basename "$file")
echo "################################"
echo "Running tests in $filename..."
bash "$file"
echo "################################"
done
\ No newline at end of file
#!/bin/bash
task_name="infer_controlnet_canny"
log_file="${task_name}.log"
python sample_controlnet.py --infer-mode torch --no-enhance --load-key distill --infer-steps 50 --control-type canny --prompt "在夜晚的酒店门前,一座古老的中国风格的狮子雕像矗立着,它的眼睛闪烁着光芒,仿佛在守护着这座建筑。背景是夜晚的酒店前,构图方式是特写,平视,居中构图。这张照片呈现了真实摄影风格,蕴含了中国雕塑文化,同时展现了神秘氛围" --condition-image-path controlnet/asset/input/canny.jpg --control-weight 1.0 > "$log_file" 2>&1
exit_status=$?
if [ $exit_status -eq 0 ]; then
echo -e "\033[0;32m$task_name Passed\033[0m"
else
echo -e "\033[0;31m$task_name Failed\033[0m"
fi
###
task_name="infer_controlnet_depth"
log_file="${task_name}.log"
python sample_controlnet.py --infer-mode torch --no-enhance --load-key distill --infer-steps 50 --control-type depth --prompt "在夜晚的酒店门前,一座古老的中国风格的狮子雕像矗立着,它的眼睛闪烁着光芒,仿佛在守护着这座建筑。背景是夜晚的酒店前,构图方式是特写,平视,居中构图。这张照片呈现了真实摄影风格,蕴含了中国雕塑文化,同时展现了神秘氛围" --condition-image-path controlnet/asset/input/depth.jpg --control-weight 1.0 > "$log_file" 2>&1
exit_status=$?
if [ $exit_status -eq 0 ]; then
echo -e "\033[0;32m$task_name Passed\033[0m"
else
echo -e "\033[0;31m$task_name Failed\033[0m"
fi
###
task_name="infer_controlnet_pose"
log_file="${task_name}.log"
python sample_controlnet.py --infer-mode torch --no-enhance --load-key distill --infer-steps 50 --control-type pose --prompt "在夜晚的酒店门前,一座古老的中国风格的狮子雕像矗立着,它的眼睛闪烁着光芒,仿佛在守护着这座建筑。背景是夜晚的酒店前,构图方式是特写,平视,居中构图。这张照片呈现了真实摄影风格,蕴含了中国雕塑文化,同时展现了神秘氛围" --condition-image-path controlnet/asset/input/pose.jpg --control-weight 1.0 > "$log_file" 2>&1
exit_status=$?
if [ $exit_status -eq 0 ]; then
echo -e "\033[0;32m$task_name Passed\033[0m"
else
echo -e "\033[0;31m$task_name Failed\033[0m"
fi
#!/bin/bash
task_name="infer_ipadapter.sh"
log_file="${task_name}.log"
python3 sample_ipadapter.py --infer-mode torch --ref-image-path ipadapter/asset/input/tiger.png --i-scale 1.0 --prompt 一只老虎在海洋中游泳,背景是海洋。构图方式是居中构图,呈现了动漫风格和文化,营造了平静的氛围。 --infer-steps 30 --is-ipa True --load-key distill > "$log_file" 2>&1
exit_status=$?
if [ $exit_status -eq 0 ]; then
echo -e "\033[0;32m$task_name Passed\033[0m"
else
echo -e "\033[0;31m$task_name Failed\033[0m"
fi
#!/bin/bash
task_name="infer_text2img_flash_attn"
log_file="${task_name}.log"
python sample_t2i.py --infer-mode fa --infer-steps 30 --prompt "青花瓷风格,一只可爱的哈士奇" --no-enhance --load-key distill > "$log_file" 2>&1
exit_status=$?
if [ $exit_status -eq 0 ]; then
echo -e "\033[0;32m$task_name Passed\033[0m"
else
echo -e "\033[0;31m$task_name Failed\033[0m"
fi
###
task_name="infer_text2img_raw_attn"
log_file="${task_name}.log"
python sample_t2i.py --infer-mode torch --infer-steps 30 --prompt "青花瓷风格,一只可爱的哈士奇" --no-enhance --load-key distill > "$log_file" 2>&1
exit_status=$?
if [ $exit_status -eq 0 ]; then
echo -e "\033[0;32m$task_name Passed\033[0m"
else
echo -e "\033[0;31m$task_name Failed\033[0m"
fi
# ==============================================================================
# Description: Export ONNX model and build TensorRT engine.
# ==============================================================================
# Check Hydit Version.
if [ -z "$1" ]; then
HYDIT_VERSION=1.2
elif [ "$1" == "1.0" ]; then
HYDIT_VERSION=1.0
elif [ "$1" == "1.1" ]; then
HYDIT_VERSION=1.1
elif [ "$1" == "1.2" ]; then
HYDIT_VERSION=1.2
else
echo "Failed. Hydit Only Has Version: 1.0, 1.1, 1.2!"
exit 1
fi
echo "Hydit Version: "${HYDIT_VERSION}
export MODEL_ROOT=ckpts
export ONNX_WORKDIR=${MODEL_ROOT}/onnx_model
echo "MODEL_ROOT=${MODEL_ROOT}"
echo "ONNX_WORKDIR=${ONNX_WORKDIR}"
# Remove old directories.
if [ -d "${ONNX_WORKDIR}" ]; then
echo "Remove old ONNX directories..."
rm -r ${ONNX_WORKDIR}
fi
# Inspect the project directory.
SCRIPT_PATH="$( cd "$( dirname "$0" )" && pwd )"
PROJECT_DIR=$(dirname "$SCRIPT_PATH")
export PYTHONPATH=${PROJECT_DIR}:${PYTHONPATH}
echo "PYTHONPATH=${PYTHONPATH}"
cd ${PROJECT_DIR}
echo "Change directory to ${PROJECT_DIR}"
# ----------------------------------------
# 1. Export ONNX model.
# ----------------------------------------
# Sleep for reading the message.
sleep 2s
echo "Exporting ONNX model..."
if [ ${HYDIT_VERSION} == "1.2" ]; then
echo "Export ONNX for Hydit Version 1.2"
python trt/export_onnx.py --model-root ${MODEL_ROOT} --onnx-workdir ${ONNX_WORKDIR} --infer-mode torch
elif [ ${HYDIT_VERSION} == "1.1" ]; then
echo "Export ONNX for Hydit Version 1.1"
python trt/export_onnx.py --model-root ./HunyuanDiT-v1.1 --onnx-workdir ${ONNX_WORKDIR} --infer-mode torch --use-style-cond --size-cond 1024 1024 --beta-end 0.03
elif [ ${HYDIT_VERSION} == "1.0" ]; then
echo "Export ONNX for Hydit Version 1.0"
python trt/export_onnx.py --model-root ./HunyuanDiT-v1.0 --onnx-workdir ${ONNX_WORKDIR} --infer-mode torch --use-style-cond --size-cond 1024 1024 --beta-end 0.03
fi
echo "Exporting ONNX model finished"
# ----------------------------------------
# 2. Build TensorRT engine.
# ----------------------------------------
echo "Building TensorRT engine..."
ENGINE_DIR="${MODEL_ROOT}/t2i/model_trt/engine"
mkdir -p ${ENGINE_DIR}
ENGINE_PATH=${ENGINE_DIR}/model_onnx.plan
PLUGIN_PATH=${MODEL_ROOT}/t2i/model_trt/fmha_plugins/10.1_plugin_cuda11/fMHAPlugin.so
if [ ${HYDIT_VERSION} == "1.2" ]; then
trtexec \
--onnx=${ONNX_WORKDIR}/export_modified_fmha/model.onnx \
--fp16 \
--saveEngine=${ENGINE_PATH} \
--minShapes=x:2x4x90x90,t:2,encoder_hidden_states:2x77x1024,text_embedding_mask:2x77,encoder_hidden_states_t5:2x256x2048,text_embedding_mask_t5:2x256,cos_cis_img:2025x88,sin_cis_img:2025x88 \
--optShapes=x:2x4x128x128,t:2,encoder_hidden_states:2x77x1024,text_embedding_mask:2x77,encoder_hidden_states_t5:2x256x2048,text_embedding_mask_t5:2x256,cos_cis_img:4096x88,sin_cis_img:4096x88 \
--maxShapes=x:2x4x160x160,t:2,encoder_hidden_states:2x77x1024,text_embedding_mask:2x77,encoder_hidden_states_t5:2x256x2048,text_embedding_mask_t5:2x256,cos_cis_img:6400x88,sin_cis_img:6400x88 \
--shapes=x:2x4x128x128,t:2,encoder_hidden_states:2x77x1024,text_embedding_mask:2x77,encoder_hidden_states_t5:2x256x2048,text_embedding_mask_t5:2x256,cos_cis_img:4096x88,sin_cis_img:4096x88 \
--verbose \
--staticPlugins=${PLUGIN_PATH} \
--stronglyTyped
else
trtexec \
--onnx=${ONNX_WORKDIR}/export_modified_fmha/model.onnx \
--fp16 \
--saveEngine=${ENGINE_PATH} \
--minShapes=x:2x4x90x90,t:2,encoder_hidden_states:2x77x1024,text_embedding_mask:2x77,encoder_hidden_states_t5:2x256x2048,text_embedding_mask_t5:2x256,image_meta_size:2x6,style:2,cos_cis_img:2025x88,sin_cis_img:2025x88 \
--optShapes=x:2x4x128x128,t:2,encoder_hidden_states:2x77x1024,text_embedding_mask:2x77,encoder_hidden_states_t5:2x256x2048,text_embedding_mask_t5:2x256,image_meta_size:2x6,style:2,cos_cis_img:4096x88,sin_cis_img:4096x88 \
--maxShapes=x:2x4x160x160,t:2,encoder_hidden_states:2x77x1024,text_embedding_mask:2x77,encoder_hidden_states_t5:2x256x2048,text_embedding_mask_t5:2x256,image_meta_size:2x6,style:2,cos_cis_img:6400x88,sin_cis_img:6400x88 \
--shapes=x:2x4x128x128,t:2,encoder_hidden_states:2x77x1024,text_embedding_mask:2x77,encoder_hidden_states_t5:2x256x2048,text_embedding_mask_t5:2x256,image_meta_size:2x6,style:2,cos_cis_img:4096x88,sin_cis_img:4096x88 \
--verbose \
--builderOptimizationLevel=4 \
--staticPlugins=${PLUGIN_PATH} \
--stronglyTyped
fi
from pathlib import Path
import torch
from loguru import logger
from hydit.config import get_args
from hydit.modules.models import HunYuanDiT, HUNYUAN_DIT_CONFIG
import numpy as np
import onnx
import onnx_graphsurgeon as gs
import polygraphy.backend.onnx.loader
from copy import deepcopy
def _to_tuple(val):
if isinstance(val, (list, tuple)):
if len(val) == 1:
val = [val[0], val[0]]
elif len(val) == 2:
val = tuple(val)
else:
raise ValueError(f"Invalid value: {val}")
elif isinstance(val, (int, float)):
val = (val, val)
else:
raise ValueError(f"Invalid value: {val}")
return val
class ExportONNX(object):
def __init__(self, args, models_root_path):
self.args = args
self.model = None
# Set device and disable gradient
self.device = "cuda" if torch.cuda.is_available() else "cpu"
torch.set_grad_enabled(False)
# Check arguments
t2i_root_path = Path(models_root_path) / "t2i"
self.root = t2i_root_path
logger.info(f"Got text-to-image model root path: {t2i_root_path}")
# Create folder to save onnx model
onnx_workdir = Path(self.args.onnx_workdir)
self.onnx_workdir = onnx_workdir
self.onnx_export = self.onnx_workdir / "export/model.onnx"
self.onnx_export.parent.mkdir(parents=True, exist_ok=True)
self.onnx_modify = self.onnx_workdir / "export_modified/model.onnx"
self.onnx_modify.parent.mkdir(parents=True, exist_ok=True)
self.onnx_fmha = self.onnx_workdir / "export_modified_fmha/model.onnx"
self.onnx_fmha.parent.mkdir(parents=True, exist_ok=True)
def load_model(self):
# ========================================================================
# Create model structure and load the checkpoint
logger.info(f"Building HunYuan-DiT model...")
model_config = HUNYUAN_DIT_CONFIG[self.args.model]
image_size = _to_tuple(self.args.image_size)
latent_size = (image_size[0] // 8, image_size[1] // 8)
model_dir = self.root / "model"
model_path = model_dir / f"pytorch_model_{self.args.load_key}.pt"
if not model_path.exists():
raise ValueError(f"model_path not exists: {model_path}")
# Build model structure
self.model = (
HunYuanDiT(
self.args,
input_size=latent_size,
**model_config,
log_fn=logger.info,
)
.half()
.to(self.device)
) # Force to use fp16
# Load model checkpoint
logger.info(f"Loading torch model {model_path}...")
state_dict = torch.load(model_path, map_location=lambda storage, loc: storage)
self.model.load_state_dict(state_dict)
self.model.eval()
logger.info(f"Loading torch model finished")
logger.info("==================================================")
logger.info(f" Model is ready. ")
logger.info("==================================================")
def export(self):
if self.model is None:
self.load_model()
# Construct model inputs
latent_model_input = torch.randn(2, 4, 128, 128, device=self.device).half()
t_expand = torch.randint(0, 1000, [2], device=self.device).half()
prompt_embeds = torch.randn(2, 77, 1024, device=self.device).half()
attention_mask = torch.randint(0, 2, [2, 77], device=self.device).long()
prompt_embeds_t5 = torch.randn(2, 256, 2048, device=self.device).half()
attention_mask_t5 = torch.randint(0, 2, [2, 256], device=self.device).long()
ims = torch.tensor(
[[1024, 1024, 1024, 1024, 0, 0], [1024, 1024, 1024, 1024, 0, 0]],
device=self.device,
).half()
style = torch.tensor([0, 0], device=self.device).long()
freqs_cis_img = (
torch.randn(4096, 88),
torch.randn(4096, 88),
)
save_to = self.onnx_export
logger.info(f"Exporting ONNX model {save_to}...")
logger.info(f"Exporting ONNX external data {save_to.parent}...")
# Hydit version 1.2
if not self.args.use_style_cond:
model_args = (
latent_model_input,
t_expand,
prompt_embeds,
attention_mask,
prompt_embeds_t5,
attention_mask_t5,
freqs_cis_img[0],
freqs_cis_img[1],
)
torch.onnx.export(
self.model,
model_args,
str(save_to),
export_params=True,
opset_version=17,
do_constant_folding=True,
input_names=[
"x",
"t",
"encoder_hidden_states",
"text_embedding_mask",
"encoder_hidden_states_t5",
"text_embedding_mask_t5",
"cos_cis_img",
"sin_cis_img",
],
output_names=["output"],
dynamic_axes={
"x": {0: "2B", 2: "H", 3: "W"},
"t": {0: "2B"},
"encoder_hidden_states": {0: "2B"},
"text_embedding_mask": {0: "2B"},
"encoder_hidden_states_t5": {0: "2B"},
"text_embedding_mask_t5": {0: "2B"},
"cos_cis_img": {0: "seqlen"},
"sin_cis_img": {0: "seqlen"},
},
)
# Hydit version 1.0 or 1.1
else:
model_args = (
latent_model_input,
t_expand,
prompt_embeds,
attention_mask,
prompt_embeds_t5,
attention_mask_t5,
freqs_cis_img[0],
freqs_cis_img[1],
ims,
style,
)
torch.onnx.export(
self.model,
model_args,
str(save_to),
export_params=True,
opset_version=17,
do_constant_folding=True,
input_names=[
"x",
"t",
"encoder_hidden_states",
"text_embedding_mask",
"encoder_hidden_states_t5",
"text_embedding_mask_t5",
"cos_cis_img",
"sin_cis_img",
"image_meta_size",
"style",
],
output_names=["output"],
dynamic_axes={
"x": {0: "2B", 2: "H", 3: "W"},
"t": {0: "2B"},
"encoder_hidden_states": {0: "2B"},
"text_embedding_mask": {0: "2B"},
"encoder_hidden_states_t5": {0: "2B"},
"text_embedding_mask_t5": {0: "2B"},
"image_meta_size": {0: "2B"},
"style": {0: "2B"},
"cos_cis_img": {0: "seqlen"},
"sin_cis_img": {0: "seqlen"},
},
)
logger.info("Exporting onnx finished")
def postprocessing(self):
load_from = self.onnx_export
save_to = self.onnx_modify
logger.info(f"Postprocessing ONNX model {load_from}...")
onnxModel = onnx.load(str(load_from), load_external_data=False)
onnx.load_external_data_for_model(onnxModel, str(load_from.parent))
graph = gs.import_onnx(onnxModel)
# ADD GAMMA BETA FOR LN
for node in graph.nodes:
if node.name == "/final_layer/norm_final/LayerNormalization":
constantKernel = gs.Constant(
"final_layer.norm_final.weight",
np.ascontiguousarray(np.ones((1408,), dtype=np.float32)),
)
constantBias = gs.Constant(
"final_layer.norm_final.bias",
np.ascontiguousarray(np.zeros((1408,), dtype=np.float32)),
)
node.inputs = [node.inputs[0], constantKernel, constantBias]
if node.op == "LayerNormalization":
input_fp32 = gs.Variable(name=node.name + "_input_tensor_fp32")
cast_fp32_node = gs.Node(
op="Cast",
name=node.name + "_cast_to_fp32",
attrs={"to": onnx.TensorProto.FLOAT},
inputs=[node.inputs[0]],
outputs=[input_fp32],
)
scale = np.array(
deepcopy(node.inputs[1].values.tolist()), dtype=np.float32
)
bias = np.array(
deepcopy(node.inputs[2].values.tolist()), dtype=np.float32
)
scale_constant = gs.Constant(
node.inputs[1].name + "_fp32",
np.ascontiguousarray(scale.reshape(-1)),
)
bias_constant = gs.Constant(
node.inputs[2].name + "_fp32",
np.ascontiguousarray(bias.reshape(-1)),
)
node.inputs = [input_fp32, scale_constant, bias_constant]
out = node.outputs[0]
output_fp32 = gs.Variable(name=node.name + "_output_tensor_fp32")
node.outputs = [output_fp32]
cast_fp16_node = gs.Node(
op="Cast",
name=node.name + "_cast_to_fp16",
attrs={"to": onnx.TensorProto.FLOAT16},
inputs=[output_fp32],
outputs=[out],
)
graph.nodes.append(cast_fp16_node)
graph.nodes.append(cast_fp32_node)
graph.cleanup().toposort()
onnx.save(
gs.export_onnx(graph.cleanup()),
str(save_to),
save_as_external_data=True,
all_tensors_to_one_file=False,
location=str(save_to.parent),
)
logger.info(f"Postprocessing ONNX model finished: {save_to}")
def fuse_attn(self):
load_from = self.onnx_modify
save_to = self.onnx_fmha
logger.info(f"FuseAttn ONNX model {load_from}...")
onnx_graph = polygraphy.backend.onnx.loader.fold_constants(
onnx.load(str(load_from)),
allow_onnxruntime_shape_inference=True,
)
graph = gs.import_onnx(onnx_graph)
cnt = 0
for node in graph.nodes:
if (
node.op == "Softmax"
and node.i().op == "MatMul"
and node.o().op == "MatMul"
and node.o().o().op == "Transpose"
):
if "pooler" in node.name:
continue
if "attn1" in node.name:
matmul_0 = node.i()
transpose = matmul_0.i(1, 0)
transpose.attrs["perm"] = [0, 2, 1, 3]
k = transpose.outputs[0]
q = gs.Variable(
"transpose_0_v_{}".format(cnt), np.dtype(np.float16)
)
transpose_0 = gs.Node(
"Transpose",
"Transpose_0_{}".format(cnt),
attrs={"perm": [0, 2, 1, 3]},
inputs=[matmul_0.inputs[0]],
outputs=[q],
)
graph.nodes.append(transpose_0)
matmul_1 = node.o()
v = gs.Variable(
"transpose_1_v_{}".format(cnt), np.dtype(np.float16)
)
transpose_1 = gs.Node(
"Transpose",
"Transpose_1_{}".format(cnt),
attrs={"perm": [0, 2, 1, 3]},
inputs=[matmul_1.inputs[1]],
outputs=[v],
)
graph.nodes.append(transpose_1)
output_variable = node.o().o().outputs[0]
# fMHA_v = gs.Variable("fMHA_v", np.dtype(np.float16))
fMHA = gs.Node(
"fMHAPlugin",
"fMHAPlugin_1_{}".format(cnt),
# attrs={"scale": 1.0},
inputs=[q, k, v],
outputs=[output_variable],
)
graph.nodes.append(fMHA)
node.o().o().outputs = []
cnt = cnt + 1
elif "attn2" in node.name:
matmul_0 = node.i()
transpose_q = matmul_0.i()
transpose_k = matmul_0.i(1, 0)
matmul_1 = node.o()
transpose_v = matmul_1.i(1, 0)
q = transpose_q.inputs[0]
k = transpose_k.inputs[0]
v = transpose_v.inputs[0]
output_variable = node.o().o().outputs[0]
fMHA = gs.Node(
"fMHAPlugin",
"fMHAPlugin_1_{}".format(cnt),
# attrs={"scale": 1.0},
inputs=[q, k, v],
outputs=[output_variable],
)
graph.nodes.append(fMHA)
node.o().o().outputs = []
cnt = cnt + 1
print("mha count: ", cnt)
logger.info("mha count: ", cnt)
onnx.save(
gs.export_onnx(graph.cleanup()),
str(save_to),
save_as_external_data=True,
)
logger.info(f"FuseAttn ONNX model finished: {save_to}")
if __name__ == "__main__":
args = get_args()
models_root_path = Path(args.model_root)
if not models_root_path.exists():
raise ValueError(f"`models_root` not exists: {models_root_path}")
exporter = ExportONNX(args, models_root_path)
exporter.export()
exporter.postprocessing()
exporter.fuse_attn()
# ==============================================================================
# Description: Install TensorRT and prepare the environment for TensorRT.
# ==============================================================================
# ----------------------------------------
# Check the system, tools and arguments
# ----------------------------------------
# Check system. Only support TensorRT on Linux (MacOS is not supported.)
if [ "$(uname)" != "Linux" ]; then
echo "Only support TensorRT on Linux"
exit 1
fi
# Check if the model_trt path is provided. If not, use the default path.
if [ -z "$1" ]; then
MODEL_TRT_DIR=$(cd ckpts/t2i/model_trt; pwd)
else
MODEL_TRT_DIR=$(cd "$1"; pwd)
fi
# Check if the model_trt path exists.
if [ ! -d "${MODEL_TRT_DIR}" ]; then
echo "The model_trt directory (${MODEL_TRT_DIR}) does not exist. Please specify the path by:"
echo " sh trt/install.sh <model_trt_dir>"
exit 1
fi
# Check if ldconfig exists.
if [ ! -x "$(command -v ldconfig)" ]; then
echo "ldconfig is not installed. Please install it first."
exit 1
fi
export TENSORRT_VERSION='10.1.0.27'
TENSORRT_PACKAGE="${MODEL_TRT_DIR}/TensorRT-${TENSORRT_VERSION}.tar.gz"
# Check if the TensorRT package is downloaded.
if [ ! -f "${TENSORRT_PACKAGE}" ]; then
echo "The TensorRT package (${TENSORRT_PACKAGE}) does not exist. Please download it first with following steps:"
echo "1. cd HunyuanDiT"
echo "2. huggingface-cli download Tencent-Hunyuan/HunyuanDiT-TensorRT --local-dir ./ckpts/t2i/model_trt"
exit 1
else
echo "Found TensorRT package: ${TENSORRT_PACKAGE}"
fi
# ----------------------------------------
# Start to install TensorRT
# ----------------------------------------
# Extract the TensorRT package.
echo "Extracting the TensorRT package..."
tar xf "${TENSORRT_PACKAGE}" -C "${MODEL_TRT_DIR}"
TENSORRT_DIR="${MODEL_TRT_DIR}/TensorRT-${TENSORRT_VERSION}"
echo "Extracting the TensorRT package finished"
# Add the TensorRT library path to the system library path.
echo "${MODEL_TRT_DIR}/lib/" >> /etc/ld.so.conf.d/nvidia.conf && ldconfig
# Install the TensorRT Python wheel.
echo "Installing the TensorRT Python wheel..."
# Get python version, e.g., cp38 for Python 3.8; cp310 for Python 3.10
PYTHON_VERSION=$(python -c 'import sys; print(f"cp{sys.version_info.major}{sys.version_info.minor}")')
python -m pip install --no-cache-dir ${TENSORRT_DIR}/python/tensorrt*-${PYTHON_VERSION}*
echo "Installing the TensorRT Python wheel finished"
# Prepare activate.sh and deactivate.sh
{
echo "TENSORRT_DIR=${TENSORRT_DIR}"
echo 'export LD_LIBRARY_PATH=${TENSORRT_DIR}/lib/:$LD_LIBRARY_PATH'
echo 'export LIBRARY_PATH=${TENSORRT_DIR}/lib/:$LIBRARY_PATH'
echo 'export PATH=${TENSORRT_DIR}/bin/:$PATH'
} > $(dirname "$0")/activate.sh
{
echo "TENSORRT_DIR=${TENSORRT_DIR}"
echo 'export LD_LIBRARY_PATH=${LD_LIBRARY_PATH/${TENSORRT_DIR}\/lib\/:}'
echo 'export LIBRARY_PATH=${LIBRARY_PATH/${TENSORRT_DIR}\/lib\/:}'
echo 'export PATH=${PATH/${TENSORRT_DIR}\/bin\/:}'
} > $(dirname "$0")/deactivate.sh
# Copyright (c) OpenMMLab. All rights reserved.
"""This file holding some environment constant for sharing by other files."""
import os
import os.path as osp
import subprocess
import sys
from collections import OrderedDict, defaultdict
import numpy as np
import torch
def is_rocm_pytorch() -> bool:
"""Check whether the PyTorch is compiled on ROCm."""
is_rocm = False
if TORCH_VERSION != "parrots":
try:
from torch.utils.cpp_extension import ROCM_HOME
is_rocm = (
True
if ((torch.version.hip is not None) and (ROCM_HOME is not None))
else False
)
except ImportError:
pass
return is_rocm
TORCH_VERSION = torch.__version__
def get_build_config():
"""Obtain the build information of PyTorch or Parrots."""
if TORCH_VERSION == "parrots":
from parrots.config import get_build_info
return get_build_info()
else:
return torch.__config__.show()
try:
import torch_musa # noqa: F401
IS_MUSA_AVAILABLE = True
except Exception:
IS_MUSA_AVAILABLE = False
def is_musa_available() -> bool:
return IS_MUSA_AVAILABLE
def is_cuda_available() -> bool:
"""Returns True if cuda devices exist."""
return torch.cuda.is_available()
def _get_cuda_home():
if TORCH_VERSION == "parrots":
from parrots.utils.build_extension import CUDA_HOME
else:
if is_rocm_pytorch():
from torch.utils.cpp_extension import ROCM_HOME
CUDA_HOME = ROCM_HOME
else:
from torch.utils.cpp_extension import CUDA_HOME
return CUDA_HOME
def _get_musa_home():
return os.environ.get("MUSA_HOME")
def collect_env():
"""Collect the information of the running environments.
Returns:
dict: The environment information. The following fields are contained.
- sys.platform: The variable of ``sys.platform``.
- Python: Python version.
- CUDA available: Bool, indicating if CUDA is available.
- GPU devices: Device type of each GPU.
- CUDA_HOME (optional): The env var ``CUDA_HOME``.
- NVCC (optional): NVCC version.
- GCC: GCC version, "n/a" if GCC is not installed.
- MSVC: Microsoft Virtual C++ Compiler version, Windows only.
- PyTorch: PyTorch version.
- PyTorch compiling details: The output of \
``torch.__config__.show()``.
- TorchVision (optional): TorchVision version.
- OpenCV (optional): OpenCV version.
"""
from distutils import errors
env_info = OrderedDict()
env_info["sys.platform"] = sys.platform
env_info["Python"] = sys.version.replace("\n", "")
cuda_available = is_cuda_available()
musa_available = is_musa_available()
env_info["CUDA available"] = cuda_available
env_info["MUSA available"] = musa_available
env_info["numpy_random_seed"] = np.random.get_state()[1][0]
if cuda_available:
devices = defaultdict(list)
for k in range(torch.cuda.device_count()):
devices[torch.cuda.get_device_name(k)].append(str(k))
for name, device_ids in devices.items():
env_info["GPU " + ",".join(device_ids)] = name
CUDA_HOME = _get_cuda_home()
env_info["CUDA_HOME"] = CUDA_HOME
if CUDA_HOME is not None and osp.isdir(CUDA_HOME):
if CUDA_HOME == "/opt/rocm":
try:
nvcc = osp.join(CUDA_HOME, "hip/bin/hipcc")
nvcc = subprocess.check_output(f'"{nvcc}" --version', shell=True)
nvcc = nvcc.decode("utf-8").strip()
release = nvcc.rfind("HIP version:")
build = nvcc.rfind("")
nvcc = nvcc[release:build].strip()
except subprocess.SubprocessError:
nvcc = "Not Available"
else:
try:
nvcc = osp.join(CUDA_HOME, "bin/nvcc")
nvcc = subprocess.check_output(f'"{nvcc}" -V', shell=True)
nvcc = nvcc.decode("utf-8").strip()
release = nvcc.rfind("Cuda compilation tools")
build = nvcc.rfind("Build ")
nvcc = nvcc[release:build].strip()
except subprocess.SubprocessError:
nvcc = "Not Available"
env_info["NVCC"] = nvcc
elif musa_available:
devices = defaultdict(list)
for k in range(torch.musa.device_count()):
devices[torch.musa.get_device_name(k)].append(str(k))
for name, device_ids in devices.items():
env_info["GPU " + ",".join(device_ids)] = name
MUSA_HOME = _get_musa_home()
env_info["MUSA_HOME"] = MUSA_HOME
if MUSA_HOME is not None and osp.isdir(MUSA_HOME):
try:
mcc = osp.join(MUSA_HOME, "bin/mcc")
subprocess.check_output(f'"{mcc}" -v', shell=True)
except subprocess.SubprocessError:
mcc = "Not Available"
env_info["mcc"] = mcc
try:
# Check C++ Compiler.
# For Unix-like, sysconfig has 'CC' variable like 'gcc -pthread ...',
# indicating the compiler used, we use this to get the compiler name
import io
import sysconfig
cc = sysconfig.get_config_var("CC")
if cc:
cc = osp.basename(cc.split()[0])
cc_info = subprocess.check_output(f"{cc} --version", shell=True)
env_info["GCC"] = cc_info.decode("utf-8").partition("\n")[0].strip()
else:
# on Windows, cl.exe is not in PATH. We need to find the path.
# distutils.ccompiler.new_compiler() returns a msvccompiler
# object and after initialization, path to cl.exe is found.
import locale
import os
from distutils.ccompiler import new_compiler
ccompiler = new_compiler()
ccompiler.initialize()
cc = subprocess.check_output(
f"{ccompiler.cc}", stderr=subprocess.STDOUT, shell=True
)
encoding = (
os.device_encoding(sys.stdout.fileno()) or locale.getpreferredencoding()
)
env_info["MSVC"] = cc.decode(encoding).partition("\n")[0].strip()
env_info["GCC"] = "n/a"
except (subprocess.CalledProcessError, errors.DistutilsPlatformError):
env_info["GCC"] = "n/a"
except io.UnsupportedOperation as e:
# JupyterLab on Windows changes sys.stdout, which has no `fileno` attr
# Refer to: https://github.com/open-mmlab/mmengine/issues/931
# TODO: find a solution to get compiler info in Windows JupyterLab,
# while preserving backward-compatibility in other systems.
env_info["MSVC"] = f"n/a, reason: {str(e)}"
env_info["PyTorch"] = torch.__version__
env_info["PyTorch compiling details"] = get_build_config()
try:
import torchvision
env_info["TorchVision"] = torchvision.__version__
except ModuleNotFoundError:
pass
try:
import cv2
env_info["OpenCV"] = cv2.__version__
except ImportError:
pass
return env_info
if __name__ == "__main__":
for name, val in collect_env().items():
print(f"{name}: {val}")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment