Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
00fa8b12
"scripts/deprecated/test_httpserver_llava.py" did not exist on "ae7ee01a8e59f755d47426c4b08641053b765a89"
Unverified
Commit
00fa8b12
authored
Jan 26, 2018
by
cclauss
Committed by
GitHub
Jan 26, 2018
Browse files
Merge branch 'master' into patch-13
parents
6d257a4f
1f34fcaf
Changes
328
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
57 additions
and
28 deletions
+57
-28
research/neural_programmer/wiki_data.py
research/neural_programmer/wiki_data.py
+6
-4
research/next_frame_prediction/cross_conv/eval.py
research/next_frame_prediction/cross_conv/eval.py
+1
-0
research/next_frame_prediction/cross_conv/example_gen.py
research/next_frame_prediction/cross_conv/example_gen.py
+1
-0
research/next_frame_prediction/cross_conv/model.py
research/next_frame_prediction/cross_conv/model.py
+1
-0
research/next_frame_prediction/cross_conv/reader.py
research/next_frame_prediction/cross_conv/reader.py
+1
-0
research/next_frame_prediction/cross_conv/sprites_gen.py
research/next_frame_prediction/cross_conv/sprites_gen.py
+1
-0
research/object_detection/dataset_tools/__init__.py
research/object_detection/dataset_tools/__init__.py
+1
-0
research/object_detection/dataset_tools/create_oid_tf_record.py
...ch/object_detection/dataset_tools/create_oid_tf_record.py
+1
-1
research/object_detection/dataset_tools/create_pet_tf_record.py
...ch/object_detection/dataset_tools/create_pet_tf_record.py
+3
-2
research/object_detection/dataset_tools/oid_tfrecord_creation.py
...h/object_detection/dataset_tools/oid_tfrecord_creation.py
+1
-0
research/object_detection/eval.py
research/object_detection/eval.py
+4
-1
research/object_detection/exporter.py
research/object_detection/exporter.py
+1
-1
research/object_detection/g3doc/exporting_models.md
research/object_detection/g3doc/exporting_models.md
+1
-1
research/object_detection/g3doc/running_on_cloud.md
research/object_detection/g3doc/running_on_cloud.md
+5
-1
research/object_detection/g3doc/running_pets.md
research/object_detection/g3doc/running_pets.md
+5
-1
research/object_detection/object_detection_tutorial.ipynb
research/object_detection/object_detection_tutorial.ipynb
+2
-2
research/object_detection/utils/np_box_list_ops.py
research/object_detection/utils/np_box_list_ops.py
+1
-0
research/object_detection/utils/ops.py
research/object_detection/utils/ops.py
+2
-2
research/object_detection/utils/visualization_utils_test.py
research/object_detection/utils/visualization_utils_test.py
+1
-1
research/pcl_rl/README.md
research/pcl_rl/README.md
+18
-11
No files found.
research/neural_programmer/wiki_data.py
View file @
00fa8b12
...
@@ -22,6 +22,8 @@ columns.
...
@@ -22,6 +22,8 @@ columns.
lookup answer (or matrix) is also split into number and word lookup matrix
lookup answer (or matrix) is also split into number and word lookup matrix
Author: aneelakantan (Arvind Neelakantan)
Author: aneelakantan (Arvind Neelakantan)
"""
"""
from
__future__
import
print_function
import
math
import
math
import
os
import
os
import
re
import
re
...
@@ -56,7 +58,7 @@ def correct_unicode(string):
...
@@ -56,7 +58,7 @@ def correct_unicode(string):
#string = re.sub("[“â€Â«Â»]", "\"", string)
#string = re.sub("[“â€Â«Â»]", "\"", string)
#string = re.sub("[•†‡]", "", string)
#string = re.sub("[•†‡]", "", string)
#string = re.sub("[â€â€‘–—]", "-", string)
#string = re.sub("[â€â€‘–—]", "-", string)
string
=
re
.
sub
(
u
r
'[\u2E00-\uFFFF]'
,
""
,
string
)
string
=
re
.
sub
(
r
'[\u2E00-\uFFFF]'
,
""
,
string
)
string
=
re
.
sub
(
"
\\
s+"
,
" "
,
string
).
strip
()
string
=
re
.
sub
(
"
\\
s+"
,
" "
,
string
).
strip
()
return
string
return
string
...
@@ -78,7 +80,7 @@ def full_normalize(string):
...
@@ -78,7 +80,7 @@ def full_normalize(string):
# Remove trailing info in brackets
# Remove trailing info in brackets
string
=
re
.
sub
(
"\[[^\]]*\]"
,
""
,
string
)
string
=
re
.
sub
(
"\[[^\]]*\]"
,
""
,
string
)
# Remove most unicode characters in other languages
# Remove most unicode characters in other languages
string
=
re
.
sub
(
u
r
'[\u007F-\uFFFF]'
,
""
,
string
.
strip
())
string
=
re
.
sub
(
r
'[\u007F-\uFFFF]'
,
""
,
string
.
strip
())
# Remove trailing info in parenthesis
# Remove trailing info in parenthesis
string
=
re
.
sub
(
"\([^)]*\)$"
,
""
,
string
.
strip
())
string
=
re
.
sub
(
"\([^)]*\)$"
,
""
,
string
.
strip
())
string
=
final_normalize
(
string
)
string
=
final_normalize
(
string
)
...
@@ -207,7 +209,7 @@ class WikiQuestionGenerator(object):
...
@@ -207,7 +209,7 @@ class WikiQuestionGenerator(object):
self
.
dev_loader
=
WikiQuestionLoader
(
dev_name
,
root_folder
)
self
.
dev_loader
=
WikiQuestionLoader
(
dev_name
,
root_folder
)
self
.
test_loader
=
WikiQuestionLoader
(
test_name
,
root_folder
)
self
.
test_loader
=
WikiQuestionLoader
(
test_name
,
root_folder
)
self
.
bad_examples
=
0
self
.
bad_examples
=
0
self
.
root_folder
=
root_folder
self
.
root_folder
=
root_folder
self
.
data_folder
=
os
.
path
.
join
(
self
.
root_folder
,
"annotated/data"
)
self
.
data_folder
=
os
.
path
.
join
(
self
.
root_folder
,
"annotated/data"
)
self
.
annotated_examples
=
{}
self
.
annotated_examples
=
{}
self
.
annotated_tables
=
{}
self
.
annotated_tables
=
{}
...
@@ -298,7 +300,7 @@ class WikiQuestionGenerator(object):
...
@@ -298,7 +300,7 @@ class WikiQuestionGenerator(object):
question_id
,
question
,
target_canon
,
context
)
question_id
,
question
,
target_canon
,
context
)
self
.
annotated_tables
[
context
]
=
[]
self
.
annotated_tables
[
context
]
=
[]
counter
+=
1
counter
+=
1
print
"Annotated examples loaded "
,
len
(
self
.
annotated_examples
)
print
(
"Annotated examples loaded "
,
len
(
self
.
annotated_examples
)
)
f
.
close
()
f
.
close
()
def
is_number_column
(
self
,
a
):
def
is_number_column
(
self
,
a
):
...
...
research/next_frame_prediction/cross_conv/eval.py
View file @
00fa8b12
...
@@ -20,6 +20,7 @@ import sys
...
@@ -20,6 +20,7 @@ import sys
import
time
import
time
import
numpy
as
np
import
numpy
as
np
from
six.moves
import
xrange
import
tensorflow
as
tf
import
tensorflow
as
tf
import
model
as
cross_conv_model
import
model
as
cross_conv_model
...
...
research/next_frame_prediction/cross_conv/example_gen.py
View file @
00fa8b12
...
@@ -18,6 +18,7 @@ import random
...
@@ -18,6 +18,7 @@ import random
import
sys
import
sys
import
numpy
as
np
import
numpy
as
np
from
six.moves
import
xrange
import
tensorflow
as
tf
import
tensorflow
as
tf
...
...
research/next_frame_prediction/cross_conv/model.py
View file @
00fa8b12
...
@@ -20,6 +20,7 @@ https://arxiv.org/pdf/1607.02586v1.pdf
...
@@ -20,6 +20,7 @@ https://arxiv.org/pdf/1607.02586v1.pdf
import
math
import
math
import
sys
import
sys
from
six.moves
import
xrange
import
tensorflow
as
tf
import
tensorflow
as
tf
slim
=
tf
.
contrib
.
slim
slim
=
tf
.
contrib
.
slim
...
...
research/next_frame_prediction/cross_conv/reader.py
View file @
00fa8b12
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
"""Read image sequence."""
"""Read image sequence."""
from
six.moves
import
xrange
import
tensorflow
as
tf
import
tensorflow
as
tf
...
...
research/next_frame_prediction/cross_conv/sprites_gen.py
View file @
00fa8b12
...
@@ -21,6 +21,7 @@ import sys
...
@@ -21,6 +21,7 @@ import sys
import
numpy
as
np
import
numpy
as
np
import
scipy.misc
import
scipy.misc
from
six.moves
import
xrange
import
tensorflow
as
tf
import
tensorflow
as
tf
...
...
research/object_detection/dataset_tools/__init__.py
0 → 100644
View file @
00fa8b12
research/object_detection/dataset_tools/create_oid_tf_record.py
View file @
00fa8b12
...
@@ -96,7 +96,7 @@ def main(_):
...
@@ -96,7 +96,7 @@ def main(_):
tf_example
=
oid_tfrecord_creation
.
tf_example_from_annotations_data_frame
(
tf_example
=
oid_tfrecord_creation
.
tf_example_from_annotations_data_frame
(
image_annotations
,
label_map
,
encoded_image
)
image_annotations
,
label_map
,
encoded_image
)
if
tf_example
:
if
tf_example
:
shard_idx
=
long
(
image_id
,
16
)
%
FLAGS
.
num_shards
shard_idx
=
int
(
image_id
,
16
)
%
FLAGS
.
num_shards
output_tfrecords
[
shard_idx
].
write
(
tf_example
.
SerializeToString
())
output_tfrecords
[
shard_idx
].
write
(
tf_example
.
SerializeToString
())
...
...
research/object_detection/dataset_tools/create_pet_tf_record.py
View file @
00fa8b12
...
@@ -160,8 +160,6 @@ def dict_to_tf_example(data,
...
@@ -160,8 +160,6 @@ def dict_to_tf_example(data,
if
not
faces_only
:
if
not
faces_only
:
mask_remapped
=
mask_np
!=
2
mask_remapped
=
mask_np
!=
2
masks
.
append
(
mask_remapped
)
masks
.
append
(
mask_remapped
)
mask_stack
=
np
.
stack
(
masks
).
astype
(
np
.
float32
)
masks_flattened
=
np
.
reshape
(
mask_stack
,
[
-
1
])
feature_dict
=
{
feature_dict
=
{
'image/height'
:
dataset_util
.
int64_feature
(
height
),
'image/height'
:
dataset_util
.
int64_feature
(
height
),
...
@@ -184,8 +182,11 @@ def dict_to_tf_example(data,
...
@@ -184,8 +182,11 @@ def dict_to_tf_example(data,
'image/object/view'
:
dataset_util
.
bytes_list_feature
(
poses
),
'image/object/view'
:
dataset_util
.
bytes_list_feature
(
poses
),
}
}
if
not
faces_only
:
if
not
faces_only
:
mask_stack
=
np
.
stack
(
masks
).
astype
(
np
.
float32
)
masks_flattened
=
np
.
reshape
(
mask_stack
,
[
-
1
])
feature_dict
[
'image/object/mask'
]
=
(
feature_dict
[
'image/object/mask'
]
=
(
dataset_util
.
float_list_feature
(
masks_flattened
.
tolist
()))
dataset_util
.
float_list_feature
(
masks_flattened
.
tolist
()))
example
=
tf
.
train
.
Example
(
features
=
tf
.
train
.
Features
(
feature
=
feature_dict
))
example
=
tf
.
train
.
Example
(
features
=
tf
.
train
.
Features
(
feature
=
feature_dict
))
return
example
return
example
...
...
research/object_detection/dataset_tools/oid_tfrecord_creation.py
View file @
00fa8b12
...
@@ -18,6 +18,7 @@ from __future__ import absolute_import
...
@@ -18,6 +18,7 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
print_function
from
six.moves
import
xrange
import
tensorflow
as
tf
import
tensorflow
as
tf
from
object_detection.core
import
standard_fields
from
object_detection.core
import
standard_fields
...
...
research/object_detection/eval.py
View file @
00fa8b12
...
@@ -103,7 +103,10 @@ def main(unused_argv):
...
@@ -103,7 +103,10 @@ def main(unused_argv):
model_config
=
configs
[
'model'
]
model_config
=
configs
[
'model'
]
eval_config
=
configs
[
'eval_config'
]
eval_config
=
configs
[
'eval_config'
]
input_config
=
configs
[
'eval_input_config'
]
if
FLAGS
.
eval_training_data
:
input_config
=
configs
[
'train_input_config'
]
else
:
input_config
=
configs
[
'eval_input_config'
]
model_fn
=
functools
.
partial
(
model_fn
=
functools
.
partial
(
model_builder
.
build
,
model_builder
.
build
,
...
...
research/object_detection/exporter.py
View file @
00fa8b12
...
@@ -69,7 +69,7 @@ def freeze_graph_with_def_protos(
...
@@ -69,7 +69,7 @@ def freeze_graph_with_def_protos(
if
optimize_graph
:
if
optimize_graph
:
logging
.
info
(
'Graph Rewriter optimizations enabled'
)
logging
.
info
(
'Graph Rewriter optimizations enabled'
)
rewrite_options
=
rewriter_config_pb2
.
RewriterConfig
(
rewrite_options
=
rewriter_config_pb2
.
RewriterConfig
(
optimize
_tensor_layout
=
True
)
layout_
optimize
r
=
rewriter_config_pb2
.
RewriterConfig
.
ON
)
rewrite_options
.
optimizers
.
append
(
'pruning'
)
rewrite_options
.
optimizers
.
append
(
'pruning'
)
rewrite_options
.
optimizers
.
append
(
'constfold'
)
rewrite_options
.
optimizers
.
append
(
'constfold'
)
rewrite_options
.
optimizers
.
append
(
'layout'
)
rewrite_options
.
optimizers
.
append
(
'layout'
)
...
...
research/object_detection/g3doc/exporting_models.md
View file @
00fa8b12
...
@@ -8,7 +8,7 @@ graph proto. A checkpoint will typically consist of three files:
...
@@ -8,7 +8,7 @@ graph proto. A checkpoint will typically consist of three files:
*
model.ckpt-${CHECKPOINT_NUMBER}.meta
*
model.ckpt-${CHECKPOINT_NUMBER}.meta
After you've identified a candidate checkpoint to export, run the following
After you've identified a candidate checkpoint to export, run the following
command from tensorflow/models/research/
object_detection
:
command from tensorflow/models/research/:
```
bash
```
bash
# From tensorflow/models/research/
# From tensorflow/models/research/
...
...
research/object_detection/g3doc/running_on_cloud.md
View file @
00fa8b12
...
@@ -42,7 +42,7 @@ job using GPUs. A sample YAML file is given below:
...
@@ -42,7 +42,7 @@ job using GPUs. A sample YAML file is given below:
```
```
trainingInput:
trainingInput:
runtimeVersion: "1.
0
"
runtimeVersion: "1.
2
"
scaleTier: CUSTOM
scaleTier: CUSTOM
masterType: standard_gpu
masterType: standard_gpu
workerCount: 9
workerCount: 9
...
@@ -71,6 +71,7 @@ following command:
...
@@ -71,6 +71,7 @@ following command:
```
bash
```
bash
# From tensorflow/models/research/
# From tensorflow/models/research/
gcloud ml-engine
jobs
submit training object_detection_
`
date
+%s
`
\
gcloud ml-engine
jobs
submit training object_detection_
`
date
+%s
`
\
--runtime-version
1.2
\
--job-dir
=
gs://
${
TRAIN_DIR
}
\
--job-dir
=
gs://
${
TRAIN_DIR
}
\
--packages
dist/object_detection-0.1.tar.gz,slim/dist/slim-0.1.tar.gz
\
--packages
dist/object_detection-0.1.tar.gz,slim/dist/slim-0.1.tar.gz
\
--module-name
object_detection.train
\
--module-name
object_detection.train
\
...
@@ -90,6 +91,8 @@ Google Cloud Storage.
...
@@ -90,6 +91,8 @@ Google Cloud Storage.
Users can monitor the progress of their training job on the
[
ML Engine
Users can monitor the progress of their training job on the
[
ML Engine
Dashboard
](
https://console.cloud.google.com/mlengine/jobs
)
.
Dashboard
](
https://console.cloud.google.com/mlengine/jobs
)
.
Note: This sample is supported for use with 1.2 runtime version.
## Running an Evaluation Job on Cloud
## Running an Evaluation Job on Cloud
Evaluation jobs run on a single machine, so it is not necessary to write a YAML
Evaluation jobs run on a single machine, so it is not necessary to write a YAML
...
@@ -98,6 +101,7 @@ job:
...
@@ -98,6 +101,7 @@ job:
```
bash
```
bash
gcloud ml-engine
jobs
submit training object_detection_eval_
`
date
+%s
`
\
gcloud ml-engine
jobs
submit training object_detection_eval_
`
date
+%s
`
\
--runtime-version
1.2
\
--job-dir
=
gs://
${
TRAIN_DIR
}
\
--job-dir
=
gs://
${
TRAIN_DIR
}
\
--packages
dist/object_detection-0.1.tar.gz,slim/dist/slim-0.1.tar.gz
\
--packages
dist/object_detection-0.1.tar.gz,slim/dist/slim-0.1.tar.gz
\
--module-name
object_detection.eval
\
--module-name
object_detection.eval
\
...
...
research/object_detection/g3doc/running_pets.md
View file @
00fa8b12
...
@@ -81,7 +81,7 @@ Oxford-IIIT Pet dataset into TFRecords. Run the following commands from the
...
@@ -81,7 +81,7 @@ Oxford-IIIT Pet dataset into TFRecords. Run the following commands from the
```
bash
```
bash
# From tensorflow/models/research/
# From tensorflow/models/research/
python object_detection/create_pet_tf_record.py
\
python object_detection/
dataset_tools/
create_pet_tf_record.py
\
--label_map_path
=
object_detection/data/pet_label_map.pbtxt
\
--label_map_path
=
object_detection/data/pet_label_map.pbtxt
\
--data_dir
=
`
pwd
`
\
--data_dir
=
`
pwd
`
\
--output_dir
=
`
pwd
`
--output_dir
=
`
pwd
`
...
@@ -203,12 +203,15 @@ For running the training Cloud ML job, we'll configure the cluster to use 10
...
@@ -203,12 +203,15 @@ For running the training Cloud ML job, we'll configure the cluster to use 10
training jobs (1 master + 9 workers) and three parameters servers. The
training jobs (1 master + 9 workers) and three parameters servers. The
configuration file can be found at
`object_detection/samples/cloud/cloud.yml`
.
configuration file can be found at
`object_detection/samples/cloud/cloud.yml`
.
Note: This sample is supported for use with 1.2 runtime version.
To start training, execute the following command from the
To start training, execute the following command from the
`tensorflow/models/research/`
directory:
`tensorflow/models/research/`
directory:
```
bash
```
bash
# From tensorflow/models/research/
# From tensorflow/models/research/
gcloud ml-engine
jobs
submit training
`
whoami
`
_object_detection_
`
date
+%s
`
\
gcloud ml-engine
jobs
submit training
`
whoami
`
_object_detection_
`
date
+%s
`
\
--runtime-version
1.2
\
--job-dir
=
gs://
${
YOUR_GCS_BUCKET
}
/train
\
--job-dir
=
gs://
${
YOUR_GCS_BUCKET
}
/train
\
--packages
dist/object_detection-0.1.tar.gz,slim/dist/slim-0.1.tar.gz
\
--packages
dist/object_detection-0.1.tar.gz,slim/dist/slim-0.1.tar.gz
\
--module-name
object_detection.train
\
--module-name
object_detection.train
\
...
@@ -224,6 +227,7 @@ Once training has started, we can run an evaluation concurrently:
...
@@ -224,6 +227,7 @@ Once training has started, we can run an evaluation concurrently:
```
bash
```
bash
# From tensorflow/models/research/
# From tensorflow/models/research/
gcloud ml-engine
jobs
submit training
`
whoami
`
_object_detection_eval_
`
date
+%s
`
\
gcloud ml-engine
jobs
submit training
`
whoami
`
_object_detection_eval_
`
date
+%s
`
\
--runtime-version
1.2
\
--job-dir
=
gs://
${
YOUR_GCS_BUCKET
}
/train
\
--job-dir
=
gs://
${
YOUR_GCS_BUCKET
}
/train
\
--packages
dist/object_detection-0.1.tar.gz,slim/dist/slim-0.1.tar.gz
\
--packages
dist/object_detection-0.1.tar.gz,slim/dist/slim-0.1.tar.gz
\
--module-name
object_detection.eval
\
--module-name
object_detection.eval
\
...
...
research/object_detection/object_detection_tutorial.ipynb
View file @
00fa8b12
...
@@ -36,8 +36,8 @@
...
@@ -36,8 +36,8 @@
"from matplotlib import pyplot as plt\n",
"from matplotlib import pyplot as plt\n",
"from PIL import Image\n",
"from PIL import Image\n",
"\n",
"\n",
"if tf.__version__
!=
'1.4.0':\n",
"if tf.__version__
<
'1.4.0':\n",
" raise ImportError('Please upgrade your tensorflow installation to v1.4.
0
!')\n"
" raise ImportError('Please upgrade your tensorflow installation to v1.4.
* or later
!')\n"
]
]
},
},
{
{
...
...
research/object_detection/utils/np_box_list_ops.py
View file @
00fa8b12
...
@@ -21,6 +21,7 @@ Example box operations that are supported:
...
@@ -21,6 +21,7 @@ Example box operations that are supported:
"""
"""
import
numpy
as
np
import
numpy
as
np
from
six.moves
import
xrange
from
object_detection.utils
import
np_box_list
from
object_detection.utils
import
np_box_list
from
object_detection.utils
import
np_box_ops
from
object_detection.utils
import
np_box_ops
...
...
research/object_detection/utils/ops.py
View file @
00fa8b12
...
@@ -203,9 +203,9 @@ def padded_one_hot_encoding(indices, depth, left_pad):
...
@@ -203,9 +203,9 @@ def padded_one_hot_encoding(indices, depth, left_pad):
TODO: add runtime checks for depth and indices.
TODO: add runtime checks for depth and indices.
"""
"""
if
depth
<
0
or
not
isinstance
(
depth
,
(
int
,
long
)
if
six
.
PY2
else
int
):
if
depth
<
0
or
not
isinstance
(
depth
,
six
.
integer_types
):
raise
ValueError
(
'`depth` must be a non-negative integer.'
)
raise
ValueError
(
'`depth` must be a non-negative integer.'
)
if
left_pad
<
0
or
not
isinstance
(
left_pad
,
(
int
,
long
)
if
six
.
PY2
else
int
):
if
left_pad
<
0
or
not
isinstance
(
left_pad
,
six
.
integer_types
):
raise
ValueError
(
'`left_pad` must be a non-negative integer.'
)
raise
ValueError
(
'`left_pad` must be a non-negative integer.'
)
if
depth
==
0
:
if
depth
==
0
:
return
None
return
None
...
...
research/object_detection/utils/visualization_utils_test.py
View file @
00fa8b12
...
@@ -145,7 +145,7 @@ class VisualizationUtilsTest(tf.test.TestCase):
...
@@ -145,7 +145,7 @@ class VisualizationUtilsTest(tf.test.TestCase):
for
i
in
range
(
images_with_boxes_np
.
shape
[
0
]):
for
i
in
range
(
images_with_boxes_np
.
shape
[
0
]):
img_name
=
'image_'
+
str
(
i
)
+
'.png'
img_name
=
'image_'
+
str
(
i
)
+
'.png'
output_file
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
img_name
)
output_file
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
img_name
)
print
'Writing output image %d to %s'
%
(
i
,
output_file
)
print
(
'Writing output image %d to %s'
%
(
i
,
output_file
)
)
image_pil
=
Image
.
fromarray
(
images_with_boxes_np
[
i
,
...])
image_pil
=
Image
.
fromarray
(
images_with_boxes_np
[
i
,
...])
image_pil
.
save
(
output_file
)
image_pil
.
save
(
output_file
)
...
...
research/pcl_rl/README.md
View file @
00fa8b12
...
@@ -67,20 +67,27 @@ python trainer.py --logtostderr --batch_size=25 --env=HalfCheetah-v1 \
...
@@ -67,20 +67,27 @@ python trainer.py --logtostderr --batch_size=25 --env=HalfCheetah-v1 \
--max_divergence=0.05 --value_opt=best_fit --critic_weight=0.0 \
--max_divergence=0.05 --value_opt=best_fit --critic_weight=0.0 \
```
```
Run Mujoco task with Trust-PCL:
To run Mujoco task using Trust-PCL (off-policy) use the below command.
It should work well across all environments, given that you
search sufficiently among
(1) max_divergence (0.001, 0.0005, 0.002 are good values),
(2) rollout (1, 5, 10 are good values),
(3) tf_seed (need to average over enough random seeds).
```
```
python trainer.py --logtostderr --batch_size=1 --env=HalfCheetah-v1 \
python trainer.py --logtostderr --batch_size=1 --env=HalfCheetah-v1 \
--validation_frequency=50 --rollout=10 --critic_weight=0.0 \
--validation_frequency=250 --rollout=1 --critic_weight=1.0 --gamma=0.995 \
--gamma=0.995 --clip_norm=40 --learning_rate=0.002 \
--clip_norm=40 --learning_rate=0.0001 --replay_buffer_freq=1 \
--replay_buffer_freq=1 --replay_buffer_size=20000 \
--replay_buffer_size=5000 --replay_buffer_alpha=0.001 --norecurrent \
--replay_buffer_alpha=0.1 --norecurrent --objective=pcl \
--objective=pcl --max_step=10 --cutoff_agent=1000 --tau=0.0 --eviction=fifo \
--max_step=100 --tau=0.0 --eviction=fifo --max_divergence=0.001 \
--max_divergence=0.001 --internal_dim=256 --replay_batch_size=64 \
--internal_dim=64 --cutoff_agent=1000 \
--nouse_online_batch --batch_by_steps --value_hidden_layers=2 \
--replay_batch_size=25 --nouse_online_batch --batch_by_steps \
--update_eps_lambda --nounify_episodes --target_network_lag=0.99 \
--sample_from=target --value_opt=grad --value_hidden_layers=2 \
--sample_from=online --clip_adv=1 --prioritize_by=step --num_steps=1000000 \
--update_eps_lambda --unify_episodes --clip_adv=1.0 \
--noinput_prev_actions --use_target_values --tf_seed=57
--target_network_lag=0.99 --prioritize_by=step
```
```
Run Mujoco task with PCL constraint trust region:
Run Mujoco task with PCL constraint trust region:
...
...
Prev
1
…
6
7
8
9
10
11
12
13
14
…
17
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment