Unverified Commit 9bb7dfe9 authored by Rongjin Yang's avatar Rongjin Yang Committed by GitHub
Browse files

add spas_sage_attn (#402)

https://github.com/Linboyan-trc/SpargeAttn/tree/02_develop



need install

---------
Co-authored-by: default avatarYang Yong (雍洋) <yongyang1030@163.com>
parent e1b16a56
{
"infer_steps": 40,
"target_video_length": 81,
"target_height": 480,
"target_width": 832,
"self_attn_1_type": "sage_attn",
"cross_attn_1_type": "sage_attn",
"cross_attn_2_type": "sage_attn",
"sample_guide_scale": 5,
"sample_shift": 3,
"enable_cfg": true,
"cpu_offload": false
}
{
"infer_steps": 50,
"target_video_length": 81,
"text_len": 512,
"target_height": 480,
"target_width": 832,
"self_attn_1_type": "spas_sage_attn",
"cross_attn_1_type": "spas_sage_attn",
"cross_attn_2_type": "spas_sage_attn",
"sample_guide_scale": 6,
"sample_shift": 8,
"enable_cfg": true,
"cpu_offload": false
}
......@@ -2,6 +2,7 @@ from .flash_attn import FlashAttn2Weight, FlashAttn3Weight
from .radial_attn import RadialAttnWeight
from .ring_attn import RingAttnWeight
from .sage_attn import SageAttn2Weight
from .spassage_attn import SageAttnWeight
from .svg2_attn import Svg2AttnWeight
from .svg_attn import SvgAttnWeight
from .torch_sdpa import TorchSDPAWeight
......
import os
import torch
try:
import spas_sage_attn
except ImportError:
spas_sage_attn = None
from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER
from .template import AttnWeightTemplate
@ATTN_WEIGHT_REGISTER("spas_sage_attn")
class SageAttnWeight(AttnWeightTemplate):
def __init__(self):
self.config = {}
@classmethod
def apply(self, q, k, v, cu_seqlens_q=None, cu_seqlens_kv=None, max_seqlen_q=None, max_seqlen_kv=None, model_cls=None, tensor_layout="HND"):
q = q.unsqueeze(0)
k = k.unsqueeze(0)
v = v.unsqueeze(0)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
attn_out = spas_sage_attn.core.spas_sage2_attn_meansim_cuda(q, k, v, tensor_layout)
_, H, N, D = attn_out.shape
attn_out = attn_out.permute(2, 1, 3, 0).contiguous().view(N, H * D)
return attn_out
if __name__ == "__main__":
import matplotlib.pyplot as plt
# 1. 构造输入
q = torch.randn(32760, 12, 128, dtype=torch.bfloat16).cuda()
k = torch.randn(32760, 12, 128, dtype=torch.bfloat16).cuda()
v = torch.randn(32760, 12, 128, dtype=torch.bfloat16).cuda()
# 2. 直接用PyTorch计算注意力
q_ = q.float()
k_ = k.float()
v_ = v.float()
attn_weights = torch.matmul(q_, k_.transpose(-2, -1)) / (128**0.5)
attn_weights = torch.softmax(attn_weights, dim=-1)
output_pt = torch.matmul(attn_weights, v_)
# 3. 用spas_sage2_attn_meansim_cuda计算注意力
q = q.unsqueeze(0) # shape: (1, 32760, 12, 128)
k = k.unsqueeze(0)
v = v.unsqueeze(0)
q = q.transpose(1, 2) # shape: (1, 12, 32760, 128)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
output_cuda = spas_sage_attn.core.spas_sage2_attn_meansim_cuda(q, k, v, tensor_layout="HND")
output_cuda = output_cuda.float()
# 4. 取左上角[3000, 3000],只取第一个head
output_pt_crop = output_pt[0, :3000, :3000].cpu().detach().numpy()
output_cuda_crop = output_cuda[0, 0, :3000, :3000].cpu().detach().numpy()
# 5. 保存图片
save_dir = os.path.expanduser("~/Log/10-22/")
os.makedirs(save_dir, exist_ok=True)
plt.imshow(output_pt_crop, aspect="auto")
plt.title("PyTorch Attention (left-top 3000x3000)")
plt.savefig(os.path.join(save_dir, "attn.png"))
plt.close()
plt.imshow(output_cuda_crop, aspect="auto")
plt.title("spas_sage2_attn_meansim_cuda (left-top 3000x3000)")
plt.savefig(os.path.join(save_dir, "spas_attn.png"))
plt.close()
#!/bin/bash
# set path and first
lightx2v_path=
model_path=
export CUDA_VISIBLE_DEVICES=0
# set environment variables
source ${lightx2v_path}/scripts/base/base.sh
python -m lightx2v.infer \
--model_cls wan2.1 \
--task i2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/sparse_attn/spas_sage_attn/wan_i2v.json \
--prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." \
--negative_prompt "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \
--image_path ${lightx2v_path}/assets/inputs/imgs/img_0.jpg \
--save_result_path ${lightx2v_path}/save_results/output_lightx2v_wan_i2v_spas_sage_attn.mp4
#!/bin/bash
# set path and first
lightx2v_path=
model_path=
export CUDA_VISIBLE_DEVICES=0
# set environment variables
source ${lightx2v_path}/scripts/base/base.sh
python -m lightx2v.infer \
--model_cls wan2.1 \
--task t2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/sparse_attn/spas_sage_attn/wan_t2v.json \
--prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." \
--negative_prompt "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \
--save_result_path ${lightx2v_path}/save_results/output_lightx2v_wan_t2v_spas_sage_attn.mp4
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment