Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
1a881d63
Commit
1a881d63
authored
Jul 28, 2025
by
helloyongyang
Browse files
重构并行模块
parent
18e2b23a
Changes
47
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
587 additions
and
570 deletions
+587
-570
lightx2v/attentions/distributed/ulysses/wrap.py
lightx2v/attentions/distributed/ulysses/wrap.py
+0
-71
lightx2v/attentions/distributed/utils/__init__.py
lightx2v/attentions/distributed/utils/__init__.py
+0
-0
lightx2v/attentions/distributed/utils/hunyuan/processor.py
lightx2v/attentions/distributed/utils/hunyuan/processor.py
+0
-72
lightx2v/attentions/distributed/utils/process.py
lightx2v/attentions/distributed/utils/process.py
+0
-72
lightx2v/attentions/distributed/utils/wan/processor.py
lightx2v/attentions/distributed/utils/wan/processor.py
+0
-37
lightx2v/common/ops/attn/__init__.py
lightx2v/common/ops/attn/__init__.py
+7
-1
lightx2v/common/ops/attn/attn_weight.py
lightx2v/common/ops/attn/attn_weight.py
+0
-292
lightx2v/common/ops/attn/flash_attn.py
lightx2v/common/ops/attn/flash_attn.py
+73
-0
lightx2v/common/ops/attn/radial_attn.py
lightx2v/common/ops/attn/radial_attn.py
+38
-0
lightx2v/common/ops/attn/ring_attn.py
lightx2v/common/ops/attn/ring_attn.py
+173
-0
lightx2v/common/ops/attn/sage_attn.py
lightx2v/common/ops/attn/sage_attn.py
+61
-0
lightx2v/common/ops/attn/sparge_attn.py
lightx2v/common/ops/attn/sparge_attn.py
+64
-0
lightx2v/common/ops/attn/template.py
lightx2v/common/ops/attn/template.py
+29
-0
lightx2v/common/ops/attn/torch_sdpa.py
lightx2v/common/ops/attn/torch_sdpa.py
+38
-0
lightx2v/common/ops/attn/ulysses_attn.py
lightx2v/common/ops/attn/ulysses_attn.py
+98
-0
lightx2v/common/ops/attn/utils/all2all.py
lightx2v/common/ops/attn/utils/all2all.py
+0
-0
lightx2v/common/ops/attn/utils/ring_comm.py
lightx2v/common/ops/attn/utils/ring_comm.py
+0
-0
lightx2v/models/input_encoders/hf/xlm_roberta/model.py
lightx2v/models/input_encoders/hf/xlm_roberta/model.py
+3
-2
lightx2v/models/networks/hunyuan/model.py
lightx2v/models/networks/hunyuan/model.py
+0
-10
lightx2v/models/networks/wan/audio_model.py
lightx2v/models/networks/wan/audio_model.py
+3
-13
No files found.
lightx2v/attentions/distributed/ulysses/wrap.py
deleted
100644 → 0
View file @
18e2b23a
import
functools
from
lightx2v.attentions.distributed.ulysses.attn
import
ulysses_attn
def
parallelize_hunyuan
(
hunyuan_model
):
from
lightx2v.attentions.distributed.utils.hunyuan.processor
import
pre_process
,
post_process
"""将 Hunyuan 模型的推理过程并行化,使用 Ulysses 注意力机制。
参数:
hunyuan_model: Hunyuan 模型实例,包含推理方法和其他属性。
"""
# 将 Hunyuan 模型的并行注意力机制替换为 Ulysses 注意力
hunyuan_model
.
transformer_infer
.
parallel_attention
=
ulysses_attn
# 保存原始的推理方法,以便后续调用
original_infer
=
hunyuan_model
.
infer
@
functools
.
wraps
(
hunyuan_model
.
__class__
.
infer
)
# 保留原始推理方法的元信息
def
new_infer
(
self
,
text_encoders_output
,
image_encoder_output
,
args
):
"""新的推理方法,处理输入并调用原始推理方法。
参数:
self: Hunyuan 模型实例
text_encoders_output: 文本编码器的输出
args: 其他参数
返回:
None
"""
# 保存原始的潜在模型输入和频率数据
self
.
scheduler
.
ori_latents
,
self
.
scheduler
.
ori_freqs_cos
,
self
.
scheduler
.
ori_freqs_sin
=
(
self
.
scheduler
.
latents
,
self
.
scheduler
.
freqs_cos
,
self
.
scheduler
.
freqs_sin
)
# 预处理输入数据以适应并行计算
self
.
scheduler
.
latents
,
self
.
scheduler
.
freqs_cos
,
self
.
scheduler
.
freqs_sin
,
split_dim
=
pre_process
(
self
.
scheduler
.
latents
,
self
.
scheduler
.
freqs_cos
,
self
.
scheduler
.
freqs_sin
)
# 调用原始推理方法,获取输出
original_infer
(
text_encoders_output
,
image_encoder_output
,
args
)
# 对输出进行后处理
self
.
scheduler
.
noise_pred
=
post_process
(
self
.
scheduler
.
noise_pred
,
split_dim
)
# 恢复原始的潜在模型输入和频率数据
self
.
scheduler
.
latents
,
self
.
scheduler
.
freqs_cos
,
self
.
scheduler
.
freqs_sin
=
(
self
.
scheduler
.
ori_latents
,
self
.
scheduler
.
ori_freqs_cos
,
self
.
scheduler
.
ori_freqs_sin
)
# return combined_output # 返回处理后的输出(当前被注释掉)
# 将新的推理方法绑定到 Hunyuan 模型实例
new_infer
=
new_infer
.
__get__
(
hunyuan_model
)
hunyuan_model
.
infer
=
new_infer
# 替换原始推理方法
def
parallelize_wan
(
wan_model
):
from
lightx2v.attentions.distributed.utils.wan.processor
import
pre_process
,
post_process
wan_model
.
transformer_infer
.
parallel_attention
=
ulysses_attn
original_infer
=
wan_model
.
transformer_infer
.
infer
@
functools
.
wraps
(
wan_model
.
transformer_infer
.
__class__
.
infer
)
# 保留原始推理方法的元信息
def
new_infer
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
):
x
=
pre_process
(
x
)
x
=
original_infer
(
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
)
x
=
post_process
(
x
)
return
x
new_infer
=
new_infer
.
__get__
(
wan_model
.
transformer_infer
)
wan_model
.
transformer_infer
.
infer
=
new_infer
# 替换原始推理方法
lightx2v/attentions/distributed/utils/__init__.py
deleted
100644 → 0
View file @
18e2b23a
lightx2v/attentions/distributed/utils/hunyuan/processor.py
deleted
100644 → 0
View file @
18e2b23a
import
torch
import
torch.distributed
as
dist
def
pre_process
(
latent_model_input
,
freqs_cos
,
freqs_sin
):
"""
对输入的潜在模型数据和频率数据进行预处理,进行切分以适应分布式计算。
参数:
latent_model_input (torch.Tensor): 输入的潜在模型数据,形状为 [batch_size, channels, temporal_size, height, width]
freqs_cos (torch.Tensor): 余弦频率数据,形状为 [batch_size, channels, temporal_size, height, width]
freqs_sin (torch.Tensor): 正弦频率数据,形状为 [batch_size, channels, temporal_size, height, width]
返回:
tuple: 处理后的 latent_model_input, freqs_cos, freqs_sin 和切分维度 split_dim
"""
# 获取当前进程的世界大小和当前进程的排名
world_size
=
dist
.
get_world_size
()
cur_rank
=
dist
.
get_rank
()
# 根据输入的形状确定切分维度
if
latent_model_input
.
shape
[
-
2
]
//
2
%
world_size
==
0
:
split_dim
=
-
2
# 按高度切分
elif
latent_model_input
.
shape
[
-
1
]
//
2
%
world_size
==
0
:
split_dim
=
-
1
# 按宽度切分
else
:
raise
ValueError
(
f
"Cannot split video sequence into world size (
{
world_size
}
) parts evenly"
)
# 获取时间维度、处理后的高度和宽度
temporal_size
,
h
,
w
=
latent_model_input
.
shape
[
2
],
latent_model_input
.
shape
[
3
]
//
2
,
latent_model_input
.
shape
[
4
]
//
2
# 按照确定的维度切分潜在模型输入
latent_model_input
=
torch
.
chunk
(
latent_model_input
,
world_size
,
dim
=
split_dim
)[
cur_rank
]
# 处理余弦频率数据
dim_thw
=
freqs_cos
.
shape
[
-
1
]
# 获取频率数据的最后一个维度
freqs_cos
=
freqs_cos
.
reshape
(
temporal_size
,
h
,
w
,
dim_thw
)
# 重塑为 [temporal_size, height, width, dim_thw]
freqs_cos
=
torch
.
chunk
(
freqs_cos
,
world_size
,
dim
=
split_dim
-
1
)[
cur_rank
]
# 切分频率数据
freqs_cos
=
freqs_cos
.
reshape
(
-
1
,
dim_thw
)
# 重塑为 [batch_size, dim_thw]
# 处理正弦频率数据
dim_thw
=
freqs_sin
.
shape
[
-
1
]
# 获取频率数据的最后一个维度
freqs_sin
=
freqs_sin
.
reshape
(
temporal_size
,
h
,
w
,
dim_thw
)
# 重塑为 [temporal_size, height, width, dim_thw]
freqs_sin
=
torch
.
chunk
(
freqs_sin
,
world_size
,
dim
=
split_dim
-
1
)[
cur_rank
]
# 切分频率数据
freqs_sin
=
freqs_sin
.
reshape
(
-
1
,
dim_thw
)
# 重塑为 [batch_size, dim_thw]
return
latent_model_input
,
freqs_cos
,
freqs_sin
,
split_dim
# 返回处理后的数据
def
post_process
(
output
,
split_dim
):
"""对输出进行后处理,收集所有进程的输出并合并。
参数:
output (torch.Tensor): 当前进程的输出,形状为 [batch_size, ...]
split_dim (int): 切分维度,用于合并输出
返回:
torch.Tensor: 合并后的输出,形状为 [world_size * batch_size, ...]
"""
# 获取当前进程的世界大小
world_size
=
dist
.
get_world_size
()
# 创建一个列表,用于存储所有进程的输出
gathered_outputs
=
[
torch
.
empty_like
(
output
)
for
_
in
range
(
world_size
)]
# 收集所有进程的输出
dist
.
all_gather
(
gathered_outputs
,
output
)
# 在指定的维度上合并所有进程的输出
combined_output
=
torch
.
cat
(
gathered_outputs
,
dim
=
split_dim
)
return
combined_output
# 返回合并后的输出
lightx2v/attentions/distributed/utils/process.py
deleted
100644 → 0
View file @
18e2b23a
import
torch
import
torch.distributed
as
dist
def
pre_process
(
latent_model_input
,
freqs_cos
,
freqs_sin
):
"""
对输入的潜在模型数据和频率数据进行预处理,进行切分以适应分布式计算。
参数:
latent_model_input (torch.Tensor): 输入的潜在模型数据,形状为 [batch_size, channels, temporal_size, height, width]
freqs_cos (torch.Tensor): 余弦频率数据,形状为 [batch_size, channels, temporal_size, height, width]
freqs_sin (torch.Tensor): 正弦频率数据,形状为 [batch_size, channels, temporal_size, height, width]
返回:
tuple: 处理后的 latent_model_input, freqs_cos, freqs_sin 和切分维度 split_dim
"""
# 获取当前进程的世界大小和当前进程的排名
world_size
=
dist
.
get_world_size
()
cur_rank
=
dist
.
get_rank
()
# 根据输入的形状确定切分维度
if
latent_model_input
.
shape
[
-
2
]
//
2
%
world_size
==
0
:
split_dim
=
-
2
# 按高度切分
elif
latent_model_input
.
shape
[
-
1
]
//
2
%
world_size
==
0
:
split_dim
=
-
1
# 按宽度切分
else
:
raise
ValueError
(
f
"Cannot split video sequence into world size (
{
world_size
}
) parts evenly"
)
# 获取时间维度、处理后的高度和宽度
temporal_size
,
h
,
w
=
latent_model_input
.
shape
[
2
],
latent_model_input
.
shape
[
3
]
//
2
,
latent_model_input
.
shape
[
4
]
//
2
# 按照确定的维度切分潜在模型输入
latent_model_input
=
torch
.
chunk
(
latent_model_input
,
world_size
,
dim
=
split_dim
)[
cur_rank
]
# 处理余弦频率数据
dim_thw
=
freqs_cos
.
shape
[
-
1
]
# 获取频率数据的最后一个维度
freqs_cos
=
freqs_cos
.
reshape
(
temporal_size
,
h
,
w
,
dim_thw
)
# 重塑为 [temporal_size, height, width, dim_thw]
freqs_cos
=
torch
.
chunk
(
freqs_cos
,
world_size
,
dim
=
split_dim
-
1
)[
cur_rank
]
# 切分频率数据
freqs_cos
=
freqs_cos
.
reshape
(
-
1
,
dim_thw
)
# 重塑为 [batch_size, dim_thw]
# 处理正弦频率数据
dim_thw
=
freqs_sin
.
shape
[
-
1
]
# 获取频率数据的最后一个维度
freqs_sin
=
freqs_sin
.
reshape
(
temporal_size
,
h
,
w
,
dim_thw
)
# 重塑为 [temporal_size, height, width, dim_thw]
freqs_sin
=
torch
.
chunk
(
freqs_sin
,
world_size
,
dim
=
split_dim
-
1
)[
cur_rank
]
# 切分频率数据
freqs_sin
=
freqs_sin
.
reshape
(
-
1
,
dim_thw
)
# 重塑为 [batch_size, dim_thw]
return
latent_model_input
,
freqs_cos
,
freqs_sin
,
split_dim
# 返回处理后的数据
def
post_process
(
output
,
split_dim
):
"""对输出进行后处理,收集所有进程的输出并合并。
参数:
output (torch.Tensor): 当前进程的输出,形状为 [batch_size, ...]
split_dim (int): 切分维度,用于合并输出
返回:
torch.Tensor: 合并后的输出,形状为 [world_size * batch_size, ...]
"""
# 获取当前进程的世界大小
world_size
=
dist
.
get_world_size
()
# 创建一个列表,用于存储所有进程的输出
gathered_outputs
=
[
torch
.
empty_like
(
output
)
for
_
in
range
(
world_size
)]
# 收集所有进程的输出
dist
.
all_gather
(
gathered_outputs
,
output
)
# 在指定的维度上合并所有进程的输出
combined_output
=
torch
.
cat
(
gathered_outputs
,
dim
=
split_dim
)
return
combined_output
# 返回合并后的输出
lightx2v/attentions/distributed/utils/wan/processor.py
deleted
100644 → 0
View file @
18e2b23a
from
re
import
split
import
torch
import
torch.distributed
as
dist
import
torch.nn.functional
as
F
PADDING_SIZE
=
None
def
pre_process
(
x
):
world_size
=
dist
.
get_world_size
()
cur_rank
=
dist
.
get_rank
()
padding_size
=
(
world_size
-
(
x
.
shape
[
0
]
%
world_size
))
%
world_size
if
padding_size
>
0
:
# 使用 F.pad 填充第一维
x
=
F
.
pad
(
x
,
(
0
,
0
,
0
,
padding_size
))
# (后维度填充, 前维度填充)
x
=
torch
.
chunk
(
x
,
world_size
,
dim
=
0
)[
cur_rank
]
return
x
def
post_process
(
x
):
# 获取当前进程的世界大小
world_size
=
dist
.
get_world_size
()
# 创建一个列表,用于存储所有进程的输出
gathered_x
=
[
torch
.
empty_like
(
x
)
for
_
in
range
(
world_size
)]
# 收集所有进程的输出
dist
.
all_gather
(
gathered_x
,
x
)
# 在指定的维度上合并所有进程的输出
combined_output
=
torch
.
cat
(
gathered_x
,
dim
=
0
)
return
combined_output
# 返回合并后的输出
lightx2v/common/ops/attn/__init__.py
View file @
1a881d63
from
.attn_weight
import
*
from
.flash_attn
import
*
from
.radial_attn
import
*
from
.ring_attn
import
*
from
.sage_attn
import
*
from
.torch_sdpa
import
*
from
.ulysses_attn
import
*
from
.sparge_attn
import
*
lightx2v/common/ops/attn/attn_weight.py
deleted
100755 → 0
View file @
18e2b23a
import
torch
import
torch.nn
as
nn
from
abc
import
ABCMeta
,
abstractmethod
from
lightx2v.utils.registry_factory
import
ATTN_WEIGHT_REGISTER
import
torch.nn.functional
as
F
from
loguru
import
logger
try
:
from
spas_sage_attn.autotune
import
SparseAttentionMeansim
except
ImportError
:
logger
.
info
(
"SparseAttentionMeansim not found, please install sparge first"
)
SparseAttentionMeansim
=
None
try
:
from
flash_attn.flash_attn_interface
import
flash_attn_varlen_func
except
ImportError
:
logger
.
info
(
"flash_attn_varlen_func not found, please install flash_attn2 first"
)
flash_attn_varlen_func
=
None
try
:
from
flash_attn_interface
import
flash_attn_varlen_func
as
flash_attn_varlen_func_v3
except
ImportError
:
logger
.
info
(
"flash_attn_varlen_func_v3 not found, please install flash_attn3 first"
)
flash_attn_varlen_func_v3
=
None
if
torch
.
cuda
.
get_device_capability
(
0
)
==
(
8
,
9
):
try
:
from
sageattention
import
sageattn_qk_int8_pv_fp16_triton
as
sageattn
except
ImportError
:
logger
.
info
(
"sageattn not found, please install sageattention first"
)
sageattn
=
None
else
:
try
:
from
sageattention
import
sageattn
except
ImportError
:
logger
.
info
(
"sageattn not found, please install sageattention first"
)
sageattn
=
None
from
lightx2v.attentions.common.radial_attn
import
radial_attn
class
AttnWeightTemplate
(
metaclass
=
ABCMeta
):
def
__init__
(
self
,
weight_name
):
self
.
weight_name
=
weight_name
self
.
config
=
{}
def
load
(
self
,
weight_dict
):
pass
@
abstractmethod
def
apply
(
self
,
input_tensor
):
pass
def
set_config
(
self
,
config
=
None
):
if
config
is
not
None
:
self
.
config
=
config
def
to_cpu
(
self
,
non_blocking
=
False
):
pass
def
to_cuda
(
self
,
non_blocking
=
False
):
pass
def
state_dict
(
self
,
destination
=
None
):
if
destination
is
None
:
destination
=
{}
return
destination
@
ATTN_WEIGHT_REGISTER
(
"flash_attn2"
)
class
FlashAttn2Weight
(
AttnWeightTemplate
):
def
__init__
(
self
):
self
.
config
=
{}
def
apply
(
self
,
q
,
k
,
v
,
cu_seqlens_q
=
None
,
cu_seqlens_kv
=
None
,
max_seqlen_q
=
None
,
max_seqlen_kv
=
None
,
model_cls
=
None
,
mask_map
=
None
,
):
x
=
flash_attn_varlen_func
(
q
,
k
,
v
,
cu_seqlens_q
,
cu_seqlens_kv
,
max_seqlen_q
,
max_seqlen_kv
,
).
reshape
(
max_seqlen_q
,
-
1
)
return
x
@
ATTN_WEIGHT_REGISTER
(
"flash_attn3"
)
class
FlashAttn3Weight
(
AttnWeightTemplate
):
def
__init__
(
self
):
self
.
config
=
{}
def
apply
(
self
,
q
,
k
,
v
,
cu_seqlens_q
=
None
,
cu_seqlens_kv
=
None
,
max_seqlen_q
=
None
,
max_seqlen_kv
=
None
,
model_cls
=
None
,
mask_map
=
None
,
):
x
=
flash_attn_varlen_func_v3
(
q
,
k
,
v
,
cu_seqlens_q
,
cu_seqlens_kv
,
max_seqlen_q
,
max_seqlen_kv
,
)[
0
].
reshape
(
max_seqlen_q
,
-
1
)
return
x
@
ATTN_WEIGHT_REGISTER
(
"radial_attn"
)
class
RadialAttnWeight
(
AttnWeightTemplate
):
def
__init__
(
self
):
self
.
config
=
{}
def
apply
(
self
,
q
,
k
,
v
,
cu_seqlens_q
=
None
,
cu_seqlens_kv
=
None
,
max_seqlen_q
=
None
,
max_seqlen_kv
=
None
,
mask_map
=
None
,
sparsity_type
=
"radial"
,
block_size
=
128
,
decay_factor
=
1
,
model_cls
=
"wan"
,
):
assert
len
(
q
.
shape
)
==
3
x
=
radial_attn
(
q
,
k
,
v
,
mask_map
=
mask_map
,
sparsity_type
=
sparsity_type
,
block_size
=
block_size
,
model_cls
=
model_cls
[:
3
],
# Use first 3 characters to match "wan", "wan2", etc.
decay_factor
=
decay_factor
,
)
x
=
x
.
view
(
max_seqlen_q
,
-
1
)
return
x
@
ATTN_WEIGHT_REGISTER
(
"sage_attn2"
)
class
SageAttn2Weight
(
AttnWeightTemplate
):
def
__init__
(
self
):
self
.
config
=
{}
def
apply
(
self
,
q
,
k
,
v
,
cu_seqlens_q
=
None
,
cu_seqlens_kv
=
None
,
max_seqlen_q
=
None
,
max_seqlen_kv
=
None
,
model_cls
=
None
,
mask_map
=
None
,
):
q
,
k
,
v
=
q
.
contiguous
(),
k
.
contiguous
(),
v
.
contiguous
()
if
model_cls
==
"hunyuan"
:
x1
=
sageattn
(
q
[:
cu_seqlens_q
[
1
]].
unsqueeze
(
0
),
k
[:
cu_seqlens_kv
[
1
]].
unsqueeze
(
0
),
v
[:
cu_seqlens_kv
[
1
]].
unsqueeze
(
0
),
tensor_layout
=
"NHD"
,
)
x2
=
sageattn
(
q
[
cu_seqlens_q
[
1
]
:].
unsqueeze
(
0
),
k
[
cu_seqlens_kv
[
1
]
:].
unsqueeze
(
0
),
v
[
cu_seqlens_kv
[
1
]
:].
unsqueeze
(
0
),
tensor_layout
=
"NHD"
,
)
x
=
torch
.
cat
((
x1
,
x2
),
dim
=
1
)
x
=
x
.
view
(
max_seqlen_q
,
-
1
)
elif
model_cls
in
[
"wan2.1"
,
"wan2.1_distill"
,
"wan2.1_causvid"
,
"wan2.1_df"
]:
x
=
sageattn
(
q
.
unsqueeze
(
0
),
k
.
unsqueeze
(
0
),
v
.
unsqueeze
(
0
),
tensor_layout
=
"NHD"
,
)
x
=
x
.
view
(
max_seqlen_q
,
-
1
)
return
x
@
ATTN_WEIGHT_REGISTER
(
"torch_sdpa"
)
class
TorchSDPAWeight
(
AttnWeightTemplate
):
def
__init__
(
self
):
self
.
config
=
{}
def
apply
(
self
,
q
,
k
,
v
,
drop_rate
=
0
,
attn_mask
=
None
,
causal
=
False
,
cu_seqlens_q
=
None
,
cu_seqlens_kv
=
None
,
max_seqlen_q
=
None
,
max_seqlen_kv
=
None
,
model_cls
=
None
,
mask_map
=
None
,
):
q
,
k
,
v
=
q
.
unsqueeze
(
0
),
k
.
unsqueeze
(
0
),
v
.
unsqueeze
(
0
)
q
=
q
.
transpose
(
1
,
2
)
k
=
k
.
transpose
(
1
,
2
)
v
=
v
.
transpose
(
1
,
2
)
if
attn_mask
is
not
None
and
attn_mask
.
dtype
!=
torch
.
bool
:
attn_mask
=
attn_mask
.
to
(
q
.
dtype
)
x
=
F
.
scaled_dot_product_attention
(
q
,
k
,
v
,
attn_mask
=
attn_mask
,
dropout_p
=
drop_rate
,
is_causal
=
causal
)
x
=
x
.
transpose
(
1
,
2
)
b
,
s
,
a
,
d
=
x
.
shape
out
=
x
.
reshape
(
b
,
s
,
-
1
)
return
out
.
squeeze
(
0
)
@
ATTN_WEIGHT_REGISTER
(
"Sparge"
)
class
SpargeAttnWeight
(
AttnWeightTemplate
):
def
__init__
(
self
,
weight_name
,
verbose
=
False
,
l1
=
0.07
,
pv_l1
=
0.08
,
tune_pv
=
True
,
inner_attn_type
=
"flash_attn3"
,
):
self
.
verbose
=
(
verbose
,)
self
.
l1
=
(
l1
,)
self
.
pv_l1
=
(
pv_l1
,)
self
.
tune_pv
=
(
tune_pv
,)
self
.
inner_attn_type
=
inner_attn_type
self
.
inner_cls
=
SparseAttentionMeansim
(
l1
=
l1
,
pv_l1
=
pv_l1
,
tune_pv
=
tune_pv
)
super
().
__init__
(
weight_name
)
def
load
(
self
,
weight_dict
):
# match all key with prefix weight_name
for
key
in
weight_dict
.
keys
():
if
key
.
startswith
(
self
.
weight_name
):
sub_name
=
key
.
split
(
"."
)[
-
1
]
setattr
(
self
.
inner_cls
,
sub_name
,
nn
.
Parameter
(
weight_dict
[
key
],
requires_grad
=
False
),
)
def
apply
(
self
,
q
,
k
,
v
,
cu_seqlens_q
=
None
,
cu_seqlens_kv
=
None
,
max_seqlen_q
=
None
,
max_seqlen_kv
=
None
,
model_cls
=
None
,
):
if
len
(
q
.
shape
)
==
3
:
q
=
q
.
unsqueeze
(
0
)
k
=
k
.
unsqueeze
(
0
)
v
=
v
.
unsqueeze
(
0
)
x
=
self
.
inner_cls
(
q
,
k
,
v
,
tensor_layout
=
"NHD"
)
x
=
x
.
flatten
(
2
)
x
=
x
.
squeeze
(
0
)
return
x
lightx2v/common/ops/attn/flash_attn.py
0 → 100644
View file @
1a881d63
try
:
import
flash_attn
from
flash_attn.flash_attn_interface
import
flash_attn_varlen_func
except
ImportError
:
logger
.
info
(
"flash_attn_varlen_func not found, please install flash_attn2 first"
)
flash_attn_varlen_func
=
None
try
:
from
flash_attn_interface
import
flash_attn_varlen_func
as
flash_attn_varlen_func_v3
except
ImportError
:
logger
.
info
(
"flash_attn_varlen_func_v3 not found, please install flash_attn3 first"
)
flash_attn_varlen_func_v3
=
None
from
.template
import
AttnWeightTemplate
from
lightx2v.utils.registry_factory
import
ATTN_WEIGHT_REGISTER
@
ATTN_WEIGHT_REGISTER
(
"flash_attn2"
)
class
FlashAttn2Weight
(
AttnWeightTemplate
):
def
__init__
(
self
):
self
.
config
=
{}
def
apply
(
self
,
q
,
k
,
v
,
cu_seqlens_q
=
None
,
cu_seqlens_kv
=
None
,
max_seqlen_q
=
None
,
max_seqlen_kv
=
None
,
model_cls
=
None
,
mask_map
=
None
,
):
x
=
flash_attn_varlen_func
(
q
,
k
,
v
,
cu_seqlens_q
,
cu_seqlens_kv
,
max_seqlen_q
,
max_seqlen_kv
,
).
reshape
(
max_seqlen_q
,
-
1
)
return
x
@
ATTN_WEIGHT_REGISTER
(
"flash_attn3"
)
class
FlashAttn3Weight
(
AttnWeightTemplate
):
def
__init__
(
self
):
self
.
config
=
{}
def
apply
(
self
,
q
,
k
,
v
,
cu_seqlens_q
=
None
,
cu_seqlens_kv
=
None
,
max_seqlen_q
=
None
,
max_seqlen_kv
=
None
,
model_cls
=
None
,
mask_map
=
None
,
):
x
=
flash_attn_varlen_func_v3
(
q
,
k
,
v
,
cu_seqlens_q
,
cu_seqlens_kv
,
max_seqlen_q
,
max_seqlen_kv
,
)[
0
].
reshape
(
max_seqlen_q
,
-
1
)
return
x
lightx2v/
attentions/commo
n/radial_attn.py
→
lightx2v/
common/ops/att
n/radial_attn.py
100755 → 100644
View file @
1a881d63
import
torch
from
.template
import
AttnWeightTemplate
from
lightx2v.utils.registry_factory
import
ATTN_WEIGHT_REGISTER
try
:
import
flashinfer
...
...
@@ -15,6 +17,42 @@ except ImportError:
###
@
ATTN_WEIGHT_REGISTER
(
"radial_attn"
)
class
RadialAttnWeight
(
AttnWeightTemplate
):
def
__init__
(
self
):
self
.
config
=
{}
def
apply
(
self
,
q
,
k
,
v
,
cu_seqlens_q
=
None
,
cu_seqlens_kv
=
None
,
max_seqlen_q
=
None
,
max_seqlen_kv
=
None
,
mask_map
=
None
,
sparsity_type
=
"radial"
,
block_size
=
128
,
decay_factor
=
1
,
model_cls
=
"wan"
,
):
assert
len
(
q
.
shape
)
==
3
x
=
radial_attn
(
q
,
k
,
v
,
mask_map
=
mask_map
,
sparsity_type
=
sparsity_type
,
block_size
=
block_size
,
model_cls
=
model_cls
[:
3
],
# Use first 3 characters to match "wan", "wan2", etc.
decay_factor
=
decay_factor
,
)
x
=
x
.
view
(
max_seqlen_q
,
-
1
)
return
x
def
radial_attn
(
query
,
key
,
value
,
cu_seqlens_q
=
None
,
cu_seqlens_kv
=
None
,
max_seqlen_q
=
None
,
max_seqlen_kv
=
None
,
mask_map
=
None
,
sparsity_type
=
"radial"
,
block_size
=
128
,
decay_factor
=
1
,
model_cls
=
"wan"
):
...
...
lightx2v/common/ops/attn/ring_attn.py
0 → 100644
View file @
1a881d63
import
torch
from
.template
import
AttnWeightTemplate
from
lightx2v.utils.registry_factory
import
ATTN_WEIGHT_REGISTER
import
torch.distributed
as
dist
from
.utils.ring_comm
import
RingComm
import
torch.nn.functional
as
F
try
:
import
flash_attn
from
flash_attn.flash_attn_interface
import
flash_attn_varlen_func
except
ImportError
:
logger
.
info
(
"flash_attn_varlen_func not found, please install flash_attn2 first"
)
flash_attn_varlen_func
=
None
@
torch
.
jit
.
script
def
_update_out_and_lse
(
out
,
lse
,
block_out
,
block_lse
,
):
block_out
=
block_out
.
to
(
torch
.
float32
)
block_lse
=
block_lse
.
transpose
(
-
2
,
-
1
).
unsqueeze
(
dim
=-
1
)
# new_lse = lse + torch.log(1 + torch.exp(block_lse - lse))
# torch.exp(lse - new_lse) * out + torch.exp(block_lse - new_lse) * block_out
# For additional context and discussion, please refer to:
# https://github.com/zhuzilin/ring-flash-attention/pull/34#issuecomment-2076126795
out
=
out
-
F
.
sigmoid
(
block_lse
-
lse
)
*
(
out
-
block_out
)
lse
=
lse
-
F
.
logsigmoid
(
lse
-
block_lse
)
return
out
,
lse
@
ATTN_WEIGHT_REGISTER
(
"ring"
)
class
RingAttnWeight
(
AttnWeightTemplate
):
def
__init__
(
self
):
self
.
config
=
{}
def
apply
(
self
,
q
,
k
,
v
,
img_qkv_len
,
cu_seqlens_qkv
,
attention_module
=
None
):
"""
执行 Ring 注意力机制,结合图像和文本的查询、键和值。
参数:
q (torch.Tensor): 查询张量,形状为 [shard_seqlen, heads, hidden_dims]
k (torch.Tensor): 键张量,形状为 [shard_seqlen, heads, hidden_dims]
v (torch.Tensor): 值张量,形状为 [shard_seqlen, heads, hidden_dims]
img_qkv_len (int): 图像查询、键和值的长度
cu_seqlens_qkv (torch.Tensor): 累积序列长度,包含文本和图像的长度信息
attention_type (str): 注意力类型,默认为 "flash_attn2"
返回:
torch.Tensor: 计算得到的注意力结果
"""
# 获取当前进程的排名和全局进程数
cur_rank
=
dist
.
get_rank
()
world_size
=
dist
.
get_world_size
()
if
len
(
cu_seqlens_qkv
)
==
3
:
txt_qkv_len
=
cu_seqlens_qkv
[
1
]
-
img_qkv_len
# 文本查询、键和值的长度
txt_mask_len
=
cu_seqlens_qkv
[
2
]
-
img_qkv_len
# 文本掩码长度
elif
len
(
cu_seqlens_qkv
)
==
2
:
txt_qkv_len
=
cu_seqlens_qkv
[
1
]
-
img_qkv_len
# 文本查询、键和值的长度
txt_mask_len
=
0
# if RING_COMM is None:
# init_ring_comm()
RING_COMM
=
RingComm
()
# if len(cu_seqlens_qkv) == 3:
# txt_qkv_len = cu_seqlens_qkv[1] - img_qkv_len # 文本查询、键和值的长度
# txt_mask_len = cu_seqlens_qkv[2] - img_qkv_len # 文本掩码长度
# elif len(cu_seqlens_qkv) == 2:
# txt_qkv_len = cu_seqlens_qkv[1] - img_qkv_len # 文本查询、键和值的长度
# txt_mask_len = None
q
=
q
.
unsqueeze
(
0
)
k
=
k
.
unsqueeze
(
0
)
v
=
v
.
unsqueeze
(
0
)
img_q
,
img_k
,
img_v
=
q
[:,
:
img_qkv_len
,
:,
:].
contiguous
(),
k
[:,
:
img_qkv_len
,
:,
:].
contiguous
(),
v
[:,
:
img_qkv_len
,
:,
:].
contiguous
()
txt_q
,
txt_k
,
txt_v
=
(
q
[:,
img_qkv_len
:
img_qkv_len
+
txt_qkv_len
,
:,
:].
contiguous
(),
k
[:,
img_qkv_len
:
img_qkv_len
+
txt_qkv_len
,
:,
:].
contiguous
(),
v
[:,
img_qkv_len
:
img_qkv_len
+
txt_qkv_len
,
:,
:].
contiguous
(),
)
out
,
lse
,
next_k
,
next_v
=
None
,
None
,
None
,
None
if
len
(
cu_seqlens_qkv
)
==
3
:
q
=
torch
.
cat
((
img_q
,
txt_q
),
dim
=
1
)
k
=
img_k
v
=
img_v
for
step
in
range
(
world_size
):
if
step
+
1
!=
world_size
:
next_k
=
RING_COMM
.
send_recv
(
k
)
next_v
=
RING_COMM
.
send_recv
(
v
)
RING_COMM
.
commit
()
if
step
+
1
==
world_size
:
k
=
torch
.
cat
((
k
,
txt_k
),
dim
=
1
)
v
=
torch
.
cat
((
v
,
txt_v
),
dim
=
1
)
block_out
,
block_lse
=
self
.
ring_attn_sub
(
q
,
k
,
v
)
out
,
lse
=
self
.
update_out_and_lse
(
out
,
lse
,
block_out
,
block_lse
)
if
step
+
1
!=
world_size
:
RING_COMM
.
wait
()
k
=
next_k
v
=
next_v
attn1
=
out
.
to
(
torch
.
bfloat16
).
squeeze
(
0
).
reshape
(
img_qkv_len
+
txt_qkv_len
,
-
1
)
if
txt_mask_len
>
0
:
attn2
,
*
_
=
flash_attn
.
flash_attn_interface
.
_flash_attn_forward
(
q
[:,
-
(
txt_mask_len
-
txt_qkv_len
)
:,
:,
:].
contiguous
(),
k
[:,
-
(
txt_mask_len
-
txt_qkv_len
)
:,
:,
:].
contiguous
(),
v
[:,
-
(
txt_mask_len
-
txt_qkv_len
)
:,
:,
:].
contiguous
(),
dropout_p
=
0.0
,
softmax_scale
=
q
.
shape
[
-
1
]
**
(
-
0.5
),
causal
=
False
,
window_size_left
=-
1
,
window_size_right
=-
1
,
softcap
=
0.0
,
alibi_slopes
=
None
,
return_softmax
=
False
,
)
attn2
=
attn2
.
to
(
torch
.
bfloat16
).
squeeze
(
0
).
reshape
((
txt_mask_len
-
txt_qkv_len
),
-
1
)
attn1
=
torch
.
cat
([
attn1
,
attn2
],
dim
=
0
)
return
attn1
def
ring_attn_sub
(
self
,
q
,
k
,
v
,
dropout_p
=
0.0
,
softmax_scale
=
None
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
softcap
=
0.0
,
alibi_slopes
=
None
,
return_softmax
=
False
):
if
softmax_scale
is
None
:
softmax_scale
=
q
.
shape
[
-
1
]
**
(
-
0.5
)
block_out
,
block_lse
,
_
,
_
=
flash_attn
.
flash_attn_interface
.
_flash_attn_forward
(
q
,
k
,
v
,
dropout_p
=
dropout_p
,
softmax_scale
=
softmax_scale
,
causal
=
causal
,
window_size_left
=
window_size
[
0
],
window_size_right
=
window_size
[
1
],
softcap
=
softcap
,
alibi_slopes
=
alibi_slopes
,
return_softmax
=
return_softmax
,
)
return
block_out
,
block_lse
def
update_out_and_lse
(
self
,
out
,
lse
,
block_out
,
block_lse
,
slice_
=
None
,
):
if
out
is
None
:
if
slice_
is
not
None
:
raise
RuntimeError
(
"first update_out_and_lse should not pass slice_ args"
)
out
=
block_out
.
to
(
torch
.
float32
)
lse
=
block_lse
.
transpose
(
-
2
,
-
1
).
unsqueeze
(
dim
=-
1
)
elif
slice_
is
not
None
:
slice_out
,
slice_lse
=
out
[
slice_
],
lse
[
slice_
]
slice_out
,
slice_lse
=
_update_out_and_lse
(
slice_out
,
slice_lse
,
block_out
,
block_lse
)
out
[
slice_
],
lse
[
slice_
]
=
slice_out
,
slice_lse
else
:
out
,
lse
=
_update_out_and_lse
(
out
,
lse
,
block_out
,
block_lse
)
return
out
,
lse
lightx2v/common/ops/attn/sage_attn.py
0 → 100644
View file @
1a881d63
import
torch
from
.template
import
AttnWeightTemplate
from
lightx2v.utils.registry_factory
import
ATTN_WEIGHT_REGISTER
if
torch
.
cuda
.
get_device_capability
(
0
)
==
(
8
,
9
):
try
:
from
sageattention
import
sageattn_qk_int8_pv_fp16_triton
as
sageattn
except
ImportError
:
logger
.
info
(
"sageattn not found, please install sageattention first"
)
sageattn
=
None
else
:
try
:
from
sageattention
import
sageattn
except
ImportError
:
logger
.
info
(
"sageattn not found, please install sageattention first"
)
sageattn
=
None
@
ATTN_WEIGHT_REGISTER
(
"sage_attn2"
)
class
SageAttn2Weight
(
AttnWeightTemplate
):
def
__init__
(
self
):
self
.
config
=
{}
def
apply
(
self
,
q
,
k
,
v
,
cu_seqlens_q
=
None
,
cu_seqlens_kv
=
None
,
max_seqlen_q
=
None
,
max_seqlen_kv
=
None
,
model_cls
=
None
,
mask_map
=
None
,
):
q
,
k
,
v
=
q
.
contiguous
(),
k
.
contiguous
(),
v
.
contiguous
()
if
model_cls
==
"hunyuan"
:
x1
=
sageattn
(
q
[:
cu_seqlens_q
[
1
]].
unsqueeze
(
0
),
k
[:
cu_seqlens_kv
[
1
]].
unsqueeze
(
0
),
v
[:
cu_seqlens_kv
[
1
]].
unsqueeze
(
0
),
tensor_layout
=
"NHD"
,
)
x2
=
sageattn
(
q
[
cu_seqlens_q
[
1
]
:].
unsqueeze
(
0
),
k
[
cu_seqlens_kv
[
1
]
:].
unsqueeze
(
0
),
v
[
cu_seqlens_kv
[
1
]
:].
unsqueeze
(
0
),
tensor_layout
=
"NHD"
,
)
x
=
torch
.
cat
((
x1
,
x2
),
dim
=
1
)
x
=
x
.
view
(
max_seqlen_q
,
-
1
)
elif
model_cls
in
[
"wan2.1"
,
"wan2.1_distill"
,
"wan2.1_causvid"
,
"wan2.1_df"
]:
x
=
sageattn
(
q
.
unsqueeze
(
0
),
k
.
unsqueeze
(
0
),
v
.
unsqueeze
(
0
),
tensor_layout
=
"NHD"
,
)
x
=
x
.
view
(
max_seqlen_q
,
-
1
)
return
x
lightx2v/common/ops/attn/sparge_attn.py
0 → 100644
View file @
1a881d63
import
torch
from
.template
import
AttnWeightTemplate
from
lightx2v.utils.registry_factory
import
ATTN_WEIGHT_REGISTER
from
loguru
import
logger
import
torch.nn
as
nn
try
:
from
spas_sage_attn.autotune
import
SparseAttentionMeansim
except
ImportError
:
logger
.
info
(
"SparseAttentionMeansim not found, please install sparge first"
)
SparseAttentionMeansim
=
None
@
ATTN_WEIGHT_REGISTER
(
"Sparge"
)
class
SpargeAttnWeight
(
AttnWeightTemplate
):
def
__init__
(
self
,
weight_name
,
verbose
=
False
,
l1
=
0.07
,
pv_l1
=
0.08
,
tune_pv
=
True
,
inner_attn_type
=
"flash_attn3"
,
):
self
.
verbose
=
(
verbose
,)
self
.
l1
=
(
l1
,)
self
.
pv_l1
=
(
pv_l1
,)
self
.
tune_pv
=
(
tune_pv
,)
self
.
inner_attn_type
=
inner_attn_type
self
.
inner_cls
=
SparseAttentionMeansim
(
l1
=
l1
,
pv_l1
=
pv_l1
,
tune_pv
=
tune_pv
)
super
().
__init__
(
weight_name
)
def
load
(
self
,
weight_dict
):
# match all key with prefix weight_name
for
key
in
weight_dict
.
keys
():
if
key
.
startswith
(
self
.
weight_name
):
sub_name
=
key
.
split
(
"."
)[
-
1
]
setattr
(
self
.
inner_cls
,
sub_name
,
nn
.
Parameter
(
weight_dict
[
key
],
requires_grad
=
False
),
)
def
apply
(
self
,
q
,
k
,
v
,
cu_seqlens_q
=
None
,
cu_seqlens_kv
=
None
,
max_seqlen_q
=
None
,
max_seqlen_kv
=
None
,
model_cls
=
None
,
):
if
len
(
q
.
shape
)
==
3
:
q
=
q
.
unsqueeze
(
0
)
k
=
k
.
unsqueeze
(
0
)
v
=
v
.
unsqueeze
(
0
)
x
=
self
.
inner_cls
(
q
,
k
,
v
,
tensor_layout
=
"NHD"
)
x
=
x
.
flatten
(
2
)
x
=
x
.
squeeze
(
0
)
return
x
lightx2v/common/ops/attn/template.py
0 → 100644
View file @
1a881d63
from
abc
import
ABCMeta
,
abstractmethod
class
AttnWeightTemplate
(
metaclass
=
ABCMeta
):
def
__init__
(
self
,
weight_name
):
self
.
weight_name
=
weight_name
self
.
config
=
{}
def
load
(
self
,
weight_dict
):
pass
@
abstractmethod
def
apply
(
self
,
input_tensor
):
pass
def
set_config
(
self
,
config
=
None
):
if
config
is
not
None
:
self
.
config
=
config
def
to_cpu
(
self
,
non_blocking
=
False
):
pass
def
to_cuda
(
self
,
non_blocking
=
False
):
pass
def
state_dict
(
self
,
destination
=
None
):
if
destination
is
None
:
destination
=
{}
return
destination
lightx2v/common/ops/attn/torch_sdpa.py
0 → 100644
View file @
1a881d63
import
torch
import
torch.nn.functional
as
F
from
.template
import
AttnWeightTemplate
from
lightx2v.utils.registry_factory
import
ATTN_WEIGHT_REGISTER
@
ATTN_WEIGHT_REGISTER
(
"torch_sdpa"
)
class
TorchSDPAWeight
(
AttnWeightTemplate
):
def
__init__
(
self
):
self
.
config
=
{}
def
apply
(
self
,
q
,
k
,
v
,
drop_rate
=
0
,
attn_mask
=
None
,
causal
=
False
,
cu_seqlens_q
=
None
,
cu_seqlens_kv
=
None
,
max_seqlen_q
=
None
,
max_seqlen_kv
=
None
,
model_cls
=
None
,
mask_map
=
None
,
):
if
q
.
ndim
==
3
:
q
,
k
,
v
=
q
.
unsqueeze
(
0
),
k
.
unsqueeze
(
0
),
v
.
unsqueeze
(
0
)
q
=
q
.
transpose
(
1
,
2
)
k
=
k
.
transpose
(
1
,
2
)
v
=
v
.
transpose
(
1
,
2
)
if
attn_mask
is
not
None
and
attn_mask
.
dtype
!=
torch
.
bool
:
attn_mask
=
attn_mask
.
to
(
q
.
dtype
)
x
=
F
.
scaled_dot_product_attention
(
q
,
k
,
v
,
attn_mask
=
attn_mask
,
dropout_p
=
drop_rate
,
is_causal
=
causal
)
x
=
x
.
transpose
(
1
,
2
)
b
,
s
,
a
,
d
=
x
.
shape
out
=
x
.
reshape
(
b
,
s
,
-
1
)
return
out
.
squeeze
(
0
)
lightx2v/common/ops/attn/ulysses_attn.py
0 → 100644
View file @
1a881d63
import
torch
from
.template
import
AttnWeightTemplate
from
lightx2v.utils.registry_factory
import
ATTN_WEIGHT_REGISTER
import
torch.distributed
as
dist
from
.utils.all2all
import
all2all_seq2head
,
all2all_head2seq
@
ATTN_WEIGHT_REGISTER
(
"ulysses"
)
class
UlyssesAttnWeight
(
AttnWeightTemplate
):
def
__init__
(
self
):
self
.
config
=
{}
def
apply
(
self
,
q
,
k
,
v
,
img_qkv_len
,
cu_seqlens_qkv
,
attention_module
=
None
):
"""
执行 Ulysses 注意力机制,结合图像和文本的查询、键和值。
参数:
q (torch.Tensor): 查询张量,形状为 [shard_seqlen, heads, hidden_dims]
k (torch.Tensor): 键张量,形状为 [shard_seqlen, heads, hidden_dims]
v (torch.Tensor): 值张量,形状为 [shard_seqlen, heads, hidden_dims]
img_qkv_len (int): 图像查询、键和值的长度
cu_seqlens_qkv (torch.Tensor): 累积序列长度,包含文本和图像的长度信息
attention_type (str): 注意力类型,默认为 "flash_attn2"
返回:
torch.Tensor: 计算得到的注意力结果
"""
# 获取当前进程的排名和全局进程数
cur_rank
=
dist
.
get_rank
()
world_size
=
dist
.
get_world_size
()
# 获取序列长度和文本相关的长度
seq_len
=
q
.
shape
[
0
]
if
len
(
cu_seqlens_qkv
)
==
3
:
txt_qkv_len
=
cu_seqlens_qkv
[
1
]
-
img_qkv_len
# 文本查询、键和值的长度
txt_mask_len
=
cu_seqlens_qkv
[
2
]
-
img_qkv_len
# 文本掩码长度
elif
len
(
cu_seqlens_qkv
)
==
2
:
txt_qkv_len
=
cu_seqlens_qkv
[
1
]
-
img_qkv_len
# 文本查询、键和值的长度
txt_mask_len
=
None
# 获取查询张量的头数和隐藏维度
_
,
heads
,
hidden_dims
=
q
.
shape
shard_heads
=
heads
//
world_size
# 每个进程处理的头数
shard_seqlen
=
img_qkv_len
# 每个进程处理的序列长度
# 分割图像和文本的查询、键和值
img_q
,
img_k
,
img_v
=
q
[:
img_qkv_len
,
:,
:].
contiguous
(),
k
[:
img_qkv_len
,
:,
:].
contiguous
(),
v
[:
img_qkv_len
,
:,
:].
contiguous
()
txt_q
,
txt_k
,
txt_v
=
q
[
img_qkv_len
:,
:,
:].
contiguous
(),
k
[
img_qkv_len
:,
:,
:].
contiguous
(),
v
[
img_qkv_len
:,
:,
:].
contiguous
()
# 将图像的查询、键和值转换为头的格式
img_q
=
all2all_seq2head
(
img_q
)
img_k
=
all2all_seq2head
(
img_k
)
img_v
=
all2all_seq2head
(
img_v
)
torch
.
cuda
.
synchronize
()
# 确保CUDA操作完成
# 处理文本的查询、键和值,选择当前进程的头
txt_q
=
txt_q
[:,
cur_rank
*
shard_heads
:
(
cur_rank
+
1
)
*
shard_heads
,
:]
txt_k
=
txt_k
[:,
cur_rank
*
shard_heads
:
(
cur_rank
+
1
)
*
shard_heads
,
:]
txt_v
=
txt_v
[:,
cur_rank
*
shard_heads
:
(
cur_rank
+
1
)
*
shard_heads
,
:]
# 合并图像和文本的查询、键和值
q
=
torch
.
cat
((
img_q
,
txt_q
),
dim
=
0
)
k
=
torch
.
cat
((
img_k
,
txt_k
),
dim
=
0
)
v
=
torch
.
cat
((
img_v
,
txt_v
),
dim
=
0
)
# 初始化累积序列长度张量
cu_seqlens_qkv
=
torch
.
zeros
([
2
],
dtype
=
torch
.
int32
,
device
=
"cuda"
)
s
=
txt_qkv_len
+
img_q
.
shape
[
0
]
# 计算文本和图像的总长度
s1
=
s
# 当前样本的结束位置
cu_seqlens_qkv
[
1
]
=
s1
# 设置累积序列长度
if
txt_mask_len
:
s2
=
txt_mask_len
+
img_q
.
shape
[
0
]
# 文本掩码的结束位置
cu_seqlens_qkv
=
torch
.
cat
(
cu_seqlens_qkv
,
s2
)
max_seqlen_qkv
=
img_q
.
shape
[
0
]
+
txt_q
.
shape
[
0
]
# 最大序列长度
# 调用注意力函数计算注意力结果
# attn = attention(attention_type=attention_type, q=q, k=k, v=v, cu_seqlens_q=cu_seqlens_qkv, cu_seqlens_kv=cu_seqlens_qkv, max_seqlen_q=max_seqlen_qkv, max_seqlen_kv=max_seqlen_qkv)
attn
=
attention_module
.
apply
(
q
=
q
,
k
=
k
,
v
=
v
,
cu_seqlens_q
=
cu_seqlens_qkv
,
cu_seqlens_kv
=
cu_seqlens_qkv
,
max_seqlen_q
=
max_seqlen_qkv
,
max_seqlen_kv
=
max_seqlen_qkv
)
# 分割图像和文本的注意力结果
img_attn
,
txt_attn
=
attn
[:
img_q
.
shape
[
0
],
:],
attn
[
img_q
.
shape
[
0
]
:,]
# 收集所有进程的文本注意力结果
gathered_txt_attn
=
[
torch
.
empty_like
(
txt_attn
)
for
_
in
range
(
world_size
)]
dist
.
all_gather
(
gathered_txt_attn
,
txt_attn
)
# 处理图像注意力结果
img_attn
=
img_attn
.
reshape
(
world_size
*
shard_seqlen
,
shard_heads
,
hidden_dims
)
# 重塑图像注意力结果
img_attn
=
all2all_head2seq
(
img_attn
)
# 将头的格式转换回序列格式
img_attn
=
img_attn
.
reshape
(
shard_seqlen
,
-
1
)
# 重塑为 [shard_seqlen, -1] 形状
torch
.
cuda
.
synchronize
()
# 确保CUDA操作完成
txt_attn
=
torch
.
cat
(
gathered_txt_attn
,
dim
=
1
)
# 合并所有进程的文本注意力结果
# 合并图像和文本的注意力结果
attn
=
torch
.
cat
([
img_attn
,
txt_attn
],
dim
=
0
)
return
attn
# 返回最终的注意力结果
lightx2v/
attentions/distributed/comm
/all2all.py
→
lightx2v/
common/ops/attn/utils
/all2all.py
View file @
1a881d63
File moved
lightx2v/
attentions/distributed/comm
/ring_comm.py
→
lightx2v/
common/ops/attn/utils
/ring_comm.py
View file @
1a881d63
File moved
lightx2v/models/input_encoders/hf/xlm_roberta/model.py
View file @
1a881d63
...
...
@@ -8,7 +8,8 @@ import torch.nn as nn
import
torch.nn.functional
as
F
import
torchvision.transforms
as
T
from
lightx2v.attentions
import
attention
# from lightx2v.attentions import attention
from
lightx2v.common.ops.attn
import
TorchSDPAWeight
from
loguru
import
logger
from
lightx2v.models.input_encoders.hf.q_linear
import
VllmQuantLinearInt8
,
VllmQuantLinearFp8
,
TorchaoQuantLinearInt8
,
Q8FQuantLinearInt8
,
Q8FQuantLinearFp8
...
...
@@ -84,7 +85,7 @@ class SelfAttention(nn.Module):
q
,
k
,
v
=
self
.
to_qkv
(
x
).
view
(
b
,
s
,
3
,
n
,
d
).
unbind
(
2
)
# compute attention
x
=
attention
(
q
=
q
,
k
=
k
,
v
=
v
,
attention_type
=
"torch_sdpa"
)
x
=
TorchSDPAWeight
().
apply
(
q
=
q
,
k
=
k
,
v
=
v
)
x
=
x
.
reshape
(
b
,
s
,
c
)
# output
...
...
lightx2v/models/networks/hunyuan/model.py
View file @
1a881d63
...
...
@@ -13,8 +13,6 @@ from lightx2v.models.networks.hunyuan.infer.feature_caching.transformer_infer im
HunyuanTransformerInferAdaCaching
,
HunyuanTransformerInferCustomCaching
,
)
import
lightx2v.attentions.distributed.ulysses.wrap
as
ulysses_dist_wrap
import
lightx2v.attentions.distributed.ring.wrap
as
ring_dist_wrap
from
lightx2v.utils.envs
import
*
from
loguru
import
logger
from
safetensors
import
safe_open
...
...
@@ -41,14 +39,6 @@ class HunyuanModel:
self
.
_init_weights
()
self
.
_init_infer
()
if
config
[
"parallel_attn_type"
]:
if
config
[
"parallel_attn_type"
]
==
"ulysses"
:
ulysses_dist_wrap
.
parallelize_hunyuan
(
self
)
elif
config
[
"parallel_attn_type"
]
==
"ring"
:
ring_dist_wrap
.
parallelize_hunyuan
(
self
)
else
:
raise
Exception
(
f
"Unsuppotred parallel_attn_type"
)
if
self
.
config
[
"cpu_offload"
]:
self
.
to_cpu
()
...
...
lightx2v/models/networks/wan/audio_model.py
View file @
1a881d63
...
...
@@ -12,13 +12,8 @@ from lightx2v.models.networks.wan.infer.audio.pre_wan_audio_infer import WanAudi
from
lightx2v.models.networks.wan.infer.audio.post_wan_audio_infer
import
WanAudioPostInfer
from
lightx2v.models.networks.wan.infer.feature_caching.transformer_infer
import
WanTransformerInferTeaCaching
from
lightx2v.attentions.common.radial_attn
import
MaskMap
from
lightx2v.models.networks.wan.infer.transformer_infer
import
(
WanTransformerInfer
,
)
from
lightx2v.models.networks.wan.infer.feature_caching.transformer_infer
import
(
WanTransformerInferTeaCaching
,
)
from
safetensors
import
safe_open
from
lightx2v.common.ops.attn.radial_attn
import
MaskMap
class
WanAudioModel
(
WanModel
):
...
...
@@ -30,14 +25,9 @@ class WanAudioModel(WanModel):
super
().
__init__
(
model_path
,
config
,
device
)
def
_init_infer_class
(
self
):
super
().
_init_infer_class
()
self
.
pre_infer_class
=
WanAudioPreInfer
self
.
post_infer_class
=
WanAudioPostInfer
if
self
.
config
[
"feature_caching"
]
==
"NoCaching"
:
self
.
transformer_infer_class
=
WanTransformerInfer
elif
self
.
config
[
"feature_caching"
]
==
"Tea"
:
self
.
transformer_infer_class
=
WanTransformerInferTeaCaching
else
:
raise
NotImplementedError
(
f
"Unsupported feature_caching type:
{
self
.
config
[
'feature_caching'
]
}
"
)
@
torch
.
no_grad
()
def
infer
(
self
,
inputs
):
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment