Commit 1d610ef9 authored by Taylor Robie's avatar Taylor Robie Committed by A. Unique TensorFlower
Browse files

Switch wide-deep movielens to use a threadpool rather than a forkpool.

PiperOrigin-RevId: 265114251
parent 0d2c2e01
...@@ -20,6 +20,7 @@ from __future__ import print_function ...@@ -20,6 +20,7 @@ from __future__ import print_function
import atexit import atexit
import multiprocessing import multiprocessing
import multiprocessing.dummy
import os import os
import tempfile import tempfile
import uuid import uuid
...@@ -78,8 +79,8 @@ def iter_shard_dataframe(df, rows_per_core=1000): ...@@ -78,8 +79,8 @@ def iter_shard_dataframe(df, rows_per_core=1000):
It yields a list of dataframes with length equal to the number of CPU cores, It yields a list of dataframes with length equal to the number of CPU cores,
with each dataframe having rows_per_core rows. (Except for the last batch with each dataframe having rows_per_core rows. (Except for the last batch
which may have fewer rows in the dataframes.) Passing vectorized inputs to which may have fewer rows in the dataframes.) Passing vectorized inputs to
a multiprocessing pool is much more effecient than iterating through a a pool is more effecient than iterating through a dataframe in serial and
dataframe in serial and passing a list of inputs to the pool. passing a list of inputs to the pool.
Args: Args:
df: Pandas dataframe to be sharded. df: Pandas dataframe to be sharded.
...@@ -134,7 +135,7 @@ def _serialize_shards(df_shards, columns, pool, writer): ...@@ -134,7 +135,7 @@ def _serialize_shards(df_shards, columns, pool, writer):
Args: Args:
df_shards: A list of pandas dataframes. (Should be of similar size) df_shards: A list of pandas dataframes. (Should be of similar size)
columns: The dataframe columns to be serialized. columns: The dataframe columns to be serialized.
pool: A multiprocessing pool to serialize in parallel. pool: A pool to serialize in parallel.
writer: A TFRecordWriter to write the serialized shards. writer: A TFRecordWriter to write the serialized shards.
""" """
# Pandas does not store columns of arrays as nd arrays. stack remedies this. # Pandas does not store columns of arrays as nd arrays. stack remedies this.
...@@ -190,7 +191,7 @@ def write_to_buffer(dataframe, buffer_path, columns, expected_size=None): ...@@ -190,7 +191,7 @@ def write_to_buffer(dataframe, buffer_path, columns, expected_size=None):
.format(buffer_path)) .format(buffer_path))
count = 0 count = 0
pool = multiprocessing.Pool(multiprocessing.cpu_count()) pool = multiprocessing.dummy.Pool(multiprocessing.cpu_count())
try: try:
with tf.io.TFRecordWriter(buffer_path) as writer: with tf.io.TFRecordWriter(buffer_path) as writer:
for df_shards in iter_shard_dataframe(df=dataframe, for df_shards in iter_shard_dataframe(df=dataframe,
......
...@@ -27,7 +27,7 @@ import pandas as pd ...@@ -27,7 +27,7 @@ import pandas as pd
import tensorflow as tf import tensorflow as tf
# pylint: enable=wrong-import-order # pylint: enable=wrong-import-order
from official.utils.data import file_io from official.r1.utils.data import file_io
from official.utils.misc import keras_utils from official.utils.misc import keras_utils
......
...@@ -29,7 +29,7 @@ import tensorflow as tf ...@@ -29,7 +29,7 @@ import tensorflow as tf
# pylint: enable=wrong-import-order # pylint: enable=wrong-import-order
from official.datasets import movielens from official.datasets import movielens
from official.utils.data import file_io from official.r1.utils.data import file_io
from official.utils.flags import core as flags_core from official.utils.flags import core as flags_core
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment