operators.py 3.89 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import tensorflow as tf
import math


def weight_variable(shape):
    """weight_variable generates a weight variable of a given shape."""
    initial = tf.truncated_normal(shape, stddev=0.1)
    return tf.Variable(initial)

def bias_variable(shape):
    """bias_variable generates a bias variable of a given shape."""
    initial = tf.constant(0.1, shape=shape)
    return tf.Variable(initial)

def sum_op(inputs):
    """sum_op"""
    fixed_input = inputs[0][0]
    optional_input = inputs[1][0]
    fixed_shape = fixed_input.get_shape().as_list()
    optional_shape = optional_input.get_shape().as_list()
    assert fixed_shape[1] == fixed_shape[2]
    assert optional_shape[1] == optional_shape[2]
    pool_size = math.ceil(optional_shape[1] / fixed_shape[1])
    pool_out = tf.nn.avg_pool(optional_input, ksize=[1, pool_size, pool_size, 1], strides=[1, pool_size, pool_size, 1], padding='SAME')
    conv_matrix = weight_variable([1, 1, optional_shape[3], fixed_shape[3]])
    conv_out = tf.nn.conv2d(pool_out, conv_matrix, strides=[1, 1, 1, 1], padding='SAME')
    return fixed_input + conv_out


def conv2d(inputs, size=-1, in_ch=-1, out_ch=-1):
    """conv2d returns a 2d convolution layer with full stride."""
    if not inputs[1]:
        x_input = inputs[0][0]
    else:
        x_input = sum_op(inputs)
    if size in [1, 3]:
        w_matrix = weight_variable([size, size, in_ch, out_ch])
        return tf.nn.conv2d(x_input, w_matrix, strides=[1, 1, 1, 1], padding='SAME')
    else:
        raise Exception("Unknown filter size: %d." % size)

def twice_conv2d(inputs, size=-1, in_ch=-1, out_ch=-1):
    """twice_conv2d"""
    if not inputs[1]:
        x_input = inputs[0][0]
    else:
        x_input = sum_op(inputs)
    if size in  [3, 7]:
        w_matrix1 = weight_variable([1, size, in_ch, int(out_ch/2)])
        out = tf.nn.conv2d(x_input, w_matrix1, strides=[1, 1, 1, 1], padding='SAME')
        w_matrix2 = weight_variable([size, 1, int(out_ch/2), out_ch])
        return tf.nn.conv2d(out, w_matrix2, strides=[1, 1, 1, 1], padding='SAME')
    else:
        raise Exception("Unknown filter size: %d." % size)

def dilated_conv(inputs, size=3, in_ch=-1, out_ch=-1):
    """dilated_conv"""
    if not inputs[1]:
        x_input = inputs[0][0]
    else:
        x_input = sum_op(inputs)
    if size == 3:
        w_matrix = weight_variable([size, size, in_ch, out_ch])
        return tf.nn.atrous_conv2d(x_input, w_matrix, rate=2, padding='SAME')
    else:
        raise Exception("Unknown filter size: %d." % size)

def separable_conv(inputs, size=-1, in_ch=-1, out_ch=-1):
    """separable_conv"""
    if not inputs[1]:
        x_input = inputs[0][0]
    else:
        x_input = sum_op(inputs)
    if size in [3, 5, 7]:
        depth_matrix = weight_variable([size, size, in_ch, 1])
        point_matrix = weight_variable([1, 1, 1*in_ch, out_ch])
        return tf.nn.separable_conv2d(x_input, depth_matrix, point_matrix, strides=[1, 1, 1, 1], padding='SAME')
    else:
        raise Exception("Unknown filter size: %d." % size)


def avg_pool(inputs, size=-1):
    """avg_pool downsamples a feature map."""
    if not inputs[1]:
        x_input = inputs[0][0]
    else:
        x_input = sum_op(inputs)
    if size in [3, 5, 7]:
        return tf.nn.avg_pool(x_input, ksize=[1, size, size, 1], strides=[1, size, size, 1], padding='SAME')
    else:
        raise Exception("Unknown filter size: %d." % size)

def max_pool(inputs, size=-1):
    """max_pool downsamples a feature map."""
    if not inputs[1]:
        x_input = inputs[0][0]
    else:
        x_input = sum_op(inputs)
    if size in [3, 5, 7]:
        return tf.nn.max_pool(x_input, ksize=[1, size, size, 1], strides=[1, size, size, 1], padding='SAME')
    else:
        raise Exception("Unknown filter size: %d." % size)


def post_process(inputs, ch_size=-1):
    """post_process"""
    x_input = inputs[0][0]
    bias_matrix = bias_variable([ch_size])
    return tf.nn.relu(x_input + bias_matrix)