"docs/vscode:/vscode.git/clone" did not exist on "bdb2449ddca40e3d49cbb1a29611d76052f40338"
Unverified Commit 4f499c7f authored by Jinjing Zhou's avatar Jinjing Zhou Committed by GitHub
Browse files

[Distributed] DistDataloader (#1901)



* 111

* 111

* fix

* 111

* fix

* 11

* fix

* lint

* Update __init__.py

* lint

* fix

* lint

* fix

* fix

* fix

* fix

* fix

* try fix

* try fix

* fix

* Revert "fix"

This reverts commit a0b954fd4e99b7df92b53db8334dcb583d6e1551.

* fixes.

* fix.

* fix test.

* fix exit.

* fix.

* fix

* fix

* lint

* lint

* lint

* fix

* Update .gitignore

* 111

* fix

* 111

* 111

* fff

* 1111

* 111

* 1325315

* ffff

* f???

* fff

* 1111

* 111

* fix

* 111

* asda

* 1111

* 11

* 123

* 啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊

* spawn

* 1231231

* up

* 111

* fix

* fix

* Revert "fix"

This reverts commit 7373f95312fdcaa36d2fc330bf242339e89c045d.

* fix

* fix

* 1111

* fix

* fix tests

* start kvclient as early as possible.

* lint

* fix test

* lint

* 1111

* fix

* fix

* 111

* fix

* fix

* 1

* fix

* fix

* lint

* fix

* lint

* lint

* remove quit

* fix

* lint

* fix

* fix several

* lint

* fix minor

* fix

* lint
Co-authored-by: default avatarDa Zheng <zhengda1936@gmail.com>
Co-authored-by: default avatarxiang song(charlie.song) <classicxsong@gmail.com>
parent c801a164
......@@ -12,6 +12,7 @@ from dgl.data import register_data_args, load_data
from dgl.data.utils import load_graphs
import dgl.function as fn
import dgl.nn.pytorch as dglnn
from dgl.distributed import DistDataLoader
import torch as th
import torch.nn as nn
......@@ -91,7 +92,7 @@ class DistSAGE(nn.Module):
sampler = NeighborSampler(g, [-1], dgl.distributed.sample_neighbors)
print('|V|={}, eval batch size: {}'.format(g.number_of_nodes(), batch_size))
# Create PyTorch DataLoader for constructing blocks
dataloader = DataLoader(
dataloader = DistDataLoader(
dataset=nodes,
batch_size=batch_size,
collate_fn=sampler.sample_blocks,
......@@ -154,14 +155,13 @@ def run(args, device, data):
sampler = NeighborSampler(g, [int(fanout) for fanout in args.fan_out.split(',')],
dgl.distributed.sample_neighbors)
# Create PyTorch DataLoader for constructing blocks
dataloader = DataLoader(
# Create DataLoader for constructing blocks
dataloader = DistDataLoader(
dataset=train_nid.numpy(),
batch_size=args.batch_size,
collate_fn=sampler.sample_blocks,
shuffle=True,
drop_last=False,
num_workers=args.num_workers)
drop_last=False)
# Define model and optimizer
model = DistSAGE(in_feats, args.num_hidden, n_classes, args.num_layers, F.relu, args.dropout)
......@@ -258,13 +258,12 @@ def run(args, device, data):
profiler.stop()
print(profiler.output_text(unicode=True, color=True))
# clean up
if not args.standalone:
g._client.barrier()
def main(args):
if not args.standalone:
th.distributed.init_process_group(backend='gloo')
dgl.distributed.initialize(args.ip_config, num_workers=args.num_workers)
g = dgl.distributed.DistGraph(args.ip_config, args.graph_name, part_config=args.conf_path)
print('rank:', g.rank())
......@@ -309,7 +308,7 @@ if __name__ == '__main__':
parser.add_argument('--eval-every', type=int, default=5)
parser.add_argument('--lr', type=float, default=0.003)
parser.add_argument('--dropout', type=float, default=0.5)
parser.add_argument('--num-workers', type=int, default=0,
parser.add_argument('--num-workers', type=int, default=4,
help="Number of sampling processes. Use 0 for no extra process.")
parser.add_argument('--local_rank', type=int, help='get rank of the process')
parser.add_argument('--standalone', action='store_true', help='run in the standalone mode')
......
......@@ -9,9 +9,11 @@ from .sparse_emb import SparseAdagrad, DistEmbedding
from .rpc import *
from .rpc_server import start_server
from .rpc_client import connect_to_server, exit_client
from .rpc_client import connect_to_server
from .dist_context import initialize, exit_client
from .kvstore import KVServer, KVClient
from .server_state import ServerState
from .dist_dataloader import DistDataLoader
from .graph_services import sample_neighbors, in_subgraph, find_edges
if os.environ.get('DGL_ROLE', 'client') == 'server':
......
"""Initialize the distributed services"""
import multiprocessing as mp
import traceback
import atexit
import time
from . import rpc
from .constants import MAX_QUEUE_SIZE
from .kvstore import init_kvstore, close_kvstore
from .rpc_client import connect_to_server, shutdown_servers
SAMPLER_POOL = None
NUM_SAMPLER_WORKERS = 0
INITIALIZED = False
def set_initialized(value=True):
"""Set the initialized state of rpc"""
global INITIALIZED
INITIALIZED = value
def get_sampler_pool():
"""Return the sampler pool and num_workers"""
return SAMPLER_POOL, NUM_SAMPLER_WORKERS
def _init_rpc(ip_config, max_queue_size, net_type, role):
''' This init function is called in the worker processes.
'''
try:
connect_to_server(ip_config, max_queue_size, net_type)
init_kvstore(ip_config, role)
except Exception as e:
print(e, flush=True)
traceback.print_exc()
raise e
def initialize(ip_config, num_workers=0, max_queue_size=MAX_QUEUE_SIZE, net_type='socket'):
"""Init rpc service
ip_config: str
File path of ip_config file
num_workers: int
Number of worker process to be created
max_queue_size : int
Maximal size (bytes) of client queue buffer (~20 GB on default).
Note that the 20 GB is just an upper-bound and DGL uses zero-copy and
it will not allocate 20GB memory at once.
net_type : str
Networking type. Current options are: 'socket'.
"""
rpc.reset()
ctx = mp.get_context("spawn")
global SAMPLER_POOL
global NUM_SAMPLER_WORKERS
if num_workers > 0:
SAMPLER_POOL = ctx.Pool(
num_workers, initializer=_init_rpc, initargs=(ip_config, max_queue_size,
net_type, 'sampler'))
NUM_SAMPLER_WORKERS = num_workers
connect_to_server(ip_config, max_queue_size, net_type)
init_kvstore(ip_config)
def finalize_client():
"""Release resources of this client."""
rpc.finalize_sender()
rpc.finalize_receiver()
global INITIALIZED
INITIALIZED = False
def _exit():
exit_client()
time.sleep(1)
def finalize_worker():
"""Finalize workers
Python's multiprocessing pool will not call atexit function when close
"""
if SAMPLER_POOL is not None:
for _ in range(NUM_SAMPLER_WORKERS):
SAMPLER_POOL.apply_async(_exit)
time.sleep(0.1) # This is necessary but I don't know why
SAMPLER_POOL.close()
def join_finalize_worker():
"""join the worker close process"""
if SAMPLER_POOL is not None:
SAMPLER_POOL.join()
def is_initialized():
"""Is RPC initialized?
"""
return INITIALIZED
def exit_client():
"""Register exit callback.
"""
# Only client with rank_0 will send shutdown request to servers.
finalize_worker() # finalize workers should be earilier than barrier, and non-blocking
rpc.client_barrier()
shutdown_servers()
finalize_client()
join_finalize_worker()
close_kvstore()
atexit.unregister(exit_client)
# pylint: disable=global-variable-undefined, invalid-name
"""Multiprocess dataloader for distributed training"""
import multiprocessing as mp
import time
import traceback
from .dist_context import get_sampler_pool
from .. import backend as F
__all__ = ["DistDataLoader"]
def call_collate_fn(next_data):
"""Call collate function"""
try:
result = DGL_GLOBAL_COLLATE_FN(next_data)
DGL_GLOBAL_MP_QUEUE.put(result)
except Exception as e:
traceback.print_exc()
print(e)
raise e
return 1
def init_fn(collate_fn, queue):
"""Initialize setting collate function and mp.Queue in the subprocess"""
global DGL_GLOBAL_COLLATE_FN
global DGL_GLOBAL_MP_QUEUE
DGL_GLOBAL_MP_QUEUE = queue
DGL_GLOBAL_COLLATE_FN = collate_fn
# sleep here is to ensure this function is executed in all worker processes
# probably need better solution in the future
time.sleep(1)
return 1
def enable_mp_debug():
"""Print multiprocessing debug information. This is only
for debug usage"""
import logging
logger = multiprocessing.log_to_stderr()
logger.setLevel(logging.DEBUG)
class DistDataLoader:
"""DGL customized multiprocessing dataloader, which is designed for using with DistGraph."""
def __init__(self, dataset, batch_size, shuffle=False, collate_fn=None, drop_last=False,
queue_size=None):
"""
This class will utilize the worker process created by dgl.distributed.initialize function
Note that the iteration order is not guaranteed with this class. For example,
if dataset = [1, 2, 3, 4], batch_size = 2 and shuffle = False, the order of [1, 2]
and [3, 4] is not guaranteed.
dataset (Dataset): dataset from which to load the data.
batch_size (int, optional): how many samples per batch to load
(default: ``1``).
shuffle (bool, optional): set to ``True`` to have the data reshuffled
at every epoch (default: ``False``).
collate_fn (callable, optional): merges a list of samples to form a
mini-batch of Tensor(s). Used when using batched loading from a
map-style dataset.
drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
if the dataset size is not divisible by the batch size. If ``False`` and
the size of dataset is not divisible by the batch size, then the last batch
will be smaller. (default: ``False``)
queue_size (int, optional): Size of multiprocessing queue
"""
self.pool, num_workers = get_sampler_pool()
assert num_workers > 0, "DistDataloader only supports num_workers>0 for now. if you \
want to use single process dataloader, please use PyTorch dataloader for now"
if queue_size is None:
queue_size = num_workers * 4
self.queue_size = queue_size
self.batch_size = batch_size
self.queue_size = queue_size
self.collate_fn = collate_fn
self.current_pos = 0
self.num_workers = num_workers
self.m = mp.Manager()
self.queue = self.m.Queue(maxsize=queue_size)
self.drop_last = drop_last
self.recv_idxs = 0
self.started = False
self.shuffle = shuffle
self.is_closed = False
if self.pool is None:
ctx = mp.get_context("spawn")
self.pool = ctx.Pool(num_workers)
results = []
for _ in range(num_workers):
results.append(self.pool.apply_async(
init_fn, args=(collate_fn, self.queue)))
for res in results:
res.get()
self.dataset = F.tensor(dataset)
self.expected_idxs = len(dataset) // self.batch_size
if not self.drop_last and len(dataset) % self.batch_size != 0:
self.expected_idxs += 1
def __next__(self):
if not self.started:
for _ in range(self.queue_size):
self._request_next_batch()
self._request_next_batch()
if self.recv_idxs < self.expected_idxs:
result = self.queue.get(timeout=9999)
self.recv_idxs += 1
return result
else:
self.recv_idxs = 0
self.current_pos = 0
raise StopIteration
def __iter__(self):
if self.shuffle:
self.dataset = F.rand_shuffle(self.dataset)
return self
def _request_next_batch(self):
next_data = self._next_data()
if next_data is None:
return None
else:
async_result = self.pool.apply_async(
call_collate_fn, args=(next_data, ))
return async_result
def _next_data(self):
if self.current_pos == len(self.dataset):
return None
end_pos = 0
if self.current_pos + self.batch_size > len(self.dataset):
if self.drop_last:
return None
else:
end_pos = len(self.dataset)
else:
end_pos = self.current_pos + self.batch_size
ret = self.dataset[self.current_pos:end_pos]
self.current_pos = end_pos
return ret
......@@ -8,7 +8,7 @@ from ..heterograph import DGLHeteroGraph
from .. import heterograph_index
from .. import backend as F
from ..base import NID, EID
from .kvstore import KVServer, KVClient
from .kvstore import KVServer, init_kvstore, get_kvstore
from .standalone_kvstore import KVClient as SA_KVClient
from .._ffi.ndarray import empty_shared_mem
from ..frame import infer_scheme
......@@ -17,7 +17,6 @@ from .graph_partition_book import PartitionPolicy, get_shared_mem_partition_book
from .graph_partition_book import NODE_PART_POLICY, EDGE_PART_POLICY
from .shared_mem_utils import _to_shared_mem, _get_ndata_path, _get_edata_path, DTYPE_DICT
from . import rpc
from .rpc_client import connect_to_server
from .server_state import ServerState
from .rpc_server import start_server
from .graph_services import find_edges as dist_find_edges
......@@ -323,6 +322,9 @@ class DistGraph:
The partition config file. It's used in the standalone mode.
'''
def __init__(self, ip_config, graph_name, gpb=None, part_config=None):
self.ip_config = ip_config
self.graph_name = graph_name
self._gpb_input = gpb
if os.environ.get('DGL_DIST_MODE', 'standalone') == 'standalone':
assert part_config is not None, \
'When running in the standalone model, the partition config file is required'
......@@ -340,20 +342,13 @@ class DistGraph:
self._client.add_data(_get_data_name(name, EDGE_PART_POLICY), edge_feats[name])
rpc.set_num_client(1)
else:
connect_to_server(ip_config=ip_config)
self._client = KVClient(ip_config)
self._g = _get_graph_from_shared_mem(graph_name)
self._gpb = get_shared_mem_partition_book(graph_name, self._g)
if self._gpb is None:
self._gpb = gpb
self._init()
# Tell the backup servers to load the graph structure from shared memory.
for server_id in range(self._client.num_servers):
rpc.send_request(server_id, InitGraphRequest(graph_name))
for server_id in range(self._client.num_servers):
rpc.recv_response()
self._client.barrier()
self._client.map_shared_data(self._gpb)
self._ndata = NodeDataView(self)
self._edata = EdgeDataView(self)
......@@ -364,6 +359,32 @@ class DistGraph:
self._num_nodes += int(part_md['num_nodes'])
self._num_edges += int(part_md['num_edges'])
def _init(self):
# Init KVStore client if it's not initialized yet.
init_kvstore(self.ip_config)
self._client = get_kvstore()
self._g = _get_graph_from_shared_mem(self.graph_name)
self._gpb = get_shared_mem_partition_book(self.graph_name, self._g)
if self._gpb is None:
self._gpb = self._gpb_input
self._client.map_shared_data(self._gpb)
def __getstate__(self):
return self.ip_config, self.graph_name, self._gpb_input
def __setstate__(self, state):
self.ip_config, self.graph_name, self._gpb_input = state
self._init()
self._ndata = NodeDataView(self)
self._edata = EdgeDataView(self)
self._num_nodes = 0
self._num_edges = 0
for part_md in self._gpb.metadata():
self._num_nodes += int(part_md['num_nodes'])
self._num_edges += int(part_md['num_edges'])
@property
def local_partition(self):
''' Return the local partition on the client
......
......@@ -3,7 +3,7 @@
import os
from .graph_partition_book import PartitionPolicy, NODE_PART_POLICY, EDGE_PART_POLICY
from .rpc_client import is_initialized
from .dist_context import is_initialized
from ..base import DGLError
from .. import utils
from .. import backend as F
......
......@@ -1287,3 +1287,20 @@ class KVClient(object):
"""Used by sort response list
"""
return elem.server_id
KVCLIENT = None
def init_kvstore(ip_config, role='default'):
"""initialize KVStore"""
global KVCLIENT
if KVCLIENT is None:
KVCLIENT = KVClient(ip_config, role)
def close_kvstore():
"""Close the current KVClient"""
global KVCLIENT
KVCLIENT = None
def get_kvstore():
"""get the KVClient"""
return KVCLIENT
......@@ -79,6 +79,12 @@ def read_ip_config(filename):
print("Error: data format on each line should be: [ip] [base_port] [server_count]")
return server_namebook
def reset():
"""Reset the rpc context
"""
_CAPI_DGLRPCReset()
def create_sender(max_queue_size, net_type):
"""Create rpc sender of this process.
......
......@@ -96,7 +96,6 @@ def get_local_usable_addr():
return ip_addr + ':' + str(port)
INITIALIZED = False
def connect_to_server(ip_config, max_queue_size=MAX_QUEUE_SIZE, net_type='socket'):
"""Connect this client to server.
......@@ -175,16 +174,9 @@ def connect_to_server(ip_config, max_queue_size=MAX_QUEUE_SIZE, net_type='socket
rpc.send_request(0, get_client_num_req)
res = rpc.recv_response()
rpc.set_num_client(res.num_client)
from .dist_context import exit_client, set_initialized
atexit.register(exit_client)
global INITIALIZED
INITIALIZED = True
def finalize_client():
"""Release resources of this client."""
rpc.finalize_sender()
rpc.finalize_receiver()
global INITIALIZED
INITIALIZED = False
set_initialized()
def shutdown_servers():
"""Issue commands to remote servers to shut them down.
......@@ -193,21 +185,7 @@ def shutdown_servers():
------
ConnectionError : If anything wrong with the connection.
"""
if rpc.get_rank() == 0: # Only client_0 issue this command
if rpc.get_rank() == 0: # Only client_0 issue this command
req = rpc.ShutDownRequest(rpc.get_rank())
for server_id in range(rpc.get_num_server()):
rpc.send_request(server_id, req)
def exit_client():
"""Register exit callback.
"""
# Only client with rank_0 will send shutdown request to servers.
rpc.client_barrier()
shutdown_servers()
finalize_client()
atexit.unregister(exit_client)
def is_initialized():
"""Is RPC initialized?
"""
return INITIALIZED
......@@ -77,6 +77,10 @@ RPCStatus RecvRPCMessage(RPCMessage* msg, int32_t timeout) {
}
//////////////////////////// C APIs ////////////////////////////
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCReset")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
RPCContext::Reset();
});
DGL_REGISTER_GLOBAL("distributed.rpc._CAPI_DGLRPCCreateSender")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
......
......@@ -96,6 +96,19 @@ struct RPCContext {
static RPCContext *ThreadLocal() {
return dmlc::ThreadLocalStore<RPCContext>::Get();
}
/*! \brief Reset the RPC context */
static void Reset() {
auto* t = ThreadLocal();
t->rank = -1;
t->machine_id = -1;
t->num_machines = 0;
t->num_clients = 0;
t->barrier_count = 0;
t->num_servers_per_machine = 0;
t->sender = std::shared_ptr<network::Sender>();
t->receiver = std::shared_ptr<network::Receiver>();
}
};
/*! \brief RPC message data structure
......
......@@ -68,6 +68,7 @@ def rand_init(shape, dtype):
def run_client(graph_name, part_id, num_clients, num_nodes, num_edges):
time.sleep(5)
dgl.distributed.initialize("kv_ip_config.txt")
gpb, graph_name = load_partition_book('/tmp/dist_graph/{}.json'.format(graph_name),
part_id, None)
g = DistGraph("kv_ip_config.txt", graph_name, gpb=gpb)
......
......@@ -11,7 +11,7 @@ import backend as F
import time
from utils import get_local_usable_addr
from pathlib import Path
import pytest
from dgl.distributed import DistGraphServer, DistGraph
......@@ -25,6 +25,7 @@ def start_sample_client(rank, tmpdir, disable_shared_mem):
gpb = None
if disable_shared_mem:
_, _, _, gpb, _ = load_partition(tmpdir / 'test_sampling.json', rank)
dgl.distributed.initialize("rpc_ip_config.txt")
dist_graph = DistGraph("rpc_ip_config.txt", "test_sampling", gpb=gpb)
sampled_graph = sample_neighbors(dist_graph, [0, 10, 99, 66, 1024, 2008], 3)
dgl.distributed.exit_client()
......@@ -158,12 +159,12 @@ def check_rpc_sampling_shuffle(tmpdir, num_server):
# Wait non shared memory graph store
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
@unittest.skipIf(dgl.backend.backend_name == 'tensorflow', reason='Not support tensorflow for now')
def test_rpc_sampling_shuffle():
@pytest.mark.parametrize("num_server", [1, 2])
def test_rpc_sampling_shuffle(num_server):
import tempfile
os.environ['DGL_DIST_MODE'] = 'distributed'
with tempfile.TemporaryDirectory() as tmpdirname:
check_rpc_sampling_shuffle(Path(tmpdirname), 2)
check_rpc_sampling_shuffle(Path(tmpdirname), 1)
check_rpc_sampling_shuffle(Path(tmpdirname), num_server)
def check_standalone_sampling(tmpdir):
g = CitationGraphDataset("cora")[0]
......@@ -171,7 +172,7 @@ def check_standalone_sampling(tmpdir):
num_hops = 1
partition_graph(g, 'test_sampling', num_parts, tmpdir,
num_hops=num_hops, part_method='metis', reshuffle=False)
os.environ['DGL_DIST_MODE'] = 'standalone'
dist_graph = DistGraph(None, "test_sampling", part_config=tmpdir / 'test_sampling.json')
sampled_graph = sample_neighbors(dist_graph, [0, 10, 99, 66, 1024, 2008], 3)
......@@ -192,6 +193,7 @@ def test_standalone_sampling():
def start_in_subgraph_client(rank, tmpdir, disable_shared_mem, nodes):
gpb = None
dgl.distributed.initialize("rpc_ip_config.txt")
if disable_shared_mem:
_, _, _, gpb, _ = load_partition(tmpdir / 'test_in_subgraph.json', rank)
dist_graph = DistGraph("rpc_ip_config.txt", "test_in_subgraph", gpb=gpb)
......
import dgl
import unittest
import os
from dgl.data import CitationGraphDataset
from dgl.distributed import sample_neighbors
from dgl.distributed import partition_graph, load_partition, load_partition_book
import sys
import multiprocessing as mp
import numpy as np
import time
from utils import get_local_usable_addr
from pathlib import Path
from dgl.distributed import DistGraphServer, DistGraph, DistDataLoader
import pytest
import backend as F
class NeighborSampler(object):
def __init__(self, g, fanouts, sample_neighbors):
self.g = g
self.fanouts = fanouts
self.sample_neighbors = sample_neighbors
def sample_blocks(self, seeds):
import torch as th
seeds = th.LongTensor(np.asarray(seeds))
blocks = []
for fanout in self.fanouts:
# For each seed node, sample ``fanout`` neighbors.
frontier = self.sample_neighbors(
self.g, seeds, fanout, replace=True)
# Then we compact the frontier into a bipartite graph for message passing.
block = dgl.to_block(frontier, seeds)
# Obtain the seed nodes for next layer.
seeds = block.srcdata[dgl.NID]
blocks.insert(0, block)
return blocks
def start_server(rank, tmpdir, disable_shared_mem, num_clients):
import dgl
print('server: #clients=' + str(num_clients))
g = DistGraphServer(rank, "mp_ip_config.txt", num_clients,
tmpdir / 'test_sampling.json', disable_shared_mem=disable_shared_mem)
g.start()
def start_client(rank, tmpdir, disable_shared_mem, num_workers, drop_last):
import dgl
import torch as th
os.environ['DGL_DIST_MODE'] = 'distributed'
dgl.distributed.initialize("mp_ip_config.txt", num_workers=4)
gpb = None
if disable_shared_mem:
_, _, _, gpb, _ = load_partition(tmpdir / 'test_sampling.json', rank)
num_nodes_to_sample = 202
batch_size = 32
train_nid = th.arange(num_nodes_to_sample)
dist_graph = DistGraph("mp_ip_config.txt", "test_mp", gpb=gpb)
# Create sampler
sampler = NeighborSampler(dist_graph, [5, 10],
dgl.distributed.sample_neighbors)
# Create DataLoader for constructing blocks
dataloader = DistDataLoader(
dataset=train_nid.numpy(),
batch_size=batch_size,
collate_fn=sampler.sample_blocks,
shuffle=False,
drop_last=drop_last)
groundtruth_g = CitationGraphDataset("cora")[0]
max_nid = []
for epoch in range(2):
for idx, blocks in zip(range(0, num_nodes_to_sample, batch_size), dataloader):
block = blocks[-1]
o_src, o_dst = block.edges()
src_nodes_id = block.srcdata[dgl.NID][o_src]
dst_nodes_id = block.dstdata[dgl.NID][o_dst]
has_edges = groundtruth_g.has_edges_between(src_nodes_id, dst_nodes_id)
assert np.all(F.asnumpy(has_edges))
print(np.unique(np.sort(F.asnumpy(dst_nodes_id))))
max_nid.append(np.max(F.asnumpy(dst_nodes_id)))
# assert np.all(np.unique(np.sort(F.asnumpy(dst_nodes_id))) == np.arange(idx, batch_size))
if drop_last:
assert np.max(max_nid) == num_nodes_to_sample - 1 - num_nodes_to_sample % batch_size
else:
assert np.max(max_nid) == num_nodes_to_sample - 1
dgl.distributed.exit_client() # this is needed since there's two test here in one process
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
@unittest.skipIf(dgl.backend.backend_name != 'pytorch', reason='Only support PyTorch for now')
@pytest.mark.parametrize("num_server", [3])
@pytest.mark.parametrize("drop_last", [True, False])
def test_dist_dataloader(tmpdir, num_server, drop_last):
ip_config = open("mp_ip_config.txt", "w")
for _ in range(num_server):
ip_config.write('{} 1\n'.format(get_local_usable_addr()))
ip_config.close()
g = CitationGraphDataset("cora")[0]
print(g.idtype)
num_parts = num_server
num_hops = 1
partition_graph(g, 'test_sampling', num_parts, tmpdir,
num_hops=num_hops, part_method='metis', reshuffle=False)
num_workers = 4
pserver_list = []
ctx = mp.get_context('spawn')
for i in range(num_server):
p = ctx.Process(target=start_server, args=(
i, tmpdir, num_server > 1, num_workers+1))
p.start()
time.sleep(1)
pserver_list.append(p)
time.sleep(3)
sampled_graph = start_client(0, tmpdir, num_server > 1, num_workers, drop_last)
if __name__ == "__main__":
import tempfile
with tempfile.TemporaryDirectory() as tmpdirname:
test_dist_dataloader(Path(tmpdirname), 3, True)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment