Commit 12a9dce2 authored by David Sussillo's avatar David Sussillo Committed by GitHub
Browse files

Merge pull request #2050 from djoshea/master

LFADS: Fixing alignment bias bug
parents 582bf927 41d700ef
...@@ -281,7 +281,7 @@ class LFADS(object): ...@@ -281,7 +281,7 @@ class LFADS(object):
"""Create an LFADS model. """Create an LFADS model.
train - a model for training, sampling of posteriors is used train - a model for training, sampling of posteriors is used
posterior_sample_and_average - sample from the posterior, this is used posterior_sample_and_average - sample from the posterior, this is used
for evaluating the expected value of the outputs of LFADS, given a for evaluating the expected value of the outputs of LFADS, given a
specific input, by averaging over multiple samples from the approx specific input, by averaging over multiple samples from the approx
posterior. Also used for the lower bound on the negative posterior. Also used for the lower bound on the negative
...@@ -409,6 +409,11 @@ class LFADS(object): ...@@ -409,6 +409,11 @@ class LFADS(object):
dataset = datasets[name] dataset = datasets[name]
in_mat_cxf = dataset['alignment_matrix_cxf'].astype(np.float32) in_mat_cxf = dataset['alignment_matrix_cxf'].astype(np.float32)
if datasets and 'alignment_bias_c' in datasets[name].keys():
dataset = datasets[name]
align_bias_c = dataset['alignment_bias_c'].astype(np.float32)
align_bias_1xc = np.expand_dims(align_bias_c, axis=0)
out_mat_fxc = None out_mat_fxc = None
out_bias_1xc = None out_bias_1xc = None
if in_mat_cxf is not None: if in_mat_cxf is not None:
...@@ -1714,7 +1719,7 @@ class LFADS(object): ...@@ -1714,7 +1719,7 @@ class LFADS(object):
out_dist_params = np.zeros([E_to_process, T, D+D]) out_dist_params = np.zeros([E_to_process, T, D+D])
else: else:
assert False, "NIY" assert False, "NIY"
costs = np.zeros(E_to_process) costs = np.zeros(E_to_process)
nll_bound_vaes = np.zeros(E_to_process) nll_bound_vaes = np.zeros(E_to_process)
nll_bound_iwaes = np.zeros(E_to_process) nll_bound_iwaes = np.zeros(E_to_process)
...@@ -1914,7 +1919,7 @@ class LFADS(object): ...@@ -1914,7 +1919,7 @@ class LFADS(object):
for i, (var, var_eval) in enumerate(zip(all_tf_vars, all_tf_vars_eval)): for i, (var, var_eval) in enumerate(zip(all_tf_vars, all_tf_vars_eval)):
if any(s in include_strs for s in var.name): if any(s in include_strs for s in var.name):
if not isinstance(var_eval, np.ndarray): # for H5PY if not isinstance(var_eval, np.ndarray): # for H5PY
print(var.name, """ is not numpy array, saving as numpy array print(var.name, """ is not numpy array, saving as numpy array
with value: """, var_eval, type(var_eval)) with value: """, var_eval, type(var_eval))
e = np.array(var_eval) e = np.array(var_eval)
print(e, type(e)) print(e, type(e))
......
...@@ -112,7 +112,7 @@ def init_linear(in_size, out_size, do_bias=True, mat_init_value=None, ...@@ -112,7 +112,7 @@ def init_linear(in_size, out_size, do_bias=True, mat_init_value=None,
'Provided mat_init_value must have shape [%d, %d].'%(in_size, out_size)) 'Provided mat_init_value must have shape [%d, %d].'%(in_size, out_size))
if bias_init_value is not None and bias_init_value.shape != (1,out_size): if bias_init_value is not None and bias_init_value.shape != (1,out_size):
raise ValueError( raise ValueError(
'Provided bias_init_value must have shape [1,%d].'%(1,out_size)) 'Provided bias_init_value must have shape [1,%d].'%(out_size,))
if mat_init_value is None: if mat_init_value is None:
stddev = alpha/np.sqrt(float(in_size)) stddev = alpha/np.sqrt(float(in_size))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment