"official/modeling/optimization/slide_optimizer.py" did not exist on "70702f79c7817b09fb87fef7729478af58532870"
Unverified Commit 503c3abc authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

[Feat] add sageattn3 (#416)


Co-authored-by: default avatargushiqiao <975033167@qq.ocm>
parent e106ff67
from .flash_attn import FlashAttn2Weight, FlashAttn3Weight
from .radial_attn import RadialAttnWeight
from .ring_attn import RingAttnWeight
from .sage_attn import SageAttn2Weight
from .sage_attn import SageAttn2Weight, SageAttn3Weight
from .spassage_attn import SageAttnWeight
from .svg2_attn import Svg2AttnWeight
from .svg_attn import SvgAttnWeight
......
......@@ -5,7 +5,7 @@ from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER
from .template import AttnWeightTemplate
if torch.cuda.get_device_capability(0) == (8, 9):
if torch.cuda.get_device_capability(0) in [(8, 9), (12, 0)]:
try:
from sageattention import sageattn_qk_int8_pv_fp16_triton as sageattn
except ImportError:
......@@ -18,6 +18,12 @@ else:
logger.info("sageattn not found, please install sageattention first")
sageattn = None
try:
from sageattn3 import sageattn3_blackwell
except ImportError:
logger.info("sageattn3 not found, please install sageattention first")
sageattn3_blackwell = None
@ATTN_WEIGHT_REGISTER("sage_attn2")
class SageAttn2Weight(AttnWeightTemplate):
......@@ -48,3 +54,30 @@ class SageAttn2Weight(AttnWeightTemplate):
tensor_layout="NHD",
).view(bs * max_seqlen_q, -1)
return x
@ATTN_WEIGHT_REGISTER("sage_attn3")
class SageAttn3Weight(AttnWeightTemplate):
def __init__(self):
self.config = {}
def apply(
self,
q,
k,
v,
cu_seqlens_q=None,
cu_seqlens_kv=None,
max_seqlen_q=None,
max_seqlen_kv=None,
model_cls=None,
):
q, k, v = q.contiguous(), k.contiguous(), v.contiguous()
if len(q.shape) == 3:
bs = 1
q, k, v = q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0)
elif len(q.shape) == 4:
bs = q.shape[0]
x = sageattn3_blackwell(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)).transpose(1, 2).reshape(bs * max_seqlen_q, -1)
return x
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment