"src/vscode:/vscode.git/clone" did not exist on "43346adc1ffa9051fc71be9af33fd982ee14c383"
Commit 206f267e authored by Rick Ho's avatar Rick Ho
Browse files

fix swipe eval

parent 57bdfe88
......@@ -26,15 +26,15 @@ class SwipeGate(NaiveGate):
def forward(self, inp):
score = self.gate(inp)
_, orig_idx = torch.topk(score, k=self.top_k, dim=-1)
orig_score, orig_idx = torch.topk(score, k=self.top_k, dim=-1)
if not self.training:
topk_val = F.softmax(topk_val, dim=-1)
return topk_idx, topk_val
topk_val = F.softmax(orig_score, dim=-1)
return orig_idx, topk_val
capacity = torch.scalar_tensor(inp.shape[0] * self.top_k,
dtype=torch.long)
topk_idxs = []
topk_vals = []
idx_x = torch.arange(inp.shape[0], device=inp.device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment