Commit 430905d7 authored by Iurii Zdebskyi's avatar Iurii Zdebskyi Committed by Facebook Github Bot
Browse files

Changed tensor comparison return type from uint8 to bool (#21113)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/21113
ghimport-source-id: 9c4ba63457a72bfc41894387e0b01be3fd9a9baf

Test Plan: Imported from OSS

Differential Revision: D15552204

Pulled By: izdeby

fbshipit-source-id: a608213668649d058e22b510d7755cb99e7d0037
parent 4abadbdf
......@@ -39,7 +39,7 @@ class MeanPoolGatingNetwork(torch.nn.Module):
if encoder_padding_mask is not None:
encoder_out = encoder_out.clone() # required because of transpose above
encoder_out[encoder_padding_mask] = 0
ntokens = torch.sum(1 - encoder_padding_mask, dim=1, keepdim=True)
ntokens = torch.sum(~encoder_padding_mask, dim=1, keepdim=True)
x = torch.sum(encoder_out, dim=1) / ntokens.type_as(encoder_out)
else:
x = torch.mean(encoder_out, dim=1)
......
......@@ -212,7 +212,7 @@ class Sampling(Search):
# trim the words that are not in top-P by setting their probabilities
# to 0, so that they would not be sampled later.
trim_mask = 1 - truncated_mask
trim_mask = truncated_mask.bitwise_not()
trimed_probs = truncated_probs.masked_fill_(trim_mask, 0)
return trimed_probs, truncated_indices
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment