Commit b903d3d3 authored by patil-suraj's avatar patil-suraj
Browse files

fix einsum

parent a9374a02
......@@ -6,7 +6,7 @@ import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import repeat, rearrange, einsum
from einops import repeat, rearrange
from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
......@@ -180,7 +180,7 @@ class CrossAttention(nn.Module):
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale
if exists(mask):
mask = rearrange(mask, 'b ... -> b (...)')
......@@ -191,7 +191,7 @@ class CrossAttention(nn.Module):
# attention, what we cannot get enough of
attn = sim.softmax(dim=-1)
out = einsum('b i j, b j d -> b i d', attn, v)
out = torch.einsum('b i j, b j d -> b i d', attn, v)
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
return self.to_out(out)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment