Commit daf4c74e authored by helloyongyang's avatar helloyongyang Committed by Yang Yong(雍洋)
Browse files

first commit

parent 6c79160f
FROM nvidia/cuda:12.8.0-cudnn-devel-ubuntu22.04 AS base
WORKDIR /app
ENV DEBIAN_FRONTEND=noninteractive
ENV LANG=C.UTF-8
ENV LC_ALL=C.UTF-8
RUN sed -i 's|http://archive.ubuntu.com/ubuntu/|https://mirrors.tuna.tsinghua.edu.cn/ubuntu/|g' /etc/apt/sources.list && \
sed -i 's|http://security.ubuntu.com/ubuntu/|https://mirrors.tuna.tsinghua.edu.cn/ubuntu/|g' /etc/apt/sources.list
RUN apt-get update && \
apt-get install -y vim tmux zip unzip wget git cmake build-essential software-properties-common curl libibverbs-dev ca-certificates iproute2 ffmpeg libsm6 libxext6 && \
add-apt-repository ppa:deadsnakes/ppa && \
apt-get update && \
apt-get install -y python3.11 python3.11-venv python3.11-dev python3-pip && \
apt-get clean && rm -rf /var/lib/apt/lists/*
RUN update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.11 1 && \
update-alternatives --install /usr/bin/python python /usr/bin/python3.11 1
RUN pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
RUN pip install packaging ninja
RUN pip install vllm
RUN pip install torch torchvision
# FROM tmp-image AS base
WORKDIR /workspace
# download flash-attention source code
# git clone https://github.com/Dao-AILab/flash-attention.git --recursive
COPY flash-attention /workspace/flash-attention
RUN cd flash-attention && pip install --no-cache-dir -v -e .
RUN cd flash-attention/hopper && pip install --no-cache-dir -v -e .
RUN pip install diffusers transformers tokenizers accelerate safetensors opencv-python numpy imageio imageio-ffmpeg einops loguru
RUN pip install sgl-kernel
# FROM registry.cn-sh-01.sensecore.cn/devsft-ccr-2/video-gen:25030702 AS base
RUN pip install qtorch ftfy
# lightx2v # lightx2v
这是一个视频生成推理框架
## 运行环境
```
# 内网镜像
docker pull registry.cn-sh-01.sensecore.cn/devsft-ccr-2/video-gen:25031303
docker run --gpus all -itd --ipc=host --name [name] -v /mnt:/mnt --entrypoint /bin/bash [image id]
```
## 运行方式
```
git clone https://gitlab.bj.sensetime.com/video-gen/lightx2v.git
cd lightx2v
# 修改运行脚本的参数
bash run_hunyuan_t2v.sh
```
#!/bin/bash
export PYTHONPATH="./":$PYTHONPATH
# trtexec \
# --onnx="/mnt/nvme0/wq/project/sd/code/lightx2v/vae_decoder_hf_sim.onnx" \
# --saveEngine="./vae_decoder_hf_sim.engine" \
# --allowWeightStreaming \
# --stronglyTyped \
# --fp16 \
# --weightStreamingBudget=100 \
# --minShapes=inp:1x16x9x18x16 \
# --optShapes=inp:1x16x17x32x16 \
# --maxShapes=inp:1x16x17x32x32
python examples/vae_trt/convert_vae_trt_engine.py --model_path "/mnt/nvme1/yongyang/models/hy/ckpts"
\ No newline at end of file
from pathlib import Path
import os
import argparse
import torch
from loguru import logger
from lightx2v.text2v.models.video_encoders.hf.autoencoder_kl_causal_3d.autoencoder_kl_causal_3d import AutoencoderKLCausal3D
from lightx2v.text2v.models.video_encoders.trt.autoencoder_kl_causal_3d.trt_vae_infer import HyVaeTrtModelInfer
def parse_args():
args = argparse.ArgumentParser()
args.add_argument("--model_path", help="", type=str)
args.add_argument("--dtype", default=torch.float16)
args.add_argument("--device", default="cuda", type=str)
return args.parse_args()
def convert_vae_trt_engine(args):
vae_path = os.path.join(args.model_path, 'hunyuan-video-t2v-720p/vae')
assert Path(vae_path).exists(), f"{vae_path} not exists."
config = AutoencoderKLCausal3D.load_config(vae_path)
model = AutoencoderKLCausal3D.from_config(config)
assert Path(os.path.join(vae_path, 'pytorch_model.pt')).exists(), f"{os.path.join(vae_path, 'pytorch_model.pt')} not exists."
ckpt = torch.load(os.path.join(vae_path, 'pytorch_model.pt'), map_location='cpu', weights_only=True)
model.load_state_dict(ckpt)
model = model.to(dtype=args.dtype, device=args.device)
onnx_path = HyVaeTrtModelInfer.export_to_onnx(model.decoder, vae_path)
del model
torch.cuda.empty_cache()
engine_path = onnx_path.replace(".onnx", ".engine")
HyVaeTrtModelInfer.convert_to_trt_engine(onnx_path, engine_path)
logger.info(f"ONNX: {onnx_path}")
logger.info(f"TRT Engine: {engine_path}")
return
def main():
args = parse_args()
convert_vae_trt_engine(args)
if __name__ == "__main__":
main()
from lightx2v.attentions.common.torch_sdpa import torch_sdpa
from lightx2v.attentions.common.flash_attn2 import flash_attn2
from lightx2v.attentions.common.flash_attn3 import flash_attn3
from lightx2v.attentions.common.sage_attn2 import sage_attn2
def attention(
attention_type="flash_attn2",
*args, **kwargs
):
if attention_type == "torch_sdpa":
return torch_sdpa(*args, **kwargs)
elif attention_type == "flash_attn2":
return flash_attn2(*args, **kwargs)
elif attention_type == "flash_attn3":
return flash_attn3(*args, **kwargs)
elif attention_type == 'sage_attn2':
return sage_attn2(*args, **kwargs)
else:
raise NotImplementedError(f"Unsupported attention mode: {attention_type}")
try:
from flash_attn.flash_attn_interface import flash_attn_varlen_func
except ImportError:
flash_attn_varlen_func = None
def flash_attn2(
q,
k,
v,
cu_seqlens_q=None,
cu_seqlens_kv=None,
max_seqlen_q=None,
max_seqlen_kv=None
):
x = flash_attn_varlen_func(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_kv,
max_seqlen_q,
max_seqlen_kv,
).reshape(max_seqlen_q, -1)
return x
try:
from flash_attn_interface import flash_attn_varlen_func as flash_attn_varlen_func_v3
except ImportError:
flash_attn_varlen_func_v3 = None
def flash_attn3(
q,
k,
v,
cu_seqlens_q=None,
cu_seqlens_kv=None,
max_seqlen_q=None,
max_seqlen_kv=None
):
x = flash_attn_varlen_func_v3(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_kv,
max_seqlen_q,
max_seqlen_kv,
)[0].reshape(max_seqlen_q, -1)
return x
import torch
try:
from sageattention import sageattn
except ImportError:
sageattn = None
def sage_attn2(
q,
k,
v,
cu_seqlens_q=None,
cu_seqlens_kv=None,
max_seqlen_q=None,
max_seqlen_kv=None,
):
q, k, v = (
q.transpose(1, 0).contiguous(),
k.transpose(1, 0).contiguous(),
v.transpose(1, 0).contiguous(),
)
x1 = sageattn(
q[:, : cu_seqlens_q[1], :].unsqueeze(0),
k[:, : cu_seqlens_q[1], :].unsqueeze(0),
v[:, : cu_seqlens_kv[1], :].unsqueeze(0),
)
x2 = sageattn(
q[:, cu_seqlens_q[1] :, :].unsqueeze(0),
k[:, cu_seqlens_kv[1] :, :].unsqueeze(0),
v[:, cu_seqlens_kv[1] :, :].unsqueeze(0),
)
x = torch.cat((x1, x2), dim=-2).transpose(2, 1).contiguous()
x = x.view(max_seqlen_q, -1)
return x
import torch
import torch.nn.functional as F
def torch_sdpa(
q,
k,
v,
drop_rate=0,
attn_mask=None,
causal=False,
):
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
if attn_mask is not None and attn_mask.dtype != torch.bool:
attn_mask = attn_mask.to(q.dtype)
x = F.scaled_dot_product_attention(
q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal
)
x = x.transpose(1, 2)
b, s, a, d = x.shape
out = x.reshape(b, s, -1)
return out
import torch
import torch.distributed as dist
def all2all_seq2head(input):
'''
将输入张量从 [seq_len/N, heads, hidden_dims] 转换为 [seq_len, heads/N, hidden_dims] 的格式。
参数:
input (torch.Tensor): 输入张量,形状为 [seq_len/N, heads, hidden_dims]
返回:
torch.Tensor: 转换后的输出张量,形状为 [seq_len, heads/N, hidden_dims]
'''
# 确保输入是一个3D张量
assert (input.dim() == 3), f"input must be 3D tensor"
# 获取当前进程的世界大小
world_size = dist.get_world_size()
# 获取输入张量的形状
shard_seq_len, heads, hidden_dims = input.shape
seq_len = shard_seq_len * world_size # 计算总序列长度
shard_heads = heads // world_size # 计算每个进程处理的头数
# 重塑输入张量以便进行 all-to-all 操作
input_t = (
input.reshape(shard_seq_len, world_size, shard_heads, hidden_dims) # 重塑为 [shard_seq_len, world_size, shard_heads, hidden_dims]
.transpose(0, 1) # 转置以便进行 all-to-all 操作
.contiguous() # 确保内存连续
)
# 创建一个与输入张量相同形状的输出张量
output = torch.empty_like(input_t)
# 执行 all-to-all 操作,将输入张量的内容分发到所有进程
dist.all_to_all_single(output, input_t)
# 重塑输出张量为 [seq_len, heads/N, hidden_dims] 形状
output = output.reshape(seq_len, shard_heads, hidden_dims).contiguous()
return output # 返回转换后的输出张量
def all2all_head2seq(input):
'''
将输入张量从 [seq_len, heads/N, hidden_dims] 转换为 [seq_len/N, heads, hidden_dims] 的格式。
参数:
input (torch.Tensor): 输入张量,形状为 [seq_len, heads/N, hidden_dims]
返回:
torch.Tensor: 转换后的输出张量,形状为 [seq_len/N, heads, hidden_dims]
'''
# 确保输入是一个3D张量
assert (input.dim() == 3), f"input must be 3D tensor"
# 获取当前进程的世界大小
world_size = dist.get_world_size()
# 获取输入张量的形状
seq_len, shard_heads, hidden_dims = input.shape
heads = shard_heads * world_size # 计算总头数
shard_seq_len = seq_len // world_size # 计算每个进程处理的序列长度
# 重塑输入张量以便进行 all-to-all 操作
input_t = (
input.reshape(world_size, shard_seq_len, shard_heads, hidden_dims) # 重塑为 [world_size, shard_seq_len, shard_heads, hidden_dims]
.transpose(1, 2) # 转置以便进行 all-to-all 操作
.contiguous() # 确保内存连续
.reshape(world_size, shard_heads, shard_seq_len, hidden_dims) # 再次重塑为 [world_size, shard_heads, shard_seq_len, hidden_dims]
)
# 创建一个与输入张量相同形状的输出张量
output = torch.empty_like(input_t)
# 执行 all-to-all 操作,将输入张量的内容分发到所有进程
dist.all_to_all_single(output, input_t)
# 重塑输出张量为 [heads, shard_seq_len, hidden_dims] 形状
output = output.reshape(heads, shard_seq_len, hidden_dims)
# 转置输出张量并重塑为 [shard_seq_len, heads, hidden_dims] 形状
output = output.transpose(0, 1).contiguous().reshape(shard_seq_len, heads, hidden_dims)
return output # 返回转换后的输出张量
import torch
import torch.distributed as dist
from lightx2v.attentions import attention
def partial_heads_attn(attention_type, q, k, v, cu_seqlens_qkv, max_seqlen_qkv):
num_heads = q.shape[-2]
cur_rank = dist.get_rank()
world_size = dist.get_world_size()
num_chunk_heads = int(num_heads / dist.get_world_size())
if cur_rank == world_size-1:
q = q[:, num_chunk_heads*cur_rank:, :]
k = k[:, num_chunk_heads*cur_rank:, :]
v = v[:, num_chunk_heads*cur_rank:, :]
else:
q = q[:, num_chunk_heads*cur_rank:num_chunk_heads*(cur_rank+1), :]
k = k[:, num_chunk_heads*cur_rank:num_chunk_heads*(cur_rank+1), :]
v = v[:, num_chunk_heads*cur_rank:num_chunk_heads*(cur_rank+1), :]
output = attention(
attention_type=attention_type,
q=q,
k=k,
v=v,
cu_seqlens_q=cu_seqlens_qkv,
cu_seqlens_kv=cu_seqlens_qkv,
max_seqlen_q=max_seqlen_qkv,
max_seqlen_kv=max_seqlen_qkv,
)
gathered_outputs = [torch.empty_like(output) for _ in range(world_size)]
dist.all_gather(gathered_outputs, output)
combined_output = torch.cat(gathered_outputs, dim=1)
return combined_output
\ No newline at end of file
export PYTHONPATH=/workspace/lightx2v:$PYTHONPATH
export CUDA_VISIBLE_DEVICES=0,1
torchrun --nproc_per_node=2 test_acc.py
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment