Unverified Commit 8b1e4f94 authored by Yang Yong (雍洋)'s avatar Yang Yong (雍洋) Committed by GitHub
Browse files

Support SVG2 (#384)

parent b20ec092
{
"infer_steps": 40,
"target_video_length": 81,
"target_height": 480,
"target_width": 832,
"self_attn_1_type": "svg2_attn",
"cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3",
"sample_guide_scale": 5,
"sample_shift": 3,
"enable_cfg": true,
"cpu_offload": false
}
......@@ -2,6 +2,7 @@ from .flash_attn import FlashAttn2Weight, FlashAttn3Weight
from .radial_attn import RadialAttnWeight
from .ring_attn import RingAttnWeight
from .sage_attn import SageAttn2Weight
from .svg2_attn import Svg2AttnWeight
from .svg_attn import SvgAttnWeight
from .torch_sdpa import TorchSDPAWeight
from .ulysses_attn import UlyssesAttnWeight
from typing import Optional
# Please reinstall flashinfer by referring to https://github.com/svg-project/Sparse-VideoGen
try:
import flashinfer
except ImportError:
flashinfer = None
import torch
import triton
import triton.language as tl
from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER
from .svg2_attn_utils import (
batch_kmeans_Euclid,
identify_dynamic_map,
)
from .template import AttnWeightTemplate
@triton.jit
def _permute_kernel(
X_ptr,
IDX_ptr,
Y_ptr,
S: tl.constexpr,
D: tl.constexpr,
BLOCK_S: tl.constexpr,
):
"""Each program permutes BLOCK_S tokens *all* hidden features (D). No inner python loop."""
pid_bh = tl.program_id(0)
tile_s = tl.program_id(1)
# Offsets along sequence
s_offsets = tile_s * BLOCK_S + tl.arange(0, BLOCK_S)
token_mask = s_offsets < S
# Gather source indices for these tokens
idx_ptrs = IDX_ptr + pid_bh * S + s_offsets
src_row_idx = tl.load(idx_ptrs, mask=token_mask, other=0).to(tl.int32)
# Broadcast to create 2-D pointer matrix (BLOCK_S, D)
d_offsets = tl.arange(0, D)
src_ptrs = X_ptr + (pid_bh * S + src_row_idx[:, None]) * D + d_offsets[None, :]
dst_ptrs = Y_ptr + (pid_bh * S + s_offsets[:, None]) * D + d_offsets[None, :]
full_mask = token_mask[:, None]
values = tl.load(src_ptrs, mask=full_mask, other=0.0)
tl.store(dst_ptrs, values, mask=full_mask)
def permute_tensor_by_labels_triton(
tensor: torch.Tensor,
labels: Optional[torch.Tensor],
dim: int,
*,
sorted_indices: Optional[torch.Tensor] = None,
):
"""
Permute `tensor` along `dim` according to ascending order of `labels`.
This is a Triton-accelerated replacement for the original implementation.
It currently supports 4-D tensors of shape [B, H, S, D] and `dim == 2`.
If these conditions are not met or the tensors reside on CPU, we fall back
to the reference PyTorch implementation.
"""
# Assertions – we only support the optimized CUDA path.
assert dim == 2, "permute_tensor_by_labels currently only supports dim==2 (sequence dimension)"
assert tensor.dim() == 4, "Expected tensor shape [B,H,S,D]"
assert tensor.is_cuda, "permute_tensor_by_labels requires CUDA tensors"
B, H, S, D = tensor.shape
BH = B * H
# Determine sorted indices
if sorted_indices is not None:
sorted_indices = sorted_indices.to(torch.int32).contiguous()
else:
assert labels is not None, "Either `labels` or `sorted_indices` must be provided."
labels = labels.to(tensor.device)
sorted_indices = torch.argsort(labels, dim=-1).to(torch.int32).contiguous()
# Flatten tensor and allocate output
inp_flat = tensor.reshape(BH, S, D).contiguous()
out_flat = torch.empty_like(inp_flat)
# Triton kernel tile size
BLOCK_S = 64 # number of tokens per program, tunable
n_s_tiles = triton.cdiv(S, BLOCK_S)
grid = (BH, n_s_tiles)
_permute_kernel[grid](inp_flat, sorted_indices, out_flat, S, D, BLOCK_S, num_warps=4)
permuted_tensor = out_flat.reshape(B, H, S, D)
return permuted_tensor, sorted_indices
@triton.jit
def _inverse_permute_kernel(
X_ptr,
IDX_ptr,
Y_ptr,
S: tl.constexpr,
D: tl.constexpr,
BLOCK_S: tl.constexpr,
):
"""Inverse permutation: scatter BLOCK_S tokens back in one shot."""
pid_bh = tl.program_id(0)
tile_s = tl.program_id(1)
s_offsets = tile_s * BLOCK_S + tl.arange(0, BLOCK_S)
token_mask = s_offsets < S
idx_ptrs = IDX_ptr + pid_bh * S + s_offsets
src_pos_idx = s_offsets.to(tl.int32)
dst_pos_idx = tl.load(idx_ptrs, mask=token_mask, other=0).to(tl.int32)
d_offsets = tl.arange(0, D)
src_ptrs = X_ptr + (pid_bh * S + src_pos_idx[:, None]) * D + d_offsets[None, :]
dst_ptrs = Y_ptr + (pid_bh * S + dst_pos_idx[:, None]) * D + d_offsets[None, :]
full_mask = token_mask[:, None]
values = tl.load(src_ptrs, mask=full_mask, other=0.0)
tl.store(dst_ptrs, values, mask=full_mask)
def apply_inverse_permutation_triton(
permuted_tensor: torch.Tensor,
sorted_indices: torch.Tensor,
dim: int,
):
"""
Triton implementation of inverse permutation. Inverse the permutation applied by `permute_tensor_by_labels`.
Args:
permuted_tensor: (B, H, S, D).
sorted_indices: (B, H, S).
dim: Dimension along which to apply inverse permutation. Typically 2, meaning the sequence lengthdimension.
Returns:
Tensor of shape (B, H, S, D).
"""
assert dim == 2, "apply_inverse_permutation currently only supports dim==2"
assert permuted_tensor.dim() == 4, "Expected tensor shape [B,H,S,D]"
assert permuted_tensor.is_cuda, "apply_inverse_permutation requires CUDA tensors"
B, H, S, D = permuted_tensor.shape
BH = B * H
# Ensure index dtype
sorted_indices = sorted_indices.to(torch.int32).contiguous()
# Flatten inputs
inp_flat = permuted_tensor.reshape(BH, S, D).contiguous()
out_flat = torch.empty_like(inp_flat)
BLOCK_S = 64
n_s_tiles = triton.cdiv(S, BLOCK_S)
grid = (BH, n_s_tiles)
_inverse_permute_kernel[grid](inp_flat, sorted_indices, out_flat, S, D, BLOCK_S, num_warps=4)
original_tensor = out_flat.reshape(B, H, S, D)
return original_tensor
@ATTN_WEIGHT_REGISTER("svg2_attn")
class Svg2AttnWeight(AttnWeightTemplate):
centroids_init = False
num_q_centroids = 300
num_k_centroids = 1000
kmeans_iter_init = 50
top_p_kmeans = 0.9
min_kc_ratio = 0.10
kmeans_iter_step = 2
def __init__(self):
self.config = {}
def apply(
self,
q,
k,
v,
cu_seqlens_q=None,
cu_seqlens_kv=None,
max_seqlen_q=None,
max_seqlen_kv=None,
model_cls=None,
):
q = q.unsqueeze(0).transpose(1, 2)
k = k.unsqueeze(0).transpose(1, 2)
v = v.unsqueeze(0).transpose(1, 2)
bs, num_heads, seq_len, dim = q.size()
q_perm, k_perm, v_perm, dyn_map, qc_sz_s, kc_sz_s, q_sorted_indices = self.semantic_aware_permutation(q, k, v)
output_permuted = self.dynamic_block_sparse_fwd_flashinfer(q_perm, k_perm, v_perm, dyn_map, qc_sz_s, kc_sz_s, is_cpu=False)
attn_output = apply_inverse_permutation_triton(output_permuted, q_sorted_indices, dim=2)
return attn_output.reshape(bs, num_heads, seq_len, dim).transpose(1, 2).reshape(bs * seq_len, -1)
def dynamic_block_sparse_fwd_flashinfer(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
block_mask_map: torch.Tensor,
block_row_sz: torch.Tensor,
block_col_sz: torch.Tensor,
is_cpu: bool = True,
):
"""
Launcher for the Flashinfer dynamic block sparse attention kernel.
Args:
q (torch.Tensor): Query tensor, shape [B, H, S, D].
k (torch.Tensor): Key tensor, shape [B, H, S, D].
v (torch.Tensor): Value tensor, shape [B, H, S, D].
block_mask_map (torch.Tensor): Boolean mask, shape [B, H, qc_num, kc_num]. Currently must on CPU.
block_row_sz (torch.Tensor): Query block sizes, shape [B, H, qc_num]. Currently must on CPU.
block_col_sz (torch.Tensor): Key block sizes, shape [B, H, kc_num]. Currently must on CPU.
is_cpu (bool): Whether to run on CPU. Flashinfer default is to run on CPU. We switch to GPU for faster planning. Default is True.
"""
# Input shape check
B, H, S, D = q.shape
qc_num = block_row_sz.shape[-1]
kc_num = block_col_sz.shape[-1]
assert block_mask_map.shape == (B, H, qc_num, kc_num)
assert all(t.device == torch.device("cpu") for t in [block_mask_map, block_row_sz, block_col_sz]) if is_cpu else True
# Check if block_col_sz and block_row_sz are the same for each head
assert torch.all(block_col_sz.sum(dim=2) == block_col_sz.sum(dim=2)[0, 0])
assert torch.all(block_row_sz.sum(dim=2) == block_row_sz.sum(dim=2)[0, 0])
# Prepare flashinfer wrapper
float_workspace_buffer = torch.empty(128 * 1024 * 1024, device=q.device)
vector_sparse_indices_buffer = torch.empty(1024 * 1024 * 1024, device=q.device)
wrapper = flashinfer.sparse.VariableBlockSparseAttentionWrapper(float_workspace_buffer, backend="auto")
wrapper.reset_workspace_buffer(
float_workspace_buffer=wrapper._float_workspace_buffer,
int_workspace_buffer=wrapper._int_workspace_buffer,
vector_sparse_indices_buffer=vector_sparse_indices_buffer, # Only reset this buffer size
vector_sparse_indptr_buffer=wrapper._vector_sparse_indptr_buffer,
)
block_mask_map = block_mask_map.reshape(B * H, qc_num, kc_num)
block_row_sz = block_row_sz.reshape(B * H, qc_num)
block_col_sz = block_col_sz.reshape(B * H, kc_num)
wrapper.plan(
block_mask_map=block_mask_map,
block_row_sz=block_row_sz,
block_col_sz=block_col_sz,
num_qo_heads=B * H,
num_kv_heads=B * H,
head_dim=D,
q_data_type=q.dtype,
kv_data_type=k.dtype,
)
# print_memory_usage("After plan")
q = q.reshape(B * H, S, D)
k = k.reshape(B * H, S, D)
v = v.reshape(B * H, S, D)
o = wrapper.run(q, k, v) # [num_qo_heads, qo_len, head_dim]
o = o.reshape(B, H, S, D)
return o
def semantic_aware_permutation(self, query, key, value):
cfg, num_heads, seq_len, dim = query.size()
# 1. Kmeans clustering
qlabels, qcentroids, qcluster_sizes, qiter, klabels, kcentroids, kcluster_sizes, kiter = self.kmeans_clustering(query, key)
# 2. Identify dynamic map
q_cluster_sizes = qcluster_sizes.view(cfg, num_heads, self.num_q_centroids)
k_cluster_sizes = kcluster_sizes.view(cfg, num_heads, self.num_k_centroids)
dynamic_map = identify_dynamic_map(
qcentroids.view(cfg, num_heads, self.num_q_centroids, dim),
kcentroids.view(cfg, num_heads, self.num_k_centroids, dim),
q_cluster_sizes,
k_cluster_sizes,
self.top_p_kmeans,
self.min_kc_ratio,
)
# 3. Permute the query, key, value
q_permuted, q_sorted_indices = permute_tensor_by_labels_triton(query, qlabels, dim=2)
k_permuted, k_sorted_indices = permute_tensor_by_labels_triton(key, klabels, dim=2)
v_permuted, v_sorted_indices = permute_tensor_by_labels_triton(value, klabels, dim=2, sorted_indices=k_sorted_indices)
return q_permuted, k_permuted, v_permuted, dynamic_map, q_cluster_sizes, k_cluster_sizes, q_sorted_indices
def kmeans_clustering(self, query, key):
if not self.centroids_init:
qlabels, qcentroids, qcluster_sizes, qiter, klabels, kcentroids, kcluster_sizes, kiter = self.kmeans_init(query, key)
self.centroids_init = True
else:
qlabels, qcentroids, qcluster_sizes, qiter, klabels, kcentroids, kcluster_sizes, kiter = self.kmeans_step(query, key)
return qlabels, qcentroids, qcluster_sizes, qiter, klabels, kcentroids, kcluster_sizes, kiter
def kmeans_init(self, query, key):
cfg, num_heads, seq_len, dim = query.size()
qlabels, qcentroids, qcluster_sizes, qiter = batch_kmeans_Euclid(query.view(cfg * num_heads, seq_len, dim), n_clusters=self.num_q_centroids, max_iters=self.kmeans_iter_init)
klabels, kcentroids, kcluster_sizes, kiter = batch_kmeans_Euclid(key.view(cfg * num_heads, seq_len, dim), n_clusters=self.num_k_centroids, max_iters=self.kmeans_iter_init)
self.q_centroids = qcentroids
self.k_centroids = kcentroids
return qlabels, qcentroids, qcluster_sizes, qiter, klabels, kcentroids, kcluster_sizes, kiter
def kmeans_step(self, query, key):
cfg, num_heads, seq_len, dim = query.size()
qlabels, qcentroids, qcluster_sizes, qiter = batch_kmeans_Euclid(
query.view(cfg * num_heads, seq_len, dim),
n_clusters=self.num_q_centroids,
max_iters=self.kmeans_iter_step,
init_centroids=self.q_centroids,
)
klabels, kcentroids, kcluster_sizes, kiter = batch_kmeans_Euclid(
key.view(cfg * num_heads, seq_len, dim),
n_clusters=self.num_k_centroids,
max_iters=self.kmeans_iter_step,
init_centroids=self.k_centroids,
)
self.q_centroids = qcentroids
self.k_centroids = kcentroids
return qlabels, qcentroids, qcluster_sizes, qiter, klabels, kcentroids, kcluster_sizes, kiter
if __name__ == "__main__":
q, k, v = torch.randn(32130, 40, 128, dtype=torch.bfloat16).cuda(), torch.randn(32130, 40, 128, dtype=torch.bfloat16).cuda(), torch.randn(32130, 40, 128, dtype=torch.bfloat16).cuda()
svg2_attn = Svg2AttnWeight()
print("Svg2AttnWeight initialized.")
out = svg2_attn.apply(q, k, v)
print(f"out: {out.shape}, {out.dtype}, {out.device}")
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment