"cacheflow/core/server.py" did not exist on "cfae35b861c5fc0c9f3689f99c7aba2e4501beb8"
Unverified Commit 3bde773d authored by Myle Ott's avatar Myle Ott Committed by GitHub
Browse files

More fixes for recent PyTorch (incl. topk issue) (#113)

parent 21b8fb5c
......@@ -24,7 +24,7 @@ If you use the code in your paper, then please cite it as:
* Python version 3.6
* A [PyTorch installation](http://pytorch.org/)
Currently fairseq-py requires PyTorch version >= 0.3.0.
Currently fairseq-py requires PyTorch version >= 0.4.0.
Please follow the instructions here: https://github.com/pytorch/pytorch#installation.
If you use Docker make sure to increase the shared memory size either with `--ipc=host` or `--shm-size` as command line
......
......@@ -271,8 +271,8 @@ class SequenceGenerator(object):
if step < maxlen:
if prefix_tokens is not None and step < prefix_tokens.size(1):
probs_slice = probs.view(bsz, -1, probs.size(-1))[:, 0, :]
cand_scores = probs_slice.gather(
dim=1,
cand_scores = torch.gather(
probs_slice, dim=1,
index=prefix_tokens[:, step].view(-1, 1).data
).expand(-1, cand_size)
cand_indices = prefix_tokens[:, step].view(-1, 1).expand(bsz, cand_size).data
......@@ -341,17 +341,17 @@ class SequenceGenerator(object):
# get the top beam_size active hypotheses, which are just the hypos
# with the smallest values in active_mask
active_hypos, _ignore = buffer('active_hypos'), buffer('_ignore')
active_mask.topk(
k=beam_size, dim=1, largest=False,
out=(_ignore, active_hypos),
torch.topk(
active_mask, k=beam_size, dim=1, largest=False,
out=(_ignore, active_hypos)
)
active_bbsz_idx = buffer('active_bbsz_idx')
cand_bbsz_idx.gather(
dim=1, index=active_hypos,
torch.gather(
cand_bbsz_idx, dim=1, index=active_hypos,
out=active_bbsz_idx,
)
active_scores = cand_scores.gather(
dim=1, index=active_hypos,
active_scores = torch.gather(
cand_scores, dim=1, index=active_hypos,
out=scores[:, step].view(bsz, beam_size),
)
active_bbsz_idx = active_bbsz_idx.view(-1)
......
......@@ -127,7 +127,10 @@ def main(args):
hypo['positional_scores'].tolist(),
))
))
print('A-{}\t{}'.format(sample_id, ' '.join(map(str, alignment))))
print('A-{}\t{}'.format(
sample_id,
' '.join(map(lambda x: str(utils.item(x)), alignment))
))
# Score only the top hypothesis
if i == 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment