Commit a50bcc53 authored by Dongz's avatar Dongz Committed by Yang Yong(雍洋)
Browse files

add lint feature and minor fix (#7)

* [minor]: optimize dockerfile for fewer layer

* [feature]: add pre-commit lint, update readme for contribution guidance

* [minor]: fix run shell privileges

* [auto]: first lint without rule F, fix rule E

* [minor]: fix docker file error
parent 3b460075
...@@ -3,7 +3,7 @@ import torch.distributed as dist ...@@ -3,7 +3,7 @@ import torch.distributed as dist
def pre_process(latent_model_input, freqs_cos, freqs_sin): def pre_process(latent_model_input, freqs_cos, freqs_sin):
''' """
对输入的潜在模型数据和频率数据进行预处理,进行切分以适应分布式计算。 对输入的潜在模型数据和频率数据进行预处理,进行切分以适应分布式计算。
参数: 参数:
...@@ -13,7 +13,7 @@ def pre_process(latent_model_input, freqs_cos, freqs_sin): ...@@ -13,7 +13,7 @@ def pre_process(latent_model_input, freqs_cos, freqs_sin):
返回: 返回:
tuple: 处理后的 latent_model_input, freqs_cos, freqs_sin 和切分维度 split_dim tuple: 处理后的 latent_model_input, freqs_cos, freqs_sin 和切分维度 split_dim
''' """
# 获取当前进程的世界大小和当前进程的排名 # 获取当前进程的世界大小和当前进程的排名
world_size = dist.get_world_size() world_size = dist.get_world_size()
cur_rank = dist.get_rank() cur_rank = dist.get_rank()
...@@ -25,7 +25,7 @@ def pre_process(latent_model_input, freqs_cos, freqs_sin): ...@@ -25,7 +25,7 @@ def pre_process(latent_model_input, freqs_cos, freqs_sin):
split_dim = -1 # 按宽度切分 split_dim = -1 # 按宽度切分
else: else:
raise ValueError(f"Cannot split video sequence into world size ({world_size}) parts evenly") raise ValueError(f"Cannot split video sequence into world size ({world_size}) parts evenly")
# 获取时间维度、处理后的高度和宽度 # 获取时间维度、处理后的高度和宽度
temporal_size, h, w = latent_model_input.shape[2], latent_model_input.shape[3] // 2, latent_model_input.shape[4] // 2 temporal_size, h, w = latent_model_input.shape[2], latent_model_input.shape[3] // 2, latent_model_input.shape[4] // 2
...@@ -62,7 +62,7 @@ def post_process(output, split_dim): ...@@ -62,7 +62,7 @@ def post_process(output, split_dim):
# 创建一个列表,用于存储所有进程的输出 # 创建一个列表,用于存储所有进程的输出
gathered_outputs = [torch.empty_like(output) for _ in range(world_size)] gathered_outputs = [torch.empty_like(output) for _ in range(world_size)]
# 收集所有进程的输出 # 收集所有进程的输出
dist.all_gather(gathered_outputs, output) dist.all_gather(gathered_outputs, output)
......
...@@ -3,7 +3,7 @@ import torch.distributed as dist ...@@ -3,7 +3,7 @@ import torch.distributed as dist
def pre_process(latent_model_input, freqs_cos, freqs_sin): def pre_process(latent_model_input, freqs_cos, freqs_sin):
''' """
对输入的潜在模型数据和频率数据进行预处理,进行切分以适应分布式计算。 对输入的潜在模型数据和频率数据进行预处理,进行切分以适应分布式计算。
参数: 参数:
...@@ -13,7 +13,7 @@ def pre_process(latent_model_input, freqs_cos, freqs_sin): ...@@ -13,7 +13,7 @@ def pre_process(latent_model_input, freqs_cos, freqs_sin):
返回: 返回:
tuple: 处理后的 latent_model_input, freqs_cos, freqs_sin 和切分维度 split_dim tuple: 处理后的 latent_model_input, freqs_cos, freqs_sin 和切分维度 split_dim
''' """
# 获取当前进程的世界大小和当前进程的排名 # 获取当前进程的世界大小和当前进程的排名
world_size = dist.get_world_size() world_size = dist.get_world_size()
cur_rank = dist.get_rank() cur_rank = dist.get_rank()
...@@ -25,7 +25,7 @@ def pre_process(latent_model_input, freqs_cos, freqs_sin): ...@@ -25,7 +25,7 @@ def pre_process(latent_model_input, freqs_cos, freqs_sin):
split_dim = -1 # 按宽度切分 split_dim = -1 # 按宽度切分
else: else:
raise ValueError(f"Cannot split video sequence into world size ({world_size}) parts evenly") raise ValueError(f"Cannot split video sequence into world size ({world_size}) parts evenly")
# 获取时间维度、处理后的高度和宽度 # 获取时间维度、处理后的高度和宽度
temporal_size, h, w = latent_model_input.shape[2], latent_model_input.shape[3] // 2, latent_model_input.shape[4] // 2 temporal_size, h, w = latent_model_input.shape[2], latent_model_input.shape[3] // 2, latent_model_input.shape[4] // 2
...@@ -62,7 +62,7 @@ def post_process(output, split_dim): ...@@ -62,7 +62,7 @@ def post_process(output, split_dim):
# 创建一个列表,用于存储所有进程的输出 # 创建一个列表,用于存储所有进程的输出
gathered_outputs = [torch.empty_like(output) for _ in range(world_size)] gathered_outputs = [torch.empty_like(output) for _ in range(world_size)]
# 收集所有进程的输出 # 收集所有进程的输出
dist.all_gather(gathered_outputs, output) dist.all_gather(gathered_outputs, output)
......
...@@ -7,23 +7,22 @@ def pre_process(x): ...@@ -7,23 +7,22 @@ def pre_process(x):
world_size = dist.get_world_size() world_size = dist.get_world_size()
cur_rank = dist.get_rank() cur_rank = dist.get_rank()
x = torch.chunk( x = torch.chunk(x, world_size, dim=0)[cur_rank]
x, world_size, dim=0
)[cur_rank]
return x return x
def post_process(x): def post_process(x):
# 获取当前进程的世界大小 # 获取当前进程的世界大小
world_size = dist.get_world_size() world_size = dist.get_world_size()
# 创建一个列表,用于存储所有进程的输出 # 创建一个列表,用于存储所有进程的输出
gathered_x = [torch.empty_like(x) for _ in range(world_size)] gathered_x = [torch.empty_like(x) for _ in range(world_size)]
# 收集所有进程的输出 # 收集所有进程的输出
dist.all_gather(gathered_x, x) dist.all_gather(gathered_x, x)
# 在指定的维度上合并所有进程的输出 # 在指定的维度上合并所有进程的输出
combined_output = torch.cat(gathered_x, dim=0) combined_output = torch.cat(gathered_x, dim=0)
return combined_output # 返回合并后的输出 return combined_output # 返回合并后的输出
\ No newline at end of file
...@@ -33,15 +33,11 @@ def GiB(val): ...@@ -33,15 +33,11 @@ def GiB(val):
def add_help(description): def add_help(description):
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(description=description, formatter_class=argparse.ArgumentDefaultsHelpFormatter)
description=description, formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
args, _ = parser.parse_known_args() args, _ = parser.parse_known_args()
def find_sample_data( def find_sample_data(description="Runs a TensorRT Python sample", subfolder="", find_files=[], err_msg=""):
description="Runs a TensorRT Python sample", subfolder="", find_files=[], err_msg=""
):
""" """
Parses sample arguments. Parses sample arguments.
...@@ -56,9 +52,7 @@ def find_sample_data( ...@@ -56,9 +52,7 @@ def find_sample_data(
# Standard command-line arguments for all samples. # Standard command-line arguments for all samples.
kDEFAULT_DATA_ROOT = os.path.join(os.sep, "usr", "src", "tensorrt", "data") kDEFAULT_DATA_ROOT = os.path.join(os.sep, "usr", "src", "tensorrt", "data")
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(description=description, formatter_class=argparse.ArgumentDefaultsHelpFormatter)
description=description, formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument( parser.add_argument(
"-d", "-d",
"--datadir", "--datadir",
...@@ -73,21 +67,11 @@ def find_sample_data( ...@@ -73,21 +67,11 @@ def find_sample_data(
data_path = os.path.join(data_dir, subfolder) data_path = os.path.join(data_dir, subfolder)
if not os.path.exists(data_path): if not os.path.exists(data_path):
if data_dir != kDEFAULT_DATA_ROOT: if data_dir != kDEFAULT_DATA_ROOT:
print( print("WARNING: " + data_path + " does not exist. Trying " + data_dir + " instead.")
"WARNING: "
+ data_path
+ " does not exist. Trying "
+ data_dir
+ " instead."
)
data_path = data_dir data_path = data_dir
# Make sure data directory exists. # Make sure data directory exists.
if not (os.path.exists(data_path)) and data_dir != kDEFAULT_DATA_ROOT: if not (os.path.exists(data_path)) and data_dir != kDEFAULT_DATA_ROOT:
print( print("WARNING: {:} does not exist. Please provide the correct data path with the -d option.".format(data_path))
"WARNING: {:} does not exist. Please provide the correct data path with the -d option.".format(
data_path
)
)
return data_path return data_path
data_paths = [get_data_path(data_dir) for data_dir in args.datadir] data_paths = [get_data_path(data_dir) for data_dir in args.datadir]
...@@ -121,11 +105,7 @@ def locate_files(data_paths, filenames, err_msg=""): ...@@ -121,11 +105,7 @@ def locate_files(data_paths, filenames, err_msg=""):
# Check that all files were found # Check that all files were found
for f, filename in zip(found_files, filenames): for f, filename in zip(found_files, filenames):
if not f or not os.path.exists(f): if not f or not os.path.exists(f):
raise FileNotFoundError( raise FileNotFoundError("Could not find {:}. Searched in data paths: {:}\n{:}".format(filename, data_paths, err_msg))
"Could not find {:}. Searched in data paths: {:}\n{:}".format(
filename, data_paths, err_msg
)
)
return found_files return found_files
...@@ -143,4 +123,4 @@ def setup_timing_cache(config: trt.IBuilderConfig, timing_cache_path: os.PathLik ...@@ -143,4 +123,4 @@ def setup_timing_cache(config: trt.IBuilderConfig, timing_cache_path: os.PathLik
def save_timing_cache(config: trt.IBuilderConfig, timing_cache_path: os.PathLike): def save_timing_cache(config: trt.IBuilderConfig, timing_cache_path: os.PathLike):
timing_cache: trt.ITimingCache = config.get_timing_cache() timing_cache: trt.ITimingCache = config.get_timing_cache()
with open(timing_cache_path, "wb") as timing_cache_file: with open(timing_cache_path, "wb") as timing_cache_file:
timing_cache_file.write(memoryview(timing_cache.serialize())) timing_cache_file.write(memoryview(timing_cache.serialize()))
\ No newline at end of file
...@@ -22,6 +22,7 @@ import numpy as np ...@@ -22,6 +22,7 @@ import numpy as np
import tensorrt as trt import tensorrt as trt
from cuda import cuda, cudart from cuda import cuda, cudart
def check_cuda_err(err): def check_cuda_err(err):
if isinstance(err, cuda.CUresult): if isinstance(err, cuda.CUresult):
if err != cuda.CUresult.CUDA_SUCCESS: if err != cuda.CUresult.CUDA_SUCCESS:
...@@ -32,6 +33,7 @@ def check_cuda_err(err): ...@@ -32,6 +33,7 @@ def check_cuda_err(err):
else: else:
raise RuntimeError("Unknown error type: {}".format(err)) raise RuntimeError("Unknown error type: {}".format(err))
def cuda_call(call): def cuda_call(call):
err, res = call[0], call[1:] err, res = call[0], call[1:]
check_cuda_err(err) check_cuda_err(err)
...@@ -42,6 +44,7 @@ def cuda_call(call): ...@@ -42,6 +44,7 @@ def cuda_call(call):
class HostDeviceMem: class HostDeviceMem:
"""Pair of host and device memory, where the host memory is wrapped in a numpy array""" """Pair of host and device memory, where the host memory is wrapped in a numpy array"""
def __init__(self, size: int, dtype: Optional[np.dtype] = None): def __init__(self, size: int, dtype: Optional[np.dtype] = None):
dtype = dtype or np.dtype(np.uint8) dtype = dtype or np.dtype(np.uint8)
nbytes = size * dtype.itemsize nbytes = size * dtype.itemsize
...@@ -60,13 +63,11 @@ class HostDeviceMem: ...@@ -60,13 +63,11 @@ class HostDeviceMem:
def host(self, data: Union[np.ndarray, bytes]): def host(self, data: Union[np.ndarray, bytes]):
if isinstance(data, np.ndarray): if isinstance(data, np.ndarray):
if data.size > self.host.size: if data.size > self.host.size:
raise ValueError( raise ValueError(f"Tried to fit an array of size {data.size} into host memory of size {self.host.size}")
f"Tried to fit an array of size {data.size} into host memory of size {self.host.size}" np.copyto(self.host[: data.size], data.flat, casting="safe")
)
np.copyto(self.host[:data.size], data.flat, casting='safe')
else: else:
assert self.host.dtype == np.uint8 assert self.host.dtype == np.uint8
self.host[:self.nbytes] = np.frombuffer(data, dtype=np.uint8) self.host[: self.nbytes] = np.frombuffer(data, dtype=np.uint8)
@property @property
def device(self) -> int: def device(self) -> int:
...@@ -101,8 +102,7 @@ def allocate_buffers(engine: trt.ICudaEngine, profile_idx: Optional[int] = None) ...@@ -101,8 +102,7 @@ def allocate_buffers(engine: trt.ICudaEngine, profile_idx: Optional[int] = None)
shape = engine.get_tensor_shape(binding) if profile_idx is None else engine.get_tensor_profile_shape(binding, profile_idx)[-1] shape = engine.get_tensor_shape(binding) if profile_idx is None else engine.get_tensor_profile_shape(binding, profile_idx)[-1]
shape_valid = np.all([s >= 0 for s in shape]) shape_valid = np.all([s >= 0 for s in shape])
if not shape_valid and profile_idx is None: if not shape_valid and profile_idx is None:
raise ValueError(f"Binding {binding} has dynamic shape, " +\ raise ValueError(f"Binding {binding} has dynamic shape, " + "but no profile was specified.")
"but no profile was specified.")
size = trt.volume(shape) size = trt.volume(shape)
trt_type = engine.get_tensor_dtype(binding) trt_type = engine.get_tensor_dtype(binding)
...@@ -110,7 +110,7 @@ def allocate_buffers(engine: trt.ICudaEngine, profile_idx: Optional[int] = None) ...@@ -110,7 +110,7 @@ def allocate_buffers(engine: trt.ICudaEngine, profile_idx: Optional[int] = None)
try: try:
dtype = np.dtype(trt.nptype(trt_type)) dtype = np.dtype(trt.nptype(trt_type))
bindingMemory = HostDeviceMem(size, dtype) bindingMemory = HostDeviceMem(size, dtype)
except TypeError: # no numpy support: create a byte array instead (BF16, FP8, INT4) except TypeError: # no numpy support: create a byte array instead (BF16, FP8, INT4)
size = int(size * trt_type.itemsize) size = int(size * trt_type.itemsize)
bindingMemory = HostDeviceMem(size) bindingMemory = HostDeviceMem(size)
...@@ -137,6 +137,7 @@ def memcpy_host_to_device(device_ptr: int, host_arr: np.ndarray): ...@@ -137,6 +137,7 @@ def memcpy_host_to_device(device_ptr: int, host_arr: np.ndarray):
nbytes = host_arr.size * host_arr.itemsize nbytes = host_arr.size * host_arr.itemsize
cuda_call(cudart.cudaMemcpy(device_ptr, host_arr, nbytes, cudart.cudaMemcpyKind.cudaMemcpyHostToDevice)) cuda_call(cudart.cudaMemcpy(device_ptr, host_arr, nbytes, cudart.cudaMemcpyKind.cudaMemcpyHostToDevice))
# Wrapper for cudaMemcpy which infers copy size and does error checking # Wrapper for cudaMemcpy which infers copy size and does error checking
def memcpy_device_to_host(host_arr: np.ndarray, device_ptr: int): def memcpy_device_to_host(host_arr: np.ndarray, device_ptr: int):
nbytes = host_arr.size * host_arr.itemsize nbytes = host_arr.size * host_arr.itemsize
...@@ -163,8 +164,9 @@ def _do_inference_base(inputs, outputs, stream, execute_async_func): ...@@ -163,8 +164,9 @@ def _do_inference_base(inputs, outputs, stream, execute_async_func):
def do_inference(context, engine, bindings, inputs, outputs, stream): def do_inference(context, engine, bindings, inputs, outputs, stream):
def execute_async_func(): def execute_async_func():
context.execute_async_v3(stream_handle=stream) context.execute_async_v3(stream_handle=stream)
# Setup context tensor address. # Setup context tensor address.
num_io = engine.num_io_tensors num_io = engine.num_io_tensors
for i in range(num_io): for i in range(num_io):
context.set_tensor_address(engine.get_tensor_name(i), bindings[i]) context.set_tensor_address(engine.get_tensor_name(i), bindings[i])
return _do_inference_base(inputs, outputs, stream, execute_async_func) return _do_inference_base(inputs, outputs, stream, execute_async_func)
\ No newline at end of file
...@@ -8,64 +8,54 @@ class MemoryEfficientBlocks(nn.Module): ...@@ -8,64 +8,54 @@ class MemoryEfficientBlocks(nn.Module):
self.block_class = block_class self.block_class = block_class
self.num_blocks = num_blocks self.num_blocks = num_blocks
self.block_params = block_params self.block_params = block_params
# 初始化两个block # 初始化两个block
self.active_blocks = nn.ModuleList([ self.active_blocks = nn.ModuleList([block_class(**block_params) for _ in range(2)])
block_class(**block_params) for _ in range(2)
])
# 为权重加载创建独立的CUDA流,并设置优先级 # 为权重加载创建独立的CUDA流,并设置优先级
self.compute_stream = torch.cuda.Stream(priority=-1) # 高优先级 self.compute_stream = torch.cuda.Stream(priority=-1) # 高优先级
self.load_stream = torch.cuda.Stream(priority=0) # 普通优先级 self.load_stream = torch.cuda.Stream(priority=0) # 普通优先级
# 预分配固定内存用于异步传输 # 预分配固定内存用于异步传输
self.pinned_memory = torch.cuda.empty_cache() self.pinned_memory = torch.cuda.empty_cache()
torch.cuda.memory.set_per_process_memory_fraction(0.8) # 限制GPU内存使用 torch.cuda.memory.set_per_process_memory_fraction(0.8) # 限制GPU内存使用
# 用于存储预加载的权重 # 用于存储预加载的权重
# self.next_weights = None # self.next_weights = None
self.weight_buffer = [] self.weight_buffer = []
# self.current_block_idx = 0 # self.current_block_idx = 0
def initialize_weights(self, checkpoint, key): def initialize_weights(self, checkpoint, key):
"""加载所有权重到CPU内存""" """加载所有权重到CPU内存"""
# checkpoint = torch.load(checkpoint_path, map_location='cpu') # checkpoint = torch.load(checkpoint_path, map_location='cpu')
for i in range(self.num_blocks): for i in range(self.num_blocks):
block_weights = { block_weights = {k.replace(f"{key}.{i}.", ""): v for k, v in checkpoint.items() if f"{key}.{i}." in k}
k.replace(f'{key}.{i}.', ''): v
for k, v in checkpoint.items()
if f'{key}.{i}.' in k
}
self.weight_buffer.append(block_weights) self.weight_buffer.append(block_weights)
def prefetch_weights(self, block_idx): def prefetch_weights(self, block_idx):
"""在独立CUDA流中预加载下一个block的权重""" """在独立CUDA流中预加载下一个block的权重"""
with torch.cuda.stream(self.load_stream): with torch.cuda.stream(self.load_stream):
next_weights = self.weight_buffer[block_idx] next_weights = self.weight_buffer[block_idx]
next_weights = { next_weights = {k: v.cuda(non_blocking=True) for k, v in next_weights.items()}
k: v.cuda(non_blocking=True)
for k, v in next_weights.items()
}
self.active_blocks[1].load_state_dict(next_weights) self.active_blocks[1].load_state_dict(next_weights)
def swap_blocks(self): def swap_blocks(self):
"""交换两个block并更新权重""" """交换两个block并更新权重"""
# 等待计算完成 # 等待计算完成
self.compute_stream.synchronize() self.compute_stream.synchronize()
# 等待加载完成 # 等待加载完成
self.load_stream.synchronize() self.load_stream.synchronize()
# 交换blocks # 交换blocks
self.active_blocks[0], self.active_blocks[1] = \ self.active_blocks[0], self.active_blocks[1] = self.active_blocks[1], self.active_blocks[0]
self.active_blocks[1], self.active_blocks[0]
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
"""前向传播,同时进行计算和权重加载""" """前向传播,同时进行计算和权重加载"""
# import pdb; pdb.set_trace() # import pdb; pdb.set_trace()
for i in range(self.num_blocks): for i in range(self.num_blocks):
if i == 0: if i == 0:
self.active_blocks[0].load_state_dict(self.weight_buffer[0]) self.active_blocks[0].load_state_dict(self.weight_buffer[0])
# 在主计算流中进行当前block的计算 # 在主计算流中进行当前block的计算
with torch.cuda.stream(self.compute_stream): with torch.cuda.stream(self.compute_stream):
current_block = self.active_blocks[0] current_block = self.active_blocks[0]
...@@ -75,10 +65,10 @@ class MemoryEfficientBlocks(nn.Module): ...@@ -75,10 +65,10 @@ class MemoryEfficientBlocks(nn.Module):
# 在独立流中预加载下一个block的权重 # 在独立流中预加载下一个block的权重
if i < self.num_blocks - 1: if i < self.num_blocks - 1:
self.prefetch_weights(i + 1) self.prefetch_weights(i + 1)
# 交换blocks并更新权重 # 交换blocks并更新权重
self.swap_blocks() self.swap_blocks()
# 更新args中的输入为当前输出 # 更新args中的输入为当前输出
args = list(args) args = list(args)
if len(outputs) == 1: if len(outputs) == 1:
...@@ -87,5 +77,5 @@ class MemoryEfficientBlocks(nn.Module): ...@@ -87,5 +77,5 @@ class MemoryEfficientBlocks(nn.Module):
for i in range(len(outputs)): for i in range(len(outputs)):
args[i] = outputs[i] args[i] = outputs[i]
args = tuple(args) args = tuple(args)
return outputs return outputs
\ No newline at end of file
...@@ -26,7 +26,7 @@ class Conv2dWeightTemplate(metaclass=ABCMeta): ...@@ -26,7 +26,7 @@ class Conv2dWeightTemplate(metaclass=ABCMeta):
self.config = config self.config = config
@CONV2D_WEIGHT_REGISTER('Default') @CONV2D_WEIGHT_REGISTER("Default")
class Conv2dWeight(Conv2dWeightTemplate): class Conv2dWeight(Conv2dWeightTemplate):
def __init__(self, weight_name, bias_name, stride=1, padding=0, dilation=1, groups=1): def __init__(self, weight_name, bias_name, stride=1, padding=0, dilation=1, groups=1):
super().__init__(weight_name, bias_name, stride, padding, dilation, groups) super().__init__(weight_name, bias_name, stride, padding, dilation, groups)
...@@ -36,15 +36,7 @@ class Conv2dWeight(Conv2dWeightTemplate): ...@@ -36,15 +36,7 @@ class Conv2dWeight(Conv2dWeightTemplate):
self.bias = weight_dict[self.bias_name].cuda() if self.bias_name is not None else None self.bias = weight_dict[self.bias_name].cuda() if self.bias_name is not None else None
def apply(self, input_tensor): def apply(self, input_tensor):
input_tensor = torch.nn.functional.conv2d( input_tensor = torch.nn.functional.conv2d(input_tensor, weight=self.weight, bias=self.bias, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups)
input_tensor,
weight=self.weight,
bias=self.bias,
stride=self.stride,
padding=self.padding,
dilation=self.dilation,
groups=self.groups
)
return input_tensor return input_tensor
def to_cpu(self): def to_cpu(self):
......
...@@ -26,7 +26,7 @@ class Conv3dWeightTemplate(metaclass=ABCMeta): ...@@ -26,7 +26,7 @@ class Conv3dWeightTemplate(metaclass=ABCMeta):
self.config = config self.config = config
@CONV3D_WEIGHT_REGISTER('Default') @CONV3D_WEIGHT_REGISTER("Default")
class Conv3dWeight(Conv3dWeightTemplate): class Conv3dWeight(Conv3dWeightTemplate):
def __init__(self, weight_name, bias_name, stride=1, padding=0, dilation=1, groups=1): def __init__(self, weight_name, bias_name, stride=1, padding=0, dilation=1, groups=1):
super().__init__(weight_name, bias_name, stride, padding, dilation, groups) super().__init__(weight_name, bias_name, stride, padding, dilation, groups)
...@@ -36,15 +36,7 @@ class Conv3dWeight(Conv3dWeightTemplate): ...@@ -36,15 +36,7 @@ class Conv3dWeight(Conv3dWeightTemplate):
self.bias = weight_dict[self.bias_name].cuda() if self.bias_name is not None else None self.bias = weight_dict[self.bias_name].cuda() if self.bias_name is not None else None
def apply(self, input_tensor): def apply(self, input_tensor):
input_tensor = torch.nn.functional.conv3d( input_tensor = torch.nn.functional.conv3d(input_tensor, weight=self.weight, bias=self.bias, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups)
input_tensor,
weight=self.weight,
bias=self.bias,
stride=self.stride,
padding=self.padding,
dilation=self.dilation,
groups=self.groups
)
return input_tensor return input_tensor
def to_cpu(self): def to_cpu(self):
...@@ -58,11 +50,11 @@ class Conv3dWeight(Conv3dWeightTemplate): ...@@ -58,11 +50,11 @@ class Conv3dWeight(Conv3dWeightTemplate):
self.bias = self.bias.cuda() self.bias = self.bias.cuda()
@CONV3D_WEIGHT_REGISTER('Defaultt-Force-BF16') @CONV3D_WEIGHT_REGISTER("Defaultt-Force-BF16")
class Conv3dWeightForceBF16(Conv3dWeight): class Conv3dWeightForceBF16(Conv3dWeight):
def __init__(self, weight_name, bias_name, stride=1, padding=0, dilation=1, groups=1): def __init__(self, weight_name, bias_name, stride=1, padding=0, dilation=1, groups=1):
super().__init__(weight_name, bias_name, stride, padding, dilation, groups) super().__init__(weight_name, bias_name, stride, padding, dilation, groups)
def load(self, weight_dict): def load(self, weight_dict):
self.weight = weight_dict[self.weight_name].to(torch.bfloat16).cuda() self.weight = weight_dict[self.weight_name].to(torch.bfloat16).cuda()
self.bias = weight_dict[self.bias_name].to(torch.bfloat16).cuda() if self.bias_name is not None else None self.bias = weight_dict[self.bias_name].to(torch.bfloat16).cuda() if self.bias_name is not None else None
\ No newline at end of file
...@@ -3,6 +3,7 @@ from abc import ABCMeta, abstractmethod ...@@ -3,6 +3,7 @@ from abc import ABCMeta, abstractmethod
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER
from lightx2v.utils.quant_utils import IntegerQuantizer, FloatQuantizer from lightx2v.utils.quant_utils import IntegerQuantizer, FloatQuantizer
try: try:
import q8_kernels.functional as Q8F import q8_kernels.functional as Q8F
except ImportError: except ImportError:
...@@ -28,7 +29,7 @@ class MMWeightTemplate(metaclass=ABCMeta): ...@@ -28,7 +29,7 @@ class MMWeightTemplate(metaclass=ABCMeta):
self.config = config self.config = config
@MM_WEIGHT_REGISTER('Default') @MM_WEIGHT_REGISTER("Default")
class MMWeight(MMWeightTemplate): class MMWeight(MMWeightTemplate):
def __init__(self, weight_name, bias_name): def __init__(self, weight_name, bias_name):
super().__init__(weight_name, bias_name) super().__init__(weight_name, bias_name)
...@@ -57,7 +58,7 @@ class MMWeight(MMWeightTemplate): ...@@ -57,7 +58,7 @@ class MMWeight(MMWeightTemplate):
self.bias = self.bias.cuda() self.bias = self.bias.cuda()
@MM_WEIGHT_REGISTER('Default-Force-FP32') @MM_WEIGHT_REGISTER("Default-Force-FP32")
class MMWeightForceFP32(MMWeight): class MMWeightForceFP32(MMWeight):
def __init__(self, weight_name, bias_name): def __init__(self, weight_name, bias_name):
super().__init__(weight_name, bias_name) super().__init__(weight_name, bias_name)
...@@ -69,29 +70,30 @@ class MMWeightForceFP32(MMWeight): ...@@ -69,29 +70,30 @@ class MMWeightForceFP32(MMWeight):
self.bias = self.bias.to(torch.float32) self.bias = self.bias.to(torch.float32)
@MM_WEIGHT_REGISTER('W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm') @MM_WEIGHT_REGISTER("W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm")
class MMWeightWfp8channelAfp8channeldynamicVllm(MMWeightTemplate): class MMWeightWfp8channelAfp8channeldynamicVllm(MMWeightTemplate):
''' """
Name: W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm Name: W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm
Quant MM: Quant MM:
Weight: fp8 perchannel sym Weight: fp8 perchannel sym
Act: fp8 perchannel dynamic sym Act: fp8 perchannel dynamic sym
Kernel: vllm Kernel: vllm
''' """
def __init__(self, weight_name, bias_name): def __init__(self, weight_name, bias_name):
super().__init__(weight_name, bias_name) super().__init__(weight_name, bias_name)
def load(self, weight_dict): def load(self, weight_dict):
if self.config.get('weight_auto_quant', True): if self.config.get("weight_auto_quant", True):
self.weight = weight_dict[self.weight_name].to(torch.float32).cuda() self.weight = weight_dict[self.weight_name].to(torch.float32).cuda()
w_quantizer = FloatQuantizer('e4m3', True, 'channel') w_quantizer = FloatQuantizer("e4m3", True, "channel")
self.weight, self.weight_scale, _ = w_quantizer.real_quant_tensor(self.weight) self.weight, self.weight_scale, _ = w_quantizer.real_quant_tensor(self.weight)
self.weight = self.weight.to(torch.float8_e4m3fn).t().cuda() self.weight = self.weight.to(torch.float8_e4m3fn).t().cuda()
self.weight_scale = self.weight_scale.to(torch.float32).cuda() self.weight_scale = self.weight_scale.to(torch.float32).cuda()
else: else:
self.weight = weight_dict[self.weight_name].t().cuda() self.weight = weight_dict[self.weight_name].t().cuda()
self.weight_scale = weight_dict[self.weight_name.rstrip(".weight") + '.weight_scale'].cuda() self.weight_scale = weight_dict[self.weight_name.rstrip(".weight") + ".weight_scale"].cuda()
self.bias = weight_dict[self.bias_name].cuda() if self.bias_name is not None else None self.bias = weight_dict[self.bias_name].cuda() if self.bias_name is not None else None
def apply(self, input_tensor): def apply(self, input_tensor):
...@@ -116,29 +118,30 @@ class MMWeightWfp8channelAfp8channeldynamicVllm(MMWeightTemplate): ...@@ -116,29 +118,30 @@ class MMWeightWfp8channelAfp8channeldynamicVllm(MMWeightTemplate):
self.bias = self.bias.cuda() self.bias = self.bias.cuda()
@MM_WEIGHT_REGISTER('W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm') @MM_WEIGHT_REGISTER("W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm")
class MMWeightWint8channelAint8channeldynamicVllm(MMWeightTemplate): class MMWeightWint8channelAint8channeldynamicVllm(MMWeightTemplate):
''' """
Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm
Quant MM: Quant MM:
Weight: int8 perchannel sym Weight: int8 perchannel sym
Act: int8 perchannel dynamic sym Act: int8 perchannel dynamic sym
Kernel: vllm Kernel: vllm
''' """
def __init__(self, weight_name, bias_name): def __init__(self, weight_name, bias_name):
super().__init__(weight_name, bias_name) super().__init__(weight_name, bias_name)
def load(self, weight_dict): def load(self, weight_dict):
if self.config.get('weight_auto_quant', True): if self.config.get("weight_auto_quant", True):
self.weight = weight_dict[self.weight_name].to(torch.float32).cuda() self.weight = weight_dict[self.weight_name].to(torch.float32).cuda()
w_quantizer = IntegerQuantizer(8, True, 'channel') w_quantizer = IntegerQuantizer(8, True, "channel")
self.weight, self.weight_scale, _ = w_quantizer.real_quant_tensor(self.weight) self.weight, self.weight_scale, _ = w_quantizer.real_quant_tensor(self.weight)
self.weight = self.weight.to(torch.int8).t().cuda() self.weight = self.weight.to(torch.int8).t().cuda()
self.weight_scale = self.weight_scale.to(torch.float32).cuda() self.weight_scale = self.weight_scale.to(torch.float32).cuda()
else: else:
self.weight = weight_dict[self.weight_name].t().cuda() self.weight = weight_dict[self.weight_name].t().cuda()
self.weight_scale = weight_dict[self.weight_name.rstrip(".weight") + '.weight_scale'].cuda() self.weight_scale = weight_dict[self.weight_name.rstrip(".weight") + ".weight_scale"].cuda()
self.bias = weight_dict[self.bias_name].cuda() if self.bias_name is not None else None self.bias = weight_dict[self.bias_name].cuda() if self.bias_name is not None else None
def apply(self, input_tensor): def apply(self, input_tensor):
...@@ -163,29 +166,30 @@ class MMWeightWint8channelAint8channeldynamicVllm(MMWeightTemplate): ...@@ -163,29 +166,30 @@ class MMWeightWint8channelAint8channeldynamicVllm(MMWeightTemplate):
self.bias = self.bias.cuda() self.bias = self.bias.cuda()
@MM_WEIGHT_REGISTER('W-int8-channel-sym-A-int8-channel-sym-dynamic-Q8F') @MM_WEIGHT_REGISTER("W-int8-channel-sym-A-int8-channel-sym-dynamic-Q8F")
class MMWeightWint8channelAint8channeldynamicQ8F(MMWeightTemplate): class MMWeightWint8channelAint8channeldynamicQ8F(MMWeightTemplate):
''' """
Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Q8F Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Q8F
Quant MM: Quant MM:
Weight: int8 perchannel sym Weight: int8 perchannel sym
Act: int8 perchannel dynamic sym Act: int8 perchannel dynamic sym
Kernel: Q8F Kernel: Q8F
''' """
def __init__(self, weight_name, bias_name): def __init__(self, weight_name, bias_name):
super().__init__(weight_name, bias_name) super().__init__(weight_name, bias_name)
def load(self, weight_dict): def load(self, weight_dict):
if self.config.get('weight_auto_quant', True): if self.config.get("weight_auto_quant", True):
self.weight = weight_dict[self.weight_name].cuda() self.weight = weight_dict[self.weight_name].cuda()
w_quantizer = IntegerQuantizer(8, True, 'channel') w_quantizer = IntegerQuantizer(8, True, "channel")
self.weight, self.weight_scale, _ = w_quantizer.real_quant_tensor(self.weight) self.weight, self.weight_scale, _ = w_quantizer.real_quant_tensor(self.weight)
self.weight = self.weight.to(torch.int8) self.weight = self.weight.to(torch.int8)
self.weight_scale = self.weight_scale.to(torch.float32) self.weight_scale = self.weight_scale.to(torch.float32)
else: else:
self.weight = weight_dict[self.weight_name].cuda() self.weight = weight_dict[self.weight_name].cuda()
self.weight_scale = weight_dict[self.weight_name.rstrip(".weight") + '.weight_scale'].cuda() self.weight_scale = weight_dict[self.weight_name.rstrip(".weight") + ".weight_scale"].cuda()
self.bias = weight_dict[self.bias_name].float().cuda() if self.bias_name is not None else None self.bias = weight_dict[self.bias_name].float().cuda() if self.bias_name is not None else None
def apply(self, input_tensor, act=None): def apply(self, input_tensor, act=None):
...@@ -206,29 +210,30 @@ class MMWeightWint8channelAint8channeldynamicQ8F(MMWeightTemplate): ...@@ -206,29 +210,30 @@ class MMWeightWint8channelAint8channeldynamicQ8F(MMWeightTemplate):
self.bias = self.bias.cuda() self.bias = self.bias.cuda()
@MM_WEIGHT_REGISTER('W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Q8F') @MM_WEIGHT_REGISTER("W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Q8F")
class MMWeightWfp8channelAfp8channeldynamicQ8F(MMWeightTemplate): class MMWeightWfp8channelAfp8channeldynamicQ8F(MMWeightTemplate):
''' """
Name: W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Q8F Name: W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Q8F
Quant MM: Quant MM:
Weight: fp8 perchannel sym Weight: fp8 perchannel sym
Act: fp8 perchannel dynamic sym Act: fp8 perchannel dynamic sym
Kernel: Q8F Kernel: Q8F
''' """
def __init__(self, weight_name, bias_name): def __init__(self, weight_name, bias_name):
super().__init__(weight_name, bias_name) super().__init__(weight_name, bias_name)
def load(self, weight_dict): def load(self, weight_dict):
if self.config.get('weight_auto_quant', True): if self.config.get("weight_auto_quant", True):
self.weight = weight_dict[self.weight_name].cuda() self.weight = weight_dict[self.weight_name].cuda()
w_quantizer = FloatQuantizer('e4m3', True, 'channel') w_quantizer = FloatQuantizer("e4m3", True, "channel")
self.weight, self.weight_scale, _ = w_quantizer.real_quant_tensor(self.weight) self.weight, self.weight_scale, _ = w_quantizer.real_quant_tensor(self.weight)
self.weight = self.weight.to(torch.float8_e4m3fn) self.weight = self.weight.to(torch.float8_e4m3fn)
self.weight_scale = self.weight_scale.to(torch.float32) self.weight_scale = self.weight_scale.to(torch.float32)
else: else:
self.weight = weight_dict[self.weight_name].cuda() self.weight = weight_dict[self.weight_name].cuda()
self.weight_scale = weight_dict[self.weight_name.rstrip(".weight") + '.weight_scale'].cuda() self.weight_scale = weight_dict[self.weight_name.rstrip(".weight") + ".weight_scale"].cuda()
self.bias = weight_dict[self.bias_name].float().cuda() if self.bias_name is not None else None self.bias = weight_dict[self.bias_name].float().cuda() if self.bias_name is not None else None
def apply(self, input_tensor): def apply(self, input_tensor):
...@@ -249,41 +254,40 @@ class MMWeightWfp8channelAfp8channeldynamicQ8F(MMWeightTemplate): ...@@ -249,41 +254,40 @@ class MMWeightWfp8channelAfp8channeldynamicQ8F(MMWeightTemplate):
self.bias = self.bias.cuda() self.bias = self.bias.cuda()
if __name__ == '__main__': if __name__ == "__main__":
weight_dict = { weight_dict = {
'xx.weight': torch.randn(8192, 4096).to(torch.float8_e4m3fn), "xx.weight": torch.randn(8192, 4096).to(torch.float8_e4m3fn),
'xx.bias': torch.randn(8192).to(torch.bfloat16), "xx.bias": torch.randn(8192).to(torch.bfloat16),
'xx.weight_scale': torch.randn(8192, 1).to(torch.float32), "xx.weight_scale": torch.randn(8192, 1).to(torch.float32),
} }
mm_weight = MM_WEIGHT_REGISTER['W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm']('xx.weight', 'xx.bias') mm_weight = MM_WEIGHT_REGISTER["W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm"]("xx.weight", "xx.bias")
mm_weight.set_config({'weight_auto_quant': False}) mm_weight.set_config({"weight_auto_quant": False})
mm_weight.load(weight_dict) mm_weight.load(weight_dict)
input_tensor = torch.randn(1024, 4096).to(torch.bfloat16).cuda() input_tensor = torch.randn(1024, 4096).to(torch.bfloat16).cuda()
output_tensor = mm_weight.apply(input_tensor) output_tensor = mm_weight.apply(input_tensor)
print(output_tensor.shape) print(output_tensor.shape)
weight_dict = { weight_dict = {
'xx.weight': torch.randn(8192, 4096), "xx.weight": torch.randn(8192, 4096),
'xx.bias': torch.randn(8192).to(torch.bfloat16), "xx.bias": torch.randn(8192).to(torch.bfloat16),
} }
mm_weight = MM_WEIGHT_REGISTER['W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm']('xx.weight', 'xx.bias') mm_weight = MM_WEIGHT_REGISTER["W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm"]("xx.weight", "xx.bias")
mm_weight.set_config({'weight_auto_quant': True}) mm_weight.set_config({"weight_auto_quant": True})
mm_weight.load(weight_dict) mm_weight.load(weight_dict)
input_tensor = torch.randn(1024, 4096).to(torch.bfloat16).cuda() input_tensor = torch.randn(1024, 4096).to(torch.bfloat16).cuda()
output_tensor = mm_weight.apply(input_tensor) output_tensor = mm_weight.apply(input_tensor)
print(output_tensor.shape) print(output_tensor.shape)
weight_dict = { weight_dict = {
'xx.weight': torch.randn(8192, 4096), "xx.weight": torch.randn(8192, 4096),
'xx.bias': torch.randn(8192).to(torch.bfloat16), "xx.bias": torch.randn(8192).to(torch.bfloat16),
} }
mm_weight = MM_WEIGHT_REGISTER['W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm']('xx.weight', 'xx.bias') mm_weight = MM_WEIGHT_REGISTER["W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm"]("xx.weight", "xx.bias")
mm_weight.set_config({'weight_auto_quant': True}) mm_weight.set_config({"weight_auto_quant": True})
mm_weight.load(weight_dict) mm_weight.load(weight_dict)
input_tensor = torch.randn(1024, 4096).to(torch.bfloat16).cuda() input_tensor = torch.randn(1024, 4096).to(torch.bfloat16).cuda()
output_tensor = mm_weight.apply(input_tensor) output_tensor = mm_weight.apply(input_tensor)
print(output_tensor.shape) print(output_tensor.shape)
\ No newline at end of file
...@@ -4,50 +4,41 @@ from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER ...@@ -4,50 +4,41 @@ from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER
from lightx2v.utils.quant_utils import IntegerQuantizer, FloatQuantizer from lightx2v.utils.quant_utils import IntegerQuantizer, FloatQuantizer
@MM_WEIGHT_REGISTER('Calib') @MM_WEIGHT_REGISTER("Calib")
class MMWeightCalib(MMWeight): class MMWeightCalib(MMWeight):
def __init__(self, weight_name, bias_name): def __init__(self, weight_name, bias_name):
super().__init__(weight_name, bias_name) super().__init__(weight_name, bias_name)
def load(self, weight_dict): def load(self, weight_dict):
assert self.config and self.config.get('mm_type', 'Default') != 'Default' assert self.config and self.config.get("mm_type", "Default") != "Default"
self.weight = weight_dict[self.weight_name] self.weight = weight_dict[self.weight_name]
self.get_quantizer() self.get_quantizer()
shape_and_dtype = self.get_quant_shape_and_dtype(self.weight.shape) shape_and_dtype = self.get_quant_shape_and_dtype(self.weight.shape)
self.realq_weight, self.scales, self.zeros = self.w_quantizer.real_quant_tensor(self.weight) self.realq_weight, self.scales, self.zeros = self.w_quantizer.real_quant_tensor(self.weight)
self.realq_weight = self.realq_weight.view(shape_and_dtype['tensor'][0]).contiguous().to(shape_and_dtype['tensor'][1]) self.realq_weight = self.realq_weight.view(shape_and_dtype["tensor"][0]).contiguous().to(shape_and_dtype["tensor"][1])
self.scales = self.scales.view(shape_and_dtype['scales'][0]).contiguous().to(shape_and_dtype['scales'][1]) self.scales = self.scales.view(shape_and_dtype["scales"][0]).contiguous().to(shape_and_dtype["scales"][1])
if self.zeros is not None: if self.zeros is not None:
self.zeros = self.zeros.view(shape_and_dtype['zeros'][0]).contiguous().to(shape_and_dtype['zeros'][1]) self.zeros = self.zeros.view(shape_and_dtype["zeros"][0]).contiguous().to(shape_and_dtype["zeros"][1])
def apply(self, input_tensor): def apply(self, input_tensor):
return super().apply(input_tensor) return super().apply(input_tensor)
def get_quantizer(self): def get_quantizer(self):
if self.config['mm_type'] == 'W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm': if self.config["mm_type"] == "W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm":
self.w_setting = { self.w_setting = {"bit": "e4m3", "symmetric": True, "granularity": "channel"}
'bit': 'e4m3', self.a_setting = {"bit": "e4m3", "symmetric": True, "granularity": "channel"}
'symmetric': True,
'granularity': 'channel'
}
self.a_setting = {
'bit': 'e4m3',
'symmetric': True,
'granularity': 'channel'
}
self.w_quantizer = FloatQuantizer(**self.w_setting) self.w_quantizer = FloatQuantizer(**self.w_setting)
self.a_quantizer = FloatQuantizer(**self.a_setting) self.a_quantizer = FloatQuantizer(**self.a_setting)
self.act_dynamic_quant = True self.act_dynamic_quant = True
else: else:
raise NotImplementedError(f'Unsupported mm_type: {self.config["mm_type"]}') raise NotImplementedError(f"Unsupported mm_type: {self.config['mm_type']}")
def get_quant_shape_and_dtype(self, shape): def get_quant_shape_and_dtype(self, shape):
if self.config['mm_type'] == 'W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm': if self.config["mm_type"] == "W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm":
return { return {
'tensor': (shape, torch.float8_e5m2), "tensor": (shape, torch.float8_e5m2),
'scales': ((shape[0], 1), torch.float32), "scales": ((shape[0], 1), torch.float32),
'zeros': None, "zeros": None,
} }
else: else:
raise NotImplementedError(f'Unsupported mm_type: {self.config["mm_type"]}') raise NotImplementedError(f"Unsupported mm_type: {self.config['mm_type']}")
...@@ -35,7 +35,7 @@ class LNWeightTemplate(metaclass=ABCMeta): ...@@ -35,7 +35,7 @@ class LNWeightTemplate(metaclass=ABCMeta):
self.bias = self.bias.cuda() self.bias = self.bias.cuda()
@LN_WEIGHT_REGISTER('Default') @LN_WEIGHT_REGISTER("Default")
class LNWeight(LNWeightTemplate): class LNWeight(LNWeightTemplate):
def __init__(self, weight_name, bias_name, eps=1e-6): def __init__(self, weight_name, bias_name, eps=1e-6):
super().__init__(weight_name, bias_name, eps) super().__init__(weight_name, bias_name, eps)
......
...@@ -28,7 +28,7 @@ class RMSWeightTemplate(metaclass=ABCMeta): ...@@ -28,7 +28,7 @@ class RMSWeightTemplate(metaclass=ABCMeta):
self.weight = self.weight.cuda() self.weight = self.weight.cuda()
@RMS_WEIGHT_REGISTER('Default') @RMS_WEIGHT_REGISTER("Default")
class RMSWeight(RMSWeightTemplate): class RMSWeight(RMSWeightTemplate):
def __init__(self, weight_name, eps=1e-6): def __init__(self, weight_name, eps=1e-6):
super().__init__(weight_name, eps) super().__init__(weight_name, eps)
...@@ -39,7 +39,7 @@ class RMSWeight(RMSWeightTemplate): ...@@ -39,7 +39,7 @@ class RMSWeight(RMSWeightTemplate):
return input_tensor return input_tensor
@RMS_WEIGHT_REGISTER('FP32') @RMS_WEIGHT_REGISTER("FP32")
class RMSWeightFP32(RMSWeight): class RMSWeightFP32(RMSWeight):
def __init__(self, weight_name, eps=1e-6): def __init__(self, weight_name, eps=1e-6):
super().__init__(weight_name, eps) super().__init__(weight_name, eps)
...@@ -52,7 +52,7 @@ class RMSWeightFP32(RMSWeight): ...@@ -52,7 +52,7 @@ class RMSWeightFP32(RMSWeight):
return input_tensor return input_tensor
@RMS_WEIGHT_REGISTER('sgl-kernel') @RMS_WEIGHT_REGISTER("sgl-kernel")
class RMSWeightSgl(RMSWeight): class RMSWeightSgl(RMSWeight):
def __init__(self, weight_name, eps=1e-6): def __init__(self, weight_name, eps=1e-6):
super().__init__(weight_name, eps) super().__init__(weight_name, eps)
......
...@@ -14,9 +14,9 @@ from lightx2v.text2v.models.text_encoders.hf.t5.tokenizer import HuggingfaceToke ...@@ -14,9 +14,9 @@ from lightx2v.text2v.models.text_encoders.hf.t5.tokenizer import HuggingfaceToke
from .xlm_roberta import XLMRoberta from .xlm_roberta import XLMRoberta
__all__ = [ __all__ = [
'XLMRobertaCLIP', "XLMRobertaCLIP",
'clip_xlm_roberta_vit_h_14', "clip_xlm_roberta_vit_h_14",
'CLIPModel', "CLIPModel",
] ]
...@@ -27,38 +27,27 @@ def pos_interpolate(pos, seq_len): ...@@ -27,38 +27,27 @@ def pos_interpolate(pos, seq_len):
src_grid = int(math.sqrt(pos.size(1))) src_grid = int(math.sqrt(pos.size(1)))
tar_grid = int(math.sqrt(seq_len)) tar_grid = int(math.sqrt(seq_len))
n = pos.size(1) - src_grid * src_grid n = pos.size(1) - src_grid * src_grid
return torch.cat([ return torch.cat(
pos[:, :n], [
F.interpolate( pos[:, :n],
pos[:, n:].float().reshape(1, src_grid, src_grid, -1).permute( F.interpolate(pos[:, n:].float().reshape(1, src_grid, src_grid, -1).permute(0, 3, 1, 2), size=(tar_grid, tar_grid), mode="bicubic", align_corners=False).flatten(2).transpose(1, 2),
0, 3, 1, 2), ],
size=(tar_grid, tar_grid), dim=1,
mode='bicubic', )
align_corners=False).flatten(2).transpose(1, 2)
],
dim=1)
class QuickGELU(nn.Module): class QuickGELU(nn.Module):
def forward(self, x): def forward(self, x):
return x * torch.sigmoid(1.702 * x) return x * torch.sigmoid(1.702 * x)
class LayerNorm(nn.LayerNorm): class LayerNorm(nn.LayerNorm):
def forward(self, x): def forward(self, x):
return super().forward(x.float()).type_as(x) return super().forward(x.float()).type_as(x)
class SelfAttention(nn.Module): class SelfAttention(nn.Module):
def __init__(self, dim, num_heads, causal=False, attn_dropout=0.0, proj_dropout=0.0):
def __init__(self,
dim,
num_heads,
causal=False,
attn_dropout=0.0,
proj_dropout=0.0):
assert dim % num_heads == 0 assert dim % num_heads == 0
super().__init__() super().__init__()
self.dim = dim self.dim = dim
...@@ -82,7 +71,7 @@ class SelfAttention(nn.Module): ...@@ -82,7 +71,7 @@ class SelfAttention(nn.Module):
q, k, v = self.to_qkv(x).view(b, s, 3, n, d).unbind(2) q, k, v = self.to_qkv(x).view(b, s, 3, n, d).unbind(2)
# compute attention # compute attention
x = attention(q=q, k=k, v=v, attention_type='torch_sdpa') x = attention(q=q, k=k, v=v, attention_type="torch_sdpa")
x = x.reshape(b, s, c) x = x.reshape(b, s, c)
# output # output
...@@ -92,7 +81,6 @@ class SelfAttention(nn.Module): ...@@ -92,7 +81,6 @@ class SelfAttention(nn.Module):
class SwiGLU(nn.Module): class SwiGLU(nn.Module):
def __init__(self, dim, mid_dim): def __init__(self, dim, mid_dim):
super().__init__() super().__init__()
self.dim = dim self.dim = dim
...@@ -110,18 +98,8 @@ class SwiGLU(nn.Module): ...@@ -110,18 +98,8 @@ class SwiGLU(nn.Module):
class AttentionBlock(nn.Module): class AttentionBlock(nn.Module):
def __init__(self, dim, mlp_ratio, num_heads, post_norm=False, causal=False, activation="quick_gelu", attn_dropout=0.0, proj_dropout=0.0, norm_eps=1e-5):
def __init__(self, assert activation in ["quick_gelu", "gelu", "swi_glu"]
dim,
mlp_ratio,
num_heads,
post_norm=False,
causal=False,
activation='quick_gelu',
attn_dropout=0.0,
proj_dropout=0.0,
norm_eps=1e-5):
assert activation in ['quick_gelu', 'gelu', 'swi_glu']
super().__init__() super().__init__()
self.dim = dim self.dim = dim
self.mlp_ratio = mlp_ratio self.mlp_ratio = mlp_ratio
...@@ -132,16 +110,12 @@ class AttentionBlock(nn.Module): ...@@ -132,16 +110,12 @@ class AttentionBlock(nn.Module):
# layers # layers
self.norm1 = LayerNorm(dim, eps=norm_eps) self.norm1 = LayerNorm(dim, eps=norm_eps)
self.attn = SelfAttention(dim, num_heads, causal, attn_dropout, self.attn = SelfAttention(dim, num_heads, causal, attn_dropout, proj_dropout)
proj_dropout)
self.norm2 = LayerNorm(dim, eps=norm_eps) self.norm2 = LayerNorm(dim, eps=norm_eps)
if activation == 'swi_glu': if activation == "swi_glu":
self.mlp = SwiGLU(dim, int(dim * mlp_ratio)) self.mlp = SwiGLU(dim, int(dim * mlp_ratio))
else: else:
self.mlp = nn.Sequential( self.mlp = nn.Sequential(nn.Linear(dim, int(dim * mlp_ratio)), QuickGELU() if activation == "quick_gelu" else nn.GELU(), nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout))
nn.Linear(dim, int(dim * mlp_ratio)),
QuickGELU() if activation == 'quick_gelu' else nn.GELU(),
nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout))
def forward(self, x): def forward(self, x):
if self.post_norm: if self.post_norm:
...@@ -154,14 +128,7 @@ class AttentionBlock(nn.Module): ...@@ -154,14 +128,7 @@ class AttentionBlock(nn.Module):
class AttentionPool(nn.Module): class AttentionPool(nn.Module):
def __init__(self, dim, mlp_ratio, num_heads, activation="gelu", proj_dropout=0.0, norm_eps=1e-5):
def __init__(self,
dim,
mlp_ratio,
num_heads,
activation='gelu',
proj_dropout=0.0,
norm_eps=1e-5):
assert dim % num_heads == 0 assert dim % num_heads == 0
super().__init__() super().__init__()
self.dim = dim self.dim = dim
...@@ -178,10 +145,7 @@ class AttentionPool(nn.Module): ...@@ -178,10 +145,7 @@ class AttentionPool(nn.Module):
self.to_kv = nn.Linear(dim, dim * 2) self.to_kv = nn.Linear(dim, dim * 2)
self.proj = nn.Linear(dim, dim) self.proj = nn.Linear(dim, dim)
self.norm = LayerNorm(dim, eps=norm_eps) self.norm = LayerNorm(dim, eps=norm_eps)
self.mlp = nn.Sequential( self.mlp = nn.Sequential(nn.Linear(dim, int(dim * mlp_ratio)), QuickGELU() if activation == "quick_gelu" else nn.GELU(), nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout))
nn.Linear(dim, int(dim * mlp_ratio)),
QuickGELU() if activation == 'quick_gelu' else nn.GELU(),
nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout))
def forward(self, x): def forward(self, x):
""" """
...@@ -194,7 +158,7 @@ class AttentionPool(nn.Module): ...@@ -194,7 +158,7 @@ class AttentionPool(nn.Module):
k, v = self.to_kv(x).view(b, s, 2, n, d).unbind(2) k, v = self.to_kv(x).view(b, s, 2, n, d).unbind(2)
# compute attention # compute attention
x = attention(q=q, k=k, v=v, attention_type='torch_sdpa') x = attention(q=q, k=k, v=v, attention_type="torch_sdpa")
x = x.reshape(b, 1, c) x = x.reshape(b, 1, c)
# output # output
...@@ -207,33 +171,32 @@ class AttentionPool(nn.Module): ...@@ -207,33 +171,32 @@ class AttentionPool(nn.Module):
class VisionTransformer(nn.Module): class VisionTransformer(nn.Module):
def __init__(
def __init__(self, self,
image_size=224, image_size=224,
patch_size=16, patch_size=16,
dim=768, dim=768,
mlp_ratio=4, mlp_ratio=4,
out_dim=512, out_dim=512,
num_heads=12, num_heads=12,
num_layers=12, num_layers=12,
pool_type='token', pool_type="token",
pre_norm=True, pre_norm=True,
post_norm=False, post_norm=False,
activation='quick_gelu', activation="quick_gelu",
attn_dropout=0.0, attn_dropout=0.0,
proj_dropout=0.0, proj_dropout=0.0,
embedding_dropout=0.0, embedding_dropout=0.0,
norm_eps=1e-5): norm_eps=1e-5,
):
if image_size % patch_size != 0: if image_size % patch_size != 0:
print( print("[WARNING] image_size is not divisible by patch_size", flush=True)
'[WARNING] image_size is not divisible by patch_size', assert pool_type in ("token", "token_fc", "attn_pool")
flush=True)
assert pool_type in ('token', 'token_fc', 'attn_pool')
out_dim = out_dim or dim out_dim = out_dim or dim
super().__init__() super().__init__()
self.image_size = image_size self.image_size = image_size
self.patch_size = patch_size self.patch_size = patch_size
self.num_patches = (image_size // patch_size)**2 self.num_patches = (image_size // patch_size) ** 2
self.dim = dim self.dim = dim
self.mlp_ratio = mlp_ratio self.mlp_ratio = mlp_ratio
self.out_dim = out_dim self.out_dim = out_dim
...@@ -245,43 +208,31 @@ class VisionTransformer(nn.Module): ...@@ -245,43 +208,31 @@ class VisionTransformer(nn.Module):
# embeddings # embeddings
gain = 1.0 / math.sqrt(dim) gain = 1.0 / math.sqrt(dim)
self.patch_embedding = nn.Conv2d( self.patch_embedding = nn.Conv2d(3, dim, kernel_size=patch_size, stride=patch_size, bias=not pre_norm)
3, if pool_type in ("token", "token_fc"):
dim,
kernel_size=patch_size,
stride=patch_size,
bias=not pre_norm)
if pool_type in ('token', 'token_fc'):
self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim)) self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
self.pos_embedding = nn.Parameter(gain * torch.randn( self.pos_embedding = nn.Parameter(gain * torch.randn(1, self.num_patches + (1 if pool_type in ("token", "token_fc") else 0), dim))
1, self.num_patches +
(1 if pool_type in ('token', 'token_fc') else 0), dim))
self.dropout = nn.Dropout(embedding_dropout) self.dropout = nn.Dropout(embedding_dropout)
# transformer # transformer
self.pre_norm = LayerNorm(dim, eps=norm_eps) if pre_norm else None self.pre_norm = LayerNorm(dim, eps=norm_eps) if pre_norm else None
self.transformer = nn.Sequential(*[ self.transformer = nn.Sequential(*[AttentionBlock(dim, mlp_ratio, num_heads, post_norm, False, activation, attn_dropout, proj_dropout, norm_eps) for _ in range(num_layers)])
AttentionBlock(dim, mlp_ratio, num_heads, post_norm, False,
activation, attn_dropout, proj_dropout, norm_eps)
for _ in range(num_layers)
])
self.post_norm = LayerNorm(dim, eps=norm_eps) self.post_norm = LayerNorm(dim, eps=norm_eps)
# head # head
if pool_type == 'token': if pool_type == "token":
self.head = nn.Parameter(gain * torch.randn(dim, out_dim)) self.head = nn.Parameter(gain * torch.randn(dim, out_dim))
elif pool_type == 'token_fc': elif pool_type == "token_fc":
self.head = nn.Linear(dim, out_dim) self.head = nn.Linear(dim, out_dim)
elif pool_type == 'attn_pool': elif pool_type == "attn_pool":
self.head = AttentionPool(dim, mlp_ratio, num_heads, activation, self.head = AttentionPool(dim, mlp_ratio, num_heads, activation, proj_dropout, norm_eps)
proj_dropout, norm_eps)
def forward(self, x, interpolation=False, use_31_block=False): def forward(self, x, interpolation=False, use_31_block=False):
b = x.size(0) b = x.size(0)
# embeddings # embeddings
x = self.patch_embedding(x).flatten(2).permute(0, 2, 1) x = self.patch_embedding(x).flatten(2).permute(0, 2, 1)
if self.pool_type in ('token', 'token_fc'): if self.pool_type in ("token", "token_fc"):
x = torch.cat([self.cls_embedding.expand(b, -1, -1), x], dim=1) x = torch.cat([self.cls_embedding.expand(b, -1, -1), x], dim=1)
if interpolation: if interpolation:
e = pos_interpolate(self.pos_embedding, x.size(1)) e = pos_interpolate(self.pos_embedding, x.size(1))
...@@ -301,16 +252,13 @@ class VisionTransformer(nn.Module): ...@@ -301,16 +252,13 @@ class VisionTransformer(nn.Module):
class XLMRobertaWithHead(XLMRoberta): class XLMRobertaWithHead(XLMRoberta):
def __init__(self, **kwargs): def __init__(self, **kwargs):
self.out_dim = kwargs.pop('out_dim') self.out_dim = kwargs.pop("out_dim")
super().__init__(**kwargs) super().__init__(**kwargs)
# head # head
mid_dim = (self.dim + self.out_dim) // 2 mid_dim = (self.dim + self.out_dim) // 2
self.head = nn.Sequential( self.head = nn.Sequential(nn.Linear(self.dim, mid_dim, bias=False), nn.GELU(), nn.Linear(mid_dim, self.out_dim, bias=False))
nn.Linear(self.dim, mid_dim, bias=False), nn.GELU(),
nn.Linear(mid_dim, self.out_dim, bias=False))
def forward(self, ids): def forward(self, ids):
# xlm-roberta # xlm-roberta
...@@ -326,32 +274,33 @@ class XLMRobertaWithHead(XLMRoberta): ...@@ -326,32 +274,33 @@ class XLMRobertaWithHead(XLMRoberta):
class XLMRobertaCLIP(nn.Module): class XLMRobertaCLIP(nn.Module):
def __init__(
def __init__(self, self,
embed_dim=1024, embed_dim=1024,
image_size=224, image_size=224,
patch_size=14, patch_size=14,
vision_dim=1280, vision_dim=1280,
vision_mlp_ratio=4, vision_mlp_ratio=4,
vision_heads=16, vision_heads=16,
vision_layers=32, vision_layers=32,
vision_pool='token', vision_pool="token",
vision_pre_norm=True, vision_pre_norm=True,
vision_post_norm=False, vision_post_norm=False,
activation='gelu', activation="gelu",
vocab_size=250002, vocab_size=250002,
max_text_len=514, max_text_len=514,
type_size=1, type_size=1,
pad_id=1, pad_id=1,
text_dim=1024, text_dim=1024,
text_heads=16, text_heads=16,
text_layers=24, text_layers=24,
text_post_norm=True, text_post_norm=True,
text_dropout=0.1, text_dropout=0.1,
attn_dropout=0.0, attn_dropout=0.0,
proj_dropout=0.0, proj_dropout=0.0,
embedding_dropout=0.0, embedding_dropout=0.0,
norm_eps=1e-5): norm_eps=1e-5,
):
super().__init__() super().__init__()
self.embed_dim = embed_dim self.embed_dim = embed_dim
self.image_size = image_size self.image_size = image_size
...@@ -389,7 +338,8 @@ class XLMRobertaCLIP(nn.Module): ...@@ -389,7 +338,8 @@ class XLMRobertaCLIP(nn.Module):
attn_dropout=attn_dropout, attn_dropout=attn_dropout,
proj_dropout=proj_dropout, proj_dropout=proj_dropout,
embedding_dropout=embedding_dropout, embedding_dropout=embedding_dropout,
norm_eps=norm_eps) norm_eps=norm_eps,
)
self.textual = XLMRobertaWithHead( self.textual = XLMRobertaWithHead(
vocab_size=vocab_size, vocab_size=vocab_size,
max_seq_len=max_text_len, max_seq_len=max_text_len,
...@@ -400,7 +350,8 @@ class XLMRobertaCLIP(nn.Module): ...@@ -400,7 +350,8 @@ class XLMRobertaCLIP(nn.Module):
num_heads=text_heads, num_heads=text_heads,
num_layers=text_layers, num_layers=text_layers,
post_norm=text_post_norm, post_norm=text_post_norm,
dropout=text_dropout) dropout=text_dropout,
)
self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([])) self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([]))
def forward(self, imgs, txt_ids): def forward(self, imgs, txt_ids):
...@@ -416,30 +367,14 @@ class XLMRobertaCLIP(nn.Module): ...@@ -416,30 +367,14 @@ class XLMRobertaCLIP(nn.Module):
return xi, xt return xi, xt
def param_groups(self): def param_groups(self):
groups = [{ groups = [
'params': [ {"params": [p for n, p in self.named_parameters() if "norm" in n or n.endswith("bias")], "weight_decay": 0.0},
p for n, p in self.named_parameters() {"params": [p for n, p in self.named_parameters() if not ("norm" in n or n.endswith("bias"))]},
if 'norm' in n or n.endswith('bias') ]
],
'weight_decay': 0.0
}, {
'params': [
p for n, p in self.named_parameters()
if not ('norm' in n or n.endswith('bias'))
]
}]
return groups return groups
def _clip(pretrained=False, def _clip(pretrained=False, pretrained_name=None, model_cls=XLMRobertaCLIP, return_transforms=False, return_tokenizer=False, tokenizer_padding="eos", dtype=torch.float32, device="cpu", **kwargs):
pretrained_name=None,
model_cls=XLMRobertaCLIP,
return_transforms=False,
return_tokenizer=False,
tokenizer_padding='eos',
dtype=torch.float32,
device='cpu',
**kwargs):
# init a model on device # init a model on device
with torch.device(device): with torch.device(device):
model = model_cls(**kwargs) model = model_cls(**kwargs)
...@@ -451,27 +386,19 @@ def _clip(pretrained=False, ...@@ -451,27 +386,19 @@ def _clip(pretrained=False,
# init transforms # init transforms
if return_transforms: if return_transforms:
# mean and std # mean and std
if 'siglip' in pretrained_name.lower(): if "siglip" in pretrained_name.lower():
mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5] mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
else: else:
mean = [0.48145466, 0.4578275, 0.40821073] mean = [0.48145466, 0.4578275, 0.40821073]
std = [0.26862954, 0.26130258, 0.27577711] std = [0.26862954, 0.26130258, 0.27577711]
# transforms # transforms
transforms = T.Compose([ transforms = T.Compose([T.Resize((model.image_size, model.image_size), interpolation=T.InterpolationMode.BICUBIC), T.ToTensor(), T.Normalize(mean=mean, std=std)])
T.Resize((model.image_size, model.image_size),
interpolation=T.InterpolationMode.BICUBIC),
T.ToTensor(),
T.Normalize(mean=mean, std=std)
])
output += (transforms,) output += (transforms,)
return output[0] if len(output) == 1 else output return output[0] if len(output) == 1 else output
def clip_xlm_roberta_vit_h_14( def clip_xlm_roberta_vit_h_14(pretrained=False, pretrained_name="open-clip-xlm-roberta-large-vit-huge-14", **kwargs):
pretrained=False,
pretrained_name='open-clip-xlm-roberta-large-vit-huge-14',
**kwargs):
cfg = dict( cfg = dict(
embed_dim=1024, embed_dim=1024,
image_size=224, image_size=224,
...@@ -480,8 +407,8 @@ def clip_xlm_roberta_vit_h_14( ...@@ -480,8 +407,8 @@ def clip_xlm_roberta_vit_h_14(
vision_mlp_ratio=4, vision_mlp_ratio=4,
vision_heads=16, vision_heads=16,
vision_layers=32, vision_layers=32,
vision_pool='token', vision_pool="token",
activation='gelu', activation="gelu",
vocab_size=250002, vocab_size=250002,
max_text_len=514, max_text_len=514,
type_size=1, type_size=1,
...@@ -493,13 +420,13 @@ def clip_xlm_roberta_vit_h_14( ...@@ -493,13 +420,13 @@ def clip_xlm_roberta_vit_h_14(
text_dropout=0.1, text_dropout=0.1,
attn_dropout=0.0, attn_dropout=0.0,
proj_dropout=0.0, proj_dropout=0.0,
embedding_dropout=0.0) embedding_dropout=0.0,
)
cfg.update(**kwargs) cfg.update(**kwargs)
return _clip(pretrained, pretrained_name, XLMRobertaCLIP, **cfg) return _clip(pretrained, pretrained_name, XLMRobertaCLIP, **cfg)
class CLIPModel: class CLIPModel:
def __init__(self, dtype, device, checkpoint_path, tokenizer_path): def __init__(self, dtype, device, checkpoint_path, tokenizer_path):
self.dtype = dtype self.dtype = dtype
self.device = device self.device = device
...@@ -507,36 +434,21 @@ class CLIPModel: ...@@ -507,36 +434,21 @@ class CLIPModel:
self.tokenizer_path = tokenizer_path self.tokenizer_path = tokenizer_path
# init model # init model
self.model, self.transforms = clip_xlm_roberta_vit_h_14( self.model, self.transforms = clip_xlm_roberta_vit_h_14(pretrained=False, return_transforms=True, return_tokenizer=False, dtype=dtype, device=device)
pretrained=False,
return_transforms=True,
return_tokenizer=False,
dtype=dtype,
device=device)
self.model = self.model.eval().requires_grad_(False) self.model = self.model.eval().requires_grad_(False)
logging.info(f'loading {checkpoint_path}') logging.info(f"loading {checkpoint_path}")
self.model.load_state_dict( self.model.load_state_dict(torch.load(checkpoint_path, map_location="cpu", weights_only=True))
torch.load(checkpoint_path, map_location='cpu', weights_only=True))
# init tokenizer # init tokenizer
self.tokenizer = HuggingfaceTokenizer( self.tokenizer = HuggingfaceTokenizer(name=tokenizer_path, seq_len=self.model.max_text_len - 2, clean="whitespace")
name=tokenizer_path,
seq_len=self.model.max_text_len - 2,
clean='whitespace')
def visual(self, videos): def visual(self, videos):
# preprocess # preprocess
size = (self.model.image_size,) * 2 size = (self.model.image_size,) * 2
videos = torch.cat([ videos = torch.cat([F.interpolate(u.transpose(0, 1), size=size, mode="bicubic", align_corners=False) for u in videos])
F.interpolate(
u.transpose(0, 1),
size=size,
mode='bicubic',
align_corners=False) for u in videos
])
videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5)) videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5))
# forward # forward
with torch.amp.autocast('cuda', dtype=self.dtype): with torch.amp.autocast("cuda", dtype=self.dtype):
out = self.model.visual(videos, use_31_block=True) out = self.model.visual(videos, use_31_block=True)
return out return out
...@@ -4,11 +4,10 @@ import torch ...@@ -4,11 +4,10 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
__all__ = ['XLMRoberta', 'xlm_roberta_large'] __all__ = ["XLMRoberta", "xlm_roberta_large"]
class SelfAttention(nn.Module): class SelfAttention(nn.Module):
def __init__(self, dim, num_heads, dropout=0.1, eps=1e-5): def __init__(self, dim, num_heads, dropout=0.1, eps=1e-5):
assert dim % num_heads == 0 assert dim % num_heads == 0
super().__init__() super().__init__()
...@@ -47,7 +46,6 @@ class SelfAttention(nn.Module): ...@@ -47,7 +46,6 @@ class SelfAttention(nn.Module):
class AttentionBlock(nn.Module): class AttentionBlock(nn.Module):
def __init__(self, dim, num_heads, post_norm, dropout=0.1, eps=1e-5): def __init__(self, dim, num_heads, post_norm, dropout=0.1, eps=1e-5):
super().__init__() super().__init__()
self.dim = dim self.dim = dim
...@@ -58,9 +56,7 @@ class AttentionBlock(nn.Module): ...@@ -58,9 +56,7 @@ class AttentionBlock(nn.Module):
# layers # layers
self.attn = SelfAttention(dim, num_heads, dropout, eps) self.attn = SelfAttention(dim, num_heads, dropout, eps)
self.norm1 = nn.LayerNorm(dim, eps=eps) self.norm1 = nn.LayerNorm(dim, eps=eps)
self.ffn = nn.Sequential( self.ffn = nn.Sequential(nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim), nn.Dropout(dropout))
nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim),
nn.Dropout(dropout))
self.norm2 = nn.LayerNorm(dim, eps=eps) self.norm2 = nn.LayerNorm(dim, eps=eps)
def forward(self, x, mask): def forward(self, x, mask):
...@@ -78,17 +74,7 @@ class XLMRoberta(nn.Module): ...@@ -78,17 +74,7 @@ class XLMRoberta(nn.Module):
XLMRobertaModel with no pooler and no LM head. XLMRobertaModel with no pooler and no LM head.
""" """
def __init__(self, def __init__(self, vocab_size=250002, max_seq_len=514, type_size=1, pad_id=1, dim=1024, num_heads=16, num_layers=24, post_norm=True, dropout=0.1, eps=1e-5):
vocab_size=250002,
max_seq_len=514,
type_size=1,
pad_id=1,
dim=1024,
num_heads=16,
num_layers=24,
post_norm=True,
dropout=0.1,
eps=1e-5):
super().__init__() super().__init__()
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.max_seq_len = max_seq_len self.max_seq_len = max_seq_len
...@@ -107,10 +93,7 @@ class XLMRoberta(nn.Module): ...@@ -107,10 +93,7 @@ class XLMRoberta(nn.Module):
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
# blocks # blocks
self.blocks = nn.ModuleList([ self.blocks = nn.ModuleList([AttentionBlock(dim, num_heads, post_norm, dropout, eps) for _ in range(num_layers)])
AttentionBlock(dim, num_heads, post_norm, dropout, eps)
for _ in range(num_layers)
])
# norm layer # norm layer
self.norm = nn.LayerNorm(dim, eps=eps) self.norm = nn.LayerNorm(dim, eps=eps)
...@@ -123,17 +106,13 @@ class XLMRoberta(nn.Module): ...@@ -123,17 +106,13 @@ class XLMRoberta(nn.Module):
mask = ids.ne(self.pad_id).long() mask = ids.ne(self.pad_id).long()
# embeddings # embeddings
x = self.token_embedding(ids) + \ x = self.token_embedding(ids) + self.type_embedding(torch.zeros_like(ids)) + self.pos_embedding(self.pad_id + torch.cumsum(mask, dim=1) * mask)
self.type_embedding(torch.zeros_like(ids)) + \
self.pos_embedding(self.pad_id + torch.cumsum(mask, dim=1) * mask)
if self.post_norm: if self.post_norm:
x = self.norm(x) x = self.norm(x)
x = self.dropout(x) x = self.dropout(x)
# blocks # blocks
mask = torch.where( mask = torch.where(mask.view(b, 1, 1, s).gt(0), 0.0, torch.finfo(x.dtype).min)
mask.view(b, 1, 1, s).gt(0), 0.0,
torch.finfo(x.dtype).min)
for block in self.blocks: for block in self.blocks:
x = block(x, mask) x = block(x, mask)
...@@ -143,25 +122,12 @@ class XLMRoberta(nn.Module): ...@@ -143,25 +122,12 @@ class XLMRoberta(nn.Module):
return x return x
def xlm_roberta_large(pretrained=False, def xlm_roberta_large(pretrained=False, return_tokenizer=False, device="cpu", **kwargs):
return_tokenizer=False,
device='cpu',
**kwargs):
""" """
XLMRobertaLarge adapted from Huggingface. XLMRobertaLarge adapted from Huggingface.
""" """
# params # params
cfg = dict( cfg = dict(vocab_size=250002, max_seq_len=514, type_size=1, pad_id=1, dim=1024, num_heads=16, num_layers=24, post_norm=True, dropout=0.1, eps=1e-5)
vocab_size=250002,
max_seq_len=514,
type_size=1,
pad_id=1,
dim=1024,
num_heads=16,
num_layers=24,
post_norm=True,
dropout=0.1,
eps=1e-5)
cfg.update(**kwargs) cfg.update(**kwargs)
# init a model on device # init a model on device
......
...@@ -6,14 +6,15 @@ from typing import Dict ...@@ -6,14 +6,15 @@ from typing import Dict
import math import math
from ..transformer_infer import HunyuanTransformerInfer from ..transformer_infer import HunyuanTransformerInfer
def taylor_cache_init(cache_dic: Dict, current: Dict): def taylor_cache_init(cache_dic: Dict, current: Dict):
""" """
Initialize Taylor cache, expanding storage areas for Taylor series derivatives Initialize Taylor cache, expanding storage areas for Taylor series derivatives
:param cache_dic: Cache dictionary :param cache_dic: Cache dictionary
:param current: Information of the current step :param current: Information of the current step
""" """
if current['step'] == 0: if current["step"] == 0:
cache_dic['cache'][-1][current['stream']][current['layer']][current['module']] = {} cache_dic["cache"][-1][current["stream"]][current["layer"]][current["module"]] = {}
def derivative_approximation(cache_dic: Dict, current: Dict, feature: torch.Tensor): def derivative_approximation(cache_dic: Dict, current: Dict, feature: torch.Tensor):
...@@ -22,33 +23,34 @@ def derivative_approximation(cache_dic: Dict, current: Dict, feature: torch.Tens ...@@ -22,33 +23,34 @@ def derivative_approximation(cache_dic: Dict, current: Dict, feature: torch.Tens
:param cache_dic: Cache dictionary :param cache_dic: Cache dictionary
:param current: Information of the current step :param current: Information of the current step
""" """
difference_distance = current['activated_steps'][-1] - current['activated_steps'][-2] difference_distance = current["activated_steps"][-1] - current["activated_steps"][-2]
#difference_distance = current['activated_times'][-1] - current['activated_times'][-2] # difference_distance = current['activated_times'][-1] - current['activated_times'][-2]
updated_taylor_factors = {} updated_taylor_factors = {}
updated_taylor_factors[0] = feature updated_taylor_factors[0] = feature
for i in range(cache_dic['max_order']): for i in range(cache_dic["max_order"]):
if (cache_dic['cache'][-1][current['stream']][current['layer']][current['module']].get(i, None) is not None) and (current['step'] > cache_dic['first_enhance'] - 2): if (cache_dic["cache"][-1][current["stream"]][current["layer"]][current["module"]].get(i, None) is not None) and (current["step"] > cache_dic["first_enhance"] - 2):
updated_taylor_factors[i + 1] = (updated_taylor_factors[i] - cache_dic['cache'][-1][current['stream']][current['layer']][current['module']][i]) / difference_distance updated_taylor_factors[i + 1] = (updated_taylor_factors[i] - cache_dic["cache"][-1][current["stream"]][current["layer"]][current["module"]][i]) / difference_distance
else: else:
break break
cache_dic['cache'][-1][current['stream']][current['layer']][current['module']] = updated_taylor_factors cache_dic["cache"][-1][current["stream"]][current["layer"]][current["module"]] = updated_taylor_factors
def taylor_formula(cache_dic: Dict, current: Dict) -> torch.Tensor:
def taylor_formula(cache_dic: Dict, current: Dict) -> torch.Tensor:
""" """
Compute Taylor expansion error Compute Taylor expansion error
:param cache_dic: Cache dictionary :param cache_dic: Cache dictionary
:param current: Information of the current step :param current: Information of the current step
""" """
x = current['step'] - current['activated_steps'][-1] x = current["step"] - current["activated_steps"][-1]
#x = current['t'] - current['activated_times'][-1] # x = current['t'] - current['activated_times'][-1]
output = 0 output = 0
for i in range(len(cache_dic['cache'][-1][current['stream']][current['layer']][current['module']])): for i in range(len(cache_dic["cache"][-1][current["stream"]][current["layer"]][current["module"]])):
output += (1 / math.factorial(i)) * cache_dic['cache'][-1][current['stream']][current['layer']][current['module']][i] * (x ** i) output += (1 / math.factorial(i)) * cache_dic["cache"][-1][current["stream"]][current["layer"]][current["module"]][i] * (x**i)
return output return output
...@@ -56,27 +58,25 @@ class HunyuanTransformerInferFeatureCaching(HunyuanTransformerInfer): ...@@ -56,27 +58,25 @@ class HunyuanTransformerInferFeatureCaching(HunyuanTransformerInfer):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
def infer(self, weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis): def infer(self, weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis):
txt_seq_len = txt.shape[0] txt_seq_len = txt.shape[0]
img_seq_len = img.shape[0] img_seq_len = img.shape[0]
self.scheduler.current['stream'] = 'double_stream' self.scheduler.current["stream"] = "double_stream"
for i in range(self.double_blocks_num): for i in range(self.double_blocks_num):
self.scheduler.current['layer'] = i self.scheduler.current["layer"] = i
img, txt = self.infer_double_block(weights.double_blocks_weights[i], img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis) img, txt = self.infer_double_block(weights.double_blocks_weights[i], img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis)
x = torch.cat((img, txt), 0) x = torch.cat((img, txt), 0)
self.scheduler.current['stream'] = 'single_stream' self.scheduler.current["stream"] = "single_stream"
for i in range(self.single_blocks_num): for i in range(self.single_blocks_num):
self.scheduler.current['layer'] = i self.scheduler.current["layer"] = i
x = self.infer_single_block(weights.single_blocks_weights[i], x, vec, txt_seq_len, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis) x = self.infer_single_block(weights.single_blocks_weights[i], x, vec, txt_seq_len, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis)
img = x[:img_seq_len, ...] img = x[:img_seq_len, ...]
return img, vec return img, vec
def infer_double_block(self, weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis): def infer_double_block(self, weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis):
vec_silu = torch.nn.functional.silu(vec) vec_silu = torch.nn.functional.silu(vec)
...@@ -89,7 +89,7 @@ class HunyuanTransformerInferFeatureCaching(HunyuanTransformerInfer): ...@@ -89,7 +89,7 @@ class HunyuanTransformerInferFeatureCaching(HunyuanTransformerInfer):
img_mod2_scale, img_mod2_scale,
img_mod2_gate, img_mod2_gate,
) = img_mod_out.chunk(6, dim=-1) ) = img_mod_out.chunk(6, dim=-1)
txt_mod_out = weights.txt_mod.apply(vec_silu) txt_mod_out = weights.txt_mod.apply(vec_silu)
( (
txt_mod1_shift, txt_mod1_shift,
...@@ -99,15 +99,15 @@ class HunyuanTransformerInferFeatureCaching(HunyuanTransformerInfer): ...@@ -99,15 +99,15 @@ class HunyuanTransformerInferFeatureCaching(HunyuanTransformerInfer):
txt_mod2_scale, txt_mod2_scale,
txt_mod2_gate, txt_mod2_gate,
) = txt_mod_out.chunk(6, dim=-1) ) = txt_mod_out.chunk(6, dim=-1)
if self.scheduler.current['type'] == 'full': if self.scheduler.current["type"] == "full":
img_q, img_k, img_v = self.infer_double_block_img_pre_atten(weights, img, img_mod1_scale, img_mod1_shift, freqs_cis) img_q, img_k, img_v = self.infer_double_block_img_pre_atten(weights, img, img_mod1_scale, img_mod1_shift, freqs_cis)
txt_q, txt_k, txt_v = self.infer_double_block_txt_pre_atten(weights, txt, txt_mod1_scale, txt_mod1_shift) txt_q, txt_k, txt_v = self.infer_double_block_txt_pre_atten(weights, txt, txt_mod1_scale, txt_mod1_shift)
q = torch.cat((img_q, txt_q), dim=0) q = torch.cat((img_q, txt_q), dim=0)
k = torch.cat((img_k, txt_k), dim=0) k = torch.cat((img_k, txt_k), dim=0)
v = torch.cat((img_v, txt_v), dim=0) v = torch.cat((img_v, txt_v), dim=0)
if not self.parallel_attention: if not self.parallel_attention:
attn = attention( attn = attention(
attention_type=self.attention_type, attention_type=self.attention_type,
...@@ -127,39 +127,39 @@ class HunyuanTransformerInferFeatureCaching(HunyuanTransformerInfer): ...@@ -127,39 +127,39 @@ class HunyuanTransformerInferFeatureCaching(HunyuanTransformerInfer):
k=k, k=k,
v=v, v=v,
img_qkv_len=img_q.shape[0], img_qkv_len=img_q.shape[0],
cu_seqlens_qkv=cu_seqlens_qkv cu_seqlens_qkv=cu_seqlens_qkv,
# cu_seqlens_qkv=cu_seqlens_qkv, # cu_seqlens_qkv=cu_seqlens_qkv,
# max_seqlen_qkv=max_seqlen_qkv, # max_seqlen_qkv=max_seqlen_qkv,
) )
img_attn, txt_attn = attn[: img.shape[0]], attn[img.shape[0] :] img_attn, txt_attn = attn[: img.shape[0]], attn[img.shape[0] :]
img = self.infer_double_block_img_post_atten(weights, img, img_attn, img_mod1_gate, img_mod2_shift, img_mod2_scale, img_mod2_gate) img = self.infer_double_block_img_post_atten(weights, img, img_attn, img_mod1_gate, img_mod2_shift, img_mod2_scale, img_mod2_gate)
txt = self.infer_double_block_txt_post_atten(weights, txt, txt_attn, txt_mod1_gate, txt_mod2_shift, txt_mod2_scale, txt_mod2_gate) txt = self.infer_double_block_txt_post_atten(weights, txt, txt_attn, txt_mod1_gate, txt_mod2_shift, txt_mod2_scale, txt_mod2_gate)
return img, txt return img, txt
elif self.scheduler.current['type'] == 'taylor_cache': elif self.scheduler.current["type"] == "taylor_cache":
self.scheduler.current['module'] = 'img_attn' self.scheduler.current["module"] = "img_attn"
out = taylor_formula(self.scheduler.cache_dic, self.scheduler.current) out = taylor_formula(self.scheduler.cache_dic, self.scheduler.current)
out = out * img_mod1_gate out = out * img_mod1_gate
img = img + out img = img + out
self.scheduler.current['module'] = 'img_mlp' self.scheduler.current["module"] = "img_mlp"
out = taylor_formula(self.scheduler.cache_dic, self.scheduler.current) out = taylor_formula(self.scheduler.cache_dic, self.scheduler.current)
out = out * img_mod2_gate out = out * img_mod2_gate
img = img + out img = img + out
self.scheduler.current['module'] = 'txt_attn' self.scheduler.current["module"] = "txt_attn"
out = taylor_formula(self.scheduler.cache_dic, self.scheduler.current) out = taylor_formula(self.scheduler.cache_dic, self.scheduler.current)
out = out * txt_mod1_gate out = out * txt_mod1_gate
txt = txt + out txt = txt + out
self.scheduler.current['module'] = 'txt_mlp' self.scheduler.current["module"] = "txt_mlp"
out = out * txt_mod2_gate out = out * txt_mod2_gate
txt = txt + out txt = txt + out
...@@ -167,31 +167,31 @@ class HunyuanTransformerInferFeatureCaching(HunyuanTransformerInfer): ...@@ -167,31 +167,31 @@ class HunyuanTransformerInferFeatureCaching(HunyuanTransformerInfer):
return img, txt return img, txt
def infer_double_block_img_post_atten(self, weights, img, img_attn, img_mod1_gate, img_mod2_shift, img_mod2_scale, img_mod2_gate): def infer_double_block_img_post_atten(self, weights, img, img_attn, img_mod1_gate, img_mod2_shift, img_mod2_scale, img_mod2_gate):
self.scheduler.current['module'] = 'img_attn' self.scheduler.current["module"] = "img_attn"
taylor_cache_init(self.scheduler.cache_dic, self.scheduler.current) taylor_cache_init(self.scheduler.cache_dic, self.scheduler.current)
out = weights.img_attn_proj.apply(img_attn) out = weights.img_attn_proj.apply(img_attn)
derivative_approximation(self.scheduler.cache_dic, self.scheduler.current, out) derivative_approximation(self.scheduler.cache_dic, self.scheduler.current, out)
out = out * img_mod1_gate out = out * img_mod1_gate
img = img + out img = img + out
self.scheduler.current['module'] = 'img_mlp' self.scheduler.current["module"] = "img_mlp"
taylor_cache_init(self.scheduler.cache_dic, self.scheduler.current) taylor_cache_init(self.scheduler.cache_dic, self.scheduler.current)
out = torch.nn.functional.layer_norm(img, (img.shape[1],), None, None, 1e-6) out = torch.nn.functional.layer_norm(img, (img.shape[1],), None, None, 1e-6)
out = out * (1 + img_mod2_scale) + img_mod2_shift out = out * (1 + img_mod2_scale) + img_mod2_shift
out = weights.img_mlp_fc1.apply(out) out = weights.img_mlp_fc1.apply(out)
out = torch.nn.functional.gelu(out, approximate='tanh') out = torch.nn.functional.gelu(out, approximate="tanh")
out = weights.img_mlp_fc2.apply(out) out = weights.img_mlp_fc2.apply(out)
derivative_approximation(self.scheduler.cache_dic, self.scheduler.current, out) derivative_approximation(self.scheduler.cache_dic, self.scheduler.current, out)
out = out * img_mod2_gate out = out * img_mod2_gate
img = img + out img = img + out
return img return img
def infer_double_block_txt_post_atten(self, weights, txt, txt_attn, txt_mod1_gate, txt_mod2_shift, txt_mod2_scale, txt_mod2_gate): def infer_double_block_txt_post_atten(self, weights, txt, txt_attn, txt_mod1_gate, txt_mod2_shift, txt_mod2_scale, txt_mod2_gate):
self.scheduler.current['module'] = 'txt_attn' self.scheduler.current["module"] = "txt_attn"
taylor_cache_init(self.scheduler.cache_dic, self.scheduler.current) taylor_cache_init(self.scheduler.cache_dic, self.scheduler.current)
out = weights.txt_attn_proj.apply(txt_attn) out = weights.txt_attn_proj.apply(txt_attn)
...@@ -200,36 +200,36 @@ class HunyuanTransformerInferFeatureCaching(HunyuanTransformerInfer): ...@@ -200,36 +200,36 @@ class HunyuanTransformerInferFeatureCaching(HunyuanTransformerInfer):
out = out * txt_mod1_gate out = out * txt_mod1_gate
txt = txt + out txt = txt + out
self.scheduler.current['module'] = 'txt_mlp' self.scheduler.current["module"] = "txt_mlp"
taylor_cache_init(self.scheduler.cache_dic, self.scheduler.current) taylor_cache_init(self.scheduler.cache_dic, self.scheduler.current)
out = torch.nn.functional.layer_norm(txt, (txt.shape[1],), None, None, 1e-6) out = torch.nn.functional.layer_norm(txt, (txt.shape[1],), None, None, 1e-6)
out = out * (1 + txt_mod2_scale) + txt_mod2_shift out = out * (1 + txt_mod2_scale) + txt_mod2_shift
out = weights.txt_mlp_fc1.apply(out) out = weights.txt_mlp_fc1.apply(out)
out = torch.nn.functional.gelu(out, approximate='tanh') out = torch.nn.functional.gelu(out, approximate="tanh")
out = weights.txt_mlp_fc2.apply(out) out = weights.txt_mlp_fc2.apply(out)
derivative_approximation(self.scheduler.cache_dic, self.scheduler.current, out) derivative_approximation(self.scheduler.cache_dic, self.scheduler.current, out)
out = out * txt_mod2_gate out = out * txt_mod2_gate
txt = txt + out txt = txt + out
return txt return txt
def infer_single_block(self, weights, x, vec, txt_seq_len, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis): def infer_single_block(self, weights, x, vec, txt_seq_len, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis):
out = torch.nn.functional.silu(vec) out = torch.nn.functional.silu(vec)
out = weights.modulation.apply(out) out = weights.modulation.apply(out)
mod_shift, mod_scale, mod_gate = out.chunk(3, dim=-1) mod_shift, mod_scale, mod_gate = out.chunk(3, dim=-1)
if self.scheduler.current['type'] == 'full': if self.scheduler.current["type"] == "full":
out = torch.nn.functional.layer_norm(x, (x.shape[1],), None, None, 1e-6) out = torch.nn.functional.layer_norm(x, (x.shape[1],), None, None, 1e-6)
x_mod = out * (1 + mod_scale) + mod_shift x_mod = out * (1 + mod_scale) + mod_shift
x_mod = weights.linear1.apply(x_mod) x_mod = weights.linear1.apply(x_mod)
qkv, mlp = torch.split(x_mod, [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) qkv, mlp = torch.split(x_mod, [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
self.scheduler.current['module'] = 'attn' self.scheduler.current["module"] = "attn"
taylor_cache_init(self.scheduler.cache_dic, self.scheduler.current) taylor_cache_init(self.scheduler.cache_dic, self.scheduler.current)
q, k, v = rearrange(qkv, "L (K H D) -> K L H D", K=3, H=self.heads_num) q, k, v = rearrange(qkv, "L (K H D) -> K L H D", K=3, H=self.heads_num)
q = weights.q_norm.apply(q) q = weights.q_norm.apply(q)
k = weights.k_norm.apply(k) k = weights.k_norm.apply(k)
...@@ -258,16 +258,16 @@ class HunyuanTransformerInferFeatureCaching(HunyuanTransformerInfer): ...@@ -258,16 +258,16 @@ class HunyuanTransformerInferFeatureCaching(HunyuanTransformerInfer):
k=k, k=k,
v=v, v=v,
img_qkv_len=img_q.shape[0], img_qkv_len=img_q.shape[0],
cu_seqlens_qkv=cu_seqlens_qkv cu_seqlens_qkv=cu_seqlens_qkv,
# cu_seqlens_qkv=cu_seqlens_qkv, # cu_seqlens_qkv=cu_seqlens_qkv,
# max_seqlen_qkv=max_seqlen_qkv, # max_seqlen_qkv=max_seqlen_qkv,
) )
derivative_approximation(self.scheduler.cache_dic, self.scheduler.current, attn) derivative_approximation(self.scheduler.cache_dic, self.scheduler.current, attn)
self.scheduler.current['module'] = 'total' self.scheduler.current["module"] = "total"
taylor_cache_init(self.scheduler.cache_dic, self.scheduler.current) taylor_cache_init(self.scheduler.cache_dic, self.scheduler.current)
out = torch.nn.functional.gelu(mlp, approximate='tanh') out = torch.nn.functional.gelu(mlp, approximate="tanh")
out = torch.cat((attn, out), 1) out = torch.cat((attn, out), 1)
out = weights.linear2.apply(out) out = weights.linear2.apply(out)
derivative_approximation(self.scheduler.cache_dic, self.scheduler.current, out) derivative_approximation(self.scheduler.cache_dic, self.scheduler.current, out)
...@@ -276,8 +276,8 @@ class HunyuanTransformerInferFeatureCaching(HunyuanTransformerInfer): ...@@ -276,8 +276,8 @@ class HunyuanTransformerInferFeatureCaching(HunyuanTransformerInfer):
x = x + out x = x + out
return x return x
elif self.scheduler.current['type'] == 'taylor_cache': elif self.scheduler.current["type"] == "taylor_cache":
self.scheduler.current['module'] = 'total' self.scheduler.current["module"] = "total"
out = taylor_formula(self.scheduler.cache_dic, self.scheduler.current) out = taylor_formula(self.scheduler.cache_dic, self.scheduler.current)
out = out * mod_gate out = out * mod_gate
x = x + out x = x + out
......
import torch import torch
class HunyuanPostInfer(): class HunyuanPostInfer:
def __init__(self): def __init__(self):
pass pass
...@@ -19,7 +19,7 @@ class HunyuanPostInfer(): ...@@ -19,7 +19,7 @@ class HunyuanPostInfer():
oh // patch_size[1], oh // patch_size[1],
ow // patch_size[2], ow // patch_size[2],
) )
c = 16 c = 16
pt, ph, pw = patch_size pt, ph, pw = patch_size
......
...@@ -4,7 +4,7 @@ from einops import rearrange ...@@ -4,7 +4,7 @@ from einops import rearrange
from lightx2v.attentions import attention from lightx2v.attentions import attention
class HunyuanPreInfer(): class HunyuanPreInfer:
def __init__(self): def __init__(self):
self.heads_num = 24 self.heads_num = 24
...@@ -30,11 +30,10 @@ class HunyuanPreInfer(): ...@@ -30,11 +30,10 @@ class HunyuanPreInfer():
s2 = (i + 1) * max_len s2 = (i + 1) * max_len
cu_seqlens_qkv[2 * i + 1] = s1 cu_seqlens_qkv[2 * i + 1] = s1
cu_seqlens_qkv[2 * i + 2] = s2 cu_seqlens_qkv[2 * i + 2] = s2
max_seqlen_qkv = img_seq_len + txt_seq_len max_seqlen_qkv = img_seq_len + txt_seq_len
return img_out[0], infer_text_out, vec, cu_seqlens_qkv, max_seqlen_qkv, (freqs_cos, freqs_sin) return img_out[0], infer_text_out, vec, cu_seqlens_qkv, max_seqlen_qkv, (freqs_cos, freqs_sin)
def infer_time_in(self, weights, t): def infer_time_in(self, weights, t):
freqs = torch.exp(-math.log(10000) * torch.arange(start=0, end=128, dtype=torch.float32) / 128).to(device=t.device) freqs = torch.exp(-math.log(10000) * torch.arange(start=0, end=128, dtype=torch.float32) / 128).to(device=t.device)
args = t.unsqueeze(0).unsqueeze(0).float() * freqs[None] args = t.unsqueeze(0).unsqueeze(0).float() * freqs[None]
...@@ -56,28 +55,25 @@ class HunyuanPreInfer(): ...@@ -56,28 +55,25 @@ class HunyuanPreInfer():
out = weights.txt_in_t_embedder_mlp_0.apply(embedding) out = weights.txt_in_t_embedder_mlp_0.apply(embedding)
out = torch.nn.functional.silu(out) out = torch.nn.functional.silu(out)
timestep_aware_representations = weights.txt_in_t_embedder_mlp_2.apply(out) timestep_aware_representations = weights.txt_in_t_embedder_mlp_2.apply(out)
mask_float = text_mask.float().unsqueeze(-1).to(torch.bfloat16) # [b, s1, 1] mask_float = text_mask.float().unsqueeze(-1).to(torch.bfloat16) # [b, s1, 1]
context_aware_representations = (text_states * mask_float).sum(dim=1) / mask_float.sum(dim=1) context_aware_representations = (text_states * mask_float).sum(dim=1) / mask_float.sum(dim=1)
context_aware_representations = context_aware_representations context_aware_representations = context_aware_representations
out = weights.txt_in_c_embedder_linear_1.apply(context_aware_representations) out = weights.txt_in_c_embedder_linear_1.apply(context_aware_representations)
out = torch.nn.functional.silu(out) out = torch.nn.functional.silu(out)
context_aware_representations = weights.txt_in_c_embedder_linear_2.apply(out) context_aware_representations = weights.txt_in_c_embedder_linear_2.apply(out)
c = timestep_aware_representations + context_aware_representations c = timestep_aware_representations + context_aware_representations
txt_in_input_embed = weights.txt_in_input_embedder.apply(text_states[0]) txt_in_input_embed = weights.txt_in_input_embedder.apply(text_states[0])
batch_size = text_mask.shape[0] batch_size = text_mask.shape[0]
seq_len = text_mask.shape[1] seq_len = text_mask.shape[1]
self_attn_mask_1 = text_mask.view(batch_size, 1, 1, seq_len).repeat( self_attn_mask_1 = text_mask.view(batch_size, 1, 1, seq_len).repeat(1, 1, seq_len, 1)
1, 1, seq_len, 1
)
self_attn_mask_2 = self_attn_mask_1.transpose(2, 3) self_attn_mask_2 = self_attn_mask_1.transpose(2, 3)
self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool() self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool()
self_attn_mask[:, :, :, 0] = True self_attn_mask[:, :, :, 0] = True
cx = torch.nn.functional.silu(c) cx = torch.nn.functional.silu(c)
cx = weights.txt_in_individual_token_refiner_blocks_0_adaLN_modulation_1.apply(cx) cx = weights.txt_in_individual_token_refiner_blocks_0_adaLN_modulation_1.apply(cx)
gate_msa, gate_mlp = cx.chunk(2, dim=1) gate_msa, gate_mlp = cx.chunk(2, dim=1)
...@@ -94,7 +90,6 @@ class HunyuanPreInfer(): ...@@ -94,7 +90,6 @@ class HunyuanPreInfer():
out = weights.txt_in_individual_token_refiner_blocks_0_mlp_fc2.apply(out) out = weights.txt_in_individual_token_refiner_blocks_0_mlp_fc2.apply(out)
txt_in_input_embed = out_1 + out * gate_mlp txt_in_input_embed = out_1 + out * gate_mlp
cx = torch.nn.functional.silu(c) cx = torch.nn.functional.silu(c)
cx = weights.txt_in_individual_token_refiner_blocks_1_adaLN_modulation_1.apply(cx) cx = weights.txt_in_individual_token_refiner_blocks_1_adaLN_modulation_1.apply(cx)
gate_msa, gate_mlp = cx.chunk(2, dim=1) gate_msa, gate_mlp = cx.chunk(2, dim=1)
......
...@@ -4,7 +4,7 @@ from lightx2v.attentions import attention ...@@ -4,7 +4,7 @@ from lightx2v.attentions import attention
from .utils_bf16 import apply_rotary_emb from .utils_bf16 import apply_rotary_emb
class HunyuanTransformerInfer(): class HunyuanTransformerInfer:
def __init__(self, config): def __init__(self, config):
self.config = config self.config = config
self.attention_type = config.get("attention_type", "flash_attn2") self.attention_type = config.get("attention_type", "flash_attn2")
...@@ -26,14 +26,13 @@ class HunyuanTransformerInfer(): ...@@ -26,14 +26,13 @@ class HunyuanTransformerInfer():
img, txt = self.infer_double_block(weights.double_blocks_weights[i], img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis) img, txt = self.infer_double_block(weights.double_blocks_weights[i], img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis)
x = torch.cat((img, txt), 0) x = torch.cat((img, txt), 0)
for i in range(self.single_blocks_num): for i in range(self.single_blocks_num):
x = self.infer_single_block(weights.single_blocks_weights[i], x, vec, txt_seq_len, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis) x = self.infer_single_block(weights.single_blocks_weights[i], x, vec, txt_seq_len, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis)
img = x[:img_seq_len, ...] img = x[:img_seq_len, ...]
return img, vec return img, vec
def infer_double_block(self, weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis): def infer_double_block(self, weights, img, txt, vec, cu_seqlens_qkv, max_seqlen_qkv, freqs_cis):
vec_silu = torch.nn.functional.silu(vec) vec_silu = torch.nn.functional.silu(vec)
...@@ -46,7 +45,7 @@ class HunyuanTransformerInfer(): ...@@ -46,7 +45,7 @@ class HunyuanTransformerInfer():
img_mod2_scale, img_mod2_scale,
img_mod2_gate, img_mod2_gate,
) = img_mod_out.chunk(6, dim=-1) ) = img_mod_out.chunk(6, dim=-1)
txt_mod_out = weights.txt_mod.apply(vec_silu) txt_mod_out = weights.txt_mod.apply(vec_silu)
( (
txt_mod1_shift, txt_mod1_shift,
...@@ -56,15 +55,14 @@ class HunyuanTransformerInfer(): ...@@ -56,15 +55,14 @@ class HunyuanTransformerInfer():
txt_mod2_scale, txt_mod2_scale,
txt_mod2_gate, txt_mod2_gate,
) = txt_mod_out.chunk(6, dim=-1) ) = txt_mod_out.chunk(6, dim=-1)
img_q, img_k, img_v = self.infer_double_block_img_pre_atten(weights, img, img_mod1_scale, img_mod1_shift, freqs_cis) img_q, img_k, img_v = self.infer_double_block_img_pre_atten(weights, img, img_mod1_scale, img_mod1_shift, freqs_cis)
txt_q, txt_k, txt_v = self.infer_double_block_txt_pre_atten(weights, txt, txt_mod1_scale, txt_mod1_shift) txt_q, txt_k, txt_v = self.infer_double_block_txt_pre_atten(weights, txt, txt_mod1_scale, txt_mod1_shift)
q = torch.cat((img_q, txt_q), dim=0) q = torch.cat((img_q, txt_q), dim=0)
k = torch.cat((img_k, txt_k), dim=0) k = torch.cat((img_k, txt_k), dim=0)
v = torch.cat((img_v, txt_v), dim=0) v = torch.cat((img_v, txt_v), dim=0)
if not self.parallel_attention: if not self.parallel_attention:
attn = attention( attn = attention(
attention_type=self.attention_type, attention_type=self.attention_type,
...@@ -84,11 +82,11 @@ class HunyuanTransformerInfer(): ...@@ -84,11 +82,11 @@ class HunyuanTransformerInfer():
k=k, k=k,
v=v, v=v,
img_qkv_len=img_q.shape[0], img_qkv_len=img_q.shape[0],
cu_seqlens_qkv=cu_seqlens_qkv cu_seqlens_qkv=cu_seqlens_qkv,
# cu_seqlens_qkv=cu_seqlens_qkv, # cu_seqlens_qkv=cu_seqlens_qkv,
# max_seqlen_qkv=max_seqlen_qkv, # max_seqlen_qkv=max_seqlen_qkv,
) )
img_attn, txt_attn = attn[: img.shape[0]], attn[img.shape[0] :] img_attn, txt_attn = attn[: img.shape[0]], attn[img.shape[0] :]
img = self.infer_double_block_img_post_atten(weights, img, img_attn, img_mod1_gate, img_mod2_shift, img_mod2_scale, img_mod2_gate) img = self.infer_double_block_img_post_atten(weights, img, img_attn, img_mod1_gate, img_mod2_shift, img_mod2_scale, img_mod2_gate)
txt = self.infer_double_block_txt_post_atten(weights, txt, txt_attn, txt_mod1_gate, txt_mod2_shift, txt_mod2_scale, txt_mod2_gate) txt = self.infer_double_block_txt_post_atten(weights, txt, txt_attn, txt_mod1_gate, txt_mod2_shift, txt_mod2_scale, txt_mod2_gate)
...@@ -99,9 +97,7 @@ class HunyuanTransformerInfer(): ...@@ -99,9 +97,7 @@ class HunyuanTransformerInfer():
img_modulated = img_modulated * (1 + img_mod1_scale) + img_mod1_shift img_modulated = img_modulated * (1 + img_mod1_scale) + img_mod1_shift
img_qkv = weights.img_attn_qkv.apply(img_modulated) img_qkv = weights.img_attn_qkv.apply(img_modulated)
img_q, img_k, img_v = rearrange( img_q, img_k, img_v = rearrange(img_qkv, "L (K H D) -> K L H D", K=3, H=self.heads_num)
img_qkv, "L (K H D) -> K L H D", K=3, H=self.heads_num
)
img_q = weights.img_attn_q_norm.apply(img_q) img_q = weights.img_attn_q_norm.apply(img_q)
img_k = weights.img_attn_k_norm.apply(img_k) img_k = weights.img_attn_k_norm.apply(img_k)
...@@ -114,9 +110,7 @@ class HunyuanTransformerInfer(): ...@@ -114,9 +110,7 @@ class HunyuanTransformerInfer():
txt_modulated = txt_modulated * (1 + txt_mod1_scale) + txt_mod1_shift txt_modulated = txt_modulated * (1 + txt_mod1_scale) + txt_mod1_shift
txt_qkv = weights.txt_attn_qkv.apply(txt_modulated) txt_qkv = weights.txt_attn_qkv.apply(txt_modulated)
txt_q, txt_k, txt_v = rearrange( txt_q, txt_k, txt_v = rearrange(txt_qkv, "L (K H D) -> K L H D", K=3, H=self.heads_num)
txt_qkv, "L (K H D) -> K L H D", K=3, H=self.heads_num
)
txt_q = weights.txt_attn_q_norm.apply(txt_q) txt_q = weights.txt_attn_q_norm.apply(txt_q)
txt_k = weights.txt_attn_k_norm.apply(txt_k) txt_k = weights.txt_attn_k_norm.apply(txt_k)
...@@ -126,11 +120,11 @@ class HunyuanTransformerInfer(): ...@@ -126,11 +120,11 @@ class HunyuanTransformerInfer():
out = weights.img_attn_proj.apply(img_attn) out = weights.img_attn_proj.apply(img_attn)
out = out * img_mod1_gate out = out * img_mod1_gate
img = img + out img = img + out
out = torch.nn.functional.layer_norm(img, (img.shape[1],), None, None, 1e-6) out = torch.nn.functional.layer_norm(img, (img.shape[1],), None, None, 1e-6)
out = out * (1 + img_mod2_scale) + img_mod2_shift out = out * (1 + img_mod2_scale) + img_mod2_shift
out = weights.img_mlp_fc1.apply(out) out = weights.img_mlp_fc1.apply(out)
out = torch.nn.functional.gelu(out, approximate='tanh') out = torch.nn.functional.gelu(out, approximate="tanh")
out = weights.img_mlp_fc2.apply(out) out = weights.img_mlp_fc2.apply(out)
out = out * img_mod2_gate out = out * img_mod2_gate
img = img + out img = img + out
...@@ -140,11 +134,11 @@ class HunyuanTransformerInfer(): ...@@ -140,11 +134,11 @@ class HunyuanTransformerInfer():
out = weights.txt_attn_proj.apply(txt_attn) out = weights.txt_attn_proj.apply(txt_attn)
out = out * txt_mod1_gate out = out * txt_mod1_gate
txt = txt + out txt = txt + out
out = torch.nn.functional.layer_norm(txt, (txt.shape[1],), None, None, 1e-6) out = torch.nn.functional.layer_norm(txt, (txt.shape[1],), None, None, 1e-6)
out = out * (1 + txt_mod2_scale) + txt_mod2_shift out = out * (1 + txt_mod2_scale) + txt_mod2_shift
out = weights.txt_mlp_fc1.apply(out) out = weights.txt_mlp_fc1.apply(out)
out = torch.nn.functional.gelu(out, approximate='tanh') out = torch.nn.functional.gelu(out, approximate="tanh")
out = weights.txt_mlp_fc2.apply(out) out = weights.txt_mlp_fc2.apply(out)
out = out * txt_mod2_gate out = out * txt_mod2_gate
txt = txt + out txt = txt + out
...@@ -157,11 +151,11 @@ class HunyuanTransformerInfer(): ...@@ -157,11 +151,11 @@ class HunyuanTransformerInfer():
out = torch.nn.functional.layer_norm(x, (x.shape[1],), None, None, 1e-6) out = torch.nn.functional.layer_norm(x, (x.shape[1],), None, None, 1e-6)
x_mod = out * (1 + mod_scale) + mod_shift x_mod = out * (1 + mod_scale) + mod_shift
x_mod = weights.linear1.apply(x_mod) x_mod = weights.linear1.apply(x_mod)
qkv, mlp = torch.split(x_mod, [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) qkv, mlp = torch.split(x_mod, [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
q, k, v = rearrange(qkv, "L (K H D) -> K L H D", K=3, H=self.heads_num) q, k, v = rearrange(qkv, "L (K H D) -> K L H D", K=3, H=self.heads_num)
q = weights.q_norm.apply(q) q = weights.q_norm.apply(q)
...@@ -192,12 +186,12 @@ class HunyuanTransformerInfer(): ...@@ -192,12 +186,12 @@ class HunyuanTransformerInfer():
k=k, k=k,
v=v, v=v,
img_qkv_len=img_q.shape[0], img_qkv_len=img_q.shape[0],
cu_seqlens_qkv=cu_seqlens_qkv cu_seqlens_qkv=cu_seqlens_qkv,
# cu_seqlens_qkv=cu_seqlens_qkv, # cu_seqlens_qkv=cu_seqlens_qkv,
# max_seqlen_qkv=max_seqlen_qkv, # max_seqlen_qkv=max_seqlen_qkv,
) )
out = torch.nn.functional.gelu(mlp, approximate='tanh') out = torch.nn.functional.gelu(mlp, approximate="tanh")
out = torch.cat((attn, out), 1) out = torch.cat((attn, out), 1)
out = weights.linear2.apply(out) out = weights.linear2.apply(out)
out = out * mod_gate out = out * mod_gate
......
...@@ -7,23 +7,25 @@ def rms_norm(x, weight, eps): ...@@ -7,23 +7,25 @@ def rms_norm(x, weight, eps):
x = x * weight x = x * weight
return x return x
def rotate_half(x, shape_0, shape_1): def rotate_half(x, shape_0, shape_1):
x_real, x_imag = x.reshape(shape_0, shape_1, -1, 2).unbind(-1) x_real, x_imag = x.reshape(shape_0, shape_1, -1, 2).unbind(-1)
return torch.stack([-x_imag, x_real], dim=-1).flatten(2) return torch.stack([-x_imag, x_real], dim=-1).flatten(2)
def rotary_emb(x, shape_0, shape_1, cos, sin): def rotary_emb(x, shape_0, shape_1, cos, sin):
x_out = (x * cos + rotate_half(x, shape_0, shape_1) * sin) x_out = x * cos + rotate_half(x, shape_0, shape_1) * sin
return x_out return x_out
def apply_rotary_emb( def apply_rotary_emb(
xq: torch.Tensor, xq: torch.Tensor,
xk: torch.Tensor, xk: torch.Tensor,
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
shape_0, shape_1, shape_2 = xq.shape shape_0, shape_1, shape_2 = xq.shape
cos = freqs_cis[0].view(shape_0, 1, shape_2) cos = freqs_cis[0].view(shape_0, 1, shape_2)
sin = freqs_cis[1].view(shape_0, 1, shape_2) sin = freqs_cis[1].view(shape_0, 1, shape_2)
xq_out = rotary_emb(xq, shape_0, shape_1, cos, sin) xq_out = rotary_emb(xq, shape_0, shape_1, cos, sin)
xk_out = rotary_emb(xk, shape_0, shape_1, cos, sin) xk_out = rotary_emb(xk, shape_0, shape_1, cos, sin)
return xq_out, xk_out return xq_out, xk_out
...@@ -9,23 +9,25 @@ def rms_norm(x, weight, eps): ...@@ -9,23 +9,25 @@ def rms_norm(x, weight, eps):
x = x * weight x = x * weight
return x return x
def rotate_half(x, shape_0, shape_1): def rotate_half(x, shape_0, shape_1):
x_real, x_imag = x.float().reshape(shape_0, shape_1, -1, 2).unbind(-1) x_real, x_imag = x.float().reshape(shape_0, shape_1, -1, 2).unbind(-1)
return torch.stack([-x_imag, x_real], dim=-1).flatten(2) return torch.stack([-x_imag, x_real], dim=-1).flatten(2)
def rotary_emb(x, shape_0, shape_1, cos, sin): def rotary_emb(x, shape_0, shape_1, cos, sin):
x_out = (x * cos + rotate_half(x, shape_0, shape_1) * sin) x_out = x * cos + rotate_half(x, shape_0, shape_1) * sin
return x_out.to(torch.bfloat16) return x_out.to(torch.bfloat16)
def apply_rotary_emb( def apply_rotary_emb(
xq: torch.Tensor, xq: torch.Tensor,
xk: torch.Tensor, xk: torch.Tensor,
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
shape_0, shape_1, shape_2 = xq.shape shape_0, shape_1, shape_2 = xq.shape
cos = freqs_cis[0].view(shape_0, 1, shape_2) cos = freqs_cis[0].view(shape_0, 1, shape_2)
sin = freqs_cis[1].view(shape_0, 1, shape_2) sin = freqs_cis[1].view(shape_0, 1, shape_2)
xq_out = rotary_emb(xq.float(), shape_0, shape_1, cos, sin) xq_out = rotary_emb(xq.float(), shape_0, shape_1, cos, sin)
xk_out = rotary_emb(xk.float(), shape_0, shape_1, cos, sin) xk_out = rotary_emb(xk.float(), shape_0, shape_1, cos, sin)
return xq_out, xk_out return xq_out, xk_out
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment