"vscode:/vscode.git/clone" did not exist on "2a7553ce09cf1ae5a93a290acd5109f2327cd6da"
VariationalAutoencoderRunner.py 1.7 KB
Newer Older
Jiří Vahala's avatar
Jiří Vahala committed
1
2
3
4
5
6
7
8
9
10
11
import numpy as np

import sklearn.preprocessing as prep
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

from autoencoder.autoencoder_models.VariationalAutoencoder import VariationalAutoencoder

mnist = input_data.read_data_sets('MNIST_data', one_hot = True)


12
13
def min_max_scale(X_train, X_test):
    preprocessor = prep.MinMaxScaler().fit(X_train)
Jiří Vahala's avatar
Jiří Vahala committed
14
15
16
17
18
19
20
21
22
23
    X_train = preprocessor.transform(X_train)
    X_test = preprocessor.transform(X_test)
    return X_train, X_test


def get_random_block_from_data(data, batch_size):
    start_index = np.random.randint(0, len(data) - batch_size)
    return data[start_index:(start_index + batch_size)]


24
X_train, X_test = min_max_scale(mnist.train.images, mnist.test.images)
Jiří Vahala's avatar
Jiří Vahala committed
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53

n_samples = int(mnist.train.num_examples)
training_epochs = 20
batch_size = 128
display_step = 1

autoencoder = VariationalAutoencoder(n_input = 784,
                                     n_hidden = 200,
                                     optimizer = tf.train.AdamOptimizer(learning_rate = 0.001),
                                     gaussian_sample_size = 128)

for epoch in range(training_epochs):
    avg_cost = 0.
    total_batch = int(n_samples / batch_size)
    # Loop over all batches
    for i in range(total_batch):
        batch_xs = get_random_block_from_data(X_train, batch_size)

        # Fit training using batch data
        cost = autoencoder.partial_fit(batch_xs)
        # Compute average loss
        avg_cost += cost / n_samples * batch_size

    # Display logs per epoch step
    if epoch % display_step == 0:
        print "Epoch:", '%04d' % (epoch + 1), \
            "cost=", "{:.9f}".format(avg_cost)

print "Total cost: " + str(autoencoder.calc_total_cost(X_test))