"torchvision/git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "2d55b9d52d7f3441c3dbe82d4af3c41977ea3ff7"
Commit f1efaf83 authored by Taylor Robie's avatar Taylor Robie
Browse files

don't use forkpool to shuffle with TPUs

parent c8be4828
...@@ -378,6 +378,7 @@ class BaseDataConstructor(threading.Thread): ...@@ -378,6 +378,7 @@ class BaseDataConstructor(threading.Thread):
self._current_epoch_order = np.empty(shape=(0,)) self._current_epoch_order = np.empty(shape=(0,))
self._shuffle_iterator = None self._shuffle_iterator = None
self._shuffle_with_forkpool = stream_files
if stream_files: if stream_files:
self._shard_root = tempfile.mkdtemp(prefix="ncf_") self._shard_root = tempfile.mkdtemp(prefix="ncf_")
atexit.register(tf.gfile.DeleteRecursively, dirname=self._shard_root) atexit.register(tf.gfile.DeleteRecursively, dirname=self._shard_root)
...@@ -449,7 +450,10 @@ class BaseDataConstructor(threading.Thread): ...@@ -449,7 +450,10 @@ class BaseDataConstructor(threading.Thread):
raise raise
def _start_shuffle_iterator(self): def _start_shuffle_iterator(self):
pool = popen_helper.get_forkpool(3, closing=False) if self._shuffle_with_forkpool:
pool = popen_helper.get_forkpool(3, closing=False)
else:
pool = popen_helper.get_threadpool(1, closing=False)
atexit.register(pool.close) atexit.register(pool.close)
args = [(self._elements_in_epoch, stat_utils.random_int32()) args = [(self._elements_in_epoch, stat_utils.random_int32())
for _ in range(self._maximum_number_epochs)] for _ in range(self._maximum_number_epochs)]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment