"torchvision/vscode:/vscode.git/clone" did not exist on "ab0b9a436bd64c4d0309f1b700868c2fe73c0f3e"
Commit 03269887 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 405767455
parent 8e00ce42
......@@ -30,6 +30,19 @@ EXAMPLE_IMAGE = ('third_party/tensorflow_models/official/vision/'
CKPTS = 'gs://**/efficientnets'
def _copy_recursively(src: str, dst: str) -> None:
"""Recursively copy directory."""
for src_dir, _, src_files in tf.io.gfile.walk(src):
dst_dir = os.path.join(dst, os.path.relpath(src_dir, src))
if not tf.io.gfile.exists(dst_dir):
tf.io.gfile.makedirs(dst_dir)
for src_file in src_files:
tf.io.gfile.copy(
os.path.join(src_dir, src_file),
os.path.join(dst_dir, src_file),
overwrite=True)
class MobilenetEdgeTPUBlocksTest(tf.test.TestCase):
def setUp(self):
......@@ -200,7 +213,7 @@ class MobilenetEdgeTPUPredictTest(tf.test.TestCase):
def _copy_saved_model_to_local(self, model_ckpt):
# Copy saved model to local first for speed
tmp_path = '/tmp/saved_model'
tf.io.gfile.RecursivelyCopyDir(model_ckpt, tmp_path, overwrite=True)
_copy_recursively(model_ckpt, tmp_path)
return tmp_path
def _test_prediction(self, model_name, image_size):
......
......@@ -13,11 +13,13 @@
# limitations under the License.
"""Image classification task definition."""
import os
import tempfile
from typing import Any, List, Mapping, Optional, Tuple
from absl import logging
import tensorflow as tf
from official.common import dataset_fn
from official.core import base_task
from official.core import task_factory
......@@ -30,6 +32,19 @@ from official.vision.beta.configs import image_classification as base_cfg
from official.vision.beta.dataloaders import input_reader_factory
def _copy_recursively(src: str, dst: str) -> None:
"""Recursively copy directory."""
for src_dir, _, src_files in tf.io.gfile.walk(src):
dst_dir = os.path.join(dst, os.path.relpath(src_dir, src))
if not tf.io.gfile.exists(dst_dir):
tf.io.gfile.makedirs(dst_dir)
for src_file in src_files:
tf.io.gfile.copy(
os.path.join(src_dir, src_file),
os.path.join(dst_dir, src_file),
overwrite=True)
def get_models() -> Mapping[str, tf.keras.Model]:
"""Returns the mapping from model type name to Keras model."""
model_mapping = {}
......@@ -61,8 +76,8 @@ def load_searched_model(saved_model_path: str) -> tf.keras.Model:
Loaded keras model.
"""
with tempfile.TemporaryDirectory() as tmp_dir:
if tf.io.gfile.IsDirectory(saved_model_path):
tf.io.gfile.RecursivelyCopyDir(saved_model_path, tmp_dir, overwrite=True)
if tf.io.gfile.isdir(saved_model_path):
_copy_recursively(saved_model_path, tmp_dir)
load_path = tmp_dir
else:
raise ValueError('Saved model path is invalid.')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment