Unverified Commit 4f499c7f authored by Jinjing Zhou's avatar Jinjing Zhou Committed by GitHub
Browse files

[Distributed] DistDataloader (#1901)



* 111

* 111

* fix

* 111

* fix

* 11

* fix

* lint

* Update __init__.py

* lint

* fix

* lint

* fix

* fix

* fix

* fix

* fix

* try fix

* try fix

* fix

* Revert "fix"

This reverts commit a0b954fd4e99b7df92b53db8334dcb583d6e1551.

* fixes.

* fix.

* fix test.

* fix exit.

* fix.

* fix

* fix

* lint

* lint

* lint

* fix

* Update .gitignore

* 111

* fix

* 111

* 111

* fff

* 1111

* 111

* 1325315

* ffff

* f???

* fff

* 1111

* 111

* fix

* 111

* asda

* 1111

* 11

* 123

* 啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊

* spawn

* 1231231

* up

* 111

* fix

* fix

* Revert "fix"

This reverts commit 7373f95312fdcaa36d2fc330bf242339e89c045d.

* fix

* fix

* 1111

* fix

* fix tests

* start kvclient as early as possible.

* lint

* fix test

* lint

* 1111

* fix

* fix

* 111

* fix

* fix

* 1

* fix

* fix

* lint

* fix

* lint

* lint

* remove quit

* fix

* lint

* fix

* fix several

* lint

* fix minor

* fix

* lint
Co-authored-by: default avatarDa Zheng <zhengda1936@gmail.com>
Co-authored-by: default avatarxiang song(charlie.song) <classicxsong@gmail.com>
parent c801a164
...@@ -12,6 +12,7 @@ from dgl.data import register_data_args, load_data ...@@ -12,6 +12,7 @@ from dgl.data import register_data_args, load_data
from dgl.data.utils import load_graphs from dgl.data.utils import load_graphs
import dgl.function as fn import dgl.function as fn
import dgl.nn.pytorch as dglnn import dgl.nn.pytorch as dglnn
from dgl.distributed import DistDataLoader
import torch as th import torch as th
import torch.nn as nn import torch.nn as nn
...@@ -91,7 +92,7 @@ class DistSAGE(nn.Module): ...@@ -91,7 +92,7 @@ class DistSAGE(nn.Module):
sampler = NeighborSampler(g, [-1], dgl.distributed.sample_neighbors) sampler = NeighborSampler(g, [-1], dgl.distributed.sample_neighbors)
print('|V|={}, eval batch size: {}'.format(g.number_of_nodes(), batch_size)) print('|V|={}, eval batch size: {}'.format(g.number_of_nodes(), batch_size))
# Create PyTorch DataLoader for constructing blocks # Create PyTorch DataLoader for constructing blocks
dataloader = DataLoader( dataloader = DistDataLoader(
dataset=nodes, dataset=nodes,
batch_size=batch_size, batch_size=batch_size,
collate_fn=sampler.sample_blocks, collate_fn=sampler.sample_blocks,
...@@ -154,14 +155,13 @@ def run(args, device, data): ...@@ -154,14 +155,13 @@ def run(args, device, data):
sampler = NeighborSampler(g, [int(fanout) for fanout in args.fan_out.split(',')], sampler = NeighborSampler(g, [int(fanout) for fanout in args.fan_out.split(',')],
dgl.distributed.sample_neighbors) dgl.distributed.sample_neighbors)
# Create PyTorch DataLoader for constructing blocks # Create DataLoader for constructing blocks
dataloader = DataLoader( dataloader = DistDataLoader(
dataset=train_nid.numpy(), dataset=train_nid.numpy(),
batch_size=args.batch_size, batch_size=args.batch_size,
collate_fn=sampler.sample_blocks, collate_fn=sampler.sample_blocks,
shuffle=True, shuffle=True,
drop_last=False, drop_last=False)
num_workers=args.num_workers)
# Define model and optimizer # Define model and optimizer
model = DistSAGE(in_feats, args.num_hidden, n_classes, args.num_layers, F.relu, args.dropout) model = DistSAGE(in_feats, args.num_hidden, n_classes, args.num_layers, F.relu, args.dropout)
...@@ -258,13 +258,12 @@ def run(args, device, data): ...@@ -258,13 +258,12 @@ def run(args, device, data):
profiler.stop() profiler.stop()
print(profiler.output_text(unicode=True, color=True)) print(profiler.output_text(unicode=True, color=True))
# clean up
if not args.standalone:
g._client.barrier()
def main(args): def main(args):
if not args.standalone: if not args.standalone:
th.distributed.init_process_group(backend='gloo') th.distributed.init_process_group(backend='gloo')
dgl.distributed.initialize(args.ip_config, num_workers=args.num_workers)
g = dgl.distributed.DistGraph(args.ip_config, args.graph_name, part_config=args.conf_path) g = dgl.distributed.DistGraph(args.ip_config, args.graph_name, part_config=args.conf_path)
print('rank:', g.rank()) print('rank:', g.rank())
...@@ -309,7 +308,7 @@ if __name__ == '__main__': ...@@ -309,7 +308,7 @@ if __name__ == '__main__':
parser.add_argument('--eval-every', type=int, default=5) parser.add_argument('--eval-every', type=int, default=5)
parser.add_argument('--lr', type=float, default=0.003) parser.add_argument('--lr', type=float, default=0.003)
parser.add_argument('--dropout', type=float, default=0.5) parser.add_argument('--dropout', type=float, default=0.5)
parser.add_argument('--num-workers', type=int, default=0, parser.add_argument('--num-workers', type=int, default=4,
help="Number of sampling processes. Use 0 for no extra process.") help="Number of sampling processes. Use 0 for no extra process.")
parser.add_argument('--local_rank', type=int, help='get rank of the process') parser.add_argument('--local_rank', type=int, help='get rank of the process')
parser.add_argument('--standalone', action='store_true', help='run in the standalone mode') parser.add_argument('--standalone', action='store_true', help='run in the standalone mode')
......
...@@ -9,9 +9,11 @@ from .sparse_emb import SparseAdagrad, DistEmbedding ...@@ -9,9 +9,11 @@ from .sparse_emb import SparseAdagrad, DistEmbedding
from .rpc import * from .rpc import *
from .rpc_server import start_server from .rpc_server import start_server
from .rpc_client import connect_to_server, exit_client from .rpc_client import connect_to_server
from .dist_context import initialize, exit_client
from .kvstore import KVServer, KVClient from .kvstore import KVServer, KVClient
from .server_state import ServerState from .server_state import ServerState
from .dist_dataloader import DistDataLoader
from .graph_services import sample_neighbors, in_subgraph, find_edges from .graph_services import sample_neighbors, in_subgraph, find_edges
if os.environ.get('DGL_ROLE', 'client') == 'server': if os.environ.get('DGL_ROLE', 'client') == 'server':
......
"""Initialize the distributed services"""
import multiprocessing as mp
import traceback
import atexit
import time
from . import rpc
from .constants import MAX_QUEUE_SIZE
from .kvstore import init_kvstore, close_kvstore
from .rpc_client import connect_to_server, shutdown_servers
SAMPLER_POOL = None
NUM_SAMPLER_WORKERS = 0
INITIALIZED = False
def set_initialized(value=True):
"""Set the initialized state of rpc"""
global INITIALIZED
INITIALIZED = value
def get_sampler_pool():
"""Return the sampler pool and num_workers"""
return SAMPLER_POOL, NUM_SAMPLER_WORKERS
def _init_rpc(ip_config, max_queue_size, net_type, role):
''' This init function is called in the worker processes.
'''
try:
connect_to_server(ip_config, max_queue_size, net_type)
init_kvstore(ip_config, role)
except Exception as e:
print(e, flush=True)
traceback.print_exc()
raise e
def initialize(ip_config, num_workers=0, max_queue_size=MAX_QUEUE_SIZE, net_type='socket'):
"""Init rpc service
ip_config: str
File path of ip_config file
num_workers: int
Number of worker process to be created
max_queue_size : int
Maximal size (bytes) of client queue buffer (~20 GB on default).
Note that the 20 GB is just an upper-bound and DGL uses zero-copy and
it will not allocate 20GB memory at once.
net_type : str
Networking type. Current options are: 'socket'.
"""
rpc.reset()
ctx = mp.get_context("spawn")
global SAMPLER_POOL
global NUM_SAMPLER_WORKERS
if num_workers > 0:
SAMPLER_POOL = ctx.Pool(
num_workers, initializer=_init_rpc, initargs=(ip_config, max_queue_size,
net_type, 'sampler'))
NUM_SAMPLER_WORKERS = num_workers
connect_to_server(ip_config, max_queue_size, net_type)
init_kvstore(ip_config)
def finalize_client():
"""Release resources of this client."""
rpc.finalize_sender()
rpc.finalize_receiver()
global INITIALIZED
INITIALIZED = False
def _exit():
exit_client()
time.sleep(1)
def finalize_worker():
"""Finalize workers
Python's multiprocessing pool will not call atexit function when close
"""
if SAMPLER_POOL is not None:
for _ in range(NUM_SAMPLER_WORKERS):
SAMPLER_POOL.apply_async(_exit)
time.sleep(0.1) # This is necessary but I don't know why
SAMPLER_POOL.close()
def join_finalize_worker():
"""join the worker close process"""
if SAMPLER_POOL is not None:
SAMPLER_POOL.join()
def is_initialized():
"""Is RPC initialized?
"""
return INITIALIZED
def exit_client():
"""Register exit callback.
"""
# Only client with rank_0 will send shutdown request to servers.
finalize_worker() # finalize workers should be earilier than barrier, and non-blocking
rpc.client_barrier()
shutdown_servers()
finalize_client()
join_finalize_worker()
close_kvstore()
atexit.unregister(exit_client)
# pylint: disable=global-variable-undefined, invalid-name
"""Multiprocess dataloader for distributed training"""
import multiprocessing as mp
import time
import traceback
from .dist_context import get_sampler_pool
from .. import backend as F
__all__ = ["DistDataLoader"]
def call_collate_fn(next_data):
"""Call collate function"""
try:
result = DGL_GLOBAL_COLLATE_FN(next_data)
DGL_GLOBAL_MP_QUEUE.put(result)
except Exception as e:
traceback.print_exc()
print(e)
raise e
return 1
def init_fn(collate_fn, queue):
"""Initialize setting collate function and mp.Queue in the subprocess"""
global DGL_GLOBAL_COLLATE_FN
global DGL_GLOBAL_MP_QUEUE
DGL_GLOBAL_MP_QUEUE = queue
DGL_GLOBAL_COLLATE_FN = collate_fn
# sleep here is to ensure this function is executed in all worker processes
# probably need better solution in the future
time.sleep(1)
return 1
def enable_mp_debug():
"""Print multiprocessing debug information. This is only
for debug usage"""
import logging
logger = multiprocessing.log_to_stderr()
logger.setLevel(logging.DEBUG)
class DistDataLoader:
"""DGL customized multiprocessing dataloader, which is designed for using with DistGraph."""
def __init__(self, dataset, batch_size, shuffle=False, collate_fn=None, drop_last=False,
queue_size=None):
"""
This class will utilize the worker process created by dgl.distributed.initialize function
Note that the iteration order is not guaranteed with this class. For example,
if dataset = [1, 2, 3, 4], batch_size = 2 and shuffle = False, the order of [1, 2]
and [3, 4] is not guaranteed.
dataset (Dataset): dataset from which to load the data.
batch_size (int, optional): how many samples per batch to load
(default: ``1``).
shuffle (bool, optional): set to ``True`` to have the data reshuffled
at every epoch (default: ``False``).
collate_fn (callable, optional): merges a list of samples to form a
mini-batch of Tensor(s). Used when using batched loading from a
map-style dataset.
drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
if the dataset size is not divisible by the batch size. If ``False`` and
the size of dataset is not divisible by the batch size, then the last batch
will be smaller. (default: ``False``)
queue_size (int, optional): Size of multiprocessing queue
"""
self.pool, num_workers = get_sampler_pool()
assert num_workers > 0, "DistDataloader only supports num_workers>0 for now. if you \
want to use single process dataloader, please use PyTorch dataloader for now"
if queue_size is None:
queue_size = num_workers * 4
self.queue_size = queue_size
self.batch_size = batch_size
self.queue_size = queue_size
self.collate_fn = collate_fn
self.current_pos = 0
self.num_workers = num_workers
self.m = mp.Manager()
self.queue = self.m.Queue(maxsize=queue_size)
self.drop_last = drop_last
self.recv_idxs = 0
self.started = False
self.shuffle = shuffle
self.is_closed = False
if self.pool is None:
ctx = mp.get_context("spawn")
self.pool = ctx.Pool(num_workers)
results = []
for _ in range(num_workers):
results.append(self.pool.apply_async(
init_fn, args=(collate_fn, self.queue)))
for res in results:
res.get()
self.dataset = F.tensor(dataset)
self.expected_idxs = len(dataset) // self.batch_size
if not self.drop_last and len(dataset) % self.batch_size != 0:
self.expected_idxs += 1
def __next__(self):
if not self.started:
for _ in range(self.queue_size):
self._request_next_batch()
self._request_next_batch()
if self.recv_idxs < self.expected_idxs:
result = self.queue.get(timeout=9999)
self.recv_idxs += 1
return result
else:
self.recv_idxs = 0
self.current_pos = 0
raise StopIteration
def __iter__(self):
if self.shuffle:
self.dataset = F.rand_shuffle(self.dataset)
return self
def _request_next_batch(self):
next_data = self._next_data()
if next_data is None:
return None
else:
async_result = self.pool.apply_async(
call_collate_fn, args=(next_data, ))
return async_result
def _next_data(self):
if self.current_pos == len(self.dataset):
return None
end_pos = 0
if self.current_pos + self.batch_size > len(self.dataset):
if self.drop_last:
return None
else:
end_pos = len(self.dataset)
else:
end_pos = self.current_pos + self.batch_size
ret = self.dataset[self.current_pos:end_pos]
self.current_pos = end_pos
return ret
...@@ -8,7 +8,7 @@ from ..heterograph import DGLHeteroGraph ...@@ -8,7 +8,7 @@ from ..heterograph import DGLHeteroGraph
from .. import heterograph_index from .. import heterograph_index
from .. import backend as F from .. import backend as F
from ..base import NID, EID from ..base import NID, EID
from .kvstore import KVServer, KVClient from .kvstore import KVServer, init_kvstore, get_kvstore
from .standalone_kvstore import KVClient as SA_KVClient from .standalone_kvstore import KVClient as SA_KVClient
from .._ffi.ndarray import empty_shared_mem from .._ffi.ndarray import empty_shared_mem
from ..frame import infer_scheme from ..frame import infer_scheme
...@@ -17,7 +17,6 @@ from .graph_partition_book import PartitionPolicy, get_shared_mem_partition_book ...@@ -17,7 +17,6 @@ from .graph_partition_book import PartitionPolicy, get_shared_mem_partition_book
from .graph_partition_book import NODE_PART_POLICY, EDGE_PART_POLICY from .graph_partition_book import NODE_PART_POLICY, EDGE_PART_POLICY
from .shared_mem_utils import _to_shared_mem, _get_ndata_path, _get_edata_path, DTYPE_DICT from .shared_mem_utils import _to_shared_mem, _get_ndata_path, _get_edata_path, DTYPE_DICT
from . import rpc from . import rpc
from .rpc_client import connect_to_server
from .server_state import ServerState from .server_state import ServerState
from .rpc_server import start_server from .rpc_server import start_server
from .graph_services import find_edges as dist_find_edges from .graph_services import find_edges as dist_find_edges
...@@ -323,6 +322,9 @@ class DistGraph: ...@@ -323,6 +322,9 @@ class DistGraph:
The partition config file. It's used in the standalone mode. The partition config file. It's used in the standalone mode.
''' '''
def __init__(self, ip_config, graph_name, gpb=None, part_config=None): def __init__(self, ip_config, graph_name, gpb=None, part_config=None):
self.ip_config = ip_config
self.graph_name = graph_name
self._gpb_input = gpb
if os.environ.get('DGL_DIST_MODE', 'standalone') == 'standalone': if os.environ.get('DGL_DIST_MODE', 'standalone') == 'standalone':
assert part_config is not None, \ assert part_config is not None, \
'When running in the standalone model, the partition config file is required' 'When running in the standalone model, the partition config file is required'
...@@ -340,20 +342,13 @@ class DistGraph: ...@@ -340,20 +342,13 @@ class DistGraph:
self._client.add_data(_get_data_name(name, EDGE_PART_POLICY), edge_feats[name]) self._client.add_data(_get_data_name(name, EDGE_PART_POLICY), edge_feats[name])
rpc.set_num_client(1) rpc.set_num_client(1)
else: else:
connect_to_server(ip_config=ip_config) self._init()
self._client = KVClient(ip_config)
self._g = _get_graph_from_shared_mem(graph_name)
self._gpb = get_shared_mem_partition_book(graph_name, self._g)
if self._gpb is None:
self._gpb = gpb
# Tell the backup servers to load the graph structure from shared memory. # Tell the backup servers to load the graph structure from shared memory.
for server_id in range(self._client.num_servers): for server_id in range(self._client.num_servers):
rpc.send_request(server_id, InitGraphRequest(graph_name)) rpc.send_request(server_id, InitGraphRequest(graph_name))
for server_id in range(self._client.num_servers): for server_id in range(self._client.num_servers):
rpc.recv_response() rpc.recv_response()
self._client.barrier() self._client.barrier()
self._client.map_shared_data(self._gpb)
self._ndata = NodeDataView(self) self._ndata = NodeDataView(self)
self._edata = EdgeDataView(self) self._edata = EdgeDataView(self)
...@@ -364,6 +359,32 @@ class DistGraph: ...@@ -364,6 +359,32 @@ class DistGraph:
self._num_nodes += int(part_md['num_nodes']) self._num_nodes += int(part_md['num_nodes'])
self._num_edges += int(part_md['num_edges']) self._num_edges += int(part_md['num_edges'])
def _init(self):
# Init KVStore client if it's not initialized yet.
init_kvstore(self.ip_config)
self._client = get_kvstore()
self._g = _get_graph_from_shared_mem(self.graph_name)
self._gpb = get_shared_mem_partition_book(self.graph_name, self._g)
if self._gpb is None:
self._gpb = self._gpb_input
self._client.map_shared_data(self._gpb)
def __getstate__(self):
return self.ip_config, self.graph_name, self._gpb_input
def __setstate__(self, state):
self.ip_config, self.graph_name, self._gpb_input = state
self._init()
self._ndata = NodeDataView(self)
self._edata = EdgeDataView(self)
self._num_nodes = 0
self._num_edges = 0
for part_md in self._gpb.metadata():
self._num_nodes += int(part_md['num_nodes'])
self._num_edges += int(part_md['num_edges'])
@property @property
def local_partition(self): def local_partition(self):
''' Return the local partition on the client ''' Return the local partition on the client
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
import os import os
from .graph_partition_book import PartitionPolicy, NODE_PART_POLICY, EDGE_PART_POLICY from .graph_partition_book import PartitionPolicy, NODE_PART_POLICY, EDGE_PART_POLICY
from .rpc_client import is_initialized from .dist_context import is_initialized
from ..base import DGLError from ..base import DGLError
from .. import utils from .. import utils
from .. import backend as F from .. import backend as F
......
...@@ -1287,3 +1287,20 @@ class KVClient(object): ...@@ -1287,3 +1287,20 @@ class KVClient(object):
"""Used by sort response list """Used by sort response list
""" """
return elem.server_id return elem.server_id
KVCLIENT = None
def init_kvstore(ip_config, role='default'):
"""initialize KVStore"""
global KVCLIENT
if KVCLIENT is None:
KVCLIENT = KVClient(ip_config, role)
def close_kvstore():
"""Close the current KVClient"""
global KVCLIENT
KVCLIENT = None
def get_kvstore():
"""get the KVClient"""
return KVCLIENT
...@@ -79,6 +79,12 @@ def read_ip_config(filename): ...@@ -79,6 +79,12 @@ def read_ip_config(filename):
print("Error: data format on each line should be: [ip] [base_port] [server_count]") print("Error: data format on each line should be: [ip] [base_port] [server_count]")
return server_namebook return server_namebook
def reset():
"""Reset the rpc context
"""
_CAPI_DGLRPCReset()
def create_sender(max_queue_size, net_type): def create_sender(max_queue_size, net_type):
"""Create rpc sender of this process. """Create rpc sender of this process.
......
...@@ -96,7 +96,6 @@ def get_local_usable_addr(): ...@@ -96,7 +96,6 @@ def get_local_usable_addr():
return ip_addr + ':' + str(port) return ip_addr + ':' + str(port)
INITIALIZED = False
def connect_to_server(ip_config, max_queue_size=MAX_QUEUE_SIZE, net_type='socket'): def connect_to_server(ip_config, max_queue_size=MAX_QUEUE_SIZE, net_type='socket'):
"""Connect this client to server. """Connect this client to server.
...@@ -175,16 +174,9 @@ def connect_to_server(ip_config, max_queue_size=MAX_QUEUE_SIZE, net_type='socket ...@@ -175,16 +174,9 @@ def connect_to_server(ip_config, max_queue_size=MAX_QUEUE_SIZE, net_type='socket
rpc.send_request(0, get_client_num_req) rpc.send_request(0, get_client_num_req)
res = rpc.recv_response() res = rpc.recv_response()
rpc.set_num_client(res.num_client) rpc.set_num_client(res.num_client)
from .dist_context import exit_client, set_initialized
atexit.register(exit_client) atexit.register(exit_client)
global INITIALIZED set_initialized()
INITIALIZED = True
def finalize_client():
"""Release resources of this client."""
rpc.finalize_sender()
rpc.finalize_receiver()
global INITIALIZED
INITIALIZED = False
def shutdown_servers(): def shutdown_servers():
"""Issue commands to remote servers to shut them down. """Issue commands to remote servers to shut them down.
...@@ -197,17 +189,3 @@ def shutdown_servers(): ...@@ -197,17 +189,3 @@ def shutdown_servers():
req = rpc.ShutDownRequest(rpc.get_rank()) req = rpc.ShutDownRequest(rpc.get_rank())
for server_id in range(rpc.get_num_server()): for server_id in range(rpc.get_num_server()):
rpc.send_request(server_id, req) rpc.send_request(server_id, req)
def exit_client():
"""Register exit callback.
"""
# Only client with rank_0 will send shutdown request to servers.
rpc.client_barrier()
shutdown_servers()
finalize_client()
atexit.unregister(exit_client)
def is_initialized():
"""Is RPC initialized?
"""
return INITIALIZED
...@@ -77,6 +77,10 @@ RPCStatus RecvRPCMessage(RPCMessage* msg, int32_t timeout) { ...@@ -77,6 +77,10 @@ RPCStatus RecvRPCMessage(RPCMessage* msg, int32_t timeout) {
} }
//////////////////////////// C APIs //////////////////////////// //////////////////////////// C APIs ////////////////////////////
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCReset")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
RPCContext::Reset();
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateSender") DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateSender")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
......
...@@ -96,6 +96,19 @@ struct RPCContext { ...@@ -96,6 +96,19 @@ struct RPCContext {
static RPCContext *ThreadLocal() { static RPCContext *ThreadLocal() {
return dmlc::ThreadLocalStore<RPCContext>::Get(); return dmlc::ThreadLocalStore<RPCContext>::Get();
} }
/*! \brief Reset the RPC context */
static void Reset() {
auto* t = ThreadLocal();
t->rank = -1;
t->machine_id = -1;
t->num_machines = 0;
t->num_clients = 0;
t->barrier_count = 0;
t->num_servers_per_machine = 0;
t->sender = std::shared_ptr<network::Sender>();
t->receiver = std::shared_ptr<network::Receiver>();
}
}; };
/*! \brief RPC message data structure /*! \brief RPC message data structure
......
...@@ -68,6 +68,7 @@ def rand_init(shape, dtype): ...@@ -68,6 +68,7 @@ def rand_init(shape, dtype):
def run_client(graph_name, part_id, num_clients, num_nodes, num_edges): def run_client(graph_name, part_id, num_clients, num_nodes, num_edges):
time.sleep(5) time.sleep(5)
dgl.distributed.initialize("kv_ip_config.txt")
gpb, graph_name = load_partition_book('/tmp/dist_graph/{}.json'.format(graph_name), gpb, graph_name = load_partition_book('/tmp/dist_graph/{}.json'.format(graph_name),
part_id, None) part_id, None)
g = DistGraph("kv_ip_config.txt", graph_name, gpb=gpb) g = DistGraph("kv_ip_config.txt", graph_name, gpb=gpb)
......
...@@ -11,7 +11,7 @@ import backend as F ...@@ -11,7 +11,7 @@ import backend as F
import time import time
from utils import get_local_usable_addr from utils import get_local_usable_addr
from pathlib import Path from pathlib import Path
import pytest
from dgl.distributed import DistGraphServer, DistGraph from dgl.distributed import DistGraphServer, DistGraph
...@@ -25,6 +25,7 @@ def start_sample_client(rank, tmpdir, disable_shared_mem): ...@@ -25,6 +25,7 @@ def start_sample_client(rank, tmpdir, disable_shared_mem):
gpb = None gpb = None
if disable_shared_mem: if disable_shared_mem:
_, _, _, gpb, _ = load_partition(tmpdir / 'test_sampling.json', rank) _, _, _, gpb, _ = load_partition(tmpdir / 'test_sampling.json', rank)
dgl.distributed.initialize("rpc_ip_config.txt")
dist_graph = DistGraph("rpc_ip_config.txt", "test_sampling", gpb=gpb) dist_graph = DistGraph("rpc_ip_config.txt", "test_sampling", gpb=gpb)
sampled_graph = sample_neighbors(dist_graph, [0, 10, 99, 66, 1024, 2008], 3) sampled_graph = sample_neighbors(dist_graph, [0, 10, 99, 66, 1024, 2008], 3)
dgl.distributed.exit_client() dgl.distributed.exit_client()
...@@ -158,12 +159,12 @@ def check_rpc_sampling_shuffle(tmpdir, num_server): ...@@ -158,12 +159,12 @@ def check_rpc_sampling_shuffle(tmpdir, num_server):
# Wait non shared memory graph store # Wait non shared memory graph store
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet') @unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
@unittest.skipIf(dgl.backend.backend_name == 'tensorflow', reason='Not support tensorflow for now') @unittest.skipIf(dgl.backend.backend_name == 'tensorflow', reason='Not support tensorflow for now')
def test_rpc_sampling_shuffle(): @pytest.mark.parametrize("num_server", [1, 2])
def test_rpc_sampling_shuffle(num_server):
import tempfile import tempfile
os.environ['DGL_DIST_MODE'] = 'distributed' os.environ['DGL_DIST_MODE'] = 'distributed'
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
check_rpc_sampling_shuffle(Path(tmpdirname), 2) check_rpc_sampling_shuffle(Path(tmpdirname), num_server)
check_rpc_sampling_shuffle(Path(tmpdirname), 1)
def check_standalone_sampling(tmpdir): def check_standalone_sampling(tmpdir):
g = CitationGraphDataset("cora")[0] g = CitationGraphDataset("cora")[0]
...@@ -171,7 +172,7 @@ def check_standalone_sampling(tmpdir): ...@@ -171,7 +172,7 @@ def check_standalone_sampling(tmpdir):
num_hops = 1 num_hops = 1
partition_graph(g, 'test_sampling', num_parts, tmpdir, partition_graph(g, 'test_sampling', num_parts, tmpdir,
num_hops=num_hops, part_method='metis', reshuffle=False) num_hops=num_hops, part_method='metis', reshuffle=False)
os.environ['DGL_DIST_MODE'] = 'standalone'
dist_graph = DistGraph(None, "test_sampling", part_config=tmpdir / 'test_sampling.json') dist_graph = DistGraph(None, "test_sampling", part_config=tmpdir / 'test_sampling.json')
sampled_graph = sample_neighbors(dist_graph, [0, 10, 99, 66, 1024, 2008], 3) sampled_graph = sample_neighbors(dist_graph, [0, 10, 99, 66, 1024, 2008], 3)
...@@ -192,6 +193,7 @@ def test_standalone_sampling(): ...@@ -192,6 +193,7 @@ def test_standalone_sampling():
def start_in_subgraph_client(rank, tmpdir, disable_shared_mem, nodes): def start_in_subgraph_client(rank, tmpdir, disable_shared_mem, nodes):
gpb = None gpb = None
dgl.distributed.initialize("rpc_ip_config.txt")
if disable_shared_mem: if disable_shared_mem:
_, _, _, gpb, _ = load_partition(tmpdir / 'test_in_subgraph.json', rank) _, _, _, gpb, _ = load_partition(tmpdir / 'test_in_subgraph.json', rank)
dist_graph = DistGraph("rpc_ip_config.txt", "test_in_subgraph", gpb=gpb) dist_graph = DistGraph("rpc_ip_config.txt", "test_in_subgraph", gpb=gpb)
......
import dgl
import unittest
import os
from dgl.data import CitationGraphDataset
from dgl.distributed import sample_neighbors
from dgl.distributed import partition_graph, load_partition, load_partition_book
import sys
import multiprocessing as mp
import numpy as np
import time
from utils import get_local_usable_addr
from pathlib import Path
from dgl.distributed import DistGraphServer, DistGraph, DistDataLoader
import pytest
import backend as F
class NeighborSampler(object):
def __init__(self, g, fanouts, sample_neighbors):
self.g = g
self.fanouts = fanouts
self.sample_neighbors = sample_neighbors
def sample_blocks(self, seeds):
import torch as th
seeds = th.LongTensor(np.asarray(seeds))
blocks = []
for fanout in self.fanouts:
# For each seed node, sample ``fanout`` neighbors.
frontier = self.sample_neighbors(
self.g, seeds, fanout, replace=True)
# Then we compact the frontier into a bipartite graph for message passing.
block = dgl.to_block(frontier, seeds)
# Obtain the seed nodes for next layer.
seeds = block.srcdata[dgl.NID]
blocks.insert(0, block)
return blocks
def start_server(rank, tmpdir, disable_shared_mem, num_clients):
import dgl
print('server: #clients=' + str(num_clients))
g = DistGraphServer(rank, "mp_ip_config.txt", num_clients,
tmpdir / 'test_sampling.json', disable_shared_mem=disable_shared_mem)
g.start()
def start_client(rank, tmpdir, disable_shared_mem, num_workers, drop_last):
import dgl
import torch as th
os.environ['DGL_DIST_MODE'] = 'distributed'
dgl.distributed.initialize("mp_ip_config.txt", num_workers=4)
gpb = None
if disable_shared_mem:
_, _, _, gpb, _ = load_partition(tmpdir / 'test_sampling.json', rank)
num_nodes_to_sample = 202
batch_size = 32
train_nid = th.arange(num_nodes_to_sample)
dist_graph = DistGraph("mp_ip_config.txt", "test_mp", gpb=gpb)
# Create sampler
sampler = NeighborSampler(dist_graph, [5, 10],
dgl.distributed.sample_neighbors)
# Create DataLoader for constructing blocks
dataloader = DistDataLoader(
dataset=train_nid.numpy(),
batch_size=batch_size,
collate_fn=sampler.sample_blocks,
shuffle=False,
drop_last=drop_last)
groundtruth_g = CitationGraphDataset("cora")[0]
max_nid = []
for epoch in range(2):
for idx, blocks in zip(range(0, num_nodes_to_sample, batch_size), dataloader):
block = blocks[-1]
o_src, o_dst = block.edges()
src_nodes_id = block.srcdata[dgl.NID][o_src]
dst_nodes_id = block.dstdata[dgl.NID][o_dst]
has_edges = groundtruth_g.has_edges_between(src_nodes_id, dst_nodes_id)
assert np.all(F.asnumpy(has_edges))
print(np.unique(np.sort(F.asnumpy(dst_nodes_id))))
max_nid.append(np.max(F.asnumpy(dst_nodes_id)))
# assert np.all(np.unique(np.sort(F.asnumpy(dst_nodes_id))) == np.arange(idx, batch_size))
if drop_last:
assert np.max(max_nid) == num_nodes_to_sample - 1 - num_nodes_to_sample % batch_size
else:
assert np.max(max_nid) == num_nodes_to_sample - 1
dgl.distributed.exit_client() # this is needed since there's two test here in one process
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
@unittest.skipIf(dgl.backend.backend_name != 'pytorch', reason='Only support PyTorch for now')
@pytest.mark.parametrize("num_server", [3])
@pytest.mark.parametrize("drop_last", [True, False])
def test_dist_dataloader(tmpdir, num_server, drop_last):
ip_config = open("mp_ip_config.txt", "w")
for _ in range(num_server):
ip_config.write('{} 1\n'.format(get_local_usable_addr()))
ip_config.close()
g = CitationGraphDataset("cora")[0]
print(g.idtype)
num_parts = num_server
num_hops = 1
partition_graph(g, 'test_sampling', num_parts, tmpdir,
num_hops=num_hops, part_method='metis', reshuffle=False)
num_workers = 4
pserver_list = []
ctx = mp.get_context('spawn')
for i in range(num_server):
p = ctx.Process(target=start_server, args=(
i, tmpdir, num_server > 1, num_workers+1))
p.start()
time.sleep(1)
pserver_list.append(p)
time.sleep(3)
sampled_graph = start_client(0, tmpdir, num_server > 1, num_workers, drop_last)
if __name__ == "__main__":
import tempfile
with tempfile.TemporaryDirectory() as tmpdirname:
test_dist_dataloader(Path(tmpdirname), 3, True)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment