Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
20da056d
Commit
20da056d
authored
Jun 18, 2018
by
Dan O'Shea
Browse files
Cleaning comments for posterior_push_mean
parent
8c5c60ca
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
17 additions
and
13 deletions
+17
-13
research/lfads/lfads.py
research/lfads/lfads.py
+17
-13
No files found.
research/lfads/lfads.py
View file @
20da056d
...
@@ -490,11 +490,6 @@ class LFADS(object):
...
@@ -490,11 +490,6 @@ class LFADS(object):
pf_pairs_out_fac_Ws
=
zip
(
preds
,
fns_out_fac_Ws
)
pf_pairs_out_fac_Ws
=
zip
(
preds
,
fns_out_fac_Ws
)
pf_pairs_out_fac_bs
=
zip
(
preds
,
fns_out_fac_bs
)
pf_pairs_out_fac_bs
=
zip
(
preds
,
fns_out_fac_bs
)
# def _case_with_no_default(pairs):
# def _default_value_fn():
# with tf.control_dependencies([tf.Assert(False, ["Reached default"])]):
# return tf.identity(pairs[0][1]())
# return tf.case(pairs, _default_value_fn, exclusive=True)
this_in_fac_W
=
tf
.
case
(
pf_pairs_in_fac_Ws
,
exclusive
=
True
)
this_in_fac_W
=
tf
.
case
(
pf_pairs_in_fac_Ws
,
exclusive
=
True
)
this_in_fac_b
=
tf
.
case
(
pf_pairs_in_fac_bs
,
exclusive
=
True
)
this_in_fac_b
=
tf
.
case
(
pf_pairs_in_fac_bs
,
exclusive
=
True
)
this_out_fac_W
=
tf
.
case
(
pf_pairs_out_fac_Ws
,
exclusive
=
True
)
this_out_fac_W
=
tf
.
case
(
pf_pairs_out_fac_Ws
,
exclusive
=
True
)
...
@@ -931,7 +926,7 @@ class LFADS(object):
...
@@ -931,7 +926,7 @@ class LFADS(object):
tvars2
=
\
tvars2
=
\
tf
.
get_collection
(
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
,
tf
.
get_collection
(
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
,
scope
=
'LFADS/z/ic_enc_*'
)
scope
=
'LFADS/z/ic_enc_*'
)
self
.
train_vars
=
tvars
=
tvars1
+
tvars2
self
.
train_vars
=
tvars
=
tvars1
+
tvars2
# train all variables
# train all variables
else
:
else
:
...
@@ -1765,7 +1760,6 @@ class LFADS(object):
...
@@ -1765,7 +1760,6 @@ class LFADS(object):
E
,
T
,
D
=
data_extxd
.
shape
E
,
T
,
D
=
data_extxd
.
shape
E_to_process
=
hps
.
ps_nexamples_to_process
E_to_process
=
hps
.
ps_nexamples_to_process
if
E_to_process
>
E
:
if
E_to_process
>
E
:
print
(
"Setting number of posterior samples to process to : "
,
E
)
E_to_process
=
E
E_to_process
=
E
if
hps
.
ic_dim
>
0
:
if
hps
.
ic_dim
>
0
:
...
@@ -1843,12 +1837,16 @@ class LFADS(object):
...
@@ -1843,12 +1837,16 @@ class LFADS(object):
def
eval_model_runs_push_mean
(
self
,
data_name
,
data_extxd
,
def
eval_model_runs_push_mean
(
self
,
data_name
,
data_extxd
,
ext_input_extxi
=
None
):
ext_input_extxi
=
None
):
"""Returns
the
value
for goodi
es for the
entire
model using the means
"""Returns value
s of inter
es
t
for the model
by p
us
h
ing the means
through
The expected value is taken over hidden (z) variables, namely the initial
The mean values for both initial conditions and the control inputs are
conditions and the control inputs, by pushing the mean values for both
pushed through the model instead of sampling (as is done in
through the model rather than by sampling (as in eval_model_runs_avg_epoch)
eval_model_runs_avg_epoch).
A total of batch_size trials are run at a time.
This is a quick and approximate version of estimating these values instead
of sampling from the posterior many times and then averaging those values of
interest.
Internally, a total of batch_size trials are run through the model at once.
Args:
Args:
data_name: The name of the data dict, to select which in/out matrices
data_name: The name of the data dict, to select which in/out matrices
...
@@ -1859,7 +1857,7 @@ class LFADS(object):
...
@@ -1859,7 +1857,7 @@ class LFADS(object):
shape: # examples x # time steps x # external input dims
shape: # examples x # time steps x # external input dims
Returns:
Returns:
A dictionary with the
averag
ed outputs of the model decoder, namely:
A dictionary with the
estimat
ed outputs of the model decoder, namely:
prior g0 mean, prior g0 variance, approx. posterior mean, approx
prior g0 mean, prior g0 variance, approx. posterior mean, approx
posterior mean, the generator initial conditions, the control inputs (if
posterior mean, the generator initial conditions, the control inputs (if
enabled), the state of the generator, the factors, and the output
enabled), the state of the generator, the factors, and the output
...
@@ -1897,6 +1895,9 @@ class LFADS(object):
...
@@ -1897,6 +1895,9 @@ class LFADS(object):
nll_bound_iwaes
=
np
.
zeros
(
E_to_process
)
nll_bound_iwaes
=
np
.
zeros
(
E_to_process
)
train_steps
=
np
.
zeros
(
E_to_process
)
train_steps
=
np
.
zeros
(
E_to_process
)
# generator that will yield 0:N in groups of per items, e.g.
# (0:per-1), (per:2*per-1), ..., with the last group containing <= per items
# this will be used to feed per=batch_size trials into the model at a time
def
trial_batches
(
N
,
per
):
def
trial_batches
(
N
,
per
):
for
i
in
range
(
0
,
N
,
per
):
for
i
in
range
(
0
,
N
,
per
):
yield
np
.
arange
(
i
,
min
(
i
+
per
,
N
),
dtype
=
np
.
int32
)
yield
np
.
arange
(
i
,
min
(
i
+
per
,
N
),
dtype
=
np
.
int32
)
...
@@ -1949,6 +1950,9 @@ class LFADS(object):
...
@@ -1949,6 +1950,9 @@ class LFADS(object):
model_runs
[
'gen_states'
]
=
gen_states
model_runs
[
'gen_states'
]
=
gen_states
model_runs
[
'factors'
]
=
factors
model_runs
[
'factors'
]
=
factors
model_runs
[
'output_dist_params'
]
=
out_dist_params
model_runs
[
'output_dist_params'
]
=
out_dist_params
# You probably do not want the LL associated values when pushing the mean
# instead of sampling.
model_runs
[
'costs'
]
=
costs
model_runs
[
'costs'
]
=
costs
model_runs
[
'nll_bound_vaes'
]
=
nll_bound_vaes
model_runs
[
'nll_bound_vaes'
]
=
nll_bound_vaes
model_runs
[
'nll_bound_iwaes'
]
=
nll_bound_iwaes
model_runs
[
'nll_bound_iwaes'
]
=
nll_bound_iwaes
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment