Commit daf4c74e authored by helloyongyang's avatar helloyongyang Committed by Yang Yong(雍洋)
Browse files

first commit

parent 6c79160f
import torch
import torch.distributed as dist
from lightx2v.attentions import attention
from lightx2v.utils.utils import seed_all
seed_all(42)
def prepare_tensors():
cur_rank = dist.get_rank() # 获取当前进程的 rank
torch.cuda.set_device(cur_rank) # 设置当前进程的 CUDA 设备
q = torch.randn(32656, 24, 128, dtype=torch.bfloat16).cuda()
k = torch.randn(32656, 24, 128, dtype=torch.bfloat16).cuda()
v = torch.randn(32656, 24, 128, dtype=torch.bfloat16).cuda()
cu_seqlens_qkv = torch.tensor(
[0, 32411, 32656], dtype=torch.int32
).cuda()
max_seqlen_qkv = 32656
return q, k, v, cu_seqlens_qkv, max_seqlen_qkv
def test_part_head():
q, k, v, cu_seqlens_qkv, max_seqlen_qkv = prepare_tensors()
# 先计算完整的结果作为参考
single_gpu_output = attention(
q=q,
k=k,
v=v,
cu_seqlens_q=cu_seqlens_qkv,
cu_seqlens_kv=cu_seqlens_qkv,
max_seqlen_q=max_seqlen_qkv,
max_seqlen_kv=max_seqlen_qkv,
)
num_heads = q.shape[-2]
cur_rank = dist.get_rank()
world_size = dist.get_world_size()
num_chunk_heads = int(num_heads / dist.get_world_size())
if cur_rank == world_size-1:
q = q[:, num_chunk_heads*cur_rank:, :]
k = k[:, num_chunk_heads*cur_rank:, :]
v = v[:, num_chunk_heads*cur_rank:, :]
else:
q = q[:, num_chunk_heads*cur_rank:num_chunk_heads*(cur_rank+1), :]
k = k[:, num_chunk_heads*cur_rank:num_chunk_heads*(cur_rank+1), :]
v = v[:, num_chunk_heads*cur_rank:num_chunk_heads*(cur_rank+1), :]
output = attention(
q=q,
k=k,
v=v,
cu_seqlens_q=cu_seqlens_qkv,
cu_seqlens_kv=cu_seqlens_qkv,
max_seqlen_q=max_seqlen_qkv,
max_seqlen_kv=max_seqlen_qkv,
)
gathered_outputs = [torch.empty_like(output) for _ in range(world_size)]
dist.all_gather(gathered_outputs, output)
combined_output = torch.cat(gathered_outputs, dim=1)
# 验证结果一致性
if cur_rank == 0:
# import pdb; pdb.set_trace()
print("Outputs match:", torch.allclose(single_gpu_output, combined_output, rtol=1e-3, atol=1e-3))
# # 验证结果一致性
# print("Outputs match:", torch.allclose(single_gpu_output, combined_output, rtol=1e-3, atol=1e-3))
if __name__ == "__main__":
# 初始化分布式环境
dist.init_process_group(backend='nccl')
test_part_head()
\ No newline at end of file
from lightx2v.attentions.distributed.partial_heads_attn.attn import partial_heads_attn
def parallelize_hunyuan(hunyuan_model):
hunyuan_model.transformer_infer.parallel_attention = partial_heads_attn
\ No newline at end of file
import torch
import torch.distributed as dist
from lightx2v.attentions import attention
def ring_attn(q, k, v, img_qkv_len, cu_seqlens_qkv, attention_type="flash_attn2"):
'''
执行 Ulysses 注意力机制,结合图像和文本的查询、键和值。
参数:
q (torch.Tensor): 查询张量,形状为 [shard_seqlen, heads, hidden_dims]
k (torch.Tensor): 键张量,形状为 [shard_seqlen, heads, hidden_dims]
v (torch.Tensor): 值张量,形状为 [shard_seqlen, heads, hidden_dims]
img_qkv_len (int): 图像查询、键和值的长度
cu_seqlens_qkv (torch.Tensor): 累积序列长度,包含文本和图像的长度信息
attention_type (str): 注意力类型,默认为 "flash_attn2"
返回:
torch.Tensor: 计算得到的注意力结果
'''
# 获取当前进程的排名和全局进程数
cur_rank = dist.get_rank()
world_size = dist.get_world_size()
# 获取序列长度和文本相关的长度
seq_len = q.shape[0]
txt_qkv_len = cu_seqlens_qkv[1] - img_qkv_len # 文本查询、键和值的长度
txt_mask_len = cu_seqlens_qkv[2] - img_qkv_len # 文本掩码长度
# 获取查询张量的头数和隐藏维度
_, heads, hidden_dims = q.shape
shard_heads = heads // world_size # 每个进程处理的头数
shard_seqlen = img_qkv_len # 每个进程处理的序列长度
# 分割图像和文本的查询、键和值
img_q, img_k, img_v = q[:img_qkv_len,:,:].contiguous(), k[:img_qkv_len,:,:].contiguous(), v[:img_qkv_len,:,:].contiguous()
txt_q, txt_k, txt_v = q[img_qkv_len:,:,:].contiguous(), k[img_qkv_len:,:,:].contiguous(), v[img_qkv_len:,:,:].contiguous()
gathered_img_k = [torch.empty_like(img_k) for _ in range(world_size)]
gathered_img_v = [torch.empty_like(img_v) for _ in range(world_size)]
dist.all_gather(gathered_img_k, img_k)
dist.all_gather(gathered_img_v, img_v)
torch.cuda.synchronize()
q = q
k = torch.cat(gathered_img_k+[txt_k], dim=0)
v = torch.cat(gathered_img_v+[txt_v], dim=0)
# 初始化累积序列长度张量
cu_seqlens_q = torch.zeros([3], dtype=torch.int32, device="cuda")
s = txt_qkv_len + img_q.shape[0] # 计算文本和图像的总长度
s1 = s # 当前样本的结束位置
s2 = txt_mask_len + img_q.shape[0] # 文本掩码的结束位置
cu_seqlens_q[1] = s1 # 设置累积序列长度
cu_seqlens_q[2] = s2 # 设置累积序列长度
max_seqlen_q = img_q.shape[0] + txt_q.shape[0] # 最大序列长度
# 初始化累积序列长度张量
cu_seqlens_kv = torch.zeros([3], dtype=torch.int32, device="cuda")
s = txt_qkv_len + img_k.shape[0]*world_size # 计算文本和图像的总长度
s1 = s # 当前样本的结束位置
s2 = txt_mask_len + img_k.shape[0]*world_size # 文本掩码的结束位置
cu_seqlens_kv[1] = s1 # 设置累积序列长度
cu_seqlens_kv[2] = s2 # 设置累积序列长度
max_seqlen_kv = img_k.shape[0]*world_size + txt_q.shape[0] # 最大序列长度
attn = attention(
attention_type=attention_type,
q=q,
k=k,
v=v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv
)
return attn
import functools
from lightx2v.attentions.distributed.ring.attn import ring_attn
from lightx2v.attentions.distributed.utils.process import pre_process, post_process
def parallelize_hunyuan(hunyuan_model):
"""将 Hunyuan 模型的推理过程并行化,使用 Ulysses 注意力机制。
参数:
hunyuan_model: Hunyuan 模型实例,包含推理方法和其他属性。
"""
# 将 Hunyuan 模型的并行注意力机制替换为 Ulysses 注意力
hunyuan_model.transformer_infer.parallel_attention = ring_attn
# 保存原始的推理方法,以便后续调用
original_infer = hunyuan_model.infer
@functools.wraps(hunyuan_model.__class__.infer) # 保留原始推理方法的元信息
def new_infer(self, latent_model_input, t_expand, text_states, text_mask, text_states_2, freqs_cos, freqs_sin, guidance):
"""新的推理方法,处理输入并调用原始推理方法。
参数:
self: Hunyuan 模型实例
latent_model_input: 潜在模型输入
t_expand: 时间扩展参数
text_states: 文本状态
text_mask: 文本掩码
text_states_2: 第二组文本状态
freqs_cos: 余弦频率
freqs_sin: 正弦频率
guidance: 指导参数
返回:
combined_output: 经过后处理的输出结果
"""
# 预处理输入数据
latent_model_input, freqs_cos, freqs_sin, split_dim = pre_process(
latent_model_input, freqs_cos, freqs_sin
)
# 调用原始推理方法,获取输出
output = original_infer(
latent_model_input, t_expand, text_states, text_mask, text_states_2, freqs_cos, freqs_sin, guidance
)
# 对输出进行后处理
combined_output = post_process(output, split_dim)
return combined_output # 返回处理后的输出
# 将新的推理方法绑定到 Hunyuan 模型实例
new_infer = new_infer.__get__(hunyuan_model)
hunyuan_model.infer = new_infer # 替换原始推理方法
\ No newline at end of file
import torch
import torch.distributed as dist
from lightx2v.attentions import attention
from lightx2v.attentions.distributed.comm.all2all import all2all_seq2head, all2all_head2seq
def ulysses_attn(q, k, v, img_qkv_len, cu_seqlens_qkv, attention_type="flash_attn2"):
'''
执行 Ulysses 注意力机制,结合图像和文本的查询、键和值。
参数:
q (torch.Tensor): 查询张量,形状为 [shard_seqlen, heads, hidden_dims]
k (torch.Tensor): 键张量,形状为 [shard_seqlen, heads, hidden_dims]
v (torch.Tensor): 值张量,形状为 [shard_seqlen, heads, hidden_dims]
img_qkv_len (int): 图像查询、键和值的长度
cu_seqlens_qkv (torch.Tensor): 累积序列长度,包含文本和图像的长度信息
attention_type (str): 注意力类型,默认为 "flash_attn2"
返回:
torch.Tensor: 计算得到的注意力结果
'''
# 获取当前进程的排名和全局进程数
cur_rank = dist.get_rank()
world_size = dist.get_world_size()
# 获取序列长度和文本相关的长度
seq_len = q.shape[0]
txt_qkv_len = cu_seqlens_qkv[1] - img_qkv_len # 文本查询、键和值的长度
txt_mask_len = cu_seqlens_qkv[2] - img_qkv_len # 文本掩码长度
# 获取查询张量的头数和隐藏维度
_, heads, hidden_dims = q.shape
shard_heads = heads // world_size # 每个进程处理的头数
shard_seqlen = img_qkv_len # 每个进程处理的序列长度
# 分割图像和文本的查询、键和值
img_q, img_k, img_v = q[:img_qkv_len,:,:].contiguous(), k[:img_qkv_len,:,:].contiguous(), v[:img_qkv_len,:,:].contiguous()
txt_q, txt_k, txt_v = q[img_qkv_len:,:,:].contiguous(), k[img_qkv_len:,:,:].contiguous(), v[img_qkv_len:,:,:].contiguous()
# 将图像的查询、键和值转换为头的格式
img_q = all2all_seq2head(img_q)
img_k = all2all_seq2head(img_k)
img_v = all2all_seq2head(img_v)
torch.cuda.synchronize() # 确保CUDA操作完成
# 处理文本的查询、键和值,选择当前进程的头
txt_q = txt_q[:,cur_rank*shard_heads:(cur_rank+1)*shard_heads,:]
txt_k = txt_k[:,cur_rank*shard_heads:(cur_rank+1)*shard_heads,:]
txt_v = txt_v[:,cur_rank*shard_heads:(cur_rank+1)*shard_heads,:]
# 合并图像和文本的查询、键和值
q = torch.cat((img_q, txt_q), dim=0)
k = torch.cat((img_k, txt_k), dim=0)
v = torch.cat((img_v, txt_v), dim=0)
# 初始化累积序列长度张量
cu_seqlens_qkv = torch.zeros([3], dtype=torch.int32, device="cuda")
s = txt_qkv_len + img_q.shape[0] # 计算文本和图像的总长度
s1 = s # 当前样本的结束位置
s2 = txt_mask_len + img_q.shape[0] # 文本掩码的结束位置
cu_seqlens_qkv[1] = s1 # 设置累积序列长度
cu_seqlens_qkv[2] = s2 # 设置累积序列长度
max_seqlen_qkv = img_q.shape[0] + txt_q.shape[0] # 最大序列长度
# 调用注意力函数计算注意力结果
attn = attention(
attention_type=attention_type,
q=q,
k=k,
v=v,
cu_seqlens_q=cu_seqlens_qkv,
cu_seqlens_kv=cu_seqlens_qkv,
max_seqlen_q=max_seqlen_qkv,
max_seqlen_kv=max_seqlen_qkv
)
# 分割图像和文本的注意力结果
img_attn, txt_attn = attn[:img_q.shape[0],:], attn[img_q.shape[0]:,]
# 收集所有进程的文本注意力结果
gathered_txt_attn = [torch.empty_like(txt_attn) for _ in range(world_size)]
dist.all_gather(gathered_txt_attn, txt_attn)
# 处理图像注意力结果
img_attn = img_attn.reshape(world_size*shard_seqlen, shard_heads, hidden_dims) # 重塑图像注意力结果
img_attn = all2all_head2seq(img_attn) # 将头的格式转换回序列格式
img_attn = img_attn.reshape(shard_seqlen, -1) # 重塑为 [shard_seqlen, -1] 形状
torch.cuda.synchronize() # 确保CUDA操作完成
txt_attn = torch.cat(gathered_txt_attn, dim=1) # 合并所有进程的文本注意力结果
# 合并图像和文本的注意力结果
attn = torch.cat([img_attn, txt_attn], dim=0)
return attn # 返回最终的注意力结果
\ No newline at end of file
import functools
from lightx2v.attentions.distributed.ulysses.attn import ulysses_attn
from lightx2v.attentions.distributed.utils.process import pre_process, post_process
def parallelize_hunyuan(hunyuan_model):
"""将 Hunyuan 模型的推理过程并行化,使用 Ulysses 注意力机制。
参数:
hunyuan_model: Hunyuan 模型实例,包含推理方法和其他属性。
"""
# 将 Hunyuan 模型的并行注意力机制替换为 Ulysses 注意力
hunyuan_model.transformer_infer.parallel_attention = ulysses_attn
# 保存原始的推理方法,以便后续调用
original_infer = hunyuan_model.infer
@functools.wraps(hunyuan_model.__class__.infer) # 保留原始推理方法的元信息
def new_infer(self, text_encoders_output, args):
"""新的推理方法,处理输入并调用原始推理方法。
参数:
self: Hunyuan 模型实例
text_encoders_output: 文本编码器的输出
args: 其他参数
返回:
None
"""
# 保存原始的潜在模型输入和频率数据
self.scheduler.ori_latents, self.scheduler.ori_freqs_cos, self.scheduler.ori_freqs_sin = (
self.scheduler.latents,
self.scheduler.freqs_cos,
self.scheduler.freqs_sin
)
# 预处理输入数据以适应并行计算
self.scheduler.latents, self.scheduler.freqs_cos, self.scheduler.freqs_sin, split_dim = pre_process(
self.scheduler.latents, self.scheduler.freqs_cos, self.scheduler.freqs_sin
)
# 调用原始推理方法,获取输出
output = original_infer(
text_encoders_output, args
)
# 对输出进行后处理
self.scheduler.noise_pred = post_process(self.scheduler.noise_pred, split_dim)
# 恢复原始的潜在模型输入和频率数据
self.scheduler.latents, self.scheduler.freqs_cos, self.scheduler.freqs_sin = (
self.scheduler.ori_latents,
self.scheduler.ori_freqs_cos,
self.scheduler.ori_freqs_sin
)
# return combined_output # 返回处理后的输出(当前被注释掉)
# 将新的推理方法绑定到 Hunyuan 模型实例
new_infer = new_infer.__get__(hunyuan_model)
hunyuan_model.infer = new_infer # 替换原始推理方法
\ No newline at end of file
import torch
import torch.distributed as dist
def pre_process(latent_model_input, freqs_cos, freqs_sin):
'''
对输入的潜在模型数据和频率数据进行预处理,进行切分以适应分布式计算。
参数:
latent_model_input (torch.Tensor): 输入的潜在模型数据,形状为 [batch_size, channels, temporal_size, height, width]
freqs_cos (torch.Tensor): 余弦频率数据,形状为 [batch_size, channels, temporal_size, height, width]
freqs_sin (torch.Tensor): 正弦频率数据,形状为 [batch_size, channels, temporal_size, height, width]
返回:
tuple: 处理后的 latent_model_input, freqs_cos, freqs_sin 和切分维度 split_dim
'''
# 获取当前进程的世界大小和当前进程的排名
world_size = dist.get_world_size()
cur_rank = dist.get_rank()
# 根据输入的形状确定切分维度
if latent_model_input.shape[-2] // 2 % world_size == 0:
split_dim = -2 # 按高度切分
elif latent_model_input.shape[-1] // 2 % world_size == 0:
split_dim = -1 # 按宽度切分
else:
raise ValueError(f"Cannot split video sequence into world size ({world_size}) parts evenly")
# 获取时间维度、处理后的高度和宽度
temporal_size, h, w = latent_model_input.shape[2], latent_model_input.shape[3] // 2, latent_model_input.shape[4] // 2
# 按照确定的维度切分潜在模型输入
latent_model_input = torch.chunk(latent_model_input, world_size, dim=split_dim)[cur_rank]
# 处理余弦频率数据
dim_thw = freqs_cos.shape[-1] # 获取频率数据的最后一个维度
freqs_cos = freqs_cos.reshape(temporal_size, h, w, dim_thw) # 重塑为 [temporal_size, height, width, dim_thw]
freqs_cos = torch.chunk(freqs_cos, world_size, dim=split_dim - 1)[cur_rank] # 切分频率数据
freqs_cos = freqs_cos.reshape(-1, dim_thw) # 重塑为 [batch_size, dim_thw]
# 处理正弦频率数据
dim_thw = freqs_sin.shape[-1] # 获取频率数据的最后一个维度
freqs_sin = freqs_sin.reshape(temporal_size, h, w, dim_thw) # 重塑为 [temporal_size, height, width, dim_thw]
freqs_sin = torch.chunk(freqs_sin, world_size, dim=split_dim - 1)[cur_rank] # 切分频率数据
freqs_sin = freqs_sin.reshape(-1, dim_thw) # 重塑为 [batch_size, dim_thw]
return latent_model_input, freqs_cos, freqs_sin, split_dim # 返回处理后的数据
def post_process(output, split_dim):
"""对输出进行后处理,收集所有进程的输出并合并。
参数:
output (torch.Tensor): 当前进程的输出,形状为 [batch_size, ...]
split_dim (int): 切分维度,用于合并输出
返回:
torch.Tensor: 合并后的输出,形状为 [world_size * batch_size, ...]
"""
# 获取当前进程的世界大小
world_size = dist.get_world_size()
# 创建一个列表,用于存储所有进程的输出
gathered_outputs = [torch.empty_like(output) for _ in range(world_size)]
# 收集所有进程的输出
dist.all_gather(gathered_outputs, output)
# 在指定的维度上合并所有进程的输出
combined_output = torch.cat(gathered_outputs, dim=split_dim)
return combined_output # 返回合并后的输出
#
# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import argparse
import os
import tensorrt as trt
from .common_runtime import *
try:
# Sometimes python does not understand FileNotFoundError
FileNotFoundError
except NameError:
FileNotFoundError = IOError
def GiB(val):
return val * 1 << 30
def add_help(description):
parser = argparse.ArgumentParser(
description=description, formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
args, _ = parser.parse_known_args()
def find_sample_data(
description="Runs a TensorRT Python sample", subfolder="", find_files=[], err_msg=""
):
"""
Parses sample arguments.
Args:
description (str): Description of the sample.
subfolder (str): The subfolder containing data relevant to this sample
find_files (str): A list of filenames to find. Each filename will be replaced with an absolute path.
Returns:
str: Path of data directory.
"""
# Standard command-line arguments for all samples.
kDEFAULT_DATA_ROOT = os.path.join(os.sep, "usr", "src", "tensorrt", "data")
parser = argparse.ArgumentParser(
description=description, formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"-d",
"--datadir",
help="Location of the TensorRT sample data directory, and any additional data directories.",
action="append",
default=[kDEFAULT_DATA_ROOT],
)
args, _ = parser.parse_known_args()
def get_data_path(data_dir):
# If the subfolder exists, append it to the path, otherwise use the provided path as-is.
data_path = os.path.join(data_dir, subfolder)
if not os.path.exists(data_path):
if data_dir != kDEFAULT_DATA_ROOT:
print(
"WARNING: "
+ data_path
+ " does not exist. Trying "
+ data_dir
+ " instead."
)
data_path = data_dir
# Make sure data directory exists.
if not (os.path.exists(data_path)) and data_dir != kDEFAULT_DATA_ROOT:
print(
"WARNING: {:} does not exist. Please provide the correct data path with the -d option.".format(
data_path
)
)
return data_path
data_paths = [get_data_path(data_dir) for data_dir in args.datadir]
return data_paths, locate_files(data_paths, find_files, err_msg)
def locate_files(data_paths, filenames, err_msg=""):
"""
Locates the specified files in the specified data directories.
If a file exists in multiple data directories, the first directory is used.
Args:
data_paths (List[str]): The data directories.
filename (List[str]): The names of the files to find.
Returns:
List[str]: The absolute paths of the files.
Raises:
FileNotFoundError if a file could not be located.
"""
found_files = [None] * len(filenames)
for data_path in data_paths:
# Find all requested files.
for index, (found, filename) in enumerate(zip(found_files, filenames)):
if not found:
file_path = os.path.abspath(os.path.join(data_path, filename))
if os.path.exists(file_path):
found_files[index] = file_path
# Check that all files were found
for f, filename in zip(found_files, filenames):
if not f or not os.path.exists(f):
raise FileNotFoundError(
"Could not find {:}. Searched in data paths: {:}\n{:}".format(
filename, data_paths, err_msg
)
)
return found_files
# Sets up the builder to use the timing cache file, and creates it if it does not already exist
def setup_timing_cache(config: trt.IBuilderConfig, timing_cache_path: os.PathLike):
buffer = b""
if os.path.exists(timing_cache_path):
with open(timing_cache_path, mode="rb") as timing_cache_file:
buffer = timing_cache_file.read()
timing_cache: trt.ITimingCache = config.create_timing_cache(buffer)
config.set_timing_cache(timing_cache, True)
# Saves the config's timing cache to file
def save_timing_cache(config: trt.IBuilderConfig, timing_cache_path: os.PathLike):
timing_cache: trt.ITimingCache = config.get_timing_cache()
with open(timing_cache_path, "wb") as timing_cache_file:
timing_cache_file.write(memoryview(timing_cache.serialize()))
\ No newline at end of file
#
# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import ctypes
from typing import Optional, List, Union
import numpy as np
import tensorrt as trt
from cuda import cuda, cudart
def check_cuda_err(err):
if isinstance(err, cuda.CUresult):
if err != cuda.CUresult.CUDA_SUCCESS:
raise RuntimeError("Cuda Error: {}".format(err))
if isinstance(err, cudart.cudaError_t):
if err != cudart.cudaError_t.cudaSuccess:
raise RuntimeError("Cuda Runtime Error: {}".format(err))
else:
raise RuntimeError("Unknown error type: {}".format(err))
def cuda_call(call):
err, res = call[0], call[1:]
check_cuda_err(err)
if len(res) == 1:
res = res[0]
return res
class HostDeviceMem:
"""Pair of host and device memory, where the host memory is wrapped in a numpy array"""
def __init__(self, size: int, dtype: Optional[np.dtype] = None):
dtype = dtype or np.dtype(np.uint8)
nbytes = size * dtype.itemsize
host_mem = cuda_call(cudart.cudaMallocHost(nbytes))
pointer_type = ctypes.POINTER(np.ctypeslib.as_ctypes_type(dtype))
self._host = np.ctypeslib.as_array(ctypes.cast(host_mem, pointer_type), (size,))
self._device = cuda_call(cudart.cudaMalloc(nbytes))
self._nbytes = nbytes
@property
def host(self) -> np.ndarray:
return self._host
@host.setter
def host(self, data: Union[np.ndarray, bytes]):
if isinstance(data, np.ndarray):
if data.size > self.host.size:
raise ValueError(
f"Tried to fit an array of size {data.size} into host memory of size {self.host.size}"
)
np.copyto(self.host[:data.size], data.flat, casting='safe')
else:
assert self.host.dtype == np.uint8
self.host[:self.nbytes] = np.frombuffer(data, dtype=np.uint8)
@property
def device(self) -> int:
return self._device
@property
def nbytes(self) -> int:
return self._nbytes
def __str__(self):
return f"Host:\n{self.host}\nDevice:\n{self.device}\nSize:\n{self.nbytes}\n"
def __repr__(self):
return self.__str__()
def free(self):
cuda_call(cudart.cudaFree(self.device))
cuda_call(cudart.cudaFreeHost(self.host.ctypes.data))
# Allocates all buffers required for an engine, i.e. host/device inputs/outputs.
# If engine uses dynamic shapes, specify a profile to find the maximum input & output size.
def allocate_buffers(engine: trt.ICudaEngine, profile_idx: Optional[int] = None):
inputs = []
outputs = []
bindings = []
stream = cuda_call(cudart.cudaStreamCreate())
tensor_names = [engine.get_tensor_name(i) for i in range(engine.num_io_tensors)]
for binding in tensor_names:
# get_tensor_profile_shape returns (min_shape, optimal_shape, max_shape)
# Pick out the max shape to allocate enough memory for the binding.
shape = engine.get_tensor_shape(binding) if profile_idx is None else engine.get_tensor_profile_shape(binding, profile_idx)[-1]
shape_valid = np.all([s >= 0 for s in shape])
if not shape_valid and profile_idx is None:
raise ValueError(f"Binding {binding} has dynamic shape, " +\
"but no profile was specified.")
size = trt.volume(shape)
trt_type = engine.get_tensor_dtype(binding)
# Allocate host and device buffers
try:
dtype = np.dtype(trt.nptype(trt_type))
bindingMemory = HostDeviceMem(size, dtype)
except TypeError: # no numpy support: create a byte array instead (BF16, FP8, INT4)
size = int(size * trt_type.itemsize)
bindingMemory = HostDeviceMem(size)
# Append the device buffer to device bindings.
bindings.append(int(bindingMemory.device))
# Append to the appropriate list.
if engine.get_tensor_mode(binding) == trt.TensorIOMode.INPUT:
inputs.append(bindingMemory)
else:
outputs.append(bindingMemory)
return inputs, outputs, bindings, stream
# Frees the resources allocated in allocate_buffers
def free_buffers(inputs: List[HostDeviceMem], outputs: List[HostDeviceMem], stream: cudart.cudaStream_t):
for mem in inputs + outputs:
mem.free()
cuda_call(cudart.cudaStreamDestroy(stream))
# Wrapper for cudaMemcpy which infers copy size and does error checking
def memcpy_host_to_device(device_ptr: int, host_arr: np.ndarray):
nbytes = host_arr.size * host_arr.itemsize
cuda_call(cudart.cudaMemcpy(device_ptr, host_arr, nbytes, cudart.cudaMemcpyKind.cudaMemcpyHostToDevice))
# Wrapper for cudaMemcpy which infers copy size and does error checking
def memcpy_device_to_host(host_arr: np.ndarray, device_ptr: int):
nbytes = host_arr.size * host_arr.itemsize
cuda_call(cudart.cudaMemcpy(host_arr, device_ptr, nbytes, cudart.cudaMemcpyKind.cudaMemcpyDeviceToHost))
def _do_inference_base(inputs, outputs, stream, execute_async_func):
# Transfer input data to the GPU.
kind = cudart.cudaMemcpyKind.cudaMemcpyHostToDevice
[cuda_call(cudart.cudaMemcpyAsync(inp.device, inp.host, inp.nbytes, kind, stream)) for inp in inputs]
# Run inference.
execute_async_func()
# Transfer predictions back from the GPU.
kind = cudart.cudaMemcpyKind.cudaMemcpyDeviceToHost
[cuda_call(cudart.cudaMemcpyAsync(out.host, out.device, out.nbytes, kind, stream)) for out in outputs]
# Synchronize the stream
cuda_call(cudart.cudaStreamSynchronize(stream))
# Return only the host outputs.
return [out.host for out in outputs]
# This function is generalized for multiple inputs/outputs.
# inputs and outputs are expected to be lists of HostDeviceMem objects.
def do_inference(context, engine, bindings, inputs, outputs, stream):
def execute_async_func():
context.execute_async_v3(stream_handle=stream)
# Setup context tensor address.
num_io = engine.num_io_tensors
for i in range(num_io):
context.set_tensor_address(engine.get_tensor_name(i), bindings[i])
return _do_inference_base(inputs, outputs, stream, execute_async_func)
\ No newline at end of file
import torch
import torch.nn as nn
class MemoryEfficientBlocks(nn.Module):
def __init__(self, block_class, num_blocks, **block_params):
super().__init__()
self.block_class = block_class
self.num_blocks = num_blocks
self.block_params = block_params
# 初始化两个block
self.active_blocks = nn.ModuleList([
block_class(**block_params) for _ in range(2)
])
# 为权重加载创建独立的CUDA流,并设置优先级
self.compute_stream = torch.cuda.Stream(priority=-1) # 高优先级
self.load_stream = torch.cuda.Stream(priority=0) # 普通优先级
# 预分配固定内存用于异步传输
self.pinned_memory = torch.cuda.empty_cache()
torch.cuda.memory.set_per_process_memory_fraction(0.8) # 限制GPU内存使用
# 用于存储预加载的权重
# self.next_weights = None
self.weight_buffer = []
# self.current_block_idx = 0
def initialize_weights(self, checkpoint, key):
"""加载所有权重到CPU内存"""
# checkpoint = torch.load(checkpoint_path, map_location='cpu')
for i in range(self.num_blocks):
block_weights = {
k.replace(f'{key}.{i}.', ''): v
for k, v in checkpoint.items()
if f'{key}.{i}.' in k
}
self.weight_buffer.append(block_weights)
def prefetch_weights(self, block_idx):
"""在独立CUDA流中预加载下一个block的权重"""
with torch.cuda.stream(self.load_stream):
next_weights = self.weight_buffer[block_idx]
next_weights = {
k: v.cuda(non_blocking=True)
for k, v in next_weights.items()
}
self.active_blocks[1].load_state_dict(next_weights)
def swap_blocks(self):
"""交换两个block并更新权重"""
# 等待计算完成
self.compute_stream.synchronize()
# 等待加载完成
self.load_stream.synchronize()
# 交换blocks
self.active_blocks[0], self.active_blocks[1] = \
self.active_blocks[1], self.active_blocks[0]
def forward(self, *args, **kwargs):
"""前向传播,同时进行计算和权重加载"""
# import pdb; pdb.set_trace()
for i in range(self.num_blocks):
if i == 0:
self.active_blocks[0].load_state_dict(self.weight_buffer[0])
# 在主计算流中进行当前block的计算
with torch.cuda.stream(self.compute_stream):
current_block = self.active_blocks[0]
outputs = current_block(*args, **kwargs) # 解包参数传入
# import pdb; pdb.set_trace()
# 在独立流中预加载下一个block的权重
if i < self.num_blocks - 1:
self.prefetch_weights(i + 1)
# 交换blocks并更新权重
self.swap_blocks()
# 更新args中的输入为当前输出
args = list(args)
if len(outputs) == 1:
args[0] = outputs
else:
for i in range(len(outputs)):
args[i] = outputs[i]
args = tuple(args)
return outputs
\ No newline at end of file
from .mm import *
from .norm import *
from .conv import *
from .conv2d import *
from .conv3d import *
import torch
from abc import ABCMeta, abstractmethod
from lightx2v.utils.registry_factory import CONV2D_WEIGHT_REGISTER
class Conv2dWeightTemplate(metaclass=ABCMeta):
def __init__(self, weight_name, bias_name, stride, padding, dilation, groups):
self.weight_name = weight_name
self.bias_name = bias_name
self.stride = stride
self.padding = padding
self.dilation = dilation
self.groups = groups
self.config = {}
@abstractmethod
def load(self, weight_dict):
pass
@abstractmethod
def apply(self, input_tensor):
pass
def set_config(self, config=None):
if config is not None:
self.config = config
@CONV2D_WEIGHT_REGISTER('Default')
class Conv2dWeight(Conv2dWeightTemplate):
def __init__(self, weight_name, bias_name, stride=1, padding=0, dilation=1, groups=1):
super().__init__(weight_name, bias_name, stride, padding, dilation, groups)
def load(self, weight_dict):
self.weight = weight_dict[self.weight_name].cuda()
self.bias = weight_dict[self.bias_name].cuda() if self.bias_name is not None else None
def apply(self, input_tensor):
input_tensor = torch.nn.functional.conv2d(
input_tensor,
weight=self.weight,
bias=self.bias,
stride=self.stride,
padding=self.padding,
dilation=self.dilation,
groups=self.groups
)
return input_tensor
def to_cpu(self):
self.weight = self.weight.cpu()
if self.bias is not None:
self.bias = self.bias.cpu()
def to_cuda(self):
self.weight = self.weight.cuda()
if self.bias is not None:
self.bias = self.bias.cuda()
import torch
from abc import ABCMeta, abstractmethod
from lightx2v.utils.registry_factory import CONV3D_WEIGHT_REGISTER
class Conv3dWeightTemplate(metaclass=ABCMeta):
def __init__(self, weight_name, bias_name, stride=1, padding=0, dilation=1, groups=1):
self.weight_name = weight_name
self.bias_name = bias_name
self.stride = stride
self.padding = padding
self.dilation = dilation
self.groups = groups
self.config = {}
@abstractmethod
def load(self, weight_dict):
pass
@abstractmethod
def apply(self, input_tensor):
pass
def set_config(self, config=None):
if config is not None:
self.config = config
@CONV3D_WEIGHT_REGISTER('Default')
class Conv3dWeight(Conv3dWeightTemplate):
def __init__(self, weight_name, bias_name, stride=1, padding=0, dilation=1, groups=1):
super().__init__(weight_name, bias_name, stride, padding, dilation, groups)
def load(self, weight_dict):
self.weight = weight_dict[self.weight_name].cuda()
self.bias = weight_dict[self.bias_name].cuda() if self.bias_name is not None else None
def apply(self, input_tensor):
input_tensor = torch.nn.functional.conv3d(
input_tensor,
weight=self.weight,
bias=self.bias,
stride=self.stride,
padding=self.padding,
dilation=self.dilation,
groups=self.groups
)
return input_tensor
def to_cpu(self):
self.weight = self.weight.cpu()
if self.bias is not None:
self.bias = self.bias.cpu()
def to_cuda(self):
self.weight = self.weight.cuda()
if self.bias is not None:
self.bias = self.bias.cuda()
from .mm_weight import *
from .mm_weight_calib import *
import torch
from abc import ABCMeta, abstractmethod
from vllm import _custom_ops as ops
from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER
from lightx2v.utils.quant_utils import IntegerQuantizer, FloatQuantizer
class MMWeightTemplate(metaclass=ABCMeta):
def __init__(self, weight_name, bias_name):
self.weight_name = weight_name
self.bias_name = bias_name
self.config = {}
@abstractmethod
def load(self, weight_dict):
pass
@abstractmethod
def apply(self, input_tensor):
pass
def set_config(self, config=None):
if config is not None:
self.config = config
@MM_WEIGHT_REGISTER('Default')
class MMWeight(MMWeightTemplate):
def __init__(self, weight_name, bias_name):
super().__init__(weight_name, bias_name)
def load(self, weight_dict):
self.weight = weight_dict[self.weight_name].t().cuda()
self.bias = weight_dict[self.bias_name].cuda() if self.bias_name is not None else None
def apply(self, input_tensor):
shape = (input_tensor.shape[0], self.weight.shape[1])
dtype = input_tensor.dtype
device = input_tensor.device
output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)
if self.bias is None:
return torch.mm(input_tensor, self.weight, out=output_tensor)
return torch.addmm(self.bias, input_tensor, self.weight, out=output_tensor)
def to_cpu(self):
self.weight = self.weight.cpu()
if self.bias is not None:
self.bias = self.bias.cpu()
def to_cuda(self):
self.weight = self.weight.cuda()
if self.bias is not None:
self.bias = self.bias.cuda()
@MM_WEIGHT_REGISTER('Default-Force-FP32')
class MMWeight(MMWeight):
def __init__(self, weight_name, bias_name):
super().__init__(weight_name, bias_name)
def load(self, weight_dict):
super().load(weight_dict)
self.weight = self.weight.to(torch.float32)
if self.bias is not None:
self.bias = self.bias.to(torch.float32)
@MM_WEIGHT_REGISTER('W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm')
class MMWeightWfp8channelAfp8channeldynamicVllm(MMWeightTemplate):
'''
Name: W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm
Quant MM:
Weight: fp8 perchannel sym
Act: fp8 perchannel dynamic sym
Kernel: vllm
'''
def __init__(self, weight_name, bias_name):
super().__init__(weight_name, bias_name)
def load(self, weight_dict):
if self.config.get('weight_auto_quant', True):
self.weight = weight_dict[self.weight_name].to(torch.float32).cuda()
w_quantizer = FloatQuantizer('e4m3', True, 'channel')
self.weight, self.weight_scale, _ = w_quantizer.real_quant_tensor(self.weight)
self.weight = self.weight.to(torch.float8_e4m3fn).t().cuda()
self.weight_scale = self.weight_scale.to(torch.float32).cuda()
else:
self.weight = weight_dict[self.weight_name].t().cuda()
self.weight_scale = weight_dict[self.weight_name.rstrip(".weight") + '.weight_scale'].cuda()
self.bias = weight_dict[self.bias_name].cuda() if self.bias_name is not None else None
def apply(self, input_tensor):
shape = (input_tensor.shape[0], self.weight.shape[1])
dtype = input_tensor.dtype
device = input_tensor.device
output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)
qinput, x_scale = ops.scaled_fp8_quant(input_tensor, None, scale_ub=None, use_per_token_if_dynamic=True)
torch.ops._C.cutlass_scaled_mm(output_tensor, qinput, self.weight, x_scale, self.weight_scale, self.bias)
return output_tensor
def to_cpu(self):
self.weight = self.weight.cpu()
self.weight_scale = self.weight_scale.cpu()
if self.bias is not None:
self.bias = self.bias.cpu()
def to_cuda(self):
self.weight = self.weight.cuda()
self.weight_scale = self.weight_scale.cuda()
if self.bias is not None:
self.bias = self.bias.cuda()
@MM_WEIGHT_REGISTER('W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm')
class MMWeightWfp8channelAfp8channeldynamicVllm(MMWeightTemplate):
'''
Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm
Quant MM:
Weight: int8 perchannel sym
Act: int8 perchannel dynamic sym
Kernel: vllm
'''
def __init__(self, weight_name, bias_name):
super().__init__(weight_name, bias_name)
def load(self, weight_dict):
if self.config.get('weight_auto_quant', True):
self.weight = weight_dict[self.weight_name].to(torch.float32).cuda()
w_quantizer = IntegerQuantizer(8, True, 'channel')
self.weight, self.weight_scale, _ = w_quantizer.real_quant_tensor(self.weight)
self.weight = self.weight.to(torch.int8).t().cuda()
self.weight_scale = self.weight_scale.to(torch.float32).cuda()
else:
self.weight = weight_dict[self.weight_name].t().cuda()
self.weight_scale = weight_dict[self.weight_name.rstrip(".weight") + '.weight_scale'].cuda()
self.bias = weight_dict[self.bias_name].cuda() if self.bias_name is not None else None
def apply(self, input_tensor):
shape = (input_tensor.shape[0], self.weight.shape[1])
dtype = input_tensor.dtype
device = input_tensor.device
output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)
qinput, x_scale, _ = ops.scaled_int8_quant(input_tensor, scale=None, azp=None, symmetric=True)
torch.ops._C.cutlass_scaled_mm(output_tensor, qinput, self.weight, x_scale, self.weight_scale, self.bias)
return output_tensor
def to_cpu(self):
self.weight = self.weight.cpu()
self.weight_scale = self.weight_scale.cpu()
if self.bias is not None:
self.bias = self.bias.cpu()
def to_cuda(self):
self.weight = self.weight.cuda()
self.weight_scale = self.weight_scale.cuda()
if self.bias is not None:
self.bias = self.bias.cuda()
if __name__ == '__main__':
weight_dict = {
'xx.weight': torch.randn(8192, 4096).to(torch.float8_e4m3fn),
'xx.bias': torch.randn(8192).to(torch.bfloat16),
'xx.weight_scale': torch.randn(8192, 1).to(torch.float32),
}
mm_weight = MM_WEIGHT_REGISTER['W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm']('xx.weight', 'xx.bias')
mm_weight.set_config({'weight_auto_quant': False})
mm_weight.load(weight_dict)
input_tensor = torch.randn(1024, 4096).to(torch.bfloat16).cuda()
output_tensor = mm_weight.apply(input_tensor)
print(output_tensor.shape)
weight_dict = {
'xx.weight': torch.randn(8192, 4096),
'xx.bias': torch.randn(8192).to(torch.bfloat16),
}
mm_weight = MM_WEIGHT_REGISTER['W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm']('xx.weight', 'xx.bias')
mm_weight.set_config({'weight_auto_quant': True})
mm_weight.load(weight_dict)
input_tensor = torch.randn(1024, 4096).to(torch.bfloat16).cuda()
output_tensor = mm_weight.apply(input_tensor)
print(output_tensor.shape)
weight_dict = {
'xx.weight': torch.randn(8192, 4096),
'xx.bias': torch.randn(8192).to(torch.bfloat16),
}
mm_weight = MM_WEIGHT_REGISTER['W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm']('xx.weight', 'xx.bias')
mm_weight.set_config({'weight_auto_quant': True})
mm_weight.load(weight_dict)
input_tensor = torch.randn(1024, 4096).to(torch.bfloat16).cuda()
output_tensor = mm_weight.apply(input_tensor)
print(output_tensor.shape)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment