Unverified Commit ef6cb209 authored by George Sterpu's avatar George Sterpu Committed by GitHub
Browse files

Update beam_search.py

parent 6a3bcef8
...@@ -30,9 +30,13 @@ class SequenceBeamSearchV2(v1.SequenceBeamSearch): ...@@ -30,9 +30,13 @@ class SequenceBeamSearchV2(v1.SequenceBeamSearch):
"""Beam search for sequences with highest scores.""" """Beam search for sequences with highest scores."""
state, state_shapes = self._create_initial_state(initial_ids, initial_cache) state, state_shapes = self._create_initial_state(initial_ids, initial_cache)
finished_state = tf.while_loop( finished_state = tf.nest.map_structure(
self._continue_search, self._search_step, loop_vars=[state], tf.stop_gradient,
shape_invariants=[state_shapes], parallel_iterations=1, back_prop=False) tf.while_loop(self._continue_search,
self._search_step,
loop_vars=[state],
shape_invariants=[state_shapes],
parallel_iterations=1))
finished_state = finished_state[0] finished_state = finished_state[0]
alive_seq = finished_state[_StateKeys.ALIVE_SEQ] alive_seq = finished_state[_StateKeys.ALIVE_SEQ]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment