Commit 18e2b23a authored by wangshankun's avatar wangshankun
Browse files

Wan系列模型可以使用radial attention;hunyuan系使用旧方法

parent e75d0db7
......@@ -3,7 +3,7 @@
"target_video_length": 81,
"target_height": 480,
"target_width": 832,
"self_attn_1_type": "radial_attn",
"self_attn_1_type": "flash_attn3",
"cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3",
"seed": 42,
......
......@@ -3,8 +3,6 @@
"target_video_length": 33,
"target_height": 720,
"target_width": 1280,
"self_attn_1_type": "flash_attn3",
"cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3",
"attention_type": "flash_attn3",
"seed": 0
}
......@@ -3,8 +3,6 @@
"target_video_length": 33,
"target_height": 720,
"target_width": 1280,
"self_attn_1_type": "flash_attn3",
"cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3",
"attention_type": "flash_attn3",
"seed": 42
}
......@@ -3,9 +3,7 @@
"target_video_length": 33,
"target_height": 720,
"target_width": 1280,
"self_attn_1_type": "flash_attn3",
"cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3",
"attention_type": "flash_attn3",
"seed": 42,
"parallel_attn_type": "ring"
}
......@@ -3,9 +3,7 @@
"target_video_length": 33,
"target_height": 720,
"target_width": 1280,
"self_attn_1_type": "flash_attn3",
"cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3",
"attention_type": "flash_attn3",
"seed": 42,
"parallel_attn_type": "ulysses"
}
......@@ -2,9 +2,7 @@
"infer_steps": 20,
"target_video_length": 33,
"i2v_resolution": "720p",
"self_attn_1_type": "flash_attn3",
"cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3",
"attention_type": "flash_attn3",
"seed": 0,
"dit_quantized_ckpt": "/path/to/int8/model",
"mm_config": {
......
......@@ -8,19 +8,11 @@ from lightx2v.models.networks.wan.weights.post_weights import WanPostWeights
from lightx2v.models.networks.wan.weights.transformer_weights import (
WanTransformerWeights,
)
from lightx2v.models.networks.wan.infer.pre_infer import WanPreInfer
from lightx2v.models.networks.wan.infer.post_infer import WanPostInfer
from lightx2v.models.networks.wan.infer.pre_infer import WanPreInfer
from lightx2v.models.networks.wan.infer.audio.pre_wan_audio_infer import WanAudioPreInfer
from lightx2v.models.networks.wan.infer.audio.post_wan_audio_infer import WanAudioPostInfer
from lightx2v.models.networks.wan.infer.feature_caching.transformer_infer import WanTransformerInferTeaCaching
from safetensors import safe_open
import lightx2v.attentions.distributed.ulysses.wrap as ulysses_dist_wrap
import lightx2v.attentions.distributed.ring.wrap as ring_dist_wrap
from lightx2v.attentions.common.radial_attn import MaskMap
from lightx2v.models.networks.wan.infer.transformer_infer import (
WanTransformerInfer,
)
......
import os
import torch
from lightx2v.attentions.common.radial_attn import MaskMap
from lightx2v.models.networks.wan.model import WanModel
from lightx2v.models.networks.wan.weights.pre_weights import WanPreWeights
from lightx2v.models.networks.wan.weights.post_weights import WanPostWeights
......@@ -48,6 +49,11 @@ class WanCausVidModel(WanModel):
@torch.no_grad()
def infer(self, inputs, kv_start, kv_end):
if self.transformer_infer.mask_map is None:
_, c, h, w = self.scheduler.latents.shape
video_token_num = c * (h // 2) * (w // 2)
self.transformer_infer.mask_map = MaskMap(video_token_num, c)
if self.config["cpu_offload"]:
self.pre_weight.to_cuda()
self.post_weight.to_cuda()
......
......@@ -124,6 +124,7 @@ class WanTransformerInferCausVid(WanTransformerInfer):
max_seqlen_q=q.size(0),
max_seqlen_kv=k.size(0),
model_cls=self.config["model_cls"],
mask_map=self.mask_map,
)
else:
# TODO: Implement parallel attention for causvid inference
......
......@@ -2,6 +2,7 @@ import os
import torch
import glob
import json
from lightx2v.attentions.common.radial_attn import MaskMap
from lightx2v.models.networks.wan.weights.pre_weights import WanPreWeights
from lightx2v.models.networks.wan.weights.post_weights import WanPostWeights
from lightx2v.models.networks.wan.weights.transformer_weights import (
......@@ -201,6 +202,11 @@ class WanModel:
@torch.no_grad()
def infer(self, inputs):
if self.transformer_infer.mask_map is None:
_, c, h, w = self.scheduler.latents.shape
video_token_num = c * (h // 2) * (w // 2)
self.transformer_infer.mask_map = MaskMap(video_token_num, c)
if self.config.get("cpu_offload", False):
self.pre_weight.to_cuda()
self.post_weight.to_cuda()
......
......@@ -4,7 +4,6 @@
lightx2v_path=
model_path=
lora_path=
# check section
if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment