Unverified Commit 71cf7865 authored by Jinjing Zhou's avatar Jinjing Zhou Committed by GitHub
Browse files

Revert "[Distributed] DistDataloader (#1870)" (#1876)

This reverts commit 6557291f.
parent 6557291f
......@@ -9,10 +9,9 @@ from .sparse_emb import SparseAdagrad, DistEmbedding
from .rpc import *
from .rpc_server import start_server
from .rpc_client import connect_to_server, exit_client, init_rpc
from .rpc_client import connect_to_server, exit_client
from .kvstore import KVServer, KVClient
from .server_state import ServerState
from .dist_dataloader import DistDataLoader
from .graph_services import sample_neighbors, in_subgraph
if os.environ.get('DGL_ROLE', 'client') == 'server':
......
# pylint: disable=global-variable-undefined, invalid-name
"""Multiprocess dataloader for distributed training"""
import multiprocessing as mp
import time
import traceback
from . import exit_client
from .rpc_client import get_sampler_pool
from .. import backend as F
__all__ = ["DistDataLoader"]
def call_collate_fn(next_data):
"""Call collate function"""
try:
result = DGL_GLOBAL_COLLATE_FN(next_data)
DGL_GLOBAL_MP_QUEUE.put(result)
except Exception as e:
traceback.print_exc()
print(e)
raise e
return 1
def init_fn(collate_fn, queue, sig_queue):
"""Initialize setting collate function and mp.Queue in the subprocess"""
global DGL_GLOBAL_COLLATE_FN
global DGL_GLOBAL_MP_QUEUE
global DGL_SIG_QUEUE
DGL_SIG_QUEUE = sig_queue
DGL_GLOBAL_MP_QUEUE = queue
DGL_GLOBAL_COLLATE_FN = collate_fn
time.sleep(1)
return 1
def _exit():
exit_client()
time.sleep(1)
class DistDataLoader:
"""DGL customized multiprocessing dataloader"""
def __init__(self, dataset, batch_size, shuffle=False,
num_workers=1, collate_fn=None, drop_last=False,
queue_size=None):
"""
dataset (Dataset): dataset from which to load the data.
batch_size (int, optional): how many samples per batch to load
(default: ``1``).
num_workers (int, optional): how many subprocesses to use for data
loading. ``0`` means that the data will be loaded in the main process.
(default: ``0``)
collate_fn (callable, optional): merges a list of samples to form a
mini-batch of Tensor(s). Used when using batched loading from a
map-style dataset.
drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
if the dataset size is not divisible by the batch size. If ``False`` and
the size of dataset is not divisible by the batch size, then the last batch
will be smaller. (default: ``False``)
queue_size (int): Size of multiprocessing queue
"""
assert num_workers > 0
if queue_size is None:
queue_size = num_workers * 4
self.queue_size = queue_size
self.batch_size = batch_size
self.queue_size = queue_size
self.collate_fn = collate_fn
self.current_pos = 0
self.num_workers = num_workers
self.m = mp.Manager()
self.queue = self.m.Queue(maxsize=queue_size)
self.sig_queue = self.m.Queue(maxsize=num_workers)
self.drop_last = drop_last
self.send_idxs = 0
self.recv_idxs = 0
self.started = False
self.shuffle = shuffle
self.pool, num_sampler_workers = get_sampler_pool()
if self.pool is None:
ctx = mp.get_context("spawn")
self.pool = ctx.Pool(num_workers)
else:
assert num_sampler_workers == num_workers, "Num workers should be the same"
results = []
for _ in range(num_workers):
results.append(self.pool.apply_async(
init_fn, args=(collate_fn, self.queue, self.sig_queue)))
time.sleep(0.1)
for res in results:
res.get()
self.dataset = F.tensor(dataset)
self.expected_idxs = len(dataset) // self.batch_size
if not self.drop_last and len(dataset) % self.batch_size != 0:
self.expected_idxs += 1
def __next__(self):
if not self.started:
for _ in range(self.queue_size):
self._request_next_batch()
self._request_next_batch()
if self.recv_idxs < self.expected_idxs:
result = self.queue.get(timeout=9999)
self.recv_idxs += 1
return result
else:
self.recv_idxs = 0
self.current_pos = 0
raise StopIteration
def __iter__(self):
if self.shuffle:
self.dataset = F.rand_shuffle(self.dataset)
return self
def _request_next_batch(self):
next_data = self._next_data()
if next_data is None:
return None
else:
async_result = self.pool.apply_async(
call_collate_fn, args=(next_data, ))
self.send_idxs += 1
return async_result
def _next_data(self):
if self.current_pos == len(self.dataset):
return None
end_pos = 0
if self.current_pos + self.batch_size > len(self.dataset):
if self.drop_last:
return None
else:
end_pos = len(self.dataset)
else:
end_pos = self.current_pos + self.batch_size
ret = self.dataset[self.current_pos:end_pos]
self.current_pos = end_pos
return ret
def close(self):
"""Finalize the connection with server and close pool"""
for _ in range(self.num_workers):
self.pool.apply_async(_exit)
time.sleep(0.1)
self.pool.close()
......@@ -17,6 +17,7 @@ from .graph_partition_book import PartitionPolicy, get_shared_mem_partition_book
from .graph_partition_book import NODE_PART_POLICY, EDGE_PART_POLICY
from .shared_mem_utils import _to_shared_mem, _get_ndata_path, _get_edata_path, DTYPE_DICT
from . import rpc
from .rpc_client import connect_to_server
from .server_state import ServerState
from .rpc_server import start_server
from .dist_tensor import DistTensor, _get_data_name
......@@ -295,9 +296,6 @@ class DistGraph:
The partition config file. It's used in the standalone mode.
'''
def __init__(self, ip_config, graph_name, gpb=None, conf_file=None):
self.ip_config = ip_config
self.graph_name = graph_name
self._gpb_input = gpb
if os.environ.get('DGL_DIST_MODE', 'standalone') == 'standalone':
assert conf_file is not None, \
'When running in the standalone model, the partition config file is required'
......@@ -315,19 +313,7 @@ class DistGraph:
self._client.add_data(_get_data_name(name, EDGE_PART_POLICY), edge_feats[name])
rpc.set_num_client(1)
else:
self._init()
self._ndata = NodeDataView(self)
self._edata = EdgeDataView(self)
self._num_nodes = 0
self._num_edges = 0
for part_md in self._gpb.metadata():
self._num_nodes += int(part_md['num_nodes'])
self._num_edges += int(part_md['num_edges'])
def _init(self):
ip_config, graph_name, gpb = self.ip_config, self.graph_name, self._gpb_input
connect_to_server(ip_config=ip_config)
self._client = KVClient(ip_config)
g = _get_graph_from_shared_mem(graph_name)
if g is not None:
......@@ -337,17 +323,12 @@ class DistGraph:
self._gpb = get_shared_mem_partition_book(graph_name, self._g)
if self._gpb is None:
self._gpb = gpb
self._client.barrier()
self._client.map_shared_data(self._gpb)
def __getstate__(self):
return self.ip_config, self.graph_name, self._gpb_input
def __setstate__(self, state):
self.ip_config, self.graph_name, self._gpb_input = state
self._init()
self._ndata = NodeDataView(self)
self._edata = EdgeDataView(self)
self._num_nodes = 0
self._num_edges = 0
for part_md in self._gpb.metadata():
......
......@@ -132,7 +132,8 @@ def merge_graphs(res_list, num_nodes):
src_tensor = res_list[0].global_src
dst_tensor = res_list[0].global_dst
eid_tensor = res_list[0].global_eids
g = graph((src_tensor, dst_tensor), num_nodes=num_nodes)
g = graph((src_tensor, dst_tensor),
restrict_format='coo', num_nodes=num_nodes)
g.edata[EID] = eid_tensor
return g
......
......@@ -1001,6 +1001,7 @@ class KVClient(object):
Store the partition information
"""
# Get shared data from server side
self.barrier()
request = GetSharedDataRequest(GET_SHARED_MSG)
rpc.send_request(self._main_server_id, request)
response = rpc.recv_response()
......@@ -1042,6 +1043,7 @@ class KVClient(object):
response = rpc.recv_response()
assert response.msg == SEND_META_TO_BACKUP_MSG
self._data_name_list.add(name)
self.barrier()
def data_name_list(self):
"""Get all the data name"""
......
......@@ -1024,6 +1024,7 @@ class ShutDownRequest(Request):
self.client_id = state
def process_request(self, server_state):
assert self.client_id == 0
finalize_server()
return 'exit'
......
......@@ -2,7 +2,6 @@
import os
import socket
import multiprocessing as mp
import atexit
from . import rpc
......@@ -181,14 +180,12 @@ def finalize_client():
"""Release resources of this client."""
rpc.finalize_sender()
rpc.finalize_receiver()
if SAMPLER_POOL is not None:
SAMPLER_POOL.close()
SAMPLER_POOL.join()
global INITIALIZED
INITIALIZED = False
def shutdown_servers():
"""Issue commands to remote servers to shut them down.
Raises
------
ConnectionError : If anything wrong with the connection.
......@@ -198,31 +195,6 @@ def shutdown_servers():
for server_id in range(rpc.get_num_server()):
rpc.send_request(server_id, req)
SAMPLER_POOL = None
NUM_SAMPLER_WORKERS = 0
def _close():
"""Finalize client and close servers when finished"""
rpc.finalize_sender()
rpc.finalize_receiver()
def _init_rpc(ip_config, max_queue_size, net_type):
connect_to_server(ip_config, max_queue_size, net_type)
def get_sampler_pool():
"""Return the sampler pool and num_workers"""
return SAMPLER_POOL, NUM_SAMPLER_WORKERS
def init_rpc(ip_config, num_workers, max_queue_size=MAX_QUEUE_SIZE, net_type='socket'):
"""Init rpc service"""
ctx = mp.get_context("spawn")
global SAMPLER_POOL
global NUM_SAMPLER_WORKERS
SAMPLER_POOL = ctx.Pool(
num_workers, initializer=_init_rpc, initargs=(ip_config, max_queue_size, net_type))
NUM_SAMPLER_WORKERS = num_workers
connect_to_server(ip_config, max_queue_size, net_type)
def exit_client():
"""Register exit callback.
"""
......
import dgl
import unittest
import os
from dgl.data import CitationGraphDataset
from dgl.distributed import sample_neighbors
from dgl.distributed import partition_graph, load_partition, load_partition_book
import sys
import multiprocessing as mp
import numpy as np
import time
from utils import get_local_usable_addr
from pathlib import Path
from dgl.distributed import DistGraphServer, DistGraph, DistDataLoader
import pytest
class NeighborSampler(object):
def __init__(self, g, fanouts, sample_neighbors):
self.g = g
self.fanouts = fanouts
self.sample_neighbors = sample_neighbors
def sample_blocks(self, seeds):
import torch as th
seeds = th.LongTensor(np.asarray(seeds))
blocks = []
for fanout in self.fanouts:
# For each seed node, sample ``fanout`` neighbors.
frontier = self.sample_neighbors(self.g, seeds, fanout, replace=True)
# Then we compact the frontier into a bipartite graph for message passing.
block = dgl.to_block(frontier, seeds)
# Obtain the seed nodes for next layer.
seeds = block.srcdata[dgl.NID]
blocks.insert(0, block)
return blocks
def start_server(rank, tmpdir, disable_shared_mem, num_clients):
import dgl
print('server: #clients=' + str(num_clients))
g = DistGraphServer(rank, "mp_ip_config.txt", num_clients,
tmpdir / 'test_sampling.json', disable_shared_mem=disable_shared_mem)
g.start()
def start_client(rank, tmpdir, disable_shared_mem, num_workers):
import dgl
import torch as th
os.environ['DGL_DIST_MODE'] = 'distributed'
dgl.distributed.init_rpc("mp_ip_config.txt", num_workers=4)
gpb = None
if disable_shared_mem:
_, _, _, gpb, _ = load_partition(tmpdir / 'test_sampling.json', rank)
train_nid = th.arange(202)
dist_graph = DistGraph("mp_ip_config.txt", "test_mp", gpb=gpb)
# Create sampler
sampler = NeighborSampler(dist_graph, [5, 10],
dgl.distributed.sample_neighbors)
# Create PyTorch DataLoader for constructing blocks
dataloader = DistDataLoader(
dataset=train_nid.numpy(),
batch_size=32,
collate_fn=sampler.sample_blocks,
shuffle=True,
drop_last=False,
num_workers=4)
dist_graph._init()
for epoch in range(3):
for idx, blocks in enumerate(dataloader):
print(blocks)
print(blocks[1].edges())
print(idx)
dataloader.close()
dgl.distributed.exit_client()
def main(tmpdir, num_server):
ip_config = open("mp_ip_config.txt", "w")
for _ in range(num_server):
ip_config.write('{} 1\n'.format(get_local_usable_addr()))
ip_config.close()
g = CitationGraphDataset("cora")[0]
g.readonly()
print(g.idtype)
num_parts = num_server
num_hops = 1
partition_graph(g, 'test_sampling', num_parts, tmpdir,
num_hops=num_hops, part_method='metis', reshuffle=False)
num_workers = 4
pserver_list = []
ctx = mp.get_context('spawn')
for i in range(num_server):
p = ctx.Process(target=start_server, args=(
i, tmpdir, num_server > 1, num_workers+1))
p.start()
time.sleep(1)
pserver_list.append(p)
time.sleep(3)
sampled_graph = start_client(0, tmpdir, num_server > 1, num_workers)
for p in pserver_list:
p.join()
# Wait non shared memory graph store
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
@unittest.skipIf(dgl.backend.backend_name == 'tensorflow', reason='Not support tensorflow for now')
def test_dist_dataloader(tmpdir):
main(Path(tmpdir), 3)
if __name__ == "__main__":
import tempfile
with tempfile.TemporaryDirectory() as tmpdirname:
main(Path(tmpdirname), 3)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment