Commit bbcfd6ba authored by Kaushik Shivakumar's avatar Kaushik Shivakumar
Browse files

save some changes

parent 745fed5c
...@@ -134,7 +134,7 @@ class BoxPredictor(object): ...@@ -134,7 +134,7 @@ class BoxPredictor(object):
pass pass
class KerasBoxPredictor(tf.keras.Model): class KerasBoxPredictor(tf.keras.layers.Layer):
"""Keras-based BoxPredictor.""" """Keras-based BoxPredictor."""
def __init__(self, is_training, num_classes, freeze_batchnorm, def __init__(self, is_training, num_classes, freeze_batchnorm,
......
...@@ -50,13 +50,16 @@ import io ...@@ -50,13 +50,16 @@ import io
import itertools import itertools
import json import json
import os import os
import apache_beam as beam
import numpy as np import numpy as np
import PIL.Image import PIL.Image
import six import six
import tensorflow.compat.v1 as tf import tensorflow.compat.v1 as tf
try:
import apache_beam as beam # pylint:disable=g-import-not-at-top
except ModuleNotFoundError:
pass
class ReKeyDataFn(beam.DoFn): class ReKeyDataFn(beam.DoFn):
"""Re-keys tfrecords by sequence_key. """Re-keys tfrecords by sequence_key.
......
...@@ -22,7 +22,7 @@ import datetime ...@@ -22,7 +22,7 @@ import datetime
import os import os
import tempfile import tempfile
import unittest import unittest
import apache_beam as beam
import numpy as np import numpy as np
import six import six
import tensorflow.compat.v1 as tf import tensorflow.compat.v1 as tf
...@@ -31,6 +31,12 @@ from object_detection.dataset_tools.context_rcnn import add_context_to_examples ...@@ -31,6 +31,12 @@ from object_detection.dataset_tools.context_rcnn import add_context_to_examples
from object_detection.utils import tf_version from object_detection.utils import tf_version
try:
import apache_beam as beam # pylint:disable=g-import-not-at-top
except ModuleNotFoundError:
pass
@contextlib.contextmanager @contextlib.contextmanager
def InMemoryTFRecord(entries): def InMemoryTFRecord(entries):
temp = tempfile.NamedTemporaryFile(delete=False) temp = tempfile.NamedTemporaryFile(delete=False)
......
...@@ -39,12 +39,16 @@ import io ...@@ -39,12 +39,16 @@ import io
import json import json
import logging import logging
import os import os
import apache_beam as beam
import numpy as np import numpy as np
import PIL.Image import PIL.Image
import tensorflow.compat.v1 as tf import tensorflow.compat.v1 as tf
from object_detection.utils import dataset_util from object_detection.utils import dataset_util
try:
import apache_beam as beam # pylint:disable=g-import-not-at-top
except ModuleNotFoundError:
pass
class ParseImage(beam.DoFn): class ParseImage(beam.DoFn):
"""A DoFn that parses a COCO-CameraTraps json and emits TFRecords.""" """A DoFn that parses a COCO-CameraTraps json and emits TFRecords."""
......
...@@ -22,7 +22,6 @@ import os ...@@ -22,7 +22,6 @@ import os
import tempfile import tempfile
import unittest import unittest
import apache_beam as beam
import numpy as np import numpy as np
from PIL import Image from PIL import Image
...@@ -30,6 +29,11 @@ import tensorflow.compat.v1 as tf ...@@ -30,6 +29,11 @@ import tensorflow.compat.v1 as tf
from object_detection.dataset_tools.context_rcnn import create_cococameratraps_tfexample_main from object_detection.dataset_tools.context_rcnn import create_cococameratraps_tfexample_main
from object_detection.utils import tf_version from object_detection.utils import tf_version
try:
import apache_beam as beam # pylint:disable=g-import-not-at-top
except ModuleNotFoundError:
pass
@unittest.skipIf(tf_version.is_tf2(), 'Skipping TF1.X only test.') @unittest.skipIf(tf_version.is_tf2(), 'Skipping TF1.X only test.')
class CreateCOCOCameraTrapsTfexampleTest(tf.test.TestCase): class CreateCOCOCameraTrapsTfexampleTest(tf.test.TestCase):
......
...@@ -48,8 +48,11 @@ from __future__ import print_function ...@@ -48,8 +48,11 @@ from __future__ import print_function
import argparse import argparse
import os import os
import threading import threading
import apache_beam as beam
import tensorflow.compat.v1 as tf import tensorflow.compat.v1 as tf
try:
import apache_beam as beam # pylint:disable=g-import-not-at-top
except ModuleNotFoundError:
pass
class GenerateDetectionDataFn(beam.DoFn): class GenerateDetectionDataFn(beam.DoFn):
......
...@@ -22,7 +22,6 @@ import contextlib ...@@ -22,7 +22,6 @@ import contextlib
import os import os
import tempfile import tempfile
import unittest import unittest
import apache_beam as beam
import numpy as np import numpy as np
import six import six
import tensorflow.compat.v1 as tf import tensorflow.compat.v1 as tf
...@@ -39,6 +38,11 @@ if six.PY2: ...@@ -39,6 +38,11 @@ if six.PY2:
else: else:
mock = unittest.mock mock = unittest.mock
try:
import apache_beam as beam # pylint:disable=g-import-not-at-top
except ModuleNotFoundError:
pass
class FakeModel(model.DetectionModel): class FakeModel(model.DetectionModel):
"""A Fake Detection model with expected output nodes from post-processing.""" """A Fake Detection model with expected output nodes from post-processing."""
......
...@@ -34,7 +34,8 @@ python tensorflow_models/object_detection/export_inference_graph.py \ ...@@ -34,7 +34,8 @@ python tensorflow_models/object_detection/export_inference_graph.py \
--input_type tf_example \ --input_type tf_example \
--pipeline_config_path path/to/faster_rcnn_model.config \ --pipeline_config_path path/to/faster_rcnn_model.config \
--trained_checkpoint_prefix path/to/model.ckpt \ --trained_checkpoint_prefix path/to/model.ckpt \
--output_directory path/to/exported_model_directory --output_directory path/to/exported_model_directory \
--additional_output_tensor_names detection_features
python generate_embedding_data.py \ python generate_embedding_data.py \
--alsologtostderr \ --alsologtostderr \
...@@ -52,11 +53,15 @@ import datetime ...@@ -52,11 +53,15 @@ import datetime
import os import os
import threading import threading
import apache_beam as beam
import numpy as np import numpy as np
import six import six
import tensorflow.compat.v1 as tf import tensorflow.compat.v1 as tf
try:
import apache_beam as beam # pylint:disable=g-import-not-at-top
except ModuleNotFoundError:
pass
class GenerateEmbeddingDataFn(beam.DoFn): class GenerateEmbeddingDataFn(beam.DoFn):
"""Generates embedding data for camera trap images. """Generates embedding data for camera trap images.
......
...@@ -21,7 +21,6 @@ import contextlib ...@@ -21,7 +21,6 @@ import contextlib
import os import os
import tempfile import tempfile
import unittest import unittest
import apache_beam as beam
import numpy as np import numpy as np
import six import six
import tensorflow.compat.v1 as tf import tensorflow.compat.v1 as tf
...@@ -38,6 +37,11 @@ if six.PY2: ...@@ -38,6 +37,11 @@ if six.PY2:
else: else:
mock = unittest.mock mock = unittest.mock
try:
import apache_beam as beam # pylint:disable=g-import-not-at-top
except ModuleNotFoundError:
pass
class FakeModel(model.DetectionModel): class FakeModel(model.DetectionModel):
"""A Fake Detection model with expected output nodes from post-processing.""" """A Fake Detection model with expected output nodes from post-processing."""
......
...@@ -61,7 +61,7 @@ class Head(object): ...@@ -61,7 +61,7 @@ class Head(object):
pass pass
class KerasHead(tf.keras.Model): class KerasHead(tf.keras.layers.Layer):
"""Keras head base class.""" """Keras head base class."""
def call(self, features): def call(self, features):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment