"examples/git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "66f661df0a102aeebbeaa5336599fbfa0467e5b0"
Commit 1f96d284 authored by Liang Wang's avatar Liang Wang Committed by Facebook Github Bot
Browse files

Fix topp sampling issues (#882)

Summary:
Two issues here:

1. `last_included` should be the last included index `cumsum_mask[:, :, -1:]` instead of `cumsum_mask[:, :, :1]`  (which is either 0 or 1);

2. If `--no-repeat-ngram-size` is set, the sum of `probs` may less than 1, we need to re-normalize to make it a valid probability distribution

The following code can reproduce this issues:

```
import torch
import numpy as np

def _sample_topp(probs):

    # =====  Code from  fairseq/search.py _sample_topp ======

    # sort the last dimension (vocab dimension) in descending order
    sorted_probs, sorted_indices = probs.sort(descending=True)

    # compute a mask to indicate the words to be included in the top-P set.
    cumsum_probs = sorted_probs.cumsum(dim=2)
    mask = cumsum_probs.lt(sampling_topp)

    # note that mask was computed by 'lt'. One more word needs to be included
    # so that the cumulative probability mass can exceed p.
    cumsum_mask = mask.cumsum(dim=2)
    last_included = cumsum_mask[:, :, :1]
    mask = mask.scatter_(2, last_included, 1)

    # truncate unnecessary dims.
    max_dim = last_included.max()
    truncated_mask = mask[:, :, :max_dim + 1]
    truncated_probs = sorted_probs[:, :, :max_dim + 1]
    truncated_indices = sorted_indices[:, :, :max_dim + 1]

    # trim the words that are not in top-P by setting their probabilities
    # to 0, so that they would not be sampled later.
    trim_mask = 1 - truncated_mask
    trimed_probs = truncated_probs.masked_fill_(trim_mask, 0)
    return trimed_probs, truncated_indices

    # ========================================================

if __name__ == '__main__':
    np.random.seed(1234)
    torch.manual_seed(1234)

    sampling_topp = 0.9
    probs = torch.softmax(torch.randn(1, 1, 10), dim=-1)
    # probs = tensor([0.0545, 0.0779, 0.0189, 0.0647, 0.0282, 0.0862, 0.0656, 0.1041, 0.0399, 0.4600])
    print('probs =', probs[0][0])

    trimed_probs, truncated_indices = _sample_topp(probs)

    cum_probs = trimed_probs.cumsum(dim=-1)[0][0]
    # cumsum = tensor([0.4600, 0.5641])
    print('cumsum =', cum_probs)
    # Will throw AssertionError
    assert float(cum_probs[-1]) >= sampling_topp

```
Pull Request resolved: https://github.com/pytorch/fairseq/pull/882

Differential Revision: D16409269

Pulled By: xingz9

fbshipit-source-id: 94b1122eed50c656057b64e22af6f4a6ea7a68af
parent f812e529
...@@ -202,7 +202,8 @@ class Sampling(Search): ...@@ -202,7 +202,8 @@ class Sampling(Search):
# note that mask was computed by 'lt'. One more word needs to be included # note that mask was computed by 'lt'. One more word needs to be included
# so that the cumulative probability mass can exceed p. # so that the cumulative probability mass can exceed p.
cumsum_mask = mask.cumsum(dim=2) cumsum_mask = mask.cumsum(dim=2)
last_included = cumsum_mask[:, :, :1] last_included = cumsum_mask[:, :, -1:]
last_included.clamp_(0, mask.size()[2] - 1)
mask = mask.scatter_(2, last_included, 1) mask = mask.scatter_(2, last_included, 1)
# truncate unnecessary dims. # truncate unnecessary dims.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment