Commit 94a2302a authored by Vivek Rathod's avatar Vivek Rathod Committed by TF Object Detection Team
Browse files

Load model in `setup` method to only do it once per thread/worker.

This is the likely reason for jobs that take forever.

PiperOrigin-RevId: 325275789
parent cbbba228
...@@ -78,7 +78,7 @@ class GenerateDetectionDataFn(beam.DoFn): ...@@ -78,7 +78,7 @@ class GenerateDetectionDataFn(beam.DoFn):
self._num_examples_processed = beam.metrics.Metrics.counter( self._num_examples_processed = beam.metrics.Metrics.counter(
'detection_data_generation', 'num_tf_examples_processed') 'detection_data_generation', 'num_tf_examples_processed')
def start_bundle(self): def setup(self):
self._load_inference_model() self._load_inference_model()
def _load_inference_model(self): def _load_inference_model(self):
......
...@@ -212,7 +212,7 @@ class GenerateDetectionDataTest(tf.test.TestCase): ...@@ -212,7 +212,7 @@ class GenerateDetectionDataTest(tf.test.TestCase):
confidence_threshold = 0.8 confidence_threshold = 0.8
inference_fn = generate_detection_data.GenerateDetectionDataFn( inference_fn = generate_detection_data.GenerateDetectionDataFn(
saved_model_path, confidence_threshold) saved_model_path, confidence_threshold)
inference_fn.start_bundle() inference_fn.setup()
generated_example = self._create_tf_example() generated_example = self._create_tf_example()
self.assertAllEqual(tf.train.Example.FromString( self.assertAllEqual(tf.train.Example.FromString(
generated_example).features.feature['image/object/class/label'] generated_example).features.feature['image/object/class/label']
......
...@@ -157,7 +157,7 @@ class GenerateEmbeddingDataFn(beam.DoFn): ...@@ -157,7 +157,7 @@ class GenerateEmbeddingDataFn(beam.DoFn):
self._top_k_embedding_count = top_k_embedding_count self._top_k_embedding_count = top_k_embedding_count
self._bottom_k_embedding_count = bottom_k_embedding_count self._bottom_k_embedding_count = bottom_k_embedding_count
def start_bundle(self): def setup(self):
self._load_inference_model() self._load_inference_model()
def _load_inference_model(self): def _load_inference_model(self):
......
...@@ -250,7 +250,7 @@ class GenerateEmbeddingData(tf.test.TestCase): ...@@ -250,7 +250,7 @@ class GenerateEmbeddingData(tf.test.TestCase):
bottom_k_embedding_count = 0 bottom_k_embedding_count = 0
inference_fn = generate_embedding_data.GenerateEmbeddingDataFn( inference_fn = generate_embedding_data.GenerateEmbeddingDataFn(
saved_model_path, top_k_embedding_count, bottom_k_embedding_count) saved_model_path, top_k_embedding_count, bottom_k_embedding_count)
inference_fn.start_bundle() inference_fn.setup()
generated_example = self._create_tf_example() generated_example = self._create_tf_example()
self.assertAllEqual(tf.train.Example.FromString( self.assertAllEqual(tf.train.Example.FromString(
generated_example).features.feature['image/object/class/label'] generated_example).features.feature['image/object/class/label']
...@@ -268,7 +268,7 @@ class GenerateEmbeddingData(tf.test.TestCase): ...@@ -268,7 +268,7 @@ class GenerateEmbeddingData(tf.test.TestCase):
bottom_k_embedding_count = 0 bottom_k_embedding_count = 0
inference_fn = generate_embedding_data.GenerateEmbeddingDataFn( inference_fn = generate_embedding_data.GenerateEmbeddingDataFn(
saved_model_path, top_k_embedding_count, bottom_k_embedding_count) saved_model_path, top_k_embedding_count, bottom_k_embedding_count)
inference_fn.start_bundle() inference_fn.setup()
generated_example = self._create_tf_example() generated_example = self._create_tf_example()
self.assertAllEqual( self.assertAllEqual(
tf.train.Example.FromString(generated_example).features tf.train.Example.FromString(generated_example).features
...@@ -286,7 +286,7 @@ class GenerateEmbeddingData(tf.test.TestCase): ...@@ -286,7 +286,7 @@ class GenerateEmbeddingData(tf.test.TestCase):
bottom_k_embedding_count = 2 bottom_k_embedding_count = 2
inference_fn = generate_embedding_data.GenerateEmbeddingDataFn( inference_fn = generate_embedding_data.GenerateEmbeddingDataFn(
saved_model_path, top_k_embedding_count, bottom_k_embedding_count) saved_model_path, top_k_embedding_count, bottom_k_embedding_count)
inference_fn.start_bundle() inference_fn.setup()
generated_example = self._create_tf_example() generated_example = self._create_tf_example()
self.assertAllEqual( self.assertAllEqual(
tf.train.Example.FromString(generated_example).features tf.train.Example.FromString(generated_example).features
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment