Commit f1a9a129 authored by mashun1's avatar mashun1
Browse files

update

parent d4151fa9
File added
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
......@@ -212,7 +212,7 @@ class MemoryEfficientCrossAttention(nn.Module):
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
self.attention_op: Optional[Any] = None
self.attention_op: Optional[Any] = xformers.ops.fmha.MemoryEfficientAttentionFlashAttentionOp
def forward(self, x, context=None, mask=None):
q = self.to_q(x)
......@@ -221,25 +221,26 @@ class MemoryEfficientCrossAttention(nn.Module):
v = self.to_v(context)
b, _, _ = q.shape
print("========================", q.shape, k.shape, v.shape, self.heads)
q, k, v = map(
lambda t: t.unsqueeze(3)
lambda t: t
.reshape(b, t.shape[1], self.heads, self.dim_head)
.permute(0, 2, 1, 3)
.reshape(b * self.heads, t.shape[1], self.dim_head)
.contiguous(),
(q, k, v),
)
# actually compute the attention, what we cannot get enough of
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
try:
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
except:
print("error========================", q.shape, k.shape, v.shape)
raise
if exists(mask):
raise NotImplementedError
out = (
out.unsqueeze(0)
.reshape(b, self.heads, out.shape[1], self.dim_head)
.permute(0, 2, 1, 3)
.reshape(b, out.shape[1], self.heads * self.dim_head)
out.reshape(b, out.shape[1], self.heads * self.dim_head)
)
return self.to_out(out)
......@@ -254,7 +255,8 @@ class BasicTransformerBlock(nn.Module):
super().__init__()
attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax"
assert attn_mode in self.ATTENTION_MODES
attn_cls = self.ATTENTION_MODES[attn_mode]
# attn_cls = self.ATTENTION_MODES[attn_mode]
attn_cls = self.ATTENTION_MODES["softmax"]
self.disable_self_attn = disable_self_attn
self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn
......
......@@ -234,7 +234,7 @@ class MemoryEfficientAttnBlock(nn.Module):
kernel_size=1,
stride=1,
padding=0)
self.attention_op: Optional[Any] = None
self.attention_op: Optional[Any] = xformers.ops.fmha.MemoryEfficientAttentionFlashAttentionOp
def forward(self, x):
h_ = x
......@@ -358,7 +358,8 @@ class Model(nn.Module):
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
# self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
self.mid.attn_1 = AttnBlock(block_in)
self.mid.block_2 = ResnetBlock(in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
......@@ -501,7 +502,8 @@ class Encoder(nn.Module):
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
# self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
self.mid.attn_1 = AttnBlock(block_in)
self.mid.block_2 = ResnetBlock(in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
......@@ -580,7 +582,8 @@ class Decoder(nn.Module):
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
# self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
self.mid.attn_1 = AttnBlock(block_in)
self.mid.block_2 = ResnetBlock(in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
......
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment