Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
8b1e4f94
Unverified
Commit
8b1e4f94
authored
Oct 21, 2025
by
Yang Yong (雍洋)
Committed by
GitHub
Oct 21, 2025
Browse files
Support SVG2 (#384)
parent
b20ec092
Changes
4
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
1728 additions
and
0 deletions
+1728
-0
configs/attentions/wan_i2v_svg2.json
configs/attentions/wan_i2v_svg2.json
+13
-0
lightx2v/common/ops/attn/__init__.py
lightx2v/common/ops/attn/__init__.py
+1
-0
lightx2v/common/ops/attn/svg2_attn.py
lightx2v/common/ops/attn/svg2_attn.py
+355
-0
lightx2v/common/ops/attn/svg2_attn_utils.py
lightx2v/common/ops/attn/svg2_attn_utils.py
+1359
-0
No files found.
configs/attentions/wan_i2v_svg2.json
0 → 100755
View file @
8b1e4f94
{
"infer_steps"
:
40
,
"target_video_length"
:
81
,
"target_height"
:
480
,
"target_width"
:
832
,
"self_attn_1_type"
:
"svg2_attn"
,
"cross_attn_1_type"
:
"flash_attn3"
,
"cross_attn_2_type"
:
"flash_attn3"
,
"sample_guide_scale"
:
5
,
"sample_shift"
:
3
,
"enable_cfg"
:
true
,
"cpu_offload"
:
false
}
lightx2v/common/ops/attn/__init__.py
View file @
8b1e4f94
...
...
@@ -2,6 +2,7 @@ from .flash_attn import FlashAttn2Weight, FlashAttn3Weight
from
.radial_attn
import
RadialAttnWeight
from
.ring_attn
import
RingAttnWeight
from
.sage_attn
import
SageAttn2Weight
from
.svg2_attn
import
Svg2AttnWeight
from
.svg_attn
import
SvgAttnWeight
from
.torch_sdpa
import
TorchSDPAWeight
from
.ulysses_attn
import
UlyssesAttnWeight
lightx2v/common/ops/attn/svg2_attn.py
0 → 100644
View file @
8b1e4f94
from
typing
import
Optional
# Please reinstall flashinfer by referring to https://github.com/svg-project/Sparse-VideoGen
try
:
import
flashinfer
except
ImportError
:
flashinfer
=
None
import
torch
import
triton
import
triton.language
as
tl
from
lightx2v.utils.registry_factory
import
ATTN_WEIGHT_REGISTER
from
.svg2_attn_utils
import
(
batch_kmeans_Euclid
,
identify_dynamic_map
,
)
from
.template
import
AttnWeightTemplate
@
triton
.
jit
def
_permute_kernel
(
X_ptr
,
IDX_ptr
,
Y_ptr
,
S
:
tl
.
constexpr
,
D
:
tl
.
constexpr
,
BLOCK_S
:
tl
.
constexpr
,
):
"""Each program permutes BLOCK_S tokens *all* hidden features (D). No inner python loop."""
pid_bh
=
tl
.
program_id
(
0
)
tile_s
=
tl
.
program_id
(
1
)
# Offsets along sequence
s_offsets
=
tile_s
*
BLOCK_S
+
tl
.
arange
(
0
,
BLOCK_S
)
token_mask
=
s_offsets
<
S
# Gather source indices for these tokens
idx_ptrs
=
IDX_ptr
+
pid_bh
*
S
+
s_offsets
src_row_idx
=
tl
.
load
(
idx_ptrs
,
mask
=
token_mask
,
other
=
0
).
to
(
tl
.
int32
)
# Broadcast to create 2-D pointer matrix (BLOCK_S, D)
d_offsets
=
tl
.
arange
(
0
,
D
)
src_ptrs
=
X_ptr
+
(
pid_bh
*
S
+
src_row_idx
[:,
None
])
*
D
+
d_offsets
[
None
,
:]
dst_ptrs
=
Y_ptr
+
(
pid_bh
*
S
+
s_offsets
[:,
None
])
*
D
+
d_offsets
[
None
,
:]
full_mask
=
token_mask
[:,
None
]
values
=
tl
.
load
(
src_ptrs
,
mask
=
full_mask
,
other
=
0.0
)
tl
.
store
(
dst_ptrs
,
values
,
mask
=
full_mask
)
def
permute_tensor_by_labels_triton
(
tensor
:
torch
.
Tensor
,
labels
:
Optional
[
torch
.
Tensor
],
dim
:
int
,
*
,
sorted_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
):
"""
Permute `tensor` along `dim` according to ascending order of `labels`.
This is a Triton-accelerated replacement for the original implementation.
It currently supports 4-D tensors of shape [B, H, S, D] and `dim == 2`.
If these conditions are not met or the tensors reside on CPU, we fall back
to the reference PyTorch implementation.
"""
# Assertions – we only support the optimized CUDA path.
assert
dim
==
2
,
"permute_tensor_by_labels currently only supports dim==2 (sequence dimension)"
assert
tensor
.
dim
()
==
4
,
"Expected tensor shape [B,H,S,D]"
assert
tensor
.
is_cuda
,
"permute_tensor_by_labels requires CUDA tensors"
B
,
H
,
S
,
D
=
tensor
.
shape
BH
=
B
*
H
# Determine sorted indices
if
sorted_indices
is
not
None
:
sorted_indices
=
sorted_indices
.
to
(
torch
.
int32
).
contiguous
()
else
:
assert
labels
is
not
None
,
"Either `labels` or `sorted_indices` must be provided."
labels
=
labels
.
to
(
tensor
.
device
)
sorted_indices
=
torch
.
argsort
(
labels
,
dim
=-
1
).
to
(
torch
.
int32
).
contiguous
()
# Flatten tensor and allocate output
inp_flat
=
tensor
.
reshape
(
BH
,
S
,
D
).
contiguous
()
out_flat
=
torch
.
empty_like
(
inp_flat
)
# Triton kernel tile size
BLOCK_S
=
64
# number of tokens per program, tunable
n_s_tiles
=
triton
.
cdiv
(
S
,
BLOCK_S
)
grid
=
(
BH
,
n_s_tiles
)
_permute_kernel
[
grid
](
inp_flat
,
sorted_indices
,
out_flat
,
S
,
D
,
BLOCK_S
,
num_warps
=
4
)
permuted_tensor
=
out_flat
.
reshape
(
B
,
H
,
S
,
D
)
return
permuted_tensor
,
sorted_indices
@
triton
.
jit
def
_inverse_permute_kernel
(
X_ptr
,
IDX_ptr
,
Y_ptr
,
S
:
tl
.
constexpr
,
D
:
tl
.
constexpr
,
BLOCK_S
:
tl
.
constexpr
,
):
"""Inverse permutation: scatter BLOCK_S tokens back in one shot."""
pid_bh
=
tl
.
program_id
(
0
)
tile_s
=
tl
.
program_id
(
1
)
s_offsets
=
tile_s
*
BLOCK_S
+
tl
.
arange
(
0
,
BLOCK_S
)
token_mask
=
s_offsets
<
S
idx_ptrs
=
IDX_ptr
+
pid_bh
*
S
+
s_offsets
src_pos_idx
=
s_offsets
.
to
(
tl
.
int32
)
dst_pos_idx
=
tl
.
load
(
idx_ptrs
,
mask
=
token_mask
,
other
=
0
).
to
(
tl
.
int32
)
d_offsets
=
tl
.
arange
(
0
,
D
)
src_ptrs
=
X_ptr
+
(
pid_bh
*
S
+
src_pos_idx
[:,
None
])
*
D
+
d_offsets
[
None
,
:]
dst_ptrs
=
Y_ptr
+
(
pid_bh
*
S
+
dst_pos_idx
[:,
None
])
*
D
+
d_offsets
[
None
,
:]
full_mask
=
token_mask
[:,
None
]
values
=
tl
.
load
(
src_ptrs
,
mask
=
full_mask
,
other
=
0.0
)
tl
.
store
(
dst_ptrs
,
values
,
mask
=
full_mask
)
def
apply_inverse_permutation_triton
(
permuted_tensor
:
torch
.
Tensor
,
sorted_indices
:
torch
.
Tensor
,
dim
:
int
,
):
"""
Triton implementation of inverse permutation. Inverse the permutation applied by `permute_tensor_by_labels`.
Args:
permuted_tensor: (B, H, S, D).
sorted_indices: (B, H, S).
dim: Dimension along which to apply inverse permutation. Typically 2, meaning the sequence lengthdimension.
Returns:
Tensor of shape (B, H, S, D).
"""
assert
dim
==
2
,
"apply_inverse_permutation currently only supports dim==2"
assert
permuted_tensor
.
dim
()
==
4
,
"Expected tensor shape [B,H,S,D]"
assert
permuted_tensor
.
is_cuda
,
"apply_inverse_permutation requires CUDA tensors"
B
,
H
,
S
,
D
=
permuted_tensor
.
shape
BH
=
B
*
H
# Ensure index dtype
sorted_indices
=
sorted_indices
.
to
(
torch
.
int32
).
contiguous
()
# Flatten inputs
inp_flat
=
permuted_tensor
.
reshape
(
BH
,
S
,
D
).
contiguous
()
out_flat
=
torch
.
empty_like
(
inp_flat
)
BLOCK_S
=
64
n_s_tiles
=
triton
.
cdiv
(
S
,
BLOCK_S
)
grid
=
(
BH
,
n_s_tiles
)
_inverse_permute_kernel
[
grid
](
inp_flat
,
sorted_indices
,
out_flat
,
S
,
D
,
BLOCK_S
,
num_warps
=
4
)
original_tensor
=
out_flat
.
reshape
(
B
,
H
,
S
,
D
)
return
original_tensor
@
ATTN_WEIGHT_REGISTER
(
"svg2_attn"
)
class
Svg2AttnWeight
(
AttnWeightTemplate
):
centroids_init
=
False
num_q_centroids
=
300
num_k_centroids
=
1000
kmeans_iter_init
=
50
top_p_kmeans
=
0.9
min_kc_ratio
=
0.10
kmeans_iter_step
=
2
def
__init__
(
self
):
self
.
config
=
{}
def
apply
(
self
,
q
,
k
,
v
,
cu_seqlens_q
=
None
,
cu_seqlens_kv
=
None
,
max_seqlen_q
=
None
,
max_seqlen_kv
=
None
,
model_cls
=
None
,
):
q
=
q
.
unsqueeze
(
0
).
transpose
(
1
,
2
)
k
=
k
.
unsqueeze
(
0
).
transpose
(
1
,
2
)
v
=
v
.
unsqueeze
(
0
).
transpose
(
1
,
2
)
bs
,
num_heads
,
seq_len
,
dim
=
q
.
size
()
q_perm
,
k_perm
,
v_perm
,
dyn_map
,
qc_sz_s
,
kc_sz_s
,
q_sorted_indices
=
self
.
semantic_aware_permutation
(
q
,
k
,
v
)
output_permuted
=
self
.
dynamic_block_sparse_fwd_flashinfer
(
q_perm
,
k_perm
,
v_perm
,
dyn_map
,
qc_sz_s
,
kc_sz_s
,
is_cpu
=
False
)
attn_output
=
apply_inverse_permutation_triton
(
output_permuted
,
q_sorted_indices
,
dim
=
2
)
return
attn_output
.
reshape
(
bs
,
num_heads
,
seq_len
,
dim
).
transpose
(
1
,
2
).
reshape
(
bs
*
seq_len
,
-
1
)
def
dynamic_block_sparse_fwd_flashinfer
(
self
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
block_mask_map
:
torch
.
Tensor
,
block_row_sz
:
torch
.
Tensor
,
block_col_sz
:
torch
.
Tensor
,
is_cpu
:
bool
=
True
,
):
"""
Launcher for the Flashinfer dynamic block sparse attention kernel.
Args:
q (torch.Tensor): Query tensor, shape [B, H, S, D].
k (torch.Tensor): Key tensor, shape [B, H, S, D].
v (torch.Tensor): Value tensor, shape [B, H, S, D].
block_mask_map (torch.Tensor): Boolean mask, shape [B, H, qc_num, kc_num]. Currently must on CPU.
block_row_sz (torch.Tensor): Query block sizes, shape [B, H, qc_num]. Currently must on CPU.
block_col_sz (torch.Tensor): Key block sizes, shape [B, H, kc_num]. Currently must on CPU.
is_cpu (bool): Whether to run on CPU. Flashinfer default is to run on CPU. We switch to GPU for faster planning. Default is True.
"""
# Input shape check
B
,
H
,
S
,
D
=
q
.
shape
qc_num
=
block_row_sz
.
shape
[
-
1
]
kc_num
=
block_col_sz
.
shape
[
-
1
]
assert
block_mask_map
.
shape
==
(
B
,
H
,
qc_num
,
kc_num
)
assert
all
(
t
.
device
==
torch
.
device
(
"cpu"
)
for
t
in
[
block_mask_map
,
block_row_sz
,
block_col_sz
])
if
is_cpu
else
True
# Check if block_col_sz and block_row_sz are the same for each head
assert
torch
.
all
(
block_col_sz
.
sum
(
dim
=
2
)
==
block_col_sz
.
sum
(
dim
=
2
)[
0
,
0
])
assert
torch
.
all
(
block_row_sz
.
sum
(
dim
=
2
)
==
block_row_sz
.
sum
(
dim
=
2
)[
0
,
0
])
# Prepare flashinfer wrapper
float_workspace_buffer
=
torch
.
empty
(
128
*
1024
*
1024
,
device
=
q
.
device
)
vector_sparse_indices_buffer
=
torch
.
empty
(
1024
*
1024
*
1024
,
device
=
q
.
device
)
wrapper
=
flashinfer
.
sparse
.
VariableBlockSparseAttentionWrapper
(
float_workspace_buffer
,
backend
=
"auto"
)
wrapper
.
reset_workspace_buffer
(
float_workspace_buffer
=
wrapper
.
_float_workspace_buffer
,
int_workspace_buffer
=
wrapper
.
_int_workspace_buffer
,
vector_sparse_indices_buffer
=
vector_sparse_indices_buffer
,
# Only reset this buffer size
vector_sparse_indptr_buffer
=
wrapper
.
_vector_sparse_indptr_buffer
,
)
block_mask_map
=
block_mask_map
.
reshape
(
B
*
H
,
qc_num
,
kc_num
)
block_row_sz
=
block_row_sz
.
reshape
(
B
*
H
,
qc_num
)
block_col_sz
=
block_col_sz
.
reshape
(
B
*
H
,
kc_num
)
wrapper
.
plan
(
block_mask_map
=
block_mask_map
,
block_row_sz
=
block_row_sz
,
block_col_sz
=
block_col_sz
,
num_qo_heads
=
B
*
H
,
num_kv_heads
=
B
*
H
,
head_dim
=
D
,
q_data_type
=
q
.
dtype
,
kv_data_type
=
k
.
dtype
,
)
# print_memory_usage("After plan")
q
=
q
.
reshape
(
B
*
H
,
S
,
D
)
k
=
k
.
reshape
(
B
*
H
,
S
,
D
)
v
=
v
.
reshape
(
B
*
H
,
S
,
D
)
o
=
wrapper
.
run
(
q
,
k
,
v
)
# [num_qo_heads, qo_len, head_dim]
o
=
o
.
reshape
(
B
,
H
,
S
,
D
)
return
o
def
semantic_aware_permutation
(
self
,
query
,
key
,
value
):
cfg
,
num_heads
,
seq_len
,
dim
=
query
.
size
()
# 1. Kmeans clustering
qlabels
,
qcentroids
,
qcluster_sizes
,
qiter
,
klabels
,
kcentroids
,
kcluster_sizes
,
kiter
=
self
.
kmeans_clustering
(
query
,
key
)
# 2. Identify dynamic map
q_cluster_sizes
=
qcluster_sizes
.
view
(
cfg
,
num_heads
,
self
.
num_q_centroids
)
k_cluster_sizes
=
kcluster_sizes
.
view
(
cfg
,
num_heads
,
self
.
num_k_centroids
)
dynamic_map
=
identify_dynamic_map
(
qcentroids
.
view
(
cfg
,
num_heads
,
self
.
num_q_centroids
,
dim
),
kcentroids
.
view
(
cfg
,
num_heads
,
self
.
num_k_centroids
,
dim
),
q_cluster_sizes
,
k_cluster_sizes
,
self
.
top_p_kmeans
,
self
.
min_kc_ratio
,
)
# 3. Permute the query, key, value
q_permuted
,
q_sorted_indices
=
permute_tensor_by_labels_triton
(
query
,
qlabels
,
dim
=
2
)
k_permuted
,
k_sorted_indices
=
permute_tensor_by_labels_triton
(
key
,
klabels
,
dim
=
2
)
v_permuted
,
v_sorted_indices
=
permute_tensor_by_labels_triton
(
value
,
klabels
,
dim
=
2
,
sorted_indices
=
k_sorted_indices
)
return
q_permuted
,
k_permuted
,
v_permuted
,
dynamic_map
,
q_cluster_sizes
,
k_cluster_sizes
,
q_sorted_indices
def
kmeans_clustering
(
self
,
query
,
key
):
if
not
self
.
centroids_init
:
qlabels
,
qcentroids
,
qcluster_sizes
,
qiter
,
klabels
,
kcentroids
,
kcluster_sizes
,
kiter
=
self
.
kmeans_init
(
query
,
key
)
self
.
centroids_init
=
True
else
:
qlabels
,
qcentroids
,
qcluster_sizes
,
qiter
,
klabels
,
kcentroids
,
kcluster_sizes
,
kiter
=
self
.
kmeans_step
(
query
,
key
)
return
qlabels
,
qcentroids
,
qcluster_sizes
,
qiter
,
klabels
,
kcentroids
,
kcluster_sizes
,
kiter
def
kmeans_init
(
self
,
query
,
key
):
cfg
,
num_heads
,
seq_len
,
dim
=
query
.
size
()
qlabels
,
qcentroids
,
qcluster_sizes
,
qiter
=
batch_kmeans_Euclid
(
query
.
view
(
cfg
*
num_heads
,
seq_len
,
dim
),
n_clusters
=
self
.
num_q_centroids
,
max_iters
=
self
.
kmeans_iter_init
)
klabels
,
kcentroids
,
kcluster_sizes
,
kiter
=
batch_kmeans_Euclid
(
key
.
view
(
cfg
*
num_heads
,
seq_len
,
dim
),
n_clusters
=
self
.
num_k_centroids
,
max_iters
=
self
.
kmeans_iter_init
)
self
.
q_centroids
=
qcentroids
self
.
k_centroids
=
kcentroids
return
qlabels
,
qcentroids
,
qcluster_sizes
,
qiter
,
klabels
,
kcentroids
,
kcluster_sizes
,
kiter
def
kmeans_step
(
self
,
query
,
key
):
cfg
,
num_heads
,
seq_len
,
dim
=
query
.
size
()
qlabels
,
qcentroids
,
qcluster_sizes
,
qiter
=
batch_kmeans_Euclid
(
query
.
view
(
cfg
*
num_heads
,
seq_len
,
dim
),
n_clusters
=
self
.
num_q_centroids
,
max_iters
=
self
.
kmeans_iter_step
,
init_centroids
=
self
.
q_centroids
,
)
klabels
,
kcentroids
,
kcluster_sizes
,
kiter
=
batch_kmeans_Euclid
(
key
.
view
(
cfg
*
num_heads
,
seq_len
,
dim
),
n_clusters
=
self
.
num_k_centroids
,
max_iters
=
self
.
kmeans_iter_step
,
init_centroids
=
self
.
k_centroids
,
)
self
.
q_centroids
=
qcentroids
self
.
k_centroids
=
kcentroids
return
qlabels
,
qcentroids
,
qcluster_sizes
,
qiter
,
klabels
,
kcentroids
,
kcluster_sizes
,
kiter
if
__name__
==
"__main__"
:
q
,
k
,
v
=
torch
.
randn
(
32130
,
40
,
128
,
dtype
=
torch
.
bfloat16
).
cuda
(),
torch
.
randn
(
32130
,
40
,
128
,
dtype
=
torch
.
bfloat16
).
cuda
(),
torch
.
randn
(
32130
,
40
,
128
,
dtype
=
torch
.
bfloat16
).
cuda
()
svg2_attn
=
Svg2AttnWeight
()
print
(
"Svg2AttnWeight initialized."
)
out
=
svg2_attn
.
apply
(
q
,
k
,
v
)
print
(
f
"out:
{
out
.
shape
}
,
{
out
.
dtype
}
,
{
out
.
device
}
"
)
lightx2v/common/ops/attn/svg2_attn_utils.py
0 → 100755
View file @
8b1e4f94
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment