Unverified Commit ea48ce7a authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] Black auto fix. (#4697)


Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent bd3fe59e
from urllib.parse import urlparse, urljoin
import os
import requests
import pytest
import json
import enum import enum
from pathlib import Path import json
import os
import tempfile import tempfile
from pathlib import Path
from urllib.parse import urljoin, urlparse
import pytest
import requests
class JobStatus(enum.Enum): class JobStatus(enum.Enum):
...@@ -27,8 +28,8 @@ JENKINS_STATUS_MAPPING = { ...@@ -27,8 +28,8 @@ JENKINS_STATUS_MAPPING = {
assert "BUILD_URL" in os.environ, "Are you in the Jenkins environment?" assert "BUILD_URL" in os.environ, "Are you in the Jenkins environment?"
job_link = os.environ["BUILD_URL"] job_link = os.environ["BUILD_URL"]
response = requests.get('{}wfapi'.format(job_link), verify=False).json() response = requests.get("{}wfapi".format(job_link), verify=False).json()
domain = '{uri.scheme}://{uri.netloc}/'.format(uri=urlparse(job_link)) domain = "{uri.scheme}://{uri.netloc}/".format(uri=urlparse(job_link))
stages = response["stages"] stages = response["stages"]
final_dict = {} final_dict = {}
...@@ -41,37 +42,38 @@ def get_jenkins_json(path): ...@@ -41,37 +42,38 @@ def get_jenkins_json(path):
for stage in stages: for stage in stages:
link = stage['_links']['self']['href'] link = stage["_links"]["self"]["href"]
stage_name = stage['name'] stage_name = stage["name"]
res = requests.get(urljoin(domain, link), verify=False).json() res = requests.get(urljoin(domain, link), verify=False).json()
nodes = res['stageFlowNodes'] nodes = res["stageFlowNodes"]
for node in nodes: for node in nodes:
nodes_dict[node['id']] = node nodes_dict[node["id"]] = node
nodes_dict[node['id']]['stageName'] = stage_name nodes_dict[node["id"]]["stageName"] = stage_name
def get_node_full_name(node, node_dict): def get_node_full_name(node, node_dict):
name = "" name = ""
while "parentNodes" in node: while "parentNodes" in node:
name = name + "/" + node["name"] name = name + "/" + node["name"]
id = node['parentNodes'][0] id = node["parentNodes"][0]
if id in nodes_dict: if id in nodes_dict:
node = node_dict[id] node = node_dict[id]
else: else:
break break
return name return name
for key, node in nodes_dict.items(): for key, node in nodes_dict.items():
logs = get_jenkins_json( logs = get_jenkins_json(node["_links"]["log"]["href"]).get("text", "")
node['_links']['log']['href']).get('text', '') node_name = node["name"]
node_name = node['name'] if "Post Actions" in node["stageName"]:
if "Post Actions" in node['stageName']:
continue continue
node_status = node['status'] node_status = node["status"]
id = node['id'] id = node["id"]
full_name = get_node_full_name(node, nodes_dict) full_name = get_node_full_name(node, nodes_dict)
final_dict["{}_{}/{}".format(id, node['stageName'], full_name)] = { final_dict["{}_{}/{}".format(id, node["stageName"], full_name)] = {
"status": JENKINS_STATUS_MAPPING[node_status], "status": JENKINS_STATUS_MAPPING[node_status],
"logs": logs "logs": logs,
} }
JOB_NAME = os.getenv("JOB_NAME") JOB_NAME = os.getenv("JOB_NAME")
...@@ -85,15 +87,18 @@ prefix = f"https://dgl-ci-result.s3.us-west-2.amazonaws.com/{JOB_NAME}/{BUILD_NU ...@@ -85,15 +87,18 @@ prefix = f"https://dgl-ci-result.s3.us-west-2.amazonaws.com/{JOB_NAME}/{BUILD_NU
def test_generate_report(test_name): def test_generate_report(test_name):
os.makedirs("./logs_dir/", exist_ok=True) os.makedirs("./logs_dir/", exist_ok=True)
tmp = tempfile.NamedTemporaryFile( tmp = tempfile.NamedTemporaryFile(
mode='w', delete=False, suffix=".log", dir="./logs_dir/") mode="w", delete=False, suffix=".log", dir="./logs_dir/"
)
tmp.write(final_dict[test_name]["logs"]) tmp.write(final_dict[test_name]["logs"])
filename = Path(tmp.name).name filename = Path(tmp.name).name
# print(final_dict[test_name]["logs"]) # print(final_dict[test_name]["logs"])
print("Log path: {}".format(prefix+filename)) print("Log path: {}".format(prefix + filename))
if final_dict[test_name]["status"] == JobStatus.FAIL: if final_dict[test_name]["status"] == JobStatus.FAIL:
pytest.fail( pytest.fail(
"Test failed. Please see the log at {}".format(prefix+filename)) "Test failed. Please see the log at {}".format(prefix + filename)
)
elif final_dict[test_name]["status"] == JobStatus.SKIP: elif final_dict[test_name]["status"] == JobStatus.SKIP:
pytest.skip( pytest.skip(
"Test skipped. Please see the log at {}".format(prefix+filename)) "Test skipped. Please see the log at {}".format(prefix + filename)
)
import os import os
import requests import requests
JOB_NAME = os.getenv("JOB_NAME") JOB_NAME = os.getenv("JOB_NAME")
BUILD_NUMBER = os.getenv("BUILD_NUMBER") BUILD_NUMBER = os.getenv("BUILD_NUMBER")
BUILD_ID = os.getenv("BUILD_ID") BUILD_ID = os.getenv("BUILD_ID")
COMMIT = os.getenv("GIT_COMMIT") COMMIT = os.getenv("GIT_COMMIT")
job_link = os.environ["BUILD_URL"] job_link = os.environ["BUILD_URL"]
response = requests.get('{}wfapi'.format(job_link), verify=False).json() response = requests.get("{}wfapi".format(job_link), verify=False).json()
status = "✅ CI test succeeded" status = "✅ CI test succeeded"
for v in response['stages']: for v in response["stages"]:
if v['status'] in ['FAILED', 'ABORTED']: if v["status"] in ["FAILED", "ABORTED"]:
status = "❌ CI test failed in Stage [{}].".format(v['name']) status = "❌ CI test failed in Stage [{}].".format(v["name"])
break break
comment = f""" Commit ID: {COMMIT}\n comment = f""" Commit ID: {COMMIT}\n
......
def test(): def test():
pass pass
if __name__ == "__main__": if __name__ == "__main__":
test() test()
\ No newline at end of file
import tensorflow as tf from copy import deepcopy
from tensorflow.keras import layers
import backend as F
import networkx as nx import networkx as nx
import numpy as np
import pytest import pytest
import scipy as sp
import tensorflow as tf
from tensorflow.keras import layers
from test_utils import parametrize_idtype
from test_utils.graph_cases import (
get_cases,
random_bipartite,
random_dglgraph,
random_graph,
)
import dgl import dgl
import dgl.nn.tensorflow as nn
import dgl.function as fn import dgl.function as fn
import backend as F import dgl.nn.tensorflow as nn
from test_utils.graph_cases import get_cases, random_graph, random_bipartite, random_dglgraph
from test_utils import parametrize_idtype
from copy import deepcopy
import numpy as np
import scipy as sp
def _AXWb(A, X, W, b): def _AXWb(A, X, W, b):
X = tf.matmul(X, W) X = tf.matmul(X, W)
Y = tf.reshape(tf.matmul(A, tf.reshape(X, (X.shape[0], -1))), X.shape) Y = tf.reshape(tf.matmul(A, tf.reshape(X, (X.shape[0], -1))), X.shape)
return Y + b return Y + b
@pytest.mark.parametrize('out_dim', [1, 2])
@pytest.mark.parametrize("out_dim", [1, 2])
def test_graph_conv(out_dim): def test_graph_conv(out_dim):
g = dgl.DGLGraph(nx.path_graph(3)).to(F.ctx()) g = dgl.DGLGraph(nx.path_graph(3)).to(F.ctx())
ctx = F.ctx() ctx = F.ctx()
adj = tf.sparse.to_dense(tf.sparse.reorder(g.adjacency_matrix(transpose=True, ctx=ctx))) adj = tf.sparse.to_dense(
tf.sparse.reorder(g.adjacency_matrix(transpose=True, ctx=ctx))
)
conv = nn.GraphConv(5, out_dim, norm='none', bias=True) conv = nn.GraphConv(5, out_dim, norm="none", bias=True)
# conv = conv # conv = conv
print(conv) print(conv)
# test#1: basic # test#1: basic
...@@ -72,12 +82,16 @@ def test_graph_conv(out_dim): ...@@ -72,12 +82,16 @@ def test_graph_conv(out_dim):
# new_weight = conv.weight.data # new_weight = conv.weight.data
# assert not F.allclose(old_weight, new_weight) # assert not F.allclose(old_weight, new_weight)
@parametrize_idtype @parametrize_idtype
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree', 'dglgraph'])) @pytest.mark.parametrize(
@pytest.mark.parametrize('norm', ['none', 'both', 'right', 'left']) "g",
@pytest.mark.parametrize('weight', [True, False]) get_cases(["homo", "block-bipartite"], exclude=["zero-degree", "dglgraph"]),
@pytest.mark.parametrize('bias', [True, False]) )
@pytest.mark.parametrize('out_dim', [1, 2]) @pytest.mark.parametrize("norm", ["none", "both", "right", "left"])
@pytest.mark.parametrize("weight", [True, False])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("out_dim", [1, 2])
def test_graph_conv2(idtype, g, norm, weight, bias, out_dim): def test_graph_conv2(idtype, g, norm, weight, bias, out_dim):
g = g.astype(idtype).to(F.ctx()) g = g.astype(idtype).to(F.ctx())
conv = nn.GraphConv(5, out_dim, norm=norm, weight=weight, bias=bias) conv = nn.GraphConv(5, out_dim, norm=norm, weight=weight, bias=bias)
...@@ -92,12 +106,15 @@ def test_graph_conv2(idtype, g, norm, weight, bias, out_dim): ...@@ -92,12 +106,15 @@ def test_graph_conv2(idtype, g, norm, weight, bias, out_dim):
h_out = conv(g, h, weight=ext_w) h_out = conv(g, h, weight=ext_w)
assert h_out.shape == (ndst, out_dim) assert h_out.shape == (ndst, out_dim)
@parametrize_idtype @parametrize_idtype
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree', 'dglgraph'])) @pytest.mark.parametrize(
@pytest.mark.parametrize('norm', ['none', 'both', 'right']) "g", get_cases(["bipartite"], exclude=["zero-degree", "dglgraph"])
@pytest.mark.parametrize('weight', [True, False]) )
@pytest.mark.parametrize('bias', [True, False]) @pytest.mark.parametrize("norm", ["none", "both", "right"])
@pytest.mark.parametrize('out_dim', [1, 2]) @pytest.mark.parametrize("weight", [True, False])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("out_dim", [1, 2])
def test_graph_conv2_bi(idtype, g, norm, weight, bias, out_dim): def test_graph_conv2_bi(idtype, g, norm, weight, bias, out_dim):
g = g.astype(idtype).to(F.ctx()) g = g.astype(idtype).to(F.ctx())
conv = nn.GraphConv(5, out_dim, norm=norm, weight=weight, bias=bias) conv = nn.GraphConv(5, out_dim, norm=norm, weight=weight, bias=bias)
...@@ -112,6 +129,7 @@ def test_graph_conv2_bi(idtype, g, norm, weight, bias, out_dim): ...@@ -112,6 +129,7 @@ def test_graph_conv2_bi(idtype, g, norm, weight, bias, out_dim):
h_out = conv(g, (h, h_dst), weight=ext_w) h_out = conv(g, (h, h_dst), weight=ext_w)
assert h_out.shape == (ndst, out_dim) assert h_out.shape == (ndst, out_dim)
def test_simple_pool(): def test_simple_pool():
ctx = F.ctx() ctx = F.ctx()
g = dgl.DGLGraph(nx.path_graph(15)).to(F.ctx()) g = dgl.DGLGraph(nx.path_graph(15)).to(F.ctx())
...@@ -119,7 +137,7 @@ def test_simple_pool(): ...@@ -119,7 +137,7 @@ def test_simple_pool():
sum_pool = nn.SumPooling() sum_pool = nn.SumPooling()
avg_pool = nn.AvgPooling() avg_pool = nn.AvgPooling()
max_pool = nn.MaxPooling() max_pool = nn.MaxPooling()
sort_pool = nn.SortPooling(10) # k = 10 sort_pool = nn.SortPooling(10) # k = 10
print(sum_pool, avg_pool, max_pool, sort_pool) print(sum_pool, avg_pool, max_pool, sort_pool)
# test#1: basic # test#1: basic
...@@ -138,32 +156,48 @@ def test_simple_pool(): ...@@ -138,32 +156,48 @@ def test_simple_pool():
bg = dgl.batch([g, g_, g, g_, g]) bg = dgl.batch([g, g_, g, g_, g])
h0 = F.randn((bg.number_of_nodes(), 5)) h0 = F.randn((bg.number_of_nodes(), 5))
h1 = sum_pool(bg, h0) h1 = sum_pool(bg, h0)
truth = tf.stack([F.sum(h0[:15], 0), truth = tf.stack(
F.sum(h0[15:20], 0), [
F.sum(h0[20:35], 0), F.sum(h0[:15], 0),
F.sum(h0[35:40], 0), F.sum(h0[15:20], 0),
F.sum(h0[40:55], 0)], 0) F.sum(h0[20:35], 0),
F.sum(h0[35:40], 0),
F.sum(h0[40:55], 0),
],
0,
)
assert F.allclose(h1, truth) assert F.allclose(h1, truth)
h1 = avg_pool(bg, h0) h1 = avg_pool(bg, h0)
truth = tf.stack([F.mean(h0[:15], 0), truth = tf.stack(
F.mean(h0[15:20], 0), [
F.mean(h0[20:35], 0), F.mean(h0[:15], 0),
F.mean(h0[35:40], 0), F.mean(h0[15:20], 0),
F.mean(h0[40:55], 0)], 0) F.mean(h0[20:35], 0),
F.mean(h0[35:40], 0),
F.mean(h0[40:55], 0),
],
0,
)
assert F.allclose(h1, truth) assert F.allclose(h1, truth)
h1 = max_pool(bg, h0) h1 = max_pool(bg, h0)
truth = tf.stack([F.max(h0[:15], 0), truth = tf.stack(
F.max(h0[15:20], 0), [
F.max(h0[20:35], 0), F.max(h0[:15], 0),
F.max(h0[35:40], 0), F.max(h0[15:20], 0),
F.max(h0[40:55], 0)], 0) F.max(h0[20:35], 0),
F.max(h0[35:40], 0),
F.max(h0[40:55], 0),
],
0,
)
assert F.allclose(h1, truth) assert F.allclose(h1, truth)
h1 = sort_pool(bg, h0) h1 = sort_pool(bg, h0)
assert h1.shape[0] == 5 and h1.shape[1] == 10 * 5 and h1.ndim == 2 assert h1.shape[0] == 5 and h1.shape[1] == 10 * 5 and h1.ndim == 2
def test_glob_att_pool(): def test_glob_att_pool():
g = dgl.DGLGraph(nx.path_graph(10)).to(F.ctx()) g = dgl.DGLGraph(nx.path_graph(10)).to(F.ctx())
...@@ -182,10 +216,12 @@ def test_glob_att_pool(): ...@@ -182,10 +216,12 @@ def test_glob_att_pool():
assert h1.shape[0] == 4 and h1.shape[1] == 10 and h1.ndim == 2 assert h1.shape[0] == 4 and h1.shape[1] == 10 and h1.ndim == 2
@pytest.mark.parametrize('O', [1, 2, 8]) @pytest.mark.parametrize("O", [1, 2, 8])
def test_rgcn(O): def test_rgcn(O):
etype = [] etype = []
g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True).to(F.ctx()) g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True).to(
F.ctx()
)
# 5 etypes # 5 etypes
R = 5 R = 5
for i in range(g.number_of_edges()): for i in range(g.number_of_edges()):
...@@ -262,10 +298,13 @@ def test_rgcn(O): ...@@ -262,10 +298,13 @@ def test_rgcn(O):
assert list(h_new_low.shape) == [100, O] assert list(h_new_low.shape) == [100, O]
assert F.allclose(h_new, h_new_low) assert F.allclose(h_new, h_new_low)
@parametrize_idtype @parametrize_idtype
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree'])) @pytest.mark.parametrize(
@pytest.mark.parametrize('out_dim', [1, 2]) "g", get_cases(["homo", "block-bipartite"], exclude=["zero-degree"])
@pytest.mark.parametrize('num_heads', [1, 4]) )
@pytest.mark.parametrize("out_dim", [1, 2])
@pytest.mark.parametrize("num_heads", [1, 4])
def test_gat_conv(g, idtype, out_dim, num_heads): def test_gat_conv(g, idtype, out_dim, num_heads):
g = g.astype(idtype).to(F.ctx()) g = g.astype(idtype).to(F.ctx())
ctx = F.ctx() ctx = F.ctx()
...@@ -280,24 +319,29 @@ def test_gat_conv(g, idtype, out_dim, num_heads): ...@@ -280,24 +319,29 @@ def test_gat_conv(g, idtype, out_dim, num_heads):
gat = nn.GATConv(5, out_dim, num_heads, residual=True) gat = nn.GATConv(5, out_dim, num_heads, residual=True)
h = gat(g, feat) h = gat(g, feat)
@parametrize_idtype @parametrize_idtype
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree'])) @pytest.mark.parametrize("g", get_cases(["bipartite"], exclude=["zero-degree"]))
@pytest.mark.parametrize('out_dim', [1, 2]) @pytest.mark.parametrize("out_dim", [1, 2])
@pytest.mark.parametrize('num_heads', [1, 4]) @pytest.mark.parametrize("num_heads", [1, 4])
def test_gat_conv_bi(g, idtype, out_dim, num_heads): def test_gat_conv_bi(g, idtype, out_dim, num_heads):
g = g.astype(idtype).to(F.ctx()) g = g.astype(idtype).to(F.ctx())
ctx = F.ctx() ctx = F.ctx()
gat = nn.GATConv(5, out_dim, num_heads) gat = nn.GATConv(5, out_dim, num_heads)
feat = (F.randn((g.number_of_src_nodes(), 5)), F.randn((g.number_of_dst_nodes(), 5))) feat = (
F.randn((g.number_of_src_nodes(), 5)),
F.randn((g.number_of_dst_nodes(), 5)),
)
h = gat(g, feat) h = gat(g, feat)
assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim) assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)
_, a = gat(g, feat, get_attention=True) _, a = gat(g, feat, get_attention=True)
assert a.shape == (g.number_of_edges(), num_heads, 1) assert a.shape == (g.number_of_edges(), num_heads, 1)
@parametrize_idtype @parametrize_idtype
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'])) @pytest.mark.parametrize("g", get_cases(["homo", "block-bipartite"]))
@pytest.mark.parametrize('aggre_type', ['mean', 'pool', 'gcn']) @pytest.mark.parametrize("aggre_type", ["mean", "pool", "gcn"])
@pytest.mark.parametrize('out_dim', [1, 10]) @pytest.mark.parametrize("out_dim", [1, 10])
def test_sage_conv(idtype, g, aggre_type, out_dim): def test_sage_conv(idtype, g, aggre_type, out_dim):
g = g.astype(idtype).to(F.ctx()) g = g.astype(idtype).to(F.ctx())
sage = nn.SAGEConv(5, out_dim, aggre_type) sage = nn.SAGEConv(5, out_dim, aggre_type)
...@@ -305,41 +349,49 @@ def test_sage_conv(idtype, g, aggre_type, out_dim): ...@@ -305,41 +349,49 @@ def test_sage_conv(idtype, g, aggre_type, out_dim):
h = sage(g, feat) h = sage(g, feat)
assert h.shape[-1] == out_dim assert h.shape[-1] == out_dim
@parametrize_idtype @parametrize_idtype
@pytest.mark.parametrize('g', get_cases(['bipartite'])) @pytest.mark.parametrize("g", get_cases(["bipartite"]))
@pytest.mark.parametrize('aggre_type', ['mean', 'pool', 'gcn']) @pytest.mark.parametrize("aggre_type", ["mean", "pool", "gcn"])
@pytest.mark.parametrize('out_dim', [1, 2]) @pytest.mark.parametrize("out_dim", [1, 2])
def test_sage_conv_bi(idtype, g, aggre_type, out_dim): def test_sage_conv_bi(idtype, g, aggre_type, out_dim):
g = g.astype(idtype).to(F.ctx()) g = g.astype(idtype).to(F.ctx())
dst_dim = 5 if aggre_type != 'gcn' else 10 dst_dim = 5 if aggre_type != "gcn" else 10
sage = nn.SAGEConv((10, dst_dim), out_dim, aggre_type) sage = nn.SAGEConv((10, dst_dim), out_dim, aggre_type)
feat = (F.randn((g.number_of_src_nodes(), 10)), F.randn((g.number_of_dst_nodes(), dst_dim))) feat = (
F.randn((g.number_of_src_nodes(), 10)),
F.randn((g.number_of_dst_nodes(), dst_dim)),
)
h = sage(g, feat) h = sage(g, feat)
assert h.shape[-1] == out_dim assert h.shape[-1] == out_dim
assert h.shape[0] == g.number_of_dst_nodes() assert h.shape[0] == g.number_of_dst_nodes()
@parametrize_idtype @parametrize_idtype
@pytest.mark.parametrize('aggre_type', ['mean', 'pool', 'gcn']) @pytest.mark.parametrize("aggre_type", ["mean", "pool", "gcn"])
@pytest.mark.parametrize('out_dim', [1, 2]) @pytest.mark.parametrize("out_dim", [1, 2])
def test_sage_conv_bi_empty(idtype, aggre_type, out_dim): def test_sage_conv_bi_empty(idtype, aggre_type, out_dim):
# Test the case for graphs without edges # Test the case for graphs without edges
g = dgl.heterograph({('_U', '_E', '_V'): ([], [])}, {'_U': 5, '_V': 3}).to(F.ctx()) g = dgl.heterograph({("_U", "_E", "_V"): ([], [])}, {"_U": 5, "_V": 3}).to(
F.ctx()
)
g = g.astype(idtype).to(F.ctx()) g = g.astype(idtype).to(F.ctx())
sage = nn.SAGEConv((3, 3), out_dim, 'gcn') sage = nn.SAGEConv((3, 3), out_dim, "gcn")
feat = (F.randn((5, 3)), F.randn((3, 3))) feat = (F.randn((5, 3)), F.randn((3, 3)))
h = sage(g, feat) h = sage(g, feat)
assert h.shape[-1] == out_dim assert h.shape[-1] == out_dim
assert h.shape[0] == 3 assert h.shape[0] == 3
for aggre_type in ['mean', 'pool', 'lstm']: for aggre_type in ["mean", "pool", "lstm"]:
sage = nn.SAGEConv((3, 1), out_dim, aggre_type) sage = nn.SAGEConv((3, 1), out_dim, aggre_type)
feat = (F.randn((5, 3)), F.randn((3, 1))) feat = (F.randn((5, 3)), F.randn((3, 1)))
h = sage(g, feat) h = sage(g, feat)
assert h.shape[-1] == out_dim assert h.shape[-1] == out_dim
assert h.shape[0] == 3 assert h.shape[0] == 3
@parametrize_idtype @parametrize_idtype
@pytest.mark.parametrize('g', get_cases(['homo'], exclude=['zero-degree'])) @pytest.mark.parametrize("g", get_cases(["homo"], exclude=["zero-degree"]))
@pytest.mark.parametrize('out_dim', [1, 2]) @pytest.mark.parametrize("out_dim", [1, 2])
def test_sgc_conv(g, idtype, out_dim): def test_sgc_conv(g, idtype, out_dim):
ctx = F.ctx() ctx = F.ctx()
g = g.astype(idtype).to(ctx) g = g.astype(idtype).to(ctx)
...@@ -357,8 +409,9 @@ def test_sgc_conv(g, idtype, out_dim): ...@@ -357,8 +409,9 @@ def test_sgc_conv(g, idtype, out_dim):
assert F.allclose(h_0, h_1) assert F.allclose(h_0, h_1)
assert h_0.shape[-1] == out_dim assert h_0.shape[-1] == out_dim
@parametrize_idtype @parametrize_idtype
@pytest.mark.parametrize('g', get_cases(['homo'], exclude=['zero-degree'])) @pytest.mark.parametrize("g", get_cases(["homo"], exclude=["zero-degree"]))
def test_appnp_conv(g, idtype): def test_appnp_conv(g, idtype):
ctx = F.ctx() ctx = F.ctx()
g = g.astype(idtype).to(ctx) g = g.astype(idtype).to(ctx)
...@@ -368,36 +421,38 @@ def test_appnp_conv(g, idtype): ...@@ -368,36 +421,38 @@ def test_appnp_conv(g, idtype):
h = appnp(g, feat) h = appnp(g, feat)
assert h.shape[-1] == 5 assert h.shape[-1] == 5
@parametrize_idtype @parametrize_idtype
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'])) @pytest.mark.parametrize("g", get_cases(["homo", "block-bipartite"]))
@pytest.mark.parametrize('aggregator_type', ['mean', 'max', 'sum']) @pytest.mark.parametrize("aggregator_type", ["mean", "max", "sum"])
def test_gin_conv(g, idtype, aggregator_type): def test_gin_conv(g, idtype, aggregator_type):
g = g.astype(idtype).to(F.ctx()) g = g.astype(idtype).to(F.ctx())
ctx = F.ctx() ctx = F.ctx()
gin = nn.GINConv( gin = nn.GINConv(tf.keras.layers.Dense(12), aggregator_type)
tf.keras.layers.Dense(12),
aggregator_type
)
feat = F.randn((g.number_of_src_nodes(), 5)) feat = F.randn((g.number_of_src_nodes(), 5))
h = gin(g, feat) h = gin(g, feat)
assert h.shape == (g.number_of_dst_nodes(), 12) assert h.shape == (g.number_of_dst_nodes(), 12)
@parametrize_idtype @parametrize_idtype
@pytest.mark.parametrize('g', get_cases(['bipartite'])) @pytest.mark.parametrize("g", get_cases(["bipartite"]))
@pytest.mark.parametrize('aggregator_type', ['mean', 'max', 'sum']) @pytest.mark.parametrize("aggregator_type", ["mean", "max", "sum"])
def test_gin_conv_bi(g, idtype, aggregator_type): def test_gin_conv_bi(g, idtype, aggregator_type):
g = g.astype(idtype).to(F.ctx()) g = g.astype(idtype).to(F.ctx())
gin = nn.GINConv( gin = nn.GINConv(tf.keras.layers.Dense(12), aggregator_type)
tf.keras.layers.Dense(12), feat = (
aggregator_type F.randn((g.number_of_src_nodes(), 5)),
F.randn((g.number_of_dst_nodes(), 5)),
) )
feat = (F.randn((g.number_of_src_nodes(), 5)), F.randn((g.number_of_dst_nodes(), 5)))
h = gin(g, feat) h = gin(g, feat)
assert h.shape == (g.number_of_dst_nodes(), 12) assert h.shape == (g.number_of_dst_nodes(), 12)
@parametrize_idtype @parametrize_idtype
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree'])) @pytest.mark.parametrize(
@pytest.mark.parametrize('out_dim', [1, 2]) "g", get_cases(["homo", "block-bipartite"], exclude=["zero-degree"])
)
@pytest.mark.parametrize("out_dim", [1, 2])
def test_edge_conv(g, idtype, out_dim): def test_edge_conv(g, idtype, out_dim):
g = g.astype(idtype).to(F.ctx()) g = g.astype(idtype).to(F.ctx())
edge_conv = nn.EdgeConv(out_dim) edge_conv = nn.EdgeConv(out_dim)
...@@ -406,9 +461,10 @@ def test_edge_conv(g, idtype, out_dim): ...@@ -406,9 +461,10 @@ def test_edge_conv(g, idtype, out_dim):
h1 = edge_conv(g, h0) h1 = edge_conv(g, h0)
assert h1.shape == (g.number_of_dst_nodes(), out_dim) assert h1.shape == (g.number_of_dst_nodes(), out_dim)
@parametrize_idtype @parametrize_idtype
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree'])) @pytest.mark.parametrize("g", get_cases(["bipartite"], exclude=["zero-degree"]))
@pytest.mark.parametrize('out_dim', [1, 2]) @pytest.mark.parametrize("out_dim", [1, 2])
def test_edge_conv_bi(g, idtype, out_dim): def test_edge_conv_bi(g, idtype, out_dim):
g = g.astype(idtype).to(F.ctx()) g = g.astype(idtype).to(F.ctx())
ctx = F.ctx() ctx = F.ctx()
...@@ -419,56 +475,73 @@ def test_edge_conv_bi(g, idtype, out_dim): ...@@ -419,56 +475,73 @@ def test_edge_conv_bi(g, idtype, out_dim):
h1 = edge_conv(g, (h0, x0)) h1 = edge_conv(g, (h0, x0))
assert h1.shape == (g.number_of_dst_nodes(), out_dim) assert h1.shape == (g.number_of_dst_nodes(), out_dim)
def myagg(alist, dsttype): def myagg(alist, dsttype):
rst = alist[0] rst = alist[0]
for i in range(1, len(alist)): for i in range(1, len(alist)):
rst = rst + (i + 1) * alist[i] rst = rst + (i + 1) * alist[i]
return rst return rst
@parametrize_idtype @parametrize_idtype
@pytest.mark.parametrize('agg', ['sum', 'max', 'min', 'mean', 'stack', myagg]) @pytest.mark.parametrize("agg", ["sum", "max", "min", "mean", "stack", myagg])
def test_hetero_conv(agg, idtype): def test_hetero_conv(agg, idtype):
g = dgl.heterograph({ g = dgl.heterograph(
('user', 'follows', 'user'): ([0, 0, 2, 1], [1, 2, 1, 3]), {
('user', 'plays', 'game'): ([0, 0, 0, 1, 2], [0, 2, 3, 0, 2]), ("user", "follows", "user"): ([0, 0, 2, 1], [1, 2, 1, 3]),
('store', 'sells', 'game'): ([0, 0, 1, 1], [0, 3, 1, 2])}, ("user", "plays", "game"): ([0, 0, 0, 1, 2], [0, 2, 3, 0, 2]),
idtype=idtype, device=F.ctx()) ("store", "sells", "game"): ([0, 0, 1, 1], [0, 3, 1, 2]),
conv = nn.HeteroGraphConv({ },
'follows': nn.GraphConv(2, 3, allow_zero_in_degree=True), idtype=idtype,
'plays': nn.GraphConv(2, 4, allow_zero_in_degree=True), device=F.ctx(),
'sells': nn.GraphConv(3, 4, allow_zero_in_degree=True)}, )
agg) conv = nn.HeteroGraphConv(
{
"follows": nn.GraphConv(2, 3, allow_zero_in_degree=True),
"plays": nn.GraphConv(2, 4, allow_zero_in_degree=True),
"sells": nn.GraphConv(3, 4, allow_zero_in_degree=True),
},
agg,
)
uf = F.randn((4, 2)) uf = F.randn((4, 2))
gf = F.randn((4, 4)) gf = F.randn((4, 4))
sf = F.randn((2, 3)) sf = F.randn((2, 3))
h = conv(g, {'user': uf, 'store': sf, 'game': gf}) h = conv(g, {"user": uf, "store": sf, "game": gf})
assert set(h.keys()) == {'user', 'game'} assert set(h.keys()) == {"user", "game"}
if agg != 'stack': if agg != "stack":
assert h['user'].shape == (4, 3) assert h["user"].shape == (4, 3)
assert h['game'].shape == (4, 4) assert h["game"].shape == (4, 4)
else: else:
assert h['user'].shape == (4, 1, 3) assert h["user"].shape == (4, 1, 3)
assert h['game'].shape == (4, 2, 4) assert h["game"].shape == (4, 2, 4)
block = dgl.to_block(g.to(F.cpu()), {'user': [0, 1, 2, 3], 'game': [0, 1, 2, 3], 'store': []}).to(F.ctx()) block = dgl.to_block(
h = conv(block, ({'user': uf, 'game': gf, 'store': sf}, {'user': uf, 'game': gf, 'store': sf[0:0]})) g.to(F.cpu()), {"user": [0, 1, 2, 3], "game": [0, 1, 2, 3], "store": []}
assert set(h.keys()) == {'user', 'game'} ).to(F.ctx())
if agg != 'stack': h = conv(
assert h['user'].shape == (4, 3) block,
assert h['game'].shape == (4, 4) (
{"user": uf, "game": gf, "store": sf},
{"user": uf, "game": gf, "store": sf[0:0]},
),
)
assert set(h.keys()) == {"user", "game"}
if agg != "stack":
assert h["user"].shape == (4, 3)
assert h["game"].shape == (4, 4)
else: else:
assert h['user'].shape == (4, 1, 3) assert h["user"].shape == (4, 1, 3)
assert h['game'].shape == (4, 2, 4) assert h["game"].shape == (4, 2, 4)
h = conv(block, {'user': uf, 'game': gf, 'store': sf}) h = conv(block, {"user": uf, "game": gf, "store": sf})
assert set(h.keys()) == {'user', 'game'} assert set(h.keys()) == {"user", "game"}
if agg != 'stack': if agg != "stack":
assert h['user'].shape == (4, 3) assert h["user"].shape == (4, 3)
assert h['game'].shape == (4, 4) assert h["game"].shape == (4, 4)
else: else:
assert h['user'].shape == (4, 1, 3) assert h["user"].shape == (4, 1, 3)
assert h['game'].shape == (4, 2, 4) assert h["game"].shape == (4, 2, 4)
# test with mod args # test with mod args
class MyMod(tf.keras.layers.Layer): class MyMod(tf.keras.layers.Layer):
...@@ -478,23 +551,28 @@ def test_hetero_conv(agg, idtype): ...@@ -478,23 +551,28 @@ def test_hetero_conv(agg, idtype):
self.carg2 = 0 self.carg2 = 0
self.s1 = s1 self.s1 = s1
self.s2 = s2 self.s2 = s2
def call(self, g, h, arg1=None, *, arg2=None): def call(self, g, h, arg1=None, *, arg2=None):
if arg1 is not None: if arg1 is not None:
self.carg1 += 1 self.carg1 += 1
if arg2 is not None: if arg2 is not None:
self.carg2 += 1 self.carg2 += 1
return tf.zeros((g.number_of_dst_nodes(), self.s2)) return tf.zeros((g.number_of_dst_nodes(), self.s2))
mod1 = MyMod(2, 3) mod1 = MyMod(2, 3)
mod2 = MyMod(2, 4) mod2 = MyMod(2, 4)
mod3 = MyMod(3, 4) mod3 = MyMod(3, 4)
conv = nn.HeteroGraphConv({ conv = nn.HeteroGraphConv(
'follows': mod1, {"follows": mod1, "plays": mod2, "sells": mod3}, agg
'plays': mod2, )
'sells': mod3}, mod_args = {"follows": (1,), "plays": (1,)}
agg) mod_kwargs = {"sells": {"arg2": "abc"}}
mod_args = {'follows' : (1,), 'plays' : (1,)} h = conv(
mod_kwargs = {'sells' : {'arg2' : 'abc'}} g,
h = conv(g, {'user' : uf, 'game': gf, 'store' : sf}, mod_args=mod_args, mod_kwargs=mod_kwargs) {"user": uf, "game": gf, "store": sf},
mod_args=mod_args,
mod_kwargs=mod_kwargs,
)
assert mod1.carg1 == 1 assert mod1.carg1 == 1
assert mod1.carg2 == 0 assert mod1.carg2 == 0
assert mod2.carg1 == 1 assert mod2.carg1 == 1
...@@ -502,28 +580,38 @@ def test_hetero_conv(agg, idtype): ...@@ -502,28 +580,38 @@ def test_hetero_conv(agg, idtype):
assert mod3.carg1 == 0 assert mod3.carg1 == 0
assert mod3.carg2 == 1 assert mod3.carg2 == 1
#conv on graph without any edges # conv on graph without any edges
for etype in g.etypes: for etype in g.etypes:
g = dgl.remove_edges(g, g.edges(form='eid', etype=etype), etype=etype) g = dgl.remove_edges(g, g.edges(form="eid", etype=etype), etype=etype)
assert g.num_edges() == 0 assert g.num_edges() == 0
h = conv(g, {'user': uf, 'game': gf, 'store': sf}) h = conv(g, {"user": uf, "game": gf, "store": sf})
assert set(h.keys()) == {'user', 'game'} assert set(h.keys()) == {"user", "game"}
block = dgl.to_block(g.to(F.cpu()), {'user': [0, 1, 2, 3], 'game': [ block = dgl.to_block(
0, 1, 2, 3], 'store': []}).to(F.ctx()) g.to(F.cpu()), {"user": [0, 1, 2, 3], "game": [0, 1, 2, 3], "store": []}
h = conv(block, ({'user': uf, 'game': gf, 'store': sf}, ).to(F.ctx())
{'user': uf, 'game': gf, 'store': sf[0:0]})) h = conv(
assert set(h.keys()) == {'user', 'game'} block,
(
{"user": uf, "game": gf, "store": sf},
{"user": uf, "game": gf, "store": sf[0:0]},
),
)
assert set(h.keys()) == {"user", "game"}
@pytest.mark.parametrize('out_dim', [1, 2]) @pytest.mark.parametrize("out_dim", [1, 2])
def test_dense_cheb_conv(out_dim): def test_dense_cheb_conv(out_dim):
for k in range(3, 4): for k in range(3, 4):
ctx = F.ctx() ctx = F.ctx()
g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1, random_state=42)) g = dgl.DGLGraph(
sp.sparse.random(100, 100, density=0.1, random_state=42)
)
g = g.to(ctx) g = g.to(ctx)
adj = tf.sparse.to_dense(tf.sparse.reorder(g.adjacency_matrix(transpose=True, ctx=ctx))) adj = tf.sparse.to_dense(
tf.sparse.reorder(g.adjacency_matrix(transpose=True, ctx=ctx))
)
cheb = nn.ChebConv(5, out_dim, k, None, bias=True) cheb = nn.ChebConv(5, out_dim, k, None, bias=True)
dense_cheb = nn.DenseChebConv(5, out_dim, k, bias=True) dense_cheb = nn.DenseChebConv(5, out_dim, k, bias=True)
...@@ -540,7 +628,7 @@ def test_dense_cheb_conv(out_dim): ...@@ -540,7 +628,7 @@ def test_dense_cheb_conv(out_dim):
assert F.allclose(out_cheb, out_dense_cheb) assert F.allclose(out_cheb, out_dense_cheb)
if __name__ == '__main__': if __name__ == "__main__":
test_graph_conv() test_graph_conv()
# test_set2set() # test_set2set()
test_glob_att_pool() test_glob_att_pool()
......
import pytest
import backend as F import backend as F
import pytest
parametrize_idtype = pytest.mark.parametrize("idtype", [F.int32, F.int64]) parametrize_idtype = pytest.mark.parametrize("idtype", [F.int32, F.int64])
......
import dgl
import backend as F import backend as F
__all__ = ['check_graph_equal'] import dgl
__all__ = ["check_graph_equal"]
def check_graph_equal(g1, g2, *, def check_graph_equal(g1, g2, *, check_idtype=True, check_feature=True):
check_idtype=True,
check_feature=True):
assert g1.device == g1.device assert g1.device == g1.device
if check_idtype: if check_idtype:
assert g1.idtype == g2.idtype assert g1.idtype == g2.idtype
...@@ -26,8 +26,8 @@ def check_graph_equal(g1, g2, *, ...@@ -26,8 +26,8 @@ def check_graph_equal(g1, g2, *,
for ety in g1.canonical_etypes: for ety in g1.canonical_etypes:
assert g1.number_of_edges(ety) == g2.number_of_edges(ety) assert g1.number_of_edges(ety) == g2.number_of_edges(ety)
assert F.allclose(g1.batch_num_edges(ety), g2.batch_num_edges(ety)) assert F.allclose(g1.batch_num_edges(ety), g2.batch_num_edges(ety))
src1, dst1, eid1 = g1.edges(etype=ety, form='all') src1, dst1, eid1 = g1.edges(etype=ety, form="all")
src2, dst2, eid2 = g2.edges(etype=ety, form='all') src2, dst2, eid2 = g2.edges(etype=ety, form="all")
if check_idtype: if check_idtype:
assert F.allclose(src1, src2) assert F.allclose(src1, src2)
assert F.allclose(dst1, dst2) assert F.allclose(dst1, dst2)
...@@ -42,9 +42,13 @@ def check_graph_equal(g1, g2, *, ...@@ -42,9 +42,13 @@ def check_graph_equal(g1, g2, *,
if g1.number_of_nodes(nty) == 0: if g1.number_of_nodes(nty) == 0:
continue continue
for feat_name in g1.nodes[nty].data.keys(): for feat_name in g1.nodes[nty].data.keys():
assert F.allclose(g1.nodes[nty].data[feat_name], g2.nodes[nty].data[feat_name]) assert F.allclose(
g1.nodes[nty].data[feat_name], g2.nodes[nty].data[feat_name]
)
for ety in g1.canonical_etypes: for ety in g1.canonical_etypes:
if g1.number_of_edges(ety) == 0: if g1.number_of_edges(ety) == 0:
continue continue
for feat_name in g2.edges[ety].data.keys(): for feat_name in g2.edges[ety].data.keys():
assert F.allclose(g1.edges[ety].data[feat_name], g2.edges[ety].data[feat_name]) assert F.allclose(
g1.edges[ety].data[feat_name], g2.edges[ety].data[feat_name]
)
...@@ -3,14 +3,15 @@ import os ...@@ -3,14 +3,15 @@ import os
import tempfile import tempfile
import unittest import unittest
import dgl
import numpy as np import numpy as np
import pytest import pytest
import torch import torch
from chunk_graph import chunk_graph from chunk_graph import chunk_graph
from create_chunked_dataset import create_chunked_dataset
import dgl
from dgl.data.utils import load_graphs, load_tensors from dgl.data.utils import load_graphs, load_tensors
from create_chunked_dataset import create_chunked_dataset
@pytest.mark.parametrize("num_chunks", [1, 8]) @pytest.mark.parametrize("num_chunks", [1, 8])
def test_chunk_graph(num_chunks): def test_chunk_graph(num_chunks):
...@@ -19,43 +20,43 @@ def test_chunk_graph(num_chunks): ...@@ -19,43 +20,43 @@ def test_chunk_graph(num_chunks):
g = create_chunked_dataset(root_dir, num_chunks, include_edge_data=True) g = create_chunked_dataset(root_dir, num_chunks, include_edge_data=True)
num_cite_edges = g.number_of_edges('cites') num_cite_edges = g.number_of_edges("cites")
num_write_edges = g.number_of_edges('writes') num_write_edges = g.number_of_edges("writes")
num_affiliate_edges = g.number_of_edges('affiliated_with') num_affiliate_edges = g.number_of_edges("affiliated_with")
num_institutions = g.number_of_nodes('institution') num_institutions = g.number_of_nodes("institution")
num_authors = g.number_of_nodes('author') num_authors = g.number_of_nodes("author")
num_papers = g.number_of_nodes('paper') num_papers = g.number_of_nodes("paper")
# check metadata.json # check metadata.json
output_dir = os.path.join(root_dir, 'chunked-data') output_dir = os.path.join(root_dir, "chunked-data")
json_file = os.path.join(output_dir, 'metadata.json') json_file = os.path.join(output_dir, "metadata.json")
assert os.path.isfile(json_file) assert os.path.isfile(json_file)
with open(json_file, 'rb') as f: with open(json_file, "rb") as f:
meta_data = json.load(f) meta_data = json.load(f)
assert meta_data['graph_name'] == 'mag240m' assert meta_data["graph_name"] == "mag240m"
assert len(meta_data['num_nodes_per_chunk'][0]) == num_chunks assert len(meta_data["num_nodes_per_chunk"][0]) == num_chunks
# check edge_index # check edge_index
output_edge_index_dir = os.path.join(output_dir, 'edge_index') output_edge_index_dir = os.path.join(output_dir, "edge_index")
for utype, etype, vtype in g.canonical_etypes: for utype, etype, vtype in g.canonical_etypes:
fname = ':'.join([utype, etype, vtype]) fname = ":".join([utype, etype, vtype])
for i in range(num_chunks): for i in range(num_chunks):
chunk_f_name = os.path.join( chunk_f_name = os.path.join(
output_edge_index_dir, fname + str(i) + '.txt' output_edge_index_dir, fname + str(i) + ".txt"
) )
assert os.path.isfile(chunk_f_name) assert os.path.isfile(chunk_f_name)
with open(chunk_f_name, 'r') as f: with open(chunk_f_name, "r") as f:
header = f.readline() header = f.readline()
num1, num2 = header.rstrip().split(' ') num1, num2 = header.rstrip().split(" ")
assert isinstance(int(num1), int) assert isinstance(int(num1), int)
assert isinstance(int(num2), int) assert isinstance(int(num2), int)
# check node_data # check node_data
output_node_data_dir = os.path.join(output_dir, 'node_data', 'paper') output_node_data_dir = os.path.join(output_dir, "node_data", "paper")
for feat in ['feat', 'label', 'year']: for feat in ["feat", "label", "year"]:
for i in range(num_chunks): for i in range(num_chunks):
chunk_f_name = '{}-{}.npy'.format(feat, i) chunk_f_name = "{}-{}.npy".format(feat, i)
chunk_f_name = os.path.join(output_node_data_dir, chunk_f_name) chunk_f_name = os.path.join(output_node_data_dir, chunk_f_name)
assert os.path.isfile(chunk_f_name) assert os.path.isfile(chunk_f_name)
feat_array = np.load(chunk_f_name) feat_array = np.load(chunk_f_name)
...@@ -63,19 +64,19 @@ def test_chunk_graph(num_chunks): ...@@ -63,19 +64,19 @@ def test_chunk_graph(num_chunks):
# check edge_data # check edge_data
num_edges = { num_edges = {
'paper:cites:paper': num_cite_edges, "paper:cites:paper": num_cite_edges,
'author:writes:paper': num_write_edges, "author:writes:paper": num_write_edges,
'paper:rev_writes:author': num_write_edges, "paper:rev_writes:author": num_write_edges,
} }
output_edge_data_dir = os.path.join(output_dir, 'edge_data') output_edge_data_dir = os.path.join(output_dir, "edge_data")
for etype, feat in [ for etype, feat in [
['paper:cites:paper', 'count'], ["paper:cites:paper", "count"],
['author:writes:paper', 'year'], ["author:writes:paper", "year"],
['paper:rev_writes:author', 'year'], ["paper:rev_writes:author", "year"],
]: ]:
output_edge_sub_dir = os.path.join(output_edge_data_dir, etype) output_edge_sub_dir = os.path.join(output_edge_data_dir, etype)
for i in range(num_chunks): for i in range(num_chunks):
chunk_f_name = '{}-{}.npy'.format(feat, i) chunk_f_name = "{}-{}.npy".format(feat, i)
chunk_f_name = os.path.join(output_edge_sub_dir, chunk_f_name) chunk_f_name = os.path.join(output_edge_sub_dir, chunk_f_name)
assert os.path.isfile(chunk_f_name) assert os.path.isfile(chunk_f_name)
feat_array = np.load(chunk_f_name) feat_array = np.load(chunk_f_name)
...@@ -100,63 +101,63 @@ def test_part_pipeline(num_chunks, num_parts): ...@@ -100,63 +101,63 @@ def test_part_pipeline(num_chunks, num_parts):
all_ntypes = g.ntypes all_ntypes = g.ntypes
all_etypes = g.etypes all_etypes = g.etypes
num_cite_edges = g.number_of_edges('cites') num_cite_edges = g.number_of_edges("cites")
num_write_edges = g.number_of_edges('writes') num_write_edges = g.number_of_edges("writes")
num_affiliate_edges = g.number_of_edges('affiliated_with') num_affiliate_edges = g.number_of_edges("affiliated_with")
num_institutions = g.number_of_nodes('institution') num_institutions = g.number_of_nodes("institution")
num_authors = g.number_of_nodes('author') num_authors = g.number_of_nodes("author")
num_papers = g.number_of_nodes('paper') num_papers = g.number_of_nodes("paper")
# Step1: graph partition # Step1: graph partition
in_dir = os.path.join(root_dir, 'chunked-data') in_dir = os.path.join(root_dir, "chunked-data")
output_dir = os.path.join(root_dir, 'parted_data') output_dir = os.path.join(root_dir, "parted_data")
os.system( os.system(
'python3 tools/partition_algo/random_partition.py ' "python3 tools/partition_algo/random_partition.py "
'--in_dir {} --out_dir {} --num_partitions {}'.format( "--in_dir {} --out_dir {} --num_partitions {}".format(
in_dir, output_dir, num_parts in_dir, output_dir, num_parts
) )
) )
for ntype in ['author', 'institution', 'paper']: for ntype in ["author", "institution", "paper"]:
fname = os.path.join(output_dir, '{}.txt'.format(ntype)) fname = os.path.join(output_dir, "{}.txt".format(ntype))
with open(fname, 'r') as f: with open(fname, "r") as f:
header = f.readline().rstrip() header = f.readline().rstrip()
assert isinstance(int(header), int) assert isinstance(int(header), int)
# Step2: data dispatch # Step2: data dispatch
partition_dir = os.path.join(root_dir, 'parted_data') partition_dir = os.path.join(root_dir, "parted_data")
out_dir = os.path.join(root_dir, 'partitioned') out_dir = os.path.join(root_dir, "partitioned")
ip_config = os.path.join(root_dir, 'ip_config.txt') ip_config = os.path.join(root_dir, "ip_config.txt")
with open(ip_config, 'w') as f: with open(ip_config, "w") as f:
for i in range(num_parts): for i in range(num_parts):
f.write(f'127.0.0.{i + 1}\n') f.write(f"127.0.0.{i + 1}\n")
cmd = 'python3 tools/dispatch_data.py' cmd = "python3 tools/dispatch_data.py"
cmd += f' --in-dir {in_dir}' cmd += f" --in-dir {in_dir}"
cmd += f' --partitions-dir {partition_dir}' cmd += f" --partitions-dir {partition_dir}"
cmd += f' --out-dir {out_dir}' cmd += f" --out-dir {out_dir}"
cmd += f' --ip-config {ip_config}' cmd += f" --ip-config {ip_config}"
cmd += ' --process-group-timeout 60' cmd += " --process-group-timeout 60"
cmd += ' --save-orig-nids' cmd += " --save-orig-nids"
cmd += ' --save-orig-eids' cmd += " --save-orig-eids"
os.system(cmd) os.system(cmd)
# check metadata.json # check metadata.json
meta_fname = os.path.join(out_dir, 'metadata.json') meta_fname = os.path.join(out_dir, "metadata.json")
with open(meta_fname, 'rb') as f: with open(meta_fname, "rb") as f:
meta_data = json.load(f) meta_data = json.load(f)
for etype in all_etypes: for etype in all_etypes:
assert len(meta_data['edge_map'][etype]) == num_parts assert len(meta_data["edge_map"][etype]) == num_parts
assert meta_data['etypes'].keys() == set(all_etypes) assert meta_data["etypes"].keys() == set(all_etypes)
assert meta_data['graph_name'] == 'mag240m' assert meta_data["graph_name"] == "mag240m"
for ntype in all_ntypes: for ntype in all_ntypes:
assert len(meta_data['node_map'][ntype]) == num_parts assert len(meta_data["node_map"][ntype]) == num_parts
assert meta_data['ntypes'].keys() == set(all_ntypes) assert meta_data["ntypes"].keys() == set(all_ntypes)
assert meta_data['num_edges'] == g.num_edges() assert meta_data["num_edges"] == g.num_edges()
assert meta_data['num_nodes'] == g.num_nodes() assert meta_data["num_nodes"] == g.num_nodes()
assert meta_data['num_parts'] == num_parts assert meta_data["num_parts"] == num_parts
edge_dict = {} edge_dict = {}
edge_data_gold = {} edge_data_gold = {}
...@@ -165,7 +166,7 @@ def test_part_pipeline(num_chunks, num_parts): ...@@ -165,7 +166,7 @@ def test_part_pipeline(num_chunks, num_parts):
# Create Id Map here. # Create Id Map here.
num_edges = 0 num_edges = 0
for utype, etype, vtype in g.canonical_etypes: for utype, etype, vtype in g.canonical_etypes:
fname = ':'.join([utype, etype, vtype]) fname = ":".join([utype, etype, vtype])
edge_dict[fname] = np.array( edge_dict[fname] = np.array(
[num_edges, num_edges + g.number_of_edges(etype)] [num_edges, num_edges + g.number_of_edges(etype)]
).reshape(1, 2) ).reshape(1, 2)
...@@ -177,21 +178,21 @@ def test_part_pipeline(num_chunks, num_parts): ...@@ -177,21 +178,21 @@ def test_part_pipeline(num_chunks, num_parts):
# check edge_data # check edge_data
num_edges = { num_edges = {
'paper:cites:paper': num_cite_edges, "paper:cites:paper": num_cite_edges,
'author:writes:paper': num_write_edges, "author:writes:paper": num_write_edges,
'paper:rev_writes:author': num_write_edges, "paper:rev_writes:author": num_write_edges,
} }
output_dir = os.path.join(root_dir, 'chunked-data') output_dir = os.path.join(root_dir, "chunked-data")
output_edge_data_dir = os.path.join(output_dir, 'edge_data') output_edge_data_dir = os.path.join(output_dir, "edge_data")
for etype, feat in [ for etype, feat in [
['paper:cites:paper', 'count'], ["paper:cites:paper", "count"],
['author:writes:paper', 'year'], ["author:writes:paper", "year"],
['paper:rev_writes:author', 'year'], ["paper:rev_writes:author", "year"],
]: ]:
output_edge_sub_dir = os.path.join(output_edge_data_dir, etype) output_edge_sub_dir = os.path.join(output_edge_data_dir, etype)
features = [] features = []
for i in range(num_chunks): for i in range(num_chunks):
chunk_f_name = '{}-{}.npy'.format(feat, i) chunk_f_name = "{}-{}.npy".format(feat, i)
chunk_f_name = os.path.join( chunk_f_name = os.path.join(
output_edge_sub_dir, chunk_f_name output_edge_sub_dir, chunk_f_name
) )
...@@ -199,54 +200,54 @@ def test_part_pipeline(num_chunks, num_parts): ...@@ -199,54 +200,54 @@ def test_part_pipeline(num_chunks, num_parts):
feat_array = np.load(chunk_f_name) feat_array = np.load(chunk_f_name)
assert feat_array.shape[0] == num_edges[etype] // num_chunks assert feat_array.shape[0] == num_edges[etype] // num_chunks
features.append(feat_array) features.append(feat_array)
edge_data_gold[etype + '/' + feat] = np.concatenate(features) edge_data_gold[etype + "/" + feat] = np.concatenate(features)
for i in range(num_parts): for i in range(num_parts):
sub_dir = 'part-' + str(i) sub_dir = "part-" + str(i)
assert meta_data[sub_dir][ assert meta_data[sub_dir][
'node_feats' "node_feats"
] == 'part{}/node_feat.dgl'.format(i) ] == "part{}/node_feat.dgl".format(i)
assert meta_data[sub_dir][ assert meta_data[sub_dir][
'edge_feats' "edge_feats"
] == 'part{}/edge_feat.dgl'.format(i) ] == "part{}/edge_feat.dgl".format(i)
assert meta_data[sub_dir][ assert meta_data[sub_dir][
'part_graph' "part_graph"
] == 'part{}/graph.dgl'.format(i) ] == "part{}/graph.dgl".format(i)
# check data # check data
sub_dir = os.path.join(out_dir, 'part' + str(i)) sub_dir = os.path.join(out_dir, "part" + str(i))
# graph.dgl # graph.dgl
fname = os.path.join(sub_dir, 'graph.dgl') fname = os.path.join(sub_dir, "graph.dgl")
assert os.path.isfile(fname) assert os.path.isfile(fname)
g_list, data_dict = load_graphs(fname) g_list, data_dict = load_graphs(fname)
part_g = g_list[0] part_g = g_list[0]
assert isinstance(part_g, dgl.DGLGraph) assert isinstance(part_g, dgl.DGLGraph)
# node_feat.dgl # node_feat.dgl
fname = os.path.join(sub_dir, 'node_feat.dgl') fname = os.path.join(sub_dir, "node_feat.dgl")
assert os.path.isfile(fname) assert os.path.isfile(fname)
tensor_dict = load_tensors(fname) tensor_dict = load_tensors(fname)
all_tensors = [ all_tensors = [
'paper/feat', "paper/feat",
'paper/label', "paper/label",
'paper/year', "paper/year",
'paper/orig_ids', "paper/orig_ids",
] ]
assert tensor_dict.keys() == set(all_tensors) assert tensor_dict.keys() == set(all_tensors)
for key in all_tensors: for key in all_tensors:
assert isinstance(tensor_dict[key], torch.Tensor) assert isinstance(tensor_dict[key], torch.Tensor)
ndata_paper_orig_ids = tensor_dict['paper/orig_ids'] ndata_paper_orig_ids = tensor_dict["paper/orig_ids"]
# orig_nids.dgl # orig_nids.dgl
fname = os.path.join(sub_dir, 'orig_nids.dgl') fname = os.path.join(sub_dir, "orig_nids.dgl")
assert os.path.isfile(fname) assert os.path.isfile(fname)
orig_nids = load_tensors(fname) orig_nids = load_tensors(fname)
assert len(orig_nids.keys()) == 3 assert len(orig_nids.keys()) == 3
assert torch.equal(ndata_paper_orig_ids, orig_nids['paper']) assert torch.equal(ndata_paper_orig_ids, orig_nids["paper"])
# orig_eids.dgl # orig_eids.dgl
fname = os.path.join(sub_dir, 'orig_eids.dgl') fname = os.path.join(sub_dir, "orig_eids.dgl")
assert os.path.isfile(fname) assert os.path.isfile(fname)
orig_eids = load_tensors(fname) orig_eids = load_tensors(fname)
assert len(orig_eids.keys()) == 4 assert len(orig_eids.keys()) == 4
...@@ -254,13 +255,13 @@ def test_part_pipeline(num_chunks, num_parts): ...@@ -254,13 +255,13 @@ def test_part_pipeline(num_chunks, num_parts):
if include_edge_data: if include_edge_data:
# Read edge_feat.dgl # Read edge_feat.dgl
fname = os.path.join(sub_dir, 'edge_feat.dgl') fname = os.path.join(sub_dir, "edge_feat.dgl")
assert os.path.isfile(fname) assert os.path.isfile(fname)
tensor_dict = load_tensors(fname) tensor_dict = load_tensors(fname)
all_tensors = [ all_tensors = [
'paper:cites:paper/count', "paper:cites:paper/count",
'author:writes:paper/year', "author:writes:paper/year",
'paper:rev_writes:author/year', "paper:rev_writes:author/year",
] ]
assert tensor_dict.keys() == set(all_tensors) assert tensor_dict.keys() == set(all_tensors)
for key in all_tensors: for key in all_tensors:
......
import unittest
import json import json
import tempfile
import os import os
import tempfile
import unittest
from launch import * from launch import *
class TestWrapUdfInTorchDistLauncher(unittest.TestCase): class TestWrapUdfInTorchDistLauncher(unittest.TestCase):
"""wrap_udf_in_torch_dist_launcher()""" """wrap_udf_in_torch_dist_launcher()"""
...@@ -18,14 +20,18 @@ class TestWrapUdfInTorchDistLauncher(unittest.TestCase): ...@@ -18,14 +20,18 @@ class TestWrapUdfInTorchDistLauncher(unittest.TestCase):
master_addr="127.0.0.1", master_addr="127.0.0.1",
master_port=1234, master_port=1234,
) )
expected = "python3.7 -m torch.distributed.launch " \ expected = (
"--nproc_per_node=2 --nnodes=2 --node_rank=1 --master_addr=127.0.0.1 " \ "python3.7 -m torch.distributed.launch "
"--master_port=1234 path/to/some/trainer.py arg1 arg2" "--nproc_per_node=2 --nnodes=2 --node_rank=1 --master_addr=127.0.0.1 "
"--master_port=1234 path/to/some/trainer.py arg1 arg2"
)
self.assertEqual(wrapped_udf_command, expected) self.assertEqual(wrapped_udf_command, expected)
def test_chained_udf(self): def test_chained_udf(self):
# test that a chained udf_command is properly handled # test that a chained udf_command is properly handled
udf_command = "cd path/to && python3.7 path/to/some/trainer.py arg1 arg2" udf_command = (
"cd path/to && python3.7 path/to/some/trainer.py arg1 arg2"
)
wrapped_udf_command = wrap_udf_in_torch_dist_launcher( wrapped_udf_command = wrap_udf_in_torch_dist_launcher(
udf_command=udf_command, udf_command=udf_command,
num_trainers=2, num_trainers=2,
...@@ -34,15 +40,21 @@ class TestWrapUdfInTorchDistLauncher(unittest.TestCase): ...@@ -34,15 +40,21 @@ class TestWrapUdfInTorchDistLauncher(unittest.TestCase):
master_addr="127.0.0.1", master_addr="127.0.0.1",
master_port=1234, master_port=1234,
) )
expected = "cd path/to && python3.7 -m torch.distributed.launch " \ expected = (
"--nproc_per_node=2 --nnodes=2 --node_rank=1 --master_addr=127.0.0.1 " \ "cd path/to && python3.7 -m torch.distributed.launch "
"--master_port=1234 path/to/some/trainer.py arg1 arg2" "--nproc_per_node=2 --nnodes=2 --node_rank=1 --master_addr=127.0.0.1 "
"--master_port=1234 path/to/some/trainer.py arg1 arg2"
)
self.assertEqual(wrapped_udf_command, expected) self.assertEqual(wrapped_udf_command, expected)
def test_py_versions(self): def test_py_versions(self):
# test that this correctly handles different py versions/binaries # test that this correctly handles different py versions/binaries
py_binaries = ( py_binaries = (
"python3.7", "python3.8", "python3.9", "python3", "python" "python3.7",
"python3.8",
"python3.9",
"python3",
"python",
) )
udf_command = "{python_bin} path/to/some/trainer.py arg1 arg2" udf_command = "{python_bin} path/to/some/trainer.py arg1 arg2"
...@@ -55,9 +67,13 @@ class TestWrapUdfInTorchDistLauncher(unittest.TestCase): ...@@ -55,9 +67,13 @@ class TestWrapUdfInTorchDistLauncher(unittest.TestCase):
master_addr="127.0.0.1", master_addr="127.0.0.1",
master_port=1234, master_port=1234,
) )
expected = "{python_bin} -m torch.distributed.launch ".format(python_bin=py_bin) + \ expected = (
"--nproc_per_node=2 --nnodes=2 --node_rank=1 --master_addr=127.0.0.1 " \ "{python_bin} -m torch.distributed.launch ".format(
"--master_port=1234 path/to/some/trainer.py arg1 arg2" python_bin=py_bin
)
+ "--nproc_per_node=2 --nnodes=2 --node_rank=1 --master_addr=127.0.0.1 "
"--master_port=1234 path/to/some/trainer.py arg1 arg2"
)
self.assertEqual(wrapped_udf_command, expected) self.assertEqual(wrapped_udf_command, expected)
...@@ -67,12 +83,13 @@ class TestWrapCmdWithLocalEnvvars(unittest.TestCase): ...@@ -67,12 +83,13 @@ class TestWrapCmdWithLocalEnvvars(unittest.TestCase):
def test_simple(self): def test_simple(self):
self.assertEqual( self.assertEqual(
wrap_cmd_with_local_envvars("ls && pwd", "VAR1=value1 VAR2=value2"), wrap_cmd_with_local_envvars("ls && pwd", "VAR1=value1 VAR2=value2"),
"(export VAR1=value1 VAR2=value2; ls && pwd)" "(export VAR1=value1 VAR2=value2; ls && pwd)",
) )
class TestConstructDglServerEnvVars(unittest.TestCase): class TestConstructDglServerEnvVars(unittest.TestCase):
"""construct_dgl_server_env_vars()""" """construct_dgl_server_env_vars()"""
def test_simple(self): def test_simple(self):
self.assertEqual( self.assertEqual(
construct_dgl_server_env_vars( construct_dgl_server_env_vars(
...@@ -83,7 +100,7 @@ class TestConstructDglServerEnvVars(unittest.TestCase): ...@@ -83,7 +100,7 @@ class TestConstructDglServerEnvVars(unittest.TestCase):
ip_config="path/to/ip.config", ip_config="path/to/ip.config",
num_servers=5, num_servers=5,
graph_format="csc", graph_format="csc",
keep_alive=False keep_alive=False,
), ),
( (
"DGL_ROLE=server " "DGL_ROLE=server "
...@@ -95,12 +112,13 @@ class TestConstructDglServerEnvVars(unittest.TestCase): ...@@ -95,12 +112,13 @@ class TestConstructDglServerEnvVars(unittest.TestCase):
"DGL_NUM_SERVER=5 " "DGL_NUM_SERVER=5 "
"DGL_GRAPH_FORMAT=csc " "DGL_GRAPH_FORMAT=csc "
"DGL_KEEP_ALIVE=0 " "DGL_KEEP_ALIVE=0 "
) ),
) )
class TestConstructDglClientEnvVars(unittest.TestCase): class TestConstructDglClientEnvVars(unittest.TestCase):
"""construct_dgl_client_env_vars()""" """construct_dgl_client_env_vars()"""
def test_simple(self): def test_simple(self):
# with pythonpath # with pythonpath
self.assertEqual( self.assertEqual(
...@@ -113,7 +131,7 @@ class TestConstructDglClientEnvVars(unittest.TestCase): ...@@ -113,7 +131,7 @@ class TestConstructDglClientEnvVars(unittest.TestCase):
graph_format="csc", graph_format="csc",
num_omp_threads=4, num_omp_threads=4,
group_id=0, group_id=0,
pythonpath="some/pythonpath/" pythonpath="some/pythonpath/",
), ),
( (
"DGL_DIST_MODE=distributed " "DGL_DIST_MODE=distributed "
...@@ -127,7 +145,7 @@ class TestConstructDglClientEnvVars(unittest.TestCase): ...@@ -127,7 +145,7 @@ class TestConstructDglClientEnvVars(unittest.TestCase):
"OMP_NUM_THREADS=4 " "OMP_NUM_THREADS=4 "
"DGL_GROUP_ID=0 " "DGL_GROUP_ID=0 "
"PYTHONPATH=some/pythonpath/ " "PYTHONPATH=some/pythonpath/ "
) ),
) )
# without pythonpath # without pythonpath
self.assertEqual( self.assertEqual(
...@@ -139,7 +157,7 @@ class TestConstructDglClientEnvVars(unittest.TestCase): ...@@ -139,7 +157,7 @@ class TestConstructDglClientEnvVars(unittest.TestCase):
num_servers=3, num_servers=3,
graph_format="csc", graph_format="csc",
num_omp_threads=4, num_omp_threads=4,
group_id=0 group_id=0,
), ),
( (
"DGL_DIST_MODE=distributed " "DGL_DIST_MODE=distributed "
...@@ -152,64 +170,72 @@ class TestConstructDglClientEnvVars(unittest.TestCase): ...@@ -152,64 +170,72 @@ class TestConstructDglClientEnvVars(unittest.TestCase):
"DGL_GRAPH_FORMAT=csc " "DGL_GRAPH_FORMAT=csc "
"OMP_NUM_THREADS=4 " "OMP_NUM_THREADS=4 "
"DGL_GROUP_ID=0 " "DGL_GROUP_ID=0 "
) ),
) )
def test_submit_jobs(): def test_submit_jobs():
class Args(): class Args:
pass pass
args = Args() args = Args()
with tempfile.TemporaryDirectory() as test_dir: with tempfile.TemporaryDirectory() as test_dir:
num_machines = 8 num_machines = 8
ip_config = os.path.join(test_dir, 'ip_config.txt') ip_config = os.path.join(test_dir, "ip_config.txt")
with open(ip_config, 'w') as f: with open(ip_config, "w") as f:
for i in range(num_machines): for i in range(num_machines):
f.write('{} {}\n'.format('127.0.0.'+str(i), 30050)) f.write("{} {}\n".format("127.0.0." + str(i), 30050))
part_config = os.path.join(test_dir, 'ogb-products.json') part_config = os.path.join(test_dir, "ogb-products.json")
with open(part_config, 'w') as f: with open(part_config, "w") as f:
json.dump({'num_parts': num_machines}, f) json.dump({"num_parts": num_machines}, f)
args.num_trainers = 8 args.num_trainers = 8
args.num_samplers = 1 args.num_samplers = 1
args.num_servers = 4 args.num_servers = 4
args.workspace = test_dir args.workspace = test_dir
args.part_config = 'ogb-products.json' args.part_config = "ogb-products.json"
args.ip_config = 'ip_config.txt' args.ip_config = "ip_config.txt"
args.server_name = 'ogb-products' args.server_name = "ogb-products"
args.keep_alive = False args.keep_alive = False
args.num_server_threads = 1 args.num_server_threads = 1
args.graph_format = 'csc' args.graph_format = "csc"
args.extra_envs = ["NCCL_DEBUG=INFO"] args.extra_envs = ["NCCL_DEBUG=INFO"]
args.num_omp_threads = 1 args.num_omp_threads = 1
udf_command = "python3 train_dist.py --num_epochs 10" udf_command = "python3 train_dist.py --num_epochs 10"
clients_cmd, servers_cmd = submit_jobs(args, udf_command, dry_run=True) clients_cmd, servers_cmd = submit_jobs(args, udf_command, dry_run=True)
def common_checks(): def common_checks():
assert 'cd ' + test_dir in cmd assert "cd " + test_dir in cmd
assert 'export ' + args.extra_envs[0] in cmd assert "export " + args.extra_envs[0] in cmd
assert f'DGL_NUM_SAMPLER={args.num_samplers}' in cmd assert f"DGL_NUM_SAMPLER={args.num_samplers}" in cmd
assert f'DGL_NUM_CLIENT={args.num_trainers*(args.num_samplers+1)*num_machines}' in cmd assert (
assert f'DGL_CONF_PATH={args.part_config}' in cmd f"DGL_NUM_CLIENT={args.num_trainers*(args.num_samplers+1)*num_machines}"
assert f'DGL_IP_CONFIG={args.ip_config}' in cmd in cmd
assert f'DGL_NUM_SERVER={args.num_servers}' in cmd )
assert f'DGL_GRAPH_FORMAT={args.graph_format}' in cmd assert f"DGL_CONF_PATH={args.part_config}" in cmd
assert f'OMP_NUM_THREADS={args.num_omp_threads}' in cmd assert f"DGL_IP_CONFIG={args.ip_config}" in cmd
assert udf_command[len('python3 '):] in cmd assert f"DGL_NUM_SERVER={args.num_servers}" in cmd
assert f"DGL_GRAPH_FORMAT={args.graph_format}" in cmd
assert f"OMP_NUM_THREADS={args.num_omp_threads}" in cmd
assert udf_command[len("python3 ") :] in cmd
for cmd in clients_cmd: for cmd in clients_cmd:
common_checks() common_checks()
assert 'DGL_DIST_MODE=distributed' in cmd assert "DGL_DIST_MODE=distributed" in cmd
assert 'DGL_ROLE=client' in cmd assert "DGL_ROLE=client" in cmd
assert 'DGL_GROUP_ID=0' in cmd assert "DGL_GROUP_ID=0" in cmd
assert f'python3 -m torch.distributed.launch --nproc_per_node={args.num_trainers} --nnodes={num_machines}' in cmd assert (
assert '--master_addr=127.0.0' in cmd f"python3 -m torch.distributed.launch --nproc_per_node={args.num_trainers} --nnodes={num_machines}"
assert '--master_port=1234' in cmd in cmd
)
assert "--master_addr=127.0.0" in cmd
assert "--master_port=1234" in cmd
for cmd in servers_cmd: for cmd in servers_cmd:
common_checks() common_checks()
assert 'DGL_ROLE=server' in cmd assert "DGL_ROLE=server" in cmd
assert 'DGL_KEEP_ALIVE=0' in cmd assert "DGL_KEEP_ALIVE=0" in cmd
assert 'DGL_SERVER_ID=' in cmd assert "DGL_SERVER_ID=" in cmd
if __name__ == '__main__': if __name__ == "__main__":
unittest.main() unittest.main()
# See the __main__ block for usage of chunk_graph(). # See the __main__ block for usage of chunk_graph().
import pathlib
import json import json
from contextlib import contextmanager
import logging import logging
import os import os
import pathlib
from contextlib import contextmanager
import torch import torch
from utils import array_readwriter, setdir
import dgl import dgl
from utils import setdir
from utils import array_readwriter
def chunk_numpy_array(arr, fmt_meta, chunk_sizes, path_fmt): def chunk_numpy_array(arr, fmt_meta, chunk_sizes, path_fmt):
paths = [] paths = []
...@@ -17,31 +17,38 @@ def chunk_numpy_array(arr, fmt_meta, chunk_sizes, path_fmt): ...@@ -17,31 +17,38 @@ def chunk_numpy_array(arr, fmt_meta, chunk_sizes, path_fmt):
for j, n in enumerate(chunk_sizes): for j, n in enumerate(chunk_sizes):
path = os.path.abspath(path_fmt % j) path = os.path.abspath(path_fmt % j)
arr_chunk = arr[offset:offset + n] arr_chunk = arr[offset : offset + n]
logging.info('Chunking %d-%d' % (offset, offset + n)) logging.info("Chunking %d-%d" % (offset, offset + n))
array_readwriter.get_array_parser(**fmt_meta).write(path, arr_chunk) array_readwriter.get_array_parser(**fmt_meta).write(path, arr_chunk)
offset += n offset += n
paths.append(path) paths.append(path)
return paths return paths
def _chunk_graph(g, name, ndata_paths, edata_paths, num_chunks, output_path): def _chunk_graph(g, name, ndata_paths, edata_paths, num_chunks, output_path):
# First deal with ndata and edata that are homogeneous (i.e. not a dict-of-dict) # First deal with ndata and edata that are homogeneous (i.e. not a dict-of-dict)
if len(g.ntypes) == 1 and not isinstance(next(iter(ndata_paths.values())), dict): if len(g.ntypes) == 1 and not isinstance(
next(iter(ndata_paths.values())), dict
):
ndata_paths = {g.ntypes[0]: ndata_paths} ndata_paths = {g.ntypes[0]: ndata_paths}
if len(g.etypes) == 1 and not isinstance(next(iter(edata_paths.values())), dict): if len(g.etypes) == 1 and not isinstance(
next(iter(edata_paths.values())), dict
):
edata_paths = {g.etypes[0]: ndata_paths} edata_paths = {g.etypes[0]: ndata_paths}
# Then convert all edge types to canonical edge types # Then convert all edge types to canonical edge types
etypestrs = {etype: ':'.join(etype) for etype in g.canonical_etypes} etypestrs = {etype: ":".join(etype) for etype in g.canonical_etypes}
edata_paths = {':'.join(g.to_canonical_etype(k)): v for k, v in edata_paths.items()} edata_paths = {
":".join(g.to_canonical_etype(k)): v for k, v in edata_paths.items()
}
metadata = {} metadata = {}
metadata['graph_name'] = name metadata["graph_name"] = name
metadata['node_type'] = g.ntypes metadata["node_type"] = g.ntypes
# Compute the number of nodes per chunk per node type # Compute the number of nodes per chunk per node type
metadata['num_nodes_per_chunk'] = num_nodes_per_chunk = [] metadata["num_nodes_per_chunk"] = num_nodes_per_chunk = []
for ntype in g.ntypes: for ntype in g.ntypes:
num_nodes = g.num_nodes(ntype) num_nodes = g.num_nodes(ntype)
num_nodes_list = [] num_nodes_list = []
...@@ -49,12 +56,14 @@ def _chunk_graph(g, name, ndata_paths, edata_paths, num_chunks, output_path): ...@@ -49,12 +56,14 @@ def _chunk_graph(g, name, ndata_paths, edata_paths, num_chunks, output_path):
n = num_nodes // num_chunks + (i < num_nodes % num_chunks) n = num_nodes // num_chunks + (i < num_nodes % num_chunks)
num_nodes_list.append(n) num_nodes_list.append(n)
num_nodes_per_chunk.append(num_nodes_list) num_nodes_per_chunk.append(num_nodes_list)
num_nodes_per_chunk_dict = {k: v for k, v in zip(g.ntypes, num_nodes_per_chunk)} num_nodes_per_chunk_dict = {
k: v for k, v in zip(g.ntypes, num_nodes_per_chunk)
}
metadata['edge_type'] = [etypestrs[etype] for etype in g.canonical_etypes] metadata["edge_type"] = [etypestrs[etype] for etype in g.canonical_etypes]
# Compute the number of edges per chunk per edge type # Compute the number of edges per chunk per edge type
metadata['num_edges_per_chunk'] = num_edges_per_chunk = [] metadata["num_edges_per_chunk"] = num_edges_per_chunk = []
for etype in g.canonical_etypes: for etype in g.canonical_etypes:
num_edges = g.num_edges(etype) num_edges = g.num_edges(etype)
num_edges_list = [] num_edges_list = []
...@@ -62,67 +71,88 @@ def _chunk_graph(g, name, ndata_paths, edata_paths, num_chunks, output_path): ...@@ -62,67 +71,88 @@ def _chunk_graph(g, name, ndata_paths, edata_paths, num_chunks, output_path):
n = num_edges // num_chunks + (i < num_edges % num_chunks) n = num_edges // num_chunks + (i < num_edges % num_chunks)
num_edges_list.append(n) num_edges_list.append(n)
num_edges_per_chunk.append(num_edges_list) num_edges_per_chunk.append(num_edges_list)
num_edges_per_chunk_dict = {k: v for k, v in zip(g.canonical_etypes, num_edges_per_chunk)} num_edges_per_chunk_dict = {
k: v for k, v in zip(g.canonical_etypes, num_edges_per_chunk)
}
# Split edge index # Split edge index
metadata['edges'] = {} metadata["edges"] = {}
with setdir('edge_index'): with setdir("edge_index"):
for etype in g.canonical_etypes: for etype in g.canonical_etypes:
etypestr = etypestrs[etype] etypestr = etypestrs[etype]
logging.info('Chunking edge index for %s' % etypestr) logging.info("Chunking edge index for %s" % etypestr)
edges_meta = {} edges_meta = {}
fmt_meta = {"name": "csv", "delimiter": " "} fmt_meta = {"name": "csv", "delimiter": " "}
edges_meta['format'] = fmt_meta edges_meta["format"] = fmt_meta
srcdst = torch.stack(g.edges(etype=etype), 1) srcdst = torch.stack(g.edges(etype=etype), 1)
edges_meta['data'] = chunk_numpy_array( edges_meta["data"] = chunk_numpy_array(
srcdst.numpy(), fmt_meta, num_edges_per_chunk_dict[etype], srcdst.numpy(),
etypestr + '%d.txt') fmt_meta,
metadata['edges'][etypestr] = edges_meta num_edges_per_chunk_dict[etype],
etypestr + "%d.txt",
)
metadata["edges"][etypestr] = edges_meta
# Chunk node data # Chunk node data
metadata['node_data'] = {} metadata["node_data"] = {}
with setdir('node_data'): with setdir("node_data"):
for ntype, ndata_per_type in ndata_paths.items(): for ntype, ndata_per_type in ndata_paths.items():
ndata_meta = {} ndata_meta = {}
with setdir(ntype): with setdir(ntype):
for key, path in ndata_per_type.items(): for key, path in ndata_per_type.items():
logging.info('Chunking node data for type %s key %s' % (ntype, key)) logging.info(
"Chunking node data for type %s key %s" % (ntype, key)
)
ndata_key_meta = {} ndata_key_meta = {}
reader_fmt_meta = writer_fmt_meta = {"name": "numpy"} reader_fmt_meta = writer_fmt_meta = {"name": "numpy"}
arr = array_readwriter.get_array_parser(**reader_fmt_meta).read(path) arr = array_readwriter.get_array_parser(
ndata_key_meta['format'] = writer_fmt_meta **reader_fmt_meta
ndata_key_meta['data'] = chunk_numpy_array( ).read(path)
arr, writer_fmt_meta, num_nodes_per_chunk_dict[ntype], ndata_key_meta["format"] = writer_fmt_meta
key + '-%d.npy') ndata_key_meta["data"] = chunk_numpy_array(
arr,
writer_fmt_meta,
num_nodes_per_chunk_dict[ntype],
key + "-%d.npy",
)
ndata_meta[key] = ndata_key_meta ndata_meta[key] = ndata_key_meta
metadata['node_data'][ntype] = ndata_meta metadata["node_data"][ntype] = ndata_meta
# Chunk edge data # Chunk edge data
metadata['edge_data'] = {} metadata["edge_data"] = {}
with setdir('edge_data'): with setdir("edge_data"):
for etypestr, edata_per_type in edata_paths.items(): for etypestr, edata_per_type in edata_paths.items():
edata_meta = {} edata_meta = {}
with setdir(etypestr): with setdir(etypestr):
for key, path in edata_per_type.items(): for key, path in edata_per_type.items():
logging.info('Chunking edge data for type %s key %s' % (etypestr, key)) logging.info(
"Chunking edge data for type %s key %s"
% (etypestr, key)
)
edata_key_meta = {} edata_key_meta = {}
reader_fmt_meta = writer_fmt_meta = {"name": "numpy"} reader_fmt_meta = writer_fmt_meta = {"name": "numpy"}
arr = array_readwriter.get_array_parser(**reader_fmt_meta).read(path) arr = array_readwriter.get_array_parser(
edata_key_meta['format'] = writer_fmt_meta **reader_fmt_meta
etype = tuple(etypestr.split(':')) ).read(path)
edata_key_meta['data'] = chunk_numpy_array( edata_key_meta["format"] = writer_fmt_meta
arr, writer_fmt_meta, num_edges_per_chunk_dict[etype], etype = tuple(etypestr.split(":"))
key + '-%d.npy') edata_key_meta["data"] = chunk_numpy_array(
arr,
writer_fmt_meta,
num_edges_per_chunk_dict[etype],
key + "-%d.npy",
)
edata_meta[key] = edata_key_meta edata_meta[key] = edata_key_meta
metadata['edge_data'][etypestr] = edata_meta metadata["edge_data"][etypestr] = edata_meta
metadata_path = 'metadata.json' metadata_path = "metadata.json"
with open(metadata_path, 'w') as f: with open(metadata_path, "w") as f:
json.dump(metadata, f) json.dump(metadata, f)
logging.info('Saved metadata in %s' % os.path.abspath(metadata_path)) logging.info("Saved metadata in %s" % os.path.abspath(metadata_path))
def chunk_graph(g, name, ndata_paths, edata_paths, num_chunks, output_path): def chunk_graph(g, name, ndata_paths, edata_paths, num_chunks, output_path):
""" """
...@@ -157,22 +187,29 @@ def chunk_graph(g, name, ndata_paths, edata_paths, num_chunks, output_path): ...@@ -157,22 +187,29 @@ def chunk_graph(g, name, ndata_paths, edata_paths, num_chunks, output_path):
with setdir(output_path): with setdir(output_path):
_chunk_graph(g, name, ndata_paths, edata_paths, num_chunks, output_path) _chunk_graph(g, name, ndata_paths, edata_paths, num_chunks, output_path)
if __name__ == '__main__':
logging.basicConfig(level='INFO') if __name__ == "__main__":
input_dir = '/data' logging.basicConfig(level="INFO")
output_dir = '/chunked-data' input_dir = "/data"
(g,), _ = dgl.load_graphs(os.path.join(input_dir, 'graph.dgl')) output_dir = "/chunked-data"
(g,), _ = dgl.load_graphs(os.path.join(input_dir, "graph.dgl"))
chunk_graph( chunk_graph(
g, g,
'mag240m', "mag240m",
{'paper': { {
'feat': os.path.join(input_dir, 'paper/feat.npy'), "paper": {
'label': os.path.join(input_dir, 'paper/label.npy'), "feat": os.path.join(input_dir, "paper/feat.npy"),
'year': os.path.join(input_dir, 'paper/year.npy')}}, "label": os.path.join(input_dir, "paper/label.npy"),
{'cites': {'count': os.path.join(input_dir, 'cites/count.npy')}, "year": os.path.join(input_dir, "paper/year.npy"),
'writes': {'year': os.path.join(input_dir, 'writes/year.npy')}, }
# you can put the same data file if they indeed share the features. },
'rev_writes': {'year': os.path.join(input_dir, 'writes/year.npy')}}, {
4, "cites": {"count": os.path.join(input_dir, "cites/count.npy")},
output_dir) "writes": {"year": os.path.join(input_dir, "writes/year.npy")},
# you can put the same data file if they indeed share the features.
"rev_writes": {"year": os.path.join(input_dir, "writes/year.npy")},
},
4,
output_dir,
)
# The generated metadata goes as in tools/sample-config/mag240m-metadata.json. # The generated metadata goes as in tools/sample-config/mag240m-metadata.json.
"""Copy the partitions to a cluster of machines.""" """Copy the partitions to a cluster of machines."""
import argparse
import copy
import json
import logging
import os import os
import signal
import stat import stat
import sys
import subprocess import subprocess
import argparse import sys
import signal
import logging
import json def copy_file(file_name, ip, workspace, param=""):
import copy print("copy {} to {}".format(file_name, ip + ":" + workspace + "/"))
cmd = "scp " + param + " " + file_name + " " + ip + ":" + workspace + "/"
subprocess.check_call(cmd, shell=True)
def copy_file(file_name, ip, workspace, param=''):
print('copy {} to {}'.format(file_name, ip + ':' + workspace + '/'))
cmd = 'scp ' + param + ' ' + file_name + ' ' + ip + ':' + workspace + '/'
subprocess.check_call(cmd, shell = True)
def exec_cmd(ip, cmd): def exec_cmd(ip, cmd):
cmd = 'ssh -o StrictHostKeyChecking=no ' + ip + ' \'' + cmd + '\'' cmd = "ssh -o StrictHostKeyChecking=no " + ip + " '" + cmd + "'"
subprocess.check_call(cmd, shell = True) subprocess.check_call(cmd, shell=True)
def main(): def main():
parser = argparse.ArgumentParser(description='Copy data to the servers.') parser = argparse.ArgumentParser(description="Copy data to the servers.")
parser.add_argument('--workspace', type=str, required=True, parser.add_argument(
help='Path of user directory of distributed tasks. \ "--workspace",
type=str,
required=True,
help="Path of user directory of distributed tasks. \
This is used to specify a destination location where \ This is used to specify a destination location where \
data are copied to on remote machines.') data are copied to on remote machines.",
parser.add_argument('--rel_data_path', type=str, required=True, )
help='Relative path in workspace to store the partition data.') parser.add_argument(
parser.add_argument('--part_config', type=str, required=True, "--rel_data_path",
help='The partition config file. The path is on the local machine.') type=str,
parser.add_argument('--script_folder', type=str, required=True, required=True,
help='The folder contains all the user code scripts.') help="Relative path in workspace to store the partition data.",
parser.add_argument('--ip_config', type=str, required=True, )
help='The file of IP configuration for servers. \ parser.add_argument(
The path is on the local machine.') "--part_config",
type=str,
required=True,
help="The partition config file. The path is on the local machine.",
)
parser.add_argument(
"--script_folder",
type=str,
required=True,
help="The folder contains all the user code scripts.",
)
parser.add_argument(
"--ip_config",
type=str,
required=True,
help="The file of IP configuration for servers. \
The path is on the local machine.",
)
args = parser.parse_args() args = parser.parse_args()
hosts = [] hosts = []
with open(args.ip_config) as f: with open(args.ip_config) as f:
for line in f: for line in f:
res = line.strip().split(' ') res = line.strip().split(" ")
ip = res[0] ip = res[0]
hosts.append(ip) hosts.append(ip)
# We need to update the partition config file so that the paths are relative to # We need to update the partition config file so that the paths are relative to
# the workspace in the remote machines. # the workspace in the remote machines.
with open(args.part_config) as conf_f: with open(args.part_config) as conf_f:
part_metadata = json.load(conf_f) part_metadata = json.load(conf_f)
tmp_part_metadata = copy.deepcopy(part_metadata) tmp_part_metadata = copy.deepcopy(part_metadata)
num_parts = part_metadata['num_parts'] num_parts = part_metadata["num_parts"]
assert num_parts == len(hosts), \ assert num_parts == len(
'The number of partitions needs to be the same as the number of hosts.' hosts
graph_name = part_metadata['graph_name'] ), "The number of partitions needs to be the same as the number of hosts."
node_map = part_metadata['node_map'] graph_name = part_metadata["graph_name"]
edge_map = part_metadata['edge_map'] node_map = part_metadata["node_map"]
edge_map = part_metadata["edge_map"]
if not isinstance(node_map, dict): if not isinstance(node_map, dict):
assert node_map[-4:] == '.npy', 'node map should be stored in a NumPy array.' assert (
tmp_part_metadata['node_map'] = '{}/{}/node_map.npy'.format(args.workspace, node_map[-4:] == ".npy"
args.rel_data_path) ), "node map should be stored in a NumPy array."
tmp_part_metadata["node_map"] = "{}/{}/node_map.npy".format(
args.workspace, args.rel_data_path
)
if not isinstance(edge_map, dict): if not isinstance(edge_map, dict):
assert edge_map[-4:] == '.npy', 'edge map should be stored in a NumPy array.' assert (
tmp_part_metadata['edge_map'] = '{}/{}/edge_map.npy'.format(args.workspace, edge_map[-4:] == ".npy"
args.rel_data_path) ), "edge map should be stored in a NumPy array."
tmp_part_metadata["edge_map"] = "{}/{}/edge_map.npy".format(
args.workspace, args.rel_data_path
)
for part_id in range(num_parts): for part_id in range(num_parts):
part_files = tmp_part_metadata['part-{}'.format(part_id)] part_files = tmp_part_metadata["part-{}".format(part_id)]
part_files['edge_feats'] = '{}/part{}/edge_feat.dgl'.format(args.rel_data_path, part_id) part_files["edge_feats"] = "{}/part{}/edge_feat.dgl".format(
part_files['node_feats'] = '{}/part{}/node_feat.dgl'.format(args.rel_data_path, part_id) args.rel_data_path, part_id
part_files['part_graph'] = '{}/part{}/graph.dgl'.format(args.rel_data_path, part_id) )
tmp_part_config = '/tmp/{}.json'.format(graph_name) part_files["node_feats"] = "{}/part{}/node_feat.dgl".format(
with open(tmp_part_config, 'w') as outfile: args.rel_data_path, part_id
)
part_files["part_graph"] = "{}/part{}/graph.dgl".format(
args.rel_data_path, part_id
)
tmp_part_config = "/tmp/{}.json".format(graph_name)
with open(tmp_part_config, "w") as outfile:
json.dump(tmp_part_metadata, outfile, sort_keys=True, indent=4) json.dump(tmp_part_metadata, outfile, sort_keys=True, indent=4)
# Copy ip config. # Copy ip config.
for part_id, ip in enumerate(hosts): for part_id, ip in enumerate(hosts):
remote_path = '{}/{}'.format(args.workspace, args.rel_data_path) remote_path = "{}/{}".format(args.workspace, args.rel_data_path)
exec_cmd(ip, 'mkdir -p {}'.format(remote_path)) exec_cmd(ip, "mkdir -p {}".format(remote_path))
copy_file(args.ip_config, ip, args.workspace) copy_file(args.ip_config, ip, args.workspace)
copy_file(tmp_part_config, ip, '{}/{}'.format(args.workspace, args.rel_data_path)) copy_file(
node_map = part_metadata['node_map'] tmp_part_config,
edge_map = part_metadata['edge_map'] ip,
"{}/{}".format(args.workspace, args.rel_data_path),
)
node_map = part_metadata["node_map"]
edge_map = part_metadata["edge_map"]
if not isinstance(node_map, dict): if not isinstance(node_map, dict):
copy_file(node_map, ip, tmp_part_metadata['node_map']) copy_file(node_map, ip, tmp_part_metadata["node_map"])
if not isinstance(edge_map, dict): if not isinstance(edge_map, dict):
copy_file(edge_map, ip, tmp_part_metadata['edge_map']) copy_file(edge_map, ip, tmp_part_metadata["edge_map"])
remote_path = '{}/{}/part{}'.format(args.workspace, args.rel_data_path, part_id) remote_path = "{}/{}/part{}".format(
exec_cmd(ip, 'mkdir -p {}'.format(remote_path)) args.workspace, args.rel_data_path, part_id
)
part_files = part_metadata['part-{}'.format(part_id)] exec_cmd(ip, "mkdir -p {}".format(remote_path))
copy_file(part_files['node_feats'], ip, remote_path)
copy_file(part_files['edge_feats'], ip, remote_path) part_files = part_metadata["part-{}".format(part_id)]
copy_file(part_files['part_graph'], ip, remote_path) copy_file(part_files["node_feats"], ip, remote_path)
copy_file(part_files["edge_feats"], ip, remote_path)
copy_file(part_files["part_graph"], ip, remote_path)
# copy script folder # copy script folder
copy_file(args.script_folder, ip, args.workspace, '-r') copy_file(args.script_folder, ip, args.workspace, "-r")
def signal_handler(signal, frame): def signal_handler(signal, frame):
logging.info('Stop copying') logging.info("Stop copying")
sys.exit(0) sys.exit(0)
if __name__ == '__main__':
fmt = '%(asctime)s %(levelname)s %(message)s' if __name__ == "__main__":
fmt = "%(asctime)s %(levelname)s %(message)s"
logging.basicConfig(format=fmt, level=logging.INFO) logging.basicConfig(format=fmt, level=logging.INFO)
signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGINT, signal_handler)
main() main()
"""Launching distributed graph partitioning pipeline """ """Launching distributed graph partitioning pipeline """
import os
import sys
import argparse import argparse
import logging
import json import json
import logging
import os
import sys
from partition_algo.base import load_partition_meta from partition_algo.base import load_partition_meta
INSTALL_DIR = os.path.abspath(os.path.join(__file__, '..')) INSTALL_DIR = os.path.abspath(os.path.join(__file__, ".."))
LAUNCH_SCRIPT = "distgraphlaunch.py" LAUNCH_SCRIPT = "distgraphlaunch.py"
PIPELINE_SCRIPT = "distpartitioning/data_proc_pipeline.py" PIPELINE_SCRIPT = "distpartitioning/data_proc_pipeline.py"
...@@ -23,6 +24,7 @@ LARG_IPCONF = "ip_config" ...@@ -23,6 +24,7 @@ LARG_IPCONF = "ip_config"
LARG_MASTER_PORT = "master_port" LARG_MASTER_PORT = "master_port"
LARG_SSH_PORT = "ssh_port" LARG_SSH_PORT = "ssh_port"
def get_launch_cmd(args) -> str: def get_launch_cmd(args) -> str:
cmd = sys.executable + " " + os.path.join(INSTALL_DIR, LAUNCH_SCRIPT) cmd = sys.executable + " " + os.path.join(INSTALL_DIR, LAUNCH_SCRIPT)
cmd = f"{cmd} --{LARG_SSH_PORT} {args.ssh_port} " cmd = f"{cmd} --{LARG_SSH_PORT} {args.ssh_port} "
...@@ -34,7 +36,7 @@ def get_launch_cmd(args) -> str: ...@@ -34,7 +36,7 @@ def get_launch_cmd(args) -> str:
def submit_jobs(args) -> str: def submit_jobs(args) -> str:
#read the json file and get the remaining argument here. # read the json file and get the remaining argument here.
schema_path = "metadata.json" schema_path = "metadata.json"
with open(os.path.join(args.in_dir, schema_path)) as schema: with open(os.path.join(args.in_dir, schema_path)) as schema:
schema_map = json.load(schema) schema_map = json.load(schema)
...@@ -49,17 +51,22 @@ def submit_jobs(args) -> str: ...@@ -49,17 +51,22 @@ def submit_jobs(args) -> str:
part_meta = load_partition_meta(partition_path) part_meta = load_partition_meta(partition_path)
num_parts = part_meta.num_parts num_parts = part_meta.num_parts
if num_parts > num_chunks: if num_parts > num_chunks:
raise Exception('Number of partitions should be less/equal than number of chunks.') raise Exception(
"Number of partitions should be less/equal than number of chunks."
)
# verify ip_config # verify ip_config
with open(args.ip_config, 'r') as f: with open(args.ip_config, "r") as f:
num_ips = len(f.readlines()) num_ips = len(f.readlines())
assert num_ips == num_parts, \ assert (
f'The number of lines[{args.ip_config}] should be equal to num_parts[{num_parts}].' num_ips == num_parts
), f"The number of lines[{args.ip_config}] should be equal to num_parts[{num_parts}]."
argslist = "" argslist = ""
argslist += "--world-size {} ".format(num_parts) argslist += "--world-size {} ".format(num_parts)
argslist += "--partitions-dir {} ".format(os.path.abspath(args.partitions_dir)) argslist += "--partitions-dir {} ".format(
os.path.abspath(args.partitions_dir)
)
argslist += "--input-dir {} ".format(os.path.abspath(args.in_dir)) argslist += "--input-dir {} ".format(os.path.abspath(args.in_dir))
argslist += "--graph-name {} ".format(graph_name) argslist += "--graph-name {} ".format(graph_name)
argslist += "--schema {} ".format(schema_path) argslist += "--schema {} ".format(schema_path)
...@@ -75,28 +82,73 @@ def submit_jobs(args) -> str: ...@@ -75,28 +82,73 @@ def submit_jobs(args) -> str:
udf_cmd = f"{args.python_path} {pipeline_cmd} {argslist}" udf_cmd = f"{args.python_path} {pipeline_cmd} {argslist}"
launch_cmd = get_launch_cmd(args) launch_cmd = get_launch_cmd(args)
launch_cmd += '\"'+udf_cmd+'\"' launch_cmd += '"' + udf_cmd + '"'
print(launch_cmd) print(launch_cmd)
os.system(launch_cmd) os.system(launch_cmd)
def main(): def main():
parser = argparse.ArgumentParser(description='Dispatch edge index and data to partitions', formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser = argparse.ArgumentParser(
description="Dispatch edge index and data to partitions",
parser.add_argument('--in-dir', type=str, help='Location of the input directory where the dataset is located') formatter_class=argparse.ArgumentDefaultsHelpFormatter,
parser.add_argument('--partitions-dir', type=str, help='Location of the partition-id mapping files which define node-ids and their respective partition-ids, relative to the input directory') )
parser.add_argument('--out-dir', type=str, help='Location of the output directory where the graph partitions will be created by this pipeline')
parser.add_argument('--ip-config', type=str, help='File location of IP configuration for server processes') parser.add_argument(
parser.add_argument('--master-port', type=int, default=12345, help='port used by gloo group to create randezvous point') "--in-dir",
parser.add_argument('--log-level', type=str, default="info", help='To enable log level for debugging purposes. Available options: (Critical, Error, Warning, Info, Debug, Notset)') type=str,
parser.add_argument('--python-path', type=str, default=sys.executable, help='Path to the Python executable on all workers') help="Location of the input directory where the dataset is located",
parser.add_argument('--ssh-port', type=int, default=22, help='SSH Port.') )
parser.add_argument('--process-group-timeout', type=int, default=1800, parser.add_argument(
help='timeout[seconds] for operations executed against the process group') "--partitions-dir",
parser.add_argument('--save-orig-nids', action='store_true', type=str,
help='Save original node IDs into files') help="Location of the partition-id mapping files which define node-ids and their respective partition-ids, relative to the input directory",
parser.add_argument('--save-orig-eids', action='store_true', )
help='Save original edge IDs into files') parser.add_argument(
"--out-dir",
type=str,
help="Location of the output directory where the graph partitions will be created by this pipeline",
)
parser.add_argument(
"--ip-config",
type=str,
help="File location of IP configuration for server processes",
)
parser.add_argument(
"--master-port",
type=int,
default=12345,
help="port used by gloo group to create randezvous point",
)
parser.add_argument(
"--log-level",
type=str,
default="info",
help="To enable log level for debugging purposes. Available options: (Critical, Error, Warning, Info, Debug, Notset)",
)
parser.add_argument(
"--python-path",
type=str,
default=sys.executable,
help="Path to the Python executable on all workers",
)
parser.add_argument("--ssh-port", type=int, default=22, help="SSH Port.")
parser.add_argument(
"--process-group-timeout",
type=int,
default=1800,
help="timeout[seconds] for operations executed against the process group",
)
parser.add_argument(
"--save-orig-nids",
action="store_true",
help="Save original node IDs into files",
)
parser.add_argument(
"--save-orig-eids",
action="store_true",
help="Save original edge IDs into files",
)
args, udf_command = parser.parse_known_args() args, udf_command = parser.parse_known_args()
...@@ -109,7 +161,8 @@ def main(): ...@@ -109,7 +161,8 @@ def main():
tokens = sys.executable.split(os.sep) tokens = sys.executable.split(os.sep)
submit_jobs(args) submit_jobs(args)
if __name__ == '__main__':
fmt = '%(asctime)s %(levelname)s %(message)s' if __name__ == "__main__":
fmt = "%(asctime)s %(levelname)s %(message)s"
logging.basicConfig(format=fmt, level=logging.INFO) logging.basicConfig(format=fmt, level=logging.INFO)
main() main()
"""Launching tool for DGL distributed training""" """Launching tool for DGL distributed training"""
import os
import stat
import sys
import subprocess
import argparse import argparse
import signal
import logging
import time
import json import json
import logging
import multiprocessing import multiprocessing
import os
import queue import queue
import re import re
import signal
import stat
import subprocess
import sys
import time
from functools import partial from functools import partial
from threading import Thread from threading import Thread
from typing import Optional from typing import Optional
def cleanup_proc(get_all_remote_pids, conn): def cleanup_proc(get_all_remote_pids, conn):
'''This process tries to clean up the remote training tasks. """This process tries to clean up the remote training tasks."""
''' print("cleanupu process runs")
print('cleanupu process runs')
# This process should not handle SIGINT. # This process should not handle SIGINT.
signal.signal(signal.SIGINT, signal.SIG_IGN) signal.signal(signal.SIGINT, signal.SIG_IGN)
data = conn.recv() data = conn.recv()
# If the launch process exits normally, this process doesn't need to do anything. # If the launch process exits normally, this process doesn't need to do anything.
if data == 'exit': if data == "exit":
sys.exit(0) sys.exit(0)
else: else:
remote_pids = get_all_remote_pids() remote_pids = get_all_remote_pids()
# Otherwise, we need to ssh to each machine and kill the training jobs. # Otherwise, we need to ssh to each machine and kill the training jobs.
for (ip, port), pids in remote_pids.items(): for (ip, port), pids in remote_pids.items():
kill_process(ip, port, pids) kill_process(ip, port, pids)
print('cleanup process exits') print("cleanup process exits")
def kill_process(ip, port, pids): def kill_process(ip, port, pids):
'''ssh to a remote machine and kill the specified processes. """ssh to a remote machine and kill the specified processes."""
'''
curr_pid = os.getpid() curr_pid = os.getpid()
killed_pids = [] killed_pids = []
# If we kill child processes first, the parent process may create more again. This happens # If we kill child processes first, the parent process may create more again. This happens
...@@ -43,8 +43,14 @@ def kill_process(ip, port, pids): ...@@ -43,8 +43,14 @@ def kill_process(ip, port, pids):
pids.sort() pids.sort()
for pid in pids: for pid in pids:
assert curr_pid != pid assert curr_pid != pid
print('kill process {} on {}:{}'.format(pid, ip, port), flush=True) print("kill process {} on {}:{}".format(pid, ip, port), flush=True)
kill_cmd = 'ssh -o StrictHostKeyChecking=no -p ' + str(port) + ' ' + ip + ' \'kill {}\''.format(pid) kill_cmd = (
"ssh -o StrictHostKeyChecking=no -p "
+ str(port)
+ " "
+ ip
+ " 'kill {}'".format(pid)
)
subprocess.run(kill_cmd, shell=True) subprocess.run(kill_cmd, shell=True)
killed_pids.append(pid) killed_pids.append(pid)
# It's possible that some of the processes are not killed. Let's try again. # It's possible that some of the processes are not killed. Let's try again.
...@@ -55,30 +61,45 @@ def kill_process(ip, port, pids): ...@@ -55,30 +61,45 @@ def kill_process(ip, port, pids):
else: else:
killed_pids.sort() killed_pids.sort()
for pid in killed_pids: for pid in killed_pids:
print('kill process {} on {}:{}'.format(pid, ip, port), flush=True) print(
kill_cmd = 'ssh -o StrictHostKeyChecking=no -p ' + str(port) + ' ' + ip + ' \'kill -9 {}\''.format(pid) "kill process {} on {}:{}".format(pid, ip, port), flush=True
)
kill_cmd = (
"ssh -o StrictHostKeyChecking=no -p "
+ str(port)
+ " "
+ ip
+ " 'kill -9 {}'".format(pid)
)
subprocess.run(kill_cmd, shell=True) subprocess.run(kill_cmd, shell=True)
def get_killed_pids(ip, port, killed_pids): def get_killed_pids(ip, port, killed_pids):
'''Get the process IDs that we want to kill but are still alive. """Get the process IDs that we want to kill but are still alive."""
'''
killed_pids = [str(pid) for pid in killed_pids] killed_pids = [str(pid) for pid in killed_pids]
killed_pids = ','.join(killed_pids) killed_pids = ",".join(killed_pids)
ps_cmd = 'ssh -o StrictHostKeyChecking=no -p ' + str(port) + ' ' + ip + ' \'ps -p {} -h\''.format(killed_pids) ps_cmd = (
"ssh -o StrictHostKeyChecking=no -p "
+ str(port)
+ " "
+ ip
+ " 'ps -p {} -h'".format(killed_pids)
)
res = subprocess.run(ps_cmd, shell=True, stdout=subprocess.PIPE) res = subprocess.run(ps_cmd, shell=True, stdout=subprocess.PIPE)
pids = [] pids = []
for p in res.stdout.decode('utf-8').split('\n'): for p in res.stdout.decode("utf-8").split("\n"):
l = p.split() l = p.split()
if len(l) > 0: if len(l) > 0:
pids.append(int(l[0])) pids.append(int(l[0]))
return pids return pids
def execute_remote( def execute_remote(
cmd: str, cmd: str,
state_q: queue.Queue, state_q: queue.Queue,
ip: str, ip: str,
port: int, port: int,
username: Optional[str] = "" username: Optional[str] = "",
) -> Thread: ) -> Thread:
"""Execute command line on remote machine via ssh. """Execute command line on remote machine via ssh.
...@@ -118,22 +139,34 @@ def execute_remote( ...@@ -118,22 +139,34 @@ def execute_remote(
except Exception: except Exception:
state_q.put(-1) state_q.put(-1)
thread = Thread(target=run, args=(ssh_cmd, state_q,)) thread = Thread(
target=run,
args=(
ssh_cmd,
state_q,
),
)
thread.setDaemon(True) thread.setDaemon(True)
thread.start() thread.start()
# sleep for a while in case of ssh is rejected by peer due to busy connection # sleep for a while in case of ssh is rejected by peer due to busy connection
time.sleep(0.2) time.sleep(0.2)
return thread return thread
def get_remote_pids(ip, port, cmd_regex): def get_remote_pids(ip, port, cmd_regex):
"""Get the process IDs that run the command in the remote machine. """Get the process IDs that run the command in the remote machine."""
"""
pids = [] pids = []
curr_pid = os.getpid() curr_pid = os.getpid()
# Here we want to get the python processes. We may get some ssh processes, so we should filter them out. # Here we want to get the python processes. We may get some ssh processes, so we should filter them out.
ps_cmd = 'ssh -o StrictHostKeyChecking=no -p ' + str(port) + ' ' + ip + ' \'ps -aux | grep python | grep -v StrictHostKeyChecking\'' ps_cmd = (
"ssh -o StrictHostKeyChecking=no -p "
+ str(port)
+ " "
+ ip
+ " 'ps -aux | grep python | grep -v StrictHostKeyChecking'"
)
res = subprocess.run(ps_cmd, shell=True, stdout=subprocess.PIPE) res = subprocess.run(ps_cmd, shell=True, stdout=subprocess.PIPE)
for p in res.stdout.decode('utf-8').split('\n'): for p in res.stdout.decode("utf-8").split("\n"):
l = p.split() l = p.split()
if len(l) < 2: if len(l) < 2:
continue continue
...@@ -142,28 +175,34 @@ def get_remote_pids(ip, port, cmd_regex): ...@@ -142,28 +175,34 @@ def get_remote_pids(ip, port, cmd_regex):
if res is not None and int(l[1]) != curr_pid: if res is not None and int(l[1]) != curr_pid:
pids.append(l[1]) pids.append(l[1])
pid_str = ','.join([str(pid) for pid in pids]) pid_str = ",".join([str(pid) for pid in pids])
ps_cmd = 'ssh -o StrictHostKeyChecking=no -p ' + str(port) + ' ' + ip + ' \'pgrep -P {}\''.format(pid_str) ps_cmd = (
"ssh -o StrictHostKeyChecking=no -p "
+ str(port)
+ " "
+ ip
+ " 'pgrep -P {}'".format(pid_str)
)
res = subprocess.run(ps_cmd, shell=True, stdout=subprocess.PIPE) res = subprocess.run(ps_cmd, shell=True, stdout=subprocess.PIPE)
pids1 = res.stdout.decode('utf-8').split('\n') pids1 = res.stdout.decode("utf-8").split("\n")
all_pids = [] all_pids = []
for pid in set(pids + pids1): for pid in set(pids + pids1):
if pid == '' or int(pid) == curr_pid: if pid == "" or int(pid) == curr_pid:
continue continue
all_pids.append(int(pid)) all_pids.append(int(pid))
all_pids.sort() all_pids.sort()
return all_pids return all_pids
def get_all_remote_pids(hosts, ssh_port, udf_command): def get_all_remote_pids(hosts, ssh_port, udf_command):
'''Get all remote processes. """Get all remote processes."""
'''
remote_pids = {} remote_pids = {}
for node_id, host in enumerate(hosts): for node_id, host in enumerate(hosts):
ip, _ = host ip, _ = host
# When creating training processes in remote machines, we may insert some arguments # When creating training processes in remote machines, we may insert some arguments
# in the commands. We need to use regular expressions to match the modified command. # in the commands. We need to use regular expressions to match the modified command.
cmds = udf_command.split() cmds = udf_command.split()
new_udf_command = ' .*'.join(cmds) new_udf_command = " .*".join(cmds)
pids = get_remote_pids(ip, ssh_port, new_udf_command) pids = get_remote_pids(ip, ssh_port, new_udf_command)
remote_pids[(ip, ssh_port)] = pids remote_pids[(ip, ssh_port)] = pids
return remote_pids return remote_pids
...@@ -174,7 +213,7 @@ def construct_torch_dist_launcher_cmd( ...@@ -174,7 +213,7 @@ def construct_torch_dist_launcher_cmd(
num_nodes: int, num_nodes: int,
node_rank: int, node_rank: int,
master_addr: str, master_addr: str,
master_port: int master_port: int,
) -> str: ) -> str:
"""Constructs the torch distributed launcher command. """Constructs the torch distributed launcher command.
Helper function. Helper function.
...@@ -189,18 +228,20 @@ def construct_torch_dist_launcher_cmd( ...@@ -189,18 +228,20 @@ def construct_torch_dist_launcher_cmd(
Returns: Returns:
cmd_str. cmd_str.
""" """
torch_cmd_template = "-m torch.distributed.launch " \ torch_cmd_template = (
"--nproc_per_node={nproc_per_node} " \ "-m torch.distributed.launch "
"--nnodes={nnodes} " \ "--nproc_per_node={nproc_per_node} "
"--node_rank={node_rank} " \ "--nnodes={nnodes} "
"--master_addr={master_addr} " \ "--node_rank={node_rank} "
"--master_port={master_port}" "--master_addr={master_addr} "
"--master_port={master_port}"
)
return torch_cmd_template.format( return torch_cmd_template.format(
nproc_per_node=num_trainers, nproc_per_node=num_trainers,
nnodes=num_nodes, nnodes=num_nodes,
node_rank=node_rank, node_rank=node_rank,
master_addr=master_addr, master_addr=master_addr,
master_port=master_port master_port=master_port,
) )
...@@ -243,7 +284,7 @@ def wrap_udf_in_torch_dist_launcher( ...@@ -243,7 +284,7 @@ def wrap_udf_in_torch_dist_launcher(
num_nodes=num_nodes, num_nodes=num_nodes,
node_rank=node_rank, node_rank=node_rank,
master_addr=master_addr, master_addr=master_addr,
master_port=master_port master_port=master_port,
) )
# Auto-detect the python binary that kicks off the distributed trainer code. # Auto-detect the python binary that kicks off the distributed trainer code.
# Note: This allowlist order matters, this will match with the FIRST matching entry. Thus, please add names to this # Note: This allowlist order matters, this will match with the FIRST matching entry. Thus, please add names to this
...@@ -251,9 +292,14 @@ def wrap_udf_in_torch_dist_launcher( ...@@ -251,9 +292,14 @@ def wrap_udf_in_torch_dist_launcher(
# (python3.7, python3.8) -> (python3) # (python3.7, python3.8) -> (python3)
# The allowed python versions are from this: https://www.dgl.ai/pages/start.html # The allowed python versions are from this: https://www.dgl.ai/pages/start.html
python_bin_allowlist = ( python_bin_allowlist = (
"python3.6", "python3.7", "python3.8", "python3.9", "python3", "python3.6",
"python3.7",
"python3.8",
"python3.9",
"python3",
# for backwards compatibility, accept python2 but technically DGL is a py3 library, so this is not recommended # for backwards compatibility, accept python2 but technically DGL is a py3 library, so this is not recommended
"python2.7", "python2", "python2.7",
"python2",
) )
# If none of the candidate python bins match, then we go with the default `python` # If none of the candidate python bins match, then we go with the default `python`
python_bin = "python" python_bin = "python"
...@@ -268,7 +314,9 @@ def wrap_udf_in_torch_dist_launcher( ...@@ -268,7 +314,9 @@ def wrap_udf_in_torch_dist_launcher(
# python -m torch.distributed.launch [DIST TORCH ARGS] path/to/dist_trainer.py arg0 arg1 # python -m torch.distributed.launch [DIST TORCH ARGS] path/to/dist_trainer.py arg0 arg1
# Note: if there are multiple python commands in `udf_command`, this may do the Wrong Thing, eg launch each # Note: if there are multiple python commands in `udf_command`, this may do the Wrong Thing, eg launch each
# python command within the torch distributed launcher. # python command within the torch distributed launcher.
new_udf_command = udf_command.replace(python_bin, f"{python_bin} {torch_dist_cmd}") new_udf_command = udf_command.replace(
python_bin, f"{python_bin} {torch_dist_cmd}"
)
return new_udf_command return new_udf_command
...@@ -425,6 +473,7 @@ def wrap_cmd_with_local_envvars(cmd: str, env_vars: str) -> str: ...@@ -425,6 +473,7 @@ def wrap_cmd_with_local_envvars(cmd: str, env_vars: str) -> str:
# https://stackoverflow.com/a/45993803 # https://stackoverflow.com/a/45993803
return f"(export {env_vars}; {cmd})" return f"(export {env_vars}; {cmd})"
def wrap_cmd_with_extra_envvars(cmd: str, env_vars: list) -> str: def wrap_cmd_with_extra_envvars(cmd: str, env_vars: list) -> str:
"""Wraps a CLI command with extra env vars """Wraps a CLI command with extra env vars
...@@ -448,6 +497,7 @@ def wrap_cmd_with_extra_envvars(cmd: str, env_vars: list) -> str: ...@@ -448,6 +497,7 @@ def wrap_cmd_with_extra_envvars(cmd: str, env_vars: list) -> str:
g_monitor_file = None g_monitor_file = None
g_group_id = 0 g_group_id = 0
def has_alive_servers(args): def has_alive_servers(args):
"""Check whether there exists alive servers. """Check whether there exists alive servers.
...@@ -467,23 +517,32 @@ def has_alive_servers(args): ...@@ -467,23 +517,32 @@ def has_alive_servers(args):
return False return False
global g_monitor_file global g_monitor_file
global g_group_id global g_group_id
monitor_file = '/tmp/dgl_dist_monitor_' + args.server_name monitor_file = "/tmp/dgl_dist_monitor_" + args.server_name
from filelock import FileLock from filelock import FileLock
lock = FileLock(monitor_file + '.lock')
lock = FileLock(monitor_file + ".lock")
with lock: with lock:
next_group_id = None next_group_id = None
ret = os.path.exists(monitor_file) ret = os.path.exists(monitor_file)
if ret: if ret:
print("Monitor file for alive servers already exist: {}.".format(monitor_file)) print(
lines = [line.rstrip('\n') for line in open(monitor_file)] "Monitor file for alive servers already exist: {}.".format(
monitor_file
)
)
lines = [line.rstrip("\n") for line in open(monitor_file)]
g_group_id = int(lines[0]) g_group_id = int(lines[0])
next_group_id = g_group_id + 1 next_group_id = g_group_id + 1
if not ret and args.keep_alive: if not ret and args.keep_alive:
next_group_id = 1 next_group_id = 1
print("Monitor file for alive servers is created: {}.".format(monitor_file)) print(
"Monitor file for alive servers is created: {}.".format(
monitor_file
)
)
g_monitor_file = monitor_file g_monitor_file = monitor_file
if next_group_id is not None: if next_group_id is not None:
with open(monitor_file, 'w') as f: with open(monitor_file, "w") as f:
f.write(str(next_group_id)) f.write(str(next_group_id))
return ret return ret
...@@ -494,14 +553,24 @@ def clean_alive_servers(): ...@@ -494,14 +553,24 @@ def clean_alive_servers():
try: try:
if g_monitor_file is not None: if g_monitor_file is not None:
os.remove(g_monitor_file) os.remove(g_monitor_file)
os.remove(g_monitor_file + '.lock') os.remove(g_monitor_file + ".lock")
print("Monitor file for alive servers is removed: {}.".format(g_monitor_file)) print(
"Monitor file for alive servers is removed: {}.".format(
g_monitor_file
)
)
except: except:
print("Failed to delete monitor file for alive servers: {}.".format(g_monitor_file)) print(
"Failed to delete monitor file for alive servers: {}.".format(
g_monitor_file
)
)
def get_available_port(ip): def get_available_port(ip):
"""Get available port with specified ip.""" """Get available port with specified ip."""
import socket import socket
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
for port in range(1234, 65535): for port in range(1234, 65535):
try: try:
...@@ -510,10 +579,13 @@ def get_available_port(ip): ...@@ -510,10 +579,13 @@ def get_available_port(ip):
return port return port
raise RuntimeError("Failed to get available port for ip~{}".format(ip)) raise RuntimeError("Failed to get available port for ip~{}".format(ip))
def submit_jobs(args, udf_command, dry_run=False): def submit_jobs(args, udf_command, dry_run=False):
"""Submit distributed jobs (server and client processes) via ssh""" """Submit distributed jobs (server and client processes) via ssh"""
if dry_run: if dry_run:
print("Currently it's in dry run mode which means no jobs will be launched.") print(
"Currently it's in dry run mode which means no jobs will be launched."
)
servers_cmd = [] servers_cmd = []
clients_cmd = [] clients_cmd = []
hosts = [] hosts = []
...@@ -540,10 +612,11 @@ def submit_jobs(args, udf_command, dry_run=False): ...@@ -540,10 +612,11 @@ def submit_jobs(args, udf_command, dry_run=False):
part_config = os.path.join(args.workspace, args.part_config) part_config = os.path.join(args.workspace, args.part_config)
with open(part_config) as conf_f: with open(part_config) as conf_f:
part_metadata = json.load(conf_f) part_metadata = json.load(conf_f)
assert 'num_parts' in part_metadata, 'num_parts does not exist.' assert "num_parts" in part_metadata, "num_parts does not exist."
# The number of partitions must match the number of machines in the cluster. # The number of partitions must match the number of machines in the cluster.
assert part_metadata['num_parts'] == len(hosts), \ assert part_metadata["num_parts"] == len(
'The number of graph partitions has to match the number of machines in the cluster.' hosts
), "The number of graph partitions has to match the number of machines in the cluster."
state_q = queue.Queue() state_q = queue.Queue()
tot_num_clients = args.num_trainers * (1 + args.num_samplers) * len(hosts) tot_num_clients = args.num_trainers * (1 + args.num_samplers) * len(hosts)
...@@ -564,11 +637,23 @@ def submit_jobs(args, udf_command, dry_run=False): ...@@ -564,11 +637,23 @@ def submit_jobs(args, udf_command, dry_run=False):
ip, _ = hosts[int(i / server_count_per_machine)] ip, _ = hosts[int(i / server_count_per_machine)]
server_env_vars_cur = f"{server_env_vars} DGL_SERVER_ID={i}" server_env_vars_cur = f"{server_env_vars} DGL_SERVER_ID={i}"
cmd = wrap_cmd_with_local_envvars(udf_command, server_env_vars_cur) cmd = wrap_cmd_with_local_envvars(udf_command, server_env_vars_cur)
cmd = wrap_cmd_with_extra_envvars(cmd, args.extra_envs) if len(args.extra_envs) > 0 else cmd cmd = (
cmd = 'cd ' + str(args.workspace) + '; ' + cmd wrap_cmd_with_extra_envvars(cmd, args.extra_envs)
if len(args.extra_envs) > 0
else cmd
)
cmd = "cd " + str(args.workspace) + "; " + cmd
servers_cmd.append(cmd) servers_cmd.append(cmd)
if not dry_run: if not dry_run:
thread_list.append(execute_remote(cmd, state_q, ip, args.ssh_port, username=args.ssh_username)) thread_list.append(
execute_remote(
cmd,
state_q,
ip,
args.ssh_port,
username=args.ssh_username,
)
)
else: else:
print(f"Use running server {args.server_name}.") print(f"Use running server {args.server_name}.")
...@@ -580,7 +665,9 @@ def submit_jobs(args, udf_command, dry_run=False): ...@@ -580,7 +665,9 @@ def submit_jobs(args, udf_command, dry_run=False):
ip_config=args.ip_config, ip_config=args.ip_config,
num_servers=args.num_servers, num_servers=args.num_servers,
graph_format=args.graph_format, graph_format=args.graph_format,
num_omp_threads=os.environ.get("OMP_NUM_THREADS", str(args.num_omp_threads)), num_omp_threads=os.environ.get(
"OMP_NUM_THREADS", str(args.num_omp_threads)
),
group_id=g_group_id, group_id=g_group_id,
pythonpath=os.environ.get("PYTHONPATH", ""), pythonpath=os.environ.get("PYTHONPATH", ""),
) )
...@@ -596,31 +683,42 @@ def submit_jobs(args, udf_command, dry_run=False): ...@@ -596,31 +683,42 @@ def submit_jobs(args, udf_command, dry_run=False):
num_nodes=len(hosts), num_nodes=len(hosts),
node_rank=node_id, node_rank=node_id,
master_addr=master_addr, master_addr=master_addr,
master_port=master_port master_port=master_port,
)
cmd = wrap_cmd_with_local_envvars(
torch_dist_udf_command, client_env_vars
) )
cmd = wrap_cmd_with_local_envvars(torch_dist_udf_command, client_env_vars) cmd = (
cmd = wrap_cmd_with_extra_envvars(cmd, args.extra_envs) if len(args.extra_envs) > 0 else cmd wrap_cmd_with_extra_envvars(cmd, args.extra_envs)
cmd = 'cd ' + str(args.workspace) + '; ' + cmd if len(args.extra_envs) > 0
else cmd
)
cmd = "cd " + str(args.workspace) + "; " + cmd
clients_cmd.append(cmd) clients_cmd.append(cmd)
if not dry_run: if not dry_run:
thread_list.append(execute_remote(cmd, state_q, ip, args.ssh_port, username=args.ssh_username)) thread_list.append(
execute_remote(
cmd, state_q, ip, args.ssh_port, username=args.ssh_username
)
)
# return commands of clients/servers directly if in dry run mode # return commands of clients/servers directly if in dry run mode
if dry_run: if dry_run:
return clients_cmd, servers_cmd return clients_cmd, servers_cmd
# Start a cleanup process dedicated for cleaning up remote training jobs. # Start a cleanup process dedicated for cleaning up remote training jobs.
conn1,conn2 = multiprocessing.Pipe() conn1, conn2 = multiprocessing.Pipe()
func = partial(get_all_remote_pids, hosts, args.ssh_port, udf_command) func = partial(get_all_remote_pids, hosts, args.ssh_port, udf_command)
process = multiprocessing.Process(target=cleanup_proc, args=(func, conn1)) process = multiprocessing.Process(target=cleanup_proc, args=(func, conn1))
process.start() process.start()
def signal_handler(signal, frame): def signal_handler(signal, frame):
logging.info('Stop launcher') logging.info("Stop launcher")
# We need to tell the cleanup process to kill remote training jobs. # We need to tell the cleanup process to kill remote training jobs.
conn2.send('cleanup') conn2.send("cleanup")
clean_alive_servers() clean_alive_servers()
sys.exit(0) sys.exit(0)
signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGINT, signal_handler)
err = 0 err = 0
...@@ -633,81 +731,144 @@ def submit_jobs(args, udf_command, dry_run=False): ...@@ -633,81 +731,144 @@ def submit_jobs(args, udf_command, dry_run=False):
err = err_code err = err_code
# The training processes complete. We should tell the cleanup process to exit. # The training processes complete. We should tell the cleanup process to exit.
conn2.send('exit') conn2.send("exit")
process.join() process.join()
if err != 0: if err != 0:
print("Task failed") print("Task failed")
sys.exit(-1) sys.exit(-1)
def main(): def main():
parser = argparse.ArgumentParser(description='Launch a distributed job') parser = argparse.ArgumentParser(description="Launch a distributed job")
parser.add_argument('--ssh_port', type=int, default=22, help='SSH Port.') parser.add_argument("--ssh_port", type=int, default=22, help="SSH Port.")
parser.add_argument( parser.add_argument(
"--ssh_username", default="", "--ssh_username",
default="",
help="Optional. When issuing commands (via ssh) to cluster, use the provided username in the ssh cmd. " help="Optional. When issuing commands (via ssh) to cluster, use the provided username in the ssh cmd. "
"Example: If you provide --ssh_username=bob, then the ssh command will be like: 'ssh bob@1.2.3.4 CMD' " "Example: If you provide --ssh_username=bob, then the ssh command will be like: 'ssh bob@1.2.3.4 CMD' "
"instead of 'ssh 1.2.3.4 CMD'" "instead of 'ssh 1.2.3.4 CMD'",
) )
parser.add_argument('--workspace', type=str, parser.add_argument(
help='Path of user directory of distributed tasks. \ "--workspace",
type=str,
help="Path of user directory of distributed tasks. \
This is used to specify a destination location where \ This is used to specify a destination location where \
the contents of current directory will be rsyncd') the contents of current directory will be rsyncd",
parser.add_argument('--num_trainers', type=int, )
help='The number of trainer processes per machine') parser.add_argument(
parser.add_argument('--num_omp_threads', type=int, "--num_trainers",
help='The number of OMP threads per trainer') type=int,
parser.add_argument('--num_samplers', type=int, default=0, help="The number of trainer processes per machine",
help='The number of sampler processes per trainer process') )
parser.add_argument('--num_servers', type=int, parser.add_argument(
help='The number of server processes per machine') "--num_omp_threads",
parser.add_argument('--part_config', type=str, type=int,
help='The file (in workspace) of the partition config') help="The number of OMP threads per trainer",
parser.add_argument('--ip_config', type=str, )
help='The file (in workspace) of IP configuration for server processes') parser.add_argument(
parser.add_argument('--num_server_threads', type=int, default=1, "--num_samplers",
help='The number of OMP threads in the server process. \ type=int,
default=0,
help="The number of sampler processes per trainer process",
)
parser.add_argument(
"--num_servers",
type=int,
help="The number of server processes per machine",
)
parser.add_argument(
"--part_config",
type=str,
help="The file (in workspace) of the partition config",
)
parser.add_argument(
"--ip_config",
type=str,
help="The file (in workspace) of IP configuration for server processes",
)
parser.add_argument(
"--num_server_threads",
type=int,
default=1,
help="The number of OMP threads in the server process. \
It should be small if server processes and trainer processes run on \ It should be small if server processes and trainer processes run on \
the same machine. By default, it is 1.') the same machine. By default, it is 1.",
parser.add_argument('--graph_format', type=str, default='csc', )
help='The format of the graph structure of each partition. \ parser.add_argument(
"--graph_format",
type=str,
default="csc",
help='The format of the graph structure of each partition. \
The allowed formats are csr, csc and coo. A user can specify multiple \ The allowed formats are csr, csc and coo. A user can specify multiple \
formats, separated by ",". For example, the graph format is "csr,csc".') formats, separated by ",". For example, the graph format is "csr,csc".',
parser.add_argument('--extra_envs', nargs='+', type=str, default=[], )
help='Extra environment parameters need to be set. For example, \ parser.add_argument(
"--extra_envs",
nargs="+",
type=str,
default=[],
help="Extra environment parameters need to be set. For example, \
you can set the LD_LIBRARY_PATH and NCCL_DEBUG by adding: \ you can set the LD_LIBRARY_PATH and NCCL_DEBUG by adding: \
--extra_envs LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH NCCL_DEBUG=INFO ') --extra_envs LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH NCCL_DEBUG=INFO ",
parser.add_argument('--keep_alive', action='store_true', help='Servers keep alive when clients exit') )
parser.add_argument('--server_name', type=str, parser.add_argument(
help='Used to check whether there exist alive servers') "--keep_alive",
action="store_true",
help="Servers keep alive when clients exit",
)
parser.add_argument(
"--server_name",
type=str,
help="Used to check whether there exist alive servers",
)
args, udf_command = parser.parse_known_args() args, udf_command = parser.parse_known_args()
if args.keep_alive: if args.keep_alive:
assert args.server_name is not None, "Server name is required if '--keep_alive' is enabled." assert (
args.server_name is not None
), "Server name is required if '--keep_alive' is enabled."
print("Servers will keep alive even clients exit...") print("Servers will keep alive even clients exit...")
assert len(udf_command) == 1, 'Please provide user command line.' assert len(udf_command) == 1, "Please provide user command line."
assert args.num_trainers is not None and args.num_trainers > 0, \ assert (
'--num_trainers must be a positive number.' args.num_trainers is not None and args.num_trainers > 0
assert args.num_samplers is not None and args.num_samplers >= 0, \ ), "--num_trainers must be a positive number."
'--num_samplers must be a non-negative number.' assert (
assert args.num_servers is not None and args.num_servers > 0, \ args.num_samplers is not None and args.num_samplers >= 0
'--num_servers must be a positive number.' ), "--num_samplers must be a non-negative number."
assert args.num_server_threads > 0, '--num_server_threads must be a positive number.' assert (
assert args.workspace is not None, 'A user has to specify a workspace with --workspace.' args.num_servers is not None and args.num_servers > 0
assert args.part_config is not None, \ ), "--num_servers must be a positive number."
'A user has to specify a partition configuration file with --part_config.' assert (
assert args.ip_config is not None, \ args.num_server_threads > 0
'A user has to specify an IP configuration file with --ip_config.' ), "--num_server_threads must be a positive number."
assert (
args.workspace is not None
), "A user has to specify a workspace with --workspace."
assert (
args.part_config is not None
), "A user has to specify a partition configuration file with --part_config."
assert (
args.ip_config is not None
), "A user has to specify an IP configuration file with --ip_config."
if args.num_omp_threads is None: if args.num_omp_threads is None:
# Here we assume all machines have the same number of CPU cores as the machine # Here we assume all machines have the same number of CPU cores as the machine
# where the launch script runs. # where the launch script runs.
args.num_omp_threads = max(multiprocessing.cpu_count() // 2 // args.num_trainers, 1) args.num_omp_threads = max(
print('The number of OMP threads per trainer is set to', args.num_omp_threads) multiprocessing.cpu_count() // 2 // args.num_trainers, 1
)
print(
"The number of OMP threads per trainer is set to",
args.num_omp_threads,
)
udf_command = str(udf_command[0]) udf_command = str(udf_command[0])
if 'python' not in udf_command: if "python" not in udf_command:
raise RuntimeError("DGL launching script can only support Python executable file.") raise RuntimeError(
"DGL launching script can only support Python executable file."
)
submit_jobs(args, udf_command) submit_jobs(args, udf_command)
if __name__ == '__main__':
fmt = '%(asctime)s %(levelname)s %(message)s' if __name__ == "__main__":
fmt = "%(asctime)s %(levelname)s %(message)s"
logging.basicConfig(format=fmt, level=logging.INFO) logging.basicConfig(format=fmt, level=logging.INFO)
main() main()
# Requires setting PYTHONPATH=${GITROOT}/tools # Requires setting PYTHONPATH=${GITROOT}/tools
import argparse
import json import json
import logging import logging
import sys
import os import os
import numpy as np import sys
import argparse
from utils import setdir import numpy as np
from utils import array_readwriter
from base import PartitionMeta, dump_partition_meta from base import PartitionMeta, dump_partition_meta
from utils import array_readwriter, setdir
def _random_partition(metadata, num_parts): def _random_partition(metadata, num_parts):
num_nodes_per_type = [sum(_) for _ in metadata['num_nodes_per_chunk']] num_nodes_per_type = [sum(_) for _ in metadata["num_nodes_per_chunk"]]
ntypes = metadata['node_type'] ntypes = metadata["node_type"]
for ntype, n in zip(ntypes, num_nodes_per_type): for ntype, n in zip(ntypes, num_nodes_per_type):
logging.info('Generating partition for node type %s' % ntype) logging.info("Generating partition for node type %s" % ntype)
parts = np.random.randint(0, num_parts, (n,)) parts = np.random.randint(0, num_parts, (n,))
array_readwriter.get_array_parser(name='csv').write(ntype + '.txt', parts) array_readwriter.get_array_parser(name="csv").write(
ntype + ".txt", parts
)
def random_partition(metadata, num_parts, output_path): def random_partition(metadata, num_parts, output_path):
""" """
...@@ -31,22 +34,28 @@ def random_partition(metadata, num_parts, output_path): ...@@ -31,22 +34,28 @@ def random_partition(metadata, num_parts, output_path):
""" """
with setdir(output_path): with setdir(output_path):
_random_partition(metadata, num_parts) _random_partition(metadata, num_parts)
part_meta = PartitionMeta(version='1.0.0', num_parts=num_parts, algo_name='random') part_meta = PartitionMeta(
dump_partition_meta(part_meta, 'partition_meta.json') version="1.0.0", num_parts=num_parts, algo_name="random"
)
dump_partition_meta(part_meta, "partition_meta.json")
# Run with PYTHONPATH=${GIT_ROOT_DIR}/tools # Run with PYTHONPATH=${GIT_ROOT_DIR}/tools
# where ${GIT_ROOT_DIR} is the directory to the DGL git repository. # where ${GIT_ROOT_DIR} is the directory to the DGL git repository.
if __name__ == '__main__': if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
'--in_dir', type=str, help='input directory that contains the metadata file') "--in_dir",
parser.add_argument( type=str,
'--out_dir', type=str, help='output directory') help="input directory that contains the metadata file",
)
parser.add_argument("--out_dir", type=str, help="output directory")
parser.add_argument( parser.add_argument(
'--num_partitions', type=int, help='number of partitions') "--num_partitions", type=int, help="number of partitions"
logging.basicConfig(level='INFO') )
logging.basicConfig(level="INFO")
args = parser.parse_args() args = parser.parse_args()
with open(os.path.join(args.in_dir, 'metadata.json')) as f: with open(os.path.join(args.in_dir, "metadata.json")) as f:
metadata = json.load(f) metadata = json.load(f)
num_parts = args.num_partitions num_parts = args.num_partitions
random_partition(metadata, num_parts, args.out_dir) random_partition(metadata, num_parts, args.out_dir)
from .files import *
from . import array_readwriter from . import array_readwriter
from .files import *
from .registry import register_array_parser, get_array_parser from . import csv, numpy_array
from .registry import get_array_parser, register_array_parser
from . import csv
from . import numpy_array
import logging import logging
import pandas as pd import pandas as pd
import pyarrow import pyarrow
import pyarrow.csv import pyarrow.csv
from .registry import register_array_parser from .registry import register_array_parser
@register_array_parser("csv") @register_array_parser("csv")
class CSVArrayParser(object): class CSVArrayParser(object):
def __init__(self, delimiter=','): def __init__(self, delimiter=","):
self.delimiter = delimiter self.delimiter = delimiter
def read(self, path): def read(self, path):
logging.info('Reading from %s using CSV format with configuration %s' % ( logging.info(
path, self.__dict__)) "Reading from %s using CSV format with configuration %s"
% (path, self.__dict__)
)
# do not read the first line as header # do not read the first line as header
read_options = pyarrow.csv.ReadOptions(autogenerate_column_names=True) read_options = pyarrow.csv.ReadOptions(autogenerate_column_names=True)
parse_options = pyarrow.csv.ParseOptions(delimiter=self.delimiter) parse_options = pyarrow.csv.ParseOptions(delimiter=self.delimiter)
arr = pyarrow.csv.read_csv(path, read_options=read_options, parse_options=parse_options) arr = pyarrow.csv.read_csv(
logging.info('Done reading from %s' % path) path, read_options=read_options, parse_options=parse_options
)
logging.info("Done reading from %s" % path)
return arr.to_pandas().to_numpy() return arr.to_pandas().to_numpy()
def write(self, path, arr): def write(self, path, arr):
logging.info('Writing to %s using CSV format with configuration %s' % ( logging.info(
path, self.__dict__)) "Writing to %s using CSV format with configuration %s"
write_options = pyarrow.csv.WriteOptions(include_header=False, delimiter=self.delimiter) % (path, self.__dict__)
)
write_options = pyarrow.csv.WriteOptions(
include_header=False, delimiter=self.delimiter
)
arr = pyarrow.Table.from_pandas(pd.DataFrame(arr)) arr = pyarrow.Table.from_pandas(pd.DataFrame(arr))
pyarrow.csv.write_csv(arr, path, write_options=write_options) pyarrow.csv.write_csv(arr, path, write_options=write_options)
logging.info('Done writing to %s' % path) logging.info("Done writing to %s" % path)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment