Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
STAR
Commits
da3124dd
Commit
da3124dd
authored
Dec 10, 2025
by
yangzhong
Browse files
修改内存高效交叉注意力计算走flash_attn
parent
c9ce7f39
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
15 additions
and
5 deletions
+15
-5
video_to_video/modules/unet_v2v.py
video_to_video/modules/unet_v2v.py
+15
-5
No files found.
video_to_video/modules/unet_v2v.py
View file @
da3124dd
...
...
@@ -7,8 +7,6 @@ from abc import abstractmethod
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
xformers
import
xformers.ops
from
einops
import
rearrange
from
fairscale.nn.checkpoint
import
checkpoint_wrapper
from
timm.models.vision_transformer
import
Mlp
...
...
@@ -162,13 +160,20 @@ class MemoryEfficientCrossAttention(nn.Module):
v
=
self
.
to_v
(
context
)
b
,
_
,
_
=
q
.
shape
'''
q, k, v = map(
lambda t: t.unsqueeze(3).reshape(b, t.shape[
1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape(
b * self.heads, t.shape[1], self.dim_head).contiguous(),
(q, k, v),
)
'''
q
,
k
,
v
=
map
(
lambda
t
:
t
.
unsqueeze
(
3
).
reshape
(
b
,
t
.
shape
[
1
],
self
.
heads
,
self
.
dim_head
).
contiguous
(),
(
q
,
k
,
v
),
)
'''
# actually compute the attention, what we cannot get enough of.
if q.shape[0] > self.max_bs:
q_list = torch.chunk(q, q.shape[0] // self.max_bs, dim=0)
...
...
@@ -183,15 +188,20 @@ class MemoryEfficientCrossAttention(nn.Module):
else:
out = xformers.ops.memory_efficient_attention(
q, k, v, attn_bias=None, op=self.attention_op)
'''
from
flash_attn
import
flash_attn_func
out
=
flash_attn_func
(
q
,
k
,
v
)
if
exists
(
mask
):
raise
NotImplementedError
'''
out = (
out.unsqueeze(0).reshape(
b, self.heads, out.shape[1],
self.dim_head).permute(0, 2, 1,
3).reshape(b, out.shape[1],
self
.
heads
*
self
.
dim_head
))
self.heads * self.dim_head))
'''
out
=
out
.
reshape
(
b
,
out
.
shape
[
1
],
self
.
heads
*
self
.
dim_head
)
return
self
.
to_out
(
out
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment