Commit 9a686a73 authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

Support wan2.1 sageattn and fix oom for 720P. (#12)

parent a951c882
......@@ -328,6 +328,7 @@ if __name__ == "__main__":
mm_config = None
model_config = {
"model_cls": args.model_cls,
"task": args.task,
"attention_type": args.attention_type,
"sample_neg_prompt": args.sample_neg_prompt,
......@@ -348,6 +349,9 @@ if __name__ == "__main__":
model, text_encoders, vae_model, image_encoder = load_models(args, model_config)
load_models_time = time.time()
print(f"Load models cost: {load_models_time - start_time}")
if args.task in ["i2v"]:
image_encoder_output = run_image_encoder(args, image_encoder, vae_model)
else:
......@@ -362,19 +366,23 @@ if __name__ == "__main__":
gc.collect()
torch.cuda.empty_cache()
latents, generator = run_main_inference(args, model, text_encoder_output, image_encoder_output)
gc.collect()
torch.cuda.empty_cache()
if args.cpu_offload:
scheduler.clear()
del text_encoder_output, image_encoder_output, model, text_encoders, scheduler
torch.cuda.empty_cache()
images = run_vae(latents, generator, args)
if not args.parallel_attn_type or (args.parallel_attn_type and dist.get_rank() == 0):
save_video_st = time.time()
if args.model_cls == "wan2.1":
cache_video(tensor=images, save_file=args.save_video_path, fps=16, nrow=1, normalize=True, value_range=(-1, 1))
else:
save_videos_grid(images, args.save_video_path, fps=24)
save_video_et = time.time()
print(f"Save video cost: {save_video_et - save_video_st}")
end_time = time.time()
print(f"Total time: {end_time - start_time}")
print(f"Total cost: {end_time - start_time}")
import torch
try:
from sageattention import sageattn
except ImportError:
sageattn = None
if torch.cuda.get_device_capability(0) == (8, 9):
try:
from sageattention import sageattn_qk_int8_pv_fp16_triton as sageattn
except ImportError:
sageattn = None, None
else:
try:
from sageattention import sageattn
except ImportError:
sageattn = None
def sage_attn2(
q,
k,
v,
cu_seqlens_q=None,
cu_seqlens_kv=None,
max_seqlen_q=None,
max_seqlen_kv=None,
):
def sage_attn2(q, k, v, cu_seqlens_q=None, cu_seqlens_kv=None, max_seqlen_q=None, max_seqlen_kv=None, model_cls="hunyuan"):
q, k, v = (
q.transpose(1, 0).contiguous(),
k.transpose(1, 0).contiguous(),
v.transpose(1, 0).contiguous(),
)
x1 = sageattn(
q[:, : cu_seqlens_q[1], :].unsqueeze(0),
k[:, : cu_seqlens_q[1], :].unsqueeze(0),
v[:, : cu_seqlens_kv[1], :].unsqueeze(0),
)
x2 = sageattn(
q[:, cu_seqlens_q[1] :, :].unsqueeze(0),
k[:, cu_seqlens_kv[1] :, :].unsqueeze(0),
v[:, cu_seqlens_kv[1] :, :].unsqueeze(0),
)
x = torch.cat((x1, x2), dim=-2).transpose(2, 1).contiguous()
x = x.view(max_seqlen_q, -1)
if model_cls == "hunyuan":
x1 = sageattn(
q[:, : cu_seqlens_q[1], :].unsqueeze(0),
k[:, : cu_seqlens_q[1], :].unsqueeze(0),
v[:, : cu_seqlens_kv[1], :].unsqueeze(0),
)
x2 = sageattn(
q[:, cu_seqlens_q[1] :, :].unsqueeze(0),
k[:, cu_seqlens_kv[1] :, :].unsqueeze(0),
v[:, cu_seqlens_kv[1] :, :].unsqueeze(0),
)
x = torch.cat((x1, x2), dim=-2).transpose(2, 1).contiguous()
x = x.view(max_seqlen_q, -1)
elif model_cls == "wan2.1":
x = (
sageattn(
q[:, : cu_seqlens_q[1], :].unsqueeze(0),
k[:, : cu_seqlens_q[1], :].unsqueeze(0),
v[:, : cu_seqlens_kv[1], :].unsqueeze(0),
)
.transpose(2, 1)
.contiguous()
)
x = x.view(max_seqlen_q, -1)
return x
import numpy as np
from ..transformer_infer import WanTransformerInfer
from lightx2v.attentions import attention
import torch
class WanTransformerInferFeatureCaching(WanTransformerInfer):
......@@ -61,6 +61,10 @@ class WanTransformerInferFeatureCaching(WanTransformerInfer):
context,
)
self.scheduler.previous_residual_even = x - ori_x
if self.config["cpu_offload"]:
ori_x = ori_x.to("cpu")
del ori_x
torch.cuda.empty_cache()
else:
if not should_calc_odd:
x += self.scheduler.previous_residual_odd
......@@ -77,5 +81,8 @@ class WanTransformerInferFeatureCaching(WanTransformerInfer):
context,
)
self.scheduler.previous_residual_odd = x - ori_x
if self.config["cpu_offload"]:
ori_x = ori_x.to("cpu")
del ori_x
torch.cuda.empty_cache()
return x
......@@ -98,14 +98,7 @@ class WanTransformerInfer:
if not self.parallel_attention:
attn_out = attention(
attention_type=self.attention_type,
q=q,
k=k,
v=v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_k,
max_seqlen_q=lq,
max_seqlen_kv=lk,
attention_type=self.attention_type, q=q, k=k, v=v, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_k, max_seqlen_q=lq, max_seqlen_kv=lk, model_cls=self.config["model_cls"]
)
else:
attn_out = self.parallel_attention(
......@@ -136,14 +129,7 @@ class WanTransformerInfer:
cu_seqlens_q, cu_seqlens_k, lq, lk = self._calculate_q_k_len(q, k, k_lens=torch.tensor([k.size(0)], dtype=torch.int32, device=k.device))
attn_out = attention(
attention_type=self.attention_type,
q=q,
k=k,
v=v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_k,
max_seqlen_q=lq,
max_seqlen_kv=lk,
attention_type=self.attention_type, q=q, k=k, v=v, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_k, max_seqlen_q=lq, max_seqlen_kv=lk, model_cls=self.config["model_cls"]
)
if self.task == "i2v":
......@@ -157,14 +143,7 @@ class WanTransformerInfer:
)
img_attn_out = attention(
attention_type=self.attention_type,
q=q,
k=k_img,
v=v_img,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_k,
max_seqlen_q=lq,
max_seqlen_kv=lk,
attention_type=self.attention_type, q=q, k=k_img, v=v_img, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_k, max_seqlen_q=lq, max_seqlen_kv=lk, model_cls=self.config["model_cls"]
)
attn_out = attn_out + img_attn_out
......
......@@ -71,3 +71,18 @@ class WanSchedulerFeatureCaching(WanScheduler):
self.coefficients = [-5784.54975374, 5449.50911966, -1811.16591783, 256.27178429, -13.02252404]
self.ret_steps = 1 * 2
self.cutoff_steps = self.args.infer_steps * 2 - 2
def clear(self):
if self.previous_e0_even is not None:
self.previous_e0_even = self.previous_e0_even.cpu()
if self.previous_e0_odd is not None:
self.previous_e0_odd = self.previous_e0_odd.cpu()
if self.previous_residual_even is not None:
self.previous_residual_even = self.previous_residual_even.cpu()
if self.previous_residual_odd is not None:
self.previous_residual_odd = self.previous_residual_odd.cpu()
self.previous_e0_even = None
self.previous_e0_odd = None
self.previous_residual_even = None
self.previous_residual_odd = None
torch.cuda.empty_cache()
......@@ -341,3 +341,6 @@ class WanScheduler(BaseScheduler):
self.lower_order_nums += 1
self.latents = prev_sample
def clear(self):
pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment