Commit 23b892ca authored by Yunho Jeon's avatar Yunho Jeon
Browse files

fix wrong order of hidden and cell state (fix #2839)

parent bf8be1e7
......@@ -17,16 +17,16 @@ class StackedLSTMCell(nn.Module):
for _ in range(self.lstm_num_layers)])
def forward(self, inputs, hidden):
prev_c, prev_h = hidden
next_c, next_h = [], []
prev_h, prev_c = hidden
next_h, next_c = [], []
for i, m in enumerate(self.lstm_modules):
curr_c, curr_h = m(inputs, (prev_c[i], prev_h[i]))
curr_h, curr_c = m(inputs, (prev_h[i], prev_c[i]))
next_c.append(curr_c)
next_h.append(curr_h)
# current implementation only supports batch size equals 1,
# but the algorithm does not necessarily have this limitation
inputs = curr_h[-1].view(1, -1)
return next_c, next_h
return next_h, next_c
class EnasMutator(Mutator):
......@@ -136,7 +136,7 @@ class EnasMutator(Mutator):
self.sample_skip_penalty = 0
def _lstm_next_step(self):
self._c, self._h = self.lstm(self._inputs, (self._c, self._h))
self._h, self._c = self.lstm(self._inputs, (self._h, self._c))
def _mark_anchor(self, key):
self._anchors_hid[key] = self._h[-1]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment