Commit 0c1d2961 authored by Dan O'Shea's avatar Dan O'Shea
Browse files

Fixing init_linear with do_bias false, formatting

parent b1354256
...@@ -366,11 +366,11 @@ class LFADS(object): ...@@ -366,11 +366,11 @@ class LFADS(object):
if datasets and 'alignment_matrix_cxf' in datasets[name].keys(): if datasets and 'alignment_matrix_cxf' in datasets[name].keys():
dataset = datasets[name] dataset = datasets[name]
if hps.do_train_readin: if hps.do_train_readin:
print("Initializing trainable readin matrix with alignment matrix \ print("Initializing trainable readin matrix with alignment matrix" \
provided for dataset:", name) " provided for dataset:", name)
else: else:
print("Setting non-trainable readin matrix to alignment matrix \ print("Setting non-trainable readin matrix to alignment matrix" \
provided for dataset:", name) " provided for dataset:", name)
in_mat_cxf = dataset['alignment_matrix_cxf'].astype(np.float32) in_mat_cxf = dataset['alignment_matrix_cxf'].astype(np.float32)
if in_mat_cxf.shape != (data_dim, factors_dim): if in_mat_cxf.shape != (data_dim, factors_dim):
raise ValueError("""Alignment matrix must have dimensions %d x %d raise ValueError("""Alignment matrix must have dimensions %d x %d
...@@ -380,11 +380,11 @@ class LFADS(object): ...@@ -380,11 +380,11 @@ class LFADS(object):
if datasets and 'alignment_bias_c' in datasets[name].keys(): if datasets and 'alignment_bias_c' in datasets[name].keys():
dataset = datasets[name] dataset = datasets[name]
if hps.do_train_readin: if hps.do_train_readin:
print("Initializing trainable readin bias with alignment bias \ print("Initializing trainable readin bias with alignment bias " \
provided for dataset:", name) "provided for dataset:", name)
else: else:
print("Setting non-trainable readin bias to alignment bias \ print("Setting non-trainable readin bias to alignment bias " \
provided for dataset:", name) "provided for dataset:", name)
align_bias_c = dataset['alignment_bias_c'].astype(np.float32) align_bias_c = dataset['alignment_bias_c'].astype(np.float32)
align_bias_1xc = np.expand_dims(align_bias_c, axis=0) align_bias_1xc = np.expand_dims(align_bias_c, axis=0)
if align_bias_1xc.shape[1] != data_dim: if align_bias_1xc.shape[1] != data_dim:
...@@ -398,7 +398,9 @@ class LFADS(object): ...@@ -398,7 +398,9 @@ class LFADS(object):
in_bias_1xf = -np.dot(align_bias_1xc, in_mat_cxf) in_bias_1xf = -np.dot(align_bias_1xc, in_mat_cxf)
if hps.do_train_readin: if hps.do_train_readin:
# only add to IO transformations collection only if we want it to be learnable, because IO_transformations collection will be trained when do_train_io_only # only add to IO transformations collection only if we want it to be
# learnable, because IO_transformations collection will be trained
# when do_train_io_only
collections_readin=['IO_transformations'] collections_readin=['IO_transformations']
else: else:
collections_readin=None collections_readin=None
......
...@@ -164,7 +164,8 @@ def init_linear(in_size, out_size, do_bias=True, mat_init_value=None, ...@@ -164,7 +164,8 @@ def init_linear(in_size, out_size, do_bias=True, mat_init_value=None,
else: else:
# construct a non-learnable vector of zeros as the bias # construct a non-learnable vector of zeros as the bias
b = tf.get_variable(bname, [1, out_size], b = tf.get_variable(bname, [1, out_size],
initializer=tf.zeros_initializer(), trainable=False) initializer=tf.zeros_initializer(),
collections=b_collections, trainable=False)
return (w, b) return (w, b)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment