Unverified Commit 69c2f650 authored by Yang Yong (雍洋)'s avatar Yang Yong (雍洋) Committed by GitHub
Browse files

Remove outdated models (#348)

parent 08d2f46a
#
# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import argparse
import os
import tensorrt as trt
from loguru import logger
from .common_runtime import *
try:
# Sometimes python does not understand FileNotFoundError
FileNotFoundError
except NameError:
FileNotFoundError = IOError
def GiB(val):
return val * 1 << 30
def add_help(description):
parser = argparse.ArgumentParser(description=description, formatter_class=argparse.ArgumentDefaultsHelpFormatter)
args, _ = parser.parse_known_args()
def find_sample_data(description="Runs a TensorRT Python sample", subfolder="", find_files=[], err_msg=""):
"""
Parses sample arguments.
Args:
description (str): Description of the sample.
subfolder (str): The subfolder containing data relevant to this sample
find_files (str): A list of filenames to find. Each filename will be replaced with an absolute path.
Returns:
str: Path of data directory.
"""
# Standard command-line arguments for all samples.
kDEFAULT_DATA_ROOT = os.path.join(os.sep, "usr", "src", "tensorrt", "data")
parser = argparse.ArgumentParser(description=description, formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument(
"-d",
"--datadir",
help="Location of the TensorRT sample data directory, and any additional data directories.",
action="append",
default=[kDEFAULT_DATA_ROOT],
)
args, _ = parser.parse_known_args()
def get_data_path(data_dir):
# If the subfolder exists, append it to the path, otherwise use the provided path as-is.
data_path = os.path.join(data_dir, subfolder)
if not os.path.exists(data_path):
if data_dir != kDEFAULT_DATA_ROOT:
logger.info("WARNING: " + data_path + " does not exist. Trying " + data_dir + " instead.")
data_path = data_dir
# Make sure data directory exists.
if not (os.path.exists(data_path)) and data_dir != kDEFAULT_DATA_ROOT:
logger.info("WARNING: {:} does not exist. Please provide the correct data path with the -d option.".format(data_path))
return data_path
data_paths = [get_data_path(data_dir) for data_dir in args.datadir]
return data_paths, locate_files(data_paths, find_files, err_msg)
def locate_files(data_paths, filenames, err_msg=""):
"""
Locates the specified files in the specified data directories.
If a file exists in multiple data directories, the first directory is used.
Args:
data_paths (List[str]): The data directories.
filename (List[str]): The names of the files to find.
Returns:
List[str]: The absolute paths of the files.
Raises:
FileNotFoundError if a file could not be located.
"""
found_files = [None] * len(filenames)
for data_path in data_paths:
# Find all requested files.
for index, (found, filename) in enumerate(zip(found_files, filenames)):
if not found:
file_path = os.path.abspath(os.path.join(data_path, filename))
if os.path.exists(file_path):
found_files[index] = file_path
# Check that all files were found
for f, filename in zip(found_files, filenames):
if not f or not os.path.exists(f):
raise FileNotFoundError("Could not find {:}. Searched in data paths: {:}\n{:}".format(filename, data_paths, err_msg))
return found_files
# Sets up the builder to use the timing cache file, and creates it if it does not already exist
def setup_timing_cache(config: trt.IBuilderConfig, timing_cache_path: os.PathLike):
buffer = b""
if os.path.exists(timing_cache_path):
with open(timing_cache_path, mode="rb") as timing_cache_file:
buffer = timing_cache_file.read()
timing_cache: trt.ITimingCache = config.create_timing_cache(buffer)
config.set_timing_cache(timing_cache, True)
# Saves the config's timing cache to file
def save_timing_cache(config: trt.IBuilderConfig, timing_cache_path: os.PathLike):
timing_cache: trt.ITimingCache = config.get_timing_cache()
with open(timing_cache_path, "wb") as timing_cache_file:
timing_cache_file.write(memoryview(timing_cache.serialize()))
#
# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import ctypes
from typing import List, Optional, Union
import numpy as np
import tensorrt as trt
from cuda import cuda, cudart
def check_cuda_err(err):
if isinstance(err, cuda.CUresult):
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError("Cuda Error: {}".format(err))
if isinstance(err, cudart.cudaError_t):
if err != cudart.cudaError_t.cudaSuccess:
raise RuntimeError("Cuda Runtime Error: {}".format(err))
else:
raise RuntimeError("Unknown error type: {}".format(err))
def cuda_call(call):
err, res = call[0], call[1:]
check_cuda_err(err)
if len(res) == 1:
res = res[0]
return res
class HostDeviceMem:
"""Pair of host and device memory, where the host memory is wrapped in a numpy array"""
def __init__(self, size: int, dtype: Optional[np.dtype] = None):
dtype = dtype or np.dtype(np.uint8)
nbytes = size * dtype.itemsize
host_mem = cuda_call(cudart.cudaMallocHost(nbytes))
pointer_type = ctypes.POINTER(np.ctypeslib.as_ctypes_type(dtype))
self._host = np.ctypeslib.as_array(ctypes.cast(host_mem, pointer_type), (size,))
self._device = cuda_call(cudart.cudaMalloc(nbytes))
self._nbytes = nbytes
@property
def host(self) -> np.ndarray:
return self._host
@host.setter
def host(self, data: Union[np.ndarray, bytes]):
if isinstance(data, np.ndarray):
if data.size > self.host.size:
raise ValueError(f"Tried to fit an array of size {data.size} into host memory of size {self.host.size}")
np.copyto(self.host[: data.size], data.flat, casting="safe")
else:
assert self.host.dtype == np.uint8
self.host[: self.nbytes] = np.frombuffer(data, dtype=np.uint8)
@property
def device(self) -> int:
return self._device
@property
def nbytes(self) -> int:
return self._nbytes
def __str__(self):
return f"Host:\n{self.host}\nDevice:\n{self.device}\nSize:\n{self.nbytes}\n"
def __repr__(self):
return self.__str__()
def free(self):
cuda_call(cudart.cudaFree(self.device))
cuda_call(cudart.cudaFreeHost(self.host.ctypes.data))
# Allocates all buffers required for an engine, i.e. host/device inputs/outputs.
# If engine uses dynamic shapes, specify a profile to find the maximum input & output size.
def allocate_buffers(engine: trt.ICudaEngine, profile_idx: Optional[int] = None):
inputs = []
outputs = []
bindings = []
stream = cuda_call(cudart.cudaStreamCreate())
tensor_names = [engine.get_tensor_name(i) for i in range(engine.num_io_tensors)]
for binding in tensor_names:
# get_tensor_profile_shape returns (min_shape, optimal_shape, max_shape)
# Pick out the max shape to allocate enough memory for the binding.
shape = engine.get_tensor_shape(binding) if profile_idx is None else engine.get_tensor_profile_shape(binding, profile_idx)[-1]
shape_valid = np.all([s >= 0 for s in shape])
if not shape_valid and profile_idx is None:
raise ValueError(f"Binding {binding} has dynamic shape, " + "but no profile was specified.")
size = trt.volume(shape)
trt_type = engine.get_tensor_dtype(binding)
# Allocate host and device buffers
try:
dtype = np.dtype(trt.nptype(trt_type))
bindingMemory = HostDeviceMem(size, dtype)
except TypeError: # no numpy support: create a byte array instead (BF16, FP8, INT4)
size = int(size * trt_type.itemsize)
bindingMemory = HostDeviceMem(size)
# Append the device buffer to device bindings.
bindings.append(int(bindingMemory.device))
# Append to the appropriate list.
if engine.get_tensor_mode(binding) == trt.TensorIOMode.INPUT:
inputs.append(bindingMemory)
else:
outputs.append(bindingMemory)
return inputs, outputs, bindings, stream
# Frees the resources allocated in allocate_buffers
def free_buffers(inputs: List[HostDeviceMem], outputs: List[HostDeviceMem], stream: cudart.cudaStream_t):
for mem in inputs + outputs:
mem.free()
cuda_call(cudart.cudaStreamDestroy(stream))
# Wrapper for cudaMemcpy which infers copy size and does error checking
def memcpy_host_to_device(device_ptr: int, host_arr: np.ndarray):
nbytes = host_arr.size * host_arr.itemsize
cuda_call(cudart.cudaMemcpy(device_ptr, host_arr, nbytes, cudart.cudaMemcpyKind.cudaMemcpyHostToDevice))
# Wrapper for cudaMemcpy which infers copy size and does error checking
def memcpy_device_to_host(host_arr: np.ndarray, device_ptr: int):
nbytes = host_arr.size * host_arr.itemsize
cuda_call(cudart.cudaMemcpy(host_arr, device_ptr, nbytes, cudart.cudaMemcpyKind.cudaMemcpyDeviceToHost))
def _do_inference_base(inputs, outputs, stream, execute_async_func):
# Transfer input data to the GPU.
kind = cudart.cudaMemcpyKind.cudaMemcpyHostToDevice
[cuda_call(cudart.cudaMemcpyAsync(inp.device, inp.host, inp.nbytes, kind, stream)) for inp in inputs]
# Run inference.
execute_async_func()
# Transfer predictions back from the GPU.
kind = cudart.cudaMemcpyKind.cudaMemcpyDeviceToHost
[cuda_call(cudart.cudaMemcpyAsync(out.host, out.device, out.nbytes, kind, stream)) for out in outputs]
# Synchronize the stream
cuda_call(cudart.cudaStreamSynchronize(stream))
# Return only the host outputs.
return [out.host for out in outputs]
# This function is generalized for multiple inputs/outputs.
# inputs and outputs are expected to be lists of HostDeviceMem objects.
def do_inference(context, engine, bindings, inputs, outputs, stream):
def execute_async_func():
context.execute_async_v3(stream_handle=stream)
# Setup context tensor address.
num_io = engine.num_io_tensors
for i in range(num_io):
context.set_tensor_address(engine.get_tensor_name(i), bindings[i])
return _do_inference_base(inputs, outputs, stream, execute_async_func)
...@@ -5,16 +5,12 @@ import torch.distributed as dist ...@@ -5,16 +5,12 @@ import torch.distributed as dist
from loguru import logger from loguru import logger
from lightx2v.common.ops import * from lightx2v.common.ops import *
from lightx2v.models.runners.cogvideox.cogvidex_runner import CogvideoxRunner # noqa: F401
from lightx2v.models.runners.hunyuan.hunyuan_runner import HunyuanRunner # noqa: F401
from lightx2v.models.runners.qwen_image.qwen_image_runner import QwenImageRunner # noqa: F401 from lightx2v.models.runners.qwen_image.qwen_image_runner import QwenImageRunner # noqa: F401
from lightx2v.models.runners.wan.wan_animate_runner import WanAnimateRunner # noqa: F401 from lightx2v.models.runners.wan.wan_animate_runner import WanAnimateRunner # noqa: F401
from lightx2v.models.runners.wan.wan_audio_runner import Wan22AudioRunner, WanAudioRunner # noqa: F401 from lightx2v.models.runners.wan.wan_audio_runner import Wan22AudioRunner, WanAudioRunner # noqa: F401
from lightx2v.models.runners.wan.wan_causvid_runner import WanCausVidRunner # noqa: F401
from lightx2v.models.runners.wan.wan_distill_runner import WanDistillRunner # noqa: F401 from lightx2v.models.runners.wan.wan_distill_runner import WanDistillRunner # noqa: F401
from lightx2v.models.runners.wan.wan_runner import Wan22MoeRunner, WanRunner # noqa: F401 from lightx2v.models.runners.wan.wan_runner import Wan22MoeRunner, WanRunner # noqa: F401
from lightx2v.models.runners.wan.wan_sf_runner import WanSFRunner # noqa: F401 from lightx2v.models.runners.wan.wan_sf_runner import WanSFRunner # noqa: F401
from lightx2v.models.runners.wan.wan_skyreels_v2_df_runner import WanSkyreelsV2DFRunner # noqa: F401
from lightx2v.models.runners.wan.wan_vace_runner import WanVaceRunner # noqa: F401 from lightx2v.models.runners.wan.wan_vace_runner import WanVaceRunner # noqa: F401
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
from lightx2v.utils.input_info import set_input_info from lightx2v.utils.input_info import set_input_info
...@@ -40,13 +36,9 @@ def main(): ...@@ -40,13 +36,9 @@ def main():
required=True, required=True,
choices=[ choices=[
"wan2.1", "wan2.1",
"hunyuan",
"wan2.1_distill", "wan2.1_distill",
"wan2.1_causvid",
"wan2.1_skyreels_v2_df",
"wan2.1_vace", "wan2.1_vace",
"wan2.1_sf", "wan2.1_sf",
"cogvideox",
"seko_talk", "seko_talk",
"wan2.2_moe", "wan2.2_moe",
"wan2.2", "wan2.2",
......
import torch
from loguru import logger
from transformers import AutoTokenizer, CLIPTextModel
class TextEncoderHFClipModel:
def __init__(self, model_path, device):
self.device = device
self.model_path = model_path
self.init()
self.load()
def init(self):
self.max_length = 77
def load(self):
self.model = CLIPTextModel.from_pretrained(self.model_path).to(torch.float16).to(self.device)
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, padding_side="right")
def to_cpu(self):
self.model = self.model.to("cpu")
def to_cuda(self):
self.model = self.model.to("cuda")
@torch.no_grad()
def infer(self, text, config):
if config.cpu_offload:
self.to_cuda()
tokens = self.tokenizer(
text,
return_length=False,
return_overflowing_tokens=False,
return_attention_mask=True,
truncation=True,
max_length=self.max_length,
padding="max_length",
return_tensors="pt",
).to("cuda")
outputs = self.model(
input_ids=tokens["input_ids"],
attention_mask=tokens["attention_mask"],
output_hidden_states=False,
)
last_hidden_state = outputs["pooler_output"]
if config.cpu_offload:
self.to_cpu()
return last_hidden_state, tokens["attention_mask"]
if __name__ == "__main__":
model_path = ""
model = TextEncoderHFClipModel(model_path, torch.device("cuda"))
text = "A cat walks on the grass, realistic style."
outputs = model.infer(text)
logger.info(outputs)
import torch
from loguru import logger
from transformers import AutoModel, AutoTokenizer
class TextEncoderHFLlamaModel:
def __init__(self, model_path, device):
self.device = device
self.model_path = model_path
self.init()
self.load()
def init(self):
self.max_length = 351
self.hidden_state_skip_layer = 2
self.crop_start = 95
self.prompt_template = (
"<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: "
"1. The main content and theme of the video."
"2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
"3. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
"4. background environment, light, style and atmosphere."
"5. camera angles, movements, and transitions used in the video:<|eot_id|>"
"<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
)
def load(self):
self.model = AutoModel.from_pretrained(self.model_path, low_cpu_mem_usage=True).to(torch.float16).to(self.device)
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, padding_side="right")
def to_cpu(self):
self.model = self.model.to("cpu")
def to_cuda(self):
self.model = self.model.to("cuda")
@torch.no_grad()
def infer(self, text, config):
if config.cpu_offload:
self.to_cuda()
text = self.prompt_template.format(text)
tokens = self.tokenizer(
text,
return_length=False,
return_overflowing_tokens=False,
return_attention_mask=True,
truncation=True,
max_length=self.max_length,
padding="max_length",
return_tensors="pt",
).to("cuda")
outputs = self.model(
input_ids=tokens["input_ids"],
attention_mask=tokens["attention_mask"],
output_hidden_states=True,
)
last_hidden_state = outputs.hidden_states[-(self.hidden_state_skip_layer + 1)][:, self.crop_start :]
attention_mask = tokens["attention_mask"][:, self.crop_start :]
if config.cpu_offload:
self.to_cpu()
return last_hidden_state, attention_mask
if __name__ == "__main__":
model_path = ""
model = TextEncoderHFLlamaModel(model_path, torch.device("cuda"))
text = "A cat walks on the grass, realistic style."
outputs = model.infer(text)
logger.info(outputs)
import torch
from transformers import AutoTokenizer, CLIPImageProcessor, LlavaForConditionalGeneration
def generate_crop_size_list(base_size=256, patch_size=32, max_ratio=4.0):
"""generate crop size list
Args:
base_size (int, optional): the base size for generate bucket. Defaults to 256.
patch_size (int, optional): the stride to generate bucket. Defaults to 32.
max_ratio (float, optional): th max ratio for h or w based on base_size . Defaults to 4.0.
Returns:
list: generate crop size list
"""
num_patches = round((base_size / patch_size) ** 2)
assert max_ratio >= 1.0
crop_size_list = []
wp, hp = num_patches, 1
while wp > 0:
if max(wp, hp) / min(wp, hp) <= max_ratio:
crop_size_list.append((wp * patch_size, hp * patch_size))
if (hp + 1) * wp <= num_patches:
hp += 1
else:
wp -= 1
return crop_size_list
def get_closest_ratio(height: float, width: float, ratios: list, buckets: list):
"""get the closest ratio in the buckets
Args:
height (float): video height
width (float): video width
ratios (list): video aspect ratio
buckets (list): buckets generate by `generate_crop_size_list`
Returns:
the closest ratio in the buckets and the corresponding ratio
"""
aspect_ratio = float(height) / float(width)
diff_ratios = ratios - aspect_ratio
if aspect_ratio >= 1:
indices = [(index, x) for index, x in enumerate(diff_ratios) if x <= 0]
else:
indices = [(index, x) for index, x in enumerate(diff_ratios) if x > 0]
closest_ratio_id = min(indices, key=lambda pair: abs(pair[1]))[0]
closest_size = buckets[closest_ratio_id]
closest_ratio = ratios[closest_ratio_id]
return closest_size, closest_ratio
class TextEncoderHFLlavaModel:
def __init__(self, model_path, device):
self.device = device
self.model_path = model_path
self.init()
self.load()
def init(self):
self.max_length = 359
self.hidden_state_skip_layer = 2
self.crop_start = 103
self.double_return_token_id = 271
self.image_emb_len = 576
self.text_crop_start = self.crop_start - 1 + self.image_emb_len
self.image_crop_start = 5
self.image_crop_end = 581
self.image_embed_interleave = 4
self.prompt_template = (
"<|start_header_id|>system<|end_header_id|>\n\n<image>\nDescribe the video by detailing the following aspects according to the reference image: "
"1. The main content and theme of the video."
"2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
"3. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
"4. background environment, light, style and atmosphere."
"5. camera angles, movements, and transitions used in the video:<|eot_id|>\n\n"
"<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n"
)
def load(self):
self.model = LlavaForConditionalGeneration.from_pretrained(self.model_path, low_cpu_mem_usage=True).to(torch.float16).to(self.device)
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, padding_side="right")
self.processor = CLIPImageProcessor.from_pretrained(self.model_path)
def to_cpu(self):
self.model = self.model.to("cpu")
def to_cuda(self):
self.model = self.model.to("cuda")
@torch.no_grad()
def infer(self, text, img, config):
if config.cpu_offload:
self.to_cuda()
text = self.prompt_template.format(text)
tokens = self.tokenizer(
text,
return_length=False,
return_overflowing_tokens=False,
return_attention_mask=True,
truncation=True,
max_length=self.max_length,
padding="max_length",
return_tensors="pt",
).to("cuda")
image_outputs = self.processor(img, return_tensors="pt")["pixel_values"].to(self.device)
attention_mask = tokens["attention_mask"].to(self.device)
outputs = self.model(input_ids=tokens["input_ids"], attention_mask=attention_mask, output_hidden_states=True, pixel_values=image_outputs)
last_hidden_state = outputs.hidden_states[-(self.hidden_state_skip_layer + 1)]
batch_indices, last_double_return_token_indices = torch.where(tokens["input_ids"] == self.double_return_token_id)
last_double_return_token_indices = last_double_return_token_indices.reshape(1, -1)[:, -1]
assistant_crop_start = last_double_return_token_indices - 1 + self.image_emb_len - 4
assistant_crop_end = last_double_return_token_indices - 1 + self.image_emb_len
attention_mask_assistant_crop_start = last_double_return_token_indices - 4
attention_mask_assistant_crop_end = last_double_return_token_indices
text_last_hidden_state = torch.cat([last_hidden_state[0, self.text_crop_start : assistant_crop_start[0].item()], last_hidden_state[0, assistant_crop_end[0].item() :]])
text_attention_mask = torch.cat([attention_mask[0, self.crop_start : attention_mask_assistant_crop_start[0].item()], attention_mask[0, attention_mask_assistant_crop_end[0].item() :]])
image_last_hidden_state = last_hidden_state[0, self.image_crop_start : self.image_crop_end]
image_attention_mask = torch.ones(image_last_hidden_state.shape[0]).to(last_hidden_state.device).to(attention_mask.dtype)
text_last_hidden_state.unsqueeze_(0)
text_attention_mask.unsqueeze_(0)
image_last_hidden_state.unsqueeze_(0)
image_attention_mask.unsqueeze_(0)
image_last_hidden_state = image_last_hidden_state[:, :: self.image_embed_interleave, :]
image_attention_mask = image_attention_mask[:, :: self.image_embed_interleave]
last_hidden_state = torch.cat([image_last_hidden_state, text_last_hidden_state], dim=1)
attention_mask = torch.cat([image_attention_mask, text_attention_mask], dim=1)
if config.cpu_offload:
self.to_cpu()
return last_hidden_state, attention_mask
import os
import torch
from transformers import T5EncoderModel, T5Tokenizer
class T5EncoderModel_v1_1_xxl:
def __init__(self, config):
self.config = config
self.model = T5EncoderModel.from_pretrained(os.path.join(config.model_path, "text_encoder")).to(torch.bfloat16).to(torch.device("cuda"))
self.tokenizer = T5Tokenizer.from_pretrained(os.path.join(config.model_path, "tokenizer"), padding_side="right")
def to_cpu(self):
self.model = self.model.to("cpu")
def to_cuda(self):
self.model = self.model.to("cuda")
def infer(self, texts, config):
text_inputs = self.tokenizer(
texts,
padding="max_length",
max_length=config.text_len,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
).to("cuda")
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(texts, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, config.text_len - 1 : -1])
print(f"The following part of your input was truncated because `max_sequence_length` is set to {self.text_len} tokens: {removed_text}")
prompt_embeds = self.model(text_input_ids.to(torch.device("cuda")))[0]
return prompt_embeds
import torch
class CogvideoxPostInfer:
def __init__(self, config):
self.config = config
def ada_layernorm(self, weight_mm, weight_ln, x, temb):
temb = torch.nn.functional.silu(temb)
temb = weight_mm.apply(temb)
shift, scale = temb.chunk(2, dim=1)
x = weight_ln.apply(x) * (1 + scale) + shift
return x
def infer(self, weight, hidden_states, encoder_hidden_states, temb, infer_shapes):
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=0)
hidden_states = weight.norm_final.apply(hidden_states)
hidden_states = hidden_states[self.config.text_len :,]
hidden_states = self.ada_layernorm(weight.norm_out_linear, weight.norm_out_norm, hidden_states, temb=temb)
hidden_states = weight.proj_out.apply(hidden_states)
p = self.config["patch_size"]
p_t = self.config["patch_size_t"]
num_frames, _, height, width = infer_shapes
output = hidden_states.reshape((num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p)
output = output.permute(0, 4, 3, 1, 5, 2, 6).flatten(5, 6).flatten(3, 4).flatten(0, 1)
return output
import torch
from diffusers.models.embeddings import get_3d_sincos_pos_embed, get_timestep_embedding
class CogvideoxPreInfer:
def __init__(self, config):
self.config = config
self.use_positional_embeddings = not self.config.use_rotary_positional_embeddings
self.inner_dim = self.config.transformer_num_attention_heads * self.config.transformer_attention_head_dim
self.freq_shift = 0
self.flip_sin_to_cos = True
self.scale = 1
self.act = "silu"
def _get_positional_embeddings(self, sample_height, sample_width, sample_frames, device):
post_patch_height = sample_height // self.config.patch_size
post_patch_width = sample_width // self.config.patch_size
post_time_compression_frames = (sample_frames - 1) // self.config.transformer_temporal_compression_ratio + 1
num_patches = post_patch_height * post_patch_width * post_time_compression_frames
pos_embedding = get_3d_sincos_pos_embed(
self.inner_dim,
(post_patch_width, post_patch_height),
post_time_compression_frames,
self.config.transformer_spatial_interpolation_scale,
self.config.transformer_temporal_interpolation_scale,
device=device,
output_type="pt",
)
pos_embedding = pos_embedding.flatten(0, 1)
joint_pos_embedding = pos_embedding.new_zeros(1, self.config.text_len + num_patches, self.inner_dim, requires_grad=False)
joint_pos_embedding.data[:, self.config.text_len :].copy_(pos_embedding)
return joint_pos_embedding
def infer(self, weights, hidden_states, timestep, encoder_hidden_states):
t_emb = get_timestep_embedding(
timestep,
self.inner_dim,
flip_sin_to_cos=self.flip_sin_to_cos,
downscale_freq_shift=self.freq_shift,
scale=self.scale,
)
t_emb = t_emb.to(dtype=hidden_states.dtype)
sample = weights.time_embedding_linear_1.apply(t_emb)
sample = torch.nn.functional.silu(sample)
emb = weights.time_embedding_linear_2.apply(sample)
text_embeds = weights.patch_embed_text_proj.apply(encoder_hidden_states)
num_frames, channels, height, width = hidden_states.shape
infer_shapes = (num_frames, channels, height, width)
p = self.config.patch_size
p_t = self.config.patch_size_t
image_embeds = hidden_states.permute(0, 2, 3, 1)
image_embeds = image_embeds.reshape(num_frames // p_t, p_t, height // p, p, width // p, p, channels)
image_embeds = image_embeds.permute(0, 2, 4, 6, 1, 3, 5).flatten(3, 6).flatten(0, 2)
image_embeds = weights.patch_embed_proj.apply(image_embeds)
embeds = torch.cat([text_embeds, image_embeds], dim=0).contiguous()
if self.use_positional_embeddings or self.config.transformer_use_learned_positional_embeddings:
if self.config.transformer_use_learned_positional_embeddings and (self.sample_width != width or self.sample_height != height):
raise ValueError(
"It is currently not possible to generate videos at a different resolution that the defaults. This should only be the case with 'THUDM/CogVideoX-5b-I2V'."
"If you think this is incorrect, please open an issue at https://github.com/huggingface/diffusers/issues."
)
pre_time_compression_frames = (num_frames - 1) * self.config.transformer_temporal_compression_ratio + 1
if self.config.transformer_sample_height != height or self.config.transformer_sample_width != width or self.config.transformer_sample_frames != pre_time_compression_frames:
pos_embedding = self._get_positional_embeddings(height, width, pre_time_compression_frames, device=embeds.device)[0]
else:
pos_embedding = self.pos_embedding[0]
pos_embedding = pos_embedding.to(dtype=embeds.dtype)
embeds = embeds + pos_embedding
hidden_states = embeds
text_seq_length = encoder_hidden_states.shape[0]
encoder_hidden_states = hidden_states[:text_seq_length, :]
hidden_states = hidden_states[text_seq_length:, :]
return hidden_states, encoder_hidden_states, emb, infer_shapes
import torch
import torch.nn.functional as F
def apply_rotary_emb(x, freqs_cis, use_real=True, use_real_unbind_dim=-1):
"""
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
tensors contain rotary embeddings and are returned as real tensors.
Args:
x (`torch.Tensor`):
Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
Returns:
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
"""
if use_real:
cos, sin = freqs_cis # [S, D]
cos = cos[None]
sin = sin[None]
cos, sin = cos.to(x.device), sin.to(x.device)
if use_real_unbind_dim == -1:
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1)
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(2)
elif use_real_unbind_dim == -2:
# Used for Stable Audio
x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2)
x_rotated = torch.cat([-x_imag, x_real], dim=-1)
else:
raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
return out
else:
# used for lumina
x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
freqs_cis = freqs_cis.unsqueeze(2)
x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(2)
return x_out.type_as(x)
class CogvideoxTransformerInfer:
def __init__(self, config):
self.config = config
self.attn_type = "torch_sdpa"
def set_scheduler(self, scheduler):
self.scheduler = scheduler
def infer(self, weights, hidden_states, encoder_hidden_states, temb):
image_rotary_emb = self.scheduler.image_rotary_emb
for i in range(self.config.transformer_num_layers):
hidden_states, encoder_hidden_states = self.infer_block(
weights.blocks_weights[i],
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
)
return hidden_states, encoder_hidden_states
def cogvideox_norm1(self, weights, hidden_states, encoder_hidden_states, temb):
temb = torch.nn.functional.silu(temb)
temb = weights.norm1_linear.apply(temb)
shift, scale, gate, enc_shift, enc_scale, enc_gate = temb.chunk(6, dim=1)
hidden_states = weights.norm1_norm.apply(hidden_states) * (1 + scale)[:, :] + shift[:, :]
encoder_hidden_states = weights.norm1_norm.apply(encoder_hidden_states) * (1 + enc_scale)[:, :] + enc_shift[:, :]
return hidden_states, encoder_hidden_states, gate, enc_gate
def cogvideox_norm2(self, weights, hidden_states, encoder_hidden_states, temb):
temb = torch.nn.functional.silu(temb)
temb = weights.norm2_linear.apply(temb)
shift, scale, gate, enc_shift, enc_scale, enc_gate = temb.chunk(6, dim=1)
hidden_states = weights.norm2_norm.apply(hidden_states) * (1 + scale)[:, :] + shift[:, :]
encoder_hidden_states = weights.norm2_norm.apply(encoder_hidden_states) * (1 + enc_scale)[:, :] + enc_shift[:, :]
return hidden_states, encoder_hidden_states, gate, enc_gate
def cogvideox_attention(self, weights, hidden_states, encoder_hidden_states, image_rotary_emb):
text_seq_length = encoder_hidden_states.size(0)
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=0)
query = weights.attn1_to_q.apply(hidden_states)
key = weights.attn1_to_k.apply(hidden_states)
value = weights.attn1_to_v.apply(hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // self.config.transformer_num_attention_heads
query = query.view(-1, self.config.transformer_num_attention_heads, head_dim).transpose(0, 1)
key = key.view(-1, self.config.transformer_num_attention_heads, head_dim).transpose(0, 1)
value = value.view(-1, self.config.transformer_num_attention_heads, head_dim).transpose(0, 1)
query = weights.attn1_norm_q.apply(query)
key = weights.attn1_norm_k.apply(key)
query[:, text_seq_length:] = apply_rotary_emb(query[:, text_seq_length:], image_rotary_emb)
key[:, text_seq_length:] = apply_rotary_emb(key[:, text_seq_length:], image_rotary_emb)
hidden_states = F.scaled_dot_product_attention(query[None], key[None], value[None], attn_mask=None, dropout_p=0.0, is_causal=False)
hidden_states = hidden_states.transpose(1, 2).reshape(1, -1, self.config.transformer_num_attention_heads * head_dim)
hidden_states = hidden_states.squeeze(0)
hidden_states = weights.attn1_to_out.apply(hidden_states)
encoder_hidden_states, hidden_states = hidden_states.split([text_seq_length, hidden_states.size(0) - text_seq_length], dim=0)
return hidden_states, encoder_hidden_states
def cogvideox_ff(self, weights, hidden_states):
hidden_states = weights.ff_net_0_proj.apply(hidden_states)
hidden_states = torch.nn.functional.gelu(hidden_states, approximate="tanh")
hidden_states = weights.ff_net_2_proj.apply(hidden_states)
return hidden_states
@torch.no_grad()
def infer_block(self, weights, hidden_states, encoder_hidden_states, temb, image_rotary_emb):
text_seq_length = encoder_hidden_states.size(0)
norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.cogvideox_norm1(weights, hidden_states, encoder_hidden_states, temb)
attn_hidden_states, attn_encoder_hidden_states = self.cogvideox_attention(
weights,
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
)
hidden_states = hidden_states + gate_msa * attn_hidden_states
encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.cogvideox_norm2(weights, hidden_states, encoder_hidden_states, temb)
norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=0)
ff_output = self.cogvideox_ff(weights, norm_hidden_states)
hidden_states = hidden_states + gate_ff * ff_output[text_seq_length:,]
encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:text_seq_length,]
return hidden_states, encoder_hidden_states
import glob
import json
import math
import os
import torch
from safetensors import safe_open
from lightx2v.models.networks.cogvideox.infer.post_infer import CogvideoxPostInfer
from lightx2v.models.networks.cogvideox.infer.pre_infer import CogvideoxPreInfer
from lightx2v.models.networks.cogvideox.infer.transformer_infer import CogvideoxTransformerInfer
from lightx2v.models.networks.cogvideox.weights.post_weights import CogvideoxPostWeights
from lightx2v.models.networks.cogvideox.weights.pre_weights import CogvideoxPreWeights
from lightx2v.models.networks.cogvideox.weights.transformers_weights import CogvideoxTransformerWeights
from lightx2v.utils.envs import *
class CogvideoxModel:
pre_weight_class = CogvideoxPreWeights
post_weight_class = CogvideoxPostWeights
transformer_weight_class = CogvideoxTransformerWeights
def __init__(self, config):
self.config = config
self.device = torch.device("cuda")
self._init_infer_class()
self._init_weights()
self._init_infer()
def _init_infer_class(self):
self.pre_infer_class = CogvideoxPreInfer
self.post_infer_class = CogvideoxPostInfer
self.transformer_infer_class = CogvideoxTransformerInfer
def _load_safetensor_to_dict(self, file_path):
with safe_open(file_path, framework="pt") as f:
tensor_dict = {key: f.get_tensor(key).to(GET_DTYPE()).cuda() for key in f.keys()}
return tensor_dict
def _load_ckpt(self):
safetensors_pattern = os.path.join(self.config.model_path, "transformer", "*.safetensors")
safetensors_files = glob.glob(safetensors_pattern)
if not safetensors_files:
raise FileNotFoundError(f"No .safetensors files found in directory: {self.model_path}")
weight_dict = {}
for file_path in safetensors_files:
file_weights = self._load_safetensor_to_dict(file_path)
weight_dict.update(file_weights)
return weight_dict
def _init_weights(self):
weight_dict = self._load_ckpt()
with open(os.path.join(self.config.model_path, "transformer", "config.json"), "r") as f:
transformer_cfg = json.load(f)
# init weights
self.pre_weight = self.pre_weight_class(transformer_cfg)
self.transformer_weights = self.transformer_weight_class(transformer_cfg)
self.post_weight = self.post_weight_class(transformer_cfg)
# load weights
self.pre_weight.load_weights(weight_dict)
self.transformer_weights.load_weights(weight_dict)
self.post_weight.load_weights(weight_dict)
def _init_infer(self):
self.pre_infer = self.pre_infer_class(self.config)
self.transformer_infer = self.transformer_infer_class(self.config)
self.post_infer = self.post_infer_class(self.config)
def set_scheduler(self, scheduler):
self.scheduler = scheduler
self.transformer_infer.set_scheduler(scheduler)
def to_cpu(self):
self.pre_weight.to_cpu()
self.post_weight.to_cpu()
self.transformer_weights.to_cpu()
def to_cuda(self):
self.pre_weight.to_cuda()
self.post_weight.to_cuda()
self.transformer_weights.to_cuda()
@torch.no_grad()
def infer(self, inputs):
t = self.scheduler.timesteps[self.scheduler.step_index]
text_encoder_output = inputs["text_encoder_output"]["context"]
do_classifier_free_guidance = self.config.guidance_scale > 1.0
latent_model_input = self.scheduler.latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
timestep = t.expand(latent_model_input.shape[0])
hidden_states, encoder_hidden_states, emb, infer_shapes = self.pre_infer.infer(
self.pre_weight,
latent_model_input[0],
timestep,
text_encoder_output[0],
)
hidden_states, encoder_hidden_states = self.transformer_infer.infer(
self.transformer_weights,
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=emb,
)
noise_pred = self.post_infer.infer(self.post_weight, hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=emb, infer_shapes=infer_shapes)
noise_pred = noise_pred.float()
if self.config.use_dynamic_cfg: # True
self.scheduler.guidance_scale = 1 + self.scheduler.guidance_scale * ((1 - math.cos(math.pi * ((self.scheduler.infer_steps - t.item()) / self.scheduler.infer_steps) ** 5.0)) / 2)
if do_classifier_free_guidance: # False
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.scheduler.guidance_scale * (noise_pred_text - noise_pred_uncond)
self.scheduler.noise_pred = noise_pred
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment