Commit 1d610ef9 authored by Taylor Robie's avatar Taylor Robie Committed by A. Unique TensorFlower
Browse files

Switch wide-deep movielens to use a threadpool rather than a forkpool.

PiperOrigin-RevId: 265114251
parent 0d2c2e01
......@@ -20,6 +20,7 @@ from __future__ import print_function
import atexit
import multiprocessing
import multiprocessing.dummy
import os
import tempfile
import uuid
......@@ -78,8 +79,8 @@ def iter_shard_dataframe(df, rows_per_core=1000):
It yields a list of dataframes with length equal to the number of CPU cores,
with each dataframe having rows_per_core rows. (Except for the last batch
which may have fewer rows in the dataframes.) Passing vectorized inputs to
a multiprocessing pool is much more effecient than iterating through a
dataframe in serial and passing a list of inputs to the pool.
a pool is more effecient than iterating through a dataframe in serial and
passing a list of inputs to the pool.
Args:
df: Pandas dataframe to be sharded.
......@@ -134,7 +135,7 @@ def _serialize_shards(df_shards, columns, pool, writer):
Args:
df_shards: A list of pandas dataframes. (Should be of similar size)
columns: The dataframe columns to be serialized.
pool: A multiprocessing pool to serialize in parallel.
pool: A pool to serialize in parallel.
writer: A TFRecordWriter to write the serialized shards.
"""
# Pandas does not store columns of arrays as nd arrays. stack remedies this.
......@@ -190,7 +191,7 @@ def write_to_buffer(dataframe, buffer_path, columns, expected_size=None):
.format(buffer_path))
count = 0
pool = multiprocessing.Pool(multiprocessing.cpu_count())
pool = multiprocessing.dummy.Pool(multiprocessing.cpu_count())
try:
with tf.io.TFRecordWriter(buffer_path) as writer:
for df_shards in iter_shard_dataframe(df=dataframe,
......
......@@ -27,7 +27,7 @@ import pandas as pd
import tensorflow as tf
# pylint: enable=wrong-import-order
from official.utils.data import file_io
from official.r1.utils.data import file_io
from official.utils.misc import keras_utils
......
......@@ -29,7 +29,7 @@ import tensorflow as tf
# pylint: enable=wrong-import-order
from official.datasets import movielens
from official.utils.data import file_io
from official.r1.utils.data import file_io
from official.utils.flags import core as flags_core
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment