Unverified Commit 7f927939 authored by nv-dlasalle's avatar nv-dlasalle Committed by GitHub
Browse files

[Feature][DistDGL] Add NCCL support for range based partitions (#3213)

* Implement range based NDArrayPartition

* Finish implement range based partition support

* Add unit test

* Fix whitepace

* Add Kernel suffix

* Fix argument passing

* Add doxygen docs and improve variable naming

* Add unit test

* Add function for converting a partition book

* Add example to partition_op docs

* Fix dtype conversion for mxnet and tensorflow
parent 75fca8e4
...@@ -10,6 +10,7 @@ from .. import utils ...@@ -10,6 +10,7 @@ from .. import utils
from .shared_mem_utils import _to_shared_mem, _get_ndata_path, _get_edata_path, DTYPE_DICT from .shared_mem_utils import _to_shared_mem, _get_ndata_path, _get_edata_path, DTYPE_DICT
from .._ffi.ndarray import empty_shared_mem from .._ffi.ndarray import empty_shared_mem
from ..ndarray import exist_shared_mem_array from ..ndarray import exist_shared_mem_array
from ..partition import NDArrayPartition
from .id_map import IdMap from .id_map import IdMap
def _move_metadata_to_shared_mem(graph_name, num_nodes, num_edges, part_id, def _move_metadata_to_shared_mem(graph_name, num_nodes, num_edges, part_id,
...@@ -144,6 +145,38 @@ def get_shared_mem_partition_book(graph_name, graph_part): ...@@ -144,6 +145,38 @@ def get_shared_mem_partition_book(graph_name, graph_part):
else: else:
return BasicPartitionBook(part_id, num_parts, node_map_data, edge_map_data, graph_part) return BasicPartitionBook(part_id, num_parts, node_map_data, edge_map_data, graph_part)
def get_node_partition_from_book(book, device):
""" Get an NDArrayPartition of the nodes from a RangePartitionBook.
Parameters
----------
book : RangePartitionBook
The partition book to extract the node partition from.
device : Device context object.
The location to node partition is to be used.
Returns
-------
NDarrayPartition
The NDArrayPartition object for the nodes in the graph.
"""
assert isinstance(book, RangePartitionBook), "Can only convert " \
"RangePartitionBook to NDArrayPartition."
# create prefix-sum array on host
max_node_ids = F.zerocopy_from_numpy(book._max_node_ids)
cpu_range = F.cat([F.tensor([0], dtype=F.dtype(max_node_ids)),
max_node_ids+1], dim=0)
gpu_range = F.copy_to(cpu_range, ctx=device)
# convert from numpy
array_size = int(F.as_scalar(cpu_range[-1]))
num_parts = book.num_partitions()
return NDArrayPartition(array_size,
num_parts,
mode='range',
part_ranges=gpu_range)
class GraphPartitionBook(ABC): class GraphPartitionBook(ABC):
""" The base class of the graph partition book. """ The base class of the graph partition book.
......
...@@ -6,6 +6,7 @@ from ._ffi.function import _init_api ...@@ -6,6 +6,7 @@ from ._ffi.function import _init_api
from .heterograph import DGLHeteroGraph from .heterograph import DGLHeteroGraph
from . import backend as F from . import backend as F
from . import utils from . import utils
from .ndarray import NDArray
from .base import EID, NID, NTYPE, ETYPE from .base import EID, NID, NTYPE, ETYPE
from .subgraph import edge_subgraph from .subgraph import edge_subgraph
...@@ -395,11 +396,21 @@ class NDArrayPartition(object): ...@@ -395,11 +396,21 @@ class NDArrayPartition(object):
num_parts : int num_parts : int
The number of parts to divide the array into. The number of parts to divide the array into.
mode : String mode : String
The type of partition. Currently, the only valid value is 'remainder', The type of partition. Currently, the only valid values are
which assigns rows based on remainder when dividing the row id by the 'remainder' and 'range'.
'remainder' assigns rows based on remainder when dividing the row id by the
number of parts (e.g., i % num_parts). number of parts (e.g., i % num_parts).
part_ranges : List 'range' assigns rows based on which part of the range 'part_ranges'
Currently unused. they fall into.
part_ranges : Tensor or dgl.NDArray, Optional
Should only be specified when the mode is 'range'. Should be of the
length `num_parts + 1`, and be the exclusive prefix-sum of the number
of nodes in each partition. That is, for 3 partitions, we could have
the list [0, a, b, 'array_size'], and all rows with index less
than 'a' are assigned to partition 0, all rows with index greater than
or equal to 'a' and less than 'b' are in partition 1, and all rows
with index greater or equal to 'b' are in partition 2. Should have
the same context as the partitioned NDArray (i.e., be on the same GPU).
Examples Examples
-------- --------
...@@ -409,14 +420,40 @@ class NDArrayPartition(object): ...@@ -409,14 +420,40 @@ class NDArrayPartition(object):
>>> from dgl.partition import NDArrayPartition >>> from dgl.partition import NDArrayPartition
>>> part = NDArrayPartition(g.num_nodes(), num_parts, mode='remainder' ) >>> part = NDArrayPartition(g.num_nodes(), num_parts, mode='remainder' )
A range based partition of a homogenous graph `g`'s nodes, where
the nodes are stored in contiguous memory. This converts an existing
range based partitioning (e.g. from a
dgl.distributed.graph_partition_book.RangePartitionBook)
'max_node_map', to an NDArrayPartition 'part'.
>>> part_range = [0]
>>> for part in part_book.metadata():
>>> part_range.append(part_range[-1] + part['num_nodes'])
>>> part = NDArrayPartition(g.num_nodes(), num_parts, mode='range',
... part_ranges=part_range)
""" """
def __init__(self, array_size, num_parts, mode='remainder', part_ranges=None): def __init__(self, array_size, num_parts, mode='remainder', part_ranges=None):
assert num_parts > 0, 'Invalid "num_parts", must be > 0.' assert num_parts > 0, 'Invalid "num_parts", must be > 0.'
if mode == 'remainder': if mode == 'remainder':
assert part_ranges is None, 'When using remainder-based ' \ assert part_ranges is None, 'When using remainder-based ' \
'partitioning, "part_ranges" should not be specified.' 'partitioning, "part_ranges" should not be specified.'
self._partition = _CAPI_DGLNDArrayPartitionCreateRemainderBased( self._partition = _CAPI_DGLNDArrayPartitionCreateRemainderBased(
array_size, num_parts) array_size, num_parts)
elif mode == 'range':
assert part_ranges is not None, 'When using range-based ' \
'partitioning, "part_ranges" must not be None.'
assert part_ranges[0] == 0 and part_ranges[-1] == array_size, \
'part_ranges[0] must be 0, and part_ranges[-1] must be ' \
'"array_size".'
if F.is_tensor(part_ranges):
part_ranges = F.zerocopy_to_dgl_ndarray(part_ranges)
assert isinstance(part_ranges, NDArray), '"part_ranges" must ' \
'be Tensor or dgl.NDArray.'
self._partition = _CAPI_DGLNDArrayPartitionCreateRangeBased(
array_size,
num_parts,
part_ranges)
else: else:
assert False, 'Unknown partition mode "{}"'.format(mode) assert False, 'Unknown partition mode "{}"'.format(mode)
self._array_size = array_size self._array_size = array_size
......
...@@ -18,63 +18,238 @@ namespace dgl { ...@@ -18,63 +18,238 @@ namespace dgl {
namespace partition { namespace partition {
namespace impl { namespace impl {
template<typename IdType> __global__ void _MapProcByRemainder( namespace {
const IdType * const index,
const int64_t num_index, /**
const int64_t num_proc, * @brief Kernel to map global element IDs to partition IDs by remainder.
IdType * const proc_id) { *
assert(num_index <= gridDim.x*blockDim.x); * @tparam IdType The type of ID.
* @param global The global element IDs.
* @param num_elements The number of element IDs.
* @param num_parts The number of partitions.
* @param part_id The mapped partition ID (outupt).
*/
template<typename IdType>
__global__ void _MapProcByRemainderKernel(
const IdType * const global,
const int64_t num_elements,
const int64_t num_parts,
IdType * const part_id) {
assert(num_elements <= gridDim.x*blockDim.x);
const int64_t idx = blockDim.x*static_cast<int64_t>(blockIdx.x)+threadIdx.x; const int64_t idx = blockDim.x*static_cast<int64_t>(blockIdx.x)+threadIdx.x;
if (idx < num_index) { if (idx < num_elements) {
proc_id[idx] = index[idx] % num_proc; part_id[idx] = global[idx] % num_parts;
} }
} }
/**
* @brief Kernel to map global element IDs to partition IDs, using a bit-mask.
* The number of partitions must be a power a two.
*
* @tparam IdType The type of ID.
* @param global The global element IDs.
* @param num_elements The number of element IDs.
* @param mask The bit-mask with 1's for each bit to keep from the element ID to
* extract the partition ID (e.g., an 8 partition mask would be 0x07).
* @param part_id The mapped partition ID (outupt).
*/
template<typename IdType> template<typename IdType>
__global__ void _MapProcByMaskRemainder( __global__ void _MapProcByMaskRemainderKernel(
const IdType * const index, const IdType * const global,
const int64_t num_index, const int64_t num_elements,
const IdType mask, const IdType mask,
IdType * const proc_id) { IdType * const part_id) {
assert(num_index <= gridDim.x*blockDim.x); assert(num_elements <= gridDim.x*blockDim.x);
const int64_t idx = blockDim.x*static_cast<int64_t>(blockIdx.x)+threadIdx.x; const int64_t idx = blockDim.x*static_cast<int64_t>(blockIdx.x)+threadIdx.x;
if (idx < num_index) { if (idx < num_elements) {
proc_id[idx] = index[idx] & mask; part_id[idx] = global[idx] & mask;
} }
} }
/**
* @brief Kernel to map global element IDs to local element IDs.
*
* @tparam IdType The type of ID.
* @param global The global element IDs.
* @param num_elements The number of IDs.
* @param num_parts The number of partitions.
* @param local The local element IDs (output).
*/
template<typename IdType> template<typename IdType>
__global__ void _MapLocalIndexByRemainder( __global__ void _MapLocalIndexByRemainderKernel(
const IdType * const in, const IdType * const global,
IdType * const out, const int64_t num_elements,
const int64_t num_items, const int num_parts,
const int comm_size) { IdType * const local) {
assert(num_items <= gridDim.x*blockDim.x); assert(num_elements <= gridDim.x*blockDim.x);
const int64_t idx = threadIdx.x+blockDim.x*blockIdx.x; const int64_t idx = threadIdx.x+blockDim.x*blockIdx.x;
if (idx < num_items) { if (idx < num_elements) {
out[idx] = in[idx] / comm_size; local[idx] = global[idx] / num_parts;
} }
} }
/**
* @brief Kernel to map local element IDs within a partition to their global
* IDs, using the remainder over the number of partitions.
*
* @tparam IdType The type of ID.
* @param local The local element IDs.
* @param part_id The partition to map local elements from.
* @param num_elements The number of elements to map.
* @param num_parts The number of partitions.
* @param global The global element IDs (output).
*/
template<typename IdType> template<typename IdType>
__global__ void _MapGlobalIndexByRemainder( __global__ void _MapGlobalIndexByRemainderKernel(
const IdType * const in, const IdType * const local,
IdType * const out,
const int part_id, const int part_id,
const int64_t num_items, const int64_t num_elements,
const int comm_size) { const int num_parts,
assert(num_items <= gridDim.x*blockDim.x); IdType * const global) {
assert(num_elements <= gridDim.x*blockDim.x);
const int64_t idx = threadIdx.x+blockDim.x*blockIdx.x;
assert(part_id < num_parts);
if (idx < num_elements) {
global[idx] = (local[idx] * num_parts) + part_id;
}
}
/**
* @brief Device function to perform a binary search to find to which partition a
* given ID belongs.
*
* @tparam RangeType The type of range.
* @param range The prefix-sum of IDs assigned to partitions.
* @param num_parts The number of partitions.
* @param target The element ID to find the partition of.
*
* @return The partition.
*/
template<typename RangeType>
__device__ RangeType _SearchRange(
const RangeType * const range,
const int num_parts,
const RangeType target) {
int start = 0;
int end = num_parts;
int cur = (end+start)/2;
assert(range[0] == 0);
assert(target < range[num_parts]);
while (start+1 < end) {
if (target < range[cur]) {
end = cur;
} else {
start = cur;
}
cur = (start+end)/2;
}
return cur;
}
/**
* @brief Kernel to map element IDs to partition IDs.
*
* @tparam IdType The type of element ID.
* @tparam RangeType The type of of the range.
* @param range The prefix-sum of IDs assigned to partitions.
* @param global The global element IDs.
* @param num_elements The number of element IDs.
* @param num_parts The number of partitions.
* @param part_id The partition ID assigned to each element (output).
*/
template<typename IdType, typename RangeType>
__global__ void _MapProcByRangeKernel(
const RangeType * const range,
const IdType * const global,
const int64_t num_elements,
const int64_t num_parts,
IdType * const part_id) {
assert(num_elements <= gridDim.x*blockDim.x);
const int64_t idx = blockDim.x*static_cast<int64_t>(blockIdx.x)+threadIdx.x;
// rely on caching to load the range into L1 cache
if (idx < num_elements) {
part_id[idx] = static_cast<IdType>(_SearchRange(
range,
static_cast<int>(num_parts),
static_cast<RangeType>(global[idx])));
}
}
/**
* @brief Kernel to map global element IDs to their ID within their respective
* partition.
*
* @tparam IdType The type of element ID.
* @tparam RangeType The type of the range.
* @param range The prefix-sum of IDs assigned to partitions.
* @param global The global element IDs.
* @param num_elements The number of elements.
* @param num_parts The number of partitions.
* @param local The local element IDs (output).
*/
template<typename IdType, typename RangeType>
__global__ void _MapLocalIndexByRangeKernel(
const RangeType * const range,
const IdType * const global,
const int64_t num_elements,
const int num_parts,
IdType * const local) {
assert(num_elements <= gridDim.x*blockDim.x);
const int64_t idx = threadIdx.x+blockDim.x*blockIdx.x;
// rely on caching to load the range into L1 cache
if (idx < num_elements) {
const int proc = _SearchRange(
range,
static_cast<int>(num_parts),
static_cast<RangeType>(global[idx]));
local[idx] = global[idx] - range[proc];
}
}
/**
* @brief Kernel to map local element IDs within a partition to their global
* IDs.
*
* @tparam IdType The type of ID.
* @tparam RangeType The type of the range.
* @param range The prefix-sum of IDs assigend to partitions.
* @param local The local element IDs.
* @param part_id The partition to map local elements from.
* @param num_elements The number of elements to map.
* @param num_parts The number of partitions.
* @param global The global element IDs (output).
*/
template<typename IdType, typename RangeType>
__global__ void _MapGlobalIndexByRangeKernel(
const RangeType * const range,
const IdType * const local,
const int part_id,
const int64_t num_elements,
const int num_parts,
IdType * const global) {
assert(num_elements <= gridDim.x*blockDim.x);
const int64_t idx = threadIdx.x+blockDim.x*blockIdx.x; const int64_t idx = threadIdx.x+blockDim.x*blockIdx.x;
assert(part_id < comm_size); assert(part_id < num_parts);
if (idx < num_items) { // rely on caching to load the range into L1 cache
out[idx] = (in[idx] * comm_size) + part_id; if (idx < num_elements) {
global[idx] = local[idx] + range[part_id];
} }
} }
} // namespace
// Remainder Based Partition Operations
template <DLDeviceType XPU, typename IdType> template <DLDeviceType XPU, typename IdType>
std::pair<IdArray, NDArray> std::pair<IdArray, NDArray>
...@@ -120,14 +295,14 @@ GeneratePermutationFromRemainder( ...@@ -120,14 +295,14 @@ GeneratePermutationFromRemainder(
if (num_parts < (1 << part_bits)) { if (num_parts < (1 << part_bits)) {
// num_parts is not a power of 2 // num_parts is not a power of 2
CUDA_KERNEL_CALL(_MapProcByRemainder, grid, block, 0, stream, CUDA_KERNEL_CALL(_MapProcByRemainderKernel, grid, block, 0, stream,
static_cast<const IdType*>(in_idx->data), static_cast<const IdType*>(in_idx->data),
num_in, num_in,
num_parts, num_parts,
proc_id_in.get()); proc_id_in.get());
} else { } else {
// num_parts is a power of 2 // num_parts is a power of 2
CUDA_KERNEL_CALL(_MapProcByMaskRemainder, grid, block, 0, stream, CUDA_KERNEL_CALL(_MapProcByMaskRemainderKernel, grid, block, 0, stream,
static_cast<const IdType*>(in_idx->data), static_cast<const IdType*>(in_idx->data),
num_in, num_in,
static_cast<IdType>(num_parts-1), // bit mask static_cast<IdType>(num_parts-1), // bit mask
...@@ -227,15 +402,15 @@ IdArray MapToLocalFromRemainder( ...@@ -227,15 +402,15 @@ IdArray MapToLocalFromRemainder(
const dim3 grid((global_idx->shape[0] +block.x-1)/block.x); const dim3 grid((global_idx->shape[0] +block.x-1)/block.x);
CUDA_KERNEL_CALL( CUDA_KERNEL_CALL(
_MapLocalIndexByRemainder, _MapLocalIndexByRemainderKernel,
grid, grid,
block, block,
0, 0,
stream, stream,
static_cast<const IdType*>(global_idx->data), static_cast<const IdType*>(global_idx->data),
static_cast<IdType*>(local_idx->data),
global_idx->shape[0], global_idx->shape[0],
num_parts); num_parts,
static_cast<IdType*>(local_idx->data));
return local_idx; return local_idx;
} else { } else {
...@@ -274,16 +449,16 @@ IdArray MapToGlobalFromRemainder( ...@@ -274,16 +449,16 @@ IdArray MapToGlobalFromRemainder(
const dim3 grid((local_idx->shape[0] +block.x-1)/block.x); const dim3 grid((local_idx->shape[0] +block.x-1)/block.x);
CUDA_KERNEL_CALL( CUDA_KERNEL_CALL(
_MapGlobalIndexByRemainder, _MapGlobalIndexByRemainderKernel,
grid, grid,
block, block,
0, 0,
stream, stream,
static_cast<const IdType*>(local_idx->data), static_cast<const IdType*>(local_idx->data),
static_cast<IdType*>(global_idx->data),
part_id, part_id,
global_idx->shape[0], global_idx->shape[0],
num_parts); num_parts,
static_cast<IdType*>(global_idx->data));
return global_idx; return global_idx;
} else { } else {
...@@ -304,6 +479,271 @@ MapToGlobalFromRemainder<kDLGPU, int64_t>( ...@@ -304,6 +479,271 @@ MapToGlobalFromRemainder<kDLGPU, int64_t>(
int part_id); int part_id);
// Range Based Partition Operations
template <DLDeviceType XPU, typename IdType, typename RangeType>
std::pair<IdArray, NDArray>
GeneratePermutationFromRange(
int64_t array_size,
int num_parts,
IdArray range,
IdArray in_idx) {
std::pair<IdArray, NDArray> result;
const auto& ctx = in_idx->ctx;
auto device = DeviceAPI::Get(ctx);
cudaStream_t stream = CUDAThreadEntry::ThreadLocal()->stream;
const int64_t num_in = in_idx->shape[0];
CHECK_GE(num_parts, 1) << "The number of partitions (" << num_parts <<
") must be at least 1.";
if (num_parts == 1) {
// no permutation
result.first = aten::Range(0, num_in, sizeof(IdType)*8, ctx);
result.second = aten::Full(num_in, num_parts, sizeof(int64_t)*8, ctx);
return result;
}
result.first = aten::NewIdArray(num_in, ctx, sizeof(IdType)*8);
result.second = aten::Full(0, num_parts, sizeof(int64_t)*8, ctx);
int64_t * out_counts = static_cast<int64_t*>(result.second->data);
if (num_in == 0) {
// now that we've zero'd out_counts, nothing left to do for an empty
// mapping
return result;
}
const int64_t part_bits =
static_cast<int64_t>(std::ceil(std::log2(num_parts)));
// First, generate a mapping of indexes to processors
Workspace<IdType> proc_id_in(device, ctx, num_in);
{
const dim3 block(256);
const dim3 grid((num_in+block.x-1)/block.x);
CUDA_KERNEL_CALL(_MapProcByRangeKernel, grid, block, 0, stream,
static_cast<const RangeType*>(range->data),
static_cast<const IdType*>(in_idx->data),
num_in,
num_parts,
proc_id_in.get());
}
// then create a permutation array that groups processors together by
// performing a radix sort
Workspace<IdType> proc_id_out(device, ctx, num_in);
IdType * perm_out = static_cast<IdType*>(result.first->data);
{
IdArray perm_in = aten::Range(0, num_in, sizeof(IdType)*8, ctx);
size_t sort_workspace_size;
CUDA_CALL(cub::DeviceRadixSort::SortPairs(nullptr, sort_workspace_size,
proc_id_in.get(), proc_id_out.get(), static_cast<IdType*>(perm_in->data), perm_out,
num_in, 0, part_bits, stream));
Workspace<void> sort_workspace(device, ctx, sort_workspace_size);
CUDA_CALL(cub::DeviceRadixSort::SortPairs(sort_workspace.get(), sort_workspace_size,
proc_id_in.get(), proc_id_out.get(), static_cast<IdType*>(perm_in->data), perm_out,
num_in, 0, part_bits, stream));
}
// explicitly free so workspace can be re-used
proc_id_in.free();
// perform a histogram and then prefixsum on the sorted proc_id vector
// Count the number of values to be sent to each processor
{
using AtomicCount = unsigned long long; // NOLINT
static_assert(sizeof(AtomicCount) == sizeof(*out_counts),
"AtomicCount must be the same width as int64_t for atomicAdd "
"in cub::DeviceHistogram::HistogramEven() to work");
// TODO(dlasalle): Once https://github.com/NVIDIA/cub/pull/287 is merged,
// add a compile time check against the cub version to allow
// num_in > (2 << 31).
CHECK(num_in < static_cast<int64_t>(std::numeric_limits<int>::max())) <<
"number of values to insert into histogram must be less than max "
"value of int.";
size_t hist_workspace_size;
CUDA_CALL(cub::DeviceHistogram::HistogramEven(
nullptr,
hist_workspace_size,
proc_id_out.get(),
reinterpret_cast<AtomicCount*>(out_counts),
num_parts+1,
static_cast<IdType>(0),
static_cast<IdType>(num_parts+1),
static_cast<int>(num_in),
stream));
Workspace<void> hist_workspace(device, ctx, hist_workspace_size);
CUDA_CALL(cub::DeviceHistogram::HistogramEven(
hist_workspace.get(),
hist_workspace_size,
proc_id_out.get(),
reinterpret_cast<AtomicCount*>(out_counts),
num_parts+1,
static_cast<IdType>(0),
static_cast<IdType>(num_parts+1),
static_cast<int>(num_in),
stream));
}
return result;
}
template std::pair<IdArray, IdArray>
GeneratePermutationFromRange<kDLGPU, int32_t, int32_t>(
int64_t array_size,
int num_parts,
IdArray range,
IdArray in_idx);
template std::pair<IdArray, IdArray>
GeneratePermutationFromRange<kDLGPU, int64_t, int32_t>(
int64_t array_size,
int num_parts,
IdArray range,
IdArray in_idx);
template std::pair<IdArray, IdArray>
GeneratePermutationFromRange<kDLGPU, int32_t, int64_t>(
int64_t array_size,
int num_parts,
IdArray range,
IdArray in_idx);
template std::pair<IdArray, IdArray>
GeneratePermutationFromRange<kDLGPU, int64_t, int64_t>(
int64_t array_size,
int num_parts,
IdArray range,
IdArray in_idx);
template <DLDeviceType XPU, typename IdType, typename RangeType>
IdArray MapToLocalFromRange(
const int num_parts,
IdArray range,
IdArray global_idx) {
const auto& ctx = global_idx->ctx;
cudaStream_t stream = CUDAThreadEntry::ThreadLocal()->stream;
if (num_parts > 1 && global_idx->shape[0] > 0) {
IdArray local_idx = aten::NewIdArray(global_idx->shape[0], ctx,
sizeof(IdType)*8);
const dim3 block(128);
const dim3 grid((global_idx->shape[0] +block.x-1)/block.x);
CUDA_KERNEL_CALL(
_MapLocalIndexByRangeKernel,
grid,
block,
0,
stream,
static_cast<const RangeType*>(range->data),
static_cast<const IdType*>(global_idx->data),
global_idx->shape[0],
num_parts,
static_cast<IdType*>(local_idx->data));
return local_idx;
} else {
// no mapping to be done
return global_idx;
}
}
template IdArray
MapToLocalFromRange<kDLGPU, int32_t, int32_t>(
int num_parts,
IdArray range,
IdArray in_idx);
template IdArray
MapToLocalFromRange<kDLGPU, int64_t, int32_t>(
int num_parts,
IdArray range,
IdArray in_idx);
template IdArray
MapToLocalFromRange<kDLGPU, int32_t, int64_t>(
int num_parts,
IdArray range,
IdArray in_idx);
template IdArray
MapToLocalFromRange<kDLGPU, int64_t, int64_t>(
int num_parts,
IdArray range,
IdArray in_idx);
template <DLDeviceType XPU, typename IdType, typename RangeType>
IdArray MapToGlobalFromRange(
const int num_parts,
IdArray range,
IdArray local_idx,
const int part_id) {
CHECK_LT(part_id, num_parts) << "Invalid partition id " << part_id <<
"/" << num_parts;
CHECK_GE(part_id, 0) << "Invalid partition id " << part_id <<
"/" << num_parts;
const auto& ctx = local_idx->ctx;
cudaStream_t stream = CUDAThreadEntry::ThreadLocal()->stream;
if (num_parts > 1 && local_idx->shape[0] > 0) {
IdArray global_idx = aten::NewIdArray(local_idx->shape[0], ctx,
sizeof(IdType)*8);
const dim3 block(128);
const dim3 grid((local_idx->shape[0] +block.x-1)/block.x);
CUDA_KERNEL_CALL(
_MapGlobalIndexByRangeKernel,
grid,
block,
0,
stream,
static_cast<const RangeType*>(range->data),
static_cast<const IdType*>(local_idx->data),
part_id,
global_idx->shape[0],
num_parts,
static_cast<IdType*>(global_idx->data));
return global_idx;
} else {
// no mapping to be done
return local_idx;
}
}
template IdArray
MapToGlobalFromRange<kDLGPU, int32_t, int32_t>(
int num_parts,
IdArray range,
IdArray in_idx,
int part_id);
template IdArray
MapToGlobalFromRange<kDLGPU, int64_t, int32_t>(
int num_parts,
IdArray range,
IdArray in_idx,
int part_id);
template IdArray
MapToGlobalFromRange<kDLGPU, int32_t, int64_t>(
int num_parts,
IdArray range,
IdArray in_idx,
int part_id);
template IdArray
MapToGlobalFromRange<kDLGPU, int64_t, int64_t>(
int num_parts,
IdArray range,
IdArray in_idx,
int part_id);
} // namespace impl } // namespace impl
} // namespace partition } // namespace partition
......
...@@ -105,6 +105,109 @@ class RemainderPartition : public NDArrayPartition { ...@@ -105,6 +105,109 @@ class RemainderPartition : public NDArrayPartition {
} }
}; };
class RangePartition : public NDArrayPartition {
public:
RangePartition(
const int64_t array_size,
const int num_parts,
IdArray range) :
NDArrayPartition(array_size, num_parts),
range_(range),
// We also need a copy of the range on the CPU, to compute partition
// sizes. We require the input range on the GPU, as if we have multiple
// GPUs, we can't know which is the proper one to copy the array to, but we
// have only one CPU context, and can safely copy the array to that.
range_cpu_(range.CopyTo(DGLContext{kDLCPU, 0})) {
auto ctx = range->ctx;
if (ctx.device_type != kDLGPU) {
LOG(FATAL) << "The range for an NDArrayPartition is only supported "
" on GPUs. Transfer the range to the target device before "
"creating the partition.";
}
}
std::pair<IdArray, NDArray>
GeneratePermutation(
IdArray in_idx) const override {
auto ctx = in_idx->ctx;
#ifdef DGL_USE_CUDA
if (ctx.device_type == kDLGPU) {
if (ctx.device_type != range_->ctx.device_type ||
ctx.device_id != range_->ctx.device_id) {
LOG(FATAL) << "The range for the NDArrayPartition and the input "
"array must be on the same device: " << ctx << " vs. " << range_->ctx;
}
ATEN_ID_TYPE_SWITCH(in_idx->dtype, IdType, {
ATEN_ID_TYPE_SWITCH(range_->dtype, RangeType, {
return impl::GeneratePermutationFromRange<kDLGPU, IdType, RangeType>(
ArraySize(), NumParts(), range_, in_idx);
});
});
}
#endif
LOG(FATAL) << "Remainder based partitioning for the CPU is not yet "
"implemented.";
// should be unreachable
return std::pair<IdArray, NDArray>{};
}
IdArray MapToLocal(
IdArray in_idx) const override {
auto ctx = in_idx->ctx;
#ifdef DGL_USE_CUDA
if (ctx.device_type == kDLGPU) {
ATEN_ID_TYPE_SWITCH(in_idx->dtype, IdType, {
ATEN_ID_TYPE_SWITCH(range_->dtype, RangeType, {
return impl::MapToLocalFromRange<kDLGPU, IdType, RangeType>(
NumParts(), range_, in_idx);
});
});
}
#endif
LOG(FATAL) << "Remainder based partitioning for the CPU is not yet "
"implemented.";
// should be unreachable
return IdArray{};
}
IdArray MapToGlobal(
IdArray in_idx,
const int part_id) const override {
auto ctx = in_idx->ctx;
#ifdef DGL_USE_CUDA
if (ctx.device_type == kDLGPU) {
ATEN_ID_TYPE_SWITCH(in_idx->dtype, IdType, {
ATEN_ID_TYPE_SWITCH(range_->dtype, RangeType, {
return impl::MapToGlobalFromRange<kDLGPU, IdType, RangeType>(
NumParts(), range_, in_idx, part_id);
});
});
}
#endif
LOG(FATAL) << "Remainder based partitioning for the CPU is not yet "
"implemented.";
// should be unreachable
return IdArray{};
}
int64_t PartSize(const int part_id) const override {
CHECK_LT(part_id, NumParts()) << "Invalid part ID (" << part_id << ") for "
"partition of size " << NumParts() << ".";
ATEN_ID_TYPE_SWITCH(range_cpu_->dtype, RangeType, {
const RangeType * const ptr = static_cast<const RangeType*>(range_cpu_->data);
return ptr[part_id+1]-ptr[part_id];
});
}
private:
IdArray range_;
IdArray range_cpu_;
};
NDArrayPartitionRef CreatePartitionRemainderBased( NDArrayPartitionRef CreatePartitionRemainderBased(
const int64_t array_size, const int64_t array_size,
const int num_parts) { const int num_parts) {
...@@ -112,6 +215,16 @@ NDArrayPartitionRef CreatePartitionRemainderBased( ...@@ -112,6 +215,16 @@ NDArrayPartitionRef CreatePartitionRemainderBased(
array_size, num_parts)); array_size, num_parts));
} }
NDArrayPartitionRef CreatePartitionRangeBased(
const int64_t array_size,
const int num_parts,
IdArray range) {
return NDArrayPartitionRef(std::make_shared<RangePartition>(
array_size,
num_parts,
range));
}
DGL_REGISTER_GLOBAL("partition._CAPI_DGLNDArrayPartitionCreateRemainderBased") DGL_REGISTER_GLOBAL("partition._CAPI_DGLNDArrayPartitionCreateRemainderBased")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
int64_t array_size = args[0]; int64_t array_size = args[0];
...@@ -120,6 +233,17 @@ DGL_REGISTER_GLOBAL("partition._CAPI_DGLNDArrayPartitionCreateRemainderBased") ...@@ -120,6 +233,17 @@ DGL_REGISTER_GLOBAL("partition._CAPI_DGLNDArrayPartitionCreateRemainderBased")
*rv = CreatePartitionRemainderBased(array_size, num_parts); *rv = CreatePartitionRemainderBased(array_size, num_parts);
}); });
DGL_REGISTER_GLOBAL("partition._CAPI_DGLNDArrayPartitionCreateRangeBased")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
const int64_t array_size = args[0];
const int num_parts = args[1];
IdArray range = args[2];
*rv = CreatePartitionRangeBased(array_size, num_parts, range);
});
DGL_REGISTER_GLOBAL("partition._CAPI_DGLNDArrayPartitionGetPartSize") DGL_REGISTER_GLOBAL("partition._CAPI_DGLNDArrayPartitionGetPartSize")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([] (DGLArgs args, DGLRetValue* rv) {
NDArrayPartitionRef part = args[0]; NDArrayPartitionRef part = args[0];
......
...@@ -122,6 +122,24 @@ NDArrayPartitionRef CreatePartitionRemainderBased( ...@@ -122,6 +122,24 @@ NDArrayPartitionRef CreatePartitionRemainderBased(
int64_t array_size, int64_t array_size,
int num_parts); int num_parts);
/**
* @brief Create a new partition object, using the range (exclusive prefix-sum)
* provided to identify which rows belong to which partitions.
*
* @param array_size The size of the partitioned array.
* @param num_parts The number of parts the array is partitioned into.
* @param range The exclusive prefix-sum of the number of rows owned by each
* partition. The first value must be zero, and the last value must be the
* total number of rows. It should be of length `num_parts+1`.
*
* @return The partition object.
*/
NDArrayPartitionRef CreatePartitionRangeBased(
int64_t array_size,
int num_parts,
IdArray range);
} // namespace partition } // namespace partition
} // namespace dgl } // namespace dgl
......
...@@ -16,7 +16,12 @@ namespace partition { ...@@ -16,7 +16,12 @@ namespace partition {
namespace impl { namespace impl {
/** /**
* @brief Create a permutation that groups indices by the part id. * @brief Create a permutation that groups indices by the part id when used for
* slicing, via the remainder. That is, for the input indices A, find I
* such that A[I] is grouped by part ID.
*
* For example, if we have the set of indices [3, 9, 2, 4, 1, 7] and two
* partitions, the permutation vector would be [2, 3, 0, 1, 4, 5].
* *
* @tparam XPU The type of device to run on. * @tparam XPU The type of device to run on.
* @tparam IdType The type of the index. * @tparam IdType The type of the index.
...@@ -70,6 +75,79 @@ IdArray MapToGlobalFromRemainder( ...@@ -70,6 +75,79 @@ IdArray MapToGlobalFromRemainder(
IdArray local_idx, IdArray local_idx,
int part_id); int part_id);
/**
* @brief Create a permutation that groups indices by the part id when used for
* slicing. That is, for the input indices A, find I such that A[I] is grouped
* by part ID.
*
* For example, if we have a range of [0, 5, 10] and the set of indices
* [3, 9, 2, 4, 1, 7], the permutation vector would be [0, 2, 3, 4, 1, 5].
*
* @tparam XPU The type of device to run on.
* @tparam IdType The type of the index.
* @tparam RangeType THe type of the range.
* @param array_size The total size of the partitioned array.
* @param num_parts The number parts the array id divided into.
* @param range The exclusive prefix-sum, representing the range of rows
* assigned to each partition. Must be on the same context as `in_idx`.
* @param in_idx The array of indices to group by part id.
*
* @return The permutation to group the indices by part id, and the number of
* indices in each part.
*/
template <DLDeviceType XPU, typename IdType, typename RangeType>
std::pair<IdArray, IdArray>
GeneratePermutationFromRange(
int64_t array_size,
int num_parts,
IdArray range,
IdArray in_idx);
/**
* @brief Generate the set of local indices from the global indices, using
* remainder. That is, for each index `i` in `global_idx`, the local index
* is computed as `global_idx[i] / num_parts`.
*
* @tparam XPU The type of device to run on.
* @tparam IdType The type of the index.
* @tparam RangeType THe type of the range.
* @param num_parts The number parts the array id divided into.
* @param range The exclusive prefix-sum, representing the range of rows
* assigned to each partition. Must be on the same context as `global_idx`.
* @param global_idx The array of global indices to map.
*
* @return The array of local indices.
*/
template <DLDeviceType XPU, typename IdType, typename RangeType>
IdArray MapToLocalFromRange(
int num_parts,
IdArray range,
IdArray global_idx);
/**
* @brief Generate the set of global indices from the local indices, using
* remainder. That is, for each index `i` in `local_idx`, the global index
* is computed as `local_idx[i] * num_parts + part_id`.
*
* @tparam XPU The type of device to run on.
* @tparam IdType The type of the index.
* @tparam RangeType THe type of the range.
* @param num_parts The number parts the array id divided into.
* @param range The exclusive prefix-sum, representing the range of rows
* assigned to each partition. Must be on the same context as `local_idx`.
* @param local_idx The array of local indices to map.
* @param part_id The id of the current part.
*
* @return The array of global indices.
*/
template <DLDeviceType XPU, typename IdType, typename RangeType>
IdArray MapToGlobalFromRange(
int num_parts,
IdArray range,
IdArray local_idx,
int part_id);
} // namespace impl } // namespace impl
} // namespace partition } // namespace partition
......
...@@ -25,7 +25,7 @@ def test_nccl_id(): ...@@ -25,7 +25,7 @@ def test_nccl_id():
@unittest.skipIf(F._default_context_str == 'cpu', reason="NCCL only runs on GPU.") @unittest.skipIf(F._default_context_str == 'cpu', reason="NCCL only runs on GPU.")
def test_nccl_sparse_push_single(): def test_nccl_sparse_push_single_remainder():
nccl_id = nccl.UniqueId() nccl_id = nccl.UniqueId()
comm = nccl.Communicator(1, 0, nccl_id) comm = nccl.Communicator(1, 0, nccl_id)
...@@ -39,7 +39,7 @@ def test_nccl_sparse_push_single(): ...@@ -39,7 +39,7 @@ def test_nccl_sparse_push_single():
assert F.array_equal(rv, value) assert F.array_equal(rv, value)
@unittest.skipIf(F._default_context_str == 'cpu', reason="NCCL only runs on GPU.") @unittest.skipIf(F._default_context_str == 'cpu', reason="NCCL only runs on GPU.")
def test_nccl_sparse_pull_single(): def test_nccl_sparse_pull_single_remainder():
nccl_id = nccl.UniqueId() nccl_id = nccl.UniqueId()
comm = nccl.Communicator(1, 0, nccl_id) comm = nccl.Communicator(1, 0, nccl_id)
...@@ -52,6 +52,36 @@ def test_nccl_sparse_pull_single(): ...@@ -52,6 +52,36 @@ def test_nccl_sparse_pull_single():
exp_rv = F.gather_row(value, req_index) exp_rv = F.gather_row(value, req_index)
assert F.array_equal(rv, exp_rv) assert F.array_equal(rv, exp_rv)
@unittest.skipIf(F._default_context_str == 'cpu', reason="NCCL only runs on GPU.")
def test_nccl_sparse_push_single_range():
nccl_id = nccl.UniqueId()
comm = nccl.Communicator(1, 0, nccl_id)
index = F.randint([10000], F.int32, F.ctx(), 0, 10000)
value = F.uniform([10000, 100], F.float32, F.ctx(), -1.0, 1.0)
part_ranges = F.copy_to(F.tensor([0, value.shape[0]], dtype=F.int64), F.ctx())
part = NDArrayPartition(10000, 1, 'range', part_ranges=part_ranges)
ri, rv = comm.sparse_all_to_all_push(index, value, part)
assert F.array_equal(ri, index)
assert F.array_equal(rv, value)
@unittest.skipIf(F._default_context_str == 'cpu', reason="NCCL only runs on GPU.")
def test_nccl_sparse_pull_single_range():
nccl_id = nccl.UniqueId()
comm = nccl.Communicator(1, 0, nccl_id)
req_index = F.randint([10000], F.int64, F.ctx(), 0, 100000)
value = F.uniform([100000, 100], F.float32, F.ctx(), -1.0, 1.0)
part_ranges = F.copy_to(F.tensor([0, value.shape[0]], dtype=F.int64), F.ctx())
part = NDArrayPartition(100000, 1, 'range', part_ranges=part_ranges)
rv = comm.sparse_all_to_all_pull(req_index, value, part)
exp_rv = F.gather_row(value, req_index)
assert F.array_equal(rv, exp_rv)
if __name__ == '__main__': if __name__ == '__main__':
test_nccl_id() test_nccl_id()
......
from dgl.partition import NDArrayPartition
from dgl.distributed import graph_partition_book as gpb
import unittest
import backend as F
from test_utils import parametrize_dtype
@unittest.skipIf(F._default_context_str == 'cpu', reason="NDArrayPartition only works on GPU.")
@parametrize_dtype
def test_get_node_partition_from_book(idtype):
node_map = {
"type_n": F.tensor([
[0,3],
[4,5],
[6,10]
], dtype=idtype)}
edge_map = {
"type_e": F.tensor([
[0,9],
[10,15],
[16,25]
], dtype=idtype)}
book = gpb.RangePartitionBook(0, 3, node_map, edge_map,
{"type_n": 0}, {"type_e": 0})
partition = gpb.get_node_partition_from_book(book, F.ctx())
assert partition.num_parts() == 3
assert partition.array_size() == 11
test_ids = F.copy_to(F.tensor([0, 2, 6, 7, 10], dtype=idtype), F.ctx())
act_ids = partition.map_to_local(test_ids)
exp_ids = F.copy_to(F.tensor([0, 2, 0, 1, 4], dtype=idtype), F.ctx())
assert F.array_equal(act_ids, exp_ids)
test_ids = F.copy_to(F.tensor([0, 2], dtype=idtype), F.ctx())
act_ids = partition.map_to_global(test_ids, 0)
exp_ids = F.copy_to(F.tensor([0, 2], dtype=idtype), F.ctx())
assert F.array_equal(act_ids, exp_ids)
test_ids = F.copy_to(F.tensor([0, 1], dtype=idtype), F.ctx())
act_ids = partition.map_to_global(test_ids, 1)
exp_ids = F.copy_to(F.tensor([4, 5], dtype=idtype), F.ctx())
assert F.array_equal(act_ids, exp_ids)
test_ids = F.copy_to(F.tensor([0, 1, 4], dtype=idtype), F.ctx())
act_ids = partition.map_to_global(test_ids, 2)
exp_ids = F.copy_to(F.tensor([6, 7, 10], dtype=idtype), F.ctx())
assert F.array_equal(act_ids, exp_ids)
...@@ -6,7 +6,6 @@ ...@@ -6,7 +6,6 @@
using namespace dgl; using namespace dgl;
using namespace dgl::partition; using namespace dgl::partition;
template<DLDeviceType XPU, typename IdType> template<DLDeviceType XPU, typename IdType>
void _TestRemainder_GeneratePermutation() { void _TestRemainder_GeneratePermutation() {
const int64_t size = 160000; const int64_t size = 160000;
...@@ -78,6 +77,7 @@ void _TestRemainder_MapToX() { ...@@ -78,6 +77,7 @@ void _TestRemainder_MapToX() {
} }
} }
TEST(PartitionTest, TestRemainderPartition) { TEST(PartitionTest, TestRemainderPartition) {
#ifdef DGL_USE_CUDA #ifdef DGL_USE_CUDA
_TestRemainder_GeneratePermutation<kDLGPU, int32_t>(); _TestRemainder_GeneratePermutation<kDLGPU, int32_t>();
...@@ -86,9 +86,114 @@ TEST(PartitionTest, TestRemainderPartition) { ...@@ -86,9 +86,114 @@ TEST(PartitionTest, TestRemainderPartition) {
_TestRemainder_MapToX<kDLGPU, int32_t>(); _TestRemainder_MapToX<kDLGPU, int32_t>();
_TestRemainder_MapToX<kDLGPU, int64_t>(); _TestRemainder_MapToX<kDLGPU, int64_t>();
#endif #endif
// CPU is not implemented // CPU is not implemented
} }
template<typename INDEX, typename RANGE>
int _FindPart(
const INDEX idx,
const RANGE * const range,
const int num_parts)
{
for (int i = 0; i < num_parts; ++i) {
if (range[i+1] > idx) {
return i;
}
}
return -1;
}
template<DLDeviceType XPU, typename IdType>
void _TestRange_GeneratePermutation() {
const int64_t size = 160000;
const int num_parts = 7;
IdArray range = aten::NewIdArray(num_parts+1, DGLContext{kDLCPU, 0},
sizeof(IdType)*8);
for (int i = 0; i < num_parts; ++i) {
range.Ptr<IdType>()[i] = (size/num_parts)*i;
}
range.Ptr<IdType>()[num_parts] = size;
NDArrayPartitionRef part = CreatePartitionRangeBased(
size, num_parts, range.CopyTo(DGLContext{XPU, 0}));
IdArray idxs = aten::Range(0, size/10, sizeof(IdType)*8,
DGLContext{XPU, 0});
std::pair<IdArray, IdArray> result = part->GeneratePermutation(idxs);
// first part of result should be the permutation
IdArray perm = result.first.CopyTo(DGLContext{kDLCPU, 0});
ASSERT_TRUE(perm.Ptr<IdType>() != nullptr);
ASSERT_EQ(perm->shape[0], idxs->shape[0]);
const IdType * const perm_cpu = static_cast<const IdType*>(perm->data);
// second part of result should be the counts
IdArray counts = result.second.CopyTo(DGLContext{kDLCPU, 0});
ASSERT_TRUE(counts.Ptr<int64_t>() != nullptr);
ASSERT_EQ(counts->shape[0], num_parts);
const int64_t * const counts_cpu = static_cast<const int64_t*>(counts->data);
std::vector<int64_t> prefix(num_parts+1, 0);
for (int p = 0; p < num_parts; ++p) {
prefix[p+1] = prefix[p] + counts_cpu[p];
}
ASSERT_EQ(prefix.back(), idxs->shape[0]);
// copy original indexes to cpu
idxs = idxs.CopyTo(DGLContext{kDLCPU, 0});
const IdType * const idxs_cpu = static_cast<const IdType*>(idxs->data);
for (int p = 0; p < num_parts; ++p) {
for (int64_t i = prefix[p]; i < prefix[p+1]; ++i) {
EXPECT_EQ(_FindPart(idxs_cpu[perm_cpu[i]], range.Ptr<IdType>(), num_parts), p);
}
}
}
template<DLDeviceType XPU, typename IdType>
void _TestRange_MapToX() {
const int64_t size = 160000;
const int num_parts = 7;
IdArray range = aten::NewIdArray(num_parts+1, DGLContext{kDLCPU, 0},
sizeof(IdType)*8);
for (int i = 0; i < num_parts; ++i) {
Ptr<IdType>(range)[i] = (size/num_parts)*i;
}
range.Ptr<IdType>()[num_parts] = size;
NDArrayPartitionRef part = CreatePartitionRangeBased(
size, num_parts, range.CopyTo(DGLContext{XPU, 0}));
for (int part_id = 0; part_id < num_parts; ++part_id) {
IdArray local = aten::Range(0, part->PartSize(part_id), sizeof(IdType)*8,
DGLContext{XPU, 0});
IdArray global = part->MapToGlobal(local, part_id);
IdArray act_local = part->MapToLocal(global).CopyTo(CPU);
ASSERT_EQ(global->shape[0], local->shape[0]);
global = global.CopyTo(CPU);
for (size_t i = 0; i < global->shape[0]; ++i) {
EXPECT_EQ(_FindPart(Ptr<IdType>(global)[i], Ptr<IdType>(range), num_parts), part_id) << "i=" << i <<
", num_parts=" << num_parts << ", part_id=" << part_id << ", shape=" << global->shape[0];
}
// the remapped local indices to should match the original
local = local.CopyTo(CPU);
ASSERT_EQ(local->shape[0], act_local->shape[0]);
for (size_t i = 0; i < act_local->shape[0]; ++i) {
EXPECT_EQ(Ptr<IdType>(local)[i], Ptr<IdType>(act_local)[i]);
}
}
}
TEST(PartitionTest, TestRangePartition) {
#ifdef DGL_USE_CUDA
_TestRange_GeneratePermutation<kDLGPU, int32_t>();
_TestRange_GeneratePermutation<kDLGPU, int64_t>();
_TestRange_MapToX<kDLGPU, int32_t>();
_TestRange_MapToX<kDLGPU, int64_t>();
#endif
// CPU is not implemented
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment