"git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "524e656dd696e91f3bf15054d964ba3e5716f226"
Unverified Commit dccf1f16 authored by peizhou001's avatar peizhou001 Committed by GitHub
Browse files

[Dist] remove dependecy of load_partition_book in change tool (#4802)



* remove dependecy of load_partition_book in change tool

* fix issue

* fix issue
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-16-19.ap-northeast-1.compute.internal>
parent 8ae50c42
...@@ -9,11 +9,13 @@ import torch ...@@ -9,11 +9,13 @@ import torch
import dgl import dgl
from dgl._ffi.base import DGLError from dgl._ffi.base import DGLError
from dgl.data.utils import load_graphs from dgl.data.utils import load_graphs
from dgl.distributed import load_partition_book from dgl.utils import toindex
etypes_key = "etypes" ETYPES_KEY = "etypes"
edge_map_key = "edge_map" EDGE_MAP_KEY = "edge_map"
canonical_etypes_delimiter = ":" NTYPES_KEY = "ntypes"
NUM_PARTS_KEY = "num_parts"
CANONICAL_ETYPE_DELIMITER = ":"
def convert_conf(part_config): def convert_conf(part_config):
...@@ -22,95 +24,105 @@ def convert_conf(part_config): ...@@ -22,95 +24,105 @@ def convert_conf(part_config):
logging.info("Checking if the provided json file need to be changed.") logging.info("Checking if the provided json file need to be changed.")
if is_old_version(config): if is_old_version(config):
logging.info("Changing the partition configuration file.") logging.info("Changing the partition configuration file.")
canonical_etypes = etype2canonical_etype(part_config) canonical_etypes = {}
# convert edge_map key from etype -> c_etype if len(config[NTYPES_KEY]) == 1:
ntype = list(config[NTYPES_KEY].keys())[0]
canonical_etypes = {
CANONICAL_ETYPE_DELIMITER.join((ntype, etype, ntype)): eid
for etype, eid in config[ETYPES_KEY].items()
}
else:
canonical_etypes = etype2canonical_etype(part_config, config)
reverse_c_etypes = {v: k for k, v in canonical_etypes.items()}
# Convert edge_map keys from etype -> c_etype.
new_edge_map = {} new_edge_map = {}
for e_type, range in config[edge_map_key].items(): for e_type, range in config[EDGE_MAP_KEY].items():
eid = config[etypes_key][e_type] eid = config[ETYPES_KEY][e_type]
c_etype = [ c_etype = reverse_c_etypes[eid]
key
for key in canonical_etypes
if canonical_etypes[key] == eid
][0]
new_edge_map[c_etype] = range new_edge_map[c_etype] = range
config[edge_map_key] = new_edge_map config[EDGE_MAP_KEY] = new_edge_map
config[etypes_key] = canonical_etypes config[ETYPES_KEY] = canonical_etypes
logging.info("Dumping the content to disk.") logging.info("Dumping the content to disk.")
f.seek(0) f.seek(0)
json.dump(config, f, indent=4) json.dump(config, f, indent=4)
f.truncate() f.truncate()
def etype2canonical_etype(part_config): def etype2canonical_etype(part_config, config):
gpb, _, _, etypes = load_partition_book(part_config=part_config, part_id=0) num_parts = config[NUM_PARTS_KEY]
eid = [] edge_map = config[EDGE_MAP_KEY]
etype_id = [] etypes = list(edge_map.keys())
for etype in etypes: # Get part id of each seed edge.
type_eid = torch.zeros((1,), dtype=torch.int64) partition_ids = []
eid.append(gpb.map_to_homo_eid(type_eid, etype)) for _, bound in edge_map.items():
etype_id.append(etypes[etype]) for i in range(num_parts):
eid = torch.cat(eid, 0) if bound[i][1] > bound[i][0]:
etype_id = torch.IntTensor(etype_id) partition_ids.append(i)
partition_id = gpb.eid2partid(eid) break
partition_ids = torch.tensor(partition_ids)
# Get starting index of each partition.
shifts = []
for i in range(num_parts):
shifts.append(edge_map[etypes[0]][i][0])
shifts = torch.tensor(shifts)
canonical_etypes = {} canonical_etypes = {}
part_ids = [ part_ids = [
part_id part_id for part_id in range(num_parts) if part_id in partition_ids
for part_id in range(gpb.num_partitions())
if part_id in partition_id
] ]
for part_id in part_ids: for part_id in part_ids:
seed_edges = torch.masked_select(eid, partition_id == part_id) seed_etypes = [
seed_edge_tids = torch.masked_select(etype_id, partition_id == part_id) etypes[i] for i in range(len(etypes)) if partition_ids[i] == part_id
]
c_etype = _find_c_etypes_in_partition( c_etype = _find_c_etypes_in_partition(
seed_edges, seed_edge_tids, part_id, part_config part_id,
seed_etypes,
config[ETYPES_KEY],
config[NTYPES_KEY],
edge_map,
shifts,
part_config,
) )
canonical_etypes.update(c_etype) canonical_etypes.update(c_etype)
return canonical_etypes return canonical_etypes
def _find_c_etypes_in_partition( def _find_c_etypes_in_partition(
seed_edges, seed_edge_tids, part_id, part_config part_id, seed_etypes, etypes, ntypes, edge_map, shifts, config_path
): ):
folder = os.path.dirname(os.path.realpath(part_config))
partition_book = {}
local_g = dgl.DGLGraph()
try: try:
folder = os.path.dirname(os.path.realpath(config_path))
local_g = load_graphs(f"{folder}/part{part_id}/graph.dgl")[0][0] local_g = load_graphs(f"{folder}/part{part_id}/graph.dgl")[0][0]
partition_book = load_partition_book( local_eids = [
part_config=part_config, part_id=part_id edge_map[etype][part_id][0] - shifts[part_id]
)[0] for etype in seed_etypes
]
local_eids = toindex(torch.tensor(local_eids))
local_eids = local_eids.tousertensor()
local_src, local_dst = local_g.find_edges(local_eids)
src_ntids, dst_ntids = (
local_g.ndata[dgl.NTYPE][local_src],
local_g.ndata[dgl.NTYPE][local_dst],
)
ntypes = {v: k for k, v in ntypes.items()}
src_ntypes = [ntypes[ntid.item()] for ntid in src_ntids]
dst_ntypes = [ntypes[ntid.item()] for ntid in dst_ntids]
c_etypes = list(zip(src_ntypes, seed_etypes, dst_ntypes))
c_etypes = [
CANONICAL_ETYPE_DELIMITER.join(c_etype) for c_etype in c_etypes
]
return {k: etypes[v] for (k, v) in zip(c_etypes, seed_etypes)}
except DGLError as e: except DGLError as e:
print(e)
logging.fatal( logging.fatal(
f"Graph data of partition {part_id} is requested but not found." f"Graph data of partition {part_id} is requested but not found."
) )
raise e
ntypes, etypes = partition_book.ntypes, partition_book.etypes
src, dst = _find_edges(local_g, partition_book, seed_edges)
src_tids, _ = partition_book.map_to_per_ntype(src)
dst_tids, _ = partition_book.map_to_per_ntype(dst)
canonical_etypes = {}
for src_tid, etype_id, dst_tid in zip(src_tids, seed_edge_tids, dst_tids):
src_tid = src_tid.item()
etype_id = etype_id.item()
dst_tid = dst_tid.item()
c_etype = (ntypes[src_tid], etypes[etype_id], ntypes[dst_tid])
canonical_etypes[canonical_etypes_delimiter.join(c_etype)] = etype_id
return canonical_etypes
def _find_edges(local_g, partition_book, seed_edges):
local_eids = partition_book.eid2localeid(seed_edges, partition_book.partid)
local_src, local_dst = local_g.find_edges(local_eids)
global_nid_mapping = local_g.ndata[dgl.NID]
global_src = global_nid_mapping[local_src]
global_dst = global_nid_mapping[local_dst]
return global_src, global_dst
def is_old_version(config): def is_old_version(config):
first_etype = list(config[etypes_key].keys())[0] first_etype = list(config[ETYPES_KEY].keys())[0]
etype_tuple = first_etype.split(canonical_etypes_delimiter) etype_tuple = first_etype.split(CANONICAL_ETYPE_DELIMITER)
return len(etype_tuple) == 1 return len(etype_tuple) == 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment