Unverified Commit ea4d9e83 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[Dist] fix etype issue in dist part pipeline (#4754)

* [Dist] fix etype issue in dist part pipeline

* add comments
parent ec17136e
...@@ -65,6 +65,7 @@ def create_chunked_dataset( ...@@ -65,6 +65,7 @@ def create_chunked_dataset(
paper_label = np.random.choice(num_classes, num_papers) paper_label = np.random.choice(num_classes, num_papers)
paper_year = np.random.choice(2022, num_papers) paper_year = np.random.choice(2022, num_papers)
paper_orig_ids = np.arange(0, num_papers) paper_orig_ids = np.arange(0, num_papers)
writes_orig_ids = np.arange(0, g.num_edges('writes'))
# masks. # masks.
if include_masks: if include_masks:
...@@ -93,26 +94,38 @@ def create_chunked_dataset( ...@@ -93,26 +94,38 @@ def create_chunked_dataset(
paper_feat_path = os.path.join(input_dir, 'paper/feat.npy') paper_feat_path = os.path.join(input_dir, 'paper/feat.npy')
with open(paper_feat_path, 'wb') as f: with open(paper_feat_path, 'wb') as f:
np.save(f, paper_feat) np.save(f, paper_feat)
g.nodes['paper'].data['feat'] = torch.from_numpy(paper_feat)
paper_label_path = os.path.join(input_dir, 'paper/label.npy') paper_label_path = os.path.join(input_dir, 'paper/label.npy')
with open(paper_label_path, 'wb') as f: with open(paper_label_path, 'wb') as f:
np.save(f, paper_label) np.save(f, paper_label)
g.nodes['paper'].data['label'] = torch.from_numpy(paper_label)
paper_year_path = os.path.join(input_dir, 'paper/year.npy') paper_year_path = os.path.join(input_dir, 'paper/year.npy')
with open(paper_year_path, 'wb') as f: with open(paper_year_path, 'wb') as f:
np.save(f, paper_year) np.save(f, paper_year)
g.nodes['paper'].data['year'] = torch.from_numpy(paper_year)
paper_orig_ids_path = os.path.join(input_dir, 'paper/orig_ids.npy') paper_orig_ids_path = os.path.join(input_dir, 'paper/orig_ids.npy')
with open(paper_orig_ids_path, 'wb') as f: with open(paper_orig_ids_path, 'wb') as f:
np.save(f, paper_orig_ids) np.save(f, paper_orig_ids)
g.nodes['paper'].data['orig_ids'] = torch.from_numpy(paper_orig_ids)
cite_count_path = os.path.join(input_dir, 'cites/count.npy') cite_count_path = os.path.join(input_dir, 'cites/count.npy')
with open(cite_count_path, 'wb') as f: with open(cite_count_path, 'wb') as f:
np.save(f, cite_count) np.save(f, cite_count)
g.edges['cites'].data['count'] = torch.from_numpy(cite_count)
write_year_path = os.path.join(input_dir, 'writes/year.npy') write_year_path = os.path.join(input_dir, 'writes/year.npy')
with open(write_year_path, 'wb') as f: with open(write_year_path, 'wb') as f:
np.save(f, write_year) np.save(f, write_year)
g.edges['writes'].data['year'] = torch.from_numpy(write_year)
g.edges['rev_writes'].data['year'] = torch.from_numpy(write_year)
writes_orig_ids_path = os.path.join(input_dir, 'writes/orig_ids.npy')
with open(writes_orig_ids_path, 'wb') as f:
np.save(f, writes_orig_ids)
g.edges['writes'].data['orig_ids'] = torch.from_numpy(writes_orig_ids)
node_data = None node_data = None
if include_masks: if include_masks:
...@@ -193,7 +206,10 @@ def create_chunked_dataset( ...@@ -193,7 +206,10 @@ def create_chunked_dataset(
edge_data = { edge_data = {
'cites': {'count': cite_count_path}, 'cites': {'count': cite_count_path},
'writes': {'year': write_year_path}, 'writes': {
'year': write_year_path,
'orig_ids': writes_orig_ids_path
},
'rev_writes': {'year': write_year_path}, 'rev_writes': {'year': write_year_path},
} }
......
...@@ -58,7 +58,8 @@ def _verify_graph_feats( ...@@ -58,7 +58,8 @@ def _verify_graph_feats(
ndata = node_feats[ntype + "/" + name][local_nids] ndata = node_feats[ntype + "/" + name][local_nids]
assert torch.equal(ndata, true_feats) assert torch.equal(ndata, true_feats)
for etype in g.etypes: for c_etype in g.canonical_etypes:
etype = c_etype[1]
etype_id = g.get_etype_id(etype) etype_id = g.get_etype_id(etype)
inner_edge_mask = _get_inner_edge_mask(part, etype_id) inner_edge_mask = _get_inner_edge_mask(part, etype_id)
inner_eids = part.edata[dgl.EID][inner_edge_mask] inner_eids = part.edata[dgl.EID][inner_edge_mask]
...@@ -75,7 +76,7 @@ def _verify_graph_feats( ...@@ -75,7 +76,7 @@ def _verify_graph_feats(
continue continue
true_feats = g.edges[etype].data[name][orig_id] true_feats = g.edges[etype].data[name][orig_id]
edata = edge_feats[etype + "/" + name][local_eids] edata = edge_feats[etype + "/" + name][local_eids]
assert torch.equal(edata == true_feats) assert torch.equal(edata, true_feats)
@pytest.mark.parametrize("num_chunks", [1, 8]) @pytest.mark.parametrize("num_chunks", [1, 8])
...@@ -119,13 +120,17 @@ def test_chunk_graph(num_chunks): ...@@ -119,13 +120,17 @@ def test_chunk_graph(num_chunks):
# check node_data # check node_data
output_node_data_dir = os.path.join(output_dir, "node_data", "paper") output_node_data_dir = os.path.join(output_dir, "node_data", "paper")
for feat in ["feat", "label", "year"]: for feat in ["feat", "label", "year", "orig_ids"]:
feat_data = []
for i in range(num_chunks): for i in range(num_chunks):
chunk_f_name = "{}-{}.npy".format(feat, i) chunk_f_name = "{}-{}.npy".format(feat, i)
chunk_f_name = os.path.join(output_node_data_dir, chunk_f_name) chunk_f_name = os.path.join(output_node_data_dir, chunk_f_name)
assert os.path.isfile(chunk_f_name) assert os.path.isfile(chunk_f_name)
feat_array = np.load(chunk_f_name) feat_array = np.load(chunk_f_name)
assert feat_array.shape[0] == num_papers // num_chunks assert feat_array.shape[0] == num_papers // num_chunks
feat_data.append(feat_array)
feat_data = np.concatenate(feat_data, 0)
assert torch.equal(torch.from_numpy(feat_data), g.nodes['paper'].data[feat])
# check edge_data # check edge_data
num_edges = { num_edges = {
...@@ -137,15 +142,21 @@ def test_chunk_graph(num_chunks): ...@@ -137,15 +142,21 @@ def test_chunk_graph(num_chunks):
for etype, feat in [ for etype, feat in [
["paper:cites:paper", "count"], ["paper:cites:paper", "count"],
["author:writes:paper", "year"], ["author:writes:paper", "year"],
["author:writes:paper", "orig_ids"],
["paper:rev_writes:author", "year"], ["paper:rev_writes:author", "year"],
]: ]:
feat_data = []
output_edge_sub_dir = os.path.join(output_edge_data_dir, etype) output_edge_sub_dir = os.path.join(output_edge_data_dir, etype)
for i in range(num_chunks): for i in range(num_chunks):
chunk_f_name = "{}-{}.npy".format(feat, i) chunk_f_name = "{}-{}.npy".format(feat, i)
chunk_f_name = os.path.join(output_edge_sub_dir, chunk_f_name) chunk_f_name = os.path.join(output_edge_sub_dir, chunk_f_name)
assert os.path.isfile(chunk_f_name) assert os.path.isfile(chunk_f_name)
feat_array = np.load(chunk_f_name) feat_array = np.load(chunk_f_name)
assert feat_array.shape[0] == num_edges[etype] // num_chunks assert feat_array.shape[0] == num_edges[etype] // num_chunks
feat_data.append(feat_array)
feat_data = np.concatenate(feat_data, 0)
assert torch.equal(torch.from_numpy(feat_data),
g.edges[etype.split(':')[1]].data[feat])
@pytest.mark.parametrize("num_chunks", [1, 3, 8]) @pytest.mark.parametrize("num_chunks", [1, 3, 8])
......
...@@ -377,6 +377,15 @@ def write_edge_features(edge_features, edge_file): ...@@ -377,6 +377,15 @@ def write_edge_features(edge_features, edge_file):
edge_file : string edge_file : string
File in which the edge information is serialized File in which the edge information is serialized
""" """
# TODO[Rui]: Below is a temporary fix for etype and will be
# further refined in the near future as we'll shift to canonical
# etypes entirely.
def format_etype(etype):
etype, name = etype.split('/')
etype = etype.split(':')[1]
return etype + '/' + name
edge_features = {format_etype(etype):
data for etype, data in edge_features.items()}
dgl.data.utils.save_tensors(edge_file, edge_features) dgl.data.utils.save_tensors(edge_file, edge_features)
def write_graph_dgl(graph_file, graph_obj, formats, sort_etypes): def write_graph_dgl(graph_file, graph_obj, formats, sort_etypes):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment