"docs/source/vscode:/vscode.git/clone" did not exist on "f00cd6efbd00b0273f58c393a617415b5d1d410e"
Unverified Commit 61b6edab authored by kylasa's avatar kylasa Committed by GitHub
Browse files

[DistDGL][Mem_Optimizations]get_partition_ids, service provided by the...

[DistDGL][Mem_Optimizations]get_partition_ids, service provided by the distributed lookup service has high memory footprint (#5226)

* get_partition_ids, service provided by the distributed lookup service has high memory footprint

'get_partitionid' function, which is used to retrieve owner processes of the given list of global node ids, has high memory footprint. Currently this is of the order of 8x compared to the size of the input list.

For massively large datasets, this memory needs are very unrealistic and may result in OOM. In the case of CoreGraph, when retrieving owner of an edge list of size 6 Billion edges, the memory needs can be as high as 8*8*8 = 256 GB.

To limit the amount of memory used by this function, we split the size of the message sent to the distributed lookup service, so that each message is limited by the number of global node ids, which is 200 million. This reduced the memory footprint of this entire function to be no more than 0.2 * 8 * 8 = 13 GB. which is within reasonable limits.

Now since we send multiple small messages compared to one large message to the distributed lookup service, this may consume more wall-clock-time compared to earlier implementation.

* lintrunner patch.

* using np.ceil() per suggestion.

* converting the output of np.ceil() as ints.
parent 99937422
......@@ -1000,6 +1000,7 @@ def gen_dist_partitions(rank, world_size, params):
id_map,
rank,
world_size,
params.num_parts,
)
ntypes_ntypeid_map, ntypes, ntypeid_ntypes_map = get_node_types(schema_map)
......
......@@ -5,10 +5,10 @@ import os
import numpy as np
import pyarrow
import torch
from gloo_wrapper import alltoallv_cpu
from gloo_wrapper import allgather_sizes, alltoallv_cpu
from pyarrow import csv
from utils import map_partid_rank
from utils import map_partid_rank, memory_snapshot
class DistLookupService:
......@@ -45,9 +45,13 @@ class DistLookupService:
integer indicating the rank of a given process
world_size : integer
integer indicating the total no. of processes
num_parts : integer
interger representing the no. of partitions
"""
def __init__(self, input_dir, ntype_names, id_map, rank, world_size):
def __init__(
self, input_dir, ntype_names, id_map, rank, world_size, num_parts
):
assert os.path.isdir(input_dir)
assert ntype_names is not None
assert len(ntype_names) > 0
......@@ -113,8 +117,9 @@ class DistLookupService:
self.ntype_count = np.array(ntype_count, dtype=np.int64)
self.rank = rank
self.world_size = world_size
self.num_parts = num_parts
def get_partition_ids(self, global_nids):
def get_partition_ids(self, agg_global_nids):
"""
This function is used to get the partition-ids for a given set of global node ids
......@@ -136,9 +141,10 @@ class DistLookupService:
-----------
self : instance of this class
instance of this class, which is passed by the runtime implicitly
global_nids : numpy array
an array of global node-ids for which partition-ids are to be retrieved by
the distributed lookup service.
agg_global_nids : numpy array
an array of aggregated global node-ids for which partition-ids are
to be retrieved by the distributed lookup service.
Returns:
--------
......@@ -146,6 +152,30 @@ class DistLookupService:
list of integers, which are the partition-ids of the global-node-ids (which is the
function argument)
"""
CHUNK_SIZE = 200 * 1000 * 1000
# Determine the no. of times each process has to send alltoall messages.
local_rows = agg_global_nids.shape[0]
all_sizes = allgather_sizes(
[local_rows], self.world_size, self.num_parts, return_sizes=True
)
max_count = np.amax(all_sizes)
num_splits = np.ceil(max_count / CHUNK_SIZE).astype(np.uint16)
LOCAL_CHUNK_SIZE = np.ceil(local_rows / num_splits).astype(np.int64)
agg_partition_ids = []
logging.info(
f"[Rank: {self.rank}] BatchSize: {CHUNK_SIZE}, \
max_count: {max_count}, \
splits: {num_splits}, \
rows: {agg_global_nids.shape}, \
local batch_size: {LOCAL_CHUNK_SIZE}"
)
for split in range(num_splits):
# Compute the global_nids for this iteration
global_nids = agg_global_nids[
split * LOCAL_CHUNK_SIZE : (split + 1) * LOCAL_CHUNK_SIZE
]
# Find the process where global_nid --> partition-id(owner) is stored.
ntype_ids, type_nids = self.id_map(global_nids)
......@@ -157,7 +187,9 @@ class DistLookupService:
# The no. of these mappings stored by each process, in the lookup service, are
# equally split among all the processes in the lookup service, deterministically.
typeid_counts = self.ntype_count[ntype_ids]
chunk_sizes = np.ceil(typeid_counts / self.world_size).astype(np.int64)
chunk_sizes = np.ceil(typeid_counts / self.world_size).astype(
np.int64
)
service_owners = np.floor_divide(type_nids, chunk_sizes).astype(
np.int64
)
......@@ -178,7 +210,8 @@ class DistLookupService:
indices_list.append(idxes[0])
assert len(np.concatenate(indices_list)) == len(global_nids)
assert np.all(
np.sort(np.concatenate(indices_list)) == np.arange(len(global_nids))
np.sort(np.concatenate(indices_list))
== np.arange(len(global_nids))
)
# Send the request to everyone else.
......@@ -186,7 +219,9 @@ class DistLookupService:
# from all the other processes.
# These lists are global-node-ids whose global-node-ids <-> partition-id mappings
# are owned/stored by the current process
owner_req_list = alltoallv_cpu(self.rank, self.world_size, send_list)
owner_req_list = alltoallv_cpu(
self.rank, self.world_size, send_list
)
# Create the response list here for each of the request list received in the previous
# step. Populate the respective partition-ids in this response lists appropriately
......@@ -213,12 +248,18 @@ class DistLookupService:
if len(global_type_nids) <= 0:
continue
local_type_nids = global_type_nids - self.type_nid_begin[tid]
local_type_nids = (
global_type_nids - self.type_nid_begin[tid]
)
assert np.all(local_type_nids >= 0)
assert np.all(
local_type_nids
<= (self.type_nid_end[tid] + 1 - self.type_nid_begin[tid])
<= (
self.type_nid_end[tid]
+ 1
- self.type_nid_begin[tid]
)
)
cur_owners = self.partid_list[tid][local_type_nids]
......@@ -235,7 +276,9 @@ class DistLookupService:
out_list.append(torch.from_numpy(lookups))
# Send the partition-ids to their respective requesting processes.
owner_resp_list = alltoallv_cpu(self.rank, self.world_size, out_list)
owner_resp_list = alltoallv_cpu(
self.rank, self.world_size, out_list
)
# Owner_resp_list, is a list of lists of numpy arrays where each list
# is a list of partition-ids which the current process requested
......@@ -255,8 +298,15 @@ class DistLookupService:
global_nids_order = global_nids_order[sort_order_idx]
assert np.all(np.arange(len(global_nids)) == global_nids_order)
# Store the partition-ids for the current split
agg_partition_ids.append(owner_ids)
# Stitch the list of partition-ids and return to the caller
agg_partition_ids = np.concatenate(agg_partition_ids)
assert agg_global_nids.shape[0] == agg_partition_ids.shape[0]
# Now the owner_ids (partition-ids) which corresponding to the global_nids.
return owner_ids
return agg_partition_ids
def get_shuffle_nids(
self, global_nids, my_global_nids, my_shuffle_global_nids, world_size
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment