Commit a50bcc53 authored by Dongz's avatar Dongz Committed by Yang Yong(雍洋)
Browse files

add lint feature and minor fix (#7)

* [minor]: optimize dockerfile for fewer layer

* [feature]: add pre-commit lint, update readme for contribution guidance

* [minor]: fix run shell privileges

* [auto]: first lint without rule F, fix rule E

* [minor]: fix docker file error
parent 3b460075
# Follow https://verdantfox.com/blog/how-to-use-git-pre-commit-hooks-the-hard-way-and-the-easy-way
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.11.0
hooks:
- id: ruff
args: [--fix, --respect-gitignore, --config=pyproject.toml]
- id: ruff-format
args: [--config=pyproject.toml]
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-yaml
- id: check-toml
- id: check-added-large-files
- id: check-case-conflict
- id: check-merge-conflict
- id: debug-statements
FROM nvidia/cuda:12.8.0-cudnn-devel-ubuntu22.04 AS base FROM nvidia/cuda:12.8.0-cudnn-devel-ubuntu22.04 AS base
WORKDIR /app WORKDIR /workspace
ENV DEBIAN_FRONTEND=noninteractive ENV DEBIAN_FRONTEND=noninteractive
ENV LANG=C.UTF-8 ENV LANG=C.UTF-8
ENV LC_ALL=C.UTF-8 ENV LC_ALL=C.UTF-8
RUN sed -i 's|http://archive.ubuntu.com/ubuntu/|https://mirrors.tuna.tsinghua.edu.cn/ubuntu/|g' /etc/apt/sources.list && \ # use tsinghua source
sed -i 's|http://security.ubuntu.com/ubuntu/|https://mirrors.tuna.tsinghua.edu.cn/ubuntu/|g' /etc/apt/sources.list RUN sed -i 's|http://archive.ubuntu.com/ubuntu/|https://mirrors.tuna.tsinghua.edu.cn/ubuntu/|g' /etc/apt/sources.list \
&& sed -i 's|http://security.ubuntu.com/ubuntu/|https://mirrors.tuna.tsinghua.edu.cn/ubuntu/|g' /etc/apt/sources.list
RUN apt-get update && \
apt-get install -y vim tmux zip unzip wget git cmake build-essential software-properties-common curl libibverbs-dev ca-certificates iproute2 ffmpeg libsm6 libxext6 && \
add-apt-repository ppa:deadsnakes/ppa && \
apt-get update && \
apt-get install -y python3.11 python3.11-venv python3.11-dev python3-pip && \
apt-get clean && rm -rf /var/lib/apt/lists/*
RUN update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.11 1 && \
update-alternatives --install /usr/bin/python python /usr/bin/python3.11 1
RUN pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
RUN pip install packaging ninja
RUN pip install vllm
RUN pip install torch torchvision RUN apt-get update && apt install -y software-properties-common \
&& add-apt-repository ppa:deadsnakes/ppa \
&& apt-get update \
&& apt-get install -y vim tmux zip unzip wget git cmake build-essential \
curl libibverbs-dev ca-certificates iproute2 \
ffmpeg libsm6 libxext6 \
&& apt-get install -y python3.11 python3.11-venv python3.11-dev python3-pip \
&& apt-get clean && rm -rf /var/lib/apt/lists/*
# FROM tmp-image AS base RUN update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.11 1 \
&& update-alternatives --install /usr/bin/python python /usr/bin/python3.11 1
WORKDIR /workspace RUN pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple \
&& pip install packaging ninja vllm torch torchvision diffusers transformers \
tokenizers accelerate safetensors opencv-python numpy imageio imageio-ffmpeg \
einops loguru sgl-kernel qtorch ftfy
# download flash-attention source code # please download flash-attention source code first
# git clone https://github.com/Dao-AILab/flash-attention.git --recursive # git clone https://github.com/Dao-AILab/flash-attention.git --recursive
# todo: add third party repo feature
COPY flash-attention /workspace/flash-attention COPY flash-attention /workspace/flash-attention
# install flash-attention 2
RUN cd flash-attention && pip install --no-cache-dir -v -e . RUN cd flash-attention && pip install --no-cache-dir -v -e .
# install flash-attention 3, only if hopper
RUN cd flash-attention/hopper && pip install --no-cache-dir -v -e . RUN cd flash-attention/hopper && pip install --no-cache-dir -v -e .
RUN pip install diffusers transformers tokenizers accelerate safetensors opencv-python numpy imageio imageio-ffmpeg einops loguru
RUN pip install sgl-kernel
# FROM registry.cn-sh-01.sensecore.cn/devsft-ccr-2/video-gen:25030702 AS base
RUN pip install qtorch ftfy
# LightX2V: Light Video Generation Inference Framework # LightX2V: Light Video Generation Inference Framework
<div align="center"> <div align="center">
<picture> <picture>
...@@ -8,21 +8,37 @@ ...@@ -8,21 +8,37 @@
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------
## Prepare Environment
## START ENV ```shell
```
docker pull registry.cn-sh-01.sensecore.cn/devsft-ccr-2/video-gen:25033101 docker pull registry.cn-sh-01.sensecore.cn/devsft-ccr-2/video-gen:25033101
docker run --gpus all -itd --ipc=host --name [name] -v /mnt:/mnt --entrypoint /bin/bash [image id] docker run --gpus all -itd --ipc=host --name [name] -v /mnt:/mnt --entrypoint /bin/bash [image id]
``` ```
## START RUN ## Fast Start
``` ```shell
git clone https://gitlab.bj.sensetime.com/video-gen/lightx2v.git git clone https://github.com/ModelTC/lightx2v.git
cd lightx2v/scripts cd lightx2v
# Modify the parameters of the running script # Modify the parameters of the running script
bash run_hunyuan_t2v.sh bash run_hunyuan_t2v.sh
``` ```
## Contribute
We have prepared a `pre-commit` hook to enforce consistent code formatting across the project. You can clean up your code following the steps below:
1. Install the required dependencies:
```shell
pip install ruff pre-commit
```
2. Then, run the following command:
```shell
pre-commit run --all-files
```
If your code complies with the standards, you should not see any errors.
...@@ -11,4 +11,4 @@ export PYTHONPATH="./":$PYTHONPATH ...@@ -11,4 +11,4 @@ export PYTHONPATH="./":$PYTHONPATH
# --optShapes=inp:1x16x17x32x16 \ # --optShapes=inp:1x16x17x32x16 \
# --maxShapes=inp:1x16x17x32x32 # --maxShapes=inp:1x16x17x32x32
python examples/vae_trt/convert_vae_trt_engine.py --model_path "/mnt/nvme1/yongyang/models/hy/ckpts" python examples/vae_trt/convert_vae_trt_engine.py --model_path "/mnt/nvme1/yongyang/models/hy/ckpts"
\ No newline at end of file
...@@ -18,12 +18,12 @@ def parse_args(): ...@@ -18,12 +18,12 @@ def parse_args():
def convert_vae_trt_engine(args): def convert_vae_trt_engine(args):
vae_path = os.path.join(args.model_path, 'hunyuan-video-t2v-720p/vae') vae_path = os.path.join(args.model_path, "hunyuan-video-t2v-720p/vae")
assert Path(vae_path).exists(), f"{vae_path} not exists." assert Path(vae_path).exists(), f"{vae_path} not exists."
config = AutoencoderKLCausal3D.load_config(vae_path) config = AutoencoderKLCausal3D.load_config(vae_path)
model = AutoencoderKLCausal3D.from_config(config) model = AutoencoderKLCausal3D.from_config(config)
assert Path(os.path.join(vae_path, 'pytorch_model.pt')).exists(), f"{os.path.join(vae_path, 'pytorch_model.pt')} not exists." assert Path(os.path.join(vae_path, "pytorch_model.pt")).exists(), f"{os.path.join(vae_path, 'pytorch_model.pt')} not exists."
ckpt = torch.load(os.path.join(vae_path, 'pytorch_model.pt'), map_location='cpu', weights_only=True) ckpt = torch.load(os.path.join(vae_path, "pytorch_model.pt"), map_location="cpu", weights_only=True)
model.load_state_dict(ckpt) model.load_state_dict(ckpt)
model = model.to(dtype=args.dtype, device=args.device) model = model.to(dtype=args.dtype, device=args.device)
onnx_path = HyVaeTrtModelInfer.export_to_onnx(model.decoder, vae_path) onnx_path = HyVaeTrtModelInfer.export_to_onnx(model.decoder, vae_path)
......
...@@ -28,7 +28,7 @@ from lightx2v.image2v.models.wan.model import CLIPModel ...@@ -28,7 +28,7 @@ from lightx2v.image2v.models.wan.model import CLIPModel
def load_models(args, model_config): def load_models(args, model_config):
if model_config['parallel_attn']: if model_config["parallel_attn"]:
cur_rank = dist.get_rank() # 获取当前进程的 rank cur_rank = dist.get_rank() # 获取当前进程的 rank
torch.cuda.set_device(cur_rank) # 设置当前进程的 CUDA 设备 torch.cuda.set_device(cur_rank) # 设置当前进程的 CUDA 设备
image_encoder = None image_encoder = None
...@@ -56,13 +56,13 @@ def load_models(args, model_config): ...@@ -56,13 +56,13 @@ def load_models(args, model_config):
text_encoders = [text_encoder] text_encoders = [text_encoder]
model = WanModel(args.model_path, model_config) model = WanModel(args.model_path, model_config)
vae_model = WanVAE(vae_pth=os.path.join(args.model_path, "Wan2.1_VAE.pth"), device=init_device, parallel=args.parallel_vae) vae_model = WanVAE(vae_pth=os.path.join(args.model_path, "Wan2.1_VAE.pth"), device=init_device, parallel=args.parallel_vae)
if args.task == 'i2v': if args.task == "i2v":
image_encoder = CLIPModel( image_encoder = CLIPModel(
dtype=torch.float16, dtype=torch.float16,
device=init_device, device=init_device,
checkpoint_path=os.path.join(args.model_path, checkpoint_path=os.path.join(args.model_path, "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"),
"models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"), tokenizer_path=os.path.join(args.model_path, "xlm-roberta-large"),
tokenizer_path=os.path.join(args.model_path, "xlm-roberta-large")) )
else: else:
raise NotImplementedError(f"Unsupported model class: {args.model_cls}") raise NotImplementedError(f"Unsupported model class: {args.model_cls}")
...@@ -70,7 +70,7 @@ def load_models(args, model_config): ...@@ -70,7 +70,7 @@ def load_models(args, model_config):
def set_target_shape(args): def set_target_shape(args):
if args.model_cls == 'hunyuan': if args.model_cls == "hunyuan":
vae_scale_factor = 2 ** (4 - 1) vae_scale_factor = 2 ** (4 - 1)
args.target_shape = ( args.target_shape = (
1, 1,
...@@ -79,15 +79,10 @@ def set_target_shape(args): ...@@ -79,15 +79,10 @@ def set_target_shape(args):
int(args.target_height) // vae_scale_factor, int(args.target_height) // vae_scale_factor,
int(args.target_width) // vae_scale_factor, int(args.target_width) // vae_scale_factor,
) )
elif args.model_cls == 'wan2.1': elif args.model_cls == "wan2.1":
if args.task == 'i2v': if args.task == "i2v":
args.target_shape = ( args.target_shape = (16, 21, args.lat_h, args.lat_w)
16, elif args.task == "t2v":
21,
args.lat_h,
args.lat_w
)
elif args.task == 't2v':
args.target_shape = ( args.target_shape = (
16, 16,
(args.target_video_length - 1) // 4 + 1, (args.target_video_length - 1) // 4 + 1,
...@@ -99,7 +94,7 @@ def set_target_shape(args): ...@@ -99,7 +94,7 @@ def set_target_shape(args):
def run_image_encoder(args, image_encoder, vae_model): def run_image_encoder(args, image_encoder, vae_model):
if args.model_cls == "hunyuan": if args.model_cls == "hunyuan":
return None return None
elif args.model_cls == 'wan2.1': elif args.model_cls == "wan2.1":
img = Image.open(args.image_path).convert("RGB") img = Image.open(args.image_path).convert("RGB")
img = TF.to_tensor(img).sub_(0.5).div_(0.5).cuda() img = TF.to_tensor(img).sub_(0.5).div_(0.5).cuda()
clip_encoder_out = image_encoder.visual([img[:, None, :, :]]).squeeze(0).to(torch.bfloat16) clip_encoder_out = image_encoder.visual([img[:, None, :, :]]).squeeze(0).to(torch.bfloat16)
...@@ -107,34 +102,21 @@ def run_image_encoder(args, image_encoder, vae_model): ...@@ -107,34 +102,21 @@ def run_image_encoder(args, image_encoder, vae_model):
h, w = img.shape[1:] h, w = img.shape[1:]
aspect_ratio = h / w aspect_ratio = h / w
max_area = args.target_height * args.target_width max_area = args.target_height * args.target_width
lat_h = round( lat_h = round(np.sqrt(max_area * aspect_ratio) // args.vae_stride[1] // args.patch_size[1] * args.patch_size[1])
np.sqrt(max_area * aspect_ratio) // args.vae_stride[1] // lat_w = round(np.sqrt(max_area / aspect_ratio) // args.vae_stride[2] // args.patch_size[2] * args.patch_size[2])
args.patch_size[1] * args.patch_size[1])
lat_w = round(
np.sqrt(max_area / aspect_ratio) // args.vae_stride[2] //
args.patch_size[2] * args.patch_size[2])
h = lat_h * args.vae_stride[1] h = lat_h * args.vae_stride[1]
w = lat_w * args.vae_stride[2] w = lat_w * args.vae_stride[2]
args.lat_h = lat_h args.lat_h = lat_h
args.lat_w = lat_w args.lat_w = lat_w
msk = torch.ones(1, 81, lat_h, lat_w, device=torch.device('cuda')) msk = torch.ones(1, 81, lat_h, lat_w, device=torch.device("cuda"))
msk[:, 1:] = 0 msk[:, 1:] = 0
msk = torch.concat([ msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]
], dim=1)
msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w) msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
msk = msk.transpose(1, 2)[0] msk = msk.transpose(1, 2)[0]
vae_encode_out = vae_model.encode([ vae_encode_out = vae_model.encode([torch.concat([torch.nn.functional.interpolate(img[None].cpu(), size=(h, w), mode="bicubic").transpose(0, 1), torch.zeros(3, 80, h, w)], dim=1).cuda()])[0]
torch.concat([
torch.nn.functional.interpolate(
img[None].cpu(), size=(h, w), mode='bicubic').transpose(
0, 1),
torch.zeros(3, 80, h, w)
], dim=1).cuda()
])[0]
vae_encode_out = torch.concat([msk, vae_encode_out]).to(torch.bfloat16) vae_encode_out = torch.concat([msk, vae_encode_out]).to(torch.bfloat16)
return {"clip_encoder_out": clip_encoder_out, "vae_encode_out": vae_encode_out} return {"clip_encoder_out": clip_encoder_out, "vae_encode_out": vae_encode_out}
...@@ -147,8 +129,8 @@ def run_text_encoder(args, text, text_encoders, model_config): ...@@ -147,8 +129,8 @@ def run_text_encoder(args, text, text_encoders, model_config):
if args.model_cls == "hunyuan": if args.model_cls == "hunyuan":
for i, encoder in enumerate(text_encoders): for i, encoder in enumerate(text_encoders):
text_state, attention_mask = encoder.infer(text, args) text_state, attention_mask = encoder.infer(text, args)
text_encoder_output[f"text_encoder_{i+1}_text_states"] = text_state.to(dtype=torch.bfloat16) text_encoder_output[f"text_encoder_{i + 1}_text_states"] = text_state.to(dtype=torch.bfloat16)
text_encoder_output[f"text_encoder_{i+1}_attention_mask"] = attention_mask text_encoder_output[f"text_encoder_{i + 1}_attention_mask"] = attention_mask
elif args.model_cls == "wan2.1": elif args.model_cls == "wan2.1":
n_prompt = model_config.get("sample_neg_prompt", "") n_prompt = model_config.get("sample_neg_prompt", "")
...@@ -186,7 +168,6 @@ def init_scheduler(args): ...@@ -186,7 +168,6 @@ def init_scheduler(args):
def run_main_inference(args, model, text_encoder_output, image_encoder_output): def run_main_inference(args, model, text_encoder_output, image_encoder_output):
for step_index in range(model.scheduler.infer_steps): for step_index in range(model.scheduler.infer_steps):
torch.cuda.synchronize() torch.cuda.synchronize()
time1 = time.time() time1 = time.time()
...@@ -225,7 +206,7 @@ if __name__ == "__main__": ...@@ -225,7 +206,7 @@ if __name__ == "__main__":
parser.add_argument("--model_path", type=str, required=True) parser.add_argument("--model_path", type=str, required=True)
parser.add_argument("--config_path", type=str, default=None) parser.add_argument("--config_path", type=str, default=None)
parser.add_argument("--image_path", type=str, default=None) parser.add_argument("--image_path", type=str, default=None)
parser.add_argument('--save_video_path', type=str, default='./output_ligthx2v.mp4') parser.add_argument("--save_video_path", type=str, default="./output_ligthx2v.mp4")
parser.add_argument("--prompt", type=str, required=True) parser.add_argument("--prompt", type=str, required=True)
parser.add_argument("--infer_steps", type=int, required=True) parser.add_argument("--infer_steps", type=int, required=True)
parser.add_argument("--target_video_length", type=int, required=True) parser.add_argument("--target_video_length", type=int, required=True)
...@@ -235,27 +216,27 @@ if __name__ == "__main__": ...@@ -235,27 +216,27 @@ if __name__ == "__main__":
parser.add_argument("--sample_neg_prompt", type=str, default="") parser.add_argument("--sample_neg_prompt", type=str, default="")
parser.add_argument("--sample_guide_scale", type=float, default=5.0) parser.add_argument("--sample_guide_scale", type=float, default=5.0)
parser.add_argument("--sample_shift", type=float, default=5.0) parser.add_argument("--sample_shift", type=float, default=5.0)
parser.add_argument('--do_mm_calib', action='store_true') parser.add_argument("--do_mm_calib", action="store_true")
parser.add_argument('--cpu_offload', action='store_true') parser.add_argument("--cpu_offload", action="store_true")
parser.add_argument('--feature_caching', choices=["NoCaching", "TaylorSeer", "Tea"], default="NoCaching") parser.add_argument("--feature_caching", choices=["NoCaching", "TaylorSeer", "Tea"], default="NoCaching")
parser.add_argument('--mm_config', default=None) parser.add_argument("--mm_config", default=None)
parser.add_argument('--seed', type=int, default=42) parser.add_argument("--seed", type=int, default=42)
parser.add_argument('--parallel_attn', action='store_true') parser.add_argument("--parallel_attn", action="store_true")
parser.add_argument('--parallel_vae', action='store_true') parser.add_argument("--parallel_vae", action="store_true")
parser.add_argument('--max_area', action='store_true') parser.add_argument("--max_area", action="store_true")
parser.add_argument('--vae_stride', default=(4, 8, 8)) parser.add_argument("--vae_stride", default=(4, 8, 8))
parser.add_argument('--patch_size', default=(1, 2, 2)) parser.add_argument("--patch_size", default=(1, 2, 2))
parser.add_argument("--teacache_thresh", type=float, default=0.26) parser.add_argument("--teacache_thresh", type=float, default=0.26)
parser.add_argument("--use_ret_steps", action="store_true", default=False) parser.add_argument("--use_ret_steps", action="store_true", default=False)
args = parser.parse_args() args = parser.parse_args()
start_time = time.time() start_time = time.time()
print(f"args: {args}") print(f"args: {args}")
seed_all(args.seed) seed_all(args.seed)
if args.parallel_attn: if args.parallel_attn:
dist.init_process_group(backend='nccl') dist.init_process_group(backend="nccl")
if args.mm_config: if args.mm_config:
mm_config = json.loads(args.mm_config) mm_config = json.loads(args.mm_config)
...@@ -271,7 +252,7 @@ if __name__ == "__main__": ...@@ -271,7 +252,7 @@ if __name__ == "__main__":
"cpu_offload": args.cpu_offload, "cpu_offload": args.cpu_offload,
"feature_caching": args.feature_caching, "feature_caching": args.feature_caching,
"parallel_attn": args.parallel_attn, "parallel_attn": args.parallel_attn,
"parallel_vae": args.parallel_vae "parallel_vae": args.parallel_vae,
} }
if args.config_path is not None: if args.config_path is not None:
...@@ -283,7 +264,7 @@ if __name__ == "__main__": ...@@ -283,7 +264,7 @@ if __name__ == "__main__":
model, text_encoders, vae_model, image_encoder = load_models(args, model_config) model, text_encoders, vae_model, image_encoder = load_models(args, model_config)
if args.task in ['i2v']: if args.task in ["i2v"]:
image_encoder_output = run_image_encoder(args, image_encoder, vae_model) image_encoder_output = run_image_encoder(args, image_encoder, vae_model)
else: else:
image_encoder_output = {"clip_encoder_out": None, "vae_encode_out": None} image_encoder_output = {"clip_encoder_out": None, "vae_encode_out": None}
......
...@@ -3,17 +3,15 @@ from lightx2v.attentions.common.flash_attn2 import flash_attn2 ...@@ -3,17 +3,15 @@ from lightx2v.attentions.common.flash_attn2 import flash_attn2
from lightx2v.attentions.common.flash_attn3 import flash_attn3 from lightx2v.attentions.common.flash_attn3 import flash_attn3
from lightx2v.attentions.common.sage_attn2 import sage_attn2 from lightx2v.attentions.common.sage_attn2 import sage_attn2
def attention(
attention_type="flash_attn2", def attention(attention_type="flash_attn2", *args, **kwargs):
*args, **kwargs
):
if attention_type == "torch_sdpa": if attention_type == "torch_sdpa":
return torch_sdpa(*args, **kwargs) return torch_sdpa(*args, **kwargs)
elif attention_type == "flash_attn2": elif attention_type == "flash_attn2":
return flash_attn2(*args, **kwargs) return flash_attn2(*args, **kwargs)
elif attention_type == "flash_attn3": elif attention_type == "flash_attn3":
return flash_attn3(*args, **kwargs) return flash_attn3(*args, **kwargs)
elif attention_type == 'sage_attn2': elif attention_type == "sage_attn2":
return sage_attn2(*args, **kwargs) return sage_attn2(*args, **kwargs)
else: else:
raise NotImplementedError(f"Unsupported attention mode: {attention_type}") raise NotImplementedError(f"Unsupported attention mode: {attention_type}")
...@@ -3,15 +3,8 @@ try: ...@@ -3,15 +3,8 @@ try:
except ImportError: except ImportError:
flash_attn_varlen_func = None flash_attn_varlen_func = None
def flash_attn2(
q, def flash_attn2(q, k, v, cu_seqlens_q=None, cu_seqlens_kv=None, max_seqlen_q=None, max_seqlen_kv=None):
k,
v,
cu_seqlens_q=None,
cu_seqlens_kv=None,
max_seqlen_q=None,
max_seqlen_kv=None
):
x = flash_attn_varlen_func( x = flash_attn_varlen_func(
q, q,
k, k,
......
...@@ -3,15 +3,8 @@ try: ...@@ -3,15 +3,8 @@ try:
except ImportError: except ImportError:
flash_attn_varlen_func_v3 = None flash_attn_varlen_func_v3 = None
def flash_attn3(
q, def flash_attn3(q, k, v, cu_seqlens_q=None, cu_seqlens_kv=None, max_seqlen_q=None, max_seqlen_kv=None):
k,
v,
cu_seqlens_q=None,
cu_seqlens_kv=None,
max_seqlen_q=None,
max_seqlen_kv=None
):
x = flash_attn_varlen_func_v3( x = flash_attn_varlen_func_v3(
q, q,
k, k,
......
import torch import torch
try: try:
from sageattention import sageattn from sageattention import sageattn
except ImportError: except ImportError:
......
...@@ -15,9 +15,7 @@ def torch_sdpa( ...@@ -15,9 +15,7 @@ def torch_sdpa(
v = v.transpose(1, 2) v = v.transpose(1, 2)
if attn_mask is not None and attn_mask.dtype != torch.bool: if attn_mask is not None and attn_mask.dtype != torch.bool:
attn_mask = attn_mask.to(q.dtype) attn_mask = attn_mask.to(q.dtype)
x = F.scaled_dot_product_attention( x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal)
q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal
)
x = x.transpose(1, 2) x = x.transpose(1, 2)
b, s, a, d = x.shape b, s, a, d = x.shape
out = x.reshape(b, s, -1) out = x.reshape(b, s, -1)
......
...@@ -3,7 +3,7 @@ import torch.distributed as dist ...@@ -3,7 +3,7 @@ import torch.distributed as dist
def all2all_seq2head(input): def all2all_seq2head(input):
''' """
将输入张量从 [seq_len/N, heads, hidden_dims] 转换为 [seq_len, heads/N, hidden_dims] 的格式。 将输入张量从 [seq_len/N, heads, hidden_dims] 转换为 [seq_len, heads/N, hidden_dims] 的格式。
参数: 参数:
...@@ -11,9 +11,9 @@ def all2all_seq2head(input): ...@@ -11,9 +11,9 @@ def all2all_seq2head(input):
返回: 返回:
torch.Tensor: 转换后的输出张量,形状为 [seq_len, heads/N, hidden_dims] torch.Tensor: 转换后的输出张量,形状为 [seq_len, heads/N, hidden_dims]
''' """
# 确保输入是一个3D张量 # 确保输入是一个3D张量
assert (input.dim() == 3), f"input must be 3D tensor" assert input.dim() == 3, f"input must be 3D tensor"
# 获取当前进程的世界大小 # 获取当前进程的世界大小
world_size = dist.get_world_size() world_size = dist.get_world_size()
...@@ -43,7 +43,7 @@ def all2all_seq2head(input): ...@@ -43,7 +43,7 @@ def all2all_seq2head(input):
def all2all_head2seq(input): def all2all_head2seq(input):
''' """
将输入张量从 [seq_len, heads/N, hidden_dims] 转换为 [seq_len/N, heads, hidden_dims] 的格式。 将输入张量从 [seq_len, heads/N, hidden_dims] 转换为 [seq_len/N, heads, hidden_dims] 的格式。
参数: 参数:
...@@ -51,9 +51,9 @@ def all2all_head2seq(input): ...@@ -51,9 +51,9 @@ def all2all_head2seq(input):
返回: 返回:
torch.Tensor: 转换后的输出张量,形状为 [seq_len/N, heads, hidden_dims] torch.Tensor: 转换后的输出张量,形状为 [seq_len/N, heads, hidden_dims]
''' """
# 确保输入是一个3D张量 # 确保输入是一个3D张量
assert (input.dim() == 3), f"input must be 3D tensor" assert input.dim() == 3, f"input must be 3D tensor"
# 获取当前进程的世界大小 # 获取当前进程的世界大小
world_size = dist.get_world_size() world_size = dist.get_world_size()
...@@ -84,4 +84,3 @@ def all2all_head2seq(input): ...@@ -84,4 +84,3 @@ def all2all_head2seq(input):
output = output.transpose(0, 1).contiguous().reshape(shard_seq_len, heads, hidden_dims) output = output.transpose(0, 1).contiguous().reshape(shard_seq_len, heads, hidden_dims)
return output # 返回转换后的输出张量 return output # 返回转换后的输出张量
...@@ -9,14 +9,14 @@ def partial_heads_attn(attention_type, q, k, v, cu_seqlens_qkv, max_seqlen_qkv): ...@@ -9,14 +9,14 @@ def partial_heads_attn(attention_type, q, k, v, cu_seqlens_qkv, max_seqlen_qkv):
world_size = dist.get_world_size() world_size = dist.get_world_size()
num_chunk_heads = int(num_heads / dist.get_world_size()) num_chunk_heads = int(num_heads / dist.get_world_size())
if cur_rank == world_size-1: if cur_rank == world_size - 1:
q = q[:, num_chunk_heads*cur_rank:, :] q = q[:, num_chunk_heads * cur_rank :, :]
k = k[:, num_chunk_heads*cur_rank:, :] k = k[:, num_chunk_heads * cur_rank :, :]
v = v[:, num_chunk_heads*cur_rank:, :] v = v[:, num_chunk_heads * cur_rank :, :]
else: else:
q = q[:, num_chunk_heads*cur_rank:num_chunk_heads*(cur_rank+1), :] q = q[:, num_chunk_heads * cur_rank : num_chunk_heads * (cur_rank + 1), :]
k = k[:, num_chunk_heads*cur_rank:num_chunk_heads*(cur_rank+1), :] k = k[:, num_chunk_heads * cur_rank : num_chunk_heads * (cur_rank + 1), :]
v = v[:, num_chunk_heads*cur_rank:num_chunk_heads*(cur_rank+1), :] v = v[:, num_chunk_heads * cur_rank : num_chunk_heads * (cur_rank + 1), :]
output = attention( output = attention(
attention_type=attention_type, attention_type=attention_type,
...@@ -34,4 +34,4 @@ def partial_heads_attn(attention_type, q, k, v, cu_seqlens_qkv, max_seqlen_qkv): ...@@ -34,4 +34,4 @@ def partial_heads_attn(attention_type, q, k, v, cu_seqlens_qkv, max_seqlen_qkv):
combined_output = torch.cat(gathered_outputs, dim=1) combined_output = torch.cat(gathered_outputs, dim=1)
return combined_output return combined_output
\ No newline at end of file
export PYTHONPATH=/workspace/lightx2v:$PYTHONPATH export PYTHONPATH=/workspace/lightx2v:$PYTHONPATH
export CUDA_VISIBLE_DEVICES=0,1 export CUDA_VISIBLE_DEVICES=0,1
torchrun --nproc_per_node=2 test_acc.py torchrun --nproc_per_node=2 test_acc.py
\ No newline at end of file
...@@ -14,16 +14,14 @@ def prepare_tensors(): ...@@ -14,16 +14,14 @@ def prepare_tensors():
k = torch.randn(32656, 24, 128, dtype=torch.bfloat16).cuda() k = torch.randn(32656, 24, 128, dtype=torch.bfloat16).cuda()
v = torch.randn(32656, 24, 128, dtype=torch.bfloat16).cuda() v = torch.randn(32656, 24, 128, dtype=torch.bfloat16).cuda()
cu_seqlens_qkv = torch.tensor( cu_seqlens_qkv = torch.tensor([0, 32411, 32656], dtype=torch.int32).cuda()
[0, 32411, 32656], dtype=torch.int32
).cuda()
max_seqlen_qkv = 32656 max_seqlen_qkv = 32656
return q, k, v, cu_seqlens_qkv, max_seqlen_qkv return q, k, v, cu_seqlens_qkv, max_seqlen_qkv
def test_part_head(): def test_part_head():
q, k, v, cu_seqlens_qkv, max_seqlen_qkv = prepare_tensors() q, k, v, cu_seqlens_qkv, max_seqlen_qkv = prepare_tensors()
# 先计算完整的结果作为参考 # 先计算完整的结果作为参考
single_gpu_output = attention( single_gpu_output = attention(
q=q, q=q,
...@@ -39,17 +37,16 @@ def test_part_head(): ...@@ -39,17 +37,16 @@ def test_part_head():
cur_rank = dist.get_rank() cur_rank = dist.get_rank()
world_size = dist.get_world_size() world_size = dist.get_world_size()
num_chunk_heads = int(num_heads / dist.get_world_size()) num_chunk_heads = int(num_heads / dist.get_world_size())
if cur_rank == world_size-1: if cur_rank == world_size - 1:
q = q[:, num_chunk_heads*cur_rank:, :] q = q[:, num_chunk_heads * cur_rank :, :]
k = k[:, num_chunk_heads*cur_rank:, :] k = k[:, num_chunk_heads * cur_rank :, :]
v = v[:, num_chunk_heads*cur_rank:, :] v = v[:, num_chunk_heads * cur_rank :, :]
else: else:
q = q[:, num_chunk_heads*cur_rank:num_chunk_heads*(cur_rank+1), :] q = q[:, num_chunk_heads * cur_rank : num_chunk_heads * (cur_rank + 1), :]
k = k[:, num_chunk_heads*cur_rank:num_chunk_heads*(cur_rank+1), :] k = k[:, num_chunk_heads * cur_rank : num_chunk_heads * (cur_rank + 1), :]
v = v[:, num_chunk_heads*cur_rank:num_chunk_heads*(cur_rank+1), :] v = v[:, num_chunk_heads * cur_rank : num_chunk_heads * (cur_rank + 1), :]
output = attention( output = attention(
q=q, q=q,
k=k, k=k,
...@@ -69,12 +66,12 @@ def test_part_head(): ...@@ -69,12 +66,12 @@ def test_part_head():
if cur_rank == 0: if cur_rank == 0:
# import pdb; pdb.set_trace() # import pdb; pdb.set_trace()
print("Outputs match:", torch.allclose(single_gpu_output, combined_output, rtol=1e-3, atol=1e-3)) print("Outputs match:", torch.allclose(single_gpu_output, combined_output, rtol=1e-3, atol=1e-3))
# # 验证结果一致性 # # 验证结果一致性
# print("Outputs match:", torch.allclose(single_gpu_output, combined_output, rtol=1e-3, atol=1e-3)) # print("Outputs match:", torch.allclose(single_gpu_output, combined_output, rtol=1e-3, atol=1e-3))
if __name__ == "__main__": if __name__ == "__main__":
# 初始化分布式环境 # 初始化分布式环境
dist.init_process_group(backend='nccl') dist.init_process_group(backend="nccl")
test_part_head() test_part_head()
\ No newline at end of file
...@@ -2,4 +2,4 @@ from lightx2v.attentions.distributed.partial_heads_attn.attn import partial_head ...@@ -2,4 +2,4 @@ from lightx2v.attentions.distributed.partial_heads_attn.attn import partial_head
def parallelize_hunyuan(hunyuan_model): def parallelize_hunyuan(hunyuan_model):
hunyuan_model.transformer_infer.parallel_attention = partial_heads_attn hunyuan_model.transformer_infer.parallel_attention = partial_heads_attn
\ No newline at end of file
...@@ -3,9 +3,8 @@ import torch.distributed as dist ...@@ -3,9 +3,8 @@ import torch.distributed as dist
from lightx2v.attentions import attention from lightx2v.attentions import attention
def ring_attn(q, k, v, img_qkv_len, cu_seqlens_qkv, attention_type="flash_attn2"): def ring_attn(q, k, v, img_qkv_len, cu_seqlens_qkv, attention_type="flash_attn2"):
''' """
执行 Ulysses 注意力机制,结合图像和文本的查询、键和值。 执行 Ulysses 注意力机制,结合图像和文本的查询、键和值。
参数: 参数:
...@@ -18,7 +17,7 @@ def ring_attn(q, k, v, img_qkv_len, cu_seqlens_qkv, attention_type="flash_attn2" ...@@ -18,7 +17,7 @@ def ring_attn(q, k, v, img_qkv_len, cu_seqlens_qkv, attention_type="flash_attn2"
返回: 返回:
torch.Tensor: 计算得到的注意力结果 torch.Tensor: 计算得到的注意力结果
''' """
# 获取当前进程的排名和全局进程数 # 获取当前进程的排名和全局进程数
cur_rank = dist.get_rank() cur_rank = dist.get_rank()
world_size = dist.get_world_size() world_size = dist.get_world_size()
...@@ -27,15 +26,15 @@ def ring_attn(q, k, v, img_qkv_len, cu_seqlens_qkv, attention_type="flash_attn2" ...@@ -27,15 +26,15 @@ def ring_attn(q, k, v, img_qkv_len, cu_seqlens_qkv, attention_type="flash_attn2"
seq_len = q.shape[0] seq_len = q.shape[0]
txt_qkv_len = cu_seqlens_qkv[1] - img_qkv_len # 文本查询、键和值的长度 txt_qkv_len = cu_seqlens_qkv[1] - img_qkv_len # 文本查询、键和值的长度
txt_mask_len = cu_seqlens_qkv[2] - img_qkv_len # 文本掩码长度 txt_mask_len = cu_seqlens_qkv[2] - img_qkv_len # 文本掩码长度
# 获取查询张量的头数和隐藏维度 # 获取查询张量的头数和隐藏维度
_, heads, hidden_dims = q.shape _, heads, hidden_dims = q.shape
shard_heads = heads // world_size # 每个进程处理的头数 shard_heads = heads // world_size # 每个进程处理的头数
shard_seqlen = img_qkv_len # 每个进程处理的序列长度 shard_seqlen = img_qkv_len # 每个进程处理的序列长度
# 分割图像和文本的查询、键和值 # 分割图像和文本的查询、键和值
img_q, img_k, img_v = q[:img_qkv_len,:,:].contiguous(), k[:img_qkv_len,:,:].contiguous(), v[:img_qkv_len,:,:].contiguous() img_q, img_k, img_v = q[:img_qkv_len, :, :].contiguous(), k[:img_qkv_len, :, :].contiguous(), v[:img_qkv_len, :, :].contiguous()
txt_q, txt_k, txt_v = q[img_qkv_len:,:,:].contiguous(), k[img_qkv_len:,:,:].contiguous(), v[img_qkv_len:,:,:].contiguous() txt_q, txt_k, txt_v = q[img_qkv_len:, :, :].contiguous(), k[img_qkv_len:, :, :].contiguous(), v[img_qkv_len:, :, :].contiguous()
gathered_img_k = [torch.empty_like(img_k) for _ in range(world_size)] gathered_img_k = [torch.empty_like(img_k) for _ in range(world_size)]
gathered_img_v = [torch.empty_like(img_v) for _ in range(world_size)] gathered_img_v = [torch.empty_like(img_v) for _ in range(world_size)]
...@@ -45,8 +44,8 @@ def ring_attn(q, k, v, img_qkv_len, cu_seqlens_qkv, attention_type="flash_attn2" ...@@ -45,8 +44,8 @@ def ring_attn(q, k, v, img_qkv_len, cu_seqlens_qkv, attention_type="flash_attn2"
torch.cuda.synchronize() torch.cuda.synchronize()
q = q q = q
k = torch.cat(gathered_img_k+[txt_k], dim=0) k = torch.cat(gathered_img_k + [txt_k], dim=0)
v = torch.cat(gathered_img_v+[txt_v], dim=0) v = torch.cat(gathered_img_v + [txt_v], dim=0)
# 初始化累积序列长度张量 # 初始化累积序列长度张量
cu_seqlens_q = torch.zeros([3], dtype=torch.int32, device="cuda") cu_seqlens_q = torch.zeros([3], dtype=torch.int32, device="cuda")
...@@ -59,22 +58,13 @@ def ring_attn(q, k, v, img_qkv_len, cu_seqlens_qkv, attention_type="flash_attn2" ...@@ -59,22 +58,13 @@ def ring_attn(q, k, v, img_qkv_len, cu_seqlens_qkv, attention_type="flash_attn2"
# 初始化累积序列长度张量 # 初始化累积序列长度张量
cu_seqlens_kv = torch.zeros([3], dtype=torch.int32, device="cuda") cu_seqlens_kv = torch.zeros([3], dtype=torch.int32, device="cuda")
s = txt_qkv_len + img_k.shape[0]*world_size # 计算文本和图像的总长度 s = txt_qkv_len + img_k.shape[0] * world_size # 计算文本和图像的总长度
s1 = s # 当前样本的结束位置 s1 = s # 当前样本的结束位置
s2 = txt_mask_len + img_k.shape[0]*world_size # 文本掩码的结束位置 s2 = txt_mask_len + img_k.shape[0] * world_size # 文本掩码的结束位置
cu_seqlens_kv[1] = s1 # 设置累积序列长度 cu_seqlens_kv[1] = s1 # 设置累积序列长度
cu_seqlens_kv[2] = s2 # 设置累积序列长度 cu_seqlens_kv[2] = s2 # 设置累积序列长度
max_seqlen_kv = img_k.shape[0]*world_size + txt_q.shape[0] # 最大序列长度 max_seqlen_kv = img_k.shape[0] * world_size + txt_q.shape[0] # 最大序列长度
attn = attention( attn = attention(attention_type=attention_type, q=q, k=k, v=v, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv, max_seqlen_q=max_seqlen_q, max_seqlen_kv=max_seqlen_kv)
attention_type=attention_type,
q=q,
k=k,
v=v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv
)
return attn return attn
...@@ -34,14 +34,10 @@ def parallelize_hunyuan(hunyuan_model): ...@@ -34,14 +34,10 @@ def parallelize_hunyuan(hunyuan_model):
combined_output: 经过后处理的输出结果 combined_output: 经过后处理的输出结果
""" """
# 预处理输入数据 # 预处理输入数据
latent_model_input, freqs_cos, freqs_sin, split_dim = pre_process( latent_model_input, freqs_cos, freqs_sin, split_dim = pre_process(latent_model_input, freqs_cos, freqs_sin)
latent_model_input, freqs_cos, freqs_sin
)
# 调用原始推理方法,获取输出 # 调用原始推理方法,获取输出
output = original_infer( output = original_infer(latent_model_input, t_expand, text_states, text_mask, text_states_2, freqs_cos, freqs_sin, guidance)
latent_model_input, t_expand, text_states, text_mask, text_states_2, freqs_cos, freqs_sin, guidance
)
# 对输出进行后处理 # 对输出进行后处理
combined_output = post_process(output, split_dim) combined_output = post_process(output, split_dim)
...@@ -50,4 +46,4 @@ def parallelize_hunyuan(hunyuan_model): ...@@ -50,4 +46,4 @@ def parallelize_hunyuan(hunyuan_model):
# 将新的推理方法绑定到 Hunyuan 模型实例 # 将新的推理方法绑定到 Hunyuan 模型实例
new_infer = new_infer.__get__(hunyuan_model) new_infer = new_infer.__get__(hunyuan_model)
hunyuan_model.infer = new_infer # 替换原始推理方法 hunyuan_model.infer = new_infer # 替换原始推理方法
\ No newline at end of file
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from lightx2v.attentions import attention from lightx2v.attentions import attention
...@@ -6,7 +5,7 @@ from lightx2v.attentions.distributed.comm.all2all import all2all_seq2head, all2a ...@@ -6,7 +5,7 @@ from lightx2v.attentions.distributed.comm.all2all import all2all_seq2head, all2a
def ulysses_attn(q, k, v, img_qkv_len, cu_seqlens_qkv, attention_type="flash_attn2"): def ulysses_attn(q, k, v, img_qkv_len, cu_seqlens_qkv, attention_type="flash_attn2"):
''' """
执行 Ulysses 注意力机制,结合图像和文本的查询、键和值。 执行 Ulysses 注意力机制,结合图像和文本的查询、键和值。
参数: 参数:
...@@ -19,11 +18,11 @@ def ulysses_attn(q, k, v, img_qkv_len, cu_seqlens_qkv, attention_type="flash_att ...@@ -19,11 +18,11 @@ def ulysses_attn(q, k, v, img_qkv_len, cu_seqlens_qkv, attention_type="flash_att
返回: 返回:
torch.Tensor: 计算得到的注意力结果 torch.Tensor: 计算得到的注意力结果
''' """
# 获取当前进程的排名和全局进程数 # 获取当前进程的排名和全局进程数
cur_rank = dist.get_rank() cur_rank = dist.get_rank()
world_size = dist.get_world_size() world_size = dist.get_world_size()
# 获取序列长度和文本相关的长度 # 获取序列长度和文本相关的长度
seq_len = q.shape[0] seq_len = q.shape[0]
if len(cu_seqlens_qkv) == 3: if len(cu_seqlens_qkv) == 3:
...@@ -32,15 +31,15 @@ def ulysses_attn(q, k, v, img_qkv_len, cu_seqlens_qkv, attention_type="flash_att ...@@ -32,15 +31,15 @@ def ulysses_attn(q, k, v, img_qkv_len, cu_seqlens_qkv, attention_type="flash_att
elif len(cu_seqlens_qkv) == 2: elif len(cu_seqlens_qkv) == 2:
txt_qkv_len = cu_seqlens_qkv[1] - img_qkv_len # 文本查询、键和值的长度 txt_qkv_len = cu_seqlens_qkv[1] - img_qkv_len # 文本查询、键和值的长度
txt_mask_len = None txt_mask_len = None
# 获取查询张量的头数和隐藏维度 # 获取查询张量的头数和隐藏维度
_, heads, hidden_dims = q.shape _, heads, hidden_dims = q.shape
shard_heads = heads // world_size # 每个进程处理的头数 shard_heads = heads // world_size # 每个进程处理的头数
shard_seqlen = img_qkv_len # 每个进程处理的序列长度 shard_seqlen = img_qkv_len # 每个进程处理的序列长度
# 分割图像和文本的查询、键和值 # 分割图像和文本的查询、键和值
img_q, img_k, img_v = q[:img_qkv_len,:,:].contiguous(), k[:img_qkv_len,:,:].contiguous(), v[:img_qkv_len,:,:].contiguous() img_q, img_k, img_v = q[:img_qkv_len, :, :].contiguous(), k[:img_qkv_len, :, :].contiguous(), v[:img_qkv_len, :, :].contiguous()
txt_q, txt_k, txt_v = q[img_qkv_len:,:,:].contiguous(), k[img_qkv_len:,:,:].contiguous(), v[img_qkv_len:,:,:].contiguous() txt_q, txt_k, txt_v = q[img_qkv_len:, :, :].contiguous(), k[img_qkv_len:, :, :].contiguous(), v[img_qkv_len:, :, :].contiguous()
# 将图像的查询、键和值转换为头的格式 # 将图像的查询、键和值转换为头的格式
img_q = all2all_seq2head(img_q) img_q = all2all_seq2head(img_q)
...@@ -49,9 +48,9 @@ def ulysses_attn(q, k, v, img_qkv_len, cu_seqlens_qkv, attention_type="flash_att ...@@ -49,9 +48,9 @@ def ulysses_attn(q, k, v, img_qkv_len, cu_seqlens_qkv, attention_type="flash_att
torch.cuda.synchronize() # 确保CUDA操作完成 torch.cuda.synchronize() # 确保CUDA操作完成
# 处理文本的查询、键和值,选择当前进程的头 # 处理文本的查询、键和值,选择当前进程的头
txt_q = txt_q[:,cur_rank*shard_heads:(cur_rank+1)*shard_heads,:] txt_q = txt_q[:, cur_rank * shard_heads : (cur_rank + 1) * shard_heads, :]
txt_k = txt_k[:,cur_rank*shard_heads:(cur_rank+1)*shard_heads,:] txt_k = txt_k[:, cur_rank * shard_heads : (cur_rank + 1) * shard_heads, :]
txt_v = txt_v[:,cur_rank*shard_heads:(cur_rank+1)*shard_heads,:] txt_v = txt_v[:, cur_rank * shard_heads : (cur_rank + 1) * shard_heads, :]
# 合并图像和文本的查询、键和值 # 合并图像和文本的查询、键和值
q = torch.cat((img_q, txt_q), dim=0) q = torch.cat((img_q, txt_q), dim=0)
...@@ -69,26 +68,17 @@ def ulysses_attn(q, k, v, img_qkv_len, cu_seqlens_qkv, attention_type="flash_att ...@@ -69,26 +68,17 @@ def ulysses_attn(q, k, v, img_qkv_len, cu_seqlens_qkv, attention_type="flash_att
max_seqlen_qkv = img_q.shape[0] + txt_q.shape[0] # 最大序列长度 max_seqlen_qkv = img_q.shape[0] + txt_q.shape[0] # 最大序列长度
# 调用注意力函数计算注意力结果 # 调用注意力函数计算注意力结果
attn = attention( attn = attention(attention_type=attention_type, q=q, k=k, v=v, cu_seqlens_q=cu_seqlens_qkv, cu_seqlens_kv=cu_seqlens_qkv, max_seqlen_q=max_seqlen_qkv, max_seqlen_kv=max_seqlen_qkv)
attention_type=attention_type,
q=q,
k=k,
v=v,
cu_seqlens_q=cu_seqlens_qkv,
cu_seqlens_kv=cu_seqlens_qkv,
max_seqlen_q=max_seqlen_qkv,
max_seqlen_kv=max_seqlen_qkv
)
# 分割图像和文本的注意力结果 # 分割图像和文本的注意力结果
img_attn, txt_attn = attn[:img_q.shape[0],:], attn[img_q.shape[0]:,] img_attn, txt_attn = attn[: img_q.shape[0], :], attn[img_q.shape[0] :,]
# 收集所有进程的文本注意力结果 # 收集所有进程的文本注意力结果
gathered_txt_attn = [torch.empty_like(txt_attn) for _ in range(world_size)] gathered_txt_attn = [torch.empty_like(txt_attn) for _ in range(world_size)]
dist.all_gather(gathered_txt_attn, txt_attn) dist.all_gather(gathered_txt_attn, txt_attn)
# 处理图像注意力结果 # 处理图像注意力结果
img_attn = img_attn.reshape(world_size*shard_seqlen, shard_heads, hidden_dims) # 重塑图像注意力结果 img_attn = img_attn.reshape(world_size * shard_seqlen, shard_heads, hidden_dims) # 重塑图像注意力结果
img_attn = all2all_head2seq(img_attn) # 将头的格式转换回序列格式 img_attn = all2all_head2seq(img_attn) # 将头的格式转换回序列格式
img_attn = img_attn.reshape(shard_seqlen, -1) # 重塑为 [shard_seqlen, -1] 形状 img_attn = img_attn.reshape(shard_seqlen, -1) # 重塑为 [shard_seqlen, -1] 形状
...@@ -98,4 +88,4 @@ def ulysses_attn(q, k, v, img_qkv_len, cu_seqlens_qkv, attention_type="flash_att ...@@ -98,4 +88,4 @@ def ulysses_attn(q, k, v, img_qkv_len, cu_seqlens_qkv, attention_type="flash_att
# 合并图像和文本的注意力结果 # 合并图像和文本的注意力结果
attn = torch.cat([img_attn, txt_attn], dim=0) attn = torch.cat([img_attn, txt_attn], dim=0)
return attn # 返回最终的注意力结果 return attn # 返回最终的注意力结果
\ No newline at end of file
import functools import functools
from lightx2v.attentions.distributed.ulysses.attn import ulysses_attn from lightx2v.attentions.distributed.ulysses.attn import ulysses_attn
def parallelize_hunyuan(hunyuan_model): def parallelize_hunyuan(hunyuan_model):
from lightx2v.attentions.distributed.utils.hunyuan.processor import pre_process, post_process from lightx2v.attentions.distributed.utils.hunyuan.processor import pre_process, post_process
"""将 Hunyuan 模型的推理过程并行化,使用 Ulysses 注意力机制。 """将 Hunyuan 模型的推理过程并行化,使用 Ulysses 注意力机制。
参数: 参数:
...@@ -27,31 +29,19 @@ def parallelize_hunyuan(hunyuan_model): ...@@ -27,31 +29,19 @@ def parallelize_hunyuan(hunyuan_model):
None None
""" """
# 保存原始的潜在模型输入和频率数据 # 保存原始的潜在模型输入和频率数据
self.scheduler.ori_latents, self.scheduler.ori_freqs_cos, self.scheduler.ori_freqs_sin = ( self.scheduler.ori_latents, self.scheduler.ori_freqs_cos, self.scheduler.ori_freqs_sin = (self.scheduler.latents, self.scheduler.freqs_cos, self.scheduler.freqs_sin)
self.scheduler.latents,
self.scheduler.freqs_cos,
self.scheduler.freqs_sin
)
# 预处理输入数据以适应并行计算 # 预处理输入数据以适应并行计算
self.scheduler.latents, self.scheduler.freqs_cos, self.scheduler.freqs_sin, split_dim = pre_process( self.scheduler.latents, self.scheduler.freqs_cos, self.scheduler.freqs_sin, split_dim = pre_process(self.scheduler.latents, self.scheduler.freqs_cos, self.scheduler.freqs_sin)
self.scheduler.latents, self.scheduler.freqs_cos, self.scheduler.freqs_sin
)
# 调用原始推理方法,获取输出 # 调用原始推理方法,获取输出
original_infer( original_infer(text_encoders_output, image_encoder_output, args)
text_encoders_output, image_encoder_output, args
)
# 对输出进行后处理 # 对输出进行后处理
self.scheduler.noise_pred = post_process(self.scheduler.noise_pred, split_dim) self.scheduler.noise_pred = post_process(self.scheduler.noise_pred, split_dim)
# 恢复原始的潜在模型输入和频率数据 # 恢复原始的潜在模型输入和频率数据
self.scheduler.latents, self.scheduler.freqs_cos, self.scheduler.freqs_sin = ( self.scheduler.latents, self.scheduler.freqs_cos, self.scheduler.freqs_sin = (self.scheduler.ori_latents, self.scheduler.ori_freqs_cos, self.scheduler.ori_freqs_sin)
self.scheduler.ori_latents,
self.scheduler.ori_freqs_cos,
self.scheduler.ori_freqs_sin
)
# return combined_output # 返回处理后的输出(当前被注释掉) # return combined_output # 返回处理后的输出(当前被注释掉)
...@@ -62,24 +52,20 @@ def parallelize_hunyuan(hunyuan_model): ...@@ -62,24 +52,20 @@ def parallelize_hunyuan(hunyuan_model):
def parallelize_wan(wan_model): def parallelize_wan(wan_model):
from lightx2v.attentions.distributed.utils.wan.processor import pre_process, post_process from lightx2v.attentions.distributed.utils.wan.processor import pre_process, post_process
wan_model.transformer_infer.parallel_attention = ulysses_attn wan_model.transformer_infer.parallel_attention = ulysses_attn
original_infer = wan_model.transformer_infer.infer original_infer = wan_model.transformer_infer.infer
@functools.wraps(wan_model.transformer_infer.__class__.infer) # 保留原始推理方法的元信息 @functools.wraps(wan_model.transformer_infer.__class__.infer) # 保留原始推理方法的元信息
def new_infer(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context): def new_infer(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
x = pre_process(x)
x = pre_process(
x
)
x = original_infer(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context) x = original_infer(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context)
x = post_process( x = post_process(x)
x
)
return x return x
new_infer = new_infer.__get__(wan_model.transformer_infer) new_infer = new_infer.__get__(wan_model.transformer_infer)
wan_model.transformer_infer.infer = new_infer # 替换原始推理方法 wan_model.transformer_infer.infer = new_infer # 替换原始推理方法
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment