Commit 727428ec authored by jerrrrry's avatar jerrrrry
Browse files

Initial commit CI/CD

parents
import copy
import functools
import json
import os
import sys
import time
from collections import defaultdict
from pathlib import Path
from typing import Optional
import safetensors
import torch
import torch.nn as nn
from tqdm import tqdm
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
from transformers.pytorch_utils import Conv1D
from ..._utils import pad_vocab_size, release_gc
from ...layers import MoeConfig
from ...logger import logger
from ...mapping import Mapping
from ...quantization import QuantAlgo
from ..convert_utils import load_calib_dataset
from ..modeling_utils import PretrainedConfig, QuantConfig, optimize_model
from .weight import load_from_hf_checkpoint, load_from_hf_safetensors
try:
from transformers import (
LlavaConfig,
LlavaForConditionalGeneration,
LlavaNextConfig,
LlavaNextForConditionalGeneration,
)
except ImportError:
pass
try:
pass
except ImportError:
pass
def generate_int8(weights, act_range, is_qkv=False, multi_query_mode=False):
"""
This function has two purposes:
- compute quantized weights, scaled either per-tensor or per-column
- compute scaling factors
Depending on the GEMM API (CUTLASS/CUBLAS) the required scaling factors differ.
CUTLASS uses two sets of scaling factors. One for the activation X, one for the weight W.
CUBLAS only has one (we can't do per-row scaling). So we must provide pre-multiplied scaling factor.
Here is the list of what we need (T means per-tensor, C per-column):
- scale_x_orig_quant puts fp activation into the quantized range (i.e. [-128, 127], for int8). Used before the GEMM. (T)
- scale_y_quant_orig puts quantized activation into the fp range. Used if the GEMM outputs int8. (T)
- scale_w_quant_orig puts weights from quant range to fp range (used with CUTLASS) (T, C)
- scale_y_accum_quant puts the GEMM result (XW) from accumulation range (int32)
to quant range (int8) (used for CUBLAS) (T, C)
Note that we don't do anything special about row-parallel GEMM. Theoretically, we could have per-GPU scaling factors too,
but then the model would change depending on the number of GPUs used.
For QKV projection, the behavior is special. Even if we have a single matrix to perform QKV projection, we consider it
as three different matrices: Q, K, and V. So per-tensor actually means one scaling factor for each Q, K and V.
For our GEMM implementation to respect this behavior, we use per-column mode and replicate values along columns.
"""
# compute weight scaling factors for fp->int8 and int8->fp
if is_qkv and not multi_query_mode:
scale_w_orig_quant_t = (
127.0 / act_range["w"].reshape(3, -1).max(dim=-1, keepdims=True)[0]
)
scale_w_orig_quant_c = 127.0 / act_range["w"].reshape(3, -1)
elif is_qkv and multi_query_mode:
hidden_dim = weights.shape[0]
local_dim = act_range["w"].shape[0]
kv_dim = (local_dim - hidden_dim) // 2
scale_w_q = act_range["w"][0:hidden_dim]
scale_w_k = act_range["w"][hidden_dim : hidden_dim + kv_dim]
scale_w_v = act_range["w"][-kv_dim:]
scale_w_qkv_t = torch.concat(
[
scale_w_q.max(dim=0, keepdim=True)[0],
scale_w_k.max(dim=0, keepdim=True)[0],
scale_w_v.max(dim=0, keepdim=True)[0],
]
)
scale_w_orig_quant_t = 127.0 / scale_w_qkv_t
scale_w_orig_quant_c = 127.0 / act_range["w"]
else:
scale_w_orig_quant_t = 127.0 / act_range["w"].max()
scale_w_orig_quant_c = 127.0 / act_range["w"]
scale_w_quant_orig_t = 1.0 / scale_w_orig_quant_t
scale_w_quant_orig_c = 1.0 / scale_w_orig_quant_c
scale_w_orig_quant_c = scale_w_orig_quant_c.to(torch.float32)
scale_w_orig_quant_t = scale_w_orig_quant_t.to(torch.float32)
# compute the rest of needed scaling factors
scale_x_orig_quant_t = 127.0 / act_range["x"].max()
scale_y_orig_quant_t = 127.0 / act_range["y"].max()
scale_y_quant_orig_t = act_range["y"].max() / 127.0
scale_y_accum_quant_t = scale_y_orig_quant_t / (
scale_x_orig_quant_t * scale_w_orig_quant_t
)
scale_y_accum_quant_c = scale_y_orig_quant_t / (
scale_x_orig_quant_t * scale_w_orig_quant_c
)
if is_qkv and not multi_query_mode:
scale_y_accum_quant_t = torch.broadcast_to(
scale_y_accum_quant_t, scale_w_orig_quant_c.shape
)
scale_w_quant_orig_t = torch.broadcast_to(
scale_w_quant_orig_t, scale_w_orig_quant_c.shape
)
if is_qkv and multi_query_mode:
scale_q_y_accum_t = torch.broadcast_to(
scale_y_accum_quant_t[0], scale_w_q.shape
)
scale_k_y_accum_t = torch.broadcast_to(
scale_y_accum_quant_t[1], scale_w_k.shape
)
scale_v_y_accum_t = torch.broadcast_to(
scale_y_accum_quant_t[2], scale_w_v.shape
)
scale_y_accum_quant_t = torch.concat(
[scale_q_y_accum_t, scale_k_y_accum_t, scale_v_y_accum_t]
)
scale_w_quant_orig_t = torch.concat(
[
torch.broadcast_to(scale_w_quant_orig_t[0], scale_w_q.shape),
torch.broadcast_to(scale_w_quant_orig_t[1], scale_w_k.shape),
torch.broadcast_to(scale_w_quant_orig_t[2], scale_w_v.shape),
]
)
to_i8 = lambda x: x.round().clip(-127, 127).to(torch.int8)
if is_qkv and multi_query_mode:
weight_int8 = to_i8(weights / scale_w_quant_orig_t)
else:
weight_int8 = to_i8(weights * scale_w_orig_quant_t)
return {
"weight.int8": weight_int8,
"weight.int8.col": to_i8(weights * scale_w_orig_quant_c),
"scale_x_orig_quant": scale_x_orig_quant_t.to(torch.float32),
"scale_w_quant_orig": scale_w_quant_orig_t.to(torch.float32),
"scale_w_quant_orig.col": scale_w_quant_orig_c.to(torch.float32),
"scale_y_accum_quant": scale_y_accum_quant_t.to(torch.float32),
"scale_y_accum_quant.col": scale_y_accum_quant_c.to(torch.float32),
"scale_y_quant_orig": scale_y_quant_orig_t.to(torch.float32),
}
@torch.no_grad()
def apply_smoothing(
scales,
gemm_weights,
layernorm_weights=None,
layernorm_bias=None,
dtype=torch.float32,
layernorm_1p=False,
):
if not isinstance(gemm_weights, list):
gemm_weights = [gemm_weights]
if layernorm_weights is not None:
assert layernorm_weights.numel() == scales.numel()
layernorm_weights.div_(scales).to(dtype)
if layernorm_bias is not None:
assert layernorm_bias.numel() == scales.numel()
layernorm_bias.div_(scales).to(dtype)
if layernorm_1p:
layernorm_weights += (1 / scales) - 1
for gemm in gemm_weights:
gemm.mul_(scales.view(1, -1)).to(dtype)
@torch.no_grad()
def smooth_gemm(
gemm_weights,
act_scales,
layernorm_weights=None,
layernorm_bias=None,
alpha=0.5,
weight_scales=None,
):
if not isinstance(gemm_weights, list):
gemm_weights = [gemm_weights]
orig_dtype = gemm_weights[0].dtype
for gemm in gemm_weights:
# gemm_weights are expected to be transposed
assert gemm.shape[1] == act_scales.numel()
if weight_scales is None:
weight_scales = torch.cat(
[gemm.abs().max(dim=0, keepdim=True)[0] for gemm in gemm_weights], dim=0
)
weight_scales = weight_scales.max(dim=0)[0]
weight_scales.to(float).clamp(min=1e-5)
scales = (
act_scales.to(gemm_weights[0].device).to(float).pow(alpha)
/ weight_scales.pow(1 - alpha)
).clamp(min=1e-5)
apply_smoothing(scales, gemm_weights, layernorm_weights, layernorm_bias, orig_dtype)
return scales
@torch.no_grad()
def smooth_gemm_fc1_gate(
fc1_weights,
gate_weights,
act_scales,
layernorm_weights=None,
layernorm_bias=None,
alpha=0.5,
weight_scales=None,
):
gemm_weights = []
if not isinstance(fc1_weights, list):
fc1_weights = [fc1_weights]
if not isinstance(gate_weights, list):
gate_weights = [gate_weights]
for i in range(len(fc1_weights)):
gemm_weight = torch.cat([fc1_weights[i], gate_weights[i]], dim=0)
gemm_weights.append(gemm_weight)
orig_dtype = gemm_weights[0].dtype
for gemm in gemm_weights:
# gemm_weights are expected to be transposed
assert gemm.shape[1] == act_scales.numel()
if weight_scales is None:
weight_scales = torch.cat(
[gemm.abs().max(dim=0, keepdim=True)[0] for gemm in gemm_weights], dim=0
)
weight_scales = weight_scales.max(dim=0)[0]
weight_scales.to(float).clamp(min=1e-5)
scales = (
act_scales.to(gemm_weights[0].device).to(float).pow(alpha)
/ weight_scales.pow(1 - alpha)
).clamp(min=1e-5)
apply_smoothing(
scales,
fc1_weights + gate_weights,
layernorm_weights,
layernorm_bias,
orig_dtype,
)
return scales
@torch.no_grad()
def smooth_llama_model(model, scales, alpha, llama_qkv_para, llama_smoother):
# Smooth the activation and weights with smoother = $\diag{s}$
for name, module in model.named_modules():
if not isinstance(
module, LlamaDecoderLayer
) and not module.__class__.__name__ in [
"InternLMDecoderLayer",
"MistralDecoderLayer",
]:
continue
# qkv_proj
layer_name_q = name + ".self_attn.q_proj"
layer_name_k = name + ".self_attn.k_proj"
layer_name_v = name + ".self_attn.v_proj"
layer_name_qkv = name + ".self_attn.qkv_proj"
weight = torch.cat(
[
module.self_attn.q_proj.weight,
module.self_attn.k_proj.weight,
module.self_attn.v_proj.weight,
],
dim=0,
)
smoother = smooth_gemm(
weight,
scales[layer_name_q]["x"],
module.input_layernorm.weight,
None,
alpha,
)
scales[layer_name_qkv]["x"] = scales[layer_name_q]["x"] / smoother
scales[layer_name_qkv]["w"] = weight.abs().max(dim=1)[0]
scales[layer_name_qkv]["y"] = torch.cat(
[
scales[layer_name_q]["y"],
scales[layer_name_k]["y"],
scales[layer_name_v]["y"],
],
dim=0,
)
# see transpose_weights function
llama_qkv_para[layer_name_qkv] = weight.transpose(0, 1)
# =================================================================
layer_name = name + ".self_attn.o_proj"
smoother = smooth_gemm(
module.self_attn.o_proj.weight, scales[layer_name]["x"], None, None, alpha
)
llama_smoother[layer_name] = smoother.float()
scales[layer_name]["x"] = scales[layer_name]["x"] / smoother
scales[layer_name]["w"] = module.self_attn.o_proj.weight.abs().max(dim=1)[0]
# ==================================================================
fc1_layer_name = name + ".mlp.gate_proj"
gate_layer_name = name + ".mlp.up_proj"
smoother = smooth_gemm_fc1_gate(
module.mlp.gate_proj.weight,
module.mlp.up_proj.weight,
scales[fc1_layer_name]["x"],
module.post_attention_layernorm.weight,
None,
alpha,
)
scales[fc1_layer_name]["x"] = scales[fc1_layer_name]["x"] / smoother
scales[fc1_layer_name]["w"] = module.mlp.gate_proj.weight.abs().max(dim=1)[0]
scales[gate_layer_name]["x"] = scales[gate_layer_name]["x"] / smoother
scales[gate_layer_name]["w"] = module.mlp.up_proj.weight.abs().max(dim=1)[0]
# ==================================================================
layer_name = name + ".mlp.down_proj"
smoother = smooth_gemm(
module.mlp.down_proj.weight, scales[layer_name]["x"], None, None, alpha
)
llama_smoother[layer_name] = smoother.float()
scales[layer_name]["x"] = scales[layer_name]["x"] / smoother
scales[layer_name]["w"] = module.mlp.down_proj.weight.abs().max(dim=1)[0]
# ==================================================================
if hasattr(module, "residual_mlp"):
fc1_layer_name = name + ".residual_mlp.w1"
gate_layer_name = name + ".residual_mlp.w3"
smoother = smooth_gemm_fc1_gate(
module.residual_mlp.w1.weight,
module.residual_mlp.w3.weight,
scales[fc1_layer_name]["x"],
module.residual_layernorm.weight,
None,
alpha,
)
scales[fc1_layer_name]["x"] = scales[fc1_layer_name]["x"] / smoother
scales[fc1_layer_name]["w"] = module.residual_mlp.w1.weight.abs().max(
dim=1
)[0]
scales[gate_layer_name]["x"] = scales[gate_layer_name]["x"] / smoother
scales[gate_layer_name]["w"] = module.residual_mlp.w3.weight.abs().max(
dim=1
)[0]
# ==================================================================
layer_name = name + ".residual_mlp.w2"
smoother = smooth_gemm(
module.residual_mlp.w2.weight,
scales[layer_name]["x"],
None,
None,
alpha,
)
llama_smoother[layer_name] = smoother.float()
scales[layer_name]["x"] = scales[layer_name]["x"] / smoother
scales[layer_name]["w"] = module.residual_mlp.w2.weight.abs().max(dim=1)[0]
@torch.no_grad()
def capture_activation_range(model, tokenizer, dataset, num_samples=512, seq_len=512):
model.eval()
device = next(model.parameters()).device
act_scales = defaultdict(lambda: {"x": None, "y": None, "w": None})
tokenizer.pad_token = tokenizer.eos_token
def stat_tensor(name, tensor, act_scales, key):
hidden_dim = tensor.shape[-1]
tensor = tensor.view(-1, hidden_dim).abs().detach()
comming_max = torch.max(tensor, dim=0)[0].float()
if act_scales[name][key] is None:
act_scales[name][key] = comming_max
else:
act_scales[name][key] = torch.max(act_scales[name][key], comming_max)
def stat_input_hook(m, x, y, name):
if isinstance(x, tuple):
x = x[0]
stat_tensor(name, x, act_scales, "x")
stat_tensor(name, y, act_scales, "y")
if act_scales[name]["w"] is None:
act_scales[name]["w"] = m.weight.abs().clip(1e-8, None).max(dim=1)[0]
hooks = []
for name, m in model.named_modules():
if isinstance(m, nn.Linear) or isinstance(m, Conv1D):
hooks.append(
m.register_forward_hook(functools.partial(stat_input_hook, name=name))
)
for i in tqdm(range(num_samples), desc="calibrating model"):
datapoint = dataset[i : i + 1]
line = copy.copy(datapoint)
line[0] = line[0] + " TL;DR: "
line[0] = line[0].strip()
line[0] = line[0].replace(" n't", "n't")
input_ids = tokenizer(
line, return_tensors="pt", max_length=seq_len, padding=True, truncation=True
).input_ids.to(device)
model(input_ids)
for h in hooks:
h.remove()
return act_scales
def split(v, tp_size, idx, dim=0):
if tp_size == 1:
return v
if len(v.shape) == 1:
return torch.chunk(v, tp_size)[idx].contiguous()
else:
return torch.chunk(v, tp_size, dim=dim)[idx]
def split_qkv_tp(v, n_head, n_hidden, tensor_parallel, rank):
"""
Splits the QKV matrix according to tensor parallelism
"""
v = v.reshape(3, n_hidden, n_hidden)
split_v = split(v, tensor_parallel, rank, dim=1)
split_v = split_v.reshape(3 * (n_hidden // tensor_parallel), n_hidden)
return split_v
def split_qkv_bias_tp(v, n_head, n_hidden, tensor_parallel, rank):
"""
Splits the QKV bias according to tensor parallelism
"""
v = v.reshape(3, n_hidden)
split_v = split(v, tensor_parallel, rank, dim=1)
split_v = split_v.reshape(3 * (n_hidden // tensor_parallel))
return split_v
def split_matrix_tp(v, tensor_parallel, rank, dim):
return split(v, tensor_parallel, rank, dim=dim)
def get_weight(config, prefix, dtype):
if config[prefix + ".weight"].dtype != dtype:
config[prefix + ".weight"].data = config[prefix + ".weight"].to(dtype)
return config[prefix + ".weight"].detach()
def get_bias(config, prefix, dtype):
if config[prefix + ".bias"].dtype != dtype:
config[prefix + ".bias"].data = config[prefix + ".bias"].to(dtype)
return config[prefix + ".bias"].detach()
def get_weight_and_bias(config, prefix, dtype):
return get_weight(config, prefix, dtype), get_bias(config, prefix, dtype)
def get_tllm_linear_weight(
weight,
prefix,
bias=None,
use_weight_only=False,
plugin_weight_only_quant_type=torch.int8,
dtype="float32",
use_gemm_woq_plugin=True,
postfix="weight",
quant_scale_name=None,
):
results = {}
if use_weight_only:
if weight.dim() > 2:
v = weight.transpose(1, 2).contiguous()
else:
v = weight.t().contiguous()
processed_torch_weights, torch_weight_scales = (
torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix(
v.cpu(), plugin_weight_only_quant_type
)
)
if not use_gemm_woq_plugin:
results[prefix + postfix] = v.to(dtype)
else:
results[prefix + postfix] = processed_torch_weights
if quant_scale_name is not None:
results[quant_scale_name] = torch_weight_scales
else:
results[prefix + "per_channel_scale"] = torch_weight_scales
else:
results[prefix + postfix] = weight
if bias is not None:
results[prefix + "bias"] = bias
return results
def dup_kv_weight(v, num_head, tp_size):
assert tp_size % num_head == 0
reps = tp_size // num_head
head_size = v.shape[0] // num_head
v = v.reshape(num_head, head_size, -1)[:, None, :, :].expand(
num_head, reps, head_size, v.shape[1]
)
return v.reshape(num_head * reps * head_size, -1).clone().detach()
def get_tllm_linear_sq_weight(
vals,
prefix,
shape,
tensor_parallel,
is_qkv=False,
per_token=False,
per_channel=False,
last_prefix=None,
bias=None,
smoother_value=None,
smoother_shape=None,
rank=0,
cat_dim=0,
multi_query_mode=False,
):
results = {}
def multi_query_split(data, local_dim, head_size, tp_size, cur_rank):
q, k, v = torch.split(data, [local_dim, head_size, head_size], dim=-1)
q_split = torch.chunk(q, tp_size, dim=-1)
k_split = torch.chunk(k, tp_size, dim=-1)
v_split = torch.chunk(v, tp_size, dim=-1)
return [
torch.concat((q_split[ii], k_split[ii], v_split[ii]), axis=-1)
for ii in range(tp_size)
][cur_rank]
col_shape = shape if (is_qkv or per_channel) else [1, 1]
if per_token:
if per_channel:
original_weights = torch.Tensor(vals["weight.int8.col"]).cuda()
else:
original_weights = torch.Tensor(vals["weight.int8"]).cuda()
local_dim = original_weights.shape[0]
head_size = (original_weights.shape[1] - local_dim) // 2
if multi_query_mode:
cur_weights = multi_query_split(
original_weights, local_dim, head_size, tensor_parallel, rank
)
else:
cur_weights = torch.chunk(original_weights, tensor_parallel, dim=cat_dim)[
rank
]
if is_qkv:
hidden_dim = cur_weights.shape[0]
cur_weights = cur_weights.reshape(hidden_dim, -1)
results[prefix + "weight"] = cur_weights.t().contiguous()
if smoother_value is None:
results[last_prefix] = torch.Tensor([1.0]).to(torch.float32).cuda()
if per_channel:
cur_per_channel_value = vals["scale_w_quant_orig.col"]
if smoother_value is None:
if multi_query_mode:
cur_per_channel_value = multi_query_split(
vals["scale_w_quant_orig.col"],
local_dim,
head_size,
tensor_parallel,
rank,
)
else:
cur_per_channel_value = torch.chunk(
vals["scale_w_quant_orig.col"], tensor_parallel, dim=cat_dim
)[rank]
else:
cur_per_channel_value = vals["scale_w_quant_orig"]
if is_qkv:
if multi_query_mode:
cur_per_channel_value = multi_query_split(
vals["scale_w_quant_orig"],
local_dim,
head_size,
tensor_parallel,
rank,
)
else:
cur_per_channel_value = torch.chunk(
vals["scale_w_quant_orig"], tensor_parallel, dim=cat_dim
)[rank]
results[prefix + "per_channel_scale"] = cur_per_channel_value.reshape(
col_shape
).contiguous()
else:
if per_channel:
original_weights = torch.Tensor(vals["weight.int8.col"]).cuda()
else:
original_weights = torch.Tensor(vals["weight.int8"]).cuda()
local_dim = original_weights.shape[0]
head_size = (original_weights.shape[1] - local_dim) // 2
if multi_query_mode:
cur_weights = multi_query_split(
original_weights, local_dim, head_size, tensor_parallel, rank
)
else:
cur_weights = torch.chunk(original_weights, tensor_parallel, dim=cat_dim)[
rank
]
if is_qkv:
hidden_dim = cur_weights.shape[0]
cur_weights = cur_weights.reshape(hidden_dim, -1)
results[prefix + "weight"] = cur_weights.t().contiguous()
if per_channel:
cur_per_channel_value = vals["scale_y_accum_quant.col"]
if smoother_value is None:
if multi_query_mode:
cur_per_channel_value = multi_query_split(
vals["scale_y_accum_quant.col"],
local_dim,
head_size,
tensor_parallel,
rank,
)
else:
cur_per_channel_value = torch.chunk(
vals["scale_y_accum_quant.col"], tensor_parallel, dim=cat_dim
)[rank]
else:
cur_per_channel_value = vals["scale_y_accum_quant"]
# QKV is always per_channel
if is_qkv:
if multi_query_mode:
cur_per_channel_value = multi_query_split(
vals["scale_y_accum_quant"],
local_dim,
head_size,
tensor_parallel,
rank,
)
else:
cur_per_channel_value = torch.chunk(
vals["scale_y_accum_quant"], tensor_parallel, dim=cat_dim
)[rank]
results[prefix + "per_channel_scale"] = (
torch.Tensor(cur_per_channel_value)
.to(torch.float32)
.reshape(col_shape)
.contiguous()
.cuda()
)
results[prefix + "act_scale"] = (
torch.Tensor([[vals["scale_y_quant_orig"]]])
.to(torch.float32)
.contiguous()
.cuda()
)
results[last_prefix] = (
torch.Tensor([vals["scale_x_orig_quant"]])
.to(torch.float32)
.contiguous()
.cuda()
)
if smoother_value is not None:
cur_smoother_value = torch.chunk(smoother_value, tensor_parallel, dim=cat_dim)[
rank
]
results[prefix + "smoother"] = (
cur_smoother_value.reshape(smoother_shape).contiguous().to(torch.float32)
)
if bias is not None:
results[prefix + "bias"] = bias
return results
def convert_hf_llama(
hf_model,
mapping,
vocab_size=32000,
dtype="float32",
use_parallel_embedding=False,
sharding_dim=0,
use_weight_only=False,
share_embedding_table=False,
residual_mlp=False,
use_gemm_woq_plugin=False,
plugin_weight_only_quant_type=torch.int8,
use_smooth_quant=False,
per_channel=False,
per_token=False,
int8_kv_cache=False,
act_range=[],
qkv_para=[],
smoother=[],
moe_config=None,
):
weights = {}
tik = time.time()
tensor_parallel = mapping.tp_size
model_params = dict(hf_model.named_parameters())
dtype = getattr(torch, dtype)
num_attention_heads = hf_model.config.num_attention_heads
hidden_size = hf_model.config.hidden_size
head_size = hidden_size // num_attention_heads
intermediate_size = hf_model.config.intermediate_size
num_key_value_heads = getattr(
hf_model.config, "num_key_value_heads", num_attention_heads
)
mha_mode = num_key_value_heads == num_attention_heads
layers_range = mapping.pp_layers(hf_model.config.num_hidden_layers)
def convert_layer(l):
prefix = f"model.layers.{l}."
tllm_prex = f"transformer.layers.{l - layers_range[0]}."
q_weight = get_weight(model_params, prefix + "self_attn.q_proj", dtype)
k_weight = get_weight(model_params, prefix + "self_attn.k_proj", dtype)
v_weight = get_weight(model_params, prefix + "self_attn.v_proj", dtype)
if not mha_mode:
if num_key_value_heads < tensor_parallel:
# duplicate the KV heads up to tensor_parallel
k_weight = dup_kv_weight(k_weight, num_key_value_heads, tensor_parallel)
v_weight = dup_kv_weight(v_weight, num_key_value_heads, tensor_parallel)
assert (k_weight.shape[0] % (mapping.tp_size * head_size)) == 0
assert (v_weight.shape[0] % (mapping.tp_size * head_size)) == 0
wq = split(q_weight, mapping.tp_size, mapping.tp_rank)
wk = split(k_weight, mapping.tp_size, mapping.tp_rank)
wv = split(v_weight, mapping.tp_size, mapping.tp_rank)
split_v = torch.concat((wq, wk, wv))
else:
qkv_weight = torch.cat([q_weight, k_weight, v_weight], dim=0)
split_v = split_qkv_tp(
qkv_weight,
num_attention_heads,
hidden_size,
tensor_parallel,
mapping.tp_rank,
)
if prefix + "self_attn.q_proj.bias" in model_params:
# only used in Internlm 7B models
q_bias = get_bias(model_params, prefix + "self_attn.q_proj", dtype)
k_bias = get_bias(model_params, prefix + "self_attn.k_proj", dtype)
v_bias = get_bias(model_params, prefix + "self_attn.v_proj", dtype)
qkv_bias = torch.cat((q_bias, k_bias, v_bias))
split_bias_v = split_qkv_bias_tp(
qkv_bias,
num_attention_heads,
hidden_size,
tensor_parallel,
mapping.tp_rank,
)
else:
split_bias_v = None
if use_smooth_quant:
qkv_weight = qkv_para[prefix + "self_attn.qkv_proj"]
qkv_out_dim = qkv_weight.shape[1]
if not mha_mode:
local_dim = qkv_weight.shape[0]
kv_hidden_size = (qkv_weight.shape[-1] - local_dim) // 2
qkv_weight = qkv_weight.reshape(
local_dim, local_dim + 2 * kv_hidden_size
)
else:
qkv_weight = qkv_weight.reshape(hidden_size, 3, hidden_size)
int8_weights = generate_int8(
qkv_weight,
act_range.get(prefix + "self_attn.qkv_proj"),
is_qkv=True,
multi_query_mode=bool(not mha_mode),
)
weights.update(
get_tllm_linear_sq_weight(
int8_weights,
tllm_prex + "attention.qkv.",
[1, qkv_out_dim // tensor_parallel],
tensor_parallel,
is_qkv=True,
bias=split_bias_v,
per_token=per_token,
per_channel=per_channel,
last_prefix=tllm_prex + "input_layernorm.scale_to_int",
smoother_value=None,
smoother_shape=None,
rank=mapping.tp_rank,
cat_dim=-1,
multi_query_mode=bool(not mha_mode),
)
)
else:
weights.update(
get_tllm_linear_weight(
split_v,
tllm_prex + "attention.qkv.",
split_bias_v,
use_weight_only,
plugin_weight_only_quant_type,
dtype,
use_gemm_woq_plugin,
)
)
if int8_kv_cache:
qkv_y = torch.cat(
[
act_range.get(prefix + "self_attn.q_proj")["y"],
act_range.get(prefix + "self_attn.k_proj")["y"],
act_range.get(prefix + "self_attn.v_proj")["y"],
],
dim=0,
)
int8_kv_scales = qkv_y.max() / 127.0
kv_cache_weights = {}
kv_cache_weights[tllm_prex + "attention.kv_cache_scaling_factor"] = (
int8_kv_scales.reshape([1])
)
weights.update(kv_cache_weights)
attn_dense_weight = get_weight(model_params, prefix + "self_attn.o_proj", dtype)
split_v = split_matrix_tp(
attn_dense_weight, tensor_parallel, mapping.tp_rank, dim=1
)
if prefix + "self_attn.o_proj.bias" in model_params:
attn_dense_bias = get_bias(model_params, prefix + "self_attn.o_proj", dtype)
else:
attn_dense_bias = None
if use_smooth_quant:
attn_dense_weight = attn_dense_weight.t()
int8_weights = generate_int8(
attn_dense_weight, act_range.get(prefix + "self_attn.o_proj")
)
weights.update(
get_tllm_linear_sq_weight(
int8_weights,
tllm_prex + "attention.dense.",
[1, hidden_size],
tensor_parallel,
is_qkv=False,
bias=attn_dense_bias,
per_token=per_token,
per_channel=per_channel,
last_prefix=tllm_prex + "attention.quantization_scaling_factor",
smoother_value=smoother[(prefix + "self_attn.o_proj")],
smoother_shape=[1, hidden_size // tensor_parallel],
rank=mapping.tp_rank,
cat_dim=0,
)
)
else:
weights.update(
get_tllm_linear_weight(
split_v,
tllm_prex + "attention.dense.",
attn_dense_bias,
use_weight_only,
plugin_weight_only_quant_type,
dtype,
use_gemm_woq_plugin,
)
)
if moe_config and moe_config.has_moe():
rank_experts = list(range(moe_config.num_experts))
if moe_config.tp_mode == moe_config.ParallelismMode.EXPERT_PARALLEL:
rank_experts = mapping.ep_experts(moe_config.num_experts)
for suffix in ["w1", "w2", "w3"]:
model_params[
f"model.layers.{l}.block_sparse_moe.experts.{suffix}.weight"
] = torch.stack(
[
model_params[
f"model.layers.{l}.block_sparse_moe.experts.{expert}.{suffix}.weight"
].detach()
for expert in rank_experts
]
)
w3 = model_params[f"model.layers.{l}.block_sparse_moe.experts.w3.weight"]
w2 = model_params[f"model.layers.{l}.block_sparse_moe.experts.w2.weight"]
w1 = model_params[f"model.layers.{l}.block_sparse_moe.experts.w1.weight"]
if moe_config.tp_mode == moe_config.ParallelismMode.TENSOR_PARALLEL:
w3 = split(w3, mapping.tp_size, mapping.tp_rank, dim=1)
w2 = split(w2, mapping.tp_size, mapping.tp_rank, dim=2)
w1 = split(w1, mapping.tp_size, mapping.tp_rank, dim=1)
model_params[f"model.layers.{l}.block_sparse_moe.experts.w3w1.weight"] = (
torch.concat([w3, w1], dim=-2)
)
model_params[f"model.layers.{l}.block_sparse_moe.experts.w2.weight"] = w2
## block_sparse_moe.experts.w2.weight
moe_experts_w2_weights = get_weight(
model_params, prefix + "block_sparse_moe.experts.w2", dtype
)
weights.update(
get_tllm_linear_weight(
moe_experts_w2_weights,
tllm_prex + "mlp.proj.",
None,
use_weight_only,
plugin_weight_only_quant_type,
dtype,
use_gemm_woq_plugin,
)
)
##block_sparse_moe.experts.w3w1.weight
moe_experts_w3w1_weights = get_weight(
model_params, prefix + "block_sparse_moe.experts.w3w1", dtype
)
weights.update(
get_tllm_linear_weight(
moe_experts_w3w1_weights,
tllm_prex + "mlp.fc.",
None,
use_weight_only,
plugin_weight_only_quant_type,
dtype,
use_gemm_woq_plugin,
)
)
if residual_mlp:
residual_mlp_gate_weights = get_weight(
model_params, prefix + "residual_mlp.w3", dtype
)
if use_smooth_quant:
residual_mlp_gate_weights = residual_mlp_gate_weights.t()
int8_weights = generate_int8(
residual_mlp_gate_weights,
act_range.get(prefix + "residual_mlp.w3"),
)
weights.update(
get_tllm_linear_sq_weight(
int8_weights,
tllm_prex + "residual_mlp.gate.",
[1, hidden_size // tensor_parallel],
tensor_parallel,
is_qkv=False,
per_token=per_token,
per_channel=per_channel,
last_prefix=tllm_prex + "post_layernorm.scale_to_int",
smoother_value=None,
smoother_shape=None,
rank=mapping.tp_rank,
cat_dim=-1,
)
)
else:
split_v = split_matrix_tp(
residual_mlp_gate_weights,
tensor_parallel,
mapping.tp_rank,
dim=0,
)
weights.update(
get_tllm_linear_weight(
split_v,
tllm_prex + "residual_mlp.gate.",
None,
use_weight_only,
plugin_weight_only_quant_type,
dtype,
use_gemm_woq_plugin,
)
)
residual_mlp_fc_weight = get_weight(
model_params, prefix + "residual_mlp.w1", dtype
)
if use_smooth_quant:
residual_mlp_fc_weight = residual_mlp_fc_weight.t() # verified
int8_weights = generate_int8(
residual_mlp_fc_weight,
act_range.get(prefix + "residual_mlp.w1"),
)
weights.update(
get_tllm_linear_sq_weight(
int8_weights,
tllm_prex + "residual_mlp.fc.",
[1, hidden_size // tensor_parallel],
tensor_parallel,
is_qkv=False,
per_token=per_token,
per_channel=per_channel,
last_prefix=tllm_prex + "post_layernorm.scale_to_int",
smoother_value=None,
smoother_shape=None,
rank=mapping.tp_rank,
cat_dim=-1,
)
)
else:
split_v = split_matrix_tp(
residual_mlp_fc_weight, tensor_parallel, mapping.tp_rank, dim=0
)
weights.update(
get_tllm_linear_weight(
split_v,
tllm_prex + "residual_mlp.fc.",
None,
use_weight_only,
plugin_weight_only_quant_type,
dtype,
use_gemm_woq_plugin,
)
)
residual_mlp_proj_weight = get_weight(
model_params, prefix + "residual_mlp.w2", dtype
)
if use_smooth_quant:
residual_mlp_proj_weight = residual_mlp_proj_weight.t()
int8_weights = generate_int8(
residual_mlp_proj_weight,
act_range.get(prefix + "residual_mlp.w2"),
)
weights.update(
get_tllm_linear_sq_weight(
int8_weights,
tllm_prex + "residual_mlp.proj.",
[1, hidden_size],
tensor_parallel,
is_qkv=False,
per_token=per_token,
per_channel=per_channel,
last_prefix=tllm_prex
+ "residual_mlp.quantization_scaling_factor",
smoother_value=smoother[prefix + "residual_mlp.w2"],
smoother_shape=[1, hidden_size // tensor_parallel],
rank=mapping.tp_rank,
cat_dim=0,
)
)
else:
split_v = split_matrix_tp(
residual_mlp_proj_weight,
tensor_parallel,
mapping.tp_rank,
dim=1,
)
weights.update(
get_tllm_linear_weight(
split_v,
tllm_prex + "residual_mlp.proj.",
None,
use_weight_only,
plugin_weight_only_quant_type,
dtype,
use_gemm_woq_plugin,
)
)
moe_experts_gate_weights = get_weight(
model_params, prefix + "block_sparse_moe.gate", torch.float32
)
weights.update(
get_tllm_linear_weight(
moe_experts_gate_weights,
tllm_prex + "mlp.router.",
None,
False, # Router should never be quantized
plugin_weight_only_quant_type,
dtype,
use_gemm_woq_plugin,
)
)
else:
mlp_gate_weight = get_weight(model_params, prefix + "mlp.up_proj", dtype)
split_v = split_matrix_tp(
mlp_gate_weight, tensor_parallel, mapping.tp_rank, dim=0
)
if use_smooth_quant:
mlp_gate_weight = mlp_gate_weight.t()
int8_weights = generate_int8(
mlp_gate_weight, act_range.get(prefix + "mlp.up_proj")
)
weights.update(
get_tllm_linear_sq_weight(
int8_weights,
tllm_prex + "mlp.gate.",
[1, intermediate_size // tensor_parallel],
tensor_parallel,
is_qkv=False,
per_token=per_token,
per_channel=per_channel,
last_prefix=tllm_prex + "post_layernorm.scale_to_int",
smoother_value=None,
smoother_shape=None,
rank=mapping.tp_rank,
cat_dim=-1,
)
)
else:
weights.update(
get_tllm_linear_weight(
split_v,
tllm_prex + "mlp.gate.",
None,
use_weight_only,
plugin_weight_only_quant_type,
dtype,
use_gemm_woq_plugin,
)
)
mlp_fc_weight = get_weight(model_params, prefix + "mlp.gate_proj", dtype)
split_v = split_matrix_tp(
mlp_fc_weight, tensor_parallel, mapping.tp_rank, dim=0
)
if use_smooth_quant:
mlp_fc_weight = mlp_fc_weight.t() # verified
int8_weights = generate_int8(
mlp_fc_weight, act_range.get(prefix + "mlp.gate_proj")
)
weights.update(
get_tllm_linear_sq_weight(
int8_weights,
tllm_prex + "mlp.fc.",
[1, intermediate_size // tensor_parallel],
tensor_parallel,
is_qkv=False,
per_token=per_token,
per_channel=per_channel,
last_prefix=tllm_prex + "post_layernorm.scale_to_int",
smoother_value=None,
smoother_shape=None,
rank=mapping.tp_rank,
cat_dim=-1,
)
)
else:
weights.update(
get_tllm_linear_weight(
split_v,
tllm_prex + "mlp.fc.",
None,
use_weight_only,
plugin_weight_only_quant_type,
dtype,
use_gemm_woq_plugin,
)
)
mlp_proj_weight = get_weight(model_params, prefix + "mlp.down_proj", dtype)
split_v = split_matrix_tp(
mlp_proj_weight, tensor_parallel, mapping.tp_rank, dim=1
)
if use_smooth_quant:
mlp_proj_weight = mlp_proj_weight.t()
int8_weights = generate_int8(
mlp_proj_weight, act_range.get(prefix + "mlp.down_proj")
)
weights.update(
get_tllm_linear_sq_weight(
int8_weights,
tllm_prex + "mlp.proj.",
[1, hidden_size],
tensor_parallel,
is_qkv=False,
per_token=per_token,
per_channel=per_channel,
last_prefix=tllm_prex + "mlp.quantization_scaling_factor",
smoother_value=smoother[prefix + "mlp.down_proj"],
smoother_shape=[1, intermediate_size // tensor_parallel],
rank=mapping.tp_rank,
cat_dim=0,
)
)
else:
weights.update(
get_tllm_linear_weight(
split_v,
tllm_prex + "mlp.proj.",
None,
use_weight_only,
plugin_weight_only_quant_type,
dtype,
use_gemm_woq_plugin,
)
)
# Layer norms do not use tensor parallelism
input_ln_weight = get_weight(model_params, prefix + "input_layernorm", dtype)
weights[tllm_prex + "input_layernorm.weight"] = input_ln_weight
post_ln_weight = get_weight(
model_params, prefix + "post_attention_layernorm", dtype
)
weights[tllm_prex + "post_layernorm.weight"] = post_ln_weight
if residual_mlp:
residual_ln_weight = get_weight(
model_params, prefix + "residual_layernorm", dtype
)
weights[tllm_prex + "residual_layernorm.weight"] = residual_ln_weight
cur_block_weights = [
weight_name
for weight_name in model_params
if weight_name.find(prefix) != -1
]
for weight_name in cur_block_weights:
model_params[weight_name] = None
for l in layers_range:
convert_layer(l)
release_gc()
v = get_weight(model_params, "model.embed_tokens", dtype)
if hf_model.config.tie_word_embeddings:
# lm_head.weight has the same weights as embedding
if mapping.is_last_pp_rank():
if vocab_size % mapping.tp_size != 0:
# padding
vocab_size_padded = pad_vocab_size(vocab_size, mapping.tp_size)
pad_width = vocab_size_padded - vocab_size
v = torch.nn.functional.pad(v, (0, 0, 0, pad_width), "constant", 0)
weights["lm_head.weight"] = split(v, mapping.tp_size, mapping.tp_rank)
if use_parallel_embedding:
v = split_matrix_tp(v, mapping.tp_size, mapping.tp_rank, dim=sharding_dim)
if mapping.is_first_pp_rank():
weights["transformer.vocab_embedding.weight"] = v
# if not use_parallel_embedding:
# weights['transformer.vocab_embedding.weight'] = embed_w
# else:
# assert hf_model.config.vocab_size % tensor_parallel == 0
# weights['transformer.vocab_embedding.weight'] = split_matrix_tp(
# embed_w, tensor_parallel, rank
lm_head_weights = get_weight(model_params, "lm_head", dtype)
if mapping.is_last_pp_rank():
if vocab_size % mapping.tp_size != 0:
# padding
vocab_size_padded = pad_vocab_size(vocab_size, mapping.tp_size)
pad_width = vocab_size_padded - vocab_size
lm_head_weights = torch.nn.functional.pad(
lm_head_weights, (0, 0, 0, pad_width), "constant", value=0
)
weights["lm_head.weight"] = split_matrix_tp(
lm_head_weights, tensor_parallel, mapping.tp_rank, dim=0
)
ln_f_w = get_weight(model_params, "model.norm", dtype)
weights["transformer.ln_f.weight"] = ln_f_w
tok = time.time()
t = time.strftime("%H:%M:%S", time.gmtime(tok - tik))
print(f"Weights loaded. Total time: {t}")
return weights
def smooth_quant(
model,
model_dir,
calib_dataset,
dataset_cache_dir,
smoothquant: Optional[float] = None,
):
assert model is not None
act_range = {}
llama_qkv_para = {}
# smoother for inputs of self_attn.o_proj and mlp.down_proj
llama_smoother = {}
os.environ["TOKENIZERS_PARALLELISM"] = os.environ.get(
"TOKENIZERS_PARALLELISM", "false"
)
tokenizer = AutoTokenizer.from_pretrained(
model_dir, trust_remote_code=True, use_fast=False, padding_side="left"
)
dataset = load_calib_dataset(calib_dataset, cache_dir=dataset_cache_dir)
act_range = capture_activation_range(model, tokenizer, dataset)
if smoothquant is not None:
smooth_llama_model(
model, act_range, smoothquant, llama_qkv_para, llama_smoother
)
return act_range, llama_qkv_para, llama_smoother
def create_config_from_hugging_face(
hf_model,
dtype,
mapping,
quantization: QuantConfig = None,
override_fields: dict = {},
):
config = {}
hf_config = AutoConfig.from_pretrained(hf_model, trust_remote_code=True)
if hf_config.model_type == "llava":
# LLaVA = Vision model + Llama LLM
# We load a llava config and use its' text config as llama config
hf_config = LlavaConfig.from_pretrained(hf_model).text_config
if hf_config.model_type == "llava_next":
# LLaVA = Vision model + Llama LLM
# We load a llava config and use its' text config as llama config
hf_config = LlavaNextConfig.from_pretrained(hf_model).text_config
# TODO: directly assign the hf_config fields to the config dict w/o creating these local vars
# same for from_meta and from_cli_args
n_head = hf_config.num_attention_heads
inter_size = hf_config.intermediate_size
n_layer = hf_config.num_hidden_layers
n_embd = hf_config.hidden_size
n_kv_head = getattr(hf_config, "num_key_value_heads", n_head)
rms_norm_eps = hf_config.rms_norm_eps
vocab_size = hf_config.vocab_size
n_positions = hf_config.max_position_embeddings
hidden_act = hf_config.hidden_act
config["rotary_scaling"] = getattr(hf_config, "rope_scaling", None)
rotary_base = getattr(hf_config, "rope_theta", 10000.0)
config["residual_mlp"] = getattr(hf_config, "parallel_attn_mlp_res", False)
if hf_config.model_type == "mixtral" or hf_config.model_type == "arctic":
# HF LLaMA-type models are implicitly using gated activation.
# With our MoE implementation, we must make it explicit
hidden_act = "swiglu"
config["moe_normalization_mode"] = (
MoeConfig.ExpertScaleNormalizationMode.RENORMALIZE
)
else:
config["moe_normalization_mode"] = None
moe_num_experts = getattr(hf_config, "num_local_experts", 0)
moe_top_k = getattr(hf_config, "num_experts_per_tok", 0)
moe_tp_mode = MoeConfig.ParallelismMode.TENSOR_PARALLEL
architecture = hf_config.architectures[0]
# VILA model, force to use llama config
if hf_config.model_type == "llava_llama":
architecture = "LlamaForCausalLM"
attn_bias = getattr(hf_config, "bias", False) or getattr(
hf_config, "attention_bias", False
)
config.update(
{
"architecture": architecture,
"dtype": dtype,
"logits_dtype": "float32",
"num_hidden_layers": n_layer,
"num_attention_heads": n_head,
"hidden_size": n_embd,
"intermediate_size": inter_size,
"num_key_value_heads": n_kv_head,
"vocab_size": vocab_size,
"position_embedding_type": "rope_gpt_neox",
"max_position_embeddings": n_positions,
"hidden_act": hidden_act,
"rotary_base": rotary_base,
"norm_epsilon": rms_norm_eps,
"moe_num_experts": moe_num_experts,
"moe_top_k": moe_top_k,
"moe_tp_mode": moe_tp_mode,
# TODO: should have directly map from the Mapping object to the TRT-LLM checkpoint fields
"mapping": {
"world_size": mapping.tp_size * mapping.pp_size,
"tp_size": mapping.tp_size,
"pp_size": mapping.pp_size,
},
"attn_bias": attn_bias,
}
)
config["quantization"] = quantization.asdict()
config.update(override_fields)
moe_config = MoeConfig(
config["moe_num_experts"],
config["moe_top_k"],
config["moe_tp_mode"],
config["moe_normalization_mode"],
).validate()
use_weight_only = config["quantization"]["quant_algo"] in [
QuantAlgo.W8A16,
QuantAlgo.W4A16,
QuantAlgo.FP8,
]
if use_weight_only and moe_config.has_moe():
config["quantization"]["exclude_modules"].append("router")
print("-----Debug config: ", config)
return config
def from_hugging_face(
cls,
model_dir,
dtype,
*,
mapping,
quantization: QuantConfig = None,
load_by_shard=False,
load_model_on_cpu=False,
override_fields={},
skip_loading_weights=False,
preloaded_model=None,
):
"""Create a LLaMAForCausalLM object from give parameters"""
assert model_dir is not None
if isinstance(model_dir, Path): # some code relies on this as string
model_dir = str(model_dir)
# register VILA model
if "vila" in model_dir:
sys.path.append(model_dir + "/../VILA")
from llava.model import LlavaConfig, LlavaLlamaForCausalLM
AutoConfig.register("llava_llama", LlavaConfig)
AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM)
if override_fields.get("share_embedding_table", False):
logger.warning(
"Llama model does not support share_embedding_table; setting share_embedding_table=False"
)
override_fields["share_embedding_table"] = False
config = create_config_from_hugging_face(
model_dir, dtype, mapping, quantization, override_fields=override_fields
)
pretrained_config = PretrainedConfig.from_dict(config)
pretrained_config.set_rank(mapping.rank) # TODO:remove this hack
llama = cls.from_config(pretrained_config)
llama = optimize_model(
llama,
use_parallel_embedding=pretrained_config.use_parallel_embedding,
share_embedding_table=pretrained_config.share_embedding_table,
)
if skip_loading_weights:
return llama
model = preloaded_model
if (
model is None and not load_by_shard
): # when load by shard, no need to create complete hf model
have_safetensors = any(
[f.endswith(".safetensors") for f in os.listdir(model_dir)]
)
hf_config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True)
if hf_config.model_type == "llava":
hf_llava = LlavaForConditionalGeneration.from_pretrained(
model_dir, torch_dtype="auto"
)
model = hf_llava.language_model
elif hf_config.model_type == "llava_next":
hf_llava_next = LlavaNextForConditionalGeneration.from_pretrained(
model_dir, torch_dtype="auto"
)
model = hf_llava_next.language_model
else:
# TODO: Remove WAR after `load_from_hf_safetensors` supports weight-only quantization
if not have_safetensors or config["quantization"]["quant_algo"] is not None:
model = AutoModelForCausalLM.from_pretrained(
model_dir,
device_map="auto" if not load_model_on_cpu else "cpu",
torch_dtype="auto",
trust_remote_code=True,
)
if load_by_shard:
weights = load_from_hf_checkpoint(model_dir, mapping, pretrained_config)
elif model is not None:
weights = load_weights_from_hf(config=config, mapping=mapping, model=model)
else:
weights = load_from_hf_safetensors(
model_dir=model_dir, config=pretrained_config, mapping=mapping
)
llama.load(weights)
return llama
def quantize(
dtype,
model_dir,
output_dir,
mapping,
quantization: QuantConfig,
*,
calib_dataset="cnn_dailymail",
override_fields={},
dataset_cache_dir: Optional[str] = None,
):
"""
Quantize the save the model as TRT-LLM checkpoint to output_dir
"""
# TODO: currently only smooth quant and kv cache quantization are supported, needs to support mode quant algorithm calling modelopt
config = create_config_from_hugging_face(
model_dir, dtype, mapping, quantization, override_fields=override_fields
)
with open(os.path.join(output_dir, "config.json"), "w") as f:
json.dump(config, f, indent=4)
assert (
mapping.rank == -1
), "You shall call quantize only once in one rank, assert rank==-1 for precaution"
act_range = {}
llama_qkv_para = {}
# smoother for inputs of self_attn.o_proj and mlp.down_proj
llama_smoother = {}
model = None
assert config["quantization"]["quant_algo"] == quantization.quant_algo
int8_kv_cache = quantization.kv_cache_quant_algo == QuantAlgo.INT8
use_smooth_quant = (
quantization.quant_algo is not None
and quantization.quant_algo.startswith("W8A8_SQ")
)
assert (
use_smooth_quant or int8_kv_cache
), "Call from_hugging_face when there is no quantization"
if use_smooth_quant:
assert (
quantization.smoothquant_val is not None
), "A smooth value must be specified when using smooth quant"
assert model_dir is not None
## only load and call smooth quant routine once for all ranks
hf_config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True)
assert (
"llava" not in hf_config.model_type
), "Smooth quant llava/vila is not supported yet"
model = AutoModelForCausalLM.from_pretrained(
model_dir,
device_map="auto",
torch_dtype="auto" if not use_smooth_quant else torch.float16,
trust_remote_code=True,
)
act_range, llama_qkv_para, llama_smoother = smooth_quant(
model, model_dir, calib_dataset, dataset_cache_dir, quantization.smoothquant_val
)
for rank in range(mapping.world_size):
# To avoid changing the mapping arg in-place, also the given mapping from caller is rank agnostic, since quantize is called from only one rank
ranked_mapping = Mapping(
world_size=mapping.world_size,
rank=rank,
tp_size=mapping.tp_size,
pp_size=mapping.pp_size,
)
weights = load_weights_from_hf(
config=config,
mapping=ranked_mapping,
model=model,
# for smooth quant only
act_range=act_range,
llama_qkv_para=llama_qkv_para,
llama_smoother=llama_smoother,
)
safetensors.torch.save_file(
weights, os.path.join(output_dir, f"rank{rank}.safetensors")
)
del weights
def load_weights_from_hf(
*, config, mapping, model, act_range={}, llama_qkv_para={}, llama_smoother={}
):
# TODO: simplify the parameters here
assert model is not None
plugin_weight_only_quant_type = (
None # the value does not matter when use_weight_only is False
)
quant_algo = config["quantization"]["quant_algo"]
if quant_algo == QuantAlgo.W8A16:
plugin_weight_only_quant_type = torch.int8
elif quant_algo == QuantAlgo.W4A16:
plugin_weight_only_quant_type = torch.quint4x2
moe_config = MoeConfig(
config["moe_num_experts"],
config["moe_top_k"],
config["moe_tp_mode"],
config["moe_normalization_mode"],
).validate()
use_weight_only = quant_algo in [QuantAlgo.W8A16, QuantAlgo.W4A16]
use_smooth_quant = quant_algo is not None and quant_algo.startswith("W8A8_SQ")
per_channel_sq = use_smooth_quant and "PER_CHANNEL" in quant_algo
per_token_sq = use_smooth_quant and "PER_TOKEN" in quant_algo
use_int8_kv_cache = config["quantization"]["kv_cache_quant_algo"] == QuantAlgo.INT8
weights = convert_hf_llama(
model,
mapping,
vocab_size=config["vocab_size"],
dtype=config["dtype"],
use_weight_only=use_weight_only,
use_gemm_woq_plugin=not config.get("disable_weight_only_quant_plugin", False),
plugin_weight_only_quant_type=plugin_weight_only_quant_type,
use_parallel_embedding=config.get("use_parallel_embedding", False),
sharding_dim=config.get("embedding_sharding_dim", 0),
share_embedding_table=config.get("share_embedding_table", False),
residual_mlp=config["residual_mlp"],
use_smooth_quant=use_smooth_quant,
per_channel=per_channel_sq,
per_token=per_token_sq,
int8_kv_cache=use_int8_kv_cache,
act_range=act_range,
qkv_para=llama_qkv_para,
smoother=llama_smoother,
moe_config=moe_config,
)
return weights
# from llava.constants import (
# IMAGE_TOKEN_INDEX,
# DEFAULT_IMAGE_TOKEN,
# DEFAULT_IM_START_TOKEN,
# DEFAULT_IM_END_TOKEN,
# IMAGE_PLACEHOLDER,
# )
# from llava.conversation import conv_templates, SeparatorStyle
# from llava.model.builder import load_pretrained_model
# from llava.utils import disable_torch_init
# from llava.mm_utils import (
# process_images,
# tokenizer_image_token,
# get_model_name_from_path,
# )
import argparse
import json
import os
import sys
import time
import traceback
from concurrent.futures import ThreadPoolExecutor, as_completed
import tensorrt_llm
from tensorrt_llm._utils import release_gc
from tensorrt_llm.layers import MoeConfig
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.models import LLaMAForCausalLM
from tensorrt_llm.models.llama.weight import load_from_gptq_llama
from tensorrt_llm.models.modeling_utils import QuantConfig
from tensorrt_llm.quantization import QuantAlgo
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument("--model_dir", type=str, default=None)
parser.add_argument("--meta_ckpt_dir", type=str, default=None)
parser.add_argument(
"--tp_size", type=int, default=1, help="N-way tensor parallelism size"
)
parser.add_argument(
"--pp_size", type=int, default=1, help="N-way pipeline parallelism size"
)
parser.add_argument(
"--dtype",
type=str,
default="float16",
choices=["float32", "bfloat16", "float16"],
)
parser.add_argument("--vocab_size", type=int, default=32000)
parser.add_argument("--n_positions", type=int, default=2048)
parser.add_argument("--n_layer", type=int, default=32)
parser.add_argument("--n_head", type=int, default=32)
parser.add_argument("--n_kv_head", type=int, default=None)
parser.add_argument("--n_embd", type=int, default=4096)
parser.add_argument("--inter_size", type=int, default=11008)
parser.add_argument("--rms_norm_eps", type=float, default=1e-06)
parser.add_argument(
"--use_weight_only",
default=False,
action="store_true",
help="Quantize weights for the various GEMMs to INT4/INT8."
"See --weight_only_precision to set the precision",
)
parser.add_argument(
"--disable_weight_only_quant_plugin",
default=False,
action="store_true",
help="By default, using plugin implementation for weight quantization. Enabling disable_weight_only_quant_plugin flag will use ootb implementation instead of plugin."
"You must also use --use_weight_only for that argument to have an impact.",
)
parser.add_argument(
"--weight_only_precision",
const="int8",
type=str,
nargs="?",
default="int8",
choices=["int8", "int4", "int4_gptq"],
help="Define the precision for the weights when using weight-only quantization."
"You must also use --use_weight_only for that argument to have an impact.",
)
parser.add_argument(
"--calib_dataset",
type=str,
default="ccdv/cnn_dailymail",
help="The huggingface dataset name or the local directory of the dataset for calibration.",
)
parser.add_argument(
"--smoothquant",
"-sq",
type=float,
default=None,
help="Set the α parameter (see https://arxiv.org/pdf/2211.10438.pdf)"
" to Smoothquant the model, and output int8 weights."
" A good first try is 0.5. Must be in [0, 1]",
)
parser.add_argument(
"--per_channel",
action="store_true",
default=False,
help="By default, we use a single static scaling factor for the GEMM's result. "
"per_channel instead uses a different static scaling factor for each channel. "
"The latter is usually more accurate, but a little slower.",
)
parser.add_argument(
"--per_token",
action="store_true",
default=False,
help="By default, we use a single static scaling factor to scale activations in the int8 range. "
"per_token chooses at run time, and for each token, a custom scaling factor. "
"The latter is usually more accurate, but a little slower.",
)
parser.add_argument(
"--int8_kv_cache",
default=False,
action="store_true",
help="By default, we use dtype for KV cache. int8_kv_cache chooses int8 quantization for KV",
)
parser.add_argument(
"--modelopt_quant_ckpt_path",
type=str,
default=None,
help="Path of a quantized model checkpoint in .npz format",
)
parser.add_argument(
"--per_group",
default=False,
action="store_true",
help="By default, we use a single static scaling factor to scale weights in the int4 range. "
"per_group chooses at run time, and for each group, a custom scaling factor. "
"The flag is built for GPTQ/AWQ quantization.",
)
parser.add_argument(
"--load_by_shard",
action="store_true",
help="Load a pretrained model shard-by-shard.",
)
parser.add_argument("--hidden_act", type=str, default="silu")
parser.add_argument("--rotary_base", type=float, default=10000.0)
parser.add_argument(
"--group_size",
type=int,
default=128,
help="Group size used in GPTQ quantization.",
) # AWQ is only supported by quantize.py script
parser.add_argument(
"--dataset-cache-dir",
type=str,
default=None,
help="cache dir to load the hugging face dataset",
)
parser.add_argument("--load_model_on_cpu", action="store_true")
parser.add_argument(
"--use_parallel_embedding",
action="store_true",
default=False,
help="By default embedding parallelism is disabled. By setting this flag, embedding parallelism is enabled",
)
parser.add_argument(
"--embedding_sharding_dim",
type=int,
default=0,
choices=[0, 1],
help="By default the embedding lookup table is sharded along vocab dimension (embedding_sharding_dim=0). "
"To shard it along hidden dimension, set embedding_sharding_dim=1"
"Note: embedding sharing is only enabled when embedding_sharding_dim = 0",
)
parser.add_argument(
"--use_embedding_sharing",
action="store_true",
default=False,
help="Try to reduce the engine size by sharing the embedding lookup table between two layers."
"Note: the flag might not take effect when the criteria are not met.",
)
parser.add_argument(
"--output_dir",
type=str,
default="tllm_checkpoint",
help="The path to save the TensorRT-LLM checkpoint",
)
parser.add_argument(
"--workers",
type=int,
default=1,
help="The number of workers for converting checkpoint in parallel",
)
parser.add_argument(
"--moe_num_experts",
default=0,
type=int,
help="Specify the number of experts to use for MOE layers",
)
parser.add_argument(
"--moe_top_k",
default=0,
type=int,
help="Specify the top_k value to use for MOE layers. Default to 1 if --moe_num_experts is set",
)
parser.add_argument(
"--moe_tp_mode",
default=MoeConfig.ParallelismMode.TENSOR_PARALLEL,
type=int,
help="Controls how to distribute experts in TP. Check layers/moe.py for accepted values",
)
parser.add_argument(
"--moe_renorm_mode",
default=MoeConfig.ExpertScaleNormalizationMode.RENORMALIZE,
type=int,
help="Controls renormalization after gate logits. Check layers/moe.py for accepted values",
)
parser.add_argument(
"--save_config_only",
action="store_true",
default=False,
help="Only save the model config w/o read and converting weights, be careful, this is for debug only",
)
args = parser.parse_args()
# changing the default to be consistent as the cli help said.
if args.moe_num_experts and args.moe_top_k == 0:
args.moe_top_k = 1
return args
def args_to_quantization(args: argparse.Namespace) -> QuantConfig:
"""return config dict with quantization info based on the command line args"""
quant_config = QuantConfig()
quant_config.exclude_modules = ["lm_head"]
if args.use_weight_only:
if args.weight_only_precision == "int8":
quant_config.quant_algo = QuantAlgo.W8A16
elif args.weight_only_precision == "int4":
quant_config.quant_algo = QuantAlgo.W4A16
elif args.smoothquant:
quant_config.smoothquant_val = args.smoothquant
if args.per_channel:
if args.per_token:
quant_config.quant_algo = QuantAlgo.W8A8_SQ_PER_CHANNEL_PER_TOKEN_PLUGIN
else:
quant_config.quant_algo = (
QuantAlgo.W8A8_SQ_PER_CHANNEL_PER_TENSOR_PLUGIN
)
else:
if args.per_token:
quant_config.quant_algo = QuantAlgo.W8A8_SQ_PER_TENSOR_PER_TOKEN_PLUGIN
else:
quant_config.quant_algo = QuantAlgo.W8A8_SQ_PER_TENSOR_PLUGIN
if args.int8_kv_cache:
quant_config.kv_cache_quant_algo = QuantAlgo.INT8
if args.weight_only_precision == "int4_gptq":
quant_config.group_size = args.group_size
quant_config.has_zero_point = True
quant_config.pre_quant_scale = False
quant_config.quant_algo = QuantAlgo.W4A16_GPTQ
return quant_config
def convert_and_save_meta(args, rank):
mapping = Mapping(
world_size=args.tp_size * args.pp_size,
tp_size=args.tp_size,
pp_size=args.pp_size,
rank=rank,
)
assert not args_to_quantization(
args
).quant_mode.has_any_quant(), (
"quantization from meta checkpoint or empty model were never supported"
)
llama = LLaMAForCausalLM.from_meta_ckpt(
args.meta_ckpt_dir,
args.dtype,
mapping,
use_parallel_embedding=args.use_parallel_embedding,
embedding_sharding_dim=args.embedding_sharding_dim,
)
llama.save_checkpoint(args.output_dir, save_config=(rank == 0))
def args_to_build_options(args):
return {
"use_parallel_embedding": args.use_parallel_embedding,
"embedding_sharding_dim": args.embedding_sharding_dim,
"share_embedding_table": args.use_embedding_sharing,
"disable_weight_only_quant_plugin": args.disable_weight_only_quant_plugin,
}
def from_cli_args(args):
n_kv_head = args.n_kv_head if args.n_kv_head is not None else args.n_head
config = {
"architecture": "LlamaForCausalLM",
"dtype": args.dtype,
"logits_dtype": "float32",
"num_hidden_layers": args.n_layer,
"num_attention_heads": args.n_head,
"hidden_size": args.n_embd,
"intermediate_size": args.inter_size,
"num_key_value_heads": n_kv_head,
"vocab_size": args.vocab_size,
"position_embedding_type": "rope_gpt_neox",
"max_position_embeddings": args.n_positions,
"hidden_act": args.hidden_act,
"rotary_base": args.rotary_base,
"norm_epsilon": args.rms_norm_eps,
"moe_num_experts": args.moe_num_experts,
"moe_top_k": args.moe_top_k,
"moe_tp_mode": args.moe_tp_mode,
"moe_normalization_mode": args.moe_renorm_mode,
"mapping": {
"world_size": args.tp_size * args.pp_size,
"tp_size": args.tp_size,
"pp_size": args.pp_size,
},
"quantization": args_to_quantization(args).asdict(),
}
config.update(args_to_build_options(args))
return config
def preload_model(model_dir, load_model_on_cpu):
use_safetensors = True
from transformers import AutoConfig, AutoModelForCausalLM
if "vila" in model_dir:
use_safetensors = False
sys.path.append(model_dir + "/../VILA")
from llava.model import LlavaConfig, LlavaLlamaForCausalLM
AutoConfig.register("llava_llama", LlavaConfig)
AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM)
hf_config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True)
model_cls = AutoModelForCausalLM
if hf_config.model_type == "llava":
use_safetensors = False
from transformers import LlavaForConditionalGeneration
model_cls = LlavaForConditionalGeneration
use_safetensors = (
any([f.endswith(".safetensors") for f in os.listdir(model_dir)])
and use_safetensors
)
if use_safetensors:
return None
model = model_cls.from_pretrained(
model_dir,
device_map="auto" if not load_model_on_cpu else "cpu",
torch_dtype="auto",
trust_remote_code=True,
)
if hf_config.model_type == "llava":
model = model.language_model
return model
def convert_and_save_hf(args):
model_dir = args.model_dir
load_model_on_cpu = args.load_model_on_cpu
load_by_shard = args.load_by_shard
world_size = args.tp_size * args.pp_size
# Need to convert the cli args to the kay-value pairs and override them in the generate config dict.
# Ideally these fields will be moved out of the config and pass them into build API, keep them here for compatibility purpose for now,
# before the refactor is done.
override_fields = {"moe_tp_mode": args.moe_tp_mode}
quantization = args_to_quantization(args)
override_fields.update(args_to_build_options(args))
if args.smoothquant is not None or args.int8_kv_cache:
assert (
not args.load_by_shard
), "When using quantization, TRT-LLM needs to load the whole HF model, thus load by shard not supported"
assert (
not args.load_model_on_cpu
), "When using quantization, TRT-LLM needs to load the model to GPU"
mapping = Mapping(
world_size=world_size,
rank=-1, # intentinoally make -1 to avoid mistake
tp_size=args.tp_size,
pp_size=args.pp_size,
)
LLaMAForCausalLM.quantize(
args.model_dir,
args.output_dir,
quantization,
dtype=args.dtype,
mapping=mapping,
calib_dataset=args.calib_dataset,
override_fields=override_fields,
dataset_cache_dir=args.dataset_cache_dir,
)
else:
# When not loading by shard, preload one complete model and then slice per rank weights from this
# this saves the disk reloading time
hf_model = (
preload_model(model_dir, load_model_on_cpu)
if not args.load_by_shard
else None
)
def convert_and_save_rank(args, rank):
mapping = Mapping(
world_size=world_size,
rank=rank,
tp_size=args.tp_size,
pp_size=args.pp_size,
)
llama = LLaMAForCausalLM.from_hugging_face(
model_dir,
args.dtype,
mapping=mapping,
quantization=quantization,
load_by_shard=load_by_shard,
load_model_on_cpu=load_model_on_cpu,
override_fields=override_fields,
preloaded_model=hf_model,
)
llama.save_checkpoint(args.output_dir, save_config=(rank == 0))
del llama
execute(args.workers, [convert_and_save_rank] * world_size, args)
release_gc()
def convert_and_save_gptq(args, rank):
mapping = Mapping(
world_size=args.tp_size * args.pp_size,
tp_size=args.tp_size,
rank=rank,
pp_size=args.pp_size,
)
llama = LLaMAForCausalLM.from_hugging_face(
args.model_dir,
args.dtype,
mapping=mapping,
quantization=args_to_quantization(args),
skip_loading_weights=True,
)
weights = load_from_gptq_llama(llama.config, args.modelopt_quant_ckpt_path)
llama.load(weights)
llama.save_checkpoint(args.output_dir, rank == 0)
def execute(workers, func, args):
if workers == 1:
for rank, f in enumerate(func):
f(args, rank)
else:
with ThreadPoolExecutor(max_workers=workers) as p:
futures = [p.submit(f, args, rank) for rank, f in enumerate(func)]
exceptions = []
for future in as_completed(futures):
try:
future.result()
except Exception as e:
traceback.print_exc()
exceptions.append(e)
assert (
len(exceptions) == 0
), "Checkpoint conversion failed, please check error log."
def main():
print(tensorrt_llm.__version__)
args = parse_arguments()
world_size = args.tp_size * args.pp_size
tik = time.time()
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
if (
args.model_dir is None and args.meta_ckpt_dir is None
): # generate fake config.json
config = from_cli_args(args)
with open(os.path.join(args.output_dir, "config.json"), "w") as f:
json.dump(config, f, indent=4)
elif args.meta_ckpt_dir is not None:
assert (
args.model_dir is None
), "Shall not specify both meta checkpoint dir and hugging face dir"
execute(args.workers, [convert_and_save_meta] * world_size, args)
elif args.weight_only_precision == "int4_gptq":
assert args.model_dir is not None
assert args.modelopt_quant_ckpt_path is not None
execute(args.workers, [convert_and_save_gptq] * world_size, args)
else: # all other non-gptq paths from hf model
assert args.model_dir is not None
assert (
args.modelopt_quant_ckpt_path is None
), "only gptq weights only needs this option"
convert_and_save_hf(args)
tok = time.time()
t = time.strftime("%H:%M:%S", time.gmtime(tok - tik))
print(f"Total time of converting checkpoints: {t}")
if __name__ == "__main__":
main()
cp convert.py /usr/local/lib/python3.10/dist-packages/tensorrt_llm/models/llama/convert.py
from safetensors.torch import load_file, safe_open
from safetensors.torch import save_file
import argparse
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument(
"--huggingface_repo_dir",
type=str,
)
parser.add_argument(
"--thirdparty_repo_dir",
type=str,
)
parser.add_argument(
"--merged_repo_dir",
type=str,
)
return parser.parse_args()
args = parse_arguments()
import shutil
shutil.copytree(args.huggingface_repo_dir, args.merged_repo_dir)
import torch
hf_weights_dict = dict()
hf_wgt_names = [
"model-00001-of-00004.safetensors",
"model-00002-of-00004.safetensors",
"model-00003-of-00004.safetensors",
"model-00004-of-00004.safetensors",
]
for wgt in hf_wgt_names:
ori_weights = load_file(args.huggingface_repo_dir + wgt)
for key, value in ori_weights.items():
if key == "language_model.lm_head.weight":
hf_weights_dict[key] = value
elif key == "language_model.model.embed_tokens.weight":
hf_weights_dict[key] = value
weights = [
"model-00001-of-00004.safetensors",
"model-00002-of-00004.safetensors",
"model-00003-of-00004.safetensors",
"model-00004-of-00004.safetensors",
]
for wgt in weights:
ori_weights = load_file(args.thirdparty_repo_dir + wgt)
# import pdb;pdb.set_trace()
new_weights = dict()
for key, value in ori_weights.items():
if key == "lm_head.weight":
new_key = "language_model.lm_head.weight"
elif key == "model.embed_tokens.weight":
new_key = "language_model.model.embed_tokens.weight"
elif key == "model.image_newline":
new_key = "image_newline"
elif "model.layers." in key:
new_key = key.replace("model", "language_model.model")
elif key == "model.norm.weight":
new_key = "language_model.model.norm.weight"
elif key == "model.mm_projector.0.bias":
new_key = "multi_modal_projector.linear_1.bias"
elif key == "model.mm_projector.0.weight":
new_key = "multi_modal_projector.linear_1.weight"
elif key == "model.mm_projector.2.bias":
new_key = "multi_modal_projector.linear_2.bias"
elif key == "model.mm_projector.2.weight":
new_key = "multi_modal_projector.linear_2.weight"
elif "model.vision_tower.vision_tower" in key:
new_key = key.replace("model.vision_tower.vision_tower", "vision_tower")
if new_key == "language_model.lm_head.weight":
value = torch.cat(
(value, hf_weights_dict["language_model.lm_head.weight"][32000:]), dim=0
)
elif new_key == "language_model.model.embed_tokens.weight":
value = torch.cat(
(
value,
hf_weights_dict["language_model.model.embed_tokens.weight"][32000:],
),
dim=0,
)
new_weights[new_key] = value
save_file(new_weights, args.merged_repo_dir + wgt, metadata={"format": "pt"})
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import json
import time
from pathlib import Path
# isort: off
import torch
import tensorrt as trt
# isort: on
import numpy as np
from transformers import (
AutoConfig,
AutoModelForSeq2SeqLM,
AutoTokenizer,
BartForConditionalGeneration,
MBartForConditionalGeneration,
T5ForConditionalGeneration,
)
import tensorrt_llm
from tensorrt_llm import logger
from tensorrt_llm._utils import torch_to_numpy, trt_dtype_to_torch
from tensorrt_llm.lora_manager import LoraManager
from tensorrt_llm.runtime import ModelConfig, SamplingConfig
def get_engine_name(rank):
return "rank{}.engine".format(rank)
def print_tensor(tensor_name, tensor, num_elements=10):
if tensor.dtype in (torch.int32, torch.int64):
tensor = tensor.to(dtype=float)
print(
f"{tensor_name}: mean={tensor.abs().mean().item():.3f}, sum={tensor.abs().sum().item():.3f}, max={tensor.abs().max().item():.3f}"
)
# Pass num_elements=-1 will print the whole tensor
if num_elements < 0:
num_elements = torch.numel(tensor)
print(f"{tensor.flatten()[:num_elements]}")
print("Tensor Shape: ", tensor.size())
print("")
def read_config(config_path: Path):
with open(config_path, "r") as f:
config = json.load(f)
builder_config = config["build_config"]
plugin_config = builder_config["plugin_config"]
pretrained_config = config["pretrained_config"]
lora_config = builder_config["lora_config"]
auto_parallel_config = builder_config["auto_parallel_config"]
use_gpt_attention_plugin = plugin_config["gpt_attention_plugin"]
remove_input_padding = plugin_config["remove_input_padding"]
use_lora_plugin = plugin_config["lora_plugin"]
tp_size = pretrained_config["mapping"]["tp_size"]
pp_size = pretrained_config["mapping"]["pp_size"]
gpus_per_node = auto_parallel_config["gpus_per_node"]
world_size = tp_size * pp_size
assert (
world_size == tensorrt_llm.mpi_world_size()
), f"Engine world size ({world_size}) != Runtime world size ({tensorrt_llm.mpi_world_size()})"
num_heads = pretrained_config["num_attention_heads"]
hidden_size = pretrained_config["hidden_size"]
head_size = pretrained_config["head_size"]
vocab_size = pretrained_config["vocab_size"]
max_batch_size = builder_config["max_batch_size"]
max_beam_width = builder_config["max_beam_width"]
num_layers = pretrained_config["num_hidden_layers"]
num_kv_heads = pretrained_config.get("num_kv_heads", num_heads)
assert (num_heads % tp_size) == 0
num_heads = num_heads // tp_size
hidden_size = hidden_size // tp_size
num_kv_heads = (num_kv_heads + tp_size - 1) // tp_size
cross_attention = pretrained_config["architecture"] == "DecoderModel"
skip_cross_qkv = pretrained_config.get("skip_cross_qkv", False)
has_position_embedding = pretrained_config["has_position_embedding"]
has_token_type_embedding = hasattr(pretrained_config, "type_vocab_size")
use_custom_all_reduce = plugin_config.get("use_custom_all_reduce", False)
dtype = pretrained_config["dtype"]
paged_kv_cache = plugin_config["paged_kv_cache"]
tokens_per_block = plugin_config["tokens_per_block"]
gather_context_logits = builder_config.get("gather_context_logits", False)
gather_generation_logits = builder_config.get("gather_generation_logits", False)
max_prompt_embedding_table_size = builder_config.get(
"max_prompt_embedding_table_size", 0
)
model_config = ModelConfig(
num_heads=num_heads,
num_kv_heads=num_kv_heads,
hidden_size=hidden_size,
head_size=head_size,
max_batch_size=max_batch_size,
max_beam_width=max_beam_width,
vocab_size=vocab_size,
num_layers=num_layers,
gpt_attention_plugin=use_gpt_attention_plugin,
remove_input_padding=remove_input_padding,
paged_kv_cache=paged_kv_cache,
tokens_per_block=tokens_per_block,
cross_attention=cross_attention,
has_position_embedding=has_position_embedding,
has_token_type_embedding=has_token_type_embedding,
use_custom_all_reduce=use_custom_all_reduce,
dtype=dtype,
gather_context_logits=gather_context_logits,
gather_generation_logits=gather_generation_logits,
max_prompt_embedding_table_size=max_prompt_embedding_table_size,
lora_plugin=use_lora_plugin,
lora_target_modules=lora_config.get("lora_target_modules"),
trtllm_modules_to_hf_modules=lora_config.get("trtllm_modules_to_hf_modules"),
skip_cross_qkv=skip_cross_qkv,
)
return model_config, tp_size, pp_size, gpus_per_node, dtype
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument("--max_new_tokens", type=int, default=64)
parser.add_argument("--log_level", type=str, default="error")
parser.add_argument("--engine_dir", "-i", type=str, default="trt_engines")
parser.add_argument("--engine_name", type=str, default="enc_dec")
parser.add_argument(
"--model_name",
type=str,
help="HuggingFace model name or FairSeq model path",
default="t5-small",
)
parser.add_argument(
"--num_beams", type=int, help="Use beam search if num_beams >1", default=1
)
parser.add_argument(
"--debug_mode",
help="Whether or not to turn on the debug mode",
action="store_true",
)
parser.add_argument(
"--compare_hf_fp32",
help="Compare results with HuggingFace FP32",
action="store_true",
)
parser.add_argument("--lora_dir", type=str, default=None, nargs="+")
parser.add_argument("--lora_task_uids", type=str, default=None, nargs="+")
parser.add_argument(
"--output_encoder_npy",
help="Store tensors like encoder outputs used for testing enc-dec C++ runtime.",
action="store_true",
)
return parser.parse_args()
class TRTLLMEncDecModel:
def __init__(
self,
engine_name,
engine_dir,
lora_dir=None,
lora_task_uids=None,
debug_mode=False,
skip_encoder=False,
stream: torch.cuda.Stream = None,
):
# in multi-node setup, it's important to set_device at the very beginning so .to('cuda') refers to current device
# accordingly, all input & output tensors should be moved to current device
# otherwise, it's default to 'cuda:0'
self.runtime_rank = tensorrt_llm.mpi_rank()
device_id = self.runtime_rank % torch.cuda.device_count()
torch.cuda.set_device(device_id)
self.device = torch.cuda.current_device()
self.skip_encoder = skip_encoder
self.lora_task_uids = lora_task_uids
# when enc-dec runs by itself, stream can be None and we create new stream here
# when enc-dec has to run as a component in a bigger workflow (e.g., multimodal), earlier components in the workflow may have results in its stream, which we should pass that stream in to avoid unnecessary stream sync
self.stream = stream
if self.stream is None:
self.stream = torch.cuda.Stream(self.device)
torch.cuda.set_stream(self.stream)
engine_dir = Path(engine_dir)
def engine_setup(component):
# model config
config_path = engine_dir / component / "config.json"
logger.info(f"Using config path {config_path}")
model_config, tp_size, pp_size, gpus_per_node, dtype = read_config(
config_path
)
# MGMN config
world_size = tp_size * pp_size
runtime_rank = tensorrt_llm.mpi_rank()
assert (
runtime_rank < world_size
), "Runtime GPU rank exceeds MPI world size. Did you launch more MPI processes than required?"
runtime_mapping = tensorrt_llm.Mapping(
world_size,
runtime_rank,
tp_size=tp_size,
pp_size=pp_size,
gpus_per_node=gpus_per_node,
)
# load engine
engine_fname = get_engine_name(runtime_rank)
with open(engine_dir / component / engine_fname, "rb") as f:
engine_buffer = f.read()
return model_config, runtime_mapping, engine_buffer
# Note: encoder and decoder doesn't necessarily have the same TP & PP config
if not skip_encoder:
(
self.encoder_model_config,
self.encoder_runtime_mapping,
encoder_engine_buffer,
) = engine_setup(component="encoder")
# for Pipeline Parallelism in encoder
self.nccl_comm = torch.classes.trtllm.NcclCommunicatorOp(
self.encoder_runtime_mapping.tp_size,
self.encoder_runtime_mapping.pp_size,
self.encoder_runtime_mapping.rank,
)
# session setup
self.encoder_session = tensorrt_llm.runtime.Session.from_serialized_engine(
encoder_engine_buffer
)
# encoder lora manager setup
if self.encoder_model_config.lora_plugin:
self.encoder_lora_manager = LoraManager()
# TODO: this is only for bart
self.encoder_lora_manager.load_from_hf(
model_dirs=lora_dir,
model_config=self.encoder_model_config,
runtime_mapping=self.encoder_runtime_mapping,
component="encoder",
)
else:
self.encoder_lora_manager = None
else:
(
self.encoder_model_config,
self.encoder_runtime_mapping,
encoder_engine_buffer,
) = (None, None, None)
self.nccl_comm, self.encoder_session = None, None
(
self.decoder_model_config,
self.decoder_runtime_mapping,
decoder_engine_buffer,
) = engine_setup(component="decoder")
self.decoder_session = tensorrt_llm.runtime.GenerationSession(
self.decoder_model_config,
decoder_engine_buffer,
self.decoder_runtime_mapping,
debug_mode=debug_mode,
)
# decoder lora manager setup
if self.decoder_model_config.lora_plugin:
self.decoder_lora_manager = LoraManager()
# TODO: this is only for bart
self.decoder_lora_manager.load_from_hf(
model_dirs=lora_dir,
model_config=self.decoder_model_config,
runtime_mapping=self.decoder_runtime_mapping,
component="decoder",
)
else:
self.decoder_lora_manager = None
@classmethod
def from_engine(
cls,
engine_name,
engine_dir,
lora_dir=None,
lora_task_uids=None,
debug_mode=False,
skip_encoder=False,
stream=None,
):
return cls(
engine_name,
engine_dir,
lora_dir,
lora_task_uids,
debug_mode=debug_mode,
skip_encoder=skip_encoder,
stream=stream,
)
def process_input(
self, input_ids, remove_input_padding=False, pad_token_id=0, prompt_tasks=None
):
if remove_input_padding:
# in remove padding mode --> flatten input, calculate actual length and max length
# Note: 1st token should never be removed, even if it is pad_token_id
first_ids = input_ids[:, 0]
input_ids = input_ids[:, 1:]
input_lengths = 1 + (input_ids != pad_token_id).sum(dim=1).type(
torch.IntTensor
).to(
self.device
) # [batch_size]
new_ids = []
for i in range(len(input_ids)):
row = input_ids[i, :]
row = row[row != pad_token_id]
new_ids.append(
torch.cat((torch.IntTensor([first_ids[i]]).to(self.device), row))
)
input_ids = torch.cat(new_ids) # [num_tokens]
if prompt_tasks is not None:
prompt_tasks = prompt_tasks[: input_ids.shape[0]]
else:
# in padding mode --> keep input, just calculate actual length and max length
# Note: 1st token should always count, even if it is pad_token_id. e.g., decoder start id in enc-dec models could be a single pad_token_id, we should count
input_lengths = torch.tensor(
1
+ (input_ids[:, 1:] != pad_token_id)
.sum(dim=1)
.type(torch.IntTensor)
.to(self.device),
dtype=torch.int32,
device=self.device,
)
max_input_length = torch.max(input_lengths).item()
return input_ids, input_lengths, max_input_length, prompt_tasks
def encoder_run(
self,
input_ids,
input_lengths,
max_input_length,
position_ids=None,
token_type_ids=None,
debug_mode=False,
prompt_embedding_table=None,
prompt_tasks=None,
prompt_vocab_size=None,
attention_mask=None,
):
# each engine has hidden_dim/TP, don't forget to multiply TP
hidden_size = (
self.encoder_model_config.hidden_size * self.encoder_runtime_mapping.tp_size
)
if input_ids.dim() == 1:
hidden_states_shape = (input_ids.shape[0], hidden_size) # [num_tokens,D]
else:
hidden_states_shape = (
input_ids.shape[0],
input_ids.shape[1],
hidden_size,
) # [BS,seqlen,D]
hidden_states_dtype = lambda name: trt_dtype_to_torch(
self.encoder_session.engine.get_tensor_dtype(name)
)
# input tensors. only first PP rank has id input, others are hidden_states input
inputs = {}
if self.encoder_runtime_mapping.is_first_pp_rank():
inputs["input_ids"] = input_ids.contiguous()
if self.encoder_model_config.has_position_embedding:
if position_ids is None:
if self.encoder_model_config.remove_input_padding:
position_ids = [
torch.arange(
sample_length,
dtype=torch.int32,
device=input_ids.device,
)
for sample_length in torch_to_numpy(input_lengths)
]
position_ids = torch.cat(position_ids)
else:
bsz, seq_len = input_ids.shape[:2]
position_ids = torch.arange(
seq_len, dtype=torch.int32, device=input_ids.device
).expand(bsz, -1)
inputs["position_ids"] = position_ids.contiguous()
if self.encoder_model_config.has_token_type_embedding:
inputs["token_type_ids"] = token_type_ids.contiguous()
if self.encoder_model_config.max_prompt_embedding_table_size > 0:
inputs["prompt_embedding_table"] = prompt_embedding_table.contiguous()
inputs["tasks"] = prompt_tasks.contiguous()
inputs["prompt_vocab_size"] = prompt_vocab_size.contiguous()
else:
# just need a placeholder, engine will call NCCL to recv and fill data from previous rank
inputs["hidden_states_input"] = torch.empty(
hidden_states_shape,
dtype=hidden_states_dtype("hidden_states_input"),
device=self.device,
).contiguous()
if (
attention_mask is not None
and not self.encoder_model_config.gpt_attention_plugin
):
inputs["attention_mask"] = attention_mask.contiguous()
inputs["input_lengths"] = input_lengths
# use shape info to pass max length info in remove padding mode
inputs["max_input_length"] = torch.empty(
(max_input_length,),
dtype=hidden_states_dtype("max_input_length"),
device=self.device,
).contiguous()
batch_size = input_lengths.size(0)
inputs["host_request_types"] = torch.IntTensor([0] * batch_size).to("cpu")
if self.encoder_model_config.remove_input_padding:
inputs["host_context_lengths"] = input_lengths.to("cpu")
if (
self.encoder_model_config.lora_plugin
and self.encoder_lora_manager is not None
):
inputs.update(
self.encoder_lora_manager.input_buffers(
self.lora_task_uids,
self.encoder_runtime_mapping,
self.encoder_model_config.num_layers,
)
)
# Note: runtime.Session's run() method will set input/output tensor address, here we only need to provide tensor shape
self.encoder_session.set_shapes(inputs)
# output tensors. only last PP rank final encoder output, others are intermediate hidden_states output. Need broadcast later
outputs = {}
if self.encoder_runtime_mapping.is_last_pp_rank():
outputs["encoder_output"] = torch.empty(
hidden_states_shape,
dtype=hidden_states_dtype("encoder_output"),
device=self.device,
).contiguous()
else:
outputs["hidden_states_output"] = torch.empty(
hidden_states_shape,
dtype=hidden_states_dtype("hidden_states_output"),
device=self.device,
).contiguous()
# -------------------------------------------
if debug_mode:
engine = self.encoder_session.engine
context = self.encoder_session.context
# setup debugging buffer for the encoder
for i in range(self.encoder_session.engine.num_io_tensors):
name = engine.get_tensor_name(i)
if (
engine.get_tensor_mode(name) == trt.TensorIOMode.OUTPUT
and name not in outputs.keys()
):
dtype = engine.get_tensor_dtype(name)
shape = context.get_tensor_shape(name)
outputs[name] = torch.zeros(
tuple(shape),
dtype=trt_dtype_to_torch(dtype),
device=self.device,
)
context.set_tensor_address(name, outputs[name].data_ptr())
# -------------------------------------------
# TRT session run
# Note: need cuda stream ID, not a torch Stream
ok = self.encoder_session.run(inputs, outputs, self.stream.cuda_stream)
assert ok, "Runtime execution failed"
self.stream.synchronize()
# Tensor Parallelism is handled by model/engine definition
# But we need to broadcast among PP group at the end of encoder's Pipeline Parallelism
# After this, all ranks should recv the encoder output, and world might be re-configured using decoder's TP-PP config
def pp_communicate_encoder_output(encoder_output):
if self.encoder_runtime_mapping.is_last_pp_rank():
for pp_rank in self.encoder_runtime_mapping.pp_group:
if pp_rank != self.encoder_runtime_mapping.rank:
self.nccl_comm.send(encoder_output, pp_rank)
return encoder_output
else:
self.nccl_comm.recv(
encoder_output, self.encoder_runtime_mapping.pp_group[-1]
)
return encoder_output
if self.encoder_runtime_mapping.has_pp():
# use hidden_states output buffer to receive output as the shapes are same
encoder_output_buf = (
outputs["encoder_output"]
if self.encoder_runtime_mapping.is_last_pp_rank()
else outputs["hidden_states_output"]
)
encoder_output = pp_communicate_encoder_output(encoder_output_buf)
else:
encoder_output = outputs["encoder_output"]
# -------------------------------------------
if (
debug_mode and self.encoder_runtime_mapping.tp_rank == 0
): # only tp_rank 0 print encoder output
torch.cuda.synchronize()
# use print_tensor() to print the tensors registered in the encoder network
print("--------------------------------------")
print("Debug output for Encoder")
print("--------------------------------------")
print("Registered output tensors are: ", outputs.keys())
for k, v in outputs.items():
print_tensor(k, v, num_elements=30)
print_tensor("encoder_output", encoder_output)
print("--------------------------------------")
# -------------------------------------------
return encoder_output
def generate(
self,
encoder_input_ids,
decoder_input_ids,
max_new_tokens,
num_beams=1,
pad_token_id=None,
eos_token_id=None,
bos_token_id=None,
debug_mode=False,
return_dict=False,
prompt_embedding_table=None,
prompt_tasks=None,
prompt_vocab_size=None,
attention_mask=None,
time_encoder=False,
return_encoder_output=False,
):
## ensure all externally provided tensors are on the correct device.
encoder_input_ids = encoder_input_ids.to(self.device)
decoder_input_ids = decoder_input_ids.to(self.device)
if attention_mask is not None:
attention_mask = torch.tensor(
attention_mask, dtype=torch.int32, device=self.device
)
## encoder run
encoder_remove_input_padding = (
self.encoder_model_config.remove_input_padding
if self.encoder_model_config
else self.decoder_model_config.remove_input_padding
)
(
encoder_input_ids,
encoder_input_lengths,
encoder_max_input_length,
prompt_tasks,
) = self.process_input(
encoder_input_ids, encoder_remove_input_padding, pad_token_id, prompt_tasks
)
if not self.skip_encoder:
logger.info(f"Rank {self.runtime_rank} Running encoder engine ...")
if time_encoder:
tik = time.time()
encoder_output = self.encoder_run(
encoder_input_ids,
encoder_input_lengths,
encoder_max_input_length,
debug_mode=debug_mode,
prompt_embedding_table=prompt_embedding_table,
prompt_tasks=prompt_tasks,
prompt_vocab_size=prompt_vocab_size,
attention_mask=attention_mask,
)
if time_encoder:
tok = time.time()
print(f"TRT-LLM Encoder time {(tok-tik)*1000}ms")
else:
encoder_output = prompt_embedding_table
if encoder_input_ids.dim() > 1:
encoder_output = encoder_output.unsqueeze(0)
## decoder run
logger.info(f"Rank {self.runtime_rank} Running decoder engine ...")
decoder_input_ids, decoder_input_lengths, decoder_max_input_length, _ = (
self.process_input(
decoder_input_ids,
self.decoder_model_config.remove_input_padding,
pad_token_id,
)
)
# `cross_attention_mask` in context phase [batch_size, query_len, encoder_input_len]
# where query_len happens to be 1 in current cases, but not necessarily always, and
# `cross_attention_mask` in generation phase [batch_size, 1, encoder_input_len] where
# the query_len is always 1 since we have kv cache.
cross_attention_mask = None
if attention_mask is not None:
cross_attention_mask = torch.tensor(
attention_mask, dtype=torch.int32, device=self.device
).reshape(attention_mask.shape[0], 1, attention_mask.shape[1])
# generation config
sampling_config = SamplingConfig(
end_id=eos_token_id,
pad_id=pad_token_id,
num_beams=num_beams,
min_length=1,
return_dict=return_dict,
)
sampling_config.update(
output_cum_log_probs=return_dict, output_log_probs=return_dict
)
# decoder autoregressive generation
self.decoder_session.setup(
decoder_input_lengths.size(0),
decoder_max_input_length,
max_new_tokens,
num_beams,
max_attention_window_size=None,
encoder_max_input_length=encoder_max_input_length,
lora_manager=self.decoder_lora_manager,
lora_uids=self.lora_task_uids,
)
output = self.decoder_session.decode(
decoder_input_ids,
decoder_input_lengths,
sampling_config,
encoder_output=encoder_output,
encoder_input_lengths=encoder_input_lengths,
return_dict=return_dict,
cross_attention_mask=cross_attention_mask,
)
if return_encoder_output:
return output, encoder_output
return output
def test_fairseq_models(args):
## Note: NMT is the only FairSeq model. Adding FairSeq dependency is too heavy for the CI workflow, hence we used fixed input/output ids for correctness check and leave FairSeq code in comments. Users can follow Encoder-Decoder's README to install FairSeq and test locally.
"""
from fairseq.models.transformer import TransformerModel
fairseq_model = TransformerModel.from_pretrained(model_name_or_path=args.model_name, data_name_or_path=args.model_name, bpe='subword_nmt', tokenizer='moses').cuda()
input_text = "Good Morning! How are you doing today?"
input_ids = fairseq_model.encode(input_text)
tik = time.time()
# Note: FairSeq sampling=True results are not deterministic, disable during accuracy check
fairseq_output_ids = fairseq_model.generate(input_ids, beam=1, sampling=False) #
tik = time.time()
fairseq_output_ids = fairseq_output_ids[0]['tokens']
fairseq_output_text = fairseq_model.decode(fairseq_output_ids)
print("--------------------------------------")
print("input text: ", input_text)
print("input ids: ", input_ids) # [9938, 5384, 9328, 812, 3619, 53, 181, 3829, 1735, 171, 2]
print("fairseq_output ids: ", fairseq_output_ids) # [9804, 391, 4, 4625, 167, 25, 1003, 5123, 17, 167, 1466, 1234, 171, 2]
print("fairseq_output text: ", fairseq_output_text) # "Bonjour, Comment vous en tirez-vous aujourd'hui ?"
print(f"FairSeq E2E time {(tok-tik)*1000}ms")
print("--------------------------------------")
"""
max_new_tokens = args.max_new_tokens
bos_token_id = 2
pad_token_id = 0
eos_token_id = 2
decoder_start_token_id = bos_token_id
input_ids = torch.tensor([9938, 5384, 9328, 812, 3619, 53, 181, 3829, 1735, 171, 2])
fairseq_output_ids = torch.tensor(
[9804, 391, 4, 4625, 167, 25, 1003, 5123, 17, 167, 1466, 1234, 171, 2]
)
input_ids = torch.tensor([input_ids.tolist()]).type(torch.IntTensor).cuda()
decoder_input_ids = torch.IntTensor([[decoder_start_token_id]]).cuda()
decoder_input_ids = decoder_input_ids.repeat((input_ids.shape[0], 1))
tllm_model = TRTLLMEncDecModel.from_engine(
args.engine_name, args.engine_dir, debug_mode=args.debug_mode
)
inference_dtype = tllm_model.encoder_model_config.dtype
return_dict = False # when set return_dict=True, get outputs by key
tik = time.time()
tllm_output = tllm_model.generate(
encoder_input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
max_new_tokens=max_new_tokens,
num_beams=args.num_beams,
bos_token_id=bos_token_id,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
debug_mode=args.debug_mode,
)
tok = time.time()
torch.cuda.synchronize()
if return_dict:
tllm_output_ids = tllm_output["output_ids"]
else:
tllm_output_ids = tllm_output
if tensorrt_llm.mpi_rank() == 0:
output_ids = tllm_output_ids[:, 0, :]
output_ids = output_ids[output_ids != eos_token_id]
fairseq_output_ids = fairseq_output_ids[fairseq_output_ids != eos_token_id]
print("--------------------------------------")
print("TRT-LLM output_ids: ", output_ids)
print(f"TRT-LLM E2E time {(tok-tik)*1000}ms")
print("Precision:", inference_dtype)
print("--------------------------------------")
assert (
output_ids.tolist() == fairseq_output_ids.tolist()
), f"TRT-LLM output ids {output_ids} does not match Fairseq ids {fairseq_output_ids}"
if __name__ == "__main__":
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
args = parse_arguments()
logger.set_level(args.log_level)
# FairSeq NMT test logic is different from HuggingFace models
if "wmt" in args.model_name:
test_fairseq_models(args)
exit()
test_remove_padding = True
if not test_remove_padding:
if "t5" in args.model_name:
input_text = "translate English to German: The house is wonderful, radiating timeless charm and offering a warm, inviting interior with beautiful details and a serene backyard."
elif "bart" in args.model_name:
input_text = "The tower is 324 metres (1,063 ft) tall, about the same height as an 81-storey building, and the tallest structure in Paris. Its base is square, measuring 125 metres (410 ft) on each side. During its construction, the Eiffel Tower surpassed the Washington Monument to become the tallest man-made structure in the world, a title it held for 41 years until the Chrysler Building in New York City was finished in 1930. It was the first structure to reach a height of 300 metres. Due to the addition of a broadcasting aerial at the top of the tower in 1957, it is now taller than the Chrysler Building by 5.2 metres (17 ft). Excluding transmitters, the Eiffel Tower is the second tallest free-standing structure in France after the Millau Viaduct."
else:
raise RuntimeError("Unsupported model type!")
else:
input_text = [
"translate English to German: The house is wonderful.",
"summarize: I am a high-performance inference optimizer and runtime.",
"During its construction, the Eiffel Tower surpassed the Washington Monument to become the tallest man-made structure in the world",
]
# TRT-LLM runtime
tllm_model = TRTLLMEncDecModel.from_engine(
args.engine_name,
args.engine_dir,
args.lora_dir,
args.lora_task_uids,
debug_mode=args.debug_mode,
)
inference_dtype = tllm_model.encoder_model_config.dtype
if inference_dtype == "float32":
if "byt5" in args.model_name:
print(
"ByT5 models tokenize input by bytes instead of words, causing the input text in this example to be longer than the default value during build stage. Please adjust --max_input_len during trtllm-build to select the right length limit for ByT5 models."
)
else:
input_text.append(
'Summarize this article in one sentence.\n\nKristine Watts (Molie Weeks) is broken apart, missing her lover; she is not able to overcome her love for him that is lost in the past. She hires a stranger (Douglas Davis) and gives a list of her mistakes to him with things to fix. But time is irreversible and sometimes the cure for the pain is a tragic end.\n\nThe first point that impresses in "The Cure" is the stylish cinematography that alternates black and white with color. The concise and sharp screenplay is capable to develop a tragic and bleak tale of love with an unexpected plot point in the very end in less than eight minutes. The soundtrack is beautiful but the volume is a little loud and associated to the fact that English is not my native language, in some moments I needed to repeat some words whispered by the narrator. The unknown lead actress has magnificent performance and is extremely gorgeous. I hope to have a chance to see her again on the screen. Last but not the least, the debut of the director and writer Ryan Jafri could not be better. My vote is nine.\n\nTitle (Brazil): Not Available',
)
tokenizer = AutoTokenizer.from_pretrained(
args.model_name
) # TODO: use model path instead
tokenized_inputs = tokenizer(input_text, return_tensors="pt", padding=True)
max_new_tokens = args.max_new_tokens
input_ids = tokenized_inputs.input_ids.type(torch.IntTensor).to(
"cuda"
) # [batch_size, padded_length]
# by default int64, must cast to int32! otherwise C++ kernel will interpret as [a, 0, b, 0, c, 0, ...]
CPP_RESULTS_SAVED_DIR = "cpp/tests/resources/data/enc_dec"
if tensorrt_llm.mpi_rank() == 0:
if args.output_encoder_npy:
if not os.path.isdir(CPP_RESULTS_SAVED_DIR):
os.mkdir(os.path.join(CPP_RESULTS_SAVED_DIR))
np_input_ids = tokenized_inputs.input_ids.type(torch.IntTensor)
np_input_ids = np_input_ids.numpy()
np.save(
os.path.join(CPP_RESULTS_SAVED_DIR, "enc_input_ids.npy"), np_input_ids
)
input_lengths = (
tokenized_inputs.attention_mask.sum(dim=1).type(torch.IntTensor).numpy()
)
np.save(
os.path.join(CPP_RESULTS_SAVED_DIR, "enc_input_lengths.npy"),
input_lengths,
)
print("--------------------------------------")
print(
f"BOS={tokenizer.bos_token_id}, PAD={tokenizer.pad_token_id}, EOS={tokenizer.eos_token_id}"
)
print("input text: ", input_text)
print("input ids: ", input_ids)
print("input lengths: ", tokenized_inputs.attention_mask.sum(dim=1))
print("--------------------------------------")
model_config = AutoConfig.from_pretrained(args.model_name)
# start_id for decoder (could add more input_ids as forced_decoder_ids)
decoder_input_ids = torch.IntTensor([[model_config.decoder_start_token_id]]).to(
"cuda"
)
decoder_input_ids = decoder_input_ids.repeat((input_ids.shape[0], 1))
# simple comparison with HF on FP32
if args.compare_hf_fp32:
if tensorrt_llm.mpi_rank() == 0:
hf_model = (
AutoModelForSeq2SeqLM.from_pretrained(
args.model_name, # TODO: use model path instead
# torch_dtype=torch.float16 if '16' in dtype else torch.float32, # TODO: use matched torch dtype
)
.to("cuda")
.eval()
) # TODO: create config model path instead
assert type(hf_model) in (
T5ForConditionalGeneration,
BartForConditionalGeneration,
MBartForConditionalGeneration,
), "Unsupported model!"
if args.lora_dir is not None:
assert (
len(args.lora_dir) >= 1
), "At least one lora model dir is required"
# we can only test single lora with HF
from peft import PeftModel
hf_model = (
PeftModel.from_pretrained(hf_model, args.lora_dir[0])
.to("cuda")
.eval()
)
tik = time.time()
hf_gen_output = hf_model.generate(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
max_new_tokens=max_new_tokens,
num_beams=args.num_beams,
bos_token_id=tokenizer.bos_token_id,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
use_cache=True,
# control logits processors
no_repeat_ngram_size=0, # disable no repeat post-processor
forced_bos_token_id=None, # disable forced first/last token
forced_eos_token_id=None,
min_length=0,
# for debug
output_scores=True,
output_hidden_states=True,
return_dict_in_generate=True,
)
# get hf output scores
hf_output_ids = hf_gen_output.sequences
# convert to logits
torch.cuda.synchronize()
tok = time.time()
output_ids = hf_output_ids.squeeze(dim=1)
hf_output_text = tokenizer.batch_decode(
output_ids, skip_special_tokens=True
)
decoder_input_lengths = (decoder_input_ids != tokenizer.pad_token_id).sum(
dim=1
)
output_gen_lengths = (output_ids != tokenizer.eos_token_id).sum(
dim=1
) - decoder_input_lengths
print("--------------------------------------")
print("HF output_ids: ", output_ids)
print("HF output text: ", hf_output_text)
print("HF output generated lengths: ", output_gen_lengths)
print(f"HF E2E time {(tok-tik)*1000}ms")
print("--------------------------------------")
return_dict = False # when set return_dict=True, get outputs by key
tik = time.time()
tllm_output = tllm_model.generate(
encoder_input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
max_new_tokens=max_new_tokens,
num_beams=args.num_beams,
bos_token_id=tokenizer.bos_token_id,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
debug_mode=args.debug_mode,
return_dict=return_dict,
attention_mask=tokenized_inputs.attention_mask,
time_encoder=True,
return_encoder_output=args.output_encoder_npy and tensorrt_llm.mpi_rank() == 0,
)
tok = time.time()
if args.output_encoder_npy and tensorrt_llm.mpi_rank() == 0:
tllm_output, encoder_output = tllm_output
encoder_output = encoder_output.cpu().numpy()
np.save(
os.path.join(CPP_RESULTS_SAVED_DIR, "encoder_output.npy"), encoder_output
)
if return_dict:
tllm_output_ids = tllm_output["output_ids"]
else:
tllm_output_ids = tllm_output
if tensorrt_llm.mpi_rank() == 0:
output_ids = tllm_output_ids[:, 0, :]
output_text = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
decoder_input_lengths = (decoder_input_ids != tokenizer.pad_token_id).sum(dim=1)
output_gen_lengths = (output_ids != tokenizer.eos_token_id).sum(
dim=1
) - decoder_input_lengths
print("--------------------------------------")
print("TRT-LLM output_ids: ", output_ids)
print("TRT-LLM output text: ", output_text)
print("TRT-LLM output generated lengths: ", output_gen_lengths)
print(f"TRT-LLM E2E time {(tok-tik)*1000}ms")
print("Precision:", inference_dtype)
print("--------------------------------------")
# simple accuracy check
if args.compare_hf_fp32:
from difflib import SequenceMatcher
match_rate = SequenceMatcher(
None, "\n".join(output_text), "\n".join(hf_output_text)
).ratio()
print(output_text)
print(hf_output_text)
if inference_dtype != "float32":
print("")
print(
f"[CAVEAT] Comparing TRT-LLM {inference_dtype} results with HF float32 results. Close match are not expected!"
)
assert match_rate > 0.8, f"Incorrect results! Match rate {match_rate}"
else:
assert match_rate > 0.95, f"Incorrect results! Match rate {match_rate}"
print(
f"TRT-LLM results match HF FP32 results with literal match rate {match_rate}"
)
import argparse
import json
import os
import sys
from pathlib import Path
from typing import Tuple, List, Union
from torchvision.transforms import InterpolationMode
from torchvision import transforms
import requests
# isort: off
import torch
import tensorrt as trt
# isort: on
from huggingface_hub import hf_hub_download
from PIL import Image
from torchvision import transforms
from transformers import (
AutoConfig,
AutoProcessor,
AutoTokenizer,
Blip2Processor,
NougatProcessor,
NougatTokenizerFast,
)
import tensorrt_llm
import tensorrt_llm.profiler as profiler
from tensorrt_llm import logger
from tensorrt_llm._utils import str_dtype_to_trt
from tensorrt_llm.runtime import ModelRunner, Session, TensorInfo
import pandas as pd
from run import TRTLLMEncDecModel
import tqdm
class Preprocss:
def __init__(
self,
image_size: int,
):
mean = (0.48145466, 0.4578275, 0.40821073)
std = (0.26862954, 0.26130258, 0.27577711)
self.image_transform = transforms.Compose(
[
transforms.Resize(
(image_size, image_size), interpolation=InterpolationMode.BICUBIC
),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std),
]
)
def encode(self, image_list):
images = []
for image in image_list:
image = image.convert("RGB")
images.append(self.image_transform(image))
images = torch.stack(images, dim=0)
return images
image_pre_obj = Preprocss(336)
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument("--max_new_tokens", type=int, default=30)
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--log_level", type=str, default="info")
parser.add_argument(
"--visual_engine_dir",
type=str,
default=None,
help="Directory containing visual TRT engines",
)
parser.add_argument(
"--llm_engine_dir",
type=str,
default=None,
help="Directory containing TRT-LLM engines",
)
parser.add_argument(
"--hf_model_dir", type=str, default=None, help="Directory containing tokenizer"
)
parser.add_argument("--content", type=str, default=None)
parser.add_argument(
"--image_file", type=str, default="images/demo1.jpeg"
) # 'images/demo1.jpeg'i
parser.add_argument("--input_file", type=str, default=None) # 'images/demo.csv'
parser.add_argument(
"--output_file", type=str, default=None
) # 'images/demo_res.csv'
parser.add_argument(
"--mode",
choices=["caption_zh", "caption_en", "insert_content"],
default="caption_zh",
)
parser.add_argument(
"--num_beams", type=int, help="Use beam search if num_beams >1", default=1
)
parser.add_argument("--top_k", type=int, default=1)
parser.add_argument("--top_p", type=float, default=0.0)
parser.add_argument("--temperature", type=float, default=1.0)
parser.add_argument("--repetition_penalty", type=float, default=1.0)
parser.add_argument(
"--run_profiling",
action="store_true",
help="Profile runtime over several iterations",
)
parser.add_argument(
"--check_accuracy", action="store_true", help="Check correctness of text output"
)
return parser.parse_args()
def trt_dtype_to_torch(dtype):
if dtype == trt.float16:
return torch.float16
elif dtype == trt.float32:
return torch.float32
elif dtype == trt.int32:
return torch.int32
elif dtype == trt.bfloat16:
return torch.bfloat16
else:
raise TypeError("%s is not supported" % dtype)
class MultimodalModelRunner:
def __init__(self, args):
self.args = args
self.runtime_rank = tensorrt_llm.mpi_rank()
device_id = self.runtime_rank % torch.cuda.device_count()
torch.cuda.set_device(device_id)
self.device = "cuda:%d" % (device_id)
self.stream = torch.cuda.Stream(torch.cuda.current_device())
torch.cuda.set_stream(self.stream)
# parse model type from visual engine config
with open(os.path.join(self.args.visual_engine_dir, "config.json"), "r") as f:
config = json.load(f)
self.model_type = config["builder_config"]["model_type"]
self.vision_precision = config["builder_config"]["precision"]
if self.model_type == "pix2struct":
self.vision_precision = "float16"
self.decoder_llm = not (
"t5" in self.model_type or self.model_type in ["nougat", "pix2struct"]
) # BLIP2-T5, pix2struct and Nougat are using encoder-decoder models as LLMs
self.profiling_iterations = 20
self.init_image_encoder()
self.init_tokenizer()
self.init_llm()
def init_tokenizer(self):
if self.model_type == "nougat":
self.tokenizer = NougatTokenizerFast.from_pretrained(self.args.hf_model_dir)
elif self.model_type == "neva":
from sentencepiece import SentencePieceProcessor
sp = SentencePieceProcessor(
os.path.join(self.args.hf_model_dir, "tokenizer.model")
)
class return_obj:
def __init__(self, input_ids):
self.input_ids = input_ids
def __getitem__(self, name):
if name in "input_ids":
return self.input_ids
else:
raise AttributeError(f"'return_obj' has no item '{name}'")
# sentencepiece does not follow the same interface as HF
class HFTokenizerInterface:
def encode(self, x, return_tensors=None, **kwargs):
out = sp.encode(x)
if return_tensors == "pt":
out = torch.tensor(out)
return return_obj(out)
def __call__(self, x, return_tensors=None, **kwargs):
return self.encode(x, return_tensors, **kwargs)
def decode(self, x, **kwargs):
return sp.decode(x.tolist())
def batch_decode(self, x, **kwargs):
return self.decode(x, **kwargs)
self.tokenizer = HFTokenizerInterface()
self.tokenizer.eos_token_id = sp.eos_id()
self.tokenizer.bos_token_id = sp.bos_id()
self.tokenizer.pad_token_id = sp.pad_id()
else:
self.tokenizer = AutoTokenizer.from_pretrained(
self.args.hf_model_dir, use_fast=False, use_legacy=False
)
self.tokenizer.padding_side = "right"
def init_image_encoder(self):
vision_encoder_path = os.path.join(
self.args.visual_engine_dir, "visual_encoder.engine"
)
logger.info(f"Loading engine from {vision_encoder_path}")
with open(vision_encoder_path, "rb") as f:
engine_buffer = f.read()
logger.info(f"Creating session from engine {vision_encoder_path}")
self.visual_encoder_session = Session.from_serialized_engine(engine_buffer)
def init_llm(self):
if self.decoder_llm:
self.model = ModelRunner.from_dir(
self.args.llm_engine_dir,
rank=tensorrt_llm.mpi_rank(),
debug_mode=False,
stream=self.stream,
)
self.model_config = self.model.session._model_config
self.runtime_mapping = self.model.session.mapping
else:
self.model = TRTLLMEncDecModel.from_engine(
os.path.basename(self.args.hf_model_dir),
self.args.llm_engine_dir,
skip_encoder=self.model_type in ["nougat", "pix2struct"],
debug_mode=False,
stream=self.stream,
)
if self.model_type in ["nougat", "pix2struct"]:
self.model_config = self.model.decoder_model_config
self.runtime_mapping = self.model.decoder_runtime_mapping
else:
self.model_config = self.model.encoder_model_config
self.runtime_mapping = self.model.encoder_runtime_mapping
def preprocess(self, warmup, pre_prompt, post_prompt, image, attention_mask):
if self.model_type == "kosmos-2":
input_ids = image["input_ids"].clone()
image_mask = image["image_embeds_position_mask"]
image = image["pixel_values"]
input_ids += image_mask * (self.model_config.vocab_size - 4)
input_ids = input_ids.expand(self.args.batch_size, *input_ids.shape[1:])
length = input_ids.shape[1]
if not warmup:
profiler.start("Vision")
visual_features, visual_atts = self.get_visual_features(
(
torch.stack(image["image_patches"], dim=0)
if self.model_type == "fuyu"
else image
),
attention_mask,
)
if not warmup:
profiler.stop("Vision")
if self.model_type == "fuyu":
visual_features = visual_features.squeeze()
input_ids = image["input_ids"].to(torch.int32)
image_patches_indices = image["image_patches_indices"].to(torch.int32)
input_ids = input_ids.expand(self.args.batch_size, *input_ids.shape[1:])
image_patches_indices = image_patches_indices.expand(
self.args.batch_size, *image_patches_indices.shape[1:]
)
input_ids = self.ptuning_setup_fuyu(input_ids, image_patches_indices)
input_ids = torch.stack(input_ids, dim=0).to("cpu")
length = input_ids.shape[1]
elif self.model_type == "kosmos-2":
visual_features = visual_features.squeeze()
else:
pre_input_ids = self.tokenizer(
pre_prompt, return_tensors="pt", padding=True
).input_ids
if post_prompt[0] is not None:
post_input_ids = self.tokenizer(
post_prompt, return_tensors="pt", padding=True
).input_ids
length = (
pre_input_ids.shape[1]
+ post_input_ids.shape[1]
+ visual_atts.shape[1]
)
else:
post_input_ids = None
length = pre_input_ids.shape[1] + visual_atts.shape[1]
input_lengths = torch.IntTensor([length] * args.batch_size).to(torch.int32)
if self.model_type in ["fuyu", "kosmos-2"]:
return input_ids, input_lengths, [visual_features], visual_features
input_ids, ptuning_args = self.setup_fake_prompts(
visual_features, pre_input_ids, post_input_ids, input_lengths
)
return input_ids, input_lengths, ptuning_args, visual_features
def generate(
self,
pre_prompt,
post_prompt,
image,
decoder_input_ids,
max_new_tokens,
attention_mask,
warmup,
):
if not warmup:
profiler.start("Generate")
input_ids, input_lengths, ptuning_args, visual_features = self.preprocess(
warmup, pre_prompt, post_prompt, image, attention_mask
)
if warmup:
return None
profiler.start("LLM")
if self.decoder_llm:
end_id = self.tokenizer.eos_token_id
if "opt" in self.model_type and "blip2" in self.model_type:
# For BLIP2-OPT, model outputs a "\n" at the end.
# we avoid it by using newline as the end token
end_id = self.tokenizer.encode("\n", add_special_tokens=False)[0]
ptuning_args[0] = torch.stack([ptuning_args[0]])
output_ids = self.model.generate(
input_ids,
sampling_config=None,
prompt_table=ptuning_args[0],
max_new_tokens=max_new_tokens,
end_id=end_id,
pad_id=(
self.tokenizer.pad_token_id
if self.tokenizer.pad_token_id is not None
else self.tokenizer.all_special_ids[0]
),
top_k=self.args.top_k,
top_p=self.args.top_p,
temperature=self.args.temperature,
repetition_penalty=self.args.repetition_penalty,
num_beams=self.args.num_beams,
output_sequence_lengths=False,
return_dict=False,
)
else:
if self.model_type in ["nougat", "pix2struct"]:
# Trim encoder input_ids to match visual features shape
ids_shape = (self.args.batch_size, visual_features.shape[1])
if self.model_type == "nougat":
input_ids = torch.zeros(ids_shape, dtype=torch.int32)
elif self.model_type == "pix2struct":
input_ids = torch.ones(ids_shape, dtype=torch.int32)
output_ids = self.model.generate(
input_ids,
decoder_input_ids,
max_new_tokens,
num_beams=self.args.num_beams,
bos_token_id=self.tokenizer.bos_token_id,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id,
debug_mode=False,
prompt_embedding_table=ptuning_args[0],
prompt_tasks=ptuning_args[1],
prompt_vocab_size=ptuning_args[2],
attention_mask=attention_mask,
)
# Reset input_lengths to match decoder_input_ids
input_lengths = torch.ones(input_lengths.shape, dtype=input_lengths.dtype)
profiler.stop("LLM")
if tensorrt_llm.mpi_rank() == 0:
# Extract a list of tensors of shape beam_width x output_ids.
output_beams_list = [
self.tokenizer.batch_decode(
output_ids[batch_idx, :, input_lengths[batch_idx] :],
skip_special_tokens=True,
)
for batch_idx in range(self.args.batch_size)
]
stripped_text = [
[
output_beams_list[batch_idx][beam_idx].strip()
for beam_idx in range(self.args.num_beams)
]
for batch_idx in range(self.args.batch_size)
]
profiler.stop("Generate")
return stripped_text
else:
profiler.stop("Generate")
return None
def get_visual_features(self, image, attention_mask):
visual_features = {
"input": image.to(
tensorrt_llm._utils.str_dtype_to_torch(self.vision_precision)
)
}
if attention_mask is not None:
visual_features["attention_mask"] = attention_mask
tensor_info = [
TensorInfo("input", str_dtype_to_trt(self.vision_precision), image.shape)
]
if attention_mask is not None:
tensor_info.append(
TensorInfo("attention_mask", trt.DataType.INT32, attention_mask.shape)
)
visual_output_info = self.visual_encoder_session.infer_shapes(tensor_info)
visual_outputs = {
t.name: torch.empty(
tuple(t.shape), dtype=trt_dtype_to_torch(t.dtype), device=image.device
)
for t in visual_output_info
}
ok = self.visual_encoder_session.run(
visual_features, visual_outputs, self.stream.cuda_stream
)
assert ok, "Runtime execution failed for vision encoder session"
self.stream.synchronize()
image_embeds = visual_outputs["output"]
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(
image.device
)
return image_embeds, image_atts
def setup_fake_prompts(
self, visual_features, pre_input_ids, post_input_ids, input_lengths
):
# Assemble fake prompts which points to image embedding actually
fake_prompt_id = torch.arange(
self.model_config.vocab_size,
self.model_config.vocab_size
+ visual_features.shape[0] * visual_features.shape[1],
)
fake_prompt_id = fake_prompt_id.reshape(
visual_features.shape[0], visual_features.shape[1]
)
if "cogvlm" in self.model_type:
input_ids = (
torch.cat(
[pre_input_ids[:, 0:1], fake_prompt_id, pre_input_ids[:, 1:]], dim=1
)
.contiguous()
.to(torch.int32)
)
else:
if post_input_ids is not None:
input_ids = [pre_input_ids, fake_prompt_id, post_input_ids]
else:
input_ids = [fake_prompt_id, pre_input_ids]
input_ids = torch.cat(input_ids, dim=1).contiguous().to(torch.int32)
if self.decoder_llm or self.runtime_mapping.is_first_pp_rank():
ptuning_args = self.ptuning_setup(visual_features, input_ids, input_lengths)
else:
ptuning_args = [None, None, None]
return input_ids, ptuning_args
def ptuning_setup_fuyu(self, input_ids, image_patches_indices):
res_input_ids = []
for cur_input_ids, cur_image_patches_indices in zip(
input_ids, image_patches_indices
):
# Truncate input_ids to the length of image_patches_indices
cur_image_patches_indices = cur_image_patches_indices[: len(cur_input_ids)]
# Get ids of the image_patches
non_zero_mask = cur_image_patches_indices != -1
# Replace input_ids with image_patches_indices values (where the patches are placed)
cur_input_ids = cur_input_ids.masked_scatter(
non_zero_mask,
cur_image_patches_indices[non_zero_mask] + self.model_config.vocab_size,
)
res_input_ids.append(cur_input_ids)
return res_input_ids
def ptuning_setup(self, prompt_table, input_ids, input_lengths):
hidden_size = self.model_config.hidden_size * self.runtime_mapping.tp_size
if prompt_table is not None:
task_vocab_size = torch.tensor(
[prompt_table.shape[1]],
dtype=torch.int32,
).cuda()
prompt_table = prompt_table.view(
(prompt_table.shape[0] * prompt_table.shape[1], prompt_table.shape[2])
)
assert (
prompt_table.shape[1] == hidden_size
), "Prompt table dimensions do not match hidden size"
prompt_table = prompt_table.cuda().to(
dtype=tensorrt_llm._utils.str_dtype_to_torch(self.model_config.dtype)
)
else:
prompt_table = torch.empty([1, hidden_size]).cuda()
task_vocab_size = torch.zeros([1]).cuda()
if self.model_config.remove_input_padding:
tasks = torch.zeros([torch.sum(input_lengths)], dtype=torch.int32).cuda()
if self.decoder_llm:
tasks = tasks.unsqueeze(0)
else:
tasks = torch.zeros(input_ids.shape, dtype=torch.int32).cuda()
return [prompt_table, tasks, task_vocab_size]
def load_test_image(self):
if "vila" in self.model_type:
img_url = "https://github.com/Efficient-Large-Model/VILA/raw/main/demo_images/av.png"
image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
elif "nougat" in self.model_type:
filepath = hf_hub_download(
repo_id="hf-internal-testing/fixtures_docvqa",
filename="nougat_paper.png",
repo_type="dataset",
)
image = Image.open(filepath)
elif "fuyu" in self.model_type:
filepath = hf_hub_download(
repo_id="adept/fuyu-8b", filename="skateboard.png", repo_type="model"
)
image = Image.open(filepath)
elif "kosmos" in self.model_type:
img_url = "https://huggingface.co/microsoft/kosmos-2-patch14-224/resolve/main/snowman.png"
image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
elif "pix2struct" in self.model_type:
img_url = "https://raw.githubusercontent.com/vis-nlp/ChartQA/main/ChartQA%20Dataset/val/png/multi_col_40963.png"
image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
else:
img_url = "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/assets/merlion.png"
image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
return image
def setup_inputs(self, input_text, raw_image):
attention_mask = None
if "blip2" in self.model_type:
processor = Blip2Processor.from_pretrained(self.model_type)
image = processor(raw_image, input_text, return_tensors="pt")[
"pixel_values"
]
if input_text is None:
input_text = "Question: which city is this? Answer:"
pre_prompt = input_text
post_prompt = None
elif "nougat" in self.model_type:
processor = NougatProcessor.from_pretrained(self.args.hf_model_dir)
image = processor(raw_image, return_tensors="pt")["pixel_values"]
# Nougat doesn't need text prompt (mBART use single token to start generation), just leave a dummy one here
if input_text is None:
input_text = "Question: which city is this? Answer:"
pre_prompt = input_text
post_prompt = None
elif "cogvlm" in self.model_type:
image_size = 490
dtype = torch.bfloat16
transform = transforms.Compose(
[
transforms.Resize(
(image_size, image_size),
interpolation=transforms.InterpolationMode.BICUBIC,
),
transforms.ToTensor(),
transforms.Normalize(
(0.48145466, 0.4578275, 0.40821073),
(0.26862954, 0.26130258, 0.27577711),
),
]
)
image = transform(raw_image).to(dtype).unsqueeze(0)
if input_text is None:
input_text = " [INST] which city is this? [/INST] "
pre_prompt = input_text
post_prompt = None
elif self.model_type == "pix2struct":
image_processor = AutoProcessor.from_pretrained(args.hf_model_dir)
if input_text is None:
input_text = ""
inputs = image_processor(
images=raw_image,
text=input_text,
return_tensors="pt",
)
image = inputs["flattened_patches"]
image = image.expand(self.args.batch_size, -1, -1).contiguous()
attention_mask = inputs["attention_mask"].to(self.device).to(torch.int)
attention_mask = attention_mask.expand(args.batch_size, -1).contiguous()
pre_prompt = ""
post_prompt = None
elif "neva" in self.model_type:
image_size = 384
dtype = torch.float32
transform = transforms.Compose(
[
transforms.Resize(
(image_size, image_size),
interpolation=transforms.InterpolationMode.BICUBIC,
),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
)
image = transform(raw_image).to(dtype).unsqueeze(0)
if input_text is None:
input_text = "Hi! What is in this image?"
pre_prompt = "<extra_id_0>System\n\n<extra_id_1>User\n"
post_prompt = f"\n{input_text}\n<extra_id_1>Assistant\n"
elif self.model_type in ["llava", "vila", "fuyu", "kosmos-2", "llava_next"]:
# LLaVA and VILA
if self.model_type == "llava":
pre_prompt = "USER:\n"
if input_text is None:
input_text = "Question: which city is this? Answer:"
elif self.model_type == "llava_next":
pre_prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: "
elif self.model_type == "vila":
pre_prompt = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: "
if input_text is None:
input_text = "Please describe the traffic condition."
elif self.model_type == "fuyu":
pre_prompt = "Describe this image:"
if input_text is None:
input_text = "Answer the following VQAv2 question based on the image: How many people are in the image?\n"
elif self.model_type == "kosmos-2":
pre_prompt = ""
if input_text is None:
input_text = "<grounding>An image of"
if self.model_type not in ["fuyu", "kosmos-2"]:
post_prompt = input_text + " ASSISTANT:"
else:
post_prompt = None
if self.model_type == "vila":
sys.path.append(self.args.hf_model_dir + "/../VILA")
from llava.model import LlavaLlamaForCausalLM
model = LlavaLlamaForCausalLM.from_pretrained(
self.args.hf_model_dir, torch_dtype=torch.float16
)
vision_tower = model.get_vision_tower()
image_processor = vision_tower.image_processor
image = image_processor(images=raw_image, return_tensors="pt")[
"pixel_values"
]
else:
# processor = AutoProcessor.from_pretrained(
# self.args.hf_model_dir)
# if self.model_type in ['fuyu', 'kosmos-2']:
# image = processor(text=input_text,
# images=raw_image,
# return_tensors='pt')
# else:
# image = processor(text=input_text,
# images=raw_image,
# return_tensors="pt")['pixel_values']
image = image_pre_obj.encode(raw_image).cuda()
# Repeat inputs to match batch size
pre_prompt = [pre_prompt] * self.args.batch_size
post_prompt = [post_prompt] * self.args.batch_size
if self.model_type not in ["fuyu", "pix2struct", "kosmos-2"]:
image = image.expand(args.batch_size, -1, -1, -1).contiguous()
image = image.to(self.device)
# Generate decoder_input_ids for enc-dec models
# Custom prompts can be added as:
# decoder_input_ids = model.tokenizer(decoder_prompt).input_ids
if self.decoder_llm:
decoder_input_ids = None
else:
config = AutoConfig.from_pretrained(args.hf_model_dir)
decoder_start_id = config.decoder_start_token_id # T5
if decoder_start_id is None:
decoder_start_id = config.decoder.bos_token_id # Nougat
decoder_input_ids = torch.IntTensor([[decoder_start_id]])
decoder_input_ids = decoder_input_ids.repeat((args.batch_size, 1))
return (
input_text,
pre_prompt,
post_prompt,
image,
decoder_input_ids,
attention_mask,
)
def run(self, input_text, input_image, max_new_tokens):
(
input_text,
pre_prompt,
post_prompt,
processed_image,
decoder_input_ids,
attention_mask,
) = model.setup_inputs(input_text, input_image)
model.generate(
pre_prompt,
post_prompt,
processed_image,
decoder_input_ids,
max_new_tokens,
attention_mask=attention_mask,
warmup=True,
)
num_iters = self.profiling_iterations if self.args.run_profiling else 1
num_iters = 5
output_text = model.generate(
pre_prompt,
post_prompt,
processed_image,
decoder_input_ids,
max_new_tokens,
attention_mask=attention_mask,
warmup=False,
)
# for _ in range(2):
# output_text = model.generate(pre_prompt,
# post_prompt,
# processed_image,
# decoder_input_ids,
# max_new_tokens,
# attention_mask=attention_mask,
# warmup=False)
# from datetime import datetime
# torch.cuda.synchronize()
# a = datetime.now()
# for _ in range(num_iters):
# output_text = model.generate(pre_prompt,
# post_prompt,
# processed_image,
# decoder_input_ids,
# max_new_tokens,
# attention_mask=attention_mask,
# warmup=False)
# torch.cuda.synchronize()
# b = datetime.now()
# print("cost time : ", (b - a).total_seconds() / num_iters)
if self.runtime_rank == 0:
self.print_result(input_text, output_text)
return output_text
def print_result(self, input_text, output_text):
logger.info("---------------------------------------------------------")
if self.model_type != "nougat":
logger.info(f"\n[Q] {input_text}")
logger.info(f"\n[A] {output_text[0]}")
if args.num_beams == 1:
output_ids = self.tokenizer(output_text[0][0], add_special_tokens=False)[
"input_ids"
]
logger.info(f"Generated {len(output_ids)} tokens")
if self.args.check_accuracy:
for i in range(self.args.batch_size - 1):
if not (output_text[i] == output_text[i + 1]):
logger.info(f"Output {i} and {i + 1} do not match")
assert False
if self.model_type != "nougat":
if self.model_type == "vila":
assert (
output_text[0][0].lower()
== "the traffic condition in the image is quite busy, with multiple cars and bicycles sharing the road. there are also pedestrians walking on"
)
elif self.model_type == "fuyu":
assert output_text[0][0].lower() == "4"
elif self.model_type == "pix2struct":
assert (
"characteristic | cat food, day | cat food, wet | cat treats"
in output_text[0][0].lower()
)
elif self.model_type == "neva":
assert "singapore" in output_text[0][0].lower()
elif self.model_type == "kosmos-2":
assert "snowman" in output_text[0][0].lower()
else:
assert output_text[0][0].lower() == "singapore"
if self.args.run_profiling:
msec_per_batch = (
lambda name: 1000
* profiler.elapsed_time_in_sec(name)
/ self.profiling_iterations
)
logger.info("Latencies per batch (msec)")
logger.info("TRT vision encoder: %.1f" % (msec_per_batch("Vision")))
logger.info("TRTLLM LLM generate: %.1f" % (msec_per_batch("LLM")))
logger.info("Multimodal generate: %.1f" % (msec_per_batch("Generate")))
logger.info("---------------------------------------------------------")
if __name__ == "__main__":
os.environ["TOKENIZERS_PARALLELISM"] = "false"
args = parse_arguments()
if args.mode == "caption_zh":
query = "描述这张图片"
elif args.mode == "caption_en":
query = "Please describe the content of this image"
elif args.mode == "insert_content":
assert args.content is not None
query = f"根据提示词“{args.content}”,描述这张图片"
tensorrt_llm.logger.set_level(args.log_level)
model = MultimodalModelRunner(args)
if args.input_file != None:
df = pd.read_csv(args.input_file)
text_zh = []
for i in tqdm.tqdm(range(len(df))):
img_path = df.loc[i]["img_path"]
raw_image = Image.open(img_path)
res = model.run(query, [raw_image], args.max_new_tokens)
text_zh.append(res)
df["text_zh"] = text_zh
df.to_csv(args.output_file, index=False, encoding="utf-8-sig")
else:
raw_image = Image.open(args.image_file)
res = model.run(query, [raw_image], args.max_new_tokens)
print(res)
timm==0.9.5
diffusers==0.21.2
peft==0.10.0
protobuf==3.19.0
transformers==4.39.1
accelerate==0.29.3
loguru==0.7.2
einops==0.7.0
sentencepiece==0.1.99
cuda-python==11.7.1
nvidia-pyindex==1.0.9
pandas==2.0.3
gradio==3.50.2
huggingface_hub==0.25.2
\ No newline at end of file
from pathlib import Path
from loguru import logger
from mllm.dialoggen_demo import DialogGen
from hydit.config import get_args
from hydit.inference_controlnet import End2End
from torchvision import transforms as T
import numpy as np
norm_transform = T.Compose(
[
T.ToTensor(),
T.Normalize([0.5], [0.5]),
]
)
from PIL import Image
def inferencer():
args = get_args()
models_root_path = Path(args.model_root)
if not models_root_path.exists():
raise ValueError(f"`models_root` not exists: {models_root_path}")
# Load models
gen = End2End(args, models_root_path)
# Try to enhance prompt
if args.enhance:
logger.info("Loading DialogGen model (for prompt enhancement)...")
enhancer = DialogGen(str(models_root_path / "dialoggen"), args.load_4bit)
logger.info("DialogGen model loaded.")
else:
enhancer = None
return args, gen, enhancer
if __name__ == "__main__":
args, gen, enhancer = inferencer()
if enhancer:
logger.info("Prompt Enhancement...")
success, enhanced_prompt = enhancer(args.prompt)
if not success:
logger.info("Sorry, the prompt is not compliant, refuse to draw.")
exit()
logger.info(f"Enhanced prompt: {enhanced_prompt}")
else:
enhanced_prompt = None
# Run inference
logger.info("Generating images...")
height, width = args.image_size
condition = (
Image.open(args.condition_image_path).convert("RGB").resize((width, height))
)
image = norm_transform(condition)
image = image.unsqueeze(0).cuda()
results = gen.predict(
args.prompt,
height=height,
width=width,
image=image,
seed=args.seed,
enhanced_prompt=enhanced_prompt,
negative_prompt=args.negative,
infer_steps=args.infer_steps,
guidance_scale=args.cfg_scale,
batch_size=args.batch_size,
src_size_cond=args.size_cond,
use_style_cond=args.use_style_cond,
)
images = results["images"]
# Save images
save_dir = Path("results")
save_dir.mkdir(exist_ok=True)
# Find the first available index
all_files = list(save_dir.glob("*.png"))
if all_files:
start = max([int(f.stem) for f in all_files]) + 1
else:
start = 0
for idx, pil_img in enumerate(images):
save_path = save_dir / f"{idx + start}.png"
pil_img.save(save_path)
logger.info(f"Save to {save_path}")
from pathlib import Path
from loguru import logger
from mllm.dialoggen_demo import DialogGen
from hydit.config import get_args
from hydit.inference_ipadapter import End2End
from torchvision import transforms as T
import numpy as np
norm_transform = T.Compose(
[
T.ToTensor(),
T.Normalize([0.5], [0.5]),
]
)
from PIL import Image
def inferencer():
args = get_args()
models_root_path = Path(args.model_root)
if not models_root_path.exists():
raise ValueError(f"`models_root` not exists: {models_root_path}")
# Load models
gen = End2End(args, models_root_path)
# Try to enhance prompt
if args.enhance:
logger.info("Loading DialogGen model (for prompt enhancement)...")
enhancer = DialogGen(str(models_root_path / "dialoggen"), args.load_4bit)
logger.info("DialogGen model loaded.")
else:
enhancer = None
return args, gen, enhancer
if __name__ == "__main__":
args, gen, enhancer = inferencer()
if enhancer:
logger.info("Prompt Enhancement...")
success, enhanced_prompt = enhancer(args.prompt)
if not success:
logger.info("Sorry, the prompt is not compliant, refuse to draw.")
exit()
logger.info(f"Enhanced prompt: {enhanced_prompt}")
else:
enhanced_prompt = None
# # Run inference
logger.info("Generating images...")
height, width = args.image_size
ref_image = Image.open(args.ref_image_path).convert("RGB")
i_scale = args.i_scale
results = gen.predict(
args.prompt,
height=height,
width=width,
image=ref_image,
i_scale=i_scale,
t_scale=1,
seed=3333,
enhanced_prompt=enhanced_prompt,
negative_prompt=args.negative,
infer_steps=args.infer_steps,
guidance_scale=3,
batch_size=args.batch_size,
src_size_cond=args.size_cond,
)
images = results["images"]
save_dir = Path("results")
save_dir.mkdir(exist_ok=True)
# Find the first available index
all_files = list(save_dir.glob("*.png"))
if all_files:
start = max([int(f.stem) for f in all_files]) + 1
else:
start = 0
for idx, pil_img in enumerate(images):
save_path = save_dir / f"{idx + start}.png"
pil_img.save(save_path)
logger.info(f"Save to {save_path}")
from pathlib import Path
from loguru import logger
from mllm.dialoggen_demo import DialogGen
from hydit.config import get_args
from hydit.inference import End2End
def inferencer():
args = get_args()
models_root_path = Path(args.model_root)
if not models_root_path.exists():
raise ValueError(f"`models_root` not exists: {models_root_path}")
# Load models
gen = End2End(args, models_root_path)
# Try to enhance prompt
if args.enhance:
logger.info("Loading DialogGen model (for prompt enhancement)...")
enhancer = DialogGen(str(models_root_path / "dialoggen"), args.load_4bit)
logger.info("DialogGen model loaded.")
else:
enhancer = None
return args, gen, enhancer
if __name__ == "__main__":
args, gen, enhancer = inferencer()
if enhancer:
logger.info("Prompt Enhancement...")
success, enhanced_prompt = enhancer(args.prompt)
if not success:
logger.info("Sorry, the prompt is not compliant, refuse to draw.")
exit()
logger.info(f"Enhanced prompt: {enhanced_prompt}")
else:
enhanced_prompt = None
# Run inference
logger.info("Generating images...")
height, width = args.image_size
results = gen.predict(
args.prompt,
height=height,
width=width,
seed=args.seed,
enhanced_prompt=enhanced_prompt,
negative_prompt=args.negative,
infer_steps=args.infer_steps,
guidance_scale=args.cfg_scale,
batch_size=args.batch_size,
src_size_cond=args.size_cond,
use_style_cond=args.use_style_cond,
)
images = results["images"]
# Save images
save_dir = Path("results")
save_dir.mkdir(exist_ok=True)
# Find the first available index
all_files = list(save_dir.glob("*.png"))
if all_files:
start = max([int(f.stem) for f in all_files]) + 1
else:
start = 0
for idx, pil_img in enumerate(images):
save_path = save_dir / f"{idx + start}.png"
pil_img.save(save_path)
logger.info(f"Save to {save_path}")
#!/bin/bash
test_base=./tests # 指定测试目录
export CUDA_VISIBLE_DEVICES=3 # 指定GPU
for file in $(find "$test_base" -maxdepth 1 -name 'test_*.sh'); do
# 去掉路径前的 './' 以获得文件名
filename=$(basename "$file")
echo "################################"
echo "Running tests in $filename..."
bash "$file"
echo "################################"
done
\ No newline at end of file
#!/bin/bash
task_name="infer_controlnet_canny"
log_file="${task_name}.log"
python sample_controlnet.py --infer-mode torch --no-enhance --load-key distill --infer-steps 50 --control-type canny --prompt "在夜晚的酒店门前,一座古老的中国风格的狮子雕像矗立着,它的眼睛闪烁着光芒,仿佛在守护着这座建筑。背景是夜晚的酒店前,构图方式是特写,平视,居中构图。这张照片呈现了真实摄影风格,蕴含了中国雕塑文化,同时展现了神秘氛围" --condition-image-path controlnet/asset/input/canny.jpg --control-weight 1.0 > "$log_file" 2>&1
exit_status=$?
if [ $exit_status -eq 0 ]; then
echo -e "\033[0;32m$task_name Passed\033[0m"
else
echo -e "\033[0;31m$task_name Failed\033[0m"
fi
###
task_name="infer_controlnet_depth"
log_file="${task_name}.log"
python sample_controlnet.py --infer-mode torch --no-enhance --load-key distill --infer-steps 50 --control-type depth --prompt "在夜晚的酒店门前,一座古老的中国风格的狮子雕像矗立着,它的眼睛闪烁着光芒,仿佛在守护着这座建筑。背景是夜晚的酒店前,构图方式是特写,平视,居中构图。这张照片呈现了真实摄影风格,蕴含了中国雕塑文化,同时展现了神秘氛围" --condition-image-path controlnet/asset/input/depth.jpg --control-weight 1.0 > "$log_file" 2>&1
exit_status=$?
if [ $exit_status -eq 0 ]; then
echo -e "\033[0;32m$task_name Passed\033[0m"
else
echo -e "\033[0;31m$task_name Failed\033[0m"
fi
###
task_name="infer_controlnet_pose"
log_file="${task_name}.log"
python sample_controlnet.py --infer-mode torch --no-enhance --load-key distill --infer-steps 50 --control-type pose --prompt "在夜晚的酒店门前,一座古老的中国风格的狮子雕像矗立着,它的眼睛闪烁着光芒,仿佛在守护着这座建筑。背景是夜晚的酒店前,构图方式是特写,平视,居中构图。这张照片呈现了真实摄影风格,蕴含了中国雕塑文化,同时展现了神秘氛围" --condition-image-path controlnet/asset/input/pose.jpg --control-weight 1.0 > "$log_file" 2>&1
exit_status=$?
if [ $exit_status -eq 0 ]; then
echo -e "\033[0;32m$task_name Passed\033[0m"
else
echo -e "\033[0;31m$task_name Failed\033[0m"
fi
#!/bin/bash
task_name="infer_ipadapter.sh"
log_file="${task_name}.log"
python3 sample_ipadapter.py --infer-mode torch --ref-image-path ipadapter/asset/input/tiger.png --i-scale 1.0 --prompt 一只老虎在海洋中游泳,背景是海洋。构图方式是居中构图,呈现了动漫风格和文化,营造了平静的氛围。 --infer-steps 30 --is-ipa True --load-key distill > "$log_file" 2>&1
exit_status=$?
if [ $exit_status -eq 0 ]; then
echo -e "\033[0;32m$task_name Passed\033[0m"
else
echo -e "\033[0;31m$task_name Failed\033[0m"
fi
#!/bin/bash
task_name="infer_text2img_flash_attn"
log_file="${task_name}.log"
python sample_t2i.py --infer-mode fa --infer-steps 30 --prompt "青花瓷风格,一只可爱的哈士奇" --no-enhance --load-key distill > "$log_file" 2>&1
exit_status=$?
if [ $exit_status -eq 0 ]; then
echo -e "\033[0;32m$task_name Passed\033[0m"
else
echo -e "\033[0;31m$task_name Failed\033[0m"
fi
###
task_name="infer_text2img_raw_attn"
log_file="${task_name}.log"
python sample_t2i.py --infer-mode torch --infer-steps 30 --prompt "青花瓷风格,一只可爱的哈士奇" --no-enhance --load-key distill > "$log_file" 2>&1
exit_status=$?
if [ $exit_status -eq 0 ]; then
echo -e "\033[0;32m$task_name Passed\033[0m"
else
echo -e "\033[0;31m$task_name Failed\033[0m"
fi
# ==============================================================================
# Description: Export ONNX model and build TensorRT engine.
# ==============================================================================
# Check Hydit Version.
if [ -z "$1" ]; then
HYDIT_VERSION=1.2
elif [ "$1" == "1.0" ]; then
HYDIT_VERSION=1.0
elif [ "$1" == "1.1" ]; then
HYDIT_VERSION=1.1
elif [ "$1" == "1.2" ]; then
HYDIT_VERSION=1.2
else
echo "Failed. Hydit Only Has Version: 1.0, 1.1, 1.2!"
exit 1
fi
echo "Hydit Version: "${HYDIT_VERSION}
export MODEL_ROOT=ckpts
export ONNX_WORKDIR=${MODEL_ROOT}/onnx_model
echo "MODEL_ROOT=${MODEL_ROOT}"
echo "ONNX_WORKDIR=${ONNX_WORKDIR}"
# Remove old directories.
if [ -d "${ONNX_WORKDIR}" ]; then
echo "Remove old ONNX directories..."
rm -r ${ONNX_WORKDIR}
fi
# Inspect the project directory.
SCRIPT_PATH="$( cd "$( dirname "$0" )" && pwd )"
PROJECT_DIR=$(dirname "$SCRIPT_PATH")
export PYTHONPATH=${PROJECT_DIR}:${PYTHONPATH}
echo "PYTHONPATH=${PYTHONPATH}"
cd ${PROJECT_DIR}
echo "Change directory to ${PROJECT_DIR}"
# ----------------------------------------
# 1. Export ONNX model.
# ----------------------------------------
# Sleep for reading the message.
sleep 2s
echo "Exporting ONNX model..."
if [ ${HYDIT_VERSION} == "1.2" ]; then
echo "Export ONNX for Hydit Version 1.2"
python trt/export_onnx.py --model-root ${MODEL_ROOT} --onnx-workdir ${ONNX_WORKDIR} --infer-mode torch
elif [ ${HYDIT_VERSION} == "1.1" ]; then
echo "Export ONNX for Hydit Version 1.1"
python trt/export_onnx.py --model-root ./HunyuanDiT-v1.1 --onnx-workdir ${ONNX_WORKDIR} --infer-mode torch --use-style-cond --size-cond 1024 1024 --beta-end 0.03
elif [ ${HYDIT_VERSION} == "1.0" ]; then
echo "Export ONNX for Hydit Version 1.0"
python trt/export_onnx.py --model-root ./HunyuanDiT-v1.0 --onnx-workdir ${ONNX_WORKDIR} --infer-mode torch --use-style-cond --size-cond 1024 1024 --beta-end 0.03
fi
echo "Exporting ONNX model finished"
# ----------------------------------------
# 2. Build TensorRT engine.
# ----------------------------------------
echo "Building TensorRT engine..."
ENGINE_DIR="${MODEL_ROOT}/t2i/model_trt/engine"
mkdir -p ${ENGINE_DIR}
ENGINE_PATH=${ENGINE_DIR}/model_onnx.plan
PLUGIN_PATH=${MODEL_ROOT}/t2i/model_trt/fmha_plugins/10.1_plugin_cuda11/fMHAPlugin.so
if [ ${HYDIT_VERSION} == "1.2" ]; then
trtexec \
--onnx=${ONNX_WORKDIR}/export_modified_fmha/model.onnx \
--fp16 \
--saveEngine=${ENGINE_PATH} \
--minShapes=x:2x4x90x90,t:2,encoder_hidden_states:2x77x1024,text_embedding_mask:2x77,encoder_hidden_states_t5:2x256x2048,text_embedding_mask_t5:2x256,cos_cis_img:2025x88,sin_cis_img:2025x88 \
--optShapes=x:2x4x128x128,t:2,encoder_hidden_states:2x77x1024,text_embedding_mask:2x77,encoder_hidden_states_t5:2x256x2048,text_embedding_mask_t5:2x256,cos_cis_img:4096x88,sin_cis_img:4096x88 \
--maxShapes=x:2x4x160x160,t:2,encoder_hidden_states:2x77x1024,text_embedding_mask:2x77,encoder_hidden_states_t5:2x256x2048,text_embedding_mask_t5:2x256,cos_cis_img:6400x88,sin_cis_img:6400x88 \
--shapes=x:2x4x128x128,t:2,encoder_hidden_states:2x77x1024,text_embedding_mask:2x77,encoder_hidden_states_t5:2x256x2048,text_embedding_mask_t5:2x256,cos_cis_img:4096x88,sin_cis_img:4096x88 \
--verbose \
--staticPlugins=${PLUGIN_PATH} \
--stronglyTyped
else
trtexec \
--onnx=${ONNX_WORKDIR}/export_modified_fmha/model.onnx \
--fp16 \
--saveEngine=${ENGINE_PATH} \
--minShapes=x:2x4x90x90,t:2,encoder_hidden_states:2x77x1024,text_embedding_mask:2x77,encoder_hidden_states_t5:2x256x2048,text_embedding_mask_t5:2x256,image_meta_size:2x6,style:2,cos_cis_img:2025x88,sin_cis_img:2025x88 \
--optShapes=x:2x4x128x128,t:2,encoder_hidden_states:2x77x1024,text_embedding_mask:2x77,encoder_hidden_states_t5:2x256x2048,text_embedding_mask_t5:2x256,image_meta_size:2x6,style:2,cos_cis_img:4096x88,sin_cis_img:4096x88 \
--maxShapes=x:2x4x160x160,t:2,encoder_hidden_states:2x77x1024,text_embedding_mask:2x77,encoder_hidden_states_t5:2x256x2048,text_embedding_mask_t5:2x256,image_meta_size:2x6,style:2,cos_cis_img:6400x88,sin_cis_img:6400x88 \
--shapes=x:2x4x128x128,t:2,encoder_hidden_states:2x77x1024,text_embedding_mask:2x77,encoder_hidden_states_t5:2x256x2048,text_embedding_mask_t5:2x256,image_meta_size:2x6,style:2,cos_cis_img:4096x88,sin_cis_img:4096x88 \
--verbose \
--builderOptimizationLevel=4 \
--staticPlugins=${PLUGIN_PATH} \
--stronglyTyped
fi
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment