"tools/imglab/vscode:/vscode.git/clone" did not exist on "b5511d2cb381435d7dfda27f3fd04b1ff377f974"
Commit 03269887 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 405767455
parent 8e00ce42
...@@ -30,6 +30,19 @@ EXAMPLE_IMAGE = ('third_party/tensorflow_models/official/vision/' ...@@ -30,6 +30,19 @@ EXAMPLE_IMAGE = ('third_party/tensorflow_models/official/vision/'
CKPTS = 'gs://**/efficientnets' CKPTS = 'gs://**/efficientnets'
def _copy_recursively(src: str, dst: str) -> None:
"""Recursively copy directory."""
for src_dir, _, src_files in tf.io.gfile.walk(src):
dst_dir = os.path.join(dst, os.path.relpath(src_dir, src))
if not tf.io.gfile.exists(dst_dir):
tf.io.gfile.makedirs(dst_dir)
for src_file in src_files:
tf.io.gfile.copy(
os.path.join(src_dir, src_file),
os.path.join(dst_dir, src_file),
overwrite=True)
class MobilenetEdgeTPUBlocksTest(tf.test.TestCase): class MobilenetEdgeTPUBlocksTest(tf.test.TestCase):
def setUp(self): def setUp(self):
...@@ -200,7 +213,7 @@ class MobilenetEdgeTPUPredictTest(tf.test.TestCase): ...@@ -200,7 +213,7 @@ class MobilenetEdgeTPUPredictTest(tf.test.TestCase):
def _copy_saved_model_to_local(self, model_ckpt): def _copy_saved_model_to_local(self, model_ckpt):
# Copy saved model to local first for speed # Copy saved model to local first for speed
tmp_path = '/tmp/saved_model' tmp_path = '/tmp/saved_model'
tf.io.gfile.RecursivelyCopyDir(model_ckpt, tmp_path, overwrite=True) _copy_recursively(model_ckpt, tmp_path)
return tmp_path return tmp_path
def _test_prediction(self, model_name, image_size): def _test_prediction(self, model_name, image_size):
......
...@@ -13,11 +13,13 @@ ...@@ -13,11 +13,13 @@
# limitations under the License. # limitations under the License.
"""Image classification task definition.""" """Image classification task definition."""
import os
import tempfile import tempfile
from typing import Any, List, Mapping, Optional, Tuple from typing import Any, List, Mapping, Optional, Tuple
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
from official.common import dataset_fn from official.common import dataset_fn
from official.core import base_task from official.core import base_task
from official.core import task_factory from official.core import task_factory
...@@ -30,6 +32,19 @@ from official.vision.beta.configs import image_classification as base_cfg ...@@ -30,6 +32,19 @@ from official.vision.beta.configs import image_classification as base_cfg
from official.vision.beta.dataloaders import input_reader_factory from official.vision.beta.dataloaders import input_reader_factory
def _copy_recursively(src: str, dst: str) -> None:
"""Recursively copy directory."""
for src_dir, _, src_files in tf.io.gfile.walk(src):
dst_dir = os.path.join(dst, os.path.relpath(src_dir, src))
if not tf.io.gfile.exists(dst_dir):
tf.io.gfile.makedirs(dst_dir)
for src_file in src_files:
tf.io.gfile.copy(
os.path.join(src_dir, src_file),
os.path.join(dst_dir, src_file),
overwrite=True)
def get_models() -> Mapping[str, tf.keras.Model]: def get_models() -> Mapping[str, tf.keras.Model]:
"""Returns the mapping from model type name to Keras model.""" """Returns the mapping from model type name to Keras model."""
model_mapping = {} model_mapping = {}
...@@ -61,8 +76,8 @@ def load_searched_model(saved_model_path: str) -> tf.keras.Model: ...@@ -61,8 +76,8 @@ def load_searched_model(saved_model_path: str) -> tf.keras.Model:
Loaded keras model. Loaded keras model.
""" """
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
if tf.io.gfile.IsDirectory(saved_model_path): if tf.io.gfile.isdir(saved_model_path):
tf.io.gfile.RecursivelyCopyDir(saved_model_path, tmp_dir, overwrite=True) _copy_recursively(saved_model_path, tmp_dir)
load_path = tmp_dir load_path = tmp_dir
else: else:
raise ValueError('Saved model path is invalid.') raise ValueError('Saved model path is invalid.')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment