Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
997f7871
Commit
997f7871
authored
Jan 12, 2018
by
Asim Shankar
Browse files
[samples]: Samples using the Java API.
parent
364f96dd
Changes
22
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2428 additions
and
0 deletions
+2428
-0
samples/java/README.md
samples/java/README.md
+3
-0
samples/java/docker/Dockerfile
samples/java/docker/Dockerfile
+7
-0
samples/java/docker/README.md
samples/java/docker/README.md
+12
-0
samples/java/label_image/.gitignore
samples/java/label_image/.gitignore
+3
-0
samples/java/label_image/README.md
samples/java/label_image/README.md
+23
-0
samples/java/label_image/download.py
samples/java/label_image/download.py
+93
-0
samples/java/label_image/download.sh
samples/java/label_image/download.sh
+4
-0
samples/java/label_image/download_sample_images.sh
samples/java/label_image/download_sample_images.sh
+10
-0
samples/java/label_image/pom.xml
samples/java/label_image/pom.xml
+26
-0
samples/java/label_image/src/main/java/LabelImage.java
samples/java/label_image/src/main/java/LabelImage.java
+98
-0
samples/java/object_detection/.gitignore
samples/java/object_detection/.gitignore
+5
-0
samples/java/object_detection/README.md
samples/java/object_detection/README.md
+57
-0
samples/java/object_detection/download.sh
samples/java/object_detection/download.sh
+18
-0
samples/java/object_detection/pom.xml
samples/java/object_detection/pom.xml
+25
-0
samples/java/object_detection/src/main/java/DetectObjects.java
...es/java/object_detection/src/main/java/DetectObjects.java
+184
-0
samples/java/object_detection/src/main/java/object_detection/protos/StringIntLabelMapOuterClass.java
.../object_detection/protos/StringIntLabelMapOuterClass.java
+1785
-0
samples/java/training/.gitignore
samples/java/training/.gitignore
+2
-0
samples/java/training/README.md
samples/java/training/README.md
+37
-0
samples/java/training/model/create_graph.py
samples/java/training/model/create_graph.py
+36
-0
samples/java/training/model/graph.pb
samples/java/training/model/graph.pb
+0
-0
No files found.
samples/java/README.md
0 → 100644
View file @
997f7871
# TensorFlow for Java: Examples
Examples using the TensorFlow Java API.
samples/java/docker/Dockerfile
0 → 100644
View file @
997f7871
FROM
tensorflow/tensorflow:1.4.0
WORKDIR
/
RUN
apt-get update
RUN
apt-get
-y
install
maven openjdk-8-jdk
RUN
mvn dependency:get
-Dartifact
=
org.tensorflow:tensorflow:1.4.0
RUN
mvn dependency:get
-Dartifact
=
org.tensorflow:proto:1.4.0
CMD
["/bin/bash", "-l"]
samples/java/docker/README.md
0 → 100644
View file @
997f7871
Dockerfile for building an image suitable for running the Java examples.
Typical usage:
```
docker build -t java-tensorflow .
docker run -it --rm -v ${PWD}/..:/examples java-tensorflow
```
That second command will pop you into a shell which has all
the dependencies required to execute the scripts and Java
examples.
samples/java/label_image/.gitignore
0 → 100644
View file @
997f7871
images
src/main/resources
target
samples/java/label_image/README.md
0 → 100644
View file @
997f7871
# Image Classification Example
1.
Download the model:
-
If you have
[
TensorFlow 1.4+ for Python installed
](
https://www.tensorflow.org/install/
)
,
run
`python ./download.py`
-
If not, but you have
[
docker
](
https://www.docker.com/get-docker
)
installed,
run
`download.sh`
.
2.
Compile
[
`LabelImage.java`
](
src/main/java/LabelImage.java
)
:
```
mvn compile
```
3.
Download some sample images:
If you already have some images, great. Otherwise
`download_sample_images.sh`
gets a few.
3.
Classify!
```
mvn -q exec:java -Dexec.args="<path to image file>"
```
samples/java/label_image/download.py
0 → 100644
View file @
997f7871
"""Create an image classification graph.
Script to download a pre-trained image classifier and tweak it so that
the model accepts raw bytes of an encoded image.
Doing so involves some model-specific normalization of an image.
Ideally, this would have been part of the image classifier model,
but the particular model being used didn't include this normalization,
so this script does the necessary tweaking.
"""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
six.moves
import
urllib
import
os
import
zipfile
import
tensorflow
as
tf
URL
=
'https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip'
LABELS_FILE
=
'imagenet_comp_graph_label_strings.txt'
GRAPH_FILE
=
'tensorflow_inception_graph.pb'
GRAPH_INPUT_TENSOR
=
'input:0'
GRAPH_PROBABILITIES_TENSOR
=
'output:0'
IMAGE_HEIGHT
=
224
IMAGE_WIDTH
=
224
MEAN
=
117
SCALE
=
1
LOCAL_DIR
=
'src/main/resources'
def
download
():
print
(
'Downloading %s'
%
URL
)
zip_filename
,
_
=
urllib
.
request
.
urlretrieve
(
URL
)
with
zipfile
.
ZipFile
(
zip_filename
)
as
zip
:
zip
.
extract
(
LABELS_FILE
)
zip
.
extract
(
GRAPH_FILE
)
os
.
rename
(
LABELS_FILE
,
os
.
path
.
join
(
LOCAL_DIR
,
'labels.txt'
))
os
.
rename
(
GRAPH_FILE
,
os
.
path
.
join
(
LOCAL_DIR
,
'graph.pb'
))
def
create_graph_to_decode_and_normalize_image
():
"""See file docstring.
Returns:
input: The placeholder to feed the raw bytes of an encoded image.
y: A Tensor (the decoded, normalized image) to be fed to the graph.
"""
image
=
tf
.
placeholder
(
tf
.
string
,
shape
=
(),
name
=
'encoded_image_bytes'
)
with
tf
.
name_scope
(
"preprocess"
):
y
=
tf
.
image
.
decode_image
(
image
,
channels
=
3
)
y
=
tf
.
cast
(
y
,
tf
.
float32
)
y
=
tf
.
expand_dims
(
y
,
axis
=
0
)
y
=
tf
.
image
.
resize_bilinear
(
y
,
(
IMAGE_HEIGHT
,
IMAGE_WIDTH
))
y
=
(
y
-
MEAN
)
/
SCALE
return
(
image
,
y
)
def
patch_graph
():
"""Create graph.pb that applies the model in URL to raw image bytes."""
with
tf
.
Graph
().
as_default
()
as
g
:
input_image
,
image_normalized
=
create_graph_to_decode_and_normalize_image
()
original_graph_def
=
tf
.
GraphDef
()
with
open
(
os
.
path
.
join
(
LOCAL_DIR
,
'graph.pb'
))
as
f
:
original_graph_def
.
ParseFromString
(
f
.
read
())
softmax
=
tf
.
import_graph_def
(
original_graph_def
,
name
=
'inception'
,
input_map
=
{
GRAPH_INPUT_TENSOR
:
image_normalized
},
return_elements
=
[
GRAPH_PROBABILITIES_TENSOR
])
# We're constructing a graph that accepts a single image (as opposed to a
# batch of images), so might as well make the output be a vector of
# probabilities, instead of a batch of vectors with batch size 1.
output_probabilities
=
tf
.
squeeze
(
softmax
,
name
=
'probabilities'
)
# Overwrite the graph.
with
open
(
os
.
path
.
join
(
LOCAL_DIR
,
'graph.pb'
),
'w'
)
as
f
:
f
.
write
(
g
.
as_graph_def
().
SerializeToString
())
print
(
'------------------------------------------------------------'
)
print
(
'MODEL GRAPH : graph.pb'
)
print
(
'LABELS : labels.txt'
)
print
(
'INPUT TENSOR : %s'
%
input_image
.
op
.
name
)
print
(
'OUTPUT TENSOR: %s'
%
output_probabilities
.
op
.
name
)
if
__name__
==
'__main__'
:
if
not
os
.
path
.
exists
(
LOCAL_DIR
):
os
.
makedirs
(
LOCAL_DIR
)
download
()
patch_graph
()
samples/java/label_image/download.sh
0 → 100755
View file @
997f7871
#!/bin/bash
DIR
=
"
$(
cd
"
$(
dirname
"
$0
"
)
"
&&
pwd
-P
)
"
docker run
-it
-v
${
DIR
}
:/x
-w
/x
--rm
tensorflow/tensorflow:1.4.0 python download.py
samples/java/label_image/download_sample_images.sh
0 → 100755
View file @
997f7871
#!/bin/bash
DIR
=
$(
dirname
$0
)
mkdir
-p
${
DIR
}
/images
cd
${
DIR
}
/images
# Some random images
curl
-o
"porcupine.jpg"
-L
"https://cdn.pixabay.com/photo/2014/11/06/12/46/porcupines-519145_960_720.jpg"
curl
-o
"whale.jpg"
-L
"https://static.pexels.com/photos/417196/pexels-photo-417196.jpeg"
curl
-o
"terrier1u.jpg"
-L
"https://upload.wikimedia.org/wikipedia/commons/3/34/Australian_Terrier_Melly_%282%29.JPG"
curl
-o
"terrier2.jpg"
-L
"https://cdn.pixabay.com/photo/2014/05/13/07/44/yorkshire-terrier-343198_960_720.jpg"
samples/java/label_image/pom.xml
0 → 100644
View file @
997f7871
<project>
<modelVersion>
4.0.0
</modelVersion>
<groupId>
org.myorg
</groupId>
<artifactId>
label-image
</artifactId>
<version>
1.0-SNAPSHOT
</version>
<properties>
<exec.mainClass>
LabelImage
</exec.mainClass>
<!-- The sample code requires at least JDK 1.7. -->
<!-- The maven compiler plugin defaults to a lower version -->
<maven.compiler.source>
1.7
</maven.compiler.source>
<maven.compiler.target>
1.7
</maven.compiler.target>
</properties>
<dependencies>
<dependency>
<groupId>
org.tensorflow
</groupId>
<artifactId>
tensorflow
</artifactId>
<version>
1.4.0
</version>
</dependency>
<!-- For ByteStreams.toByteArray: https://google.github.io/guava/releases/23.0/api/docs/com/google/common/io/ByteStreams.html -->
<dependency>
<groupId>
com.google.guava
</groupId>
<artifactId>
guava
</artifactId>
<version>
23.6-jre
</version>
</dependency>
</dependencies>
</project>
samples/java/label_image/src/main/java/LabelImage.java
0 → 100644
View file @
997f7871
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
import
com.google.common.io.ByteStreams
;
import
java.io.BufferedReader
;
import
java.io.IOException
;
import
java.io.InputStream
;
import
java.io.InputStreamReader
;
import
java.nio.file.Files
;
import
java.nio.file.Path
;
import
java.nio.file.Paths
;
import
java.util.ArrayList
;
import
java.util.List
;
import
org.tensorflow.Graph
;
import
org.tensorflow.Session
;
import
org.tensorflow.Tensor
;
import
org.tensorflow.Tensors
;
/**
* Simplified version of
* https://github.com/tensorflow/tensorflow/blob/r1.4/tensorflow/java/src/main/java/org/tensorflow/examples/LabelImage.java
*/
public
class
LabelImage
{
public
static
void
main
(
String
[]
args
)
throws
Exception
{
if
(
args
.
length
<
1
)
{
System
.
err
.
println
(
"USAGE: Provide a list of image filenames"
);
System
.
exit
(
1
);
}
final
List
<
String
>
labels
=
loadLabels
();
try
(
Graph
graph
=
new
Graph
();
Session
session
=
new
Session
(
graph
))
{
graph
.
importGraphDef
(
loadGraphDef
());
float
[]
probabilities
=
null
;
for
(
String
filename
:
args
)
{
byte
[]
bytes
=
Files
.
readAllBytes
(
Paths
.
get
(
filename
));
try
(
Tensor
<
String
>
input
=
Tensors
.
create
(
bytes
);
Tensor
<
Float
>
output
=
session
.
runner
()
.
feed
(
"encoded_image_bytes"
,
input
)
.
fetch
(
"probabilities"
)
.
run
()
.
get
(
0
)
.
expect
(
Float
.
class
))
{
if
(
probabilities
==
null
)
{
probabilities
=
new
float
[(
int
)
output
.
shape
()[
0
]];
}
output
.
copyTo
(
probabilities
);
int
label
=
argmax
(
probabilities
);
System
.
out
.
printf
(
"%-30s --> %-15s (%.2f%% likely)\n"
,
filename
,
labels
.
get
(
label
),
probabilities
[
label
]
*
100.0
);
}
}
}
}
private
static
byte
[]
loadGraphDef
()
throws
IOException
{
try
(
InputStream
is
=
LabelImage
.
class
.
getClassLoader
().
getResourceAsStream
(
"graph.pb"
))
{
return
ByteStreams
.
toByteArray
(
is
);
}
}
private
static
ArrayList
<
String
>
loadLabels
()
throws
IOException
{
ArrayList
<
String
>
labels
=
new
ArrayList
<
String
>();
String
line
;
final
InputStream
is
=
LabelImage
.
class
.
getClassLoader
().
getResourceAsStream
(
"labels.txt"
);
try
(
BufferedReader
reader
=
new
BufferedReader
(
new
InputStreamReader
(
is
)))
{
while
((
line
=
reader
.
readLine
())
!=
null
)
{
labels
.
add
(
line
);
}
}
return
labels
;
}
private
static
int
argmax
(
float
[]
probabilities
)
{
int
best
=
0
;
for
(
int
i
=
1
;
i
<
probabilities
.
length
;
++
i
)
{
if
(
probabilities
[
i
]
>
probabilities
[
best
])
{
best
=
i
;
}
}
return
best
;
}
}
samples/java/object_detection/.gitignore
0 → 100644
View file @
997f7871
images
labels
models
src/main/protobuf
target
samples/java/object_detection/README.md
0 → 100644
View file @
997f7871
# Object Detection in Java
Example of using pre-trained models of the
[
TensorFlow Object Detection
API
](
https://github.com/tensorflow/models/tree/master/research/object_detection
)
in Java.
## Quickstart
1.
Download some metadata files:
```
./download.sh
```
2.
Download a model from the
[
object detection API model
zoo
](
https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md
)
.
For example:
```
mkdir -p models
curl -L \
http://download.tensorflow.org/models/object_detection/ssd_inception_v2_coco_2017_11_17.tar.gz \
| tar -xz -C models/
```
3.
Locate the corresponding labels file in the
`data/`
directory.
3.
Have some test images handy. For example:
```
mkdir -p images
curl -L -o images/test.jpg \
https://pixnio.com/free-images/people/mother-father-and-children-washing-dog-labrador-retriever-outside-in-the-fresh-air-725x483.jpg
```
4.
Compile and run!
```
mvn -q compile exec:java \
-Dexec.args="models/ssd_inception_v2_coco_2017_11_17/saved_model labels/mscoco_label_map.pbtxt images/test.jpg"
```
## Notes
-
This example demonstrates the use of the TensorFlow
[
SavedModel
format
](
https://www.tensorflow.org/programmers_guide/saved_model
)
. If you have
TensorFlow for Python installed, you could explore the model to get the names
of the tensors using
`saved_model_cli`
command. For example:
```
saved_model_cli show --dir models/ssd_inception_v2_coco_2017_11_17/saved_model/ --all
```
-
The file in
`src/main/object_detection/protos/`
was generated using:
```
./download.sh
protoc -Isrc/main/protobuf --java_out=src/main/java src/main/protobuf/string_int_label_map.proto
```
Where
`protoc`
was downloaded from
https://github.com/google/protobuf/releases/tag/v3.5.1
samples/java/object_detection/download.sh
0 → 100755
View file @
997f7871
#!/bin/bash
set
-ex
DIR
=
"
$(
cd
"
$(
dirname
"
$0
"
)
"
&&
pwd
-P
)
"
cd
"
${
DIR
}
"
# The protobuf file needed for mapping labels to human readable names.
# From:
# https://github.com/tensorflow/models/blob/f87a58c/research/object_detection/protos/string_int_label_map.proto
mkdir
-p
src/main/protobuf
curl
-L
-o
src/main/protobuf/string_int_label_map.proto
"https://raw.githubusercontent.com/tensorflow/models/f87a58cd96d45de73c9a8330a06b2ab56749a7fa/research/object_detection/protos/string_int_label_map.proto"
# Labels from:
# https://github.com/tensorflow/models/tree/865c14c/research/object_detection/data
mkdir
-p
labels
curl
-L
-o
labels/mscoco_label_map.pbtxt
"https://raw.githubusercontent.com/tensorflow/models/865c14c1209cb9ae188b2a1b5f0883c72e050d4c/research/object_detection/data/mscoco_label_map.pbtxt"
curl
-L
-o
labels/oid_bbox_trainable_label_map.pbtxt
"https://raw.githubusercontent.com/tensorflow/models/865c14c1209cb9ae188b2a1b5f0883c72e050d4c/research/object_detection/data/oid_bbox_trainable_label_map.pbtxt"
samples/java/object_detection/pom.xml
0 → 100644
View file @
997f7871
<project>
<modelVersion>
4.0.0
</modelVersion>
<groupId>
org.myorg
</groupId>
<artifactId>
detect-objects
</artifactId>
<version>
1.0-SNAPSHOT
</version>
<properties>
<exec.mainClass>
DetectObjects
</exec.mainClass>
<!-- The sample code requires at least JDK 1.7. -->
<!-- The maven compiler plugin defaults to a lower version -->
<maven.compiler.source>
1.7
</maven.compiler.source>
<maven.compiler.target>
1.7
</maven.compiler.target>
</properties>
<dependencies>
<dependency>
<groupId>
org.tensorflow
</groupId>
<artifactId>
tensorflow
</artifactId>
<version>
1.4.0
</version>
</dependency>
<dependency>
<groupId>
org.tensorflow
</groupId>
<artifactId>
proto
</artifactId>
<version>
1.4.0
</version>
</dependency>
</dependencies>
</project>
samples/java/object_detection/src/main/java/DetectObjects.java
0 → 100644
View file @
997f7871
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
import
static
object_detection
.
protos
.
StringIntLabelMapOuterClass
.
StringIntLabelMap
;
import
static
object_detection
.
protos
.
StringIntLabelMapOuterClass
.
StringIntLabelMapItem
;
import
com.google.protobuf.TextFormat
;
import
java.awt.image.BufferedImage
;
import
java.awt.image.DataBufferByte
;
import
java.io.File
;
import
java.io.IOException
;
import
java.io.PrintStream
;
import
java.nio.ByteBuffer
;
import
java.nio.charset.StandardCharsets
;
import
java.nio.file.Files
;
import
java.nio.file.Paths
;
import
java.util.List
;
import
java.util.Map
;
import
javax.imageio.ImageIO
;
import
org.tensorflow.SavedModelBundle
;
import
org.tensorflow.Tensor
;
import
org.tensorflow.framework.MetaGraphDef
;
import
org.tensorflow.framework.SignatureDef
;
import
org.tensorflow.framework.TensorInfo
;
import
org.tensorflow.types.UInt8
;
/**
* Java inference for the Object Detection API at:
* https://github.com/tensorflow/models/blob/master/research/object_detection/
*/
public
class
DetectObjects
{
public
static
void
main
(
String
[]
args
)
throws
Exception
{
if
(
args
.
length
<
3
)
{
printUsage
(
System
.
err
);
System
.
exit
(
1
);
}
final
String
[]
labels
=
loadLabels
(
args
[
1
]);
try
(
SavedModelBundle
model
=
SavedModelBundle
.
load
(
args
[
0
],
"serve"
))
{
printSignature
(
model
);
for
(
int
arg
=
2
;
arg
<
args
.
length
;
arg
++)
{
final
String
filename
=
args
[
arg
];
List
<
Tensor
<?>>
outputs
=
null
;
try
(
Tensor
<
UInt8
>
input
=
makeImageTensor
(
filename
))
{
outputs
=
model
.
session
()
.
runner
()
.
feed
(
"image_tensor"
,
input
)
.
fetch
(
"detection_scores"
)
.
fetch
(
"detection_classes"
)
.
fetch
(
"detection_boxes"
)
.
run
();
}
try
(
Tensor
<
Float
>
scoresT
=
outputs
.
get
(
0
).
expect
(
Float
.
class
);
Tensor
<
Float
>
classesT
=
outputs
.
get
(
1
).
expect
(
Float
.
class
);
Tensor
<
Float
>
boxesT
=
outputs
.
get
(
2
).
expect
(
Float
.
class
))
{
// All these tensors have:
// - 1 as the first dimension
// - maxObjects as the second dimension
// While boxesT will have 4 as the third dimension (2 sets of (x, y) coordinates).
// This can be verified by looking at scoresT.shape() etc.
int
maxObjects
=
(
int
)
scoresT
.
shape
()[
1
];
float
[]
scores
=
scoresT
.
copyTo
(
new
float
[
1
][
maxObjects
])[
0
];
float
[]
classes
=
classesT
.
copyTo
(
new
float
[
1
][
maxObjects
])[
0
];
float
[][]
boxes
=
boxesT
.
copyTo
(
new
float
[
1
][
maxObjects
][
4
])[
0
];
// Print all objects whose score is at least 0.5.
System
.
out
.
printf
(
"* %s\n"
,
filename
);
boolean
foundSomething
=
false
;
for
(
int
i
=
0
;
i
<
scores
.
length
;
++
i
)
{
if
(
scores
[
i
]
<
0.5
)
{
continue
;
}
foundSomething
=
true
;
System
.
out
.
printf
(
"\tFound %-20s (score: %.4f)\n"
,
labels
[(
int
)
classes
[
i
]],
scores
[
i
]);
}
if
(!
foundSomething
)
{
System
.
out
.
println
(
"No objects detected with a high enough score."
);
}
}
}
}
}
private
static
void
printSignature
(
SavedModelBundle
model
)
throws
Exception
{
MetaGraphDef
m
=
MetaGraphDef
.
parseFrom
(
model
.
metaGraphDef
());
SignatureDef
sig
=
m
.
getSignatureDefOrThrow
(
"serving_default"
);
int
numInputs
=
sig
.
getInputsCount
();
int
i
=
1
;
System
.
out
.
println
(
"MODEL SIGNATURE"
);
System
.
out
.
println
(
"Inputs:"
);
for
(
Map
.
Entry
<
String
,
TensorInfo
>
entry
:
sig
.
getInputsMap
().
entrySet
())
{
TensorInfo
t
=
entry
.
getValue
();
System
.
out
.
printf
(
"%d of %d: %-20s (Node name in graph: %-20s, type: %s)\n"
,
i
++,
numInputs
,
entry
.
getKey
(),
t
.
getName
(),
t
.
getDtype
());
}
int
numOutputs
=
sig
.
getOutputsCount
();
i
=
1
;
System
.
out
.
println
(
"Outputs:"
);
for
(
Map
.
Entry
<
String
,
TensorInfo
>
entry
:
sig
.
getOutputsMap
().
entrySet
())
{
TensorInfo
t
=
entry
.
getValue
();
System
.
out
.
printf
(
"%d of %d: %-20s (Node name in graph: %-20s, type: %s)\n"
,
i
++,
numOutputs
,
entry
.
getKey
(),
t
.
getName
(),
t
.
getDtype
());
}
System
.
out
.
println
(
"-----------------------------------------------"
);
}
private
static
String
[]
loadLabels
(
String
filename
)
throws
Exception
{
String
text
=
new
String
(
Files
.
readAllBytes
(
Paths
.
get
(
filename
)),
StandardCharsets
.
UTF_8
);
StringIntLabelMap
.
Builder
builder
=
StringIntLabelMap
.
newBuilder
();
TextFormat
.
merge
(
text
,
builder
);
StringIntLabelMap
proto
=
builder
.
build
();
int
maxId
=
0
;
for
(
StringIntLabelMapItem
item
:
proto
.
getItemList
())
{
if
(
item
.
getId
()
>
maxId
)
{
maxId
=
item
.
getId
();
}
}
String
[]
ret
=
new
String
[
maxId
+
1
];
for
(
StringIntLabelMapItem
item
:
proto
.
getItemList
())
{
ret
[
item
.
getId
()]
=
item
.
getDisplayName
();
}
return
ret
;
}
private
static
void
bgr2rgb
(
byte
[]
data
)
{
for
(
int
i
=
0
;
i
<
data
.
length
;
i
+=
3
)
{
byte
tmp
=
data
[
i
];
data
[
i
]
=
data
[
i
+
2
];
data
[
i
+
2
]
=
tmp
;
}
}
private
static
Tensor
<
UInt8
>
makeImageTensor
(
String
filename
)
throws
IOException
{
BufferedImage
img
=
ImageIO
.
read
(
new
File
(
filename
));
if
(
img
.
getType
()
!=
BufferedImage
.
TYPE_3BYTE_BGR
)
{
throw
new
IOException
(
String
.
format
(
"Expected 3-byte BGR encoding in BufferedImage, found %d (file: %s). This code could be made more robust"
,
img
.
getType
(),
filename
));
}
byte
[]
data
=
((
DataBufferByte
)
img
.
getData
().
getDataBuffer
()).
getData
();
// ImageIO.read seems to produce BGR-encoded images, but the model expects RGB.
bgr2rgb
(
data
);
final
long
BATCH_SIZE
=
1
;
final
long
CHANNELS
=
3
;
long
[]
shape
=
new
long
[]
{
BATCH_SIZE
,
img
.
getHeight
(),
img
.
getWidth
(),
CHANNELS
};
return
Tensor
.
create
(
UInt8
.
class
,
shape
,
ByteBuffer
.
wrap
(
data
));
}
private
static
void
printUsage
(
PrintStream
s
)
{
s
.
println
(
"USAGE: <model> <label_map> <image> [<image>] [<image>]"
);
s
.
println
(
""
);
s
.
println
(
"Where"
);
s
.
println
(
"<model> is the path to the SavedModel directory of the model to use."
);
s
.
println
(
" For example, the saved_model directory in tarballs from "
);
s
.
println
(
" https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md)"
);
s
.
println
(
""
);
s
.
println
(
"<label_map> is the path to a file containing information about the labels detected by the model."
);
s
.
println
(
" For example, one of the .pbtxt files from "
);
s
.
println
(
" https://github.com/tensorflow/models/tree/master/research/object_detection/data"
);
s
.
println
(
""
);
s
.
println
(
"<image> is the path to an image file."
);
s
.
println
(
" Sample images can be found from the COCO, Kitti, or Open Images dataset."
);
s
.
println
(
" See: https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md"
);
}
}
samples/java/object_detection/src/main/java/object_detection/protos/StringIntLabelMapOuterClass.java
0 → 100644
View file @
997f7871
This diff is collapsed.
Click to expand it.
samples/java/training/.gitignore
0 → 100644
View file @
997f7871
target
checkpoint
samples/java/training/README.md
0 → 100644
View file @
997f7871
# Training models in Java
Example of training a model (and saving and restoring checkpoints) using the
TensorFlow Java API.
## Quickstart
1.
Train for a few steps:
```
mvn -q compile exec:java -Dexec.args="model/graph.pb checkpoint"
```
2.
Resume training from previous checkpoint and train some more:
```
mvn -q exec:java -Dexec.args="model/graph.pb checkpoint"
```
3.
Delete checkpoint:
```
rm -rf checkpoint
```
## Details
The model in
`model/graph.pb`
represents a very simple linear model:
```
y = x * W + b
```
The
`graph.pb`
file is generated by executing
`create_graph.py`
in Python.
The training is orchestrated by
`src/main/java/Train.java`
, which generates
training data of the form
`y = 3.0 * x + 2.0`
and over time, using gradient
descent, the model should "learn" and the value of
`W`
should converge to 3.0,
and
`b`
to 2.0.
samples/java/training/model/create_graph.py
0 → 100644
View file @
997f7871
from
__future__
import
print_function
import
tensorflow
as
tf
x
=
tf
.
placeholder
(
tf
.
float32
,
name
=
'input'
)
y_
=
tf
.
placeholder
(
tf
.
float32
,
name
=
'target'
)
W
=
tf
.
Variable
(
5.
,
name
=
'W'
)
b
=
tf
.
Variable
(
3.
,
name
=
'b'
)
y
=
x
*
W
+
b
y
=
tf
.
identity
(
y
,
name
=
'output'
)
loss
=
tf
.
reduce_mean
(
tf
.
square
(
y
-
y_
))
optimizer
=
tf
.
train
.
GradientDescentOptimizer
(
learning_rate
=
0.01
)
train_op
=
optimizer
.
minimize
(
loss
,
name
=
'train'
)
init
=
tf
.
global_variables_initializer
()
# Creating a tf.train.Saver adds operations to the graph to save and
# restore variables from checkpoints.
saver_def
=
tf
.
train
.
Saver
().
as_saver_def
()
print
(
'Operation to initialize variables: '
,
init
.
name
)
print
(
'Tensor to feed as input data: '
,
x
.
name
)
print
(
'Tensor to feed as training targets: '
,
y_
.
name
)
print
(
'Tensor to fetch as prediction: '
,
y
.
name
)
print
(
'Operation to train one step: '
,
train_op
.
name
)
print
(
'Tensor to be fed for checkpoint filename:'
,
saver_def
.
filename_tensor_name
)
print
(
'Operation to save a checkpoint: '
,
saver_def
.
save_tensor_name
)
print
(
'Operation to restore a checkpoint: '
,
saver_def
.
restore_op_name
)
print
(
'Tensor to read value of W '
,
W
.
value
().
name
)
print
(
'Tensor to read value of b '
,
b
.
value
().
name
)
with
open
(
'graph.pb'
,
'w'
)
as
f
:
f
.
write
(
tf
.
get_default_graph
().
as_graph_def
().
SerializeToString
())
samples/java/training/model/graph.pb
0 → 100644
View file @
997f7871
File added
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment