Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
48270232
Commit
48270232
authored
Jan 22, 2020
by
Mark Daoust
Browse files
Move Java examples.
parent
ab361d21
Changes
24
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
0 additions
and
154 deletions
+0
-154
samples/languages/java/training/model/create_graph.py
samples/languages/java/training/model/create_graph.py
+0
-36
samples/languages/java/training/model/graph.pb
samples/languages/java/training/model/graph.pb
+0
-0
samples/languages/java/training/pom.xml
samples/languages/java/training/pom.xml
+0
-20
samples/languages/java/training/src/main/java/Train.java
samples/languages/java/training/src/main/java/Train.java
+0
-98
No files found.
samples/languages/java/training/model/create_graph.py
deleted
100644 → 0
View file @
ab361d21
from
__future__
import
print_function
import
tensorflow
as
tf
x
=
tf
.
placeholder
(
tf
.
float32
,
name
=
'input'
)
y_
=
tf
.
placeholder
(
tf
.
float32
,
name
=
'target'
)
W
=
tf
.
Variable
(
5.
,
name
=
'W'
)
b
=
tf
.
Variable
(
3.
,
name
=
'b'
)
y
=
x
*
W
+
b
y
=
tf
.
identity
(
y
,
name
=
'output'
)
loss
=
tf
.
reduce_mean
(
tf
.
square
(
y
-
y_
))
optimizer
=
tf
.
train
.
GradientDescentOptimizer
(
learning_rate
=
0.01
)
train_op
=
optimizer
.
minimize
(
loss
,
name
=
'train'
)
init
=
tf
.
global_variables_initializer
()
# Creating a tf.train.Saver adds operations to the graph to save and
# restore variables from checkpoints.
saver_def
=
tf
.
train
.
Saver
().
as_saver_def
()
print
(
'Operation to initialize variables: '
,
init
.
name
)
print
(
'Tensor to feed as input data: '
,
x
.
name
)
print
(
'Tensor to feed as training targets: '
,
y_
.
name
)
print
(
'Tensor to fetch as prediction: '
,
y
.
name
)
print
(
'Operation to train one step: '
,
train_op
.
name
)
print
(
'Tensor to be fed for checkpoint filename:'
,
saver_def
.
filename_tensor_name
)
print
(
'Operation to save a checkpoint: '
,
saver_def
.
save_tensor_name
)
print
(
'Operation to restore a checkpoint: '
,
saver_def
.
restore_op_name
)
print
(
'Tensor to read value of W '
,
W
.
value
().
name
)
print
(
'Tensor to read value of b '
,
b
.
value
().
name
)
with
open
(
'graph.pb'
,
'w'
)
as
f
:
f
.
write
(
tf
.
get_default_graph
().
as_graph_def
().
SerializeToString
())
samples/languages/java/training/model/graph.pb
deleted
100644 → 0
View file @
ab361d21
File deleted
samples/languages/java/training/pom.xml
deleted
100644 → 0
View file @
ab361d21
<project>
<modelVersion>
4.0.0
</modelVersion>
<groupId>
org.myorg
</groupId>
<artifactId>
training
</artifactId>
<version>
1.0-SNAPSHOT
</version>
<properties>
<exec.mainClass>
Train
</exec.mainClass>
<!-- The sample code requires at least JDK 1.7. -->
<!-- The maven compiler plugin defaults to a lower version -->
<maven.compiler.source>
1.7
</maven.compiler.source>
<maven.compiler.target>
1.7
</maven.compiler.target>
</properties>
<dependencies>
<dependency>
<groupId>
org.tensorflow
</groupId>
<artifactId>
tensorflow
</artifactId>
<version>
1.4.0
</version>
</dependency>
</dependencies>
</project>
samples/languages/java/training/src/main/java/Train.java
deleted
100644 → 0
View file @
ab361d21
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
import
java.nio.file.Files
;
import
java.nio.file.Paths
;
import
java.util.List
;
import
java.util.Random
;
import
org.tensorflow.Graph
;
import
org.tensorflow.Session
;
import
org.tensorflow.Tensor
;
import
org.tensorflow.Tensors
;
/**
* Training a trivial linear model.
*/
public
class
Train
{
public
static
void
main
(
String
[]
args
)
throws
Exception
{
if
(
args
.
length
!=
2
)
{
System
.
err
.
println
(
"Require two arguments: The GraphDef file and checkpoint directory"
);
System
.
exit
(
1
);
}
final
byte
[]
graphDef
=
Files
.
readAllBytes
(
Paths
.
get
(
args
[
0
]));
final
String
checkpointDir
=
args
[
1
];
final
boolean
checkpointExists
=
Files
.
exists
(
Paths
.
get
(
checkpointDir
));
try
(
Graph
graph
=
new
Graph
();
Session
sess
=
new
Session
(
graph
);
Tensor
<
String
>
checkpointPrefix
=
Tensors
.
create
(
Paths
.
get
(
checkpointDir
,
"ckpt"
).
toString
()))
{
graph
.
importGraphDef
(
graphDef
);
// Initialize or restore.
// The names of the tensors in the graph are printed out by the program
// that created the graph:
// https://github.com/tensorflow/models/blob/master/samples/languages/java/training/model/create_graph.py
if
(
checkpointExists
)
{
sess
.
runner
().
feed
(
"save/Const"
,
checkpointPrefix
).
addTarget
(
"save/restore_all"
).
run
();
}
else
{
sess
.
runner
().
addTarget
(
"init"
).
run
();
}
System
.
out
.
print
(
"Starting from : "
);
printVariables
(
sess
);
// Train a bunch of times.
// (Will be much more efficient if we sent batches instead of individual values).
final
Random
r
=
new
Random
();
final
int
NUM_EXAMPLES
=
500
;
for
(
int
i
=
1
;
i
<=
5
;
i
++)
{
for
(
int
n
=
0
;
n
<
NUM_EXAMPLES
;
n
++)
{
float
in
=
r
.
nextFloat
();
try
(
Tensor
<
Float
>
input
=
Tensors
.
create
(
in
);
Tensor
<
Float
>
target
=
Tensors
.
create
(
3
*
in
+
2
))
{
// Again the tensor names are from the program that created the graph.
// https://github.com/tensorflow/models/blob/master/samples/languages/java/training/model/create_graph.py
sess
.
runner
().
feed
(
"input"
,
input
).
feed
(
"target"
,
target
).
addTarget
(
"train"
).
run
();
}
}
System
.
out
.
printf
(
"After %5d examples: "
,
i
*
NUM_EXAMPLES
);
printVariables
(
sess
);
}
// Checkpoint.
// The feed and target name are from the program that created the graph.
// https://github.com/tensorflow/models/blob/master/samples/languages/java/training/model/create_graph.py.
sess
.
runner
().
feed
(
"save/Const"
,
checkpointPrefix
).
addTarget
(
"save/control_dependency"
).
run
();
// Example of "inference" in the same graph:
try
(
Tensor
<
Float
>
input
=
Tensors
.
create
(
1.0f
);
Tensor
<
Float
>
output
=
sess
.
runner
().
feed
(
"input"
,
input
).
fetch
(
"output"
).
run
().
get
(
0
).
expect
(
Float
.
class
))
{
System
.
out
.
printf
(
"For input %f, produced %f (ideally would produce 3*%f + 2)\n"
,
input
.
floatValue
(),
output
.
floatValue
(),
input
.
floatValue
());
}
}
}
private
static
void
printVariables
(
Session
sess
)
{
List
<
Tensor
<?>>
values
=
sess
.
runner
().
fetch
(
"W/read"
).
fetch
(
"b/read"
).
run
();
System
.
out
.
printf
(
"W = %f\tb = %f\n"
,
values
.
get
(
0
).
floatValue
(),
values
.
get
(
1
).
floatValue
());
for
(
Tensor
<?>
t
:
values
)
{
t
.
close
();
}
}
}
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment