Commit da3124dd authored by yangzhong's avatar yangzhong
Browse files

修改内存高效交叉注意力计算走flash_attn

parent c9ce7f39
......@@ -7,8 +7,6 @@ from abc import abstractmethod
import torch
import torch.nn as nn
import torch.nn.functional as F
import xformers
import xformers.ops
from einops import rearrange
from fairscale.nn.checkpoint import checkpoint_wrapper
from timm.models.vision_transformer import Mlp
......@@ -162,13 +160,20 @@ class MemoryEfficientCrossAttention(nn.Module):
v = self.to_v(context)
b, _, _ = q.shape
'''
q, k, v = map(
lambda t: t.unsqueeze(3).reshape(b, t.shape[
1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape(
b * self.heads, t.shape[1], self.dim_head).contiguous(),
(q, k, v),
)
'''
q, k, v = map(
lambda t: t.unsqueeze(3).reshape(b, t.shape[
1], self.heads, self.dim_head).contiguous(),
(q, k, v),
)
'''
# actually compute the attention, what we cannot get enough of.
if q.shape[0] > self.max_bs:
q_list = torch.chunk(q, q.shape[0] // self.max_bs, dim=0)
......@@ -183,15 +188,20 @@ class MemoryEfficientCrossAttention(nn.Module):
else:
out = xformers.ops.memory_efficient_attention(
q, k, v, attn_bias=None, op=self.attention_op)
'''
from flash_attn import flash_attn_func
out = flash_attn_func(q, k, v)
if exists(mask):
raise NotImplementedError
'''
out = (
out.unsqueeze(0).reshape(
b, self.heads, out.shape[1],
self.dim_head).permute(0, 2, 1,
3).reshape(b, out.shape[1],
self.heads * self.dim_head))
'''
out = out.reshape(b, out.shape[1],self.heads * self.dim_head)
return self.to_out(out)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment