Unverified Commit 6b9d5fba authored by Toby Boyd's avatar Toby Boyd Committed by GitHub
Browse files

Merge branch 'master' into patch-1

parents 5fd687c5 5fa2a4e6
# Training models in Java
Example of training a model (and saving and restoring checkpoints) using the
TensorFlow Java API.
## Quickstart
1. Train for a few steps:
```
mvn -q compile exec:java -Dexec.args="model/graph.pb checkpoint"
```
2. Resume training from previous checkpoint and train some more:
```
mvn -q exec:java -Dexec.args="model/graph.pb checkpoint"
```
3. Delete checkpoint:
```
rm -rf checkpoint
```
## Details
The model in `model/graph.pb` represents a very simple linear model:
```
y = x * W + b
```
The `graph.pb` file is generated by executing `create_graph.py` in Python.
The training is orchestrated by `src/main/java/Train.java`, which generates
training data of the form `y = 3.0 * x + 2.0` and over time, using gradient
descent, the model should "learn" and the value of `W` should converge to 3.0,
and `b` to 2.0.
from __future__ import print_function
import tensorflow as tf
x = tf.placeholder(tf.float32, name='input')
y_ = tf.placeholder(tf.float32, name='target')
W = tf.Variable(5., name='W')
b = tf.Variable(3., name='b')
y = x * W + b
y = tf.identity(y, name='output')
loss = tf.reduce_mean(tf.square(y - y_))
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)
train_op = optimizer.minimize(loss, name='train')
init = tf.global_variables_initializer()
# Creating a tf.train.Saver adds operations to the graph to save and
# restore variables from checkpoints.
saver_def = tf.train.Saver().as_saver_def()
print('Operation to initialize variables: ', init.name)
print('Tensor to feed as input data: ', x.name)
print('Tensor to feed as training targets: ', y_.name)
print('Tensor to fetch as prediction: ', y.name)
print('Operation to train one step: ', train_op.name)
print('Tensor to be fed for checkpoint filename:', saver_def.filename_tensor_name)
print('Operation to save a checkpoint: ', saver_def.save_tensor_name)
print('Operation to restore a checkpoint: ', saver_def.restore_op_name)
print('Tensor to read value of W ', W.value().name)
print('Tensor to read value of b ', b.value().name)
with open('graph.pb', 'w') as f:
f.write(tf.get_default_graph().as_graph_def().SerializeToString())
<project>
<modelVersion>4.0.0</modelVersion>
<groupId>org.myorg</groupId>
<artifactId>training</artifactId>
<version>1.0-SNAPSHOT</version>
<properties>
<exec.mainClass>Train</exec.mainClass>
<!-- The sample code requires at least JDK 1.7. -->
<!-- The maven compiler plugin defaults to a lower version -->
<maven.compiler.source>1.7</maven.compiler.source>
<maven.compiler.target>1.7</maven.compiler.target>
</properties>
<dependencies>
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow</artifactId>
<version>1.4.0</version>
</dependency>
</dependencies>
</project>
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.List;
import java.util.Random;
import org.tensorflow.Graph;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.Tensors;
/**
* Training a trivial linear model.
*/
public class Train {
public static void main(String[] args) throws Exception {
if (args.length != 2) {
System.err.println("Require two arguments: The GraphDef file and checkpoint directory");
System.exit(1);
}
final byte[] graphDef = Files.readAllBytes(Paths.get(args[0]));
final String checkpointDir = args[1];
final boolean checkpointExists = Files.exists(Paths.get(checkpointDir));
try (Graph graph = new Graph();
Session sess = new Session(graph);
Tensor<String> checkpointPrefix =
Tensors.create(Paths.get(checkpointDir, "ckpt").toString())) {
graph.importGraphDef(graphDef);
// Initialize or restore
if (checkpointExists) {
sess.runner().feed("save/Const", checkpointPrefix).addTarget("save/restore_all").run();
} else {
sess.runner().addTarget("init").run();
}
System.out.print("Starting from : ");
printVariables(sess);
// Train a bunch of times.
// (Will be much more efficient if we sent batches instead of individual values).
final Random r = new Random();
final int NUM_EXAMPLES = 500;
for (int i = 1; i <= 5; i++) {
for (int n = 0; n < NUM_EXAMPLES; n++) {
float in = r.nextFloat();
try (Tensor<Float> input = Tensors.create(in);
Tensor<Float> target = Tensors.create(3 * in + 2)) {
sess.runner().feed("input", input).feed("target", target).addTarget("train").run();
}
}
System.out.printf("After %5d examples: ", i*NUM_EXAMPLES);
printVariables(sess);
}
// Checkpoint
sess.runner().feed("save/Const", checkpointPrefix).addTarget("save/control_dependency").run();
// Example of "inference" in the same graph:
try (Tensor<Float> input = Tensors.create(1.0f);
Tensor<Float> output =
sess.runner().feed("input", input).fetch("output").run().get(0).expect(Float.class)) {
System.out.printf(
"For input %f, produced %f (ideally would produce 3*%f + 2)\n",
input.floatValue(), output.floatValue(), input.floatValue());
}
}
}
private static void printVariables(Session sess) {
List<Tensor<?>> values = sess.runner().fetch("W/read").fetch("b/read").run();
System.out.printf("W = %f\tb = %f\n", values.get(0).floatValue(), values.get(1).floatValue());
for (Tensor<?> t : values) {
t.close();
}
}
}
...@@ -19,17 +19,19 @@ unzip -p source-archive.zip word2vec/trunk/questions-words.txt > questions-word ...@@ -19,17 +19,19 @@ unzip -p source-archive.zip word2vec/trunk/questions-words.txt > questions-word
rm text8.zip source-archive.zip rm text8.zip source-archive.zip
``` ```
You will need to compile the ops as follows: You will need to compile the ops as follows (See
[Adding a New Op to TensorFlow](https://www.tensorflow.org/how_tos/adding_an_op/#building_the_op_library)
for more details).:
```shell ```shell
TF_INC=$(python -c 'import tensorflow as tf; print(tf.sysconfig.get_include())') TF_CFLAGS=( $(python -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_compile_flags()))') )
g++ -std=c++11 -shared word2vec_ops.cc word2vec_kernels.cc -o word2vec_ops.so -fPIC -I $TF_INC -O2 -D_GLIBCXX_USE_CXX11_ABI=0 TF_LFLAGS=( $(python -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_link_flags()))') )
g++ -std=c++11 -shared word2vec_ops.cc word2vec_kernels.cc -o word2vec_ops.so -fPIC ${TF_CFLAGS[@]} ${TF_LFLAGS[@]} -O2 -D_GLIBCXX_USE_CXX11_ABI=0
``` ```
On Mac, add `-undefined dynamic_lookup` to the g++ command. On Mac, add `-undefined dynamic_lookup` to the g++ command. The flag `-D_GLIBCXX_USE_CXX11_ABI=0` is included to support newer versions of gcc. However, if you compiled TensorFlow from source using gcc 5 or later, you may need to exclude the flag. Specifically, if you get an error similar to the following: `word2vec_ops.so: undefined symbol: _ZN10tensorflow7strings6StrCatERKNS0_8AlphaNumES3_S3_S3_` then you likely need to exclude the flag.
(For an explanation of what this is doing, see the tutorial on [Adding a New Op to TensorFlow](https://www.tensorflow.org/how_tos/adding_an_op/#building_the_op_library). The flag `-D_GLIBCXX_USE_CXX11_ABI=0` is included to support newer versions of gcc. However, if you compiled TensorFlow from source using gcc 5 or later, you may need to exclude the flag.) Once you've successfully compiled the ops, run the model as follows:
Then run using:
```shell ```shell
python word2vec_optimized.py \ python word2vec_optimized.py \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment