README.md 885 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
# Training models in Java

Example of training a model (and saving and restoring checkpoints) using the
TensorFlow Java API.

## Quickstart

1.  Train for a few steps:
    ```
    mvn -q compile exec:java -Dexec.args="model/graph.pb checkpoint"
    ```

2.  Resume training from previous checkpoint and train some more:
    ```
    mvn -q exec:java -Dexec.args="model/graph.pb checkpoint"
    ```

3. Delete checkpoint:
   ```
   rm -rf checkpoint
   ```


## Details

The model in `model/graph.pb` represents a very simple linear model:

```
y = x * W + b
```

The `graph.pb` file is generated by executing `create_graph.py` in Python.

The training is orchestrated by `src/main/java/Train.java`, which generates
training data of the form `y = 3.0 * x + 2.0` and over time, using gradient
descent, the model should "learn" and the value of `W` should converge to 3.0,
and `b` to 2.0.