Commit 609f332f authored by Stephan Lee's avatar Stephan Lee Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 460731093
parent caaa39c2
......@@ -14,6 +14,7 @@
"""Provides the `ExportSavedModel` action and associated helper classes."""
import os
import re
from typing import Callable, Optional
......@@ -77,9 +78,9 @@ class ExportFileManager:
One common alternative maybe be to use the current global step count,
for instance passing `next_id_fn=global_step.numpy`.
"""
self._base_name = base_name
self._base_name = os.path.normpath(base_name)
self._max_to_keep = max_to_keep
self._next_id_fn = next_id_fn or _CounterIdFn(base_name)
self._next_id_fn = next_id_fn or _CounterIdFn(self._base_name)
@property
def managed_files(self):
......
......@@ -122,6 +122,26 @@ class ExportSavedModelTest(tf.test.TestCase):
manager.managed_files,
[f'{base_name}-10', f'{base_name}-50', f'{base_name}-1000'])
def test_export_file_manager_managed_files_double_slash(self):
directory = self.create_tempdir('foo//bar')
directory.create_file('basename-5')
directory.create_file('basename-10')
directory.create_file('basename-50')
directory.create_file('basename-1000')
directory.create_file('basename-9')
directory.create_file('basename-10-suffix')
base_name = os.path.join(directory.full_path, 'basename')
expected_base_name = os.path.normpath(base_name)
self.assertNotEqual(base_name, expected_base_name)
manager = actions.ExportFileManager(base_name, max_to_keep=3)
self.assertLen(manager.managed_files, 5)
self.assertEqual(manager.next_name(), f'{expected_base_name}-1001')
manager.clean_up()
self.assertEqual(manager.managed_files, [
f'{expected_base_name}-10', f'{expected_base_name}-50',
f'{expected_base_name}-1000'
])
def test_export_saved_model(self):
directory = self.create_tempdir()
base_name = os.path.join(directory.full_path, 'basename')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment