"tests/vscode:/vscode.git/clone" did not exist on "88fa6b7d68d77b2531462ebe5a339b8c5b034ce4"
Commit fda7e6dc authored by Andrew M. Dai's avatar Andrew M. Dai
Browse files

Fix frequent deadlocks caused by summaries not depending on the save_state ops.

PiperOrigin-RevId: 172665358
parent 4b5a0801
...@@ -189,11 +189,10 @@ class VatxtModel(object): ...@@ -189,11 +189,10 @@ class VatxtModel(object):
tf.summary.scalar('adversarial_loss', adv_loss) tf.summary.scalar('adversarial_loss', adv_loss)
total_loss = loss + adv_loss total_loss = loss + adv_loss
tf.summary.scalar('total_classification_loss', total_loss)
with tf.control_dependencies([inputs.save_state(next_state)]): with tf.control_dependencies([inputs.save_state(next_state)]):
total_loss = tf.identity(total_loss) total_loss = tf.identity(total_loss)
tf.summary.scalar('total_classification_loss', total_loss)
return total_loss return total_loss
def language_model_graph(self, compute_loss=True): def language_model_graph(self, compute_loss=True):
...@@ -419,12 +418,12 @@ class VatxtBidirModel(VatxtModel): ...@@ -419,12 +418,12 @@ class VatxtBidirModel(VatxtModel):
tf.summary.scalar('adversarial_loss', adv_loss) tf.summary.scalar('adversarial_loss', adv_loss)
total_loss = loss + adv_loss total_loss = loss + adv_loss
tf.summary.scalar('total_classification_loss', total_loss)
saves = [inp.save_state(state) for (inp, state) in zip(inputs, next_states)] saves = [inp.save_state(state) for (inp, state) in zip(inputs, next_states)]
with tf.control_dependencies(saves): with tf.control_dependencies(saves):
total_loss = tf.identity(total_loss) total_loss = tf.identity(total_loss)
tf.summary.scalar('total_classification_loss', total_loss)
return total_loss return total_loss
def language_model_graph(self, compute_loss=True): def language_model_graph(self, compute_loss=True):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment