Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
1a881d63
Commit
1a881d63
authored
Jul 28, 2025
by
helloyongyang
Browse files
重构并行模块
parent
18e2b23a
Changes
47
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
0 additions
and
699 deletions
+0
-699
lightx2v/attentions/__init__.py
lightx2v/attentions/__init__.py
+0
-20
lightx2v/attentions/common/flash_attn2.py
lightx2v/attentions/common/flash_attn2.py
+0
-17
lightx2v/attentions/common/flash_attn3.py
lightx2v/attentions/common/flash_attn3.py
+0
-17
lightx2v/attentions/common/sage_attn2.py
lightx2v/attentions/common/sage_attn2.py
+0
-40
lightx2v/attentions/common/torch_sdpa.py
lightx2v/attentions/common/torch_sdpa.py
+0
-22
lightx2v/attentions/distributed/__init__.py
lightx2v/attentions/distributed/__init__.py
+0
-0
lightx2v/attentions/distributed/comm/__init__.py
lightx2v/attentions/distributed/comm/__init__.py
+0
-0
lightx2v/attentions/distributed/partial_heads_attn/__init__.py
...x2v/attentions/distributed/partial_heads_attn/__init__.py
+0
-0
lightx2v/attentions/distributed/partial_heads_attn/attn.py
lightx2v/attentions/distributed/partial_heads_attn/attn.py
+0
-37
lightx2v/attentions/distributed/partial_heads_attn/tests/__init__.py
...tentions/distributed/partial_heads_attn/tests/__init__.py
+0
-0
lightx2v/attentions/distributed/partial_heads_attn/tests/test.sh
...v/attentions/distributed/partial_heads_attn/tests/test.sh
+0
-3
lightx2v/attentions/distributed/partial_heads_attn/tests/test_acc.py
...tentions/distributed/partial_heads_attn/tests/test_acc.py
+0
-78
lightx2v/attentions/distributed/partial_heads_attn/wrap.py
lightx2v/attentions/distributed/partial_heads_attn/wrap.py
+0
-5
lightx2v/attentions/distributed/ring/__init__.py
lightx2v/attentions/distributed/ring/__init__.py
+0
-0
lightx2v/attentions/distributed/ring/attn.py
lightx2v/attentions/distributed/ring/attn.py
+0
-193
lightx2v/attentions/distributed/ring/tests/test.py
lightx2v/attentions/distributed/ring/tests/test.py
+0
-102
lightx2v/attentions/distributed/ring/tests/test.sh
lightx2v/attentions/distributed/ring/tests/test.sh
+0
-3
lightx2v/attentions/distributed/ring/wrap.py
lightx2v/attentions/distributed/ring/wrap.py
+0
-71
lightx2v/attentions/distributed/ulysses/__init__.py
lightx2v/attentions/distributed/ulysses/__init__.py
+0
-0
lightx2v/attentions/distributed/ulysses/attn.py
lightx2v/attentions/distributed/ulysses/attn.py
+0
-91
No files found.
lightx2v/attentions/__init__.py
deleted
100755 → 0
View file @
18e2b23a
from
lightx2v.attentions.common.torch_sdpa
import
torch_sdpa
from
lightx2v.attentions.common.flash_attn2
import
flash_attn2
from
lightx2v.attentions.common.flash_attn3
import
flash_attn3
from
lightx2v.attentions.common.sage_attn2
import
sage_attn2
from
lightx2v.attentions.common.radial_attn
import
radial_attn
def
attention
(
attention_type
=
"flash_attn2"
,
*
args
,
**
kwargs
):
if
attention_type
==
"torch_sdpa"
:
return
torch_sdpa
(
*
args
,
**
kwargs
)
elif
attention_type
==
"flash_attn2"
:
return
flash_attn2
(
*
args
,
**
kwargs
)
elif
attention_type
==
"flash_attn3"
:
return
flash_attn3
(
*
args
,
**
kwargs
)
elif
attention_type
==
"sage_attn2"
:
return
sage_attn2
(
*
args
,
**
kwargs
)
elif
attention_type
==
"radial_attn"
:
return
radial_attn
(
*
args
,
**
kwargs
)
else
:
raise
NotImplementedError
(
f
"Unsupported attention mode:
{
attention_type
}
"
)
lightx2v/attentions/common/flash_attn2.py
deleted
100644 → 0
View file @
18e2b23a
try
:
from
flash_attn.flash_attn_interface
import
flash_attn_varlen_func
except
ImportError
:
flash_attn_varlen_func
=
None
def
flash_attn2
(
q
,
k
,
v
,
cu_seqlens_q
=
None
,
cu_seqlens_kv
=
None
,
max_seqlen_q
=
None
,
max_seqlen_kv
=
None
,
model_cls
=
None
):
x
=
flash_attn_varlen_func
(
q
,
k
,
v
,
cu_seqlens_q
,
cu_seqlens_kv
,
max_seqlen_q
,
max_seqlen_kv
,
).
reshape
(
max_seqlen_q
,
-
1
)
return
x
lightx2v/attentions/common/flash_attn3.py
deleted
100644 → 0
View file @
18e2b23a
try
:
from
flash_attn_interface
import
flash_attn_varlen_func
as
flash_attn_varlen_func_v3
except
ImportError
:
flash_attn_varlen_func_v3
=
None
def
flash_attn3
(
q
,
k
,
v
,
cu_seqlens_q
=
None
,
cu_seqlens_kv
=
None
,
max_seqlen_q
=
None
,
max_seqlen_kv
=
None
,
model_cls
=
None
):
x
=
flash_attn_varlen_func_v3
(
q
,
k
,
v
,
cu_seqlens_q
,
cu_seqlens_kv
,
max_seqlen_q
,
max_seqlen_kv
,
)[
0
].
reshape
(
max_seqlen_q
,
-
1
)
return
x
lightx2v/attentions/common/sage_attn2.py
deleted
100644 → 0
View file @
18e2b23a
import
torch
if
torch
.
cuda
.
get_device_capability
(
0
)
==
(
8
,
9
):
try
:
from
sageattention
import
sageattn_qk_int8_pv_fp16_triton
as
sageattn
except
ImportError
:
sageattn
=
None
,
None
else
:
try
:
from
sageattention
import
sageattn
except
ImportError
:
sageattn
=
None
def
sage_attn2
(
q
,
k
,
v
,
cu_seqlens_q
=
None
,
cu_seqlens_kv
=
None
,
max_seqlen_q
=
None
,
max_seqlen_kv
=
None
,
model_cls
=
"hunyuan"
):
q
,
k
,
v
=
q
.
contiguous
(),
k
.
contiguous
(),
v
.
contiguous
()
if
model_cls
==
"hunyuan"
:
x1
=
sageattn
(
q
[:
cu_seqlens_q
[
1
]].
unsqueeze
(
0
),
k
[:
cu_seqlens_kv
[
1
]].
unsqueeze
(
0
),
v
[:
cu_seqlens_kv
[
1
]].
unsqueeze
(
0
),
tensor_layout
=
"NHD"
,
)
x2
=
sageattn
(
q
[
cu_seqlens_q
[
1
]
:].
unsqueeze
(
0
),
k
[
cu_seqlens_kv
[
1
]
:].
unsqueeze
(
0
),
v
[
cu_seqlens_kv
[
1
]
:].
unsqueeze
(
0
),
tensor_layout
=
"NHD"
,
)
x
=
torch
.
cat
((
x1
,
x2
),
dim
=
1
)
x
=
x
.
view
(
max_seqlen_q
,
-
1
)
elif
model_cls
in
[
"wan2.1"
,
"wan2.1_distill"
,
"wan2.1_causvid"
,
"wan2.1_df"
]:
x
=
sageattn
(
q
.
unsqueeze
(
0
),
k
.
unsqueeze
(
0
),
v
.
unsqueeze
(
0
),
tensor_layout
=
"NHD"
,
)
x
=
x
.
view
(
max_seqlen_q
,
-
1
)
return
x
lightx2v/attentions/common/torch_sdpa.py
deleted
100644 → 0
View file @
18e2b23a
import
torch
import
torch.nn.functional
as
F
def
torch_sdpa
(
q
,
k
,
v
,
drop_rate
=
0
,
attn_mask
=
None
,
causal
=
False
,
):
q
=
q
.
transpose
(
1
,
2
)
k
=
k
.
transpose
(
1
,
2
)
v
=
v
.
transpose
(
1
,
2
)
if
attn_mask
is
not
None
and
attn_mask
.
dtype
!=
torch
.
bool
:
attn_mask
=
attn_mask
.
to
(
q
.
dtype
)
x
=
F
.
scaled_dot_product_attention
(
q
,
k
,
v
,
attn_mask
=
attn_mask
,
dropout_p
=
drop_rate
,
is_causal
=
causal
)
x
=
x
.
transpose
(
1
,
2
)
b
,
s
,
a
,
d
=
x
.
shape
out
=
x
.
reshape
(
b
,
s
,
-
1
)
return
out
lightx2v/attentions/distributed/__init__.py
deleted
100644 → 0
View file @
18e2b23a
lightx2v/attentions/distributed/comm/__init__.py
deleted
100644 → 0
View file @
18e2b23a
lightx2v/attentions/distributed/partial_heads_attn/__init__.py
deleted
100644 → 0
View file @
18e2b23a
lightx2v/attentions/distributed/partial_heads_attn/attn.py
deleted
100644 → 0
View file @
18e2b23a
import
torch
import
torch.distributed
as
dist
from
lightx2v.attentions
import
attention
def
partial_heads_attn
(
attention_type
,
q
,
k
,
v
,
cu_seqlens_qkv
,
max_seqlen_qkv
):
num_heads
=
q
.
shape
[
-
2
]
cur_rank
=
dist
.
get_rank
()
world_size
=
dist
.
get_world_size
()
num_chunk_heads
=
int
(
num_heads
/
dist
.
get_world_size
())
if
cur_rank
==
world_size
-
1
:
q
=
q
[:,
num_chunk_heads
*
cur_rank
:,
:]
k
=
k
[:,
num_chunk_heads
*
cur_rank
:,
:]
v
=
v
[:,
num_chunk_heads
*
cur_rank
:,
:]
else
:
q
=
q
[:,
num_chunk_heads
*
cur_rank
:
num_chunk_heads
*
(
cur_rank
+
1
),
:]
k
=
k
[:,
num_chunk_heads
*
cur_rank
:
num_chunk_heads
*
(
cur_rank
+
1
),
:]
v
=
v
[:,
num_chunk_heads
*
cur_rank
:
num_chunk_heads
*
(
cur_rank
+
1
),
:]
output
=
attention
(
attention_type
=
attention_type
,
q
=
q
,
k
=
k
,
v
=
v
,
cu_seqlens_q
=
cu_seqlens_qkv
,
cu_seqlens_kv
=
cu_seqlens_qkv
,
max_seqlen_q
=
max_seqlen_qkv
,
max_seqlen_kv
=
max_seqlen_qkv
,
)
gathered_outputs
=
[
torch
.
empty_like
(
output
)
for
_
in
range
(
world_size
)]
dist
.
all_gather
(
gathered_outputs
,
output
)
combined_output
=
torch
.
cat
(
gathered_outputs
,
dim
=
1
)
return
combined_output
lightx2v/attentions/distributed/partial_heads_attn/tests/__init__.py
deleted
100644 → 0
View file @
18e2b23a
lightx2v/attentions/distributed/partial_heads_attn/tests/test.sh
deleted
100644 → 0
View file @
18e2b23a
export
PYTHONPATH
=
/workspace/lightx2v:
$PYTHONPATH
export
CUDA_VISIBLE_DEVICES
=
0,1
torchrun
--nproc_per_node
=
2 test_acc.py
lightx2v/attentions/distributed/partial_heads_attn/tests/test_acc.py
deleted
100644 → 0
View file @
18e2b23a
import
torch
import
torch.distributed
as
dist
from
lightx2v.attentions
import
attention
from
lightx2v.utils.utils
import
seed_all
from
loguru
import
logger
seed_all
(
42
)
def
prepare_tensors
():
cur_rank
=
dist
.
get_rank
()
# 获取当前进程的 rank
torch
.
cuda
.
set_device
(
cur_rank
)
# 设置当前进程的 CUDA 设备
q
=
torch
.
randn
(
32656
,
24
,
128
,
dtype
=
torch
.
bfloat16
).
cuda
()
k
=
torch
.
randn
(
32656
,
24
,
128
,
dtype
=
torch
.
bfloat16
).
cuda
()
v
=
torch
.
randn
(
32656
,
24
,
128
,
dtype
=
torch
.
bfloat16
).
cuda
()
cu_seqlens_qkv
=
torch
.
tensor
([
0
,
32411
,
32656
],
dtype
=
torch
.
int32
).
cuda
()
max_seqlen_qkv
=
32656
return
q
,
k
,
v
,
cu_seqlens_qkv
,
max_seqlen_qkv
def
test_part_head
():
q
,
k
,
v
,
cu_seqlens_qkv
,
max_seqlen_qkv
=
prepare_tensors
()
# 先计算完整的结果作为参考
single_gpu_output
=
attention
(
q
=
q
,
k
=
k
,
v
=
v
,
cu_seqlens_q
=
cu_seqlens_qkv
,
cu_seqlens_kv
=
cu_seqlens_qkv
,
max_seqlen_q
=
max_seqlen_qkv
,
max_seqlen_kv
=
max_seqlen_qkv
,
)
num_heads
=
q
.
shape
[
-
2
]
cur_rank
=
dist
.
get_rank
()
world_size
=
dist
.
get_world_size
()
num_chunk_heads
=
int
(
num_heads
/
dist
.
get_world_size
())
if
cur_rank
==
world_size
-
1
:
q
=
q
[:,
num_chunk_heads
*
cur_rank
:,
:]
k
=
k
[:,
num_chunk_heads
*
cur_rank
:,
:]
v
=
v
[:,
num_chunk_heads
*
cur_rank
:,
:]
else
:
q
=
q
[:,
num_chunk_heads
*
cur_rank
:
num_chunk_heads
*
(
cur_rank
+
1
),
:]
k
=
k
[:,
num_chunk_heads
*
cur_rank
:
num_chunk_heads
*
(
cur_rank
+
1
),
:]
v
=
v
[:,
num_chunk_heads
*
cur_rank
:
num_chunk_heads
*
(
cur_rank
+
1
),
:]
output
=
attention
(
q
=
q
,
k
=
k
,
v
=
v
,
cu_seqlens_q
=
cu_seqlens_qkv
,
cu_seqlens_kv
=
cu_seqlens_qkv
,
max_seqlen_q
=
max_seqlen_qkv
,
max_seqlen_kv
=
max_seqlen_qkv
,
)
gathered_outputs
=
[
torch
.
empty_like
(
output
)
for
_
in
range
(
world_size
)]
dist
.
all_gather
(
gathered_outputs
,
output
)
combined_output
=
torch
.
cat
(
gathered_outputs
,
dim
=
1
)
# 验证结果一致性
if
cur_rank
==
0
:
# import pdb; pdb.set_trace()
logger
.
info
(
"Outputs match:"
,
torch
.
allclose
(
single_gpu_output
,
combined_output
,
rtol
=
1e-3
,
atol
=
1e-3
))
# # 验证结果一致性
# logger.info("Outputs match:", torch.allclose(single_gpu_output, combined_output, rtol=1e-3, atol=1e-3))
if
__name__
==
"__main__"
:
# 初始化分布式环境
dist
.
init_process_group
(
backend
=
"nccl"
)
test_part_head
()
lightx2v/attentions/distributed/partial_heads_attn/wrap.py
deleted
100644 → 0
View file @
18e2b23a
from
lightx2v.attentions.distributed.partial_heads_attn.attn
import
partial_heads_attn
def
parallelize_hunyuan
(
hunyuan_model
):
hunyuan_model
.
transformer_infer
.
parallel_attention
=
partial_heads_attn
lightx2v/attentions/distributed/ring/__init__.py
deleted
100644 → 0
View file @
18e2b23a
lightx2v/attentions/distributed/ring/attn.py
deleted
100644 → 0
View file @
18e2b23a
import
torch
import
torch.distributed
as
dist
import
torch.nn.functional
as
F
# from lightx2v.attentions import attention
from
lightx2v.attentions.distributed.comm.ring_comm
import
RingComm
try
:
import
flash_attn
from
flash_attn.flash_attn_interface
import
_flash_attn_forward
except
ImportError
:
flash_attn
=
None
_flash_attn_forward
=
None
from
typing
import
Optional
,
Tuple
# RING_COMM = None
# def init_ring_comm():
# global RING_COMM
# RING_COMM = RingComm()
@
torch
.
jit
.
script
def
_update_out_and_lse
(
out
:
torch
.
Tensor
,
lse
:
torch
.
Tensor
,
block_out
:
torch
.
Tensor
,
block_lse
:
torch
.
Tensor
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
block_out
=
block_out
.
to
(
torch
.
float32
)
block_lse
=
block_lse
.
transpose
(
-
2
,
-
1
).
unsqueeze
(
dim
=-
1
)
# new_lse = lse + torch.log(1 + torch.exp(block_lse - lse))
# torch.exp(lse - new_lse) * out + torch.exp(block_lse - new_lse) * block_out
# For additional context and discussion, please refer to:
# https://github.com/zhuzilin/ring-flash-attention/pull/34#issuecomment-2076126795
out
=
out
-
F
.
sigmoid
(
block_lse
-
lse
)
*
(
out
-
block_out
)
lse
=
lse
-
F
.
logsigmoid
(
lse
-
block_lse
)
return
out
,
lse
def
update_out_and_lse
(
out
:
Optional
[
torch
.
Tensor
],
lse
:
Optional
[
torch
.
Tensor
],
block_out
:
torch
.
Tensor
,
block_lse
:
torch
.
Tensor
,
slice_
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
out
is
None
:
if
slice_
is
not
None
:
raise
RuntimeError
(
"first update_out_and_lse should not pass slice_ args"
)
out
=
block_out
.
to
(
torch
.
float32
)
lse
=
block_lse
.
transpose
(
-
2
,
-
1
).
unsqueeze
(
dim
=-
1
)
elif
slice_
is
not
None
:
slice_out
,
slice_lse
=
out
[
slice_
],
lse
[
slice_
]
slice_out
,
slice_lse
=
_update_out_and_lse
(
slice_out
,
slice_lse
,
block_out
,
block_lse
)
out
[
slice_
],
lse
[
slice_
]
=
slice_out
,
slice_lse
else
:
out
,
lse
=
_update_out_and_lse
(
out
,
lse
,
block_out
,
block_lse
)
return
out
,
lse
def
ring_attn_sub
(
q
,
k
,
v
,
dropout_p
=
0.0
,
softmax_scale
=
None
,
causal
=
False
,
window_size
=
(
-
1
,
-
1
),
softcap
=
0.0
,
alibi_slopes
=
None
,
return_softmax
=
False
):
if
softmax_scale
is
None
:
softmax_scale
=
q
.
shape
[
-
1
]
**
(
-
0.5
)
if
flash_attn
.
__version__
<
"2.6.3"
:
block_out
,
_
,
_
,
_
,
_
,
block_lse
,
_
,
_
=
_flash_attn_forward
(
q
,
k
,
v
,
dropout_p
=
dropout_p
,
softmax_scale
=
softmax_scale
,
causal
=
causal
,
window_size
=
window_size
,
softcap
=
softcap
,
alibi_slopes
=
alibi_slopes
,
return_softmax
=
return_softmax
,
)
else
:
block_out
,
block_lse
,
_
,
_
=
_flash_attn_forward
(
q
,
k
,
v
,
dropout_p
=
dropout_p
,
softmax_scale
=
softmax_scale
,
causal
=
causal
,
window_size_left
=
window_size
[
0
],
window_size_right
=
window_size
[
1
],
softcap
=
softcap
,
alibi_slopes
=
alibi_slopes
,
return_softmax
=
return_softmax
,
)
return
block_out
,
block_lse
def
ring_attn
(
q
,
k
,
v
,
img_qkv_len
,
cu_seqlens_qkv
,
attention_type
=
"flash_attn2"
):
"""
执行 Ring 注意力机制,结合图像和文本的查询、键和值。
参数:
q (torch.Tensor): 查询张量,形状为 [shard_seqlen, heads, hidden_dims]
k (torch.Tensor): 键张量,形状为 [shard_seqlen, heads, hidden_dims]
v (torch.Tensor): 值张量,形状为 [shard_seqlen, heads, hidden_dims]
img_qkv_len (int): 图像查询、键和值的长度
cu_seqlens_qkv (torch.Tensor): 累积序列长度,包含文本和图像的长度信息
attention_type (str): 注意力类型,默认为 "flash_attn2"
返回:
torch.Tensor: 计算得到的注意力结果
"""
# 获取当前进程的排名和全局进程数
cur_rank
=
dist
.
get_rank
()
world_size
=
dist
.
get_world_size
()
if
len
(
cu_seqlens_qkv
)
==
3
:
txt_qkv_len
=
cu_seqlens_qkv
[
1
]
-
img_qkv_len
# 文本查询、键和值的长度
txt_mask_len
=
cu_seqlens_qkv
[
2
]
-
img_qkv_len
# 文本掩码长度
elif
len
(
cu_seqlens_qkv
)
==
2
:
txt_qkv_len
=
cu_seqlens_qkv
[
1
]
-
img_qkv_len
# 文本查询、键和值的长度
txt_mask_len
=
0
# if RING_COMM is None:
# init_ring_comm()
RING_COMM
=
RingComm
()
# if len(cu_seqlens_qkv) == 3:
# txt_qkv_len = cu_seqlens_qkv[1] - img_qkv_len # 文本查询、键和值的长度
# txt_mask_len = cu_seqlens_qkv[2] - img_qkv_len # 文本掩码长度
# elif len(cu_seqlens_qkv) == 2:
# txt_qkv_len = cu_seqlens_qkv[1] - img_qkv_len # 文本查询、键和值的长度
# txt_mask_len = None
q
=
q
.
unsqueeze
(
0
)
k
=
k
.
unsqueeze
(
0
)
v
=
v
.
unsqueeze
(
0
)
img_q
,
img_k
,
img_v
=
q
[:,
:
img_qkv_len
,
:,
:].
contiguous
(),
k
[:,
:
img_qkv_len
,
:,
:].
contiguous
(),
v
[:,
:
img_qkv_len
,
:,
:].
contiguous
()
txt_q
,
txt_k
,
txt_v
=
(
q
[:,
img_qkv_len
:
img_qkv_len
+
txt_qkv_len
,
:,
:].
contiguous
(),
k
[:,
img_qkv_len
:
img_qkv_len
+
txt_qkv_len
,
:,
:].
contiguous
(),
v
[:,
img_qkv_len
:
img_qkv_len
+
txt_qkv_len
,
:,
:].
contiguous
(),
)
out
,
lse
,
next_k
,
next_v
=
None
,
None
,
None
,
None
if
len
(
cu_seqlens_qkv
)
==
3
:
q
=
torch
.
cat
((
img_q
,
txt_q
),
dim
=
1
)
k
=
img_k
v
=
img_v
for
step
in
range
(
world_size
):
if
step
+
1
!=
world_size
:
next_k
=
RING_COMM
.
send_recv
(
k
)
next_v
=
RING_COMM
.
send_recv
(
v
)
RING_COMM
.
commit
()
if
step
+
1
==
world_size
:
k
=
torch
.
cat
((
k
,
txt_k
),
dim
=
1
)
v
=
torch
.
cat
((
v
,
txt_v
),
dim
=
1
)
block_out
,
block_lse
=
ring_attn_sub
(
q
,
k
,
v
)
out
,
lse
=
update_out_and_lse
(
out
,
lse
,
block_out
,
block_lse
)
if
step
+
1
!=
world_size
:
RING_COMM
.
wait
()
k
=
next_k
v
=
next_v
attn1
=
out
.
to
(
torch
.
bfloat16
).
squeeze
(
0
).
reshape
(
img_qkv_len
+
txt_qkv_len
,
-
1
)
if
txt_mask_len
>
0
:
attn2
,
*
_
=
_flash_attn_forward
(
q
[:,
-
(
txt_mask_len
-
txt_qkv_len
)
:,
:,
:].
contiguous
(),
k
[:,
-
(
txt_mask_len
-
txt_qkv_len
)
:,
:,
:].
contiguous
(),
v
[:,
-
(
txt_mask_len
-
txt_qkv_len
)
:,
:,
:].
contiguous
(),
dropout_p
=
0.0
,
softmax_scale
=
q
.
shape
[
-
1
]
**
(
-
0.5
),
causal
=
False
,
window_size_left
=-
1
,
window_size_right
=-
1
,
softcap
=
0.0
,
alibi_slopes
=
None
,
return_softmax
=
False
,
)
attn2
=
attn2
.
to
(
torch
.
bfloat16
).
squeeze
(
0
).
reshape
((
txt_mask_len
-
txt_qkv_len
),
-
1
)
attn1
=
torch
.
cat
([
attn1
,
attn2
],
dim
=
0
)
return
attn1
lightx2v/attentions/distributed/ring/tests/test.py
deleted
100644 → 0
View file @
18e2b23a
import
torch
import
torch.distributed
as
dist
from
lightx2v.attentions
import
attention
from
flash_attn.flash_attn_interface
import
_flash_attn_varlen_forward
from
lightx2v.attentions.distributed.ring.attn
import
ring_attn_sub
,
update_out_and_lse
from
lightx2v.attentions.distributed.comm.ring_comm
import
RingComm
RING_COMM
=
None
def
init_ring_comm
():
global
RING_COMM
RING_COMM
=
RingComm
()
def
base_attention
(
q
,
k
,
v
,
cu_seqlens_q
,
cu_seqlens_k
,
lq
,
lk
):
attn_out
=
attention
(
q
=
q
,
k
=
k
,
v
=
v
,
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_kv
=
cu_seqlens_k
,
max_seqlen_q
=
lq
,
max_seqlen_kv
=
lk
,
)
return
attn_out
def
ring_attention
(
q
,
k
,
v
,
cu_seqlens_q
,
cu_seqlens_k
,
lq
,
lk
,
ring_size
):
out
,
lse
=
None
,
None
# q = torch.chunk(q, ring_size)
q
=
q
.
unsqueeze
(
0
)
k
=
k
.
unsqueeze
(
0
)
v
=
v
.
unsqueeze
(
0
)
k
=
torch
.
chunk
(
k
,
ring_size
,
dim
=
1
)
v
=
torch
.
chunk
(
v
,
ring_size
,
dim
=
1
)
for
i
in
range
(
ring_size
):
k_block
,
v_block
=
k
[
i
],
v
[
i
]
block_out
,
block_lse
=
ring_attn_sub
(
q
,
k_block
,
v_block
)
out
,
lse
=
update_out_and_lse
(
out
,
lse
,
block_out
,
block_lse
)
attn_out
=
out
.
to
(
torch
.
bfloat16
).
squeeze
(
0
).
reshape
(
lq
,
-
1
)
return
attn_out
def
ring_attention_dist
(
q
,
k
,
v
,
cu_seqlens_q
,
cu_seqlens_k
,
lq
,
lk
):
if
RING_COMM
is
None
:
init_ring_comm
()
out
,
lse
=
None
,
None
# q = torch.chunk(q, ring_size)
cur_rank
=
dist
.
get_rank
()
world_size
=
dist
.
get_world_size
()
out
,
lse
,
next_k
,
next_v
=
None
,
None
,
None
,
None
q
=
q
.
unsqueeze
(
0
)
k
=
k
.
unsqueeze
(
0
)
v
=
v
.
unsqueeze
(
0
)
k
=
torch
.
chunk
(
k
,
world_size
,
dim
=
1
)[
cur_rank
]
v
=
torch
.
chunk
(
v
,
world_size
,
dim
=
1
)[
cur_rank
]
for
step
in
range
(
world_size
):
if
step
+
1
!=
world_size
:
next_k
=
RING_COMM
.
send_recv
(
k
)
next_v
=
RING_COMM
.
send_recv
(
v
)
RING_COMM
.
commit
()
block_out
,
block_lse
=
ring_attn_sub
(
q
,
k
,
v
)
out
,
lse
=
update_out_and_lse
(
out
,
lse
,
block_out
,
block_lse
)
if
step
+
1
!=
world_size
:
RING_COMM
.
wait
()
k
=
next_k
v
=
next_v
attn_out
=
out
.
to
(
torch
.
bfloat16
).
squeeze
(
0
).
reshape
(
lq
,
-
1
)
return
attn_out
def
test
():
q
=
torch
.
randn
((
32760
,
12
,
128
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
k
=
torch
.
randn
((
32760
,
12
,
128
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
v
=
torch
.
randn
((
32760
,
12
,
128
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
cu_seqlens_q
=
torch
.
tensor
([
0
,
32760
],
dtype
=
torch
.
int32
,
device
=
"cuda"
)
cu_seqlens_k
=
torch
.
tensor
([
0
,
32760
],
dtype
=
torch
.
int32
,
device
=
"cuda"
)
lq
=
32760
lk
=
32760
base_attn
=
base_attention
(
q
=
q
,
k
=
k
,
v
=
v
,
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_k
=
cu_seqlens_k
,
lq
=
lq
,
lk
=
lk
)
ring_attn
=
ring_attention
(
q
=
q
,
k
=
k
,
v
=
v
,
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_k
=
cu_seqlens_k
,
lq
=
lq
,
lk
=
lk
,
ring_size
=
4
)
# import pdb; pdb.set_trace()
# 添加断言以确认数值相同
assert
torch
.
allclose
(
base_attn
,
ring_attn
,
rtol
=
1e-3
,
atol
=
1e-3
),
"base_attn 和 ring_attn 的数值不相同!"
if
__name__
==
"__main__"
:
# dist.init_process_group(backend="nccl")
test
()
lightx2v/attentions/distributed/ring/tests/test.sh
deleted
100644 → 0
View file @
18e2b23a
lightx2v_path
=
""
export
PYTHONPATH
=
${
lightx2v_path
}
:
$PYTHONPATH
python3 test.py
lightx2v/attentions/distributed/ring/wrap.py
deleted
100644 → 0
View file @
18e2b23a
import
functools
from
lightx2v.attentions.distributed.ring.attn
import
ring_attn
def
parallelize_hunyuan
(
hunyuan_model
):
from
lightx2v.attentions.distributed.utils.hunyuan.processor
import
pre_process
,
post_process
"""将 Hunyuan 模型的推理过程并行化,使用 Ulysses 注意力机制。
参数:
hunyuan_model: Hunyuan 模型实例,包含推理方法和其他属性。
"""
# 将 Hunyuan 模型的并行注意力机制替换为 Ulysses 注意力
hunyuan_model
.
transformer_infer
.
parallel_attention
=
ring_attn
# 保存原始的推理方法,以便后续调用
original_infer
=
hunyuan_model
.
infer
@
functools
.
wraps
(
hunyuan_model
.
__class__
.
infer
)
# 保留原始推理方法的元信息
def
new_infer
(
self
,
text_encoders_output
,
image_encoder_output
,
args
):
"""新的推理方法,处理输入并调用原始推理方法。
参数:
self: Hunyuan 模型实例
text_encoders_output: 文本编码器的输出
args: 其他参数
返回:
None
"""
# 保存原始的潜在模型输入和频率数据
self
.
scheduler
.
ori_latents
,
self
.
scheduler
.
ori_freqs_cos
,
self
.
scheduler
.
ori_freqs_sin
=
(
self
.
scheduler
.
latents
,
self
.
scheduler
.
freqs_cos
,
self
.
scheduler
.
freqs_sin
)
# 预处理输入数据以适应并行计算
self
.
scheduler
.
latents
,
self
.
scheduler
.
freqs_cos
,
self
.
scheduler
.
freqs_sin
,
split_dim
=
pre_process
(
self
.
scheduler
.
latents
,
self
.
scheduler
.
freqs_cos
,
self
.
scheduler
.
freqs_sin
)
# 调用原始推理方法,获取输出
original_infer
(
text_encoders_output
,
image_encoder_output
,
args
)
# 对输出进行后处理
self
.
scheduler
.
noise_pred
=
post_process
(
self
.
scheduler
.
noise_pred
,
split_dim
)
# 恢复原始的潜在模型输入和频率数据
self
.
scheduler
.
latents
,
self
.
scheduler
.
freqs_cos
,
self
.
scheduler
.
freqs_sin
=
(
self
.
scheduler
.
ori_latents
,
self
.
scheduler
.
ori_freqs_cos
,
self
.
scheduler
.
ori_freqs_sin
)
# return combined_output # 返回处理后的输出(当前被注释掉)
# 将新的推理方法绑定到 Hunyuan 模型实例
new_infer
=
new_infer
.
__get__
(
hunyuan_model
)
hunyuan_model
.
infer
=
new_infer
# 替换原始推理方法
def
parallelize_wan
(
wan_model
):
from
lightx2v.attentions.distributed.utils.wan.processor
import
pre_process
,
post_process
wan_model
.
transformer_infer
.
parallel_attention
=
ring_attn
original_infer
=
wan_model
.
transformer_infer
.
infer
@
functools
.
wraps
(
wan_model
.
transformer_infer
.
__class__
.
infer
)
# 保留原始推理方法的元信息
def
new_infer
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
):
x
=
pre_process
(
x
)
x
=
original_infer
(
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
)
x
=
post_process
(
x
)
return
x
new_infer
=
new_infer
.
__get__
(
wan_model
.
transformer_infer
)
wan_model
.
transformer_infer
.
infer
=
new_infer
# 替换原始推理方法
lightx2v/attentions/distributed/ulysses/__init__.py
deleted
100644 → 0
View file @
18e2b23a
lightx2v/attentions/distributed/ulysses/attn.py
deleted
100644 → 0
View file @
18e2b23a
import
torch
import
torch.distributed
as
dist
from
lightx2v.attentions
import
attention
from
lightx2v.attentions.distributed.comm.all2all
import
all2all_seq2head
,
all2all_head2seq
def
ulysses_attn
(
q
,
k
,
v
,
img_qkv_len
,
cu_seqlens_qkv
,
attention_type
=
"flash_attn2"
):
"""
执行 Ulysses 注意力机制,结合图像和文本的查询、键和值。
参数:
q (torch.Tensor): 查询张量,形状为 [shard_seqlen, heads, hidden_dims]
k (torch.Tensor): 键张量,形状为 [shard_seqlen, heads, hidden_dims]
v (torch.Tensor): 值张量,形状为 [shard_seqlen, heads, hidden_dims]
img_qkv_len (int): 图像查询、键和值的长度
cu_seqlens_qkv (torch.Tensor): 累积序列长度,包含文本和图像的长度信息
attention_type (str): 注意力类型,默认为 "flash_attn2"
返回:
torch.Tensor: 计算得到的注意力结果
"""
# 获取当前进程的排名和全局进程数
cur_rank
=
dist
.
get_rank
()
world_size
=
dist
.
get_world_size
()
# 获取序列长度和文本相关的长度
seq_len
=
q
.
shape
[
0
]
if
len
(
cu_seqlens_qkv
)
==
3
:
txt_qkv_len
=
cu_seqlens_qkv
[
1
]
-
img_qkv_len
# 文本查询、键和值的长度
txt_mask_len
=
cu_seqlens_qkv
[
2
]
-
img_qkv_len
# 文本掩码长度
elif
len
(
cu_seqlens_qkv
)
==
2
:
txt_qkv_len
=
cu_seqlens_qkv
[
1
]
-
img_qkv_len
# 文本查询、键和值的长度
txt_mask_len
=
None
# 获取查询张量的头数和隐藏维度
_
,
heads
,
hidden_dims
=
q
.
shape
shard_heads
=
heads
//
world_size
# 每个进程处理的头数
shard_seqlen
=
img_qkv_len
# 每个进程处理的序列长度
# 分割图像和文本的查询、键和值
img_q
,
img_k
,
img_v
=
q
[:
img_qkv_len
,
:,
:].
contiguous
(),
k
[:
img_qkv_len
,
:,
:].
contiguous
(),
v
[:
img_qkv_len
,
:,
:].
contiguous
()
txt_q
,
txt_k
,
txt_v
=
q
[
img_qkv_len
:,
:,
:].
contiguous
(),
k
[
img_qkv_len
:,
:,
:].
contiguous
(),
v
[
img_qkv_len
:,
:,
:].
contiguous
()
# 将图像的查询、键和值转换为头的格式
img_q
=
all2all_seq2head
(
img_q
)
img_k
=
all2all_seq2head
(
img_k
)
img_v
=
all2all_seq2head
(
img_v
)
torch
.
cuda
.
synchronize
()
# 确保CUDA操作完成
# 处理文本的查询、键和值,选择当前进程的头
txt_q
=
txt_q
[:,
cur_rank
*
shard_heads
:
(
cur_rank
+
1
)
*
shard_heads
,
:]
txt_k
=
txt_k
[:,
cur_rank
*
shard_heads
:
(
cur_rank
+
1
)
*
shard_heads
,
:]
txt_v
=
txt_v
[:,
cur_rank
*
shard_heads
:
(
cur_rank
+
1
)
*
shard_heads
,
:]
# 合并图像和文本的查询、键和值
q
=
torch
.
cat
((
img_q
,
txt_q
),
dim
=
0
)
k
=
torch
.
cat
((
img_k
,
txt_k
),
dim
=
0
)
v
=
torch
.
cat
((
img_v
,
txt_v
),
dim
=
0
)
# 初始化累积序列长度张量
cu_seqlens_qkv
=
torch
.
zeros
([
2
],
dtype
=
torch
.
int32
,
device
=
"cuda"
)
s
=
txt_qkv_len
+
img_q
.
shape
[
0
]
# 计算文本和图像的总长度
s1
=
s
# 当前样本的结束位置
cu_seqlens_qkv
[
1
]
=
s1
# 设置累积序列长度
if
txt_mask_len
:
s2
=
txt_mask_len
+
img_q
.
shape
[
0
]
# 文本掩码的结束位置
cu_seqlens_qkv
=
torch
.
cat
(
cu_seqlens_qkv
,
s2
)
max_seqlen_qkv
=
img_q
.
shape
[
0
]
+
txt_q
.
shape
[
0
]
# 最大序列长度
# 调用注意力函数计算注意力结果
attn
=
attention
(
attention_type
=
attention_type
,
q
=
q
,
k
=
k
,
v
=
v
,
cu_seqlens_q
=
cu_seqlens_qkv
,
cu_seqlens_kv
=
cu_seqlens_qkv
,
max_seqlen_q
=
max_seqlen_qkv
,
max_seqlen_kv
=
max_seqlen_qkv
)
# 分割图像和文本的注意力结果
img_attn
,
txt_attn
=
attn
[:
img_q
.
shape
[
0
],
:],
attn
[
img_q
.
shape
[
0
]
:,]
# 收集所有进程的文本注意力结果
gathered_txt_attn
=
[
torch
.
empty_like
(
txt_attn
)
for
_
in
range
(
world_size
)]
dist
.
all_gather
(
gathered_txt_attn
,
txt_attn
)
# 处理图像注意力结果
img_attn
=
img_attn
.
reshape
(
world_size
*
shard_seqlen
,
shard_heads
,
hidden_dims
)
# 重塑图像注意力结果
img_attn
=
all2all_head2seq
(
img_attn
)
# 将头的格式转换回序列格式
img_attn
=
img_attn
.
reshape
(
shard_seqlen
,
-
1
)
# 重塑为 [shard_seqlen, -1] 形状
torch
.
cuda
.
synchronize
()
# 确保CUDA操作完成
txt_attn
=
torch
.
cat
(
gathered_txt_attn
,
dim
=
1
)
# 合并所有进程的文本注意力结果
# 合并图像和文本的注意力结果
attn
=
torch
.
cat
([
img_attn
,
txt_attn
],
dim
=
0
)
return
attn
# 返回最终的注意力结果
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment