Unverified Commit 888468dd authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Remove nn sequential (#1086)

* Remove nn sequential

* up
parent 17c2c060
......@@ -244,7 +244,9 @@ class CrossAttention(nn.Module):
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
self.to_out = nn.ModuleList([])
self.to_out.append(nn.Linear(inner_dim, query_dim))
self.to_out.append(nn.Dropout(dropout))
def reshape_heads_to_batch_dim(self, tensor):
batch_size, seq_len, dim = tensor.shape
......@@ -283,7 +285,11 @@ class CrossAttention(nn.Module):
else:
hidden_states = self._sliced_attention(query, key, value, sequence_length, dim)
return self.to_out(hidden_states)
# linear proj
hidden_states = self.to_out[0](hidden_states)
# dropout
hidden_states = self.to_out[1](hidden_states)
return hidden_states
def _attention(self, query, key, value):
# TODO: use baddbmm for better performance
......@@ -354,12 +360,19 @@ class FeedForward(nn.Module):
super().__init__()
inner_dim = int(dim * mult)
dim_out = dim_out if dim_out is not None else dim
project_in = GEGLU(dim, inner_dim)
self.net = nn.ModuleList([])
self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
# project in
self.net.append(GEGLU(dim, inner_dim))
# project dropout
self.net.append(nn.Dropout(dropout))
# project out
self.net.append(nn.Linear(inner_dim, dim_out))
def forward(self, hidden_states):
return self.net(hidden_states)
for module in self.net:
hidden_states = module(hidden_states)
return hidden_states
# feedforward
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment