Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
8b1e4f94
"llm/vscode:/vscode.git/clone" did not exist on "89bbaafa64421e835c841435d0fdff94aa4152e7"
Unverified
Commit
8b1e4f94
authored
Oct 21, 2025
by
Yang Yong (雍洋)
Committed by
GitHub
Oct 21, 2025
Browse files
Support SVG2 (#384)
parent
b20ec092
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
1728 additions
and
0 deletions
+1728
-0
configs/attentions/wan_i2v_svg2.json
configs/attentions/wan_i2v_svg2.json
+13
-0
lightx2v/common/ops/attn/__init__.py
lightx2v/common/ops/attn/__init__.py
+1
-0
lightx2v/common/ops/attn/svg2_attn.py
lightx2v/common/ops/attn/svg2_attn.py
+355
-0
lightx2v/common/ops/attn/svg2_attn_utils.py
lightx2v/common/ops/attn/svg2_attn_utils.py
+1359
-0
No files found.
configs/attentions/wan_i2v_svg2.json
0 → 100755
View file @
8b1e4f94
{
"infer_steps"
:
40
,
"target_video_length"
:
81
,
"target_height"
:
480
,
"target_width"
:
832
,
"self_attn_1_type"
:
"svg2_attn"
,
"cross_attn_1_type"
:
"flash_attn3"
,
"cross_attn_2_type"
:
"flash_attn3"
,
"sample_guide_scale"
:
5
,
"sample_shift"
:
3
,
"enable_cfg"
:
true
,
"cpu_offload"
:
false
}
lightx2v/common/ops/attn/__init__.py
View file @
8b1e4f94
...
@@ -2,6 +2,7 @@ from .flash_attn import FlashAttn2Weight, FlashAttn3Weight
...
@@ -2,6 +2,7 @@ from .flash_attn import FlashAttn2Weight, FlashAttn3Weight
from
.radial_attn
import
RadialAttnWeight
from
.radial_attn
import
RadialAttnWeight
from
.ring_attn
import
RingAttnWeight
from
.ring_attn
import
RingAttnWeight
from
.sage_attn
import
SageAttn2Weight
from
.sage_attn
import
SageAttn2Weight
from
.svg2_attn
import
Svg2AttnWeight
from
.svg_attn
import
SvgAttnWeight
from
.svg_attn
import
SvgAttnWeight
from
.torch_sdpa
import
TorchSDPAWeight
from
.torch_sdpa
import
TorchSDPAWeight
from
.ulysses_attn
import
UlyssesAttnWeight
from
.ulysses_attn
import
UlyssesAttnWeight
lightx2v/common/ops/attn/svg2_attn.py
0 → 100644
View file @
8b1e4f94
from
typing
import
Optional
# Please reinstall flashinfer by referring to https://github.com/svg-project/Sparse-VideoGen
try
:
import
flashinfer
except
ImportError
:
flashinfer
=
None
import
torch
import
triton
import
triton.language
as
tl
from
lightx2v.utils.registry_factory
import
ATTN_WEIGHT_REGISTER
from
.svg2_attn_utils
import
(
batch_kmeans_Euclid
,
identify_dynamic_map
,
)
from
.template
import
AttnWeightTemplate
@
triton
.
jit
def
_permute_kernel
(
X_ptr
,
IDX_ptr
,
Y_ptr
,
S
:
tl
.
constexpr
,
D
:
tl
.
constexpr
,
BLOCK_S
:
tl
.
constexpr
,
):
"""Each program permutes BLOCK_S tokens *all* hidden features (D). No inner python loop."""
pid_bh
=
tl
.
program_id
(
0
)
tile_s
=
tl
.
program_id
(
1
)
# Offsets along sequence
s_offsets
=
tile_s
*
BLOCK_S
+
tl
.
arange
(
0
,
BLOCK_S
)
token_mask
=
s_offsets
<
S
# Gather source indices for these tokens
idx_ptrs
=
IDX_ptr
+
pid_bh
*
S
+
s_offsets
src_row_idx
=
tl
.
load
(
idx_ptrs
,
mask
=
token_mask
,
other
=
0
).
to
(
tl
.
int32
)
# Broadcast to create 2-D pointer matrix (BLOCK_S, D)
d_offsets
=
tl
.
arange
(
0
,
D
)
src_ptrs
=
X_ptr
+
(
pid_bh
*
S
+
src_row_idx
[:,
None
])
*
D
+
d_offsets
[
None
,
:]
dst_ptrs
=
Y_ptr
+
(
pid_bh
*
S
+
s_offsets
[:,
None
])
*
D
+
d_offsets
[
None
,
:]
full_mask
=
token_mask
[:,
None
]
values
=
tl
.
load
(
src_ptrs
,
mask
=
full_mask
,
other
=
0.0
)
tl
.
store
(
dst_ptrs
,
values
,
mask
=
full_mask
)
def
permute_tensor_by_labels_triton
(
tensor
:
torch
.
Tensor
,
labels
:
Optional
[
torch
.
Tensor
],
dim
:
int
,
*
,
sorted_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
):
"""
Permute `tensor` along `dim` according to ascending order of `labels`.
This is a Triton-accelerated replacement for the original implementation.
It currently supports 4-D tensors of shape [B, H, S, D] and `dim == 2`.
If these conditions are not met or the tensors reside on CPU, we fall back
to the reference PyTorch implementation.
"""
# Assertions – we only support the optimized CUDA path.
assert
dim
==
2
,
"permute_tensor_by_labels currently only supports dim==2 (sequence dimension)"
assert
tensor
.
dim
()
==
4
,
"Expected tensor shape [B,H,S,D]"
assert
tensor
.
is_cuda
,
"permute_tensor_by_labels requires CUDA tensors"
B
,
H
,
S
,
D
=
tensor
.
shape
BH
=
B
*
H
# Determine sorted indices
if
sorted_indices
is
not
None
:
sorted_indices
=
sorted_indices
.
to
(
torch
.
int32
).
contiguous
()
else
:
assert
labels
is
not
None
,
"Either `labels` or `sorted_indices` must be provided."
labels
=
labels
.
to
(
tensor
.
device
)
sorted_indices
=
torch
.
argsort
(
labels
,
dim
=-
1
).
to
(
torch
.
int32
).
contiguous
()
# Flatten tensor and allocate output
inp_flat
=
tensor
.
reshape
(
BH
,
S
,
D
).
contiguous
()
out_flat
=
torch
.
empty_like
(
inp_flat
)
# Triton kernel tile size
BLOCK_S
=
64
# number of tokens per program, tunable
n_s_tiles
=
triton
.
cdiv
(
S
,
BLOCK_S
)
grid
=
(
BH
,
n_s_tiles
)
_permute_kernel
[
grid
](
inp_flat
,
sorted_indices
,
out_flat
,
S
,
D
,
BLOCK_S
,
num_warps
=
4
)
permuted_tensor
=
out_flat
.
reshape
(
B
,
H
,
S
,
D
)
return
permuted_tensor
,
sorted_indices
@
triton
.
jit
def
_inverse_permute_kernel
(
X_ptr
,
IDX_ptr
,
Y_ptr
,
S
:
tl
.
constexpr
,
D
:
tl
.
constexpr
,
BLOCK_S
:
tl
.
constexpr
,
):
"""Inverse permutation: scatter BLOCK_S tokens back in one shot."""
pid_bh
=
tl
.
program_id
(
0
)
tile_s
=
tl
.
program_id
(
1
)
s_offsets
=
tile_s
*
BLOCK_S
+
tl
.
arange
(
0
,
BLOCK_S
)
token_mask
=
s_offsets
<
S
idx_ptrs
=
IDX_ptr
+
pid_bh
*
S
+
s_offsets
src_pos_idx
=
s_offsets
.
to
(
tl
.
int32
)
dst_pos_idx
=
tl
.
load
(
idx_ptrs
,
mask
=
token_mask
,
other
=
0
).
to
(
tl
.
int32
)
d_offsets
=
tl
.
arange
(
0
,
D
)
src_ptrs
=
X_ptr
+
(
pid_bh
*
S
+
src_pos_idx
[:,
None
])
*
D
+
d_offsets
[
None
,
:]
dst_ptrs
=
Y_ptr
+
(
pid_bh
*
S
+
dst_pos_idx
[:,
None
])
*
D
+
d_offsets
[
None
,
:]
full_mask
=
token_mask
[:,
None
]
values
=
tl
.
load
(
src_ptrs
,
mask
=
full_mask
,
other
=
0.0
)
tl
.
store
(
dst_ptrs
,
values
,
mask
=
full_mask
)
def
apply_inverse_permutation_triton
(
permuted_tensor
:
torch
.
Tensor
,
sorted_indices
:
torch
.
Tensor
,
dim
:
int
,
):
"""
Triton implementation of inverse permutation. Inverse the permutation applied by `permute_tensor_by_labels`.
Args:
permuted_tensor: (B, H, S, D).
sorted_indices: (B, H, S).
dim: Dimension along which to apply inverse permutation. Typically 2, meaning the sequence lengthdimension.
Returns:
Tensor of shape (B, H, S, D).
"""
assert
dim
==
2
,
"apply_inverse_permutation currently only supports dim==2"
assert
permuted_tensor
.
dim
()
==
4
,
"Expected tensor shape [B,H,S,D]"
assert
permuted_tensor
.
is_cuda
,
"apply_inverse_permutation requires CUDA tensors"
B
,
H
,
S
,
D
=
permuted_tensor
.
shape
BH
=
B
*
H
# Ensure index dtype
sorted_indices
=
sorted_indices
.
to
(
torch
.
int32
).
contiguous
()
# Flatten inputs
inp_flat
=
permuted_tensor
.
reshape
(
BH
,
S
,
D
).
contiguous
()
out_flat
=
torch
.
empty_like
(
inp_flat
)
BLOCK_S
=
64
n_s_tiles
=
triton
.
cdiv
(
S
,
BLOCK_S
)
grid
=
(
BH
,
n_s_tiles
)
_inverse_permute_kernel
[
grid
](
inp_flat
,
sorted_indices
,
out_flat
,
S
,
D
,
BLOCK_S
,
num_warps
=
4
)
original_tensor
=
out_flat
.
reshape
(
B
,
H
,
S
,
D
)
return
original_tensor
@
ATTN_WEIGHT_REGISTER
(
"svg2_attn"
)
class
Svg2AttnWeight
(
AttnWeightTemplate
):
centroids_init
=
False
num_q_centroids
=
300
num_k_centroids
=
1000
kmeans_iter_init
=
50
top_p_kmeans
=
0.9
min_kc_ratio
=
0.10
kmeans_iter_step
=
2
def
__init__
(
self
):
self
.
config
=
{}
def
apply
(
self
,
q
,
k
,
v
,
cu_seqlens_q
=
None
,
cu_seqlens_kv
=
None
,
max_seqlen_q
=
None
,
max_seqlen_kv
=
None
,
model_cls
=
None
,
):
q
=
q
.
unsqueeze
(
0
).
transpose
(
1
,
2
)
k
=
k
.
unsqueeze
(
0
).
transpose
(
1
,
2
)
v
=
v
.
unsqueeze
(
0
).
transpose
(
1
,
2
)
bs
,
num_heads
,
seq_len
,
dim
=
q
.
size
()
q_perm
,
k_perm
,
v_perm
,
dyn_map
,
qc_sz_s
,
kc_sz_s
,
q_sorted_indices
=
self
.
semantic_aware_permutation
(
q
,
k
,
v
)
output_permuted
=
self
.
dynamic_block_sparse_fwd_flashinfer
(
q_perm
,
k_perm
,
v_perm
,
dyn_map
,
qc_sz_s
,
kc_sz_s
,
is_cpu
=
False
)
attn_output
=
apply_inverse_permutation_triton
(
output_permuted
,
q_sorted_indices
,
dim
=
2
)
return
attn_output
.
reshape
(
bs
,
num_heads
,
seq_len
,
dim
).
transpose
(
1
,
2
).
reshape
(
bs
*
seq_len
,
-
1
)
def
dynamic_block_sparse_fwd_flashinfer
(
self
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
block_mask_map
:
torch
.
Tensor
,
block_row_sz
:
torch
.
Tensor
,
block_col_sz
:
torch
.
Tensor
,
is_cpu
:
bool
=
True
,
):
"""
Launcher for the Flashinfer dynamic block sparse attention kernel.
Args:
q (torch.Tensor): Query tensor, shape [B, H, S, D].
k (torch.Tensor): Key tensor, shape [B, H, S, D].
v (torch.Tensor): Value tensor, shape [B, H, S, D].
block_mask_map (torch.Tensor): Boolean mask, shape [B, H, qc_num, kc_num]. Currently must on CPU.
block_row_sz (torch.Tensor): Query block sizes, shape [B, H, qc_num]. Currently must on CPU.
block_col_sz (torch.Tensor): Key block sizes, shape [B, H, kc_num]. Currently must on CPU.
is_cpu (bool): Whether to run on CPU. Flashinfer default is to run on CPU. We switch to GPU for faster planning. Default is True.
"""
# Input shape check
B
,
H
,
S
,
D
=
q
.
shape
qc_num
=
block_row_sz
.
shape
[
-
1
]
kc_num
=
block_col_sz
.
shape
[
-
1
]
assert
block_mask_map
.
shape
==
(
B
,
H
,
qc_num
,
kc_num
)
assert
all
(
t
.
device
==
torch
.
device
(
"cpu"
)
for
t
in
[
block_mask_map
,
block_row_sz
,
block_col_sz
])
if
is_cpu
else
True
# Check if block_col_sz and block_row_sz are the same for each head
assert
torch
.
all
(
block_col_sz
.
sum
(
dim
=
2
)
==
block_col_sz
.
sum
(
dim
=
2
)[
0
,
0
])
assert
torch
.
all
(
block_row_sz
.
sum
(
dim
=
2
)
==
block_row_sz
.
sum
(
dim
=
2
)[
0
,
0
])
# Prepare flashinfer wrapper
float_workspace_buffer
=
torch
.
empty
(
128
*
1024
*
1024
,
device
=
q
.
device
)
vector_sparse_indices_buffer
=
torch
.
empty
(
1024
*
1024
*
1024
,
device
=
q
.
device
)
wrapper
=
flashinfer
.
sparse
.
VariableBlockSparseAttentionWrapper
(
float_workspace_buffer
,
backend
=
"auto"
)
wrapper
.
reset_workspace_buffer
(
float_workspace_buffer
=
wrapper
.
_float_workspace_buffer
,
int_workspace_buffer
=
wrapper
.
_int_workspace_buffer
,
vector_sparse_indices_buffer
=
vector_sparse_indices_buffer
,
# Only reset this buffer size
vector_sparse_indptr_buffer
=
wrapper
.
_vector_sparse_indptr_buffer
,
)
block_mask_map
=
block_mask_map
.
reshape
(
B
*
H
,
qc_num
,
kc_num
)
block_row_sz
=
block_row_sz
.
reshape
(
B
*
H
,
qc_num
)
block_col_sz
=
block_col_sz
.
reshape
(
B
*
H
,
kc_num
)
wrapper
.
plan
(
block_mask_map
=
block_mask_map
,
block_row_sz
=
block_row_sz
,
block_col_sz
=
block_col_sz
,
num_qo_heads
=
B
*
H
,
num_kv_heads
=
B
*
H
,
head_dim
=
D
,
q_data_type
=
q
.
dtype
,
kv_data_type
=
k
.
dtype
,
)
# print_memory_usage("After plan")
q
=
q
.
reshape
(
B
*
H
,
S
,
D
)
k
=
k
.
reshape
(
B
*
H
,
S
,
D
)
v
=
v
.
reshape
(
B
*
H
,
S
,
D
)
o
=
wrapper
.
run
(
q
,
k
,
v
)
# [num_qo_heads, qo_len, head_dim]
o
=
o
.
reshape
(
B
,
H
,
S
,
D
)
return
o
def
semantic_aware_permutation
(
self
,
query
,
key
,
value
):
cfg
,
num_heads
,
seq_len
,
dim
=
query
.
size
()
# 1. Kmeans clustering
qlabels
,
qcentroids
,
qcluster_sizes
,
qiter
,
klabels
,
kcentroids
,
kcluster_sizes
,
kiter
=
self
.
kmeans_clustering
(
query
,
key
)
# 2. Identify dynamic map
q_cluster_sizes
=
qcluster_sizes
.
view
(
cfg
,
num_heads
,
self
.
num_q_centroids
)
k_cluster_sizes
=
kcluster_sizes
.
view
(
cfg
,
num_heads
,
self
.
num_k_centroids
)
dynamic_map
=
identify_dynamic_map
(
qcentroids
.
view
(
cfg
,
num_heads
,
self
.
num_q_centroids
,
dim
),
kcentroids
.
view
(
cfg
,
num_heads
,
self
.
num_k_centroids
,
dim
),
q_cluster_sizes
,
k_cluster_sizes
,
self
.
top_p_kmeans
,
self
.
min_kc_ratio
,
)
# 3. Permute the query, key, value
q_permuted
,
q_sorted_indices
=
permute_tensor_by_labels_triton
(
query
,
qlabels
,
dim
=
2
)
k_permuted
,
k_sorted_indices
=
permute_tensor_by_labels_triton
(
key
,
klabels
,
dim
=
2
)
v_permuted
,
v_sorted_indices
=
permute_tensor_by_labels_triton
(
value
,
klabels
,
dim
=
2
,
sorted_indices
=
k_sorted_indices
)
return
q_permuted
,
k_permuted
,
v_permuted
,
dynamic_map
,
q_cluster_sizes
,
k_cluster_sizes
,
q_sorted_indices
def
kmeans_clustering
(
self
,
query
,
key
):
if
not
self
.
centroids_init
:
qlabels
,
qcentroids
,
qcluster_sizes
,
qiter
,
klabels
,
kcentroids
,
kcluster_sizes
,
kiter
=
self
.
kmeans_init
(
query
,
key
)
self
.
centroids_init
=
True
else
:
qlabels
,
qcentroids
,
qcluster_sizes
,
qiter
,
klabels
,
kcentroids
,
kcluster_sizes
,
kiter
=
self
.
kmeans_step
(
query
,
key
)
return
qlabels
,
qcentroids
,
qcluster_sizes
,
qiter
,
klabels
,
kcentroids
,
kcluster_sizes
,
kiter
def
kmeans_init
(
self
,
query
,
key
):
cfg
,
num_heads
,
seq_len
,
dim
=
query
.
size
()
qlabels
,
qcentroids
,
qcluster_sizes
,
qiter
=
batch_kmeans_Euclid
(
query
.
view
(
cfg
*
num_heads
,
seq_len
,
dim
),
n_clusters
=
self
.
num_q_centroids
,
max_iters
=
self
.
kmeans_iter_init
)
klabels
,
kcentroids
,
kcluster_sizes
,
kiter
=
batch_kmeans_Euclid
(
key
.
view
(
cfg
*
num_heads
,
seq_len
,
dim
),
n_clusters
=
self
.
num_k_centroids
,
max_iters
=
self
.
kmeans_iter_init
)
self
.
q_centroids
=
qcentroids
self
.
k_centroids
=
kcentroids
return
qlabels
,
qcentroids
,
qcluster_sizes
,
qiter
,
klabels
,
kcentroids
,
kcluster_sizes
,
kiter
def
kmeans_step
(
self
,
query
,
key
):
cfg
,
num_heads
,
seq_len
,
dim
=
query
.
size
()
qlabels
,
qcentroids
,
qcluster_sizes
,
qiter
=
batch_kmeans_Euclid
(
query
.
view
(
cfg
*
num_heads
,
seq_len
,
dim
),
n_clusters
=
self
.
num_q_centroids
,
max_iters
=
self
.
kmeans_iter_step
,
init_centroids
=
self
.
q_centroids
,
)
klabels
,
kcentroids
,
kcluster_sizes
,
kiter
=
batch_kmeans_Euclid
(
key
.
view
(
cfg
*
num_heads
,
seq_len
,
dim
),
n_clusters
=
self
.
num_k_centroids
,
max_iters
=
self
.
kmeans_iter_step
,
init_centroids
=
self
.
k_centroids
,
)
self
.
q_centroids
=
qcentroids
self
.
k_centroids
=
kcentroids
return
qlabels
,
qcentroids
,
qcluster_sizes
,
qiter
,
klabels
,
kcentroids
,
kcluster_sizes
,
kiter
if
__name__
==
"__main__"
:
q
,
k
,
v
=
torch
.
randn
(
32130
,
40
,
128
,
dtype
=
torch
.
bfloat16
).
cuda
(),
torch
.
randn
(
32130
,
40
,
128
,
dtype
=
torch
.
bfloat16
).
cuda
(),
torch
.
randn
(
32130
,
40
,
128
,
dtype
=
torch
.
bfloat16
).
cuda
()
svg2_attn
=
Svg2AttnWeight
()
print
(
"Svg2AttnWeight initialized."
)
out
=
svg2_attn
.
apply
(
q
,
k
,
v
)
print
(
f
"out:
{
out
.
shape
}
,
{
out
.
dtype
}
,
{
out
.
device
}
"
)
lightx2v/common/ops/attn/svg2_attn_utils.py
0 → 100755
View file @
8b1e4f94
import
torch
import
torch.nn.functional
as
F
import
triton
import
triton.language
as
tl
try
:
from
cuvs.cluster.kmeans
import
KMeansParams
,
fit
except
ImportError
:
KMeansParams
=
None
fit
=
None
# --- New functions ---
def
density_calculation
(
dynamic_map
,
q_cluster_sizes
,
k_cluster_sizes
):
"""
Calculate the density of the dynamic map. Currently only batch size = 1 and head size = 1 are supported.
Input:
dynamic_map: [cfg, num_heads, qc_num, kc_num]
q_cluster_sizes: [cfg, num_heads, qc_num]
k_cluster_sizes: [cfg, num_heads, kc_num]
"""
cfg
,
num_heads
,
qc_num
,
kc_num
=
dynamic_map
.
shape
# Calculate the block size of each block
clustered_block_size
=
q_cluster_sizes
[:,
:,
:,
None
]
*
k_cluster_sizes
[:,
:,
None
,
:]
masked_block_size
=
clustered_block_size
*
dynamic_map
# Calculate the density of each block
density
=
torch
.
sum
(
masked_block_size
,
dim
=
(
2
,
3
))
/
torch
.
sum
(
clustered_block_size
,
dim
=
(
2
,
3
))
return
density
# --- Functions from analyze/kmeans_rapidai.py ---
def
pairwise_distance
(
x
,
y
):
"""
Computes pairwise squared Euclidean distance between two sets of points.
"""
x_norm
=
(
x
**
2
).
sum
(
1
).
view
(
-
1
,
1
)
y_norm
=
(
y
**
2
).
sum
(
1
).
view
(
1
,
-
1
)
dist
=
torch
.
clamp
(
x_norm
+
y_norm
-
2.0
*
torch
.
mm
(
x
,
torch
.
transpose
(
y
,
0
,
1
)),
min
=
0.0
)
return
dist
def
kmeans_predict
(
centroids
,
input_tensor
):
# Removed unused params argument
"""
Predict the labels for the input tensor using the centroids.
"""
input_tensor
=
input_tensor
.
to
(
torch
.
float32
)
dist
=
pairwise_distance
(
input_tensor
,
centroids
)
labels
=
torch
.
argmin
(
dist
,
dim
=
1
)
return
labels
def
kmeans_rapidai
(
tensor
,
k
,
max_iter
=
5
,
tol
=
1e-4
,
init_method
=
"Array"
,
centroids_init
=
None
):
# Renamed centroids to centroids_init
"""
Performs K-means clustering using cuVS.
"""
assert
tensor
.
dtype
==
torch
.
float32
,
"Tensor must be float32 for cuVS KMeans"
assert
tensor
.
ndim
==
2
,
f
"Tensor must be 2D, but got
{
tensor
.
shape
}
"
# assert init_method == "Array", "init_method must be 'Array' for now"
L
,
D
=
tensor
.
shape
# cuVS KMeans in RAPIDS >=23.10 uses 'centroids_init' for initial centroids
current_centroids
=
centroids_init
if
current_centroids
is
None
:
# Default init: cuVS handles KMeansPlusPlus if centroids_init is None and init_method is KMeansPlusPlus
# If you need to pass an empty tensor for cuVS to initialize:
current_centroids
=
torch
.
empty
(
k
,
D
,
device
=
tensor
.
device
,
dtype
=
torch
.
float32
)
# Or pass None
else
:
assert
current_centroids
.
dtype
==
torch
.
float32
,
"Initial centroids must be float32"
assert
current_centroids
.
shape
==
(
k
,
D
,
),
f
"Initial centroids shape mismatch, got
{
current_centroids
.
shape
}
, expected (
{
k
}
,
{
D
}
)"
# cuVS uses 'init_method="Array"' when 'centroids_init' is provided.
# import IPython; IPython.embed()
params
=
KMeansParams
(
n_clusters
=
k
,
max_iter
=
max_iter
,
tol
=
tol
,
init_method
=
init_method
)
# Changed init_method to init
# Call fit with centroids_init (can be None)
new_centroids
,
inertia
,
n_iter_
=
fit
(
params
,
tensor
,
current_centroids
)
# Added handle=None
labels
=
kmeans_predict
(
new_centroids
,
tensor
)
return
labels
,
new_centroids
,
n_iter_
@
triton
.
jit
def
_centroid_update_kernel
(
x_ptr
,
# *f16 [B, N, D]
cluster_ptr
,
# *i32 [B, N]
sum_ptr
,
# *f32 [B, K, D]
count_ptr
,
# *i32 [B, K]
B
:
tl
.
constexpr
,
N
:
tl
.
constexpr
,
D
:
tl
.
constexpr
,
K
:
tl
.
constexpr
,
BLOCK_D
:
tl
.
constexpr
,
# number of dims processed per program
):
"""Each program processes 1 point (token) across BLOCK_D dimensions with atomics."""
pid
=
tl
.
program_id
(
axis
=
0
)
token_idx
=
pid
# range: [0, B * N)
# Derive (b, n) indices
b
=
token_idx
//
N
n
=
token_idx
%
N
# Pointers to the token features and its cluster id
x_offset
=
(
b
*
N
+
n
)
*
D
x_ptr
=
x_ptr
+
x_offset
cluster_idx
=
tl
.
load
(
cluster_ptr
+
b
*
N
+
n
)
# int32
# Guard for invalid cluster ids (should not happen)
cluster_idx
=
tl
.
where
(
cluster_idx
<
K
,
cluster_idx
,
0
)
# Base pointer for this centroid in the output sum tensor
centroid_base
=
(
b
*
K
+
cluster_idx
)
*
D
# Process feature vector in chunks of BLOCK_D
offs
=
tl
.
arange
(
0
,
BLOCK_D
)
for
d_start
in
range
(
0
,
D
,
BLOCK_D
):
mask
=
offs
+
d_start
<
D
feats
=
tl
.
load
(
x_ptr
+
d_start
+
offs
,
mask
=
mask
,
other
=
0.0
)
feats
=
feats
.
to
(
tl
.
float32
)
dest_ptr
=
sum_ptr
+
centroid_base
+
d_start
+
offs
tl
.
atomic_add
(
dest_ptr
,
feats
,
mask
=
mask
)
# Update counts (only once per point)
tl
.
atomic_add
(
count_ptr
+
b
*
K
+
cluster_idx
,
1
)
def
triton_centroid_update_cosine
(
x_norm
:
torch
.
Tensor
,
cluster_ids
:
torch
.
Tensor
,
old_centroids
:
torch
.
Tensor
):
"""Compute centroids using custom Triton kernel.
Args:
x_norm (Tensor): (B, N, D) normalized input vectors (float16/float32)
cluster_ids (LongTensor): (B, N) cluster assignment per point
old_centroids (Tensor): (B, K, D) previous centroids (same dtype as x_norm)
Returns:
Tensor: (B, K, D) updated and L2-normalized centroids (dtype == x_norm.dtype)
"""
assert
x_norm
.
is_cuda
and
cluster_ids
.
is_cuda
,
"Input tensors must be on CUDA device"
B
,
N
,
D
=
x_norm
.
shape
K
=
old_centroids
.
shape
[
1
]
assert
cluster_ids
.
shape
==
(
B
,
N
)
# Allocate accumulation buffers
centroid_sums
=
torch
.
zeros
((
B
,
K
,
D
),
device
=
x_norm
.
device
,
dtype
=
torch
.
float32
)
centroid_counts
=
torch
.
zeros
((
B
,
K
),
device
=
x_norm
.
device
,
dtype
=
torch
.
int32
)
# Launch Triton kernel – one program per token
total_tokens
=
B
*
N
BLOCK_D
=
128
# tuneable
grid
=
(
total_tokens
,)
_centroid_update_kernel
[
grid
](
x_norm
,
cluster_ids
.
to
(
torch
.
int32
),
centroid_sums
,
centroid_counts
,
B
,
N
,
D
,
K
,
BLOCK_D
=
BLOCK_D
,
)
# Compute means; keep old centroid if empty cluster
counts_f
=
centroid_counts
.
to
(
torch
.
float32
).
unsqueeze
(
-
1
).
clamp
(
min
=
1.0
)
centroids
=
centroid_sums
/
counts_f
# For clusters with zero count, revert to old centroids
zero_mask
=
(
centroid_counts
==
0
).
unsqueeze
(
-
1
)
centroids
=
torch
.
where
(
zero_mask
,
old_centroids
.
to
(
torch
.
float32
),
centroids
)
centroids
=
centroids
.
to
(
x_norm
.
dtype
)
centroids
=
F
.
normalize
(
centroids
,
p
=
2
,
dim
=-
1
)
return
centroids
def
torch_loop_centroid_update_cosine
(
x_norm
:
torch
.
Tensor
,
cluster_ids
:
torch
.
Tensor
,
old_centroids
:
torch
.
Tensor
):
"""Reference Python implementation (double for-loop)"""
B
,
N
,
D
=
x_norm
.
shape
K
=
old_centroids
.
shape
[
1
]
new_centroids
=
torch
.
zeros_like
(
old_centroids
)
for
b
in
range
(
B
):
for
k
in
range
(
K
):
mask
=
cluster_ids
[
b
]
==
k
if
mask
.
any
():
new_centroids
[
b
,
k
]
=
F
.
normalize
(
x_norm
[
b
][
mask
].
mean
(
dim
=
0
,
dtype
=
x_norm
.
dtype
),
p
=
2
,
dim
=
0
)
else
:
new_centroids
[
b
,
k
]
=
old_centroids
[
b
,
k
]
return
new_centroids
def
triton_centroid_update_euclid
(
x
:
torch
.
Tensor
,
cluster_ids
:
torch
.
Tensor
,
old_centroids
:
torch
.
Tensor
):
"""Compute centroids for Euclidean KMeans using Triton.
Args:
x (Tensor): (B, N, D) input vectors (float16/float32)
cluster_ids (LongTensor): (B, N) cluster assignment per point
old_centroids (Tensor): (B, K, D) previous centroids (same dtype as x)
Returns:
Tensor: (B, K, D) updated centroids (dtype == x.dtype)
"""
assert
x
.
is_cuda
and
cluster_ids
.
is_cuda
,
"Input tensors must be on CUDA device"
B
,
N
,
D
=
x
.
shape
K
=
old_centroids
.
shape
[
1
]
assert
cluster_ids
.
shape
==
(
B
,
N
)
# Allocate accumulation buffers
centroid_sums
=
torch
.
zeros
((
B
,
K
,
D
),
device
=
x
.
device
,
dtype
=
torch
.
float32
)
centroid_counts
=
torch
.
zeros
((
B
,
K
),
device
=
x
.
device
,
dtype
=
torch
.
int32
)
total_tokens
=
B
*
N
BLOCK_D
=
128
# tuneable
grid
=
(
total_tokens
,)
_centroid_update_kernel
[
grid
](
x
,
cluster_ids
.
to
(
torch
.
int32
),
centroid_sums
,
centroid_counts
,
B
,
N
,
D
,
K
,
BLOCK_D
=
BLOCK_D
,
)
# Compute means; keep old centroid if empty cluster
counts_f
=
centroid_counts
.
to
(
torch
.
float32
).
unsqueeze
(
-
1
).
clamp
(
min
=
1.0
)
centroids
=
centroid_sums
/
counts_f
# For clusters with zero count, revert to old centroids
zero_mask
=
(
centroid_counts
==
0
).
unsqueeze
(
-
1
)
centroids
=
torch
.
where
(
zero_mask
,
old_centroids
.
to
(
torch
.
float32
),
centroids
)
return
centroids
.
to
(
x
.
dtype
)
# ------------------------------ NEW: chunk-wise centroid update (sorted ids) ------------------------------
@
triton
.
jit
def
_centroid_update_chunk_kernel
(
x_ptr
,
# *f16 / *f32 [B, N, D] – ORIGINAL ORDER
sorted_idx_ptr
,
# *i32 [B, N] – indices after sort
sorted_cluster_ptr
,
# *i32 [B, N] – cluster ids in sorted order
sum_ptr
,
# *f32 [B, K, D]
count_ptr
,
# *i32 [B, K]
B
:
tl
.
constexpr
,
N
:
tl
.
constexpr
,
D
:
tl
.
constexpr
,
K
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
# how many tokens (points) each program processes
):
"""Each program processes **BLOCK_N consecutive, already-sorted tokens**.
Because the tokens are sorted by cluster id, identical ids appear in
contiguous runs. We therefore accumulate a local sum/count for the
current run and perform **a single atomic update per run**, instead of
per-token.
"""
# program indices – 2-D launch grid: (chunk_id, batch_id)
pid_chunk
=
tl
.
program_id
(
axis
=
0
)
pid_b
=
tl
.
program_id
(
axis
=
1
)
b
=
pid_b
chunk_start
=
pid_chunk
*
BLOCK_N
# position of the first token handled by this program
# Nothing to do – out of range
if
chunk_start
>=
N
:
return
# base pointers for this batch
idx_batch_base
=
sorted_idx_ptr
+
b
*
N
cid_batch_base
=
sorted_cluster_ptr
+
b
*
N
x_batch_base
=
x_ptr
+
b
*
N
*
D
# for pointer arithmetic
# helper aranges
offs_token
=
tl
.
arange
(
0
,
BLOCK_N
)
offs_dim
=
tl
.
arange
(
0
,
D
)
# first token index & validity mask
token_idx
=
chunk_start
+
offs_token
valid_tok
=
token_idx
<
N
first_token_idx
=
chunk_start
last_token_idx
=
tl
.
minimum
(
chunk_start
+
BLOCK_N
,
N
)
-
1
# Load first cluster id to initialise the running accumulator
first_id
=
tl
.
load
(
cid_batch_base
+
first_token_idx
)
last_id
=
tl
.
load
(
cid_batch_base
+
last_token_idx
)
all_ids
=
tl
.
load
(
cid_batch_base
+
token_idx
,
mask
=
valid_tok
,
other
=-
1
)
all_tokens_idxs
=
tl
.
load
(
idx_batch_base
+
token_idx
,
mask
=
valid_tok
,
other
=-
1
)
# [BLOCK_N]
load_mask
=
all_tokens_idxs
[:,
None
]
*
D
+
offs_dim
[
None
,
:]
for
cid
in
range
(
first_id
,
last_id
+
1
):
cluster_mask
=
all_ids
==
cid
cluster_size
=
tl
.
sum
(
cluster_mask
.
to
(
tl
.
int32
))
if
cluster_size
!=
0
:
cluster_feats
=
tl
.
load
(
x_batch_base
+
load_mask
,
mask
=
cluster_mask
[:,
None
],
other
=
0.0
)
# [BLOCK_N, D]
cluster_feats
=
cluster_feats
.
to
(
tl
.
float32
)
sum_feats
=
tl
.
sum
(
cluster_feats
,
axis
=
0
)
dest_ptr
=
sum_ptr
+
(
b
*
K
+
cid
)
*
D
+
offs_dim
tl
.
atomic_add
(
dest_ptr
,
sum_feats
)
tl
.
atomic_add
(
count_ptr
+
b
*
K
+
cid
,
cluster_size
)
# ---------------------------------------------------------------------------------------------
def
triton_centroid_update_sorted_cosine
(
x_norm
:
torch
.
Tensor
,
cluster_ids
:
torch
.
Tensor
,
old_centroids
:
torch
.
Tensor
,
*
,
BLOCK_N
:
int
=
256
):
"""Fast centroid update assuming **cluster_ids are sorted along N**.
This helper will sort the assignments (together with `x_norm`) and launch the
chunk kernel above. Compared to the naive per-token kernel it performs *one
atomic add per run of identical ids* instead of per token, providing large
speed-ups when clusters are reasonably sized.
"""
assert
x_norm
.
is_cuda
and
cluster_ids
.
is_cuda
,
"Inputs must be on CUDA"
B
,
N
,
D
=
x_norm
.
shape
K
=
old_centroids
.
shape
[
1
]
assert
cluster_ids
.
shape
==
(
B
,
N
)
# -------- sort per-batch --------
sorted_cluster_ids
,
sorted_idx
=
torch
.
sort
(
cluster_ids
,
dim
=-
1
)
sorted_idx_int
=
sorted_idx
.
to
(
torch
.
int32
)
# accumulation buffers
centroid_sums
=
torch
.
zeros
((
B
,
K
,
D
),
device
=
x_norm
.
device
,
dtype
=
torch
.
float32
)
centroid_cnts
=
torch
.
zeros
((
B
,
K
),
device
=
x_norm
.
device
,
dtype
=
torch
.
int32
)
grid
=
(
triton
.
cdiv
(
N
,
BLOCK_N
),
B
)
_centroid_update_chunk_kernel
[
grid
](
x_norm
,
sorted_idx_int
,
sorted_cluster_ids
.
to
(
torch
.
int32
),
centroid_sums
,
centroid_cnts
,
B
,
N
,
D
,
K
,
BLOCK_N
=
BLOCK_N
,
)
# finalise – convert to means, handle empty clusters, renormalise
counts_f
=
centroid_cnts
.
to
(
torch
.
float32
).
unsqueeze
(
-
1
).
clamp
(
min
=
1.0
)
centroids
=
centroid_sums
/
counts_f
empty_mask
=
(
centroid_cnts
==
0
).
unsqueeze
(
-
1
)
centroids
=
torch
.
where
(
empty_mask
,
old_centroids
.
to
(
torch
.
float32
),
centroids
)
centroids
=
centroids
.
to
(
x_norm
.
dtype
)
centroids
=
F
.
normalize
(
centroids
,
p
=
2
,
dim
=-
1
)
return
centroids
def
triton_centroid_update_sorted_euclid
(
x
:
torch
.
Tensor
,
cluster_ids
:
torch
.
Tensor
,
old_centroids
:
torch
.
Tensor
,
*
,
BLOCK_N
:
int
=
256
):
"""Fast centroid update for *Euclidean* KMeans assuming cluster IDs are pre-sorted.
Parameters
----------
x : Tensor [B, N, D]
Input feature vectors (no normalization assumed).
cluster_ids : LongTensor [B, N]
Cluster assignment for each point.
old_centroids : Tensor [B, K, D]
Previous centroids (used to fill empty clusters).
BLOCK_N : int, optional
Tokens per Triton program (affects occupancy/perf).
"""
assert
x
.
is_cuda
and
cluster_ids
.
is_cuda
,
"Inputs must be on CUDA device"
B
,
N
,
D
=
x
.
shape
K
=
old_centroids
.
shape
[
1
]
# Batch-wise sort of cluster assignments
sorted_cluster_ids
,
sorted_idx
=
torch
.
sort
(
cluster_ids
,
dim
=-
1
)
sorted_idx_int
=
sorted_idx
.
to
(
torch
.
int32
)
centroid_sums
=
torch
.
zeros
((
B
,
K
,
D
),
device
=
x
.
device
,
dtype
=
torch
.
float32
)
centroid_cnts
=
torch
.
zeros
((
B
,
K
),
device
=
x
.
device
,
dtype
=
torch
.
int32
)
grid
=
(
triton
.
cdiv
(
N
,
BLOCK_N
),
B
)
_centroid_update_chunk_kernel
[
grid
](
x
,
# original features
sorted_idx_int
,
# gather indices
sorted_cluster_ids
.
to
(
torch
.
int32
),
centroid_sums
,
centroid_cnts
,
B
,
N
,
D
,
K
,
BLOCK_N
=
BLOCK_N
,
)
# Convert sums to means; replace empty clusters with old centroids
counts_f
=
centroid_cnts
.
to
(
torch
.
float32
).
unsqueeze
(
-
1
).
clamp
(
min
=
1.0
)
centroids
=
centroid_sums
/
counts_f
empty_mask
=
(
centroid_cnts
==
0
).
unsqueeze
(
-
1
)
centroids
=
torch
.
where
(
empty_mask
,
old_centroids
.
to
(
torch
.
float32
),
centroids
)
return
centroids
.
to
(
x
.
dtype
),
centroid_cnts
# ===============================================================
# Triton kernel: compute nearest-centroid IDs (Euclidean distance)
# Inputs:
# x : (B, N, D) float16 / float32
# centroids : (B, K, D) same dtype as x
# x_sq : (B, N) float32 – pre-computed ||x||^2 per point
# Output:
# cluster_ids : (B, N) int32 – nearest centroid index per point
# ===============================================================
def
_ceil_div
(
a
:
int
,
b
:
int
)
->
int
:
return
(
a
+
b
-
1
)
//
b
# -----------------------------------------------------------------------------
# Auto-tuning setup – explore various tile sizes / warp counts
# -----------------------------------------------------------------------------
_TUNE_CONFIGS
=
[
triton
.
Config
({
"BLOCK_N"
:
BN
,
"BLOCK_K"
:
BK
},
num_stages
=
4
,
num_warps
=
wp
)
for
BN
in
[
32
,
64
,
128
]
for
BK
in
[
32
,
64
,
128
]
for
wp
in
[
4
,
8
]]
def
_cfg_keep
(
conf
):
"""Basic heuristic to prune unbalanced configs."""
BN
=
conf
.
kwargs
[
"BLOCK_N"
]
BK
=
conf
.
kwargs
[
"BLOCK_K"
]
# Avoid tiny tiles on many warps
if
BN
*
BK
<
32
*
32
and
conf
.
num_warps
>
4
:
return
False
return
True
_TUNE_CONFIGS
=
list
(
filter
(
_cfg_keep
,
_TUNE_CONFIGS
))
@
triton
.
autotune
(
_TUNE_CONFIGS
,
key
=
[
"N"
,
"K"
])
@
triton
.
jit
def
_euclid_assign_kernel
(
x_ptr
,
# *f16 / *f32 [B, N, D]
c_ptr
,
# *f16 / *f32 [B, K, D]
x_sq_ptr
,
# *f32 [B, N]
out_ptr
,
# *i32 [B, N]
B
:
tl
.
constexpr
,
N
:
tl
.
constexpr
,
K
:
tl
.
constexpr
,
D
:
tl
.
constexpr
,
stride_x_b
:
tl
.
constexpr
,
stride_x_n
:
tl
.
constexpr
,
stride_x_d
:
tl
.
constexpr
,
stride_c_b
:
tl
.
constexpr
,
stride_c_k
:
tl
.
constexpr
,
stride_c_d
:
tl
.
constexpr
,
stride_xsq_b
:
tl
.
constexpr
,
stride_xsq_n
:
tl
.
constexpr
,
stride_out_b
:
tl
.
constexpr
,
stride_out_n
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_K
:
tl
.
constexpr
,
):
"""Each program handles a tile of BLOCK_N points for a given batch element.
The kernel iterates over the centroid dimension K in chunks of BLOCK_K and
maintains the running minimum distance as well as the corresponding index
for every point in the tile.
"""
pid_n
=
tl
.
program_id
(
0
)
# tile index along N dimension
pid_b
=
tl
.
program_id
(
1
)
# batch index
n_start
=
pid_n
*
BLOCK_N
n_offsets
=
n_start
+
tl
.
arange
(
0
,
BLOCK_N
)
n_mask
=
n_offsets
<
N
# ------------------------------------------------------------------
# Load x tile (BLOCK_N, D)
# ------------------------------------------------------------------
offs_d
=
tl
.
arange
(
0
,
D
)
# Compute pointer for x block: base + b*stride_x_b + n*stride_x_n + d*stride_x_d
x_ptrs
=
x_ptr
+
pid_b
*
stride_x_b
+
n_offsets
[:,
None
]
*
stride_x_n
+
offs_d
[
None
,
:]
*
stride_x_d
x_tile
=
tl
.
load
(
x_ptrs
,
mask
=
n_mask
[:,
None
],
other
=
0.0
)
x_tile
=
x_tile
# compute in f32
# Pre-load x_sq for the tile (BLOCK_N,)
xsq_ptrs
=
x_sq_ptr
+
pid_b
*
stride_xsq_b
+
n_offsets
*
stride_xsq_n
x_sq_tile
=
tl
.
load
(
xsq_ptrs
,
mask
=
n_mask
,
other
=
0.0
).
to
(
tl
.
float32
)
# Init best distance / index
best_dist
=
tl
.
full
((
BLOCK_N
,),
3.4e38
,
tl
.
float32
)
# large number
best_idx
=
tl
.
zeros
((
BLOCK_N
,),
tl
.
int32
)
# ------------------------------------------------------------------
# Iterate over centroids in chunks of BLOCK_K
# ------------------------------------------------------------------
for
k_start
in
range
(
0
,
K
,
BLOCK_K
):
k_offsets
=
k_start
+
tl
.
arange
(
0
,
BLOCK_K
)
k_mask
=
k_offsets
<
K
# Load centroid tile (D, BLOCK_K)
c_ptrs
=
c_ptr
+
pid_b
*
stride_c_b
+
k_offsets
[
None
,
:]
*
stride_c_k
+
offs_d
[:,
None
]
*
stride_c_d
c_tile
=
tl
.
load
(
c_ptrs
,
mask
=
k_mask
[
None
,
:],
other
=
0.0
)
c_tile
=
c_tile
# Compute centroid squared norms (BLOCK_K,)
cent_sq
=
tl
.
sum
(
c_tile
*
c_tile
,
axis
=
0
).
to
(
tl
.
float32
)
# Compute cross term (BLOCK_N, BLOCK_K) = x_tile @ c_tile
cross
=
tl
.
dot
(
x_tile
,
c_tile
).
to
(
tl
.
float32
)
# float32
# Squared Euclidean distance
dist
=
x_sq_tile
[:,
None
]
+
cent_sq
[
None
,
:]
-
2.0
*
cross
dist
=
tl
.
maximum
(
dist
,
0.0
)
# Mask out invalid centroid columns before reduction
dist
=
tl
.
where
(
k_mask
[
None
,
:],
dist
,
3.4e38
)
curr_min
=
tl
.
min
(
dist
,
axis
=
1
)
curr_idx
=
tl
.
argmin
(
dist
,
axis
=
1
)
update
=
curr_min
<
best_dist
best_dist
=
tl
.
where
(
update
,
curr_min
,
best_dist
)
best_idx
=
tl
.
where
(
update
,
k_start
+
curr_idx
,
best_idx
)
# ------------------------------------------------------------------
# Write results
# ------------------------------------------------------------------
out_ptrs
=
out_ptr
+
pid_b
*
stride_out_b
+
n_offsets
*
stride_out_n
tl
.
store
(
out_ptrs
,
best_idx
,
mask
=
n_mask
)
# ---------------------------------------------------------------
# Python wrapper
# ---------------------------------------------------------------
def
euclid_assign_triton
(
x
:
torch
.
Tensor
,
centroids
:
torch
.
Tensor
,
x_sq
:
torch
.
Tensor
,
out
:
torch
.
Tensor
=
None
,
*
,
BLOCK_N
:
int
=
128
,
BLOCK_K
:
int
=
128
,
)
->
torch
.
Tensor
:
"""Return nearest-centroid indices using Triton kernel.
Args:
x : (B, N, D) float16 / float32 (on CUDA)
centroids : (B, K, D) same dtype/device as x
x_sq : (B, N) float32 – ||x||^2 per point (on CUDA)
Returns:
cluster_ids (B, N) int32 (callers can cast to int64 if desired)
"""
assert
x
.
is_cuda
and
centroids
.
is_cuda
and
x_sq
.
is_cuda
,
"All tensors must be on CUDA"
# assert x.dtype in (torch.float16, torch.float32), "x must be fp16/fp32"
assert
centroids
.
dtype
==
x
.
dtype
,
"centroids dtype mismatch"
B
,
N
,
D
=
x
.
shape
K
=
centroids
.
shape
[
1
]
assert
centroids
.
shape
==
(
B
,
K
,
D
),
"centroids shape mismatch"
assert
x_sq
.
shape
==
(
B
,
N
),
"x_sq shape mismatch"
# x = x.contiguous()
# centroids = centroids.contiguous()
# x_sq = x_sq.contiguous()
if
out
is
None
:
out
=
torch
.
empty
((
B
,
N
),
device
=
x
.
device
,
dtype
=
torch
.
int64
)
# Strides (in elements)
stride_x_b
,
stride_x_n
,
stride_x_d
=
x
.
stride
()
stride_c_b
,
stride_c_k
,
stride_c_d
=
centroids
.
stride
()
stride_xsq_b
,
stride_xsq_n
=
x_sq
.
stride
()
stride_out_b
,
stride_out_n
=
out
.
stride
()
grid
=
lambda
META
:
(
triton
.
cdiv
(
N
,
META
[
"BLOCK_N"
]),
B
)
# noqa
_euclid_assign_kernel
[
grid
](
x
,
centroids
,
x_sq
,
out
,
B
,
N
,
K
,
D
,
stride_x_b
,
stride_x_n
,
stride_x_d
,
stride_c_b
,
stride_c_k
,
stride_c_d
,
stride_xsq_b
,
stride_xsq_n
,
stride_out_b
,
stride_out_n
,
)
return
out
# 1. Euclidean
def
_euclid_iter
(
x
,
x_sq
,
centroids
):
# cent_sq = (centroids ** 2).sum(dim=-1)
# cross = torch.einsum('bnd,bkd->bnk', x, centroids)
# dist_sq = (x_sq[:,:,None] + cent_sq[:,None,:] - 2.0 * cross).clamp_min_(0.0)
# cluster_ids = dist_sq.argmin(dim=-1)
cluster_ids
=
euclid_assign_triton
(
x
,
centroids
,
x_sq
)
centroids_new
,
cluster_sizes
=
triton_centroid_update_sorted_euclid
(
x
,
cluster_ids
,
centroids
)
# centroids_new = triton_centroid_update_euclid(x, cluster_ids, centroids)
# centroids_new = centroids_new.clone() # avoid CUDA graphs aliasing
shift
=
(
centroids_new
-
centroids
).
norm
(
dim
=-
1
).
max
()
return
centroids_new
,
shift
,
cluster_ids
,
cluster_sizes
# 2. Cosine
def
_cosine_iter
(
x_norm
,
centroids
):
cos_sim
=
torch
.
einsum
(
"bnd,bkd->bnk"
,
x_norm
,
centroids
)
cluster_ids
=
cos_sim
.
argmax
(
dim
=-
1
)
centroids_new
=
triton_centroid_update_cosine
(
x_norm
,
cluster_ids
,
centroids
)
# centroids_new = centroids_new.clone()
shift
=
(
centroids_new
-
centroids
).
norm
(
dim
=-
1
).
max
()
return
centroids_new
,
shift
,
cluster_ids
# 3. Dot-product
def
_dot_iter
(
x
,
centroids
):
sim
=
torch
.
einsum
(
"bnd,bkd->bnk"
,
x
,
centroids
)
cluster_ids
=
sim
.
argmax
(
dim
=-
1
)
centroids_new
=
triton_centroid_update_cosine
(
x
,
cluster_ids
,
centroids
)
# centroids_new = centroids_new.clone()
shift
=
(
centroids_new
-
centroids
).
norm
(
dim
=-
1
).
max
()
return
centroids_new
,
shift
,
cluster_ids
COMPILE_FLAG
=
False
# Try to compile; if PyTorch < 2.0 or compile is not available, fallback to original function
try
:
if
COMPILE_FLAG
:
_euclid_iter_compiled
=
torch
.
compile
(
_euclid_iter
,
dynamic
=
True
,
mode
=
"reduce-overhead"
)
_cosine_iter_compiled
=
torch
.
compile
(
_cosine_iter
,
dynamic
=
True
,
mode
=
"reduce-overhead"
)
_dot_iter_compiled
=
torch
.
compile
(
_dot_iter
,
dynamic
=
True
,
mode
=
"reduce-overhead"
)
else
:
_euclid_iter_compiled
=
_euclid_iter
_cosine_iter_compiled
=
_cosine_iter
_dot_iter_compiled
=
_dot_iter
except
Exception
:
# pragma: no cover
_euclid_iter_compiled
=
_euclid_iter
_cosine_iter_compiled
=
_cosine_iter
_dot_iter_compiled
=
_dot_iter
def
batch_kmeans_Euclid
(
x
,
n_clusters
,
max_iters
=
100
,
tol
=
1e-4
,
init_centroids
=
None
,
verbose
=
False
):
"""
Batched KMeans clustering in PyTorch using Euclidean distance.
Args:
x: Tensor of shape (B, N, D), batch_size B, N points per batch, D dims.
n_clusters: Number of clusters.
max_iters: Max number of iterations.
tol: Relative tolerance for center movement.
verbose: Print loss for each iter.
Returns:
cluster_ids: (B, N) LongTensor, cluster assignment for each point.
centroids: (B, n_clusters, D) final cluster centers.
cluster_sizes: (B, n_clusters) LongTensor, number of points per cluster.
n_iters: actual number of iterations executed (int)
"""
B
,
N
,
D
=
x
.
shape
# Pre-compute squared L2 norm of all points (constant during iterations)
x_sq
=
(
x
**
2
).
sum
(
dim
=-
1
)
# (B, N)
if
init_centroids
is
None
:
# Randomly select initial centers from x
indices
=
torch
.
randint
(
0
,
N
,
(
B
,
n_clusters
),
device
=
x
.
device
)
centroids
=
torch
.
gather
(
x
,
dim
=
1
,
index
=
indices
[...,
None
].
expand
(
-
1
,
-
1
,
D
))
# (B, n_clusters, D)
else
:
# centroids = init_centroids.clone()
centroids
=
init_centroids
centroids
=
centroids
.
view
(
B
,
n_clusters
,
D
)
for
it
in
range
(
max_iters
):
# ---- compiled single iteration ----
centroids_new
,
center_shift
,
cluster_ids
,
cluster_sizes
=
_euclid_iter_compiled
(
x
,
x_sq
,
centroids
)
# 4. Check for convergence
if
verbose
:
print
(
f
"Iter
{
it
}
, center shift:
{
center_shift
.
item
():.
6
f
}
"
)
if
center_shift
<
tol
:
break
# centroids = centroids_new.clone()
centroids
=
centroids_new
# # --- compute cluster sizes ---
# ones = torch.ones_like(cluster_ids, dtype=torch.int64)
# cluster_sizes = torch.zeros(B, n_clusters, dtype=torch.int64, device=x.device)
# cluster_sizes.scatter_add_(1, cluster_ids, ones)
return
cluster_ids
,
centroids
,
cluster_sizes
,
it
+
1
# return cluster_ids.clone(), centroids.clone(), cluster_sizes.clone(), it + 1
# batch_kmeans_Euclid = torch.compile(batch_kmeans_Euclid, dynamic=True, mode="reduce-overhead")
def
batch_kmeans_Cosine
(
x
,
n_clusters
,
max_iters
=
100
,
tol
=
1e-4
,
init_centroids
=
None
,
verbose
=
False
):
"""
Batched KMeans clustering in PyTorch using Cosine similarity.
Args:
x: Tensor of shape (B, N, D), batch_size B, N points per batch, D dims.
n_clusters: Number of clusters.
max_iters: Max number of iterations.
tol: Relative tolerance for center movement.
verbose: Print loss for each iter.
Returns:
cluster_ids: (B, N) LongTensor, cluster assignment for each point.
centroids: (B, n_clusters, D) final cluster centers.
cluster_sizes: (B, n_clusters) LongTensor, number of points per cluster.
n_iters: actual number of iterations executed (int)
"""
B
,
N
,
D
=
x
.
shape
# Normalize input vectors for cosine similarity
x_norm
=
F
.
normalize
(
x
,
p
=
2
,
dim
=-
1
)
# (B, N, D)
if
init_centroids
is
None
:
# Randomly select initial centers from x_norm
indices
=
torch
.
randint
(
0
,
N
,
(
B
,
n_clusters
),
device
=
x
.
device
)
centroids
=
torch
.
gather
(
x_norm
,
dim
=
1
,
index
=
indices
[...,
None
].
expand
(
-
1
,
-
1
,
D
))
# (B, n_clusters, D)
else
:
centroids
=
init_centroids
centroids
=
centroids
.
view
(
B
,
n_clusters
,
D
)
centroids
=
F
.
normalize
(
centroids
,
p
=
2
,
dim
=-
1
)
# Ensure centroids are normalized
for
it
in
range
(
max_iters
):
# ---- compiled single iteration ----
centroids_new
,
center_shift
,
cluster_ids
=
_cosine_iter_compiled
(
x_norm
,
centroids
)
# 4. Check for convergence
if
verbose
:
print
(
f
"Iter
{
it
}
, center shift:
{
center_shift
.
item
():.
6
f
}
"
)
if
center_shift
<
tol
:
break
centroids
=
centroids_new
.
clone
()
# --- compute cluster sizes ---
ones
=
torch
.
ones_like
(
cluster_ids
,
dtype
=
torch
.
int64
)
cluster_sizes
=
torch
.
zeros
(
B
,
n_clusters
,
dtype
=
torch
.
int64
,
device
=
x
.
device
)
cluster_sizes
.
scatter_add_
(
1
,
cluster_ids
,
ones
)
return
cluster_ids
,
centroids
,
cluster_sizes
,
it
+
1
def
batch_kmeans_Dot
(
x
,
n_clusters
,
max_iters
=
100
,
tol
=
1e-4
,
init_centroids
=
None
,
verbose
=
False
):
"""
Batched KMeans clustering in PyTorch using raw dot-product as similarity.
"""
B
,
N
,
D
=
x
.
shape
if
init_centroids
is
None
:
# Randomly initialize centroids
indices
=
torch
.
randint
(
0
,
N
,
(
B
,
n_clusters
),
device
=
x
.
device
)
centroids
=
torch
.
gather
(
x
,
dim
=
1
,
index
=
indices
[...,
None
].
expand
(
-
1
,
-
1
,
D
))
else
:
centroids
=
init_centroids
centroids
=
centroids
.
view
(
B
,
n_clusters
,
D
)
for
it
in
range
(
max_iters
):
# ---- compiled single iteration ----
centroids_new
,
center_shift
,
cluster_ids
=
_dot_iter_compiled
(
x
,
centroids
)
# 4. Check for convergence
if
verbose
:
print
(
f
"Iter
{
it
}
(dot), center shift:
{
center_shift
.
item
():.
6
f
}
"
)
if
center_shift
<
tol
:
break
centroids
=
centroids_new
.
clone
()
# --- compute cluster sizes ---
ones
=
torch
.
ones_like
(
cluster_ids
,
dtype
=
torch
.
int64
)
cluster_sizes
=
torch
.
zeros
(
B
,
n_clusters
,
dtype
=
torch
.
int64
,
device
=
x
.
device
)
cluster_sizes
.
scatter_add_
(
1
,
cluster_ids
,
ones
)
return
cluster_ids
,
centroids
,
cluster_sizes
,
it
+
1
# --- Functions from analyze/kmeans_block_sparse_attention.py (helpers) ---
def
permute_tensor_by_labels
(
tensor
,
labels
,
dim
):
labels
=
labels
.
to
(
tensor
.
device
)
sorted_indices
=
torch
.
argsort
(
labels
,
dim
=-
1
)
gather_indices
=
sorted_indices
for
i
in
range
(
dim
+
1
,
tensor
.
dim
()):
gather_indices
=
gather_indices
.
unsqueeze
(
-
1
)
expand_shape
=
list
(
tensor
.
shape
)
gather_indices
=
gather_indices
.
expand
(
expand_shape
)
permuted_tensor
=
torch
.
gather
(
tensor
,
dim
,
gather_indices
)
return
permuted_tensor
,
sorted_indices
def
apply_inverse_permutation
(
permuted_tensor
,
sorted_indices
,
dim
):
inverse_indices
=
torch
.
argsort
(
sorted_indices
,
dim
=-
1
)
gather_indices
=
inverse_indices
for
i
in
range
(
dim
+
1
,
permuted_tensor
.
dim
()):
gather_indices
=
gather_indices
.
unsqueeze
(
-
1
)
gather_indices
=
gather_indices
.
expand
(
permuted_tensor
.
shape
)
original_tensor
=
torch
.
gather
(
permuted_tensor
,
dim
,
gather_indices
)
return
original_tensor
def
weighted_softmax
(
scores
,
weights
):
input_dtype
=
scores
.
dtype
scores
=
scores
.
float
()
weights
=
weights
.
float
()
max_score
=
torch
.
max
(
scores
,
dim
=-
1
,
keepdim
=
True
)[
0
]
exp_scores
=
torch
.
exp
(
scores
-
max_score
)
weighted_exp
=
weights
*
exp_scores
softmax_out
=
weighted_exp
/
torch
.
sum
(
weighted_exp
,
dim
=-
1
,
keepdim
=
True
).
clamp
(
min
=
1e-12
)
return
softmax_out
.
to
(
input_dtype
)
def
identify_dynamic_map
(
query_centroids
,
key_centroids
,
q_cluster_sizes
,
k_cluster_sizes
,
p
,
min_kc_ratio
=
0
,
):
B
,
H
,
qc_num
,
D
=
query_centroids
.
shape
kc_num
=
key_centroids
.
shape
[
2
]
device
=
query_centroids
.
device
attn_scores
=
torch
.
matmul
(
query_centroids
,
key_centroids
.
transpose
(
-
2
,
-
1
))
/
(
D
**
0.5
)
k_weights
=
k_cluster_sizes
.
unsqueeze
(
-
2
).
float
()
weighted_attn_probs
=
weighted_softmax
(
attn_scores
,
k_weights
)
sorted_probs
,
sorted_indices
=
torch
.
sort
(
weighted_attn_probs
,
dim
=-
1
,
descending
=
True
)
cumsum_probs
=
torch
.
cumsum
(
sorted_probs
,
dim
=-
1
)
remove_indices
=
cumsum_probs
>
p
remove_indices
[...,
1
:]
=
remove_indices
[...,
:
-
1
].
clone
()
remove_indices
[...,
0
]
=
False
if
min_kc_ratio
>
0
:
preserve_length
=
int
(
min_kc_ratio
*
kc_num
)
remove_indices
[...,
:
preserve_length
]
=
False
sorted_clusters_to_keep
=
~
remove_indices
dynamic_map
=
torch
.
zeros
(
B
,
H
,
qc_num
,
kc_num
,
dtype
=
torch
.
bool
,
device
=
device
)
dynamic_map
.
scatter_
(
-
1
,
sorted_indices
,
sorted_clusters_to_keep
)
return
dynamic_map
# --- Functions from analyze/dynamic_block_sparse_attention.py ---
def
dynamic_block_sparse_fwd_torch
(
q
,
k
,
v
,
dynamic_map
,
qc_size
,
kc_size
):
"""
Computes dynamic block sparse attention using pure PyTorch.
Args:
q (torch.Tensor): Query tensor, shape [B, H, S, D].
k (torch.Tensor): Key tensor, shape [B, H, S, D].
v (torch.Tensor): Value tensor, shape [B, H, S, D].
dynamic_map (torch.Tensor): Boolean mask, shape [B, H, qc_num, kc_num].
qc_size (torch.Tensor): Query block sizes, shape [B, H, qc_num].
kc_size (torch.Tensor): Key block sizes, shape [B, H, kc_num].
Returns:
torch.Tensor: Output tensor, shape [B, H, S, D].
"""
B
,
H
,
S
,
D
=
q
.
shape
qc_num
=
qc_size
.
shape
[
-
1
]
kc_num
=
kc_size
.
shape
[
-
1
]
device
=
q
.
device
dtype
=
q
.
dtype
# Ensure sequence lengths match sum of block sizes
assert
S
==
torch
.
sum
(
qc_size
[
0
,
0
,
:]),
"Sum of qc_size must equal S"
assert
S
==
torch
.
sum
(
kc_size
[
0
,
0
,
:]),
"Sum of kc_size must equal S"
# Precompute cumulative sizes for block indexing
# Add a 0 at the beginning for easier slicing
qc_cum_size
=
torch
.
cumsum
(
torch
.
cat
([
torch
.
zeros_like
(
qc_size
[...,
:
1
]),
qc_size
],
dim
=-
1
),
dim
=-
1
)
kc_cum_size
=
torch
.
cumsum
(
torch
.
cat
([
torch
.
zeros_like
(
kc_size
[...,
:
1
]),
kc_size
],
dim
=-
1
),
dim
=-
1
)
out
=
torch
.
zeros_like
(
q
)
scale
=
D
**-
0.5
# Naive implementation: Iterate through batch, head, and blocks
for
b
in
range
(
B
):
for
h
in
range
(
H
):
# Precompute start/end indices for this batch/head
q_starts
=
qc_cum_size
[
b
,
h
,
:
-
1
]
q_ends
=
qc_cum_size
[
b
,
h
,
1
:]
k_starts
=
kc_cum_size
[
b
,
h
,
:
-
1
]
k_ends
=
kc_cum_size
[
b
,
h
,
1
:]
# Iterate through query blocks
for
i
in
range
(
qc_num
):
q_start
,
q_end
=
q_starts
[
i
],
q_ends
[
i
]
q_block
=
q
[
b
,
h
,
q_start
:
q_end
,
:]
# Shape: [qc_i, D]
if
q_block
.
shape
[
0
]
==
0
:
continue
# Skip empty blocks
m_i
=
torch
.
full
((
q_block
.
shape
[
0
],
1
),
-
float
(
"inf"
),
device
=
device
,
dtype
=
dtype
)
l_i
=
torch
.
zeros
((
q_block
.
shape
[
0
],
1
),
device
=
device
,
dtype
=
dtype
)
acc_o_i
=
torch
.
zeros_like
(
q_block
)
# Shape: [qc_i, D]
# Iterate through key/value blocks for the current query block
for
j
in
range
(
kc_num
):
# Check if this block needs computation
if
dynamic_map
[
b
,
h
,
i
,
j
]:
k_start
,
k_end
=
k_starts
[
j
],
k_ends
[
j
]
k_block
=
k
[
b
,
h
,
k_start
:
k_end
,
:]
# Shape: [kc_j, D]
v_block
=
v
[
b
,
h
,
k_start
:
k_end
,
:]
# Shape: [kc_j, D]
if
k_block
.
shape
[
0
]
==
0
:
continue
# Skip empty blocks
# Compute attention scores for the block
# QK^T: [qc_i, D] @ [D, kc_j] -> [qc_i, kc_j]
s_ij
=
(
q_block
@
k_block
.
transpose
(
-
1
,
-
2
))
*
scale
# --- Online Softmax ---
# Find max score per query token in this block
m_ij
=
torch
.
max
(
s_ij
,
dim
=-
1
,
keepdim
=
True
)[
0
]
# Shape: [qc_i, 1]
# Update overall max score (m_i)
m_new
=
torch
.
maximum
(
m_i
,
m_ij
)
# Shape: [qc_i, 1]
# Calculate scaling factors for previous accumulator and current block
p_ij
=
torch
.
exp
(
s_ij
-
m_new
)
# Shape: [qc_i, kc_j]
exp_m_diff
=
torch
.
exp
(
m_i
-
m_new
)
# Shape: [qc_i, 1]
# Update softmax denominator (l_i)
l_i
=
(
l_i
*
exp_m_diff
)
+
torch
.
sum
(
p_ij
,
dim
=-
1
,
keepdim
=
True
)
# Shape: [qc_i, 1]
# Update output accumulator (acc_o_i)
# P_ij @ V_j: [qc_i, kc_j] @ [kc_j, D] -> [qc_i, D]
acc_o_i
=
(
acc_o_i
*
exp_m_diff
)
+
(
p_ij
@
v_block
)
# Shape: [qc_i, D]
# Update max score for next iteration
m_i
=
m_new
# Normalize the accumulated output
out
[
b
,
h
,
q_start
:
q_end
,
:]
=
acc_o_i
/
l_i
.
clamp
(
min
=
1e-12
)
# Avoid division by zero
return
out
# --- Triton Implementation ---
@
triton
.
jit
def
_dynamic_block_sparse_fwd_kernel
(
Q
,
K
,
V
,
Out
,
dynamic_map
,
qc_cum_size
,
kc_cum_size
,
stride_qb
,
stride_qh
,
stride_qs
,
stride_qd
,
stride_kb
,
stride_kh
,
stride_ks
,
stride_kd
,
stride_vb
,
stride_vh
,
stride_vs
,
stride_vd
,
stride_ob
,
stride_oh
,
stride_os
,
stride_od
,
stride_dmap_b
,
stride_dmap_h
,
stride_dmap_qc
,
stride_dmap_kc
,
stride_qcs_b
,
stride_qcs_h
,
stride_qcs_qc
,
stride_kcs_b
,
stride_kcs_h
,
stride_kcs_kc
,
B
,
H
,
S
,
D
,
scale
,
QC_NUM
:
tl
.
constexpr
,
KC_NUM
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_D
:
tl
.
constexpr
,
):
"""
Triton kernel for dynamic block sparse attention.
Each program computes attention for one query block within a batch/head.
Processes query block in chunks of BLOCK_M.
Iterates through key blocks, checking dynamic_map.
Processes key/value blocks in chunks of BLOCK_N.
Uses online softmax.
"""
# --- Grid Calculation ---
# Each program instance handles one query block for a specific batch and head
pid
=
tl
.
program_id
(
axis
=
0
)
B
*
H
*
QC_NUM
# Calculate batch, head, and query block index
pid_q_block_global
=
pid
# 0 to B*H*QC_NUM - 1
# pid_bh = pid // QC_NUM # Deprecated: Causes issues if QC_NUM is not constant across BH
# pid_q_block_idx = pid % QC_NUM
# Need to map pid (0.. B*H*QC_NUM-1) back to (b, h, q_block_idx)
# q_block_idx changes fastest, then h, then b
q_block_idx
=
pid_q_block_global
%
QC_NUM
pid_h_temp
=
pid_q_block_global
//
QC_NUM
h
=
pid_h_temp
%
H
b
=
pid_h_temp
//
H
# --- Load Q block info (start/end offsets) ---
qcs_offset
=
b
*
stride_qcs_b
+
h
*
stride_qcs_h
q_start_offset
=
tl
.
load
(
qc_cum_size
+
qcs_offset
+
q_block_idx
*
stride_qcs_qc
)
q_end_offset
=
tl
.
load
(
qc_cum_size
+
qcs_offset
+
(
q_block_idx
+
1
)
*
stride_qcs_qc
)
q_block_size
=
q_end_offset
-
q_start_offset
# Early exit if the query block is empty
if
q_block_size
==
0
:
return
# --- Pointers setup ---
q_ptr_base
=
Q
+
b
*
stride_qb
+
h
*
stride_qh
+
q_start_offset
*
stride_qs
k_ptr_base
=
K
+
b
*
stride_kb
+
h
*
stride_kh
v_ptr_base
=
V
+
b
*
stride_vb
+
h
*
stride_vh
out_ptr_base
=
Out
+
b
*
stride_ob
+
h
*
stride_oh
+
q_start_offset
*
stride_os
dmap_ptr
=
dynamic_map
+
b
*
stride_dmap_b
+
h
*
stride_dmap_h
+
q_block_idx
*
stride_dmap_qc
kcs_ptr
=
kc_cum_size
+
b
*
stride_kcs_b
+
h
*
stride_kcs_h
# --- Iterate over the query block rows in chunks of BLOCK_M ---
offs_qm
=
tl
.
arange
(
0
,
BLOCK_M
)
# Query block row offsets [0, 1, ..., BLOCK_M-1]
offs_d
=
tl
.
arange
(
0
,
BLOCK_D
)
# Dimension offsets [0, 1, ..., BLOCK_D-1]
for
q_chunk_start
in
range
(
0
,
q_block_size
,
BLOCK_M
):
q_chunk_rows
=
offs_qm
+
q_chunk_start
q_rows_mask
=
q_chunk_rows
<
q_block_size
# Mask for valid rows in this Q chunk [BLOCK_M]
# --- Initialize accumulators for this Q chunk ---
m_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
-
float
(
"inf"
)
# Max score
l_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
# Sum of exp(scores - max)
acc_o
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_D
],
dtype
=
tl
.
float32
)
# Accumulated output
# --- Load Q chunk ---
q_ptr
=
q_ptr_base
+
q_chunk_rows
[:,
None
]
*
stride_qs
+
offs_d
[
None
,
:]
# Mask ensures we don't read out of bounds for the query block or dimension D
mask_q
=
q_rows_mask
[:,
None
]
&
(
offs_d
[
None
,
:]
<
D
)
q_chunk
=
tl
.
load
(
q_ptr
,
mask
=
mask_q
,
other
=
0.0
)
# Shape: [BLOCK_M, BLOCK_D]
# --- Inner loop over K blocks (columns in the block sparse map) ---
for
k_block_idx
in
range
(
KC_NUM
):
# --- Check dynamic_map: Is this block active? ---
is_active
=
tl
.
load
(
dmap_ptr
+
k_block_idx
*
stride_dmap_kc
)
if
is_active
:
# Process block only if it's active
# --- Load K block info (start/end offsets) ---
k_start_offset
=
tl
.
load
(
kcs_ptr
+
k_block_idx
*
stride_kcs_kc
)
k_end_offset
=
tl
.
load
(
kcs_ptr
+
(
k_block_idx
+
1
)
*
stride_kcs_kc
)
k_block_size
=
k_end_offset
-
k_start_offset
# Skip if the key block is empty (inside the active block check)
if
k_block_size
>
0
:
k_block_ptr_base
=
k_ptr_base
+
k_start_offset
*
stride_ks
v_block_ptr_base
=
v_ptr_base
+
k_start_offset
*
stride_vs
# --- Loop over K block chunks (size BLOCK_N) ---
offs_kn
=
tl
.
arange
(
0
,
BLOCK_N
)
# Key block row offsets [0, ..., BLOCK_N-1]
for
k_chunk_start
in
range
(
0
,
k_block_size
,
BLOCK_N
):
k_chunk_rows
=
offs_kn
+
k_chunk_start
k_rows_mask
=
k_chunk_rows
<
k_block_size
# Mask for valid rows in this K/V chunk [BLOCK_N]
# --- Load K, V chunks ---
k_ptr
=
k_block_ptr_base
+
k_chunk_rows
[:,
None
]
*
stride_ks
+
offs_d
[
None
,
:]
v_ptr
=
v_block_ptr_base
+
k_chunk_rows
[:,
None
]
*
stride_vs
+
offs_d
[
None
,
:]
# Mask ensures we don't read out of bounds for the key block or dimension D
mask_kv
=
k_rows_mask
[:,
None
]
&
(
offs_d
[
None
,
:]
<
D
)
k_chunk
=
tl
.
load
(
k_ptr
,
mask
=
mask_kv
,
other
=
0.0
)
# Shape: [BLOCK_N, BLOCK_D]
v_chunk
=
tl
.
load
(
v_ptr
,
mask
=
mask_kv
,
other
=
0.0
)
# Shape: [BLOCK_N, BLOCK_D]
# --- Compute Scores (Attention) ---
# QK^T: [BLOCK_M, BLOCK_D] @ [BLOCK_D, BLOCK_N] -> [BLOCK_M, BLOCK_N]
s_ij_chunk
=
tl
.
dot
(
q_chunk
,
k_chunk
.
T
)
*
scale
# IMPORTANT: Mask out scores corresponding to padding in K before max/softmax
# Set scores for invalid K elements to -inf
s_ij_chunk
=
tl
.
where
(
k_rows_mask
[
None
,
:],
s_ij_chunk
,
-
float
(
"inf"
))
# Mask out scores for invalid Q elements as well (although q_chunk elements are 0, avoid potential issues)
s_ij_chunk
=
tl
.
where
(
q_rows_mask
[:,
None
],
s_ij_chunk
,
-
float
(
"inf"
))
# --- Online Softmax Update ---
# Current max for this Q-K chunk interaction
m_ij_chunk
=
tl
.
max
(
s_ij_chunk
,
axis
=
1
)
# Shape: [BLOCK_M]
# Update overall max (across K chunks seen so far for this Q chunk)
m_new
=
tl
.
maximum
(
m_i
,
m_ij_chunk
)
# Shape: [BLOCK_M]
# Calculate scaled probabilities P_ij = exp(S_ij - m_new)
p_ij_chunk
=
tl
.
exp
(
s_ij_chunk
-
m_new
[:,
None
])
# Shape: [BLOCK_M, BLOCK_N]
# Zero out probabilities for masked K elements before summing
p_ij_chunk
=
tl
.
where
(
k_rows_mask
[
None
,
:],
p_ij_chunk
,
0.0
)
# Calculate scaling factor for previous accumulator state
exp_m_diff
=
tl
.
exp
(
m_i
-
m_new
)
# Shape: [BLOCK_M]
# Update sum accumulator (denominator L)
l_i_chunk
=
tl
.
sum
(
p_ij_chunk
,
axis
=
1
)
# Sum probabilities for this chunk, shape [BLOCK_M]
l_i
=
(
l_i
*
exp_m_diff
)
+
l_i_chunk
# Shape: [BLOCK_M]
# Update output accumulator O
# P_ij @ V_j: [BLOCK_M, BLOCK_N] @ [BLOCK_N, BLOCK_D] -> [BLOCK_M, BLOCK_D]
# Ensure p_ij_chunk is the correct dtype for dot product
p_ij_chunk_casted
=
p_ij_chunk
.
to
(
V
.
dtype
.
element_ty
)
o_chunk
=
tl
.
dot
(
p_ij_chunk_casted
,
v_chunk
)
# Shape: [BLOCK_M, BLOCK_D]
acc_o
=
(
acc_o
*
exp_m_diff
[:,
None
])
+
o_chunk
# Shape: [BLOCK_M, BLOCK_D]
# Update max for the next K chunk/block
m_i
=
m_new
# End of 'if is_active:' block
# --- End of loop over K blocks ---
# --- Finalize output for this Q chunk ---
# Normalize the accumulated output: O = acc_o / l_i
# Add epsilon to l_i to avoid division by zero
l_i_safe
=
tl
.
where
(
l_i
==
0
,
1.0
,
l_i
)
# Avoid 0/0 -> NaN
o_final_chunk
=
acc_o
/
(
l_i_safe
[:,
None
])
o_final_chunk
=
tl
.
where
(
l_i
[:,
None
]
==
0
,
0.0
,
o_final_chunk
)
# Ensure output is 0 if l_i was 0
# --- Write output chunk to global memory ---
out_ptr
=
out_ptr_base
+
q_chunk_rows
[:,
None
]
*
stride_os
+
offs_d
[
None
,
:]
# Mask ensures we don't write out of bounds for the query block or dimension D
mask_out
=
q_rows_mask
[:,
None
]
&
(
offs_d
[
None
,
:]
<
D
)
tl
.
store
(
out_ptr
,
o_final_chunk
.
to
(
Out
.
dtype
.
element_ty
),
mask
=
mask_out
)
# --- (Optional: Write L and M stats if needed) ---
# Example:
# l_ptr = L + b * stride_lb + h * stride_lh + (q_start_offset + q_chunk_rows) * stride_ls
# tl.store(l_ptr, l_i, mask=q_rows_mask)
# m_ptr = M + ...
# tl.store(m_ptr, m_i, mask=q_rows_mask)
# --- End of loop over Q chunks ---
def
dynamic_block_sparse_fwd_triton
(
q
,
k
,
v
,
dynamic_map
,
qc_size
,
kc_size
):
"""
Launcher for the Triton dynamic block sparse attention kernel.
Args:
q (torch.Tensor): Query tensor, shape [B, H, S, D].
k (torch.Tensor): Key tensor, shape [B, H, S, D].
v (torch.Tensor): Value tensor, shape [B, H, S, D].
dynamic_map (torch.Tensor): Boolean mask, shape [B, H, qc_num, kc_num].
qc_size (torch.Tensor): Query block sizes, shape [B, H, qc_num].
kc_size (torch.Tensor): Key block sizes, shape [B, H, kc_num].
Returns:
torch.Tensor: Output tensor, shape [B, H, S, D].
"""
B
,
H
,
S
,
D
=
q
.
shape
qc_num
=
qc_size
.
shape
[
-
1
]
kc_num
=
kc_size
.
shape
[
-
1
]
dtype
=
q
.
dtype
# Assertions and checks
assert
q
.
is_cuda
and
k
.
is_cuda
and
v
.
is_cuda
,
"Inputs must be CUDA tensors"
assert
dynamic_map
.
is_cuda
and
qc_size
.
is_cuda
and
kc_size
.
is_cuda
assert
q
.
dtype
==
k
.
dtype
==
v
.
dtype
,
"Input dtypes must match"
assert
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
],
"Unsupported dtype"
assert
D
in
[
16
,
32
,
64
,
128
],
"Head dimension D must be 16, 32, 64, or 128 for efficient Triton dot"
# Ensure sequence lengths match sum of block sizes (check on one batch/head for simplicity)
assert
S
==
torch
.
sum
(
qc_size
[
0
,
0
,
:]),
"Sum of qc_size must equal S"
assert
S
==
torch
.
sum
(
kc_size
[
0
,
0
,
:]),
"Sum of kc_size must equal S"
# Ensure dynamic_map is boolean
assert
dynamic_map
.
dtype
==
torch
.
bool
# Calculate scale factor (using float32 for stability)
scale
=
D
**-
0.5
# Precompute cumulative sizes (on CPU/GPU, keep on device)
qc_cum_size
=
torch
.
cumsum
(
torch
.
cat
([
torch
.
zeros_like
(
qc_size
[...,
:
1
]),
qc_size
],
dim
=-
1
),
dim
=-
1
).
int
()
kc_cum_size
=
torch
.
cumsum
(
torch
.
cat
([
torch
.
zeros_like
(
kc_size
[...,
:
1
]),
kc_size
],
dim
=-
1
),
dim
=-
1
).
int
()
# Output tensor
out
=
torch
.
empty_like
(
q
)
# Triton kernel config
# BLOCK_M/N can be tuned. Larger blocks may increase occupancy but need more shared memory.
# Let's start with reasonably sized blocks.
BLOCK_D
=
D
if
S
<=
512
:
# Smaller sequence, smaller blocks might be ok
BLOCK_M
=
64
BLOCK_N
=
64
elif
S
<=
1024
:
BLOCK_M
=
64
BLOCK_N
=
64
else
:
# Larger sequence, potentially larger blocks
BLOCK_M
=
128
# Or keep 64? Test
BLOCK_N
=
64
# Adjust block size if sequence length is smaller
BLOCK_M
=
min
(
BLOCK_M
,
S
)
BLOCK_N
=
min
(
BLOCK_N
,
S
)
# Launch grid: One program per query block per batch/head
grid
=
(
B
*
H
*
qc_num
,)
# Call the kernel
_dynamic_block_sparse_fwd_kernel
[
grid
](
q
,
k
,
v
,
out
,
dynamic_map
,
qc_cum_size
,
kc_cum_size
,
q
.
stride
(
0
),
q
.
stride
(
1
),
q
.
stride
(
2
),
q
.
stride
(
3
),
k
.
stride
(
0
),
k
.
stride
(
1
),
k
.
stride
(
2
),
k
.
stride
(
3
),
v
.
stride
(
0
),
v
.
stride
(
1
),
v
.
stride
(
2
),
v
.
stride
(
3
),
out
.
stride
(
0
),
out
.
stride
(
1
),
out
.
stride
(
2
),
out
.
stride
(
3
),
dynamic_map
.
stride
(
0
),
dynamic_map
.
stride
(
1
),
dynamic_map
.
stride
(
2
),
dynamic_map
.
stride
(
3
),
qc_cum_size
.
stride
(
0
),
qc_cum_size
.
stride
(
1
),
qc_cum_size
.
stride
(
2
),
kc_cum_size
.
stride
(
0
),
kc_cum_size
.
stride
(
1
),
kc_cum_size
.
stride
(
2
),
B
,
H
,
S
,
D
,
scale
,
QC_NUM
=
qc_num
,
KC_NUM
=
kc_num
,
BLOCK_M
=
BLOCK_M
,
BLOCK_N
=
BLOCK_N
,
BLOCK_D
=
BLOCK_D
,
# num_warps=4 # Can tune this
)
return
out
# ---------------- Batch wrapper for cuVS KMeans -----------------
def
batch_kmeans_rapidai
(
x
,
n_clusters
,
max_iters
=
100
,
tol
=
1e-4
,
init_centroids
=
None
,
verbose
=
False
):
"""Batched K-Means using RAPIDS cuVS implementation.
Args:
x (Tensor): (B, N, D) float32 tensor on CUDA.
n_clusters (int): K.
max_iters (int): maximum iterations.
tol (float): tolerance.
init_centroids (Tensor|None): optional initial centroids (B,K,D) float32.
verbose (bool): print per-batch info.
Returns:
cluster_ids (B, N) LongTensor
centroids (B, K, D) float32
cluster_sizes (B, K) LongTensor
n_iters_list (List[int]) iterations per batch
"""
B
,
N
,
D
=
x
.
shape
if
init_centroids
is
not
None
:
assert
init_centroids
.
shape
==
(
B
,
n_clusters
,
D
)
cluster_ids_list
=
[]
centroids_list
=
[]
# cluster_sizes_list = []
n_iters_list
=
[]
x_float
=
x
.
float
()
if
init_centroids
is
not
None
:
init_centroids_float
=
init_centroids
.
float
()
for
b
in
range
(
B
):
xb
=
x_float
[
b
]
if
init_centroids
is
None
:
centroids_init_b
=
None
init_method
=
"KMeansPlusPlus"
else
:
centroids_init_b
=
init_centroids_float
[
b
]
init_method
=
"Array"
labels_b
,
centroids_b
,
n_iter_b
=
kmeans_rapidai
(
xb
,
n_clusters
,
max_iter
=
max_iters
,
tol
=
tol
,
init_method
=
init_method
,
centroids_init
=
centroids_init_b
)
cluster_ids_list
.
append
(
labels_b
.
to
(
torch
.
int64
))
# (N,)
centroids_list
.
append
(
centroids_b
)
# cluster_sizes_b = torch.bincount(labels_b, minlength=n_clusters).to(torch.int64)
# cluster_sizes_list.append(cluster_sizes_b)
# n_iters_list.append(n_iter_b)
# if verbose:
# print(f"Batch {b}: iters={n_iter_b}, cluster sizes min={cluster_sizes_b.min().item()} max={cluster_sizes_b.max().item()}")
cluster_ids
=
torch
.
stack
(
cluster_ids_list
,
dim
=
0
)
# (B,N)
centroids
=
torch
.
stack
(
centroids_list
,
dim
=
0
).
to
(
x
.
dtype
)
# (B,K,D)
# cluster_sizes = torch.stack(cluster_sizes_list, dim=0) # (B,K)
# --- compute cluster sizes ---
ones
=
torch
.
ones_like
(
cluster_ids
,
dtype
=
torch
.
int64
)
cluster_sizes
=
torch
.
zeros
(
B
,
n_clusters
,
dtype
=
torch
.
int64
,
device
=
x
.
device
)
cluster_sizes
.
scatter_add_
(
1
,
cluster_ids
,
ones
)
return
cluster_ids
,
centroids
,
cluster_sizes
,
n_iters_list
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment