Unverified Commit 2e21c80e authored by George Sterpu's avatar George Sterpu Committed by GitHub
Browse files

Update beam_search_v1.py

parent ef6cb209
...@@ -128,9 +128,13 @@ class SequenceBeamSearch(object): ...@@ -128,9 +128,13 @@ class SequenceBeamSearch(object):
"""Beam search for sequences with highest scores.""" """Beam search for sequences with highest scores."""
state, state_shapes = self._create_initial_state(initial_ids, initial_cache) state, state_shapes = self._create_initial_state(initial_ids, initial_cache)
finished_state = tf.while_loop( finished_state = tf.nest.map_structure(
self._continue_search, self._search_step, loop_vars=[state], tf.stop_gradient,
shape_invariants=[state_shapes], parallel_iterations=1, back_prop=False) tf.while_loop(self._continue_search,
self._search_step,
loop_vars=[state],
shape_invariants=[state_shapes],
parallel_iterations=1))
finished_state = finished_state[0] finished_state = finished_state[0]
alive_seq = finished_state[_StateKeys.ALIVE_SEQ] alive_seq = finished_state[_StateKeys.ALIVE_SEQ]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment