Commit 206f267e authored by Rick Ho's avatar Rick Ho
Browse files

fix swipe eval

parent 57bdfe88
...@@ -26,11 +26,11 @@ class SwipeGate(NaiveGate): ...@@ -26,11 +26,11 @@ class SwipeGate(NaiveGate):
def forward(self, inp): def forward(self, inp):
score = self.gate(inp) score = self.gate(inp)
_, orig_idx = torch.topk(score, k=self.top_k, dim=-1) orig_score, orig_idx = torch.topk(score, k=self.top_k, dim=-1)
if not self.training: if not self.training:
topk_val = F.softmax(topk_val, dim=-1) topk_val = F.softmax(orig_score, dim=-1)
return topk_idx, topk_val return orig_idx, topk_val
capacity = torch.scalar_tensor(inp.shape[0] * self.top_k, capacity = torch.scalar_tensor(inp.shape[0] * self.top_k,
dtype=torch.long) dtype=torch.long)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment