Unverified Commit 61b6edab authored by kylasa's avatar kylasa Committed by GitHub
Browse files

[DistDGL][Mem_Optimizations]get_partition_ids, service provided by the...

[DistDGL][Mem_Optimizations]get_partition_ids, service provided by the distributed lookup service has high memory footprint (#5226)

* get_partition_ids, service provided by the distributed lookup service has high memory footprint

'get_partitionid' function, which is used to retrieve owner processes of the given list of global node ids, has high memory footprint. Currently this is of the order of 8x compared to the size of the input list.

For massively large datasets, this memory needs are very unrealistic and may result in OOM. In the case of CoreGraph, when retrieving owner of an edge list of size 6 Billion edges, the memory needs can be as high as 8*8*8 = 256 GB.

To limit the amount of memory used by this function, we split the size of the message sent to the distributed lookup service, so that each message is limited by the number of global node ids, which is 200 million. This reduced the memory footprint of this entire function to be no more than 0.2 * 8 * 8 = 13 GB. which is within reasonable limits.

Now since we send multiple small messages compared to one large message to the distributed lookup service, this may consume more wall-clock-time compared to earlier implementation.

* lintrunner patch.

* using np.ceil() per suggestion.

* converting the output of np.ceil() as ints.
parent 99937422
...@@ -1000,6 +1000,7 @@ def gen_dist_partitions(rank, world_size, params): ...@@ -1000,6 +1000,7 @@ def gen_dist_partitions(rank, world_size, params):
id_map, id_map,
rank, rank,
world_size, world_size,
params.num_parts,
) )
ntypes_ntypeid_map, ntypes, ntypeid_ntypes_map = get_node_types(schema_map) ntypes_ntypeid_map, ntypes, ntypeid_ntypes_map = get_node_types(schema_map)
......
...@@ -5,10 +5,10 @@ import os ...@@ -5,10 +5,10 @@ import os
import numpy as np import numpy as np
import pyarrow import pyarrow
import torch import torch
from gloo_wrapper import alltoallv_cpu from gloo_wrapper import allgather_sizes, alltoallv_cpu
from pyarrow import csv from pyarrow import csv
from utils import map_partid_rank from utils import map_partid_rank, memory_snapshot
class DistLookupService: class DistLookupService:
...@@ -45,9 +45,13 @@ class DistLookupService: ...@@ -45,9 +45,13 @@ class DistLookupService:
integer indicating the rank of a given process integer indicating the rank of a given process
world_size : integer world_size : integer
integer indicating the total no. of processes integer indicating the total no. of processes
num_parts : integer
interger representing the no. of partitions
""" """
def __init__(self, input_dir, ntype_names, id_map, rank, world_size): def __init__(
self, input_dir, ntype_names, id_map, rank, world_size, num_parts
):
assert os.path.isdir(input_dir) assert os.path.isdir(input_dir)
assert ntype_names is not None assert ntype_names is not None
assert len(ntype_names) > 0 assert len(ntype_names) > 0
...@@ -113,8 +117,9 @@ class DistLookupService: ...@@ -113,8 +117,9 @@ class DistLookupService:
self.ntype_count = np.array(ntype_count, dtype=np.int64) self.ntype_count = np.array(ntype_count, dtype=np.int64)
self.rank = rank self.rank = rank
self.world_size = world_size self.world_size = world_size
self.num_parts = num_parts
def get_partition_ids(self, global_nids): def get_partition_ids(self, agg_global_nids):
""" """
This function is used to get the partition-ids for a given set of global node ids This function is used to get the partition-ids for a given set of global node ids
...@@ -136,9 +141,10 @@ class DistLookupService: ...@@ -136,9 +141,10 @@ class DistLookupService:
----------- -----------
self : instance of this class self : instance of this class
instance of this class, which is passed by the runtime implicitly instance of this class, which is passed by the runtime implicitly
global_nids : numpy array
an array of global node-ids for which partition-ids are to be retrieved by agg_global_nids : numpy array
the distributed lookup service. an array of aggregated global node-ids for which partition-ids are
to be retrieved by the distributed lookup service.
Returns: Returns:
-------- --------
...@@ -146,6 +152,30 @@ class DistLookupService: ...@@ -146,6 +152,30 @@ class DistLookupService:
list of integers, which are the partition-ids of the global-node-ids (which is the list of integers, which are the partition-ids of the global-node-ids (which is the
function argument) function argument)
""" """
CHUNK_SIZE = 200 * 1000 * 1000
# Determine the no. of times each process has to send alltoall messages.
local_rows = agg_global_nids.shape[0]
all_sizes = allgather_sizes(
[local_rows], self.world_size, self.num_parts, return_sizes=True
)
max_count = np.amax(all_sizes)
num_splits = np.ceil(max_count / CHUNK_SIZE).astype(np.uint16)
LOCAL_CHUNK_SIZE = np.ceil(local_rows / num_splits).astype(np.int64)
agg_partition_ids = []
logging.info(
f"[Rank: {self.rank}] BatchSize: {CHUNK_SIZE}, \
max_count: {max_count}, \
splits: {num_splits}, \
rows: {agg_global_nids.shape}, \
local batch_size: {LOCAL_CHUNK_SIZE}"
)
for split in range(num_splits):
# Compute the global_nids for this iteration
global_nids = agg_global_nids[
split * LOCAL_CHUNK_SIZE : (split + 1) * LOCAL_CHUNK_SIZE
]
# Find the process where global_nid --> partition-id(owner) is stored. # Find the process where global_nid --> partition-id(owner) is stored.
ntype_ids, type_nids = self.id_map(global_nids) ntype_ids, type_nids = self.id_map(global_nids)
...@@ -157,7 +187,9 @@ class DistLookupService: ...@@ -157,7 +187,9 @@ class DistLookupService:
# The no. of these mappings stored by each process, in the lookup service, are # The no. of these mappings stored by each process, in the lookup service, are
# equally split among all the processes in the lookup service, deterministically. # equally split among all the processes in the lookup service, deterministically.
typeid_counts = self.ntype_count[ntype_ids] typeid_counts = self.ntype_count[ntype_ids]
chunk_sizes = np.ceil(typeid_counts / self.world_size).astype(np.int64) chunk_sizes = np.ceil(typeid_counts / self.world_size).astype(
np.int64
)
service_owners = np.floor_divide(type_nids, chunk_sizes).astype( service_owners = np.floor_divide(type_nids, chunk_sizes).astype(
np.int64 np.int64
) )
...@@ -178,7 +210,8 @@ class DistLookupService: ...@@ -178,7 +210,8 @@ class DistLookupService:
indices_list.append(idxes[0]) indices_list.append(idxes[0])
assert len(np.concatenate(indices_list)) == len(global_nids) assert len(np.concatenate(indices_list)) == len(global_nids)
assert np.all( assert np.all(
np.sort(np.concatenate(indices_list)) == np.arange(len(global_nids)) np.sort(np.concatenate(indices_list))
== np.arange(len(global_nids))
) )
# Send the request to everyone else. # Send the request to everyone else.
...@@ -186,7 +219,9 @@ class DistLookupService: ...@@ -186,7 +219,9 @@ class DistLookupService:
# from all the other processes. # from all the other processes.
# These lists are global-node-ids whose global-node-ids <-> partition-id mappings # These lists are global-node-ids whose global-node-ids <-> partition-id mappings
# are owned/stored by the current process # are owned/stored by the current process
owner_req_list = alltoallv_cpu(self.rank, self.world_size, send_list) owner_req_list = alltoallv_cpu(
self.rank, self.world_size, send_list
)
# Create the response list here for each of the request list received in the previous # Create the response list here for each of the request list received in the previous
# step. Populate the respective partition-ids in this response lists appropriately # step. Populate the respective partition-ids in this response lists appropriately
...@@ -213,12 +248,18 @@ class DistLookupService: ...@@ -213,12 +248,18 @@ class DistLookupService:
if len(global_type_nids) <= 0: if len(global_type_nids) <= 0:
continue continue
local_type_nids = global_type_nids - self.type_nid_begin[tid] local_type_nids = (
global_type_nids - self.type_nid_begin[tid]
)
assert np.all(local_type_nids >= 0) assert np.all(local_type_nids >= 0)
assert np.all( assert np.all(
local_type_nids local_type_nids
<= (self.type_nid_end[tid] + 1 - self.type_nid_begin[tid]) <= (
self.type_nid_end[tid]
+ 1
- self.type_nid_begin[tid]
)
) )
cur_owners = self.partid_list[tid][local_type_nids] cur_owners = self.partid_list[tid][local_type_nids]
...@@ -235,7 +276,9 @@ class DistLookupService: ...@@ -235,7 +276,9 @@ class DistLookupService:
out_list.append(torch.from_numpy(lookups)) out_list.append(torch.from_numpy(lookups))
# Send the partition-ids to their respective requesting processes. # Send the partition-ids to their respective requesting processes.
owner_resp_list = alltoallv_cpu(self.rank, self.world_size, out_list) owner_resp_list = alltoallv_cpu(
self.rank, self.world_size, out_list
)
# Owner_resp_list, is a list of lists of numpy arrays where each list # Owner_resp_list, is a list of lists of numpy arrays where each list
# is a list of partition-ids which the current process requested # is a list of partition-ids which the current process requested
...@@ -255,8 +298,15 @@ class DistLookupService: ...@@ -255,8 +298,15 @@ class DistLookupService:
global_nids_order = global_nids_order[sort_order_idx] global_nids_order = global_nids_order[sort_order_idx]
assert np.all(np.arange(len(global_nids)) == global_nids_order) assert np.all(np.arange(len(global_nids)) == global_nids_order)
# Store the partition-ids for the current split
agg_partition_ids.append(owner_ids)
# Stitch the list of partition-ids and return to the caller
agg_partition_ids = np.concatenate(agg_partition_ids)
assert agg_global_nids.shape[0] == agg_partition_ids.shape[0]
# Now the owner_ids (partition-ids) which corresponding to the global_nids. # Now the owner_ids (partition-ids) which corresponding to the global_nids.
return owner_ids return agg_partition_ids
def get_shuffle_nids( def get_shuffle_nids(
self, global_nids, my_global_nids, my_shuffle_global_nids, world_size self, global_nids, my_global_nids, my_shuffle_global_nids, world_size
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment