Commit 5f342527 authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Update beam search code to support torch.bool change

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/797

Differential Revision: D16617067

Pulled By: myleott

fbshipit-source-id: 52e3aeb98d6e3b55ff9154b784028bf13eabfe38
parent ccb5dea5
...@@ -212,7 +212,7 @@ class Sampling(Search): ...@@ -212,7 +212,7 @@ class Sampling(Search):
# trim the words that are not in top-P by setting their probabilities # trim the words that are not in top-P by setting their probabilities
# to 0, so that they would not be sampled later. # to 0, so that they would not be sampled later.
trim_mask = truncated_mask.bitwise_not() trim_mask = (~truncated_mask)
trimed_probs = truncated_probs.masked_fill_(trim_mask, 0) trimed_probs = truncated_probs.masked_fill_(trim_mask, 0)
return trimed_probs, truncated_indices return trimed_probs, truncated_indices
......
...@@ -163,7 +163,7 @@ class SequenceGenerator(object): ...@@ -163,7 +163,7 @@ class SequenceGenerator(object):
# For example, suppose we're sampling and have already finalized 2/5 # For example, suppose we're sampling and have already finalized 2/5
# samples. Then the blacklist would mark 2 positions as being ignored, # samples. Then the blacklist would mark 2 positions as being ignored,
# so that we only finalize the remaining 3 samples. # so that we only finalize the remaining 3 samples.
blacklist = src_tokens.new(bsz, beam_size).byte().fill_(0) blacklist = src_tokens.new_zeros(bsz, beam_size).eq(-1) # forward and backward-compatible False mask
# list of completed sentences # list of completed sentences
finalized = [[] for i in range(bsz)] finalized = [[] for i in range(bsz)]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment