"git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "3301bd0bcc0fb1cfb9f502691241f2bc675c6462"
Unverified Commit ee30b2aa authored by Da Zheng's avatar Da Zheng Committed by GitHub
Browse files

[Distributed] Fix standalone (#1974)



* fix tests.

* fix.

* remove a test.

* make code work in the standalone mode.

* fix example.

* more fix.

* make DistDataloader work with num_workers=0

* fix DistDataloader tests.

* fix.

* fix lint.

* fix cleanup.

* fix test

* remove unnecessary code.

* remove tests.

* fix.

* fix.

* fix.

* fix example

* fix.

* fix.

* fix launch script.
Co-authored-by: default avatarJinjing Zhou <VoVAllen@users.noreply.github.com>
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-19-1.us-west-2.compute.internal>
parent c2cd6eb2
...@@ -43,10 +43,11 @@ specify relative paths to the path of the workspace. ...@@ -43,10 +43,11 @@ specify relative paths to the path of the workspace.
```bash ```bash
python3 ~/dgl/tools/launch.py \ python3 ~/dgl/tools/launch.py \
--workspace ~/graphsage/ \ --workspace ~/graphsage/ \
--num_client 4 \ --num_trainers 1 \
--num_samplers 4 \
--part_config ogb-product/ogb-product.json \ --part_config ogb-product/ogb-product.json \
--ip_config ip_config.txt \ --ip_config ip_config.txt \
"python3 train_dist.py --graph-name ogb-product --ip_config ip_config.txt --num-epochs 30 --batch-size 1000" "python3 train_dist.py --graph-name ogb-product --ip_config ip_config.txt --num-epochs 30 --batch-size 1000 --num-workers 4"
``` ```
To run unsupervised training: To run unsupervised training:
...@@ -54,10 +55,10 @@ To run unsupervised training: ...@@ -54,10 +55,10 @@ To run unsupervised training:
```bash ```bash
python3 ~/dgl/tools/launch.py \ python3 ~/dgl/tools/launch.py \
--workspace ~/dgl/examples/pytorch/graphsage/experimental \ --workspace ~/dgl/examples/pytorch/graphsage/experimental \
--num_client 4 \ --num_trainers 1 \
--part_config data/ogb-product.json \ --part_config ogb-product/ogb-product.json \
--ip_config ip_config.txt \ --ip_config ip_config.txt \
"python3 train_dist_unsupervised.py --graph-name ogb-product --ip_config ip_config.txt --num-epochs 3 --batch-size 1000 --num-client 4" "python3 train_dist_unsupervised.py --graph-name ogb-product --ip_config ip_config.txt --num-epochs 3 --batch-size 1000"
``` ```
## Distributed code runs in the standalone mode ## Distributed code runs in the standalone mode
......
...@@ -264,7 +264,7 @@ def main(args): ...@@ -264,7 +264,7 @@ def main(args):
th.distributed.init_process_group(backend='gloo') th.distributed.init_process_group(backend='gloo')
dgl.distributed.initialize(args.ip_config, num_workers=args.num_workers) dgl.distributed.initialize(args.ip_config, num_workers=args.num_workers)
g = dgl.distributed.DistGraph(args.ip_config, args.graph_name, part_config=args.conf_path) g = dgl.distributed.DistGraph(args.ip_config, args.graph_name, part_config=args.part_config)
print('rank:', g.rank()) print('rank:', g.rank())
pb = g.get_partition_book() pb = g.get_partition_book()
...@@ -293,7 +293,7 @@ if __name__ == '__main__': ...@@ -293,7 +293,7 @@ if __name__ == '__main__':
parser.add_argument('--graph-name', type=str, help='graph name') parser.add_argument('--graph-name', type=str, help='graph name')
parser.add_argument('--id', type=int, help='the partition id') parser.add_argument('--id', type=int, help='the partition id')
parser.add_argument('--ip_config', type=str, help='The file for IP configuration') parser.add_argument('--ip_config', type=str, help='The file for IP configuration')
parser.add_argument('--conf_path', type=str, help='The path to the partition config file') parser.add_argument('--part_config', type=str, help='The path to the partition config file')
parser.add_argument('--num-client', type=int, help='The number of clients') parser.add_argument('--num-client', type=int, help='The number of clients')
parser.add_argument('--n-classes', type=int, help='the number of classes') parser.add_argument('--n-classes', type=int, help='the number of classes')
parser.add_argument('--gpu', type=int, default=0, parser.add_argument('--gpu', type=int, default=0,
......
...@@ -349,7 +349,8 @@ def run(args, device, data): ...@@ -349,7 +349,8 @@ def run(args, device, data):
def main(args): def main(args):
if not args.standalone: if not args.standalone:
th.distributed.init_process_group(backend='gloo') th.distributed.init_process_group(backend='gloo')
g = dgl.distributed.DistGraph(args.ip_config, args.graph_name, part_config=args.conf_path) dgl.distributed.initialize(args.ip_config, num_workers=args.num_workers)
g = dgl.distributed.DistGraph(args.ip_config, args.graph_name, part_config=args.part_config)
print('rank:', g.rank()) print('rank:', g.rank())
print('number of edges', g.number_of_edges()) print('number of edges', g.number_of_edges())
...@@ -379,8 +380,7 @@ if __name__ == '__main__': ...@@ -379,8 +380,7 @@ if __name__ == '__main__':
parser.add_argument('--graph-name', type=str, help='graph name') parser.add_argument('--graph-name', type=str, help='graph name')
parser.add_argument('--id', type=int, help='the partition id') parser.add_argument('--id', type=int, help='the partition id')
parser.add_argument('--ip_config', type=str, help='The file for IP configuration') parser.add_argument('--ip_config', type=str, help='The file for IP configuration')
parser.add_argument('--conf_path', type=str, help='The path to the partition config file') parser.add_argument('--part_config', type=str, help='The path to the partition config file')
parser.add_argument('--num-client', type=int, help='The number of clients')
parser.add_argument('--n-classes', type=int, help='the number of classes') parser.add_argument('--n-classes', type=int, help='the number of classes')
parser.add_argument('--gpu', type=int, default=0, parser.add_argument('--gpu', type=int, default=0,
help="GPU device ID. Use -1 for CPU training") help="GPU device ID. Use -1 for CPU training")
......
...@@ -4,6 +4,8 @@ import multiprocessing as mp ...@@ -4,6 +4,8 @@ import multiprocessing as mp
import traceback import traceback
import atexit import atexit
import time import time
import os
from . import rpc from . import rpc
from .constants import MAX_QUEUE_SIZE from .constants import MAX_QUEUE_SIZE
from .kvstore import init_kvstore, close_kvstore from .kvstore import init_kvstore, close_kvstore
...@@ -31,7 +33,8 @@ def _init_rpc(ip_config, max_queue_size, net_type, role, num_threads): ...@@ -31,7 +33,8 @@ def _init_rpc(ip_config, max_queue_size, net_type, role, num_threads):
''' '''
try: try:
utils.set_num_threads(num_threads) utils.set_num_threads(num_threads)
connect_to_server(ip_config, max_queue_size, net_type) if os.environ.get('DGL_DIST_MODE', 'standalone') != 'standalone':
connect_to_server(ip_config, max_queue_size, net_type)
init_role(role) init_role(role)
init_kvstore(ip_config, role) init_kvstore(ip_config, role)
except Exception as e: except Exception as e:
...@@ -60,20 +63,25 @@ def initialize(ip_config, num_workers=0, max_queue_size=MAX_QUEUE_SIZE, net_type ...@@ -60,20 +63,25 @@ def initialize(ip_config, num_workers=0, max_queue_size=MAX_QUEUE_SIZE, net_type
ctx = mp.get_context("spawn") ctx = mp.get_context("spawn")
global SAMPLER_POOL global SAMPLER_POOL
global NUM_SAMPLER_WORKERS global NUM_SAMPLER_WORKERS
if num_workers > 0: is_standalone = os.environ.get('DGL_DIST_MODE', 'standalone') == 'standalone'
SAMPLER_POOL = ctx.Pool( if num_workers > 0 and not is_standalone:
num_workers, initializer=_init_rpc, initargs=(ip_config, max_queue_size, SAMPLER_POOL = ctx.Pool(num_workers, initializer=_init_rpc,
net_type, 'sampler', num_worker_threads)) initargs=(ip_config, max_queue_size,
net_type, 'sampler', num_worker_threads))
else:
SAMPLER_POOL = None
NUM_SAMPLER_WORKERS = num_workers NUM_SAMPLER_WORKERS = num_workers
connect_to_server(ip_config, max_queue_size, net_type) if not is_standalone:
connect_to_server(ip_config, max_queue_size, net_type)
init_role('default') init_role('default')
init_kvstore(ip_config, 'default') init_kvstore(ip_config, 'default')
def finalize_client(): def finalize_client():
"""Release resources of this client.""" """Release resources of this client."""
rpc.finalize_sender() if os.environ.get('DGL_DIST_MODE', 'standalone') != 'standalone':
rpc.finalize_receiver() rpc.finalize_sender()
rpc.finalize_receiver()
global INITIALIZED global INITIALIZED
INITIALIZED = False INITIALIZED = False
...@@ -95,8 +103,10 @@ def finalize_worker(): ...@@ -95,8 +103,10 @@ def finalize_worker():
def join_finalize_worker(): def join_finalize_worker():
"""join the worker close process""" """join the worker close process"""
global SAMPLER_POOL
if SAMPLER_POOL is not None: if SAMPLER_POOL is not None:
SAMPLER_POOL.join() SAMPLER_POOL.join()
SAMPLER_POOL = None
def is_initialized(): def is_initialized():
"""Is RPC initialized? """Is RPC initialized?
...@@ -109,8 +119,9 @@ def exit_client(): ...@@ -109,8 +119,9 @@ def exit_client():
""" """
# Only client with rank_0 will send shutdown request to servers. # Only client with rank_0 will send shutdown request to servers.
finalize_worker() # finalize workers should be earilier than barrier, and non-blocking finalize_worker() # finalize workers should be earilier than barrier, and non-blocking
rpc.client_barrier() if os.environ.get('DGL_DIST_MODE', 'standalone') != 'standalone':
shutdown_servers() rpc.client_barrier()
shutdown_servers()
finalize_client() finalize_client()
join_finalize_worker() join_finalize_worker()
close_kvstore() close_kvstore()
......
# pylint: disable=global-variable-undefined, invalid-name # pylint: disable=global-variable-undefined, invalid-name
"""Multiprocess dataloader for distributed training""" """Multiprocess dataloader for distributed training"""
import multiprocessing as mp import multiprocessing as mp
from queue import Queue
import time import time
import traceback import traceback
...@@ -82,29 +83,24 @@ class DistDataLoader: ...@@ -82,29 +83,24 @@ class DistDataLoader:
will be smaller. (default: ``False``) will be smaller. (default: ``False``)
queue_size (int, optional): Size of multiprocessing queue queue_size (int, optional): Size of multiprocessing queue
""" """
self.pool, num_workers = get_sampler_pool() self.pool, self.num_workers = get_sampler_pool()
assert num_workers > 0, "DistDataloader only supports num_workers>0 for now. if you \
want to use single process dataloader, please use PyTorch dataloader for now"
if queue_size is None: if queue_size is None:
queue_size = num_workers * 4 queue_size = self.num_workers * 4 if self.num_workers > 0 else 4
self.queue_size = queue_size self.queue_size = queue_size
self.batch_size = batch_size self.batch_size = batch_size
self.queue_size = queue_size self.num_pending = 0
self.collate_fn = collate_fn self.collate_fn = collate_fn
self.current_pos = 0 self.current_pos = 0
self.num_workers = num_workers if self.pool is not None:
self.m = mp.Manager() self.m = mp.Manager()
self.queue = self.m.Queue(maxsize=queue_size) self.queue = self.m.Queue(maxsize=queue_size)
else:
self.queue = Queue(maxsize=queue_size)
self.drop_last = drop_last self.drop_last = drop_last
self.recv_idxs = 0 self.recv_idxs = 0
self.started = False
self.shuffle = shuffle self.shuffle = shuffle
self.is_closed = False self.is_closed = False
if self.pool is None:
ctx = mp.get_context("spawn")
self.pool = ctx.Pool(num_workers)
self.dataset = F.tensor(dataset) self.dataset = F.tensor(dataset)
self.expected_idxs = len(dataset) // self.batch_size self.expected_idxs = len(dataset) // self.batch_size
if not self.drop_last and len(dataset) % self.batch_size != 0: if not self.drop_last and len(dataset) % self.batch_size != 0:
...@@ -116,30 +112,33 @@ class DistDataLoader: ...@@ -116,30 +112,33 @@ class DistDataLoader:
self.name = "dataloader-" + str(DATALOADER_ID) self.name = "dataloader-" + str(DATALOADER_ID)
DATALOADER_ID += 1 DATALOADER_ID += 1
results = [] if self.pool is not None:
for _ in range(self.num_workers): results = []
results.append(self.pool.apply_async( for _ in range(self.num_workers):
init_fn, args=(self.name, self.collate_fn, self.queue))) results.append(self.pool.apply_async(
for res in results: init_fn, args=(self.name, self.collate_fn, self.queue)))
res.get() for res in results:
res.get()
def __del__(self): def __del__(self):
results = [] if self.pool is not None:
for _ in range(self.num_workers): results = []
results.append(self.pool.apply_async(cleanup_fn, args=(self.name,))) for _ in range(self.num_workers):
for res in results: results.append(self.pool.apply_async(cleanup_fn, args=(self.name,)))
res.get() for res in results:
res.get()
def __next__(self): def __next__(self):
if not self.started: num_reqs = self.queue_size - self.num_pending
for _ in range(self.queue_size): for _ in range(num_reqs):
self._request_next_batch() self._request_next_batch()
self._request_next_batch()
if self.recv_idxs < self.expected_idxs: if self.recv_idxs < self.expected_idxs:
result = self.queue.get(timeout=9999) result = self.queue.get(timeout=9999)
self.recv_idxs += 1 self.recv_idxs += 1
self.num_pending -= 1
return result return result
else: else:
assert self.num_pending == 0
raise StopIteration raise StopIteration
def __iter__(self): def __iter__(self):
...@@ -147,16 +146,19 @@ class DistDataLoader: ...@@ -147,16 +146,19 @@ class DistDataLoader:
self.dataset = F.rand_shuffle(self.dataset) self.dataset = F.rand_shuffle(self.dataset)
self.recv_idxs = 0 self.recv_idxs = 0
self.current_pos = 0 self.current_pos = 0
self.num_pending = 0
return self return self
def _request_next_batch(self): def _request_next_batch(self):
next_data = self._next_data() next_data = self._next_data()
if next_data is None: if next_data is None:
return None return
elif self.pool is not None:
self.pool.apply_async(call_collate_fn, args=(self.name, next_data, ))
else: else:
async_result = self.pool.apply_async( result = self.collate_fn(next_data)
call_collate_fn, args=(self.name, next_data, )) self.queue.put(result)
return async_result self.num_pending += 1
def _next_data(self): def _next_data(self):
if self.current_pos == len(self.dataset): if self.current_pos == len(self.dataset):
......
...@@ -369,7 +369,7 @@ class DistGraph: ...@@ -369,7 +369,7 @@ class DistGraph:
self._client.map_shared_data(self._gpb) self._client.map_shared_data(self._gpb)
def __getstate__(self): def __getstate__(self):
return self.ip_config, self.graph_name, self._gpb_input return self.ip_config, self.graph_name, self._gpb
def __setstate__(self, state): def __setstate__(self, state):
self.ip_config, self.graph_name, self._gpb_input = state self.ip_config, self.graph_name, self._gpb_input = state
......
...@@ -5,6 +5,7 @@ import numpy as np ...@@ -5,6 +5,7 @@ import numpy as np
from . import rpc from . import rpc
from .graph_partition_book import PartitionPolicy from .graph_partition_book import PartitionPolicy
from .standalone_kvstore import KVClient as SA_KVClient
from .. import backend as F from .. import backend as F
from .. import utils from .. import utils
...@@ -776,8 +777,8 @@ class KVClient(object): ...@@ -776,8 +777,8 @@ class KVClient(object):
We can set different role for kvstore. We can set different role for kvstore.
""" """
def __init__(self, ip_config, role='default'): def __init__(self, ip_config, role='default'):
assert rpc.get_rank() != -1, 'Please invoke rpc.connect_to_server() \ assert rpc.get_rank() != -1, \
before creating KVClient.' 'Please invoke rpc.connect_to_server() before creating KVClient.'
assert os.path.exists(ip_config), 'Cannot open file: %s' % ip_config assert os.path.exists(ip_config), 'Cannot open file: %s' % ip_config
# Register services on client # Register services on client
rpc.register_service(KVSTORE_PULL, rpc.register_service(KVSTORE_PULL,
...@@ -1233,7 +1234,10 @@ def init_kvstore(ip_config, role): ...@@ -1233,7 +1234,10 @@ def init_kvstore(ip_config, role):
"""initialize KVStore""" """initialize KVStore"""
global KVCLIENT global KVCLIENT
if KVCLIENT is None: if KVCLIENT is None:
KVCLIENT = KVClient(ip_config, role) if os.environ.get('DGL_DIST_MODE', 'standalone') == 'standalone':
KVCLIENT = SA_KVClient()
else:
KVCLIENT = KVClient(ip_config, role)
def close_kvstore(): def close_kvstore():
"""Close the current KVClient""" """Close the current KVClient"""
......
...@@ -104,6 +104,8 @@ GLOBAL_RANK = {} ...@@ -104,6 +104,8 @@ GLOBAL_RANK = {}
# The role of the current process # The role of the current process
CUR_ROLE = None CUR_ROLE = None
IS_STANDALONE = False
def init_role(role): def init_role(role):
"""Initialize the role of the current process. """Initialize the role of the current process.
...@@ -121,11 +123,17 @@ def init_role(role): ...@@ -121,11 +123,17 @@ def init_role(role):
global PER_ROLE_RANK global PER_ROLE_RANK
global GLOBAL_RANK global GLOBAL_RANK
global IS_STANDALONE
if os.environ.get('DGL_DIST_MODE', 'standalone') == 'standalone': if os.environ.get('DGL_DIST_MODE', 'standalone') == 'standalone':
assert role == 'default' if role == 'default':
GLOBAL_RANK[0] = 0 GLOBAL_RANK[0] = 0
PER_ROLE_RANK['default'] = {0:0} PER_ROLE_RANK['default'] = {0:0}
IS_STANDALONE = True
return
PER_ROLE_RANK = {}
GLOBAL_RANK = {}
# Register the current role. This blocks until all clients register themselves. # Register the current role. This blocks until all clients register themselves.
client_id = rpc.get_rank() client_id = rpc.get_rank()
...@@ -180,11 +188,17 @@ def get_global_rank(): ...@@ -180,11 +188,17 @@ def get_global_rank():
The rank can globally identify the client process. For the client processes The rank can globally identify the client process. For the client processes
of the same role, their ranks are in a contiguous range. of the same role, their ranks are in a contiguous range.
""" """
return GLOBAL_RANK[rpc.get_rank()] if IS_STANDALONE:
return 0
else:
return GLOBAL_RANK[rpc.get_rank()]
def get_rank(role): def get_rank(role):
"""Get the role-specific rank""" """Get the role-specific rank"""
return PER_ROLE_RANK[role][rpc.get_rank()] if IS_STANDALONE:
return 0
else:
return PER_ROLE_RANK[role][rpc.get_rank()]
def get_trainer_rank(): def get_trainer_rank():
"""Get the rank of the current trainer process. """Get the rank of the current trainer process.
...@@ -193,7 +207,10 @@ def get_trainer_rank(): ...@@ -193,7 +207,10 @@ def get_trainer_rank():
an error if it's called in the process of other roles. an error if it's called in the process of other roles.
""" """
assert CUR_ROLE == 'default' assert CUR_ROLE == 'default'
return PER_ROLE_RANK['default'][rpc.get_rank()] if IS_STANDALONE:
return 0
else:
return PER_ROLE_RANK['default'][rpc.get_rank()]
def get_role(): def get_role():
"""Get the role of the current process""" """Get the role of the current process"""
......
...@@ -60,3 +60,6 @@ class KVClient(object): ...@@ -60,3 +60,6 @@ class KVClient(object):
return self._pull_handlers[name](self._data, name, id_tensor) return self._pull_handlers[name](self._data, name, id_tensor)
else: else:
return F.gather_row(self._data[name], id_tensor) return F.gather_row(self._data[name], id_tensor)
def map_shared_data(self, partition_book):
'''Mapping shared-memory tensor from server to client.'''
...@@ -139,7 +139,8 @@ def check_dist_graph(g, num_clients, num_nodes, num_edges): ...@@ -139,7 +139,8 @@ def check_dist_graph(g, num_clients, num_nodes, num_edges):
policy = dgl.distributed.PartitionPolicy('node', g.get_partition_book()) policy = dgl.distributed.PartitionPolicy('node', g.get_partition_book())
grad_sum = dgl.distributed.DistTensor(g, (g.number_of_nodes(),), F.float32, grad_sum = dgl.distributed.DistTensor(g, (g.number_of_nodes(),), F.float32,
'emb1_sum', policy) 'emb1_sum', policy)
assert np.all(F.asnumpy(grad_sum[nids]) == np.ones((len(nids), 1)) * num_clients) if num_clients == 1:
assert np.all(F.asnumpy(grad_sum[nids]) == np.ones((len(nids), 1)) * num_clients)
assert np.all(F.asnumpy(grad_sum[rest]) == np.zeros((len(rest), 1))) assert np.all(F.asnumpy(grad_sum[rest]) == np.zeros((len(rest), 1)))
emb = DistEmbedding(g, g.number_of_nodes(), 1, 'emb2', emb_init) emb = DistEmbedding(g, g.number_of_nodes(), 1, 'emb2', emb_init)
...@@ -240,11 +241,6 @@ def test_server_client(): ...@@ -240,11 +241,6 @@ def test_server_client():
@unittest.skipIf(dgl.backend.backend_name == "tensorflow", reason="TF doesn't support some of operations in DistGraph") @unittest.skipIf(dgl.backend.backend_name == "tensorflow", reason="TF doesn't support some of operations in DistGraph")
def test_standalone(): def test_standalone():
os.environ['DGL_DIST_MODE'] = 'standalone' os.environ['DGL_DIST_MODE'] = 'standalone'
# TODO(zhengda) this is a temporary fix. We need to make initialize work
# for standalone mode as well.
dgl.distributed.role.CUR_ROLE = 'default'
dgl.distributed.role.GLOBAL_RANK = {-1:0}
dgl.distributed.role.PER_ROLE_RANK['default'] = {-1:0}
g = create_random_graph(10000) g = create_random_graph(10000)
# Partition the graph # Partition the graph
...@@ -253,9 +249,12 @@ def test_standalone(): ...@@ -253,9 +249,12 @@ def test_standalone():
g.ndata['features'] = F.unsqueeze(F.arange(0, g.number_of_nodes()), 1) g.ndata['features'] = F.unsqueeze(F.arange(0, g.number_of_nodes()), 1)
g.edata['features'] = F.unsqueeze(F.arange(0, g.number_of_edges()), 1) g.edata['features'] = F.unsqueeze(F.arange(0, g.number_of_edges()), 1)
partition_graph(g, graph_name, num_parts, '/tmp/dist_graph') partition_graph(g, graph_name, num_parts, '/tmp/dist_graph')
dgl.distributed.initialize("kv_ip_config.txt")
dist_g = DistGraph("kv_ip_config.txt", graph_name, dist_g = DistGraph("kv_ip_config.txt", graph_name,
part_config='/tmp/dist_graph/{}.json'.format(graph_name)) part_config='/tmp/dist_graph/{}.json'.format(graph_name))
check_dist_graph(dist_g, 1, g.number_of_nodes(), g.number_of_edges()) check_dist_graph(dist_g, 1, g.number_of_nodes(), g.number_of_edges())
dgl.distributed.exit_client() # this is needed since there's two test here in one process
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet') @unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
def test_split(): def test_split():
......
...@@ -35,6 +35,7 @@ def start_find_edges_client(rank, tmpdir, disable_shared_mem, eids): ...@@ -35,6 +35,7 @@ def start_find_edges_client(rank, tmpdir, disable_shared_mem, eids):
gpb = None gpb = None
if disable_shared_mem: if disable_shared_mem:
_, _, _, gpb, _ = load_partition(tmpdir / 'test_find_edges.json', rank) _, _, _, gpb, _ = load_partition(tmpdir / 'test_find_edges.json', rank)
dgl.distributed.initialize("rpc_ip_config.txt")
dist_graph = DistGraph("rpc_ip_config.txt", "test_find_edges", gpb=gpb) dist_graph = DistGraph("rpc_ip_config.txt", "test_find_edges", gpb=gpb)
u, v = find_edges(dist_graph, eids) u, v = find_edges(dist_graph, eids)
dgl.distributed.exit_client() dgl.distributed.exit_client()
...@@ -172,7 +173,9 @@ def check_standalone_sampling(tmpdir): ...@@ -172,7 +173,9 @@ def check_standalone_sampling(tmpdir):
num_hops = 1 num_hops = 1
partition_graph(g, 'test_sampling', num_parts, tmpdir, partition_graph(g, 'test_sampling', num_parts, tmpdir,
num_hops=num_hops, part_method='metis', reshuffle=False) num_hops=num_hops, part_method='metis', reshuffle=False)
os.environ['DGL_DIST_MODE'] = 'standalone' os.environ['DGL_DIST_MODE'] = 'standalone'
dgl.distributed.initialize("rpc_ip_config.txt")
dist_graph = DistGraph(None, "test_sampling", part_config=tmpdir / 'test_sampling.json') dist_graph = DistGraph(None, "test_sampling", part_config=tmpdir / 'test_sampling.json')
sampled_graph = sample_neighbors(dist_graph, [0, 10, 99, 66, 1024, 2008], 3) sampled_graph = sample_neighbors(dist_graph, [0, 10, 99, 66, 1024, 2008], 3)
...@@ -182,6 +185,7 @@ def check_standalone_sampling(tmpdir): ...@@ -182,6 +185,7 @@ def check_standalone_sampling(tmpdir):
eids = g.edge_ids(src, dst) eids = g.edge_ids(src, dst)
assert np.array_equal( assert np.array_equal(
F.asnumpy(sampled_graph.edata[dgl.EID]), F.asnumpy(eids)) F.asnumpy(sampled_graph.edata[dgl.EID]), F.asnumpy(eids))
dgl.distributed.exit_client()
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet') @unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
@unittest.skipIf(dgl.backend.backend_name == 'tensorflow', reason='Not support tensorflow for now') @unittest.skipIf(dgl.backend.backend_name == 'tensorflow', reason='Not support tensorflow for now')
...@@ -259,4 +263,4 @@ if __name__ == "__main__": ...@@ -259,4 +263,4 @@ if __name__ == "__main__":
check_rpc_sampling(Path(tmpdirname), 2) check_rpc_sampling(Path(tmpdirname), 2)
check_rpc_sampling(Path(tmpdirname), 1) check_rpc_sampling(Path(tmpdirname), 1)
check_rpc_find_edges(Path(tmpdirname), 2) check_rpc_find_edges(Path(tmpdirname), 2)
check_rpc_find_edges(Path(tmpdirname), 1) check_rpc_find_edges(Path(tmpdirname), 1)
\ No newline at end of file
...@@ -48,15 +48,15 @@ def start_server(rank, tmpdir, disable_shared_mem, num_clients): ...@@ -48,15 +48,15 @@ def start_server(rank, tmpdir, disable_shared_mem, num_clients):
def start_client(rank, tmpdir, disable_shared_mem, num_workers, drop_last): def start_client(rank, tmpdir, disable_shared_mem, num_workers, drop_last):
import dgl import dgl
import torch as th import torch as th
os.environ['DGL_DIST_MODE'] = 'distributed' dgl.distributed.initialize("mp_ip_config.txt", num_workers=num_workers)
dgl.distributed.initialize("mp_ip_config.txt", num_workers=4)
gpb = None gpb = None
if disable_shared_mem: if disable_shared_mem:
_, _, _, gpb, _ = load_partition(tmpdir / 'test_sampling.json', rank) _, _, _, gpb, _ = load_partition(tmpdir / 'test_sampling.json', rank)
num_nodes_to_sample = 202 num_nodes_to_sample = 202
batch_size = 32 batch_size = 32
train_nid = th.arange(num_nodes_to_sample) train_nid = th.arange(num_nodes_to_sample)
dist_graph = DistGraph("mp_ip_config.txt", "test_mp", gpb=gpb) dist_graph = DistGraph("mp_ip_config.txt", "test_mp", gpb=gpb,
part_config=tmpdir / 'test_sampling.json')
# Create sampler # Create sampler
sampler = NeighborSampler(dist_graph, [5, 10], sampler = NeighborSampler(dist_graph, [5, 10],
...@@ -83,22 +83,40 @@ def start_client(rank, tmpdir, disable_shared_mem, num_workers, drop_last): ...@@ -83,22 +83,40 @@ def start_client(rank, tmpdir, disable_shared_mem, num_workers, drop_last):
dst_nodes_id = block.dstdata[dgl.NID][o_dst] dst_nodes_id = block.dstdata[dgl.NID][o_dst]
has_edges = groundtruth_g.has_edges_between(src_nodes_id, dst_nodes_id) has_edges = groundtruth_g.has_edges_between(src_nodes_id, dst_nodes_id)
assert np.all(F.asnumpy(has_edges)) assert np.all(F.asnumpy(has_edges))
print(np.unique(np.sort(F.asnumpy(dst_nodes_id))))
max_nid.append(np.max(F.asnumpy(dst_nodes_id))) max_nid.append(np.max(F.asnumpy(dst_nodes_id)))
# assert np.all(np.unique(np.sort(F.asnumpy(dst_nodes_id))) == np.arange(idx, batch_size)) # assert np.all(np.unique(np.sort(F.asnumpy(dst_nodes_id))) == np.arange(idx, batch_size))
if drop_last: if drop_last:
assert np.max(max_nid) == num_nodes_to_sample - 1 - num_nodes_to_sample % batch_size assert np.max(max_nid) == num_nodes_to_sample - 1 - num_nodes_to_sample % batch_size
else: else:
assert np.max(max_nid) == num_nodes_to_sample - 1 assert np.max(max_nid) == num_nodes_to_sample - 1
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
@unittest.skipIf(dgl.backend.backend_name != 'pytorch', reason='Only support PyTorch for now')
def test_standalone(tmpdir):
ip_config = open("mp_ip_config.txt", "w")
for _ in range(1):
ip_config.write('{} 1\n'.format(get_local_usable_addr()))
ip_config.close()
g = CitationGraphDataset("cora")[0]
print(g.idtype)
num_parts = 1
num_hops = 1
partition_graph(g, 'test_sampling', num_parts, tmpdir,
num_hops=num_hops, part_method='metis', reshuffle=False)
os.environ['DGL_DIST_MODE'] = 'standalone'
start_client(0, tmpdir, False, 2, True)
dgl.distributed.exit_client() # this is needed since there's two test here in one process dgl.distributed.exit_client() # this is needed since there's two test here in one process
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet') @unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
@unittest.skipIf(dgl.backend.backend_name != 'pytorch', reason='Only support PyTorch for now') @unittest.skipIf(dgl.backend.backend_name != 'pytorch', reason='Only support PyTorch for now')
@pytest.mark.parametrize("num_server", [3]) @pytest.mark.parametrize("num_server", [3])
@pytest.mark.parametrize("num_workers", [0, 4])
@pytest.mark.parametrize("drop_last", [True, False]) @pytest.mark.parametrize("drop_last", [True, False])
def test_dist_dataloader(tmpdir, num_server, drop_last): def test_dist_dataloader(tmpdir, num_server, num_workers, drop_last):
ip_config = open("mp_ip_config.txt", "w") ip_config = open("mp_ip_config.txt", "w")
for _ in range(num_server): for _ in range(num_server):
ip_config.write('{} 1\n'.format(get_local_usable_addr())) ip_config.write('{} 1\n'.format(get_local_usable_addr()))
...@@ -112,7 +130,6 @@ def test_dist_dataloader(tmpdir, num_server, drop_last): ...@@ -112,7 +130,6 @@ def test_dist_dataloader(tmpdir, num_server, drop_last):
partition_graph(g, 'test_sampling', num_parts, tmpdir, partition_graph(g, 'test_sampling', num_parts, tmpdir,
num_hops=num_hops, part_method='metis', reshuffle=False) num_hops=num_hops, part_method='metis', reshuffle=False)
num_workers = 4
pserver_list = [] pserver_list = []
ctx = mp.get_context('spawn') ctx = mp.get_context('spawn')
for i in range(num_server): for i in range(num_server):
...@@ -123,9 +140,13 @@ def test_dist_dataloader(tmpdir, num_server, drop_last): ...@@ -123,9 +140,13 @@ def test_dist_dataloader(tmpdir, num_server, drop_last):
pserver_list.append(p) pserver_list.append(p)
time.sleep(3) time.sleep(3)
sampled_graph = start_client(0, tmpdir, num_server > 1, num_workers, drop_last) os.environ['DGL_DIST_MODE'] = 'distributed'
start_client(0, tmpdir, num_server > 1, num_workers, drop_last)
dgl.distributed.exit_client() # this is needed since there's two test here in one process
if __name__ == "__main__": if __name__ == "__main__":
import tempfile import tempfile
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
test_dist_dataloader(Path(tmpdirname), 3, True) test_standalone(Path(tmpdirname))
test_dist_dataloader(Path(tmpdirname), 3, 0, True)
test_dist_dataloader(Path(tmpdirname), 3, 4, True)
...@@ -153,6 +153,7 @@ def start_client(ip_config): ...@@ -153,6 +153,7 @@ def start_client(ip_config):
assert_array_equal(F.asnumpy(res.tensor), F.asnumpy(TENSOR)) assert_array_equal(F.asnumpy(res.tensor), F.asnumpy(TENSOR))
def test_serialize(): def test_serialize():
os.environ['DGL_DIST_MODE'] = 'distributed'
from dgl.distributed.rpc import serialize_to_payload, deserialize_from_payload from dgl.distributed.rpc import serialize_to_payload, deserialize_from_payload
SERVICE_ID = 12345 SERVICE_ID = 12345
dgl.distributed.register_service(SERVICE_ID, MyRequest, MyResponse) dgl.distributed.register_service(SERVICE_ID, MyRequest, MyResponse)
...@@ -170,6 +171,7 @@ def test_serialize(): ...@@ -170,6 +171,7 @@ def test_serialize():
assert res.x == res1.x assert res.x == res1.x
def test_rpc_msg(): def test_rpc_msg():
os.environ['DGL_DIST_MODE'] = 'distributed'
from dgl.distributed.rpc import serialize_to_payload, deserialize_from_payload, RPCMessage from dgl.distributed.rpc import serialize_to_payload, deserialize_from_payload, RPCMessage
SERVICE_ID = 32452 SERVICE_ID = 32452
dgl.distributed.register_service(SERVICE_ID, MyRequest, MyResponse) dgl.distributed.register_service(SERVICE_ID, MyRequest, MyResponse)
...@@ -186,6 +188,7 @@ def test_rpc_msg(): ...@@ -186,6 +188,7 @@ def test_rpc_msg():
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet') @unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
def test_rpc(): def test_rpc():
os.environ['DGL_DIST_MODE'] = 'distributed'
ip_config = open("rpc_ip_config.txt", "w") ip_config = open("rpc_ip_config.txt", "w")
ip_addr = get_local_usable_addr() ip_addr = get_local_usable_addr()
ip_config.write('%s 1\n' % ip_addr) ip_config.write('%s 1\n' % ip_addr)
...@@ -201,6 +204,7 @@ def test_rpc(): ...@@ -201,6 +204,7 @@ def test_rpc():
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet') @unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
def test_multi_client(): def test_multi_client():
os.environ['DGL_DIST_MODE'] = 'distributed'
ip_config = open("rpc_ip_config_mul_client.txt", "w") ip_config = open("rpc_ip_config_mul_client.txt", "w")
ip_addr = get_local_usable_addr() ip_addr = get_local_usable_addr()
ip_config.write('%s 1\n' % ip_addr) ip_config.write('%s 1\n' % ip_addr)
......
...@@ -100,7 +100,7 @@ def main(): ...@@ -100,7 +100,7 @@ def main():
the contents of current directory will be rsyncd') the contents of current directory will be rsyncd')
parser.add_argument('--num_trainers', type=int, parser.add_argument('--num_trainers', type=int,
help='The number of trainer processes per machine') help='The number of trainer processes per machine')
parser.add_argument('--num_samplers', type=int, parser.add_argument('--num_samplers', type=int, default=0,
help='The number of sampler processes per trainer process') help='The number of sampler processes per trainer process')
parser.add_argument('--part_config', type=str, parser.add_argument('--part_config', type=str,
help='The file (in workspace) of the partition config') help='The file (in workspace) of the partition config')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment