"...resnet50_tensorflow.git" did not exist on "c2902cfb1afb5370dd7fc7cb55d0b506475efcc2"
Commit bbcfd6ba authored by Kaushik Shivakumar's avatar Kaushik Shivakumar
Browse files

save some changes

parent 745fed5c
......@@ -134,7 +134,7 @@ class BoxPredictor(object):
pass
class KerasBoxPredictor(tf.keras.Model):
class KerasBoxPredictor(tf.keras.layers.Layer):
"""Keras-based BoxPredictor."""
def __init__(self, is_training, num_classes, freeze_batchnorm,
......
......@@ -50,13 +50,16 @@ import io
import itertools
import json
import os
import apache_beam as beam
import numpy as np
import PIL.Image
import six
import tensorflow.compat.v1 as tf
try:
import apache_beam as beam # pylint:disable=g-import-not-at-top
except ModuleNotFoundError:
pass
class ReKeyDataFn(beam.DoFn):
"""Re-keys tfrecords by sequence_key.
......
......@@ -22,7 +22,7 @@ import datetime
import os
import tempfile
import unittest
import apache_beam as beam
import numpy as np
import six
import tensorflow.compat.v1 as tf
......@@ -31,6 +31,12 @@ from object_detection.dataset_tools.context_rcnn import add_context_to_examples
from object_detection.utils import tf_version
try:
import apache_beam as beam # pylint:disable=g-import-not-at-top
except ModuleNotFoundError:
pass
@contextlib.contextmanager
def InMemoryTFRecord(entries):
temp = tempfile.NamedTemporaryFile(delete=False)
......
......@@ -39,12 +39,16 @@ import io
import json
import logging
import os
import apache_beam as beam
import numpy as np
import PIL.Image
import tensorflow.compat.v1 as tf
from object_detection.utils import dataset_util
try:
import apache_beam as beam # pylint:disable=g-import-not-at-top
except ModuleNotFoundError:
pass
class ParseImage(beam.DoFn):
"""A DoFn that parses a COCO-CameraTraps json and emits TFRecords."""
......
......@@ -22,7 +22,6 @@ import os
import tempfile
import unittest
import apache_beam as beam
import numpy as np
from PIL import Image
......@@ -30,6 +29,11 @@ import tensorflow.compat.v1 as tf
from object_detection.dataset_tools.context_rcnn import create_cococameratraps_tfexample_main
from object_detection.utils import tf_version
try:
import apache_beam as beam # pylint:disable=g-import-not-at-top
except ModuleNotFoundError:
pass
@unittest.skipIf(tf_version.is_tf2(), 'Skipping TF1.X only test.')
class CreateCOCOCameraTrapsTfexampleTest(tf.test.TestCase):
......
......@@ -48,8 +48,11 @@ from __future__ import print_function
import argparse
import os
import threading
import apache_beam as beam
import tensorflow.compat.v1 as tf
try:
import apache_beam as beam # pylint:disable=g-import-not-at-top
except ModuleNotFoundError:
pass
class GenerateDetectionDataFn(beam.DoFn):
......
......@@ -22,7 +22,6 @@ import contextlib
import os
import tempfile
import unittest
import apache_beam as beam
import numpy as np
import six
import tensorflow.compat.v1 as tf
......@@ -39,6 +38,11 @@ if six.PY2:
else:
mock = unittest.mock
try:
import apache_beam as beam # pylint:disable=g-import-not-at-top
except ModuleNotFoundError:
pass
class FakeModel(model.DetectionModel):
"""A Fake Detection model with expected output nodes from post-processing."""
......
......@@ -34,7 +34,8 @@ python tensorflow_models/object_detection/export_inference_graph.py \
--input_type tf_example \
--pipeline_config_path path/to/faster_rcnn_model.config \
--trained_checkpoint_prefix path/to/model.ckpt \
--output_directory path/to/exported_model_directory
--output_directory path/to/exported_model_directory \
--additional_output_tensor_names detection_features
python generate_embedding_data.py \
--alsologtostderr \
......@@ -52,11 +53,15 @@ import datetime
import os
import threading
import apache_beam as beam
import numpy as np
import six
import tensorflow.compat.v1 as tf
try:
import apache_beam as beam # pylint:disable=g-import-not-at-top
except ModuleNotFoundError:
pass
class GenerateEmbeddingDataFn(beam.DoFn):
"""Generates embedding data for camera trap images.
......
......@@ -21,7 +21,6 @@ import contextlib
import os
import tempfile
import unittest
import apache_beam as beam
import numpy as np
import six
import tensorflow.compat.v1 as tf
......@@ -38,6 +37,11 @@ if six.PY2:
else:
mock = unittest.mock
try:
import apache_beam as beam # pylint:disable=g-import-not-at-top
except ModuleNotFoundError:
pass
class FakeModel(model.DetectionModel):
"""A Fake Detection model with expected output nodes from post-processing."""
......
......@@ -61,7 +61,7 @@ class Head(object):
pass
class KerasHead(tf.keras.Model):
class KerasHead(tf.keras.layers.Layer):
"""Keras head base class."""
def call(self, features):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment