"vscode:/vscode.git/clone" did not exist on "298ce679992e439a9869632579a2f853ea2b095e"
Commit da3124dd authored by yangzhong's avatar yangzhong
Browse files

修改内存高效交叉注意力计算走flash_attn

parent c9ce7f39
...@@ -7,8 +7,6 @@ from abc import abstractmethod ...@@ -7,8 +7,6 @@ from abc import abstractmethod
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import xformers
import xformers.ops
from einops import rearrange from einops import rearrange
from fairscale.nn.checkpoint import checkpoint_wrapper from fairscale.nn.checkpoint import checkpoint_wrapper
from timm.models.vision_transformer import Mlp from timm.models.vision_transformer import Mlp
...@@ -162,13 +160,20 @@ class MemoryEfficientCrossAttention(nn.Module): ...@@ -162,13 +160,20 @@ class MemoryEfficientCrossAttention(nn.Module):
v = self.to_v(context) v = self.to_v(context)
b, _, _ = q.shape b, _, _ = q.shape
'''
q, k, v = map( q, k, v = map(
lambda t: t.unsqueeze(3).reshape(b, t.shape[ lambda t: t.unsqueeze(3).reshape(b, t.shape[
1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape( 1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape(
b * self.heads, t.shape[1], self.dim_head).contiguous(), b * self.heads, t.shape[1], self.dim_head).contiguous(),
(q, k, v), (q, k, v),
) )
'''
q, k, v = map(
lambda t: t.unsqueeze(3).reshape(b, t.shape[
1], self.heads, self.dim_head).contiguous(),
(q, k, v),
)
'''
# actually compute the attention, what we cannot get enough of. # actually compute the attention, what we cannot get enough of.
if q.shape[0] > self.max_bs: if q.shape[0] > self.max_bs:
q_list = torch.chunk(q, q.shape[0] // self.max_bs, dim=0) q_list = torch.chunk(q, q.shape[0] // self.max_bs, dim=0)
...@@ -183,15 +188,20 @@ class MemoryEfficientCrossAttention(nn.Module): ...@@ -183,15 +188,20 @@ class MemoryEfficientCrossAttention(nn.Module):
else: else:
out = xformers.ops.memory_efficient_attention( out = xformers.ops.memory_efficient_attention(
q, k, v, attn_bias=None, op=self.attention_op) q, k, v, attn_bias=None, op=self.attention_op)
'''
from flash_attn import flash_attn_func
out = flash_attn_func(q, k, v)
if exists(mask): if exists(mask):
raise NotImplementedError raise NotImplementedError
'''
out = ( out = (
out.unsqueeze(0).reshape( out.unsqueeze(0).reshape(
b, self.heads, out.shape[1], b, self.heads, out.shape[1],
self.dim_head).permute(0, 2, 1, self.dim_head).permute(0, 2, 1,
3).reshape(b, out.shape[1], 3).reshape(b, out.shape[1],
self.heads * self.dim_head)) self.heads * self.dim_head))
'''
out = out.reshape(b, out.shape[1],self.heads * self.dim_head)
return self.to_out(out) return self.to_out(out)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment