Commit b13d6e51 authored by Dan O'Shea's avatar Dan O'Shea
Browse files

LFADS: Fixing alignment bias bug

parent 7dde2e2f
...@@ -281,7 +281,7 @@ class LFADS(object): ...@@ -281,7 +281,7 @@ class LFADS(object):
"""Create an LFADS model. """Create an LFADS model.
train - a model for training, sampling of posteriors is used train - a model for training, sampling of posteriors is used
posterior_sample_and_average - sample from the posterior, this is used posterior_sample_and_average - sample from the posterior, this is used
for evaluating the expected value of the outputs of LFADS, given a for evaluating the expected value of the outputs of LFADS, given a
specific input, by averaging over multiple samples from the approx specific input, by averaging over multiple samples from the approx
posterior. Also used for the lower bound on the negative posterior. Also used for the lower bound on the negative
...@@ -409,6 +409,12 @@ class LFADS(object): ...@@ -409,6 +409,12 @@ class LFADS(object):
dataset = datasets[name] dataset = datasets[name]
in_mat_cxf = dataset['alignment_matrix_cxf'].astype(np.float32) in_mat_cxf = dataset['alignment_matrix_cxf'].astype(np.float32)
if datasets and 'alignment_bias_c' in datasets[name].keys():
dataset = datasets[name]
print("Using alignment bias provided for dataset:", name)
align_bias_c = dataset['alignment_bias_c'].astype(np.float32)
align_bias_1xc = np.expand_dims(align_bias_c, axis=0)
out_mat_fxc = None out_mat_fxc = None
out_bias_1xc = None out_bias_1xc = None
if in_mat_cxf is not None: if in_mat_cxf is not None:
...@@ -1714,7 +1720,7 @@ class LFADS(object): ...@@ -1714,7 +1720,7 @@ class LFADS(object):
out_dist_params = np.zeros([E_to_process, T, D+D]) out_dist_params = np.zeros([E_to_process, T, D+D])
else: else:
assert False, "NIY" assert False, "NIY"
costs = np.zeros(E_to_process) costs = np.zeros(E_to_process)
nll_bound_vaes = np.zeros(E_to_process) nll_bound_vaes = np.zeros(E_to_process)
nll_bound_iwaes = np.zeros(E_to_process) nll_bound_iwaes = np.zeros(E_to_process)
...@@ -1914,7 +1920,7 @@ class LFADS(object): ...@@ -1914,7 +1920,7 @@ class LFADS(object):
for i, (var, var_eval) in enumerate(zip(all_tf_vars, all_tf_vars_eval)): for i, (var, var_eval) in enumerate(zip(all_tf_vars, all_tf_vars_eval)):
if any(s in include_strs for s in var.name): if any(s in include_strs for s in var.name):
if not isinstance(var_eval, np.ndarray): # for H5PY if not isinstance(var_eval, np.ndarray): # for H5PY
print(var.name, """ is not numpy array, saving as numpy array print(var.name, """ is not numpy array, saving as numpy array
with value: """, var_eval, type(var_eval)) with value: """, var_eval, type(var_eval))
e = np.array(var_eval) e = np.array(var_eval)
print(e, type(e)) print(e, type(e))
......
...@@ -112,7 +112,7 @@ def init_linear(in_size, out_size, do_bias=True, mat_init_value=None, ...@@ -112,7 +112,7 @@ def init_linear(in_size, out_size, do_bias=True, mat_init_value=None,
'Provided mat_init_value must have shape [%d, %d].'%(in_size, out_size)) 'Provided mat_init_value must have shape [%d, %d].'%(in_size, out_size))
if bias_init_value is not None and bias_init_value.shape != (1,out_size): if bias_init_value is not None and bias_init_value.shape != (1,out_size):
raise ValueError( raise ValueError(
'Provided bias_init_value must have shape [1,%d].'%(1,out_size)) 'Provided bias_init_value must have shape [1,%d].'%(out_size,))
if mat_init_value is None: if mat_init_value is None:
stddev = alpha/np.sqrt(float(in_size)) stddev = alpha/np.sqrt(float(in_size))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment