Unverified Commit 5318dd46 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Merge pull request #2842 from jyh2986/master

fix wrong order of hidden and cell state
parents e51aca42 23b892ca
...@@ -17,16 +17,16 @@ class StackedLSTMCell(nn.Module): ...@@ -17,16 +17,16 @@ class StackedLSTMCell(nn.Module):
for _ in range(self.lstm_num_layers)]) for _ in range(self.lstm_num_layers)])
def forward(self, inputs, hidden): def forward(self, inputs, hidden):
prev_c, prev_h = hidden prev_h, prev_c = hidden
next_c, next_h = [], [] next_h, next_c = [], []
for i, m in enumerate(self.lstm_modules): for i, m in enumerate(self.lstm_modules):
curr_c, curr_h = m(inputs, (prev_c[i], prev_h[i])) curr_h, curr_c = m(inputs, (prev_h[i], prev_c[i]))
next_c.append(curr_c) next_c.append(curr_c)
next_h.append(curr_h) next_h.append(curr_h)
# current implementation only supports batch size equals 1, # current implementation only supports batch size equals 1,
# but the algorithm does not necessarily have this limitation # but the algorithm does not necessarily have this limitation
inputs = curr_h[-1].view(1, -1) inputs = curr_h[-1].view(1, -1)
return next_c, next_h return next_h, next_c
class EnasMutator(Mutator): class EnasMutator(Mutator):
...@@ -136,7 +136,7 @@ class EnasMutator(Mutator): ...@@ -136,7 +136,7 @@ class EnasMutator(Mutator):
self.sample_skip_penalty = 0 self.sample_skip_penalty = 0
def _lstm_next_step(self): def _lstm_next_step(self):
self._c, self._h = self.lstm(self._inputs, (self._c, self._h)) self._h, self._c = self.lstm(self._inputs, (self._h, self._c))
def _mark_anchor(self, key): def _mark_anchor(self, key):
self._anchors_hid[key] = self._h[-1] self._anchors_hid[key] = self._h[-1]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment