log_uniform_sampler.py 5.22 KB
Newer Older
Zhilin Yang's avatar
init  
Zhilin Yang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import torch
from torch import nn
import numpy as np

class LogUniformSampler(object):
    def __init__(self, range_max, n_sample):
        """
        Reference : https://github.com/tensorflow/tensorflow/blob/r1.10/tensorflow/python/ops/candidate_sampling_ops.py
            `P(class) = (log(class + 2) - log(class + 1)) / log(range_max + 1)`

        expected count can be approximated by 1 - (1 - p)^n
        and we use a numerically stable version -expm1(num_tries * log1p(-p))

        Our implementation fixes num_tries at 2 * n_sample, and the actual #samples will vary from run to run
        """
        with torch.no_grad():
            self.range_max = range_max
            log_indices = torch.arange(1., range_max+2., 1.).log_()
            self.dist = (log_indices[1:] - log_indices[:-1]) / log_indices[-1]
            # print('P', self.dist.numpy().tolist()[-30:])

            self.log_q = (- (-self.dist.double().log1p_() * 2 * n_sample).expm1_()).log_().float()

        self.n_sample = n_sample

    def sample(self, labels):
        """
            labels: [b1, b2]
        Return
            true_log_probs: [b1, b2]
            samp_log_probs: [n_sample]
            neg_samples: [n_sample]
        """

        # neg_samples = torch.empty(0).long()
        n_sample = self.n_sample
        n_tries = 2 * n_sample

        with torch.no_grad():
            neg_samples = torch.multinomial(self.dist, n_tries, replacement=True).unique()
            device = labels.device
            neg_samples = neg_samples.to(device)
            true_log_probs = self.log_q[labels].to(device)
            samp_log_probs = self.log_q[neg_samples].to(device)
            return true_log_probs, samp_log_probs, neg_samples

def sample_logits(embedding, bias, labels, inputs, sampler):
    """
        embedding: an nn.Embedding layer
        bias: [n_vocab]
        labels: [b1, b2]
        inputs: [b1, b2, n_emb]
        sampler: you may use a LogUniformSampler
    Return
        logits: [b1, b2, 1 + n_sample]
    """
    true_log_probs, samp_log_probs, neg_samples = sampler.sample(labels)
    n_sample = neg_samples.size(0)
    b1, b2 = labels.size(0), labels.size(1)
    all_ids = torch.cat([labels.view(-1), neg_samples])
    all_w = embedding(all_ids)
    true_w = all_w[: -n_sample].view(b1, b2, -1)
    sample_w = all_w[- n_sample:].view(n_sample, -1)

    all_b = bias[all_ids]
    true_b = all_b[: -n_sample].view(b1, b2)
    sample_b = all_b[- n_sample:]

    hit = (labels[:, :, None] == neg_samples).detach()

    true_logits = torch.einsum('ijk,ijk->ij',
        [true_w, inputs]) + true_b - true_log_probs
    sample_logits = torch.einsum('lk,ijk->ijl',
        [sample_w, inputs]) + sample_b - samp_log_probs
    sample_logits.masked_fill_(hit, -1e30)
    logits = torch.cat([true_logits[:, :, None], sample_logits], -1)

    return logits


# class LogUniformSampler(object):
#     def __init__(self, range_max, unique=False):
#         """
#         Reference : https://github.com/tensorflow/tensorflow/blob/r1.10/tensorflow/python/ops/candidate_sampling_ops.py
#             `P(class) = (log(class + 2) - log(class + 1)) / log(range_max + 1)`
#         """
#         self.range_max = range_max
#         log_indices = torch.arange(1., range_max+2., 1.).log_()
#         self.dist = (log_indices[1:] - log_indices[:-1]) / log_indices[-1]

#         self.unique = unique

#         if self.unique:
#             self.exclude_mask = torch.ByteTensor(range_max).fill_(0)

#     def sample(self, n_sample, labels):
#         pos_sample, new_labels = labels.unique(return_inverse=True)
#         n_pos_sample = pos_sample.size(0)
#         n_neg_sample = n_sample - n_pos_sample

#         if self.unique:
#             self.exclude_mask.index_fill_(0, pos_sample, 1)
#             sample_dist = self.dist.clone().masked_fill_(self.exclude_mask, 0)
#             self.exclude_mask.index_fill_(0, pos_sample, 0)
#         else:
#             sample_dist = self.dist

#         neg_sample = torch.multinomial(sample_dist, n_neg_sample)

#         sample = torch.cat([pos_sample, neg_sample])
#         sample_prob = self.dist[sample]

#         return new_labels, sample, sample_prob


if __name__ == '__main__':
    S, B = 3, 4
    n_vocab = 10000
    n_sample = 5
    H = 32

    labels = torch.LongTensor(S, B).random_(0, n_vocab)

    # sampler = LogUniformSampler(n_vocab, unique=False)
    # new_labels, sample, sample_prob = sampler.sample(n_sample, labels)

    sampler = LogUniformSampler(n_vocab, unique=True)
    # true_probs, samp_probs, neg_samples = sampler.sample(n_sample, labels)

    # print('true_probs', true_probs.numpy().tolist())
    # print('samp_probs', samp_probs.numpy().tolist())
    # print('neg_samples', neg_samples.numpy().tolist())

    # print('sum', torch.sum(sampler.dist).item())

    # assert torch.all(torch.sort(sample.unique())[0].eq(torch.sort(sample)[0])).item()

    embedding = nn.Embedding(n_vocab, H)
    bias = torch.zeros(n_vocab)
    inputs = torch.Tensor(S, B, H).normal_()

    logits, out_labels = sample_logits(embedding, bias, labels, inputs, sampler, n_sample)
    print('logits', logits.detach().numpy().tolist())
    print('logits shape', logits.size())
    print('out_labels', out_labels.detach().numpy().tolist())
    print('out_labels shape', out_labels.size())