"git@developer.sourcefind.cn:wangsen/mineru.git" did not exist on "6c9b23c3d39c265f639163e66746aa90c8d65ede"
Commit 4a91d110 authored by Zhichao Lu's avatar Zhichao Lu Committed by pkulzc
Browse files

Use get_data_files_path to access test data.

PiperOrigin-RevId: 190505306
parent 93b8168a
...@@ -32,20 +32,19 @@ from object_detection.builders import model_builder ...@@ -32,20 +32,19 @@ from object_detection.builders import model_builder
from object_detection.core import standard_fields as fields from object_detection.core import standard_fields as fields
from object_detection.utils import config_util from object_detection.utils import config_util
FLAGS = tf.flags.FLAGS
MODEL_NAME_FOR_TEST = model_test_util.SSD_INCEPTION_MODEL_NAME MODEL_NAME_FOR_TEST = model_test_util.SSD_INCEPTION_MODEL_NAME
def _get_data_path(): def _get_data_path():
"""Returns an absolute path to TFRecord file.""" """Returns an absolute path to TFRecord file."""
return os.path.join(FLAGS.test_srcdir, model_test_util.PATH_BASE, 'test_data', return os.path.join(tf.resource_loader.get_data_files_path(), 'test_data',
'pets_examples.record') 'pets_examples.record')
def _get_labelmap_path(): def _get_labelmap_path():
"""Returns an absolute path to label map file.""" """Returns an absolute path to label map file."""
return os.path.join(FLAGS.test_srcdir, model_test_util.PATH_BASE, 'data', return os.path.join(tf.resource_loader.get_data_files_path(), 'data',
'pet_label_map.pbtxt') 'pet_label_map.pbtxt')
......
...@@ -28,13 +28,12 @@ FLAGS = tf.flags.FLAGS ...@@ -28,13 +28,12 @@ FLAGS = tf.flags.FLAGS
FASTER_RCNN_MODEL_NAME = 'faster_rcnn_resnet50_pets' FASTER_RCNN_MODEL_NAME = 'faster_rcnn_resnet50_pets'
SSD_INCEPTION_MODEL_NAME = 'ssd_inception_v2_pets' SSD_INCEPTION_MODEL_NAME = 'ssd_inception_v2_pets'
PATH_BASE = 'google3/third_party/tensorflow_models/object_detection/'
def GetPipelineConfigPath(model_name): def GetPipelineConfigPath(model_name):
"""Returns path to the local pipeline config file.""" """Returns path to the local pipeline config file."""
return os.path.join(FLAGS.test_srcdir, PATH_BASE, 'samples', 'configs', return os.path.join(tf.resource_loader.get_data_files_path(), 'samples',
model_name + '.config') 'configs', model_name + '.config')
def InitializeFlags(model_name_for_test): def InitializeFlags(model_name_for_test):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment