Commit 40f7870f authored by mashun1's avatar mashun1
Browse files

Update util.py

parent 42f6a485
......@@ -225,7 +225,7 @@ class MemoryEfficientCrossAttention(nn.Module):
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
self.attention_op: Optional[Any] = None
self.attention_op: Optional[Any] = xformers.ops.fmha.MemoryEfficientAttentionFlashAttentionOp
def forward(self, x, context=None, mask=None):
q = self.to_q(x)
......@@ -441,7 +441,7 @@ class MemoryEfficientCrossAttention_attemask(nn.Module):
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
self.attention_op: Optional[Any] = None
self.attention_op: Optional[Any] = xformers.ops.fmha.MemoryEfficientAttentionFlashAttentionOp
def forward(self, x, context=None, mask=None):
q = self.to_q(x)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment