Commit 18e2b23a authored by wangshankun's avatar wangshankun
Browse files

Wan系列模型可以使用radial attention;hunyuan系使用旧方法

parent e75d0db7
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
"target_video_length": 81, "target_video_length": 81,
"target_height": 480, "target_height": 480,
"target_width": 832, "target_width": 832,
"self_attn_1_type": "radial_attn", "self_attn_1_type": "flash_attn3",
"cross_attn_1_type": "flash_attn3", "cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3", "cross_attn_2_type": "flash_attn3",
"seed": 42, "seed": 42,
......
...@@ -3,8 +3,6 @@ ...@@ -3,8 +3,6 @@
"target_video_length": 33, "target_video_length": 33,
"target_height": 720, "target_height": 720,
"target_width": 1280, "target_width": 1280,
"self_attn_1_type": "flash_attn3", "attention_type": "flash_attn3",
"cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3",
"seed": 0 "seed": 0
} }
...@@ -3,8 +3,6 @@ ...@@ -3,8 +3,6 @@
"target_video_length": 33, "target_video_length": 33,
"target_height": 720, "target_height": 720,
"target_width": 1280, "target_width": 1280,
"self_attn_1_type": "flash_attn3", "attention_type": "flash_attn3",
"cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3",
"seed": 42 "seed": 42
} }
...@@ -3,9 +3,7 @@ ...@@ -3,9 +3,7 @@
"target_video_length": 33, "target_video_length": 33,
"target_height": 720, "target_height": 720,
"target_width": 1280, "target_width": 1280,
"self_attn_1_type": "flash_attn3", "attention_type": "flash_attn3",
"cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3",
"seed": 42, "seed": 42,
"parallel_attn_type": "ring" "parallel_attn_type": "ring"
} }
...@@ -3,9 +3,7 @@ ...@@ -3,9 +3,7 @@
"target_video_length": 33, "target_video_length": 33,
"target_height": 720, "target_height": 720,
"target_width": 1280, "target_width": 1280,
"self_attn_1_type": "flash_attn3", "attention_type": "flash_attn3",
"cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3",
"seed": 42, "seed": 42,
"parallel_attn_type": "ulysses" "parallel_attn_type": "ulysses"
} }
...@@ -2,9 +2,7 @@ ...@@ -2,9 +2,7 @@
"infer_steps": 20, "infer_steps": 20,
"target_video_length": 33, "target_video_length": 33,
"i2v_resolution": "720p", "i2v_resolution": "720p",
"self_attn_1_type": "flash_attn3", "attention_type": "flash_attn3",
"cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3",
"seed": 0, "seed": 0,
"dit_quantized_ckpt": "/path/to/int8/model", "dit_quantized_ckpt": "/path/to/int8/model",
"mm_config": { "mm_config": {
......
...@@ -8,19 +8,11 @@ from lightx2v.models.networks.wan.weights.post_weights import WanPostWeights ...@@ -8,19 +8,11 @@ from lightx2v.models.networks.wan.weights.post_weights import WanPostWeights
from lightx2v.models.networks.wan.weights.transformer_weights import ( from lightx2v.models.networks.wan.weights.transformer_weights import (
WanTransformerWeights, WanTransformerWeights,
) )
from lightx2v.models.networks.wan.infer.pre_infer import WanPreInfer
from lightx2v.models.networks.wan.infer.post_infer import WanPostInfer
from lightx2v.models.networks.wan.infer.pre_infer import WanPreInfer
from lightx2v.models.networks.wan.infer.audio.pre_wan_audio_infer import WanAudioPreInfer from lightx2v.models.networks.wan.infer.audio.pre_wan_audio_infer import WanAudioPreInfer
from lightx2v.models.networks.wan.infer.audio.post_wan_audio_infer import WanAudioPostInfer from lightx2v.models.networks.wan.infer.audio.post_wan_audio_infer import WanAudioPostInfer
from lightx2v.models.networks.wan.infer.feature_caching.transformer_infer import WanTransformerInferTeaCaching from lightx2v.models.networks.wan.infer.feature_caching.transformer_infer import WanTransformerInferTeaCaching
from safetensors import safe_open
import lightx2v.attentions.distributed.ulysses.wrap as ulysses_dist_wrap
import lightx2v.attentions.distributed.ring.wrap as ring_dist_wrap
from lightx2v.attentions.common.radial_attn import MaskMap from lightx2v.attentions.common.radial_attn import MaskMap
from lightx2v.models.networks.wan.infer.transformer_infer import ( from lightx2v.models.networks.wan.infer.transformer_infer import (
WanTransformerInfer, WanTransformerInfer,
) )
......
import os import os
import torch import torch
from lightx2v.attentions.common.radial_attn import MaskMap
from lightx2v.models.networks.wan.model import WanModel from lightx2v.models.networks.wan.model import WanModel
from lightx2v.models.networks.wan.weights.pre_weights import WanPreWeights from lightx2v.models.networks.wan.weights.pre_weights import WanPreWeights
from lightx2v.models.networks.wan.weights.post_weights import WanPostWeights from lightx2v.models.networks.wan.weights.post_weights import WanPostWeights
...@@ -48,6 +49,11 @@ class WanCausVidModel(WanModel): ...@@ -48,6 +49,11 @@ class WanCausVidModel(WanModel):
@torch.no_grad() @torch.no_grad()
def infer(self, inputs, kv_start, kv_end): def infer(self, inputs, kv_start, kv_end):
if self.transformer_infer.mask_map is None:
_, c, h, w = self.scheduler.latents.shape
video_token_num = c * (h // 2) * (w // 2)
self.transformer_infer.mask_map = MaskMap(video_token_num, c)
if self.config["cpu_offload"]: if self.config["cpu_offload"]:
self.pre_weight.to_cuda() self.pre_weight.to_cuda()
self.post_weight.to_cuda() self.post_weight.to_cuda()
......
...@@ -124,6 +124,7 @@ class WanTransformerInferCausVid(WanTransformerInfer): ...@@ -124,6 +124,7 @@ class WanTransformerInferCausVid(WanTransformerInfer):
max_seqlen_q=q.size(0), max_seqlen_q=q.size(0),
max_seqlen_kv=k.size(0), max_seqlen_kv=k.size(0),
model_cls=self.config["model_cls"], model_cls=self.config["model_cls"],
mask_map=self.mask_map,
) )
else: else:
# TODO: Implement parallel attention for causvid inference # TODO: Implement parallel attention for causvid inference
......
...@@ -2,6 +2,7 @@ import os ...@@ -2,6 +2,7 @@ import os
import torch import torch
import glob import glob
import json import json
from lightx2v.attentions.common.radial_attn import MaskMap
from lightx2v.models.networks.wan.weights.pre_weights import WanPreWeights from lightx2v.models.networks.wan.weights.pre_weights import WanPreWeights
from lightx2v.models.networks.wan.weights.post_weights import WanPostWeights from lightx2v.models.networks.wan.weights.post_weights import WanPostWeights
from lightx2v.models.networks.wan.weights.transformer_weights import ( from lightx2v.models.networks.wan.weights.transformer_weights import (
...@@ -201,6 +202,11 @@ class WanModel: ...@@ -201,6 +202,11 @@ class WanModel:
@torch.no_grad() @torch.no_grad()
def infer(self, inputs): def infer(self, inputs):
if self.transformer_infer.mask_map is None:
_, c, h, w = self.scheduler.latents.shape
video_token_num = c * (h // 2) * (w // 2)
self.transformer_infer.mask_map = MaskMap(video_token_num, c)
if self.config.get("cpu_offload", False): if self.config.get("cpu_offload", False):
self.pre_weight.to_cuda() self.pre_weight.to_cuda()
self.post_weight.to_cuda() self.post_weight.to_cuda()
......
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
lightx2v_path= lightx2v_path=
model_path= model_path=
lora_path=
# check section # check section
if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment