Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
954df466
"git@developer.sourcefind.cn:tianlh/lightgbm-dcu.git" did not exist on "9c0e477ac6e082da3c072412641967822980edc7"
Unverified
Commit
954df466
authored
Oct 17, 2025
by
Yang Yong (雍洋)
Committed by
GitHub
Oct 17, 2025
Browse files
Support SVG Attention (#374)
parent
51e102fe
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
514 additions
and
20 deletions
+514
-20
configs/attentions/wan_i2v_svg.json
configs/attentions/wan_i2v_svg.json
+13
-0
lightx2v/common/ops/attn/__init__.py
lightx2v/common/ops/attn/__init__.py
+7
-6
lightx2v/common/ops/attn/svg_attn.py
lightx2v/common/ops/attn/svg_attn.py
+408
-0
lightx2v/models/networks/qwen_image/weights/transformer_weights.py
...models/networks/qwen_image/weights/transformer_weights.py
+0
-2
lightx2v/models/networks/wan/weights/transformer_weights.py
lightx2v/models/networks/wan/weights/transformer_weights.py
+10
-12
lightx2v/utils/print_atten_score.py
lightx2v/utils/print_atten_score.py
+76
-0
No files found.
configs/attentions/wan_i2v_svg.json
0 → 100755
View file @
954df466
{
"infer_steps"
:
40
,
"target_video_length"
:
81
,
"target_height"
:
480
,
"target_width"
:
832
,
"self_attn_1_type"
:
"svg_attn"
,
"cross_attn_1_type"
:
"flash_attn3"
,
"cross_attn_2_type"
:
"flash_attn3"
,
"sample_guide_scale"
:
5
,
"sample_shift"
:
3
,
"enable_cfg"
:
true
,
"cpu_offload"
:
false
}
lightx2v/common/ops/attn/__init__.py
View file @
954df466
from
.flash_attn
import
*
from
.radial_attn
import
*
from
.ring_attn
import
*
from
.sage_attn
import
*
from
.torch_sdpa
import
*
from
.ulysses_attn
import
*
from
.flash_attn
import
FlashAttn2Weight
,
FlashAttn3Weight
from
.radial_attn
import
RadialAttnWeight
from
.ring_attn
import
RingAttnWeight
from
.sage_attn
import
SageAttn2Weight
from
.svg_attn
import
SvgAttnWeight
from
.torch_sdpa
import
TorchSDPAWeight
from
.ulysses_attn
import
UlyssesAttnWeight
lightx2v/common/ops/attn/svg_attn.py
0 → 100644
View file @
954df466
import
math
from
functools
import
lru_cache
from
math
import
ceil
import
torch
import
torch.nn.functional
as
F
import
triton
import
triton.language
as
tl
from
loguru
import
logger
from
torch.nn.attention.flex_attention
import
create_block_mask
,
flex_attention
from
lightx2v.utils.registry_factory
import
ATTN_WEIGHT_REGISTER
from
.template
import
AttnWeightTemplate
@
triton
.
jit
def
wan_hidden_states_placement_kernel
(
hidden_states_ptr
,
# [cfg, num_heads, seq_len, head_dim] seq_len = context_length + num_frame * frame_size
hidden_states_out_ptr
,
# [cfg, num_heads, seq_len, head_dim]
best_mask_idx_ptr
,
# [cfg, num_heads]
hidden_states_stride_b
,
hidden_states_stride_h
,
hidden_states_stride_s
,
hidden_states_stride_d
,
mask_idx_stride_b
,
mask_idx_stride_h
,
seq_len
:
tl
.
constexpr
,
head_dim
:
tl
.
constexpr
,
context_length
:
tl
.
constexpr
,
num_frame
:
tl
.
constexpr
,
frame_size
:
tl
.
constexpr
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
# Copy hidden_states to output
# range: [b, h, block_id * block_size: block_id * block_size + block_size, :]
cfg
=
tl
.
program_id
(
0
)
head
=
tl
.
program_id
(
1
)
block_id
=
tl
.
program_id
(
2
)
start_id
=
block_id
*
BLOCK_SIZE
end_id
=
start_id
+
BLOCK_SIZE
end_id
=
tl
.
where
(
end_id
>
seq_len
,
seq_len
,
end_id
)
# Load best mask idx (0 is spatial, 1 is temporal)
is_temporal
=
tl
.
load
(
best_mask_idx_ptr
+
cfg
*
mask_idx_stride_b
+
head
*
mask_idx_stride_h
)
offset_token
=
tl
.
arange
(
0
,
BLOCK_SIZE
)
+
start_id
offset_mask
=
offset_token
<
seq_len
offset_d
=
tl
.
arange
(
0
,
head_dim
)
if
is_temporal
:
patch_id
=
offset_token
//
num_frame
frame_id
=
offset_token
-
patch_id
*
num_frame
offset_store_token
=
tl
.
where
(
offset_token
>=
seq_len
-
context_length
,
offset_token
,
frame_id
*
frame_size
+
patch_id
)
offset_load
=
(
cfg
*
hidden_states_stride_b
+
head
*
hidden_states_stride_h
+
offset_token
[:,
None
]
*
hidden_states_stride_s
)
+
offset_d
[
None
,
:]
*
hidden_states_stride_d
offset_hidden_states
=
hidden_states_ptr
+
offset_load
offset_store
=
(
cfg
*
hidden_states_stride_b
+
head
*
hidden_states_stride_h
+
offset_store_token
[:,
None
]
*
hidden_states_stride_s
)
+
offset_d
[
None
,
:]
*
hidden_states_stride_d
offset_hidden_states_out
=
hidden_states_out_ptr
+
offset_store
# Maybe tune the pipeline here
hidden_states
=
tl
.
load
(
offset_hidden_states
,
mask
=
offset_mask
[:,
None
])
tl
.
store
(
offset_hidden_states_out
,
hidden_states
,
mask
=
offset_mask
[:,
None
])
else
:
offset_load
=
(
cfg
*
hidden_states_stride_b
+
head
*
hidden_states_stride_h
+
offset_token
[:,
None
]
*
hidden_states_stride_s
)
+
offset_d
[
None
,
:]
*
hidden_states_stride_d
offset_hidden_states
=
hidden_states_ptr
+
offset_load
offset_store
=
offset_load
offset_hidden_states_out
=
hidden_states_out_ptr
+
offset_store
# Maybe tune the pipeline here
hidden_states
=
tl
.
load
(
offset_hidden_states
,
mask
=
offset_mask
[:,
None
])
tl
.
store
(
offset_hidden_states_out
,
hidden_states
,
mask
=
offset_mask
[:,
None
])
def
wan_hidden_states_placement
(
hidden_states
,
hidden_states_out
,
best_mask_idx
,
context_length
,
num_frame
,
frame_size
):
cfg
,
num_heads
,
seq_len
,
head_dim
=
hidden_states
.
shape
BLOCK_SIZE
=
128
assert
seq_len
==
context_length
+
num_frame
*
frame_size
grid
=
(
cfg
,
num_heads
,
(
seq_len
+
BLOCK_SIZE
-
1
)
//
BLOCK_SIZE
)
wan_hidden_states_placement_kernel
[
grid
](
hidden_states
,
hidden_states_out
,
best_mask_idx
,
hidden_states
.
stride
(
0
),
hidden_states
.
stride
(
1
),
hidden_states
.
stride
(
2
),
hidden_states
.
stride
(
3
),
best_mask_idx
.
stride
(
0
),
best_mask_idx
.
stride
(
1
),
seq_len
,
head_dim
,
context_length
,
num_frame
,
frame_size
,
BLOCK_SIZE
,
)
return
hidden_states_out
@
triton
.
jit
def
wan_sparse_head_placement_kernel
(
query_ptr
,
key_ptr
,
value_ptr
,
# [cfg, num_heads, seq_len, head_dim] seq_len = context_length + num_frame * frame_size
query_out_ptr
,
key_out_ptr
,
value_out_ptr
,
# [cfg, num_heads, seq_len, head_dim]
best_mask_idx_ptr
,
# [cfg, num_heads]
query_stride_b
,
query_stride_h
,
query_stride_s
,
query_stride_d
,
mask_idx_stride_b
,
mask_idx_stride_h
,
seq_len
:
tl
.
constexpr
,
head_dim
:
tl
.
constexpr
,
context_length
:
tl
.
constexpr
,
num_frame
:
tl
.
constexpr
,
frame_size
:
tl
.
constexpr
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
# Copy query, key, value to output
# range: [b, h, block_id * block_size: block_id * block_size + block_size, :]
cfg
=
tl
.
program_id
(
0
)
head
=
tl
.
program_id
(
1
)
block_id
=
tl
.
program_id
(
2
)
start_id
=
block_id
*
BLOCK_SIZE
end_id
=
start_id
+
BLOCK_SIZE
end_id
=
tl
.
where
(
end_id
>
seq_len
,
seq_len
,
end_id
)
# Load best mask idx (0 is spatial, 1 is temporal)
is_temporal
=
tl
.
load
(
best_mask_idx_ptr
+
cfg
*
mask_idx_stride_b
+
head
*
mask_idx_stride_h
)
offset_token
=
tl
.
arange
(
0
,
BLOCK_SIZE
)
+
start_id
offset_mask
=
offset_token
<
seq_len
offset_d
=
tl
.
arange
(
0
,
head_dim
)
if
is_temporal
:
frame_id
=
offset_token
//
frame_size
patch_id
=
offset_token
-
frame_id
*
frame_size
offset_store_token
=
tl
.
where
(
offset_token
>=
seq_len
-
context_length
,
offset_token
,
patch_id
*
num_frame
+
frame_id
)
offset_load
=
(
cfg
*
query_stride_b
+
head
*
query_stride_h
+
offset_token
[:,
None
]
*
query_stride_s
)
+
offset_d
[
None
,
:]
*
query_stride_d
offset_query
=
query_ptr
+
offset_load
offset_key
=
key_ptr
+
offset_load
offset_value
=
value_ptr
+
offset_load
offset_store
=
(
cfg
*
query_stride_b
+
head
*
query_stride_h
+
offset_store_token
[:,
None
]
*
query_stride_s
)
+
offset_d
[
None
,
:]
*
query_stride_d
offset_query_out
=
query_out_ptr
+
offset_store
offset_key_out
=
key_out_ptr
+
offset_store
offset_value_out
=
value_out_ptr
+
offset_store
# Maybe tune the pipeline here
query
=
tl
.
load
(
offset_query
,
mask
=
offset_mask
[:,
None
])
tl
.
store
(
offset_query_out
,
query
,
mask
=
offset_mask
[:,
None
])
key
=
tl
.
load
(
offset_key
,
mask
=
offset_mask
[:,
None
])
tl
.
store
(
offset_key_out
,
key
,
mask
=
offset_mask
[:,
None
])
value
=
tl
.
load
(
offset_value
,
mask
=
offset_mask
[:,
None
])
tl
.
store
(
offset_value_out
,
value
,
mask
=
offset_mask
[:,
None
])
else
:
offset_load
=
(
cfg
*
query_stride_b
+
head
*
query_stride_h
+
offset_token
[:,
None
]
*
query_stride_s
)
+
offset_d
[
None
,
:]
*
query_stride_d
offset_query
=
query_ptr
+
offset_load
offset_key
=
key_ptr
+
offset_load
offset_value
=
value_ptr
+
offset_load
offset_store
=
offset_load
offset_query_out
=
query_out_ptr
+
offset_store
offset_key_out
=
key_out_ptr
+
offset_store
offset_value_out
=
value_out_ptr
+
offset_store
# Maybe tune the pipeline here
query
=
tl
.
load
(
offset_query
,
mask
=
offset_mask
[:,
None
])
tl
.
store
(
offset_query_out
,
query
,
mask
=
offset_mask
[:,
None
])
key
=
tl
.
load
(
offset_key
,
mask
=
offset_mask
[:,
None
])
tl
.
store
(
offset_key_out
,
key
,
mask
=
offset_mask
[:,
None
])
value
=
tl
.
load
(
offset_value
,
mask
=
offset_mask
[:,
None
])
tl
.
store
(
offset_value_out
,
value
,
mask
=
offset_mask
[:,
None
])
def
wan_sparse_head_placement
(
query
,
key
,
value
,
query_out
,
key_out
,
value_out
,
best_mask_idx
,
context_length
,
num_frame
,
frame_size
):
cfg
,
num_heads
,
seq_len
,
head_dim
=
query
.
shape
BLOCK_SIZE
=
128
assert
seq_len
==
context_length
+
num_frame
*
frame_size
grid
=
(
cfg
,
num_heads
,
(
seq_len
+
BLOCK_SIZE
-
1
)
//
BLOCK_SIZE
)
wan_sparse_head_placement_kernel
[
grid
](
query
,
key
,
value
,
query_out
,
key_out
,
value_out
,
best_mask_idx
,
query
.
stride
(
0
),
query
.
stride
(
1
),
query
.
stride
(
2
),
query
.
stride
(
3
),
best_mask_idx
.
stride
(
0
),
best_mask_idx
.
stride
(
1
),
seq_len
,
head_dim
,
context_length
,
num_frame
,
frame_size
,
BLOCK_SIZE
,
)
def
generate_temporal_head_mask_mod
(
context_length
:
int
=
226
,
prompt_length
:
int
=
226
,
num_frames
:
int
=
13
,
token_per_frame
:
int
=
1350
,
mul
:
int
=
2
):
def
round_to_multiple
(
idx
):
return
ceil
(
idx
/
128
)
*
128
def
temporal_mask_mod
(
b
,
h
,
q_idx
,
kv_idx
):
two_frame
=
round_to_multiple
(
mul
*
token_per_frame
)
temporal_head_mask
=
torch
.
abs
(
q_idx
-
kv_idx
)
<=
two_frame
# return temporal_head_mask
first_frame_mask
=
kv_idx
<
token_per_frame
video_mask
=
first_frame_mask
|
temporal_head_mask
return
video_mask
return
temporal_mask_mod
@
lru_cache
def
create_block_mask_cached
(
score_mod
,
B
,
H
,
M
,
N
,
device
=
"cuda"
,
_compile
=
False
):
block_mask
=
create_block_mask
(
score_mod
,
B
,
H
,
M
,
N
,
device
=
device
,
_compile
=
_compile
)
return
block_mask
def
prepare_flexattention
(
cfg_size
,
num_head
,
head_dim
,
dtype
,
device
,
context_length
,
prompt_length
,
num_frame
,
frame_size
,
diag_width
=
1
,
multiplier
=
2
):
assert
diag_width
==
multiplier
,
f
"
{
diag_width
}
is not equivalent to
{
multiplier
}
"
seq_len
=
context_length
+
num_frame
*
frame_size
mask_mod
=
generate_temporal_head_mask_mod
(
context_length
,
prompt_length
,
num_frame
,
frame_size
,
mul
=
multiplier
)
block_mask
=
create_block_mask_cached
(
mask_mod
,
None
,
None
,
seq_len
,
seq_len
,
device
=
device
,
_compile
=
True
)
return
block_mask
def
sparsity_to_width
(
sparsity
,
context_length
,
num_frame
,
frame_size
):
seq_len
=
context_length
+
num_frame
*
frame_size
total_elements
=
seq_len
**
2
sparsity
=
(
sparsity
*
total_elements
-
2
*
seq_len
*
context_length
)
/
total_elements
width
=
seq_len
*
(
1
-
math
.
sqrt
(
1
-
sparsity
))
width_frame
=
width
/
frame_size
return
width_frame
def
get_attention_mask
(
mask_name
,
sample_mse_max_row
,
context_length
,
num_frame
,
frame_size
):
attention_mask
=
torch
.
zeros
((
context_length
+
num_frame
*
frame_size
,
context_length
+
num_frame
*
frame_size
),
device
=
"cpu"
)
# TODO: fix hard coded mask
if
mask_name
==
"spatial"
:
pixel_attn_mask
=
torch
.
zeros_like
(
attention_mask
,
dtype
=
torch
.
bool
,
device
=
"cpu"
)
pixel_attn_mask
[:,
:
frame_size
]
=
1
# First Frame Sink
block_size
,
block_thres
=
128
,
frame_size
*
2
num_block
=
math
.
ceil
(
num_frame
*
frame_size
/
block_size
)
for
i
in
range
(
num_block
):
for
j
in
range
(
num_block
):
if
abs
(
i
-
j
)
<
block_thres
//
block_size
:
pixel_attn_mask
[
i
*
block_size
:
(
i
+
1
)
*
block_size
,
j
*
block_size
:
(
j
+
1
)
*
block_size
]
=
1
attention_mask
=
pixel_attn_mask
else
:
pixel_attn_mask
=
torch
.
zeros_like
(
attention_mask
,
dtype
=
torch
.
bool
,
device
=
"cpu"
)
pixel_attn_mask
[:,
:
frame_size
]
=
1
# First Frame Sink
block_size
,
block_thres
=
128
,
frame_size
*
2
num_block
=
math
.
ceil
(
num_frame
*
frame_size
/
block_size
)
for
i
in
range
(
num_block
):
for
j
in
range
(
num_block
):
if
abs
(
i
-
j
)
<
block_thres
//
block_size
:
pixel_attn_mask
[
i
*
block_size
:
(
i
+
1
)
*
block_size
,
j
*
block_size
:
(
j
+
1
)
*
block_size
]
=
1
pixel_attn_mask
=
pixel_attn_mask
.
reshape
(
frame_size
,
num_frame
,
frame_size
,
num_frame
).
permute
(
1
,
0
,
3
,
2
).
reshape
(
frame_size
*
num_frame
,
frame_size
*
num_frame
)
attention_mask
=
pixel_attn_mask
attention_mask
=
attention_mask
[:
sample_mse_max_row
].
cuda
()
return
attention_mask
@
ATTN_WEIGHT_REGISTER
(
"svg_attn"
)
class
SvgAttnWeight
(
AttnWeightTemplate
):
head_num
=
None
head_dim
=
None
sample_mse_max_row
=
None
num_sampled_rows
=
None
context_length
=
None
num_frame
=
None
frame_size
=
None
sparsity
=
None
mask_name_list
=
[
"spatial"
,
"temporal"
]
attention_masks
=
None
block_mask
=
None
@
classmethod
def
prepare
(
cls
,
head_num
,
head_dim
,
sample_mse_max_row
,
num_sampled_rows
,
context_length
,
sparsity
):
cls
.
head_num
=
head_num
cls
.
head_dim
=
head_dim
cls
.
sample_mse_max_row
=
sample_mse_max_row
cls
.
num_sampled_rows
=
num_sampled_rows
cls
.
context_length
=
context_length
cls
.
sparsity
=
sparsity
torch
.
_dynamo
.
config
.
cache_size_limit
=
192
*
3
torch
.
_dynamo
.
config
.
accumulated_cache_size_limit
=
192
*
3
logger
.
info
(
f
"SvgAttnWeight Prepare: head_num=
{
head_num
}
, head_dim=
{
head_dim
}
, sample_mse_max_row=
{
sample_mse_max_row
}
, num_sampled_rows=
{
num_sampled_rows
}
, context_length=
{
context_length
}
, sparsity=
{
sparsity
}
"
)
def
__init__
(
self
):
self
.
config
=
{}
self
.
sparse_attention
=
torch
.
compile
(
flex_attention
,
dynamic
=
False
,
mode
=
"max-autotune-no-cudagraphs"
)
@
classmethod
def
prepare_mask
(
cls
,
num_frame
,
frame_size
):
# Use class attributes so updates affect all instances of this class
if
num_frame
==
cls
.
num_frame
and
frame_size
==
cls
.
frame_size
:
return
cls
.
num_frame
=
num_frame
cls
.
frame_size
=
frame_size
cls
.
attention_masks
=
[
get_attention_mask
(
mask_name
,
cls
.
sample_mse_max_row
,
cls
.
context_length
,
num_frame
,
frame_size
)
for
mask_name
in
cls
.
mask_name_list
]
multiplier
=
diag_width
=
sparsity_to_width
(
cls
.
sparsity
,
cls
.
context_length
,
num_frame
,
frame_size
)
cls
.
block_mask
=
prepare_flexattention
(
1
,
cls
.
head_num
,
cls
.
head_dim
,
torch
.
bfloat16
,
"cuda"
,
cls
.
context_length
,
cls
.
context_length
,
num_frame
,
frame_size
,
diag_width
=
diag_width
,
multiplier
=
multiplier
)
logger
.
info
(
f
"SvgAttnWeight Update: num_frame=
{
num_frame
}
, frame_size=
{
frame_size
}
"
)
def
apply
(
self
,
q
,
k
,
v
,
cu_seqlens_q
=
None
,
cu_seqlens_kv
=
None
,
max_seqlen_q
=
None
,
max_seqlen_kv
=
None
,
model_cls
=
None
,
):
q
=
q
.
unsqueeze
(
0
).
transpose
(
1
,
2
)
k
=
k
.
unsqueeze
(
0
).
transpose
(
1
,
2
)
v
=
v
.
unsqueeze
(
0
).
transpose
(
1
,
2
)
bs
,
num_heads
,
seq_len
,
dim
=
q
.
size
()
num_frame
=
21
self
.
prepare_mask
(
num_frame
=
num_frame
,
frame_size
=
seq_len
//
num_frame
)
sampled_mses
=
self
.
sample_mse
(
q
,
k
,
v
)
best_mask_idx
=
torch
.
argmin
(
sampled_mses
,
dim
=
0
)
output_hidden_states
=
torch
.
zeros_like
(
q
)
query_out
,
key_out
,
value_out
=
torch
.
zeros_like
(
q
),
torch
.
zeros_like
(
k
),
torch
.
zeros_like
(
v
)
query_out
,
key_out
,
value_out
=
self
.
fast_sparse_head_placement
(
q
,
k
,
v
,
query_out
,
key_out
,
value_out
,
best_mask_idx
,
self
.
context_length
,
self
.
num_frame
,
self
.
frame_size
)
hidden_states
=
self
.
sparse_attention
(
query_out
,
key_out
,
value_out
)
wan_hidden_states_placement
(
hidden_states
,
output_hidden_states
,
best_mask_idx
,
self
.
context_length
,
self
.
num_frame
,
self
.
frame_size
)
return
output_hidden_states
.
reshape
(
bs
,
num_heads
,
seq_len
,
dim
).
transpose
(
1
,
2
).
reshape
(
bs
*
seq_len
,
-
1
)
def
fast_sparse_head_placement
(
self
,
query
,
key
,
value
,
query_out
,
key_out
,
value_out
,
best_mask_idx
,
context_length
,
num_frame
,
frame_size
):
wan_sparse_head_placement
(
query
,
key
,
value
,
query_out
,
key_out
,
value_out
,
best_mask_idx
,
context_length
,
num_frame
,
frame_size
)
return
query_out
,
key_out
,
value_out
def
sample_mse
(
self
,
query
,
key
,
value
):
cfg
,
num_heads
,
seq_len
,
dim
=
query
.
size
()
num_sampled_rows
=
min
(
self
.
num_sampled_rows
,
seq_len
)
sampled_rows
=
torch
.
randint
(
low
=
0
,
high
=
self
.
sample_mse_max_row
,
size
=
(
num_sampled_rows
,))
sampled_q
=
query
[:,
:,
sampled_rows
,
:]
sampled_qk_scores
=
torch
.
matmul
(
sampled_q
,
key
.
transpose
(
-
2
,
-
1
))
/
(
dim
**
0.5
)
sampled_attn_weights
=
F
.
softmax
(
sampled_qk_scores
,
dim
=-
1
)
sampled_golden_hidden_states
=
torch
.
matmul
(
sampled_attn_weights
,
value
)
# (1, seq_len, dim)
sampled_mses
=
torch
.
zeros
(
len
(
self
.
attention_masks
),
cfg
,
num_heads
,
device
=
query
.
device
,
dtype
=
query
.
dtype
)
# Only have Tri-diagonal and Striped
for
mask_idx
,
attn_mask
in
enumerate
(
self
.
attention_masks
):
sampled_attention_mask
=
attn_mask
[
sampled_rows
,
:]
sampled_attention_scores
=
sampled_qk_scores
.
masked_fill
(
sampled_attention_mask
==
0
,
float
(
"-inf"
))
sampled_attn_weights
=
F
.
softmax
(
sampled_attention_scores
,
dim
=-
1
)
sampled_hidden_states
=
torch
.
matmul
(
sampled_attn_weights
,
value
)
mse
=
torch
.
mean
((
sampled_hidden_states
-
sampled_golden_hidden_states
)
**
2
,
dim
=
(
2
,
3
))
sampled_mses
[
mask_idx
]
=
mse
return
sampled_mses
if
__name__
==
"__main__"
:
q
,
k
,
v
=
torch
.
randn
(
32130
,
40
,
128
,
dtype
=
torch
.
bfloat16
).
cuda
(),
torch
.
randn
(
32130
,
40
,
128
,
dtype
=
torch
.
bfloat16
).
cuda
(),
torch
.
randn
(
32130
,
40
,
128
,
dtype
=
torch
.
bfloat16
).
cuda
()
SvgAttnWeight
.
prepare
(
head_num
=
40
,
head_dim
=
128
,
sample_mse_max_row
=
10000
,
num_sampled_rows
=
64
,
context_length
=
0
,
sparsity
=
0.25
)
svg_attn
=
SvgAttnWeight
()
print
(
"SvgAttnWeight initialized."
)
out
=
svg_attn
.
apply
(
q
,
k
,
v
)
print
(
f
"out:
{
out
.
shape
}
,
{
out
.
dtype
}
,
{
out
.
device
}
"
)
lightx2v/models/networks/qwen_image/weights/transformer_weights.py
View file @
954df466
...
...
@@ -29,7 +29,6 @@ class QwenImageTransformerAttentionBlock(WeightModule):
self
.
task
=
task
self
.
config
=
config
self
.
quant_method
=
config
.
get
(
"quant_method"
,
None
)
self
.
sparge
=
config
.
get
(
"sparge"
,
False
)
self
.
lazy_load
=
self
.
config
.
get
(
"lazy_load"
,
False
)
if
self
.
lazy_load
:
...
...
@@ -139,7 +138,6 @@ class QwenImageCrossAttention(WeightModule):
self
.
task
=
task
self
.
config
=
config
self
.
quant_method
=
config
.
get
(
"quant_method"
,
None
)
self
.
sparge
=
config
.
get
(
"sparge"
,
False
)
self
.
attn_type
=
config
.
get
(
"attn_type"
,
"flash_attn3"
)
self
.
heads
=
config
[
"attention_out_dim"
]
//
config
[
"attention_dim_head"
]
...
...
lightx2v/models/networks/wan/weights/transformer_weights.py
View file @
954df466
import
os
import
torch
from
safetensors
import
safe_open
from
lightx2v.common.modules.weight_module
import
WeightModule
,
WeightModuleList
...
...
@@ -56,7 +55,6 @@ class WanTransformerAttentionBlock(WeightModule):
self
.
task
=
task
self
.
config
=
config
self
.
quant_method
=
config
.
get
(
"quant_method"
,
None
)
self
.
sparge
=
config
.
get
(
"sparge"
,
False
)
self
.
lazy_load
=
self
.
config
.
get
(
"lazy_load"
,
False
)
if
self
.
lazy_load
:
...
...
@@ -108,7 +106,6 @@ class WanSelfAttention(WeightModule):
self
.
task
=
task
self
.
config
=
config
self
.
quant_method
=
config
.
get
(
"quant_method"
,
None
)
self
.
sparge
=
config
.
get
(
"sparge"
,
False
)
self
.
lazy_load
=
lazy_load
self
.
lazy_load_file
=
lazy_load_file
...
...
@@ -185,16 +182,17 @@ class WanSelfAttention(WeightModule):
self
.
lazy_load_file
,
),
)
if
self
.
sparge
:
assert
self
.
config
[
"sparge_ckpt"
],
"sparge_ckpt must be set when sparge is True"
self
.
add_module
(
"self_attn_1"
,
ATTN_WEIGHT_REGISTER
[
"Sparge"
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
"
),
attention_weights_cls
=
ATTN_WEIGHT_REGISTER
[
self
.
config
[
"self_attn_1_type"
]]
if
self
.
config
[
"self_attn_1_type"
]
==
"svg_attn"
:
attention_weights_cls
.
prepare
(
head_num
=
self
.
config
[
"num_heads"
],
head_dim
=
self
.
config
[
"dim"
]
//
self
.
config
[
"num_heads"
],
sample_mse_max_row
=
self
.
config
.
get
(
"svg_sample_mse_max_row"
,
10000
),
num_sampled_rows
=
self
.
config
.
get
(
"svg_num_sampled_rows"
,
64
),
context_length
=
self
.
config
.
get
(
"svg_context_length"
,
0
),
sparsity
=
self
.
config
.
get
(
"svg_sparsity"
,
0.25
),
)
sparge_ckpt
=
torch
.
load
(
self
.
config
[
"sparge_ckpt"
])
self
.
self_attn_1
.
load
(
sparge_ckpt
)
else
:
self
.
add_module
(
"self_attn_1"
,
ATTN_WEIGHT_REGISTER
[
self
.
config
[
"self_attn_1_type"
]]())
self
.
add_module
(
"self_attn_1"
,
attention_weights_cls
())
if
self
.
config
[
"seq_parallel"
]:
self
.
add_module
(
"self_attn_1_parallel"
,
ATTN_WEIGHT_REGISTER
[
self
.
config
[
"parallel"
].
get
(
"seq_p_attn_type"
,
"ulysses"
)]())
...
...
lightx2v/utils/print_atten_score.py
0 → 100644
View file @
954df466
import
math
import
matplotlib.pyplot
as
plt
import
torch
import
torch.nn.functional
as
F
def
scaled_dot_product_attention
(
Q
,
K
,
V
,
mask
=
None
):
"""
Scaled dot-product attention
Args:
Q: Query tensor [batch_size, num_heads, seq_len, d_k]
K: Key tensor [batch_size, num_heads, seq_len, d_k]
V: Value tensor [batch_size, num_heads, seq_len, d_k]
mask: Attention mask (0 indicates positions to mask, 1 indicates positions to keep)
Returns:
output: Attention output
attention_weights: Attention weights
"""
d_k
=
Q
.
size
(
-
1
)
scores
=
torch
.
matmul
(
Q
,
K
.
transpose
(
-
2
,
-
1
))
/
math
.
sqrt
(
d_k
)
if
mask
is
not
None
:
mask_value
=
torch
.
where
(
mask
==
0
,
torch
.
tensor
(
-
float
(
"inf"
)),
torch
.
tensor
(
0.0
))
scores
=
scores
+
mask_value
attention_weights
=
F
.
softmax
(
scores
,
dim
=-
1
)
output
=
torch
.
matmul
(
attention_weights
,
V
)
return
output
,
scores
,
attention_weights
def
draw_matrix
(
weights
,
save_path
):
plt
.
imshow
(
weights
,
aspect
=
"auto"
,
cmap
=
"viridis"
)
plt
.
colorbar
()
plt
.
savefig
(
save_path
)
plt
.
close
()
def
get_qkv_subset
(
x
,
head_index
,
token_start
,
token_end
):
"""
x : [seq_len, num_heads, head_dim]
return: [batch_size, num_heads, seq_len, head_dim]
batch_size = 1, num_heads = 1, seq_len = token_end - token_start
"""
x
=
x
[
token_start
:
token_end
,
head_index
,
:]
# [seq_len, head_dim]
x
=
x
.
unsqueeze
(
0
).
unsqueeze
(
0
)
# [1, 1, seq_len, head_dim]
return
x
def
draw_attention_weights
(
q
,
k
,
v
,
head_index
,
token_start
,
token_end
,
save_path
):
"""
q k v : [seq_len, num_heads, head_dim]
"""
q_vis
=
get_qkv_subset
(
q
,
head_index
=
head_index
,
token_start
=
token_start
,
token_end
=
token_end
)
k_vis
=
get_qkv_subset
(
k
,
head_index
=
head_index
,
token_start
=
token_start
,
token_end
=
token_end
)
v_vis
=
get_qkv_subset
(
v
,
head_index
=
head_index
,
token_start
=
token_start
,
token_end
=
token_end
)
output
,
scores
,
attention_weights
=
scaled_dot_product_attention
(
q_vis
,
k_vis
,
v_vis
,
mask
=
None
)
draw_matrix
(
scores
[
0
][
0
].
float
().
cpu
().
numpy
(),
save_path
)
print
(
f
"Saved to
{
save_path
}
"
)
if
__name__
==
"__main__"
:
seq_len
=
10
num_heads
=
4
head_dim
=
8
q
=
torch
.
randn
(
seq_len
,
num_heads
,
head_dim
)
k
=
torch
.
randn
(
seq_len
,
num_heads
,
head_dim
)
v
=
torch
.
randn
(
seq_len
,
num_heads
,
head_dim
)
draw_attention_weights
(
q
,
k
,
v
,
head_index
=
0
,
token_start
=
0
,
token_end
=
10
,
save_path
=
"scores.png"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment