Commit 9a686a73 authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

Support wan2.1 sageattn and fix oom for 720P. (#12)

parent a951c882
...@@ -328,6 +328,7 @@ if __name__ == "__main__": ...@@ -328,6 +328,7 @@ if __name__ == "__main__":
mm_config = None mm_config = None
model_config = { model_config = {
"model_cls": args.model_cls,
"task": args.task, "task": args.task,
"attention_type": args.attention_type, "attention_type": args.attention_type,
"sample_neg_prompt": args.sample_neg_prompt, "sample_neg_prompt": args.sample_neg_prompt,
...@@ -348,6 +349,9 @@ if __name__ == "__main__": ...@@ -348,6 +349,9 @@ if __name__ == "__main__":
model, text_encoders, vae_model, image_encoder = load_models(args, model_config) model, text_encoders, vae_model, image_encoder = load_models(args, model_config)
load_models_time = time.time()
print(f"Load models cost: {load_models_time - start_time}")
if args.task in ["i2v"]: if args.task in ["i2v"]:
image_encoder_output = run_image_encoder(args, image_encoder, vae_model) image_encoder_output = run_image_encoder(args, image_encoder, vae_model)
else: else:
...@@ -362,19 +366,23 @@ if __name__ == "__main__": ...@@ -362,19 +366,23 @@ if __name__ == "__main__":
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
latents, generator = run_main_inference(args, model, text_encoder_output, image_encoder_output) latents, generator = run_main_inference(args, model, text_encoder_output, image_encoder_output)
gc.collect() if args.cpu_offload:
torch.cuda.empty_cache() scheduler.clear()
del text_encoder_output, image_encoder_output, model, text_encoders, scheduler
torch.cuda.empty_cache()
images = run_vae(latents, generator, args) images = run_vae(latents, generator, args)
if not args.parallel_attn_type or (args.parallel_attn_type and dist.get_rank() == 0): if not args.parallel_attn_type or (args.parallel_attn_type and dist.get_rank() == 0):
save_video_st = time.time()
if args.model_cls == "wan2.1": if args.model_cls == "wan2.1":
cache_video(tensor=images, save_file=args.save_video_path, fps=16, nrow=1, normalize=True, value_range=(-1, 1)) cache_video(tensor=images, save_file=args.save_video_path, fps=16, nrow=1, normalize=True, value_range=(-1, 1))
else: else:
save_videos_grid(images, args.save_video_path, fps=24) save_videos_grid(images, args.save_video_path, fps=24)
save_video_et = time.time()
print(f"Save video cost: {save_video_et - save_video_st}")
end_time = time.time() end_time = time.time()
print(f"Total time: {end_time - start_time}") print(f"Total cost: {end_time - start_time}")
import torch import torch
try: if torch.cuda.get_device_capability(0) == (8, 9):
from sageattention import sageattn try:
except ImportError: from sageattention import sageattn_qk_int8_pv_fp16_triton as sageattn
sageattn = None except ImportError:
sageattn = None, None
else:
try:
from sageattention import sageattn
except ImportError:
sageattn = None
def sage_attn2( def sage_attn2(q, k, v, cu_seqlens_q=None, cu_seqlens_kv=None, max_seqlen_q=None, max_seqlen_kv=None, model_cls="hunyuan"):
q,
k,
v,
cu_seqlens_q=None,
cu_seqlens_kv=None,
max_seqlen_q=None,
max_seqlen_kv=None,
):
q, k, v = ( q, k, v = (
q.transpose(1, 0).contiguous(), q.transpose(1, 0).contiguous(),
k.transpose(1, 0).contiguous(), k.transpose(1, 0).contiguous(),
v.transpose(1, 0).contiguous(), v.transpose(1, 0).contiguous(),
) )
x1 = sageattn(
q[:, : cu_seqlens_q[1], :].unsqueeze(0), if model_cls == "hunyuan":
k[:, : cu_seqlens_q[1], :].unsqueeze(0), x1 = sageattn(
v[:, : cu_seqlens_kv[1], :].unsqueeze(0), q[:, : cu_seqlens_q[1], :].unsqueeze(0),
) k[:, : cu_seqlens_q[1], :].unsqueeze(0),
x2 = sageattn( v[:, : cu_seqlens_kv[1], :].unsqueeze(0),
q[:, cu_seqlens_q[1] :, :].unsqueeze(0), )
k[:, cu_seqlens_kv[1] :, :].unsqueeze(0), x2 = sageattn(
v[:, cu_seqlens_kv[1] :, :].unsqueeze(0), q[:, cu_seqlens_q[1] :, :].unsqueeze(0),
) k[:, cu_seqlens_kv[1] :, :].unsqueeze(0),
x = torch.cat((x1, x2), dim=-2).transpose(2, 1).contiguous() v[:, cu_seqlens_kv[1] :, :].unsqueeze(0),
x = x.view(max_seqlen_q, -1) )
x = torch.cat((x1, x2), dim=-2).transpose(2, 1).contiguous()
x = x.view(max_seqlen_q, -1)
elif model_cls == "wan2.1":
x = (
sageattn(
q[:, : cu_seqlens_q[1], :].unsqueeze(0),
k[:, : cu_seqlens_q[1], :].unsqueeze(0),
v[:, : cu_seqlens_kv[1], :].unsqueeze(0),
)
.transpose(2, 1)
.contiguous()
)
x = x.view(max_seqlen_q, -1)
return x return x
import numpy as np import numpy as np
from ..transformer_infer import WanTransformerInfer from ..transformer_infer import WanTransformerInfer
from lightx2v.attentions import attention import torch
class WanTransformerInferFeatureCaching(WanTransformerInfer): class WanTransformerInferFeatureCaching(WanTransformerInfer):
...@@ -61,6 +61,10 @@ class WanTransformerInferFeatureCaching(WanTransformerInfer): ...@@ -61,6 +61,10 @@ class WanTransformerInferFeatureCaching(WanTransformerInfer):
context, context,
) )
self.scheduler.previous_residual_even = x - ori_x self.scheduler.previous_residual_even = x - ori_x
if self.config["cpu_offload"]:
ori_x = ori_x.to("cpu")
del ori_x
torch.cuda.empty_cache()
else: else:
if not should_calc_odd: if not should_calc_odd:
x += self.scheduler.previous_residual_odd x += self.scheduler.previous_residual_odd
...@@ -77,5 +81,8 @@ class WanTransformerInferFeatureCaching(WanTransformerInfer): ...@@ -77,5 +81,8 @@ class WanTransformerInferFeatureCaching(WanTransformerInfer):
context, context,
) )
self.scheduler.previous_residual_odd = x - ori_x self.scheduler.previous_residual_odd = x - ori_x
if self.config["cpu_offload"]:
ori_x = ori_x.to("cpu")
del ori_x
torch.cuda.empty_cache()
return x return x
...@@ -98,14 +98,7 @@ class WanTransformerInfer: ...@@ -98,14 +98,7 @@ class WanTransformerInfer:
if not self.parallel_attention: if not self.parallel_attention:
attn_out = attention( attn_out = attention(
attention_type=self.attention_type, attention_type=self.attention_type, q=q, k=k, v=v, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_k, max_seqlen_q=lq, max_seqlen_kv=lk, model_cls=self.config["model_cls"]
q=q,
k=k,
v=v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_k,
max_seqlen_q=lq,
max_seqlen_kv=lk,
) )
else: else:
attn_out = self.parallel_attention( attn_out = self.parallel_attention(
...@@ -136,14 +129,7 @@ class WanTransformerInfer: ...@@ -136,14 +129,7 @@ class WanTransformerInfer:
cu_seqlens_q, cu_seqlens_k, lq, lk = self._calculate_q_k_len(q, k, k_lens=torch.tensor([k.size(0)], dtype=torch.int32, device=k.device)) cu_seqlens_q, cu_seqlens_k, lq, lk = self._calculate_q_k_len(q, k, k_lens=torch.tensor([k.size(0)], dtype=torch.int32, device=k.device))
attn_out = attention( attn_out = attention(
attention_type=self.attention_type, attention_type=self.attention_type, q=q, k=k, v=v, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_k, max_seqlen_q=lq, max_seqlen_kv=lk, model_cls=self.config["model_cls"]
q=q,
k=k,
v=v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_k,
max_seqlen_q=lq,
max_seqlen_kv=lk,
) )
if self.task == "i2v": if self.task == "i2v":
...@@ -157,14 +143,7 @@ class WanTransformerInfer: ...@@ -157,14 +143,7 @@ class WanTransformerInfer:
) )
img_attn_out = attention( img_attn_out = attention(
attention_type=self.attention_type, attention_type=self.attention_type, q=q, k=k_img, v=v_img, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_k, max_seqlen_q=lq, max_seqlen_kv=lk, model_cls=self.config["model_cls"]
q=q,
k=k_img,
v=v_img,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_k,
max_seqlen_q=lq,
max_seqlen_kv=lk,
) )
attn_out = attn_out + img_attn_out attn_out = attn_out + img_attn_out
......
...@@ -71,3 +71,18 @@ class WanSchedulerFeatureCaching(WanScheduler): ...@@ -71,3 +71,18 @@ class WanSchedulerFeatureCaching(WanScheduler):
self.coefficients = [-5784.54975374, 5449.50911966, -1811.16591783, 256.27178429, -13.02252404] self.coefficients = [-5784.54975374, 5449.50911966, -1811.16591783, 256.27178429, -13.02252404]
self.ret_steps = 1 * 2 self.ret_steps = 1 * 2
self.cutoff_steps = self.args.infer_steps * 2 - 2 self.cutoff_steps = self.args.infer_steps * 2 - 2
def clear(self):
if self.previous_e0_even is not None:
self.previous_e0_even = self.previous_e0_even.cpu()
if self.previous_e0_odd is not None:
self.previous_e0_odd = self.previous_e0_odd.cpu()
if self.previous_residual_even is not None:
self.previous_residual_even = self.previous_residual_even.cpu()
if self.previous_residual_odd is not None:
self.previous_residual_odd = self.previous_residual_odd.cpu()
self.previous_e0_even = None
self.previous_e0_odd = None
self.previous_residual_even = None
self.previous_residual_odd = None
torch.cuda.empty_cache()
...@@ -341,3 +341,6 @@ class WanScheduler(BaseScheduler): ...@@ -341,3 +341,6 @@ class WanScheduler(BaseScheduler):
self.lower_order_nums += 1 self.lower_order_nums += 1
self.latents = prev_sample self.latents = prev_sample
def clear(self):
pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment