Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
LightX2V
Commits
a1ebc651
Commit
a1ebc651
authored
Dec 11, 2025
by
xuwx1
Browse files
updata lightx2v
parent
5a4db490
Pipeline
#3149
canceled with stages
Changes
428
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
5022 additions
and
0 deletions
+5022
-0
lightx2v/common/ops/attn/nbhd_attn.py
lightx2v/common/ops/attn/nbhd_attn.py
+196
-0
lightx2v/common/ops/attn/radial_attn.py
lightx2v/common/ops/attn/radial_attn.py
+185
-0
lightx2v/common/ops/attn/ring_attn.py
lightx2v/common/ops/attn/ring_attn.py
+179
-0
lightx2v/common/ops/attn/sage_attn.py
lightx2v/common/ops/attn/sage_attn.py
+83
-0
lightx2v/common/ops/attn/spassage_attn.py
lightx2v/common/ops/attn/spassage_attn.py
+76
-0
lightx2v/common/ops/attn/svg2_attn.py
lightx2v/common/ops/attn/svg2_attn.py
+355
-0
lightx2v/common/ops/attn/svg2_attn_utils.py
lightx2v/common/ops/attn/svg2_attn_utils.py
+1359
-0
lightx2v/common/ops/attn/svg_attn.py
lightx2v/common/ops/attn/svg_attn.py
+409
-0
lightx2v/common/ops/attn/template.py
lightx2v/common/ops/attn/template.py
+35
-0
lightx2v/common/ops/attn/torch_sdpa.py
lightx2v/common/ops/attn/torch_sdpa.py
+39
-0
lightx2v/common/ops/attn/ulysses_attn.py
lightx2v/common/ops/attn/ulysses_attn.py
+415
-0
lightx2v/common/ops/attn/utils/all2all.py
lightx2v/common/ops/attn/utils/all2all.py
+89
-0
lightx2v/common/ops/attn/utils/ring_comm.py
lightx2v/common/ops/attn/utils/ring_comm.py
+46
-0
lightx2v/common/ops/conv/__init__.py
lightx2v/common/ops/conv/__init__.py
+2
-0
lightx2v/common/ops/conv/conv2d.py
lightx2v/common/ops/conv/conv2d.py
+61
-0
lightx2v/common/ops/conv/conv3d.py
lightx2v/common/ops/conv/conv3d.py
+94
-0
lightx2v/common/ops/embedding/__init__.py
lightx2v/common/ops/embedding/__init__.py
+1
-0
lightx2v/common/ops/embedding/embedding_weight.py
lightx2v/common/ops/embedding/embedding_weight.py
+72
-0
lightx2v/common/ops/mm/__init__.py
lightx2v/common/ops/mm/__init__.py
+1
-0
lightx2v/common/ops/mm/mm_weight.py
lightx2v/common/ops/mm/mm_weight.py
+1325
-0
No files found.
Too many changes to show.
To preserve performance only
428 of 428+
files are displayed.
Plain diff
Email patch
lightx2v/common/ops/attn/nbhd_attn.py
0 → 100644
View file @
a1ebc651
import
torch
from
loguru
import
logger
try
:
from
magi_attention.functional
import
flex_flash_attn_func
as
magi_ffa_func
except
ImportError
:
magi_ffa_func
=
None
try
:
import
flashinfer
except
ImportError
:
flashinfer
=
None
from
lightx2v.utils.registry_factory
import
ATTN_WEIGHT_REGISTER
from
.template
import
AttnWeightTemplate
def
generate_nbhd_mask
(
a
,
block_num
,
attnmap_frame_num
,
coefficient
=
[
1.0
,
0.5
,
0.056
],
min_width
=
1.0
,
device
=
"cpu"
):
"""
a : block num per frame
block_num : block num per col/row
attnmap_frame_num : total frame num
"""
i_indices
=
torch
.
arange
(
block_num
,
device
=
device
).
unsqueeze
(
1
)
# [block_num, 1]
j_indices
=
torch
.
arange
(
block_num
,
device
=
device
).
unsqueeze
(
0
)
# [1, block_num]
assert
len
(
coefficient
)
<=
attnmap_frame_num
,
f
"coefficient length
{
len
(
coefficient
)
}
should <= attnmap_frame_num
{
attnmap_frame_num
}
"
width_list
=
[
max
(
min_width
,
coefficient
[
i
]
*
a
)
for
i
in
range
(
len
(
coefficient
))]
+
[
min_width
]
*
(
attnmap_frame_num
-
len
(
coefficient
))
logger
.
info
(
f
"nbhd_attn width_list:
{
width_list
}
, len=
{
len
(
width_list
)
}
"
)
# attention sink frame: j <= a
mask_sink
=
j_indices
<=
a
mask_sparse
=
torch
.
zeros
((
block_num
,
block_num
),
dtype
=
torch
.
bool
,
device
=
device
)
for
interval
in
range
(
0
,
attnmap_frame_num
):
n
=
i_indices
//
a
mask_sparse_base_1
=
(
j_indices
>=
(
n
+
interval
)
*
a
)
&
(
j_indices
<=
(
n
+
interval
+
1
)
*
a
)
n
=
j_indices
//
a
mask_sparse_base_2
=
(
i_indices
>=
(
n
+
interval
)
*
a
)
&
(
i_indices
<=
(
n
+
interval
+
1
)
*
a
)
width
=
width_list
[
interval
]
mask_1
=
mask_sparse_base_1
&
(
i_indices
-
j_indices
+
(
interval
*
a
+
width
)
>=
0
)
&
(
i_indices
-
j_indices
+
(
interval
*
a
-
width
)
<=
0
)
mask_2
=
mask_sparse_base_2
&
(
i_indices
-
j_indices
-
(
interval
*
a
-
width
)
>=
0
)
&
(
i_indices
-
j_indices
-
(
interval
*
a
+
width
)
<=
0
)
mask_sparse
=
mask_sparse
|
mask_1
|
mask_2
mask
=
mask_sink
|
mask_sparse
return
mask
def
generate_qk_ranges
(
mask
,
block_size
,
seqlen
):
indices
=
torch
.
nonzero
(
mask
,
as_tuple
=
False
)
# shape: [N, 2]
i_indices
=
indices
[:,
0
]
# [N]
j_indices
=
indices
[:,
1
]
# [N]
q_start
=
i_indices
*
block_size
# [N]
q_end
=
torch
.
clamp
((
i_indices
+
1
)
*
block_size
,
max
=
seqlen
)
# [N]
k_start
=
j_indices
*
block_size
# [N]
k_end
=
torch
.
clamp
((
j_indices
+
1
)
*
block_size
,
max
=
seqlen
)
# [N]
q_ranges
=
torch
.
stack
([
q_start
,
q_end
],
dim
=
1
)
# [N, 2]
k_ranges
=
torch
.
stack
([
k_start
,
k_end
],
dim
=
1
)
# [N, 2]
return
q_ranges
,
k_ranges
@
ATTN_WEIGHT_REGISTER
(
"nbhd_attn"
)
class
NbhdAttnWeight
(
AttnWeightTemplate
):
block_size
=
128
seqlen
=
None
attnmap_frame_num
=
None
q_ranges
=
None
k_ranges
=
None
attn_type_map
=
None
coefficient
=
[
1.0
,
0.5
,
0.056
]
min_width
=
1.0
def
__init__
(
self
):
self
.
config
=
{}
@
classmethod
@
torch
.
compiler
.
disable
def
prepare_mask
(
cls
,
seqlen
):
if
seqlen
==
cls
.
seqlen
:
return
block_num
=
(
seqlen
+
cls
.
block_size
-
1
)
//
cls
.
block_size
block_num_per_frame
=
seqlen
/
cls
.
attnmap_frame_num
/
cls
.
block_size
mask
=
generate_nbhd_mask
(
block_num_per_frame
,
block_num
,
cls
.
attnmap_frame_num
,
coefficient
=
cls
.
coefficient
,
min_width
=
cls
.
min_width
,
device
=
"cpu"
)
q_ranges
,
k_ranges
=
generate_qk_ranges
(
mask
,
cls
.
block_size
,
seqlen
)
attn_type_map
=
torch
.
zeros
(
len
(
q_ranges
),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
q_ranges
=
q_ranges
.
to
(
torch
.
int32
).
to
(
"cuda"
)
k_ranges
=
k_ranges
.
to
(
torch
.
int32
).
to
(
"cuda"
)
cls
.
seqlen
=
seqlen
cls
.
q_ranges
=
q_ranges
cls
.
k_ranges
=
k_ranges
cls
.
attn_type_map
=
attn_type_map
logger
.
info
(
f
"NbhdAttnWeight Update: seqlen=
{
seqlen
}
"
)
sparsity
=
1
-
mask
.
sum
().
item
()
/
mask
.
numel
()
logger
.
info
(
f
"Attention sparsity:
{
sparsity
}
"
)
def
apply
(
self
,
q
,
k
,
v
,
cu_seqlens_q
=
None
,
cu_seqlens_kv
=
None
,
max_seqlen_q
=
None
,
max_seqlen_kv
=
None
,
model_cls
=
None
,
):
"""
q: [seqlen, head_num, head_dim]
k: [seqlen, head_num, head_dim]
v: [seqlen, head_num, head_dim]
"""
self
.
prepare_mask
(
seqlen
=
q
.
shape
[
0
])
out
=
magi_ffa_func
(
q
,
k
,
v
,
q_ranges
=
self
.
q_ranges
,
k_ranges
=
self
.
k_ranges
,
attn_type_map
=
self
.
attn_type_map
,
auto_range_merge
=
True
,
)[
0
]
return
out
.
reshape
(
out
.
shape
[
0
],
-
1
)
@
ATTN_WEIGHT_REGISTER
(
"nbhd_attn_flashinfer"
)
class
NbhdAttnWeightFlashInfer
(
AttnWeightTemplate
):
block_size
=
128
seqlen
=
None
attnmap_frame_num
=
None
coefficient
=
[
1.0
,
0.5
,
0.056
]
min_width
=
1.0
sparse_wrapper
=
None
def
__init__
(
self
):
self
.
config
=
{}
@
classmethod
@
torch
.
compiler
.
disable
def
prepare_mask
(
cls
,
seqlen
,
head_num
,
head_dim
):
if
seqlen
==
cls
.
seqlen
:
return
block_num
=
(
seqlen
+
cls
.
block_size
-
1
)
//
cls
.
block_size
block_num_per_frame
=
seqlen
/
cls
.
attnmap_frame_num
/
cls
.
block_size
mask
=
generate_nbhd_mask
(
block_num_per_frame
,
block_num
,
cls
.
attnmap_frame_num
,
coefficient
=
cls
.
coefficient
,
min_width
=
cls
.
min_width
,
device
=
"cpu"
)
mask
=
mask
.
unsqueeze
(
0
).
repeat
(
head_num
,
1
,
1
)
block_rowcol_size
=
torch
.
ones
(
block_num
,
dtype
=
torch
.
int32
)
*
cls
.
block_size
block_rowcol_size
[
-
1
]
=
seqlen
-
cls
.
block_size
*
(
block_num
-
1
)
block_rowcol_size
=
block_rowcol_size
.
unsqueeze
(
0
).
repeat
(
head_num
,
1
)
float_workspace_buffer
=
torch
.
empty
(
1024
*
1024
*
1024
,
dtype
=
torch
.
uint8
,
device
=
"cuda:0"
)
cls
.
sparse_wrapper
=
flashinfer
.
sparse
.
VariableBlockSparseAttentionWrapper
(
float_workspace_buffer
,
backend
=
"fa2"
)
cls
.
sparse_wrapper
.
plan
(
block_mask_map
=
mask
,
block_row_sz
=
block_rowcol_size
,
block_col_sz
=
block_rowcol_size
,
num_qo_heads
=
head_num
,
num_kv_heads
=
head_num
,
head_dim
=
head_dim
,
q_data_type
=
torch
.
bfloat16
,
)
cls
.
seqlen
=
seqlen
logger
.
info
(
f
"NbhdAttnWeight Update: seqlen=
{
seqlen
}
"
)
sparsity
=
1
-
mask
.
sum
().
item
()
/
mask
.
numel
()
logger
.
info
(
f
"Attention sparsity:
{
sparsity
}
"
)
def
apply
(
self
,
q
,
k
,
v
,
cu_seqlens_q
=
None
,
cu_seqlens_kv
=
None
,
max_seqlen_q
=
None
,
max_seqlen_kv
=
None
,
model_cls
=
None
,
):
"""
q: [seqlen, head_num, head_dim]
k: [seqlen, head_num, head_dim]
v: [seqlen, head_num, head_dim]
"""
self
.
prepare_mask
(
seqlen
=
q
.
shape
[
0
],
head_num
=
q
.
shape
[
1
],
head_dim
=
q
.
shape
[
2
])
q
=
q
.
transpose
(
0
,
1
)
k
=
k
.
transpose
(
0
,
1
)
v
=
v
.
transpose
(
0
,
1
)
out
=
self
.
sparse_wrapper
.
run
(
q
,
k
,
v
)
out
=
out
.
transpose
(
0
,
1
)
return
out
.
reshape
(
out
.
shape
[
0
],
-
1
)
lightx2v/common/ops/attn/radial_attn.py
0 → 100644
View file @
a1ebc651
import
torch
from
loguru
import
logger
try
:
from
magi_attention.functional
import
flex_flash_attn_func
as
magi_ffa_func
except
ImportError
:
magi_ffa_func
=
None
from
lightx2v.utils.registry_factory
import
ATTN_WEIGHT_REGISTER
from
.template
import
AttnWeightTemplate
def
shrinkMaskStrict
(
mask
,
block_size
=
128
):
seqlen
=
mask
.
shape
[
0
]
block_num
=
seqlen
//
block_size
mask
=
mask
[:
block_num
*
block_size
,
:
block_num
*
block_size
].
view
(
block_num
,
block_size
,
block_num
,
block_size
)
col_densities
=
mask
.
sum
(
dim
=
1
)
/
block_size
# we want the minimum non-zero column density in the block
non_zero_densities
=
col_densities
>
0
high_density_cols
=
col_densities
>
1
/
3
frac_high_density_cols
=
high_density_cols
.
sum
(
dim
=-
1
)
/
(
non_zero_densities
.
sum
(
dim
=-
1
)
+
1e-9
)
block_mask
=
frac_high_density_cols
>
0.6
block_mask
[
0
:
0
]
=
True
block_mask
[
-
1
:
-
1
]
=
True
return
block_mask
def
get_window_width
(
i
,
j
,
token_per_frame
,
sparse_type
,
num_frame
,
decay_factor
=
1
,
block_size
=
128
,
model_type
=
None
):
assert
sparse_type
in
[
"radial"
]
dist
=
abs
(
i
-
j
)
if
model_type
==
"wan"
:
if
dist
<
1
:
return
token_per_frame
if
dist
==
1
:
return
token_per_frame
//
2
elif
model_type
==
"hunyuan"
:
if
dist
<=
1
:
return
token_per_frame
else
:
raise
ValueError
(
f
"Unknown model type:
{
model_type
}
"
)
group
=
dist
.
bit_length
()
decay_length
=
2
**
token_per_frame
.
bit_length
()
/
2
**
group
*
decay_factor
threshold
=
block_size
if
decay_length
>=
threshold
:
return
decay_length
else
:
return
threshold
def
get_diagonal_split_mask
(
i
,
j
,
token_per_frame
,
sparse_type
,
device
):
assert
sparse_type
in
[
"radial"
]
dist
=
abs
(
i
-
j
)
group
=
dist
.
bit_length
()
threshold
=
128
# hardcoded threshold for now, which is equal to block-size
decay_length
=
2
**
token_per_frame
.
bit_length
()
/
2
**
group
if
decay_length
>=
threshold
:
return
torch
.
ones
((
token_per_frame
,
token_per_frame
),
device
=
device
,
dtype
=
torch
.
bool
)
split_factor
=
int
(
threshold
/
decay_length
)
modular
=
dist
%
split_factor
if
modular
==
0
:
return
torch
.
ones
((
token_per_frame
,
token_per_frame
),
device
=
device
,
dtype
=
torch
.
bool
)
else
:
return
torch
.
zeros
((
token_per_frame
,
token_per_frame
),
device
=
device
,
dtype
=
torch
.
bool
)
def
gen_log_mask_shrinked
(
device
,
s
,
video_token_num
,
num_frame
,
block_size
=
128
,
sparse_type
=
"log"
,
decay_factor
=
0.5
,
model_type
=
None
):
"""
A more memory friendly version, we generate the attention mask of each frame pair at a time,
shrinks it, and stores it into the final result
"""
final_log_mask
=
torch
.
zeros
(((
s
+
block_size
-
1
)
//
block_size
,
(
s
+
block_size
-
1
)
//
block_size
),
device
=
device
,
dtype
=
torch
.
bool
)
token_per_frame
=
video_token_num
//
num_frame
video_text_border
=
video_token_num
//
block_size
col_indices
=
torch
.
arange
(
0
,
token_per_frame
,
device
=
device
).
view
(
1
,
-
1
)
row_indices
=
torch
.
arange
(
0
,
token_per_frame
,
device
=
device
).
view
(
-
1
,
1
)
final_log_mask
[
video_text_border
:]
=
True
final_log_mask
[:,
video_text_border
:]
=
True
for
i
in
range
(
num_frame
):
for
j
in
range
(
num_frame
):
local_mask
=
torch
.
zeros
((
token_per_frame
,
token_per_frame
),
device
=
device
,
dtype
=
torch
.
bool
)
if
j
==
0
and
model_type
==
"wan"
:
# this is attention sink
local_mask
=
torch
.
ones
((
token_per_frame
,
token_per_frame
),
device
=
device
,
dtype
=
torch
.
bool
)
else
:
window_width
=
get_window_width
(
i
,
j
,
token_per_frame
,
sparse_type
,
num_frame
,
decay_factor
=
decay_factor
,
block_size
=
block_size
,
model_type
=
model_type
)
local_mask
=
torch
.
abs
(
col_indices
-
row_indices
)
<=
window_width
split_mask
=
get_diagonal_split_mask
(
i
,
j
,
token_per_frame
,
sparse_type
,
device
)
local_mask
=
torch
.
logical_and
(
local_mask
,
split_mask
)
remainder_row
=
(
i
*
token_per_frame
)
%
block_size
remainder_col
=
(
j
*
token_per_frame
)
%
block_size
# get the padded size
all_length_row
=
remainder_row
+
((
token_per_frame
-
1
)
//
block_size
+
1
)
*
block_size
all_length_col
=
remainder_col
+
((
token_per_frame
-
1
)
//
block_size
+
1
)
*
block_size
padded_local_mask
=
torch
.
zeros
((
all_length_row
,
all_length_col
),
device
=
device
,
dtype
=
torch
.
bool
)
padded_local_mask
[
remainder_row
:
remainder_row
+
token_per_frame
,
remainder_col
:
remainder_col
+
token_per_frame
]
=
local_mask
# shrink the mask
block_mask
=
shrinkMaskStrict
(
padded_local_mask
,
block_size
=
block_size
)
# set the block mask to the final log mask
block_row_start
=
(
i
*
token_per_frame
)
//
block_size
block_col_start
=
(
j
*
token_per_frame
)
//
block_size
block_row_end
=
block_row_start
+
block_mask
.
shape
[
0
]
block_col_end
=
block_col_start
+
block_mask
.
shape
[
1
]
final_log_mask
[
block_row_start
:
block_row_end
,
block_col_start
:
block_col_end
]
=
torch
.
logical_or
(
final_log_mask
[
block_row_start
:
block_row_end
,
block_col_start
:
block_col_end
],
block_mask
)
return
final_log_mask
def
generate_qk_ranges
(
mask
,
block_size
,
seqlen
):
indices
=
torch
.
nonzero
(
mask
,
as_tuple
=
False
)
# shape: [N, 2]
i_indices
=
indices
[:,
0
]
# [N]
j_indices
=
indices
[:,
1
]
# [N]
q_start
=
i_indices
*
block_size
# [N]
q_end
=
torch
.
clamp
((
i_indices
+
1
)
*
block_size
,
max
=
seqlen
)
# [N]
k_start
=
j_indices
*
block_size
# [N]
k_end
=
torch
.
clamp
((
j_indices
+
1
)
*
block_size
,
max
=
seqlen
)
# [N]
q_ranges
=
torch
.
stack
([
q_start
,
q_end
],
dim
=
1
)
# [N, 2]
k_ranges
=
torch
.
stack
([
k_start
,
k_end
],
dim
=
1
)
# [N, 2]
return
q_ranges
,
k_ranges
@
ATTN_WEIGHT_REGISTER
(
"radial_attn"
)
class
RadialAttnWeight
(
AttnWeightTemplate
):
block_size
=
128
seqlen
=
None
attnmap_frame_num
=
None
q_ranges
=
None
k_ranges
=
None
attn_type_map
=
None
def
__init__
(
self
):
self
.
config
=
{}
@
classmethod
def
prepare_mask
(
cls
,
seqlen
):
if
seqlen
==
cls
.
seqlen
:
return
mask
=
gen_log_mask_shrinked
(
device
=
"cuda"
,
s
=
seqlen
,
video_token_num
=
seqlen
,
num_frame
=
cls
.
attnmap_frame_num
,
block_size
=
cls
.
block_size
,
sparse_type
=
"radial"
,
decay_factor
=
0.2
,
model_type
=
"wan"
)
q_ranges
,
k_ranges
=
generate_qk_ranges
(
mask
,
cls
.
block_size
,
seqlen
)
attn_type_map
=
torch
.
zeros
(
len
(
q_ranges
),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
q_ranges
=
q_ranges
.
to
(
torch
.
int32
).
to
(
"cuda"
)
k_ranges
=
k_ranges
.
to
(
torch
.
int32
).
to
(
"cuda"
)
cls
.
seqlen
=
seqlen
cls
.
q_ranges
=
q_ranges
cls
.
k_ranges
=
k_ranges
cls
.
attn_type_map
=
attn_type_map
logger
.
info
(
f
"NbhdAttnWeight Update: seqlen=
{
seqlen
}
"
)
sparsity
=
1
-
mask
.
sum
().
item
()
/
mask
.
numel
()
logger
.
info
(
f
"Attention sparsity:
{
sparsity
}
"
)
def
apply
(
self
,
q
,
k
,
v
,
cu_seqlens_q
=
None
,
cu_seqlens_kv
=
None
,
max_seqlen_q
=
None
,
max_seqlen_kv
=
None
,
model_cls
=
None
,
):
"""
q: [seqlen, head_num, head_dim]
k: [seqlen, head_num, head_dim]
v: [seqlen, head_num, head_dim]
"""
self
.
prepare_mask
(
seqlen
=
q
.
shape
[
0
])
out
=
magi_ffa_func
(
q
,
k
,
v
,
q_ranges
=
self
.
q_ranges
,
k_ranges
=
self
.
k_ranges
,
attn_type_map
=
self
.
attn_type_map
,
auto_range_merge
=
True
,
)[
0
]
return
out
.
reshape
(
out
.
shape
[
0
],
-
1
)
lightx2v/common/ops/attn/ring_attn.py
0 → 100644
View file @
a1ebc651
import
torch
import
torch.distributed
as
dist
import
torch.nn.functional
as
F
from
loguru
import
logger
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.registry_factory
import
ATTN_WEIGHT_REGISTER
from
.template
import
AttnWeightTemplate
from
.utils.ring_comm
import
RingComm
try
:
import
flash_attn
from
flash_attn.flash_attn_interface
import
flash_attn_varlen_func
except
ImportError
:
logger
.
info
(
"flash_attn_varlen_func not found, please install flash_attn2 first"
)
flash_attn_varlen_func
=
None
@
torch
.
jit
.
script
def
_update_out_and_lse
(
out
,
lse
,
block_out
,
block_lse
,
):
block_out
=
block_out
.
to
(
torch
.
float32
)
block_lse
=
block_lse
.
transpose
(
-
2
,
-
1
).
unsqueeze
(
dim
=-
1
)
# new_lse = lse + torch.log(1 + torch.exp(block_lse - lse))
# torch.exp(lse - new_lse) * out + torch.exp(block_lse - new_lse) * block_out
# For additional context and discussion, please refer to:
# https://github.com/zhuzilin/ring-flash-attention/pull/34#issuecomment-2076126795
out
=
out
-
F
.
sigmoid
(
block_lse
-
lse
)
*
(
out
-
block_out
)
lse
=
lse
-
F
.
logsigmoid
(
lse
-
block_lse
)
return
out
,
lse
@
ATTN_WEIGHT_REGISTER
(
"ring"
)
class
RingAttnWeight
(
AttnWeightTemplate
):
def
__init__
(
self
):
self
.
config
=
{}
def
apply
(
self
,
q
,
k
,
v
,
img_qkv_len
,
cu_seqlens_qkv
,
attention_module
=
None
,
seq_p_group
=
None
,
model_cls
=
None
,
use_fp8_comm
=
False
):
"""
执行 Ring 注意力机制,结合图像和文本的查询、键和值。
参数:
q (torch.Tensor): 查询张量,形状为 [shard_seqlen, heads, hidden_dims]
k (torch.Tensor): 键张量,形状为 [shard_seqlen, heads, hidden_dims]
v (torch.Tensor): 值张量,形状为 [shard_seqlen, heads, hidden_dims]
img_qkv_len (int): 图像查询、键和值的长度
cu_seqlens_qkv (torch.Tensor): 累积序列长度,包含文本和图像的长度信息
attention_type (str): 注意力类型,默认为 "flash_attn2"
返回:
torch.Tensor: 计算得到的注意力结果
"""
assert
not
use_fp8_comm
,
"RingAttn can't support fp8 comm now."
# 获取当前进程的排名和全局进程数
cur_rank
=
dist
.
get_rank
(
seq_p_group
)
world_size
=
dist
.
get_world_size
(
seq_p_group
)
if
len
(
cu_seqlens_qkv
)
==
3
:
txt_qkv_len
=
cu_seqlens_qkv
[
1
]
-
img_qkv_len
# 文本查询、键和值的长度
txt_mask_len
=
cu_seqlens_qkv
[
2
]
-
img_qkv_len
# 文本掩码长度
elif
len
(
cu_seqlens_qkv
)
==
2
:
txt_qkv_len
=
cu_seqlens_qkv
[
1
]
-
img_qkv_len
# 文本查询、键和值的长度
txt_mask_len
=
0
# if RING_COMM is None:
# init_ring_comm()
RING_COMM
=
RingComm
(
seq_p_group
)
# if len(cu_seqlens_qkv) == 3:
# txt_qkv_len = cu_seqlens_qkv[1] - img_qkv_len # 文本查询、键和值的长度
# txt_mask_len = cu_seqlens_qkv[2] - img_qkv_len # 文本掩码长度
# elif len(cu_seqlens_qkv) == 2:
# txt_qkv_len = cu_seqlens_qkv[1] - img_qkv_len # 文本查询、键和值的长度
# txt_mask_len = None
q
=
q
.
unsqueeze
(
0
)
k
=
k
.
unsqueeze
(
0
)
v
=
v
.
unsqueeze
(
0
)
img_q
,
img_k
,
img_v
=
q
[:,
:
img_qkv_len
,
:,
:].
contiguous
(),
k
[:,
:
img_qkv_len
,
:,
:].
contiguous
(),
v
[:,
:
img_qkv_len
,
:,
:].
contiguous
()
txt_q
,
txt_k
,
txt_v
=
(
q
[:,
img_qkv_len
:
img_qkv_len
+
txt_qkv_len
,
:,
:].
contiguous
(),
k
[:,
img_qkv_len
:
img_qkv_len
+
txt_qkv_len
,
:,
:].
contiguous
(),
v
[:,
img_qkv_len
:
img_qkv_len
+
txt_qkv_len
,
:,
:].
contiguous
(),
)
out
,
lse
,
next_k
,
next_v
=
None
,
None
,
None
,
None
if
len
(
cu_seqlens_qkv
)
==
3
:
q
=
torch
.
cat
((
img_q
,
txt_q
),
dim
=
1
)
k
=
img_k
v
=
img_v
for
step
in
range
(
world_size
):
if
step
+
1
!=
world_size
:
next_k
=
RING_COMM
.
send_recv
(
k
)
next_v
=
RING_COMM
.
send_recv
(
v
)
RING_COMM
.
commit
()
if
step
+
1
==
world_size
:
k
=
torch
.
cat
((
k
,
txt_k
),
dim
=
1
)
v
=
torch
.
cat
((
v
,
txt_v
),
dim
=
1
)
block_out
,
block_lse
=
self
.
ring_attn_sub
(
q
,
k
,
v
)
out
,
lse
=
self
.
update_out_and_lse
(
out
,
lse
,
block_out
,
block_lse
)
if
step
+
1
!=
world_size
:
RING_COMM
.
wait
()
k
=
next_k
v
=
next_v
attn1
=
out
.
to
(
GET_DTYPE
()).
squeeze
(
0
).
reshape
(
img_qkv_len
+
txt_qkv_len
,
-
1
)
if
txt_mask_len
>
0
:
attn2
,
*
_
=
flash_attn
.
flash_attn_interface
.
_flash_attn_forward
(
q
[:,
-
(
txt_mask_len
-
txt_qkv_len
)
:,
:,
:].
contiguous
(),
k
[:,
-
(
txt_mask_len
-
txt_qkv_len
)
:,
:,
:].
contiguous
(),
v
[:,
-
(
txt_mask_len
-
txt_qkv_len
)
:,
:,
:].
contiguous
(),
dropout_p
=
0.0
,
softmax_scale
=
q
.
shape
[
-
1
]
**
(
-
0.5
),
causal
=
False
,
window_size_left
=-
1
,
window_size_right
=-
1
,
softcap
=
0.0
,
alibi_slopes
=
None
,
return_softmax
=
False
,
)
attn2
=
attn2
.
to
(
GET_DTYPE
()).
squeeze
(
0
).
reshape
((
txt_mask_len
-
txt_qkv_len
),
-
1
)
attn1
=
torch
.
cat
([
attn1
,
attn2
],
dim
=
0
)
return
attn1
def
ring_attn_sub
(
self
,
q
,
k
,
v
,
dropout_p
=
0.0
,
softmax_scale
=
None
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
softcap
=
0.0
,
alibi_slopes
=
None
,
return_softmax
=
False
):
if
softmax_scale
is
None
:
softmax_scale
=
q
.
shape
[
-
1
]
**
(
-
0.5
)
block_out
,
block_lse
,
_
,
_
=
flash_attn
.
flash_attn_interface
.
_flash_attn_forward
(
q
,
k
,
v
,
dropout_p
=
dropout_p
,
softmax_scale
=
softmax_scale
,
causal
=
causal
,
window_size_left
=
window_size
[
0
],
window_size_right
=
window_size
[
1
],
softcap
=
softcap
,
alibi_slopes
=
alibi_slopes
,
return_softmax
=
return_softmax
,
)
return
block_out
,
block_lse
def
update_out_and_lse
(
self
,
out
,
lse
,
block_out
,
block_lse
,
slice_
=
None
,
):
if
out
is
None
:
if
slice_
is
not
None
:
raise
RuntimeError
(
"first update_out_and_lse should not pass slice_ args"
)
out
=
block_out
.
to
(
torch
.
float32
)
lse
=
block_lse
.
transpose
(
-
2
,
-
1
).
unsqueeze
(
dim
=-
1
)
elif
slice_
is
not
None
:
slice_out
,
slice_lse
=
out
[
slice_
],
lse
[
slice_
]
slice_out
,
slice_lse
=
_update_out_and_lse
(
slice_out
,
slice_lse
,
block_out
,
block_lse
)
out
[
slice_
],
lse
[
slice_
]
=
slice_out
,
slice_lse
else
:
out
,
lse
=
_update_out_and_lse
(
out
,
lse
,
block_out
,
block_lse
)
return
out
,
lse
lightx2v/common/ops/attn/sage_attn.py
0 → 100644
View file @
a1ebc651
import
torch
from
loguru
import
logger
from
lightx2v.utils.registry_factory
import
ATTN_WEIGHT_REGISTER
from
.template
import
AttnWeightTemplate
if
torch
.
cuda
.
is_available
()
and
torch
.
cuda
.
get_device_capability
(
0
)
in
[(
8
,
9
),
(
12
,
0
)]:
try
:
from
sageattention
import
sageattn_qk_int8_pv_fp16_triton
as
sageattn
except
ImportError
:
logger
.
info
(
"sageattn not found, please install sageattention first"
)
sageattn
=
None
else
:
try
:
from
sageattention
import
sageattn
except
ImportError
:
logger
.
info
(
"sageattn not found, please install sageattention first"
)
sageattn
=
None
try
:
from
sageattn3
import
sageattn3_blackwell
except
ImportError
:
logger
.
info
(
"sageattn3 not found, please install sageattention first"
)
sageattn3_blackwell
=
None
@
ATTN_WEIGHT_REGISTER
(
"sage_attn2"
)
class
SageAttn2Weight
(
AttnWeightTemplate
):
def
__init__
(
self
):
self
.
config
=
{}
def
apply
(
self
,
q
,
k
,
v
,
cu_seqlens_q
=
None
,
cu_seqlens_kv
=
None
,
max_seqlen_q
=
None
,
max_seqlen_kv
=
None
,
model_cls
=
None
,
):
q
,
k
,
v
=
q
.
contiguous
(),
k
.
contiguous
(),
v
.
contiguous
()
if
len
(
q
.
shape
)
==
3
:
bs
=
1
q
,
k
,
v
=
q
.
unsqueeze
(
0
),
k
.
unsqueeze
(
0
),
v
.
unsqueeze
(
0
)
elif
len
(
q
.
shape
)
==
4
:
bs
=
q
.
shape
[
0
]
x
=
sageattn
(
q
,
k
,
v
,
tensor_layout
=
"NHD"
,
).
view
(
bs
*
max_seqlen_q
,
-
1
)
return
x
@
ATTN_WEIGHT_REGISTER
(
"sage_attn3"
)
class
SageAttn3Weight
(
AttnWeightTemplate
):
def
__init__
(
self
):
self
.
config
=
{}
def
apply
(
self
,
q
,
k
,
v
,
cu_seqlens_q
=
None
,
cu_seqlens_kv
=
None
,
max_seqlen_q
=
None
,
max_seqlen_kv
=
None
,
model_cls
=
None
,
):
q
,
k
,
v
=
q
.
contiguous
(),
k
.
contiguous
(),
v
.
contiguous
()
if
len
(
q
.
shape
)
==
3
:
bs
=
1
q
,
k
,
v
=
q
.
unsqueeze
(
0
),
k
.
unsqueeze
(
0
),
v
.
unsqueeze
(
0
)
elif
len
(
q
.
shape
)
==
4
:
bs
=
q
.
shape
[
0
]
x
=
sageattn3_blackwell
(
q
.
transpose
(
1
,
2
),
k
.
transpose
(
1
,
2
),
v
.
transpose
(
1
,
2
)).
transpose
(
1
,
2
).
reshape
(
bs
*
max_seqlen_q
,
-
1
)
return
x
lightx2v/common/ops/attn/spassage_attn.py
0 → 100644
View file @
a1ebc651
import
os
import
torch
try
:
import
spas_sage_attn
except
ImportError
:
spas_sage_attn
=
None
from
lightx2v.utils.registry_factory
import
ATTN_WEIGHT_REGISTER
from
.template
import
AttnWeightTemplate
@
ATTN_WEIGHT_REGISTER
(
"spas_sage_attn"
)
class
SageAttnWeight
(
AttnWeightTemplate
):
def
__init__
(
self
):
self
.
config
=
{}
@
classmethod
def
apply
(
self
,
q
,
k
,
v
,
cu_seqlens_q
=
None
,
cu_seqlens_kv
=
None
,
max_seqlen_q
=
None
,
max_seqlen_kv
=
None
,
model_cls
=
None
,
tensor_layout
=
"HND"
):
q
=
q
.
unsqueeze
(
0
)
k
=
k
.
unsqueeze
(
0
)
v
=
v
.
unsqueeze
(
0
)
q
=
q
.
transpose
(
1
,
2
)
k
=
k
.
transpose
(
1
,
2
)
v
=
v
.
transpose
(
1
,
2
)
attn_out
=
spas_sage_attn
.
core
.
spas_sage2_attn_meansim_cuda
(
q
,
k
,
v
,
tensor_layout
)
_
,
H
,
N
,
D
=
attn_out
.
shape
attn_out
=
attn_out
.
permute
(
2
,
1
,
3
,
0
).
contiguous
().
view
(
N
,
H
*
D
)
return
attn_out
if
__name__
==
"__main__"
:
import
matplotlib.pyplot
as
plt
# 1. 构造输入
q
=
torch
.
randn
(
32760
,
12
,
128
,
dtype
=
torch
.
bfloat16
).
cuda
()
k
=
torch
.
randn
(
32760
,
12
,
128
,
dtype
=
torch
.
bfloat16
).
cuda
()
v
=
torch
.
randn
(
32760
,
12
,
128
,
dtype
=
torch
.
bfloat16
).
cuda
()
# 2. 直接用PyTorch计算注意力
q_
=
q
.
float
()
k_
=
k
.
float
()
v_
=
v
.
float
()
attn_weights
=
torch
.
matmul
(
q_
,
k_
.
transpose
(
-
2
,
-
1
))
/
(
128
**
0.5
)
attn_weights
=
torch
.
softmax
(
attn_weights
,
dim
=-
1
)
output_pt
=
torch
.
matmul
(
attn_weights
,
v_
)
# 3. 用spas_sage2_attn_meansim_cuda计算注意力
q
=
q
.
unsqueeze
(
0
)
# shape: (1, 32760, 12, 128)
k
=
k
.
unsqueeze
(
0
)
v
=
v
.
unsqueeze
(
0
)
q
=
q
.
transpose
(
1
,
2
)
# shape: (1, 12, 32760, 128)
k
=
k
.
transpose
(
1
,
2
)
v
=
v
.
transpose
(
1
,
2
)
output_cuda
=
spas_sage_attn
.
core
.
spas_sage2_attn_meansim_cuda
(
q
,
k
,
v
,
tensor_layout
=
"HND"
)
output_cuda
=
output_cuda
.
float
()
# 4. 取左上角[3000, 3000],只取第一个head
output_pt_crop
=
output_pt
[
0
,
:
3000
,
:
3000
].
cpu
().
detach
().
numpy
()
output_cuda_crop
=
output_cuda
[
0
,
0
,
:
3000
,
:
3000
].
cpu
().
detach
().
numpy
()
# 5. 保存图片
save_dir
=
os
.
path
.
expanduser
(
"~/Log/10-22/"
)
os
.
makedirs
(
save_dir
,
exist_ok
=
True
)
plt
.
imshow
(
output_pt_crop
,
aspect
=
"auto"
)
plt
.
title
(
"PyTorch Attention (left-top 3000x3000)"
)
plt
.
savefig
(
os
.
path
.
join
(
save_dir
,
"attn.png"
))
plt
.
close
()
plt
.
imshow
(
output_cuda_crop
,
aspect
=
"auto"
)
plt
.
title
(
"spas_sage2_attn_meansim_cuda (left-top 3000x3000)"
)
plt
.
savefig
(
os
.
path
.
join
(
save_dir
,
"spas_attn.png"
))
plt
.
close
()
lightx2v/common/ops/attn/svg2_attn.py
0 → 100644
View file @
a1ebc651
from
typing
import
Optional
# Please reinstall flashinfer by referring to https://github.com/svg-project/Sparse-VideoGen
try
:
import
flashinfer
except
ImportError
:
flashinfer
=
None
import
torch
import
triton
import
triton.language
as
tl
from
lightx2v.utils.registry_factory
import
ATTN_WEIGHT_REGISTER
from
.svg2_attn_utils
import
(
batch_kmeans_Euclid
,
identify_dynamic_map
,
)
from
.template
import
AttnWeightTemplate
@
triton
.
jit
def
_permute_kernel
(
X_ptr
,
IDX_ptr
,
Y_ptr
,
S
:
tl
.
constexpr
,
D
:
tl
.
constexpr
,
BLOCK_S
:
tl
.
constexpr
,
):
"""Each program permutes BLOCK_S tokens *all* hidden features (D). No inner python loop."""
pid_bh
=
tl
.
program_id
(
0
)
tile_s
=
tl
.
program_id
(
1
)
# Offsets along sequence
s_offsets
=
tile_s
*
BLOCK_S
+
tl
.
arange
(
0
,
BLOCK_S
)
token_mask
=
s_offsets
<
S
# Gather source indices for these tokens
idx_ptrs
=
IDX_ptr
+
pid_bh
*
S
+
s_offsets
src_row_idx
=
tl
.
load
(
idx_ptrs
,
mask
=
token_mask
,
other
=
0
).
to
(
tl
.
int32
)
# Broadcast to create 2-D pointer matrix (BLOCK_S, D)
d_offsets
=
tl
.
arange
(
0
,
D
)
src_ptrs
=
X_ptr
+
(
pid_bh
*
S
+
src_row_idx
[:,
None
])
*
D
+
d_offsets
[
None
,
:]
dst_ptrs
=
Y_ptr
+
(
pid_bh
*
S
+
s_offsets
[:,
None
])
*
D
+
d_offsets
[
None
,
:]
full_mask
=
token_mask
[:,
None
]
values
=
tl
.
load
(
src_ptrs
,
mask
=
full_mask
,
other
=
0.0
)
tl
.
store
(
dst_ptrs
,
values
,
mask
=
full_mask
)
def
permute_tensor_by_labels_triton
(
tensor
:
torch
.
Tensor
,
labels
:
Optional
[
torch
.
Tensor
],
dim
:
int
,
*
,
sorted_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
):
"""
Permute `tensor` along `dim` according to ascending order of `labels`.
This is a Triton-accelerated replacement for the original implementation.
It currently supports 4-D tensors of shape [B, H, S, D] and `dim == 2`.
If these conditions are not met or the tensors reside on CPU, we fall back
to the reference PyTorch implementation.
"""
# Assertions – we only support the optimized CUDA path.
assert
dim
==
2
,
"permute_tensor_by_labels currently only supports dim==2 (sequence dimension)"
assert
tensor
.
dim
()
==
4
,
"Expected tensor shape [B,H,S,D]"
assert
tensor
.
is_cuda
,
"permute_tensor_by_labels requires CUDA tensors"
B
,
H
,
S
,
D
=
tensor
.
shape
BH
=
B
*
H
# Determine sorted indices
if
sorted_indices
is
not
None
:
sorted_indices
=
sorted_indices
.
to
(
torch
.
int32
).
contiguous
()
else
:
assert
labels
is
not
None
,
"Either `labels` or `sorted_indices` must be provided."
labels
=
labels
.
to
(
tensor
.
device
)
sorted_indices
=
torch
.
argsort
(
labels
,
dim
=-
1
).
to
(
torch
.
int32
).
contiguous
()
# Flatten tensor and allocate output
inp_flat
=
tensor
.
reshape
(
BH
,
S
,
D
).
contiguous
()
out_flat
=
torch
.
empty_like
(
inp_flat
)
# Triton kernel tile size
BLOCK_S
=
64
# number of tokens per program, tunable
n_s_tiles
=
triton
.
cdiv
(
S
,
BLOCK_S
)
grid
=
(
BH
,
n_s_tiles
)
_permute_kernel
[
grid
](
inp_flat
,
sorted_indices
,
out_flat
,
S
,
D
,
BLOCK_S
,
num_warps
=
4
)
permuted_tensor
=
out_flat
.
reshape
(
B
,
H
,
S
,
D
)
return
permuted_tensor
,
sorted_indices
@
triton
.
jit
def
_inverse_permute_kernel
(
X_ptr
,
IDX_ptr
,
Y_ptr
,
S
:
tl
.
constexpr
,
D
:
tl
.
constexpr
,
BLOCK_S
:
tl
.
constexpr
,
):
"""Inverse permutation: scatter BLOCK_S tokens back in one shot."""
pid_bh
=
tl
.
program_id
(
0
)
tile_s
=
tl
.
program_id
(
1
)
s_offsets
=
tile_s
*
BLOCK_S
+
tl
.
arange
(
0
,
BLOCK_S
)
token_mask
=
s_offsets
<
S
idx_ptrs
=
IDX_ptr
+
pid_bh
*
S
+
s_offsets
src_pos_idx
=
s_offsets
.
to
(
tl
.
int32
)
dst_pos_idx
=
tl
.
load
(
idx_ptrs
,
mask
=
token_mask
,
other
=
0
).
to
(
tl
.
int32
)
d_offsets
=
tl
.
arange
(
0
,
D
)
src_ptrs
=
X_ptr
+
(
pid_bh
*
S
+
src_pos_idx
[:,
None
])
*
D
+
d_offsets
[
None
,
:]
dst_ptrs
=
Y_ptr
+
(
pid_bh
*
S
+
dst_pos_idx
[:,
None
])
*
D
+
d_offsets
[
None
,
:]
full_mask
=
token_mask
[:,
None
]
values
=
tl
.
load
(
src_ptrs
,
mask
=
full_mask
,
other
=
0.0
)
tl
.
store
(
dst_ptrs
,
values
,
mask
=
full_mask
)
def
apply_inverse_permutation_triton
(
permuted_tensor
:
torch
.
Tensor
,
sorted_indices
:
torch
.
Tensor
,
dim
:
int
,
):
"""
Triton implementation of inverse permutation. Inverse the permutation applied by `permute_tensor_by_labels`.
Args:
permuted_tensor: (B, H, S, D).
sorted_indices: (B, H, S).
dim: Dimension along which to apply inverse permutation. Typically 2, meaning the sequence lengthdimension.
Returns:
Tensor of shape (B, H, S, D).
"""
assert
dim
==
2
,
"apply_inverse_permutation currently only supports dim==2"
assert
permuted_tensor
.
dim
()
==
4
,
"Expected tensor shape [B,H,S,D]"
assert
permuted_tensor
.
is_cuda
,
"apply_inverse_permutation requires CUDA tensors"
B
,
H
,
S
,
D
=
permuted_tensor
.
shape
BH
=
B
*
H
# Ensure index dtype
sorted_indices
=
sorted_indices
.
to
(
torch
.
int32
).
contiguous
()
# Flatten inputs
inp_flat
=
permuted_tensor
.
reshape
(
BH
,
S
,
D
).
contiguous
()
out_flat
=
torch
.
empty_like
(
inp_flat
)
BLOCK_S
=
64
n_s_tiles
=
triton
.
cdiv
(
S
,
BLOCK_S
)
grid
=
(
BH
,
n_s_tiles
)
_inverse_permute_kernel
[
grid
](
inp_flat
,
sorted_indices
,
out_flat
,
S
,
D
,
BLOCK_S
,
num_warps
=
4
)
original_tensor
=
out_flat
.
reshape
(
B
,
H
,
S
,
D
)
return
original_tensor
@
ATTN_WEIGHT_REGISTER
(
"svg2_attn"
)
class
Svg2AttnWeight
(
AttnWeightTemplate
):
centroids_init
=
False
num_q_centroids
=
300
num_k_centroids
=
1000
kmeans_iter_init
=
50
top_p_kmeans
=
0.9
min_kc_ratio
=
0.10
kmeans_iter_step
=
2
def
__init__
(
self
):
self
.
config
=
{}
def
apply
(
self
,
q
,
k
,
v
,
cu_seqlens_q
=
None
,
cu_seqlens_kv
=
None
,
max_seqlen_q
=
None
,
max_seqlen_kv
=
None
,
model_cls
=
None
,
):
q
=
q
.
unsqueeze
(
0
).
transpose
(
1
,
2
)
k
=
k
.
unsqueeze
(
0
).
transpose
(
1
,
2
)
v
=
v
.
unsqueeze
(
0
).
transpose
(
1
,
2
)
bs
,
num_heads
,
seq_len
,
dim
=
q
.
size
()
q_perm
,
k_perm
,
v_perm
,
dyn_map
,
qc_sz_s
,
kc_sz_s
,
q_sorted_indices
=
self
.
semantic_aware_permutation
(
q
,
k
,
v
)
output_permuted
=
self
.
dynamic_block_sparse_fwd_flashinfer
(
q_perm
,
k_perm
,
v_perm
,
dyn_map
,
qc_sz_s
,
kc_sz_s
,
is_cpu
=
False
)
attn_output
=
apply_inverse_permutation_triton
(
output_permuted
,
q_sorted_indices
,
dim
=
2
)
return
attn_output
.
reshape
(
bs
,
num_heads
,
seq_len
,
dim
).
transpose
(
1
,
2
).
reshape
(
bs
*
seq_len
,
-
1
)
def
dynamic_block_sparse_fwd_flashinfer
(
self
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
block_mask_map
:
torch
.
Tensor
,
block_row_sz
:
torch
.
Tensor
,
block_col_sz
:
torch
.
Tensor
,
is_cpu
:
bool
=
True
,
):
"""
Launcher for the Flashinfer dynamic block sparse attention kernel.
Args:
q (torch.Tensor): Query tensor, shape [B, H, S, D].
k (torch.Tensor): Key tensor, shape [B, H, S, D].
v (torch.Tensor): Value tensor, shape [B, H, S, D].
block_mask_map (torch.Tensor): Boolean mask, shape [B, H, qc_num, kc_num]. Currently must on CPU.
block_row_sz (torch.Tensor): Query block sizes, shape [B, H, qc_num]. Currently must on CPU.
block_col_sz (torch.Tensor): Key block sizes, shape [B, H, kc_num]. Currently must on CPU.
is_cpu (bool): Whether to run on CPU. Flashinfer default is to run on CPU. We switch to GPU for faster planning. Default is True.
"""
# Input shape check
B
,
H
,
S
,
D
=
q
.
shape
qc_num
=
block_row_sz
.
shape
[
-
1
]
kc_num
=
block_col_sz
.
shape
[
-
1
]
assert
block_mask_map
.
shape
==
(
B
,
H
,
qc_num
,
kc_num
)
assert
all
(
t
.
device
==
torch
.
device
(
"cpu"
)
for
t
in
[
block_mask_map
,
block_row_sz
,
block_col_sz
])
if
is_cpu
else
True
# Check if block_col_sz and block_row_sz are the same for each head
assert
torch
.
all
(
block_col_sz
.
sum
(
dim
=
2
)
==
block_col_sz
.
sum
(
dim
=
2
)[
0
,
0
])
assert
torch
.
all
(
block_row_sz
.
sum
(
dim
=
2
)
==
block_row_sz
.
sum
(
dim
=
2
)[
0
,
0
])
# Prepare flashinfer wrapper
float_workspace_buffer
=
torch
.
empty
(
128
*
1024
*
1024
,
device
=
q
.
device
)
vector_sparse_indices_buffer
=
torch
.
empty
(
1024
*
1024
*
1024
,
device
=
q
.
device
)
wrapper
=
flashinfer
.
sparse
.
VariableBlockSparseAttentionWrapper
(
float_workspace_buffer
,
backend
=
"auto"
)
wrapper
.
reset_workspace_buffer
(
float_workspace_buffer
=
wrapper
.
_float_workspace_buffer
,
int_workspace_buffer
=
wrapper
.
_int_workspace_buffer
,
vector_sparse_indices_buffer
=
vector_sparse_indices_buffer
,
# Only reset this buffer size
vector_sparse_indptr_buffer
=
wrapper
.
_vector_sparse_indptr_buffer
,
)
block_mask_map
=
block_mask_map
.
reshape
(
B
*
H
,
qc_num
,
kc_num
)
block_row_sz
=
block_row_sz
.
reshape
(
B
*
H
,
qc_num
)
block_col_sz
=
block_col_sz
.
reshape
(
B
*
H
,
kc_num
)
wrapper
.
plan
(
block_mask_map
=
block_mask_map
,
block_row_sz
=
block_row_sz
,
block_col_sz
=
block_col_sz
,
num_qo_heads
=
B
*
H
,
num_kv_heads
=
B
*
H
,
head_dim
=
D
,
q_data_type
=
q
.
dtype
,
kv_data_type
=
k
.
dtype
,
)
# print_memory_usage("After plan")
q
=
q
.
reshape
(
B
*
H
,
S
,
D
)
k
=
k
.
reshape
(
B
*
H
,
S
,
D
)
v
=
v
.
reshape
(
B
*
H
,
S
,
D
)
o
=
wrapper
.
run
(
q
,
k
,
v
)
# [num_qo_heads, qo_len, head_dim]
o
=
o
.
reshape
(
B
,
H
,
S
,
D
)
return
o
def
semantic_aware_permutation
(
self
,
query
,
key
,
value
):
cfg
,
num_heads
,
seq_len
,
dim
=
query
.
size
()
# 1. Kmeans clustering
qlabels
,
qcentroids
,
qcluster_sizes
,
qiter
,
klabels
,
kcentroids
,
kcluster_sizes
,
kiter
=
self
.
kmeans_clustering
(
query
,
key
)
# 2. Identify dynamic map
q_cluster_sizes
=
qcluster_sizes
.
view
(
cfg
,
num_heads
,
self
.
num_q_centroids
)
k_cluster_sizes
=
kcluster_sizes
.
view
(
cfg
,
num_heads
,
self
.
num_k_centroids
)
dynamic_map
=
identify_dynamic_map
(
qcentroids
.
view
(
cfg
,
num_heads
,
self
.
num_q_centroids
,
dim
),
kcentroids
.
view
(
cfg
,
num_heads
,
self
.
num_k_centroids
,
dim
),
q_cluster_sizes
,
k_cluster_sizes
,
self
.
top_p_kmeans
,
self
.
min_kc_ratio
,
)
# 3. Permute the query, key, value
q_permuted
,
q_sorted_indices
=
permute_tensor_by_labels_triton
(
query
,
qlabels
,
dim
=
2
)
k_permuted
,
k_sorted_indices
=
permute_tensor_by_labels_triton
(
key
,
klabels
,
dim
=
2
)
v_permuted
,
v_sorted_indices
=
permute_tensor_by_labels_triton
(
value
,
klabels
,
dim
=
2
,
sorted_indices
=
k_sorted_indices
)
return
q_permuted
,
k_permuted
,
v_permuted
,
dynamic_map
,
q_cluster_sizes
,
k_cluster_sizes
,
q_sorted_indices
def
kmeans_clustering
(
self
,
query
,
key
):
if
not
self
.
centroids_init
:
qlabels
,
qcentroids
,
qcluster_sizes
,
qiter
,
klabels
,
kcentroids
,
kcluster_sizes
,
kiter
=
self
.
kmeans_init
(
query
,
key
)
self
.
centroids_init
=
True
else
:
qlabels
,
qcentroids
,
qcluster_sizes
,
qiter
,
klabels
,
kcentroids
,
kcluster_sizes
,
kiter
=
self
.
kmeans_step
(
query
,
key
)
return
qlabels
,
qcentroids
,
qcluster_sizes
,
qiter
,
klabels
,
kcentroids
,
kcluster_sizes
,
kiter
def
kmeans_init
(
self
,
query
,
key
):
cfg
,
num_heads
,
seq_len
,
dim
=
query
.
size
()
qlabels
,
qcentroids
,
qcluster_sizes
,
qiter
=
batch_kmeans_Euclid
(
query
.
view
(
cfg
*
num_heads
,
seq_len
,
dim
),
n_clusters
=
self
.
num_q_centroids
,
max_iters
=
self
.
kmeans_iter_init
)
klabels
,
kcentroids
,
kcluster_sizes
,
kiter
=
batch_kmeans_Euclid
(
key
.
view
(
cfg
*
num_heads
,
seq_len
,
dim
),
n_clusters
=
self
.
num_k_centroids
,
max_iters
=
self
.
kmeans_iter_init
)
self
.
q_centroids
=
qcentroids
self
.
k_centroids
=
kcentroids
return
qlabels
,
qcentroids
,
qcluster_sizes
,
qiter
,
klabels
,
kcentroids
,
kcluster_sizes
,
kiter
def
kmeans_step
(
self
,
query
,
key
):
cfg
,
num_heads
,
seq_len
,
dim
=
query
.
size
()
qlabels
,
qcentroids
,
qcluster_sizes
,
qiter
=
batch_kmeans_Euclid
(
query
.
view
(
cfg
*
num_heads
,
seq_len
,
dim
),
n_clusters
=
self
.
num_q_centroids
,
max_iters
=
self
.
kmeans_iter_step
,
init_centroids
=
self
.
q_centroids
,
)
klabels
,
kcentroids
,
kcluster_sizes
,
kiter
=
batch_kmeans_Euclid
(
key
.
view
(
cfg
*
num_heads
,
seq_len
,
dim
),
n_clusters
=
self
.
num_k_centroids
,
max_iters
=
self
.
kmeans_iter_step
,
init_centroids
=
self
.
k_centroids
,
)
self
.
q_centroids
=
qcentroids
self
.
k_centroids
=
kcentroids
return
qlabels
,
qcentroids
,
qcluster_sizes
,
qiter
,
klabels
,
kcentroids
,
kcluster_sizes
,
kiter
if
__name__
==
"__main__"
:
q
,
k
,
v
=
torch
.
randn
(
32130
,
40
,
128
,
dtype
=
torch
.
bfloat16
).
cuda
(),
torch
.
randn
(
32130
,
40
,
128
,
dtype
=
torch
.
bfloat16
).
cuda
(),
torch
.
randn
(
32130
,
40
,
128
,
dtype
=
torch
.
bfloat16
).
cuda
()
svg2_attn
=
Svg2AttnWeight
()
print
(
"Svg2AttnWeight initialized."
)
out
=
svg2_attn
.
apply
(
q
,
k
,
v
)
print
(
f
"out:
{
out
.
shape
}
,
{
out
.
dtype
}
,
{
out
.
device
}
"
)
lightx2v/common/ops/attn/svg2_attn_utils.py
0 → 100644
View file @
a1ebc651
import
torch
import
torch.nn.functional
as
F
import
triton
import
triton.language
as
tl
try
:
from
cuvs.cluster.kmeans
import
KMeansParams
,
fit
except
ImportError
:
KMeansParams
=
None
fit
=
None
# --- New functions ---
def
density_calculation
(
dynamic_map
,
q_cluster_sizes
,
k_cluster_sizes
):
"""
Calculate the density of the dynamic map. Currently only batch size = 1 and head size = 1 are supported.
Input:
dynamic_map: [cfg, num_heads, qc_num, kc_num]
q_cluster_sizes: [cfg, num_heads, qc_num]
k_cluster_sizes: [cfg, num_heads, kc_num]
"""
cfg
,
num_heads
,
qc_num
,
kc_num
=
dynamic_map
.
shape
# Calculate the block size of each block
clustered_block_size
=
q_cluster_sizes
[:,
:,
:,
None
]
*
k_cluster_sizes
[:,
:,
None
,
:]
masked_block_size
=
clustered_block_size
*
dynamic_map
# Calculate the density of each block
density
=
torch
.
sum
(
masked_block_size
,
dim
=
(
2
,
3
))
/
torch
.
sum
(
clustered_block_size
,
dim
=
(
2
,
3
))
return
density
# --- Functions from analyze/kmeans_rapidai.py ---
def
pairwise_distance
(
x
,
y
):
"""
Computes pairwise squared Euclidean distance between two sets of points.
"""
x_norm
=
(
x
**
2
).
sum
(
1
).
view
(
-
1
,
1
)
y_norm
=
(
y
**
2
).
sum
(
1
).
view
(
1
,
-
1
)
dist
=
torch
.
clamp
(
x_norm
+
y_norm
-
2.0
*
torch
.
mm
(
x
,
torch
.
transpose
(
y
,
0
,
1
)),
min
=
0.0
)
return
dist
def
kmeans_predict
(
centroids
,
input_tensor
):
# Removed unused params argument
"""
Predict the labels for the input tensor using the centroids.
"""
input_tensor
=
input_tensor
.
to
(
torch
.
float32
)
dist
=
pairwise_distance
(
input_tensor
,
centroids
)
labels
=
torch
.
argmin
(
dist
,
dim
=
1
)
return
labels
def
kmeans_rapidai
(
tensor
,
k
,
max_iter
=
5
,
tol
=
1e-4
,
init_method
=
"Array"
,
centroids_init
=
None
):
# Renamed centroids to centroids_init
"""
Performs K-means clustering using cuVS.
"""
assert
tensor
.
dtype
==
torch
.
float32
,
"Tensor must be float32 for cuVS KMeans"
assert
tensor
.
ndim
==
2
,
f
"Tensor must be 2D, but got
{
tensor
.
shape
}
"
# assert init_method == "Array", "init_method must be 'Array' for now"
L
,
D
=
tensor
.
shape
# cuVS KMeans in RAPIDS >=23.10 uses 'centroids_init' for initial centroids
current_centroids
=
centroids_init
if
current_centroids
is
None
:
# Default init: cuVS handles KMeansPlusPlus if centroids_init is None and init_method is KMeansPlusPlus
# If you need to pass an empty tensor for cuVS to initialize:
current_centroids
=
torch
.
empty
(
k
,
D
,
device
=
tensor
.
device
,
dtype
=
torch
.
float32
)
# Or pass None
else
:
assert
current_centroids
.
dtype
==
torch
.
float32
,
"Initial centroids must be float32"
assert
current_centroids
.
shape
==
(
k
,
D
,
),
f
"Initial centroids shape mismatch, got
{
current_centroids
.
shape
}
, expected (
{
k
}
,
{
D
}
)"
# cuVS uses 'init_method="Array"' when 'centroids_init' is provided.
# import IPython; IPython.embed()
params
=
KMeansParams
(
n_clusters
=
k
,
max_iter
=
max_iter
,
tol
=
tol
,
init_method
=
init_method
)
# Changed init_method to init
# Call fit with centroids_init (can be None)
new_centroids
,
inertia
,
n_iter_
=
fit
(
params
,
tensor
,
current_centroids
)
# Added handle=None
labels
=
kmeans_predict
(
new_centroids
,
tensor
)
return
labels
,
new_centroids
,
n_iter_
@
triton
.
jit
def
_centroid_update_kernel
(
x_ptr
,
# *f16 [B, N, D]
cluster_ptr
,
# *i32 [B, N]
sum_ptr
,
# *f32 [B, K, D]
count_ptr
,
# *i32 [B, K]
B
:
tl
.
constexpr
,
N
:
tl
.
constexpr
,
D
:
tl
.
constexpr
,
K
:
tl
.
constexpr
,
BLOCK_D
:
tl
.
constexpr
,
# number of dims processed per program
):
"""Each program processes 1 point (token) across BLOCK_D dimensions with atomics."""
pid
=
tl
.
program_id
(
axis
=
0
)
token_idx
=
pid
# range: [0, B * N)
# Derive (b, n) indices
b
=
token_idx
//
N
n
=
token_idx
%
N
# Pointers to the token features and its cluster id
x_offset
=
(
b
*
N
+
n
)
*
D
x_ptr
=
x_ptr
+
x_offset
cluster_idx
=
tl
.
load
(
cluster_ptr
+
b
*
N
+
n
)
# int32
# Guard for invalid cluster ids (should not happen)
cluster_idx
=
tl
.
where
(
cluster_idx
<
K
,
cluster_idx
,
0
)
# Base pointer for this centroid in the output sum tensor
centroid_base
=
(
b
*
K
+
cluster_idx
)
*
D
# Process feature vector in chunks of BLOCK_D
offs
=
tl
.
arange
(
0
,
BLOCK_D
)
for
d_start
in
range
(
0
,
D
,
BLOCK_D
):
mask
=
offs
+
d_start
<
D
feats
=
tl
.
load
(
x_ptr
+
d_start
+
offs
,
mask
=
mask
,
other
=
0.0
)
feats
=
feats
.
to
(
tl
.
float32
)
dest_ptr
=
sum_ptr
+
centroid_base
+
d_start
+
offs
tl
.
atomic_add
(
dest_ptr
,
feats
,
mask
=
mask
)
# Update counts (only once per point)
tl
.
atomic_add
(
count_ptr
+
b
*
K
+
cluster_idx
,
1
)
def
triton_centroid_update_cosine
(
x_norm
:
torch
.
Tensor
,
cluster_ids
:
torch
.
Tensor
,
old_centroids
:
torch
.
Tensor
):
"""Compute centroids using custom Triton kernel.
Args:
x_norm (Tensor): (B, N, D) normalized input vectors (float16/float32)
cluster_ids (LongTensor): (B, N) cluster assignment per point
old_centroids (Tensor): (B, K, D) previous centroids (same dtype as x_norm)
Returns:
Tensor: (B, K, D) updated and L2-normalized centroids (dtype == x_norm.dtype)
"""
assert
x_norm
.
is_cuda
and
cluster_ids
.
is_cuda
,
"Input tensors must be on CUDA device"
B
,
N
,
D
=
x_norm
.
shape
K
=
old_centroids
.
shape
[
1
]
assert
cluster_ids
.
shape
==
(
B
,
N
)
# Allocate accumulation buffers
centroid_sums
=
torch
.
zeros
((
B
,
K
,
D
),
device
=
x_norm
.
device
,
dtype
=
torch
.
float32
)
centroid_counts
=
torch
.
zeros
((
B
,
K
),
device
=
x_norm
.
device
,
dtype
=
torch
.
int32
)
# Launch Triton kernel – one program per token
total_tokens
=
B
*
N
BLOCK_D
=
128
# tuneable
grid
=
(
total_tokens
,)
_centroid_update_kernel
[
grid
](
x_norm
,
cluster_ids
.
to
(
torch
.
int32
),
centroid_sums
,
centroid_counts
,
B
,
N
,
D
,
K
,
BLOCK_D
=
BLOCK_D
,
)
# Compute means; keep old centroid if empty cluster
counts_f
=
centroid_counts
.
to
(
torch
.
float32
).
unsqueeze
(
-
1
).
clamp
(
min
=
1.0
)
centroids
=
centroid_sums
/
counts_f
# For clusters with zero count, revert to old centroids
zero_mask
=
(
centroid_counts
==
0
).
unsqueeze
(
-
1
)
centroids
=
torch
.
where
(
zero_mask
,
old_centroids
.
to
(
torch
.
float32
),
centroids
)
centroids
=
centroids
.
to
(
x_norm
.
dtype
)
centroids
=
F
.
normalize
(
centroids
,
p
=
2
,
dim
=-
1
)
return
centroids
def
torch_loop_centroid_update_cosine
(
x_norm
:
torch
.
Tensor
,
cluster_ids
:
torch
.
Tensor
,
old_centroids
:
torch
.
Tensor
):
"""Reference Python implementation (double for-loop)"""
B
,
N
,
D
=
x_norm
.
shape
K
=
old_centroids
.
shape
[
1
]
new_centroids
=
torch
.
zeros_like
(
old_centroids
)
for
b
in
range
(
B
):
for
k
in
range
(
K
):
mask
=
cluster_ids
[
b
]
==
k
if
mask
.
any
():
new_centroids
[
b
,
k
]
=
F
.
normalize
(
x_norm
[
b
][
mask
].
mean
(
dim
=
0
,
dtype
=
x_norm
.
dtype
),
p
=
2
,
dim
=
0
)
else
:
new_centroids
[
b
,
k
]
=
old_centroids
[
b
,
k
]
return
new_centroids
def
triton_centroid_update_euclid
(
x
:
torch
.
Tensor
,
cluster_ids
:
torch
.
Tensor
,
old_centroids
:
torch
.
Tensor
):
"""Compute centroids for Euclidean KMeans using Triton.
Args:
x (Tensor): (B, N, D) input vectors (float16/float32)
cluster_ids (LongTensor): (B, N) cluster assignment per point
old_centroids (Tensor): (B, K, D) previous centroids (same dtype as x)
Returns:
Tensor: (B, K, D) updated centroids (dtype == x.dtype)
"""
assert
x
.
is_cuda
and
cluster_ids
.
is_cuda
,
"Input tensors must be on CUDA device"
B
,
N
,
D
=
x
.
shape
K
=
old_centroids
.
shape
[
1
]
assert
cluster_ids
.
shape
==
(
B
,
N
)
# Allocate accumulation buffers
centroid_sums
=
torch
.
zeros
((
B
,
K
,
D
),
device
=
x
.
device
,
dtype
=
torch
.
float32
)
centroid_counts
=
torch
.
zeros
((
B
,
K
),
device
=
x
.
device
,
dtype
=
torch
.
int32
)
total_tokens
=
B
*
N
BLOCK_D
=
128
# tuneable
grid
=
(
total_tokens
,)
_centroid_update_kernel
[
grid
](
x
,
cluster_ids
.
to
(
torch
.
int32
),
centroid_sums
,
centroid_counts
,
B
,
N
,
D
,
K
,
BLOCK_D
=
BLOCK_D
,
)
# Compute means; keep old centroid if empty cluster
counts_f
=
centroid_counts
.
to
(
torch
.
float32
).
unsqueeze
(
-
1
).
clamp
(
min
=
1.0
)
centroids
=
centroid_sums
/
counts_f
# For clusters with zero count, revert to old centroids
zero_mask
=
(
centroid_counts
==
0
).
unsqueeze
(
-
1
)
centroids
=
torch
.
where
(
zero_mask
,
old_centroids
.
to
(
torch
.
float32
),
centroids
)
return
centroids
.
to
(
x
.
dtype
)
# ------------------------------ NEW: chunk-wise centroid update (sorted ids) ------------------------------
@
triton
.
jit
def
_centroid_update_chunk_kernel
(
x_ptr
,
# *f16 / *f32 [B, N, D] – ORIGINAL ORDER
sorted_idx_ptr
,
# *i32 [B, N] – indices after sort
sorted_cluster_ptr
,
# *i32 [B, N] – cluster ids in sorted order
sum_ptr
,
# *f32 [B, K, D]
count_ptr
,
# *i32 [B, K]
B
:
tl
.
constexpr
,
N
:
tl
.
constexpr
,
D
:
tl
.
constexpr
,
K
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
# how many tokens (points) each program processes
):
"""Each program processes **BLOCK_N consecutive, already-sorted tokens**.
Because the tokens are sorted by cluster id, identical ids appear in
contiguous runs. We therefore accumulate a local sum/count for the
current run and perform **a single atomic update per run**, instead of
per-token.
"""
# program indices – 2-D launch grid: (chunk_id, batch_id)
pid_chunk
=
tl
.
program_id
(
axis
=
0
)
pid_b
=
tl
.
program_id
(
axis
=
1
)
b
=
pid_b
chunk_start
=
pid_chunk
*
BLOCK_N
# position of the first token handled by this program
# Nothing to do – out of range
if
chunk_start
>=
N
:
return
# base pointers for this batch
idx_batch_base
=
sorted_idx_ptr
+
b
*
N
cid_batch_base
=
sorted_cluster_ptr
+
b
*
N
x_batch_base
=
x_ptr
+
b
*
N
*
D
# for pointer arithmetic
# helper aranges
offs_token
=
tl
.
arange
(
0
,
BLOCK_N
)
offs_dim
=
tl
.
arange
(
0
,
D
)
# first token index & validity mask
token_idx
=
chunk_start
+
offs_token
valid_tok
=
token_idx
<
N
first_token_idx
=
chunk_start
last_token_idx
=
tl
.
minimum
(
chunk_start
+
BLOCK_N
,
N
)
-
1
# Load first cluster id to initialise the running accumulator
first_id
=
tl
.
load
(
cid_batch_base
+
first_token_idx
)
last_id
=
tl
.
load
(
cid_batch_base
+
last_token_idx
)
all_ids
=
tl
.
load
(
cid_batch_base
+
token_idx
,
mask
=
valid_tok
,
other
=-
1
)
all_tokens_idxs
=
tl
.
load
(
idx_batch_base
+
token_idx
,
mask
=
valid_tok
,
other
=-
1
)
# [BLOCK_N]
load_mask
=
all_tokens_idxs
[:,
None
]
*
D
+
offs_dim
[
None
,
:]
for
cid
in
range
(
first_id
,
last_id
+
1
):
cluster_mask
=
all_ids
==
cid
cluster_size
=
tl
.
sum
(
cluster_mask
.
to
(
tl
.
int32
))
if
cluster_size
!=
0
:
cluster_feats
=
tl
.
load
(
x_batch_base
+
load_mask
,
mask
=
cluster_mask
[:,
None
],
other
=
0.0
)
# [BLOCK_N, D]
cluster_feats
=
cluster_feats
.
to
(
tl
.
float32
)
sum_feats
=
tl
.
sum
(
cluster_feats
,
axis
=
0
)
dest_ptr
=
sum_ptr
+
(
b
*
K
+
cid
)
*
D
+
offs_dim
tl
.
atomic_add
(
dest_ptr
,
sum_feats
)
tl
.
atomic_add
(
count_ptr
+
b
*
K
+
cid
,
cluster_size
)
# ---------------------------------------------------------------------------------------------
def
triton_centroid_update_sorted_cosine
(
x_norm
:
torch
.
Tensor
,
cluster_ids
:
torch
.
Tensor
,
old_centroids
:
torch
.
Tensor
,
*
,
BLOCK_N
:
int
=
256
):
"""Fast centroid update assuming **cluster_ids are sorted along N**.
This helper will sort the assignments (together with `x_norm`) and launch the
chunk kernel above. Compared to the naive per-token kernel it performs *one
atomic add per run of identical ids* instead of per token, providing large
speed-ups when clusters are reasonably sized.
"""
assert
x_norm
.
is_cuda
and
cluster_ids
.
is_cuda
,
"Inputs must be on CUDA"
B
,
N
,
D
=
x_norm
.
shape
K
=
old_centroids
.
shape
[
1
]
assert
cluster_ids
.
shape
==
(
B
,
N
)
# -------- sort per-batch --------
sorted_cluster_ids
,
sorted_idx
=
torch
.
sort
(
cluster_ids
,
dim
=-
1
)
sorted_idx_int
=
sorted_idx
.
to
(
torch
.
int32
)
# accumulation buffers
centroid_sums
=
torch
.
zeros
((
B
,
K
,
D
),
device
=
x_norm
.
device
,
dtype
=
torch
.
float32
)
centroid_cnts
=
torch
.
zeros
((
B
,
K
),
device
=
x_norm
.
device
,
dtype
=
torch
.
int32
)
grid
=
(
triton
.
cdiv
(
N
,
BLOCK_N
),
B
)
_centroid_update_chunk_kernel
[
grid
](
x_norm
,
sorted_idx_int
,
sorted_cluster_ids
.
to
(
torch
.
int32
),
centroid_sums
,
centroid_cnts
,
B
,
N
,
D
,
K
,
BLOCK_N
=
BLOCK_N
,
)
# finalise – convert to means, handle empty clusters, renormalise
counts_f
=
centroid_cnts
.
to
(
torch
.
float32
).
unsqueeze
(
-
1
).
clamp
(
min
=
1.0
)
centroids
=
centroid_sums
/
counts_f
empty_mask
=
(
centroid_cnts
==
0
).
unsqueeze
(
-
1
)
centroids
=
torch
.
where
(
empty_mask
,
old_centroids
.
to
(
torch
.
float32
),
centroids
)
centroids
=
centroids
.
to
(
x_norm
.
dtype
)
centroids
=
F
.
normalize
(
centroids
,
p
=
2
,
dim
=-
1
)
return
centroids
def
triton_centroid_update_sorted_euclid
(
x
:
torch
.
Tensor
,
cluster_ids
:
torch
.
Tensor
,
old_centroids
:
torch
.
Tensor
,
*
,
BLOCK_N
:
int
=
256
):
"""Fast centroid update for *Euclidean* KMeans assuming cluster IDs are pre-sorted.
Parameters
----------
x : Tensor [B, N, D]
Input feature vectors (no normalization assumed).
cluster_ids : LongTensor [B, N]
Cluster assignment for each point.
old_centroids : Tensor [B, K, D]
Previous centroids (used to fill empty clusters).
BLOCK_N : int, optional
Tokens per Triton program (affects occupancy/perf).
"""
assert
x
.
is_cuda
and
cluster_ids
.
is_cuda
,
"Inputs must be on CUDA device"
B
,
N
,
D
=
x
.
shape
K
=
old_centroids
.
shape
[
1
]
# Batch-wise sort of cluster assignments
sorted_cluster_ids
,
sorted_idx
=
torch
.
sort
(
cluster_ids
,
dim
=-
1
)
sorted_idx_int
=
sorted_idx
.
to
(
torch
.
int32
)
centroid_sums
=
torch
.
zeros
((
B
,
K
,
D
),
device
=
x
.
device
,
dtype
=
torch
.
float32
)
centroid_cnts
=
torch
.
zeros
((
B
,
K
),
device
=
x
.
device
,
dtype
=
torch
.
int32
)
grid
=
(
triton
.
cdiv
(
N
,
BLOCK_N
),
B
)
_centroid_update_chunk_kernel
[
grid
](
x
,
# original features
sorted_idx_int
,
# gather indices
sorted_cluster_ids
.
to
(
torch
.
int32
),
centroid_sums
,
centroid_cnts
,
B
,
N
,
D
,
K
,
BLOCK_N
=
BLOCK_N
,
)
# Convert sums to means; replace empty clusters with old centroids
counts_f
=
centroid_cnts
.
to
(
torch
.
float32
).
unsqueeze
(
-
1
).
clamp
(
min
=
1.0
)
centroids
=
centroid_sums
/
counts_f
empty_mask
=
(
centroid_cnts
==
0
).
unsqueeze
(
-
1
)
centroids
=
torch
.
where
(
empty_mask
,
old_centroids
.
to
(
torch
.
float32
),
centroids
)
return
centroids
.
to
(
x
.
dtype
),
centroid_cnts
# ===============================================================
# Triton kernel: compute nearest-centroid IDs (Euclidean distance)
# Inputs:
# x : (B, N, D) float16 / float32
# centroids : (B, K, D) same dtype as x
# x_sq : (B, N) float32 – pre-computed ||x||^2 per point
# Output:
# cluster_ids : (B, N) int32 – nearest centroid index per point
# ===============================================================
def
_ceil_div
(
a
:
int
,
b
:
int
)
->
int
:
return
(
a
+
b
-
1
)
//
b
# -----------------------------------------------------------------------------
# Auto-tuning setup – explore various tile sizes / warp counts
# -----------------------------------------------------------------------------
_TUNE_CONFIGS
=
[
triton
.
Config
({
"BLOCK_N"
:
BN
,
"BLOCK_K"
:
BK
},
num_stages
=
4
,
num_warps
=
wp
)
for
BN
in
[
32
,
64
,
128
]
for
BK
in
[
32
,
64
,
128
]
for
wp
in
[
4
,
8
]]
def
_cfg_keep
(
conf
):
"""Basic heuristic to prune unbalanced configs."""
BN
=
conf
.
kwargs
[
"BLOCK_N"
]
BK
=
conf
.
kwargs
[
"BLOCK_K"
]
# Avoid tiny tiles on many warps
if
BN
*
BK
<
32
*
32
and
conf
.
num_warps
>
4
:
return
False
return
True
_TUNE_CONFIGS
=
list
(
filter
(
_cfg_keep
,
_TUNE_CONFIGS
))
@
triton
.
autotune
(
_TUNE_CONFIGS
,
key
=
[
"N"
,
"K"
])
@
triton
.
jit
def
_euclid_assign_kernel
(
x_ptr
,
# *f16 / *f32 [B, N, D]
c_ptr
,
# *f16 / *f32 [B, K, D]
x_sq_ptr
,
# *f32 [B, N]
out_ptr
,
# *i32 [B, N]
B
:
tl
.
constexpr
,
N
:
tl
.
constexpr
,
K
:
tl
.
constexpr
,
D
:
tl
.
constexpr
,
stride_x_b
:
tl
.
constexpr
,
stride_x_n
:
tl
.
constexpr
,
stride_x_d
:
tl
.
constexpr
,
stride_c_b
:
tl
.
constexpr
,
stride_c_k
:
tl
.
constexpr
,
stride_c_d
:
tl
.
constexpr
,
stride_xsq_b
:
tl
.
constexpr
,
stride_xsq_n
:
tl
.
constexpr
,
stride_out_b
:
tl
.
constexpr
,
stride_out_n
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_K
:
tl
.
constexpr
,
):
"""Each program handles a tile of BLOCK_N points for a given batch element.
The kernel iterates over the centroid dimension K in chunks of BLOCK_K and
maintains the running minimum distance as well as the corresponding index
for every point in the tile.
"""
pid_n
=
tl
.
program_id
(
0
)
# tile index along N dimension
pid_b
=
tl
.
program_id
(
1
)
# batch index
n_start
=
pid_n
*
BLOCK_N
n_offsets
=
n_start
+
tl
.
arange
(
0
,
BLOCK_N
)
n_mask
=
n_offsets
<
N
# ------------------------------------------------------------------
# Load x tile (BLOCK_N, D)
# ------------------------------------------------------------------
offs_d
=
tl
.
arange
(
0
,
D
)
# Compute pointer for x block: base + b*stride_x_b + n*stride_x_n + d*stride_x_d
x_ptrs
=
x_ptr
+
pid_b
*
stride_x_b
+
n_offsets
[:,
None
]
*
stride_x_n
+
offs_d
[
None
,
:]
*
stride_x_d
x_tile
=
tl
.
load
(
x_ptrs
,
mask
=
n_mask
[:,
None
],
other
=
0.0
)
x_tile
=
x_tile
# compute in f32
# Pre-load x_sq for the tile (BLOCK_N,)
xsq_ptrs
=
x_sq_ptr
+
pid_b
*
stride_xsq_b
+
n_offsets
*
stride_xsq_n
x_sq_tile
=
tl
.
load
(
xsq_ptrs
,
mask
=
n_mask
,
other
=
0.0
).
to
(
tl
.
float32
)
# Init best distance / index
best_dist
=
tl
.
full
((
BLOCK_N
,),
3.4e38
,
tl
.
float32
)
# large number
best_idx
=
tl
.
zeros
((
BLOCK_N
,),
tl
.
int32
)
# ------------------------------------------------------------------
# Iterate over centroids in chunks of BLOCK_K
# ------------------------------------------------------------------
for
k_start
in
range
(
0
,
K
,
BLOCK_K
):
k_offsets
=
k_start
+
tl
.
arange
(
0
,
BLOCK_K
)
k_mask
=
k_offsets
<
K
# Load centroid tile (D, BLOCK_K)
c_ptrs
=
c_ptr
+
pid_b
*
stride_c_b
+
k_offsets
[
None
,
:]
*
stride_c_k
+
offs_d
[:,
None
]
*
stride_c_d
c_tile
=
tl
.
load
(
c_ptrs
,
mask
=
k_mask
[
None
,
:],
other
=
0.0
)
c_tile
=
c_tile
# Compute centroid squared norms (BLOCK_K,)
cent_sq
=
tl
.
sum
(
c_tile
*
c_tile
,
axis
=
0
).
to
(
tl
.
float32
)
# Compute cross term (BLOCK_N, BLOCK_K) = x_tile @ c_tile
cross
=
tl
.
dot
(
x_tile
,
c_tile
).
to
(
tl
.
float32
)
# float32
# Squared Euclidean distance
dist
=
x_sq_tile
[:,
None
]
+
cent_sq
[
None
,
:]
-
2.0
*
cross
dist
=
tl
.
maximum
(
dist
,
0.0
)
# Mask out invalid centroid columns before reduction
dist
=
tl
.
where
(
k_mask
[
None
,
:],
dist
,
3.4e38
)
curr_min
=
tl
.
min
(
dist
,
axis
=
1
)
curr_idx
=
tl
.
argmin
(
dist
,
axis
=
1
)
update
=
curr_min
<
best_dist
best_dist
=
tl
.
where
(
update
,
curr_min
,
best_dist
)
best_idx
=
tl
.
where
(
update
,
k_start
+
curr_idx
,
best_idx
)
# ------------------------------------------------------------------
# Write results
# ------------------------------------------------------------------
out_ptrs
=
out_ptr
+
pid_b
*
stride_out_b
+
n_offsets
*
stride_out_n
tl
.
store
(
out_ptrs
,
best_idx
,
mask
=
n_mask
)
# ---------------------------------------------------------------
# Python wrapper
# ---------------------------------------------------------------
def
euclid_assign_triton
(
x
:
torch
.
Tensor
,
centroids
:
torch
.
Tensor
,
x_sq
:
torch
.
Tensor
,
out
:
torch
.
Tensor
=
None
,
*
,
BLOCK_N
:
int
=
128
,
BLOCK_K
:
int
=
128
,
)
->
torch
.
Tensor
:
"""Return nearest-centroid indices using Triton kernel.
Args:
x : (B, N, D) float16 / float32 (on CUDA)
centroids : (B, K, D) same dtype/device as x
x_sq : (B, N) float32 – ||x||^2 per point (on CUDA)
Returns:
cluster_ids (B, N) int32 (callers can cast to int64 if desired)
"""
assert
x
.
is_cuda
and
centroids
.
is_cuda
and
x_sq
.
is_cuda
,
"All tensors must be on CUDA"
# assert x.dtype in (torch.float16, torch.float32), "x must be fp16/fp32"
assert
centroids
.
dtype
==
x
.
dtype
,
"centroids dtype mismatch"
B
,
N
,
D
=
x
.
shape
K
=
centroids
.
shape
[
1
]
assert
centroids
.
shape
==
(
B
,
K
,
D
),
"centroids shape mismatch"
assert
x_sq
.
shape
==
(
B
,
N
),
"x_sq shape mismatch"
# x = x.contiguous()
# centroids = centroids.contiguous()
# x_sq = x_sq.contiguous()
if
out
is
None
:
out
=
torch
.
empty
((
B
,
N
),
device
=
x
.
device
,
dtype
=
torch
.
int64
)
# Strides (in elements)
stride_x_b
,
stride_x_n
,
stride_x_d
=
x
.
stride
()
stride_c_b
,
stride_c_k
,
stride_c_d
=
centroids
.
stride
()
stride_xsq_b
,
stride_xsq_n
=
x_sq
.
stride
()
stride_out_b
,
stride_out_n
=
out
.
stride
()
grid
=
lambda
META
:
(
triton
.
cdiv
(
N
,
META
[
"BLOCK_N"
]),
B
)
# noqa
_euclid_assign_kernel
[
grid
](
x
,
centroids
,
x_sq
,
out
,
B
,
N
,
K
,
D
,
stride_x_b
,
stride_x_n
,
stride_x_d
,
stride_c_b
,
stride_c_k
,
stride_c_d
,
stride_xsq_b
,
stride_xsq_n
,
stride_out_b
,
stride_out_n
,
)
return
out
# 1. Euclidean
def
_euclid_iter
(
x
,
x_sq
,
centroids
):
# cent_sq = (centroids ** 2).sum(dim=-1)
# cross = torch.einsum('bnd,bkd->bnk', x, centroids)
# dist_sq = (x_sq[:,:,None] + cent_sq[:,None,:] - 2.0 * cross).clamp_min_(0.0)
# cluster_ids = dist_sq.argmin(dim=-1)
cluster_ids
=
euclid_assign_triton
(
x
,
centroids
,
x_sq
)
centroids_new
,
cluster_sizes
=
triton_centroid_update_sorted_euclid
(
x
,
cluster_ids
,
centroids
)
# centroids_new = triton_centroid_update_euclid(x, cluster_ids, centroids)
# centroids_new = centroids_new.clone() # avoid CUDA graphs aliasing
shift
=
(
centroids_new
-
centroids
).
norm
(
dim
=-
1
).
max
()
return
centroids_new
,
shift
,
cluster_ids
,
cluster_sizes
# 2. Cosine
def
_cosine_iter
(
x_norm
,
centroids
):
cos_sim
=
torch
.
einsum
(
"bnd,bkd->bnk"
,
x_norm
,
centroids
)
cluster_ids
=
cos_sim
.
argmax
(
dim
=-
1
)
centroids_new
=
triton_centroid_update_cosine
(
x_norm
,
cluster_ids
,
centroids
)
# centroids_new = centroids_new.clone()
shift
=
(
centroids_new
-
centroids
).
norm
(
dim
=-
1
).
max
()
return
centroids_new
,
shift
,
cluster_ids
# 3. Dot-product
def
_dot_iter
(
x
,
centroids
):
sim
=
torch
.
einsum
(
"bnd,bkd->bnk"
,
x
,
centroids
)
cluster_ids
=
sim
.
argmax
(
dim
=-
1
)
centroids_new
=
triton_centroid_update_cosine
(
x
,
cluster_ids
,
centroids
)
# centroids_new = centroids_new.clone()
shift
=
(
centroids_new
-
centroids
).
norm
(
dim
=-
1
).
max
()
return
centroids_new
,
shift
,
cluster_ids
COMPILE_FLAG
=
False
# Try to compile; if PyTorch < 2.0 or compile is not available, fallback to original function
try
:
if
COMPILE_FLAG
:
_euclid_iter_compiled
=
torch
.
compile
(
_euclid_iter
,
dynamic
=
True
,
mode
=
"reduce-overhead"
)
_cosine_iter_compiled
=
torch
.
compile
(
_cosine_iter
,
dynamic
=
True
,
mode
=
"reduce-overhead"
)
_dot_iter_compiled
=
torch
.
compile
(
_dot_iter
,
dynamic
=
True
,
mode
=
"reduce-overhead"
)
else
:
_euclid_iter_compiled
=
_euclid_iter
_cosine_iter_compiled
=
_cosine_iter
_dot_iter_compiled
=
_dot_iter
except
Exception
:
# pragma: no cover
_euclid_iter_compiled
=
_euclid_iter
_cosine_iter_compiled
=
_cosine_iter
_dot_iter_compiled
=
_dot_iter
def
batch_kmeans_Euclid
(
x
,
n_clusters
,
max_iters
=
100
,
tol
=
1e-4
,
init_centroids
=
None
,
verbose
=
False
):
"""
Batched KMeans clustering in PyTorch using Euclidean distance.
Args:
x: Tensor of shape (B, N, D), batch_size B, N points per batch, D dims.
n_clusters: Number of clusters.
max_iters: Max number of iterations.
tol: Relative tolerance for center movement.
verbose: Print loss for each iter.
Returns:
cluster_ids: (B, N) LongTensor, cluster assignment for each point.
centroids: (B, n_clusters, D) final cluster centers.
cluster_sizes: (B, n_clusters) LongTensor, number of points per cluster.
n_iters: actual number of iterations executed (int)
"""
B
,
N
,
D
=
x
.
shape
# Pre-compute squared L2 norm of all points (constant during iterations)
x_sq
=
(
x
**
2
).
sum
(
dim
=-
1
)
# (B, N)
if
init_centroids
is
None
:
# Randomly select initial centers from x
indices
=
torch
.
randint
(
0
,
N
,
(
B
,
n_clusters
),
device
=
x
.
device
)
centroids
=
torch
.
gather
(
x
,
dim
=
1
,
index
=
indices
[...,
None
].
expand
(
-
1
,
-
1
,
D
))
# (B, n_clusters, D)
else
:
# centroids = init_centroids.clone()
centroids
=
init_centroids
centroids
=
centroids
.
view
(
B
,
n_clusters
,
D
)
for
it
in
range
(
max_iters
):
# ---- compiled single iteration ----
centroids_new
,
center_shift
,
cluster_ids
,
cluster_sizes
=
_euclid_iter_compiled
(
x
,
x_sq
,
centroids
)
# 4. Check for convergence
if
verbose
:
print
(
f
"Iter
{
it
}
, center shift:
{
center_shift
.
item
():.
6
f
}
"
)
if
center_shift
<
tol
:
break
# centroids = centroids_new.clone()
centroids
=
centroids_new
# # --- compute cluster sizes ---
# ones = torch.ones_like(cluster_ids, dtype=torch.int64)
# cluster_sizes = torch.zeros(B, n_clusters, dtype=torch.int64, device=x.device)
# cluster_sizes.scatter_add_(1, cluster_ids, ones)
return
cluster_ids
,
centroids
,
cluster_sizes
,
it
+
1
# return cluster_ids.clone(), centroids.clone(), cluster_sizes.clone(), it + 1
# batch_kmeans_Euclid = torch.compile(batch_kmeans_Euclid, dynamic=True, mode="reduce-overhead")
def
batch_kmeans_Cosine
(
x
,
n_clusters
,
max_iters
=
100
,
tol
=
1e-4
,
init_centroids
=
None
,
verbose
=
False
):
"""
Batched KMeans clustering in PyTorch using Cosine similarity.
Args:
x: Tensor of shape (B, N, D), batch_size B, N points per batch, D dims.
n_clusters: Number of clusters.
max_iters: Max number of iterations.
tol: Relative tolerance for center movement.
verbose: Print loss for each iter.
Returns:
cluster_ids: (B, N) LongTensor, cluster assignment for each point.
centroids: (B, n_clusters, D) final cluster centers.
cluster_sizes: (B, n_clusters) LongTensor, number of points per cluster.
n_iters: actual number of iterations executed (int)
"""
B
,
N
,
D
=
x
.
shape
# Normalize input vectors for cosine similarity
x_norm
=
F
.
normalize
(
x
,
p
=
2
,
dim
=-
1
)
# (B, N, D)
if
init_centroids
is
None
:
# Randomly select initial centers from x_norm
indices
=
torch
.
randint
(
0
,
N
,
(
B
,
n_clusters
),
device
=
x
.
device
)
centroids
=
torch
.
gather
(
x_norm
,
dim
=
1
,
index
=
indices
[...,
None
].
expand
(
-
1
,
-
1
,
D
))
# (B, n_clusters, D)
else
:
centroids
=
init_centroids
centroids
=
centroids
.
view
(
B
,
n_clusters
,
D
)
centroids
=
F
.
normalize
(
centroids
,
p
=
2
,
dim
=-
1
)
# Ensure centroids are normalized
for
it
in
range
(
max_iters
):
# ---- compiled single iteration ----
centroids_new
,
center_shift
,
cluster_ids
=
_cosine_iter_compiled
(
x_norm
,
centroids
)
# 4. Check for convergence
if
verbose
:
print
(
f
"Iter
{
it
}
, center shift:
{
center_shift
.
item
():.
6
f
}
"
)
if
center_shift
<
tol
:
break
centroids
=
centroids_new
.
clone
()
# --- compute cluster sizes ---
ones
=
torch
.
ones_like
(
cluster_ids
,
dtype
=
torch
.
int64
)
cluster_sizes
=
torch
.
zeros
(
B
,
n_clusters
,
dtype
=
torch
.
int64
,
device
=
x
.
device
)
cluster_sizes
.
scatter_add_
(
1
,
cluster_ids
,
ones
)
return
cluster_ids
,
centroids
,
cluster_sizes
,
it
+
1
def
batch_kmeans_Dot
(
x
,
n_clusters
,
max_iters
=
100
,
tol
=
1e-4
,
init_centroids
=
None
,
verbose
=
False
):
"""
Batched KMeans clustering in PyTorch using raw dot-product as similarity.
"""
B
,
N
,
D
=
x
.
shape
if
init_centroids
is
None
:
# Randomly initialize centroids
indices
=
torch
.
randint
(
0
,
N
,
(
B
,
n_clusters
),
device
=
x
.
device
)
centroids
=
torch
.
gather
(
x
,
dim
=
1
,
index
=
indices
[...,
None
].
expand
(
-
1
,
-
1
,
D
))
else
:
centroids
=
init_centroids
centroids
=
centroids
.
view
(
B
,
n_clusters
,
D
)
for
it
in
range
(
max_iters
):
# ---- compiled single iteration ----
centroids_new
,
center_shift
,
cluster_ids
=
_dot_iter_compiled
(
x
,
centroids
)
# 4. Check for convergence
if
verbose
:
print
(
f
"Iter
{
it
}
(dot), center shift:
{
center_shift
.
item
():.
6
f
}
"
)
if
center_shift
<
tol
:
break
centroids
=
centroids_new
.
clone
()
# --- compute cluster sizes ---
ones
=
torch
.
ones_like
(
cluster_ids
,
dtype
=
torch
.
int64
)
cluster_sizes
=
torch
.
zeros
(
B
,
n_clusters
,
dtype
=
torch
.
int64
,
device
=
x
.
device
)
cluster_sizes
.
scatter_add_
(
1
,
cluster_ids
,
ones
)
return
cluster_ids
,
centroids
,
cluster_sizes
,
it
+
1
# --- Functions from analyze/kmeans_block_sparse_attention.py (helpers) ---
def
permute_tensor_by_labels
(
tensor
,
labels
,
dim
):
labels
=
labels
.
to
(
tensor
.
device
)
sorted_indices
=
torch
.
argsort
(
labels
,
dim
=-
1
)
gather_indices
=
sorted_indices
for
i
in
range
(
dim
+
1
,
tensor
.
dim
()):
gather_indices
=
gather_indices
.
unsqueeze
(
-
1
)
expand_shape
=
list
(
tensor
.
shape
)
gather_indices
=
gather_indices
.
expand
(
expand_shape
)
permuted_tensor
=
torch
.
gather
(
tensor
,
dim
,
gather_indices
)
return
permuted_tensor
,
sorted_indices
def
apply_inverse_permutation
(
permuted_tensor
,
sorted_indices
,
dim
):
inverse_indices
=
torch
.
argsort
(
sorted_indices
,
dim
=-
1
)
gather_indices
=
inverse_indices
for
i
in
range
(
dim
+
1
,
permuted_tensor
.
dim
()):
gather_indices
=
gather_indices
.
unsqueeze
(
-
1
)
gather_indices
=
gather_indices
.
expand
(
permuted_tensor
.
shape
)
original_tensor
=
torch
.
gather
(
permuted_tensor
,
dim
,
gather_indices
)
return
original_tensor
def
weighted_softmax
(
scores
,
weights
):
input_dtype
=
scores
.
dtype
scores
=
scores
.
float
()
weights
=
weights
.
float
()
max_score
=
torch
.
max
(
scores
,
dim
=-
1
,
keepdim
=
True
)[
0
]
exp_scores
=
torch
.
exp
(
scores
-
max_score
)
weighted_exp
=
weights
*
exp_scores
softmax_out
=
weighted_exp
/
torch
.
sum
(
weighted_exp
,
dim
=-
1
,
keepdim
=
True
).
clamp
(
min
=
1e-12
)
return
softmax_out
.
to
(
input_dtype
)
def
identify_dynamic_map
(
query_centroids
,
key_centroids
,
q_cluster_sizes
,
k_cluster_sizes
,
p
,
min_kc_ratio
=
0
,
):
B
,
H
,
qc_num
,
D
=
query_centroids
.
shape
kc_num
=
key_centroids
.
shape
[
2
]
device
=
query_centroids
.
device
attn_scores
=
torch
.
matmul
(
query_centroids
,
key_centroids
.
transpose
(
-
2
,
-
1
))
/
(
D
**
0.5
)
k_weights
=
k_cluster_sizes
.
unsqueeze
(
-
2
).
float
()
weighted_attn_probs
=
weighted_softmax
(
attn_scores
,
k_weights
)
sorted_probs
,
sorted_indices
=
torch
.
sort
(
weighted_attn_probs
,
dim
=-
1
,
descending
=
True
)
cumsum_probs
=
torch
.
cumsum
(
sorted_probs
,
dim
=-
1
)
remove_indices
=
cumsum_probs
>
p
remove_indices
[...,
1
:]
=
remove_indices
[...,
:
-
1
].
clone
()
remove_indices
[...,
0
]
=
False
if
min_kc_ratio
>
0
:
preserve_length
=
int
(
min_kc_ratio
*
kc_num
)
remove_indices
[...,
:
preserve_length
]
=
False
sorted_clusters_to_keep
=
~
remove_indices
dynamic_map
=
torch
.
zeros
(
B
,
H
,
qc_num
,
kc_num
,
dtype
=
torch
.
bool
,
device
=
device
)
dynamic_map
.
scatter_
(
-
1
,
sorted_indices
,
sorted_clusters_to_keep
)
return
dynamic_map
# --- Functions from analyze/dynamic_block_sparse_attention.py ---
def
dynamic_block_sparse_fwd_torch
(
q
,
k
,
v
,
dynamic_map
,
qc_size
,
kc_size
):
"""
Computes dynamic block sparse attention using pure PyTorch.
Args:
q (torch.Tensor): Query tensor, shape [B, H, S, D].
k (torch.Tensor): Key tensor, shape [B, H, S, D].
v (torch.Tensor): Value tensor, shape [B, H, S, D].
dynamic_map (torch.Tensor): Boolean mask, shape [B, H, qc_num, kc_num].
qc_size (torch.Tensor): Query block sizes, shape [B, H, qc_num].
kc_size (torch.Tensor): Key block sizes, shape [B, H, kc_num].
Returns:
torch.Tensor: Output tensor, shape [B, H, S, D].
"""
B
,
H
,
S
,
D
=
q
.
shape
qc_num
=
qc_size
.
shape
[
-
1
]
kc_num
=
kc_size
.
shape
[
-
1
]
device
=
q
.
device
dtype
=
q
.
dtype
# Ensure sequence lengths match sum of block sizes
assert
S
==
torch
.
sum
(
qc_size
[
0
,
0
,
:]),
"Sum of qc_size must equal S"
assert
S
==
torch
.
sum
(
kc_size
[
0
,
0
,
:]),
"Sum of kc_size must equal S"
# Precompute cumulative sizes for block indexing
# Add a 0 at the beginning for easier slicing
qc_cum_size
=
torch
.
cumsum
(
torch
.
cat
([
torch
.
zeros_like
(
qc_size
[...,
:
1
]),
qc_size
],
dim
=-
1
),
dim
=-
1
)
kc_cum_size
=
torch
.
cumsum
(
torch
.
cat
([
torch
.
zeros_like
(
kc_size
[...,
:
1
]),
kc_size
],
dim
=-
1
),
dim
=-
1
)
out
=
torch
.
zeros_like
(
q
)
scale
=
D
**-
0.5
# Naive implementation: Iterate through batch, head, and blocks
for
b
in
range
(
B
):
for
h
in
range
(
H
):
# Precompute start/end indices for this batch/head
q_starts
=
qc_cum_size
[
b
,
h
,
:
-
1
]
q_ends
=
qc_cum_size
[
b
,
h
,
1
:]
k_starts
=
kc_cum_size
[
b
,
h
,
:
-
1
]
k_ends
=
kc_cum_size
[
b
,
h
,
1
:]
# Iterate through query blocks
for
i
in
range
(
qc_num
):
q_start
,
q_end
=
q_starts
[
i
],
q_ends
[
i
]
q_block
=
q
[
b
,
h
,
q_start
:
q_end
,
:]
# Shape: [qc_i, D]
if
q_block
.
shape
[
0
]
==
0
:
continue
# Skip empty blocks
m_i
=
torch
.
full
((
q_block
.
shape
[
0
],
1
),
-
float
(
"inf"
),
device
=
device
,
dtype
=
dtype
)
l_i
=
torch
.
zeros
((
q_block
.
shape
[
0
],
1
),
device
=
device
,
dtype
=
dtype
)
acc_o_i
=
torch
.
zeros_like
(
q_block
)
# Shape: [qc_i, D]
# Iterate through key/value blocks for the current query block
for
j
in
range
(
kc_num
):
# Check if this block needs computation
if
dynamic_map
[
b
,
h
,
i
,
j
]:
k_start
,
k_end
=
k_starts
[
j
],
k_ends
[
j
]
k_block
=
k
[
b
,
h
,
k_start
:
k_end
,
:]
# Shape: [kc_j, D]
v_block
=
v
[
b
,
h
,
k_start
:
k_end
,
:]
# Shape: [kc_j, D]
if
k_block
.
shape
[
0
]
==
0
:
continue
# Skip empty blocks
# Compute attention scores for the block
# QK^T: [qc_i, D] @ [D, kc_j] -> [qc_i, kc_j]
s_ij
=
(
q_block
@
k_block
.
transpose
(
-
1
,
-
2
))
*
scale
# --- Online Softmax ---
# Find max score per query token in this block
m_ij
=
torch
.
max
(
s_ij
,
dim
=-
1
,
keepdim
=
True
)[
0
]
# Shape: [qc_i, 1]
# Update overall max score (m_i)
m_new
=
torch
.
maximum
(
m_i
,
m_ij
)
# Shape: [qc_i, 1]
# Calculate scaling factors for previous accumulator and current block
p_ij
=
torch
.
exp
(
s_ij
-
m_new
)
# Shape: [qc_i, kc_j]
exp_m_diff
=
torch
.
exp
(
m_i
-
m_new
)
# Shape: [qc_i, 1]
# Update softmax denominator (l_i)
l_i
=
(
l_i
*
exp_m_diff
)
+
torch
.
sum
(
p_ij
,
dim
=-
1
,
keepdim
=
True
)
# Shape: [qc_i, 1]
# Update output accumulator (acc_o_i)
# P_ij @ V_j: [qc_i, kc_j] @ [kc_j, D] -> [qc_i, D]
acc_o_i
=
(
acc_o_i
*
exp_m_diff
)
+
(
p_ij
@
v_block
)
# Shape: [qc_i, D]
# Update max score for next iteration
m_i
=
m_new
# Normalize the accumulated output
out
[
b
,
h
,
q_start
:
q_end
,
:]
=
acc_o_i
/
l_i
.
clamp
(
min
=
1e-12
)
# Avoid division by zero
return
out
# --- Triton Implementation ---
@
triton
.
jit
def
_dynamic_block_sparse_fwd_kernel
(
Q
,
K
,
V
,
Out
,
dynamic_map
,
qc_cum_size
,
kc_cum_size
,
stride_qb
,
stride_qh
,
stride_qs
,
stride_qd
,
stride_kb
,
stride_kh
,
stride_ks
,
stride_kd
,
stride_vb
,
stride_vh
,
stride_vs
,
stride_vd
,
stride_ob
,
stride_oh
,
stride_os
,
stride_od
,
stride_dmap_b
,
stride_dmap_h
,
stride_dmap_qc
,
stride_dmap_kc
,
stride_qcs_b
,
stride_qcs_h
,
stride_qcs_qc
,
stride_kcs_b
,
stride_kcs_h
,
stride_kcs_kc
,
B
,
H
,
S
,
D
,
scale
,
QC_NUM
:
tl
.
constexpr
,
KC_NUM
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_D
:
tl
.
constexpr
,
):
"""
Triton kernel for dynamic block sparse attention.
Each program computes attention for one query block within a batch/head.
Processes query block in chunks of BLOCK_M.
Iterates through key blocks, checking dynamic_map.
Processes key/value blocks in chunks of BLOCK_N.
Uses online softmax.
"""
# --- Grid Calculation ---
# Each program instance handles one query block for a specific batch and head
pid
=
tl
.
program_id
(
axis
=
0
)
B
*
H
*
QC_NUM
# Calculate batch, head, and query block index
pid_q_block_global
=
pid
# 0 to B*H*QC_NUM - 1
# pid_bh = pid // QC_NUM # Deprecated: Causes issues if QC_NUM is not constant across BH
# pid_q_block_idx = pid % QC_NUM
# Need to map pid (0.. B*H*QC_NUM-1) back to (b, h, q_block_idx)
# q_block_idx changes fastest, then h, then b
q_block_idx
=
pid_q_block_global
%
QC_NUM
pid_h_temp
=
pid_q_block_global
//
QC_NUM
h
=
pid_h_temp
%
H
b
=
pid_h_temp
//
H
# --- Load Q block info (start/end offsets) ---
qcs_offset
=
b
*
stride_qcs_b
+
h
*
stride_qcs_h
q_start_offset
=
tl
.
load
(
qc_cum_size
+
qcs_offset
+
q_block_idx
*
stride_qcs_qc
)
q_end_offset
=
tl
.
load
(
qc_cum_size
+
qcs_offset
+
(
q_block_idx
+
1
)
*
stride_qcs_qc
)
q_block_size
=
q_end_offset
-
q_start_offset
# Early exit if the query block is empty
if
q_block_size
==
0
:
return
# --- Pointers setup ---
q_ptr_base
=
Q
+
b
*
stride_qb
+
h
*
stride_qh
+
q_start_offset
*
stride_qs
k_ptr_base
=
K
+
b
*
stride_kb
+
h
*
stride_kh
v_ptr_base
=
V
+
b
*
stride_vb
+
h
*
stride_vh
out_ptr_base
=
Out
+
b
*
stride_ob
+
h
*
stride_oh
+
q_start_offset
*
stride_os
dmap_ptr
=
dynamic_map
+
b
*
stride_dmap_b
+
h
*
stride_dmap_h
+
q_block_idx
*
stride_dmap_qc
kcs_ptr
=
kc_cum_size
+
b
*
stride_kcs_b
+
h
*
stride_kcs_h
# --- Iterate over the query block rows in chunks of BLOCK_M ---
offs_qm
=
tl
.
arange
(
0
,
BLOCK_M
)
# Query block row offsets [0, 1, ..., BLOCK_M-1]
offs_d
=
tl
.
arange
(
0
,
BLOCK_D
)
# Dimension offsets [0, 1, ..., BLOCK_D-1]
for
q_chunk_start
in
range
(
0
,
q_block_size
,
BLOCK_M
):
q_chunk_rows
=
offs_qm
+
q_chunk_start
q_rows_mask
=
q_chunk_rows
<
q_block_size
# Mask for valid rows in this Q chunk [BLOCK_M]
# --- Initialize accumulators for this Q chunk ---
m_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
-
float
(
"inf"
)
# Max score
l_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
# Sum of exp(scores - max)
acc_o
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_D
],
dtype
=
tl
.
float32
)
# Accumulated output
# --- Load Q chunk ---
q_ptr
=
q_ptr_base
+
q_chunk_rows
[:,
None
]
*
stride_qs
+
offs_d
[
None
,
:]
# Mask ensures we don't read out of bounds for the query block or dimension D
mask_q
=
q_rows_mask
[:,
None
]
&
(
offs_d
[
None
,
:]
<
D
)
q_chunk
=
tl
.
load
(
q_ptr
,
mask
=
mask_q
,
other
=
0.0
)
# Shape: [BLOCK_M, BLOCK_D]
# --- Inner loop over K blocks (columns in the block sparse map) ---
for
k_block_idx
in
range
(
KC_NUM
):
# --- Check dynamic_map: Is this block active? ---
is_active
=
tl
.
load
(
dmap_ptr
+
k_block_idx
*
stride_dmap_kc
)
if
is_active
:
# Process block only if it's active
# --- Load K block info (start/end offsets) ---
k_start_offset
=
tl
.
load
(
kcs_ptr
+
k_block_idx
*
stride_kcs_kc
)
k_end_offset
=
tl
.
load
(
kcs_ptr
+
(
k_block_idx
+
1
)
*
stride_kcs_kc
)
k_block_size
=
k_end_offset
-
k_start_offset
# Skip if the key block is empty (inside the active block check)
if
k_block_size
>
0
:
k_block_ptr_base
=
k_ptr_base
+
k_start_offset
*
stride_ks
v_block_ptr_base
=
v_ptr_base
+
k_start_offset
*
stride_vs
# --- Loop over K block chunks (size BLOCK_N) ---
offs_kn
=
tl
.
arange
(
0
,
BLOCK_N
)
# Key block row offsets [0, ..., BLOCK_N-1]
for
k_chunk_start
in
range
(
0
,
k_block_size
,
BLOCK_N
):
k_chunk_rows
=
offs_kn
+
k_chunk_start
k_rows_mask
=
k_chunk_rows
<
k_block_size
# Mask for valid rows in this K/V chunk [BLOCK_N]
# --- Load K, V chunks ---
k_ptr
=
k_block_ptr_base
+
k_chunk_rows
[:,
None
]
*
stride_ks
+
offs_d
[
None
,
:]
v_ptr
=
v_block_ptr_base
+
k_chunk_rows
[:,
None
]
*
stride_vs
+
offs_d
[
None
,
:]
# Mask ensures we don't read out of bounds for the key block or dimension D
mask_kv
=
k_rows_mask
[:,
None
]
&
(
offs_d
[
None
,
:]
<
D
)
k_chunk
=
tl
.
load
(
k_ptr
,
mask
=
mask_kv
,
other
=
0.0
)
# Shape: [BLOCK_N, BLOCK_D]
v_chunk
=
tl
.
load
(
v_ptr
,
mask
=
mask_kv
,
other
=
0.0
)
# Shape: [BLOCK_N, BLOCK_D]
# --- Compute Scores (Attention) ---
# QK^T: [BLOCK_M, BLOCK_D] @ [BLOCK_D, BLOCK_N] -> [BLOCK_M, BLOCK_N]
s_ij_chunk
=
tl
.
dot
(
q_chunk
,
k_chunk
.
T
)
*
scale
# IMPORTANT: Mask out scores corresponding to padding in K before max/softmax
# Set scores for invalid K elements to -inf
s_ij_chunk
=
tl
.
where
(
k_rows_mask
[
None
,
:],
s_ij_chunk
,
-
float
(
"inf"
))
# Mask out scores for invalid Q elements as well (although q_chunk elements are 0, avoid potential issues)
s_ij_chunk
=
tl
.
where
(
q_rows_mask
[:,
None
],
s_ij_chunk
,
-
float
(
"inf"
))
# --- Online Softmax Update ---
# Current max for this Q-K chunk interaction
m_ij_chunk
=
tl
.
max
(
s_ij_chunk
,
axis
=
1
)
# Shape: [BLOCK_M]
# Update overall max (across K chunks seen so far for this Q chunk)
m_new
=
tl
.
maximum
(
m_i
,
m_ij_chunk
)
# Shape: [BLOCK_M]
# Calculate scaled probabilities P_ij = exp(S_ij - m_new)
p_ij_chunk
=
tl
.
exp
(
s_ij_chunk
-
m_new
[:,
None
])
# Shape: [BLOCK_M, BLOCK_N]
# Zero out probabilities for masked K elements before summing
p_ij_chunk
=
tl
.
where
(
k_rows_mask
[
None
,
:],
p_ij_chunk
,
0.0
)
# Calculate scaling factor for previous accumulator state
exp_m_diff
=
tl
.
exp
(
m_i
-
m_new
)
# Shape: [BLOCK_M]
# Update sum accumulator (denominator L)
l_i_chunk
=
tl
.
sum
(
p_ij_chunk
,
axis
=
1
)
# Sum probabilities for this chunk, shape [BLOCK_M]
l_i
=
(
l_i
*
exp_m_diff
)
+
l_i_chunk
# Shape: [BLOCK_M]
# Update output accumulator O
# P_ij @ V_j: [BLOCK_M, BLOCK_N] @ [BLOCK_N, BLOCK_D] -> [BLOCK_M, BLOCK_D]
# Ensure p_ij_chunk is the correct dtype for dot product
p_ij_chunk_casted
=
p_ij_chunk
.
to
(
V
.
dtype
.
element_ty
)
o_chunk
=
tl
.
dot
(
p_ij_chunk_casted
,
v_chunk
)
# Shape: [BLOCK_M, BLOCK_D]
acc_o
=
(
acc_o
*
exp_m_diff
[:,
None
])
+
o_chunk
# Shape: [BLOCK_M, BLOCK_D]
# Update max for the next K chunk/block
m_i
=
m_new
# End of 'if is_active:' block
# --- End of loop over K blocks ---
# --- Finalize output for this Q chunk ---
# Normalize the accumulated output: O = acc_o / l_i
# Add epsilon to l_i to avoid division by zero
l_i_safe
=
tl
.
where
(
l_i
==
0
,
1.0
,
l_i
)
# Avoid 0/0 -> NaN
o_final_chunk
=
acc_o
/
(
l_i_safe
[:,
None
])
o_final_chunk
=
tl
.
where
(
l_i
[:,
None
]
==
0
,
0.0
,
o_final_chunk
)
# Ensure output is 0 if l_i was 0
# --- Write output chunk to global memory ---
out_ptr
=
out_ptr_base
+
q_chunk_rows
[:,
None
]
*
stride_os
+
offs_d
[
None
,
:]
# Mask ensures we don't write out of bounds for the query block or dimension D
mask_out
=
q_rows_mask
[:,
None
]
&
(
offs_d
[
None
,
:]
<
D
)
tl
.
store
(
out_ptr
,
o_final_chunk
.
to
(
Out
.
dtype
.
element_ty
),
mask
=
mask_out
)
# --- (Optional: Write L and M stats if needed) ---
# Example:
# l_ptr = L + b * stride_lb + h * stride_lh + (q_start_offset + q_chunk_rows) * stride_ls
# tl.store(l_ptr, l_i, mask=q_rows_mask)
# m_ptr = M + ...
# tl.store(m_ptr, m_i, mask=q_rows_mask)
# --- End of loop over Q chunks ---
def
dynamic_block_sparse_fwd_triton
(
q
,
k
,
v
,
dynamic_map
,
qc_size
,
kc_size
):
"""
Launcher for the Triton dynamic block sparse attention kernel.
Args:
q (torch.Tensor): Query tensor, shape [B, H, S, D].
k (torch.Tensor): Key tensor, shape [B, H, S, D].
v (torch.Tensor): Value tensor, shape [B, H, S, D].
dynamic_map (torch.Tensor): Boolean mask, shape [B, H, qc_num, kc_num].
qc_size (torch.Tensor): Query block sizes, shape [B, H, qc_num].
kc_size (torch.Tensor): Key block sizes, shape [B, H, kc_num].
Returns:
torch.Tensor: Output tensor, shape [B, H, S, D].
"""
B
,
H
,
S
,
D
=
q
.
shape
qc_num
=
qc_size
.
shape
[
-
1
]
kc_num
=
kc_size
.
shape
[
-
1
]
dtype
=
q
.
dtype
# Assertions and checks
assert
q
.
is_cuda
and
k
.
is_cuda
and
v
.
is_cuda
,
"Inputs must be CUDA tensors"
assert
dynamic_map
.
is_cuda
and
qc_size
.
is_cuda
and
kc_size
.
is_cuda
assert
q
.
dtype
==
k
.
dtype
==
v
.
dtype
,
"Input dtypes must match"
assert
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
],
"Unsupported dtype"
assert
D
in
[
16
,
32
,
64
,
128
],
"Head dimension D must be 16, 32, 64, or 128 for efficient Triton dot"
# Ensure sequence lengths match sum of block sizes (check on one batch/head for simplicity)
assert
S
==
torch
.
sum
(
qc_size
[
0
,
0
,
:]),
"Sum of qc_size must equal S"
assert
S
==
torch
.
sum
(
kc_size
[
0
,
0
,
:]),
"Sum of kc_size must equal S"
# Ensure dynamic_map is boolean
assert
dynamic_map
.
dtype
==
torch
.
bool
# Calculate scale factor (using float32 for stability)
scale
=
D
**-
0.5
# Precompute cumulative sizes (on CPU/GPU, keep on device)
qc_cum_size
=
torch
.
cumsum
(
torch
.
cat
([
torch
.
zeros_like
(
qc_size
[...,
:
1
]),
qc_size
],
dim
=-
1
),
dim
=-
1
).
int
()
kc_cum_size
=
torch
.
cumsum
(
torch
.
cat
([
torch
.
zeros_like
(
kc_size
[...,
:
1
]),
kc_size
],
dim
=-
1
),
dim
=-
1
).
int
()
# Output tensor
out
=
torch
.
empty_like
(
q
)
# Triton kernel config
# BLOCK_M/N can be tuned. Larger blocks may increase occupancy but need more shared memory.
# Let's start with reasonably sized blocks.
BLOCK_D
=
D
if
S
<=
512
:
# Smaller sequence, smaller blocks might be ok
BLOCK_M
=
64
BLOCK_N
=
64
elif
S
<=
1024
:
BLOCK_M
=
64
BLOCK_N
=
64
else
:
# Larger sequence, potentially larger blocks
BLOCK_M
=
128
# Or keep 64? Test
BLOCK_N
=
64
# Adjust block size if sequence length is smaller
BLOCK_M
=
min
(
BLOCK_M
,
S
)
BLOCK_N
=
min
(
BLOCK_N
,
S
)
# Launch grid: One program per query block per batch/head
grid
=
(
B
*
H
*
qc_num
,)
# Call the kernel
_dynamic_block_sparse_fwd_kernel
[
grid
](
q
,
k
,
v
,
out
,
dynamic_map
,
qc_cum_size
,
kc_cum_size
,
q
.
stride
(
0
),
q
.
stride
(
1
),
q
.
stride
(
2
),
q
.
stride
(
3
),
k
.
stride
(
0
),
k
.
stride
(
1
),
k
.
stride
(
2
),
k
.
stride
(
3
),
v
.
stride
(
0
),
v
.
stride
(
1
),
v
.
stride
(
2
),
v
.
stride
(
3
),
out
.
stride
(
0
),
out
.
stride
(
1
),
out
.
stride
(
2
),
out
.
stride
(
3
),
dynamic_map
.
stride
(
0
),
dynamic_map
.
stride
(
1
),
dynamic_map
.
stride
(
2
),
dynamic_map
.
stride
(
3
),
qc_cum_size
.
stride
(
0
),
qc_cum_size
.
stride
(
1
),
qc_cum_size
.
stride
(
2
),
kc_cum_size
.
stride
(
0
),
kc_cum_size
.
stride
(
1
),
kc_cum_size
.
stride
(
2
),
B
,
H
,
S
,
D
,
scale
,
QC_NUM
=
qc_num
,
KC_NUM
=
kc_num
,
BLOCK_M
=
BLOCK_M
,
BLOCK_N
=
BLOCK_N
,
BLOCK_D
=
BLOCK_D
,
# num_warps=4 # Can tune this
)
return
out
# ---------------- Batch wrapper for cuVS KMeans -----------------
def
batch_kmeans_rapidai
(
x
,
n_clusters
,
max_iters
=
100
,
tol
=
1e-4
,
init_centroids
=
None
,
verbose
=
False
):
"""Batched K-Means using RAPIDS cuVS implementation.
Args:
x (Tensor): (B, N, D) float32 tensor on CUDA.
n_clusters (int): K.
max_iters (int): maximum iterations.
tol (float): tolerance.
init_centroids (Tensor|None): optional initial centroids (B,K,D) float32.
verbose (bool): print per-batch info.
Returns:
cluster_ids (B, N) LongTensor
centroids (B, K, D) float32
cluster_sizes (B, K) LongTensor
n_iters_list (List[int]) iterations per batch
"""
B
,
N
,
D
=
x
.
shape
if
init_centroids
is
not
None
:
assert
init_centroids
.
shape
==
(
B
,
n_clusters
,
D
)
cluster_ids_list
=
[]
centroids_list
=
[]
# cluster_sizes_list = []
n_iters_list
=
[]
x_float
=
x
.
float
()
if
init_centroids
is
not
None
:
init_centroids_float
=
init_centroids
.
float
()
for
b
in
range
(
B
):
xb
=
x_float
[
b
]
if
init_centroids
is
None
:
centroids_init_b
=
None
init_method
=
"KMeansPlusPlus"
else
:
centroids_init_b
=
init_centroids_float
[
b
]
init_method
=
"Array"
labels_b
,
centroids_b
,
n_iter_b
=
kmeans_rapidai
(
xb
,
n_clusters
,
max_iter
=
max_iters
,
tol
=
tol
,
init_method
=
init_method
,
centroids_init
=
centroids_init_b
)
cluster_ids_list
.
append
(
labels_b
.
to
(
torch
.
int64
))
# (N,)
centroids_list
.
append
(
centroids_b
)
# cluster_sizes_b = torch.bincount(labels_b, minlength=n_clusters).to(torch.int64)
# cluster_sizes_list.append(cluster_sizes_b)
# n_iters_list.append(n_iter_b)
# if verbose:
# print(f"Batch {b}: iters={n_iter_b}, cluster sizes min={cluster_sizes_b.min().item()} max={cluster_sizes_b.max().item()}")
cluster_ids
=
torch
.
stack
(
cluster_ids_list
,
dim
=
0
)
# (B,N)
centroids
=
torch
.
stack
(
centroids_list
,
dim
=
0
).
to
(
x
.
dtype
)
# (B,K,D)
# cluster_sizes = torch.stack(cluster_sizes_list, dim=0) # (B,K)
# --- compute cluster sizes ---
ones
=
torch
.
ones_like
(
cluster_ids
,
dtype
=
torch
.
int64
)
cluster_sizes
=
torch
.
zeros
(
B
,
n_clusters
,
dtype
=
torch
.
int64
,
device
=
x
.
device
)
cluster_sizes
.
scatter_add_
(
1
,
cluster_ids
,
ones
)
return
cluster_ids
,
centroids
,
cluster_sizes
,
n_iters_list
lightx2v/common/ops/attn/svg_attn.py
0 → 100644
View file @
a1ebc651
import
math
from
functools
import
lru_cache
from
math
import
ceil
import
torch
import
torch.nn.functional
as
F
import
triton
import
triton.language
as
tl
from
loguru
import
logger
from
torch.nn.attention.flex_attention
import
create_block_mask
,
flex_attention
from
lightx2v.utils.registry_factory
import
ATTN_WEIGHT_REGISTER
from
.template
import
AttnWeightTemplate
@
triton
.
jit
def
wan_hidden_states_placement_kernel
(
hidden_states_ptr
,
# [cfg, num_heads, seq_len, head_dim] seq_len = context_length + num_frame * frame_size
hidden_states_out_ptr
,
# [cfg, num_heads, seq_len, head_dim]
best_mask_idx_ptr
,
# [cfg, num_heads]
hidden_states_stride_b
,
hidden_states_stride_h
,
hidden_states_stride_s
,
hidden_states_stride_d
,
mask_idx_stride_b
,
mask_idx_stride_h
,
seq_len
:
tl
.
constexpr
,
head_dim
:
tl
.
constexpr
,
context_length
:
tl
.
constexpr
,
num_frame
:
tl
.
constexpr
,
frame_size
:
tl
.
constexpr
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
# Copy hidden_states to output
# range: [b, h, block_id * block_size: block_id * block_size + block_size, :]
cfg
=
tl
.
program_id
(
0
)
head
=
tl
.
program_id
(
1
)
block_id
=
tl
.
program_id
(
2
)
start_id
=
block_id
*
BLOCK_SIZE
end_id
=
start_id
+
BLOCK_SIZE
end_id
=
tl
.
where
(
end_id
>
seq_len
,
seq_len
,
end_id
)
# Load best mask idx (0 is spatial, 1 is temporal)
is_temporal
=
tl
.
load
(
best_mask_idx_ptr
+
cfg
*
mask_idx_stride_b
+
head
*
mask_idx_stride_h
)
offset_token
=
tl
.
arange
(
0
,
BLOCK_SIZE
)
+
start_id
offset_mask
=
offset_token
<
seq_len
offset_d
=
tl
.
arange
(
0
,
head_dim
)
if
is_temporal
:
patch_id
=
offset_token
//
num_frame
frame_id
=
offset_token
-
patch_id
*
num_frame
offset_store_token
=
tl
.
where
(
offset_token
>=
seq_len
-
context_length
,
offset_token
,
frame_id
*
frame_size
+
patch_id
)
offset_load
=
(
cfg
*
hidden_states_stride_b
+
head
*
hidden_states_stride_h
+
offset_token
[:,
None
]
*
hidden_states_stride_s
)
+
offset_d
[
None
,
:]
*
hidden_states_stride_d
offset_hidden_states
=
hidden_states_ptr
+
offset_load
offset_store
=
(
cfg
*
hidden_states_stride_b
+
head
*
hidden_states_stride_h
+
offset_store_token
[:,
None
]
*
hidden_states_stride_s
)
+
offset_d
[
None
,
:]
*
hidden_states_stride_d
offset_hidden_states_out
=
hidden_states_out_ptr
+
offset_store
# Maybe tune the pipeline here
hidden_states
=
tl
.
load
(
offset_hidden_states
,
mask
=
offset_mask
[:,
None
])
tl
.
store
(
offset_hidden_states_out
,
hidden_states
,
mask
=
offset_mask
[:,
None
])
else
:
offset_load
=
(
cfg
*
hidden_states_stride_b
+
head
*
hidden_states_stride_h
+
offset_token
[:,
None
]
*
hidden_states_stride_s
)
+
offset_d
[
None
,
:]
*
hidden_states_stride_d
offset_hidden_states
=
hidden_states_ptr
+
offset_load
offset_store
=
offset_load
offset_hidden_states_out
=
hidden_states_out_ptr
+
offset_store
# Maybe tune the pipeline here
hidden_states
=
tl
.
load
(
offset_hidden_states
,
mask
=
offset_mask
[:,
None
])
tl
.
store
(
offset_hidden_states_out
,
hidden_states
,
mask
=
offset_mask
[:,
None
])
def
wan_hidden_states_placement
(
hidden_states
,
hidden_states_out
,
best_mask_idx
,
context_length
,
num_frame
,
frame_size
):
cfg
,
num_heads
,
seq_len
,
head_dim
=
hidden_states
.
shape
BLOCK_SIZE
=
128
assert
seq_len
==
context_length
+
num_frame
*
frame_size
grid
=
(
cfg
,
num_heads
,
(
seq_len
+
BLOCK_SIZE
-
1
)
//
BLOCK_SIZE
)
wan_hidden_states_placement_kernel
[
grid
](
hidden_states
,
hidden_states_out
,
best_mask_idx
,
hidden_states
.
stride
(
0
),
hidden_states
.
stride
(
1
),
hidden_states
.
stride
(
2
),
hidden_states
.
stride
(
3
),
best_mask_idx
.
stride
(
0
),
best_mask_idx
.
stride
(
1
),
seq_len
,
head_dim
,
context_length
,
num_frame
,
frame_size
,
BLOCK_SIZE
,
)
return
hidden_states_out
@
triton
.
jit
def
wan_sparse_head_placement_kernel
(
query_ptr
,
key_ptr
,
value_ptr
,
# [cfg, num_heads, seq_len, head_dim] seq_len = context_length + num_frame * frame_size
query_out_ptr
,
key_out_ptr
,
value_out_ptr
,
# [cfg, num_heads, seq_len, head_dim]
best_mask_idx_ptr
,
# [cfg, num_heads]
query_stride_b
,
query_stride_h
,
query_stride_s
,
query_stride_d
,
mask_idx_stride_b
,
mask_idx_stride_h
,
seq_len
:
tl
.
constexpr
,
head_dim
:
tl
.
constexpr
,
context_length
:
tl
.
constexpr
,
num_frame
:
tl
.
constexpr
,
frame_size
:
tl
.
constexpr
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
# Copy query, key, value to output
# range: [b, h, block_id * block_size: block_id * block_size + block_size, :]
cfg
=
tl
.
program_id
(
0
)
head
=
tl
.
program_id
(
1
)
block_id
=
tl
.
program_id
(
2
)
start_id
=
block_id
*
BLOCK_SIZE
end_id
=
start_id
+
BLOCK_SIZE
end_id
=
tl
.
where
(
end_id
>
seq_len
,
seq_len
,
end_id
)
# Load best mask idx (0 is spatial, 1 is temporal)
is_temporal
=
tl
.
load
(
best_mask_idx_ptr
+
cfg
*
mask_idx_stride_b
+
head
*
mask_idx_stride_h
)
offset_token
=
tl
.
arange
(
0
,
BLOCK_SIZE
)
+
start_id
offset_mask
=
offset_token
<
seq_len
offset_d
=
tl
.
arange
(
0
,
head_dim
)
if
is_temporal
:
frame_id
=
offset_token
//
frame_size
patch_id
=
offset_token
-
frame_id
*
frame_size
offset_store_token
=
tl
.
where
(
offset_token
>=
seq_len
-
context_length
,
offset_token
,
patch_id
*
num_frame
+
frame_id
)
offset_load
=
(
cfg
*
query_stride_b
+
head
*
query_stride_h
+
offset_token
[:,
None
]
*
query_stride_s
)
+
offset_d
[
None
,
:]
*
query_stride_d
offset_query
=
query_ptr
+
offset_load
offset_key
=
key_ptr
+
offset_load
offset_value
=
value_ptr
+
offset_load
offset_store
=
(
cfg
*
query_stride_b
+
head
*
query_stride_h
+
offset_store_token
[:,
None
]
*
query_stride_s
)
+
offset_d
[
None
,
:]
*
query_stride_d
offset_query_out
=
query_out_ptr
+
offset_store
offset_key_out
=
key_out_ptr
+
offset_store
offset_value_out
=
value_out_ptr
+
offset_store
# Maybe tune the pipeline here
query
=
tl
.
load
(
offset_query
,
mask
=
offset_mask
[:,
None
])
tl
.
store
(
offset_query_out
,
query
,
mask
=
offset_mask
[:,
None
])
key
=
tl
.
load
(
offset_key
,
mask
=
offset_mask
[:,
None
])
tl
.
store
(
offset_key_out
,
key
,
mask
=
offset_mask
[:,
None
])
value
=
tl
.
load
(
offset_value
,
mask
=
offset_mask
[:,
None
])
tl
.
store
(
offset_value_out
,
value
,
mask
=
offset_mask
[:,
None
])
else
:
offset_load
=
(
cfg
*
query_stride_b
+
head
*
query_stride_h
+
offset_token
[:,
None
]
*
query_stride_s
)
+
offset_d
[
None
,
:]
*
query_stride_d
offset_query
=
query_ptr
+
offset_load
offset_key
=
key_ptr
+
offset_load
offset_value
=
value_ptr
+
offset_load
offset_store
=
offset_load
offset_query_out
=
query_out_ptr
+
offset_store
offset_key_out
=
key_out_ptr
+
offset_store
offset_value_out
=
value_out_ptr
+
offset_store
# Maybe tune the pipeline here
query
=
tl
.
load
(
offset_query
,
mask
=
offset_mask
[:,
None
])
tl
.
store
(
offset_query_out
,
query
,
mask
=
offset_mask
[:,
None
])
key
=
tl
.
load
(
offset_key
,
mask
=
offset_mask
[:,
None
])
tl
.
store
(
offset_key_out
,
key
,
mask
=
offset_mask
[:,
None
])
value
=
tl
.
load
(
offset_value
,
mask
=
offset_mask
[:,
None
])
tl
.
store
(
offset_value_out
,
value
,
mask
=
offset_mask
[:,
None
])
def
wan_sparse_head_placement
(
query
,
key
,
value
,
query_out
,
key_out
,
value_out
,
best_mask_idx
,
context_length
,
num_frame
,
frame_size
):
cfg
,
num_heads
,
seq_len
,
head_dim
=
query
.
shape
BLOCK_SIZE
=
128
assert
seq_len
==
context_length
+
num_frame
*
frame_size
grid
=
(
cfg
,
num_heads
,
(
seq_len
+
BLOCK_SIZE
-
1
)
//
BLOCK_SIZE
)
wan_sparse_head_placement_kernel
[
grid
](
query
,
key
,
value
,
query_out
,
key_out
,
value_out
,
best_mask_idx
,
query
.
stride
(
0
),
query
.
stride
(
1
),
query
.
stride
(
2
),
query
.
stride
(
3
),
best_mask_idx
.
stride
(
0
),
best_mask_idx
.
stride
(
1
),
seq_len
,
head_dim
,
context_length
,
num_frame
,
frame_size
,
BLOCK_SIZE
,
)
def
generate_temporal_head_mask_mod
(
context_length
:
int
=
226
,
prompt_length
:
int
=
226
,
num_frames
:
int
=
13
,
token_per_frame
:
int
=
1350
,
mul
:
int
=
2
):
def
round_to_multiple
(
idx
):
return
ceil
(
idx
/
128
)
*
128
def
temporal_mask_mod
(
b
,
h
,
q_idx
,
kv_idx
):
two_frame
=
round_to_multiple
(
mul
*
token_per_frame
)
temporal_head_mask
=
torch
.
abs
(
q_idx
-
kv_idx
)
<=
two_frame
# return temporal_head_mask
first_frame_mask
=
kv_idx
<
token_per_frame
video_mask
=
first_frame_mask
|
temporal_head_mask
return
video_mask
return
temporal_mask_mod
@
lru_cache
def
create_block_mask_cached
(
score_mod
,
B
,
H
,
M
,
N
,
device
=
"cuda"
,
_compile
=
False
):
block_mask
=
create_block_mask
(
score_mod
,
B
,
H
,
M
,
N
,
device
=
device
,
_compile
=
_compile
)
return
block_mask
def
prepare_flexattention
(
cfg_size
,
num_head
,
head_dim
,
dtype
,
device
,
context_length
,
prompt_length
,
num_frame
,
frame_size
,
diag_width
=
1
,
multiplier
=
2
):
assert
diag_width
==
multiplier
,
f
"
{
diag_width
}
is not equivalent to
{
multiplier
}
"
seq_len
=
context_length
+
num_frame
*
frame_size
mask_mod
=
generate_temporal_head_mask_mod
(
context_length
,
prompt_length
,
num_frame
,
frame_size
,
mul
=
multiplier
)
block_mask
=
create_block_mask_cached
(
mask_mod
,
None
,
None
,
seq_len
,
seq_len
,
device
=
device
,
_compile
=
True
)
return
block_mask
def
sparsity_to_width
(
sparsity
,
context_length
,
num_frame
,
frame_size
):
seq_len
=
context_length
+
num_frame
*
frame_size
total_elements
=
seq_len
**
2
sparsity
=
(
sparsity
*
total_elements
-
2
*
seq_len
*
context_length
)
/
total_elements
width
=
seq_len
*
(
1
-
math
.
sqrt
(
1
-
sparsity
))
width_frame
=
width
/
frame_size
return
width_frame
def
get_attention_mask
(
mask_name
,
sample_mse_max_row
,
context_length
,
num_frame
,
frame_size
):
attention_mask
=
torch
.
zeros
((
context_length
+
num_frame
*
frame_size
,
context_length
+
num_frame
*
frame_size
),
device
=
"cpu"
)
# TODO: fix hard coded mask
if
mask_name
==
"spatial"
:
pixel_attn_mask
=
torch
.
zeros_like
(
attention_mask
,
dtype
=
torch
.
bool
,
device
=
"cpu"
)
pixel_attn_mask
[:,
:
frame_size
]
=
1
# First Frame Sink
block_size
,
block_thres
=
128
,
frame_size
*
2
num_block
=
math
.
ceil
(
num_frame
*
frame_size
/
block_size
)
for
i
in
range
(
num_block
):
for
j
in
range
(
num_block
):
if
abs
(
i
-
j
)
<
block_thres
//
block_size
:
pixel_attn_mask
[
i
*
block_size
:
(
i
+
1
)
*
block_size
,
j
*
block_size
:
(
j
+
1
)
*
block_size
]
=
1
attention_mask
=
pixel_attn_mask
else
:
pixel_attn_mask
=
torch
.
zeros_like
(
attention_mask
,
dtype
=
torch
.
bool
,
device
=
"cpu"
)
pixel_attn_mask
[:,
:
frame_size
]
=
1
# First Frame Sink
block_size
,
block_thres
=
128
,
frame_size
*
2
num_block
=
math
.
ceil
(
num_frame
*
frame_size
/
block_size
)
for
i
in
range
(
num_block
):
for
j
in
range
(
num_block
):
if
abs
(
i
-
j
)
<
block_thres
//
block_size
:
pixel_attn_mask
[
i
*
block_size
:
(
i
+
1
)
*
block_size
,
j
*
block_size
:
(
j
+
1
)
*
block_size
]
=
1
pixel_attn_mask
=
pixel_attn_mask
.
reshape
(
frame_size
,
num_frame
,
frame_size
,
num_frame
).
permute
(
1
,
0
,
3
,
2
).
reshape
(
frame_size
*
num_frame
,
frame_size
*
num_frame
)
attention_mask
=
pixel_attn_mask
attention_mask
=
attention_mask
[:
sample_mse_max_row
].
cuda
()
return
attention_mask
@
ATTN_WEIGHT_REGISTER
(
"svg_attn"
)
class
SvgAttnWeight
(
AttnWeightTemplate
):
head_num
=
None
head_dim
=
None
sample_mse_max_row
=
None
num_sampled_rows
=
None
context_length
=
None
attnmap_frame_num
=
None
seqlen
=
None
sparsity
=
None
mask_name_list
=
[
"spatial"
,
"temporal"
]
attention_masks
=
None
block_mask
=
None
@
classmethod
def
prepare
(
cls
,
head_num
,
head_dim
,
sample_mse_max_row
,
num_sampled_rows
,
context_length
,
sparsity
):
cls
.
head_num
=
head_num
cls
.
head_dim
=
head_dim
cls
.
sample_mse_max_row
=
sample_mse_max_row
cls
.
num_sampled_rows
=
num_sampled_rows
cls
.
context_length
=
context_length
cls
.
sparsity
=
sparsity
torch
.
_dynamo
.
config
.
cache_size_limit
=
192
*
3
torch
.
_dynamo
.
config
.
accumulated_cache_size_limit
=
192
*
3
logger
.
info
(
f
"SvgAttnWeight Prepare: head_num=
{
head_num
}
, head_dim=
{
head_dim
}
, sample_mse_max_row=
{
sample_mse_max_row
}
, num_sampled_rows=
{
num_sampled_rows
}
, context_length=
{
context_length
}
, sparsity=
{
sparsity
}
"
)
def
__init__
(
self
):
self
.
config
=
{}
self
.
sparse_attention
=
torch
.
compile
(
flex_attention
,
dynamic
=
False
,
mode
=
"max-autotune-no-cudagraphs"
)
@
classmethod
def
prepare_mask
(
cls
,
seqlen
):
# Use class attributes so updates affect all instances of this class
if
seqlen
==
cls
.
seqlen
:
return
frame_size
=
seqlen
//
cls
.
attnmap_frame_num
cls
.
attention_masks
=
[
get_attention_mask
(
mask_name
,
cls
.
sample_mse_max_row
,
cls
.
context_length
,
cls
.
attnmap_frame_num
,
frame_size
)
for
mask_name
in
cls
.
mask_name_list
]
multiplier
=
diag_width
=
sparsity_to_width
(
cls
.
sparsity
,
cls
.
context_length
,
cls
.
attnmap_frame_num
,
frame_size
)
cls
.
block_mask
=
prepare_flexattention
(
1
,
cls
.
head_num
,
cls
.
head_dim
,
torch
.
bfloat16
,
"cuda"
,
cls
.
context_length
,
cls
.
context_length
,
cls
.
attnmap_frame_num
,
frame_size
,
diag_width
=
diag_width
,
multiplier
=
multiplier
)
cls
.
seqlen
=
seqlen
logger
.
info
(
f
"SvgAttnWeight Update: seqlen=
{
seqlen
}
"
)
def
apply
(
self
,
q
,
k
,
v
,
cu_seqlens_q
=
None
,
cu_seqlens_kv
=
None
,
max_seqlen_q
=
None
,
max_seqlen_kv
=
None
,
model_cls
=
None
,
):
q
=
q
.
unsqueeze
(
0
).
transpose
(
1
,
2
)
k
=
k
.
unsqueeze
(
0
).
transpose
(
1
,
2
)
v
=
v
.
unsqueeze
(
0
).
transpose
(
1
,
2
)
bs
,
num_heads
,
seq_len
,
dim
=
q
.
size
()
self
.
prepare_mask
(
seq_len
)
sampled_mses
=
self
.
sample_mse
(
q
,
k
,
v
)
best_mask_idx
=
torch
.
argmin
(
sampled_mses
,
dim
=
0
)
output_hidden_states
=
torch
.
zeros_like
(
q
)
query_out
,
key_out
,
value_out
=
torch
.
zeros_like
(
q
),
torch
.
zeros_like
(
k
),
torch
.
zeros_like
(
v
)
query_out
,
key_out
,
value_out
=
self
.
fast_sparse_head_placement
(
q
,
k
,
v
,
query_out
,
key_out
,
value_out
,
best_mask_idx
,
self
.
context_length
,
self
.
attnmap_frame_num
,
seq_len
//
self
.
attnmap_frame_num
)
hidden_states
=
self
.
sparse_attention
(
query_out
,
key_out
,
value_out
)
wan_hidden_states_placement
(
hidden_states
,
output_hidden_states
,
best_mask_idx
,
self
.
context_length
,
self
.
attnmap_frame_num
,
seq_len
//
self
.
attnmap_frame_num
)
return
output_hidden_states
.
reshape
(
bs
,
num_heads
,
seq_len
,
dim
).
transpose
(
1
,
2
).
reshape
(
bs
*
seq_len
,
-
1
)
def
fast_sparse_head_placement
(
self
,
query
,
key
,
value
,
query_out
,
key_out
,
value_out
,
best_mask_idx
,
context_length
,
num_frame
,
frame_size
):
wan_sparse_head_placement
(
query
,
key
,
value
,
query_out
,
key_out
,
value_out
,
best_mask_idx
,
context_length
,
num_frame
,
frame_size
)
return
query_out
,
key_out
,
value_out
def
sample_mse
(
self
,
query
,
key
,
value
):
cfg
,
num_heads
,
seq_len
,
dim
=
query
.
size
()
num_sampled_rows
=
min
(
self
.
num_sampled_rows
,
seq_len
)
sampled_rows
=
torch
.
randint
(
low
=
0
,
high
=
self
.
sample_mse_max_row
,
size
=
(
num_sampled_rows
,))
sampled_q
=
query
[:,
:,
sampled_rows
,
:]
sampled_qk_scores
=
torch
.
matmul
(
sampled_q
,
key
.
transpose
(
-
2
,
-
1
))
/
(
dim
**
0.5
)
sampled_attn_weights
=
F
.
softmax
(
sampled_qk_scores
,
dim
=-
1
)
sampled_golden_hidden_states
=
torch
.
matmul
(
sampled_attn_weights
,
value
)
# (1, seq_len, dim)
sampled_mses
=
torch
.
zeros
(
len
(
self
.
attention_masks
),
cfg
,
num_heads
,
device
=
query
.
device
,
dtype
=
query
.
dtype
)
# Only have Tri-diagonal and Striped
for
mask_idx
,
attn_mask
in
enumerate
(
self
.
attention_masks
):
sampled_attention_mask
=
attn_mask
[
sampled_rows
,
:]
sampled_attention_scores
=
sampled_qk_scores
.
masked_fill
(
sampled_attention_mask
==
0
,
float
(
"-inf"
))
sampled_attn_weights
=
F
.
softmax
(
sampled_attention_scores
,
dim
=-
1
)
sampled_hidden_states
=
torch
.
matmul
(
sampled_attn_weights
,
value
)
mse
=
torch
.
mean
((
sampled_hidden_states
-
sampled_golden_hidden_states
)
**
2
,
dim
=
(
2
,
3
))
sampled_mses
[
mask_idx
]
=
mse
return
sampled_mses
if
__name__
==
"__main__"
:
q
,
k
,
v
=
torch
.
randn
(
32130
,
40
,
128
,
dtype
=
torch
.
bfloat16
).
cuda
(),
torch
.
randn
(
32130
,
40
,
128
,
dtype
=
torch
.
bfloat16
).
cuda
(),
torch
.
randn
(
32130
,
40
,
128
,
dtype
=
torch
.
bfloat16
).
cuda
()
SvgAttnWeight
.
prepare
(
head_num
=
40
,
head_dim
=
128
,
sample_mse_max_row
=
10000
,
num_sampled_rows
=
64
,
context_length
=
0
,
sparsity
=
0.25
)
svg_attn
=
SvgAttnWeight
()
print
(
"SvgAttnWeight initialized."
)
out
=
svg_attn
.
apply
(
q
,
k
,
v
)
print
(
f
"out:
{
out
.
shape
}
,
{
out
.
dtype
}
,
{
out
.
device
}
"
)
lightx2v/common/ops/attn/template.py
0 → 100644
View file @
a1ebc651
from
abc
import
ABCMeta
,
abstractmethod
class
AttnWeightTemplate
(
metaclass
=
ABCMeta
):
def
__init__
(
self
,
weight_name
):
self
.
weight_name
=
weight_name
self
.
config
=
{}
def
load
(
self
,
weight_dict
):
pass
@
abstractmethod
def
apply
(
self
,
input_tensor
):
pass
def
set_config
(
self
,
config
=
None
):
if
config
is
not
None
:
self
.
config
=
config
def
to_cpu
(
self
,
non_blocking
=
False
):
pass
def
to_cuda
(
self
,
non_blocking
=
False
):
pass
def
state_dict
(
self
,
destination
=
None
):
if
destination
is
None
:
destination
=
{}
return
destination
def
load_state_dict
(
self
,
destination
,
block_index
,
adapter_block_inde
=
None
):
return
{}
def
load_state_dict_from_disk
(
self
,
block_index
,
adapter_block_inde
=
None
):
pass
lightx2v/common/ops/attn/torch_sdpa.py
0 → 100644
View file @
a1ebc651
import
torch
import
torch.nn.functional
as
F
from
lightx2v.utils.registry_factory
import
ATTN_WEIGHT_REGISTER
from
.template
import
AttnWeightTemplate
@
ATTN_WEIGHT_REGISTER
(
"torch_sdpa"
)
class
TorchSDPAWeight
(
AttnWeightTemplate
):
def
__init__
(
self
):
self
.
config
=
{}
def
apply
(
self
,
q
,
k
,
v
,
drop_rate
=
0
,
attn_mask
=
None
,
causal
=
False
,
cu_seqlens_q
=
None
,
cu_seqlens_kv
=
None
,
max_seqlen_q
=
None
,
max_seqlen_kv
=
None
,
model_cls
=
None
,
):
if
q
.
ndim
==
3
:
q
,
k
,
v
=
q
.
unsqueeze
(
0
),
k
.
unsqueeze
(
0
),
v
.
unsqueeze
(
0
)
q
=
q
.
transpose
(
1
,
2
)
k
=
k
.
transpose
(
1
,
2
)
v
=
v
.
transpose
(
1
,
2
)
if
attn_mask
is
not
None
and
attn_mask
.
dtype
!=
torch
.
bool
:
attn_mask
=
attn_mask
.
to
(
q
.
dtype
)
x
=
F
.
scaled_dot_product_attention
(
q
,
k
,
v
,
attn_mask
=
attn_mask
,
dropout_p
=
drop_rate
,
is_causal
=
causal
)
x
=
x
.
transpose
(
1
,
2
)
b
,
s
,
a
,
d
=
x
.
shape
out
=
x
.
reshape
(
b
,
s
,
-
1
)
return
out
.
squeeze
(
0
)
lightx2v/common/ops/attn/ulysses_attn.py
0 → 100644
View file @
a1ebc651
import
torch
import
torch.distributed
as
dist
from
loguru
import
logger
from
lightx2v.utils.quant_utils
import
dequant_fp8_vllm
,
quant_fp8_vllm
from
lightx2v.utils.registry_factory
import
ATTN_WEIGHT_REGISTER
from
lightx2v_platform.base.global_var
import
AI_DEVICE
from
.template
import
AttnWeightTemplate
from
.utils.all2all
import
all2all_head2seq
,
all2all_seq2head
@
ATTN_WEIGHT_REGISTER
(
"ulysses"
)
class
UlyssesAttnWeight
(
AttnWeightTemplate
):
def
__init__
(
self
):
self
.
config
=
{}
def
apply
(
self
,
q
,
k
,
v
,
img_qkv_len
,
cu_seqlens_qkv
,
attention_module
=
None
,
seq_p_group
=
None
,
model_cls
=
None
,
use_fp8_comm
=
False
):
"""
执行 Ulysses 注意力机制,结合图像和文本的查询、键和值。
参数:
q (torch.Tensor): 查询张量,形状为 [shard_seqlen, heads, hidden_dims]
k (torch.Tensor): 键张量,形状为 [shard_seqlen, heads, hidden_dims]
v (torch.Tensor): 值张量,形状为 [shard_seqlen, heads, hidden_dims]
img_qkv_len (int): 图像查询、键和值的长度
cu_seqlens_qkv (torch.Tensor): 累积序列长度,包含文本和图像的长度信息
attention_type (str): 注意力类型,默认为 "flash_attn2"
返回:
torch.Tensor: 计算得到的注意力结果
"""
if
len
(
q
.
shape
)
==
4
:
q
=
q
.
reshape
(
-
1
,
q
.
shape
[
-
2
],
q
.
shape
[
-
1
])
k
=
k
.
reshape
(
-
1
,
k
.
shape
[
-
2
],
k
.
shape
[
-
1
])
v
=
v
.
reshape
(
-
1
,
v
.
shape
[
-
2
],
v
.
shape
[
-
1
])
# 获取当前进程的排名和全局进程数
world_size
=
dist
.
get_world_size
(
seq_p_group
)
cur_rank
=
dist
.
get_rank
(
seq_p_group
)
# 获取序列长度和文本相关的长度
seq_len
=
q
.
shape
[
0
]
if
len
(
cu_seqlens_qkv
)
==
3
:
txt_qkv_len
=
cu_seqlens_qkv
[
1
]
-
img_qkv_len
# 文本查询、键和值的长度
txt_mask_len
=
cu_seqlens_qkv
[
2
]
-
img_qkv_len
# 文本掩码长度
elif
len
(
cu_seqlens_qkv
)
==
2
:
txt_qkv_len
=
cu_seqlens_qkv
[
1
]
-
img_qkv_len
# 文本查询、键和值的长度
txt_mask_len
=
None
# 获取查询张量的头数和隐藏维度
_
,
heads
,
hidden_dims
=
q
.
shape
shard_heads
=
heads
//
world_size
# 每个进程处理的头数
shard_seqlen
=
img_qkv_len
# 每个进程处理的序列长度
# 分割图像和文本的查询、键和值
img_q
,
img_k
,
img_v
=
q
[:
img_qkv_len
,
:,
:].
contiguous
(),
k
[:
img_qkv_len
,
:,
:].
contiguous
(),
v
[:
img_qkv_len
,
:,
:].
contiguous
()
txt_q
,
txt_k
,
txt_v
=
q
[
img_qkv_len
:,
:,
:].
contiguous
(),
k
[
img_qkv_len
:,
:,
:].
contiguous
(),
v
[
img_qkv_len
:,
:,
:].
contiguous
()
# 将图像的查询、键和值转换为头的格式
if
use_fp8_comm
:
original_dtype
=
img_q
.
dtype
original_shape
=
img_q
.
shape
img_q_fp8
,
q_scale
=
quant_fp8_vllm
(
img_q
.
reshape
(
-
1
,
original_shape
[
-
1
]))
img_k_fp8
,
k_scale
=
quant_fp8_vllm
(
img_k
.
reshape
(
-
1
,
original_shape
[
-
1
]))
img_v_fp8
,
v_scale
=
quant_fp8_vllm
(
img_v
.
reshape
(
-
1
,
original_shape
[
-
1
]))
img_q_fp8
=
all2all_seq2head
(
img_q_fp8
.
reshape
(
original_shape
),
group
=
seq_p_group
)
img_k_fp8
=
all2all_seq2head
(
img_k_fp8
.
reshape
(
original_shape
),
group
=
seq_p_group
)
img_v_fp8
=
all2all_seq2head
(
img_v_fp8
.
reshape
(
original_shape
),
group
=
seq_p_group
)
q_scale
=
all2all_seq2head
(
q_scale
.
reshape
(
original_shape
[
0
],
original_shape
[
1
],
1
),
group
=
seq_p_group
)
k_scale
=
all2all_seq2head
(
k_scale
.
reshape
(
original_shape
[
0
],
original_shape
[
1
],
1
),
group
=
seq_p_group
)
v_scale
=
all2all_seq2head
(
v_scale
.
reshape
(
original_shape
[
0
],
original_shape
[
1
],
1
),
group
=
seq_p_group
)
img_q
=
dequant_fp8_vllm
(
img_q_fp8
,
q_scale
,
original_dtype
)
img_k
=
dequant_fp8_vllm
(
img_k_fp8
,
k_scale
,
original_dtype
)
img_v
=
dequant_fp8_vllm
(
img_v_fp8
,
v_scale
,
original_dtype
)
else
:
img_q
=
all2all_seq2head
(
img_q
,
group
=
seq_p_group
)
img_k
=
all2all_seq2head
(
img_k
,
group
=
seq_p_group
)
img_v
=
all2all_seq2head
(
img_v
,
group
=
seq_p_group
)
# 处理文本的查询、键和值,选择当前进程的头
txt_q
=
txt_q
[:,
cur_rank
*
shard_heads
:
(
cur_rank
+
1
)
*
shard_heads
,
:]
txt_k
=
txt_k
[:,
cur_rank
*
shard_heads
:
(
cur_rank
+
1
)
*
shard_heads
,
:]
txt_v
=
txt_v
[:,
cur_rank
*
shard_heads
:
(
cur_rank
+
1
)
*
shard_heads
,
:]
# 合并图像和文本的查询、键和值
q
=
torch
.
cat
((
img_q
,
txt_q
),
dim
=
0
)
k
=
torch
.
cat
((
img_k
,
txt_k
),
dim
=
0
)
v
=
torch
.
cat
((
img_v
,
txt_v
),
dim
=
0
)
# 初始化累积序列长度张量
cu_seqlens_qkv
=
torch
.
zeros
([
2
],
dtype
=
torch
.
int32
,
device
=
AI_DEVICE
)
s
=
txt_qkv_len
+
img_q
.
shape
[
0
]
# 计算文本和图像的总长度
s1
=
s
# 当前样本的结束位置
cu_seqlens_qkv
[
1
]
=
s1
# 设置累积序列长度
if
txt_mask_len
:
s2
=
txt_mask_len
+
img_q
.
shape
[
0
]
# 文本掩码的结束位置
cu_seqlens_qkv
=
torch
.
cat
(
cu_seqlens_qkv
,
s2
)
max_seqlen_qkv
=
img_q
.
shape
[
0
]
+
txt_q
.
shape
[
0
]
# 最大序列长度
# 调用注意力函数计算注意力结果
# attn = attention(attention_type=attention_type, q=q, k=k, v=v, cu_seqlens_q=cu_seqlens_qkv, cu_seqlens_kv=cu_seqlens_qkv, max_seqlen_q=max_seqlen_qkv, max_seqlen_kv=max_seqlen_qkv)
attn
=
attention_module
.
apply
(
q
=
q
,
k
=
k
,
v
=
v
,
cu_seqlens_q
=
cu_seqlens_qkv
,
cu_seqlens_kv
=
cu_seqlens_qkv
,
max_seqlen_q
=
max_seqlen_qkv
,
max_seqlen_kv
=
max_seqlen_qkv
,
model_cls
=
model_cls
)
# 分割图像和文本的注意力结果
img_attn
,
txt_attn
=
attn
[:
img_q
.
shape
[
0
],
:],
attn
[
img_q
.
shape
[
0
]
:,]
# 收集所有进程的文本注意力结果
gathered_txt_attn
=
[
torch
.
empty_like
(
txt_attn
)
for
_
in
range
(
world_size
)]
dist
.
all_gather
(
gathered_txt_attn
,
txt_attn
,
group
=
seq_p_group
)
img_attn
=
self
.
_reshape_img_attn
(
img_attn
,
world_size
,
shard_seqlen
,
shard_heads
,
hidden_dims
,
seq_p_group
,
use_fp8_comm
)
txt_attn
=
torch
.
cat
(
gathered_txt_attn
,
dim
=
1
)
# 合并所有进程的文本注意力结果
# 合并图像和文本的注意力结果
attn
=
torch
.
cat
([
img_attn
,
txt_attn
],
dim
=
0
)
return
attn
# 返回最终的注意力结果
@
torch
.
compiler
.
disable
def
_reshape_img_attn
(
self
,
img_attn
,
world_size
,
shard_seqlen
,
shard_heads
,
hidden_dims
,
seq_p_group
,
use_fp8_comm
):
img_attn
=
img_attn
.
reshape
(
world_size
*
shard_seqlen
,
shard_heads
,
hidden_dims
)
# 重塑图像注意力结果
# 将头的格式转换回序列格式
if
use_fp8_comm
:
original_dtype
=
img_attn
.
dtype
original_shape
=
img_attn
.
shape
img_attn_fp8
,
attn_scale
=
quant_fp8_vllm
(
img_attn
.
reshape
(
-
1
,
original_shape
[
-
1
]))
img_attn_fp8
=
all2all_head2seq
(
img_attn_fp8
.
reshape
(
original_shape
),
group
=
seq_p_group
)
attn_scale
=
all2all_head2seq
(
attn_scale
.
reshape
(
original_shape
[
0
],
original_shape
[
1
],
1
),
group
=
seq_p_group
)
img_attn
=
dequant_fp8_vllm
(
img_attn_fp8
,
attn_scale
,
original_dtype
)
else
:
img_attn
=
all2all_head2seq
(
img_attn
,
group
=
seq_p_group
)
img_attn
=
img_attn
.
reshape
(
shard_seqlen
,
-
1
)
# 重塑为 [shard_seqlen, -1] 形状
return
img_attn
@
ATTN_WEIGHT_REGISTER
(
"ulysses-4090"
)
class
Ulysses4090AttnWeight
(
AttnWeightTemplate
):
def
__init__
(
self
):
self
.
config
=
{}
self
.
rounds
=
[]
def
generate_round_robin_pairs
(
self
,
seq_p_group
=
None
):
"""
生成循环赛配对表,并确保每个配对中的第一个元素小于第二个
这样我们可以用简单的规则确定通信顺序
"""
cur_rank
=
dist
.
get_rank
(
seq_p_group
)
world_size
=
dist
.
get_world_size
(
seq_p_group
)
if
world_size
%
2
!=
0
:
raise
ValueError
(
"world_size必须是偶数,奇数情况需要特殊处理"
)
teams
=
list
(
range
(
world_size
))
for
_
in
range
(
world_size
-
1
):
round_schedule
=
{}
for
i
in
range
(
world_size
//
2
):
team1
,
team2
=
teams
[
i
],
teams
[
world_size
-
1
-
i
]
smaller
,
larger
=
min
(
team1
,
team2
),
max
(
team1
,
team2
)
round_schedule
[
smaller
]
=
(
larger
,
True
)
round_schedule
[
larger
]
=
(
smaller
,
False
)
self
.
rounds
.
append
(
round_schedule
)
# 旋转列表(固定第一个元素)
teams
=
[
teams
[
0
]]
+
[
teams
[
-
1
]]
+
teams
[
1
:
-
1
]
# if cur_rank == 0:
# self.print_pairing_schedule(seq_p_group)
def
print_pairing_schedule
(
self
,
seq_p_group
):
"""打印通信调度表"""
world_size
=
dist
.
get_world_size
(
seq_p_group
)
logger
.
info
(
"循环赛通信调度表:"
)
logger
.
info
(
"="
*
50
)
for
i
,
round_schedule
in
enumerate
(
self
.
rounds
):
logger
.
info
(
f
"第
{
i
+
1
}
轮:"
)
for
cur_rank
in
range
(
world_size
):
partner
,
is_smaller_in_pair
=
round_schedule
[
cur_rank
]
logger
.
info
(
f
" 进程
{
cur_rank
}
←→ 进程
{
partner
}
"
)
logger
.
info
(
"="
*
50
)
def
load_balanced_all_to_all
(
self
,
shards
,
seq_p_group
=
None
):
"""
负载均衡all-to-all通信实现
"""
world_size
=
dist
.
get_world_size
(
seq_p_group
)
cur_rank
=
dist
.
get_rank
(
seq_p_group
)
global_rank
=
dist
.
get_global_rank
(
seq_p_group
,
cur_rank
)
cfg_p_group_index
=
global_rank
//
world_size
# 准备接收缓冲区
gathered_shards
=
[
None
]
*
world_size
for
target_rank
in
range
(
world_size
):
if
target_rank
!=
cur_rank
:
gathered_shards
[
target_rank
]
=
torch
.
empty_like
(
shards
[
target_rank
])
else
:
gathered_shards
[
cur_rank
]
=
shards
[
cur_rank
]
for
i
,
round_schedule
in
enumerate
(
self
.
rounds
):
# 查找当前进程在本轮的配对
partner
=
None
is_smaller_in_pair
=
False
if
cur_rank
in
round_schedule
:
partner
,
is_smaller_in_pair
=
round_schedule
[
cur_rank
]
# 如果没有找到配对,说明本轮当前进程空闲
if
partner
is
None
:
continue
# 计算全局rank
partner_global_rank
=
cfg_p_group_index
*
world_size
+
partner
if
is_smaller_in_pair
:
# 当前进程是配对中的较小者,先发送后接收
send_req
=
dist
.
isend
(
shards
[
partner
],
dst
=
partner_global_rank
,
group
=
seq_p_group
)
recv_req
=
dist
.
irecv
(
gathered_shards
[
partner
],
src
=
partner_global_rank
,
group
=
seq_p_group
)
send_req
.
wait
()
recv_req
.
wait
()
else
:
# 当前进程是配对中的较大者,先接收后发送
recv_req
=
dist
.
irecv
(
gathered_shards
[
partner
],
src
=
partner_global_rank
,
group
=
seq_p_group
)
send_req
=
dist
.
isend
(
shards
[
partner
],
dst
=
partner_global_rank
,
group
=
seq_p_group
)
recv_req
.
wait
()
send_req
.
wait
()
return
gathered_shards
def
apply
(
self
,
q
,
k
,
v
,
img_qkv_len
,
cu_seqlens_qkv
,
attention_module
=
None
,
seq_p_group
=
None
,
model_cls
=
None
,
use_fp8_comm
=
False
):
"""
执行 Ulysses 注意力机制,结合图像和文本的查询、键和值。
参数:
q (torch.Tensor): 查询张量,形状为 [shard_seqlen, heads, hidden_dims]
k (torch.Tensor): 键张量,形状为 [shard_seqlen, heads, hidden_dims]
v (torch.Tensor): 值张量,形状为 [shard_seqlen, heads, hidden_dims]
img_qkv_len (int): 图像查询、键和值的长度
cu_seqlens_qkv (torch.Tensor): 累积序列长度,包含文本和图像的长度信息
attention_type (str): 注意力类型,默认为 "flash_attn2"
返回:
torch.Tensor: 计算得到的注意力结果
"""
if
len
(
self
.
rounds
)
==
0
:
self
.
generate_round_robin_pairs
(
seq_p_group
)
if
len
(
q
.
shape
)
==
4
:
q
=
q
.
reshape
(
-
1
,
q
.
shape
[
-
2
],
q
.
shape
[
-
1
])
k
=
k
.
reshape
(
-
1
,
k
.
shape
[
-
2
],
k
.
shape
[
-
1
])
v
=
v
.
reshape
(
-
1
,
v
.
shape
[
-
2
],
v
.
shape
[
-
1
])
# 获取当前进程的排名和全局进程数
world_size
=
dist
.
get_world_size
(
seq_p_group
)
cur_rank
=
dist
.
get_rank
(
seq_p_group
)
global_world_size
=
dist
.
get_world_size
()
global_rank
=
dist
.
get_global_rank
(
seq_p_group
,
cur_rank
)
cfg_p_group_index
=
global_rank
//
world_size
# 获取序列长度和文本相关的长度
seq_len
=
q
.
shape
[
0
]
if
len
(
cu_seqlens_qkv
)
==
3
:
txt_qkv_len
=
cu_seqlens_qkv
[
1
]
-
img_qkv_len
# 文本查询、键和值的长度
txt_mask_len
=
cu_seqlens_qkv
[
2
]
-
img_qkv_len
# 文本掩码长度
elif
len
(
cu_seqlens_qkv
)
==
2
:
txt_qkv_len
=
cu_seqlens_qkv
[
1
]
-
img_qkv_len
# 文本查询、键和值的长度
txt_mask_len
=
None
# 获取查询张量的头数和隐藏维度
_
,
heads
,
hidden_dims
=
q
.
shape
shard_heads
=
heads
//
world_size
# 每个进程处理的头数
shard_seqlen
=
img_qkv_len
# 每个进程处理的序列长度
# 分割图像和文本的查询、键和值
img_q
,
img_k
,
img_v
=
q
[:
img_qkv_len
,
:,
:].
contiguous
(),
k
[:
img_qkv_len
,
:,
:].
contiguous
(),
v
[:
img_qkv_len
,
:,
:].
contiguous
()
txt_q
,
txt_k
,
txt_v
=
q
[
img_qkv_len
:,
:,
:].
contiguous
(),
k
[
img_qkv_len
:,
:,
:].
contiguous
(),
v
[
img_qkv_len
:,
:,
:].
contiguous
()
# 计算每个进程应该持有的头数分片
num_heads
=
img_q
.
shape
[
1
]
shard_heads
=
num_heads
//
world_size
# 将 image QKV 拼接后,按头维度切分成 N 份,每份大小为 D/N
img_qkv
=
torch
.
stack
([
img_q
,
img_k
,
img_v
],
dim
=
0
)
qkv_shards
=
[
img_qkv
[:,
:,
i
*
shard_heads
:
(
i
+
1
)
*
shard_heads
,
:].
contiguous
()
for
i
in
range
(
world_size
)]
qkv_dtype
=
img_qkv
.
dtype
if
use_fp8_comm
:
qkv_fp8_byte_tensors
=
[]
qkv_fp8_bytes
=
0
qkv_fp8_dtype
=
None
qkv_scale_dtype
=
None
for
i
in
range
(
world_size
):
qkv_fp8
,
qkv_scale
=
quant_fp8_vllm
(
qkv_shards
[
i
].
reshape
(
-
1
,
hidden_dims
))
if
i
==
0
:
qkv_fp8_bytes
=
qkv_fp8
.
numel
()
*
qkv_fp8
.
element_size
()
qkv_fp8_dtype
=
qkv_fp8
.
dtype
qkv_scale_dtype
=
qkv_scale
.
dtype
qkv_fp8_byte_tensors
.
append
(
torch
.
cat
([
qkv_fp8
.
contiguous
().
reshape
(
-
1
).
view
(
torch
.
uint8
),
qkv_scale
.
contiguous
().
reshape
(
-
1
).
view
(
torch
.
uint8
)],
dim
=
0
))
gathered_qkv_fp8_byte_tensors
=
self
.
load_balanced_all_to_all
(
qkv_fp8_byte_tensors
,
seq_p_group
)
gathered_q_shards
=
[]
gathered_k_shards
=
[]
gathered_v_shards
=
[]
for
i
in
range
(
world_size
):
qkv_fp8_byte_tensor
=
gathered_qkv_fp8_byte_tensors
[
i
]
qkv_fp8
=
qkv_fp8_byte_tensor
[:
qkv_fp8_bytes
].
view
(
qkv_fp8_dtype
).
reshape
(
3
,
-
1
,
hidden_dims
)
qkv_scale
=
qkv_fp8_byte_tensor
[
qkv_fp8_bytes
:].
view
(
qkv_scale_dtype
).
reshape
(
3
,
-
1
,
1
)
q_shards_new
=
dequant_fp8_vllm
(
qkv_fp8
[
0
],
qkv_scale
[
0
],
qkv_dtype
).
reshape
(
-
1
,
shard_heads
,
hidden_dims
)
k_shards_new
=
dequant_fp8_vllm
(
qkv_fp8
[
1
],
qkv_scale
[
1
],
qkv_dtype
).
reshape
(
-
1
,
shard_heads
,
hidden_dims
)
v_shards_new
=
dequant_fp8_vllm
(
qkv_fp8
[
2
],
qkv_scale
[
2
],
qkv_dtype
).
reshape
(
-
1
,
shard_heads
,
hidden_dims
)
gathered_q_shards
.
append
(
q_shards_new
)
gathered_k_shards
.
append
(
k_shards_new
)
gathered_v_shards
.
append
(
v_shards_new
)
else
:
gathered_qkv_byte_tensors
=
self
.
load_balanced_all_to_all
(
qkv_shards
,
seq_p_group
)
gathered_q_shards
=
[]
gathered_k_shards
=
[]
gathered_v_shards
=
[]
for
i
in
range
(
world_size
):
qkv_tensor
=
gathered_qkv_byte_tensors
[
i
].
view
(
qkv_dtype
).
reshape
(
3
,
-
1
,
shard_heads
,
hidden_dims
)
gathered_q_shards
.
append
(
qkv_tensor
[
0
])
gathered_k_shards
.
append
(
qkv_tensor
[
1
])
gathered_v_shards
.
append
(
qkv_tensor
[
2
])
# 拼接所有分片 (在序列维度上)
# 每个 gathered_*_shards[i] 的形状是 (seq_len/N, num_heads/N, head_dim)
# 拼接后形状是 (seq_len, num_heads/N, head_dim)
img_q
=
torch
.
cat
(
gathered_q_shards
,
dim
=
0
)
img_k
=
torch
.
cat
(
gathered_k_shards
,
dim
=
0
)
img_v
=
torch
.
cat
(
gathered_v_shards
,
dim
=
0
)
# 处理文本的查询、键和值,选择当前进程的头
txt_q
=
txt_q
[:,
cur_rank
*
shard_heads
:
(
cur_rank
+
1
)
*
shard_heads
,
:]
txt_k
=
txt_k
[:,
cur_rank
*
shard_heads
:
(
cur_rank
+
1
)
*
shard_heads
,
:]
txt_v
=
txt_v
[:,
cur_rank
*
shard_heads
:
(
cur_rank
+
1
)
*
shard_heads
,
:]
# 合并图像和文本的查询、键和值
q
=
torch
.
cat
((
img_q
,
txt_q
),
dim
=
0
)
k
=
torch
.
cat
((
img_k
,
txt_k
),
dim
=
0
)
v
=
torch
.
cat
((
img_v
,
txt_v
),
dim
=
0
)
# 初始化累积序列长度张量
cu_seqlens_qkv
=
torch
.
zeros
([
2
],
dtype
=
torch
.
int32
,
device
=
"cuda"
)
s
=
txt_qkv_len
+
img_q
.
shape
[
0
]
# 计算文本和图像的总长度
s1
=
s
# 当前样本的结束位置
cu_seqlens_qkv
[
1
]
=
s1
# 设置累积序列长度
if
txt_mask_len
:
s2
=
txt_mask_len
+
img_q
.
shape
[
0
]
# 文本掩码的结束位置
cu_seqlens_qkv
=
torch
.
cat
(
cu_seqlens_qkv
,
s2
)
max_seqlen_qkv
=
img_q
.
shape
[
0
]
+
txt_q
.
shape
[
0
]
# 最大序列长度
# 调用注意力函数计算注意力结果
# attn = attention(attention_type=attention_type, q=q, k=k, v=v, cu_seqlens_q=cu_seqlens_qkv, cu_seqlens_kv=cu_seqlens_qkv, max_seqlen_q=max_seqlen_qkv, max_seqlen_kv=max_seqlen_qkv)
attn
=
attention_module
.
apply
(
q
=
q
,
k
=
k
,
v
=
v
,
cu_seqlens_q
=
cu_seqlens_qkv
,
cu_seqlens_kv
=
cu_seqlens_qkv
,
max_seqlen_q
=
max_seqlen_qkv
,
max_seqlen_kv
=
max_seqlen_qkv
,
model_cls
=
model_cls
)
# 分割图像和文本的注意力结果
img_attn
,
txt_attn
=
attn
[:
img_q
.
shape
[
0
],
:],
attn
[
img_q
.
shape
[
0
]
:,]
# 收集所有进程的文本注意力结果
gathered_txt_attn
=
[
torch
.
empty_like
(
txt_attn
)
for
_
in
range
(
world_size
)]
dist
.
all_gather
(
gathered_txt_attn
,
txt_attn
,
group
=
seq_p_group
)
img_attn
=
self
.
_reshape_img_attn
(
img_attn
,
world_size
,
shard_seqlen
,
shard_heads
,
hidden_dims
,
seq_p_group
,
use_fp8_comm
)
txt_attn
=
torch
.
cat
(
gathered_txt_attn
,
dim
=
1
)
# 合并所有进程的文本注意力结果
# 合并图像和文本的注意力结果
attn
=
torch
.
cat
([
img_attn
,
txt_attn
],
dim
=
0
)
return
attn
# 返回最终的注意力结果
@
torch
.
compiler
.
disable
def
_reshape_img_attn
(
self
,
img_attn
,
world_size
,
shard_seqlen
,
shard_heads
,
hidden_dims
,
seq_p_group
,
use_fp8_comm
):
cur_rank
=
dist
.
get_rank
(
seq_p_group
)
global_world_size
=
dist
.
get_world_size
()
global_rank
=
dist
.
get_global_rank
(
seq_p_group
,
cur_rank
)
cfg_p_group_index
=
global_rank
//
world_size
img_attn
=
img_attn
.
reshape
(
world_size
*
shard_seqlen
,
shard_heads
,
hidden_dims
)
# 重塑图像注意力结果
attn_dtype
=
img_attn
.
dtype
# 按序列维度切分成 N 份
attn_shards
=
[
img_attn
[
i
*
shard_seqlen
:
(
i
+
1
)
*
shard_seqlen
,
:,
:].
contiguous
()
for
i
in
range
(
world_size
)]
if
use_fp8_comm
:
attn_fp8_byte_tensors
=
[]
attn_fp8_bytes
=
0
attn_fp8_dtype
=
None
attn_scale_dtype
=
None
for
i
in
range
(
world_size
):
attn_fp8
,
attn_scale
=
quant_fp8_vllm
(
attn_shards
[
i
].
reshape
(
-
1
,
hidden_dims
))
if
i
==
0
:
attn_fp8_bytes
=
attn_fp8
.
numel
()
*
attn_fp8
.
element_size
()
attn_fp8_dtype
=
attn_fp8
.
dtype
attn_scale_dtype
=
attn_scale
.
dtype
attn_fp8_byte_tensors
.
append
(
torch
.
cat
([
attn_fp8
.
contiguous
().
reshape
(
-
1
).
view
(
torch
.
uint8
),
attn_scale
.
contiguous
().
reshape
(
-
1
).
view
(
torch
.
uint8
)],
dim
=
0
))
gathered_attn_fp8_byte_tensors
=
self
.
load_balanced_all_to_all
(
attn_fp8_byte_tensors
,
seq_p_group
)
gathered_attn_shards
=
[]
for
i
in
range
(
world_size
):
attn_fp8_byte_tensor
=
gathered_attn_fp8_byte_tensors
[
i
]
attn_fp8
=
attn_fp8_byte_tensor
[:
attn_fp8_bytes
].
view
(
attn_fp8_dtype
).
reshape
(
-
1
,
hidden_dims
)
attn_scale
=
attn_fp8_byte_tensor
[
attn_fp8_bytes
:].
view
(
attn_scale_dtype
).
reshape
(
-
1
,
1
)
attn_shards_new
=
dequant_fp8_vllm
(
attn_fp8
,
attn_scale
,
attn_dtype
).
reshape
(
-
1
,
shard_heads
,
hidden_dims
)
gathered_attn_shards
.
append
(
attn_shards_new
)
else
:
gathered_attn_shards
=
self
.
load_balanced_all_to_all
(
attn_shards
,
seq_p_group
)
# 拼接所有分片 (在头维度上)
img_attn
=
torch
.
cat
(
gathered_attn_shards
,
dim
=
1
)
img_attn
=
img_attn
.
reshape
(
shard_seqlen
,
-
1
)
# 重塑为 [shard_seqlen, -1] 形状
return
img_attn
lightx2v/common/ops/attn/utils/all2all.py
0 → 100644
View file @
a1ebc651
import
torch
import
torch._dynamo
as
dynamo
import
torch.distributed
as
dist
@
dynamo
.
disable
def
all2all_seq2head
(
input
,
group
=
None
):
"""
将输入张量从 [seq_len/N, heads, hidden_dims] 转换为 [seq_len, heads/N, hidden_dims] 的格式。
参数:
input (torch.Tensor): 输入张量,形状为 [seq_len/N, heads, hidden_dims]
返回:
torch.Tensor: 转换后的输出张量,形状为 [seq_len, heads/N, hidden_dims]
"""
# 确保输入是一个3D张量
assert
input
.
dim
()
==
3
,
f
"input must be 3D tensor"
# 获取当前进程的世界大小
world_size
=
dist
.
get_world_size
(
group
=
group
)
# 获取输入张量的形状
shard_seq_len
,
heads
,
hidden_dims
=
input
.
shape
seq_len
=
shard_seq_len
*
world_size
# 计算总序列长度
shard_heads
=
heads
//
world_size
# 计算每个进程处理的头数
# 重塑输入张量以便进行 all-to-all 操作
input_t
=
(
input
.
reshape
(
shard_seq_len
,
world_size
,
shard_heads
,
hidden_dims
)
# 重塑为 [shard_seq_len, world_size, shard_heads, hidden_dims]
.
transpose
(
0
,
1
)
# 转置以便进行 all-to-all 操作
.
contiguous
()
# 确保内存连续
)
# 创建一个与输入张量相同形状的输出张量
output
=
torch
.
empty_like
(
input_t
)
# 执行 all-to-all 操作,将输入张量的内容分发到所有进程
dist
.
all_to_all_single
(
output
,
input_t
,
group
=
group
)
# 重塑输出张量为 [seq_len, heads/N, hidden_dims] 形状
output
=
output
.
reshape
(
seq_len
,
shard_heads
,
hidden_dims
).
contiguous
()
return
output
# 返回转换后的输出张量
@
dynamo
.
disable
def
all2all_head2seq
(
input
,
group
=
None
):
"""
将输入张量从 [seq_len, heads/N, hidden_dims] 转换为 [seq_len/N, heads, hidden_dims] 的格式。
参数:
input (torch.Tensor): 输入张量,形状为 [seq_len, heads/N, hidden_dims]
返回:
torch.Tensor: 转换后的输出张量,形状为 [seq_len/N, heads, hidden_dims]
"""
# 确保输入是一个3D张量
assert
input
.
dim
()
==
3
,
f
"input must be 3D tensor"
# 获取当前进程的世界大小
world_size
=
dist
.
get_world_size
(
group
=
group
)
# 获取输入张量的形状
seq_len
,
shard_heads
,
hidden_dims
=
input
.
shape
heads
=
shard_heads
*
world_size
# 计算总头数
shard_seq_len
=
seq_len
//
world_size
# 计算每个进程处理的序列长度
# 重塑输入张量以便进行 all-to-all 操作
input_t
=
(
input
.
reshape
(
world_size
,
shard_seq_len
,
shard_heads
,
hidden_dims
)
# 重塑为 [world_size, shard_seq_len, shard_heads, hidden_dims]
.
transpose
(
1
,
2
)
# 转置以便进行 all-to-all 操作
.
contiguous
()
# 确保内存连续
.
reshape
(
world_size
,
shard_heads
,
shard_seq_len
,
hidden_dims
)
# 再次重塑为 [world_size, shard_heads, shard_seq_len, hidden_dims]
)
# 创建一个与输入张量相同形状的输出张量
output
=
torch
.
empty_like
(
input_t
)
# 执行 all-to-all 操作,将输入张量的内容分发到所有进程
dist
.
all_to_all_single
(
output
,
input_t
,
group
=
group
)
# 重塑输出张量为 [heads, shard_seq_len, hidden_dims] 形状
output
=
output
.
reshape
(
heads
,
shard_seq_len
,
hidden_dims
)
# 转置输出张量并重塑为 [shard_seq_len, heads, hidden_dims] 形状
output
=
output
.
transpose
(
0
,
1
).
contiguous
().
reshape
(
shard_seq_len
,
heads
,
hidden_dims
)
return
output
# 返回转换后的输出张量
lightx2v/common/ops/attn/utils/ring_comm.py
0 → 100644
View file @
a1ebc651
from
typing
import
Optional
import
torch
import
torch.distributed
as
dist
class
RingComm
:
def
__init__
(
self
,
process_group
:
dist
.
ProcessGroup
=
None
):
self
.
_process_group
=
process_group
self
.
_ops
=
[]
self
.
rank
=
dist
.
get_rank
(
self
.
_process_group
)
self
.
world_size
=
dist
.
get_world_size
(
self
.
_process_group
)
self
.
_reqs
=
None
self
.
send_rank
=
(
self
.
rank
+
1
)
%
self
.
world_size
self
.
recv_rank
=
(
self
.
rank
-
1
)
%
self
.
world_size
if
process_group
is
not
None
:
self
.
send_rank
=
dist
.
get_global_rank
(
self
.
_process_group
,
self
.
send_rank
)
self
.
recv_rank
=
dist
.
get_global_rank
(
self
.
_process_group
,
self
.
recv_rank
)
def
send_recv
(
self
,
to_send
:
torch
.
Tensor
,
recv_tensor
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
if
recv_tensor
is
None
:
res
=
torch
.
empty_like
(
to_send
)
# logger.info(f"send_recv: empty_like {to_send.shape}")
else
:
res
=
recv_tensor
send_op
=
dist
.
P2POp
(
dist
.
isend
,
to_send
,
self
.
send_rank
,
group
=
self
.
_process_group
)
recv_op
=
dist
.
P2POp
(
dist
.
irecv
,
res
,
self
.
recv_rank
,
group
=
self
.
_process_group
)
self
.
_ops
.
append
(
send_op
)
self
.
_ops
.
append
(
recv_op
)
return
res
def
commit
(
self
):
if
self
.
_reqs
is
not
None
:
raise
RuntimeError
(
"commit called twice"
)
self
.
_reqs
=
dist
.
batch_isend_irecv
(
self
.
_ops
)
def
wait
(
self
):
if
self
.
_reqs
is
None
:
raise
RuntimeError
(
"wait called before commit"
)
for
req
in
self
.
_reqs
:
req
.
wait
()
self
.
_reqs
=
None
self
.
_ops
=
[]
lightx2v/common/ops/conv/__init__.py
0 → 100644
View file @
a1ebc651
from
.conv2d
import
*
from
.conv3d
import
*
lightx2v/common/ops/conv/conv2d.py
0 → 100644
View file @
a1ebc651
from
abc
import
ABCMeta
,
abstractmethod
import
torch
from
lightx2v.utils.registry_factory
import
CONV2D_WEIGHT_REGISTER
from
lightx2v_platform.base.global_var
import
AI_DEVICE
class
Conv2dWeightTemplate
(
metaclass
=
ABCMeta
):
def
__init__
(
self
,
weight_name
,
bias_name
,
stride
,
padding
,
dilation
,
groups
):
self
.
weight_name
=
weight_name
self
.
bias_name
=
bias_name
self
.
stride
=
stride
self
.
padding
=
padding
self
.
dilation
=
dilation
self
.
groups
=
groups
self
.
config
=
{}
@
abstractmethod
def
load
(
self
,
weight_dict
):
pass
@
abstractmethod
def
apply
(
self
,
input_tensor
):
pass
def
set_config
(
self
,
config
=
None
):
if
config
is
not
None
:
self
.
config
=
config
@
CONV2D_WEIGHT_REGISTER
(
"Default"
)
class
Conv2dWeight
(
Conv2dWeightTemplate
):
def
__init__
(
self
,
weight_name
,
bias_name
,
stride
=
1
,
padding
=
0
,
dilation
=
1
,
groups
=
1
):
super
().
__init__
(
weight_name
,
bias_name
,
stride
,
padding
,
dilation
,
groups
)
def
load
(
self
,
weight_dict
):
self
.
weight
=
weight_dict
[
self
.
weight_name
].
to
(
AI_DEVICE
)
self
.
bias
=
weight_dict
[
self
.
bias_name
].
to
(
AI_DEVICE
)
if
self
.
bias_name
is
not
None
else
None
def
apply
(
self
,
input_tensor
):
input_tensor
=
torch
.
nn
.
functional
.
conv2d
(
input_tensor
,
weight
=
self
.
weight
,
bias
=
self
.
bias
,
stride
=
self
.
stride
,
padding
=
self
.
padding
,
dilation
=
self
.
dilation
,
groups
=
self
.
groups
)
return
input_tensor
def
to_cpu
(
self
,
non_blocking
=
False
):
self
.
weight
=
self
.
weight
.
cpu
(
non_blocking
=
non_blocking
)
if
self
.
bias
is
not
None
:
self
.
bias
=
self
.
bias
.
cpu
(
non_blocking
=
non_blocking
)
def
to_cuda
(
self
,
non_blocking
=
False
):
self
.
weight
=
self
.
weight
.
to
(
AI_DEVICE
,
non_blocking
=
non_blocking
)
if
self
.
bias
is
not
None
:
self
.
bias
=
self
.
bias
.
to
(
AI_DEVICE
,
non_blocking
=
non_blocking
)
def
state_dict
(
self
,
destination
=
None
):
if
destination
is
None
:
destination
=
{}
destination
[
self
.
weight_name
]
=
self
.
weight
.
cpu
().
detach
().
clone
()
if
self
.
bias
is
not
None
:
destination
[
self
.
bias_name
]
=
self
.
bias
.
cpu
().
detach
().
clone
()
return
destination
lightx2v/common/ops/conv/conv3d.py
0 → 100644
View file @
a1ebc651
from
abc
import
ABCMeta
,
abstractmethod
import
torch
from
lightx2v.utils.registry_factory
import
CONV3D_WEIGHT_REGISTER
from
lightx2v_platform.base.global_var
import
AI_DEVICE
class
Conv3dWeightTemplate
(
metaclass
=
ABCMeta
):
def
__init__
(
self
,
weight_name
,
bias_name
,
stride
=
1
,
padding
=
0
,
dilation
=
1
,
groups
=
1
):
self
.
weight_name
=
weight_name
self
.
bias_name
=
bias_name
self
.
stride
=
stride
self
.
padding
=
padding
self
.
dilation
=
dilation
self
.
groups
=
groups
self
.
config
=
{}
@
abstractmethod
def
load
(
self
,
weight_dict
):
pass
@
abstractmethod
def
apply
(
self
,
input_tensor
):
pass
def
set_config
(
self
,
config
=
None
):
if
config
is
not
None
:
self
.
config
=
config
@
CONV3D_WEIGHT_REGISTER
(
"Default"
)
class
Conv3dWeight
(
Conv3dWeightTemplate
):
def
__init__
(
self
,
weight_name
,
bias_name
,
stride
=
1
,
padding
=
0
,
dilation
=
1
,
groups
=
1
):
super
().
__init__
(
weight_name
,
bias_name
,
stride
,
padding
,
dilation
,
groups
)
def
load
(
self
,
weight_dict
):
device
=
weight_dict
[
self
.
weight_name
].
device
if
device
.
type
==
"cpu"
:
weight_shape
=
weight_dict
[
self
.
weight_name
].
shape
weight_dtype
=
weight_dict
[
self
.
weight_name
].
dtype
self
.
pin_weight
=
torch
.
empty
(
weight_shape
,
pin_memory
=
True
,
dtype
=
weight_dtype
)
self
.
pin_weight
.
copy_
(
weight_dict
[
self
.
weight_name
])
if
self
.
bias_name
is
not
None
:
bias_shape
=
weight_dict
[
self
.
bias_name
].
shape
bias_dtype
=
weight_dict
[
self
.
bias_name
].
dtype
self
.
pin_bias
=
torch
.
empty
(
bias_shape
,
pin_memory
=
True
,
dtype
=
bias_dtype
)
self
.
pin_bias
.
copy_
(
weight_dict
[
self
.
bias_name
])
else
:
self
.
bias
=
None
self
.
pin_bias
=
None
del
weight_dict
[
self
.
weight_name
]
else
:
self
.
weight
=
weight_dict
[
self
.
weight_name
]
if
self
.
bias_name
is
not
None
:
self
.
bias
=
weight_dict
[
self
.
bias_name
]
else
:
self
.
bias
=
None
def
apply
(
self
,
input_tensor
):
input_tensor
=
torch
.
nn
.
functional
.
conv3d
(
input_tensor
,
weight
=
self
.
weight
,
bias
=
self
.
bias
,
stride
=
self
.
stride
,
padding
=
self
.
padding
,
dilation
=
self
.
dilation
,
groups
=
self
.
groups
,
)
return
input_tensor
def
to_cuda
(
self
,
non_blocking
=
False
):
self
.
weight
=
self
.
pin_weight
.
to
(
AI_DEVICE
,
non_blocking
=
non_blocking
)
if
hasattr
(
self
,
"pin_bias"
)
and
self
.
pin_bias
is
not
None
:
self
.
bias
=
self
.
pin_bias
.
to
(
AI_DEVICE
,
non_blocking
=
non_blocking
)
def
to_cpu
(
self
,
non_blocking
=
False
):
if
hasattr
(
self
,
"pin_weight"
):
self
.
weight
=
self
.
pin_weight
.
copy_
(
self
.
weight
,
non_blocking
=
non_blocking
).
cpu
()
if
self
.
bias
is
not
None
:
self
.
bias
=
self
.
pin_bias
.
copy_
(
self
.
bias
,
non_blocking
=
non_blocking
).
cpu
()
else
:
self
.
weight
=
self
.
weight
.
to
(
"cpu"
,
non_blocking
=
non_blocking
)
if
hasattr
(
self
,
"bias"
)
and
self
.
bias
is
not
None
:
self
.
bias
=
self
.
bias
.
to
(
"cpu"
,
non_blocking
=
non_blocking
)
def
state_dict
(
self
,
destination
=
None
):
if
destination
is
None
:
destination
=
{}
destination
[
self
.
weight_name
]
=
self
.
pin_weight
if
hasattr
(
self
,
"pin_weight"
)
else
self
.
weight
# .cpu().detach().clone().contiguous()
if
self
.
bias_name
is
not
None
:
destination
[
self
.
bias_name
]
=
self
.
pin_bias
if
hasattr
(
self
,
"pin_bias"
)
else
self
.
bias
# .cpu().detach().clone()
return
destination
lightx2v/common/ops/embedding/__init__.py
0 → 100644
View file @
a1ebc651
from
.embedding_weight
import
*
lightx2v/common/ops/embedding/embedding_weight.py
0 → 100644
View file @
a1ebc651
import
re
from
abc
import
ABCMeta
import
torch
import
torch.nn.functional
as
F
from
lightx2v.utils.registry_factory
import
EMBEDDING_WEIGHT_REGISTER
from
lightx2v_platform.base.global_var
import
AI_DEVICE
class
EmbeddingWeightTemplate
(
metaclass
=
ABCMeta
):
def
__init__
(
self
,
weight_name
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
self
.
weight_name
=
weight_name
self
.
create_cuda_buffer
=
create_cuda_buffer
self
.
create_cpu_buffer
=
create_cpu_buffer
self
.
lazy_load
=
lazy_load
self
.
lazy_load_file
=
lazy_load_file
self
.
is_post_adapter
=
is_post_adapter
self
.
config
=
{}
def
load
(
self
,
weight_dict
):
if
not
self
.
lazy_load
:
if
self
.
create_cuda_buffer
:
self
.
weight_cuda_buffer
=
weight_dict
[
self
.
weight_name
].
to
(
AI_DEVICE
)
else
:
device
=
weight_dict
[
self
.
weight_name
].
device
if
device
.
type
==
"cpu"
:
weight_shape
=
weight_dict
[
self
.
weight_name
].
shape
weight_dtype
=
weight_dict
[
self
.
weight_name
].
dtype
self
.
pin_weight
=
torch
.
empty
(
weight_shape
,
pin_memory
=
True
,
dtype
=
weight_dtype
)
self
.
pin_weight
.
copy_
(
weight_dict
[
self
.
weight_name
])
del
weight_dict
[
self
.
weight_name
]
else
:
self
.
weight
=
weight_dict
[
self
.
weight_name
]
def
to_cuda
(
self
,
non_blocking
=
False
):
self
.
weight
=
self
.
pin_weight
.
to
(
AI_DEVICE
,
non_blocking
=
non_blocking
)
def
to_cpu
(
self
,
non_blocking
=
False
):
if
hasattr
(
self
,
"pin_weight"
):
self
.
weight
=
self
.
pin_weight
.
copy_
(
self
.
weight
,
non_blocking
=
non_blocking
).
cpu
()
else
:
self
.
weight
=
self
.
weight
.
to
(
"cpu"
,
non_blocking
=
non_blocking
)
def
state_dict
(
self
,
destination
=
None
):
if
destination
is
None
:
destination
=
{}
destination
[
self
.
weight_name
]
=
self
.
pin_weight
if
hasattr
(
self
,
"pin_weight"
)
else
self
.
weight
return
destination
def
load_state_dict
(
self
,
destination
,
block_index
,
adapter_block_index
=
None
):
if
self
.
is_post_adapter
:
assert
adapter_block_index
is
not
None
weight_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
adapter_block_index
}
"
,
self
.
weight_name
,
count
=
1
)
else
:
weight_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
block_index
}
"
,
self
.
weight_name
,
count
=
1
)
if
weight_name
not
in
destination
:
self
.
weight
=
None
return
self
.
weight
=
self
.
weight_cuda_buffer
.
copy_
(
destination
[
weight_name
],
non_blocking
=
True
)
@
EMBEDDING_WEIGHT_REGISTER
(
"Default"
)
class
EmbeddingWeight
(
EmbeddingWeightTemplate
):
def
__init__
(
self
,
weight_name
=
None
,
lazy_load
=
False
,
lazy_load_file
=
None
):
super
().
__init__
(
weight_name
,
lazy_load
,
lazy_load_file
)
def
apply
(
self
,
input_indices
):
output
=
F
.
embedding
(
input
=
input_indices
,
weight
=
self
.
weight
,
padding_idx
=
None
,
max_norm
=
None
,
norm_type
=
2.0
,
scale_grad_by_freq
=
False
,
sparse
=
False
)
return
output
lightx2v/common/ops/mm/__init__.py
0 → 100644
View file @
a1ebc651
from
.mm_weight
import
*
lightx2v/common/ops/mm/mm_weight.py
0 → 100644
View file @
a1ebc651
import
os
import
re
from
abc
import
ABCMeta
,
abstractmethod
from
pathlib
import
Path
import
torch
from
safetensors
import
safe_open
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.ggml_tensor
import
GGMLTensor
from
lightx2v.utils.ggml_tensor
import
dequantize_tensor
as
gguf_dequantize_tensor
from
lightx2v.utils.global_paras
import
CALIB
from
lightx2v.utils.quant_utils
import
FloatQuantizer
,
IntegerQuantizer
from
lightx2v.utils.registry_factory
import
MM_WEIGHT_REGISTER
from
lightx2v_platform.base.global_var
import
AI_DEVICE
try
:
from
lightx2v_kernel.gemm
import
(
cutlass_scaled_mxfp4_mm
,
cutlass_scaled_mxfp6_mxfp8_mm
,
cutlass_scaled_mxfp8_mm
,
cutlass_scaled_nvfp4_mm
,
scaled_mxfp4_quant
,
scaled_mxfp6_quant
,
scaled_mxfp8_quant
,
scaled_nvfp4_quant
,
)
except
ImportError
:
scaled_nvfp4_quant
,
cutlass_scaled_nvfp4_mm
=
None
,
None
scaled_mxfp4_quant
,
cutlass_scaled_mxfp4_mm
=
None
,
None
scaled_mxfp6_quant
,
cutlass_scaled_mxfp6_mxfp8_mm
=
None
,
None
scaled_mxfp8_quant
,
cutlass_scaled_mxfp8_mm
=
None
,
None
try
:
from
vllm
import
_custom_ops
as
ops
except
ImportError
:
ops
=
None
try
:
import
sgl_kernel
except
ImportError
:
sgl_kernel
=
None
try
:
from
q8_kernels.functional.linear
import
q8_linear
except
ImportError
:
q8_linear
=
None
try
:
from
q8_kernels.functional.linear
import
fp8_linear
except
ImportError
:
fp8_linear
=
None
try
:
import
deep_gemm
except
ImportError
:
deep_gemm
=
None
try
:
from
torchao.quantization.utils
import
quant_int8_per_token_matmul
,
quantize_activation_per_token_absmax
except
ImportError
:
quant_int8_per_token_matmul
,
quantize_activation_per_token_absmax
=
None
,
None
try
:
import
gguf
except
ImportError
:
gguf
=
None
try
:
import
marlin_cuda_quant
except
ImportError
:
marlin_cuda_quant
=
None
class
MMWeightTemplate
(
metaclass
=
ABCMeta
):
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
self
.
weight_name
=
weight_name
self
.
bias_name
=
bias_name
self
.
create_cuda_buffer
=
create_cuda_buffer
self
.
create_cpu_buffer
=
create_cpu_buffer
self
.
lazy_load
=
lazy_load
self
.
lazy_load_file
=
lazy_load_file
self
.
is_post_adapter
=
is_post_adapter
self
.
config
=
{}
@
abstractmethod
def
load
(
self
,
weight_dict
):
pass
@
abstractmethod
def
apply
(
self
):
pass
def
set_config
(
self
,
config
=
{}):
self
.
config
=
config
def
to_cuda
(
self
,
non_blocking
=
False
):
self
.
weight
=
self
.
pin_weight
.
to
(
AI_DEVICE
,
non_blocking
=
non_blocking
)
if
hasattr
(
self
,
"pin_weight_scale"
):
self
.
weight_scale
=
self
.
pin_weight_scale
.
to
(
AI_DEVICE
,
non_blocking
=
non_blocking
)
if
hasattr
(
self
,
"pin_bias"
)
and
self
.
pin_bias
is
not
None
:
self
.
bias
=
self
.
pin_bias
.
to
(
AI_DEVICE
,
non_blocking
=
non_blocking
)
def
to_cpu
(
self
,
non_blocking
=
False
):
if
hasattr
(
self
,
"pin_weight"
):
self
.
weight
=
self
.
pin_weight
.
copy_
(
self
.
weight
,
non_blocking
=
non_blocking
).
cpu
()
if
hasattr
(
self
,
"weight_scale_name"
):
self
.
weight_scale
=
self
.
pin_weight_scale
.
copy_
(
self
.
weight_scale
,
non_blocking
=
non_blocking
).
cpu
()
if
self
.
bias
is
not
None
:
self
.
bias
=
self
.
pin_bias
.
copy_
(
self
.
bias
,
non_blocking
=
non_blocking
).
cpu
()
else
:
self
.
weight
=
self
.
weight
.
to
(
"cpu"
,
non_blocking
=
non_blocking
)
if
hasattr
(
self
,
"weight_scale"
):
self
.
weight_scale
=
self
.
weight_scale
.
to
(
"cpu"
,
non_blocking
=
non_blocking
)
if
hasattr
(
self
,
"bias"
)
and
self
.
bias
is
not
None
:
self
.
bias
=
self
.
bias
.
to
(
"cpu"
,
non_blocking
=
non_blocking
)
@
MM_WEIGHT_REGISTER
(
"Default"
)
class
MMWeight
(
MMWeightTemplate
):
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
create_cpu_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
)
def
load
(
self
,
weight_dict
):
if
self
.
create_cuda_buffer
:
self
.
_load_cuda_buffers
(
weight_dict
)
elif
self
.
create_cpu_buffer
:
self
.
_load_cpu_pin_buffers
()
else
:
self
.
_load_default_tensors
(
weight_dict
)
def
_get_source_tensor
(
self
,
source_name
,
weight_dict
=
None
):
if
self
.
lazy_load
:
if
Path
(
self
.
lazy_load_file
).
is_file
():
lazy_load_file_path
=
self
.
lazy_load_file
else
:
lazy_load_file_path
=
os
.
path
.
join
(
self
.
lazy_load_file
,
f
"block_
{
source_name
.
split
(
'.'
)[
1
]
}
.safetensors"
)
with
safe_open
(
lazy_load_file_path
,
framework
=
"pt"
,
device
=
"cpu"
)
as
lazy_load_file
:
return
lazy_load_file
.
get_tensor
(
source_name
)
return
weight_dict
[
source_name
]
def
_create_pin_tensor
(
self
,
tensor
,
transpose
=
False
):
pin_tensor
=
torch
.
empty
(
tensor
.
shape
,
pin_memory
=
True
,
dtype
=
tensor
.
dtype
)
pin_tensor
=
pin_tensor
.
copy_
(
tensor
)
if
transpose
:
pin_tensor
=
pin_tensor
.
t
()
del
tensor
return
pin_tensor
def
_load_cuda_buffers
(
self
,
weight_dict
):
self
.
weight_cuda_buffer
=
self
.
_get_source_tensor
(
self
.
weight_name
,
weight_dict
).
t
().
to
(
AI_DEVICE
)
if
self
.
bias_name
is
not
None
:
self
.
bias_cuda_buffer
=
self
.
_get_source_tensor
(
self
.
bias_name
,
weight_dict
).
to
(
AI_DEVICE
)
def
_load_cpu_pin_buffers
(
self
):
if
self
.
lazy_load
:
if
Path
(
self
.
lazy_load_file
).
is_file
():
lazy_load_file_path
=
self
.
lazy_load_file
else
:
lazy_load_file_path
=
os
.
path
.
join
(
self
.
lazy_load_file
,
f
"block_
{
self
.
weight_name
.
split
(
'.'
)[
1
]
}
.safetensors"
)
with
safe_open
(
lazy_load_file_path
,
framework
=
"pt"
,
device
=
"cpu"
)
as
lazy_load_file
:
weight_tensor
=
lazy_load_file
.
get_tensor
(
self
.
weight_name
)
self
.
pin_weight
=
self
.
_create_pin_tensor
(
weight_tensor
,
transpose
=
True
)
if
self
.
bias_name
is
not
None
:
bias_tensor
=
lazy_load_file
.
get_tensor
(
self
.
bias_name
)
self
.
pin_bias
=
self
.
_create_pin_tensor
(
bias_tensor
)
else
:
self
.
bias
=
None
self
.
pin_bias
=
None
def
_load_default_tensors
(
self
,
weight_dict
):
if
not
self
.
lazy_load
:
device
=
weight_dict
[
self
.
weight_name
].
device
if
device
.
type
==
"cpu"
:
weight_tensor
=
weight_dict
[
self
.
weight_name
]
self
.
pin_weight
=
self
.
_create_pin_tensor
(
weight_tensor
,
transpose
=
True
)
if
self
.
bias_name
is
not
None
:
bias_tensor
=
weight_dict
[
self
.
bias_name
]
self
.
pin_bias
=
self
.
_create_pin_tensor
(
bias_tensor
)
else
:
self
.
bias
=
None
self
.
pin_bias
=
None
del
weight_dict
[
self
.
weight_name
]
else
:
self
.
weight
=
weight_dict
[
self
.
weight_name
].
t
()
self
.
bias
=
weight_dict
[
self
.
bias_name
]
if
self
.
bias_name
is
not
None
else
None
def
apply
(
self
,
input_tensor
):
shape
=
(
input_tensor
.
shape
[
0
],
self
.
weight
.
shape
[
1
])
dtype
=
input_tensor
.
dtype
device
=
input_tensor
.
device
output_tensor
=
torch
.
empty
(
shape
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
False
)
if
self
.
bias
is
None
:
return
torch
.
mm
(
input_tensor
,
self
.
weight
,
out
=
output_tensor
)
return
torch
.
addmm
(
self
.
bias
,
input_tensor
,
self
.
weight
,
out
=
output_tensor
)
def
state_dict
(
self
,
destination
=
None
):
if
destination
is
None
:
destination
=
{}
destination
[
self
.
weight_name
]
=
self
.
pin_weight
if
hasattr
(
self
,
"pin_weight"
)
else
self
.
weight
if
self
.
bias_name
is
not
None
:
destination
[
self
.
bias_name
]
=
self
.
pin_bias
if
hasattr
(
self
,
"pin_bias"
)
else
self
.
bias
return
destination
def
load_state_dict_from_disk
(
self
,
block_index
,
adapter_block_index
=
None
):
if
self
.
is_post_adapter
:
assert
adapter_block_index
is
not
None
self
.
weight_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
adapter_block_index
}
"
,
self
.
weight_name
,
count
=
1
)
else
:
self
.
weight_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
block_index
}
"
,
self
.
weight_name
,
count
=
1
)
if
self
.
bias_name
is
not
None
:
if
self
.
is_post_adapter
:
assert
adapter_block_index
is
not
None
self
.
bias_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
adapter_block_index
}
"
,
self
.
bias_name
,
count
=
1
)
else
:
self
.
bias_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
block_index
}
"
,
self
.
bias_name
,
count
=
1
)
if
Path
(
self
.
lazy_load_file
).
is_file
():
lazy_load_file_path
=
self
.
lazy_load_file
else
:
lazy_load_file_path
=
os
.
path
.
join
(
self
.
lazy_load_file
,
f
"block_
{
block_index
}
.safetensors"
)
with
safe_open
(
lazy_load_file_path
,
framework
=
"pt"
,
device
=
"cpu"
)
as
lazy_load_file
:
weight_tensor
=
lazy_load_file
.
get_tensor
(
self
.
weight_name
).
t
()
self
.
pin_weight
=
self
.
pin_weight
.
copy_
(
weight_tensor
)
del
weight_tensor
if
self
.
bias_name
is
not
None
:
bias_tensor
=
lazy_load_file
.
get_tensor
(
self
.
bias_name
)
self
.
pin_bias
.
copy_
(
bias_tensor
)
del
bias_tensor
def
load_state_dict
(
self
,
destination
,
block_index
,
adapter_block_index
=
None
):
if
self
.
is_post_adapter
:
assert
adapter_block_index
is
not
None
weight_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
adapter_block_index
}
"
,
self
.
weight_name
,
count
=
1
)
else
:
weight_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
block_index
}
"
,
self
.
weight_name
,
count
=
1
)
if
weight_name
not
in
destination
:
self
.
weight
=
None
return
self
.
weight
=
self
.
weight_cuda_buffer
.
copy_
(
destination
[
weight_name
],
non_blocking
=
True
)
if
self
.
bias_name
is
not
None
:
if
self
.
is_post_adapter
:
assert
adapter_block_index
is
not
None
bias_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
adapter_block_index
}
"
,
self
.
bias_name
,
count
=
1
)
else
:
bias_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
block_index
}
"
,
self
.
bias_name
,
count
=
1
)
self
.
bias
=
self
.
bias_cuda_buffer
.
copy_
(
destination
[
bias_name
],
non_blocking
=
True
)
else
:
self
.
bias
=
None
@
MM_WEIGHT_REGISTER
(
"Default-Force-FP32"
)
class
MMWeightForceFP32
(
MMWeight
):
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
create_cpu_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
)
def
load
(
self
,
weight_dict
):
if
not
self
.
lazy_load
:
super
().
load
(
weight_dict
)
self
.
weight
=
self
.
weight
.
to
(
torch
.
float32
)
if
hasattr
(
self
,
"bias"
)
and
self
.
bias
is
not
None
:
self
.
bias
=
self
.
bias
.
to
(
torch
.
float32
)
class
MMWeightQuantTemplate
(
MMWeightTemplate
):
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
create_cpu_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
)
self
.
weight_scale_name
=
self
.
weight_name
.
removesuffix
(
".weight"
)
+
".weight_scale"
self
.
load_func
=
None
self
.
weight_need_transpose
=
True
self
.
act_quant_func
=
None
self
.
lazy_load
=
lazy_load
self
.
lazy_load_file
=
lazy_load_file
self
.
infer_dtype
=
GET_DTYPE
()
self
.
bias_force_fp32
=
False
# =========================
# weight load functions
# =========================
def
load
(
self
,
weight_dict
):
self
.
load_quantized
(
weight_dict
)
if
self
.
weight_need_transpose
:
if
hasattr
(
self
,
"weight"
)
and
self
.
weight
is
not
None
:
self
.
weight
=
self
.
weight
.
t
()
if
hasattr
(
self
,
"pin_weight"
)
and
self
.
pin_weight
is
not
None
:
self
.
pin_weight
=
self
.
pin_weight
.
t
()
if
hasattr
(
self
,
"weight_cuda_buffer"
)
and
self
.
weight_cuda_buffer
is
not
None
:
self
.
weight_cuda_buffer
=
self
.
weight_cuda_buffer
.
t
()
def
load_quantized
(
self
,
weight_dict
):
if
self
.
create_cuda_buffer
:
self
.
_load_cuda_buffers
(
weight_dict
)
elif
self
.
create_cpu_buffer
:
self
.
_load_cpu_pin_buffers
()
else
:
self
.
_load_default_tensors
(
weight_dict
)
def
_load_cuda_buffers
(
self
,
weight_dict
):
if
self
.
lazy_load
:
if
Path
(
self
.
lazy_load_file
).
is_file
():
lazy_load_file_path
=
self
.
lazy_load_file
else
:
lazy_load_file_path
=
os
.
path
.
join
(
self
.
lazy_load_file
,
f
"block_
{
self
.
weight_name
.
split
(
'.'
)[
1
]
}
.safetensors"
)
with
safe_open
(
lazy_load_file_path
,
framework
=
"pt"
,
device
=
"cpu"
)
as
source
:
self
.
weight_cuda_buffer
,
self
.
weight_scale_cuda_buffer
=
self
.
_get_cuda_tensor_pair
(
source
,
self
.
lazy_load
)
self
.
bias_cuda_buffer
=
self
.
_get_cuda_bias_tensor
(
source
,
self
.
lazy_load
)
else
:
source
=
weight_dict
self
.
weight_cuda_buffer
,
self
.
weight_scale_cuda_buffer
=
self
.
_get_cuda_tensor_pair
(
source
,
self
.
lazy_load
)
self
.
bias_cuda_buffer
=
self
.
_get_cuda_bias_tensor
(
source
,
self
.
lazy_load
)
def
_get_cuda_tensor_pair
(
self
,
source
,
is_lazy
):
if
is_lazy
:
weight
=
source
.
get_tensor
(
self
.
weight_name
).
to
(
AI_DEVICE
)
scale
=
source
.
get_tensor
(
self
.
weight_scale_name
).
float
().
to
(
AI_DEVICE
)
else
:
weight
=
source
[
self
.
weight_name
].
to
(
AI_DEVICE
)
scale
=
source
[
self
.
weight_scale_name
].
float
().
to
(
AI_DEVICE
)
return
weight
,
scale
def
_get_cuda_bias_tensor
(
self
,
source
,
is_lazy
):
if
self
.
bias_name
is
None
:
return
None
if
is_lazy
:
bias
=
source
.
get_tensor
(
self
.
bias_name
)
dtype
=
self
.
infer_dtype
else
:
bias
=
source
[
self
.
bias_name
]
dtype
=
bias
.
dtype
if
self
.
bias_force_fp32
:
bias
=
bias
.
to
(
torch
.
float32
)
else
:
bias
=
bias
.
to
(
dtype
)
return
bias
.
to
(
AI_DEVICE
)
def
_load_cpu_pin_buffers
(
self
):
self
.
pin_weight
,
self
.
pin_weight_scale
=
self
.
_get_cpu_pin_tensor_pair
(
self
.
lazy_load_file
,
is_lazy
=
True
)
self
.
pin_bias
=
self
.
_get_cpu_pin_bias_tensor
(
self
.
lazy_load_file
,
is_lazy
=
True
)
self
.
bias
=
None
def
_get_cpu_pin_tensor_pair
(
self
,
source
,
is_lazy
):
if
is_lazy
:
if
Path
(
self
.
lazy_load_file
).
is_file
():
lazy_load_file_path
=
self
.
lazy_load_file
else
:
lazy_load_file_path
=
os
.
path
.
join
(
self
.
lazy_load_file
,
f
"block_
{
self
.
weight_name
.
split
(
'.'
)[
1
]
}
.safetensors"
)
with
safe_open
(
lazy_load_file_path
,
framework
=
"pt"
,
device
=
"cpu"
)
as
source
:
weight_tensor
=
source
.
get_tensor
(
self
.
weight_name
)
scale_tensor
=
source
.
get_tensor
(
self
.
weight_scale_name
)
scale_dtype
=
torch
.
float
pin_weight
=
self
.
_create_pin_tensor
(
weight_tensor
)
pin_scale
=
self
.
_create_pin_tensor
(
scale_tensor
,
scale_dtype
)
else
:
weight_tensor
=
source
[
self
.
weight_name
]
scale_tensor
=
source
[
self
.
weight_scale_name
]
scale_dtype
=
torch
.
float
pin_weight
=
self
.
_create_pin_tensor
(
weight_tensor
)
pin_scale
=
self
.
_create_pin_tensor
(
scale_tensor
,
scale_dtype
)
return
pin_weight
,
pin_scale
def
_get_cpu_pin_bias_tensor
(
self
,
source
,
is_lazy
):
if
self
.
bias_name
is
None
:
return
None
if
is_lazy
:
if
Path
(
self
.
lazy_load_file
).
is_file
():
lazy_load_file_path
=
self
.
lazy_load_file
else
:
lazy_load_file_path
=
os
.
path
.
join
(
self
.
lazy_load_file
,
f
"block_
{
self
.
weight_name
.
split
(
'.'
)[
1
]
}
.safetensors"
)
with
safe_open
(
lazy_load_file_path
,
framework
=
"pt"
,
device
=
"cpu"
)
as
source
:
bias_tensor
=
source
.
get_tensor
(
self
.
bias_name
)
if
not
self
.
bias_force_fp32
:
bias_tensor
=
bias_tensor
.
to
(
self
.
infer_dtype
)
if
self
.
bias_force_fp32
:
bias_tensor
=
bias_tensor
.
to
(
torch
.
float32
)
return
self
.
_create_pin_tensor
(
bias_tensor
)
else
:
bias_tensor
=
source
[
self
.
bias_name
]
if
self
.
bias_force_fp32
:
bias_tensor
=
bias_tensor
.
to
(
torch
.
float32
)
return
self
.
_create_pin_tensor
(
bias_tensor
)
def
_create_pin_tensor
(
self
,
tensor
,
dtype
=
None
):
dtype
=
dtype
or
tensor
.
dtype
pin_tensor
=
torch
.
empty
(
tensor
.
shape
,
pin_memory
=
True
,
dtype
=
dtype
)
pin_tensor
.
copy_
(
tensor
)
del
tensor
return
pin_tensor
def
_load_default_tensors
(
self
,
weight_dict
):
if
not
self
.
lazy_load
:
self
.
weight
,
self
.
weight_scale
,
self
.
pin_weight
,
self
.
pin_weight_scale
=
self
.
_get_device_tensor_pair
(
weight_dict
)
self
.
_load_default_bias
(
weight_dict
)
else
:
self
.
bias
=
None
self
.
pin_bias
=
None
def
_get_device_tensor_pair
(
self
,
source
):
device
=
source
[
self
.
weight_name
].
device
if
device
.
type
==
"cpu"
:
pin_weight
,
pin_scale
=
self
.
_get_cpu_pin_tensor_pair
(
source
,
is_lazy
=
False
)
return
None
,
None
,
pin_weight
,
pin_scale
else
:
return
source
[
self
.
weight_name
],
source
[
self
.
weight_scale_name
].
float
(),
None
,
None
def
_load_default_bias
(
self
,
source
):
if
self
.
bias_name
is
None
:
self
.
bias
=
None
self
.
pin_bias
=
None
self
.
bias_cuda_buffer
=
None
return
if
self
.
create_cuda_buffer
:
self
.
bias_cuda_buffer
=
self
.
_get_cuda_bias_tensor
(
source
,
is_lazy
=
False
)
self
.
bias
=
None
self
.
pin_bias
=
None
else
:
bias_tensor
=
source
[
self
.
bias_name
].
float
()
if
self
.
bias_force_fp32
else
source
[
self
.
bias_name
]
device
=
bias_tensor
.
device
if
device
.
type
==
"cpu"
:
self
.
pin_bias
=
self
.
_get_cpu_pin_bias_tensor
(
source
,
is_lazy
=
False
)
self
.
bias
=
None
else
:
self
.
bias
=
bias_tensor
self
.
pin_bias
=
None
def
load_fp8_perchannel_sym
(
self
,
weight_dict
):
if
self
.
config
.
get
(
"weight_auto_quant"
,
False
):
self
.
weight
=
weight_dict
[
self
.
weight_name
].
to
(
torch
.
float32
)
w_quantizer
=
FloatQuantizer
(
"e4m3"
,
True
,
"per_channel"
)
self
.
weight
,
self
.
weight_scale
,
_
=
w_quantizer
.
real_quant_tensor
(
self
.
weight
)
self
.
weight
=
self
.
weight
.
to
(
torch
.
float8_e4m3fn
)
self
.
weight_scale
=
self
.
weight_scale
.
to
(
torch
.
float32
)
else
:
self
.
load_quantized
(
weight_dict
)
def
load_int8_perchannel_sym
(
self
,
weight_dict
):
if
self
.
config
.
get
(
"weight_auto_quant"
,
False
):
self
.
weight
=
weight_dict
[
self
.
weight_name
].
to
(
torch
.
float32
)
w_quantizer
=
IntegerQuantizer
(
8
,
True
,
"per_channel"
)
self
.
weight
,
self
.
weight_scale
,
_
=
w_quantizer
.
real_quant_tensor
(
self
.
weight
)
self
.
weight
=
self
.
weight
.
to
(
torch
.
int8
)
self
.
weight_scale
=
self
.
weight_scale
.
to
(
torch
.
float32
)
else
:
self
.
load_quantized
(
weight_dict
)
def
load_mxfp4
(
self
,
weight_dict
):
if
self
.
config
.
get
(
"weight_auto_quant"
,
False
):
device
=
weight_dict
[
self
.
weight_name
].
device
self
.
weight
=
weight_dict
[
self
.
weight_name
].
to
(
AI_DEVICE
).
to
(
torch
.
bfloat16
)
self
.
weight
,
self
.
weight_scale
=
scaled_mxfp4_quant
(
self
.
weight
)
self
.
weight
,
self
.
weight_scale
=
self
.
weight
.
to
(
device
),
self
.
weight_scale
.
to
(
device
)
else
:
device
=
weight_dict
[
self
.
weight_name
].
device
if
device
.
type
==
"cpu"
:
weight_shape
=
weight_dict
[
self
.
weight_name
].
shape
weight_dtype
=
weight_dict
[
self
.
weight_name
].
dtype
self
.
pin_weight
=
torch
.
empty
(
weight_shape
,
pin_memory
=
True
,
dtype
=
weight_dtype
)
self
.
pin_weight
.
copy_
(
weight_dict
[
self
.
weight_name
])
weight_scale_shape
=
weight_dict
[
self
.
weight_scale_name
].
shape
weight_scale_dtype
=
weight_dict
[
self
.
weight_scale_name
].
dtype
self
.
pin_weight_scale
=
torch
.
empty
(
weight_scale_shape
,
pin_memory
=
True
,
dtype
=
weight_scale_dtype
)
self
.
pin_weight_scale
.
copy_
(
weight_dict
[
self
.
weight_scale_name
])
del
weight_dict
[
self
.
weight_name
]
else
:
self
.
weight
=
weight_dict
[
self
.
weight_name
]
self
.
weight_scale
=
weight_dict
[
self
.
weight_scale_name
]
def
load_mxfp6
(
self
,
weight_dict
):
if
self
.
config
.
get
(
"weight_auto_quant"
,
False
):
device
=
weight_dict
[
self
.
weight_name
].
device
self
.
weight
=
weight_dict
[
self
.
weight_name
].
to
(
AI_DEVICE
).
to
(
torch
.
bfloat16
)
self
.
weight
,
self
.
weight_scale
=
scaled_mxfp6_quant
(
self
.
weight
)
self
.
weight
,
self
.
weight_scale
=
self
.
weight
.
to
(
device
),
self
.
weight_scale
.
to
(
device
)
else
:
device
=
weight_dict
[
self
.
weight_name
].
device
if
device
.
type
==
"cpu"
:
weight_shape
=
weight_dict
[
self
.
weight_name
].
shape
weight_dtype
=
weight_dict
[
self
.
weight_name
].
dtype
self
.
pin_weight
=
torch
.
empty
(
weight_shape
,
pin_memory
=
True
,
dtype
=
weight_dtype
)
self
.
pin_weight
.
copy_
(
weight_dict
[
self
.
weight_name
])
weight_scale_shape
=
weight_dict
[
self
.
weight_scale_name
].
shape
weight_scale_dtype
=
weight_dict
[
self
.
weight_scale_name
].
dtype
self
.
pin_weight_scale
=
torch
.
empty
(
weight_scale_shape
,
pin_memory
=
True
,
dtype
=
weight_scale_dtype
)
self
.
pin_weight_scale
.
copy_
(
weight_dict
[
self
.
weight_scale_name
])
del
weight_dict
[
self
.
weight_name
]
else
:
self
.
weight
=
weight_dict
[
self
.
weight_name
]
self
.
weight_scale
=
weight_dict
[
self
.
weight_scale_name
]
def
load_mxfp8
(
self
,
weight_dict
):
if
self
.
config
.
get
(
"weight_auto_quant"
,
False
):
device
=
weight_dict
[
self
.
weight_name
].
device
self
.
weight
=
weight_dict
[
self
.
weight_name
].
to
(
AI_DEVICE
).
to
(
torch
.
bfloat16
)
self
.
weight
,
self
.
weight_scale
=
scaled_mxfp8_quant
(
self
.
weight
)
self
.
weight
,
self
.
weight_scale
=
self
.
weight
.
to
(
device
),
self
.
weight_scale
.
to
(
device
)
else
:
device
=
weight_dict
[
self
.
weight_name
].
device
if
device
.
type
==
"cpu"
:
weight_shape
=
weight_dict
[
self
.
weight_name
].
shape
weight_dtype
=
weight_dict
[
self
.
weight_name
].
dtype
self
.
pin_weight
=
torch
.
empty
(
weight_shape
,
pin_memory
=
True
,
dtype
=
weight_dtype
)
self
.
pin_weight
.
copy_
(
weight_dict
[
self
.
weight_name
])
weight_scale_shape
=
weight_dict
[
self
.
weight_scale_name
].
shape
weight_scale_dtype
=
weight_dict
[
self
.
weight_scale_name
].
dtype
self
.
pin_weight_scale
=
torch
.
empty
(
weight_scale_shape
,
pin_memory
=
True
,
dtype
=
weight_scale_dtype
)
self
.
pin_weight_scale
.
copy_
(
weight_dict
[
self
.
weight_scale_name
])
del
weight_dict
[
self
.
weight_name
]
else
:
self
.
weight
=
weight_dict
[
self
.
weight_name
]
self
.
weight_scale
=
weight_dict
[
self
.
weight_scale_name
]
def
load_nvfp4
(
self
,
weight_dict
):
device
=
weight_dict
[
self
.
weight_name
].
device
input_absmax
=
weight_dict
[
self
.
weight_name
.
replace
(
".weight"
,
".input_absmax"
)]
input_global_scale
=
(
2688.0
/
input_absmax
).
to
(
torch
.
float32
)
weight_global_scale
=
weight_dict
[
f
"
{
self
.
weight_name
}
_global_scale"
]
alpha
=
1.0
/
(
input_global_scale
*
weight_global_scale
)
if
device
.
type
==
"cpu"
:
weight_shape
=
weight_dict
[
self
.
weight_name
].
shape
weight_dtype
=
weight_dict
[
self
.
weight_name
].
dtype
self
.
pin_weight
=
torch
.
empty
(
weight_shape
,
pin_memory
=
True
,
dtype
=
weight_dtype
)
self
.
pin_weight
.
copy_
(
weight_dict
[
self
.
weight_name
])
weight_scale_shape
=
weight_dict
[
self
.
weight_scale_name
].
shape
weight_scale_dtype
=
weight_dict
[
self
.
weight_scale_name
].
dtype
self
.
pin_weight_scale
=
torch
.
empty
(
weight_scale_shape
,
pin_memory
=
True
,
dtype
=
weight_scale_dtype
)
self
.
pin_weight_scale
.
copy_
(
weight_dict
[
self
.
weight_scale_name
])
input_global_scale_shape
=
input_global_scale
.
shape
input_global_scale_dtype
=
input_global_scale
.
dtype
self
.
pin_input_global_scale
=
torch
.
empty
(
input_global_scale_shape
,
pin_memory
=
True
,
dtype
=
input_global_scale_dtype
)
self
.
pin_input_global_scale
.
copy_
(
input_global_scale
)
alpha_shape
=
alpha
.
shape
alpha_dtype
=
alpha
.
dtype
self
.
pin_alpha
=
torch
.
empty
(
alpha_shape
,
pin_memory
=
True
,
dtype
=
alpha_dtype
)
self
.
pin_alpha
.
copy_
(
alpha
)
del
weight_dict
[
self
.
weight_name
]
else
:
self
.
weight
=
weight_dict
[
self
.
weight_name
]
self
.
weight_scale
=
weight_dict
[
self
.
weight_scale_name
]
self
.
input_global_scale
=
input_global_scale
self
.
alpha
=
alpha
if
self
.
bias_name
is
not
None
:
if
self
.
create_cuda_buffer
:
self
.
bias_cuda_buffer
=
weight_dict
[
self
.
bias_name
].
to
(
AI_DEVICE
)
else
:
device
=
weight_dict
[
self
.
bias_name
].
device
if
device
.
type
==
"cpu"
:
bias_shape
=
weight_dict
[
self
.
bias_name
].
shape
bias_dtype
=
weight_dict
[
self
.
bias_name
].
dtype
self
.
pin_bias
=
torch
.
empty
(
bias_shape
,
pin_memory
=
True
,
dtype
=
bias_dtype
)
self
.
pin_bias
.
copy_
(
weight_dict
[
self
.
bias_name
])
else
:
self
.
bias
=
weight_dict
[
self
.
bias_name
]
else
:
self
.
bias
=
None
self
.
pin_bias
=
None
def
load_fp8_perblock128_sym
(
self
,
weight_dict
):
if
self
.
config
.
get
(
"weight_auto_quant"
,
False
):
self
.
weight
=
weight_dict
[
self
.
weight_name
]
self
.
weight
,
self
.
weight_scale
=
self
.
per_block_cast_to_fp8
(
self
.
weight
)
else
:
self
.
load_quantized
(
weight_dict
)
def
per_block_cast_to_fp8
(
self
,
x
):
assert
x
.
dim
()
==
2
m
,
n
=
x
.
shape
x_padded
=
torch
.
zeros
(
(
deep_gemm
.
ceil_div
(
m
,
128
)
*
128
,
deep_gemm
.
ceil_div
(
n
,
128
)
*
128
),
dtype
=
x
.
dtype
,
device
=
x
.
device
,
)
x_padded
[:
m
,
:
n
]
=
x
x_view
=
x_padded
.
view
(
-
1
,
128
,
x_padded
.
size
(
1
)
//
128
,
128
)
x_amax
=
x_view
.
abs
().
float
().
amax
(
dim
=
(
1
,
3
),
keepdim
=
True
).
clamp
(
1e-4
)
x_scaled
=
(
x_view
*
(
448.0
/
x_amax
)).
to
(
torch
.
float8_e4m3fn
)
return
x_scaled
.
view_as
(
x_padded
)[:
m
,
:
n
].
contiguous
(),
(
x_amax
/
448.0
).
view
(
x_view
.
size
(
0
),
x_view
.
size
(
2
))
# =========================
# act quant kernels
# =========================
def
act_quant_int8_perchannel_sym_torchao
(
self
,
x
):
input_tensor_quant
,
input_tensor_scale
=
quantize_activation_per_token_absmax
(
x
)
return
input_tensor_quant
,
input_tensor_scale
def
act_quant_fp8_perchannel_sym_vllm
(
self
,
x
):
input_tensor_quant
,
input_tensor_scale
=
ops
.
scaled_fp8_quant
(
x
,
None
,
scale_ub
=
None
,
use_per_token_if_dynamic
=
True
)
return
input_tensor_quant
,
input_tensor_scale
def
act_quant_fp8_perchannel_sym_sgl
(
self
,
x
):
m
,
k
=
x
.
shape
input_tensor_quant
=
torch
.
empty
((
m
,
k
),
dtype
=
torch
.
float8_e4m3fn
,
device
=
"cuda"
,
requires_grad
=
False
)
input_tensor_scale
=
torch
.
empty
((
m
,
1
),
dtype
=
torch
.
float32
,
device
=
"cuda"
,
requires_grad
=
False
)
sgl_kernel
.
sgl_per_token_quant_fp8
(
x
,
input_tensor_quant
,
input_tensor_scale
)
return
input_tensor_quant
,
input_tensor_scale
def
act_quant_int8_perchannel_sym_vllm
(
self
,
x
):
input_tensor_quant
,
input_tensor_scale
,
_
=
ops
.
scaled_int8_quant
(
x
,
scale
=
None
,
azp
=
None
,
symmetric
=
True
)
return
input_tensor_quant
,
input_tensor_scale
def
act_quant_nvfp4
(
self
,
x
):
input_tensor_quant
,
input_tensor_scale
=
scaled_nvfp4_quant
(
x
,
self
.
input_global_scale
)
return
input_tensor_quant
,
input_tensor_scale
def
act_quant_mxfp4
(
self
,
x
):
input_tensor_quant
,
input_tensor_scale
=
scaled_mxfp4_quant
(
x
)
return
input_tensor_quant
,
input_tensor_scale
def
act_quant_mxfp8
(
self
,
x
):
input_tensor_quant
,
input_tensor_scale
=
scaled_mxfp8_quant
(
x
)
return
input_tensor_quant
,
input_tensor_scale
def
act_quant_fp8_perchannelgroup128_sym_deepgemm
(
self
,
x
):
assert
x
.
dim
()
==
2
and
x
.
size
(
1
)
%
128
==
0
m
,
n
=
x
.
shape
x_view
=
x
.
view
(
m
,
-
1
,
128
)
x_amax
=
x_view
.
abs
().
float
().
amax
(
dim
=
2
).
view
(
m
,
-
1
).
clamp
(
1e-4
)
return
(
x_view
*
(
448.0
/
x_amax
.
unsqueeze
(
2
))).
to
(
torch
.
float8_e4m3fn
).
view
(
m
,
n
),
(
x_amax
/
448.0
).
view
(
m
,
-
1
)
def
act_quant_fp8_perchannelgroup128_sym_sgl
(
self
,
x
):
m
,
k
=
x
.
shape
input_tensor_quant
=
torch
.
empty
((
m
,
k
),
dtype
=
torch
.
float8_e4m3fn
,
device
=
"cuda"
,
requires_grad
=
False
)
input_tensor_scale
=
torch
.
empty
((
m
,
k
//
128
),
dtype
=
torch
.
float32
,
device
=
"cuda"
,
requires_grad
=
False
)
sgl_kernel
.
sgl_per_token_group_quant_fp8
(
x
,
input_tensor_quant
,
input_tensor_scale
,
group_size
=
128
,
eps
=
1e-10
,
fp8_min
=-
448.0
,
fp8_max
=
448.0
,
)
return
input_tensor_quant
,
input_tensor_scale
def
state_dict
(
self
,
destination
=
None
):
if
destination
is
None
:
destination
=
{}
destination
[
self
.
weight_name
]
=
self
.
pin_weight
if
hasattr
(
self
,
"pin_weight"
)
else
self
.
weight
if
self
.
bias_name
is
not
None
:
destination
[
self
.
bias_name
]
=
self
.
pin_bias
if
hasattr
(
self
,
"pin_bias"
)
else
self
.
bias
destination
[
self
.
weight_scale_name
]
=
self
.
pin_weight_scale
if
hasattr
(
self
,
"pin_weight_scale"
)
else
self
.
weight_scale
return
destination
def
load_state_dict
(
self
,
destination
,
block_index
,
adapter_block_index
=
None
):
if
self
.
is_post_adapter
:
weight_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
adapter_block_index
}
"
,
self
.
weight_name
,
count
=
1
)
weight_scale_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
adapter_block_index
}
"
,
self
.
weight_scale_name
,
count
=
1
)
else
:
weight_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
block_index
}
"
,
self
.
weight_name
,
count
=
1
)
weight_scale_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
block_index
}
"
,
self
.
weight_scale_name
,
count
=
1
)
if
weight_name
not
in
destination
:
self
.
weight
=
None
return
self
.
weight
=
self
.
weight_cuda_buffer
.
copy_
(
destination
[
weight_name
],
non_blocking
=
True
)
self
.
weight_scale
=
self
.
weight_scale_cuda_buffer
.
copy_
(
destination
[
weight_scale_name
],
non_blocking
=
True
)
if
self
.
bias_name
is
not
None
:
bias_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
block_index
}
"
,
self
.
bias_name
,
count
=
1
)
self
.
bias
=
self
.
bias_cuda_buffer
.
copy_
(
destination
[
bias_name
],
non_blocking
=
True
)
else
:
self
.
bias
=
None
def
load_state_dict_from_disk
(
self
,
block_index
,
adapter_block_index
=
None
):
if
self
.
is_post_adapter
:
self
.
weight_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
adapter_block_index
}
"
,
self
.
weight_name
,
count
=
1
)
self
.
weight_scale_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
adapter_block_index
}
"
,
self
.
weight_scale_name
,
count
=
1
)
else
:
self
.
weight_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
block_index
}
"
,
self
.
weight_name
,
count
=
1
)
self
.
weight_scale_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
block_index
}
"
,
self
.
weight_scale_name
,
count
=
1
)
if
self
.
bias_name
is
not
None
:
if
self
.
is_post_adapter
:
assert
adapter_block_index
is
not
None
self
.
bias_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
adapter_block_index
}
"
,
self
.
bias_name
,
count
=
1
)
else
:
self
.
bias_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
block_index
}
"
,
self
.
bias_name
,
count
=
1
)
if
Path
(
self
.
lazy_load_file
).
is_file
():
lazy_load_file_path
=
self
.
lazy_load_file
else
:
lazy_load_file_path
=
os
.
path
.
join
(
self
.
lazy_load_file
,
f
"block_
{
block_index
}
.safetensors"
)
with
safe_open
(
lazy_load_file_path
,
framework
=
"pt"
,
device
=
"cpu"
)
as
lazy_load_file
:
if
self
.
weight_need_transpose
:
weight_tensor
=
lazy_load_file
.
get_tensor
(
self
.
weight_name
).
t
()
else
:
weight_tensor
=
lazy_load_file
.
get_tensor
(
self
.
weight_name
)
self
.
pin_weight
=
self
.
pin_weight
.
copy_
(
weight_tensor
)
del
weight_tensor
weight_scale_tensor
=
lazy_load_file
.
get_tensor
(
self
.
weight_scale_name
)
self
.
pin_weight_scale
=
self
.
pin_weight_scale
.
copy_
(
weight_scale_tensor
)
del
weight_scale_tensor
if
self
.
bias_name
is
not
None
:
bias_tensor
=
lazy_load_file
.
get_tensor
(
self
.
bias_name
)
self
.
pin_bias
.
copy_
(
bias_tensor
)
del
bias_tensor
@
MM_WEIGHT_REGISTER
(
"fp8-vllm"
)
class
MMWeightWfp8channelAfp8channeldynamicVllm
(
MMWeightQuantTemplate
):
"""
Name: W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm
Quant MM:
Weight: fp8 perchannel sym
Act: fp8 perchannel dynamic sym
Kernel: vllm
"""
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
create_cpu_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
)
self
.
load_func
=
self
.
load_fp8_perchannel_sym
self
.
weight_need_transpose
=
True
self
.
act_quant_func
=
self
.
act_quant_fp8_perchannel_sym_vllm
def
apply
(
self
,
input_tensor
):
shape
=
(
input_tensor
.
shape
[
0
],
self
.
weight
.
shape
[
1
])
dtype
=
input_tensor
.
dtype
device
=
input_tensor
.
device
output_tensor
=
torch
.
empty
(
shape
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
False
)
input_tensor_quant
,
input_tensor_scale
=
self
.
act_quant_func
(
input_tensor
)
torch
.
ops
.
_C
.
cutlass_scaled_mm
(
output_tensor
,
input_tensor_quant
,
self
.
weight
,
input_tensor_scale
,
self
.
weight_scale
,
self
.
bias
if
self
.
bias
is
not
None
else
None
,
)
return
output_tensor
@
MM_WEIGHT_REGISTER
(
"int8-vllm"
)
class
MMWeightWint8channelAint8channeldynamicVllm
(
MMWeightQuantTemplate
):
"""
Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm
Quant MM:
Weight: int8 perchannel sym
Act: int8 perchannel dynamic sym
Kernel: vllm
"""
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
create_cpu_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
)
self
.
load_func
=
self
.
load_int8_perchannel_sym
self
.
weight_need_transpose
=
True
self
.
act_quant_func
=
self
.
act_quant_int8_perchannel_sym_vllm
def
apply
(
self
,
input_tensor
):
shape
=
(
input_tensor
.
shape
[
0
],
self
.
weight
.
shape
[
1
])
dtype
=
input_tensor
.
dtype
device
=
input_tensor
.
device
output_tensor
=
torch
.
empty
(
shape
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
False
)
input_tensor_quant
,
input_tensor_scale
=
self
.
act_quant_func
(
input_tensor
)
torch
.
ops
.
_C
.
cutlass_scaled_mm
(
output_tensor
,
input_tensor_quant
,
self
.
weight
,
input_tensor_scale
,
self
.
weight_scale
,
self
.
bias
if
self
.
bias
is
not
None
else
None
,
)
return
output_tensor
@
MM_WEIGHT_REGISTER
(
"mxfp4"
)
class
MMWeightWmxfp4Amxfp4dynamic
(
MMWeightQuantTemplate
):
"""
Name: W-mxfp4-A-mxfp4-dynamic
Quant MM:
Weight: mxfp4
Act: mxfp4
"""
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
create_cpu_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
)
self
.
load_func
=
self
.
load_mxfp4
self
.
weight_need_transpose
=
False
self
.
act_quant_func
=
self
.
act_quant_mxfp4
self
.
set_alpha
()
def
set_alpha
(
self
):
self
.
alpha
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
)
def
apply
(
self
,
input_tensor
):
input_tensor_quant
,
input_tensor_scale
=
self
.
act_quant_func
(
input_tensor
)
self
.
alpha
=
self
.
alpha
.
to
(
self
.
weight
.
device
)
output_tensor
=
cutlass_scaled_mxfp4_mm
(
input_tensor_quant
,
self
.
weight
,
input_tensor_scale
,
self
.
weight_scale
,
alpha
=
self
.
alpha
,
bias
=
self
.
bias
)
return
output_tensor
@
MM_WEIGHT_REGISTER
(
"mxfp6-mxfp8"
)
class
MMWeightWmxfp6Amxfp8dynamic
(
MMWeightQuantTemplate
):
"""
Name: W-mxfp6-A-nvfp8-dynamic
Quant MM:
Weight: mxfp6
Act: mxfp8
"""
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
create_cpu_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
)
self
.
load_func
=
self
.
load_mxfp6
self
.
weight_need_transpose
=
False
self
.
act_quant_func
=
self
.
act_quant_mxfp8
self
.
set_alpha
()
def
set_alpha
(
self
):
self
.
alpha
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
)
def
apply
(
self
,
input_tensor
):
input_tensor_quant
,
input_tensor_scale
=
self
.
act_quant_func
(
input_tensor
)
self
.
alpha
=
self
.
alpha
.
to
(
self
.
weight
.
device
)
output_tensor
=
cutlass_scaled_mxfp6_mxfp8_mm
(
input_tensor_quant
,
self
.
weight
,
input_tensor_scale
,
self
.
weight_scale
,
alpha
=
self
.
alpha
,
bias
=
self
.
bias
)
return
output_tensor
@
MM_WEIGHT_REGISTER
(
"mxfp8"
)
class
MMWeightWmxfp8Amxfp8dynamic
(
MMWeightQuantTemplate
):
"""
Name: W-mxfp8-A-nvfp8-dynamic
Quant MM:
Weight: mxfp8
Act: mxfp8
"""
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
create_cpu_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
)
self
.
load_func
=
self
.
load_mxfp8
self
.
weight_need_transpose
=
False
self
.
act_quant_func
=
self
.
act_quant_mxfp8
self
.
set_alpha
()
def
set_alpha
(
self
):
self
.
alpha
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
)
def
apply
(
self
,
input_tensor
):
input_tensor_quant
,
input_tensor_scale
=
self
.
act_quant_func
(
input_tensor
)
self
.
alpha
=
self
.
alpha
.
to
(
self
.
weight
.
device
)
output_tensor
=
cutlass_scaled_mxfp8_mm
(
input_tensor_quant
,
self
.
weight
,
input_tensor_scale
,
self
.
weight_scale
,
alpha
=
self
.
alpha
,
bias
=
self
.
bias
)
return
output_tensor
@
MM_WEIGHT_REGISTER
(
"nvfp4"
)
class
MMWeightWnvfp4Anvfp4dynamic
(
MMWeightQuantTemplate
):
"""
Name: W-nvfp4-A-nvfp4-dynamic
Quant MM:
Weight: nvfp4
Act: nvfp4
"""
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
create_cpu_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
)
self
.
load_func
=
self
.
load_nvfp4
self
.
weight_need_transpose
=
False
self
.
act_quant_func
=
self
.
act_quant_nvfp4
def
apply
(
self
,
input_tensor
):
input_tensor_quant
,
input_tensor_scale
=
self
.
act_quant_func
(
input_tensor
)
output_tensor
=
cutlass_scaled_nvfp4_mm
(
input_tensor_quant
,
self
.
weight
,
input_tensor_scale
,
self
.
weight_scale
,
alpha
=
self
.
alpha
,
bias
=
self
.
bias
)
return
output_tensor
def
to_cuda
(
self
,
non_blocking
=
False
):
self
.
weight
=
self
.
pin_weight
.
to
(
AI_DEVICE
,
non_blocking
=
non_blocking
)
if
hasattr
(
self
,
"pin_weight_scale"
):
self
.
weight_scale
=
self
.
pin_weight_scale
.
to
(
AI_DEVICE
,
non_blocking
=
non_blocking
)
self
.
input_global_scale
=
self
.
pin_input_global_scale
.
to
(
AI_DEVICE
,
non_blocking
=
non_blocking
)
self
.
alpha
=
self
.
pin_alpha
.
to
(
AI_DEVICE
,
non_blocking
=
non_blocking
)
if
hasattr
(
self
,
"pin_bias"
)
and
self
.
pin_bias
is
not
None
:
self
.
bias
=
self
.
pin_bias
.
to
(
AI_DEVICE
,
non_blocking
=
non_blocking
)
def
to_cpu
(
self
,
non_blocking
=
False
):
if
hasattr
(
self
,
"pin_weight"
):
self
.
weight
=
self
.
pin_weight
.
copy_
(
self
.
weight
,
non_blocking
=
non_blocking
).
cpu
()
if
hasattr
(
self
,
"weight_scale_name"
):
self
.
weight_scale
=
self
.
pin_weight_scale
.
copy_
(
self
.
weight_scale
,
non_blocking
=
non_blocking
).
cpu
()
self
.
input_global_scale
=
self
.
pin_input_global_scale
.
copy_
(
self
.
input_global_scale
,
non_blocking
=
non_blocking
).
cpu
()
self
.
alpha
=
self
.
pin_alpha
.
copy_
(
self
.
alpha
,
non_blocking
=
non_blocking
).
cpu
()
if
self
.
bias
is
not
None
:
self
.
bias
=
self
.
pin_bias
.
copy_
(
self
.
bias
,
non_blocking
=
non_blocking
).
cpu
()
else
:
self
.
weight
=
self
.
weight
.
to
(
"cpu"
,
non_blocking
=
non_blocking
)
if
hasattr
(
self
,
"weight_scale"
):
self
.
weight_scale
=
self
.
weight_scale
.
to
(
"cpu"
,
non_blocking
=
non_blocking
)
self
.
input_global_scale
=
self
.
input_global_scale
.
to
(
"cpu"
,
non_blocking
=
non_blocking
)
self
.
alpha
=
self
.
alpha
.
to
(
"cpu"
,
non_blocking
=
non_blocking
)
if
hasattr
(
self
,
"bias"
)
and
self
.
bias
is
not
None
:
self
.
bias
=
self
.
bias
.
to
(
"cpu"
,
non_blocking
=
non_blocking
)
@
MM_WEIGHT_REGISTER
(
"Calib"
)
class
MMCalibNvfp4
(
MMWeight
):
"""
Name: calib
Calib:
absmax: torch.max(torch.abs(input_tensor))
"""
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
create_cpu_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
)
self
.
running_absmax
=
None
self
.
count
=
0
self
.
decay
=
0.9
def
apply
(
self
,
input_tensor
):
shape
=
(
input_tensor
.
shape
[
0
],
self
.
weight
.
shape
[
1
])
dtype
,
device
=
input_tensor
.
dtype
,
input_tensor
.
device
current_absmax
=
torch
.
max
(
torch
.
abs
(
input_tensor
)).
to
(
"cpu"
)
if
self
.
count
%
2
==
0
:
if
self
.
running_absmax
is
None
:
self
.
running_absmax
=
current_absmax
else
:
self
.
running_absmax
=
self
.
decay
*
self
.
running_absmax
+
(
1
-
self
.
decay
)
*
current_absmax
CALIB
[
"absmax"
][
self
.
weight_name
]
=
self
.
running_absmax
self
.
count
=
self
.
count
+
1
output_tensor
=
torch
.
empty
(
shape
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
False
)
if
self
.
bias
is
None
:
return
torch
.
mm
(
input_tensor
,
self
.
weight
,
out
=
output_tensor
)
return
torch
.
addmm
(
self
.
bias
,
input_tensor
,
self
.
weight
,
out
=
output_tensor
)
@
MM_WEIGHT_REGISTER
(
"fp8-q8f"
)
class
MMWeightWfp8channelAfp8channeldynamicQ8F
(
MMWeightQuantTemplate
):
"""
Name: W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Q8F
Quant MM:
Weight: fp8 perchannel sym
Act: fp8 perchannel dynamic sym
Kernel: Q8F
"""
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
create_cpu_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
)
self
.
load_func
=
self
.
load_fp8_perchannel_sym
self
.
weight_need_transpose
=
False
self
.
act_quant_func
=
self
.
act_quant_fp8_perchannel_sym_vllm
self
.
bias_force_fp32
=
True
def
apply
(
self
,
input_tensor
):
input_tensor_quant
,
input_tensor_scale
=
self
.
act_quant_func
(
input_tensor
)
output_tensor
=
fp8_linear
(
input_tensor_quant
,
self
.
weight
,
self
.
bias
.
float
()
if
self
.
bias
is
not
None
else
None
,
input_tensor_scale
,
self
.
weight_scale
,
out_dtype
=
self
.
infer_dtype
,
)
return
output_tensor
.
squeeze
(
0
)
if
len
(
output_tensor
.
shape
)
==
3
else
output_tensor
@
MM_WEIGHT_REGISTER
(
"int8-q8f"
)
class
MMWeightWint8channelAint8channeldynamicQ8F
(
MMWeightQuantTemplate
):
"""
Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Q8F
Quant MM:
Weight: int8 perchannel sym
Act: int8 perchannel dynamic sym
Kernel: Q8F
"""
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
create_cpu_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
)
self
.
load_func
=
self
.
load_int8_perchannel_sym
self
.
weight_need_transpose
=
False
self
.
act_quant_func
=
self
.
act_quant_int8_perchannel_sym_vllm
def
apply
(
self
,
input_tensor
):
input_tensor_quant
,
input_tensor_scale
=
self
.
act_quant_func
(
input_tensor
)
output_tensor
=
q8_linear
(
input_tensor_quant
,
self
.
weight
,
self
.
bias
.
float
()
if
self
.
bias
is
not
None
else
None
,
input_tensor_scale
,
self
.
weight_scale
,
fuse_gelu
=
False
,
out_dtype
=
self
.
infer_dtype
,
)
return
output_tensor
.
squeeze
(
0
)
if
len
(
output_tensor
.
shape
)
==
3
else
output_tensor
@
MM_WEIGHT_REGISTER
(
"fp8-b128-deepgemm"
)
class
MMWeightWfp8block128Afp8channelgroup128dynamicDeepgemmActSgl
(
MMWeightQuantTemplate
):
"""
Name: W-fp8-block128-sym-A-fp8-channel-group128-sym-dynamic-Deepgemm-ActSgl
Quant MM:
Weight: fp8 perblock 128x128 sym
Act: fp8 pertoken-pergroup group=128 dynamic sym
Kernel: quant-mm using Deepgemm, act dynamic quant using Sgl-kernel
"""
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
create_cpu_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
)
self
.
load_func
=
self
.
load_fp8_perblock128_sym
self
.
weight_need_transpose
=
False
self
.
act_quant_func
=
self
.
act_quant_fp8_perchannelgroup128_sym_sgl
def
apply
(
self
,
input_tensor
):
shape
=
(
input_tensor
.
shape
[
0
],
self
.
weight
.
shape
[
0
])
dtype
=
input_tensor
.
dtype
device
=
input_tensor
.
device
output_tensor
=
torch
.
empty
(
shape
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
False
)
input_tensor_quant
,
input_tensor_scale
=
self
.
act_quant_func
(
input_tensor
)
deep_gemm
.
gemm_fp8_fp8_bf16_nt
(
(
input_tensor_quant
,
input_tensor_scale
),
(
self
.
weight
,
self
.
weight_scale
),
output_tensor
,
)
if
hasattr
(
self
,
"bias"
)
and
self
.
bias
is
not
None
:
output_tensor
.
add_
(
self
.
bias
)
return
output_tensor
@
MM_WEIGHT_REGISTER
(
"fp8-sgl"
)
class
MMWeightWfp8channelAfp8channeldynamicSgl
(
MMWeightQuantTemplate
):
"""
Name: W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl
Quant MM:
Weight: fp8 perchannel sym
Act: fp8 perchannel dynamic sym
Kernel: Sgl-kernel
"""
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
create_cpu_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
)
self
.
load_func
=
self
.
load_fp8_perchannel_sym
self
.
weight_need_transpose
=
True
self
.
act_quant_func
=
self
.
act_quant_fp8_perchannel_sym_sgl
def
apply
(
self
,
input_tensor
):
input_tensor_quant
,
input_tensor_scale
=
self
.
act_quant_func
(
input_tensor
)
output_tensor
=
sgl_kernel
.
fp8_scaled_mm
(
input_tensor_quant
,
self
.
weight
,
input_tensor_scale
,
self
.
weight_scale
,
self
.
infer_dtype
,
self
.
bias
if
self
.
bias
is
not
None
else
None
,
)
return
output_tensor
@
MM_WEIGHT_REGISTER
(
"int8-sgl"
)
class
MMWeightWint8channelAint8channeldynamicSglActVllm
(
MMWeightQuantTemplate
):
"""
Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Sgl-ActVllm
Quant MM:
Weight: int8 perchannel sym
Act: int8 perchannel dynamic sym
Kernel: quant-mm using Sgl-kernel, act dynamic quant using vllm
"""
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
create_cpu_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
)
self
.
load_func
=
self
.
load_int8_perchannel_sym
self
.
weight_need_transpose
=
True
self
.
act_quant_func
=
self
.
act_quant_int8_perchannel_sym_vllm
def
apply
(
self
,
input_tensor
):
shape
=
(
input_tensor
.
shape
[
0
],
self
.
weight
.
shape
[
1
])
dtype
=
input_tensor
.
dtype
device
=
input_tensor
.
device
output_tensor
=
torch
.
empty
(
shape
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
False
)
input_tensor_quant
,
input_tensor_scale
=
self
.
act_quant_func
(
input_tensor
)
output_tensor
=
sgl_kernel
.
int8_scaled_mm
(
input_tensor_quant
,
self
.
weight
,
input_tensor_scale
,
self
.
weight_scale
,
self
.
infer_dtype
,
self
.
bias
if
self
.
bias
is
not
None
else
None
,
)
return
output_tensor
@
MM_WEIGHT_REGISTER
(
"int8-torchao"
)
class
MMWeightWint8channelAint8channeldynamicTorchao
(
MMWeightQuantTemplate
):
"""
Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Torchao
Quant MM:
Weight: int8 perchannel sym
Act: int8 perchannel dynamic sym
Kernel: Torchao
"""
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
create_cpu_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
)
self
.
load_func
=
self
.
load_int8_perchannel_sym
self
.
weight_need_transpose
=
True
self
.
act_quant_func
=
self
.
act_quant_int8_perchannel_sym_torchao
def
apply
(
self
,
input_tensor
):
input_tensor
=
input_tensor
input_tensor_quant
,
input_tensor_scale
=
self
.
act_quant_func
(
input_tensor
)
output_tensor
=
quant_int8_per_token_matmul
(
input_tensor_quant
,
input_tensor_scale
,
self
.
weight
,
self
.
weight_scale
.
t
().
float
(),
output_dtype
=
self
.
infer_dtype
)
if
self
.
bias
is
not
None
:
output_tensor
=
output_tensor
+
self
.
bias
return
output_tensor
class
MMWeightGGUFTemplate
(
MMWeightTemplate
):
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
create_cpu_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
)
def
load
(
self
,
weight_dict
):
if
not
self
.
lazy_load
:
assert
not
self
.
create_cuda_buffer
,
"GGUF Unsupported offload block"
self
.
weight
=
weight_dict
[
self
.
weight_name
]
weight_shape
=
self
.
weight
.
shape
weight_dtype
=
self
.
weight
.
dtype
if
isinstance
(
self
.
weight
,
GGMLTensor
):
self
.
pin_weight
=
GGMLTensor
.
empty_pinned
(
weight_shape
,
orig_shape
=
self
.
weight
.
orig_shape
,
dtype
=
weight_dtype
,
gguf_type
=
self
.
weight
.
gguf_type
)
self
.
pin_weight
.
copy_from
(
self
.
weight
)
else
:
self
.
pin_weight
=
torch
.
empty
(
weight_shape
,
pin_memory
=
True
,
dtype
=
weight_dtype
)
self
.
pin_weight
.
copy_
(
weight_dict
[
self
.
weight_name
])
if
self
.
bias_name
is
not
None
:
self
.
bias
=
weight_dict
[
self
.
bias_name
]
if
isinstance
(
self
.
bias
,
GGMLTensor
):
self
.
pin_bias
=
GGMLTensor
.
empty_pinned
(
self
.
bias
.
shape
,
orig_shape
=
self
.
bias
.
orig_shape
,
dtype
=
self
.
bias
.
dtype
,
gguf_type
=
self
.
bias
.
gguf_type
)
self
.
pin_bias
.
copy_from
(
self
.
bias
)
else
:
self
.
pin_bias
=
torch
.
empty
(
self
.
bias
.
shape
,
pin_memory
=
True
,
dtype
=
self
.
bias
.
dtype
)
self
.
pin_bias
.
copy_
(
weight_dict
[
self
.
bias_name
])
else
:
self
.
bias
=
None
def
load_state_dict
(
self
,
destination
,
block_index
,
adapter_block_index
=
None
):
if
self
.
is_post_adapter
:
assert
adapter_block_index
is
not
None
weight_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
adapter_block_index
}
"
,
self
.
weight_name
,
count
=
1
)
else
:
weight_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
block_index
}
"
,
self
.
weight_name
,
count
=
1
)
if
weight_name
not
in
destination
:
self
.
weight
=
None
return
self
.
weight
=
self
.
weight_cuda_buffer
.
copy_
(
destination
[
weight_name
],
non_blocking
=
True
)
if
self
.
bias_name
is
not
None
:
if
self
.
is_post_adapter
:
assert
adapter_block_index
is
not
None
bias_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
adapter_block_index
}
"
,
self
.
bias_name
,
count
=
1
)
else
:
bias_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
block_index
}
"
,
self
.
bias_name
,
count
=
1
)
self
.
bias
=
self
.
bias_cuda_buffer
.
copy_
(
destination
[
bias_name
],
non_blocking
=
True
)
else
:
self
.
bias
=
None
def
state_dict
(
self
,
destination
=
None
):
if
destination
is
None
:
destination
=
{}
destination
[
self
.
weight_name
]
=
self
.
pin_weight
if
hasattr
(
self
,
"pin_weight"
)
else
self
.
weight
if
self
.
bias_name
is
not
None
:
destination
[
self
.
bias_name
]
=
self
.
pin_bias
if
hasattr
(
self
,
"pin_bias"
)
else
self
.
bias
return
destination
def
get_weight
(
self
,
tensor
,
dtype
):
if
tensor
is
None
:
return
weight
=
gguf_dequantize_tensor
(
tensor
,
dtype
)
if
isinstance
(
weight
,
GGMLTensor
):
weight
=
torch
.
Tensor
(
weight
)
return
weight
def
cast_bias_weight
(
self
,
input_tensor
=
None
,
dtype
=
None
,
device
=
None
,
bias_dtype
=
None
):
if
input_tensor
is
not
None
:
if
dtype
is
None
:
dtype
=
getattr
(
input_tensor
,
"dtype"
,
torch
.
float32
)
bias
=
None
if
self
.
bias
is
not
None
:
bias
=
self
.
get_weight
(
self
.
bias
,
dtype
)
weight
=
self
.
get_weight
(
self
.
weight
,
dtype
)
return
weight
,
bias
def
apply
(
self
,
input_tensor
):
weight
,
bias
=
self
.
cast_bias_weight
(
input_tensor
)
return
torch
.
nn
.
functional
.
linear
(
input_tensor
,
weight
,
bias
)
@
MM_WEIGHT_REGISTER
(
"gguf-BF16"
)
class
MMWeightGGUFBF16
(
MMWeightGGUFTemplate
):
qtype
=
gguf
.
GGMLQuantizationType
.
BF16
@
MM_WEIGHT_REGISTER
(
"gguf-Q8_0"
)
class
MMWeightGGUFQ80
(
MMWeightGGUFTemplate
):
qtype
=
gguf
.
GGMLQuantizationType
.
Q8_0
@
MM_WEIGHT_REGISTER
(
"gguf-Q6_K"
)
class
MMWeightGGUFQ6K
(
MMWeightGGUFTemplate
):
qtype
=
gguf
.
GGMLQuantizationType
.
Q6_K
@
MM_WEIGHT_REGISTER
(
"gguf-Q5_K_S"
)
class
MMWeightGGUFQ5KS
(
MMWeightGGUFTemplate
):
qtype
=
gguf
.
GGMLQuantizationType
.
Q6_K
@
MM_WEIGHT_REGISTER
(
"gguf-Q5_K_M"
)
class
MMWeightGGUFQ5KM
(
MMWeightGGUFTemplate
):
qtype
=
gguf
.
GGMLQuantizationType
.
Q6_K
@
MM_WEIGHT_REGISTER
(
"gguf-Q5_1"
)
class
MMWeightGGUFQ51
(
MMWeightGGUFTemplate
):
qtype
=
gguf
.
GGMLQuantizationType
.
Q5_1
@
MM_WEIGHT_REGISTER
(
"gguf-Q5_0"
)
class
MMWeightGGUFQ50
(
MMWeightGGUFTemplate
):
qtype
=
gguf
.
GGMLQuantizationType
.
Q5_0
@
MM_WEIGHT_REGISTER
(
"gguf-Q4_K_M"
)
class
MMWeightGGUFQ4KM
(
MMWeightGGUFTemplate
):
qtype
=
gguf
.
GGMLQuantizationType
.
Q5_0
@
MM_WEIGHT_REGISTER
(
"gguf-Q4_K_S"
)
class
MMWeightGGUFQ4KS
(
MMWeightGGUFTemplate
):
qtype
=
gguf
.
GGMLQuantizationType
.
Q4_K
@
MM_WEIGHT_REGISTER
(
"gguf-Q4_1"
)
class
MMWeightGGUFQ41
(
MMWeightGGUFTemplate
):
qtype
=
gguf
.
GGMLQuantizationType
.
Q4_1
@
MM_WEIGHT_REGISTER
(
"gguf-Q4_0"
)
class
MMWeightGGUFQ40
(
MMWeightGGUFTemplate
):
qtype
=
gguf
.
GGMLQuantizationType
.
Q4_0
@
MM_WEIGHT_REGISTER
(
"gguf-Q3_K_M"
)
class
MMWeightGGUFQ3KM
(
MMWeightGGUFTemplate
):
qtype
=
gguf
.
GGMLQuantizationType
.
Q3_K
@
MM_WEIGHT_REGISTER
(
"gguf-Q3_K_S"
)
class
MMWeightGGUFQ3KS
(
MMWeightGGUFTemplate
):
qtype
=
gguf
.
GGMLQuantizationType
.
Q2_K
@
MM_WEIGHT_REGISTER
(
"int4-g128-marlin"
)
class
MMWeightWint4group128Marlin
(
MMWeightQuantTemplate
):
"""
Name: "W-int4-group128-sym-Marlin
Quant int4 x FP16:
Weight: int4 pergroup sym
Kernel: Marlin
"""
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
create_cpu_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
)
self
.
load_func
=
self
.
load_quantized
def
load
(
self
,
weight_dict
):
assert
not
self
.
lazy_load
self
.
load_func
(
weight_dict
)
self
.
workspace
=
weight_dict
[
f
"
{
self
.
weight_name
}
_workspace"
]
if
self
.
bias_name
is
not
None
:
bias_shape
=
weight_dict
[
self
.
bias_name
].
shape
bias_dtype
=
weight_dict
[
self
.
bias_name
].
dtype
self
.
bias
=
torch
.
empty
(
bias_shape
,
pin_memory
=
True
,
dtype
=
bias_dtype
)
self
.
bias
.
copy_
(
weight_dict
[
self
.
bias_name
])
else
:
self
.
bias
=
None
def
apply
(
self
,
input_tensor
):
output_tensor
=
torch
.
empty
(
input_tensor
.
shape
[:
-
1
]
+
(
self
.
weight_scale
.
shape
[
1
],),
dtype
=
input_tensor
.
dtype
,
device
=
input_tensor
.
device
)
marlin_cuda_quant
.
mul
(
input_tensor
,
self
.
weight
,
output_tensor
,
self
.
weight_scale
.
half
(),
self
.
workspace
,
-
1
,
-
1
,
-
1
,
-
1
)
if
hasattr
(
self
,
"bias"
)
and
self
.
bias
is
not
None
:
output_tensor
.
add_
(
self
.
bias
)
return
output_tensor
Prev
1
…
14
15
16
17
18
19
20
21
22
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment