Unverified Commit ca15f5d9 authored by Asim Shankar's avatar Asim Shankar Committed by GitHub
Browse files

Merge pull request #3157 from asimshankar/java-samples

[samples]: Samples using the Java API.
parents 99491464 28bd85d1
from __future__ import print_function
import tensorflow as tf
x = tf.placeholder(tf.float32, name='input')
y_ = tf.placeholder(tf.float32, name='target')
W = tf.Variable(5., name='W')
b = tf.Variable(3., name='b')
y = x * W + b
y = tf.identity(y, name='output')
loss = tf.reduce_mean(tf.square(y - y_))
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)
train_op = optimizer.minimize(loss, name='train')
init = tf.global_variables_initializer()
# Creating a tf.train.Saver adds operations to the graph to save and
# restore variables from checkpoints.
saver_def = tf.train.Saver().as_saver_def()
print('Operation to initialize variables: ', init.name)
print('Tensor to feed as input data: ', x.name)
print('Tensor to feed as training targets: ', y_.name)
print('Tensor to fetch as prediction: ', y.name)
print('Operation to train one step: ', train_op.name)
print('Tensor to be fed for checkpoint filename:', saver_def.filename_tensor_name)
print('Operation to save a checkpoint: ', saver_def.save_tensor_name)
print('Operation to restore a checkpoint: ', saver_def.restore_op_name)
print('Tensor to read value of W ', W.value().name)
print('Tensor to read value of b ', b.value().name)
with open('graph.pb', 'w') as f:
f.write(tf.get_default_graph().as_graph_def().SerializeToString())
<project>
<modelVersion>4.0.0</modelVersion>
<groupId>org.myorg</groupId>
<artifactId>training</artifactId>
<version>1.0-SNAPSHOT</version>
<properties>
<exec.mainClass>Train</exec.mainClass>
<!-- The sample code requires at least JDK 1.7. -->
<!-- The maven compiler plugin defaults to a lower version -->
<maven.compiler.source>1.7</maven.compiler.source>
<maven.compiler.target>1.7</maven.compiler.target>
</properties>
<dependencies>
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow</artifactId>
<version>1.4.0</version>
</dependency>
</dependencies>
</project>
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.List;
import java.util.Random;
import org.tensorflow.Graph;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.Tensors;
/**
* Training a trivial linear model.
*/
public class Train {
public static void main(String[] args) throws Exception {
if (args.length != 2) {
System.err.println("Require two arguments: The GraphDef file and checkpoint directory");
System.exit(1);
}
final byte[] graphDef = Files.readAllBytes(Paths.get(args[0]));
final String checkpointDir = args[1];
final boolean checkpointExists = Files.exists(Paths.get(checkpointDir));
try (Graph graph = new Graph();
Session sess = new Session(graph);
Tensor<String> checkpointPrefix =
Tensors.create(Paths.get(checkpointDir, "ckpt").toString())) {
graph.importGraphDef(graphDef);
// Initialize or restore
if (checkpointExists) {
sess.runner().feed("save/Const", checkpointPrefix).addTarget("save/restore_all").run();
} else {
sess.runner().addTarget("init").run();
}
System.out.print("Starting from : ");
printVariables(sess);
// Train a bunch of times.
// (Will be much more efficient if we sent batches instead of individual values).
final Random r = new Random();
final int NUM_EXAMPLES = 500;
for (int i = 1; i <= 5; i++) {
for (int n = 0; n < NUM_EXAMPLES; n++) {
float in = r.nextFloat();
try (Tensor<Float> input = Tensors.create(in);
Tensor<Float> target = Tensors.create(3 * in + 2)) {
sess.runner().feed("input", input).feed("target", target).addTarget("train").run();
}
}
System.out.printf("After %5d examples: ", i*NUM_EXAMPLES);
printVariables(sess);
}
// Checkpoint
sess.runner().feed("save/Const", checkpointPrefix).addTarget("save/control_dependency").run();
// Example of "inference" in the same graph:
try (Tensor<Float> input = Tensors.create(1.0f);
Tensor<Float> output =
sess.runner().feed("input", input).fetch("output").run().get(0).expect(Float.class)) {
System.out.printf(
"For input %f, produced %f (ideally would produce 3*%f + 2)\n",
input.floatValue(), output.floatValue(), input.floatValue());
}
}
}
private static void printVariables(Session sess) {
List<Tensor<?>> values = sess.runner().fetch("W/read").fetch("b/read").run();
System.out.printf("W = %f\tb = %f\n", values.get(0).floatValue(), values.get(1).floatValue());
for (Tensor<?> t : values) {
t.close();
}
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment