Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
ca15f5d9
Unverified
Commit
ca15f5d9
authored
Jan 16, 2018
by
Asim Shankar
Committed by
GitHub
Jan 16, 2018
Browse files
Merge pull request #3157 from asimshankar/java-samples
[samples]: Samples using the Java API.
parents
99491464
28bd85d1
Changes
24
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
147 additions
and
0 deletions
+147
-0
samples/languages/java/training/model/create_graph.py
samples/languages/java/training/model/create_graph.py
+36
-0
samples/languages/java/training/model/graph.pb
samples/languages/java/training/model/graph.pb
+0
-0
samples/languages/java/training/pom.xml
samples/languages/java/training/pom.xml
+20
-0
samples/languages/java/training/src/main/java/Train.java
samples/languages/java/training/src/main/java/Train.java
+91
-0
No files found.
samples/languages/java/training/model/create_graph.py
0 → 100644
View file @
ca15f5d9
from
__future__
import
print_function
import
tensorflow
as
tf
x
=
tf
.
placeholder
(
tf
.
float32
,
name
=
'input'
)
y_
=
tf
.
placeholder
(
tf
.
float32
,
name
=
'target'
)
W
=
tf
.
Variable
(
5.
,
name
=
'W'
)
b
=
tf
.
Variable
(
3.
,
name
=
'b'
)
y
=
x
*
W
+
b
y
=
tf
.
identity
(
y
,
name
=
'output'
)
loss
=
tf
.
reduce_mean
(
tf
.
square
(
y
-
y_
))
optimizer
=
tf
.
train
.
GradientDescentOptimizer
(
learning_rate
=
0.01
)
train_op
=
optimizer
.
minimize
(
loss
,
name
=
'train'
)
init
=
tf
.
global_variables_initializer
()
# Creating a tf.train.Saver adds operations to the graph to save and
# restore variables from checkpoints.
saver_def
=
tf
.
train
.
Saver
().
as_saver_def
()
print
(
'Operation to initialize variables: '
,
init
.
name
)
print
(
'Tensor to feed as input data: '
,
x
.
name
)
print
(
'Tensor to feed as training targets: '
,
y_
.
name
)
print
(
'Tensor to fetch as prediction: '
,
y
.
name
)
print
(
'Operation to train one step: '
,
train_op
.
name
)
print
(
'Tensor to be fed for checkpoint filename:'
,
saver_def
.
filename_tensor_name
)
print
(
'Operation to save a checkpoint: '
,
saver_def
.
save_tensor_name
)
print
(
'Operation to restore a checkpoint: '
,
saver_def
.
restore_op_name
)
print
(
'Tensor to read value of W '
,
W
.
value
().
name
)
print
(
'Tensor to read value of b '
,
b
.
value
().
name
)
with
open
(
'graph.pb'
,
'w'
)
as
f
:
f
.
write
(
tf
.
get_default_graph
().
as_graph_def
().
SerializeToString
())
samples/languages/java/training/model/graph.pb
0 → 100644
View file @
ca15f5d9
File added
samples/languages/java/training/pom.xml
0 → 100644
View file @
ca15f5d9
<project>
<modelVersion>
4.0.0
</modelVersion>
<groupId>
org.myorg
</groupId>
<artifactId>
training
</artifactId>
<version>
1.0-SNAPSHOT
</version>
<properties>
<exec.mainClass>
Train
</exec.mainClass>
<!-- The sample code requires at least JDK 1.7. -->
<!-- The maven compiler plugin defaults to a lower version -->
<maven.compiler.source>
1.7
</maven.compiler.source>
<maven.compiler.target>
1.7
</maven.compiler.target>
</properties>
<dependencies>
<dependency>
<groupId>
org.tensorflow
</groupId>
<artifactId>
tensorflow
</artifactId>
<version>
1.4.0
</version>
</dependency>
</dependencies>
</project>
samples/languages/java/training/src/main/java/Train.java
0 → 100644
View file @
ca15f5d9
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
import
java.nio.file.Files
;
import
java.nio.file.Paths
;
import
java.util.List
;
import
java.util.Random
;
import
org.tensorflow.Graph
;
import
org.tensorflow.Session
;
import
org.tensorflow.Tensor
;
import
org.tensorflow.Tensors
;
/**
* Training a trivial linear model.
*/
public
class
Train
{
public
static
void
main
(
String
[]
args
)
throws
Exception
{
if
(
args
.
length
!=
2
)
{
System
.
err
.
println
(
"Require two arguments: The GraphDef file and checkpoint directory"
);
System
.
exit
(
1
);
}
final
byte
[]
graphDef
=
Files
.
readAllBytes
(
Paths
.
get
(
args
[
0
]));
final
String
checkpointDir
=
args
[
1
];
final
boolean
checkpointExists
=
Files
.
exists
(
Paths
.
get
(
checkpointDir
));
try
(
Graph
graph
=
new
Graph
();
Session
sess
=
new
Session
(
graph
);
Tensor
<
String
>
checkpointPrefix
=
Tensors
.
create
(
Paths
.
get
(
checkpointDir
,
"ckpt"
).
toString
()))
{
graph
.
importGraphDef
(
graphDef
);
// Initialize or restore
if
(
checkpointExists
)
{
sess
.
runner
().
feed
(
"save/Const"
,
checkpointPrefix
).
addTarget
(
"save/restore_all"
).
run
();
}
else
{
sess
.
runner
().
addTarget
(
"init"
).
run
();
}
System
.
out
.
print
(
"Starting from : "
);
printVariables
(
sess
);
// Train a bunch of times.
// (Will be much more efficient if we sent batches instead of individual values).
final
Random
r
=
new
Random
();
final
int
NUM_EXAMPLES
=
500
;
for
(
int
i
=
1
;
i
<=
5
;
i
++)
{
for
(
int
n
=
0
;
n
<
NUM_EXAMPLES
;
n
++)
{
float
in
=
r
.
nextFloat
();
try
(
Tensor
<
Float
>
input
=
Tensors
.
create
(
in
);
Tensor
<
Float
>
target
=
Tensors
.
create
(
3
*
in
+
2
))
{
sess
.
runner
().
feed
(
"input"
,
input
).
feed
(
"target"
,
target
).
addTarget
(
"train"
).
run
();
}
}
System
.
out
.
printf
(
"After %5d examples: "
,
i
*
NUM_EXAMPLES
);
printVariables
(
sess
);
}
// Checkpoint
sess
.
runner
().
feed
(
"save/Const"
,
checkpointPrefix
).
addTarget
(
"save/control_dependency"
).
run
();
// Example of "inference" in the same graph:
try
(
Tensor
<
Float
>
input
=
Tensors
.
create
(
1.0f
);
Tensor
<
Float
>
output
=
sess
.
runner
().
feed
(
"input"
,
input
).
fetch
(
"output"
).
run
().
get
(
0
).
expect
(
Float
.
class
))
{
System
.
out
.
printf
(
"For input %f, produced %f (ideally would produce 3*%f + 2)\n"
,
input
.
floatValue
(),
output
.
floatValue
(),
input
.
floatValue
());
}
}
}
private
static
void
printVariables
(
Session
sess
)
{
List
<
Tensor
<?>>
values
=
sess
.
runner
().
fetch
(
"W/read"
).
fetch
(
"b/read"
).
run
();
System
.
out
.
printf
(
"W = %f\tb = %f\n"
,
values
.
get
(
0
).
floatValue
(),
values
.
get
(
1
).
floatValue
());
for
(
Tensor
<?>
t
:
values
)
{
t
.
close
();
}
}
}
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment