Unverified Commit 499071ef authored by Mark Daoust's avatar Mark Daoust Committed by GitHub
Browse files

Merge pull request #4297 from tensorflow/asimshankar-patch-1

[samples/java]: Explain where tensor names came from.
parents 24347947 fcdbc364
...@@ -42,7 +42,10 @@ public class Train { ...@@ -42,7 +42,10 @@ public class Train {
Tensors.create(Paths.get(checkpointDir, "ckpt").toString())) { Tensors.create(Paths.get(checkpointDir, "ckpt").toString())) {
graph.importGraphDef(graphDef); graph.importGraphDef(graphDef);
// Initialize or restore // Initialize or restore.
// The names of the tensors in the graph are printed out by the program
// that created the graph:
// https://github.com/tensorflow/models/blob/master/samples/languages/java/training/model/create_graph.py
if (checkpointExists) { if (checkpointExists) {
sess.runner().feed("save/Const", checkpointPrefix).addTarget("save/restore_all").run(); sess.runner().feed("save/Const", checkpointPrefix).addTarget("save/restore_all").run();
} else { } else {
...@@ -60,6 +63,8 @@ public class Train { ...@@ -60,6 +63,8 @@ public class Train {
float in = r.nextFloat(); float in = r.nextFloat();
try (Tensor<Float> input = Tensors.create(in); try (Tensor<Float> input = Tensors.create(in);
Tensor<Float> target = Tensors.create(3 * in + 2)) { Tensor<Float> target = Tensors.create(3 * in + 2)) {
// Again the tensor names are from the program that created the graph.
// https://github.com/tensorflow/models/blob/master/samples/languages/java/training/model/create_graph.py
sess.runner().feed("input", input).feed("target", target).addTarget("train").run(); sess.runner().feed("input", input).feed("target", target).addTarget("train").run();
} }
} }
...@@ -67,7 +72,9 @@ public class Train { ...@@ -67,7 +72,9 @@ public class Train {
printVariables(sess); printVariables(sess);
} }
// Checkpoint // Checkpoint.
// The feed and target name are from the program that created the graph.
// https://github.com/tensorflow/models/blob/master/samples/languages/java/training/model/create_graph.py.
sess.runner().feed("save/Const", checkpointPrefix).addTarget("save/control_dependency").run(); sess.runner().feed("save/Const", checkpointPrefix).addTarget("save/control_dependency").run();
// Example of "inference" in the same graph: // Example of "inference" in the same graph:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment