vae.py 4.63 KB
Newer Older
unknown's avatar
unknown committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
# Copyright (c) 2019, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import tensorflow as tf
from vae.models.layers import DenseFromSparse

TRAINING = 0
VALIDATION = 1
QUERY = 2


class _VAEGraph(tf.keras.Model):
    def __init__(self, encoder_dims, decoder_dims, activation='tanh'):
        super(_VAEGraph, self).__init__()
        if encoder_dims[-1] != decoder_dims[0]:
            raise Exception("encoder/decoder dims mismatch")
        self.input_dim = encoder_dims[0]
        self.output_dim = decoder_dims[-1]
        self.activation = tf.nn.tanh if activation == 'tanh' else tf.nn.relu
        self.encoder = self.encoder_model(encoder_dims[1:])
        self.decoder = self.decoder_model(decoder_dims[1:])

    def call(self, inputs: tf.SparseTensor, mode):
        """ Get handlers to VAE output
        :param inputs: batch_size * items_count as sparse tensor.
        :param mode: Either 0,1 or 2 representing type of network
        :return: Tuple of 3 tensors:
            1. decoder output: batch_size * items_count tensor
            2. latent_mean: mean tensor between encoder and decoder. It has size batch_size * size_of_mean_vector
            3. latent_log_var: tesor containing logarithms of variances. It has size batch_size * size_of_var_vector
        """

        latent_all = self.encoder(inputs, training=(mode is TRAINING))
        latent_mean = latent_all[:, 0]
        latent_log_var = latent_all[:, 1]
        latent_std = tf.exp(0.5 * latent_log_var)

        # reparametrization trick
        batch = tf.shape(latent_mean)[0]
        dim = tf.shape(latent_mean)[1]
        epsilon = tf.random_normal(shape=(batch, dim))
        decoder_input = latent_mean + (int(mode is TRAINING)) * latent_std * epsilon

        decoder_output = self.decoder(decoder_input, training=(mode is TRAINING))

        return decoder_output, latent_mean, latent_log_var


    def encoder_model(self, dims):
        assert dims
        last = dims[-1]
        dims[-1] = 2 * last
        layers = tf.keras.layers
        return tf.keras.Sequential(
            [DenseFromSparse(
                    dims[0],
                    activation=self.activation,
                    name="encoder_{}".format(dims[0]),
                    kernel_initializer=tf.contrib.layers.xavier_initializer(),
                    bias_initializer=tf.truncated_normal_initializer(stddev=0.001),
                    kernel_regularizer=tf.contrib.layers.l2_regularizer)
            ] + [
                layers.Dense(
                    d,
                    activation=self.activation,
                    name="encoder_{}".format(d),
                    kernel_initializer=tf.contrib.layers.xavier_initializer(),
                    bias_initializer=tf.truncated_normal_initializer(stddev=0.001),
                    kernel_regularizer=tf.contrib.layers.l2_regularizer)
                for d in dims[1:-1]
            ] + [
                layers.Dense(
                    dims[-1],
                    name="encoder_{}".format(dims[-1]),
                    kernel_initializer=tf.contrib.layers.xavier_initializer(),
                    bias_initializer=tf.truncated_normal_initializer(stddev=0.001),
                    kernel_regularizer=tf.contrib.layers.l2_regularizer)
            ] + [layers.Reshape(target_shape=(2, last))])


    def decoder_model(self, dims):
        assert dims
        layers = tf.keras.layers
        return tf.keras.Sequential([
            layers.Dense(
                d,
                activation=self.activation,
                name="decoder_{}".format(d),
                kernel_initializer=tf.contrib.layers.xavier_initializer(),
                bias_initializer=tf.truncated_normal_initializer(stddev=0.001),
                kernel_regularizer=tf.contrib.layers.l2_regularizer) for d in dims[:-1]
            ] + [
                layers.Dense(
                dims[-1],
                name="decoder_{}".format(dims[-1]),
                kernel_initializer=tf.contrib.layers.xavier_initializer(),
                bias_initializer=tf.truncated_normal_initializer(stddev=0.001),
                kernel_regularizer=tf.contrib.layers.l2_regularizer)
            ])