Commit e2778d0d authored by litzh's avatar litzh
Browse files

Initial commit

parents
Pipeline #3370 canceled with stages
from loguru import logger
from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER, SPARSE_MASK_GENERATOR_REGISTER, SPARSE_OPERATOR_REGISTER
from .template import AttnWeightTemplate
@ATTN_WEIGHT_REGISTER("general_sparse_attn")
class GeneralSparseAttnWeight(AttnWeightTemplate):
sparse_mask_generator = None
sparse_operator = None
sparse_setting = {}
attnmap_frame_num = None
def __init__(self):
self.config = {}
self._setup_operator()
self._setup_mask_generator()
logger.info(
f"GeneralSparseAttnWeight: sparse_setting={self.sparse_setting}, operator={self.sparse_operator}, mask_generator={self.sparse_mask_generator}, attnmap_frame_num={self.attnmap_frame_num}"
)
def _setup_operator(self):
self.operator = SPARSE_OPERATOR_REGISTER[self.sparse_operator]()
def _setup_mask_generator(self):
self.mask_generator = SPARSE_MASK_GENERATOR_REGISTER[self.sparse_mask_generator](self.operator.q_block_size, self.operator.k_block_size, self.sparse_setting, self.attnmap_frame_num)
def apply(
self,
q,
k,
v,
cu_seqlens_q=None,
cu_seqlens_kv=None,
max_seqlen_q=None,
max_seqlen_kv=None,
**kwargs,
):
# Generate sparse mask
mask = self.mask_generator(q, k)
# reorg
q, k, v = self.mask_generator.reorg(q, k, v)
# Apply sparse operator
out = self.operator(q, k, v, mask, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv, max_seqlen_q=max_seqlen_q, max_seqlen_kv=max_seqlen_kv, **kwargs)
# restore
out = self.mask_generator.restore(out)
return out
import torch
import triton
import triton.language as tl
@triton.jit
def _attn_fwd(
Q,
K,
V,
qk_scale: tl.constexpr,
topk: tl.constexpr,
LUT,
LSE,
OS,
L: tl.constexpr,
M_BLOCKS: tl.constexpr,
D: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
idx_m = tl.program_id(0).to(tl.int64)
idx_bh = tl.program_id(1).to(tl.int64)
qkv_offset = idx_bh * L * D
lut_offset = (idx_bh * M_BLOCKS + idx_m) * topk
lse_offset = idx_bh * L
offs_m = idx_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, D)
Q_ptrs = Q + qkv_offset + offs_m[:, None] * D + offs_d[None, :]
K_ptrs = K + qkv_offset + offs_n[None, :] * D + offs_d[:, None]
V_ptrs = V + qkv_offset + offs_n[:, None] * D + offs_d[None, :]
OS_ptrs = OS + qkv_offset + offs_m[:, None] * D + offs_d[None, :]
LUT_ptr = LUT + lut_offset
LSE_ptrs = LSE + lse_offset + offs_m
m_i = tl.full([BLOCK_M], -float("inf"), dtype=tl.float32)
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
o_s = tl.zeros([BLOCK_M, D], dtype=tl.float32)
q = tl.load(Q_ptrs, mask=offs_m[:, None] < L)
for block_idx in tl.range(topk):
idx_n = tl.load(LUT_ptr + block_idx)
n_mask = offs_n < L - idx_n * BLOCK_N
k = tl.load(K_ptrs + idx_n * BLOCK_N * D, mask=n_mask[None, :])
qk = tl.dot(q, k) * (qk_scale * 1.4426950408889634) # = 1 / ln(2)
if L - idx_n * BLOCK_N < BLOCK_N:
qk = tl.where(n_mask[None, :], qk, float("-inf"))
v = tl.load(V_ptrs + idx_n * BLOCK_N * D, mask=n_mask[:, None])
local_m = tl.max(qk, 1)
new_m = tl.maximum(m_i, local_m)
qk = qk - new_m[:, None]
p = tl.math.exp2(qk)
l_ij = tl.sum(p, 1)
alpha = tl.math.exp2(m_i - new_m)
o_s = o_s * alpha[:, None]
o_s += tl.dot(p.to(v.dtype), v)
l_i = l_i * alpha + l_ij
m_i = new_m
o_s = o_s / l_i[:, None]
tl.store(OS_ptrs, o_s.to(OS.type.element_ty), mask=offs_m[:, None] < L)
m_i += tl.math.log2(l_i)
tl.store(LSE_ptrs, m_i, mask=offs_m < L)
@triton.jit
def _attn_bwd_preprocess(
OS,
DOS,
DELTAS,
L,
D: tl.constexpr,
BLOCK_M: tl.constexpr,
):
idx_m = tl.program_id(0).to(tl.int64)
idx_bh = tl.program_id(1).to(tl.int64)
OS += idx_bh * L * D
DOS += idx_bh * L * D
DELTAS += idx_bh * L
offs_m = idx_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_d = tl.arange(0, D)
o_s = tl.load(OS + offs_m[:, None] * D + offs_d[None, :], mask=offs_m[:, None] < L)
do_s = tl.load(DOS + offs_m[:, None] * D + offs_d[None, :], mask=offs_m[:, None] < L)
delta_s = tl.sum(o_s * do_s, axis=1).to(DELTAS.type.element_ty)
tl.store(DELTAS + offs_m, delta_s, mask=offs_m < L)
# the main inner-loop logic for computing dQ
@triton.jit
def _attn_bwd_dq(
Q,
K,
V,
LSE,
DELTAS,
DOS,
DQ,
LUT,
qk_scale: tl.constexpr,
topk: tl.constexpr,
L: tl.constexpr,
M_BLOCKS: tl.constexpr,
D: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
idx_m = tl.program_id(0).to(tl.int64)
idx_bh = tl.program_id(1).to(tl.int64)
offs_m = idx_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, D)
qkv_offset = idx_bh * L * D
lse_offset = idx_bh * L
lut_offset = (idx_bh * M_BLOCKS + idx_m) * topk
Q_ptrs = Q + qkv_offset + offs_m[:, None] * D + offs_d[None, :]
K_ptrs = K + qkv_offset + offs_n[:, None] * D + offs_d[None, :]
V_ptrs = V + qkv_offset + offs_n[:, None] * D + offs_d[None, :]
DQ_ptrs = DQ + qkv_offset + offs_m[:, None] * D + offs_d[None, :]
DOS_ptrs = DOS + qkv_offset + offs_m[:, None] * D + offs_d[None, :]
LSE_ptrs = LSE + lse_offset + offs_m
DELTAS_ptrs = DELTAS + lse_offset + offs_m
LUT_ptr = LUT + lut_offset
# load Q, DOS, DOL, LSE, DELTA, S: they stay in SRAM throughout the inner loop.
q = tl.load(Q_ptrs, mask=offs_m[:, None] < L)
do_s = tl.load(DOS_ptrs, mask=offs_m[:, None] < L)
delta_s = tl.load(DELTAS_ptrs, mask=offs_m < L)
lse = tl.load(LSE_ptrs, mask=offs_m < L, other=float("inf"))
dq = tl.zeros([BLOCK_M, D], dtype=tl.float32)
for block_idx in tl.range(topk, num_stages=2):
idx_n = tl.load(LUT_ptr + block_idx)
n_mask = offs_n < L - idx_n * BLOCK_N
k = tl.load(K_ptrs + idx_n * BLOCK_N * D, mask=n_mask[:, None])
v = tl.load(V_ptrs + idx_n * BLOCK_N * D, mask=n_mask[:, None])
qk = tl.dot(q, k.T) * (qk_scale * 1.4426950408889634) # = 1 / ln(2)
p = tl.math.exp2(qk - lse[:, None])
p = tl.where(n_mask[None, :], p, 0.0)
# Compute dP and dS.
dp = tl.dot(do_s, v.T).to(tl.float32)
ds = p * (dp - delta_s[:, None])
# Compute dQ.
dq += tl.dot(ds.to(k.dtype), k)
tl.store(DQ_ptrs, dq * qk_scale, mask=offs_m[:, None] < L)
@triton.jit
def _attn_bwd_dkdv(
Q,
K,
V,
DOS,
DK,
DV,
qk_scale,
KBID,
LSE,
DELTAS,
L: tl.constexpr,
M_BLOCKS: tl.constexpr,
N_BLOCKS: tl.constexpr,
D: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_SLICE_FACTOR: tl.constexpr,
):
BLOCK_M2: tl.constexpr = BLOCK_M // BLOCK_SLICE_FACTOR
idx_n = tl.program_id(0).to(tl.int64)
idx_bh = tl.program_id(1).to(tl.int64)
offs_n = idx_n * BLOCK_N + tl.arange(0, BLOCK_N)
offs_m = tl.arange(0, BLOCK_M2)
offs_d = tl.arange(0, D)
qkv_offset = idx_bh * L * D
kbid_offset = idx_bh * M_BLOCKS * N_BLOCKS
lse_offset = idx_bh * L
Q_ptrs = Q + qkv_offset + offs_m[:, None] * D + offs_d[None, :]
K_ptrs = K + qkv_offset + offs_n[:, None] * D + offs_d[None, :]
V_ptrs = V + qkv_offset + offs_n[:, None] * D + offs_d[None, :]
DOS_ptrs = DOS + qkv_offset + offs_m[:, None] * D + offs_d[None, :]
DK_ptrs = DK + qkv_offset + offs_n[:, None] * D + offs_d[None, :]
DV_ptrs = DV + qkv_offset + offs_n[:, None] * D + offs_d[None, :]
LSE_ptrs = LSE + lse_offset + offs_m
DELTAS_ptrs = DELTAS + lse_offset + offs_m
KBID_ptr = KBID + kbid_offset + idx_n
# load K, V and CK: they stay in SRAM throughout the inner loop.
k = tl.load(K_ptrs, mask=offs_n[:, None] < L)
v = tl.load(V_ptrs, mask=offs_n[:, None] < L)
dk = tl.zeros([BLOCK_N, D], dtype=tl.float32)
dv = tl.zeros([BLOCK_N, D], dtype=tl.float32)
for idx_m in tl.range(0, L, BLOCK_M2):
kbid = tl.load(KBID_ptr)
if kbid == 1:
m_mask = offs_m < L - idx_m
q = tl.load(Q_ptrs, mask=m_mask[:, None])
lse = tl.load(LSE_ptrs, mask=m_mask, other=float("inf"))
qkT = tl.dot(k, q.T) * (qk_scale * 1.4426950408889634) # = 1 / ln(2)
pT = tl.math.exp2(qkT - lse[None, :])
pT = tl.where(offs_n[:, None] < L, pT, 0.0)
do = tl.load(DOS_ptrs, mask=m_mask[:, None])
# Compute dV.
dv += tl.dot(pT.to(do.dtype), do)
delta = tl.load(DELTAS_ptrs, mask=m_mask)
# Compute dP and dS.
dpT = tl.dot(v, tl.trans(do))
dsT = pT * (dpT - delta[None, :])
dk += tl.dot(dsT.to(q.dtype), q)
# Increment pointers
Q_ptrs += BLOCK_M2 * D
DOS_ptrs += BLOCK_M2 * D
LSE_ptrs += BLOCK_M2
DELTAS_ptrs += BLOCK_M2
if (idx_m + BLOCK_M2) % BLOCK_M == 0:
KBID_ptr += N_BLOCKS
# Write back dK, dV and dCK
tl.store(DK_ptrs, dk * qk_scale, mask=offs_n[:, None] < L)
tl.store(DV_ptrs, dv, mask=offs_n[:, None] < L)
class _attention(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, k_block_id, lut, topk, BLOCK_M, BLOCK_N, qk_scale=None):
assert q.is_contiguous() and k.is_contiguous() and v.is_contiguous()
assert k_block_id.is_contiguous() and lut.is_contiguous()
# We recommend the following two settings
assert BLOCK_M == 64 or BLOCK_M == 128
assert BLOCK_N == 64 or BLOCK_N == 128
B, H, L, D = q.shape
if qk_scale is None:
qk_scale = D**-0.5
M_BLOCKS = triton.cdiv(L, BLOCK_M)
o_s = torch.empty_like(v)
lse = torch.empty(q.shape[:-1], device=q.device, dtype=torch.float32)
grid = (M_BLOCKS, B * H)
_attn_fwd[grid](q, k, v, qk_scale, topk, lut, lse, o_s, L, M_BLOCKS, D, BLOCK_M, BLOCK_N, num_warps=4 if q.shape[-1] == 64 else 8, num_stages=3)
ctx.save_for_backward(q, k, v, k_block_id, lut, lse, o_s)
ctx.qk_scale = qk_scale
ctx.topk = topk
ctx.BLOCK_M = BLOCK_M
ctx.BLOCK_N = BLOCK_N
return o_s
@staticmethod
def backward(ctx, do_s):
q, k, v, k_block_id, lut, lse, o_s = ctx.saved_tensors
do_s = do_s.contiguous()
BLOCK_M, BLOCK_N = ctx.BLOCK_M, ctx.BLOCK_N
B, H, L, D = q.shape
M_BLOCKS = triton.cdiv(L, BLOCK_M)
N_BLOCKS = triton.cdiv(L, BLOCK_N)
dq = torch.empty_like(q)
dk = torch.empty_like(k)
dv = torch.empty_like(v)
delta_s = torch.empty_like(lse)
grid = (M_BLOCKS, B * H)
_attn_bwd_preprocess[grid](
o_s,
do_s,
delta_s,
L,
D,
BLOCK_M,
)
grid = (M_BLOCKS, B * H)
_attn_bwd_dq[grid](
q, k, v, lse, delta_s, do_s, dq, lut, ctx.qk_scale, ctx.topk, L, M_BLOCKS, D, BLOCK_M, BLOCK_N, num_warps=4 if q.shape[-1] == 64 else 8, num_stages=4 if q.shape[-1] == 64 else 5
)
grid = (N_BLOCKS, B * H)
_attn_bwd_dkdv[grid](
q,
k,
v,
do_s,
dk,
dv,
ctx.qk_scale,
k_block_id,
lse,
delta_s,
L,
M_BLOCKS,
N_BLOCKS,
D,
BLOCK_M,
BLOCK_N,
BLOCK_SLICE_FACTOR=BLOCK_M // 64,
num_warps=4 if q.shape[-1] == 64 else 8,
num_stages=4 if q.shape[-1] == 64 else 5,
)
return dq, dk, dv, None, None, None, None, None, None
import torch
from loguru import logger
try:
from magi_attention.functional import flex_flash_attn_func as magi_ffa_func
except ImportError:
magi_ffa_func = None
try:
import flashinfer
except ImportError:
flashinfer = None
from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER
from .template import AttnWeightTemplate
def generate_nbhd_mask(a, block_num, attnmap_frame_num, coefficient=[1.0, 0.5, 0.056], min_width=1.0, device="cpu"):
"""
a : block num per frame
block_num : block num per col/row
attnmap_frame_num : total frame num
"""
i_indices = torch.arange(block_num, device=device).unsqueeze(1) # [block_num, 1]
j_indices = torch.arange(block_num, device=device).unsqueeze(0) # [1, block_num]
assert len(coefficient) <= attnmap_frame_num, f"coefficient length {len(coefficient)} should <= attnmap_frame_num {attnmap_frame_num}"
width_list = [max(min_width, coefficient[i] * a) for i in range(len(coefficient))] + [min_width] * (attnmap_frame_num - len(coefficient))
logger.info(f"nbhd_attn width_list: {width_list}, len={len(width_list)}")
# attention sink frame: j <= a
mask_sink = j_indices <= a
mask_sparse = torch.zeros((block_num, block_num), dtype=torch.bool, device=device)
for interval in range(0, attnmap_frame_num):
n = i_indices // a
mask_sparse_base_1 = (j_indices >= (n + interval) * a) & (j_indices <= (n + interval + 1) * a)
n = j_indices // a
mask_sparse_base_2 = (i_indices >= (n + interval) * a) & (i_indices <= (n + interval + 1) * a)
width = width_list[interval]
mask_1 = mask_sparse_base_1 & (i_indices - j_indices + (interval * a + width) >= 0) & (i_indices - j_indices + (interval * a - width) <= 0)
mask_2 = mask_sparse_base_2 & (i_indices - j_indices - (interval * a - width) >= 0) & (i_indices - j_indices - (interval * a + width) <= 0)
mask_sparse = mask_sparse | mask_1 | mask_2
mask = mask_sink | mask_sparse
return mask
def generate_qk_ranges(mask, q_block_size, k_block_size, seqlen):
# mask: [H, Q_block_num, K_block_num]
h_indices, i_indices, j_indices = torch.nonzero(mask, as_tuple=True)
base_offset = h_indices * seqlen
q_start = base_offset + i_indices * q_block_size
q_end = base_offset + torch.clamp((i_indices + 1) * q_block_size, max=seqlen)
k_start = base_offset + j_indices * k_block_size
k_end = base_offset + torch.clamp((j_indices + 1) * k_block_size, max=seqlen)
q_ranges = torch.stack([q_start, q_end], dim=1)
k_ranges = torch.stack([k_start, k_end], dim=1)
return q_ranges, k_ranges
@ATTN_WEIGHT_REGISTER("nbhd_attn")
class NbhdAttnWeight(AttnWeightTemplate):
block_size = 128
seqlen = None
attnmap_frame_num = None
q_ranges = None
k_ranges = None
attn_type_map = None
coefficient = [1.0, 0.5, 0.056]
min_width = 1.0
def __init__(self):
self.config = {}
@classmethod
@torch.compiler.disable
def prepare_mask(cls, seqlen, head_num):
if seqlen == cls.seqlen:
return
block_num = (seqlen + cls.block_size - 1) // cls.block_size
block_num_per_frame = seqlen / cls.attnmap_frame_num / cls.block_size
mask = generate_nbhd_mask(block_num_per_frame, block_num, cls.attnmap_frame_num, coefficient=cls.coefficient, min_width=cls.min_width, device="cpu")
repeat_mask = mask.unsqueeze(0).repeat(head_num, 1, 1)
q_ranges, k_ranges = generate_qk_ranges(repeat_mask, cls.block_size, cls.block_size, seqlen)
attn_type_map = torch.zeros(len(q_ranges), dtype=torch.int32, device="cuda")
q_ranges = q_ranges.to(torch.int32).to("cuda")
k_ranges = k_ranges.to(torch.int32).to("cuda")
cls.seqlen = seqlen
cls.q_ranges = q_ranges
cls.k_ranges = k_ranges
cls.attn_type_map = attn_type_map
logger.info(f"NbhdAttnWeight Update: seqlen={seqlen}")
sparsity = 1 - mask.sum().item() / mask.numel()
logger.info(f"Attention sparsity: {sparsity}")
def apply(
self,
q,
k,
v,
cu_seqlens_q=None,
cu_seqlens_kv=None,
max_seqlen_q=None,
max_seqlen_kv=None,
**kwargs,
):
"""
q: [seqlen, head_num, head_dim]
k: [seqlen, head_num, head_dim]
v: [seqlen, head_num, head_dim]
"""
seqlen, head_num, head_dim = q.shape
self.prepare_mask(seqlen=seqlen, head_num=head_num)
q = q.permute(1, 0, 2).reshape(head_num * seqlen, 1, head_dim)
k = k.permute(1, 0, 2).reshape(head_num * seqlen, 1, head_dim)
v = v.permute(1, 0, 2).reshape(head_num * seqlen, 1, head_dim)
out = magi_ffa_func(
q,
k,
v,
q_ranges=self.q_ranges,
k_ranges=self.k_ranges,
attn_type_map=self.attn_type_map,
auto_range_merge=True,
)[0]
out = out.reshape(head_num, seqlen, head_dim).permute(1, 0, 2)
return out.reshape(out.shape[0], -1)
@ATTN_WEIGHT_REGISTER("nbhd_attn_flashinfer")
class NbhdAttnWeightFlashInfer(AttnWeightTemplate):
block_size = 128
seqlen = None
attnmap_frame_num = None
coefficient = [1.0, 0.5, 0.056]
min_width = 1.0
sparse_wrapper = None
def __init__(self):
self.config = {}
@classmethod
@torch.compiler.disable
def prepare_mask(cls, seqlen, head_num, head_dim):
if seqlen == cls.seqlen:
return
block_num = (seqlen + cls.block_size - 1) // cls.block_size
block_num_per_frame = seqlen / cls.attnmap_frame_num / cls.block_size
mask = generate_nbhd_mask(block_num_per_frame, block_num, cls.attnmap_frame_num, coefficient=cls.coefficient, min_width=cls.min_width, device="cpu")
mask = mask.unsqueeze(0).repeat(head_num, 1, 1)
block_rowcol_size = torch.ones(block_num, dtype=torch.int32) * cls.block_size
block_rowcol_size[-1] = seqlen - cls.block_size * (block_num - 1)
block_rowcol_size = block_rowcol_size.unsqueeze(0).repeat(head_num, 1)
float_workspace_buffer = torch.empty(1024 * 1024 * 1024, dtype=torch.uint8, device="cuda:0")
cls.sparse_wrapper = flashinfer.sparse.VariableBlockSparseAttentionWrapper(float_workspace_buffer, backend="fa2")
cls.sparse_wrapper.plan(
block_mask_map=mask,
block_row_sz=block_rowcol_size,
block_col_sz=block_rowcol_size,
num_qo_heads=head_num,
num_kv_heads=head_num,
head_dim=head_dim,
q_data_type=torch.bfloat16,
)
cls.seqlen = seqlen
logger.info(f"NbhdAttnWeight Update: seqlen={seqlen}")
sparsity = 1 - mask.sum().item() / mask.numel()
logger.info(f"Attention sparsity: {sparsity}")
def apply(
self,
q,
k,
v,
cu_seqlens_q=None,
cu_seqlens_kv=None,
max_seqlen_q=None,
max_seqlen_kv=None,
**kwargs,
):
"""
q: [seqlen, head_num, head_dim]
k: [seqlen, head_num, head_dim]
v: [seqlen, head_num, head_dim]
"""
self.prepare_mask(seqlen=q.shape[0], head_num=q.shape[1], head_dim=q.shape[2])
q = q.transpose(0, 1)
k = k.transpose(0, 1)
v = v.transpose(0, 1)
out = self.sparse_wrapper.run(q, k, v)
out = out.transpose(0, 1)
return out.reshape(out.shape[0], -1)
import torch
from loguru import logger
try:
from magi_attention.functional import flex_flash_attn_func as magi_ffa_func
except ImportError:
magi_ffa_func = None
from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER
from .template import AttnWeightTemplate
def shrinkMaskStrict(mask, block_size=128):
seqlen = mask.shape[0]
block_num = seqlen // block_size
mask = mask[: block_num * block_size, : block_num * block_size].view(block_num, block_size, block_num, block_size)
col_densities = mask.sum(dim=1) / block_size
# we want the minimum non-zero column density in the block
non_zero_densities = col_densities > 0
high_density_cols = col_densities > 1 / 3
frac_high_density_cols = high_density_cols.sum(dim=-1) / (non_zero_densities.sum(dim=-1) + 1e-9)
block_mask = frac_high_density_cols > 0.6
block_mask[0:0] = True
block_mask[-1:-1] = True
return block_mask
def get_window_width(i, j, token_per_frame, sparse_type, num_frame, decay_factor=1, block_size=128, model_type=None):
assert sparse_type in ["radial"]
dist = abs(i - j)
if model_type == "wan":
if dist < 1:
return token_per_frame
if dist == 1:
return token_per_frame // 2
elif model_type == "hunyuan":
if dist <= 1:
return token_per_frame
else:
raise ValueError(f"Unknown model type: {model_type}")
group = dist.bit_length()
decay_length = 2 ** token_per_frame.bit_length() / 2**group * decay_factor
threshold = block_size
if decay_length >= threshold:
return decay_length
else:
return threshold
def get_diagonal_split_mask(i, j, token_per_frame, sparse_type, device):
assert sparse_type in ["radial"]
dist = abs(i - j)
group = dist.bit_length()
threshold = 128 # hardcoded threshold for now, which is equal to block-size
decay_length = 2 ** token_per_frame.bit_length() / 2**group
if decay_length >= threshold:
return torch.ones((token_per_frame, token_per_frame), device=device, dtype=torch.bool)
split_factor = int(threshold / decay_length)
modular = dist % split_factor
if modular == 0:
return torch.ones((token_per_frame, token_per_frame), device=device, dtype=torch.bool)
else:
return torch.zeros((token_per_frame, token_per_frame), device=device, dtype=torch.bool)
def gen_log_mask_shrinked(device, s, video_token_num, num_frame, block_size=128, sparse_type="log", decay_factor=0.5, model_type=None):
"""
A more memory friendly version, we generate the attention mask of each frame pair at a time,
shrinks it, and stores it into the final result
"""
final_log_mask = torch.zeros(((s + block_size - 1) // block_size, (s + block_size - 1) // block_size), device=device, dtype=torch.bool)
token_per_frame = video_token_num // num_frame
video_text_border = video_token_num // block_size
col_indices = torch.arange(0, token_per_frame, device=device).view(1, -1)
row_indices = torch.arange(0, token_per_frame, device=device).view(-1, 1)
final_log_mask[video_text_border:] = True
final_log_mask[:, video_text_border:] = True
for i in range(num_frame):
for j in range(num_frame):
local_mask = torch.zeros((token_per_frame, token_per_frame), device=device, dtype=torch.bool)
if j == 0 and model_type == "wan": # this is attention sink
local_mask = torch.ones((token_per_frame, token_per_frame), device=device, dtype=torch.bool)
else:
window_width = get_window_width(i, j, token_per_frame, sparse_type, num_frame, decay_factor=decay_factor, block_size=block_size, model_type=model_type)
local_mask = torch.abs(col_indices - row_indices) <= window_width
split_mask = get_diagonal_split_mask(i, j, token_per_frame, sparse_type, device)
local_mask = torch.logical_and(local_mask, split_mask)
remainder_row = (i * token_per_frame) % block_size
remainder_col = (j * token_per_frame) % block_size
# get the padded size
all_length_row = remainder_row + ((token_per_frame - 1) // block_size + 1) * block_size
all_length_col = remainder_col + ((token_per_frame - 1) // block_size + 1) * block_size
padded_local_mask = torch.zeros((all_length_row, all_length_col), device=device, dtype=torch.bool)
padded_local_mask[remainder_row : remainder_row + token_per_frame, remainder_col : remainder_col + token_per_frame] = local_mask
# shrink the mask
block_mask = shrinkMaskStrict(padded_local_mask, block_size=block_size)
# set the block mask to the final log mask
block_row_start = (i * token_per_frame) // block_size
block_col_start = (j * token_per_frame) // block_size
block_row_end = block_row_start + block_mask.shape[0]
block_col_end = block_col_start + block_mask.shape[1]
final_log_mask[block_row_start:block_row_end, block_col_start:block_col_end] = torch.logical_or(final_log_mask[block_row_start:block_row_end, block_col_start:block_col_end], block_mask)
return final_log_mask
def generate_qk_ranges(mask, block_size, seqlen):
indices = torch.nonzero(mask, as_tuple=False) # shape: [N, 2]
i_indices = indices[:, 0] # [N]
j_indices = indices[:, 1] # [N]
q_start = i_indices * block_size # [N]
q_end = torch.clamp((i_indices + 1) * block_size, max=seqlen) # [N]
k_start = j_indices * block_size # [N]
k_end = torch.clamp((j_indices + 1) * block_size, max=seqlen) # [N]
q_ranges = torch.stack([q_start, q_end], dim=1) # [N, 2]
k_ranges = torch.stack([k_start, k_end], dim=1) # [N, 2]
return q_ranges, k_ranges
@ATTN_WEIGHT_REGISTER("radial_attn")
class RadialAttnWeight(AttnWeightTemplate):
block_size = 128
seqlen = None
attnmap_frame_num = None
q_ranges = None
k_ranges = None
attn_type_map = None
def __init__(self):
self.config = {}
@classmethod
def prepare_mask(cls, seqlen):
if seqlen == cls.seqlen:
return
mask = gen_log_mask_shrinked(
device="cuda", s=seqlen, video_token_num=seqlen, num_frame=cls.attnmap_frame_num, block_size=cls.block_size, sparse_type="radial", decay_factor=0.2, model_type="wan"
)
q_ranges, k_ranges = generate_qk_ranges(mask, cls.block_size, seqlen)
attn_type_map = torch.zeros(len(q_ranges), dtype=torch.int32, device="cuda")
q_ranges = q_ranges.to(torch.int32).to("cuda")
k_ranges = k_ranges.to(torch.int32).to("cuda")
cls.seqlen = seqlen
cls.q_ranges = q_ranges
cls.k_ranges = k_ranges
cls.attn_type_map = attn_type_map
logger.info(f"NbhdAttnWeight Update: seqlen={seqlen}")
sparsity = 1 - mask.sum().item() / mask.numel()
logger.info(f"Attention sparsity: {sparsity}")
def apply(
self,
q,
k,
v,
cu_seqlens_q=None,
cu_seqlens_kv=None,
max_seqlen_q=None,
max_seqlen_kv=None,
**kwargs,
):
"""
q: [seqlen, head_num, head_dim]
k: [seqlen, head_num, head_dim]
v: [seqlen, head_num, head_dim]
"""
self.prepare_mask(seqlen=q.shape[0])
out = magi_ffa_func(
q,
k,
v,
q_ranges=self.q_ranges,
k_ranges=self.k_ranges,
attn_type_map=self.attn_type_map,
auto_range_merge=True,
)[0]
return out.reshape(out.shape[0], -1)
import torch
import torch.distributed as dist
import torch.nn.functional as F
from loguru import logger
from lightx2v.utils.envs import *
from lightx2v.utils.quant_utils import dequant_fp8_vllm, quant_fp8_vllm
from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER
from .template import AttnWeightTemplate
from .utils.ring_comm import RingComm
try:
import flash_attn
from flash_attn.flash_attn_interface import flash_attn_varlen_func
except ImportError:
logger.info("flash_attn_varlen_func not found, please install flash_attn2 first")
flash_attn_varlen_func = None
@torch.jit.script
def _update_out_and_lse(
out,
lse,
block_out,
block_lse,
):
block_out = block_out.to(torch.float32)
block_lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1)
# new_lse = lse + torch.log(1 + torch.exp(block_lse - lse))
# torch.exp(lse - new_lse) * out + torch.exp(block_lse - new_lse) * block_out
# For additional context and discussion, please refer to:
# https://github.com/zhuzilin/ring-flash-attention/pull/34#issuecomment-2076126795
out = out - F.sigmoid(block_lse - lse) * (out - block_out)
lse = lse - F.logsigmoid(lse - block_lse)
return out, lse
@ATTN_WEIGHT_REGISTER("ring")
class RingAttnWeight(AttnWeightTemplate):
def __init__(self):
self.config = {}
self.helper = RingAttnHelper()
def apply(
self,
q,
k,
v,
slice_qkv_len,
cu_seqlens_qkv,
attention_module=None,
attention_type="flash_attn2",
seq_p_group=None,
use_fp8_comm=False,
use_tensor_fusion=False,
enable_head_parallel=False,
**kwargs,
):
"""
执行 Ring 注意力机制,结合图像和文本的查询、键和值。
参数:
q (torch.Tensor): 查询张量,形状为 [shard_seqlen, heads, hidden_dims]
k (torch.Tensor): 键张量,形状为 [shard_seqlen, heads, hidden_dims]
v (torch.Tensor): 值张量,形状为 [shard_seqlen, heads, hidden_dims]
slice_qkv_len (int): 图像查询、键和值的长度
cu_seqlens_qkv (torch.Tensor): 累积序列长度,包含文本和图像的长度信息
attention_type (str): 注意力类型,默认为 "flash_attn2"
返回:
torch.Tensor: 计算得到的注意力结果
"""
assert not enable_head_parallel, "RingAttn can't support head parallel mode."
use_kv_fusion = use_tensor_fusion
# 获取当前进程的排名和全局进程数
cur_rank = dist.get_rank(seq_p_group)
world_size = dist.get_world_size(seq_p_group)
img_qkv_len = slice_qkv_len
txt_qkv_len, txt_mask_len = self.helper._get_text_lengths(cu_seqlens_qkv, img_qkv_len)
# if RING_COMM is None:
# init_ring_comm()
RING_COMM = RingComm(seq_p_group)
# if len(cu_seqlens_qkv) == 3:
# txt_qkv_len = cu_seqlens_qkv[1] - img_qkv_len # 文本查询、键和值的长度
# txt_mask_len = cu_seqlens_qkv[2] - img_qkv_len # 文本掩码长度
# elif len(cu_seqlens_qkv) == 2:
# txt_qkv_len = cu_seqlens_qkv[1] - img_qkv_len # 文本查询、键和值的长度
# txt_mask_len = None
q = q.unsqueeze(0)
k = k.unsqueeze(0)
v = v.unsqueeze(0)
heads, hidden_dims = k.shape[-2], k.shape[-1]
img_q, img_k, img_v = q[:, :img_qkv_len, :, :].contiguous(), k[:, :img_qkv_len, :, :].contiguous(), v[:, :img_qkv_len, :, :].contiguous()
txt_q, txt_k, txt_v = (
q[:, img_qkv_len : img_qkv_len + txt_qkv_len, :, :].contiguous(),
k[:, img_qkv_len : img_qkv_len + txt_qkv_len, :, :].contiguous(),
v[:, img_qkv_len : img_qkv_len + txt_qkv_len, :, :].contiguous(),
)
out, lse, next_k, next_v = None, None, None, None
if len(cu_seqlens_qkv) == 3:
q = torch.cat((img_q, txt_q), dim=1)
k = img_k
v = img_v
if use_kv_fusion:
txt_kv = torch.stack([txt_k, txt_v], dim=0).reshape(2, txt_qkv_len, heads, hidden_dims).contiguous()
kv, original_dtype, original_shape = self.helper._prepare_kv_tensors(k, v, use_kv_fusion)
else:
original_dtype = k.dtype
original_shape = k.shape
for step in range(world_size):
if step + 1 != world_size:
if use_fp8_comm:
if use_kv_fusion:
next_kv_fp8, next_kv_scale = self.helper._send_recv_tensor(kv, hidden_dims, RING_COMM, use_fp8_comm, original_shape)
else:
next_k_fp8, next_k_scale = self.helper._send_recv_tensor(k, hidden_dims, RING_COMM, use_fp8_comm, original_shape)
next_v_fp8, next_v_scale = self.helper._send_recv_tensor(v, hidden_dims, RING_COMM, use_fp8_comm, original_shape)
else:
if use_kv_fusion:
next_kv = self.helper._send_recv_tensor(kv, hidden_dims, RING_COMM, use_fp8_comm, original_shape)[0]
else:
next_k = self.helper._send_recv_tensor(k, hidden_dims, RING_COMM, use_fp8_comm, original_shape)[0]
next_v = self.helper._send_recv_tensor(v, hidden_dims, RING_COMM, use_fp8_comm, original_shape)[0]
RING_COMM.commit()
if step + 1 == world_size:
if use_kv_fusion:
kv = torch.cat((kv, txt_kv), dim=1)
else:
k = torch.cat((k, txt_k), dim=1)
v = torch.cat((v, txt_v), dim=1)
if use_kv_fusion:
block_out, block_lse = self.ring_attn_sub_kv_fusion(q, kv)
else:
block_out, block_lse = self.ring_attn_sub(q, k, v)
out, lse = self.update_out_and_lse(out, lse, block_out, block_lse)
if step + 1 != world_size:
RING_COMM.wait()
if use_fp8_comm:
if use_kv_fusion:
kv = self.helper._dequantize_received(next_kv_fp8, next_kv_scale, original_dtype, original_shape, use_kv_fusion=True, is_kv_fusion=True)
else:
k, v = self.helper._dequantize_received(
next_k_fp8, next_k_scale, original_dtype, original_shape, use_kv_fusion=False, is_kv_fusion=False, v_fp8=next_v_fp8, v_scale=next_v_scale
)
else:
if use_kv_fusion:
kv = next_kv
else:
k, v = next_k, next_v
attn1 = out.to(GET_DTYPE()).squeeze(0).reshape(img_qkv_len + txt_qkv_len, -1)
if txt_mask_len > 0:
attn2, *_ = flash_attn.flash_attn_interface._flash_attn_forward(
q[:, -(txt_mask_len - txt_qkv_len) :, :, :].contiguous(),
k[:, -(txt_mask_len - txt_qkv_len) :, :, :].contiguous(),
v[:, -(txt_mask_len - txt_qkv_len) :, :, :].contiguous(),
dropout_p=0.0,
softmax_scale=q.shape[-1] ** (-0.5),
causal=False,
window_size_left=-1,
window_size_right=-1,
softcap=0.0,
alibi_slopes=None,
return_softmax=False,
)
attn2 = attn2.to(GET_DTYPE()).squeeze(0).reshape((txt_mask_len - txt_qkv_len), -1)
attn1 = torch.cat([attn1, attn2], dim=0)
return attn1
def ring_attn_sub_kv_fusion(self, q, kv, dropout_p=0.0, softmax_scale=None, causal=False, window_size=(-1, -1), softcap=0.0, alibi_slopes=None, return_softmax=False):
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
block_out, block_lse, _, _ = flash_attn.flash_attn_interface._flash_attn_forward(
q,
kv[:1, :, :, :],
kv[1:, :, :, :],
dropout_p=dropout_p,
softmax_scale=softmax_scale,
causal=causal,
window_size_left=window_size[0],
window_size_right=window_size[1],
softcap=softcap,
alibi_slopes=alibi_slopes,
return_softmax=return_softmax,
)
return block_out, block_lse
def ring_attn_sub(self, q, k, v, dropout_p=0.0, softmax_scale=None, causal=False, window_size=(-1, -1), softcap=0.0, alibi_slopes=None, return_softmax=False):
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
block_out, block_lse, _, _ = flash_attn.flash_attn_interface._flash_attn_forward(
q,
k,
v,
dropout_p=dropout_p,
softmax_scale=softmax_scale,
causal=causal,
window_size_left=window_size[0],
window_size_right=window_size[1],
softcap=softcap,
alibi_slopes=alibi_slopes,
return_softmax=return_softmax,
)
return block_out, block_lse
def update_out_and_lse(
self,
out,
lse,
block_out,
block_lse,
slice_=None,
):
if out is None:
if slice_ is not None:
raise RuntimeError("first update_out_and_lse should not pass slice_ args")
out = block_out.to(torch.float32)
lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1)
elif slice_ is not None:
slice_out, slice_lse = out[slice_], lse[slice_]
slice_out, slice_lse = _update_out_and_lse(slice_out, slice_lse, block_out, block_lse)
out[slice_], lse[slice_] = slice_out, slice_lse
else:
out, lse = _update_out_and_lse(out, lse, block_out, block_lse)
return out, lse
class RingAttnHelper:
"""辅助函数类,处理 Ring Attention 中的量化、通信和反量化逻辑"""
@staticmethod
def _quant_and_send(tensor, hidden_dims, comm, original_shape=None):
"""
对张量进行 FP8 量化并通过通信器发送/接收
参数:
tensor: 要量化和发送的张量
hidden_dims: 隐藏维度大小
comm: 通信器对象
original_shape: 原始形状(用于 reshape 回原始形状)
返回:
tuple: (量化后的张量, scale 张量)
"""
if original_shape is None:
original_shape = tensor.shape
# 量化为 FP8
tensor_fp8, tensor_scale = quant_fp8_vllm(tensor.reshape(-1, hidden_dims))
# reshape 回原始形状
tensor_fp8 = tensor_fp8.reshape(original_shape)
tensor_scale = tensor_scale.reshape(original_shape[0], original_shape[1], original_shape[2], 1)
# 发送/接收量化后的张量
next_tensor_fp8 = comm.send_recv(tensor_fp8)
next_tensor_scale = comm.send_recv(tensor_scale)
return next_tensor_fp8, next_tensor_scale
@staticmethod
def _prepare_kv_tensors(k, v, use_kv_fusion):
"""
准备 K 和 V 张量,根据是否使用 KV 融合返回适当的张量
参数:
k: 键张量
v: 值张量
use_kv_fusion: 是否使用 KV 融合
返回:
tuple: (主张量, 原始数据类型, 原始形状)
"""
original_dtype = k.dtype
original_shape = k.shape
if use_kv_fusion:
# 融合 K 和 V
kv = torch.stack([k, v], dim=0).reshape(2, k.shape[1], k.shape[2], k.shape[3]).contiguous()
return kv, original_dtype, kv.shape
else:
return k, original_dtype, original_shape
@staticmethod
def _dequantize_received(next_tensor_fp8, next_tensor_scale, original_dtype, original_shape, use_kv_fusion=False, is_kv_fusion=False, v_fp8=None, v_scale=None):
"""
反量化接收到的 FP8 张量
参数:
next_tensor_fp8: 接收到的量化张量
next_tensor_scale: 接收到的 scale 张量
original_dtype: 原始数据类型
original_shape: 原始形状
use_kv_fusion: 是否使用 KV 融合模式
is_kv_fusion: 当前张量是否为 KV 融合张量
v_fp8, v_scale: 分离模式下的 V 张量和 scale
返回:
tuple: 反量化后的张量 (k, v) 或 kv
"""
if use_kv_fusion and is_kv_fusion:
# KV 融合模式
return dequant_fp8_vllm(next_tensor_fp8, next_tensor_scale, original_dtype)
elif not use_kv_fusion:
# 分离模式
k = dequant_fp8_vllm(next_tensor_fp8, next_tensor_scale, original_dtype)
v = dequant_fp8_vllm(v_fp8, v_scale, original_dtype)
return k, v
else:
# 默认返回单个张量
return dequant_fp8_vllm(next_tensor_fp8, next_tensor_scale, original_dtype)
@staticmethod
def _send_recv_tensor(tensor, hidden_dims, comm, use_fp8_comm, original_shape=None):
"""
发送/接收张量,根据是否使用 FP8 选择通信方式
参数:
tensor: 要发送的张量
hidden_dims: 隐藏维度大小
comm: 通信器对象
use_fp8_comm: 是否使用 FP8 通信
original_shape: 原始形状
返回:
tuple: 接收到的张量(和可能的 scale)
"""
if use_fp8_comm:
return RingAttnHelper._quant_and_send(tensor, hidden_dims, comm, original_shape)
else:
next_tensor = comm.send_recv(tensor)
return next_tensor, None
@staticmethod
def _get_text_lengths(cu_seqlens_qkv, img_qkv_len):
"""
从累积序列长度中获取文本长度
参数:
cu_seqlens_qkv: 累积序列长度
img_qkv_len: 图像序列长度
返回:
tuple: (文本QKV长度, 文本掩码长度)
"""
if len(cu_seqlens_qkv) == 3:
txt_qkv_len = cu_seqlens_qkv[1] - img_qkv_len
txt_mask_len = cu_seqlens_qkv[2] - img_qkv_len
elif len(cu_seqlens_qkv) == 2:
txt_qkv_len = cu_seqlens_qkv[1] - img_qkv_len
txt_mask_len = 0
else:
raise ValueError(f"Invalid cu_seqlens_qkv length: {len(cu_seqlens_qkv)}")
return txt_qkv_len, txt_mask_len
import torch
from loguru import logger
from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER
from .template import AttnWeightTemplate
capability = torch.cuda.get_device_capability(0) if torch.cuda.is_available() else None
if capability in [(8, 9), (12, 0)]:
try:
from sageattention import sageattn_qk_int8_pv_fp16_triton as sageattn
except ImportError:
logger.info("sageattn not found, please install sageattention first")
sageattn = None
else:
try:
from sageattention import sageattn
except ImportError:
logger.info("sageattn not found, please install sageattention first")
sageattn = None
try:
from sageattn3 import sageattn3_blackwell
except ImportError:
logger.info("sageattn3 not found, please install sageattention first")
sageattn3_blackwell = None
@ATTN_WEIGHT_REGISTER("sage_attn2")
class SageAttn2Weight(AttnWeightTemplate):
def __init__(self):
self.config = {}
def apply(
self,
q,
k,
v,
cu_seqlens_q=None,
cu_seqlens_kv=None,
max_seqlen_q=None,
max_seqlen_kv=None,
**kwargs,
):
q, k, v = q.contiguous(), k.contiguous(), v.contiguous()
if len(q.shape) == 3:
bs = 1
q, k, v = q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0)
elif len(q.shape) == 4:
bs = q.shape[0]
x = sageattn(
q,
k,
v,
tensor_layout="NHD",
).view(bs * max_seqlen_q, -1)
return x
@ATTN_WEIGHT_REGISTER("sage_attn3")
class SageAttn3Weight(AttnWeightTemplate):
def __init__(self):
self.config = {}
def apply(
self,
q,
k,
v,
cu_seqlens_q=None,
cu_seqlens_kv=None,
max_seqlen_q=None,
max_seqlen_kv=None,
**kwargs,
):
q, k, v = q.contiguous(), k.contiguous(), v.contiguous()
if len(q.shape) == 3:
bs = 1
q, k, v = q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0)
elif len(q.shape) == 4:
bs = q.shape[0]
x = sageattn3_blackwell(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)).transpose(1, 2).reshape(bs * max_seqlen_q, -1)
return x
import torch
from loguru import logger
from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER
from .kernels.sla_kernel import _attention
from .template import AttnWeightTemplate
from .utils.sla_util import get_block_map, get_cuda_arch
try:
import spas_sage_attn._fused as fused
import spas_sage_attn._qattn as qattn
from spas_sage_attn.utils import block_map_lut_triton, get_vanilla_qk_quant
except ImportError:
logger.warning("spas_sage_attn is not installed. SageSparseLinearAttention will not be available.")
SAGE2PP_ENABLED = True
try:
from spas_sage_attn._qattn import qk_int8_sv_f8_accum_f16_block_sparse_attn_inst_buf_fuse_v_scale_with_pv_threshold
except ImportError:
SAGE2PP_ENABLED = False
try:
from magi_attention.functional import flex_flash_attn_func as magi_ffa_func
except ImportError:
magi_ffa_func = None
@ATTN_WEIGHT_REGISTER("sla_attn")
class SlaAttnWeight(AttnWeightTemplate):
sparsity_ratio = 0.8
operator = "triton"
def __init__(self):
self.config = {}
self.arch = get_cuda_arch(torch.cuda.current_device())
self.topk = 1 - self.sparsity_ratio
if self.operator == "triton":
self.BLKQ, self.BLKK = 128, 128
self.apply_func = self.apply_triton
elif self.operator == "sage":
if self.arch == "sm90":
self.BLKQ, self.BLKK = 64, 128
else:
self.BLKQ, self.BLKK = 128, 64
self.apply_func = self.apply_sage
elif self.operator == "magi":
self.BLKQ, self.BLKK = 128, 128
self.apply_func = self.apply_magi
else:
raise NotImplementedError(f"Not supported SLA operator: {self.operator}.")
logger.info(f"SlaAttnWeight: sparsity_ratio={self.sparsity_ratio}, operator={self.operator}, topk={self.topk}, BLKQ={self.BLKQ}, BLKK={self.BLKK}")
def apply(
self,
q,
k,
v,
cu_seqlens_q=None,
cu_seqlens_kv=None,
max_seqlen_q=None,
max_seqlen_kv=None,
**kwargs,
):
return self.apply_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, **kwargs)
def apply_triton(
self,
q,
k,
v,
cu_seqlens_q=None,
cu_seqlens_kv=None,
max_seqlen_q=None,
max_seqlen_kv=None,
**kwargs,
):
# (L, H, D) -> (B, H, L, D)
q = q.unsqueeze(0).transpose(1, 2).contiguous()
k = k.unsqueeze(0).transpose(1, 2).contiguous()
v = v.unsqueeze(0).transpose(1, 2).contiguous()
sparse_map, lut, real_topk = get_block_map(q, k, topk_ratio=self.topk, BLKQ=self.BLKQ, BLKK=self.BLKK)
out = _attention.apply(q, k, v, sparse_map, lut, real_topk, self.BLKQ, self.BLKK)
out = out.transpose(1, 2).reshape(max_seqlen_q, -1)
return out
def apply_sage(
self,
q,
k,
v,
cu_seqlens_q=None,
cu_seqlens_kv=None,
max_seqlen_q=None,
max_seqlen_kv=None,
**kwargs,
):
# (L, H, D) -> (B, H, L, D)
q = q.unsqueeze(0).transpose(1, 2).contiguous()
k = k.unsqueeze(0).transpose(1, 2).contiguous()
v = v.unsqueeze(0).transpose(1, 2).contiguous()
sparse_map, lut, real_topk = get_block_map(q, k, topk_ratio=self.topk, BLKQ=self.BLKQ, BLKK=self.BLKK)
km = k.mean(dim=-2, keepdim=True)
headdim = q.size(-1)
q_int8, q_scale, k_int8, k_scale = get_vanilla_qk_quant(q, k, km, self.BLKQ, self.BLKK)
lut, valid_block_num = block_map_lut_triton(sparse_map)
scale = 1.0 / (headdim**0.5)
assert headdim in [64, 128], "headdim should be in [64, 128]. For other headdim, you can use padding and specify the softmax scale."
o_s = torch.empty_like(q)
if self.arch in ("sm80", "sm86", "sm87"):
pvthreshold = torch.full((q.shape[-3],), 1e6, dtype=torch.float32, device=q.device)
v_fp16 = v.to(torch.float16)
qattn.qk_int8_sv_f16_accum_f16_block_sparse_attn_inst_buf_with_pv_threshold(q_int8, k_int8, v_fp16, o_s, lut, valid_block_num, pvthreshold, q_scale, k_scale, 1, False, 1, scale, 0)
else:
b, h_kv, kv_len, head_dim = v.shape
padded_len = (kv_len + 127) // 128 * 128
v_transposed_permutted = torch.empty((b, h_kv, head_dim, padded_len), dtype=v.dtype, device=v.device)
fused.transpose_pad_permute_cuda(v, v_transposed_permutted, 1)
v_fp8 = torch.empty(v_transposed_permutted.shape, dtype=torch.float8_e4m3fn, device=v.device)
v_scale = torch.empty((b, h_kv, head_dim), dtype=torch.float32, device=v.device)
fused.scale_fuse_quant_cuda(v_transposed_permutted, v_fp8, v_scale, kv_len, 2.25, 1)
if self.arch == "sm90":
qattn.qk_int8_sv_f8_accum_f32_block_sparse_attn_inst_buf_fuse_v_scale_sm90(q_int8, k_int8, v_fp8, o_s, lut, valid_block_num, q_scale, k_scale, v_scale, 1, False, 1, scale)
else:
pvthreshold = torch.full((q.shape[-3],), 1e6, dtype=torch.float32, device=q.device)
if SAGE2PP_ENABLED:
qk_int8_sv_f8_accum_f16_block_sparse_attn_inst_buf_fuse_v_scale_with_pv_threshold(
q_int8, k_int8, v_fp8, o_s, lut, valid_block_num, pvthreshold, q_scale, k_scale, v_scale, 1, False, 1, scale, 0
)
else:
qattn.qk_int8_sv_f8_accum_f32_block_sparse_attn_inst_buf_fuse_v_scale_with_pv_threshold(
q_int8, k_int8, v_fp8, o_s, lut, valid_block_num, pvthreshold, q_scale, k_scale, v_scale, 1, False, 1, scale, 0
)
out = o_s.transpose(1, 2).reshape(max_seqlen_q, -1)
return out
def apply_magi(
self,
q,
k,
v,
cu_seqlens_q=None,
cu_seqlens_kv=None,
max_seqlen_q=None,
max_seqlen_kv=None,
**kwargs,
):
# (L, H, D) -> (B, H, L, D)
q_block_map, k_block_map = q.unsqueeze(0).transpose(1, 2), k.unsqueeze(0).transpose(1, 2)
q_block_map = q_block_map.contiguous()
k_block_map = k_block_map.contiguous()
sparse_map, lut, real_topk = get_block_map(q_block_map, k_block_map, topk_ratio=self.topk, BLKQ=self.BLKQ, BLKK=self.BLKK)
seqlen, head_num, head_dim = q.shape
q_ranges, k_ranges = self.generate_qk_ranges(sparse_map[0], self.BLKQ, self.BLKK, seqlen)
attn_type_map = torch.zeros(len(q_ranges), dtype=torch.int32, device="cpu").to(q.device, non_blocking=True)
q = q.permute(1, 0, 2).reshape(head_num * seqlen, 1, head_dim)
k = k.permute(1, 0, 2).reshape(head_num * seqlen, 1, head_dim)
v = v.permute(1, 0, 2).reshape(head_num * seqlen, 1, head_dim)
out = magi_ffa_func(
q,
k,
v,
q_ranges=q_ranges,
k_ranges=k_ranges,
attn_type_map=attn_type_map,
auto_range_merge=True,
)[0]
out = out.reshape(head_num, seqlen, head_dim).permute(1, 0, 2)
return out.reshape(out.shape[0], -1)
def generate_qk_ranges(self, mask, q_block_size, k_block_size, seqlen):
# mask: [H, Q_block_num, K_block_num]
h_indices, i_indices, j_indices = torch.nonzero(mask, as_tuple=True)
base_offset = h_indices * seqlen
q_start = base_offset + i_indices * q_block_size
q_end = base_offset + torch.clamp((i_indices + 1) * q_block_size, max=seqlen)
k_start = base_offset + j_indices * k_block_size
k_end = base_offset + torch.clamp((j_indices + 1) * k_block_size, max=seqlen)
q_ranges = torch.stack([q_start, q_end], dim=1).to(dtype=torch.int32)
k_ranges = torch.stack([k_start, k_end], dim=1).to(dtype=torch.int32)
return q_ranges, k_ranges
from abc import ABC, abstractmethod
import torch
import torch.nn.functional as F
from loguru import logger
from lightx2v.utils.registry_factory import SPARSE_MASK_GENERATOR_REGISTER
from .nbhd_attn import generate_nbhd_mask
from .svg_attn import diagonal_band_mask_from_sparsity, get_attention_mask, wan_hidden_states_placement, wan_sparse_head_placement
from .utils.sla_util import get_block_map
class GeneralMaskGenerator(ABC):
def __init__(self, q_block_size=128, k_block_size=128, sparse_setting={}, attnmap_frame_num=None):
self.sparse_setting = sparse_setting
self.q_block_size = q_block_size
self.k_block_size = k_block_size
self.attnmap_frame_num = attnmap_frame_num
@abstractmethod
def __call__(self, q, k):
pass
def reorg(self, q, k, v):
return q, k, v
def restore(self, out):
return out
@SPARSE_MASK_GENERATOR_REGISTER("sla_mask_generator")
class SlaMaskGenerator(GeneralMaskGenerator):
def __init__(self, q_block_size=128, k_block_size=128, sparse_setting={}, attnmap_frame_num=None):
super().__init__(q_block_size, k_block_size, sparse_setting, attnmap_frame_num)
sparsity_ratio = self.sparse_setting.get("sla_sparsity_ratio", 0.8)
self.topk_ratio = 1 - sparsity_ratio
def __call__(self, q, k):
# (L, H, D) -> (B, H, L, D)
q = q.unsqueeze(0).transpose(1, 2).contiguous()
k = k.unsqueeze(0).transpose(1, 2).contiguous()
sparse_map, lut, topk = get_block_map(q, k, topk_ratio=self.topk_ratio, BLKQ=self.q_block_size, BLKK=self.k_block_size)
# return: [B, H, Q_block_num, K_block_num]
return sparse_map
@SPARSE_MASK_GENERATOR_REGISTER("nbhd_mask_generator")
class NbhdMaskGenerator(GeneralMaskGenerator):
seqlen = None
mask = None
def __init__(self, q_block_size=128, k_block_size=128, sparse_setting={}, attnmap_frame_num=None):
super().__init__(q_block_size, k_block_size, sparse_setting, attnmap_frame_num)
self.coefficient = self.sparse_setting.get("nbhd_coefficient", [1.0, 0.5, 0.056])
self.min_width = self.sparse_setting.get("nbhd_min_width", 1.0)
self.block_size = self.q_block_size
def __call__(self, q, k):
seqlen, head_num, head_dim = q.shape
if seqlen == NbhdMaskGenerator.seqlen:
return NbhdMaskGenerator.mask
block_num = (seqlen + self.block_size - 1) // self.block_size
block_num_per_frame = seqlen / self.attnmap_frame_num / self.block_size
mask = generate_nbhd_mask(block_num_per_frame, block_num, self.attnmap_frame_num, coefficient=self.coefficient, min_width=self.min_width, device=q.device)
mask = mask[None, None, :, :].repeat(1, head_num, 1, 1)
# return: [B, H, Q_block_num, K_block_num]
NbhdMaskGenerator.seqlen = seqlen
NbhdMaskGenerator.mask = mask
return mask
@SPARSE_MASK_GENERATOR_REGISTER("svg_mask_generator")
class SvgMaskGenerator(GeneralMaskGenerator):
seqlen = None
attention_masks = None
mask = None
def __init__(self, q_block_size=128, k_block_size=128, sparse_setting={}, attnmap_frame_num=None):
super().__init__(q_block_size, k_block_size, sparse_setting, attnmap_frame_num)
self.sample_mse_max_row = self.sparse_setting.get("svg_sample_mse_max_row", 10000)
self.num_sampled_rows = self.sparse_setting.get("svg_num_sampled_rows", 64)
self.context_length = self.sparse_setting.get("svg_context_length", 0)
self.sparsity = self.sparse_setting.get("svg_sparsity", 0.75)
self.block_size = self.k_block_size
self.best_model_idx = None
self.head_num = None
self.head_dim = None
def prepare_mask(self, q):
seqlen, head_num, head_dim = q.shape
if seqlen == SvgMaskGenerator.seqlen:
return
logger.info(f"SvgMaskGenerator: Preparing mask for seqlen={seqlen}, head_num={head_num}, head_dim={head_dim}")
frame_size = seqlen // self.attnmap_frame_num
SvgMaskGenerator.attention_masks = [get_attention_mask(mask_name, self.sample_mse_max_row, self.context_length, self.attnmap_frame_num, frame_size) for mask_name in ["spatial", "temporal"]]
block_num = (seqlen + self.block_size - 1) // self.block_size
block_num_per_frame = block_num // self.attnmap_frame_num
mask = diagonal_band_mask_from_sparsity(block_num, block_num_per_frame, self.sparsity, device=q.device)
SvgMaskGenerator.mask = mask[None, None, :, :].repeat(1, head_num, 1, 1)
SvgMaskGenerator.seqlen = seqlen
def sample_mse(self, query, key, value):
cfg, num_heads, seq_len, dim = query.size()
num_sampled_rows = min(self.num_sampled_rows, seq_len)
sampled_rows = torch.randint(low=0, high=self.sample_mse_max_row, size=(num_sampled_rows,))
sampled_q = query[:, :, sampled_rows, :]
sampled_qk_scores = torch.matmul(sampled_q, key.transpose(-2, -1)) / (dim**0.5)
sampled_attn_weights = F.softmax(sampled_qk_scores, dim=-1)
sampled_golden_hidden_states = torch.matmul(sampled_attn_weights, value) # (1, seq_len, dim)
sampled_mses = torch.zeros(len(self.attention_masks), cfg, num_heads, device=query.device, dtype=query.dtype)
# Only have Tri-diagonal and Striped
for mask_idx, attn_mask in enumerate(self.attention_masks):
sampled_attention_mask = attn_mask[sampled_rows, :]
sampled_attention_scores = sampled_qk_scores.masked_fill(sampled_attention_mask == 0, float("-inf"))
sampled_attn_weights = F.softmax(sampled_attention_scores, dim=-1)
sampled_hidden_states = torch.matmul(sampled_attn_weights, value)
mse = torch.mean((sampled_hidden_states - sampled_golden_hidden_states) ** 2, dim=(2, 3))
sampled_mses[mask_idx] = mse
return sampled_mses
def reorg(self, q, k, v):
seqlen, head_num, head_dim = q.shape
q = q.unsqueeze(0).transpose(1, 2)
k = k.unsqueeze(0).transpose(1, 2)
v = v.unsqueeze(0).transpose(1, 2)
sampled_mses = self.sample_mse(q, k, v)
self.best_mask_idx = torch.argmin(sampled_mses, dim=0)
self.head_num = head_num
self.head_dim = head_dim
q_out, k_out, v_out = torch.zeros_like(q), torch.zeros_like(k), torch.zeros_like(v)
wan_sparse_head_placement(q, k, v, q_out, k_out, v_out, self.best_mask_idx, self.context_length, self.attnmap_frame_num, seqlen // self.attnmap_frame_num)
q_out = q_out.transpose(1, 2).squeeze(0)
k_out = k_out.transpose(1, 2).squeeze(0)
v_out = v_out.transpose(1, 2).squeeze(0)
return q_out, k_out, v_out
def restore(self, out):
# out: (L, H*D)
out = out.reshape(-1, self.head_num, self.head_dim)
seqlen = out.shape[0]
# (L, H, D) -> (B, H, L, D)
out = out.unsqueeze(0).transpose(1, 2)
restore_out = torch.zeros_like(out)
wan_hidden_states_placement(out, restore_out, self.best_mask_idx, self.context_length, self.attnmap_frame_num, seqlen // self.attnmap_frame_num)
restore_out = restore_out.transpose(1, 2).reshape(seqlen, -1)
return restore_out
def __call__(self, q, k):
self.prepare_mask(q)
# return: [B, H, Q_block_num, K_block_num]
return self.mask
import torch
from lightx2v.utils.registry_factory import SPARSE_OPERATOR_REGISTER
from lightx2v_platform.base.global_var import AI_DEVICE
from .kernels.sla_kernel import _attention
try:
from magi_attention.functional import flex_flash_attn_func as magi_ffa_func
except ImportError:
magi_ffa_func = None
try:
from flex_block_attn import flex_block_attn_func
except ImportError:
flex_block_attn_func = None
try:
import flashinfer
except ImportError:
flashinfer = None
@SPARSE_OPERATOR_REGISTER("sla_triton_operator")
class SlaTritonOperator:
def __init__(self):
self.q_block_size = 128
self.k_block_size = 128
def __call__(
self,
q,
k,
v,
mask,
cu_seqlens_q=None,
cu_seqlens_kv=None,
max_seqlen_q=None,
max_seqlen_kv=None,
**kwargs,
):
# (L, H, D) -> (B, H, L, D)
q = q.unsqueeze(0).transpose(1, 2).contiguous()
k = k.unsqueeze(0).transpose(1, 2).contiguous()
v = v.unsqueeze(0).transpose(1, 2).contiguous()
# (B, H, Q_block_num, K_block_num)
mask = mask.int()
topk = int(mask.sum(dim=-1).max().item())
lut = torch.topk(mask, topk, dim=-1, sorted=False).indices
out = _attention.apply(q, k, v, mask, lut, topk, self.q_block_size, self.k_block_size)
out = out.transpose(1, 2).reshape(max_seqlen_q, -1)
return out
@SPARSE_OPERATOR_REGISTER("magi_operator")
class MagiOperator:
def __init__(self):
self.q_block_size = 128
self.k_block_size = 128
def generate_qk_ranges(self, mask, q_block_size, k_block_size, seqlen):
# mask: [H, Q_block_num, K_block_num]
h_indices, i_indices, j_indices = torch.nonzero(mask, as_tuple=True)
base_offset = h_indices * seqlen
q_start = base_offset + i_indices * q_block_size
q_end = base_offset + torch.clamp((i_indices + 1) * q_block_size, max=seqlen)
k_start = base_offset + j_indices * k_block_size
k_end = base_offset + torch.clamp((j_indices + 1) * k_block_size, max=seqlen)
q_ranges = torch.stack([q_start, q_end], dim=1).to(dtype=torch.int32)
k_ranges = torch.stack([k_start, k_end], dim=1).to(dtype=torch.int32)
return q_ranges, k_ranges
def __call__(
self,
q,
k,
v,
mask,
cu_seqlens_q=None,
cu_seqlens_kv=None,
max_seqlen_q=None,
max_seqlen_kv=None,
**kwargs,
):
seqlen, head_num, head_dim = q.shape
# (B, H, Q_block_num, K_block_num) -> (H, Q_block_num, K_block_num)
mask = mask.squeeze(0)
q_ranges, k_ranges = self.generate_qk_ranges(mask, self.q_block_size, self.k_block_size, seqlen)
attn_type_map = torch.zeros(len(q_ranges), dtype=torch.int32, device="cpu").to(q.device, non_blocking=True)
q = q.permute(1, 0, 2).reshape(head_num * seqlen, 1, head_dim)
k = k.permute(1, 0, 2).reshape(head_num * seqlen, 1, head_dim)
v = v.permute(1, 0, 2).reshape(head_num * seqlen, 1, head_dim)
out = magi_ffa_func(
q,
k,
v,
q_ranges=q_ranges,
k_ranges=k_ranges,
attn_type_map=attn_type_map,
auto_range_merge=True,
)[0]
out = out.reshape(head_num, seqlen, head_dim).permute(1, 0, 2)
return out.reshape(out.shape[0], -1)
@SPARSE_OPERATOR_REGISTER("flex_block_operator")
class FlexBlockOperator:
def __init__(self):
self.q_block_size = 128
self.k_block_size = 128
def __call__(
self,
q,
k,
v,
mask,
cu_seqlens_q=None,
cu_seqlens_kv=None,
max_seqlen_q=None,
max_seqlen_kv=None,
**kwargs,
):
q = q.unsqueeze(0).transpose(1, 2)
k = k.unsqueeze(0).transpose(1, 2)
v = v.unsqueeze(0).transpose(1, 2)
pad_len = (self.q_block_size - q.shape[2] % self.q_block_size) % self.q_block_size
if pad_len > 0:
q = torch.nn.functional.pad(q, (0, 0, 0, pad_len))
k = torch.nn.functional.pad(k, (0, 0, 0, pad_len))
v = torch.nn.functional.pad(v, (0, 0, 0, pad_len))
# (B, H, Q_block_num, K_block_num)
mask = mask.bool()
out = flex_block_attn_func(q, k, v, self.q_block_size, self.k_block_size, mask)
if pad_len > 0:
out = out[:, :, :-pad_len, :]
out = out.transpose(1, 2)
return out.reshape(max_seqlen_q, -1)
@SPARSE_OPERATOR_REGISTER("flashinfer_operator")
class FlashinferOperator:
sparse_wrapper = None
mask = None
def __init__(self):
self.q_block_size = 128
self.k_block_size = 128
if FlashinferOperator.sparse_wrapper is None:
float_workspace_buffer = torch.empty(1024 * 1024 * 1024, dtype=torch.uint8, device=AI_DEVICE)
FlashinferOperator.sparse_wrapper = flashinfer.sparse.VariableBlockSparseAttentionWrapper(float_workspace_buffer, backend="fa2")
def __call__(
self,
q,
k,
v,
mask,
cu_seqlens_q=None,
cu_seqlens_kv=None,
max_seqlen_q=None,
max_seqlen_kv=None,
**kwargs,
):
seqlen, head_num, head_dim = q.shape
# (B, H, Q_block_num, K_block_num) -> (H, Q_block_num, K_block_num)
mask = mask.squeeze(0)
if FlashinferOperator.mask is None or not torch.equal(mask, FlashinferOperator.mask):
_, q_block_num, k_block_num = mask.shape
block_row_sz = torch.ones(q_block_num, dtype=torch.int32, device=q.device) * self.q_block_size
block_row_sz[-1] = seqlen - self.q_block_size * (q_block_num - 1)
block_row_sz = block_row_sz.unsqueeze(0).repeat(head_num, 1)
block_col_sz = torch.ones(k_block_num, dtype=torch.int32, device=k.device) * self.k_block_size
block_col_sz[-1] = seqlen - self.k_block_size * (k_block_num - 1)
block_col_sz = block_col_sz.unsqueeze(0).repeat(head_num, 1)
FlashinferOperator.sparse_wrapper.plan(
block_mask_map=mask,
block_row_sz=block_row_sz,
block_col_sz=block_col_sz,
num_qo_heads=head_num,
num_kv_heads=head_num,
head_dim=head_dim,
q_data_type=q.dtype,
)
FlashinferOperator.mask = mask
q = q.transpose(0, 1)
k = k.transpose(0, 1)
v = v.transpose(0, 1)
out = FlashinferOperator.sparse_wrapper.run(q, k, v)
out = out.transpose(0, 1)
return out.reshape(max_seqlen_q, -1)
import os
import torch
try:
import spas_sage_attn
except ImportError:
spas_sage_attn = None
from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER
from .template import AttnWeightTemplate
@ATTN_WEIGHT_REGISTER("spas_sage_attn")
class SageAttnWeight(AttnWeightTemplate):
def __init__(self):
self.config = {}
@classmethod
def apply(self, q, k, v, cu_seqlens_q=None, cu_seqlens_kv=None, max_seqlen_q=None, max_seqlen_kv=None, tensor_layout="HND", **kwargs):
q = q.unsqueeze(0)
k = k.unsqueeze(0)
v = v.unsqueeze(0)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
attn_out = spas_sage_attn.core.spas_sage2_attn_meansim_cuda(q, k, v, tensor_layout)
_, H, N, D = attn_out.shape
attn_out = attn_out.permute(2, 1, 3, 0).contiguous().view(N, H * D)
return attn_out
if __name__ == "__main__":
import matplotlib.pyplot as plt
# 1. 构造输入
q = torch.randn(32760, 12, 128, dtype=torch.bfloat16).cuda()
k = torch.randn(32760, 12, 128, dtype=torch.bfloat16).cuda()
v = torch.randn(32760, 12, 128, dtype=torch.bfloat16).cuda()
# 2. 直接用PyTorch计算注意力
q_ = q.float()
k_ = k.float()
v_ = v.float()
attn_weights = torch.matmul(q_, k_.transpose(-2, -1)) / (128**0.5)
attn_weights = torch.softmax(attn_weights, dim=-1)
output_pt = torch.matmul(attn_weights, v_)
# 3. 用spas_sage2_attn_meansim_cuda计算注意力
q = q.unsqueeze(0) # shape: (1, 32760, 12, 128)
k = k.unsqueeze(0)
v = v.unsqueeze(0)
q = q.transpose(1, 2) # shape: (1, 12, 32760, 128)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
output_cuda = spas_sage_attn.core.spas_sage2_attn_meansim_cuda(q, k, v, tensor_layout="HND")
output_cuda = output_cuda.float()
# 4. 取左上角[3000, 3000],只取第一个head
output_pt_crop = output_pt[0, :3000, :3000].cpu().detach().numpy()
output_cuda_crop = output_cuda[0, 0, :3000, :3000].cpu().detach().numpy()
# 5. 保存图片
save_dir = os.path.expanduser("~/Log/10-22/")
os.makedirs(save_dir, exist_ok=True)
plt.imshow(output_pt_crop, aspect="auto")
plt.title("PyTorch Attention (left-top 3000x3000)")
plt.savefig(os.path.join(save_dir, "attn.png"))
plt.close()
plt.imshow(output_cuda_crop, aspect="auto")
plt.title("spas_sage2_attn_meansim_cuda (left-top 3000x3000)")
plt.savefig(os.path.join(save_dir, "spas_attn.png"))
plt.close()
from typing import Optional
# Please reinstall flashinfer by referring to https://github.com/svg-project/Sparse-VideoGen
try:
import flashinfer
except ImportError:
flashinfer = None
import torch
import triton
import triton.language as tl
from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER
from .svg2_attn_utils import (
batch_kmeans_Euclid,
identify_dynamic_map,
)
from .template import AttnWeightTemplate
@triton.jit
def _permute_kernel(
X_ptr,
IDX_ptr,
Y_ptr,
S: tl.constexpr,
D: tl.constexpr,
BLOCK_S: tl.constexpr,
):
"""Each program permutes BLOCK_S tokens *all* hidden features (D). No inner python loop."""
pid_bh = tl.program_id(0)
tile_s = tl.program_id(1)
# Offsets along sequence
s_offsets = tile_s * BLOCK_S + tl.arange(0, BLOCK_S)
token_mask = s_offsets < S
# Gather source indices for these tokens
idx_ptrs = IDX_ptr + pid_bh * S + s_offsets
src_row_idx = tl.load(idx_ptrs, mask=token_mask, other=0).to(tl.int32)
# Broadcast to create 2-D pointer matrix (BLOCK_S, D)
d_offsets = tl.arange(0, D)
src_ptrs = X_ptr + (pid_bh * S + src_row_idx[:, None]) * D + d_offsets[None, :]
dst_ptrs = Y_ptr + (pid_bh * S + s_offsets[:, None]) * D + d_offsets[None, :]
full_mask = token_mask[:, None]
values = tl.load(src_ptrs, mask=full_mask, other=0.0)
tl.store(dst_ptrs, values, mask=full_mask)
def permute_tensor_by_labels_triton(
tensor: torch.Tensor,
labels: Optional[torch.Tensor],
dim: int,
*,
sorted_indices: Optional[torch.Tensor] = None,
):
"""
Permute `tensor` along `dim` according to ascending order of `labels`.
This is a Triton-accelerated replacement for the original implementation.
It currently supports 4-D tensors of shape [B, H, S, D] and `dim == 2`.
If these conditions are not met or the tensors reside on CPU, we fall back
to the reference PyTorch implementation.
"""
# Assertions – we only support the optimized CUDA path.
assert dim == 2, "permute_tensor_by_labels currently only supports dim==2 (sequence dimension)"
assert tensor.dim() == 4, "Expected tensor shape [B,H,S,D]"
assert tensor.is_cuda, "permute_tensor_by_labels requires CUDA tensors"
B, H, S, D = tensor.shape
BH = B * H
# Determine sorted indices
if sorted_indices is not None:
sorted_indices = sorted_indices.to(torch.int32).contiguous()
else:
assert labels is not None, "Either `labels` or `sorted_indices` must be provided."
labels = labels.to(tensor.device)
sorted_indices = torch.argsort(labels, dim=-1).to(torch.int32).contiguous()
# Flatten tensor and allocate output
inp_flat = tensor.reshape(BH, S, D).contiguous()
out_flat = torch.empty_like(inp_flat)
# Triton kernel tile size
BLOCK_S = 64 # number of tokens per program, tunable
n_s_tiles = triton.cdiv(S, BLOCK_S)
grid = (BH, n_s_tiles)
_permute_kernel[grid](inp_flat, sorted_indices, out_flat, S, D, BLOCK_S, num_warps=4)
permuted_tensor = out_flat.reshape(B, H, S, D)
return permuted_tensor, sorted_indices
@triton.jit
def _inverse_permute_kernel(
X_ptr,
IDX_ptr,
Y_ptr,
S: tl.constexpr,
D: tl.constexpr,
BLOCK_S: tl.constexpr,
):
"""Inverse permutation: scatter BLOCK_S tokens back in one shot."""
pid_bh = tl.program_id(0)
tile_s = tl.program_id(1)
s_offsets = tile_s * BLOCK_S + tl.arange(0, BLOCK_S)
token_mask = s_offsets < S
idx_ptrs = IDX_ptr + pid_bh * S + s_offsets
src_pos_idx = s_offsets.to(tl.int32)
dst_pos_idx = tl.load(idx_ptrs, mask=token_mask, other=0).to(tl.int32)
d_offsets = tl.arange(0, D)
src_ptrs = X_ptr + (pid_bh * S + src_pos_idx[:, None]) * D + d_offsets[None, :]
dst_ptrs = Y_ptr + (pid_bh * S + dst_pos_idx[:, None]) * D + d_offsets[None, :]
full_mask = token_mask[:, None]
values = tl.load(src_ptrs, mask=full_mask, other=0.0)
tl.store(dst_ptrs, values, mask=full_mask)
def apply_inverse_permutation_triton(
permuted_tensor: torch.Tensor,
sorted_indices: torch.Tensor,
dim: int,
):
"""
Triton implementation of inverse permutation. Inverse the permutation applied by `permute_tensor_by_labels`.
Args:
permuted_tensor: (B, H, S, D).
sorted_indices: (B, H, S).
dim: Dimension along which to apply inverse permutation. Typically 2, meaning the sequence lengthdimension.
Returns:
Tensor of shape (B, H, S, D).
"""
assert dim == 2, "apply_inverse_permutation currently only supports dim==2"
assert permuted_tensor.dim() == 4, "Expected tensor shape [B,H,S,D]"
assert permuted_tensor.is_cuda, "apply_inverse_permutation requires CUDA tensors"
B, H, S, D = permuted_tensor.shape
BH = B * H
# Ensure index dtype
sorted_indices = sorted_indices.to(torch.int32).contiguous()
# Flatten inputs
inp_flat = permuted_tensor.reshape(BH, S, D).contiguous()
out_flat = torch.empty_like(inp_flat)
BLOCK_S = 64
n_s_tiles = triton.cdiv(S, BLOCK_S)
grid = (BH, n_s_tiles)
_inverse_permute_kernel[grid](inp_flat, sorted_indices, out_flat, S, D, BLOCK_S, num_warps=4)
original_tensor = out_flat.reshape(B, H, S, D)
return original_tensor
@ATTN_WEIGHT_REGISTER("svg2_attn")
class Svg2AttnWeight(AttnWeightTemplate):
centroids_init = False
num_q_centroids = 300
num_k_centroids = 1000
kmeans_iter_init = 50
top_p_kmeans = 0.9
min_kc_ratio = 0.10
kmeans_iter_step = 2
def __init__(self):
self.config = {}
def apply(
self,
q,
k,
v,
cu_seqlens_q=None,
cu_seqlens_kv=None,
max_seqlen_q=None,
max_seqlen_kv=None,
**kwargs,
):
q = q.unsqueeze(0).transpose(1, 2)
k = k.unsqueeze(0).transpose(1, 2)
v = v.unsqueeze(0).transpose(1, 2)
bs, num_heads, seq_len, dim = q.size()
q_perm, k_perm, v_perm, dyn_map, qc_sz_s, kc_sz_s, q_sorted_indices = self.semantic_aware_permutation(q, k, v)
output_permuted = self.dynamic_block_sparse_fwd_flashinfer(q_perm, k_perm, v_perm, dyn_map, qc_sz_s, kc_sz_s, is_cpu=False)
attn_output = apply_inverse_permutation_triton(output_permuted, q_sorted_indices, dim=2)
return attn_output.reshape(bs, num_heads, seq_len, dim).transpose(1, 2).reshape(bs * seq_len, -1)
def dynamic_block_sparse_fwd_flashinfer(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
block_mask_map: torch.Tensor,
block_row_sz: torch.Tensor,
block_col_sz: torch.Tensor,
is_cpu: bool = True,
):
"""
Launcher for the Flashinfer dynamic block sparse attention kernel.
Args:
q (torch.Tensor): Query tensor, shape [B, H, S, D].
k (torch.Tensor): Key tensor, shape [B, H, S, D].
v (torch.Tensor): Value tensor, shape [B, H, S, D].
block_mask_map (torch.Tensor): Boolean mask, shape [B, H, qc_num, kc_num]. Currently must on CPU.
block_row_sz (torch.Tensor): Query block sizes, shape [B, H, qc_num]. Currently must on CPU.
block_col_sz (torch.Tensor): Key block sizes, shape [B, H, kc_num]. Currently must on CPU.
is_cpu (bool): Whether to run on CPU. Flashinfer default is to run on CPU. We switch to GPU for faster planning. Default is True.
"""
# Input shape check
B, H, S, D = q.shape
qc_num = block_row_sz.shape[-1]
kc_num = block_col_sz.shape[-1]
assert block_mask_map.shape == (B, H, qc_num, kc_num)
assert all(t.device == torch.device("cpu") for t in [block_mask_map, block_row_sz, block_col_sz]) if is_cpu else True
# Check if block_col_sz and block_row_sz are the same for each head
assert torch.all(block_col_sz.sum(dim=2) == block_col_sz.sum(dim=2)[0, 0])
assert torch.all(block_row_sz.sum(dim=2) == block_row_sz.sum(dim=2)[0, 0])
# Prepare flashinfer wrapper
float_workspace_buffer = torch.empty(128 * 1024 * 1024, device=q.device)
vector_sparse_indices_buffer = torch.empty(1024 * 1024 * 1024, device=q.device)
wrapper = flashinfer.sparse.VariableBlockSparseAttentionWrapper(float_workspace_buffer, backend="auto")
wrapper.reset_workspace_buffer(
float_workspace_buffer=wrapper._float_workspace_buffer,
int_workspace_buffer=wrapper._int_workspace_buffer,
vector_sparse_indices_buffer=vector_sparse_indices_buffer, # Only reset this buffer size
vector_sparse_indptr_buffer=wrapper._vector_sparse_indptr_buffer,
)
block_mask_map = block_mask_map.reshape(B * H, qc_num, kc_num)
block_row_sz = block_row_sz.reshape(B * H, qc_num)
block_col_sz = block_col_sz.reshape(B * H, kc_num)
wrapper.plan(
block_mask_map=block_mask_map,
block_row_sz=block_row_sz,
block_col_sz=block_col_sz,
num_qo_heads=B * H,
num_kv_heads=B * H,
head_dim=D,
q_data_type=q.dtype,
kv_data_type=k.dtype,
)
# print_memory_usage("After plan")
q = q.reshape(B * H, S, D)
k = k.reshape(B * H, S, D)
v = v.reshape(B * H, S, D)
o = wrapper.run(q, k, v) # [num_qo_heads, qo_len, head_dim]
o = o.reshape(B, H, S, D)
return o
def semantic_aware_permutation(self, query, key, value):
cfg, num_heads, seq_len, dim = query.size()
# 1. Kmeans clustering
qlabels, qcentroids, qcluster_sizes, qiter, klabels, kcentroids, kcluster_sizes, kiter = self.kmeans_clustering(query, key)
# 2. Identify dynamic map
q_cluster_sizes = qcluster_sizes.view(cfg, num_heads, self.num_q_centroids)
k_cluster_sizes = kcluster_sizes.view(cfg, num_heads, self.num_k_centroids)
dynamic_map = identify_dynamic_map(
qcentroids.view(cfg, num_heads, self.num_q_centroids, dim),
kcentroids.view(cfg, num_heads, self.num_k_centroids, dim),
q_cluster_sizes,
k_cluster_sizes,
self.top_p_kmeans,
self.min_kc_ratio,
)
# 3. Permute the query, key, value
q_permuted, q_sorted_indices = permute_tensor_by_labels_triton(query, qlabels, dim=2)
k_permuted, k_sorted_indices = permute_tensor_by_labels_triton(key, klabels, dim=2)
v_permuted, v_sorted_indices = permute_tensor_by_labels_triton(value, klabels, dim=2, sorted_indices=k_sorted_indices)
return q_permuted, k_permuted, v_permuted, dynamic_map, q_cluster_sizes, k_cluster_sizes, q_sorted_indices
def kmeans_clustering(self, query, key):
if not self.centroids_init:
qlabels, qcentroids, qcluster_sizes, qiter, klabels, kcentroids, kcluster_sizes, kiter = self.kmeans_init(query, key)
self.centroids_init = True
else:
qlabels, qcentroids, qcluster_sizes, qiter, klabels, kcentroids, kcluster_sizes, kiter = self.kmeans_step(query, key)
return qlabels, qcentroids, qcluster_sizes, qiter, klabels, kcentroids, kcluster_sizes, kiter
def kmeans_init(self, query, key):
cfg, num_heads, seq_len, dim = query.size()
qlabels, qcentroids, qcluster_sizes, qiter = batch_kmeans_Euclid(query.view(cfg * num_heads, seq_len, dim), n_clusters=self.num_q_centroids, max_iters=self.kmeans_iter_init)
klabels, kcentroids, kcluster_sizes, kiter = batch_kmeans_Euclid(key.view(cfg * num_heads, seq_len, dim), n_clusters=self.num_k_centroids, max_iters=self.kmeans_iter_init)
self.q_centroids = qcentroids
self.k_centroids = kcentroids
return qlabels, qcentroids, qcluster_sizes, qiter, klabels, kcentroids, kcluster_sizes, kiter
def kmeans_step(self, query, key):
cfg, num_heads, seq_len, dim = query.size()
qlabels, qcentroids, qcluster_sizes, qiter = batch_kmeans_Euclid(
query.view(cfg * num_heads, seq_len, dim),
n_clusters=self.num_q_centroids,
max_iters=self.kmeans_iter_step,
init_centroids=self.q_centroids,
)
klabels, kcentroids, kcluster_sizes, kiter = batch_kmeans_Euclid(
key.view(cfg * num_heads, seq_len, dim),
n_clusters=self.num_k_centroids,
max_iters=self.kmeans_iter_step,
init_centroids=self.k_centroids,
)
self.q_centroids = qcentroids
self.k_centroids = kcentroids
return qlabels, qcentroids, qcluster_sizes, qiter, klabels, kcentroids, kcluster_sizes, kiter
if __name__ == "__main__":
q, k, v = torch.randn(32130, 40, 128, dtype=torch.bfloat16).cuda(), torch.randn(32130, 40, 128, dtype=torch.bfloat16).cuda(), torch.randn(32130, 40, 128, dtype=torch.bfloat16).cuda()
svg2_attn = Svg2AttnWeight()
print("Svg2AttnWeight initialized.")
out = svg2_attn.apply(q, k, v)
print(f"out: {out.shape}, {out.dtype}, {out.device}")
import torch
import torch.nn.functional as F
import triton
import triton.language as tl
try:
from cuvs.cluster.kmeans import KMeansParams, fit
except ImportError:
KMeansParams = None
fit = None
# --- New functions ---
def density_calculation(dynamic_map, q_cluster_sizes, k_cluster_sizes):
"""
Calculate the density of the dynamic map. Currently only batch size = 1 and head size = 1 are supported.
Input:
dynamic_map: [cfg, num_heads, qc_num, kc_num]
q_cluster_sizes: [cfg, num_heads, qc_num]
k_cluster_sizes: [cfg, num_heads, kc_num]
"""
cfg, num_heads, qc_num, kc_num = dynamic_map.shape
# Calculate the block size of each block
clustered_block_size = q_cluster_sizes[:, :, :, None] * k_cluster_sizes[:, :, None, :]
masked_block_size = clustered_block_size * dynamic_map
# Calculate the density of each block
density = torch.sum(masked_block_size, dim=(2, 3)) / torch.sum(clustered_block_size, dim=(2, 3))
return density
# --- Functions from analyze/kmeans_rapidai.py ---
def pairwise_distance(x, y):
"""
Computes pairwise squared Euclidean distance between two sets of points.
"""
x_norm = (x**2).sum(1).view(-1, 1)
y_norm = (y**2).sum(1).view(1, -1)
dist = torch.clamp(x_norm + y_norm - 2.0 * torch.mm(x, torch.transpose(y, 0, 1)), min=0.0)
return dist
def kmeans_predict(centroids, input_tensor): # Removed unused params argument
"""
Predict the labels for the input tensor using the centroids.
"""
input_tensor = input_tensor.to(torch.float32)
dist = pairwise_distance(input_tensor, centroids)
labels = torch.argmin(dist, dim=1)
return labels
def kmeans_rapidai(tensor, k, max_iter=5, tol=1e-4, init_method="Array", centroids_init=None): # Renamed centroids to centroids_init
"""
Performs K-means clustering using cuVS.
"""
assert tensor.dtype == torch.float32, "Tensor must be float32 for cuVS KMeans"
assert tensor.ndim == 2, f"Tensor must be 2D, but got {tensor.shape}"
# assert init_method == "Array", "init_method must be 'Array' for now"
L, D = tensor.shape
# cuVS KMeans in RAPIDS >=23.10 uses 'centroids_init' for initial centroids
current_centroids = centroids_init
if current_centroids is None:
# Default init: cuVS handles KMeansPlusPlus if centroids_init is None and init_method is KMeansPlusPlus
# If you need to pass an empty tensor for cuVS to initialize:
current_centroids = torch.empty(k, D, device=tensor.device, dtype=torch.float32) # Or pass None
else:
assert current_centroids.dtype == torch.float32, "Initial centroids must be float32"
assert current_centroids.shape == (
k,
D,
), f"Initial centroids shape mismatch, got {current_centroids.shape}, expected ({k}, {D})"
# cuVS uses 'init_method="Array"' when 'centroids_init' is provided.
# import IPython; IPython.embed()
params = KMeansParams(n_clusters=k, max_iter=max_iter, tol=tol, init_method=init_method) # Changed init_method to init
# Call fit with centroids_init (can be None)
new_centroids, inertia, n_iter_ = fit(params, tensor, current_centroids) # Added handle=None
labels = kmeans_predict(new_centroids, tensor)
return labels, new_centroids, n_iter_
@triton.jit
def _centroid_update_kernel(
x_ptr, # *f16 [B, N, D]
cluster_ptr, # *i32 [B, N]
sum_ptr, # *f32 [B, K, D]
count_ptr, # *i32 [B, K]
B: tl.constexpr,
N: tl.constexpr,
D: tl.constexpr,
K: tl.constexpr,
BLOCK_D: tl.constexpr, # number of dims processed per program
):
"""Each program processes 1 point (token) across BLOCK_D dimensions with atomics."""
pid = tl.program_id(axis=0)
token_idx = pid # range: [0, B * N)
# Derive (b, n) indices
b = token_idx // N
n = token_idx % N
# Pointers to the token features and its cluster id
x_offset = (b * N + n) * D
x_ptr = x_ptr + x_offset
cluster_idx = tl.load(cluster_ptr + b * N + n) # int32
# Guard for invalid cluster ids (should not happen)
cluster_idx = tl.where(cluster_idx < K, cluster_idx, 0)
# Base pointer for this centroid in the output sum tensor
centroid_base = (b * K + cluster_idx) * D
# Process feature vector in chunks of BLOCK_D
offs = tl.arange(0, BLOCK_D)
for d_start in range(0, D, BLOCK_D):
mask = offs + d_start < D
feats = tl.load(x_ptr + d_start + offs, mask=mask, other=0.0)
feats = feats.to(tl.float32)
dest_ptr = sum_ptr + centroid_base + d_start + offs
tl.atomic_add(dest_ptr, feats, mask=mask)
# Update counts (only once per point)
tl.atomic_add(count_ptr + b * K + cluster_idx, 1)
def triton_centroid_update_cosine(x_norm: torch.Tensor, cluster_ids: torch.Tensor, old_centroids: torch.Tensor):
"""Compute centroids using custom Triton kernel.
Args:
x_norm (Tensor): (B, N, D) normalized input vectors (float16/float32)
cluster_ids (LongTensor): (B, N) cluster assignment per point
old_centroids (Tensor): (B, K, D) previous centroids (same dtype as x_norm)
Returns:
Tensor: (B, K, D) updated and L2-normalized centroids (dtype == x_norm.dtype)
"""
assert x_norm.is_cuda and cluster_ids.is_cuda, "Input tensors must be on CUDA device"
B, N, D = x_norm.shape
K = old_centroids.shape[1]
assert cluster_ids.shape == (B, N)
# Allocate accumulation buffers
centroid_sums = torch.zeros((B, K, D), device=x_norm.device, dtype=torch.float32)
centroid_counts = torch.zeros((B, K), device=x_norm.device, dtype=torch.int32)
# Launch Triton kernel – one program per token
total_tokens = B * N
BLOCK_D = 128 # tuneable
grid = (total_tokens,)
_centroid_update_kernel[grid](
x_norm,
cluster_ids.to(torch.int32),
centroid_sums,
centroid_counts,
B,
N,
D,
K,
BLOCK_D=BLOCK_D,
)
# Compute means; keep old centroid if empty cluster
counts_f = centroid_counts.to(torch.float32).unsqueeze(-1).clamp(min=1.0)
centroids = centroid_sums / counts_f
# For clusters with zero count, revert to old centroids
zero_mask = (centroid_counts == 0).unsqueeze(-1)
centroids = torch.where(zero_mask, old_centroids.to(torch.float32), centroids)
centroids = centroids.to(x_norm.dtype)
centroids = F.normalize(centroids, p=2, dim=-1)
return centroids
def torch_loop_centroid_update_cosine(x_norm: torch.Tensor, cluster_ids: torch.Tensor, old_centroids: torch.Tensor):
"""Reference Python implementation (double for-loop)"""
B, N, D = x_norm.shape
K = old_centroids.shape[1]
new_centroids = torch.zeros_like(old_centroids)
for b in range(B):
for k in range(K):
mask = cluster_ids[b] == k
if mask.any():
new_centroids[b, k] = F.normalize(x_norm[b][mask].mean(dim=0, dtype=x_norm.dtype), p=2, dim=0)
else:
new_centroids[b, k] = old_centroids[b, k]
return new_centroids
def triton_centroid_update_euclid(x: torch.Tensor, cluster_ids: torch.Tensor, old_centroids: torch.Tensor):
"""Compute centroids for Euclidean KMeans using Triton.
Args:
x (Tensor): (B, N, D) input vectors (float16/float32)
cluster_ids (LongTensor): (B, N) cluster assignment per point
old_centroids (Tensor): (B, K, D) previous centroids (same dtype as x)
Returns:
Tensor: (B, K, D) updated centroids (dtype == x.dtype)
"""
assert x.is_cuda and cluster_ids.is_cuda, "Input tensors must be on CUDA device"
B, N, D = x.shape
K = old_centroids.shape[1]
assert cluster_ids.shape == (B, N)
# Allocate accumulation buffers
centroid_sums = torch.zeros((B, K, D), device=x.device, dtype=torch.float32)
centroid_counts = torch.zeros((B, K), device=x.device, dtype=torch.int32)
total_tokens = B * N
BLOCK_D = 128 # tuneable
grid = (total_tokens,)
_centroid_update_kernel[grid](
x,
cluster_ids.to(torch.int32),
centroid_sums,
centroid_counts,
B,
N,
D,
K,
BLOCK_D=BLOCK_D,
)
# Compute means; keep old centroid if empty cluster
counts_f = centroid_counts.to(torch.float32).unsqueeze(-1).clamp(min=1.0)
centroids = centroid_sums / counts_f
# For clusters with zero count, revert to old centroids
zero_mask = (centroid_counts == 0).unsqueeze(-1)
centroids = torch.where(zero_mask, old_centroids.to(torch.float32), centroids)
return centroids.to(x.dtype)
# ------------------------------ NEW: chunk-wise centroid update (sorted ids) ------------------------------
@triton.jit
def _centroid_update_chunk_kernel(
x_ptr, # *f16 / *f32 [B, N, D] – ORIGINAL ORDER
sorted_idx_ptr, # *i32 [B, N] – indices after sort
sorted_cluster_ptr, # *i32 [B, N] – cluster ids in sorted order
sum_ptr, # *f32 [B, K, D]
count_ptr, # *i32 [B, K]
B: tl.constexpr,
N: tl.constexpr,
D: tl.constexpr,
K: tl.constexpr,
BLOCK_N: tl.constexpr, # how many tokens (points) each program processes
):
"""Each program processes **BLOCK_N consecutive, already-sorted tokens**.
Because the tokens are sorted by cluster id, identical ids appear in
contiguous runs. We therefore accumulate a local sum/count for the
current run and perform **a single atomic update per run**, instead of
per-token.
"""
# program indices – 2-D launch grid: (chunk_id, batch_id)
pid_chunk = tl.program_id(axis=0)
pid_b = tl.program_id(axis=1)
b = pid_b
chunk_start = pid_chunk * BLOCK_N # position of the first token handled by this program
# Nothing to do – out of range
if chunk_start >= N:
return
# base pointers for this batch
idx_batch_base = sorted_idx_ptr + b * N
cid_batch_base = sorted_cluster_ptr + b * N
x_batch_base = x_ptr + b * N * D # for pointer arithmetic
# helper aranges
offs_token = tl.arange(0, BLOCK_N)
offs_dim = tl.arange(0, D)
# first token index & validity mask
token_idx = chunk_start + offs_token
valid_tok = token_idx < N
first_token_idx = chunk_start
last_token_idx = tl.minimum(chunk_start + BLOCK_N, N) - 1
# Load first cluster id to initialise the running accumulator
first_id = tl.load(cid_batch_base + first_token_idx)
last_id = tl.load(cid_batch_base + last_token_idx)
all_ids = tl.load(cid_batch_base + token_idx, mask=valid_tok, other=-1)
all_tokens_idxs = tl.load(idx_batch_base + token_idx, mask=valid_tok, other=-1) # [BLOCK_N]
load_mask = all_tokens_idxs[:, None] * D + offs_dim[None, :]
for cid in range(first_id, last_id + 1):
cluster_mask = all_ids == cid
cluster_size = tl.sum(cluster_mask.to(tl.int32))
if cluster_size != 0:
cluster_feats = tl.load(x_batch_base + load_mask, mask=cluster_mask[:, None], other=0.0) # [BLOCK_N, D]
cluster_feats = cluster_feats.to(tl.float32)
sum_feats = tl.sum(cluster_feats, axis=0)
dest_ptr = sum_ptr + (b * K + cid) * D + offs_dim
tl.atomic_add(dest_ptr, sum_feats)
tl.atomic_add(count_ptr + b * K + cid, cluster_size)
# ---------------------------------------------------------------------------------------------
def triton_centroid_update_sorted_cosine(x_norm: torch.Tensor, cluster_ids: torch.Tensor, old_centroids: torch.Tensor, *, BLOCK_N: int = 256):
"""Fast centroid update assuming **cluster_ids are sorted along N**.
This helper will sort the assignments (together with `x_norm`) and launch the
chunk kernel above. Compared to the naive per-token kernel it performs *one
atomic add per run of identical ids* instead of per token, providing large
speed-ups when clusters are reasonably sized.
"""
assert x_norm.is_cuda and cluster_ids.is_cuda, "Inputs must be on CUDA"
B, N, D = x_norm.shape
K = old_centroids.shape[1]
assert cluster_ids.shape == (B, N)
# -------- sort per-batch --------
sorted_cluster_ids, sorted_idx = torch.sort(cluster_ids, dim=-1)
sorted_idx_int = sorted_idx.to(torch.int32)
# accumulation buffers
centroid_sums = torch.zeros((B, K, D), device=x_norm.device, dtype=torch.float32)
centroid_cnts = torch.zeros((B, K), device=x_norm.device, dtype=torch.int32)
grid = (triton.cdiv(N, BLOCK_N), B)
_centroid_update_chunk_kernel[grid](
x_norm,
sorted_idx_int,
sorted_cluster_ids.to(torch.int32),
centroid_sums,
centroid_cnts,
B,
N,
D,
K,
BLOCK_N=BLOCK_N,
)
# finalise – convert to means, handle empty clusters, renormalise
counts_f = centroid_cnts.to(torch.float32).unsqueeze(-1).clamp(min=1.0)
centroids = centroid_sums / counts_f
empty_mask = (centroid_cnts == 0).unsqueeze(-1)
centroids = torch.where(empty_mask, old_centroids.to(torch.float32), centroids)
centroids = centroids.to(x_norm.dtype)
centroids = F.normalize(centroids, p=2, dim=-1)
return centroids
def triton_centroid_update_sorted_euclid(x: torch.Tensor, cluster_ids: torch.Tensor, old_centroids: torch.Tensor, *, BLOCK_N: int = 256):
"""Fast centroid update for *Euclidean* KMeans assuming cluster IDs are pre-sorted.
Parameters
----------
x : Tensor [B, N, D]
Input feature vectors (no normalization assumed).
cluster_ids : LongTensor [B, N]
Cluster assignment for each point.
old_centroids : Tensor [B, K, D]
Previous centroids (used to fill empty clusters).
BLOCK_N : int, optional
Tokens per Triton program (affects occupancy/perf).
"""
assert x.is_cuda and cluster_ids.is_cuda, "Inputs must be on CUDA device"
B, N, D = x.shape
K = old_centroids.shape[1]
# Batch-wise sort of cluster assignments
sorted_cluster_ids, sorted_idx = torch.sort(cluster_ids, dim=-1)
sorted_idx_int = sorted_idx.to(torch.int32)
centroid_sums = torch.zeros((B, K, D), device=x.device, dtype=torch.float32)
centroid_cnts = torch.zeros((B, K), device=x.device, dtype=torch.int32)
grid = (triton.cdiv(N, BLOCK_N), B)
_centroid_update_chunk_kernel[grid](
x, # original features
sorted_idx_int, # gather indices
sorted_cluster_ids.to(torch.int32),
centroid_sums,
centroid_cnts,
B,
N,
D,
K,
BLOCK_N=BLOCK_N,
)
# Convert sums to means; replace empty clusters with old centroids
counts_f = centroid_cnts.to(torch.float32).unsqueeze(-1).clamp(min=1.0)
centroids = centroid_sums / counts_f
empty_mask = (centroid_cnts == 0).unsqueeze(-1)
centroids = torch.where(empty_mask, old_centroids.to(torch.float32), centroids)
return centroids.to(x.dtype), centroid_cnts
# ===============================================================
# Triton kernel: compute nearest-centroid IDs (Euclidean distance)
# Inputs:
# x : (B, N, D) float16 / float32
# centroids : (B, K, D) same dtype as x
# x_sq : (B, N) float32 – pre-computed ||x||^2 per point
# Output:
# cluster_ids : (B, N) int32 – nearest centroid index per point
# ===============================================================
def _ceil_div(a: int, b: int) -> int:
return (a + b - 1) // b
# -----------------------------------------------------------------------------
# Auto-tuning setup – explore various tile sizes / warp counts
# -----------------------------------------------------------------------------
_TUNE_CONFIGS = [triton.Config({"BLOCK_N": BN, "BLOCK_K": BK}, num_stages=4, num_warps=wp) for BN in [32, 64, 128] for BK in [32, 64, 128] for wp in [4, 8]]
def _cfg_keep(conf):
"""Basic heuristic to prune unbalanced configs."""
BN = conf.kwargs["BLOCK_N"]
BK = conf.kwargs["BLOCK_K"]
# Avoid tiny tiles on many warps
if BN * BK < 32 * 32 and conf.num_warps > 4:
return False
return True
_TUNE_CONFIGS = list(filter(_cfg_keep, _TUNE_CONFIGS))
@triton.autotune(_TUNE_CONFIGS, key=["N", "K"])
@triton.jit
def _euclid_assign_kernel(
x_ptr, # *f16 / *f32 [B, N, D]
c_ptr, # *f16 / *f32 [B, K, D]
x_sq_ptr, # *f32 [B, N]
out_ptr, # *i32 [B, N]
B: tl.constexpr,
N: tl.constexpr,
K: tl.constexpr,
D: tl.constexpr,
stride_x_b: tl.constexpr,
stride_x_n: tl.constexpr,
stride_x_d: tl.constexpr,
stride_c_b: tl.constexpr,
stride_c_k: tl.constexpr,
stride_c_d: tl.constexpr,
stride_xsq_b: tl.constexpr,
stride_xsq_n: tl.constexpr,
stride_out_b: tl.constexpr,
stride_out_n: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
):
"""Each program handles a tile of BLOCK_N points for a given batch element.
The kernel iterates over the centroid dimension K in chunks of BLOCK_K and
maintains the running minimum distance as well as the corresponding index
for every point in the tile.
"""
pid_n = tl.program_id(0) # tile index along N dimension
pid_b = tl.program_id(1) # batch index
n_start = pid_n * BLOCK_N
n_offsets = n_start + tl.arange(0, BLOCK_N)
n_mask = n_offsets < N
# ------------------------------------------------------------------
# Load x tile (BLOCK_N, D)
# ------------------------------------------------------------------
offs_d = tl.arange(0, D)
# Compute pointer for x block: base + b*stride_x_b + n*stride_x_n + d*stride_x_d
x_ptrs = x_ptr + pid_b * stride_x_b + n_offsets[:, None] * stride_x_n + offs_d[None, :] * stride_x_d
x_tile = tl.load(x_ptrs, mask=n_mask[:, None], other=0.0)
x_tile = x_tile # compute in f32
# Pre-load x_sq for the tile (BLOCK_N,)
xsq_ptrs = x_sq_ptr + pid_b * stride_xsq_b + n_offsets * stride_xsq_n
x_sq_tile = tl.load(xsq_ptrs, mask=n_mask, other=0.0).to(tl.float32)
# Init best distance / index
best_dist = tl.full((BLOCK_N,), 3.4e38, tl.float32) # large number
best_idx = tl.zeros((BLOCK_N,), tl.int32)
# ------------------------------------------------------------------
# Iterate over centroids in chunks of BLOCK_K
# ------------------------------------------------------------------
for k_start in range(0, K, BLOCK_K):
k_offsets = k_start + tl.arange(0, BLOCK_K)
k_mask = k_offsets < K
# Load centroid tile (D, BLOCK_K)
c_ptrs = c_ptr + pid_b * stride_c_b + k_offsets[None, :] * stride_c_k + offs_d[:, None] * stride_c_d
c_tile = tl.load(c_ptrs, mask=k_mask[None, :], other=0.0)
c_tile = c_tile
# Compute centroid squared norms (BLOCK_K,)
cent_sq = tl.sum(c_tile * c_tile, axis=0).to(tl.float32)
# Compute cross term (BLOCK_N, BLOCK_K) = x_tile @ c_tile
cross = tl.dot(x_tile, c_tile).to(tl.float32) # float32
# Squared Euclidean distance
dist = x_sq_tile[:, None] + cent_sq[None, :] - 2.0 * cross
dist = tl.maximum(dist, 0.0)
# Mask out invalid centroid columns before reduction
dist = tl.where(k_mask[None, :], dist, 3.4e38)
curr_min = tl.min(dist, axis=1)
curr_idx = tl.argmin(dist, axis=1)
update = curr_min < best_dist
best_dist = tl.where(update, curr_min, best_dist)
best_idx = tl.where(update, k_start + curr_idx, best_idx)
# ------------------------------------------------------------------
# Write results
# ------------------------------------------------------------------
out_ptrs = out_ptr + pid_b * stride_out_b + n_offsets * stride_out_n
tl.store(out_ptrs, best_idx, mask=n_mask)
# ---------------------------------------------------------------
# Python wrapper
# ---------------------------------------------------------------
def euclid_assign_triton(
x: torch.Tensor,
centroids: torch.Tensor,
x_sq: torch.Tensor,
out: torch.Tensor = None,
*,
BLOCK_N: int = 128,
BLOCK_K: int = 128,
) -> torch.Tensor:
"""Return nearest-centroid indices using Triton kernel.
Args:
x : (B, N, D) float16 / float32 (on CUDA)
centroids : (B, K, D) same dtype/device as x
x_sq : (B, N) float32 – ||x||^2 per point (on CUDA)
Returns:
cluster_ids (B, N) int32 (callers can cast to int64 if desired)
"""
assert x.is_cuda and centroids.is_cuda and x_sq.is_cuda, "All tensors must be on CUDA"
# assert x.dtype in (torch.float16, torch.float32), "x must be fp16/fp32"
assert centroids.dtype == x.dtype, "centroids dtype mismatch"
B, N, D = x.shape
K = centroids.shape[1]
assert centroids.shape == (B, K, D), "centroids shape mismatch"
assert x_sq.shape == (B, N), "x_sq shape mismatch"
# x = x.contiguous()
# centroids = centroids.contiguous()
# x_sq = x_sq.contiguous()
if out is None:
out = torch.empty((B, N), device=x.device, dtype=torch.int64)
# Strides (in elements)
stride_x_b, stride_x_n, stride_x_d = x.stride()
stride_c_b, stride_c_k, stride_c_d = centroids.stride()
stride_xsq_b, stride_xsq_n = x_sq.stride()
stride_out_b, stride_out_n = out.stride()
grid = lambda META: (triton.cdiv(N, META["BLOCK_N"]), B) # noqa
_euclid_assign_kernel[grid](
x,
centroids,
x_sq,
out,
B,
N,
K,
D,
stride_x_b,
stride_x_n,
stride_x_d,
stride_c_b,
stride_c_k,
stride_c_d,
stride_xsq_b,
stride_xsq_n,
stride_out_b,
stride_out_n,
)
return out
# 1. Euclidean
def _euclid_iter(x, x_sq, centroids):
# cent_sq = (centroids ** 2).sum(dim=-1)
# cross = torch.einsum('bnd,bkd->bnk', x, centroids)
# dist_sq = (x_sq[:,:,None] + cent_sq[:,None,:] - 2.0 * cross).clamp_min_(0.0)
# cluster_ids = dist_sq.argmin(dim=-1)
cluster_ids = euclid_assign_triton(x, centroids, x_sq)
centroids_new, cluster_sizes = triton_centroid_update_sorted_euclid(x, cluster_ids, centroids)
# centroids_new = triton_centroid_update_euclid(x, cluster_ids, centroids)
# centroids_new = centroids_new.clone() # avoid CUDA graphs aliasing
shift = (centroids_new - centroids).norm(dim=-1).max()
return centroids_new, shift, cluster_ids, cluster_sizes
# 2. Cosine
def _cosine_iter(x_norm, centroids):
cos_sim = torch.einsum("bnd,bkd->bnk", x_norm, centroids)
cluster_ids = cos_sim.argmax(dim=-1)
centroids_new = triton_centroid_update_cosine(x_norm, cluster_ids, centroids)
# centroids_new = centroids_new.clone()
shift = (centroids_new - centroids).norm(dim=-1).max()
return centroids_new, shift, cluster_ids
# 3. Dot-product
def _dot_iter(x, centroids):
sim = torch.einsum("bnd,bkd->bnk", x, centroids)
cluster_ids = sim.argmax(dim=-1)
centroids_new = triton_centroid_update_cosine(x, cluster_ids, centroids)
# centroids_new = centroids_new.clone()
shift = (centroids_new - centroids).norm(dim=-1).max()
return centroids_new, shift, cluster_ids
COMPILE_FLAG = False
# Try to compile; if PyTorch < 2.0 or compile is not available, fallback to original function
try:
if COMPILE_FLAG:
_euclid_iter_compiled = torch.compile(_euclid_iter, dynamic=True, mode="reduce-overhead")
_cosine_iter_compiled = torch.compile(_cosine_iter, dynamic=True, mode="reduce-overhead")
_dot_iter_compiled = torch.compile(_dot_iter, dynamic=True, mode="reduce-overhead")
else:
_euclid_iter_compiled = _euclid_iter
_cosine_iter_compiled = _cosine_iter
_dot_iter_compiled = _dot_iter
except Exception: # pragma: no cover
_euclid_iter_compiled = _euclid_iter
_cosine_iter_compiled = _cosine_iter
_dot_iter_compiled = _dot_iter
def batch_kmeans_Euclid(x, n_clusters, max_iters=100, tol=1e-4, init_centroids=None, verbose=False):
"""
Batched KMeans clustering in PyTorch using Euclidean distance.
Args:
x: Tensor of shape (B, N, D), batch_size B, N points per batch, D dims.
n_clusters: Number of clusters.
max_iters: Max number of iterations.
tol: Relative tolerance for center movement.
verbose: Print loss for each iter.
Returns:
cluster_ids: (B, N) LongTensor, cluster assignment for each point.
centroids: (B, n_clusters, D) final cluster centers.
cluster_sizes: (B, n_clusters) LongTensor, number of points per cluster.
n_iters: actual number of iterations executed (int)
"""
B, N, D = x.shape
# Pre-compute squared L2 norm of all points (constant during iterations)
x_sq = (x**2).sum(dim=-1) # (B, N)
if init_centroids is None:
# Randomly select initial centers from x
indices = torch.randint(0, N, (B, n_clusters), device=x.device)
centroids = torch.gather(x, dim=1, index=indices[..., None].expand(-1, -1, D)) # (B, n_clusters, D)
else:
# centroids = init_centroids.clone()
centroids = init_centroids
centroids = centroids.view(B, n_clusters, D)
for it in range(max_iters):
# ---- compiled single iteration ----
centroids_new, center_shift, cluster_ids, cluster_sizes = _euclid_iter_compiled(x, x_sq, centroids)
# 4. Check for convergence
if verbose:
print(f"Iter {it}, center shift: {center_shift.item():.6f}")
if center_shift < tol:
break
# centroids = centroids_new.clone()
centroids = centroids_new
# # --- compute cluster sizes ---
# ones = torch.ones_like(cluster_ids, dtype=torch.int64)
# cluster_sizes = torch.zeros(B, n_clusters, dtype=torch.int64, device=x.device)
# cluster_sizes.scatter_add_(1, cluster_ids, ones)
return cluster_ids, centroids, cluster_sizes, it + 1
# return cluster_ids.clone(), centroids.clone(), cluster_sizes.clone(), it + 1
# batch_kmeans_Euclid = torch.compile(batch_kmeans_Euclid, dynamic=True, mode="reduce-overhead")
def batch_kmeans_Cosine(x, n_clusters, max_iters=100, tol=1e-4, init_centroids=None, verbose=False):
"""
Batched KMeans clustering in PyTorch using Cosine similarity.
Args:
x: Tensor of shape (B, N, D), batch_size B, N points per batch, D dims.
n_clusters: Number of clusters.
max_iters: Max number of iterations.
tol: Relative tolerance for center movement.
verbose: Print loss for each iter.
Returns:
cluster_ids: (B, N) LongTensor, cluster assignment for each point.
centroids: (B, n_clusters, D) final cluster centers.
cluster_sizes: (B, n_clusters) LongTensor, number of points per cluster.
n_iters: actual number of iterations executed (int)
"""
B, N, D = x.shape
# Normalize input vectors for cosine similarity
x_norm = F.normalize(x, p=2, dim=-1) # (B, N, D)
if init_centroids is None:
# Randomly select initial centers from x_norm
indices = torch.randint(0, N, (B, n_clusters), device=x.device)
centroids = torch.gather(x_norm, dim=1, index=indices[..., None].expand(-1, -1, D)) # (B, n_clusters, D)
else:
centroids = init_centroids
centroids = centroids.view(B, n_clusters, D)
centroids = F.normalize(centroids, p=2, dim=-1) # Ensure centroids are normalized
for it in range(max_iters):
# ---- compiled single iteration ----
centroids_new, center_shift, cluster_ids = _cosine_iter_compiled(x_norm, centroids)
# 4. Check for convergence
if verbose:
print(f"Iter {it}, center shift: {center_shift.item():.6f}")
if center_shift < tol:
break
centroids = centroids_new.clone()
# --- compute cluster sizes ---
ones = torch.ones_like(cluster_ids, dtype=torch.int64)
cluster_sizes = torch.zeros(B, n_clusters, dtype=torch.int64, device=x.device)
cluster_sizes.scatter_add_(1, cluster_ids, ones)
return cluster_ids, centroids, cluster_sizes, it + 1
def batch_kmeans_Dot(x, n_clusters, max_iters=100, tol=1e-4, init_centroids=None, verbose=False):
"""
Batched KMeans clustering in PyTorch using raw dot-product as similarity.
"""
B, N, D = x.shape
if init_centroids is None:
# Randomly initialize centroids
indices = torch.randint(0, N, (B, n_clusters), device=x.device)
centroids = torch.gather(x, dim=1, index=indices[..., None].expand(-1, -1, D))
else:
centroids = init_centroids
centroids = centroids.view(B, n_clusters, D)
for it in range(max_iters):
# ---- compiled single iteration ----
centroids_new, center_shift, cluster_ids = _dot_iter_compiled(x, centroids)
# 4. Check for convergence
if verbose:
print(f"Iter {it} (dot), center shift: {center_shift.item():.6f}")
if center_shift < tol:
break
centroids = centroids_new.clone()
# --- compute cluster sizes ---
ones = torch.ones_like(cluster_ids, dtype=torch.int64)
cluster_sizes = torch.zeros(B, n_clusters, dtype=torch.int64, device=x.device)
cluster_sizes.scatter_add_(1, cluster_ids, ones)
return cluster_ids, centroids, cluster_sizes, it + 1
# --- Functions from analyze/kmeans_block_sparse_attention.py (helpers) ---
def permute_tensor_by_labels(tensor, labels, dim):
labels = labels.to(tensor.device)
sorted_indices = torch.argsort(labels, dim=-1)
gather_indices = sorted_indices
for i in range(dim + 1, tensor.dim()):
gather_indices = gather_indices.unsqueeze(-1)
expand_shape = list(tensor.shape)
gather_indices = gather_indices.expand(expand_shape)
permuted_tensor = torch.gather(tensor, dim, gather_indices)
return permuted_tensor, sorted_indices
def apply_inverse_permutation(permuted_tensor, sorted_indices, dim):
inverse_indices = torch.argsort(sorted_indices, dim=-1)
gather_indices = inverse_indices
for i in range(dim + 1, permuted_tensor.dim()):
gather_indices = gather_indices.unsqueeze(-1)
gather_indices = gather_indices.expand(permuted_tensor.shape)
original_tensor = torch.gather(permuted_tensor, dim, gather_indices)
return original_tensor
def weighted_softmax(scores, weights):
input_dtype = scores.dtype
scores = scores.float()
weights = weights.float()
max_score = torch.max(scores, dim=-1, keepdim=True)[0]
exp_scores = torch.exp(scores - max_score)
weighted_exp = weights * exp_scores
softmax_out = weighted_exp / torch.sum(weighted_exp, dim=-1, keepdim=True).clamp(min=1e-12)
return softmax_out.to(input_dtype)
def identify_dynamic_map(
query_centroids,
key_centroids,
q_cluster_sizes,
k_cluster_sizes,
p,
min_kc_ratio=0,
):
B, H, qc_num, D = query_centroids.shape
kc_num = key_centroids.shape[2]
device = query_centroids.device
attn_scores = torch.matmul(query_centroids, key_centroids.transpose(-2, -1)) / (D**0.5)
k_weights = k_cluster_sizes.unsqueeze(-2).float()
weighted_attn_probs = weighted_softmax(attn_scores, k_weights)
sorted_probs, sorted_indices = torch.sort(weighted_attn_probs, dim=-1, descending=True)
cumsum_probs = torch.cumsum(sorted_probs, dim=-1)
remove_indices = cumsum_probs > p
remove_indices[..., 1:] = remove_indices[..., :-1].clone()
remove_indices[..., 0] = False
if min_kc_ratio > 0:
preserve_length = int(min_kc_ratio * kc_num)
remove_indices[..., :preserve_length] = False
sorted_clusters_to_keep = ~remove_indices
dynamic_map = torch.zeros(B, H, qc_num, kc_num, dtype=torch.bool, device=device)
dynamic_map.scatter_(-1, sorted_indices, sorted_clusters_to_keep)
return dynamic_map
# --- Functions from analyze/dynamic_block_sparse_attention.py ---
def dynamic_block_sparse_fwd_torch(q, k, v, dynamic_map, qc_size, kc_size):
"""
Computes dynamic block sparse attention using pure PyTorch.
Args:
q (torch.Tensor): Query tensor, shape [B, H, S, D].
k (torch.Tensor): Key tensor, shape [B, H, S, D].
v (torch.Tensor): Value tensor, shape [B, H, S, D].
dynamic_map (torch.Tensor): Boolean mask, shape [B, H, qc_num, kc_num].
qc_size (torch.Tensor): Query block sizes, shape [B, H, qc_num].
kc_size (torch.Tensor): Key block sizes, shape [B, H, kc_num].
Returns:
torch.Tensor: Output tensor, shape [B, H, S, D].
"""
B, H, S, D = q.shape
qc_num = qc_size.shape[-1]
kc_num = kc_size.shape[-1]
device = q.device
dtype = q.dtype
# Ensure sequence lengths match sum of block sizes
assert S == torch.sum(qc_size[0, 0, :]), "Sum of qc_size must equal S"
assert S == torch.sum(kc_size[0, 0, :]), "Sum of kc_size must equal S"
# Precompute cumulative sizes for block indexing
# Add a 0 at the beginning for easier slicing
qc_cum_size = torch.cumsum(torch.cat([torch.zeros_like(qc_size[..., :1]), qc_size], dim=-1), dim=-1)
kc_cum_size = torch.cumsum(torch.cat([torch.zeros_like(kc_size[..., :1]), kc_size], dim=-1), dim=-1)
out = torch.zeros_like(q)
scale = D**-0.5
# Naive implementation: Iterate through batch, head, and blocks
for b in range(B):
for h in range(H):
# Precompute start/end indices for this batch/head
q_starts = qc_cum_size[b, h, :-1]
q_ends = qc_cum_size[b, h, 1:]
k_starts = kc_cum_size[b, h, :-1]
k_ends = kc_cum_size[b, h, 1:]
# Iterate through query blocks
for i in range(qc_num):
q_start, q_end = q_starts[i], q_ends[i]
q_block = q[b, h, q_start:q_end, :] # Shape: [qc_i, D]
if q_block.shape[0] == 0:
continue # Skip empty blocks
m_i = torch.full((q_block.shape[0], 1), -float("inf"), device=device, dtype=dtype)
l_i = torch.zeros((q_block.shape[0], 1), device=device, dtype=dtype)
acc_o_i = torch.zeros_like(q_block) # Shape: [qc_i, D]
# Iterate through key/value blocks for the current query block
for j in range(kc_num):
# Check if this block needs computation
if dynamic_map[b, h, i, j]:
k_start, k_end = k_starts[j], k_ends[j]
k_block = k[b, h, k_start:k_end, :] # Shape: [kc_j, D]
v_block = v[b, h, k_start:k_end, :] # Shape: [kc_j, D]
if k_block.shape[0] == 0:
continue # Skip empty blocks
# Compute attention scores for the block
# QK^T: [qc_i, D] @ [D, kc_j] -> [qc_i, kc_j]
s_ij = (q_block @ k_block.transpose(-1, -2)) * scale
# --- Online Softmax ---
# Find max score per query token in this block
m_ij = torch.max(s_ij, dim=-1, keepdim=True)[0] # Shape: [qc_i, 1]
# Update overall max score (m_i)
m_new = torch.maximum(m_i, m_ij) # Shape: [qc_i, 1]
# Calculate scaling factors for previous accumulator and current block
p_ij = torch.exp(s_ij - m_new) # Shape: [qc_i, kc_j]
exp_m_diff = torch.exp(m_i - m_new) # Shape: [qc_i, 1]
# Update softmax denominator (l_i)
l_i = (l_i * exp_m_diff) + torch.sum(p_ij, dim=-1, keepdim=True) # Shape: [qc_i, 1]
# Update output accumulator (acc_o_i)
# P_ij @ V_j: [qc_i, kc_j] @ [kc_j, D] -> [qc_i, D]
acc_o_i = (acc_o_i * exp_m_diff) + (p_ij @ v_block) # Shape: [qc_i, D]
# Update max score for next iteration
m_i = m_new
# Normalize the accumulated output
out[b, h, q_start:q_end, :] = acc_o_i / l_i.clamp(min=1e-12) # Avoid division by zero
return out
# --- Triton Implementation ---
@triton.jit
def _dynamic_block_sparse_fwd_kernel(
Q,
K,
V,
Out,
dynamic_map,
qc_cum_size,
kc_cum_size,
stride_qb,
stride_qh,
stride_qs,
stride_qd,
stride_kb,
stride_kh,
stride_ks,
stride_kd,
stride_vb,
stride_vh,
stride_vs,
stride_vd,
stride_ob,
stride_oh,
stride_os,
stride_od,
stride_dmap_b,
stride_dmap_h,
stride_dmap_qc,
stride_dmap_kc,
stride_qcs_b,
stride_qcs_h,
stride_qcs_qc,
stride_kcs_b,
stride_kcs_h,
stride_kcs_kc,
B,
H,
S,
D,
scale,
QC_NUM: tl.constexpr,
KC_NUM: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_D: tl.constexpr,
):
"""
Triton kernel for dynamic block sparse attention.
Each program computes attention for one query block within a batch/head.
Processes query block in chunks of BLOCK_M.
Iterates through key blocks, checking dynamic_map.
Processes key/value blocks in chunks of BLOCK_N.
Uses online softmax.
"""
# --- Grid Calculation ---
# Each program instance handles one query block for a specific batch and head
pid = tl.program_id(axis=0)
B * H * QC_NUM
# Calculate batch, head, and query block index
pid_q_block_global = pid # 0 to B*H*QC_NUM - 1
# pid_bh = pid // QC_NUM # Deprecated: Causes issues if QC_NUM is not constant across BH
# pid_q_block_idx = pid % QC_NUM
# Need to map pid (0.. B*H*QC_NUM-1) back to (b, h, q_block_idx)
# q_block_idx changes fastest, then h, then b
q_block_idx = pid_q_block_global % QC_NUM
pid_h_temp = pid_q_block_global // QC_NUM
h = pid_h_temp % H
b = pid_h_temp // H
# --- Load Q block info (start/end offsets) ---
qcs_offset = b * stride_qcs_b + h * stride_qcs_h
q_start_offset = tl.load(qc_cum_size + qcs_offset + q_block_idx * stride_qcs_qc)
q_end_offset = tl.load(qc_cum_size + qcs_offset + (q_block_idx + 1) * stride_qcs_qc)
q_block_size = q_end_offset - q_start_offset
# Early exit if the query block is empty
if q_block_size == 0:
return
# --- Pointers setup ---
q_ptr_base = Q + b * stride_qb + h * stride_qh + q_start_offset * stride_qs
k_ptr_base = K + b * stride_kb + h * stride_kh
v_ptr_base = V + b * stride_vb + h * stride_vh
out_ptr_base = Out + b * stride_ob + h * stride_oh + q_start_offset * stride_os
dmap_ptr = dynamic_map + b * stride_dmap_b + h * stride_dmap_h + q_block_idx * stride_dmap_qc
kcs_ptr = kc_cum_size + b * stride_kcs_b + h * stride_kcs_h
# --- Iterate over the query block rows in chunks of BLOCK_M ---
offs_qm = tl.arange(0, BLOCK_M) # Query block row offsets [0, 1, ..., BLOCK_M-1]
offs_d = tl.arange(0, BLOCK_D) # Dimension offsets [0, 1, ..., BLOCK_D-1]
for q_chunk_start in range(0, q_block_size, BLOCK_M):
q_chunk_rows = offs_qm + q_chunk_start
q_rows_mask = q_chunk_rows < q_block_size # Mask for valid rows in this Q chunk [BLOCK_M]
# --- Initialize accumulators for this Q chunk ---
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") # Max score
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) # Sum of exp(scores - max)
acc_o = tl.zeros([BLOCK_M, BLOCK_D], dtype=tl.float32) # Accumulated output
# --- Load Q chunk ---
q_ptr = q_ptr_base + q_chunk_rows[:, None] * stride_qs + offs_d[None, :]
# Mask ensures we don't read out of bounds for the query block or dimension D
mask_q = q_rows_mask[:, None] & (offs_d[None, :] < D)
q_chunk = tl.load(q_ptr, mask=mask_q, other=0.0) # Shape: [BLOCK_M, BLOCK_D]
# --- Inner loop over K blocks (columns in the block sparse map) ---
for k_block_idx in range(KC_NUM):
# --- Check dynamic_map: Is this block active? ---
is_active = tl.load(dmap_ptr + k_block_idx * stride_dmap_kc)
if is_active: # Process block only if it's active
# --- Load K block info (start/end offsets) ---
k_start_offset = tl.load(kcs_ptr + k_block_idx * stride_kcs_kc)
k_end_offset = tl.load(kcs_ptr + (k_block_idx + 1) * stride_kcs_kc)
k_block_size = k_end_offset - k_start_offset
# Skip if the key block is empty (inside the active block check)
if k_block_size > 0:
k_block_ptr_base = k_ptr_base + k_start_offset * stride_ks
v_block_ptr_base = v_ptr_base + k_start_offset * stride_vs
# --- Loop over K block chunks (size BLOCK_N) ---
offs_kn = tl.arange(0, BLOCK_N) # Key block row offsets [0, ..., BLOCK_N-1]
for k_chunk_start in range(0, k_block_size, BLOCK_N):
k_chunk_rows = offs_kn + k_chunk_start
k_rows_mask = k_chunk_rows < k_block_size # Mask for valid rows in this K/V chunk [BLOCK_N]
# --- Load K, V chunks ---
k_ptr = k_block_ptr_base + k_chunk_rows[:, None] * stride_ks + offs_d[None, :]
v_ptr = v_block_ptr_base + k_chunk_rows[:, None] * stride_vs + offs_d[None, :]
# Mask ensures we don't read out of bounds for the key block or dimension D
mask_kv = k_rows_mask[:, None] & (offs_d[None, :] < D)
k_chunk = tl.load(k_ptr, mask=mask_kv, other=0.0) # Shape: [BLOCK_N, BLOCK_D]
v_chunk = tl.load(v_ptr, mask=mask_kv, other=0.0) # Shape: [BLOCK_N, BLOCK_D]
# --- Compute Scores (Attention) ---
# QK^T: [BLOCK_M, BLOCK_D] @ [BLOCK_D, BLOCK_N] -> [BLOCK_M, BLOCK_N]
s_ij_chunk = tl.dot(q_chunk, k_chunk.T) * scale
# IMPORTANT: Mask out scores corresponding to padding in K before max/softmax
# Set scores for invalid K elements to -inf
s_ij_chunk = tl.where(k_rows_mask[None, :], s_ij_chunk, -float("inf"))
# Mask out scores for invalid Q elements as well (although q_chunk elements are 0, avoid potential issues)
s_ij_chunk = tl.where(q_rows_mask[:, None], s_ij_chunk, -float("inf"))
# --- Online Softmax Update ---
# Current max for this Q-K chunk interaction
m_ij_chunk = tl.max(s_ij_chunk, axis=1) # Shape: [BLOCK_M]
# Update overall max (across K chunks seen so far for this Q chunk)
m_new = tl.maximum(m_i, m_ij_chunk) # Shape: [BLOCK_M]
# Calculate scaled probabilities P_ij = exp(S_ij - m_new)
p_ij_chunk = tl.exp(s_ij_chunk - m_new[:, None]) # Shape: [BLOCK_M, BLOCK_N]
# Zero out probabilities for masked K elements before summing
p_ij_chunk = tl.where(k_rows_mask[None, :], p_ij_chunk, 0.0)
# Calculate scaling factor for previous accumulator state
exp_m_diff = tl.exp(m_i - m_new) # Shape: [BLOCK_M]
# Update sum accumulator (denominator L)
l_i_chunk = tl.sum(p_ij_chunk, axis=1) # Sum probabilities for this chunk, shape [BLOCK_M]
l_i = (l_i * exp_m_diff) + l_i_chunk # Shape: [BLOCK_M]
# Update output accumulator O
# P_ij @ V_j: [BLOCK_M, BLOCK_N] @ [BLOCK_N, BLOCK_D] -> [BLOCK_M, BLOCK_D]
# Ensure p_ij_chunk is the correct dtype for dot product
p_ij_chunk_casted = p_ij_chunk.to(V.dtype.element_ty)
o_chunk = tl.dot(p_ij_chunk_casted, v_chunk) # Shape: [BLOCK_M, BLOCK_D]
acc_o = (acc_o * exp_m_diff[:, None]) + o_chunk # Shape: [BLOCK_M, BLOCK_D]
# Update max for the next K chunk/block
m_i = m_new
# End of 'if is_active:' block
# --- End of loop over K blocks ---
# --- Finalize output for this Q chunk ---
# Normalize the accumulated output: O = acc_o / l_i
# Add epsilon to l_i to avoid division by zero
l_i_safe = tl.where(l_i == 0, 1.0, l_i) # Avoid 0/0 -> NaN
o_final_chunk = acc_o / (l_i_safe[:, None])
o_final_chunk = tl.where(l_i[:, None] == 0, 0.0, o_final_chunk) # Ensure output is 0 if l_i was 0
# --- Write output chunk to global memory ---
out_ptr = out_ptr_base + q_chunk_rows[:, None] * stride_os + offs_d[None, :]
# Mask ensures we don't write out of bounds for the query block or dimension D
mask_out = q_rows_mask[:, None] & (offs_d[None, :] < D)
tl.store(out_ptr, o_final_chunk.to(Out.dtype.element_ty), mask=mask_out)
# --- (Optional: Write L and M stats if needed) ---
# Example:
# l_ptr = L + b * stride_lb + h * stride_lh + (q_start_offset + q_chunk_rows) * stride_ls
# tl.store(l_ptr, l_i, mask=q_rows_mask)
# m_ptr = M + ...
# tl.store(m_ptr, m_i, mask=q_rows_mask)
# --- End of loop over Q chunks ---
def dynamic_block_sparse_fwd_triton(q, k, v, dynamic_map, qc_size, kc_size):
"""
Launcher for the Triton dynamic block sparse attention kernel.
Args:
q (torch.Tensor): Query tensor, shape [B, H, S, D].
k (torch.Tensor): Key tensor, shape [B, H, S, D].
v (torch.Tensor): Value tensor, shape [B, H, S, D].
dynamic_map (torch.Tensor): Boolean mask, shape [B, H, qc_num, kc_num].
qc_size (torch.Tensor): Query block sizes, shape [B, H, qc_num].
kc_size (torch.Tensor): Key block sizes, shape [B, H, kc_num].
Returns:
torch.Tensor: Output tensor, shape [B, H, S, D].
"""
B, H, S, D = q.shape
qc_num = qc_size.shape[-1]
kc_num = kc_size.shape[-1]
dtype = q.dtype
# Assertions and checks
assert q.is_cuda and k.is_cuda and v.is_cuda, "Inputs must be CUDA tensors"
assert dynamic_map.is_cuda and qc_size.is_cuda and kc_size.is_cuda
assert q.dtype == k.dtype == v.dtype, "Input dtypes must match"
assert dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported dtype"
assert D in [16, 32, 64, 128], "Head dimension D must be 16, 32, 64, or 128 for efficient Triton dot"
# Ensure sequence lengths match sum of block sizes (check on one batch/head for simplicity)
assert S == torch.sum(qc_size[0, 0, :]), "Sum of qc_size must equal S"
assert S == torch.sum(kc_size[0, 0, :]), "Sum of kc_size must equal S"
# Ensure dynamic_map is boolean
assert dynamic_map.dtype == torch.bool
# Calculate scale factor (using float32 for stability)
scale = D**-0.5
# Precompute cumulative sizes (on CPU/GPU, keep on device)
qc_cum_size = torch.cumsum(torch.cat([torch.zeros_like(qc_size[..., :1]), qc_size], dim=-1), dim=-1).int()
kc_cum_size = torch.cumsum(torch.cat([torch.zeros_like(kc_size[..., :1]), kc_size], dim=-1), dim=-1).int()
# Output tensor
out = torch.empty_like(q)
# Triton kernel config
# BLOCK_M/N can be tuned. Larger blocks may increase occupancy but need more shared memory.
# Let's start with reasonably sized blocks.
BLOCK_D = D
if S <= 512: # Smaller sequence, smaller blocks might be ok
BLOCK_M = 64
BLOCK_N = 64
elif S <= 1024:
BLOCK_M = 64
BLOCK_N = 64
else: # Larger sequence, potentially larger blocks
BLOCK_M = 128 # Or keep 64? Test
BLOCK_N = 64
# Adjust block size if sequence length is smaller
BLOCK_M = min(BLOCK_M, S)
BLOCK_N = min(BLOCK_N, S)
# Launch grid: One program per query block per batch/head
grid = (B * H * qc_num,)
# Call the kernel
_dynamic_block_sparse_fwd_kernel[grid](
q,
k,
v,
out,
dynamic_map,
qc_cum_size,
kc_cum_size,
q.stride(0),
q.stride(1),
q.stride(2),
q.stride(3),
k.stride(0),
k.stride(1),
k.stride(2),
k.stride(3),
v.stride(0),
v.stride(1),
v.stride(2),
v.stride(3),
out.stride(0),
out.stride(1),
out.stride(2),
out.stride(3),
dynamic_map.stride(0),
dynamic_map.stride(1),
dynamic_map.stride(2),
dynamic_map.stride(3),
qc_cum_size.stride(0),
qc_cum_size.stride(1),
qc_cum_size.stride(2),
kc_cum_size.stride(0),
kc_cum_size.stride(1),
kc_cum_size.stride(2),
B,
H,
S,
D,
scale,
QC_NUM=qc_num,
KC_NUM=kc_num,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
BLOCK_D=BLOCK_D,
# num_warps=4 # Can tune this
)
return out
# ---------------- Batch wrapper for cuVS KMeans -----------------
def batch_kmeans_rapidai(x, n_clusters, max_iters=100, tol=1e-4, init_centroids=None, verbose=False):
"""Batched K-Means using RAPIDS cuVS implementation.
Args:
x (Tensor): (B, N, D) float32 tensor on CUDA.
n_clusters (int): K.
max_iters (int): maximum iterations.
tol (float): tolerance.
init_centroids (Tensor|None): optional initial centroids (B,K,D) float32.
verbose (bool): print per-batch info.
Returns:
cluster_ids (B, N) LongTensor
centroids (B, K, D) float32
cluster_sizes (B, K) LongTensor
n_iters_list (List[int]) iterations per batch
"""
B, N, D = x.shape
if init_centroids is not None:
assert init_centroids.shape == (B, n_clusters, D)
cluster_ids_list = []
centroids_list = []
# cluster_sizes_list = []
n_iters_list = []
x_float = x.float()
if init_centroids is not None:
init_centroids_float = init_centroids.float()
for b in range(B):
xb = x_float[b]
if init_centroids is None:
centroids_init_b = None
init_method = "KMeansPlusPlus"
else:
centroids_init_b = init_centroids_float[b]
init_method = "Array"
labels_b, centroids_b, n_iter_b = kmeans_rapidai(xb, n_clusters, max_iter=max_iters, tol=tol, init_method=init_method, centroids_init=centroids_init_b)
cluster_ids_list.append(labels_b.to(torch.int64)) # (N,)
centroids_list.append(centroids_b)
# cluster_sizes_b = torch.bincount(labels_b, minlength=n_clusters).to(torch.int64)
# cluster_sizes_list.append(cluster_sizes_b)
# n_iters_list.append(n_iter_b)
# if verbose:
# print(f"Batch {b}: iters={n_iter_b}, cluster sizes min={cluster_sizes_b.min().item()} max={cluster_sizes_b.max().item()}")
cluster_ids = torch.stack(cluster_ids_list, dim=0) # (B,N)
centroids = torch.stack(centroids_list, dim=0).to(x.dtype) # (B,K,D)
# cluster_sizes = torch.stack(cluster_sizes_list, dim=0) # (B,K)
# --- compute cluster sizes ---
ones = torch.ones_like(cluster_ids, dtype=torch.int64)
cluster_sizes = torch.zeros(B, n_clusters, dtype=torch.int64, device=x.device)
cluster_sizes.scatter_add_(1, cluster_ids, ones)
return cluster_ids, centroids, cluster_sizes, n_iters_list
import math
from functools import lru_cache
from math import ceil
import torch
import torch.nn.functional as F
import triton
import triton.language as tl
from loguru import logger
from torch.nn.attention.flex_attention import create_block_mask, flex_attention
from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER
from .template import AttnWeightTemplate
@triton.jit
def wan_hidden_states_placement_kernel(
hidden_states_ptr, # [cfg, num_heads, seq_len, head_dim] seq_len = context_length + num_frame * frame_size
hidden_states_out_ptr, # [cfg, num_heads, seq_len, head_dim]
best_mask_idx_ptr, # [cfg, num_heads]
hidden_states_stride_b,
hidden_states_stride_h,
hidden_states_stride_s,
hidden_states_stride_d,
mask_idx_stride_b,
mask_idx_stride_h,
seq_len: tl.constexpr,
head_dim: tl.constexpr,
context_length: tl.constexpr,
num_frame: tl.constexpr,
frame_size: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
# Copy hidden_states to output
# range: [b, h, block_id * block_size: block_id * block_size + block_size, :]
cfg = tl.program_id(0)
head = tl.program_id(1)
block_id = tl.program_id(2)
start_id = block_id * BLOCK_SIZE
end_id = start_id + BLOCK_SIZE
end_id = tl.where(end_id > seq_len, seq_len, end_id)
# Load best mask idx (0 is spatial, 1 is temporal)
is_temporal = tl.load(best_mask_idx_ptr + cfg * mask_idx_stride_b + head * mask_idx_stride_h)
offset_token = tl.arange(0, BLOCK_SIZE) + start_id
offset_mask = offset_token < seq_len
offset_d = tl.arange(0, head_dim)
if is_temporal:
patch_id = offset_token // num_frame
frame_id = offset_token - patch_id * num_frame
offset_store_token = tl.where(offset_token >= seq_len - context_length, offset_token, frame_id * frame_size + patch_id)
offset_load = (cfg * hidden_states_stride_b + head * hidden_states_stride_h + offset_token[:, None] * hidden_states_stride_s) + offset_d[None, :] * hidden_states_stride_d
offset_hidden_states = hidden_states_ptr + offset_load
offset_store = (cfg * hidden_states_stride_b + head * hidden_states_stride_h + offset_store_token[:, None] * hidden_states_stride_s) + offset_d[None, :] * hidden_states_stride_d
offset_hidden_states_out = hidden_states_out_ptr + offset_store
# Maybe tune the pipeline here
hidden_states = tl.load(offset_hidden_states, mask=offset_mask[:, None])
tl.store(offset_hidden_states_out, hidden_states, mask=offset_mask[:, None])
else:
offset_load = (cfg * hidden_states_stride_b + head * hidden_states_stride_h + offset_token[:, None] * hidden_states_stride_s) + offset_d[None, :] * hidden_states_stride_d
offset_hidden_states = hidden_states_ptr + offset_load
offset_store = offset_load
offset_hidden_states_out = hidden_states_out_ptr + offset_store
# Maybe tune the pipeline here
hidden_states = tl.load(offset_hidden_states, mask=offset_mask[:, None])
tl.store(offset_hidden_states_out, hidden_states, mask=offset_mask[:, None])
def wan_hidden_states_placement(hidden_states, hidden_states_out, best_mask_idx, context_length, num_frame, frame_size):
cfg, num_heads, seq_len, head_dim = hidden_states.shape
BLOCK_SIZE = 128
assert seq_len == context_length + num_frame * frame_size
grid = (cfg, num_heads, (seq_len + BLOCK_SIZE - 1) // BLOCK_SIZE)
wan_hidden_states_placement_kernel[grid](
hidden_states,
hidden_states_out,
best_mask_idx,
hidden_states.stride(0),
hidden_states.stride(1),
hidden_states.stride(2),
hidden_states.stride(3),
best_mask_idx.stride(0),
best_mask_idx.stride(1),
seq_len,
head_dim,
context_length,
num_frame,
frame_size,
BLOCK_SIZE,
)
return hidden_states_out
@triton.jit
def wan_sparse_head_placement_kernel(
query_ptr,
key_ptr,
value_ptr, # [cfg, num_heads, seq_len, head_dim] seq_len = context_length + num_frame * frame_size
query_out_ptr,
key_out_ptr,
value_out_ptr, # [cfg, num_heads, seq_len, head_dim]
best_mask_idx_ptr, # [cfg, num_heads]
query_stride_b,
query_stride_h,
query_stride_s,
query_stride_d,
mask_idx_stride_b,
mask_idx_stride_h,
seq_len: tl.constexpr,
head_dim: tl.constexpr,
context_length: tl.constexpr,
num_frame: tl.constexpr,
frame_size: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
# Copy query, key, value to output
# range: [b, h, block_id * block_size: block_id * block_size + block_size, :]
cfg = tl.program_id(0)
head = tl.program_id(1)
block_id = tl.program_id(2)
start_id = block_id * BLOCK_SIZE
end_id = start_id + BLOCK_SIZE
end_id = tl.where(end_id > seq_len, seq_len, end_id)
# Load best mask idx (0 is spatial, 1 is temporal)
is_temporal = tl.load(best_mask_idx_ptr + cfg * mask_idx_stride_b + head * mask_idx_stride_h)
offset_token = tl.arange(0, BLOCK_SIZE) + start_id
offset_mask = offset_token < seq_len
offset_d = tl.arange(0, head_dim)
if is_temporal:
frame_id = offset_token // frame_size
patch_id = offset_token - frame_id * frame_size
offset_store_token = tl.where(offset_token >= seq_len - context_length, offset_token, patch_id * num_frame + frame_id)
offset_load = (cfg * query_stride_b + head * query_stride_h + offset_token[:, None] * query_stride_s) + offset_d[None, :] * query_stride_d
offset_query = query_ptr + offset_load
offset_key = key_ptr + offset_load
offset_value = value_ptr + offset_load
offset_store = (cfg * query_stride_b + head * query_stride_h + offset_store_token[:, None] * query_stride_s) + offset_d[None, :] * query_stride_d
offset_query_out = query_out_ptr + offset_store
offset_key_out = key_out_ptr + offset_store
offset_value_out = value_out_ptr + offset_store
# Maybe tune the pipeline here
query = tl.load(offset_query, mask=offset_mask[:, None])
tl.store(offset_query_out, query, mask=offset_mask[:, None])
key = tl.load(offset_key, mask=offset_mask[:, None])
tl.store(offset_key_out, key, mask=offset_mask[:, None])
value = tl.load(offset_value, mask=offset_mask[:, None])
tl.store(offset_value_out, value, mask=offset_mask[:, None])
else:
offset_load = (cfg * query_stride_b + head * query_stride_h + offset_token[:, None] * query_stride_s) + offset_d[None, :] * query_stride_d
offset_query = query_ptr + offset_load
offset_key = key_ptr + offset_load
offset_value = value_ptr + offset_load
offset_store = offset_load
offset_query_out = query_out_ptr + offset_store
offset_key_out = key_out_ptr + offset_store
offset_value_out = value_out_ptr + offset_store
# Maybe tune the pipeline here
query = tl.load(offset_query, mask=offset_mask[:, None])
tl.store(offset_query_out, query, mask=offset_mask[:, None])
key = tl.load(offset_key, mask=offset_mask[:, None])
tl.store(offset_key_out, key, mask=offset_mask[:, None])
value = tl.load(offset_value, mask=offset_mask[:, None])
tl.store(offset_value_out, value, mask=offset_mask[:, None])
def wan_sparse_head_placement(query, key, value, query_out, key_out, value_out, best_mask_idx, context_length, num_frame, frame_size):
cfg, num_heads, seq_len, head_dim = query.shape
BLOCK_SIZE = 128
assert seq_len == context_length + num_frame * frame_size
grid = (cfg, num_heads, (seq_len + BLOCK_SIZE - 1) // BLOCK_SIZE)
wan_sparse_head_placement_kernel[grid](
query,
key,
value,
query_out,
key_out,
value_out,
best_mask_idx,
query.stride(0),
query.stride(1),
query.stride(2),
query.stride(3),
best_mask_idx.stride(0),
best_mask_idx.stride(1),
seq_len,
head_dim,
context_length,
num_frame,
frame_size,
BLOCK_SIZE,
)
def generate_temporal_head_mask_mod(context_length: int = 226, prompt_length: int = 226, num_frames: int = 13, token_per_frame: int = 1350, mul: int = 2):
def round_to_multiple(idx):
return ceil(idx / 128) * 128
def temporal_mask_mod(b, h, q_idx, kv_idx):
two_frame = round_to_multiple(mul * token_per_frame)
temporal_head_mask = torch.abs(q_idx - kv_idx) <= two_frame
# return temporal_head_mask
first_frame_mask = kv_idx < token_per_frame
video_mask = first_frame_mask | temporal_head_mask
return video_mask
return temporal_mask_mod
@lru_cache
def create_block_mask_cached(score_mod, B, H, M, N, device="cuda", _compile=False):
block_mask = create_block_mask(score_mod, B, H, M, N, device=device, _compile=_compile)
return block_mask
def prepare_flexattention(cfg_size, num_head, head_dim, dtype, device, context_length, prompt_length, num_frame, frame_size, diag_width=1, multiplier=2):
assert diag_width == multiplier, f"{diag_width} is not equivalent to {multiplier}"
seq_len = context_length + num_frame * frame_size
mask_mod = generate_temporal_head_mask_mod(context_length, prompt_length, num_frame, frame_size, mul=multiplier)
block_mask = create_block_mask_cached(mask_mod, None, None, seq_len, seq_len, device=device, _compile=True)
return block_mask
def sparsity_to_width(sparsity, context_length, num_frame, frame_size):
seq_len = context_length + num_frame * frame_size
total_elements = seq_len**2
sparsity = (sparsity * total_elements - 2 * seq_len * context_length) / total_elements
width = seq_len * (1 - math.sqrt(1 - sparsity))
width_frame = width / frame_size
return width_frame
def get_attention_mask(mask_name, sample_mse_max_row, context_length, num_frame, frame_size):
attention_mask = torch.zeros((context_length + num_frame * frame_size, context_length + num_frame * frame_size), device="cpu")
# TODO: fix hard coded mask
if mask_name == "spatial":
pixel_attn_mask = torch.zeros_like(attention_mask, dtype=torch.bool, device="cpu")
pixel_attn_mask[:, :frame_size] = 1 # First Frame Sink
block_size, block_thres = 128, frame_size * 2
num_block = math.ceil(num_frame * frame_size / block_size)
for i in range(num_block):
for j in range(num_block):
if abs(i - j) < block_thres // block_size:
pixel_attn_mask[i * block_size : (i + 1) * block_size, j * block_size : (j + 1) * block_size] = 1
attention_mask = pixel_attn_mask
else:
pixel_attn_mask = torch.zeros_like(attention_mask, dtype=torch.bool, device="cpu")
pixel_attn_mask[:, :frame_size] = 1 # First Frame Sink
block_size, block_thres = 128, frame_size * 2
num_block = math.ceil(num_frame * frame_size / block_size)
for i in range(num_block):
for j in range(num_block):
if abs(i - j) < block_thres // block_size:
pixel_attn_mask[i * block_size : (i + 1) * block_size, j * block_size : (j + 1) * block_size] = 1
pixel_attn_mask = pixel_attn_mask.reshape(frame_size, num_frame, frame_size, num_frame).permute(1, 0, 3, 2).reshape(frame_size * num_frame, frame_size * num_frame)
attention_mask = pixel_attn_mask
attention_mask = attention_mask[:sample_mse_max_row].cuda()
return attention_mask
def diagonal_band_mask_from_sparsity(
block_num: int,
block_num_per_frame: int,
sparsity: float,
device="cpu",
):
k = int(round(block_num * (1 - sparsity) / 2))
k = max(0, min(k, block_num - 1))
idx = torch.arange(block_num, device=device)
mask = torch.abs(idx[:, None] - idx[None, :]) <= k
sink = idx[None, :] <= block_num_per_frame
mask = mask | sink
actual_sparsity = 1 - mask.float().mean().item()
logger.info(f"Diagonal Band Mask: block_num={block_num}, block_num_per_frame={block_num_per_frame}, sparsity={sparsity}, actual_sparsity={actual_sparsity}")
return mask
@ATTN_WEIGHT_REGISTER("svg_attn")
class SvgAttnWeight(AttnWeightTemplate):
head_num = None
head_dim = None
sample_mse_max_row = None
num_sampled_rows = None
context_length = None
attnmap_frame_num = None
seqlen = None
sparsity = None
mask_name_list = ["spatial", "temporal"]
attention_masks = None
block_mask = None
@classmethod
def prepare(cls, head_num, head_dim, sample_mse_max_row, num_sampled_rows, context_length, sparsity):
cls.head_num = head_num
cls.head_dim = head_dim
cls.sample_mse_max_row = sample_mse_max_row
cls.num_sampled_rows = num_sampled_rows
cls.context_length = context_length
cls.sparsity = sparsity
torch._dynamo.config.cache_size_limit = 192 * 3
torch._dynamo.config.accumulated_cache_size_limit = 192 * 3
logger.info(
f"SvgAttnWeight Prepare: head_num={head_num}, head_dim={head_dim}, sample_mse_max_row={sample_mse_max_row}, num_sampled_rows={num_sampled_rows}, context_length={context_length}, sparsity={sparsity}"
)
def __init__(self):
self.config = {}
self.sparse_attention = torch.compile(flex_attention, dynamic=False, mode="max-autotune-no-cudagraphs")
@classmethod
def prepare_mask(cls, seqlen):
# Use class attributes so updates affect all instances of this class
if seqlen == cls.seqlen:
return
frame_size = seqlen // cls.attnmap_frame_num
cls.attention_masks = [get_attention_mask(mask_name, cls.sample_mse_max_row, cls.context_length, cls.attnmap_frame_num, frame_size) for mask_name in cls.mask_name_list]
multiplier = diag_width = sparsity_to_width(cls.sparsity, cls.context_length, cls.attnmap_frame_num, frame_size)
cls.block_mask = prepare_flexattention(
1, cls.head_num, cls.head_dim, torch.bfloat16, "cuda", cls.context_length, cls.context_length, cls.attnmap_frame_num, frame_size, diag_width=diag_width, multiplier=multiplier
)
cls.seqlen = seqlen
logger.info(f"SvgAttnWeight Update: seqlen={seqlen}")
def apply(
self,
q,
k,
v,
cu_seqlens_q=None,
cu_seqlens_kv=None,
max_seqlen_q=None,
max_seqlen_kv=None,
**kwargs,
):
q = q.unsqueeze(0).transpose(1, 2)
k = k.unsqueeze(0).transpose(1, 2)
v = v.unsqueeze(0).transpose(1, 2)
bs, num_heads, seq_len, dim = q.size()
self.prepare_mask(seq_len)
sampled_mses = self.sample_mse(q, k, v)
best_mask_idx = torch.argmin(sampled_mses, dim=0)
output_hidden_states = torch.zeros_like(q)
query_out, key_out, value_out = torch.zeros_like(q), torch.zeros_like(k), torch.zeros_like(v)
query_out, key_out, value_out = self.fast_sparse_head_placement(
q, k, v, query_out, key_out, value_out, best_mask_idx, self.context_length, self.attnmap_frame_num, seq_len // self.attnmap_frame_num
)
hidden_states = self.sparse_attention(query_out, key_out, value_out, block_mask=self.block_mask)
wan_hidden_states_placement(hidden_states, output_hidden_states, best_mask_idx, self.context_length, self.attnmap_frame_num, seq_len // self.attnmap_frame_num)
return output_hidden_states.reshape(bs, num_heads, seq_len, dim).transpose(1, 2).reshape(bs * seq_len, -1)
def fast_sparse_head_placement(self, query, key, value, query_out, key_out, value_out, best_mask_idx, context_length, num_frame, frame_size):
wan_sparse_head_placement(query, key, value, query_out, key_out, value_out, best_mask_idx, context_length, num_frame, frame_size)
return query_out, key_out, value_out
def sample_mse(self, query, key, value):
cfg, num_heads, seq_len, dim = query.size()
num_sampled_rows = min(self.num_sampled_rows, seq_len)
sampled_rows = torch.randint(low=0, high=self.sample_mse_max_row, size=(num_sampled_rows,))
sampled_q = query[:, :, sampled_rows, :]
sampled_qk_scores = torch.matmul(sampled_q, key.transpose(-2, -1)) / (dim**0.5)
sampled_attn_weights = F.softmax(sampled_qk_scores, dim=-1)
sampled_golden_hidden_states = torch.matmul(sampled_attn_weights, value) # (1, seq_len, dim)
sampled_mses = torch.zeros(len(self.attention_masks), cfg, num_heads, device=query.device, dtype=query.dtype)
# Only have Tri-diagonal and Striped
for mask_idx, attn_mask in enumerate(self.attention_masks):
sampled_attention_mask = attn_mask[sampled_rows, :]
sampled_attention_scores = sampled_qk_scores.masked_fill(sampled_attention_mask == 0, float("-inf"))
sampled_attn_weights = F.softmax(sampled_attention_scores, dim=-1)
sampled_hidden_states = torch.matmul(sampled_attn_weights, value)
mse = torch.mean((sampled_hidden_states - sampled_golden_hidden_states) ** 2, dim=(2, 3))
sampled_mses[mask_idx] = mse
return sampled_mses
if __name__ == "__main__":
q, k, v = torch.randn(32130, 40, 128, dtype=torch.bfloat16).cuda(), torch.randn(32130, 40, 128, dtype=torch.bfloat16).cuda(), torch.randn(32130, 40, 128, dtype=torch.bfloat16).cuda()
SvgAttnWeight.prepare(head_num=40, head_dim=128, sample_mse_max_row=10000, num_sampled_rows=64, context_length=0, sparsity=0.25)
svg_attn = SvgAttnWeight()
print("SvgAttnWeight initialized.")
out = svg_attn.apply(q, k, v)
print(f"out: {out.shape}, {out.dtype}, {out.device}")
from abc import ABCMeta, abstractmethod
class AttnWeightTemplate(metaclass=ABCMeta):
def __init__(self, weight_name):
self.weight_name = weight_name
self.config = {}
def load(self, weight_dict):
pass
@abstractmethod
def apply(self, input_tensor):
pass
def set_config(self, config=None):
if config is not None:
self.config = config
def to_cpu(self, non_blocking=False):
pass
def to_cuda(self, non_blocking=False):
pass
def state_dict(self, destination=None):
if destination is None:
destination = {}
return destination
def load_state_dict(self, destination, block_index, adapter_block_inde=None):
return {}
def load_state_dict_from_disk(self, block_index, adapter_block_inde=None):
pass
import torch
import torch.nn.functional as F
from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER
from .template import AttnWeightTemplate
@ATTN_WEIGHT_REGISTER("torch_sdpa")
class TorchSDPAWeight(AttnWeightTemplate):
def __init__(self):
self.config = {}
def apply(
self,
q,
k,
v,
drop_rate=0,
attn_mask=None,
causal=False,
cu_seqlens_q=None,
cu_seqlens_kv=None,
max_seqlen_q=None,
max_seqlen_kv=None,
**kwargs,
):
if q.ndim == 3:
q, k, v = q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
if attn_mask is not None and attn_mask.dtype != torch.bool:
attn_mask = attn_mask.to(q.dtype)
x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal)
x = x.transpose(1, 2)
b, s, a, d = x.shape
out = x.reshape(b, s, -1)
return out.squeeze(0)
import torch
import torch.distributed as dist
from loguru import logger
from lightx2v.utils.quant_utils import dequant_fp8_vllm, quant_fp8_vllm
from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER
from lightx2v_platform.base.global_var import AI_DEVICE
from .template import AttnWeightTemplate
from .utils.all2all import all2all_head2seq
@ATTN_WEIGHT_REGISTER("ulysses")
class UlyssesAttnWeight(AttnWeightTemplate):
def __init__(self):
self.config = {}
def apply(
self,
q,
k,
v,
slice_qkv_len,
cu_seqlens_qkv,
attention_module=None,
attention_type="flash_attn2",
seq_p_group=None,
use_fp8_comm=False,
enable_head_parallel=False,
img_first=True,
**kwargs,
):
"""
执行 Ulysses 注意力机制,结合图像和文本的查询、键和值。
参数:
q (torch.Tensor): 查询张量,形状为 [shard_seqlen, heads, hidden_dims]
k (torch.Tensor): 键张量,形状为 [shard_seqlen, heads, hidden_dims]
v (torch.Tensor): 值张量,形状为 [shard_seqlen, heads, hidden_dims]
slice_qkv_len (int): 图像或者文本查询、键和值的长度,根据img_first确定谁在前半部分
cu_seqlens_qkv (torch.Tensor): 累积序列长度,包含文本和图像的长度信息
attention_type (str): 注意力类型,默认为 "flash_attn2"
返回:
torch.Tensor: 计算得到的注意力结果
"""
if len(q.shape) == 4:
q = q.reshape(-1, q.shape[-2], q.shape[-1])
k = k.reshape(-1, k.shape[-2], k.shape[-1])
v = v.reshape(-1, v.shape[-2], v.shape[-1])
# 获取当前进程的排名和全局进程数
world_size = dist.get_world_size(seq_p_group)
cur_rank = dist.get_rank(seq_p_group)
# 获取序列长度和文本相关的长度
if img_first:
img_qkv_len = slice_qkv_len
if len(cu_seqlens_qkv) == 3:
txt_qkv_len = cu_seqlens_qkv[1] - slice_qkv_len # 文本查询、键和值的长度
txt_mask_len = cu_seqlens_qkv[2] - slice_qkv_len # 文本掩码长度
elif len(cu_seqlens_qkv) == 2:
txt_qkv_len = cu_seqlens_qkv[1] - slice_qkv_len # 文本查询、键和值的长度
txt_mask_len = None
else:
# assert len(cu_seqlens_qkv) == 2
txt_qkv_len = slice_qkv_len
img_qkv_len = cu_seqlens_qkv[1] - slice_qkv_len
txt_mask_len = None
# 获取查询张量的头数和隐藏维度
_, heads, hidden_dims = q.shape
shard_heads = heads // world_size # 每个进程处理的头数
shard_seqlen = img_qkv_len # 每个进程处理的序列长度
global_img_seqlen = shard_seqlen * world_size # 全局序列长度
# 初始化累积序列长度张量
cu_seqlens_qkv = torch.zeros([2], dtype=torch.int32)
s = txt_qkv_len + global_img_seqlen # 计算文本和图像的总长度
s1 = s # 当前样本的结束位置
cu_seqlens_qkv[1] = s1 # 设置累积序列长度
if txt_mask_len:
s2 = txt_mask_len + global_img_seqlen # 文本掩码的结束位置
cu_seqlens_qkv = torch.cat((cu_seqlens_qkv, torch.tensor([s2], dtype=torch.int32)))
if attention_type == "flash_attn2" or attention_type == "flash_attn3":
cu_seqlens_qkv = cu_seqlens_qkv.to(AI_DEVICE, non_blocking=True)
max_seqlen_qkv = global_img_seqlen + txt_qkv_len # 最大序列长度
# 分割图像和文本的查询、键和值
if img_first:
img_q, img_k, img_v = q[:img_qkv_len, :, :].contiguous(), k[:img_qkv_len, :, :].contiguous(), v[:img_qkv_len, :, :].contiguous()
txt_q, txt_k, txt_v = q[img_qkv_len:, :, :].contiguous(), k[img_qkv_len:, :, :].contiguous(), v[img_qkv_len:, :, :].contiguous()
else:
txt_q, txt_k, txt_v = q[:txt_qkv_len, :, :].contiguous(), k[:txt_qkv_len, :, :].contiguous(), v[:txt_qkv_len, :, :].contiguous()
img_q, img_k, img_v = q[txt_qkv_len:, :, :].contiguous(), k[txt_qkv_len:, :, :].contiguous(), v[txt_qkv_len:, :, :].contiguous()
img_qkv = torch.stack([img_q, img_k, img_v], dim=0).reshape(3, img_qkv_len, world_size, shard_heads, hidden_dims)
original_dtype = img_qkv.dtype
if enable_head_parallel:
img_qkv = img_qkv.permute(3, 2, 1, 0, 4).contiguous() # (shard_heads, world_size, img_qkv_len, 3, hidden_dims)
output_qkv = torch.empty_like(img_qkv)
# 通信图像的查询、键和值
if use_fp8_comm:
img_qkv_fp8, img_qkv_scale = quant_fp8_vllm(img_qkv.reshape(-1, hidden_dims))
img_qkv_fp8 = img_qkv_fp8.reshape(shard_heads, world_size, img_qkv_len, 3, hidden_dims)
img_qkv_scale = img_qkv_scale.reshape(shard_heads, world_size, img_qkv_len, 3, 1)
output_qkv_fp8 = torch.empty_like(img_qkv_fp8)
output_qkv_scale = torch.empty_like(img_qkv_scale)
comm_fp8_works = []
comm_scale_works = []
for h in range(shard_heads):
work_fp8 = dist.all_to_all_single(output_qkv_fp8[h], img_qkv_fp8[h], group=seq_p_group, async_op=True)
work_scale = dist.all_to_all_single(output_qkv_scale[h], img_qkv_scale[h], group=seq_p_group, async_op=True)
comm_fp8_works.append(work_fp8)
comm_scale_works.append(work_scale)
else:
comm_works = []
for h in range(shard_heads):
work = dist.all_to_all_single(output_qkv[h], img_qkv[h], group=seq_p_group, async_op=True)
comm_works.append(work)
# 逐个head完成Attention计算
single_head = 1
head_attns = []
for h in range(shard_heads):
if use_fp8_comm:
comm_fp8_works[h].wait()
comm_scale_works[h].wait()
output_qkv[h] = dequant_fp8_vllm(output_qkv_fp8[h], output_qkv_scale[h], original_dtype)
else:
comm_works[h].wait()
qkv = output_qkv[h].reshape(global_img_seqlen, 3, single_head, hidden_dims).transpose(0, 1)
shard_img_q = qkv[0] # (global_img_seqlen, single_head, hidden_dims)
shard_img_k = qkv[1]
shard_img_v = qkv[2]
# 处理文本的查询、键和值,选择当前进程的当前头
shard_txt_q = txt_q[:, (cur_rank * shard_heads + h) : (cur_rank * shard_heads + h + 1), :]
shard_txt_k = txt_k[:, (cur_rank * shard_heads + h) : (cur_rank * shard_heads + h + 1), :]
shard_txt_v = txt_v[:, (cur_rank * shard_heads + h) : (cur_rank * shard_heads + h + 1), :]
# 合并图像和文本的查询、键和值
if img_first:
q = torch.cat((shard_img_q, shard_txt_q), dim=0)
k = torch.cat((shard_img_k, shard_txt_k), dim=0)
v = torch.cat((shard_img_v, shard_txt_v), dim=0)
else:
q = torch.cat((shard_txt_q, shard_img_q), dim=0)
k = torch.cat((shard_txt_k, shard_img_k), dim=0)
v = torch.cat((shard_txt_v, shard_img_v), dim=0)
# 调用注意力函数计算注意力结果
head_attn = attention_module.apply(q=q, k=k, v=v, cu_seqlens_q=cu_seqlens_qkv, cu_seqlens_kv=cu_seqlens_qkv, max_seqlen_q=max_seqlen_qkv, max_seqlen_kv=max_seqlen_qkv, **kwargs)
head_attns.append(head_attn)
# 合并当前进程的所有head的attn
attn = torch.cat(head_attns, dim=1)
else:
img_qkv = img_qkv.permute(2, 1, 0, 3, 4).contiguous() # (world_size, img_qkv_len, 3, shard_heads, hidden_dims)
# 通信图像的查询、键和值
if use_fp8_comm:
img_qkv_fp8, img_qkv_scale = quant_fp8_vllm(img_qkv.reshape(-1, hidden_dims))
img_qkv_fp8 = img_qkv_fp8.reshape(world_size, img_qkv_len, shard_heads, 3, hidden_dims)
img_qkv_scale = img_qkv_scale.reshape(world_size, img_qkv_len, shard_heads, 3, 1)
output_qkv_fp8 = torch.empty_like(img_qkv_fp8)
output_qkv_scale = torch.empty_like(img_qkv_scale)
dist.all_to_all_single(output_qkv_fp8, img_qkv_fp8, group=seq_p_group)
dist.all_to_all_single(output_qkv_scale, img_qkv_scale, group=seq_p_group)
output_qkv = dequant_fp8_vllm(output_qkv_fp8, output_qkv_scale, original_dtype)
else:
output_qkv = torch.empty_like(img_qkv)
dist.all_to_all_single(output_qkv, img_qkv, group=seq_p_group)
# 完成Attention计算
qkv = output_qkv.reshape(global_img_seqlen, 3, shard_heads, hidden_dims).transpose(0, 1)
shard_img_q = qkv[0] # (global_img_seqlen, shard_head, hidden_dims)
shard_img_k = qkv[1]
shard_img_v = qkv[2]
# 处理文本的查询、键和值,选择当前进程的当前头
shard_txt_q = txt_q[:, cur_rank * shard_heads : (cur_rank + 1) * shard_heads, :]
shard_txt_k = txt_k[:, cur_rank * shard_heads : (cur_rank + 1) * shard_heads, :]
shard_txt_v = txt_v[:, cur_rank * shard_heads : (cur_rank + 1) * shard_heads, :]
# 合并图像和文本的查询、键和值
if img_first:
q = torch.cat((shard_img_q, shard_txt_q), dim=0)
k = torch.cat((shard_img_k, shard_txt_k), dim=0)
v = torch.cat((shard_img_v, shard_txt_v), dim=0)
else:
q = torch.cat((shard_txt_q, shard_img_q), dim=0)
k = torch.cat((shard_txt_k, shard_img_k), dim=0)
v = torch.cat((shard_txt_v, shard_img_v), dim=0)
# 调用注意力函数计算注意力结果
attn = attention_module.apply(q=q, k=k, v=v, cu_seqlens_q=cu_seqlens_qkv, cu_seqlens_kv=cu_seqlens_qkv, max_seqlen_q=max_seqlen_qkv, max_seqlen_kv=max_seqlen_qkv, **kwargs)
# 分割图像和文本的注意力结果
if img_first:
img_attn, txt_attn = attn[:global_img_seqlen, :], attn[global_img_seqlen:]
else:
txt_attn, img_attn = attn[:txt_qkv_len, :], attn[txt_qkv_len:]
# 通信所有进程的图像注意力结果
img_attn = self._reshape_img_attn(img_attn, world_size, shard_seqlen, shard_heads, hidden_dims, seq_p_group, use_fp8_comm)
# 收集所有进程的文本注意力结果
gathered_txt_attn = [torch.empty_like(txt_attn) for _ in range(world_size)]
dist.all_gather(gathered_txt_attn, txt_attn, group=seq_p_group)
txt_attn = torch.cat(gathered_txt_attn, dim=1) # 合并所有进程的文本注意力结果
# 合并图像和文本的注意力结果
if img_first:
attn = torch.cat([img_attn, txt_attn], dim=0)
else:
attn = torch.cat([txt_attn, img_attn], dim=0)
return attn # 返回最终的注意力结果
@torch.compiler.disable
def _reshape_img_attn(self, img_attn, world_size, shard_seqlen, shard_heads, hidden_dims, seq_p_group, use_fp8_comm):
img_attn = img_attn.reshape(world_size * shard_seqlen, shard_heads, hidden_dims) # 重塑图像注意力结果
# 将头的格式转换回序列格式
if use_fp8_comm:
original_dtype = img_attn.dtype
original_shape = img_attn.shape
img_attn_fp8, attn_scale = quant_fp8_vllm(img_attn.reshape(-1, original_shape[-1]))
img_attn_fp8 = all2all_head2seq(img_attn_fp8.reshape(original_shape), group=seq_p_group)
attn_scale = all2all_head2seq(attn_scale.reshape(original_shape[0], original_shape[1], 1), group=seq_p_group)
img_attn = dequant_fp8_vllm(img_attn_fp8, attn_scale, original_dtype)
else:
img_attn = all2all_head2seq(img_attn, group=seq_p_group)
img_attn = img_attn.reshape(shard_seqlen, -1) # 重塑为 [shard_seqlen, -1] 形状
return img_attn
@ATTN_WEIGHT_REGISTER("ulysses-4090")
class Ulysses4090AttnWeight(AttnWeightTemplate):
def __init__(self):
self.config = {}
self.rounds = []
def generate_round_robin_pairs(self, seq_p_group=None):
"""
生成循环赛配对表,并确保每个配对中的第一个元素小于第二个
这样我们可以用简单的规则确定通信顺序
"""
cur_rank = dist.get_rank(seq_p_group)
world_size = dist.get_world_size(seq_p_group)
if world_size % 2 != 0:
raise ValueError("world_size必须是偶数,奇数情况需要特殊处理")
teams = list(range(world_size))
for _ in range(world_size - 1):
round_schedule = {}
for i in range(world_size // 2):
team1, team2 = teams[i], teams[world_size - 1 - i]
smaller, larger = min(team1, team2), max(team1, team2)
round_schedule[smaller] = (larger, True)
round_schedule[larger] = (smaller, False)
self.rounds.append(round_schedule)
# 旋转列表(固定第一个元素)
teams = [teams[0]] + [teams[-1]] + teams[1:-1]
# if cur_rank == 0:
# self.print_pairing_schedule(seq_p_group)
def print_pairing_schedule(self, seq_p_group):
"""打印通信调度表"""
world_size = dist.get_world_size(seq_p_group)
logger.info("循环赛通信调度表:")
logger.info("=" * 50)
for i, round_schedule in enumerate(self.rounds):
logger.info(f"第 {i + 1} 轮:")
for cur_rank in range(world_size):
partner, is_smaller_in_pair = round_schedule[cur_rank]
logger.info(f" 进程 {cur_rank} ←→ 进程 {partner}")
logger.info("=" * 50)
def load_balanced_all_to_all(self, shards, seq_p_group=None):
"""
负载均衡all-to-all通信实现
"""
world_size = dist.get_world_size(seq_p_group)
cur_rank = dist.get_rank(seq_p_group)
global_rank = dist.get_global_rank(seq_p_group, cur_rank)
cfg_p_group_index = global_rank // world_size
# 准备接收缓冲区
gathered_shards = [None] * world_size
for target_rank in range(world_size):
if target_rank != cur_rank:
gathered_shards[target_rank] = torch.empty_like(shards[target_rank])
else:
gathered_shards[cur_rank] = shards[cur_rank]
for i, round_schedule in enumerate(self.rounds):
# 查找当前进程在本轮的配对
partner = None
is_smaller_in_pair = False
if cur_rank in round_schedule:
partner, is_smaller_in_pair = round_schedule[cur_rank]
# 如果没有找到配对,说明本轮当前进程空闲
if partner is None:
continue
# 计算全局rank
partner_global_rank = cfg_p_group_index * world_size + partner
if is_smaller_in_pair:
# 当前进程是配对中的较小者,先发送后接收
send_req = dist.isend(shards[partner], dst=partner_global_rank, group=seq_p_group)
recv_req = dist.irecv(gathered_shards[partner], src=partner_global_rank, group=seq_p_group)
send_req.wait()
recv_req.wait()
else:
# 当前进程是配对中的较大者,先接收后发送
recv_req = dist.irecv(gathered_shards[partner], src=partner_global_rank, group=seq_p_group)
send_req = dist.isend(shards[partner], dst=partner_global_rank, group=seq_p_group)
recv_req.wait()
send_req.wait()
return gathered_shards
def apply(
self,
q,
k,
v,
slice_qkv_len,
cu_seqlens_qkv,
attention_module=None,
attention_type="flash_attn2",
seq_p_group=None,
use_fp8_comm=False,
enable_head_parallel=False,
img_first=True,
**kwargs,
):
"""
执行 Ulysses 注意力机制,结合图像和文本的查询、键和值。
参数:
q (torch.Tensor): 查询张量,形状为 [shard_seqlen, heads, hidden_dims]
k (torch.Tensor): 键张量,形状为 [shard_seqlen, heads, hidden_dims]
v (torch.Tensor): 值张量,形状为 [shard_seqlen, heads, hidden_dims]
slice_qkv_len (int): 图像或者文本查询、键和值的长度,根据img_first确定谁在前半部分
cu_seqlens_qkv (torch.Tensor): 累积序列长度,包含文本和图像的长度信息
attention_type (str): 注意力类型,默认为 "flash_attn2"
返回:
torch.Tensor: 计算得到的注意力结果
"""
assert not enable_head_parallel, "Ulysses-4090 can't support head parallel mode."
if len(self.rounds) == 0:
self.generate_round_robin_pairs(seq_p_group)
if len(q.shape) == 4:
q = q.reshape(-1, q.shape[-2], q.shape[-1])
k = k.reshape(-1, k.shape[-2], k.shape[-1])
v = v.reshape(-1, v.shape[-2], v.shape[-1])
# 获取当前进程的排名和全局进程数
world_size = dist.get_world_size(seq_p_group)
cur_rank = dist.get_rank(seq_p_group)
global_world_size = dist.get_world_size()
global_rank = dist.get_global_rank(seq_p_group, cur_rank)
cfg_p_group_index = global_rank // world_size
# 获取序列长度和文本相关的长度
if img_first:
img_qkv_len = slice_qkv_len
if len(cu_seqlens_qkv) == 3:
txt_qkv_len = cu_seqlens_qkv[1] - slice_qkv_len # 文本查询、键和值的长度
txt_mask_len = cu_seqlens_qkv[2] - slice_qkv_len # 文本掩码长度
elif len(cu_seqlens_qkv) == 2:
txt_qkv_len = cu_seqlens_qkv[1] - slice_qkv_len # 文本查询、键和值的长度
txt_mask_len = None
else:
# assert len(cu_seqlens_qkv) == 2
txt_qkv_len = slice_qkv_len
img_qkv_len = cu_seqlens_qkv[1] - slice_qkv_len
txt_mask_len = None
# 获取查询张量的头数和隐藏维度
_, heads, hidden_dims = q.shape
shard_heads = heads // world_size # 每个进程处理的头数
shard_seqlen = img_qkv_len # 每个进程处理的序列长度
# 分割图像和文本的查询、键和值
if img_first:
img_q, img_k, img_v = q[:img_qkv_len, :, :].contiguous(), k[:img_qkv_len, :, :].contiguous(), v[:img_qkv_len, :, :].contiguous()
txt_q, txt_k, txt_v = q[img_qkv_len:, :, :].contiguous(), k[img_qkv_len:, :, :].contiguous(), v[img_qkv_len:, :, :].contiguous()
else:
txt_q, txt_k, txt_v = q[:txt_qkv_len, :, :].contiguous(), k[:txt_qkv_len, :, :].contiguous(), v[:txt_qkv_len, :, :].contiguous()
img_q, img_k, img_v = q[txt_qkv_len:, :, :].contiguous(), k[txt_qkv_len:, :, :].contiguous(), v[txt_qkv_len:, :, :].contiguous()
# 计算每个进程应该持有的头数分片
num_heads = img_q.shape[1]
shard_heads = num_heads // world_size
# 将 image QKV 拼接后,按头维度切分成 N 份,每份大小为 D/N
img_qkv = torch.stack([img_q, img_k, img_v], dim=0)
qkv_shards = [img_qkv[:, :, i * shard_heads : (i + 1) * shard_heads, :].contiguous() for i in range(world_size)]
qkv_dtype = img_qkv.dtype
if use_fp8_comm:
qkv_fp8_byte_tensors = []
qkv_fp8_bytes = 0
qkv_fp8_dtype = None
qkv_scale_dtype = None
for i in range(world_size):
qkv_fp8, qkv_scale = quant_fp8_vllm(qkv_shards[i].reshape(-1, hidden_dims))
if i == 0:
qkv_fp8_bytes = qkv_fp8.numel() * qkv_fp8.element_size()
qkv_fp8_dtype = qkv_fp8.dtype
qkv_scale_dtype = qkv_scale.dtype
qkv_fp8_byte_tensors.append(torch.cat([qkv_fp8.contiguous().reshape(-1).view(torch.uint8), qkv_scale.contiguous().reshape(-1).view(torch.uint8)], dim=0))
gathered_qkv_fp8_byte_tensors = self.load_balanced_all_to_all(qkv_fp8_byte_tensors, seq_p_group)
gathered_q_shards = []
gathered_k_shards = []
gathered_v_shards = []
for i in range(world_size):
qkv_fp8_byte_tensor = gathered_qkv_fp8_byte_tensors[i]
qkv_fp8 = qkv_fp8_byte_tensor[:qkv_fp8_bytes].view(qkv_fp8_dtype).reshape(3, -1, hidden_dims)
qkv_scale = qkv_fp8_byte_tensor[qkv_fp8_bytes:].view(qkv_scale_dtype).reshape(3, -1, 1)
q_shards_new = dequant_fp8_vllm(qkv_fp8[0], qkv_scale[0], qkv_dtype).reshape(-1, shard_heads, hidden_dims)
k_shards_new = dequant_fp8_vllm(qkv_fp8[1], qkv_scale[1], qkv_dtype).reshape(-1, shard_heads, hidden_dims)
v_shards_new = dequant_fp8_vllm(qkv_fp8[2], qkv_scale[2], qkv_dtype).reshape(-1, shard_heads, hidden_dims)
gathered_q_shards.append(q_shards_new)
gathered_k_shards.append(k_shards_new)
gathered_v_shards.append(v_shards_new)
else:
gathered_qkv_byte_tensors = self.load_balanced_all_to_all(qkv_shards, seq_p_group)
gathered_q_shards = []
gathered_k_shards = []
gathered_v_shards = []
for i in range(world_size):
qkv_tensor = gathered_qkv_byte_tensors[i].view(qkv_dtype).reshape(3, -1, shard_heads, hidden_dims)
gathered_q_shards.append(qkv_tensor[0])
gathered_k_shards.append(qkv_tensor[1])
gathered_v_shards.append(qkv_tensor[2])
# 拼接所有分片 (在序列维度上)
# 每个 gathered_*_shards[i] 的形状是 (seq_len/N, num_heads/N, head_dim)
# 拼接后形状是 (seq_len, num_heads/N, head_dim)
img_q = torch.cat(gathered_q_shards, dim=0)
img_k = torch.cat(gathered_k_shards, dim=0)
img_v = torch.cat(gathered_v_shards, dim=0)
# 处理文本的查询、键和值,选择当前进程的头
txt_q = txt_q[:, cur_rank * shard_heads : (cur_rank + 1) * shard_heads, :]
txt_k = txt_k[:, cur_rank * shard_heads : (cur_rank + 1) * shard_heads, :]
txt_v = txt_v[:, cur_rank * shard_heads : (cur_rank + 1) * shard_heads, :]
# 合并图像和文本的查询、键和值
if img_first:
q = torch.cat((img_q, txt_q), dim=0)
k = torch.cat((img_k, txt_k), dim=0)
v = torch.cat((img_v, txt_v), dim=0)
else:
q = torch.cat((txt_q, img_q), dim=0)
k = torch.cat((txt_k, img_k), dim=0)
v = torch.cat((txt_v, img_v), dim=0)
# 初始化累积序列长度张量
cu_seqlens_qkv = torch.zeros([2], dtype=torch.int32, device="cuda")
s = txt_qkv_len + img_q.shape[0] # 计算文本和图像的总长度
s1 = s # 当前样本的结束位置
cu_seqlens_qkv[1] = s1 # 设置累积序列长度
if txt_mask_len:
s2 = txt_mask_len + img_q.shape[0] # 文本掩码的结束位置
cu_seqlens_qkv = torch.cat(cu_seqlens_qkv, s2)
max_seqlen_qkv = img_q.shape[0] + txt_q.shape[0] # 最大序列长度
# 调用注意力函数计算注意力结果
# attn = attention(attention_type=attention_type, q=q, k=k, v=v, cu_seqlens_q=cu_seqlens_qkv, cu_seqlens_kv=cu_seqlens_qkv, max_seqlen_q=max_seqlen_qkv, max_seqlen_kv=max_seqlen_qkv)
attn = attention_module.apply(q=q, k=k, v=v, cu_seqlens_q=cu_seqlens_qkv, cu_seqlens_kv=cu_seqlens_qkv, max_seqlen_q=max_seqlen_qkv, max_seqlen_kv=max_seqlen_qkv, **kwargs)
# 分割图像和文本的注意力结果
if img_first:
img_attn, txt_attn = attn[: img_q.shape[0], :], attn[img_q.shape[0] :,]
else:
txt_attn, img_attn = attn[: txt_q.shape[0], :], attn[txt_q.shape[0] :,]
# 收集所有进程的文本注意力结果
gathered_txt_attn = [torch.empty_like(txt_attn) for _ in range(world_size)]
dist.all_gather(gathered_txt_attn, txt_attn, group=seq_p_group)
img_attn = self._reshape_img_attn(img_attn, world_size, shard_seqlen, shard_heads, hidden_dims, seq_p_group, use_fp8_comm)
txt_attn = torch.cat(gathered_txt_attn, dim=1) # 合并所有进程的文本注意力结果
# 合并图像和文本的注意力结果
if img_first:
attn = torch.cat([img_attn, txt_attn], dim=0)
else:
attn = torch.cat([txt_attn, img_attn], dim=0)
return attn # 返回最终的注意力结果
@torch.compiler.disable
def _reshape_img_attn(self, img_attn, world_size, shard_seqlen, shard_heads, hidden_dims, seq_p_group, use_fp8_comm):
cur_rank = dist.get_rank(seq_p_group)
global_world_size = dist.get_world_size()
global_rank = dist.get_global_rank(seq_p_group, cur_rank)
cfg_p_group_index = global_rank // world_size
img_attn = img_attn.reshape(world_size * shard_seqlen, shard_heads, hidden_dims) # 重塑图像注意力结果
attn_dtype = img_attn.dtype
# 按序列维度切分成 N 份
attn_shards = [img_attn[i * shard_seqlen : (i + 1) * shard_seqlen, :, :].contiguous() for i in range(world_size)]
if use_fp8_comm:
attn_fp8_byte_tensors = []
attn_fp8_bytes = 0
attn_fp8_dtype = None
attn_scale_dtype = None
for i in range(world_size):
attn_fp8, attn_scale = quant_fp8_vllm(attn_shards[i].reshape(-1, hidden_dims))
if i == 0:
attn_fp8_bytes = attn_fp8.numel() * attn_fp8.element_size()
attn_fp8_dtype = attn_fp8.dtype
attn_scale_dtype = attn_scale.dtype
attn_fp8_byte_tensors.append(torch.cat([attn_fp8.contiguous().reshape(-1).view(torch.uint8), attn_scale.contiguous().reshape(-1).view(torch.uint8)], dim=0))
gathered_attn_fp8_byte_tensors = self.load_balanced_all_to_all(attn_fp8_byte_tensors, seq_p_group)
gathered_attn_shards = []
for i in range(world_size):
attn_fp8_byte_tensor = gathered_attn_fp8_byte_tensors[i]
attn_fp8 = attn_fp8_byte_tensor[:attn_fp8_bytes].view(attn_fp8_dtype).reshape(-1, hidden_dims)
attn_scale = attn_fp8_byte_tensor[attn_fp8_bytes:].view(attn_scale_dtype).reshape(-1, 1)
attn_shards_new = dequant_fp8_vllm(attn_fp8, attn_scale, attn_dtype).reshape(-1, shard_heads, hidden_dims)
gathered_attn_shards.append(attn_shards_new)
else:
gathered_attn_shards = self.load_balanced_all_to_all(attn_shards, seq_p_group)
# 拼接所有分片 (在头维度上)
img_attn = torch.cat(gathered_attn_shards, dim=1)
img_attn = img_attn.reshape(shard_seqlen, -1) # 重塑为 [shard_seqlen, -1] 形状
return img_attn
import torch
import torch._dynamo as dynamo
import torch.distributed as dist
@dynamo.disable
def all2all_seq2head(input, group=None):
"""
将输入张量从 [seq_len/N, heads, hidden_dims] 转换为 [seq_len, heads/N, hidden_dims] 的格式。
参数:
input (torch.Tensor): 输入张量,形状为 [seq_len/N, heads, hidden_dims]
返回:
torch.Tensor: 转换后的输出张量,形状为 [seq_len, heads/N, hidden_dims]
"""
# 确保输入是一个3D张量
assert input.dim() == 3, f"input must be 3D tensor"
# 获取当前进程的世界大小
world_size = dist.get_world_size(group=group)
# 获取输入张量的形状
shard_seq_len, heads, hidden_dims = input.shape
seq_len = shard_seq_len * world_size # 计算总序列长度
shard_heads = heads // world_size # 计算每个进程处理的头数
# 重塑输入张量以便进行 all-to-all 操作
input_t = (
input.reshape(shard_seq_len, world_size, shard_heads, hidden_dims) # 重塑为 [shard_seq_len, world_size, shard_heads, hidden_dims]
.transpose(0, 1) # 转置以便进行 all-to-all 操作
.contiguous() # 确保内存连续
)
# 创建一个与输入张量相同形状的输出张量
output = torch.empty_like(input_t)
# 执行 all-to-all 操作,将输入张量的内容分发到所有进程
dist.all_to_all_single(output, input_t, group=group)
# 重塑输出张量为 [seq_len, heads/N, hidden_dims] 形状
output = output.reshape(seq_len, shard_heads, hidden_dims).contiguous()
return output # 返回转换后的输出张量
@dynamo.disable
def all2all_head2seq(input, group=None):
"""
将输入张量从 [seq_len, heads/N, hidden_dims] 转换为 [seq_len/N, heads, hidden_dims] 的格式。
参数:
input (torch.Tensor): 输入张量,形状为 [seq_len, heads/N, hidden_dims]
返回:
torch.Tensor: 转换后的输出张量,形状为 [seq_len/N, heads, hidden_dims]
"""
# 确保输入是一个3D张量
assert input.dim() == 3, f"input must be 3D tensor"
# 获取当前进程的世界大小
world_size = dist.get_world_size(group=group)
# 获取输入张量的形状
seq_len, shard_heads, hidden_dims = input.shape
heads = shard_heads * world_size # 计算总头数
shard_seq_len = seq_len // world_size # 计算每个进程处理的序列长度
# 重塑输入张量以便进行 all-to-all 操作
input_t = (
input.reshape(world_size, shard_seq_len, shard_heads, hidden_dims) # 重塑为 [world_size, shard_seq_len, shard_heads, hidden_dims]
.transpose(1, 2) # 转置以便进行 all-to-all 操作
.contiguous() # 确保内存连续
.reshape(world_size, shard_heads, shard_seq_len, hidden_dims) # 再次重塑为 [world_size, shard_heads, shard_seq_len, hidden_dims]
)
# 创建一个与输入张量相同形状的输出张量
output = torch.empty_like(input_t)
# 执行 all-to-all 操作,将输入张量的内容分发到所有进程
dist.all_to_all_single(output, input_t, group=group)
# 重塑输出张量为 [heads, shard_seq_len, hidden_dims] 形状
output = output.reshape(heads, shard_seq_len, hidden_dims)
# 转置输出张量并重塑为 [shard_seq_len, heads, hidden_dims] 形状
output = output.transpose(0, 1).contiguous().reshape(shard_seq_len, heads, hidden_dims)
return output # 返回转换后的输出张量
from typing import Optional
import torch
import torch.distributed as dist
class RingComm:
def __init__(self, process_group: dist.ProcessGroup = None):
self._process_group = process_group
self._ops = []
self.rank = dist.get_rank(self._process_group)
self.world_size = dist.get_world_size(self._process_group)
self._reqs = None
self.send_rank = (self.rank + 1) % self.world_size
self.recv_rank = (self.rank - 1) % self.world_size
if process_group is not None:
self.send_rank = dist.get_global_rank(self._process_group, self.send_rank)
self.recv_rank = dist.get_global_rank(self._process_group, self.recv_rank)
def send_recv(self, to_send: torch.Tensor, recv_tensor: Optional[torch.Tensor] = None) -> torch.Tensor:
if recv_tensor is None:
res = torch.empty_like(to_send)
# logger.info(f"send_recv: empty_like {to_send.shape}")
else:
res = recv_tensor
send_op = dist.P2POp(dist.isend, to_send, self.send_rank, group=self._process_group)
recv_op = dist.P2POp(dist.irecv, res, self.recv_rank, group=self._process_group)
self._ops.append(send_op)
self._ops.append(recv_op)
return res
def commit(self):
if self._reqs is not None:
raise RuntimeError("commit called twice")
self._reqs = dist.batch_isend_irecv(self._ops)
def wait(self):
if self._reqs is None:
raise RuntimeError("wait called before commit")
for req in self._reqs:
req.wait()
self._reqs = None
self._ops = []
import torch
import triton
import triton.language as tl
@triton.jit
def compress_kernel(
X,
XM,
L: tl.constexpr,
D: tl.constexpr,
BLOCK_L: tl.constexpr,
):
idx_l = tl.program_id(0)
idx_bh = tl.program_id(1)
offs_l = idx_l * BLOCK_L + tl.arange(0, BLOCK_L)
offs_d = tl.arange(0, D)
x_offset = idx_bh * L * D
xm_offset = idx_bh * ((L + BLOCK_L - 1) // BLOCK_L) * D
x = tl.load(X + x_offset + offs_l[:, None] * D + offs_d[None, :], mask=offs_l[:, None] < L)
nx = min(BLOCK_L, L - idx_l * BLOCK_L)
x_mean = tl.sum(x, axis=0, dtype=tl.float32) / nx
tl.store(XM + xm_offset + idx_l * D + offs_d, x_mean.to(XM.dtype.element_ty))
def mean_pool(x, BLK):
assert x.is_contiguous()
B, H, L, D = x.shape
L_BLOCKS = (L + BLK - 1) // BLK
x_mean = torch.empty((B, H, L_BLOCKS, D), device=x.device, dtype=x.dtype)
grid = (L_BLOCKS, B * H)
compress_kernel[grid](x, x_mean, L, D, BLK)
return x_mean
def get_block_map(q, k, topk_ratio, BLKQ=64, BLKK=64):
arg_k = k - torch.mean(k, dim=-2, keepdim=True) # smooth-k technique in SageAttention
pooled_qblocks = mean_pool(q, BLKQ)
pooled_kblocks = mean_pool(arg_k, BLKK)
pooled_score = pooled_qblocks @ pooled_kblocks.transpose(-1, -2)
K = pooled_score.shape[-1]
topk = min(K, int(topk_ratio * K))
lut = torch.topk(pooled_score, topk, dim=-1, sorted=False).indices
sparse_map = torch.zeros_like(pooled_score, dtype=torch.int8)
sparse_map.scatter_(-1, lut, 1)
return sparse_map, lut, topk
def get_cuda_arch(device_index):
major, minor = torch.cuda.get_device_capability(device_index)
return f"sm{major}{minor}"
from .conv2d import *
from .conv3d import *
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment